2017-06-06 3 views
0

J'essaie de créer un réseau de neurones Feed Forward avec MXNetR. Mon entrée est une trame de données avec 6380 lignes et 180 colonnes. Mes sorties d'entraînement et de test sont des vecteurs unidimensionnels avec 319 éléments chacun. J'ai exécuté le modèle avec la taille de lot définie sur 1 et le nombre de neurones sur la couche de sortie défini sur 319. Donc, pour chaque lot, je m'attendais à obtenir un vecteur avec 319 éléments. Je vise à minimiser ma fonction de perte, qui est la corrélation entre mon vecteur de sortie prédit et le vecteur de sortie réel.Erreur liée à la forme des données lors de l'exécution de MXNetR

Ci-dessous mon code:

# Define the input data 
    data <- mx.symbol.Variable("data") 

    # Define the first fully connected layer 
    fc1 <- mx.symbol.FullyConnected(data, num_hidden = 100) 
    act.fun <- mx.symbol.Activation(fc1, act_type = "relu") # create a hidden layer with Rectified Linear Unit as its activation function. 
    output <<- mx.symbol.FullyConnected(act.fun, num_hidden = 319) 

    # Customize loss function 
    label <- mx.symbol.Variable("label") 
    lro <- 
     mx.symbol.MakeLoss(mx.symbol.Correlation(mx.symbol.reshape(output 
    ,shape = (1,319)),label)) 

    model <- mx.model.FeedForward.create(symbol=lro, X=train.x, 
             y=train.y, 
             eval.data = list(data = test.x, 
                 label = test.y), 
             num.round=5000, 
             array.batch.size=1, 
             optimizer = "adam", 
             learning.rate = 0.0003, 
             eval.metric = mx.metric.rmse, 
             epoch.end.callback = 
             mx.callback.log.train.metric(20, logger)) 

Et voici l'erreur quand je lance le code ci-dessus:

[15:49:28] /home/cgagnon/src/q5/mxnet/dmlc-core/include/dmlc/./logging.h:304: [15:49:28] src/operator/./correlation-inl.h:176: Check failed: dshape1.ndim() == 4U (2 vs. 4) data should be a 4D tensor 

Stack trace returned 10 entries: 
[bt] (0) /usr/lib64/R/library/mxnet/libs/libmxnet.so(_ZN4dmlc15LogMessageFatalD1Ev+0x29) [0x7f725a8528b9] 
[bt] (1) /usr/lib64/R/library/mxnet/libs/libmxnet.so(_ZNK5mxnet2op15CorrelationProp10InferShapeEPSt6vectorIN4nnvm6TShapeESaIS4_EES7_S7_+0x2a2) [0x7f725b4a8222] 
[bt] (2) /usr/lib64/R/library/mxnet/libs/libmxnet.so(+0xd461f9) [0x7f725b3241f9] 
[bt] (3) /usr/lib64/R/library/mxnet/libs/libmxnet.so(+0x116630f) [0x7f725b74430f] 
[bt] (4) /usr/lib64/R/library/mxnet/libs/libmxnet.so(+0x1167bb2) [0x7f725b745bb2] 
[bt] (5) /usr/lib64/R/library/mxnet/libs/libmxnet.so(_ZN4nnvm11ApplyPassesENS_5GraphERKSt6vectorISsSaISsEE+0x501) [0x7f725b761481] 
[bt] (6) /usr/lib64/R/library/mxnet/libs/libmxnet.so(_ZN4nnvm9ApplyPassENS_5GraphERKSs+0x8e) [0x7f725b699f2e] 
[bt] (7) /usr/lib64/R/library/mxnet/libs/libmxnet.so(_ZN4nnvm4pass10InferShapeENS_5GraphESt6vectorINS_6TShapeESaIS3_EESs+0x240) [0x7f725b69c520] 
[bt] (8) /usr/lib64/R/library/mxnet/libs/libmxnet.so(MXSymbolInferShape+0x281) [0x7f725b6959a1] 
[bt] (9) /usr/lib64/R/library/mxnet/libs/mxnet.so(_ZNK5mxnet1R6Symbol10InferShapeERKN4Rcpp6VectorILi19ENS2_15PreserveStorageEEE+0x6b9) [0x7f724cef6739] 

En ce moment, je suis la moindre idée de la façon dont je devrais corriger cette erreur. J'ai cherché un moyen de remodeler mes ensembles de données afin qu'ils soient des tenseurs 4D mais n'en ont trouvé aucun. Je ne cherche pas une solution explicite pour mon problème, mais toutes les suggestions sur la façon dont je devrais aborder cette erreur seraient grandement appréciées.

Répondre

0

Je ne pouvais pas reproduire le problème sans les données, mais je pense que si vous cherchez simplement à remodeler votre jeu de données dans des tenseurs 4D, vous devriez pouvoir le faire par "symbol.reshape (output, shape = c (1,1,1,319)) ". Je ne sais pas si cela vous aide.

+0

J'ai changé mon code comme vous l'avez suggéré mais le même type d'erreur apparaît toujours. Pour une raison de confidentialité, je ne peux pas partager mon ensemble de données avec vous, mais je crois que l'erreur réside dans la dimension de l'ensemble de données, pas le contenu. – nnguyen24