kwplot.draw_conv module¶
Helper for drawing convolutional neural network weights.
This may be removed in the future.
- kwplot.draw_conv.make_conv_images(conv, color=None, norm_per_feat=True)[source]¶
Convert convolutional weights to a list of visualize-able images
- Parameters
conv (torch.nn.Conv2d) – a torch convolutional layer
color (bool) – if True output images are colorized
norm_per_feat (bool) – if True normalizes over each feature separately, otherwise normalizes all features together.
- Return type
ndarray
Todo
[ ] better normalization options
Example
>>> # xdoctest: +REQUIRES(module:torch) >>> conv = torch.nn.Conv2d(3, 9, (5, 7)) >>> weights_tohack = conv.weight[0:7].data.numpy() >>> weights_flat = make_conv_images(conv, norm_per_feat=False) >>> # xdoctest: +REQUIRES(--show) >>> import kwimage >>> import kwplot >>> stacked = kwimage.stack_images_grid(weights_flat, chunksize=5, overlap=-1) >>> kwplot.imshow(stacked) >>> kwplot.show_if_requested()
- kwplot.draw_conv.plot_convolutional_features(conv, limit=144, colorspace='rgb', fnum=None, nCols=None, voxels=False, alpha=0.2, labels=False, normaxis=None, _hack_2drows=False)[source]¶
Plots the convolutional layers to a matplotlib pyplot.
The convolutional filters (kernels) are stored into a grid and saved to disk as a Maplotlib figure. The convolutional filters, if it has one channel, will be stored as an intensity imgage. If a colorspace is specified and there are three input channels, the convolutional filters will be represented as an RGB image.
In the event that 2 or 4+ filters are displayed, the different channels will be flattened and showed as distinct outputs in the grid.
Todo
[ ] refactor to use make_conv_images
- Parameters
conv (torch.nn.modules.conv._ConvNd) – torch convolutional layer with weights to draw
limit (int) – the limit on the number of filters drawn in the figure, achieved by simply dropping any filters past the limit starting at the first filter. Detaults to 144.
colorspace (str) – the colorspace seen by the convolutional filter (if applicable), so we can convert to rgb for display.
voxels (bool) – if True, and we have a 3d conv, show the voxels
alpha (float) – only applicable if voxels=True
stride (list) – only applicable if voxels=True
- Returns
fig - a Matplotlib figure
- Return type
matplotlib.figure.Figure
References
https://matplotlib.org/devdocs/gallery/mplot3d/voxels.html
Example
>>> # xdoctest: +REQUIRES(module:torch) >>> conv = torch.nn.Conv2d(3, 9, (5, 7)) >>> plot_convolutional_features(conv, colorspace=None, fnum=None, limit=2)
Example
>>> # xdoctest: +REQUIRES(--comprehensive) >>> # xdoctest: +REQUIRES(module:torch) >>> import torchvision >>> # 2d uncolored gray-images >>> conv = torch.nn.Conv3d(1, 2, (3, 4, 5)) >>> plot_convolutional_features(conv, colorspace=None, fnum=1, limit=2)
>>> # 2d colored rgb-images >>> conv = torch.nn.Conv3d(3, 2, (6, 4, 5)) >>> plot_convolutional_features(conv, colorspace='rgb', fnum=1, limit=2)
>>> # 2d uncolored rgb-images >>> conv = torch.nn.Conv3d(3, 2, (6, 4, 5)) >>> plot_convolutional_features(conv, colorspace=None, fnum=1, limit=2)
>>> # 3d gray voxels >>> conv = torch.nn.Conv3d(1, 2, (6, 4, 5)) >>> plot_convolutional_features(conv, colorspace=None, fnum=1, voxels=True, >>> limit=2)
>>> # 3d color voxels >>> conv = torch.nn.Conv3d(3, 2, (6, 4, 5)) >>> plot_convolutional_features(conv, colorspace='rgb', fnum=1, >>> voxels=True, alpha=1, limit=3)
>>> # hack the nice resnet weights into 3d-space >>> # xdoctest: +REQUIRES(--network) >>> import torchvision >>> model = torchvision.models.resnet50(pretrained=True) >>> conv = torch.nn.Conv3d(3, 1, (7, 7, 7)) >>> weights_tohack = model.conv1.weight[0:7].data.numpy() >>> # normalize each weight for nice colors, then place in the conv3d >>> for w in weights_tohack: ... w[:] = (w - w.min()) / (w.max() - w.min()) >>> weights_hacked = weights_tohack.transpose(1, 0, 2, 3)[None, :] >>> conv.weight.data[:] = torch.FloatTensor(weights_hacked)
>>> plot_convolutional_features(conv, colorspace='rgb', fnum=1, voxels=True, alpha=.6)
>>> plot_convolutional_features(conv, colorspace='rgb', fnum=2, voxels=False, alpha=.9)
Example
>>> # xdoctest: +REQUIRES(--network) >>> # xdoctest: +REQUIRES(module:torch) >>> import torchvision >>> model = torchvision.models.resnet50(pretrained=True) >>> conv = model.conv1 >>> plot_convolutional_features(conv, colorspace='rgb', fnum=None)