2017-03-27 2 views
1

je la fonction suivante:fonction Keras lambda produit scalaire mistmatch

def transpose_dot(vects): 
    x, y = vects 
    # <x,x> + <y,y> - 2<x,y> 

    return K.dot(x, K.transpose(y)) 

Quand essayer d'évaluer avec keras cela fonctionne

x = K.variable(np.array(np_x)) 
y = K.variable(np.array(np_x)) 
obj = transpose_dot 
objective_output = obj((x, y)) 
print('-----------------') 
print (K.eval(objective_output)) 

résultat avec:

[[ 1. 1. 1. 2.] 
[ 1. 2. 2. 4.] 
[ 1. 2. 2. 4.] 
[ 2. 4. 4. 8.] 

Mais, lorsque vous essayez de l'utiliser en tant que fonction pour Lambda couche cela ne fonctionne pas.

np_x = [[1, 0], [1, 1], [1, 1], [2, 2]] 
features = np.array([np_x]) 
test_input = Input(shape=np.array(np_x).shape) 
dot_layer= Lambda(transpose_dot, output_shape=(4,4))([test_input, test_input]) 
x = Model(inputs=test_input, outputs=dot_layer) 
x.predict(features, batch_size=1) 

Résultat avec

self.fn() if output_subset is None else\ 
ValueError: Shape mismatch: x has 2 cols (and 4 rows) but y has 4 rows (and 2 cols) 
Apply node that caused the error: Dot22(Reshape{2}.0, Reshape{2}.0) 
Toposort index: 11 
Inputs types: [TensorType(float32, matrix), TensorType(float32, matrix)] 
Inputs shapes: [(4, 2), (4, 2)] 
Inputs strides: [(8, 4), (8, 4)] 
Inputs values: ['not shown', 'not shown'] 
Outputs clients: [[Reshape{4}(Dot22.0, MakeVector{dtype='int64'}.0)]] 

Toute idée de ce que je suis absent ici?

Edit: sortie ajoutée du message d'erreur

+0

Quel est le message d'erreur? Vous avez un) manquant dans la ligne Lambda ... –

+0

@NassimBen, j'ai ajouté le message d'erreur, Fondamentalement, il se plaint de la forme, mais 'x a 2 cols (et 4 lignes) mais y a 4 lignes (et 2 cols) ' – oak

Répondre

0

Avec l'aide de gars à https://github.com/fchollet/keras/ j'ai trouvé mon erreur. La fonction s'attend à obtenir (n, m). Mais lors de l'utilisation d'une fonction Lambda attendez-vous à obtenir (samples, n, m).