Open In Colab

CNN - Convolutional neural networks¶

  • Segmentation: label every pixel with its semantics.
  • Convolutions are linear operations; a non linear operation like ReLU is applied after each convolutinal layer. Multiple convolutional layers are used; they are capturing low level features to high level features as input progresses through the network.
  • Stride: the step used to move the kernel over the image. For a 7x7 image, and a 3x3 kernel, if stride=1 then the kernel will move 1 pixel to the right and down each time it is applied. So it will move 5 times horizontally. if stride=2 then it will move 2 pixels to the right and down each time it is applied. So it will move 3 times horizontally. and so on.
  • Padding: add zeroes on the borders around the input.
  • when convolving a kernel with the input, a dot product is calculated between the flatten view of the kernel and the flatten view of the input resulting in a single number for each cell in the output.
  • pooling layers: they reduce the input volume size. The depth of the volume is preserved, e.g. a 4x4x3 would be reduced to a 2x2x3. An implementation is the max pooling that takes the maximum value in each cell of a 2x2 window.
  • FX: a final layer fully connected to the volume,
  • ResNet skip conections: assuming L1->L2->L3, instead of L3 getting the output of L2 only as input, it gets the addition of the output of both L1 and L2. During backpropagation, the gradient will flow from L3 to L2, but also from L3 to L1.

Install required packages¶

In [13]:
!pip install torch albumentations tqdm pandas kagglehub scikit-learn
Requirement already satisfied: torch in /Users/LSoica/work/AI/blog/.venv/lib/python3.12/site-packages (2.5.0)
Requirement already satisfied: albumentations in /Users/LSoica/work/AI/blog/.venv/lib/python3.12/site-packages (1.4.18)
Requirement already satisfied: tqdm in /Users/LSoica/work/AI/blog/.venv/lib/python3.12/site-packages (4.66.5)
Requirement already satisfied: pandas in /Users/LSoica/work/AI/blog/.venv/lib/python3.12/site-packages (2.2.3)
Requirement already satisfied: kagglehub in /Users/LSoica/work/AI/blog/.venv/lib/python3.12/site-packages (0.3.3)
Collecting scikit-learn
  Using cached scikit_learn-1.5.2-cp312-cp312-macosx_12_0_arm64.whl.metadata (13 kB)
