├── .gitignore ├── LICENSE ├── README.md ├── bin └── ml_board ├── devel_requirements.txt ├── gifs ├── dropdown.gif ├── table.gif └── thoughts.png ├── ml_board ├── Logger.py ├── __init__.py └── utils.py ├── requirements.txt └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled python modules. 2 | *.pyc 3 | 4 | # Setuptools distribution folder. 5 | /dist/ 6 | 7 | # Python egg metadata, regenerated from source files by setuptools. 8 | /*.egg-info 9 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 bbli 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ml_board 2 | ## Why ml_board/Limitations of Tensorboard 3 | I decided to create this machine learning dashboard after using [tensorboardX](https://github.com/lanpa/tensorboardX) for a couple months in training regular neural networks for deep reinforcement learning. As great as tensorboardX was in helping me debug and understand neural networks(it certainly beats printing out statistics to the terminal), I found myself using only a subset of its features, and also discover certain limitations about tensorboard 4 | 5 | * **Text is buggy**: The text tab will sometimes take a long time to load, or it will load the text from another run. When training a machine learning model, I often go through many settings and test various hypotheses. Having a reliable log is a must, so that I can know which thoughts goes with which runs. 6 | * **Disconnect between visualization and settings**: My hypotheses often involve varying a hyperparameter and seeing its effect on quantities such as the loss, percentage of activations, etc. But the graphs don't have a legend that tells me which the setting each run used. As a result, I am forced to go back and forth between the Scalars and Text tab, interrupting my train of thought. As an example, if I were to log a bunch of experiments from random search, I would have put in a non-trivial amount of effort to remember which experiment used which setting(since the number won't monotonically increase/decrease as in grid search) 7 | * **Inadequate search**: Suppose I wanted to view all the runs that achieved a certain accuracy, or were run on a particular hyperparameter setting. As far as I know, this is not possible in tensorboard. So in some senses, past runs are only useful if I can remember them. 8 | 9 | Although tensorboard has great visualization capabilities(embeddings,computational graphs,etc), it is not the best tool for tracking, presenting, and organizing the knowledge one obtains as they run through many machine learning experiments. So the focus of this project will not be on visualization [tensorboard](https://github.com/tensorflow/tensorboard), or experiment control and reproducibility [sacred](https://github.com/IDSIA/sacred), but on creating a better interface for the scientist to view the relationship between model parameters and its output characteristics. 10 | ## Features 11 | * **Interactive Table -> Filters Visualizations**: Allows individual selection of runs, and numerical filtering based on equality/inequality. Once these choices are made, the Plots, Histograms, and Image Tabs are updated accordingly, allowing you to choose which run's visualizations you see. Also, because the plots are plotly Graph objects, one can click on the individual items in the plot legend to remove the corresponding plot from view 12 | 13 | ![table](gifs/table.gif) 14 | * **Legend Dropdown -> Hyperparameter Display**: Allows you to choose the hyperparameter setting(well technically any statistic you wish) to be displayed as the title(or legend) for each run in the Plots/Histogram/Image Tabs. I limited the title to one item because I did not want the figures to be cluttered with words, which I believe is worth the tradeoff of the occasional lack of uniqueness. 15 | 16 | ![dropdown](gifs/dropdown.gif) 17 | * **Figure/AutoUpdate Toggle**: As in tensorboard, you can click on the figure's title to minimize it. Also, every 10 seconds, the app will reread the data from the database, unless the autoupdate toggle is turned off. 18 | 19 | * **Log of Thoughts**: As explained later, the user specifies a mongodb database name and mongodb collection(I call them folder) name where the run's statistics will be stored. The intended usage for this is that the user will specify a different folder for every "hypotheses" they want to test. Examples include "debug_binary_loss" or "lr_hyper_tune". The dashboard will only display all the runs in one folder at a time, since visual information takes up a substantial amount of space -> inevitably will lead to scrolling, a flow state killer. The limitations with this is that the user is unable to view the entire progression of their ideas. To preserve this folder-independent, sequential flow of thoughts, the Thoughts Tab aggregates logged thoughts across all folders within the given database and displays them in order by time, and furthermore is labeled by the folder name to give the comments context. 20 | 21 | ![thoughts](gifs/thoughts.png) 22 | 23 | 24 | * **Extensibility**: The Dash library comes with awesome interactive components, such as the Table and Tabs components that were used in my project. Because I did not need to write these primitives myself, I could focus my attention on the domain problem/vision at hand, something that Peter Norvig talks about in [As We May Program](https://vimeo.com/215418110) (see 17:50-19:00). Though this is a project geared towards my own personal usage, it can be easily extended(b/c it is written in a language data scientists are intimately familiar with+official tutorial is excellent+great community) by end-users as they see fit. After all, I created this with no prior web app experience! 25 | 26 | # Installation 27 | Until the tabs feature is integrated into the master branch of [dash](https://github.com/plotly/dash), and I do more testing, and write up the documentation, you will have to manually install the package with the following commands: 28 | ``` 29 | git clone https://github.com/bbli/ml_board.git 30 | cd ml_board 31 | ## Activate/create conda environment you want to install in 32 | pip3 install . 33 | ``` 34 | After this, install MongoDB(and make sure it is enabled and started) and you are good to go! 35 | 36 | If something goes wrong during usage, and you can't debug it, you can try installing the exact versions this package was tested on: 37 | 1. Commenting out the `install_requires` field in setup.py 38 | 2. run `pip install -r requirements.txt && pip install .` 39 | 40 | # Usage 41 | ### Logging 42 | Usage is very similar to tensorboard. Difference are: 43 | * No need to specify a count, as ml_board will append the result to a MongoDB list. The other is that ml_board has an additional `add_experiment_parameter` which is intended to log hyperparameters to a table 44 | ``` 45 | from ml_board import SummaryWriter 46 | w= SummaryWriter('name_of_mongodb_database','name_of_mongo_db_collection') 47 | 48 | ## These two append to a list 49 | w.add_scalar("example_loss_name",example_loss_value) 50 | w.add_histogram("example_histogram_name",example_loss_value) 51 | 52 | w.add_image("example_picture_name",2d_numpy_matrix_in_range_0_to_1) 53 | w.add_thought("example_thoughts") 54 | ## the current time is automatically logged as an experiment parameter when you create a SummaryWriter object 55 | w.add_experiment_parameter("example_hyperparameter_name",example_hyperparamter_value) 56 | w.close() 57 | ``` 58 | For more details, look at the `Logger.py` file in the ml_board folder 59 | ### Visualizing 60 | From the command line, enter 61 | ``` 62 | ml_board --database_name name_of_mongodb_database --folder_name name_of_mongo_db_collection 63 | # shorthand notation 64 | ml_board -db name_of_mongodb_database -f name_of_mongo_db_collection 65 | # specific port. Default 8000 66 | ml_board -db name_of_mongodb_database -f name_of_mongo_db_collection -p 8050 67 | ``` 68 | ### Comments 69 | * If autoUpdate is on, do not filter rows as it will be overwritten 70 | * Don't click on the filter rows button twice, or it will filter permemantly. If this does happen, refresh the webpage to reset the app's state. 71 | * FYI, the Histogram Tab generally takes the longest time to update(b/c multiple plotly Figure objects are created for each histogram name). 72 | 73 | # TODO 74 | * allow user to change folders from within the dashboard 75 | * put priority on the callbacks(basically if I am on the Plots Tab, its callbacks should finish first) 76 | * testing/documentation 77 | -------------------------------------------------------------------------------- /bin/ml_board: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import dash 3 | import dash_core_components as dcc 4 | import dash_html_components as html 5 | import dash_table_experiments as dt 6 | from dash.dependencies import Input, Output, State 7 | from ml_board.utils import * 8 | # import sys 9 | # sys.path.append('/home/benson/Dropbox/Code/Projects/ml_board/ml_board') 10 | # from utils import * 11 | import plotly.graph_objs as go 12 | import ipdb 13 | 14 | import argparse 15 | 16 | ################ **Parsing Command Line Arguments** ################## 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--database_name','-db', 19 | required=True, 20 | help='which database you want to use' ) 21 | parser.add_argument('--folder_name','-f', 22 | required=True, 23 | help='which folder(aka a MongoDB collection) you want to view' ) 24 | parser.add_argument('--port','-p', 25 | default = 8000, 26 | help='port number you want to serve on' ) 27 | 28 | args = parser.parse_args() 29 | 30 | ################ **App Startup** ################## 31 | app = dash.Dash(__name__) 32 | app.title = "Machine Learning Dashboard" 33 | # Boostrap CSS. 34 | app.css.append_css({ 35 | "external_url": "https://maxcdn.bootstrapcdn.com/bootstrap/3.3.7/css/bootstrap.min.css" 36 | }) 37 | 38 | # Extra Dash styling. 39 | app.css.append_css({ 40 | "external_url": 'https://codepen.io/chriddyp/pen/bWLwgP.css' 41 | }) 42 | 43 | # JQuery is required for Bootstrap. 44 | app.scripts.append_script({ 45 | "external_url": "https://code.jquery.com/jquery-3.2.1.min.js" 46 | }) 47 | 48 | # Bootstrap Javascript. 49 | app.scripts.append_script({ 50 | "external_url": "https://maxcdn.bootstrapcdn.com/bootstrap/3.3.7/js/bootstrap.min.js" 51 | }) 52 | 53 | ############################################################## 54 | class BaseTab(): 55 | def __init__(self,database_name,folder_name,title,f): 56 | self.title = title 57 | self.f = f 58 | self.nameObjects_for_each_run = getDictOfNameObjects(database_name,folder_name,self.title,self.f) 59 | self.figure_names = getFigureNames(self.nameObjects_for_each_run) 60 | self.database_name = database_name 61 | self.folder_name = folder_name 62 | 63 | ######################################### 64 | def createHTMLStructure(self): 65 | html_row_list = [] 66 | for figure_name in self.figure_names: 67 | button_row = html.Div(html.Button(figure_name,id=self.title+':'+figure_name+'button'),className='row') 68 | html_row_list.append(button_row) 69 | 70 | figure_row = html.Div(id=self.title+':'+figure_name+'content') 71 | html_row_list.append(figure_row) 72 | return html.Div(html_row_list,id=self.title) 73 | def assignCallbacks(self,app): 74 | for figure_name in self.figure_names: 75 | self.assignFigureShowCallback(figure_name,app) 76 | self.assignFigureCallback(figure_name,app) 77 | 78 | self.assignTabShowCallback(app) 79 | 80 | 81 | ############################################# 82 | def assignFigureShowCallback(self,figure_name,app): 83 | @app.callback( 84 | ## Still Need to define this html structure 85 | Output(self.title+':'+figure_name+'content','style'), 86 | [Input(self.title+':'+figure_name+'button','n_clicks')] 87 | ) 88 | def show_figure(n_clicks): 89 | if n_clicks!=None: 90 | if n_clicks%2==0: 91 | return {'display':'inline'} 92 | else: 93 | return {'display':'None'} 94 | ##inital display 95 | return {'display':'inline'} 96 | ######################### 97 | def assignFigureCallback(self,figure_name,app): 98 | @app.callback( 99 | Output(self.title+':'+figure_name+'content','children'), 100 | [Input('buffer','children'), 101 | ## can change due to user interaction 102 | Input('legend','value'), 103 | ## can change due to filter 104 | Input('datatable', 'rows'), 105 | ## can change based on user interaction 106 | Input('datatable', 'selected_row_indices')], 107 | ) 108 | # @profile(self.title) 109 | # @timeFigureUpdate(self.title) 110 | def update_figure_and_data_structure(children,legend_value,rows,selected_row_indices): 111 | ################ **Updating Data Structures** ################## 112 | global g_dict_of_param_dicts 113 | global g_legend_names 114 | g_dict_of_param_dicts = getParamDict(self.database_name,self.folder_name) 115 | g_legend_names = getLegendNames(g_dict_of_param_dicts) 116 | 117 | self.nameObjects_for_each_run = getDictOfNameObjects(self.database_name,self.folder_name,self.title,self.f) 118 | self.figure_names = getFigureNames(self.nameObjects_for_each_run) 119 | ################ **Interacting with DataTable to get Selected Runs** ################## 120 | times_of_each_run = getSelectedRunsFromDatatable(rows,selected_row_indices) 121 | figure_content_for_this_name = self.getFigureContentForThisName(figure_name,times_of_each_run,legend_value) 122 | return figure_content_for_this_name 123 | 124 | def getFigureContentForThisName(self,figure_name,times_of_each_run,legend_value): 125 | ''' 126 | figure_name is so we know figure info is pulled correctly 127 | times_of_each_run is so we know which runs to pull 128 | legend_value for formatting the figure 129 | ''' 130 | raise NotImplementedError("Implement this function!") 131 | ######################### 132 | def assignTabShowCallback(self,app): 133 | @app.callback( 134 | Output(self.title,'style'), 135 | [Input('tabs','value')] 136 | ) 137 | def show_tab(value): 138 | if value == self.title: 139 | return {'display':'inline'} 140 | else: 141 | return {'display':'none'} 142 | ################ **Components** ################## 143 | class PlotTab(BaseTab): 144 | def __init__(self,database_name,folder_name): 145 | title = 'Plots' 146 | f = None 147 | super().__init__(database_name,folder_name,title,f) 148 | def getFigureContentForThisName(self,figure_name,times_of_each_run,legend_value): 149 | plot_for_each_run = [] 150 | for time in times_of_each_run: 151 | one_run_plots = self.nameObjects_for_each_run[time] 152 | one_run_params = g_dict_of_param_dicts[time] 153 | # run_dict = {'y':list(filtered_df[plot_name])} 154 | scatter_obj = self.createScatterObject(figure_name,one_run_plots,one_run_params,legend_value) 155 | plot_for_each_run.append(scatter_obj) 156 | 157 | data_dict= {'data':plot_for_each_run} 158 | ## Note id is required, even though I don't use it in my callbacks 159 | figure_object = dcc.Graph(id=figure_name+' Plot',figure= data_dict) 160 | return html.Div(html.Div(figure_object,className='col-md-10'),className='row') 161 | @staticmethod 162 | def createScatterObject(name,one_run_plots,one_run_params,legend_value): 163 | label = legend_value+':'+str(one_run_params[legend_value]) 164 | return go.Scatter( 165 | y = list(one_run_plots[name]), 166 | mode = 'lines', 167 | name = label, 168 | text = label, 169 | hoverinfo='y' 170 | ) 171 | 172 | class HistogramTab(BaseTab): 173 | def __init__(self,database_name,folder_name): 174 | title = 'Histograms' 175 | f = None 176 | super().__init__(database_name,folder_name,title,f) 177 | def getFigureContentForThisName(self,figure_name,times_of_each_run,legend_value): 178 | histo_component_list = [] 179 | for time in times_of_each_run: 180 | one_run_histogram = self.nameObjects_for_each_run[time] 181 | one_run_params = g_dict_of_param_dicts[time] 182 | 183 | histo_component = self.createHistogramComponent(time,figure_name,one_run_histogram,one_run_params,legend_value) 184 | histo_component_list.append(histo_component) 185 | 186 | return html.Div(histo_component_list,className='row') 187 | # @profile("Temp title") ## Note this needs to be the first decorator 188 | @staticmethod 189 | def createHistogramComponent(time,figure_name,one_run_histogram,one_run_params,legend_value): 190 | ################ **Creating Data Object** ################## 191 | one_run_values = one_run_histogram[figure_name] 192 | histo_data = [go.Histogram(x=one_run_values,histnorm='probability')] 193 | label = legend_value+':'+str(one_run_params[legend_value]) 194 | histo_layout = go.Layout(title=label) 195 | data_dict = go.Figure(data=histo_data,layout=histo_layout) 196 | ################################################## 197 | 198 | ## Note id is required, even though I don't use it in my callbacks 199 | figure_object = dcc.Graph(id=time+':'+figure_name+' Histogram',figure= data_dict) 200 | return html.Div(figure_object,className='col-md-6') 201 | 202 | 203 | class ImageTab(BaseTab): 204 | def __init__(self,database_name,folder_name): 205 | title = 'Images' 206 | f = getBase64Encoding 207 | super().__init__(database_name,folder_name,title,f) 208 | def getFigureContentForThisName(self,figure_name,times_of_each_run,legend_value): 209 | html_row_objects = [] 210 | ################ **Creating the Components** ################## 211 | image_component_list = [] 212 | for time in times_of_each_run: 213 | one_run_images = self.nameObjects_for_each_run[time] 214 | one_run_params = g_dict_of_param_dicts[time] 215 | 216 | image_component = self.createImageComponent(figure_name,one_run_images,one_run_params,legend_value) 217 | image_component_list.append(image_component) 218 | image_component_row = html.Div(image_component_list,className='row') 219 | html_row_objects.append(image_component_row) 220 | return html_row_objects 221 | @staticmethod 222 | def createImageComponent(figure_name,one_run_image,one_run_params,legend_value): 223 | base64_image = one_run_image[figure_name] 224 | figure_object = html.Img(src='data:image/png;base64,{}'.format(base64_image),className='center-block') 225 | figure_caption = legend_value+':'+str(one_run_params[legend_value]) 226 | figure=html.Figure([figure_caption,figure_object],style={'text-align':'center'}) 227 | return html.Div(figure,className='col-md-4') 228 | 229 | class ThoughtsTab(BaseTab): 230 | def __init__(self,database_name): 231 | self.title = 'Thoughts' 232 | self.dict_of_all_thought_lists = getDictOfAllThoughtLists(database_name) 233 | self.ordered_thoughtList_keys = getOrderedKeys(self.dict_of_all_thought_lists) 234 | def createHTMLStructure(self): 235 | html_row_list = createHTMLRowList(self) 236 | # print(html_row_list,file=sys.stdout) 237 | return html.Div(html_row_list,id=self.title) 238 | def assignCallbacks(self,app): 239 | self.assignTabShowCallback(app) 240 | self.assignFigureCallback(app) 241 | ## no button callback, so no need to define on seperate figures 242 | def assignFigureCallback(self,app): 243 | @app.callback( 244 | Output(self.title,'children'), 245 | [Input('buffer','children')] 246 | ) 247 | def update_thoughts_tab(children): 248 | self.dict_of_all_thought_lists = getDictOfAllThoughtLists(database_name) 249 | self.ordered_thoughtList_keys = getOrderedKeys(self.dict_of_all_thought_lists) 250 | html_row_list = createHTMLRowList(self) 251 | # print(self.dict_of_all_thought_lists) 252 | # print("break") 253 | # print(html_row_list,file=sys.stdout) 254 | return html_row_list 255 | 256 | 257 | 258 | 259 | ################ **Global Variables** ################## 260 | # database_name='software_testing' 261 | # database_name = 'pendulum' 262 | # folder_name = 'lunarlander' 263 | # folder_name = 'ml_board_gifs' 264 | 265 | database_name = args.database_name 266 | folder_name = args.folder_name 267 | plotTab_object = PlotTab(database_name,folder_name) 268 | histoTab_object = HistogramTab(database_name,folder_name) 269 | imageTab_object = ImageTab(database_name,folder_name) 270 | thoughtTab_object = ThoughtsTab(database_name) 271 | 272 | g_dict_of_param_dicts = getParamDict(database_name,folder_name) 273 | g_legend_names = getLegendNames(g_dict_of_param_dicts) 274 | g_inital_legend_name = g_legend_names[0] 275 | 276 | g_tab_names = [plotTab_object.title,histoTab_object.title,imageTab_object.title,thoughtTab_object.title] 277 | 278 | ################ **Layout** ################## 279 | app.layout = html.Div( 280 | [html.Div( 281 | [html.H1("Machine Learning Dashboard", className="text-center")] 282 | ,className="row")]+ 283 | [html.Div( 284 | [html.Div( 285 | dcc.Checklist( 286 | id='autoupdateToggle', 287 | options=[{'label':'AutoUpdate','value':'On'}], 288 | values=['On']) 289 | ,className ='col-md-2'), 290 | html.Div( 291 | dcc.Interval( 292 | id='interval', 293 | interval=1*10_000, 294 | n_intervals=0) 295 | ,className="col-md-1"), 296 | html.Div( 297 | html.Div( 298 | style={'display':"none"}, 299 | id='buffer') 300 | ,className="col-md-5"), 301 | html.Div( 302 | dcc.Dropdown( 303 | id='legend', 304 | options=[{'label':param,'value':param} for param in g_legend_names], 305 | # options=[{'label':"test","value":"test"}], 306 | value = g_inital_legend_name, 307 | # labelStyle={'display': 'inline-block'} 308 | ) 309 | ,className='col-md-4') 310 | ] 311 | ,className='row')]+ 312 | [html.Div( 313 | [dt.DataTable( 314 | rows= [value for key,value in g_dict_of_param_dicts.items()], 315 | # optional - sets the order of columns 316 | columns= g_legend_names, 317 | 318 | row_selectable=True, 319 | filterable=True, 320 | sortable=True, 321 | editable=False, 322 | selected_row_indices=[], 323 | id='datatable' 324 | )] 325 | ,className="row")]+ 326 | 327 | [html.Div( 328 | [html.P("Debug Value",id='debug',className="text-center")] 329 | ,className="row",style={'display':'none'})]+ 330 | [html.Div( 331 | [html.P("Debug Value",id='debug2',className="text-center")] 332 | ,className="row",style={'display':'none'})]+ 333 | [html.Div( 334 | dcc.Tabs( 335 | tabs=[{'label': '{}'.format(name), 'value': name} for name in g_tab_names], 336 | value=g_tab_names[0], 337 | id='tabs' 338 | ) 339 | ,className="row")] 340 | +[plotTab_object.createHTMLStructure()] 341 | +[imageTab_object.createHTMLStructure()] 342 | +[histoTab_object.createHTMLStructure()] 343 | +[thoughtTab_object.createHTMLStructure()] 344 | , className="container-fluid") 345 | 346 | 347 | ################ **Assigning Callbacks** ################## 348 | plotTab_object.assignCallbacks(app) 349 | imageTab_object.assignCallbacks(app) 350 | histoTab_object.assignCallbacks(app) 351 | thoughtTab_object.assignCallbacks(app) 352 | 353 | # Time toggle buffer 354 | @app.callback( 355 | Output("buffer","children"), 356 | [Input("interval","n_intervals")], 357 | [State("autoupdateToggle","values")] 358 | ) 359 | def add_more_datapoints(n_intervals,values): 360 | if 'On' in values: 361 | return "changed" 362 | else: 363 | raise Exception 364 | 365 | ## Table data 366 | @app.callback( 367 | Output("datatable","rows"), 368 | [Input('buffer','children')], 369 | ) 370 | def update_table(children): 371 | rows= [value for key,value in g_dict_of_param_dicts.items()] 372 | # print("line break") 373 | # print(type(rows)) 374 | return rows 375 | @app.callback( 376 | Output("datatable","selected_row_indices"), 377 | [Input("datatable","rows")], 378 | [State("datatable","selected_row_indices")] 379 | ) 380 | def preserver_selected_row_index(rows,index): 381 | return index 382 | ## Table columns 383 | @app.callback( 384 | Output("datatable","columns"), 385 | [Input('buffer','children')], 386 | ) 387 | def update_table_columns(children): 388 | return g_legend_names 389 | 390 | ## Debug 391 | @app.callback( 392 | Output('debug','children'), 393 | [Input('datatable','rows')] 394 | ) 395 | def printer(children): 396 | return "Debug Value 1:"+str(children) 397 | # @app.callback( 398 | # Output('debug2','children'), 399 | # [Input("datatable",'rows')], 400 | # ) 401 | # def printer(rows): 402 | # # return str(children)+str(rows[14:]) 403 | # return "Debug Value 2:"+str(rows) 404 | 405 | # if __name__=='__main__': 406 | port_number = args.port 407 | app.run_server(port=port_number) 408 | # app.run_server(port=8000,debug=True) 409 | -------------------------------------------------------------------------------- /devel_requirements.txt: -------------------------------------------------------------------------------- 1 | backcall==0.1.0 2 | certifi==2018.4.16 3 | chardet==3.0.4 4 | click==6.7 5 | dash==0.21.1 6 | dash-core-components==0.21.0rc1 7 | dash-html-components==0.11.0 8 | dash-renderer==0.13.0 9 | dash-table-experiments==0.6.0 10 | decorator==4.3.0 11 | docopt==0.6.2 12 | Flask==1.0.2 13 | Flask-Compress==1.4.0 14 | idna==2.7 15 | ipdb==0.11 16 | ipython==6.4.0 17 | ipython-genutils==0.2.0 18 | itsdangerous==0.24 19 | jedi==0.11.1 20 | Jinja2==2.10 21 | jsonschema==2.6.0 22 | jupyter-core==4.4.0 23 | loremipsum==1.0.5 24 | MarkupSafe==1.0 25 | ml-board==0.0.1 26 | nbformat==4.4.0 27 | numpy==1.14.5 28 | parso==0.1.1 29 | pexpect==4.6.0 30 | pickleshare==0.7.4 31 | Pillow==5.2.0 32 | pipreqs==0.4.9 33 | plotly==3.0.0 34 | prompt-toolkit==1.0.15 35 | ptyprocess==0.6.0 36 | Pygments==2.2.0 37 | pymongo==3.7.0 38 | pytz==2018.5 39 | requests>=2.20.0 40 | retrying==1.3.3 41 | simplegeneric==0.8.1 42 | six==1.11.0 43 | traitlets==4.3.2 44 | urllib3==1.23 45 | wcwidth==0.1.7 46 | Werkzeug==0.14.1 47 | yarg==0.1.9 48 | -------------------------------------------------------------------------------- /gifs/dropdown.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bbli/ml_board/f88da939138559065b32b633a71a169e3c9d604f/gifs/dropdown.gif -------------------------------------------------------------------------------- /gifs/table.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bbli/ml_board/f88da939138559065b32b633a71a169e3c9d604f/gifs/table.gif -------------------------------------------------------------------------------- /gifs/thoughts.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bbli/ml_board/f88da939138559065b32b633a71a169e3c9d604f/gifs/thoughts.png -------------------------------------------------------------------------------- /ml_board/Logger.py: -------------------------------------------------------------------------------- 1 | from pymongo import MongoClient 2 | import datetime 3 | from ml_board.utils import Database 4 | from bson.binary import Binary 5 | # import cPickle 6 | import pickle 7 | 8 | class SummaryWriter(Database): 9 | def __init__(self,database_name,folder_name): 10 | super().__init__() 11 | self.database_name = database_name 12 | self.folder_name = folder_name 13 | self.runs = self.client[database_name][folder_name] 14 | 15 | self.date = datetime.datetime.today().strftime("%Y-%m-%d-%H:%M:%S") 16 | self.runs.insert_one( {"Experimental Parameters":{"Time":self.date}}) 17 | # self.runs.update_one({"Experimental Parameters.Time":self.date},{'$set':{"Time":self.date}}) 18 | 19 | def add_scalar(self,variable_name:str, f:int): 20 | self.runs.update_one({"Experimental Parameters.Time":self.date},{'$push':{"Plots."+variable_name:f}},upsert= True) 21 | def add_histogram(self, histogram_name:str, f:int): 22 | self.runs.update_one({"Experimental Parameters.Time":self.date},{'$push':{"Histograms."+histogram_name:f}},upsert= True) 23 | def add_image(self,image_name,image): 24 | processed_image = Binary(pickle.dumps(image,protocol=2)) 25 | self.runs.update_one({"Experimental Parameters.Time":self.date},{'$set':{"Images."+image_name:processed_image}},upsert= True) 26 | def add_experiment_parameter(self, parameter_name:str, value:int): 27 | self.runs.update_one({"Experimental Parameters.Time":self.date}, {'$set':{"Experimental Parameters."+parameter_name:value}}) 28 | def add_thought(self,string): 29 | self.runs.update_one({"Experimental Parameters.Time":self.date},{'$push':{"Thoughts":self.folder_name}},upsert= True) 30 | self.runs.update_one({"Experimental Parameters.Time":self.date},{'$push':{"Thoughts":string}},upsert= True) 31 | def viewRun(self): 32 | ''' 33 | show all the data logged from the run 34 | ''' 35 | for doc in self.runs.find({"Experimental Parameters.Time":self.date}): 36 | print(doc) 37 | 38 | 39 | 40 | 41 | if __name__ == '__main__': 42 | w = SummaryWriter('test_db','test_collection') 43 | w.add_experiment_parameter('Learning Rate',2) 44 | w.add_experiment_parameter('Neurons',3) 45 | for i in range(5): 46 | w.add_scalar("Loss",i**2) 47 | w.add_thought("hi") 48 | w.add_thought("hello") 49 | w.viewRun() 50 | # ipdb.set_trace() 51 | # w.removeFolder('test_db','test_collection') 52 | # w.close() 53 | 54 | 55 | -------------------------------------------------------------------------------- /ml_board/__init__.py: -------------------------------------------------------------------------------- 1 | from ml_board.Logger import SummaryWriter 2 | -------------------------------------------------------------------------------- /ml_board/utils.py: -------------------------------------------------------------------------------- 1 | from line_profiler import LineProfiler 2 | from pymongo import MongoClient 3 | from threading import Thread 4 | import dash 5 | import dash_core_components as dcc 6 | import dash_html_components as html 7 | from dash.dependencies import Input, Output 8 | import itertools 9 | from io import BytesIO 10 | from PIL import Image 11 | import pickle 12 | import base64 13 | import numpy as np 14 | import sys 15 | ################ **Optimizing Utils** ################## 16 | import time 17 | import sys 18 | 19 | def timeFigureUpdate(title): 20 | def wrapper(func): 21 | def timeit(*args): 22 | start = time.time() 23 | x = func(*args) 24 | end = time.time() 25 | # print("Elapsed Time: {}".format(end-start),file=sys.stdout) 26 | sys.stdout.write("Elapsed Time of {} update_figure_and_data_structure function: {}\n".format(title,end-start)) 27 | return x 28 | return timeit 29 | return wrapper 30 | 31 | def profile(title): 32 | def wrapper(f): 33 | def printProfile(*args): 34 | lp = LineProfiler() 35 | dec_f = lp(f) 36 | output_value = dec_f(*args) 37 | print("Line Profile for:",title) 38 | print("----------------------") 39 | lp.print_stats() 40 | return output_value 41 | return printProfile 42 | return wrapper 43 | ############################################################## 44 | 45 | class Database(): 46 | def __init__(self): 47 | self.client = MongoClient() 48 | self.checkConnection() 49 | ## Database utilities 50 | ## I do not want the user to accidently delete all their data 51 | # def removeDataBase(self,folder_name): 52 | # self.client.drop_database(folder_name) 53 | 54 | def removeFolder(self,database_name,folder_name): 55 | self.client[database_name][folder_name].drop() 56 | 57 | def viewDataBase(self,database_name): 58 | ''' 59 | show all collections in a folder 60 | ''' 61 | # include include_system_collections=False? 62 | for collection in self.client[database_name].list_collection_names(): 63 | print(collection) 64 | def getAllFolderIteratorsFromDatabase(self,database_name): 65 | folder_iterators_list= [] 66 | folder_names = self.client[database_name].list_collection_names() 67 | for folder_name in folder_names: 68 | iterator = self.client[database_name][folder_name].find() 69 | folder_iterators_list.append(iterator) 70 | 71 | return folder_iterators_list 72 | 73 | 74 | def viewFolder(self,database_name,folder_name): 75 | ''' 76 | show all documents in a collection 77 | ''' 78 | for doc in self.client[database_name][folder_name].find(): 79 | print(doc) 80 | def close(self): 81 | self.client.close() 82 | 83 | ## Connection utilties, not meant to be used by user 84 | def checkConnection(self): 85 | t = Thread(target=self.testInsert) 86 | t.start() 87 | t.join(2) 88 | if t.is_alive(): 89 | raise Exception("Cannot connect to MongoDB") 90 | 91 | def testInsert(self): 92 | doc = self.client['test_db']['test_collection'] 93 | doc.insert({"Test":1}) 94 | doc.remove({"Test":1}) 95 | ################ **Misc** ################## 96 | from functools import partial 97 | def partial_decomaker(partial_name): 98 | def decorator(func): 99 | partial_func = partial(func,partial_name=partial_name) 100 | return partial_func 101 | return decorator 102 | 103 | from inspect import getsource 104 | def code(function): 105 | print(getsource(function)) 106 | 107 | ################ **Functions used to load Data in** ################## 108 | def getParamDict(database_name,folder_name): 109 | mongo = Database() 110 | runs = mongo.client[database_name][folder_name] 111 | ## all the runs in the folder 112 | runs_iterator = runs.find() 113 | 114 | dict_of_dicts = {} 115 | for run_object in runs_iterator: 116 | Experimental_Parameters = run_object['Experimental Parameters'] 117 | time = Experimental_Parameters['Time'] 118 | dict_of_dicts[time] = Experimental_Parameters 119 | return dict_of_dicts 120 | def getLegendNames(dict_of_param_dicts): 121 | list_of_param_names = [] 122 | for time,plot_dict in dict_of_param_dicts.items(): 123 | list_of_param_names.append(plot_dict.keys()) 124 | legend_names = sorted(set(list(itertools.chain(*list_of_param_names)))) 125 | return legend_names 126 | ## Object Related 127 | def getDictOfNameObjects(database_name,folder_name,name,f=None): 128 | mongo = Database() 129 | runs = mongo.client[database_name][folder_name] 130 | ## all the runs in the folder 131 | runs_iterator = runs.find() 132 | 133 | nameObjects_for_each_run = {} 134 | # paramObjects_for_each_run = {} 135 | for run_object in runs_iterator: 136 | Experimental_Parameters = run_object['Experimental Parameters'] 137 | time = Experimental_Parameters['Time'] 138 | # param_objects_for_each_run[time] = Experimental_Parameters 139 | 140 | try: 141 | one_run_dict = run_object[name] 142 | if f: 143 | one_run_dict = f(one_run_dict) 144 | nameObjects_for_each_run[time] = one_run_dict 145 | except KeyError: 146 | print("Name does not exist in the run") 147 | mongo.close() 148 | # return nameObjects_for_each_run, paramObjects_for_each_run 149 | return nameObjects_for_each_run 150 | def getBase64Encoding(one_run_dict): 151 | return {image_name:binaryToBase64(binary_image) for image_name,binary_image in one_run_dict.items()} 152 | def binaryToBase64(binary_image): 153 | numpy_matrix=pickle.loads(binary_image) 154 | img = Image.fromarray(np.uint8(numpy_matrix*255),'L') 155 | # base64_string= base64.b64encode(numpy_matrix) 156 | buff = BytesIO() 157 | img.save(buff, format="JPEG") 158 | base64_string = base64.b64encode(buff.getvalue()) 159 | buff.close() 160 | return str(base64_string)[2:-1] 161 | def getFigureNames(nameObjects_for_each_run): 162 | list_of_names = [] 163 | for time, one_run_dict in nameObjects_for_each_run.items(): 164 | list_of_names.append(one_run_dict.keys()) 165 | names = sorted(set(list(itertools.chain(*list_of_names)))) 166 | return names 167 | 168 | ############################################################## 169 | def createHTMLRowList(self): 170 | html_row_list = [] 171 | for time in self.ordered_thoughtList_keys: 172 | thought_list = self.dict_of_all_thought_lists[time] 173 | 174 | title_row = createThoughtsTitle(thought_list,time) 175 | html_row_list.append(title_row) 176 | 177 | paragraph_for_each_thought = createThoughts(thought_list) 178 | paragraph_row = html.Div(paragraph_for_each_thought,className='row') 179 | html_row_list.append(paragraph_row) 180 | return html_row_list 181 | ## only take 0.1 seconds. So no issue in updating it 182 | # @profile("Thoughts") 183 | def getDictOfAllThoughtLists(database_name): 184 | mongo = Database() 185 | folder_iterators_list = mongo.getAllFolderIteratorsFromDatabase(database_name) 186 | database_dict = {} 187 | for folder_iterator in folder_iterators_list: 188 | dict_of_thoughtlists = getDictOfThoughtLists(folder_iterator) 189 | database_dict.update(dict_of_thoughtlists) 190 | mongo.close() 191 | return database_dict 192 | 193 | ######################### 194 | def getDictOfThoughtLists(folder_iterator): 195 | dict_of_thoughtlists = {} 196 | for run_object in folder_iterator: 197 | Experimental_Parameters = run_object['Experimental Parameters'] 198 | time = Experimental_Parameters['Time'] 199 | try: 200 | thought_list = run_object['Thoughts'] 201 | ## eliminating the extra self.folder_name logs 202 | dict_of_thoughtlists[time]=thought_list 203 | except KeyError: 204 | print("Run object does not have 'Thoughts' as a key") 205 | 206 | return dict_of_thoughtlists 207 | ######################### 208 | def getOrderedKeys(dict_of_thoughtlists): 209 | return sorted(dict_of_thoughtlists.keys()) 210 | 211 | def createThoughts(list_of_thoughts): 212 | paragraph_list = [] 213 | ## skipping the folder_names 214 | for thought in list_of_thoughts[1::2]: 215 | paragraph = html.P(thought) 216 | paragraph_list.append(paragraph) 217 | return paragraph_list 218 | 219 | def createThoughtsTitle(list_of_thoughts,time): 220 | folder_name = list_of_thoughts[0] 221 | ## No need for year and seconds 222 | title_row = html.Div(html.B(time[5:-3]+': '+folder_name),className='row') 223 | return title_row 224 | 225 | ############################################################## 226 | 227 | ################ **Functions used During Callbacks** ################## 228 | def getSelectedRunsFromDatatable(rows,selected_row_indices): 229 | if selected_row_indices==[]: 230 | selected_runs= rows 231 | else: 232 | selected_runs = [rows[i] for i in selected_row_indices] 233 | return [run_dict['Time'] for run_dict in selected_runs] 234 | 235 | 236 | 237 | 238 | if __name__ == '__main__': 239 | database = Database() 240 | database.client['test_db']['test_collection'].insert_one({"Test":"test"}) 241 | database.viewRun('test_db','test_collection') 242 | database.removeRun('test_db','test_collection') 243 | database.viewRun('test_db','test_collection') 244 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | setuptools==39.2.0 2 | dash_core_components==0.21.0rc1 3 | dash-renderer==0.13.0 4 | dash_table_experiments==0.6.0 5 | plotly==3.0.0 6 | dash_html_components==0.11.0 7 | pymongo==3.7.0 8 | numpy==1.14.5 9 | dash==0.21.1 10 | Pillow==5.2.0 11 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | with open("README.md", "r") as fh: 4 | long_description = fh.read() 5 | 6 | setup(name='ml_board', 7 | version='0.0.1', 8 | description="A machine learning dashboard that displays hyperparameter settings alongside visualizations, and logs the scientist's thoughts throughout the training process", 9 | long_description = long_description, 10 | long_description_content_type="text/markdown", 11 | url='http://github.com/bbli/ml_board', 12 | author='Benson Li', 13 | scripts=['bin/ml_board'], 14 | author_email='bensonbinbinli@gmail.com', 15 | license='MIT', 16 | packages=['ml_board'], 17 | install_requires=[ 18 | 'dash_core_components==0.21.0rc1', 19 | 'dash-renderer', 20 | 'plotly', 21 | 'dash_html_components', 22 | 'pymongo', 23 | 'numpy', 24 | 'dash', 25 | 'Pillow', 26 | 'dash_table_experiments==0.6.0' 27 | ], 28 | classifiers=( 29 | "Programming Language :: Python :: 3", 30 | "License :: OSI Approved :: MIT License", 31 | "Operating System :: OS Independent", 32 | ) 33 | ) 34 | --------------------------------------------------------------------------------