2017-09-23 1 views
0

L'entrée de mon réseau est une image RVB avec des dimensions n m, comment puis-je obtenir la sortie pour avoir des dimensions de n m.n * m * 3 image d'entrée vers une étiquette nxm dans PyTorch

class Net(nn.Module): 
    def __init__(self): 
     super(Net, self).__init__() 
     self.conv1 = nn.Conv2d(3, 20, kernel_size = 5) 
     self.conv2 = nn.Conv2d(20, 50, kernel_size = 3) 
     self.conv3 = nn.ConvTranspose2d(50,20, kernel_size = 5) 
     self.conv4 = nn.ConvTranspose2d(20,1, kernel_size = 3) 

    def forward(self, x): 
     x = F.relu(self.conv1(x)) 
     x = F.relu(self.conv2(x)) 
     x = F.relu(self.conv3(x)) 
     x = F.relu(self.conv4(x)) 
    return x 

Je génère actuellement un 1 * n * m. Comment puis-je sortir un n * m?

Répondre

0

Si vous souhaitez remodeler un Tenseur en une taille différente mais avec le même nombre d'éléments, vous pouvez généralement utiliser torch.view.

Pour votre cas, il existe une solution encore plus simple: torch.squeeze renvoie un Tenseur avec toutes les dimensions de taille 1 supprimées.