├── .gitignore
├── Changelog.md
├── LICENSE
├── Readme.md
├── cocob.py
├── data
├── 2017-08-15_2017-09-11.csv.zip
└── placeholder
├── extractor.py
├── feeder.py
├── how_it_works.md
├── hparams.py
├── images
├── attention.xml
├── autocorr.png
├── encoder-decoder.png
├── encoder-decoder.xml
├── from_past.png
├── lagged_data.png
├── losses_0.png
├── losses_1.png
├── predictions.png
├── split.png
├── training.png
└── validation-split.xml
├── input_pipe.py
├── make_features.py
├── model.py
├── requirements.txt
├── submission-final.ipynb
└── trainer.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea/*
2 |
3 | # Jupyter Notebook
4 | .ipynb_checkpoints
5 |
6 | data/cpt
7 | data/logs
8 | data/vars
9 | data/*.pkl
10 | data/*.zip
11 | data/submission.csv.gz
12 | !data/2017-08-15_2017-09-11.csv.zip
13 |
14 |
--------------------------------------------------------------------------------
/Changelog.md:
--------------------------------------------------------------------------------
1 | 2018-10-15
2 | - Model updated to work with a modern Tensorflow (>=1.10)
3 | - Switched to Adam instead of COCOB (COCOB don't works with TF > 1.4)
4 | - No parameter tuning for Adam performed, therefore model probably has
5 | suboptimal training rate and did'nt reproduce exact result from the competition
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Artur Suilin
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/Readme.md:
--------------------------------------------------------------------------------
1 | # Kaggle Web Traffic Time Series Forecasting
2 | 1st place solution
3 |
4 | 
5 |
6 | Main files:
7 | * `make_features.py` - builds features from source data
8 | * `input_pipe.py` - TF data preprocessing pipeline (assembles features
9 | into training/evaluation tensors, performs some sampling and normalisation)
10 | * `model.py` - the model
11 | * `trainer.py` - trains the model(s)
12 | * `hparams.py` - hyperpatameter sets.
13 | * `submission-final.ipynb` - generates predictions for submission
14 |
15 | How to reproduce competition results:
16 | 1. Download input files from https://www.kaggle.com/c/web-traffic-time-series-forecasting/data :
17 | `key_2.csv.zip`, `train_2.csv.zip`, put them into `data` directory.
18 | 2. Run `python make_features.py data/vars --add_days=63`. It will
19 | extract data and features from the input files and put them into
20 | `data/vars` as Tensorflow checkpoint.
21 | 3. Run trainer:
22 | `python trainer.py --name s32 --hparam_set=s32 --n_models=3 --name s32 --no_eval --no_forward_split
23 | --asgd_decay=0.99 --max_steps=11500 --save_from_step=10500`. This command
24 | will simultaneously train 3 models on different seeds (on a single TF graph)
25 | and save 10 checkpoints from step 10500 to step 11500 to `data/cpt`.
26 | __Note:__ training requires GPU, because of cuDNN usage. CPU training will not work.
27 | If you have 3 or more GPUs, add `--multi_gpu` flag to speed up the training. One can also try different
28 | hyperparameter sets (described in `hparams.py`): `--hparam_set=definc`,
29 | `--hparam_set=inst81`, etc.
30 | Don't be afraid of displayed NaN losses during training. This is normal,
31 | because we do the training in a blind mode, without any evaluation of model performance.
32 | 4. Run `submission-final.ipynb` in a standard jupyter notebook environment,
33 | execute all cells. Prediction will take some time, because it have to
34 | load and evaluate 30 different model weights. At the end,
35 | you'll get `submission.csv.gz` file in `data` directory.
36 |
37 | See also [detailed model description](how_it_works.md)
38 |
--------------------------------------------------------------------------------
/cocob.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Francesco Orabona. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 |
17 | """
18 | COntinuos COin Betting (COCOB) optimizer
19 | See 'Training Deep Networks without Learning Rates Through Coin Betting'
20 | https://arxiv.org/abs/1705.07795
21 | """
22 |
23 | from tensorflow.python.framework import ops
24 | from tensorflow.python.ops import state_ops
25 | from tensorflow.python.ops import control_flow_ops
26 | from tensorflow.python.framework import constant_op
27 | from tensorflow.python.training.optimizer import Optimizer
28 | import tensorflow as tf
29 |
30 |
31 |
32 | class COCOB(Optimizer):
33 | def __init__(self, alpha=100, use_locking=False, name='COCOB'):
34 | '''
35 | constructs a new COCOB optimizer
36 | '''
37 | super(COCOB, self).__init__(use_locking, name)
38 | self._alpha = alpha
39 |
40 | def _create_slots(self, var_list):
41 | for v in var_list:
42 | with ops.colocate_with(v):
43 | gradients_sum = constant_op.constant(0,
44 | shape=v.get_shape(),
45 | dtype=v.dtype.base_dtype)
46 | grad_norm_sum = constant_op.constant(0,
47 | shape=v.get_shape(),
48 | dtype=v.dtype.base_dtype)
49 | L = constant_op.constant(1e-8, shape=v.get_shape(), dtype=v.dtype.base_dtype)
50 | tilde_w = constant_op.constant(0.0, shape=v.get_shape(), dtype=v.dtype.base_dtype)
51 | reward = constant_op.constant(0.0, shape=v.get_shape(), dtype=v.dtype.base_dtype)
52 |
53 | self._get_or_make_slot(v, L, "L", self._name)
54 | self._get_or_make_slot(v, grad_norm_sum, "grad_norm_sum", self._name)
55 | self._get_or_make_slot(v, gradients_sum, "gradients_sum", self._name)
56 | self._get_or_make_slot(v, tilde_w, "tilde_w", self._name)
57 | self._get_or_make_slot(v, reward, "reward", self._name)
58 |
59 | def _apply_dense(self, grad, var):
60 | gradients_sum = self.get_slot(var, "gradients_sum")
61 | grad_norm_sum = self.get_slot(var, "grad_norm_sum")
62 | tilde_w = self.get_slot(var, "tilde_w")
63 | L = self.get_slot(var, "L")
64 | reward = self.get_slot(var, "reward")
65 |
66 | L_update = tf.maximum(L, tf.abs(grad))
67 | gradients_sum_update = gradients_sum + grad
68 | grad_norm_sum_update = grad_norm_sum + tf.abs(grad)
69 | reward_update = tf.maximum(reward - grad * tilde_w, 0)
70 | new_w = -gradients_sum_update / (
71 | L_update * (tf.maximum(grad_norm_sum_update + L_update, self._alpha * L_update))) * (reward_update + L_update)
72 | var_update = var - tilde_w + new_w
73 | tilde_w_update = new_w
74 |
75 | gradients_sum_update_op = state_ops.assign(gradients_sum, gradients_sum_update)
76 | grad_norm_sum_update_op = state_ops.assign(grad_norm_sum, grad_norm_sum_update)
77 | var_update_op = state_ops.assign(var, var_update)
78 | tilde_w_update_op = state_ops.assign(tilde_w, tilde_w_update)
79 | L_update_op = state_ops.assign(L, L_update)
80 | reward_update_op = state_ops.assign(reward, reward_update)
81 |
82 | return control_flow_ops.group(*[gradients_sum_update_op,
83 | var_update_op,
84 | grad_norm_sum_update_op,
85 | tilde_w_update_op,
86 | reward_update_op,
87 | L_update_op])
88 |
89 | def _apply_sparse(self, grad, var):
90 | return self._apply_dense(grad, var)
91 |
92 | def _resource_apply_dense(self, grad, handle):
93 | return self._apply_dense(grad, handle)
94 |
--------------------------------------------------------------------------------
/data/2017-08-15_2017-09-11.csv.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Arturus/kaggle-web-traffic/a9abb80c800409abf0ece21ea244ef779f758f96/data/2017-08-15_2017-09-11.csv.zip
--------------------------------------------------------------------------------
/data/placeholder:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Arturus/kaggle-web-traffic/a9abb80c800409abf0ece21ea244ef779f758f96/data/placeholder
--------------------------------------------------------------------------------
/extractor.py:
--------------------------------------------------------------------------------
1 | import re
2 | import pandas as pd
3 | import numpy as np
4 |
5 | term_pat = re.compile('(.+?):(.+)')
6 | pat = re.compile(
7 | '(.+)_([a-z][a-z]\.)?((?:wikipedia\.org)|(?:commons\.wikimedia\.org)|(?:www\.mediawiki\.org))_([a-z_-]+?)$')
8 |
9 | # Debug output to ensure pattern still works
10 | # print(pat.fullmatch('BLEACH_zh.wikipedia.org_all-accessspider').groups())
11 | # print(pat.fullmatch('Accueil_commons.wikimedia.org_all-access_spider').groups())
12 |
13 |
14 | def extract(source) -> pd.DataFrame:
15 | """
16 | Extracts features from url. Features: agent, site, country, term, marker
17 | :param source: urls
18 | :return: DataFrame, one column per feature
19 | """
20 | if isinstance(source, pd.Series):
21 | source = source.values
22 | agents = np.full_like(source, np.NaN)
23 | sites = np.full_like(source, np.NaN)
24 | countries = np.full_like(source, np.NaN)
25 | terms = np.full_like(source, np.NaN)
26 | markers = np.full_like(source, np.NaN)
27 |
28 | for i in range(len(source)):
29 | l = source[i]
30 | match = pat.fullmatch(l)
31 | assert match, "Non-matched string %s" % l
32 | term = match.group(1)
33 | country = match.group(2)
34 | if country:
35 | countries[i] = country[:-1]
36 | site = match.group(3)
37 | sites[i] = site
38 | agents[i] = match.group(4)
39 | if site != 'wikipedia.org':
40 | term_match = term_pat.match(term)
41 | if term_match:
42 | markers[i] = term_match.group(1)
43 | term = term_match.group(2)
44 | terms[i] = term
45 |
46 | return pd.DataFrame({
47 | 'agent': agents,
48 | 'site': sites,
49 | 'country': countries,
50 | 'term': terms,
51 | 'marker': markers,
52 | 'page': source
53 | })
54 |
--------------------------------------------------------------------------------
/feeder.py:
--------------------------------------------------------------------------------
1 | from collections import UserList, UserDict
2 | from typing import Union, Iterable, Tuple, Dict, Any
3 |
4 | import tensorflow as tf
5 | import numpy as np
6 | import pandas as pd
7 | import pickle
8 | import os.path
9 |
10 |
11 | def _meta_file(path):
12 | return os.path.join(path, 'feeder_meta.pkl')
13 |
14 |
15 | class VarFeeder:
16 | """
17 | Helper to avoid feed_dict and manual batching. Maybe I had to use TFRecords instead.
18 | Builds temporary TF graph, injects variables into, and saves variables to TF checkpoint.
19 | In a train time, variables can be built by build_vars() and content restored by FeederVars.restore()
20 | """
21 | def __init__(self, path: str,
22 | tensor_vars: Dict[str, Union[pd.DataFrame, pd.Series, np.ndarray]] = None,
23 | plain_vars: Dict[str, Any] = None):
24 | """
25 | :param path: dir to store data
26 | :param tensor_vars: Variables to save as Tensors (pandas DataFrames/Series or numpy arrays)
27 | :param plain_vars: Variables to save as Python objects
28 | """
29 | tensor_vars = tensor_vars or dict()
30 |
31 | def get_values(v):
32 | v = v.values if hasattr(v, 'values') else v
33 | if not isinstance(v, np.ndarray):
34 | v = np.array(v)
35 | if v.dtype == np.float64:
36 | v = v.astype(np.float32)
37 | return v
38 |
39 | values = [get_values(var) for var in tensor_vars.values()]
40 |
41 | self.shapes = [var.shape for var in values]
42 | self.dtypes = [v.dtype for v in values]
43 | self.names = list(tensor_vars.keys())
44 | self.path = path
45 | self.plain_vars = plain_vars
46 |
47 | if not os.path.exists(path):
48 | os.mkdir(path)
49 |
50 | with open(_meta_file(path), mode='wb') as file:
51 | pickle.dump(self, file)
52 |
53 | with tf.Graph().as_default():
54 | tensor_vars = self._build_vars()
55 | placeholders = [tf.placeholder(tf.as_dtype(dtype), shape=shape) for dtype, shape in
56 | zip(self.dtypes, self.shapes)]
57 | assigners = [tensor_var.assign(placeholder) for tensor_var, placeholder in
58 | zip(tensor_vars, placeholders)]
59 | feed = {ph: v for ph, v in zip(placeholders, values)}
60 | saver = tf.train.Saver(self._var_dict(tensor_vars), max_to_keep=1)
61 | init = tf.global_variables_initializer()
62 |
63 | with tf.Session(config=tf.ConfigProto(device_count={'GPU': 0})) as sess:
64 | sess.run(init)
65 | sess.run(assigners, feed_dict=feed)
66 | save_path = os.path.join(path, 'feeder.cpt')
67 | saver.save(sess, save_path, write_meta_graph=False, write_state=False)
68 |
69 | def _var_dict(self, variables):
70 | return {name: var for name, var in zip(self.names, variables)}
71 |
72 | def _build_vars(self):
73 | def make_tensor(shape, dtype, name):
74 | tf_type = tf.as_dtype(dtype)
75 | if tf_type == tf.string:
76 | empty = ''
77 | elif tf_type == tf.bool:
78 | empty = False
79 | else:
80 | empty = 0
81 | init = tf.constant(empty, shape=shape, dtype=tf_type)
82 | return tf.get_local_variable(name=name, initializer=init, dtype=tf_type)
83 |
84 | with tf.device("/cpu:0"):
85 | with tf.name_scope('feeder_vars'):
86 | return [make_tensor(shape, dtype, name) for shape, dtype, name in
87 | zip(self.shapes, self.dtypes, self.names)]
88 |
89 | def create_vars(self):
90 | """
91 | Builds variable list to use in current graph. Should be called during graph building stage
92 | :return: variable list with additional restore and create_saver methods
93 | """
94 | return FeederVars(self._var_dict(self._build_vars()), self.plain_vars, self.path)
95 |
96 | @staticmethod
97 | def read_vars(path):
98 | with open(_meta_file(path), mode='rb') as file:
99 | feeder = pickle.load(file)
100 | assert feeder.path == path
101 | return feeder.create_vars()
102 |
103 |
104 | class FeederVars(UserDict):
105 | def __init__(self, tensors: dict, plain_vars: dict, path):
106 | variables = dict(tensors)
107 | if plain_vars:
108 | variables.update(plain_vars)
109 | super().__init__(variables)
110 | self.path = path
111 | self.saver = tf.train.Saver(tensors, name='varfeeder_saver')
112 | for var in variables:
113 | if var not in self.__dict__:
114 | self.__dict__[var] = variables[var]
115 |
116 | def restore(self, session):
117 | """
118 | Restores variable content
119 | :param session: current session
120 | :return: variable list
121 | """
122 | self.saver.restore(session, os.path.join(self.path, 'feeder.cpt'))
123 | return self
124 |
--------------------------------------------------------------------------------
/how_it_works.md:
--------------------------------------------------------------------------------
1 | # How it works
2 | __TL;DR__ this is seq2seq model with some additions to utilize year-to-year
3 | and quarter-to-quarter seasonality in data.
4 | ___
5 | There are two main information sources for prediction:
6 | 1. Local features. If we see a trend, we expect that it will continue
7 | (AutoRegressive model), if we see a traffic spike, it will gradually decay (Moving Average model),
8 | if wee see more traffic on holidays, we expect to have more traffic on
9 | holidays in the future (seasonal model).
10 | 2. Global features. If we look to autocorrelation plot, we'll notice strong
11 | year-to-year autocorrelation and some quarter-to-quarter autocorrelation.
12 |
13 | 
14 |
15 | The good model should use both global and local features, combining them
16 | in a intelligent way.
17 |
18 | I decided to use RNN seq2seq model for prediction, because:
19 | 1. RNN can be thought as a natural extension of well-studied ARIMA models, but much more
20 | flexible and expressive.
21 | 2. RNN is non-parametric, that's greatly simplifies learning.
22 | Imagine working with different ARIMA parameters for 145K timeseries.
23 | 3. Any exogenous feature (numerical or categorical, time-dependent or series-dependent)
24 | can be easily injected into the model
25 | 4. seq2seq seems natural for this task: we predict next values, conditioning on joint
26 | probability of previous values, including our past predictions. Use of past predictions
27 | stabilizes the model, it learns to be conservative, because error accumulates on each step,
28 | and extreme prediction at one step can ruin prediction quality for all subsequent steps.
29 | 5. Deep Learning is all the hype nowadays
30 |
31 | ## Feature engineering
32 | I tried to be minimalistic, because RNN is powerful enough to discover
33 | and learn features on its own.
34 | Model feature list:
35 | * *pageviews* (spelled as 'hits' in the model code, because of my web-analytics background).
36 | Raw values transformed by log1p() to get more-or-less normal intra-series values distribution,
37 | instead of skewed one.
38 | * *agent*, *country*, *site* - these features are extracted from page urls and one-hot encoded
39 | * *day of week* - to capture weekly seasonality
40 | * *year-to-year autocorrelation*, *quarter-to-quarter autocorrelation* - to capture yearly and quarterly seasonality strength.
41 | * *page popularity* - High traffic and low traffic pages have different traffic change patterns,
42 | this feature (median of pageviews) helps to capture traffic scale.
43 | This scale information is lost in a *pageviews* feature, because each pageviews series
44 | independently normalized to zero mean and unit variance.
45 | * *lagged pageviews* - I'll describe this feature later
46 |
47 | ## Feature preprocessing
48 | All features (including one-hot encoded) are normalized to zero mean and unit variance. Each *pageviews*
49 | series normalized independently.
50 |
51 | Time-independent features (autocorrelations, country, etc) are "stretched" to timeseries length
52 | i.e. repeated for each day by `tf.tile()` command.
53 |
54 | Model trains on random fixed-length samples from original timeseries. For example,
55 | if original timeseries length is 600 days, and we use 200-day samples for training,
56 | we'll have a choice of 400 days to start the sample.
57 |
58 | This sampling works as effective data augmentation mechanism:
59 | training code randomly chooses starting point for each timeseries on each
60 | step, generating endless stream of almost non-repeating data.
61 |
62 |
63 | ## Model core
64 | Model has two main parts: encoder and decoder.
65 |
66 | 
67 |
68 | Encoder is [cuDNN GRU](https://www.tensorflow.org/versions/master/api_docs/python/tf/contrib/cudnn_rnn/CudnnGRU). cuDNN works much faster (5x-10x) than native Tensorflow RNNCells, at the cost
69 | of some inconvenience to use and poor documentation.
70 |
71 | Decoder is TF `GRUBlockCell`, wrapped in `tf.while_loop()` construct. Code
72 | inside the loop gets prediction from previous step and
73 | appends it to the input features for current step.
74 |
75 | ## Working with long timeseries
76 | LSTM/GRU is a great solution for relatively short sequences, up to 100-300 items.
77 | On longer sequences LSTM/GRU still works, but can gradually forget information from the oldest items.
78 | Competition timeseries is up to 700 days long, so I have to find some method
79 | to "strengthen" GRU memory.
80 |
81 | My first method was to use some kind of *[attention](https://distill.pub/2016/augmented-rnns)*.
82 | Attention can bring useful information from a distant past to the current
83 | RNN cell. The simplest yet effective attention method for our problem is a
84 | fixed-weight sliding-window attention.
85 | There are two most important points in a distant past (taking into
86 | account long-term seasonality): 1) year ago, 2) quarter ago.
87 |
88 | 
89 |
90 | I can just take
91 | encoder outputs from `current_day - 365` and `current_day - 90` timepoints,
92 | pass them through FC layer to reduce dimensionality and append result to input
93 | features for decoder. This solution, despite of being simple, considerably
94 | lowered prediction error.
95 |
96 | Then I averaged important points with their neighbors to reduce noise and
97 | compensate uneven intervals (leap years, different month lengths):
98 | `attn_365 = 0.25 * day_364 + 0.5 * day_365 + 0.25 * day_366`
99 |
100 | Then I realized that `0.25,0.5,0.25` is a 1D convolutional kernel (length=3)
101 | and I can automatically learn bigger kernel to detect important points in a past.
102 |
103 | I ended up with a monstrous attention mechanism, it looks into 'fingerprint'
104 | of each timeseries (fingerprint produced by small ConvNet), decides
105 | which points to attend and produces weights for big convolution kernel.
106 | This big kernel, applied to decoder outputs,
107 | produces attention features for each prediction day. This monster is still
108 | alive and can be found in a model code.
109 |
110 | Note, I did'nt used classical attention scheme (Bahdanau or Luong attention),
111 | because classical attention should be recalculated from scratch on every prediction step,
112 | using all historical datapoints. This will take too much time for our long (~2 years) timeseries.
113 | My scheme, one convolution per all datapoints, uses same attention
114 | weights for all prediction steps (that's drawback), but much faster to compute.
115 |
116 | Unsatisfied by complexity of attention mechanics, I tried to remove attention
117 | completely and just take important (year, halfyear, quarter ago) datapoints
118 | from the past and use them as an additional
119 | features for encoder and decoder. That worked surprisingly well, even slightly
120 | surpassing attention in prediction quality. My best public score was
121 | achieved using only lagged datapoints, without attention.
122 |
123 | 
124 |
125 | Additional important benefit of lagged datapoints: model can use much shorter
126 | encoder without fear of losing information from the past, because this
127 | information now explicitly contained in features. Even 60-90 days long
128 | encoder still gives acceptable results, in contrast to 300-400 days
129 | required for previous models. Shorter encoder = faster training and less
130 | loss of information
131 |
132 | ## Losses and regularization
133 | [SMAPE](https://en.wikipedia.org/wiki/Symmetric_mean_absolute_percentage_error)
134 | (target loss for competition) can't be used directly, because of unstable
135 | behavior near zero values (loss is a step function if truth value is zero,
136 | and not defined, if predicted value is also zero).
137 |
138 | I used smoothed differentiable SMAPE variant, which is well-behaved at all real numbers:
139 | ```python
140 | epsilon = 0.1
141 | summ = tf.maximum(tf.abs(true) + tf.abs(predicted) + epsilon, 0.5 + epsilon)
142 | smape = tf.abs(predicted - true) / summ * 2.0
143 | ```
144 | 
145 | 
146 |
147 | Another possible choice is MAE loss on `log1p(data)`, it's smooth almost everywhere
148 | and close enough to SMAPE for training purposes.
149 |
150 | Final predictions were rounded to the closest integer, negative predictions clipped at zero.
151 |
152 | I tried to use RNN activation regularizations from the paper
153 | ["Regularizing RNNs by Stabilizing Activations"](https://arxiv.org/abs/1511.08400),
154 | because internal weights in cuDNN GRU can't be directly regularized
155 | (or I did not found a right way to do this).
156 | Stability loss didn't work at all, activation loss gave some very
157 | slight improvement for low (1e-06..1e-05) loss weights.
158 |
159 | ## Training and validation
160 | I used COCOB optimizer (see paper [Training Deep Networks without Learning Rates Through Coin Betting](https://arxiv.org/abs/1705.07795)) for training, in combination with gradient clipping.
161 | COCOB tries to predict optimal learning rate for every training step, so
162 | I don't have to tune learning rate at all. It also converges considerably
163 | faster than traditional momentum-based optimizers, especially on first
164 | epochs, allowing me to stop unsuccessful experiments early.
165 |
166 | There are two ways to split timeseries into training and validation datasets:
167 | 1. *Walk-forward split*. This is not actually a split: we train on full dataset
168 | and validate on full dataset, using different timeframes. Timeframe
169 | for validation is shifted forward by one prediction interval relative to
170 | timeframe for training.
171 | 2. *Side-by-side split*. This is traditional split model for mainstream machine
172 | learning. Dataset splits into independent parts, one part used strictly
173 | for training and another part used strictly for validation.
174 |
175 | 
176 |
177 | I tried both ways.
178 |
179 | Walk-forward is preferable, because it directly relates to the competition goal:
180 | predict future values using historical values. But this split consumes
181 | datapoints at the end of timeseries, thus making hard to train model to
182 | precisely predict the future.
183 |
184 | Let's explain: for example, we have 300 days of historical data
185 | and want to predict next 100 days. If we choose walk-forward split, we'll have to use
186 | first 100 days for real training, next 100 days for training-mode prediction
187 | (run decoder and calculate losses), next 100 days for validation and
188 | next 100 days for actual prediction of future values.
189 | So we actually can use only 1/3 of available datapoints for training
190 | and will have 200 days gap between last training datapoint
191 | and first prediction datapoint. That's too much, because prediction quality
192 | falls exponentially as we move away from a training data (uncertainty grows).
193 | Model trained with a 100 days gap (instead of 200) would have considerable
194 | better quality.
195 |
196 | Side-by-side split is more economical, as it don't consumes datapoints at the
197 | end. That was a good news. Now the bad news: for our data, model performance
198 | on validation dataset is strongly correlated to performance on training dataset,
199 | and almost uncorrelated to the actual model performance in a future. In other words,
200 | side-by-side split is useless for our problem, it just duplicates
201 | model loss observed on training data.
202 |
203 | Resume?
204 |
205 | I used validation (with walk-forward split) only for model tuning.
206 | Final model to predict future values was trained in blind mode, without any validation.
207 |
208 |
209 |
210 | ## Reducing model variance
211 | Model has inevitably high variance due to very noisy input data. To be fair,
212 | I was surprised that RNN learns something at all on such noisy inputs.
213 |
214 | Same model trained on different seeds can have different performance,
215 | sometimes model even diverges on "unfortunate" seeds. During training, performance also wildly
216 | fluctuates from step to step. I can't just rely on pure luck (be on right
217 | seed and stop on right training step) to win the competition,
218 | so I had to take actions to reduce variance.
219 |
220 | 
221 | 1. I don't know which training step would be best for predicting the future
222 | (validation result on current data is very weakly correlated with a
223 | result on a future data), so I can't use early stopping. But I know
224 | approximate region where model is (possibly) trained well enough,
225 | but (possibly) not started to overfit. I decided to set this optimal region
226 | bounds to 10500..11500 training steps and save 10 checkpoints from each 100th step
227 | in this region.
228 | 2. Similarly, I decided to train 3 models on different seeds and save checkpoints
229 | from each model. So I have 30 checkpoints total.
230 | 3. One widely known method for reducing variance and improving model performance
231 | is SGD averaging (ASGD). Method is very simple and well supported
232 | in [Tensorflow](https://www.tensorflow.org/versions/r0.12/api_docs/python/train/moving_averages) -
233 | we have to maintain moving averages of network weights during training and use these
234 | averaged weights, instead of original ones, during inference.
235 |
236 | Combination of all three methods (average predictions from 30 checkpoints
237 | using averaged model weights in each checkpoint) worked well, I got
238 | roughly the same SMAPE error on leaderboard (for future data)
239 | as for validation on historical data.
240 |
241 | Theoretically, one can also consider two first methods as a kind of ensemble
242 | learning, that's right, but I used them mainly for variance reduction.
243 |
244 | ## Hyperparameter tuning
245 | There are many model parameters (number of layers, layer depths,
246 | activation functions, dropout coefficents, etc) that can be (and should be) tuned to
247 | achieve optimal model performance. Manual tuning is tedious and
248 | time-consuming process, so I decided to automate it and use [SMAC3](https://automl.github.io/SMAC3/stable/) package for hyperparameter search.
249 | Some benefits of SMAC3:
250 | * Support for conditional parameters (e.g. jointly tune number of layers
251 | and dropout for each layer; dropout on second layer will be tuned only if
252 | n_layers > 1)
253 | * Explicit handling of model variance. SMAC trains several instances
254 | of each model on different seeds, and compares models only if instances were
255 | trained on same seed. One model wins if it's better than another model on all equal seeds.
256 |
257 | Contrary to my expectations, hyperparamter search did not found well-defined global minima.
258 | All best models had roughly the same performance, but different parameters.
259 | Probably RNN model is too expressive for this task, and best model score
260 | depends more on the data signal-to-noise ratio than on the model architecture.
261 |
262 | Anyway, best parameters sets can be found in `hparams.py` file
263 |
--------------------------------------------------------------------------------
/hparams.py:
--------------------------------------------------------------------------------
1 | import tensorflow.contrib.training as training
2 | import re
3 |
4 | # Manually selected params
5 | params_s32 = dict(
6 | batch_size=256,
7 | #train_window=380,
8 | train_window=283,
9 | train_skip_first=0,
10 | rnn_depth=267,
11 | use_attn=False,
12 | attention_depth=64,
13 | attention_heads=1,
14 | encoder_readout_dropout=0.4768781146510798,
15 |
16 | encoder_rnn_layers=1,
17 | decoder_rnn_layers=1,
18 |
19 | # decoder_state_dropout_type=['outside','outside'],
20 | decoder_input_dropout=[1.0, 1.0, 1.0],
21 | decoder_output_dropout=[0.975, 1.0, 1.0], # min 0.95
22 | decoder_state_dropout=[0.99, 0.995, 0.995], # min 0.95
23 | decoder_variational_dropout=[False, False, False],
24 | # decoder_candidate_l2=[0.0, 0.0],
25 | # decoder_gates_l2=[0.0, 0.0],
26 | #decoder_state_dropout_type='outside',
27 | #decoder_input_dropout=1.0,
28 | #decoder_output_dropout=1.0,
29 | #decoder_state_dropout=0.995, #0.98, # min 0.95
30 | # decoder_variational_dropout=False,
31 | decoder_candidate_l2=0.0,
32 | decoder_gates_l2=0.0,
33 |
34 | fingerprint_fc_dropout=0.8232342370695286,
35 | gate_dropout=0.9967589439360334,#0.9786,
36 | gate_activation='none',
37 | encoder_dropout=0.030490422531402273,
38 | encoder_stability_loss=0.0, # max 100
39 | encoder_activation_loss=1e-06, # max 0.001
40 | decoder_stability_loss=0.0, # max 100
41 | decoder_activation_loss=5e-06, # max 0.001
42 | )
43 |
44 | # Default incumbent on last smac3 search
45 | params_definc = dict(
46 | batch_size=256,
47 | train_window=100,
48 | train_skip_first=0,
49 | rnn_depth=128,
50 | use_attn=True,
51 | attention_depth=64,
52 | attention_heads=1,
53 | encoder_readout_dropout=0.4768781146510798,
54 |
55 | encoder_rnn_layers=1,
56 | decoder_rnn_layers=1,
57 |
58 | decoder_input_dropout=[1.0, 1.0, 1.0],
59 | decoder_output_dropout=[1.0, 1.0, 1.0],
60 | decoder_state_dropout=[0.995, 0.995, 0.995],
61 | decoder_variational_dropout=[False, False, False],
62 | decoder_candidate_l2=0.0,
63 | decoder_gates_l2=0.0,
64 | fingerprint_fc_dropout=0.8232342370695286,
65 | gate_dropout=0.8961710392091516,
66 | gate_activation='none',
67 | encoder_dropout=0.030490422531402273,
68 | encoder_stability_loss=0.0,
69 | encoder_activation_loss=1e-05,
70 | decoder_stability_loss=0.0,
71 | decoder_activation_loss=5e-05,
72 | )
73 |
74 | # Found incumbent 0.35503610596060753
75 | #"decoder_activation_loss='1e-05'", "decoder_output_dropout:0='1.0'", "decoder_rnn_layers='1'", "decoder_state_dropout:0='0.995'", "encoder_activation_loss='1e-05'", "encoder_rnn_layers='1'", "gate_dropout='0.7934826952854418'", "rnn_depth='243'", "train_window='135'", "use_attn='1'", "attention_depth='17'", "attention_heads='2'", "encoder_readout_dropout='0.7711751356092252'", "fingerprint_fc_dropout='0.9693950737901414'"
76 | params_foundinc = dict(
77 | batch_size=256,
78 | train_window=135,
79 | train_skip_first=0,
80 | rnn_depth=243,
81 | use_attn=True,
82 | attention_depth=17,
83 | attention_heads=2,
84 | encoder_readout_dropout=0.7711751356092252,
85 |
86 | encoder_rnn_layers=1,
87 | decoder_rnn_layers=1,
88 |
89 | decoder_input_dropout=[1.0, 1.0, 1.0],
90 | decoder_output_dropout=[1.0, 1.0, 1.0],
91 | decoder_state_dropout=[0.995, 0.995, 0.995],
92 | decoder_variational_dropout=[False, False, False],
93 | decoder_candidate_l2=0.0,
94 | decoder_gates_l2=0.0,
95 | fingerprint_fc_dropout=0.9693950737901414,
96 | gate_dropout=0.7934826952854418,
97 | gate_activation='none',
98 | encoder_dropout=0.0,
99 | encoder_stability_loss=0.0,
100 | encoder_activation_loss=1e-05,
101 | decoder_stability_loss=0.0,
102 | decoder_activation_loss=1e-05,
103 | )
104 |
105 | # 81 on smac_run0 (0.3552077534247418 x 7)
106 | #{'decoder_activation_loss': 0.0, 'decoder_output_dropout:0': 0.85, 'decoder_rnn_layers': 2, 'decoder_state_dropout:0': 0.995,
107 | # 'encoder_activation_loss': 0.0, 'encoder_rnn_layers': 2, 'gate_dropout': 0.7665920904244501, 'rnn_depth': 201,
108 | # 'train_window': 143, 'use_attn': 1, 'attention_depth': 17, 'attention_heads': 2, 'decoder_output_dropout:1': 0.975,
109 | # 'decoder_state_dropout:1': 0.99, 'encoder_dropout': 0.0304904225, 'encoder_readout_dropout': 0.4444295965935664, 'fingerprint_fc_dropout': 0.26412480387331017}
110 | params_inst81 = dict(
111 | batch_size=256,
112 | train_window=143,
113 | train_skip_first=0,
114 | rnn_depth=201,
115 | use_attn=True,
116 | attention_depth=17,
117 | attention_heads=2,
118 | encoder_readout_dropout=0.4444295965935664,
119 |
120 | encoder_rnn_layers=2,
121 | decoder_rnn_layers=2,
122 |
123 | decoder_input_dropout=[1.0, 1.0, 1.0],
124 | decoder_output_dropout=[0.85, 0.975, 1.0],
125 | decoder_state_dropout=[0.995, 0.99, 0.995],
126 | decoder_variational_dropout=[False, False, False],
127 | decoder_candidate_l2=0.0,
128 | decoder_gates_l2=0.0,
129 | fingerprint_fc_dropout=0.26412480387331017,
130 | gate_dropout=0.7665920904244501,
131 | gate_activation='none',
132 | encoder_dropout=0.0304904225,
133 | encoder_stability_loss=0.0,
134 | encoder_activation_loss=0.0,
135 | decoder_stability_loss=0.0,
136 | decoder_activation_loss=0.0,
137 | )
138 | # 121 on smac_run0 (0.3548671560628074 x 3)
139 | # {'decoder_activation_loss': 1e-05, 'decoder_output_dropout:0': 0.975, 'decoder_rnn_layers': 2, 'decoder_state_dropout:0': 1.0,
140 | # 'encoder_activation_loss': 1e-05, 'encoder_rnn_layers': 1, 'gate_dropout': 0.8631496699358483, 'rnn_depth': 122,
141 | # 'train_window': 269, 'use_attn': 1, 'attention_depth': 29, 'attention_heads': 4, 'decoder_output_dropout:1': 0.975,
142 | # 'decoder_state_dropout:1': 0.975, 'encoder_readout_dropout': 0.9835390239895767, 'fingerprint_fc_dropout': 0.7452161827064421}
143 |
144 | # 83 on smac_run1 (0.355050330259362 x 7)
145 | # {'decoder_activation_loss': 1e-06, 'decoder_output_dropout:0': 0.925, 'decoder_rnn_layers': 2, 'decoder_state_dropout:0': 0.98,
146 | # 'encoder_activation_loss': 1e-06, 'encoder_rnn_layers': 1, 'gate_dropout': 0.9275441207192259, 'rnn_depth': 138,
147 | # 'train_window': 84, 'use_attn': 1, 'attention_depth': 52, 'attention_heads': 2, 'decoder_output_dropout:1': 0.925,
148 | # 'decoder_state_dropout:1': 0.98, 'encoder_readout_dropout': 0.6415488109353416, 'fingerprint_fc_dropout': 0.2581296623398802}
149 |
150 |
151 | params_inst83 = dict(
152 | batch_size=256,
153 | train_window=84,
154 | train_skip_first=0,
155 | rnn_depth=138,
156 | use_attn=True,
157 | attention_depth=52,
158 | attention_heads=2,
159 | encoder_readout_dropout=0.6415488109353416,
160 |
161 | encoder_rnn_layers=1,
162 | decoder_rnn_layers=2,
163 |
164 | decoder_input_dropout=[1.0, 1.0, 1.0],
165 | decoder_output_dropout=[0.925, 0.925, 1.0],
166 | decoder_state_dropout=[0.98, 0.98, 0.995],
167 | decoder_variational_dropout=[False, False, False],
168 | decoder_candidate_l2=0.0,
169 | decoder_gates_l2=0.0,
170 | fingerprint_fc_dropout=0.2581296623398802,
171 | gate_dropout=0.9275441207192259,
172 | gate_activation='none',
173 | encoder_dropout=0.0,
174 | encoder_stability_loss=0.0,
175 | encoder_activation_loss=1e-06,
176 | decoder_stability_loss=0.0,
177 | decoder_activation_loss=1e-06,
178 | )
179 |
180 | def_params = params_s32
181 |
182 | sets = {
183 | 's32':params_s32,
184 | 'definc':params_definc,
185 | 'foundinc':params_foundinc,
186 | 'inst81':params_inst81,
187 | 'inst83':params_inst83,
188 | }
189 |
190 |
191 | def build_hparams(params=def_params):
192 | return training.HParams(**params)
193 |
194 |
195 | def build_from_set(set_name):
196 | return build_hparams(sets[set_name])
197 |
198 |
199 |
--------------------------------------------------------------------------------
/images/attention.xml:
--------------------------------------------------------------------------------
1 | 7V3Rcps4FP0av3ZAQoAf66TpPmxnutOd3e2jArJNi5EH4ybp168wkm0EGDkGgRLSmY4RQiDOkXTu5WDP4N3m+XOKt+svNCTxDFjh8wzezwDw5y77Py94KQpc2yoKVmkUFkX2qeBb9JvwQlFtH4VkV6qYURpn0bZcGNAkIUFWKsNpSp/K1ZY0Lp91i1ekUvAtwHG19N8ozNa8W8g6lf9BotVanNm2+J5HHPxcpXSf8PPNAFwe/ordGyza4vV3axzSp7Mi+GkG71JKs+LT5vmOxPmtFbetOO6hYe/xulOSZCoH+Kg44heO90Rc8uHCshdxMw7dIfkB1gwuntZRRr5tcZDvfWLos7J1tonZls0+rmK82/GqAd1EAf8c40cSL453547GNGW7Epqwdha7LKU/iShkN809/B33CBDyMyyjOJYOX6U4jFiPpeIlTbIHvIninIX/kDTECebFnHI24Ntnp7YOf6wcx9EqYWVpAfQixLv14TbkV8FvG0kz8tx47+0jomygELohWfrCqvADHEECPkiAx7efTpQDDi9bn9FN1MOc5atj0yek2QcOdgPwsB14Rs5t/nEXJauYfMzHVTsBymzRS4d8tKH8nxr4rgL4AQOTpHlB3n1xYuuDBV3oOp7jWnPoIbGft2x9sC2AAHJsB7qWP3c6YgyyJMZAVGGM7dcwxu6AMZ5XIQgJ2azJN2maremKJjj+dCpdBPv013HMnKhhS3T4sd9sRUOcA2eUIs9R9t/htiK+9Z0fx+5b+nK2K9/8zo/aZTgVUBd1Q07h+yBnZRT8vY6SYsdDFB9Ploj1KKfHD5JlL3wb7zPKik4d/ZPS7fFkVzD2NmYWKOS3/jKLGFJ0nwa8lufwZRSnK5KJVbmebSmJcRb9Krd/E3X8iTpXUOfO/7R4GBF1rCp1kKeJOmI2m6ijQh0fLOCoZh1YM+tYN1KHH/qVRuzMp7XRLa+NUF7ziuviR51L4paGgCs1VHSm0hCDGb+cVdvmFXYVrh97rDZzuhP91el/7y5cNCb6g5qZc65r0UUTdQzWa3YNdfwuZs5e5inHUQ8n4yj5WSaMydDXZBCOq/D5kCkPJlVSVME+C/VQTaQnym5cTaG0CNrwlaup3JCIWFsW09dQ0JsoeJmCYycd8jsindyQo6bgXkE6pJA/ndJor06jWdABLvI8CzkIekw7oVkniTTbKxPEq+bRXNhTHs1WWCrZEdF2R9ppgnfb4vnHMnrO2TIob25Kux9nrAplus67QyiBX5N2r8EedIC9gPoS9jkk28Zu8ide+FFUt67tvu23dt+ue+rQRf8VHjqMjfqBTx6XKlNmiIm/DHrgvxLlizvZCPpcH8dd8zDuZVnUh7EGUG3bOFRD97ExGSShulyCYKwjF+obubbC6jQykH3w2JjwlqZnRPzQGSnIrkYJYhkHsunzszgCaERZQWn1LTQdazihCQwMskxRmnAsShMoJF/GBrIpU1kDyDpQNS+AMEZqXh66GqUmUEjWjwxkY6TmZZB1Sk3fOJBNn58HkJpQIaDoW2q6YDipCc1LjZgiNWGDVUy/1ITmpUZMmcqaQNaBqnnPI0yRmi1DV6PUhOalAkyRmi0ga5SacEoFDASyTqkpJNxVltYyeLe6WAvTV7OLtWRWMNkNfYlIJH48dCJnThTgeCYZhwuUrrW8imeM55ZX2BDndGV5rfePASmcEBxv8pu5sluyXJ99KK7g1dZGFfVZ4VaZzhIx9XHlwmC50o+uTisVCunxNDpOmRpzaRpUtTQC+a0U2Rt59Vspatdr21dS3+6Y+uZJ9DfmlpOtkhqX+zcr6gzBHqHhsDfvKdAbw34+HPaveV383cr82975HUDmgxqN5nQh86uiSXJ6H7+m51r15XotDXWkvgC6MvCY9xx4zBVm4ZsCjx7Zqz/wGCzI8LoJMmQ6ye10FWPIo6ktxpBZ3nGMIVA4Y/l3gtO8Ryta4TtbUrOZ9O7gOanF+l5d8itrtphYP/IdmygM4yY906xZSiNMXcB08l7zTRoDWKgErFPVGPOeNAaqPrX9a8+WcTLB3jvsjt0Kuxji3eNeTaV9TUkYBVlEkwn0/kB34YCgw6pqmQKKJkl22/dP6Qgoug8U/G4ElAe0CCj5ctvCBPmyug4TkMoj9pvChB45+X7CBPnrI0fOcvly28KECsu7DhMUUpJ9Gx2RXe5jnZGmN6cjMs+0bYrTseDWGJyOaDJt6wZZB6oqibxxoWqK07Fl6Gp0OopH8QaBbIrTsQVkjU5H18DXDQyfnwdwOrojeH+7ojV9jVrTNc+1bYrWdEfzArf7Zg0+g89lTSDrQNU8644pWrNl6OrUmublAkzRmi0g69SaUy5gIJC1ak2F3IBurXn89QYdWlPYUUyiuSlas+HXDvRrTW+Km3WDrANVYByqxmjNy0NXo9b0zHsHyBiteRlkjVrTMzDhY/j8PIDW9KurcK9WW/HbrR257zo13NUA3anL0mmX1G5fMFeX5b7dte8ZaYVEfV8+ar86bU9+2gPoPftpkeR2AlZNxNzbNF5N8T4QnO1TsnvXmF8a6GuaRr/Zvtz2et+VqVp6X6zuLc2OfkGcbZ5+lr5wz7H7vP5CQ5LX+B8=
--------------------------------------------------------------------------------
/images/autocorr.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Arturus/kaggle-web-traffic/a9abb80c800409abf0ece21ea244ef779f758f96/images/autocorr.png
--------------------------------------------------------------------------------
/images/encoder-decoder.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Arturus/kaggle-web-traffic/a9abb80c800409abf0ece21ea244ef779f758f96/images/encoder-decoder.png
--------------------------------------------------------------------------------
/images/encoder-decoder.xml:
--------------------------------------------------------------------------------
1 | 7VxLc6M4EP41PiYFCDA+xklm5jBbldrU7k6OMig2Mxh5hZzY8+tXgMRDAvzg4XhDLoFGL9Rft7o/yUzA/Xr3lcDN6g/soWBiaN5uAh4mhqFrlsn+xZJ9KpnqTipYEt/jhXLBs/8biZpcuvU9FJUKUowD6m/KQheHIXJpSQYJwe/lYq84KPe6gUukCJ5dGKjSf3yPrlKpY2m5/BvylyvRs67xJwvo/loSvA15fxMDvCZ/6eM1FG3x8tEKevi9IAKPE3BPMKbp1Xp3j4J4bsW0pfW+1DzNxk1QSI+pAIy0xhsMtkgMORkY3YvJQB6bG34boeWaNf2Yi+Yo9O7iKWdP3QBGke8y4YquAybQ2WVECf6F7nGASdIecDT7AQD25BWH9Atc+0GMkr8R8WAIuZhDQjf4faG6lvwxeTrQeHS1L69nU8qgivAaUbJnRd5znQqVrgrqFDKCAkj9t3LzkENrmTWX9fCEfdaxoXEzAKIdbgSGo5WbiPCWuIjXKipJash0pIZMqSEKyRJRpSGmFrgvFNvEBaKGAdtSP7Z22rjK5dlFOgJxV9BBLkoAWQNOBZtf//xLgWdibcjjaHtf+RQ9b6AbP31n3qkMxmWMUHYd48fF6xiryXUAFyiYZ8Yr4BbiEFVC2FgA286eCB8R9/DqB0GhpGchxzPPBrtjzNOOYOAvw9jEGLYRydD/hghFu8mJ+OcVZPUJJ1YwD7vCPGT0FS2hoPtG1ZqjantVrS6b8oC6tUbd9qrbTEkX0K19cd2y2X2wHp0H8/+rYGBcTsHTinjQDiifgZKa7X+3WDy4iZK5umMFdGOzyx+yq2X8/4doJtouhEwXMjaogrggTfsUYjku7TPwtBsCT6ZGsv8Rw/TWErcvvNMVJv5vVg+KUSQwfsKRT31cAkoMEZ+lHN+lAgtMKV5XQUvUuOMPKN6cEAazFCoJEwux1aUiY13vKDLW5RB7elxkfEYw6lQYhgRIlsxtkiTJD5cB4tg85P5yX6l17Aw9e2FbRzlDlp8artveMARiSYqYeZKKi361W+Dojj6zbd20HMD+LFGEt63dmtZMm2kmS50cMHNMs6NgSF4wdaD4U6MC7brV3p8KpI+4uT7cADlBGhI3eu1C7Plvleswe1F6w6cyXoj5uqEsxdlKvIGhLDtpkU+qSY0/hi724l5ZG5Ci2rU8X/grRsHEyUvWLPzCXtj7xtOBt4vkSdkqhssq+rGXJuZLiQw8nyCXxw9s7Yz7L0ciOV/WyiJMrWwR+lSNTEGVRWgdhKa6ShqMrvRKXKkhpzRDulKVkBhxcyW4AdrlcGPMFJgUt0IwoSu8xCEMirsh5QWoABa082khZ2R3LxwgTenkT0Tpns8x3FLMRHm/33GcAPa7yfJzu96IN+YL4OTohDNN47iIU0fFJDRNqlTlH51eHu0CVN6qn2jqXHokYdHaxESXCHxm9hTAo7wU0j0LTc9G453taHOnKvDpIqaROYQB2Ta9im4bl6arWJpMcMGQRmWjssSrBj7ulgT7OWHKjr1vr/ip8xeZRpuhw6t3hJrGHElmVDPu9WTWtsjFdrkpm23klTKsCliBDtxRxY57f+y/MRnZ/+HY/4tR/bOuqH65oSMPwZxB9VecTujPDMKbcRvsMxhCtjXV1hCUhno0hCrmpDdDGM3gE5hBZ4cilYZ6NAOVPXg55TjDkEht4nLqKagzw92DmFz7nhe05YwuBVUpEs9Ox7Y9vqv3iNSRT7haPsGSD8UMSXVXnW4Z+dGRH41xKfNcA/KjoIqQGP3ZdfgzOWEf0J+B+gx+9Gef3J/Z8jo7pD8bjyJcrT+zZd5lSH827l+P/qwOl/I6O6Q/U/PNKmakdqtnZEbOZEbqTs4Mw4yYSmh3JjVi2YNRI0BNcaug2rAdM4L1w7PLdt1PI08FptxQn8CcHQfMEZZXC8tp3Q86T4Wl3FCPsDR1FVp9nH79iCdcm7b02qzX0k8wzWpMtgSbsqY6Tjc+MHuJjj87og64+bMjcgJ/avlDnzVR3rvjz5qYas4/WlYrL/5RogbFz/f0oR4lOjlkAbN25Q9ZjPLeXVuMug3wgMbjrkcfdw1ju28KoM4/ljJpT2AoB/Arf1TY05FXU90peCLI85OpixR40fQ3p7XxtFCmqt/jo9kqrHYB0CMQZjYgrAdNZ9+BOKRpowtNqxzqNz+imMRKYHIPUjjqu1N9yydoKpjJrtTNbvOvJqZLTP5pSvD4Hw==
--------------------------------------------------------------------------------
/images/from_past.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Arturus/kaggle-web-traffic/a9abb80c800409abf0ece21ea244ef779f758f96/images/from_past.png
--------------------------------------------------------------------------------
/images/lagged_data.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Arturus/kaggle-web-traffic/a9abb80c800409abf0ece21ea244ef779f758f96/images/lagged_data.png
--------------------------------------------------------------------------------
/images/losses_0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Arturus/kaggle-web-traffic/a9abb80c800409abf0ece21ea244ef779f758f96/images/losses_0.png
--------------------------------------------------------------------------------
/images/losses_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Arturus/kaggle-web-traffic/a9abb80c800409abf0ece21ea244ef779f758f96/images/losses_1.png
--------------------------------------------------------------------------------
/images/predictions.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Arturus/kaggle-web-traffic/a9abb80c800409abf0ece21ea244ef779f758f96/images/predictions.png
--------------------------------------------------------------------------------
/images/split.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Arturus/kaggle-web-traffic/a9abb80c800409abf0ece21ea244ef779f758f96/images/split.png
--------------------------------------------------------------------------------
/images/training.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Arturus/kaggle-web-traffic/a9abb80c800409abf0ece21ea244ef779f758f96/images/training.png
--------------------------------------------------------------------------------
/images/validation-split.xml:
--------------------------------------------------------------------------------
1 | 7Vpdb6s2GP41kbaLHBkbCLls0ma7mVSp1TnbpQMusUowM/Qk2a+fje0EbHJCO0i3tanU4tcf2M/zfvltJmi53f/CcbH5jSUkm0CQ7CfodgJhNA/Fbyk4KEHoASVIOU2UyDsJHuhfRAvNsBeakLI1sGIsq2jRFsYsz0lctWSYc7ZrD3tiWfutBU6JI3iIceZKv9Gk2uhjBeAk/5XQdGPe7AHds8bxc8rZS67fN4Hoqf6o7i02a+nx5QYnbNcQobsJWnLGKvW03S9JJqE1sKl5qzO9x31zkld9JvhQzfiOsxditlxvrDoYMOrjEDkBTNBit6EVeShwLHt3gn0h21TbTLQ88ZhmuCz10ASXm3qe7IjZlsa6I8Nrki2OUC1ZxrjoylkuFl2UFWfPxAgFgmH9ET1PNMsa8qdA/kg5y6sV3tJMqttXwhOcYy3WuuVB3W5Mj0B4i5CQawgIr8j+LI7ekR2h9IRtScUPYoiZEGpCtcJ7huDdSX2gr2WbhupAozlYq2x6XPtEm3jQzHWzODSJFs6r1Qoul9cDyos6gPKGAGo+LlC3wV106w8D1NFFHKz2VYDyQgepbzh7nj4xvsM8kZAVGa0c9MR5qzZEbWvWJt7ETYtwRtNcNGOBDxHyhUSPCpd8ozu2NEnkazo5abPWwx9Epq33DkbQ7XlvHzAAYTOHsAcRQ6frw7QUfz8J6yLM99+RsMgh7JFuiUOQyBAK+VjSPM3IjUxuXu2X7kL50z9Q1imUSXzAFz9EIYoQCMLQByEMzAg9BXwBgR/NQuAFPhCj0DwaJ1Yg5PITBB38+EPE1MDh556ThMb/YRuCA4WmuRXDQ5eXsczGjUt3ecwS127+1aQMYRzRZRK60oMhSBBZ88U8iuSJdlYG19Z1oOms+iku2dPq99rdhOFMt//QS8jne8KF/6wZq6EXOPKDmRAagZwxFR4LIiOx56mTkMS5IVrMiNOyFx6TVmZZYZ6SqiFy+Ws6rw5+jIyTDFf0e3sTXaTpN9wzKrZ3Cm3ASh+RxbvavJ7VvCFaCyE7YZ9ZC6kjOwvVOnQ8dj+1co37CmrVUJMZbKoJOKchDU2cwR6a+DaNgq5GwffUKBh5bUWYv1GjjlUeo5r2QgNqVNeFL8xkYCgLnLdUK/zzRVZdFqcSzjRWgeFGgpeufxLJxkS6T/FaUFdOzDPwfz7NF0+p/PvIMc1F0ibaeCsjQ74uC9WfqXErtQc1fIgA1qyJmCjGlSL8kyDWq6bTw9ykEy6Mbcg3SnTQLbpYEVK9jCeE27UitcojK/Q2yuOyAyU7PcLsbKQw63uX/aHEqzh7TF0fxWszHLz2+NC6IsGO83tdXmcIADqyvY+Tg/sX8j/vekk3cosL/9us+xzsV0ir/R5p9dj2bpdEjrZ9qSYyTNmxh8e7XiVb8fFDVeiEws6Q34QEHBeJ15WqzyFxaB/4EjD2HeRNMaEHMGNeEmZB845gUv5mn74N9Mz8jdE3U39jBs3c33vX5H+KAisPsB1e3+x/iqyb6YgXSviR4tYZE9239cBYbOBY7FhRDbqF7g+UxP2YFW9+kZaxcrqO/4V+2oZC5nq2gdx859M2zthGBy1j2YbvuqyvAsVERFCWj8DMmDUa2bjHlWA+ryUQ9Pw2TjNJ7Fncsb5gA+rPMGUYzy4ajvc/J9E8fQFLZR2nL7mhu78B
--------------------------------------------------------------------------------
/input_pipe.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 | from feeder import VarFeeder
4 | from enum import Enum
5 | from typing import List, Iterable
6 | import numpy as np
7 | import pandas as pd
8 |
9 |
10 | class ModelMode(Enum):
11 | TRAIN = 0
12 | EVAL = 1,
13 | PREDICT = 2
14 |
15 |
16 | class Split:
17 | def __init__(self, test_set: List[tf.Tensor], train_set: List[tf.Tensor], test_size: int, train_size: int):
18 | self.test_set = test_set
19 | self.train_set = train_set
20 | self.test_size = test_size
21 | self.train_size = train_size
22 |
23 |
24 | class Splitter:
25 | def cluster_pages(self, cluster_idx: tf.Tensor):
26 | """
27 | Shuffles pages so all user_agents of each unique pages stays together in a shuffled list
28 | :param cluster_idx: Tensor[uniq_pages, n_agents], each value is index of pair (uniq_page, agent) in other page tensors
29 | :return: list of page indexes for use in a global page tensors
30 | """
31 | size = cluster_idx.shape[0].value
32 | random_idx = tf.random_shuffle(tf.range(0, size, dtype=tf.int32), self.seed)
33 | shuffled_pages = tf.gather(cluster_idx, random_idx)
34 | # Drop non-existent (uniq_page, agent) pairs. Non-existent pair has index value = -1
35 | mask = shuffled_pages >= 0
36 | page_idx = tf.boolean_mask(shuffled_pages, mask)
37 | return page_idx
38 |
39 | def __init__(self, tensors: List[tf.Tensor], cluster_indexes: tf.Tensor, n_splits, seed, train_sampling=1.0,
40 | test_sampling=1.0):
41 | size = tensors[0].shape[0].value
42 | self.seed = seed
43 | clustered_index = self.cluster_pages(cluster_indexes)
44 | index_len = tf.shape(clustered_index)[0]
45 | assert_op = tf.assert_equal(index_len, size, message='n_pages is not equals to size of clustered index')
46 | with tf.control_dependencies([assert_op]):
47 | split_nitems = int(round(size / n_splits))
48 | split_size = [split_nitems] * n_splits
49 | split_size[-1] = size - (n_splits - 1) * split_nitems
50 | splits = tf.split(clustered_index, split_size)
51 | complements = [tf.random_shuffle(tf.concat(splits[:i] + splits[i + 1:], axis=0), seed) for i in
52 | range(n_splits)]
53 | splits = [tf.random_shuffle(split, seed) for split in splits]
54 |
55 | def mk_name(prefix, tensor):
56 | return prefix + '_' + tensor.name[:-2]
57 |
58 | def prepare_split(i):
59 | test_size = split_size[i]
60 | train_size = size - test_size
61 | test_sampled_size = int(round(test_size * test_sampling))
62 | train_sampled_size = int(round(train_size * train_sampling))
63 | test_idx = splits[i][:test_sampled_size]
64 | train_idx = complements[i][:train_sampled_size]
65 | test_set = [tf.gather(tensor, test_idx, name=mk_name('test', tensor)) for tensor in tensors]
66 | tran_set = [tf.gather(tensor, train_idx, name=mk_name('train', tensor)) for tensor in tensors]
67 | return Split(test_set, tran_set, test_sampled_size, train_sampled_size)
68 |
69 | self.splits = [prepare_split(i) for i in range(n_splits)]
70 |
71 |
72 | class FakeSplitter:
73 | def __init__(self, tensors: List[tf.Tensor], n_splits, seed, test_sampling=1.0):
74 | total_pages = tensors[0].shape[0].value
75 | n_pages = int(round(total_pages * test_sampling))
76 |
77 | def mk_name(prefix, tensor):
78 | return prefix + '_' + tensor.name[:-2]
79 |
80 | def prepare_split(i):
81 | idx = tf.random_shuffle(tf.range(0, n_pages, dtype=tf.int32), seed + i)
82 | train_tensors = [tf.gather(tensor, idx, name=mk_name('shfl', tensor)) for tensor in tensors]
83 | if test_sampling < 1.0:
84 | sampled_idx = idx[:n_pages]
85 | test_tensors = [tf.gather(tensor, sampled_idx, name=mk_name('shfl_test', tensor)) for tensor in tensors]
86 | else:
87 | test_tensors = train_tensors
88 | return Split(test_tensors, train_tensors, n_pages, total_pages)
89 |
90 | self.splits = [prepare_split(i) for i in range(n_splits)]
91 |
92 |
93 | class InputPipe:
94 | def cut(self, hits, start, end):
95 | """
96 | Cuts [start:end] diapason from input data
97 | :param hits: hits timeseries
98 | :param start: start index
99 | :param end: end index
100 | :return: tuple (train_hits, test_hits, dow, lagged_hits)
101 | """
102 | # Pad hits to ensure we have enough array length for prediction
103 | hits = tf.concat([hits, tf.fill([self.predict_window], np.NaN)], axis=0)
104 | cropped_hit = hits[start:end]
105 |
106 | # cut day of week
107 | cropped_dow = self.inp.dow[start:end]
108 |
109 | # Cut lagged hits
110 | # gather() accepts only int32 indexes
111 | cropped_lags = tf.cast(self.inp.lagged_ix[start:end], tf.int32)
112 | # Mask for -1 (no data) lag indexes
113 | lag_mask = cropped_lags < 0
114 | # Convert -1 to 0 for gather(), it don't accept anything exotic
115 | cropped_lags = tf.maximum(cropped_lags, 0)
116 | # Translate lag indexes to hit values
117 | lagged_hit = tf.gather(hits, cropped_lags)
118 | # Convert masked (see above) or NaN lagged hits to zeros
119 | lag_zeros = tf.zeros_like(lagged_hit)
120 | lagged_hit = tf.where(lag_mask | tf.is_nan(lagged_hit), lag_zeros, lagged_hit)
121 |
122 | # Split for train and test
123 | x_hits, y_hits = tf.split(cropped_hit, [self.train_window, self.predict_window], axis=0)
124 |
125 | # Convert NaN to zero in for train data
126 | x_hits = tf.where(tf.is_nan(x_hits), tf.zeros_like(x_hits), x_hits)
127 | return x_hits, y_hits, cropped_dow, lagged_hit
128 |
129 | def cut_train(self, hits, *args):
130 | """
131 | Cuts a segment of time series for training. Randomly chooses starting point.
132 | :param hits: hits timeseries
133 | :param args: pass-through data, will be appended to result
134 | :return: result of cut() + args
135 | """
136 | n_days = self.predict_window + self.train_window
137 | # How much free space we have to choose starting day
138 | free_space = self.inp.data_days - n_days - self.back_offset - self.start_offset
139 | if self.verbose:
140 | lower_train_start = self.inp.data_start + pd.Timedelta(self.start_offset, 'D')
141 | lower_test_end = lower_train_start + pd.Timedelta(n_days, 'D')
142 | lower_test_start = lower_test_end - pd.Timedelta(self.predict_window, 'D')
143 | upper_train_start = self.inp.data_start + pd.Timedelta(free_space - 1, 'D')
144 | upper_test_end = upper_train_start + pd.Timedelta(n_days, 'D')
145 | upper_test_start = upper_test_end - pd.Timedelta(self.predict_window, 'D')
146 | print(f"Free space for training: {free_space} days.")
147 | print(f" Lower train {lower_train_start}, prediction {lower_test_start}..{lower_test_end}")
148 | print(f" Upper train {upper_train_start}, prediction {upper_test_start}..{upper_test_end}")
149 | # Random starting point
150 | offset = tf.random_uniform((), self.start_offset, free_space, dtype=tf.int32, seed=self.rand_seed)
151 | end = offset + n_days
152 | # Cut all the things
153 | return self.cut(hits, offset, end) + args
154 |
155 | def cut_eval(self, hits, *args):
156 | """
157 | Cuts segment of time series for evaluation.
158 | Always cuts train_window + predict_window length segment beginning at start_offset point
159 | :param hits: hits timeseries
160 | :param args: pass-through data, will be appended to result
161 | :return: result of cut() + args
162 | """
163 | end = self.start_offset + self.train_window + self.predict_window
164 | return self.cut(hits, self.start_offset, end) + args
165 |
166 | def reject_filter(self, x_hits, y_hits, *args):
167 | """
168 | Rejects timeseries having too many zero datapoints (more than self.max_train_empty)
169 | """
170 | if self.verbose:
171 | print("max empty %d train %d predict" % (self.max_train_empty, self.max_predict_empty))
172 | zeros_x = tf.reduce_sum(tf.to_int32(tf.equal(x_hits, 0.0)))
173 | keep = zeros_x <= self.max_train_empty
174 | return keep
175 |
176 | def make_features(self, x_hits, y_hits, dow, lagged_hits, pf_agent, pf_country, pf_site, page_ix,
177 | page_popularity, year_autocorr, quarter_autocorr):
178 | """
179 | Main method. Assembles input data into final tensors
180 | """
181 | # Split day of week to train and test
182 | x_dow, y_dow = tf.split(dow, [self.train_window, self.predict_window], axis=0)
183 |
184 | # Normalize hits
185 | mean = tf.reduce_mean(x_hits)
186 | std = tf.sqrt(tf.reduce_mean(tf.squared_difference(x_hits, mean)))
187 | norm_x_hits = (x_hits - mean) / std
188 | norm_y_hits = (y_hits - mean) / std
189 | norm_lagged_hits = (lagged_hits - mean) / std
190 |
191 | # Split lagged hits to train and test
192 | x_lagged, y_lagged = tf.split(norm_lagged_hits, [self.train_window, self.predict_window], axis=0)
193 |
194 | # Combine all page features into single tensor
195 | stacked_features = tf.stack([page_popularity, quarter_autocorr, year_autocorr])
196 | flat_page_features = tf.concat([pf_agent, pf_country, pf_site, stacked_features], axis=0)
197 | page_features = tf.expand_dims(flat_page_features, 0)
198 |
199 | # Train features
200 | x_features = tf.concat([
201 | # [n_days] -> [n_days, 1]
202 | tf.expand_dims(norm_x_hits, -1),
203 | x_dow,
204 | x_lagged,
205 | # Stretch page_features to all training days
206 | # [1, features] -> [n_days, features]
207 | tf.tile(page_features, [self.train_window, 1])
208 | ], axis=1)
209 |
210 | # Test features
211 | y_features = tf.concat([
212 | # [n_days] -> [n_days, 1]
213 | y_dow,
214 | y_lagged,
215 | # Stretch page_features to all testing days
216 | # [1, features] -> [n_days, features]
217 | tf.tile(page_features, [self.predict_window, 1])
218 | ], axis=1)
219 |
220 | return x_hits, x_features, norm_x_hits, x_lagged, y_hits, y_features, norm_y_hits, mean, std, flat_page_features, page_ix
221 |
222 | def __init__(self, inp: VarFeeder, features: Iterable[tf.Tensor], n_pages: int, mode: ModelMode, n_epoch=None,
223 | batch_size=127, runs_in_burst=1, verbose=True, predict_window=60, train_window=500,
224 | train_completeness_threshold=1, predict_completeness_threshold=1, back_offset=0,
225 | train_skip_first=0, rand_seed=None):
226 | """
227 | Create data preprocessing pipeline
228 | :param inp: Raw input data
229 | :param features: Features tensors (subset of data in inp)
230 | :param n_pages: Total number of pages
231 | :param mode: Train/Predict/Eval mode selector
232 | :param n_epoch: Number of epochs. Generates endless data stream if None
233 | :param batch_size:
234 | :param runs_in_burst: How many batches can be consumed at short time interval (burst). Multiplicator for prefetch()
235 | :param verbose: Print additional information during graph construction
236 | :param predict_window: Number of days to predict
237 | :param train_window: Use train_window days for traning
238 | :param train_completeness_threshold: Percent of zero datapoints allowed in train timeseries.
239 | :param predict_completeness_threshold: Percent of zero datapoints allowed in test/predict timeseries.
240 | :param back_offset: Don't use back_offset days at the end of timeseries
241 | :param train_skip_first: Don't use train_skip_first days at the beginning of timeseries
242 | :param rand_seed:
243 |
244 | """
245 | self.n_pages = n_pages
246 | self.inp = inp
247 | self.batch_size = batch_size
248 | self.rand_seed = rand_seed
249 | self.back_offset = back_offset
250 | if verbose:
251 | print("Mode:%s, data days:%d, Data start:%s, data end:%s, features end:%s " % (
252 | mode, inp.data_days, inp.data_start, inp.data_end, inp.features_end))
253 |
254 | if mode == ModelMode.TRAIN:
255 | # reserve predict_window at the end for validation
256 | assert inp.data_days - predict_window > predict_window + train_window, \
257 | "Predict+train window length (+predict window for validation) is larger than total number of days in dataset"
258 | self.start_offset = train_skip_first
259 | elif mode == ModelMode.EVAL or mode == ModelMode.PREDICT:
260 | self.start_offset = inp.data_days - train_window - back_offset
261 | if verbose:
262 | train_start = inp.data_start + pd.Timedelta(self.start_offset, 'D')
263 | eval_start = train_start + pd.Timedelta(train_window, 'D')
264 | end = eval_start + pd.Timedelta(predict_window - 1, 'D')
265 | print("Train start %s, predict start %s, end %s" % (train_start, eval_start, end))
266 | assert self.start_offset >= 0
267 |
268 | self.train_window = train_window
269 | self.predict_window = predict_window
270 | self.attn_window = train_window - predict_window + 1
271 | self.max_train_empty = int(round(train_window * (1 - train_completeness_threshold)))
272 | self.max_predict_empty = int(round(predict_window * (1 - predict_completeness_threshold)))
273 | self.mode = mode
274 | self.verbose = verbose
275 |
276 | # Reserve more processing threads for eval/predict because of larger batches
277 | num_threads = 3 if mode == ModelMode.TRAIN else 6
278 |
279 | # Choose right cutter function for current ModelMode
280 | cutter = {ModelMode.TRAIN: self.cut_train, ModelMode.EVAL: self.cut_eval, ModelMode.PREDICT: self.cut_eval}
281 | # Create dataset, transform features and assemble batches
282 | root_ds = tf.data.Dataset.from_tensor_slices(tuple(features)).repeat(n_epoch)
283 | batch = (root_ds
284 | .map(cutter[mode])
285 | .filter(self.reject_filter)
286 | .map(self.make_features, num_parallel_calls=num_threads)
287 | .batch(batch_size)
288 | .prefetch(runs_in_burst * 2)
289 | )
290 |
291 | self.iterator = batch.make_initializable_iterator()
292 | it_tensors = self.iterator.get_next()
293 |
294 | # Assign all tensors to class variables
295 | self.true_x, self.time_x, self.norm_x, self.lagged_x, self.true_y, self.time_y, self.norm_y, self.norm_mean, \
296 | self.norm_std, self.page_features, self.page_ix = it_tensors
297 |
298 | self.encoder_features_depth = self.time_x.shape[2].value
299 |
300 | def load_vars(self, session):
301 | self.inp.restore(session)
302 |
303 | def init_iterator(self, session):
304 | session.run(self.iterator.initializer)
305 |
306 |
307 | def page_features(inp: VarFeeder):
308 | return (inp.hits, inp.pf_agent, inp.pf_country, inp.pf_site,
309 | inp.page_ix, inp.page_popularity, inp.year_autocorr, inp.quarter_autocorr)
310 |
--------------------------------------------------------------------------------
/make_features.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import numpy as np
3 | import os.path
4 | import os
5 | import argparse
6 |
7 | import extractor
8 | from feeder import VarFeeder
9 | import numba
10 | from typing import Tuple, Dict, Collection, List
11 |
12 |
13 | def read_cached(name) -> pd.DataFrame:
14 | """
15 | Reads csv file (maybe zipped) from data directory and caches it's content as a pickled DataFrame
16 | :param name: file name without extension
17 | :return: file content
18 | """
19 | cached = 'data/%s.pkl' % name
20 | sources = ['data/%s.csv' % name, 'data/%s.csv.zip' % name]
21 | if os.path.exists(cached):
22 | return pd.read_pickle(cached)
23 | else:
24 | for src in sources:
25 | if os.path.exists(src):
26 | df = pd.read_csv(src)
27 | df.to_pickle(cached)
28 | return df
29 |
30 |
31 | def read_all() -> pd.DataFrame:
32 | """
33 | Reads source data for training/prediction
34 | """
35 | def read_file(file):
36 | df = read_cached(file).set_index('Page')
37 | df.columns = df.columns.astype('M8[D]')
38 | return df
39 |
40 | # Path to cached data
41 | path = os.path.join('data', 'all.pkl')
42 | if os.path.exists(path):
43 | df = pd.read_pickle(path)
44 | else:
45 | # Official data
46 | df = read_file('train_2')
47 | # Scraped data
48 | scraped = read_file('2017-08-15_2017-09-11')
49 | # Update last two days by scraped data
50 | df[pd.Timestamp('2017-09-10')] = scraped['2017-09-10']
51 | df[pd.Timestamp('2017-09-11')] = scraped['2017-09-11']
52 |
53 | df = df.sort_index()
54 | # Cache result
55 | df.to_pickle(path)
56 | return df
57 |
58 | # todo:remove
59 | def make_holidays(tagged, start, end) -> pd.DataFrame:
60 | def read_df(lang):
61 | result = pd.read_pickle('data/holidays/%s.pkl' % lang)
62 | return result[~result.dw].resample('D').size().rename(lang)
63 |
64 | holidays = pd.DataFrame([read_df(lang) for lang in ['de', 'en', 'es', 'fr', 'ja', 'ru', 'zh']])
65 | holidays = holidays.loc[:, start:end].fillna(0)
66 | result =tagged[['country']].join(holidays, on='country').drop('country', axis=1).fillna(0).astype(np.int8)
67 | result.columns = pd.DatetimeIndex(result.columns.values)
68 | return result
69 |
70 |
71 | def read_x(start, end) -> pd.DataFrame:
72 | """
73 | Gets source data from start to end date. Any date can be None
74 | """
75 | df = read_all()
76 | # User GoogleAnalitycsRoman has really bad data with huge traffic spikes in all incarnations.
77 | # Wikipedia banned him, we'll ban it too
78 | bad_roman = df.index.str.startswith("User:GoogleAnalitycsRoman")
79 | df = df[~bad_roman]
80 | if start and end:
81 | return df.loc[:, start:end]
82 | elif end:
83 | return df.loc[:, :end]
84 | else:
85 | return df
86 |
87 |
88 | @numba.jit(nopython=True)
89 | def single_autocorr(series, lag):
90 | """
91 | Autocorrelation for single data series
92 | :param series: traffic series
93 | :param lag: lag, days
94 | :return:
95 | """
96 | s1 = series[lag:]
97 | s2 = series[:-lag]
98 | ms1 = np.mean(s1)
99 | ms2 = np.mean(s2)
100 | ds1 = s1 - ms1
101 | ds2 = s2 - ms2
102 | divider = np.sqrt(np.sum(ds1 * ds1)) * np.sqrt(np.sum(ds2 * ds2))
103 | return np.sum(ds1 * ds2) / divider if divider != 0 else 0
104 |
105 |
106 | @numba.jit(nopython=True)
107 | def batch_autocorr(data, lag, starts, ends, threshold, backoffset=0):
108 | """
109 | Calculate autocorrelation for batch (many time series at once)
110 | :param data: Time series, shape [n_pages, n_days]
111 | :param lag: Autocorrelation lag
112 | :param starts: Start index for each series
113 | :param ends: End index for each series
114 | :param threshold: Minimum support (ratio of time series length to lag) to calculate meaningful autocorrelation.
115 | :param backoffset: Offset from the series end, days.
116 | :return: autocorrelation, shape [n_series]. If series is too short (support less than threshold),
117 | autocorrelation value is NaN
118 | """
119 | n_series = data.shape[0]
120 | n_days = data.shape[1]
121 | max_end = n_days - backoffset
122 | corr = np.empty(n_series, dtype=np.float64)
123 | support = np.empty(n_series, dtype=np.float64)
124 | for i in range(n_series):
125 | series = data[i]
126 | end = min(ends[i], max_end)
127 | real_len = end - starts[i]
128 | support[i] = real_len/lag
129 | if support[i] > threshold:
130 | series = series[starts[i]:end]
131 | c_365 = single_autocorr(series, lag)
132 | c_364 = single_autocorr(series, lag-1)
133 | c_366 = single_autocorr(series, lag+1)
134 | # Average value between exact lag and two nearest neighborhs for smoothness
135 | corr[i] = 0.5 * c_365 + 0.25 * c_364 + 0.25 * c_366
136 | else:
137 | corr[i] = np.NaN
138 | return corr #, support
139 |
140 |
141 | @numba.jit(nopython=True)
142 | def find_start_end(data: np.ndarray):
143 | """
144 | Calculates start and end of real traffic data. Start is an index of first non-zero, non-NaN value,
145 | end is index of last non-zero, non-NaN value
146 | :param data: Time series, shape [n_pages, n_days]
147 | :return:
148 | """
149 | n_pages = data.shape[0]
150 | n_days = data.shape[1]
151 | start_idx = np.full(n_pages, -1, dtype=np.int32)
152 | end_idx = np.full(n_pages, -1, dtype=np.int32)
153 | for page in range(n_pages):
154 | # scan from start to the end
155 | for day in range(n_days):
156 | if not np.isnan(data[page, day]) and data[page, day] > 0:
157 | start_idx[page] = day
158 | break
159 | # reverse scan, from end to start
160 | for day in range(n_days - 1, -1, -1):
161 | if not np.isnan(data[page, day]) and data[page, day] > 0:
162 | end_idx[page] = day
163 | break
164 | return start_idx, end_idx
165 |
166 |
167 | def prepare_data(start, end, valid_threshold) -> Tuple[pd.DataFrame, pd.DataFrame, np.ndarray, np.ndarray]:
168 | """
169 | Reads source data, calculates start and end of each series, drops bad series, calculates log1p(series)
170 | :param start: start date of effective time interval, can be None to start from beginning
171 | :param end: end date of effective time interval, can be None to return all data
172 | :param valid_threshold: minimal ratio of series real length to entire (end-start) interval. Series dropped if
173 | ratio is less than threshold
174 | :return: tuple(log1p(series), nans, series start, series end)
175 | """
176 | df = read_x(start, end)
177 | starts, ends = find_start_end(df.values)
178 | # boolean mask for bad (too short) series
179 | page_mask = (ends - starts) / df.shape[1] < valid_threshold
180 | print("Masked %d pages from %d" % (page_mask.sum(), len(df)))
181 | inv_mask = ~page_mask
182 | df = df[inv_mask]
183 | nans = pd.isnull(df)
184 | return np.log1p(df.fillna(0)), nans, starts[inv_mask], ends[inv_mask]
185 |
186 |
187 | def lag_indexes(begin, end) -> List[pd.Series]:
188 | """
189 | Calculates indexes for 3, 6, 9, 12 months backward lag for the given date range
190 | :param begin: start of date range
191 | :param end: end of date range
192 | :return: List of 4 Series, one for each lag. For each Series, index is date in range(begin, end), value is an index
193 | of target (lagged) date in a same Series. If target date is out of (begin,end) range, index is -1
194 | """
195 | dr = pd.date_range(begin, end)
196 | # key is date, value is day index
197 | base_index = pd.Series(np.arange(0, len(dr)), index=dr)
198 |
199 | def lag(offset):
200 | dates = dr - offset
201 | return pd.Series(data=base_index.loc[dates].fillna(-1).astype(np.int16).values, index=dr)
202 |
203 | return [lag(pd.DateOffset(months=m)) for m in (3, 6, 9, 12)]
204 |
205 |
206 | def make_page_features(pages: np.ndarray) -> pd.DataFrame:
207 | """
208 | Calculates page features (site, country, agent, etc) from urls
209 | :param pages: Source urls
210 | :return: DataFrame with features as columns and urls as index
211 | """
212 | tagged = extractor.extract(pages).set_index('page')
213 | # Drop useless features
214 | features: pd.DataFrame = tagged.drop(['term', 'marker'], axis=1)
215 | return features
216 |
217 |
218 | def uniq_page_map(pages:Collection):
219 | """
220 | Finds agent types (spider, desktop, mobile, all) for each unique url, i.e. groups pages by agents
221 | :param pages: all urls (must be presorted)
222 | :return: array[num_unique_urls, 4], where each column corresponds to agent type and each row corresponds to unique url.
223 | Value is an index of page in source pages array. If agent is missing, value is -1
224 | """
225 | import re
226 | result = np.full([len(pages), 4], -1, dtype=np.int32)
227 | pat = re.compile(
228 | '(.+(?:(?:wikipedia\.org)|(?:commons\.wikimedia\.org)|(?:www\.mediawiki\.org)))_([a-z_-]+?)')
229 | prev_page = None
230 | num_page = -1
231 | agents = {'all-access_spider': 0, 'desktop_all-agents': 1, 'mobile-web_all-agents': 2, 'all-access_all-agents': 3}
232 | for i, entity in enumerate(pages):
233 | match = pat.fullmatch(entity)
234 | assert match
235 | page = match.group(1)
236 | agent = match.group(2)
237 | if page != prev_page:
238 | prev_page = page
239 | num_page += 1
240 | result[num_page, agents[agent]] = i
241 | return result[:num_page+1]
242 |
243 |
244 | def encode_page_features(df) -> Dict[str, pd.DataFrame]:
245 | """
246 | Applies one-hot encoding to page features and normalises result
247 | :param df: page features DataFrame (one column per feature)
248 | :return: dictionary feature_name:encoded_values. Encoded values is [n_pages,n_values] array
249 | """
250 | def encode(column) -> pd.DataFrame:
251 | one_hot = pd.get_dummies(df[column], drop_first=False)
252 | # noinspection PyUnresolvedReferences
253 | return (one_hot - one_hot.mean()) / one_hot.std()
254 |
255 | return {str(column): encode(column) for column in df}
256 |
257 |
258 | def normalize(values: np.ndarray):
259 | return (values - values.mean()) / np.std(values)
260 |
261 |
262 | def run():
263 | parser = argparse.ArgumentParser(description='Prepare data')
264 | parser.add_argument('data_dir')
265 | parser.add_argument('--valid_threshold', default=0.0, type=float, help="Series minimal length threshold (pct of data length)")
266 | parser.add_argument('--add_days', default=64, type=int, help="Add N days in a future for prediction")
267 | parser.add_argument('--start', help="Effective start date. Data before the start is dropped")
268 | parser.add_argument('--end', help="Effective end date. Data past the end is dropped")
269 | parser.add_argument('--corr_backoffset', default=0, type=int, help='Offset for correlation calculation')
270 | args = parser.parse_args()
271 |
272 | # Get the data
273 | df, nans, starts, ends = prepare_data(args.start, args.end, args.valid_threshold)
274 |
275 | # Our working date range
276 | data_start, data_end = df.columns[0], df.columns[-1]
277 |
278 | # We have to project some date-dependent features (day of week, etc) to the future dates for prediction
279 | features_end = data_end + pd.Timedelta(args.add_days, unit='D')
280 | print(f"start: {data_start}, end:{data_end}, features_end:{features_end}")
281 |
282 | # Group unique pages by agents
283 | assert df.index.is_monotonic_increasing
284 | page_map = uniq_page_map(df.index.values)
285 |
286 | # Yearly(annual) autocorrelation
287 | raw_year_autocorr = batch_autocorr(df.values, 365, starts, ends, 1.5, args.corr_backoffset)
288 | year_unknown_pct = np.sum(np.isnan(raw_year_autocorr))/len(raw_year_autocorr) # type: float
289 |
290 | # Quarterly autocorrelation
291 | raw_quarter_autocorr = batch_autocorr(df.values, int(round(365.25/4)), starts, ends, 2, args.corr_backoffset)
292 | quarter_unknown_pct = np.sum(np.isnan(raw_quarter_autocorr)) / len(raw_quarter_autocorr) # type: float
293 |
294 | print("Percent of undefined autocorr = yearly:%.3f, quarterly:%.3f" % (year_unknown_pct, quarter_unknown_pct))
295 |
296 | # Normalise all the things
297 | year_autocorr = normalize(np.nan_to_num(raw_year_autocorr))
298 | quarter_autocorr = normalize(np.nan_to_num(raw_quarter_autocorr))
299 |
300 | # Calculate and encode page features
301 | page_features = make_page_features(df.index.values)
302 | encoded_page_features = encode_page_features(page_features)
303 |
304 | # Make time-dependent features
305 | features_days = pd.date_range(data_start, features_end)
306 | #dow = normalize(features_days.dayofweek.values)
307 | week_period = 7 / (2 * np.pi)
308 | dow_norm = features_days.dayofweek.values / week_period
309 | dow = np.stack([np.cos(dow_norm), np.sin(dow_norm)], axis=-1)
310 |
311 | # Assemble indices for quarterly lagged data
312 | lagged_ix = np.stack(lag_indexes(data_start, features_end), axis=-1)
313 |
314 | page_popularity = df.median(axis=1)
315 | page_popularity = (page_popularity - page_popularity.mean()) / page_popularity.std()
316 |
317 | # Put NaNs back
318 | df[nans] = np.NaN
319 |
320 | # Assemble final output
321 | tensors = dict(
322 | hits=df,
323 | lagged_ix=lagged_ix,
324 | page_map=page_map,
325 | page_ix=df.index.values,
326 | pf_agent=encoded_page_features['agent'],
327 | pf_country=encoded_page_features['country'],
328 | pf_site=encoded_page_features['site'],
329 | page_popularity=page_popularity,
330 | year_autocorr=year_autocorr,
331 | quarter_autocorr=quarter_autocorr,
332 | dow=dow,
333 | )
334 | plain = dict(
335 | features_days=len(features_days),
336 | data_days=len(df.columns),
337 | n_pages=len(df),
338 | data_start=data_start,
339 | data_end=data_end,
340 | features_end=features_end
341 |
342 | )
343 |
344 | # Store data to the disk
345 | VarFeeder(args.data_dir, tensors, plain)
346 |
347 |
348 | if __name__ == '__main__':
349 | run()
350 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 | import tensorflow.contrib.cudnn_rnn as cudnn_rnn
4 | import tensorflow.contrib.rnn as rnn
5 | import tensorflow.contrib.layers as layers
6 | from tensorflow.python.util import nest
7 |
8 | from input_pipe import InputPipe, ModelMode
9 |
10 | GRAD_CLIP_THRESHOLD = 10
11 | RNN = cudnn_rnn.CudnnGRU
12 | # RNN = tf.contrib.cudnn_rnn.CudnnLSTM
13 | # RNN = tf.contrib.cudnn_rnn.CudnnRNNRelu
14 |
15 |
16 | def default_init(seed):
17 | # replica of tf.glorot_uniform_initializer(seed=seed)
18 | return layers.variance_scaling_initializer(factor=1.0,
19 | mode="FAN_AVG",
20 | uniform=True,
21 | seed=seed)
22 |
23 |
24 | def selu(x):
25 | """
26 | SELU activation
27 | https://arxiv.org/abs/1706.02515
28 | :param x:
29 | :return:
30 | """
31 | with tf.name_scope('elu') as scope:
32 | alpha = 1.6732632423543772848170429916717
33 | scale = 1.0507009873554804934193349852946
34 | return scale * tf.where(x >= 0.0, x, alpha * tf.nn.elu(x))
35 |
36 |
37 | def make_encoder(time_inputs, encoder_features_depth, is_train, hparams, seed, transpose_output=True):
38 | """
39 | Builds encoder, using CUDA RNN
40 | :param time_inputs: Input tensor, shape [batch, time, features]
41 | :param encoder_features_depth: Static size for features dimension
42 | :param is_train:
43 | :param hparams:
44 | :param seed:
45 | :param transpose_output: Transform RNN output to batch-first shape
46 | :return:
47 | """
48 |
49 | def build_rnn():
50 | return RNN(num_layers=hparams.encoder_rnn_layers, num_units=hparams.rnn_depth,
51 | #input_size=encoder_features_depth,
52 | kernel_initializer=tf.initializers.random_uniform(minval=-0.05, maxval=0.05,
53 | seed=seed + 1 if seed else None),
54 | direction='unidirectional',
55 | dropout=hparams.encoder_dropout if is_train else 0, seed=seed)
56 |
57 | cuda_model = build_rnn()
58 |
59 | # [batch, time, features] -> [time, batch, features]
60 | time_first = tf.transpose(time_inputs, [1, 0, 2])
61 | rnn_time_input = time_first
62 | if RNN == tf.contrib.cudnn_rnn.CudnnLSTM:
63 | rnn_out, (rnn_state, c_state) = cuda_model(inputs=rnn_time_input)
64 | else:
65 | rnn_out, (rnn_state,) = cuda_model(inputs=rnn_time_input)
66 | c_state = None
67 | if transpose_output:
68 | rnn_out = tf.transpose(rnn_out, [1, 0, 2])
69 | return rnn_out, rnn_state, c_state
70 |
71 |
72 | def compressed_readout(rnn_out, hparams, dropout, seed):
73 | """
74 | FC compression layer, reduces RNN output depth to hparams.attention_depth
75 | :param rnn_out:
76 | :param hparams:
77 | :param dropout:
78 | :param seed:
79 | :return:
80 | """
81 | if dropout < 1.0:
82 | rnn_out = tf.nn.dropout(rnn_out, dropout, seed=seed)
83 | return tf.layers.dense(rnn_out, hparams.attention_depth,
84 | use_bias=True,
85 | activation=selu,
86 | kernel_initializer=layers.variance_scaling_initializer(factor=1.0, seed=seed),
87 | name='compress_readout'
88 | )
89 |
90 |
91 | def make_fingerprint(x, is_train, fc_dropout, seed):
92 | """
93 | Calculates 'fingerprint' of timeseries, to feed into attention layer
94 | :param x:
95 | :param is_train:
96 | :param fc_dropout:
97 | :param seed:
98 | :return:
99 | """
100 | with tf.variable_scope("fingerpint"):
101 | # x = tf.expand_dims(x, -1)
102 | with tf.variable_scope('convnet', initializer=layers.variance_scaling_initializer(seed=seed)):
103 | c11 = tf.layers.conv1d(x, filters=16, kernel_size=7, activation=tf.nn.relu, padding='same')
104 | c12 = tf.layers.conv1d(c11, filters=16, kernel_size=3, activation=tf.nn.relu, padding='same')
105 | pool1 = tf.layers.max_pooling1d(c12, 2, 2, padding='same')
106 | c21 = tf.layers.conv1d(pool1, filters=32, kernel_size=3, activation=tf.nn.relu, padding='same')
107 | c22 = tf.layers.conv1d(c21, filters=32, kernel_size=3, activation=tf.nn.relu, padding='same')
108 | pool2 = tf.layers.max_pooling1d(c22, 2, 2, padding='same')
109 | c31 = tf.layers.conv1d(pool2, filters=64, kernel_size=3, activation=tf.nn.relu, padding='same')
110 | c32 = tf.layers.conv1d(c31, filters=64, kernel_size=3, activation=tf.nn.relu, padding='same')
111 | pool3 = tf.layers.max_pooling1d(c32, 2, 2, padding='same')
112 | dims = pool3.shape.dims
113 | pool3 = tf.reshape(pool3, [-1, dims[1].value * dims[2].value])
114 | if is_train and fc_dropout < 1.0:
115 | cnn_out = tf.nn.dropout(pool3, fc_dropout, seed=seed)
116 | else:
117 | cnn_out = pool3
118 | with tf.variable_scope('fc_convnet',
119 | initializer=layers.variance_scaling_initializer(factor=1.0, mode='FAN_IN', seed=seed)):
120 | fc_encoder = tf.layers.dense(cnn_out, 512, activation=selu, name='fc_encoder')
121 | out_encoder = tf.layers.dense(fc_encoder, 16, activation=selu, name='out_encoder')
122 | return out_encoder
123 |
124 |
125 | def attn_readout_v3(readout, attn_window, attn_heads, page_features, seed):
126 | # input: [n_days, batch, readout_depth]
127 | # [n_days, batch, readout_depth] -> [batch(readout_depth), width=n_days, channels=batch]
128 | readout = tf.transpose(readout, [2, 0, 1])
129 | # [batch(readout_depth), width, channels] -> [batch, height=1, width, channels]
130 | inp = readout[:, tf.newaxis, :, :]
131 |
132 | # attn_window = train_window - predict_window + 1
133 | # [batch, attn_window * n_heads]
134 | filter_logits = tf.layers.dense(page_features, attn_window * attn_heads, name="attn_focus",
135 | kernel_initializer=default_init(seed)
136 | # kernel_initializer=layers.variance_scaling_initializer(uniform=True)
137 | # activation=selu,
138 | # kernel_initializer=layers.variance_scaling_initializer(factor=1.0, mode='FAN_IN')
139 | )
140 | # [batch, attn_window * n_heads] -> [batch, attn_window, n_heads]
141 | filter_logits = tf.reshape(filter_logits, [-1, attn_window, attn_heads])
142 |
143 | # attns_max = tf.nn.softmax(filter_logits, dim=1)
144 | attns_max = filter_logits / tf.reduce_sum(filter_logits, axis=1, keep_dims=True)
145 | # [batch, attn_window, n_heads] -> [width(attn_window), channels(batch), n_heads]
146 | attns_max = tf.transpose(attns_max, [1, 0, 2])
147 |
148 | # [width(attn_window), channels(batch), n_heads] -> [height(1), width(attn_window), channels(batch), multiplier(n_heads)]
149 | attn_filter = attns_max[tf.newaxis, :, :, :]
150 | # [batch(readout_depth), height=1, width=n_days, channels=batch] -> [batch(readout_depth), height=1, width=predict_window, channels=batch*n_heads]
151 | averaged = tf.nn.depthwise_conv2d_native(inp, attn_filter, [1, 1, 1, 1], 'VALID')
152 | # [batch, height=1, width=predict_window, channels=readout_depth*n_neads] -> [batch(depth), predict_window, batch*n_heads]
153 | attn_features = tf.squeeze(averaged, 1)
154 | # [batch(depth), predict_window, batch*n_heads] -> [batch*n_heads, predict_window, depth]
155 | attn_features = tf.transpose(attn_features, [2, 1, 0])
156 | # [batch * n_heads, predict_window, depth] -> n_heads * [batch, predict_window, depth]
157 | heads = [attn_features[head_no::attn_heads] for head_no in range(attn_heads)]
158 | # n_heads * [batch, predict_window, depth] -> [batch, predict_window, depth*n_heads]
159 | result = tf.concat(heads, axis=-1)
160 | # attn_diag = tf.unstack(attns_max, axis=-1)
161 | return result, None
162 |
163 |
164 | def calc_smape_rounded(true, predicted, weights):
165 | """
166 | Calculates SMAPE on rounded submission values. Should be close to official SMAPE in competition
167 | :param true:
168 | :param predicted:
169 | :param weights: Weights mask to exclude some values
170 | :return:
171 | """
172 | n_valid = tf.reduce_sum(weights)
173 | true_o = tf.round(tf.expm1(true))
174 | pred_o = tf.maximum(tf.round(tf.expm1(predicted)), 0.0)
175 | summ = tf.abs(true_o) + tf.abs(pred_o)
176 | zeros = summ < 0.01
177 | raw_smape = tf.abs(pred_o - true_o) / summ * 2.0
178 | smape = tf.where(zeros, tf.zeros_like(summ, dtype=tf.float32), raw_smape)
179 | return tf.reduce_sum(smape * weights) / n_valid
180 |
181 |
182 | def smape_loss(true, predicted, weights):
183 | """
184 | Differentiable SMAPE loss
185 | :param true: Truth values
186 | :param predicted: Predicted values
187 | :param weights: Weights mask to exclude some values
188 | :return:
189 | """
190 | epsilon = 0.1 # Smoothing factor, helps SMAPE to be well-behaved near zero
191 | true_o = tf.expm1(true)
192 | pred_o = tf.expm1(predicted)
193 | summ = tf.maximum(tf.abs(true_o) + tf.abs(pred_o) + epsilon, 0.5 + epsilon)
194 | smape = tf.abs(pred_o - true_o) / summ * 2.0
195 | return tf.losses.compute_weighted_loss(smape, weights, loss_collection=None)
196 |
197 |
198 | def decode_predictions(decoder_readout, inp: InputPipe):
199 | """
200 | Converts normalized prediction values to log1p(pageviews), e.g. reverts normalization
201 | :param decoder_readout: Decoder output, shape [n_days, batch]
202 | :param inp: Input tensors
203 | :return:
204 | """
205 | # [n_days, batch] -> [batch, n_days]
206 | batch_readout = tf.transpose(decoder_readout)
207 | batch_std = tf.expand_dims(inp.norm_std, -1)
208 | batch_mean = tf.expand_dims(inp.norm_mean, -1)
209 | return batch_readout * batch_std + batch_mean
210 |
211 |
212 | def calc_loss(predictions, true_y, additional_mask=None):
213 | """
214 | Calculates losses, ignoring NaN true values (assigning zero loss to them)
215 | :param predictions: Predicted values
216 | :param true_y: True values
217 | :param additional_mask:
218 | :return: MAE loss, differentiable SMAPE loss, competition SMAPE loss
219 | """
220 | # Take into account NaN's in true values
221 | mask = tf.is_finite(true_y)
222 | # Fill NaNs by zeros (can use any value)
223 | true_y = tf.where(mask, true_y, tf.zeros_like(true_y))
224 | # Assign zero weight to NaNs
225 | weights = tf.to_float(mask)
226 | if additional_mask is not None:
227 | weights = weights * tf.expand_dims(additional_mask, axis=0)
228 |
229 | mae_loss = tf.losses.absolute_difference(labels=true_y, predictions=predictions, weights=weights)
230 | return mae_loss, smape_loss(true_y, predictions, weights), calc_smape_rounded(true_y, predictions,
231 | weights), tf.size(true_y)
232 |
233 |
234 | def make_train_op(loss, ema_decay=None, prefix=None):
235 | optimizer = tf.train.AdamOptimizer()
236 | glob_step = tf.train.get_global_step()
237 |
238 | # Add regularization losses
239 | reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
240 | total_loss = loss + reg_losses if reg_losses else loss
241 |
242 | # Clip gradients
243 | grads_and_vars = optimizer.compute_gradients(total_loss)
244 | gradients, variables = zip(*grads_and_vars)
245 | clipped_gradients, glob_norm = tf.clip_by_global_norm(gradients, GRAD_CLIP_THRESHOLD)
246 | sgd_op, glob_norm = optimizer.apply_gradients(zip(clipped_gradients, variables)), glob_norm
247 |
248 | # Apply SGD averaging
249 | if ema_decay:
250 | ema = tf.train.ExponentialMovingAverage(decay=ema_decay, num_updates=glob_step)
251 | if prefix:
252 | # Some magic to handle multiple models trained in single graph
253 | ema_vars = [var for var in variables if var.name.startswith(prefix)]
254 | else:
255 | ema_vars = variables
256 | update_ema = ema.apply(ema_vars)
257 | with tf.control_dependencies([sgd_op]):
258 | training_op = tf.group(update_ema)
259 | else:
260 | training_op = sgd_op
261 | ema = None
262 | return training_op, glob_norm, ema
263 |
264 |
265 | def convert_cudnn_state_v2(h_state, hparams, seed, c_state=None, dropout=1.0):
266 | """
267 | Converts RNN state tensor from cuDNN representation to TF RNNCell compatible representation.
268 | :param h_state: tensor [num_layers, batch_size, depth]
269 | :param c_state: LSTM additional state, should be same shape as h_state
270 | :return: TF cell representation matching RNNCell.state_size structure for compatible cell
271 | """
272 |
273 | def squeeze(seq):
274 | return tuple(seq) if len(seq) > 1 else seq[0]
275 |
276 | def wrap_dropout(structure):
277 | if dropout < 1.0:
278 | return nest.map_structure(lambda x: tf.nn.dropout(x, keep_prob=dropout, seed=seed), structure)
279 | else:
280 | return structure
281 |
282 | # Cases:
283 | # decoder_layer = encoder_layers, straight mapping
284 | # encoder_layers > decoder_layers: get outputs of upper encoder layers
285 | # encoder_layers < decoder_layers: feed encoder outputs to lower decoder layers, feed zeros to top layers
286 | h_layers = tf.unstack(h_state)
287 | if hparams.encoder_rnn_layers >= hparams.decoder_rnn_layers:
288 | return squeeze(wrap_dropout(h_layers[hparams.encoder_rnn_layers - hparams.decoder_rnn_layers:]))
289 | else:
290 | lower_inputs = wrap_dropout(h_layers)
291 | upper_inputs = [tf.zeros_like(h_layers[0]) for _ in
292 | range(hparams.decoder_rnn_layers - hparams.encoder_rnn_layers)]
293 | return squeeze(lower_inputs + upper_inputs)
294 |
295 |
296 | def rnn_stability_loss(rnn_output, beta):
297 | """
298 | REGULARIZING RNNS BY STABILIZING ACTIVATIONS
299 | https://arxiv.org/pdf/1511.08400.pdf
300 | :param rnn_output: [time, batch, features]
301 | :return: loss value
302 | """
303 | if beta == 0.0:
304 | return 0.0
305 | # [time, batch, features] -> [time, batch]
306 | l2 = tf.sqrt(tf.reduce_sum(tf.square(rnn_output), axis=-1))
307 | # [time, batch] -> []
308 | return beta * tf.reduce_mean(tf.square(l2[1:] - l2[:-1]))
309 |
310 |
311 | def rnn_activation_loss(rnn_output, beta):
312 | """
313 | REGULARIZING RNNS BY STABILIZING ACTIVATIONS
314 | https://arxiv.org/pdf/1511.08400.pdf
315 | :param rnn_output: [time, batch, features]
316 | :return: loss value
317 | """
318 | if beta == 0.0:
319 | return 0.0
320 | return tf.nn.l2_loss(rnn_output) * beta
321 |
322 |
323 | class Model:
324 | def __init__(self, inp: InputPipe, hparams, is_train, seed, graph_prefix=None, asgd_decay=None, loss_mask=None):
325 | """
326 | Encoder-decoder prediction model
327 | :param inp: Input tensors
328 | :param hparams:
329 | :param is_train:
330 | :param seed:
331 | :param graph_prefix: Subgraph prefix for multi-model graph
332 | :param asgd_decay: Decay for SGD averaging
333 | :param loss_mask: Additional mask for losses calculation (one value for each prediction day), shape=[predict_window]
334 | """
335 | self.is_train = is_train
336 | self.inp = inp
337 | self.hparams = hparams
338 | self.seed = seed
339 | self.inp = inp
340 |
341 | encoder_output, h_state, c_state = make_encoder(inp.time_x, inp.encoder_features_depth, is_train, hparams, seed,
342 | transpose_output=False)
343 | # Encoder activation losses
344 | enc_stab_loss = rnn_stability_loss(encoder_output, hparams.encoder_stability_loss / inp.train_window)
345 | enc_activation_loss = rnn_activation_loss(encoder_output, hparams.encoder_activation_loss / inp.train_window)
346 |
347 | # Convert state from cuDNN representation to TF RNNCell-compatible representation
348 | encoder_state = convert_cudnn_state_v2(h_state, hparams, c_state,
349 | dropout=hparams.gate_dropout if is_train else 1.0)
350 |
351 | # Attention calculations
352 | # Compress encoder outputs
353 | enc_readout = compressed_readout(encoder_output, hparams,
354 | dropout=hparams.encoder_readout_dropout if is_train else 1.0, seed=seed)
355 | # Calculate fingerprint from input features
356 | fingerprint_inp = tf.concat([inp.lagged_x, tf.expand_dims(inp.norm_x, -1)], axis=-1)
357 | fingerprint = make_fingerprint(fingerprint_inp, is_train, hparams.fingerprint_fc_dropout, seed)
358 | # Calculate attention vector
359 | attn_features, attn_weights = attn_readout_v3(enc_readout, inp.attn_window, hparams.attention_heads,
360 | fingerprint, seed=seed)
361 |
362 | # Run decoder
363 | decoder_targets, decoder_outputs = self.decoder(encoder_state,
364 | attn_features if hparams.use_attn else None,
365 | inp.time_y, inp.norm_x[:, -1])
366 | # Decoder activation losses
367 | dec_stab_loss = rnn_stability_loss(decoder_outputs, hparams.decoder_stability_loss / inp.predict_window)
368 | dec_activation_loss = rnn_activation_loss(decoder_outputs, hparams.decoder_activation_loss / inp.predict_window)
369 |
370 | # Get final denormalized predictions
371 | self.predictions = decode_predictions(decoder_targets, inp)
372 |
373 | # Calculate losses and build training op
374 | if inp.mode == ModelMode.PREDICT:
375 | # Pseudo-apply ema to get variable names later in ema.variables_to_restore()
376 | # This is copypaste from make_train_op()
377 | if asgd_decay:
378 | self.ema = tf.train.ExponentialMovingAverage(decay=asgd_decay)
379 | variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
380 | if graph_prefix:
381 | ema_vars = [var for var in variables if var.name.startswith(graph_prefix)]
382 | else:
383 | ema_vars = variables
384 | self.ema.apply(ema_vars)
385 | else:
386 | self.mae, smape_loss, self.smape, self.loss_item_count = calc_loss(self.predictions, inp.true_y,
387 | additional_mask=loss_mask)
388 | if is_train:
389 | # Sum all losses
390 | total_loss = smape_loss + enc_stab_loss + dec_stab_loss + enc_activation_loss + dec_activation_loss
391 | self.train_op, self.glob_norm, self.ema = make_train_op(total_loss, asgd_decay, prefix=graph_prefix)
392 |
393 |
394 |
395 | def default_init(self, seed_add=0):
396 | return default_init(self.seed + seed_add)
397 |
398 | def decoder(self, encoder_state, attn_features, prediction_inputs, previous_y):
399 | """
400 | :param encoder_state: shape [batch_size, encoder_rnn_depth]
401 | :param prediction_inputs: features for prediction days, tensor[batch_size, time, input_depth]
402 | :param previous_y: Last day pageviews, shape [batch_size]
403 | :param attn_features: Additional features from attention layer, shape [batch, predict_window, readout_depth*n_heads]
404 | :return: decoder rnn output
405 | """
406 | hparams = self.hparams
407 |
408 | def build_cell(idx):
409 | with tf.variable_scope('decoder_cell', initializer=self.default_init(idx)):
410 | cell = rnn.GRUBlockCell(self.hparams.rnn_depth)
411 | has_dropout = hparams.decoder_input_dropout[idx] < 1 \
412 | or hparams.decoder_state_dropout[idx] < 1 or hparams.decoder_output_dropout[idx] < 1
413 |
414 | if self.is_train and has_dropout:
415 | attn_depth = attn_features.shape[-1].value if attn_features is not None else 0
416 | input_size = attn_depth + prediction_inputs.shape[-1].value + 1 if idx == 0 else self.hparams.rnn_depth
417 | cell = rnn.DropoutWrapper(cell, dtype=tf.float32, input_size=input_size,
418 | variational_recurrent=hparams.decoder_variational_dropout[idx],
419 | input_keep_prob=hparams.decoder_input_dropout[idx],
420 | output_keep_prob=hparams.decoder_output_dropout[idx],
421 | state_keep_prob=hparams.decoder_state_dropout[idx], seed=self.seed + idx)
422 | return cell
423 |
424 | if hparams.decoder_rnn_layers > 1:
425 | cells = [build_cell(idx) for idx in range(hparams.decoder_rnn_layers)]
426 | cell = rnn.MultiRNNCell(cells)
427 | else:
428 | cell = build_cell(0)
429 |
430 | nest.assert_same_structure(encoder_state, cell.state_size)
431 | predict_days = self.inp.predict_window
432 | assert prediction_inputs.shape[1] == predict_days
433 |
434 | # [batch_size, time, input_depth] -> [time, batch_size, input_depth]
435 | inputs_by_time = tf.transpose(prediction_inputs, [1, 0, 2])
436 |
437 | # Return raw outputs for RNN losses calculation
438 | return_raw_outputs = self.hparams.decoder_stability_loss > 0.0 or self.hparams.decoder_activation_loss > 0.0
439 |
440 | # Stop condition for decoding loop
441 | def cond_fn(time, prev_output, prev_state, array_targets: tf.TensorArray, array_outputs: tf.TensorArray):
442 | return time < predict_days
443 |
444 | # FC projecting layer to get single predicted value from RNN output
445 | def project_output(tensor):
446 | return tf.layers.dense(tensor, 1, name='decoder_output_proj', kernel_initializer=self.default_init())
447 |
448 | def loop_fn(time, prev_output, prev_state, array_targets: tf.TensorArray, array_outputs: tf.TensorArray):
449 | """
450 | Main decoder loop
451 | :param time: Day number
452 | :param prev_output: Output(prediction) from previous step
453 | :param prev_state: RNN state tensor from previous step
454 | :param array_targets: Predictions, each step will append new value to this array
455 | :param array_outputs: Raw RNN outputs (for regularization losses)
456 | :return:
457 | """
458 | # RNN inputs for current step
459 | features = inputs_by_time[time]
460 |
461 | # [batch, predict_window, readout_depth * n_heads] -> [batch, readout_depth * n_heads]
462 | if attn_features is not None:
463 | # [batch_size, 1] + [batch_size, input_depth]
464 | attn = attn_features[:, time, :]
465 | # Append previous predicted value + attention vector to input features
466 | next_input = tf.concat([prev_output, features, attn], axis=1)
467 | else:
468 | # Append previous predicted value to input features
469 | next_input = tf.concat([prev_output, features], axis=1)
470 |
471 | # Run RNN cell
472 | output, state = cell(next_input, prev_state)
473 | # Make prediction from RNN outputs
474 | projected_output = project_output(output)
475 | # Append step results to the buffer arrays
476 | if return_raw_outputs:
477 | array_outputs = array_outputs.write(time, output)
478 | array_targets = array_targets.write(time, projected_output)
479 | # Increment time and return
480 | return time + 1, projected_output, state, array_targets, array_outputs
481 |
482 | # Initial values for loop
483 | loop_init = [tf.constant(0, dtype=tf.int32),
484 | tf.expand_dims(previous_y, -1),
485 | encoder_state,
486 | tf.TensorArray(dtype=tf.float32, size=predict_days),
487 | tf.TensorArray(dtype=tf.float32, size=predict_days) if return_raw_outputs else tf.constant(0)]
488 | # Run the loop
489 | _, _, _, targets_ta, outputs_ta = tf.while_loop(cond_fn, loop_fn, loop_init)
490 |
491 | # Get final tensors from buffer arrays
492 | targets = targets_ta.stack()
493 | # [time, batch_size, 1] -> [time, batch_size]
494 | targets = tf.squeeze(targets, axis=-1)
495 | raw_outputs = outputs_ta.stack() if return_raw_outputs else None
496 | return targets, raw_outputs
497 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy
2 | pandas
3 | tqdm
4 | matplotlib
5 | tensorflow-gpu>=1.10
6 |
7 |
--------------------------------------------------------------------------------
/trainer.py:
--------------------------------------------------------------------------------
1 | import os.path
2 | import shutil
3 | import sys
4 | import numpy as np
5 | import tensorflow as tf
6 | from tqdm import trange
7 | from typing import List, Tuple
8 | import heapq
9 | import logging
10 | import pandas as pd
11 | from enum import Enum
12 |
13 | from hparams import build_from_set, build_hparams
14 | from feeder import VarFeeder
15 | from input_pipe import InputPipe, ModelMode, Splitter,FakeSplitter, page_features
16 | from model import Model
17 | import argparse
18 |
19 |
20 | log = logging.getLogger('trainer')
21 |
22 | class Ema:
23 | def __init__(self, k=0.99):
24 | self.k = k
25 | self.state = None
26 | self.steps = 0
27 |
28 | def __call__(self, *args, **kwargs):
29 | v = args[0]
30 | self.steps += 1
31 | if self.state is None:
32 | self.state = v
33 | else:
34 | eff_k = min(1 - 1 / self.steps, self.k)
35 | self.state = eff_k * self.state + (1 - eff_k) * v
36 | return self.state
37 |
38 |
39 | class Metric:
40 | def __init__(self, name: str, op, smoothness: float = None):
41 | self.name = name
42 | self.op = op
43 | self.smoother = Ema(smoothness) if smoothness else None
44 | self.epoch_values = []
45 | self.best_value = np.Inf
46 | self.best_step = 0
47 | self.last_epoch = -1
48 | self.improved = False
49 | self._top = []
50 |
51 | @property
52 | def avg_epoch(self):
53 | return np.mean(self.epoch_values)
54 |
55 | @property
56 | def best_epoch(self):
57 | return np.min(self.epoch_values)
58 |
59 | @property
60 | def last(self):
61 | return self.epoch_values[-1] if self.epoch_values else np.nan
62 |
63 | @property
64 | def top(self):
65 | return -np.mean(self._top)
66 |
67 |
68 | def update(self, value, epoch, step):
69 | if self.smoother:
70 | value = self.smoother(value)
71 | if epoch > self.last_epoch:
72 | self.epoch_values = []
73 | self.last_epoch = epoch
74 | self.epoch_values.append(value)
75 | if value < self.best_value:
76 | self.best_value = value
77 | self.best_step = step
78 | self.improved = True
79 | else:
80 | self.improved = False
81 | if len(self._top) >= 5:
82 | heapq.heappushpop(self._top, -value)
83 | else:
84 | heapq.heappush(self._top, -value)
85 |
86 |
87 | class AggMetric:
88 | def __init__(self, metrics: List[Metric]):
89 | self.metrics = metrics
90 |
91 | def _mean(self, fun) -> float:
92 | # noinspection PyTypeChecker
93 | return np.mean([fun(metric) for metric in self.metrics])
94 |
95 | @property
96 | def avg_epoch(self):
97 | return self._mean(lambda m: m.avg_epoch)
98 |
99 | @property
100 | def best_epoch(self):
101 | return self._mean(lambda m: m.best_epoch)
102 |
103 | @property
104 | def last(self):
105 | return self._mean(lambda m: m.last)
106 |
107 | @property
108 | def top(self):
109 | return self._mean(lambda m: m.top)
110 |
111 | @property
112 | def improved(self):
113 | return np.any([metric.improved for metric in self.metrics])
114 |
115 |
116 | class DummyMetric:
117 | @property
118 | def avg_epoch(self):
119 | return np.nan
120 |
121 | @property
122 | def best_epoch(self):
123 | return np.nan
124 |
125 | @property
126 | def last(self):
127 | return np.nan
128 |
129 | @property
130 | def top(self):
131 | return np.nan
132 |
133 | @property
134 | def improved(self):
135 | return False
136 |
137 | @property
138 | def metrics(self):
139 | return []
140 |
141 |
142 | class Stage(Enum):
143 | TRAIN = 0
144 | EVAL_SIDE = 1
145 | EVAL_FRWD = 2
146 | EVAL_SIDE_EMA = 3
147 | EVAL_FRWD_EMA = 4
148 |
149 |
150 | class ModelTrainerV2:
151 | def __init__(self, train_model: Model, eval: List[Tuple[Stage, Model]], model_no=0,
152 | patience=None, stop_metric=None, summary_writer=None):
153 | self.train_model = train_model
154 | if eval:
155 | self.eval_stages, self.eval_models = zip(*eval)
156 | else:
157 | self.eval_stages, self.eval_models = [], []
158 | self.stopped = False
159 | self.model_no = model_no
160 | self.patience = patience
161 | self.best_metric = np.inf
162 | self.bad_epochs = 0
163 | self.stop_metric = stop_metric
164 | self.summary_writer = summary_writer
165 |
166 | def std_metrics(model: Model, smoothness):
167 | return [Metric('SMAPE', model.smape, smoothness), Metric('MAE', model.mae, smoothness)]
168 |
169 | self._metrics = {Stage.TRAIN: std_metrics(train_model, 0.9) + [Metric('GrNorm', train_model.glob_norm)]}
170 | for stage, model in eval:
171 | self._metrics[stage] = std_metrics(model, None)
172 | self.dict_metrics = {key: {metric.name: metric for metric in metrics} for key, metrics in self._metrics.items()}
173 |
174 | def init(self, sess):
175 | for model in list(self.eval_models) + [self.train_model]:
176 | model.inp.init_iterator(sess)
177 |
178 | @property
179 | def metrics(self):
180 | return self._metrics
181 |
182 | @property
183 | def train_ops(self):
184 | model = self.train_model
185 | return [model.train_op] # , model.summaries
186 |
187 | def metric_ops(self, key):
188 | return [metric.op for metric in self._metrics[key]]
189 |
190 | def process_metrics(self, key, run_results, epoch, step):
191 | metrics = self._metrics[key]
192 | summaries = []
193 | for result, metric in zip(run_results, metrics):
194 | metric.update(result, epoch, step)
195 | summaries.append(tf.Summary.Value(tag=f"{key.name}/{metric.name}_0", simple_value=result))
196 | return summaries
197 |
198 | def end_epoch(self):
199 | if self.stop_metric:
200 | best_metric = self.stop_metric(self.dict_metrics)# self.dict_metrics[Stage.EVAL_FRWD]['SMAPE'].avg_epoch
201 | if self.best_metric > best_metric:
202 | self.best_metric = best_metric
203 | self.bad_epochs = 0
204 | else:
205 | self.bad_epochs += 1
206 | if self.bad_epochs > self.patience:
207 | self.stopped = True
208 |
209 |
210 | class MultiModelTrainer:
211 | def __init__(self, trainers: List[ModelTrainerV2], inc_step_op,
212 | misc_global_ops=None):
213 | self.trainers = trainers
214 | self.inc_step = inc_step_op
215 | self.global_ops = misc_global_ops or []
216 | self.eval_stages = trainers[0].eval_stages
217 |
218 | def active(self):
219 | return [trainer for trainer in self.trainers if not trainer.stopped]
220 |
221 | def _metric_step(self, stage, initial_ops, sess: tf.Session, epoch: int, step=None, repeats=1, summary_every=1):
222 | ops = initial_ops
223 | offsets, lengths = [], []
224 | trainers = self.active()
225 | for trainer in trainers:
226 | offsets.append(len(ops))
227 | metric_ops = trainer.metric_ops(stage)
228 | lengths.append(len(metric_ops))
229 | ops.extend(metric_ops)
230 | if repeats > 1:
231 | all_results = np.stack([np.array(sess.run(ops)) for _ in range(repeats)])
232 | results = np.mean(all_results, axis=0)
233 | else:
234 | results = sess.run(ops)
235 | if step is None:
236 | step = results[0]
237 |
238 | for trainer, offset, length in zip(trainers, offsets, lengths):
239 | chunk = results[offset: offset + length]
240 | summaries = trainer.process_metrics(stage, chunk, epoch, step)
241 | if trainer.summary_writer and step > 200 and (step % summary_every == 0):
242 | summary = tf.Summary(value=summaries)
243 | trainer.summary_writer.add_summary(summary, global_step=step)
244 | return results
245 |
246 | def train_step(self, sess: tf.Session, epoch: int):
247 | ops = [self.inc_step] + self.global_ops
248 | for trainer in self.active():
249 | ops.extend(trainer.train_ops)
250 | results = self._metric_step(Stage.TRAIN, ops, sess, epoch, summary_every=20)
251 | #return results[:len(self.global_ops) + 1] # step, grad_norm
252 | return results[0]
253 |
254 | def eval_step(self, sess: tf.Session, epoch: int, step, n_batches, stages:List[Stage]=None):
255 | target_stages = stages if stages is not None else self.eval_stages
256 | for stage in target_stages:
257 | self._metric_step(stage, [], sess, epoch, step, repeats=n_batches)
258 |
259 | def metric(self, stage, name):
260 | return AggMetric([trainer.dict_metrics[stage][name] for trainer in self.trainers])
261 |
262 | def end_epoch(self):
263 | for trainer in self.active():
264 | trainer.end_epoch()
265 |
266 | def has_active(self):
267 | return len(self.active())
268 |
269 |
270 | class ModelTrainer:
271 | def __init__(self, train_model, eval_model, model_no=0, summary_writer=None, keep_best=5, patience=None):
272 | self.train_model = train_model
273 | self.eval_model = eval_model
274 | self.stopped = False
275 | self.smooth_train_mae = Ema()
276 | self.smooth_train_smape = Ema()
277 | self.smooth_eval_mae = Ema(0.5)
278 | self.smooth_eval_smape = Ema(0.5)
279 | self.smooth_grad = Ema(0.9)
280 | self.summary_writer = summary_writer
281 | self.model_no = model_no
282 | self.best_top_n_loss = []
283 | self.keep_best = keep_best
284 | self.best_step = 0
285 | self.patience = patience
286 | self.train_pipe = train_model.inp
287 | self.eval_pipe = eval_model.inp
288 | self.epoch_mae = []
289 | self.epoch_smape = []
290 | self.last_epoch = -1
291 |
292 | @property
293 | def train_ops(self):
294 | model = self.train_model
295 | return [model.train_op, model.update_ema, model.summaries, model.mae, model.smape, model.glob_norm]
296 |
297 | def process_train_results(self, run_results, offset, global_step, write_summary):
298 | offset += 2
299 | summaries, mae, smape, glob_norm = run_results[offset:offset + 4]
300 | results = self.smooth_train_mae(mae), self.smooth_train_smape(smape), self.smooth_grad(glob_norm)
301 | if self.summary_writer and write_summary:
302 | self.summary_writer.add_summary(summaries, global_step=global_step)
303 | return np.array(results)
304 |
305 | @property
306 | def eval_ops(self):
307 | model = self.eval_model
308 | return [model.mae, model.smape]
309 |
310 | @property
311 | def eval_len(self):
312 | return len(self.eval_ops)
313 |
314 | @property
315 | def train_len(self):
316 | return len(self.train_ops)
317 |
318 | @property
319 | def best_top_loss(self):
320 | return -np.array(self.best_top_n_loss).mean()
321 |
322 | @property
323 | def best_epoch_mae(self):
324 | return min(self.epoch_mae) if self.epoch_mae else np.NaN
325 |
326 | @property
327 | def mean_epoch_mae(self):
328 | return np.mean(self.epoch_mae) if self.epoch_mae else np.NaN
329 |
330 | @property
331 | def mean_epoch_smape(self):
332 | return np.mean(self.epoch_smape) if self.epoch_smape else np.NaN
333 |
334 | @property
335 | def best_epoch_smape(self):
336 | return min(self.epoch_smape) if self.epoch_smape else np.NaN
337 |
338 | def remember_for_epoch(self, epoch, mae, smape):
339 | if epoch > self.last_epoch:
340 | self.last_epoch = epoch
341 | self.epoch_mae = []
342 | self.epoch_smape = []
343 | self.epoch_mae.append(mae)
344 | self.epoch_smape.append(smape)
345 |
346 | @property
347 | def best_epoch_metrics(self):
348 | return np.array([self.best_epoch_mae, self.best_epoch_smape])
349 |
350 | @property
351 | def mean_epoch_metrics(self):
352 | return np.array([self.mean_epoch_mae, self.mean_epoch_smape])
353 |
354 | def process_eval_results(self, run_results, offset, global_step, epoch):
355 | totals = np.zeros(self.eval_len, np.float)
356 | for result in run_results:
357 | items = np.array(result[offset:offset + self.eval_len])
358 | totals += items
359 | results = totals / len(run_results)
360 | mae, smape = results
361 | if self.summary_writer and global_step > 200:
362 | summary = tf.Summary(value=[
363 | tf.Summary.Value(tag=f"test/MAE_{self.model_no}", simple_value=mae),
364 | tf.Summary.Value(tag=f"test/SMAPE_{self.model_no}", simple_value=smape),
365 | ])
366 | self.summary_writer.add_summary(summary, global_step=global_step)
367 | smooth_mae = self.smooth_eval_mae(mae)
368 | smooth_smape = self.smooth_eval_smape(smape)
369 | self.remember_for_epoch(epoch, mae, smape)
370 |
371 | current_loss = -smooth_smape
372 |
373 | prev_best_n = np.mean(self.best_top_n_loss) if self.best_top_n_loss else -np.inf
374 | if self.best_top_n_loss:
375 | log.debug("Current loss=%.3f, old best=%.3f, wait steps=%d", -current_loss,
376 | -max(self.best_top_n_loss), global_step - self.best_step)
377 |
378 | if len(self.best_top_n_loss) >= self.keep_best:
379 | heapq.heappushpop(self.best_top_n_loss, current_loss)
380 | else:
381 | heapq.heappush(self.best_top_n_loss, current_loss)
382 | log.debug("Best loss=%.3f, top_5 avg loss=%.3f, top_5=%s",
383 | -max(self.best_top_n_loss), -np.mean(self.best_top_n_loss),
384 | ",".join(["%.3f" % -mae for mae in self.best_top_n_loss]))
385 | new_best_n = np.mean(self.best_top_n_loss)
386 |
387 | new_best = new_best_n > prev_best_n
388 | if new_best:
389 | self.best_step = global_step
390 | log.debug("New best step %d, current loss=%.3f", global_step, -current_loss)
391 | else:
392 | step_count = global_step - self.best_step
393 | if step_count > self.patience:
394 | self.stopped = True
395 |
396 | return mae, smape, new_best, smooth_mae, smooth_smape
397 |
398 |
399 | def train(name, hparams, multi_gpu=False, n_models=1, train_completeness_threshold=0.01,
400 | seed=None, logdir='data/logs', max_epoch=100, patience=2, train_sampling=1.0,
401 | eval_sampling=1.0, eval_memsize=5, gpu=0, gpu_allow_growth=False, save_best_model=False,
402 | forward_split=False, write_summaries=False, verbose=False, asgd_decay=None, tqdm=True,
403 | side_split=True, max_steps=None, save_from_step=None, do_eval=True, predict_window=63):
404 |
405 | eval_k = int(round(26214 * eval_memsize / n_models))
406 | eval_batch_size = int(
407 | eval_k / (hparams.rnn_depth * hparams.encoder_rnn_layers)) # 128 -> 1024, 256->512, 512->256
408 | eval_pct = 0.1
409 | batch_size = hparams.batch_size
410 | train_window = hparams.train_window
411 | tf.reset_default_graph()
412 | if seed:
413 | tf.set_random_seed(seed)
414 |
415 | with tf.device("/cpu:0"):
416 | inp = VarFeeder.read_vars("data/vars")
417 | if side_split:
418 | splitter = Splitter(page_features(inp), inp.page_map, 3, train_sampling=train_sampling,
419 | test_sampling=eval_sampling, seed=seed)
420 | else:
421 | splitter = FakeSplitter(page_features(inp), 3, seed=seed, test_sampling=eval_sampling)
422 |
423 | real_train_pages = splitter.splits[0].train_size
424 | real_eval_pages = splitter.splits[0].test_size
425 |
426 | items_per_eval = real_eval_pages * eval_pct
427 | eval_batches = int(np.ceil(items_per_eval / eval_batch_size))
428 | steps_per_epoch = real_train_pages // batch_size
429 | eval_every_step = int(round(steps_per_epoch * eval_pct))
430 | # eval_every_step = int(round(items_per_eval * train_sampling / batch_size))
431 |
432 | global_step = tf.train.get_or_create_global_step()
433 | inc_step = tf.assign_add(global_step, 1)
434 |
435 |
436 | all_models: List[ModelTrainerV2] = []
437 |
438 | def create_model(scope, index, prefix, seed):
439 |
440 | with tf.variable_scope('input') as inp_scope:
441 | with tf.device("/cpu:0"):
442 | split = splitter.splits[index]
443 | pipe = InputPipe(inp, features=split.train_set, n_pages=split.train_size,
444 | mode=ModelMode.TRAIN, batch_size=batch_size, n_epoch=None, verbose=verbose,
445 | train_completeness_threshold=train_completeness_threshold,
446 | predict_completeness_threshold=train_completeness_threshold, train_window=train_window,
447 | predict_window=predict_window,
448 | rand_seed=seed, train_skip_first=hparams.train_skip_first,
449 | back_offset=predict_window if forward_split else 0)
450 | inp_scope.reuse_variables()
451 | if side_split:
452 | side_eval_pipe = InputPipe(inp, features=split.test_set, n_pages=split.test_size,
453 | mode=ModelMode.EVAL, batch_size=eval_batch_size, n_epoch=None,
454 | verbose=verbose, predict_window=predict_window,
455 | train_completeness_threshold=0.01, predict_completeness_threshold=0,
456 | train_window=train_window, rand_seed=seed, runs_in_burst=eval_batches,
457 | back_offset=predict_window * (2 if forward_split else 1))
458 | else:
459 | side_eval_pipe = None
460 | if forward_split:
461 | forward_eval_pipe = InputPipe(inp, features=split.test_set, n_pages=split.test_size,
462 | mode=ModelMode.EVAL, batch_size=eval_batch_size, n_epoch=None,
463 | verbose=verbose, predict_window=predict_window,
464 | train_completeness_threshold=0.01, predict_completeness_threshold=0,
465 | train_window=train_window, rand_seed=seed, runs_in_burst=eval_batches,
466 | back_offset=predict_window)
467 | else:
468 | forward_eval_pipe = None
469 | avg_sgd = asgd_decay is not None
470 | #asgd_decay = 0.99 if avg_sgd else None
471 | train_model = Model(pipe, hparams, is_train=True, graph_prefix=prefix, asgd_decay=asgd_decay, seed=seed)
472 | scope.reuse_variables()
473 |
474 | eval_stages = []
475 | if side_split:
476 | side_eval_model = Model(side_eval_pipe, hparams, is_train=False,
477 | #loss_mask=np.concatenate([np.zeros(50, dtype=np.float32), np.ones(10, dtype=np.float32)]),
478 | seed=seed)
479 | eval_stages.append((Stage.EVAL_SIDE, side_eval_model))
480 | if avg_sgd:
481 | eval_stages.append((Stage.EVAL_SIDE_EMA, side_eval_model))
482 | if forward_split:
483 | forward_eval_model = Model(forward_eval_pipe, hparams, is_train=False, seed=seed)
484 | eval_stages.append((Stage.EVAL_FRWD, forward_eval_model))
485 | if avg_sgd:
486 | eval_stages.append((Stage.EVAL_FRWD_EMA, forward_eval_model))
487 |
488 | if write_summaries:
489 | summ_path = f"{logdir}/{name}_{index}"
490 | if os.path.exists(summ_path):
491 | shutil.rmtree(summ_path)
492 | summ_writer = tf.summary.FileWriter(summ_path) # , graph=tf.get_default_graph()
493 | else:
494 | summ_writer = None
495 | if do_eval and forward_split:
496 | stop_metric = lambda metrics: metrics[Stage.EVAL_FRWD]['SMAPE'].avg_epoch
497 | else:
498 | stop_metric = None
499 | return ModelTrainerV2(train_model, eval_stages, index, patience=patience,
500 | stop_metric=stop_metric,
501 | summary_writer=summ_writer)
502 |
503 |
504 | if n_models == 1:
505 | with tf.device(f"/gpu:{gpu}"):
506 | scope = tf.get_variable_scope()
507 | all_models = [create_model(scope, 0, None, seed=seed)]
508 | else:
509 | for i in range(n_models):
510 | device = f"/gpu:{i}" if multi_gpu else f"/gpu:{gpu}"
511 | with tf.device(device):
512 | prefix = f"m_{i}"
513 | with tf.variable_scope(prefix) as scope:
514 | all_models.append(create_model(scope, i, prefix=prefix, seed=seed + i))
515 | trainer = MultiModelTrainer(all_models, inc_step)
516 | if save_best_model or save_from_step:
517 | saver_path = f'data/cpt/{name}'
518 | if os.path.exists(saver_path):
519 | shutil.rmtree(saver_path)
520 | os.makedirs(saver_path)
521 | saver = tf.train.Saver(max_to_keep=10, name='train_saver')
522 | else:
523 | saver = None
524 | avg_sgd = asgd_decay is not None
525 | if avg_sgd:
526 | from itertools import chain
527 | def ema_vars(model):
528 | ema = model.train_model.ema
529 | return {ema.average_name(v):v for v in model.train_model.ema._averages}
530 |
531 | ema_names = dict(chain(*[ema_vars(model).items() for model in all_models]))
532 | #ema_names = all_models[0].train_model.ema.variables_to_restore()
533 | ema_loader = tf.train.Saver(var_list=ema_names, max_to_keep=1, name='ema_loader')
534 | ema_saver = tf.train.Saver(max_to_keep=1, name='ema_saver')
535 | else:
536 | ema_loader = None
537 |
538 | init = tf.global_variables_initializer()
539 |
540 | if forward_split and do_eval:
541 | eval_smape = trainer.metric(Stage.EVAL_FRWD, 'SMAPE')
542 | eval_mae = trainer.metric(Stage.EVAL_FRWD, 'MAE')
543 | else:
544 | eval_smape = DummyMetric()
545 | eval_mae = DummyMetric()
546 |
547 | if side_split and do_eval:
548 | eval_mae_side = trainer.metric(Stage.EVAL_SIDE, 'MAE')
549 | eval_smape_side = trainer.metric(Stage.EVAL_SIDE, 'SMAPE')
550 | else:
551 | eval_mae_side = DummyMetric()
552 | eval_smape_side = DummyMetric()
553 |
554 | train_smape = trainer.metric(Stage.TRAIN, 'SMAPE')
555 | train_mae = trainer.metric(Stage.TRAIN, 'MAE')
556 | grad_norm = trainer.metric(Stage.TRAIN, 'GrNorm')
557 | eval_stages = []
558 | ema_eval_stages = []
559 | if forward_split and do_eval:
560 | eval_stages.append(Stage.EVAL_FRWD)
561 | ema_eval_stages.append(Stage.EVAL_FRWD_EMA)
562 | if side_split and do_eval:
563 | eval_stages.append(Stage.EVAL_SIDE)
564 | ema_eval_stages.append(Stage.EVAL_SIDE_EMA)
565 |
566 | # gpu_options=tf.GPUOptions(allow_growth=False),
567 | with tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
568 | gpu_options=tf.GPUOptions(allow_growth=gpu_allow_growth))) as sess:
569 | sess.run(init)
570 | # pipe.load_vars(sess)
571 | inp.restore(sess)
572 | for model in all_models:
573 | model.init(sess)
574 | # if beholder:
575 | # visualizer = Beholder(session=sess, logdir=summ_path)
576 | step = 0
577 | prev_top = np.inf
578 | best_smape = np.inf
579 | # Contains best value (first item) and subsequent values
580 | best_epoch_smape = []
581 |
582 | for epoch in range(max_epoch):
583 |
584 | # n_steps = pusher.n_pages // batch_size
585 | if tqdm:
586 | tqr = trange(steps_per_epoch, desc="%2d" % (epoch + 1), leave=False)
587 | else:
588 | tqr = range(steps_per_epoch)
589 |
590 | for _ in tqr:
591 | try:
592 | step = trainer.train_step(sess, epoch)
593 | except tf.errors.OutOfRangeError:
594 | break
595 | # if beholder:
596 | # if step % 5 == 0:
597 | # noinspection PyUnboundLocalVariable
598 | # visualizer.update()
599 | if step % eval_every_step == 0:
600 | if eval_stages:
601 | trainer.eval_step(sess, epoch, step, eval_batches, stages=eval_stages)
602 |
603 | if save_best_model and epoch > 0 and eval_smape.last < best_smape:
604 | best_smape = eval_smape.last
605 | saver.save(sess, f'data/cpt/{name}/cpt', global_step=step)
606 | if save_from_step and step >= save_from_step:
607 | saver.save(sess, f'data/cpt/{name}/cpt', global_step=step)
608 |
609 | if avg_sgd and ema_eval_stages:
610 | ema_saver.save(sess, 'data/cpt_tmp/ema', write_meta_graph=False)
611 | # restore ema-backed vars
612 | ema_loader.restore(sess, 'data/cpt_tmp/ema')
613 |
614 | trainer.eval_step(sess, epoch, step, eval_batches, stages=ema_eval_stages)
615 | # restore normal vars
616 | ema_saver.restore(sess, 'data/cpt_tmp/ema')
617 |
618 | MAE = "%.3f/%.3f/%.3f" % (eval_mae.last, eval_mae_side.last, train_mae.last)
619 | improvement = '↑' if eval_smape.improved else ' '
620 | SMAPE = "%s%.3f/%.3f/%.3f" % (improvement, eval_smape.last, eval_smape_side.last, train_smape.last)
621 | if tqdm:
622 | tqr.set_postfix(gr=grad_norm.last, MAE=MAE, SMAPE=SMAPE)
623 | if not trainer.has_active() or (max_steps and step > max_steps):
624 | break
625 |
626 | if tqdm:
627 | tqr.close()
628 | trainer.end_epoch()
629 | if not best_epoch_smape or eval_smape.avg_epoch < best_epoch_smape[0]:
630 | best_epoch_smape = [eval_smape.avg_epoch]
631 | else:
632 | best_epoch_smape.append(eval_smape.avg_epoch)
633 |
634 | current_top = eval_smape.top
635 | if prev_top > current_top:
636 | prev_top = current_top
637 | has_best_indicator = '↑'
638 | else:
639 | has_best_indicator = ' '
640 | status = "%2d: Best top SMAPE=%.3f%s (%s)" % (
641 | epoch + 1, current_top, has_best_indicator,
642 | ",".join(["%.3f" % m.top for m in eval_smape.metrics]))
643 |
644 | if trainer.has_active():
645 | status += ", frwd/side best MAE=%.3f/%.3f, SMAPE=%.3f/%.3f; avg MAE=%.3f/%.3f, SMAPE=%.3f/%.3f, %d am" % \
646 | (eval_mae.best_epoch, eval_mae_side.best_epoch, eval_smape.best_epoch, eval_smape_side.best_epoch,
647 | eval_mae.avg_epoch, eval_mae_side.avg_epoch, eval_smape.avg_epoch, eval_smape_side.avg_epoch,
648 | trainer.has_active())
649 | print(status, file=sys.stderr)
650 | else:
651 | print(status, file=sys.stderr)
652 | print("Early stopping!", file=sys.stderr)
653 | break
654 | if max_steps and step > max_steps:
655 | print("Max steps calculated", file=sys.stderr)
656 | break
657 | sys.stderr.flush()
658 |
659 | # noinspection PyUnboundLocalVariable
660 | return np.mean(best_epoch_smape, dtype=np.float64)
661 |
662 |
663 | def predict(checkpoints, hparams, return_x=False, verbose=False, predict_window=6, back_offset=0, n_models=1,
664 | target_model=0, asgd=False, seed=1, batch_size=1024):
665 | with tf.variable_scope('input') as inp_scope:
666 | with tf.device("/cpu:0"):
667 | inp = VarFeeder.read_vars("data/vars")
668 | pipe = InputPipe(inp, page_features(inp), inp.n_pages, mode=ModelMode.PREDICT, batch_size=batch_size,
669 | n_epoch=1, verbose=verbose,
670 | train_completeness_threshold=0.01,
671 | predict_window=predict_window,
672 | predict_completeness_threshold=0.0, train_window=hparams.train_window,
673 | back_offset=back_offset)
674 | asgd_decay = 0.99 if asgd else None
675 | if n_models == 1:
676 | model = Model(pipe, hparams, is_train=False, seed=seed, asgd_decay=asgd_decay)
677 | else:
678 | models = []
679 | for i in range(n_models):
680 | prefix = f"m_{i}"
681 | with tf.variable_scope(prefix) as scope:
682 | models.append(Model(pipe, hparams, is_train=False, seed=seed, asgd_decay=asgd_decay, graph_prefix=prefix))
683 | model = models[target_model]
684 |
685 | if asgd:
686 | var_list = model.ema.variables_to_restore()
687 | prefix = f"m_{target_model}"
688 | for var in list(var_list.keys()):
689 | if var.endswith('ExponentialMovingAverage') and not var.startswith(prefix):
690 | del var_list[var]
691 | else:
692 | var_list = None
693 | saver = tf.train.Saver(name='eval_saver', var_list=var_list)
694 | x_buffer = []
695 | predictions = None
696 | with tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True))) as sess:
697 | pipe.load_vars(sess)
698 | for checkpoint in checkpoints:
699 | pred_buffer = []
700 | pipe.init_iterator(sess)
701 | saver.restore(sess, checkpoint)
702 | cnt = 0
703 | while True:
704 | try:
705 | if return_x:
706 | pred, x, pname = sess.run([model.predictions, model.inp.true_x, model.inp.page_ix])
707 | else:
708 | pred, pname = sess.run([model.predictions, model.inp.page_ix])
709 | utf_names = [str(name, 'utf-8') for name in pname]
710 | pred_df = pd.DataFrame(index=utf_names, data=np.expm1(pred))
711 | pred_buffer.append(pred_df)
712 | if return_x:
713 | # noinspection PyUnboundLocalVariable
714 | x_values = pd.DataFrame(index=utf_names, data=np.round(np.expm1(x)).astype(np.int64))
715 | x_buffer.append(x_values)
716 | newline = cnt % 80 == 0
717 | if cnt > 0:
718 | print('.', end='\n' if newline else '', flush=True)
719 | if newline:
720 | print(cnt, end='')
721 | cnt += 1
722 | except tf.errors.OutOfRangeError:
723 | print('🎉')
724 | break
725 | cp_predictions = pd.concat(pred_buffer)
726 | if predictions is None:
727 | predictions = cp_predictions
728 | else:
729 | predictions += cp_predictions
730 | predictions /= len(checkpoints)
731 | offset = pd.Timedelta(back_offset, 'D')
732 | start_prediction = inp.data_end + pd.Timedelta('1D') - offset
733 | end_prediction = start_prediction + pd.Timedelta(predict_window - 1, 'D')
734 | predictions.columns = pd.date_range(start_prediction, end_prediction)
735 | if return_x:
736 | x = pd.concat(x_buffer)
737 | start_data = inp.data_end - pd.Timedelta(hparams.train_window - 1, 'D') - back_offset
738 | end_data = inp.data_end - back_offset
739 | x.columns = pd.date_range(start_data, end_data)
740 | return predictions, x
741 | else:
742 | return predictions
743 |
744 |
745 | if __name__ == '__main__':
746 | parser = argparse.ArgumentParser(description='Train the model')
747 | parser.add_argument('--name', default='s32', help='Model name to identify different logs/checkpoints')
748 | parser.add_argument('--hparam_set', default='s32', help="Hyperparameters set to use (see hparams.py for available sets)")
749 | parser.add_argument('--n_models', default=1, type=int, help="Jointly train n models with different seeds")
750 | parser.add_argument('--multi_gpu', default=False, action='store_true', help="Use multiple GPUs for multi-model training, one GPU per model")
751 | parser.add_argument('--seed', default=5, type=int, help="Random seed")
752 | parser.add_argument('--logdir', default='data/logs', help="Directory for summary logs")
753 | parser.add_argument('--max_epoch', type=int, default=100, help="Max number of epochs")
754 | parser.add_argument('--patience', type=int, default=2, help="Early stopping: stop after N epochs without improvement. Requires do_eval=True")
755 | parser.add_argument('--train_sampling', type=float, default=1.0, help="Sample this percent of data for training")
756 | parser.add_argument('--eval_sampling', type=float, default=1.0, help="Sample this percent of data for evaluation")
757 | parser.add_argument('--eval_memsize', type=int, default=5, help="Approximate amount of avalable memory on GPU, used for calculation of optimal evaluation batch size")
758 | parser.add_argument('--gpu', default=0, type=int, help='GPU instance to use')
759 | parser.add_argument('--gpu_allow_growth', default=False, action='store_true', help='Allow to gradually increase GPU memory usage instead of grabbing all available memory at start')
760 | parser.add_argument('--save_best_model', default=False, action='store_true', help='Save best model during training. Requires do_eval=True')
761 | parser.add_argument('--no_forward_split', default=True, dest='forward_split', action='store_false', help='Use walk-forward split for model evaluation. Requires do_eval=True')
762 | parser.add_argument('--side_split', default=False, action='store_true', help='Use side split for model evaluation. Requires do_eval=True')
763 | parser.add_argument('--no_eval', default=True, dest='do_eval', action='store_false', help="Don't evaluate model quality during training")
764 | parser.add_argument('--no_summaries', default=True, dest='write_summaries', action='store_false', help="Don't Write Tensorflow summaries")
765 | parser.add_argument('--verbose', default=False, action='store_true', help='Print additional information during graph construction')
766 | parser.add_argument('--asgd_decay', type=float, help="EMA decay for averaged SGD. Not use ASGD if not set")
767 | parser.add_argument('--no_tqdm', default=True, dest='tqdm', action='store_false', help="Don't use tqdm for status display during training")
768 | parser.add_argument('--max_steps', type=int, help="Stop training after max steps")
769 | parser.add_argument('--save_from_step', type=int, help="Save model on each evaluation (10 evals per epoch), starting from this step")
770 | parser.add_argument('--predict_window', default=63, type=int, help="Number of days to predict")
771 | args = parser.parse_args()
772 |
773 | param_dict = dict(vars(args))
774 | param_dict['hparams'] = build_from_set(args.hparam_set)
775 | del param_dict['hparam_set']
776 | train(**param_dict)
777 |
778 | # hparams = build_hparams()
779 | # result = train("definc_attn", hparams, n_models=1, train_sampling=1.0, eval_sampling=1.0, patience=5, multi_gpu=True,
780 | # save_best_model=False, gpu=0, eval_memsize=15, seed=5, verbose=True, forward_split=False,
781 | # write_summaries=True, side_split=True, do_eval=False, predict_window=63, asgd_decay=None, max_steps=11500,
782 | # save_from_step=10500)
783 |
784 | # print("Training result:", result)
785 | # preds = predict('data/cpt/fair_365-15428', 380, hparams, verbose=True, back_offset=60, n_models=3)
786 | # print(preds)
787 |
--------------------------------------------------------------------------------