2017-10-15 17 views
0

Je voulais voir que les images que j'ai utilisé dans mon réseau étaient OK, donc je sauvé un tas d'entre eux en utilisant le code suivant:Est-ce que le chargeur torche MNIST ne fonctionne pas correctement ou est-ce que je fais quelque chose de mal?

train_set = dset.MNIST(root=root, train=True, transform=transforms.ToTensor(), download=download) 

for it, (img, target) in enumerate(train_loader): 
    X = Variable(img) 
    tar = Variable(target) 
    X = X.view(batch_size, -1) 
    cur_img_batch = X.data.numpy() 
    cur_tar_batch = tar.data.numpy() 
    for i in range(batch_size): 
     cur_img = cur_img_batch[i] 
     im = Image.fromarray(cur_img.reshape((28, 28)).astype('uint8') * 255) 
     if cur_tar_batch[i] == 8: 
      im.save(test_img_dir + 'iter_' + str(it) + '_sample_' + str(i) + '.png') 

Ce n'est pas le code plus propre, mais il sauve tout un tas des images qui sont toutes étiquetées comme '8'. En les ouvrant, je vois que la plupart d'entre eux ressemblent à this, même si une petite minorité d'entre eux sont parfaitement fine.

Ai-je fait quelque chose de mal?

+0

Cette ligne 'cur_img.reshape ((28, 28)). astype ('uint8') * 255'-vous convertir les données en entier avant de multiplier par 255? –

+0

Bien sûr! C'était bien - merci beaucoup :) –

+0

La ligne correcte devrait être: im = Image.fromarray ((cur_img.reshape ((28, 28)) * 255) .astype ('uint8')) –

Répondre

0

D'après les commentaires:

La question était dans cette ligne cur_img.reshape((28, 28)).astype('uint8') * 255, convertir l'image normalisée à entier avant de multiplier par 255, entraînant ainsi des images avec 0 ou 255.

Le code mis à jour:

train_set = dset.MNIST(root=root, train=True, transform=transforms.ToTensor(), download=download) 

for it, (img, target) in enumerate(train_loader): 
    X = Variable(img) 
    tar = Variable(target) 
    X = X.view(batch_size, -1) 
    cur_img_batch = X.data.numpy() 
    cur_tar_batch = tar.data.numpy() 
    for i in range(batch_size): 
     cur_img = cur_img_batch[i] 
     im = Image.fromarray((cur_img.reshape((28, 28)) * 255).astype('uint8')) 
     if cur_tar_batch[i] == 8: 
      im.save(test_img_dir + 'iter_' + str(it) + '_sample_' + str(i) + '.png')