├── ts_charting ├── lab │ ├── __init__.py │ ├── test │ │ └── test_json.py │ └── lab.py ├── test │ ├── __init__.py │ ├── test_styler.py │ ├── test_formatter.py │ └── test_figure.py ├── util.py ├── __init__.py ├── extras.py ├── ipython.py ├── span.py ├── plot_3d.py ├── monkey.py ├── json.py ├── heatmap.py ├── styles.py ├── boxplot.py ├── imagefile.py ├── ohlc.py ├── charting.py ├── formatter.py └── figure.py ├── README.md ├── setup.py └── LICENSE.txt /ts_charting/lab/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ts_charting/test/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ts_charting/util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def process_signal(series, source): 4 | """ 5 | Take any non 0/na value and changes it to corresponding value of source 6 | """ 7 | temp = series.astype(float).copy() 8 | temp[temp == 0] = np.nan 9 | temp, source = temp.align(source, join='left') 10 | temp *= source 11 | return temp 12 | -------------------------------------------------------------------------------- /ts_charting/__init__.py: -------------------------------------------------------------------------------- 1 | from ts_charting.figure import Figure, Grapher 2 | 3 | from ts_charting.charting import * 4 | import ts_charting.ohlc 5 | import ts_charting.boxplot 6 | import ts_charting.span 7 | from ts_charting.styles import styler, marker_styler, level_styler 8 | from ts_charting.ipython import figsize, IN_NOTEBOOK 9 | from ts_charting.plot_3d import plot_wireframe 10 | from ts_charting.imagefile import plot_pdf, save_images 11 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Time Series Charting 2 | =================== 3 | 4 | Essentially moving `trtools.charting` out to its own repo. 5 | 6 | Example Notebooks 7 | ================= 8 | ## [ohlc (nbviewer)](http://nbviewer.ipython.org/urls/raw.github.com/dalejung/ts-charting/master/notebooks/ohlc.ipynb) 9 | 10 | **covers:** 11 | 12 | * `ohlc_plot` 13 | * Multiple subplot 14 | * Multiple y-axis per plot with splines 15 | * Aligning x-axis with sharex 16 | * Highlighting horizontal spans 17 | * Custom tick locator 18 | 19 | d3.js 20 | ====== 21 | 22 | Been working on a `d3.js` adapter: [id3](https://github.com/dalejung/id3) 23 | 24 | -------------------------------------------------------------------------------- /ts_charting/extras.py: -------------------------------------------------------------------------------- 1 | # TODO SeriesByGroupBy.boxplot 2 | """ 3 | import matplotlib.ticker as ticker 4 | 5 | labels = [] 6 | data = [] 7 | for label, group in grouped: 8 | labels.append(label) 9 | data.append(group) 10 | r = labels 11 | N = len(r) 12 | ind = np.arange(N) # the evenly spaced plot indices 13 | def format_date(x, pos=None): 14 | thisind = np.clip(int(x+0.5), 0, N-1) 15 | return r[thisind].strftime('%Y-%m-%d') 16 | 17 | fig = gcf() 18 | ax = gca() 19 | ax.xaxis.set_major_formatter(ticker.FuncFormatter(format_date)) 20 | _ = boxplot(data) 21 | fig.autofmt_xdate() 22 | """ 23 | -------------------------------------------------------------------------------- /ts_charting/lab/test/test_json.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from pandas import json 4 | from ts_charting.json import to_json 5 | 6 | import ts_charting.lab.lab as tslab 7 | 8 | plot_index = pd.date_range(start="2000-1-1", freq="B", periods=10000) 9 | df = pd.DataFrame(index=plot_index) 10 | df['open'] = np.random.randn(len(plot_index)) 11 | df['high'] = np.random.randn(len(plot_index)) 12 | df['low'] = np.random.randn(len(plot_index)) 13 | df['close'] = np.random.randn(len(plot_index)) 14 | 15 | lab = tslab.Lab() 16 | fig = lab.station('candle') 17 | df.tail(5).ohlc_plot() 18 | fig.plot_markers('high', df.high > df.high.shift(1), yvalues=df.open) 19 | 20 | jd = to_json(lab) 21 | obj = json.loads(jd) 22 | 23 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup 2 | 3 | DISTNAME='ts_charting' 4 | FULLVERSION='0.1' 5 | 6 | setup(name=DISTNAME, 7 | version=FULLVERSION, 8 | packages=['ts_charting', 9 | ] 10 | ) 11 | -------------------------------------------------------------------------------- /ts_charting/ipython.py: -------------------------------------------------------------------------------- 1 | """ 2 | Specific ipython stuff 3 | """ 4 | import IPython 5 | 6 | from ts_charting import reset_figure 7 | 8 | IN_NOTEBOOK = True 9 | instance = IPython.Application._instance 10 | 11 | # IPython.frontend was flattened so its submodule now live in the root 12 | # namespace. i.e. IPython.frontend.terminal -> IPython.terminal 13 | if hasattr(IPython, 'frontend'): 14 | terminal = IPython.frontend.terminal.ipapp.TerminalIPythonApp 15 | else: 16 | terminal = IPython.terminal.ipapp.TerminalIPythonApp 17 | 18 | if isinstance(instance, terminal): 19 | IN_NOTEBOOK = False 20 | 21 | def figsize(width, height): 22 | """ 23 | Resize figure 24 | """ 25 | IPython.core.pylabtools.figsize(width, height) 26 | 27 | # in notebook, reset the CURRENT_FIGURE for every cell execution 28 | # this allows us to have cell specific plots 29 | shell = IPython.InteractiveShell._instance 30 | if IN_NOTEBOOK and shell: 31 | shell.register_post_execute(reset_figure) 32 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2013 dalejung 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of 6 | this software and associated documentation files (the "Software"), to deal in 7 | the Software without restriction, including without limitation the rights to 8 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 9 | the Software, and to permit persons to whom the Software is furnished to do so, 10 | subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 17 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 18 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 19 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 20 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 21 | -------------------------------------------------------------------------------- /ts_charting/span.py: -------------------------------------------------------------------------------- 1 | """ 2 | Span highlighting 3 | """ 4 | import ts_charting as charting 5 | 6 | def highlight_span(start=None, end=None, color='g', alpha=0.5, grapher=None): 7 | """ 8 | A quick shortcut way to highlight regions of a chart. Uses the Grapher.df.index 9 | to translate non int-position arguments to int locations. 10 | """ 11 | if start is None and end is None: 12 | raise Exception("strat and end cannot both be None") 13 | 14 | if grapher is None: 15 | fig = charting.gcf() 16 | grapher = fig.grapher 17 | index = grapher.index 18 | 19 | if index is None: 20 | raise Exception("grapher/ax has no plots on it. Can only highlight populated ax") 21 | 22 | if start is None: 23 | start = 0 24 | if end is None: 25 | end = len(index) 26 | 27 | # convert from object/string to pos-index 28 | if not isinstance(start, int): 29 | start = index.get_loc(start) 30 | if not isinstance(end, int): 31 | end = index.get_loc(end) 32 | 33 | grapher.ax.axvspan(start, end, color=color, alpha=alpha) 34 | 35 | def hl_span_figure(self, *args, **kwargs): 36 | grapher = self.grapher 37 | kwargs['grapher'] = grapher 38 | return highlight_span(*args, **kwargs) 39 | 40 | charting.Figure.hl_span = hl_span_figure 41 | -------------------------------------------------------------------------------- /ts_charting/plot_3d.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | 4 | from mpl_toolkits.mplot3d import Axes3D 5 | from pylab import plt 6 | 7 | def grab_first_unique(index): 8 | """ 9 | There are instances where you are subsetting your data and end up with a MultiIndex where 10 | one level is constant. This function grabs the first (and hopefully) unique level and returns it. 11 | This is to you can plot a DataFrame that has the correct format but might have an extraneous index level 12 | """ 13 | if isinstance(index, pd.MultiIndex): 14 | for i in range(index.nlevels): 15 | ind = index.get_level_values(i) 16 | if ind.is_unique: 17 | return ind 18 | return index 19 | 20 | def _3d_values(df): 21 | # grab the first non-unique index 22 | index = grab_first_unique(df.index) 23 | columns = grab_first_unique(df.columns) 24 | 25 | X, Y = np.meshgrid(index, columns) 26 | Z = df.values.reshape(len(X), len(Y)) 27 | return {'values': (X, Y, Z), 'labels': (index.name, columns.name)} 28 | 29 | def plot_wireframe(df, ax=None, *args, **kwargs): 30 | if ax is None: 31 | fig = plt.figure() 32 | ax = Axes3D(fig) 33 | 34 | res = _3d_values(df) 35 | X, Y, Z = res['values'] 36 | x_name, y_name = res['labels'] 37 | 38 | ax.plot_wireframe(X, Y, Z, *args, **kwargs) 39 | ax.set_xlabel(x_name) 40 | ax.set_ylabel(y_name) 41 | return ax 42 | -------------------------------------------------------------------------------- /ts_charting/monkey.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tools to shard logic into separate files like ohlc.py. 3 | """ 4 | def merge(base, mixin, overrides=None): 5 | """ 6 | Merge attributes from mixin class to base 7 | 8 | overrides : list of strings 9 | List of name that will be transferred regardless of previous checks. 10 | for the time being, this is only for double underscore names 11 | """ 12 | if overrides is None: 13 | overrides = [] 14 | 15 | 16 | for name, meth in list(mixin.__dict__.items()): 17 | if name.startswith('__') and name not in overrides: 18 | continue 19 | 20 | if hasattr(base, name): 21 | # note that the base._mixins_ check should prevent us frogm 22 | # running the same mixin 23 | raise Exception("We should never replace an existing method. {0}".format(name)) 24 | setattr(base, name, meth) 25 | 26 | def mixin(base, overrides=None): 27 | """ 28 | Create mixin decorator for specific base class 29 | """ 30 | def _mixin(mixin): 31 | mixin_name = mixin.__name__ 32 | _mixins_ = getattr(base, '_mixins_', []) 33 | if mixin_name in _mixins_: 34 | print(('{mixin_name} already mixed'.format(mixin_name=mixin_name))) 35 | return False 36 | _mixins_.append(mixin_name) 37 | setattr(base, '_mixins_', _mixins_) 38 | 39 | merge(base, mixin, overrides=overrides) 40 | return _mixin 41 | -------------------------------------------------------------------------------- /ts_charting/json.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from pandas import json 3 | from IPython.display import JSON 4 | 5 | def dataframe_json(df): 6 | data = {} 7 | for k, v in list(df.items()): 8 | data[k] = v.values 9 | data['index'] = df.index 10 | data['_pandas_type'] = 'dataframe'; 11 | data['__repr__'] = repr(df) 12 | return json.dumps(data) 13 | 14 | def series_json(series): 15 | data = {} 16 | data['data'] = series.values 17 | data['index'] = series.index 18 | data['name'] = series.name 19 | data['_pandas_type'] = 'series' 20 | data['__repr__'] = repr(series) 21 | return json.dumps(data) 22 | 23 | def to_json(obj): 24 | if isinstance(obj, pd.DataFrame): 25 | return dataframe_json(obj) 26 | 27 | if isinstance(obj, pd.Series): 28 | return series_json(obj) 29 | 30 | if isinstance(obj, list): 31 | jlist = [] 32 | for v in obj: 33 | jlist.append(to_json(v)) 34 | return json_list(jlist) 35 | 36 | if isinstance(obj, dict): 37 | jdict = {} 38 | for k, v in list(obj.items()): 39 | jdict[k] = to_json(v) 40 | return json_dict(jdict) 41 | 42 | if hasattr(obj, 'to_json'): 43 | return obj.to_json() 44 | 45 | return json.dumps(obj) 46 | 47 | def json_dict(dct): 48 | items = [] 49 | for k, v in list(dct.items()): 50 | items.append("\"{k}\":{v}".format(k=k, v=v)) 51 | return "{" + ','.join(items) + "}" 52 | 53 | def json_list(lst): 54 | return "[" + ','.join(lst) + "]" 55 | 56 | def to_json_display(obj): 57 | return JSON(to_json(obj)); 58 | -------------------------------------------------------------------------------- /ts_charting/test/test_styler.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | import pandas as pd 3 | import numpy as np 4 | 5 | import ts_charting.styles as styles 6 | level_styler = styles.level_styler 7 | 8 | 9 | class TestStyler(TestCase): 10 | # testing base pandas 11 | 12 | def __init__(self, *args, **kwargs): 13 | TestCase.__init__(self, *args, **kwargs) 14 | 15 | def runTest(self): 16 | pass 17 | 18 | def setUp(self): 19 | pass 20 | 21 | def test_level_styler(self): 22 | """ 23 | Test level styler. 24 | """ 25 | df = pd.DataFrame({'value':np.random.randn(100)}) 26 | df['num_cat'] = np.random.choice(list(range(5)), 100) * 100 27 | df['name'] = np.random.choice(['dale', 'bob', 'wes', 'frank'], 100) 28 | 29 | styles = level_styler(color=df.num_cat, linestyle=df.name) 30 | style_df = pd.DataFrame(styles) 31 | 32 | for name in df.name.unique(): 33 | inds = df.name == name 34 | # all values with the same name shoudl have same linestyle 35 | assert len(np.unique(style_df.ix[inds].linestyle)) == 1 36 | 37 | for i in df.num_cat.unique(): 38 | inds = df.num_cat == i 39 | # all values with the same name shoudl have same color 40 | assert len(np.unique(style_df.ix[inds].color)) == 1 41 | 42 | if __name__ == '__main__': 43 | import nose 44 | nose.runmodule(argv=[__file__,'-vvs','-x','--pdb', '--pdb-failure'],exit=False) 45 | -------------------------------------------------------------------------------- /ts_charting/heatmap.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | import pandas as pd 5 | from pylab import * 6 | 7 | def _gen_labels(labels, names=None): 8 | if names is None: 9 | names = labels.names 10 | if np.isscalar(labels[0]): 11 | labels = [(l,) for l in labels] 12 | zips = [list(zip(names, l)) for l in labels] 13 | new_labels = [', '.join(['{1}'.format(*m) for m in z]) for z in zips] 14 | return new_labels, names 15 | 16 | def heatmap(data, xlabels=None, ylabels=None, title=None): 17 | fig, ax = plt.subplots() 18 | 19 | values = data.values 20 | cmap = plt.cm.RdYlGn 21 | # plot np.nan as white 22 | cmap.set_bad('w',1.) 23 | masked_array = np.ma.array(values, mask=np.isnan(values)) 24 | 25 | heatmap = ax.pcolormesh(masked_array, cmap=cmap) 26 | plt.colorbar(heatmap) 27 | 28 | 29 | xaxis = data.columns 30 | yaxis = data.index 31 | xlabels, xnames = _gen_labels(xaxis) 32 | ylabels, ynames = _gen_labels(yaxis) 33 | ax.set_xticklabels(xlabels, minor=False) 34 | ax.set_yticklabels(ylabels, minor=False) 35 | 36 | ax.set_xlabel(xnames) 37 | ax.set_ylabel(ynames) 38 | 39 | # trying to be smart about creating ticks. 40 | # previously this was to cut down on having like 1000's of labels 41 | if isinstance(yaxis, pd.MultiIndex): 42 | yaxis_labels = yaxis.labels 43 | else: 44 | yaxis_labels = [yaxis] 45 | 46 | for i in range(len(yaxis_labels)): 47 | labels, ind = np.unique(yaxis_labels[i], return_index=True) 48 | yticks = ind + 0.5 49 | if len(ind) > 1: 50 | break 51 | 52 | ax.set_xticks(np.arange(len(xaxis))+0.5, minor=False) 53 | ax.set_yticks(yticks, minor=False) 54 | plt.xticks(rotation=90) 55 | ax.set_xlim(0, len(xaxis)) 56 | ax.set_ylim(0, len(yaxis)) 57 | if title: 58 | ax.set_title(title) 59 | return ax 60 | -------------------------------------------------------------------------------- /ts_charting/styles.py: -------------------------------------------------------------------------------- 1 | try: 2 | from itertools import izip 3 | except ImportError: 4 | izip = zip 5 | 6 | import itertools 7 | from collections import OrderedDict 8 | from pandas.core.algorithms import factorize 9 | import numpy as np 10 | 11 | LINESTYLES = ('-', '--', ':') 12 | COLORS = ('b', 'g', 'r', 'c', 'm', 'y', 'k') 13 | MARKERS = (None,'o', 's', 'v', '*', '^', 'x') 14 | 15 | def styler(): 16 | """ 17 | Default styler that cycles colors than line-styles 18 | """ 19 | styles = itertools.product(LINESTYLES, COLORS) 20 | 21 | # cycle through 22 | while True: 23 | yield dict(list(zip(('linestyle', 'color'), next(styles)))) 24 | 25 | def marker_styler(): 26 | """ 27 | Adds differing markers 28 | """ 29 | styles = itertools.product(LINESTYLES, MARKERS, COLORS) 30 | 31 | # cycle through 32 | while True: 33 | yield dict(list(zip(('linestyle', 'marker', 'color'), next(styles)))) 34 | 35 | class StyleCategory(object): 36 | def __init__(self, name, values): 37 | self.name = name 38 | self.values = values 39 | 40 | STYLES = {} 41 | STYLES['linestyle'] = LINESTYLES 42 | STYLES['color'] = COLORS 43 | STYLES['marker'] = MARKERS 44 | 45 | def level_styler(linestyle=None, color=None, marker=None): 46 | """ 47 | This function is useful for categorical plotting. Based on certain categories, 48 | it will return styles that are persistant. 49 | 50 | This is when you want to distinguish groups of line plots by their style 51 | """ 52 | vars = locals().copy() 53 | 54 | styles = OrderedDict() 55 | 56 | for k, SC in list(STYLES.items()): 57 | vals = vars.get(k, None) 58 | if vals is None: 59 | continue 60 | labels, uniques = factorize(vals) 61 | labels = labels % len(SC) # cycle back to start 62 | style_values = np.take(SC, labels) 63 | styles[k] = style_values 64 | 65 | keys = list(styles.keys()) 66 | return [dict(list(zip(keys, st))) for st in izip(*list(styles.values()))] 67 | -------------------------------------------------------------------------------- /ts_charting/boxplot.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from pandas.core.series import remove_na 3 | 4 | from ts_charting import Figure, Grapher 5 | from ts_charting.monkey import mixin 6 | 7 | @mixin(Figure) 8 | class BoxPlotFigure(object): 9 | 10 | def boxplot(self, df, axis=0, *args, **kwargs): 11 | self.figure.autofmt_xdate() 12 | self.grapher.boxplot(df, axis=axis, *args, **kwargs) 13 | 14 | @mixin(Grapher) 15 | class BoxPlotGrapher(object): 16 | def boxplot(self, df, axis=0, secondary_y=False, *args, **kwargs): 17 | """ 18 | Currently supports plotting DataFrames. 19 | 20 | Downside is that this only works for data that has equal columns. 21 | For something like plotting groups with varying sizes, you'd 22 | need to use boxplot(list()). Example is creating a SeriesGroupBy.boxplot 23 | """ 24 | if axis == 1: 25 | df = df.T 26 | index = df.columns 27 | self.set_index(index) 28 | clean_values = [remove_na(x) for x in df.values.T] 29 | 30 | ax = self.find_ax(secondary_y, kwargs) 31 | 32 | # positions need to start at 0 to align with TimestampLocator 33 | ax.boxplot(clean_values, positions=np.arange(len(index))) 34 | self.setup_datetime(index) 35 | self.set_formatter() 36 | 37 | def boxplot_list(self, data, secondary_y=False, *args, **kwargs): 38 | pass 39 | 40 | # TODO SeriesByGroupBy.boxplot 41 | """ 42 | import matplotlib.ticker as ticker 43 | 44 | labels = [] 45 | data = [] 46 | for label, group in grouped: 47 | labels.append(label) 48 | data.append(group) 49 | r = labels 50 | N = len(r) 51 | ind = np.arange(N) # the evenly spaced plot indices 52 | def format_date(x, pos=None): 53 | thisind = np.clip(int(x+0.5), 0, N-1) 54 | return r[thisind].strftime('%Y-%m-%d') 55 | 56 | fig = gcf() 57 | ax = gca() 58 | ax.xaxis.set_major_formatter(ticker.FuncFormatter(format_date)) 59 | _ = boxplot(data) 60 | fig.autofmt_xdate() 61 | """ 62 | -------------------------------------------------------------------------------- /ts_charting/imagefile.py: -------------------------------------------------------------------------------- 1 | """ 2 | Idea is to turn on writing to images for all plots 3 | """ 4 | import os 5 | import errno 6 | import tempfile 7 | 8 | import pandas as pd 9 | import IPython 10 | import IPython.core.pylabtools as pylabtools 11 | import matplotlib.pylab as pylab 12 | import matplotlib.pyplot as plt 13 | from matplotlib.backends.backend_pdf import PdfPages 14 | 15 | import ts_charting as charting 16 | 17 | def save_to_pdf(file, figs=None): 18 | with PdfPages(file) as pdf: 19 | 20 | if figs is None: 21 | figs = pylabtools.getfigs() 22 | 23 | for fig in figs: 24 | fig.savefig(pdf, format='pdf') 25 | 26 | close_figures() 27 | 28 | def plot_pdf(fn=None, open=True): 29 | if fn is None: 30 | file = tempfile.NamedTemporaryFile(suffix='.pdf', delete=False) 31 | fn = file.name 32 | 33 | dir, file = os.path.split(fn) 34 | if dir: 35 | mkdir_p(dir) 36 | 37 | save_to_pdf(fn) 38 | if open: 39 | os.system('open '+fn) 40 | return fn 41 | 42 | def _get_title(fig): 43 | """ 44 | grab title from figure. Assume it's a one ax per figure or that 45 | the main ax is the first one 46 | """ 47 | ax = fig.get_axes()[0] # assume first ax is correct 48 | title = ax.title.get_text() 49 | return title 50 | 51 | def mkdir_p(path): 52 | """ 53 | http://stackoverflow.com/questions/600268/mkdir-p-functionality-in-python 54 | """ 55 | try: 56 | os.makedirs(path) 57 | except OSError as exc: # Python >2.5 58 | if exc.errno == errno.EEXIST and os.path.isdir(path): 59 | pass 60 | else: raise 61 | 62 | def save_images(dir='', figs=None, prefix=None): 63 | """ 64 | Save all open figures to image files. 65 | 66 | Parameters: 67 | ---------- 68 | dir : string 69 | Directory to place image files into 70 | figs : list of Figures 71 | will default to open figures 72 | prefix : string 73 | prefix all image file names 74 | """ 75 | if figs is None: 76 | figs = pylabtools.getfigs() 77 | 78 | if dir: 79 | mkdir_p(dir) 80 | 81 | for i, fig in enumerate(figs, 1): 82 | label = _get_title(fig) 83 | if label == '': 84 | label = "Figure_%d" % i 85 | if prefix: 86 | label = prefix + '_' + label 87 | filepath = os.path.join(dir, label+'.png') 88 | fig.savefig(filepath) 89 | 90 | close_figures() 91 | 92 | def close_figures(): 93 | plt.close('all') 94 | charting.gcf(reset=True) 95 | 96 | 97 | # start of doing something where the execution stuff runs automatically? 98 | def imagefile_reroute(func): 99 | def wrapped(*args, **kwargs): 100 | return func(*args, **kwargs) 101 | return wrapped 102 | 103 | shell = IPython.InteractiveShell._instance 104 | shell = None 105 | 106 | # check so we don't break non ipython runs 107 | if shell: 108 | execution_magic = shell.magics_manager.registry['ExecutionMagics'] 109 | execution_magic.default_runner = imagefile_reroute(execution_magic.default_runner) 110 | -------------------------------------------------------------------------------- /ts_charting/ohlc.py: -------------------------------------------------------------------------------- 1 | try: 2 | from itertools import izip 3 | except ImportError: 4 | izip = zip 5 | 6 | import pandas as pd 7 | import numpy as np 8 | 9 | from matplotlib.finance import candlestick_ochl as candlestick 10 | 11 | from ts_charting import Figure, Grapher, gcf 12 | from ts_charting.monkey import mixin 13 | 14 | def _match_col(col, columns): 15 | """ 16 | Match column name by the following process: 17 | 1. col == 'open' 18 | 2. col == 'Open' 19 | 3. 'open' in col.lower() 20 | """ 21 | for test in columns: 22 | # merged 1 and 2, assuming one won't have 'open' and 'Open' 23 | if col == test.lower(): 24 | return test 25 | if col in test.lower(): 26 | return test 27 | 28 | def normalize_ohlc(df): 29 | """ 30 | Return an OHLC where the column names are single word lower cased 31 | 32 | This is support dataframes like ones from quantmod whihc have the 33 | symbol embedded in the column name. i.e. SPY.Close 34 | """ 35 | cols = ['open', 'high', 'low', 'close'] 36 | matched = [] 37 | for col in cols: 38 | match = _match_col(col, df.columns) 39 | if match: 40 | matched.append(match) 41 | continue 42 | 43 | raise Exception("{col} not found".format(col=col)) 44 | res = df.ix[:, matched] 45 | res.columns = cols 46 | return res 47 | 48 | @mixin(Figure) 49 | class OHLCFigure(object): 50 | def candlestick(self, *args, **kwargs): 51 | if self.ax is None: 52 | print('NO AX set') 53 | return 54 | self.figure.autofmt_xdate() 55 | self.grapher.candlestick(*args, **kwargs) 56 | 57 | def ohlc(self, *args, **kwargs): 58 | if self.ax is None: 59 | print('NO AX set') 60 | return 61 | self.figure.autofmt_xdate() 62 | self.grapher.ohlc(*args, **kwargs) 63 | 64 | @mixin(Grapher) 65 | class OHLCGrapher(object): 66 | 67 | def candlestick(self, index, open, high, low, close, width=0.3, secondary_y=False, 68 | *args, **kwargs): 69 | """ 70 | Takes a df and plots a candlestick. 71 | Will auto search for proper columns 72 | """ 73 | data = {} 74 | data['open'] = open 75 | data['high'] = high 76 | data['low'] = low 77 | data['close'] = close 78 | df = pd.DataFrame(data, index=index) 79 | 80 | if self.index is None: 81 | self.index = index 82 | 83 | # grab merged data 84 | xax = np.arange(len(self.index)) 85 | quotes = izip(xax, df['open'], df['close'], df['high'], df['low']) 86 | 87 | ax = self.find_ax(secondary_y, kwargs) 88 | 89 | self.setup_datetime(index) 90 | candlestick(ax, quotes, width=width, colorup='g') 91 | 92 | def ohlc(self, df, width=0.3, *args, **kwargs): 93 | ohlc_df = normalize_ohlc(df) 94 | self.candlestick(df.index, ohlc_df.open, ohlc_df.high, ohlc_df.low, ohlc_df.close, *args, **kwargs) 95 | 96 | def ohlc_plot(self, width=0.3, *args, **kwargs): 97 | fig = gcf() 98 | fig.ohlc(self, width=width, *args, **kwargs) 99 | return fig 100 | 101 | pd.DataFrame.ohlc_plot = ohlc_plot 102 | 103 | -------------------------------------------------------------------------------- /ts_charting/lab/lab.py: -------------------------------------------------------------------------------- 1 | """ 2 | Conceptually a Lab is a JSON-able working area. 3 | 4 | A station, at least for now, maps to a Figure. 5 | 6 | I'm not 100% sure these abstractions are the way to go, or 7 | if this even belongs in this module. 8 | """ 9 | from collections import OrderedDict 10 | from ts_charting import json 11 | 12 | from ts_charting import Figure, scf 13 | from ts_charting.util import process_signal 14 | 15 | class FakeFigure(object): 16 | """ 17 | Figure that does no plotting. 18 | """ 19 | def __getattr__(self, name): 20 | return self.fake_call 21 | 22 | def fake_call(self, *args, **kwargs): 23 | pass 24 | 25 | class Lab(object): 26 | def __init__(self, draw=False): 27 | self.draw = draw 28 | self.data = {} 29 | self.plots = {} 30 | self.stations = OrderedDict() 31 | 32 | def station(self, name): 33 | station = Station(self, name, draw=self.draw) 34 | self.stations[name] = station 35 | scf(station) 36 | return station 37 | 38 | def to_json(self): 39 | dct = {} 40 | dct['stations'] = self.stations 41 | return json.to_json(dct) 42 | 43 | class Station(object): 44 | def __init__(self, lab, name, draw=False): 45 | self.lab = lab 46 | self.name = name 47 | if draw: 48 | self.figure = Figure(1, warn=False) 49 | else: 50 | self.figure = FakeFigure() 51 | self.layers = [] 52 | 53 | def plot_markers(self, name, series, yvalues=None, xindex=None, **kwargs): 54 | geom = {'type': 'marker'} 55 | # dont support xindex for now 56 | #geom['xindex'] = xindex 57 | geom.update(kwargs) 58 | 59 | if yvalues is not None: 60 | series = process_signal(series, yvalues) 61 | self.add_layer(name, series, geom) 62 | self.figure.plot_markers(name, series, **kwargs) 63 | 64 | def plot(self, label, series, **kwargs): 65 | geom = {'type': 'line'} 66 | geom.update(kwargs) 67 | self.add_layer(label, series, geom) 68 | self.figure.plot(label, series, **kwargs) 69 | 70 | def ohlc(self, df, width=0.3): 71 | # ohlc_df = normalize_ohlc(df) 72 | self.add_layer('candlestick', df, {'type': 'candlestick', 'width': .03}) 73 | self.figure.ohlc(df, width=width) 74 | 75 | def add_layer(self, name, data, geoms): 76 | if not isinstance(geoms, list): 77 | geoms = [geoms] 78 | 79 | self.layers.append({'name': name, 'data': data, 'geoms':geoms}) 80 | 81 | def to_json(self): 82 | dct = self.__dict__.copy() 83 | del dct['figure'] 84 | del dct['lab'] 85 | index = self.consolidate_index() 86 | dct['index'] = index 87 | return json.to_json(dct) 88 | 89 | def consolidate_index(self): 90 | """ 91 | Take the index of every layer's data and create a single index. Currently 92 | acts like the regular Grapher and uses the first data's index as the master 93 | index. All other indexes get reindex to that one 94 | """ 95 | index = None 96 | for layer in self.layers: 97 | if index is None: 98 | index = layer['data'].index 99 | continue 100 | # reindex the layer data and any yvalues passed into its geoms 101 | layer['data'] = layer['data'].reindex(index) 102 | return index 103 | -------------------------------------------------------------------------------- /ts_charting/charting.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from pandas.util.decorators import Appender 3 | 4 | from ts_charting import Figure 5 | import ts_charting.styles as cstyler 6 | 7 | CURRENT_FIGURE = None 8 | 9 | def reset_figure(*args): 10 | """ 11 | In ipython notebook, clear the figure after each cell execute. 12 | This negates the need to specify a Figure for each plot 13 | """ 14 | global CURRENT_FIGURE 15 | CURRENT_FIGURE = None 16 | 17 | def gcf(reset=False): 18 | global CURRENT_FIGURE 19 | if reset: 20 | CURRENT_FIGURE = None 21 | return CURRENT_FIGURE 22 | 23 | if CURRENT_FIGURE is None: 24 | CURRENT_FIGURE = figure(1) 25 | return CURRENT_FIGURE 26 | 27 | def scf(figure): 28 | global CURRENT_FIGURE 29 | CURRENT_FIGURE = figure 30 | 31 | _fplot_doc = """ 32 | Keyword Parameters 33 | ---------- 34 | secondary_y : bool 35 | Plot on a secondary y-axis 36 | yax : string: 37 | named y-axis plot. 38 | """ 39 | # Monkey Patches, no good reason for this to be here... 40 | @Appender(_fplot_doc) 41 | def series_plot(self, label=None, *args, **kwargs): 42 | label = plot_label(self, label, **kwargs) 43 | 44 | fig = gcf() 45 | fig.plot(str(label), self, *args, **kwargs) 46 | 47 | pd.Series.fplot = series_plot 48 | pd.TimeSeries.fplot = series_plot 49 | 50 | def df_plot(self, *args, **kwargs): 51 | force_plot = kwargs.pop('force_plot', False) 52 | styler = kwargs.pop('styler', cstyler.marker_styler()) 53 | 54 | if len(self.columns) > 20 and not force_plot: 55 | raise Exception("Are you crazy? Too many columns") 56 | 57 | # pass styler to each series plot 58 | kwargs['styler'] = styler 59 | for col in self.columns: 60 | series = self[col] 61 | series.fplot(*args, **kwargs) 62 | 63 | pd.DataFrame.fplot = df_plot 64 | 65 | def series_plot_markers(self, label=None, yvalues=None, *args, **kwargs): 66 | """ 67 | Really just an automated way of calling gcf 68 | """ 69 | fig = gcf() 70 | label = plot_label(self, label, **kwargs) 71 | fig.plot_markers(str(label), self, yvalues=yvalues, *args, **kwargs) 72 | 73 | pd.Series.fplot_markers = series_plot_markers 74 | 75 | def figure(*args, **kwargs): 76 | """ create Figure and set as current """ 77 | kwargs['warn'] = False 78 | fig = Figure(*args, **kwargs) 79 | scf(fig) 80 | return fig 81 | 82 | def plot_label(self, label=None, **kwargs): 83 | """ 84 | Logic to grab plot label 85 | 86 | Note that this both takes label as a positional argument and a keyword. 87 | 88 | This is a legacy issue where instead of grabbing label from kwargs, 89 | which is how matplotlib handles it, I decided to make it a positional 90 | argument, under the assumption that you would almost always need a label. 91 | 92 | While this is true, it makes it so I have to check check both types of 93 | args. 94 | """ 95 | label = label or kwargs.get('label') 96 | if label is None: # allow series to define non `.name` label 97 | label = getattr(self, 'plot_label', None) 98 | label = label or self.name 99 | 100 | prefix = kwargs.pop('prefix', None) 101 | if prefix: 102 | label = prefix +' '+label 103 | 104 | return label 105 | 106 | # try to monkey patch pandas_composition 107 | # we do this to get access to the subclass self 108 | # get around: https://github.com/dalejung/pandas-composition/issues/19 109 | try: 110 | import pandas_composition as pc 111 | pc.UserFrame.fplot = df_plot 112 | pc.UserSeries.fplot = series_plot 113 | pc.UserSeries.fplot_markers = series_plot_markers 114 | except ImportError: 115 | pass 116 | -------------------------------------------------------------------------------- /ts_charting/formatter.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import numpy as np 4 | import pandas as pd 5 | import matplotlib.ticker as ticker 6 | from pandas import datetools, DatetimeIndex 7 | from pandas.tseries.resample import _get_range_edges 8 | import pandas.lib as lib 9 | 10 | class TimestampLocator(ticker.Locator): 11 | """ 12 | Place a tick on every multiple of some base number of points 13 | plotted, eg on every 5th point. It is assumed that you are doing 14 | index plotting; ie the axis is 0, len(data). This is mainly 15 | useful for x ticks. 16 | """ 17 | def __init__(self, index, freq=None, xticks=None, min_ticks=5): 18 | """ 19 | place ticks on the i-th data points where (i-offset)%base==0 20 | 21 | Parameters 22 | ---------- 23 | index : DatetimeIndex 24 | freq : pd.Offset (optional) 25 | Fixed frequency 26 | xticks : pd.DateTimeIndex or bool array 27 | Either an Index of datetimes representing ticks or a boolean 28 | array where True denotes a tick 29 | min_ticks : int 30 | Minimum number of ticks before jumping up to a lower frequency 31 | 32 | """ 33 | self.index = index 34 | self.min_ticks = min_ticks 35 | self.freq = freq 36 | self.set_xticks(xticks) 37 | 38 | self.gen_freq = None 39 | 40 | _xticks = None 41 | 42 | @property 43 | def xticks(self): 44 | return self._xticks 45 | 46 | def set_xticks(self, value): 47 | xticks = self._init_xticks(value) 48 | self._xticks = xticks 49 | 50 | def _init_xticks(self, xticks): 51 | if xticks is None: 52 | return xticks 53 | 54 | if isinstance(xticks, (list, tuple)): 55 | xticks = pd.DatetimeIndex(xticks) 56 | 57 | if isinstance(xticks, pd.DatetimeIndex): 58 | xticks = pd.Series(1, index=xticks).reindex(self.index, fill_value=0) 59 | return xticks.astype(bool) 60 | 61 | if xticks.dtype == bool: 62 | return xticks 63 | 64 | raise Exception("xticks must be DatetimeIndex or bool Series") 65 | 66 | def __call__(self): 67 | 'Return the locations of the ticks' 68 | vmin, vmax = self.axis.get_view_interval() 69 | xticks = self._process(vmin, vmax) 70 | return self.raise_if_exceeds(xticks) 71 | 72 | def _process(self, vmin, vmax): 73 | vmin = int(math.ceil(vmin)) 74 | vmax = int(math.floor(vmax)) or len(self.index) - 1 75 | vmax = min(vmax, len(self.index) -1) 76 | 77 | if self.xticks is None: 78 | xticks = self._xticks_from_freq(vmin, vmax) 79 | else: 80 | if self.xticks.dtype != bool: 81 | raise Exception("xticks must be a bool series") 82 | sub_xticks = self.xticks[vmin:vmax] 83 | xticks = np.where(sub_xticks)[0] 84 | return xticks 85 | 86 | def _xticks_from_freq(self, vmin, vmax): 87 | dmin = self.index[vmin] 88 | dmax = self.index[vmax] 89 | 90 | freq = self.freq 91 | if freq is None: 92 | freq = self.infer_scale(dmin, dmax) 93 | 94 | self.gen_freq = freq 95 | 96 | sub_index = self.index[vmin:vmax] 97 | 98 | xticks = self.generate_xticks(sub_index, freq) 99 | return xticks 100 | 101 | def infer_scale(self, dmin, dmax): 102 | delta = datetools.relativedelta(dmax, dmin) 103 | 104 | numYears = (delta.years * 1.0) 105 | numMonths = (numYears * 12.0) + delta.months 106 | numDays = (numMonths * 31.0) + delta.days 107 | numWeeks = numDays // 7 108 | numHours = (numDays * 24.0) + delta.hours 109 | numMinutes = (numHours * 60.0) + delta.minutes 110 | nums = [('AS', numYears), ('MS', numMonths), ('W', numWeeks), ('D', numDays), ('H', numHours), 111 | ('15min', numMinutes)] 112 | freq = None 113 | for key, num in nums: 114 | if num > self.min_ticks: 115 | freq = key 116 | break 117 | 118 | return freq 119 | 120 | def generate_xticks(self, index, freq): 121 | """ 122 | This is a lot like binning except we done have an extra label 123 | containing an unclosed bin. 124 | """ 125 | tg = pd.TimeGrouper(freq) 126 | binlabels = pd.Series(1, index=index).groupby(tg).grouper.binlabels 127 | # bound between start/end of index. 128 | if binlabels[0] < index[0]: 129 | binlabels = binlabels[1:] 130 | if binlabels[-1] > index[-1]: 131 | binlabels = binlabels[:-1] 132 | 133 | if tg.closed == 'left': 134 | method = 'bfill' 135 | else: 136 | method = 'ffill' 137 | ticks = index.get_indexer(binlabels, method) 138 | # -1 is a sentinel for out of index range 139 | ticks = ticks[ticks != -1] 140 | return ticks 141 | 142 | class TimestampFormatter(object): 143 | def __init__(self, index, locator): 144 | self.index = index 145 | self.locator = locator 146 | 147 | def format_date(self, x, pos=None): 148 | thisind = np.clip(int(x+0.5), 0, len(self.index)-1) 149 | date = self.index[thisind] 150 | gen_freq = self.locator.gen_freq 151 | if gen_freq == 'T': 152 | return date.strftime('%H:%M %m/%d/%y') 153 | if gen_freq == 'H': 154 | return date.strftime('%H:%M %m/%d/%y') 155 | if gen_freq in ['D', 'W']: 156 | return date.strftime('%m/%d/%Y') 157 | if gen_freq in ['M', 'MS']: 158 | return date.strftime('%m/%d/%Y') 159 | return date.strftime('%m/%d/%Y %H:%M') 160 | 161 | @property 162 | def ticker_func(self): 163 | return ticker.FuncFormatter(self.format_date) 164 | -------------------------------------------------------------------------------- /ts_charting/test/test_formatter.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | import numpy as np 4 | import pandas as pd 5 | import pandas.util.testing as tm 6 | 7 | import ts_charting.formatter as formatter 8 | 9 | plot_index = pd.date_range(start="2000-1-1", freq="B", periods=10000) 10 | class TestTimestampLocator(TestCase): 11 | 12 | def __init__(self, *args, **kwargs): 13 | TestCase.__init__(self, *args, **kwargs) 14 | 15 | def runTest(self): 16 | pass 17 | 18 | def setUp(self): 19 | pass 20 | 21 | def test_inferred_freq(self): 22 | """ 23 | inferred freqs are based off of min_ticks 24 | """ 25 | plot_index = pd.date_range(start="2000-1-1", freq="B", periods=10000) 26 | tl = formatter.TimestampLocator(plot_index) 27 | # showing only the first 10 should give us days 28 | xticks = tl._process(1, 10) 29 | assert tl.gen_freq == 'D' 30 | 31 | # showing only the first 70 should give us weeks 32 | xticks = tl._process(1, 6 * 7 + 1) 33 | assert tl.gen_freq == 'W' 34 | 35 | # months should trigger at around 6 * 31 36 | xticks = tl._process(1, 6 * 31 ) 37 | assert tl.gen_freq == 'MS' 38 | 39 | # year should trigger at around 6 *366 40 | xticks = tl._process(1, 6 * 366 + 1) 41 | assert tl.gen_freq == 'AS' 42 | 43 | def test_fixed_freq(self): 44 | """ 45 | Test passing in a fixed freq. This will allow len(xticks) 46 | less than min_ticks 47 | """ 48 | plot_index = pd.date_range(start="2000-1-1", freq="D", periods=10000) 49 | tl = formatter.TimestampLocator(plot_index, 'MS') 50 | xticks = tl._process(0, 30*3) 51 | assert len(xticks) == 3 52 | 53 | tl = formatter.TimestampLocator(plot_index, 'MS') 54 | xticks = tl._process(0, 30*6) 55 | assert len(xticks) == 6 56 | 57 | tl = formatter.TimestampLocator(plot_index, 'W') 58 | xticks = tl._process(0, 10*7) 59 | assert len(xticks) == 10 60 | 61 | tl = formatter.TimestampLocator(plot_index, 'AS') 62 | xticks = tl._process(0, 10 * 365) 63 | assert len(xticks) == 10 64 | 65 | def test_bool_xticks(self): 66 | """ 67 | ability to set ticks with a bool series where True == tick 68 | """ 69 | plot_index = pd.date_range(start="2000-1-1", freq="D", periods=10000) 70 | freq = 'M' 71 | ds = pd.Series(1, index=plot_index) 72 | # True when freq market is hit 73 | bool_ticks = ds.resample(freq).reindex(plot_index).fillna(0).astype(bool) 74 | tl = formatter.TimestampLocator(plot_index, xticks=bool_ticks) 75 | xticks = tl._process(0, 90) 76 | tl = formatter.TimestampLocator(plot_index, freq=freq) 77 | correct = tl._process(0, 90) 78 | tm.assert_almost_equal(xticks, correct) 79 | 80 | freq = 'MS' 81 | ds = pd.Series(1, index=plot_index) 82 | # True when freq market is hit 83 | bool_ticks = ds.resample(freq).reindex(plot_index).fillna(0).astype(bool) 84 | tl = formatter.TimestampLocator(plot_index, xticks=bool_ticks) 85 | xticks = tl._process(3, 94) 86 | tl = formatter.TimestampLocator(plot_index, freq=freq) 87 | correct = tl._process(3, 94) 88 | tm.assert_almost_equal(xticks, correct) 89 | 90 | freq = 'W' 91 | ds = pd.Series(1, index=plot_index) 92 | # True when freq market is hit 93 | bool_ticks = ds.resample(freq).reindex(plot_index).fillna(0).astype(bool) 94 | tl = formatter.TimestampLocator(plot_index, xticks=bool_ticks) 95 | xticks = tl._process(3, 94) 96 | tl = formatter.TimestampLocator(plot_index, freq=freq) 97 | correct = tl._process(3, 94) 98 | tm.assert_almost_equal(xticks, correct) 99 | 100 | def test_list_of_datetimes(self): 101 | """ 102 | The other xticks option is sending in a DatetimeIndex of the dates you want 103 | """ 104 | plot_index = pd.date_range(start="2000-1-1", freq="D", periods=10000) 105 | freq = 'M' 106 | 107 | dates = pd.Series(1, index=plot_index).resample(freq).index 108 | tl = formatter.TimestampLocator(plot_index, xticks=dates) 109 | test = tl._process(3, 900) 110 | 111 | tl = formatter.TimestampLocator(plot_index, freq=freq) 112 | correct = tl._process(3, 900) 113 | tm.assert_almost_equal(test, correct) 114 | 115 | freq = 'MS' 116 | dates = pd.Series(1, index=plot_index).resample(freq).index 117 | tl = formatter.TimestampLocator(plot_index, xticks=dates) 118 | test = tl._process(3, 900) 119 | 120 | tl = formatter.TimestampLocator(plot_index, freq=freq) 121 | correct = tl._process(3, 900) 122 | tm.assert_almost_equal(test, correct) 123 | 124 | # straight list of dates 125 | freq = 'MS' 126 | dates = pd.Series(1, index=plot_index).resample(freq).index 127 | dates = list(dates) 128 | tl = formatter.TimestampLocator(plot_index, xticks=dates) 129 | test = tl._process(3, 900) 130 | 131 | tl = formatter.TimestampLocator(plot_index, freq=freq) 132 | correct = tl._process(3, 900) 133 | tm.assert_almost_equal(test, correct) 134 | 135 | def test_sparse_index(self): 136 | """ 137 | Make sure that we match the correct dates even if the 138 | freq gives us days not in index. 139 | """ 140 | freq = 'MS' 141 | index = pd.date_range(start="2000-1-1", freq="B", periods=1000) 142 | tl = formatter.TimestampLocator(plot_index, freq=freq) 143 | test = tl._process(0, 900) 144 | new_ind = index[test] 145 | assert np.all(new_ind.day < 5) 146 | 147 | freq = 'M' 148 | index = pd.date_range(start="2000-1-1", freq="B", periods=1000) 149 | tl = formatter.TimestampLocator(plot_index, freq=freq) 150 | test = tl._process(0, 900) 151 | new_ind = index[test] 152 | assert np.all(new_ind.day > 25) 153 | 154 | if __name__ == '__main__': 155 | import nose 156 | nose.runmodule(argv=[__file__,'-vs','-x','--pdb', '--pdb-failure'],exit=False) 157 | -------------------------------------------------------------------------------- /ts_charting/test/test_figure.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | import numpy as np 4 | import pandas as pd 5 | import pandas.util.testing as tm 6 | 7 | import ts_charting.figure as figure 8 | from ts_charting.figure import process_series 9 | 10 | class Testprocess_data(TestCase): 11 | 12 | def __init__(self, *args, **kwargs): 13 | TestCase.__init__(self, *args, **kwargs) 14 | 15 | def runTest(self): 16 | pass 17 | 18 | def setUp(self): 19 | pass 20 | 21 | def test_already_aligned(self): 22 | plot_index = pd.date_range(start="2000", freq="D", periods=100) 23 | series = pd.Series(range(100), index=plot_index) 24 | plot_series = process_series(series, plot_index) 25 | tm.assert_almost_equal(series, plot_series) 26 | tm.assert_almost_equal(plot_series.index, plot_index) 27 | 28 | def test_partial_plot(self): 29 | """ 30 | Test plotting series that is a subset of plot_index. 31 | Should align and fill with nans 32 | """ 33 | plot_index = pd.date_range(start="2000", freq="D", periods=100) 34 | series = pd.Series(range(100), index=plot_index) 35 | series = series[:50] # only first 50 36 | plot_series = process_series(series, plot_index) 37 | 38 | # have same index 39 | tm.assert_almost_equal(plot_series.index, plot_index) 40 | assert plot_series.count() == 50 41 | assert np.all(plot_series[50:].isnull()) # method=None so fill with nan 42 | assert np.all(plot_series[:50] == series[:50]) 43 | 44 | def test_unaligned_indexes(self): 45 | """ 46 | Test when series.index and plot_index have no common datetimes 47 | """ 48 | plot_index = pd.date_range(start="2000", freq="D", periods=100) 49 | series = pd.Series(range(100), index=plot_index) 50 | # move days to 11 PM the night before 51 | shift_series = series.tshift(-1, '1h') 52 | plot_series = process_series(shift_series, plot_index) 53 | # without method, data doesn't align and we nothing but nans 54 | tm.assert_almost_equal(plot_series.index, plot_index) # index aligh properly 55 | assert np.all(plot_series.isnull()) # no data 56 | 57 | # method = 'ffill' 58 | plot_series = process_series(shift_series, plot_index, method='ffill') 59 | # without method, data doesn't align and we nothing but nans 60 | tm.assert_almost_equal(plot_series.index, plot_index) # index align 61 | # since we're forward filling a series we tshifted into past 62 | # plot_series should just equal the original series 63 | tm.assert_almost_equal(plot_series, series) 64 | 65 | 66 | def test_different_freqs(self): 67 | """ 68 | Tests indexes of differeing frequencies. This is more of repeat 69 | test of test_partial_plot but with many holes instead of one half missing 70 | value. 71 | """ 72 | plot_index = pd.date_range(start="2000-01-01", freq="D", periods=100) 73 | series = pd.Series(range(100), index=plot_index) 74 | grouped_series = series.resample('MS', 'max') 75 | plot_series = process_series(grouped_series, plot_index) 76 | tm.assert_almost_equal(plot_series.index, plot_index) # index align 77 | # method=None, dropna should give back same series 78 | tm.assert_almost_equal(plot_series.dropna(), grouped_series) 79 | 80 | plot_series = process_series(grouped_series, plot_index, method='ffill') 81 | tm.assert_almost_equal(plot_series.index, plot_index) # index align 82 | assert plot_series.isnull().sum() == 0 83 | month_ind = plot_series.index.month - 1 84 | # assert that each value corresponds to its month in grouped_series 85 | assert np.all(grouped_series[month_ind] == plot_series) 86 | 87 | def test_scalar(self): 88 | """ 89 | Test the various ways we handle scalars. 90 | """ 91 | plot_index = pd.date_range(start="2000-01-01", freq="D", periods=100) 92 | plot_series = process_series(5, plot_index) 93 | tm.assert_almost_equal(plot_series.index, plot_index) # index align 94 | assert np.all(plot_series == 5) 95 | 96 | # explicitly pass in the series index. Should have a plot_series with only iloc[10:20] 97 | # equal to the scalar 5. 98 | plot_series = process_series(5, plot_index, series_index=plot_index[10:20]) 99 | tm.assert_almost_equal(plot_series.index, plot_index) # index align 100 | assert np.all(plot_series[10:20] == 5) 101 | assert plot_series.isnull().sum() == 90 102 | 103 | # no plot_index. This still works because we're passing in series_index 104 | plot_series = process_series(5, None, series_index=plot_index[10:20]) 105 | correct = pd.Series(5, index=plot_index[10:20]) 106 | tm.assert_almost_equal(correct, plot_series) 107 | 108 | # without any index, a scalar will error. Cannot plot a scalar on an 109 | # empty plot without passing in an index 110 | try: 111 | plot_series = process_series(5, None) 112 | except: 113 | pass 114 | else: 115 | assert False, "scalar should fail without plot_index or series_index" 116 | 117 | def test_iterable(self): 118 | """ 119 | Non pd.Series iterables require an equal length series_index or 120 | plot_index. 121 | """ 122 | try: 123 | plot_series = process_series(range(10), None) 124 | except: 125 | pass 126 | else: 127 | assert False, "iterable should fail without plot_index or series_index" 128 | 129 | plot_index = pd.date_range(start="2000-01-01", freq="D", periods=100) 130 | try: 131 | plot_series = process_series(range(10), plot_index) 132 | except: 133 | pass 134 | else: 135 | assert False, "iterable requires an index of same length" 136 | 137 | # equal length, good times 138 | plot_series = process_series(range(10), plot_index[:10]) 139 | correct = pd.Series(range(10), index=plot_index[:10]) 140 | tm.assert_almost_equal(correct, plot_series) 141 | 142 | if __name__ == '__main__': 143 | import nose 144 | nose.runmodule(argv=[__file__,'-vvs','-x','--pdb', '--pdb-failure'],exit=False) 145 | -------------------------------------------------------------------------------- /ts_charting/figure.py: -------------------------------------------------------------------------------- 1 | import collections 2 | 3 | import numpy as np 4 | import pandas as pd 5 | from matplotlib import pyplot as plt 6 | 7 | import ts_charting.styles as cstyler 8 | from ts_charting.formatter import TimestampFormatter, TimestampLocator 9 | from ts_charting.util import process_signal 10 | 11 | class Figure(object): 12 | def __init__(self, rows=1, cols=1, skip_na=True, warn=True): 13 | self.figure = plt.figure() 14 | self.rows = rows 15 | self.cols = cols 16 | self.ax = None 17 | self.axnum = None 18 | self.graphers = {} 19 | self.grapher = None 20 | self.skip_na = skip_na 21 | if rows == 1: 22 | self.set_ax(1) 23 | if warn: 24 | print("Use charting.figure instead. Lowercase f") 25 | 26 | def get_ax(self, axnum): 27 | if axnum not in self.graphers: 28 | return None 29 | return self.graphers[axnum].ax 30 | 31 | def _set_ax(self, axnum): 32 | self.axnum = axnum 33 | grapher = self.graphers[axnum] 34 | self.grapher = grapher 35 | self.ax = grapher.ax 36 | 37 | def init_ax(self, axnum, sharex=None, skip_na=None): 38 | if skip_na is None: 39 | skip_na = self.skip_na 40 | shared_index = None 41 | if type(sharex) == int: 42 | shared_index = self.graphers[sharex].index 43 | ax = plt.subplot(self.rows, self.cols, axnum) 44 | self.graphers[axnum] = Grapher(ax, skip_na, sharex=shared_index) 45 | 46 | def set_ax(self, axnum, sharex=None, skip_na=None): 47 | if self.get_ax(axnum) is None: 48 | self.init_ax(axnum, sharex, skip_na) 49 | self._set_ax(axnum) 50 | 51 | def align_xlim(self, axes=None): 52 | """ 53 | Make sure the axes line up their xlims 54 | """ 55 | # TODO take a param of ax numbers to align 56 | left = [] 57 | right = [] 58 | for grapher in list(self.graphers.values()): 59 | if grapher.index is None: 60 | continue 61 | l, r = grapher.ax.get_xlim() 62 | left.append(l) 63 | right.append(r) 64 | 65 | for grapher in list(self.graphers.values()): 66 | if grapher.index is None: 67 | continue 68 | grapher.ax.set_xlim(min(left), max(right)) 69 | 70 | def plot(self, label, series, index=None, method=None, **kwargs): 71 | if self.ax is None: 72 | print('NO AX set') 73 | return 74 | self.figure.tight_layout() 75 | plt.xticks(rotation=30, ha='right') 76 | self.grapher.plot(label, series, index, method, **kwargs) 77 | 78 | def plot_markers(self, label, series, yvalues=None, xindex=None, **kwargs): 79 | if self.ax is None: 80 | print('NO AX set') 81 | return 82 | self.grapher.plot_markers(label, series, yvalues, xindex, **kwargs) 83 | 84 | def clear(self, axnum=None): 85 | if axnum is None: 86 | axnum = self.axnum 87 | 88 | grapher = self.graphers[axnum] 89 | ax = grapher.ax 90 | ax.clear() 91 | del self.graphers[axnum] 92 | self.ax = None 93 | self.set_ax(axnum) 94 | 95 | def __getattr__(self, name): 96 | if hasattr(self.grapher, name): 97 | return getattr(self.grapher, name) 98 | raise AttributeError() 99 | 100 | class Grapher(object): 101 | def __init__(self, ax, skip_na=True, sharex=None): 102 | self.index = None 103 | self.formatter = None 104 | self.locator = None 105 | self.ax = ax 106 | self.skip_na = skip_na 107 | self.sharex = sharex 108 | self.styler = cstyler.styler() 109 | self.yaxes = {} 110 | 111 | @property 112 | def right_ax(self): 113 | return self.yaxes.get('right', None) 114 | 115 | def is_datetime(self): 116 | return self.index.inferred_type in ('datetime', 'date', 'datetime64') 117 | 118 | def find_ax(self, secondary_y, kwargs): 119 | """ 120 | multiple y-axis support. stay backward compatible with secondary_y 121 | 122 | Note: we take in the actual kwargs because we want to pop('yax') 123 | to affect the callers kwargs 124 | """ 125 | yax = kwargs.pop('yax', None) 126 | if yax and secondary_y: 127 | raise Exception('yax and secondary_y should not both be set') 128 | if secondary_y: 129 | yax = 'right' 130 | 131 | ax = self.ax 132 | if yax: 133 | ax = self.get_yax(yax) 134 | return ax 135 | 136 | def plot(self, label, series, index=None, method=None, secondary_y=False, 137 | **kwargs): 138 | 139 | # use default styler if one is not passed in 140 | styler = kwargs.pop('styler', self.styler) 141 | if styler: 142 | style_dict = next(styler) 143 | # note we do it this way so explicit args passed in kwargs 144 | # override style_dict 145 | kwargs = dict(list(style_dict.items()) + list(kwargs.items())) 146 | 147 | plot_index = self.index 148 | if self.sharex is not None: 149 | plot_index = self.sharex 150 | 151 | series = process_series(series, plot_index, series_index=index, method=method) 152 | 153 | # first plot, set index 154 | if self.index is None: 155 | self.index = series.index 156 | 157 | is_datetime = self.is_datetime() 158 | if is_datetime: 159 | self.setup_datetime(self.index) 160 | 161 | plot_series = series 162 | 163 | if label is not None: 164 | kwargs['label'] = label 165 | 166 | xax = self.index 167 | if self.skip_na and is_datetime: 168 | xax = np.arange(len(self.index)) 169 | self.formatter.index = self.index 170 | 171 | ax = self.find_ax(secondary_y, kwargs) 172 | ax.plot(xax, plot_series, **kwargs) 173 | 174 | # generate combined legend 175 | lines, labels = self.consolidate_legend() 176 | self.ax.legend(lines, labels, loc=0) 177 | 178 | if is_datetime: 179 | # plot empty space for leading NaN and trailing NaN 180 | # not sure if I should only call this for is_datetime 181 | plt.xlim(0, len(self.index)-1) 182 | 183 | def consolidate_legend(self): 184 | """ 185 | consolidate the legends from all axes and merge into one 186 | """ 187 | lines, labels = self.ax.get_legend_handles_labels() 188 | for k, ax in list(self.yaxes.items()): 189 | new_lines, new_labels = ax.get_legend_handles_labels() 190 | lines = lines + new_lines 191 | labels = labels + new_labels 192 | return lines, labels 193 | 194 | def get_right_ax(self): 195 | return self.get_yax('right') 196 | 197 | def get_yax(self, name): 198 | """ 199 | Get a yaxis keyed by name. Returns a newly 200 | generted twinx if it doesn't exist 201 | """ 202 | def make_patch_spines_invisible(ax): 203 | ax.set_frame_on(True) 204 | ax.patch.set_visible(False) 205 | for sp in list(ax.spines.values()): 206 | sp.set_visible(False) 207 | 208 | size = len(self.yaxes) 209 | if name not in self.yaxes: 210 | ax = self.ax.twinx() 211 | self.yaxes[name] = ax 212 | # set spine 213 | ax.spines["right"].set_position(("outward", 50 * size)) 214 | make_patch_spines_invisible(ax) 215 | ax.spines["right"].set_visible(True) 216 | ax.set_ylabel(name) 217 | 218 | self.set_formatter() 219 | return self.yaxes[name] 220 | 221 | def setup_datetime(self, index=None): 222 | """ 223 | Setup the int based matplotlib x-index to translate 224 | to datetime 225 | 226 | Separated out here to share between plot and candlestick 227 | """ 228 | if index is None: 229 | index = self.index 230 | 231 | is_datetime = self.is_datetime() 232 | if self.formatter is None and self.skip_na and is_datetime: 233 | self.locator = TimestampLocator(index) 234 | self.formatter = TimestampFormatter(index, self.locator) 235 | self.set_formatter() 236 | 237 | # reupdate index 238 | self.locator.index = index 239 | self.formatter.index = index 240 | 241 | def set_index(self, index): 242 | if self.index is not None: 243 | raise Exception("Cannot set index if index already exists") 244 | self.index = index 245 | 246 | def set_formatter(self): 247 | """ quick call to reset locator/formatter when lost. i.e. boxplot """ 248 | if self.formatter: 249 | # set to xaxis 250 | ax = self.ax 251 | ax.xaxis.set_major_locator(self.locator) 252 | ax.xaxis.set_major_formatter(self.formatter.ticker_func) 253 | ax.xaxis.grid(True) 254 | 255 | def set_xticks(self, xticks): 256 | # freq 257 | if isinstance(xticks, str): 258 | self.locator.freq = xticks 259 | self.locator.set_xticks(None) 260 | else: 261 | self.locator.set_xticks(xticks) 262 | self.locator.freq = None 263 | 264 | def plot_markers(self, label, series, yvalues=None, xindex=None, **kwargs): 265 | if yvalues is not None: 266 | series = process_signal(series, yvalues) 267 | props = {} 268 | props['linestyle'] = 'None' 269 | props['marker'] = 'o' 270 | props['markersize'] = 10 271 | props.update(kwargs) 272 | 273 | if xindex is not None: 274 | series = series.copy() 275 | series.index = xindex 276 | 277 | self.plot(label, series, **props) 278 | 279 | def line(self, val, *args, **kwargs): 280 | """ print horizontal line """ 281 | self.plot(None, val, *args, **kwargs) 282 | 283 | def process_series(series, plot_index, series_index=None, method=None): 284 | """ 285 | Parameters 286 | ---------- 287 | series : int/float, iterable, pd.Series 288 | Data to be plotted. 289 | plot_index : pd.DatetimeIndex 290 | Index of the x-axis. Can be None if subplot has no plots. 291 | series_index : pd.DatetimeIndex 292 | Index of series to be plotted. Only really applicable to iterables/scalars. 293 | method : {'backfill', 'bfill', 'pad', 'ffill', None} 294 | Passed along to `reindex`. 295 | """ 296 | if series_index is not None: 297 | series = pd.Series(series, index=series_index) 298 | 299 | # no need to align index 300 | if plot_index is None: 301 | if hasattr(series, 'index') and isinstance(series.index, pd.Index): 302 | return series 303 | else: 304 | raise Exception("First plotted series must have an index") 305 | 306 | if np.isscalar(series): 307 | series = pd.Series(series, index=plot_index) 308 | 309 | if not isinstance(series, pd.Series) and isinstance(series, collections.Iterable): 310 | series = pd.Series(series, index=plot_index) 311 | 312 | series = series.reindex(plot_index, method=method) 313 | 314 | return series 315 | --------------------------------------------------------------------------------