├── .gitignore ├── Changelog.md ├── LICENSE ├── Readme.md ├── cocob.py ├── data ├── 2017-08-15_2017-09-11.csv.zip └── placeholder ├── extractor.py ├── feeder.py ├── how_it_works.md ├── hparams.py ├── images ├── attention.xml ├── autocorr.png ├── encoder-decoder.png ├── encoder-decoder.xml ├── from_past.png ├── lagged_data.png ├── losses_0.png ├── losses_1.png ├── predictions.png ├── split.png ├── training.png └── validation-split.xml ├── input_pipe.py ├── make_features.py ├── model.py ├── requirements.txt ├── submission-final.ipynb └── trainer.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/* 2 | 3 | # Jupyter Notebook 4 | .ipynb_checkpoints 5 | 6 | data/cpt 7 | data/logs 8 | data/vars 9 | data/*.pkl 10 | data/*.zip 11 | data/submission.csv.gz 12 | !data/2017-08-15_2017-09-11.csv.zip 13 | 14 | -------------------------------------------------------------------------------- /Changelog.md: -------------------------------------------------------------------------------- 1 | 2018-10-15 2 | - Model updated to work with a modern Tensorflow (>=1.10) 3 | - Switched to Adam instead of COCOB (COCOB don't works with TF > 1.4) 4 | - No parameter tuning for Adam performed, therefore model probably has 5 | suboptimal training rate and did'nt reproduce exact result from the competition -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Artur Suilin 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Readme.md: -------------------------------------------------------------------------------- 1 | # Kaggle Web Traffic Time Series Forecasting 2 | 1st place solution 3 | 4 | ![predictions](images/predictions.png) 5 | 6 | Main files: 7 | * `make_features.py` - builds features from source data 8 | * `input_pipe.py` - TF data preprocessing pipeline (assembles features 9 | into training/evaluation tensors, performs some sampling and normalisation) 10 | * `model.py` - the model 11 | * `trainer.py` - trains the model(s) 12 | * `hparams.py` - hyperpatameter sets. 13 | * `submission-final.ipynb` - generates predictions for submission 14 | 15 | How to reproduce competition results: 16 | 1. Download input files from https://www.kaggle.com/c/web-traffic-time-series-forecasting/data : 17 | `key_2.csv.zip`, `train_2.csv.zip`, put them into `data` directory. 18 | 2. Run `python make_features.py data/vars --add_days=63`. It will 19 | extract data and features from the input files and put them into 20 | `data/vars` as Tensorflow checkpoint. 21 | 3. Run trainer: 22 | `python trainer.py --name s32 --hparam_set=s32 --n_models=3 --name s32 --no_eval --no_forward_split 23 | --asgd_decay=0.99 --max_steps=11500 --save_from_step=10500`. This command 24 | will simultaneously train 3 models on different seeds (on a single TF graph) 25 | and save 10 checkpoints from step 10500 to step 11500 to `data/cpt`. 26 | __Note:__ training requires GPU, because of cuDNN usage. CPU training will not work. 27 | If you have 3 or more GPUs, add `--multi_gpu` flag to speed up the training. One can also try different 28 | hyperparameter sets (described in `hparams.py`): `--hparam_set=definc`, 29 | `--hparam_set=inst81`, etc. 30 | Don't be afraid of displayed NaN losses during training. This is normal, 31 | because we do the training in a blind mode, without any evaluation of model performance. 32 | 4. Run `submission-final.ipynb` in a standard jupyter notebook environment, 33 | execute all cells. Prediction will take some time, because it have to 34 | load and evaluate 30 different model weights. At the end, 35 | you'll get `submission.csv.gz` file in `data` directory. 36 | 37 | See also [detailed model description](how_it_works.md) 38 | -------------------------------------------------------------------------------- /cocob.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Francesco Orabona. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | 17 | """ 18 | COntinuos COin Betting (COCOB) optimizer 19 | See 'Training Deep Networks without Learning Rates Through Coin Betting' 20 | https://arxiv.org/abs/1705.07795 21 | """ 22 | 23 | from tensorflow.python.framework import ops 24 | from tensorflow.python.ops import state_ops 25 | from tensorflow.python.ops import control_flow_ops 26 | from tensorflow.python.framework import constant_op 27 | from tensorflow.python.training.optimizer import Optimizer 28 | import tensorflow as tf 29 | 30 | 31 | 32 | class COCOB(Optimizer): 33 | def __init__(self, alpha=100, use_locking=False, name='COCOB'): 34 | ''' 35 | constructs a new COCOB optimizer 36 | ''' 37 | super(COCOB, self).__init__(use_locking, name) 38 | self._alpha = alpha 39 | 40 | def _create_slots(self, var_list): 41 | for v in var_list: 42 | with ops.colocate_with(v): 43 | gradients_sum = constant_op.constant(0, 44 | shape=v.get_shape(), 45 | dtype=v.dtype.base_dtype) 46 | grad_norm_sum = constant_op.constant(0, 47 | shape=v.get_shape(), 48 | dtype=v.dtype.base_dtype) 49 | L = constant_op.constant(1e-8, shape=v.get_shape(), dtype=v.dtype.base_dtype) 50 | tilde_w = constant_op.constant(0.0, shape=v.get_shape(), dtype=v.dtype.base_dtype) 51 | reward = constant_op.constant(0.0, shape=v.get_shape(), dtype=v.dtype.base_dtype) 52 | 53 | self._get_or_make_slot(v, L, "L", self._name) 54 | self._get_or_make_slot(v, grad_norm_sum, "grad_norm_sum", self._name) 55 | self._get_or_make_slot(v, gradients_sum, "gradients_sum", self._name) 56 | self._get_or_make_slot(v, tilde_w, "tilde_w", self._name) 57 | self._get_or_make_slot(v, reward, "reward", self._name) 58 | 59 | def _apply_dense(self, grad, var): 60 | gradients_sum = self.get_slot(var, "gradients_sum") 61 | grad_norm_sum = self.get_slot(var, "grad_norm_sum") 62 | tilde_w = self.get_slot(var, "tilde_w") 63 | L = self.get_slot(var, "L") 64 | reward = self.get_slot(var, "reward") 65 | 66 | L_update = tf.maximum(L, tf.abs(grad)) 67 | gradients_sum_update = gradients_sum + grad 68 | grad_norm_sum_update = grad_norm_sum + tf.abs(grad) 69 | reward_update = tf.maximum(reward - grad * tilde_w, 0) 70 | new_w = -gradients_sum_update / ( 71 | L_update * (tf.maximum(grad_norm_sum_update + L_update, self._alpha * L_update))) * (reward_update + L_update) 72 | var_update = var - tilde_w + new_w 73 | tilde_w_update = new_w 74 | 75 | gradients_sum_update_op = state_ops.assign(gradients_sum, gradients_sum_update) 76 | grad_norm_sum_update_op = state_ops.assign(grad_norm_sum, grad_norm_sum_update) 77 | var_update_op = state_ops.assign(var, var_update) 78 | tilde_w_update_op = state_ops.assign(tilde_w, tilde_w_update) 79 | L_update_op = state_ops.assign(L, L_update) 80 | reward_update_op = state_ops.assign(reward, reward_update) 81 | 82 | return control_flow_ops.group(*[gradients_sum_update_op, 83 | var_update_op, 84 | grad_norm_sum_update_op, 85 | tilde_w_update_op, 86 | reward_update_op, 87 | L_update_op]) 88 | 89 | def _apply_sparse(self, grad, var): 90 | return self._apply_dense(grad, var) 91 | 92 | def _resource_apply_dense(self, grad, handle): 93 | return self._apply_dense(grad, handle) 94 | -------------------------------------------------------------------------------- /data/2017-08-15_2017-09-11.csv.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Arturus/kaggle-web-traffic/a9abb80c800409abf0ece21ea244ef779f758f96/data/2017-08-15_2017-09-11.csv.zip -------------------------------------------------------------------------------- /data/placeholder: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Arturus/kaggle-web-traffic/a9abb80c800409abf0ece21ea244ef779f758f96/data/placeholder -------------------------------------------------------------------------------- /extractor.py: -------------------------------------------------------------------------------- 1 | import re 2 | import pandas as pd 3 | import numpy as np 4 | 5 | term_pat = re.compile('(.+?):(.+)') 6 | pat = re.compile( 7 | '(.+)_([a-z][a-z]\.)?((?:wikipedia\.org)|(?:commons\.wikimedia\.org)|(?:www\.mediawiki\.org))_([a-z_-]+?)$') 8 | 9 | # Debug output to ensure pattern still works 10 | # print(pat.fullmatch('BLEACH_zh.wikipedia.org_all-accessspider').groups()) 11 | # print(pat.fullmatch('Accueil_commons.wikimedia.org_all-access_spider').groups()) 12 | 13 | 14 | def extract(source) -> pd.DataFrame: 15 | """ 16 | Extracts features from url. Features: agent, site, country, term, marker 17 | :param source: urls 18 | :return: DataFrame, one column per feature 19 | """ 20 | if isinstance(source, pd.Series): 21 | source = source.values 22 | agents = np.full_like(source, np.NaN) 23 | sites = np.full_like(source, np.NaN) 24 | countries = np.full_like(source, np.NaN) 25 | terms = np.full_like(source, np.NaN) 26 | markers = np.full_like(source, np.NaN) 27 | 28 | for i in range(len(source)): 29 | l = source[i] 30 | match = pat.fullmatch(l) 31 | assert match, "Non-matched string %s" % l 32 | term = match.group(1) 33 | country = match.group(2) 34 | if country: 35 | countries[i] = country[:-1] 36 | site = match.group(3) 37 | sites[i] = site 38 | agents[i] = match.group(4) 39 | if site != 'wikipedia.org': 40 | term_match = term_pat.match(term) 41 | if term_match: 42 | markers[i] = term_match.group(1) 43 | term = term_match.group(2) 44 | terms[i] = term 45 | 46 | return pd.DataFrame({ 47 | 'agent': agents, 48 | 'site': sites, 49 | 'country': countries, 50 | 'term': terms, 51 | 'marker': markers, 52 | 'page': source 53 | }) 54 | -------------------------------------------------------------------------------- /feeder.py: -------------------------------------------------------------------------------- 1 | from collections import UserList, UserDict 2 | from typing import Union, Iterable, Tuple, Dict, Any 3 | 4 | import tensorflow as tf 5 | import numpy as np 6 | import pandas as pd 7 | import pickle 8 | import os.path 9 | 10 | 11 | def _meta_file(path): 12 | return os.path.join(path, 'feeder_meta.pkl') 13 | 14 | 15 | class VarFeeder: 16 | """ 17 | Helper to avoid feed_dict and manual batching. Maybe I had to use TFRecords instead. 18 | Builds temporary TF graph, injects variables into, and saves variables to TF checkpoint. 19 | In a train time, variables can be built by build_vars() and content restored by FeederVars.restore() 20 | """ 21 | def __init__(self, path: str, 22 | tensor_vars: Dict[str, Union[pd.DataFrame, pd.Series, np.ndarray]] = None, 23 | plain_vars: Dict[str, Any] = None): 24 | """ 25 | :param path: dir to store data 26 | :param tensor_vars: Variables to save as Tensors (pandas DataFrames/Series or numpy arrays) 27 | :param plain_vars: Variables to save as Python objects 28 | """ 29 | tensor_vars = tensor_vars or dict() 30 | 31 | def get_values(v): 32 | v = v.values if hasattr(v, 'values') else v 33 | if not isinstance(v, np.ndarray): 34 | v = np.array(v) 35 | if v.dtype == np.float64: 36 | v = v.astype(np.float32) 37 | return v 38 | 39 | values = [get_values(var) for var in tensor_vars.values()] 40 | 41 | self.shapes = [var.shape for var in values] 42 | self.dtypes = [v.dtype for v in values] 43 | self.names = list(tensor_vars.keys()) 44 | self.path = path 45 | self.plain_vars = plain_vars 46 | 47 | if not os.path.exists(path): 48 | os.mkdir(path) 49 | 50 | with open(_meta_file(path), mode='wb') as file: 51 | pickle.dump(self, file) 52 | 53 | with tf.Graph().as_default(): 54 | tensor_vars = self._build_vars() 55 | placeholders = [tf.placeholder(tf.as_dtype(dtype), shape=shape) for dtype, shape in 56 | zip(self.dtypes, self.shapes)] 57 | assigners = [tensor_var.assign(placeholder) for tensor_var, placeholder in 58 | zip(tensor_vars, placeholders)] 59 | feed = {ph: v for ph, v in zip(placeholders, values)} 60 | saver = tf.train.Saver(self._var_dict(tensor_vars), max_to_keep=1) 61 | init = tf.global_variables_initializer() 62 | 63 | with tf.Session(config=tf.ConfigProto(device_count={'GPU': 0})) as sess: 64 | sess.run(init) 65 | sess.run(assigners, feed_dict=feed) 66 | save_path = os.path.join(path, 'feeder.cpt') 67 | saver.save(sess, save_path, write_meta_graph=False, write_state=False) 68 | 69 | def _var_dict(self, variables): 70 | return {name: var for name, var in zip(self.names, variables)} 71 | 72 | def _build_vars(self): 73 | def make_tensor(shape, dtype, name): 74 | tf_type = tf.as_dtype(dtype) 75 | if tf_type == tf.string: 76 | empty = '' 77 | elif tf_type == tf.bool: 78 | empty = False 79 | else: 80 | empty = 0 81 | init = tf.constant(empty, shape=shape, dtype=tf_type) 82 | return tf.get_local_variable(name=name, initializer=init, dtype=tf_type) 83 | 84 | with tf.device("/cpu:0"): 85 | with tf.name_scope('feeder_vars'): 86 | return [make_tensor(shape, dtype, name) for shape, dtype, name in 87 | zip(self.shapes, self.dtypes, self.names)] 88 | 89 | def create_vars(self): 90 | """ 91 | Builds variable list to use in current graph. Should be called during graph building stage 92 | :return: variable list with additional restore and create_saver methods 93 | """ 94 | return FeederVars(self._var_dict(self._build_vars()), self.plain_vars, self.path) 95 | 96 | @staticmethod 97 | def read_vars(path): 98 | with open(_meta_file(path), mode='rb') as file: 99 | feeder = pickle.load(file) 100 | assert feeder.path == path 101 | return feeder.create_vars() 102 | 103 | 104 | class FeederVars(UserDict): 105 | def __init__(self, tensors: dict, plain_vars: dict, path): 106 | variables = dict(tensors) 107 | if plain_vars: 108 | variables.update(plain_vars) 109 | super().__init__(variables) 110 | self.path = path 111 | self.saver = tf.train.Saver(tensors, name='varfeeder_saver') 112 | for var in variables: 113 | if var not in self.__dict__: 114 | self.__dict__[var] = variables[var] 115 | 116 | def restore(self, session): 117 | """ 118 | Restores variable content 119 | :param session: current session 120 | :return: variable list 121 | """ 122 | self.saver.restore(session, os.path.join(self.path, 'feeder.cpt')) 123 | return self 124 | -------------------------------------------------------------------------------- /how_it_works.md: -------------------------------------------------------------------------------- 1 | # How it works 2 | __TL;DR__ this is seq2seq model with some additions to utilize year-to-year 3 | and quarter-to-quarter seasonality in data. 4 | ___ 5 | There are two main information sources for prediction: 6 | 1. Local features. If we see a trend, we expect that it will continue 7 | (AutoRegressive model), if we see a traffic spike, it will gradually decay (Moving Average model), 8 | if wee see more traffic on holidays, we expect to have more traffic on 9 | holidays in the future (seasonal model). 10 | 2. Global features. If we look to autocorrelation plot, we'll notice strong 11 | year-to-year autocorrelation and some quarter-to-quarter autocorrelation. 12 | 13 | ![autocorrelation](images/autocorr.png "Encoder-decoder") 14 | 15 | The good model should use both global and local features, combining them 16 | in a intelligent way. 17 | 18 | I decided to use RNN seq2seq model for prediction, because: 19 | 1. RNN can be thought as a natural extension of well-studied ARIMA models, but much more 20 | flexible and expressive. 21 | 2. RNN is non-parametric, that's greatly simplifies learning. 22 | Imagine working with different ARIMA parameters for 145K timeseries. 23 | 3. Any exogenous feature (numerical or categorical, time-dependent or series-dependent) 24 | can be easily injected into the model 25 | 4. seq2seq seems natural for this task: we predict next values, conditioning on joint 26 | probability of previous values, including our past predictions. Use of past predictions 27 | stabilizes the model, it learns to be conservative, because error accumulates on each step, 28 | and extreme prediction at one step can ruin prediction quality for all subsequent steps. 29 | 5. Deep Learning is all the hype nowadays 30 | 31 | ## Feature engineering 32 | I tried to be minimalistic, because RNN is powerful enough to discover 33 | and learn features on its own. 34 | Model feature list: 35 | * *pageviews* (spelled as 'hits' in the model code, because of my web-analytics background). 36 | Raw values transformed by log1p() to get more-or-less normal intra-series values distribution, 37 | instead of skewed one. 38 | * *agent*, *country*, *site* - these features are extracted from page urls and one-hot encoded 39 | * *day of week* - to capture weekly seasonality 40 | * *year-to-year autocorrelation*, *quarter-to-quarter autocorrelation* - to capture yearly and quarterly seasonality strength. 41 | * *page popularity* - High traffic and low traffic pages have different traffic change patterns, 42 | this feature (median of pageviews) helps to capture traffic scale. 43 | This scale information is lost in a *pageviews* feature, because each pageviews series 44 | independently normalized to zero mean and unit variance. 45 | * *lagged pageviews* - I'll describe this feature later 46 | 47 | ## Feature preprocessing 48 | All features (including one-hot encoded) are normalized to zero mean and unit variance. Each *pageviews* 49 | series normalized independently. 50 | 51 | Time-independent features (autocorrelations, country, etc) are "stretched" to timeseries length 52 | i.e. repeated for each day by `tf.tile()` command. 53 | 54 | Model trains on random fixed-length samples from original timeseries. For example, 55 | if original timeseries length is 600 days, and we use 200-day samples for training, 56 | we'll have a choice of 400 days to start the sample. 57 | 58 | This sampling works as effective data augmentation mechanism: 59 | training code randomly chooses starting point for each timeseries on each 60 | step, generating endless stream of almost non-repeating data. 61 | 62 | 63 | ## Model core 64 | Model has two main parts: encoder and decoder. 65 | 66 | ![seq2seq](images/encoder-decoder.png "Encoder-decoder") 67 | 68 | Encoder is [cuDNN GRU](https://www.tensorflow.org/versions/master/api_docs/python/tf/contrib/cudnn_rnn/CudnnGRU). cuDNN works much faster (5x-10x) than native Tensorflow RNNCells, at the cost 69 | of some inconvenience to use and poor documentation. 70 | 71 | Decoder is TF `GRUBlockCell`, wrapped in `tf.while_loop()` construct. Code 72 | inside the loop gets prediction from previous step and 73 | appends it to the input features for current step. 74 | 75 | ## Working with long timeseries 76 | LSTM/GRU is a great solution for relatively short sequences, up to 100-300 items. 77 | On longer sequences LSTM/GRU still works, but can gradually forget information from the oldest items. 78 | Competition timeseries is up to 700 days long, so I have to find some method 79 | to "strengthen" GRU memory. 80 | 81 | My first method was to use some kind of *[attention](https://distill.pub/2016/augmented-rnns)*. 82 | Attention can bring useful information from a distant past to the current 83 | RNN cell. The simplest yet effective attention method for our problem is a 84 | fixed-weight sliding-window attention. 85 | There are two most important points in a distant past (taking into 86 | account long-term seasonality): 1) year ago, 2) quarter ago. 87 | 88 | ![from_past](images/from_past.png) 89 | 90 | I can just take 91 | encoder outputs from `current_day - 365` and `current_day - 90` timepoints, 92 | pass them through FC layer to reduce dimensionality and append result to input 93 | features for decoder. This solution, despite of being simple, considerably 94 | lowered prediction error. 95 | 96 | Then I averaged important points with their neighbors to reduce noise and 97 | compensate uneven intervals (leap years, different month lengths): 98 | `attn_365 = 0.25 * day_364 + 0.5 * day_365 + 0.25 * day_366` 99 | 100 | Then I realized that `0.25,0.5,0.25` is a 1D convolutional kernel (length=3) 101 | and I can automatically learn bigger kernel to detect important points in a past. 102 | 103 | I ended up with a monstrous attention mechanism, it looks into 'fingerprint' 104 | of each timeseries (fingerprint produced by small ConvNet), decides 105 | which points to attend and produces weights for big convolution kernel. 106 | This big kernel, applied to decoder outputs, 107 | produces attention features for each prediction day. This monster is still 108 | alive and can be found in a model code. 109 | 110 | Note, I did'nt used classical attention scheme (Bahdanau or Luong attention), 111 | because classical attention should be recalculated from scratch on every prediction step, 112 | using all historical datapoints. This will take too much time for our long (~2 years) timeseries. 113 | My scheme, one convolution per all datapoints, uses same attention 114 | weights for all prediction steps (that's drawback), but much faster to compute. 115 | 116 | Unsatisfied by complexity of attention mechanics, I tried to remove attention 117 | completely and just take important (year, halfyear, quarter ago) datapoints 118 | from the past and use them as an additional 119 | features for encoder and decoder. That worked surprisingly well, even slightly 120 | surpassing attention in prediction quality. My best public score was 121 | achieved using only lagged datapoints, without attention. 122 | 123 | ![lagged_data](images/lagged_data.png "Lagged datapoints") 124 | 125 | Additional important benefit of lagged datapoints: model can use much shorter 126 | encoder without fear of losing information from the past, because this 127 | information now explicitly contained in features. Even 60-90 days long 128 | encoder still gives acceptable results, in contrast to 300-400 days 129 | required for previous models. Shorter encoder = faster training and less 130 | loss of information 131 | 132 | ## Losses and regularization 133 | [SMAPE](https://en.wikipedia.org/wiki/Symmetric_mean_absolute_percentage_error) 134 | (target loss for competition) can't be used directly, because of unstable 135 | behavior near zero values (loss is a step function if truth value is zero, 136 | and not defined, if predicted value is also zero). 137 | 138 | I used smoothed differentiable SMAPE variant, which is well-behaved at all real numbers: 139 | ```python 140 | epsilon = 0.1 141 | summ = tf.maximum(tf.abs(true) + tf.abs(predicted) + epsilon, 0.5 + epsilon) 142 | smape = tf.abs(predicted - true) / summ * 2.0 143 | ``` 144 | ![losses_0](images/losses_0.png "Losses for true value=0") 145 | ![losses_1](images/losses_1.png "Losses for true value=1") 146 | 147 | Another possible choice is MAE loss on `log1p(data)`, it's smooth almost everywhere 148 | and close enough to SMAPE for training purposes. 149 | 150 | Final predictions were rounded to the closest integer, negative predictions clipped at zero. 151 | 152 | I tried to use RNN activation regularizations from the paper 153 | ["Regularizing RNNs by Stabilizing Activations"](https://arxiv.org/abs/1511.08400), 154 | because internal weights in cuDNN GRU can't be directly regularized 155 | (or I did not found a right way to do this). 156 | Stability loss didn't work at all, activation loss gave some very 157 | slight improvement for low (1e-06..1e-05) loss weights. 158 | 159 | ## Training and validation 160 | I used COCOB optimizer (see paper [Training Deep Networks without Learning Rates Through Coin Betting](https://arxiv.org/abs/1705.07795)) for training, in combination with gradient clipping. 161 | COCOB tries to predict optimal learning rate for every training step, so 162 | I don't have to tune learning rate at all. It also converges considerably 163 | faster than traditional momentum-based optimizers, especially on first 164 | epochs, allowing me to stop unsuccessful experiments early. 165 | 166 | There are two ways to split timeseries into training and validation datasets: 167 | 1. *Walk-forward split*. This is not actually a split: we train on full dataset 168 | and validate on full dataset, using different timeframes. Timeframe 169 | for validation is shifted forward by one prediction interval relative to 170 | timeframe for training. 171 | 2. *Side-by-side split*. This is traditional split model for mainstream machine 172 | learning. Dataset splits into independent parts, one part used strictly 173 | for training and another part used strictly for validation. 174 | 175 | ![split](images/split.png "Split variants") 176 | 177 | I tried both ways. 178 | 179 | Walk-forward is preferable, because it directly relates to the competition goal: 180 | predict future values using historical values. But this split consumes 181 | datapoints at the end of timeseries, thus making hard to train model to 182 | precisely predict the future. 183 | 184 | Let's explain: for example, we have 300 days of historical data 185 | and want to predict next 100 days. If we choose walk-forward split, we'll have to use 186 | first 100 days for real training, next 100 days for training-mode prediction 187 | (run decoder and calculate losses), next 100 days for validation and 188 | next 100 days for actual prediction of future values. 189 | So we actually can use only 1/3 of available datapoints for training 190 | and will have 200 days gap between last training datapoint 191 | and first prediction datapoint. That's too much, because prediction quality 192 | falls exponentially as we move away from a training data (uncertainty grows). 193 | Model trained with a 100 days gap (instead of 200) would have considerable 194 | better quality. 195 | 196 | Side-by-side split is more economical, as it don't consumes datapoints at the 197 | end. That was a good news. Now the bad news: for our data, model performance 198 | on validation dataset is strongly correlated to performance on training dataset, 199 | and almost uncorrelated to the actual model performance in a future. In other words, 200 | side-by-side split is useless for our problem, it just duplicates 201 | model loss observed on training data. 202 | 203 | Resume? 204 | 205 | I used validation (with walk-forward split) only for model tuning. 206 | Final model to predict future values was trained in blind mode, without any validation. 207 | 208 | 209 | 210 | ## Reducing model variance 211 | Model has inevitably high variance due to very noisy input data. To be fair, 212 | I was surprised that RNN learns something at all on such noisy inputs. 213 | 214 | Same model trained on different seeds can have different performance, 215 | sometimes model even diverges on "unfortunate" seeds. During training, performance also wildly 216 | fluctuates from step to step. I can't just rely on pure luck (be on right 217 | seed and stop on right training step) to win the competition, 218 | so I had to take actions to reduce variance. 219 | 220 | ![training](images/training.png "Losses for true value=1") 221 | 1. I don't know which training step would be best for predicting the future 222 | (validation result on current data is very weakly correlated with a 223 | result on a future data), so I can't use early stopping. But I know 224 | approximate region where model is (possibly) trained well enough, 225 | but (possibly) not started to overfit. I decided to set this optimal region 226 | bounds to 10500..11500 training steps and save 10 checkpoints from each 100th step 227 | in this region. 228 | 2. Similarly, I decided to train 3 models on different seeds and save checkpoints 229 | from each model. So I have 30 checkpoints total. 230 | 3. One widely known method for reducing variance and improving model performance 231 | is SGD averaging (ASGD). Method is very simple and well supported 232 | in [Tensorflow](https://www.tensorflow.org/versions/r0.12/api_docs/python/train/moving_averages) - 233 | we have to maintain moving averages of network weights during training and use these 234 | averaged weights, instead of original ones, during inference. 235 | 236 | Combination of all three methods (average predictions from 30 checkpoints 237 | using averaged model weights in each checkpoint) worked well, I got 238 | roughly the same SMAPE error on leaderboard (for future data) 239 | as for validation on historical data. 240 | 241 | Theoretically, one can also consider two first methods as a kind of ensemble 242 | learning, that's right, but I used them mainly for variance reduction. 243 | 244 | ## Hyperparameter tuning 245 | There are many model parameters (number of layers, layer depths, 246 | activation functions, dropout coefficents, etc) that can be (and should be) tuned to 247 | achieve optimal model performance. Manual tuning is tedious and 248 | time-consuming process, so I decided to automate it and use [SMAC3](https://automl.github.io/SMAC3/stable/) package for hyperparameter search. 249 | Some benefits of SMAC3: 250 | * Support for conditional parameters (e.g. jointly tune number of layers 251 | and dropout for each layer; dropout on second layer will be tuned only if 252 | n_layers > 1) 253 | * Explicit handling of model variance. SMAC trains several instances 254 | of each model on different seeds, and compares models only if instances were 255 | trained on same seed. One model wins if it's better than another model on all equal seeds. 256 | 257 | Contrary to my expectations, hyperparamter search did not found well-defined global minima. 258 | All best models had roughly the same performance, but different parameters. 259 | Probably RNN model is too expressive for this task, and best model score 260 | depends more on the data signal-to-noise ratio than on the model architecture. 261 | 262 | Anyway, best parameters sets can be found in `hparams.py` file 263 | -------------------------------------------------------------------------------- /hparams.py: -------------------------------------------------------------------------------- 1 | import tensorflow.contrib.training as training 2 | import re 3 | 4 | # Manually selected params 5 | params_s32 = dict( 6 | batch_size=256, 7 | #train_window=380, 8 | train_window=283, 9 | train_skip_first=0, 10 | rnn_depth=267, 11 | use_attn=False, 12 | attention_depth=64, 13 | attention_heads=1, 14 | encoder_readout_dropout=0.4768781146510798, 15 | 16 | encoder_rnn_layers=1, 17 | decoder_rnn_layers=1, 18 | 19 | # decoder_state_dropout_type=['outside','outside'], 20 | decoder_input_dropout=[1.0, 1.0, 1.0], 21 | decoder_output_dropout=[0.975, 1.0, 1.0], # min 0.95 22 | decoder_state_dropout=[0.99, 0.995, 0.995], # min 0.95 23 | decoder_variational_dropout=[False, False, False], 24 | # decoder_candidate_l2=[0.0, 0.0], 25 | # decoder_gates_l2=[0.0, 0.0], 26 | #decoder_state_dropout_type='outside', 27 | #decoder_input_dropout=1.0, 28 | #decoder_output_dropout=1.0, 29 | #decoder_state_dropout=0.995, #0.98, # min 0.95 30 | # decoder_variational_dropout=False, 31 | decoder_candidate_l2=0.0, 32 | decoder_gates_l2=0.0, 33 | 34 | fingerprint_fc_dropout=0.8232342370695286, 35 | gate_dropout=0.9967589439360334,#0.9786, 36 | gate_activation='none', 37 | encoder_dropout=0.030490422531402273, 38 | encoder_stability_loss=0.0, # max 100 39 | encoder_activation_loss=1e-06, # max 0.001 40 | decoder_stability_loss=0.0, # max 100 41 | decoder_activation_loss=5e-06, # max 0.001 42 | ) 43 | 44 | # Default incumbent on last smac3 search 45 | params_definc = dict( 46 | batch_size=256, 47 | train_window=100, 48 | train_skip_first=0, 49 | rnn_depth=128, 50 | use_attn=True, 51 | attention_depth=64, 52 | attention_heads=1, 53 | encoder_readout_dropout=0.4768781146510798, 54 | 55 | encoder_rnn_layers=1, 56 | decoder_rnn_layers=1, 57 | 58 | decoder_input_dropout=[1.0, 1.0, 1.0], 59 | decoder_output_dropout=[1.0, 1.0, 1.0], 60 | decoder_state_dropout=[0.995, 0.995, 0.995], 61 | decoder_variational_dropout=[False, False, False], 62 | decoder_candidate_l2=0.0, 63 | decoder_gates_l2=0.0, 64 | fingerprint_fc_dropout=0.8232342370695286, 65 | gate_dropout=0.8961710392091516, 66 | gate_activation='none', 67 | encoder_dropout=0.030490422531402273, 68 | encoder_stability_loss=0.0, 69 | encoder_activation_loss=1e-05, 70 | decoder_stability_loss=0.0, 71 | decoder_activation_loss=5e-05, 72 | ) 73 | 74 | # Found incumbent 0.35503610596060753 75 | #"decoder_activation_loss='1e-05'", "decoder_output_dropout:0='1.0'", "decoder_rnn_layers='1'", "decoder_state_dropout:0='0.995'", "encoder_activation_loss='1e-05'", "encoder_rnn_layers='1'", "gate_dropout='0.7934826952854418'", "rnn_depth='243'", "train_window='135'", "use_attn='1'", "attention_depth='17'", "attention_heads='2'", "encoder_readout_dropout='0.7711751356092252'", "fingerprint_fc_dropout='0.9693950737901414'" 76 | params_foundinc = dict( 77 | batch_size=256, 78 | train_window=135, 79 | train_skip_first=0, 80 | rnn_depth=243, 81 | use_attn=True, 82 | attention_depth=17, 83 | attention_heads=2, 84 | encoder_readout_dropout=0.7711751356092252, 85 | 86 | encoder_rnn_layers=1, 87 | decoder_rnn_layers=1, 88 | 89 | decoder_input_dropout=[1.0, 1.0, 1.0], 90 | decoder_output_dropout=[1.0, 1.0, 1.0], 91 | decoder_state_dropout=[0.995, 0.995, 0.995], 92 | decoder_variational_dropout=[False, False, False], 93 | decoder_candidate_l2=0.0, 94 | decoder_gates_l2=0.0, 95 | fingerprint_fc_dropout=0.9693950737901414, 96 | gate_dropout=0.7934826952854418, 97 | gate_activation='none', 98 | encoder_dropout=0.0, 99 | encoder_stability_loss=0.0, 100 | encoder_activation_loss=1e-05, 101 | decoder_stability_loss=0.0, 102 | decoder_activation_loss=1e-05, 103 | ) 104 | 105 | # 81 on smac_run0 (0.3552077534247418 x 7) 106 | #{'decoder_activation_loss': 0.0, 'decoder_output_dropout:0': 0.85, 'decoder_rnn_layers': 2, 'decoder_state_dropout:0': 0.995, 107 | # 'encoder_activation_loss': 0.0, 'encoder_rnn_layers': 2, 'gate_dropout': 0.7665920904244501, 'rnn_depth': 201, 108 | # 'train_window': 143, 'use_attn': 1, 'attention_depth': 17, 'attention_heads': 2, 'decoder_output_dropout:1': 0.975, 109 | # 'decoder_state_dropout:1': 0.99, 'encoder_dropout': 0.0304904225, 'encoder_readout_dropout': 0.4444295965935664, 'fingerprint_fc_dropout': 0.26412480387331017} 110 | params_inst81 = dict( 111 | batch_size=256, 112 | train_window=143, 113 | train_skip_first=0, 114 | rnn_depth=201, 115 | use_attn=True, 116 | attention_depth=17, 117 | attention_heads=2, 118 | encoder_readout_dropout=0.4444295965935664, 119 | 120 | encoder_rnn_layers=2, 121 | decoder_rnn_layers=2, 122 | 123 | decoder_input_dropout=[1.0, 1.0, 1.0], 124 | decoder_output_dropout=[0.85, 0.975, 1.0], 125 | decoder_state_dropout=[0.995, 0.99, 0.995], 126 | decoder_variational_dropout=[False, False, False], 127 | decoder_candidate_l2=0.0, 128 | decoder_gates_l2=0.0, 129 | fingerprint_fc_dropout=0.26412480387331017, 130 | gate_dropout=0.7665920904244501, 131 | gate_activation='none', 132 | encoder_dropout=0.0304904225, 133 | encoder_stability_loss=0.0, 134 | encoder_activation_loss=0.0, 135 | decoder_stability_loss=0.0, 136 | decoder_activation_loss=0.0, 137 | ) 138 | # 121 on smac_run0 (0.3548671560628074 x 3) 139 | # {'decoder_activation_loss': 1e-05, 'decoder_output_dropout:0': 0.975, 'decoder_rnn_layers': 2, 'decoder_state_dropout:0': 1.0, 140 | # 'encoder_activation_loss': 1e-05, 'encoder_rnn_layers': 1, 'gate_dropout': 0.8631496699358483, 'rnn_depth': 122, 141 | # 'train_window': 269, 'use_attn': 1, 'attention_depth': 29, 'attention_heads': 4, 'decoder_output_dropout:1': 0.975, 142 | # 'decoder_state_dropout:1': 0.975, 'encoder_readout_dropout': 0.9835390239895767, 'fingerprint_fc_dropout': 0.7452161827064421} 143 | 144 | # 83 on smac_run1 (0.355050330259362 x 7) 145 | # {'decoder_activation_loss': 1e-06, 'decoder_output_dropout:0': 0.925, 'decoder_rnn_layers': 2, 'decoder_state_dropout:0': 0.98, 146 | # 'encoder_activation_loss': 1e-06, 'encoder_rnn_layers': 1, 'gate_dropout': 0.9275441207192259, 'rnn_depth': 138, 147 | # 'train_window': 84, 'use_attn': 1, 'attention_depth': 52, 'attention_heads': 2, 'decoder_output_dropout:1': 0.925, 148 | # 'decoder_state_dropout:1': 0.98, 'encoder_readout_dropout': 0.6415488109353416, 'fingerprint_fc_dropout': 0.2581296623398802} 149 | 150 | 151 | params_inst83 = dict( 152 | batch_size=256, 153 | train_window=84, 154 | train_skip_first=0, 155 | rnn_depth=138, 156 | use_attn=True, 157 | attention_depth=52, 158 | attention_heads=2, 159 | encoder_readout_dropout=0.6415488109353416, 160 | 161 | encoder_rnn_layers=1, 162 | decoder_rnn_layers=2, 163 | 164 | decoder_input_dropout=[1.0, 1.0, 1.0], 165 | decoder_output_dropout=[0.925, 0.925, 1.0], 166 | decoder_state_dropout=[0.98, 0.98, 0.995], 167 | decoder_variational_dropout=[False, False, False], 168 | decoder_candidate_l2=0.0, 169 | decoder_gates_l2=0.0, 170 | fingerprint_fc_dropout=0.2581296623398802, 171 | gate_dropout=0.9275441207192259, 172 | gate_activation='none', 173 | encoder_dropout=0.0, 174 | encoder_stability_loss=0.0, 175 | encoder_activation_loss=1e-06, 176 | decoder_stability_loss=0.0, 177 | decoder_activation_loss=1e-06, 178 | ) 179 | 180 | def_params = params_s32 181 | 182 | sets = { 183 | 's32':params_s32, 184 | 'definc':params_definc, 185 | 'foundinc':params_foundinc, 186 | 'inst81':params_inst81, 187 | 'inst83':params_inst83, 188 | } 189 | 190 | 191 | def build_hparams(params=def_params): 192 | return training.HParams(**params) 193 | 194 | 195 | def build_from_set(set_name): 196 | return build_hparams(sets[set_name]) 197 | 198 | 199 | -------------------------------------------------------------------------------- /images/attention.xml: -------------------------------------------------------------------------------- 1 | 7V3Rcps4FP0av3ZAQoAf66TpPmxnutOd3e2jArJNi5EH4ybp168wkm0EGDkGgRLSmY4RQiDOkXTu5WDP4N3m+XOKt+svNCTxDFjh8wzezwDw5y77Py94KQpc2yoKVmkUFkX2qeBb9JvwQlFtH4VkV6qYURpn0bZcGNAkIUFWKsNpSp/K1ZY0Lp91i1ekUvAtwHG19N8ozNa8W8g6lf9BotVanNm2+J5HHPxcpXSf8PPNAFwe/ordGyza4vV3axzSp7Mi+GkG71JKs+LT5vmOxPmtFbetOO6hYe/xulOSZCoH+Kg44heO90Rc8uHCshdxMw7dIfkB1gwuntZRRr5tcZDvfWLos7J1tonZls0+rmK82/GqAd1EAf8c40cSL453547GNGW7Epqwdha7LKU/iShkN809/B33CBDyMyyjOJYOX6U4jFiPpeIlTbIHvIninIX/kDTECebFnHI24Ntnp7YOf6wcx9EqYWVpAfQixLv14TbkV8FvG0kz8tx47+0jomygELohWfrCqvADHEECPkiAx7efTpQDDi9bn9FN1MOc5atj0yek2QcOdgPwsB14Rs5t/nEXJauYfMzHVTsBymzRS4d8tKH8nxr4rgL4AQOTpHlB3n1xYuuDBV3oOp7jWnPoIbGft2x9sC2AAHJsB7qWP3c6YgyyJMZAVGGM7dcwxu6AMZ5XIQgJ2azJN2maremKJjj+dCpdBPv013HMnKhhS3T4sd9sRUOcA2eUIs9R9t/htiK+9Z0fx+5b+nK2K9/8zo/aZTgVUBd1Q07h+yBnZRT8vY6SYsdDFB9Ploj1KKfHD5JlL3wb7zPKik4d/ZPS7fFkVzD2NmYWKOS3/jKLGFJ0nwa8lufwZRSnK5KJVbmebSmJcRb9Krd/E3X8iTpXUOfO/7R4GBF1rCp1kKeJOmI2m6ijQh0fLOCoZh1YM+tYN1KHH/qVRuzMp7XRLa+NUF7ziuviR51L4paGgCs1VHSm0hCDGb+cVdvmFXYVrh97rDZzuhP91el/7y5cNCb6g5qZc65r0UUTdQzWa3YNdfwuZs5e5inHUQ8n4yj5WSaMydDXZBCOq/D5kCkPJlVSVME+C/VQTaQnym5cTaG0CNrwlaup3JCIWFsW09dQ0JsoeJmCYycd8jsindyQo6bgXkE6pJA/ndJor06jWdABLvI8CzkIekw7oVkniTTbKxPEq+bRXNhTHs1WWCrZEdF2R9ppgnfb4vnHMnrO2TIob25Kux9nrAplus67QyiBX5N2r8EedIC9gPoS9jkk28Zu8ide+FFUt67tvu23dt+ue+rQRf8VHjqMjfqBTx6XKlNmiIm/DHrgvxLlizvZCPpcH8dd8zDuZVnUh7EGUG3bOFRD97ExGSShulyCYKwjF+obubbC6jQykH3w2JjwlqZnRPzQGSnIrkYJYhkHsunzszgCaERZQWn1LTQdazihCQwMskxRmnAsShMoJF/GBrIpU1kDyDpQNS+AMEZqXh66GqUmUEjWjwxkY6TmZZB1Sk3fOJBNn58HkJpQIaDoW2q6YDipCc1LjZgiNWGDVUy/1ITmpUZMmcqaQNaBqnnPI0yRmi1DV6PUhOalAkyRmi0ga5SacEoFDASyTqkpJNxVltYyeLe6WAvTV7OLtWRWMNkNfYlIJH48dCJnThTgeCYZhwuUrrW8imeM55ZX2BDndGV5rfePASmcEBxv8pu5sluyXJ99KK7g1dZGFfVZ4VaZzhIx9XHlwmC50o+uTisVCunxNDpOmRpzaRpUtTQC+a0U2Rt59Vspatdr21dS3+6Y+uZJ9DfmlpOtkhqX+zcr6gzBHqHhsDfvKdAbw34+HPaveV383cr82975HUDmgxqN5nQh86uiSXJ6H7+m51r15XotDXWkvgC6MvCY9xx4zBVm4ZsCjx7Zqz/wGCzI8LoJMmQ6ye10FWPIo6ktxpBZ3nGMIVA4Y/l3gtO8Ryta4TtbUrOZ9O7gOanF+l5d8itrtphYP/IdmygM4yY906xZSiNMXcB08l7zTRoDWKgErFPVGPOeNAaqPrX9a8+WcTLB3jvsjt0Kuxji3eNeTaV9TUkYBVlEkwn0/kB34YCgw6pqmQKKJkl22/dP6Qgoug8U/G4ElAe0CCj5ctvCBPmyug4TkMoj9pvChB45+X7CBPnrI0fOcvly28KECsu7DhMUUpJ9Gx2RXe5jnZGmN6cjMs+0bYrTseDWGJyOaDJt6wZZB6oqibxxoWqK07Fl6Gp0OopH8QaBbIrTsQVkjU5H18DXDQyfnwdwOrojeH+7ojV9jVrTNc+1bYrWdEfzArf7Zg0+g89lTSDrQNU8644pWrNl6OrUmublAkzRmi0g69SaUy5gIJC1ak2F3IBurXn89QYdWlPYUUyiuSlas+HXDvRrTW+Km3WDrANVYByqxmjNy0NXo9b0zHsHyBiteRlkjVrTMzDhY/j8PIDW9KurcK9WW/HbrR257zo13NUA3anL0mmX1G5fMFeX5b7dte8ZaYVEfV8+ar86bU9+2gPoPftpkeR2AlZNxNzbNF5N8T4QnO1TsnvXmF8a6GuaRr/Zvtz2et+VqVp6X6zuLc2OfkGcbZ5+lr5wz7H7vP5CQ5LX+B8= -------------------------------------------------------------------------------- /images/autocorr.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Arturus/kaggle-web-traffic/a9abb80c800409abf0ece21ea244ef779f758f96/images/autocorr.png -------------------------------------------------------------------------------- /images/encoder-decoder.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Arturus/kaggle-web-traffic/a9abb80c800409abf0ece21ea244ef779f758f96/images/encoder-decoder.png -------------------------------------------------------------------------------- /images/encoder-decoder.xml: -------------------------------------------------------------------------------- 1 | 7VxLc6M4EP41PiYFCDA+xklm5jBbldrU7k6OMig2Mxh5hZzY8+tXgMRDAvzg4XhDLoFGL9Rft7o/yUzA/Xr3lcDN6g/soWBiaN5uAh4mhqFrlsn+xZJ9KpnqTipYEt/jhXLBs/8biZpcuvU9FJUKUowD6m/KQheHIXJpSQYJwe/lYq84KPe6gUukCJ5dGKjSf3yPrlKpY2m5/BvylyvRs67xJwvo/loSvA15fxMDvCZ/6eM1FG3x8tEKevi9IAKPE3BPMKbp1Xp3j4J4bsW0pfW+1DzNxk1QSI+pAIy0xhsMtkgMORkY3YvJQB6bG34boeWaNf2Yi+Yo9O7iKWdP3QBGke8y4YquAybQ2WVECf6F7nGASdIecDT7AQD25BWH9Atc+0GMkr8R8WAIuZhDQjf4faG6lvwxeTrQeHS1L69nU8qgivAaUbJnRd5znQqVrgrqFDKCAkj9t3LzkENrmTWX9fCEfdaxoXEzAKIdbgSGo5WbiPCWuIjXKipJash0pIZMqSEKyRJRpSGmFrgvFNvEBaKGAdtSP7Z22rjK5dlFOgJxV9BBLkoAWQNOBZtf//xLgWdibcjjaHtf+RQ9b6AbP31n3qkMxmWMUHYd48fF6xiryXUAFyiYZ8Yr4BbiEFVC2FgA286eCB8R9/DqB0GhpGchxzPPBrtjzNOOYOAvw9jEGLYRydD/hghFu8mJ+OcVZPUJJ1YwD7vCPGT0FS2hoPtG1ZqjantVrS6b8oC6tUbd9qrbTEkX0K19cd2y2X2wHp0H8/+rYGBcTsHTinjQDiifgZKa7X+3WDy4iZK5umMFdGOzyx+yq2X8/4doJtouhEwXMjaogrggTfsUYjku7TPwtBsCT6ZGsv8Rw/TWErcvvNMVJv5vVg+KUSQwfsKRT31cAkoMEZ+lHN+lAgtMKV5XQUvUuOMPKN6cEAazFCoJEwux1aUiY13vKDLW5RB7elxkfEYw6lQYhgRIlsxtkiTJD5cB4tg85P5yX6l17Aw9e2FbRzlDlp8artveMARiSYqYeZKKi361W+Dojj6zbd20HMD+LFGEt63dmtZMm2kmS50cMHNMs6NgSF4wdaD4U6MC7brV3p8KpI+4uT7cADlBGhI3eu1C7Plvleswe1F6w6cyXoj5uqEsxdlKvIGhLDtpkU+qSY0/hi724l5ZG5Ci2rU8X/grRsHEyUvWLPzCXtj7xtOBt4vkSdkqhssq+rGXJuZLiQw8nyCXxw9s7Yz7L0ciOV/WyiJMrWwR+lSNTEGVRWgdhKa6ShqMrvRKXKkhpzRDulKVkBhxcyW4AdrlcGPMFJgUt0IwoSu8xCEMirsh5QWoABa082khZ2R3LxwgTenkT0Tpns8x3FLMRHm/33GcAPa7yfJzu96IN+YL4OTohDNN47iIU0fFJDRNqlTlH51eHu0CVN6qn2jqXHokYdHaxESXCHxm9hTAo7wU0j0LTc9G453taHOnKvDpIqaROYQB2Ta9im4bl6arWJpMcMGQRmWjssSrBj7ulgT7OWHKjr1vr/ip8xeZRpuhw6t3hJrGHElmVDPu9WTWtsjFdrkpm23klTKsCliBDtxRxY57f+y/MRnZ/+HY/4tR/bOuqH65oSMPwZxB9VecTujPDMKbcRvsMxhCtjXV1hCUhno0hCrmpDdDGM3gE5hBZ4cilYZ6NAOVPXg55TjDkEht4nLqKagzw92DmFz7nhe05YwuBVUpEs9Ox7Y9vqv3iNSRT7haPsGSD8UMSXVXnW4Z+dGRH41xKfNcA/KjoIqQGP3ZdfgzOWEf0J+B+gx+9Gef3J/Z8jo7pD8bjyJcrT+zZd5lSH827l+P/qwOl/I6O6Q/U/PNKmakdqtnZEbOZEbqTs4Mw4yYSmh3JjVi2YNRI0BNcaug2rAdM4L1w7PLdt1PI08FptxQn8CcHQfMEZZXC8tp3Q86T4Wl3FCPsDR1FVp9nH79iCdcm7b02qzX0k8wzWpMtgSbsqY6Tjc+MHuJjj87og64+bMjcgJ/avlDnzVR3rvjz5qYas4/WlYrL/5RogbFz/f0oR4lOjlkAbN25Q9ZjPLeXVuMug3wgMbjrkcfdw1ju28KoM4/ljJpT2AoB/Arf1TY05FXU90peCLI85OpixR40fQ3p7XxtFCmqt/jo9kqrHYB0CMQZjYgrAdNZ9+BOKRpowtNqxzqNz+imMRKYHIPUjjqu1N9yydoKpjJrtTNbvOvJqZLTP5pSvD4Hw== -------------------------------------------------------------------------------- /images/from_past.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Arturus/kaggle-web-traffic/a9abb80c800409abf0ece21ea244ef779f758f96/images/from_past.png -------------------------------------------------------------------------------- /images/lagged_data.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Arturus/kaggle-web-traffic/a9abb80c800409abf0ece21ea244ef779f758f96/images/lagged_data.png -------------------------------------------------------------------------------- /images/losses_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Arturus/kaggle-web-traffic/a9abb80c800409abf0ece21ea244ef779f758f96/images/losses_0.png -------------------------------------------------------------------------------- /images/losses_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Arturus/kaggle-web-traffic/a9abb80c800409abf0ece21ea244ef779f758f96/images/losses_1.png -------------------------------------------------------------------------------- /images/predictions.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Arturus/kaggle-web-traffic/a9abb80c800409abf0ece21ea244ef779f758f96/images/predictions.png -------------------------------------------------------------------------------- /images/split.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Arturus/kaggle-web-traffic/a9abb80c800409abf0ece21ea244ef779f758f96/images/split.png -------------------------------------------------------------------------------- /images/training.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Arturus/kaggle-web-traffic/a9abb80c800409abf0ece21ea244ef779f758f96/images/training.png -------------------------------------------------------------------------------- /images/validation-split.xml: -------------------------------------------------------------------------------- 1 | 7Vpdb6s2GP41kbaLHBkbCLls0ma7mVSp1TnbpQMusUowM/Qk2a+fje0EbHJCO0i3tanU4tcf2M/zfvltJmi53f/CcbH5jSUkm0CQ7CfodgJhNA/Fbyk4KEHoASVIOU2UyDsJHuhfRAvNsBeakLI1sGIsq2jRFsYsz0lctWSYc7ZrD3tiWfutBU6JI3iIceZKv9Gk2uhjBeAk/5XQdGPe7AHds8bxc8rZS67fN4Hoqf6o7i02a+nx5QYnbNcQobsJWnLGKvW03S9JJqE1sKl5qzO9x31zkld9JvhQzfiOsxditlxvrDoYMOrjEDkBTNBit6EVeShwLHt3gn0h21TbTLQ88ZhmuCz10ASXm3qe7IjZlsa6I8Nrki2OUC1ZxrjoylkuFl2UFWfPxAgFgmH9ET1PNMsa8qdA/kg5y6sV3tJMqttXwhOcYy3WuuVB3W5Mj0B4i5CQawgIr8j+LI7ekR2h9IRtScUPYoiZEGpCtcJ7huDdSX2gr2WbhupAozlYq2x6XPtEm3jQzHWzODSJFs6r1Qoul9cDyos6gPKGAGo+LlC3wV106w8D1NFFHKz2VYDyQgepbzh7nj4xvsM8kZAVGa0c9MR5qzZEbWvWJt7ETYtwRtNcNGOBDxHyhUSPCpd8ozu2NEnkazo5abPWwx9Epq33DkbQ7XlvHzAAYTOHsAcRQ6frw7QUfz8J6yLM99+RsMgh7JFuiUOQyBAK+VjSPM3IjUxuXu2X7kL50z9Q1imUSXzAFz9EIYoQCMLQByEMzAg9BXwBgR/NQuAFPhCj0DwaJ1Yg5PITBB38+EPE1MDh556ThMb/YRuCA4WmuRXDQ5eXsczGjUt3ecwS127+1aQMYRzRZRK60oMhSBBZ88U8iuSJdlYG19Z1oOms+iku2dPq99rdhOFMt//QS8jne8KF/6wZq6EXOPKDmRAagZwxFR4LIiOx56mTkMS5IVrMiNOyFx6TVmZZYZ6SqiFy+Ws6rw5+jIyTDFf0e3sTXaTpN9wzKrZ3Cm3ASh+RxbvavJ7VvCFaCyE7YZ9ZC6kjOwvVOnQ8dj+1co37CmrVUJMZbKoJOKchDU2cwR6a+DaNgq5GwffUKBh5bUWYv1GjjlUeo5r2QgNqVNeFL8xkYCgLnLdUK/zzRVZdFqcSzjRWgeFGgpeufxLJxkS6T/FaUFdOzDPwfz7NF0+p/PvIMc1F0ibaeCsjQ74uC9WfqXErtQc1fIgA1qyJmCjGlSL8kyDWq6bTw9ykEy6Mbcg3SnTQLbpYEVK9jCeE27UitcojK/Q2yuOyAyU7PcLsbKQw63uX/aHEqzh7TF0fxWszHLz2+NC6IsGO83tdXmcIADqyvY+Tg/sX8j/vekk3cosL/9us+xzsV0ir/R5p9dj2bpdEjrZ9qSYyTNmxh8e7XiVb8fFDVeiEws6Q34QEHBeJ15WqzyFxaB/4EjD2HeRNMaEHMGNeEmZB845gUv5mn74N9Mz8jdE3U39jBs3c33vX5H+KAisPsB1e3+x/iqyb6YgXSviR4tYZE9239cBYbOBY7FhRDbqF7g+UxP2YFW9+kZaxcrqO/4V+2oZC5nq2gdx859M2zthGBy1j2YbvuqyvAsVERFCWj8DMmDUa2bjHlWA+ryUQ9Pw2TjNJ7Fncsb5gA+rPMGUYzy4ajvc/J9E8fQFLZR2nL7mhu78B -------------------------------------------------------------------------------- /input_pipe.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from feeder import VarFeeder 4 | from enum import Enum 5 | from typing import List, Iterable 6 | import numpy as np 7 | import pandas as pd 8 | 9 | 10 | class ModelMode(Enum): 11 | TRAIN = 0 12 | EVAL = 1, 13 | PREDICT = 2 14 | 15 | 16 | class Split: 17 | def __init__(self, test_set: List[tf.Tensor], train_set: List[tf.Tensor], test_size: int, train_size: int): 18 | self.test_set = test_set 19 | self.train_set = train_set 20 | self.test_size = test_size 21 | self.train_size = train_size 22 | 23 | 24 | class Splitter: 25 | def cluster_pages(self, cluster_idx: tf.Tensor): 26 | """ 27 | Shuffles pages so all user_agents of each unique pages stays together in a shuffled list 28 | :param cluster_idx: Tensor[uniq_pages, n_agents], each value is index of pair (uniq_page, agent) in other page tensors 29 | :return: list of page indexes for use in a global page tensors 30 | """ 31 | size = cluster_idx.shape[0].value 32 | random_idx = tf.random_shuffle(tf.range(0, size, dtype=tf.int32), self.seed) 33 | shuffled_pages = tf.gather(cluster_idx, random_idx) 34 | # Drop non-existent (uniq_page, agent) pairs. Non-existent pair has index value = -1 35 | mask = shuffled_pages >= 0 36 | page_idx = tf.boolean_mask(shuffled_pages, mask) 37 | return page_idx 38 | 39 | def __init__(self, tensors: List[tf.Tensor], cluster_indexes: tf.Tensor, n_splits, seed, train_sampling=1.0, 40 | test_sampling=1.0): 41 | size = tensors[0].shape[0].value 42 | self.seed = seed 43 | clustered_index = self.cluster_pages(cluster_indexes) 44 | index_len = tf.shape(clustered_index)[0] 45 | assert_op = tf.assert_equal(index_len, size, message='n_pages is not equals to size of clustered index') 46 | with tf.control_dependencies([assert_op]): 47 | split_nitems = int(round(size / n_splits)) 48 | split_size = [split_nitems] * n_splits 49 | split_size[-1] = size - (n_splits - 1) * split_nitems 50 | splits = tf.split(clustered_index, split_size) 51 | complements = [tf.random_shuffle(tf.concat(splits[:i] + splits[i + 1:], axis=0), seed) for i in 52 | range(n_splits)] 53 | splits = [tf.random_shuffle(split, seed) for split in splits] 54 | 55 | def mk_name(prefix, tensor): 56 | return prefix + '_' + tensor.name[:-2] 57 | 58 | def prepare_split(i): 59 | test_size = split_size[i] 60 | train_size = size - test_size 61 | test_sampled_size = int(round(test_size * test_sampling)) 62 | train_sampled_size = int(round(train_size * train_sampling)) 63 | test_idx = splits[i][:test_sampled_size] 64 | train_idx = complements[i][:train_sampled_size] 65 | test_set = [tf.gather(tensor, test_idx, name=mk_name('test', tensor)) for tensor in tensors] 66 | tran_set = [tf.gather(tensor, train_idx, name=mk_name('train', tensor)) for tensor in tensors] 67 | return Split(test_set, tran_set, test_sampled_size, train_sampled_size) 68 | 69 | self.splits = [prepare_split(i) for i in range(n_splits)] 70 | 71 | 72 | class FakeSplitter: 73 | def __init__(self, tensors: List[tf.Tensor], n_splits, seed, test_sampling=1.0): 74 | total_pages = tensors[0].shape[0].value 75 | n_pages = int(round(total_pages * test_sampling)) 76 | 77 | def mk_name(prefix, tensor): 78 | return prefix + '_' + tensor.name[:-2] 79 | 80 | def prepare_split(i): 81 | idx = tf.random_shuffle(tf.range(0, n_pages, dtype=tf.int32), seed + i) 82 | train_tensors = [tf.gather(tensor, idx, name=mk_name('shfl', tensor)) for tensor in tensors] 83 | if test_sampling < 1.0: 84 | sampled_idx = idx[:n_pages] 85 | test_tensors = [tf.gather(tensor, sampled_idx, name=mk_name('shfl_test', tensor)) for tensor in tensors] 86 | else: 87 | test_tensors = train_tensors 88 | return Split(test_tensors, train_tensors, n_pages, total_pages) 89 | 90 | self.splits = [prepare_split(i) for i in range(n_splits)] 91 | 92 | 93 | class InputPipe: 94 | def cut(self, hits, start, end): 95 | """ 96 | Cuts [start:end] diapason from input data 97 | :param hits: hits timeseries 98 | :param start: start index 99 | :param end: end index 100 | :return: tuple (train_hits, test_hits, dow, lagged_hits) 101 | """ 102 | # Pad hits to ensure we have enough array length for prediction 103 | hits = tf.concat([hits, tf.fill([self.predict_window], np.NaN)], axis=0) 104 | cropped_hit = hits[start:end] 105 | 106 | # cut day of week 107 | cropped_dow = self.inp.dow[start:end] 108 | 109 | # Cut lagged hits 110 | # gather() accepts only int32 indexes 111 | cropped_lags = tf.cast(self.inp.lagged_ix[start:end], tf.int32) 112 | # Mask for -1 (no data) lag indexes 113 | lag_mask = cropped_lags < 0 114 | # Convert -1 to 0 for gather(), it don't accept anything exotic 115 | cropped_lags = tf.maximum(cropped_lags, 0) 116 | # Translate lag indexes to hit values 117 | lagged_hit = tf.gather(hits, cropped_lags) 118 | # Convert masked (see above) or NaN lagged hits to zeros 119 | lag_zeros = tf.zeros_like(lagged_hit) 120 | lagged_hit = tf.where(lag_mask | tf.is_nan(lagged_hit), lag_zeros, lagged_hit) 121 | 122 | # Split for train and test 123 | x_hits, y_hits = tf.split(cropped_hit, [self.train_window, self.predict_window], axis=0) 124 | 125 | # Convert NaN to zero in for train data 126 | x_hits = tf.where(tf.is_nan(x_hits), tf.zeros_like(x_hits), x_hits) 127 | return x_hits, y_hits, cropped_dow, lagged_hit 128 | 129 | def cut_train(self, hits, *args): 130 | """ 131 | Cuts a segment of time series for training. Randomly chooses starting point. 132 | :param hits: hits timeseries 133 | :param args: pass-through data, will be appended to result 134 | :return: result of cut() + args 135 | """ 136 | n_days = self.predict_window + self.train_window 137 | # How much free space we have to choose starting day 138 | free_space = self.inp.data_days - n_days - self.back_offset - self.start_offset 139 | if self.verbose: 140 | lower_train_start = self.inp.data_start + pd.Timedelta(self.start_offset, 'D') 141 | lower_test_end = lower_train_start + pd.Timedelta(n_days, 'D') 142 | lower_test_start = lower_test_end - pd.Timedelta(self.predict_window, 'D') 143 | upper_train_start = self.inp.data_start + pd.Timedelta(free_space - 1, 'D') 144 | upper_test_end = upper_train_start + pd.Timedelta(n_days, 'D') 145 | upper_test_start = upper_test_end - pd.Timedelta(self.predict_window, 'D') 146 | print(f"Free space for training: {free_space} days.") 147 | print(f" Lower train {lower_train_start}, prediction {lower_test_start}..{lower_test_end}") 148 | print(f" Upper train {upper_train_start}, prediction {upper_test_start}..{upper_test_end}") 149 | # Random starting point 150 | offset = tf.random_uniform((), self.start_offset, free_space, dtype=tf.int32, seed=self.rand_seed) 151 | end = offset + n_days 152 | # Cut all the things 153 | return self.cut(hits, offset, end) + args 154 | 155 | def cut_eval(self, hits, *args): 156 | """ 157 | Cuts segment of time series for evaluation. 158 | Always cuts train_window + predict_window length segment beginning at start_offset point 159 | :param hits: hits timeseries 160 | :param args: pass-through data, will be appended to result 161 | :return: result of cut() + args 162 | """ 163 | end = self.start_offset + self.train_window + self.predict_window 164 | return self.cut(hits, self.start_offset, end) + args 165 | 166 | def reject_filter(self, x_hits, y_hits, *args): 167 | """ 168 | Rejects timeseries having too many zero datapoints (more than self.max_train_empty) 169 | """ 170 | if self.verbose: 171 | print("max empty %d train %d predict" % (self.max_train_empty, self.max_predict_empty)) 172 | zeros_x = tf.reduce_sum(tf.to_int32(tf.equal(x_hits, 0.0))) 173 | keep = zeros_x <= self.max_train_empty 174 | return keep 175 | 176 | def make_features(self, x_hits, y_hits, dow, lagged_hits, pf_agent, pf_country, pf_site, page_ix, 177 | page_popularity, year_autocorr, quarter_autocorr): 178 | """ 179 | Main method. Assembles input data into final tensors 180 | """ 181 | # Split day of week to train and test 182 | x_dow, y_dow = tf.split(dow, [self.train_window, self.predict_window], axis=0) 183 | 184 | # Normalize hits 185 | mean = tf.reduce_mean(x_hits) 186 | std = tf.sqrt(tf.reduce_mean(tf.squared_difference(x_hits, mean))) 187 | norm_x_hits = (x_hits - mean) / std 188 | norm_y_hits = (y_hits - mean) / std 189 | norm_lagged_hits = (lagged_hits - mean) / std 190 | 191 | # Split lagged hits to train and test 192 | x_lagged, y_lagged = tf.split(norm_lagged_hits, [self.train_window, self.predict_window], axis=0) 193 | 194 | # Combine all page features into single tensor 195 | stacked_features = tf.stack([page_popularity, quarter_autocorr, year_autocorr]) 196 | flat_page_features = tf.concat([pf_agent, pf_country, pf_site, stacked_features], axis=0) 197 | page_features = tf.expand_dims(flat_page_features, 0) 198 | 199 | # Train features 200 | x_features = tf.concat([ 201 | # [n_days] -> [n_days, 1] 202 | tf.expand_dims(norm_x_hits, -1), 203 | x_dow, 204 | x_lagged, 205 | # Stretch page_features to all training days 206 | # [1, features] -> [n_days, features] 207 | tf.tile(page_features, [self.train_window, 1]) 208 | ], axis=1) 209 | 210 | # Test features 211 | y_features = tf.concat([ 212 | # [n_days] -> [n_days, 1] 213 | y_dow, 214 | y_lagged, 215 | # Stretch page_features to all testing days 216 | # [1, features] -> [n_days, features] 217 | tf.tile(page_features, [self.predict_window, 1]) 218 | ], axis=1) 219 | 220 | return x_hits, x_features, norm_x_hits, x_lagged, y_hits, y_features, norm_y_hits, mean, std, flat_page_features, page_ix 221 | 222 | def __init__(self, inp: VarFeeder, features: Iterable[tf.Tensor], n_pages: int, mode: ModelMode, n_epoch=None, 223 | batch_size=127, runs_in_burst=1, verbose=True, predict_window=60, train_window=500, 224 | train_completeness_threshold=1, predict_completeness_threshold=1, back_offset=0, 225 | train_skip_first=0, rand_seed=None): 226 | """ 227 | Create data preprocessing pipeline 228 | :param inp: Raw input data 229 | :param features: Features tensors (subset of data in inp) 230 | :param n_pages: Total number of pages 231 | :param mode: Train/Predict/Eval mode selector 232 | :param n_epoch: Number of epochs. Generates endless data stream if None 233 | :param batch_size: 234 | :param runs_in_burst: How many batches can be consumed at short time interval (burst). Multiplicator for prefetch() 235 | :param verbose: Print additional information during graph construction 236 | :param predict_window: Number of days to predict 237 | :param train_window: Use train_window days for traning 238 | :param train_completeness_threshold: Percent of zero datapoints allowed in train timeseries. 239 | :param predict_completeness_threshold: Percent of zero datapoints allowed in test/predict timeseries. 240 | :param back_offset: Don't use back_offset days at the end of timeseries 241 | :param train_skip_first: Don't use train_skip_first days at the beginning of timeseries 242 | :param rand_seed: 243 | 244 | """ 245 | self.n_pages = n_pages 246 | self.inp = inp 247 | self.batch_size = batch_size 248 | self.rand_seed = rand_seed 249 | self.back_offset = back_offset 250 | if verbose: 251 | print("Mode:%s, data days:%d, Data start:%s, data end:%s, features end:%s " % ( 252 | mode, inp.data_days, inp.data_start, inp.data_end, inp.features_end)) 253 | 254 | if mode == ModelMode.TRAIN: 255 | # reserve predict_window at the end for validation 256 | assert inp.data_days - predict_window > predict_window + train_window, \ 257 | "Predict+train window length (+predict window for validation) is larger than total number of days in dataset" 258 | self.start_offset = train_skip_first 259 | elif mode == ModelMode.EVAL or mode == ModelMode.PREDICT: 260 | self.start_offset = inp.data_days - train_window - back_offset 261 | if verbose: 262 | train_start = inp.data_start + pd.Timedelta(self.start_offset, 'D') 263 | eval_start = train_start + pd.Timedelta(train_window, 'D') 264 | end = eval_start + pd.Timedelta(predict_window - 1, 'D') 265 | print("Train start %s, predict start %s, end %s" % (train_start, eval_start, end)) 266 | assert self.start_offset >= 0 267 | 268 | self.train_window = train_window 269 | self.predict_window = predict_window 270 | self.attn_window = train_window - predict_window + 1 271 | self.max_train_empty = int(round(train_window * (1 - train_completeness_threshold))) 272 | self.max_predict_empty = int(round(predict_window * (1 - predict_completeness_threshold))) 273 | self.mode = mode 274 | self.verbose = verbose 275 | 276 | # Reserve more processing threads for eval/predict because of larger batches 277 | num_threads = 3 if mode == ModelMode.TRAIN else 6 278 | 279 | # Choose right cutter function for current ModelMode 280 | cutter = {ModelMode.TRAIN: self.cut_train, ModelMode.EVAL: self.cut_eval, ModelMode.PREDICT: self.cut_eval} 281 | # Create dataset, transform features and assemble batches 282 | root_ds = tf.data.Dataset.from_tensor_slices(tuple(features)).repeat(n_epoch) 283 | batch = (root_ds 284 | .map(cutter[mode]) 285 | .filter(self.reject_filter) 286 | .map(self.make_features, num_parallel_calls=num_threads) 287 | .batch(batch_size) 288 | .prefetch(runs_in_burst * 2) 289 | ) 290 | 291 | self.iterator = batch.make_initializable_iterator() 292 | it_tensors = self.iterator.get_next() 293 | 294 | # Assign all tensors to class variables 295 | self.true_x, self.time_x, self.norm_x, self.lagged_x, self.true_y, self.time_y, self.norm_y, self.norm_mean, \ 296 | self.norm_std, self.page_features, self.page_ix = it_tensors 297 | 298 | self.encoder_features_depth = self.time_x.shape[2].value 299 | 300 | def load_vars(self, session): 301 | self.inp.restore(session) 302 | 303 | def init_iterator(self, session): 304 | session.run(self.iterator.initializer) 305 | 306 | 307 | def page_features(inp: VarFeeder): 308 | return (inp.hits, inp.pf_agent, inp.pf_country, inp.pf_site, 309 | inp.page_ix, inp.page_popularity, inp.year_autocorr, inp.quarter_autocorr) 310 | -------------------------------------------------------------------------------- /make_features.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import os.path 4 | import os 5 | import argparse 6 | 7 | import extractor 8 | from feeder import VarFeeder 9 | import numba 10 | from typing import Tuple, Dict, Collection, List 11 | 12 | 13 | def read_cached(name) -> pd.DataFrame: 14 | """ 15 | Reads csv file (maybe zipped) from data directory and caches it's content as a pickled DataFrame 16 | :param name: file name without extension 17 | :return: file content 18 | """ 19 | cached = 'data/%s.pkl' % name 20 | sources = ['data/%s.csv' % name, 'data/%s.csv.zip' % name] 21 | if os.path.exists(cached): 22 | return pd.read_pickle(cached) 23 | else: 24 | for src in sources: 25 | if os.path.exists(src): 26 | df = pd.read_csv(src) 27 | df.to_pickle(cached) 28 | return df 29 | 30 | 31 | def read_all() -> pd.DataFrame: 32 | """ 33 | Reads source data for training/prediction 34 | """ 35 | def read_file(file): 36 | df = read_cached(file).set_index('Page') 37 | df.columns = df.columns.astype('M8[D]') 38 | return df 39 | 40 | # Path to cached data 41 | path = os.path.join('data', 'all.pkl') 42 | if os.path.exists(path): 43 | df = pd.read_pickle(path) 44 | else: 45 | # Official data 46 | df = read_file('train_2') 47 | # Scraped data 48 | scraped = read_file('2017-08-15_2017-09-11') 49 | # Update last two days by scraped data 50 | df[pd.Timestamp('2017-09-10')] = scraped['2017-09-10'] 51 | df[pd.Timestamp('2017-09-11')] = scraped['2017-09-11'] 52 | 53 | df = df.sort_index() 54 | # Cache result 55 | df.to_pickle(path) 56 | return df 57 | 58 | # todo:remove 59 | def make_holidays(tagged, start, end) -> pd.DataFrame: 60 | def read_df(lang): 61 | result = pd.read_pickle('data/holidays/%s.pkl' % lang) 62 | return result[~result.dw].resample('D').size().rename(lang) 63 | 64 | holidays = pd.DataFrame([read_df(lang) for lang in ['de', 'en', 'es', 'fr', 'ja', 'ru', 'zh']]) 65 | holidays = holidays.loc[:, start:end].fillna(0) 66 | result =tagged[['country']].join(holidays, on='country').drop('country', axis=1).fillna(0).astype(np.int8) 67 | result.columns = pd.DatetimeIndex(result.columns.values) 68 | return result 69 | 70 | 71 | def read_x(start, end) -> pd.DataFrame: 72 | """ 73 | Gets source data from start to end date. Any date can be None 74 | """ 75 | df = read_all() 76 | # User GoogleAnalitycsRoman has really bad data with huge traffic spikes in all incarnations. 77 | # Wikipedia banned him, we'll ban it too 78 | bad_roman = df.index.str.startswith("User:GoogleAnalitycsRoman") 79 | df = df[~bad_roman] 80 | if start and end: 81 | return df.loc[:, start:end] 82 | elif end: 83 | return df.loc[:, :end] 84 | else: 85 | return df 86 | 87 | 88 | @numba.jit(nopython=True) 89 | def single_autocorr(series, lag): 90 | """ 91 | Autocorrelation for single data series 92 | :param series: traffic series 93 | :param lag: lag, days 94 | :return: 95 | """ 96 | s1 = series[lag:] 97 | s2 = series[:-lag] 98 | ms1 = np.mean(s1) 99 | ms2 = np.mean(s2) 100 | ds1 = s1 - ms1 101 | ds2 = s2 - ms2 102 | divider = np.sqrt(np.sum(ds1 * ds1)) * np.sqrt(np.sum(ds2 * ds2)) 103 | return np.sum(ds1 * ds2) / divider if divider != 0 else 0 104 | 105 | 106 | @numba.jit(nopython=True) 107 | def batch_autocorr(data, lag, starts, ends, threshold, backoffset=0): 108 | """ 109 | Calculate autocorrelation for batch (many time series at once) 110 | :param data: Time series, shape [n_pages, n_days] 111 | :param lag: Autocorrelation lag 112 | :param starts: Start index for each series 113 | :param ends: End index for each series 114 | :param threshold: Minimum support (ratio of time series length to lag) to calculate meaningful autocorrelation. 115 | :param backoffset: Offset from the series end, days. 116 | :return: autocorrelation, shape [n_series]. If series is too short (support less than threshold), 117 | autocorrelation value is NaN 118 | """ 119 | n_series = data.shape[0] 120 | n_days = data.shape[1] 121 | max_end = n_days - backoffset 122 | corr = np.empty(n_series, dtype=np.float64) 123 | support = np.empty(n_series, dtype=np.float64) 124 | for i in range(n_series): 125 | series = data[i] 126 | end = min(ends[i], max_end) 127 | real_len = end - starts[i] 128 | support[i] = real_len/lag 129 | if support[i] > threshold: 130 | series = series[starts[i]:end] 131 | c_365 = single_autocorr(series, lag) 132 | c_364 = single_autocorr(series, lag-1) 133 | c_366 = single_autocorr(series, lag+1) 134 | # Average value between exact lag and two nearest neighborhs for smoothness 135 | corr[i] = 0.5 * c_365 + 0.25 * c_364 + 0.25 * c_366 136 | else: 137 | corr[i] = np.NaN 138 | return corr #, support 139 | 140 | 141 | @numba.jit(nopython=True) 142 | def find_start_end(data: np.ndarray): 143 | """ 144 | Calculates start and end of real traffic data. Start is an index of first non-zero, non-NaN value, 145 | end is index of last non-zero, non-NaN value 146 | :param data: Time series, shape [n_pages, n_days] 147 | :return: 148 | """ 149 | n_pages = data.shape[0] 150 | n_days = data.shape[1] 151 | start_idx = np.full(n_pages, -1, dtype=np.int32) 152 | end_idx = np.full(n_pages, -1, dtype=np.int32) 153 | for page in range(n_pages): 154 | # scan from start to the end 155 | for day in range(n_days): 156 | if not np.isnan(data[page, day]) and data[page, day] > 0: 157 | start_idx[page] = day 158 | break 159 | # reverse scan, from end to start 160 | for day in range(n_days - 1, -1, -1): 161 | if not np.isnan(data[page, day]) and data[page, day] > 0: 162 | end_idx[page] = day 163 | break 164 | return start_idx, end_idx 165 | 166 | 167 | def prepare_data(start, end, valid_threshold) -> Tuple[pd.DataFrame, pd.DataFrame, np.ndarray, np.ndarray]: 168 | """ 169 | Reads source data, calculates start and end of each series, drops bad series, calculates log1p(series) 170 | :param start: start date of effective time interval, can be None to start from beginning 171 | :param end: end date of effective time interval, can be None to return all data 172 | :param valid_threshold: minimal ratio of series real length to entire (end-start) interval. Series dropped if 173 | ratio is less than threshold 174 | :return: tuple(log1p(series), nans, series start, series end) 175 | """ 176 | df = read_x(start, end) 177 | starts, ends = find_start_end(df.values) 178 | # boolean mask for bad (too short) series 179 | page_mask = (ends - starts) / df.shape[1] < valid_threshold 180 | print("Masked %d pages from %d" % (page_mask.sum(), len(df))) 181 | inv_mask = ~page_mask 182 | df = df[inv_mask] 183 | nans = pd.isnull(df) 184 | return np.log1p(df.fillna(0)), nans, starts[inv_mask], ends[inv_mask] 185 | 186 | 187 | def lag_indexes(begin, end) -> List[pd.Series]: 188 | """ 189 | Calculates indexes for 3, 6, 9, 12 months backward lag for the given date range 190 | :param begin: start of date range 191 | :param end: end of date range 192 | :return: List of 4 Series, one for each lag. For each Series, index is date in range(begin, end), value is an index 193 | of target (lagged) date in a same Series. If target date is out of (begin,end) range, index is -1 194 | """ 195 | dr = pd.date_range(begin, end) 196 | # key is date, value is day index 197 | base_index = pd.Series(np.arange(0, len(dr)), index=dr) 198 | 199 | def lag(offset): 200 | dates = dr - offset 201 | return pd.Series(data=base_index.loc[dates].fillna(-1).astype(np.int16).values, index=dr) 202 | 203 | return [lag(pd.DateOffset(months=m)) for m in (3, 6, 9, 12)] 204 | 205 | 206 | def make_page_features(pages: np.ndarray) -> pd.DataFrame: 207 | """ 208 | Calculates page features (site, country, agent, etc) from urls 209 | :param pages: Source urls 210 | :return: DataFrame with features as columns and urls as index 211 | """ 212 | tagged = extractor.extract(pages).set_index('page') 213 | # Drop useless features 214 | features: pd.DataFrame = tagged.drop(['term', 'marker'], axis=1) 215 | return features 216 | 217 | 218 | def uniq_page_map(pages:Collection): 219 | """ 220 | Finds agent types (spider, desktop, mobile, all) for each unique url, i.e. groups pages by agents 221 | :param pages: all urls (must be presorted) 222 | :return: array[num_unique_urls, 4], where each column corresponds to agent type and each row corresponds to unique url. 223 | Value is an index of page in source pages array. If agent is missing, value is -1 224 | """ 225 | import re 226 | result = np.full([len(pages), 4], -1, dtype=np.int32) 227 | pat = re.compile( 228 | '(.+(?:(?:wikipedia\.org)|(?:commons\.wikimedia\.org)|(?:www\.mediawiki\.org)))_([a-z_-]+?)') 229 | prev_page = None 230 | num_page = -1 231 | agents = {'all-access_spider': 0, 'desktop_all-agents': 1, 'mobile-web_all-agents': 2, 'all-access_all-agents': 3} 232 | for i, entity in enumerate(pages): 233 | match = pat.fullmatch(entity) 234 | assert match 235 | page = match.group(1) 236 | agent = match.group(2) 237 | if page != prev_page: 238 | prev_page = page 239 | num_page += 1 240 | result[num_page, agents[agent]] = i 241 | return result[:num_page+1] 242 | 243 | 244 | def encode_page_features(df) -> Dict[str, pd.DataFrame]: 245 | """ 246 | Applies one-hot encoding to page features and normalises result 247 | :param df: page features DataFrame (one column per feature) 248 | :return: dictionary feature_name:encoded_values. Encoded values is [n_pages,n_values] array 249 | """ 250 | def encode(column) -> pd.DataFrame: 251 | one_hot = pd.get_dummies(df[column], drop_first=False) 252 | # noinspection PyUnresolvedReferences 253 | return (one_hot - one_hot.mean()) / one_hot.std() 254 | 255 | return {str(column): encode(column) for column in df} 256 | 257 | 258 | def normalize(values: np.ndarray): 259 | return (values - values.mean()) / np.std(values) 260 | 261 | 262 | def run(): 263 | parser = argparse.ArgumentParser(description='Prepare data') 264 | parser.add_argument('data_dir') 265 | parser.add_argument('--valid_threshold', default=0.0, type=float, help="Series minimal length threshold (pct of data length)") 266 | parser.add_argument('--add_days', default=64, type=int, help="Add N days in a future for prediction") 267 | parser.add_argument('--start', help="Effective start date. Data before the start is dropped") 268 | parser.add_argument('--end', help="Effective end date. Data past the end is dropped") 269 | parser.add_argument('--corr_backoffset', default=0, type=int, help='Offset for correlation calculation') 270 | args = parser.parse_args() 271 | 272 | # Get the data 273 | df, nans, starts, ends = prepare_data(args.start, args.end, args.valid_threshold) 274 | 275 | # Our working date range 276 | data_start, data_end = df.columns[0], df.columns[-1] 277 | 278 | # We have to project some date-dependent features (day of week, etc) to the future dates for prediction 279 | features_end = data_end + pd.Timedelta(args.add_days, unit='D') 280 | print(f"start: {data_start}, end:{data_end}, features_end:{features_end}") 281 | 282 | # Group unique pages by agents 283 | assert df.index.is_monotonic_increasing 284 | page_map = uniq_page_map(df.index.values) 285 | 286 | # Yearly(annual) autocorrelation 287 | raw_year_autocorr = batch_autocorr(df.values, 365, starts, ends, 1.5, args.corr_backoffset) 288 | year_unknown_pct = np.sum(np.isnan(raw_year_autocorr))/len(raw_year_autocorr) # type: float 289 | 290 | # Quarterly autocorrelation 291 | raw_quarter_autocorr = batch_autocorr(df.values, int(round(365.25/4)), starts, ends, 2, args.corr_backoffset) 292 | quarter_unknown_pct = np.sum(np.isnan(raw_quarter_autocorr)) / len(raw_quarter_autocorr) # type: float 293 | 294 | print("Percent of undefined autocorr = yearly:%.3f, quarterly:%.3f" % (year_unknown_pct, quarter_unknown_pct)) 295 | 296 | # Normalise all the things 297 | year_autocorr = normalize(np.nan_to_num(raw_year_autocorr)) 298 | quarter_autocorr = normalize(np.nan_to_num(raw_quarter_autocorr)) 299 | 300 | # Calculate and encode page features 301 | page_features = make_page_features(df.index.values) 302 | encoded_page_features = encode_page_features(page_features) 303 | 304 | # Make time-dependent features 305 | features_days = pd.date_range(data_start, features_end) 306 | #dow = normalize(features_days.dayofweek.values) 307 | week_period = 7 / (2 * np.pi) 308 | dow_norm = features_days.dayofweek.values / week_period 309 | dow = np.stack([np.cos(dow_norm), np.sin(dow_norm)], axis=-1) 310 | 311 | # Assemble indices for quarterly lagged data 312 | lagged_ix = np.stack(lag_indexes(data_start, features_end), axis=-1) 313 | 314 | page_popularity = df.median(axis=1) 315 | page_popularity = (page_popularity - page_popularity.mean()) / page_popularity.std() 316 | 317 | # Put NaNs back 318 | df[nans] = np.NaN 319 | 320 | # Assemble final output 321 | tensors = dict( 322 | hits=df, 323 | lagged_ix=lagged_ix, 324 | page_map=page_map, 325 | page_ix=df.index.values, 326 | pf_agent=encoded_page_features['agent'], 327 | pf_country=encoded_page_features['country'], 328 | pf_site=encoded_page_features['site'], 329 | page_popularity=page_popularity, 330 | year_autocorr=year_autocorr, 331 | quarter_autocorr=quarter_autocorr, 332 | dow=dow, 333 | ) 334 | plain = dict( 335 | features_days=len(features_days), 336 | data_days=len(df.columns), 337 | n_pages=len(df), 338 | data_start=data_start, 339 | data_end=data_end, 340 | features_end=features_end 341 | 342 | ) 343 | 344 | # Store data to the disk 345 | VarFeeder(args.data_dir, tensors, plain) 346 | 347 | 348 | if __name__ == '__main__': 349 | run() 350 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | import tensorflow.contrib.cudnn_rnn as cudnn_rnn 4 | import tensorflow.contrib.rnn as rnn 5 | import tensorflow.contrib.layers as layers 6 | from tensorflow.python.util import nest 7 | 8 | from input_pipe import InputPipe, ModelMode 9 | 10 | GRAD_CLIP_THRESHOLD = 10 11 | RNN = cudnn_rnn.CudnnGRU 12 | # RNN = tf.contrib.cudnn_rnn.CudnnLSTM 13 | # RNN = tf.contrib.cudnn_rnn.CudnnRNNRelu 14 | 15 | 16 | def default_init(seed): 17 | # replica of tf.glorot_uniform_initializer(seed=seed) 18 | return layers.variance_scaling_initializer(factor=1.0, 19 | mode="FAN_AVG", 20 | uniform=True, 21 | seed=seed) 22 | 23 | 24 | def selu(x): 25 | """ 26 | SELU activation 27 | https://arxiv.org/abs/1706.02515 28 | :param x: 29 | :return: 30 | """ 31 | with tf.name_scope('elu') as scope: 32 | alpha = 1.6732632423543772848170429916717 33 | scale = 1.0507009873554804934193349852946 34 | return scale * tf.where(x >= 0.0, x, alpha * tf.nn.elu(x)) 35 | 36 | 37 | def make_encoder(time_inputs, encoder_features_depth, is_train, hparams, seed, transpose_output=True): 38 | """ 39 | Builds encoder, using CUDA RNN 40 | :param time_inputs: Input tensor, shape [batch, time, features] 41 | :param encoder_features_depth: Static size for features dimension 42 | :param is_train: 43 | :param hparams: 44 | :param seed: 45 | :param transpose_output: Transform RNN output to batch-first shape 46 | :return: 47 | """ 48 | 49 | def build_rnn(): 50 | return RNN(num_layers=hparams.encoder_rnn_layers, num_units=hparams.rnn_depth, 51 | #input_size=encoder_features_depth, 52 | kernel_initializer=tf.initializers.random_uniform(minval=-0.05, maxval=0.05, 53 | seed=seed + 1 if seed else None), 54 | direction='unidirectional', 55 | dropout=hparams.encoder_dropout if is_train else 0, seed=seed) 56 | 57 | cuda_model = build_rnn() 58 | 59 | # [batch, time, features] -> [time, batch, features] 60 | time_first = tf.transpose(time_inputs, [1, 0, 2]) 61 | rnn_time_input = time_first 62 | if RNN == tf.contrib.cudnn_rnn.CudnnLSTM: 63 | rnn_out, (rnn_state, c_state) = cuda_model(inputs=rnn_time_input) 64 | else: 65 | rnn_out, (rnn_state,) = cuda_model(inputs=rnn_time_input) 66 | c_state = None 67 | if transpose_output: 68 | rnn_out = tf.transpose(rnn_out, [1, 0, 2]) 69 | return rnn_out, rnn_state, c_state 70 | 71 | 72 | def compressed_readout(rnn_out, hparams, dropout, seed): 73 | """ 74 | FC compression layer, reduces RNN output depth to hparams.attention_depth 75 | :param rnn_out: 76 | :param hparams: 77 | :param dropout: 78 | :param seed: 79 | :return: 80 | """ 81 | if dropout < 1.0: 82 | rnn_out = tf.nn.dropout(rnn_out, dropout, seed=seed) 83 | return tf.layers.dense(rnn_out, hparams.attention_depth, 84 | use_bias=True, 85 | activation=selu, 86 | kernel_initializer=layers.variance_scaling_initializer(factor=1.0, seed=seed), 87 | name='compress_readout' 88 | ) 89 | 90 | 91 | def make_fingerprint(x, is_train, fc_dropout, seed): 92 | """ 93 | Calculates 'fingerprint' of timeseries, to feed into attention layer 94 | :param x: 95 | :param is_train: 96 | :param fc_dropout: 97 | :param seed: 98 | :return: 99 | """ 100 | with tf.variable_scope("fingerpint"): 101 | # x = tf.expand_dims(x, -1) 102 | with tf.variable_scope('convnet', initializer=layers.variance_scaling_initializer(seed=seed)): 103 | c11 = tf.layers.conv1d(x, filters=16, kernel_size=7, activation=tf.nn.relu, padding='same') 104 | c12 = tf.layers.conv1d(c11, filters=16, kernel_size=3, activation=tf.nn.relu, padding='same') 105 | pool1 = tf.layers.max_pooling1d(c12, 2, 2, padding='same') 106 | c21 = tf.layers.conv1d(pool1, filters=32, kernel_size=3, activation=tf.nn.relu, padding='same') 107 | c22 = tf.layers.conv1d(c21, filters=32, kernel_size=3, activation=tf.nn.relu, padding='same') 108 | pool2 = tf.layers.max_pooling1d(c22, 2, 2, padding='same') 109 | c31 = tf.layers.conv1d(pool2, filters=64, kernel_size=3, activation=tf.nn.relu, padding='same') 110 | c32 = tf.layers.conv1d(c31, filters=64, kernel_size=3, activation=tf.nn.relu, padding='same') 111 | pool3 = tf.layers.max_pooling1d(c32, 2, 2, padding='same') 112 | dims = pool3.shape.dims 113 | pool3 = tf.reshape(pool3, [-1, dims[1].value * dims[2].value]) 114 | if is_train and fc_dropout < 1.0: 115 | cnn_out = tf.nn.dropout(pool3, fc_dropout, seed=seed) 116 | else: 117 | cnn_out = pool3 118 | with tf.variable_scope('fc_convnet', 119 | initializer=layers.variance_scaling_initializer(factor=1.0, mode='FAN_IN', seed=seed)): 120 | fc_encoder = tf.layers.dense(cnn_out, 512, activation=selu, name='fc_encoder') 121 | out_encoder = tf.layers.dense(fc_encoder, 16, activation=selu, name='out_encoder') 122 | return out_encoder 123 | 124 | 125 | def attn_readout_v3(readout, attn_window, attn_heads, page_features, seed): 126 | # input: [n_days, batch, readout_depth] 127 | # [n_days, batch, readout_depth] -> [batch(readout_depth), width=n_days, channels=batch] 128 | readout = tf.transpose(readout, [2, 0, 1]) 129 | # [batch(readout_depth), width, channels] -> [batch, height=1, width, channels] 130 | inp = readout[:, tf.newaxis, :, :] 131 | 132 | # attn_window = train_window - predict_window + 1 133 | # [batch, attn_window * n_heads] 134 | filter_logits = tf.layers.dense(page_features, attn_window * attn_heads, name="attn_focus", 135 | kernel_initializer=default_init(seed) 136 | # kernel_initializer=layers.variance_scaling_initializer(uniform=True) 137 | # activation=selu, 138 | # kernel_initializer=layers.variance_scaling_initializer(factor=1.0, mode='FAN_IN') 139 | ) 140 | # [batch, attn_window * n_heads] -> [batch, attn_window, n_heads] 141 | filter_logits = tf.reshape(filter_logits, [-1, attn_window, attn_heads]) 142 | 143 | # attns_max = tf.nn.softmax(filter_logits, dim=1) 144 | attns_max = filter_logits / tf.reduce_sum(filter_logits, axis=1, keep_dims=True) 145 | # [batch, attn_window, n_heads] -> [width(attn_window), channels(batch), n_heads] 146 | attns_max = tf.transpose(attns_max, [1, 0, 2]) 147 | 148 | # [width(attn_window), channels(batch), n_heads] -> [height(1), width(attn_window), channels(batch), multiplier(n_heads)] 149 | attn_filter = attns_max[tf.newaxis, :, :, :] 150 | # [batch(readout_depth), height=1, width=n_days, channels=batch] -> [batch(readout_depth), height=1, width=predict_window, channels=batch*n_heads] 151 | averaged = tf.nn.depthwise_conv2d_native(inp, attn_filter, [1, 1, 1, 1], 'VALID') 152 | # [batch, height=1, width=predict_window, channels=readout_depth*n_neads] -> [batch(depth), predict_window, batch*n_heads] 153 | attn_features = tf.squeeze(averaged, 1) 154 | # [batch(depth), predict_window, batch*n_heads] -> [batch*n_heads, predict_window, depth] 155 | attn_features = tf.transpose(attn_features, [2, 1, 0]) 156 | # [batch * n_heads, predict_window, depth] -> n_heads * [batch, predict_window, depth] 157 | heads = [attn_features[head_no::attn_heads] for head_no in range(attn_heads)] 158 | # n_heads * [batch, predict_window, depth] -> [batch, predict_window, depth*n_heads] 159 | result = tf.concat(heads, axis=-1) 160 | # attn_diag = tf.unstack(attns_max, axis=-1) 161 | return result, None 162 | 163 | 164 | def calc_smape_rounded(true, predicted, weights): 165 | """ 166 | Calculates SMAPE on rounded submission values. Should be close to official SMAPE in competition 167 | :param true: 168 | :param predicted: 169 | :param weights: Weights mask to exclude some values 170 | :return: 171 | """ 172 | n_valid = tf.reduce_sum(weights) 173 | true_o = tf.round(tf.expm1(true)) 174 | pred_o = tf.maximum(tf.round(tf.expm1(predicted)), 0.0) 175 | summ = tf.abs(true_o) + tf.abs(pred_o) 176 | zeros = summ < 0.01 177 | raw_smape = tf.abs(pred_o - true_o) / summ * 2.0 178 | smape = tf.where(zeros, tf.zeros_like(summ, dtype=tf.float32), raw_smape) 179 | return tf.reduce_sum(smape * weights) / n_valid 180 | 181 | 182 | def smape_loss(true, predicted, weights): 183 | """ 184 | Differentiable SMAPE loss 185 | :param true: Truth values 186 | :param predicted: Predicted values 187 | :param weights: Weights mask to exclude some values 188 | :return: 189 | """ 190 | epsilon = 0.1 # Smoothing factor, helps SMAPE to be well-behaved near zero 191 | true_o = tf.expm1(true) 192 | pred_o = tf.expm1(predicted) 193 | summ = tf.maximum(tf.abs(true_o) + tf.abs(pred_o) + epsilon, 0.5 + epsilon) 194 | smape = tf.abs(pred_o - true_o) / summ * 2.0 195 | return tf.losses.compute_weighted_loss(smape, weights, loss_collection=None) 196 | 197 | 198 | def decode_predictions(decoder_readout, inp: InputPipe): 199 | """ 200 | Converts normalized prediction values to log1p(pageviews), e.g. reverts normalization 201 | :param decoder_readout: Decoder output, shape [n_days, batch] 202 | :param inp: Input tensors 203 | :return: 204 | """ 205 | # [n_days, batch] -> [batch, n_days] 206 | batch_readout = tf.transpose(decoder_readout) 207 | batch_std = tf.expand_dims(inp.norm_std, -1) 208 | batch_mean = tf.expand_dims(inp.norm_mean, -1) 209 | return batch_readout * batch_std + batch_mean 210 | 211 | 212 | def calc_loss(predictions, true_y, additional_mask=None): 213 | """ 214 | Calculates losses, ignoring NaN true values (assigning zero loss to them) 215 | :param predictions: Predicted values 216 | :param true_y: True values 217 | :param additional_mask: 218 | :return: MAE loss, differentiable SMAPE loss, competition SMAPE loss 219 | """ 220 | # Take into account NaN's in true values 221 | mask = tf.is_finite(true_y) 222 | # Fill NaNs by zeros (can use any value) 223 | true_y = tf.where(mask, true_y, tf.zeros_like(true_y)) 224 | # Assign zero weight to NaNs 225 | weights = tf.to_float(mask) 226 | if additional_mask is not None: 227 | weights = weights * tf.expand_dims(additional_mask, axis=0) 228 | 229 | mae_loss = tf.losses.absolute_difference(labels=true_y, predictions=predictions, weights=weights) 230 | return mae_loss, smape_loss(true_y, predictions, weights), calc_smape_rounded(true_y, predictions, 231 | weights), tf.size(true_y) 232 | 233 | 234 | def make_train_op(loss, ema_decay=None, prefix=None): 235 | optimizer = tf.train.AdamOptimizer() 236 | glob_step = tf.train.get_global_step() 237 | 238 | # Add regularization losses 239 | reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) 240 | total_loss = loss + reg_losses if reg_losses else loss 241 | 242 | # Clip gradients 243 | grads_and_vars = optimizer.compute_gradients(total_loss) 244 | gradients, variables = zip(*grads_and_vars) 245 | clipped_gradients, glob_norm = tf.clip_by_global_norm(gradients, GRAD_CLIP_THRESHOLD) 246 | sgd_op, glob_norm = optimizer.apply_gradients(zip(clipped_gradients, variables)), glob_norm 247 | 248 | # Apply SGD averaging 249 | if ema_decay: 250 | ema = tf.train.ExponentialMovingAverage(decay=ema_decay, num_updates=glob_step) 251 | if prefix: 252 | # Some magic to handle multiple models trained in single graph 253 | ema_vars = [var for var in variables if var.name.startswith(prefix)] 254 | else: 255 | ema_vars = variables 256 | update_ema = ema.apply(ema_vars) 257 | with tf.control_dependencies([sgd_op]): 258 | training_op = tf.group(update_ema) 259 | else: 260 | training_op = sgd_op 261 | ema = None 262 | return training_op, glob_norm, ema 263 | 264 | 265 | def convert_cudnn_state_v2(h_state, hparams, seed, c_state=None, dropout=1.0): 266 | """ 267 | Converts RNN state tensor from cuDNN representation to TF RNNCell compatible representation. 268 | :param h_state: tensor [num_layers, batch_size, depth] 269 | :param c_state: LSTM additional state, should be same shape as h_state 270 | :return: TF cell representation matching RNNCell.state_size structure for compatible cell 271 | """ 272 | 273 | def squeeze(seq): 274 | return tuple(seq) if len(seq) > 1 else seq[0] 275 | 276 | def wrap_dropout(structure): 277 | if dropout < 1.0: 278 | return nest.map_structure(lambda x: tf.nn.dropout(x, keep_prob=dropout, seed=seed), structure) 279 | else: 280 | return structure 281 | 282 | # Cases: 283 | # decoder_layer = encoder_layers, straight mapping 284 | # encoder_layers > decoder_layers: get outputs of upper encoder layers 285 | # encoder_layers < decoder_layers: feed encoder outputs to lower decoder layers, feed zeros to top layers 286 | h_layers = tf.unstack(h_state) 287 | if hparams.encoder_rnn_layers >= hparams.decoder_rnn_layers: 288 | return squeeze(wrap_dropout(h_layers[hparams.encoder_rnn_layers - hparams.decoder_rnn_layers:])) 289 | else: 290 | lower_inputs = wrap_dropout(h_layers) 291 | upper_inputs = [tf.zeros_like(h_layers[0]) for _ in 292 | range(hparams.decoder_rnn_layers - hparams.encoder_rnn_layers)] 293 | return squeeze(lower_inputs + upper_inputs) 294 | 295 | 296 | def rnn_stability_loss(rnn_output, beta): 297 | """ 298 | REGULARIZING RNNS BY STABILIZING ACTIVATIONS 299 | https://arxiv.org/pdf/1511.08400.pdf 300 | :param rnn_output: [time, batch, features] 301 | :return: loss value 302 | """ 303 | if beta == 0.0: 304 | return 0.0 305 | # [time, batch, features] -> [time, batch] 306 | l2 = tf.sqrt(tf.reduce_sum(tf.square(rnn_output), axis=-1)) 307 | # [time, batch] -> [] 308 | return beta * tf.reduce_mean(tf.square(l2[1:] - l2[:-1])) 309 | 310 | 311 | def rnn_activation_loss(rnn_output, beta): 312 | """ 313 | REGULARIZING RNNS BY STABILIZING ACTIVATIONS 314 | https://arxiv.org/pdf/1511.08400.pdf 315 | :param rnn_output: [time, batch, features] 316 | :return: loss value 317 | """ 318 | if beta == 0.0: 319 | return 0.0 320 | return tf.nn.l2_loss(rnn_output) * beta 321 | 322 | 323 | class Model: 324 | def __init__(self, inp: InputPipe, hparams, is_train, seed, graph_prefix=None, asgd_decay=None, loss_mask=None): 325 | """ 326 | Encoder-decoder prediction model 327 | :param inp: Input tensors 328 | :param hparams: 329 | :param is_train: 330 | :param seed: 331 | :param graph_prefix: Subgraph prefix for multi-model graph 332 | :param asgd_decay: Decay for SGD averaging 333 | :param loss_mask: Additional mask for losses calculation (one value for each prediction day), shape=[predict_window] 334 | """ 335 | self.is_train = is_train 336 | self.inp = inp 337 | self.hparams = hparams 338 | self.seed = seed 339 | self.inp = inp 340 | 341 | encoder_output, h_state, c_state = make_encoder(inp.time_x, inp.encoder_features_depth, is_train, hparams, seed, 342 | transpose_output=False) 343 | # Encoder activation losses 344 | enc_stab_loss = rnn_stability_loss(encoder_output, hparams.encoder_stability_loss / inp.train_window) 345 | enc_activation_loss = rnn_activation_loss(encoder_output, hparams.encoder_activation_loss / inp.train_window) 346 | 347 | # Convert state from cuDNN representation to TF RNNCell-compatible representation 348 | encoder_state = convert_cudnn_state_v2(h_state, hparams, c_state, 349 | dropout=hparams.gate_dropout if is_train else 1.0) 350 | 351 | # Attention calculations 352 | # Compress encoder outputs 353 | enc_readout = compressed_readout(encoder_output, hparams, 354 | dropout=hparams.encoder_readout_dropout if is_train else 1.0, seed=seed) 355 | # Calculate fingerprint from input features 356 | fingerprint_inp = tf.concat([inp.lagged_x, tf.expand_dims(inp.norm_x, -1)], axis=-1) 357 | fingerprint = make_fingerprint(fingerprint_inp, is_train, hparams.fingerprint_fc_dropout, seed) 358 | # Calculate attention vector 359 | attn_features, attn_weights = attn_readout_v3(enc_readout, inp.attn_window, hparams.attention_heads, 360 | fingerprint, seed=seed) 361 | 362 | # Run decoder 363 | decoder_targets, decoder_outputs = self.decoder(encoder_state, 364 | attn_features if hparams.use_attn else None, 365 | inp.time_y, inp.norm_x[:, -1]) 366 | # Decoder activation losses 367 | dec_stab_loss = rnn_stability_loss(decoder_outputs, hparams.decoder_stability_loss / inp.predict_window) 368 | dec_activation_loss = rnn_activation_loss(decoder_outputs, hparams.decoder_activation_loss / inp.predict_window) 369 | 370 | # Get final denormalized predictions 371 | self.predictions = decode_predictions(decoder_targets, inp) 372 | 373 | # Calculate losses and build training op 374 | if inp.mode == ModelMode.PREDICT: 375 | # Pseudo-apply ema to get variable names later in ema.variables_to_restore() 376 | # This is copypaste from make_train_op() 377 | if asgd_decay: 378 | self.ema = tf.train.ExponentialMovingAverage(decay=asgd_decay) 379 | variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) 380 | if graph_prefix: 381 | ema_vars = [var for var in variables if var.name.startswith(graph_prefix)] 382 | else: 383 | ema_vars = variables 384 | self.ema.apply(ema_vars) 385 | else: 386 | self.mae, smape_loss, self.smape, self.loss_item_count = calc_loss(self.predictions, inp.true_y, 387 | additional_mask=loss_mask) 388 | if is_train: 389 | # Sum all losses 390 | total_loss = smape_loss + enc_stab_loss + dec_stab_loss + enc_activation_loss + dec_activation_loss 391 | self.train_op, self.glob_norm, self.ema = make_train_op(total_loss, asgd_decay, prefix=graph_prefix) 392 | 393 | 394 | 395 | def default_init(self, seed_add=0): 396 | return default_init(self.seed + seed_add) 397 | 398 | def decoder(self, encoder_state, attn_features, prediction_inputs, previous_y): 399 | """ 400 | :param encoder_state: shape [batch_size, encoder_rnn_depth] 401 | :param prediction_inputs: features for prediction days, tensor[batch_size, time, input_depth] 402 | :param previous_y: Last day pageviews, shape [batch_size] 403 | :param attn_features: Additional features from attention layer, shape [batch, predict_window, readout_depth*n_heads] 404 | :return: decoder rnn output 405 | """ 406 | hparams = self.hparams 407 | 408 | def build_cell(idx): 409 | with tf.variable_scope('decoder_cell', initializer=self.default_init(idx)): 410 | cell = rnn.GRUBlockCell(self.hparams.rnn_depth) 411 | has_dropout = hparams.decoder_input_dropout[idx] < 1 \ 412 | or hparams.decoder_state_dropout[idx] < 1 or hparams.decoder_output_dropout[idx] < 1 413 | 414 | if self.is_train and has_dropout: 415 | attn_depth = attn_features.shape[-1].value if attn_features is not None else 0 416 | input_size = attn_depth + prediction_inputs.shape[-1].value + 1 if idx == 0 else self.hparams.rnn_depth 417 | cell = rnn.DropoutWrapper(cell, dtype=tf.float32, input_size=input_size, 418 | variational_recurrent=hparams.decoder_variational_dropout[idx], 419 | input_keep_prob=hparams.decoder_input_dropout[idx], 420 | output_keep_prob=hparams.decoder_output_dropout[idx], 421 | state_keep_prob=hparams.decoder_state_dropout[idx], seed=self.seed + idx) 422 | return cell 423 | 424 | if hparams.decoder_rnn_layers > 1: 425 | cells = [build_cell(idx) for idx in range(hparams.decoder_rnn_layers)] 426 | cell = rnn.MultiRNNCell(cells) 427 | else: 428 | cell = build_cell(0) 429 | 430 | nest.assert_same_structure(encoder_state, cell.state_size) 431 | predict_days = self.inp.predict_window 432 | assert prediction_inputs.shape[1] == predict_days 433 | 434 | # [batch_size, time, input_depth] -> [time, batch_size, input_depth] 435 | inputs_by_time = tf.transpose(prediction_inputs, [1, 0, 2]) 436 | 437 | # Return raw outputs for RNN losses calculation 438 | return_raw_outputs = self.hparams.decoder_stability_loss > 0.0 or self.hparams.decoder_activation_loss > 0.0 439 | 440 | # Stop condition for decoding loop 441 | def cond_fn(time, prev_output, prev_state, array_targets: tf.TensorArray, array_outputs: tf.TensorArray): 442 | return time < predict_days 443 | 444 | # FC projecting layer to get single predicted value from RNN output 445 | def project_output(tensor): 446 | return tf.layers.dense(tensor, 1, name='decoder_output_proj', kernel_initializer=self.default_init()) 447 | 448 | def loop_fn(time, prev_output, prev_state, array_targets: tf.TensorArray, array_outputs: tf.TensorArray): 449 | """ 450 | Main decoder loop 451 | :param time: Day number 452 | :param prev_output: Output(prediction) from previous step 453 | :param prev_state: RNN state tensor from previous step 454 | :param array_targets: Predictions, each step will append new value to this array 455 | :param array_outputs: Raw RNN outputs (for regularization losses) 456 | :return: 457 | """ 458 | # RNN inputs for current step 459 | features = inputs_by_time[time] 460 | 461 | # [batch, predict_window, readout_depth * n_heads] -> [batch, readout_depth * n_heads] 462 | if attn_features is not None: 463 | # [batch_size, 1] + [batch_size, input_depth] 464 | attn = attn_features[:, time, :] 465 | # Append previous predicted value + attention vector to input features 466 | next_input = tf.concat([prev_output, features, attn], axis=1) 467 | else: 468 | # Append previous predicted value to input features 469 | next_input = tf.concat([prev_output, features], axis=1) 470 | 471 | # Run RNN cell 472 | output, state = cell(next_input, prev_state) 473 | # Make prediction from RNN outputs 474 | projected_output = project_output(output) 475 | # Append step results to the buffer arrays 476 | if return_raw_outputs: 477 | array_outputs = array_outputs.write(time, output) 478 | array_targets = array_targets.write(time, projected_output) 479 | # Increment time and return 480 | return time + 1, projected_output, state, array_targets, array_outputs 481 | 482 | # Initial values for loop 483 | loop_init = [tf.constant(0, dtype=tf.int32), 484 | tf.expand_dims(previous_y, -1), 485 | encoder_state, 486 | tf.TensorArray(dtype=tf.float32, size=predict_days), 487 | tf.TensorArray(dtype=tf.float32, size=predict_days) if return_raw_outputs else tf.constant(0)] 488 | # Run the loop 489 | _, _, _, targets_ta, outputs_ta = tf.while_loop(cond_fn, loop_fn, loop_init) 490 | 491 | # Get final tensors from buffer arrays 492 | targets = targets_ta.stack() 493 | # [time, batch_size, 1] -> [time, batch_size] 494 | targets = tf.squeeze(targets, axis=-1) 495 | raw_outputs = outputs_ta.stack() if return_raw_outputs else None 496 | return targets, raw_outputs 497 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | pandas 3 | tqdm 4 | matplotlib 5 | tensorflow-gpu>=1.10 6 | 7 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import shutil 3 | import sys 4 | import numpy as np 5 | import tensorflow as tf 6 | from tqdm import trange 7 | from typing import List, Tuple 8 | import heapq 9 | import logging 10 | import pandas as pd 11 | from enum import Enum 12 | 13 | from hparams import build_from_set, build_hparams 14 | from feeder import VarFeeder 15 | from input_pipe import InputPipe, ModelMode, Splitter,FakeSplitter, page_features 16 | from model import Model 17 | import argparse 18 | 19 | 20 | log = logging.getLogger('trainer') 21 | 22 | class Ema: 23 | def __init__(self, k=0.99): 24 | self.k = k 25 | self.state = None 26 | self.steps = 0 27 | 28 | def __call__(self, *args, **kwargs): 29 | v = args[0] 30 | self.steps += 1 31 | if self.state is None: 32 | self.state = v 33 | else: 34 | eff_k = min(1 - 1 / self.steps, self.k) 35 | self.state = eff_k * self.state + (1 - eff_k) * v 36 | return self.state 37 | 38 | 39 | class Metric: 40 | def __init__(self, name: str, op, smoothness: float = None): 41 | self.name = name 42 | self.op = op 43 | self.smoother = Ema(smoothness) if smoothness else None 44 | self.epoch_values = [] 45 | self.best_value = np.Inf 46 | self.best_step = 0 47 | self.last_epoch = -1 48 | self.improved = False 49 | self._top = [] 50 | 51 | @property 52 | def avg_epoch(self): 53 | return np.mean(self.epoch_values) 54 | 55 | @property 56 | def best_epoch(self): 57 | return np.min(self.epoch_values) 58 | 59 | @property 60 | def last(self): 61 | return self.epoch_values[-1] if self.epoch_values else np.nan 62 | 63 | @property 64 | def top(self): 65 | return -np.mean(self._top) 66 | 67 | 68 | def update(self, value, epoch, step): 69 | if self.smoother: 70 | value = self.smoother(value) 71 | if epoch > self.last_epoch: 72 | self.epoch_values = [] 73 | self.last_epoch = epoch 74 | self.epoch_values.append(value) 75 | if value < self.best_value: 76 | self.best_value = value 77 | self.best_step = step 78 | self.improved = True 79 | else: 80 | self.improved = False 81 | if len(self._top) >= 5: 82 | heapq.heappushpop(self._top, -value) 83 | else: 84 | heapq.heappush(self._top, -value) 85 | 86 | 87 | class AggMetric: 88 | def __init__(self, metrics: List[Metric]): 89 | self.metrics = metrics 90 | 91 | def _mean(self, fun) -> float: 92 | # noinspection PyTypeChecker 93 | return np.mean([fun(metric) for metric in self.metrics]) 94 | 95 | @property 96 | def avg_epoch(self): 97 | return self._mean(lambda m: m.avg_epoch) 98 | 99 | @property 100 | def best_epoch(self): 101 | return self._mean(lambda m: m.best_epoch) 102 | 103 | @property 104 | def last(self): 105 | return self._mean(lambda m: m.last) 106 | 107 | @property 108 | def top(self): 109 | return self._mean(lambda m: m.top) 110 | 111 | @property 112 | def improved(self): 113 | return np.any([metric.improved for metric in self.metrics]) 114 | 115 | 116 | class DummyMetric: 117 | @property 118 | def avg_epoch(self): 119 | return np.nan 120 | 121 | @property 122 | def best_epoch(self): 123 | return np.nan 124 | 125 | @property 126 | def last(self): 127 | return np.nan 128 | 129 | @property 130 | def top(self): 131 | return np.nan 132 | 133 | @property 134 | def improved(self): 135 | return False 136 | 137 | @property 138 | def metrics(self): 139 | return [] 140 | 141 | 142 | class Stage(Enum): 143 | TRAIN = 0 144 | EVAL_SIDE = 1 145 | EVAL_FRWD = 2 146 | EVAL_SIDE_EMA = 3 147 | EVAL_FRWD_EMA = 4 148 | 149 | 150 | class ModelTrainerV2: 151 | def __init__(self, train_model: Model, eval: List[Tuple[Stage, Model]], model_no=0, 152 | patience=None, stop_metric=None, summary_writer=None): 153 | self.train_model = train_model 154 | if eval: 155 | self.eval_stages, self.eval_models = zip(*eval) 156 | else: 157 | self.eval_stages, self.eval_models = [], [] 158 | self.stopped = False 159 | self.model_no = model_no 160 | self.patience = patience 161 | self.best_metric = np.inf 162 | self.bad_epochs = 0 163 | self.stop_metric = stop_metric 164 | self.summary_writer = summary_writer 165 | 166 | def std_metrics(model: Model, smoothness): 167 | return [Metric('SMAPE', model.smape, smoothness), Metric('MAE', model.mae, smoothness)] 168 | 169 | self._metrics = {Stage.TRAIN: std_metrics(train_model, 0.9) + [Metric('GrNorm', train_model.glob_norm)]} 170 | for stage, model in eval: 171 | self._metrics[stage] = std_metrics(model, None) 172 | self.dict_metrics = {key: {metric.name: metric for metric in metrics} for key, metrics in self._metrics.items()} 173 | 174 | def init(self, sess): 175 | for model in list(self.eval_models) + [self.train_model]: 176 | model.inp.init_iterator(sess) 177 | 178 | @property 179 | def metrics(self): 180 | return self._metrics 181 | 182 | @property 183 | def train_ops(self): 184 | model = self.train_model 185 | return [model.train_op] # , model.summaries 186 | 187 | def metric_ops(self, key): 188 | return [metric.op for metric in self._metrics[key]] 189 | 190 | def process_metrics(self, key, run_results, epoch, step): 191 | metrics = self._metrics[key] 192 | summaries = [] 193 | for result, metric in zip(run_results, metrics): 194 | metric.update(result, epoch, step) 195 | summaries.append(tf.Summary.Value(tag=f"{key.name}/{metric.name}_0", simple_value=result)) 196 | return summaries 197 | 198 | def end_epoch(self): 199 | if self.stop_metric: 200 | best_metric = self.stop_metric(self.dict_metrics)# self.dict_metrics[Stage.EVAL_FRWD]['SMAPE'].avg_epoch 201 | if self.best_metric > best_metric: 202 | self.best_metric = best_metric 203 | self.bad_epochs = 0 204 | else: 205 | self.bad_epochs += 1 206 | if self.bad_epochs > self.patience: 207 | self.stopped = True 208 | 209 | 210 | class MultiModelTrainer: 211 | def __init__(self, trainers: List[ModelTrainerV2], inc_step_op, 212 | misc_global_ops=None): 213 | self.trainers = trainers 214 | self.inc_step = inc_step_op 215 | self.global_ops = misc_global_ops or [] 216 | self.eval_stages = trainers[0].eval_stages 217 | 218 | def active(self): 219 | return [trainer for trainer in self.trainers if not trainer.stopped] 220 | 221 | def _metric_step(self, stage, initial_ops, sess: tf.Session, epoch: int, step=None, repeats=1, summary_every=1): 222 | ops = initial_ops 223 | offsets, lengths = [], [] 224 | trainers = self.active() 225 | for trainer in trainers: 226 | offsets.append(len(ops)) 227 | metric_ops = trainer.metric_ops(stage) 228 | lengths.append(len(metric_ops)) 229 | ops.extend(metric_ops) 230 | if repeats > 1: 231 | all_results = np.stack([np.array(sess.run(ops)) for _ in range(repeats)]) 232 | results = np.mean(all_results, axis=0) 233 | else: 234 | results = sess.run(ops) 235 | if step is None: 236 | step = results[0] 237 | 238 | for trainer, offset, length in zip(trainers, offsets, lengths): 239 | chunk = results[offset: offset + length] 240 | summaries = trainer.process_metrics(stage, chunk, epoch, step) 241 | if trainer.summary_writer and step > 200 and (step % summary_every == 0): 242 | summary = tf.Summary(value=summaries) 243 | trainer.summary_writer.add_summary(summary, global_step=step) 244 | return results 245 | 246 | def train_step(self, sess: tf.Session, epoch: int): 247 | ops = [self.inc_step] + self.global_ops 248 | for trainer in self.active(): 249 | ops.extend(trainer.train_ops) 250 | results = self._metric_step(Stage.TRAIN, ops, sess, epoch, summary_every=20) 251 | #return results[:len(self.global_ops) + 1] # step, grad_norm 252 | return results[0] 253 | 254 | def eval_step(self, sess: tf.Session, epoch: int, step, n_batches, stages:List[Stage]=None): 255 | target_stages = stages if stages is not None else self.eval_stages 256 | for stage in target_stages: 257 | self._metric_step(stage, [], sess, epoch, step, repeats=n_batches) 258 | 259 | def metric(self, stage, name): 260 | return AggMetric([trainer.dict_metrics[stage][name] for trainer in self.trainers]) 261 | 262 | def end_epoch(self): 263 | for trainer in self.active(): 264 | trainer.end_epoch() 265 | 266 | def has_active(self): 267 | return len(self.active()) 268 | 269 | 270 | class ModelTrainer: 271 | def __init__(self, train_model, eval_model, model_no=0, summary_writer=None, keep_best=5, patience=None): 272 | self.train_model = train_model 273 | self.eval_model = eval_model 274 | self.stopped = False 275 | self.smooth_train_mae = Ema() 276 | self.smooth_train_smape = Ema() 277 | self.smooth_eval_mae = Ema(0.5) 278 | self.smooth_eval_smape = Ema(0.5) 279 | self.smooth_grad = Ema(0.9) 280 | self.summary_writer = summary_writer 281 | self.model_no = model_no 282 | self.best_top_n_loss = [] 283 | self.keep_best = keep_best 284 | self.best_step = 0 285 | self.patience = patience 286 | self.train_pipe = train_model.inp 287 | self.eval_pipe = eval_model.inp 288 | self.epoch_mae = [] 289 | self.epoch_smape = [] 290 | self.last_epoch = -1 291 | 292 | @property 293 | def train_ops(self): 294 | model = self.train_model 295 | return [model.train_op, model.update_ema, model.summaries, model.mae, model.smape, model.glob_norm] 296 | 297 | def process_train_results(self, run_results, offset, global_step, write_summary): 298 | offset += 2 299 | summaries, mae, smape, glob_norm = run_results[offset:offset + 4] 300 | results = self.smooth_train_mae(mae), self.smooth_train_smape(smape), self.smooth_grad(glob_norm) 301 | if self.summary_writer and write_summary: 302 | self.summary_writer.add_summary(summaries, global_step=global_step) 303 | return np.array(results) 304 | 305 | @property 306 | def eval_ops(self): 307 | model = self.eval_model 308 | return [model.mae, model.smape] 309 | 310 | @property 311 | def eval_len(self): 312 | return len(self.eval_ops) 313 | 314 | @property 315 | def train_len(self): 316 | return len(self.train_ops) 317 | 318 | @property 319 | def best_top_loss(self): 320 | return -np.array(self.best_top_n_loss).mean() 321 | 322 | @property 323 | def best_epoch_mae(self): 324 | return min(self.epoch_mae) if self.epoch_mae else np.NaN 325 | 326 | @property 327 | def mean_epoch_mae(self): 328 | return np.mean(self.epoch_mae) if self.epoch_mae else np.NaN 329 | 330 | @property 331 | def mean_epoch_smape(self): 332 | return np.mean(self.epoch_smape) if self.epoch_smape else np.NaN 333 | 334 | @property 335 | def best_epoch_smape(self): 336 | return min(self.epoch_smape) if self.epoch_smape else np.NaN 337 | 338 | def remember_for_epoch(self, epoch, mae, smape): 339 | if epoch > self.last_epoch: 340 | self.last_epoch = epoch 341 | self.epoch_mae = [] 342 | self.epoch_smape = [] 343 | self.epoch_mae.append(mae) 344 | self.epoch_smape.append(smape) 345 | 346 | @property 347 | def best_epoch_metrics(self): 348 | return np.array([self.best_epoch_mae, self.best_epoch_smape]) 349 | 350 | @property 351 | def mean_epoch_metrics(self): 352 | return np.array([self.mean_epoch_mae, self.mean_epoch_smape]) 353 | 354 | def process_eval_results(self, run_results, offset, global_step, epoch): 355 | totals = np.zeros(self.eval_len, np.float) 356 | for result in run_results: 357 | items = np.array(result[offset:offset + self.eval_len]) 358 | totals += items 359 | results = totals / len(run_results) 360 | mae, smape = results 361 | if self.summary_writer and global_step > 200: 362 | summary = tf.Summary(value=[ 363 | tf.Summary.Value(tag=f"test/MAE_{self.model_no}", simple_value=mae), 364 | tf.Summary.Value(tag=f"test/SMAPE_{self.model_no}", simple_value=smape), 365 | ]) 366 | self.summary_writer.add_summary(summary, global_step=global_step) 367 | smooth_mae = self.smooth_eval_mae(mae) 368 | smooth_smape = self.smooth_eval_smape(smape) 369 | self.remember_for_epoch(epoch, mae, smape) 370 | 371 | current_loss = -smooth_smape 372 | 373 | prev_best_n = np.mean(self.best_top_n_loss) if self.best_top_n_loss else -np.inf 374 | if self.best_top_n_loss: 375 | log.debug("Current loss=%.3f, old best=%.3f, wait steps=%d", -current_loss, 376 | -max(self.best_top_n_loss), global_step - self.best_step) 377 | 378 | if len(self.best_top_n_loss) >= self.keep_best: 379 | heapq.heappushpop(self.best_top_n_loss, current_loss) 380 | else: 381 | heapq.heappush(self.best_top_n_loss, current_loss) 382 | log.debug("Best loss=%.3f, top_5 avg loss=%.3f, top_5=%s", 383 | -max(self.best_top_n_loss), -np.mean(self.best_top_n_loss), 384 | ",".join(["%.3f" % -mae for mae in self.best_top_n_loss])) 385 | new_best_n = np.mean(self.best_top_n_loss) 386 | 387 | new_best = new_best_n > prev_best_n 388 | if new_best: 389 | self.best_step = global_step 390 | log.debug("New best step %d, current loss=%.3f", global_step, -current_loss) 391 | else: 392 | step_count = global_step - self.best_step 393 | if step_count > self.patience: 394 | self.stopped = True 395 | 396 | return mae, smape, new_best, smooth_mae, smooth_smape 397 | 398 | 399 | def train(name, hparams, multi_gpu=False, n_models=1, train_completeness_threshold=0.01, 400 | seed=None, logdir='data/logs', max_epoch=100, patience=2, train_sampling=1.0, 401 | eval_sampling=1.0, eval_memsize=5, gpu=0, gpu_allow_growth=False, save_best_model=False, 402 | forward_split=False, write_summaries=False, verbose=False, asgd_decay=None, tqdm=True, 403 | side_split=True, max_steps=None, save_from_step=None, do_eval=True, predict_window=63): 404 | 405 | eval_k = int(round(26214 * eval_memsize / n_models)) 406 | eval_batch_size = int( 407 | eval_k / (hparams.rnn_depth * hparams.encoder_rnn_layers)) # 128 -> 1024, 256->512, 512->256 408 | eval_pct = 0.1 409 | batch_size = hparams.batch_size 410 | train_window = hparams.train_window 411 | tf.reset_default_graph() 412 | if seed: 413 | tf.set_random_seed(seed) 414 | 415 | with tf.device("/cpu:0"): 416 | inp = VarFeeder.read_vars("data/vars") 417 | if side_split: 418 | splitter = Splitter(page_features(inp), inp.page_map, 3, train_sampling=train_sampling, 419 | test_sampling=eval_sampling, seed=seed) 420 | else: 421 | splitter = FakeSplitter(page_features(inp), 3, seed=seed, test_sampling=eval_sampling) 422 | 423 | real_train_pages = splitter.splits[0].train_size 424 | real_eval_pages = splitter.splits[0].test_size 425 | 426 | items_per_eval = real_eval_pages * eval_pct 427 | eval_batches = int(np.ceil(items_per_eval / eval_batch_size)) 428 | steps_per_epoch = real_train_pages // batch_size 429 | eval_every_step = int(round(steps_per_epoch * eval_pct)) 430 | # eval_every_step = int(round(items_per_eval * train_sampling / batch_size)) 431 | 432 | global_step = tf.train.get_or_create_global_step() 433 | inc_step = tf.assign_add(global_step, 1) 434 | 435 | 436 | all_models: List[ModelTrainerV2] = [] 437 | 438 | def create_model(scope, index, prefix, seed): 439 | 440 | with tf.variable_scope('input') as inp_scope: 441 | with tf.device("/cpu:0"): 442 | split = splitter.splits[index] 443 | pipe = InputPipe(inp, features=split.train_set, n_pages=split.train_size, 444 | mode=ModelMode.TRAIN, batch_size=batch_size, n_epoch=None, verbose=verbose, 445 | train_completeness_threshold=train_completeness_threshold, 446 | predict_completeness_threshold=train_completeness_threshold, train_window=train_window, 447 | predict_window=predict_window, 448 | rand_seed=seed, train_skip_first=hparams.train_skip_first, 449 | back_offset=predict_window if forward_split else 0) 450 | inp_scope.reuse_variables() 451 | if side_split: 452 | side_eval_pipe = InputPipe(inp, features=split.test_set, n_pages=split.test_size, 453 | mode=ModelMode.EVAL, batch_size=eval_batch_size, n_epoch=None, 454 | verbose=verbose, predict_window=predict_window, 455 | train_completeness_threshold=0.01, predict_completeness_threshold=0, 456 | train_window=train_window, rand_seed=seed, runs_in_burst=eval_batches, 457 | back_offset=predict_window * (2 if forward_split else 1)) 458 | else: 459 | side_eval_pipe = None 460 | if forward_split: 461 | forward_eval_pipe = InputPipe(inp, features=split.test_set, n_pages=split.test_size, 462 | mode=ModelMode.EVAL, batch_size=eval_batch_size, n_epoch=None, 463 | verbose=verbose, predict_window=predict_window, 464 | train_completeness_threshold=0.01, predict_completeness_threshold=0, 465 | train_window=train_window, rand_seed=seed, runs_in_burst=eval_batches, 466 | back_offset=predict_window) 467 | else: 468 | forward_eval_pipe = None 469 | avg_sgd = asgd_decay is not None 470 | #asgd_decay = 0.99 if avg_sgd else None 471 | train_model = Model(pipe, hparams, is_train=True, graph_prefix=prefix, asgd_decay=asgd_decay, seed=seed) 472 | scope.reuse_variables() 473 | 474 | eval_stages = [] 475 | if side_split: 476 | side_eval_model = Model(side_eval_pipe, hparams, is_train=False, 477 | #loss_mask=np.concatenate([np.zeros(50, dtype=np.float32), np.ones(10, dtype=np.float32)]), 478 | seed=seed) 479 | eval_stages.append((Stage.EVAL_SIDE, side_eval_model)) 480 | if avg_sgd: 481 | eval_stages.append((Stage.EVAL_SIDE_EMA, side_eval_model)) 482 | if forward_split: 483 | forward_eval_model = Model(forward_eval_pipe, hparams, is_train=False, seed=seed) 484 | eval_stages.append((Stage.EVAL_FRWD, forward_eval_model)) 485 | if avg_sgd: 486 | eval_stages.append((Stage.EVAL_FRWD_EMA, forward_eval_model)) 487 | 488 | if write_summaries: 489 | summ_path = f"{logdir}/{name}_{index}" 490 | if os.path.exists(summ_path): 491 | shutil.rmtree(summ_path) 492 | summ_writer = tf.summary.FileWriter(summ_path) # , graph=tf.get_default_graph() 493 | else: 494 | summ_writer = None 495 | if do_eval and forward_split: 496 | stop_metric = lambda metrics: metrics[Stage.EVAL_FRWD]['SMAPE'].avg_epoch 497 | else: 498 | stop_metric = None 499 | return ModelTrainerV2(train_model, eval_stages, index, patience=patience, 500 | stop_metric=stop_metric, 501 | summary_writer=summ_writer) 502 | 503 | 504 | if n_models == 1: 505 | with tf.device(f"/gpu:{gpu}"): 506 | scope = tf.get_variable_scope() 507 | all_models = [create_model(scope, 0, None, seed=seed)] 508 | else: 509 | for i in range(n_models): 510 | device = f"/gpu:{i}" if multi_gpu else f"/gpu:{gpu}" 511 | with tf.device(device): 512 | prefix = f"m_{i}" 513 | with tf.variable_scope(prefix) as scope: 514 | all_models.append(create_model(scope, i, prefix=prefix, seed=seed + i)) 515 | trainer = MultiModelTrainer(all_models, inc_step) 516 | if save_best_model or save_from_step: 517 | saver_path = f'data/cpt/{name}' 518 | if os.path.exists(saver_path): 519 | shutil.rmtree(saver_path) 520 | os.makedirs(saver_path) 521 | saver = tf.train.Saver(max_to_keep=10, name='train_saver') 522 | else: 523 | saver = None 524 | avg_sgd = asgd_decay is not None 525 | if avg_sgd: 526 | from itertools import chain 527 | def ema_vars(model): 528 | ema = model.train_model.ema 529 | return {ema.average_name(v):v for v in model.train_model.ema._averages} 530 | 531 | ema_names = dict(chain(*[ema_vars(model).items() for model in all_models])) 532 | #ema_names = all_models[0].train_model.ema.variables_to_restore() 533 | ema_loader = tf.train.Saver(var_list=ema_names, max_to_keep=1, name='ema_loader') 534 | ema_saver = tf.train.Saver(max_to_keep=1, name='ema_saver') 535 | else: 536 | ema_loader = None 537 | 538 | init = tf.global_variables_initializer() 539 | 540 | if forward_split and do_eval: 541 | eval_smape = trainer.metric(Stage.EVAL_FRWD, 'SMAPE') 542 | eval_mae = trainer.metric(Stage.EVAL_FRWD, 'MAE') 543 | else: 544 | eval_smape = DummyMetric() 545 | eval_mae = DummyMetric() 546 | 547 | if side_split and do_eval: 548 | eval_mae_side = trainer.metric(Stage.EVAL_SIDE, 'MAE') 549 | eval_smape_side = trainer.metric(Stage.EVAL_SIDE, 'SMAPE') 550 | else: 551 | eval_mae_side = DummyMetric() 552 | eval_smape_side = DummyMetric() 553 | 554 | train_smape = trainer.metric(Stage.TRAIN, 'SMAPE') 555 | train_mae = trainer.metric(Stage.TRAIN, 'MAE') 556 | grad_norm = trainer.metric(Stage.TRAIN, 'GrNorm') 557 | eval_stages = [] 558 | ema_eval_stages = [] 559 | if forward_split and do_eval: 560 | eval_stages.append(Stage.EVAL_FRWD) 561 | ema_eval_stages.append(Stage.EVAL_FRWD_EMA) 562 | if side_split and do_eval: 563 | eval_stages.append(Stage.EVAL_SIDE) 564 | ema_eval_stages.append(Stage.EVAL_SIDE_EMA) 565 | 566 | # gpu_options=tf.GPUOptions(allow_growth=False), 567 | with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, 568 | gpu_options=tf.GPUOptions(allow_growth=gpu_allow_growth))) as sess: 569 | sess.run(init) 570 | # pipe.load_vars(sess) 571 | inp.restore(sess) 572 | for model in all_models: 573 | model.init(sess) 574 | # if beholder: 575 | # visualizer = Beholder(session=sess, logdir=summ_path) 576 | step = 0 577 | prev_top = np.inf 578 | best_smape = np.inf 579 | # Contains best value (first item) and subsequent values 580 | best_epoch_smape = [] 581 | 582 | for epoch in range(max_epoch): 583 | 584 | # n_steps = pusher.n_pages // batch_size 585 | if tqdm: 586 | tqr = trange(steps_per_epoch, desc="%2d" % (epoch + 1), leave=False) 587 | else: 588 | tqr = range(steps_per_epoch) 589 | 590 | for _ in tqr: 591 | try: 592 | step = trainer.train_step(sess, epoch) 593 | except tf.errors.OutOfRangeError: 594 | break 595 | # if beholder: 596 | # if step % 5 == 0: 597 | # noinspection PyUnboundLocalVariable 598 | # visualizer.update() 599 | if step % eval_every_step == 0: 600 | if eval_stages: 601 | trainer.eval_step(sess, epoch, step, eval_batches, stages=eval_stages) 602 | 603 | if save_best_model and epoch > 0 and eval_smape.last < best_smape: 604 | best_smape = eval_smape.last 605 | saver.save(sess, f'data/cpt/{name}/cpt', global_step=step) 606 | if save_from_step and step >= save_from_step: 607 | saver.save(sess, f'data/cpt/{name}/cpt', global_step=step) 608 | 609 | if avg_sgd and ema_eval_stages: 610 | ema_saver.save(sess, 'data/cpt_tmp/ema', write_meta_graph=False) 611 | # restore ema-backed vars 612 | ema_loader.restore(sess, 'data/cpt_tmp/ema') 613 | 614 | trainer.eval_step(sess, epoch, step, eval_batches, stages=ema_eval_stages) 615 | # restore normal vars 616 | ema_saver.restore(sess, 'data/cpt_tmp/ema') 617 | 618 | MAE = "%.3f/%.3f/%.3f" % (eval_mae.last, eval_mae_side.last, train_mae.last) 619 | improvement = '↑' if eval_smape.improved else ' ' 620 | SMAPE = "%s%.3f/%.3f/%.3f" % (improvement, eval_smape.last, eval_smape_side.last, train_smape.last) 621 | if tqdm: 622 | tqr.set_postfix(gr=grad_norm.last, MAE=MAE, SMAPE=SMAPE) 623 | if not trainer.has_active() or (max_steps and step > max_steps): 624 | break 625 | 626 | if tqdm: 627 | tqr.close() 628 | trainer.end_epoch() 629 | if not best_epoch_smape or eval_smape.avg_epoch < best_epoch_smape[0]: 630 | best_epoch_smape = [eval_smape.avg_epoch] 631 | else: 632 | best_epoch_smape.append(eval_smape.avg_epoch) 633 | 634 | current_top = eval_smape.top 635 | if prev_top > current_top: 636 | prev_top = current_top 637 | has_best_indicator = '↑' 638 | else: 639 | has_best_indicator = ' ' 640 | status = "%2d: Best top SMAPE=%.3f%s (%s)" % ( 641 | epoch + 1, current_top, has_best_indicator, 642 | ",".join(["%.3f" % m.top for m in eval_smape.metrics])) 643 | 644 | if trainer.has_active(): 645 | status += ", frwd/side best MAE=%.3f/%.3f, SMAPE=%.3f/%.3f; avg MAE=%.3f/%.3f, SMAPE=%.3f/%.3f, %d am" % \ 646 | (eval_mae.best_epoch, eval_mae_side.best_epoch, eval_smape.best_epoch, eval_smape_side.best_epoch, 647 | eval_mae.avg_epoch, eval_mae_side.avg_epoch, eval_smape.avg_epoch, eval_smape_side.avg_epoch, 648 | trainer.has_active()) 649 | print(status, file=sys.stderr) 650 | else: 651 | print(status, file=sys.stderr) 652 | print("Early stopping!", file=sys.stderr) 653 | break 654 | if max_steps and step > max_steps: 655 | print("Max steps calculated", file=sys.stderr) 656 | break 657 | sys.stderr.flush() 658 | 659 | # noinspection PyUnboundLocalVariable 660 | return np.mean(best_epoch_smape, dtype=np.float64) 661 | 662 | 663 | def predict(checkpoints, hparams, return_x=False, verbose=False, predict_window=6, back_offset=0, n_models=1, 664 | target_model=0, asgd=False, seed=1, batch_size=1024): 665 | with tf.variable_scope('input') as inp_scope: 666 | with tf.device("/cpu:0"): 667 | inp = VarFeeder.read_vars("data/vars") 668 | pipe = InputPipe(inp, page_features(inp), inp.n_pages, mode=ModelMode.PREDICT, batch_size=batch_size, 669 | n_epoch=1, verbose=verbose, 670 | train_completeness_threshold=0.01, 671 | predict_window=predict_window, 672 | predict_completeness_threshold=0.0, train_window=hparams.train_window, 673 | back_offset=back_offset) 674 | asgd_decay = 0.99 if asgd else None 675 | if n_models == 1: 676 | model = Model(pipe, hparams, is_train=False, seed=seed, asgd_decay=asgd_decay) 677 | else: 678 | models = [] 679 | for i in range(n_models): 680 | prefix = f"m_{i}" 681 | with tf.variable_scope(prefix) as scope: 682 | models.append(Model(pipe, hparams, is_train=False, seed=seed, asgd_decay=asgd_decay, graph_prefix=prefix)) 683 | model = models[target_model] 684 | 685 | if asgd: 686 | var_list = model.ema.variables_to_restore() 687 | prefix = f"m_{target_model}" 688 | for var in list(var_list.keys()): 689 | if var.endswith('ExponentialMovingAverage') and not var.startswith(prefix): 690 | del var_list[var] 691 | else: 692 | var_list = None 693 | saver = tf.train.Saver(name='eval_saver', var_list=var_list) 694 | x_buffer = [] 695 | predictions = None 696 | with tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True))) as sess: 697 | pipe.load_vars(sess) 698 | for checkpoint in checkpoints: 699 | pred_buffer = [] 700 | pipe.init_iterator(sess) 701 | saver.restore(sess, checkpoint) 702 | cnt = 0 703 | while True: 704 | try: 705 | if return_x: 706 | pred, x, pname = sess.run([model.predictions, model.inp.true_x, model.inp.page_ix]) 707 | else: 708 | pred, pname = sess.run([model.predictions, model.inp.page_ix]) 709 | utf_names = [str(name, 'utf-8') for name in pname] 710 | pred_df = pd.DataFrame(index=utf_names, data=np.expm1(pred)) 711 | pred_buffer.append(pred_df) 712 | if return_x: 713 | # noinspection PyUnboundLocalVariable 714 | x_values = pd.DataFrame(index=utf_names, data=np.round(np.expm1(x)).astype(np.int64)) 715 | x_buffer.append(x_values) 716 | newline = cnt % 80 == 0 717 | if cnt > 0: 718 | print('.', end='\n' if newline else '', flush=True) 719 | if newline: 720 | print(cnt, end='') 721 | cnt += 1 722 | except tf.errors.OutOfRangeError: 723 | print('🎉') 724 | break 725 | cp_predictions = pd.concat(pred_buffer) 726 | if predictions is None: 727 | predictions = cp_predictions 728 | else: 729 | predictions += cp_predictions 730 | predictions /= len(checkpoints) 731 | offset = pd.Timedelta(back_offset, 'D') 732 | start_prediction = inp.data_end + pd.Timedelta('1D') - offset 733 | end_prediction = start_prediction + pd.Timedelta(predict_window - 1, 'D') 734 | predictions.columns = pd.date_range(start_prediction, end_prediction) 735 | if return_x: 736 | x = pd.concat(x_buffer) 737 | start_data = inp.data_end - pd.Timedelta(hparams.train_window - 1, 'D') - back_offset 738 | end_data = inp.data_end - back_offset 739 | x.columns = pd.date_range(start_data, end_data) 740 | return predictions, x 741 | else: 742 | return predictions 743 | 744 | 745 | if __name__ == '__main__': 746 | parser = argparse.ArgumentParser(description='Train the model') 747 | parser.add_argument('--name', default='s32', help='Model name to identify different logs/checkpoints') 748 | parser.add_argument('--hparam_set', default='s32', help="Hyperparameters set to use (see hparams.py for available sets)") 749 | parser.add_argument('--n_models', default=1, type=int, help="Jointly train n models with different seeds") 750 | parser.add_argument('--multi_gpu', default=False, action='store_true', help="Use multiple GPUs for multi-model training, one GPU per model") 751 | parser.add_argument('--seed', default=5, type=int, help="Random seed") 752 | parser.add_argument('--logdir', default='data/logs', help="Directory for summary logs") 753 | parser.add_argument('--max_epoch', type=int, default=100, help="Max number of epochs") 754 | parser.add_argument('--patience', type=int, default=2, help="Early stopping: stop after N epochs without improvement. Requires do_eval=True") 755 | parser.add_argument('--train_sampling', type=float, default=1.0, help="Sample this percent of data for training") 756 | parser.add_argument('--eval_sampling', type=float, default=1.0, help="Sample this percent of data for evaluation") 757 | parser.add_argument('--eval_memsize', type=int, default=5, help="Approximate amount of avalable memory on GPU, used for calculation of optimal evaluation batch size") 758 | parser.add_argument('--gpu', default=0, type=int, help='GPU instance to use') 759 | parser.add_argument('--gpu_allow_growth', default=False, action='store_true', help='Allow to gradually increase GPU memory usage instead of grabbing all available memory at start') 760 | parser.add_argument('--save_best_model', default=False, action='store_true', help='Save best model during training. Requires do_eval=True') 761 | parser.add_argument('--no_forward_split', default=True, dest='forward_split', action='store_false', help='Use walk-forward split for model evaluation. Requires do_eval=True') 762 | parser.add_argument('--side_split', default=False, action='store_true', help='Use side split for model evaluation. Requires do_eval=True') 763 | parser.add_argument('--no_eval', default=True, dest='do_eval', action='store_false', help="Don't evaluate model quality during training") 764 | parser.add_argument('--no_summaries', default=True, dest='write_summaries', action='store_false', help="Don't Write Tensorflow summaries") 765 | parser.add_argument('--verbose', default=False, action='store_true', help='Print additional information during graph construction') 766 | parser.add_argument('--asgd_decay', type=float, help="EMA decay for averaged SGD. Not use ASGD if not set") 767 | parser.add_argument('--no_tqdm', default=True, dest='tqdm', action='store_false', help="Don't use tqdm for status display during training") 768 | parser.add_argument('--max_steps', type=int, help="Stop training after max steps") 769 | parser.add_argument('--save_from_step', type=int, help="Save model on each evaluation (10 evals per epoch), starting from this step") 770 | parser.add_argument('--predict_window', default=63, type=int, help="Number of days to predict") 771 | args = parser.parse_args() 772 | 773 | param_dict = dict(vars(args)) 774 | param_dict['hparams'] = build_from_set(args.hparam_set) 775 | del param_dict['hparam_set'] 776 | train(**param_dict) 777 | 778 | # hparams = build_hparams() 779 | # result = train("definc_attn", hparams, n_models=1, train_sampling=1.0, eval_sampling=1.0, patience=5, multi_gpu=True, 780 | # save_best_model=False, gpu=0, eval_memsize=15, seed=5, verbose=True, forward_split=False, 781 | # write_summaries=True, side_split=True, do_eval=False, predict_window=63, asgd_decay=None, max_steps=11500, 782 | # save_from_step=10500) 783 | 784 | # print("Training result:", result) 785 | # preds = predict('data/cpt/fair_365-15428', 380, hparams, verbose=True, back_offset=60, n_models=3) 786 | # print(preds) 787 | --------------------------------------------------------------------------------