├── bulbea ├── app │ ├── config │ │ ├── client.py │ │ ├── base.py │ │ ├── server.py │ │ └── __init__.py │ ├── client │ │ └── __init__.py │ ├── __init__.py │ └── server │ │ ├── __init__.py │ │ └── server.py ├── _util │ ├── tests │ │ ├── __init__.py │ │ └── test__util.py │ ├── const.py │ ├── __init__.py │ ├── color.py │ └── _util.py ├── entity │ ├── tests │ │ └── test_share.py │ ├── base.py │ ├── stock.py │ ├── __init__.py │ └── share.py ├── learn │ ├── __init__.py │ ├── evaluation │ │ ├── __init__.py │ │ └── split.py │ ├── models │ │ ├── __init__.py │ │ ├── model.py │ │ └── ann.py │ └── sentiment │ │ ├── __init__.py │ │ ├── sentiment.py │ │ └── twitter.py ├── exceptions.py ├── config │ ├── base.py │ ├── __init__.py │ └── app.py ├── __main__.py ├── __init__.py └── cli.py ├── requirements ├── test.txt ├── documentation.txt ├── development.txt └── production.txt ├── docs ├── blog │ ├── visualizing-the-market.rst │ ├── artificial-neural-networks.rst │ └── data-data-everywhere.rst ├── _static │ ├── bulbea.png │ └── google_bollinger_bands.png ├── guides │ ├── api.rst │ └── user │ │ ├── quickstart.rst │ │ ├── introduction.rst │ │ └── installation.rst ├── _templates │ └── sidebar-logo.html ├── Makefile ├── conf.py ├── index.rst └── _themes │ └── flask_theme_support.py ├── .github ├── plot.png └── bulbea.png ├── examples ├── bulbea.png └── bulbea.ipynb ├── requirements.txt ├── .gitignore ├── Makefile ├── LICENSE ├── package.py ├── README.md └── setup.py /bulbea/app/config/client.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /bulbea/_util/tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /bulbea/app/client/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /bulbea/entity/tests/test_share.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements/test.txt: -------------------------------------------------------------------------------- 1 | pytest 2 | -------------------------------------------------------------------------------- /requirements/documentation.txt: -------------------------------------------------------------------------------- 1 | sphinx 2 | -------------------------------------------------------------------------------- /docs/blog/visualizing-the-market.rst: -------------------------------------------------------------------------------- 1 | Vizualizing the Market 2 | ====================== -------------------------------------------------------------------------------- /.github/plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/achillesrasquinha/bulbea/HEAD/.github/plot.png -------------------------------------------------------------------------------- /requirements/development.txt: -------------------------------------------------------------------------------- 1 | jupyter 2 | numpy 3 | scipy 4 | matplotlib 5 | scikit-learn 6 | -------------------------------------------------------------------------------- /.github/bulbea.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/achillesrasquinha/bulbea/HEAD/.github/bulbea.png -------------------------------------------------------------------------------- /examples/bulbea.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/achillesrasquinha/bulbea/HEAD/examples/bulbea.png -------------------------------------------------------------------------------- /bulbea/learn/__init__.py: -------------------------------------------------------------------------------- 1 | # module - bulbea.learn 2 | from bulbea.learn.sentiment import sentiment 3 | -------------------------------------------------------------------------------- /docs/_static/bulbea.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/achillesrasquinha/bulbea/HEAD/docs/_static/bulbea.png -------------------------------------------------------------------------------- /bulbea/exceptions.py: -------------------------------------------------------------------------------- 1 | TYPE_ERROR_STRING = 'Expected {expected_type_name}, got {recieved_type_name} instead.' 2 | -------------------------------------------------------------------------------- /requirements/production.txt: -------------------------------------------------------------------------------- 1 | six 2 | matplotlib 3 | Pillow 4 | quandl 5 | keras 6 | tweepy 7 | textblob 8 | Flask 9 | -------------------------------------------------------------------------------- /bulbea/learn/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from bulbea.learn.evaluation.split import split 4 | -------------------------------------------------------------------------------- /docs/_static/google_bollinger_bands.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/achillesrasquinha/bulbea/HEAD/docs/_static/google_bollinger_bands.png -------------------------------------------------------------------------------- /bulbea/entity/base.py: -------------------------------------------------------------------------------- 1 | from six import with_metaclass 2 | 3 | from abc import ABCMeta 4 | 5 | class Entity(with_metaclass(ABCMeta)): 6 | pass 7 | -------------------------------------------------------------------------------- /bulbea/config/base.py: -------------------------------------------------------------------------------- 1 | from six import with_metaclass 2 | 3 | from abc import ABCMeta 4 | 5 | class BaseConfig(with_metaclass(ABCMeta)): 6 | pass 7 | -------------------------------------------------------------------------------- /bulbea/_util/const.py: -------------------------------------------------------------------------------- 1 | ABSURL_QUANDL = 'https://www.quandl.com' 2 | QUANDL_MAX_DAILY_CALLS = 50000 3 | 4 | SHARE_ACCEPTED_SAVE_FORMATS = ('csv', 'pkl') -------------------------------------------------------------------------------- /bulbea/app/__init__.py: -------------------------------------------------------------------------------- 1 | # imports - compatibility packages 2 | from __future__ import absolute_import 3 | 4 | # module imports 5 | from bulbea.app.server import app 6 | -------------------------------------------------------------------------------- /bulbea/config/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from bulbea.config.base import BaseConfig 4 | from bulbea.config.app import AppConfig 5 | -------------------------------------------------------------------------------- /bulbea/app/config/base.py: -------------------------------------------------------------------------------- 1 | # imports - standard packages 2 | from enum import Enum 3 | 4 | class BaseConfig(object): 5 | class URL(object): 6 | BASE = "/" 7 | -------------------------------------------------------------------------------- /bulbea/app/server/__init__.py: -------------------------------------------------------------------------------- 1 | # imports - compatibility packages 2 | from __future__ import absolute_import 3 | 4 | # module imports 5 | from bulbea.app.server.server import app 6 | -------------------------------------------------------------------------------- /bulbea/learn/models/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from bulbea.learn.models.model import Model, Supervised 4 | from bulbea.learn.models.ann import ANN, RNN 5 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | jupyter 2 | numpy 3 | scipy 4 | matplotlib 5 | scikit-learn 6 | sphinx 7 | six 8 | matplotlib 9 | Pillow 10 | quandl 11 | keras 12 | tweepy 13 | textblob 14 | Flask 15 | pytest 16 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # virtualenvs 2 | .venv 3 | 4 | # docs 5 | docs/_build 6 | 7 | # Jupyter Checkpoints 8 | examples/.ipynb_checkpoints 9 | 10 | # twitter 11 | twitter.sh 12 | 13 | # misc 14 | snippet.py 15 | -------------------------------------------------------------------------------- /bulbea/entity/stock.py: -------------------------------------------------------------------------------- 1 | # imports - compatibility packages 2 | from __future__ import absolute_import 3 | 4 | # module imports 5 | from bulbea.entity import Entity 6 | 7 | class Stock(Entity): 8 | pass 9 | -------------------------------------------------------------------------------- /bulbea/app/config/server.py: -------------------------------------------------------------------------------- 1 | # imports - compatibility packages 2 | from __future__ import absolute_import 3 | 4 | # module imports 5 | from bulbea.app.config import BaseConfig 6 | 7 | class ServerConfig(BaseConfig): 8 | pass 9 | -------------------------------------------------------------------------------- /bulbea/app/config/__init__.py: -------------------------------------------------------------------------------- 1 | # imports - compatibility packages 2 | from __future__ import absolute_import 3 | 4 | # module imports 5 | from bulbea.app.config.base import BaseConfig 6 | from bulbea.app.config.server import ServerConfig 7 | -------------------------------------------------------------------------------- /bulbea/learn/sentiment/__init__.py: -------------------------------------------------------------------------------- 1 | # imports - compatibility packages 2 | from __future__ import absolute_import 3 | 4 | # module imports 5 | from bulbea.learn.sentiment.sentiment import sentiment 6 | from bulbea.learn.sentiment.twitter import Twitter 7 | -------------------------------------------------------------------------------- /bulbea/entity/__init__.py: -------------------------------------------------------------------------------- 1 | # modules - compatibility packages 2 | from __future__ import absolute_import 3 | 4 | # module bulbea.entity 5 | from bulbea.entity.base import Entity 6 | from bulbea.entity.share import Share 7 | from bulbea.entity.stock import Stock 8 | -------------------------------------------------------------------------------- /docs/guides/api.rst: -------------------------------------------------------------------------------- 1 | Developer Interface 2 | =================== 3 | 4 | Entities 5 | ++++++++ 6 | 7 | .. autoclass:: bulbea.Share 8 | :inherited-members: 9 | .. autoclass:: bulbea.Stock 10 | :inherited-members: 11 | 12 | Modelling 13 | +++++++++ 14 | -------------------------------------------------------------------------------- /bulbea/learn/sentiment/sentiment.py: -------------------------------------------------------------------------------- 1 | # imports - compatibility imports 2 | from __future__ import absolute_import 3 | 4 | # imports - third-party packages 5 | import textblob 6 | 7 | # module imports 8 | from bulbea.learn.sentiment.twitter import Twitter 9 | 10 | def sentiment(share): 11 | pass 12 | -------------------------------------------------------------------------------- /bulbea/__main__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # imports - compatibility imports 4 | from __future__ import absolute_import 5 | 6 | # imports - standard imports 7 | import sys 8 | 9 | # imports - module imports 10 | from bulbea.cli import main 11 | 12 | if __name__ == '__main__': 13 | sys.exit(main()) 14 | -------------------------------------------------------------------------------- /bulbea/__init__.py: -------------------------------------------------------------------------------- 1 | # imports - compatibility packages 2 | from __future__ import absolute_import 3 | 4 | # module imports 5 | from bulbea.entity import Share, Stock 6 | from bulbea.config import AppConfig 7 | from bulbea.app import app 8 | from bulbea.learn import sentiment 9 | 10 | __version__ = AppConfig.VERSION 11 | -------------------------------------------------------------------------------- /bulbea/app/server/server.py: -------------------------------------------------------------------------------- 1 | # imports - compatibility packages 2 | from __future__ import absolute_import 3 | 4 | # imports - third-party packages 5 | from flask import Flask 6 | 7 | # module imports 8 | from bulbea.app.config import ServerConfig 9 | 10 | app = Flask(__name__) 11 | 12 | @app.route(ServerConfig.URL.BASE) 13 | def index(): 14 | pass 15 | -------------------------------------------------------------------------------- /docs/_templates/sidebar-logo.html: -------------------------------------------------------------------------------- 1 | 6 | 7 |

