CNN - Convolutional neural networks¶
- Segmentation: label every pixel with its semantics.
- Convolutions are linear operations; a non linear operation like ReLU is applied after each convolutinal layer. Multiple convolutional layers are used; they are capturing low level features to high level features as input progresses through the network.
- Stride: the step used to move the kernel over the image. For a 7x7 image, and a 3x3 kernel, if stride=1 then the kernel will move 1 pixel to the right and down each time it is applied. So it will move 5 times horizontally. if stride=2 then it will move 2 pixels to the right and down each time it is applied. So it will move 3 times horizontally. and so on.
- Padding: add zeroes on the borders around the input.
- when convolving a kernel with the input, a dot product is calculated between the flatten view of the kernel and the flatten view of the input resulting in a single number for each cell in the output.
- pooling layers: they reduce the input volume size. The depth of the volume is preserved, e.g. a 4x4x3 would be reduced to a 2x2x3. An implementation is the max pooling that takes the maximum value in each cell of a 2x2 window.
- FX: a final layer fully connected to the volume,
- ResNet skip conections: assuming L1->L2->L3, instead of L3 getting the output of L2 only as input, it gets the addition of the output of both L1 and L2. During backpropagation, the gradient will flow from L3 to L2, but also from L3 to L1.
Install required packages¶
In [13]:
!pip install torch albumentations tqdm pandas kagglehub scikit-learn
Requirement already satisfied: torch in /Users/LSoica/work/AI/blog/.venv/lib/python3.12/site-packages (2.5.0) Requirement already satisfied: albumentations in /Users/LSoica/work/AI/blog/.venv/lib/python3.12/site-packages (1.4.18) Requirement already satisfied: tqdm in /Users/LSoica/work/AI/blog/.venv/lib/python3.12/site-packages (4.66.5) Requirement already satisfied: pandas in /Users/LSoica/work/AI/blog/.venv/lib/python3.12/site-packages (2.2.3) Requirement already satisfied: kagglehub in /Users/LSoica/work/AI/blog/.venv/lib/python3.12/site-packages (0.3.3) Collecting scikit-learn Using cached scikit_learn-1.5.2-cp312-cp312-macosx_12_0_arm64.whl.metadata (13 kB) Requirement already satisfied: filelock in /Users/LSoica/work/AI/blog/.venv/lib/python3.12/site-packages (from torch) (3.16.1) Requirement already satisfied: typing-extensions>=4.8.0 in /Users/LSoica/work/AI/blog/.venv/lib/python3.12/site-packages (from torch) (4.12.2) Requirement already satisfied: networkx in /Users/LSoica/work/AI/blog/.venv/lib/python3.12/site-packages (from torch) (3.4.1) Requirement already satisfied: jinja2 in /Users/LSoica/work/AI/blog/.venv/lib/python3.12/site-packages (from torch) (3.1.4) Requirement already satisfied: fsspec in /Users/LSoica/work/AI/blog/.venv/lib/python3.12/site-packages (from torch) (2024.9.0) Requirement already satisfied: setuptools in /Users/LSoica/work/AI/blog/.venv/lib/python3.12/site-packages (from torch) (75.2.0) Requirement already satisfied: sympy==1.13.1 in /Users/LSoica/work/AI/blog/.venv/lib/python3.12/site-packages (from torch) (1.13.1) Requirement already satisfied: mpmath<1.4,>=1.1.0 in /Users/LSoica/work/AI/blog/.venv/lib/python3.12/site-packages (from sympy==1.13.1->torch) (1.3.0) Requirement already satisfied: numpy>=1.24.4 in /Users/LSoica/work/AI/blog/.venv/lib/python3.12/site-packages (from albumentations) (2.1.2) Requirement already satisfied: scipy>=1.10.0 in /Users/LSoica/work/AI/blog/.venv/lib/python3.12/site-packages (from albumentations) (1.14.1) Requirement already satisfied: scikit-image>=0.21.0 in /Users/LSoica/work/AI/blog/.venv/lib/python3.12/site-packages (from albumentations) (0.24.0) Requirement already satisfied: PyYAML in /Users/LSoica/work/AI/blog/.venv/lib/python3.12/site-packages (from albumentations) (6.0.2) Requirement already satisfied: pydantic>=2.7.0 in /Users/LSoica/work/AI/blog/.venv/lib/python3.12/site-packages (from albumentations) (2.9.2) Requirement already satisfied: albucore==0.0.17 in /Users/LSoica/work/AI/blog/.venv/lib/python3.12/site-packages (from albumentations) (0.0.17) Requirement already satisfied: eval-type-backport in /Users/LSoica/work/AI/blog/.venv/lib/python3.12/site-packages (from albumentations) (0.2.0) Requirement already satisfied: opencv-python-headless>=4.9.0.80 in /Users/LSoica/work/AI/blog/.venv/lib/python3.12/site-packages (from albumentations) (4.10.0.84) Requirement already satisfied: python-dateutil>=2.8.2 in /Users/LSoica/work/AI/blog/.venv/lib/python3.12/site-packages (from pandas) (2.9.0.post0) Requirement already satisfied: pytz>=2020.1 in /Users/LSoica/work/AI/blog/.venv/lib/python3.12/site-packages (from pandas) (2024.2) Requirement already satisfied: tzdata>=2022.7 in /Users/LSoica/work/AI/blog/.venv/lib/python3.12/site-packages (from pandas) (2024.2) Requirement already satisfied: packaging in /Users/LSoica/work/AI/blog/.venv/lib/python3.12/site-packages (from kagglehub) (24.1) Requirement already satisfied: requests in /Users/LSoica/work/AI/blog/.venv/lib/python3.12/site-packages (from kagglehub) (2.32.3) Collecting joblib>=1.2.0 (from scikit-learn) Using cached joblib-1.4.2-py3-none-any.whl.metadata (5.4 kB) Collecting threadpoolctl>=3.1.0 (from scikit-learn) Using cached threadpoolctl-3.5.0-py3-none-any.whl.metadata (13 kB) Requirement already satisfied: annotated-types>=0.6.0 in /Users/LSoica/work/AI/blog/.venv/lib/python3.12/site-packages (from pydantic>=2.7.0->albumentations) (0.7.0) Requirement already satisfied: pydantic-core==2.23.4 in /Users/LSoica/work/AI/blog/.venv/lib/python3.12/site-packages (from pydantic>=2.7.0->albumentations) (2.23.4) Requirement already satisfied: six>=1.5 in /Users/LSoica/work/AI/blog/.venv/lib/python3.12/site-packages (from python-dateutil>=2.8.2->pandas) (1.16.0) Requirement already satisfied: pillow>=9.1 in /Users/LSoica/work/AI/blog/.venv/lib/python3.12/site-packages (from scikit-image>=0.21.0->albumentations) (11.0.0) Requirement already satisfied: imageio>=2.33 in /Users/LSoica/work/AI/blog/.venv/lib/python3.12/site-packages (from scikit-image>=0.21.0->albumentations) (2.36.0) Requirement already satisfied: tifffile>=2022.8.12 in /Users/LSoica/work/AI/blog/.venv/lib/python3.12/site-packages (from scikit-image>=0.21.0->albumentations) (2024.9.20) Requirement already satisfied: lazy-loader>=0.4 in /Users/LSoica/work/AI/blog/.venv/lib/python3.12/site-packages (from scikit-image>=0.21.0->albumentations) (0.4) Requirement already satisfied: MarkupSafe>=2.0 in /Users/LSoica/work/AI/blog/.venv/lib/python3.12/site-packages (from jinja2->torch) (3.0.1) Requirement already satisfied: charset-normalizer<4,>=2 in /Users/LSoica/work/AI/blog/.venv/lib/python3.12/site-packages (from requests->kagglehub) (3.4.0) Requirement already satisfied: idna<4,>=2.5 in /Users/LSoica/work/AI/blog/.venv/lib/python3.12/site-packages (from requests->kagglehub) (3.10) Requirement already satisfied: urllib3<3,>=1.21.1 in /Users/LSoica/work/AI/blog/.venv/lib/python3.12/site-packages (from requests->kagglehub) (2.2.3) Requirement already satisfied: certifi>=2017.4.17 in /Users/LSoica/work/AI/blog/.venv/lib/python3.12/site-packages (from requests->kagglehub) (2024.8.30) Using cached scikit_learn-1.5.2-cp312-cp312-macosx_12_0_arm64.whl (11.0 MB) Using cached joblib-1.4.2-py3-none-any.whl (301 kB) Using cached threadpoolctl-3.5.0-py3-none-any.whl (18 kB) Installing collected packages: threadpoolctl, joblib, scikit-learn Successfully installed joblib-1.4.2 scikit-learn-1.5.2 threadpoolctl-3.5.0 [notice] A new release of pip is available: 24.2 -> 24.3.1 [notice] To update, run: pip install --upgrade pip
In [14]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import make_grid
import numpy as np
import pandas as pd
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
%matplotlib inline
In [15]:
if torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cpu")
In [16]:
# Convert MNIST Image Files into a Tensor of 4-Dimensions (# of images, Height, Width, Color Channels)
transform = transforms.ToTensor()
In [18]:
# Train Data
train_data = datasets.MNIST(root='./cnn_data', train=True, download=True, transform=transform)
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz Failed to download (trying next): <urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1000)> Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./cnn_data/MNIST/raw/train-images-idx3-ubyte.gz
100%|██████████| 9.91M/9.91M [00:01<00:00, 6.05MB/s]
Extracting ./cnn_data/MNIST/raw/train-images-idx3-ubyte.gz to ./cnn_data/MNIST/raw Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz Failed to download (trying next): <urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1000)> Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./cnn_data/MNIST/raw/train-labels-idx1-ubyte.gz
100%|██████████| 28.9k/28.9k [00:00<00:00, 231kB/s]
Extracting ./cnn_data/MNIST/raw/train-labels-idx1-ubyte.gz to ./cnn_data/MNIST/raw Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz Failed to download (trying next): <urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1000)> Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./cnn_data/MNIST/raw/t10k-images-idx3-ubyte.gz
100%|██████████| 1.65M/1.65M [00:00<00:00, 2.17MB/s]
Extracting ./cnn_data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./cnn_data/MNIST/raw Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz Failed to download (trying next): <urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1000)> Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./cnn_data/MNIST/raw/t10k-labels-idx1-ubyte.gz
100%|██████████| 4.54k/4.54k [00:00<00:00, 6.11MB/s]
Extracting ./cnn_data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./cnn_data/MNIST/raw
In [20]:
# Test Data
test_data = datasets.MNIST(root='./cnn_data', train=False, download=True, transform=transform)
In [21]:
train_data
Out[21]:
Dataset MNIST Number of datapoints: 60000 Root location: ./cnn_data Split: Train StandardTransform Transform: ToTensor()
In [22]:
test_data
Out[22]:
Dataset MNIST Number of datapoints: 10000 Root location: ./cnn_data Split: Test StandardTransform Transform: ToTensor()
In [23]:
# Create a small batch size for images...let's say 10
train_loader = DataLoader(train_data, batch_size=10, shuffle=True)
test_loader = DataLoader(test_data, batch_size=10, shuffle=False)
In [24]:
conv1 = nn.Conv2d(1, 6, 3, 1)
conv2 = nn.Conv2d(6, 16, 3, 1)
In [25]:
# Grab 1 MNIST record/image
for i, (X_Train, y_train) in enumerate(train_data):
break
In [26]:
X_Train.shape
Out[26]:
torch.Size([1, 28, 28])
In [27]:
x = X_Train.view(1,1,28,28)
In [28]:
# Perform our first convolution
x = F.relu(conv1(x)) # Rectified Linear Unit for our activation function
In [29]:
# 1 single image, 6 is the filters we asked for, 26x26
x.shape
Out[29]:
torch.Size([1, 6, 26, 26])
In [30]:
# pass thru the pooling layer
x = F.max_pool2d(x,2,2) # kernal of 2 and stride of 2
In [31]:
x.shape
Out[31]:
torch.Size([1, 6, 13, 13])
In [32]:
# Do our second convolutional layer
x = F.relu(conv2(x))
In [33]:
x.shape
Out[33]:
torch.Size([1, 16, 11, 11])
In [34]:
# Pooling layer
x = F.max_pool2d(x,2,2)
In [35]:
x.shape
Out[35]:
torch.Size([1, 16, 5, 5])
In [36]:
((28-2) / 2 - 2) / 2
Out[36]:
5.5
In [37]:
# Model Class
class ConvolutionalNetwork(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1,6,3,1)
self.conv2 = nn.Conv2d(6,16,3,1)
# Fully Connected Layer
self.fc1 = nn.Linear(5*5*16, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, X):
X = F.relu(self.conv1(X))
X = F.max_pool2d(X,2,2) # 2x2 kernal and stride 2
# Second Pass
X = F.relu(self.conv2(X))
X = F.max_pool2d(X,2,2) # 2x2 kernal and stride 2
# Re-View to flatten it out
X = X.view(-1, 16*5*5) # negative one so that we can vary the batch size
# Fully Connected Layers
X = F.relu(self.fc1(X))
X = F.relu(self.fc2(X))
X = self.fc3(X)
return F.log_softmax(X, dim=1)
In [38]:
# Create an Instance of our Model
torch.manual_seed(41)
model = ConvolutionalNetwork()
model
Out[38]:
ConvolutionalNetwork( (conv1): Conv2d(1, 6, kernel_size=(3, 3), stride=(1, 1)) (conv2): Conv2d(6, 16, kernel_size=(3, 3), stride=(1, 1)) (fc1): Linear(in_features=400, out_features=120, bias=True) (fc2): Linear(in_features=120, out_features=84, bias=True) (fc3): Linear(in_features=84, out_features=10, bias=True) )
In [39]:
# Loss Function Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001) # Smaller the Learning Rate, longer its gonna take to train.
In [40]:
import time
start_time = time.time()
# Create Variables To Tracks Things
epochs = 5
train_losses = []
test_losses = []
train_correct = []
test_correct = []
# For Loop of Epochs
for i in range(epochs):
trn_corr = 0
tst_corr = 0
# Train
for b,(X_train, y_train) in enumerate(train_loader):
b+=1 # start our batches at 1
y_pred = model(X_train) # get predicted values from the training set. Not flattened 2D
loss = criterion(y_pred, y_train) # how off are we? Compare the predictions to correct answers in y_train
predicted = torch.max(y_pred.data, 1)[1] # add up the number of correct predictions. Indexed off the first point
batch_corr = (predicted == y_train).sum() # how many we got correct from this batch. True = 1, False=0, sum those up
trn_corr += batch_corr # keep track as we go along in training.
# Update our parameters
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Print out some results
if b%600 == 0:
print(f'Epoch: {i} Batch: {b} Loss: {loss.item()}')
train_losses.append(loss)
train_correct.append(trn_corr)
# Test
with torch.no_grad(): #No gradient so we don't update our weights and biases with test data
for b,(X_test, y_test) in enumerate(test_loader):
y_val = model(X_test)
predicted = torch.max(y_val.data, 1)[1] # Adding up correct predictions
tst_corr += (predicted == y_test).sum() # T=1 F=0 and sum away
loss = criterion(y_val, y_test)
test_losses.append(loss)
test_correct.append(tst_corr)
current_time = time.time()
total = current_time - start_time
print(f'Training Took: {total/60} minutes!')
Epoch: 0 Batch: 600 Loss: 0.16236090660095215 Epoch: 0 Batch: 1200 Loss: 0.1614551544189453 Epoch: 0 Batch: 1800 Loss: 0.5041669607162476 Epoch: 0 Batch: 2400 Loss: 0.10426308214664459 Epoch: 0 Batch: 3000 Loss: 0.007077544927597046 Epoch: 0 Batch: 3600 Loss: 0.3652905821800232 Epoch: 0 Batch: 4200 Loss: 0.0037760832346975803 Epoch: 0 Batch: 4800 Loss: 0.0013758999994024634 Epoch: 0 Batch: 5400 Loss: 0.04505983740091324 Epoch: 0 Batch: 6000 Loss: 0.0005676982691511512 Epoch: 1 Batch: 600 Loss: 0.0040985336527228355 Epoch: 1 Batch: 1200 Loss: 0.26344460248947144 Epoch: 1 Batch: 1800 Loss: 0.0015710301231592894 Epoch: 1 Batch: 2400 Loss: 0.002920893719419837 Epoch: 1 Batch: 3000 Loss: 0.025759723037481308 Epoch: 1 Batch: 3600 Loss: 0.5109620094299316 Epoch: 1 Batch: 4200 Loss: 0.03832785040140152 Epoch: 1 Batch: 4800 Loss: 0.0004107930581085384 Epoch: 1 Batch: 5400 Loss: 0.00042201072210446 Epoch: 1 Batch: 6000 Loss: 0.4828798770904541 Epoch: 2 Batch: 600 Loss: 0.05119292065501213 Epoch: 2 Batch: 1200 Loss: 0.0038511897437274456 Epoch: 2 Batch: 1800 Loss: 0.0014970828779041767 Epoch: 2 Batch: 2400 Loss: 0.004896306432783604 Epoch: 2 Batch: 3000 Loss: 0.008256791159510612 Epoch: 2 Batch: 3600 Loss: 0.0004542603564914316 Epoch: 2 Batch: 4200 Loss: 0.26317352056503296 Epoch: 2 Batch: 4800 Loss: 0.001537359319627285 Epoch: 2 Batch: 5400 Loss: 0.006470398046076298 Epoch: 2 Batch: 6000 Loss: 0.1095944419503212 Epoch: 3 Batch: 600 Loss: 0.03505117446184158 Epoch: 3 Batch: 1200 Loss: 0.001770398230291903 Epoch: 3 Batch: 1800 Loss: 0.005633163265883923 Epoch: 3 Batch: 2400 Loss: 1.2385697118588723e-05 Epoch: 3 Batch: 3000 Loss: 0.0020400488283485174 Epoch: 3 Batch: 3600 Loss: 0.00025755190290510654 Epoch: 3 Batch: 4200 Loss: 0.011785751208662987 Epoch: 3 Batch: 4800 Loss: 2.1826495867571793e-05 Epoch: 3 Batch: 5400 Loss: 0.0542694516479969 Epoch: 3 Batch: 6000 Loss: 0.017493976280093193 Epoch: 4 Batch: 600 Loss: 0.00170081143733114 Epoch: 4 Batch: 1200 Loss: 0.08627621829509735 Epoch: 4 Batch: 1800 Loss: 0.34209710359573364 Epoch: 4 Batch: 2400 Loss: 0.00012794588110409677 Epoch: 4 Batch: 3000 Loss: 0.0014115956146270037 Epoch: 4 Batch: 3600 Loss: 0.47453540563583374 Epoch: 4 Batch: 4200 Loss: 0.0017075622454285622 Epoch: 4 Batch: 4800 Loss: 0.0002860472886823118 Epoch: 4 Batch: 5400 Loss: 0.018258752301335335 Epoch: 4 Batch: 6000 Loss: 0.0014418931677937508 Training Took: 0.6474462866783142 minutes!
In [41]:
# Graph the loss at epoch
train_losses = [tl.item() for tl in train_losses]
plt.plot(train_losses, label="Training Loss")
plt.plot(test_losses, label="Validation Loss")
plt.title("Loss at Epoch")
plt.legend()
Out[41]:
<matplotlib.legend.Legend at 0x17f16f200>
In [42]:
# graph the accuracy at the end of each epoch
plt.plot([t/600 for t in train_correct], label="Training Accuracy")
plt.plot([t/100 for t in test_correct], label="Validation Accuracy")
plt.title("Accuracy at the end of each Epoch")
plt.legend()
Out[42]:
<matplotlib.legend.Legend at 0x17f1d8c80>
In [43]:
test_load_everything = DataLoader(test_data, batch_size=10000, shuffle=False)
In [44]:
with torch.no_grad():
correct = 0
for X_test, y_test in test_load_everything:
y_val = model(X_test)
predicted = torch.max(y_val, 1)[1]
correct += (predicted == y_test).sum()
In [45]:
# Did for correct
correct.item()/len(test_data)*100
Out[45]:
98.64
In [46]:
# Grab an image
test_data[1978]
Out[46]:
(tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.4392, 0.6902, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1529, 0.2314, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.7922, 0.9255, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.5490, 0.9843, 0.1608, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0353, 0.8392, 0.9255, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.6196, 0.9961, 0.1725, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1412, 0.9961, 0.7333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.8941, 0.9059, 0.0824, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.3804, 0.9961, 0.5804, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.8941, 0.8235, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.8275, 0.9961, 0.4196, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1098, 0.9412, 0.8235, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0314, 0.3098, 0.7569, 0.7922, 0.9608, 0.9961, 0.2392, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.2667, 0.9961, 0.6784, 0.0000, 0.0000, 0.0000, 0.0039, 0.0706, 0.6392, 0.8235, 0.9961, 0.9961, 0.9961, 0.9961, 0.9294, 0.1412, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0118, 0.7176, 0.9961, 0.3804, 0.2039, 0.2627, 0.3804, 0.4039, 0.9961, 1.0000, 0.9961, 0.9725, 0.7686, 0.9451, 0.9961, 0.4196, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0863, 0.8314, 0.9961, 0.9961, 0.9961, 0.9961, 0.9961, 0.9961, 0.9961, 0.9961, 0.5647, 0.1294, 0.0000, 0.8588, 0.9961, 0.2039, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1686, 0.7529, 0.9961, 0.9961, 0.9961, 0.9765, 0.6863, 0.5686, 0.0000, 0.0000, 0.0000, 0.1373, 0.9529, 0.9961, 0.2078, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0078, 0.0314, 0.0314, 0.0314, 0.0314, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.3059, 0.9961, 0.9451, 0.1333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.7529, 0.9961, 0.9608, 0.1569, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.2196, 0.9843, 0.9961, 0.7843, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.3020, 0.9961, 0.9961, 0.2157, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.7529, 0.9961, 0.8510, 0.0314, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.9294, 0.9961, 0.5490, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0196, 0.9333, 0.9961, 0.2196, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.2667, 0.9922, 0.9882, 0.1333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.9294, 0.4353, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]]), 4)
In [47]:
# Grab just the data
test_data[1978][0]
Out[47]:
tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.4392, 0.6902, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1529, 0.2314, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.7922, 0.9255, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.5490, 0.9843, 0.1608, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0353, 0.8392, 0.9255, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.6196, 0.9961, 0.1725, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1412, 0.9961, 0.7333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.8941, 0.9059, 0.0824, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.3804, 0.9961, 0.5804, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.8941, 0.8235, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.8275, 0.9961, 0.4196, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1098, 0.9412, 0.8235, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0314, 0.3098, 0.7569, 0.7922, 0.9608, 0.9961, 0.2392, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.2667, 0.9961, 0.6784, 0.0000, 0.0000, 0.0000, 0.0039, 0.0706, 0.6392, 0.8235, 0.9961, 0.9961, 0.9961, 0.9961, 0.9294, 0.1412, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0118, 0.7176, 0.9961, 0.3804, 0.2039, 0.2627, 0.3804, 0.4039, 0.9961, 1.0000, 0.9961, 0.9725, 0.7686, 0.9451, 0.9961, 0.4196, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0863, 0.8314, 0.9961, 0.9961, 0.9961, 0.9961, 0.9961, 0.9961, 0.9961, 0.9961, 0.5647, 0.1294, 0.0000, 0.8588, 0.9961, 0.2039, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1686, 0.7529, 0.9961, 0.9961, 0.9961, 0.9765, 0.6863, 0.5686, 0.0000, 0.0000, 0.0000, 0.1373, 0.9529, 0.9961, 0.2078, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0078, 0.0314, 0.0314, 0.0314, 0.0314, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.3059, 0.9961, 0.9451, 0.1333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.7529, 0.9961, 0.9608, 0.1569, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.2196, 0.9843, 0.9961, 0.7843, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.3020, 0.9961, 0.9961, 0.2157, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.7529, 0.9961, 0.8510, 0.0314, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.9294, 0.9961, 0.5490, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0196, 0.9333, 0.9961, 0.2196, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.2667, 0.9922, 0.9882, 0.1333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.9294, 0.4353, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]])
In [48]:
# Reshape it
test_data[1978][0].reshape(28,28)
Out[48]:
tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.4392, 0.6902, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1529, 0.2314, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.7922, 0.9255, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.5490, 0.9843, 0.1608, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0353, 0.8392, 0.9255, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.6196, 0.9961, 0.1725, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1412, 0.9961, 0.7333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.8941, 0.9059, 0.0824, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.3804, 0.9961, 0.5804, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.8941, 0.8235, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.8275, 0.9961, 0.4196, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1098, 0.9412, 0.8235, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0314, 0.3098, 0.7569, 0.7922, 0.9608, 0.9961, 0.2392, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.2667, 0.9961, 0.6784, 0.0000, 0.0000, 0.0000, 0.0039, 0.0706, 0.6392, 0.8235, 0.9961, 0.9961, 0.9961, 0.9961, 0.9294, 0.1412, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0118, 0.7176, 0.9961, 0.3804, 0.2039, 0.2627, 0.3804, 0.4039, 0.9961, 1.0000, 0.9961, 0.9725, 0.7686, 0.9451, 0.9961, 0.4196, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0863, 0.8314, 0.9961, 0.9961, 0.9961, 0.9961, 0.9961, 0.9961, 0.9961, 0.9961, 0.5647, 0.1294, 0.0000, 0.8588, 0.9961, 0.2039, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1686, 0.7529, 0.9961, 0.9961, 0.9961, 0.9765, 0.6863, 0.5686, 0.0000, 0.0000, 0.0000, 0.1373, 0.9529, 0.9961, 0.2078, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0078, 0.0314, 0.0314, 0.0314, 0.0314, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.3059, 0.9961, 0.9451, 0.1333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.7529, 0.9961, 0.9608, 0.1569, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.2196, 0.9843, 0.9961, 0.7843, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.3020, 0.9961, 0.9961, 0.2157, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.7529, 0.9961, 0.8510, 0.0314, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.9294, 0.9961, 0.5490, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0196, 0.9333, 0.9961, 0.2196, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.2667, 0.9922, 0.9882, 0.1333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.9294, 0.4353, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]])
In [49]:
# Show the image
plt.imshow(test_data[1978][0].reshape(28,28))
Out[49]:
<matplotlib.image.AxesImage at 0x17efa5820>
In [50]:
# Pass the image thru our model
model.eval()
with torch.no_grad():
new_prediction = model(test_data[1978][0].view(1,1,28,28)) # batch size of 1, 1 color channel, 28x28 image
In [51]:
new_prediction
Out[51]:
tensor([[-2.3625e+01, -1.9385e+01, -2.5273e+01, -3.2036e+01, -1.1921e-07, -2.0260e+01, -1.7388e+01, -2.0916e+01, -2.0654e+01, -1.6100e+01]])
In [52]:
new_prediction.argmax()
Out[52]:
tensor(4)
Bounding box detection¶
TBD