J'essaie d'ajouter une nouvelle couche à un réseau existant (en tant que première couche) et de l'entraîner sur l'entrée d'origine. Quand j'ajoute un calque convolutif, tout fonctionne parfaitement mais quand je le change en linéaire, il ne semble pas s'entraîner. Des idées pourquoi? Voici l'ensemble du réseau:Ajout d'une couche linéaire à un modèle existant sur Pytorch
class ActorCritic(torch.nn.Module): #original model
def __init__(self, num_inputs, action_space):
super(ActorCritic, self).__init__()
self.conv1 = nn.Conv2d(num_inputs, 32, 3, stride=2, padding=1)
self.conv2 = nn.Conv2d(32, 32, 3, stride=2, padding=1)
self.conv3 = nn.Conv2d(32, 32, 3, stride=2, padding=1)
self.conv4 = nn.Conv2d(32, 32, 3, stride=2, padding=1)
self.lstm = nn.LSTMCell(32 * 3 * 3, 256)
num_outputs = action_space.n
self.critic_linear = nn.Linear(256, 1)
self.actor_linear = nn.Linear(256, num_outputs)
def forward(self, inputs):
inputs, (hx, cx) = inputs
x = F.elu(self.conv1(inputs))
x = F.elu(self.conv2(x))
x = F.elu(self.conv3(x))
x = F.elu(self.conv4(x))
x = x.view(-1, 32 * 3 * 3)
hx, cx = self.lstm(x, (hx, cx))
x = hx
return self.critic_linear(x), self.actor_linear(x), (hx, cx)
class TLModel(torch.nn.Module): #new model
def __init__(self, pretrained_model, num_inputs):
super(TLModel, self).__init__()
self.new_layer = nn.Linear(1*1*42*42, 1*1*42*42)
self.pretrained_model = pretrained_model
def forward(self, inputs):
inputs, (hx, cx) = inputs
x = F.elu(self.new_layer(inputs.view(-1, 1*1*42*42)))
return self.pretrained_model((x.view(1,1,42,42), (hx, cx)))
J'ai essayé différentes fonctions d'activation (non seulement ELU). il fonctionne avec conv:
class TLModel(torch.nn.Module):
def __init__(self, pretrained_model, num_inputs):
super(TLModel, self).__init__()
self.new_layer = nn.Conv2d(num_inputs, num_inputs, 1)
self.pretrained_model = pretrained_model
def forward(self, inputs):
inputs, (hx, cx) = inputs
x = F.elu(self.new_layer(inputs))
return self.pretrained_model((x, (hx, cx)))
Le nombre d'entrées est 1 et la taille d'une entrée est 1x1x42x42
En fait, je ne reçois aucun message d'erreur, il ne semble pas s'entraîner avec linéaire (mais fait avec conv). –