├── README.md ├── Visual_UI.ipynb ├── Visual_UI2.ipynb ├── _config.yml ├── environment.yml ├── paperspace_ui.py ├── static ├── CM_FN.PNG ├── CM_FP.PNG ├── CM_TN.PNG ├── CM_TP.PNG ├── CM_eight.PNG ├── CM_five.PNG ├── CM_four.PNG ├── CM_nine.PNG ├── CM_one.PNG ├── CM_seven.PNG ├── CM_six.PNG ├── CM_three.PNG ├── CM_two.PNG ├── LR.PNG ├── LR_one.PNG ├── LR_three.PNG ├── LR_two.PNG ├── Lr_four.PNG ├── aug_one.PNG ├── aug_one2.PNG ├── aug_three.PNG ├── aug_three2.PNG ├── aug_three3.PNG ├── aug_two.PNG ├── aug_two2.PNG ├── aug_two3.PNG ├── batch.PNG ├── batch_three.PNG ├── batch_two.PNG ├── cm_class.PNG ├── data.PNG ├── data2.PNG ├── heatmap3.PNG ├── info.PNG ├── info_dashboard.PNG ├── metrics.PNG ├── model.PNG ├── visionUI2_part1.gif ├── visionUI2_part2.gif └── visionUI2_part3.gif ├── viola_test.ipynb ├── vision_ui.py ├── vision_ui2.py └── xresnet2.py /README.md: -------------------------------------------------------------------------------- 1 | ### Vision_UI 2 | Graphical User interface for fastai 3 | 4 | [![GitHub license](https://img.shields.io/github/license/Naereen/StrapDown.js.svg)](https://github.com/Naereen/StrapDown.js/blob/master/LICENSE) ![](https://github.com/fastai/nbdev/workflows/CI/badge.svg) 5 | 6 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1O_H41XhABAEQxg_p8KZd_BCQ8pj-eJX6) currently only works with **version 1** 7 | 8 | 9 | 10 | 11 | Visual UI adds a graphical interface to fastai allowing the user to quickly load, choose parameters, train and view results without the need to dig deep into the code. 12 | 13 | ________________________________________________________________________________________________________________________________________ 14 | 15 | ### Updates 16 | 17 | #### 06/03/2020 18 | - Can now be `pip` installed [fast-gui](https://pypi.org/project/fast-gui/) 19 | 20 | #### 03/17/2020 21 | - Update for compatability with [fastai2](https://github.com/fastai/fastai2) 22 | - Files: `Visual_UI2.ipyb` and `vision_ui2.py` 23 | 24 | #### Updates below are for version 1 25 | Files: `Visual_UI.ipyb` and `vision_ui.py` 26 | 27 | #### 12/23/2019 28 | - Inclusion of ImageDataBunch.from_csv 29 | - Additional augmentations included [cutout, jitter, contrast, brightness, rotate, symmetric warp, padding] 30 | - Inclusion of ClassConfusion widget 31 | - Addition of 'Code' tab to view code 32 | 33 | #### 11/12/2019 34 | - Under the 'Info' tab you can now easily upload some common datasets: Cats&Dogs, Imagenette, Imagewoof, Cifar and Mnist 35 |

36 | 37 |

38 | 39 | - Under the 'Results' tab if there are more than 2 classes the confusion matrix upgrades will not work but you can now view the confusion matrix 40 |

41 | 42 |

43 | 44 | #### 10/12/2019 - [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1O_H41XhABAEQxg_p8KZd_BCQ8pj-eJX6) 45 | - Works with Google Colab (https://github.com/asvcode/Colab_UI) - Results tab not currently available in Colab 46 | 47 | #### 09/25/2019 - xresnet architecture 48 | - xresnet architectures now working (using xresnet2.py from fastai) 49 | 50 | #### 09/12/2019 - Confusion Matrix Upgrades (currently only works if there are 2 classes) 51 | - Under the Results tab, the confusion matrix tab now includes enhanced viewing features: 52 | 53 | > Option to view images with heatmaps or not 54 |

55 | 56 |

57 | 58 | > Option to view images within each section of the matrix 59 |

60 | 61 |

62 | 63 | > If heatmap option is 'YES' you can choose colormap, interpolation and alpha parameters 64 |

65 | 66 |

67 | 68 | > Examples of using different parameters for viewing images 69 |

70 | 71 |

72 | 73 | > Also have the option to view the images without the heatmap feature. Images within each matrix class display Index, Actual_Class, Predicted_Class, Prediction value, Loss and Image location 74 |

75 | 76 |

77 | 78 | > Images are stored within the path folder under their respective confusion matrix tags 79 | 80 | > View saved image files from various sections of the confusion matrix and compare their heatmap images. 81 | 82 |

83 | False Positive 84 | 85 | True Positive 86 | 87 | True Negative 88 | 89 | False Negative 90 | 91 |

92 | 93 | 94 | 95 | #### 07/09/2019 96 | - after a training run, the model is saved in the models folder with the following name: 'architecture' + 'pretrained' + batchsize + image size eg: resnet50_pretrained_True_batch_32_image_128.pth 97 | - updated tkinter askdirectory code: now after choosing a file the tkinter dialogue box will be destroyed - previously the box would remain open 98 | 99 | #### 06/05/2019 100 | - results tab added where you can load your saved model and plot multi_plot_losses, top_losses and Confusion_matrix 101 | 102 | #### 06/03/2019 103 | - path and image_path (for augmentations) is now within vision_ui so no need to have a seperate cell to specify path 104 | - included link to fastai docs and forum in 'info' tab 105 | 106 | ________________________________________________________________________________________________________________________________________ 107 | 108 | 109 | 110 | All tabs are provided within an accordion design using ipywidgets, this allows for all aspects of choosing and viewing parameters in one line of sight 111 | 112 |

113 | 114 |

115 | 116 | The Augmentation tab utilizes fastai parameters so you can view what different image augmentations look like and compare 117 | 118 |

119 | 120 |

121 | 122 | View batch information 123 | 124 |

125 | 126 |

127 | 128 | Review model data and choose suitable metrics for training 129 | 130 |

131 | 132 |

133 | 134 | Review parameters get learning rate and train using the one cycle policy 135 | 136 |

137 | 138 |

139 | 140 | Can experiment with various learning rates and train 141 | 142 |

143 | 144 |

145 | 146 | 147 | 148 | ### Requirements 149 | - fastai 150 | 151 | I am using the developer version: 152 | 153 |

154 | 155 |

