# -*- coding: utf-8 -*-
"""
Helper for making 3D plots
"""
from __future__ import absolute_import, division, print_function, unicode_literals
import six
[docs]def plot_surface3d(xgrid, ygrid, zdata, xlabel=None, ylabel=None, zlabel=None,
wire=False, mode=None, contour=False, rstride=1, cstride=1,
pnum=None, labelkw=None, xlabelkw=None, ylabelkw=None,
zlabelkw=None, titlekw=None, *args, **kwargs):
r"""
References:
http://matplotlib.org/mpl_toolkits/mplot3d/tutorial.html
Example:
>>> # DISABLE_DOCTEST
>>> import kwplot
>>> import matplotlib as mpl
>>> import kwimage
>>> shape=(19, 19)
>>> sigma1, sigma2 = 2.0, 1.0
>>> ybasis = np.arange(shape[0])
>>> xbasis = np.arange(shape[1])
>>> xgrid, ygrid = np.meshgrid(xbasis, ybasis)
>>> sigma = [sigma1, sigma2]
>>> gausspatch = kwimage.gaussian_patch(shape, sigma=sigma)
>>> title = 'ksize={!r}, sigma={!r}'.format(shape, (sigma1, sigma2))
>>> kwplot.plot_surface3d(xgrid, ygrid, gausspatch, rstride=1, cstride=1,
>>> cmap=mpl.cm.coolwarm, title=title)
>>> kwplot.show_if_requested()
"""
if titlekw is None:
titlekw = {}
if labelkw is None:
labelkw = {}
if xlabelkw is None:
xlabelkw = labelkw.copy()
if ylabelkw is None:
ylabelkw = labelkw.copy()
if zlabelkw is None:
zlabelkw = labelkw.copy()
from mpl_toolkits.mplot3d import Axes3D # NOQA
import matplotlib.pyplot as plt
import matplotlib as mpl
cmap = kwargs.get('cmap', 'magma') # cm.coolwarm)
if isinstance(cmap, six.string_types):
if cmap == 'magma':
kwargs['cmap'] = cmap = mpl.cm.magma
if pnum is None:
ax = plt.gca(projection='3d')
else:
fig = plt.gcf()
#print('pnum = %r' % (pnum,))
ax = fig.add_subplot(*pnum, projection='3d')
title = kwargs.pop('title', None)
if mode is None:
mode = 'wire' if wire else 'surface'
if len(xgrid.shape) == 1:
# TODO: if we are given long-form data points can we quickly check and
# reshape to the necessary grid
pass
# maybe use ax.scatter3D
if mode == 'wire':
ax.plot_wireframe(xgrid, ygrid, zdata, rstride=rstride,
cstride=cstride, *args, **kwargs)
#ax.contour(xgrid, ygrid, zdata, rstride=rstride, cstride=cstride,
#extend3d=True, *args, **kwargs)
elif mode == 'surface' :
ax.plot_surface(xgrid, ygrid, zdata, rstride=rstride, cstride=cstride,
linewidth=.1, *args, **kwargs)
else:
raise NotImplementedError('mode=%r' % (mode,))
if contour:
import matplotlib.cm as cm
xoffset = xgrid.min() - ((xgrid.max() - xgrid.min()) * .1)
yoffset = ygrid.max() + ((ygrid.max() - ygrid.min()) * .1)
zoffset = zdata.min() - ((zdata.max() - zdata.min()) * .1)
cmap = kwargs.get('cmap', cm.coolwarm)
ax.contour(xgrid, ygrid, zdata, zdir='x', offset=xoffset, cmap=cmap)
ax.contour(xgrid, ygrid, zdata, zdir='y', offset=yoffset, cmap=cmap)
ax.contour(xgrid, ygrid, zdata, zdir='z', offset=zoffset, cmap=cmap)
#ax.plot_trisurf(xgrid.flatten(), ygrid.flatten(), zdata.flatten(), *args, **kwargs)
if title is not None:
ax.set_title(title, **titlekw)
if xlabel is not None:
ax.set_xlabel(xlabel, **xlabelkw)
if ylabel is not None:
ax.set_ylabel(ylabel, **ylabelkw)
if zlabel is not None:
ax.set_zlabel(zlabel, **zlabelkw)
return ax