2017-04-19 5 views
0

Aidez-moi s'il vous plaît! Je travaille sur un projet en utilisant deeplearning4j. L'exemple MNIST fonctionne très bien mais j'ai une erreur avec mon dataset. Mon jeu de données a deux sorties.Erreur de code en utilisant MNIST exemple de deeplearning4j

int height = 45; 
int width = 800; 
int channels = 1; 
int rngseed = 123; 
Random randNumGen = new Random(rngseed); 
int batchSize = 128; 
int outputNum = 2; 
int numEpochs = 15; 
File trainData = new File("C:/Users/JHP/Desktop/learningData/training"); 
File testData = new File("C:/Users/JHP/Desktop/learningData/testing"); 
FileSplit train = new FileSplit(trainData, NativeImageLoader.ALLOWED_FORMATS, randNumGen); 
FileSplit test = new FileSplit(testData, NativeImageLoader.ALLOWED_FORMATS, randNumGen); 
ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator(); 

ImageRecordReader recordReader = new ImageRecordReader(height, width, channels, labelMaker); 
ImageRecordReader recordReader2 = new ImageRecordReader(height, width, channels, labelMaker); 
recordReader.initialize(train); 
recordReader2.initialize(test); 

DataSetIterator dataIter = new RecordReaderDataSetIterator(recordReader, batchSize, 1, outputNum); 
DataSetIterator testIter = new RecordReaderDataSetIterator(recordReader2, batchSize, 1, outputNum); 

DataNormalization scaler = new ImagePreProcessingScaler(0, 1); 
scaler.fit(dataIter); 
dataIter.setPreProcessor(scaler); 

System.out.println("Build model...."); 
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() 
     .seed(rngseed) 
     .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) 
     .iterations(1) 
     .learningRate(0.006) 
     .updater(Updater.NESTEROVS).momentum(0.9) 
     .regularization(true).l2(1e-4) 
     .list() 
     .layer(0, new DenseLayer.Builder() 
       .nIn(height * width) 
       .nOut(1000) 
       .activation(Activation.RELU) 
       .weightInit(WeightInit.XAVIER) 
       .build() 
       ) 
     .layer(1, newOutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD) 
       .nIn(1000) 
       .nOut(outputNum) 
       .activation(Activation.SOFTMAX) 
       .weightInit(WeightInit.XAVIER) 
       .build() 
       ) 
     .pretrain(false).backprop(true) 
     .build(); 

MultiLayerNetwork model = new MultiLayerNetwork(conf); 
model.init(); 
model.setListeners(new ScoreIterationListener(1)); 

System.out.println("Train model...."); 
for (int i = 0; i < numEpochs; i++) { 
    try { 
     model.fit(dataIter); 
    } catch (Exception e) { 
     System.out.println(e); 
    } 
} 

erreur est

org.deeplearning4j.exception.DL4JInvalidInputException: entrée qui est pas une matrice; matrice attendu (rang 2), obtenu tableau de rang 4 avec la forme [128, 1, 45, 800]

+0

Je pense qu'il est nécessaire de changer la fonction DataSetIterator à une autre fonction. Dans le cas de l'exemple MNIST, c'est comme importer des données dans une fonction. ** DataSetIterator mnistTrain = nouveau MnistDataSetIterator (batchSize, true, rngseed); ** Je ne sais pas quelle fonction utiliser. – user7887249

+0

@ TriV TriV Merci beaucoup de me laisser savoir ce qu'il faut améliorer! Je ne le savais pas parce que j'utilisais le débordement de pile pour la première fois. Merci beaucoup! – user7887249

Répondre

0

Vous initialiser le mauvais réseau de neurones. Si vous regardez de plus près chaque exemple cnn exemple dans le repo des exemples dl4j (indice: c'est la source canonique de l'endroit où vous devriez tirer le code, toute autre chose sera probablement invalide ou obsolète:) Vous remarquerez dans tous nos exemples, nous avons une configuration inputType: https://github.com/deeplearning4j/dl4j-examples/blob/master/dl4j-examples/src/main/java/org/deeplearning4j/examples/convolution/LenetMnistExample.java#L114

Il existe différents types que vous devriez utiliser vous ne devriez jamais ensemble nDans manuellement. Juste nOut. Pour mnist, nous utilisons un appartement convolutif et le convertissons en un ensemble de données 4d automatiquement pour vous. Mnist commence comme un vecteur plat, mais un cnn ne comprend que les données 3d. Nous faisons cette transition et refaçonnons pour vous.