2017-08-22 4 views
0

Voici un modèle de jouet. J'imprime les paramètres du modèle avant d'appeler le backward exactement une fois, puis j'imprime à nouveau les paramètres du modèle. Les paramètres sont inchangés. Si j'ajoute la ligne model:updateParameters(<learning_rate>) après avoir appelé backward, je vois la mise à jour des paramètres.Torch: Comment les paramètres du modèle sont-ils mis à jour?

Mais dans l'exemple de code que j'ai vu, par exemple https://github.com/torch/demos/blob/master/train-a-digit-classifier/train-on-mnist.lua, personne n'appelle réellement updateParameters. En outre, il ne ressemble pas optim.sgd, optim.adam, ou nn.StochasticGradient jamais appeler updateParameters soit. Qu'est-ce que j'oublie ici? Comment les paramètres sont-ils mis à jour automatiquement? Si je dois appeler updateParameters, pourquoi aucun exemple ne fait cela?

require 'nn' 
require 'optim' 

local model = nn.Sequential() 
model:add(nn.Linear(4, 1, false)) 
local params, grads = model:getParameters() 

local criterion = nn.MSECriterion() 
local inputs = torch.randn(1, 4) 
local labels = torch.Tensor{1} 

print(params) 

model:zeroGradParameters() 
local output = model:forward(inputs) 
local loss = criterion:forward(output, labels) 
local dfdw = criterion:backward(output, labels) 
model:backward(inputs, dfdw) 

-- With the line below uncommented, the parameters are updated: 
-- model:updateParameters(1000) 

print(params) 

Répondre

1

Le backward() est pas censé modifier les paramètres, il calcule simplement les dérivées de la fonction d'erreur par rapport à tous les paramètres du réseau.

En général, la formation est la séquence des étapes:

repeat 
    local output = model:forward(input) --see what model predicts 
    local loss = criterion:forward(output, answer) --see how wrong it is 
    local loss_grad = criterion:backward(output, answer) --see where it is the most wrong 
    model:backward(input,loss_grad) --see how much each particular parameter of network is responsible for error 
    model:updateParameters(learningRate) --fix the parameters based on their wrongness 
    model:zeroGradParameters() --network parameters are different now, so old gradients are of no use now 
until is_user_satisfied() 

updateParameters implémente l'algorithme d'optimisation la plus simple ici (descente de gradient). Si vous le souhaitez, vous pouvez utiliser votre propre fonction à la place. En théorie, vous pouvez effectuer des boucles explicites à travers les stockages réseau pour mettre à jour leurs valeurs. Dans la pratique, vous appelez habituellement getParameters()

local model_parameters,model_parameters_gradient=model:getParameters() 

Ce qui vous donne tenseurs homogènes de toutes les valeurs et les gradients. Ces tenseurs sont des vues à l'intérieur du réseau, de sorte que les changements affectent le réseau. Vous ne pouvez pas savoir quel point du réseau correspond à quelle valeur, mais la plupart des optimiseurs ne s'en soucient pas.

L'utilisation demo de optim.sgd est la suivante:

optim.sgd(
    function_to_return_error_and_its_gradients, 
    model_parameters, 
    optimizer_special_settings) 

Les détails sont couverts dans la démo, mais ici, il est pertinent que optimiseur reçoit le model_parameters comme un paramètre qui donne un accès en écriture au réseau. Et il n'est pas explicitement indiqué dans la documentation, mais dans le source code on voit, que l'optimiseur change les valeurs de son tenseur d'entrée (notez aussi qu'il renvoie le tenseur reçu).

+0

Très complet, merci. Donc, pour réitérer, 'updateParameters' est utilisé pour mettre à jour" manuellement "les paramètres. Et les algorithmes d'optimisation mettent simplement à jour leur référence aux paramètres aplatis car l'appel de 'updateParameters' ne tient pas compte des optimisations. – gwg