Requirement already satisfied: filelock in /Users/LSoica/work/AI/blog/.venv/lib/python3.12/site-packages (from torch) (3.16.1)
Requirement already satisfied: typing-extensions>=4.8.0 in /Users/LSoica/work/AI/blog/.venv/lib/python3.12/site-packages (from torch) (4.12.2)
Requirement already satisfied: networkx in /Users/LSoica/work/AI/blog/.venv/lib/python3.12/site-packages (from torch) (3.4.1)
Requirement already satisfied: jinja2 in /Users/LSoica/work/AI/blog/.venv/lib/python3.12/site-packages (from torch) (3.1.4)
Requirement already satisfied: fsspec in /Users/LSoica/work/AI/blog/.venv/lib/python3.12/site-packages (from torch) (2024.9.0)
Requirement already satisfied: setuptools in /Users/LSoica/work/AI/blog/.venv/lib/python3.12/site-packages (from torch) (75.2.0)
Requirement already satisfied: sympy==1.13.1 in /Users/LSoica/work/AI/blog/.venv/lib/python3.12/site-packages (from torch) (1.13.1)
Requirement already satisfied: mpmath<1.4,>=1.1.0 in /Users/LSoica/work/AI/blog/.venv/lib/python3.12/site-packages (from sympy==1.13.1->torch) (1.3.0)
Requirement already satisfied: numpy>=1.24.4 in /Users/LSoica/work/AI/blog/.venv/lib/python3.12/site-packages (from albumentations) (2.1.2)
Requirement already satisfied: scipy>=1.10.0 in /Users/LSoica/work/AI/blog/.venv/lib/python3.12/site-packages (from albumentations) (1.14.1)
Requirement already satisfied: scikit-image>=0.21.0 in /Users/LSoica/work/AI/blog/.venv/lib/python3.12/site-packages (from albumentations) (0.24.0)
Requirement already satisfied: PyYAML in /Users/LSoica/work/AI/blog/.venv/lib/python3.12/site-packages (from albumentations) (6.0.2)
Requirement already satisfied: pydantic>=2.7.0 in /Users/LSoica/work/AI/blog/.venv/lib/python3.12/site-packages (from albumentations) (2.9.2)
Requirement already satisfied: albucore==0.0.17 in /Users/LSoica/work/AI/blog/.venv/lib/python3.12/site-packages (from albumentations) (0.0.17)
Requirement already satisfied: eval-type-backport in /Users/LSoica/work/AI/blog/.venv/lib/python3.12/site-packages (from albumentations) (0.2.0)
Requirement already satisfied: opencv-python-headless>=4.9.0.80 in /Users/LSoica/work/AI/blog/.venv/lib/python3.12/site-packages (from albumentations) (4.10.0.84)
Requirement already satisfied: python-dateutil>=2.8.2 in /Users/LSoica/work/AI/blog/.venv/lib/python3.12/site-packages (from pandas) (2.9.0.post0)
Requirement already satisfied: pytz>=2020.1 in /Users/LSoica/work/AI/blog/.venv/lib/python3.12/site-packages (from pandas) (2024.2)
Requirement already satisfied: tzdata>=2022.7 in /Users/LSoica/work/AI/blog/.venv/lib/python3.12/site-packages (from pandas) (2024.2)
Requirement already satisfied: packaging in /Users/LSoica/work/AI/blog/.venv/lib/python3.12/site-packages (from kagglehub) (24.1)
Requirement already satisfied: requests in /Users/LSoica/work/AI/blog/.venv/lib/python3.12/site-packages (from kagglehub) (2.32.3)
Collecting joblib>=1.2.0 (from scikit-learn)
  Using cached joblib-1.4.2-py3-none-any.whl.metadata (5.4 kB)
Collecting threadpoolctl>=3.1.0 (from scikit-learn)
  Using cached threadpoolctl-3.5.0-py3-none-any.whl.metadata (13 kB)
Requirement already satisfied: annotated-types>=0.6.0 in /Users/LSoica/work/AI/blog/.venv/lib/python3.12/site-packages (from pydantic>=2.7.0->albumentations) (0.7.0)
Requirement already satisfied: pydantic-core==2.23.4 in /Users/LSoica/work/AI/blog/.venv/lib/python3.12/site-packages (from pydantic>=2.7.0->albumentations) (2.23.4)
Requirement already satisfied: six>=1.5 in /Users/LSoica/work/AI/blog/.venv/lib/python3.12/site-packages (from python-dateutil>=2.8.2->pandas) (1.16.0)
Requirement already satisfied: pillow>=9.1 in /Users/LSoica/work/AI/blog/.venv/lib/python3.12/site-packages (from scikit-image>=0.21.0->albumentations) (11.0.0)
Requirement already satisfied: imageio>=2.33 in /Users/LSoica/work/AI/blog/.venv/lib/python3.12/site-packages (from scikit-image>=0.21.0->albumentations) (2.36.0)
Requirement already satisfied: tifffile>=2022.8.12 in /Users/LSoica/work/AI/blog/.venv/lib/python3.12/site-packages (from scikit-image>=0.21.0->albumentations) (2024.9.20)
Requirement already satisfied: lazy-loader>=0.4 in /Users/LSoica/work/AI/blog/.venv/lib/python3.12/site-packages (from scikit-image>=0.21.0->albumentations) (0.4)
Requirement already satisfied: MarkupSafe>=2.0 in /Users/LSoica/work/AI/blog/.venv/lib/python3.12/site-packages (from jinja2->torch) (3.0.1)
Requirement already satisfied: charset-normalizer<4,>=2 in /Users/LSoica/work/AI/blog/.venv/lib/python3.12/site-packages (from requests->kagglehub) (3.4.0)
Requirement already satisfied: idna<4,>=2.5 in /Users/LSoica/work/AI/blog/.venv/lib/python3.12/site-packages (from requests->kagglehub) (3.10)
Requirement already satisfied: urllib3<3,>=1.21.1 in /Users/LSoica/work/AI/blog/.venv/lib/python3.12/site-packages (from requests->kagglehub) (2.2.3)
Requirement already satisfied: certifi>=2017.4.17 in /Users/LSoica/work/AI/blog/.venv/lib/python3.12/site-packages (from requests->kagglehub) (2024.8.30)
Using cached scikit_learn-1.5.2-cp312-cp312-macosx_12_0_arm64.whl (11.0 MB)
Using cached joblib-1.4.2-py3-none-any.whl (301 kB)
Using cached threadpoolctl-3.5.0-py3-none-any.whl (18 kB)
Installing collected packages: threadpoolctl, joblib, scikit-learn
Successfully installed joblib-1.4.2 scikit-learn-1.5.2 threadpoolctl-3.5.0

