2017-10-11 9 views
0

J'essaie de calculer une passe avant en utilisant un modèle ResNet pré-formé dans pytorch. J'ai du mal à créer un Tenseur de 4-D en mini-lots. Quelqu'un peut-il dire s'il vous plaît quelle est la bonne façon de le faire?comment faire efficacement un mini-lot d'images dans pytorch?

EDIT: J'ai changé le code et cela fonctionne maintenant. Cependant, je pense toujours qu'il devrait y avoir un moyen plus efficace de le faire.

Voici mon code:

import pickle 
import json 
import shutil 
import Image 
import torchvision.models as models 
import torchvision.transformers as transformers 
from torch.autograd import Variable 
from torch import Tensor 
import glob 
import torch 

batch_size = 128 
im_size = 299 

normalize = transforms.Normalize(
    mean=[0.485, 0.456, 0.406], 
    std=[0.229, 0.224, 0.225] 
) 
preprocess = transforms.Compose([ 
    transforms.Scale(im_size), 
    transforms.CenterCrop(im_size), 
    transforms.ToTensor(), 
    normalize 
]) 


model = models.resnet50(pretrained=True) 

d_batch = make_batch(imgs, batch_size) 

dtype = torch.FloatTensor 
tmp = Variable(torch.randn(batch_size, 3, im_size, im_size).type(dtype), requires_grad=False) 


for batch in tqdm(batches): 
     try: 
       data = [Image.open(img) for img in batch] 
       for idx, item in enumerate(data): 
         tmp[idx] = preprocess(item) 
       batch_result = model(tmp) 
     except Exception,x: 
       print x 
+1

Vous pouvez créer un 4d Tensor comme ceci: torch.Tensor (1,1,1,1). Ou pour ajouter une dimension à n'importe quel tenseur (ou variable), vous pouvez faire t.unsqueeze (0). Mais je ne sais pas comment cela va vous aider. Vous devez nous donner l'erreur, ou plus d'indices où vous êtes coincé. – blckbird

+0

Bienvenue dans StackOverflow. Veuillez lire et suivre les consignes de publication dans la documentation d'aide. [Exemple minimal, complet, vérifiable] (http://stackoverflow.com/help/mcve) s'applique ici. Nous ne pouvons pas vous aider efficacement tant que vous n'afficherez pas votre code MCVE et que vous ne décrivez pas précisément le problème. Nous devrions pouvoir coller votre code posté dans un fichier texte et reproduire le problème que vous avez décrit. – Prune

+1

Avez-vous essayé DataLoader (vous pouvez le trouver dans torch.utils.data) dans pytorch ?? Il fait des minibatches pour vous en utilisant le multitraitement – Kashyap

Répondre

0

En utilisant dataset = torchvision.datasets.ImageFolder(...) vous pouvez charger un jeu de données à partir du dossier d'image. Après cela, vous pouvez utiliser torch.utils.data.DataLoader(dataset, batch_size=batchSize) pour spécifier la taille mini-lot et d'autres choses pour un traitement ultérieur.