├── mtg ├── __init__.py ├── ml │ ├── __init__.py │ ├── utils.py │ ├── README.md │ ├── nn.py │ ├── trainer.py │ ├── layers.py │ ├── generator.py │ ├── display.py │ └── models.py ├── obj │ ├── __init__.py │ ├── README.md │ ├── scryfall_utils.py │ ├── cards.py │ ├── dataloading_utils.py │ └── expansion.py ├── _version.py └── scripts │ ├── preprocess.py │ ├── README.md │ ├── train_builder.py │ └── train_drafter.py ├── requirements.txt ├── setup.py ├── README.md ├── .gitignore └── LICENSE /mtg/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mtg/ml/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mtg/obj/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mtg/_version.py: -------------------------------------------------------------------------------- 1 | __version__ = '1.0.0' 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | pandas 3 | scikit-learn 4 | ipykernel 5 | jupyter 6 | matplotlib 7 | tensorflow 8 | tqdm 9 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | from pkg_resources import parse_requirements 4 | import pathlib 5 | from setuptools import find_packages, setup 6 | 7 | README_FILE = 'README.md' 8 | REQUIREMENTS_FILE = 'requirements.txt' 9 | VERSION_FILE = 'mtg/_version.py' 10 | VERSION_REGEXP = r'^__version__ = \'(\d+\.\d+\.\d+)\'' 11 | 12 | r = re.search(VERSION_REGEXP, open(VERSION_FILE).read(), re.M) 13 | if r is None: 14 | raise RuntimeError(f'Unable to find version string in {VERSION_FILE}.') 15 | 16 | version = r.group(1) 17 | long_description = open(README_FILE, encoding='utf-8').read() 18 | install_requires = [str(r) for r in parse_requirements(open(REQUIREMENTS_FILE, 'rt'))] 19 | 20 | setup( 21 | name='mtg', 22 | version=version, 23 | description='mtg is a collection of data science and ml projects for Magic:the Gathering', 24 | long_description=long_description, 25 | long_description_content_type='text/markdown', 26 | author='Ryan Saxe', 27 | author_email='ryancsaxe@gmail.com', 28 | url='https://github.com/RyanSaxe/mtg', 29 | packages=find_packages(), 30 | install_requires=install_requires, 31 | ) 32 | -------------------------------------------------------------------------------- /mtg/scripts/preprocess.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from mtg.obj.expansion import get_expansion_obj_from_name 3 | import pickle 4 | 5 | 6 | def main(): 7 | EXPANSION = get_expansion_obj_from_name(FLAGS.expansion) 8 | expansion = EXPANSION(bo1=FLAGS.game_data, draft=FLAGS.draft_data, ml_data=True) 9 | with open(FLAGS.expansion_fname, "wb") as f: 10 | pickle.dump(expansion, f) 11 | 12 | 13 | if __name__ == "__main__": 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument( 16 | "--expansion", 17 | type=str, 18 | default="VOW", 19 | help="name of magic expansion corresponding to data files", 20 | ) 21 | parser.add_argument( 22 | "--game_data", type=str, default=None, help="path to bo1 game data" 23 | ) 24 | parser.add_argument( 25 | "--draft_data", type=str, default=None, help="path to bo1 draft data" 26 | ) 27 | parser.add_argument( 28 | "--expansion_fname", 29 | type=str, 30 | default="expansion.pkl", 31 | help="path/to/fname.pkl for where we should store the expansion object", 32 | ) 33 | FLAGS, unparsed = parser.parse_known_args() 34 | main() 35 | -------------------------------------------------------------------------------- /mtg/obj/README.md: -------------------------------------------------------------------------------- 1 | # obj 2 | 3 | This part of the project is responsible for card and data type objects. 4 | 5 | ## CardSet 6 | 7 | A `CardSet` object is meant to download card information from scryfall.com to easily integrate that information with other Magic data. 8 | 9 | ```python 10 | from mtg.obj.cards import CardSet 11 | #this object contains all cards in Crimson Vow with 12 | #cmc 4 or greater, that are present in booster packs 13 | VOW_expensive = CardSet([ 14 | "set=vow", 15 | "cmc>=4", 16 | "is:booster", 17 | ]) 18 | ``` 19 | 20 | Generally, to work with `CardSet` data, it is best to use a pandas DataFrame. So the CardSet object has a `to_dataframe` function for that conversion. 21 | 22 | ## Expansion 23 | 24 | Different expansions have different custom rules and datasets. The `Expansion` object will automatically pull the proper statistical data from 17lands.com, and integrate that with information from scryfall.com using the `CardSet` object. 25 | 26 | ```python 27 | from mtg.obj.expansion import VOW 28 | #use_ml_data specifies to get the 17lands stat data 29 | VOW_expansion = VOW(use_ml_data=True) 30 | ``` 31 | 32 | Additionally, you can pass `bo1` and `draft` arguments to any `Expansion` class to tell it to load 17lands bo1 game data or draft data. Currently I have not used the replay data or bo3 data, so there is no custom preprocessing for that. 33 | 34 | When working with data from a new (or old set), create a child of the `Expansion` object accordingly. 35 | -------------------------------------------------------------------------------- /mtg/obj/scryfall_utils.py: -------------------------------------------------------------------------------- 1 | def merge_card_faces(row): 2 | nans = row.isna() 3 | if nans["card_faces"]: 4 | return row 5 | card_faces = row["card_faces"] 6 | face_1_keys = card_faces[0].keys() 7 | face_2_keys = card_faces[1].keys() 8 | face = dict() 9 | for key in face_1_keys: 10 | if key in ["power", "toughness"]: 11 | try: 12 | val = int(card_faces[0][key]) 13 | except: 14 | val = card_faces[0][key] 15 | else: 16 | val = card_faces[0][key] 17 | face[key] = val 18 | for key in face_2_keys: 19 | if key in ["power", "toughness"]: 20 | try: 21 | val = int(card_faces[1][key]) 22 | except: 23 | val = card_faces[1][key] 24 | else: 25 | val = card_faces[1][key] 26 | if key not in face.keys(): 27 | face[key] = val 28 | else: 29 | if key in ["oracle_text", "flavor_text", "type_line"]: 30 | face[key] = face[key] + "\n//\n" + val 31 | elif key == "colors": 32 | face[key] = list(set(face[key]).union(set(val))) 33 | for key, val in face.items(): 34 | if key not in nans.index: 35 | continue 36 | if nans[key]: 37 | row[key] = val 38 | return row 39 | 40 | 41 | def produce_for_splash(row): 42 | nans = row.isna() 43 | if nans["produced_mana"]: 44 | return [] 45 | return list(set(row["produced_mana"]) - {"C"} - set(row["colors"])) 46 | 47 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # mtg 2 | 3 | State of the art Magic: the Gathering Draft and DeckBuilder AI. 4 | 5 | ## achievements 6 | 7 | ![mythicbot](https://user-images.githubusercontent.com/2286292/149001531-9c983259-4ac6-4ed3-b54a-b0705fb57124.PNG) 8 | 9 | This repository contains an algorithm for automated drafting and building for Magic: the Gathering. I used this algorithm to achieve the highest rank (Mythic) on Magic Arena. I did so in 23 drafts, with a 66% win-rate, which is comparable to how I perform on my normal account in which I do not use any algorithms. The highest rank within Mythic I have hit so far is #27. As far as I know, this is the first time anybody has acheived results of this caliber using an AI in Magic: the Gathering. 10 | 11 | ## architecture 12 | 13 | Below is a general description of the transformer architecture for the Draft AI in order to make it easier to grok than reading through code. 14 | 15 | ![transformer](https://user-images.githubusercontent.com/2286292/158926118-86d8301e-8c0e-43c2-a21b-cced4f785b97.jpg) 16 | 17 | ## installation 18 | 19 | - Step 1: clone this repository, and cd into it. 20 | - Step 2: create a virtual environment in whatever your favorite way to do that is (e.g. `conda create -n my_env` -> `conda activate my_env`). 21 | - Step 3: `pip install .` will install this repo such that you can use `from mtg.xxx.yyy import zzz`. 22 | 23 | **NOTE:** I am not currently providing a pretrained instance of the Draft AI or DeckBulder AI in this repository. That means you cannot simply install this codebase, launch Magic Arena, and use the bot like I do. If you would like to do that, you need to use this code to train it yourself following [these instructions](mtg/scripts). A non-cleaned version of the UI I use that interacts with Magic Arena can be found [here](https://github.com/RyanSaxe/MTGA_Draft_17Lands), and it will eventually be cleaned and added to this repository under mtg/app/. 24 | 25 | ## documentation 26 | 27 | Find any documentation on usage of the different sections in the README of their corresponding folders. 28 | 29 | ## TODO 30 | 31 | - Integrate deckbuilder and drafter in one end-to-end pipeline. 32 | - Add mtg/viz/ as a folder for containing 17lands data visualizations, explorations, and useful insights. 33 | - Add mtg/app/ as a folder to contain the application UI for running on live arena drafts. 34 | -------------------------------------------------------------------------------- /mtg/scripts/README.md: -------------------------------------------------------------------------------- 1 | # scripts 2 | 3 | Scripts for preprocessing data, and training a `DraftBot` and `DeckBuilder` 4 | 5 | ## instructions 6 | 7 | First, Download bo1 game data and bo1 draft data for an expansion from https://www.17lands.com/public_datasets 8 | 9 | The following will preprocess the data and store it: 10 | 11 | ``` 12 | >>> python preprocess.py --expansion VOW \ 13 | --game_data path/to/game/data.csv \ 14 | --draft_data path/to/draft/data.csv \ 15 | --expansion_fname path/to/expansion.pkl 16 | ``` 17 | 18 | Now, you can run the script to train the draft model using the preprocessed data. 19 | 20 | ``` 21 | >>> python train_drafter.py --expansion_fname path/to/expansion.pkl \ 22 | --model_name path/to/draft_model 23 | ``` 24 | 25 | And the same for the deckbuilder model. Note, that it is advised to train the draft model first so that you can use the embeddings from it in the deckbuilder model: 26 | 27 | ``` 28 | >>> python train_builder.py --expansion_fname path/to/expansion.pkl \ 29 | --draft_model path/to/draft_model \ 30 | --model_name path/to/build_model 31 | ``` 32 | 33 | If you want to train any of these models with different hyperparameters, please check the flags specified in the corresponding scripts. 34 | 35 | ## usage 36 | 37 | Once you've trained your instances of these models, you can load them in python and see how they would build decks and make decisions given 17lands logs. If you don't currently have an account at 17lands.com, please make one, as you need an API token for this part. Below is an example of seeing how the model would make different deckbuilding decisions and draft decisions given a 17lands log: 38 | 39 | ```python 40 | from mtg.ml.utils import load_model 41 | from mtg.ml.display import draft_log_ai 42 | import pickle 43 | 44 | draft_model, attrs = load_model("path/to/draft_model") 45 | build_model, cards = load_model("path/to/build_model", extra_pickle="cards.pkl") 46 | expansion = pickle.load(open("path/to/expansion.pkl, "rb")) 47 | 48 | log = 'https://www.17lands.com/draft/[draft_id]' 49 | token = '[your API token]' 50 | # log_url[0] will be a link to a 17lands draft log 51 | # log_url[1] will be a link to a sealeddeck.tech deckbuild 52 | log_url = draft_log_ai( 53 | log, 54 | draft_model, 55 | expansion=expansion, 56 | token=token, 57 | build_model=build_model, 58 | ) 59 | ``` 60 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # Don't upload data 132 | data/ 133 | 134 | # Don't upload vscode specs 135 | .vscode 136 | 137 | # Mac files 138 | .DS_Store 139 | -------------------------------------------------------------------------------- /mtg/ml/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import pickle 4 | import os 5 | 6 | 7 | class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule): 8 | """ 9 | learning rate scheduling 10 | """ 11 | 12 | def __init__(self, d_model, warmup_steps=1000): 13 | super(CustomSchedule, self).__init__() 14 | 15 | self.d_model = d_model 16 | self.d_model = tf.cast(self.d_model, tf.float32) 17 | 18 | self.warmup_steps = warmup_steps 19 | 20 | def __call__(self, step): 21 | arg1 = tf.math.rsqrt(step) 22 | arg2 = step * (self.warmup_steps ** -1.5) 23 | 24 | return tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2) 25 | 26 | 27 | def importance_weighting(df, minim=0.1, maxim=1.0): 28 | rank_to_score = { 29 | "bronze": 0.01, 30 | "silver": 0.1, 31 | "gold": 0.25, 32 | "platinum": 0.5, 33 | "diamond": 0.75, 34 | "mythic": 1.0, 35 | } 36 | # decrease exponentiation by larger amounts for higher 37 | # ranks such that rank and win-rate matter together 38 | rank_addition = df["rank"].apply(lambda x: rank_to_score.get(x, 0.5)) 39 | scaled_win_rate = np.clip( 40 | df["user_win_rate_bucket"].fillna(0.5) ** (2 - rank_addition), 41 | a_min=minim, 42 | a_max=maxim, 43 | ) 44 | 45 | last = df["date"].max() 46 | # increase importance factor for recent data points according to number of weeks from most recent data point 47 | n_weeks = df["date"].apply(lambda x: (last - x).days // 7) 48 | # lower the value of pxp11 + 49 | if "position" in df.columns: 50 | pack_size = (df["position"].max() + 1) / 3 51 | pick_nums = df["position"] % pack_size + 1 52 | # alpha default to this (~0.54) because it places highest 53 | # importance in the beginning of the pack, but lower on PxP1 54 | # to help reduce rare drafting. Nice property of PxP1 ~= PxP8 55 | alpha = np.e / 5.0 56 | position_scale = pick_nums.apply(lambda x: (np.log(x) + 1) / np.power(x, alpha)) 57 | else: 58 | position_scale = 1.0 59 | return ( 60 | position_scale 61 | * scaled_win_rate 62 | * np.clip(df["won"], a_min=0.5, a_max=1.0) 63 | * 0.9 ** n_weeks 64 | ) 65 | 66 | 67 | def load_model(location, extra_pickle="attrs.pkl"): 68 | model_loc = os.path.join(location, "model") 69 | data_loc = os.path.join(location, extra_pickle) 70 | model = tf.saved_model.load(model_loc) 71 | try: 72 | with open(data_loc, "rb") as f: 73 | extra = pickle.load(f) 74 | return (model, extra) 75 | except: 76 | return model 77 | -------------------------------------------------------------------------------- /mtg/ml/README.md: -------------------------------------------------------------------------------- 1 | # ml 2 | 3 | This repository is dedicated to Machine Learning implementations for Magic: the Gathering. Below is a quick description of each file, if you would like to see how to take all of this and train your own instances of the models in this folder, please refer to the `train_xxx.py` files in `mtg/scripts/`. 4 | 5 | ## layers.py 6 | 7 | This file contains implementations of layers, such as `Dense`, `MultiHeadedAttention`, `LayerNormalization`, and more! 8 | 9 | ## nn.py 10 | 11 | This file contains implementations of module blocks that models can use such as `MLP` and `TransformerBlock`. Here is an example of building an autoencoder using the `MLP` block. 12 | 13 | ```python 14 | import tensorflow as tf 15 | from mtg.ml.nn import MLP 16 | 17 | class AutoEncoder(tf.Module): 18 | def __init__(self, in_dim, emb_dim, name=None): 19 | super().__init__(name=name) 20 | self.encoder = MLP( 21 | in_dim = in_dim, 22 | start_dim = in_dim // 2 23 | n_h_layers = 2, 24 | out_dim = emb_dim, 25 | style="bottleneck", 26 | ) 27 | self.decoder = MLP( 28 | in_dim = emb_dim, 29 | start_dim = emb_dim * 2, 30 | n_h_layers = 2, 31 | out_dim = in_dim, 32 | style="reverse_bottleneck", 33 | ) 34 | 35 | def __call__(self, x, training=None): 36 | embedding = self.encoder(x, training=training) 37 | return self.decoder(embedding, training=training) 38 | ``` 39 | 40 | ## model.py 41 | 42 | This file contains implementations of projects to apply Machine Learning to Magic: the Gathering. Currently it contains a model for drafting, `DraftBot` and a model for deckbuilding, `DeckBuilder` 43 | 44 | ## generator.py 45 | 46 | This file contains data generator objects for batching 17lands data properly to feed into models in `models.py`. 47 | 48 | ## trainer.py 49 | 50 | This file contains the custom training object used to train models from `models.py` using generators from `generator.py`. 51 | 52 | ## display.py 53 | 54 | This file contains different ways to visualize and run pretrained models. Here is an example of a common use case for debugging: 55 | 56 | ```python 57 | from mtg.ml.display import draft_sim 58 | 59 | # assume draft_model and build_model are pretrained instances of those MTG models 60 | # assume expansion is a loaded instance of the expansion object containing the 61 | # data corresponding to draft_model and build_model 62 | # then, draft_sim as ran below will spin up a table of 8 bots and run them through a draft. 63 | # what is returned is links to 8 corresponding 17land draft logs and sealeddeck.tech deck builds. 64 | 65 | token = "abcdefghijk1234567890" #replace this with your 17lands API token 66 | bot_table = draft_sim(expansion, draft_model, token=token, build_model=build_model) 67 | ``` 68 | 69 | ## utils.py 70 | 71 | This file contains utility functions needed for the models such as learning rate schedulers, and model loading functions. 72 | 73 | ## TODO: 74 | 75 | - integrate the deckbuilder model to be a part of the drafting model. 76 | - update MLP implementation such that `n_h_layers` actually corresponds to the number of hidden layers (at the moment, it's technically 1 more) 77 | -------------------------------------------------------------------------------- /mtg/scripts/train_builder.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from mtg.ml.generator import DeckGenerator, create_train_and_val_gens 3 | from mtg.ml.models import DeckBuilder 4 | import pickle 5 | from mtg.ml.trainer import Trainer 6 | import numpy as np 7 | from mtg.ml.display import build_decks 8 | from mtg.ml.utils import load_model 9 | 10 | 11 | def main(): 12 | with open(FLAGS.expansion_fname, "rb") as f: 13 | expansion = pickle.load(f) 14 | 15 | decks = expansion.get_bo1_decks() 16 | train_gen, val_gen = create_train_and_val_gens( 17 | decks, 18 | expansion.cards.copy(), 19 | train_p=FLAGS.train_p, 20 | id_col="draft_id", 21 | train_batch_size=FLAGS.batch_size, 22 | generator=DeckGenerator, 23 | include_val=True, 24 | mask_decks=True, 25 | ) 26 | 27 | if FLAGS.draft_model is not None: 28 | _, attrs = load_model(FLAGS.draft_model) 29 | embeddings = attrs["embeddings"] 30 | else: 31 | embeddings = FLAGS.emb_dim 32 | 33 | model = DeckBuilder( 34 | expansion.cards["idx"].max() - 4, 35 | dropout=FLAGS.dropout, 36 | embeddings=embeddings, 37 | name="DeckBuilder", 38 | ) 39 | 40 | model.compile( 41 | learning_rate={"warmup_steps": FLAGS.lr_warmup}, 42 | cmc_lambda=FLAGS.cmc_lambda, 43 | card_data=expansion.card_data_for_ML.iloc[:-1, :], 44 | ) 45 | trainer = Trainer( 46 | model, 47 | generator=train_gen, 48 | val_generator=val_gen, 49 | ) 50 | trainer.train( 51 | FLAGS.epochs, 52 | verbose=FLAGS.verbose, 53 | ) 54 | # we run inference once before saving the model in order to serialize it with the right input parameters for inference 55 | # and we do it with train_gen because val_gen can be None, and this isn't used for validation but serialization 56 | x, y, z = train_gen[0] 57 | pid = 0 58 | pool = np.expand_dims(x[0][pid, 0, :], 0) 59 | basics, spells, n_basics = build_decks(model, pool, cards=expansion.cards) 60 | 61 | model.save(expansion.cards, FLAGS.model_name) 62 | 63 | 64 | if __name__ == "__main__": 65 | parser = argparse.ArgumentParser() 66 | parser.add_argument( 67 | "--expansion_fname", 68 | type=str, 69 | default="expansion.pkl", 70 | help="path/to/fname.pkl for where we should load the expansion object", 71 | ) 72 | parser.add_argument( 73 | "--batch_size", type=int, default=32, help="training batch size" 74 | ) 75 | parser.add_argument( 76 | "--train_p", type=float, default=1.0, help="number in [0,1] for train-val split" 77 | ) 78 | parser.add_argument( 79 | "--emb_dim", 80 | type=int, 81 | default=None, 82 | help="card embedding dimension. If None, embeddings aren't used. If we pass a Draft Model in --draft_model, we use the embeddings from that model instead.", 83 | ) 84 | parser.add_argument( 85 | "--draft_model", 86 | type=str, 87 | default=None, 88 | help="path/to/model so we can use embeddings learned from an existing pretrained draft model", 89 | ) 90 | parser.add_argument( 91 | "--dropout", 92 | type=float, 93 | default=0.2, 94 | help="dropout rate to apply to the dense layers in the encoders", 95 | ) 96 | parser.add_argument( 97 | "--lr_warmup", 98 | type=float, 99 | default=4000, 100 | help="number of warmup steps in the classic transformer learning rate scheduler", 101 | ) 102 | parser.add_argument( 103 | "--cmc_lambda", 104 | type=float, 105 | default=0.1, 106 | help="regularization coefficient for helping the model build comparable curves to humans", 107 | ) 108 | parser.add_argument( 109 | "--epochs", type=int, default=1, help="number of epochs to train the model" 110 | ) 111 | parser.add_argument( 112 | "--verbose", 113 | type=bool, 114 | default=True, 115 | help="If True, tqdm will display during training", 116 | ) 117 | 118 | parser.add_argument( 119 | "--model_name", 120 | type=str, 121 | default="draft_model", 122 | help="path/to/deck_model where the model will be stored", 123 | ) 124 | 125 | FLAGS, unparsed = parser.parse_known_args() 126 | main() 127 | -------------------------------------------------------------------------------- /mtg/obj/cards.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import urllib 3 | import json as js 4 | import pandas as pd 5 | import mtg.obj.scryfall_utils as scry_utils 6 | 7 | 8 | class CardSet: 9 | """ 10 | Run a scryfall search for list of cards according to a query: 11 | 12 | parameters: 13 | 14 | query_args: a list of queries to hit the scryfall api with 15 | 16 | usage: 17 | 18 | KHM_expensive = Cards([ 19 | "set=khm", 20 | "cmc>=4", 21 | ]) 22 | #this gets you all kaldheim cards with cmc 4 or greater 23 | 24 | """ 25 | 26 | def __init__(self, query_args, json_files=[]): 27 | self.search_q = "https://api.scryfall.com/cards/search?q=" 28 | if isinstance(query_args, str): 29 | self.search_q += urllib.parse.quote(query_args) 30 | else: 31 | self.search_q += urllib.parse.quote( 32 | " & ".join([query for query in query_args]) 33 | ) 34 | response = requests.get(self.search_q) 35 | self._json = response.json() 36 | self.cards = set() 37 | self._build_card_list_query() 38 | self._build_card_list_json(json_files) 39 | 40 | def _build_card_list_query(self): 41 | """ 42 | store cards from the query in self.cards 43 | """ 44 | json = self._json 45 | while json.get("has_more", False): 46 | self.cards = self.cards.union({Card(card) for card in json.get("data", [])}) 47 | if json.get("next_page", None) is not None: 48 | next_page = requests.get(json["next_page"]) 49 | json = next_page.json() 50 | self.cards = self.cards.union({Card(card) for card in json.get("data", [])}) 51 | 52 | def _build_card_list_json(self, json_files): 53 | """ 54 | store cards in the json_files in self.cards 55 | """ 56 | for json_f in json_files: 57 | json = json.load(json_f) 58 | self.cards = self.cards.union({Card(card) for card in json.get("data", [])}) 59 | 60 | def to_dataframe(self): 61 | card_data = [card.__dict__ for card in self.cards] 62 | df = pd.DataFrame(card_data) 63 | df = self.scryfall_modifications(df) 64 | # modify so that basics have the first 5 idxs 65 | basics = ["plains", "island", "swamp", "mountain", "forest"] 66 | card_names = [x for x in df["name"].unique()] 67 | for basic in basics: 68 | if basic in card_names: card_names.remove(basic) 69 | card_names = basics + card_names 70 | id_to_name = {i: card_name for i, card_name in enumerate(card_names)} 71 | name_to_id = {name: idx for idx, name in id_to_name.items()} 72 | df["idx"] = df["name"].apply(lambda x: name_to_id[x]) 73 | return df 74 | 75 | def scryfall_modifications(self, df): 76 | if "card_faces" in df.columns: 77 | df = df.apply(scry_utils.merge_card_faces, axis=1) 78 | df["produces_for_splash"] = df.apply(scry_utils.produce_for_splash, axis=1) 79 | df["name"] = df["name"].apply(lambda x: x.split("//")[0].strip().lower()) 80 | return df 81 | 82 | def union(self, cardset2): 83 | return self.cards | cardset2.cards 84 | 85 | def intersection(self, cardset2): 86 | return self.cards & cardset2.cards 87 | 88 | def difference(self, cardset2): 89 | return self.cards - cardset2.cards 90 | 91 | def simdiff(self, cardset2): 92 | return self.cards ^ cardset2.cards 93 | 94 | 95 | class Card: 96 | def __init__(self, *args, **kwargs): 97 | for dictionary in args: 98 | for key in dictionary: 99 | setattr(self, key, dictionary[key]) 100 | for key in kwargs: 101 | setattr(self, key, kwargs[key]) 102 | if hasattr(self, "name"): 103 | self.name = self.name.lower() 104 | 105 | self.colnames = { 106 | "deck": "deck_" + self.name, 107 | "hand": "opening_hand_" + self.name, 108 | "drawn": "drawn_" + self.name, 109 | "sideboard": "sideboard_" + self.name, 110 | } 111 | 112 | def __hash__(self): 113 | return hash(self.name) 114 | 115 | def __eq__(self, card2): 116 | return self.name == card2.name 117 | 118 | def __str__(self): 119 | return self.name 120 | 121 | def __repr__(self): 122 | return self.name 123 | -------------------------------------------------------------------------------- /mtg/scripts/train_drafter.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from mtg.ml.generator import DraftGenerator, create_train_and_val_gens 3 | from mtg.ml.models import DraftBot 4 | import pickle 5 | from mtg.ml.trainer import Trainer 6 | import tensorflow as tf 7 | 8 | 9 | def main(): 10 | with open(FLAGS.expansion_fname, "rb") as f: 11 | expansion = pickle.load(f) 12 | 13 | train_gen, val_gen = create_train_and_val_gens( 14 | expansion.draft, 15 | expansion.cards.copy(), 16 | train_p=FLAGS.train_p, 17 | id_col="draft_id", 18 | train_batch_size=FLAGS.batch_size, 19 | generator=DraftGenerator, 20 | include_val=True, 21 | ) 22 | 23 | model = DraftBot( 24 | expansion=expansion, 25 | emb_dim=FLAGS.emb_dim, 26 | num_encoder_heads=FLAGS.num_encoder_heads, 27 | num_decoder_heads=FLAGS.num_decoder_heads, 28 | pointwise_ffn_width=FLAGS.pointwise_ffn_width, 29 | num_encoder_layers=FLAGS.num_encoder_layers, 30 | num_decoder_layers=FLAGS.num_decoder_layers, 31 | emb_dropout=FLAGS.emb_dropout, 32 | memory_dropout=FLAGS.transformer_dropout, 33 | name="DraftBot", 34 | ) 35 | 36 | model.compile( 37 | learning_rate={"warmup_steps": FLAGS.lr_warmup}, 38 | margin=FLAGS.emb_margin, 39 | emb_lambda=FLAGS.emb_lambda, 40 | rare_lambda=FLAGS.rare_lambda, 41 | cmc_lambda=FLAGS.cmc_lambda, 42 | ) 43 | 44 | trainer = Trainer( 45 | model, 46 | generator=train_gen, 47 | val_generator=val_gen, 48 | ) 49 | trainer.train( 50 | FLAGS.epochs, 51 | print_keys=["prediction_loss", "embedding_loss", "rare_loss", "cmc_loss"], 52 | verbose=FLAGS.verbose, 53 | ) 54 | # we run inference once before saving the model in order to serialize it with the right input parameters for inference 55 | # and we do it with train_gen because val_gen can be None, and this isn't used for validation but serialization 56 | x, y, z = train_gen[0] 57 | (packs, shifted_picks, positions) = x 58 | model_input = ( 59 | tf.expand_dims(packs[0], 0), 60 | tf.expand_dims(shifted_picks[0], 0), 61 | tf.expand_dims(positions[0], 0), 62 | ) 63 | output, attention = model(model_input, training=False, return_attention=True) 64 | model.save(FLAGS.model_name) 65 | 66 | 67 | if __name__ == "__main__": 68 | parser = argparse.ArgumentParser() 69 | parser.add_argument( 70 | "--expansion_fname", 71 | type=str, 72 | default="expansion.pkl", 73 | help="path/to/fname.pkl for where we should load the expansion object", 74 | ) 75 | parser.add_argument( 76 | "--batch_size", type=int, default=32, help="training batch size" 77 | ) 78 | parser.add_argument( 79 | "--train_p", type=float, default=1.0, help="number in [0,1] for train-val split" 80 | ) 81 | parser.add_argument( 82 | "--emb_dim", type=int, default=128, help="card embedding dimension" 83 | ) 84 | parser.add_argument( 85 | "--num_encoder_heads", 86 | type=int, 87 | default=8, 88 | help="number of heads in the encoder blocks of transformer", 89 | ) 90 | parser.add_argument( 91 | "--num_decoder_heads", 92 | type=int, 93 | default=8, 94 | help="number of heads in the decoder blocks of transformer", 95 | ) 96 | parser.add_argument( 97 | "--pointwise_ffn_width", 98 | type=int, 99 | default=512, 100 | help="each transformer block has a pointwise_ffn with this width as latent space", 101 | ) 102 | parser.add_argument( 103 | "--num_encoder_layers", 104 | type=int, 105 | default=2, 106 | help="number of transformer blocks for the encoder", 107 | ) 108 | parser.add_argument( 109 | "--num_decoder_layers", 110 | type=int, 111 | default=2, 112 | help="number of transformer blocks for the decoder", 113 | ) 114 | parser.add_argument( 115 | "--emb_dropout", 116 | type=float, 117 | default=0.0, 118 | help="dropout rate to apply to embeddings before passed to encoder", 119 | ) 120 | parser.add_argument( 121 | "--transformer_dropout", 122 | type=float, 123 | default=0.1, 124 | help="dropout rate inside each transformer block", 125 | ) 126 | parser.add_argument( 127 | "--lr_warmup", 128 | type=float, 129 | default=2000, 130 | help="number of warmup steps in the classic transformer learning rate scheduler", 131 | ) 132 | parser.add_argument( 133 | "--emb_margin", 134 | type=float, 135 | default=1.0, 136 | help="margin for triplet loss penalty on the embeddings", 137 | ) 138 | parser.add_argument( 139 | "--emb_lambda", 140 | type=float, 141 | default=0.5, 142 | help="regularization coefficient for triplet loss on embeddings", 143 | ) 144 | parser.add_argument( 145 | "--rare_lambda", 146 | type=float, 147 | default=10.0, 148 | help="regularization coefficient for penalizing the model for taking rares when human doesn't", 149 | ) 150 | parser.add_argument( 151 | "--cmc_lambda", 152 | type=float, 153 | default=0.1, 154 | help="regularization coefficient for penalizing the model for taking expensive cards when human doesn't", 155 | ) 156 | parser.add_argument( 157 | "--epochs", type=int, default=1, help="number of epochs to train the model" 158 | ) 159 | parser.add_argument( 160 | "--verbose", 161 | type=bool, 162 | default=True, 163 | help="If True, tqdm will display during training", 164 | ) 165 | parser.add_argument( 166 | "--model_name", 167 | type=str, 168 | default="draft_model", 169 | help="path/to/draft_model where the model will be stored", 170 | ) 171 | 172 | FLAGS, unparsed = parser.parse_known_args() 173 | main() 174 | -------------------------------------------------------------------------------- /mtg/obj/dataloading_utils.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import requests 3 | import re 4 | 5 | 6 | def load_data(filename, cards, name=None): 7 | if name == "draft": 8 | return load_draft_data(filename, cards) 9 | elif name == "bo1": 10 | return load_bo1_data(filename, cards) 11 | else: 12 | return pd.read_csv(filename) 13 | 14 | 15 | def sort_cols_by_card_idxs(df, card_col_prefixes, cards): 16 | # initialize columns to start with the non-card columns 17 | column_order = [ 18 | c 19 | for c in df.columns 20 | if not any([c.startswith(prefix) for prefix in card_col_prefixes]) 21 | ] 22 | card_names = cards.sort_values(by="idx", ascending=True)["name"].tolist() 23 | for prefix in card_col_prefixes: 24 | prefix_columns = [prefix + "_" + name for name in card_names] 25 | column_order += prefix_columns 26 | # reorder dataframe to abide by new column ordering 27 | # this is just so df[self.deck_cols].to_numpy() 28 | # yields a comparable matrix to df[self.sideboard_cols].to_numpy() 29 | df = df[column_order] 30 | return df 31 | 32 | 33 | def load_bo1_data(filename, cards): 34 | COLUMN_REGEXES = { 35 | re.compile(r"user_game_win_rate_bucket"): "float16", 36 | re.compile(r"rank"): "str", 37 | re.compile(r"draft_id"): "str", 38 | re.compile(r"draft_time"): "str", 39 | re.compile(r"expansion"): "str", 40 | re.compile(r"event_type"): "str", 41 | re.compile(r"deck_.*"): "int8", 42 | re.compile(r"sideboard_.*"): "int8", 43 | re.compile(r"drawn_.*"): "int8", 44 | re.compile(r"sideboard_.*"): "int8", 45 | re.compile(r"opening_hand_.*"): "int8", 46 | re.compile(r"on_play"): "int8", 47 | re.compile(r"won"): "int8", 48 | re.compile(r"num_turns"): "int8", 49 | re.compile(r"num_mulligans"): "int8", 50 | re.compile(r"opp_num_mulligans"): "int8", 51 | } 52 | col_names = pd.read_csv(filename, nrows=0).columns 53 | data_types = {} 54 | draft_cols = [] 55 | for c in col_names: 56 | if any( 57 | [ 58 | c.startswith(prefix) 59 | for prefix in ["sideboard_", "deck_", "drawn_", "opening_hand_"] 60 | ] 61 | ): 62 | draft_cols.append(c) 63 | for (r, t) in COLUMN_REGEXES.items(): 64 | if r.match(c): 65 | data_types[c] = t 66 | 67 | df = pd.read_csv( 68 | filename, 69 | dtype=data_types, 70 | usecols=[ 71 | "draft_id", 72 | "draft_time", 73 | "won", 74 | "user_game_win_rate_bucket", 75 | "rank", 76 | "on_play", 77 | "num_turns", 78 | "num_mulligans", 79 | "opp_num_mulligans" 80 | # ... 81 | ] 82 | + draft_cols, 83 | ) 84 | rename_cols = { 85 | "user_game_win_rate_bucket": "user_win_rate_bucket", 86 | "draft_time": "date", 87 | } 88 | df.columns = [ 89 | x.lower() if x not in rename_cols else rename_cols[x] for x in df.columns 90 | ] 91 | df["won"] = df["won"].astype(float) 92 | df["date"] = pd.to_datetime(df["date"]) 93 | card_col_prefixes = ["deck", "opening_hand", "drawn", "sideboard"] 94 | df = sort_cols_by_card_idxs(df, card_col_prefixes, cards) 95 | return df 96 | 97 | 98 | def load_draft_data(filename, cards): 99 | COLUMN_REGEXES = { 100 | re.compile(r"user_game_win_rate_bucket"): "float16", 101 | re.compile(r"user_n_games_bucket"): "int8", 102 | re.compile(r"rank"): "str", 103 | re.compile(r"draft_id"): "str", 104 | re.compile(r"draft_time"): "str", 105 | re.compile(r"expansion"): "str", 106 | re.compile(r"event_type"): "str", 107 | re.compile(r"event_match_wins"): "int8", 108 | re.compile(r"event_match_losses"): "int8", 109 | re.compile(r"pack_number"): "int8", 110 | re.compile(r"pick_number"): "int8", 111 | re.compile(r"pick$"): "str", 112 | re.compile(r"pick_maindeck_rate"): "float16", 113 | re.compile(r"pick_sideboard_in_rate"): "float16", 114 | re.compile(r"pool_.*"): "int8", 115 | re.compile(r"pack_card_.*"): "int8", 116 | } 117 | col_names = pd.read_csv(filename, nrows=0).columns 118 | data_types = {} 119 | draft_cols = [] 120 | for c in col_names: 121 | if c.startswith("pack_card_"): 122 | draft_cols.append(c) 123 | elif c == "pick": 124 | draft_cols.append(c) 125 | elif c.startswith("pool_"): 126 | draft_cols.append(c) 127 | for (r, t) in COLUMN_REGEXES.items(): 128 | if r.match(c): 129 | data_types[c] = t 130 | 131 | df = pd.read_csv( 132 | filename, 133 | dtype=data_types, 134 | usecols=[ 135 | "draft_id", 136 | "draft_time", 137 | "event_match_losses", 138 | "event_match_wins", 139 | "pack_number", 140 | "pick_number", 141 | "user_n_games_bucket", 142 | "user_game_win_rate_bucket", 143 | "rank" 144 | # ... 145 | ] 146 | + draft_cols, 147 | ) 148 | rename_cols = { 149 | "user_game_win_rate_bucket": "user_win_rate_bucket", 150 | "draft_time": "date", 151 | } 152 | df.columns = [ 153 | x.lower() if x not in rename_cols else rename_cols[x] for x in df.columns 154 | ] 155 | n_picks = df.groupby("draft_id")["pick"].count() 156 | t = n_picks.max() 157 | bad_draft_ids = n_picks[n_picks < t].index.tolist() 158 | df = df[~df["draft_id"].isin(bad_draft_ids)] 159 | df["pick"] = df["pick"].str.lower() 160 | df["date"] = pd.to_datetime(df["date"]) 161 | df["won"] = ( 162 | df["event_match_wins"] / (df["event_match_wins"] + df["event_match_losses"]) 163 | ).fillna(0.0) 164 | card_col_prefixes = ["pack_card", "pool"] 165 | df = sort_cols_by_card_idxs(df, card_col_prefixes, cards) 166 | df["position"] = ( 167 | df["pack_number"] * (df["pick_number"].max() + 1) + df["pick_number"] 168 | ) 169 | df = df.sort_values(by=["draft_id", "position"]) 170 | return df 171 | 172 | 173 | def get_card_rating_data(expansion, endpoint=None, start=None, end=None, colors=None): 174 | if endpoint is None: 175 | endpoint = f"https://www.17lands.com/card_ratings/data?expansion={expansion.upper()}&format=PremierDraft" 176 | if start is not None: 177 | endpoint += f"&start_date={start}" 178 | if end is not None: 179 | endpoint += f"&end_date={end}" 180 | if colors is not None: 181 | endpoint += f"&colors={colors}" 182 | card_json = requests.get(endpoint).json() 183 | card_df = pd.DataFrame(card_json).fillna(0.0) 184 | numerical_cols = card_df.columns[card_df.dtypes != object] 185 | card_df["name"] = card_df["name"].str.lower() 186 | card_df = card_df.set_index("name") 187 | return card_df[numerical_cols] 188 | 189 | 190 | def get_draft_json(draft_log_url, stream=False): 191 | if not stream: 192 | base_url = "https://www.17lands.com/data/draft?draft_id=" 193 | else: 194 | base_url = "https://www.17lands.com/data/draft/stream/?draft_id=" 195 | draft_ext = draft_log_url.split("/")[-1].strip() 196 | log_json_url = base_url + draft_ext 197 | response = requests.get(log_json_url, stream=stream) 198 | if not stream: 199 | response = response.json() 200 | return response 201 | -------------------------------------------------------------------------------- /mtg/ml/nn.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from mtg.ml.layers import Dense, LayerNormalization, MultiHeadAttention 3 | 4 | 5 | class MLP(tf.Module): 6 | def __init__( 7 | self, 8 | in_dim, 9 | start_dim, 10 | out_dim, 11 | n_h_layers, 12 | dropout=0.0, 13 | noise=0.0, 14 | start_act=tf.nn.relu, 15 | middle_act=tf.nn.relu, 16 | out_act=tf.nn.relu, 17 | style="bottleneck", 18 | name=None, 19 | ): 20 | assert style in ["bottleneck", "flat", "reverse_bottleneck"] 21 | super().__init__(name=name) 22 | self.noise = noise 23 | self.dropout = dropout 24 | self.layers = [ 25 | Dense(in_dim, start_dim, activation=start_act, name=self.name + "_0") 26 | ] 27 | last_dim = start_dim 28 | for i in range(n_h_layers): 29 | if style == "bottleneck": 30 | dim = last_dim // 2 31 | elif style == "reverse_bottleneck": 32 | dim = last_dim * 2 33 | else: 34 | dim = last_dim 35 | self.layers.append( 36 | Dense( 37 | last_dim, 38 | dim, 39 | activation=middle_act, 40 | name=self.name + "_" + str(i + 1), 41 | ) 42 | ) 43 | last_dim = dim 44 | self.layers.append( 45 | Dense(last_dim, out_dim, activation=out_act, name=self.name + "_out") 46 | ) 47 | 48 | @tf.function 49 | def __call__(self, x, training=None): 50 | if self.noise > 0.0 and training: 51 | x = tf.nn.dropout(x, rate=self.noise) 52 | for layer in self.layers: 53 | x = layer(x) 54 | if self.dropout > 0.0 and training: 55 | x = tf.nn.dropout(x, rate=self.dropout) 56 | return x 57 | 58 | 59 | class ConcatEmbedding(tf.Module): 60 | """ 61 | Lets say you want an embedding that is a concatenation of the abstract object and data about the object 62 | 63 | so we learn a normal one hot embedding, and then have an MLP process the data about the object and concatenate the two. 64 | """ 65 | 66 | def __init__( 67 | self, 68 | num_items, 69 | emb_dim, 70 | item_data, 71 | dropout=0.0, 72 | n_h_layers=1, 73 | initializer=tf.initializers.GlorotNormal(), 74 | name=None, 75 | activation=None, 76 | start_act=None, 77 | middle_act=None, 78 | out_act=None, 79 | ): 80 | super().__init__(name=name) 81 | assert item_data.shape[0] == num_items 82 | self.item_data = item_data 83 | self.item_MLP = MLP( 84 | in_dim=item_data.shape[1], 85 | start_dim=item_data.shape[1] // 2, 86 | out_dim=emb_dim // 2, 87 | n_h_layers=n_h_layers, 88 | dropout=dropout, 89 | name="item_data_mlp", 90 | start_act=start_act, 91 | middle_act=middle_act, 92 | out_act=out_act, 93 | style="bottleneck", 94 | ) 95 | self.embedding = tf.Variable( 96 | initializer(shape=(num_items, emb_dim // 2)), 97 | dtype=tf.float32, 98 | name=self.name + "_embedding", 99 | ) 100 | self.activation = activation 101 | 102 | @tf.function 103 | def __call__(self, x, training=None): 104 | item_embeddings = tf.gather(self.embedding, x) 105 | data_embeddings = tf.gather( 106 | self.item_MLP(self.item_data, training=training), 107 | x, 108 | ) 109 | embeddings = tf.concat([item_embeddings, data_embeddings], axis=-1) 110 | if self.activation is not None: 111 | embeddings = self.activation(embeddings) 112 | return embeddings 113 | 114 | 115 | class TransformerBlock(tf.Module): 116 | """ 117 | Transformer Block implementation. Rather than having a separate class for the encoder 118 | block and decoder block, instead there is a `decode` flag to determine if the extra 119 | processing step is necessary 120 | """ 121 | 122 | def __init__( 123 | self, 124 | emb_dim, 125 | num_heads, 126 | pointwise_ffn_width, 127 | dropout=0.0, 128 | decode=False, 129 | name=None, 130 | ): 131 | super().__init__(name=name) 132 | self.dropout = dropout 133 | self.decode = decode 134 | # kdim and dmodel are the same because the embedding dimension of the non-attended 135 | # embeddings are the same as the attention embeddings. 136 | self.attention = MultiHeadAttention( 137 | emb_dim, emb_dim, num_heads, name=self.name + "_attention" 138 | ) 139 | self.expand_attention = Dense( 140 | emb_dim, 141 | pointwise_ffn_width, 142 | activation=tf.nn.relu, 143 | name=self.name + "_pointwise_in", 144 | ) 145 | self.compress_expansion = Dense( 146 | pointwise_ffn_width, 147 | emb_dim, 148 | activation=None, 149 | name=self.name + "_pointwise_out", 150 | ) 151 | self.final_layer_norm = LayerNormalization( 152 | emb_dim, name=self.name + "_out_norm" 153 | ) 154 | self.attention_layer_norm = LayerNormalization( 155 | emb_dim, name=self.name + "_attention_norm" 156 | ) 157 | if self.decode: 158 | self.decode_attention = MultiHeadAttention( 159 | emb_dim, emb_dim, num_heads, name=self.name + "_decode_attention" 160 | ) 161 | self.decode_layer_norm = LayerNormalization( 162 | emb_dim, name=self.name + "_decode_norm" 163 | ) 164 | 165 | def pointwise_fnn(self, x, training=None): 166 | x = self.expand_attention(x, training=training) 167 | return self.compress_expansion(x, training=training) 168 | 169 | @tf.function 170 | def __call__(self, x, mask, encoder_output=None, training=None): 171 | attention_emb, attention_weights = self.attention( 172 | x, x, x, mask, training=training 173 | ) 174 | if training and self.dropout > 0: 175 | attention_emb = tf.nn.dropout(attention_emb, rate=self.dropout) 176 | residual_emb_w_memory = self.attention_layer_norm( 177 | x + attention_emb, training=training 178 | ) 179 | if self.decode: 180 | assert encoder_output is not None 181 | decode_attention_emb, decode_attention_weights = self.decode_attention( 182 | encoder_output, 183 | encoder_output, 184 | residual_emb_w_memory, 185 | mask, 186 | training=training, 187 | ) 188 | if training and self.dropout > 0: 189 | decode_attention_emb = tf.nn.dropout( 190 | decode_attention_emb, rate=self.dropout 191 | ) 192 | residual_emb_w_memory = self.decode_layer_norm( 193 | residual_emb_w_memory + decode_attention_emb, training=training 194 | ) 195 | attention_weights = (attention_weights, decode_attention_weights) 196 | process_emb = self.pointwise_fnn(residual_emb_w_memory, training=training) 197 | if training and self.dropout > 0: 198 | process_emb = tf.nn.dropout(process_emb, rate=self.dropout) 199 | return ( 200 | self.final_layer_norm( 201 | residual_emb_w_memory + process_emb, training=training 202 | ), 203 | attention_weights, 204 | ) 205 | -------------------------------------------------------------------------------- /mtg/ml/trainer.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import sys 3 | from tqdm.auto import tqdm 4 | import numpy as np 5 | 6 | 7 | class Trainer: 8 | def __init__( 9 | self, 10 | model, 11 | generator=None, 12 | val_generator=None, 13 | features=None, 14 | target=None, 15 | weights=None, 16 | val_features=None, 17 | val_target=None, 18 | val_weights=None, 19 | clip=5.0, 20 | loss_agg_f=lambda x: np.sum(x), 21 | ): 22 | self.generator = generator 23 | self.val_generator = val_generator 24 | self.features = features 25 | self.target = target 26 | self.model = model 27 | self.epoch_n = 0 28 | self.clip = clip 29 | self.loss_agg_f = loss_agg_f 30 | if self.target is not None: 31 | self.batch_ids = np.arange(len(self.target)) 32 | else: 33 | self.batch_ids = None 34 | self.weights = weights 35 | self.val_features = val_features 36 | self.val_target = val_target 37 | self.val_weights = val_weights 38 | 39 | if self.generator is not None: 40 | assert self.features is None 41 | assert self.target is None 42 | assert self.weights is None 43 | else: 44 | assert self.features is not None 45 | assert self.target is not None 46 | 47 | if self.val_generator is not None: 48 | assert self.generator is not None 49 | assert self.val_features is None 50 | assert self.val_target is None 51 | assert self.val_weights is None 52 | 53 | def _step( 54 | self, batch_features, batch_target, batch_weights, only_val_metrics=False 55 | ): 56 | with tf.GradientTape() as tape: 57 | output = self.model(batch_features, training=True) 58 | loss = self.model.loss( 59 | batch_target, output, sample_weight=batch_weights, training=True 60 | ) 61 | # if NaN loss ever happens, raise an error inside the gradient tape so that 62 | # pdb can be used for live debugging with access to the gradients 63 | if tf.math.is_nan(loss): 64 | raise ValueError("NaN Loss!!") 65 | if len(self.model.metric_names) > 0 and not only_val_metrics: 66 | metrics = self.model.compute_metrics( 67 | batch_target, output, sample_weight=batch_weights, training=True 68 | ) 69 | else: 70 | metrics = dict() 71 | grads = tape.gradient(loss, self.model.trainable_variables) 72 | if self.clip: 73 | grads, _ = tf.clip_by_global_norm(grads, self.clip) 74 | self.model.optimizer.apply_gradients(zip(grads, self.model.trainable_variables)) 75 | return loss, metrics 76 | 77 | def train( 78 | self, 79 | n_epochs, 80 | batch_size=32, 81 | verbose=True, 82 | print_keys=[], 83 | only_val_metrics=False, 84 | ): 85 | n_batches = ( 86 | len(self.batch_ids) // batch_size 87 | if self.generator is None 88 | else len(self.generator) 89 | ) 90 | end_at = self.epoch_n + n_epochs 91 | has_val = self.val_generator is not None or self.val_features is not None 92 | extra_metric_keys = self.model.metric_names[:] 93 | if has_val: 94 | if only_val_metrics: 95 | extra_metric_keys = [ 96 | "val_" + metric_key for metric_key in extra_metric_keys 97 | ] 98 | else: 99 | extra_metric_keys += [ 100 | "val_" + metric_key for metric_key in extra_metric_keys 101 | ] 102 | for _ in range(n_epochs): 103 | self.epoch_n += 1 104 | if self.batch_ids is not None: 105 | np.random.shuffle(self.batch_ids) 106 | if verbose: 107 | progress = tqdm( 108 | total=n_batches, desc=f"Epoch {self.epoch_n}/{end_at}", unit="Batch" 109 | ) 110 | extras = {k: [] for k in print_keys} 111 | losses = [] 112 | val_losses = [] 113 | extra_metrics = {k: [] for k in extra_metric_keys} 114 | for i in range(n_batches): 115 | val_loss = None 116 | if self.generator is None: 117 | batch_idx = self.batch_ids[i * batch_size : (i + 1) * batch_size] 118 | batch_features = self.features[batch_idx, :] 119 | batch_target = self.target[batch_idx, :] 120 | if self.weights is not None: 121 | batch_weights = self.weights[batch_idx] 122 | batch_weights = batch_weights / batch_weights.sum() 123 | else: 124 | batch_weights = None 125 | else: 126 | batch_features, batch_target, batch_weights = self.generator[i] 127 | loss, metrics = self._step( 128 | batch_features, 129 | batch_target, 130 | batch_weights, 131 | only_val_metrics=only_val_metrics, 132 | ) 133 | for m_key, m_val in metrics.items(): 134 | if len(m_val.shape) > 1: 135 | m_val = self.loss_agg_f(m_val) 136 | extra_metrics[m_key].append(m_val) 137 | losses.append(self.loss_agg_f(loss)) 138 | for attr_name in extras.keys(): 139 | attr = getattr(self.model, attr_name, None) 140 | if len(attr.shape) > 1: 141 | attr = self.loss_agg_f(attr) 142 | extras[attr_name].append(attr) 143 | 144 | if self.val_generator is not None: 145 | val_features, val_target, val_weights = self.val_generator[i] 146 | # must get attention here to serialize the input for saving 147 | val_output = self.model(val_features, training=False) 148 | val_loss = self.model.loss( 149 | val_target, 150 | val_output, 151 | sample_weight=val_weights, 152 | training=False, 153 | ) 154 | if len(self.model.metric_names) > 0: 155 | val_metrics = self.model.compute_metrics( 156 | val_target, 157 | val_output, 158 | sample_weight=val_weights, 159 | training=False, 160 | ) 161 | else: 162 | val_metrics = dict() 163 | for m_key, m_val in val_metrics.items(): 164 | if len(m_val.shape) > 1: 165 | m_val = self.loss_agg_f(m_val) 166 | extra_metrics["val_" + m_key].append(m_val) 167 | val_losses.append(self.loss_agg_f(val_loss)) 168 | if verbose: 169 | extra_to_show = { 170 | **{k: np.average(v) for k, v in extras.items()}, 171 | **{k: np.average(v) for k, v in extra_metrics.items()}, 172 | } 173 | if len(val_losses) > 0: 174 | progress.set_postfix( 175 | loss=np.average(losses), 176 | val_loss=np.average(val_losses), 177 | **extra_to_show, 178 | ) 179 | else: 180 | progress.set_postfix(loss=np.average(losses), **extra_to_show) 181 | progress.update(1) 182 | if verbose: 183 | # run model as if not training on validation data to get out of sample performance 184 | if self.val_features is not None: 185 | val_out = self.model(self.val_features, training=False) 186 | val_loss = self.model.loss( 187 | self.val_target, 188 | val_out, 189 | sample_weight=self.val_weights, 190 | training=False, 191 | ) 192 | progress.set_postfix( 193 | loss=np.average(losses), 194 | val_loss=self.loss_agg_f(val_loss), 195 | **extra_to_show, 196 | ) 197 | progress.close() 198 | if self.generator is not None: 199 | self.generator.on_epoch_end() 200 | if self.val_generator is not None: 201 | self.val_generator.on_epoch_end() 202 | -------------------------------------------------------------------------------- /mtg/ml/layers.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | 5 | class Dense(tf.Module): 6 | def __init__( 7 | self, 8 | in_dim, 9 | out_dim, 10 | name=None, 11 | initializer=tf.initializers.GlorotNormal(), 12 | activation=tf.nn.relu, 13 | use_bias=True, 14 | ): 15 | super().__init__(name=name) 16 | 17 | self.activation = activation 18 | self.use_bias = use_bias 19 | 20 | self.w = tf.Variable( 21 | initializer([in_dim, out_dim]), 22 | dtype=tf.float32, 23 | name=self.name + "_w", 24 | trainable=True, 25 | ) 26 | if self.use_bias: 27 | self.b = tf.Variable( 28 | tf.zeros([out_dim]), 29 | dtype=tf.float32, 30 | trainable=True, 31 | name=self.name + "_b", 32 | ) 33 | 34 | @tf.function 35 | def __call__(self, x, training=None): 36 | rank = x.shape.rank 37 | if rank == 2 or rank is None: 38 | y = tf.matmul(x, self.w) 39 | else: 40 | y = tf.tensordot(x, self.w, [[rank - 1], [0]]) 41 | if not tf.executing_eagerly(): 42 | shape = x.shape.as_list() 43 | output_shape = shape[:-1] + [self.w.shape[-1]] 44 | y.set_shape(output_shape) 45 | if self.use_bias: 46 | y = tf.nn.bias_add(y, self.b) 47 | if self.activation is not None: 48 | y = self.activation(y) 49 | return y 50 | 51 | 52 | class LayerNormalization(tf.Module): 53 | def __init__( 54 | self, 55 | last_dim, 56 | epsilon=1e-6, 57 | center=True, 58 | scale=True, 59 | name=None, 60 | ): 61 | super().__init__(name=name) 62 | self.center = center 63 | self.epsilon = epsilon 64 | # current implementation can only normalize off last axis 65 | self.axis = -1 66 | if scale: 67 | self.gamma = tf.Variable( 68 | tf.ones(last_dim), 69 | dtype=tf.float32, 70 | trainable=True, 71 | name=self.name + "_gamma", 72 | ) 73 | else: 74 | self.gamma = None 75 | if center: 76 | self.beta = tf.Variable( 77 | tf.zeros(last_dim), 78 | dtype=tf.float32, 79 | trainable=True, 80 | name=self.name + "_beta", 81 | ) 82 | else: 83 | self.beta = None 84 | 85 | def __call__(self, x, training=None): 86 | mu, sigma = tf.nn.moments(x, self.axis, keepdims=True) 87 | # Compute layer normalization using the batch_normalization function. 88 | outputs = tf.nn.batch_normalization( 89 | x, 90 | mu, 91 | sigma, 92 | offset=self.beta, 93 | scale=self.gamma, 94 | variance_epsilon=self.epsilon, 95 | ) 96 | # If some components of the shape got lost due to adjustments, fix that. 97 | outputs.set_shape(x.shape) 98 | return outputs 99 | 100 | 101 | class MultiHeadAttention(tf.Module): 102 | """ 103 | tf implementation of multi-headed attention. 104 | 105 | d_model is the final dimension for the embedding representation post-attention 106 | 107 | num_heads is the number of contextual ways to look at the information 108 | 109 | k_dim will be equal to the number of time steps in a draft 110 | (e.g. 45), and the mask will prevent lookahead (e.g. the mask for P1P3 will look 111 | like [0, 0, 0, 1, 1, . . ., 1]), meaning that only information at P1P1, P1P2, and 112 | P1P3 can be used to make a prediction. 113 | 114 | implementation guided via: https://www.tensorflow.org/text/tutorials/transformer 115 | """ 116 | 117 | def __init__(self, d_model, k_dim, num_heads, v_dim=None, name=None): 118 | super().__init__(name=name) 119 | self.num_heads = num_heads 120 | self.d_model = d_model 121 | 122 | assert d_model % self.num_heads == 0 123 | 124 | self.depth = d_model // self.num_heads 125 | v_dim = k_dim if v_dim is None else v_dim 126 | 127 | self.wq = Dense(k_dim, d_model, activation=None, name=self.name + "_wq") 128 | self.wk = Dense(k_dim, d_model, activation=None, name=self.name + "_wk") 129 | self.wv = Dense(v_dim, d_model, activation=None, name=self.name + "_wv") 130 | 131 | self.dense = Dense(k_dim, d_model, activation=None, name=self.name + "_attout") 132 | 133 | def split_heads(self, x, batch_size): 134 | """Split the last dimension into (num_heads, depth). 135 | Transpose the result such that the shape is (batch_size, num_heads, seq_len, depth) 136 | """ 137 | x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth)) 138 | return tf.transpose(x, perm=[0, 2, 1, 3]) 139 | 140 | @tf.function 141 | def __call__(self, v, k, q, mask, training=None): 142 | batch_size = tf.shape(q)[0] 143 | 144 | q = self.wq(q, training=training) # (batch_size, seq_len, d_model) 145 | k = self.wk(k, training=training) # (batch_size, seq_len, d_model) 146 | v = self.wv(v, training=training) # (batch_size, seq_len, d_model) 147 | 148 | q = self.split_heads(q, batch_size) # (batch_size, num_heads, seq_len_q, depth) 149 | k = self.split_heads(k, batch_size) # (batch_size, num_heads, seq_len_k, depth) 150 | v = self.split_heads(v, batch_size) # (batch_size, num_heads, seq_len_v, depth) 151 | 152 | # scaled_attention.shape == (batch_size, num_heads, seq_len_q, depth) 153 | # attention_weights.shape == (batch_size, num_heads, seq_len_q, seq_len_k) 154 | scaled_attention, attention_weights = self.scaled_dot_product_attention( 155 | q, k, v, mask 156 | ) 157 | 158 | scaled_attention = tf.transpose( 159 | scaled_attention, perm=[0, 2, 1, 3] 160 | ) # (batch_size, seq_len_q, num_heads, depth) 161 | 162 | concat_attention = tf.reshape( 163 | scaled_attention, (batch_size, -1, self.d_model) 164 | ) # (batch_size, seq_len_q, d_model) 165 | 166 | output = self.dense( 167 | concat_attention, training=training 168 | ) # (batch_size, seq_len_q, d_model) 169 | 170 | return output, attention_weights 171 | 172 | def scaled_dot_product_attention(self, q, k, v, mask): 173 | """Calculate the attention weights. 174 | q, k, v must have matching leading dimensions. 175 | k, v must have matching penultimate dimension, i.e.: seq_len_k = seq_len_v. 176 | The mask has different shapes depending on its type(padding or look ahead) 177 | but it must be broadcastable for addition. 178 | 179 | Args: 180 | q: query shape == (..., seq_len_q, depth) 181 | k: key shape == (..., seq_len_k, depth) 182 | v: value shape == (..., seq_len_v, depth_v) 183 | mask: Float tensor with shape broadcastable 184 | to (..., seq_len_q, seq_len_k). Defaults to None. 185 | 186 | Returns: 187 | output, attention_weights 188 | """ 189 | 190 | matmul_qk = tf.matmul(q, k, transpose_b=True) # (..., seq_len_q, seq_len_k) 191 | 192 | # scale matmul_qk 193 | dk = tf.cast(tf.shape(k)[-1], tf.float32) 194 | scaled_attention_logits = matmul_qk / tf.math.sqrt(dk) 195 | 196 | # add the mask to the scaled tensor. 197 | if mask is not None: 198 | # expand mask dimension to allow for addition on all attention heads 199 | scaled_attention_logits += tf.expand_dims(mask, 1) * -1e9 200 | 201 | # softmax is normalized on the last axis (seq_len_k) so that the scores 202 | # add up to 1. 203 | attention_weights = tf.nn.softmax( 204 | scaled_attention_logits, axis=-1 205 | ) # (..., seq_len_q, seq_len_k) 206 | 207 | output = tf.matmul(attention_weights, v) # (..., seq_len_q, depth_v) 208 | 209 | return output, attention_weights 210 | 211 | 212 | class Embedding(tf.Module): 213 | def __init__( 214 | self, 215 | num_items, 216 | emb_dim, 217 | initializer=tf.initializers.GlorotNormal(), 218 | name=None, 219 | activation=None, 220 | ): 221 | super().__init__(name=name) 222 | self.embedding = tf.Variable( 223 | initializer(shape=(num_items, emb_dim)), 224 | dtype=tf.float32, 225 | name=self.name + "_embedding", 226 | ) 227 | self.activation = activation 228 | 229 | @tf.function 230 | def __call__(self, x, training=None): 231 | embeddings = tf.gather(self.embedding, x) 232 | if self.activation is not None: 233 | embeddings = self.activation(embeddings) 234 | return embeddings 235 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /mtg/ml/generator.py: -------------------------------------------------------------------------------- 1 | import gc 2 | 3 | import numpy as np 4 | import pandas as pd 5 | import tensorflow as tf 6 | from mtg.ml.utils import importance_weighting 7 | from tensorflow.keras.utils import Sequence 8 | 9 | 10 | class MTGDataGenerator(Sequence): 11 | def __init__( 12 | self, 13 | data, 14 | cards, 15 | card_col_prefixes, 16 | batch_size=32, 17 | shuffle=True, 18 | to_fit=True, 19 | exclude_basics=True, 20 | store_basics=False, 21 | ): 22 | self.cards = cards.sort_values(by="idx", ascending=True) 23 | self.card_col_prefixes = card_col_prefixes 24 | self.exclude_basics = exclude_basics 25 | self.store_basics = store_basics 26 | if self.exclude_basics: 27 | self.cards = self.cards.iloc[5:, :] 28 | self.cards["idx"] = self.cards["idx"] - 5 29 | self.batch_size = batch_size 30 | self.shuffle = shuffle 31 | self.to_fit = to_fit 32 | self.n_cards = self.cards.shape[0] 33 | self.generate_global_data(data) 34 | self.size = data.shape[0] 35 | # generate initial indices for batching the data 36 | self.reset_indices() 37 | 38 | def __len__(self): 39 | """ 40 | return: number of batches per epoch 41 | """ 42 | return self.size // self.batch_size 43 | 44 | def reset_indices(self): 45 | self.indices = np.arange(self.size) 46 | if self.shuffle == True: 47 | np.random.shuffle(self.indices) 48 | 49 | def on_epoch_end(self): 50 | """ 51 | Update indices after each epoch 52 | """ 53 | self.reset_indices() 54 | gc.collect() 55 | 56 | def card_name_to_idx(self, card_name, exclude_basics=True): 57 | return self.cards[self.cards["name"] == card_name]["idx"].iloc[0] 58 | 59 | def card_idx_to_name(self, card_idx, exclude_basics=True): 60 | return self.cards[self.cards["idx"] == card_idx]["name"].iloc[0] 61 | 62 | def generate_global_data(self, data): 63 | self.all_cards = [col.split("_", 1)[-1] for col in data.columns if col.startswith(self.card_col_prefixes[0])] 64 | basics = ["plains", "island", "swamp", "mountain", "forest"] 65 | if self.exclude_basics: 66 | exclude_cards = basics 67 | else: 68 | exclude_cards = [] 69 | for prefix in self.card_col_prefixes: 70 | prefix_size = len(prefix + "_") 71 | cols = [ 72 | col 73 | for col in data.columns 74 | if col.startswith(prefix + "_") and not any([x == col[prefix_size:] for x in exclude_cards]) 75 | ] 76 | setattr(self, prefix, data[cols].values) 77 | if self.store_basics: 78 | basic_cols = [col for col in data.columns if any([x == col[prefix_size:] for x in basics])] 79 | setattr(self, prefix + "_basics", data[basic_cols].values) 80 | if "ml_weights" in data.columns: 81 | self.weights = data["ml_weights"].values 82 | else: 83 | self.weights = None 84 | 85 | def __getitem__(self, batch_number): 86 | """ 87 | Generates a data mini-batch 88 | param batch_number: which batch to generate 89 | return: X and y when fitting. X only when predicting 90 | """ 91 | indices = self.indices[batch_number * self.batch_size : (batch_number + 1) * self.batch_size] 92 | X, y, weights = self.generate_data(indices) 93 | 94 | if self.to_fit: 95 | return X, y, weights 96 | else: 97 | return X 98 | 99 | def generate_data(self, indices): 100 | raise NotImplementedError 101 | 102 | 103 | class DraftGenerator(MTGDataGenerator): 104 | def __init__( 105 | self, 106 | data, 107 | cards, 108 | batch_size=32, 109 | shuffle=True, 110 | to_fit=True, 111 | exclude_basics=True, 112 | store_basics=False, 113 | ): 114 | super().__init__( 115 | data, 116 | cards, 117 | card_col_prefixes=["pack_card"], 118 | batch_size=batch_size, 119 | shuffle=shuffle, 120 | to_fit=to_fit, 121 | exclude_basics=exclude_basics, 122 | store_basics=store_basics, 123 | ) 124 | # overwrite the size to make sure we always sample full drafts 125 | self.size = len(self.draft_ids) 126 | self.reset_indices() 127 | 128 | def generate_global_data(self, data): 129 | self.draft_ids = data["draft_id"].unique() 130 | self.t = data["position"].max() + 1 131 | data = data.set_index(["draft_id", "position"]) 132 | # NOTE: the next chunk of lines is close to identical to the super() call of this 133 | # function. There is a difference in accessing .values instead of the dataframe 134 | # directly. In the future, clean this up such that it can just call super 135 | self.all_cards = [col.split("_", 1)[-1] for col in data.columns if col.startswith(self.card_col_prefixes[0])] 136 | basics = ["plains", "island", "swamp", "mountain", "forest"] 137 | if self.exclude_basics: 138 | exclude_cards = basics 139 | else: 140 | exclude_cards = [] 141 | for prefix in self.card_col_prefixes: 142 | prefix_size = len(prefix + "_") 143 | cols = [ 144 | col 145 | for col in data.columns 146 | if col.startswith(prefix + "_") and not any([x == col[prefix_size:] for x in exclude_cards]) 147 | ] 148 | setattr(self, prefix, data[cols]) 149 | if self.store_basics: 150 | basic_cols = [col for col in data.columns if any([x == col[prefix_size:] for x in basics])] 151 | setattr(self, prefix + "_basics", data[basic_cols]) 152 | if "ml_weights" in data.columns: 153 | self.weights = data["ml_weights"] 154 | else: 155 | self.weights = None 156 | name_to_idx_mapping = { 157 | k.split("//")[0].strip().lower(): v for k, v in self.cards.set_index("name")["idx"].to_dict().items() 158 | } 159 | self.pick = data["pick"].apply(lambda x: name_to_idx_mapping[x]) 160 | self.shifted_pick = self.pick.groupby(level=0).shift(1).fillna(self.n_cards) 161 | self.position = data["pack_number"] * (data["pick_number"].max() + 1) + data["pick_number"] 162 | 163 | def generate_data(self, indices): 164 | draft_ids = self.draft_ids[indices] 165 | packs = self.pack_card.loc[draft_ids].values.reshape(len(indices), self.t, len(self.pack_card.columns)) 166 | # pools = self.pool.loc[draft_ids].values.reshape(len(indices), self.t, len(self.pack_card.columns)) 167 | picks = self.pick.loc[draft_ids].values.reshape(len(indices), self.t) 168 | shifted_picks = self.shifted_pick.loc[draft_ids].values.reshape(len(indices), self.t) 169 | positions = self.position.loc[draft_ids].values.reshape(len(indices), self.t) 170 | # draft_info = np.concatenate([packs, pools], axis=-1) 171 | if self.weights is not None: 172 | # comment below is if weights sum to 1 for each draft rather than for each batch 173 | # weights = (self.weights.loc[draft_ids]/self.weights.loc[draft_ids].groupby(level=0).sum()).values.reshape(len(indices), self.t) 174 | weights = (self.weights.loc[draft_ids] / self.weights.loc[draft_ids].sum()).values.reshape( 175 | len(indices), self.t 176 | ) 177 | else: 178 | weights = None 179 | # convert to tensor needed for #tf.function 180 | packs = tf.convert_to_tensor(packs.astype(np.float32), dtype=tf.float32) 181 | positions = tf.convert_to_tensor(positions.astype(np.int32), dtype=tf.int32) 182 | picks = tf.convert_to_tensor(picks.astype(np.float32), dtype=tf.int32) 183 | shifted_picks = tf.convert_to_tensor(shifted_picks.astype(np.float32), dtype=tf.int32) 184 | return (packs, shifted_picks, positions), picks, weights 185 | 186 | 187 | class DeckGenerator(MTGDataGenerator): 188 | def __init__( 189 | self, 190 | data, 191 | cards, 192 | batch_size=32, 193 | shuffle=True, 194 | to_fit=True, 195 | exclude_basics=True, 196 | store_basics=True, 197 | pos_neg_sample=False, 198 | mask_decks=False, 199 | ): 200 | super().__init__( 201 | data, 202 | cards, 203 | card_col_prefixes=["deck", "sideboard"], 204 | batch_size=batch_size, 205 | shuffle=shuffle, 206 | to_fit=to_fit, 207 | exclude_basics=exclude_basics, 208 | store_basics=store_basics, 209 | ) 210 | self.pos_neg_sample = pos_neg_sample 211 | self.mask_decks = mask_decks 212 | 213 | def generate_data(self, indices): 214 | decks = self.deck[indices, :] 215 | sideboards = self.sideboard[indices, :] 216 | basics = self.deck_basics[indices, :] 217 | if self.mask_decks: 218 | max_n_non_basics = np.max(decks.sum(axis=1)) 219 | n = max_n_non_basics + 2 220 | basics = np.repeat(basics[:, None, :], n, axis=1) 221 | masked_decks = self.create_masked_objects(decks, n=n) 222 | # this is set up so the first element in masked decks has an empty 223 | # deck to predict from the whole pool, and the last element has a fully 224 | # built deck where the only thing to predict is the basics 225 | masked_decks[:, -1, :] = decks 226 | masked_decks = masked_decks.astype(np.float32) 227 | cards_to_add = (decks[:, None, :] - masked_decks).astype(np.float32) 228 | modified_sideboards = (sideboards[:, None, :] + cards_to_add).astype(np.float32) 229 | X = (modified_sideboards, masked_decks) 230 | Y = (basics.astype(np.float32), cards_to_add) 231 | else: 232 | X = (decks + sideboards).astype(np.float32) 233 | Y = (basics.astype(np.float32), decks.astype(np.float32)) 234 | if self.weights is not None: 235 | if self.mask_decks: 236 | weights = self.weights[indices][:, None] * np.ones((len(indices), n)) 237 | else: 238 | weights = self.weights[indices] 239 | weights = weights / weights.sum() 240 | else: 241 | weights = None 242 | if self.pos_neg_sample: 243 | anchor, pos, neg = self.sample_card_pairs(decks, sideboards) 244 | return (*X, anchor, pos, neg), Y, weights 245 | return X, Y, weights 246 | 247 | def create_masked_objects(self, decks, n): 248 | masked_decks = np.zeros((decks.shape[0], n, decks.shape[1])) 249 | for i in range(1, n): 250 | masked_decks[:, i, :] = self.get_vectorized_sample(decks.copy(), n=i, uniform=True) 251 | return masked_decks 252 | 253 | def get_vectorized_sample(self, mtx, n=1, uniform=True, return_mtx=True, modify_mtx=True): 254 | if uniform: 255 | clip_mtx = np.clip(mtx, 0, 1) 256 | probabilities = clip_mtx / (clip_mtx.sum(1, keepdims=True) + 1e-9) 257 | else: 258 | probabilities = mtx / (mtx.sum(1, keepdims=True) + 1e-9) 259 | live_idxs = np.where(mtx.sum(1) != 0) 260 | cumulative_dist = probabilities.cumsum(axis=1) 261 | random_bin = np.random.rand(len(cumulative_dist), 1) 262 | sample = (random_bin < cumulative_dist).argmax(axis=1) 263 | if modify_mtx: 264 | mtx[live_idxs, sample[live_idxs]] -= 1 265 | if n > 1: 266 | cts_sample = self.get_vectorized_sample(mtx, n=n - 1, uniform=uniform, return_mtx=False) 267 | if len(cts_sample.shape) == 1: 268 | cts_sample = np.expand_dims(cts_sample, 1) 269 | sample = np.concatenate([sample[:, None], cts_sample], axis=1) 270 | if return_mtx: 271 | return mtx 272 | return sample 273 | 274 | def sample_card_pairs(self, decks, sideboards): 275 | anchors = self.get_vectorized_sample(decks, uniform=False, return_mtx=False, modify_mtx=False) 276 | 277 | # never sample the same card as the anchor as the positive or negative axample 278 | decks_without_anchors = decks.copy() 279 | decks_without_anchors[np.arange(decks.shape[0]), anchors] = 0 280 | sideboards_without_anchors = sideboards.copy() 281 | sideboards_without_anchors[np.arange(decks.shape[0]), anchors] = 0 282 | 283 | positive_samples = self.get_vectorized_sample( 284 | decks_without_anchors, uniform=False, return_mtx=False, modify_mtx=False 285 | ) 286 | negative_samples = self.get_vectorized_sample( 287 | sideboards_without_anchors, 288 | uniform=False, 289 | return_mtx=False, 290 | modify_mtx=False, 291 | ) 292 | 293 | return anchors, positive_samples, negative_samples 294 | 295 | 296 | def create_train_and_val_gens( 297 | data, 298 | cards, 299 | id_col=None, 300 | train_p=1.0, 301 | weights=True, 302 | train_batch_size=32, 303 | shuffle=True, 304 | to_fit=True, 305 | exclude_basics=True, 306 | generator=MTGDataGenerator, 307 | include_val=True, 308 | **kwargs, 309 | ): 310 | if weights and "ml_weights" not in data.columns: 311 | data["ml_weights"] = importance_weighting(data) 312 | if train_p < 1.0: 313 | if id_col is None: 314 | idxs = np.arange(data.shape[0]) 315 | train_idxs = np.random.choice(idxs, int(len(idxs) * train_p), replace=False) 316 | test_idxs = np.asarray(list(set(idxs.flatten()) - set(train_idxs.flatten()))) 317 | train_data = data[train_idxs, :] 318 | test_data = data[test_idxs, :] 319 | else: 320 | idxs = data[id_col].unique() 321 | train_idxs = np.random.choice(idxs, int(len(idxs) * train_p), replace=False) 322 | train_data = data[data[id_col].isin(train_idxs)] 323 | test_data = data[~data[id_col].isin(train_idxs)] 324 | n_train = int(len(idxs) * train_p) 325 | n_test = len(idxs) - n_train 326 | else: 327 | train_data = data 328 | test_data = None 329 | train_gen = generator( 330 | train_data, 331 | cards.copy(), 332 | batch_size=train_batch_size, 333 | shuffle=shuffle, 334 | to_fit=to_fit, 335 | exclude_basics=exclude_basics, 336 | **kwargs, 337 | ) 338 | if test_data is not None and include_val: 339 | n_train_batches = len(train_gen) 340 | val_batch_size = n_test // n_train_batches 341 | val_gen = generator( 342 | test_data, 343 | cards.copy(), 344 | batch_size=val_batch_size, 345 | shuffle=shuffle, 346 | to_fit=to_fit, 347 | exclude_basics=exclude_basics, 348 | **kwargs, 349 | ) 350 | else: 351 | val_gen = None 352 | return train_gen, val_gen 353 | -------------------------------------------------------------------------------- /mtg/obj/expansion.py: -------------------------------------------------------------------------------- 1 | import random 2 | import time 3 | 4 | import numpy as np 5 | import pandas as pd 6 | from mtg.obj.cards import CardSet 7 | from mtg.obj.dataloading_utils import get_card_rating_data, load_data 8 | 9 | 10 | class Expansion: 11 | def __init__( 12 | self, 13 | expansion, 14 | bo1=None, 15 | bo3=None, 16 | quick=None, 17 | draft=None, 18 | replay=None, 19 | ml_data=True, 20 | idx_to_name=None, 21 | ): 22 | self.expansion = expansion 23 | self.cards = self.get_cards_from_scryfall() 24 | self.clean_card_df(idx_to_name) 25 | self.bo1 = self.process_data(bo1, name="bo1") 26 | self.bo3 = self.process_data(bo3, name="bo3") 27 | self.quick = self.process_data(quick, name="quick") 28 | self.draft = self.process_data(draft, name="draft") 29 | self.replay = self.process_data(replay, name="replay") 30 | if ml_data: 31 | self.card_data_for_ML = self.get_card_data_for_ML() 32 | else: 33 | self.card_data_for_ML = None 34 | self.create_data_dependent_attributes() 35 | 36 | def get_cards_from_scryfall(self): 37 | return CardSet([f"set={self.expansion}", "is:booster"]).to_dataframe() 38 | 39 | @property 40 | def types(self): 41 | return [ 42 | "instant", 43 | "sorcery", 44 | "creature", 45 | "planeswalker", 46 | "artifact", 47 | "enchantment", 48 | "land", 49 | ] 50 | 51 | def process_data(self, file_or_df, name=None): 52 | if isinstance(file_or_df, str): 53 | if name is None: 54 | df = pd.read_csv(file_or_df) 55 | else: 56 | df = load_data(file_or_df, self.cards.copy(), name=name) 57 | else: 58 | df = file_or_df 59 | return df 60 | 61 | def clean_card_df(self, idx_to_name=None): 62 | if idx_to_name is not None: 63 | if "plains" not in idx_to_name.keys(): 64 | idx_to_name = {k + 5: v for k, v in idx_to_name.items()} 65 | basics = ["plains", "island", "swamp", "mountain", "forest"] 66 | for basic_idx, basic in enumerate(basics): 67 | idx_to_name[basic_idx] = basic 68 | name_to_idx = {v: k for k, v in idx_to_name.items()} 69 | self.cards["idx"] = self.cards["name"].apply(lambda x: name_to_idx[x]) 70 | # set it so ramp spells that search for basics are seen as rainbow producers 71 | # logic to subset by basic implemented where needed 72 | search_check = lambda x: "search your library" in x["oracle_text"].lower() 73 | basic_check = lambda x: "basic land" in x["oracle_text"].lower() 74 | self.cards["basic_land_search"] = self.cards.apply(lambda x: search_check(x) and basic_check(x), axis=1) 75 | # TODO: at the moment, flip cards are any non-normal cards. Consider 76 | # ways to handle other layouts like split cards too 77 | self.cards["flip"] = self.cards["layout"].apply(lambda x: 0.0 if x == "normal" else 1.0) 78 | self.cards = self.cards.sort_values(by="idx") 79 | 80 | def get_card_data_for_ML(self, return_df=True): 81 | ml_data = self.get_card_stats() 82 | colors = list("WUBRG") 83 | cards = self.cards.set_index("name").copy() 84 | # Power/Toughness sometimes has "*" instead of numbers, so need to 85 | # convert variable P/Ts to unique integers so that it can feed to the model 86 | cards = cards.replace(to_replace="1+*", value=-1) 87 | cards = cards.replace(to_replace="*", value=-1) 88 | copy_from_scryfall = ["power", "toughness", "basic_land_search", "flip", "cmc"] 89 | for column in copy_from_scryfall: 90 | ml_data[column] = cards[column].astype(float) 91 | keywords = list(set(cards["keywords"].sum())) 92 | keyword_df = pd.DataFrame(index=cards.index, columns=keywords).fillna(0) 93 | for card_idx, keys in cards["keywords"].to_dict().items(): 94 | keyword_df.loc[card_idx, keys] = 1.0 95 | ml_data = pd.concat([ml_data, keyword_df], axis=1) 96 | for color in colors: 97 | ml_data[color + " pips"] = cards["mana_cost"].apply(lambda x: x.count(color)) 98 | ml_data["produces " + color] = cards["produced_mana"].apply( 99 | lambda x: 0.0 if not isinstance(x, list) else int(color in x) 100 | ) 101 | for cardtype in self.types: 102 | cardtype = cardtype.lower() 103 | ml_data[cardtype] = ( 104 | cards["type_line"].str.lower().apply(lambda x: 0.0 if not isinstance(x, str) else int(cardtype in x)) 105 | ) 106 | rarities = cards["rarity"].unique() 107 | for rarity in rarities: 108 | ml_data[rarity] = cards["rarity"].apply(lambda x: int(x == rarity)) 109 | ml_data["produces C"] = cards["produced_mana"].apply(lambda x: 0 if not isinstance(x, list) else int("C" in x)) 110 | ml_data.columns = [x.lower() for x in ml_data.columns] 111 | count_cols = [x for x in ml_data.columns if "_count" in x] 112 | # 0-1 normalize data representing counts 113 | ml_data[count_cols] = ml_data[count_cols].apply(lambda x: x / x.max(), axis=0) 114 | ml_data["idx"] = cards["idx"] 115 | # the way our embeddings work is we always have an embedding that represents the lack of a card. This helps the model 116 | # represent stuff like generic format information. Hence we make this a one-hot vector that gets used in Draft when 117 | # the pack is empty, but have that concept "on" for every single card so it can affect the learned representations 118 | ml_data.loc["bias", :] = 0.0 119 | ml_data.loc["bias", "idx"] = cards["idx"].max() + 1 120 | ml_data["bias"] = 1.0 121 | ml_data = ml_data.fillna(0).sort_values("idx").reset_index(drop=True) 122 | ml_data = ml_data.drop("idx", axis=1) 123 | if return_df: 124 | return ml_data 125 | return ml_data.values 126 | 127 | def get_card_stats(self): 128 | all_colors = [ 129 | None, 130 | "W", 131 | "U", 132 | "B", 133 | "R", 134 | "G", 135 | "WU", 136 | "WB", 137 | "WR", 138 | "WG", 139 | "UB", 140 | "UR", 141 | "UG", 142 | "BR", 143 | "BG", 144 | "RG", 145 | "WUB", 146 | "WUR", 147 | "WUG", 148 | "WBR", 149 | "WBG", 150 | "WRG", 151 | "UBR", 152 | "UBG", 153 | "URG", 154 | "BRG", 155 | "WUBR", 156 | "WUBG", 157 | "WURG", 158 | "WBRG", 159 | "UBRG", 160 | "WUBRG", 161 | ] 162 | card_df = pd.DataFrame() 163 | for colors in all_colors: 164 | time.sleep(1) 165 | card_data_df = get_card_rating_data(self.expansion, colors=colors) 166 | extension = "" if colors is None else "_" + colors 167 | card_data_df.columns = [col + extension for col in card_data_df.columns] 168 | card_df = pd.concat([card_df, card_data_df], axis=1).fillna(0.0) 169 | return card_df 170 | 171 | def get_bo1_decks(self): 172 | d = {column: "last" for column in self.bo1.columns if column not in ["opp_colors"]} 173 | d.update( 174 | { 175 | "won": "sum", 176 | "on_play": "mean", 177 | "num_mulligans": "mean", 178 | "opp_num_mulligans": "mean", 179 | "num_turns": "mean", 180 | } 181 | ) 182 | decks = self.bo1.groupby("draft_id").agg(d) 183 | deck_cols = [x for x in decks.columns if x.startswith("deck_")] 184 | decks = decks[decks[deck_cols].sum(1) == 40] 185 | return decks 186 | 187 | def create_data_dependent_attributes(self): 188 | if self.draft is not None: 189 | self.t = self.draft["position"].max() + 1 190 | 191 | def get_mapping(self, key, value, include_basics=False): 192 | assert key != value, "key and value must be different" 193 | mapping = self.cards.set_index(key)[value].to_dict() 194 | if not include_basics: 195 | if key == "idx": 196 | mapping = {k - 5: v for k, v in mapping.items() if k >= 5} 197 | elif value == "idx": 198 | mapping = {k: v - 5 for k, v in mapping.items() if v >= 5} 199 | return mapping 200 | 201 | def generate_pack(self, exclude_basics=True, name_to_idx=None, return_names=False): 202 | """ 203 | generate random pack of MTG cards 204 | """ 205 | cards = self.cards.copy() 206 | if exclude_basics: 207 | cards = cards[cards["idx"] >= 5].copy() 208 | cards["idx"] = cards["idx"] - 5 209 | if name_to_idx is None: 210 | name_to_idx = cards.set_index("name")["idx"].to_dict() 211 | if np.random.random() < 1 / 8: 212 | rare = random.sample( 213 | cards[(cards["rarity"] == "mythic")]["name"].tolist(), 214 | 1, 215 | ) 216 | else: 217 | rare = random.sample( 218 | cards[(cards["rarity"] == "rare")]["name"].tolist(), 219 | 1, 220 | ) 221 | uncommons = random.sample( 222 | cards[(cards["rarity"] == "uncommon")]["name"].tolist(), 223 | 3, 224 | ) 225 | commons = [] 226 | # make sure at least one common of each color 227 | for color in list("WUBRG"): 228 | color_common = random.sample( 229 | cards[ 230 | (cards["rarity"] == "common") 231 | & (cards["mana_cost"].str.contains(color)) 232 | & (~cards["name"].isin(commons)) 233 | ]["name"].tolist(), 234 | 1, 235 | ) 236 | commons += color_common 237 | other_commons = random.sample( 238 | cards[((cards["rarity"] == "common")) & (~cards["name"].isin(commons))]["name"].tolist(), 239 | 5, 240 | ) 241 | commons += other_commons 242 | names = rare + uncommons + commons 243 | if return_names: 244 | return names 245 | idxs = [name_to_idx[name] for name in names] 246 | pack = np.zeros(len(cards)) 247 | pack[idxs] = 1 248 | return pack 249 | 250 | 251 | class VOW(Expansion): 252 | def __init__( 253 | self, 254 | bo1=None, 255 | bo3=None, 256 | quick=None, 257 | draft=None, 258 | replay=None, 259 | ml_data=True, 260 | idx_to_name=None, 261 | ): 262 | super().__init__( 263 | expansion="vow", 264 | bo1=bo1, 265 | bo3=bo3, 266 | quick=quick, 267 | draft=draft, 268 | replay=replay, 269 | ml_data=ml_data, 270 | idx_to_name=idx_to_name, 271 | ) 272 | 273 | def generate_pack(self, exclude_basics=True, name_to_idx=None, return_names=False): 274 | """ 275 | special handling for flip cards 276 | """ 277 | cards = self.cards.copy() 278 | if exclude_basics: 279 | cards = cards[cards["idx"] >= 5].copy() 280 | cards["idx"] = cards["idx"] - 5 281 | if name_to_idx is None: 282 | name_to_idx = cards.set_index("name")["idx"].to_dict() 283 | uncommon_or_rare_flip = random.sample( 284 | cards[(cards["rarity"].isin(["mythic", "rare", "uncommon"])) & (cards["flip"] == 1)]["name"].tolist(), 285 | 1, 286 | )[0] 287 | common_flip = random.sample( 288 | cards[(cards["rarity"] == "common") & (cards["flip"] == 1)]["name"].tolist(), 289 | 1, 290 | )[0] 291 | upper_rarity = cards[cards["name"] == uncommon_or_rare_flip]["rarity"].values[0] 292 | if upper_rarity == "uncommon": 293 | if np.random.random() < 1 / 8: 294 | rare = random.sample( 295 | cards[(cards["rarity"] == "mythic") & (cards["flip"] == 0)]["name"].tolist(), 296 | 1, 297 | ) 298 | else: 299 | rare = random.sample( 300 | cards[(cards["rarity"] == "rare") & (cards["flip"] == 0)]["name"].tolist(), 301 | 1, 302 | ) 303 | uncommons = random.sample( 304 | cards[(cards["rarity"] == "uncommon") & (cards["flip"] == 0)]["name"].tolist(), 305 | 2, 306 | ) + [uncommon_or_rare_flip] 307 | else: 308 | uncommons = random.sample( 309 | cards[(cards["rarity"] == "uncommon") & (cards["flip"] == 0)]["name"].tolist(), 310 | 3, 311 | ) 312 | rare = [uncommon_or_rare_flip] 313 | commons = [common_flip] 314 | # make sure at least one common of each color 315 | for color in list("WUBRG"): 316 | color_common = random.sample( 317 | cards[ 318 | (cards["rarity"] == "common") 319 | & (cards["flip"] == 0) 320 | & (cards["mana_cost"].str.contains(color)) 321 | & (~cards["name"].isin(commons)) 322 | ]["name"].tolist(), 323 | 1, 324 | ) 325 | commons += color_common 326 | other_commons = random.sample( 327 | cards[((cards["rarity"] == "common")) & (cards["flip"] == 0) & (~cards["name"].isin(commons))][ 328 | "name" 329 | ].tolist(), 330 | 4, 331 | ) 332 | commons += other_commons 333 | names = rare + uncommons + commons 334 | if return_names: 335 | return names 336 | idxs = [name_to_idx[name] for name in names] 337 | pack = np.zeros(len(cards)) 338 | pack[idxs] = 1 339 | return pack 340 | 341 | @property 342 | def types(self): 343 | types = super().types 344 | return types + ["human", "zombie", "wolf", "werewolf", "spirit", "aura"] 345 | 346 | 347 | class SNC(Expansion): 348 | def __init__( 349 | self, 350 | bo1=None, 351 | bo3=None, 352 | quick=None, 353 | draft=None, 354 | replay=None, 355 | ml_data=True, 356 | idx_to_name=None, 357 | ): 358 | super().__init__( 359 | expansion="snc", 360 | bo1=bo1, 361 | bo3=bo3, 362 | quick=quick, 363 | draft=draft, 364 | replay=replay, 365 | ml_data=ml_data, 366 | idx_to_name=idx_to_name, 367 | ) 368 | 369 | @property 370 | def types(self): 371 | types = super().types 372 | return types + ["citizen"] 373 | 374 | 375 | class DMU(Expansion): 376 | def __init__( 377 | self, 378 | bo1=None, 379 | bo3=None, 380 | quick=None, 381 | draft=None, 382 | replay=None, 383 | ml_data=True, 384 | idx_to_name=None, 385 | ): 386 | super().__init__( 387 | expansion="dmu", 388 | bo1=bo1, 389 | bo3=bo3, 390 | quick=quick, 391 | draft=draft, 392 | replay=replay, 393 | ml_data=ml_data, 394 | idx_to_name=idx_to_name, 395 | ) 396 | 397 | @property 398 | def types(self): 399 | types = super().types 400 | return types + ["citizen"] 401 | 402 | 403 | class BRO(Expansion): 404 | def __init__( 405 | self, 406 | bo1=None, 407 | bo3=None, 408 | quick=None, 409 | draft=None, 410 | replay=None, 411 | ml_data=True, 412 | idx_to_name=None, 413 | ): 414 | super().__init__( 415 | expansion="bro", 416 | bo1=bo1, 417 | bo3=bo3, 418 | quick=quick, 419 | draft=draft, 420 | replay=replay, 421 | ml_data=ml_data, 422 | idx_to_name=idx_to_name, 423 | ) 424 | 425 | def get_cards_from_scryfall(self): 426 | bro = CardSet([f"set=bro", 427 | "is:booster", 428 | "-name:\"urza, planeswalker\"", 429 | "-name:\"titania, gaea incarnate\"", 430 | "-name:\"mishra, lost to phyrexia\""]) 431 | brr = CardSet([f"set=brr"]) 432 | all_cards = bro.union(brr) 433 | # overwrite cards in bro to be the Union 434 | # NOTE: allow creating a cardset from a set of cards in the future 435 | bro.cards = all_cards 436 | return bro.to_dataframe() 437 | 438 | @property 439 | def types(self): 440 | types = super().types 441 | return types 442 | 443 | 444 | EXPANSIONS = [VOW, SNC, DMU, BRO] 445 | 446 | 447 | def get_expansion_obj_from_name(expansion): 448 | for exp in EXPANSIONS: 449 | if exp.__name__.lower() == expansion.lower(): 450 | return exp 451 | raise ValueError(f"{expansion} does not have a corresponding Expansion object.") 452 | -------------------------------------------------------------------------------- /mtg/ml/display.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import requests 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | import warnings 6 | from mpl_toolkits.axes_grid1 import make_axes_locatable 7 | import os 8 | import pathlib 9 | from mtg.obj.dataloading_utils import get_draft_json 10 | 11 | 12 | def display_deck(pool, basics, spells, cards, return_url=False): 13 | """ 14 | given deckbuilder model output, return either the text of the build or a link 15 | to sealeddeck.tech 16 | """ 17 | pool = np.squeeze(pool) 18 | basics = np.squeeze(basics) 19 | spells = np.squeeze(spells) 20 | deck = np.concatenate([basics, spells]) 21 | idx_to_name = cards.set_index("idx")["name"].to_dict() 22 | sb_text = "SIDEBOARD\n\n" 23 | deck_text = "DECK\n\n" 24 | deck_json = {"sideboard": [], "deck": []} 25 | for idx, count in enumerate(deck): 26 | name = idx_to_name[idx] 27 | if idx >= 5: 28 | sb_count = pool[idx - 5] - count 29 | else: 30 | sb_count = 0 31 | if sb_count > 0: 32 | sb_text += str(int(sb_count)) + " " + name + "\n" 33 | deck_json["sideboard"].append({"name": name, "count": int(sb_count)}) 34 | if count == 0: 35 | continue 36 | deck_text += str(int(count)) + " " + name + "\n" 37 | deck_json["deck"].append({"name": name, "count": int(count)}) 38 | if return_url: 39 | r = requests.post(url="https://www.sealeddeck.tech/api/pools", json=deck_json) 40 | r_js = r.json() 41 | output = r_js["url"] 42 | else: 43 | output = deck_text + "\n" + sb_text 44 | return output 45 | 46 | 47 | def draft_sim( 48 | expansion, 49 | model, 50 | token="", 51 | build_model=None, 52 | basic_prior=True, 53 | ): 54 | """ 55 | run a draft table with 8 copies of bots 56 | """ 57 | t = expansion.t 58 | idx_to_name = expansion.get_mapping("idx", "name", include_basics=False) 59 | name_to_idx = expansion.get_mapping("name", "idx", include_basics=False) 60 | arena_mapping = expansion.get_mapping("name", "arena_id", include_basics=False) 61 | cards = expansion.cards.copy() 62 | seats = 8 63 | n_packs = 3 64 | n_cards = len(idx_to_name) 65 | n_picks = t // n_packs 66 | 67 | js = { 68 | idx: { 69 | "expansion": expansion.expansion.upper(), 70 | "token": f"{token}", 71 | "picks": [], 72 | } 73 | for idx in range(seats) 74 | } 75 | 76 | idx_to_js = {i: arena_mapping[idx_to_name[i]] for i in range(n_cards)} 77 | 78 | # index circular shuffle per iteration 79 | pack_shuffle_right = [7, 0, 1, 2, 3, 4, 5, 6] 80 | pack_shuffle_left = [1, 2, 3, 4, 5, 6, 7, 0] 81 | # initialize 82 | pick_data = np.ones((seats, t), dtype=np.int32) * n_cards 83 | pack_data = np.ones((seats, t, n_cards), dtype=np.float32) 84 | pool_data = np.ones((seats, t, n_cards), dtype=np.float32) 85 | final_pools = np.zeros((seats, n_cards), dtype=np.float32) 86 | positions = np.tile(np.arange(t, dtype=np.int32), [seats, 1]) 87 | cur_pos = 0 88 | for pack_number in range(n_packs): 89 | # generate packs for this round 90 | packs = [ 91 | expansion.generate_pack(name_to_idx=name_to_idx) for pack in range(seats) 92 | ] 93 | for pick_number in range(n_picks): 94 | pack_data[:, cur_pos, :] = np.vstack(packs) 95 | # draft_info = np.concatenate([pack_data, pool_data], axis=-1) 96 | for idx in range(seats): 97 | # model doesnt get serialized with 8 seats as an option so 98 | # we have to do it individually --- will ensure serialization 99 | # in the future 100 | data = (pack_data[[idx]], pick_data[[idx]], positions[[idx]]) 101 | # make pick 102 | predictions, _ = model(data, training=False, return_attention=True) 103 | bot_pick = tf.math.argmax(predictions[0, cur_pos]).numpy() 104 | final_pools[idx][bot_pick] += 1 105 | if cur_pos + 1 < t: 106 | pick_data[idx][cur_pos + 1] = bot_pick 107 | pool_data[idx][cur_pos + 1][bot_pick] += 1 108 | pick_js = { 109 | "pack_number": pack_number, 110 | "pick_number": pick_number, 111 | "pack_cards": [idx_to_js[x] for x in np.where(packs[idx] == 1)[0]], 112 | "pick": idx_to_js[bot_pick], 113 | } 114 | js[idx]["picks"].append(pick_js) 115 | # the bot picked the card, so remove it from the pack for the next person 116 | packs[idx][bot_pick] = 0 117 | # pass the packs (left, right, left) 118 | if pack_number % 2 == 1: 119 | packs = [packs[idx] for idx in pack_shuffle_right] 120 | else: 121 | packs = [packs[idx] for idx in pack_shuffle_left] 122 | cur_pos += 1 123 | draft_logs = [] 124 | for idx in range(seats): 125 | if build_model is not None: 126 | pool = np.expand_dims(final_pools[idx], 0) 127 | basics, spells, _ = build_decks( 128 | build_model, pool.copy(), cards=cards if basic_prior else None 129 | ) 130 | deck_url = display_deck(pool, basics, spells, cards, return_url=True) 131 | else: 132 | deck_url = None 133 | r = requests.post(url="https://www.17lands.com/api/submit_draft", json=js[idx]) 134 | r_js = r.json() 135 | draft_id = r_js["id"] 136 | output = f"https://www.17lands.com/submitted_draft/{draft_id}" 137 | if deck_url is not None: 138 | output = (output, deck_url) 139 | draft_logs.append(output) 140 | return draft_logs 141 | 142 | 143 | def draft_log_ai( 144 | draft_log_url, 145 | model, 146 | expansion, 147 | batch_size=1, 148 | token="", 149 | build_model=None, 150 | mod_lookup=dict(), 151 | basic_prior=True, 152 | att_folder=None, 153 | ): 154 | """ 155 | given a draft log, create a copy of that log that highlights what the bot would do 156 | 157 | att_folder: directory for storing attention visualizations 158 | basic_prior: heuristic update of manabase in deckbuilder 159 | mod_lookup: dictionary that lets you modify the data to prod the model and see if it 160 | changes decisions. Use it as such: 161 | 162 | { 163 | 'PxPy':{ 164 | 'pack':{ 165 | #change cardA to cardB in PxPy 166 | 'cardA':'cardB' 167 | }, 168 | #change the pick to cardC 169 | 'pick': 'cardC' 170 | } 171 | 'pool':{ 172 | # remove two copies of cardD from the pool and replace 173 | # it with a copy of cardE and a copy of cardD 174 | 'cardD':-2, 175 | 'cardE':1, 176 | 'cardF':1 177 | } 178 | } 179 | """ 180 | t = expansion.t 181 | idx_to_name = expansion.get_mapping("idx", "name", include_basics=False) 182 | name_to_idx = expansion.get_mapping("name", "idx", include_basics=False) 183 | arena_mapping = expansion.get_mapping("name", "arena_id", include_basics=False) 184 | cards = expansion.cards.copy() 185 | picks = get_draft_json(draft_log_url)["picks"] 186 | n_picks_per_pack = t / 3 187 | n_cards = len(name_to_idx) 188 | pool = np.zeros(n_cards, dtype=np.float32) 189 | draft_info = np.zeros((batch_size, t, n_cards * 2)) 190 | positions = np.tile( 191 | np.expand_dims(np.arange(t, dtype=np.int32), 0), batch_size 192 | ).reshape(batch_size, t) 193 | actual_pick = [] 194 | position_to_pxpy = dict() 195 | js = {"expansion": expansion.expansion.upper(), "token": f"{token}", "picks": []} 196 | for pick in picks: 197 | pxpy = "P" + str(pick["pack_number"] + 1) + "P" + str(pick["pick_number"] + 1) 198 | pack_mod = mod_lookup.get(pxpy, dict()).get("pack", dict()) 199 | pick_mod = mod_lookup.get(pxpy, dict()).get("pick", None) 200 | for i, option in enumerate(pick["available"]): 201 | cardname = option["name"].lower().split("//")[0].strip() 202 | if cardname in pack_mod: 203 | pick["available"][i]["name"] = pack_mod[cardname] 204 | 205 | position = int(pick["pack_number"] * n_picks_per_pack + pick["pick_number"]) 206 | if pick_mod is not None: 207 | correct_pick = pick_mod 208 | else: 209 | correct_pick = pick["pick"]["name"].lower().split("//")[0].strip() 210 | position_to_pxpy[position] = pxpy 211 | pick_idx = name_to_idx[correct_pick] 212 | names_pack = [ 213 | x["name"].lower().split("//")[0].strip() for x in pick["available"] 214 | ] 215 | idxs = [name_to_idx[name] for name in names_pack] 216 | arena_ids_in_pack = [arena_mapping[name] for name in names_pack] 217 | unique, counts = np.unique(idxs, return_counts=True) 218 | pack = np.zeros(len(name_to_idx)) 219 | pack[unique] += counts 220 | draft_info[0, position, :n_cards] = pack 221 | draft_info[0, position, n_cards:] = pool 222 | pool[pick_idx] += 1 223 | actual_pick.append(correct_pick) 224 | pick_js = { 225 | "pack_number": pick["pack_number"], 226 | "pick_number": pick["pick_number"], 227 | "pack_cards": arena_ids_in_pack, 228 | "pick": arena_mapping[correct_pick], 229 | } 230 | js["picks"].append(pick_js) 231 | pool_mod = mod_lookup.get("pool", dict()) 232 | for cardname, n_change in pool_mod.items(): 233 | card_idx = name_to_idx[cardname] 234 | pool[card_idx] += n_change 235 | # insert n_cards idx to shift the picks passed into the model to prevent seeing the correct pick 236 | np_pick = np.tile( 237 | np.expand_dims( 238 | np.asarray([n_cards] + [name_to_idx[name] for name in actual_pick[:-1]]), 0 239 | ), 240 | batch_size, 241 | ).reshape(batch_size, t) 242 | model_input = ( 243 | tf.convert_to_tensor(draft_info[:, :, :n_cards], dtype=tf.float32), 244 | tf.convert_to_tensor(np_pick, dtype=tf.int32), 245 | tf.convert_to_tensor(positions, dtype=tf.int32), 246 | ) 247 | # we get the first element in anything we return to handle the case where the model couldn't properly serialize 248 | # and we hence need to copy the data to be the same shape as the batch size in order to run a stored model 249 | output, attention = model(model_input, training=False, return_attention=True) 250 | output = output[0] 251 | 252 | if att_folder is not None: 253 | draft_id = draft_log_url.split("/")[-1] 254 | location = os.path.join(att_folder, draft_id) 255 | att = {"pack": attention[0], "pick": attention[1][0], "final": attention[1][1]} 256 | for att_name, att_vec in att.items(): 257 | # plot attention, shifted right if we're visualizing pick attention 258 | att_loc = os.path.join(location, att_name, shift=att_name == "pick") 259 | # index because shape is (1, n_heads, seq, seq) 260 | save_att_to_dir(att_vec[0], att_loc) 261 | 262 | predictions = tf.math.top_k(output, k=3).indices.numpy() 263 | predicted_picks = [idx_to_name[pred[0]] for pred in predictions] 264 | for i, js_obj in enumerate(js["picks"]): 265 | js_obj["suggested_pick"] = arena_mapping[predicted_picks[i]] 266 | r = requests.post(url="https://www.17lands.com/api/submit_draft", json=js) 267 | r_js = r.json() 268 | if build_model is not None: 269 | pool = np.expand_dims(pool, 0) 270 | basics, spells, _ = build_decks( 271 | build_model, pool.copy(), cards=cards if basic_prior else None 272 | ) 273 | deck_url = display_deck(pool, basics, spells, cards, return_url=True) 274 | else: 275 | deck_url = None 276 | try: 277 | draft_id = r_js["id"] 278 | output = f"https://www.17lands.com/submitted_draft/{draft_id}" 279 | if deck_url is not None: 280 | output = (output, deck_url) 281 | return output 282 | except: 283 | warnings.warn("Draft Log Upload Failed. Returning sent JSON to help debug.") 284 | return (js, r) 285 | 286 | 287 | def save_att_to_dir(attention, location, shift=False): 288 | """ 289 | create and store images showing each attention heads activations for 290 | the different places in models using attention. 291 | 292 | This aligns the heads such that it's easier to recognize patterns related 293 | to which head learns to process what 294 | """ 295 | pathlib.Path(location).mkdir(parents=True, exist_ok=True) 296 | if shift: 297 | pxpy = ["BIAS"] 298 | else: 299 | pxpy = [] 300 | seq_l = attention.shape[-1] 301 | n_picks = (seq_l) / 3 302 | for i in range(seq_l): 303 | pack = i // n_picks + 1 304 | pick = (i % n_picks) + 1 305 | pxpy.append("P" + str(int(pack)) + "P" + str(int(pick))) 306 | if shift: 307 | # if we shift right, we exclude the last pick of pack 3 308 | pxpy = pxpy[:-1] 309 | for i, pick in enumerate(pxpy): 310 | img_loc = os.path.join(location, pick + ".png") 311 | attention_weights = attention[:, i, : i + 1] 312 | xlabels = pxpy[: i + 1] 313 | fig = plt.figure(figsize=(900 / 96, 600 / 96), dpi=96) 314 | plt.grid() 315 | ax = plt.gca() 316 | mat = ax.matshow(attention_weights) 317 | ax.set_xticks(range(attention_weights.shape[-1])) 318 | ax.set_yticks(range(attention_weights.shape[0])) 319 | divider = make_axes_locatable(ax) 320 | cax = divider.append_axes("right", size="5%", pad=0.05) 321 | plt.colorbar(mat, cax=cax) 322 | ax.set_xticklabels(xlabels, rotation=90) 323 | plt.tight_layout() 324 | plt.savefig(img_loc) 325 | plt.clf() 326 | 327 | 328 | def build_decks(model, pool, cards=None): 329 | """ 330 | iteratively call the model to build the deck from a card pool 331 | """ 332 | pool = np.expand_dims(pool, 0) 333 | deck_out = np.zeros_like(pool) 334 | masked_flag = len(deck_out.shape) == 3 335 | spells_added = 0 336 | while True: 337 | basics, spells, n_non_basics = model((pool, deck_out), training=False) 338 | if np.round(n_non_basics) <= spells_added: 339 | break 340 | spells = spells.numpy() 341 | basics = basics.numpy() 342 | n_non_basics = n_non_basics.numpy()[0][0] 343 | card_to_add = np.squeeze(np.argmax(spells, axis=-1)) 344 | if not masked_flag: 345 | idx = np.arange(deck_out.shape[0]), card_to_add 346 | else: 347 | idx = ( 348 | np.arange(deck_out.shape[0]), 349 | np.arange(deck_out.shape[0]), 350 | card_to_add, 351 | ) 352 | deck_out[idx] += 1 353 | pool[idx] -= 1 354 | spells_added += 1 355 | # overwrite basics prediction using the actual discrete deck 356 | # not continuous representation 357 | basics = model.basic_decoder(deck_out) * (40 - spells_added) 358 | basics = basics.numpy() 359 | basics_out = np.zeros((*deck_out.shape[: len(deck_out.shape) - 1], 5)) 360 | for _ in range(40 - spells_added): 361 | card_to_add = np.squeeze(np.argmax(basics, axis=-1)) 362 | if not masked_flag: 363 | idx = np.arange(deck_out.shape[0]), card_to_add 364 | else: 365 | idx = ( 366 | np.arange(deck_out.shape[0]), 367 | np.arange(deck_out.shape[0]), 368 | card_to_add, 369 | ) 370 | basics_out[idx] += 1 371 | basics[idx] -= 1 372 | deck_out = np.concatenate([basics_out, deck_out], axis=-1) 373 | if cards is not None: 374 | deck_out = recalibrate_basics(np.squeeze(deck_out), cards) 375 | deck_out = deck_out[None, :] 376 | else: 377 | deck_out = deck_out[0] 378 | return deck_out[:, :5], deck_out[:, 5:], 40 - spells_added 379 | 380 | 381 | def recalibrate_basics(built_deck, cards, verbose=False): 382 | """ 383 | heuristic modification of basics in deckbuild to avoid OOD yielding 384 | weird manabases (e.g. basic that cant cast anything) 385 | 386 | --> eventually this will not be necessary, once deckbuilder improves 387 | """ 388 | color_to_idx = ( 389 | cards[cards["idx"] < 5] 390 | .set_index("idx")["produced_mana"] 391 | .apply(lambda x: x[0]) 392 | .reset_index() 393 | .set_index("produced_mana") 394 | .to_dict()["idx"] 395 | ) 396 | 397 | pip_count = {c: 0 for c in list("WUBRG")} 398 | # don't count a green mana dork that produces G as a G source, but if it produces other colors, it can count as a source 399 | basic_adds_extra_sources = {c: 0 for c in list("WUBRG")} 400 | splash_produces_count = {c: 0 for c in list("WUBRG")} 401 | for card_idx, count in enumerate(built_deck): 402 | if count == 0: 403 | continue 404 | card = cards[cards["idx"] == card_idx] 405 | basic_special_case_flag = (card["basic_land_search"]).iloc[0] 406 | mc = card["mana_cost"].iloc[0] 407 | splash_produce = ( 408 | list( 409 | set(card["produced_mana"].iloc[0]) - {"C"} - set(card["colors"].iloc[0]) 410 | ) 411 | if not card["produced_mana"].isna().iloc[0] 412 | else [] 413 | ) 414 | for color in pip_count.keys(): 415 | pip_count[color] += count * mc.count(color) 416 | if basic_special_case_flag: 417 | basic_count = built_deck[color_to_idx[color]] 418 | if basic_count == 0: 419 | basic_adds_extra_sources[color] += count 420 | else: 421 | splash_produces_count[color] += count 422 | elif color in splash_produce: 423 | splash_produces_count[color] += count 424 | min_produces_map = { 425 | 0: 0, 426 | 1: 3, 427 | 2: 4, 428 | 3: 4, 429 | 4: 5, 430 | } 431 | 432 | add_basics_dict = {c: 0 for c in list("WUBRG")} 433 | 434 | cut_basics_dict = {c: 0 for c in list("WUBRG")} 435 | 436 | basic_cut_limit = {c: 0 for c in list("WUBRG")} 437 | 438 | for color in list("WUBRG"): 439 | pips = pip_count[color] 440 | if pips == 0: 441 | # ensure we cut basics that dont do anything 442 | idx_for_basic = color_to_idx[color] 443 | basic_count_in_deck = built_deck[idx_for_basic] 444 | cut_basics_dict[color] += basic_count_in_deck 445 | if pips > 0 and basic_adds_extra_sources[color] > 0: 446 | min_add = 1 447 | else: 448 | min_add = 0 449 | mana_req = min_produces_map.get(pips, 6) 450 | produces = splash_produces_count[color] 451 | produces_diff = produces - mana_req 452 | if produces_diff < 0: 453 | add_basics_dict[color] += ( 454 | abs(produces_diff) - basic_adds_extra_sources[color] 455 | ) 456 | else: 457 | basic_cut_limit[color] = max(produces_diff, 0) 458 | if add_basics_dict[color] < min_add: 459 | add_basics_dict[color] = min_add 460 | 461 | # now ad_basics_dict is the number of basics per color that needs to be added 462 | # the following logic determines what basics need to be cut 463 | # get number of basics in the deck, but if that basic is required to be added, don't allow it to be cut 464 | basics_that_can_be_cut = { 465 | c: min(built_deck[color_to_idx[c]], basic_cut_limit[c]) if n == 0 else 0 466 | for c, n in add_basics_dict.items() 467 | } 468 | # this is used for making swaps when adding basics. If we are forcing some basics to be cut, don't let them be added 469 | basics_that_can_be_cut = { 470 | c: np.clip(v - cut_basics_dict[c], 0, np.inf) 471 | for c, v in basics_that_can_be_cut.items() 472 | } 473 | total_basics_to_cut = sum([x for x in add_basics_dict.values()]) 474 | if total_basics_to_cut > sum([x for x in basics_that_can_be_cut.values()]): 475 | if verbose: 476 | print("This manabase is not salvageable") 477 | cur_color_idx = 0 478 | colors_to_add = [c for c, n in add_basics_dict.items() if n > 0] 479 | check_bug = 0 480 | # if we are cutting lands, make sure to cut basics corresponding to 481 | # lower pips in the deck. This could have problems if there's already too 482 | # high an allocation to that, but empirically it seems more often balanced 483 | colors = sorted(list("WUBRG"), key=lambda color: pip_count[color]) 484 | added_already = [] 485 | while ( 486 | sum([x for x in add_basics_dict.values()]) > 0 487 | or sum([x for x in cut_basics_dict.values()]) > 0 488 | ): 489 | if len(colors_to_add) == 0: 490 | if sum([x for x in add_basics_dict.values()]) > 0: 491 | colors_to_add = [c for c, n in add_basics_dict.items() if n > 0] 492 | else: 493 | if sum([x for x in cut_basics_dict.values()]) <= 0: 494 | # nothing to add or cut! 495 | break 496 | else: 497 | colors_to_add = [ 498 | c 499 | for c, n in basics_that_can_be_cut.items() 500 | if n > 0 and c not in added_already 501 | ] 502 | if len(colors_to_add) == 0: 503 | colors_to_add = [ 504 | c for c, n in basics_that_can_be_cut.items() if n > 0 505 | ] 506 | if len(colors_to_add) == 0: 507 | if verbose: 508 | print("Nothing else is allowed to be cut, bad manabase") 509 | break 510 | c = colors[cur_color_idx % 5] 511 | # this is the actual idx in the deck built, not the fake one used to cycle through colors 512 | idx = color_to_idx[c] 513 | ad_c = colors_to_add[0] 514 | ad_idx = color_to_idx[ad_c] 515 | if sum([x for x in cut_basics_dict.values()]) > 0: 516 | if cut_basics_dict[c] > 0: 517 | built_deck[idx] -= 1 518 | built_deck[ad_idx] += 1 519 | basics_that_can_be_cut[c] -= 1 520 | cut_basics_dict[c] -= 1 521 | add_basics_dict[ad_c] -= 1 522 | else: 523 | if basics_that_can_be_cut[c] > 0: 524 | built_deck[idx] -= 1 525 | built_deck[ad_idx] += 1 526 | basics_that_can_be_cut[c] -= 1 527 | cut_basics_dict[c] -= 1 528 | add_basics_dict[ad_c] -= 1 529 | 530 | cur_color_idx += 1 531 | check_bug += 1 532 | if check_bug > 100: 533 | print("BUG") 534 | break 535 | return built_deck 536 | -------------------------------------------------------------------------------- /mtg/ml/models.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pathlib 3 | import pickle 4 | 5 | import numpy as np 6 | import tensorflow as tf 7 | from mtg.ml import nn 8 | from mtg.ml.layers import Embedding 9 | from mtg.ml.utils import CustomSchedule 10 | 11 | 12 | class DraftBot(tf.Module): 13 | """ 14 | Custom Tensorflow Model for Magic: the Gathering Draft AI 15 | 16 | This algorithm is a transformer that functions on draft data modified to 17 | work in a sequence-to-sequence manner. Given a sequence of packs and 18 | picks, as well as a contextual pack that determines your options, the 19 | goal is to select the card from the context that is best (best determined 20 | via "what a human did, with weighting towards more experienced humans) 21 | 22 | -------------------------------------------------------------------------- 23 | 24 | expansion: Expansion object instance from mtg/obj/expansion.py 25 | emb_dim: The embedding dimension to use for card embeddings 26 | num_encoder_heads: Number of heads in each encoder transformer block 27 | num_decoder_heads: Number of heads in each decoder transformer block 28 | num_encoder_layers: Number of transformer blocks in the encoder 29 | num_decoder_layers: Number of transformer blocks in the decoder 30 | pointwise_ffn_width: The width of the pointwise feedforward NN projection 31 | in each transformer block 32 | emb_dropout: Dropout rate to be applied to card embeddings 33 | memory_dropout: Dropout rate to be applied to the transformer blocks 34 | out_dropout: Dropout rate to be applied to the hidden layers in the 35 | MLP that converts the output from the transformer 36 | decoder to the prediction of what card to take 37 | 38 | Lastly, I would like to address that I know offering the ability to have different 39 | numbers of heads and layers in the encoder and decoder blocks is not commonplace. 40 | Generally, I use the same numbers for these, but in my experimentations looking at 41 | the activations of the attention vectors, it seems like the representation of packs 42 | needs less processes (heads) than the representation of picks. This packs sense 43 | from my domain expertise because an overwhelming majority of the reason to make a 44 | pick is from pool context, not pack context. Hence, I offer the ability to modify 45 | these numbers to experiment with how attention is different across those concepts. 46 | """ 47 | 48 | def __init__( 49 | self, 50 | expansion, 51 | emb_dim, 52 | num_encoder_heads, 53 | num_decoder_heads, 54 | num_encoder_layers, 55 | num_decoder_layers, 56 | pointwise_ffn_width=None, 57 | emb_dropout=0.0, 58 | memory_dropout=0.0, 59 | out_dropout=0.0, 60 | name=None, 61 | ): 62 | super().__init__(name=name) 63 | if pointwise_ffn_width is None: 64 | pointwise_ffn_width = emb_dim * 4 65 | # get some information from the expansion object for storage later. This is 66 | # because we don't want to store the expansion object (it's big), and in 67 | # case we lose it, we need to be able to initialize a new one with the 68 | # same exact card to id mappings for proper inference. 69 | self.idx_to_name = expansion.get_mapping("idx", "name", include_basics=False) 70 | self.n_cards = len(self.idx_to_name) 71 | # self.t is the number of picks in a draft 72 | self.t = expansion.t 73 | # the first five elements will be card data on basics, which is irrelevant 74 | # for drafting, so we get rid of them 75 | self.card_data = expansion.card_data_for_ML[5:] 76 | self.emb_dim = tf.Variable(emb_dim, dtype=tf.float32, trainable=False, name="emb_dim") 77 | self.dropout = emb_dropout 78 | # positional embedding allows deviation given temporal context 79 | self.positional_embedding = Embedding(self.t, emb_dim, name="positional_embedding") 80 | # lookahead mask to prevent the algorithm from seeing information it isn't 81 | # allowed to (e.g. at P1P5 you cannot look at P1P6-P3P14) 82 | self.positional_mask = 1 - tf.linalg.band_part(tf.ones((self.t, self.t)), -1, 0) 83 | # transformer encoder block for processing pack information 84 | self.encoder_layers = [ 85 | nn.TransformerBlock( 86 | emb_dim, 87 | num_encoder_heads, 88 | pointwise_ffn_width, 89 | dropout=memory_dropout, 90 | name=f"memory_encoder_{i}", 91 | ) 92 | for i in range(num_encoder_layers) 93 | ] 94 | # extra embedding as representation of bias before the draft starts. This is grabbed as the 95 | # representation for the "previous pick" that goes into the decoder for P1P1 96 | # additionally, we use a "concatEmbedding", which means we do the following: 97 | # 1. project a one_hot_vector to an embedding of dimension emb_dim//2 98 | # 2. use an MLP on the data about each card (self.card_data) to yield an 99 | # emb_dim//2 dimension embedding 100 | # 3. The embedding we use for cards is the concatenation of 1. and 2. 101 | self.card_embedding = nn.ConcatEmbedding( 102 | self.n_cards + 1, 103 | emb_dim, 104 | tf.convert_to_tensor(self.card_data, dtype=tf.float32), 105 | name="card_embedding", 106 | activation=None, 107 | ) 108 | # transformer decoder block for processing the pool with respect to the pack 109 | self.decoder_layers = [ 110 | nn.TransformerBlock( 111 | emb_dim, 112 | num_decoder_heads, 113 | pointwise_ffn_width, 114 | dropout=memory_dropout, 115 | name=f"memory_decoder_{i}", 116 | decode=True, 117 | ) 118 | for i in range(num_decoder_layers) 119 | ] 120 | # convert transformer decoder output to projection of what card to pick 121 | self.output_decoder = nn.MLP( 122 | in_dim=emb_dim, 123 | start_dim=emb_dim * 2, 124 | out_dim=self.n_cards, 125 | n_h_layers=1, 126 | dropout=out_dropout, 127 | name="output_decoder", 128 | start_act=tf.nn.relu, 129 | middle_act=tf.nn.relu, 130 | out_act=None, 131 | style="reverse_bottleneck", 132 | ) 133 | 134 | @tf.function 135 | def __call__( 136 | self, 137 | features, 138 | training=None, 139 | return_attention=False, 140 | ): 141 | packs, picks, positions = features 142 | # store last data batch in case specific batch of data causes an issue 143 | self.last_packs = packs 144 | self.last_picks = picks 145 | self.last_positions = positions 146 | # get the positional mask, which is a lookahead mask for autoregressive predictions. 147 | # effectively, to make a decision a P1P5, we make sure the model can never see P1P6 148 | # or later 149 | positional_masks = tf.gather(self.positional_mask, positions) 150 | # to make sure the model can differentiate context of a pool and pack at different time 151 | # steps, we have positional embeddings 152 | # (e.g. representation of card A at P1P1 is different than P1P8) 153 | self.positional_embeddings = self.positional_embedding(positions, training=training) 154 | self.all_card_embeddings = self.card_embedding(tf.range(self.n_cards), training=training) 155 | # TODO: represent packs as 15 indices for each card in the pack rather than a 156 | # binary vector. It's more computationally efficient and doesn't require 157 | # the step below 158 | self.pack_card_embeddings = packs[:, :, :, None] * self.all_card_embeddings[None, None, :, :] 159 | # get the number of cards in each pack 160 | self.n_options = tf.reduce_sum(packs, axis=-1, keepdims=True) 161 | # the pack_embedding is the average of the embeddings of the cards in the pack 162 | self.pack_embeddings = tf.reduce_sum(self.pack_card_embeddings, axis=2) / self.n_options 163 | # add the positional information to the card embeddings 164 | self.embs = self.pack_embeddings * tf.math.sqrt(self.emb_dim) + self.positional_embeddings 165 | 166 | if training and self.dropout > 0.0: 167 | self.embs = tf.nn.dropout(self.embs, rate=self.dropout) 168 | 169 | # we run the transformer encoder on the pack information. This is where the 170 | # bot learns how to predict the wheel. Search for improvements on how 171 | # this informs color distribution/expectation and pivots 172 | self.encoder_holder = [] 173 | for memory_layer in self.encoder_layers: 174 | self.embs, attention_weights_pack = memory_layer( 175 | self.embs, positional_masks, training=training 176 | ) # (batch_size, t, emb_dim) 177 | self.encoder_holder.append((self.embs, attention_weights_pack)) 178 | 179 | # we run the transformer decoder on the pick information. So, at P1P5 decision, 180 | # the transformer gets passed what the human took at P1P4. Attention with a 181 | # lookahead mask lets the pick information represent the whole pool, because 182 | # the algorithm attends to prior picks, so at P1P5 the decoder looks at 183 | # P1P1-P1P4, which is the pool. 184 | # 185 | # NOTE: at P1P1, we represent the pick (since there's no prior info) with a 186 | # vector representationt that is meant to describe the bias at the beginning 187 | # of the draft. 188 | # TODO: explore adding positional information to the picks here. Should it be the 189 | # same positional embedding, or a different one? 190 | self.dec_embs = self.card_embedding(picks, training=training) 191 | if training and self.dropout > 0.0: 192 | self.dec_embs = tf.nn.dropout(self.dec_embs, rate=self.dropout) 193 | 194 | self.decoder_holder = [] 195 | for memory_layer in self.decoder_layers: 196 | self.dec_embs, attention_weights_pick = memory_layer( 197 | self.dec_embs, 198 | positional_masks, 199 | encoder_output=self.embs, 200 | training=training, 201 | ) # (batch_size, t, emb_dim) 202 | self.decoder_holder.append((self.dec_embs, attention_weights_pick)) 203 | # in order to remove all cards in the set not in the pack as options, we create a 204 | # mask that will guarantee the values will be zero when applying softmax 205 | self.mask_for_softmax = 1e9 * (1 - packs) 206 | self.card_rankings = ( 207 | self.output_decoder(self.dec_embs, training=training) * packs - self.mask_for_softmax 208 | ) # (batch_size, t, n_cards) 209 | # compute the euclidian distance between each card embedding from the pack and 210 | # the output of the transformer decoder. This is used to regularize the network 211 | # by saying "the embedding for the correct pick should be close to the output 212 | # from the transformer, and far from the other cards in the pack". Conceptually 213 | # taken from this paper: https://ieee-cog.org/2021/assets/papers/paper_75.pdf. 214 | # NOTE: I tested the direct implementation of this paper where, rather than using 215 | # `self.output_decoder` to determine the card rankings, you just directly pick 216 | # the card with the closest distance to the output of the context (transformer 217 | # decoder). This consistently lagged behind using the decoder on validation 218 | # performance. Still a lot to experiment with the embedding space. 219 | self.emb_dists = ( 220 | tf.sqrt( 221 | tf.reduce_sum( 222 | tf.square(self.pack_card_embeddings - self.dec_embs[:, :, None, :]), 223 | -1, 224 | ) 225 | ) 226 | * packs 227 | + self.mask_for_softmax 228 | ) 229 | self.output = tf.nn.softmax(self.card_rankings) 230 | 231 | if return_attention: 232 | return self.output, (attention_weights_pack, attention_weights_pick) 233 | return self.output, self.emb_dists 234 | 235 | def compile( 236 | self, 237 | optimizer=None, 238 | learning_rate=0.001, 239 | margin=0.1, 240 | emb_lambda=1.0, 241 | pred_lambda=1.0, 242 | bad_behavior_lambda=1.0, 243 | rare_lambda=10.0, 244 | cmc_lambda=1.0, 245 | cmc_margin=1.0, 246 | metric_names=["top1", "top2", "top3"], 247 | ): 248 | """ 249 | After initializing the model, we want to compile it by setting parameters for training 250 | 251 | optimizer: optimizer to use for minimizing the objective function (default is Adam) 252 | learning_rate: learning rate for the optimizer. If passing {'lr_warmup':N}, it will use 253 | Adam with a scheduler that warms up the LR. This is recommended. 254 | margin: the minimal distance margin for triplet loss on the card embeddings 255 | emb_lambda: the regularization coefficient for triplet loss on the card embeddings 256 | pred_lambda: the coefficent for the main prediction task of the loss function 257 | bad_behavior_lambda: the regularization coefficient to be applied to all penalties that 258 | are structured as expert priors to avoid learning unwanted behavior 259 | such as rare drafting 260 | rare_lambda: the regularization coefficient for the penalty on taking rares when the 261 | bot shouldn't 262 | cmc_lambda: the regularization coefficient for the penalty that asks the model to bias 263 | towards cheaper cards. 264 | cmc_margin: the minimal distance margin for when to incur a penalty for taking expensive 265 | cards. For example: if cmc_margin = 1.5, the bot confidently wants to take 266 | a 5 drop, but the human takes a 3-drop, we incur a penalty of 0.5. If the 267 | bot takes a card with cmc four or less, the penalty would be zero. 268 | metric_names: a list of which metrics to use to help debug. 269 | """ 270 | if optimizer is None: 271 | if isinstance(learning_rate, dict): 272 | learning_rate = CustomSchedule(self.emb_dim, **learning_rate) 273 | else: 274 | learning_rate = learning_rate 275 | 276 | self.optimizer = tf.keras.optimizers.Adam( 277 | learning_rate=learning_rate, beta_1=0.9, beta_2=0.98, epsilon=1e-9 278 | ) 279 | else: 280 | self.optimizer = optimizer 281 | # because our output is softmax, CategoricalCrossentropy is the proper loss function 282 | self.loss_f = tf.keras.losses.SparseCategoricalCrossentropy(reduction=tf.keras.losses.Reduction.SUM) 283 | self.margin = margin 284 | self.emb_lambda = emb_lambda 285 | self.pred_lambda = pred_lambda 286 | self.bad_behavior_lambda = bad_behavior_lambda 287 | self.rare_lambda = rare_lambda 288 | self.cmc_lambda = cmc_lambda 289 | self.cmc_margin = cmc_margin 290 | self.set_card_params(self.card_data.iloc[:-1, :]) 291 | self.metric_names = metric_names 292 | 293 | def set_card_params(self, card_data): 294 | """ 295 | Create attributes that allow us to do computations with card level data 296 | """ 297 | # this flag allows us to incur extra regularization penalty when it appears like 298 | # the model has improperly learned to rare draft 299 | self.rare_flag = (card_data["mythic"] + card_data["rare"]).values[None, None, :] 300 | # this lets us easily convert packs or pools to cmc representations to help the 301 | # model bias towards cheaper cards in general 302 | self.cmc = card_data["cmc"].values[None, None, :] 303 | 304 | def loss(self, true, pred, sample_weight=None, training=None, **kwargs): 305 | """ 306 | implementation of the loss function. 307 | """ 308 | pred, emb_dists = pred 309 | # store inputs in case data causes an issue on specific batch 310 | self.last_true = true 311 | self.last_sample_weight = sample_weight 312 | # CategoricalCrossentropy loss applied to what the human took vs softmax output 313 | self.prediction_loss = self.loss_f(true, pred, sample_weight=sample_weight) 314 | # get the one hot representation of what the human picked 315 | correct_one_hot = tf.one_hot(true, self.n_cards) 316 | # get the distance between the incorrect picks and the transformer output 317 | dist_of_not_correct = emb_dists * (1 - correct_one_hot) 318 | # get the distance between the correct pick and the transformer output 319 | dist_of_correct = tf.reduce_sum(emb_dists * correct_one_hot, axis=-1, keepdims=True) 320 | # we want the distance from the correct pick to transformer output to be smaller 321 | # than the incorrect picks to the transformer output. So we do the following 322 | # subtraction because `dist_loss` will be negative in that case. This is where 323 | # the `margin` parameter comes into play. We want this subtraction to yield 324 | # *at least* -self.margin. Otherwise, we incure a penalty. 325 | dist_loss = dist_of_correct - dist_of_not_correct 326 | sample_weight = 1 if sample_weight is None else sample_weight 327 | self.embedding_loss = tf.reduce_sum( 328 | tf.reduce_sum(tf.maximum(dist_loss + self.margin, 0.0), axis=-1) * sample_weight 329 | ) 330 | # compute loss from expert priors (e.g. no rare drafting, take cheaper cards) 331 | self.bad_behavior_loss = self.determine_bad_behavior(true, pred, sample_weight=sample_weight) 332 | 333 | return ( 334 | self.pred_lambda * self.prediction_loss 335 | + self.emb_lambda * self.embedding_loss 336 | + self.bad_behavior_lambda * self.bad_behavior_loss 337 | ) 338 | 339 | def determine_bad_behavior(self, true, pred, sample_weight=None): 340 | true_one_hot = tf.one_hot(true, self.n_cards) 341 | # penalize for taking more expensive cards than what the human took 342 | # basically, if you're going to make a mistake, bias to low cmc cards 343 | true_cmc = tf.reduce_sum(true_one_hot * self.cmc, axis=-1) 344 | pred_cmc = tf.reduce_sum(pred * self.cmc, axis=-1) 345 | cmc_loss = tf.maximum(pred_cmc - true_cmc + self.cmc_margin, 0.0) * self.cmc_lambda 346 | self.cmc_loss = tf.reduce_sum(cmc_loss * sample_weight) 347 | # penalize taking rares when the human doesn't. This helps for generalization. Think 348 | # about it like this: people *love* taking rares. This means, when they choose not 349 | # to take a rare, that pick is likely important and full of information we want to 350 | # learn. Hence, we incur a *massive* penalty (this is why default rare_lambda=10.0) 351 | # to tell the model "when a person doesn't take a rare, you better pay attention". 352 | # Additionally, this prevents the model from learning to rare draft! 353 | human_took_rare = tf.reduce_sum(true_one_hot * self.rare_flag, axis=-1) 354 | pred_rare_val = tf.reduce_sum(pred * self.rare_flag, axis=-1) 355 | rare_loss = (1 - human_took_rare) * pred_rare_val * self.rare_lambda 356 | self.rare_loss = tf.reduce_sum(rare_loss * sample_weight) 357 | return self.cmc_loss + self.rare_loss 358 | 359 | def compute_metrics(self, true, pred, sample_weight=None, **kwargs): 360 | """ 361 | compute top1, top2, top3 accuracy to display as metrics during training when verbose=True 362 | """ 363 | if sample_weight is None: 364 | sample_weight = tf.ones_like(true.shape) / (true.shape[0] * true.shape[1]) 365 | # TODO: this caused a shape error, but didn't previously. Look into in detail later, comment out for now. 366 | # sample_weight = sample_weight.flatten() 367 | pred, _ = pred 368 | top1 = tf.reduce_sum(tf.keras.metrics.sparse_top_k_categorical_accuracy(true, pred, 1) * sample_weight) 369 | top2 = tf.reduce_sum(tf.keras.metrics.sparse_top_k_categorical_accuracy(true, pred, 2) * sample_weight) 370 | top3 = tf.reduce_sum(tf.keras.metrics.sparse_top_k_categorical_accuracy(true, pred, 3) * sample_weight) 371 | return {"top1": top1, "top2": top2, "top3": top3} 372 | 373 | def save(self, location): 374 | """ 375 | store the trained model and important attributes from the model to a file 376 | """ 377 | pathlib.Path(location).mkdir(parents=True, exist_ok=True) 378 | model_loc = os.path.join(location, "model") 379 | tf.saved_model.save(self, model_loc) 380 | data_loc = os.path.join(location, "attrs.pkl") 381 | with open(data_loc, "wb") as f: 382 | attrs = { 383 | "t": self.t, 384 | "idx_to_name": self.idx_to_name, 385 | "n_cards": self.n_cards, 386 | "embeddings": self.card_embedding(tf.range(self.n_cards), training=False), 387 | } 388 | pickle.dump(attrs, f) 389 | 390 | 391 | class DeckBuilder(tf.Module): 392 | """ 393 | Custom Tensorflow Model for Magic: the Gathering DeckBuilder AI 394 | 395 | This algorithm is an Denoising AutoEncoder: 396 | Deckbuilding in Limited is about taking a card pool, and yielding a 397 | subset of that pool as a deck, which is effectively "denoising" the pool 398 | by removing the sideboard cards from it. 399 | 400 | However, just the Denoising AutoEncoder has a few problems. 401 | 1. It doesn't address basics 402 | 2. It has difficulties during inference because inference is a discrete 403 | problem, and training is continuous 404 | 405 | Addressing basics: 406 | Observe that adding basics to a deck is a function of the final deck, and 407 | not the direct card pool. So, let DeckBuilder(pool) -> deck_projection. Then, 408 | we want to learn an additional function F(deck_projection) -> basics. 409 | 410 | Addressing inference: 411 | If we just iteratively take the argmax of the output from DeckBuilder(pool), 412 | then multiples of cards are treated poorly. If a pool contains two copies of 413 | Card A, and one copy of Card B, how should you determine what to add to the 414 | deck when the model says "add 1.7 copies of card A and 0.75 copies of card B"? 415 | If you only have a few slots left, should you add two copies of A and 0 of B? 416 | 1 and 1? It's unclear, and often yields issues (empirically it did at least). 417 | 418 | So, we modify the problem such that an iterative argmax makes sense. Rather than 419 | having the input just be a pool, we pass the model the available pool, and the 420 | current deck, where the current deck can be of any size and simply represents 421 | "cards in the pool that MUST be added to the deck at the end". This way, at 422 | inference, we can do the following: 423 | 424 | 1. Pass a full pool and an empty deck 425 | 2. Allocate the argmax of the output to the deck, and subtract it from the pool 426 | 3. Run the model again, and repeat until the model says to stop adding cards 427 | 4. Take the final allocation, and pass it to F mentioned in Addressing basics 428 | section to yield the basics corresponding to the pool. 429 | 430 | In order to accomplish this, we generate deck data as follows: 431 | 1. Sample a data point, which contains a deck and a sideboard 432 | 2. Sample N cards from the deck, set that to your target 433 | 3. Remove those N points from the deck, and add them to the sideboard 434 | 4. Now the sideboard is the options, and the deck is currently allocated cards! 435 | 5. If you'd like to view the code, look at DeckGenerator in mtg/ml.generator.py 436 | 437 | ------------------------------------------------------------------------------------------- 438 | 439 | n_cards: number of cards in the set, EXCLUDING basics 440 | dropout: Dropout rate for the encoders for the pool and partial deck 441 | latent_dim: The input (pool, partial deck) pair gets projected to a latent 442 | space of this dimension 443 | embeddings: The dimension for card embeddings. If a matrix is passed, it is 444 | treated as pretrained card embeddings and frozen. 445 | """ 446 | 447 | def __init__( 448 | self, 449 | n_cards, 450 | dropout=0.0, 451 | latent_dim=32, 452 | embeddings=128, 453 | name=None, 454 | ): 455 | super().__init__(name=name) 456 | self.n_cards = n_cards 457 | if isinstance(embeddings, int): 458 | emb_trainable = True 459 | initializer = tf.initializers.glorot_normal() 460 | emb_init = initializer(shape=(self.n_cards, embeddings)) 461 | else: 462 | emb_trainable = False 463 | emb_init = embeddings 464 | self.card_embeddings = tf.Variable(emb_init, trainable=emb_trainable) 465 | # we use card embeddings to project the pool and partial deck to a vector 466 | # space, we concatenate them and then project to `latent_dim`, so concat 467 | # dim is always the embedding dimension * 2 468 | concat_dim = self.card_embeddings.shape[1] * 2 469 | # MLP that takes the latent representation and decodes to the projection of 470 | # what cards to add to the deck. 471 | self.card_decoder = nn.MLP( 472 | in_dim=latent_dim, 473 | start_dim=latent_dim * 2, 474 | out_dim=self.n_cards, 475 | n_h_layers=2, 476 | dropout=0.0, 477 | name="card_decoder", 478 | noise=0.0, 479 | start_act=tf.nn.relu, 480 | middle_act=tf.nn.relu, 481 | out_act=tf.nn.sigmoid, 482 | style="reverse_bottleneck", 483 | ) 484 | # MLP that takes the projection of final deck and adds the basics 485 | self.basic_decoder = nn.MLP( 486 | in_dim=self.n_cards, 487 | start_dim=self.n_cards // 2, 488 | out_dim=5, 489 | n_h_layers=2, 490 | dropout=0.0, 491 | name="basic_decoder", 492 | noise=0.0, 493 | start_act=tf.nn.relu, 494 | middle_act=tf.nn.relu, 495 | out_act=tf.nn.softmax, 496 | style="reverse_bottleneck", 497 | ) 498 | # We learn to determine the number of non basics from a fully built deck to 499 | # properly allocate the number of basics, as well as terminate the iterative 500 | # process during inference. 501 | # TODO: experiment with using a sum along the last axis rather than a Dense layer, 502 | # although I expect this to have issues because the output is not discrete 503 | # NOTE: activation of relu + 22 -> a minimum of 22 non-basics. This does mean that it 504 | # is impossible to play more than 18 lands without cards like evolving wilds, and 505 | # may improperly bias away from playing 18 lands when necessary. 506 | self.determine_n_non_basics = nn.Dense( 507 | self.n_cards, 508 | 1, 509 | activation=lambda x: tf.nn.relu(x) + 22.0, 510 | name="determine_n_non_basics", 511 | ) 512 | # Dense layer that takes the concatenated pool and partial deck embeddings and 513 | # projects it to the latent representation of the deck 514 | self.merge_deck_and_pool = nn.Dense(concat_dim, latent_dim, activation=None, name="merge_deck_and_pool") 515 | self.dropout = dropout 516 | 517 | # TODO: change input data to not require relaxed shape, and change from 518 | # vector of n_cards size to vector of size cards in pool for efficiency 519 | @tf.function(experimental_relax_shapes=True) 520 | def __call__(self, features, training=None): 521 | # batch x sample x n_cards 522 | pools, decks = features 523 | # project pool and partial deck to their respective latent space as sums of 524 | # card embeddings 525 | self.latent_rep_pool = tf.reduce_sum(pools[:, :, :, None] * self.card_embeddings[None, None, :, :], axis=2) 526 | self.latent_rep_deck = tf.reduce_sum(decks[:, :, :, None] * self.card_embeddings[None, None, :, :], axis=2) 527 | # concatenate representation of pool and partial deck 528 | concat_emb = tf.concat([self.latent_rep_deck, self.latent_rep_pool], axis=-1) 529 | if self.dropout > 0.0 and training: 530 | concat_emb = tf.nn.dropout(concat_emb, self.dropout) 531 | # yield final latent representation of deck 532 | self.latent_rep = self.merge_deck_and_pool(concat_emb, training=training) 533 | # compute the cards to add from the available pool 534 | self.cards_to_add = self.card_decoder(self.latent_rep, training=training) * pools 535 | # the final built deck is equal to the cards we want to allocate from the pool 536 | # added to the partial deck input of cards already allocated to the deck 537 | built_deck = self.cards_to_add + decks 538 | # given the deck, determine how many non-basics, and hence basics, we want 539 | self.n_non_basics = self.determine_n_non_basics(built_deck, training=training) 540 | n_basics = 40 - self.n_non_basics 541 | # finally, add the basics to the deck! 542 | self.basics_to_add = self.basic_decoder(built_deck, training=training) * n_basics 543 | 544 | return self.basics_to_add, self.cards_to_add, self.n_non_basics 545 | 546 | def compile( 547 | self, 548 | card_data, 549 | learning_rate=0.001, 550 | basic_lambda=1.0, 551 | built_lambda=1.0, 552 | cmc_lambda=0.01, 553 | optimizer=None, 554 | metric_names=["basics_off", "spells_off"], 555 | ): 556 | """ 557 | After initializing the model, we want to compile it by setting parameters for training 558 | 559 | optimizer: optimizer to use for minimizing the objective function (default is Adam) 560 | learning_rate: learning rate for the optimizer. If passing {'lr_warmup':N}, it will use 561 | Adam with a scheduler that warms up the LR. This is recommended. 562 | basic_lambda: the coefficient for matching the basics the human chose to add 563 | built_lambda: the coefficent for matching the non-basics the human chose to add 564 | cmc_lambda: the regularization coefficient for the penalty that asks the model to bias 565 | towards building decks with similar curves to humans (needs improvement) 566 | metric_names: a list of which metrics to use to help debug. 567 | """ 568 | if optimizer is None: 569 | if isinstance(learning_rate, dict): 570 | learning_rate = CustomSchedule(500, **learning_rate) 571 | else: 572 | learning_rate = learning_rate 573 | 574 | self.optimizer = tf.keras.optimizers.Adam( 575 | learning_rate=learning_rate, beta_1=0.9, beta_2=0.98, epsilon=1e-9 576 | ) 577 | else: 578 | self.optimizer = optimizer 579 | 580 | self.basic_lambda = basic_lambda 581 | self.built_lambda = built_lambda 582 | 583 | self.cmc_lambda = cmc_lambda 584 | self.set_card_params(card_data) 585 | self.metric_names = metric_names 586 | 587 | def set_card_params(self, card_data): 588 | """ 589 | Create attributes that allow us to do computations with card level data 590 | """ 591 | # this lets us easily compute the distribution of cmc in any deck to help 592 | # the model yield similar curves to how humans approach deckbuilding 593 | self.cmc_map = card_data["cmc"].to_numpy(dtype=np.float32) 594 | 595 | def loss(self, true, pred, sample_weight=None, **kwargs): 596 | """ 597 | implementation of the loss function. It is currently using MSE instead of MAE. MAE is 598 | intuitively a better fit because it would yield more sparse predictions. However, 599 | I empirically found this yielded problematic generalization results because it pushed 600 | the value of all cards in consideration to 1, which meant taking the argmax was often 601 | insufficient for building decks during inference. 602 | """ 603 | true_basics, true_built = true 604 | pred_basics, pred_built, _ = pred 605 | # penalize the model for improperly allocating basic lands 606 | self.basic_loss = tf.reduce_sum( 607 | tf.reduce_sum(tf.math.square(pred_basics - true_basics), axis=-1) * sample_weight 608 | ) 609 | # penalize the model for impoperly allocating non-basic-lands and spells 610 | self.built_loss = tf.reduce_sum(tf.reduce_sum(tf.math.square(pred_built - true_built), axis=-1) * sample_weight) 611 | # penalize the model for deviating from the average curve of the deck a person built 612 | if self.cmc_lambda > 0: 613 | # TODO: test replacing this with KL-Divergence on the distribution of the curve 614 | # hopefully this helps the model play worse two-drops when needed, which it 615 | # is currently not great at (it definitely does it, but not enough) 616 | self.pred_curve_average = tf.reduce_mean( 617 | tf.multiply(pred_built, tf.expand_dims(self.cmc_map[5:], 0)), axis=-1 618 | ) 619 | self.true_curve_average = tf.reduce_mean( 620 | tf.multiply(true_built, tf.expand_dims(self.cmc_map[5:], 0)), axis=-1 621 | ) 622 | self.curve_incentive = tf.reduce_sum(abs(self.pred_curve_average - self.true_curve_average) * sample_weight) 623 | else: 624 | self.curve_incentive = 0.0 625 | 626 | return ( 627 | self.basic_lambda * self.basic_loss 628 | + self.built_lambda * self.built_loss 629 | + self.cmc_lambda * self.curve_incentive 630 | ) 631 | 632 | def compute_metrics(self, true, pred, sample_weight=None, training=None, **kwargs): 633 | pred_basics, pred_built, _ = pred 634 | true_basics, true_decks = true 635 | if sample_weight is None: 636 | sample_weight = 1.0 / true_decks.shape[0] 637 | # compute the average number of basics off the model is from human builds 638 | basic_diff = tf.reduce_sum(tf.reduce_sum(tf.math.abs(pred_basics - true_basics), axis=-1) * sample_weight) 639 | # compute the average number of non-basics and spells off the model is from human builds 640 | deck_diff = tf.reduce_sum(tf.reduce_sum(tf.math.abs(pred_built - true_decks), axis=-1) * sample_weight) 641 | return {"basics_off": basic_diff, "spells_off": deck_diff} 642 | 643 | def save(self, cards, location): 644 | """ 645 | store the trained model and card object to a file 646 | """ 647 | pathlib.Path(location).mkdir(parents=True, exist_ok=True) 648 | model_loc = os.path.join(location, "model") 649 | data_loc = os.path.join(location, "cards.pkl") 650 | tf.saved_model.save(self, model_loc) 651 | with open(data_loc, "wb") as f: 652 | pickle.dump(cards, f) 653 | --------------------------------------------------------------------------------