神经网络的数据集
数据集格式
方式一:
首先分成训练集和测试集两个文件夹,每个文件里是输出结果的 label ,每个label文件夹下是对应的图片。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
| datasets ├── train │ ├── label1 │ │ ├── file1 │ │ └── file2 │ ├── label2 │ | ├── file1 │ | └── file2 │ └── label3 │ ├── file1 │ └── file2 ├── test │ ├── label1 │ │ ├── file1 │ │ └── file2 │ └── label2 │ ├── file1 │ └── file2 └──
|
方式二:
文件夹下是训练图片数据文件夹,测试数据文件夹,输出结果文件夹(一般用TXT文档,里面存储图片的目录,图片的 label )
方式三:
直接将图片名称设置为对应的 label 。
创建Datasets
首先就是需要从目录中读取文件,需要有每个的图片和label,继承 Datas类,自己读取文件夹下的数据
方法如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
| from torch.utils.data import Dataset from PIL import Image import os
class MyData(Dataset): def __init__(self, root_dir, label_dir): self.root_dir = root_dir self.label_dir = label_dir self.path = os.path.join(self.root_dir, self.label_dir) self.img_path = os.listdir(self.path)
def __getitem__(self, idx): img_name = self.img_path[idx] img_item_path = os.path.join(self.root_dir, self.label_dir, img_name) img = Image.open(img_item_path) label = self.label_dir return img, label
def __len__(self): return len(self.img_path)
|
使用方法如下:
1 2 3 4 5 6 7
| root_path = r"f:\jupyter\data\train" ants_dir = "ants" bees_dir = "bees" ants_dataset = MyData(root_path, ants_dir) bees_dataset = MyData(root_path, bees_dir) print(bees_dataset) train_dataset=ants_dataset+bees_dataset
|
torchvasion中的数据集
torchvasion提供了已经封装好了的数据集文件,可以直接下载使用,可以从pytorch官网查看文档
这里我们使用体积较小的 CIFAR10
我们先来看一下使用的参数:
代码
1 2 3 4
| import torchvision
train_set=torchvision.datasets.CIFAR10(root=r'./dataset',train=True,transform=dataset_transform,download=True) test_set = torchvision.datasets.CIFAR10(root=r'./dataset', train=False, transform=dataset_transform, download=True)
|
- root:数据集的下载目录
- train:是否是训练数据集,False就是测试数据集
- download:是否下载
- transform:图像预处理方式
目录下没有数据集,download=True时,会进行下载,之后运行会进行验证不会下载。
Dataloader
dataloader是一个加载器,设定取数据的方式
datasets相当于扑克牌,dataloader就相当于摸牌方式
参数
代码
1 2 3
| from torch.utils.data import DataLoader
test_loader = DataLoader(dataset=test_set, batch_size=64, shuffle=False, num_workers=0, drop_last=False)
|
- dataset:从中加载数据的数据集。
- batch_size:每批加载多少个样本
- shuffle:是否打乱,设置为 True 以在每个 epoch 重新排列数据
- num_workers:用于数据加载的子进程数。 0 表示数据将在主进程中加载。 (默认值:0)
- drop_last:如果数据集大小不能被批次大小整除,则设置为 True 以删除最后一个不完整的批次。如果 False 并且数据集的大小不能被批大小整除,则最后一批将更小。 (默认:False )