Shyrma mkl matmul (#250)

* - provide matmul code based on mkl api

Signed-off-by: Yurii <iuriish@yahoo.com>

* - correct typo in mkl matmul op

Signed-off-by: Yurii <iuriish@yahoo.com>

* - take into account empty arrays in mkl matmul op

Signed-off-by: Yurii <iuriish@yahoo.com>

* - fix bug in mkl matmul and group all matmul tests in one file

Signed-off-by: Yurii <iuriish@yahoo.com>
master
Yurii Shyrma 2020-02-18 07:58:01 +02:00 committed by GitHub
parent c8882cbfa5
commit 22c7aa9acf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 1229 additions and 920 deletions

View File

@ -5,7 +5,7 @@ project(mkldnn-download NONE)
include(ExternalProject)
ExternalProject_Add(mkldnn
GIT_REPOSITORY https://github.com/intel/mkl-dnn.git
GIT_TAG v1.1.3
GIT_TAG v1.2
SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/mkldnn-src"
BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/mkldnn-build"
CONFIGURE_COMMAND ""

View File

@ -29,142 +29,128 @@
namespace nd4j {
namespace ops {
CUSTOM_OP_IMPL(matmul, 2, 1, false, 0, -2) {
auto x = INPUT_VARIABLE(0);
auto y = INPUT_VARIABLE(1);
auto z = OUTPUT_VARIABLE(0);
//////////////////////////////////////////////////////////////////////
CUSTOM_OP_IMPL(matmul, 2, 1, false, 0, -2) {
const int iSize = (int) block.getIArguments()->size();
int transX = iSize > 0 ? INT_ARG(0) : 0;
int transY = iSize > 1 ? INT_ARG(1) : 0;
const int transZ = iSize > 2 ? INT_ARG(2) : 0;
auto x = INPUT_VARIABLE(0);
auto y = INPUT_VARIABLE(1);
auto z = OUTPUT_VARIABLE(0);
const int xRank = x->rankOf();
const int yRank = y->rankOf();
const int zRank = z->rankOf();
const int iSize = (int) block.getIArguments()->size();
int transX = iSize > 0 ? INT_ARG(0) : 0;
int transY = iSize > 1 ? INT_ARG(1) : 0;
const int transZ = iSize > 2 ? INT_ARG(2) : 0;
if (transZ) {
x = INPUT_VARIABLE(1);
y = INPUT_VARIABLE(0);
bool temp = transX;
transX = !transY;
transY = !temp;
}
const int xRank = x->rankOf();
const int yRank = y->rankOf();
const int zRank = z->rankOf();
const int xLastDim = transX ? -2 : -1;
const int yLastDim = transY ? -2 : -1;
const int xLastButOneDim = transX ? -1 : -2;
const int yLastButOneDim = transY ? -1 : -2;
if (transZ) {
x = INPUT_VARIABLE(1);
y = INPUT_VARIABLE(0);
bool temp = transX;
transX = !transY;
transY = !temp;
}
// ******* input validation ******* //
REQUIRE_TRUE(xRank > 0 && yRank > 0, 0,
"MATMUL OP: input arrays must have rank bigger than 0 (should not be scalars), but got instead: x rank = %i, y rank = %i !",
xRank, yRank);
const int xLastDim = transX ? -2 : -1;
const int yLastDim = transY ? -2 : -1;
const int xLastButOneDim = transX ? -1 : -2;
const int yLastButOneDim = transY ? -1 : -2;
if (xRank == 1 && yRank == 1) { // dot case, output is scalar (or vector with length = 1)
REQUIRE_TRUE(x->lengthOf() == y->lengthOf(), 0,
"MATMUL OP: since input arrays are vectors they must have the same length, but got x length = %i, y length = %i !",
x->lengthOf(), y->lengthOf());
} else if (xRank == 1 && yRank == 2) { // vector x matrix, i.e. [4] x [4,5] = [5], output is vector
REQUIRE_TRUE(x->lengthOf() == y->sizeAt(yLastButOneDim), 0,
"MATMUL OP: input arrays have inconsistent shapes for vector-matrix product: x %s, y %s !",
ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str());
} else if (xRank == 2 && yRank == 1) { // matrix x vector , i.e. [4,5] x [5] = [4], output is vector
REQUIRE_TRUE(x->sizeAt(xLastDim) == y->lengthOf(), 0,
"MATMUL OP: input arrays have inconsistent shapes for matrix-vector product: x %s, y %s !",
ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str());
} else {
REQUIRE_TRUE(xRank == yRank && yRank == zRank, 0,
"MATMUL OP: input and output arrays must have the same rank, but got instead: x rank = %i, y rank = %i, z rank = %i !",
xRank, yRank, zRank);
REQUIRE_TRUE(x->sizeAt(xLastDim) == y->sizeAt(yLastButOneDim) &&
x->sizeAt(xLastButOneDim) == z->sizeAt(-2) && y->sizeAt(yLastDim) == z->sizeAt(-1), 0,
"MATMUL OP: input/output arrays have inconsistent shapes for matrix product: x %s, y %s, z %s !",
ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str(),
ShapeUtils::shapeAsString(z).c_str());
// ******* input validation ******* //
REQUIRE_TRUE(xRank > 0 && yRank > 0, 0, "MATMUL OP: input arrays must have rank bigger than 0 (should not be scalars), but got instead: x rank = %i, y rank = %i !", xRank, yRank);
if (xRank > 2) // outer dims must be the same
for (int i = 0; i < xRank - 2; ++i)
REQUIRE_TRUE(x->sizeAt(i) == y->sizeAt(i) && y->sizeAt(i) == z->sizeAt(i), 0,
"MATMUL OP: input/output arrays have inconsistent shapes for matrix product: x %s, y %s, z %s !",
ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str(),
ShapeUtils::shapeAsString(z).c_str());
}
// ******* end of input validation ******* //
if (xRank == 1 && yRank == 1) { // dot case, output is scalar (or vector with length = 1)
REQUIRE_TRUE(x->lengthOf() == y->lengthOf(), 0, "MATMUL OP: since input arrays are vectors they must have the same length, but got x length = %i, y length = %i !", x->lengthOf(), y->lengthOf());
} else if (xRank == 1 && yRank == 2) { // vector x matrix, i.e. [4] x [4,5] = [5], output is vector
REQUIRE_TRUE(x->lengthOf() == y->sizeAt(yLastButOneDim), 0, "MATMUL OP: input arrays have inconsistent shapes for vector-matrix product: x %s, y %s !", ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str());
} else if (xRank == 2 && yRank == 1) { // matrix x vector , i.e. [4,5] x [5] = [4], output is vector
REQUIRE_TRUE(x->sizeAt(xLastDim) == y->lengthOf(), 0, "MATMUL OP: input arrays have inconsistent shapes for matrix-vector product: x %s, y %s !", ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str());
} else {
REQUIRE_TRUE(xRank == yRank && yRank == zRank, 0, "MATMUL OP: input and output arrays must have the same rank, but got instead: x rank = %i, y rank = %i, z rank = %i !", xRank, yRank, zRank);
REQUIRE_TRUE(x->sizeAt(xLastDim) == y->sizeAt(yLastButOneDim) && x->sizeAt(xLastButOneDim) == z->sizeAt(-2) && y->sizeAt(yLastDim) == z->sizeAt(-1), 0, "MATMUL OP: input/output arrays have inconsistent shapes for matrix product: x %s, y %s, z %s !", ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str(), ShapeUtils::shapeAsString(z).c_str());
MmulHelper::matmul(x, y, z, transX, transY);
if (xRank > 2) // outer dims must be the same
for (int i = 0; i < xRank - 2; ++i)
REQUIRE_TRUE(x->sizeAt(i) == y->sizeAt(i) && y->sizeAt(i) == z->sizeAt(i), 0, "MATMUL OP: input/output arrays have inconsistent shapes for matrix product: x %s, y %s, z %s !", ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str(), ShapeUtils::shapeAsString(z).c_str());
}
// ******* end of input validation ******* //
return Status::OK();
}
MmulHelper::matmul(x, y, z, transX, transY);
DECLARE_SYN(mMul, matmul);
return Status::OK();
}
DECLARE_SYN(mmul, matmul);
DECLARE_SYN(mMul, matmul);
DECLARE_SYN(gemm, matmul);
DECLARE_SYN(mmul, matmul);
DECLARE_SYN(gemv, matmul);
DECLARE_SYN(gemm, matmul);
DECLARE_SYN(dot, matmul);
DECLARE_SYN(gemv, matmul);
DECLARE_SYN(dot, matmul);
DECLARE_SHAPE_FN(matmul) {
//////////////////////////////////////////////////////////////////////
DECLARE_SHAPE_FN(matmul) {
auto xShapeInfo = inputShape->at(0);
auto yShapeInfo = inputShape->at(1);
auto xShapeInfo = inputShape->at(0);
auto yShapeInfo = inputShape->at(1);
const int iSize = (int) block.getIArguments()->size();
int transX = iSize > 0 ? INT_ARG(0) : 0;
int transY = iSize > 1 ? INT_ARG(1) : 0;
const int transZ = iSize > 2 ? INT_ARG(2) : 0;
const int iSize = (int) block.getIArguments()->size();
int transX = iSize > 0 ? INT_ARG(0) : 0;
int transY = iSize > 1 ? INT_ARG(1) : 0;
const int transZ = iSize > 2 ? INT_ARG(2) : 0;
REQUIRE_TRUE(xShapeInfo[0] > 0 && yShapeInfo[0] > 0, 0,
"MATMUL OP: input arrays must have rank bigger than 0 (should not be scalars), but got instead: x rank = %i, y rank = %i !",
xShapeInfo[0], yShapeInfo[0]);
REQUIRE_TRUE(xShapeInfo[0] > 0 && yShapeInfo[0] > 0, 0,
"MATMUL OP: input arrays must have rank bigger than 0 (should not be scalars), but got instead: x rank = %i, y rank = %i !",
xShapeInfo[0], yShapeInfo[0]);
if (transZ) {
xShapeInfo = inputShape->at(1);
yShapeInfo = inputShape->at(0);
bool temp = transX;
transX = !transY;
transY = !temp;
}
if (transZ) {
xShapeInfo = inputShape->at(1);
yShapeInfo = inputShape->at(0);
bool temp = transX;
transX = !transY;
transY = !temp;
}
auto zShapeOnly = ShapeUtils::evalShapeForMatmul(xShapeInfo, yShapeInfo, transX, transY);
auto zShapeOnly = ShapeUtils::evalShapeForMatmul(xShapeInfo, yShapeInfo, transX, transY);
auto dtypeX = ArrayOptions::dataType(xShapeInfo);
auto dtypeY = ArrayOptions::dataType(yShapeInfo);
auto dtypeX = ArrayOptions::dataType(xShapeInfo);
auto dtypeY = ArrayOptions::dataType(yShapeInfo);
auto xOrder = shape::order(xShapeInfo);
auto yOrder = shape::order(yShapeInfo);
auto zOrder = xOrder == 'c' && yOrder == 'c' ? 'c' : 'f';
auto xOrder = shape::order(xShapeInfo);
auto yOrder = shape::order(yShapeInfo);
auto zOrder = xOrder == 'c' && yOrder == 'c' ? 'c' : 'f';
// we just pick the higher data type out of X and Y
auto dtypeZ = dtypeX > dtypeY ? dtypeX : dtypeY;
// we just pick the higher data type out of X and Y
auto dtypeZ = dtypeX > dtypeY ? dtypeX : dtypeY;
auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(dtypeZ, zOrder, zShapeOnly);
return SHAPELIST(newShape);
}
auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(dtypeZ, zOrder, zShapeOnly);
return SHAPELIST(newShape);
}
DECLARE_TYPES(matmul) {
getOpDescriptor()
->setAllowedInputTypes(0, {ALL_FLOATS})
->setAllowedInputTypes(1, {ALL_FLOATS})
->setAllowedOutputTypes(0, {ALL_FLOATS});
}
//////////////////////////////////////////////////////////////////////
DECLARE_TYPES(matmul) {
getOpDescriptor()
->setAllowedInputTypes(0, {ALL_FLOATS})
->setAllowedInputTypes(1, {ALL_FLOATS})
->setAllowedOutputTypes(0, {ALL_FLOATS});
}
//////////////////////////////////////////////////////////////////////
CUSTOM_OP_IMPL(matmul_bp, 3, 2, false, 0, -2) {
auto x = INPUT_VARIABLE(0);
auto y = INPUT_VARIABLE(1);
auto eps = INPUT_VARIABLE(2);
auto dldx = OUTPUT_VARIABLE(0);
auto dldy = OUTPUT_VARIABLE(1);
CUSTOM_OP_IMPL(matmul_bp, 3, 2, false, 0, -2) {
auto x = INPUT_VARIABLE(0);
auto y = INPUT_VARIABLE(1);
auto eps = INPUT_VARIABLE(2);
auto dldx = OUTPUT_VARIABLE(0);
auto dldy = OUTPUT_VARIABLE(1);
const int iSize = (int) block.getIArguments()->size();
int transX = iSize > 0 ? INT_ARG(0) : 0;
int transY = iSize > 1 ? INT_ARG(1) : 0;
const int transZ = iSize > 2 ? INT_ARG(2) : 0;
const int iSize = (int) block.getIArguments()->size();
int transX = iSize > 0 ? INT_ARG(0) : 0;
int transY = iSize > 1 ? INT_ARG(1) : 0;
const int transZ = iSize > 2 ? INT_ARG(2) : 0;
/*
In: x=[a,b], y=[b,c]
@ -177,34 +163,35 @@ F F T [a,b] [b,c] [c,a] [c,a]
*/
nd4j::ops::matmul op;
op.execute({eps, y}, {dldx}, {}, {transZ, !transY, transX}, {});
op.execute({x, eps}, {dldy}, {}, {!transX, transZ, transY}, {});
nd4j::ops::matmul op;
op.execute({eps, y}, {dldx}, {}, {transZ, !transY, transX}, {});
op.execute({x, eps}, {dldy}, {}, {!transX, transZ, transY}, {});
return Status::OK();
}
return Status::OK();
}
//////////////////////////////////////////////////////////////////////
DECLARE_SHAPE_FN(matmul_bp) {
Nd4jLong *xShapeInfo;
Nd4jLong *yShapeInfo;
DECLARE_SHAPE_FN(matmul_bp) {
Nd4jLong *xShapeInfo;
Nd4jLong *yShapeInfo;
COPY_SHAPE(inputShape->at(0), xShapeInfo);
COPY_SHAPE(inputShape->at(1), yShapeInfo);
COPY_SHAPE(inputShape->at(0), xShapeInfo);
COPY_SHAPE(inputShape->at(1), yShapeInfo);
return SHAPELIST(CONSTANT(xShapeInfo), CONSTANT(yShapeInfo));
}
return SHAPELIST(CONSTANT(xShapeInfo), CONSTANT(yShapeInfo));
}
//////////////////////////////////////////////////////////////////////
DECLARE_TYPES(matmul_bp) {
getOpDescriptor()
->setAllowedInputTypes(0, {ALL_FLOATS})
->setAllowedInputTypes(1, {ALL_FLOATS})
->setAllowedInputTypes(2, {ALL_FLOATS})
->setAllowedOutputTypes(0, {ALL_FLOATS})
->setAllowedOutputTypes(1, {ALL_FLOATS});
}
DECLARE_TYPES(matmul_bp) {
getOpDescriptor()
->setAllowedInputTypes(0, {ALL_FLOATS})
->setAllowedInputTypes(1, {ALL_FLOATS})
->setAllowedInputTypes(2, {ALL_FLOATS})
->setAllowedOutputTypes(0, {ALL_FLOATS})
->setAllowedOutputTypes(1, {ALL_FLOATS});
}
}
}
}

View File

@ -0,0 +1,294 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author Yurii Shyrma (iuriish@yahoo.com)
//
#include <ops/declarable/PlatformHelper.h>
#include <ops/declarable/OpRegistrator.h>
#include <platform_boilerplate.h>
#include <helpers/MKLDNNStream.h>
#include "mkldnnUtils.h"
#include <numeric>
namespace nd4j {
namespace ops {
namespace platforms {
//////////////////////////////////////////////////////////////////////////
static void matmulMKLDNN(const NDArray* x, const NDArray* y, NDArray* z, const bool transX, const bool transY) {
// mkl works with following
// [M,K] x [K,N] = [M,N]
// [bS, M,K] x [bS, K,N] = [bS, M,N]
// possible input cases not supported by mkl, however we'll perform permut/reshape procedures in order to fit requirements
// [4] x [4] = [1] --> [1,4] x [4,1] = [1,1]
// [4] x [4,5] = [5] --> [1,4] x [4,5] = [1,5]
// [4,5] x [5] = [4] --> [4,5] x [5,1] = [4,1]
// [2,3, 4,5] x [2,3, 5,4] = [2,3, 4,4] --> [6, 4,5] x [6, 5,4] = [6, 4,4]
// [2,2,3, 4,5] x [2,2,3, 5,4] = [2,2,3, 4,4] --> [12, 4,5] x [12, 5,4] = [12, 4,4]
const auto xRank = x->rankOf();
const auto yRank = y->rankOf();
const auto zRank = z->rankOf();
std::vector<int> permut;
// fill permutation vector appropriately if transposition is required
if((transX && xRank > 1) || (transY && yRank > 1)) {
const int rank = xRank >= yRank ? xRank : yRank;
permut.resize(rank);
std::iota(std::begin(permut), std::end(permut), 0);
permut[rank-2] = rank - 1;
permut[rank-1] = rank - 2;
}
const NDArray* xT = (transX && xRank > 1) ? new NDArray(x->permute(permut)) : x;
const NDArray* yT = (transY && yRank > 1) ? new NDArray(y->permute(permut)) : y;
const NDArray* xTR = xRank <= 3 ? xT : new NDArray(xT->reshape(xT->ordering(), {xT->lengthOf() / (xT->sizeAt(-2) * xT->sizeAt(-1)), xT->sizeAt(-2), xT->sizeAt(-1)}));
const NDArray* yTR = xRank <= 3 ? yT : new NDArray(yT->reshape(yT->ordering(), {yT->lengthOf() / (yT->sizeAt(-2) * yT->sizeAt(-1)), yT->sizeAt(-2), yT->sizeAt(-1)}));
NDArray* zR = xRank <= 3 ? z : new NDArray(z->reshape(z->ordering(), {z->lengthOf() / (z->sizeAt(-2) * z->sizeAt(-1)), z->sizeAt(-2), z->sizeAt(-1)})/*, false*/);
// [M,K] x [K,N] = [M,N]
const int M = (xRank > 1) ? xTR->sizeAt(-2) : 1;
const int K = (xRank > 1) ? xTR->sizeAt(-1) : xTR->lengthOf();
const int N = (yRank > 1) ? yTR->sizeAt(-1) : 1;
const int bS = (xRank > 2) ? xTR->sizeAt(0) : 1; // [bS, M,K] x [bS, K,N] = [bS, M,N]
dnnl::memory::dims xShape = xRank < 3 ? dnnl::memory::dims({M, K}) : dnnl::memory::dims({bS, M, K});
dnnl::memory::dims yShape = xRank < 3 ? dnnl::memory::dims({K, N}) : dnnl::memory::dims({bS, K, N});
dnnl::memory::dims zShape = xRank < 3 ? dnnl::memory::dims({M, N}) : dnnl::memory::dims({bS, M, N});
dnnl::memory::format_tag format = xRank < 3 ? dnnl::memory::format_tag::ab : dnnl::memory::format_tag::abc;
// x type
dnnl::memory::data_type xType;
if(x->dataType() == DataType::FLOAT32)
xType = dnnl::memory::data_type::f32;
else if(x->dataType() == DataType::HALF)
xType = dnnl::memory::data_type::f16;
else if(x->dataType() == DataType::BFLOAT16)
xType = dnnl::memory::data_type::bf16;
else if(x->dataType() == DataType::UINT8)
xType = dnnl::memory::data_type::u8;
else
xType = dnnl::memory::data_type::s8;
// y type
dnnl::memory::data_type yType = xType;
if(y->dataType() == DataType::UINT8)
yType = dnnl::memory::data_type::u8;
else if(y->dataType() == DataType::INT8)
yType = dnnl::memory::data_type::s8;
// z type
dnnl::memory::data_type zType = xType;
if(z->dataType() == DataType::FLOAT32)
zType = dnnl::memory::data_type::f32;
else if(z->dataType() == DataType::INT32)
zType = dnnl::memory::data_type::s32;
else if(z->dataType() == DataType::UINT8)
zType = dnnl::memory::data_type::u8;
else if(z->dataType() == DataType::INT8)
zType = dnnl::memory::data_type::s8;
// memory descriptors for arrays
// x
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xShape, xType, dnnl::memory::format_tag::any);
dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, xType, format);
if(xTR->ews() != 1 || xTR->ordering() != 'c') {
x_user_md.data.format_kind = dnnl_blocked; // overrides format
x_user_md.data.format_desc.blocking.strides[0] = xRank == 1 ? 1 : xTR->strideAt(0);
x_user_md.data.format_desc.blocking.strides[1] = xRank == 1 ? xTR->strideAt(0) : xTR->strideAt(1);
if(xRank > 2)
x_user_md.data.format_desc.blocking.strides[2] = xTR->strideAt(2);
}
// y
dnnl::memory::desc y_mkl_md = dnnl::memory::desc(yShape, yType, dnnl::memory::format_tag::any);
dnnl::memory::desc y_user_md = dnnl::memory::desc(yShape, yType, format);
if(yTR->ews() != 1 || yTR->ordering() != 'c') {
y_user_md.data.format_kind = dnnl_blocked; // overrides format
y_user_md.data.format_desc.blocking.strides[0] = yRank == 1 ? 1 : yTR->strideAt(0);
y_user_md.data.format_desc.blocking.strides[1] = yRank == 1 ? yTR->strideAt(0) : yTR->strideAt(1);
if(yRank > 2)
y_user_md.data.format_desc.blocking.strides[2] = yTR->strideAt(2);
}
// z
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zShape, zType, dnnl::memory::format_tag::any);
dnnl::memory::desc z_user_md = dnnl::memory::desc(zShape, zType, format);
if(zR->ews() != 1 || zR->ordering() != 'c') {
z_user_md.data.format_kind = dnnl_blocked; // overrides format
z_user_md.data.format_desc.blocking.strides[0] = zRank == 1 ? 1 : zR->strideAt(0);
z_user_md.data.format_desc.blocking.strides[1] = zRank == 1 ? zR->strideAt(0) : zR->strideAt(1);
if(zRank > 2)
z_user_md.data.format_desc.blocking.strides[2] = zR->strideAt(2);
}
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
// Create attributes (to handle alpha and beta if necessary)
dnnl::primitive_attr attr; // it is empty since we have usual values for alpha (=1) and beta (=0)
// operation primitive description
dnnl::matmul::desc op_desc(x_mkl_md, y_mkl_md, z_mkl_md);
dnnl::matmul::primitive_desc op_prim_desc(op_desc, attr, engine);
// arguments (memory buffers) necessary for calculations
std::unordered_map<int, dnnl::memory> args;
dnnl::stream stream(engine);
// provide memory buffers and check whether reorder is required
// input
auto x_user_mem = dnnl::memory(x_user_md, engine, xTR->getBuffer());
const bool xReorder = op_prim_desc.src_desc() != x_user_mem.get_desc();
auto x_mkl_mem = xReorder ? dnnl::memory(op_prim_desc.src_desc(), engine) : x_user_mem;
if (xReorder)
dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
args[DNNL_ARG_SRC] = x_mkl_mem;
// y
auto y_user_mem = dnnl::memory(y_user_md, engine, yTR->getBuffer());
const bool yReorder = op_prim_desc.weights_desc() != y_user_mem.get_desc();
auto y_mkl_mem = yReorder ? dnnl::memory(op_prim_desc.weights_desc(), engine) : y_user_mem;
if (yReorder)
dnnl::reorder(y_user_mem, y_mkl_mem).execute(stream, y_user_mem, y_mkl_mem);
args[DNNL_ARG_WEIGHTS] = y_mkl_mem;
// z
auto z_user_mem = dnnl::memory(z_user_md, engine, zR->getBuffer());
const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc();
auto z_mkl_mem = zReorder ? dnnl::memory(op_prim_desc.dst_desc(), engine) : z_user_mem;
args[DNNL_ARG_DST] = z_mkl_mem;
// run calculations
dnnl::matmul(op_prim_desc).execute(stream, args);
// reorder outputs if necessary
if (zReorder)
dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem);
stream.wait();
if(zR->getBuffer() != z->getBuffer())
z->assign(zR);
if(zR != z)
delete zR;
if(xTR != xT)
delete xTR;
if(xT != x)
delete xT;
if(yTR != yT)
delete yTR;
if(yT != y)
delete yT;
// shape::printArray(z_mkl_mem.map_data<float>(),8);
}
//////////////////////////////////////////////////////////////////////////
PLATFORM_IMPL(matmul, ENGINE_CPU) {
auto x = INPUT_VARIABLE(0);
auto y = INPUT_VARIABLE(1);
auto z = OUTPUT_VARIABLE(0);
if(x->isEmpty() || y->isEmpty())
return Status::OK();
const int iSize = (int) block.getIArguments()->size();
int transX = iSize > 0 ? INT_ARG(0) : 0;
int transY = iSize > 1 ? INT_ARG(1) : 0;
const int transZ = iSize > 2 ? INT_ARG(2) : 0;
const int xRank = x->rankOf();
const int yRank = y->rankOf();
const int zRank = z->rankOf();
if (transZ) {
x = INPUT_VARIABLE(1);
y = INPUT_VARIABLE(0);
bool temp = transX;
transX = !transY;
transY = !temp;
}
const int xLastDim = transX ? -2 : -1;
const int yLastDim = transY ? -2 : -1;
const int xLastButOneDim = transX ? -1 : -2;
const int yLastButOneDim = transY ? -1 : -2;
// ******* input validation ******* //
REQUIRE_TRUE(xRank > 0 && yRank > 0, 0, "MATMUL MKLDNN OP: input arrays must have rank bigger than 0 (should not be scalars), but got instead: x rank = %i, y rank = %i !", xRank, yRank);
if (xRank == 1 && yRank == 1) { // dot case, output is scalar (or vector with length = 1)
REQUIRE_TRUE(x->lengthOf() == y->lengthOf(), 0,"MATMUL MKLDNN OP: since input arrays are vectors they must have the same length, but got x length = %i, y length = %i !",x->lengthOf(), y->lengthOf());
} else if (xRank == 1 && yRank == 2) { // vector x matrix, i.e. [4] x [4,5] = [5], output is vector
REQUIRE_TRUE(x->lengthOf() == y->sizeAt(yLastButOneDim), 0, "MATMUL MKLDNN OP: input arrays have inconsistent shapes for vector-matrix product: x %s, y %s !", ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str());
} else if (xRank == 2 && yRank == 1) { // matrix x vector , i.e. [4,5] x [5] = [4], output is vector
REQUIRE_TRUE(x->sizeAt(xLastDim) == y->lengthOf(), 0, "MATMUL MKLDNN OP: input arrays have inconsistent shapes for matrix-vector product: x %s, y %s !", ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str());
} else {
REQUIRE_TRUE(xRank == yRank && yRank == zRank, 0, "MATMUL MKLDNN OP: input and output arrays must have the same rank, but got instead: x rank = %i, y rank = %i, z rank = %i !", xRank, yRank, zRank);
REQUIRE_TRUE(x->sizeAt(xLastDim) == y->sizeAt(yLastButOneDim) && x->sizeAt(xLastButOneDim) == z->sizeAt(-2) && y->sizeAt(yLastDim) == z->sizeAt(-1), 0, "MATMUL MKLDNN OP: input/output arrays have inconsistent shapes for matrix product: x %s, y %s, z %s !", ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str(), ShapeUtils::shapeAsString(z).c_str());
if (xRank > 2) // outer dims must be the same
for (int i = 0; i < xRank - 2; ++i)
REQUIRE_TRUE(x->sizeAt(i) == y->sizeAt(i) && y->sizeAt(i) == z->sizeAt(i), 0, "MATMUL MKLDNN OP: input/output arrays have inconsistent shapes for matrix product: x %s, y %s, z %s !", ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str(), ShapeUtils::shapeAsString(z).c_str());
}
// ******* end of input validation ******* //
matmulMKLDNN(x, y, z, transX, transY);
return Status::OK();
}
//////////////////////////////////////////////////////////////////////////
PLATFORM_CHECK(matmul, ENGINE_CPU) {
auto x = INPUT_VARIABLE(0);
auto y = INPUT_VARIABLE(1);
auto z = INPUT_VARIABLE(0);
const DataType xType = x->dataType();
const DataType yType = y->dataType();
const DataType zType = z->dataType();
return block.isUseMKLDNN() &&
(
(xType==DataType::FLOAT32 && yType==DataType::FLOAT32 && zType==DataType::FLOAT32) ||
(xType==DataType::HALF && yType==DataType::HALF && zType==DataType::FLOAT32) ||
(xType==DataType::BFLOAT16 && yType==DataType::BFLOAT16 && zType==DataType::BFLOAT16) ||
((xType==DataType::UINT8 || xType==DataType::INT8) && (yType==DataType::UINT8 || yType==DataType::INT8) && (zType==DataType::UINT8 || zType==DataType::INT8 || zType==DataType::INT32 || zType==DataType::FLOAT32))
);
}
}
}
}

