2017-03-08 5 views
2

Je suis en train de tester le modèle RNN de mxnet. Le tutoriel here ne fonctionne pas et le message d'erreur indique que de nombreuses fonctions ont été abandonnées. Je n'ai pas trouvé le tutoriel à jour pour RNN. Il existe encore quelques exemples dans le projet mxnet. Mais pour RNN, les examples montrent seulement comment former un modèle en utilisant un ensemble de formation. Ils ne montrent pas comment utiliser le modèle formé pour faire d'autres prédictions. Le code de formation est la suivante:mxnet: comment faire une prédiction en utilisant un modèle RNN formé

model.fit(
    train_data   = data_train, 
    eval_data   = data_val, 
    eval_metric   = mx.metric.Perplexity(invalid_label), 
    kvstore    = args.kv_store, 
    optimizer   = args.optimizer, 
    optimizer_params = { 'learning_rate': args.lr, 
          'momentum': args.mom, 
          'wd': args.wd }, 
    initializer   = mx.init.Xavier(factor_type="in", magnitude=2.34), 
    num_epoch   = args.num_epochs, 
    batch_end_callback = mx.callback.Speedometer(args.batch_size, args.disp_batches)) 

ce que quelqu'un sait comment utiliser le modèle formé RNN pour faire une inférence ou prédiction?

Je dois clearifier que je cherche comment utiliser RNN modèle pour faire la prédiction, pas CNN ou d'autres modèles.

Merci beaucoup de m'avoir aidé !!!

+0

https://github.com/dmlc/mxnet/blob/master/example/rnn/cudnn_lstm_bucketing.py a à la fois le code de train et d'essai . Est ce que ça aide? –

+2

Non. Mais les exemples dans https://github.com/dmlc/mxnet/tree/master/python/mxnet/module aident. – pfc

+1

@pfc Si vous avez trouvé la réponse, allez-vous répondre à votre propre question pour d'autres personnes qui pourraient avoir besoin de la même aide? – lynguyen

Répondre

1

Habituellement le modèle s'étend BaseModel classe. Et BaseModel a the method predict. La méthode peut fonctionner avec le même type que celui utilisé par la méthode fit: DataIter avec une seule différence, elle ne nécessite pas train_data, seulement eval_data. Ainsi, le processus de prédiction réelle peut être mis en œuvre de manière simple comme ceci:

result = mod.predict(dataiter.next)