Je le code suivant dans theano
afin de calculer L2
la distanceConversion distance euclidienne Théano au format moteur keras
def distance(square=False):
X = T.fmatrix('X')
Y = T.fmatrix('Y')
squared_euclidean_distances = (X ** 2).sum(1).reshape((X.shape[0], 1)) + (Y ** 2).sum(1).reshape \
((1, Y.shape[0])) - 2 * X.dot(Y.T)
if square:
return theano.function([X, Y], T.sqrt(squared_euclidean_distances))
else:
return theano.function([X, Y], squared_euclidean_distances)
print(distance()([[1, 0], [1, 1]], [[1, 0]]))
résultat avec: [[0] [ 1.]]
qui est la matrice de distance entre le jeu de gauche (tw o vecteurs - [1, 0], [1, 1]) et l'ensemble de droite qui contient le vecteur unique [1,0].
Cela fonctionne bien avec theano même si X et Y ont dim différent comme ci-dessus. Je voudrais obtenir une fonction générale keras
pour produire le même résultat. J'ai essayé:
def distance_matrix(vects):
x, y = vects
# <x,x> + <y,y> - 2<x,y>
x_shape = K.int_shape(x)
y_shape = K.int_shape(y)
return K.reshape(K.sum(K.square(x), axis=1), (x_shape[0], 1)) + \
K.reshape(K.sum(K.square(y), axis=1), (1, y_shape[0])) - \
2 * K.dot(x, y)
mais le code suivant ne produit pas le bon résultat:
x = K.variable(np.array([[1, 0], [1, 1]]))
y = K.variable(np.array([[1, 0]]))
obj = distance_matrix
objective_output = obj((x, y))
print (K.eval(objective_output))
résultat avec
ValueError: Shape mismatch: x has 2 cols (and 4 rows) but y has 4 rows (and 2 cols)
Apply node that caused the error: Dot22Scalar(/variable, /variable, TensorConstant{2.0})
Toposort index: 0
Inputs types: [TensorType(float32, matrix), TensorType(float32, matrix), TensorType(float32, scalar)]
Inputs shapes: [(4, 2), (4, 2),()]
Inputs strides: [(8, 4), (8, 4),()]
Inputs values: ['not shown', 'not shown', array(2.0, dtype=float32)]
Outputs clients: [[Elemwise{Composite{((i0 + i1) - i2)}}[(0, 2)](InplaceDimShuffle{0,x}.0, InplaceDimShuffle{x,0}.0, Dot22Scalar.0)]]
Modifier: sorties ajoutées à coder
Pourriez-vous donner plus de détails sur votre problème? Par exemple. qu'est-ce qui ne fonctionne pas correctement? –
@ MarcinMożejko J'ai ajouté l'exemple de sortie de deux cas d'utilisation ci-dessus – oak
@ MarcinMożejko, merci j'ai trouvé l'erreur, j'ai oublié de transposer Y – oak