8 | 9 |

10 | -------------------------------------------------------------------------------- /docs/blog/artificial-neural-networks.rst: -------------------------------------------------------------------------------- 1 | Artificial Neural Networks 2 | ========================== 3 | 4 | *"All models are wrong, but some are useful."* - George E. P. Box 5 | 6 | Recurrent Neural Networks 7 | +++++++++++++++++++++++++ 8 | 9 | A vanilla Recurrent Neural Network (hereby, RNN) is a kind of an Artificial Neural Network that considers a scenario - *at which time-step did you feed the input?* -------------------------------------------------------------------------------- /bulbea/learn/models/model.py: -------------------------------------------------------------------------------- 1 | from six import with_metaclass 2 | 3 | from abc import ABCMeta, abstractmethod 4 | 5 | class Model(with_metaclass(ABCMeta)): 6 | @abstractmethod 7 | def fit(self, X, y): 8 | pass 9 | 10 | @abstractmethod 11 | def predict(self, X): 12 | pass 13 | 14 | class Supervised(Model): 15 | def fit(self, X, y): 16 | pass 17 | 18 | def predict(self, X): 19 | pass 20 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: docs 2 | 3 | PYTHON = python 4 | 5 | install: 6 | cat requirements/*.txt > requirements.txt 7 | pip install -r requirements.txt 8 | 9 | pip install tensorflow-gpu 10 | 11 | $(PYTHON) setup.py install 12 | 13 | bash twitter.sh 14 | 15 | docs: 16 | cd docs && make html 17 | 18 | tests: 19 | $(PYTHON) setup.py test 20 | 21 | clean: 22 | $(PYTHON) setup.py clean 23 | 24 | all: 25 | make install docs tests clean 26 | -------------------------------------------------------------------------------- /docs/guides/user/quickstart.rst: -------------------------------------------------------------------------------- 1 | Quickstart 2 | ========== 3 | 4 | Waiting to make some money? We introduce you to a quick way of building your first prediction model. 5 | 6 | Create a :py:class:`Share ` object 7 | ++++++++++++++++++++++++++++++++++++++++++++++++ 8 | 9 | The canonical way of importing bulbea as follows: 10 | 11 | .. code:: python 12 | 13 | >>> import bulbea as bb 14 | 15 | Go ahead and create a share object. 16 | 17 | .. code:: python 18 | 19 | >>> share = bb.Share(source = 'YAHOO', ticker = 'GOOGL') 20 | -------------------------------------------------------------------------------- /bulbea/_util/__init__.py: -------------------------------------------------------------------------------- 1 | # imports - compatibility packages 2 | from __future__ import absolute_import 3 | 4 | # module imports 5 | from bulbea._util._util import ( 6 | _check_type, 7 | _check_str, 8 | _check_int, 9 | _check_real, 10 | _check_pandas_series, 11 | _check_pandas_dataframe, 12 | _check_iterable, 13 | _check_sequence, 14 | _check_environment_variable_set, 15 | _validate_in_range, 16 | _validate_date, 17 | _assign_if_none, 18 | _get_type_name, 19 | _get_datetime_str, 20 | _raise_type_error, 21 | _is_sequence_all 22 | ) 23 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2017 Achilles Rasquinha 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | SPHINXPROJ = bulbea 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -------------------------------------------------------------------------------- /bulbea/_util/color.py: -------------------------------------------------------------------------------- 1 | class Color(object): 2 | RESET = '\x1b[0m' 3 | 4 | BLACK = 0 5 | RED = 1 6 | GREEN = 2 7 | YELLOW = 3 8 | BLUE = 4 9 | MAGENTA = 5 10 | CYAN = 6 11 | WHITE = 7 12 | 13 | NORMAL = 0 14 | BOLD = 1 15 | 16 | @staticmethod 17 | def to_color_string(string, 18 | foreground = 7, 19 | background = None, 20 | style = 1): 21 | style = '\x1b[0%sm' % style 22 | foreground = '\x1b[3%sm' % foreground 23 | background = '' if background is None else '\x1b[4%sm' % background 24 | preset = style + foreground + background 25 | 26 | colored = preset + string + Color.RESET 27 | 28 | return colored 29 | 30 | def warn(string): 31 | colored = Color.to_color_string(string, foreground = Color.YELLOW) 32 | 33 | return colored -------------------------------------------------------------------------------- /docs/guides/user/introduction.rst: -------------------------------------------------------------------------------- 1 | Introduction 2 | ============ 3 | 4 | What's in the name? 5 | +++++++++++++++++++ 6 | 7 | **bulbea** is a portmanteau of the very nature of a stock market - the bull and the bear. Hence, the name. 8 | 9 | .. _license: 10 | 11 | License 12 | +++++++ 13 | 14 | **bulbea** is released under the Apache 2.0 License. 15 | 16 | .. code:: raw 17 | 18 | Copyright 2017 Achilles Rasquinha 19 | 20 | Licensed under the Apache License, Version 2.0 (the "License"); 21 | you may not use this file except in compliance with the License. 22 | You may obtain a copy of the License at 23 | 24 | http://www.apache.org/licenses/LICENSE-2.0 25 | 26 | Unless required by applicable law or agreed to in writing, software 27 | distributed under the License is distributed on an "AS IS" BASIS, 28 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 29 | See the License for the specific language governing permissions and 30 | limitations under the License. 31 | -------------------------------------------------------------------------------- /bulbea/cli.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import argparse 4 | 5 | from bulbea import AppConfig 6 | from bulbea._util.color import Color 7 | 8 | import bulbea as bb 9 | 10 | _description = '{logo} v{version}'.format(logo = Color.to_color_string(AppConfig.LOGO, AppConfig.COLOR_PRIMARY), version = '.'.join(map(str, AppConfig.VERSION))) 11 | parser = argparse.ArgumentParser(description = _description, 12 | formatter_class = argparse.RawDescriptionHelpFormatter) 13 | parser.add_argument('SOURCE', 14 | help = 'source code for economic data') 15 | parser.add_argument('SYMBOL', 16 | help = 'ticker symbol of a company') 17 | parser.add_argument('-i', '--gui', 18 | action = 'store_true', 19 | help = 'launch the graphical user interface') 20 | parser.add_argument('-v', '--version', 21 | action = 'version', 22 | version = 'v{version}'.format(version = '.'.join(map(str, AppConfig.VERSION))), 23 | help = 'show the version information and exit') 24 | 25 | def main(): 26 | args = parser.parse_args() 27 | -------------------------------------------------------------------------------- /bulbea/config/app.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from bulbea.config import BaseConfig 4 | from bulbea._util.color import Color 5 | 6 | class AppConfig(BaseConfig): 7 | NAME = 'bulbea' 8 | VERSION = (0,1,0) 9 | LOGO = \ 10 | """ 11 | 888 888 888 12 | 888 888 888 13 | 888 888 888 14 | 88888b. 888 888 888 88888b. .d88b. 8888b. 15 | 888 "88b 888 888 888 888 "88b d8P Y8b "88b 16 | 888 888 888 888 888 888 888 88888888 .d888888 17 | 888 d88P Y88b 888 888 888 d88P Y8b. 888 888 18 | 88888P" "Y88888 888 88888P" "Y8888 "Y888888""" 19 | COLOR_PRIMARY = Color.YELLOW 20 | 21 | WINDOW_ASPECT_RATIO = 3 / 2 22 | WINDOW_WIDTH = 320 23 | WINDOW_HEIGHT = int(WINDOW_WIDTH * WINDOW_ASPECT_RATIO) 24 | 25 | PLOT_STYLE = 'seaborn' 26 | 27 | ENVIRONMENT_VARIABLE = { 28 | 'quandl_api_key': 'BULBEA_QUANDL_API_KEY', 29 | 'twitter_api_key': 'BULBEA_TWITTER_API_KEY', 30 | 'twitter_api_secret': 'BULBEA_TWITTER_API_SECRET', 31 | 'twitter_access_token': 'BULBEA_TWITTER_ACCESS_TOKEN', 32 | 'twitter_access_token_secret': 'BULBEA_TWITTER_ACCESS_TOKEN_SECRET' 33 | } 34 | -------------------------------------------------------------------------------- /package.py: -------------------------------------------------------------------------------- 1 | # Inspired by npm's package.json file 2 | name = 'bulbea' 3 | version = '0.1.0' 4 | release = '0.1.0' 5 | description = 'A neural stock market predictor and model builder' 6 | long_description = ['README.md'] 7 | keywords = ['neural', 'network', 'machine', 'deep', 8 | 'learning', 'tensorflow', 'stock', 'market', 'prediction'] 9 | authors = [ 10 | { 'name': 'Achilles Rasquinha', 'email': 'achillesrasquinha@gmail.com' } 11 | ] 12 | maintainers = [ 13 | { 'name': 'Achilles Rasquinha', 'email': 'achillesrasquinha@gmail.com' } 14 | ] 15 | license = 'Apache 2.0' 16 | modules = [ 17 | 'bulbea', 18 | 'bulbea.config', 19 | 'bulbea._util', 20 | 'bulbea.entity', 21 | 'bulbea.learn', 22 | 'bulbea.learn.models', 23 | 'bulbea.learn.evaluation', 24 | 'bulbea.learn.sentiment', 25 | 'bulbea.app', 26 | 'bulbea.app.client', 27 | 'bulbea.app.server', 28 | 'bulbea.app.config' 29 | ] 30 | test_modules = [ 31 | 'bulbea._util.tests' 32 | ] 33 | homepage = 'https://achillesrasquinha.github.io/bulbea' 34 | github_username = 'achillesrasquinha' 35 | github_repository = 'bulbea' 36 | github_url = '{baseurl}/{username}/{repository}'.format( 37 | baseurl = 'https://github.com', 38 | username = github_username, 39 | repository = github_repository) 40 | -------------------------------------------------------------------------------- /bulbea/learn/models/ann.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from six import with_metaclass 3 | 4 | from keras.models import Sequential 5 | from keras.layers import recurrent 6 | from keras.layers import core 7 | 8 | from bulbea.learn.models import Supervised 9 | 10 | class ANN(Supervised): 11 | pass 12 | 13 | class RNNCell(object): 14 | RNN = recurrent.SimpleRNN 15 | GRU = recurrent.GRU 16 | LSTM = recurrent.LSTM 17 | 18 | class RNN(ANN): 19 | def __init__(self, sizes, 20 | cell = RNNCell.LSTM, 21 | dropout = 0.2, 22 | activation = 'linear', 23 | loss = 'mse', 24 | optimizer = 'rmsprop'): 25 | self.model = Sequential() 26 | self.model.add(cell( 27 | input_dim = sizes[0], 28 | output_dim = sizes[1], 29 | return_sequences = True 30 | )) 31 | 32 | for i in range(2, len(sizes) - 1): 33 | self.model.add(cell(sizes[i], return_sequences = False)) 34 | self.model.add(core.Dropout(dropout)) 35 | 36 | self.model.add(core.Dense(output_dim = sizes[-1])) 37 | self.model.add(core.Activation(activation)) 38 | 39 | self.model.compile(loss = loss, optimizer = optimizer) 40 | 41 | def fit(self, X, y, *args, **kwargs): 42 | return self.model.fit(X, y, *args, **kwargs) 43 | 44 | def predict(self, X): 45 | return self.model.predict(X) 46 | -------------------------------------------------------------------------------- /bulbea/learn/sentiment/twitter.py: -------------------------------------------------------------------------------- 1 | # imports - compatibility packages 2 | from __future__ import absolute_import 3 | 4 | # imports - third-party packages 5 | import tweepy 6 | 7 | # module imports 8 | from bulbea import AppConfig 9 | from bulbea._util import _check_environment_variable_set 10 | 11 | class Twitter(object): 12 | def __init__(self): 13 | api_key = AppConfig.ENVIRONMENT_VARIABLE['twitter_api_key'] 14 | api_secret = AppConfig.ENVIRONMENT_VARIABLE['twitter_api_secret'] 15 | access_token = AppConfig.ENVIRONMENT_VARIABLE['twitter_access_token'] 16 | access_token_secret = AppConfig.ENVIRONMENT_VARIABLE['twitter_access_token_secret'] 17 | 18 | _check_environment_variable_set(api_key, raise_err = True) 19 | _check_environment_variable_set(api_secret, raise_err = True) 20 | _check_environment_variable_set(access_token, raise_err = True) 21 | _check_environment_variable_set(access_token_secret, raise_err = True) 22 | 23 | self.api_key = api_key 24 | self.api_secret = api_secret 25 | self.access_token = access_token 26 | self.access_token_secret = access_token_secret 27 | 28 | self.auth_handler = tweepy.OAuthHandler(self.api_key, self.api_secret) 29 | self.auth_handler.set_access_token(self.access_token, self.access_token_secret) 30 | 31 | self.api = tweepy.API(self.auth_handler) 32 | -------------------------------------------------------------------------------- /bulbea/_util/tests/test__util.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from bulbea._util import ( 4 | _raise_type_error, 5 | _get_type_name, 6 | _check_type, 7 | _validate_in_range, 8 | _validate_date, 9 | _assign_if_none 10 | ) 11 | 12 | def test__raise_type_error(): 13 | with pytest.raises(TypeError): 14 | _raise_type_error( 15 | expected_type_name = 'expected', 16 | recieved_type_name = 'recieved' 17 | ) 18 | 19 | def test__get_type_name(): 20 | assert _get_type_name('foo') == 'str' 21 | assert _get_type_name(12345) == 'int' 22 | assert _get_type_name(1.234) == 'float' 23 | 24 | def test__check_type(): 25 | with pytest.raises(TypeError): 26 | _check_type('foo', type_ = int, raise_err = True, expected_type_name = 'str') 27 | with pytest.raises(TypeError): 28 | _check_type(12345, type_ = str, raise_err = True, expected_type_name = 'int') 29 | 30 | assert _check_type('bar', type_ = str) == True 31 | assert _check_type('foo', type_ = int) == False 32 | 33 | def test__validate_in_range(): 34 | with pytest.raises(ValueError): 35 | _validate_in_range(123, 0, 1, raise_err = True) 36 | 37 | assert _validate_in_range(0.5, 0, 1) == True 38 | assert _validate_in_range(123, 0, 1) == False 39 | 40 | def test__validate_date(): 41 | with pytest.raises(ValueError): 42 | _validate_date('12/12/12', raise_err = True) 43 | 44 | assert _validate_date('2012-01-01') == True 45 | assert _validate_date('2012/01/01') == False 46 | 47 | def test__assign_if_none(): 48 | assert _assign_if_none(None, 1) == 1 49 | assert _assign_if_none(1,'foo') == 1 50 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import sys 4 | import os 5 | 6 | sys.path.insert(0, os.path.abspath('..')) 7 | sys.path.insert(0, os.path.abspath('_themes')) 8 | 9 | import package 10 | import bulbea as bb 11 | 12 | project = package.name 13 | version = package.version 14 | release = package.release 15 | html_theme = 'alabaster' 16 | html_theme_options = { 17 | 'github_user' : package.github_username, 18 | 'github_repo' : package.github_repository, 19 | 'github_banner' : True, 20 | 'show_powered_by' : False, 21 | 'show_related' : False, 22 | } 23 | html_show_sourcelink = False 24 | html_static_path = ['_static'] 25 | html_sidebars = { 26 | '**': ['sidebar-logo.html'] 27 | } 28 | pygments_style = 'flask_theme_support.FlaskyStyle' 29 | extensions = [ 30 | 'sphinx.ext.autodoc', 31 | 'sphinx.ext.mathjax' 32 | ] 33 | source_suffix = '.rst' 34 | master_doc = 'index' 35 | templates_path = ['_templates'] 36 | exclude_patterns = ['_build'] 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | copyright = '2017, Achilles Rasquinha' 52 | author = 'Achilles Rasquinha' 53 | # The language for content autogenerated by Sphinx. Refer to documentation 54 | # for a list of supported languages. 55 | # 56 | # This is also used if you do content translation via gettext catalogs. 57 | # Usually you set "language" from the command line for these cases. 58 | language = None 59 | 60 | 61 | # If true, `todo` and `todoList` produce output, else they produce nothing. 62 | todo_include_todos = False 63 | -------------------------------------------------------------------------------- /docs/guides/user/installation.rst: -------------------------------------------------------------------------------- 1 | .. _installation: 2 | 3 | Installation 4 | ============ 5 | 6 | Building from source 7 | ++++++++++++++++++++ 8 | 9 | **bulbea** is actively developed on GitHub_ and is always avaliable. 10 | 11 | .. _GitHub: https://github.com/achillesrasquinha/bulbea 12 | 13 | You can clone the base repository with :code:`git` as follows: 14 | 15 | .. code-block:: console 16 | 17 | $ git clone git@github.com:achillesrasquinha/bulbea.git 18 | 19 | Optionally, you could download the tarball_ or zipball_ as follows: 20 | 21 | .. _tarball: https://github.com/achillesrasquinha/tarball/bulbea 22 | .. _zipball: https://github.com/achillesrasquinha/zipball/bulbea 23 | 24 | **For Linux Users** 25 | 26 | .. code-block:: console 27 | 28 | $ curl -OL https://github.com/achillesrasquinha/tarball/bulbea 29 | 30 | **For Windows Users** 31 | 32 | .. code-block:: console 33 | 34 | $ curl -OL https://github.com/achillesrasquinha/zipball/bulbea 35 | 36 | Install necessary dependencies 37 | 38 | .. code-block:: console 39 | 40 | $ pip install -r requirements.txt 41 | 42 | **bulbea** depends on Keras which thereby depends on TensorFlow as a backend. You may have to manually install TensorFlow as follows: 43 | 44 | .. code-block:: console 45 | 46 | $ pip install tensorflow # CPU-only 47 | 48 | OR 49 | 50 | .. code-block:: console 51 | 52 | $ pip install tensorflow-gpu # GPU-only, requires NVIDIA CUDA and cuDNN 53 | 54 | Then, go ahead and install **bulbea** in your site-packages as follows: 55 | 56 | .. code-block:: console 57 | 58 | $ python setup.py install 59 | 60 | Check to see if you've installed **bulbea** correctly. 61 | 62 | .. code-block:: python 63 | 64 | >>> import bulbea as bb 65 | -------------------------------------------------------------------------------- /bulbea/learn/evaluation/split.py: -------------------------------------------------------------------------------- 1 | # imports - compatibility packages 2 | from __future__ import absolute_import 3 | 4 | # imports - third-party packages 5 | import numpy as np 6 | from sklearn.preprocessing import MinMaxScaler 7 | 8 | # module imports 9 | from bulbea._util import ( 10 | _check_type, 11 | _check_int, 12 | _check_real, 13 | _check_iterable, 14 | _check_sequence, 15 | _validate_in_range 16 | ) 17 | from bulbea.entity.share import _get_cummulative_return 18 | import bulbea as bb 19 | 20 | def split(share, 21 | attrs = 'Close', 22 | window = 0.01, 23 | train = 0.60, 24 | shift = 1, 25 | normalize = False): 26 | ''' 27 | :param attrs: `str` or `list` of attribute names of a share, defaults to *Close* attribute 28 | :type attrs: :obj: `str`, :obj:`list` 29 | ''' 30 | _check_type(share, type_ = bb.Share, raise_err = True, expected_type_name = 'bulbea.Share') 31 | _check_iterable(attrs, raise_err = True) 32 | _check_int(shift, raise_err = True) 33 | _check_real(window, raise_err = True) 34 | _check_real(train, raise_err = True) 35 | 36 | _validate_in_range(window, 0, 1, raise_err = True) 37 | _validate_in_range(train, 0, 1, raise_err = True) 38 | 39 | data = share.data[attrs] 40 | 41 | length = len(share) 42 | 43 | window = int(np.rint(length * window)) 44 | offset = shift - 1 45 | 46 | splits = np.array([data[i if i is 0 else i + offset: i + window] for i in range(length - window)]) 47 | 48 | if normalize: 49 | splits = np.array([_get_cummulative_return(split) for split in splits]) 50 | 51 | size = len(splits) 52 | split = int(np.rint(train * size)) 53 | 54 | train = splits[:split,:] 55 | test = splits[split:,:] 56 | 57 | Xtrain, Xtest = train[:,:-1], test[:,:-1] 58 | ytrain, ytest = train[:, -1], test[:, -1] 59 | 60 | return (Xtrain, Xtest, ytrain, ytest) 61 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | bulbea 2 | ====== 3 | *"Deep Learning based Python library for Stock Market Prediction and Modelling."* 4 | 5 | Release: v\ |release| (:ref:`Installation `) 6 | 7 | .. image:: https://img.shields.io/gitter/room/bulbea/bulbea.svg 8 | :target: https://gitter.im/bulbea/bulbea 9 | 10 | **bulbea** is an Open Source Python module (released under the :ref:`Apache 2.0 License `) that consists a growing collection of statistical, visualization and modelling tools for financial data analysis and prediction using deep learning. 11 | 12 | **bulbea** helps you with 13 | 14 | **Financial Data Loading** 15 | 16 | .. code:: python 17 | 18 | >>> import bulbea as bb 19 | >>> share = bb.Share('YAHOO', 'GOOGL') # Get Google's historical data from Yahoo's database 20 | >>> share.data 21 | Open High Low Close Volume Adjusted Close 22 | Date 23 | 2004-08-19 99.999999 104.059999 95.959998 100.339998 44659000.0 50.220219 24 | 2004-08-20 101.010005 109.079998 100.500002 108.310002 22834300.0 54.209210 25 | 2004-08-23 110.750003 113.479998 109.049999 109.399998 18256100.0 54.754754 26 | ... 27 | 28 | **Statistical Vizualization** 29 | 30 | .. code:: python 31 | 32 | >>> share.plot(bollinger_bands = True, period = 100, bandwidth = 2) 33 | 34 | .. image:: _static/google_bollinger_bands.png 35 | 36 | **bulbea** is created and currently maintained by `Achilles Rasquinha `_. 37 | 38 | **bulbea** officially supports Python 2.7 and 3.5. 39 | 40 | Guide - User 41 | ++++++++++++ 42 | 43 | .. toctree:: 44 | :maxdepth: 2 45 | 46 | guides/user/introduction 47 | guides/user/installation 48 | guides/user/quickstart 49 | 50 | Guide - API 51 | +++++++++++ 52 | 53 | .. toctree:: 54 | :maxdepth: 2 55 | 56 | guides/api 57 | 58 | Blog 59 | ++++ 60 | 61 | .. toctree:: 62 | :maxdepth: 2 63 | 64 | blog/data-data-everywhere 65 | blog/visualizing-the-market 66 | blog/artificial-neural-networks 67 | -------------------------------------------------------------------------------- /docs/blog/data-data-everywhere.rst: -------------------------------------------------------------------------------- 1 | Data, Data Everywhere 2 | ===================== 3 | *"In God we trust, all others must bring data."* - W. Edwards Deming 4 | 5 | How data is stored 6 | ++++++++++++++++++ 7 | 8 | Data streams itself right from when the gates of a stock exchange open to when it closes. Such data contains vital information that is archived each day. Some of the many types of information recieved after trading hours are - *opening price*, *closing price*, *volumne of shares*, *highest price*, *lowest price*, etc. for each enterprise. 9 | 10 | **bulbea** helps you access such information (both - archived and the latest). Simply create a :py:class:`Share ` with a known :code:`source` and :code:`ticker` as follows: 11 | 12 | .. code:: python 13 | 14 | >>> import bulbea as bb 15 | >>> share = bb.Share(source = 'YAHOO', ticker = 'GOOGL') 16 | >>> share.data 17 | Open High Low Close Volume Adjusted Close 18 | Date 19 | 2004-08-19 99.999999 104.059999 95.959998 100.339998 44659000.0 50.220219 20 | 2004-08-20 101.010005 109.079998 100.500002 108.310002 22834300.0 54.209210 21 | 2004-08-23 110.750003 113.479998 109.049999 109.399998 18256100.0 54.754754 22 | ... 23 | 24 | Data is accessed through the :py:mod:`Quandl ` API stored remotely at sources in the form of CSV (Comma-Seperated Values) files. Information retrieved from such a CSV file is then wrapped around a :py:class:`pandas.DataFrame ` object. 25 | 26 | Comma, Seperated, Value? 27 | ------------------------ 28 | 29 | CSV files store tabular data in simple plain text (well, fits the need). Each row containing values associated to each attribute of a table are stored in a single line, where each value is seperated by a delimiter, you guessed it right, a comma. For instance, a data set containing the weight (in kilograms) and height (in inches) of members of my family would look something like the following: 30 | 31 | .. code:: raw 32 | 33 | weight,height 34 | 87,6.2 35 | 51,5.8 36 | 68,5.9 37 | ... 38 | 39 | Almost always, the top-most line (also called as *the header*) should denote the attribute names seperated by the delimiter. 40 | 41 | You can save a share object in a CSV format as follows: 42 | 43 | .. code:: python 44 | 45 | >>> share.save() 46 | 47 | By default, the :py:meth:`save ` method saves a share as a CSV file in the working directory with a file name of the format - :code:`___.csv`. You could also name the file anything you like as follows: 48 | 49 | .. code:: python 50 | 51 | >>> share.save(filename = 'mycsvfile.csv') 52 | 53 | :py:class:`pandas.DataFrame` 54 | ---------------------------- -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # bulbea 2 | > *"Deep Learning based Python Library for Stock Market Prediction and Modelling."* 3 | 4 | [![Gitter](https://img.shields.io/gitter/room/bulbea/bulbea.svg)](https://gitter.im/bulbea/bulbea) [![Documentation Status](https://readthedocs.org/projects/bulbea/badge/?version=latest)](http://bulbea.readthedocs.io/en/latest/?badge=latest) 5 | 6 | ![](.github/bulbea.png) 7 | 8 | ### Table of Contents 9 | * [Installation](#installation) 10 | * [Usage](#usage) 11 | * [Documentation](#documentation) 12 | * [Dependencies](#dependencies) 13 | * [License](#license) 14 | 15 | ### Installation 16 | Clone the git repository: 17 | ```console 18 | $ git clone https://github.com/achillesrasquinha/bulbea.git && cd bulbea 19 | ``` 20 | 21 | Install necessary dependencies 22 | ```console 23 | $ pip install -r requirements.txt 24 | ``` 25 | 26 | Go ahead and install as follows: 27 | ```console 28 | $ python setup.py install 29 | ``` 30 | 31 | You may have to install TensorFlow: 32 | ```console 33 | $ pip install tensorflow # CPU 34 | $ pip install tensorflow-gpu # GPU - Requires CUDA, CuDNN 35 | ``` 36 | 37 | ### Usage 38 | #### 1. Prediction 39 | ##### a. Loading 40 | Create a share object. 41 | ```python 42 | >>> import bulbea as bb 43 | >>> share = bb.Share('YAHOO', 'GOOGL') 44 | >>> share.data 45 | # Open High Low Close Volume \ 46 | # Date 47 | # 2004-08-19 99.999999 104.059999 95.959998 100.339998 44659000.0 48 | # 2004-08-20 101.010005 109.079998 100.500002 108.310002 22834300.0 49 | # 2004-08-23 110.750003 113.479998 109.049999 109.399998 18256100.0 50 | # 2004-08-24 111.239999 111.599998 103.570003 104.870002 15247300.0 51 | # 2004-08-25 104.960000 108.000002 103.880003 106.000005 9188600.0 52 | ... 53 | ``` 54 | ##### b. Preprocessing 55 | Split your data set into training and testing sets. 56 | ```python 57 | >>> from bulbea.learn.evaluation import split 58 | >>> Xtrain, Xtest, ytrain, ytest = split(share, 'Close', normalize = True) 59 | ``` 60 | 61 | ##### c. Modelling 62 | ```python 63 | >>> import numpy as np 64 | >>> Xtrain = np.reshape(Xtrain, (Xtrain.shape[0], Xtrain.shape[1], 1)) 65 | >>> Xtest = np.reshape( Xtest, ( Xtest.shape[0], Xtest.shape[1], 1)) 66 | 67 | >>> from bulbea.learn.models import RNN 68 | >>> rnn = RNN([1, 100, 100, 1]) # number of neurons in each layer 69 | >>> rnn.fit(Xtrain, ytrain) 70 | # Epoch 1/10 71 | # 1877/1877 [==============================] - 6s - loss: 0.0039 72 | # Epoch 2/10 73 | # 1877/1877 [==============================] - 6s - loss: 0.0019 74 | ... 75 | ``` 76 | 77 | ##### d. Testing 78 | ```python 79 | >>> from sklearn.metrics import mean_squared_error 80 | >>> p = rnn.predict(Xtest) 81 | >>> mean_squared_error(ytest, p) 82 | 0.00042927869370525931 83 | >>> import matplotlib.pyplot as pplt 84 | >>> pplt.plot(ytest) 85 | >>> pplt.plot(p) 86 | >>> pplt.show() 87 | ``` 88 | ![](.github/plot.png) 89 | 90 | #### 2. Sentiment Analysis 91 | Add your Twitter credentials to your environment variables. 92 | ```bash 93 | export BULBEA_TWITTER_API_KEY="" 94 | export BULBEA_TWITTER_API_SECRET="" 95 | 96 | export BULBEA_TWITTER_ACCESS_TOKEN="" 97 | export BULBEA_TWITTER_ACCESS_TOKEN_SECRET="" 98 | ``` 99 | And then, 100 | ```python 101 | >>> bb.sentiment(share) 102 | 0.07580128205128206 103 | ``` 104 | 105 | ### Documentation 106 | Detailed documentation is available [here](http://bulbea.readthedocs.io/en/latest/). 107 | 108 | ### Dependencies 109 | 1. quandl 110 | 2. keras 111 | 3. tweepy 112 | 4. textblob 113 | 114 | ### License 115 | This code has been released under the [Apache 2.0 License](LICENSE). 116 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import sys 4 | import os 5 | import shutil 6 | import codecs 7 | 8 | from distutils.core import Command 9 | from distutils.command.clean import clean as Clean 10 | 11 | import package 12 | 13 | ABSPATH_ROOTDIR = os.path.dirname(os.path.abspath(__file__)) 14 | RELPATH_FILES_CLEAN = ['build', 'dist', '{name}.egg-info'.format(name = package.name), '.cache'] 15 | RELPATH_WALK_FILES_EXT_CLEAN = ['.pyc'] 16 | RELPATH_WALK_DIRS_CLEAN = ['__pycache__'] 17 | 18 | class CleanCommand(Clean): 19 | def run(self): 20 | Clean.run(self) 21 | 22 | for filename in RELPATH_FILES_CLEAN: 23 | if os.path.exists(filename): 24 | shutil.rmtree(filename) 25 | 26 | for dirpath, dirnames, filenames in os.walk(ABSPATH_ROOTDIR): 27 | for filename in filenames: 28 | for extension in RELPATH_WALK_FILES_EXT_CLEAN: 29 | if filename.endswith(extension): 30 | path = os.path.join(dirpath, filename) 31 | os.unlink(path) 32 | 33 | for dirname in dirnames: 34 | if dirname in RELPATH_WALK_DIRS_CLEAN: 35 | path = os.path.join(dirpath, dirname) 36 | shutil.rmtree(path, ignore_errors = True) 37 | 38 | class TestCommand(Command): 39 | user_options = [('pytest=', 'a', 'arguments to be passed to pytest')] 40 | 41 | def initialize_options(self): 42 | self.args_pytest = [ ] 43 | 44 | def finalize_options(self): 45 | pass 46 | 47 | def run(self): 48 | import pytest 49 | 50 | errno = pytest.main(self.args_pytest) 51 | 52 | sys.exit(errno) 53 | 54 | def get_long_description(filepaths): 55 | content = '' 56 | filepaths = filepaths if isinstance(filepaths, list) else [filepaths] 57 | 58 | if filepaths: 59 | for i, filepath in enumerate(filepaths): 60 | if os.path.exists(filepath): 61 | if os.path.isfile(filepath): 62 | if os.path.getsize(filepath) > 0: 63 | f = codecs.open(filepath, mode = 'r', encoding = 'utf-8') 64 | raw = f.read() 65 | content += '{prepend}{content}'.format(prepend = '' if i is 0 else '\n\n', content = raw) 66 | 67 | f.close() 68 | else: 69 | raise ValueError('Not a file: {filepath}'.format(filepath = filepath)) 70 | else: 71 | raise FileNotFoundError('No such file found: {filepath}'.format(filepath = filepath)) 72 | 73 | return content 74 | 75 | def main(): 76 | try: 77 | from setuptools import setup 78 | args_setuptools = dict( 79 | keywords = ', '.join([keyword for keyword in package.keywords]) 80 | ) 81 | except ImportError: 82 | from distutils.core import setup 83 | args_setuptools = dict() 84 | 85 | metadata = dict( 86 | name = package.name, 87 | version = package.version, 88 | description = package.description, 89 | long_description = get_long_description(package.long_description), 90 | author = ','.join([author['name'] for author in package.authors]), 91 | author_email = ','.join([author['email'] for author in package.authors]), 92 | maintainer = ','.join([maintainer['name'] for maintainer in package.maintainers]), 93 | maintainer_email = ','.join([maintainer['email'] for maintainer in package.maintainers]), 94 | license = package.license, 95 | packages = package.modules, 96 | url = package.homepage, 97 | cmdclass = { 98 | 'clean': CleanCommand, 99 | 'test': TestCommand 100 | }, 101 | **args_setuptools 102 | ) 103 | 104 | setup(**metadata) 105 | 106 | if __name__ == '__main__': 107 | main() 108 | -------------------------------------------------------------------------------- /bulbea/_util/_util.py: -------------------------------------------------------------------------------- 1 | # imports - compatibility packages 2 | from __future__ import absolute_import 3 | from six import string_types 4 | 5 | # imports - standard packages 6 | import os 7 | import collections 8 | import numbers 9 | from datetime import datetime 10 | 11 | # imports - third-party packages 12 | import pandas as pd 13 | 14 | # module imports 15 | from bulbea.exceptions import TYPE_ERROR_STRING 16 | 17 | def _raise_type_error(expected_type_name, recieved_type_name): 18 | raise TypeError(TYPE_ERROR_STRING.format( 19 | expected_type_name = expected_type_name, 20 | recieved_type_name = recieved_type_name 21 | )) 22 | 23 | def _get_type_name(o): 24 | type_ = type(o) 25 | name = type_.__name__ 26 | 27 | return name 28 | 29 | def _get_datetime_str(dt, format_): 30 | if _check_type(dt, pd.Timestamp): 31 | dt = dt.to_pydatetime() 32 | 33 | _check_type(dt, type_ = datetime, raise_err = True, expected_type_name = 'datetime.datetime') 34 | 35 | string = dt.strftime(format_) 36 | 37 | return string 38 | 39 | def _check_type(o, type_, raise_err = False, expected_type_name = None): 40 | if not isinstance(o, type_): 41 | if raise_err: 42 | _raise_type_error( 43 | expected_type_name = expected_type_name, 44 | recieved_type_name = _get_type_name(o) 45 | ) 46 | else: 47 | return False 48 | else: 49 | return True 50 | 51 | def _check_str(o, raise_err = False): 52 | return _check_type(o, string_types, raise_err = raise_err, expected_type_name = 'str') 53 | 54 | def _check_int(o, raise_err = False): 55 | return _check_type(o, numbers.Integral, raise_err = raise_err, expected_type_name = 'int') 56 | 57 | def _check_real(o, raise_err = False): 58 | return _check_type(o, numbers.Real, raise_err = raise_err, expected_type_name = '(int, float)') 59 | 60 | def _check_pandas_series(data, raise_err = False): 61 | return _check_type(data, pd.Series, raise_err = raise_err, expected_type_name = 'pandas.Series') 62 | 63 | def _check_pandas_dataframe(data, raise_err = False): 64 | return _check_type(data, pd.DataFrame, raise_err = raise_err, expected_type_name = 'pandas.DataFrame') 65 | 66 | def _check_iterable(o, raise_err = False): 67 | return _check_type(o, collections.Iterable, raise_err = raise_err, expected_type_name = '(str, list, tuple)') 68 | 69 | def _check_sequence(o, string = True, raise_err = False): 70 | return _check_type(o, collections.Sequence, raise_err = raise_err, expected_type_name = '(list, tuple)') 71 | 72 | def _check_environment_variable_set(variable, raise_err = False): 73 | _check_str(variable, raise_err = raise_err) 74 | 75 | try: 76 | os.environ[variable] 77 | except KeyError: 78 | if raise_err: 79 | raise ValueError('Environment variable {variable} not set.') 80 | else: 81 | return False 82 | 83 | return True 84 | 85 | def _validate_in_range(value, low, high, raise_err = False): 86 | if not low <= value <= high: 87 | if raise_err: 88 | raise ValueError('{value} out of bounds, must be in range [{low}, {high}].'.format( 89 | value = value, 90 | low = low, 91 | high = high 92 | )) 93 | else: 94 | return False 95 | else: 96 | return True 97 | 98 | def _validate_date(value, format_ = '%Y-%m-%d', raise_err = False): 99 | _check_str(value, raise_err = raise_err) 100 | 101 | try: 102 | datetime.strptime(value, format_) 103 | except ValueError: 104 | if raise_err: 105 | raise ValueError('Expected {format_} format, got {value} instead.'.format( 106 | format_ = format_, 107 | value = value 108 | )) 109 | else: 110 | return False 111 | 112 | return True 113 | 114 | def _assign_if_none(a, b): 115 | return b if a is None else a 116 | 117 | def _is_sequence_all(seq): 118 | _check_sequence(seq, raise_err = True) 119 | 120 | length = len(seq) 121 | is_seq = True if length != 0 and seq.count(seq[0]) == length else False 122 | 123 | return is_seq 124 | -------------------------------------------------------------------------------- /docs/_themes/flask_theme_support.py: -------------------------------------------------------------------------------- 1 | # flasky extensions. flasky pygments style based on tango style 2 | from pygments.style import Style 3 | from pygments.token import Keyword, Name, Comment, String, Error, \ 4 | Number, Operator, Generic, Whitespace, Punctuation, Other, Literal 5 | 6 | 7 | class FlaskyStyle(Style): 8 | background_color = "#f8f8f8" 9 | default_style = "" 10 | 11 | styles = { 12 | # No corresponding class for the following: 13 | #Text: "", # class: '' 14 | Whitespace: "underline #f8f8f8", # class: 'w' 15 | Error: "#a40000 border:#ef2929", # class: 'err' 16 | Other: "#000000", # class 'x' 17 | 18 | Comment: "italic #8f5902", # class: 'c' 19 | Comment.Preproc: "noitalic", # class: 'cp' 20 | 21 | Keyword: "bold #004461", # class: 'k' 22 | Keyword.Constant: "bold #004461", # class: 'kc' 23 | Keyword.Declaration: "bold #004461", # class: 'kd' 24 | Keyword.Namespace: "bold #004461", # class: 'kn' 25 | Keyword.Pseudo: "bold #004461", # class: 'kp' 26 | Keyword.Reserved: "bold #004461", # class: 'kr' 27 | Keyword.Type: "bold #004461", # class: 'kt' 28 | 29 | Operator: "#582800", # class: 'o' 30 | Operator.Word: "bold #004461", # class: 'ow' - like keywords 31 | 32 | Punctuation: "bold #000000", # class: 'p' 33 | 34 | # because special names such as Name.Class, Name.Function, etc. 35 | # are not recognized as such later in the parsing, we choose them 36 | # to look the same as ordinary variables. 37 | Name: "#000000", # class: 'n' 38 | Name.Attribute: "#c4a000", # class: 'na' - to be revised 39 | Name.Builtin: "#004461", # class: 'nb' 40 | Name.Builtin.Pseudo: "#3465a4", # class: 'bp' 41 | Name.Class: "#000000", # class: 'nc' - to be revised 42 | Name.Constant: "#000000", # class: 'no' - to be revised 43 | Name.Decorator: "#888", # class: 'nd' - to be revised 44 | Name.Entity: "#ce5c00", # class: 'ni' 45 | Name.Exception: "bold #cc0000", # class: 'ne' 46 | Name.Function: "#000000", # class: 'nf' 47 | Name.Property: "#000000", # class: 'py' 48 | Name.Label: "#f57900", # class: 'nl' 49 | Name.Namespace: "#000000", # class: 'nn' - to be revised 50 | Name.Other: "#000000", # class: 'nx' 51 | Name.Tag: "bold #004461", # class: 'nt' - like a keyword 52 | Name.Variable: "#000000", # class: 'nv' - to be revised 53 | Name.Variable.Class: "#000000", # class: 'vc' - to be revised 54 | Name.Variable.Global: "#000000", # class: 'vg' - to be revised 55 | Name.Variable.Instance: "#000000", # class: 'vi' - to be revised 56 | 57 | Number: "#990000", # class: 'm' 58 | 59 | Literal: "#000000", # class: 'l' 60 | Literal.Date: "#000000", # class: 'ld' 61 | 62 | String: "#4e9a06", # class: 's' 63 | String.Backtick: "#4e9a06", # class: 'sb' 64 | String.Char: "#4e9a06", # class: 'sc' 65 | String.Doc: "italic #8f5902", # class: 'sd' - like a comment 66 | String.Double: "#4e9a06", # class: 's2' 67 | String.Escape: "#4e9a06", # class: 'se' 68 | String.Heredoc: "#4e9a06", # class: 'sh' 69 | String.Interpol: "#4e9a06", # class: 'si' 70 | String.Other: "#4e9a06", # class: 'sx' 71 | String.Regex: "#4e9a06", # class: 'sr' 72 | String.Single: "#4e9a06", # class: 's1' 73 | String.Symbol: "#4e9a06", # class: 'ss' 74 | 75 | Generic: "#000000", # class: 'g' 76 | Generic.Deleted: "#a40000", # class: 'gd' 77 | Generic.Emph: "italic #000000", # class: 'ge' 78 | Generic.Error: "#ef2929", # class: 'gr' 79 | Generic.Heading: "bold #000080", # class: 'gh' 80 | Generic.Inserted: "#00A000", # class: 'gi' 81 | Generic.Output: "#888", # class: 'go' 82 | Generic.Prompt: "#745334", # class: 'gp' 83 | Generic.Strong: "bold #000000", # class: 'gs' 84 | Generic.Subheading: "bold #800080", # class: 'gu' 85 | Generic.Traceback: "bold #a40000", # class: 'gt' 86 | } -------------------------------------------------------------------------------- /bulbea/entity/share.py: -------------------------------------------------------------------------------- 1 | # imports - compatibility packages 2 | from __future__ import absolute_import 3 | 4 | # imports - standard packages 5 | import os 6 | import warnings 7 | 8 | # imports - third-party packages 9 | import numpy as np 10 | import matplotlib.pyplot as pplt 11 | import pandas as pd 12 | import quandl 13 | 14 | # module imports 15 | from bulbea.config.app import AppConfig 16 | from bulbea.entity import Entity 17 | from bulbea._util import ( 18 | _check_type, 19 | _check_str, 20 | _check_int, 21 | _check_pandas_series, 22 | _check_pandas_dataframe, 23 | _check_iterable, 24 | _check_environment_variable_set, 25 | _validate_date, 26 | _assign_if_none, 27 | _get_type_name, 28 | _get_datetime_str, 29 | _raise_type_error, 30 | _is_sequence_all 31 | ) 32 | from bulbea._util.const import ( 33 | ABSURL_QUANDL, 34 | QUANDL_MAX_DAILY_CALLS, 35 | SHARE_ACCEPTED_SAVE_FORMATS 36 | ) 37 | from bulbea._util.color import Color 38 | import bulbea as bb 39 | 40 | pplt.style.use(AppConfig.PLOT_STYLE) 41 | 42 | def _get_cummulative_return(data): 43 | cumret = (data / data[0]) - 1 44 | 45 | return cumret 46 | 47 | def _get_bollinger_bands_columns(data): 48 | _check_pandas_dataframe(data, raise_err = True) 49 | 50 | columns = list(data.columns) 51 | ncols = len(columns) 52 | 53 | if ncols != 3: 54 | raise ValueError('Expected a pandas.DataFrame with exactly 3 columns, got {ncols} instead.'.format( 55 | ncols = ncols 56 | )) 57 | 58 | if not _is_sequence_all(columns): 59 | raise ValueError('Ambiguous column names: {columns}'.format( 60 | columns = columns 61 | )) 62 | 63 | attr = columns[0] 64 | prefixes = ['Lower', 'Mean', 'Upper'] 65 | columns = ['{prefix} ({attr})'.format( 66 | prefix = prefix, 67 | attr = attr 68 | ) for prefix in prefixes] 69 | 70 | return columns 71 | 72 | def _get_bollinger_bands(data, period = 50, bandwidth = 1): 73 | _check_int(period, raise_err = True) 74 | _check_int(bandwidth, raise_err = True) 75 | 76 | _check_pandas_series(data, raise_err = True) 77 | 78 | roll = data.rolling(window = period) 79 | std, mean = roll.std(), roll.mean() 80 | 81 | upper = mean + bandwidth * std 82 | lower = mean - bandwidth * std 83 | 84 | return (lower, mean, upper) 85 | 86 | def _get_share_filename(share, extension = None): 87 | _check_type(share, bb.Share, raise_err = True, expected_type_name = 'bulbea.Share') 88 | 89 | if extension is not None: 90 | _check_str(extension, raise_err = True) 91 | 92 | source = share.source 93 | ticker = share.ticker 94 | 95 | start = _get_datetime_str(share.data.index.min(), format_ = '%Y%m%d') 96 | end = _get_datetime_str(share.data.index.max(), format_ = '%Y%m%d') 97 | 98 | filename = '{source}_{ticker}_{start}_{end}'.format( 99 | source = source, 100 | ticker = ticker, 101 | start = start, 102 | end = end 103 | ) 104 | 105 | if extension: 106 | filename = '{filename}.{extension}'.format( 107 | filename = filename, 108 | extension = extension 109 | ) 110 | 111 | return filename 112 | 113 | def _plot_global_mean(data, axes): 114 | _check_pandas_series(data, raise_err = True) 115 | 116 | mean = data.mean() 117 | axes.axhline(mean, color = 'b', linestyle = '-.') 118 | 119 | def _plot_bollinger_bands(data, axes, period = 50, bandwidth = 1): 120 | _check_int(period, raise_err = True) 121 | _check_int(bandwidth, raise_err = True) 122 | 123 | _check_pandas_series(data, raise_err = True) 124 | 125 | lowr, mean, uppr = _get_bollinger_bands(data, period = period, bandwidth = bandwidth) 126 | 127 | axes.plot(lowr, color = 'r', linestyle = '--') 128 | axes.plot(mean, color = 'g', linestyle = '--') 129 | axes.plot(uppr, color = 'r', linestyle = '--') 130 | 131 | class Share(Entity): 132 | ''' 133 | A user-created :class:`Share ` object. 134 | 135 | :param source: *source* symbol for economic data 136 | :type source: :obj:`str` 137 | 138 | :param ticker: *ticker* symbol of a share 139 | :type ticker: :obj:`str` 140 | 141 | :param start: starting date string in the form YYYY-MM-DD for acquiring historical records, defaults to the earliest available records 142 | :type start: :obj:`str` 143 | 144 | :param end: ending date string in the form YYYY-MM-DD for acquiring historical records, defaults to the latest available records 145 | :type end: :obj:`str` 146 | 147 | :param latest: acquires the latest N records 148 | :type latest: :obj:`int` 149 | 150 | :Example: 151 | 152 | >>> import bulbea as bb 153 | >>> share = bb.Share(source = 'YAHOO', ticker = 'GOOGL') 154 | >>> share.data.sample(1) 155 | Open High Low Close Volume Adjusted Close 156 | Date 157 | 2003-05-15 18.6 18.849999 18.470001 18.73 71248800.0 1.213325 158 | ''' 159 | def __init__(self, source, ticker, start = None, end = None, latest = None, cache = False): 160 | _check_str(source, raise_err = True) 161 | _check_str(ticker, raise_err = True) 162 | 163 | envvar = AppConfig.ENVIRONMENT_VARIABLE['quandl_api_key'] 164 | 165 | if not _check_environment_variable_set(envvar): 166 | message = Color.warn("Environment variable {envvar} for Quandl hasn't been set. A maximum of {max_calls} calls per day can be made. Visit {url} to get your API key.".format(envvar = envvar, max_calls = QUANDL_MAX_DAILY_CALLS, url = ABSURL_QUANDL)) 167 | 168 | warnings.warn(message) 169 | else: 170 | quandl.ApiConfig.api_key = os.getenv(envvar) 171 | 172 | self.source = source 173 | self.ticker = ticker 174 | 175 | self.update(start = start, end = end, latest = latest, cache = cache) 176 | 177 | def update(self, start = None, end = None, latest = None, cache = False): 178 | ''' 179 | Update the share with the latest available data. 180 | 181 | :Example: 182 | 183 | >>> import bulbea as bb 184 | >>> share = bb.Share(source = 'YAHOO', ticker = 'AAPL') 185 | >>> share.update() 186 | ''' 187 | self.data = quandl.get('{database}/{code}'.format( 188 | database = self.source, 189 | code = self.ticker 190 | )) 191 | self.length = len(self.data) 192 | self.attrs = list(self.data.columns) 193 | 194 | def __len__(self): 195 | ''' 196 | Number of data points available for a given share. 197 | 198 | :Example: 199 | >>> import bulbea as bb 200 | >>> share = bb.Share(source = 'YAHOO', ticker = 'AAPL') 201 | >>> len(share) 202 | 9139 203 | ''' 204 | return self.length 205 | 206 | def bollinger_bands(self, 207 | attrs = 'Close', 208 | period = 50, 209 | bandwidth = 1): 210 | ''' 211 | Returns the Bollinger Bands (R) for each attribute. 212 | 213 | :param attrs: `str` or `list` of attribute name(s) of a share, defaults to *Close* 214 | :type attrs: :obj:`str`, :obj:`list` 215 | 216 | :param period: length of the window to compute moving averages, upper and lower bands 217 | :type period: :obj:`int` 218 | 219 | :param bandwidth: multiple of the standard deviation of upper and lower bands 220 | :type bandwidth: :obj:`int` 221 | 222 | :Example: 223 | 224 | >>> import bulbea as bb 225 | >>> share = bb.Share(source = 'YAHOO', ticker = 'AAPL') 226 | >>> bollinger = share.bollinger_bands() 227 | >>> bollinger.tail() 228 | Lower (Close) Mean (Close) Upper (Close) 229 | Date 230 | 2017-03-07 815.145883 831.694803 848.243724 231 | 2017-03-08 816.050821 832.574004 849.097187 232 | 2017-03-09 817.067353 833.574805 850.082257 233 | 2017-03-10 817.996674 834.604404 851.212135 234 | 2017-03-13 819.243360 835.804605 852.365849 235 | ''' 236 | _check_iterable(attrs, raise_err = True) 237 | 238 | if _check_str(attrs): 239 | attrs = [attrs] 240 | 241 | frames = list() 242 | 243 | for attr in attrs: 244 | data = self.data[attr] 245 | lowr, mean, upper = _get_bollinger_bands(data, period = period, bandwidth = bandwidth) 246 | bollinger_bands = pd.concat([lowr, mean, upper], axis = 1) 247 | bollinger_bands.columns = _get_bollinger_bands_columns(bollinger_bands) 248 | 249 | frames.append(bollinger_bands) 250 | 251 | return frames[0] if len(frames) == 1 else frames 252 | 253 | def plot(self, 254 | attrs = 'Close', 255 | global_mean = False, 256 | bollinger_bands = False, 257 | period = 50, 258 | bandwidth = 1, 259 | subplots = False, *args, **kwargs): 260 | ''' 261 | :param attrs: `str` or `list` of attribute names of a share to plot, defaults to *Close* attribute 262 | :type attrs: :obj: `str`, :obj:`list` 263 | 264 | :Example: 265 | 266 | >>> import bulbea as bb 267 | >>> share = bb.Share(source = 'YAHOO', ticker = 'AAPL') 268 | >>> share.plot() 269 | ''' 270 | _check_iterable(attrs, raise_err = True) 271 | 272 | if _check_str(attrs): 273 | attrs = [attrs] 274 | 275 | plot_stats = global_mean or bollinger_bands 276 | subplots = True if len(attrs) != 1 and plot_stats else subplots 277 | axes = self.data[attrs].plot(subplots = subplots, *args, **kwargs) 278 | 279 | if plot_stats: 280 | if subplots: 281 | for i, attr in enumerate(attrs): 282 | data = self.data[attr] 283 | ax = axes[i] 284 | 285 | if global_mean: 286 | _plot_global_mean(data, ax) 287 | 288 | if bollinger_bands: 289 | _plot_bollinger_bands(data, ax, period = period, bandwidth = bandwidth) 290 | else: 291 | attr = attrs[0] 292 | data = self.data[attr] 293 | 294 | if global_mean: 295 | _plot_global_mean(data, axes) 296 | 297 | if bollinger_bands: 298 | _plot_bollinger_bands(data, axes, period = period, bandwidth = bandwidth) 299 | 300 | return axes 301 | 302 | def save(self, format_ = 'csv', filename = None): 303 | ''' 304 | :param format_: type of format to save the Share object, default 'csv'. 305 | :type format_: :obj:`str` 306 | ''' 307 | if format_ not in SHARE_ACCEPTED_SAVE_FORMATS: 308 | raise ValueError('Format {format_} not accepted. Accepted formats are: {accepted_formats}'.format( 309 | format_ = format_, 310 | accepted_formats = SHARE_ACCEPTED_SAVE_FORMATS 311 | )) 312 | 313 | if filename is not None: 314 | _check_str(filename, raise_err = True) 315 | else: 316 | filename = _get_share_filename(self, extension = format_) 317 | 318 | if format_ is 'csv': 319 | self.data.to_csv(filename) -------------------------------------------------------------------------------- /examples/bulbea.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "deletable": true, 7 | "editable": true 8 | }, 9 | "source": [ 10 | "# bulbea\n", 11 | "> Deep Learning based Python Library for Stock Market Prediction and Modelling\n", 12 | "\n", 13 | "![](bulbea.png)" 14 | ] 15 | }, 16 | { 17 | "cell_type": "markdown", 18 | "metadata": { 19 | "deletable": true, 20 | "editable": true 21 | }, 22 | "source": [ 23 | "A canonical way of importing the `bulbea` module is as follows:" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": null, 29 | "metadata": { 30 | "collapsed": false, 31 | "deletable": true, 32 | "editable": true, 33 | "scrolled": true 34 | }, 35 | "outputs": [], 36 | "source": [ 37 | "import bulbea as bb" 38 | ] 39 | }, 40 | { 41 | "cell_type": "markdown", 42 | "metadata": { 43 | "deletable": true, 44 | "editable": true 45 | }, 46 | "source": [ 47 | "### `bulbea.Share`" 48 | ] 49 | }, 50 | { 51 | "cell_type": "markdown", 52 | "metadata": { 53 | "deletable": true, 54 | "editable": true 55 | }, 56 | "source": [ 57 | "In order to analyse a desired share, we use the `Share` object defined under `bulbea` which considers 2 arguments - *the **source code** for the economic data* and *the **ticker symbol** for a said company*." 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": null, 63 | "metadata": { 64 | "collapsed": true, 65 | "deletable": true, 66 | "editable": true 67 | }, 68 | "outputs": [], 69 | "source": [ 70 | "source, ticker = 'YAHOO', 'INDEX_GSPC'" 71 | ] 72 | }, 73 | { 74 | "cell_type": "markdown", 75 | "metadata": { 76 | "deletable": true, 77 | "editable": true 78 | }, 79 | "source": [ 80 | "Go ahead and create a `Share` object as follows:" 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": null, 86 | "metadata": { 87 | "collapsed": false, 88 | "deletable": true, 89 | "editable": true 90 | }, 91 | "outputs": [], 92 | "source": [ 93 | "share = bb.Share(source, ticker)" 94 | ] 95 | }, 96 | { 97 | "cell_type": "markdown", 98 | "metadata": { 99 | "deletable": true, 100 | "editable": true 101 | }, 102 | "source": [ 103 | "By default, a `Share` object for a said source and symbol provides you historical data since a company's inception, as a `pandas.DataFrame` object. In order to access the same, use the `Share` object's member variable - `data` as follows:" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": null, 109 | "metadata": { 110 | "collapsed": false, 111 | "deletable": true, 112 | "editable": true 113 | }, 114 | "outputs": [], 115 | "source": [ 116 | "data = share.data\n", 117 | "nsamples = 5\n", 118 | "data.tail(nsamples)" 119 | ] 120 | }, 121 | { 122 | "cell_type": "markdown", 123 | "metadata": { 124 | "deletable": true, 125 | "editable": true 126 | }, 127 | "source": [ 128 | "In order to analyse a given attribute, you could plot the same as follows:" 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": null, 134 | "metadata": { 135 | "collapsed": false, 136 | "deletable": true, 137 | "editable": true 138 | }, 139 | "outputs": [], 140 | "source": [ 141 | "figsize = (20, 15)\n", 142 | "\n", 143 | "% matplotlib inline" 144 | ] 145 | }, 146 | { 147 | "cell_type": "code", 148 | "execution_count": null, 149 | "metadata": { 150 | "collapsed": false, 151 | "deletable": true, 152 | "editable": true 153 | }, 154 | "outputs": [], 155 | "source": [ 156 | "share.plot(figsize = figsize)" 157 | ] 158 | }, 159 | { 160 | "cell_type": "code", 161 | "execution_count": null, 162 | "metadata": { 163 | "collapsed": false, 164 | "deletable": true, 165 | "editable": true 166 | }, 167 | "outputs": [], 168 | "source": [ 169 | "share.plot(['Close', 'Adjusted Close'], figsize = figsize)" 170 | ] 171 | }, 172 | { 173 | "cell_type": "markdown", 174 | "metadata": { 175 | "deletable": true, 176 | "editable": true 177 | }, 178 | "source": [ 179 | "### Statistics" 180 | ] 181 | }, 182 | { 183 | "cell_type": "markdown", 184 | "metadata": { 185 | "deletable": true, 186 | "editable": true 187 | }, 188 | "source": [ 189 | "#### Global Mean\n", 190 | "In order to plot the **global mean** of the stock, we could do the same as follows:" 191 | ] 192 | }, 193 | { 194 | "cell_type": "code", 195 | "execution_count": null, 196 | "metadata": { 197 | "collapsed": false, 198 | "deletable": true, 199 | "editable": true 200 | }, 201 | "outputs": [], 202 | "source": [ 203 | "share.plot(figsize = (20, 15), global_mean = True)" 204 | ] 205 | }, 206 | { 207 | "cell_type": "markdown", 208 | "metadata": { 209 | "deletable": true, 210 | "editable": true 211 | }, 212 | "source": [ 213 | "#### Moving Averages and Bollinger Bands (R)" 214 | ] 215 | }, 216 | { 217 | "cell_type": "code", 218 | "execution_count": null, 219 | "metadata": { 220 | "collapsed": false, 221 | "deletable": true, 222 | "editable": true 223 | }, 224 | "outputs": [], 225 | "source": [ 226 | "bands = share.bollinger_bands(period = 50, bandwidth = 2)\n", 227 | "bands.tail(nsamples)" 228 | ] 229 | }, 230 | { 231 | "cell_type": "code", 232 | "execution_count": null, 233 | "metadata": { 234 | "collapsed": false, 235 | "deletable": true, 236 | "editable": true 237 | }, 238 | "outputs": [], 239 | "source": [ 240 | "share.plot(['Close', 'Adjusted Close'], figsize = (20, 15), bollinger_bands = True, period = 100, bandwidth = 2)" 241 | ] 242 | }, 243 | { 244 | "cell_type": "markdown", 245 | "metadata": { 246 | "deletable": true, 247 | "editable": true 248 | }, 249 | "source": [ 250 | "### Training & Testing" 251 | ] 252 | }, 253 | { 254 | "cell_type": "code", 255 | "execution_count": null, 256 | "metadata": { 257 | "collapsed": false, 258 | "deletable": true, 259 | "editable": true 260 | }, 261 | "outputs": [], 262 | "source": [ 263 | "from bulbea.learn.evaluation import split" 264 | ] 265 | }, 266 | { 267 | "cell_type": "code", 268 | "execution_count": null, 269 | "metadata": { 270 | "collapsed": false, 271 | "deletable": true, 272 | "editable": true 273 | }, 274 | "outputs": [], 275 | "source": [ 276 | "scaler, Xtrain, Xtest, ytrain, ytest = split(share, 'Close', normalize = True)" 277 | ] 278 | }, 279 | { 280 | "cell_type": "code", 281 | "execution_count": null, 282 | "metadata": { 283 | "collapsed": false, 284 | "deletable": true, 285 | "editable": true 286 | }, 287 | "outputs": [], 288 | "source": [ 289 | "import numpy as np" 290 | ] 291 | }, 292 | { 293 | "cell_type": "code", 294 | "execution_count": null, 295 | "metadata": { 296 | "collapsed": false, 297 | "deletable": true, 298 | "editable": true 299 | }, 300 | "outputs": [], 301 | "source": [ 302 | "Xtrain = np.reshape(Xtrain, (Xtrain.shape[0], Xtrain.shape[1], 1))\n", 303 | "Xtest = np.reshape(Xtest, ( Xtest.shape[0], Xtest.shape[1], 1))" 304 | ] 305 | }, 306 | { 307 | "cell_type": "markdown", 308 | "metadata": { 309 | "deletable": true, 310 | "editable": true 311 | }, 312 | "source": [ 313 | "### Modelling" 314 | ] 315 | }, 316 | { 317 | "cell_type": "code", 318 | "execution_count": null, 319 | "metadata": { 320 | "collapsed": false, 321 | "deletable": true, 322 | "editable": true 323 | }, 324 | "outputs": [], 325 | "source": [ 326 | "layers = [1, 100, 100, 1] # number of neurons in each layer\n", 327 | "nbatch = 512 \n", 328 | "epochs = 5 \n", 329 | "nvalidation = 0.05" 330 | ] 331 | }, 332 | { 333 | "cell_type": "code", 334 | "execution_count": null, 335 | "metadata": { 336 | "collapsed": false, 337 | "deletable": true, 338 | "editable": true 339 | }, 340 | "outputs": [], 341 | "source": [ 342 | "from bulbea.learn.models import RNN\n", 343 | "from bulbea.learn.models.ann import RNNCell" 344 | ] 345 | }, 346 | { 347 | "cell_type": "code", 348 | "execution_count": null, 349 | "metadata": { 350 | "collapsed": false, 351 | "deletable": true, 352 | "editable": true 353 | }, 354 | "outputs": [], 355 | "source": [ 356 | "rnn = RNN(layers, cell = RNNCell.LSTM)" 357 | ] 358 | }, 359 | { 360 | "cell_type": "markdown", 361 | "metadata": { 362 | "deletable": true, 363 | "editable": true 364 | }, 365 | "source": [ 366 | "#### TRAINING" 367 | ] 368 | }, 369 | { 370 | "cell_type": "code", 371 | "execution_count": null, 372 | "metadata": { 373 | "collapsed": false, 374 | "deletable": true, 375 | "editable": true 376 | }, 377 | "outputs": [], 378 | "source": [ 379 | "rnn.fit(Xtrain, ytrain,\n", 380 | " batch_size = nbatch,\n", 381 | " nb_epoch = epochs,\n", 382 | " validation_split = nvalidation)" 383 | ] 384 | }, 385 | { 386 | "cell_type": "markdown", 387 | "metadata": { 388 | "deletable": true, 389 | "editable": true 390 | }, 391 | "source": [ 392 | "#### TESTING" 393 | ] 394 | }, 395 | { 396 | "cell_type": "code", 397 | "execution_count": null, 398 | "metadata": { 399 | "collapsed": false, 400 | "deletable": true, 401 | "editable": true 402 | }, 403 | "outputs": [], 404 | "source": [ 405 | "predicted = rnn.predict(Xtest)" 406 | ] 407 | }, 408 | { 409 | "cell_type": "code", 410 | "execution_count": null, 411 | "metadata": { 412 | "collapsed": false, 413 | "deletable": true, 414 | "editable": true 415 | }, 416 | "outputs": [], 417 | "source": [ 418 | "from sklearn.metrics import mean_squared_error" 419 | ] 420 | }, 421 | { 422 | "cell_type": "code", 423 | "execution_count": null, 424 | "metadata": { 425 | "collapsed": false, 426 | "deletable": true, 427 | "editable": true, 428 | "scrolled": true 429 | }, 430 | "outputs": [], 431 | "source": [ 432 | "mean_squared_error(ytest, predicted)" 433 | ] 434 | }, 435 | { 436 | "cell_type": "code", 437 | "execution_count": null, 438 | "metadata": { 439 | "collapsed": false, 440 | "deletable": true, 441 | "editable": true 442 | }, 443 | "outputs": [], 444 | "source": [ 445 | "from bulbea.entity.share import _plot_bollinger_bands\n", 446 | "import pandas as pd\n", 447 | "import matplotlib.pyplot as pplt\n", 448 | "\n", 449 | "figsize = (20, 15)\n", 450 | "figure = pplt.figure(figsize = figsize)\n", 451 | "\n", 452 | "axes = figure.add_subplot(111)\n", 453 | "\n", 454 | "series = pd.Series(data = scaler.inverse_transform(ytest))\n", 455 | "\n", 456 | "# axes.plot(scaler.inverse_transform(ytest))\n", 457 | "axes.plot(scaler.inverse_transform(predicted))\n", 458 | "\n", 459 | "_plot_bollinger_bands(series, axes, bandwidth = 10)" 460 | ] 461 | }, 462 | { 463 | "cell_type": "markdown", 464 | "metadata": { 465 | "deletable": true, 466 | "editable": true 467 | }, 468 | "source": [ 469 | "### Sentiment Analysis" 470 | ] 471 | }, 472 | { 473 | "cell_type": "code", 474 | "execution_count": null, 475 | "metadata": { 476 | "collapsed": false, 477 | "deletable": true, 478 | "editable": true 479 | }, 480 | "outputs": [], 481 | "source": [ 482 | "s = bb.sentiment(share)\n", 483 | "s" 484 | ] 485 | } 486 | ], 487 | "metadata": { 488 | "kernelspec": { 489 | "display_name": "Python 3", 490 | "language": "python", 491 | "name": "python3" 492 | }, 493 | "language_info": { 494 | "codemirror_mode": { 495 | "name": "ipython", 496 | "version": 3 497 | }, 498 | "file_extension": ".py", 499 | "mimetype": "text/x-python", 500 | "name": "python", 501 | "nbconvert_exporter": "python", 502 | "pygments_lexer": "ipython3", 503 | "version": "3.5.2" 504 | } 505 | }, 506 | "nbformat": 4, 507 | "nbformat_minor": 2 508 | } 509 | --------------------------------------------------------------------------------