[notice] A new release of pip is available: 24.2 -> 24.3.1
[notice] To update, run: pip install --upgrade pip

Image classification¶

Import¶

In [14]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import make_grid

import numpy as np
import pandas as pd
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
%matplotlib inline
In [15]:
if torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")
In [16]:
# Convert MNIST Image Files into a Tensor of 4-Dimensions (# of images, Height, Width, Color Channels)
transform = transforms.ToTensor()
In [18]:
# Train Data
train_data = datasets.MNIST(root='./cnn_data', train=True, download=True, transform=transform)
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1000)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./cnn_data/MNIST/raw/train-images-idx3-ubyte.gz
100%|██████████| 9.91M/9.91M [00:01<00:00, 6.05MB/s]
Extracting ./cnn_data/MNIST/raw/train-images-idx3-ubyte.gz to ./cnn_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1000)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./cnn_data/MNIST/raw/train-labels-idx1-ubyte.gz
100%|██████████| 28.9k/28.9k [00:00<00:00, 231kB/s]
Extracting ./cnn_data/MNIST/raw/train-labels-idx1-ubyte.gz to ./cnn_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1000)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./cnn_data/MNIST/raw/t10k-images-idx3-ubyte.gz
100%|██████████| 1.65M/1.65M [00:00<00:00, 2.17MB/s]
Extracting ./cnn_data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./cnn_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1000)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./cnn_data/MNIST/raw/t10k-labels-idx1-ubyte.gz
100%|██████████| 4.54k/4.54k [00:00<00:00, 6.11MB/s]
Extracting ./cnn_data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./cnn_data/MNIST/raw


In [20]:
# Test Data
test_data = datasets.MNIST(root='./cnn_data', train=False, download=True, transform=transform)
In [21]:
train_data
Out[21]:
Dataset MNIST
    Number of datapoints: 60000
    Root location: ./cnn_data
    Split: Train
    StandardTransform
Transform: ToTensor()
In [22]:
test_data
Out[22]:
Dataset MNIST
    Number of datapoints: 10000
    Root location: ./cnn_data
    Split: Test
    StandardTransform
Transform: ToTensor()
In [23]:
# Create a small batch size for images...let's say 10
train_loader = DataLoader(train_data, batch_size=10, shuffle=True)
test_loader = DataLoader(test_data, batch_size=10, shuffle=False)
     
In [24]:
conv1 = nn.Conv2d(1, 6, 3, 1)
conv2 = nn.Conv2d(6, 16, 3, 1)
In [25]:
# Grab 1 MNIST record/image
for i, (X_Train, y_train) in enumerate(train_data):
  break
