Shyrma mkl matmul (#250)
* - provide matmul code based on mkl api Signed-off-by: Yurii <iuriish@yahoo.com> * - correct typo in mkl matmul op Signed-off-by: Yurii <iuriish@yahoo.com> * - take into account empty arrays in mkl matmul op Signed-off-by: Yurii <iuriish@yahoo.com> * - fix bug in mkl matmul and group all matmul tests in one file Signed-off-by: Yurii <iuriish@yahoo.com>master
parent
c8882cbfa5
commit
22c7aa9acf
|
@ -5,7 +5,7 @@ project(mkldnn-download NONE)
|
|||
include(ExternalProject)
|
||||
ExternalProject_Add(mkldnn
|
||||
GIT_REPOSITORY https://github.com/intel/mkl-dnn.git
|
||||
GIT_TAG v1.1.3
|
||||
GIT_TAG v1.2
|
||||
SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/mkldnn-src"
|
||||
BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/mkldnn-build"
|
||||
CONFIGURE_COMMAND ""
|
||||
|
|
|
@ -20,7 +20,7 @@
|
|||
// @author Yurii Shyrma (iuriish@yahoo.com), fully rewritten
|
||||
//
|
||||
|
||||
#include <op_boilerplate.h>
|
||||
#include <op_boilerplate.h>
|
||||
#if NOT_EXCLUDED(OP_matmul)
|
||||
|
||||
#include <ops/declarable/CustomOperations.h>
|
||||
|
@ -29,142 +29,128 @@
|
|||
namespace nd4j {
|
||||
namespace ops {
|
||||
|
||||
CUSTOM_OP_IMPL(matmul, 2, 1, false, 0, -2) {
|
||||
auto x = INPUT_VARIABLE(0);
|
||||
auto y = INPUT_VARIABLE(1);
|
||||
auto z = OUTPUT_VARIABLE(0);
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
CUSTOM_OP_IMPL(matmul, 2, 1, false, 0, -2) {
|
||||
|
||||
const int iSize = (int) block.getIArguments()->size();
|
||||
int transX = iSize > 0 ? INT_ARG(0) : 0;
|
||||
int transY = iSize > 1 ? INT_ARG(1) : 0;
|
||||
const int transZ = iSize > 2 ? INT_ARG(2) : 0;
|
||||
auto x = INPUT_VARIABLE(0);
|
||||
auto y = INPUT_VARIABLE(1);
|
||||
auto z = OUTPUT_VARIABLE(0);
|
||||
|
||||
const int xRank = x->rankOf();
|
||||
const int yRank = y->rankOf();
|
||||
const int zRank = z->rankOf();
|
||||
const int iSize = (int) block.getIArguments()->size();
|
||||
int transX = iSize > 0 ? INT_ARG(0) : 0;
|
||||
int transY = iSize > 1 ? INT_ARG(1) : 0;
|
||||
const int transZ = iSize > 2 ? INT_ARG(2) : 0;
|
||||
|
||||
if (transZ) {
|
||||
x = INPUT_VARIABLE(1);
|
||||
y = INPUT_VARIABLE(0);
|
||||
bool temp = transX;
|
||||
transX = !transY;
|
||||
transY = !temp;
|
||||
}
|
||||
const int xRank = x->rankOf();
|
||||
const int yRank = y->rankOf();
|
||||
const int zRank = z->rankOf();
|
||||
|
||||
const int xLastDim = transX ? -2 : -1;
|
||||
const int yLastDim = transY ? -2 : -1;
|
||||
const int xLastButOneDim = transX ? -1 : -2;
|
||||
const int yLastButOneDim = transY ? -1 : -2;
|
||||
if (transZ) {
|
||||
x = INPUT_VARIABLE(1);
|
||||
y = INPUT_VARIABLE(0);
|
||||
bool temp = transX;
|
||||
transX = !transY;
|
||||
transY = !temp;
|
||||
}
|
||||
|
||||
// ******* input validation ******* //
|
||||
REQUIRE_TRUE(xRank > 0 && yRank > 0, 0,
|
||||
"MATMUL OP: input arrays must have rank bigger than 0 (should not be scalars), but got instead: x rank = %i, y rank = %i !",
|
||||
xRank, yRank);
|
||||
const int xLastDim = transX ? -2 : -1;
|
||||
const int yLastDim = transY ? -2 : -1;
|
||||
const int xLastButOneDim = transX ? -1 : -2;
|
||||
const int yLastButOneDim = transY ? -1 : -2;
|
||||
|
||||
if (xRank == 1 && yRank == 1) { // dot case, output is scalar (or vector with length = 1)
|
||||
REQUIRE_TRUE(x->lengthOf() == y->lengthOf(), 0,
|
||||
"MATMUL OP: since input arrays are vectors they must have the same length, but got x length = %i, y length = %i !",
|
||||
x->lengthOf(), y->lengthOf());
|
||||
} else if (xRank == 1 && yRank == 2) { // vector x matrix, i.e. [4] x [4,5] = [5], output is vector
|
||||
REQUIRE_TRUE(x->lengthOf() == y->sizeAt(yLastButOneDim), 0,
|
||||
"MATMUL OP: input arrays have inconsistent shapes for vector-matrix product: x %s, y %s !",
|
||||
ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str());
|
||||
} else if (xRank == 2 && yRank == 1) { // matrix x vector , i.e. [4,5] x [5] = [4], output is vector
|
||||
REQUIRE_TRUE(x->sizeAt(xLastDim) == y->lengthOf(), 0,
|
||||
"MATMUL OP: input arrays have inconsistent shapes for matrix-vector product: x %s, y %s !",
|
||||
ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str());
|
||||
} else {
|
||||
REQUIRE_TRUE(xRank == yRank && yRank == zRank, 0,
|
||||
"MATMUL OP: input and output arrays must have the same rank, but got instead: x rank = %i, y rank = %i, z rank = %i !",
|
||||
xRank, yRank, zRank);
|
||||
REQUIRE_TRUE(x->sizeAt(xLastDim) == y->sizeAt(yLastButOneDim) &&
|
||||
x->sizeAt(xLastButOneDim) == z->sizeAt(-2) && y->sizeAt(yLastDim) == z->sizeAt(-1), 0,
|
||||
"MATMUL OP: input/output arrays have inconsistent shapes for matrix product: x %s, y %s, z %s !",
|
||||
ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str(),
|
||||
ShapeUtils::shapeAsString(z).c_str());
|
||||
// ******* input validation ******* //
|
||||
REQUIRE_TRUE(xRank > 0 && yRank > 0, 0, "MATMUL OP: input arrays must have rank bigger than 0 (should not be scalars), but got instead: x rank = %i, y rank = %i !", xRank, yRank);
|
||||
|
||||
if (xRank > 2) // outer dims must be the same
|
||||
for (int i = 0; i < xRank - 2; ++i)
|
||||
REQUIRE_TRUE(x->sizeAt(i) == y->sizeAt(i) && y->sizeAt(i) == z->sizeAt(i), 0,
|
||||
"MATMUL OP: input/output arrays have inconsistent shapes for matrix product: x %s, y %s, z %s !",
|
||||
ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str(),
|
||||
ShapeUtils::shapeAsString(z).c_str());
|
||||
}
|
||||
// ******* end of input validation ******* //
|
||||
if (xRank == 1 && yRank == 1) { // dot case, output is scalar (or vector with length = 1)
|
||||
REQUIRE_TRUE(x->lengthOf() == y->lengthOf(), 0, "MATMUL OP: since input arrays are vectors they must have the same length, but got x length = %i, y length = %i !", x->lengthOf(), y->lengthOf());
|
||||
} else if (xRank == 1 && yRank == 2) { // vector x matrix, i.e. [4] x [4,5] = [5], output is vector
|
||||
REQUIRE_TRUE(x->lengthOf() == y->sizeAt(yLastButOneDim), 0, "MATMUL OP: input arrays have inconsistent shapes for vector-matrix product: x %s, y %s !", ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str());
|
||||
} else if (xRank == 2 && yRank == 1) { // matrix x vector , i.e. [4,5] x [5] = [4], output is vector
|
||||
REQUIRE_TRUE(x->sizeAt(xLastDim) == y->lengthOf(), 0, "MATMUL OP: input arrays have inconsistent shapes for matrix-vector product: x %s, y %s !", ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str());
|
||||
} else {
|
||||
REQUIRE_TRUE(xRank == yRank && yRank == zRank, 0, "MATMUL OP: input and output arrays must have the same rank, but got instead: x rank = %i, y rank = %i, z rank = %i !", xRank, yRank, zRank);
|
||||
REQUIRE_TRUE(x->sizeAt(xLastDim) == y->sizeAt(yLastButOneDim) && x->sizeAt(xLastButOneDim) == z->sizeAt(-2) && y->sizeAt(yLastDim) == z->sizeAt(-1), 0, "MATMUL OP: input/output arrays have inconsistent shapes for matrix product: x %s, y %s, z %s !", ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str(), ShapeUtils::shapeAsString(z).c_str());
|
||||
|
||||
MmulHelper::matmul(x, y, z, transX, transY);
|
||||
if (xRank > 2) // outer dims must be the same
|
||||
for (int i = 0; i < xRank - 2; ++i)
|
||||
REQUIRE_TRUE(x->sizeAt(i) == y->sizeAt(i) && y->sizeAt(i) == z->sizeAt(i), 0, "MATMUL OP: input/output arrays have inconsistent shapes for matrix product: x %s, y %s, z %s !", ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str(), ShapeUtils::shapeAsString(z).c_str());
|
||||
}
|
||||
// ******* end of input validation ******* //
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
MmulHelper::matmul(x, y, z, transX, transY);
|
||||
|
||||
DECLARE_SYN(mMul, matmul);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
DECLARE_SYN(mmul, matmul);
|
||||
DECLARE_SYN(mMul, matmul);
|
||||
|
||||
DECLARE_SYN(gemm, matmul);
|
||||
DECLARE_SYN(mmul, matmul);
|
||||
|
||||
DECLARE_SYN(gemv, matmul);
|
||||
DECLARE_SYN(gemm, matmul);
|
||||
|
||||
DECLARE_SYN(dot, matmul);
|
||||
DECLARE_SYN(gemv, matmul);
|
||||
|
||||
DECLARE_SYN(dot, matmul);
|
||||
|
||||
DECLARE_SHAPE_FN(matmul) {
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
DECLARE_SHAPE_FN(matmul) {
|
||||
|
||||
auto xShapeInfo = inputShape->at(0);
|
||||
auto yShapeInfo = inputShape->at(1);
|
||||
auto xShapeInfo = inputShape->at(0);
|
||||
auto yShapeInfo = inputShape->at(1);
|
||||
|
||||
const int iSize = (int) block.getIArguments()->size();
|
||||
int transX = iSize > 0 ? INT_ARG(0) : 0;
|
||||
int transY = iSize > 1 ? INT_ARG(1) : 0;
|
||||
const int transZ = iSize > 2 ? INT_ARG(2) : 0;
|
||||
const int iSize = (int) block.getIArguments()->size();
|
||||
int transX = iSize > 0 ? INT_ARG(0) : 0;
|
||||
int transY = iSize > 1 ? INT_ARG(1) : 0;
|
||||
const int transZ = iSize > 2 ? INT_ARG(2) : 0;
|
||||
|
||||
REQUIRE_TRUE(xShapeInfo[0] > 0 && yShapeInfo[0] > 0, 0,
|
||||
"MATMUL OP: input arrays must have rank bigger than 0 (should not be scalars), but got instead: x rank = %i, y rank = %i !",
|
||||
xShapeInfo[0], yShapeInfo[0]);
|
||||
REQUIRE_TRUE(xShapeInfo[0] > 0 && yShapeInfo[0] > 0, 0,
|
||||
"MATMUL OP: input arrays must have rank bigger than 0 (should not be scalars), but got instead: x rank = %i, y rank = %i !",
|
||||
xShapeInfo[0], yShapeInfo[0]);
|
||||
|
||||
if (transZ) {
|
||||
xShapeInfo = inputShape->at(1);
|
||||
yShapeInfo = inputShape->at(0);
|
||||
bool temp = transX;
|
||||
transX = !transY;
|
||||
transY = !temp;
|
||||
}
|
||||
if (transZ) {
|
||||
xShapeInfo = inputShape->at(1);
|
||||
yShapeInfo = inputShape->at(0);
|
||||
bool temp = transX;
|
||||
transX = !transY;
|
||||
transY = !temp;
|
||||
}
|
||||
|
||||
auto zShapeOnly = ShapeUtils::evalShapeForMatmul(xShapeInfo, yShapeInfo, transX, transY);
|
||||
auto zShapeOnly = ShapeUtils::evalShapeForMatmul(xShapeInfo, yShapeInfo, transX, transY);
|
||||
|
||||
auto dtypeX = ArrayOptions::dataType(xShapeInfo);
|
||||
auto dtypeY = ArrayOptions::dataType(yShapeInfo);
|
||||
auto dtypeX = ArrayOptions::dataType(xShapeInfo);
|
||||
auto dtypeY = ArrayOptions::dataType(yShapeInfo);
|
||||
|
||||
auto xOrder = shape::order(xShapeInfo);
|
||||
auto yOrder = shape::order(yShapeInfo);
|
||||
auto zOrder = xOrder == 'c' && yOrder == 'c' ? 'c' : 'f';
|
||||
auto xOrder = shape::order(xShapeInfo);
|
||||
auto yOrder = shape::order(yShapeInfo);
|
||||
auto zOrder = xOrder == 'c' && yOrder == 'c' ? 'c' : 'f';
|
||||
|
||||
// we just pick the higher data type out of X and Y
|
||||
auto dtypeZ = dtypeX > dtypeY ? dtypeX : dtypeY;
|
||||
// we just pick the higher data type out of X and Y
|
||||
auto dtypeZ = dtypeX > dtypeY ? dtypeX : dtypeY;
|
||||
|
||||
auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(dtypeZ, zOrder, zShapeOnly);
|
||||
return SHAPELIST(newShape);
|
||||
}
|
||||
auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(dtypeZ, zOrder, zShapeOnly);
|
||||
return SHAPELIST(newShape);
|
||||
}
|
||||
|
||||
DECLARE_TYPES(matmul) {
|
||||
getOpDescriptor()
|
||||
->setAllowedInputTypes(0, {ALL_FLOATS})
|
||||
->setAllowedInputTypes(1, {ALL_FLOATS})
|
||||
->setAllowedOutputTypes(0, {ALL_FLOATS});
|
||||
}
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
DECLARE_TYPES(matmul) {
|
||||
getOpDescriptor()
|
||||
->setAllowedInputTypes(0, {ALL_FLOATS})
|
||||
->setAllowedInputTypes(1, {ALL_FLOATS})
|
||||
->setAllowedOutputTypes(0, {ALL_FLOATS});
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
CUSTOM_OP_IMPL(matmul_bp, 3, 2, false, 0, -2) {
|
||||
auto x = INPUT_VARIABLE(0);
|
||||
auto y = INPUT_VARIABLE(1);
|
||||
auto eps = INPUT_VARIABLE(2);
|
||||
auto dldx = OUTPUT_VARIABLE(0);
|
||||
auto dldy = OUTPUT_VARIABLE(1);
|
||||
|
||||
CUSTOM_OP_IMPL(matmul_bp, 3, 2, false, 0, -2) {
|
||||
auto x = INPUT_VARIABLE(0);
|
||||
auto y = INPUT_VARIABLE(1);
|
||||
auto eps = INPUT_VARIABLE(2);
|
||||
auto dldx = OUTPUT_VARIABLE(0);
|
||||
auto dldy = OUTPUT_VARIABLE(1);
|
||||
|
||||
const int iSize = (int) block.getIArguments()->size();
|
||||
int transX = iSize > 0 ? INT_ARG(0) : 0;
|
||||
int transY = iSize > 1 ? INT_ARG(1) : 0;
|
||||
const int transZ = iSize > 2 ? INT_ARG(2) : 0;
|
||||
const int iSize = (int) block.getIArguments()->size();
|
||||
int transX = iSize > 0 ? INT_ARG(0) : 0;
|
||||
int transY = iSize > 1 ? INT_ARG(1) : 0;
|
||||
const int transZ = iSize > 2 ? INT_ARG(2) : 0;
|
||||
|
||||
/*
|
||||
In: x=[a,b], y=[b,c]
|
||||
|
@ -177,34 +163,35 @@ F F T [a,b] [b,c] [c,a] [c,a]
|
|||
*/
|
||||
|
||||
|
||||
nd4j::ops::matmul op;
|
||||
op.execute({eps, y}, {dldx}, {}, {transZ, !transY, transX}, {});
|
||||
op.execute({x, eps}, {dldy}, {}, {!transX, transZ, transY}, {});
|
||||
nd4j::ops::matmul op;
|
||||
op.execute({eps, y}, {dldx}, {}, {transZ, !transY, transX}, {});
|
||||
op.execute({x, eps}, {dldy}, {}, {!transX, transZ, transY}, {});
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
DECLARE_SHAPE_FN(matmul_bp) {
|
||||
Nd4jLong *xShapeInfo;
|
||||
Nd4jLong *yShapeInfo;
|
||||
|
||||
DECLARE_SHAPE_FN(matmul_bp) {
|
||||
Nd4jLong *xShapeInfo;
|
||||
Nd4jLong *yShapeInfo;
|
||||
COPY_SHAPE(inputShape->at(0), xShapeInfo);
|
||||
COPY_SHAPE(inputShape->at(1), yShapeInfo);
|
||||
|
||||
COPY_SHAPE(inputShape->at(0), xShapeInfo);
|
||||
COPY_SHAPE(inputShape->at(1), yShapeInfo);
|
||||
return SHAPELIST(CONSTANT(xShapeInfo), CONSTANT(yShapeInfo));
|
||||
}
|
||||
|
||||
return SHAPELIST(CONSTANT(xShapeInfo), CONSTANT(yShapeInfo));
|
||||
}
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
DECLARE_TYPES(matmul_bp) {
|
||||
getOpDescriptor()
|
||||
->setAllowedInputTypes(0, {ALL_FLOATS})
|
||||
->setAllowedInputTypes(1, {ALL_FLOATS})
|
||||
->setAllowedInputTypes(2, {ALL_FLOATS})
|
||||
->setAllowedOutputTypes(0, {ALL_FLOATS})
|
||||
->setAllowedOutputTypes(1, {ALL_FLOATS});
|
||||
}
|
||||
|
||||
DECLARE_TYPES(matmul_bp) {
|
||||
getOpDescriptor()
|
||||
->setAllowedInputTypes(0, {ALL_FLOATS})
|
||||
->setAllowedInputTypes(1, {ALL_FLOATS})
|
||||
->setAllowedInputTypes(2, {ALL_FLOATS})
|
||||
->setAllowedOutputTypes(0, {ALL_FLOATS})
|
||||
->setAllowedOutputTypes(1, {ALL_FLOATS});
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,294 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||
//
|
||||
|
||||
#include <ops/declarable/PlatformHelper.h>
|
||||
#include <ops/declarable/OpRegistrator.h>
|
||||
#include <platform_boilerplate.h>
|
||||
|
||||
#include <helpers/MKLDNNStream.h>
|
||||
#include "mkldnnUtils.h"
|
||||
#include <numeric>
|
||||
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
namespace platforms {
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
static void matmulMKLDNN(const NDArray* x, const NDArray* y, NDArray* z, const bool transX, const bool transY) {
|
||||
|
||||
// mkl works with following
|
||||
// [M,K] x [K,N] = [M,N]
|
||||
// [bS, M,K] x [bS, K,N] = [bS, M,N]
|
||||
|
||||
// possible input cases not supported by mkl, however we'll perform permut/reshape procedures in order to fit requirements
|
||||
// [4] x [4] = [1] --> [1,4] x [4,1] = [1,1]
|
||||
// [4] x [4,5] = [5] --> [1,4] x [4,5] = [1,5]
|
||||
// [4,5] x [5] = [4] --> [4,5] x [5,1] = [4,1]
|
||||
// [2,3, 4,5] x [2,3, 5,4] = [2,3, 4,4] --> [6, 4,5] x [6, 5,4] = [6, 4,4]
|
||||
// [2,2,3, 4,5] x [2,2,3, 5,4] = [2,2,3, 4,4] --> [12, 4,5] x [12, 5,4] = [12, 4,4]
|
||||
|
||||
const auto xRank = x->rankOf();
|
||||
const auto yRank = y->rankOf();
|
||||
const auto zRank = z->rankOf();
|
||||
|
||||
std::vector<int> permut;
|
||||
|
||||
// fill permutation vector appropriately if transposition is required
|
||||
if((transX && xRank > 1) || (transY && yRank > 1)) {
|
||||
|
||||
const int rank = xRank >= yRank ? xRank : yRank;
|
||||
permut.resize(rank);
|
||||
std::iota(std::begin(permut), std::end(permut), 0);
|
||||
permut[rank-2] = rank - 1;
|
||||
permut[rank-1] = rank - 2;
|
||||
}
|
||||
|
||||
const NDArray* xT = (transX && xRank > 1) ? new NDArray(x->permute(permut)) : x;
|
||||
const NDArray* yT = (transY && yRank > 1) ? new NDArray(y->permute(permut)) : y;
|
||||
|
||||
const NDArray* xTR = xRank <= 3 ? xT : new NDArray(xT->reshape(xT->ordering(), {xT->lengthOf() / (xT->sizeAt(-2) * xT->sizeAt(-1)), xT->sizeAt(-2), xT->sizeAt(-1)}));
|
||||
const NDArray* yTR = xRank <= 3 ? yT : new NDArray(yT->reshape(yT->ordering(), {yT->lengthOf() / (yT->sizeAt(-2) * yT->sizeAt(-1)), yT->sizeAt(-2), yT->sizeAt(-1)}));
|
||||
NDArray* zR = xRank <= 3 ? z : new NDArray(z->reshape(z->ordering(), {z->lengthOf() / (z->sizeAt(-2) * z->sizeAt(-1)), z->sizeAt(-2), z->sizeAt(-1)})/*, false*/);
|
||||
|
||||
// [M,K] x [K,N] = [M,N]
|
||||
const int M = (xRank > 1) ? xTR->sizeAt(-2) : 1;
|
||||
const int K = (xRank > 1) ? xTR->sizeAt(-1) : xTR->lengthOf();
|
||||
const int N = (yRank > 1) ? yTR->sizeAt(-1) : 1;
|
||||
const int bS = (xRank > 2) ? xTR->sizeAt(0) : 1; // [bS, M,K] x [bS, K,N] = [bS, M,N]
|
||||
|
||||
dnnl::memory::dims xShape = xRank < 3 ? dnnl::memory::dims({M, K}) : dnnl::memory::dims({bS, M, K});
|
||||
dnnl::memory::dims yShape = xRank < 3 ? dnnl::memory::dims({K, N}) : dnnl::memory::dims({bS, K, N});
|
||||
dnnl::memory::dims zShape = xRank < 3 ? dnnl::memory::dims({M, N}) : dnnl::memory::dims({bS, M, N});
|
||||
|
||||
dnnl::memory::format_tag format = xRank < 3 ? dnnl::memory::format_tag::ab : dnnl::memory::format_tag::abc;
|
||||
|
||||
// x type
|
||||
dnnl::memory::data_type xType;
|
||||
if(x->dataType() == DataType::FLOAT32)
|
||||
xType = dnnl::memory::data_type::f32;
|
||||
else if(x->dataType() == DataType::HALF)
|
||||
xType = dnnl::memory::data_type::f16;
|
||||
else if(x->dataType() == DataType::BFLOAT16)
|
||||
xType = dnnl::memory::data_type::bf16;
|
||||
else if(x->dataType() == DataType::UINT8)
|
||||
xType = dnnl::memory::data_type::u8;
|
||||
else
|
||||
xType = dnnl::memory::data_type::s8;
|
||||
|
||||
// y type
|
||||
dnnl::memory::data_type yType = xType;
|
||||
if(y->dataType() == DataType::UINT8)
|
||||
yType = dnnl::memory::data_type::u8;
|
||||
else if(y->dataType() == DataType::INT8)
|
||||
yType = dnnl::memory::data_type::s8;
|
||||
|
||||
// z type
|
||||
dnnl::memory::data_type zType = xType;
|
||||
if(z->dataType() == DataType::FLOAT32)
|
||||
zType = dnnl::memory::data_type::f32;
|
||||
else if(z->dataType() == DataType::INT32)
|
||||
zType = dnnl::memory::data_type::s32;
|
||||
else if(z->dataType() == DataType::UINT8)
|
||||
zType = dnnl::memory::data_type::u8;
|
||||
else if(z->dataType() == DataType::INT8)
|
||||
zType = dnnl::memory::data_type::s8;
|
||||
|
||||
// memory descriptors for arrays
|
||||
|
||||
// x
|
||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xShape, xType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, xType, format);
|
||||
if(xTR->ews() != 1 || xTR->ordering() != 'c') {
|
||||
x_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
x_user_md.data.format_desc.blocking.strides[0] = xRank == 1 ? 1 : xTR->strideAt(0);
|
||||
x_user_md.data.format_desc.blocking.strides[1] = xRank == 1 ? xTR->strideAt(0) : xTR->strideAt(1);
|
||||
if(xRank > 2)
|
||||
x_user_md.data.format_desc.blocking.strides[2] = xTR->strideAt(2);
|
||||
}
|
||||
|
||||
// y
|
||||
dnnl::memory::desc y_mkl_md = dnnl::memory::desc(yShape, yType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc y_user_md = dnnl::memory::desc(yShape, yType, format);
|
||||
if(yTR->ews() != 1 || yTR->ordering() != 'c') {
|
||||
y_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
y_user_md.data.format_desc.blocking.strides[0] = yRank == 1 ? 1 : yTR->strideAt(0);
|
||||
y_user_md.data.format_desc.blocking.strides[1] = yRank == 1 ? yTR->strideAt(0) : yTR->strideAt(1);
|
||||
if(yRank > 2)
|
||||
y_user_md.data.format_desc.blocking.strides[2] = yTR->strideAt(2);
|
||||
}
|
||||
|
||||
// z
|
||||
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zShape, zType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc z_user_md = dnnl::memory::desc(zShape, zType, format);
|
||||
if(zR->ews() != 1 || zR->ordering() != 'c') {
|
||||
z_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
z_user_md.data.format_desc.blocking.strides[0] = zRank == 1 ? 1 : zR->strideAt(0);
|
||||
z_user_md.data.format_desc.blocking.strides[1] = zRank == 1 ? zR->strideAt(0) : zR->strideAt(1);
|
||||
if(zRank > 2)
|
||||
z_user_md.data.format_desc.blocking.strides[2] = zR->strideAt(2);
|
||||
}
|
||||
|
||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
|
||||
// Create attributes (to handle alpha and beta if necessary)
|
||||
dnnl::primitive_attr attr; // it is empty since we have usual values for alpha (=1) and beta (=0)
|
||||
|
||||
// operation primitive description
|
||||
dnnl::matmul::desc op_desc(x_mkl_md, y_mkl_md, z_mkl_md);
|
||||
dnnl::matmul::primitive_desc op_prim_desc(op_desc, attr, engine);
|
||||
|
||||
// arguments (memory buffers) necessary for calculations
|
||||
std::unordered_map<int, dnnl::memory> args;
|
||||
|
||||
dnnl::stream stream(engine);
|
||||
|
||||
// provide memory buffers and check whether reorder is required
|
||||
|
||||
// input
|
||||
auto x_user_mem = dnnl::memory(x_user_md, engine, xTR->getBuffer());
|
||||
const bool xReorder = op_prim_desc.src_desc() != x_user_mem.get_desc();
|
||||
auto x_mkl_mem = xReorder ? dnnl::memory(op_prim_desc.src_desc(), engine) : x_user_mem;
|
||||
if (xReorder)
|
||||
dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
|
||||
args[DNNL_ARG_SRC] = x_mkl_mem;
|
||||
|
||||
// y
|
||||
auto y_user_mem = dnnl::memory(y_user_md, engine, yTR->getBuffer());
|
||||
const bool yReorder = op_prim_desc.weights_desc() != y_user_mem.get_desc();
|
||||
auto y_mkl_mem = yReorder ? dnnl::memory(op_prim_desc.weights_desc(), engine) : y_user_mem;
|
||||
if (yReorder)
|
||||
dnnl::reorder(y_user_mem, y_mkl_mem).execute(stream, y_user_mem, y_mkl_mem);
|
||||
args[DNNL_ARG_WEIGHTS] = y_mkl_mem;
|
||||
|
||||
// z
|
||||
auto z_user_mem = dnnl::memory(z_user_md, engine, zR->getBuffer());
|
||||
const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc();
|
||||
auto z_mkl_mem = zReorder ? dnnl::memory(op_prim_desc.dst_desc(), engine) : z_user_mem;
|
||||
args[DNNL_ARG_DST] = z_mkl_mem;
|
||||
|
||||
// run calculations
|
||||
dnnl::matmul(op_prim_desc).execute(stream, args);
|
||||
|
||||
// reorder outputs if necessary
|
||||
if (zReorder)
|
||||
dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem);
|
||||
|
||||
stream.wait();
|
||||
|
||||
if(zR->getBuffer() != z->getBuffer())
|
||||
z->assign(zR);
|
||||
|
||||
if(zR != z)
|
||||
delete zR;
|
||||
if(xTR != xT)
|
||||
delete xTR;
|
||||
if(xT != x)
|
||||
delete xT;
|
||||
if(yTR != yT)
|
||||
delete yTR;
|
||||
if(yT != y)
|
||||
delete yT;
|
||||
|
||||
// shape::printArray(z_mkl_mem.map_data<float>(),8);
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
PLATFORM_IMPL(matmul, ENGINE_CPU) {
|
||||
|
||||
auto x = INPUT_VARIABLE(0);
|
||||
auto y = INPUT_VARIABLE(1);
|
||||
auto z = OUTPUT_VARIABLE(0);
|
||||
|
||||
if(x->isEmpty() || y->isEmpty())
|
||||
return Status::OK();
|
||||
|
||||
const int iSize = (int) block.getIArguments()->size();
|
||||
int transX = iSize > 0 ? INT_ARG(0) : 0;
|
||||
int transY = iSize > 1 ? INT_ARG(1) : 0;
|
||||
const int transZ = iSize > 2 ? INT_ARG(2) : 0;
|
||||
|
||||
const int xRank = x->rankOf();
|
||||
const int yRank = y->rankOf();
|
||||
const int zRank = z->rankOf();
|
||||
|
||||
if (transZ) {
|
||||
x = INPUT_VARIABLE(1);
|
||||
y = INPUT_VARIABLE(0);
|
||||
bool temp = transX;
|
||||
transX = !transY;
|
||||
transY = !temp;
|
||||
}
|
||||
|
||||
const int xLastDim = transX ? -2 : -1;
|
||||
const int yLastDim = transY ? -2 : -1;
|
||||
const int xLastButOneDim = transX ? -1 : -2;
|
||||
const int yLastButOneDim = transY ? -1 : -2;
|
||||
|
||||
// ******* input validation ******* //
|
||||
REQUIRE_TRUE(xRank > 0 && yRank > 0, 0, "MATMUL MKLDNN OP: input arrays must have rank bigger than 0 (should not be scalars), but got instead: x rank = %i, y rank = %i !", xRank, yRank);
|
||||
|
||||
if (xRank == 1 && yRank == 1) { // dot case, output is scalar (or vector with length = 1)
|
||||
REQUIRE_TRUE(x->lengthOf() == y->lengthOf(), 0,"MATMUL MKLDNN OP: since input arrays are vectors they must have the same length, but got x length = %i, y length = %i !",x->lengthOf(), y->lengthOf());
|
||||
} else if (xRank == 1 && yRank == 2) { // vector x matrix, i.e. [4] x [4,5] = [5], output is vector
|
||||
REQUIRE_TRUE(x->lengthOf() == y->sizeAt(yLastButOneDim), 0, "MATMUL MKLDNN OP: input arrays have inconsistent shapes for vector-matrix product: x %s, y %s !", ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str());
|
||||
} else if (xRank == 2 && yRank == 1) { // matrix x vector , i.e. [4,5] x [5] = [4], output is vector
|
||||
REQUIRE_TRUE(x->sizeAt(xLastDim) == y->lengthOf(), 0, "MATMUL MKLDNN OP: input arrays have inconsistent shapes for matrix-vector product: x %s, y %s !", ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str());
|
||||
} else {
|
||||
REQUIRE_TRUE(xRank == yRank && yRank == zRank, 0, "MATMUL MKLDNN OP: input and output arrays must have the same rank, but got instead: x rank = %i, y rank = %i, z rank = %i !", xRank, yRank, zRank);
|
||||
REQUIRE_TRUE(x->sizeAt(xLastDim) == y->sizeAt(yLastButOneDim) && x->sizeAt(xLastButOneDim) == z->sizeAt(-2) && y->sizeAt(yLastDim) == z->sizeAt(-1), 0, "MATMUL MKLDNN OP: input/output arrays have inconsistent shapes for matrix product: x %s, y %s, z %s !", ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str(), ShapeUtils::shapeAsString(z).c_str());
|
||||
|
||||
if (xRank > 2) // outer dims must be the same
|
||||
for (int i = 0; i < xRank - 2; ++i)
|
||||
REQUIRE_TRUE(x->sizeAt(i) == y->sizeAt(i) && y->sizeAt(i) == z->sizeAt(i), 0, "MATMUL MKLDNN OP: input/output arrays have inconsistent shapes for matrix product: x %s, y %s, z %s !", ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str(), ShapeUtils::shapeAsString(z).c_str());
|
||||
}
|
||||
// ******* end of input validation ******* //
|
||||
|
||||
matmulMKLDNN(x, y, z, transX, transY);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
PLATFORM_CHECK(matmul, ENGINE_CPU) {
|
||||
|
||||
auto x = INPUT_VARIABLE(0);
|
||||
auto y = INPUT_VARIABLE(1);
|
||||
|
||||
auto z = INPUT_VARIABLE(0);
|
||||
|
||||
const DataType xType = x->dataType();
|
||||
const DataType yType = y->dataType();
|
||||
const DataType zType = z->dataType();
|
||||
|
||||
|
||||
return block.isUseMKLDNN() &&
|
||||
(
|
||||
(xType==DataType::FLOAT32 && yType==DataType::FLOAT32 && zType==DataType::FLOAT32) ||
|
||||
(xType==DataType::HALF && yType==DataType::HALF && zType==DataType::FLOAT32) ||
|
||||
(xType==DataType::BFLOAT16 && yType==DataType::BFLOAT16 && zType==DataType::BFLOAT16) ||
|
||||
((xType==DataType::UINT8 || xType==DataType::INT8) && (yType==DataType::UINT8 || yType==DataType::INT8) && (zType==DataType::UINT8 || zType==DataType::INT8 || zType==DataType::INT32 || zType==DataType::FLOAT32))
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
}
|
|
@ -84,6 +84,8 @@ namespace nd4j{
|
|||
DECLARE_PLATFORM(depthwise_conv2d, ENGINE_CPU);
|
||||
|
||||
DECLARE_PLATFORM(depthwise_conv2d_bp, ENGINE_CPU);
|
||||
|
||||
DECLARE_PLATFORM(matmul, ENGINE_CPU);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1341,40 +1341,6 @@ TEST_F(DeclarableOpsTests1, MultiplyScalarScalar1) {
|
|||
delete exp;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests1, TestMatMul1) {
|
||||
auto x = NDArrayFactory::create_<float>('c', {3, 5});
|
||||
x->linspace(1);
|
||||
|
||||
auto y = NDArrayFactory::create_<float>('c', {5, 3});
|
||||
y->linspace(1);
|
||||
|
||||
float _expB[]{135.0f, 310.0f, 485.0f, 150.0f, 350.0f, 550.0f, 165.0f, 390.0f, 615.0f};
|
||||
Nd4jLong _expS[] {2, 3, 3, 1, 3, 0, 1, 102}; // expected shape
|
||||
ArrayOptions::setDataType(_expS, nd4j::DataType::FLOAT32);
|
||||
NDArray exp(_expB, _expS);
|
||||
|
||||
auto variableSpace = new VariableSpace();
|
||||
variableSpace->putVariable(-1, x);
|
||||
variableSpace->putVariable(-2, y);
|
||||
variableSpace->putVariable(1, new Variable());
|
||||
|
||||
auto block = new Context(1, variableSpace, false);
|
||||
block->fillInputs({-1, -2});
|
||||
|
||||
nd4j::ops::matmul op;
|
||||
|
||||
Nd4jStatus status = op.execute(block);
|
||||
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||
ASSERT_TRUE(variableSpace->hasVariable(1));
|
||||
|
||||
auto result = variableSpace->getVariable(1)->getNDArray();
|
||||
|
||||
ASSERT_TRUE(result->equalsTo(&exp));
|
||||
|
||||
delete block;
|
||||
delete variableSpace;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests1, TestSoftMax_bp_1) {
|
||||
|
||||
|
|
|
@ -2800,16 +2800,9 @@ TEST_F(DeclarableOpsTests12, QR_Test_1_1) {
|
|||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests12, QR_Test_2) {
|
||||
|
||||
auto in = NDArrayFactory::create<double>('c', {5,3}, {
|
||||
12., -51., 4., 6., 167., -68., -4., 24., -41., -1., 1., 0., 2., 0., 3.
|
||||
});
|
||||
auto expQ = NDArrayFactory::create<double>('c', {5, 3}, {
|
||||
0.8464148, 0.3912908, -0.3431241, -0.42320737, -0.9040873, 0.02927014, 0.28213826, -0.17042054, -0.93285596, 0.07053456, -0.01404065, 0.00109937, -0.14106913, 0.0166551, 0.10577161
|
||||
});
|
||||
|
||||
auto expR = NDArrayFactory::create<double>('c', {3,3}, {
|
||||
-14.177447, -20.666622, 13.401566, 0., -175.04254, 70.080315, 0., 0., 35.201546
|
||||
});
|
||||
auto in = NDArrayFactory::create<double>('c', {5,3}, {12., -51., 4., 6., 167., -68., -4., 24., -41., -1., 1., 0., 2., 0., 3.});
|
||||
auto expQ = NDArrayFactory::create<double>('c', {5, 3}, {0.8464148,0.3912908,-0.3431241,-0.42320737, -0.9040873,0.02927014,0.28213826, -0.17042054, -0.93285596,0.07053456, -0.01404065,0.00109937,-0.14106913,0.0166551,0.10577161});
|
||||
auto expR = NDArrayFactory::create<double>('c', {3,3}, {-14.177447,-20.666622,13.401566,0.,-175.04254,70.080315,0.,0.,35.201546});
|
||||
|
||||
nd4j::ops::qr op;
|
||||
auto res = op.evaluate({&in}, {}, {}, {false});
|
||||
|
@ -2819,8 +2812,6 @@ TEST_F(DeclarableOpsTests12, QR_Test_2) {
|
|||
auto r = res->at(1);
|
||||
ASSERT_TRUE(q->isSameShape(expQ));
|
||||
ASSERT_TRUE(r->isSameShape(expR));
|
||||
// q->printIndexedBuffer("Orthogonal 5x5");
|
||||
// r->printIndexedBuffer("Upper triangular 5x3");
|
||||
|
||||
nd4j::ops::matmul opMul;
|
||||
auto res2 = opMul.evaluate({q, r}); //MmulHelper::matmul(q, r, &in, false, false);
|
||||
|
|
|
@ -682,3 +682,810 @@ TEST_F(DeclarableOpsTests14, Test_broadcast_SpecialCaseTest8) {
|
|||
x.applyTrueBroadcast(BroadcastOpsTuple::Subtract(), y, z);
|
||||
ASSERT_EQ(e, z);
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests14, matmul_test1) {
|
||||
|
||||
auto x = NDArrayFactory::create<double>('c', {3, 4});
|
||||
auto y = NDArrayFactory::create<double>('c', {4, 3});
|
||||
auto exp = NDArrayFactory::create<double>('f', {3, 3}, {35., 79., 123., 40., 92., 144., 45., 105., 165.});
|
||||
|
||||
x.linspace(1.);
|
||||
y.linspace(0.5, 0.5);
|
||||
|
||||
nd4j::ops::matmul op;
|
||||
auto results = op.evaluate({&x, &y}, {}, {});
|
||||
auto z = results->at(0);
|
||||
|
||||
ASSERT_EQ(Status::OK(), results->status());
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests14, matmul_test2) {
|
||||
|
||||
auto x = NDArrayFactory::create<double>('c', {3, 4});
|
||||
auto y = NDArrayFactory::create<double>('f', {4, 3});
|
||||
auto exp = NDArrayFactory::create<double>('f', {3, 3}, {35., 79., 123.,40., 92., 144.,45.,105., 165.});
|
||||
|
||||
x.linspace(1.);
|
||||
y.linspace(0.5, 0.5);
|
||||
|
||||
nd4j::ops::matmul op;
|
||||
auto results = op.evaluate({&x, &y}, {}, {});
|
||||
auto z = results->at(0);
|
||||
|
||||
ASSERT_EQ(Status::OK(), results->status());
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests14, matmul_test3) {
|
||||
|
||||
auto x = NDArrayFactory::create<double>('f', {3, 4});
|
||||
auto y = NDArrayFactory::create<double>('c', {4, 3});
|
||||
auto exp = NDArrayFactory::create<double>('f', {3, 3}, {35., 79., 123.,40., 92., 144.,45.,105., 165.});
|
||||
|
||||
x.linspace(1.);
|
||||
y.linspace(0.5, 0.5);
|
||||
|
||||
nd4j::ops::matmul op;
|
||||
auto results = op.evaluate({&x, &y}, {}, {});
|
||||
auto z = results->at(0);
|
||||
|
||||
ASSERT_EQ(Status::OK(), results->status());
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests14, matmul_test4) {
|
||||
|
||||
auto x = NDArrayFactory::create<double> ('f', {3, 4});
|
||||
auto y = NDArrayFactory::create<double>('f', {4, 3});
|
||||
auto exp = NDArrayFactory::create<double>('f', {3, 3}, {35., 79., 123.,40., 92., 144.,45.,105., 165.});
|
||||
|
||||
x.linspace(1.);
|
||||
y.linspace(0.5, 0.5);
|
||||
|
||||
nd4j::ops::matmul op;
|
||||
auto results = op.evaluate({&x, &y}, {}, {});
|
||||
auto z = results->at(0);
|
||||
|
||||
ASSERT_EQ(Status::OK(), results->status());
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests14, matmul_test5) {
|
||||
|
||||
auto x = NDArrayFactory::create<double>('c', {4, 3});
|
||||
auto y = NDArrayFactory::create<double>('c', {4, 3});
|
||||
auto exp = NDArrayFactory::create<double>('f', {3, 3}, {83., 94., 105., 94., 107., 120., 105., 120., 135.});
|
||||
|
||||
x.linspace(1.);
|
||||
y.linspace(0.5, 0.5);
|
||||
|
||||
nd4j::ops::matmul op;
|
||||
auto results = op.evaluate({&x, &y}, {}, {1});
|
||||
auto z = results->at(0);
|
||||
|
||||
ASSERT_EQ(Status::OK(), results->status());
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests14, matmul_test6) {
|
||||
|
||||
auto x = NDArrayFactory::create<double>('c', {4, 3});
|
||||
auto y = NDArrayFactory::create<double>('f', {3, 4});
|
||||
auto exp = NDArrayFactory::create<double>('f', {3, 3}, {35., 40., 45., 79., 92., 105., 123., 144., 165.});
|
||||
|
||||
x.linspace(1.);
|
||||
y.linspace(0.5, 0.5);
|
||||
|
||||
nd4j::ops::matmul op;
|
||||
auto results = op.evaluate({&x, &y}, {}, {1, 1});
|
||||
auto z = results->at(0);
|
||||
|
||||
ASSERT_EQ(Status::OK(), results->status());
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests14, matmul_test7) {
|
||||
|
||||
auto x = NDArrayFactory::create<double>('c', {5, 3,4});
|
||||
auto y = NDArrayFactory::create<double>('f', {5, 3,4});
|
||||
auto exp = NDArrayFactory::create<double>('f',{5, 3,3}, {3. , 84.6, 281.4, 593.4, 1020.6, 7. , 107.8, 323.8, 655. , 1101.4,11. , 131. , 366.2, 716.6, 1182.2,
|
||||
7. , 107.8, 323.8, 655. , 1101.4,17.4, 137.4, 372.6, 723. , 1188.6,27.8, 167. , 421.4, 791. , 1275.8,
|
||||
11. , 131. , 366.2, 716.6, 1182.2,27.8, 167. , 421.4, 791. , 1275.8,44.6, 203. , 476.6, 865.4, 1369.4,});
|
||||
|
||||
x.linspace(1.);
|
||||
y.linspace(0.1, 0.1);
|
||||
|
||||
nd4j::ops::matmul op;
|
||||
auto results = op.evaluate({&x, &y}, {}, {0, 1});
|
||||
auto z = results->at(0);
|
||||
|
||||
ASSERT_EQ(Status::OK(), results->status());
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests14, matmul_test8) {
|
||||
|
||||
auto x = NDArrayFactory::create<double>('c', {2,5, 3,4});
|
||||
auto y = NDArrayFactory::create<double>('f', {2,5, 3,4});
|
||||
auto exp = NDArrayFactory::create<double>('f',{2,5, 3,3}, {3. , 1563. , 84.6, 2220.6, 281.4, 2993.4, 593.4, 3881.4,1020.6, 4884.6, 7. , 1663. , 107.8, 2339.8, 323.8, 3131.8, 655. , 4039. ,1101.4, 5061.4,
|
||||
11. , 1763. , 131. , 2459. , 366.2, 3270.2, 716.6, 4196.6,1182.2, 5238.2, 7. , 1663. , 107.8, 2339.8, 323.8, 3131.8, 655. , 4039. ,1101.4, 5061.4,
|
||||
17.4, 1769.4, 137.4, 2465.4, 372.6, 3276.6, 723. , 4203. ,1188.6, 5244.6, 27.8, 1875.8, 167. , 2591. , 421.4, 3421.4, 791. , 4367. ,1275.8, 5427.8,
|
||||
11. , 1763. , 131. , 2459. , 366.2, 3270.2, 716.6, 4196.6,1182.2, 5238.2, 27.8, 1875.8, 167. , 2591. , 421.4, 3421.4, 791. , 4367. ,1275.8, 5427.8,
|
||||
44.6, 1988.6, 203. , 2723. , 476.6, 3572.6, 865.4, 4537.4,1369.4, 5617.4});
|
||||
|
||||
x.linspace(1.);
|
||||
y.linspace(0.1, 0.1);
|
||||
|
||||
nd4j::ops::matmul op;
|
||||
auto results = op.evaluate({&x, &y}, {}, {0, 1});
|
||||
auto z = results->at(0);
|
||||
|
||||
ASSERT_EQ(Status::OK(), results->status());
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests14, matmul_test9) {
|
||||
|
||||
auto x = NDArrayFactory::create<double>('c', {2,5, 4,3});
|
||||
auto y = NDArrayFactory::create<double>('f', {2,5, 3,4});
|
||||
auto exp = NDArrayFactory::create<double>('f',{2,5, 3,3}, {7. , 1639. , 103. , 2311. , 314.2, 3098.2, 640.6, 4000.6,1082.2, 5018.2, 8. , 1664. , 108.8, 2340.8, 324.8, 3132.8, 656. , 4040. ,1102.4, 5062.4,
|
||||
9. , 1689. , 114.6, 2370.6, 335.4, 3167.4, 671.4, 4079.4,1122.6, 5106.6, 15.8, 1743.8, 131. , 2435. , 361.4, 3241.4, 707. , 4163. ,1167.8, 5199.8,
|
||||
18.4, 1770.4, 138.4, 2466.4, 373.6, 3277.6, 724. , 4204. ,1189.6, 5245.6, 21. , 1797. , 145.8, 2497.8, 385.8, 3313.8, 741. , 4245. ,1211.4, 5291.4,
|
||||
24.6, 1848.6, 159. , 2559. , 408.6, 3384.6, 773.4, 4325.4,1253.4, 5381.4, 28.8, 1876.8, 168. , 2592. , 422.4, 3422.4, 792. , 4368. ,1276.8, 5428.8,
|
||||
33. , 1905. , 177. , 2625. , 436.2, 3460.2, 810.6, 4410.6,1300.2, 5476.2});
|
||||
|
||||
x.linspace(1.);
|
||||
y.linspace(0.1, 0.1);
|
||||
|
||||
nd4j::ops::matmul op;
|
||||
auto results = op.evaluate({&x, &y}, {}, {1, 1});
|
||||
auto z = results->at(0);
|
||||
|
||||
ASSERT_EQ(Status::OK(), results->status());
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests14, matmul_test10) {
|
||||
|
||||
auto x = NDArrayFactory::create_<float>('c', {3, 5});
|
||||
x->linspace(1);
|
||||
|
||||
auto y = NDArrayFactory::create_<float>('c', {5, 3});
|
||||
y->linspace(1);
|
||||
|
||||
float _expB[]{135.0f, 310.0f, 485.0f, 150.0f, 350.0f, 550.0f, 165.0f, 390.0f, 615.0f};
|
||||
Nd4jLong _expS[] {2, 3, 3, 1, 3, 0, 1, 102}; // expected shape
|
||||
ArrayOptions::setDataType(_expS, nd4j::DataType::FLOAT32);
|
||||
NDArray exp(_expB, _expS);
|
||||
|
||||
auto variableSpace = new VariableSpace();
|
||||
variableSpace->putVariable(-1, x);
|
||||
variableSpace->putVariable(-2, y);
|
||||
variableSpace->putVariable(1, new Variable());
|
||||
|
||||
auto block = new Context(1, variableSpace, false);
|
||||
block->fillInputs({-1, -2});
|
||||
|
||||
nd4j::ops::matmul op;
|
||||
|
||||
Nd4jStatus status = op.execute(block);
|
||||
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||
ASSERT_TRUE(variableSpace->hasVariable(1));
|
||||
|
||||
auto result = variableSpace->getVariable(1)->getNDArray();
|
||||
|
||||
ASSERT_TRUE(result->equalsTo(&exp));
|
||||
|
||||
delete block;
|
||||
delete variableSpace;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests14, matmul_test11) {
|
||||
auto A = NDArrayFactory::create<float>('c', {3, 3});
|
||||
auto B = NDArrayFactory::create<float>('c', {3, 1});
|
||||
auto exp = NDArrayFactory::create<float>('c', {3, 1}, {14.00f, 32.00f, 50.00f});
|
||||
|
||||
A.linspace(1);
|
||||
B.linspace(1);
|
||||
|
||||
nd4j::ops::matmul op;
|
||||
|
||||
auto result = op.evaluate({&A, &B}, {}, {});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
|
||||
auto z = result->at(0);
|
||||
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests14, matmul_test12) {
|
||||
auto x= NDArrayFactory::create<double>('c', {3, 4}, {1, 2, 3, 4, 5, 6, 7, 8 , 9, 10, 11, 12});
|
||||
auto y= NDArrayFactory::create<double>('c', {4, 3}, {1, 2, 3, 4, 5, 6, 7, 8 , 9, 10, 11, 12});
|
||||
auto exp= NDArrayFactory::create<double>('f', {4, 4}, {38.0, 44.0, 50.0, 56.0, 83.0, 98.0, 113.0, 128.0, 128.0, 152.0, 176.0, 200.0, 173.0, 206.0, 239.0, 272.0});
|
||||
|
||||
nd4j::ops::matmul op;
|
||||
auto result = op.evaluate({&x, &y}, {}, {1, 1});
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
|
||||
auto z = result->at(0);
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
|
||||
TEST_F(DeclarableOpsTests14, matmul_test13) {
|
||||
auto x= NDArrayFactory::create<double>('c', {1, 3}, {1, 2, 3});
|
||||
auto y= NDArrayFactory::create<double>('c', {1, 4}, {1, 2, 3, 4});
|
||||
auto exp= NDArrayFactory::create<double>('f', {3, 4}, {1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 3.0, 6.0, 9.0, 4.0, 8.0, 12.0});
|
||||
|
||||
nd4j::ops::matmul op;
|
||||
auto result = op.evaluate({&x, &y}, {}, {1, 0});
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
|
||||
auto z = result->at(0);
|
||||
|
||||
//z->printIndexedBuffer("z");
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests14, matmul_test14) {
|
||||
auto x= NDArrayFactory::create<double>('c', {3, 1}, {1, 2, 3});
|
||||
auto y= NDArrayFactory::create<double>('c', {4, 1}, {1, 2, 3, 4});
|
||||
auto exp= NDArrayFactory::create<double>('f', {3, 4}, {1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 3.0, 6.0, 9.0, 4.0, 8.0, 12.0});
|
||||
|
||||
nd4j::ops::matmul op;
|
||||
auto result = op.evaluate({&x, &y}, {}, {0, 1});
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
|
||||
auto z = result->at(0);
|
||||
|
||||
//z->printIndexedBuffer("z");
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests14, matmul_test15) {
|
||||
auto x= NDArrayFactory::create<double>('c', {3, 1}, {1, 2, 3});
|
||||
auto y= NDArrayFactory::create<double>('c', {1, 4}, {1, 2, 3, 4});
|
||||
auto exp= NDArrayFactory::create<double>('f', {3, 4}, {1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 3.0, 6.0, 9.0, 4.0, 8.0, 12.0});
|
||||
|
||||
nd4j::ops::matmul op;
|
||||
auto result = op.evaluate({&x, &y}, {}, {});
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
|
||||
auto z = result->at(0);
|
||||
|
||||
//z->printIndexedBuffer("z");
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests14, matmul_test16) {
|
||||
auto x= NDArrayFactory::create<double>('c', {4, 1}, {1, 2, 3, 4});
|
||||
auto y= NDArrayFactory::create<double>('c', {1, 4}, {1, 2, 3, 4});
|
||||
auto exp= NDArrayFactory::create<double>('f', {4, 4}, {1,2, 3, 4,2,4, 6, 8,3,6, 9,12,4,8,12,16});
|
||||
|
||||
nd4j::ops::matmul op;
|
||||
auto result = op.evaluate({&x, &y});
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
|
||||
auto z = result->at(0);
|
||||
|
||||
//z->printIndexedBuffer("z");
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests14, matmul_test17) {
|
||||
auto x = NDArrayFactory::create<double>('c', {1, 2}, {2.0f, 2.0f});
|
||||
auto y = NDArrayFactory::create<double>('c', {2, 1}, {2.0f, 2.0f});
|
||||
auto exp = NDArrayFactory::create<double>('c', {1, 1}, {8.0f});
|
||||
|
||||
nd4j::ops::matmul op;
|
||||
auto result = op.evaluate({&x, &y}, {}, {});
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
|
||||
ASSERT_EQ(exp, *result->at(0));
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests14, matmul_test18) {
|
||||
|
||||
auto x = NDArrayFactory::create<double>('c', {1, 4, 3});
|
||||
auto y = NDArrayFactory::create<double>('f', {1, 3, 4});
|
||||
auto exp = NDArrayFactory::create<double>('f', {1, 3, 3}, {35., 40., 45., 79., 92., 105., 123., 144., 165.});
|
||||
|
||||
x.linspace(1.);
|
||||
y.linspace(0.5, 0.5);
|
||||
|
||||
nd4j::ops::matmul op;
|
||||
auto results = op.evaluate({&x, &y}, {}, {1, 1});
|
||||
auto z = results->at(0);
|
||||
|
||||
ASSERT_EQ(Status::OK(), results->status());
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests14, matmul_test19) {
|
||||
|
||||
auto x = NDArrayFactory::create<double>('c', {4, 1});
|
||||
auto y = NDArrayFactory::create<double>('f', {1, 4});
|
||||
auto exp = NDArrayFactory::create<double>('f', {1, 1}, {15});
|
||||
|
||||
x.linspace(1.);
|
||||
y.linspace(0.5, 0.5);
|
||||
nd4j::ops::matmul op;
|
||||
auto results = op.evaluate({&x, &y}, {}, {1, 1});
|
||||
ASSERT_EQ(Status::OK(), results->status());
|
||||
|
||||
auto z = results->at(0);
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests14, matmul_test20) {
|
||||
|
||||
auto x = NDArrayFactory::create<double>('c', {1, 4, 1});
|
||||
auto y = NDArrayFactory::create<double>('f', {1, 1, 4});
|
||||
auto exp = NDArrayFactory::create<double>('f', {1, 1, 1}, {15});
|
||||
|
||||
x.linspace(1.);
|
||||
y.linspace(0.5, 0.5);
|
||||
|
||||
nd4j::ops::matmul op;
|
||||
auto results = op.evaluate({&x, &y}, {}, {1, 1});
|
||||
|
||||
ASSERT_EQ(Status::OK(), results->status());
|
||||
auto z = results->at(0);
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests14, matmul_test21) {
|
||||
|
||||
auto x = NDArrayFactory::create<double>('c', {2, 3});
|
||||
auto y = NDArrayFactory::create<double>('c', {3, 5});
|
||||
auto exp = NDArrayFactory::create<double>('f', {5, 2}, {23. , 26. , 29. , 32. , 35., 50. , 57.5, 65. , 72.5, 80.});
|
||||
|
||||
x.linspace(1.);
|
||||
y.linspace(0.5, 0.5);
|
||||
|
||||
nd4j::ops::matmul op;
|
||||
auto results = op.evaluate({&x, &y}, {}, {0, 0, 1});
|
||||
auto z = results->at(0);
|
||||
|
||||
ASSERT_EQ(Status::OK(), results->status());
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests14, matmul_test22) {
|
||||
|
||||
auto x = NDArrayFactory::create<double>('c', {3, 2});
|
||||
auto y = NDArrayFactory::create<double>('c', {3, 5});
|
||||
auto exp = NDArrayFactory::create<double>('f', {5, 2}, {37. , 41.5, 46. , 50.5, 55., 46. , 52. , 58. , 64. , 70.});
|
||||
|
||||
x.linspace(1.);
|
||||
y.linspace(0.5, 0.5);
|
||||
|
||||
nd4j::ops::matmul op;
|
||||
auto results = op.evaluate({&x, &y}, {}, {1, 0, 1});
|
||||
auto z = results->at(0);
|
||||
|
||||
ASSERT_EQ(Status::OK(), results->status());
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests14, matmul_test23) {
|
||||
|
||||
auto x = NDArrayFactory::create<double>('c', {3, 2});
|
||||
auto y = NDArrayFactory::create<double>('c', {3, 5});
|
||||
auto exp = NDArrayFactory::create<double>('f', {5, 2}, {37. , 41.5, 46. , 50.5, 55., 46. , 52. , 58. , 64. , 70.});
|
||||
|
||||
x.linspace(1.);
|
||||
y.linspace(0.5, 0.5);
|
||||
|
||||
nd4j::ops::matmul op;
|
||||
auto results = op.evaluate({&x, &y}, {}, {1, 0, 1});
|
||||
auto z = results->at(0);
|
||||
|
||||
ASSERT_EQ(Status::OK(), results->status());
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests14, matmul_test24) {
|
||||
|
||||
auto x = NDArrayFactory::create<double>('c', {2,2, 3,5});
|
||||
auto y = NDArrayFactory::create<double>('c', {2,2, 4,3});
|
||||
auto exp = NDArrayFactory::create<double>('f',{2,2, 4,5}, {4.6, 281.8, 89.2, 582.4, 10. , 314.2,108.1, 628.3, 15.4, 346.6,127. , 674.2, 20.8, 379. ,145.9, 720.1, 5.2, 289.6, 93.4, 593.8,
|
||||
11.5, 322.9,113.2, 640.6, 17.8, 356.2,133. , 687.4, 24.1, 389.5,152.8, 734.2, 5.8, 297.4, 97.6, 605.2, 13. , 331.6,118.3, 652.9,
|
||||
20.2, 365.8,139. , 700.6, 27.4, 400. ,159.7, 748.3, 6.4, 305.2,101.8, 616.6, 14.5, 340.3,123.4, 665.2, 22.6, 375.4,145. , 713.8,
|
||||
30.7, 410.5,166.6, 762.4, 7. , 313. ,106. , 628. , 16. , 349. ,128.5, 677.5, 25. , 385. ,151. , 727. , 34. , 421. ,173.5, 776.5});
|
||||
|
||||
x.linspace(1.);
|
||||
y.linspace(0.1, 0.1);
|
||||
|
||||
nd4j::ops::matmul op;
|
||||
auto results = op.evaluate({&x, &y}, {}, {1, 1, 1});
|
||||
auto z = results->at(0);
|
||||
|
||||
ASSERT_EQ(Status::OK(), results->status());
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests14, matmul_test25) {
|
||||
|
||||
auto x = NDArrayFactory::create<double>('f', {4, 3});
|
||||
auto y = NDArrayFactory::create<double>('c', {4});
|
||||
auto exp = NDArrayFactory::create<double>('f',{3}, {7., 8., 9.});
|
||||
|
||||
x.linspace(1.);
|
||||
y.linspace(0.1, 0.1);
|
||||
|
||||
nd4j::ops::matmul op;
|
||||
auto results = op.evaluate({&x, &y}, {}, {1, 0});
|
||||
auto z = results->at(0);
|
||||
|
||||
ASSERT_EQ(Status::OK(), results->status());
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests14, matmul_test26) {
|
||||
|
||||
auto x = NDArrayFactory::create<double>('f', {3});
|
||||
auto y = NDArrayFactory::create<double>('c', {4, 3});
|
||||
auto exp = NDArrayFactory::create<double>('f',{4}, {1.4, 3.2, 5., 6.8});
|
||||
|
||||
x.linspace(1.);
|
||||
y.linspace(0.1, 0.1);
|
||||
|
||||
nd4j::ops::matmul op;
|
||||
auto results = op.evaluate({&x, &y}, {}, {0, 1});
|
||||
auto z = results->at(0);
|
||||
|
||||
ASSERT_EQ(Status::OK(), results->status());
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests14, matmul_test27) {
|
||||
|
||||
auto x = NDArrayFactory::create<double>('f', {1, 1});
|
||||
auto y = NDArrayFactory::create<double>('c', {1, 1});
|
||||
auto exp = NDArrayFactory::create<double>('f',{1, 1}, {0.2});
|
||||
|
||||
x.linspace(2.);
|
||||
y.linspace(0.1, 0.1);
|
||||
|
||||
nd4j::ops::matmul op;
|
||||
auto results = op.evaluate({&x, &y}, {}, {});
|
||||
auto z = results->at(0);
|
||||
|
||||
ASSERT_EQ(Status::OK(), results->status());
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests14, matmul_test28) {
|
||||
|
||||
auto x = NDArrayFactory::create<double>('f', {1, 1});
|
||||
auto y = NDArrayFactory::create<double>('c', {1, 1});
|
||||
auto exp = NDArrayFactory::create<double>('f',{1, 1}, {0.2});
|
||||
|
||||
x.linspace(2.);
|
||||
y.linspace(0.1, 0.1);
|
||||
|
||||
nd4j::ops::matmul op;
|
||||
auto results = op.evaluate({&x, &y}, {}, {1,1,1});
|
||||
auto z = results->at(0);
|
||||
|
||||
ASSERT_EQ(Status::OK(), results->status());
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests14, matmul_test29) {
|
||||
|
||||
auto x = NDArrayFactory::create<double>('f', {1});
|
||||
auto y = NDArrayFactory::create<double>('c', {1, 1});
|
||||
auto exp = NDArrayFactory::create<double>('f',{1}, {0.2});
|
||||
|
||||
x.linspace(2.);
|
||||
y.linspace(0.1, 0.1);
|
||||
|
||||
nd4j::ops::matmul op;
|
||||
auto results = op.evaluate({&x, &y}, {}, {});
|
||||
auto z = results->at(0);
|
||||
|
||||
ASSERT_EQ(Status::OK(), results->status());
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests14, matmul_test30) {
|
||||
|
||||
auto x = NDArrayFactory::create<double>('f', {1,1});
|
||||
auto y = NDArrayFactory::create<double>('c', {1});
|
||||
auto exp = NDArrayFactory::create<double>('f',{1}, {0.2});
|
||||
|
||||
x.linspace(2.);
|
||||
y.linspace(0.1, 0.1);
|
||||
|
||||
nd4j::ops::matmul op;
|
||||
auto results = op.evaluate({&x, &y}, {}, {1});
|
||||
auto z = results->at(0);
|
||||
|
||||
ASSERT_EQ(Status::OK(), results->status());
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests14, matmul_test31) {
|
||||
|
||||
auto x = NDArrayFactory::create<double>('f', {4});
|
||||
auto y = NDArrayFactory::create<double>('c', {4});
|
||||
auto exp = NDArrayFactory::create<double>(3.);
|
||||
|
||||
x.linspace(1.);
|
||||
y.linspace(0.1, 0.1);
|
||||
|
||||
nd4j::ops::matmul op;
|
||||
auto results = op.evaluate({&x, &y}, {}, {1, 1});
|
||||
auto z = results->at(0);
|
||||
|
||||
ASSERT_EQ(Status::OK(), results->status());
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests14, matmul_test32) {
|
||||
|
||||
auto x = NDArrayFactory::create<double>('f', {1}, {2.});
|
||||
auto y = NDArrayFactory::create<double>('c', {1}, {3.});
|
||||
auto exp = NDArrayFactory::create<double>(6.);
|
||||
|
||||
nd4j::ops::matmul op;
|
||||
auto results = op.evaluate({&x, &y}, {}, {1, 1});
|
||||
auto z = results->at(0);
|
||||
|
||||
ASSERT_EQ(Status::OK(), results->status());
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
|
||||
TEST_F(DeclarableOpsTests14, matmul_test33) {
|
||||
auto x = NDArrayFactory::create<double>('c', {4, 3});
|
||||
auto y = NDArrayFactory::create<double>('c', {4, 1});
|
||||
auto exp = NDArrayFactory::create<double>('c',{ 3, 1}, {70, 80, 90});
|
||||
|
||||
x.linspace(1);
|
||||
y.linspace(1);
|
||||
|
||||
nd4j::ops::matmul op;
|
||||
auto result = op.evaluate({&x, &y}, {}, {1, 0});
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
|
||||
auto z = result->at(0);
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
|
||||
TEST_F(DeclarableOpsTests14, matmul_test34) {
|
||||
auto a = NDArrayFactory::create<double>('c', {3, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||||
auto b = NDArrayFactory::create<double>('c', {4}, {1, 2, 3, 4});
|
||||
auto exp = NDArrayFactory::create<double>('c', {3}, {30, 70, 110});
|
||||
|
||||
nd4j::ops::matmul op;
|
||||
auto result = op.evaluate({&a, &b});
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
|
||||
auto z = result->at(0);
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests14, matmul_test35) {
|
||||
auto a = NDArrayFactory::create<double>('c', {4}, {1, 2, 3, 4});
|
||||
auto b = NDArrayFactory::create<double>('c', {4, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||||
auto exp = NDArrayFactory::create<double>('c', {3}, {70, 80, 90});
|
||||
|
||||
nd4j::ops::matmul op;
|
||||
auto result = op.evaluate({&a, &b});
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
|
||||
auto z = result->at(0);
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests14, matmul_test36) {
|
||||
auto a = NDArrayFactory::create<double>('c', {1, 4}, {1, 2, 3, 4});
|
||||
auto b = NDArrayFactory::create<double>('c', {4, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||||
auto exp = NDArrayFactory::create<double>('c', {1, 3}, {70, 80, 90});
|
||||
|
||||
nd4j::ops::matmul op;
|
||||
auto result = op.evaluate({&a, &b});
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
|
||||
auto z = result->at(0);
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests14, matmul_test37) {
|
||||
|
||||
NDArray a('c', {32, 12, 128, 64}, nd4j::DataType::FLOAT32);
|
||||
NDArray b('c', {32, 12, 128, 64}, nd4j::DataType::FLOAT32);
|
||||
NDArray c('c', {32,12,128,128}, nd4j::DataType::FLOAT32);
|
||||
NDArray cExp('c', {32,12,128,128}, nd4j::DataType::FLOAT32);
|
||||
|
||||
a = 1;
|
||||
b = 1;
|
||||
cExp = 64; //Each entry in output c is sum of 64 (1.0 x 1.0) multiplications
|
||||
|
||||
nd4j::ops::matmul op;
|
||||
auto status = op.execute({&a, &b}, {&c}, {}, {0,1});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||
|
||||
ASSERT_TRUE(cExp.isSameShape(c));
|
||||
ASSERT_TRUE(cExp.equalsTo(c));
|
||||
}
|
||||
|
||||
// @Test
|
||||
// public void testMmulRank4_simple(){
|
||||
|
||||
// INDArray arr1 = Nd4j.ones(DataType.FLOAT, 32, 12, 128, 64);
|
||||
// INDArray arr2 = Nd4j.ones(DataType.FLOAT, 32, 12, 128, 64);
|
||||
|
||||
// DynamicCustomOp op = DynamicCustomOp.builder("matmul")
|
||||
// .addInputs(arr1, arr2)
|
||||
// .addIntegerArguments(0, 1) //Transpose arr2 only
|
||||
// .build();
|
||||
|
||||
// List<LongShapeDescriptor> shapes = op.calculateOutputShape();
|
||||
// assertEquals(1, shapes.size());
|
||||
// long[] shape = new long[]{32,12,128,128};
|
||||
// assertArrayEquals(shape, shapes.get(0).getShape());
|
||||
|
||||
// INDArray out = Nd4j.create(DataType.FLOAT, shape);
|
||||
|
||||
// op.setOutputArgument(0, out);
|
||||
// Nd4j.exec(op);
|
||||
// // System.out.println(out);
|
||||
|
||||
// INDArray exp = Nd4j.valueArrayOf(shape, 64.0, DataType.FLOAT); //Each entry in output is sum of 64 (1.0 x 1.0) multiplications
|
||||
// assertEquals(exp, out);
|
||||
// }
|
||||
|
|
|
@ -397,27 +397,6 @@ TEST_F(DeclarableOpsTests2, NLP_Cbow_Test_1) {
|
|||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests2, YetAnotherMatmulTest_1) {
|
||||
auto A = NDArrayFactory::create<float>('c', {3, 3});
|
||||
auto B = NDArrayFactory::create<float>('c', {3, 1});
|
||||
auto exp = NDArrayFactory::create<float>('c', {3, 1}, {14.00f, 32.00f, 50.00f});
|
||||
|
||||
A.linspace(1);
|
||||
B.linspace(1);
|
||||
|
||||
nd4j::ops::matmul op;
|
||||
|
||||
auto result = op.evaluate({&A, &B}, {}, {});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
|
||||
auto z = result->at(0);
|
||||
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests2, Test_Squeeze_1) {
|
||||
auto x = NDArrayFactory::create<float>('c', {2, 1, 3, 1, 1, 1, 4});
|
||||
x.linspace(1);
|
||||
|
|
|
@ -789,120 +789,6 @@ TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_Validation_2) {
|
|||
}
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests3, Test_Manual_Gemm_1) {
|
||||
auto x= NDArrayFactory::create<double>('c', {3, 4}, {1, 2, 3, 4, 5, 6, 7, 8 , 9, 10, 11, 12});
|
||||
auto y= NDArrayFactory::create<double>('c', {4, 3}, {1, 2, 3, 4, 5, 6, 7, 8 , 9, 10, 11, 12});
|
||||
auto exp= NDArrayFactory::create<double>('f', {4, 4}, {38.0, 44.0, 50.0, 56.0, 83.0, 98.0, 113.0, 128.0, 128.0, 152.0, 176.0, 200.0, 173.0, 206.0, 239.0, 272.0});
|
||||
|
||||
nd4j::ops::matmul op;
|
||||
auto result = op.evaluate({&x, &y}, {}, {1, 1});
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
|
||||
auto z = result->at(0);
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
|
||||
TEST_F(DeclarableOpsTests3, Test_Manual_Gemm_2) {
|
||||
auto x= NDArrayFactory::create<double>('c', {3, 4}, {1, 2, 3, 4, 5, 6, 7, 8 , 9, 10, 11, 12});
|
||||
auto y= NDArrayFactory::create<double>('c', {4, 3}, {1, 2, 3, 4, 5, 6, 7, 8 , 9, 10, 11, 12});
|
||||
auto exp= NDArrayFactory::create<double>('f', {3, 3}, {70.0, 158.0, 246.0, 80.0, 184.0, 288.0, 90.0, 210.0, 330.0});
|
||||
|
||||
nd4j::ops::matmul op;
|
||||
auto result = op.evaluate({&x, &y}, {}, {0, 0});
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
|
||||
auto z = result->at(0);
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
|
||||
TEST_F(DeclarableOpsTests3, Test_Manual_Gemm_3) {
|
||||
auto x= NDArrayFactory::create<double>('c', {1, 3}, {1, 2, 3});
|
||||
auto y= NDArrayFactory::create<double>('c', {1, 4}, {1, 2, 3, 4});
|
||||
auto exp= NDArrayFactory::create<double>('f', {3, 4}, {1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 3.0, 6.0, 9.0, 4.0, 8.0, 12.0});
|
||||
|
||||
nd4j::ops::matmul op;
|
||||
auto result = op.evaluate({&x, &y}, {}, {1, 0});
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
|
||||
auto z = result->at(0);
|
||||
|
||||
//z->printIndexedBuffer("z");
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests3, Test_Manual_Gemm_4) {
|
||||
auto x= NDArrayFactory::create<double>('c', {3, 1}, {1, 2, 3});
|
||||
auto y= NDArrayFactory::create<double>('c', {4, 1}, {1, 2, 3, 4});
|
||||
auto exp= NDArrayFactory::create<double>('f', {3, 4}, {1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 3.0, 6.0, 9.0, 4.0, 8.0, 12.0});
|
||||
|
||||
nd4j::ops::matmul op;
|
||||
auto result = op.evaluate({&x, &y}, {}, {0, 1});
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
|
||||
auto z = result->at(0);
|
||||
|
||||
//z->printIndexedBuffer("z");
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests3, Test_Manual_Gemm_5) {
|
||||
auto x= NDArrayFactory::create<double>('c', {3, 1}, {1, 2, 3});
|
||||
auto y= NDArrayFactory::create<double>('c', {1, 4}, {1, 2, 3, 4});
|
||||
auto exp= NDArrayFactory::create<double>('f', {3, 4}, {1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 3.0, 6.0, 9.0, 4.0, 8.0, 12.0});
|
||||
|
||||
nd4j::ops::matmul op;
|
||||
auto result = op.evaluate({&x, &y}, {}, {});
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
|
||||
auto z = result->at(0);
|
||||
|
||||
//z->printIndexedBuffer("z");
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests3, Test_Manual_Gemm_6) {
|
||||
auto x= NDArrayFactory::create<double>('c', {4, 1}, {1, 2, 3, 4});
|
||||
auto y= NDArrayFactory::create<double>('c', {1, 4}, {1, 2, 3, 4});
|
||||
auto exp= NDArrayFactory::create<double>('f', {4, 4}, {1,2, 3, 4,2,4, 6, 8,3,6, 9,12,4,8,12,16});
|
||||
|
||||
nd4j::ops::matmul op;
|
||||
auto result = op.evaluate({&x, &y});
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
|
||||
auto z = result->at(0);
|
||||
|
||||
//z->printIndexedBuffer("z");
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests3, Test_ReverseDivide_1) {
|
||||
auto x= NDArrayFactory::create<double>('c', {1, 3}, {2, 2, 2});
|
||||
auto y= NDArrayFactory::create<double>('c', {1, 3}, {4, 6, 8});
|
||||
|
|
|
@ -809,26 +809,6 @@ TEST_F(DeclarableOpsTests4, Test_Reshape_Again) {
|
|||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests4, Test_Gemv_Transpose_1) {
|
||||
auto x = NDArrayFactory::create<double>('c', {4, 3});
|
||||
auto y = NDArrayFactory::create<double>('c', {4, 1});
|
||||
auto exp = NDArrayFactory::create<double>('c',{ 3, 1}, {70, 80, 90});
|
||||
|
||||
x.linspace(1);
|
||||
y.linspace(1);
|
||||
|
||||
nd4j::ops::matmul op;
|
||||
auto result = op.evaluate({&x, &y}, {}, {1, 0});
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
|
||||
auto z = result->at(0);
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests4, Test_Split_1) {
|
||||
auto x = NDArrayFactory::create<double>('c', {5, 30});
|
||||
auto sizes = NDArrayFactory::create<int>('c', {1, 3}, {4, 15, 11});
|
||||
|
@ -1166,57 +1146,6 @@ TEST_F(DeclarableOpsTests4, Test_Cross_3) {
|
|||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests4, Test_Matmul_YATS_1) {
|
||||
auto a = NDArrayFactory::create<double>('c', {3, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||||
auto b = NDArrayFactory::create<double>('c', {4}, {1, 2, 3, 4});
|
||||
auto exp = NDArrayFactory::create<double>('c', {3}, {30, 70, 110});
|
||||
|
||||
nd4j::ops::matmul op;
|
||||
auto result = op.evaluate({&a, &b});
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
|
||||
auto z = result->at(0);
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests4, Test_Matmul_YATS_2) {
|
||||
auto a = NDArrayFactory::create<double>('c', {4}, {1, 2, 3, 4});
|
||||
auto b = NDArrayFactory::create<double>('c', {4, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||||
auto exp = NDArrayFactory::create<double>('c', {3}, {70, 80, 90});
|
||||
|
||||
nd4j::ops::matmul op;
|
||||
auto result = op.evaluate({&a, &b});
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
|
||||
auto z = result->at(0);
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests4, Test_Matmul_YATS_3) {
|
||||
auto a = NDArrayFactory::create<double>('c', {1, 4}, {1, 2, 3, 4});
|
||||
auto b = NDArrayFactory::create<double>('c', {4, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||||
auto exp = NDArrayFactory::create<double>('c', {1, 3}, {70, 80, 90});
|
||||
|
||||
nd4j::ops::matmul op;
|
||||
auto result = op.evaluate({&a, &b});
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
|
||||
auto z = result->at(0);
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests4, Test_Add_119) {
|
||||
auto a = NDArrayFactory::create<double>('c', {1, 4}, {1, 2, 3, 4});
|
||||
auto b = NDArrayFactory::create<double>('c', {4}, {1, 2, 3, 4});
|
||||
|
|
|
@ -5019,20 +5019,6 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_Prod_7) {
|
|||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests7, Test_Matmul_Once_Again) {
|
||||
auto x = NDArrayFactory::create<double>('c', {1, 2}, {2.0f, 2.0f});
|
||||
auto y = NDArrayFactory::create<double>('c', {2, 1}, {2.0f, 2.0f});
|
||||
auto exp = NDArrayFactory::create<double>('c', {1, 1}, {8.0f});
|
||||
|
||||
nd4j::ops::matmul op;
|
||||
auto result = op.evaluate({&x, &y}, {}, {});
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
|
||||
ASSERT_EQ(exp, *result->at(0));
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
TYPED_TEST(TypedDeclarableOpsTests7, Test_Pnorm_Once_Again) {
|
||||
auto input = NDArrayFactory::create<TypeParam>('c', {1, 1, 5, 5}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f, 25.0f});
|
||||
auto exp = NDArrayFactory::create<TypeParam>('c', {1, 1, 5, 5}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f, 25.0f});
|
||||
|
|
|
@ -932,208 +932,6 @@ TEST_F(DeclarableOpsTests9, tile_test1) {
|
|||
delete results;
|
||||
}
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests9, matmul_test1) {
|
||||
|
||||
auto x = NDArrayFactory::create<double>('c', {3, 4});
|
||||
auto y = NDArrayFactory::create<double>('c', {4, 3});
|
||||
auto exp = NDArrayFactory::create<double>('f', {3, 3}, {35., 79., 123., 40., 92., 144., 45., 105., 165.});
|
||||
|
||||
x.linspace(1.);
|
||||
y.linspace(0.5, 0.5);
|
||||
|
||||
nd4j::ops::matmul op;
|
||||
auto results = op.evaluate({&x, &y}, {}, {});
|
||||
auto z = results->at(0);
|
||||
|
||||
ASSERT_EQ(Status::OK(), results->status());
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests9, matmul_test2) {
|
||||
|
||||
auto x = NDArrayFactory::create<double>('c', {3, 4});
|
||||
auto y = NDArrayFactory::create<double>('f', {4, 3});
|
||||
auto exp = NDArrayFactory::create<double>('f', {3, 3}, {35., 79., 123.,40., 92., 144.,45.,105., 165.});
|
||||
|
||||
x.linspace(1.);
|
||||
y.linspace(0.5, 0.5);
|
||||
|
||||
nd4j::ops::matmul op;
|
||||
auto results = op.evaluate({&x, &y}, {}, {});
|
||||
auto z = results->at(0);
|
||||
|
||||
ASSERT_EQ(Status::OK(), results->status());
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests9, matmul_test3) {
|
||||
|
||||
auto x = NDArrayFactory::create<double>('f', {3, 4});
|
||||
auto y = NDArrayFactory::create<double>('c', {4, 3});
|
||||
auto exp = NDArrayFactory::create<double>('f', {3, 3}, {35., 79., 123.,40., 92., 144.,45.,105., 165.});
|
||||
|
||||
x.linspace(1.);
|
||||
y.linspace(0.5, 0.5);
|
||||
|
||||
nd4j::ops::matmul op;
|
||||
auto results = op.evaluate({&x, &y}, {}, {});
|
||||
auto z = results->at(0);
|
||||
|
||||
ASSERT_EQ(Status::OK(), results->status());
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests9, matmul_test4) {
|
||||
|
||||
auto x = NDArrayFactory::create<double> ('f', {3, 4});
|
||||
auto y = NDArrayFactory::create<double>('f', {4, 3});
|
||||
auto exp = NDArrayFactory::create<double>('f', {3, 3}, {35., 79., 123.,40., 92., 144.,45.,105., 165.});
|
||||
|
||||
x.linspace(1.);
|
||||
y.linspace(0.5, 0.5);
|
||||
|
||||
nd4j::ops::matmul op;
|
||||
auto results = op.evaluate({&x, &y}, {}, {});
|
||||
auto z = results->at(0);
|
||||
|
||||
ASSERT_EQ(Status::OK(), results->status());
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests9, matmul_test5) {
|
||||
|
||||
auto x = NDArrayFactory::create<double>('c', {4, 3});
|
||||
auto y = NDArrayFactory::create<double>('c', {4, 3});
|
||||
auto exp = NDArrayFactory::create<double>('f', {3, 3}, {83., 94., 105., 94., 107., 120., 105., 120., 135.});
|
||||
|
||||
x.linspace(1.);
|
||||
y.linspace(0.5, 0.5);
|
||||
|
||||
nd4j::ops::matmul op;
|
||||
auto results = op.evaluate({&x, &y}, {}, {1});
|
||||
auto z = results->at(0);
|
||||
|
||||
ASSERT_EQ(Status::OK(), results->status());
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests9, matmul_test6) {
|
||||
|
||||
auto x = NDArrayFactory::create<double>('c', {4, 3});
|
||||
auto y = NDArrayFactory::create<double>('f', {3, 4});
|
||||
auto exp = NDArrayFactory::create<double>('f', {3, 3}, {35., 40., 45., 79., 92., 105., 123., 144., 165.});
|
||||
|
||||
x.linspace(1.);
|
||||
y.linspace(0.5, 0.5);
|
||||
|
||||
nd4j::ops::matmul op;
|
||||
auto results = op.evaluate({&x, &y}, {}, {1, 1});
|
||||
auto z = results->at(0);
|
||||
|
||||
ASSERT_EQ(Status::OK(), results->status());
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests9, matmul_test7) {
|
||||
|
||||
auto x = NDArrayFactory::create<double>('c', {5, 3,4});
|
||||
auto y = NDArrayFactory::create<double>('f', {5, 3,4});
|
||||
auto exp = NDArrayFactory::create<double>('f',{5, 3,3}, {3. , 84.6, 281.4, 593.4, 1020.6, 7. , 107.8, 323.8, 655. , 1101.4,11. , 131. , 366.2, 716.6, 1182.2,
|
||||
7. , 107.8, 323.8, 655. , 1101.4,17.4, 137.4, 372.6, 723. , 1188.6,27.8, 167. , 421.4, 791. , 1275.8,
|
||||
11. , 131. , 366.2, 716.6, 1182.2,27.8, 167. , 421.4, 791. , 1275.8,44.6, 203. , 476.6, 865.4, 1369.4,});
|
||||
|
||||
x.linspace(1.);
|
||||
y.linspace(0.1, 0.1);
|
||||
|
||||
nd4j::ops::matmul op;
|
||||
auto results = op.evaluate({&x, &y}, {}, {0, 1});
|
||||
auto z = results->at(0);
|
||||
|
||||
ASSERT_EQ(Status::OK(), results->status());
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests9, matmul_test8) {
|
||||
|
||||
auto x = NDArrayFactory::create<double>('c', {2,5, 3,4});
|
||||
auto y = NDArrayFactory::create<double>('f', {2,5, 3,4});
|
||||
auto exp = NDArrayFactory::create<double>('f',{2,5, 3,3}, {3. , 1563. , 84.6, 2220.6, 281.4, 2993.4, 593.4, 3881.4,1020.6, 4884.6, 7. , 1663. , 107.8, 2339.8, 323.8, 3131.8, 655. , 4039. ,1101.4, 5061.4,
|
||||
11. , 1763. , 131. , 2459. , 366.2, 3270.2, 716.6, 4196.6,1182.2, 5238.2, 7. , 1663. , 107.8, 2339.8, 323.8, 3131.8, 655. , 4039. ,1101.4, 5061.4,
|
||||
17.4, 1769.4, 137.4, 2465.4, 372.6, 3276.6, 723. , 4203. ,1188.6, 5244.6, 27.8, 1875.8, 167. , 2591. , 421.4, 3421.4, 791. , 4367. ,1275.8, 5427.8,
|
||||
11. , 1763. , 131. , 2459. , 366.2, 3270.2, 716.6, 4196.6,1182.2, 5238.2, 27.8, 1875.8, 167. , 2591. , 421.4, 3421.4, 791. , 4367. ,1275.8, 5427.8,
|
||||
44.6, 1988.6, 203. , 2723. , 476.6, 3572.6, 865.4, 4537.4,1369.4, 5617.4});
|
||||
|
||||
x.linspace(1.);
|
||||
y.linspace(0.1, 0.1);
|
||||
|
||||
nd4j::ops::matmul op;
|
||||
auto results = op.evaluate({&x, &y}, {}, {0, 1});
|
||||
auto z = results->at(0);
|
||||
|
||||
ASSERT_EQ(Status::OK(), results->status());
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests9, matmul_test9) {
|
||||
|
||||
auto x = NDArrayFactory::create<double>('c', {2,5, 4,3});
|
||||
auto y = NDArrayFactory::create<double>('f', {2,5, 3,4});
|
||||
auto exp = NDArrayFactory::create<double>('f',{2,5, 3,3}, {7. , 1639. , 103. , 2311. , 314.2, 3098.2, 640.6, 4000.6,1082.2, 5018.2, 8. , 1664. , 108.8, 2340.8, 324.8, 3132.8, 656. , 4040. ,1102.4, 5062.4,
|
||||
9. , 1689. , 114.6, 2370.6, 335.4, 3167.4, 671.4, 4079.4,1122.6, 5106.6, 15.8, 1743.8, 131. , 2435. , 361.4, 3241.4, 707. , 4163. ,1167.8, 5199.8,
|
||||
18.4, 1770.4, 138.4, 2466.4, 373.6, 3277.6, 724. , 4204. ,1189.6, 5245.6, 21. , 1797. , 145.8, 2497.8, 385.8, 3313.8, 741. , 4245. ,1211.4, 5291.4,
|
||||
24.6, 1848.6, 159. , 2559. , 408.6, 3384.6, 773.4, 4325.4,1253.4, 5381.4, 28.8, 1876.8, 168. , 2592. , 422.4, 3422.4, 792. , 4368. ,1276.8, 5428.8,
|
||||
33. , 1905. , 177. , 2625. , 436.2, 3460.2, 810.6, 4410.6,1300.2, 5476.2});
|
||||
|
||||
x.linspace(1.);
|
||||
y.linspace(0.1, 0.1);
|
||||
|
||||
nd4j::ops::matmul op;
|
||||
auto results = op.evaluate({&x, &y}, {}, {1, 1});
|
||||
auto z = results->at(0);
|
||||
|
||||
ASSERT_EQ(Status::OK(), results->status());
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete results;
|
||||
}
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests9, TestDropout_BP_1) {
|
||||
|
||||
|
@ -1325,325 +1123,6 @@ TEST_F(DeclarableOpsTests9, Test_AlphaDropout_BP_1) {
|
|||
delete ress2;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests9, matmul_test10) {
|
||||
|
||||
auto x = NDArrayFactory::create<double>('c', {1, 4, 3});
|
||||
auto y = NDArrayFactory::create<double>('f', {1, 3, 4});
|
||||
auto exp = NDArrayFactory::create<double>('f', {1, 3, 3}, {35., 40., 45., 79., 92., 105., 123., 144., 165.});
|
||||
|
||||
x.linspace(1.);
|
||||
y.linspace(0.5, 0.5);
|
||||
|
||||
nd4j::ops::matmul op;
|
||||
auto results = op.evaluate({&x, &y}, {}, {1, 1});
|
||||
auto z = results->at(0);
|
||||
|
||||
ASSERT_EQ(Status::OK(), results->status());
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests9, matmul_test11) {
|
||||
|
||||
auto x = NDArrayFactory::create<double>('c', {4, 1});
|
||||
auto y = NDArrayFactory::create<double>('f', {1, 4});
|
||||
auto exp = NDArrayFactory::create<double>('f', {1, 1}, {15});
|
||||
|
||||
x.linspace(1.);
|
||||
y.linspace(0.5, 0.5);
|
||||
nd4j::ops::matmul op;
|
||||
auto results = op.evaluate({&x, &y}, {}, {1, 1});
|
||||
ASSERT_EQ(Status::OK(), results->status());
|
||||
|
||||
auto z = results->at(0);
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests9, matmul_test12) {
|
||||
|
||||
auto x = NDArrayFactory::create<double>('c', {1, 4, 1});
|
||||
auto y = NDArrayFactory::create<double>('f', {1, 1, 4});
|
||||
auto exp = NDArrayFactory::create<double>('f', {1, 1, 1}, {15});
|
||||
|
||||
x.linspace(1.);
|
||||
y.linspace(0.5, 0.5);
|
||||
|
||||
nd4j::ops::matmul op;
|
||||
auto results = op.evaluate({&x, &y}, {}, {1, 1});
|
||||
|
||||
ASSERT_EQ(Status::OK(), results->status());
|
||||
auto z = results->at(0);
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests9, matmul_test13) {
|
||||
|
||||
auto x = NDArrayFactory::create<double>('c', {2, 3});
|
||||
auto y = NDArrayFactory::create<double>('c', {3, 5});
|
||||
auto exp = NDArrayFactory::create<double>('f', {5, 2}, {23. , 26. , 29. , 32. , 35., 50. , 57.5, 65. , 72.5, 80.});
|
||||
|
||||
x.linspace(1.);
|
||||
y.linspace(0.5, 0.5);
|
||||
|
||||
nd4j::ops::matmul op;
|
||||
auto results = op.evaluate({&x, &y}, {}, {0, 0, 1});
|
||||
auto z = results->at(0);
|
||||
|
||||
ASSERT_EQ(Status::OK(), results->status());
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests9, matmul_test14) {
|
||||
|
||||
auto x = NDArrayFactory::create<double>('c', {3, 2});
|
||||
auto y = NDArrayFactory::create<double>('c', {3, 5});
|
||||
auto exp = NDArrayFactory::create<double>('f', {5, 2}, {37. , 41.5, 46. , 50.5, 55., 46. , 52. , 58. , 64. , 70.});
|
||||
|
||||
x.linspace(1.);
|
||||
y.linspace(0.5, 0.5);
|
||||
|
||||
nd4j::ops::matmul op;
|
||||
auto results = op.evaluate({&x, &y}, {}, {1, 0, 1});
|
||||
auto z = results->at(0);
|
||||
|
||||
ASSERT_EQ(Status::OK(), results->status());
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests9, matmul_test15) {
|
||||
|
||||
auto x = NDArrayFactory::create<double>('c', {3, 2});
|
||||
auto y = NDArrayFactory::create<double>('c', {3, 5});
|
||||
auto exp = NDArrayFactory::create<double>('f', {5, 2}, {37. , 41.5, 46. , 50.5, 55., 46. , 52. , 58. , 64. , 70.});
|
||||
|
||||
x.linspace(1.);
|
||||
y.linspace(0.5, 0.5);
|
||||
|
||||
nd4j::ops::matmul op;
|
||||
auto results = op.evaluate({&x, &y}, {}, {1, 0, 1});
|
||||
auto z = results->at(0);
|
||||
|
||||
ASSERT_EQ(Status::OK(), results->status());
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests9, matmul_test16) {
|
||||
|
||||
auto x = NDArrayFactory::create<double>('c', {2,2, 3,5});
|
||||
auto y = NDArrayFactory::create<double>('c', {2,2, 4,3});
|
||||
auto exp = NDArrayFactory::create<double>('f',{2,2, 4,5}, {4.6, 281.8, 89.2, 582.4, 10. , 314.2,108.1, 628.3, 15.4, 346.6,127. , 674.2, 20.8, 379. ,145.9, 720.1, 5.2, 289.6, 93.4, 593.8,
|
||||
11.5, 322.9,113.2, 640.6, 17.8, 356.2,133. , 687.4, 24.1, 389.5,152.8, 734.2, 5.8, 297.4, 97.6, 605.2, 13. , 331.6,118.3, 652.9,
|
||||
20.2, 365.8,139. , 700.6, 27.4, 400. ,159.7, 748.3, 6.4, 305.2,101.8, 616.6, 14.5, 340.3,123.4, 665.2, 22.6, 375.4,145. , 713.8,
|
||||
30.7, 410.5,166.6, 762.4, 7. , 313. ,106. , 628. , 16. , 349. ,128.5, 677.5, 25. , 385. ,151. , 727. , 34. , 421. ,173.5, 776.5});
|
||||
|
||||
x.linspace(1.);
|
||||
y.linspace(0.1, 0.1);
|
||||
|
||||
nd4j::ops::matmul op;
|
||||
auto results = op.evaluate({&x, &y}, {}, {1, 1, 1});
|
||||
auto z = results->at(0);
|
||||
|
||||
ASSERT_EQ(Status::OK(), results->status());
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests9, matmul_test17) {
|
||||
|
||||
auto x = NDArrayFactory::create<double>('f', {4, 3});
|
||||
auto y = NDArrayFactory::create<double>('c', {4});
|
||||
auto exp = NDArrayFactory::create<double>('f',{3}, {7., 8., 9.});
|
||||
|
||||
x.linspace(1.);
|
||||
y.linspace(0.1, 0.1);
|
||||
|
||||
nd4j::ops::matmul op;
|
||||
auto results = op.evaluate({&x, &y}, {}, {1, 0});
|
||||
auto z = results->at(0);
|
||||
|
||||
ASSERT_EQ(Status::OK(), results->status());
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests9, matmul_test18) {
|
||||
|
||||
auto x = NDArrayFactory::create<double>('f', {3});
|
||||
auto y = NDArrayFactory::create<double>('c', {4, 3});
|
||||
auto exp = NDArrayFactory::create<double>('f',{4}, {1.4, 3.2, 5., 6.8});
|
||||
|
||||
x.linspace(1.);
|
||||
y.linspace(0.1, 0.1);
|
||||
|
||||
nd4j::ops::matmul op;
|
||||
auto results = op.evaluate({&x, &y}, {}, {0, 1});
|
||||
auto z = results->at(0);
|
||||
|
||||
ASSERT_EQ(Status::OK(), results->status());
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests9, matmul_test19) {
|
||||
|
||||
auto x = NDArrayFactory::create<double>('f', {1, 1});
|
||||
auto y = NDArrayFactory::create<double>('c', {1, 1});
|
||||
auto exp = NDArrayFactory::create<double>('f',{1, 1}, {0.2});
|
||||
|
||||
x.linspace(2.);
|
||||
y.linspace(0.1, 0.1);
|
||||
|
||||
nd4j::ops::matmul op;
|
||||
auto results = op.evaluate({&x, &y}, {}, {});
|
||||
auto z = results->at(0);
|
||||
|
||||
ASSERT_EQ(Status::OK(), results->status());
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests9, matmul_test20) {
|
||||
|
||||
auto x = NDArrayFactory::create<double>('f', {1, 1});
|
||||
auto y = NDArrayFactory::create<double>('c', {1, 1});
|
||||
auto exp = NDArrayFactory::create<double>('f',{1, 1}, {0.2});
|
||||
|
||||
x.linspace(2.);
|
||||
y.linspace(0.1, 0.1);
|
||||
|
||||
nd4j::ops::matmul op;
|
||||
auto results = op.evaluate({&x, &y}, {}, {1,1,1});
|
||||
auto z = results->at(0);
|
||||
|
||||
ASSERT_EQ(Status::OK(), results->status());
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests9, matmul_test21) {
|
||||
|
||||
auto x = NDArrayFactory::create<double>('f', {1});
|
||||
auto y = NDArrayFactory::create<double>('c', {1, 1});
|
||||
auto exp = NDArrayFactory::create<double>('f',{1}, {0.2});
|
||||
|
||||
x.linspace(2.);
|
||||
y.linspace(0.1, 0.1);
|
||||
|
||||
nd4j::ops::matmul op;
|
||||
auto results = op.evaluate({&x, &y}, {}, {});
|
||||
auto z = results->at(0);
|
||||
|
||||
ASSERT_EQ(Status::OK(), results->status());
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests9, matmul_test22) {
|
||||
|
||||
auto x = NDArrayFactory::create<double>('f', {1,1});
|
||||
auto y = NDArrayFactory::create<double>('c', {1});
|
||||
auto exp = NDArrayFactory::create<double>('f',{1}, {0.2});
|
||||
|
||||
x.linspace(2.);
|
||||
y.linspace(0.1, 0.1);
|
||||
|
||||
nd4j::ops::matmul op;
|
||||
auto results = op.evaluate({&x, &y}, {}, {1});
|
||||
auto z = results->at(0);
|
||||
|
||||
ASSERT_EQ(Status::OK(), results->status());
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests9, matmul_test23) {
|
||||
|
||||
auto x = NDArrayFactory::create<double>('f', {4});
|
||||
auto y = NDArrayFactory::create<double>('c', {4});
|
||||
auto exp = NDArrayFactory::create<double>(3.);
|
||||
|
||||
x.linspace(1.);
|
||||
y.linspace(0.1, 0.1);
|
||||
|
||||
nd4j::ops::matmul op;
|
||||
auto results = op.evaluate({&x, &y}, {}, {1, 1});
|
||||
auto z = results->at(0);
|
||||
|
||||
ASSERT_EQ(Status::OK(), results->status());
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests9, matmul_test24) {
|
||||
|
||||
auto x = NDArrayFactory::create<double>('f', {1}, {2.});
|
||||
auto y = NDArrayFactory::create<double>('c', {1}, {3.});
|
||||
auto exp = NDArrayFactory::create<double>(6.);
|
||||
|
||||
nd4j::ops::matmul op;
|
||||
auto results = op.evaluate({&x, &y}, {}, {1, 1});
|
||||
auto z = results->at(0);
|
||||
|
||||
ASSERT_EQ(Status::OK(), results->status());
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests9, test_range_int_1) {
|
||||
auto x0 = NDArrayFactory::create<int>(0);
|
||||
auto x1 = NDArrayFactory::create<int>(2);
|
||||
|
|
|
@ -64,8 +64,11 @@ TEST_F(MklDnnTests, helpers_includer) {
|
|||
nd4j::ops::platforms::PLATFORM_maxpool3dnew_bp_ENGINE_CPU maxpool3d_bp;
|
||||
|
||||
nd4j::ops::platforms::PLATFORM_lrn_ENGINE_CPU lrn;
|
||||
|
||||
nd4j::ops::platforms::PLATFORM_batchnorm_ENGINE_CPU batchnorm;
|
||||
|
||||
printer({&conv2d, &conv2d_bp, &conv3d, &conv3d_bp, &avgpool2d, &avgpool2d_bp, &maxpool2d, &maxpool2d_bp, &avgpool3d, &avgpool3d_bp, &maxpool3d, &maxpool3d_bp, &lrn, &batchnorm});
|
||||
nd4j::ops::platforms::PLATFORM_matmul_ENGINE_CPU matmul;
|
||||
|
||||
printer({&conv2d, &conv2d_bp, &conv3d, &conv3d_bp, &avgpool2d, &avgpool2d_bp, &maxpool2d, &maxpool2d_bp, &avgpool3d, &avgpool3d_bp, &maxpool3d, &maxpool3d_bp, &lrn, &batchnorm, &matmul});
|
||||
#endif
|
||||
}
|
Loading…
Reference in New Issue