2017-08-05 1 views
0

J'ai un réseau avec trois branches parallèles, et je veux partager tous leurs paramètres afin qu'ils soient identiques à la fin de la formation. Soit some_model soit un module nn.Sequential standard constitué de cudnn.SpatialConvolution, nn.PReLU, nn.SpatialBatchNormalization. De plus, il existe un nn.SpatialDropout, mais sa probabilité est 0, donc il n'a aucun effet.Partage de paramètres en réseau avec nn.SpatialBatchNormalization

ptb=nn.ParallelTable() 
ptb:add(some_model) 
ptb:add(some_model:clone('weight','bias', 'gradWeight','gradBias')) 
ptb:add(some_model:clone('weight','bias', 'gradWeight','gradBias')) 

triplet=nn.Sequential() 
triplet:add(ptb) 

Je ne pense pas que la fonction de perte est pertinente, mais juste au cas où, j'utilise nn.DistanceRatioCriterion. Pour vérifier que tous les poids sont correctement partagés, je passe une table de trois exemples identiques {A,A,A} au réseau. Évidemment, si les poids sont correctement partagés, la sortie des trois branches doit être la même. Cela vaut au moment de l'initialisation du réseau, mais une fois que les paramétreurs ont été mis à jour (disons, après une itération de mini-lot), les résultats des trois branches deviennent différents. Grâce à l'inspection couche par couche, j'ai remarqué que cette différence dans la sortie provient des couches nn.SpatialBatchNormalization en some_model. Par conséquent, il semble que les paramètres de ces couches ne sont pas partagés correctement. Après this, j'ai essayé d'appeler clone avec les paramètres supplémentaires running_mean et running_std, mais la sortie des couches batchnorm diffère encore. De plus, cela semble annuler le partage de tous les autres paramètres du réseau. Quelle est la bonne façon de partager les paramètres entre les modules nn.SpatialBatchNormalization?

Répondre

2

Ok, j'ai trouvé la solution! Il semble que le paramètre running_std a été modifié en running_var depuis the discussion I had linked to in the question. Appeler le constructeur avec

ptb:add(some_model:clone('weight','bias', 'gradWeight','gradBias','running_mean','running_var')) 

Résout le problème.