2017-08-07 3 views
0

J'ai construit un projet DL4j. Tout va bien si je DataSet MNIST comme suit:DeepLearning4j et DataVec lire le fichier csv avec l'étiquette

DataSetIterator mnistTrain = new MnistDataSetIterator(batchSize, true, rngSeed); 
    DataSetIterator mnistTest = new MnistDataSetIterator(batchSize, false, rngSeed); 

Cependant, je veux passer à mon propre fichier csv avec le format suivant:

A | B | C | X | Y 
------------------------- 
1 | 100 | 5 | 15 | 6 
... 

X et Y sont les résultats (ou les étiquettes). Comme je prévois d'effectuer une analyse de régression, les deux X et Y sont des nombres réels. Donc, je l'ai lu le fichier csv en utilisant le code suivant:

RecordReader recordReaderTrain = new CSVRecordReader(1, ","); 
    recordReaderTrain.initialize(new FileSplit(new File("src/main/resources/data/Data.csv"))); 
    DataSetIterator dataIterTrain = new RecordReaderDataSetIterator(recordReaderTrain, batchSize, 3, 2); 

3 dans le code signifie index of the labels et 2 signifie number of possible labels. Il n'y a pas beaucoup d'explications sur ces deux paramètres. Je suppose qu'ils veulent dire que les étiquettes commencent à partir de la 4ème colonne et ont 2 étiquettes.

Quand je lance le code, il montre l'exception suivante:

Exception in thread "main" java.lang.ArrayIndexOutOfBoundsException: 14 

Je pense qu'il est parce que dl4j ne reconnaît pas 15 comme étiquette. Donc, ma question est: comment puis-je lire correctement le fichier csv pour une analyse de régression?

Merci beaucoup.

Répondre

1

droit que nous avons des exemples de régression: https://github.com/deeplearning4j/dl4j-examples/tree/cc383de91bdf4e28e36859aa2e8749100cd63177/dl4j-examples/src/main/java/org/deeplearning4j/examples/feedforward/regression

Vous devez passer une régression vraie (c'est une partie supplémentaire du constructeur) au RecordReaderDataSetIterator.

+0

Nous vous remercions de votre réponse. Ça fonctionne bien. Puis-je poser une autre question? Comment est-ce que je peux convertir le DataSet dans le DataSetIterator dans une entrée du type '5' par' 3' de graphique de sorte que je puisse employer le réseau de convolution? –

+0

Esprit publiant une deuxième question avec plus de détails? Merci! –