In [26]:
X_Train.shape
Out[26]:
torch.Size([1, 28, 28])
In [27]:
x = X_Train.view(1,1,28,28)
In [28]:
# Perform our first convolution
x = F.relu(conv1(x)) # Rectified Linear Unit for our activation function
In [29]:
# 1 single image, 6 is the filters we asked for, 26x26
x.shape
Out[29]:
torch.Size([1, 6, 26, 26])
In [30]:
# pass thru the pooling layer
x = F.max_pool2d(x,2,2) # kernal of 2 and stride of 2
In [31]:
x.shape
Out[31]:
torch.Size([1, 6, 13, 13])
In [32]:
# Do our second convolutional layer
x = F.relu(conv2(x))
In [33]:
x.shape
Out[33]:
torch.Size([1, 16, 11, 11])
In [34]:
# Pooling layer
x = F.max_pool2d(x,2,2)
In [35]:
x.shape 
Out[35]:
torch.Size([1, 16, 5, 5])
In [36]:
((28-2) / 2 - 2) / 2
Out[36]:
5.5
In [37]:
# Model Class
class ConvolutionalNetwork(nn.Module):
  def __init__(self):
    super().__init__()
    self.conv1 = nn.Conv2d(1,6,3,1)
    self.conv2 = nn.Conv2d(6,16,3,1)
    # Fully Connected Layer
    self.fc1 = nn.Linear(5*5*16, 120)
    self.fc2 = nn.Linear(120, 84)
    self.fc3 = nn.Linear(84, 10)

  def forward(self, X):
    X = F.relu(self.conv1(X))
    X = F.max_pool2d(X,2,2) # 2x2 kernal and stride 2
    # Second Pass
    X = F.relu(self.conv2(X))
    X = F.max_pool2d(X,2,2) # 2x2 kernal and stride 2

    # Re-View to flatten it out
    X = X.view(-1, 16*5*5) # negative one so that we can vary the batch size

    # Fully Connected Layers
    X = F.relu(self.fc1(X))
    X = F.relu(self.fc2(X))
    X = self.fc3(X)
    return F.log_softmax(X, dim=1)
