PyTorch - 加载数据

  • 简述

    PyTorch 包含一个名为 torchvision 的包,用于加载和准备数据集。它包括两个基本功能,即 Dataset 和 DataLoader,它们有助于数据集的转换和加载。
  • 数据集

    数据集用于从给定的数据集中读取和转换数据点。下面提到了要实现的基本语法 -
    
    trainset = torchvision.datasets.CIFAR10(root = './data', train = True,
       download = True, transform = transform)
    
    DataLoader 用于对数据进行 shuffle 和批处理。它可用于与多处理工作者并行加载数据。
    
    trainloader = torch.utils.data.DataLoader(trainset, batch_size = 4,
       shuffle = True, num_workers = 2)
    

    示例:加载 CSV 文件

    我们使用 Python 包 Panda 来加载 csv 文件。原始文件具有以下格式:(图像名称,68 个地标 - 每个地标都有 ax、y 坐标)。
    
    landmarks_frame = pd.read_csv('faces/face_landmarks.csv')
    n = 65
    img_name = landmarks_frame.iloc[n, 0]
    landmarks = landmarks_frame.iloc[n, 1:].as_matrix()
    landmarks = landmarks.astype('float').reshape(-1, 2)