Source code for kwplot.mpl_draw

# -*- coding: utf-8 -*-
"""
Note, this module should be refactored into MPL figure drawings and cv2
on-image drawings.
"""
from __future__ import absolute_import, division, print_function, unicode_literals
import copy
import numpy as np
import ubelt as ub

__all__ = [
    'draw_boxes',
    'draw_line_segments',
    'plot_matrix',
    'draw_points',
    # draw_*_on_image functions are deprecated in favor of kwimage versions
    'draw_text_on_image',
    'draw_boxes_on_image',
    'draw_clf_on_image',
]


[docs]def draw_boxes(boxes, alpha=None, color='blue', labels=None, centers=False, fill=False, ax=None, lw=2): """ Args: boxes (kwimage.Boxes): labels (List[str]): of labels alpha (List[float]): alpha for each box centers (bool): draw centers or not lw (float): linewidth Example: >>> import kwimage >>> bboxes = kwimage.Boxes([[.1, .1, .6, .3], [.3, .5, .5, .6]], 'xywh') >>> draw_boxes(bboxes) >>> #kwplot.autompl() """ import kwplot import matplotlib as mpl from matplotlib import pyplot as plt if ax is None: ax = plt.gca() xywh = boxes.to_xywh().data transparent = kwplot.Color((0, 0, 0, 0)).as01('rgba') # More grouped patches == more efficient runtime if alpha is None: alpha = [1.0] * len(xywh) elif not ub.iterable(alpha): alpha = [alpha] * len(xywh) edgecolors = [kwplot.Color(color, alpha=a).as01('rgba') for a in alpha] color_groups = ub.group_items(range(len(edgecolors)), edgecolors) for edgecolor, idxs in color_groups.items(): if fill: fc = edgecolor else: fc = transparent rectkw = dict(ec=edgecolor, fc=fc, lw=lw, linestyle='solid') patches = [mpl.patches.Rectangle((x, y), w, h, **rectkw) for x, y, w, h in xywh[idxs]] col = mpl.collections.PatchCollection(patches, match_original=True) ax.add_collection(col) if centers not in [None, False]: default_centerkw = { # 'radius': 1, 'fill': True } centerkw = default_centerkw.copy() if isinstance(centers, dict): centerkw.update(centers) xy_centers = boxes.xy_center for fcolor, idxs in color_groups.items(): # TODO: radius based on size of bbox # if 'radius' not in centerkw: # boxes.area[idxs] patches = [ mpl.patches.Circle((x, y), ec=None, fc=fcolor, **centerkw) for x, y in xy_centers[idxs] ] col = mpl.collections.PatchCollection(patches, match_original=True) ax.add_collection(col) if labels: texts = [] default_textkw = { 'horizontalalignment': 'left', 'verticalalignment': 'top', 'backgroundcolor': (0, 0, 0, .8), 'color': 'white', 'fontproperties': mpl.font_manager.FontProperties( size=6, family='monospace'), } tkw = default_textkw.copy() for (x1, y1, w, h), label in zip(xywh, labels): texts.append((x1, y1, label, tkw)) for (x1, y1, catname, tkw) in texts: ax.text(x1, y1, catname, **tkw)
[docs]def draw_line_segments(pts1, pts2, ax=None, **kwargs): """ draws `N` line segments between `N` pairs of points Args: pts1 (ndarray): Nx2 pts2 (ndarray): Nx2 ax (None): (default = None) **kwargs: lw, alpha, colors Example: >>> import numpy as np >>> import kwplot >>> pts1 = np.array([(.1, .8), (.6, .8)]) >>> pts2 = np.array([(.6, .7), (.4, .1)]) >>> kwplot.figure(fnum=None) >>> draw_line_segments(pts1, pts2) >>> # xdoc: +REQUIRES(--show) >>> import matplotlib.pyplot as plt >>> ax = plt.gca() >>> ax.set_xlim(0, 1) >>> ax.set_ylim(0, 1) >>> kwplot.show_if_requested() """ import matplotlib.pyplot as plt import matplotlib as mpl if ax is None: ax = plt.gca() assert len(pts1) == len(pts2), 'unaligned' segments = [(xy1, xy2) for xy1, xy2 in zip(pts1, pts2)] linewidth = kwargs.pop('lw', kwargs.pop('linewidth', 1.0)) alpha = kwargs.pop('alpha', 1.0) if 'color' in kwargs: kwargs['colors'] = kwargs['color'] # mpl.colors.ColorConverter().to_rgb(kwargs['color']) line_group = mpl.collections.LineCollection(segments, linewidths=linewidth, alpha=alpha, **kwargs) ax.add_collection(line_group)
[docs]def plot_matrix(matrix, index=None, columns=None, rot=90, ax=None, grid=True, label=None, zerodiag=False, cmap='viridis', showvals=False, showzero=True, logscale=False, xlabel=None, ylabel=None, fnum=None, pnum=None): """ Helper for plotting confusion matrices Args: matrix (ndarray | pd.DataFrame) : if a data frame then index, columns, xlabel, and ylabel will be defaulted to sensible values. TODO: - [ ] Finish args docs - [ ] Replace internals with seaborn Example: >>> # xdoctest: +REQUIRES(module:pandas) >>> from kwplot.mpl_draw import * # NOQA >>> import pandas as pd >>> classes = ['cls1', 'cls2', 'cls3'] >>> matrix = np.array([[2, 2, 1], [3, 1, 0], [1, 0, 0]]) >>> matrix = pd.DataFrame(matrix, index=classes, columns=classes) >>> matrix.index.name = 'real' >>> matrix.columns.name = 'pred' >>> plot_matrix(matrix, showvals=True) >>> # xdoc: +REQUIRES(--show) >>> import matplotlib.pyplot as plt >>> import kwplot >>> kwplot.autompl() >>> plot_matrix(matrix, showvals=True) Example: >>> # xdoctest: +REQUIRES(module:pandas) >>> from kwplot.mpl_draw import * # NOQA >>> matrix = np.array([[2, 2, 1], [3, 1, 0], [1, 0, 0]]) >>> plot_matrix(matrix) >>> # xdoc: +REQUIRES(--show) >>> import matplotlib.pyplot as plt >>> import kwplot >>> kwplot.autompl() >>> plot_matrix(matrix) Example: >>> # xdoctest: +REQUIRES(module:pandas) >>> from kwplot.mpl_draw import * # NOQA >>> matrix = np.array([[2, 2, 1], [3, 1, 0], [1, 0, 0]]) >>> classes = ['cls1', 'cls2', 'cls3'] >>> plot_matrix(matrix, index=classes, columns=classes) """ import pandas as pd import matplotlib as mpl import matplotlib.cm # NOQA assert len(matrix.shape) == 2 if isinstance(matrix, pd.DataFrame): values = matrix.values if index is None and columns is None: index = matrix.index columns = matrix.columns if xlabel is None and ylabel is None: ylabel = index.name xlabel = columns.name else: values = matrix if index is None: index = np.arange(matrix.shape[0]) if columns is None: columns = np.arange(matrix.shape[1]) if ax is None: import kwplot fig = kwplot.figure(fnum=fnum, pnum=pnum) fig.clear() ax = fig.gca() if zerodiag: values = values.copy() values = values - np.diag(np.diag(values)) # aximg = ax.imshow(values, interpolation='none', cmap='viridis') if logscale: from matplotlib.colors import LogNorm vmin = values[values > 0].min().min() norm = LogNorm(vmin=vmin, vmax=values.max()) else: norm = None cmap = copy.copy(mpl.cm.get_cmap(cmap)) # copy the default cmap cmap.set_bad((0, 0, 0)) if not showzero and not logscale: # hack zero to be black cmap.colors[0] = [0, 0, 0] aximg = ax.matshow(values, interpolation='none', cmap=cmap, norm=norm) ax.grid(False) cax = ax.figure.colorbar(aximg, ax=ax) if label is not None: cax.set_label(label) ax.set_xticks(list(range(len(index)))) ax.set_xticklabels([str(lbl)[0:100] for lbl in index]) for lbl in ax.get_xticklabels(): lbl.set_rotation(rot) for lbl in ax.get_xticklabels(): lbl.set_horizontalalignment('center') ax.set_yticks(list(range(len(columns)))) ax.set_yticklabels([str(lbl)[0:100] for lbl in columns]) for lbl in ax.get_yticklabels(): lbl.set_horizontalalignment('right') for lbl in ax.get_yticklabels(): lbl.set_verticalalignment('center') # Grid lines around the pixels if grid: offset = -.5 xlim = [-.5, len(columns)] ylim = [-.5, len(index)] segments = [] for x in range(ylim[1]): xdata = [x + offset, x + offset] ydata = ylim segment = list(zip(xdata, ydata)) segments.append(segment) for y in range(xlim[1]): xdata = xlim ydata = [y + offset, y + offset] segment = list(zip(xdata, ydata)) segments.append(segment) bingrid = mpl.collections.LineCollection(segments, color='w', linewidths=1) ax.add_collection(bingrid) if showvals: x_basis = np.arange(len(columns)) y_basis = np.arange(len(index)) x, y = np.meshgrid(x_basis, y_basis) for c, r in zip(x.flatten(), y.flatten()): val = values[r, c] if val == 0: if showzero: ax.text(c, r, val, va='center', ha='center', color='white') else: ax.text(c, r, val, va='center', ha='center', color='white') if xlabel is not None: ax.set_xlabel(xlabel) if ylabel is not None: ax.set_ylabel(ylabel) return ax
[docs]def draw_points(xy, color='blue', class_idxs=None, classes=None, ax=None, alpha=None, radius=1, **kwargs): """ Args: xy (ndarray): of points. Example: >>> from kwplot.mpl_draw import * # NOQA >>> import kwimage >>> xy = kwimage.Points.random(10).xy >>> draw_points(xy, radius=0.01) >>> draw_points(xy, class_idxs=np.random.randint(0, 3, 10), >>> radius=0.01, classes=['a', 'b', 'c'], color='classes') Ignore: >>> import kwplot >>> kwplot.autompl() """ import kwimage import matplotlib as mpl from matplotlib import pyplot as plt if ax is None: ax = plt.gca() xy = xy.reshape(-1, 2) # More grouped patches == more efficient runtime if alpha is None: alpha = [1.0] * len(xy) elif not ub.iterable(alpha): alpha = [alpha] * len(xy) if color == 'distinct': colors = kwimage.Color.distinct(len(alpha)) elif color == 'classes': # TODO: read colors from categories if they exist if class_idxs is None or classes is None: raise Exception('cannot draw class colors without class_idxs and classes') try: cls_colors = kwimage.Color.distinct(len(classes)) except KeyError: raise Exception('cannot draw class colors without class_idxs and classes') import kwarray _keys, _vals = kwarray.group_indices(class_idxs) colors = list(ub.take(cls_colors, class_idxs)) else: colors = [color] * len(alpha) ptcolors = [kwimage.Color(c, alpha=a).as01('rgba') for c, a in zip(colors, alpha)] color_groups = ub.group_items(range(len(ptcolors)), ptcolors) circlekw = { 'radius': radius, 'fill': True, 'ec': None, } if 'fc' in kwargs: import warnings warnings.warning( 'Warning: specifying fc to Points.draw overrides ' 'the color argument. Use color instead') circlekw.update(kwargs) fc = circlekw.pop('fc', None) # hack collections = [] for pcolor, idxs in color_groups.items(): # hack for fc if fc is not None: pcolor = fc patches = [ mpl.patches.Circle((x, y), fc=pcolor, **circlekw) for x, y in xy[idxs] ] col = mpl.collections.PatchCollection(patches, match_original=True) collections.append(col) ax.add_collection(col) return collections
# DEPRECATED FUNCTIONS. STILL EXISTS FOR BACKWARDS COMPAT # backwards compat from kwimage import draw_boxes_on_image, draw_clf_on_image, draw_text_on_image # NOQA