In [38]:
# Create an Instance of our Model
torch.manual_seed(41)
model = ConvolutionalNetwork()
model
Out[38]:
ConvolutionalNetwork(
  (conv1): Conv2d(1, 6, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(6, 16, kernel_size=(3, 3), stride=(1, 1))
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)
In [39]:
# Loss Function Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001) # Smaller the Learning Rate, longer its gonna take to train.
     
In [40]:
import time
start_time = time.time()

# Create Variables To Tracks Things
epochs = 5
train_losses = []
test_losses = []
train_correct = []
test_correct = []

# For Loop of Epochs
for i in range(epochs):
  trn_corr = 0
  tst_corr = 0


  # Train
  for b,(X_train, y_train) in enumerate(train_loader):
    b+=1 # start our batches at 1
    y_pred = model(X_train) # get predicted values from the training set. Not flattened 2D
    loss = criterion(y_pred, y_train) # how off are we? Compare the predictions to correct answers in y_train

    predicted = torch.max(y_pred.data, 1)[1] # add up the number of correct predictions. Indexed off the first point
    batch_corr = (predicted == y_train).sum() # how many we got correct from this batch. True = 1, False=0, sum those up
    trn_corr += batch_corr # keep track as we go along in training.

    # Update our parameters
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()


    # Print out some results
    if b%600 == 0:
      print(f'Epoch: {i}  Batch: {b}  Loss: {loss.item()}')

  train_losses.append(loss)
  train_correct.append(trn_corr)


  # Test
  with torch.no_grad(): #No gradient so we don't update our weights and biases with test data
    for b,(X_test, y_test) in enumerate(test_loader):
      y_val = model(X_test)
      predicted = torch.max(y_val.data, 1)[1] # Adding up correct predictions
      tst_corr += (predicted == y_test).sum() # T=1 F=0 and sum away


  loss = criterion(y_val, y_test)
  test_losses.append(loss)
  test_correct.append(tst_corr)



current_time = time.time()
total = current_time - start_time
print(f'Training Took: {total/60} minutes!')
Epoch: 0  Batch: 600  Loss: 0.16236090660095215
Epoch: 0  Batch: 1200  Loss: 0.1614551544189453
Epoch: 0  Batch: 1800  Loss: 0.5041669607162476
Epoch: 0  Batch: 2400  Loss: 0.10426308214664459
Epoch: 0  Batch: 3000  Loss: 0.007077544927597046
Epoch: 0  Batch: 3600  Loss: 0.3652905821800232
Epoch: 0  Batch: 4200  Loss: 0.0037760832346975803
Epoch: 0  Batch: 4800  Loss: 0.0013758999994024634
Epoch: 0  Batch: 5400  Loss: 0.04505983740091324
Epoch: 0  Batch: 6000  Loss: 0.0005676982691511512
Epoch: 1  Batch: 600  Loss: 0.0040985336527228355
Epoch: 1  Batch: 1200  Loss: 0.26344460248947144
Epoch: 1  Batch: 1800  Loss: 0.0015710301231592894
Epoch: 1  Batch: 2400  Loss: 0.002920893719419837
Epoch: 1  Batch: 3000  Loss: 0.025759723037481308
Epoch: 1  Batch: 3600  Loss: 0.5109620094299316
Epoch: 1  Batch: 4200  Loss: 0.03832785040140152
Epoch: 1  Batch: 4800  Loss: 0.0004107930581085384
Epoch: 1  Batch: 5400  Loss: 0.00042201072210446
Epoch: 1  Batch: 6000  Loss: 0.4828798770904541
Epoch: 2  Batch: 600  Loss: 0.05119292065501213
Epoch: 2  Batch: 1200  Loss: 0.0038511897437274456
Epoch: 2  Batch: 1800  Loss: 0.0014970828779041767
Epoch: 2  Batch: 2400  Loss: 0.004896306432783604
Epoch: 2  Batch: 3000  Loss: 0.008256791159510612
Epoch: 2  Batch: 3600  Loss: 0.0004542603564914316
Epoch: 2  Batch: 4200  Loss: 0.26317352056503296
Epoch: 2  Batch: 4800  Loss: 0.001537359319627285
Epoch: 2  Batch: 5400  Loss: 0.006470398046076298
Epoch: 2  Batch: 6000  Loss: 0.1095944419503212
Epoch: 3  Batch: 600  Loss: 0.03505117446184158
Epoch: 3  Batch: 1200  Loss: 0.001770398230291903
Epoch: 3  Batch: 1800  Loss: 0.005633163265883923
Epoch: 3  Batch: 2400  Loss: 1.2385697118588723e-05
Epoch: 3  Batch: 3000  Loss: 0.0020400488283485174
Epoch: 3  Batch: 3600  Loss: 0.00025755190290510654
Epoch: 3  Batch: 4200  Loss: 0.011785751208662987
Epoch: 3  Batch: 4800  Loss: 2.1826495867571793e-05
Epoch: 3  Batch: 5400  Loss: 0.0542694516479969
Epoch: 3  Batch: 6000  Loss: 0.017493976280093193
Epoch: 4  Batch: 600  Loss: 0.00170081143733114
Epoch: 4  Batch: 1200  Loss: 0.08627621829509735
Epoch: 4  Batch: 1800  Loss: 0.34209710359573364
Epoch: 4  Batch: 2400  Loss: 0.00012794588110409677
Epoch: 4  Batch: 3000  Loss: 0.0014115956146270037
Epoch: 4  Batch: 3600  Loss: 0.47453540563583374
Epoch: 4  Batch: 4200  Loss: 0.0017075622454285622
Epoch: 4  Batch: 4800  Loss: 0.0002860472886823118
Epoch: 4  Batch: 5400  Loss: 0.018258752301335335
Epoch: 4  Batch: 6000  Loss: 0.0014418931677937508
Training Took: 0.6474462866783142 minutes!
In [41]:
# Graph the loss at epoch
train_losses = [tl.item() for tl in train_losses]
plt.plot(train_losses, label="Training Loss")
plt.plot(test_losses, label="Validation Loss")
plt.title("Loss at Epoch")
plt.legend()
Out[41]:
<matplotlib.legend.Legend at 0x17f16f200>
No description has been provided for this image
In [42]:
# graph the accuracy at the end of each epoch
plt.plot([t/600 for t in train_correct], label="Training Accuracy")
plt.plot([t/100 for t in test_correct], label="Validation Accuracy")
plt.title("Accuracy at the end of each Epoch")
plt.legend()
Out[42]:
<matplotlib.legend.Legend at 0x17f1d8c80>
No description has been provided for this image
In [43]:
test_load_everything = DataLoader(test_data, batch_size=10000, shuffle=False)
In [44]:
with torch.no_grad():
  correct = 0
  for X_test, y_test in test_load_everything:
    y_val = model(X_test)
    predicted = torch.max(y_val, 1)[1]
    correct += (predicted == y_test).sum()
In [45]:
# Did for correct
correct.item()/len(test_data)*100
Out[45]:
98.64
In [46]:
# Grab an image
test_data[1978]
Out[46]:
(tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.4392, 0.6902, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1529, 0.2314,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.7922, 0.9255, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.5490, 0.9843,
           0.1608, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0353, 0.8392, 0.9255, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.6196, 0.9961,
           0.1725, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.1412, 0.9961, 0.7333, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.8941, 0.9059,
           0.0824, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.3804, 0.9961, 0.5804, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.8941, 0.8235,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.8275, 0.9961, 0.4196, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1098, 0.9412, 0.8235,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0314, 0.3098,
           0.7569, 0.7922, 0.9608, 0.9961, 0.2392, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.2667, 0.9961, 0.6784,
           0.0000, 0.0000, 0.0000, 0.0039, 0.0706, 0.6392, 0.8235, 0.9961,
           0.9961, 0.9961, 0.9961, 0.9294, 0.1412, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0118, 0.7176, 0.9961, 0.3804,
           0.2039, 0.2627, 0.3804, 0.4039, 0.9961, 1.0000, 0.9961, 0.9725,
           0.7686, 0.9451, 0.9961, 0.4196, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0863, 0.8314, 0.9961, 0.9961,
           0.9961, 0.9961, 0.9961, 0.9961, 0.9961, 0.9961, 0.5647, 0.1294,
           0.0000, 0.8588, 0.9961, 0.2039, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1686, 0.7529, 0.9961,
           0.9961, 0.9961, 0.9765, 0.6863, 0.5686, 0.0000, 0.0000, 0.0000,
           0.1373, 0.9529, 0.9961, 0.2078, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0078, 0.0314,
           0.0314, 0.0314, 0.0314, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.3059, 0.9961, 0.9451, 0.1333, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.7529, 0.9961, 0.9608, 0.1569, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.2196,
           0.9843, 0.9961, 0.7843, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.3020,
           0.9961, 0.9961, 0.2157, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.7529,
           0.9961, 0.8510, 0.0314, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.9294,
           0.9961, 0.5490, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0196, 0.9333,
           0.9961, 0.2196, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.2667, 0.9922,
           0.9882, 0.1333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.9294,
           0.4353, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000]]]),
 4)
