2017-10-19 2 views
0

J'ai modifié le code du tutoriel approfondi & pour lire les grandes entrées du fichier en utilisant tf.contrib.learn.read_batch_examples. Pour accélérer le processus de formation, j'ai mis le read_batch_size et j'ai obtenu une erreur ValueError: Toutes les formes doivent être entièrement définies: [TensorShape ([]), TensorShape ([Dimension (None)])] Mon code:Erreur lors de la définition de read_batch_size dans tf.contrib.learn.read_batch_examples. la valeur par défaut est correcte

def input_fn_pre(batch_size, filename): 
    examples_op = tf.contrib.learn.read_batch_examples(
    filename, 
    batch_size=5000, 
    reader=tf.TextLineReader, 
    num_epochs=5, 
    num_threads=5, 
    read_batch_size=2500, 
    parse_fn=lambda x: tf.decode_csv(x, [tf.constant(['0'], dtype=tf.string)] * len(COLUMNS) * 2500, use_quote_delim=False))         
    examples_dict = {} 

    for i, col in enumerate(COLUMNS): 
    examples_dict[col] = examples_op[:, i] 
    feature_cols = {k: tf.string_to_number(examples_dict[k], out_type=tf.float32) for k in CONTINUOUS_COLUMNS} 
    feature_cols.update({k: dense_to_sparse(examples_dict[k]) for k in CATEGORICAL_COLUMNS}) 
    label = tf.string_to_number(examples_dict[LABEL_COLUMN], out_type=tf.int32) 
    return feature_cols, label 

tout en utilisant le réglage des paramètres par défaut est ok:

def input_fn_pre(batch_size, filename): 
    examples_op = tf.contrib.learn.read_batch_examples(
    filename, 
    batch_size=5000, 
    reader=tf.TextLineReader, 
    num_epochs=5, 
    num_threads=5, 
    parse_fn=lambda x: tf.decode_csv(x, [tf.constant(['0'], dtype=tf.string)] * len(COLUMNS), use_quote_delim=False))         
    examples_dict = {} 

    for i, col in enumerate(COLUMNS): 
    examples_dict[col] = examples_op[:, i] 
    feature_cols = {k: tf.string_to_number(examples_dict[k], out_type=tf.float32) for k in CONTINUOUS_COLUMNS} 
    feature_cols.update({k: dense_to_sparse(examples_dict[k]) for k in CATEGORICAL_COLUMNS}) 
    label = tf.string_to_number(examples_dict[LABEL_COLUMN], out_type=tf.int32) 
    return feature_cols, label 

Il n'y a pas suffisamment d'explications dans le tensorflow doc.

Répondre

0

Je n'ai vu aucune différence entre vos deux extraits de code. Pourriez-vous mettre à jour votre code?

+0

Oh désolé, j'ai mis à jour mon code. –

+0

L'erreur n'est pas évidente en lisant simplement le code. Pourriez-vous poster votre code minimal complet (exécutable)? – Mingxing