├── README.md
├── Visual_UI.ipynb
├── Visual_UI2.ipynb
├── _config.yml
├── environment.yml
├── paperspace_ui.py
├── static
├── CM_FN.PNG
├── CM_FP.PNG
├── CM_TN.PNG
├── CM_TP.PNG
├── CM_eight.PNG
├── CM_five.PNG
├── CM_four.PNG
├── CM_nine.PNG
├── CM_one.PNG
├── CM_seven.PNG
├── CM_six.PNG
├── CM_three.PNG
├── CM_two.PNG
├── LR.PNG
├── LR_one.PNG
├── LR_three.PNG
├── LR_two.PNG
├── Lr_four.PNG
├── aug_one.PNG
├── aug_one2.PNG
├── aug_three.PNG
├── aug_three2.PNG
├── aug_three3.PNG
├── aug_two.PNG
├── aug_two2.PNG
├── aug_two3.PNG
├── batch.PNG
├── batch_three.PNG
├── batch_two.PNG
├── cm_class.PNG
├── data.PNG
├── data2.PNG
├── heatmap3.PNG
├── info.PNG
├── info_dashboard.PNG
├── metrics.PNG
├── model.PNG
├── visionUI2_part1.gif
├── visionUI2_part2.gif
└── visionUI2_part3.gif
├── viola_test.ipynb
├── vision_ui.py
├── vision_ui2.py
└── xresnet2.py
/README.md:
--------------------------------------------------------------------------------
1 | ### Vision_UI
2 | Graphical User interface for fastai
3 |
4 | [](https://github.com/Naereen/StrapDown.js/blob/master/LICENSE) 
5 |
6 | [](https://colab.research.google.com/drive/1O_H41XhABAEQxg_p8KZd_BCQ8pj-eJX6) currently only works with **version 1**
7 |
8 |
9 | 
10 |
11 | Visual UI adds a graphical interface to fastai allowing the user to quickly load, choose parameters, train and view results without the need to dig deep into the code.
12 |
13 | ________________________________________________________________________________________________________________________________________
14 |
15 | ### Updates
16 |
17 | #### 06/03/2020
18 | - Can now be `pip` installed [fast-gui](https://pypi.org/project/fast-gui/)
19 |
20 | #### 03/17/2020
21 | - Update for compatability with [fastai2](https://github.com/fastai/fastai2)
22 | - Files: `Visual_UI2.ipyb` and `vision_ui2.py`
23 |
24 | #### Updates below are for version 1
25 | Files: `Visual_UI.ipyb` and `vision_ui.py`
26 |
27 | #### 12/23/2019
28 | - Inclusion of ImageDataBunch.from_csv
29 | - Additional augmentations included [cutout, jitter, contrast, brightness, rotate, symmetric warp, padding]
30 | - Inclusion of ClassConfusion widget
31 | - Addition of 'Code' tab to view code
32 |
33 | #### 11/12/2019
34 | - Under the 'Info' tab you can now easily upload some common datasets: Cats&Dogs, Imagenette, Imagewoof, Cifar and Mnist
35 |
36 |
37 |
38 |
39 | - Under the 'Results' tab if there are more than 2 classes the confusion matrix upgrades will not work but you can now view the confusion matrix
40 |
41 |
42 |
43 |
44 | #### 10/12/2019 - [](https://colab.research.google.com/drive/1O_H41XhABAEQxg_p8KZd_BCQ8pj-eJX6)
45 | - Works with Google Colab (https://github.com/asvcode/Colab_UI) - Results tab not currently available in Colab
46 |
47 | #### 09/25/2019 - xresnet architecture
48 | - xresnet architectures now working (using xresnet2.py from fastai)
49 |
50 | #### 09/12/2019 - Confusion Matrix Upgrades (currently only works if there are 2 classes)
51 | - Under the Results tab, the confusion matrix tab now includes enhanced viewing features:
52 |
53 | > Option to view images with heatmaps or not
54 |
55 |
56 |
57 |
58 | > Option to view images within each section of the matrix
59 |
60 |
61 |
62 |
63 | > If heatmap option is 'YES' you can choose colormap, interpolation and alpha parameters
64 |
65 |
66 |
67 |
68 | > Examples of using different parameters for viewing images
69 |
70 | 


71 |
72 |
73 | > Also have the option to view the images without the heatmap feature. Images within each matrix class display Index, Actual_Class, Predicted_Class, Prediction value, Loss and Image location
74 |
75 |
76 |
77 |
78 | > Images are stored within the path folder under their respective confusion matrix tags
79 |
80 | > View saved image files from various sections of the confusion matrix and compare their heatmap images.
81 |
82 |
83 | False Positive
84 |
85 | True Positive
86 |
87 | True Negative
88 |
89 | False Negative
90 |
91 |
92 |
93 |
94 |
95 | #### 07/09/2019
96 | - after a training run, the model is saved in the models folder with the following name: 'architecture' + 'pretrained' + batchsize + image size eg: resnet50_pretrained_True_batch_32_image_128.pth
97 | - updated tkinter askdirectory code: now after choosing a file the tkinter dialogue box will be destroyed - previously the box would remain open
98 |
99 | #### 06/05/2019
100 | - results tab added where you can load your saved model and plot multi_plot_losses, top_losses and Confusion_matrix
101 |
102 | #### 06/03/2019
103 | - path and image_path (for augmentations) is now within vision_ui so no need to have a seperate cell to specify path
104 | - included link to fastai docs and forum in 'info' tab
105 |
106 | ________________________________________________________________________________________________________________________________________
107 |
108 |
109 |
110 | All tabs are provided within an accordion design using ipywidgets, this allows for all aspects of choosing and viewing parameters in one line of sight
111 |
112 |
113 | 
114 |
115 |
116 | The Augmentation tab utilizes fastai parameters so you can view what different image augmentations look like and compare
117 |
118 |
119 | 
120 |
121 |
122 | View batch information
123 |
124 |
125 |
126 |
127 |
128 | Review model data and choose suitable metrics for training
129 |
130 |
131 | 
132 |
133 |
134 | Review parameters get learning rate and train using the one cycle policy
135 |
136 |
137 | 
138 |
139 |
140 | Can experiment with various learning rates and train
141 |
142 |
143 | 
144 |
145 |
146 |
147 |
148 | ### Requirements
149 | - fastai
150 |
151 | I am using the developer version:
152 |
153 |
154 |
155 |
156 |
157 |
158 | `git clone https://github.com/fastai/fastai`
159 |
160 | `cd fastai`
161 |
162 | `tools/run-after-git-clone`
163 |
164 | `pip install -e ".[dev]"`
165 |
166 | for installation instructions visit [Fastai Installation Readme](https://github.com/fastai/fastai/blob/master/README.md#installation)
167 |
168 | - ipywidgets
169 |
170 | `pip install ipywidgets
171 | jupyter nbextension enable --py widgetsnbextension`
172 |
173 | or
174 |
175 | `conda install -c conda-forge ipywidgets`
176 |
177 | for installation instructions visit [Installation docs](https://ipywidgets.readthedocs.io/en/stable/user_install.html)
178 |
179 | - psutil
180 |
181 | psutil (process and system utilities) is a cross-platform library for retrieving information on running processes and system utilization (CPU, memory, disks, network, sensors) in Python
182 |
183 | `pip install psutil`
184 |
185 |
186 | ### Installation
187 |
188 | git clone this repository
189 |
190 | `git clone https://github.com/asvcode/Vision_UI.git`
191 |
192 | run `Visual_UI.ipynb` and run `display_ui()`
193 |
194 |
195 | ### Known Issues
196 |
197 | - **Colab** - version 1 works with colab [Colab_UI](https://github.com/asvcode/Colab_UI) but is glitchy
198 |
199 | ### Future Work
200 |
201 | - ~~Integrate into fastai v2~~ - Compatability with fastai v2 done but not with the full functionality as the v1 version
202 | - ~~Create pip install verson~~ - Done! [fast-gui](https://pypi.org/project/fast-gui/)
203 | - Include full functionality with v2 version
204 | - Create a v2 that is compatible with Colab
205 |
--------------------------------------------------------------------------------
/Visual_UI.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "##########################\n",
10 | "## Visual UI for fastai ##\n",
11 | "##########################\n",
12 | "\n",
13 | "from vision_ui import *"
14 | ]
15 | },
16 | {
17 | "cell_type": "code",
18 | "execution_count": 2,
19 | "metadata": {},
20 | "outputs": [
21 | {
22 | "data": {
23 | "application/vnd.jupyter.widget-view+json": {
24 | "model_id": "6639f7fa54964e988ad176392b0df4cb",
25 | "version_major": 2,
26 | "version_minor": 0
27 | },
28 | "text/plain": [
29 | "Tab(children=(Output(), Output(), Output(), Output(), Output(), Output(), Output(), Output(), Output()), _titl…"
30 | ]
31 | },
32 | "metadata": {},
33 | "output_type": "display_data"
34 | }
35 | ],
36 | "source": [
37 | "display_ui()"
38 | ]
39 | },
40 | {
41 | "cell_type": "code",
42 | "execution_count": null,
43 | "metadata": {},
44 | "outputs": [],
45 | "source": []
46 | }
47 | ],
48 | "metadata": {
49 | "kernelspec": {
50 | "display_name": "Python 3",
51 | "language": "python",
52 | "name": "python3"
53 | },
54 | "language_info": {
55 | "codemirror_mode": {
56 | "name": "ipython",
57 | "version": 3
58 | },
59 | "file_extension": ".py",
60 | "mimetype": "text/x-python",
61 | "name": "python",
62 | "nbconvert_exporter": "python",
63 | "pygments_lexer": "ipython3",
64 | "version": "3.7.2"
65 | }
66 | },
67 | "nbformat": 4,
68 | "nbformat_minor": 2
69 | }
70 |
--------------------------------------------------------------------------------
/Visual_UI2.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 6,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "###2 cells below are scripts for improving display appearance"
10 | ]
11 | },
12 | {
13 | "cell_type": "code",
14 | "execution_count": 7,
15 | "metadata": {},
16 | "outputs": [
17 | {
18 | "data": {
19 | "application/javascript": [
20 | "require(\n",
21 | " ['notebook/js/outputarea'],\n",
22 | " function(oa) {\n",
23 | " os.OutputArea.auto_scroll_threshold = -1;\n",
24 | " console.log(\"Setting auto to -1\")\n",
25 | " });\n"
26 | ],
27 | "text/plain": [
28 | ""
29 | ]
30 | },
31 | "metadata": {},
32 | "output_type": "display_data"
33 | }
34 | ],
35 | "source": [
36 | "%%javascript\n",
37 | "require(\n",
38 | " ['notebook/js/outputarea'],\n",
39 | " function(oa) {\n",
40 | " os.OutputArea.auto_scroll_threshold = -1;\n",
41 | " console.log(\"Setting auto to -1\")\n",
42 | " });"
43 | ]
44 | },
45 | {
46 | "cell_type": "code",
47 | "execution_count": 8,
48 | "metadata": {},
49 | "outputs": [
50 | {
51 | "data": {
52 | "application/javascript": [
53 | "(function(on) {\n",
54 | "const e=$( \"Setup failed\" );\n",
55 | "const ns=\"js_jupyter_suppress_warnings\";\n",
56 | "var cssrules=$(\"#\"+ns);\n",
57 | "if(!cssrules.length) cssrules = $(\"\").appendTo(\"head\");\n",
58 | "e.click(function() {\n",
59 | " var s='Showing'; \n",
60 | " cssrules.empty()\n",
61 | " if(on) {\n",
62 | " s='Hiding';\n",
63 | " cssrules.append(\"div.output_stderr, div[data-mime-type*='.stderr'] { display:none; }\");\n",
64 | " }\n",
65 | " e.text(s+' warnings (click to toggle)');\n",
66 | " on=!on;\n",
67 | "}).click();\n",
68 | "$(element).append(e);\n",
69 | "})(true);\n"
70 | ],
71 | "text/plain": [
72 | ""
73 | ]
74 | },
75 | "metadata": {},
76 | "output_type": "display_data"
77 | }
78 | ],
79 | "source": [
80 | "%%javascript\n",
81 | "(function(on) {\n",
82 | "const e=$( \"Setup failed\" );\n",
83 | "const ns=\"js_jupyter_suppress_warnings\";\n",
84 | "var cssrules=$(\"#\"+ns);\n",
85 | "if(!cssrules.length) cssrules = $(\"\").appendTo(\"head\");\n",
86 | "e.click(function() {\n",
87 | " var s='Showing'; \n",
88 | " cssrules.empty()\n",
89 | " if(on) {\n",
90 | " s='Hiding';\n",
91 | " cssrules.append(\"div.output_stderr, div[data-mime-type*='.stderr'] { display:none; }\");\n",
92 | " }\n",
93 | " e.text(s+' warnings (click to toggle)');\n",
94 | " on=!on;\n",
95 | "}).click();\n",
96 | "$(element).append(e);\n",
97 | "})(true);"
98 | ]
99 | },
100 | {
101 | "cell_type": "code",
102 | "execution_count": 9,
103 | "metadata": {},
104 | "outputs": [
105 | {
106 | "data": {
107 | "application/vnd.jupyter.widget-view+json": {
108 | "model_id": "196a73bfd83c424f8cba5ebce1f70877",
109 | "version_major": 2,
110 | "version_minor": 0
111 | },
112 | "text/plain": [
113 | "Tab(children=(Output(), Output(), Output(), Output(), Output(), Output()), _titles={'0': 'Info', '1': 'Data', …"
114 | ]
115 | },
116 | "metadata": {},
117 | "output_type": "display_data"
118 | }
119 | ],
120 | "source": [
121 | "from vision_ui2 import *\n",
122 | "display_ui()"
123 | ]
124 | },
125 | {
126 | "cell_type": "code",
127 | "execution_count": null,
128 | "metadata": {},
129 | "outputs": [],
130 | "source": []
131 | },
132 | {
133 | "cell_type": "code",
134 | "execution_count": null,
135 | "metadata": {},
136 | "outputs": [],
137 | "source": []
138 | }
139 | ],
140 | "metadata": {
141 | "kernelspec": {
142 | "display_name": "Python 3",
143 | "language": "python",
144 | "name": "python3"
145 | },
146 | "language_info": {
147 | "codemirror_mode": {
148 | "name": "ipython",
149 | "version": 3
150 | },
151 | "file_extension": ".py",
152 | "mimetype": "text/x-python",
153 | "name": "python",
154 | "nbconvert_exporter": "python",
155 | "pygments_lexer": "ipython3",
156 | "version": "3.7.6"
157 | }
158 | },
159 | "nbformat": 4,
160 | "nbformat_minor": 2
161 | }
162 |
--------------------------------------------------------------------------------
/_config.yml:
--------------------------------------------------------------------------------
1 | theme: jekyll-theme-slate
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: fastai2
2 | channels:
3 | - fastai
4 | - pytorch
5 | - defaults
6 | dependencies:
7 | - jupyter
8 | - pytorch>=1.3.0
9 | - torchvision>=0.5
10 | - matplotlib
11 | - pandas
12 | - requests
13 | - pyyaml
14 | - fastprogress>=0.1.22
15 | - pillow
16 | - python>=3.6
17 | - pip
18 | - scikit-learn
19 | - scipy
20 | - spacy
21 | - voila
22 |
--------------------------------------------------------------------------------
/paperspace_ui.py:
--------------------------------------------------------------------------------
1 | """
2 | Paperspace_UI based on Vision_UI
3 | Visual graphical interface for Fastai
4 |
5 | Last Update: 10/12/2019
6 | https://github.com/asvcode/Vision_UI
7 | """
8 |
9 | from ipywidgets import interact, interactive, fixed, interact_manual
10 | import ipywidgets
11 | import ipywidgets as widgets
12 | import IPython
13 | from IPython.display import display,clear_output
14 |
15 | import webbrowser
16 | from IPython.display import YouTubeVideo
17 |
18 | from fastai.vision import *
19 | from fastai.widgets import *
20 | from fastai.callbacks import*
21 |
22 | def version():
23 | import fastai
24 | import psutil
25 |
26 | print ('>> Paperspace version')
27 |
28 | button = widgets.Button(description='System')
29 | but = widgets.HBox([button])
30 | display(but)
31 |
32 | out = widgets.Output()
33 | display(out)
34 |
35 | def on_button_clicked_info(b):
36 | with out:
37 | clear_output()
38 | print(f'Fastai Version: {fastai.__version__}')
39 | print(f'Cuda: {torch.cuda.is_available()}')
40 | print(f'GPU: {torch.cuda.get_device_name(0)}')
41 | print(f'Python version: {sys.version}')
42 | print(psutil.cpu_percent())
43 | print(psutil.virtual_memory()) # physical memory usage
44 | print('memory % used:', psutil.virtual_memory()[2])
45 |
46 | button.on_click(on_button_clicked_info)
47 |
48 | def dashboard_one():
49 | style = {'description_width': 'initial'}
50 |
51 | print('>> Currently only works with files FROM_FOLDERS' '\n')
52 | dashboard_one.datain = widgets.ToggleButtons(
53 | options=['from_folder'],
54 | description='Data In:',
55 | disabled=True,
56 | button_style='success', # 'success', 'info', 'warning', 'danger' or ''
57 | tooltips=['Data in folder', 'Data in csv format - NOT ACTIVE', 'Data in dataframe - NOT ACTIVE'],
58 | )
59 | dashboard_one.norma = widgets.ToggleButtons(
60 | options=['Imagenet', 'Custom', 'Cifar', 'Mnist'],
61 | description='Normalization:',
62 | disabled=False,
63 | button_style='info', # 'success', 'info', 'warning', 'danger' or ''
64 | tooltips=['Imagenet stats', 'Create your own', 'Cifar stats', 'Mnist stats'],
65 | style=style
66 | )
67 | dashboard_one.archi = widgets.ToggleButtons(
68 | options=['alexnet', 'BasicBlock', 'densenet121', 'densenet161', 'densenet169', 'densenet201', 'resnet18',
69 | 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'squeezenet1_0', 'squeezenet1_1', 'vgg16_bn',
70 | 'vgg19_bn', 'xresnet18', 'xresnet34', 'xresnet50', 'xresnet101', 'xresnet152'],
71 | description='Architecture:',
72 | disabled=False,
73 | button_style='', # 'success', 'info', 'warning', 'danger' or ''
74 | tooltips=[],
75 | )
76 | layout = widgets.Layout(width='auto', height='40px') #set width and height
77 |
78 | xres_text = widgets.Button(
79 | description='FOR Xresnet models: Are not pretrained so have to UNCHECK Pretrain box to avoid errors.',
80 | disabled=True,
81 | display='flex',
82 | flex_flow='column',
83 | align_items='stretch',
84 | layout = layout
85 | )
86 | dashboard_one.pretrain_check = widgets.Checkbox(
87 | options=['Yes', "No"],
88 | description='Pretrained:',
89 | disabled=False,
90 | value=True,
91 | box_style='success',
92 | button_style='lightgreen', # 'success', 'info', 'warning', 'danger' or ''
93 | tooltips=['Default: Checked = use pretrained weights, Unchecked = No pretrained weights'],
94 | )
95 |
96 | layout = {'width':'90%', 'height': '50px', 'border': 'solid', 'fontcolor':'lightgreen'}
97 | layout_two = {'width':'100%', 'height': '200px', 'border': 'solid', 'fontcolor':'lightgreen'}
98 | style_green = {'handle_color': 'green', 'readout_color': 'red', 'slider_color': 'blue'}
99 | style_blue = {'handle_color': 'blue', 'readout_color': 'red', 'slider_color': 'blue'}
100 | dashboard_one.f=widgets.FloatSlider(min=8,max=64,step=8,value=32, continuous_update=False, layout=layout, style=style_green, description="Batch size")
101 | dashboard_one.m=widgets.FloatSlider(min=0, max=360, step=16, value=128, continuous_update=False, layout=layout, style=style_green, description='Image size')
102 |
103 |
104 | display(dashboard_one.datain, dashboard_one.norma, dashboard_one.archi, xres_text, dashboard_one.pretrain_check, dashboard_one.f, dashboard_one.m)
105 |
106 | def dashboard_two():
107 | button = widgets.Button(description="View")
108 | print ('>> Choose image to view augmentations:')
109 |
110 | image_choice()
111 | print('Augmentations')
112 |
113 | layout = {'width':'90%', 'height': '50px', 'border': 'solid', 'fontcolor':'lightgreen'}
114 | layout_two = {'width':'100%', 'height': '200px', 'border': 'solid', 'fontcolor':'lightgreen'}
115 | style_green = {'handle_color': 'green', 'readout_color': 'red', 'slider_color': 'blue'}
116 | style_blue = {'handle_color': 'blue', 'readout_color': 'red', 'slider_color': 'blue'}
117 |
118 | dashboard_two.doflip = widgets.ToggleButtons(
119 | options=['Yes', "No"],
120 | description='Do Flip:',
121 | disabled=False,
122 | button_style='success', # 'success', 'info', 'warning', 'danger' or ''
123 | tooltips=['Description of slow', 'Description of regular', 'Description of fast'],
124 | )
125 | dashboard_two.dovert = widgets.ToggleButtons(
126 | options=['Yes', "No"],
127 | description='Do Vert:',
128 | disabled=False,
129 | button_style='info', # 'success', 'info', 'warning', 'danger' or ''
130 | tooltips=['Description of slow', 'Description of regular', 'Description of fast'],
131 | )
132 | dashboard_two.two = widgets.FloatSlider(min=0,max=20,step=1,value=10, description='Max Rotate', orientation='vertical', style=style_green, layout=layout_two)
133 | dashboard_two.three = widgets.FloatSlider(min=1.1,max=4,step=1,value=1.1, description='Max Zoom', orientation='vertical', style=style_green, layout=layout_two)
134 | dashboard_two.four = widgets.FloatSlider(min=0.25, max=1.0, step=0.1, value=0.75, description='p_affine', orientation='vertical', style=style_green, layout=layout_two)
135 | dashboard_two.five = widgets.FloatSlider(min=0.2,max=0.99, step=0.1,value=0.2, description='Max Lighting', orientation='vertical', style=style_blue, layout=layout_two)
136 | dashboard_two.six = widgets.FloatSlider(min=0.25, max=1.1, step=0.1, value=0.75, description='p_lighting', orientation='vertical', style=style_blue, layout=layout_two)
137 | dashboard_two.seven = widgets.FloatSlider(min=0.1, max=0.9, step=0.1, value=0.2, description='Max warp', orientation='vertical', style=style_green, layout=layout_two)
138 |
139 | ui2 = widgets.VBox([dashboard_two.doflip, dashboard_two.dovert])
140 | ui = widgets.HBox([dashboard_two.two,dashboard_two.three, dashboard_two.seven, dashboard_two.four,dashboard_two.five, dashboard_two.six])
141 | ui3 = widgets.HBox([ui2, ui])
142 |
143 | display (ui3)
144 |
145 | print ('>> Press button to view augmentations. Pressing the button again will let you view additional augmentations below')
146 | display(button)
147 |
148 | def on_button_clicked(b):
149 | image_path = str(image_choice.output_variable.value)
150 | print('>> Displaying augmetations')
151 | display_augs(image_path)
152 |
153 | button.on_click(on_button_clicked)
154 |
155 | def get_image(image_path):
156 | print(image_path)
157 |
158 | def display_augs(image_path):
159 | get_image(image_path)
160 | image_d = open_image(image_path)
161 | print(image_d)
162 | def get_ex(): return open_image(image_path)
163 |
164 | out_flip = dashboard_two.doflip.value #do flip
165 | out_vert = dashboard_two.dovert.value # do vert
166 | out_rotate = dashboard_two.two.value #max rotate
167 | out_zoom = dashboard_two.three.value #max_zoom
168 | out_affine = dashboard_two.four.value #p_affine
169 | out_lighting = dashboard_two.five.value #Max_lighting
170 | out_plight = dashboard_two.six.value #p_lighting
171 | out_warp = dashboard_two.seven.value #Max_warp
172 |
173 | tfms = get_transforms(do_flip=out_flip, flip_vert=out_vert, max_zoom=out_zoom,
174 | p_affine=out_affine, max_lighting=out_lighting, p_lighting=out_plight, max_warp=out_warp,
175 | max_rotate=out_rotate)
176 |
177 | _, axs = plt.subplots(2,4,figsize=(12,6))
178 | for ax in axs.flatten():
179 | img = get_ex().apply_tfms(tfms[0], get_ex(), size=224)
180 | img.show(ax=ax)
181 |
182 | def view_batch_folder():
183 |
184 | print('>> IMPORTANT: Select data folder under INFO tab prior to clicking on batch button to avoid errors')
185 | button_g = widgets.Button(description="View Batch")
186 | display(button_g)
187 |
188 | batch_val = int(dashboard_one.f.value) # batch size
189 | image_val = int(dashboard_one.m.value) # image size
190 |
191 | out = widgets.Output()
192 | display(out)
193 |
194 | def on_button_click(b):
195 | with out:
196 | clear_output()
197 | print('\n''Augmentations''\n''Do Flip:', dashboard_two.doflip.value,'|''Do Vert:', dashboard_two.dovert.value, '\n'
198 | '\n''Max Rotate: ', dashboard_two.two.value,'|''Max Zoom: ', dashboard_two.three.value,'|''Max Warp: ',
199 | dashboard_two.seven.value,'|''p affine: ', dashboard_two.four.value, '\n''Max Lighting: ', dashboard_two.five.value,
200 | 'p lighting: ', dashboard_two.six.value, '\n'
201 | '\n''Normalization Value:', dashboard_one.norma.value, '\n''\n''working....')
202 |
203 | tfms = get_transforms(do_flip=dashboard_two.doflip.value, flip_vert=dashboard_two.dovert.value, max_zoom=dashboard_two.three.value,
204 | p_affine=dashboard_two.four.value, max_lighting=dashboard_two.five.value, p_lighting=dashboard_two.six.value,
205 | max_warp=dashboard_two.seven.value, max_rotate=dashboard_two.two.value, xtra_tfms=None)
206 |
207 | path = path_load.path_choice
208 | data = ImageDataBunch.from_folder(path, ds_tfms=tfms, bs=batch_val, size=image_val, test='test')
209 | data.normalize(stats_info())
210 | data.show_batch(rows=5, figsize=(10,10))
211 |
212 | button_g.on_click(on_button_click)
213 |
214 | def stats_info():
215 |
216 | if dashboard_one.norma.value == 'Imagenet':
217 | stats_info.stats = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
218 | elif dashboard_one.norma.value == 'Cifar':
219 | stats_info.stats = ([0.491, 0.482, 0.447], [0.247, 0.243, 0.261])
220 | elif dashboard_one.norma.value == 'Mnist':
221 | stats_info.stats = ([0.15, 0.15, 0.15], [0.15, 0.15, 0.15])
222 | else: # dashboard_one.norma.value == 'Custom':
223 | stats_info.stats = None
224 |
225 | stats = stats_info.stats
226 |
227 | mets_list = []
228 |
229 | precision = Precision()
230 | recall = Recall()
231 |
232 | def metrics_list(mets_list):
233 | mets_error = metrics_dashboard.error_choice.value
234 | mets_accuracy= metrics_dashboard.accuracy.value
235 | mets_accuracy_thr = metrics_dashboard.topk.value
236 | mets_precision = metrics_dashboard.precision.value
237 | mets_recall = metrics_dashboard.recall.value
238 | mets_dice = metrics_dashboard.dice.value
239 |
240 | mets_list=[]
241 | output_acc = accuracy
242 | output_thresh = top_k_accuracy
243 | output = error_rate
244 |
245 | if mets_error == 'Yes':
246 | mets_list.append(error_rate)
247 | else:
248 | None
249 | if mets_accuracy == 'Yes':
250 | mets_list.append(accuracy)
251 | else:
252 | None
253 | if mets_accuracy_thr == 'Yes':
254 | mets_list.append(top_k_accuracy)
255 | else:
256 | None
257 | if mets_precision == 'Yes':
258 | mets_list.append(precision)
259 | else:
260 | None
261 | if mets_recall == 'Yes':
262 | mets_list.append(recall)
263 | else:
264 | None
265 | if mets_dice == 'Yes':
266 | mets_list.append(dice)
267 | else:
268 | None
269 |
270 | metrics_info = mets_list
271 |
272 | return mets_list
273 |
274 | def model_summary():
275 |
276 | print('>> Review Model information: ', dashboard_one.archi.value)
277 |
278 | batch_val = int(dashboard_one.f.value) # batch size
279 | image_val = int(dashboard_one.m.value) # image size
280 |
281 | button_summary = widgets.Button(description="Model Summary")
282 | button_model_0 = widgets.Button(description='Model[0]')
283 | button_model_1 = widgets.Button(description='Model[1]')
284 |
285 | tfms = get_transforms(do_flip=dashboard_two.doflip.value, flip_vert=dashboard_two.dovert.value, max_zoom=dashboard_two.three.value,
286 | p_affine=dashboard_two.four.value, max_lighting=dashboard_two.five.value, p_lighting=dashboard_two.six.value,
287 | max_warp=dashboard_two.seven.value, max_rotate=dashboard_two.two.value, xtra_tfms=None)
288 |
289 | path = path_load.path_choice
290 | data = ImageDataBunch.from_folder(path, ds_tfms=tfms, bs=batch_val, size=image_val, test='test')
291 |
292 | r = dashboard_one.pretrain_check.value
293 |
294 | ui_out = widgets.HBox([button_summary, button_model_0, button_model_1])
295 |
296 | arch_work()
297 |
298 | display(ui_out)
299 | out = widgets.Output()
300 | display(out)
301 |
302 | def on_button_clicked_summary(b):
303 | with out:
304 | clear_output()
305 | print('working''\n')
306 | learn = cnn_learner(data, base_arch=arch_work.info, pretrained=r, custom_head=None)
307 | print('Model Summary')
308 | info = learn.summary()
309 | print(info)
310 |
311 | button_summary.on_click(on_button_clicked_summary)
312 |
313 | def on_button_clicked_model_0(b):
314 | with out:
315 | clear_output()
316 | print('working''\n')
317 | learn = cnn_learner(data, base_arch=arch_work.info, pretrained=r, custom_head=None)
318 | print('Model[0]')
319 | info_s = learn.model[0]
320 | print(info_s)
321 |
322 | button_model_0.on_click(on_button_clicked_model_0)
323 |
324 | def on_button_clicked_model_1(b):
325 | with out:
326 | clear_output()
327 | print('working''\n')
328 | learn = cnn_learner(data, base_arch=arch_work.info, pretrained=r, custom_head=None)
329 | print('Model[1]')
330 | info_sm = learn.model[1]
331 | print(info_sm)
332 |
333 | button_model_1.on_click(on_button_clicked_model_1)
334 |
335 | def arch_work():
336 | if dashboard_one.archi.value == 'alexnet':
337 | arch_work.info = models.alexnet
338 | elif dashboard_one.archi.value == 'BasicBlock':
339 | arch_work.info = models.BasicBlock
340 | elif dashboard_one.archi.value == 'densenet121':
341 | arch_work.info = models.densenet121
342 | elif dashboard_one.archi.value == 'densenet161':
343 | arch_work.info = models.densenet161
344 | elif dashboard_one.archi.value == 'densenet169':
345 | arch_work.info = models.densenet169
346 | elif dashboard_one.archi.value == 'densenet201':
347 | arch_work.info = models.densenet201
348 | if dashboard_one.archi.value == 'resnet18':
349 | arch_work.info = models.resnet18
350 | elif dashboard_one.archi.value == 'resnet34':
351 | arch_work.info = models.resnet34
352 | elif dashboard_one.archi.value == 'resnet50':
353 | arch_work.info = models.resnet50
354 | elif dashboard_one.archi.value == 'resnet101':
355 | arch_work.info = models.resnet101
356 | elif dashboard_one.archi.value == 'resnet152':
357 | arch_work.info = models.resnet152
358 | elif dashboard_one.archi.value == 'squeezenet1_0':
359 | arch_work.info = models.squeezenet1_0
360 | elif dashboard_one.archi.value == 'squeezenet1_1':
361 | arch_work.info = models.squeezenet1_1
362 | elif dashboard_one.archi.value == 'vgg16_bn':
363 | arch_work.info = models.vgg16_bn
364 | elif dashboard_one.archi.value == 'vgg19_bn':
365 | arch_work.info = models.vgg19_bn
366 | #elif dashboard_one.archi.value == 'wrn_22':
367 | # arch_work.info = models.wrn_22
368 | elif dashboard_one.archi.value == 'xresnet18':
369 | arch_work.info = xresnet2.xresnet18
370 | elif dashboard_one.archi.value == 'xresnet34':
371 | arch_work.info = xresnet2.xresnet34
372 | elif dashboard_one.archi.value == 'xresnet50':
373 | arch_work.info = xresnet2.xresnet50
374 | elif dashboard_one.archi.value == 'xresnet101':
375 | arch_work.info = xresnet2.xresnet101
376 | elif dashboard_one.archi.value == 'xresnet152':
377 | arch_work.info = xresnet2.xresnet152
378 |
379 | output = arch_work.info
380 | output
381 | print(output)
382 |
383 | def metrics_dashboard():
384 | button = widgets.Button(description="Metrics")
385 |
386 | batch_val = int(dashboard_one.f.value) # batch size
387 | image_val = int(dashboard_one.m.value) # image size
388 |
389 | tfms = get_transforms(do_flip=dashboard_two.doflip.value, flip_vert=dashboard_two.dovert.value, max_zoom=dashboard_two.three.value,
390 | p_affine=dashboard_two.four.value, max_lighting=dashboard_two.five.value, p_lighting=dashboard_two.six.value,
391 | max_warp=dashboard_two.seven.value, max_rotate=dashboard_two.two.value, xtra_tfms=None)
392 |
393 | path = path_load.path_choice
394 | data = ImageDataBunch.from_folder(path, ds_tfms=tfms, bs=batch_val, size=image_val, test='test')
395 |
396 | layout = {'width':'90%', 'height': '50px', 'border': 'solid', 'fontcolor':'lightgreen'}
397 | style_green = {'button_color': 'green','handle_color': 'green', 'readout_color': 'red', 'slider_color': 'blue'}
398 |
399 | metrics_dashboard.error_choice = widgets.ToggleButtons(
400 | options=['Yes', 'No'],
401 | description='Error Choice:',
402 | value='No',
403 | disabled=False,
404 | button_style='success', # 'success', 'info', 'warning', 'danger' or ''
405 | tooltips=[''],
406 | )
407 | metrics_dashboard.accuracy = widgets.ToggleButtons(
408 | options=['Yes', 'No'],
409 | description='Accuracy:',
410 | value='No',
411 | disabled=False,
412 | button_style='info', # 'success', 'info', 'warning', 'danger' or ''
413 | tooltips=[''],
414 | )
415 | metrics_dashboard.topk = widgets.ToggleButtons(
416 | options=['Yes', 'No'],
417 | description='Top K:',
418 | value='No',
419 | disabled=False,
420 | button_style='warning', # 'success', 'info', 'warning', 'danger' or ''
421 | tooltips=[''],
422 | )
423 | metrics_dashboard.recall = widgets.ToggleButtons(
424 | options=['Yes', 'No'],
425 | description='Recall:',
426 | value='No',
427 | disabled=False,
428 | button_style='success', # 'success', 'info', 'warning', 'danger' or ''
429 | tooltips=[''],
430 | )
431 | metrics_dashboard.precision = widgets.ToggleButtons(
432 | options=['Yes', 'No'],
433 | description='Precision:',
434 | value='No',
435 | disabled=False,
436 | button_style='info', # 'success', 'info', 'warning', 'danger' or ''
437 | tooltips=[''],
438 | )
439 | metrics_dashboard.dice = widgets.ToggleButtons(
440 | options=['Yes', 'No'],
441 | description='Dice:',
442 | value='No',
443 | disabled=False,
444 | button_style='warning', # 'success', 'info', 'warning', 'danger' or ''
445 | tooltips=[''],
446 | )
447 | layout = widgets.Layout(width='auto', height='40px') #set width and height
448 |
449 | centre_t = widgets.Button(
450 | description='',
451 | disabled=True,
452 | display='flex',
453 | flex_flow='column',
454 | align_items='stretch',
455 | layout = layout
456 | )
457 | ui = widgets.HBox([metrics_dashboard.error_choice, metrics_dashboard.accuracy, metrics_dashboard.topk])
458 | ui2 = widgets.HBox([metrics_dashboard.recall, metrics_dashboard.precision, metrics_dashboard.dice])
459 | ui3 = widgets.VBox([ui, centre_t, ui2])
460 |
461 | r = dashboard_one.pretrain_check.value
462 |
463 | display(ui3)
464 |
465 | print('>> Click to view choosen metrics')
466 | display(button)
467 |
468 | out = widgets.Output()
469 | display(out)
470 |
471 | def on_button_clicked(b):
472 | with out:
473 | clear_output()
474 | print('Training Metrics''\n')
475 | print('arch:', dashboard_one.archi.value, '\n''pretrain: ', dashboard_one.pretrain_check.value, '\n' ,'Choosen metrics: ',metrics_list(mets_list))
476 |
477 | button.on_click(on_button_clicked)
478 |
479 | def info_lr():
480 | button = widgets.Button(description='Review Parameters')
481 | button_two = widgets.Button(description='LR')
482 | button_three = widgets.Button(description='Train')
483 |
484 | butlr = widgets.HBox([button, button_two, button_three])
485 | display(butlr)
486 |
487 | out = widgets.Output()
488 | display(out)
489 |
490 | def on_button_clicked_info(b):
491 | with out:
492 | clear_output()
493 | print('Data in:', dashboard_one.datain.value,'|' 'Normalization:', dashboard_one.norma.value,'|' 'Architecture:', dashboard_one.archi.value,
494 | 'Pretrain:', dashboard_one.pretrain_check.value,'\n''Batch Size:', dashboard_one.f.value,'|''Image Size:', dashboard_one.m.value,'\n'
495 | '\n''Augmentations''\n''Do Flip:', dashboard_two.doflip.value,'|''Do Vert:', dashboard_two.dovert.value, '\n'
496 | '\n''Max Rotate: ', dashboard_two.two.value,'|''Max Zoom: ', dashboard_two.three.value,'|''Max Warp: ',
497 | dashboard_two.seven.value,'|''p affine: ', dashboard_two.four.value, '\n''Max Lighting: ', dashboard_two.five.value,
498 | 'p lighting: ', dashboard_two.six.value, '\n'
499 | '\n''Normalization Value:', dashboard_one.norma.value,'\n' '\n''Training Metrics''\n',
500 | metrics_list(mets_list))
501 |
502 | button.on_click(on_button_clicked_info)
503 |
504 | def on_button_clicked_info2(b):
505 | with out:
506 | clear_output()
507 | dashboard_one.datain.value, dashboard_one.norma.value, dashboard_one.archi.value, dashboard_one.pretrain_check.value,
508 | dashboard_one.f.value, dashboard_one.m.value, dashboard_two.doflip.value, dashboard_two.dovert.value,
509 | dashboard_two.two.value, dashboard_two.three.value, dashboard_two.seven.value, dashboard_two.four.value, dashboard_two.five.value,
510 | dashboard_two.six.value, dashboard_one.norma.value,metrics_list(mets_list)
511 |
512 | learn_dash()
513 |
514 | button_two.on_click(on_button_clicked_info2)
515 |
516 | def on_button_clicked_info3(b):
517 | with out:
518 | clear_output()
519 | print('Train')
520 | training()
521 |
522 | button_three.on_click(on_button_clicked_info3)
523 |
524 | def learn_dash():
525 | button = widgets.Button(description="Learn")
526 | print ('Choosen metrics: ',metrics_list(mets_list))
527 | metrics_list(mets_list)
528 |
529 | batch_val = int(dashboard_one.f.value) # batch size
530 | image_val = int(dashboard_one.m.value) # image size
531 |
532 | r = dashboard_one.pretrain_check.value
533 | t = metrics_list(mets_list)
534 |
535 | tfms = get_transforms(do_flip=dashboard_two.doflip.value, flip_vert=dashboard_two.dovert.value, max_zoom=dashboard_two.three.value,
536 | p_affine=dashboard_two.four.value, max_lighting=dashboard_two.five.value, p_lighting=dashboard_two.six.value,
537 | max_warp=dashboard_two.seven.value, max_rotate=dashboard_two.two.value, xtra_tfms=None)
538 |
539 | path = path_load.path_choice
540 | data = ImageDataBunch.from_folder(path, ds_tfms=tfms, bs=batch_val, size=image_val, test='test')
541 |
542 | learn = cnn_learner(data, base_arch=arch_work.info, pretrained=r, metrics=metrics_list(mets_list), custom_head=None)
543 |
544 | learn.lr_find()
545 | learn.recorder.plot()
546 |
547 |
548 | def model_button():
549 | button_m = widgets.Button(description='Model')
550 |
551 | print('>> View Model information (model_summary, model[0], model[1])''\n\n''>> For xresnet: Pretrained needs to be set to FALSE')
552 | display(button_m)
553 |
554 | out_two = widgets.Output()
555 | display(out_two)
556 |
557 | def on_button_clicked_train(b):
558 | with out_two:
559 | clear_output()
560 | print('Your pretrained setting: ', dashboard_one.pretrain_check.value)
561 | model_summary()
562 |
563 | button_m.on_click(on_button_clicked_train)
564 |
565 | #def drive_upload():
566 | # from google.colab import drive
567 | #print('mounting drive')
568 | #drive.mount('/content/gdrive', force_remount=True)
569 | #drive_upload.root_dir = "/content/gdrive/My Drive/"
570 | #print('drive mounted')
571 |
572 | def path_load():
573 |
574 | #path = Path(get_path.output_variable)
575 | file_location = str(get_path.output_variable.value)
576 | #path_load.path_choice = path/file_location
577 | path_load.path_choice = file_location
578 |
579 | il = ImageList.from_folder(path_load.path_choice)
580 | print(len(il.items))
581 | print(path_load.path_choice)
582 |
583 | def image_choice():
584 |
585 | from ipywidgets import widgets
586 | button_choice = widgets.Button(description="Image Path")
587 |
588 | # Create text widget for output
589 | image_choice.output_variable = widgets.Text()
590 | display(image_choice.output_variable)
591 |
592 | display(button_choice)
593 |
594 | def get_path():
595 |
596 | from ipywidgets import widgets
597 | button_choice = widgets.Button(description="Load Path")
598 |
599 | # Create text widget for output
600 | get_path.output_variable = widgets.Text()
601 | display(get_path.output_variable)
602 |
603 | display(button_choice)
604 |
605 | def on_button_clicked_summary(b):
606 | path_load()
607 |
608 | button_choice.on_click(on_button_clicked_summary)
609 |
610 | def metric_button():
611 | button_b = widgets.Button(description="Metrics")
612 | print ('>> Click button to choose appropriate metrics')
613 | display(button_b)
614 |
615 | out = widgets.Output()
616 | display(out)
617 |
618 | def on_button_clicked_learn(b):
619 | with out:
620 | clear_output()
621 | arch_work()
622 | metrics_dashboard()
623 |
624 | button_b.on_click(on_button_clicked_learn)
625 |
626 | def training():
627 | print('>> Using fit_one_cycle')
628 | button = widgets.Button(description='Train')
629 |
630 | style = {'description_width': 'initial'}
631 |
632 | layout = {'width':'90%', 'height': '50px', 'border': 'solid', 'fontcolor':'lightgreen'}
633 | layout_two = {'width':'100%', 'height': '200px', 'border': 'solid', 'fontcolor':'lightgreen'}
634 | style_green = {'handle_color': 'green', 'readout_color': 'red', 'slider_color': 'blue'}
635 | style_blue = {'handle_color': 'blue', 'readout_color': 'red', 'slider_color': 'blue'}
636 |
637 | training.cl=widgets.FloatSlider(min=1,max=64,step=1,value=1, continuous_update=False, layout=layout, style=style_green, description="Cycle Length")
638 | training.lr = widgets.ToggleButtons(
639 | options=['1e-6', '1e-5', '1e-4', '1e-3', '1e-2', '1e-1'],
640 | description='Learning Rate:',
641 | disabled=False,
642 | button_style='info', # 'success', 'info', 'warning', 'danger' or ''
643 | style=style,
644 | value='1e-2',
645 | tooltips=['Choose a suitable learning rate'],
646 | )
647 |
648 | display(training.cl, training.lr)
649 |
650 | display(button)
651 |
652 | out = widgets.Output()
653 | display(out)
654 |
655 | def on_button_clicked(b):
656 | with out:
657 | clear_output()
658 | lr_work()
659 | print('>> Training....''\n''Learning Rate: ', lr_work.info)
660 | dashboard_one.datain.value, dashboard_one.norma.value, dashboard_one.archi.value, dashboard_one.pretrain_check.value,
661 | dashboard_one.f.value, dashboard_one.m.value, dashboard_two.doflip.value, dashboard_two.dovert.value,
662 | dashboard_two.two.value, dashboard_two.three.value, dashboard_two.seven.value, dashboard_two.four.value, dashboard_two.five.value,
663 | dashboard_two.six.value, dashboard_one.norma.value,metrics_list(mets_list)
664 |
665 | metrics_list(mets_list)
666 |
667 | batch_val = int(dashboard_one.f.value) # batch size
668 | image_val = int(dashboard_one.m.value) # image size
669 |
670 | #values for saving model
671 | value_mone = str(dashboard_one.archi.value)
672 | value_mtwo = str(dashboard_one.pretrain_check.value)
673 | value_mthree = str(round(dashboard_one.f.value))
674 | value_mfour = str(round(dashboard_one.m.value))
675 |
676 | r = dashboard_one.pretrain_check.value
677 |
678 | tfms = get_transforms(do_flip=dashboard_two.doflip.value, flip_vert=dashboard_two.dovert.value, max_zoom=dashboard_two.three.value,
679 | p_affine=dashboard_two.four.value, max_lighting=dashboard_two.five.value, p_lighting=dashboard_two.six.value,
680 | max_warp=dashboard_two.seven.value, max_rotate=dashboard_two.two.value, xtra_tfms=None)
681 |
682 | path = path_load.path_choice
683 |
684 | data = (ImageList.from_folder(path)
685 | .split_by_folder()
686 | .label_from_folder()
687 | .transform(tfms, size=image_val)
688 | .add_test_folder('test')
689 | .databunch(path=path))
690 |
691 | learn = cnn_learner(data, base_arch=arch_work.info, pretrained=r, metrics=metrics_list(mets_list), custom_head=None, callback_fns=ShowGraph)
692 |
693 | cycle_l = int(training.cl.value)
694 |
695 | learn.fit_one_cycle(cycle_l, slice(lr_work.info))
696 |
697 | #save model
698 | file_model_name = value_mone + '_pretrained_' + value_mtwo + '_batch_' + value_mthree + '_image_' + value_mfour
699 |
700 | learn.save(file_model_name)
701 |
702 | button.on_click(on_button_clicked)
703 |
704 | def lr_work():
705 | if training.lr.value == '1e-6':
706 | lr_work.info = float(0.000001)
707 | elif training.lr.value == '1e-5':
708 | lr_work.info = float(0.00001)
709 | elif training.lr.value == '1e-4':
710 | lr_work.info = float(0.0001)
711 | elif training.lr.value == '1e-3':
712 | lr_work.info = float(0.001)
713 | elif training.lr.value == '1e-2':
714 | lr_work.info = float(0.01)
715 | elif training.lr.value == '1e-1':
716 | lr_work.info = float(0.1)
717 |
718 | def display_ui():
719 | button = widgets.Button(description="Train")
720 | button_b = widgets.Button(description="Metrics")
721 | button_m = widgets.Button(description='Model')
722 | button_l = widgets.Button(description='LR')
723 |
724 | out1aa = widgets.Output()
725 | out1a = widgets.Output()
726 | out1 = widgets.Output()
727 | out2 = widgets.Output()
728 | out3 = widgets.Output()
729 | out4 = widgets.Output()
730 | out5 = widgets.Output()
731 | out6 = widgets.Output()
732 |
733 | data1aa = pd.DataFrame(np.random.normal(size = 50))
734 | data1a = pd.DataFrame(np.random.normal(size = 100))
735 | data1 = pd.DataFrame(np.random.normal(size = 150))
736 | data2 = pd.DataFrame(np.random.normal(size = 200))
737 | data3 = pd.DataFrame(np.random.normal(size = 250))
738 | data4 = pd.DataFrame(np.random.normal(size = 300))
739 | data5 = pd.DataFrame(np.random.normal(size = 350))
740 | data6 = pd.DataFrame(np.random.normal(size = 400))
741 |
742 | with out1aa: #path_choice
743 | print('path')
744 | get_path()
745 |
746 | with out1a: #info
747 | version()
748 |
749 | with out1: #data
750 | dashboard_one()
751 |
752 | with out2: #augmentation
753 | dashboard_two()
754 |
755 | with out3: #Batch
756 | print('Click to view Batch' '\n\n')
757 | view_batch_folder()
758 |
759 | with out4: #model
760 | print('>> View Model information (model_summary, model[0], model[1])''\n\n''>> For xresnet: Pretrained needs to be set to FALSE, setting to TRUE results in error: NameError: name model_urls is not defined')
761 | display(button_m)
762 |
763 | out_two = widgets.Output()
764 | display(out_two)
765 |
766 | def on_button_clicked_train(b):
767 | with out_two:
768 | clear_output()
769 | print('Your pretrained setting: ', dashboard_one.pretrain_check.value)
770 | model_summary()
771 |
772 | button_m.on_click(on_button_clicked_train)
773 |
774 | with out5: #Metrics
775 | print ('>> Click button to choose appropriate metrics')
776 | display(button_b)
777 |
778 | out = widgets.Output()
779 | display(out)
780 |
781 | def on_button_clicked_learn(b):
782 | with out:
783 | clear_output()
784 | arch_work()
785 | metrics_dashboard()
786 |
787 | button_b.on_click(on_button_clicked_learn)
788 |
789 | with out6: #train
790 | print ('>> Click to view training parameters and learning rate''\n''\n'
791 | '>> IMPORTANT: You have to go through METRICS tab prior to choosing LR')
792 | info_lr()
793 |
794 | tab = widgets.Tab(children = [out1aa, out1a, out1, out2, out3, out4, out5, out6])
795 | tab.set_title(0, 'Path')
796 | tab.set_title(1, 'Info')
797 | tab.set_title(2, 'Data')
798 | tab.set_title(3, 'Augmentation')
799 | tab.set_title(4, 'Batch')
800 | tab.set_title(5, 'Model')
801 | tab.set_title(6, 'Metrics')
802 | tab.set_title(7, 'Train')
803 | display(tab)
804 |
--------------------------------------------------------------------------------
/static/CM_FN.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asvcode/Vision_UI/9a54e808a2de5553383ae167051d9a8117696ba8/static/CM_FN.PNG
--------------------------------------------------------------------------------
/static/CM_FP.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asvcode/Vision_UI/9a54e808a2de5553383ae167051d9a8117696ba8/static/CM_FP.PNG
--------------------------------------------------------------------------------
/static/CM_TN.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asvcode/Vision_UI/9a54e808a2de5553383ae167051d9a8117696ba8/static/CM_TN.PNG
--------------------------------------------------------------------------------
/static/CM_TP.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asvcode/Vision_UI/9a54e808a2de5553383ae167051d9a8117696ba8/static/CM_TP.PNG
--------------------------------------------------------------------------------
/static/CM_eight.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asvcode/Vision_UI/9a54e808a2de5553383ae167051d9a8117696ba8/static/CM_eight.PNG
--------------------------------------------------------------------------------
/static/CM_five.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asvcode/Vision_UI/9a54e808a2de5553383ae167051d9a8117696ba8/static/CM_five.PNG
--------------------------------------------------------------------------------
/static/CM_four.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asvcode/Vision_UI/9a54e808a2de5553383ae167051d9a8117696ba8/static/CM_four.PNG
--------------------------------------------------------------------------------
/static/CM_nine.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asvcode/Vision_UI/9a54e808a2de5553383ae167051d9a8117696ba8/static/CM_nine.PNG
--------------------------------------------------------------------------------
/static/CM_one.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asvcode/Vision_UI/9a54e808a2de5553383ae167051d9a8117696ba8/static/CM_one.PNG
--------------------------------------------------------------------------------
/static/CM_seven.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asvcode/Vision_UI/9a54e808a2de5553383ae167051d9a8117696ba8/static/CM_seven.PNG
--------------------------------------------------------------------------------
/static/CM_six.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asvcode/Vision_UI/9a54e808a2de5553383ae167051d9a8117696ba8/static/CM_six.PNG
--------------------------------------------------------------------------------
/static/CM_three.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asvcode/Vision_UI/9a54e808a2de5553383ae167051d9a8117696ba8/static/CM_three.PNG
--------------------------------------------------------------------------------
/static/CM_two.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asvcode/Vision_UI/9a54e808a2de5553383ae167051d9a8117696ba8/static/CM_two.PNG
--------------------------------------------------------------------------------
/static/LR.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asvcode/Vision_UI/9a54e808a2de5553383ae167051d9a8117696ba8/static/LR.PNG
--------------------------------------------------------------------------------
/static/LR_one.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asvcode/Vision_UI/9a54e808a2de5553383ae167051d9a8117696ba8/static/LR_one.PNG
--------------------------------------------------------------------------------
/static/LR_three.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asvcode/Vision_UI/9a54e808a2de5553383ae167051d9a8117696ba8/static/LR_three.PNG
--------------------------------------------------------------------------------
/static/LR_two.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asvcode/Vision_UI/9a54e808a2de5553383ae167051d9a8117696ba8/static/LR_two.PNG
--------------------------------------------------------------------------------
/static/Lr_four.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asvcode/Vision_UI/9a54e808a2de5553383ae167051d9a8117696ba8/static/Lr_four.PNG
--------------------------------------------------------------------------------
/static/aug_one.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asvcode/Vision_UI/9a54e808a2de5553383ae167051d9a8117696ba8/static/aug_one.PNG
--------------------------------------------------------------------------------
/static/aug_one2.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asvcode/Vision_UI/9a54e808a2de5553383ae167051d9a8117696ba8/static/aug_one2.PNG
--------------------------------------------------------------------------------
/static/aug_three.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asvcode/Vision_UI/9a54e808a2de5553383ae167051d9a8117696ba8/static/aug_three.PNG
--------------------------------------------------------------------------------
/static/aug_three2.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asvcode/Vision_UI/9a54e808a2de5553383ae167051d9a8117696ba8/static/aug_three2.PNG
--------------------------------------------------------------------------------
/static/aug_three3.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asvcode/Vision_UI/9a54e808a2de5553383ae167051d9a8117696ba8/static/aug_three3.PNG
--------------------------------------------------------------------------------
/static/aug_two.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asvcode/Vision_UI/9a54e808a2de5553383ae167051d9a8117696ba8/static/aug_two.PNG
--------------------------------------------------------------------------------
/static/aug_two2.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asvcode/Vision_UI/9a54e808a2de5553383ae167051d9a8117696ba8/static/aug_two2.PNG
--------------------------------------------------------------------------------
/static/aug_two3.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asvcode/Vision_UI/9a54e808a2de5553383ae167051d9a8117696ba8/static/aug_two3.PNG
--------------------------------------------------------------------------------
/static/batch.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asvcode/Vision_UI/9a54e808a2de5553383ae167051d9a8117696ba8/static/batch.PNG
--------------------------------------------------------------------------------
/static/batch_three.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asvcode/Vision_UI/9a54e808a2de5553383ae167051d9a8117696ba8/static/batch_three.PNG
--------------------------------------------------------------------------------
/static/batch_two.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asvcode/Vision_UI/9a54e808a2de5553383ae167051d9a8117696ba8/static/batch_two.PNG
--------------------------------------------------------------------------------
/static/cm_class.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asvcode/Vision_UI/9a54e808a2de5553383ae167051d9a8117696ba8/static/cm_class.PNG
--------------------------------------------------------------------------------
/static/data.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asvcode/Vision_UI/9a54e808a2de5553383ae167051d9a8117696ba8/static/data.PNG
--------------------------------------------------------------------------------
/static/data2.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asvcode/Vision_UI/9a54e808a2de5553383ae167051d9a8117696ba8/static/data2.PNG
--------------------------------------------------------------------------------
/static/heatmap3.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asvcode/Vision_UI/9a54e808a2de5553383ae167051d9a8117696ba8/static/heatmap3.PNG
--------------------------------------------------------------------------------
/static/info.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asvcode/Vision_UI/9a54e808a2de5553383ae167051d9a8117696ba8/static/info.PNG
--------------------------------------------------------------------------------
/static/info_dashboard.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asvcode/Vision_UI/9a54e808a2de5553383ae167051d9a8117696ba8/static/info_dashboard.PNG
--------------------------------------------------------------------------------
/static/metrics.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asvcode/Vision_UI/9a54e808a2de5553383ae167051d9a8117696ba8/static/metrics.PNG
--------------------------------------------------------------------------------
/static/model.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asvcode/Vision_UI/9a54e808a2de5553383ae167051d9a8117696ba8/static/model.PNG
--------------------------------------------------------------------------------
/static/visionUI2_part1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asvcode/Vision_UI/9a54e808a2de5553383ae167051d9a8117696ba8/static/visionUI2_part1.gif
--------------------------------------------------------------------------------
/static/visionUI2_part2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asvcode/Vision_UI/9a54e808a2de5553383ae167051d9a8117696ba8/static/visionUI2_part2.gif
--------------------------------------------------------------------------------
/static/visionUI2_part3.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asvcode/Vision_UI/9a54e808a2de5553383ae167051d9a8117696ba8/static/visionUI2_part3.gif
--------------------------------------------------------------------------------
/viola_test.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 3,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "#hide\n",
10 | "from fastai2.vision.all import *\n",
11 | "from fastai2.vision.widgets import *\n",
12 | "\n",
13 | "import ipywidgets as widgets\n",
14 | "from IPython.display import display,clear_output"
15 | ]
16 | },
17 | {
18 | "cell_type": "code",
19 | "execution_count": 13,
20 | "metadata": {},
21 | "outputs": [],
22 | "source": [
23 | "RED = '\\033[31m'\n",
24 | "BLUE = '\\033[94m'\n",
25 | "GREEN = '\\033[92m'\n",
26 | "BOLD = '\\033[1m'\n",
27 | "ITALIC = '\\033[3m'\n",
28 | "RESET = '\\033[0m'\n",
29 | "\n",
30 | "style = {'description_width': 'initial'}"
31 | ]
32 | },
33 | {
34 | "cell_type": "code",
35 | "execution_count": 9,
36 | "metadata": {},
37 | "outputs": [],
38 | "source": [
39 | "def dashboard_one():\n",
40 | " \"\"\"GUI for first accordion window\"\"\"\n",
41 | " import psutil\n",
42 | " import torchvision\n",
43 | " try:\n",
44 | " import fastai2; fastver = fastai2.__version__\n",
45 | " except ImportError:\n",
46 | " fastver = 'fastai not found'\n",
47 | " try:\n",
48 | " import fastprogress; fastprog = fastprogress.__version__\n",
49 | " except ImportError:\n",
50 | " fastprog = 'fastprogress not found'\n",
51 | " try:\n",
52 | " import fastpages; fastp = fastpages.__version__\n",
53 | " except ImportError:\n",
54 | " fastp = 'fastpages not found'\n",
55 | " try:\n",
56 | " import nbdev; nbd = nbdev.__version__\n",
57 | " except ImportError:\n",
58 | " ndb = 'nbdev not found'\n",
59 | "\n",
60 | " print (BOLD + RED + '>> Vision_UI Update: 03/17/2020')\n",
61 | " style = {'description_width': 'initial'}\n",
62 | "\n",
63 | " button = widgets.Button(description='System', button_style='success')\n",
64 | " ex_button = widgets.Button(description='Explore', button_style='success')\n",
65 | " display(button)\n",
66 | "\n",
67 | " out = widgets.Output()\n",
68 | " display(out)\n",
69 | "\n",
70 | " def on_button_clicked_info(b):\n",
71 | " with out:\n",
72 | " clear_output()\n",
73 | " print(BOLD + BLUE + \"fastai2 Version: \" + RESET + ITALIC + str(fastver))\n",
74 | " print(BOLD + BLUE + \"nbdev Version: \" + RESET + ITALIC + str(nbd))\n",
75 | " print(BOLD + BLUE + \"fastprogress Version: \" + RESET + ITALIC + str(fastprog))\n",
76 | " print(BOLD + BLUE + \"fastpages Version: \" + RESET + ITALIC + str(fastp) + '\\n')\n",
77 | " print(BOLD + BLUE + \"python Version: \" + RESET + ITALIC + str(sys.version))\n",
78 | " print(BOLD + BLUE + \"torchvision: \" + RESET + ITALIC + str(torchvision.__version__))\n",
79 | " print(BOLD + BLUE + \"torch version: \" + RESET + ITALIC + str(torch.__version__))\n",
80 | " print(BOLD + BLUE + \"\\nCuda: \" + RESET + ITALIC + str(torch.cuda.is_available()))\n",
81 | " print(BOLD + BLUE + \"cuda Version: \" + RESET + ITALIC + str(torch.version.cuda))\n",
82 | " print(BOLD + BLUE + \"GPU: \" + RESET + ITALIC + str(torch.cuda.get_device_name(0)))\n",
83 | " print(BOLD + BLUE + \"\\nCPU%: \" + RESET + ITALIC + str(psutil.cpu_percent()))\n",
84 | " print(BOLD + BLUE + \"\\nmemory % used: \" + RESET + ITALIC + str(psutil.virtual_memory()[2]))\n",
85 | " button.on_click(on_button_clicked_info)\n"
86 | ]
87 | },
88 | {
89 | "cell_type": "code",
90 | "execution_count": 14,
91 | "metadata": {},
92 | "outputs": [],
93 | "source": [
94 | "def dashboard_two():\n",
95 | " \"\"\"GUI for second accordion window\"\"\"\n",
96 | " dashboard_two.datas = widgets.ToggleButtons(\n",
97 | " options=['PETS', 'CIFAR', 'IMAGENETTE_160', 'IMAGEWOOF_160', 'MNIST_TINY'],\n",
98 | " description='Choose',\n",
99 | " value=None,\n",
100 | " disabled=False,\n",
101 | " button_style='info',\n",
102 | " tooltips=[''],\n",
103 | " style=style\n",
104 | " )\n",
105 | "\n",
106 | " display(dashboard_two.datas)\n",
107 | "\n",
108 | " button = widgets.Button(description='Explore', button_style='success')\n",
109 | " display(button)\n",
110 | " out = widgets.Output()\n",
111 | " display(out)\n",
112 | " def on_button_explore(b):\n",
113 | " with out:\n",
114 | " clear_output()\n",
115 | " ds_choice()\n",
116 | "\n",
117 | " button.on_click(on_button_explore)"
118 | ]
119 | },
120 | {
121 | "cell_type": "code",
122 | "execution_count": 16,
123 | "metadata": {},
124 | "outputs": [],
125 | "source": [
126 | "def ds_choice():\n",
127 | " \"\"\"Helper for dataset choices\"\"\"\n",
128 | " if dashboard_two.datas.value == 'PETS':\n",
129 | " ds_choice.source = untar_data(URLs.DOGS)\n",
130 | " elif dashboard_two.datas.value == 'CIFAR':\n",
131 | " ds_choice.source = untar_data(URLs.CIFAR)\n",
132 | " elif dashboard_two.datas.value == 'IMAGENETTE_160':\n",
133 | " ds_choice.source = untar_data(URLs.IMAGENETTE_160)\n",
134 | " elif dashboard_two.datas.value == 'IMAGEWOOF_160':\n",
135 | " ds_choice.source = untar_data(URLs.IMAGEWOOF_160)\n",
136 | " elif dashboard_two.datas.value == 'MNIST_TINY':\n",
137 | " ds_choice.source = untar_data(URLs.MNIST_TINY)\n",
138 | "\n",
139 | " print(BOLD + BLUE + \"Dataset: \" + RESET + BOLD + RED + str(dashboard_two.datas.value))\n",
140 | " plt_classes()"
141 | ]
142 | },
143 | {
144 | "cell_type": "code",
145 | "execution_count": 83,
146 | "metadata": {},
147 | "outputs": [],
148 | "source": [
149 | "def plt_classes():\n",
150 | " \"\"\"Helper for plotting classes in folder\"\"\"\n",
151 | " disp_img_but = widgets.Button(description='View Images?', button_style='success')\n",
152 | " Path.BASE_PATH = ds_choice.source\n",
153 | " train_source = (ds_choice.source/'train/').ls().items\n",
154 | " print(BOLD + BLUE + \"Folders: \" + RESET + BOLD + RED + str(train_source))\n",
155 | " print(BOLD + BLUE + \"\\n\" + \"No of classes: \" + RESET + BOLD + RED + str(len(train_source)))\n",
156 | "\n",
157 | " num_l = []\n",
158 | " class_l = []\n",
159 | " for j, name in enumerate(train_source):\n",
160 | " fol = (ds_choice.source/name).ls().sorted()\n",
161 | " names = str(name)\n",
162 | " class_split = names.split('train')\n",
163 | " class_l.append(class_split[1])\n",
164 | " num_l.append(len(fol))\n",
165 | "\n",
166 | " y_pos = np.arange(len(train_source))\n",
167 | " performance = num_l\n",
168 | " fig = plt.figure(figsize=(10,5))\n",
169 | "\n",
170 | " plt.style.use('seaborn')\n",
171 | " plt.bar(y_pos, performance, align='center', alpha=0.5, color=['black', 'red', 'green', 'blue', 'cyan'])\n",
172 | " plt.xticks(y_pos, class_l, rotation=90)\n",
173 | " plt.ylabel('Images')\n",
174 | " plt.title('Images per Class')\n",
175 | " plt.show()\n",
176 | "\n",
177 | " display(disp_img_but)\n",
178 | " img_out = widgets.Output()\n",
179 | " display(img_out)\n",
180 | " def on_disp_button(b):\n",
181 | " with img_out:\n",
182 | " clear_output()\n",
183 | " display_images()\n",
184 | " #display_i()\n",
185 | " disp_img_but.on_click(on_disp_button)"
186 | ]
187 | },
188 | {
189 | "cell_type": "code",
190 | "execution_count": 75,
191 | "metadata": {},
192 | "outputs": [],
193 | "source": [
194 | "def display_images():\n",
195 | " \"\"\"Helper for displaying images from folder\"\"\"\n",
196 | " train_source = (ds_choice.source/'train/').ls().items\n",
197 | " for i, name in enumerate(train_source):\n",
198 | " fol = (ds_choice.source/name).ls().sorted()\n",
199 | " fol_disp = fol[0:5]\n",
200 | " filename = fol_disp.items\n",
201 | " fol_tensor = [tensor(Image.open(o)) for o in fol_disp]\n",
202 | " print(BOLD + BLUE + \"Loc: \" + RESET + str(name) + \" \" + BOLD + BLUE + \"Number of Images: \" + RESET +\n",
203 | " BOLD + RED + str(len(fol)))\n",
204 | "\n",
205 | " fig = plt.figure(figsize=(15,15))\n",
206 | " columns = 5\n",
207 | " rows = 1\n",
208 | " ax = []\n",
209 | "\n",
210 | " for i in range(columns*rows):\n",
211 | " for i, j in enumerate(fol_tensor):\n",
212 | " img = fol_tensor[i] # create subplot and append to ax\n",
213 | " ax.append( fig.add_subplot(rows, columns, i+1))\n",
214 | " ax[-1].set_title(\"ax:\"+str(filename[i])) # set title\n",
215 | " plt.tick_params(bottom=\"on\", left=\"on\")\n",
216 | " plt.imshow(img)\n",
217 | " plt.xticks([])\n",
218 | " plt.show()\n"
219 | ]
220 | },
221 | {
222 | "cell_type": "code",
223 | "execution_count": 85,
224 | "metadata": {},
225 | "outputs": [],
226 | "source": [
227 | "def display_ui():\n",
228 | " \"\"\" Display tabs for visual display\"\"\"\n",
229 | " button = widgets.Button(description=\"Train\")\n",
230 | " button_b = widgets.Button(description=\"Metrics\")\n",
231 | " button_m = widgets.Button(description='Model')\n",
232 | " button_l = widgets.Button(description='LR')\n",
233 | "\n",
234 | " test_button = widgets.Button(description='Batch')\n",
235 | " test2_button = widgets.Button(description='Test2')\n",
236 | "\n",
237 | " out1a = widgets.Output()\n",
238 | " out1 = widgets.Output()\n",
239 | " out2 = widgets.Output()\n",
240 | " out3 = widgets.Output()\n",
241 | " out4 = widgets.Output()\n",
242 | " out5 = widgets.Output()\n",
243 | "\n",
244 | " data1a = pd.DataFrame(np.random.normal(size = 50))\n",
245 | " data1 = pd.DataFrame(np.random.normal(size = 100))\n",
246 | " data2 = pd.DataFrame(np.random.normal(size = 150))\n",
247 | " data3 = pd.DataFrame(np.random.normal(size = 200))\n",
248 | " data4 = pd.DataFrame(np.random.normal(size = 250))\n",
249 | " data5 = pd.DataFrame(np.random.normal(size = 300))\n",
250 | "\n",
251 | " with out1a: #info\n",
252 | " clear_output()\n",
253 | " dashboard_one()\n",
254 | "\n",
255 | " with out1: #data\n",
256 | " clear_output()\n",
257 | " dashboard_two()\n",
258 | "\n",
259 | " with out2: #augmentation\n",
260 | " clear_output()\n",
261 | " #aug_dash()\n",
262 | "\n",
263 | " with out3: #Block\n",
264 | " clear_output()\n",
265 | " #ds_3()\n",
266 | "\n",
267 | " with out4: #code\n",
268 | " clear_output()\n",
269 | " #write_code()\n",
270 | "\n",
271 | " with out5: #Imagewoof Play\n",
272 | " clear_output()\n",
273 | " #print(BOLD + BLUE + 'Work in progress.....')\n",
274 | " #play_button = widgets.Button(description='Parameters')\n",
275 | " #display(play_button)\n",
276 | " #play_out = widgets.Output()\n",
277 | " #display(play_out)\n",
278 | " #def button_play(b):\n",
279 | " # with play_out:\n",
280 | " # clear_output()\n",
281 | " # play_info()\n",
282 | " #play_button.on_click(button_play)\n",
283 | "\n",
284 | " display_ui.tab = widgets.Tab(children = [out1a, out1, out2, out3, out4, out5])\n",
285 | " display_ui.tab.set_title(0, 'Info')\n",
286 | " display_ui.tab.set_title(1, 'Data')\n",
287 | " display_ui.tab.set_title(2, 'Augmentation')\n",
288 | " display_ui.tab.set_title(3, 'DataBlock')\n",
289 | " display_ui.tab.set_title(4, 'Code')\n",
290 | " display_ui.tab.set_title(5, 'ImageWoof Play')\n",
291 | " display(display_ui.tab)\n"
292 | ]
293 | },
294 | {
295 | "cell_type": "code",
296 | "execution_count": 86,
297 | "metadata": {},
298 | "outputs": [
299 | {
300 | "data": {
301 | "application/vnd.jupyter.widget-view+json": {
302 | "model_id": "e2cb716a70dc4a2fa2dc5db83604a1e4",
303 | "version_major": 2,
304 | "version_minor": 0
305 | },
306 | "text/plain": [
307 | "Tab(children=(Output(), Output(), Output(), Output(), Output(), Output()), _titles={'0': 'Info', '1': 'Data', …"
308 | ]
309 | },
310 | "metadata": {},
311 | "output_type": "display_data"
312 | }
313 | ],
314 | "source": [
315 | "display_ui()"
316 | ]
317 | },
318 | {
319 | "cell_type": "code",
320 | "execution_count": null,
321 | "metadata": {},
322 | "outputs": [],
323 | "source": []
324 | }
325 | ],
326 | "metadata": {
327 | "kernelspec": {
328 | "display_name": "Python 3",
329 | "language": "python",
330 | "name": "python3"
331 | },
332 | "language_info": {
333 | "codemirror_mode": {
334 | "name": "ipython",
335 | "version": 3
336 | },
337 | "file_extension": ".py",
338 | "mimetype": "text/x-python",
339 | "name": "python",
340 | "nbconvert_exporter": "python",
341 | "pygments_lexer": "ipython3",
342 | "version": "3.7.6"
343 | }
344 | },
345 | "nbformat": 4,
346 | "nbformat_minor": 4
347 | }
348 |
--------------------------------------------------------------------------------
/vision_ui2.py:
--------------------------------------------------------------------------------
1 | ###############################
2 | ##### Visual_UI version 2 #####
3 | ##### Update 04/06/2020 #####
4 | ###############################
5 | from fastai2.vision.all import*
6 | from utils import*
7 |
8 | from ipywidgets import interact, interactive, fixed, interact_manual, Box, Layout
9 | import ipywidgets
10 | import ipywidgets as widgets
11 | from ipywidgets import Layout, Button, Box, FloatText, Textarea, Dropdown, Label, IntSlider, ToggleButton, FloatSlider, VBox
12 |
13 | from IPython.display import display,clear_output, Javascript
14 | import webbrowser
15 | from IPython.display import YouTubeVideo
16 |
17 | from tkinter import Tk
18 | from tkinter import filedialog
19 | from tkinter.filedialog import askdirectory
20 |
21 | import torch
22 | import collections as co
23 |
24 | style = {'description_width': 'initial'}
25 |
26 | RED = '\033[31m'
27 | BLUE = '\033[94m'
28 | GREEN = '\033[92m'
29 | BOLD = '\033[1m'
30 | ITALIC = '\033[3m'
31 | RESET = '\033[0m'
32 |
33 | align_kw = dict(
34 | _css = (('.widget-label', 'min-width', '20ex'),),
35 | margin = '0px 0px 5px 12px')
36 |
37 | def dashboard_one():
38 | """GUI for first accordion window"""
39 | import psutil
40 | import torchvision
41 | try:
42 | import fastai2; fastver = fastai2.__version__
43 | except ImportError:
44 | fastver = 'fastai not found'
45 | try:
46 | import fastprogress; fastprog = fastprogress.__version__
47 | except ImportError:
48 | fastprog = 'fastprogress not found'
49 | try:
50 | import fastpages; fastp = fastpages.__version__
51 | except ImportError:
52 | fastp = 'fastpages not found'
53 | try:
54 | import nbdev; nbd = nbdev.__version__
55 | except ImportError:
56 | nbd = 'nbdev not found'
57 |
58 | print (BOLD + RED + '>> Vision_UI Update: 03/17/2020')
59 | style = {'description_width': 'initial'}
60 |
61 | button = widgets.Button(description='System', button_style='success')
62 | ex_button = widgets.Button(description='Explore', button_style='success')
63 | display(button)
64 |
65 | out = widgets.Output()
66 | display(out)
67 |
68 | def on_button_clicked_info(b):
69 | with out:
70 | clear_output()
71 | print(BOLD + BLUE + "fastai2 Version: " + RESET + ITALIC + str(fastver))
72 | print(BOLD + BLUE + "nbdev Version: " + RESET + ITALIC + str(nbd))
73 | print(BOLD + BLUE + "fastprogress Version: " + RESET + ITALIC + str(fastprog))
74 | print(BOLD + BLUE + "fastpages Version: " + RESET + ITALIC + str(fastp) + '\n')
75 | print(BOLD + BLUE + "python Version: " + RESET + ITALIC + str(sys.version))
76 | print(BOLD + BLUE + "torchvision: " + RESET + ITALIC + str(torchvision.__version__))
77 | print(BOLD + BLUE + "torch version: " + RESET + ITALIC + str(torch.__version__))
78 | print(BOLD + BLUE + "\nCuda: " + RESET + ITALIC + str(torch.cuda.is_available()))
79 | print(BOLD + BLUE + "cuda Version: " + RESET + ITALIC + str(torch.version.cuda))
80 | print(BOLD + BLUE + "GPU: " + RESET + ITALIC + str(torch.cuda.get_device_name(0)))
81 | print(BOLD + BLUE + "\nCPU%: " + RESET + ITALIC + str(psutil.cpu_percent()))
82 | print(BOLD + BLUE + "\nmemory % used: " + RESET + ITALIC + str(psutil.virtual_memory()[2]))
83 | button.on_click(on_button_clicked_info)
84 |
85 | print ('>> Resources')
86 | button_two = widgets.Button(description='Fastai Docs', button_style='info')
87 | button_three = widgets.Button(description='Fastai Forums', button_style='info')
88 | button_four = widgets.Button(description='Vision_UI github', button_style='info')
89 |
90 | but_two = widgets.HBox([button_two, button_three, button_four])
91 | display(but_two)
92 |
93 | def on_doc_info(b):
94 | webbrowser.open('https://dev.fast.ai/')
95 | button_two.on_click(on_doc_info)
96 |
97 | def on_forum(b):
98 | webbrowser.open('https://forums.fast.ai/')
99 | button_three.on_click(on_forum)
100 |
101 | def vision_utube(b):
102 | webbrowser.open('https://github.com/asvcode/Vision_UI')
103 | button_four.on_click(vision_utube)
104 |
105 | def dashboard_two():
106 | """GUI for second accordion window"""
107 | dashboard_two.datas = widgets.ToggleButtons(
108 | options=['PETS', 'CIFAR', 'IMAGENETTE_160', 'IMAGEWOOF_160', 'MNIST_TINY'],
109 | description='Choose',
110 | value=None,
111 | disabled=False,
112 | button_style='info',
113 | tooltips=[''],
114 | style=style
115 | )
116 |
117 | display(dashboard_two.datas)
118 |
119 | button = widgets.Button(description='Explore', button_style='success')
120 | display(button)
121 | out = widgets.Output()
122 | display(out)
123 | def on_button_explore(b):
124 | with out:
125 | clear_output()
126 | ds_choice()
127 |
128 | button.on_click(on_button_explore)
129 |
130 | #Helpers for dashboard two
131 | def ds_choice():
132 | """Helper for dataset choices"""
133 | if dashboard_two.datas.value == 'PETS':
134 | ds_choice.source = untar_data(URLs.DOGS)
135 | elif dashboard_two.datas.value == 'CIFAR':
136 | ds_choice.source = untar_data(URLs.CIFAR)
137 | elif dashboard_two.datas.value == 'IMAGENETTE_160':
138 | ds_choice.source = untar_data(URLs.IMAGENETTE_160)
139 | elif dashboard_two.datas.value == 'IMAGEWOOF_160':
140 | ds_choice.source = untar_data(URLs.IMAGEWOOF_160)
141 | elif dashboard_two.datas.value == 'MNIST_TINY':
142 | ds_choice.source = untar_data(URLs.MNIST_TINY)
143 |
144 | print(BOLD + BLUE + "Dataset: " + RESET + BOLD + RED + str(dashboard_two.datas.value))
145 | plt_classes()
146 |
147 | def plt_classes():
148 | """Helper for plotting classes in folder"""
149 | disp_img_but = widgets.Button(description='View Images?', button_style='success')
150 | Path.BASE_PATH = ds_choice.source
151 | train_source = (ds_choice.source/'train/').ls().items
152 | print(BOLD + BLUE + "Folders: " + RESET + BOLD + RED + str(train_source))
153 | print(BOLD + BLUE + "\n" + "No of classes: " + RESET + BOLD + RED + str(len(train_source)))
154 |
155 | num_l = []
156 | class_l = []
157 | for j, name in enumerate(train_source):
158 | fol = (ds_choice.source/name).ls().sorted()
159 | names = str(name)
160 | class_split = names.split('train')
161 | class_l.append(class_split[1])
162 | num_l.append(len(fol))
163 |
164 | y_pos = np.arange(len(train_source))
165 | performance = num_l
166 |
167 | plt.style.use('seaborn')
168 | plt.bar(y_pos, performance, align='center', alpha=0.5, color=['black', 'red', 'green', 'blue', 'cyan'])
169 | plt.xticks(y_pos, class_l, rotation=90)
170 | plt.ylabel('Images')
171 | plt.title('Images per Class')
172 | plt.show()
173 |
174 | display(disp_img_but)
175 | out_img = widgets.Output()
176 | display(out_img)
177 | def on_disp_button(b):
178 | with out_img:
179 | clear_output()
180 | display_images()
181 | disp_img_but.on_click(on_disp_button)
182 |
183 | def display_images():
184 | """Helper for displaying images from folder"""
185 | train_source = (ds_choice.source/'train/').ls().items
186 | for i, name in enumerate(train_source):
187 | fol = (ds_choice.source/name).ls().sorted()
188 | fol_disp = fol[0:5]
189 | filename = fol_disp.items
190 | fol_tensor = [tensor(Image.open(o)) for o in fol_disp]
191 | print(BOLD + BLUE + "Loc: " + RESET + str(name) + " " + BOLD + BLUE + "Number of Images: " + RESET +
192 | BOLD + RED + str(len(fol)))
193 |
194 | fig = plt.figure(figsize=(30,13))
195 | columns = 5
196 | rows = 1
197 | ax = []
198 |
199 | for i in range(columns*rows):
200 | for i, j in enumerate(fol_tensor):
201 | img = fol_tensor[i] # create subplot and append to ax
202 | ax.append( fig.add_subplot(rows, columns, i+1))
203 | ax[-1].set_title("ax:"+str(filename[i])) # set title
204 | plt.tick_params(bottom="on", left="on")
205 | plt.imshow(img)
206 | plt.xticks([])
207 | plt.show()
208 |
209 | #Helpers for augmentation dashboard
210 | def aug_choice():
211 | """Helper for whether augmentations are choosen or not"""
212 | view_button = widgets.Button(description='View')
213 | display(view_button)
214 | view_out = widgets.Output()
215 | display(view_out)
216 | def on_view_button(b):
217 | with view_out:
218 | clear_output()
219 | if aug_dash.aug.value == 'No':
220 | code_test()
221 | if aug_dash.aug.value == 'Yes':
222 | aug_paras()
223 | view_button.on_click(on_view_button)
224 |
225 | def aug_paras():
226 | """If augmentations is choosen show available parameters"""
227 | print(BOLD + BLUE + "Choose Augmentation Parameters: ")
228 | button_paras = widgets.Button(description='Confirm', button_style='success')
229 |
230 | aug_paras.hh = widgets.ToggleButton(value=False, description='Erase', button_style='info',
231 | style=style)
232 | aug_paras.cc = widgets.ToggleButton(value=False, description='Contrast', button_style='info',
233 | style=style)
234 | aug_paras.dd = widgets.ToggleButton(value=False, description='Rotate', button_style='info',
235 | style=style)
236 | aug_paras.ee = widgets.ToggleButton(value=False, description='Warp', button_style='info',
237 | style=style)
238 | aug_paras.ff = widgets.ToggleButton(value=False, description='Bright', button_style='info',
239 | style=style)
240 | aug_paras.gg = widgets.ToggleButton(value=False, description='DihedralFlip', button_style='info',
241 | style=style)
242 | aug_paras.ii = widgets.ToggleButton(value=False, description='Zoom', button_style='info',
243 | style=style)
244 |
245 | qq = widgets.HBox([aug_paras.hh, aug_paras.cc, aug_paras.dd, aug_paras.ee, aug_paras.ff, aug_paras.gg, aug_paras.ii])
246 | display(qq)
247 | display(button_paras)
248 | aug_par = widgets.Output()
249 | display(aug_par)
250 | def on_button_two_click(b):
251 | with aug_par:
252 | clear_output()
253 | aug()
254 | aug_dash_choice()
255 | button_paras.on_click(on_button_two_click)
256 |
257 | def aug():
258 | """Aug choice helper"""
259 | #Erase
260 | if aug_paras.hh.value == True:
261 | aug.b_max = FloatSlider(min=0,max=50,step=1,value=0, description='max count',
262 | orientation='horizontal', disabled=False)
263 | aug.b_pval = FloatSlider(min=0,max=1,step=0.1,value=0, description=r"$p$",
264 | orientation='horizontal', disabled=False)
265 | aug.b_asp = FloatSlider(min=0.1,max=5, step=0.1, value=0.3, description=r'$aspect$',
266 | orientation='horizontal', disabled=False)
267 | aug.b_len = FloatSlider(min=0.1,max=5, step=0.1, value=0.3, description=r'$sl$',
268 | orientation='horizontal', disabled=False)
269 | aug.b_ht = FloatSlider(min=0.1,max=5, step=0.1, value=0.3, description=r'$sh$',
270 | orientation='horizontal', disabled=False)
271 | aug.erase_code = 'this is ERASE on'
272 | if aug_paras.hh.value == False:
273 | aug.b_max = FloatSlider(min=0,max=10,step=1,value=0, description='max count',
274 | orientation='horizontal', disabled=True)
275 | aug.b_pval = FloatSlider(min=0,max=1,step=0.1,value=0, description='p',
276 | orientation='horizontal', disabled=True)
277 | aug.b_asp = FloatSlider(min=0.1,max=1.7,value=0.3, description='aspect',
278 | orientation='horizontal', disabled=True)
279 | aug.b_len = FloatSlider(min=0.1,max=1.7,value=0.3, description='length',
280 | orientation='horizontal', disabled=True)
281 | aug.b_ht = FloatSlider(min=0.1,max=1.7,value=0.3, description='height',
282 | orientation='horizontal', disabled=True)
283 | aug.erase_code = 'this is ERASE OFF'
284 | #Contrast
285 | if aug_paras.cc.value == True:
286 | aug.b1_max = FloatSlider(min=0,max=0.9,step=0.1,value=0.2, description='max light',
287 | orientation='horizontal', disabled=False)
288 | aug.b1_pval = FloatSlider(min=0,max=1.0,step=0.05,value=0.75, description='p',
289 | orientation='horizontal', disabled=False)
290 | aug.b1_draw = FloatSlider(min=0,max=100,step=1,value=1, description='draw',
291 | orientation='horizontal', disabled=False)
292 | else:
293 | aug.b1_max = FloatSlider(min=0,max=0.9,step=0.1,value=0, description='max light',
294 | orientation='horizontal', disabled=True)
295 | aug.b1_pval = FloatSlider(min=0,max=1.0,step=0.05,value=0.75, description='p',
296 | orientation='horizontal', disabled=True)
297 | aug.b1_draw = FloatSlider(min=0,max=100,step=1,value=1, description='draw',
298 | orientation='horizontal', disabled=True)
299 | #Rotate
300 | if aug_paras.dd.value == True:
301 | aug.b2_max = FloatSlider(min=0,max=10,step=1,value=0, description='max degree',
302 | orientation='horizontal', disabled=False)
303 | aug.b2_pval = FloatSlider(min=0,max=1,step=0.1,value=0.5, description='p',
304 | orientation='horizontal', disabled=False)
305 | else:
306 | aug.b2_max = FloatSlider(min=0,max=10,step=1,value=0, description='max degree',
307 | orientation='horizontal', disabled=True)
308 | aug.b2_pval = FloatSlider(min=0,max=1,step=0.1,value=0, description='p',
309 | orientation='horizontal', disabled=True)
310 | #Warp
311 | if aug_paras.ee.value == True:
312 | aug.b3_mag = FloatSlider(min=0,max=10,step=1,value=0, description='magnitude',
313 | orientation='horizontal', disabled=False)
314 | aug.b3_pval = FloatSlider(min=0,max=1,step=0.1,value=0, description='p',
315 | orientation='horizontal', disabled=False)
316 | else:
317 | aug.b3_mag = FloatSlider(min=0,max=10,step=1,value=0, description='magnitude',
318 | orientation='horizontal', disabled=True)
319 | aug.b3_pval = FloatSlider(min=0,max=10,step=1,value=0, description='p',
320 | orientation='horizontal', disabled=True)
321 | #Bright
322 | if aug_paras.ff.value == True:
323 | aug.b4_max = FloatSlider(min=0,max=10,step=1,value=0, description='max light',
324 | orientation='horizontal', disabled=False)
325 | aug.b4_pval = FloatSlider(min=0,max=1,step=0.1,value=0, description='p',
326 | orientation='horizontal', disabled=False)
327 | else:
328 | aug.b4_max = FloatSlider(min=0,max=10,step=1,value=0, description='max_light',
329 | orientation='horizontal', disabled=True)
330 | aug.b4_pval = FloatSlider(min=0,max=1,step=0.1,value=0, description='p',
331 | orientation='horizontal', disabled=True)
332 | #DihedralFlip
333 | if aug_paras.gg.value == True:
334 | aug.b5_pval = FloatSlider(min=0,max=1,step=0.1, description='p',
335 | orientation='horizontal', disabled=False)
336 | aug.b5_draw = FloatSlider(min=0,max=7,step=1, description='p',
337 | orientation='horizontal', disabled=False)
338 | else:
339 | aug.b5_pval = FloatSlider(min=0,max=1,step=0.1, description='p',
340 | orientation='horizontal', disabled=True)
341 | aug.b5_draw = FloatSlider(min=0,max=7,step=1, description='p',
342 | orientation='horizontal', disabled=True)
343 | #Zoom
344 | if aug_paras.ii.value == True:
345 | aug.b6_zoom = FloatSlider(min=1,max=5,step=0.1, description='max_zoom',
346 | orientation='horizontal', disabled=False)
347 | aug.b6_pval = FloatSlider(min=0,max=1,step=0.1, description='p',
348 | orientation='horizontal', disabled=False)
349 | else:
350 | aug.b6_zoom = FloatSlider(min=1,max=5,step=0.1, description='max_zoom',
351 | orientation='horizontal', disabled=True)
352 | aug.b6_pval = FloatSlider(min=0,max=1,step=1, description='p',
353 | orientation='horizontal', disabled=True)
354 |
355 | #Single/Multi
356 | if aug_dash.bi.value == 'Single':
357 | aug.get_items = repeat_one
358 | aug.val = 'Single'
359 | if aug_dash.bi.value == 'Multi':
360 | aug.get_items = get_image_files
361 | aug.val = 'Multi'
362 |
363 |
364 | def aug_dash_choice():
365 | """Augmention parameter display helper"""
366 | button_aug_dash = widgets.Button(description='View', button_style='success')
367 | item_erase_val= widgets.HBox([aug.b_max, aug.b_pval, aug.b_asp, aug.b_len, aug.b_ht])
368 | item_erase = widgets.VBox([aug_paras.hh, item_erase_val])
369 |
370 | item_contrast_val = widgets.HBox([aug.b1_max, aug.b1_pval, aug.b1_draw])
371 | item_contrast = widgets.VBox([aug_paras.cc, item_contrast_val])
372 |
373 | item_rotate_val = widgets.HBox([aug.b2_max, aug.b2_pval])
374 | item_rotate = widgets.VBox([aug_paras.dd, item_rotate_val])
375 |
376 | item_warp_val = widgets.HBox([aug.b3_mag, aug.b3_pval])
377 | item_warp = widgets.VBox([aug_paras.ee, item_warp_val])
378 |
379 | item_bright_val = widgets.HBox([aug.b4_max, aug.b4_pval])
380 | item_bright = widgets.VBox([aug_paras.ff, item_bright_val])
381 |
382 | item_dihedral_val = widgets.HBox([aug.b5_pval, aug.b5_draw])
383 | item_dihedral = widgets.VBox([aug_paras.gg, item_dihedral_val])
384 |
385 | item_zoom_val = widgets.HBox([aug.b6_zoom, aug.b6_pval])
386 | item_zoom = widgets.VBox([aug_paras.ii, item_zoom_val])
387 |
388 | items = [item_erase, item_contrast, item_rotate, item_warp, item_bright, item_dihedral, item_zoom]
389 | dia = Box(items, layout=Layout(
390 | display='flex',
391 | flex_flow='column',
392 | flex_grow=0,
393 | flex_wrap='wrap',
394 | border='solid 1px',
395 | align_items='flex-start',
396 | align_content='flex-start',
397 | justify_content='space-between',
398 | width='flex'
399 | ))
400 | display(dia)
401 | display(button_aug_dash)
402 | aug_dash_out = widgets.Output()
403 | display(aug_dash_out)
404 | def on_button_two(b):
405 | with aug_dash_out:
406 | clear_output()
407 | stats_info()
408 | #image_show()
409 | code_test()
410 | button_aug_dash.on_click(on_button_two)
411 |
412 | def aug_dash():
413 | """GUI for augmentation dashboard"""
414 | aug_button = widgets.Button(description='Confirm', button_style='success')
415 |
416 | tb = widgets.Button(description='Batch Image', disabled=True, button_style='')
417 | aug_dash.bi = widgets.ToggleButtons(value='Multi', options=['Multi', 'Single'], description='', button_style='info',
418 | style=style, layout=Layout(width='auto'))
419 | tg = widgets.Button(description='Padding', disabled=True, button_style='')
420 | aug_dash.pad = widgets.ToggleButtons(value='zeros', options=['zeros', 'reflection', 'border'], description='', button_style='info',
421 | style=style, layout=Layout(width='auto'))
422 | th = widgets.Button(description='Normalization', disabled=True, button_style='')
423 | aug_dash.norm = widgets.ToggleButtons(value='Imagenet', options=['Imagenet', 'Mnist', 'Cifar', 'None'], description='', button_style='info',
424 | style=style, layout=Layout(width='auto'))
425 | tr = widgets.Button(description='Batch Size', disabled=True, button_style='warning')
426 | aug_dash.bs = widgets.ToggleButtons(value='16', options=['8', '16', '32', '64'], description='', button_style='warning',
427 | style=style, layout=Layout(width='auto'))
428 | spj = widgets.Button(description='Presizing', disabled=True, button_style='primary')
429 | te = widgets.Button(description='Item Size', disabled=True, button_style='primary')
430 | aug_dash.imgsiz = widgets.ToggleButtons(value='194', options=['28', '64', '128', '194', '254'],
431 | description='', button_style='primary', style=style, layout=Layout(width='auto'))
432 | to = widgets.Button(description='Batch Size', disabled=True, button_style='primary')
433 | aug_dash.imgbth = widgets.ToggleButtons(value='128', options=['28', '64', '128', '194', '254'],
434 | description='', button_style='primary', style=style, layout=Layout(width='auto'))
435 | tf = widgets.Button(description='Augmentation', disabled=True, button_style='danger')
436 | aug_dash.aug = widgets.ToggleButtons(value='No', options=['No', 'Yes'], description='', button_style='info',
437 | style=style, layout=Layout(width='auto'))
438 |
439 | it = [tb, aug_dash.bi]
440 | it2 = [tg, aug_dash.pad]
441 | it3 = [th, aug_dash.norm]
442 | it4 = [tr, aug_dash.bs]
443 | it5 = [te, aug_dash.imgsiz]
444 | it52 = [to, aug_dash.imgbth]
445 | it6 = [tf, aug_dash.aug]
446 | il = widgets.HBox(it)
447 | ij = widgets.HBox(it2)
448 | ik = widgets.HBox(it3)
449 | ie = widgets.HBox(it4)
450 | iw = widgets.HBox(it5)
451 | ip = widgets.HBox(it52)
452 | iq = widgets.HBox(it6)
453 | ir = widgets.VBox([il, ij, ik, ie, spj, iw, ip, iq])
454 | display(ir)
455 | display(aug_button)
456 |
457 | aug_out = widgets.Output()
458 | display(aug_out)
459 | def on_aug_button(b):
460 | with aug_out:
461 | clear_output()
462 | aug_choice()
463 | aug_button.on_click(on_aug_button)
464 |
465 | #Helpers for ds_3
466 | def stats_info():
467 | """Stats helper"""
468 | if aug_dash.norm.value == 'Imagenet':
469 | stats_info.stats = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
470 | stats_info.code = ('*imagenet_stats')
471 | if aug_dash.norm.value == 'Cifar':
472 | stats_info.stats = ([0.491, 0.482, 0.447], [0.247, 0.243, 0.261])
473 | stats_info.code = ('*cifar_stats')
474 | if aug_dash.norm.value == 'Mnist':
475 | stats_info.stats = ([0.15, 0.15, 0.15], [0.15, 0.15, 0.15])
476 | stats_info.code = ('*mnist_stats')
477 | if aug_dash.norm.value == 'None':
478 | stats_info.stats = ([0., 0., 0.], [0., 0., 0.])
479 | stats_info.code = ('[0., 0., 0.], [0., 0., 0.]')
480 | stats = stats_info.stats
481 |
482 | def repeat_one(source, n=128):
483 | """Single image helper for displaying batch"""
484 | return [get_image_files(ds_choice.source)[9]]*n
485 |
486 | def block_ch():
487 | """Helper for configuring mid-level datablock"""
488 | if ds_3.d1.value == 'PILImage':
489 | block_ch.cls = PILImage.create
490 | block_ch.code = 'PILImage.create'
491 | else:
492 | block_ch.cls = PILImageBW.create
493 | block_ch.code = 'PILImageBW.create'
494 | if ds_3.e1.value == 'CategoryBlock':
495 | block_ch.ctg = Categorize
496 | block_ch.ctg_code = 'Categorize'
497 | else:
498 | block_ch.ctg = MultiCategorize
499 | block_ch.ctg_code = 'MultiCategorize'
500 | if ds_3.g.value == True:
501 | block_ch.outputb = parent_label
502 | block_ch.outputb_code = 'parent_label'
503 | else:
504 | block_ch.outputb = None
505 | if ds_3.s.value == 'RandomSplitter':
506 | block_ch.spl_val = RandomSplitter()
507 | block_ch.spl_val_code = 'RandomSplitter()'
508 | if ds_3.s.value == 'GrandparentSplitter':
509 | block_ch.spl_val = GrandparentSplitter()
510 | block_ch.spl_val_code = 'GrandparentSplitter()'
511 |
512 | def ds_3():
513 | """GUI for 3rd Accordion window"""
514 | db_button = widgets.Button(description='Confirm')
515 | ds_3.d1 = widgets.ToggleButtons(value=None, options=['PILImage', 'PILImageBW'], description='', button_style='info')
516 | ds_3.e1 = widgets.ToggleButtons(value=None, options=['CategoryBlock', 'MultiCategoryBlock'], description='', button_style='warning')
517 | ds_3.g = widgets.ToggleButton(description='parent_label', button_style='danger', value=False)
518 | ds_3.s = widgets.ToggleButtons(value=None, options=['RandomSplitter', 'GrandparentSplitter'],
519 | description='', button_style='success')
520 | form_items = [ds_3.d1, ds_3.e1, ds_3.s ,ds_3.g]
521 |
522 | form_t = Layout(display='flex',
523 | flex_flow='row',
524 | align_items='stretch',
525 | border='solid 1px',
526 | width='100%')
527 | form = Box(children=form_items, layout=form_t)
528 | display(form)
529 | display(db_button)
530 | db_out = widgets.Output()
531 | display(db_out)
532 | def on_db_click(b):
533 | with db_out:
534 | clear_output()
535 | block_ch()
536 | code_test()
537 | db_button.on_click(on_db_click)
538 |
539 | def code_test():
540 | """Helpers"""
541 | db_button2 = widgets.Button(description='DataBlock')
542 | stats_info()
543 | method = ResizeMethod.Pad
544 |
545 | item_size = int(aug_dash.imgsiz.value)
546 | final_size = int(aug_dash.imgbth.value)
547 |
548 | if aug_dash.bi.value == 'Single':
549 | code_test.items = repeat_one(ds_choice.source)
550 | if aug_dash.bi.value == 'Multi':
551 | code_test.items = get_image_files(ds_choice.source)
552 | if aug_dash.aug.value == 'No':
553 | print(BOLD + BLUE + "working.....: " + RESET + RED + 'No Augmentations\n')
554 | print(BOLD + BLUE + "Multi/Single Image: " + RESET + RED + str(aug_dash.bi.value))
555 | print(BOLD + BLUE + "Padding: " + RESET + RED + str(aug_dash.pad.value))
556 | print(BOLD + BLUE + "Normalization: " + RESET + RED + str(stats_info.stats))
557 | print(BOLD + BLUE + "Batch Size: " + RESET + RED + (aug_dash.bs.value))
558 | print(BOLD + BLUE + "Item Size: " + RESET + RED + str(item_size))
559 | print(BOLD + BLUE + "Final Size: " + RESET + RED + str(final_size))
560 | after_b = None
561 | if aug_dash.aug.value == 'Yes':
562 | print(BOLD + BLUE + "working.....: " + RESET + RED + 'Augmentations\n')
563 | print(BOLD + BLUE + "RandomErasing: " + RESET + RED + 'max_count=' + str(aug.b_max.value) + ' p=' + str(aug.b_pval.value))
564 | print(BOLD + BLUE + "Contrast: " + RESET + RED + 'max_light=' + str(aug.b1_max.value) + ' p=' + str(aug.b1_pval.value))
565 | print(BOLD + BLUE + "Rotate: " + RESET + RED + 'max_degree=' + str(aug.b2_max.value) + ' p=' + str(aug.b2_pval.value))
566 | print(BOLD + BLUE + "Warp: " + RESET + RED + 'magnitude=' + str(aug.b3_mag.value) + ' p=' + str(aug.b3_pval.value))
567 | print(BOLD + BLUE + "Brightness: " + RESET + RED + 'max_light=' + str(aug.b4_max.value) + ' p=' + str(aug.b4_pval.value))
568 | print(BOLD + BLUE + "DihedralFlip: " + RESET + RED + ' p=' + str(aug.b5_pval.value) + 'draw=' + str(aug.b5_draw.value))
569 | print(BOLD + BLUE + "Zoom: " + RESET + RED + 'max_zoom=' + str(aug.b6_zoom.value) + ' p=' + str(aug.b6_pval.value))
570 | print(BOLD + BLUE + "\nMulti/Single Image: " + RESET + RED + str(aug_dash.bi.value))
571 | print(BOLD + BLUE + "Padding: " + RESET + RED + str(aug_dash.pad.value))
572 | print(BOLD + BLUE + "Normalization: " + RESET + RED + str(stats_info.stats))
573 | print(BOLD + BLUE + "Batch Size: " + RESET + RED + (aug_dash.bs.value))
574 | print(BOLD + BLUE + "Item Size: " + RESET + RED + str(item_size))
575 | print(BOLD + BLUE + "Final Size: " + RESET + RED + str(final_size))
576 |
577 | xtra_tfms = [RandomErasing(p=aug.b_pval.value, max_count=aug.b_max.value, min_aspect=aug.b_asp.value, sl=aug.b_len.value, sh=aug.b_ht.value), #p= probabilty
578 | Brightness(max_lighting=aug.b4_max.value, p=aug.b4_pval.value, draw=None, batch=None),
579 | Rotate(max_deg=aug.b2_max.value, p=aug.b2_pval.value, draw=None, size=None, mode='bilinear', pad_mode=aug_dash.pad.value),
580 | Warp(magnitude=aug.b3_mag.value,p=aug.b3_pval.value,draw_x=None,draw_y=None,size=None,mode='bilinear',pad_mode=aug_dash.pad.value,batch=False,),
581 | Contrast(max_lighting=aug.b1_max.value, p=aug.b1_pval.value, draw=aug.b1_draw.value, batch=True), #draw = 1 is normal batch=batch tfms or not
582 | Dihedral(p=aug.b5_pval.value, draw=aug.b5_draw.value, size=None, mode='bilinear', pad_mode=PadMode.Reflection, batch=False),
583 | Zoom(max_zoom=aug.b6_zoom.value, p=aug.b6_pval.value, draw=None, draw_x=None, draw_y=None, size=None, mode='bilinear',pad_mode=aug_dash.pad.value, batch=False)
584 | ]
585 |
586 | after_b = [Resize(final_size), IntToFloatTensor(), *aug_transforms(xtra_tfms=xtra_tfms, pad_mode=aug_dash.pad.value),
587 | Normalize(stats_info.stats)]
588 |
589 | if display_ui.tab.selected_index == 2: #>>> Augmentation tab
590 |
591 | tfms = [[PILImage.create], [parent_label, Categorize]]
592 | item_tfms = [ToTensor(), Resize(item_size)]
593 | dsets = Datasets(code_test.items, tfms=tfms)
594 | dls = dsets.dataloaders(after_item=item_tfms, after_batch=after_b, bs=int(aug_dash.bs.value), num_workers=0)
595 |
596 | dls.show_batch(max_n=12, nrows=2, ncols=6)
597 |
598 | if display_ui.tab.selected_index == 3: #>>> DataBlock tab
599 |
600 | items = get_image_files(ds_choice.source/'train')
601 | split_idx = block_ch.spl_val(items)
602 | tfms = [[block_ch.cls], [block_ch.outputb, block_ch.ctg]]
603 | item_tfms = [ToTensor(), Resize(item_size)]
604 | dsets = Datasets(items, tfms=tfms, splits=split_idx)
605 | dls = dsets.dataloaders(after_item=item_tfms, after_batch=after_b, bs=int(aug_dash.bs.value), num_workers=0)
606 |
607 | display(db_button2)
608 | db_out = widgets.Output()
609 | display(db_out)
610 | def on_db_out(b):
611 | clear_output()
612 | xb, yb = dls.one_batch()
613 | print(BOLD + BLUE + "Train: " + RESET + RED + '(' + str(len(dls.train)) + ', ' + str(len(dls.train_ds)) + ') ' +
614 | BOLD + BLUE + "Valid: "+ RESET + RED + '(' + str(len(dls.valid)) + ', ' + str(len(dls.valid_ds)) + ')')
615 | print(BOLD + BLUE + "Input Shape: " + RESET + RED + str(xb.shape))
616 | print(BOLD + BLUE + "Output Shape: " + RESET + RED + str(yb.shape) + " by " + str(dls.c) + " classes")
617 | dls.show_batch(max_n=12, nrows=2, ncols=6)
618 | db_button2.on_click(on_db_out)
619 |
620 | def write_code():
621 | """Helper for writing code"""
622 | write_button = widgets.Button(description='Code', button_style = 'success')
623 | display(write_button)
624 | write_out = widgets.Output()
625 | display(write_out)
626 | def on_write_button(b):
627 | with write_out:
628 | clear_output()
629 | print(RED + BOLD + '"""import libraries"""' + RESET)
630 | print(GREEN + BOLD + 'from' + RESET + ' fastai2.vision.all ' + GREEN + BOLD + 'import*' + RESET)
631 | print(RED + BOLD + '\n"""get data source and image files from source"""' + RESET)
632 | print('source = untar_data(URLs.' + str(dashboard_two.datas.value) + ')')
633 | print('items = get_image_files(source)')
634 | print(RED + BOLD + '\n"""get item, split and batch transforms"""' + RESET)
635 | print('tfms = [[' + str(block_ch.code) + ']' + ' ,' + '[' + str(block_ch.outputb_code) + ', '
636 | + str(block_ch.ctg_code) + ']]')
637 | print('item_tfms = [ToTensor(), Resize(' + GREEN + str(aug_dash.imgsiz.value) + RESET + ')]')
638 | print('split_idx = ' + str(block_ch.spl_val_code) + '(items)')
639 | print(RED + BOLD + '\n"""image augmentations"""' + RESET)
640 | if aug_dash.aug.value == 'No':
641 | print('xtra_tfms = '+ GREEN + BOLD + 'None' + RESET)
642 | if aug_dash.aug.value == 'Yes':
643 | print(RESET + 'xtra_tfms = [RandomErasing(p=' + GREEN + str(aug.b_pval.value) + RESET + ', max_count=' +
644 | GREEN + str(aug.b_max.value) + RESET + ', min_aspect=' + GREEN + str(aug.b_asp.value) + RESET + ', sl=' +
645 | GREEN + str(aug.b_len.value) + RESET + ', sh=' + RESET + GREEN + str(aug.b_ht.value) + RESET + '),')
646 | print(RESET + ' Brightness(max_lighting=' + GREEN + str(aug.b4_max.value) + RESET + ', p=' + GREEN +
647 | str(aug.b4_pval.value) + RESET + ', draw=' + GREEN + BOLD + 'None' + RESET + ', batch=' + GREEN + BOLD + 'None' +
648 | RESET + ')')
649 | print(RESET + ' Rotate(max_deg=' + GREEN + str(aug.b2_max.value) + RESET + ', p=' + GREEN +
650 | str(aug.b2_pval.value) + RESET + ', draw=' + GREEN + BOLD + 'None' + RESET + ', size=' + GREEN + BOLD + 'None' +
651 | RESET + ', mode=' + RED + "'bilinear'" + RESET + ', pad_mode=' + RED + str(aug_dash.pad.value) + RESET + ')')
652 | print(RESET + ' Warp(magnitude=' + GREEN + str(aug.b3_mag.value) + RESET + ', p=' + GREEN +
653 | str(aug.b3_pval.value) + RESET + ', draw_x=' + GREEN + BOLD + 'None' + RESET + ', draw_y=' + GREEN + BOLD +
654 | 'None' + RESET + ', size=' + GREEN + BOLD + 'None' + RESET + ', mode=' + RED + "'bilinear'" + RESET + ', pad_mode=' +
655 | RED + str(aug_dash.pad.value) + RESET + ', batch=' + GREEN + BOLD + 'False' + RESET + ')')
656 | print(RESET + ' Contrast(max_lighting=' + GREEN + str(aug.b1_max.value) + RESET + ', p=' + GREEN +
657 | str(aug.b1_pval.value) + RESET + ', draw=' + GREEN + str(aug.b1_draw.value) + RESET + ', batch=' + GREEN + BOLD +
658 | 'True' + RESET + ')')
659 | print(RESET + ' Dihedral(p=' + GREEN + str(aug.b5_pval.value) + RESET + ', draw' + GREEN +
660 | str(aug.b5_draw.value) + RESET + ', size=' + GREEN + BOLD + 'None' + RESET + ', mode=' + RED + "'bilinear'" + RESET +
661 | ', pad_mode=' + RED + str(aug_dash.pad.value) + RESET + ', batch=' + GREEN + BOLD + 'False' + RESET + ')')
662 | print(RESET + ' Zoom(max_zoom=' + GREEN + str(aug.b6_zoom.value) + RESET + ', p=' + GREEN +
663 | str(aug.b6_pval.value) + RESET + ', draw=' + GREEN + BOLD + 'None' + RESET + ', draw_x=' + GREEN + BOLD + 'None' + RESET +
664 | ', draw_y=' + GREEN + BOLD + 'None' + RESET + ', size=' + GREEN + BOLD + 'None' + RESET + ', mode=' + RED + "'bilinear'" +
665 | RESET + ', pad_mode=' + RED + str(aug_dash.pad.value) + RESET + ', batch=' + GREEN + BOLD + 'False' + RESET + ')]')
666 |
667 | print('\nafter_b = [Resize(' + GREEN + str(aug_dash.imgbth.value) + RESET + '), IntToFloatTensor(), ' + '\n ' + "*aug_transforms(xtra_tfms=xtra_tfms, pad_mode="
668 | + RED + "'" + str(aug_dash.pad.value) + "'" + RESET + '),' + ' Normalize.from_stats(' + GREEN + str(stats_info.code) + RESET + ')]')
669 |
670 | print('\ndsets = Datasets(items, tfms=tfms, splits=split_idx)')
671 | print('dls = dsets.dataloaders(after_item=item_tfms, after_batch=after_b, bs=' + GREEN + (aug_dash.bs.value) + RESET +
672 | ', num_workers=' + GREEN + '0' + RESET + ')')
673 |
674 |
675 | print(RED + BOLD + '\n"""Check training and valid shapes"""' + RESET)
676 | print('xb, yb = dls.one_batch()')
677 | print(RESET + 'dls.train' + RESET + RED + BOLD + ' #train')
678 | print(RESET + 'dls.train_ds' + RESET + RED + BOLD + '#train_ds')
679 | print(RESET + 'dls.valid' + RESET + RED + BOLD + ' #valid')
680 | print(RESET + 'dls.valid_ds' + RESET + RED + BOLD + ' #valid_ds')
681 |
682 | print(RESET + RED + BOLD + '\n"""show batch"""' + RESET)
683 | print('dls.show_batch(max_n=' + GREEN + '12' + RESET + ', rows=' + GREEN + '2' + RESET + ', cols=' +
684 | GREEN + '6' + RESET + ')')
685 | print(RESET + RED + BOLD + '\n"""train"""' + RESET)
686 | print('arch = xresnet50(pretrained=' + GREEN + BOLD + 'False' + RESET + ')')
687 | print('learn = Learner(dls, model=arch, loss_function=LabelSmoothingCrossEntropy(),' +
688 | '\n metrics=[top_k_accuracy, accuracy])')
689 | print(RESET + 'learn.fit_one_cycle(' + GREEN + '1' + RESET + ', ' + GREEN + '1e-2' + RESET + ')')
690 | print(RESET + RED + BOLD + '\n"""interpretations"""' + RESET)
691 | print('interp = ClassificationInterpretation.from_learner(learn)')
692 | print('losses, idxs = interp.top_losses()')
693 | print(GREEN + 'len' + RESET + '(dls.valid_ds)' + RESET + '==' + GREEN + 'len' + RESET + '(losses)==' +
694 | GREEN + 'len' + RESET + '(idxs)')
695 | print(RESET + RED + BOLD + '\n"""confusion matrix"""' + RESET)
696 | print('interp.plot_confusion_matrix(figsize=(' + GREEN + '7' + RESET + ',' + GREEN + '7' + RESET + '))')
697 | write_button.on_click(on_write_button)
698 |
699 | def play_info():
700 | """Helper for imagewoof play"""
701 |
702 | item_size = int(aug_dash.imgsiz.value)
703 | final_size = int(aug_dash.imgbth.value)
704 |
705 | if aug_dash.aug.value == 'No':
706 | print(BOLD + BLUE + "working.....: " + RESET + RED + 'No Augmentations\n')
707 | print(BOLD + BLUE + "Multi/Single Image: " + RESET + RED + str(aug_dash.bi.value))
708 | print(BOLD + BLUE + "Padding: " + RESET + RED + str(aug_dash.pad.value))
709 | print(BOLD + BLUE + "Normalization: " + RESET + RED + str(stats_info.stats))
710 | print(BOLD + BLUE + "Batch Size: " + RESET + RED + (aug_dash.bs.value))
711 | print(BOLD + BLUE + "Item Size: " + RESET + RED + str(item_size))
712 | print(BOLD + BLUE + "Final Size: " + RESET + RED + str(final_size))
713 | after_b = None
714 |
715 | if aug_dash.aug.value == 'Yes':
716 | print(BOLD + BLUE + 'Loading ImageWoof-160\n' + RESET)
717 | print(BOLD + BLUE + "Current Augmentations:" + RESET)
718 | print(BOLD + BLUE + "RandomErasing: " + RESET + RED + 'max_count=' + str(aug.b_max.value) + ' p=' + str(aug.b_pval.value))
719 | print(BOLD + BLUE + "Contrast: " + RESET + RED + 'max_light=' + str(aug.b1_max.value) + ' p=' + str(aug.b1_pval.value))
720 | print(BOLD + BLUE + "Rotate: " + RESET + RED + 'max_degree=' + str(aug.b2_max.value) + ' p=' + str(aug.b2_pval.value))
721 | print(BOLD + BLUE + "Warp: " + RESET + RED + 'magnitude=' + str(aug.b3_mag.value) + ' p=' + str(aug.b3_pval.value))
722 | print(BOLD + BLUE + "Brightness: " + RESET + RED + 'max_light=' + str(aug.b4_max.value) + ' p=' + str(aug.b4_pval.value))
723 | print(BOLD + BLUE + "DihedralFlip: " + RESET + RED + ' p=' + str(aug.b5_pval.value) + str(aug.b5_draw.value))
724 | print(BOLD + BLUE + "Zoom: " + RESET + RED + 'max_zoom=' + str(aug.b6_zoom.value) + ' p=' + str(aug.b6_pval.value))
725 | print(BOLD + BLUE + "\nMulti/Single Image: " + RESET + RED + str(aug_dash.bi.value))
726 | print(BOLD + BLUE + "Padding: " + RESET + RED + str(aug_dash.pad.value))
727 | print(BOLD + BLUE + "Normalization: " + RESET + RED + str(stats_info.stats))
728 | print(BOLD + BLUE + "Batch Size: " + RESET + RED + (aug_dash.bs.value))
729 | print(BOLD + BLUE + "Item Size: " + RESET + RED + str(item_size))
730 | print(BOLD + BLUE + "Final Size: " + RESET + RED + str(final_size))
731 |
732 | xtra_tfms = [RandomErasing(p=aug.b_pval.value, max_count=aug.b_max.value, min_aspect=aug.b_asp.value, sl=aug.b_len.value, sh=aug.b_ht.value), #p= probabilty
733 | Brightness(max_lighting=aug.b4_max.value, p=aug.b4_pval.value, draw=None, batch=None),
734 | Rotate(max_deg=aug.b2_max.value, p=aug.b2_pval.value, draw=None, size=None, mode='bilinear', pad_mode=aug_dash.pad.value),
735 | Warp(magnitude=aug.b3_mag.value,p=aug.b3_pval.value,draw_x=None,draw_y=None,size=None,mode='bilinear',pad_mode='reflection',batch=False,),
736 | Contrast(max_lighting=aug.b1_max.value, p=aug.b1_pval.value, draw=aug.b1_draw.value, batch=True), #draw = 1 is normal batch=batch tfms or not
737 | Dihedral(p=aug.b5_pval.value, draw=aug.b5_draw.value, size=None, mode='bilinear', pad_mode=PadMode.Reflection, batch=False),
738 | Zoom(max_zoom=aug.b6_zoom.value, p=aug.b6_pval.value, draw=None, draw_x=None, draw_y=None, size=None, mode='bilinear',pad_mode=aug_dash.pad.value, batch=False)
739 | ]
740 | after_b = [Resize(final_size), IntToFloatTensor(), *aug_transforms(xtra_tfms=xtra_tfms, pad_mode=aug_dash.pad.value),
741 | Normalize(stats_info.stats)]
742 |
743 | source_play = untar_data(URLs.IMAGEWOOF_160)
744 | items = get_image_files(source_play/'train')
745 |
746 | tfms = [[PILImage.create], [parent_label, Categorize]]
747 | item_tfms = [ToTensor(), Resize(item_size)]
748 | dsets = Datasets(items, tfms=tfms)
749 | dls = dsets.dataloaders(after_item=item_tfms, after_batch=after_b, bs=int(aug_dash.bs.value), num_workers=0)
750 |
751 | dls.show_batch(max_n=6, nrows=1, ncols=6)
752 | imagewoof_plaz()
753 |
754 | def imagewoof_plaz():
755 |
756 | item_size = int(aug_dash.imgsiz.value)
757 | final_size = int(aug_dash.imgbth.value)
758 |
759 | button_t2 = widgets.Button(description='Play')
760 | button_res = widgets.Button(description='View')
761 | res = widgets.ToggleButtons(value=None, options=['Confusion Matrix', 'Most Confused', 'Top Losses'], description='', button_style='info')
762 | play_out = widgets.Output()
763 | display(button_t2)
764 | display(play_out)
765 | def on_play_button(b):
766 | with play_out:
767 | clear_output()
768 | print('Training.....')
769 | if aug_dash.aug.value == 'No':
770 | after_b = None
771 | if aug_dash.aug.value == 'Yes':
772 |
773 | xtra_tfms = [RandomErasing(p=aug.b_pval.value, max_count=aug.b_max.value, min_aspect=aug.b_asp.value, sl=aug.b_len.value, sh=aug.b_ht.value), #p= probabilty
774 | Brightness(max_lighting=aug.b4_max.value, p=aug.b4_pval.value, draw=None, batch=None),
775 | Rotate(max_deg=aug.b2_max.value, p=aug.b2_pval.value, draw=None, size=None, mode='bilinear', pad_mode=aug_dash.pad.value),
776 | Warp(magnitude=aug.b3_mag.value,p=aug.b3_pval.value,draw_x=None,draw_y=None,size=None,mode='bilinear',pad_mode='reflection',batch=False,),
777 | Contrast(max_lighting=aug.b1_max.value, p=aug.b1_pval.value, draw=aug.b1_draw.value, batch=True), #draw = 1 is normal batch=batch tfms or not
778 | Dihedral(p=aug.b5_pval.value, draw=aug.b5_draw.value, size=None, mode='bilinear', pad_mode=PadMode.Reflection, batch=False),
779 | Zoom(max_zoom=aug.b6_zoom.value, p=aug.b6_pval.value, draw=None, draw_x=None, draw_y=None, size=None, mode='bilinear',pad_mode=PadMode.Reflection, batch=False)
780 | ]
781 |
782 | after_b = [Resize(final_size), IntToFloatTensor(), *aug_transforms(xtra_tfms=xtra_tfms, pad_mode=aug_dash.pad.value),
783 | Normalize(stats_info.stats)]
784 |
785 | source_play = untar_data(URLs.IMAGEWOOF_160)
786 | items = get_image_files(source_play/'train')
787 |
788 | split_idx = block_ch.spl_val(items)
789 | tfms = [[block_ch.cls], [block_ch.outputb, block_ch.ctg]]
790 | item_tfms = [ToTensor(), Resize(item_size)]
791 | dsets = Datasets(items, tfms=tfms, splits=split_idx)
792 |
793 | dls = dsets.dataloaders(after_item=item_tfms, after_batch=after_b, bs=int(aug_dash.bs.value), num_workers=0)
794 | arch = xresnet50(pretrained=False)
795 | #learn = Learner(dls, model=arch, loss_func=LabelSmoothingCrossEntropy(),
796 | # metrics=[top_k_accuracy, accuracy])
797 | learn = cnn_learner(dls, xresnet50, loss_func=LabelSmoothingCrossEntropy(), metrics=[top_k_accuracy, accuracy])
798 | learn.fine_tune(1)
799 | print('Getting Intepretations....')
800 | interp = ClassificationInterpretation.from_learner(learn)
801 |
802 | losses,idxs = interp.top_losses()
803 |
804 | len(dls.valid_ds)==len(losses)==len(idxs)
805 | #display(play_opt)
806 | display(res)
807 | display(button_res)
808 | res_out = widgets.Output()
809 | display(res_out)
810 | def on_res_button(b):
811 | with res_out:
812 | clear_output()
813 | if res.value == 'Confusion Matrix':
814 | interp.plot_confusion_matrix(figsize=(7,7))
815 | if res.value == 'Most Confused':
816 | print(interp.most_confused(min_val=1))
817 | if res.value == 'Top Losses':
818 | interp.plot_top_losses(9, figsize=(7,7))
819 | button_res.on_click(on_res_button)
820 |
821 | button_t2.on_click(on_play_button)
822 |
823 | def display_ui():
824 | """ Display tabs for visual display"""
825 | button = widgets.Button(description="Train")
826 | button_b = widgets.Button(description="Metrics")
827 | button_m = widgets.Button(description='Model')
828 | button_l = widgets.Button(description='LR')
829 |
830 | test_button = widgets.Button(description='Batch')
831 | test2_button = widgets.Button(description='Test2')
832 |
833 | out1a = widgets.Output()
834 | out1 = widgets.Output()
835 | out2 = widgets.Output()
836 | out3 = widgets.Output()
837 | out4 = widgets.Output()
838 | out5 = widgets.Output()
839 |
840 | data1a = pd.DataFrame(np.random.normal(size = 50))
841 | data1 = pd.DataFrame(np.random.normal(size = 100))
842 | data2 = pd.DataFrame(np.random.normal(size = 150))
843 | data3 = pd.DataFrame(np.random.normal(size = 200))
844 | data4 = pd.DataFrame(np.random.normal(size = 250))
845 | data5 = pd.DataFrame(np.random.normal(size = 300))
846 |
847 | with out1a: #info
848 | clear_output()
849 | dashboard_one()
850 |
851 | with out1: #data
852 | clear_output()
853 | dashboard_two()
854 |
855 | with out2: #augmentation
856 | clear_output()
857 | aug_dash()
858 |
859 | with out3: #Block
860 | clear_output()
861 | ds_3()
862 |
863 | with out4: #code
864 | clear_output()
865 | write_code()
866 |
867 | with out5: #Imagewoof Play
868 | clear_output()
869 | print(BOLD + BLUE + 'Work in progress.....')
870 | play_button = widgets.Button(description='Parameters')
871 | display(play_button)
872 | play_out = widgets.Output()
873 | display(play_out)
874 | def button_play(b):
875 | with play_out:
876 | clear_output()
877 | play_info()
878 | play_button.on_click(button_play)
879 |
880 | display_ui.tab = widgets.Tab(children = [out1a, out1, out2, out3, out4, out5])
881 | display_ui.tab.set_title(0, 'Info')
882 | display_ui.tab.set_title(1, 'Data')
883 | display_ui.tab.set_title(2, 'Augmentation')
884 | display_ui.tab.set_title(3, 'DataBlock')
885 | display_ui.tab.set_title(4, 'Code')
886 | display_ui.tab.set_title(5, 'ImageWoof Play')
887 | display(display_ui.tab)
888 |
--------------------------------------------------------------------------------
/xresnet2.py:
--------------------------------------------------------------------------------
1 | #from https://github.com/fastai/fastai/blob/master/fastai/vision/models/xresnet2.py
2 |
3 | import torch.nn as nn
4 | import torch
5 | import math
6 | import torch.utils.model_zoo as model_zoo
7 | from torch.nn import Module #changed from torch_core
8 |
9 |
10 | __all__ = ['XResNet', 'xresnet18', 'xresnet34_2', 'xresnet50_2', 'xresnet101', 'xresnet152']
11 |
12 |
13 | def conv3x3(in_planes, out_planes, stride=1):
14 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
15 |
16 |
17 | class BasicBlock(Module):
18 | expansion = 1
19 |
20 | def __init__(self, inplanes, planes, stride=1, downsample=None):
21 | super(BasicBlock, self).__init__()
22 | self.conv1 = conv3x3(inplanes, planes, stride)
23 | self.bn1 = nn.BatchNorm2d(planes)
24 | self.relu = nn.ReLU(inplace=True)
25 | self.conv2 = conv3x3(planes, planes)
26 | self.bn2 = nn.BatchNorm2d(planes)
27 | self.downsample = downsample
28 | self.stride = stride
29 |
30 | def forward(self, x):
31 | residual = x
32 |
33 | out = self.conv1(x)
34 | out = self.bn1(out)
35 | out = self.relu(out)
36 |
37 | out = self.conv2(out)
38 | out = self.bn2(out)
39 |
40 | if self.downsample is not None: residual = self.downsample(x)
41 |
42 | out += residual
43 | out = self.relu(out)
44 |
45 | return out
46 |
47 |
48 | class Bottleneck(Module):
49 | expansion = 4
50 |
51 | def __init__(self, inplanes, planes, stride=1, downsample=None):
52 | super(Bottleneck, self).__init__()
53 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
54 | self.bn1 = nn.BatchNorm2d(planes)
55 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
56 | padding=1, bias=False)
57 | self.bn2 = nn.BatchNorm2d(planes)
58 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
59 | self.bn3 = nn.BatchNorm2d(planes * self.expansion)
60 | self.relu = nn.ReLU(inplace=True)
61 | self.downsample = downsample
62 | self.stride = stride
63 |
64 | def forward(self, x):
65 | residual = x
66 |
67 | out = self.conv1(x)
68 | out = self.bn1(out)
69 | out = self.relu(out)
70 |
71 | out = self.conv2(out)
72 | out = self.bn2(out)
73 | out = self.relu(out)
74 |
75 | out = self.conv3(out)
76 | out = self.bn3(out)
77 |
78 | if self.downsample is not None: residual = self.downsample(x)
79 |
80 | out += residual
81 | out = self.relu(out)
82 |
83 | return out
84 |
85 | def conv2d(ni, nf, stride):
86 | return nn.Sequential(nn.Conv2d(ni, nf, kernel_size=3, stride=stride, padding=1, bias=False),
87 | nn.BatchNorm2d(nf), nn.ReLU(inplace=True))
88 |
89 | class XResNet(Module):
90 |
91 | def __init__(self, block, layers, c_out=1000):
92 | self.inplanes = 64
93 | super(XResNet, self).__init__()
94 | self.conv1 = conv2d(3, 32, 2)
95 | self.conv2 = conv2d(32, 32, 1)
96 | self.conv3 = conv2d(32, 64, 1)
97 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
98 | self.layer1 = self._make_layer(block, 64, layers[0])
99 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
100 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
101 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
102 | self.avgpool = nn.AdaptiveAvgPool2d(1)
103 | self.fc = nn.Linear(512 * block.expansion, c_out)
104 |
105 | for m in self.modules():
106 | if isinstance(m, nn.Conv2d):
107 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
108 | elif isinstance(m, nn.BatchNorm2d):
109 | nn.init.constant_(m.weight, 1)
110 | nn.init.constant_(m.bias, 0)
111 |
112 | for m in self.modules():
113 | if isinstance(m, BasicBlock): m.bn2.weight = nn.Parameter(torch.zeros_like(m.bn2.weight))
114 | if isinstance(m, Bottleneck): m.bn3.weight = nn.Parameter(torch.zeros_like(m.bn3.weight))
115 | if isinstance(m, nn.Linear): m.weight.data.normal_(0, 0.01)
116 |
117 | def _make_layer(self, block, planes, blocks, stride=1):
118 | downsample = None
119 | if stride != 1 or self.inplanes != planes * block.expansion:
120 | layers = []
121 | if stride==2: layers.append(nn.AvgPool2d(kernel_size=2, stride=2))
122 | layers += [
123 | nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=1, bias=False),
124 | nn.BatchNorm2d(planes * block.expansion) ]
125 | downsample = nn.Sequential(*layers)
126 |
127 | layers = []
128 | layers.append(block(self.inplanes, planes, stride, downsample))
129 | self.inplanes = planes * block.expansion
130 | for i in range(1, blocks): layers.append(block(self.inplanes, planes))
131 | return nn.Sequential(*layers)
132 |
133 | def forward(self, x):
134 | x = self.conv1(x)
135 | x = self.conv2(x)
136 | x = self.conv3(x)
137 | x = self.maxpool(x)
138 |
139 | x = self.layer1(x)
140 | x = self.layer2(x)
141 | x = self.layer3(x)
142 | x = self.layer4(x)
143 |
144 | x = self.avgpool(x)
145 | x = x.view(x.size(0), -1)
146 | x = self.fc(x)
147 |
148 | return x
149 |
150 |
151 | def xresnet18(pretrained=False, **kwargs):
152 | """Constructs a XResNet-18 model.
153 | Args:
154 | pretrained (bool): If True, returns a model pre-trained on ImageNet
155 | """
156 | model = XResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
157 | if pretrained: model.load_state_dict(model_zoo.load_url(model_urls['xresnet18']))
158 | return model
159 |
160 |
161 | def xresnet34_2(pretrained=False, **kwargs):
162 | """Constructs a XResNet-34 model.
163 | Args:
164 | pretrained (bool): If True, returns a model pre-trained on ImageNet
165 | """
166 | model = XResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
167 | if pretrained: model.load_state_dict(model_zoo.load_url(model_urls['xresnet34']))
168 | return model
169 |
170 |
171 | def xresnet50_2(pretrained=False, **kwargs):
172 | """Constructs a XResNet-50 model.
173 | Args:
174 | pretrained (bool): If True, returns a model pre-trained on ImageNet
175 | """
176 | model = XResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
177 | if pretrained: model.load_state_dict(model_zoo.load_url(model_urls['xresnet50']))
178 | return model
179 |
180 |
181 | def xresnet101(pretrained=False, **kwargs):
182 | """Constructs a XResNet-101 model.
183 | Args:
184 | pretrained (bool): If True, returns a model pre-trained on ImageNet
185 | """
186 | model = XResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
187 | if pretrained: model.load_state_dict(model_zoo.load_url(model_urls['xresnet101']))
188 | return model
189 |
190 |
191 | def xresnet152(pretrained=False, **kwargs):
192 | """Constructs a XResNet-152 model.
193 | Args:
194 | pretrained (bool): If True, returns a model pre-trained on ImageNet
195 | """
196 | model = XResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
197 | if pretrained: model.load_state_dict(model_zoo.load_url(model_urls['xresnet152']))
198 | return model
199 |
--------------------------------------------------------------------------------