2017-06-16 1 views
7

J'ai trois questions simples. Que se passe-t-il si ma fonction de perte personnalisée n'est pas différentiable?Fonction de perte personnalisée dans PyTorch

  1. Va pytorcher par erreur ou faire autre chose?
  2. Si je déclare une variable de perte dans ma fonction personnalisée qui représentera la perte finale du modèle, dois-je mettre requires_grad = True pour cette variable? ou ça n'a pas d'importance? Si cela n'a pas d'importance, alors pourquoi?
  3. J'ai vu des gens parfois écrire une couche distincte et calculer la perte dans la fonction forward. Quelle approche est préférable, écrire une fonction ou une couche? Pourquoi?

J'ai besoin d'une explication claire et agréable à ces questions pour résoudre mes confusions. S'il vous plaît aider.

Répondre

7

Laissez-moi tenter votre chance.

  1. Cela dépend de ce que vous entendez par "non différentiable". La première définition qui a du sens ici est que PyTorch ne sait pas comment calculer les dégradés. Si vous essayez néanmoins de calculer des dégradés, cela provoquera une erreur. Les deux scénarios possibles sont les suivants:

    a) Vous utilisez une opération PyTorch personnalisée pour laquelle des gradients n'ont pas été implémentés, par ex. torch.svd(). Dans ce cas, vous obtiendrez un TypeError:

    import torch 
    from torch.autograd import Function 
    from torch.autograd import Variable 
    
    A = Variable(torch.randn(10,10), requires_grad=True) 
    u, s, v = torch.svd(A) # raises TypeError 
    

    b) Vous avez mis en place votre propre opération, mais n'a pas défini backward(). Dans ce cas, vous obtiendrez un NotImplementedError:

    class my_function(Function): # forgot to define backward() 
    
        def forward(self, x): 
         return 2 * x 
    
    A = Variable(torch.randn(10,10)) 
    B = my_function()(A) 
    C = torch.sum(B) 
    C.backward() # will raise NotImplementedError 
    

    La deuxième définition qui fait sens est « mathématiquement non-différentiables ». Clairement, une opération mathématiquement non différentiable ne doit pas avoir une méthode backward() implémentée ou un sous-gradient sensible. Considérons par exemple torch.abs() dont backward() méthode renvoie le 0 à 0 sous-gradient:

    A = Variable(torch.Tensor([-1,0,1]),requires_grad=True) 
    B = torch.abs(A) 
    B.backward(torch.Tensor([1,1,1])) 
    A.grad.data 
    

    Pour ces cas, vous devez vous référer à la documentation PyTorch directement et creuser la méthode backward() de l'opération respective directement.

  2. Cela n'a pas d'importance. L'utilisation de requires_grad est d'éviter les calculs inutiles de gradients pour les sous-graphes. S'il existe une seule entrée pour une opération nécessitant un dégradé, sa sortie nécessite également un dégradé. Inversement, seulement si toutes les entrées ne nécessitent pas de gradient, la sortie ne l'exigera pas non plus. Le calcul arrière n'est jamais effectué dans les sous-graphes, où toutes les variables ne nécessitaient pas de gradients. Puisqu'il y a probablement des Variables (par exemple les paramètres d'une sous-classe de nn.Module()), votre variable loss exigera également des gradients automatiquement. Cependant, vous remarquerez que exactement pour comment fonctionne requires_grad (voir ci-dessus à nouveau), vous pouvez seulement changer requires_grad pour les variables de feuilles de votre graphique de toute façon.

  3. Toutes les fonctions de perte de PyTorch personnalisées sont des sous-classes de _Loss qui est une sous-classe de nn.Module.See here. Si vous souhaitez respecter cette convention, vous devez sous-classer _Loss lors de la définition de votre fonction de perte personnalisée. En dehors de la cohérence, un avantage est que votre sous-classe soulèvera un AssertionError, si vous n'avez pas marqué vos variables cibles comme volatile ou requires_grad = False. Un autre avantage est que vous pouvez imbriquer votre fonction de perte dans nn.Sequential(), parce que c'est un nn.Module je recommanderais cette approche pour ces raisons.

+0

De rien. Impossible d'ouvrir le lien malheureusement. – mexmex

+0

J'ai supprimé la question parce que je l'ai résolu. Mais pouvez-vous m'aider dans cette question - https://stackoverflow.com/questions/44580450/cuda-vs-dataparallel-why-the-difference? –