├── .gitignore ├── 3d.py ├── README.md ├── animate.py ├── env.yml ├── functions.py ├── mbrl.py └── plot.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pdf 2 | .DS_Store 3 | -------------------------------------------------------------------------------- /3d.py: -------------------------------------------------------------------------------- 1 | from scipy.integrate import odeint 2 | 3 | 4 | def sim_lorenz(X, t, sigma, beta, rho): 5 | """The Lorenz equations.""" 6 | u, v, w = X 7 | up = -sigma * (u - v) 8 | vp = rho * u - v - u * w 9 | wp = -beta * w + u * v 10 | return up, vp, wp 11 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # plotting-basics 2 | Collection of the few lines of code for clean plots in Plotly and Matplotlib. 3 | 4 | Includes: 5 | - `plot.py` for basic plots 6 | - `functions.py` for useful plotly functions to add features like mark every and easy error bar line plots. 7 | - `3d.py` for 3d plots (under construction) 8 | - `animate.py` for advanced features (under construction) 9 | 10 | I use these tools to make pretty things in my papers. It is very important. I am willing to help if people are having trouble with my code. 11 | ![Make pretty plots, get paid.](https://firebasestorage.googleapis.com/v0/b/firescript-577a2.appspot.com/o/imgs%2Fapp%2Fnatolambert%2FgbyyYxbe4d.png?alt=media&token=dc6377f1-666a-438e-849e-16a20647904e) 12 | 13 | -------------------------------------------------------------------------------- /animate.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/natolambert/plotting-basics/642fe0b80f5ee8bc6a4d280ab6d0c30a27cae8e0/animate.py -------------------------------------------------------------------------------- /env.yml: -------------------------------------------------------------------------------- 1 | name: plot 2 | channels: 3 | - plotly 4 | - defaults 5 | dependencies: 6 | - asn1crypto=1.3.0=py37_0 7 | - blas=1.0=mkl 8 | - ca-certificates=2020.1.1=0 9 | - certifi=2019.11.28=py37_1 10 | - cffi=1.14.0=py37hb5b8e2f_0 11 | - chardet=3.0.4=py37_1003 12 | - cryptography=2.8=py37ha12b0ac_0 13 | - cycler=0.10.0=py37_0 14 | - freetype=2.9.1=hb4e5f40_0 15 | - idna=2.9=py_1 16 | - intel-openmp=2019.4=233 17 | - kiwisolver=1.1.0=py37h0a44026_0 18 | - libcxx=4.0.1=hcfea43d_1 19 | - libcxxabi=4.0.1=hcfea43d_1 20 | - libedit=3.1.20181209=hb402a30_0 21 | - libffi=3.2.1=h475c297_4 22 | - libgfortran=3.0.1=h93005f0_2 23 | - libpng=1.6.37=ha441bb4_0 24 | - mkl=2019.4=233 25 | - mkl-service=2.3.0=py37hfbe908c_0 26 | - mkl_fft=1.0.15=py37h5e564d8_0 27 | - mkl_random=1.1.0=py37ha771720_0 28 | - ncurses=6.2=h0a44026_0 29 | - numpy=1.18.1=py37h7241aed_0 30 | - numpy-base=1.18.1=py37h6575580_1 31 | - openssl=1.1.1f=h1de35cc_0 32 | - pandas=1.0.3=py37h6c726b0_0 33 | - pip=20.0.2=py37_1 34 | - plotly=4.6.0=py_0 35 | - plotly-orca=1.2.1=1 36 | - psutil=5.7.0=py37h1de35cc_0 37 | - pycparser=2.20=py_0 38 | - pyopenssl=19.1.0=py37_0 39 | - pyparsing=2.4.6=py_0 40 | - pysocks=1.7.1=py37_0 41 | - python=3.7.7=hc70fcce_0_cpython 42 | - python-dateutil=2.8.1=py_0 43 | - pytz=2019.3=py_0 44 | - readline=8.0=h1de35cc_0 45 | - requests=2.23.0=py37_0 46 | - retrying=1.3.3=py37_2 47 | - scipy=1.4.1=py37h9fa6033_0 48 | - seaborn=0.10.0=py_0 49 | - setuptools=46.1.3=py37_0 50 | - six=1.14.0=py37_0 51 | - sqlite=3.31.1=ha441bb4_0 52 | - tk=8.6.8=ha441bb4_0 53 | - tornado=6.0.4=py37h1de35cc_1 54 | - urllib3=1.25.8=py37_0 55 | - wheel=0.34.2=py37_0 56 | - xz=5.2.4=h1de35cc_4 57 | - zlib=1.2.11=h1de35cc_3 58 | - pip: 59 | - matplotlib==3.2.1 60 | prefix: /Users/nato/miniconda3/envs/plot 61 | 62 | -------------------------------------------------------------------------------- /functions.py: -------------------------------------------------------------------------------- 1 | import plotly 2 | import plotly.graph_objects as go 3 | 4 | 5 | def add_marker(err_traces, color=[], symbol=None, skip=None, m_every=5): 6 | mark_every = m_every 7 | size = 30 8 | l = len(err_traces[0]['x']) 9 | if skip is not None: 10 | size_list = [0] * skip + [size] + [0] * (mark_every - 1 - skip) 11 | else: 12 | size_list = [size] + [0] * (mark_every - 1) 13 | repeat = int(l / mark_every) 14 | size_list = size_list * repeat 15 | line = err_traces[0] 16 | line['mode'] = 'lines+markers' 17 | line['cliponaxis'] = False 18 | line['marker'] = dict( 19 | color=line['line']['color'], 20 | size=size_list, 21 | symbol="x" if symbol is None else symbol, 22 | line=dict(width=1, 23 | color='rgba(1,1,1,1)') 24 | ) 25 | err_traces[0] = line 26 | return err_traces 27 | 28 | def generate_errorbar_traces(ys, xs=None, percentiles='66+95', color=None, name=None): 29 | if xs is None: 30 | xs = [list(range(len(y))) for y in ys] 31 | 32 | minX = min([len(x) for x in xs]) 33 | 34 | if "#" in color: 35 | color = 'rgb' + str(tuple(int(color[i + 1:i + 3], 16) for i in (0, 2, 4))) 36 | 37 | xs = [x[:minX] for x in xs] 38 | ys = [y[:minX] for y in ys] 39 | 40 | assert np.all([(len(y) == len(ys[0])) for y in ys]), \ 41 | 'Y should be the same size for all traces' 42 | 43 | assert np.all([(x == xs[0]) for x in xs]), \ 44 | 'X should be the same for all traces' 45 | 46 | y = np.array(ys) 47 | 48 | def median_percentile(data, des_percentiles='66+95'): 49 | median = np.nanmedian(data, axis=0) 50 | out = np.array(list(map(int, des_percentiles.split("+")))) 51 | for i in range(out.size): 52 | assert 0 <= out[i] <= 100, 'Percentile must be >0 <100; instead is %f' % out[i] 53 | list_percentiles = np.empty((2 * out.size,), dtype=out.dtype) 54 | list_percentiles[0::2] = out # Compute the percentile 55 | list_percentiles[1::2] = 100 - out # Compute also the mirror percentile 56 | percentiles = np.nanpercentile(data, list_percentiles, axis=0) 57 | return [median, percentiles] 58 | 59 | out = median_percentile(y, des_percentiles=percentiles) 60 | ymed = out[0] 61 | 62 | err_traces = [ 63 | dict(x=xs[0], y=ymed.tolist(), mode='lines', name=name, type='scatter', legendgroup=f"group-{name}", 64 | line=dict(color=color, width=4))] 65 | 66 | intensity = .3 67 | ''' 68 | interval = scipy.stats.norm.interval(percentile/100, loc=y, scale=np.sqrt(variance)) 69 | interval = np.nan_to_num(interval) # Fix stupid case of norm.interval(0) returning nan 70 | ''' 71 | 72 | for i, p_str in enumerate(percentiles.split("+")): 73 | p = int(p_str) 74 | high = out[1][2 * i, :] 75 | low = out[1][2 * i + 1, :] 76 | 77 | err_traces.append(dict( 78 | x=xs[0] + xs[0][::-1], 79 | type='scatter', 80 | y=(high).tolist() + (low).tolist()[::-1], 81 | fill='toself', 82 | fillcolor=(color[:-1] + str(f", {intensity})")).replace('rgb', 'rgba') 83 | if color is not None else None, 84 | line=dict(color='rgba(0,0,0,0)'), # transparent'), 85 | # legendgroup=f"group-{name}", 86 | showlegend=False, 87 | name=name + str(f"_std{p}") if name is not None else None, 88 | ), ) 89 | intensity -= .1 90 | 91 | return err_traces, xs, ys 92 | 93 | 94 | -------------------------------------------------------------------------------- /mbrl.py: -------------------------------------------------------------------------------- 1 | # These are functions I made to evaluate models in MBRL. Need to integrate these with my open-source library. 2 | 3 | def plot_dynamics_model_tests(vis_env, filtered_logs, configs, diffs): 4 | """ 5 | :param checkpoint: File to load from /checkpoint/usr/runs/ with dynamics model and training dataset 6 | :param policy: policy to plan with for trajectory estimates 7 | :param vis_env: page to return plots on visdom (default 'dynam') 8 | :return: 9 | """ 10 | import gym 11 | for log_dir, logs in filtered_logs.items(): 12 | # get name 13 | if len(diffs[log_dir]) > 0: 14 | title = ",".join([f"{e[0]}={e[1]}" for e in diffs[log_dir]]) 15 | else: 16 | title = log_dir 17 | cfg = configs[log_dir][0] 18 | 19 | env_name = logs[0]['env_name'] 20 | ys = np.stack([np.asarray(log['rewards']) for log in logs]) 21 | 22 | # load checkpoint items 23 | dynam_model = logs[0]['dynamics_model'] 24 | training_dataset = logs[0]['training_dataset'] 25 | testing_dataset = logs[0]['testing_dataset'] 26 | groundtruth_ep = logs[0]['episode'] 27 | groundtruth_traj = torch.stack([s.s0 for s in groundtruth_ep], dim=0) 28 | 29 | # generate arrays 30 | states = torch.stack([s for s in training_dataset.states0], dim=0) 31 | groundtruth_pred = torch.stack([s for s in training_dataset.states1], dim=0) 32 | actions = torch.stack([a for a in training_dataset.actions], dim=0) 33 | 34 | # prediction one step dynamics 35 | predictions = dynam_model.predict(states, actions) 36 | state0 = groundtruth_traj[0, :] # states[0, :] 37 | device = state0.device 38 | 39 | # convert for plotting 40 | predictions = predictions.cpu().numpy() 41 | groundtruth_pred = groundtruth_pred.cpu().numpy() 42 | 43 | plot_one_step_ahead_predictions(predictions[:, :, 0], groundtruth_pred, vis_env, title) 44 | 45 | env = gym.make(cfg.env.name) 46 | policy = utils.instantiate(cfg.policy, cfg) 47 | policy.setup(dynam_model, env.action_space, utils.get_static_method(cfg.env.reward_func)) 48 | 49 | def gen_plans(true_trajectory, policy, dynam_model): 50 | time = true_trajectory.shape[0] 51 | ds = true_trajectory.shape[1] 52 | da = policy.cfg.env.action_size 53 | plans = torch.empty((time, policy.planning_horizon + 1, ds)) 54 | for i, state in enumerate(true_trajectory): 55 | action_seqs = policy.plan_action_sequence(state) 56 | plans[i, :, :] = compute_trajectories(dynam_model, state, 57 | action_seqs.actions.reshape(1, policy.planning_horizon, da) 58 | ).squeeze() 59 | 60 | return plans 61 | 62 | def to_np(tensor): 63 | return tensor.cpu().detach().numpy() 64 | 65 | plans = gen_plans(groundtruth_traj, policy, dynam_model) 66 | plot_planning(to_np(plans), to_np(groundtruth_traj), vis_env) 67 | 68 | # 3. plot predicted trajectories of planned actions through model 69 | N = 10 70 | da = actions.shape[1] 71 | ds = state0.shape[0] 72 | 73 | # adjust policy to longer prediction horizon 74 | cfg.policy.params.planning_horizon = groundtruth_traj.shape[0] 75 | policy = utils.instantiate(cfg.policy, cfg) 76 | policy.setup(dynam_model, env.action_space, utils.get_static_method(cfg.env.reward_func)) 77 | 78 | actions_long = gather_actions(policy.optimizer, N, state0, da).to(device) 79 | predictions_traj = compute_trajectories(dynam_model, state0, 80 | actions_long.reshape(N, policy.planning_horizon, da) 81 | ).squeeze() 82 | plot_trajectory_pred(to_np(predictions_traj[:, :-1, :]), to_np(groundtruth_traj), vis_env) 83 | 84 | 85 | def plot_one_step_ahead_predictions(predictions, groundtruth, vis_env, title='One Step Model Predictions', 86 | cmap=cm.tab10): 87 | """ 88 | This function plot the predictions of the model against the groundtruth state. 89 | For visualization purposes, it also sorts the data w.r.t. groundtruth. 90 | (by doing so, we get as a bonus also the cdf of the groundtruth) 91 | :param predictions: np.array of dimension [N.Data x N.States] 92 | :param groundtruth: np.array of dimension [N.Data x N.States] 93 | :return: 94 | """ 95 | assert groundtruth.shape == predictions.shape, 'Wrong dimensions' 96 | 97 | for i, d in enumerate(range(groundtruth.shape[1])): 98 | opts = dict(title=title, 99 | font=dict(family='Times New Roman', size=18, color='#7f7f7f'), 100 | showlegend=True, 101 | xlabel='Sorted Ground Truth', 102 | ylabel=f"State Dimension - {i}", 103 | # legend=['Predictions', 'Ground Truth'], 104 | linecolor=np.array([[int(c * 255) for c in cmap(1)], [int(c * 255) for c in cmap(0)]]), linewidth=3, 105 | win=f"dim-{i}") 106 | 107 | idx = np.argsort(groundtruth[:, d]) 108 | vis.line(Y=np.hstack([predictions[idx, d:d + 1], groundtruth[idx, d:d + 1]]), 109 | name=['Predictions', 'Ground Truth'], env=vis_env + "_pred", 110 | opts=opts) 111 | 112 | 113 | def plot_planning_mpl(plans, groundtruth, vis_env, title='', cmap=cm.viridis): 114 | """ 115 | Plots trajectories replanned at each timestep given a ground truth trajectory. Use this for lower memory impact 116 | :param plans: np.array of dimension [Time x Horizon x N.States] 117 | :param groundtruth: np.array of dimension [Time x N.States] 118 | :param cmap: matplotlib cmap 119 | :return: 120 | """ 121 | n_curves = plans.shape[0] 122 | h = plans.shape[1] 123 | T = groundtruth.shape[0] 124 | dim_state = groundtruth.shape[1] 125 | 126 | colors = cmap(np.linspace(0, 0.85, n_curves)) # Removing final 15% of the viridis colormap 127 | 128 | fig_list = [] 129 | for i, d in enumerate(range(dim_state)): 130 | fig = plt.figure() 131 | for it, t in enumerate(range(n_curves)): 132 | y = np.arange(it, it + h) 133 | plt.plot(y, plans[it, :, d], color=colors[it], linewidth=1) 134 | plt.plot(groundtruth[:, d], color='black', linewidth=2, label='groundtruth') 135 | plt.ylabel('Variable %d' % i) 136 | plt.xlabel('Time') 137 | plt.show() 138 | fig_list.append(fig) 139 | 140 | for i, fig in enumerate(fig_list): 141 | win = f"dim_{i}_policy_plans" 142 | vis.matplot(fig, win=win, env=vis_env) 143 | 144 | 145 | def plot_planning(plans, groundtruth, vis_env, filter_num=3, dims=[0], name=''): 146 | """ 147 | Plots trajectories replanned at each timestep given a ground truth trajectory 148 | :param plans: np.array of dimension [Time x Horizon x N.States] 149 | :param groundtruth: np.array of dimension [Time x N.States] 150 | :param cmap: matplotlib cmap 151 | :return: 152 | """ 153 | 154 | n_curves = plans.shape[0] 155 | h = plans.shape[1] 156 | T = groundtruth.shape[0] 157 | dim_state = groundtruth.shape[1] 158 | 159 | for i, d in enumerate(dims): 160 | traces = [] 161 | for it, t in enumerate(range(n_curves)): 162 | if (it % filter_num != 0): continue 163 | y = np.arange(it, it + h).tolist() 164 | 165 | estimated_trace_pt = dict( 166 | x=y, 167 | y=plans[it, :, d].tolist(), 168 | type='scatter', 169 | mode='markers', 170 | marker=dict(color=np.arange(h).tolist(), colorscale='Viridis', size=6), 171 | showlegend=(it == 1), 172 | legendgroup=f"trajs-{d}", 173 | name='Planned Trajectory', 174 | ) 175 | traces.append(estimated_trace_pt) 176 | 177 | estimated_trace_line = dict( 178 | x=y, 179 | y=plans[it, :, d].tolist(), 180 | type='line', 181 | mode='lines', 182 | line=dict(color='rgba(100,100,100,.3)', width=1), 183 | showlegend=False, 184 | legendgroup=f"trajs-{d}", 185 | ) 186 | traces.append(estimated_trace_line) 187 | 188 | truth_trace = dict( 189 | x=np.arange(T).tolist(), 190 | y=groundtruth[:, d].tolist(), 191 | type='line', 192 | showlegend=True, 193 | line=dict(color='black', width=4), 194 | name='True Episode', 195 | ) 196 | traces.append(truth_trace) 197 | 198 | layout = dict(title=f"Estimated Trajectories from Each Episode Step (Dim {d}) - {name}", 199 | xaxis={'title': 'Time (Step)'}, 200 | yaxis={'title': 'State Value'}, 201 | font=dict(family='Times New Roman', size=18, color='#7f7f7f'), 202 | legend={'x': .83, 'y': .05, 'bgcolor': 'rgba(50, 50, 50, .03)'}, 203 | # width=1200, 204 | # height=700, 205 | ) 206 | 207 | vis_env._send({'data': traces, 'layout': layout, 'win': f"traj-dim-{d}-{name}", 'eid': 'policies'}) 208 | 209 | 210 | def plot_trajectory_pred(predictions, groundtruth, vis_env, title='State Trajectories Predicted via Optimizer', 211 | cmap=cm.tab10): 212 | """ 213 | Plots predicted trajectories from initial state for given number of trajectories 214 | :param predictions: np.array of dimension [Time x N.States x N.Predictions] 215 | :param groundtruth: np.array of dimension [Time x N.States] 216 | :param title: 217 | :param cmap: 218 | :return: 219 | """ 220 | dim_state = groundtruth.shape[1] 221 | 222 | linecolors = [[int(c * 255) for c in cmap(1)] for i in range(predictions.shape[2])] 223 | linecolors.append([int(c * 255) for c in cmap(0)]) 224 | linecolors = np.array(linecolors) 225 | 226 | for i, d in enumerate(range(dim_state)): 227 | opts = dict(title=title, 228 | font=dict(family='Times New Roman', size=18, color='#7f7f7f'), 229 | # showlegend=False, 230 | xlabel='Time (points)', 231 | ylabel=f"State Dimension - {i}", 232 | linecolor=linecolors, 233 | linewidth=[2], 234 | # legend=['Predictions', 'Ground Truth'], 235 | win=f"dim-{i}") 236 | 237 | data = np.hstack([predictions[n, :, d:d + 1] for n in range(predictions.shape[2])]) 238 | data = np.hstack([data, groundtruth[:, d:d + 1]]) 239 | 240 | vis.line(Y=data, X=np.arange(groundtruth.shape[0]), 241 | name=['Predictions', 'Ground Truth'], env=vis_env + "_long", 242 | opts=opts) -------------------------------------------------------------------------------- /plot.py: -------------------------------------------------------------------------------- 1 | # file tools 2 | import os 3 | import sys 4 | 5 | # plotting tools 6 | import plotly 7 | import plotly.graph_objects as go 8 | import seaborn as sns 9 | import matplotlib 10 | import matplotlib.pyplot as plt 11 | 12 | # Core 13 | import numpy as np 14 | 15 | ###### Data ###### 16 | # Frome https://plotly.com/python/line-charts/ 17 | title = 'Main Source for News' 18 | labels = ['Television', 'Newspaper', 'Internet', 'Radio'] 19 | colors = ['#1f77b4', # muted blue 20 | '#ff7f0e', # safety orange 21 | '#2ca02c', # cooked asparagus green 22 | '#d62728', # brick red 23 | ] 24 | 25 | """ 26 | colors = [ 27 | '#1f77b4', # muted blue 28 | '#ff7f0e', # safety orange 29 | '#2ca02c', # cooked asparagus green 30 | '#d62728', # brick red 31 | '#9467bd', # muted purple 32 | '#8c564b', # chestnut brown 33 | '#e377c2', # raspberry yogurt pink 34 | '#7f7f7f', # middle gray 35 | '#bcbd22', # curry yellow-green 36 | '#17becf' # blue-teal 37 | ] 38 | 39 | markers = [ 40 | "cross", 41 | "circle", 42 | "x", 43 | "triangle-up", 44 | "y-down-open", 45 | "diamond", 46 | "hourglass", 47 | "hash", 48 | "star", 49 | "square", 50 | ] 51 | 52 | 3d plots remember plot_bgcolor='white', 53 | """ 54 | 55 | # NOTE: This file is for line plots. For a single dimension of data, one should use quartile plots like this 56 | """ 57 | https://plotly.com/python/box-plots/#box-plot-with-precomputed-quartiles 58 | import plotly.graph_objects as go 59 | 60 | fig = go.Figure() 61 | fig.add_trace(go.Box( 62 | y=[0.75, 5.25, 5.5, 6, 6.2, 6.6, 6.80, 7.0, 7.2, 7.5, 7.5, 7.75, 8.15, 63 | 8.15, 8.65, 8.93, 9.2, 9.5, 10, 10.25, 11.5, 12, 16, 20.90, 22.3, 23.25], 64 | name="All Points", 65 | jitter=0.3, 66 | pointpos=-1.8, 67 | boxpoints='all', # represent all points 68 | marker_color='rgb(7,40,89)', 69 | line_color='rgb(7,40,89)' 70 | )) 71 | 72 | fig.add_trace(go.Box( 73 | y=[0.75, 5.25, 5.5, 6, 6.2, 6.6, 6.80, 7.0, 7.2, 7.5, 7.5, 7.75, 8.15, 74 | 8.15, 8.65, 8.93, 9.2, 9.5, 10, 10.25, 11.5, 12, 16, 20.90, 22.3, 23.25], 75 | name="Only Whiskers", 76 | boxpoints=False, # no data points 77 | marker_color='rgb(9,56,125)', 78 | line_color='rgb(9,56,125)' 79 | )) 80 | 81 | fig.add_trace(go.Box( 82 | y=[0.75, 5.25, 5.5, 6, 6.2, 6.6, 6.80, 7.0, 7.2, 7.5, 7.5, 7.75, 8.15, 83 | 8.15, 8.65, 8.93, 9.2, 9.5, 10, 10.25, 11.5, 12, 16, 20.90, 22.3, 23.25], 84 | name="Suspected Outliers", 85 | boxpoints='suspectedoutliers', # only suspected outliers 86 | marker=dict( 87 | color='rgb(8,81,156)', 88 | outliercolor='rgba(219, 64, 82, 0.6)', 89 | line=dict( 90 | outliercolor='rgba(219, 64, 82, 0.6)', 91 | outlierwidth=2)), 92 | line_color='rgb(8,81,156)' 93 | )) 94 | 95 | fig.add_trace(go.Box( 96 | y=[0.75, 5.25, 5.5, 6, 6.2, 6.6, 6.80, 7.0, 7.2, 7.5, 7.5, 7.75, 8.15, 97 | 8.15, 8.65, 8.93, 9.2, 9.5, 10, 10.25, 11.5, 12, 16, 20.90, 22.3, 23.25], 98 | name="Whiskers and Outliers", 99 | boxpoints='outliers', # only outliers 100 | marker_color='rgb(107,174,214)', 101 | line_color='rgb(107,174,214)' 102 | )) 103 | 104 | 105 | fig.update_layout(title_text="Box Plot Styling Outliers") 106 | fig.show() 107 | """ 108 | 109 | mode_size = [8, 8, 12, 8] 110 | line_size = [2, 2, 4, 2] 111 | 112 | x_data = np.vstack((np.arange(2001, 2014),)*4) 113 | 114 | y_data = np.array([ 115 | [74, 82, 80, 74, 73, 72, 74, 70, 70, 66, 66, 69], 116 | [45, 42, 50, 46, 36, 36, 34, 35, 32, 31, 31, 28], 117 | [13, 14, 20, 24, 20, 24, 24, 40, 35, 41, 43, 50], 118 | [18, 21, 18, 21, 16, 14, 13, 18, 17, 16, 19, 23], 119 | ]) 120 | 121 | ###### Init plot / subplots ###### 122 | # mpl 123 | fig_mpl, ax = plt.subplots() 124 | 125 | # plotly 126 | fig_plo = plotly.subplots.make_subplots(rows=1, cols=1) 127 | 128 | ###### add data ###### 129 | 130 | for i in range(0, 4): 131 | # mpl 132 | ax.plot(x_data[i][:-1], y_data[i], color=colors[i], linewidth=line_size[i], label=labels[i]) 133 | 134 | # plotly 135 | fig_plo.add_trace(go.Scatter(x=x_data[i], y=y_data[i], mode='lines', 136 | name=labels[i], 137 | line=dict(color=colors[i], width=line_size[i]), 138 | connectgaps=True, 139 | )) 140 | 141 | 142 | ##### ##### ##### ##### ##### ##### ##### ##### ##### ##### 143 | ###### ###### Stlye below ###### ###### ###### ###### ##### 144 | ##### ##### ##### ##### ##### ##### ##### ##### ##### ##### 145 | 146 | ###### Font ###### 147 | # mpl 148 | font = {'size': 24, 'family': 'serif', 'serif': ['Times']} 149 | matplotlib.rc('font', **font) 150 | matplotlib.rc('text', usetex=True) 151 | 152 | # plotly 153 | fig_plo.update_layout(font=dict( 154 | family="Times New Roman, Times, serif", 155 | size=24, 156 | color="black" 157 | ), 158 | ) 159 | 160 | ###### axis lines & fun ###### 161 | # mpl 162 | ax.spines['right'].set_visible(False) 163 | ax.spines['top'].set_visible(False) 164 | ax.set_xlabel("Year") 165 | ax.set_ylabel("Market Share (%)") 166 | 167 | # plotly 168 | fig_plo.update_xaxes(title_text="Year", linecolor='black', # account for white background 169 | row=1, col=1, zeroline=True, zerolinecolor='rgba(0,0,0,.5)', zerolinewidth=1,) 170 | fig_plo.update_yaxes(title_text="Market Share (%)", linecolor='black', # account for white background 171 | row=1, col=1, zeroline=True, zerolinecolor='rgba(0,0,0,.5)', zerolinewidth=1,) 172 | 173 | ###### grid lines & fun ###### 174 | # mpl 175 | # Hide grid lines 176 | ax.grid(False) 177 | 178 | # Hide axes ticks 179 | #ax.set_xticks([]) 180 | #ax.set_yticks([]) 181 | #ax.set_zticks([]) 182 | 183 | #plotly 184 | fig.update_layout(xaxis_showgrid=False, yaxis_showgrid=False) 185 | 186 | 187 | 188 | ###### resolution & additions ###### 189 | # mpl 190 | ax.legend() 191 | # no resolution tool until saving :( 192 | 193 | # plotly 194 | fig_plo.update_layout(showlegend=True,) 195 | fig_plo.update_layout(width=1000, 196 | height=1000,) 197 | 198 | ###### background color ###### 199 | # mpl 200 | ax.set_facecolor((1, 1, 1)) 201 | 202 | # plotly 203 | fig_plo.update_layout(plot_bgcolor='white',) 204 | 205 | ###### saving ###### 206 | # mpl 207 | plt.show() 208 | #ax.savefig(os.getcwd()+"plop_mpl.pdf", dpi=300) 209 | 210 | # plotly 211 | # advanced for saving 212 | fig_plo.update_layout(legend_orientation="h", 213 | legend=dict(x=.6, y=0.07, 214 | bgcolor='rgba(205, 223, 212, .4)', 215 | bordercolor="Black", 216 | ), 217 | plot_bgcolor='white', 218 | width=1600, 219 | height=1000, 220 | margin=dict(r=10, l=10, b=10, t=10), 221 | ) 222 | 223 | fig_plo.write_image(os.getcwd()+"plot_plotly.pdf") 224 | fig_plo.show() 225 | 226 | 227 | --------------------------------------------------------------------------------