PyTorch Datasets

code
ML
This project serves as an introduction to PyTorch and working with datasets.
Author

Matthias Quinn

Published

February 14, 2023

GOAL

Learn more about the Machine Learning framework known as PyTorch.

RESULT

A better understanding of PyTorch and working with dataset functionality.

Datasets & DataLoaders

PyTorch provides two data primitives:

  1. torch.utils.data.DataLoader
  2. torch.utils.data.Dataset

They both allow you to use pre-loaded datasets as well as your own. Dataset allows you to store the samples and the labels while DataLoader wraps an iterable around the Dataset to enable access to the samples.

There are a number of prepackaged datasets that come with PyTorch, like FashionMNIST, and more can be found here:

How to Load a Dataset

Here’s how to load the Fashion-MNIST dataset from TorchVision. This dataset is a dataset of Zalando’s article images consisting of 60,000 training examples and 10,000 test examples. Each example comprises of a 28 \times 28 greyscale image and an associated label from one of 10 classes.

We’ll load the FashionMNIST dataset with the following parameters:

  • root = the path where the train/test datasets are stored

  • train = specifies the training or test dataset

  • download = downloads the dataset from the internet if it’s not available on your computer

  • transform = specifies the feature transformations to use, if requested

Code
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt


training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)