# -*- coding: utf-8 -*-
"""
Defines the :class:`kwplot.mpl_plotnums.PlotNums` class to help manage a grid
of subplot numbers.
"""
from __future__ import absolute_import, division, print_function, unicode_literals
import numpy as np
import six
[docs]class PlotNums(object):
"""
Convinience class for dealing with plot numberings (pnums)
This is useful in the case where you want a certain number of subplots, but
you might swap the order in which those subplots are called. This class
introduces the idea of either getting the "next" subplot, or getting one at
a specific instance. The total number of subplots can be modified in just a
single place in the code (the arguments to the ``PlotNums`` constructor)
instead of each instance where you would specify a pnum normally.
Example:
>>> import ubelt as ub
>>> pnum_ = PlotNums(nRows=2, nCols=2)
>>> # Indexable
>>> print(pnum_[0])
(2, 2, 1)
>>> # Iterable
>>> print(ub.repr2(list(pnum_), nl=0, nobr=1))
(2, 2, 1), (2, 2, 2), (2, 2, 3), (2, 2, 4)
>>> # Callable (iterates through a default iterator)
>>> print(pnum_())
(2, 2, 1)
>>> print(pnum_())
(2, 2, 2)
"""
def __init__(self, nRows=None, nCols=None, nSubplots=None, start=0):
nRows, nCols = self._get_num_rc(nSubplots, nRows, nCols)
self.nRows = nRows
self.nCols = nCols
base = 0
self.offset = 0 if base == 1 else 1
self.start = start
self._iter = None
def __getitem__(self, px):
return (self.nRows, self.nCols, px + self.offset)
def __call__(self):
"""
replacement for make_pnum_nextgen
Example:
>>> import ubelt as ub
>>> import itertools as it
>>> pnum_ = PlotNums(nSubplots=9)
>>> pnum_list = [pnum_() for _ in range(len(pnum_))]
>>> result = ('pnum_list = %s' % (ub.repr2(pnum_list),))
>>> print(result)
Example:
>>> import ubelt as ub
>>> import itertools as it
>>> for nRows, nCols, nSubplots in it.product([None, 3], [None, 3], [None, 9]):
>>> start = 0
>>> pnum_ = PlotNums(nRows, nCols, nSubplots, start)
>>> pnum_list = [pnum_() for _ in range(len(pnum_))]
>>> print((nRows, nCols, nSubplots))
>>> result = ('pnum_list = %s' % (ub.repr2(pnum_list),))
>>> print(result)
"""
if self._iter is None:
self._iter = iter(self)
return six.next(self._iter)
def __iter__(self):
r"""
Yields:
tuple : pnum
Example:
>>> import ubelt as ub
>>> pnum_ = iter(PlotNums(nRows=3, nCols=2))
>>> result = ub.repr2(list(pnum_), nl=1, nobr=1)
>>> print(result)
(3, 2, 1),
(3, 2, 2),
(3, 2, 3),
(3, 2, 4),
(3, 2, 5),
(3, 2, 6),
Example:
>>> import ubelt as ub
>>> nRows = 3
>>> nCols = 2
>>> pnum_ = iter(PlotNums(nRows, nCols, start=3))
>>> result = ub.repr2(list(pnum_), nl=1, nobr=1)
>>> print(result)
(3, 2, 4),
(3, 2, 5),
(3, 2, 6),
"""
for px in range(self.start, len(self)):
yield self[px]
def __len__(self):
total_plots = self.nRows * self.nCols
return total_plots
@classmethod
def _get_num_rc(PlotNums, nSubplots=None, nRows=None, nCols=None):
r"""
Gets a constrained row column plot grid
Args:
nSubplots (None): (default = None)
nRows (None): (default = None)
nCols (None): (default = None)
Returns:
tuple: (nRows, nCols)
Example:
>>> import ubelt as ub
>>> cases = [
>>> dict(nRows=None, nCols=None, nSubplots=None),
>>> dict(nRows=2, nCols=None, nSubplots=5),
>>> dict(nRows=None, nCols=2, nSubplots=5),
>>> dict(nRows=None, nCols=None, nSubplots=5),
>>> ]
>>> for kw in cases:
>>> print('----')
>>> size = PlotNums._get_num_rc(**kw)
>>> if kw['nSubplots'] is not None:
>>> assert size[0] * size[1] >= kw['nSubplots']
>>> print('**kw = %s' % (ub.repr2(kw),))
>>> print('size = %r' % (size,))
"""
if nSubplots is None:
if nRows is None:
nRows = 1
if nCols is None:
nCols = 1
else:
if nRows is None and nCols is None:
nRows, nCols = PlotNums._get_square_row_cols(nSubplots)
elif nRows is not None:
nCols = int(np.ceil(nSubplots / nRows))
elif nCols is not None:
nRows = int(np.ceil(nSubplots / nCols))
return nRows, nCols
@staticmethod
def _get_square_row_cols(nSubplots, max_cols=None, fix=False, inclusive=True):
r"""
Args:
nSubplots (int):
max_cols (int):
Returns:
tuple: (int, int)
Example:
>>> nSubplots = 9
>>> nSubplots_list = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
>>> max_cols = None
>>> rc_list = [PlotNums._get_square_row_cols(nSubplots, fix=True) for nSubplots in nSubplots_list]
>>> print(repr(np.array(rc_list).T))
array([[1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 3],
[1, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4]])
"""
if nSubplots == 0:
return 0, 0
if inclusive:
rounder = np.ceil
else:
rounder = np.floor
if fix:
# This function is very broken, but it might have dependencies
# this is the correct version
nCols = int(rounder(np.sqrt(nSubplots)))
nRows = int(rounder(nSubplots / nCols))
return nRows, nCols
else:
# This is the clamped num cols version
# probably used in ibeis.viz
if max_cols is None:
max_cols = 5
if nSubplots in [4]:
max_cols = 2
if nSubplots in [5, 6, 7]:
max_cols = 3
if nSubplots in [8]:
max_cols = 4
nCols = int(min(nSubplots, max_cols))
#nCols = int(min(rounder(np.sqrt(nrids)), 5))
nRows = int(rounder(nSubplots / nCols))
return nRows, nCols