View File

@ -84,6 +84,8 @@ namespace nd4j{
DECLARE_PLATFORM(depthwise_conv2d, ENGINE_CPU);
DECLARE_PLATFORM(depthwise_conv2d_bp, ENGINE_CPU);
DECLARE_PLATFORM(matmul, ENGINE_CPU);
}
}

View File

@ -1341,40 +1341,6 @@ TEST_F(DeclarableOpsTests1, MultiplyScalarScalar1) {
delete exp;
}
TEST_F(DeclarableOpsTests1, TestMatMul1) {
auto x = NDArrayFactory::create_<float>('c', {3, 5});
x->linspace(1);
auto y = NDArrayFactory::create_<float>('c', {5, 3});
y->linspace(1);
float _expB[]{135.0f, 310.0f, 485.0f, 150.0f, 350.0f, 550.0f, 165.0f, 390.0f, 615.0f};
Nd4jLong _expS[] {2, 3, 3, 1, 3, 0, 1, 102}; // expected shape
ArrayOptions::setDataType(_expS, nd4j::DataType::FLOAT32);
NDArray exp(_expB, _expS);
auto variableSpace = new VariableSpace();
variableSpace->putVariable(-1, x);
variableSpace->putVariable(-2, y);
variableSpace->putVariable(1, new Variable());
auto block = new Context(1, variableSpace, false);
block->fillInputs({-1, -2});
nd4j::ops::matmul op;
Nd4jStatus status = op.execute(block);
ASSERT_EQ(ND4J_STATUS_OK, status);
ASSERT_TRUE(variableSpace->hasVariable(1));
auto result = variableSpace->getVariable(1)->getNDArray();
ASSERT_TRUE(result->equalsTo(&exp));
delete block;
delete variableSpace;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests1, TestSoftMax_bp_1) {

View File

@ -2800,16 +2800,9 @@ TEST_F(DeclarableOpsTests12, QR_Test_1_1) {
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests12, QR_Test_2) {
auto in = NDArrayFactory::create<double>('c', {5,3}, {
12., -51., 4., 6., 167., -68., -4., 24., -41., -1., 1., 0., 2., 0., 3.
});
auto expQ = NDArrayFactory::create<double>('c', {5, 3}, {
0.8464148, 0.3912908, -0.3431241, -0.42320737, -0.9040873, 0.02927014, 0.28213826, -0.17042054, -0.93285596, 0.07053456, -0.01404065, 0.00109937, -0.14106913, 0.0166551, 0.10577161
});
auto expR = NDArrayFactory::create<double>('c', {3,3}, {
-14.177447, -20.666622, 13.401566, 0., -175.04254, 70.080315, 0., 0., 35.201546
});
auto in = NDArrayFactory::create<double>('c', {5,3}, {12., -51., 4., 6., 167., -68., -4., 24., -41., -1., 1., 0., 2., 0., 3.});
auto expQ = NDArrayFactory::create<double>('c', {5, 3}, {0.8464148,0.3912908,-0.3431241,-0.42320737, -0.9040873,0.02927014,0.28213826, -0.17042054, -0.93285596,0.07053456, -0.01404065,0.00109937,-0.14106913,0.0166551,0.10577161});
auto expR = NDArrayFactory::create<double>('c', {3,3}, {-14.177447,-20.666622,13.401566,0.,-175.04254,70.080315,0.,0.,35.201546});
nd4j::ops::qr op;
auto res = op.evaluate({&in}, {}, {}, {false});
@ -2819,8 +2812,6 @@ TEST_F(DeclarableOpsTests12, QR_Test_2) {
auto r = res->at(1);
ASSERT_TRUE(q->isSameShape(expQ));
ASSERT_TRUE(r->isSameShape(expR));
// q->printIndexedBuffer("Orthogonal 5x5");
// r->printIndexedBuffer("Upper triangular 5x3");
nd4j::ops::matmul opMul;
auto res2 = opMul.evaluate({q, r}); //MmulHelper::matmul(q, r, &in, false, false);

View File

@ -682,3 +682,810 @@ TEST_F(DeclarableOpsTests14, Test_broadcast_SpecialCaseTest8) {
x.applyTrueBroadcast(BroadcastOpsTuple::Subtract(), y, z);
ASSERT_EQ(e, z);
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests14, matmul_test1) {
auto x = NDArrayFactory::create<double>('c', {3, 4});
auto y = NDArrayFactory::create<double>('c', {4, 3});
auto exp = NDArrayFactory::create<double>('f', {3, 3}, {35., 79., 123., 40., 92., 144., 45., 105., 165.});
x.linspace(1.);
y.linspace(0.5, 0.5);
nd4j::ops::matmul op;
auto results = op.evaluate({&x, &y}, {}, {});
auto z = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete results;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests14, matmul_test2) {
auto x = NDArrayFactory::create<double>('c', {3, 4});
auto y = NDArrayFactory::create<double>('f', {4, 3});
auto exp = NDArrayFactory::create<double>('f', {3, 3}, {35., 79., 123.,40., 92., 144.,45.,105., 165.});
x.linspace(1.);
y.linspace(0.5, 0.5);
nd4j::ops::matmul op;
auto results = op.evaluate({&x, &y}, {}, {});
auto z = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete results;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests14, matmul_test3) {
auto x = NDArrayFactory::create<double>('f', {3, 4});
auto y = NDArrayFactory::create<double>('c', {4, 3});
auto exp = NDArrayFactory::create<double>('f', {3, 3}, {35., 79., 123.,40., 92., 144.,45.,105., 165.});
x.linspace(1.);
y.linspace(0.5, 0.5);
nd4j::ops::matmul op;
auto results = op.evaluate({&x, &y}, {}, {});
auto z = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete results;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests14, matmul_test4) {
auto x = NDArrayFactory::create<double> ('f', {3, 4});
auto y = NDArrayFactory::create<double>('f', {4, 3});
auto exp = NDArrayFactory::create<double>('f', {3, 3}, {35., 79., 123.,40., 92., 144.,45.,105., 165.});
x.linspace(1.);
y.linspace(0.5, 0.5);
nd4j::ops::matmul op;
auto results = op.evaluate({&x, &y}, {}, {});
auto z = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete results;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests14, matmul_test5) {
auto x = NDArrayFactory::create<double>('c', {4, 3});
auto y = NDArrayFactory::create<double>('c', {4, 3});
auto exp = NDArrayFactory::create<double>('f', {3, 3}, {83., 94., 105., 94., 107., 120., 105., 120., 135.});
x.linspace(1.);
y.linspace(0.5, 0.5);
nd4j::ops::matmul op;
auto results = op.evaluate({&x, &y}, {}, {1});
auto z = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete results;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests14, matmul_test6) {
auto x = NDArrayFactory::create<double>('c', {4, 3});
auto y = NDArrayFactory::create<double>('f', {3, 4});
auto exp = NDArrayFactory::create<double>('f', {3, 3}, {35., 40., 45., 79., 92., 105., 123., 144., 165.});
x.linspace(1.);
y.linspace(0.5, 0.5);
nd4j::ops::matmul op;
auto results = op.evaluate({&x, &y}, {}, {1, 1});
auto z = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete results;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests14, matmul_test7) {
auto x = NDArrayFactory::create<double>('c', {5, 3,4});
auto y = NDArrayFactory::create<double>('f', {5, 3,4});
auto exp = NDArrayFactory::create<double>('f',{5, 3,3}, {3. , 84.6, 281.4, 593.4, 1020.6, 7. , 107.8, 323.8, 655. , 1101.4,11. , 131. , 366.2, 716.6, 1182.2,
7. , 107.8, 323.8, 655. , 1101.4,17.4, 137.4, 372.6, 723. , 1188.6,27.8, 167. , 421.4, 791. , 1275.8,
11. , 131. , 366.2, 716.6, 1182.2,27.8, 167. , 421.4, 791. , 1275.8,44.6, 203. , 476.6, 865.4, 1369.4,});
x.linspace(1.);
y.linspace(0.1, 0.1);
nd4j::ops::matmul op;
auto results = op.evaluate({&x, &y}, {}, {0, 1});
auto z = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete results;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests14, matmul_test8) {
auto x = NDArrayFactory::create<double>('c', {2,5, 3,4});
auto y = NDArrayFactory::create<double>('f', {2,5, 3,4});
auto exp = NDArrayFactory::create<double>('f',{2,5, 3,3}, {3. , 1563. , 84.6, 2220.6, 281.4, 2993.4, 593.4, 3881.4,1020.6, 4884.6, 7. , 1663. , 107.8, 2339.8, 323.8, 3131.8, 655. , 4039. ,1101.4, 5061.4,
11. , 1763. , 131. , 2459. , 366.2, 3270.2, 716.6, 4196.6,1182.2, 5238.2, 7. , 1663. , 107.8, 2339.8, 323.8, 3131.8, 655. , 4039. ,1101.4, 5061.4,
17.4, 1769.4, 137.4, 2465.4, 372.6, 3276.6, 723. , 4203. ,1188.6, 5244.6, 27.8, 1875.8, 167. , 2591. , 421.4, 3421.4, 791. , 4367. ,1275.8, 5427.8,
11. , 1763. , 131. , 2459. , 366.2, 3270.2, 716.6, 4196.6,1182.2, 5238.2, 27.8, 1875.8, 167. , 2591. , 421.4, 3421.4, 791. , 4367. ,1275.8, 5427.8,
44.6, 1988.6, 203. , 2723. , 476.6, 3572.6, 865.4, 4537.4,1369.4, 5617.4});
x.linspace(1.);
y.linspace(0.1, 0.1);
nd4j::ops::matmul op;
auto results = op.evaluate({&x, &y}, {}, {0, 1});
auto z = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete results;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests14, matmul_test9) {
auto x = NDArrayFactory::create<double>('c', {2,5, 4,3});
auto y = NDArrayFactory::create<double>('f', {2,5, 3,4});
auto exp = NDArrayFactory::create<double>('f',{2,5, 3,3}, {7. , 1639. , 103. , 2311. , 314.2, 3098.2, 640.6, 4000.6,1082.2, 5018.2, 8. , 1664. , 108.8, 2340.8, 324.8, 3132.8, 656. , 4040. ,1102.4, 5062.4,
9. , 1689. , 114.6, 2370.6, 335.4, 3167.4, 671.4, 4079.4,1122.6, 5106.6, 15.8, 1743.8, 131. , 2435. , 361.4, 3241.4, 707. , 4163. ,1167.8, 5199.8,
18.4, 1770.4, 138.4, 2466.4, 373.6, 3277.6, 724. , 4204. ,1189.6, 5245.6, 21. , 1797. , 145.8, 2497.8, 385.8, 3313.8, 741. , 4245. ,1211.4, 5291.4,
24.6, 1848.6, 159. , 2559. , 408.6, 3384.6, 773.4, 4325.4,1253.4, 5381.4, 28.8, 1876.8, 168. , 2592. , 422.4, 3422.4, 792. , 4368. ,1276.8, 5428.8,
33. , 1905. , 177. , 2625. , 436.2, 3460.2, 810.6, 4410.6,1300.2, 5476.2});
x.linspace(1.);
y.linspace(0.1, 0.1);
nd4j::ops::matmul op;
auto results = op.evaluate({&x, &y}, {}, {1, 1});
auto z = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete results;
}
TEST_F(DeclarableOpsTests14, matmul_test10) {
auto x = NDArrayFactory::create_<float>('c', {3, 5});
x->linspace(1);
auto y = NDArrayFactory::create_<float>('c', {5, 3});
y->linspace(1);
float _expB[]{135.0f, 310.0f, 485.0f, 150.0f, 350.0f, 550.0f, 165.0f, 390.0f, 615.0f};
Nd4jLong _expS[] {2, 3, 3, 1, 3, 0, 1, 102}; // expected shape
ArrayOptions::setDataType(_expS, nd4j::DataType::FLOAT32);
NDArray exp(_expB, _expS);
auto variableSpace = new VariableSpace();
variableSpace->putVariable(-1, x);
variableSpace->putVariable(-2, y);
variableSpace->putVariable(1, new Variable());
auto block = new Context(1, variableSpace, false);
block->fillInputs({-1, -2});
nd4j::ops::matmul op;
Nd4jStatus status = op.execute(block);
ASSERT_EQ(ND4J_STATUS_OK, status);
ASSERT_TRUE(variableSpace->hasVariable(1));
auto result = variableSpace->getVariable(1)->getNDArray();
ASSERT_TRUE(result->equalsTo(&exp));
delete block;
delete variableSpace;
}
TEST_F(DeclarableOpsTests14, matmul_test11) {
auto A = NDArrayFactory::create<float>('c', {3, 3});
auto B = NDArrayFactory::create<float>('c', {3, 1});
auto exp = NDArrayFactory::create<float>('c', {3, 1}, {14.00f, 32.00f, 50.00f});
A.linspace(1);
B.linspace(1);
nd4j::ops::matmul op;
auto result = op.evaluate({&A, &B}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
TEST_F(DeclarableOpsTests14, matmul_test12) {
auto x= NDArrayFactory::create<double>('c', {3, 4}, {1, 2, 3, 4, 5, 6, 7, 8 , 9, 10, 11, 12});
auto y= NDArrayFactory::create<double>('c', {4, 3}, {1, 2, 3, 4, 5, 6, 7, 8 , 9, 10, 11, 12});
auto exp= NDArrayFactory::create<double>('f', {4, 4}, {38.0, 44.0, 50.0, 56.0, 83.0, 98.0, 113.0, 128.0, 128.0, 152.0, 176.0, 200.0, 173.0, 206.0, 239.0, 272.0});
nd4j::ops::matmul op;
auto result = op.evaluate({&x, &y}, {}, {1, 1});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
TEST_F(DeclarableOpsTests14, matmul_test13) {
auto x= NDArrayFactory::create<double>('c', {1, 3}, {1, 2, 3});
auto y= NDArrayFactory::create<double>('c', {1, 4}, {1, 2, 3, 4});
auto exp= NDArrayFactory::create<double>('f', {3, 4}, {1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 3.0, 6.0, 9.0, 4.0, 8.0, 12.0});
nd4j::ops::matmul op;
auto result = op.evaluate({&x, &y}, {}, {1, 0});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
//z->printIndexedBuffer("z");
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
TEST_F(DeclarableOpsTests14, matmul_test14) {
auto x= NDArrayFactory::create<double>('c', {3, 1}, {1, 2, 3});
auto y= NDArrayFactory::create<double>('c', {4, 1}, {1, 2, 3, 4});
auto exp= NDArrayFactory::create<double>('f', {3, 4}, {1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 3.0, 6.0, 9.0, 4.0, 8.0, 12.0});
nd4j::ops::matmul op;
auto result = op.evaluate({&x, &y}, {}, {0, 1});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
//z->printIndexedBuffer("z");
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
TEST_F(DeclarableOpsTests14, matmul_test15) {
auto x= NDArrayFactory::create<double>('c', {3, 1}, {1, 2, 3});
auto y= NDArrayFactory::create<double>('c', {1, 4}, {1, 2, 3, 4});
auto exp= NDArrayFactory::create<double>('f', {3, 4}, {1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 3.0, 6.0, 9.0, 4.0, 8.0, 12.0});
nd4j::ops::matmul op;
auto result = op.evaluate({&x, &y}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
//z->printIndexedBuffer("z");
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
TEST_F(DeclarableOpsTests14, matmul_test16) {
auto x= NDArrayFactory::create<double>('c', {4, 1}, {1, 2, 3, 4});
auto y= NDArrayFactory::create<double>('c', {1, 4}, {1, 2, 3, 4});
auto exp= NDArrayFactory::create<double>('f', {4, 4}, {1,2, 3, 4,2,4, 6, 8,3,6, 9,12,4,8,12,16});
nd4j::ops::matmul op;
auto result = op.evaluate({&x, &y});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
//z->printIndexedBuffer("z");
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
TEST_F(DeclarableOpsTests14, matmul_test17) {
auto x = NDArrayFactory::create<double>('c', {1, 2}, {2.0f, 2.0f});
auto y = NDArrayFactory::create<double>('c', {2, 1}, {2.0f, 2.0f});
auto exp = NDArrayFactory::create<double>('c', {1, 1}, {8.0f});
nd4j::ops::matmul op;
auto result = op.evaluate({&x, &y}, {}, {});
ASSERT_EQ(Status::OK(), result->status());
ASSERT_EQ(exp, *result->at(0));
delete result;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests14, matmul_test18) {
auto x = NDArrayFactory::create<double>('c', {1, 4, 3});
auto y = NDArrayFactory::create<double>('f', {1, 3, 4});
auto exp = NDArrayFactory::create<double>('f', {1, 3, 3}, {35., 40., 45., 79., 92., 105., 123., 144., 165.});
x.linspace(1.);
y.linspace(0.5, 0.5);
nd4j::ops::matmul op;
auto results = op.evaluate({&x, &y}, {}, {1, 1});
auto z = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete results;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests14, matmul_test19) {
auto x = NDArrayFactory::create<double>('c', {4, 1});
auto y = NDArrayFactory::create<double>('f', {1, 4});
auto exp = NDArrayFactory::create<double>('f', {1, 1}, {15});
x.linspace(1.);
y.linspace(0.5, 0.5);
nd4j::ops::matmul op;
auto results = op.evaluate({&x, &y}, {}, {1, 1});
ASSERT_EQ(Status::OK(), results->status());
auto z = results->at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete results;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests14, matmul_test20) {
auto x = NDArrayFactory::create<double>('c', {1, 4, 1});
auto y = NDArrayFactory::create<double>('f', {1, 1, 4});
auto exp = NDArrayFactory::create<double>('f', {1, 1, 1}, {15});
x.linspace(1.);
y.linspace(0.5, 0.5);
nd4j::ops::matmul op;
auto results = op.evaluate({&x, &y}, {}, {1, 1});
ASSERT_EQ(Status::OK(), results->status());
auto z = results->at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete results;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests14, matmul_test21) {
auto x = NDArrayFactory::create<double>('c', {2, 3});
auto y = NDArrayFactory::create<double>('c', {3, 5});
auto exp = NDArrayFactory::create<double>('f', {5, 2}, {23. , 26. , 29. , 32. , 35., 50. , 57.5, 65. , 72.5, 80.});
x.linspace(1.);
y.linspace(0.5, 0.5);
nd4j::ops::matmul op;
auto results = op.evaluate({&x, &y}, {}, {0, 0, 1});
auto z = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete results;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests14, matmul_test22) {
auto x = NDArrayFactory::create<double>('c', {3, 2});
auto y = NDArrayFactory::create<double>('c', {3, 5});
auto exp = NDArrayFactory::create<double>('f', {5, 2}, {37. , 41.5, 46. , 50.5, 55., 46. , 52. , 58. , 64. , 70.});
x.linspace(1.);
y.linspace(0.5, 0.5);
nd4j::ops::matmul op;
auto results = op.evaluate({&x, &y}, {}, {1, 0, 1});
auto z = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete results;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests14, matmul_test23) {
auto x = NDArrayFactory::create<double>('c', {3, 2});
auto y = NDArrayFactory::create<double>('c', {3, 5});
auto exp = NDArrayFactory::create<double>('f', {5, 2}, {37. , 41.5, 46. , 50.5, 55., 46. , 52. , 58. , 64. , 70.});
x.linspace(1.);
y.linspace(0.5, 0.5);
nd4j::ops::matmul op;
auto results = op.evaluate({&x, &y}, {}, {1, 0, 1});
auto z = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete results;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests14, matmul_test24) {
auto x = NDArrayFactory::create<double>('c', {2,2, 3,5});
auto y = NDArrayFactory::create<double>('c', {2,2, 4,3});
auto exp = NDArrayFactory::create<double>('f',{2,2, 4,5}, {4.6, 281.8, 89.2, 582.4, 10. , 314.2,108.1, 628.3, 15.4, 346.6,127. , 674.2, 20.8, 379. ,145.9, 720.1, 5.2, 289.6, 93.4, 593.8,
11.5, 322.9,113.2, 640.6, 17.8, 356.2,133. , 687.4, 24.1, 389.5,152.8, 734.2, 5.8, 297.4, 97.6, 605.2, 13. , 331.6,118.3, 652.9,
20.2, 365.8,139. , 700.6, 27.4, 400. ,159.7, 748.3, 6.4, 305.2,101.8, 616.6, 14.5, 340.3,123.4, 665.2, 22.6, 375.4,145. , 713.8,
30.7, 410.5,166.6, 762.4, 7. , 313. ,106. , 628. , 16. , 349. ,128.5, 677.5, 25. , 385. ,151. , 727. , 34. , 421. ,173.5, 776.5});
x.linspace(1.);
y.linspace(0.1, 0.1);
nd4j::ops::matmul op;
auto results = op.evaluate({&x, &y}, {}, {1, 1, 1});
auto z = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete results;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests14, matmul_test25) {
auto x = NDArrayFactory::create<double>('f', {4, 3});
auto y = NDArrayFactory::create<double>('c', {4});
auto exp = NDArrayFactory::create<double>('f',{3}, {7., 8., 9.});
x.linspace(1.);
y.linspace(0.1, 0.1);
nd4j::ops::matmul op;
auto results = op.evaluate({&x, &y}, {}, {1, 0});
auto z = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete results;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests14, matmul_test26) {
auto x = NDArrayFactory::create<double>('f', {3});
auto y = NDArrayFactory::create<double>('c', {4, 3});
auto exp = NDArrayFactory::create<double>('f',{4}, {1.4, 3.2, 5., 6.8});
x.linspace(1.);
y.linspace(0.1, 0.1);
nd4j::ops::matmul op;
auto results = op.evaluate({&x, &y}, {}, {0, 1});
auto z = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete results;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests14, matmul_test27) {
auto x = NDArrayFactory::create<double>('f', {1, 1});
auto y = NDArrayFactory::create<double>('c', {1, 1});
auto exp = NDArrayFactory::create<double>('f',{1, 1}, {0.2});
x.linspace(2.);
y.linspace(0.1, 0.1);
nd4j::ops::matmul op;
auto results = op.evaluate({&x, &y}, {}, {});
auto z = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete results;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests14, matmul_test28) {
auto x = NDArrayFactory::create<double>('f', {1, 1});
auto y = NDArrayFactory::create<double>('c', {1, 1});
auto exp = NDArrayFactory::create<double>('f',{1, 1}, {0.2});
x.linspace(2.);
y.linspace(0.1, 0.1);
nd4j::ops::matmul op;
auto results = op.evaluate({&x, &y}, {}, {1,1,1});
auto z = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete results;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests14, matmul_test29) {
auto x = NDArrayFactory::create<double>('f', {1});
auto y = NDArrayFactory::create<double>('c', {1, 1});
auto exp = NDArrayFactory::create<double>('f',{1}, {0.2});
x.linspace(2.);
y.linspace(0.1, 0.1);
nd4j::ops::matmul op;
auto results = op.evaluate({&x, &y}, {}, {});
auto z = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete results;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests14, matmul_test30) {
auto x = NDArrayFactory::create<double>('f', {1,1});
auto y = NDArrayFactory::create<double>('c', {1});
auto exp = NDArrayFactory::create<double>('f',{1}, {0.2});
x.linspace(2.);
y.linspace(0.1, 0.1);
nd4j::ops::matmul op;
auto results = op.evaluate({&x, &y}, {}, {1});
auto z = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete results;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests14, matmul_test31) {
auto x = NDArrayFactory::create<double>('f', {4});
auto y = NDArrayFactory::create<double>('c', {4});
auto exp = NDArrayFactory::create<double>(3.);
x.linspace(1.);
y.linspace(0.1, 0.1);
nd4j::ops::matmul op;
auto results = op.evaluate({&x, &y}, {}, {1, 1});
auto z = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete results;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests14, matmul_test32) {
auto x = NDArrayFactory::create<double>('f', {1}, {2.});
auto y = NDArrayFactory::create<double>('c', {1}, {3.});
auto exp = NDArrayFactory::create<double>(6.);
nd4j::ops::matmul op;
auto results = op.evaluate({&x, &y}, {}, {1, 1});
auto z = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete results;
}
TEST_F(DeclarableOpsTests14, matmul_test33) {
auto x = NDArrayFactory::create<double>('c', {4, 3});
auto y = NDArrayFactory::create<double>('c', {4, 1});
auto exp = NDArrayFactory::create<double>('c',{ 3, 1}, {70, 80, 90});
x.linspace(1);
y.linspace(1);
nd4j::ops::matmul op;
auto result = op.evaluate({&x, &y}, {}, {1, 0});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
TEST_F(DeclarableOpsTests14, matmul_test34) {
auto a = NDArrayFactory::create<double>('c', {3, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
auto b = NDArrayFactory::create<double>('c', {4}, {1, 2, 3, 4});
auto exp = NDArrayFactory::create<double>('c', {3}, {30, 70, 110});
nd4j::ops::matmul op;
auto result = op.evaluate({&a, &b});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
TEST_F(DeclarableOpsTests14, matmul_test35) {
auto a = NDArrayFactory::create<double>('c', {4}, {1, 2, 3, 4});
auto b = NDArrayFactory::create<double>('c', {4, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
auto exp = NDArrayFactory::create<double>('c', {3}, {70, 80, 90});
nd4j::ops::matmul op;
auto result = op.evaluate({&a, &b});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
TEST_F(DeclarableOpsTests14, matmul_test36) {
auto a = NDArrayFactory::create<double>('c', {1, 4}, {1, 2, 3, 4});
auto b = NDArrayFactory::create<double>('c', {4, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
auto exp = NDArrayFactory::create<double>('c', {1, 3}, {70, 80, 90});
nd4j::ops::matmul op;
auto result = op.evaluate({&a, &b});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests14, matmul_test37) {
NDArray a('c', {32, 12, 128, 64}, nd4j::DataType::FLOAT32);
NDArray b('c', {32, 12, 128, 64}, nd4j::DataType::FLOAT32);
NDArray c('c', {32,12,128,128}, nd4j::DataType::FLOAT32);
NDArray cExp('c', {32,12,128,128}, nd4j::DataType::FLOAT32);
a = 1;
b = 1;
cExp = 64; //Each entry in output c is sum of 64 (1.0 x 1.0) multiplications
nd4j::ops::matmul op;
auto status = op.execute({&a, &b}, {&c}, {}, {0,1});
ASSERT_EQ(ND4J_STATUS_OK, status);
ASSERT_TRUE(cExp.isSameShape(c));
ASSERT_TRUE(cExp.equalsTo(c));
}
// @Test
// public void testMmulRank4_simple(){
// INDArray arr1 = Nd4j.ones(DataType.FLOAT, 32, 12, 128, 64);
// INDArray arr2 = Nd4j.ones(DataType.FLOAT, 32, 12, 128, 64);
// DynamicCustomOp op = DynamicCustomOp.builder("matmul")
// .addInputs(arr1, arr2)
// .addIntegerArguments(0, 1) //Transpose arr2 only
// .build();
// List<LongShapeDescriptor> shapes = op.calculateOutputShape();
// assertEquals(1, shapes.size());
// long[] shape = new long[]{32,12,128,128};
// assertArrayEquals(shape, shapes.get(0).getShape());
// INDArray out = Nd4j.create(DataType.FLOAT, shape);
// op.setOutputArgument(0, out);
// Nd4j.exec(op);
// // System.out.println(out);
// INDArray exp = Nd4j.valueArrayOf(shape, 64.0, DataType.FLOAT); //Each entry in output is sum of 64 (1.0 x 1.0) multiplications
// assertEquals(exp, out);
// }

View File

@ -397,27 +397,6 @@ TEST_F(DeclarableOpsTests2, NLP_Cbow_Test_1) {
delete result;
}
TEST_F(DeclarableOpsTests2, YetAnotherMatmulTest_1) {
auto A = NDArrayFactory::create<float>('c', {3, 3});
auto B = NDArrayFactory::create<float>('c', {3, 1});
auto exp = NDArrayFactory::create<float>('c', {3, 1}, {14.00f, 32.00f, 50.00f});
A.linspace(1);
B.linspace(1);
nd4j::ops::matmul op;
auto result = op.evaluate({&A, &B}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
TEST_F(DeclarableOpsTests2, Test_Squeeze_1) {
auto x = NDArrayFactory::create<float>('c', {2, 1, 3, 1, 1, 1, 4});
x.linspace(1);

View File

@ -789,120 +789,6 @@ TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_Validation_2) {
}
}
TEST_F(DeclarableOpsTests3, Test_Manual_Gemm_1) {
auto x= NDArrayFactory::create<double>('c', {3, 4}, {1, 2, 3, 4, 5, 6, 7, 8 , 9, 10, 11, 12});
auto y= NDArrayFactory::create<double>('c', {4, 3}, {1, 2, 3, 4, 5, 6, 7, 8 , 9, 10, 11, 12});
auto exp= NDArrayFactory::create<double>('f', {4, 4}, {38.0, 44.0, 50.0, 56.0, 83.0, 98.0, 113.0, 128.0, 128.0, 152.0, 176.0, 200.0, 173.0, 206.0, 239.0, 272.0});
nd4j::ops::matmul op;
auto result = op.evaluate({&x, &y}, {}, {1, 1});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
TEST_F(DeclarableOpsTests3, Test_Manual_Gemm_2) {
auto x= NDArrayFactory::create<double>('c', {3, 4}, {1, 2, 3, 4, 5, 6, 7, 8 , 9, 10, 11, 12});
auto y= NDArrayFactory::create<double>('c', {4, 3}, {1, 2, 3, 4, 5, 6, 7, 8 , 9, 10, 11, 12});
auto exp= NDArrayFactory::create<double>('f', {3, 3}, {70.0, 158.0, 246.0, 80.0, 184.0, 288.0, 90.0, 210.0, 330.0});
nd4j::ops::matmul op;
auto result = op.evaluate({&x, &y}, {}, {0, 0});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
TEST_F(DeclarableOpsTests3, Test_Manual_Gemm_3) {
auto x= NDArrayFactory::create<double>('c', {1, 3}, {1, 2, 3});
auto y= NDArrayFactory::create<double>('c', {1, 4}, {1, 2, 3, 4});
auto exp= NDArrayFactory::create<double>('f', {3, 4}, {1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 3.0, 6.0, 9.0, 4.0, 8.0, 12.0});
nd4j::ops::matmul op;
auto result = op.evaluate({&x, &y}, {}, {1, 0});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
//z->printIndexedBuffer("z");
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
TEST_F(DeclarableOpsTests3, Test_Manual_Gemm_4) {
auto x= NDArrayFactory::create<double>('c', {3, 1}, {1, 2, 3});
auto y= NDArrayFactory::create<double>('c', {4, 1}, {1, 2, 3, 4});
auto exp= NDArrayFactory::create<double>('f', {3, 4}, {1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 3.0, 6.0, 9.0, 4.0, 8.0, 12.0});
nd4j::ops::matmul op;
auto result = op.evaluate({&x, &y}, {}, {0, 1});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
//z->printIndexedBuffer("z");
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
TEST_F(DeclarableOpsTests3, Test_Manual_Gemm_5) {
auto x= NDArrayFactory::create<double>('c', {3, 1}, {1, 2, 3});
auto y= NDArrayFactory::create<double>('c', {1, 4}, {1, 2, 3, 4});
auto exp= NDArrayFactory::create<double>('f', {3, 4}, {1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 3.0, 6.0, 9.0, 4.0, 8.0, 12.0});
nd4j::ops::matmul op;
auto result = op.evaluate({&x, &y}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
//z->printIndexedBuffer("z");
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
TEST_F(DeclarableOpsTests3, Test_Manual_Gemm_6) {
auto x= NDArrayFactory::create<double>('c', {4, 1}, {1, 2, 3, 4});
auto y= NDArrayFactory::create<double>('c', {1, 4}, {1, 2, 3, 4});
auto exp= NDArrayFactory::create<double>('f', {4, 4}, {1,2, 3, 4,2,4, 6, 8,3,6, 9,12,4,8,12,16});
nd4j::ops::matmul op;
auto result = op.evaluate({&x, &y});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
//z->printIndexedBuffer("z");
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
TEST_F(DeclarableOpsTests3, Test_ReverseDivide_1) {
auto x= NDArrayFactory::create<double>('c', {1, 3}, {2, 2, 2});
auto y= NDArrayFactory::create<double>('c', {1, 3}, {4, 6, 8});

View File

@ -809,26 +809,6 @@ TEST_F(DeclarableOpsTests4, Test_Reshape_Again) {
delete result;
}
TEST_F(DeclarableOpsTests4, Test_Gemv_Transpose_1) {
auto x = NDArrayFactory::create<double>('c', {4, 3});
auto y = NDArrayFactory::create<double>('c', {4, 1});
auto exp = NDArrayFactory::create<double>('c',{ 3, 1}, {70, 80, 90});
x.linspace(1);
y.linspace(1);
nd4j::ops::matmul op;
auto result = op.evaluate({&x, &y}, {}, {1, 0});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
TEST_F(DeclarableOpsTests4, Test_Split_1) {
auto x = NDArrayFactory::create<double>('c', {5, 30});
auto sizes = NDArrayFactory::create<int>('c', {1, 3}, {4, 15, 11});
@ -1166,57 +1146,6 @@ TEST_F(DeclarableOpsTests4, Test_Cross_3) {
delete result;
}
TEST_F(DeclarableOpsTests4, Test_Matmul_YATS_1) {
auto a = NDArrayFactory::create<double>('c', {3, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
auto b = NDArrayFactory::create<double>('c', {4}, {1, 2, 3, 4});
auto exp = NDArrayFactory::create<double>('c', {3}, {30, 70, 110});
nd4j::ops::matmul op;
auto result = op.evaluate({&a, &b});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
TEST_F(DeclarableOpsTests4, Test_Matmul_YATS_2) {
auto a = NDArrayFactory::create<double>('c', {4}, {1, 2, 3, 4});
auto b = NDArrayFactory::create<double>('c', {4, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
auto exp = NDArrayFactory::create<double>('c', {3}, {70, 80, 90});
nd4j::ops::matmul op;
auto result = op.evaluate({&a, &b});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
TEST_F(DeclarableOpsTests4, Test_Matmul_YATS_3) {
auto a = NDArrayFactory::create<double>('c', {1, 4}, {1, 2, 3, 4});
auto b = NDArrayFactory::create<double>('c', {4, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
auto exp = NDArrayFactory::create<double>('c', {1, 3}, {70, 80, 90});
nd4j::ops::matmul op;
auto result = op.evaluate({&a, &b});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
TEST_F(DeclarableOpsTests4, Test_Add_119) {
auto a = NDArrayFactory::create<double>('c', {1, 4}, {1, 2, 3, 4});
auto b = NDArrayFactory::create<double>('c', {4}, {1, 2, 3, 4});

View File

@ -5019,20 +5019,6 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_Prod_7) {
delete result;
}
TEST_F(DeclarableOpsTests7, Test_Matmul_Once_Again) {
auto x = NDArrayFactory::create<double>('c', {1, 2}, {2.0f, 2.0f});
auto y = NDArrayFactory::create<double>('c', {2, 1}, {2.0f, 2.0f});
auto exp = NDArrayFactory::create<double>('c', {1, 1}, {8.0f});
nd4j::ops::matmul op;
auto result = op.evaluate({&x, &y}, {}, {});
ASSERT_EQ(Status::OK(), result->status());
ASSERT_EQ(exp, *result->at(0));
delete result;
}
TYPED_TEST(TypedDeclarableOpsTests7, Test_Pnorm_Once_Again) {
auto input = NDArrayFactory::create<TypeParam>('c', {1, 1, 5, 5}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f, 25.0f});
auto exp = NDArrayFactory::create<TypeParam>('c', {1, 1, 5, 5}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f, 25.0f});

View File

@ -932,208 +932,6 @@ TEST_F(DeclarableOpsTests9, tile_test1) {
delete results;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, matmul_test1) {
auto x = NDArrayFactory::create<double>('c', {3, 4});
auto y = NDArrayFactory::create<double>('c', {4, 3});
auto exp = NDArrayFactory::create<double>('f', {3, 3}, {35., 79., 123., 40., 92., 144., 45., 105., 165.});
x.linspace(1.);
y.linspace(0.5, 0.5);
nd4j::ops::matmul op;
auto results = op.evaluate({&x, &y}, {}, {});
auto z = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete results;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, matmul_test2) {
auto x = NDArrayFactory::create<double>('c', {3, 4});
auto y = NDArrayFactory::create<double>('f', {4, 3});
auto exp = NDArrayFactory::create<double>('f', {3, 3}, {35., 79., 123.,40., 92., 144.,45.,105., 165.});
x.linspace(1.);
y.linspace(0.5, 0.5);
nd4j::ops::matmul op;
auto results = op.evaluate({&x, &y}, {}, {});
auto z = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete results;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, matmul_test3) {
auto x = NDArrayFactory::create<double>('f', {3, 4});
auto y = NDArrayFactory::create<double>('c', {4, 3});
auto exp = NDArrayFactory::create<double>('f', {3, 3}, {35., 79., 123.,40., 92., 144.,45.,105., 165.});
x.linspace(1.);
y.linspace(0.5, 0.5);
nd4j::ops::matmul op;
auto results = op.evaluate({&x, &y}, {}, {});
auto z = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete results;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, matmul_test4) {
auto x = NDArrayFactory::create<double> ('f', {3, 4});
auto y = NDArrayFactory::create<double>('f', {4, 3});
auto exp = NDArrayFactory::create<double>('f', {3, 3}, {35., 79., 123.,40., 92., 144.,45.,105., 165.});
x.linspace(1.);
y.linspace(0.5, 0.5);
nd4j::ops::matmul op;
auto results = op.evaluate({&x, &y}, {}, {});
auto z = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete results;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, matmul_test5) {
auto x = NDArrayFactory::create<double>('c', {4, 3});
auto y = NDArrayFactory::create<double>('c', {4, 3});
auto exp = NDArrayFactory::create<double>('f', {3, 3}, {83., 94., 105., 94., 107., 120., 105., 120., 135.});
x.linspace(1.);
y.linspace(0.5, 0.5);
nd4j::ops::matmul op;
auto results = op.evaluate({&x, &y}, {}, {1});
auto z = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete results;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, matmul_test6) {
auto x = NDArrayFactory::create<double>('c', {4, 3});
auto y = NDArrayFactory::create<double>('f', {3, 4});
auto exp = NDArrayFactory::create<double>('f', {3, 3}, {35., 40., 45., 79., 92., 105., 123., 144., 165.});
x.linspace(1.);
y.linspace(0.5, 0.5);
nd4j::ops::matmul op;
auto results = op.evaluate({&x, &y}, {}, {1, 1});
auto z = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete results;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, matmul_test7) {
auto x = NDArrayFactory::create<double>('c', {5, 3,4});
auto y = NDArrayFactory::create<double>('f', {5, 3,4});
auto exp = NDArrayFactory::create<double>('f',{5, 3,3}, {3. , 84.6, 281.4, 593.4, 1020.6, 7. , 107.8, 323.8, 655. , 1101.4,11. , 131. , 366.2, 716.6, 1182.2,
7. , 107.8, 323.8, 655. , 1101.4,17.4, 137.4, 372.6, 723. , 1188.6,27.8, 167. , 421.4, 791. , 1275.8,
11. , 131. , 366.2, 716.6, 1182.2,27.8, 167. , 421.4, 791. , 1275.8,44.6, 203. , 476.6, 865.4, 1369.4,});
x.linspace(1.);
y.linspace(0.1, 0.1);
nd4j::ops::matmul op;
auto results = op.evaluate({&x, &y}, {}, {0, 1});
auto z = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete results;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, matmul_test8) {
auto x = NDArrayFactory::create<double>('c', {2,5, 3,4});
auto y = NDArrayFactory::create<double>('f', {2,5, 3,4});
auto exp = NDArrayFactory::create<double>('f',{2,5, 3,3}, {3. , 1563. , 84.6, 2220.6, 281.4, 2993.4, 593.4, 3881.4,1020.6, 4884.6, 7. , 1663. , 107.8, 2339.8, 323.8, 3131.8, 655. , 4039. ,1101.4, 5061.4,
11. , 1763. , 131. , 2459. , 366.2, 3270.2, 716.6, 4196.6,1182.2, 5238.2, 7. , 1663. , 107.8, 2339.8, 323.8, 3131.8, 655. , 4039. ,1101.4, 5061.4,
17.4, 1769.4, 137.4, 2465.4, 372.6, 3276.6, 723. , 4203. ,1188.6, 5244.6, 27.8, 1875.8, 167. , 2591. , 421.4, 3421.4, 791. , 4367. ,1275.8, 5427.8,
11. , 1763. , 131. , 2459. , 366.2, 3270.2, 716.6, 4196.6,1182.2, 5238.2, 27.8, 1875.8, 167. , 2591. , 421.4, 3421.4, 791. , 4367. ,1275.8, 5427.8,
44.6, 1988.6, 203. , 2723. , 476.6, 3572.6, 865.4, 4537.4,1369.4, 5617.4});
x.linspace(1.);
y.linspace(0.1, 0.1);
nd4j::ops::matmul op;
auto results = op.evaluate({&x, &y}, {}, {0, 1});
auto z = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete results;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, matmul_test9) {
auto x = NDArrayFactory::create<double>('c', {2,5, 4,3});
auto y = NDArrayFactory::create<double>('f', {2,5, 3,4});
auto exp = NDArrayFactory::create<double>('f',{2,5, 3,3}, {7. , 1639. , 103. , 2311. , 314.2, 3098.2, 640.6, 4000.6,1082.2, 5018.2, 8. , 1664. , 108.8, 2340.8, 324.8, 3132.8, 656. , 4040. ,1102.4, 5062.4,
9. , 1689. , 114.6, 2370.6, 335.4, 3167.4, 671.4, 4079.4,1122.6, 5106.6, 15.8, 1743.8, 131. , 2435. , 361.4, 3241.4, 707. , 4163. ,1167.8, 5199.8,
18.4, 1770.4, 138.4, 2466.4, 373.6, 3277.6, 724. , 4204. ,1189.6, 5245.6, 21. , 1797. , 145.8, 2497.8, 385.8, 3313.8, 741. , 4245. ,1211.4, 5291.4,
24.6, 1848.6, 159. , 2559. , 408.6, 3384.6, 773.4, 4325.4,1253.4, 5381.4, 28.8, 1876.8, 168. , 2592. , 422.4, 3422.4, 792. , 4368. ,1276.8, 5428.8,
33. , 1905. , 177. , 2625. , 436.2, 3460.2, 810.6, 4410.6,1300.2, 5476.2});
x.linspace(1.);
y.linspace(0.1, 0.1);
nd4j::ops::matmul op;
auto results = op.evaluate({&x, &y}, {}, {1, 1});
auto z = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete results;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, TestDropout_BP_1) {
@ -1325,325 +1123,6 @@ TEST_F(DeclarableOpsTests9, Test_AlphaDropout_BP_1) {
delete ress2;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, matmul_test10) {
auto x = NDArrayFactory::create<double>('c', {1, 4, 3});
auto y = NDArrayFactory::create<double>('f', {1, 3, 4});
auto exp = NDArrayFactory::create<double>('f', {1, 3, 3}, {35., 40., 45., 79., 92., 105., 123., 144., 165.});
x.linspace(1.);
y.linspace(0.5, 0.5);
nd4j::ops::matmul op;
auto results = op.evaluate({&x, &y}, {}, {1, 1});
auto z = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete results;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, matmul_test11) {
auto x = NDArrayFactory::create<double>('c', {4, 1});
auto y = NDArrayFactory::create<double>('f', {1, 4});
auto exp = NDArrayFactory::create<double>('f', {1, 1}, {15});
x.linspace(1.);
y.linspace(0.5, 0.5);
nd4j::ops::matmul op;
auto results = op.evaluate({&x, &y}, {}, {1, 1});
ASSERT_EQ(Status::OK(), results->status());
auto z = results->at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete results;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, matmul_test12) {
auto x = NDArrayFactory::create<double>('c', {1, 4, 1});
auto y = NDArrayFactory::create<double>('f', {1, 1, 4});
auto exp = NDArrayFactory::create<double>('f', {1, 1, 1}, {15});
x.linspace(1.);
y.linspace(0.5, 0.5);
nd4j::ops::matmul op;
auto results = op.evaluate({&x, &y}, {}, {1, 1});
ASSERT_EQ(Status::OK(), results->status());
auto z = results->at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete results;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, matmul_test13) {
auto x = NDArrayFactory::create<double>('c', {2, 3});
auto y = NDArrayFactory::create<double>('c', {3, 5});
auto exp = NDArrayFactory::create<double>('f', {5, 2}, {23. , 26. , 29. , 32. , 35., 50. , 57.5, 65. , 72.5, 80.});
x.linspace(1.);
y.linspace(0.5, 0.5);
nd4j::ops::matmul op;
auto results = op.evaluate({&x, &y}, {}, {0, 0, 1});
auto z = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete results;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, matmul_test14) {
auto x = NDArrayFactory::create<double>('c', {3, 2});
auto y = NDArrayFactory::create<double>('c', {3, 5});
auto exp = NDArrayFactory::create<double>('f', {5, 2}, {37. , 41.5, 46. , 50.5, 55., 46. , 52. , 58. , 64. , 70.});
x.linspace(1.);
y.linspace(0.5, 0.5);
nd4j::ops::matmul op;
auto results = op.evaluate({&x, &y}, {}, {1, 0, 1});
auto z = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete results;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, matmul_test15) {
auto x = NDArrayFactory::create<double>('c', {3, 2});
auto y = NDArrayFactory::create<double>('c', {3, 5});
auto exp = NDArrayFactory::create<double>('f', {5, 2}, {37. , 41.5, 46. , 50.5, 55., 46. , 52. , 58. , 64. , 70.});
x.linspace(1.);
y.linspace(0.5, 0.5);
nd4j::ops::matmul op;
auto results = op.evaluate({&x, &y}, {}, {1, 0, 1});
auto z = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete results;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, matmul_test16) {
auto x = NDArrayFactory::create<double>('c', {2,2, 3,5});
auto y = NDArrayFactory::create<double>('c', {2,2, 4,3});
auto exp = NDArrayFactory::create<double>('f',{2,2, 4,5}, {4.6, 281.8, 89.2, 582.4, 10. , 314.2,108.1, 628.3, 15.4, 346.6,127. , 674.2, 20.8, 379. ,145.9, 720.1, 5.2, 289.6, 93.4, 593.8,
11.5, 322.9,113.2, 640.6, 17.8, 356.2,133. , 687.4, 24.1, 389.5,152.8, 734.2, 5.8, 297.4, 97.6, 605.2, 13. , 331.6,118.3, 652.9,
20.2, 365.8,139. , 700.6, 27.4, 400. ,159.7, 748.3, 6.4, 305.2,101.8, 616.6, 14.5, 340.3,123.4, 665.2, 22.6, 375.4,145. , 713.8,
30.7, 410.5,166.6, 762.4, 7. , 313. ,106. , 628. , 16. , 349. ,128.5, 677.5, 25. , 385. ,151. , 727. , 34. , 421. ,173.5, 776.5});
x.linspace(1.);
y.linspace(0.1, 0.1);
nd4j::ops::matmul op;
auto results = op.evaluate({&x, &y}, {}, {1, 1, 1});
auto z = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete results;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, matmul_test17) {
auto x = NDArrayFactory::create<double>('f', {4, 3});
auto y = NDArrayFactory::create<double>('c', {4});
auto exp = NDArrayFactory::create<double>('f',{3}, {7., 8., 9.});
x.linspace(1.);
y.linspace(0.1, 0.1);
nd4j::ops::matmul op;
auto results = op.evaluate({&x, &y}, {}, {1, 0});
auto z = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete results;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, matmul_test18) {
auto x = NDArrayFactory::create<double>('f', {3});
auto y = NDArrayFactory::create<double>('c', {4, 3});
auto exp = NDArrayFactory::create<double>('f',{4}, {1.4, 3.2, 5., 6.8});
x.linspace(1.);
y.linspace(0.1, 0.1);
nd4j::ops::matmul op;
auto results = op.evaluate({&x, &y}, {}, {0, 1});
auto z = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete results;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, matmul_test19) {
auto x = NDArrayFactory::create<double>('f', {1, 1});
auto y = NDArrayFactory::create<double>('c', {1, 1});
auto exp = NDArrayFactory::create<double>('f',{1, 1}, {0.2});
x.linspace(2.);
y.linspace(0.1, 0.1);
nd4j::ops::matmul op;
auto results = op.evaluate({&x, &y}, {}, {});
auto z = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete results;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, matmul_test20) {
auto x = NDArrayFactory::create<double>('f', {1, 1});
auto y = NDArrayFactory::create<double>('c', {1, 1});
auto exp = NDArrayFactory::create<double>('f',{1, 1}, {0.2});
x.linspace(2.);
y.linspace(0.1, 0.1);
nd4j::ops::matmul op;
auto results = op.evaluate({&x, &y}, {}, {1,1,1});
auto z = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete results;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, matmul_test21) {
auto x = NDArrayFactory::create<double>('f', {1});
auto y = NDArrayFactory::create<double>('c', {1, 1});
auto exp = NDArrayFactory::create<double>('f',{1}, {0.2});
x.linspace(2.);
y.linspace(0.1, 0.1);
nd4j::ops::matmul op;
auto results = op.evaluate({&x, &y}, {}, {});
auto z = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete results;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, matmul_test22) {
auto x = NDArrayFactory::create<double>('f', {1,1});
auto y = NDArrayFactory::create<double>('c', {1});
auto exp = NDArrayFactory::create<double>('f',{1}, {0.2});
x.linspace(2.);
y.linspace(0.1, 0.1);
nd4j::ops::matmul op;
auto results = op.evaluate({&x, &y}, {}, {1});
auto z = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete results;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, matmul_test23) {
auto x = NDArrayFactory::create<double>('f', {4});
auto y = NDArrayFactory::create<double>('c', {4});
auto exp = NDArrayFactory::create<double>(3.);
x.linspace(1.);
y.linspace(0.1, 0.1);
nd4j::ops::matmul op;
auto results = op.evaluate({&x, &y}, {}, {1, 1});
auto z = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete results;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, matmul_test24) {
auto x = NDArrayFactory::create<double>('f', {1}, {2.});
auto y = NDArrayFactory::create<double>('c', {1}, {3.});
auto exp = NDArrayFactory::create<double>(6.);
nd4j::ops::matmul op;
auto results = op.evaluate({&x, &y}, {}, {1, 1});
auto z = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete results;
}
TEST_F(DeclarableOpsTests9, test_range_int_1) {
auto x0 = NDArrayFactory::create<int>(0);
auto x1 = NDArrayFactory::create<int>(2);

View File

@ -64,8 +64,11 @@ TEST_F(MklDnnTests, helpers_includer) {
nd4j::ops::platforms::PLATFORM_maxpool3dnew_bp_ENGINE_CPU maxpool3d_bp;
nd4j::ops::platforms::PLATFORM_lrn_ENGINE_CPU lrn;
nd4j::ops::platforms::PLATFORM_batchnorm_ENGINE_CPU batchnorm;
printer({&conv2d, &conv2d_bp, &conv3d, &conv3d_bp, &avgpool2d, &avgpool2d_bp, &maxpool2d, &maxpool2d_bp, &avgpool3d, &avgpool3d_bp, &maxpool3d, &maxpool3d_bp, &lrn, &batchnorm});
nd4j::ops::platforms::PLATFORM_matmul_ENGINE_CPU matmul;
printer({&conv2d, &conv2d_bp, &conv3d, &conv3d_bp, &avgpool2d, &avgpool2d_bp, &maxpool2d, &maxpool2d_bp, &avgpool3d, &avgpool3d_bp, &maxpool3d, &maxpool3d_bp, &lrn, &batchnorm, &matmul});
#endif
}