156 | 157 | 158 | `git clone https://github.com/fastai/fastai` 159 | 160 | `cd fastai` 161 | 162 | `tools/run-after-git-clone` 163 | 164 | `pip install -e ".[dev]"` 165 | 166 | for installation instructions visit [Fastai Installation Readme](https://github.com/fastai/fastai/blob/master/README.md#installation) 167 | 168 | - ipywidgets 169 | 170 | `pip install ipywidgets 171 | jupyter nbextension enable --py widgetsnbextension` 172 | 173 | or 174 | 175 | `conda install -c conda-forge ipywidgets` 176 | 177 | for installation instructions visit [Installation docs](https://ipywidgets.readthedocs.io/en/stable/user_install.html) 178 | 179 | - psutil 180 | 181 | psutil (process and system utilities) is a cross-platform library for retrieving information on running processes and system utilization (CPU, memory, disks, network, sensors) in Python 182 | 183 | `pip install psutil` 184 | 185 | 186 | ### Installation 187 | 188 | git clone this repository 189 | 190 | `git clone https://github.com/asvcode/Vision_UI.git` 191 | 192 | run `Visual_UI.ipynb` and run `display_ui()` 193 | 194 | 195 | ### Known Issues 196 | 197 | - **Colab** - version 1 works with colab [Colab_UI](https://github.com/asvcode/Colab_UI) but is glitchy 198 | 199 | ### Future Work 200 | 201 | - ~~Integrate into fastai v2~~ - Compatability with fastai v2 done but not with the full functionality as the v1 version 202 | - ~~Create pip install verson~~ - Done! [fast-gui](https://pypi.org/project/fast-gui/) 203 | - Include full functionality with v2 version 204 | - Create a v2 that is compatible with Colab 205 | -------------------------------------------------------------------------------- /Visual_UI.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "##########################\n", 10 | "## Visual UI for fastai ##\n", 11 | "##########################\n", 12 | "\n", 13 | "from vision_ui import *" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 2, 19 | "metadata": {}, 20 | "outputs": [ 21 | { 22 | "data": { 23 | "application/vnd.jupyter.widget-view+json": { 24 | "model_id": "6639f7fa54964e988ad176392b0df4cb", 25 | "version_major": 2, 26 | "version_minor": 0 27 | }, 28 | "text/plain": [ 29 | "Tab(children=(Output(), Output(), Output(), Output(), Output(), Output(), Output(), Output(), Output()), _titl…" 30 | ] 31 | }, 32 | "metadata": {}, 33 | "output_type": "display_data" 34 | } 35 | ], 36 | "source": [ 37 | "display_ui()" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": null, 43 | "metadata": {}, 44 | "outputs": [], 45 | "source": [] 46 | } 47 | ], 48 | "metadata": { 49 | "kernelspec": { 50 | "display_name": "Python 3", 51 | "language": "python", 52 | "name": "python3" 53 | }, 54 | "language_info": { 55 | "codemirror_mode": { 56 | "name": "ipython", 57 | "version": 3 58 | }, 59 | "file_extension": ".py", 60 | "mimetype": "text/x-python", 61 | "name": "python", 62 | "nbconvert_exporter": "python", 63 | "pygments_lexer": "ipython3", 64 | "version": "3.7.2" 65 | } 66 | }, 67 | "nbformat": 4, 68 | "nbformat_minor": 2 69 | } 70 | -------------------------------------------------------------------------------- /Visual_UI2.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 6, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "###2 cells below are scripts for improving display appearance" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 7, 15 | "metadata": {}, 16 | "outputs": [ 17 | { 18 | "data": { 19 | "application/javascript": [ 20 | "require(\n", 21 | " ['notebook/js/outputarea'],\n", 22 | " function(oa) {\n", 23 | " os.OutputArea.auto_scroll_threshold = -1;\n", 24 | " console.log(\"Setting auto to -1\")\n", 25 | " });\n" 26 | ], 27 | "text/plain": [ 28 | "" 29 | ] 30 | }, 31 | "metadata": {}, 32 | "output_type": "display_data" 33 | } 34 | ], 35 | "source": [ 36 | "%%javascript\n", 37 | "require(\n", 38 | " ['notebook/js/outputarea'],\n", 39 | " function(oa) {\n", 40 | " os.OutputArea.auto_scroll_threshold = -1;\n", 41 | " console.log(\"Setting auto to -1\")\n", 42 | " });" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": 8, 48 | "metadata": {}, 49 | "outputs": [ 50 | { 51 | "data": { 52 | "application/javascript": [ 53 | "(function(on) {\n", 54 | "const e=$( \"Setup failed\" );\n", 55 | "const ns=\"js_jupyter_suppress_warnings\";\n", 56 | "var cssrules=$(\"#\"+ns);\n", 57 | "if(!cssrules.length) cssrules = $(\"\").appendTo(\"head\");\n", 58 | "e.click(function() {\n", 59 | " var s='Showing'; \n", 60 | " cssrules.empty()\n", 61 | " if(on) {\n", 62 | " s='Hiding';\n", 63 | " cssrules.append(\"div.output_stderr, div[data-mime-type*='.stderr'] { display:none; }\");\n", 64 | " }\n", 65 | " e.text(s+' warnings (click to toggle)');\n", 66 | " on=!on;\n", 67 | "}).click();\n", 68 | "$(element).append(e);\n", 69 | "})(true);\n" 70 | ], 71 | "text/plain": [ 72 | "" 73 | ] 74 | }, 75 | "metadata": {}, 76 | "output_type": "display_data" 77 | } 78 | ], 79 | "source": [ 80 | "%%javascript\n", 81 | "(function(on) {\n", 82 | "const e=$( \"Setup failed\" );\n", 83 | "const ns=\"js_jupyter_suppress_warnings\";\n", 84 | "var cssrules=$(\"#\"+ns);\n", 85 | "if(!cssrules.length) cssrules = $(\"\").appendTo(\"head\");\n", 86 | "e.click(function() {\n", 87 | " var s='Showing'; \n", 88 | " cssrules.empty()\n", 89 | " if(on) {\n", 90 | " s='Hiding';\n", 91 | " cssrules.append(\"div.output_stderr, div[data-mime-type*='.stderr'] { display:none; }\");\n", 92 | " }\n", 93 | " e.text(s+' warnings (click to toggle)');\n", 94 | " on=!on;\n", 95 | "}).click();\n", 96 | "$(element).append(e);\n", 97 | "})(true);" 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "execution_count": 9, 103 | "metadata": {}, 104 | "outputs": [ 105 | { 106 | "data": { 107 | "application/vnd.jupyter.widget-view+json": { 108 | "model_id": "196a73bfd83c424f8cba5ebce1f70877", 109 | "version_major": 2, 110 | "version_minor": 0 111 | }, 112 | "text/plain": [ 113 | "Tab(children=(Output(), Output(), Output(), Output(), Output(), Output()), _titles={'0': 'Info', '1': 'Data', …" 114 | ] 115 | }, 116 | "metadata": {}, 117 | "output_type": "display_data" 118 | } 119 | ], 120 | "source": [ 121 | "from vision_ui2 import *\n", 122 | "display_ui()" 123 | ] 124 | }, 125 | { 126 | "cell_type": "code", 127 | "execution_count": null, 128 | "metadata": {}, 129 | "outputs": [], 130 | "source": [] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "execution_count": null, 135 | "metadata": {}, 136 | "outputs": [], 137 | "source": [] 138 | } 139 | ], 140 | "metadata": { 141 | "kernelspec": { 142 | "display_name": "Python 3", 143 | "language": "python", 144 | "name": "python3" 145 | }, 146 | "language_info": { 147 | "codemirror_mode": { 148 | "name": "ipython", 149 | "version": 3 150 | }, 151 | "file_extension": ".py", 152 | "mimetype": "text/x-python", 153 | "name": "python", 154 | "nbconvert_exporter": "python", 155 | "pygments_lexer": "ipython3", 156 | "version": "3.7.6" 157 | } 158 | }, 159 | "nbformat": 4, 160 | "nbformat_minor": 2 161 | } 162 | -------------------------------------------------------------------------------- /_config.yml: -------------------------------------------------------------------------------- 1 | theme: jekyll-theme-slate -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: fastai2 2 | channels: 3 | - fastai 4 | - pytorch 5 | - defaults 6 | dependencies: 7 | - jupyter 8 | - pytorch>=1.3.0 9 | - torchvision>=0.5 10 | - matplotlib 11 | - pandas 12 | - requests 13 | - pyyaml 14 | - fastprogress>=0.1.22 15 | - pillow 16 | - python>=3.6 17 | - pip 18 | - scikit-learn 19 | - scipy 20 | - spacy 21 | - voila 22 | -------------------------------------------------------------------------------- /paperspace_ui.py: -------------------------------------------------------------------------------- 1 | """ 2 | Paperspace_UI based on Vision_UI 3 | Visual graphical interface for Fastai 4 | 5 | Last Update: 10/12/2019 6 | https://github.com/asvcode/Vision_UI 7 | """ 8 | 9 | from ipywidgets import interact, interactive, fixed, interact_manual 10 | import ipywidgets 11 | import ipywidgets as widgets 12 | import IPython 13 | from IPython.display import display,clear_output 14 | 15 | import webbrowser 16 | from IPython.display import YouTubeVideo 17 | 18 | from fastai.vision import * 19 | from fastai.widgets import * 20 | from fastai.callbacks import* 21 | 22 | def version(): 23 | import fastai 24 | import psutil 25 | 26 | print ('>> Paperspace version') 27 | 28 | button = widgets.Button(description='System') 29 | but = widgets.HBox([button]) 30 | display(but) 31 | 32 | out = widgets.Output() 33 | display(out) 34 | 35 | def on_button_clicked_info(b): 36 | with out: 37 | clear_output() 38 | print(f'Fastai Version: {fastai.__version__}') 39 | print(f'Cuda: {torch.cuda.is_available()}') 40 | print(f'GPU: {torch.cuda.get_device_name(0)}') 41 | print(f'Python version: {sys.version}') 42 | print(psutil.cpu_percent()) 43 | print(psutil.virtual_memory()) # physical memory usage 44 | print('memory % used:', psutil.virtual_memory()[2]) 45 | 46 | button.on_click(on_button_clicked_info) 47 | 48 | def dashboard_one(): 49 | style = {'description_width': 'initial'} 50 | 51 | print('>> Currently only works with files FROM_FOLDERS' '\n') 52 | dashboard_one.datain = widgets.ToggleButtons( 53 | options=['from_folder'], 54 | description='Data In:', 55 | disabled=True, 56 | button_style='success', # 'success', 'info', 'warning', 'danger' or '' 57 | tooltips=['Data in folder', 'Data in csv format - NOT ACTIVE', 'Data in dataframe - NOT ACTIVE'], 58 | ) 59 | dashboard_one.norma = widgets.ToggleButtons( 60 | options=['Imagenet', 'Custom', 'Cifar', 'Mnist'], 61 | description='Normalization:', 62 | disabled=False, 63 | button_style='info', # 'success', 'info', 'warning', 'danger' or '' 64 | tooltips=['Imagenet stats', 'Create your own', 'Cifar stats', 'Mnist stats'], 65 | style=style 66 | ) 67 | dashboard_one.archi = widgets.ToggleButtons( 68 | options=['alexnet', 'BasicBlock', 'densenet121', 'densenet161', 'densenet169', 'densenet201', 'resnet18', 69 | 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'squeezenet1_0', 'squeezenet1_1', 'vgg16_bn', 70 | 'vgg19_bn', 'xresnet18', 'xresnet34', 'xresnet50', 'xresnet101', 'xresnet152'], 71 | description='Architecture:', 72 | disabled=False, 73 | button_style='', # 'success', 'info', 'warning', 'danger' or '' 74 | tooltips=[], 75 | ) 76 | layout = widgets.Layout(width='auto', height='40px') #set width and height 77 | 78 | xres_text = widgets.Button( 79 | description='FOR Xresnet models: Are not pretrained so have to UNCHECK Pretrain box to avoid errors.', 80 | disabled=True, 81 | display='flex', 82 | flex_flow='column', 83 | align_items='stretch', 84 | layout = layout 85 | ) 86 | dashboard_one.pretrain_check = widgets.Checkbox( 87 | options=['Yes', "No"], 88 | description='Pretrained:', 89 | disabled=False, 90 | value=True, 91 | box_style='success', 92 | button_style='lightgreen', # 'success', 'info', 'warning', 'danger' or '' 93 | tooltips=['Default: Checked = use pretrained weights, Unchecked = No pretrained weights'], 94 | ) 95 | 96 | layout = {'width':'90%', 'height': '50px', 'border': 'solid', 'fontcolor':'lightgreen'} 97 | layout_two = {'width':'100%', 'height': '200px', 'border': 'solid', 'fontcolor':'lightgreen'} 98 | style_green = {'handle_color': 'green', 'readout_color': 'red', 'slider_color': 'blue'} 99 | style_blue = {'handle_color': 'blue', 'readout_color': 'red', 'slider_color': 'blue'} 100 | dashboard_one.f=widgets.FloatSlider(min=8,max=64,step=8,value=32, continuous_update=False, layout=layout, style=style_green, description="Batch size") 101 | dashboard_one.m=widgets.FloatSlider(min=0, max=360, step=16, value=128, continuous_update=False, layout=layout, style=style_green, description='Image size') 102 | 103 | 104 | display(dashboard_one.datain, dashboard_one.norma, dashboard_one.archi, xres_text, dashboard_one.pretrain_check, dashboard_one.f, dashboard_one.m) 105 | 106 | def dashboard_two(): 107 | button = widgets.Button(description="View") 108 | print ('>> Choose image to view augmentations:') 109 | 110 | image_choice() 111 | print('Augmentations') 112 | 113 | layout = {'width':'90%', 'height': '50px', 'border': 'solid', 'fontcolor':'lightgreen'} 114 | layout_two = {'width':'100%', 'height': '200px', 'border': 'solid', 'fontcolor':'lightgreen'} 115 | style_green = {'handle_color': 'green', 'readout_color': 'red', 'slider_color': 'blue'} 116 | style_blue = {'handle_color': 'blue', 'readout_color': 'red', 'slider_color': 'blue'} 117 | 118 | dashboard_two.doflip = widgets.ToggleButtons( 119 | options=['Yes', "No"], 120 | description='Do Flip:', 121 | disabled=False, 122 | button_style='success', # 'success', 'info', 'warning', 'danger' or '' 123 | tooltips=['Description of slow', 'Description of regular', 'Description of fast'], 124 | ) 125 | dashboard_two.dovert = widgets.ToggleButtons( 126 | options=['Yes', "No"], 127 | description='Do Vert:', 128 | disabled=False, 129 | button_style='info', # 'success', 'info', 'warning', 'danger' or '' 130 | tooltips=['Description of slow', 'Description of regular', 'Description of fast'], 131 | ) 132 | dashboard_two.two = widgets.FloatSlider(min=0,max=20,step=1,value=10, description='Max Rotate', orientation='vertical', style=style_green, layout=layout_two) 133 | dashboard_two.three = widgets.FloatSlider(min=1.1,max=4,step=1,value=1.1, description='Max Zoom', orientation='vertical', style=style_green, layout=layout_two) 134 | dashboard_two.four = widgets.FloatSlider(min=0.25, max=1.0, step=0.1, value=0.75, description='p_affine', orientation='vertical', style=style_green, layout=layout_two) 135 | dashboard_two.five = widgets.FloatSlider(min=0.2,max=0.99, step=0.1,value=0.2, description='Max Lighting', orientation='vertical', style=style_blue, layout=layout_two) 136 | dashboard_two.six = widgets.FloatSlider(min=0.25, max=1.1, step=0.1, value=0.75, description='p_lighting', orientation='vertical', style=style_blue, layout=layout_two) 137 | dashboard_two.seven = widgets.FloatSlider(min=0.1, max=0.9, step=0.1, value=0.2, description='Max warp', orientation='vertical', style=style_green, layout=layout_two) 138 | 139 | ui2 = widgets.VBox([dashboard_two.doflip, dashboard_two.dovert]) 140 | ui = widgets.HBox([dashboard_two.two,dashboard_two.three, dashboard_two.seven, dashboard_two.four,dashboard_two.five, dashboard_two.six]) 141 | ui3 = widgets.HBox([ui2, ui]) 142 | 143 | display (ui3) 144 | 145 | print ('>> Press button to view augmentations. Pressing the button again will let you view additional augmentations below') 146 | display(button) 147 | 148 | def on_button_clicked(b): 149 | image_path = str(image_choice.output_variable.value) 150 | print('>> Displaying augmetations') 151 | display_augs(image_path) 152 | 153 | button.on_click(on_button_clicked) 154 | 155 | def get_image(image_path): 156 | print(image_path) 157 | 158 | def display_augs(image_path): 159 | get_image(image_path) 160 | image_d = open_image(image_path) 161 | print(image_d) 162 | def get_ex(): return open_image(image_path) 163 | 164 | out_flip = dashboard_two.doflip.value #do flip 165 | out_vert = dashboard_two.dovert.value # do vert 166 | out_rotate = dashboard_two.two.value #max rotate 167 | out_zoom = dashboard_two.three.value #max_zoom 168 | out_affine = dashboard_two.four.value #p_affine 169 | out_lighting = dashboard_two.five.value #Max_lighting 170 | out_plight = dashboard_two.six.value #p_lighting 171 | out_warp = dashboard_two.seven.value #Max_warp 172 | 173 | tfms = get_transforms(do_flip=out_flip, flip_vert=out_vert, max_zoom=out_zoom, 174 | p_affine=out_affine, max_lighting=out_lighting, p_lighting=out_plight, max_warp=out_warp, 175 | max_rotate=out_rotate) 176 | 177 | _, axs = plt.subplots(2,4,figsize=(12,6)) 178 | for ax in axs.flatten(): 179 | img = get_ex().apply_tfms(tfms[0], get_ex(), size=224) 180 | img.show(ax=ax) 181 | 182 | def view_batch_folder(): 183 | 184 | print('>> IMPORTANT: Select data folder under INFO tab prior to clicking on batch button to avoid errors') 185 | button_g = widgets.Button(description="View Batch") 186 | display(button_g) 187 | 188 | batch_val = int(dashboard_one.f.value) # batch size 189 | image_val = int(dashboard_one.m.value) # image size 190 | 191 | out = widgets.Output() 192 | display(out) 193 | 194 | def on_button_click(b): 195 | with out: 196 | clear_output() 197 | print('\n''Augmentations''\n''Do Flip:', dashboard_two.doflip.value,'|''Do Vert:', dashboard_two.dovert.value, '\n' 198 | '\n''Max Rotate: ', dashboard_two.two.value,'|''Max Zoom: ', dashboard_two.three.value,'|''Max Warp: ', 199 | dashboard_two.seven.value,'|''p affine: ', dashboard_two.four.value, '\n''Max Lighting: ', dashboard_two.five.value, 200 | 'p lighting: ', dashboard_two.six.value, '\n' 201 | '\n''Normalization Value:', dashboard_one.norma.value, '\n''\n''working....') 202 | 203 | tfms = get_transforms(do_flip=dashboard_two.doflip.value, flip_vert=dashboard_two.dovert.value, max_zoom=dashboard_two.three.value, 204 | p_affine=dashboard_two.four.value, max_lighting=dashboard_two.five.value, p_lighting=dashboard_two.six.value, 205 | max_warp=dashboard_two.seven.value, max_rotate=dashboard_two.two.value, xtra_tfms=None) 206 | 207 | path = path_load.path_choice 208 | data = ImageDataBunch.from_folder(path, ds_tfms=tfms, bs=batch_val, size=image_val, test='test') 209 | data.normalize(stats_info()) 210 | data.show_batch(rows=5, figsize=(10,10)) 211 | 212 | button_g.on_click(on_button_click) 213 | 214 | def stats_info(): 215 | 216 | if dashboard_one.norma.value == 'Imagenet': 217 | stats_info.stats = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 218 | elif dashboard_one.norma.value == 'Cifar': 219 | stats_info.stats = ([0.491, 0.482, 0.447], [0.247, 0.243, 0.261]) 220 | elif dashboard_one.norma.value == 'Mnist': 221 | stats_info.stats = ([0.15, 0.15, 0.15], [0.15, 0.15, 0.15]) 222 | else: # dashboard_one.norma.value == 'Custom': 223 | stats_info.stats = None 224 | 225 | stats = stats_info.stats 226 | 227 | mets_list = [] 228 | 229 | precision = Precision() 230 | recall = Recall() 231 | 232 | def metrics_list(mets_list): 233 | mets_error = metrics_dashboard.error_choice.value 234 | mets_accuracy= metrics_dashboard.accuracy.value 235 | mets_accuracy_thr = metrics_dashboard.topk.value 236 | mets_precision = metrics_dashboard.precision.value 237 | mets_recall = metrics_dashboard.recall.value 238 | mets_dice = metrics_dashboard.dice.value 239 | 240 | mets_list=[] 241 | output_acc = accuracy 242 | output_thresh = top_k_accuracy 243 | output = error_rate 244 | 245 | if mets_error == 'Yes': 246 | mets_list.append(error_rate) 247 | else: 248 | None 249 | if mets_accuracy == 'Yes': 250 | mets_list.append(accuracy) 251 | else: 252 | None 253 | if mets_accuracy_thr == 'Yes': 254 | mets_list.append(top_k_accuracy) 255 | else: 256 | None 257 | if mets_precision == 'Yes': 258 | mets_list.append(precision) 259 | else: 260 | None 261 | if mets_recall == 'Yes': 262 | mets_list.append(recall) 263 | else: 264 | None 265 | if mets_dice == 'Yes': 266 | mets_list.append(dice) 267 | else: 268 | None 269 | 270 | metrics_info = mets_list 271 | 272 | return mets_list 273 | 274 | def model_summary(): 275 | 276 | print('>> Review Model information: ', dashboard_one.archi.value) 277 | 278 | batch_val = int(dashboard_one.f.value) # batch size 279 | image_val = int(dashboard_one.m.value) # image size 280 | 281 | button_summary = widgets.Button(description="Model Summary") 282 | button_model_0 = widgets.Button(description='Model[0]') 283 | button_model_1 = widgets.Button(description='Model[1]') 284 | 285 | tfms = get_transforms(do_flip=dashboard_two.doflip.value, flip_vert=dashboard_two.dovert.value, max_zoom=dashboard_two.three.value, 286 | p_affine=dashboard_two.four.value, max_lighting=dashboard_two.five.value, p_lighting=dashboard_two.six.value, 287 | max_warp=dashboard_two.seven.value, max_rotate=dashboard_two.two.value, xtra_tfms=None) 288 | 289 | path = path_load.path_choice 290 | data = ImageDataBunch.from_folder(path, ds_tfms=tfms, bs=batch_val, size=image_val, test='test') 291 | 292 | r = dashboard_one.pretrain_check.value 293 | 294 | ui_out = widgets.HBox([button_summary, button_model_0, button_model_1]) 295 | 296 | arch_work() 297 | 298 | display(ui_out) 299 | out = widgets.Output() 300 | display(out) 301 | 302 | def on_button_clicked_summary(b): 303 | with out: 304 | clear_output() 305 | print('working''\n') 306 | learn = cnn_learner(data, base_arch=arch_work.info, pretrained=r, custom_head=None) 307 | print('Model Summary') 308 | info = learn.summary() 309 | print(info) 310 | 311 | button_summary.on_click(on_button_clicked_summary) 312 | 313 | def on_button_clicked_model_0(b): 314 | with out: 315 | clear_output() 316 | print('working''\n') 317 | learn = cnn_learner(data, base_arch=arch_work.info, pretrained=r, custom_head=None) 318 | print('Model[0]') 319 | info_s = learn.model[0] 320 | print(info_s) 321 | 322 | button_model_0.on_click(on_button_clicked_model_0) 323 | 324 | def on_button_clicked_model_1(b): 325 | with out: 326 | clear_output() 327 | print('working''\n') 328 | learn = cnn_learner(data, base_arch=arch_work.info, pretrained=r, custom_head=None) 329 | print('Model[1]') 330 | info_sm = learn.model[1] 331 | print(info_sm) 332 | 333 | button_model_1.on_click(on_button_clicked_model_1) 334 | 335 | def arch_work(): 336 | if dashboard_one.archi.value == 'alexnet': 337 | arch_work.info = models.alexnet 338 | elif dashboard_one.archi.value == 'BasicBlock': 339 | arch_work.info = models.BasicBlock 340 | elif dashboard_one.archi.value == 'densenet121': 341 | arch_work.info = models.densenet121 342 | elif dashboard_one.archi.value == 'densenet161': 343 | arch_work.info = models.densenet161 344 | elif dashboard_one.archi.value == 'densenet169': 345 | arch_work.info = models.densenet169 346 | elif dashboard_one.archi.value == 'densenet201': 347 | arch_work.info = models.densenet201 348 | if dashboard_one.archi.value == 'resnet18': 349 | arch_work.info = models.resnet18 350 | elif dashboard_one.archi.value == 'resnet34': 351 | arch_work.info = models.resnet34 352 | elif dashboard_one.archi.value == 'resnet50': 353 | arch_work.info = models.resnet50 354 | elif dashboard_one.archi.value == 'resnet101': 355 | arch_work.info = models.resnet101 356 | elif dashboard_one.archi.value == 'resnet152': 357 | arch_work.info = models.resnet152 358 | elif dashboard_one.archi.value == 'squeezenet1_0': 359 | arch_work.info = models.squeezenet1_0 360 | elif dashboard_one.archi.value == 'squeezenet1_1': 361 | arch_work.info = models.squeezenet1_1 362 | elif dashboard_one.archi.value == 'vgg16_bn': 363 | arch_work.info = models.vgg16_bn 364 | elif dashboard_one.archi.value == 'vgg19_bn': 365 | arch_work.info = models.vgg19_bn 366 | #elif dashboard_one.archi.value == 'wrn_22': 367 | # arch_work.info = models.wrn_22 368 | elif dashboard_one.archi.value == 'xresnet18': 369 | arch_work.info = xresnet2.xresnet18 370 | elif dashboard_one.archi.value == 'xresnet34': 371 | arch_work.info = xresnet2.xresnet34 372 | elif dashboard_one.archi.value == 'xresnet50': 373 | arch_work.info = xresnet2.xresnet50 374 | elif dashboard_one.archi.value == 'xresnet101': 375 | arch_work.info = xresnet2.xresnet101 376 | elif dashboard_one.archi.value == 'xresnet152': 377 | arch_work.info = xresnet2.xresnet152 378 | 379 | output = arch_work.info 380 | output 381 | print(output) 382 | 383 | def metrics_dashboard(): 384 | button = widgets.Button(description="Metrics") 385 | 386 | batch_val = int(dashboard_one.f.value) # batch size 387 | image_val = int(dashboard_one.m.value) # image size 388 | 389 | tfms = get_transforms(do_flip=dashboard_two.doflip.value, flip_vert=dashboard_two.dovert.value, max_zoom=dashboard_two.three.value, 390 | p_affine=dashboard_two.four.value, max_lighting=dashboard_two.five.value, p_lighting=dashboard_two.six.value, 391 | max_warp=dashboard_two.seven.value, max_rotate=dashboard_two.two.value, xtra_tfms=None) 392 | 393 | path = path_load.path_choice 394 | data = ImageDataBunch.from_folder(path, ds_tfms=tfms, bs=batch_val, size=image_val, test='test') 395 | 396 | layout = {'width':'90%', 'height': '50px', 'border': 'solid', 'fontcolor':'lightgreen'} 397 | style_green = {'button_color': 'green','handle_color': 'green', 'readout_color': 'red', 'slider_color': 'blue'} 398 | 399 | metrics_dashboard.error_choice = widgets.ToggleButtons( 400 | options=['Yes', 'No'], 401 | description='Error Choice:', 402 | value='No', 403 | disabled=False, 404 | button_style='success', # 'success', 'info', 'warning', 'danger' or '' 405 | tooltips=[''], 406 | ) 407 | metrics_dashboard.accuracy = widgets.ToggleButtons( 408 | options=['Yes', 'No'], 409 | description='Accuracy:', 410 | value='No', 411 | disabled=False, 412 | button_style='info', # 'success', 'info', 'warning', 'danger' or '' 413 | tooltips=[''], 414 | ) 415 | metrics_dashboard.topk = widgets.ToggleButtons( 416 | options=['Yes', 'No'], 417 | description='Top K:', 418 | value='No', 419 | disabled=False, 420 | button_style='warning', # 'success', 'info', 'warning', 'danger' or '' 421 | tooltips=[''], 422 | ) 423 | metrics_dashboard.recall = widgets.ToggleButtons( 424 | options=['Yes', 'No'], 425 | description='Recall:', 426 | value='No', 427 | disabled=False, 428 | button_style='success', # 'success', 'info', 'warning', 'danger' or '' 429 | tooltips=[''], 430 | ) 431 | metrics_dashboard.precision = widgets.ToggleButtons( 432 | options=['Yes', 'No'], 433 | description='Precision:', 434 | value='No', 435 | disabled=False, 436 | button_style='info', # 'success', 'info', 'warning', 'danger' or '' 437 | tooltips=[''], 438 | ) 439 | metrics_dashboard.dice = widgets.ToggleButtons( 440 | options=['Yes', 'No'], 441 | description='Dice:', 442 | value='No', 443 | disabled=False, 444 | button_style='warning', # 'success', 'info', 'warning', 'danger' or '' 445 | tooltips=[''], 446 | ) 447 | layout = widgets.Layout(width='auto', height='40px') #set width and height 448 | 449 | centre_t = widgets.Button( 450 | description='', 451 | disabled=True, 452 | display='flex', 453 | flex_flow='column', 454 | align_items='stretch', 455 | layout = layout 456 | ) 457 | ui = widgets.HBox([metrics_dashboard.error_choice, metrics_dashboard.accuracy, metrics_dashboard.topk]) 458 | ui2 = widgets.HBox([metrics_dashboard.recall, metrics_dashboard.precision, metrics_dashboard.dice]) 459 | ui3 = widgets.VBox([ui, centre_t, ui2]) 460 | 461 | r = dashboard_one.pretrain_check.value 462 | 463 | display(ui3) 464 | 465 | print('>> Click to view choosen metrics') 466 | display(button) 467 | 468 | out = widgets.Output() 469 | display(out) 470 | 471 | def on_button_clicked(b): 472 | with out: 473 | clear_output() 474 | print('Training Metrics''\n') 475 | print('arch:', dashboard_one.archi.value, '\n''pretrain: ', dashboard_one.pretrain_check.value, '\n' ,'Choosen metrics: ',metrics_list(mets_list)) 476 | 477 | button.on_click(on_button_clicked) 478 | 479 | def info_lr(): 480 | button = widgets.Button(description='Review Parameters') 481 | button_two = widgets.Button(description='LR') 482 | button_three = widgets.Button(description='Train') 483 | 484 | butlr = widgets.HBox([button, button_two, button_three]) 485 | display(butlr) 486 | 487 | out = widgets.Output() 488 | display(out) 489 | 490 | def on_button_clicked_info(b): 491 | with out: 492 | clear_output() 493 | print('Data in:', dashboard_one.datain.value,'|' 'Normalization:', dashboard_one.norma.value,'|' 'Architecture:', dashboard_one.archi.value, 494 | 'Pretrain:', dashboard_one.pretrain_check.value,'\n''Batch Size:', dashboard_one.f.value,'|''Image Size:', dashboard_one.m.value,'\n' 495 | '\n''Augmentations''\n''Do Flip:', dashboard_two.doflip.value,'|''Do Vert:', dashboard_two.dovert.value, '\n' 496 | '\n''Max Rotate: ', dashboard_two.two.value,'|''Max Zoom: ', dashboard_two.three.value,'|''Max Warp: ', 497 | dashboard_two.seven.value,'|''p affine: ', dashboard_two.four.value, '\n''Max Lighting: ', dashboard_two.five.value, 498 | 'p lighting: ', dashboard_two.six.value, '\n' 499 | '\n''Normalization Value:', dashboard_one.norma.value,'\n' '\n''Training Metrics''\n', 500 | metrics_list(mets_list)) 501 | 502 | button.on_click(on_button_clicked_info) 503 | 504 | def on_button_clicked_info2(b): 505 | with out: 506 | clear_output() 507 | dashboard_one.datain.value, dashboard_one.norma.value, dashboard_one.archi.value, dashboard_one.pretrain_check.value, 508 | dashboard_one.f.value, dashboard_one.m.value, dashboard_two.doflip.value, dashboard_two.dovert.value, 509 | dashboard_two.two.value, dashboard_two.three.value, dashboard_two.seven.value, dashboard_two.four.value, dashboard_two.five.value, 510 | dashboard_two.six.value, dashboard_one.norma.value,metrics_list(mets_list) 511 | 512 | learn_dash() 513 | 514 | button_two.on_click(on_button_clicked_info2) 515 | 516 | def on_button_clicked_info3(b): 517 | with out: 518 | clear_output() 519 | print('Train') 520 | training() 521 | 522 | button_three.on_click(on_button_clicked_info3) 523 | 524 | def learn_dash(): 525 | button = widgets.Button(description="Learn") 526 | print ('Choosen metrics: ',metrics_list(mets_list)) 527 | metrics_list(mets_list) 528 | 529 | batch_val = int(dashboard_one.f.value) # batch size 530 | image_val = int(dashboard_one.m.value) # image size 531 | 532 | r = dashboard_one.pretrain_check.value 533 | t = metrics_list(mets_list) 534 | 535 | tfms = get_transforms(do_flip=dashboard_two.doflip.value, flip_vert=dashboard_two.dovert.value, max_zoom=dashboard_two.three.value, 536 | p_affine=dashboard_two.four.value, max_lighting=dashboard_two.five.value, p_lighting=dashboard_two.six.value, 537 | max_warp=dashboard_two.seven.value, max_rotate=dashboard_two.two.value, xtra_tfms=None) 538 | 539 | path = path_load.path_choice 540 | data = ImageDataBunch.from_folder(path, ds_tfms=tfms, bs=batch_val, size=image_val, test='test') 541 | 542 | learn = cnn_learner(data, base_arch=arch_work.info, pretrained=r, metrics=metrics_list(mets_list), custom_head=None) 543 | 544 | learn.lr_find() 545 | learn.recorder.plot() 546 | 547 | 548 | def model_button(): 549 | button_m = widgets.Button(description='Model') 550 | 551 | print('>> View Model information (model_summary, model[0], model[1])''\n\n''>> For xresnet: Pretrained needs to be set to FALSE') 552 | display(button_m) 553 | 554 | out_two = widgets.Output() 555 | display(out_two) 556 | 557 | def on_button_clicked_train(b): 558 | with out_two: 559 | clear_output() 560 | print('Your pretrained setting: ', dashboard_one.pretrain_check.value) 561 | model_summary() 562 | 563 | button_m.on_click(on_button_clicked_train) 564 | 565 | #def drive_upload(): 566 | # from google.colab import drive 567 | #print('mounting drive') 568 | #drive.mount('/content/gdrive', force_remount=True) 569 | #drive_upload.root_dir = "/content/gdrive/My Drive/" 570 | #print('drive mounted') 571 | 572 | def path_load(): 573 | 574 | #path = Path(get_path.output_variable) 575 | file_location = str(get_path.output_variable.value) 576 | #path_load.path_choice = path/file_location 577 | path_load.path_choice = file_location 578 | 579 | il = ImageList.from_folder(path_load.path_choice) 580 | print(len(il.items)) 581 | print(path_load.path_choice) 582 | 583 | def image_choice(): 584 | 585 | from ipywidgets import widgets 586 | button_choice = widgets.Button(description="Image Path") 587 | 588 | # Create text widget for output 589 | image_choice.output_variable = widgets.Text() 590 | display(image_choice.output_variable) 591 | 592 | display(button_choice) 593 | 594 | def get_path(): 595 | 596 | from ipywidgets import widgets 597 | button_choice = widgets.Button(description="Load Path") 598 | 599 | # Create text widget for output 600 | get_path.output_variable = widgets.Text() 601 | display(get_path.output_variable) 602 | 603 | display(button_choice) 604 | 605 | def on_button_clicked_summary(b): 606 | path_load() 607 | 608 | button_choice.on_click(on_button_clicked_summary) 609 | 610 | def metric_button(): 611 | button_b = widgets.Button(description="Metrics") 612 | print ('>> Click button to choose appropriate metrics') 613 | display(button_b) 614 | 615 | out = widgets.Output() 616 | display(out) 617 | 618 | def on_button_clicked_learn(b): 619 | with out: 620 | clear_output() 621 | arch_work() 622 | metrics_dashboard() 623 | 624 | button_b.on_click(on_button_clicked_learn) 625 | 626 | def training(): 627 | print('>> Using fit_one_cycle') 628 | button = widgets.Button(description='Train') 629 | 630 | style = {'description_width': 'initial'} 631 | 632 | layout = {'width':'90%', 'height': '50px', 'border': 'solid', 'fontcolor':'lightgreen'} 633 | layout_two = {'width':'100%', 'height': '200px', 'border': 'solid', 'fontcolor':'lightgreen'} 634 | style_green = {'handle_color': 'green', 'readout_color': 'red', 'slider_color': 'blue'} 635 | style_blue = {'handle_color': 'blue', 'readout_color': 'red', 'slider_color': 'blue'} 636 | 637 | training.cl=widgets.FloatSlider(min=1,max=64,step=1,value=1, continuous_update=False, layout=layout, style=style_green, description="Cycle Length") 638 | training.lr = widgets.ToggleButtons( 639 | options=['1e-6', '1e-5', '1e-4', '1e-3', '1e-2', '1e-1'], 640 | description='Learning Rate:', 641 | disabled=False, 642 | button_style='info', # 'success', 'info', 'warning', 'danger' or '' 643 | style=style, 644 | value='1e-2', 645 | tooltips=['Choose a suitable learning rate'], 646 | ) 647 | 648 | display(training.cl, training.lr) 649 | 650 | display(button) 651 | 652 | out = widgets.Output() 653 | display(out) 654 | 655 | def on_button_clicked(b): 656 | with out: 657 | clear_output() 658 | lr_work() 659 | print('>> Training....''\n''Learning Rate: ', lr_work.info) 660 | dashboard_one.datain.value, dashboard_one.norma.value, dashboard_one.archi.value, dashboard_one.pretrain_check.value, 661 | dashboard_one.f.value, dashboard_one.m.value, dashboard_two.doflip.value, dashboard_two.dovert.value, 662 | dashboard_two.two.value, dashboard_two.three.value, dashboard_two.seven.value, dashboard_two.four.value, dashboard_two.five.value, 663 | dashboard_two.six.value, dashboard_one.norma.value,metrics_list(mets_list) 664 | 665 | metrics_list(mets_list) 666 | 667 | batch_val = int(dashboard_one.f.value) # batch size 668 | image_val = int(dashboard_one.m.value) # image size 669 | 670 | #values for saving model 671 | value_mone = str(dashboard_one.archi.value) 672 | value_mtwo = str(dashboard_one.pretrain_check.value) 673 | value_mthree = str(round(dashboard_one.f.value)) 674 | value_mfour = str(round(dashboard_one.m.value)) 675 | 676 | r = dashboard_one.pretrain_check.value 677 | 678 | tfms = get_transforms(do_flip=dashboard_two.doflip.value, flip_vert=dashboard_two.dovert.value, max_zoom=dashboard_two.three.value, 679 | p_affine=dashboard_two.four.value, max_lighting=dashboard_two.five.value, p_lighting=dashboard_two.six.value, 680 | max_warp=dashboard_two.seven.value, max_rotate=dashboard_two.two.value, xtra_tfms=None) 681 | 682 | path = path_load.path_choice 683 | 684 | data = (ImageList.from_folder(path) 685 | .split_by_folder() 686 | .label_from_folder() 687 | .transform(tfms, size=image_val) 688 | .add_test_folder('test') 689 | .databunch(path=path)) 690 | 691 | learn = cnn_learner(data, base_arch=arch_work.info, pretrained=r, metrics=metrics_list(mets_list), custom_head=None, callback_fns=ShowGraph) 692 | 693 | cycle_l = int(training.cl.value) 694 | 695 | learn.fit_one_cycle(cycle_l, slice(lr_work.info)) 696 | 697 | #save model 698 | file_model_name = value_mone + '_pretrained_' + value_mtwo + '_batch_' + value_mthree + '_image_' + value_mfour 699 | 700 | learn.save(file_model_name) 701 | 702 | button.on_click(on_button_clicked) 703 | 704 | def lr_work(): 705 | if training.lr.value == '1e-6': 706 | lr_work.info = float(0.000001) 707 | elif training.lr.value == '1e-5': 708 | lr_work.info = float(0.00001) 709 | elif training.lr.value == '1e-4': 710 | lr_work.info = float(0.0001) 711 | elif training.lr.value == '1e-3': 712 | lr_work.info = float(0.001) 713 | elif training.lr.value == '1e-2': 714 | lr_work.info = float(0.01) 715 | elif training.lr.value == '1e-1': 716 | lr_work.info = float(0.1) 717 | 718 | def display_ui(): 719 | button = widgets.Button(description="Train") 720 | button_b = widgets.Button(description="Metrics") 721 | button_m = widgets.Button(description='Model') 722 | button_l = widgets.Button(description='LR') 723 | 724 | out1aa = widgets.Output() 725 | out1a = widgets.Output() 726 | out1 = widgets.Output() 727 | out2 = widgets.Output() 728 | out3 = widgets.Output() 729 | out4 = widgets.Output() 730 | out5 = widgets.Output() 731 | out6 = widgets.Output() 732 | 733 | data1aa = pd.DataFrame(np.random.normal(size = 50)) 734 | data1a = pd.DataFrame(np.random.normal(size = 100)) 735 | data1 = pd.DataFrame(np.random.normal(size = 150)) 736 | data2 = pd.DataFrame(np.random.normal(size = 200)) 737 | data3 = pd.DataFrame(np.random.normal(size = 250)) 738 | data4 = pd.DataFrame(np.random.normal(size = 300)) 739 | data5 = pd.DataFrame(np.random.normal(size = 350)) 740 | data6 = pd.DataFrame(np.random.normal(size = 400)) 741 | 742 | with out1aa: #path_choice 743 | print('path') 744 | get_path() 745 | 746 | with out1a: #info 747 | version() 748 | 749 | with out1: #data 750 | dashboard_one() 751 | 752 | with out2: #augmentation 753 | dashboard_two() 754 | 755 | with out3: #Batch 756 | print('Click to view Batch' '\n\n') 757 | view_batch_folder() 758 | 759 | with out4: #model 760 | print('>> View Model information (model_summary, model[0], model[1])''\n\n''>> For xresnet: Pretrained needs to be set to FALSE, setting to TRUE results in error: NameError: name model_urls is not defined') 761 | display(button_m) 762 | 763 | out_two = widgets.Output() 764 | display(out_two) 765 | 766 | def on_button_clicked_train(b): 767 | with out_two: 768 | clear_output() 769 | print('Your pretrained setting: ', dashboard_one.pretrain_check.value) 770 | model_summary() 771 | 772 | button_m.on_click(on_button_clicked_train) 773 | 774 | with out5: #Metrics 775 | print ('>> Click button to choose appropriate metrics') 776 | display(button_b) 777 | 778 | out = widgets.Output() 779 | display(out) 780 | 781 | def on_button_clicked_learn(b): 782 | with out: 783 | clear_output() 784 | arch_work() 785 | metrics_dashboard() 786 | 787 | button_b.on_click(on_button_clicked_learn) 788 | 789 | with out6: #train 790 | print ('>> Click to view training parameters and learning rate''\n''\n' 791 | '>> IMPORTANT: You have to go through METRICS tab prior to choosing LR') 792 | info_lr() 793 | 794 | tab = widgets.Tab(children = [out1aa, out1a, out1, out2, out3, out4, out5, out6]) 795 | tab.set_title(0, 'Path') 796 | tab.set_title(1, 'Info') 797 | tab.set_title(2, 'Data') 798 | tab.set_title(3, 'Augmentation') 799 | tab.set_title(4, 'Batch') 800 | tab.set_title(5, 'Model') 801 | tab.set_title(6, 'Metrics') 802 | tab.set_title(7, 'Train') 803 | display(tab) 804 | -------------------------------------------------------------------------------- /static/CM_FN.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asvcode/Vision_UI/9a54e808a2de5553383ae167051d9a8117696ba8/static/CM_FN.PNG -------------------------------------------------------------------------------- /static/CM_FP.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asvcode/Vision_UI/9a54e808a2de5553383ae167051d9a8117696ba8/static/CM_FP.PNG -------------------------------------------------------------------------------- /static/CM_TN.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asvcode/Vision_UI/9a54e808a2de5553383ae167051d9a8117696ba8/static/CM_TN.PNG -------------------------------------------------------------------------------- /static/CM_TP.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asvcode/Vision_UI/9a54e808a2de5553383ae167051d9a8117696ba8/static/CM_TP.PNG -------------------------------------------------------------------------------- /static/CM_eight.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asvcode/Vision_UI/9a54e808a2de5553383ae167051d9a8117696ba8/static/CM_eight.PNG -------------------------------------------------------------------------------- /static/CM_five.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asvcode/Vision_UI/9a54e808a2de5553383ae167051d9a8117696ba8/static/CM_five.PNG -------------------------------------------------------------------------------- /static/CM_four.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asvcode/Vision_UI/9a54e808a2de5553383ae167051d9a8117696ba8/static/CM_four.PNG -------------------------------------------------------------------------------- /static/CM_nine.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asvcode/Vision_UI/9a54e808a2de5553383ae167051d9a8117696ba8/static/CM_nine.PNG -------------------------------------------------------------------------------- /static/CM_one.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asvcode/Vision_UI/9a54e808a2de5553383ae167051d9a8117696ba8/static/CM_one.PNG -------------------------------------------------------------------------------- /static/CM_seven.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asvcode/Vision_UI/9a54e808a2de5553383ae167051d9a8117696ba8/static/CM_seven.PNG -------------------------------------------------------------------------------- /static/CM_six.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asvcode/Vision_UI/9a54e808a2de5553383ae167051d9a8117696ba8/static/CM_six.PNG -------------------------------------------------------------------------------- /static/CM_three.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asvcode/Vision_UI/9a54e808a2de5553383ae167051d9a8117696ba8/static/CM_three.PNG -------------------------------------------------------------------------------- /static/CM_two.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asvcode/Vision_UI/9a54e808a2de5553383ae167051d9a8117696ba8/static/CM_two.PNG -------------------------------------------------------------------------------- /static/LR.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asvcode/Vision_UI/9a54e808a2de5553383ae167051d9a8117696ba8/static/LR.PNG -------------------------------------------------------------------------------- /static/LR_one.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asvcode/Vision_UI/9a54e808a2de5553383ae167051d9a8117696ba8/static/LR_one.PNG -------------------------------------------------------------------------------- /static/LR_three.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asvcode/Vision_UI/9a54e808a2de5553383ae167051d9a8117696ba8/static/LR_three.PNG -------------------------------------------------------------------------------- /static/LR_two.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asvcode/Vision_UI/9a54e808a2de5553383ae167051d9a8117696ba8/static/LR_two.PNG -------------------------------------------------------------------------------- /static/Lr_four.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asvcode/Vision_UI/9a54e808a2de5553383ae167051d9a8117696ba8/static/Lr_four.PNG -------------------------------------------------------------------------------- /static/aug_one.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asvcode/Vision_UI/9a54e808a2de5553383ae167051d9a8117696ba8/static/aug_one.PNG -------------------------------------------------------------------------------- /static/aug_one2.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asvcode/Vision_UI/9a54e808a2de5553383ae167051d9a8117696ba8/static/aug_one2.PNG -------------------------------------------------------------------------------- /static/aug_three.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asvcode/Vision_UI/9a54e808a2de5553383ae167051d9a8117696ba8/static/aug_three.PNG -------------------------------------------------------------------------------- /static/aug_three2.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asvcode/Vision_UI/9a54e808a2de5553383ae167051d9a8117696ba8/static/aug_three2.PNG -------------------------------------------------------------------------------- /static/aug_three3.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asvcode/Vision_UI/9a54e808a2de5553383ae167051d9a8117696ba8/static/aug_three3.PNG -------------------------------------------------------------------------------- /static/aug_two.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asvcode/Vision_UI/9a54e808a2de5553383ae167051d9a8117696ba8/static/aug_two.PNG -------------------------------------------------------------------------------- /static/aug_two2.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asvcode/Vision_UI/9a54e808a2de5553383ae167051d9a8117696ba8/static/aug_two2.PNG -------------------------------------------------------------------------------- /static/aug_two3.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asvcode/Vision_UI/9a54e808a2de5553383ae167051d9a8117696ba8/static/aug_two3.PNG -------------------------------------------------------------------------------- /static/batch.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asvcode/Vision_UI/9a54e808a2de5553383ae167051d9a8117696ba8/static/batch.PNG -------------------------------------------------------------------------------- /static/batch_three.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asvcode/Vision_UI/9a54e808a2de5553383ae167051d9a8117696ba8/static/batch_three.PNG -------------------------------------------------------------------------------- /static/batch_two.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asvcode/Vision_UI/9a54e808a2de5553383ae167051d9a8117696ba8/static/batch_two.PNG -------------------------------------------------------------------------------- /static/cm_class.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asvcode/Vision_UI/9a54e808a2de5553383ae167051d9a8117696ba8/static/cm_class.PNG -------------------------------------------------------------------------------- /static/data.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asvcode/Vision_UI/9a54e808a2de5553383ae167051d9a8117696ba8/static/data.PNG -------------------------------------------------------------------------------- /static/data2.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asvcode/Vision_UI/9a54e808a2de5553383ae167051d9a8117696ba8/static/data2.PNG -------------------------------------------------------------------------------- /static/heatmap3.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asvcode/Vision_UI/9a54e808a2de5553383ae167051d9a8117696ba8/static/heatmap3.PNG -------------------------------------------------------------------------------- /static/info.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asvcode/Vision_UI/9a54e808a2de5553383ae167051d9a8117696ba8/static/info.PNG -------------------------------------------------------------------------------- /static/info_dashboard.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asvcode/Vision_UI/9a54e808a2de5553383ae167051d9a8117696ba8/static/info_dashboard.PNG -------------------------------------------------------------------------------- /static/metrics.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asvcode/Vision_UI/9a54e808a2de5553383ae167051d9a8117696ba8/static/metrics.PNG -------------------------------------------------------------------------------- /static/model.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asvcode/Vision_UI/9a54e808a2de5553383ae167051d9a8117696ba8/static/model.PNG -------------------------------------------------------------------------------- /static/visionUI2_part1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asvcode/Vision_UI/9a54e808a2de5553383ae167051d9a8117696ba8/static/visionUI2_part1.gif -------------------------------------------------------------------------------- /static/visionUI2_part2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asvcode/Vision_UI/9a54e808a2de5553383ae167051d9a8117696ba8/static/visionUI2_part2.gif -------------------------------------------------------------------------------- /static/visionUI2_part3.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asvcode/Vision_UI/9a54e808a2de5553383ae167051d9a8117696ba8/static/visionUI2_part3.gif -------------------------------------------------------------------------------- /viola_test.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 3, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "#hide\n", 10 | "from fastai2.vision.all import *\n", 11 | "from fastai2.vision.widgets import *\n", 12 | "\n", 13 | "import ipywidgets as widgets\n", 14 | "from IPython.display import display,clear_output" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 13, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "RED = '\\033[31m'\n", 24 | "BLUE = '\\033[94m'\n", 25 | "GREEN = '\\033[92m'\n", 26 | "BOLD = '\\033[1m'\n", 27 | "ITALIC = '\\033[3m'\n", 28 | "RESET = '\\033[0m'\n", 29 | "\n", 30 | "style = {'description_width': 'initial'}" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": 9, 36 | "metadata": {}, 37 | "outputs": [], 38 | "source": [ 39 | "def dashboard_one():\n", 40 | " \"\"\"GUI for first accordion window\"\"\"\n", 41 | " import psutil\n", 42 | " import torchvision\n", 43 | " try:\n", 44 | " import fastai2; fastver = fastai2.__version__\n", 45 | " except ImportError:\n", 46 | " fastver = 'fastai not found'\n", 47 | " try:\n", 48 | " import fastprogress; fastprog = fastprogress.__version__\n", 49 | " except ImportError:\n", 50 | " fastprog = 'fastprogress not found'\n", 51 | " try:\n", 52 | " import fastpages; fastp = fastpages.__version__\n", 53 | " except ImportError:\n", 54 | " fastp = 'fastpages not found'\n", 55 | " try:\n", 56 | " import nbdev; nbd = nbdev.__version__\n", 57 | " except ImportError:\n", 58 | " ndb = 'nbdev not found'\n", 59 | "\n", 60 | " print (BOLD + RED + '>> Vision_UI Update: 03/17/2020')\n", 61 | " style = {'description_width': 'initial'}\n", 62 | "\n", 63 | " button = widgets.Button(description='System', button_style='success')\n", 64 | " ex_button = widgets.Button(description='Explore', button_style='success')\n", 65 | " display(button)\n", 66 | "\n", 67 | " out = widgets.Output()\n", 68 | " display(out)\n", 69 | "\n", 70 | " def on_button_clicked_info(b):\n", 71 | " with out:\n", 72 | " clear_output()\n", 73 | " print(BOLD + BLUE + \"fastai2 Version: \" + RESET + ITALIC + str(fastver))\n", 74 | " print(BOLD + BLUE + \"nbdev Version: \" + RESET + ITALIC + str(nbd))\n", 75 | " print(BOLD + BLUE + \"fastprogress Version: \" + RESET + ITALIC + str(fastprog))\n", 76 | " print(BOLD + BLUE + \"fastpages Version: \" + RESET + ITALIC + str(fastp) + '\\n')\n", 77 | " print(BOLD + BLUE + \"python Version: \" + RESET + ITALIC + str(sys.version))\n", 78 | " print(BOLD + BLUE + \"torchvision: \" + RESET + ITALIC + str(torchvision.__version__))\n", 79 | " print(BOLD + BLUE + \"torch version: \" + RESET + ITALIC + str(torch.__version__))\n", 80 | " print(BOLD + BLUE + \"\\nCuda: \" + RESET + ITALIC + str(torch.cuda.is_available()))\n", 81 | " print(BOLD + BLUE + \"cuda Version: \" + RESET + ITALIC + str(torch.version.cuda))\n", 82 | " print(BOLD + BLUE + \"GPU: \" + RESET + ITALIC + str(torch.cuda.get_device_name(0)))\n", 83 | " print(BOLD + BLUE + \"\\nCPU%: \" + RESET + ITALIC + str(psutil.cpu_percent()))\n", 84 | " print(BOLD + BLUE + \"\\nmemory % used: \" + RESET + ITALIC + str(psutil.virtual_memory()[2]))\n", 85 | " button.on_click(on_button_clicked_info)\n" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": 14, 91 | "metadata": {}, 92 | "outputs": [], 93 | "source": [ 94 | "def dashboard_two():\n", 95 | " \"\"\"GUI for second accordion window\"\"\"\n", 96 | " dashboard_two.datas = widgets.ToggleButtons(\n", 97 | " options=['PETS', 'CIFAR', 'IMAGENETTE_160', 'IMAGEWOOF_160', 'MNIST_TINY'],\n", 98 | " description='Choose',\n", 99 | " value=None,\n", 100 | " disabled=False,\n", 101 | " button_style='info',\n", 102 | " tooltips=[''],\n", 103 | " style=style\n", 104 | " )\n", 105 | "\n", 106 | " display(dashboard_two.datas)\n", 107 | "\n", 108 | " button = widgets.Button(description='Explore', button_style='success')\n", 109 | " display(button)\n", 110 | " out = widgets.Output()\n", 111 | " display(out)\n", 112 | " def on_button_explore(b):\n", 113 | " with out:\n", 114 | " clear_output()\n", 115 | " ds_choice()\n", 116 | "\n", 117 | " button.on_click(on_button_explore)" 118 | ] 119 | }, 120 | { 121 | "cell_type": "code", 122 | "execution_count": 16, 123 | "metadata": {}, 124 | "outputs": [], 125 | "source": [ 126 | "def ds_choice():\n", 127 | " \"\"\"Helper for dataset choices\"\"\"\n", 128 | " if dashboard_two.datas.value == 'PETS':\n", 129 | " ds_choice.source = untar_data(URLs.DOGS)\n", 130 | " elif dashboard_two.datas.value == 'CIFAR':\n", 131 | " ds_choice.source = untar_data(URLs.CIFAR)\n", 132 | " elif dashboard_two.datas.value == 'IMAGENETTE_160':\n", 133 | " ds_choice.source = untar_data(URLs.IMAGENETTE_160)\n", 134 | " elif dashboard_two.datas.value == 'IMAGEWOOF_160':\n", 135 | " ds_choice.source = untar_data(URLs.IMAGEWOOF_160)\n", 136 | " elif dashboard_two.datas.value == 'MNIST_TINY':\n", 137 | " ds_choice.source = untar_data(URLs.MNIST_TINY)\n", 138 | "\n", 139 | " print(BOLD + BLUE + \"Dataset: \" + RESET + BOLD + RED + str(dashboard_two.datas.value))\n", 140 | " plt_classes()" 141 | ] 142 | }, 143 | { 144 | "cell_type": "code", 145 | "execution_count": 83, 146 | "metadata": {}, 147 | "outputs": [], 148 | "source": [ 149 | "def plt_classes():\n", 150 | " \"\"\"Helper for plotting classes in folder\"\"\"\n", 151 | " disp_img_but = widgets.Button(description='View Images?', button_style='success')\n", 152 | " Path.BASE_PATH = ds_choice.source\n", 153 | " train_source = (ds_choice.source/'train/').ls().items\n", 154 | " print(BOLD + BLUE + \"Folders: \" + RESET + BOLD + RED + str(train_source))\n", 155 | " print(BOLD + BLUE + \"\\n\" + \"No of classes: \" + RESET + BOLD + RED + str(len(train_source)))\n", 156 | "\n", 157 | " num_l = []\n", 158 | " class_l = []\n", 159 | " for j, name in enumerate(train_source):\n", 160 | " fol = (ds_choice.source/name).ls().sorted()\n", 161 | " names = str(name)\n", 162 | " class_split = names.split('train')\n", 163 | " class_l.append(class_split[1])\n", 164 | " num_l.append(len(fol))\n", 165 | "\n", 166 | " y_pos = np.arange(len(train_source))\n", 167 | " performance = num_l\n", 168 | " fig = plt.figure(figsize=(10,5))\n", 169 | "\n", 170 | " plt.style.use('seaborn')\n", 171 | " plt.bar(y_pos, performance, align='center', alpha=0.5, color=['black', 'red', 'green', 'blue', 'cyan'])\n", 172 | " plt.xticks(y_pos, class_l, rotation=90)\n", 173 | " plt.ylabel('Images')\n", 174 | " plt.title('Images per Class')\n", 175 | " plt.show()\n", 176 | "\n", 177 | " display(disp_img_but)\n", 178 | " img_out = widgets.Output()\n", 179 | " display(img_out)\n", 180 | " def on_disp_button(b):\n", 181 | " with img_out:\n", 182 | " clear_output()\n", 183 | " display_images()\n", 184 | " #display_i()\n", 185 | " disp_img_but.on_click(on_disp_button)" 186 | ] 187 | }, 188 | { 189 | "cell_type": "code", 190 | "execution_count": 75, 191 | "metadata": {}, 192 | "outputs": [], 193 | "source": [ 194 | "def display_images():\n", 195 | " \"\"\"Helper for displaying images from folder\"\"\"\n", 196 | " train_source = (ds_choice.source/'train/').ls().items\n", 197 | " for i, name in enumerate(train_source):\n", 198 | " fol = (ds_choice.source/name).ls().sorted()\n", 199 | " fol_disp = fol[0:5]\n", 200 | " filename = fol_disp.items\n", 201 | " fol_tensor = [tensor(Image.open(o)) for o in fol_disp]\n", 202 | " print(BOLD + BLUE + \"Loc: \" + RESET + str(name) + \" \" + BOLD + BLUE + \"Number of Images: \" + RESET +\n", 203 | " BOLD + RED + str(len(fol)))\n", 204 | "\n", 205 | " fig = plt.figure(figsize=(15,15))\n", 206 | " columns = 5\n", 207 | " rows = 1\n", 208 | " ax = []\n", 209 | "\n", 210 | " for i in range(columns*rows):\n", 211 | " for i, j in enumerate(fol_tensor):\n", 212 | " img = fol_tensor[i] # create subplot and append to ax\n", 213 | " ax.append( fig.add_subplot(rows, columns, i+1))\n", 214 | " ax[-1].set_title(\"ax:\"+str(filename[i])) # set title\n", 215 | " plt.tick_params(bottom=\"on\", left=\"on\")\n", 216 | " plt.imshow(img)\n", 217 | " plt.xticks([])\n", 218 | " plt.show()\n" 219 | ] 220 | }, 221 | { 222 | "cell_type": "code", 223 | "execution_count": 85, 224 | "metadata": {}, 225 | "outputs": [], 226 | "source": [ 227 | "def display_ui():\n", 228 | " \"\"\" Display tabs for visual display\"\"\"\n", 229 | " button = widgets.Button(description=\"Train\")\n", 230 | " button_b = widgets.Button(description=\"Metrics\")\n", 231 | " button_m = widgets.Button(description='Model')\n", 232 | " button_l = widgets.Button(description='LR')\n", 233 | "\n", 234 | " test_button = widgets.Button(description='Batch')\n", 235 | " test2_button = widgets.Button(description='Test2')\n", 236 | "\n", 237 | " out1a = widgets.Output()\n", 238 | " out1 = widgets.Output()\n", 239 | " out2 = widgets.Output()\n", 240 | " out3 = widgets.Output()\n", 241 | " out4 = widgets.Output()\n", 242 | " out5 = widgets.Output()\n", 243 | "\n", 244 | " data1a = pd.DataFrame(np.random.normal(size = 50))\n", 245 | " data1 = pd.DataFrame(np.random.normal(size = 100))\n", 246 | " data2 = pd.DataFrame(np.random.normal(size = 150))\n", 247 | " data3 = pd.DataFrame(np.random.normal(size = 200))\n", 248 | " data4 = pd.DataFrame(np.random.normal(size = 250))\n", 249 | " data5 = pd.DataFrame(np.random.normal(size = 300))\n", 250 | "\n", 251 | " with out1a: #info\n", 252 | " clear_output()\n", 253 | " dashboard_one()\n", 254 | "\n", 255 | " with out1: #data\n", 256 | " clear_output()\n", 257 | " dashboard_two()\n", 258 | "\n", 259 | " with out2: #augmentation\n", 260 | " clear_output()\n", 261 | " #aug_dash()\n", 262 | "\n", 263 | " with out3: #Block\n", 264 | " clear_output()\n", 265 | " #ds_3()\n", 266 | "\n", 267 | " with out4: #code\n", 268 | " clear_output()\n", 269 | " #write_code()\n", 270 | "\n", 271 | " with out5: #Imagewoof Play\n", 272 | " clear_output()\n", 273 | " #print(BOLD + BLUE + 'Work in progress.....')\n", 274 | " #play_button = widgets.Button(description='Parameters')\n", 275 | " #display(play_button)\n", 276 | " #play_out = widgets.Output()\n", 277 | " #display(play_out)\n", 278 | " #def button_play(b):\n", 279 | " # with play_out:\n", 280 | " # clear_output()\n", 281 | " # play_info()\n", 282 | " #play_button.on_click(button_play)\n", 283 | "\n", 284 | " display_ui.tab = widgets.Tab(children = [out1a, out1, out2, out3, out4, out5])\n", 285 | " display_ui.tab.set_title(0, 'Info')\n", 286 | " display_ui.tab.set_title(1, 'Data')\n", 287 | " display_ui.tab.set_title(2, 'Augmentation')\n", 288 | " display_ui.tab.set_title(3, 'DataBlock')\n", 289 | " display_ui.tab.set_title(4, 'Code')\n", 290 | " display_ui.tab.set_title(5, 'ImageWoof Play')\n", 291 | " display(display_ui.tab)\n" 292 | ] 293 | }, 294 | { 295 | "cell_type": "code", 296 | "execution_count": 86, 297 | "metadata": {}, 298 | "outputs": [ 299 | { 300 | "data": { 301 | "application/vnd.jupyter.widget-view+json": { 302 | "model_id": "e2cb716a70dc4a2fa2dc5db83604a1e4", 303 | "version_major": 2, 304 | "version_minor": 0 305 | }, 306 | "text/plain": [ 307 | "Tab(children=(Output(), Output(), Output(), Output(), Output(), Output()), _titles={'0': 'Info', '1': 'Data', …" 308 | ] 309 | }, 310 | "metadata": {}, 311 | "output_type": "display_data" 312 | } 313 | ], 314 | "source": [ 315 | "display_ui()" 316 | ] 317 | }, 318 | { 319 | "cell_type": "code", 320 | "execution_count": null, 321 | "metadata": {}, 322 | "outputs": [], 323 | "source": [] 324 | } 325 | ], 326 | "metadata": { 327 | "kernelspec": { 328 | "display_name": "Python 3", 329 | "language": "python", 330 | "name": "python3" 331 | }, 332 | "language_info": { 333 | "codemirror_mode": { 334 | "name": "ipython", 335 | "version": 3 336 | }, 337 | "file_extension": ".py", 338 | "mimetype": "text/x-python", 339 | "name": "python", 340 | "nbconvert_exporter": "python", 341 | "pygments_lexer": "ipython3", 342 | "version": "3.7.6" 343 | } 344 | }, 345 | "nbformat": 4, 346 | "nbformat_minor": 4 347 | } 348 | -------------------------------------------------------------------------------- /vision_ui2.py: -------------------------------------------------------------------------------- 1 | ############################### 2 | ##### Visual_UI version 2 ##### 3 | ##### Update 04/06/2020 ##### 4 | ############################### 5 | from fastai2.vision.all import* 6 | from utils import* 7 | 8 | from ipywidgets import interact, interactive, fixed, interact_manual, Box, Layout 9 | import ipywidgets 10 | import ipywidgets as widgets 11 | from ipywidgets import Layout, Button, Box, FloatText, Textarea, Dropdown, Label, IntSlider, ToggleButton, FloatSlider, VBox 12 | 13 | from IPython.display import display,clear_output, Javascript 14 | import webbrowser 15 | from IPython.display import YouTubeVideo 16 | 17 | from tkinter import Tk 18 | from tkinter import filedialog 19 | from tkinter.filedialog import askdirectory 20 | 21 | import torch 22 | import collections as co 23 | 24 | style = {'description_width': 'initial'} 25 | 26 | RED = '\033[31m' 27 | BLUE = '\033[94m' 28 | GREEN = '\033[92m' 29 | BOLD = '\033[1m' 30 | ITALIC = '\033[3m' 31 | RESET = '\033[0m' 32 | 33 | align_kw = dict( 34 | _css = (('.widget-label', 'min-width', '20ex'),), 35 | margin = '0px 0px 5px 12px') 36 | 37 | def dashboard_one(): 38 | """GUI for first accordion window""" 39 | import psutil 40 | import torchvision 41 | try: 42 | import fastai2; fastver = fastai2.__version__ 43 | except ImportError: 44 | fastver = 'fastai not found' 45 | try: 46 | import fastprogress; fastprog = fastprogress.__version__ 47 | except ImportError: 48 | fastprog = 'fastprogress not found' 49 | try: 50 | import fastpages; fastp = fastpages.__version__ 51 | except ImportError: 52 | fastp = 'fastpages not found' 53 | try: 54 | import nbdev; nbd = nbdev.__version__ 55 | except ImportError: 56 | nbd = 'nbdev not found' 57 | 58 | print (BOLD + RED + '>> Vision_UI Update: 03/17/2020') 59 | style = {'description_width': 'initial'} 60 | 61 | button = widgets.Button(description='System', button_style='success') 62 | ex_button = widgets.Button(description='Explore', button_style='success') 63 | display(button) 64 | 65 | out = widgets.Output() 66 | display(out) 67 | 68 | def on_button_clicked_info(b): 69 | with out: 70 | clear_output() 71 | print(BOLD + BLUE + "fastai2 Version: " + RESET + ITALIC + str(fastver)) 72 | print(BOLD + BLUE + "nbdev Version: " + RESET + ITALIC + str(nbd)) 73 | print(BOLD + BLUE + "fastprogress Version: " + RESET + ITALIC + str(fastprog)) 74 | print(BOLD + BLUE + "fastpages Version: " + RESET + ITALIC + str(fastp) + '\n') 75 | print(BOLD + BLUE + "python Version: " + RESET + ITALIC + str(sys.version)) 76 | print(BOLD + BLUE + "torchvision: " + RESET + ITALIC + str(torchvision.__version__)) 77 | print(BOLD + BLUE + "torch version: " + RESET + ITALIC + str(torch.__version__)) 78 | print(BOLD + BLUE + "\nCuda: " + RESET + ITALIC + str(torch.cuda.is_available())) 79 | print(BOLD + BLUE + "cuda Version: " + RESET + ITALIC + str(torch.version.cuda)) 80 | print(BOLD + BLUE + "GPU: " + RESET + ITALIC + str(torch.cuda.get_device_name(0))) 81 | print(BOLD + BLUE + "\nCPU%: " + RESET + ITALIC + str(psutil.cpu_percent())) 82 | print(BOLD + BLUE + "\nmemory % used: " + RESET + ITALIC + str(psutil.virtual_memory()[2])) 83 | button.on_click(on_button_clicked_info) 84 | 85 | print ('>> Resources') 86 | button_two = widgets.Button(description='Fastai Docs', button_style='info') 87 | button_three = widgets.Button(description='Fastai Forums', button_style='info') 88 | button_four = widgets.Button(description='Vision_UI github', button_style='info') 89 | 90 | but_two = widgets.HBox([button_two, button_three, button_four]) 91 | display(but_two) 92 | 93 | def on_doc_info(b): 94 | webbrowser.open('https://dev.fast.ai/') 95 | button_two.on_click(on_doc_info) 96 | 97 | def on_forum(b): 98 | webbrowser.open('https://forums.fast.ai/') 99 | button_three.on_click(on_forum) 100 | 101 | def vision_utube(b): 102 | webbrowser.open('https://github.com/asvcode/Vision_UI') 103 | button_four.on_click(vision_utube) 104 | 105 | def dashboard_two(): 106 | """GUI for second accordion window""" 107 | dashboard_two.datas = widgets.ToggleButtons( 108 | options=['PETS', 'CIFAR', 'IMAGENETTE_160', 'IMAGEWOOF_160', 'MNIST_TINY'], 109 | description='Choose', 110 | value=None, 111 | disabled=False, 112 | button_style='info', 113 | tooltips=[''], 114 | style=style 115 | ) 116 | 117 | display(dashboard_two.datas) 118 | 119 | button = widgets.Button(description='Explore', button_style='success') 120 | display(button) 121 | out = widgets.Output() 122 | display(out) 123 | def on_button_explore(b): 124 | with out: 125 | clear_output() 126 | ds_choice() 127 | 128 | button.on_click(on_button_explore) 129 | 130 | #Helpers for dashboard two 131 | def ds_choice(): 132 | """Helper for dataset choices""" 133 | if dashboard_two.datas.value == 'PETS': 134 | ds_choice.source = untar_data(URLs.DOGS) 135 | elif dashboard_two.datas.value == 'CIFAR': 136 | ds_choice.source = untar_data(URLs.CIFAR) 137 | elif dashboard_two.datas.value == 'IMAGENETTE_160': 138 | ds_choice.source = untar_data(URLs.IMAGENETTE_160) 139 | elif dashboard_two.datas.value == 'IMAGEWOOF_160': 140 | ds_choice.source = untar_data(URLs.IMAGEWOOF_160) 141 | elif dashboard_two.datas.value == 'MNIST_TINY': 142 | ds_choice.source = untar_data(URLs.MNIST_TINY) 143 | 144 | print(BOLD + BLUE + "Dataset: " + RESET + BOLD + RED + str(dashboard_two.datas.value)) 145 | plt_classes() 146 | 147 | def plt_classes(): 148 | """Helper for plotting classes in folder""" 149 | disp_img_but = widgets.Button(description='View Images?', button_style='success') 150 | Path.BASE_PATH = ds_choice.source 151 | train_source = (ds_choice.source/'train/').ls().items 152 | print(BOLD + BLUE + "Folders: " + RESET + BOLD + RED + str(train_source)) 153 | print(BOLD + BLUE + "\n" + "No of classes: " + RESET + BOLD + RED + str(len(train_source))) 154 | 155 | num_l = [] 156 | class_l = [] 157 | for j, name in enumerate(train_source): 158 | fol = (ds_choice.source/name).ls().sorted() 159 | names = str(name) 160 | class_split = names.split('train') 161 | class_l.append(class_split[1]) 162 | num_l.append(len(fol)) 163 | 164 | y_pos = np.arange(len(train_source)) 165 | performance = num_l 166 | 167 | plt.style.use('seaborn') 168 | plt.bar(y_pos, performance, align='center', alpha=0.5, color=['black', 'red', 'green', 'blue', 'cyan']) 169 | plt.xticks(y_pos, class_l, rotation=90) 170 | plt.ylabel('Images') 171 | plt.title('Images per Class') 172 | plt.show() 173 | 174 | display(disp_img_but) 175 | out_img = widgets.Output() 176 | display(out_img) 177 | def on_disp_button(b): 178 | with out_img: 179 | clear_output() 180 | display_images() 181 | disp_img_but.on_click(on_disp_button) 182 | 183 | def display_images(): 184 | """Helper for displaying images from folder""" 185 | train_source = (ds_choice.source/'train/').ls().items 186 | for i, name in enumerate(train_source): 187 | fol = (ds_choice.source/name).ls().sorted() 188 | fol_disp = fol[0:5] 189 | filename = fol_disp.items 190 | fol_tensor = [tensor(Image.open(o)) for o in fol_disp] 191 | print(BOLD + BLUE + "Loc: " + RESET + str(name) + " " + BOLD + BLUE + "Number of Images: " + RESET + 192 | BOLD + RED + str(len(fol))) 193 | 194 | fig = plt.figure(figsize=(30,13)) 195 | columns = 5 196 | rows = 1 197 | ax = [] 198 | 199 | for i in range(columns*rows): 200 | for i, j in enumerate(fol_tensor): 201 | img = fol_tensor[i] # create subplot and append to ax 202 | ax.append( fig.add_subplot(rows, columns, i+1)) 203 | ax[-1].set_title("ax:"+str(filename[i])) # set title 204 | plt.tick_params(bottom="on", left="on") 205 | plt.imshow(img) 206 | plt.xticks([]) 207 | plt.show() 208 | 209 | #Helpers for augmentation dashboard 210 | def aug_choice(): 211 | """Helper for whether augmentations are choosen or not""" 212 | view_button = widgets.Button(description='View') 213 | display(view_button) 214 | view_out = widgets.Output() 215 | display(view_out) 216 | def on_view_button(b): 217 | with view_out: 218 | clear_output() 219 | if aug_dash.aug.value == 'No': 220 | code_test() 221 | if aug_dash.aug.value == 'Yes': 222 | aug_paras() 223 | view_button.on_click(on_view_button) 224 | 225 | def aug_paras(): 226 | """If augmentations is choosen show available parameters""" 227 | print(BOLD + BLUE + "Choose Augmentation Parameters: ") 228 | button_paras = widgets.Button(description='Confirm', button_style='success') 229 | 230 | aug_paras.hh = widgets.ToggleButton(value=False, description='Erase', button_style='info', 231 | style=style) 232 | aug_paras.cc = widgets.ToggleButton(value=False, description='Contrast', button_style='info', 233 | style=style) 234 | aug_paras.dd = widgets.ToggleButton(value=False, description='Rotate', button_style='info', 235 | style=style) 236 | aug_paras.ee = widgets.ToggleButton(value=False, description='Warp', button_style='info', 237 | style=style) 238 | aug_paras.ff = widgets.ToggleButton(value=False, description='Bright', button_style='info', 239 | style=style) 240 | aug_paras.gg = widgets.ToggleButton(value=False, description='DihedralFlip', button_style='info', 241 | style=style) 242 | aug_paras.ii = widgets.ToggleButton(value=False, description='Zoom', button_style='info', 243 | style=style) 244 | 245 | qq = widgets.HBox([aug_paras.hh, aug_paras.cc, aug_paras.dd, aug_paras.ee, aug_paras.ff, aug_paras.gg, aug_paras.ii]) 246 | display(qq) 247 | display(button_paras) 248 | aug_par = widgets.Output() 249 | display(aug_par) 250 | def on_button_two_click(b): 251 | with aug_par: 252 | clear_output() 253 | aug() 254 | aug_dash_choice() 255 | button_paras.on_click(on_button_two_click) 256 | 257 | def aug(): 258 | """Aug choice helper""" 259 | #Erase 260 | if aug_paras.hh.value == True: 261 | aug.b_max = FloatSlider(min=0,max=50,step=1,value=0, description='max count', 262 | orientation='horizontal', disabled=False) 263 | aug.b_pval = FloatSlider(min=0,max=1,step=0.1,value=0, description=r"$p$", 264 | orientation='horizontal', disabled=False) 265 | aug.b_asp = FloatSlider(min=0.1,max=5, step=0.1, value=0.3, description=r'$aspect$', 266 | orientation='horizontal', disabled=False) 267 | aug.b_len = FloatSlider(min=0.1,max=5, step=0.1, value=0.3, description=r'$sl$', 268 | orientation='horizontal', disabled=False) 269 | aug.b_ht = FloatSlider(min=0.1,max=5, step=0.1, value=0.3, description=r'$sh$', 270 | orientation='horizontal', disabled=False) 271 | aug.erase_code = 'this is ERASE on' 272 | if aug_paras.hh.value == False: 273 | aug.b_max = FloatSlider(min=0,max=10,step=1,value=0, description='max count', 274 | orientation='horizontal', disabled=True) 275 | aug.b_pval = FloatSlider(min=0,max=1,step=0.1,value=0, description='p', 276 | orientation='horizontal', disabled=True) 277 | aug.b_asp = FloatSlider(min=0.1,max=1.7,value=0.3, description='aspect', 278 | orientation='horizontal', disabled=True) 279 | aug.b_len = FloatSlider(min=0.1,max=1.7,value=0.3, description='length', 280 | orientation='horizontal', disabled=True) 281 | aug.b_ht = FloatSlider(min=0.1,max=1.7,value=0.3, description='height', 282 | orientation='horizontal', disabled=True) 283 | aug.erase_code = 'this is ERASE OFF' 284 | #Contrast 285 | if aug_paras.cc.value == True: 286 | aug.b1_max = FloatSlider(min=0,max=0.9,step=0.1,value=0.2, description='max light', 287 | orientation='horizontal', disabled=False) 288 | aug.b1_pval = FloatSlider(min=0,max=1.0,step=0.05,value=0.75, description='p', 289 | orientation='horizontal', disabled=False) 290 | aug.b1_draw = FloatSlider(min=0,max=100,step=1,value=1, description='draw', 291 | orientation='horizontal', disabled=False) 292 | else: 293 | aug.b1_max = FloatSlider(min=0,max=0.9,step=0.1,value=0, description='max light', 294 | orientation='horizontal', disabled=True) 295 | aug.b1_pval = FloatSlider(min=0,max=1.0,step=0.05,value=0.75, description='p', 296 | orientation='horizontal', disabled=True) 297 | aug.b1_draw = FloatSlider(min=0,max=100,step=1,value=1, description='draw', 298 | orientation='horizontal', disabled=True) 299 | #Rotate 300 | if aug_paras.dd.value == True: 301 | aug.b2_max = FloatSlider(min=0,max=10,step=1,value=0, description='max degree', 302 | orientation='horizontal', disabled=False) 303 | aug.b2_pval = FloatSlider(min=0,max=1,step=0.1,value=0.5, description='p', 304 | orientation='horizontal', disabled=False) 305 | else: 306 | aug.b2_max = FloatSlider(min=0,max=10,step=1,value=0, description='max degree', 307 | orientation='horizontal', disabled=True) 308 | aug.b2_pval = FloatSlider(min=0,max=1,step=0.1,value=0, description='p', 309 | orientation='horizontal', disabled=True) 310 | #Warp 311 | if aug_paras.ee.value == True: 312 | aug.b3_mag = FloatSlider(min=0,max=10,step=1,value=0, description='magnitude', 313 | orientation='horizontal', disabled=False) 314 | aug.b3_pval = FloatSlider(min=0,max=1,step=0.1,value=0, description='p', 315 | orientation='horizontal', disabled=False) 316 | else: 317 | aug.b3_mag = FloatSlider(min=0,max=10,step=1,value=0, description='magnitude', 318 | orientation='horizontal', disabled=True) 319 | aug.b3_pval = FloatSlider(min=0,max=10,step=1,value=0, description='p', 320 | orientation='horizontal', disabled=True) 321 | #Bright 322 | if aug_paras.ff.value == True: 323 | aug.b4_max = FloatSlider(min=0,max=10,step=1,value=0, description='max light', 324 | orientation='horizontal', disabled=False) 325 | aug.b4_pval = FloatSlider(min=0,max=1,step=0.1,value=0, description='p', 326 | orientation='horizontal', disabled=False) 327 | else: 328 | aug.b4_max = FloatSlider(min=0,max=10,step=1,value=0, description='max_light', 329 | orientation='horizontal', disabled=True) 330 | aug.b4_pval = FloatSlider(min=0,max=1,step=0.1,value=0, description='p', 331 | orientation='horizontal', disabled=True) 332 | #DihedralFlip 333 | if aug_paras.gg.value == True: 334 | aug.b5_pval = FloatSlider(min=0,max=1,step=0.1, description='p', 335 | orientation='horizontal', disabled=False) 336 | aug.b5_draw = FloatSlider(min=0,max=7,step=1, description='p', 337 | orientation='horizontal', disabled=False) 338 | else: 339 | aug.b5_pval = FloatSlider(min=0,max=1,step=0.1, description='p', 340 | orientation='horizontal', disabled=True) 341 | aug.b5_draw = FloatSlider(min=0,max=7,step=1, description='p', 342 | orientation='horizontal', disabled=True) 343 | #Zoom 344 | if aug_paras.ii.value == True: 345 | aug.b6_zoom = FloatSlider(min=1,max=5,step=0.1, description='max_zoom', 346 | orientation='horizontal', disabled=False) 347 | aug.b6_pval = FloatSlider(min=0,max=1,step=0.1, description='p', 348 | orientation='horizontal', disabled=False) 349 | else: 350 | aug.b6_zoom = FloatSlider(min=1,max=5,step=0.1, description='max_zoom', 351 | orientation='horizontal', disabled=True) 352 | aug.b6_pval = FloatSlider(min=0,max=1,step=1, description='p', 353 | orientation='horizontal', disabled=True) 354 | 355 | #Single/Multi 356 | if aug_dash.bi.value == 'Single': 357 | aug.get_items = repeat_one 358 | aug.val = 'Single' 359 | if aug_dash.bi.value == 'Multi': 360 | aug.get_items = get_image_files 361 | aug.val = 'Multi' 362 | 363 | 364 | def aug_dash_choice(): 365 | """Augmention parameter display helper""" 366 | button_aug_dash = widgets.Button(description='View', button_style='success') 367 | item_erase_val= widgets.HBox([aug.b_max, aug.b_pval, aug.b_asp, aug.b_len, aug.b_ht]) 368 | item_erase = widgets.VBox([aug_paras.hh, item_erase_val]) 369 | 370 | item_contrast_val = widgets.HBox([aug.b1_max, aug.b1_pval, aug.b1_draw]) 371 | item_contrast = widgets.VBox([aug_paras.cc, item_contrast_val]) 372 | 373 | item_rotate_val = widgets.HBox([aug.b2_max, aug.b2_pval]) 374 | item_rotate = widgets.VBox([aug_paras.dd, item_rotate_val]) 375 | 376 | item_warp_val = widgets.HBox([aug.b3_mag, aug.b3_pval]) 377 | item_warp = widgets.VBox([aug_paras.ee, item_warp_val]) 378 | 379 | item_bright_val = widgets.HBox([aug.b4_max, aug.b4_pval]) 380 | item_bright = widgets.VBox([aug_paras.ff, item_bright_val]) 381 | 382 | item_dihedral_val = widgets.HBox([aug.b5_pval, aug.b5_draw]) 383 | item_dihedral = widgets.VBox([aug_paras.gg, item_dihedral_val]) 384 | 385 | item_zoom_val = widgets.HBox([aug.b6_zoom, aug.b6_pval]) 386 | item_zoom = widgets.VBox([aug_paras.ii, item_zoom_val]) 387 | 388 | items = [item_erase, item_contrast, item_rotate, item_warp, item_bright, item_dihedral, item_zoom] 389 | dia = Box(items, layout=Layout( 390 | display='flex', 391 | flex_flow='column', 392 | flex_grow=0, 393 | flex_wrap='wrap', 394 | border='solid 1px', 395 | align_items='flex-start', 396 | align_content='flex-start', 397 | justify_content='space-between', 398 | width='flex' 399 | )) 400 | display(dia) 401 | display(button_aug_dash) 402 | aug_dash_out = widgets.Output() 403 | display(aug_dash_out) 404 | def on_button_two(b): 405 | with aug_dash_out: 406 | clear_output() 407 | stats_info() 408 | #image_show() 409 | code_test() 410 | button_aug_dash.on_click(on_button_two) 411 | 412 | def aug_dash(): 413 | """GUI for augmentation dashboard""" 414 | aug_button = widgets.Button(description='Confirm', button_style='success') 415 | 416 | tb = widgets.Button(description='Batch Image', disabled=True, button_style='') 417 | aug_dash.bi = widgets.ToggleButtons(value='Multi', options=['Multi', 'Single'], description='', button_style='info', 418 | style=style, layout=Layout(width='auto')) 419 | tg = widgets.Button(description='Padding', disabled=True, button_style='') 420 | aug_dash.pad = widgets.ToggleButtons(value='zeros', options=['zeros', 'reflection', 'border'], description='', button_style='info', 421 | style=style, layout=Layout(width='auto')) 422 | th = widgets.Button(description='Normalization', disabled=True, button_style='') 423 | aug_dash.norm = widgets.ToggleButtons(value='Imagenet', options=['Imagenet', 'Mnist', 'Cifar', 'None'], description='', button_style='info', 424 | style=style, layout=Layout(width='auto')) 425 | tr = widgets.Button(description='Batch Size', disabled=True, button_style='warning') 426 | aug_dash.bs = widgets.ToggleButtons(value='16', options=['8', '16', '32', '64'], description='', button_style='warning', 427 | style=style, layout=Layout(width='auto')) 428 | spj = widgets.Button(description='Presizing', disabled=True, button_style='primary') 429 | te = widgets.Button(description='Item Size', disabled=True, button_style='primary') 430 | aug_dash.imgsiz = widgets.ToggleButtons(value='194', options=['28', '64', '128', '194', '254'], 431 | description='', button_style='primary', style=style, layout=Layout(width='auto')) 432 | to = widgets.Button(description='Batch Size', disabled=True, button_style='primary') 433 | aug_dash.imgbth = widgets.ToggleButtons(value='128', options=['28', '64', '128', '194', '254'], 434 | description='', button_style='primary', style=style, layout=Layout(width='auto')) 435 | tf = widgets.Button(description='Augmentation', disabled=True, button_style='danger') 436 | aug_dash.aug = widgets.ToggleButtons(value='No', options=['No', 'Yes'], description='', button_style='info', 437 | style=style, layout=Layout(width='auto')) 438 | 439 | it = [tb, aug_dash.bi] 440 | it2 = [tg, aug_dash.pad] 441 | it3 = [th, aug_dash.norm] 442 | it4 = [tr, aug_dash.bs] 443 | it5 = [te, aug_dash.imgsiz] 444 | it52 = [to, aug_dash.imgbth] 445 | it6 = [tf, aug_dash.aug] 446 | il = widgets.HBox(it) 447 | ij = widgets.HBox(it2) 448 | ik = widgets.HBox(it3) 449 | ie = widgets.HBox(it4) 450 | iw = widgets.HBox(it5) 451 | ip = widgets.HBox(it52) 452 | iq = widgets.HBox(it6) 453 | ir = widgets.VBox([il, ij, ik, ie, spj, iw, ip, iq]) 454 | display(ir) 455 | display(aug_button) 456 | 457 | aug_out = widgets.Output() 458 | display(aug_out) 459 | def on_aug_button(b): 460 | with aug_out: 461 | clear_output() 462 | aug_choice() 463 | aug_button.on_click(on_aug_button) 464 | 465 | #Helpers for ds_3 466 | def stats_info(): 467 | """Stats helper""" 468 | if aug_dash.norm.value == 'Imagenet': 469 | stats_info.stats = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 470 | stats_info.code = ('*imagenet_stats') 471 | if aug_dash.norm.value == 'Cifar': 472 | stats_info.stats = ([0.491, 0.482, 0.447], [0.247, 0.243, 0.261]) 473 | stats_info.code = ('*cifar_stats') 474 | if aug_dash.norm.value == 'Mnist': 475 | stats_info.stats = ([0.15, 0.15, 0.15], [0.15, 0.15, 0.15]) 476 | stats_info.code = ('*mnist_stats') 477 | if aug_dash.norm.value == 'None': 478 | stats_info.stats = ([0., 0., 0.], [0., 0., 0.]) 479 | stats_info.code = ('[0., 0., 0.], [0., 0., 0.]') 480 | stats = stats_info.stats 481 | 482 | def repeat_one(source, n=128): 483 | """Single image helper for displaying batch""" 484 | return [get_image_files(ds_choice.source)[9]]*n 485 | 486 | def block_ch(): 487 | """Helper for configuring mid-level datablock""" 488 | if ds_3.d1.value == 'PILImage': 489 | block_ch.cls = PILImage.create 490 | block_ch.code = 'PILImage.create' 491 | else: 492 | block_ch.cls = PILImageBW.create 493 | block_ch.code = 'PILImageBW.create' 494 | if ds_3.e1.value == 'CategoryBlock': 495 | block_ch.ctg = Categorize 496 | block_ch.ctg_code = 'Categorize' 497 | else: 498 | block_ch.ctg = MultiCategorize 499 | block_ch.ctg_code = 'MultiCategorize' 500 | if ds_3.g.value == True: 501 | block_ch.outputb = parent_label 502 | block_ch.outputb_code = 'parent_label' 503 | else: 504 | block_ch.outputb = None 505 | if ds_3.s.value == 'RandomSplitter': 506 | block_ch.spl_val = RandomSplitter() 507 | block_ch.spl_val_code = 'RandomSplitter()' 508 | if ds_3.s.value == 'GrandparentSplitter': 509 | block_ch.spl_val = GrandparentSplitter() 510 | block_ch.spl_val_code = 'GrandparentSplitter()' 511 | 512 | def ds_3(): 513 | """GUI for 3rd Accordion window""" 514 | db_button = widgets.Button(description='Confirm') 515 | ds_3.d1 = widgets.ToggleButtons(value=None, options=['PILImage', 'PILImageBW'], description='', button_style='info') 516 | ds_3.e1 = widgets.ToggleButtons(value=None, options=['CategoryBlock', 'MultiCategoryBlock'], description='', button_style='warning') 517 | ds_3.g = widgets.ToggleButton(description='parent_label', button_style='danger', value=False) 518 | ds_3.s = widgets.ToggleButtons(value=None, options=['RandomSplitter', 'GrandparentSplitter'], 519 | description='', button_style='success') 520 | form_items = [ds_3.d1, ds_3.e1, ds_3.s ,ds_3.g] 521 | 522 | form_t = Layout(display='flex', 523 | flex_flow='row', 524 | align_items='stretch', 525 | border='solid 1px', 526 | width='100%') 527 | form = Box(children=form_items, layout=form_t) 528 | display(form) 529 | display(db_button) 530 | db_out = widgets.Output() 531 | display(db_out) 532 | def on_db_click(b): 533 | with db_out: 534 | clear_output() 535 | block_ch() 536 | code_test() 537 | db_button.on_click(on_db_click) 538 | 539 | def code_test(): 540 | """Helpers""" 541 | db_button2 = widgets.Button(description='DataBlock') 542 | stats_info() 543 | method = ResizeMethod.Pad 544 | 545 | item_size = int(aug_dash.imgsiz.value) 546 | final_size = int(aug_dash.imgbth.value) 547 | 548 | if aug_dash.bi.value == 'Single': 549 | code_test.items = repeat_one(ds_choice.source) 550 | if aug_dash.bi.value == 'Multi': 551 | code_test.items = get_image_files(ds_choice.source) 552 | if aug_dash.aug.value == 'No': 553 | print(BOLD + BLUE + "working.....: " + RESET + RED + 'No Augmentations\n') 554 | print(BOLD + BLUE + "Multi/Single Image: " + RESET + RED + str(aug_dash.bi.value)) 555 | print(BOLD + BLUE + "Padding: " + RESET + RED + str(aug_dash.pad.value)) 556 | print(BOLD + BLUE + "Normalization: " + RESET + RED + str(stats_info.stats)) 557 | print(BOLD + BLUE + "Batch Size: " + RESET + RED + (aug_dash.bs.value)) 558 | print(BOLD + BLUE + "Item Size: " + RESET + RED + str(item_size)) 559 | print(BOLD + BLUE + "Final Size: " + RESET + RED + str(final_size)) 560 | after_b = None 561 | if aug_dash.aug.value == 'Yes': 562 | print(BOLD + BLUE + "working.....: " + RESET + RED + 'Augmentations\n') 563 | print(BOLD + BLUE + "RandomErasing: " + RESET + RED + 'max_count=' + str(aug.b_max.value) + ' p=' + str(aug.b_pval.value)) 564 | print(BOLD + BLUE + "Contrast: " + RESET + RED + 'max_light=' + str(aug.b1_max.value) + ' p=' + str(aug.b1_pval.value)) 565 | print(BOLD + BLUE + "Rotate: " + RESET + RED + 'max_degree=' + str(aug.b2_max.value) + ' p=' + str(aug.b2_pval.value)) 566 | print(BOLD + BLUE + "Warp: " + RESET + RED + 'magnitude=' + str(aug.b3_mag.value) + ' p=' + str(aug.b3_pval.value)) 567 | print(BOLD + BLUE + "Brightness: " + RESET + RED + 'max_light=' + str(aug.b4_max.value) + ' p=' + str(aug.b4_pval.value)) 568 | print(BOLD + BLUE + "DihedralFlip: " + RESET + RED + ' p=' + str(aug.b5_pval.value) + 'draw=' + str(aug.b5_draw.value)) 569 | print(BOLD + BLUE + "Zoom: " + RESET + RED + 'max_zoom=' + str(aug.b6_zoom.value) + ' p=' + str(aug.b6_pval.value)) 570 | print(BOLD + BLUE + "\nMulti/Single Image: " + RESET + RED + str(aug_dash.bi.value)) 571 | print(BOLD + BLUE + "Padding: " + RESET + RED + str(aug_dash.pad.value)) 572 | print(BOLD + BLUE + "Normalization: " + RESET + RED + str(stats_info.stats)) 573 | print(BOLD + BLUE + "Batch Size: " + RESET + RED + (aug_dash.bs.value)) 574 | print(BOLD + BLUE + "Item Size: " + RESET + RED + str(item_size)) 575 | print(BOLD + BLUE + "Final Size: " + RESET + RED + str(final_size)) 576 | 577 | xtra_tfms = [RandomErasing(p=aug.b_pval.value, max_count=aug.b_max.value, min_aspect=aug.b_asp.value, sl=aug.b_len.value, sh=aug.b_ht.value), #p= probabilty 578 | Brightness(max_lighting=aug.b4_max.value, p=aug.b4_pval.value, draw=None, batch=None), 579 | Rotate(max_deg=aug.b2_max.value, p=aug.b2_pval.value, draw=None, size=None, mode='bilinear', pad_mode=aug_dash.pad.value), 580 | Warp(magnitude=aug.b3_mag.value,p=aug.b3_pval.value,draw_x=None,draw_y=None,size=None,mode='bilinear',pad_mode=aug_dash.pad.value,batch=False,), 581 | Contrast(max_lighting=aug.b1_max.value, p=aug.b1_pval.value, draw=aug.b1_draw.value, batch=True), #draw = 1 is normal batch=batch tfms or not 582 | Dihedral(p=aug.b5_pval.value, draw=aug.b5_draw.value, size=None, mode='bilinear', pad_mode=PadMode.Reflection, batch=False), 583 | Zoom(max_zoom=aug.b6_zoom.value, p=aug.b6_pval.value, draw=None, draw_x=None, draw_y=None, size=None, mode='bilinear',pad_mode=aug_dash.pad.value, batch=False) 584 | ] 585 | 586 | after_b = [Resize(final_size), IntToFloatTensor(), *aug_transforms(xtra_tfms=xtra_tfms, pad_mode=aug_dash.pad.value), 587 | Normalize(stats_info.stats)] 588 | 589 | if display_ui.tab.selected_index == 2: #>>> Augmentation tab 590 | 591 | tfms = [[PILImage.create], [parent_label, Categorize]] 592 | item_tfms = [ToTensor(), Resize(item_size)] 593 | dsets = Datasets(code_test.items, tfms=tfms) 594 | dls = dsets.dataloaders(after_item=item_tfms, after_batch=after_b, bs=int(aug_dash.bs.value), num_workers=0) 595 | 596 | dls.show_batch(max_n=12, nrows=2, ncols=6) 597 | 598 | if display_ui.tab.selected_index == 3: #>>> DataBlock tab 599 | 600 | items = get_image_files(ds_choice.source/'train') 601 | split_idx = block_ch.spl_val(items) 602 | tfms = [[block_ch.cls], [block_ch.outputb, block_ch.ctg]] 603 | item_tfms = [ToTensor(), Resize(item_size)] 604 | dsets = Datasets(items, tfms=tfms, splits=split_idx) 605 | dls = dsets.dataloaders(after_item=item_tfms, after_batch=after_b, bs=int(aug_dash.bs.value), num_workers=0) 606 | 607 | display(db_button2) 608 | db_out = widgets.Output() 609 | display(db_out) 610 | def on_db_out(b): 611 | clear_output() 612 | xb, yb = dls.one_batch() 613 | print(BOLD + BLUE + "Train: " + RESET + RED + '(' + str(len(dls.train)) + ', ' + str(len(dls.train_ds)) + ') ' + 614 | BOLD + BLUE + "Valid: "+ RESET + RED + '(' + str(len(dls.valid)) + ', ' + str(len(dls.valid_ds)) + ')') 615 | print(BOLD + BLUE + "Input Shape: " + RESET + RED + str(xb.shape)) 616 | print(BOLD + BLUE + "Output Shape: " + RESET + RED + str(yb.shape) + " by " + str(dls.c) + " classes") 617 | dls.show_batch(max_n=12, nrows=2, ncols=6) 618 | db_button2.on_click(on_db_out) 619 | 620 | def write_code(): 621 | """Helper for writing code""" 622 | write_button = widgets.Button(description='Code', button_style = 'success') 623 | display(write_button) 624 | write_out = widgets.Output() 625 | display(write_out) 626 | def on_write_button(b): 627 | with write_out: 628 | clear_output() 629 | print(RED + BOLD + '"""import libraries"""' + RESET) 630 | print(GREEN + BOLD + 'from' + RESET + ' fastai2.vision.all ' + GREEN + BOLD + 'import*' + RESET) 631 | print(RED + BOLD + '\n"""get data source and image files from source"""' + RESET) 632 | print('source = untar_data(URLs.' + str(dashboard_two.datas.value) + ')') 633 | print('items = get_image_files(source)') 634 | print(RED + BOLD + '\n"""get item, split and batch transforms"""' + RESET) 635 | print('tfms = [[' + str(block_ch.code) + ']' + ' ,' + '[' + str(block_ch.outputb_code) + ', ' 636 | + str(block_ch.ctg_code) + ']]') 637 | print('item_tfms = [ToTensor(), Resize(' + GREEN + str(aug_dash.imgsiz.value) + RESET + ')]') 638 | print('split_idx = ' + str(block_ch.spl_val_code) + '(items)') 639 | print(RED + BOLD + '\n"""image augmentations"""' + RESET) 640 | if aug_dash.aug.value == 'No': 641 | print('xtra_tfms = '+ GREEN + BOLD + 'None' + RESET) 642 | if aug_dash.aug.value == 'Yes': 643 | print(RESET + 'xtra_tfms = [RandomErasing(p=' + GREEN + str(aug.b_pval.value) + RESET + ', max_count=' + 644 | GREEN + str(aug.b_max.value) + RESET + ', min_aspect=' + GREEN + str(aug.b_asp.value) + RESET + ', sl=' + 645 | GREEN + str(aug.b_len.value) + RESET + ', sh=' + RESET + GREEN + str(aug.b_ht.value) + RESET + '),') 646 | print(RESET + ' Brightness(max_lighting=' + GREEN + str(aug.b4_max.value) + RESET + ', p=' + GREEN + 647 | str(aug.b4_pval.value) + RESET + ', draw=' + GREEN + BOLD + 'None' + RESET + ', batch=' + GREEN + BOLD + 'None' + 648 | RESET + ')') 649 | print(RESET + ' Rotate(max_deg=' + GREEN + str(aug.b2_max.value) + RESET + ', p=' + GREEN + 650 | str(aug.b2_pval.value) + RESET + ', draw=' + GREEN + BOLD + 'None' + RESET + ', size=' + GREEN + BOLD + 'None' + 651 | RESET + ', mode=' + RED + "'bilinear'" + RESET + ', pad_mode=' + RED + str(aug_dash.pad.value) + RESET + ')') 652 | print(RESET + ' Warp(magnitude=' + GREEN + str(aug.b3_mag.value) + RESET + ', p=' + GREEN + 653 | str(aug.b3_pval.value) + RESET + ', draw_x=' + GREEN + BOLD + 'None' + RESET + ', draw_y=' + GREEN + BOLD + 654 | 'None' + RESET + ', size=' + GREEN + BOLD + 'None' + RESET + ', mode=' + RED + "'bilinear'" + RESET + ', pad_mode=' + 655 | RED + str(aug_dash.pad.value) + RESET + ', batch=' + GREEN + BOLD + 'False' + RESET + ')') 656 | print(RESET + ' Contrast(max_lighting=' + GREEN + str(aug.b1_max.value) + RESET + ', p=' + GREEN + 657 | str(aug.b1_pval.value) + RESET + ', draw=' + GREEN + str(aug.b1_draw.value) + RESET + ', batch=' + GREEN + BOLD + 658 | 'True' + RESET + ')') 659 | print(RESET + ' Dihedral(p=' + GREEN + str(aug.b5_pval.value) + RESET + ', draw' + GREEN + 660 | str(aug.b5_draw.value) + RESET + ', size=' + GREEN + BOLD + 'None' + RESET + ', mode=' + RED + "'bilinear'" + RESET + 661 | ', pad_mode=' + RED + str(aug_dash.pad.value) + RESET + ', batch=' + GREEN + BOLD + 'False' + RESET + ')') 662 | print(RESET + ' Zoom(max_zoom=' + GREEN + str(aug.b6_zoom.value) + RESET + ', p=' + GREEN + 663 | str(aug.b6_pval.value) + RESET + ', draw=' + GREEN + BOLD + 'None' + RESET + ', draw_x=' + GREEN + BOLD + 'None' + RESET + 664 | ', draw_y=' + GREEN + BOLD + 'None' + RESET + ', size=' + GREEN + BOLD + 'None' + RESET + ', mode=' + RED + "'bilinear'" + 665 | RESET + ', pad_mode=' + RED + str(aug_dash.pad.value) + RESET + ', batch=' + GREEN + BOLD + 'False' + RESET + ')]') 666 | 667 | print('\nafter_b = [Resize(' + GREEN + str(aug_dash.imgbth.value) + RESET + '), IntToFloatTensor(), ' + '\n ' + "*aug_transforms(xtra_tfms=xtra_tfms, pad_mode=" 668 | + RED + "'" + str(aug_dash.pad.value) + "'" + RESET + '),' + ' Normalize.from_stats(' + GREEN + str(stats_info.code) + RESET + ')]') 669 | 670 | print('\ndsets = Datasets(items, tfms=tfms, splits=split_idx)') 671 | print('dls = dsets.dataloaders(after_item=item_tfms, after_batch=after_b, bs=' + GREEN + (aug_dash.bs.value) + RESET + 672 | ', num_workers=' + GREEN + '0' + RESET + ')') 673 | 674 | 675 | print(RED + BOLD + '\n"""Check training and valid shapes"""' + RESET) 676 | print('xb, yb = dls.one_batch()') 677 | print(RESET + 'dls.train' + RESET + RED + BOLD + ' #train') 678 | print(RESET + 'dls.train_ds' + RESET + RED + BOLD + '#train_ds') 679 | print(RESET + 'dls.valid' + RESET + RED + BOLD + ' #valid') 680 | print(RESET + 'dls.valid_ds' + RESET + RED + BOLD + ' #valid_ds') 681 | 682 | print(RESET + RED + BOLD + '\n"""show batch"""' + RESET) 683 | print('dls.show_batch(max_n=' + GREEN + '12' + RESET + ', rows=' + GREEN + '2' + RESET + ', cols=' + 684 | GREEN + '6' + RESET + ')') 685 | print(RESET + RED + BOLD + '\n"""train"""' + RESET) 686 | print('arch = xresnet50(pretrained=' + GREEN + BOLD + 'False' + RESET + ')') 687 | print('learn = Learner(dls, model=arch, loss_function=LabelSmoothingCrossEntropy(),' + 688 | '\n metrics=[top_k_accuracy, accuracy])') 689 | print(RESET + 'learn.fit_one_cycle(' + GREEN + '1' + RESET + ', ' + GREEN + '1e-2' + RESET + ')') 690 | print(RESET + RED + BOLD + '\n"""interpretations"""' + RESET) 691 | print('interp = ClassificationInterpretation.from_learner(learn)') 692 | print('losses, idxs = interp.top_losses()') 693 | print(GREEN + 'len' + RESET + '(dls.valid_ds)' + RESET + '==' + GREEN + 'len' + RESET + '(losses)==' + 694 | GREEN + 'len' + RESET + '(idxs)') 695 | print(RESET + RED + BOLD + '\n"""confusion matrix"""' + RESET) 696 | print('interp.plot_confusion_matrix(figsize=(' + GREEN + '7' + RESET + ',' + GREEN + '7' + RESET + '))') 697 | write_button.on_click(on_write_button) 698 | 699 | def play_info(): 700 | """Helper for imagewoof play""" 701 | 702 | item_size = int(aug_dash.imgsiz.value) 703 | final_size = int(aug_dash.imgbth.value) 704 | 705 | if aug_dash.aug.value == 'No': 706 | print(BOLD + BLUE + "working.....: " + RESET + RED + 'No Augmentations\n') 707 | print(BOLD + BLUE + "Multi/Single Image: " + RESET + RED + str(aug_dash.bi.value)) 708 | print(BOLD + BLUE + "Padding: " + RESET + RED + str(aug_dash.pad.value)) 709 | print(BOLD + BLUE + "Normalization: " + RESET + RED + str(stats_info.stats)) 710 | print(BOLD + BLUE + "Batch Size: " + RESET + RED + (aug_dash.bs.value)) 711 | print(BOLD + BLUE + "Item Size: " + RESET + RED + str(item_size)) 712 | print(BOLD + BLUE + "Final Size: " + RESET + RED + str(final_size)) 713 | after_b = None 714 | 715 | if aug_dash.aug.value == 'Yes': 716 | print(BOLD + BLUE + 'Loading ImageWoof-160\n' + RESET) 717 | print(BOLD + BLUE + "Current Augmentations:" + RESET) 718 | print(BOLD + BLUE + "RandomErasing: " + RESET + RED + 'max_count=' + str(aug.b_max.value) + ' p=' + str(aug.b_pval.value)) 719 | print(BOLD + BLUE + "Contrast: " + RESET + RED + 'max_light=' + str(aug.b1_max.value) + ' p=' + str(aug.b1_pval.value)) 720 | print(BOLD + BLUE + "Rotate: " + RESET + RED + 'max_degree=' + str(aug.b2_max.value) + ' p=' + str(aug.b2_pval.value)) 721 | print(BOLD + BLUE + "Warp: " + RESET + RED + 'magnitude=' + str(aug.b3_mag.value) + ' p=' + str(aug.b3_pval.value)) 722 | print(BOLD + BLUE + "Brightness: " + RESET + RED + 'max_light=' + str(aug.b4_max.value) + ' p=' + str(aug.b4_pval.value)) 723 | print(BOLD + BLUE + "DihedralFlip: " + RESET + RED + ' p=' + str(aug.b5_pval.value) + str(aug.b5_draw.value)) 724 | print(BOLD + BLUE + "Zoom: " + RESET + RED + 'max_zoom=' + str(aug.b6_zoom.value) + ' p=' + str(aug.b6_pval.value)) 725 | print(BOLD + BLUE + "\nMulti/Single Image: " + RESET + RED + str(aug_dash.bi.value)) 726 | print(BOLD + BLUE + "Padding: " + RESET + RED + str(aug_dash.pad.value)) 727 | print(BOLD + BLUE + "Normalization: " + RESET + RED + str(stats_info.stats)) 728 | print(BOLD + BLUE + "Batch Size: " + RESET + RED + (aug_dash.bs.value)) 729 | print(BOLD + BLUE + "Item Size: " + RESET + RED + str(item_size)) 730 | print(BOLD + BLUE + "Final Size: " + RESET + RED + str(final_size)) 731 | 732 | xtra_tfms = [RandomErasing(p=aug.b_pval.value, max_count=aug.b_max.value, min_aspect=aug.b_asp.value, sl=aug.b_len.value, sh=aug.b_ht.value), #p= probabilty 733 | Brightness(max_lighting=aug.b4_max.value, p=aug.b4_pval.value, draw=None, batch=None), 734 | Rotate(max_deg=aug.b2_max.value, p=aug.b2_pval.value, draw=None, size=None, mode='bilinear', pad_mode=aug_dash.pad.value), 735 | Warp(magnitude=aug.b3_mag.value,p=aug.b3_pval.value,draw_x=None,draw_y=None,size=None,mode='bilinear',pad_mode='reflection',batch=False,), 736 | Contrast(max_lighting=aug.b1_max.value, p=aug.b1_pval.value, draw=aug.b1_draw.value, batch=True), #draw = 1 is normal batch=batch tfms or not 737 | Dihedral(p=aug.b5_pval.value, draw=aug.b5_draw.value, size=None, mode='bilinear', pad_mode=PadMode.Reflection, batch=False), 738 | Zoom(max_zoom=aug.b6_zoom.value, p=aug.b6_pval.value, draw=None, draw_x=None, draw_y=None, size=None, mode='bilinear',pad_mode=aug_dash.pad.value, batch=False) 739 | ] 740 | after_b = [Resize(final_size), IntToFloatTensor(), *aug_transforms(xtra_tfms=xtra_tfms, pad_mode=aug_dash.pad.value), 741 | Normalize(stats_info.stats)] 742 | 743 | source_play = untar_data(URLs.IMAGEWOOF_160) 744 | items = get_image_files(source_play/'train') 745 | 746 | tfms = [[PILImage.create], [parent_label, Categorize]] 747 | item_tfms = [ToTensor(), Resize(item_size)] 748 | dsets = Datasets(items, tfms=tfms) 749 | dls = dsets.dataloaders(after_item=item_tfms, after_batch=after_b, bs=int(aug_dash.bs.value), num_workers=0) 750 | 751 | dls.show_batch(max_n=6, nrows=1, ncols=6) 752 | imagewoof_plaz() 753 | 754 | def imagewoof_plaz(): 755 | 756 | item_size = int(aug_dash.imgsiz.value) 757 | final_size = int(aug_dash.imgbth.value) 758 | 759 | button_t2 = widgets.Button(description='Play') 760 | button_res = widgets.Button(description='View') 761 | res = widgets.ToggleButtons(value=None, options=['Confusion Matrix', 'Most Confused', 'Top Losses'], description='', button_style='info') 762 | play_out = widgets.Output() 763 | display(button_t2) 764 | display(play_out) 765 | def on_play_button(b): 766 | with play_out: 767 | clear_output() 768 | print('Training.....') 769 | if aug_dash.aug.value == 'No': 770 | after_b = None 771 | if aug_dash.aug.value == 'Yes': 772 | 773 | xtra_tfms = [RandomErasing(p=aug.b_pval.value, max_count=aug.b_max.value, min_aspect=aug.b_asp.value, sl=aug.b_len.value, sh=aug.b_ht.value), #p= probabilty 774 | Brightness(max_lighting=aug.b4_max.value, p=aug.b4_pval.value, draw=None, batch=None), 775 | Rotate(max_deg=aug.b2_max.value, p=aug.b2_pval.value, draw=None, size=None, mode='bilinear', pad_mode=aug_dash.pad.value), 776 | Warp(magnitude=aug.b3_mag.value,p=aug.b3_pval.value,draw_x=None,draw_y=None,size=None,mode='bilinear',pad_mode='reflection',batch=False,), 777 | Contrast(max_lighting=aug.b1_max.value, p=aug.b1_pval.value, draw=aug.b1_draw.value, batch=True), #draw = 1 is normal batch=batch tfms or not 778 | Dihedral(p=aug.b5_pval.value, draw=aug.b5_draw.value, size=None, mode='bilinear', pad_mode=PadMode.Reflection, batch=False), 779 | Zoom(max_zoom=aug.b6_zoom.value, p=aug.b6_pval.value, draw=None, draw_x=None, draw_y=None, size=None, mode='bilinear',pad_mode=PadMode.Reflection, batch=False) 780 | ] 781 | 782 | after_b = [Resize(final_size), IntToFloatTensor(), *aug_transforms(xtra_tfms=xtra_tfms, pad_mode=aug_dash.pad.value), 783 | Normalize(stats_info.stats)] 784 | 785 | source_play = untar_data(URLs.IMAGEWOOF_160) 786 | items = get_image_files(source_play/'train') 787 | 788 | split_idx = block_ch.spl_val(items) 789 | tfms = [[block_ch.cls], [block_ch.outputb, block_ch.ctg]] 790 | item_tfms = [ToTensor(), Resize(item_size)] 791 | dsets = Datasets(items, tfms=tfms, splits=split_idx) 792 | 793 | dls = dsets.dataloaders(after_item=item_tfms, after_batch=after_b, bs=int(aug_dash.bs.value), num_workers=0) 794 | arch = xresnet50(pretrained=False) 795 | #learn = Learner(dls, model=arch, loss_func=LabelSmoothingCrossEntropy(), 796 | # metrics=[top_k_accuracy, accuracy]) 797 | learn = cnn_learner(dls, xresnet50, loss_func=LabelSmoothingCrossEntropy(), metrics=[top_k_accuracy, accuracy]) 798 | learn.fine_tune(1) 799 | print('Getting Intepretations....') 800 | interp = ClassificationInterpretation.from_learner(learn) 801 | 802 | losses,idxs = interp.top_losses() 803 | 804 | len(dls.valid_ds)==len(losses)==len(idxs) 805 | #display(play_opt) 806 | display(res) 807 | display(button_res) 808 | res_out = widgets.Output() 809 | display(res_out) 810 | def on_res_button(b): 811 | with res_out: 812 | clear_output() 813 | if res.value == 'Confusion Matrix': 814 | interp.plot_confusion_matrix(figsize=(7,7)) 815 | if res.value == 'Most Confused': 816 | print(interp.most_confused(min_val=1)) 817 | if res.value == 'Top Losses': 818 | interp.plot_top_losses(9, figsize=(7,7)) 819 | button_res.on_click(on_res_button) 820 | 821 | button_t2.on_click(on_play_button) 822 | 823 | def display_ui(): 824 | """ Display tabs for visual display""" 825 | button = widgets.Button(description="Train") 826 | button_b = widgets.Button(description="Metrics") 827 | button_m = widgets.Button(description='Model') 828 | button_l = widgets.Button(description='LR') 829 | 830 | test_button = widgets.Button(description='Batch') 831 | test2_button = widgets.Button(description='Test2') 832 | 833 | out1a = widgets.Output() 834 | out1 = widgets.Output() 835 | out2 = widgets.Output() 836 | out3 = widgets.Output() 837 | out4 = widgets.Output() 838 | out5 = widgets.Output() 839 | 840 | data1a = pd.DataFrame(np.random.normal(size = 50)) 841 | data1 = pd.DataFrame(np.random.normal(size = 100)) 842 | data2 = pd.DataFrame(np.random.normal(size = 150)) 843 | data3 = pd.DataFrame(np.random.normal(size = 200)) 844 | data4 = pd.DataFrame(np.random.normal(size = 250)) 845 | data5 = pd.DataFrame(np.random.normal(size = 300)) 846 | 847 | with out1a: #info 848 | clear_output() 849 | dashboard_one() 850 | 851 | with out1: #data 852 | clear_output() 853 | dashboard_two() 854 | 855 | with out2: #augmentation 856 | clear_output() 857 | aug_dash() 858 | 859 | with out3: #Block 860 | clear_output() 861 | ds_3() 862 | 863 | with out4: #code 864 | clear_output() 865 | write_code() 866 | 867 | with out5: #Imagewoof Play 868 | clear_output() 869 | print(BOLD + BLUE + 'Work in progress.....') 870 | play_button = widgets.Button(description='Parameters') 871 | display(play_button) 872 | play_out = widgets.Output() 873 | display(play_out) 874 | def button_play(b): 875 | with play_out: 876 | clear_output() 877 | play_info() 878 | play_button.on_click(button_play) 879 | 880 | display_ui.tab = widgets.Tab(children = [out1a, out1, out2, out3, out4, out5]) 881 | display_ui.tab.set_title(0, 'Info') 882 | display_ui.tab.set_title(1, 'Data') 883 | display_ui.tab.set_title(2, 'Augmentation') 884 | display_ui.tab.set_title(3, 'DataBlock') 885 | display_ui.tab.set_title(4, 'Code') 886 | display_ui.tab.set_title(5, 'ImageWoof Play') 887 | display(display_ui.tab) 888 | -------------------------------------------------------------------------------- /xresnet2.py: -------------------------------------------------------------------------------- 1 | #from https://github.com/fastai/fastai/blob/master/fastai/vision/models/xresnet2.py 2 | 3 | import torch.nn as nn 4 | import torch 5 | import math 6 | import torch.utils.model_zoo as model_zoo 7 | from torch.nn import Module #changed from torch_core 8 | 9 | 10 | __all__ = ['XResNet', 'xresnet18', 'xresnet34_2', 'xresnet50_2', 'xresnet101', 'xresnet152'] 11 | 12 | 13 | def conv3x3(in_planes, out_planes, stride=1): 14 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 15 | 16 | 17 | class BasicBlock(Module): 18 | expansion = 1 19 | 20 | def __init__(self, inplanes, planes, stride=1, downsample=None): 21 | super(BasicBlock, self).__init__() 22 | self.conv1 = conv3x3(inplanes, planes, stride) 23 | self.bn1 = nn.BatchNorm2d(planes) 24 | self.relu = nn.ReLU(inplace=True) 25 | self.conv2 = conv3x3(planes, planes) 26 | self.bn2 = nn.BatchNorm2d(planes) 27 | self.downsample = downsample 28 | self.stride = stride 29 | 30 | def forward(self, x): 31 | residual = x 32 | 33 | out = self.conv1(x) 34 | out = self.bn1(out) 35 | out = self.relu(out) 36 | 37 | out = self.conv2(out) 38 | out = self.bn2(out) 39 | 40 | if self.downsample is not None: residual = self.downsample(x) 41 | 42 | out += residual 43 | out = self.relu(out) 44 | 45 | return out 46 | 47 | 48 | class Bottleneck(Module): 49 | expansion = 4 50 | 51 | def __init__(self, inplanes, planes, stride=1, downsample=None): 52 | super(Bottleneck, self).__init__() 53 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 54 | self.bn1 = nn.BatchNorm2d(planes) 55 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 56 | padding=1, bias=False) 57 | self.bn2 = nn.BatchNorm2d(planes) 58 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) 59 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 60 | self.relu = nn.ReLU(inplace=True) 61 | self.downsample = downsample 62 | self.stride = stride 63 | 64 | def forward(self, x): 65 | residual = x 66 | 67 | out = self.conv1(x) 68 | out = self.bn1(out) 69 | out = self.relu(out) 70 | 71 | out = self.conv2(out) 72 | out = self.bn2(out) 73 | out = self.relu(out) 74 | 75 | out = self.conv3(out) 76 | out = self.bn3(out) 77 | 78 | if self.downsample is not None: residual = self.downsample(x) 79 | 80 | out += residual 81 | out = self.relu(out) 82 | 83 | return out 84 | 85 | def conv2d(ni, nf, stride): 86 | return nn.Sequential(nn.Conv2d(ni, nf, kernel_size=3, stride=stride, padding=1, bias=False), 87 | nn.BatchNorm2d(nf), nn.ReLU(inplace=True)) 88 | 89 | class XResNet(Module): 90 | 91 | def __init__(self, block, layers, c_out=1000): 92 | self.inplanes = 64 93 | super(XResNet, self).__init__() 94 | self.conv1 = conv2d(3, 32, 2) 95 | self.conv2 = conv2d(32, 32, 1) 96 | self.conv3 = conv2d(32, 64, 1) 97 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 98 | self.layer1 = self._make_layer(block, 64, layers[0]) 99 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 100 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 101 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 102 | self.avgpool = nn.AdaptiveAvgPool2d(1) 103 | self.fc = nn.Linear(512 * block.expansion, c_out) 104 | 105 | for m in self.modules(): 106 | if isinstance(m, nn.Conv2d): 107 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 108 | elif isinstance(m, nn.BatchNorm2d): 109 | nn.init.constant_(m.weight, 1) 110 | nn.init.constant_(m.bias, 0) 111 | 112 | for m in self.modules(): 113 | if isinstance(m, BasicBlock): m.bn2.weight = nn.Parameter(torch.zeros_like(m.bn2.weight)) 114 | if isinstance(m, Bottleneck): m.bn3.weight = nn.Parameter(torch.zeros_like(m.bn3.weight)) 115 | if isinstance(m, nn.Linear): m.weight.data.normal_(0, 0.01) 116 | 117 | def _make_layer(self, block, planes, blocks, stride=1): 118 | downsample = None 119 | if stride != 1 or self.inplanes != planes * block.expansion: 120 | layers = [] 121 | if stride==2: layers.append(nn.AvgPool2d(kernel_size=2, stride=2)) 122 | layers += [ 123 | nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=1, bias=False), 124 | nn.BatchNorm2d(planes * block.expansion) ] 125 | downsample = nn.Sequential(*layers) 126 | 127 | layers = [] 128 | layers.append(block(self.inplanes, planes, stride, downsample)) 129 | self.inplanes = planes * block.expansion 130 | for i in range(1, blocks): layers.append(block(self.inplanes, planes)) 131 | return nn.Sequential(*layers) 132 | 133 | def forward(self, x): 134 | x = self.conv1(x) 135 | x = self.conv2(x) 136 | x = self.conv3(x) 137 | x = self.maxpool(x) 138 | 139 | x = self.layer1(x) 140 | x = self.layer2(x) 141 | x = self.layer3(x) 142 | x = self.layer4(x) 143 | 144 | x = self.avgpool(x) 145 | x = x.view(x.size(0), -1) 146 | x = self.fc(x) 147 | 148 | return x 149 | 150 | 151 | def xresnet18(pretrained=False, **kwargs): 152 | """Constructs a XResNet-18 model. 153 | Args: 154 | pretrained (bool): If True, returns a model pre-trained on ImageNet 155 | """ 156 | model = XResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 157 | if pretrained: model.load_state_dict(model_zoo.load_url(model_urls['xresnet18'])) 158 | return model 159 | 160 | 161 | def xresnet34_2(pretrained=False, **kwargs): 162 | """Constructs a XResNet-34 model. 163 | Args: 164 | pretrained (bool): If True, returns a model pre-trained on ImageNet 165 | """ 166 | model = XResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 167 | if pretrained: model.load_state_dict(model_zoo.load_url(model_urls['xresnet34'])) 168 | return model 169 | 170 | 171 | def xresnet50_2(pretrained=False, **kwargs): 172 | """Constructs a XResNet-50 model. 173 | Args: 174 | pretrained (bool): If True, returns a model pre-trained on ImageNet 175 | """ 176 | model = XResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 177 | if pretrained: model.load_state_dict(model_zoo.load_url(model_urls['xresnet50'])) 178 | return model 179 | 180 | 181 | def xresnet101(pretrained=False, **kwargs): 182 | """Constructs a XResNet-101 model. 183 | Args: 184 | pretrained (bool): If True, returns a model pre-trained on ImageNet 185 | """ 186 | model = XResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 187 | if pretrained: model.load_state_dict(model_zoo.load_url(model_urls['xresnet101'])) 188 | return model 189 | 190 | 191 | def xresnet152(pretrained=False, **kwargs): 192 | """Constructs a XResNet-152 model. 193 | Args: 194 | pretrained (bool): If True, returns a model pre-trained on ImageNet 195 | """ 196 | model = XResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 197 | if pretrained: model.load_state_dict(model_zoo.load_url(model_urls['xresnet152'])) 198 | return model 199 | --------------------------------------------------------------------------------