2017-07-20 2 views
0

Lorsque je tente de comprendre ce qui est à l'intérieur torchvision.datasets.cifar.CIFAR10, j'ai fait quelques code simpleEst-ce que torchvision.datasets.cifar.CIFAR10 est une liste ou pas?

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, 
            download=True, transform=transform) 
print(trainset[1]) 
print(trainset[:10]) 
print(type(trainset)) 

Cependant, je suis une erreur lorsque je tente

print(trainset[:10]) 

Les informations d'erreur est

TypeError: Cannot handle this data type 

Je me demande pourquoi je peux utiliser trainset[1], mais pas trainset[:10]?

Répondre

0

Le découpage n'est pas pris en charge par CIFAR10, c'est pourquoi vous obtenez cette erreur. Si vous voulez les 10 premiers, vous devrez le faire à la place:

print([trainset[i] for i in range(10)]) 

Plus d'infos

La principale raison pour laquelle vous pouvez indexer une instance de la classe CIFAR10 est parce que la classe implémente la fonction __getitem__().

Ainsi, lorsque vous appelez trainset[i] vous appelez essentiellement trainset.__getitem__(i)

Maintenant, dans python3, tranchage expressions est également traité par __getitem__() où l'expression de découpage en tranches est passé à __getitem__() comme un objet de tranche.

Ainsi, trainset[2:10] est équivalent à trainset.__getitem__(slice(2, 10))

Et puisque les deux types d'objets différents étant passés à __getitem__ devraient faire tout à fait différentes choses, vous devez les traiter explicitement.

Malheureusement, il ny a pas, comme vous pouvez le voir dans la mise en œuvre de la méthode __getitem__ de classe CIFAR10:

def __getitem__(self, index): 
    if self.train: 
     img, target = self.train_data[index], self.train_labels[index] 
    else: 
     img, target = self.test_data[index], self.test_labels[index] 

    # doing this so that it is consistent with all other datasets 
    # to return a PIL Image 
    img = Image.fromarray(img) 

    if self.transform is not None: 
     img = self.transform(img) 

    if self.target_transform is not None: 
     target = self.target_transform(target) 

    return img, target 
+0

exactement, je ne l'ai essayé avant. Je suis juste curieux quel est le type de torchvision.datasets.cifar.CIFAR10? – davidwangv5

+0

a ajouté plus de détails – entrophy

+0

Oh, je veux dire que cette explication a beaucoup de sens pour moi maintenant, merci beaucoup! – davidwangv5