kwplot.draw_conv module

Helper for drawing convolutional neural network weights.

This may be removed in the future.

kwplot.draw_conv.make_conv_images(conv, color=None, norm_per_feat=True)[source]

Convert convolutional weights to a list of visualize-able images

Parameters
  • conv (torch.nn.Conv2d) – a torch convolutional layer

  • color (bool) – if True output images are colorized

  • norm_per_feat (bool) – if True normalizes over each feature separately, otherwise normalizes all features together.

Return type

ndarray

Todo

  • [ ] better normalization options

Example

>>> # xdoctest: +REQUIRES(module:torch)
>>> conv = torch.nn.Conv2d(3, 9, (5, 7))
>>> weights_tohack = conv.weight[0:7].data.numpy()
>>> weights_flat = make_conv_images(conv, norm_per_feat=False)
>>> # xdoctest: +REQUIRES(--show)
>>> import kwimage
>>> import kwplot
>>> stacked = kwimage.stack_images_grid(weights_flat, chunksize=5, overlap=-1)
>>> kwplot.imshow(stacked)
>>> kwplot.show_if_requested()
kwplot.draw_conv.plot_convolutional_features(conv, limit=144, colorspace='rgb', fnum=None, nCols=None, voxels=False, alpha=0.2, labels=False, normaxis=None, _hack_2drows=False)[source]

Plots the convolutional layers to a matplotlib pyplot.

The convolutional filters (kernels) are stored into a grid and saved to disk as a Maplotlib figure. The convolutional filters, if it has one channel, will be stored as an intensity imgage. If a colorspace is specified and there are three input channels, the convolutional filters will be represented as an RGB image.

In the event that 2 or 4+ filters are displayed, the different channels will be flattened and showed as distinct outputs in the grid.

Todo

  • [ ] refactor to use make_conv_images

Parameters
  • conv (torch.nn.modules.conv._ConvNd) – torch convolutional layer with weights to draw

  • limit (int) – the limit on the number of filters drawn in the figure, achieved by simply dropping any filters past the limit starting at the first filter. Detaults to 144.

  • colorspace (str) – the colorspace seen by the convolutional filter (if applicable), so we can convert to rgb for display.

  • voxels (bool) – if True, and we have a 3d conv, show the voxels

  • alpha (float) – only applicable if voxels=True

  • stride (list) – only applicable if voxels=True

Returns

fig - a Matplotlib figure

Return type

matplotlib.figure.Figure

References

https://matplotlib.org/devdocs/gallery/mplot3d/voxels.html

Example

>>> # xdoctest: +REQUIRES(module:torch)
>>> conv = torch.nn.Conv2d(3, 9, (5, 7))
>>> plot_convolutional_features(conv, colorspace=None, fnum=None, limit=2)

Example

>>> # xdoctest: +REQUIRES(--comprehensive)
>>> # xdoctest: +REQUIRES(module:torch)
>>> import torchvision
>>> # 2d uncolored gray-images
>>> conv = torch.nn.Conv3d(1, 2, (3, 4, 5))
>>> plot_convolutional_features(conv, colorspace=None, fnum=1, limit=2)
>>> # 2d colored rgb-images
>>> conv = torch.nn.Conv3d(3, 2, (6, 4, 5))
>>> plot_convolutional_features(conv, colorspace='rgb', fnum=1, limit=2)
>>> # 2d uncolored rgb-images
>>> conv = torch.nn.Conv3d(3, 2, (6, 4, 5))
>>> plot_convolutional_features(conv, colorspace=None, fnum=1, limit=2)
>>> # 3d gray voxels
>>> conv = torch.nn.Conv3d(1, 2, (6, 4, 5))
>>> plot_convolutional_features(conv, colorspace=None, fnum=1, voxels=True,
>>>                             limit=2)
>>> # 3d color voxels
>>> conv = torch.nn.Conv3d(3, 2, (6, 4, 5))
>>> plot_convolutional_features(conv, colorspace='rgb', fnum=1,
>>>                             voxels=True, alpha=1, limit=3)
>>> # hack the nice resnet weights into 3d-space
>>> # xdoctest: +REQUIRES(--network)
>>> import torchvision
>>> model = torchvision.models.resnet50(pretrained=True)
>>> conv = torch.nn.Conv3d(3, 1, (7, 7, 7))
>>> weights_tohack = model.conv1.weight[0:7].data.numpy()
>>> # normalize each weight for nice colors, then place in the conv3d
>>> for w in weights_tohack:
...     w[:] = (w - w.min()) / (w.max() - w.min())
>>> weights_hacked = weights_tohack.transpose(1, 0, 2, 3)[None, :]
>>> conv.weight.data[:] = torch.FloatTensor(weights_hacked)
>>> plot_convolutional_features(conv, colorspace='rgb', fnum=1, voxels=True, alpha=.6)
>>> plot_convolutional_features(conv, colorspace='rgb', fnum=2, voxels=False, alpha=.9)

Example

>>> # xdoctest: +REQUIRES(--network)
>>> # xdoctest: +REQUIRES(module:torch)
>>> import torchvision
>>> model = torchvision.models.resnet50(pretrained=True)
>>> conv = model.conv1
>>> plot_convolutional_features(conv, colorspace='rgb', fnum=None)