[a-zA-Z0-9_.-]+))\s*", re.X)
196 |
197 | def _parse_see_also(self, content):
198 | """
199 | func_name : Descriptive text
200 | continued text
201 | another_func_name : Descriptive text
202 | func_name1, func_name2, :meth:`func_name`, func_name3
203 |
204 | """
205 | items = []
206 |
207 | def parse_item_name(text):
208 | """Match ':role:`name`' or 'name'"""
209 | m = self._name_rgx.match(text)
210 | if m:
211 | g = m.groups()
212 | if g[1] is None:
213 | return g[3], None
214 | else:
215 | return g[2], g[1]
216 | raise ValueError("%s is not a item name" % text)
217 |
218 | def push_item(name, rest):
219 | if not name:
220 | return
221 | name, role = parse_item_name(name)
222 | items.append((name, list(rest), role))
223 | del rest[:]
224 |
225 | current_func = None
226 | rest = []
227 |
228 | for line in content:
229 | if not line.strip():
230 | continue
231 |
232 | m = self._name_rgx.match(line)
233 | if m and line[m.end():].strip().startswith(':'):
234 | push_item(current_func, rest)
235 | current_func, line = line[:m.end()], line[m.end():]
236 | rest = [line.split(':', 1)[1].strip()]
237 | if not rest[0]:
238 | rest = []
239 | elif not line.startswith(' '):
240 | push_item(current_func, rest)
241 | current_func = None
242 | if ',' in line:
243 | for func in line.split(','):
244 | push_item(func, [])
245 | elif line.strip():
246 | current_func = line
247 | elif current_func is not None:
248 | rest.append(line.strip())
249 | push_item(current_func, rest)
250 | return items
251 |
252 | def _parse_index(self, section, content):
253 | """
254 | .. index: default
255 | :refguide: something, else, and more
256 |
257 | """
258 | def strip_each_in(lst):
259 | return [s.strip() for s in lst]
260 |
261 | out = {}
262 | section = section.split('::')
263 | if len(section) > 1:
264 | out['default'] = strip_each_in(section[1].split(','))[0]
265 | for line in content:
266 | line = line.split(':')
267 | if len(line) > 2:
268 | out[line[1]] = strip_each_in(line[2].split(','))
269 | return out
270 |
271 | def _parse_summary(self):
272 | """Grab signature (if given) and summary"""
273 | if self._is_at_section():
274 | return
275 |
276 | summary = self._doc.read_to_next_empty_line()
277 | summary_str = " ".join([s.strip() for s in summary]).strip()
278 | if re.compile('^([\w., ]+=)?\s*[\w\.]+\(.*\)$').match(summary_str):
279 | self['Signature'] = summary_str
280 | if not self._is_at_section():
281 | self['Summary'] = self._doc.read_to_next_empty_line()
282 | else:
283 | self['Summary'] = summary
284 |
285 | if not self._is_at_section():
286 | self['Extended Summary'] = self._read_to_next_section()
287 |
288 | def _parse(self):
289 | self._doc.reset()
290 | self._parse_summary()
291 |
292 | for (section, content) in self._read_sections():
293 | if not section.startswith('..'):
294 | section = ' '.join([s.capitalize()
295 | for s in section.split(' ')])
296 | if section in ('Parameters', 'Attributes', 'Methods',
297 | 'Returns', 'Raises', 'Warns'):
298 | self[section] = self._parse_param_list(content)
299 | elif section.startswith('.. index::'):
300 | self['index'] = self._parse_index(section, content)
301 | elif section == 'See Also':
302 | self['See Also'] = self._parse_see_also(content)
303 | else:
304 | self[section] = content
305 |
306 | # string conversion routines
307 |
308 | def _str_header(self, name, symbol='-'):
309 | return [name, len(name) * symbol]
310 |
311 | def _str_indent(self, doc, indent=4):
312 | out = []
313 | for line in doc:
314 | out += [' ' * indent + line]
315 | return out
316 |
317 | def _str_signature(self):
318 | if self['Signature']:
319 | return [self['Signature'].replace('*', '\*')] + ['']
320 | else:
321 | return ['']
322 |
323 | def _str_summary(self):
324 | if self['Summary']:
325 | return self['Summary'] + ['']
326 | else:
327 | return []
328 |
329 | def _str_extended_summary(self):
330 | if self['Extended Summary']:
331 | return self['Extended Summary'] + ['']
332 | else:
333 | return []
334 |
335 | def _str_param_list(self, name):
336 | out = []
337 | if self[name]:
338 | out += self._str_header(name)
339 | for param, param_type, desc in self[name]:
340 | out += ['%s : %s' % (param, param_type)]
341 | out += self._str_indent(desc)
342 | out += ['']
343 | return out
344 |
345 | def _str_section(self, name):
346 | out = []
347 | if self[name]:
348 | out += self._str_header(name)
349 | out += self[name]
350 | out += ['']
351 | return out
352 |
353 | def _str_see_also(self, func_role):
354 | if not self['See Also']:
355 | return []
356 | out = []
357 | out += self._str_header("See Also")
358 | last_had_desc = True
359 | for func, desc, role in self['See Also']:
360 | if role:
361 | link = ':%s:`%s`' % (role, func)
362 | elif func_role:
363 | link = ':%s:`%s`' % (func_role, func)
364 | else:
365 | link = "`%s`_" % func
366 | if desc or last_had_desc:
367 | out += ['']
368 | out += [link]
369 | else:
370 | out[-1] += ", %s" % link
371 | if desc:
372 | out += self._str_indent([' '.join(desc)])
373 | last_had_desc = True
374 | else:
375 | last_had_desc = False
376 | out += ['']
377 | return out
378 |
379 | def _str_index(self):
380 | idx = self['index']
381 | out = []
382 | out += ['.. index:: %s' % idx.get('default', '')]
383 | for section, references in idx.iteritems():
384 | if section == 'default':
385 | continue
386 | out += [' :%s: %s' % (section, ', '.join(references))]
387 | return out
388 |
389 | def __str__(self, func_role=''):
390 | out = []
391 | out += self._str_signature()
392 | out += self._str_summary()
393 | out += self._str_extended_summary()
394 | for param_list in ('Parameters', 'Returns', 'Raises'):
395 | out += self._str_param_list(param_list)
396 | out += self._str_section('Warnings')
397 | out += self._str_see_also(func_role)
398 | for s in ('Notes', 'References', 'Examples'):
399 | out += self._str_section(s)
400 | for param_list in ('Attributes', 'Methods'):
401 | out += self._str_param_list(param_list)
402 | out += self._str_index()
403 | return '\n'.join(out)
404 |
405 |
406 | def indent(str, indent=4):
407 | indent_str = ' ' * indent
408 | if str is None:
409 | return indent_str
410 | lines = str.split('\n')
411 | return '\n'.join(indent_str + l for l in lines)
412 |
413 |
414 | def dedent_lines(lines):
415 | """Deindent a list of lines maximally"""
416 | return textwrap.dedent("\n".join(lines)).split("\n")
417 |
418 |
419 | def header(text, style='-'):
420 | return text + '\n' + style * len(text) + '\n'
421 |
422 |
423 | class FunctionDoc(NumpyDocString):
424 | def __init__(self, func, role='func', doc=None, config={}):
425 | self._f = func
426 | self._role = role # e.g. "func" or "meth"
427 |
428 | if doc is None:
429 | if func is None:
430 | raise ValueError("No function or docstring given")
431 | doc = inspect.getdoc(func) or ''
432 | NumpyDocString.__init__(self, doc)
433 |
434 | if not self['Signature'] and func is not None:
435 | func, func_name = self.get_func()
436 | try:
437 | # try to read signature
438 | argspec = inspect.getargspec(func)
439 | argspec = inspect.formatargspec(*argspec)
440 | argspec = argspec.replace('*', '\*')
441 | signature = '%s%s' % (func_name, argspec)
442 | except TypeError as e:
443 | signature = '%s()' % func_name
444 | self['Signature'] = signature
445 |
446 | def get_func(self):
447 | func_name = getattr(self._f, '__name__', self.__class__.__name__)
448 | if inspect.isclass(self._f):
449 | func = getattr(self._f, '__call__', self._f.__init__)
450 | else:
451 | func = self._f
452 | return func, func_name
453 |
454 | def __str__(self):
455 | out = ''
456 |
457 | func, func_name = self.get_func()
458 | signature = self['Signature'].replace('*', '\*')
459 |
460 | roles = {'func': 'function',
461 | 'meth': 'method'}
462 |
463 | if self._role:
464 | if self._role not in roles:
465 | print("Warning: invalid role %s" % self._role)
466 | out += '.. %s:: %s\n \n\n' % (roles.get(self._role, ''),
467 | func_name)
468 |
469 | out += super(FunctionDoc, self).__str__(func_role=self._role)
470 | return out
471 |
472 |
473 | class ClassDoc(NumpyDocString):
474 | def __init__(self, cls, doc=None, modulename='', func_doc=FunctionDoc,
475 | config=None):
476 | if not inspect.isclass(cls) and cls is not None:
477 | raise ValueError("Expected a class or None, but got %r" % cls)
478 | self._cls = cls
479 |
480 | if modulename and not modulename.endswith('.'):
481 | modulename += '.'
482 | self._mod = modulename
483 |
484 | if doc is None:
485 | if cls is None:
486 | raise ValueError("No class or documentation string given")
487 | doc = pydoc.getdoc(cls)
488 |
489 | NumpyDocString.__init__(self, doc)
490 |
491 | if config is not None and config.get('show_class_members', True):
492 | if not self['Methods']:
493 | self['Methods'] = [(name, '', '')
494 | for name in sorted(self.methods)]
495 | if not self['Attributes']:
496 | self['Attributes'] = [(name, '', '')
497 | for name in sorted(self.properties)]
498 |
499 | @property
500 | def methods(self):
501 | if self._cls is None:
502 | return []
503 | return [name for name, func in inspect.getmembers(self._cls)
504 | if not name.startswith('_') and callable(func)]
505 |
506 | @property
507 | def properties(self):
508 | if self._cls is None:
509 | return []
510 | return [name for name, func in inspect.getmembers(self._cls)
511 | if not name.startswith('_') and func is None]
512 |
--------------------------------------------------------------------------------
/doc/sphinxext/numpy_ext/docscrape_sphinx.py:
--------------------------------------------------------------------------------
1 | import re
2 | import inspect
3 | import textwrap
4 | import pydoc
5 | from .docscrape import NumpyDocString
6 | from .docscrape import FunctionDoc
7 | from .docscrape import ClassDoc
8 |
9 |
10 | class SphinxDocString(NumpyDocString):
11 | def __init__(self, docstring, config=None):
12 | config = {} if config is None else config
13 | self.use_plots = config.get('use_plots', False)
14 | NumpyDocString.__init__(self, docstring, config=config)
15 |
16 | # string conversion routines
17 | def _str_header(self, name, symbol='`'):
18 | return ['.. rubric:: ' + name, '']
19 |
20 | def _str_field_list(self, name):
21 | return [':' + name + ':']
22 |
23 | def _str_indent(self, doc, indent=4):
24 | out = []
25 | for line in doc:
26 | out += [' ' * indent + line]
27 | return out
28 |
29 | def _str_signature(self):
30 | return ['']
31 | if self['Signature']:
32 | return ['``%s``' % self['Signature']] + ['']
33 | else:
34 | return ['']
35 |
36 | def _str_summary(self):
37 | return self['Summary'] + ['']
38 |
39 | def _str_extended_summary(self):
40 | return self['Extended Summary'] + ['']
41 |
42 | def _str_param_list(self, name):
43 | out = []
44 | if self[name]:
45 | out += self._str_field_list(name)
46 | out += ['']
47 | for param, param_type, desc in self[name]:
48 | out += self._str_indent(['**%s** : %s' % (param.strip(),
49 | param_type)])
50 | out += ['']
51 | out += self._str_indent(desc, 8)
52 | out += ['']
53 | return out
54 |
55 | @property
56 | def _obj(self):
57 | if hasattr(self, '_cls'):
58 | return self._cls
59 | elif hasattr(self, '_f'):
60 | return self._f
61 | return None
62 |
63 | def _str_member_list(self, name):
64 | """
65 | Generate a member listing, autosummary:: table where possible,
66 | and a table where not.
67 |
68 | """
69 | out = []
70 | if self[name]:
71 | out += ['.. rubric:: %s' % name, '']
72 | prefix = getattr(self, '_name', '')
73 |
74 | if prefix:
75 | prefix = '~%s.' % prefix
76 |
77 | autosum = []
78 | others = []
79 | for param, param_type, desc in self[name]:
80 | param = param.strip()
81 | if not self._obj or hasattr(self._obj, param):
82 | autosum += [" %s%s" % (prefix, param)]
83 | else:
84 | others.append((param, param_type, desc))
85 |
86 | if autosum:
87 | # GAEL: Toctree commented out below because it creates
88 | # hundreds of sphinx warnings
89 | # out += ['.. autosummary::', ' :toctree:', '']
90 | out += ['.. autosummary::', '']
91 | out += autosum
92 |
93 | if others:
94 | maxlen_0 = max([len(x[0]) for x in others])
95 | maxlen_1 = max([len(x[1]) for x in others])
96 | hdr = "=" * maxlen_0 + " " + "=" * maxlen_1 + " " + "=" * 10
97 | fmt = '%%%ds %%%ds ' % (maxlen_0, maxlen_1)
98 | n_indent = maxlen_0 + maxlen_1 + 4
99 | out += [hdr]
100 | for param, param_type, desc in others:
101 | out += [fmt % (param.strip(), param_type)]
102 | out += self._str_indent(desc, n_indent)
103 | out += [hdr]
104 | out += ['']
105 | return out
106 |
107 | def _str_section(self, name):
108 | out = []
109 | if self[name]:
110 | out += self._str_header(name)
111 | out += ['']
112 | content = textwrap.dedent("\n".join(self[name])).split("\n")
113 | out += content
114 | out += ['']
115 | return out
116 |
117 | def _str_see_also(self, func_role):
118 | out = []
119 | if self['See Also']:
120 | see_also = super(SphinxDocString, self)._str_see_also(func_role)
121 | out = ['.. seealso::', '']
122 | out += self._str_indent(see_also[2:])
123 | return out
124 |
125 | def _str_warnings(self):
126 | out = []
127 | if self['Warnings']:
128 | out = ['.. warning::', '']
129 | out += self._str_indent(self['Warnings'])
130 | return out
131 |
132 | def _str_index(self):
133 | idx = self['index']
134 | out = []
135 | if len(idx) == 0:
136 | return out
137 |
138 | out += ['.. index:: %s' % idx.get('default', '')]
139 | for section, references in idx.iteritems():
140 | if section == 'default':
141 | continue
142 | elif section == 'refguide':
143 | out += [' single: %s' % (', '.join(references))]
144 | else:
145 | out += [' %s: %s' % (section, ','.join(references))]
146 | return out
147 |
148 | def _str_references(self):
149 | out = []
150 | if self['References']:
151 | out += self._str_header('References')
152 | if isinstance(self['References'], str):
153 | self['References'] = [self['References']]
154 | out.extend(self['References'])
155 | out += ['']
156 | # Latex collects all references to a separate bibliography,
157 | # so we need to insert links to it
158 | import sphinx # local import to avoid test dependency
159 | if sphinx.__version__ >= "0.6":
160 | out += ['.. only:: latex', '']
161 | else:
162 | out += ['.. latexonly::', '']
163 | items = []
164 | for line in self['References']:
165 | m = re.match(r'.. \[([a-z0-9._-]+)\]', line, re.I)
166 | if m:
167 | items.append(m.group(1))
168 | out += [' ' + ", ".join(["[%s]_" % item for item in items]), '']
169 | return out
170 |
171 | def _str_examples(self):
172 | examples_str = "\n".join(self['Examples'])
173 |
174 | if (self.use_plots and 'import matplotlib' in examples_str
175 | and 'plot::' not in examples_str):
176 | out = []
177 | out += self._str_header('Examples')
178 | out += ['.. plot::', '']
179 | out += self._str_indent(self['Examples'])
180 | out += ['']
181 | return out
182 | else:
183 | return self._str_section('Examples')
184 |
185 | def __str__(self, indent=0, func_role="obj"):
186 | out = []
187 | out += self._str_signature()
188 | out += self._str_index() + ['']
189 | out += self._str_summary()
190 | out += self._str_extended_summary()
191 | for param_list in ('Parameters', 'Returns', 'Raises', 'Attributes'):
192 | out += self._str_param_list(param_list)
193 | out += self._str_warnings()
194 | out += self._str_see_also(func_role)
195 | out += self._str_section('Notes')
196 | out += self._str_references()
197 | out += self._str_examples()
198 | for param_list in ('Methods',):
199 | out += self._str_member_list(param_list)
200 | out = self._str_indent(out, indent)
201 | return '\n'.join(out)
202 |
203 |
204 | class SphinxFunctionDoc(SphinxDocString, FunctionDoc):
205 | def __init__(self, obj, doc=None, config={}):
206 | self.use_plots = config.get('use_plots', False)
207 | FunctionDoc.__init__(self, obj, doc=doc, config=config)
208 |
209 |
210 | class SphinxClassDoc(SphinxDocString, ClassDoc):
211 | def __init__(self, obj, doc=None, func_doc=None, config={}):
212 | self.use_plots = config.get('use_plots', False)
213 | ClassDoc.__init__(self, obj, doc=doc, func_doc=None, config=config)
214 |
215 |
216 | class SphinxObjDoc(SphinxDocString):
217 | def __init__(self, obj, doc=None, config=None):
218 | self._f = obj
219 | SphinxDocString.__init__(self, doc, config=config)
220 |
221 |
222 | def get_doc_object(obj, what=None, doc=None, config={}):
223 | if what is None:
224 | if inspect.isclass(obj):
225 | what = 'class'
226 | elif inspect.ismodule(obj):
227 | what = 'module'
228 | elif callable(obj):
229 | what = 'function'
230 | else:
231 | what = 'object'
232 | if what == 'class':
233 | return SphinxClassDoc(obj, func_doc=SphinxFunctionDoc, doc=doc,
234 | config=config)
235 | elif what in ('function', 'method'):
236 | return SphinxFunctionDoc(obj, doc=doc, config=config)
237 | else:
238 | if doc is None:
239 | doc = pydoc.getdoc(obj)
240 | return SphinxObjDoc(obj, doc, config=config)
241 |
--------------------------------------------------------------------------------
/doc/sphinxext/numpy_ext/numpydoc.py:
--------------------------------------------------------------------------------
1 | """
2 | ========
3 | numpydoc
4 | ========
5 |
6 | Sphinx extension that handles docstrings in the Numpy standard format. [1]
7 |
8 | It will:
9 |
10 | - Convert Parameters etc. sections to field lists.
11 | - Convert See Also section to a See also entry.
12 | - Renumber references.
13 | - Extract the signature from the docstring, if it can't be determined
14 | otherwise.
15 |
16 | .. [1] http://projects.scipy.org/numpy/wiki/CodingStyleGuidelines#docstring-standard
17 |
18 | """
19 |
20 | from __future__ import unicode_literals
21 |
22 | import sys # Only needed to check Python version
23 | import os
24 | import re
25 | import pydoc
26 | from .docscrape_sphinx import get_doc_object
27 | from .docscrape_sphinx import SphinxDocString
28 | import inspect
29 |
30 |
31 | def mangle_docstrings(app, what, name, obj, options, lines,
32 | reference_offset=[0]):
33 |
34 | cfg = dict(use_plots=app.config.numpydoc_use_plots,
35 | show_class_members=app.config.numpydoc_show_class_members)
36 |
37 | if what == 'module':
38 | # Strip top title
39 | title_re = re.compile(r'^\s*[#*=]{4,}\n[a-z0-9 -]+\n[#*=]{4,}\s*',
40 | re.I | re.S)
41 | lines[:] = title_re.sub('', "\n".join(lines)).split("\n")
42 | else:
43 | doc = get_doc_object(obj, what, "\n".join(lines), config=cfg)
44 | if sys.version_info[0] < 3:
45 | lines[:] = unicode(doc).splitlines()
46 | else:
47 | lines[:] = str(doc).splitlines()
48 |
49 | if app.config.numpydoc_edit_link and hasattr(obj, '__name__') and \
50 | obj.__name__:
51 | if hasattr(obj, '__module__'):
52 | v = dict(full_name="%s.%s" % (obj.__module__, obj.__name__))
53 | else:
54 | v = dict(full_name=obj.__name__)
55 | lines += [u'', u'.. htmlonly::', '']
56 | lines += [u' %s' % x for x in
57 | (app.config.numpydoc_edit_link % v).split("\n")]
58 |
59 | # replace reference numbers so that there are no duplicates
60 | references = []
61 | for line in lines:
62 | line = line.strip()
63 | m = re.match(r'^.. \[([a-z0-9_.-])\]', line, re.I)
64 | if m:
65 | references.append(m.group(1))
66 |
67 | # start renaming from the longest string, to avoid overwriting parts
68 | references.sort(key=lambda x: -len(x))
69 | if references:
70 | for i, line in enumerate(lines):
71 | for r in references:
72 | if re.match(r'^\d+$', r):
73 | new_r = "R%d" % (reference_offset[0] + int(r))
74 | else:
75 | new_r = u"%s%d" % (r, reference_offset[0])
76 | lines[i] = lines[i].replace(u'[%s]_' % r,
77 | u'[%s]_' % new_r)
78 | lines[i] = lines[i].replace(u'.. [%s]' % r,
79 | u'.. [%s]' % new_r)
80 |
81 | reference_offset[0] += len(references)
82 |
83 |
84 | def mangle_signature(app, what, name, obj,
85 | options, sig, retann):
86 | # Do not try to inspect classes that don't define `__init__`
87 | if (inspect.isclass(obj) and
88 | (not hasattr(obj, '__init__') or
89 | 'initializes x; see ' in pydoc.getdoc(obj.__init__))):
90 | return '', ''
91 |
92 | if not (callable(obj) or hasattr(obj, '__argspec_is_invalid_')):
93 | return
94 | if not hasattr(obj, '__doc__'):
95 | return
96 |
97 | doc = SphinxDocString(pydoc.getdoc(obj))
98 | if doc['Signature']:
99 | sig = re.sub("^[^(]*", "", doc['Signature'])
100 | return sig, ''
101 |
102 |
103 | def setup(app, get_doc_object_=get_doc_object):
104 | global get_doc_object
105 | get_doc_object = get_doc_object_
106 |
107 | if sys.version_info[0] < 3:
108 | app.connect(b'autodoc-process-docstring', mangle_docstrings)
109 | app.connect(b'autodoc-process-signature', mangle_signature)
110 | else:
111 | app.connect('autodoc-process-docstring', mangle_docstrings)
112 | app.connect('autodoc-process-signature', mangle_signature)
113 | app.add_config_value('numpydoc_edit_link', None, False)
114 | app.add_config_value('numpydoc_use_plots', None, False)
115 | app.add_config_value('numpydoc_show_class_members', True, True)
116 |
117 | # Extra mangling domains
118 | app.add_domain(NumpyPythonDomain)
119 | app.add_domain(NumpyCDomain)
120 |
121 | #-----------------------------------------------------------------------------
122 | # Docstring-mangling domains
123 | #-----------------------------------------------------------------------------
124 |
125 | try:
126 | import sphinx # lazy to avoid test dependency
127 | except ImportError:
128 | CDomain = PythonDomain = object
129 | else:
130 | from sphinx.domains.c import CDomain
131 | from sphinx.domains.python import PythonDomain
132 |
133 |
134 | class ManglingDomainBase(object):
135 | directive_mangling_map = {}
136 |
137 | def __init__(self, *a, **kw):
138 | super(ManglingDomainBase, self).__init__(*a, **kw)
139 | self.wrap_mangling_directives()
140 |
141 | def wrap_mangling_directives(self):
142 | for name, objtype in self.directive_mangling_map.items():
143 | self.directives[name] = wrap_mangling_directive(
144 | self.directives[name], objtype)
145 |
146 |
147 | class NumpyPythonDomain(ManglingDomainBase, PythonDomain):
148 | name = 'np'
149 | directive_mangling_map = {
150 | 'function': 'function',
151 | 'class': 'class',
152 | 'exception': 'class',
153 | 'method': 'function',
154 | 'classmethod': 'function',
155 | 'staticmethod': 'function',
156 | 'attribute': 'attribute',
157 | }
158 |
159 |
160 | class NumpyCDomain(ManglingDomainBase, CDomain):
161 | name = 'np-c'
162 | directive_mangling_map = {
163 | 'function': 'function',
164 | 'member': 'attribute',
165 | 'macro': 'function',
166 | 'type': 'class',
167 | 'var': 'object',
168 | }
169 |
170 |
171 | def wrap_mangling_directive(base_directive, objtype):
172 | class directive(base_directive):
173 | def run(self):
174 | env = self.state.document.settings.env
175 |
176 | name = None
177 | if self.arguments:
178 | m = re.match(r'^(.*\s+)?(.*?)(\(.*)?', self.arguments[0])
179 | name = m.group(2).strip()
180 |
181 | if not name:
182 | name = self.arguments[0]
183 |
184 | lines = list(self.content)
185 | mangle_docstrings(env.app, objtype, name, None, None, lines)
186 | # local import to avoid testing dependency
187 | from docutils.statemachine import ViewList
188 | self.content = ViewList(lines, self.content.parent)
189 |
190 | return base_directive.run(self)
191 |
192 | return directive
193 |
--------------------------------------------------------------------------------
/doc/sphinxext/sphinx_gallery/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | ==============
3 | Sphinx Gallery
4 | ==============
5 |
6 | """
7 | import os
8 | __version__ = '0.1.7'
9 |
10 |
11 | def glr_path_static():
12 | """Returns path to packaged static files"""
13 | return os.path.abspath(os.path.join(os.path.dirname(__file__), '_static'))
14 |
--------------------------------------------------------------------------------
/doc/sphinxext/sphinx_gallery/_static/broken_example.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/svaiter/pyprox/ffc3084a2478536fec808273e16bd7f22e6a9e3c/doc/sphinxext/sphinx_gallery/_static/broken_example.png
--------------------------------------------------------------------------------
/doc/sphinxext/sphinx_gallery/_static/gallery.css:
--------------------------------------------------------------------------------
1 | /*
2 | Sphinx-Gallery has compatible CSS to fix default sphinx themes
3 | Tested for Sphinx 1.3.1 for all themes: default, alabaster, sphinxdoc,
4 | scrolls, agogo, traditional, nature, haiku, pyramid
5 | Tested for Read the Docs theme 0.1.7 */
6 | .sphx-glr-thumbcontainer {
7 | background: #fff;
8 | border: solid #fff 1px;
9 | -moz-border-radius: 5px;
10 | -webkit-border-radius: 5px;
11 | border-radius: 5px;
12 | box-shadow: none;
13 | float: left;
14 | margin: 5px;
15 | min-height: 230px;
16 | padding-top: 5px;
17 | position: relative;
18 | }
19 | .sphx-glr-thumbcontainer:hover {
20 | border: solid #b4ddfc 1px;
21 | box-shadow: 0 0 15px rgba(142, 176, 202, 0.5);
22 | }
23 | .sphx-glr-thumbcontainer a.internal {
24 | bottom: 0;
25 | display: block;
26 | left: 0;
27 | padding: 150px 10px 0;
28 | position: absolute;
29 | right: 0;
30 | top: 0;
31 | }
32 | /* Next one is to avoid Sphinx traditional theme to cover all the
33 | thumbnail with its default link Background color */
34 | .sphx-glr-thumbcontainer a.internal:hover {
35 | background-color: transparent;
36 | }
37 |
38 | .sphx-glr-thumbcontainer p {
39 | margin: 0 0 .1em 0;
40 | }
41 | .sphx-glr-thumbcontainer .figure {
42 | margin: 10px;
43 | width: 160px;
44 | }
45 | .sphx-glr-thumbcontainer img {
46 | display: inline;
47 | max-height: 160px;
48 | width: 160px;
49 | }
50 | .sphx-glr-thumbcontainer[tooltip]:hover:after {
51 | background: rgba(0, 0, 0, 0.8);
52 | -webkit-border-radius: 5px;
53 | -moz-border-radius: 5px;
54 | border-radius: 5px;
55 | color: #fff;
56 | content: attr(tooltip);
57 | left: 95%;
58 | padding: 5px 15px;
59 | position: absolute;
60 | z-index: 98;
61 | width: 220px;
62 | bottom: 52%;
63 | }
64 | .sphx-glr-thumbcontainer[tooltip]:hover:before {
65 | border: solid;
66 | border-color: #333 transparent;
67 | border-width: 18px 0 0 20px;
68 | bottom: 58%;
69 | content: '';
70 | left: 85%;
71 | position: absolute;
72 | z-index: 99;
73 | }
74 |
75 | .highlight-pytb pre {
76 | background-color: #ffe4e4;
77 | border: 1px solid #f66;
78 | margin-top: 10px;
79 | padding: 7px;
80 | }
81 |
82 | .sphx-glr-script-out {
83 | color: #888;
84 | margin: 0;
85 | }
86 | .sphx-glr-script-out .highlight {
87 | background-color: transparent;
88 | margin-left: 2.5em;
89 | margin-top: -1.4em;
90 | }
91 | .sphx-glr-script-out .highlight pre {
92 | background-color: #fafae2;
93 | border: 0;
94 | max-height: 30em;
95 | overflow: auto;
96 | padding-left: 1ex;
97 | margin: 0px;
98 | word-break: break-word;
99 | }
100 | .sphx-glr-script-out + p {
101 | margin-top: 1.8em;
102 | }
103 | blockquote.sphx-glr-script-out {
104 | margin-left: 0pt;
105 | }
106 |
107 | div.sphx-glr-footer {
108 | text-align: center;
109 | }
110 |
111 | div.sphx-glr-download {
112 | display: inline-block;
113 | margin: 1em auto 1ex 2ex;
114 | vertical-align: middle;
115 | }
116 |
117 | div.sphx-glr-download a {
118 | background-color: #ffc;
119 | background-image: linear-gradient(to bottom, #FFC, #d5d57e);
120 | border-radius: 4px;
121 | border: 1px solid #c2c22d;
122 | color: #000;
123 | display: inline-block;
124 | /* Not valid in old browser, hence we keep the line above to override */
125 | display: table-caption;
126 | font-weight: bold;
127 | padding: 1ex;
128 | text-align: center;
129 | }
130 |
131 | /* The last child of a download button is the file name */
132 | div.sphx-glr-download a span:last-child {
133 | font-size: smaller;
134 | }
135 |
136 | @media (min-width: 20em) {
137 | div.sphx-glr-download a {
138 | min-width: 10em;
139 | }
140 | }
141 |
142 | @media (min-width: 30em) {
143 | div.sphx-glr-download a {
144 | min-width: 13em;
145 | }
146 | }
147 |
148 | @media (min-width: 40em) {
149 | div.sphx-glr-download a {
150 | min-width: 16em;
151 | }
152 | }
153 |
154 |
155 | div.sphx-glr-download code.download {
156 | display: inline-block;
157 | white-space: normal;
158 | word-break: normal;
159 | overflow-wrap: break-word;
160 | /* border and background are given by the enclosing 'a' */
161 | border: none;
162 | background: none;
163 | }
164 |
165 | div.sphx-glr-download a:hover {
166 | box-shadow: inset 0 1px 0 rgba(255,255,255,.1), 0 1px 5px rgba(0,0,0,.25);
167 | text-decoration: none;
168 | background-image: none;
169 | background-color: #d5d57e;
170 | }
171 |
172 | ul.sphx-glr-horizontal {
173 | list-style: none;
174 | padding: 0;
175 | }
176 | ul.sphx-glr-horizontal li {
177 | display: inline;
178 | }
179 | ul.sphx-glr-horizontal img {
180 | height: auto !important;
181 | }
182 |
183 | p.sphx-glr-signature a.reference.external {
184 | -moz-border-radius: 5px;
185 | -webkit-border-radius: 5px;
186 | border-radius: 5px;
187 | padding: 3px;
188 | font-size: 75%;
189 | text-align: right;
190 | margin-left: auto;
191 | display: table;
192 | }
193 |
194 | a.sphx-glr-code-links:hover{
195 | text-decoration: none;
196 | }
197 |
198 | a.sphx-glr-code-links[tooltip]:hover:before{
199 | background: rgba(0,0,0,.8);
200 | border-radius: 5px;
201 | color: white;
202 | content: attr(tooltip);
203 | padding: 5px 15px;
204 | position: absolute;
205 | z-index: 98;
206 | width: 16em;
207 | word-break: normal;
208 | white-space: normal;
209 | display: inline-block;
210 | text-align: center;
211 | text-indent: 0;
212 | margin-left: 0; /* Use zero to avoid overlapping with sidebar */
213 | margin-top: 1.2em;
214 | }
215 |
--------------------------------------------------------------------------------
/doc/sphinxext/sphinx_gallery/_static/no_image.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/svaiter/pyprox/ffc3084a2478536fec808273e16bd7f22e6a9e3c/doc/sphinxext/sphinx_gallery/_static/no_image.png
--------------------------------------------------------------------------------
/doc/sphinxext/sphinx_gallery/backreferences.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Author: Óscar Nájera
3 | # License: 3-clause BSD
4 | """
5 | ========================
6 | Backreferences Generator
7 | ========================
8 |
9 | Reviews generated example files in order to keep track of used modules
10 | """
11 |
12 | from __future__ import print_function
13 | import ast
14 | import os
15 |
16 |
17 | # Try Python 2 first, otherwise load from Python 3
18 | try:
19 | import cPickle as pickle
20 | except ImportError:
21 | import pickle
22 |
23 |
24 | class NameFinder(ast.NodeVisitor):
25 | """Finds the longest form of variable names and their imports in code
26 |
27 | Only retains names from imported modules.
28 | """
29 |
30 | def __init__(self):
31 | super(NameFinder, self).__init__()
32 | self.imported_names = {}
33 | self.accessed_names = set()
34 |
35 | def visit_Import(self, node, prefix=''):
36 | for alias in node.names:
37 | local_name = alias.asname or alias.name
38 | self.imported_names[local_name] = prefix + alias.name
39 |
40 | def visit_ImportFrom(self, node):
41 | self.visit_Import(node, node.module + '.')
42 |
43 | def visit_Name(self, node):
44 | self.accessed_names.add(node.id)
45 |
46 | def visit_Attribute(self, node):
47 | attrs = []
48 | while isinstance(node, ast.Attribute):
49 | attrs.append(node.attr)
50 | node = node.value
51 |
52 | if isinstance(node, ast.Name):
53 | # This is a.b, not e.g. a().b
54 | attrs.append(node.id)
55 | self.accessed_names.add('.'.join(reversed(attrs)))
56 | else:
57 | # need to get a in a().b
58 | self.visit(node)
59 |
60 | def get_mapping(self):
61 | for name in self.accessed_names:
62 | local_name = name.split('.', 1)[0]
63 | remainder = name[len(local_name):]
64 | if local_name in self.imported_names:
65 | # Join import path to relative path
66 | full_name = self.imported_names[local_name] + remainder
67 | yield name, full_name
68 |
69 |
70 | def get_short_module_name(module_name, obj_name):
71 | """ Get the shortest possible module name """
72 | parts = module_name.split('.')
73 | short_name = module_name
74 | for i in range(len(parts) - 1, 0, -1):
75 | short_name = '.'.join(parts[:i])
76 | try:
77 | exec('from %s import %s' % (short_name, obj_name))
78 | except Exception: # libraries can throw all sorts of exceptions...
79 | # get the last working module name
80 | short_name = '.'.join(parts[:(i + 1)])
81 | break
82 | return short_name
83 |
84 |
85 | def identify_names(code):
86 | """Builds a codeobj summary by identifying and resolving used names
87 |
88 | >>> code = '''
89 | ... from a.b import c
90 | ... import d as e
91 | ... print(c)
92 | ... e.HelloWorld().f.g
93 | ... '''
94 | >>> for name, o in sorted(identify_names(code).items()):
95 | ... print(name, o['name'], o['module'], o['module_short'])
96 | c c a.b a.b
97 | e.HelloWorld HelloWorld d d
98 | """
99 | finder = NameFinder()
100 | finder.visit(ast.parse(code))
101 |
102 | example_code_obj = {}
103 | for name, full_name in finder.get_mapping():
104 | # name is as written in file (e.g. np.asarray)
105 | # full_name includes resolved import path (e.g. numpy.asarray)
106 | splitted = full_name.rsplit('.', 1)
107 | if len(splitted) == 1:
108 | # module without attribute. This is not useful for
109 | # backreferences
110 | continue
111 |
112 | module, attribute = splitted
113 | # get shortened module name
114 | module_short = get_short_module_name(module, attribute)
115 | cobj = {'name': attribute, 'module': module,
116 | 'module_short': module_short}
117 | example_code_obj[name] = cobj
118 | return example_code_obj
119 |
120 |
121 | def scan_used_functions(example_file, gallery_conf):
122 | """save variables so we can later add links to the documentation"""
123 | example_code_obj = identify_names(open(example_file).read())
124 | if example_code_obj:
125 | codeobj_fname = example_file[:-3] + '_codeobj.pickle'
126 | with open(codeobj_fname, 'wb') as fid:
127 | pickle.dump(example_code_obj, fid, pickle.HIGHEST_PROTOCOL)
128 |
129 | backrefs = set('{module_short}.{name}'.format(**entry)
130 | for entry in example_code_obj.values()
131 | if entry['module'].startswith(gallery_conf['doc_module']))
132 |
133 | return backrefs
134 |
135 |
136 | # XXX This figure:: uses a forward slash even on Windows, but the op.join's
137 | # elsewhere will use backslashes...
138 | THUMBNAIL_TEMPLATE = """
139 | .. raw:: html
140 |
141 |
142 |
143 | .. only:: html
144 |
145 | .. figure:: /{thumbnail}
146 |
147 | :ref:`sphx_glr_{ref_name}`
148 |
149 | .. raw:: html
150 |
151 |
152 | """
153 |
154 | BACKREF_THUMBNAIL_TEMPLATE = THUMBNAIL_TEMPLATE + """
155 | .. only:: not html
156 |
157 | * :ref:`sphx_glr_{ref_name}`
158 | """
159 |
160 |
161 | def _thumbnail_div(full_dir, fname, snippet, is_backref=False):
162 | """Generates RST to place a thumbnail in a gallery"""
163 | thumb = os.path.join(full_dir, 'images', 'thumb',
164 | 'sphx_glr_%s_thumb.png' % fname[:-3])
165 | ref_name = os.path.join(full_dir, fname).replace(os.path.sep, '_')
166 |
167 | template = BACKREF_THUMBNAIL_TEMPLATE if is_backref else THUMBNAIL_TEMPLATE
168 | return template.format(snippet=snippet, thumbnail=thumb, ref_name=ref_name)
169 |
170 |
171 | def write_backreferences(seen_backrefs, gallery_conf,
172 | target_dir, fname, snippet):
173 | """Writes down back reference files, which include a thumbnail list
174 | of examples using a certain module"""
175 | example_file = os.path.join(target_dir, fname)
176 | backrefs = scan_used_functions(example_file, gallery_conf)
177 | for backref in backrefs:
178 | include_path = os.path.join(gallery_conf['mod_example_dir'],
179 | '%s.examples' % backref)
180 | seen = backref in seen_backrefs
181 | with open(include_path, 'a' if seen else 'w') as ex_file:
182 | if not seen:
183 | heading = '\n\nExamples using ``%s``' % backref
184 | ex_file.write(heading + '\n')
185 | ex_file.write('^' * len(heading) + '\n')
186 | ex_file.write(_thumbnail_div(target_dir, fname, snippet,
187 | is_backref=True))
188 | seen_backrefs.add(backref)
189 |
--------------------------------------------------------------------------------
/doc/sphinxext/sphinx_gallery/docs_resolv.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Author: Óscar Nájera
3 | # License: 3-clause BSD
4 | ###############################################################################
5 | # Documentation link resolver objects
6 | from __future__ import print_function
7 | import gzip
8 | import os
9 | import posixpath
10 | import re
11 | import shelve
12 | import sys
13 |
14 | # Try Python 2 first, otherwise load from Python 3
15 | try:
16 | import cPickle as pickle
17 | import urllib2 as urllib
18 | from urllib2 import HTTPError, URLError
19 | except ImportError:
20 | import pickle
21 | import urllib.request
22 | import urllib.error
23 | import urllib.parse
24 | from urllib.error import HTTPError, URLError
25 |
26 | from io import StringIO
27 |
28 |
29 | def _get_data(url):
30 | """Helper function to get data over http or from a local file"""
31 | if url.startswith('http://'):
32 | # Try Python 2, use Python 3 on exception
33 | try:
34 | resp = urllib.urlopen(url)
35 | encoding = resp.headers.dict.get('content-encoding', 'plain')
36 | except AttributeError:
37 | resp = urllib.request.urlopen(url)
38 | encoding = resp.headers.get('content-encoding', 'plain')
39 | data = resp.read()
40 | if encoding == 'plain':
41 | pass
42 | elif encoding == 'gzip':
43 | data = StringIO(data)
44 | data = gzip.GzipFile(fileobj=data).read()
45 | else:
46 | raise RuntimeError('unknown encoding')
47 | else:
48 | with open(url, 'r') as fid:
49 | data = fid.read()
50 |
51 | return data
52 |
53 |
54 | def get_data(url, gallery_dir):
55 | """Persistent dictionary usage to retrieve the search indexes"""
56 |
57 | # shelve keys need to be str in python 2
58 | if sys.version_info[0] == 2 and isinstance(url, unicode):
59 | url = url.encode('utf-8')
60 |
61 | cached_file = os.path.join(gallery_dir, 'searchindex')
62 | search_index = shelve.open(cached_file)
63 | if url in search_index:
64 | data = search_index[url]
65 | else:
66 | data = _get_data(url)
67 | search_index[url] = data
68 | search_index.close()
69 |
70 | return data
71 |
72 |
73 | def _select_block(str_in, start_tag, end_tag):
74 | """Select first block delimited by start_tag and end_tag"""
75 | start_pos = str_in.find(start_tag)
76 | if start_pos < 0:
77 | raise ValueError('start_tag not found')
78 | depth = 0
79 | for pos in range(start_pos, len(str_in)):
80 | if str_in[pos] == start_tag:
81 | depth += 1
82 | elif str_in[pos] == end_tag:
83 | depth -= 1
84 |
85 | if depth == 0:
86 | break
87 | sel = str_in[start_pos + 1:pos]
88 | return sel
89 |
90 |
91 | def _parse_dict_recursive(dict_str):
92 | """Parse a dictionary from the search index"""
93 | dict_out = dict()
94 | pos_last = 0
95 | pos = dict_str.find(':')
96 | while pos >= 0:
97 | key = dict_str[pos_last:pos]
98 | if dict_str[pos + 1] == '[':
99 | # value is a list
100 | pos_tmp = dict_str.find(']', pos + 1)
101 | if pos_tmp < 0:
102 | raise RuntimeError('error when parsing dict')
103 | value = dict_str[pos + 2: pos_tmp].split(',')
104 | # try to convert elements to int
105 | for i in range(len(value)):
106 | try:
107 | value[i] = int(value[i])
108 | except ValueError:
109 | pass
110 | elif dict_str[pos + 1] == '{':
111 | # value is another dictionary
112 | subdict_str = _select_block(dict_str[pos:], '{', '}')
113 | value = _parse_dict_recursive(subdict_str)
114 | pos_tmp = pos + len(subdict_str)
115 | else:
116 | raise ValueError('error when parsing dict: unknown elem')
117 |
118 | key = key.strip('"')
119 | if len(key) > 0:
120 | dict_out[key] = value
121 |
122 | pos_last = dict_str.find(',', pos_tmp)
123 | if pos_last < 0:
124 | break
125 | pos_last += 1
126 | pos = dict_str.find(':', pos_last)
127 |
128 | return dict_out
129 |
130 |
131 | def parse_sphinx_searchindex(searchindex):
132 | """Parse a Sphinx search index
133 |
134 | Parameters
135 | ----------
136 | searchindex : str
137 | The Sphinx search index (contents of searchindex.js)
138 |
139 | Returns
140 | -------
141 | filenames : list of str
142 | The file names parsed from the search index.
143 | objects : dict
144 | The objects parsed from the search index.
145 | """
146 | # Make sure searchindex uses UTF-8 encoding
147 | if hasattr(searchindex, 'decode'):
148 | searchindex = searchindex.decode('UTF-8')
149 |
150 | # parse objects
151 | query = 'objects:'
152 | pos = searchindex.find(query)
153 | if pos < 0:
154 | raise ValueError('"objects:" not found in search index')
155 |
156 | sel = _select_block(searchindex[pos:], '{', '}')
157 | objects = _parse_dict_recursive(sel)
158 |
159 | # parse filenames
160 | query = 'filenames:'
161 | pos = searchindex.find(query)
162 | if pos < 0:
163 | raise ValueError('"filenames:" not found in search index')
164 | filenames = searchindex[pos + len(query) + 1:]
165 | filenames = filenames[:filenames.find(']')]
166 | filenames = [f.strip('"') for f in filenames.split(',')]
167 |
168 | return filenames, objects
169 |
170 |
171 | class SphinxDocLinkResolver(object):
172 | """ Resolve documentation links using searchindex.js generated by Sphinx
173 |
174 | Parameters
175 | ----------
176 | doc_url : str
177 | The base URL of the project website.
178 | searchindex : str
179 | Filename of searchindex, relative to doc_url.
180 | extra_modules_test : list of str
181 | List of extra module names to test.
182 | relative : bool
183 | Return relative links (only useful for links to documentation of this
184 | package).
185 | """
186 |
187 | def __init__(self, doc_url, gallery_dir, searchindex='searchindex.js',
188 | extra_modules_test=None, relative=False):
189 | self.doc_url = doc_url
190 | self.gallery_dir = gallery_dir
191 | self.relative = relative
192 | self._link_cache = {}
193 |
194 | self.extra_modules_test = extra_modules_test
195 | self._page_cache = {}
196 | if doc_url.startswith('http://'):
197 | if relative:
198 | raise ValueError('Relative links are only supported for local '
199 | 'URLs (doc_url cannot start with "http://)"')
200 | searchindex_url = doc_url + '/' + searchindex
201 | else:
202 | searchindex_url = os.path.join(doc_url, searchindex)
203 |
204 | # detect if we are using relative links on a Windows system
205 | if os.name.lower() == 'nt' and not doc_url.startswith('http://'):
206 | if not relative:
207 | raise ValueError('You have to use relative=True for the local'
208 | ' package on a Windows system.')
209 | self._is_windows = True
210 | else:
211 | self._is_windows = False
212 |
213 | # download and initialize the search index
214 | sindex = get_data(searchindex_url, gallery_dir)
215 | filenames, objects = parse_sphinx_searchindex(sindex)
216 |
217 | self._searchindex = dict(filenames=filenames, objects=objects)
218 |
219 | def _get_link(self, cobj):
220 | """Get a valid link, False if not found"""
221 |
222 | fname_idx = None
223 | full_name = cobj['module_short'] + '.' + cobj['name']
224 | if full_name in self._searchindex['objects']:
225 | value = self._searchindex['objects'][full_name]
226 | if isinstance(value, dict):
227 | value = value[next(iter(value.keys()))]
228 | fname_idx = value[0]
229 | elif cobj['module_short'] in self._searchindex['objects']:
230 | value = self._searchindex['objects'][cobj['module_short']]
231 | if cobj['name'] in value.keys():
232 | fname_idx = value[cobj['name']][0]
233 |
234 | if fname_idx is not None:
235 | fname = self._searchindex['filenames'][fname_idx]
236 | # In 1.5+ Sphinx seems to have changed from .rst.html to only
237 | # .html extension in converted files. But URLs could be
238 | # built with < 1.5 or >= 1.5 regardless of what we're currently
239 | # building with, so let's just check both :(
240 | fnames = [fname + '.html', os.path.splitext(fname)[0] + '.html']
241 | for fname in fnames:
242 | try:
243 | if self._is_windows:
244 | fname = fname.replace('/', '\\')
245 | link = os.path.join(self.doc_url, fname)
246 | else:
247 | link = posixpath.join(self.doc_url, fname)
248 |
249 | if hasattr(link, 'decode'):
250 | link = link.decode('utf-8', 'replace')
251 |
252 | if link in self._page_cache:
253 | html = self._page_cache[link]
254 | else:
255 | html = get_data(link, self.gallery_dir)
256 | self._page_cache[link] = html
257 | except (HTTPError, URLError, IOError):
258 | pass
259 | else:
260 | break
261 | else:
262 | raise
263 |
264 | # test if cobj appears in page
265 | comb_names = [cobj['module_short'] + '.' + cobj['name']]
266 | if self.extra_modules_test is not None:
267 | for mod in self.extra_modules_test:
268 | comb_names.append(mod + '.' + cobj['name'])
269 | url = False
270 | if hasattr(html, 'decode'):
271 | # Decode bytes under Python 3
272 | html = html.decode('utf-8', 'replace')
273 |
274 | for comb_name in comb_names:
275 | if hasattr(comb_name, 'decode'):
276 | # Decode bytes under Python 3
277 | comb_name = comb_name.decode('utf-8', 'replace')
278 | if comb_name in html:
279 | url = link + u'#' + comb_name
280 | link = url
281 | else:
282 | link = False
283 |
284 | return link
285 |
286 | def resolve(self, cobj, this_url):
287 | """Resolve the link to the documentation, returns None if not found
288 |
289 | Parameters
290 | ----------
291 | cobj : dict
292 | Dict with information about the "code object" for which we are
293 | resolving a link.
294 | cobi['name'] : function or class name (str)
295 | cobj['module_short'] : shortened module name (str)
296 | cobj['module'] : module name (str)
297 | this_url: str
298 | URL of the current page. Needed to construct relative URLs
299 | (only used if relative=True in constructor).
300 |
301 | Returns
302 | -------
303 | link : str | None
304 | The link (URL) to the documentation.
305 | """
306 | full_name = cobj['module_short'] + '.' + cobj['name']
307 | link = self._link_cache.get(full_name, None)
308 | if link is None:
309 | # we don't have it cached
310 | link = self._get_link(cobj)
311 | # cache it for the future
312 | self._link_cache[full_name] = link
313 |
314 | if link is False or link is None:
315 | # failed to resolve
316 | return None
317 |
318 | if self.relative:
319 | link = os.path.relpath(link, start=this_url)
320 | if self._is_windows:
321 | # replace '\' with '/' so it on the web
322 | link = link.replace('\\', '/')
323 |
324 | # for some reason, the relative link goes one directory too high up
325 | link = link[3:]
326 |
327 | return link
328 |
329 |
330 | def _embed_code_links(app, gallery_conf, gallery_dir):
331 | # Add resolvers for the packages for which we want to show links
332 | doc_resolvers = {}
333 |
334 | for this_module, url in gallery_conf['reference_url'].items():
335 | try:
336 | if url is None:
337 | doc_resolvers[this_module] = SphinxDocLinkResolver(
338 | app.builder.outdir,
339 | gallery_dir,
340 | relative=True)
341 | else:
342 | doc_resolvers[this_module] = SphinxDocLinkResolver(url,
343 | gallery_dir)
344 |
345 | except HTTPError as e:
346 | print("The following HTTP Error has occurred:\n")
347 | print(e.code)
348 | except URLError as e:
349 | print("\n...\n"
350 | "Warning: Embedding the documentation hyperlinks requires "
351 | "Internet access.\nPlease check your network connection.\n"
352 | "Unable to continue embedding `{0}` links due to a URL "
353 | "Error:\n".format(this_module))
354 | print(e.args)
355 |
356 | html_gallery_dir = os.path.abspath(os.path.join(app.builder.outdir,
357 | gallery_dir))
358 |
359 | # patterns for replacement
360 | link_pattern = ('%s')
362 | orig_pattern = '%s'
363 | period = '.'
364 |
365 | for dirpath, _, filenames in os.walk(html_gallery_dir):
366 | for fname in filenames:
367 | print('\tprocessing: %s' % fname)
368 | full_fname = os.path.join(html_gallery_dir, dirpath, fname)
369 | subpath = dirpath[len(html_gallery_dir) + 1:]
370 | pickle_fname = os.path.join(gallery_dir, subpath,
371 | fname[:-5] + '_codeobj.pickle')
372 |
373 | if os.path.exists(pickle_fname):
374 | # we have a pickle file with the objects to embed links for
375 | with open(pickle_fname, 'rb') as fid:
376 | example_code_obj = pickle.load(fid)
377 | fid.close()
378 | str_repl = {}
379 | # generate replacement strings with the links
380 | for name, cobj in example_code_obj.items():
381 | this_module = cobj['module'].split('.')[0]
382 |
383 | if this_module not in doc_resolvers:
384 | continue
385 |
386 | try:
387 | link = doc_resolvers[this_module].resolve(cobj,
388 | full_fname)
389 | except (HTTPError, URLError) as e:
390 | if isinstance(e, HTTPError):
391 | extra = e.code
392 | else:
393 | extra = e.reason
394 | print("\t\tError resolving %s.%s: %r (%s)"
395 | % (cobj['module'], cobj['name'], e, extra))
396 | continue
397 |
398 | if link is not None:
399 | parts = name.split('.')
400 | name_html = period.join(orig_pattern % part
401 | for part in parts)
402 | full_function_name = '%s.%s' % (
403 | cobj['module'], cobj['name'])
404 | str_repl[name_html] = link_pattern % (
405 | link, full_function_name, name_html)
406 | # do the replacement in the html file
407 |
408 | # ensure greediness
409 | names = sorted(str_repl, key=len, reverse=True)
410 | expr = re.compile(r'(? 0:
418 | with open(full_fname, 'rb') as fid:
419 | lines_in = fid.readlines()
420 | with open(full_fname, 'wb') as fid:
421 | for line in lines_in:
422 | line = line.decode('utf-8')
423 | line = expr.sub(substitute_link, line)
424 | fid.write(line.encode('utf-8'))
425 | print('[done]')
426 |
427 |
428 | def embed_code_links(app, exception):
429 | """Embed hyperlinks to documentation into example code"""
430 | if exception is not None:
431 | return
432 |
433 | # No need to waste time embedding hyperlinks when not running the examples
434 | # XXX: also at the time of writing this fixes make html-noplot
435 | # for some reason I don't fully understand
436 | if not app.builder.config.plot_gallery:
437 | return
438 |
439 | # XXX: Whitelist of builders for which it makes sense to embed
440 | # hyperlinks inside the example html. Note that the link embedding
441 | # require searchindex.js to exist for the links to the local doc
442 | # and there does not seem to be a good way of knowing which
443 | # builders creates a searchindex.js.
444 | if app.builder.name not in ['html', 'readthedocs']:
445 | return
446 |
447 | print('Embedding documentation hyperlinks in examples..')
448 |
449 | gallery_conf = app.config.sphinx_gallery_conf
450 |
451 | gallery_dirs = gallery_conf['gallery_dirs']
452 | if not isinstance(gallery_dirs, list):
453 | gallery_dirs = [gallery_dirs]
454 |
455 | for gallery_dir in gallery_dirs:
456 | _embed_code_links(app, gallery_conf, gallery_dir)
457 |
--------------------------------------------------------------------------------
/doc/sphinxext/sphinx_gallery/downloads.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | r"""
3 | Utilities for downloadable items
4 | ================================
5 |
6 | """
7 | # Author: Óscar Nájera
8 | # License: 3-clause BSD
9 |
10 | from __future__ import absolute_import, division, print_function
11 |
12 | import os
13 | import zipfile
14 |
15 | CODE_DOWNLOAD = """
16 | \n.. container:: sphx-glr-footer
17 |
18 | \n .. container:: sphx-glr-download
19 |
20 | :download:`Download Python source code: {0} <{0}>`\n
21 |
22 | \n .. container:: sphx-glr-download
23 |
24 | :download:`Download Jupyter notebook: {1} <{1}>`\n"""
25 |
26 | CODE_ZIP_DOWNLOAD = """
27 | \n.. container:: sphx-glr-footer
28 |
29 | \n .. container:: sphx-glr-download
30 |
31 | :download:`Download all examples in Python source code: {0} {1}>`\n
32 |
33 | \n .. container:: sphx-glr-download
34 |
35 | :download:`Download all examples in Jupyter notebooks: {2} {3}>`\n"""
36 |
37 |
38 | def python_zip(file_list, gallery_path, extension='.py'):
39 | """Stores all files in file_list into an zip file
40 |
41 | Parameters
42 | ----------
43 | file_list : list of strings
44 | Holds all the file names to be included in zip file
45 | gallery_path : string
46 | path to where the zipfile is stored
47 | extension : str
48 | '.py' or '.ipynb' In order to deal with downloads of python
49 | sources and jupyter notebooks the file extension from files in
50 | file_list will be removed and replace with the value of this
51 | variable while generating the zip file
52 | Returns
53 | -------
54 | zipname : string
55 | zip file name, written as `target_dir_{python,jupyter}.zip`
56 | depending on the extension
57 | """
58 | zipname = gallery_path.replace(os.path.sep, '_')
59 | zipname += '_python' if extension == '.py' else '_jupyter'
60 | zipname = os.path.join(gallery_path, zipname + '.zip')
61 |
62 | zipf = zipfile.ZipFile(zipname, mode='w')
63 | for fname in file_list:
64 | file_src = os.path.splitext(fname)[0] + extension
65 | zipf.write(file_src)
66 | zipf.close()
67 |
68 | return zipname
69 |
70 |
71 | def list_downloadable_sources(target_dir):
72 | """Returns a list of python source files is target_dir
73 |
74 | Parameters
75 | ----------
76 | target_dir : string
77 | path to the directory where python source file are
78 | Returns
79 | -------
80 | list
81 | list of paths to all Python source files in `target_dir`
82 | """
83 | return [os.path.join(target_dir, fname)
84 | for fname in os.listdir(target_dir)
85 | if fname.endswith('.py')]
86 |
87 |
88 | def generate_zipfiles(gallery_dir):
89 | """
90 | Collects all Python source files and Jupyter notebooks in
91 | gallery_dir and makes zipfiles of them
92 |
93 | Parameters
94 | ----------
95 | gallery_dir : string
96 | path of the gallery to collect downloadable sources
97 |
98 | Return
99 | ------
100 | download_rst: string
101 | RestructuredText to include download buttons to the generated files
102 | """
103 |
104 | listdir = list_downloadable_sources(gallery_dir)
105 | for directory in sorted(os.listdir(gallery_dir)):
106 | if os.path.isdir(os.path.join(gallery_dir, directory)):
107 | target_dir = os.path.join(gallery_dir, directory)
108 | listdir.extend(list_downloadable_sources(target_dir))
109 |
110 | py_zipfile = python_zip(listdir, gallery_dir)
111 | jy_zipfile = python_zip(listdir, gallery_dir, ".ipynb")
112 |
113 | dw_rst = CODE_ZIP_DOWNLOAD.format(os.path.basename(py_zipfile),
114 | py_zipfile,
115 | os.path.basename(jy_zipfile),
116 | jy_zipfile)
117 | return dw_rst
118 |
--------------------------------------------------------------------------------
/doc/sphinxext/sphinx_gallery/gen_gallery.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Author: Óscar Nájera
3 | # License: 3-clause BSD
4 | """
5 | ========================
6 | Sphinx-Gallery Generator
7 | ========================
8 |
9 | Attaches Sphinx-Gallery to Sphinx in order to generate the galleries
10 | when building the documentation.
11 | """
12 |
13 |
14 | from __future__ import division, print_function, absolute_import
15 | import copy
16 | import re
17 | import os
18 | from . import glr_path_static
19 | from .gen_rst import generate_dir_rst, SPHX_GLR_SIG
20 | from .docs_resolv import embed_code_links
21 | from .downloads import generate_zipfiles
22 |
23 | try:
24 | FileNotFoundError
25 | except NameError:
26 | # Python2
27 | FileNotFoundError = IOError
28 |
29 | DEFAULT_GALLERY_CONF = {
30 | 'filename_pattern': re.escape(os.sep) + 'plot',
31 | 'examples_dirs': os.path.join('..', 'examples'),
32 | 'gallery_dirs': 'auto_examples',
33 | 'mod_example_dir': os.path.join('modules', 'generated'),
34 | 'doc_module': (),
35 | 'reference_url': {},
36 | # build options
37 | 'plot_gallery': True,
38 | 'download_all_examples': True,
39 | 'abort_on_example_error': False,
40 | 'failing_examples': {},
41 | 'expected_failing_examples': set(),
42 | }
43 |
44 |
45 | def clean_gallery_out(build_dir):
46 | """Deletes images under the sphx_glr namespace in the build directory"""
47 | # Sphinx hack: sphinx copies generated images to the build directory
48 | # each time the docs are made. If the desired image name already
49 | # exists, it appends a digit to prevent overwrites. The problem is,
50 | # the directory is never cleared. This means that each time you build
51 | # the docs, the number of images in the directory grows.
52 | #
53 | # This question has been asked on the sphinx development list, but there
54 | # was no response: http://osdir.com/ml/sphinx-dev/2011-02/msg00123.html
55 | #
56 | # The following is a hack that prevents this behavior by clearing the
57 | # image build directory from gallery images each time the docs are built.
58 | # If sphinx changes their layout between versions, this will not
59 | # work (though it should probably not cause a crash).
60 | # Tested successfully on Sphinx 1.0.7
61 |
62 | build_image_dir = os.path.join(build_dir, '_images')
63 | if os.path.exists(build_image_dir):
64 | filelist = os.listdir(build_image_dir)
65 | for filename in filelist:
66 | if filename.startswith('sphx_glr') and filename.endswith('png'):
67 | os.remove(os.path.join(build_image_dir, filename))
68 |
69 |
70 | def generate_gallery_rst(app):
71 | """Generate the Main examples gallery reStructuredText
72 |
73 | Start the sphinx-gallery configuration and recursively scan the examples
74 | directories in order to populate the examples gallery
75 | """
76 | try:
77 | plot_gallery = eval(app.builder.config.plot_gallery)
78 | except TypeError:
79 | plot_gallery = bool(app.builder.config.plot_gallery)
80 |
81 | gallery_conf = copy.deepcopy(DEFAULT_GALLERY_CONF)
82 | gallery_conf.update(app.config.sphinx_gallery_conf)
83 | gallery_conf.update(plot_gallery=plot_gallery)
84 | gallery_conf.update(
85 | abort_on_example_error=app.builder.config.abort_on_example_error)
86 |
87 | # this assures I can call the config in other places
88 | app.config.sphinx_gallery_conf = gallery_conf
89 | app.config.html_static_path.append(glr_path_static())
90 |
91 | clean_gallery_out(app.builder.outdir)
92 |
93 | examples_dirs = gallery_conf['examples_dirs']
94 | gallery_dirs = gallery_conf['gallery_dirs']
95 |
96 | if not isinstance(examples_dirs, list):
97 | examples_dirs = [examples_dirs]
98 | if not isinstance(gallery_dirs, list):
99 | gallery_dirs = [gallery_dirs]
100 |
101 | mod_examples_dir = os.path.relpath(gallery_conf['mod_example_dir'],
102 | app.builder.srcdir)
103 | seen_backrefs = set()
104 |
105 | computation_times = []
106 |
107 | # cd to the appropriate directory regardless of sphinx configuration
108 | working_dir = os.getcwd()
109 | os.chdir(app.builder.srcdir)
110 | for examples_dir, gallery_dir in zip(examples_dirs, gallery_dirs):
111 | examples_dir = os.path.relpath(examples_dir,
112 | app.builder.srcdir)
113 | gallery_dir = os.path.relpath(gallery_dir,
114 | app.builder.srcdir)
115 |
116 | for workdir in [examples_dir, gallery_dir, mod_examples_dir]:
117 | if not os.path.exists(workdir):
118 | os.makedirs(workdir)
119 | # we create an index.rst with all examples
120 | fhindex = open(os.path.join(gallery_dir, 'index.rst'), 'w')
121 | # Here we don't use an os.walk, but we recurse only twice: flat is
122 | # better than nested.
123 | this_fhindex, this_computation_times = \
124 | generate_dir_rst(examples_dir, gallery_dir, gallery_conf,
125 | seen_backrefs)
126 | if this_fhindex == "":
127 | raise FileNotFoundError("Main example directory {0} does not "
128 | "have a README.txt file. Please write "
129 | "one to introduce your gallery.".format(examples_dir))
130 |
131 | computation_times += this_computation_times
132 |
133 | fhindex.write(this_fhindex)
134 | for directory in sorted(os.listdir(examples_dir)):
135 | if os.path.isdir(os.path.join(examples_dir, directory)):
136 | src_dir = os.path.join(examples_dir, directory)
137 | target_dir = os.path.join(gallery_dir, directory)
138 | this_fhindex, this_computation_times = \
139 | generate_dir_rst(src_dir, target_dir, gallery_conf,
140 | seen_backrefs)
141 | fhindex.write(this_fhindex)
142 | computation_times += this_computation_times
143 |
144 | if gallery_conf['download_all_examples']:
145 | download_fhindex = generate_zipfiles(gallery_dir)
146 | fhindex.write(download_fhindex)
147 |
148 | fhindex.write(SPHX_GLR_SIG)
149 | fhindex.flush()
150 |
151 | # Back to initial directory
152 | os.chdir(working_dir)
153 |
154 | print("Computation time summary:")
155 | for time_elapsed, fname in sorted(computation_times)[::-1]:
156 | if time_elapsed is not None:
157 | print("\t- %s : %.2g sec" % (fname, time_elapsed))
158 | else:
159 | print("\t- %s : not run" % fname)
160 |
161 |
162 | def touch_empty_backreferences(app, what, name, obj, options, lines):
163 | """Generate empty back-reference example files
164 |
165 | This avoids inclusion errors/warnings if there are no gallery
166 | examples for a class / module that is being parsed by autodoc"""
167 |
168 | examples_path = os.path.join(app.srcdir,
169 | app.config.sphinx_gallery_conf[
170 | "mod_example_dir"],
171 | "%s.examples" % name)
172 |
173 | if not os.path.exists(examples_path):
174 | # touch file
175 | open(examples_path, 'w').close()
176 |
177 |
178 | def sumarize_failing_examples(app, exception):
179 | """Collects the list of falling examples during build and prints them with the traceback
180 |
181 | Raises ValueError if there where failing examples
182 | """
183 | if exception is not None:
184 | return
185 |
186 | # Under no-plot Examples are not run so nothing to summarize
187 | if not app.config.sphinx_gallery_conf['plot_gallery']:
188 | return
189 |
190 | gallery_conf = app.config.sphinx_gallery_conf
191 | failing_examples = set([os.path.normpath(path) for path in
192 | gallery_conf['failing_examples']])
193 | expected_failing_examples = set([os.path.normpath(path) for path in
194 | gallery_conf['expected_failing_examples']])
195 |
196 | examples_expected_to_fail = failing_examples.intersection(
197 | expected_failing_examples)
198 | expected_fail_msg = []
199 | if examples_expected_to_fail:
200 | expected_fail_msg.append("Examples failing as expected:")
201 | for fail_example in examples_expected_to_fail:
202 | expected_fail_msg.append(fail_example + ' failed leaving traceback:\n' +
203 | gallery_conf['failing_examples'][fail_example] + '\n')
204 | print("\n".join(expected_fail_msg))
205 |
206 | examples_not_expected_to_fail = failing_examples.difference(
207 | expected_failing_examples)
208 | fail_msgs = []
209 | if examples_not_expected_to_fail:
210 | fail_msgs.append("Unexpected failing examples:")
211 | for fail_example in examples_not_expected_to_fail:
212 | fail_msgs.append(fail_example + ' failed leaving traceback:\n' +
213 | gallery_conf['failing_examples'][fail_example] + '\n')
214 |
215 | examples_not_expected_to_pass = expected_failing_examples.difference(
216 | failing_examples)
217 | if examples_not_expected_to_pass:
218 | fail_msgs.append("Examples expected to fail, but not failling:\n" +
219 | "Please remove these examples from\n" +
220 | "sphinx_gallery_conf['expected_failing_examples']\n" +
221 | "in your conf.py file"
222 | "\n".join(examples_not_expected_to_pass))
223 |
224 | if fail_msgs:
225 | raise ValueError("Here is a summary of the problems encountered when "
226 | "running the examples\n\n" + "\n".join(fail_msgs) +
227 | "\n" + "-" * 79)
228 |
229 |
230 | def get_default_config_value(key):
231 | def default_getter(conf):
232 | return conf['sphinx_gallery_conf'].get(key, DEFAULT_GALLERY_CONF[key])
233 | return default_getter
234 |
235 |
236 | def setup(app):
237 | """Setup sphinx-gallery sphinx extension"""
238 | app.add_config_value('sphinx_gallery_conf', DEFAULT_GALLERY_CONF, 'html')
239 | for key in ['plot_gallery', 'abort_on_example_error']:
240 | app.add_config_value(key, get_default_config_value(key), 'html')
241 |
242 | app.add_stylesheet('gallery.css')
243 |
244 | if 'sphinx.ext.autodoc' in app._extensions:
245 | app.connect('autodoc-process-docstring', touch_empty_backreferences)
246 |
247 | app.connect('builder-inited', generate_gallery_rst)
248 |
249 | app.connect('build-finished', sumarize_failing_examples)
250 | app.connect('build-finished', embed_code_links)
251 |
252 |
253 | def setup_module():
254 | # HACK: Stop nosetests running setup() above
255 | pass
256 |
--------------------------------------------------------------------------------
/doc/sphinxext/sphinx_gallery/gen_rst.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Author: Óscar Nájera
3 | # License: 3-clause BSD
4 | """
5 | ==================
6 | RST file generator
7 | ==================
8 |
9 | Generate the rst files for the examples by iterating over the python
10 | example files.
11 |
12 | Files that generate images should start with 'plot'
13 |
14 | """
15 | # Don't use unicode_literals here (be explicit with u"..." instead) otherwise
16 | # tricky errors come up with exec(code_blocks, ...) calls
17 | from __future__ import division, print_function, absolute_import
18 | from time import time
19 | import codecs
20 | import hashlib
21 | import os
22 | import re
23 | import shutil
24 | import subprocess
25 | import sys
26 | import traceback
27 | import warnings
28 |
29 |
30 | # Try Python 2 first, otherwise load from Python 3
31 | try:
32 | # textwrap indent only exists in python 3
33 | from textwrap import indent
34 | except ImportError:
35 | def indent(text, prefix, predicate=None):
36 | """Adds 'prefix' to the beginning of selected lines in 'text'.
37 |
38 | If 'predicate' is provided, 'prefix' will only be added to the lines
39 | where 'predicate(line)' is True. If 'predicate' is not provided,
40 | it will default to adding 'prefix' to all non-empty lines that do not
41 | consist solely of whitespace characters.
42 | """
43 | if predicate is None:
44 | def predicate(line):
45 | return line.strip()
46 |
47 | def prefixed_lines():
48 | for line in text.splitlines(True):
49 | yield (prefix + line if predicate(line) else line)
50 | return ''.join(prefixed_lines())
51 |
52 | from io import StringIO
53 |
54 | try:
55 | # make sure that the Agg backend is set before importing any
56 | # matplotlib
57 | import matplotlib
58 | matplotlib.use('agg')
59 | matplotlib_backend = matplotlib.get_backend()
60 |
61 | if matplotlib_backend != 'agg':
62 | mpl_backend_msg = (
63 | "Sphinx-Gallery relies on the matplotlib 'agg' backend to "
64 | "render figures and write them to files. You are "
65 | "currently using the {} backend. Sphinx-Gallery will "
66 | "terminate the build now, because changing backends is "
67 | "not well supported by matplotlib. We advise you to move "
68 | "sphinx_gallery imports before any matplotlib-dependent "
69 | "import. Moving sphinx_gallery imports at the top of "
70 | "your conf.py file should fix this issue")
71 |
72 | raise ValueError(mpl_backend_msg.format(matplotlib_backend))
73 |
74 | import matplotlib.pyplot as plt
75 | except ImportError:
76 | # this script can be imported by nosetest to find tests to run: we should
77 | # not impose the matplotlib requirement in that case.
78 | pass
79 |
80 | from . import glr_path_static
81 | from .backreferences import write_backreferences, _thumbnail_div
82 | from .downloads import CODE_DOWNLOAD
83 | from .py_source_parser import (get_docstring_and_rest,
84 | split_code_and_text_blocks)
85 |
86 | from .notebook import jupyter_notebook, text2string, save_notebook
87 |
88 | try:
89 | basestring
90 | except NameError:
91 | basestring = str
92 | unicode = str
93 |
94 |
95 | ###############################################################################
96 |
97 |
98 | class Tee(object):
99 | """A tee object to redirect streams to multiple outputs"""
100 |
101 | def __init__(self, file1, file2):
102 | self.file1 = file1
103 | self.file2 = file2
104 |
105 | def write(self, data):
106 | self.file1.write(data)
107 | self.file2.write(data)
108 |
109 | def flush(self):
110 | self.file1.flush()
111 | self.file2.flush()
112 |
113 | # When called from a local terminal seaborn needs it in Python3
114 | def isatty(self):
115 | self.file1.isatty()
116 |
117 |
118 | class MixedEncodingStringIO(StringIO):
119 | """Helper when both ASCII and unicode strings will be written"""
120 |
121 | def write(self, data):
122 | if not isinstance(data, unicode):
123 | data = data.decode('utf-8')
124 | StringIO.write(self, data)
125 |
126 |
127 | ###############################################################################
128 | # The following strings are used when we have several pictures: we use
129 | # an html div tag that our CSS uses to turn the lists into horizontal
130 | # lists.
131 | HLIST_HEADER = """
132 | .. rst-class:: sphx-glr-horizontal
133 |
134 | """
135 |
136 | HLIST_IMAGE_TEMPLATE = """
137 | *
138 |
139 | .. image:: /%s
140 | :scale: 47
141 | """
142 |
143 | SINGLE_IMAGE = """
144 | .. image:: /%s
145 | :align: center
146 | """
147 |
148 |
149 | # This one could contain unicode
150 | CODE_OUTPUT = u""".. rst-class:: sphx-glr-script-out
151 |
152 | Out::
153 |
154 | {0}\n"""
155 |
156 |
157 | SPHX_GLR_SIG = """\n.. rst-class:: sphx-glr-signature
158 |
159 | `Generated by Sphinx-Gallery `_\n"""
160 |
161 |
162 | def codestr2rst(codestr, lang='python'):
163 | """Return reStructuredText code block from code string"""
164 | code_directive = "\n.. code-block:: {0}\n\n".format(lang)
165 | indented_block = indent(codestr, ' ' * 4)
166 | return code_directive + indented_block
167 |
168 |
169 | def extract_thumbnail_number(text):
170 | """ Pull out the thumbnail image number specified in the docstring. """
171 |
172 | # check whether the user has specified a specific thumbnail image
173 | pattr = re.compile(
174 | r"^\s*#\s*sphinx_gallery_thumbnail_number\s*=\s*([0-9]+)\s*$",
175 | flags=re.MULTILINE)
176 | match = pattr.search(text)
177 |
178 | if match is None:
179 | # by default, use the first figure created
180 | thumbnail_number = 1
181 | else:
182 | thumbnail_number = int(match.groups()[0])
183 |
184 | return thumbnail_number
185 |
186 |
187 | def extract_intro(filename):
188 | """ Extract the first paragraph of module-level docstring. max:95 char"""
189 |
190 | docstring, _ = get_docstring_and_rest(filename)
191 |
192 | # lstrip is just in case docstring has a '\n\n' at the beginning
193 | paragraphs = docstring.lstrip().split('\n\n')
194 | if len(paragraphs) > 1:
195 | first_paragraph = re.sub('\n', ' ', paragraphs[1])
196 | first_paragraph = (first_paragraph[:95] + '...'
197 | if len(first_paragraph) > 95 else first_paragraph)
198 | else:
199 | raise ValueError(
200 | "Example docstring should have a header for the example title "
201 | "and at least a paragraph explaining what the example is about. "
202 | "Please check the example file:\n {}\n".format(filename))
203 |
204 | return first_paragraph
205 |
206 |
207 | def get_md5sum(src_file):
208 | """Returns md5sum of file"""
209 |
210 | with open(src_file, 'rb') as src_data:
211 | src_content = src_data.read()
212 |
213 | src_md5 = hashlib.md5(src_content).hexdigest()
214 | return src_md5
215 |
216 |
217 | def md5sum_is_current(src_file):
218 | """Checks whether src_file has the same md5 hash as the one on disk"""
219 |
220 | src_md5 = get_md5sum(src_file)
221 |
222 | src_md5_file = src_file + '.md5'
223 | if os.path.exists(src_md5_file):
224 | with open(src_md5_file, 'r') as file_checksum:
225 | ref_md5 = file_checksum.read()
226 |
227 | return src_md5 == ref_md5
228 |
229 | return False
230 |
231 |
232 | def save_figures(image_path, fig_count, gallery_conf):
233 | """Save all open matplotlib figures of the example code-block
234 |
235 | Parameters
236 | ----------
237 | image_path : str
238 | Path where plots are saved (format string which accepts figure number)
239 | fig_count : int
240 | Previous figure number count. Figure number add from this number
241 | gallery_conf : dict
242 | Contains the configuration of Sphinx-Gallery
243 |
244 | Returns
245 | -------
246 | figure_list : list of str
247 | strings containing the full path to each figure
248 | images_rst : str
249 | rst code to embed the images in the document
250 | """
251 | figure_list = []
252 |
253 | fig_numbers = plt.get_fignums()
254 | for fig_num in fig_numbers:
255 | # Set the fig_num figure as the current figure as we can't
256 | # save a figure that's not the current figure.
257 | fig = plt.figure(fig_num)
258 | kwargs = {}
259 | to_rgba = matplotlib.colors.colorConverter.to_rgba
260 | for attr in ['facecolor', 'edgecolor']:
261 | fig_attr = getattr(fig, 'get_' + attr)()
262 | default_attr = matplotlib.rcParams['figure.' + attr]
263 | if to_rgba(fig_attr) != to_rgba(default_attr):
264 | kwargs[attr] = fig_attr
265 |
266 | current_fig = image_path.format(fig_count + fig_num)
267 | fig.savefig(current_fig, **kwargs)
268 | figure_list.append(current_fig)
269 |
270 | if gallery_conf.get('find_mayavi_figures', False):
271 | from mayavi import mlab
272 | e = mlab.get_engine()
273 | last_matplotlib_fig_num = fig_count + len(figure_list)
274 | total_fig_num = last_matplotlib_fig_num + len(e.scenes)
275 | mayavi_fig_nums = range(last_matplotlib_fig_num + 1, total_fig_num + 1)
276 |
277 | for scene, mayavi_fig_num in zip(e.scenes, mayavi_fig_nums):
278 | current_fig = image_path.format(mayavi_fig_num)
279 | mlab.savefig(current_fig, figure=scene)
280 | # make sure the image is not too large
281 | scale_image(current_fig, current_fig, 850, 999)
282 | figure_list.append(current_fig)
283 | mlab.close(all=True)
284 |
285 | # Depending on whether we have one or more figures, we're using a
286 | # horizontal list or a single rst call to 'image'.
287 | images_rst = ""
288 | if len(figure_list) == 1:
289 | figure_name = figure_list[0]
290 | images_rst = SINGLE_IMAGE % figure_name.lstrip('/')
291 | elif len(figure_list) > 1:
292 | images_rst = HLIST_HEADER
293 | for figure_name in figure_list:
294 | images_rst += HLIST_IMAGE_TEMPLATE % figure_name.lstrip('/')
295 |
296 | return figure_list, images_rst
297 |
298 |
299 | def scale_image(in_fname, out_fname, max_width, max_height):
300 | """Scales an image with the same aspect ratio centered in an
301 | image with a given max_width and max_height
302 | if in_fname == out_fname the image can only be scaled down
303 | """
304 | # local import to avoid testing dependency on PIL:
305 | try:
306 | from PIL import Image
307 | except ImportError:
308 | import Image
309 | img = Image.open(in_fname)
310 | width_in, height_in = img.size
311 | scale_w = max_width / float(width_in)
312 | scale_h = max_height / float(height_in)
313 |
314 | if height_in * scale_w <= max_height:
315 | scale = scale_w
316 | else:
317 | scale = scale_h
318 |
319 | if scale >= 1.0 and in_fname == out_fname:
320 | return
321 |
322 | width_sc = int(round(scale * width_in))
323 | height_sc = int(round(scale * height_in))
324 |
325 | # resize the image
326 | img.thumbnail((width_sc, height_sc), Image.ANTIALIAS)
327 |
328 | # insert centered
329 | thumb = Image.new('RGB', (max_width, max_height), (255, 255, 255))
330 | pos_insert = ((max_width - width_sc) // 2, (max_height - height_sc) // 2)
331 | thumb.paste(img, pos_insert)
332 |
333 | thumb.save(out_fname)
334 | # Use optipng to perform lossless compression on the resized image if
335 | # software is installed
336 | if os.environ.get('SKLEARN_DOC_OPTIPNG', False):
337 | try:
338 | subprocess.call(["optipng", "-quiet", "-o", "9", out_fname])
339 | except Exception:
340 | warnings.warn('Install optipng to reduce the size of the \
341 | generated images')
342 |
343 |
344 | def save_thumbnail(image_path_template, src_file, gallery_conf):
345 | """Save the thumbnail image"""
346 | # read specification of the figure to display as thumbnail from main text
347 | _, content = get_docstring_and_rest(src_file)
348 | thumbnail_number = extract_thumbnail_number(content)
349 | thumbnail_image_path = image_path_template.format(thumbnail_number)
350 |
351 | thumb_dir = os.path.join(os.path.dirname(thumbnail_image_path), 'thumb')
352 | if not os.path.exists(thumb_dir):
353 | os.makedirs(thumb_dir)
354 |
355 | base_image_name = os.path.splitext(os.path.basename(src_file))[0]
356 | thumb_file = os.path.join(thumb_dir,
357 | 'sphx_glr_%s_thumb.png' % base_image_name)
358 |
359 | if src_file in gallery_conf['failing_examples']:
360 | broken_img = os.path.join(glr_path_static(), 'broken_example.png')
361 | scale_image(broken_img, thumb_file, 200, 140)
362 |
363 | elif os.path.exists(thumbnail_image_path):
364 | scale_image(thumbnail_image_path, thumb_file, 400, 280)
365 |
366 | elif not os.path.exists(thumb_file):
367 | # create something to replace the thumbnail
368 | default_thumb_file = os.path.join(glr_path_static(), 'no_image.png')
369 | default_thumb_file = gallery_conf.get("default_thumb_file",
370 | default_thumb_file)
371 | scale_image(default_thumb_file, thumb_file, 200, 140)
372 |
373 |
374 | def generate_dir_rst(src_dir, target_dir, gallery_conf, seen_backrefs):
375 | """Generate the gallery reStructuredText for an example directory"""
376 | if not os.path.exists(os.path.join(src_dir, 'README.txt')):
377 | print(80 * '_')
378 | print('Example directory %s does not have a README.txt file' %
379 | src_dir)
380 | print('Skipping this directory')
381 | print(80 * '_')
382 | return "", [] # because string is an expected return type
383 |
384 | fhindex = open(os.path.join(src_dir, 'README.txt')).read()
385 | # Add empty lines to avoid bug in issue #165
386 | fhindex += "\n\n"
387 |
388 | if not os.path.exists(target_dir):
389 | os.makedirs(target_dir)
390 | sorted_listdir = [fname for fname in sorted(os.listdir(src_dir))
391 | if fname.endswith('.py')]
392 | entries_text = []
393 | computation_times = []
394 | for fname in sorted_listdir:
395 | amount_of_code, time_elapsed = \
396 | generate_file_rst(fname, target_dir, src_dir, gallery_conf)
397 | computation_times.append((time_elapsed, fname))
398 | new_fname = os.path.join(src_dir, fname)
399 | intro = extract_intro(new_fname)
400 | write_backreferences(seen_backrefs, gallery_conf,
401 | target_dir, fname, intro)
402 | this_entry = _thumbnail_div(target_dir, fname, intro) + """
403 |
404 | .. toctree::
405 | :hidden:
406 |
407 | /%s/%s\n""" % (target_dir, fname[:-3])
408 | entries_text.append((amount_of_code, this_entry))
409 |
410 | # sort to have the smallest entries in the beginning
411 | entries_text.sort()
412 |
413 | for _, entry_text in entries_text:
414 | fhindex += entry_text
415 |
416 | # clear at the end of the section
417 | fhindex += """.. raw:: html\n
418 | \n\n"""
419 |
420 | return fhindex, computation_times
421 |
422 |
423 | def execute_code_block(code_block, example_globals,
424 | block_vars, gallery_conf):
425 | """Executes the code block of the example file"""
426 | time_elapsed = 0
427 | stdout = ''
428 |
429 | # If example is not suitable to run, skip executing its blocks
430 | if not block_vars['execute_script']:
431 | return stdout, time_elapsed
432 |
433 | plt.close('all')
434 | cwd = os.getcwd()
435 | # Redirect output to stdout and
436 | orig_stdout = sys.stdout
437 | src_file = block_vars['src_file']
438 |
439 | try:
440 | # First cd in the original example dir, so that any file
441 | # created by the example get created in this directory
442 | os.chdir(os.path.dirname(src_file))
443 | my_buffer = MixedEncodingStringIO()
444 | my_stdout = Tee(sys.stdout, my_buffer)
445 | sys.stdout = my_stdout
446 |
447 | t_start = time()
448 | # don't use unicode_literals at the top of this file or you get
449 | # nasty errors here on Py2.7
450 | exec(code_block, example_globals)
451 | time_elapsed = time() - t_start
452 |
453 | sys.stdout = orig_stdout
454 |
455 | my_stdout = my_buffer.getvalue().strip().expandtabs()
456 | # raise RuntimeError
457 | if my_stdout:
458 | stdout = CODE_OUTPUT.format(indent(my_stdout, u' ' * 4))
459 | os.chdir(cwd)
460 | fig_list, images_rst = save_figures(
461 | block_vars['image_path'], block_vars['fig_count'], gallery_conf)
462 | fig_num = len(fig_list)
463 |
464 | except Exception:
465 | formatted_exception = traceback.format_exc()
466 |
467 | fail_example_warning = 80 * '_' + '\n' + \
468 | '%s failed to execute correctly:' % src_file + \
469 | formatted_exception + 80 * '_' + '\n'
470 | warnings.warn(fail_example_warning)
471 |
472 | fig_num = 0
473 | images_rst = codestr2rst(formatted_exception, lang='pytb')
474 |
475 | # Breaks build on first example error
476 | # XXX This check can break during testing e.g. if you uncomment the
477 | # `raise RuntimeError` by the `my_stdout` call, maybe use `.get()`?
478 | if gallery_conf['abort_on_example_error']:
479 | raise
480 | # Stores failing file
481 | gallery_conf['failing_examples'][src_file] = formatted_exception
482 | block_vars['execute_script'] = False
483 |
484 | finally:
485 | os.chdir(cwd)
486 | sys.stdout = orig_stdout
487 |
488 | code_output = u"\n{0}\n\n{1}\n\n".format(images_rst, stdout)
489 | block_vars['fig_count'] += fig_num
490 |
491 | return code_output, time_elapsed
492 |
493 |
494 | def clean_modules():
495 | """Remove "unload" seaborn from the name space
496 |
497 | After a script is executed it can load a variety of setting that one
498 | does not want to influence in other examples in the gallery."""
499 |
500 | # Horrible code to 'unload' seaborn, so that it resets
501 | # its default when is load
502 | # Python does not support unloading of modules
503 | # https://bugs.python.org/issue9072
504 | for module in list(sys.modules.keys()):
505 | if 'seaborn' in module:
506 | del sys.modules[module]
507 |
508 | # Reset Matplotlib to default
509 | plt.rcdefaults()
510 |
511 |
512 | def generate_file_rst(fname, target_dir, src_dir, gallery_conf):
513 | """Generate the rst file for a given example.
514 |
515 | Returns
516 | -------
517 | amount_of_code : int
518 | character count of the corresponding python script in file
519 | time_elapsed : float
520 | seconds required to run the script
521 | """
522 |
523 | src_file = os.path.join(src_dir, fname)
524 | example_file = os.path.join(target_dir, fname)
525 | shutil.copyfile(src_file, example_file)
526 | script_blocks = split_code_and_text_blocks(src_file)
527 | amount_of_code = sum([len(bcontent)
528 | for blabel, bcontent in script_blocks
529 | if blabel == 'code'])
530 |
531 | if md5sum_is_current(example_file):
532 | return amount_of_code, 0
533 |
534 | image_dir = os.path.join(target_dir, 'images')
535 | if not os.path.exists(image_dir):
536 | os.makedirs(image_dir)
537 |
538 | base_image_name = os.path.splitext(fname)[0]
539 | image_fname = 'sphx_glr_' + base_image_name + '_{0:03}.png'
540 | image_path_template = os.path.join(image_dir, image_fname)
541 |
542 | ref_fname = example_file.replace(os.path.sep, '_')
543 | example_rst = """\n\n.. _sphx_glr_{0}:\n\n""".format(ref_fname)
544 |
545 | filename_pattern = gallery_conf.get('filename_pattern')
546 | execute_script = re.search(filename_pattern, src_file) and gallery_conf[
547 | 'plot_gallery']
548 | example_globals = {
549 | # A lot of examples contains 'print(__doc__)' for example in
550 | # scikit-learn so that running the example prints some useful
551 | # information. Because the docstring has been separated from
552 | # the code blocks in sphinx-gallery, __doc__ is actually
553 | # __builtin__.__doc__ in the execution context and we do not
554 | # want to print it
555 | '__doc__': '',
556 | # Examples may contain if __name__ == '__main__' guards
557 | # for in example scikit-learn if the example uses multiprocessing
558 | '__name__': '__main__',
559 | }
560 |
561 | # A simple example has two blocks: one for the
562 | # example introduction/explanation and one for the code
563 | is_example_notebook_like = len(script_blocks) > 2
564 | time_elapsed = 0
565 | block_vars = {'execute_script': execute_script, 'fig_count': 0,
566 | 'image_path': image_path_template, 'src_file': src_file}
567 | print('Executing file %s' % src_file)
568 | for blabel, bcontent in script_blocks:
569 | if blabel == 'code':
570 | code_output, rtime = execute_code_block(bcontent,
571 | example_globals,
572 | block_vars,
573 | gallery_conf)
574 |
575 | time_elapsed += rtime
576 |
577 | if is_example_notebook_like:
578 | example_rst += codestr2rst(bcontent) + '\n'
579 | example_rst += code_output
580 | else:
581 | example_rst += code_output
582 | if 'sphx-glr-script-out' in code_output:
583 | # Add some vertical space after output
584 | example_rst += "\n\n|\n\n"
585 | example_rst += codestr2rst(bcontent) + '\n'
586 |
587 | else:
588 | example_rst += text2string(bcontent) + '\n'
589 |
590 | clean_modules()
591 |
592 | # Writes md5 checksum if example has build correctly
593 | # not failed and was initially meant to run(no-plot shall not cache md5sum)
594 | if block_vars['execute_script']:
595 | with open(example_file + '.md5', 'w') as file_checksum:
596 | file_checksum.write(get_md5sum(example_file))
597 |
598 | save_thumbnail(image_path_template, src_file, gallery_conf)
599 |
600 | time_m, time_s = divmod(time_elapsed, 60)
601 | example_nb = jupyter_notebook(script_blocks)
602 | save_notebook(example_nb, example_file.replace('.py', '.ipynb'))
603 | with codecs.open(os.path.join(target_dir, base_image_name + '.rst'),
604 | mode='w', encoding='utf-8') as f:
605 | example_rst += "**Total running time of the script:**" \
606 | " ({0: .0f} minutes {1: .3f} seconds)\n\n".format(
607 | time_m, time_s)
608 | example_rst += CODE_DOWNLOAD.format(fname,
609 | fname.replace('.py', '.ipynb'))
610 | example_rst += SPHX_GLR_SIG
611 | f.write(example_rst)
612 |
613 | print("{0} ran in : {1:.2g} seconds\n".format(src_file, time_elapsed))
614 |
615 | return amount_of_code, time_elapsed
616 |
--------------------------------------------------------------------------------
/doc/sphinxext/sphinx_gallery/notebook.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | r"""
3 | ============================
4 | Parser for Jupyter notebooks
5 | ============================
6 |
7 | Class that holds the Jupyter notebook information
8 |
9 | """
10 | # Author: Óscar Nájera
11 | # License: 3-clause BSD
12 |
13 | from __future__ import division, absolute_import, print_function
14 | from functools import partial
15 | import argparse
16 | import json
17 | import re
18 | import sys
19 | from .py_source_parser import split_code_and_text_blocks
20 |
21 |
22 | def text2string(content):
23 | """Returns a string without the extra triple quotes"""
24 | try:
25 | return ast.literal_eval(content) + '\n'
26 | except Exception:
27 | return content + '\n'
28 |
29 |
30 | def jupyter_notebook_skeleton():
31 | """Returns a dictionary with the elements of a Jupyter notebook"""
32 | py_version = sys.version_info
33 | notebook_skeleton = {
34 | "cells": [],
35 | "metadata": {
36 | "kernelspec": {
37 | "display_name": "Python " + str(py_version[0]),
38 | "language": "python",
39 | "name": "python" + str(py_version[0])
40 | },
41 | "language_info": {
42 | "codemirror_mode": {
43 | "name": "ipython",
44 | "version": py_version[0]
45 | },
46 | "file_extension": ".py",
47 | "mimetype": "text/x-python",
48 | "name": "python",
49 | "nbconvert_exporter": "python",
50 | "pygments_lexer": "ipython" + str(py_version[0]),
51 | "version": '{0}.{1}.{2}'.format(*sys.version_info[:3])
52 | }
53 | },
54 | "nbformat": 4,
55 | "nbformat_minor": 0
56 | }
57 | return notebook_skeleton
58 |
59 |
60 | def directive_fun(match, directive):
61 | """Helper to fill in directives"""
62 | directive_to_alert = dict(note="info", warning="danger")
63 | return (''
64 | .format(directive_to_alert[directive], directive.capitalize(),
65 | match.group(1).strip()))
66 |
67 |
68 | def rst2md(text):
69 | """Converts the RST text from the examples docstrigs and comments
70 | into markdown text for the Jupyter notebooks"""
71 |
72 | top_heading = re.compile(r'^=+$\s^([\w\s-]+)^=+$', flags=re.M)
73 | text = re.sub(top_heading, r'# \1', text)
74 |
75 | math_eq = re.compile(r'^\.\. math::((?:.+)?(?:\n+^ .+)*)', flags=re.M)
76 | text = re.sub(math_eq,
77 | lambda match: r'\begin{{align}}{0}\end{{align}}'.format(
78 | match.group(1).strip()),
79 | text)
80 | inline_math = re.compile(r':math:`(.+?)`', re.DOTALL)
81 | text = re.sub(inline_math, r'$\1$', text)
82 |
83 | directives = ('warning', 'note')
84 | for directive in directives:
85 | directive_re = re.compile(r'^\.\. %s::((?:.+)?(?:\n+^ .+)*)'
86 | % directive, flags=re.M)
87 | text = re.sub(directive_re,
88 | partial(directive_fun, directive=directive), text)
89 |
90 | links = re.compile(r'^ *\.\. _.*:.*$\n', flags=re.M)
91 | text = re.sub(links, '', text)
92 |
93 | refs = re.compile(r':ref:`')
94 | text = re.sub(refs, '`', text)
95 |
96 | contents = re.compile(r'^\s*\.\. contents::.*$(\n +:\S+: *$)*\n',
97 | flags=re.M)
98 | text = re.sub(contents, '', text)
99 |
100 | images = re.compile(
101 | r'^\.\. image::(.*$)(?:\n *:alt:(.*$)\n)?(?: +:\S+:.*$\n)*',
102 | flags=re.M)
103 | text = re.sub(
104 | images, lambda match: '\n'.format(
105 | match.group(1).strip(), (match.group(2) or '').strip()), text)
106 |
107 | return text
108 |
109 |
110 | def jupyter_notebook(script_blocks):
111 | """Generate a Jupyter notebook file cell-by-cell
112 |
113 | Parameters
114 | ----------
115 | script_blocks: list
116 | script execution cells
117 | """
118 |
119 | work_notebook = jupyter_notebook_skeleton()
120 | add_code_cell(work_notebook, "%matplotlib inline")
121 | fill_notebook(work_notebook, script_blocks)
122 |
123 | return work_notebook
124 |
125 |
126 | def add_code_cell(work_notebook, code):
127 | """Add a code cell to the notebook
128 |
129 | Parameters
130 | ----------
131 | code : str
132 | Cell content
133 | """
134 |
135 | code_cell = {
136 | "cell_type": "code",
137 | "execution_count": None,
138 | "metadata": {"collapsed": False},
139 | "outputs": [],
140 | "source": [code.strip()]
141 | }
142 | work_notebook["cells"].append(code_cell)
143 |
144 |
145 | def add_markdown_cell(work_notebook, text):
146 | """Add a markdown cell to the notebook
147 |
148 | Parameters
149 | ----------
150 | code : str
151 | Cell content
152 | """
153 | markdown_cell = {
154 | "cell_type": "markdown",
155 | "metadata": {},
156 | "source": [rst2md(text)]
157 | }
158 | work_notebook["cells"].append(markdown_cell)
159 |
160 |
161 | def fill_notebook(work_notebook, script_blocks):
162 | """Writes the Jupyter notebook cells
163 |
164 | Parameters
165 | ----------
166 | script_blocks : list of tuples
167 | """
168 |
169 | for blabel, bcontent in script_blocks:
170 | if blabel == 'code':
171 | add_code_cell(work_notebook, bcontent)
172 | else:
173 | add_markdown_cell(work_notebook, text2string(bcontent))
174 |
175 |
176 | def save_notebook(work_notebook, write_file):
177 | """Saves the Jupyter work_notebook to write_file"""
178 | with open(write_file, 'w') as out_nb:
179 | json.dump(work_notebook, out_nb, indent=2)
180 |
181 |
182 | ###############################################################################
183 | # Notebook shell utility
184 |
185 | def python_to_jupyter_cli(args=None, namespace=None):
186 | """Exposes the jupyter notebook renderer to the command line
187 |
188 | Takes the same arguments as ArgumentParser.parse_args
189 | """
190 | parser = argparse.ArgumentParser(
191 | description='Sphinx-Gallery Notebook converter')
192 | parser.add_argument('python_src_file', nargs='+',
193 | help='Input Python file script to convert. '
194 | 'Supports multiple files and shell wildcards'
195 | ' (e.g. *.py)')
196 | args = parser.parse_args(args, namespace)
197 |
198 | for src_file in args.python_src_file:
199 | blocks = split_code_and_text_blocks(src_file)
200 | print('Converting {0}'.format(src_file))
201 | example_nb = jupyter_notebook(blocks)
202 | save_notebook(example_nb, src_file.replace('.py', '.ipynb'))
203 |
--------------------------------------------------------------------------------
/doc/sphinxext/sphinx_gallery/py_source_parser.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | r"""
3 | Parser for python source files
4 | ==============================
5 | """
6 | # Created Sun Nov 27 14:03:07 2016
7 | # Author: Óscar Nájera
8 |
9 | from __future__ import division, absolute_import, print_function
10 | import ast
11 | import re
12 | from textwrap import dedent
13 |
14 |
15 | def get_docstring_and_rest(filename):
16 | """Separate `filename` content between docstring and the rest
17 |
18 | Strongly inspired from ast.get_docstring.
19 |
20 | Returns
21 | -------
22 | docstring: str
23 | docstring of `filename`
24 | rest: str
25 | `filename` content without the docstring
26 | """
27 | # can't use codecs.open(filename, 'r', 'utf-8') here b/c ast doesn't
28 | # seem to work with unicode strings in Python2.7
29 | # "SyntaxError: encoding declaration in Unicode string"
30 | with open(filename, 'rb') as f:
31 | content = f.read()
32 | # change from Windows format to UNIX for uniformity
33 | content = content.replace(b'\r\n', b'\n')
34 |
35 | node = ast.parse(content)
36 | if not isinstance(node, ast.Module):
37 | raise TypeError("This function only supports modules. "
38 | "You provided {0}".format(node.__class__.__name__))
39 | if node.body and isinstance(node.body[0], ast.Expr) and \
40 | isinstance(node.body[0].value, ast.Str):
41 | docstring_node = node.body[0]
42 | docstring = docstring_node.value.s
43 | if hasattr(docstring, 'decode'): # python2.7
44 | docstring = docstring.decode('utf-8')
45 | # This get the content of the file after the docstring last line
46 | # Note: 'maxsplit' argument is not a keyword argument in python2
47 | rest = content.decode('utf-8').split('\n', docstring_node.lineno)[-1]
48 | return docstring, rest
49 | else:
50 | raise ValueError(('Could not find docstring in file "{0}". '
51 | 'A docstring is required by sphinx-gallery')
52 | .format(filename))
53 |
54 |
55 | def split_code_and_text_blocks(source_file):
56 | """Return list with source file separated into code and text blocks.
57 |
58 | Returns
59 | -------
60 | blocks : list of (label, content)
61 | List where each element is a tuple with the label ('text' or 'code'),
62 | and content string of block.
63 | """
64 | docstring, rest_of_content = get_docstring_and_rest(source_file)
65 | blocks = [('text', docstring)]
66 |
67 | pattern = re.compile(
68 | r'(?P^#{20,}.*)\s(?P(?:^#.*\s)*)',
69 | flags=re.M)
70 |
71 | pos_so_far = 0
72 | for match in re.finditer(pattern, rest_of_content):
73 | match_start_pos, match_end_pos = match.span()
74 | code_block_content = rest_of_content[pos_so_far:match_start_pos]
75 | text_content = match.group('text_content')
76 | sub_pat = re.compile('^#', flags=re.M)
77 | text_block_content = dedent(re.sub(sub_pat, '', text_content)).lstrip()
78 | if code_block_content.strip():
79 | blocks.append(('code', code_block_content))
80 | if text_block_content.strip():
81 | blocks.append(('text', text_block_content))
82 | pos_so_far = match_end_pos
83 |
84 | remaining_content = rest_of_content[pos_so_far:]
85 | if remaining_content.strip():
86 | blocks.append(('code', remaining_content))
87 |
88 | return blocks
89 |
--------------------------------------------------------------------------------
/doc/user_guide.rst:
--------------------------------------------------------------------------------
1 | .. title:: User guide: contents
2 |
3 | ..
4 | We are putting the title as a raw HTML so that it doesn't appear in
5 | the contents
6 |
7 | .. raw:: html
8 |
9 | User guide: contents
10 |
11 | .. _user_guide:
12 |
13 | .. toctree::
14 | :numbered:
15 |
16 | overview
17 | auto_examples/index
18 | Reference
--------------------------------------------------------------------------------
/examples/README.txt:
--------------------------------------------------------------------------------
1 | .. _general_examples:
2 |
3 | General examples
4 | ----------------
5 |
6 | General-purpose and introductory examples for the pyprox.
7 |
--------------------------------------------------------------------------------
/examples/plot_l1_constraints_dr.py:
--------------------------------------------------------------------------------
1 | """
2 | ===================================
3 | Basis Pursuit with Douglas Rachford
4 | ===================================
5 |
6 | Test DR for a standard constrained l1-minimization
7 | """
8 | # Author: Samuel Vaiter
9 | from __future__ import print_function, division
10 | print(__doc__)
11 |
12 | # modules
13 | import time
14 |
15 | import numpy as np
16 | import scipy.linalg as lin
17 | import matplotlib.pylab as plt
18 |
19 | from pyprox import douglas_rachford
20 | from pyprox.operators import soft_thresholding
21 | from pyprox.context import Context
22 |
23 | # Dimension of the problem
24 | n = 500
25 | p = n // 4
26 |
27 | # Matrix and observations
28 | A = np.random.randn(p, n)
29 | y = np.random.randn(p, 1)
30 |
31 | # operator callbacks
32 | prox_f = soft_thresholding
33 | prox_g = lambda x, tau: x + np.dot(A.T, lin.solve(np.dot(A, A.T),
34 | y - np.dot(A, x)))
35 |
36 | # context
37 | ctx = Context(full_output=True, maxiter=1000)
38 | ctx.callback = lambda x: lin.norm(x, 1)
39 |
40 | t1 = time.time()
41 | x, fx = douglas_rachford(prox_f, prox_g, np.zeros((n, 1)), context=ctx)
42 | t2 = time.time()
43 | print("Performed 1000 iterations in " + str(t2 - t1) + " seconds.")
44 |
45 | plt.plot(fx)
46 | plt.show()
47 |
--------------------------------------------------------------------------------
/examples/plot_l1_lagrangian_fb.py:
--------------------------------------------------------------------------------
1 | """
2 | =================================================================
3 | Basis Pursuit Denoising with Forward-Backward : CS Regularization
4 | =================================================================
5 |
6 | Test the use of Forward-backward-like spltitting for the resolution of a
7 | compressed sensing regularization
8 | """
9 | # Author: Samuel Vaiter
10 | from __future__ import print_function, division
11 | print(__doc__)
12 |
13 | # modules
14 | import time
15 |
16 | import numpy as np
17 | import scipy.linalg as lin
18 | import matplotlib.pylab as plt
19 |
20 | from pyprox import forward_backward, soft_thresholding
21 | from pyprox.context import Context
22 |
23 | n = 600
24 | p = n // 4
25 | la = 1.0 # regularization parameter
26 |
27 | # Matrix and observations
28 | A = np.random.randn(p, n)
29 | y = np.random.randn(p, 1)
30 |
31 | # List of benchmarked algorithms
32 | methods = ['fb', 'fista', 'nesterov']
33 |
34 | # operator callbacks
35 | F = lambda x: la * lin.norm(x, 1)
36 | G = lambda x: 1 / 2 * lin.norm(y - np.dot(A, x)) ** 2
37 | prox_f = lambda x, tau: soft_thresholding(x, la * tau)
38 | grad_g = lambda x: np.dot(A.T, np.dot(A, x) - y)
39 |
40 | L = lin.norm(A, 2) ** 2 # Lipschitz constant
41 |
42 | # context
43 | maxiter = 1000
44 | ctx = Context(full_output=True, maxiter=maxiter)
45 | ctx.callback = lambda x: F(x) + G(x)
46 |
47 | res = np.zeros((maxiter, len(methods)))
48 | i = 0
49 | for method in methods:
50 | t1 = time.time()
51 | x, fx = forward_backward(prox_f, grad_g, np.zeros((n, 1)), L,
52 | method=method, context=ctx)
53 | t2 = time.time()
54 | print ("[" + method + "]: Performed 1000 iterations in " \
55 | + str(t2 - t1) + "seconds.")
56 | res[:, i] = fx
57 | i += 1
58 |
59 | e = np.min(res.flatten())
60 |
61 | plt.loglog(res[:(maxiter // 10), :] - e)
62 | plt.legend(methods)
63 | plt.grid(True, which="both", ls="-")
64 | plt.tight_layout()
65 | plt.show()
66 |
--------------------------------------------------------------------------------
/examples/plot_tv_denoising_lena.py:
--------------------------------------------------------------------------------
1 | """
2 | ==============================================
3 | Total variation denoising using Chambolle Pock
4 | ==============================================
5 |
6 | Test the use of ADMM for a denoising scenario with anistropic TV
7 | """
8 | # Author: Samuel Vaiter
9 | from __future__ import print_function, division
10 |
11 | print(__doc__)
12 |
13 | import time
14 | import numpy as np
15 | from scipy import misc
16 | import scipy.linalg as lin
17 | import pylab as plt
18 |
19 | from pyprox import dual_prox, admm
20 | from pyprox.operators import soft_thresholding
21 | from pyprox.context import Context
22 |
23 | # Load image, downsample and convert to a float
24 | im = misc.face()[:,:,0]
25 | im = misc.imresize(im, (256, 256)).astype(np.float) / 255.
26 |
27 | n = im.shape[0]
28 |
29 | # Noisy observations
30 | sigma = 0.06
31 | y = im + sigma * np.random.randn(n, n)
32 |
33 | # Regularization parameter
34 | alpha = 0.1
35 |
36 | # Gradient and divergence with periodic boundaries
37 |
38 |
39 | def gradient(x):
40 | g = np.zeros((x.shape[0], x.shape[1], 2))
41 | g[:, :, 0] = np.roll(x, -1, axis=0) - x
42 | g[:, :, 1] = np.roll(x, -1, axis=1) - x
43 | return g
44 |
45 |
46 | def divergence(p):
47 | px = p[:, :, 0]
48 | py = p[:, :, 1]
49 | resx = px - np.roll(px, 1, axis=0)
50 | resy = py - np.roll(py, 1, axis=1)
51 | return -(resx + resy)
52 |
53 | # Minimization of F(K*x) + G(x)
54 | K = gradient
55 | K.T = divergence
56 | amp = lambda u: np.sqrt(np.sum(u ** 2, axis=2))
57 | F = lambda u: alpha * np.sum(amp(u))
58 | G = lambda x: 1 / 2 * lin.norm(y - x, 'fro') ** 2
59 |
60 | # Proximity operators
61 | normalize = lambda u: u / np.tile(
62 | (np.maximum(amp(u), 1e-10))[:, :, np.newaxis],
63 | (1, 1, 2))
64 | prox_f = lambda u, tau: np.tile(
65 | soft_thresholding(amp(u), alpha * tau)[:, :, np.newaxis],
66 | (1, 1, 2)) * normalize(u)
67 | prox_fs = dual_prox(prox_f)
68 | prox_g = lambda x, tau: (x + tau * y) / (1 + tau)
69 |
70 |
71 | # context
72 | ctx = Context(full_output=True, maxiter=300)
73 | ctx.callback = lambda x: G(x) + F(K(x))
74 |
75 | t1 = time.time()
76 | x_rec, cx = admm(prox_fs, prox_g, K, y, context=ctx)
77 | t2 = time.time()
78 | print("Performed 300 iterations in " + str(t2 - t1) + " seconds.")
79 |
80 |
81 | plt.subplot(221)
82 | plt.imshow(im, cmap='gray')
83 | plt.title('Original')
84 | plt.axis('off')
85 | plt.subplot(222)
86 | plt.imshow(y, cmap='gray')
87 | plt.title('Noisy')
88 | plt.axis('off')
89 | plt.subplot(223)
90 | plt.imshow(x_rec, cmap='gray')
91 | plt.title('TV Regularization')
92 | plt.axis('off')
93 | plt.subplot(224)
94 | fplot = plt.plot(cx)
95 | plt.title('Objective versus iterations')
96 | plt.show()
97 |
--------------------------------------------------------------------------------
/pyprox/__init__.py:
--------------------------------------------------------------------------------
1 | from pyprox.algorithms import douglas_rachford, forward_backward, \
2 | forward_backward_dual, admm
3 | from pyprox.operators import soft_thresholding, dual_prox
4 |
5 | __all__ = ['algorithms','datasets','utils']
6 | __version__ = "0.1"
--------------------------------------------------------------------------------
/pyprox/algorithms.py:
--------------------------------------------------------------------------------
1 | """
2 | The :mod:`pyprox.algorithms` module includes the proximal schemes of pyprox.
3 | """
4 | # Author: Samuel Vaiter
5 |
6 | from __future__ import division
7 | import numpy as np
8 | import math
9 | from .utils import operator_norm
10 | from .context import Context, defaultContext
11 |
12 |
13 | def douglas_rachford(prox_f, prox_g, x0,
14 | mu=1, gamma=1, context=defaultContext):
15 | """Minimize the sum of two functions using the Douglas Rachford splitting.
16 | scheme.
17 |
18 | This algorithm assumes that F, G are both "proximable" where the
19 | optimization objective reads::
20 |
21 | F(x) + G(x)
22 |
23 | Parameters
24 | ----------
25 | prox_f : callable
26 | should take two arguments : an ndarray and a float.
27 | prox_g : callable
28 | same as prox_f.
29 | x0 : ndarray
30 | initial guess for the solution.
31 | mu : float, optional
32 | gamma : float, optional
33 | context: Context
34 | the context (default to defaultContext)
35 |
36 | Returns
37 | -------
38 | x_rec: ndarray
39 | fx: list
40 |
41 | References
42 | ----------
43 | Proximal Splitting Methods in Signal Processing,
44 | Patrick L. Combettes and Jean-Christophe Pesquet, in:
45 | Fixed-Point Algorithms for Inverse Problems in Science and Engineering,
46 | New York: Springer-Verlag, 2010.
47 | """
48 | def rProx_f(x, tau):
49 | return 2 * prox_f(x, tau) - x
50 |
51 | def rProx_g(x, tau):
52 | return 2 * prox_g(x, tau) - x
53 |
54 | x = x0.copy()
55 | y = x0.copy()
56 |
57 | def step(x, y):
58 | y = (1 - mu / 2) * y + mu / 2 * rProx_f(rProx_g(y, gamma), gamma)
59 | x = prox_g(y, gamma)
60 | return [x, y]
61 |
62 | return context.execute([x, y], step)
63 |
64 |
65 | def forward_backward(prox_f, grad_g, x0, L,
66 | method='fb', fbdamping=1.8, context=defaultContext):
67 | """Minimize the sum of two functions using the Forward-backward splitting.
68 | scheme.
69 |
70 | This algorithm assumes that F, G is "proximable" and L has a
71 | L-Lipschitz gradient where the optimization objective reads::
72 |
73 | F(x) + G(x)
74 |
75 | Parameters
76 | ----------
77 | prox_f : callable
78 | should take two arguments : an ndarray and a float.
79 | grad_g : callable
80 | same as prox_f.
81 | x0 : ndarray
82 | initial guess for the solution.
83 | L : float
84 | Module of Lipschitz of nabla G.
85 | method : string, optional,
86 | can be 'fb', 'fista' or 'nesterov'
87 | fbdamping : float, optional
88 | context: Context
89 | the context (default to defaultContext)
90 |
91 | Returns
92 | -------
93 | x_rec: ndarray
94 | fx: list
95 |
96 | References
97 | ----------
98 | P. L. Combettes and V. R. Wajs, Signal recovery by proximal
99 | forward-backward splitting,
100 | Multiscale Model. Simul., 4 (2005), pp. 1168-1200
101 | """
102 | # FISTA
103 | t = 1
104 |
105 | # Nesterov
106 | tt = 2 / L
107 | gg = 0
108 | A = 0
109 |
110 | y = x0.copy()
111 | x = x0.copy()
112 |
113 | def step_fb(x):
114 | x = prox_f(x - fbdamping / L * grad_g(x), fbdamping / L)
115 | return [x]
116 |
117 | def step_fista(x, y, t):
118 | xnew = prox_f(y - 1 / L * grad_g(y), 1 / L)
119 | tnew = (1 + math.sqrt(1 + 4 * t ** 2)) / 2
120 | y = xnew + (t - 1) / tnew * (xnew - x)
121 | x = xnew
122 | t = tnew
123 | return [x, y, t]
124 |
125 | def step_nesterov(x, tt, gg, A):
126 | a = (tt + math.sqrt(tt ** 2 + 4 * tt * A)) / 2
127 | v = prox_f(x0 - gg, A)
128 | z = (A * x + a * v) / (A + a)
129 | x = prox_f(z - 1 / L * grad_g(z), 1 / L)
130 | gg += a * grad_g(x)
131 | A += a
132 | return [x, tt, gg, A]
133 |
134 | if method == "fb":
135 | return context.execute([x], step_fb)
136 | elif method == "fista":
137 | return context.execute([x, y, t], step_fista)
138 | elif method == "nesterov":
139 | return context.execute([x, tt, gg, A], step_nesterov)
140 | else:
141 | raise Exception('ex a def in fb')
142 |
143 |
144 | def forward_backward_dual(grad_fs, prox_gs, K, x0, L,
145 | method='fb', fbdamping=1.8, context=defaultContext):
146 | """Minimize the sum of the strongly convex function and a proper convex
147 | function.
148 |
149 | This algorithm minimizes
150 |
151 | F(x) + G(K(x))
152 |
153 | where F is strongly convex, G is a proper convex function and K is a
154 | linear operator by a duality argument.
155 |
156 | Parameters
157 | ----------
158 | grad_fs : callable
159 | should take one argument : an ndarray.
160 | prox_gs : callable
161 | should take two arguments : an ndarray and a float.
162 | K : callable or ndarray
163 | a linear operator
164 | KS : callable or ndarray
165 | the dual linear operator
166 | x0 : ndarray
167 | initial guess for the solution.
168 | L : float
169 | Module of Lipschitz of nabla G.
170 | method : string, optional,
171 | can be 'fb', 'fista' or 'nesterov'
172 | fbdamping : float, optional
173 | context: Context
174 | the context (default to defaultContext)
175 |
176 | Returns
177 | -------
178 | x_rec: ndarray
179 | fx: list
180 |
181 | Notes
182 | -----
183 | This algorithm use the equivalence of
184 |
185 | min_x F(x) + G(K(x)) (*)
186 |
187 | with
188 |
189 | min_u F^*(-K(u)) + G^*(u) (**)
190 |
191 | using x = grad(F^*)(-K(u)) where the convex dual function is
192 |
193 | F^*(y) = sup_x = - F(x)
194 |
195 | It uses `forward_backward` as a solver of (**)
196 | """
197 | if isinstance(K, np.ndarray):
198 | op = lambda u: np.dot(K, u)
199 | op.T = lambda u: np.dot(K.T, u)
200 | return forward_backward_dual(
201 | grad_fs, prox_gs, op, x0, L,
202 | method=method, fbdamping=fbdamping,
203 | context=context)
204 |
205 | if context.callback:
206 | old_callback = context.callback
207 | context.callback = lambda u: old_callback(grad_fs(-K.T(u)))
208 | new_grad = lambda u: - K(grad_fs(-K.T(u)))
209 | u0 = K(x0)
210 | res = forward_backward(
211 | prox_gs, new_grad, u0, L,
212 | method=method, fbdamping=fbdamping,
213 | context=context)
214 |
215 | try:
216 | res[0] = grad_fs(-K.T(res[0]))
217 | except:
218 | res = grad_fs(-K.T(res))
219 | return res
220 |
221 |
222 | def admm(prox_fs, prox_g, K, x0,
223 | theta=1, sigma=None, tau=None, context=defaultContext):
224 | """Minimize an optimization problem using the Preconditioned Alternating
225 | direction method of multipliers
226 |
227 | This algorithm assumes that F, G are both "proximable" where the
228 | optimization objective reads::
229 |
230 | F(K(x)) + G(x)
231 |
232 | where K is a linear operator.
233 |
234 | Parameters
235 | ----------
236 | prox_fs : callable
237 | should take two arguments : an ndarray and a float.
238 | prox_g : callable
239 | same as prox_f.
240 | K : callable or ndarray
241 | a linear operator
242 | KS : callable or ndarray
243 | the dual linear operator
244 | x0 : ndarray
245 | initial guess for the solution.
246 | theta : float, optional
247 | sigma : float, optional
248 | parameters of the method.
249 | They should satisfy sigma * tay * norm(K)^2 < 1
250 | context: Context
251 | the context (default to defaultContext)
252 |
253 | Returns
254 | -------
255 | x_rec: ndarray
256 | fx: list
257 |
258 | References
259 | ----------
260 | A. Chambolle and T. Pock,
261 | A First-Order Primal-Dual Algorithm for Convex Problems
262 | with Applications to Imaging,
263 | JOURNAL OF MATHEMATICAL IMAGING AND VISION
264 | Volume 40, Number 1 (2011)
265 | """
266 | if isinstance(K, np.ndarray):
267 | op = lambda u: np.dot(K, u)
268 | op.T = lambda u: np.dot(K.T, u)
269 | return admm(prox_fs, prox_g, op, x0, theta=theta,
270 | sigma=sigma, tau=tau, context=context)
271 | if not(sigma and tau):
272 | L = operator_norm(
273 | lambda x: K.T(K(x)),
274 | np.random.randn(x0.shape[0], 1)
275 | )
276 | sigma = 10.0
277 | if sigma * L > 1e-10:
278 | tau = .9 / (sigma * L)
279 | else:
280 | tau = 0.0
281 |
282 | x = x0.copy()
283 | x1 = x0.copy()
284 | xold = x0.copy()
285 | y = K(x)
286 |
287 | def step(x, x1, xold, y):
288 | xold = x.copy()
289 | y = prox_fs(y + sigma * K(x1), sigma)
290 | x = prox_g(x - tau * K.T(y), tau)
291 | x1 = x + theta * (x - xold)
292 | return [x, x1, xold, y]
293 |
294 | return context.execute([x, x1, xold, y], step)
295 |
--------------------------------------------------------------------------------
/pyprox/context.py:
--------------------------------------------------------------------------------
1 | """
2 | The :mod:`pyprox.context` module includes the definition of severals contexts.
3 | """
4 | # Author: Samuel Vaiter
5 |
6 |
7 | def maxiter_criterion(values, iteration, allvecs, fx, maxiter=100):
8 | return iteration < maxiter
9 |
10 |
11 | class Context(object):
12 | """A Context object
13 | maxiter : int, optional
14 | maximum number of iterations.
15 | full_output : bool, optional
16 | non-zero to return all optional outputs.
17 | retall : bool, optional
18 | Return a list of results at each iteration if non-zero.
19 | callback : callable, optional
20 | An optional user-supplied function to call after each iteration.
21 | Called as callback(xk), where xk is the current parameter vector.
22 | """
23 | def __init__(self, criterion=maxiter_criterion,
24 | full_output=False, retall=False, callback=None, **kwargs):
25 | self.criterion = criterion
26 | if 'maxiter' in kwargs:
27 | self.criterion = lambda v, i, a, f: maxiter_criterion(v, i, a, f, maxiter=kwargs['maxiter'])
28 | self.full_output = full_output
29 | self.retall = retall
30 | self.callback = callback
31 |
32 | def execute(self, values, step):
33 | allvecs = [values[0]]
34 | fx = []
35 | iteration = 0
36 | while self.criterion(values, iteration, allvecs, fx):
37 | values = step(*values)
38 | x = values[0]
39 | if self.full_output:
40 | pass
41 | if self.retall:
42 | allvecs.append(x)
43 | if self.callback:
44 | fx.append(self.callback(x))
45 | iteration += 1
46 | return self._output_helper(x, fx, allvecs)
47 |
48 | def _output_helper(self, x, fx, allvecs):
49 | if self.full_output:
50 | retlist = x, fx
51 | if self.retall:
52 | retlist += (allvecs,)
53 | else:
54 | retlist = x
55 | if self.retall:
56 | retlist = (x, allvecs)
57 |
58 | return retlist
59 |
60 |
61 | defaultContext = Context()
62 |
--------------------------------------------------------------------------------
/pyprox/operators.py:
--------------------------------------------------------------------------------
1 | """
2 | Proximal operators
3 | """
4 | # Author: Samuel Vaiter
5 |
6 | from __future__ import division
7 | import numpy as np
8 |
9 |
10 | def soft_thresholding(x, gamma):
11 | return np.maximum(0, 1 - gamma / np.maximum(np.abs(x), 1E-10)) * x
12 |
13 |
14 | def dual_prox(prox):
15 | return lambda u, sigma: u - sigma * prox(u / sigma, 1 / sigma)
16 |
--------------------------------------------------------------------------------
/pyprox/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/svaiter/pyprox/ffc3084a2478536fec808273e16bd7f22e6a9e3c/pyprox/tests/__init__.py
--------------------------------------------------------------------------------
/pyprox/tests/test_admm.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | from numpy.testing import assert_array_almost_equal
3 |
4 | import numpy as np
5 | from pyprox.algorithms import admm
6 |
7 |
8 | def test_admm_virtual_zero():
9 | # Virtual 0-prox
10 | prox_fs = lambda u, la: u * 0
11 | prox_g = lambda u, la: u * 0
12 |
13 | # ndarray
14 | k_nd = np.zeros((5, 5))
15 | # explicit
16 | k_exp = lambda u: 0 * u
17 | k_exp.T = lambda u: 0 * u
18 |
19 | # observations of size (5,1)
20 | y = np.zeros((5, 1))
21 | x_rec = admm(prox_fs, prox_g, k_nd, y)
22 | assert_array_almost_equal(y, x_rec)
23 | x_rec = admm(prox_fs, prox_g, k_exp, y)
24 | assert_array_almost_equal(y, x_rec)
25 |
26 | # observations of size (5,2)
27 | y = np.zeros((5, 2))
28 | x_rec = admm(prox_fs, prox_g, k_nd, y)
29 | assert_array_almost_equal(y, x_rec)
30 | x_rec = admm(prox_fs, prox_g, k_exp, y)
31 | assert_array_almost_equal(y, x_rec)
32 |
--------------------------------------------------------------------------------
/pyprox/tests/test_douglas_rachford.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | from numpy.testing import assert_array_almost_equal
3 |
4 | import numpy as np
5 | import scipy.linalg as lin
6 | from pyprox.algorithms import douglas_rachford
7 | from pyprox.operators import soft_thresholding
8 |
9 |
10 | def test_dr_virtual_zero():
11 | # Virtual 0-prox
12 | prox_f = lambda u, la: 0 * 0
13 | prox_g = lambda u, la: u * 0
14 |
15 | # observations of size (5,1)
16 | y = np.zeros((5, 1))
17 | x_rec = douglas_rachford(prox_f, prox_g, y)
18 | assert_array_almost_equal(y, x_rec)
19 |
20 | # observations of size (5,2)
21 | y = np.zeros((5, 2))
22 | x_rec = douglas_rachford(prox_f, prox_g, y)
23 | assert_array_almost_equal(y, x_rec)
24 |
25 |
26 | def test_dr_zero():
27 | # Prox of F, G = 0
28 | prox_f = lambda u, la: u
29 | prox_g = lambda u, la: u
30 |
31 | # observations of size (5,1)
32 | y = np.zeros((5, 1))
33 | x_rec = douglas_rachford(prox_f, prox_g, y)
34 | assert_array_almost_equal(y, x_rec)
35 |
36 | # observations of size (5,2)
37 | y = np.zeros((5, 2))
38 | x_rec = douglas_rachford(prox_f, prox_g, y)
39 | assert_array_almost_equal(y, x_rec)
40 |
41 |
42 | def test_dr_l1_cs():
43 | # Dimension of the problem
44 | n = 200
45 | p = n // 4
46 |
47 | # Matrix and observations
48 | A = np.random.randn(p, n)
49 | # Use a very sparse vector for the test
50 | x = np.zeros((n, 1))
51 | x[1, :] = 1
52 | y = np.dot(A, x)
53 |
54 | # operator callbacks
55 | prox_f = soft_thresholding
56 | prox_g = lambda x, tau: x + np.dot(A.T, lin.solve(np.dot(A, A.T),
57 | y - np.dot(A, x)))
58 |
59 | x_rec = douglas_rachford(prox_f, prox_g, np.zeros((n, 1)))
60 | assert_array_almost_equal(x, x_rec)
61 |
--------------------------------------------------------------------------------
/pyprox/tests/test_forward_backward.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | from numpy.testing import assert_array_almost_equal
3 |
4 | import numpy as np
5 | from pyprox.algorithms import forward_backward
6 | from pyprox.operators import soft_thresholding
7 |
8 | methods = ['fb', 'fista', 'nesterov']
9 |
10 |
11 | def test_fb_virtual_zero():
12 | # Virtual 0-prox
13 | prox_f = lambda u, la: 0 * 0
14 | grad_g = lambda u: u * 0
15 |
16 | # observations of size (5,1)
17 | y = np.zeros((5, 1))
18 | for method in methods:
19 | x_rec = forward_backward(prox_f, grad_g, y, 1, method=method)
20 | assert_array_almost_equal(y, x_rec)
21 |
22 | # observations of size (5,2)
23 | y = np.zeros((5, 2))
24 | for method in methods:
25 | x_rec = forward_backward(prox_f, grad_g, y, 1, method=method)
26 | assert_array_almost_equal(y, x_rec)
27 |
28 |
29 | def test_fb_zero():
30 | prox_f = lambda u, la: u
31 | grad_g = lambda u: u * 0
32 |
33 | # observations of size (5,1)
34 | y = np.zeros((5, 1))
35 | for method in methods:
36 | x_rec = forward_backward(prox_f, grad_g, y, 1, method=method)
37 | assert_array_almost_equal(y, x_rec)
38 |
39 | # observations of size (5,2)
40 | y = np.zeros((5, 2))
41 | for method in methods:
42 | x_rec = forward_backward(prox_f, grad_g, y, 1, method=method)
43 | assert_array_almost_equal(y, x_rec)
44 |
45 |
46 | def test_fb_l1_denoising():
47 | n = 1000
48 | # Use a very sparse vector for the test
49 | x = np.zeros((n, 1))
50 | x[1, :] = 100
51 | y = x + 0.06 * np.random.randn(n, 1)
52 |
53 | la = 0.2
54 | prox_f = lambda x, tau: soft_thresholding(x, la * tau)
55 | grad_g = lambda x: x - y
56 |
57 | for method in methods:
58 | x_rec = forward_backward(prox_f, grad_g, y, 1, method=method)
59 | #TODO ugly test to change
60 | assert_array_almost_equal(x, x_rec, decimal=0)
61 |
--------------------------------------------------------------------------------
/pyprox/tests/test_forward_backward_dual.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | from numpy.testing import assert_array_almost_equal
3 |
4 | import numpy as np
5 | from pyprox.algorithms import forward_backward_dual
6 |
7 |
8 | def test_fb_dual_virtual_zero():
9 | # Virtual 0-prox
10 | grad_fs = lambda u: u * 0
11 | prox_gs = lambda u, la: u * 0
12 |
13 | # ndarray
14 | k_nd = np.zeros((5, 5))
15 | # explicit
16 | k_exp = lambda u: 0 * u
17 | k_exp.T = lambda u: 0 * u
18 |
19 | # observations of size (5,1)
20 | y = np.zeros((5, 1))
21 | x_rec = forward_backward_dual(grad_fs, prox_gs, k_nd, y, 1)
22 | assert_array_almost_equal(y, x_rec)
23 | x_rec = forward_backward_dual(grad_fs, prox_gs, k_exp, y, 1)
24 | assert_array_almost_equal(y, x_rec)
25 |
26 | # observations of size (5,2)
27 | y = np.zeros((5, 2))
28 | x_rec = forward_backward_dual(grad_fs, prox_gs, k_nd, y, 1)
29 | assert_array_almost_equal(y, x_rec)
30 | x_rec = forward_backward_dual(grad_fs, prox_gs, k_exp, y, 1)
31 | assert_array_almost_equal(y, x_rec)
32 |
--------------------------------------------------------------------------------
/pyprox/utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Misc utils
3 | """
4 | # Author: Samuel Vaiter
5 |
6 | from __future__ import division
7 | import numpy as np
8 | import scipy.linalg as lin
9 |
10 |
11 | def operator_norm(linop, n=None, maxiter=30, check=False):
12 | if hasattr(linop, 'norm') and not check:
13 | return linop.norm
14 | if n is None:
15 | n = np.random.randn(linop.dim[1], 1)
16 | if np.size(n) == 1:
17 | u = np.random.randn(n, 1)
18 | else:
19 | u = n
20 | unorm = lin.norm(u)
21 | if unorm > 1e-10:
22 | u = u / unorm
23 | else:
24 | return 0
25 | e = []
26 | for i in range(maxiter):
27 | if hasattr(linop, 'T'):
28 | v = linop.T(linop(u))
29 | else:
30 | # assume square (implicit) operator
31 | v = linop(u)
32 | e.append((u[:] * v[:]).sum())
33 | vnorm = lin.norm(v[:])
34 | if vnorm > 1e-10:
35 | u = v / vnorm
36 | else:
37 | return 0
38 | L = e[-1]
39 | return L
40 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python
2 | #
3 | descr = """Python module for proximal algorithms"""
4 |
5 | import sys
6 | import os
7 | import shutil
8 |
9 | DISTNAME = 'pyprox'
10 | DESCRIPTION = 'Python module for proximal algorithms'
11 | LONG_DESCRIPTION = open('README.rst').read()
12 | MAINTAINER = 'Samuel Vaiter'
13 | MAINTAINER_EMAIL = 'samuel.vaiter@gmail.com'
14 | URL = 'http://svaiter.github.com/pyprox'
15 | LICENSE = 'new BSD'
16 | DOWNLOAD_URL = 'http://github.com/svaiter/pyprox/downloads'
17 | VERSION = '0.1'
18 |
19 | import setuptools # we are using a setuptools namespace
20 | from numpy.distutils.core import setup
21 |
22 |
23 | def configuration(parent_package='', top_path=None):
24 | if os.path.exists('MANIFEST'):
25 | os.remove('MANIFEST')
26 |
27 | from numpy.distutils.misc_util import Configuration
28 | config = Configuration(None, parent_package, top_path,
29 | namespace_packages=['pyprox'])
30 |
31 | config.add_subpackage('pyprox')
32 | config.add_subpackage('pyprox/tests')
33 |
34 | return config
35 |
36 | if __name__ == "__main__":
37 |
38 | old_path = os.getcwd()
39 | local_path = os.path.dirname(os.path.abspath(sys.argv[0]))
40 | os.chdir(local_path)
41 | sys.path.insert(0, local_path)
42 |
43 | setup(configuration=configuration,
44 | name=DISTNAME,
45 | maintainer=MAINTAINER,
46 | include_package_data=True,
47 | maintainer_email=MAINTAINER_EMAIL,
48 | description=DESCRIPTION,
49 | license=LICENSE,
50 | url=URL,
51 | version=VERSION,
52 | download_url=DOWNLOAD_URL,
53 | long_description=LONG_DESCRIPTION,
54 | zip_safe=False, # the package can run out of an .egg file
55 | classifiers=[
56 | 'Intended Audience :: Science/Research',
57 | 'Intended Audience :: Developers',
58 | 'License :: OSI Approved',
59 | 'Programming Language :: C',
60 | 'Programming Language :: Python',
61 | 'Topic :: Software Development',
62 | 'Topic :: Scientific/Engineering',
63 | 'Operating System :: Microsoft :: Windows',
64 | 'Operating System :: POSIX',
65 | 'Operating System :: Unix',
66 | 'Operating System :: MacOS'
67 | ]
68 | )
69 |
--------------------------------------------------------------------------------