2017-06-21 1 views
4

Je résous un problème de classification de texte. J'ai défini mon classificateur en utilisant la classe Estimator avec le mien model_fn. Je souhaite utiliser l'intégration word2vec pré-formée de Google en tant que valeurs initiales, puis l'optimiser davantage pour la tâche en cours.Chargement de word2vec pré-formé pour initialiser embedding_lookup dans l'estimateur model_fn

J'ai vu ce poste: Using a pre-trained word embedding (word2vec or Glove) in TensorFlow
qui explique comment s'y prendre dans le code TensorFlow 'brut'. Cependant, j'aimerais vraiment utiliser la classe Estimator. En guise d'extension, je voudrais ensuite former ce code sur Cloud ML Engine, y at-il un bon moyen de passer dans le fichier assez volumineux avec les valeurs initiales?

Disons que nous avons quelque chose comme:

def build_model_fn(): 
    def _model_fn(features, labels, mode, params): 
     input_layer = features['feat'] #shape=[-1, params["sequence_length"]] 
     #... what goes here to initialize W 

     embedded = tf.nn.embedding_lookup(W, input_layer) 
     ... 
     return predictions 

estimator = tf.contrib.learn.Estimator(
    model_fn=build_model_fn(), 
    model_dir=MODEL_DIR, 
    params=params) 
estimator.fit(input_fn=read_data, max_steps=2500) 

Répondre

7

Plongements sont généralement assez grandes que la seule approche viable est de les utiliser pour initialiser un tf.Variable dans votre graphique. Cela vous permettra de tirer profit des serveurs param dans les distributions, etc.

Pour cela (et toute autre chose), je vous recommande d'utiliser le nouvel estimateur "de base", tf.estimator.Estimator, car cela facilitera beaucoup les choses.

De la réponse dans le lien que vous avez fourni, et sachant que nous voulons une variable pas une constante, nous pouvons soit prendre approche:

(2) Initialiser la variable à l'aide d'un dict d'alimentation ou (3) charger la variable d'un poste de contrôle


Je vais couvrir l'option (3) Tout d'abord, car il est beaucoup plus facile et mieux:

Dans votre model_fn, initialiser simplement une variable en utilisant le Tensor retourné par un appel tf.contrib.framework.load_variable. Cela nécessite:

  1. que vous avez un poste de contrôle de TF valide avec vos incorporations
  2. Vous connaissez le nom complet de la variable Plongement dans le poste de contrôle.

Le code est assez simple:

def model_fn(mode, features, labels, hparams): 
    embeddings = tf.Variable(tf.contrib.framework.load_variable(
     'gs://my-bucket/word2vec_checkpoints/', 
     'a/fully/qualified/scope/embeddings' 
)) 
    .... 
    return tf.estimator.EstimatorSpec(...) 

Cependant, cette approche ne fonctionnera pas pour vous si vos incorporations ne sont pas produits par un autre modèle de TF, d'où l'option (2).


Pour (2), nous devons utiliser tf.train.Scaffold qui est essentiellement un objet de configuration qui contient toutes les options de démarrage d'un tf.Session (qui estimateur cache intentionnellement pour de nombreuses raisons).

Vous pouvez spécifier un Scaffold dans le tf.train.EstimatorSpec vous revenez dans votre model_fn.

Nous créons un espace réservé dans notre model_fn, et en faisons l'opération d'initialisation pour notre variable d'incorporation, puis passons un init_feed_dict via le Scaffold. par exemple.

def model_fn(mode, features, labels, hparams): 
    embed_ph = tf.placeholder(
     shape=[hparams.vocab_size, hparams.embedding_size], 
     dtype=tf.float32) 
    embeddings = tf.Variable(embed_ph) 
    # Define your model 
    return tf.estimator.EstimatorSpec(
     ..., # normal EstimatorSpec args 
     scaffold=tf.train.Scaffold(init_feed_dict={embed_ph: my_embedding_numpy_array}) 
) 

Ce qui se passe ici est le init_feed_dict renseignera les valeurs de l'espace réservé embed_ph lors de l'exécution, ce qui permettra alors embeddings.initialization_op (affectation de l'espace réservé), pour exécuter.


+0

Merci, juste une petite chose: il devrait être 'tf.estimator.EstimatorSpec (..., échafaudage = tf.train.Scaffold (init_feed_dict =. {Embed_ph: my_embedding_numpy_array})' – Tristan

+0

Merci Tristan Spaced que syntaxe même si j'avais l'explication lol. –