2016-02-25 1 views
2

Je m'entraîne sur l'apprentissage du réseau de neurones. Il y a une fonction que je ne peux pas faire apprendre à mon réseau de neurones: f(x) = max(x_1, x_2). Il semble être une fonction très simple avec 2 entrées et 1 entrée, mais un réseau de neurones à 3 couches formé à plus d'un millier d'échantillons avec 2000 époques le fait complètement faux. J'utilise deeplearning4j.Régresser la fonction max sur un réseau de neurones

Y a-t-il une raison pour laquelle la fonction max serait très difficile à apprendre pour un réseau de neurones ou est-ce que je ne l'ai pas réglé correctement?

Répondre

0

Ce n'est pas si difficile, du moins, si vous restreignez x1 et x2 pour être dans un intervalle, par ex. entre [0,3]. Pris le « RegressionSum » exemple des exemples de deeplearning4j je me suis vite réécrite pour apprendre max somme à la place et il fonctionne très bien me donner que les résultats:

Max(0.6815540048808918,0.3112081053899819) = 0.64 
Max(2.0073597506364407,1.93796211086664) = 2.09 
Max(1.1792029272560556,2.5514324329058233) = 2.58 
Max(2.489185375059013,0.0818746888836388) = 2.46 
Max(2.658169689797984,1.419135581889197) = 2.66 
Max(2.855509810112818,2.9661811672685086) = 2.98 
Max(2.774757710538552,1.3988513143140069) = 2.79 
Max(1.5852295273047565,1.1228662895771744) = 1.56 
Max(0.8403435207065576,2.5595015474951195) = 2.60 
Max(0.06913178775631723,2.61883825802004) = 2.54 

Ci-dessous mon version modifiée de l'exemple RegressionSum, qui était à l'origine de Anwar 15/03/16:

public class RegressionMax { 
    //Random number generator seed, for reproducability 
    public static final int seed = 12345; 
    //Number of iterations per minibatch 
    public static final int iterations = 1; 
    //Number of epochs (full passes of the data) 
    public static final int nEpochs = 200; 
    //Number of data points 
    public static final int nSamples = 10000; 
    //Batch size: i.e., each epoch has nSamples/batchSize parameter updates 
    public static final int batchSize = 100; 
    //Network learning rate 
    public static final double learningRate = 0.01; 
    // The range of the sample data, data in range (0-1 is sensitive for NN, you can try other ranges and see how it effects the results 
    // also try changing the range along with changing the activation function 
    public static int MIN_RANGE = 0; 
    public static int MAX_RANGE = 3; 

    public static final Random rng = new Random(seed); 

    public static void main(String[] args){ 

     //Generate the training data 
     DataSetIterator iterator = getTrainingData(batchSize,rng); 

     //Create the network 
     int numInput = 2; 
     int numOutputs = 1; 
     int nHidden = 10; 
     MultiLayerNetwork net = new MultiLayerNetwork(new NeuralNetConfiguration.Builder() 
       .seed(seed) 
       .iterations(iterations) 
       .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) 
       .learningRate(learningRate) 
       .weightInit(WeightInit.XAVIER) 
       .updater(Updater.NESTEROVS).momentum(0.9) 
       .list() 
       .layer(0, new DenseLayer.Builder().nIn(numInput).nOut(nHidden) 
         .activation("tanh") 
         .build()) 
       .layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MSE) 
         .activation("identity") 
         .nIn(nHidden).nOut(numOutputs).build()) 
       .pretrain(false).backprop(true).build() 
     ); 
     net.init(); 
     net.setListeners(new ScoreIterationListener(1)); 


     //Train the network on the full data set, and evaluate in periodically 
     for(int i=0; i<nEpochs; i++){ 
      iterator.reset(); 
      net.fit(iterator); 
     } 

     // Test the max of some numbers (Try different numbers here) 
     Random rand = new Random(); 
     for (int i= 0; i< 10; i++) { 
      double d1 = MIN_RANGE + (MAX_RANGE - MIN_RANGE) * rand.nextDouble(); 
      double d2 = MIN_RANGE + (MAX_RANGE - MIN_RANGE) * rand.nextDouble(); 
      INDArray input = Nd4j.create(new double[] { d1, d2 }, new int[] { 1, 2 }); 
      INDArray out = net.output(input, false); 
      System.out.println("Max(" + d1 + "," + d2 + ") = " + out); 
     } 

    } 

    private static DataSetIterator getTrainingData(int batchSize, Random rand){ 
     double [] max = new double[nSamples]; 
     double [] input1 = new double[nSamples]; 
     double [] input2 = new double[nSamples]; 
     for (int i= 0; i< nSamples; i++) { 
      input1[i] = MIN_RANGE + (MAX_RANGE - MIN_RANGE) * rand.nextDouble(); 
      input2[i] = MIN_RANGE + (MAX_RANGE - MIN_RANGE) * rand.nextDouble(); 
      max[i] = Math.max(input1[i], input2[i]); 
     } 
     INDArray inputNDArray1 = Nd4j.create(input1, new int[]{nSamples,1}); 
     INDArray inputNDArray2 = Nd4j.create(input2, new int[]{nSamples,1}); 
     INDArray inputNDArray = Nd4j.hstack(inputNDArray1,inputNDArray2); 
     INDArray outPut = Nd4j.create(max, new int[]{nSamples, 1}); 
     DataSet dataSet = new DataSet(inputNDArray, outPut); 
     List<DataSet> listDs = dataSet.asList(); 
     Collections.shuffle(listDs,rng); 
     return new ListDataSetIterator(listDs,batchSize); 

    } 
}