2017-05-04 1 views
1

Ceci est probablement une question débutant, mais quand même: Lors de l'exécution d'un classificateur d'image construire avec pytorch, je reçois cette erreur:Pytorch, TypeError: objet() ne prend aucun paramètre

Traceback (most recent call last): 
File "/pytorch/kanji_torch.py", line 47, in <module> 
    network = Network() 
    File "/pytorch/kanji_torch.py", line 113, in __init__ 
    self.conv1 = nn.Conv2d(1, 32, 5) 
    File "/python3.5/site-packages/torch/nn/modules/conv.py", line 233, in __init__ 
    False, _pair(0), groups, bias) 
    File "/python3.5/site-packages/torch/nn/modules/conv.py", line 32, in __init__ 
    out_channels, in_channels // groups, *kernel_size)) 
TypeError: object() takes no parameters 

Je définit la classe de réseau comme this:

class Network(torch.nn.Module): 
    def __init__(self): 
     super(Network, self).__init__() 
     self.conv1 = nn.Conv2d(1, 32, 5) 
     self.pool = nn.MaxPool2d(2, 2) 
     self.conv2 = nn.Conv2d(32, 64, 5) 
     self.pool2 = nn.MaxPool2d(2, 2) 
     self.conv3 = nn.Conv2d(64, 64, 5) 
     self.pool2 = nn.MaxPool2d(2, 2) 
     self.fc1 = nn.Linear(64 * 5 * 5, 512) 
     self.fc2 = nn.Linear(512, 640) 
     self.fc3 = nn.Linear(640, 3756) 

Je suis à peu près certain d'avoir importé correctement tous les modules de bibliothèque pytorch pertinents. (import torch.nn comme nn et
torche importation )

Toute idée de ce que je fais mal?

Merci!

+0

Non, désolé que c'était une erreur, et je l'ai corrigé – Sumaku

+0

Traceback semble pointer le fichier '/ pytorch/blitz.py' (btw je suis intrigué par le fichier'/pytorch/.py'). Dans le code que vous avez collé, la classe s'appelle 'Network' mais la traceback parle de' Net'. Avez-vous mis à jour des choses avant de coller le code? – Arount

+0

qui était bâclée de mon côté. J'ai eu deux scipts avec le même problème. Je les ai fait commuter, et j'ai pensé que j'ai changé tous les mauvais noms. La négligence a été éditée maintenant. merci – Sumaku

Répondre

0

Vous pourriez avoir un problème avec votre version pytorch, quand je lance le code:

class Network(torch.nn.Module): 
    def __init__(self): 
     super(Network, self).__init__() 
     self.conv1 = nn.Conv2d(1, 32, 5) 
     self.pool = nn.MaxPool2d(2, 2) 
     self.conv2 = nn.Conv2d(32, 64, 5) 
     self.pool2 = nn.MaxPool2d(2, 2) 
     self.conv3 = nn.Conv2d(64, 64, 5) 
     self.pool2 = nn.MaxPool2d(2, 2) 
     self.fc1 = nn.Linear(64 * 5 * 5, 512) 
     self.fc2 = nn.Linear(512, 640) 
     self.fc3 = nn.Linear(640, 3756) 
print(network) 

La sortie est:

Network (
    (conv1): Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 1)) 
    (pool): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1)) 
    (conv2): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1)) 
    (pool2): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1)) 
    (conv3): Conv2d(64, 64, kernel_size=(5, 5), stride=(1, 1)) 
    (fc1): Linear (1600 -> 512) 
    (fc2): Linear (512 -> 640) 
    (fc3): Linear (640 -> 3756) 
) 

Je suggère de mettre à jour/réinstaller pytorch.

+0

Cela a fonctionné. Merci! J'ai eu un problème similaire avec torch.Variable, qui est également résolu maintenant. – Sumaku