2017-08-16 4 views
1

Je suis la description et exemple pour la création d'un itérateur personnalisé comme décrit ici: http://mxnet.io/tutorials/basic/data.htmlexemple pour itérateur personnalisé ne fonctionne pas

Le code suivant produit un ValueError:

mod.fit(data_iter, num_epoch=5) 

ValueError: Shape of labels 0 does not match shape of predictions 1

Mes questions:

  • Quelqu'un peut-il reproduire ce problème?
  • Est-ce que quelqu'un connaît une solution?

J'utilise jupyter sur un Mac avec tout fraîchement installé, y compris python ... J'ai également testé sur python en utilisant directement:

Python 3.6.1 |Anaconda custom (x86_64)| (default, May 11 2017, 13:04:09) 
[GCC 4.2.1 Compatible Apple LLVM 6.0 (clang-600.0.57)] on darwin 

code:

import mxnet as mx 
import os 
import subprocess 
import numpy as np 
import matplotlib.pyplot as plt 
import tarfile 

import warnings 
warnings.filterwarnings("ignore", category=DeprecationWarning) 


import numpy as np 
data = np.random.rand(100,3) 
label = np.random.randint(0, 10, (100,)) 
data_iter = mx.io.NDArrayIter(data=data, label=label, batch_size=30) 
for batch in data_iter: 
    print([batch.data, batch.label, batch.pad]) 

#[[<NDArray 30x3 @cpu(0)>], [<NDArray 30 @cpu(0)>], 0] 
#[[<NDArray 30x3 @cpu(0)>], [<NDArray 30 @cpu(0)>], 0] 
#[[<NDArray 30x3 @cpu(0)>], [<NDArray 30 @cpu(0)>], 0] 
#[[<NDArray 30x3 @cpu(0)>], [<NDArray 30 @cpu(0)>], 20] 


#lets save `data` into a csv file first and try reading it back 
np.savetxt('data.csv', data, delimiter=',') 
data_iter = mx.io.CSVIter(data_csv='data.csv', data_shape=(3,), batch_size=30) 
for batch in data_iter: 
    print([batch.data, batch.pad]) 

#[[<NDArray 30x3 @cpu(0)>], 0] 
#[[<NDArray 30x3 @cpu(0)>], 0] 
#[[<NDArray 30x3 @cpu(0)>], 0] 
#[[<NDArray 30x3 @cpu(0)>], 20] 

class SimpleIter(mx.io.DataIter): 
    def __init__(self, data_names, data_shapes, data_gen, 
       label_names, label_shapes, label_gen, num_batches=10): 
     self._provide_data = zip(data_names, data_shapes) 
     self._provide_label = zip(label_names, label_shapes) 
     self.num_batches = num_batches 
     self.data_gen = data_gen 
     self.label_gen = label_gen 
     self.cur_batch = 0 

    def __iter__(self): 
     return self 

    def reset(self): 
     self.cur_batch = 0 

    def __next__(self): 
     return self.next() 

    @property 
    def provide_data(self): 
     return self._provide_data 

    @property 
    def provide_label(self): 
     return self._provide_label 

    def next(self): 
     if self.cur_batch < self.num_batches: 
      self.cur_batch += 1 
      data = [mx.nd.array(g(d[1])) for d,g in zip(self._provide_data,\ 
                 self.data_gen)] 
      label = [mx.nd.array(g(d[1])) for d,g in zip(self._provide_label,\ 
                 self.label_gen)] 
      return mx.io.DataBatch(data, label) 
     else: 
      raise StopIteration 


import mxnet as mx 
num_classes = 10 
net = mx.sym.Variable('data') 
net = mx.sym.FullyConnected(data=net, name='fc1', num_hidden=64) 
net = mx.sym.Activation(data=net, name='relu1', act_type="relu") 
net = mx.sym.FullyConnected(data=net, name='fc2', num_hidden=num_classes) 
net = mx.sym.SoftmaxOutput(data=net, name='softmax') 
print(net.list_arguments()) 
print(net.list_outputs()) 

#['data', 'fc1_weight', 'fc1_bias', 'fc2_weight', 'fc2_bias', 'softmax_label'] 
#['softmax_output'] 



import logging 
logging.basicConfig(level=logging.INFO) 

n = 32 
data_iter = SimpleIter(['data'], [(n, 100)], 
        [lambda s: np.random.uniform(-1, 1, s)], 
        ['softmax_label'], [(n,)], 
        [lambda s: np.random.randint(0, num_classes, s)]) 

mod = mx.mod.Module(symbol=net) 
mod.fit(data_iter, num_epoch=5) 

Erreur:

ValueError        Traceback (most recent call last) 
<ipython-input-57-6ceb7dd11508> in <module>() 
     9 
     10 mod = mx.mod.Module(symbol=net) 
     ---> 11 mod.fit(data_iter, num_epoch=5)/Users/bernd/anaconda/lib/python3.6/site-packages/mxnet/module/base_module.py in fit(self, train_data, eval_data, eval_metric, epoch_end_callback, batch_end_callback, kvstore, optimizer, optimizer_params, eval_end_callback, eval_batch_end_callback, initializer, arg_params, aux_params, allow_missing, force_rebind, force_init, begin_epoch, num_epoch, validation_metric, monitor) 
      493      end_of_batch = True 
      494 
     --> 495     self.update_metric(eval_metric, data_batch.label) 
      496 
      497     if monitor is not None: 
/Users/bernd/anaconda/lib/python3.6/site-packages/mxnet/module/module.py in update_metric(self, eval_metric, labels) 
      678    Typically ``data_batch.label``. 
      679   """ 
     --> 680   self._exec_group.update_metric(eval_metric, labels) 
      681 
      682  def _sync_params_from_devices(self): 
/Users/bernd/anaconda/lib/python3.6/site-packages/mxnet/module/executor_group.py in update_metric(self, eval_metric, labels) 
      561    labels_ = OrderedDict(zip(self.label_names, labels_slice)) 
      562    preds = OrderedDict(zip(self.output_names, texec.outputs)) 
     --> 563    eval_metric.update_dict(labels_, preds) 
      564 
      565  def _bind_ith_exec(self, i, data_shapes, label_shapes, shared_group): 
/Users/bernd/anaconda/lib/python3.6/site-packages/mxnet/metric.py in update_dict(self, label, pred) 
      89    label = label.values() 
      90 
     ---> 91   self.update(label, pred) 
      92 
      93  def update(self, labels, preds): 
/Users/bernd/anaconda/lib/python3.6/site-packages/mxnet/metric.py in update(self, labels, preds) 
      369    Predicted values. 
      370   """ 
     --> 371   check_label_shapes(labels, preds) 
      372 
      373   for label, pred_label in zip(labels, preds): 
/Users/bernd/anaconda/lib/python3.6/site-packages/mxnet/metric.py in check_label_shapes(labels, preds, shape) 
      22  if label_shape != pred_shape: 
      23   raise ValueError("Shape of labels {} does not match shape of " 
     ---> 24       "predictions {}".format(label_shape, pred_shape)) 
      25 
      26 
ValueError: Shape of labels 0 does not match shape of predictions 1 
+0

Je ne peux pas produire cette erreur quand je cours votre code. – gobrewers14

+0

@ gobrewers14: pouvez-vous m'envoyer votre configuration? quelles versions, quel os utilisez-vous? – user224637

Répondre

0

c'est un problème d'enfer de la version python. J'ai été capable de le faire fonctionner avec tout ce qui fonctionne et compilé sur python 2.7. les versions de python 3.x semblent créer le problème et le message d'erreur n'est pas vraiment utile ...