J'ai construit un projet DL4j. Tout va bien si je DataSet MNIST comme suit:DeepLearning4j et DataVec lire le fichier csv avec l'étiquette
DataSetIterator mnistTrain = new MnistDataSetIterator(batchSize, true, rngSeed);
DataSetIterator mnistTest = new MnistDataSetIterator(batchSize, false, rngSeed);
Cependant, je veux passer à mon propre fichier csv avec le format suivant:
A | B | C | X | Y
-------------------------
1 | 100 | 5 | 15 | 6
...
X
et Y
sont les résultats (ou les étiquettes). Comme je prévois d'effectuer une analyse de régression, les deux X
et Y
sont des nombres réels. Donc, je l'ai lu le fichier csv en utilisant le code suivant:
RecordReader recordReaderTrain = new CSVRecordReader(1, ",");
recordReaderTrain.initialize(new FileSplit(new File("src/main/resources/data/Data.csv")));
DataSetIterator dataIterTrain = new RecordReaderDataSetIterator(recordReaderTrain, batchSize, 3, 2);
3
dans le code signifie index of the labels
et 2
signifie number of possible labels
. Il n'y a pas beaucoup d'explications sur ces deux paramètres. Je suppose qu'ils veulent dire que les étiquettes commencent à partir de la 4ème colonne et ont 2 étiquettes.
Quand je lance le code, il montre l'exception suivante:
Exception in thread "main" java.lang.ArrayIndexOutOfBoundsException: 14
Je pense qu'il est parce que dl4j ne reconnaît pas 15
comme étiquette. Donc, ma question est: comment puis-je lire correctement le fichier csv pour une analyse de régression?
Merci beaucoup.
Nous vous remercions de votre réponse. Ça fonctionne bien. Puis-je poser une autre question? Comment est-ce que je peux convertir le DataSet dans le DataSetIterator dans une entrée du type '5' par' 3' de graphique de sorte que je puisse employer le réseau de convolution? –
Esprit publiant une deuxième question avec plus de détails? Merci! –