In [47]:
# Grab just the data
test_data[1978][0]
Out[47]:
tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.4392, 0.6902, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1529, 0.2314,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.7922, 0.9255, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.5490, 0.9843,
          0.1608, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0353, 0.8392, 0.9255, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.6196, 0.9961,
          0.1725, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.1412, 0.9961, 0.7333, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.8941, 0.9059,
          0.0824, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.3804, 0.9961, 0.5804, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.8941, 0.8235,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.8275, 0.9961, 0.4196, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1098, 0.9412, 0.8235,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0314, 0.3098,
          0.7569, 0.7922, 0.9608, 0.9961, 0.2392, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.2667, 0.9961, 0.6784,
          0.0000, 0.0000, 0.0000, 0.0039, 0.0706, 0.6392, 0.8235, 0.9961,
          0.9961, 0.9961, 0.9961, 0.9294, 0.1412, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0118, 0.7176, 0.9961, 0.3804,
          0.2039, 0.2627, 0.3804, 0.4039, 0.9961, 1.0000, 0.9961, 0.9725,
          0.7686, 0.9451, 0.9961, 0.4196, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0863, 0.8314, 0.9961, 0.9961,
          0.9961, 0.9961, 0.9961, 0.9961, 0.9961, 0.9961, 0.5647, 0.1294,
          0.0000, 0.8588, 0.9961, 0.2039, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1686, 0.7529, 0.9961,
          0.9961, 0.9961, 0.9765, 0.6863, 0.5686, 0.0000, 0.0000, 0.0000,
          0.1373, 0.9529, 0.9961, 0.2078, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0078, 0.0314,
          0.0314, 0.0314, 0.0314, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.3059, 0.9961, 0.9451, 0.1333, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.7529, 0.9961, 0.9608, 0.1569, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.2196,
          0.9843, 0.9961, 0.7843, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.3020,
          0.9961, 0.9961, 0.2157, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.7529,
          0.9961, 0.8510, 0.0314, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.9294,
          0.9961, 0.5490, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0196, 0.9333,
          0.9961, 0.2196, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.2667, 0.9922,
          0.9882, 0.1333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.9294,
          0.4353, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000]]])
