Some code examples are available in demo.ipynb and test.py. Additionally, in the section Projects you can see some practical examples of projects using this library.
For additional information on the Deep Learning library, visit the official web page www.keras.io or the GitHub repository https://github.com/fchollet/keras.
198 |
You can also use our custom Keras version, which provides several additional layers for Multimodal Learning.
15 | {%- endif %}
16 |
--------------------------------------------------------------------------------
/sphinx/source/conf.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #
3 | # multimodal_keras_wrapper documentation build configuration file, created by
4 | # sphinx-quickstart on Tue Apr 26 10:43:19 2016.
5 | #
6 | # This file is execfile()d with the current directory set to its
7 | # containing dir.
8 | #
9 | # Note that not all possible configuration values are present in this
10 | # autogenerated file.
11 | #
12 | # All configuration values have a default; values that are commented out
13 | # serve to show the default.
14 |
15 | import sys
16 | import os
17 | from recommonmark.parser import CommonMarkParser
18 |
19 | # If extensions (or modules to document with autodoc) are in another directory,
20 | # add these directories to sys.path here. If the directory is relative to the
21 | # documentation root, use os.path.abspath to make it absolute, like shown here.
22 | # sys.path.insert(0, os.path.abspath('.'))
23 | sys.path.insert(0, os.path.abspath('../../'))
24 | sys.path.insert(0, os.path.abspath('../_ext'))
25 |
26 | # -- General configuration ------------------------------------------------
27 |
28 | # If your documentation needs a minimal Sphinx version, state it here.
29 | # needs_sphinx = '1.0'
30 |
31 | # Add any Sphinx extension module names here, as strings. They can be
32 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
33 | # ones.
34 | extensions = [
35 | 'sphinx.ext.doctest',
36 | 'sphinx.ext.ifconfig',
37 | 'sphinx.ext.autodoc',
38 | 'edit_on_github'
39 | ]
40 |
41 | edit_on_github_project = 'MarcBS/multimodal_keras_wrapper'
42 | edit_on_github_branch = 'master'
43 |
44 | # Add any paths that contain templates here, relative to this directory.
45 | templates_path = ['_templates']
46 |
47 | source_parsers = {
48 | '.md': CommonMarkParser,
49 | }
50 |
51 | # The suffix(es) of source filenames.
52 | # You can specify multiple suffix as a list of string:
53 | source_suffix = ['.rst', '.md']
54 | # source_suffix = '.rst'
55 |
56 | # The encoding of source files.
57 | # source_encoding = 'utf-8-sig'
58 |
59 | # The master toctree document.
60 | master_doc = 'index'
61 |
62 | # General information about the project.
63 | project = u'Multimodal Keras Wrapper'
64 | copyright = u'2016, Marc Bolaños'
65 | author = u'Marc Bolaños'
66 |
67 | # The version info for the project you're documenting, acts as replacement for
68 | # |version| and |release|, also used in various other places throughout the
69 | # built documents.
70 | #
71 | # The short X.Y version.
72 | version = u'0.55'
73 | # The full version, including alpha/beta/rc tags.
74 | release = u'0.55'
75 |
76 | # The language for content autogenerated by Sphinx. Refer to documentation
77 | # for a list of supported languages.
78 | #
79 | # This is also used if you do content translation via gettext catalogs.
80 | # Usually you set "language" from the command line for these cases.
81 | language = None
82 |
83 | # There are two options for replacing |today|: either, you set today to some
84 | # non-false value, then it is used:
85 | # today = ''
86 | # Else, today_fmt is used as the format for a strftime call.
87 | # today_fmt = '%B %d, %Y'
88 |
89 | # List of patterns, relative to source directory, that match files and
90 | # directories to ignore when looking for source files.
91 | # This patterns also effect to html_static_path and html_extra_path
92 | exclude_patterns = []
93 |
94 | # The reST default role (used for this markup: `text`) to use for all
95 | # documents.
96 | # default_role = None
97 |
98 | # If true, '()' will be appended to :func: etc. cross-reference text.
99 | # add_function_parentheses = True
100 |
101 | # If true, the current module name will be prepended to all description
102 | # unit titles (such as .. function::).
103 | # add_module_names = True
104 |
105 | # If true, sectionauthor and moduleauthor directives will be shown in the
106 | # output. They are ignored by default.
107 | # show_authors = False
108 |
109 | # The name of the Pygments (syntax highlighting) style to use.
110 | pygments_style = 'sphinx'
111 |
112 | # A list of ignored prefixes for module index sorting.
113 | # modindex_common_prefix = []
114 |
115 | # If true, keep warnings as "system message" paragraphs in the built documents.
116 | # keep_warnings = False
117 |
118 | # If true, `todo` and `todoList` produce output, else they produce nothing.
119 | todo_include_todos = False
120 |
121 | # -- Options for HTML output ----------------------------------------------
122 |
123 | # The theme to use for HTML and HTML Help pages. See the documentation for
124 | # a list of builtin themes.
125 | # html_theme = 'alabaster'
126 | html_theme = 'sphinx_rtd_theme'
127 |
128 | # Theme options are theme-specific and customize the look and feel of a theme
129 | # further. For a list of options available for each theme, see the
130 | # documentation.
131 | # html_theme_options = {}
132 |
133 | # Add any paths that contain custom themes here, relative to this directory.
134 | html_theme_path = ["_themes", ]
135 |
136 | # The name for this set of Sphinx documents.
137 | # " v documentation" by default.
138 | # html_title = u'multimodal_keras_wrapper v0.1'
139 |
140 | # A shorter title for the navigation bar. Default is the same as html_title.
141 | # html_short_title = None
142 |
143 | # The name of an image file (relative to this directory) to place at the top
144 | # of the sidebar.
145 | # html_logo = None
146 |
147 | # The name of an image file (relative to this directory) to use as a favicon of
148 | # the docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32
149 | # pixels large.
150 | # html_favicon = None
151 |
152 | # Add any paths that contain custom static files (such as style sheets) here,
153 | # relative to this directory. They are copied after the builtin static files,
154 | # so a file named "default.css" will overwrite the builtin "default.css".
155 | html_static_path = ['_static']
156 |
157 | # Add any extra paths that contain custom files (such as robots.txt or
158 | # .htaccess) here, relative to this directory. These files are copied
159 | # directly to the root of the documentation.
160 | # html_extra_path = []
161 |
162 | # If not None, a 'Last updated on:' timestamp is inserted at every page
163 | # bottom, using the given strftime format.
164 | # The empty string is equivalent to '%b %d, %Y'.
165 | # html_last_updated_fmt = None
166 |
167 | # If true, SmartyPants will be used to convert quotes and dashes to
168 | # typographically correct entities.
169 | # html_use_smartypants = True
170 |
171 | # Custom sidebar templates, maps document names to template names.
172 | # html_sidebars = {}
173 |
174 | # Additional templates that should be rendered to pages, maps page names to
175 | # template names.
176 | # html_additional_pages = {}
177 |
178 | # If false, no module index is generated.
179 | # html_domain_indices = True
180 |
181 | # If false, no index is generated.
182 | # html_use_index = True
183 |
184 | # If true, the index is split into individual pages for each letter.
185 | # html_split_index = False
186 |
187 | # If true, links to the reST sources are added to the pages.
188 | # html_show_sourcelink = True
189 |
190 | # If true, "Created using Sphinx" is shown in the HTML footer. Default is True.
191 | # html_show_sphinx = True
192 |
193 | # If true, "(C) Copyright ..." is shown in the HTML footer. Default is True.
194 | # html_show_copyright = True
195 |
196 | # If true, an OpenSearch description file will be output, and all pages will
197 | # contain a tag referring to it. The value of this option must be the
198 | # base URL from which the finished HTML is served.
199 | # html_use_opensearch = ''
200 |
201 | # This is the file name suffix for HTML files (e.g. ".xhtml").
202 | # html_file_suffix = None
203 |
204 | # Language to be used for generating the HTML full-text search index.
205 | # Sphinx supports the following languages:
206 | # 'da', 'de', 'en', 'es', 'fi', 'fr', 'hu', 'it', 'ja'
207 | # 'nl', 'no', 'pt', 'ro', 'ru', 'sv', 'tr', 'zh'
208 | # html_search_language = 'en'
209 |
210 | # A dictionary with options for the search language support, empty by default.
211 | # 'ja' uses this config value.
212 | # 'zh' user can custom change `jieba` dictionary path.
213 | # html_search_options = {'type': 'default'}
214 |
215 | # The name of a javascript file (relative to the configuration directory) that
216 | # implements a search results scorer. If empty, the default will be used.
217 | # html_search_scorer = 'scorer.js'
218 |
219 | # Output file base name for HTML help builder.
220 | htmlhelp_basename = 'multimodal_keras_wrapperdoc'
221 |
222 | # -- Options for LaTeX output ---------------------------------------------
223 |
224 | latex_elements = {
225 | # The paper size ('letterpaper' or 'a4paper').
226 | # 'papersize': 'letterpaper',
227 |
228 | # The font size ('10pt', '11pt' or '12pt').
229 | # 'pointsize': '10pt',
230 |
231 | # Additional stuff for the LaTeX preamble.
232 | # 'preamble': '',
233 |
234 | # Latex figure (float) alignment
235 | # 'figure_align': 'htbp',
236 | }
237 |
238 | # Grouping the document tree into LaTeX files. List of tuples
239 | # (source start file, target name, title,
240 | # author, documentclass [howto, manual, or own class]).
241 | latex_documents = [
242 | (master_doc, 'multimodal_keras_wrapper.tex', u'multimodal\\_keras\\_wrapper Documentation',
243 | u'Marc Bolaños', 'manual'),
244 | ]
245 |
246 | # The name of an image file (relative to this directory) to place at the top of
247 | # the title page.
248 | # latex_logo = None
249 |
250 | # For "manual" documents, if this is true, then toplevel headings are parts,
251 | # not chapters.
252 | # latex_use_parts = False
253 |
254 | # If true, show page references after internal links.
255 | # latex_show_pagerefs = False
256 |
257 | # If true, show URL addresses after external links.
258 | # latex_show_urls = False
259 |
260 | # Documents to append as an appendix to all manuals.
261 | # latex_appendices = []
262 |
263 | # If false, no module index is generated.
264 | # latex_domain_indices = True
265 |
266 |
267 | # -- Options for manual page output ---------------------------------------
268 |
269 | # One entry per manual page. List of tuples
270 | # (source start file, name, description, authors, manual section).
271 | man_pages = [
272 | (master_doc, 'multimodal_keras_wrapper', u'multimodal_keras_wrapper Documentation',
273 | [author], 1)
274 | ]
275 |
276 | # If true, show URL addresses after external links.
277 | # man_show_urls = False
278 |
279 |
280 | # -- Options for Texinfo output -------------------------------------------
281 |
282 | # Grouping the document tree into Texinfo files. List of tuples
283 | # (source start file, target name, title, author,
284 | # dir menu entry, description, category)
285 | texinfo_documents = [
286 | (master_doc, 'multimodal_keras_wrapper', u'multimodal_keras_wrapper Documentation',
287 | author, 'multimodal_keras_wrapper', 'One line description of project.',
288 | 'Miscellaneous'),
289 | ]
290 |
291 | # Documents to append as an appendix to all manuals.
292 | # texinfo_appendices = []
293 |
294 | # If false, no module index is generated.
295 | # texinfo_domain_indices = True
296 |
297 | # How to display URL addresses: 'footnote', 'no', or 'inline'.
298 | # texinfo_show_urls = 'footnote'
299 |
300 | # If true, do not generate a @detailmenu in the "Top" node's menu.
301 | # texinfo_no_detailmenu = False
302 |
303 |
304 | # -- Options for Epub output ----------------------------------------------
305 |
306 | # Bibliographic Dublin Core info.
307 | epub_title = project
308 | epub_author = author
309 | epub_publisher = author
310 | epub_copyright = copyright
311 |
312 | # The basename for the epub file. It defaults to the project name.
313 | # epub_basename = project
314 |
315 | # The HTML theme for the epub output. Since the default themes are not
316 | # optimized for small screen space, using the same theme for HTML and epub
317 | # output is usually not wise. This defaults to 'epub', a theme designed to save
318 | # visual space.
319 | # epub_theme = 'epub'
320 |
321 | # The language of the text. It defaults to the language option
322 | # or 'en' if the language is not set.
323 | # epub_language = ''
324 |
325 | # The scheme of the identifier. Typical schemes are ISBN or URL.
326 | # epub_scheme = ''
327 |
328 | # The unique identifier of the text. This can be a ISBN number
329 | # or the project homepage.
330 | # epub_identifier = ''
331 |
332 | # A unique identification for the text.
333 | # epub_uid = ''
334 |
335 | # A tuple containing the cover image and cover page html template filenames.
336 | # epub_cover = ()
337 |
338 | # A sequence of (type, uri, title) tuples for the guide element of content.opf.
339 | # epub_guide = ()
340 |
341 | # HTML files that should be inserted before the pages created by sphinx.
342 | # The format is a list of tuples containing the path and title.
343 | # epub_pre_files = []
344 |
345 | # HTML files that should be inserted after the pages created by sphinx.
346 | # The format is a list of tuples containing the path and title.
347 | # epub_post_files = []
348 |
349 | # A list of files that should not be packed into the epub file.
350 | epub_exclude_files = ['search.html']
351 |
352 | # The depth of the table of contents in toc.ncx.
353 | # epub_tocdepth = 3
354 |
355 | # Allow duplicate toc entries.
356 | # epub_tocdup = True
357 |
358 | # Choose between 'default' and 'includehidden'.
359 | # epub_tocscope = 'default'
360 |
361 | # Fix unsupported image types using the Pillow.
362 | # epub_fix_images = False
363 |
364 | # Scale large images.
365 | # epub_max_image_width = 0
366 |
367 | # How to display URL addresses: 'footnote', 'no', or 'inline'.
368 | # epub_show_urls = 'inline'
369 |
370 | # If false, no index is generated.
371 | # epub_use_index = True
372 |
--------------------------------------------------------------------------------
/sphinx/source/index.rst:
--------------------------------------------------------------------------------
1 | .. multimodal_keras_wrapper documentation master file, created by
2 | sphinx-quickstart on Tue Apr 26 10:43:19 2016.
3 | You can adapt this file completely to your liking, but it should at least
4 | contain the root `toctree` directive.
5 |
6 | Welcome to Multimodal Keras Wrapper's documentation!
7 | ======================================================
8 |
9 | Contents:
10 |
11 | .. toctree::
12 | :maxdepth: 2
13 |
14 | intro.md
15 | tutorial.md
16 | modules.rst
17 |
18 |
19 | Indices and tables
20 | ====================
21 |
22 | * :ref:`genindex`
23 | * :ref:`modindex`
24 | * :ref:`search`
25 |
26 |
--------------------------------------------------------------------------------
/sphinx/source/intro.md:
--------------------------------------------------------------------------------
1 | # Introduction
2 |
3 | ## Multimodal Keras Wrapper
4 | Wrapper for Keras with support to easy multimodal data and models loading and handling.
5 |
6 | You can download and contribute to the code downloading [this repository](https://github.com/MarcBS/multimodal_keras_wrapper).
7 |
8 |
9 | ## Documentation
10 |
11 | You can access the library documentation page at [marcbs.github.io/multimodal_keras_wrapper/](http://marcbs.github.io/multimodal_keras_wrapper/)
12 |
13 | Some code examples are available in demo.ipynb and test.py. Additionally, in the section Projects you can see some practical examples of projects using this library.
14 |
15 |
16 | ## Dependencies
17 |
18 | The following dependencies are required for using this library:
19 |
20 | - [Anaconda](https://www.continuum.io/downloads)
21 | - Keras - [custom fork](https://github.com/MarcBS/keras) or [original version](https://github.com/fchollet/keras)
22 | - [cloud](https://pypi.python.org/pypi/cloud/2.8.5) >= 2.8.5
23 | - [scipy](https://pypi.python.org/pypi/scipy/0.7.0)
24 |
25 | Only when using NMS for certain localization utilities:
26 | - [cython](https://pypi.python.org/pypi/Cython/0.25.2) >= 0.23.4
27 |
28 | ## Installation
29 |
30 | In order to install the library you just have to follow these steps:
31 |
32 | 1) Clone this repository:
33 | ```
34 | git clone https://github.com/MarcBS/multimodal_keras_wrapper.git
35 | ```
36 |
37 | 2) Include the repository path into your PYTHONPATH:
38 | ```
39 | export PYTHONPATH=$PYTHONPATH:/path/to/multimodal_keras_wrapper
40 | ```
41 |
42 | 3) If you wish to install the dependencies (it will install our [custom Keras fork](https://github.com/MarcBS/keras)):
43 | ```
44 | pip install -r requirements.txt
45 | ```
46 |
47 | ## Projects
48 |
49 | You can see more practical examples in projects which use this library:
50 |
51 | [VIBIKNet for Visual Question Answering](https://github.com/MarcBS/VIBIKNet)
52 |
53 | [ABiViRNet for Video Description](https://github.com/lvapeab/ABiViRNet)
54 |
55 | [Sentence-SelectioNN for Domain Adaptation in SMT](https://github.com/lvapeab/sentence-selectioNN)
56 |
57 |
58 | ## Keras
59 |
60 | For additional information on the Deep Learning library, visit the official web page www.keras.io or the GitHub repository https://github.com/fchollet/keras.
61 |
62 | You can also use our [custom Keras version](https://github.com/MarcBS/keras), which provides several additional layers for Multimodal Learning.
63 |
--------------------------------------------------------------------------------
/sphinx/source/modules.rst:
--------------------------------------------------------------------------------
1 | Available Modules
2 | **************************
3 |
4 | List of all files, classes and methods available in the library.
5 |
6 |
7 | dataset.py
8 | =============================
9 |
10 | .. automodule:: keras_wrapper.dataset
11 | :members:
12 |
13 |
14 | cnn_model.py
15 | =============================
16 |
17 | .. automodule:: keras_wrapper.cnn_model
18 | :members:
19 |
20 |
21 | callbacks_keras_wrapper.py
22 | =============================
23 |
24 | .. automodule:: keras_wrapper.callbacks_keras_wrapper
25 | :members:
26 |
27 |
28 | beam_search_ensemble.py
29 | =============================
30 |
31 | .. automodule:: keras_wrapper.beam_search_ensemble
32 | :members:
33 |
34 |
35 | utils.py
36 | =============================
37 |
38 | .. automodule:: keras_wrapper.utils
39 | :members:
40 |
--------------------------------------------------------------------------------
/sphinx/source/tutorial.md:
--------------------------------------------------------------------------------
1 | # Tutorial
2 |
3 | ## Basic components
4 |
5 | There are two basic components that have to be built in order to use the Multimodal Keras Wrapper,
6 | which are a **[Dataset](https://github.com/MarcBS/multimodal_keras_wrapper/blob/6d0b11248fd353cc189f674dc30beaf9689da182/keras_wrapper/dataset.py#L331)** and a **[Model_Wrapper](https://github.com/MarcBS/multimodal_keras_wrapper/blob/6d0b11248fd353cc189f674dc30beaf9689da182/keras_wrapper/cnn_model.py#L154)**.
7 |
8 | The class **Dataset** is in charge of:
9 | - Storing, preprocessing and loading any kind of data for training a model (inputs).
10 | - Storing, preprocessing and loading the ground truth associated to our data (outputs).
11 | - Loading the data in batches for training or prediction.
12 |
13 | The Datasets can manage different [types of input/output data](https://github.com/MarcBS/multimodal_keras_wrapper/blob/6d0b11248fd353cc189f674dc30beaf9689da182/keras_wrapper/dataset.py#L389-L390), which can be summarized as:
14 | - input types: 'raw-image', 'video', 'image-features', 'video-features', 'text'
15 | - output types: 'categorical', 'binary', 'real', 'text', '3DLabel'
16 |
17 | Currently, the class Dataset can be used for multiple kinds of multimodal problems,
18 | e.g. image/video classification, detection, multilabel prediction, regression, image/video captioning,
19 | visual question answering, multimodal translation, neural machine translation, etc.
20 |
21 | The class **Model_Wrapper** is in charge of:
22 | - Storing an instance of a Keras' model.
23 | - Receiving the inputs/outputs of the class Dataset and using the model for training or prediction.
24 | - Providing two different methods for prediction. Either [predictNet()](http://marcbs.github.io/multimodal_keras_wrapper/modules.html#keras_wrapper.cnn_model.Model_Wrapper.predictNet), which uses a conventional Keras model for prediction, or [predictBeamSearchNet()](http://marcbs.github.io/multimodal_keras_wrapper/modules.html#keras_wrapper.cnn_model.Model_Wrapper.predictBeamSearchNet), which applies a BeamSearch for sequence generative models and additionally allows to create separate models **model_init** and **model_next** for applying an optimized prediction (see [this](https://github.com/MarcBS/multimodal_keras_wrapper/blob/b348ce9d52404434b1e98316c7f09b5d5fd3df00/keras_wrapper/cnn_model.py#L1319-L1328) and [this](https://github.com/MarcBS/multimodal_keras_wrapper/blob/f269207a65bfc77d5c2c89ea708bad8bff7f72ab/keras_wrapper/cnn_model.py#L1057) for further information).
25 |
26 | In this tutorial we will learn how to create each of the two basic components and how use a
27 | model for training and prediction.
28 |
29 |
30 | ## Creating a Dataset
31 |
32 | First, let's create a simple Dataset object with some sample data.
33 | The data used for this example can be obtained by executing `/repository_root/data/get_data.sh`.
34 | This will download the data used for this example into `/repository_root/data/sample_data`.
35 |
36 |
37 | Dataset parameters definition.
38 |
39 | ```
40 | from keras_wrapper.dataset import Dataset
41 |
42 | dataset_name = 'test_dataset'
43 | image_id = 'input_image'
44 | label_id = 'output_label'
45 | images_size = [256, 256, 3]
46 | images_crop_size = [224, 224, 3]
47 | train_mean = [103.939, 116.779, 123.68]
48 | base_path = '/data/sample_data'
49 | ```
50 |
51 | Empty dataset instance creation
52 |
53 | ```
54 | ds = Dataset(dataset_name, base_path+'/images')
55 | ```
56 |
57 |
58 | Insert dataset/model inputs
59 |
60 | ```
61 | # train split
62 | ds.setInput(base_path + '/train.txt', 'train',
63 | type='raw-image', id=image_id,
64 | img_size=images_size, img_size_crop=images_crop_size)
65 | # val split
66 | ds.setInput(base_path + '/val.txt', 'val',
67 | type='raw-image', id=image_id,
68 | img_size=images_size, img_size_crop=images_crop_size)
69 | # test split
70 | ds.setInput(base_path + '/test.txt', 'test',
71 | type='raw-image', id=image_id,
72 | img_size=images_size, img_size_crop=images_crop_size)
73 | ```
74 |
75 | Insert pre-calculated images train mean
76 |
77 | ```
78 | ds.setTrainMean(train_mean, image_id)
79 | ```
80 |
81 | Insert dataset/model outputs
82 |
83 | ```
84 | # train split
85 | ds.setOutput(base_path+'/train_labels.txt', 'train',
86 | type='categorical', id=label_id)
87 | # val split
88 | ds.setOutput(base_path+'/val_labels.txt', 'val',
89 | type='categorical', id=label_id)
90 | # test split
91 | ds.setOutput(base_path+'/test_labels.txt', 'test',
92 | type='categorical', id=label_id)
93 | ```
94 |
95 | ## Saving or loading a Dataset
96 |
97 | ```
98 | from keras_wrapper.dataset import saveDataset, loadDataset
99 |
100 | save_path = '/Datasets'
101 |
102 | # Save dataset
103 | saveDataset(ds, save_path)
104 |
105 | # Load dataset
106 | ds = loadDataset(save_path+'/Dataset_'+dataset_name+'.pkl')
107 | ```
108 |
109 | In addition, we can print some basic information of the data stored in the dataset:
110 |
111 | ```
112 | print ds
113 | ```
114 |
115 | ## Creating a Model_Wrapper
116 |
117 | Model_Wrapper parameters definition.
118 |
119 | ```
120 | from keras_wrapper.cnn_model import Model_Wrapper
121 |
122 | model_name = 'our_model'
123 | type = 'VGG_19_ImageNet'
124 | save_path = '/Models/'
125 | ```
126 |
127 | Create a basic CNN model
128 |
129 | ```
130 | net = Model_Wrapper(nOutput=2, type=type, model_name=model_name, input_shape=images_crop_size)
131 | net.setOptimizer(lr=0.001, metrics=['accuracy']) # compile it
132 | ```
133 |
134 | By default, the model type built is the one defined in [Model_Wrapper.basic_model()](https://github.com/MarcBS/multimodal_keras_wrapper/blob/6d0b11248fd353cc189f674dc30beaf9689da182/keras_wrapper/cnn_model.py#L2003).
135 | Although, any kind of custom model can be defined just by:
136 | - Defining a new method for the class Model_Wrapper which builds the model and stores it in self.model.
137 | - Referencing it with type='method_name' when creating a new Model_Wrapper instance.
138 |
139 |
140 | ## Saving or loading a Model_Wrapper
141 |
142 | ```
143 | from keras_wrapper.cnn_model import saveModel, loadModel
144 |
145 | save_epoch = 0
146 |
147 | # Save model
148 | saveModel(net, save_epoch)
149 |
150 | # Load model
151 | net = loadModel(save_path+'/'+model_name, save_epoch)
152 | ```
153 |
154 |
155 | ## Connecting a Dataset to a Model_Wrapper
156 |
157 | In order to provide a correct communication between the Dataset and the Model_Wrapper objects, we have to provide the links between the Dataset ids positions and their corresponding layer identifiers in the Keras' Model as a dictionary.
158 |
159 | In this case we only have one input and one output, for this reason both ids are mapped to position 0 of our Dataset.
160 |
161 | ```
162 | net.setInputsMapping({net.ids_inputs[0]: 0})
163 | net.setOutputsMapping({net.ids_outputs[0]: 0})
164 | ```
165 |
166 |
167 | ## Training
168 |
169 | We can specify several options for training our model, which are [summarized here](http://marcbs.github.io/multimodal_keras_wrapper/modules.html#keras_wrapper.cnn_model.Model_Wrapper.trainNet). If any of them is overriden then the [default values](https://github.com/MarcBS/multimodal_keras_wrapper/blob/011393580b2253a01c168d638b8c0bd06fe6d522/keras_wrapper/cnn_model.py#L454-L458) will be used.
170 |
171 | ```
172 | train_overriden_parameters = {'n_epochs': 2, 'batch_size': 10}
173 |
174 | net.trainNet(ds, train_overriden_parameters)
175 | ```
176 |
177 | ## Prediction
178 |
179 | The same applies to the prediction method. We can find the [available parameters here](http://marcbs.github.io/multimodal_keras_wrapper/modules.html#keras_wrapper.cnn_model.Model_Wrapper.predictNet) and the [default values here](https://github.com/MarcBS/multimodal_keras_wrapper/blob/011393580b2253a01c168d638b8c0bd06fe6d522/keras_wrapper/cnn_model.py#L1468-L1470).
180 |
181 | ```
182 | predict_overriden_parameters = {'batch_size': 10, 'predict_on_sets': ['test']}
183 |
184 | net.predictNet(ds, predict_overriden_parameters)
185 | ```
--------------------------------------------------------------------------------
/tests/data/test_data.txt:
--------------------------------------------------------------------------------
1 | This is a text file. Containing characters of different encodings.
2 | ẁñ á é í ó ú à è ì ò ù ä ë ï ö ü ^
3 | 首先 ,
4 |
--------------------------------------------------------------------------------
/tests/extra/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarcBS/multimodal_keras_wrapper/1349edaaa0e13092a72280bb24316b460ed841de/tests/extra/__init__.py
--------------------------------------------------------------------------------
/tests/extra/test_wrapper_callbacks.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from six import iteritems
3 |
4 | # TODO
5 |
--------------------------------------------------------------------------------
/tests/extra/test_wrapper_evaluation.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import numpy as np
3 | from keras_wrapper.extra.evaluation import get_sacrebleu_score, get_coco_score, multilabel_metrics, get_perplexity
4 |
5 |
6 | def test_get_sacrebleu_score():
7 | pred_list = ['Prediction 1 X W Z', 'Prediction 2 X W Z', 'Prediction 3 X W Z']
8 |
9 | for tokenize_hypothesis in {True, False}:
10 | for tokenize_references in {True, False}:
11 | for apply_detokenization in {True, False}:
12 | extra_vars = {'val': {'references': {0: ['Prediction 1 X W Z', 'Prediction 5'],
13 | 1: ['Prediction 2 X W Z', 'X Y Z'],
14 | 2: ['Prediction 3 X W Z', 'Prediction 5']},
15 | },
16 |
17 | 'test': {'references': {0: ['Prediction 2 X W Z'],
18 | 1: ['Prediction 3 X W Z'],
19 | 2: ['Prediction 1 X W Z']}
20 | },
21 | 'tokenize_hypothesis': tokenize_hypothesis,
22 | 'tokenize_references': tokenize_references,
23 | 'tokenize_references': apply_detokenization,
24 | 'tokenize_f': lambda x: x,
25 | 'detokenize_f': lambda x: x,
26 | }
27 | val_scores = get_sacrebleu_score(pred_list, 0, extra_vars, 'val')
28 | assert np.allclose(val_scores['Bleu_4'], 100.0, atol=1e6)
29 |
30 | test_scores = get_sacrebleu_score(pred_list, 0, extra_vars, 'test')
31 | assert np.allclose(test_scores['Bleu_4'], 0., atol=1e6)
32 |
33 |
34 | def test_get_coco_score():
35 | pred_list = ['Prediction 1', 'Prediction 2', 'Prediction 3']
36 | extra_vars = {'val': {'references': {0: ['Prediction 1'], 1: ['Prediction 2'],
37 | 2: ['Prediction 3', 'Prediction 5']}},
38 | 'test': {'references': {0: ['Prediction 2'], 1: ['Prediction 3'],
39 | 2: ['Prediction 1']}}
40 | }
41 | val_scores = get_coco_score(pred_list, 0, extra_vars, 'val')
42 | assert np.allclose(val_scores['Bleu_1'], 1.0, atol=1e6)
43 | assert np.allclose(val_scores['Bleu_2'], 1.0, atol=1e6)
44 | assert np.allclose(val_scores['Bleu_3'], 1.0, atol=1e6)
45 | assert np.allclose(val_scores['Bleu_4'], 1.0, atol=1e6)
46 | assert np.allclose(val_scores['ROUGE_L'], 1.0, atol=1e6)
47 | assert np.allclose(val_scores['CIDEr'], 5.0, atol=1e6)
48 | assert np.allclose(val_scores['TER'], 0., atol=1e6)
49 | assert np.allclose(val_scores['METEOR'], 1.0, atol=1e6)
50 | test_scores = get_coco_score(pred_list, 0, extra_vars, 'test')
51 |
52 | assert np.allclose(test_scores['Bleu_1'], 0.5, atol=1e6)
53 | assert np.allclose(test_scores['Bleu_2'], 0., atol=1e6)
54 | assert np.allclose(test_scores['Bleu_3'], 0., atol=1e6)
55 | assert np.allclose(test_scores['Bleu_4'], 0., atol=1e6)
56 | assert np.allclose(test_scores['ROUGE_L'], 0.5, atol=1e6)
57 | assert np.allclose(test_scores['CIDEr'], 0., atol=1e6)
58 | assert np.allclose(test_scores['TER'], 0.5, atol=1e6)
59 | assert np.allclose(test_scores['METEOR'], 0.2, atol=1e6)
60 |
61 |
62 | def test_multilabel_metrics():
63 | pred_list = [['w1'], ['w2'], ['w3']]
64 | extra_vars = {
65 | 'val': {'references': [[1, 0, 0, 0, 0], [0, 1, 0, 0, 0], [0, 0, 1, 0, 0]],
66 | 'word2idx': {'w1': 0, 'w2': 1, 'w3': 3, 'w4': 3, 'w5': 4}
67 | },
68 | 'test': {'references': [[0, 0, 0, 0, 1], [0, 1, 0, 0, 0], [0, 0, 0, 0, 1]],
69 | 'word2idx': {'w1': 0, 'w2': 1, 'w3': 2, 'w4': 3, 'w5': 4}
70 | }
71 | }
72 | val_scores = multilabel_metrics(pred_list, 0, extra_vars, 'val')
73 |
74 | assert np.allclose(val_scores['f1'], 0.66, atol=1e6)
75 | assert np.allclose(val_scores['recall'], 0.66, atol=1e6)
76 | assert np.allclose(val_scores['precision'], 0.66, atol=1e6)
77 | assert np.allclose(val_scores['ranking_loss'], 0.33, atol=1e6)
78 | assert np.allclose(val_scores['coverage_error'], 2.33, atol=1e6)
79 | assert np.allclose(val_scores['average_precision'], 0.73, atol=1e6)
80 |
81 | test_scores = multilabel_metrics(pred_list, 0, extra_vars, 'test')
82 | assert np.allclose(test_scores['f1'], 0.33, atol=1e6)
83 | assert np.allclose(test_scores['recall'], 0.33, atol=1e6)
84 | assert np.allclose(test_scores['precision'], 0.22, atol=1e6)
85 | assert np.allclose(test_scores['ranking_loss'], 0.66, atol=1e6)
86 | assert np.allclose(test_scores['coverage_error'], 3.66, atol=1e6)
87 | assert np.allclose(test_scores['average_precision'], 0.466, atol=1e6)
88 |
89 |
90 | def test_multiclass_metrics():
91 | # TODO
92 | pass
93 |
94 |
95 | def test_compute_perplexity():
96 | costs = [1., 1., 1.]
97 | ppl = get_perplexity(costs=costs)
98 | assert np.allclose(ppl['Perplexity'], np.e, atol=1e6)
99 |
100 | costs = [0., 0., 0.]
101 | ppl = get_perplexity(costs=costs)
102 | assert np.allclose(ppl['Perplexity'], 0., atol=1e6)
103 |
104 |
105 | def test_semantic_segmentation_accuracy():
106 | # TODO
107 | pass
108 |
109 |
110 | def test_semantic_segmentation_meaniou():
111 | # TODO
112 | pass
113 |
114 |
115 | def test_averagePrecision():
116 | # TODO
117 | pass
118 |
119 |
120 | if __name__ == '__main__':
121 | pytest.main([__file__])
122 |
--------------------------------------------------------------------------------
/tests/extra/test_wrapper_localization_utilities.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | # TODO
4 |
--------------------------------------------------------------------------------
/tests/extra/test_wrapper_read_write.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import pytest
3 | import sys
4 | import os
5 | import numpy
6 | from six import iteritems
7 | from keras_wrapper.extra.read_write import *
8 | from keras_wrapper.utils import flatten_list_of_lists
9 |
10 |
11 | def test_dirac():
12 | assert dirac(1, 1) == 1
13 | assert dirac(2, 1) == 0
14 |
15 |
16 | def test_create_dir_if_not_exists():
17 | create_dir_if_not_exists('test_directory')
18 | assert os.path.isdir('test_directory')
19 |
20 |
21 | def test_clean_dir():
22 | clean_dir('test_directory')
23 | assert os.path.isdir('test_directory')
24 |
25 |
26 | def test_file2list():
27 | reference_text = 'ẁñ á é í ó ú à è ì ò ù ä ë ï ö ü ^'.decode('utf-8') if sys.version_info.major == 2 else 'ẁñ á é í ó ú à è ì ò ù ä ë ï ö ü ^'
28 | stripped_list = file2list('tests/data/test_data.txt', stripfile=True)
29 | assert len(stripped_list) == 3
30 | assert stripped_list[1] == reference_text
31 |
32 |
33 | def test_numpy2hdf5():
34 | filepath = 'test_file'
35 | data_name = 'test_data'
36 | my_np = np.random.rand(10, 10).astype('float32')
37 | numpy2hdf5(filepath, my_np, data_name=data_name)
38 | assert os.path.isfile(filepath)
39 | my_np_loaded = np.asarray(load_hdf5_simple(filepath, dataset_name=data_name)).astype('float32')
40 | assert np.all(my_np == my_np_loaded)
41 |
42 |
43 | def test_numpy2file():
44 | filepath = 'test_file'
45 | my_np = np.random.rand(10, 10).astype('float32')
46 | numpy2file(filepath, my_np)
47 | assert os.path.isfile(filepath)
48 | my_np_loaded = np.asarray(np.load(filepath)).astype('float32')
49 | assert np.all(my_np == my_np_loaded)
50 |
51 |
52 | def test_listoflists2file():
53 | mylist = [['This is a text file. Containing characters of different encodings.'],
54 | ['ẁñ á é í ó ú à è ì ò ù ä ë ï ö ü ^'],
55 | ['首先 ,']
56 | ]
57 | filepath = 'saved_list'
58 | listoflists2file(filepath, mylist)
59 | loaded_list = file2list('saved_list')
60 | flatten_list = [encode_list(sublist) for sublist in mylist]
61 | flatten_list = flatten_list_of_lists(flatten_list)
62 | assert loaded_list == flatten_list
63 |
64 |
65 | def test_list2file():
66 | mylist = ['This is a text file. Containing characters of different encodings.',
67 | 'ẁñ á é í ó ú à è ì ò ù ä ë ï ö ü ^',
68 | '首先 ,'
69 | ]
70 | filepath = 'saved_list'
71 | list2file(filepath, mylist)
72 | loaded_list = file2list('saved_list')
73 | my_encoded_list = encode_list(mylist)
74 | assert loaded_list == my_encoded_list
75 |
76 |
77 | def test_list2stdout():
78 | mylist = ['This is a text file. Containing characters of different encodings.',
79 | 'ẁñ á é í ó ú à è ì ò ù ä ë ï ö ü ^',
80 | '首先 ,'
81 | ]
82 | list2stdout(mylist)
83 |
84 |
85 | def test_nbest2file():
86 | my_nbest_list = [
87 | [[1, 'This is a text file. Containing characters of different encodings.', 0.1],
88 | [1, 'Other hypothesis. Containing characters of different encodings.', 0.2]
89 | ],
90 | [[2, 'ẁñ á é í ó ú à è ì ò ù ä ë ï ö ü ^', 0.3]],
91 | [[3, '首先 ,', 90.3]]
92 | ]
93 | filepath = 'saved_nbest'
94 | nbest2file(filepath, my_nbest_list)
95 | nbest = file2list(filepath)
96 | assert nbest == encode_list(['1 ||| This is a text file. Containing characters of different encodings. ||| 0.1',
97 | '1 ||| Other hypothesis. Containing characters of different encodings. ||| 0.2',
98 | '2 ||| ẁñ á é í ó ú à è ì ò ù ä ë ï ö ü ^ ||| 0.3',
99 | '3 ||| 首先 , ||| 90.3'])
100 |
101 |
102 | def test_dump_load_hdf5_simple():
103 | filepath = 'test_file'
104 | data_name = 'test_data'
105 | data = np.random.rand(10, 10).astype('float32')
106 | dump_hdf5_simple(filepath, data_name, data)
107 | loaded_data = load_hdf5_simple(filepath, dataset_name=data_name)
108 | assert np.all(loaded_data == data)
109 |
110 |
111 | def test_dict2file():
112 | filepath = 'saved_dict'
113 | mydict = {1: 'ẁñ á é í ó ú à è ì ò ù ä ë ï ö ü ^', '首先': 9}
114 | title = None
115 | dict2file(mydict, filepath, title, permission='w')
116 | loaded_dict = file2list(filepath)
117 | assert loaded_dict == encode_list(['1:ẁñ á é í ó ú à è ì ò ù ä ë ï ö ü ^', '首先:9'])
118 | title = 'Test dict'
119 | dict2file(mydict, filepath, title, permission='w')
120 | loaded_dict = file2list(filepath)
121 | assert loaded_dict == encode_list(['Test dict', '1:ẁñ á é í ó ú à è ì ò ù ä ë ï ö ü ^', '首先:9'])
122 |
123 |
124 | def test_dict2pkl_pkl2dict():
125 | filepath = 'saved_dict'
126 | mydict = {1: 'ẁñ á é í ó ú à è ì ò ù ä ë ï ö ü ^', '首先': 9}
127 | dict2pkl(mydict, filepath)
128 | loaded_dict = pkl2dict(filepath + '.pkl')
129 | assert loaded_dict == mydict
130 |
131 |
132 | if __name__ == '__main__':
133 | pytest.main([__file__])
134 |
--------------------------------------------------------------------------------
/tests/extra/test_wrapper_tokenizers.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import pytest
3 | from six import iteritems
4 | from keras_wrapper.extra.tokenizers import *
5 |
6 |
7 | def test_tokenize_basic():
8 | untokenized_string = u'This, ¿is a , .sentence with weird\xbb symbols ù ä ë ï ö ü ^首先 ,!!!\n\n'
9 | expected_string = u'This , ¿ is a , . sentence with weird\xbb symbols ù ä ë ï ö ü ^首先 , ! ! ! '
10 | tokenized_string = tokenize_basic(untokenized_string, lowercase=False)
11 | tokenized_string_lower = tokenize_basic(untokenized_string, lowercase=True)
12 | assert expected_string == tokenized_string
13 | assert expected_string.lower() == tokenized_string_lower
14 |
15 |
16 | def test_tokenize_aggressive():
17 | untokenized_string = u'This, ¿is a , .sentence with weird\xbb symbols ù ä ë ï ö ü ^首先 ,!!!\n\n'
18 | expected_string = u'This is a sentence with weird\xbb symbols ù ä ë ï ö ü ^首先'
19 | tokenized_string = tokenize_aggressive(untokenized_string, lowercase=False)
20 | tokenized_string_lower = tokenize_aggressive(untokenized_string, lowercase=True)
21 | assert expected_string == tokenized_string
22 | assert expected_string.lower() == tokenized_string_lower
23 |
24 |
25 | def test_tokenize_icann():
26 | untokenized_string = u'This, ¿is a , .sentence with weird\xbb symbols ù ä ë ï ö ü ^首先 ,!!!\n\n'
27 | expected_string = u'This , ¿is a , . sentence with weird\xbb symbols ù ä ë ï ö ü ^首先 , ! '
28 | tokenized_string_lower = tokenize_icann(untokenized_string)
29 | assert expected_string.lower() == tokenized_string_lower
30 |
31 |
32 | def test_tokenize_montreal():
33 | untokenized_string = u'This, ¿is a , .sentence with weird\xbb symbols ù ä ë ï ö ü ^首先 ,!!!\n\n'
34 | expected_string = u'This ¿is a sentence with weird\xbb symbols ù ä ë ï ö ü ^首先 !!!'
35 | tokenized_string_lower = tokenize_montreal(untokenized_string)
36 | assert expected_string.lower() == tokenized_string_lower
37 |
38 |
39 | def test_tokenize_soft():
40 | untokenized_string = u'This, ¿is a , .sentence with weird\xbb symbols ù ä ë ï ö ü ^首先 ,!!!\n\n'
41 | expected_string = u'This , ¿is a , . sentence with weird\xbb symbols ù ä ë ï ö ü ^首先 , ! '
42 | tokenized_string = tokenize_soft(untokenized_string, lowercase=False)
43 | tokenized_string_lower = tokenize_soft(untokenized_string, lowercase=True)
44 | assert expected_string == tokenized_string
45 | assert expected_string.lower() == tokenized_string_lower
46 |
47 |
48 | def test_tokenize_none():
49 | untokenized_string = u'This, ¿is a , .sentence with weird\xbb symbols ù ä ë ï ö ü ^首先 ,!!!\n\n'
50 | expected_string = u'This, ¿is a , .sentence with weird\xbb symbols ù ä ë ï ö ü ^首先 ,!!!'
51 | tokenized_string = tokenize_none(untokenized_string)
52 | assert expected_string == tokenized_string
53 |
54 |
55 | def test_tokenize_none_char():
56 | untokenized_string = u'This, ¿is a > < , .sentence with weird\xbb symbols'
57 | expected_string = u'T h i s , ¿ i s a > < , . s e n t e n c e w i t h w e i r d \xbb s y m b o l s'
58 | tokenized_string = tokenize_none_char(untokenized_string)
59 | assert expected_string == tokenized_string
60 |
61 |
62 | def test_tokenize_CNN_sentence():
63 | # TODO
64 | pass
65 |
66 |
67 | def test_tokenize_questions():
68 | # TODO
69 | pass
70 |
71 |
72 | def test_tokenize_bpe():
73 | # TODO
74 | pass
75 |
76 |
77 | def test_detokenize_none():
78 | tokenized_string = u'This, ¿is a , .sentence with weird\xbb symbols ù ä ë ï ö ü ^首先 ,!!!'
79 | expected_string = u'This, ¿is a , .sentence with weird\xbb symbols ù ä ë ï ö ü ^首先 ,!!!'
80 | detokenized_string = detokenize_none(tokenized_string)
81 | assert expected_string == detokenized_string
82 |
83 |
84 | def test_detokenize_none_char():
85 | tokenized_string = u'T h i s , ¿ i s a > < , . s e n t e n c e w i t h w e i r d \xbb s y m b o l s'
86 | expected_string = u'This, ¿is a > < , .sentence with weird\xbb symbols'
87 | detokenized_string = detokenize_none_char(tokenized_string)
88 | assert expected_string == detokenized_string
89 |
90 |
91 | if __name__ == '__main__':
92 | pytest.main([__file__])
93 |
--------------------------------------------------------------------------------
/tests/general/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarcBS/multimodal_keras_wrapper/1349edaaa0e13092a72280bb24316b460ed841de/tests/general/__init__.py
--------------------------------------------------------------------------------
/tests/general/test_model_wrapper.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from six import iteritems
3 |
4 |
5 | def test_model_wrapper():
6 | pass
7 |
8 | if __name__ == '__main__':
9 | pytest.main([__file__])
10 |
--------------------------------------------------------------------------------
/tests/general/test_model_wrapper_ensemble.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from six import iteritems
3 |
4 |
5 | def test_model_wrapper_ensemble():
6 | pass
7 |
8 | if __name__ == '__main__':
9 | pytest.main([__file__])
10 |
--------------------------------------------------------------------------------
/tests/general/test_wrapper_dataset.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from six import iteritems
3 |
4 |
5 | def test_dataset():
6 | pass
7 |
8 | if __name__ == '__main__':
9 | pytest.main([__file__])
10 |
--------------------------------------------------------------------------------
/utils/README.md:
--------------------------------------------------------------------------------
1 | # Multimodal Keras Wrapper utils
2 |
3 | In this directory, you'll find some utilities for Models and Datasets from the MKW.
4 | The main scripts are the following:
5 |
6 | * **average_models.py**: Performs model averaging for multiple models.
7 | * **minimize_dataset.py**: Removing the data stored in a dataset instance. Keeps the rest of attributes of the dataset (types, ids, params, preprocessing...).
8 |
9 |
--------------------------------------------------------------------------------
/utils/average_models.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import logging
3 | import sys
4 | import os
5 | from keras_wrapper.utils import average_models
6 | sys.path.insert(1, os.path.abspath("."))
7 | sys.path.insert(0, os.path.abspath("../"))
8 |
9 | logging.basicConfig(level=logging.INFO, format='[%(asctime)s] %(message)s', datefmt='%d/%m/%Y %H:%M:%S')
10 | logger = logging.getLogger(__name__)
11 |
12 |
13 | def parse_args():
14 | """
15 | Argument parser.
16 | :return:
17 | """
18 | parser = argparse.ArgumentParser("Averages models")
19 |
20 | parser.add_argument("-d", "--dest",
21 | default='./model',
22 | required=False,
23 | help="Path to the averaged model. If not specified, the model is saved in './model'.")
24 | parser.add_argument("-v", "--verbose", required=False, default=0, type=int, help="Verbosity level")
25 | parser.add_argument("-w", "--weights", nargs="*", help="Weight given to each model in the averaging. You should provide the same number of weights than models."
26 | "By default, it applies the same weight to each model (1/N).", default=[])
27 | parser.add_argument("-m", "--models", nargs="+", required=True, help="Path to the models")
28 | return parser.parse_args()
29 |
30 |
31 | def weighted_average(args):
32 | """
33 | Apply a weighted average to the models.
34 | :param args: Options for the averaging function:
35 | * models: Path to the models.
36 | * dest: Path to the averaged model. If unspecified, the model is saved in './model'
37 | * weights: Weight given to each model in the averaging. Should be the same number of weights than models.
38 | If unspecified, it applies the same weight to each model (1/N).
39 | :return:
40 | """
41 | logger.info("Averaging %d models" % len(args.models))
42 | average_models(args.models, args.dest, weights=args.weights)
43 | logger.info('Averaging finished.')
44 |
45 |
46 | if __name__ == "__main__":
47 |
48 | args = parse_args()
49 | weighted_average(args)
50 |
--------------------------------------------------------------------------------
/utils/minimize_dataset.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import argparse
3 | import os
4 | from keras_wrapper.dataset import loadDataset, saveDataset
5 |
6 |
7 | def parse_args():
8 | """
9 | Argument parser
10 | :return:
11 | """
12 | parser = argparse.ArgumentParser("Minimizes a dataset by removing the data stored in it: Tranining, development and test. "
13 | "The rest of parameters are kept."
14 | "Useful for reloading datasets with new data.")
15 | parser.add_argument("-d", "--dataset", required=True, help="Stored instance of the dataset")
16 | parser.add_argument("-o", "--output", help="Output dataset file.",
17 | default="")
18 | return parser.parse_args()
19 |
20 | if __name__ == "__main__":
21 |
22 | args = parse_args()
23 | # Load dataset
24 | ds = loadDataset(args.dataset)
25 | # Reinitialize values to empty
26 | ds.loaded_train = [False, False]
27 | ds.loaded_val = [False, False]
28 | ds.loaded_test = [False, False]
29 |
30 | ds.loaded_raw_train = [False, False]
31 | ds.loaded_raw_val = [False, False]
32 | ds.loaded_raw_test = [False, False]
33 |
34 | ds.len_train = 0
35 | ds.len_val = 0
36 | ds.len_test = 0
37 | # Remove data
38 | for key in ds.X_train.keys():
39 | ds.X_train[key] = None
40 | for key in ds.X_val.keys():
41 | ds.X_val[key] = None
42 | for key in ds.X_test.keys():
43 | ds.X_test[key] = None
44 | for key in ds.X_train.keys():
45 | ds.X_train[key] = None
46 | for key in ds.Y_train.keys():
47 | ds.Y_train[key] = None
48 | for key in ds.Y_val.keys():
49 | ds.Y_val[key] = None
50 | for key in ds.Y_test.keys():
51 | ds.Y_test[key] = None
52 | for key in ds.X_raw_train.keys():
53 | ds.X_raw_train[key] = None
54 | for key in ds.X_raw_val.keys():
55 | ds.X_raw_val[key] = None
56 | for key in ds.X_raw_test.keys():
57 | ds.X_raw_test[key] = None
58 | for key in ds.Y_raw_train.keys():
59 | ds.Y_raw_train[key] = None
60 | for key in ds.Y_raw_val.keys():
61 | ds.Y_raw_val[key] = None
62 | for key in ds.Y_raw_test.keys():
63 | ds.Y_raw_test[key] = None
64 |
65 | # Save dataset
66 | output_path = args.output if args.output else os.path.dirname(args.dataset)
67 | saveDataset(ds, output_path)
68 |
--------------------------------------------------------------------------------