如何使用數(shù)據(jù)集DataSet? 在介紹DataLoader之前,需要先了解數(shù)據(jù)集DataSet的使用。Pytorch中集成了很多已經(jīng)處理好的數(shù)據(jù)集,在pytorch的torchvision、torchtext等模塊有一些典型的數(shù)據(jù)集,可以通過配置來下載使用。 以CIFAR10 數(shù)據(jù)集為例,文檔已經(jīng)描
Pytorch中集成了大量已處理好的數(shù)據(jù)集,在torchvision、torchtext等模塊中都有一些典型的數(shù)據(jù)集,用戶可以通過配置來下載并使用這些數(shù)據(jù)集。例如,CIFAR10 數(shù)據(jù)集已經(jīng)被描述得非常清晰了。其中要注意的是 transform 這個參數(shù),可以用來將圖像轉(zhuǎn)換為所需要的格式,比如將PIL格式的圖像轉(zhuǎn)化為tensor格式的圖像。
在介紹 DataLoader 之前,需要先了解如何使用 DataSet。Pytorch 中的DataSet是一個存儲所有數(shù)據(jù)(例如圖像、音頻)的容器。DataLoader 就是另一個具有更好收納功能的容器,其中分隔開來很多小隔間,可以自己設(shè)定一個小隔間有多少個數(shù)據(jù)集的數(shù)據(jù)來組成,每次將數(shù)據(jù)放進收納小隔間的時候要不要把源數(shù)據(jù)集打亂再進行收納等等。
給定了一個數(shù)據(jù)集,用戶可以決定如何從數(shù)據(jù)集里面拿取數(shù)據(jù)來進行訓(xùn)練,比如一次拿取多少數(shù)據(jù)作為一個對象來對數(shù)據(jù)集進行分割,對數(shù)據(jù)集進行分割之前要不要打亂數(shù)據(jù)集等等。DataLoader的結(jié)果就是一個對數(shù)據(jù)集進行分割的大字典列表,列表中的每個對象都是由設(shè)置的多少個數(shù)據(jù)集的對象組合而成的。
首先需要先理解 __getitem__ 方法。__getitem__被稱為魔法方法,在Python中定義一個類的時候,如果想要通過鍵來得到類的輸出值,就需要 __getitem__ 方法。__getitem__ 方法的作用就是在調(diào)用類的時候自動的運行 __getitem__ 方法的內(nèi)容,得結(jié)果并返回。
class Fib():
def __init__(self,start=0,step=1):
self.step=step
def __getitem__(self, key):
a = key+self.step
return a
s=Fib()
s[1]
例如,在Pytorch中的CIFAR10數(shù)據(jù)集中,可以看到源碼中的 __getitem__ 方法是這樣的。
def __getitem__(self, index: int) -> Tuple[Any, Any]:
img, target = self.data[index], self.targets[index]
img = Image.fromarray(img)
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
用戶可以在Pytorch的Documents文檔中查看DataLoader的使用方法。其中介紹幾個比較常用的參數(shù),例如 dataset,batch_size,shuffle,num_workers 和 drop_last。其中,batch_size表示在數(shù)據(jù)集容器中一次拿取多少數(shù)據(jù),shuffle表示是否在每次操作的時候打亂數(shù)據(jù)集,一般選擇為True。num_workers表示多線程進行拿取數(shù)據(jù)操作,0表示只在主線程中操作。drop_last表示如果拿取數(shù)據(jù)有余數(shù),是否保留最后剩下的部分。
dataset:就是用戶的數(shù)據(jù)集,構(gòu)建好數(shù)據(jù)集對象后傳入即可。
shuffle:是否在每次操作的時候打亂數(shù)據(jù)集,一般選擇為True。
num_workers: 多線程進行拿取數(shù)據(jù)操作,0表示只在主線程中操作。
drop_last:如果拿取數(shù)據(jù)有余數(shù),是否保留最后剩下的部分。
例如,在后面的代碼中,如果設(shè)置 drop_last=False,那么一共有156次數(shù)據(jù)拿取,并且最后一次剩余的部分不會被丟棄。如果設(shè)置 drop_last=True,那么最后剩余的部分被丟棄,并且拿取次數(shù)也少了一次。
初步使用的代碼如下:
import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
test_data=torchvision.datasets.CIFAR10("dataset",train=False,transform=torchvision.transforms.ToTensor())
test_dataloader=DataLoader(dataset=test_data,batch_size=4,shuffle=True,num_workers=0,drop_last=True)
writer=SummaryWriter("logs")
step=0
for data in test_dataloader:
images,targets=data
writer.add_images("test_03",images,step)
step=step+1
writer.close()
然后配合使用tensorboard就可以直觀體會到它的使用方法了。
小編推薦閱讀MethodTimer:一個輕量級的.NET運行耗時統(tǒng)計庫
閱讀構(gòu)建人工智能模型基礎(chǔ):TFDS和Keras的完美搭配
閱讀創(chuàng)建鴻蒙應(yīng)用的橫屏顯示直尺應(yīng)用全程解析
閱讀WiFi基礎(chǔ)(七):WiFi漫游與WiFi組網(wǎng)
閱讀遷移學(xué)習(xí):人工智能模型訓(xùn)練的絕學(xué)
閱讀如何使用 Pytorch 中的 DataSet 和 DataLoader
閱讀golang slice相關(guān)常見的性能優(yōu)化手段
閱讀連接Elasticsearch服務(wù)器的Python代碼示例
閱讀國產(chǎn)操作系統(tǒng)上實現(xiàn)RTMP推流攝像頭視頻和麥克風(fēng)聲音到流媒體服務(wù)器
閱讀本站所有軟件,都由網(wǎng)友上傳,如有侵犯你的版權(quán),請發(fā)郵件[email protected]
湘ICP備2022002427號-10 湘公網(wǎng)安備:43070202000427號© 2013~2024 haote.com 好特網(wǎng)