In [48]:
# Reshape it
test_data[1978][0].reshape(28,28)
Out[48]:
tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.4392, 0.6902, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1529, 0.2314, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.7922, 0.9255, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.5490, 0.9843, 0.1608,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0353, 0.8392, 0.9255, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.6196, 0.9961, 0.1725,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.1412, 0.9961, 0.7333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.8941, 0.9059, 0.0824,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.3804, 0.9961, 0.5804, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.8941, 0.8235, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.8275, 0.9961, 0.4196, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1098, 0.9412, 0.8235, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0314, 0.3098, 0.7569, 0.7922,
         0.9608, 0.9961, 0.2392, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.2667, 0.9961, 0.6784, 0.0000,
         0.0000, 0.0000, 0.0039, 0.0706, 0.6392, 0.8235, 0.9961, 0.9961, 0.9961,
         0.9961, 0.9294, 0.1412, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0118, 0.7176, 0.9961, 0.3804, 0.2039,
         0.2627, 0.3804, 0.4039, 0.9961, 1.0000, 0.9961, 0.9725, 0.7686, 0.9451,
         0.9961, 0.4196, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0863, 0.8314, 0.9961, 0.9961, 0.9961,
         0.9961, 0.9961, 0.9961, 0.9961, 0.9961, 0.5647, 0.1294, 0.0000, 0.8588,
         0.9961, 0.2039, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1686, 0.7529, 0.9961, 0.9961,
         0.9961, 0.9765, 0.6863, 0.5686, 0.0000, 0.0000, 0.0000, 0.1373, 0.9529,
         0.9961, 0.2078, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0078, 0.0314, 0.0314,
         0.0314, 0.0314, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.3059, 0.9961,
         0.9451, 0.1333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.7529, 0.9961,
         0.9608, 0.1569, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.2196, 0.9843, 0.9961,
         0.7843, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.3020, 0.9961, 0.9961,
         0.2157, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.7529, 0.9961, 0.8510,
         0.0314, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.9294, 0.9961, 0.5490,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0196, 0.9333, 0.9961, 0.2196,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.2667, 0.9922, 0.9882, 0.1333,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.9294, 0.4353, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000]])
In [49]:
# Show the image
plt.imshow(test_data[1978][0].reshape(28,28))
Out[49]:
<matplotlib.image.AxesImage at 0x17efa5820>
No description has been provided for this image
In [50]:
# Pass the image thru our model
model.eval()
with torch.no_grad():
  new_prediction = model(test_data[1978][0].view(1,1,28,28)) # batch size of 1, 1 color channel, 28x28 image
     
In [51]:
new_prediction
Out[51]:
tensor([[-2.3625e+01, -1.9385e+01, -2.5273e+01, -3.2036e+01, -1.1921e-07,
         -2.0260e+01, -1.7388e+01, -2.0916e+01, -2.0654e+01, -1.6100e+01]])
In [52]:
new_prediction.argmax()
Out[52]:
tensor(4)

Bounding box detection¶

TBD

References¶

  • Convolution Visualizer