2017-10-20 6 views
0

Ceci est un code python pour générer un modèle MLP simple multi-ports-deux entrées, deux sorties .Utilisation des fonctions HybridBlock et erport permettent d'utiliser en C++.Pourquoi le modèle multi-ports ne peut pas importer en C++?

Graphique net:

enter image description here

from mxnet import nd 
from mxnet.gluon import nn 
import mxnet as mx 

class HybridNet(nn.HybridBlock): 
    def __init__(self, **kwargs): 
     super(HybridNet, self).__init__(**kwargs) 
     with self.name_scope(): 
      self.dense0 = nn.Dense(3) 
      self.dense1 = nn.Dense(3) 
      self.dense2 = nn.Dense(6) 

    def hybrid_forward(self, F,x,y): 
     result1 = F.relu(self.dense0(x))+F.relu(self.dense1(y)) 
     result2 = F.relu(self.dense2(result1)) 
     return [result1,result2] 

net = HybridNet() 
net.initialize() 
net.hybridize() 
x = nd.random.normal(shape=(4,3)) 
y = nd.random.normal(shape=(4,5)) 
res=net(x,y) 
print "output1:",res[0] 
print "output2:",res[1] 
net.export('model') 

Nous pouvons réimporter le modèle pour vérifier l'exportation est la météo correct.You peut voir les deux résultat est le même.

from collections import namedtuple 
sym = mx.symbol.load('model-symbol.json') 
mod=mx.mod.Module(symbol=sym,data_names=['data0','data1']) 
mod.bind(data_shapes=[('data0',(1,3)),('data1',(1,5))]) 
mod.load_params('model-0000.params') 
Batch=namedtuple('Batch',['data']) 
mod.forward(Batch(data=[x,y])) 
print mod.get_outputs() 

Voir ce que les résultats ressemblent

sym.list_outputs() 

[ 'hybridnet0__plus0_output', 'hybridnet0_relu2_output']

Voici la première partie du code C++ et lever erreur. Je m'assure que le num_input_nodes et num_output_nodes tous les deux sont deux. Et en utilisant MXPredCreatePartialOut pour personnaliser ma sortie multi-tâches.

enter image description here

#include <mxnet/c_predict_api.h> 

#include <iostream> 
#include <fstream> 
#include <string> 
#include <vector> 
#include <assert.h> 

// Read file to buffer 
class BufferFile { 
public: 
    std::string file_path_; 
    int length_; 
    char* buffer_; 

    explicit BufferFile(std::string file_path) 
     :file_path_(file_path) { 

     std::ifstream ifs(file_path.c_str(), std::ios::in | std::ios::binary); 
     if (!ifs) { 
      std::cerr << "Can't open the file. Please check " << file_path << ". \n"; 
      length_ = 0; 
      buffer_ = NULL; 
      return; 
     } 

     ifs.seekg(0, std::ios::end); 
     length_ = ifs.tellg(); 
     ifs.seekg(0, std::ios::beg); 
     std::cout << file_path.c_str() << " ... " << length_ << " bytes\n"; 

     buffer_ = new char[sizeof(char) * length_]; 
     ifs.read(buffer_, length_); 
     ifs.close(); 
    } 

    int GetLength() { 
     return length_; 
    } 
    char* GetBuffer() { 
     return buffer_; 
    } 

    ~BufferFile() { 
     if (buffer_) { 
      delete[] buffer_; 
      buffer_ = NULL; 
     } 
    } 
}; 

int main(int argc, char* argv[]) { 

    // Models path for your model, you have to modify it 
    std::string json_file = "./model-symbol.json"; 
    std::string param_file = "./model-0000.params"; 

    BufferFile json_data(json_file); 
    BufferFile param_data(param_file); 

    // Parameters 
    int dev_type = 1; // 1: cpu, 2: gpu 
    int dev_id = 1; // arbitrary. 
    mx_uint num_input_nodes = 2; 
    mx_uint num_output_nodes = 2; 

    const char* input_key[2] = { "data0" , "data1" }; 
    const char** input_keys = input_key; 
    const char* output_key[2] = { "hybridnet0__plus0" , "hybridnet0_relu2" }; 
    const char** output_keys = output_key; 

    // input-dims 
    int data0_len = 3; 
    int data1_len = 5; 
    const mx_uint input_shape_indptr[4] = { 0,2,2,4 }; 
    const mx_uint input_shape_data[4] = {1,static_cast<mx_uint>(data0_len),1,static_cast<mx_uint>(data1_len) }; 
    PredictorHandle pred_hnd = 0; 

    if (json_data.GetLength() == 0 || param_data.GetLength() == 0) 
     return -1; 

    // Create Predictor 
    assert(0 == MXPredCreatePartialOut(
     (const char*)json_data.GetBuffer(), 
     (const char*)param_data.GetBuffer(), 
     static_cast<size_t>(param_data.GetLength()), 
     dev_type, 
     dev_id, 
     num_input_nodes, 
     input_keys, 
     input_shape_indptr, 
     input_shape_data, 
     num_output_nodes, 
     output_keys, 
     &pred_hnd)); //ERROR HERE 
    assert(pred_hnd); 

    return 0; 
} 

Répondre

1

On dirait que cette ligne ne va pas. const mx_uint input_shape_indptr[4] = { 0,2,2,4 };

Modifier à const mx_uint input_shape_indptr[3] = { 0,2,4 };

+0

Oui, Merci! Est pas error.But .Il Pourquoi suppression '2' dans le' input_shape_indptr'? – partida

+0

Jetez un oeil à ce [code] (https://github.com/apache/incubator-mxnet/blob/master/src/c_api/c_predict_api.cc#L175). 'indptr' contient 0 comme premier élément et accumule des longueurs de données de forme. – reminisce

+0

'input_shape_indptr' indique où les données commencent et se terminent, donc la longueur est pair.Est-ce que c'est juste? – partida