diff --git a/libnd4j/CMakeLists.txt.mkldnn.in b/libnd4j/CMakeLists.txt.mkldnn.in index 3de36dfde..e67b3554b 100644 --- a/libnd4j/CMakeLists.txt.mkldnn.in +++ b/libnd4j/CMakeLists.txt.mkldnn.in @@ -5,7 +5,7 @@ project(mkldnn-download NONE) include(ExternalProject) ExternalProject_Add(mkldnn GIT_REPOSITORY https://github.com/intel/mkl-dnn.git - GIT_TAG v1.1.3 + GIT_TAG v1.2 SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/mkldnn-src" BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/mkldnn-build" CONFIGURE_COMMAND "" diff --git a/libnd4j/include/ops/declarable/generic/blas/matmul.cpp b/libnd4j/include/ops/declarable/generic/blas/matmul.cpp index 3dd64a113..a673b1988 100644 --- a/libnd4j/include/ops/declarable/generic/blas/matmul.cpp +++ b/libnd4j/include/ops/declarable/generic/blas/matmul.cpp @@ -20,7 +20,7 @@ // @author Yurii Shyrma (iuriish@yahoo.com), fully rewritten // -#include +#include #if NOT_EXCLUDED(OP_matmul) #include @@ -29,142 +29,128 @@ namespace nd4j { namespace ops { - CUSTOM_OP_IMPL(matmul, 2, 1, false, 0, -2) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); +////////////////////////////////////////////////////////////////////// +CUSTOM_OP_IMPL(matmul, 2, 1, false, 0, -2) { - const int iSize = (int) block.getIArguments()->size(); - int transX = iSize > 0 ? INT_ARG(0) : 0; - int transY = iSize > 1 ? INT_ARG(1) : 0; - const int transZ = iSize > 2 ? INT_ARG(2) : 0; + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); - const int xRank = x->rankOf(); - const int yRank = y->rankOf(); - const int zRank = z->rankOf(); + const int iSize = (int) block.getIArguments()->size(); + int transX = iSize > 0 ? INT_ARG(0) : 0; + int transY = iSize > 1 ? INT_ARG(1) : 0; + const int transZ = iSize > 2 ? INT_ARG(2) : 0; - if (transZ) { - x = INPUT_VARIABLE(1); - y = INPUT_VARIABLE(0); - bool temp = transX; - transX = !transY; - transY = !temp; - } + const int xRank = x->rankOf(); + const int yRank = y->rankOf(); + const int zRank = z->rankOf(); - const int xLastDim = transX ? -2 : -1; - const int yLastDim = transY ? -2 : -1; - const int xLastButOneDim = transX ? -1 : -2; - const int yLastButOneDim = transY ? -1 : -2; + if (transZ) { + x = INPUT_VARIABLE(1); + y = INPUT_VARIABLE(0); + bool temp = transX; + transX = !transY; + transY = !temp; + } - // ******* input validation ******* // - REQUIRE_TRUE(xRank > 0 && yRank > 0, 0, - "MATMUL OP: input arrays must have rank bigger than 0 (should not be scalars), but got instead: x rank = %i, y rank = %i !", - xRank, yRank); + const int xLastDim = transX ? -2 : -1; + const int yLastDim = transY ? -2 : -1; + const int xLastButOneDim = transX ? -1 : -2; + const int yLastButOneDim = transY ? -1 : -2; - if (xRank == 1 && yRank == 1) { // dot case, output is scalar (or vector with length = 1) - REQUIRE_TRUE(x->lengthOf() == y->lengthOf(), 0, - "MATMUL OP: since input arrays are vectors they must have the same length, but got x length = %i, y length = %i !", - x->lengthOf(), y->lengthOf()); - } else if (xRank == 1 && yRank == 2) { // vector x matrix, i.e. [4] x [4,5] = [5], output is vector - REQUIRE_TRUE(x->lengthOf() == y->sizeAt(yLastButOneDim), 0, - "MATMUL OP: input arrays have inconsistent shapes for vector-matrix product: x %s, y %s !", - ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str()); - } else if (xRank == 2 && yRank == 1) { // matrix x vector , i.e. [4,5] x [5] = [4], output is vector - REQUIRE_TRUE(x->sizeAt(xLastDim) == y->lengthOf(), 0, - "MATMUL OP: input arrays have inconsistent shapes for matrix-vector product: x %s, y %s !", - ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str()); - } else { - REQUIRE_TRUE(xRank == yRank && yRank == zRank, 0, - "MATMUL OP: input and output arrays must have the same rank, but got instead: x rank = %i, y rank = %i, z rank = %i !", - xRank, yRank, zRank); - REQUIRE_TRUE(x->sizeAt(xLastDim) == y->sizeAt(yLastButOneDim) && - x->sizeAt(xLastButOneDim) == z->sizeAt(-2) && y->sizeAt(yLastDim) == z->sizeAt(-1), 0, - "MATMUL OP: input/output arrays have inconsistent shapes for matrix product: x %s, y %s, z %s !", - ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str(), - ShapeUtils::shapeAsString(z).c_str()); + // ******* input validation ******* // + REQUIRE_TRUE(xRank > 0 && yRank > 0, 0, "MATMUL OP: input arrays must have rank bigger than 0 (should not be scalars), but got instead: x rank = %i, y rank = %i !", xRank, yRank); - if (xRank > 2) // outer dims must be the same - for (int i = 0; i < xRank - 2; ++i) - REQUIRE_TRUE(x->sizeAt(i) == y->sizeAt(i) && y->sizeAt(i) == z->sizeAt(i), 0, - "MATMUL OP: input/output arrays have inconsistent shapes for matrix product: x %s, y %s, z %s !", - ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str(), - ShapeUtils::shapeAsString(z).c_str()); - } - // ******* end of input validation ******* // + if (xRank == 1 && yRank == 1) { // dot case, output is scalar (or vector with length = 1) + REQUIRE_TRUE(x->lengthOf() == y->lengthOf(), 0, "MATMUL OP: since input arrays are vectors they must have the same length, but got x length = %i, y length = %i !", x->lengthOf(), y->lengthOf()); + } else if (xRank == 1 && yRank == 2) { // vector x matrix, i.e. [4] x [4,5] = [5], output is vector + REQUIRE_TRUE(x->lengthOf() == y->sizeAt(yLastButOneDim), 0, "MATMUL OP: input arrays have inconsistent shapes for vector-matrix product: x %s, y %s !", ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str()); + } else if (xRank == 2 && yRank == 1) { // matrix x vector , i.e. [4,5] x [5] = [4], output is vector + REQUIRE_TRUE(x->sizeAt(xLastDim) == y->lengthOf(), 0, "MATMUL OP: input arrays have inconsistent shapes for matrix-vector product: x %s, y %s !", ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str()); + } else { + REQUIRE_TRUE(xRank == yRank && yRank == zRank, 0, "MATMUL OP: input and output arrays must have the same rank, but got instead: x rank = %i, y rank = %i, z rank = %i !", xRank, yRank, zRank); + REQUIRE_TRUE(x->sizeAt(xLastDim) == y->sizeAt(yLastButOneDim) && x->sizeAt(xLastButOneDim) == z->sizeAt(-2) && y->sizeAt(yLastDim) == z->sizeAt(-1), 0, "MATMUL OP: input/output arrays have inconsistent shapes for matrix product: x %s, y %s, z %s !", ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str(), ShapeUtils::shapeAsString(z).c_str()); - MmulHelper::matmul(x, y, z, transX, transY); + if (xRank > 2) // outer dims must be the same + for (int i = 0; i < xRank - 2; ++i) + REQUIRE_TRUE(x->sizeAt(i) == y->sizeAt(i) && y->sizeAt(i) == z->sizeAt(i), 0, "MATMUL OP: input/output arrays have inconsistent shapes for matrix product: x %s, y %s, z %s !", ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str(), ShapeUtils::shapeAsString(z).c_str()); + } + // ******* end of input validation ******* // - return Status::OK(); - } + MmulHelper::matmul(x, y, z, transX, transY); - DECLARE_SYN(mMul, matmul); + return Status::OK(); +} - DECLARE_SYN(mmul, matmul); +DECLARE_SYN(mMul, matmul); - DECLARE_SYN(gemm, matmul); +DECLARE_SYN(mmul, matmul); - DECLARE_SYN(gemv, matmul); +DECLARE_SYN(gemm, matmul); - DECLARE_SYN(dot, matmul); +DECLARE_SYN(gemv, matmul); +DECLARE_SYN(dot, matmul); - DECLARE_SHAPE_FN(matmul) { +////////////////////////////////////////////////////////////////////// +DECLARE_SHAPE_FN(matmul) { - auto xShapeInfo = inputShape->at(0); - auto yShapeInfo = inputShape->at(1); + auto xShapeInfo = inputShape->at(0); + auto yShapeInfo = inputShape->at(1); - const int iSize = (int) block.getIArguments()->size(); - int transX = iSize > 0 ? INT_ARG(0) : 0; - int transY = iSize > 1 ? INT_ARG(1) : 0; - const int transZ = iSize > 2 ? INT_ARG(2) : 0; + const int iSize = (int) block.getIArguments()->size(); + int transX = iSize > 0 ? INT_ARG(0) : 0; + int transY = iSize > 1 ? INT_ARG(1) : 0; + const int transZ = iSize > 2 ? INT_ARG(2) : 0; - REQUIRE_TRUE(xShapeInfo[0] > 0 && yShapeInfo[0] > 0, 0, - "MATMUL OP: input arrays must have rank bigger than 0 (should not be scalars), but got instead: x rank = %i, y rank = %i !", - xShapeInfo[0], yShapeInfo[0]); + REQUIRE_TRUE(xShapeInfo[0] > 0 && yShapeInfo[0] > 0, 0, + "MATMUL OP: input arrays must have rank bigger than 0 (should not be scalars), but got instead: x rank = %i, y rank = %i !", + xShapeInfo[0], yShapeInfo[0]); - if (transZ) { - xShapeInfo = inputShape->at(1); - yShapeInfo = inputShape->at(0); - bool temp = transX; - transX = !transY; - transY = !temp; - } + if (transZ) { + xShapeInfo = inputShape->at(1); + yShapeInfo = inputShape->at(0); + bool temp = transX; + transX = !transY; + transY = !temp; + } - auto zShapeOnly = ShapeUtils::evalShapeForMatmul(xShapeInfo, yShapeInfo, transX, transY); + auto zShapeOnly = ShapeUtils::evalShapeForMatmul(xShapeInfo, yShapeInfo, transX, transY); - auto dtypeX = ArrayOptions::dataType(xShapeInfo); - auto dtypeY = ArrayOptions::dataType(yShapeInfo); + auto dtypeX = ArrayOptions::dataType(xShapeInfo); + auto dtypeY = ArrayOptions::dataType(yShapeInfo); - auto xOrder = shape::order(xShapeInfo); - auto yOrder = shape::order(yShapeInfo); - auto zOrder = xOrder == 'c' && yOrder == 'c' ? 'c' : 'f'; + auto xOrder = shape::order(xShapeInfo); + auto yOrder = shape::order(yShapeInfo); + auto zOrder = xOrder == 'c' && yOrder == 'c' ? 'c' : 'f'; - // we just pick the higher data type out of X and Y - auto dtypeZ = dtypeX > dtypeY ? dtypeX : dtypeY; + // we just pick the higher data type out of X and Y + auto dtypeZ = dtypeX > dtypeY ? dtypeX : dtypeY; - auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(dtypeZ, zOrder, zShapeOnly); - return SHAPELIST(newShape); - } + auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(dtypeZ, zOrder, zShapeOnly); + return SHAPELIST(newShape); +} - DECLARE_TYPES(matmul) { - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_FLOATS}) - ->setAllowedInputTypes(1, {ALL_FLOATS}) - ->setAllowedOutputTypes(0, {ALL_FLOATS}); - } +////////////////////////////////////////////////////////////////////// +DECLARE_TYPES(matmul) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_FLOATS}) + ->setAllowedInputTypes(1, {ALL_FLOATS}) + ->setAllowedOutputTypes(0, {ALL_FLOATS}); +} +////////////////////////////////////////////////////////////////////// +CUSTOM_OP_IMPL(matmul_bp, 3, 2, false, 0, -2) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto eps = INPUT_VARIABLE(2); + auto dldx = OUTPUT_VARIABLE(0); + auto dldy = OUTPUT_VARIABLE(1); - CUSTOM_OP_IMPL(matmul_bp, 3, 2, false, 0, -2) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto eps = INPUT_VARIABLE(2); - auto dldx = OUTPUT_VARIABLE(0); - auto dldy = OUTPUT_VARIABLE(1); - - const int iSize = (int) block.getIArguments()->size(); - int transX = iSize > 0 ? INT_ARG(0) : 0; - int transY = iSize > 1 ? INT_ARG(1) : 0; - const int transZ = iSize > 2 ? INT_ARG(2) : 0; + const int iSize = (int) block.getIArguments()->size(); + int transX = iSize > 0 ? INT_ARG(0) : 0; + int transY = iSize > 1 ? INT_ARG(1) : 0; + const int transZ = iSize > 2 ? INT_ARG(2) : 0; /* In: x=[a,b], y=[b,c] @@ -177,34 +163,35 @@ F F T [a,b] [b,c] [c,a] [c,a] */ - nd4j::ops::matmul op; - op.execute({eps, y}, {dldx}, {}, {transZ, !transY, transX}, {}); - op.execute({x, eps}, {dldy}, {}, {!transX, transZ, transY}, {}); + nd4j::ops::matmul op; + op.execute({eps, y}, {dldx}, {}, {transZ, !transY, transX}, {}); + op.execute({x, eps}, {dldy}, {}, {!transX, transZ, transY}, {}); - return Status::OK(); - } + return Status::OK(); +} +////////////////////////////////////////////////////////////////////// +DECLARE_SHAPE_FN(matmul_bp) { + Nd4jLong *xShapeInfo; + Nd4jLong *yShapeInfo; - DECLARE_SHAPE_FN(matmul_bp) { - Nd4jLong *xShapeInfo; - Nd4jLong *yShapeInfo; + COPY_SHAPE(inputShape->at(0), xShapeInfo); + COPY_SHAPE(inputShape->at(1), yShapeInfo); - COPY_SHAPE(inputShape->at(0), xShapeInfo); - COPY_SHAPE(inputShape->at(1), yShapeInfo); + return SHAPELIST(CONSTANT(xShapeInfo), CONSTANT(yShapeInfo)); +} - return SHAPELIST(CONSTANT(xShapeInfo), CONSTANT(yShapeInfo)); - } +////////////////////////////////////////////////////////////////////// +DECLARE_TYPES(matmul_bp) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_FLOATS}) + ->setAllowedInputTypes(1, {ALL_FLOATS}) + ->setAllowedInputTypes(2, {ALL_FLOATS}) + ->setAllowedOutputTypes(0, {ALL_FLOATS}) + ->setAllowedOutputTypes(1, {ALL_FLOATS}); +} - DECLARE_TYPES(matmul_bp) { - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_FLOATS}) - ->setAllowedInputTypes(1, {ALL_FLOATS}) - ->setAllowedInputTypes(2, {ALL_FLOATS}) - ->setAllowedOutputTypes(0, {ALL_FLOATS}) - ->setAllowedOutputTypes(1, {ALL_FLOATS}); - } - - } +} } diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/matmul.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/matmul.cpp new file mode 100644 index 000000000..f47d08b7a --- /dev/null +++ b/libnd4j/include/ops/declarable/platform/mkldnn/matmul.cpp @@ -0,0 +1,294 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Yurii Shyrma (iuriish@yahoo.com) +// + +#include +#include +#include + +#include +#include "mkldnnUtils.h" +#include + + +namespace nd4j { +namespace ops { +namespace platforms { + +////////////////////////////////////////////////////////////////////////// +static void matmulMKLDNN(const NDArray* x, const NDArray* y, NDArray* z, const bool transX, const bool transY) { + + // mkl works with following + // [M,K] x [K,N] = [M,N] + // [bS, M,K] x [bS, K,N] = [bS, M,N] + + // possible input cases not supported by mkl, however we'll perform permut/reshape procedures in order to fit requirements + // [4] x [4] = [1] --> [1,4] x [4,1] = [1,1] + // [4] x [4,5] = [5] --> [1,4] x [4,5] = [1,5] + // [4,5] x [5] = [4] --> [4,5] x [5,1] = [4,1] + // [2,3, 4,5] x [2,3, 5,4] = [2,3, 4,4] --> [6, 4,5] x [6, 5,4] = [6, 4,4] + // [2,2,3, 4,5] x [2,2,3, 5,4] = [2,2,3, 4,4] --> [12, 4,5] x [12, 5,4] = [12, 4,4] + + const auto xRank = x->rankOf(); + const auto yRank = y->rankOf(); + const auto zRank = z->rankOf(); + + std::vector permut; + + // fill permutation vector appropriately if transposition is required + if((transX && xRank > 1) || (transY && yRank > 1)) { + + const int rank = xRank >= yRank ? xRank : yRank; + permut.resize(rank); + std::iota(std::begin(permut), std::end(permut), 0); + permut[rank-2] = rank - 1; + permut[rank-1] = rank - 2; + } + + const NDArray* xT = (transX && xRank > 1) ? new NDArray(x->permute(permut)) : x; + const NDArray* yT = (transY && yRank > 1) ? new NDArray(y->permute(permut)) : y; + + const NDArray* xTR = xRank <= 3 ? xT : new NDArray(xT->reshape(xT->ordering(), {xT->lengthOf() / (xT->sizeAt(-2) * xT->sizeAt(-1)), xT->sizeAt(-2), xT->sizeAt(-1)})); + const NDArray* yTR = xRank <= 3 ? yT : new NDArray(yT->reshape(yT->ordering(), {yT->lengthOf() / (yT->sizeAt(-2) * yT->sizeAt(-1)), yT->sizeAt(-2), yT->sizeAt(-1)})); + NDArray* zR = xRank <= 3 ? z : new NDArray(z->reshape(z->ordering(), {z->lengthOf() / (z->sizeAt(-2) * z->sizeAt(-1)), z->sizeAt(-2), z->sizeAt(-1)})/*, false*/); + + // [M,K] x [K,N] = [M,N] + const int M = (xRank > 1) ? xTR->sizeAt(-2) : 1; + const int K = (xRank > 1) ? xTR->sizeAt(-1) : xTR->lengthOf(); + const int N = (yRank > 1) ? yTR->sizeAt(-1) : 1; + const int bS = (xRank > 2) ? xTR->sizeAt(0) : 1; // [bS, M,K] x [bS, K,N] = [bS, M,N] + + dnnl::memory::dims xShape = xRank < 3 ? dnnl::memory::dims({M, K}) : dnnl::memory::dims({bS, M, K}); + dnnl::memory::dims yShape = xRank < 3 ? dnnl::memory::dims({K, N}) : dnnl::memory::dims({bS, K, N}); + dnnl::memory::dims zShape = xRank < 3 ? dnnl::memory::dims({M, N}) : dnnl::memory::dims({bS, M, N}); + + dnnl::memory::format_tag format = xRank < 3 ? dnnl::memory::format_tag::ab : dnnl::memory::format_tag::abc; + + // x type + dnnl::memory::data_type xType; + if(x->dataType() == DataType::FLOAT32) + xType = dnnl::memory::data_type::f32; + else if(x->dataType() == DataType::HALF) + xType = dnnl::memory::data_type::f16; + else if(x->dataType() == DataType::BFLOAT16) + xType = dnnl::memory::data_type::bf16; + else if(x->dataType() == DataType::UINT8) + xType = dnnl::memory::data_type::u8; + else + xType = dnnl::memory::data_type::s8; + + // y type + dnnl::memory::data_type yType = xType; + if(y->dataType() == DataType::UINT8) + yType = dnnl::memory::data_type::u8; + else if(y->dataType() == DataType::INT8) + yType = dnnl::memory::data_type::s8; + + // z type + dnnl::memory::data_type zType = xType; + if(z->dataType() == DataType::FLOAT32) + zType = dnnl::memory::data_type::f32; + else if(z->dataType() == DataType::INT32) + zType = dnnl::memory::data_type::s32; + else if(z->dataType() == DataType::UINT8) + zType = dnnl::memory::data_type::u8; + else if(z->dataType() == DataType::INT8) + zType = dnnl::memory::data_type::s8; + + // memory descriptors for arrays + + // x + dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xShape, xType, dnnl::memory::format_tag::any); + dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, xType, format); + if(xTR->ews() != 1 || xTR->ordering() != 'c') { + x_user_md.data.format_kind = dnnl_blocked; // overrides format + x_user_md.data.format_desc.blocking.strides[0] = xRank == 1 ? 1 : xTR->strideAt(0); + x_user_md.data.format_desc.blocking.strides[1] = xRank == 1 ? xTR->strideAt(0) : xTR->strideAt(1); + if(xRank > 2) + x_user_md.data.format_desc.blocking.strides[2] = xTR->strideAt(2); + } + + // y + dnnl::memory::desc y_mkl_md = dnnl::memory::desc(yShape, yType, dnnl::memory::format_tag::any); + dnnl::memory::desc y_user_md = dnnl::memory::desc(yShape, yType, format); + if(yTR->ews() != 1 || yTR->ordering() != 'c') { + y_user_md.data.format_kind = dnnl_blocked; // overrides format + y_user_md.data.format_desc.blocking.strides[0] = yRank == 1 ? 1 : yTR->strideAt(0); + y_user_md.data.format_desc.blocking.strides[1] = yRank == 1 ? yTR->strideAt(0) : yTR->strideAt(1); + if(yRank > 2) + y_user_md.data.format_desc.blocking.strides[2] = yTR->strideAt(2); + } + + // z + dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zShape, zType, dnnl::memory::format_tag::any); + dnnl::memory::desc z_user_md = dnnl::memory::desc(zShape, zType, format); + if(zR->ews() != 1 || zR->ordering() != 'c') { + z_user_md.data.format_kind = dnnl_blocked; // overrides format + z_user_md.data.format_desc.blocking.strides[0] = zRank == 1 ? 1 : zR->strideAt(0); + z_user_md.data.format_desc.blocking.strides[1] = zRank == 1 ? zR->strideAt(0) : zR->strideAt(1); + if(zRank > 2) + z_user_md.data.format_desc.blocking.strides[2] = zR->strideAt(2); + } + + auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); + + // Create attributes (to handle alpha and beta if necessary) + dnnl::primitive_attr attr; // it is empty since we have usual values for alpha (=1) and beta (=0) + + // operation primitive description + dnnl::matmul::desc op_desc(x_mkl_md, y_mkl_md, z_mkl_md); + dnnl::matmul::primitive_desc op_prim_desc(op_desc, attr, engine); + + // arguments (memory buffers) necessary for calculations + std::unordered_map args; + + dnnl::stream stream(engine); + + // provide memory buffers and check whether reorder is required + + // input + auto x_user_mem = dnnl::memory(x_user_md, engine, xTR->getBuffer()); + const bool xReorder = op_prim_desc.src_desc() != x_user_mem.get_desc(); + auto x_mkl_mem = xReorder ? dnnl::memory(op_prim_desc.src_desc(), engine) : x_user_mem; + if (xReorder) + dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem); + args[DNNL_ARG_SRC] = x_mkl_mem; + + // y + auto y_user_mem = dnnl::memory(y_user_md, engine, yTR->getBuffer()); + const bool yReorder = op_prim_desc.weights_desc() != y_user_mem.get_desc(); + auto y_mkl_mem = yReorder ? dnnl::memory(op_prim_desc.weights_desc(), engine) : y_user_mem; + if (yReorder) + dnnl::reorder(y_user_mem, y_mkl_mem).execute(stream, y_user_mem, y_mkl_mem); + args[DNNL_ARG_WEIGHTS] = y_mkl_mem; + + // z + auto z_user_mem = dnnl::memory(z_user_md, engine, zR->getBuffer()); + const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc(); + auto z_mkl_mem = zReorder ? dnnl::memory(op_prim_desc.dst_desc(), engine) : z_user_mem; + args[DNNL_ARG_DST] = z_mkl_mem; + + // run calculations + dnnl::matmul(op_prim_desc).execute(stream, args); + + // reorder outputs if necessary + if (zReorder) + dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem); + + stream.wait(); + + if(zR->getBuffer() != z->getBuffer()) + z->assign(zR); + + if(zR != z) + delete zR; + if(xTR != xT) + delete xTR; + if(xT != x) + delete xT; + if(yTR != yT) + delete yTR; + if(yT != y) + delete yT; + + // shape::printArray(z_mkl_mem.map_data(),8); +} + +////////////////////////////////////////////////////////////////////////// +PLATFORM_IMPL(matmul, ENGINE_CPU) { + + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); + + if(x->isEmpty() || y->isEmpty()) + return Status::OK(); + + const int iSize = (int) block.getIArguments()->size(); + int transX = iSize > 0 ? INT_ARG(0) : 0; + int transY = iSize > 1 ? INT_ARG(1) : 0; + const int transZ = iSize > 2 ? INT_ARG(2) : 0; + + const int xRank = x->rankOf(); + const int yRank = y->rankOf(); + const int zRank = z->rankOf(); + + if (transZ) { + x = INPUT_VARIABLE(1); + y = INPUT_VARIABLE(0); + bool temp = transX; + transX = !transY; + transY = !temp; + } + + const int xLastDim = transX ? -2 : -1; + const int yLastDim = transY ? -2 : -1; + const int xLastButOneDim = transX ? -1 : -2; + const int yLastButOneDim = transY ? -1 : -2; + + // ******* input validation ******* // + REQUIRE_TRUE(xRank > 0 && yRank > 0, 0, "MATMUL MKLDNN OP: input arrays must have rank bigger than 0 (should not be scalars), but got instead: x rank = %i, y rank = %i !", xRank, yRank); + + if (xRank == 1 && yRank == 1) { // dot case, output is scalar (or vector with length = 1) + REQUIRE_TRUE(x->lengthOf() == y->lengthOf(), 0,"MATMUL MKLDNN OP: since input arrays are vectors they must have the same length, but got x length = %i, y length = %i !",x->lengthOf(), y->lengthOf()); + } else if (xRank == 1 && yRank == 2) { // vector x matrix, i.e. [4] x [4,5] = [5], output is vector + REQUIRE_TRUE(x->lengthOf() == y->sizeAt(yLastButOneDim), 0, "MATMUL MKLDNN OP: input arrays have inconsistent shapes for vector-matrix product: x %s, y %s !", ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str()); + } else if (xRank == 2 && yRank == 1) { // matrix x vector , i.e. [4,5] x [5] = [4], output is vector + REQUIRE_TRUE(x->sizeAt(xLastDim) == y->lengthOf(), 0, "MATMUL MKLDNN OP: input arrays have inconsistent shapes for matrix-vector product: x %s, y %s !", ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str()); + } else { + REQUIRE_TRUE(xRank == yRank && yRank == zRank, 0, "MATMUL MKLDNN OP: input and output arrays must have the same rank, but got instead: x rank = %i, y rank = %i, z rank = %i !", xRank, yRank, zRank); + REQUIRE_TRUE(x->sizeAt(xLastDim) == y->sizeAt(yLastButOneDim) && x->sizeAt(xLastButOneDim) == z->sizeAt(-2) && y->sizeAt(yLastDim) == z->sizeAt(-1), 0, "MATMUL MKLDNN OP: input/output arrays have inconsistent shapes for matrix product: x %s, y %s, z %s !", ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str(), ShapeUtils::shapeAsString(z).c_str()); + + if (xRank > 2) // outer dims must be the same + for (int i = 0; i < xRank - 2; ++i) + REQUIRE_TRUE(x->sizeAt(i) == y->sizeAt(i) && y->sizeAt(i) == z->sizeAt(i), 0, "MATMUL MKLDNN OP: input/output arrays have inconsistent shapes for matrix product: x %s, y %s, z %s !", ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str(), ShapeUtils::shapeAsString(z).c_str()); + } + // ******* end of input validation ******* // + + matmulMKLDNN(x, y, z, transX, transY); + + return Status::OK(); +} + +////////////////////////////////////////////////////////////////////////// +PLATFORM_CHECK(matmul, ENGINE_CPU) { + + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + + auto z = INPUT_VARIABLE(0); + + const DataType xType = x->dataType(); + const DataType yType = y->dataType(); + const DataType zType = z->dataType(); + + + return block.isUseMKLDNN() && + ( + (xType==DataType::FLOAT32 && yType==DataType::FLOAT32 && zType==DataType::FLOAT32) || + (xType==DataType::HALF && yType==DataType::HALF && zType==DataType::FLOAT32) || + (xType==DataType::BFLOAT16 && yType==DataType::BFLOAT16 && zType==DataType::BFLOAT16) || + ((xType==DataType::UINT8 || xType==DataType::INT8) && (yType==DataType::UINT8 || yType==DataType::INT8) && (zType==DataType::UINT8 || zType==DataType::INT8 || zType==DataType::INT32 || zType==DataType::FLOAT32)) + ); +} + + +} +} +} diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.h b/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.h index c8b34a6c0..10adf533d 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.h +++ b/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.h @@ -84,6 +84,8 @@ namespace nd4j{ DECLARE_PLATFORM(depthwise_conv2d, ENGINE_CPU); DECLARE_PLATFORM(depthwise_conv2d_bp, ENGINE_CPU); + + DECLARE_PLATFORM(matmul, ENGINE_CPU); } } diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp index dee410a21..507a507af 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp @@ -1341,40 +1341,6 @@ TEST_F(DeclarableOpsTests1, MultiplyScalarScalar1) { delete exp; } -TEST_F(DeclarableOpsTests1, TestMatMul1) { - auto x = NDArrayFactory::create_('c', {3, 5}); - x->linspace(1); - - auto y = NDArrayFactory::create_('c', {5, 3}); - y->linspace(1); - - float _expB[]{135.0f, 310.0f, 485.0f, 150.0f, 350.0f, 550.0f, 165.0f, 390.0f, 615.0f}; - Nd4jLong _expS[] {2, 3, 3, 1, 3, 0, 1, 102}; // expected shape - ArrayOptions::setDataType(_expS, nd4j::DataType::FLOAT32); - NDArray exp(_expB, _expS); - - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, x); - variableSpace->putVariable(-2, y); - variableSpace->putVariable(1, new Variable()); - - auto block = new Context(1, variableSpace, false); - block->fillInputs({-1, -2}); - - nd4j::ops::matmul op; - - Nd4jStatus status = op.execute(block); - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(variableSpace->hasVariable(1)); - - auto result = variableSpace->getVariable(1)->getNDArray(); - - ASSERT_TRUE(result->equalsTo(&exp)); - - delete block; - delete variableSpace; -} - ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, TestSoftMax_bp_1) { diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp index dc672d8e6..e5eaa9a6a 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp @@ -2800,16 +2800,9 @@ TEST_F(DeclarableOpsTests12, QR_Test_1_1) { //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, QR_Test_2) { - auto in = NDArrayFactory::create('c', {5,3}, { - 12., -51., 4., 6., 167., -68., -4., 24., -41., -1., 1., 0., 2., 0., 3. - }); - auto expQ = NDArrayFactory::create('c', {5, 3}, { - 0.8464148, 0.3912908, -0.3431241, -0.42320737, -0.9040873, 0.02927014, 0.28213826, -0.17042054, -0.93285596, 0.07053456, -0.01404065, 0.00109937, -0.14106913, 0.0166551, 0.10577161 - }); - - auto expR = NDArrayFactory::create('c', {3,3}, { - -14.177447, -20.666622, 13.401566, 0., -175.04254, 70.080315, 0., 0., 35.201546 - }); + auto in = NDArrayFactory::create('c', {5,3}, {12., -51., 4., 6., 167., -68., -4., 24., -41., -1., 1., 0., 2., 0., 3.}); + auto expQ = NDArrayFactory::create('c', {5, 3}, {0.8464148,0.3912908,-0.3431241,-0.42320737, -0.9040873,0.02927014,0.28213826, -0.17042054, -0.93285596,0.07053456, -0.01404065,0.00109937,-0.14106913,0.0166551,0.10577161}); + auto expR = NDArrayFactory::create('c', {3,3}, {-14.177447,-20.666622,13.401566,0.,-175.04254,70.080315,0.,0.,35.201546}); nd4j::ops::qr op; auto res = op.evaluate({&in}, {}, {}, {false}); @@ -2819,8 +2812,6 @@ TEST_F(DeclarableOpsTests12, QR_Test_2) { auto r = res->at(1); ASSERT_TRUE(q->isSameShape(expQ)); ASSERT_TRUE(r->isSameShape(expR)); -// q->printIndexedBuffer("Orthogonal 5x5"); -// r->printIndexedBuffer("Upper triangular 5x3"); nd4j::ops::matmul opMul; auto res2 = opMul.evaluate({q, r}); //MmulHelper::matmul(q, r, &in, false, false); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests14.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests14.cpp index 7e3fae4af..25e2d383d 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests14.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests14.cpp @@ -682,3 +682,810 @@ TEST_F(DeclarableOpsTests14, Test_broadcast_SpecialCaseTest8) { x.applyTrueBroadcast(BroadcastOpsTuple::Subtract(), y, z); ASSERT_EQ(e, z); } + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests14, matmul_test1) { + + auto x = NDArrayFactory::create('c', {3, 4}); + auto y = NDArrayFactory::create('c', {4, 3}); + auto exp = NDArrayFactory::create('f', {3, 3}, {35., 79., 123., 40., 92., 144., 45., 105., 165.}); + + x.linspace(1.); + y.linspace(0.5, 0.5); + + nd4j::ops::matmul op; + auto results = op.evaluate({&x, &y}, {}, {}); + auto z = results->at(0); + + ASSERT_EQ(Status::OK(), results->status()); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + delete results; +} + + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests14, matmul_test2) { + + auto x = NDArrayFactory::create('c', {3, 4}); + auto y = NDArrayFactory::create('f', {4, 3}); + auto exp = NDArrayFactory::create('f', {3, 3}, {35., 79., 123.,40., 92., 144.,45.,105., 165.}); + + x.linspace(1.); + y.linspace(0.5, 0.5); + + nd4j::ops::matmul op; + auto results = op.evaluate({&x, &y}, {}, {}); + auto z = results->at(0); + + ASSERT_EQ(Status::OK(), results->status()); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + delete results; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests14, matmul_test3) { + + auto x = NDArrayFactory::create('f', {3, 4}); + auto y = NDArrayFactory::create('c', {4, 3}); + auto exp = NDArrayFactory::create('f', {3, 3}, {35., 79., 123.,40., 92., 144.,45.,105., 165.}); + + x.linspace(1.); + y.linspace(0.5, 0.5); + + nd4j::ops::matmul op; + auto results = op.evaluate({&x, &y}, {}, {}); + auto z = results->at(0); + + ASSERT_EQ(Status::OK(), results->status()); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + delete results; +} + + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests14, matmul_test4) { + + auto x = NDArrayFactory::create ('f', {3, 4}); + auto y = NDArrayFactory::create('f', {4, 3}); + auto exp = NDArrayFactory::create('f', {3, 3}, {35., 79., 123.,40., 92., 144.,45.,105., 165.}); + + x.linspace(1.); + y.linspace(0.5, 0.5); + + nd4j::ops::matmul op; + auto results = op.evaluate({&x, &y}, {}, {}); + auto z = results->at(0); + + ASSERT_EQ(Status::OK(), results->status()); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + delete results; +} + + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests14, matmul_test5) { + + auto x = NDArrayFactory::create('c', {4, 3}); + auto y = NDArrayFactory::create('c', {4, 3}); + auto exp = NDArrayFactory::create('f', {3, 3}, {83., 94., 105., 94., 107., 120., 105., 120., 135.}); + + x.linspace(1.); + y.linspace(0.5, 0.5); + + nd4j::ops::matmul op; + auto results = op.evaluate({&x, &y}, {}, {1}); + auto z = results->at(0); + + ASSERT_EQ(Status::OK(), results->status()); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + delete results; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests14, matmul_test6) { + + auto x = NDArrayFactory::create('c', {4, 3}); + auto y = NDArrayFactory::create('f', {3, 4}); + auto exp = NDArrayFactory::create('f', {3, 3}, {35., 40., 45., 79., 92., 105., 123., 144., 165.}); + + x.linspace(1.); + y.linspace(0.5, 0.5); + + nd4j::ops::matmul op; + auto results = op.evaluate({&x, &y}, {}, {1, 1}); + auto z = results->at(0); + + ASSERT_EQ(Status::OK(), results->status()); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + delete results; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests14, matmul_test7) { + + auto x = NDArrayFactory::create('c', {5, 3,4}); + auto y = NDArrayFactory::create('f', {5, 3,4}); + auto exp = NDArrayFactory::create('f',{5, 3,3}, {3. , 84.6, 281.4, 593.4, 1020.6, 7. , 107.8, 323.8, 655. , 1101.4,11. , 131. , 366.2, 716.6, 1182.2, + 7. , 107.8, 323.8, 655. , 1101.4,17.4, 137.4, 372.6, 723. , 1188.6,27.8, 167. , 421.4, 791. , 1275.8, + 11. , 131. , 366.2, 716.6, 1182.2,27.8, 167. , 421.4, 791. , 1275.8,44.6, 203. , 476.6, 865.4, 1369.4,}); + + x.linspace(1.); + y.linspace(0.1, 0.1); + + nd4j::ops::matmul op; + auto results = op.evaluate({&x, &y}, {}, {0, 1}); + auto z = results->at(0); + + ASSERT_EQ(Status::OK(), results->status()); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + delete results; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests14, matmul_test8) { + + auto x = NDArrayFactory::create('c', {2,5, 3,4}); + auto y = NDArrayFactory::create('f', {2,5, 3,4}); + auto exp = NDArrayFactory::create('f',{2,5, 3,3}, {3. , 1563. , 84.6, 2220.6, 281.4, 2993.4, 593.4, 3881.4,1020.6, 4884.6, 7. , 1663. , 107.8, 2339.8, 323.8, 3131.8, 655. , 4039. ,1101.4, 5061.4, + 11. , 1763. , 131. , 2459. , 366.2, 3270.2, 716.6, 4196.6,1182.2, 5238.2, 7. , 1663. , 107.8, 2339.8, 323.8, 3131.8, 655. , 4039. ,1101.4, 5061.4, + 17.4, 1769.4, 137.4, 2465.4, 372.6, 3276.6, 723. , 4203. ,1188.6, 5244.6, 27.8, 1875.8, 167. , 2591. , 421.4, 3421.4, 791. , 4367. ,1275.8, 5427.8, + 11. , 1763. , 131. , 2459. , 366.2, 3270.2, 716.6, 4196.6,1182.2, 5238.2, 27.8, 1875.8, 167. , 2591. , 421.4, 3421.4, 791. , 4367. ,1275.8, 5427.8, + 44.6, 1988.6, 203. , 2723. , 476.6, 3572.6, 865.4, 4537.4,1369.4, 5617.4}); + + x.linspace(1.); + y.linspace(0.1, 0.1); + + nd4j::ops::matmul op; + auto results = op.evaluate({&x, &y}, {}, {0, 1}); + auto z = results->at(0); + + ASSERT_EQ(Status::OK(), results->status()); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + delete results; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests14, matmul_test9) { + + auto x = NDArrayFactory::create('c', {2,5, 4,3}); + auto y = NDArrayFactory::create('f', {2,5, 3,4}); + auto exp = NDArrayFactory::create('f',{2,5, 3,3}, {7. , 1639. , 103. , 2311. , 314.2, 3098.2, 640.6, 4000.6,1082.2, 5018.2, 8. , 1664. , 108.8, 2340.8, 324.8, 3132.8, 656. , 4040. ,1102.4, 5062.4, + 9. , 1689. , 114.6, 2370.6, 335.4, 3167.4, 671.4, 4079.4,1122.6, 5106.6, 15.8, 1743.8, 131. , 2435. , 361.4, 3241.4, 707. , 4163. ,1167.8, 5199.8, + 18.4, 1770.4, 138.4, 2466.4, 373.6, 3277.6, 724. , 4204. ,1189.6, 5245.6, 21. , 1797. , 145.8, 2497.8, 385.8, 3313.8, 741. , 4245. ,1211.4, 5291.4, + 24.6, 1848.6, 159. , 2559. , 408.6, 3384.6, 773.4, 4325.4,1253.4, 5381.4, 28.8, 1876.8, 168. , 2592. , 422.4, 3422.4, 792. , 4368. ,1276.8, 5428.8, + 33. , 1905. , 177. , 2625. , 436.2, 3460.2, 810.6, 4410.6,1300.2, 5476.2}); + + x.linspace(1.); + y.linspace(0.1, 0.1); + + nd4j::ops::matmul op; + auto results = op.evaluate({&x, &y}, {}, {1, 1}); + auto z = results->at(0); + + ASSERT_EQ(Status::OK(), results->status()); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + delete results; +} + +TEST_F(DeclarableOpsTests14, matmul_test10) { + + auto x = NDArrayFactory::create_('c', {3, 5}); + x->linspace(1); + + auto y = NDArrayFactory::create_('c', {5, 3}); + y->linspace(1); + + float _expB[]{135.0f, 310.0f, 485.0f, 150.0f, 350.0f, 550.0f, 165.0f, 390.0f, 615.0f}; + Nd4jLong _expS[] {2, 3, 3, 1, 3, 0, 1, 102}; // expected shape + ArrayOptions::setDataType(_expS, nd4j::DataType::FLOAT32); + NDArray exp(_expB, _expS); + + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + variableSpace->putVariable(-2, y); + variableSpace->putVariable(1, new Variable()); + + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1, -2}); + + nd4j::ops::matmul op; + + Nd4jStatus status = op.execute(block); + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(variableSpace->hasVariable(1)); + + auto result = variableSpace->getVariable(1)->getNDArray(); + + ASSERT_TRUE(result->equalsTo(&exp)); + + delete block; + delete variableSpace; +} + +TEST_F(DeclarableOpsTests14, matmul_test11) { + auto A = NDArrayFactory::create('c', {3, 3}); + auto B = NDArrayFactory::create('c', {3, 1}); + auto exp = NDArrayFactory::create('c', {3, 1}, {14.00f, 32.00f, 50.00f}); + + A.linspace(1); + B.linspace(1); + + nd4j::ops::matmul op; + + auto result = op.evaluate({&A, &B}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result->status()); + + auto z = result->at(0); + + ASSERT_TRUE(exp.equalsTo(z)); + + delete result; +} + +TEST_F(DeclarableOpsTests14, matmul_test12) { + auto x= NDArrayFactory::create('c', {3, 4}, {1, 2, 3, 4, 5, 6, 7, 8 , 9, 10, 11, 12}); + auto y= NDArrayFactory::create('c', {4, 3}, {1, 2, 3, 4, 5, 6, 7, 8 , 9, 10, 11, 12}); + auto exp= NDArrayFactory::create('f', {4, 4}, {38.0, 44.0, 50.0, 56.0, 83.0, 98.0, 113.0, 128.0, 128.0, 152.0, 176.0, 200.0, 173.0, 206.0, 239.0, 272.0}); + + nd4j::ops::matmul op; + auto result = op.evaluate({&x, &y}, {}, {1, 1}); + ASSERT_EQ(ND4J_STATUS_OK, result->status()); + + auto z = result->at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + + delete result; +} + + +TEST_F(DeclarableOpsTests14, matmul_test13) { + auto x= NDArrayFactory::create('c', {1, 3}, {1, 2, 3}); + auto y= NDArrayFactory::create('c', {1, 4}, {1, 2, 3, 4}); + auto exp= NDArrayFactory::create('f', {3, 4}, {1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 3.0, 6.0, 9.0, 4.0, 8.0, 12.0}); + + nd4j::ops::matmul op; + auto result = op.evaluate({&x, &y}, {}, {1, 0}); + ASSERT_EQ(ND4J_STATUS_OK, result->status()); + + auto z = result->at(0); + + //z->printIndexedBuffer("z"); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + delete result; +} + +TEST_F(DeclarableOpsTests14, matmul_test14) { + auto x= NDArrayFactory::create('c', {3, 1}, {1, 2, 3}); + auto y= NDArrayFactory::create('c', {4, 1}, {1, 2, 3, 4}); + auto exp= NDArrayFactory::create('f', {3, 4}, {1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 3.0, 6.0, 9.0, 4.0, 8.0, 12.0}); + + nd4j::ops::matmul op; + auto result = op.evaluate({&x, &y}, {}, {0, 1}); + ASSERT_EQ(ND4J_STATUS_OK, result->status()); + + auto z = result->at(0); + + //z->printIndexedBuffer("z"); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + delete result; +} + +TEST_F(DeclarableOpsTests14, matmul_test15) { + auto x= NDArrayFactory::create('c', {3, 1}, {1, 2, 3}); + auto y= NDArrayFactory::create('c', {1, 4}, {1, 2, 3, 4}); + auto exp= NDArrayFactory::create('f', {3, 4}, {1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 3.0, 6.0, 9.0, 4.0, 8.0, 12.0}); + + nd4j::ops::matmul op; + auto result = op.evaluate({&x, &y}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result->status()); + + auto z = result->at(0); + + //z->printIndexedBuffer("z"); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + delete result; +} + +TEST_F(DeclarableOpsTests14, matmul_test16) { + auto x= NDArrayFactory::create('c', {4, 1}, {1, 2, 3, 4}); + auto y= NDArrayFactory::create('c', {1, 4}, {1, 2, 3, 4}); + auto exp= NDArrayFactory::create('f', {4, 4}, {1,2, 3, 4,2,4, 6, 8,3,6, 9,12,4,8,12,16}); + + nd4j::ops::matmul op; + auto result = op.evaluate({&x, &y}); + ASSERT_EQ(ND4J_STATUS_OK, result->status()); + + auto z = result->at(0); + + //z->printIndexedBuffer("z"); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + delete result; +} + +TEST_F(DeclarableOpsTests14, matmul_test17) { + auto x = NDArrayFactory::create('c', {1, 2}, {2.0f, 2.0f}); + auto y = NDArrayFactory::create('c', {2, 1}, {2.0f, 2.0f}); + auto exp = NDArrayFactory::create('c', {1, 1}, {8.0f}); + + nd4j::ops::matmul op; + auto result = op.evaluate({&x, &y}, {}, {}); + ASSERT_EQ(Status::OK(), result->status()); + + ASSERT_EQ(exp, *result->at(0)); + + delete result; +} + + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests14, matmul_test18) { + + auto x = NDArrayFactory::create('c', {1, 4, 3}); + auto y = NDArrayFactory::create('f', {1, 3, 4}); + auto exp = NDArrayFactory::create('f', {1, 3, 3}, {35., 40., 45., 79., 92., 105., 123., 144., 165.}); + + x.linspace(1.); + y.linspace(0.5, 0.5); + + nd4j::ops::matmul op; + auto results = op.evaluate({&x, &y}, {}, {1, 1}); + auto z = results->at(0); + + ASSERT_EQ(Status::OK(), results->status()); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + delete results; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests14, matmul_test19) { + + auto x = NDArrayFactory::create('c', {4, 1}); + auto y = NDArrayFactory::create('f', {1, 4}); + auto exp = NDArrayFactory::create('f', {1, 1}, {15}); + + x.linspace(1.); + y.linspace(0.5, 0.5); + nd4j::ops::matmul op; + auto results = op.evaluate({&x, &y}, {}, {1, 1}); + ASSERT_EQ(Status::OK(), results->status()); + + auto z = results->at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + delete results; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests14, matmul_test20) { + + auto x = NDArrayFactory::create('c', {1, 4, 1}); + auto y = NDArrayFactory::create('f', {1, 1, 4}); + auto exp = NDArrayFactory::create('f', {1, 1, 1}, {15}); + + x.linspace(1.); + y.linspace(0.5, 0.5); + + nd4j::ops::matmul op; + auto results = op.evaluate({&x, &y}, {}, {1, 1}); + + ASSERT_EQ(Status::OK(), results->status()); + auto z = results->at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + delete results; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests14, matmul_test21) { + + auto x = NDArrayFactory::create('c', {2, 3}); + auto y = NDArrayFactory::create('c', {3, 5}); + auto exp = NDArrayFactory::create('f', {5, 2}, {23. , 26. , 29. , 32. , 35., 50. , 57.5, 65. , 72.5, 80.}); + + x.linspace(1.); + y.linspace(0.5, 0.5); + + nd4j::ops::matmul op; + auto results = op.evaluate({&x, &y}, {}, {0, 0, 1}); + auto z = results->at(0); + + ASSERT_EQ(Status::OK(), results->status()); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + delete results; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests14, matmul_test22) { + + auto x = NDArrayFactory::create('c', {3, 2}); + auto y = NDArrayFactory::create('c', {3, 5}); + auto exp = NDArrayFactory::create('f', {5, 2}, {37. , 41.5, 46. , 50.5, 55., 46. , 52. , 58. , 64. , 70.}); + + x.linspace(1.); + y.linspace(0.5, 0.5); + + nd4j::ops::matmul op; + auto results = op.evaluate({&x, &y}, {}, {1, 0, 1}); + auto z = results->at(0); + + ASSERT_EQ(Status::OK(), results->status()); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + delete results; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests14, matmul_test23) { + + auto x = NDArrayFactory::create('c', {3, 2}); + auto y = NDArrayFactory::create('c', {3, 5}); + auto exp = NDArrayFactory::create('f', {5, 2}, {37. , 41.5, 46. , 50.5, 55., 46. , 52. , 58. , 64. , 70.}); + + x.linspace(1.); + y.linspace(0.5, 0.5); + + nd4j::ops::matmul op; + auto results = op.evaluate({&x, &y}, {}, {1, 0, 1}); + auto z = results->at(0); + + ASSERT_EQ(Status::OK(), results->status()); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + delete results; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests14, matmul_test24) { + + auto x = NDArrayFactory::create('c', {2,2, 3,5}); + auto y = NDArrayFactory::create('c', {2,2, 4,3}); + auto exp = NDArrayFactory::create('f',{2,2, 4,5}, {4.6, 281.8, 89.2, 582.4, 10. , 314.2,108.1, 628.3, 15.4, 346.6,127. , 674.2, 20.8, 379. ,145.9, 720.1, 5.2, 289.6, 93.4, 593.8, + 11.5, 322.9,113.2, 640.6, 17.8, 356.2,133. , 687.4, 24.1, 389.5,152.8, 734.2, 5.8, 297.4, 97.6, 605.2, 13. , 331.6,118.3, 652.9, + 20.2, 365.8,139. , 700.6, 27.4, 400. ,159.7, 748.3, 6.4, 305.2,101.8, 616.6, 14.5, 340.3,123.4, 665.2, 22.6, 375.4,145. , 713.8, + 30.7, 410.5,166.6, 762.4, 7. , 313. ,106. , 628. , 16. , 349. ,128.5, 677.5, 25. , 385. ,151. , 727. , 34. , 421. ,173.5, 776.5}); + + x.linspace(1.); + y.linspace(0.1, 0.1); + + nd4j::ops::matmul op; + auto results = op.evaluate({&x, &y}, {}, {1, 1, 1}); + auto z = results->at(0); + + ASSERT_EQ(Status::OK(), results->status()); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + delete results; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests14, matmul_test25) { + + auto x = NDArrayFactory::create('f', {4, 3}); + auto y = NDArrayFactory::create('c', {4}); + auto exp = NDArrayFactory::create('f',{3}, {7., 8., 9.}); + + x.linspace(1.); + y.linspace(0.1, 0.1); + + nd4j::ops::matmul op; + auto results = op.evaluate({&x, &y}, {}, {1, 0}); + auto z = results->at(0); + + ASSERT_EQ(Status::OK(), results->status()); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + delete results; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests14, matmul_test26) { + + auto x = NDArrayFactory::create('f', {3}); + auto y = NDArrayFactory::create('c', {4, 3}); + auto exp = NDArrayFactory::create('f',{4}, {1.4, 3.2, 5., 6.8}); + + x.linspace(1.); + y.linspace(0.1, 0.1); + + nd4j::ops::matmul op; + auto results = op.evaluate({&x, &y}, {}, {0, 1}); + auto z = results->at(0); + + ASSERT_EQ(Status::OK(), results->status()); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + delete results; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests14, matmul_test27) { + + auto x = NDArrayFactory::create('f', {1, 1}); + auto y = NDArrayFactory::create('c', {1, 1}); + auto exp = NDArrayFactory::create('f',{1, 1}, {0.2}); + + x.linspace(2.); + y.linspace(0.1, 0.1); + + nd4j::ops::matmul op; + auto results = op.evaluate({&x, &y}, {}, {}); + auto z = results->at(0); + + ASSERT_EQ(Status::OK(), results->status()); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + delete results; +} + + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests14, matmul_test28) { + + auto x = NDArrayFactory::create('f', {1, 1}); + auto y = NDArrayFactory::create('c', {1, 1}); + auto exp = NDArrayFactory::create('f',{1, 1}, {0.2}); + + x.linspace(2.); + y.linspace(0.1, 0.1); + + nd4j::ops::matmul op; + auto results = op.evaluate({&x, &y}, {}, {1,1,1}); + auto z = results->at(0); + + ASSERT_EQ(Status::OK(), results->status()); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + delete results; +} + + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests14, matmul_test29) { + + auto x = NDArrayFactory::create('f', {1}); + auto y = NDArrayFactory::create('c', {1, 1}); + auto exp = NDArrayFactory::create('f',{1}, {0.2}); + + x.linspace(2.); + y.linspace(0.1, 0.1); + + nd4j::ops::matmul op; + auto results = op.evaluate({&x, &y}, {}, {}); + auto z = results->at(0); + + ASSERT_EQ(Status::OK(), results->status()); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + delete results; +} + + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests14, matmul_test30) { + + auto x = NDArrayFactory::create('f', {1,1}); + auto y = NDArrayFactory::create('c', {1}); + auto exp = NDArrayFactory::create('f',{1}, {0.2}); + + x.linspace(2.); + y.linspace(0.1, 0.1); + + nd4j::ops::matmul op; + auto results = op.evaluate({&x, &y}, {}, {1}); + auto z = results->at(0); + + ASSERT_EQ(Status::OK(), results->status()); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + delete results; +} + + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests14, matmul_test31) { + + auto x = NDArrayFactory::create('f', {4}); + auto y = NDArrayFactory::create('c', {4}); + auto exp = NDArrayFactory::create(3.); + + x.linspace(1.); + y.linspace(0.1, 0.1); + + nd4j::ops::matmul op; + auto results = op.evaluate({&x, &y}, {}, {1, 1}); + auto z = results->at(0); + + ASSERT_EQ(Status::OK(), results->status()); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + delete results; +} + + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests14, matmul_test32) { + + auto x = NDArrayFactory::create('f', {1}, {2.}); + auto y = NDArrayFactory::create('c', {1}, {3.}); + auto exp = NDArrayFactory::create(6.); + + nd4j::ops::matmul op; + auto results = op.evaluate({&x, &y}, {}, {1, 1}); + auto z = results->at(0); + + ASSERT_EQ(Status::OK(), results->status()); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + delete results; +} + + +TEST_F(DeclarableOpsTests14, matmul_test33) { + auto x = NDArrayFactory::create('c', {4, 3}); + auto y = NDArrayFactory::create('c', {4, 1}); + auto exp = NDArrayFactory::create('c',{ 3, 1}, {70, 80, 90}); + + x.linspace(1); + y.linspace(1); + + nd4j::ops::matmul op; + auto result = op.evaluate({&x, &y}, {}, {1, 0}); + ASSERT_EQ(ND4J_STATUS_OK, result->status()); + + auto z = result->at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + delete result; +} + + +TEST_F(DeclarableOpsTests14, matmul_test34) { + auto a = NDArrayFactory::create('c', {3, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + auto b = NDArrayFactory::create('c', {4}, {1, 2, 3, 4}); + auto exp = NDArrayFactory::create('c', {3}, {30, 70, 110}); + + nd4j::ops::matmul op; + auto result = op.evaluate({&a, &b}); + ASSERT_EQ(ND4J_STATUS_OK, result->status()); + + auto z = result->at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + delete result; +} + +TEST_F(DeclarableOpsTests14, matmul_test35) { + auto a = NDArrayFactory::create('c', {4}, {1, 2, 3, 4}); + auto b = NDArrayFactory::create('c', {4, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + auto exp = NDArrayFactory::create('c', {3}, {70, 80, 90}); + + nd4j::ops::matmul op; + auto result = op.evaluate({&a, &b}); + ASSERT_EQ(ND4J_STATUS_OK, result->status()); + + auto z = result->at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + delete result; +} + +TEST_F(DeclarableOpsTests14, matmul_test36) { + auto a = NDArrayFactory::create('c', {1, 4}, {1, 2, 3, 4}); + auto b = NDArrayFactory::create('c', {4, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + auto exp = NDArrayFactory::create('c', {1, 3}, {70, 80, 90}); + + nd4j::ops::matmul op; + auto result = op.evaluate({&a, &b}); + ASSERT_EQ(ND4J_STATUS_OK, result->status()); + + auto z = result->at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + delete result; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests14, matmul_test37) { + + NDArray a('c', {32, 12, 128, 64}, nd4j::DataType::FLOAT32); + NDArray b('c', {32, 12, 128, 64}, nd4j::DataType::FLOAT32); + NDArray c('c', {32,12,128,128}, nd4j::DataType::FLOAT32); + NDArray cExp('c', {32,12,128,128}, nd4j::DataType::FLOAT32); + + a = 1; + b = 1; + cExp = 64; //Each entry in output c is sum of 64 (1.0 x 1.0) multiplications + + nd4j::ops::matmul op; + auto status = op.execute({&a, &b}, {&c}, {}, {0,1}); + + ASSERT_EQ(ND4J_STATUS_OK, status); + + ASSERT_TRUE(cExp.isSameShape(c)); + ASSERT_TRUE(cExp.equalsTo(c)); +} + +// @Test +// public void testMmulRank4_simple(){ + +// INDArray arr1 = Nd4j.ones(DataType.FLOAT, 32, 12, 128, 64); +// INDArray arr2 = Nd4j.ones(DataType.FLOAT, 32, 12, 128, 64); + +// DynamicCustomOp op = DynamicCustomOp.builder("matmul") +// .addInputs(arr1, arr2) +// .addIntegerArguments(0, 1) //Transpose arr2 only +// .build(); + +// List shapes = op.calculateOutputShape(); +// assertEquals(1, shapes.size()); +// long[] shape = new long[]{32,12,128,128}; +// assertArrayEquals(shape, shapes.get(0).getShape()); + +// INDArray out = Nd4j.create(DataType.FLOAT, shape); + +// op.setOutputArgument(0, out); +// Nd4j.exec(op); +// // System.out.println(out); + +// INDArray exp = Nd4j.valueArrayOf(shape, 64.0, DataType.FLOAT); //Each entry in output is sum of 64 (1.0 x 1.0) multiplications +// assertEquals(exp, out); +// } diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests2.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests2.cpp index 0cf1cea2b..029a392f7 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests2.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests2.cpp @@ -397,27 +397,6 @@ TEST_F(DeclarableOpsTests2, NLP_Cbow_Test_1) { delete result; } -TEST_F(DeclarableOpsTests2, YetAnotherMatmulTest_1) { - auto A = NDArrayFactory::create('c', {3, 3}); - auto B = NDArrayFactory::create('c', {3, 1}); - auto exp = NDArrayFactory::create('c', {3, 1}, {14.00f, 32.00f, 50.00f}); - - A.linspace(1); - B.linspace(1); - - nd4j::ops::matmul op; - - auto result = op.evaluate({&A, &B}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, result->status()); - - auto z = result->at(0); - - ASSERT_TRUE(exp.equalsTo(z)); - - delete result; -} - TEST_F(DeclarableOpsTests2, Test_Squeeze_1) { auto x = NDArrayFactory::create('c', {2, 1, 3, 1, 1, 1, 4}); x.linspace(1); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests3.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests3.cpp index 04816b2b2..e7e95afcb 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests3.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests3.cpp @@ -789,120 +789,6 @@ TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_Validation_2) { } } -TEST_F(DeclarableOpsTests3, Test_Manual_Gemm_1) { - auto x= NDArrayFactory::create('c', {3, 4}, {1, 2, 3, 4, 5, 6, 7, 8 , 9, 10, 11, 12}); - auto y= NDArrayFactory::create('c', {4, 3}, {1, 2, 3, 4, 5, 6, 7, 8 , 9, 10, 11, 12}); - auto exp= NDArrayFactory::create('f', {4, 4}, {38.0, 44.0, 50.0, 56.0, 83.0, 98.0, 113.0, 128.0, 128.0, 152.0, 176.0, 200.0, 173.0, 206.0, 239.0, 272.0}); - - nd4j::ops::matmul op; - auto result = op.evaluate({&x, &y}, {}, {1, 1}); - ASSERT_EQ(ND4J_STATUS_OK, result->status()); - - auto z = result->at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - - delete result; -} - - -TEST_F(DeclarableOpsTests3, Test_Manual_Gemm_2) { - auto x= NDArrayFactory::create('c', {3, 4}, {1, 2, 3, 4, 5, 6, 7, 8 , 9, 10, 11, 12}); - auto y= NDArrayFactory::create('c', {4, 3}, {1, 2, 3, 4, 5, 6, 7, 8 , 9, 10, 11, 12}); - auto exp= NDArrayFactory::create('f', {3, 3}, {70.0, 158.0, 246.0, 80.0, 184.0, 288.0, 90.0, 210.0, 330.0}); - - nd4j::ops::matmul op; - auto result = op.evaluate({&x, &y}, {}, {0, 0}); - ASSERT_EQ(ND4J_STATUS_OK, result->status()); - - auto z = result->at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - - delete result; -} - - -TEST_F(DeclarableOpsTests3, Test_Manual_Gemm_3) { - auto x= NDArrayFactory::create('c', {1, 3}, {1, 2, 3}); - auto y= NDArrayFactory::create('c', {1, 4}, {1, 2, 3, 4}); - auto exp= NDArrayFactory::create('f', {3, 4}, {1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 3.0, 6.0, 9.0, 4.0, 8.0, 12.0}); - - nd4j::ops::matmul op; - auto result = op.evaluate({&x, &y}, {}, {1, 0}); - ASSERT_EQ(ND4J_STATUS_OK, result->status()); - - auto z = result->at(0); - - //z->printIndexedBuffer("z"); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - delete result; -} - -TEST_F(DeclarableOpsTests3, Test_Manual_Gemm_4) { - auto x= NDArrayFactory::create('c', {3, 1}, {1, 2, 3}); - auto y= NDArrayFactory::create('c', {4, 1}, {1, 2, 3, 4}); - auto exp= NDArrayFactory::create('f', {3, 4}, {1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 3.0, 6.0, 9.0, 4.0, 8.0, 12.0}); - - nd4j::ops::matmul op; - auto result = op.evaluate({&x, &y}, {}, {0, 1}); - ASSERT_EQ(ND4J_STATUS_OK, result->status()); - - auto z = result->at(0); - - //z->printIndexedBuffer("z"); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - delete result; -} - -TEST_F(DeclarableOpsTests3, Test_Manual_Gemm_5) { - auto x= NDArrayFactory::create('c', {3, 1}, {1, 2, 3}); - auto y= NDArrayFactory::create('c', {1, 4}, {1, 2, 3, 4}); - auto exp= NDArrayFactory::create('f', {3, 4}, {1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 3.0, 6.0, 9.0, 4.0, 8.0, 12.0}); - - nd4j::ops::matmul op; - auto result = op.evaluate({&x, &y}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result->status()); - - auto z = result->at(0); - - //z->printIndexedBuffer("z"); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - delete result; -} - -TEST_F(DeclarableOpsTests3, Test_Manual_Gemm_6) { - auto x= NDArrayFactory::create('c', {4, 1}, {1, 2, 3, 4}); - auto y= NDArrayFactory::create('c', {1, 4}, {1, 2, 3, 4}); - auto exp= NDArrayFactory::create('f', {4, 4}, {1,2, 3, 4,2,4, 6, 8,3,6, 9,12,4,8,12,16}); - - nd4j::ops::matmul op; - auto result = op.evaluate({&x, &y}); - ASSERT_EQ(ND4J_STATUS_OK, result->status()); - - auto z = result->at(0); - - //z->printIndexedBuffer("z"); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - delete result; -} - TEST_F(DeclarableOpsTests3, Test_ReverseDivide_1) { auto x= NDArrayFactory::create('c', {1, 3}, {2, 2, 2}); auto y= NDArrayFactory::create('c', {1, 3}, {4, 6, 8}); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests4.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests4.cpp index f04d24395..1fb700779 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests4.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests4.cpp @@ -809,26 +809,6 @@ TEST_F(DeclarableOpsTests4, Test_Reshape_Again) { delete result; } -TEST_F(DeclarableOpsTests4, Test_Gemv_Transpose_1) { - auto x = NDArrayFactory::create('c', {4, 3}); - auto y = NDArrayFactory::create('c', {4, 1}); - auto exp = NDArrayFactory::create('c',{ 3, 1}, {70, 80, 90}); - - x.linspace(1); - y.linspace(1); - - nd4j::ops::matmul op; - auto result = op.evaluate({&x, &y}, {}, {1, 0}); - ASSERT_EQ(ND4J_STATUS_OK, result->status()); - - auto z = result->at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - delete result; -} - TEST_F(DeclarableOpsTests4, Test_Split_1) { auto x = NDArrayFactory::create('c', {5, 30}); auto sizes = NDArrayFactory::create('c', {1, 3}, {4, 15, 11}); @@ -1166,57 +1146,6 @@ TEST_F(DeclarableOpsTests4, Test_Cross_3) { delete result; } -TEST_F(DeclarableOpsTests4, Test_Matmul_YATS_1) { - auto a = NDArrayFactory::create('c', {3, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); - auto b = NDArrayFactory::create('c', {4}, {1, 2, 3, 4}); - auto exp = NDArrayFactory::create('c', {3}, {30, 70, 110}); - - nd4j::ops::matmul op; - auto result = op.evaluate({&a, &b}); - ASSERT_EQ(ND4J_STATUS_OK, result->status()); - - auto z = result->at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - delete result; -} - -TEST_F(DeclarableOpsTests4, Test_Matmul_YATS_2) { - auto a = NDArrayFactory::create('c', {4}, {1, 2, 3, 4}); - auto b = NDArrayFactory::create('c', {4, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); - auto exp = NDArrayFactory::create('c', {3}, {70, 80, 90}); - - nd4j::ops::matmul op; - auto result = op.evaluate({&a, &b}); - ASSERT_EQ(ND4J_STATUS_OK, result->status()); - - auto z = result->at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - delete result; -} - -TEST_F(DeclarableOpsTests4, Test_Matmul_YATS_3) { - auto a = NDArrayFactory::create('c', {1, 4}, {1, 2, 3, 4}); - auto b = NDArrayFactory::create('c', {4, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); - auto exp = NDArrayFactory::create('c', {1, 3}, {70, 80, 90}); - - nd4j::ops::matmul op; - auto result = op.evaluate({&a, &b}); - ASSERT_EQ(ND4J_STATUS_OK, result->status()); - - auto z = result->at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - delete result; -} - TEST_F(DeclarableOpsTests4, Test_Add_119) { auto a = NDArrayFactory::create('c', {1, 4}, {1, 2, 3, 4}); auto b = NDArrayFactory::create('c', {4}, {1, 2, 3, 4}); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests7.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests7.cpp index 0a6f8e5e8..7a9bc1648 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests7.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests7.cpp @@ -5019,20 +5019,6 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_Prod_7) { delete result; } -TEST_F(DeclarableOpsTests7, Test_Matmul_Once_Again) { - auto x = NDArrayFactory::create('c', {1, 2}, {2.0f, 2.0f}); - auto y = NDArrayFactory::create('c', {2, 1}, {2.0f, 2.0f}); - auto exp = NDArrayFactory::create('c', {1, 1}, {8.0f}); - - nd4j::ops::matmul op; - auto result = op.evaluate({&x, &y}, {}, {}); - ASSERT_EQ(Status::OK(), result->status()); - - ASSERT_EQ(exp, *result->at(0)); - - delete result; -} - TYPED_TEST(TypedDeclarableOpsTests7, Test_Pnorm_Once_Again) { auto input = NDArrayFactory::create('c', {1, 1, 5, 5}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f, 25.0f}); auto exp = NDArrayFactory::create('c', {1, 1, 5, 5}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f, 25.0f}); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp index 11ebc1229..77634b052 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp @@ -932,208 +932,6 @@ TEST_F(DeclarableOpsTests9, tile_test1) { delete results; } - -////////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests9, matmul_test1) { - - auto x = NDArrayFactory::create('c', {3, 4}); - auto y = NDArrayFactory::create('c', {4, 3}); - auto exp = NDArrayFactory::create('f', {3, 3}, {35., 79., 123., 40., 92., 144., 45., 105., 165.}); - - x.linspace(1.); - y.linspace(0.5, 0.5); - - nd4j::ops::matmul op; - auto results = op.evaluate({&x, &y}, {}, {}); - auto z = results->at(0); - - ASSERT_EQ(Status::OK(), results->status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - delete results; -} - - -////////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests9, matmul_test2) { - - auto x = NDArrayFactory::create('c', {3, 4}); - auto y = NDArrayFactory::create('f', {4, 3}); - auto exp = NDArrayFactory::create('f', {3, 3}, {35., 79., 123.,40., 92., 144.,45.,105., 165.}); - - x.linspace(1.); - y.linspace(0.5, 0.5); - - nd4j::ops::matmul op; - auto results = op.evaluate({&x, &y}, {}, {}); - auto z = results->at(0); - - ASSERT_EQ(Status::OK(), results->status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - delete results; -} - -////////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests9, matmul_test3) { - - auto x = NDArrayFactory::create('f', {3, 4}); - auto y = NDArrayFactory::create('c', {4, 3}); - auto exp = NDArrayFactory::create('f', {3, 3}, {35., 79., 123.,40., 92., 144.,45.,105., 165.}); - - x.linspace(1.); - y.linspace(0.5, 0.5); - - nd4j::ops::matmul op; - auto results = op.evaluate({&x, &y}, {}, {}); - auto z = results->at(0); - - ASSERT_EQ(Status::OK(), results->status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - delete results; -} - - -////////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests9, matmul_test4) { - - auto x = NDArrayFactory::create ('f', {3, 4}); - auto y = NDArrayFactory::create('f', {4, 3}); - auto exp = NDArrayFactory::create('f', {3, 3}, {35., 79., 123.,40., 92., 144.,45.,105., 165.}); - - x.linspace(1.); - y.linspace(0.5, 0.5); - - nd4j::ops::matmul op; - auto results = op.evaluate({&x, &y}, {}, {}); - auto z = results->at(0); - - ASSERT_EQ(Status::OK(), results->status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - delete results; -} - - -////////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests9, matmul_test5) { - - auto x = NDArrayFactory::create('c', {4, 3}); - auto y = NDArrayFactory::create('c', {4, 3}); - auto exp = NDArrayFactory::create('f', {3, 3}, {83., 94., 105., 94., 107., 120., 105., 120., 135.}); - - x.linspace(1.); - y.linspace(0.5, 0.5); - - nd4j::ops::matmul op; - auto results = op.evaluate({&x, &y}, {}, {1}); - auto z = results->at(0); - - ASSERT_EQ(Status::OK(), results->status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - delete results; -} - -////////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests9, matmul_test6) { - - auto x = NDArrayFactory::create('c', {4, 3}); - auto y = NDArrayFactory::create('f', {3, 4}); - auto exp = NDArrayFactory::create('f', {3, 3}, {35., 40., 45., 79., 92., 105., 123., 144., 165.}); - - x.linspace(1.); - y.linspace(0.5, 0.5); - - nd4j::ops::matmul op; - auto results = op.evaluate({&x, &y}, {}, {1, 1}); - auto z = results->at(0); - - ASSERT_EQ(Status::OK(), results->status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - delete results; -} - -////////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests9, matmul_test7) { - - auto x = NDArrayFactory::create('c', {5, 3,4}); - auto y = NDArrayFactory::create('f', {5, 3,4}); - auto exp = NDArrayFactory::create('f',{5, 3,3}, {3. , 84.6, 281.4, 593.4, 1020.6, 7. , 107.8, 323.8, 655. , 1101.4,11. , 131. , 366.2, 716.6, 1182.2, - 7. , 107.8, 323.8, 655. , 1101.4,17.4, 137.4, 372.6, 723. , 1188.6,27.8, 167. , 421.4, 791. , 1275.8, - 11. , 131. , 366.2, 716.6, 1182.2,27.8, 167. , 421.4, 791. , 1275.8,44.6, 203. , 476.6, 865.4, 1369.4,}); - - x.linspace(1.); - y.linspace(0.1, 0.1); - - nd4j::ops::matmul op; - auto results = op.evaluate({&x, &y}, {}, {0, 1}); - auto z = results->at(0); - - ASSERT_EQ(Status::OK(), results->status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - delete results; -} - -////////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests9, matmul_test8) { - - auto x = NDArrayFactory::create('c', {2,5, 3,4}); - auto y = NDArrayFactory::create('f', {2,5, 3,4}); - auto exp = NDArrayFactory::create('f',{2,5, 3,3}, {3. , 1563. , 84.6, 2220.6, 281.4, 2993.4, 593.4, 3881.4,1020.6, 4884.6, 7. , 1663. , 107.8, 2339.8, 323.8, 3131.8, 655. , 4039. ,1101.4, 5061.4, - 11. , 1763. , 131. , 2459. , 366.2, 3270.2, 716.6, 4196.6,1182.2, 5238.2, 7. , 1663. , 107.8, 2339.8, 323.8, 3131.8, 655. , 4039. ,1101.4, 5061.4, - 17.4, 1769.4, 137.4, 2465.4, 372.6, 3276.6, 723. , 4203. ,1188.6, 5244.6, 27.8, 1875.8, 167. , 2591. , 421.4, 3421.4, 791. , 4367. ,1275.8, 5427.8, - 11. , 1763. , 131. , 2459. , 366.2, 3270.2, 716.6, 4196.6,1182.2, 5238.2, 27.8, 1875.8, 167. , 2591. , 421.4, 3421.4, 791. , 4367. ,1275.8, 5427.8, - 44.6, 1988.6, 203. , 2723. , 476.6, 3572.6, 865.4, 4537.4,1369.4, 5617.4}); - - x.linspace(1.); - y.linspace(0.1, 0.1); - - nd4j::ops::matmul op; - auto results = op.evaluate({&x, &y}, {}, {0, 1}); - auto z = results->at(0); - - ASSERT_EQ(Status::OK(), results->status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - delete results; -} - -////////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests9, matmul_test9) { - - auto x = NDArrayFactory::create('c', {2,5, 4,3}); - auto y = NDArrayFactory::create('f', {2,5, 3,4}); - auto exp = NDArrayFactory::create('f',{2,5, 3,3}, {7. , 1639. , 103. , 2311. , 314.2, 3098.2, 640.6, 4000.6,1082.2, 5018.2, 8. , 1664. , 108.8, 2340.8, 324.8, 3132.8, 656. , 4040. ,1102.4, 5062.4, - 9. , 1689. , 114.6, 2370.6, 335.4, 3167.4, 671.4, 4079.4,1122.6, 5106.6, 15.8, 1743.8, 131. , 2435. , 361.4, 3241.4, 707. , 4163. ,1167.8, 5199.8, - 18.4, 1770.4, 138.4, 2466.4, 373.6, 3277.6, 724. , 4204. ,1189.6, 5245.6, 21. , 1797. , 145.8, 2497.8, 385.8, 3313.8, 741. , 4245. ,1211.4, 5291.4, - 24.6, 1848.6, 159. , 2559. , 408.6, 3384.6, 773.4, 4325.4,1253.4, 5381.4, 28.8, 1876.8, 168. , 2592. , 422.4, 3422.4, 792. , 4368. ,1276.8, 5428.8, - 33. , 1905. , 177. , 2625. , 436.2, 3460.2, 810.6, 4410.6,1300.2, 5476.2}); - - x.linspace(1.); - y.linspace(0.1, 0.1); - - nd4j::ops::matmul op; - auto results = op.evaluate({&x, &y}, {}, {1, 1}); - auto z = results->at(0); - - ASSERT_EQ(Status::OK(), results->status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - delete results; -} //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, TestDropout_BP_1) { @@ -1325,325 +1123,6 @@ TEST_F(DeclarableOpsTests9, Test_AlphaDropout_BP_1) { delete ress2; } -////////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests9, matmul_test10) { - - auto x = NDArrayFactory::create('c', {1, 4, 3}); - auto y = NDArrayFactory::create('f', {1, 3, 4}); - auto exp = NDArrayFactory::create('f', {1, 3, 3}, {35., 40., 45., 79., 92., 105., 123., 144., 165.}); - - x.linspace(1.); - y.linspace(0.5, 0.5); - - nd4j::ops::matmul op; - auto results = op.evaluate({&x, &y}, {}, {1, 1}); - auto z = results->at(0); - - ASSERT_EQ(Status::OK(), results->status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - delete results; -} - -////////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests9, matmul_test11) { - - auto x = NDArrayFactory::create('c', {4, 1}); - auto y = NDArrayFactory::create('f', {1, 4}); - auto exp = NDArrayFactory::create('f', {1, 1}, {15}); - - x.linspace(1.); - y.linspace(0.5, 0.5); - nd4j::ops::matmul op; - auto results = op.evaluate({&x, &y}, {}, {1, 1}); - ASSERT_EQ(Status::OK(), results->status()); - - auto z = results->at(0); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - delete results; -} - -////////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests9, matmul_test12) { - - auto x = NDArrayFactory::create('c', {1, 4, 1}); - auto y = NDArrayFactory::create('f', {1, 1, 4}); - auto exp = NDArrayFactory::create('f', {1, 1, 1}, {15}); - - x.linspace(1.); - y.linspace(0.5, 0.5); - - nd4j::ops::matmul op; - auto results = op.evaluate({&x, &y}, {}, {1, 1}); - - ASSERT_EQ(Status::OK(), results->status()); - auto z = results->at(0); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - delete results; -} - -////////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests9, matmul_test13) { - - auto x = NDArrayFactory::create('c', {2, 3}); - auto y = NDArrayFactory::create('c', {3, 5}); - auto exp = NDArrayFactory::create('f', {5, 2}, {23. , 26. , 29. , 32. , 35., 50. , 57.5, 65. , 72.5, 80.}); - - x.linspace(1.); - y.linspace(0.5, 0.5); - - nd4j::ops::matmul op; - auto results = op.evaluate({&x, &y}, {}, {0, 0, 1}); - auto z = results->at(0); - - ASSERT_EQ(Status::OK(), results->status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - delete results; -} - -////////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests9, matmul_test14) { - - auto x = NDArrayFactory::create('c', {3, 2}); - auto y = NDArrayFactory::create('c', {3, 5}); - auto exp = NDArrayFactory::create('f', {5, 2}, {37. , 41.5, 46. , 50.5, 55., 46. , 52. , 58. , 64. , 70.}); - - x.linspace(1.); - y.linspace(0.5, 0.5); - - nd4j::ops::matmul op; - auto results = op.evaluate({&x, &y}, {}, {1, 0, 1}); - auto z = results->at(0); - - ASSERT_EQ(Status::OK(), results->status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - delete results; -} - -////////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests9, matmul_test15) { - - auto x = NDArrayFactory::create('c', {3, 2}); - auto y = NDArrayFactory::create('c', {3, 5}); - auto exp = NDArrayFactory::create('f', {5, 2}, {37. , 41.5, 46. , 50.5, 55., 46. , 52. , 58. , 64. , 70.}); - - x.linspace(1.); - y.linspace(0.5, 0.5); - - nd4j::ops::matmul op; - auto results = op.evaluate({&x, &y}, {}, {1, 0, 1}); - auto z = results->at(0); - - ASSERT_EQ(Status::OK(), results->status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - delete results; -} - -////////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests9, matmul_test16) { - - auto x = NDArrayFactory::create('c', {2,2, 3,5}); - auto y = NDArrayFactory::create('c', {2,2, 4,3}); - auto exp = NDArrayFactory::create('f',{2,2, 4,5}, {4.6, 281.8, 89.2, 582.4, 10. , 314.2,108.1, 628.3, 15.4, 346.6,127. , 674.2, 20.8, 379. ,145.9, 720.1, 5.2, 289.6, 93.4, 593.8, - 11.5, 322.9,113.2, 640.6, 17.8, 356.2,133. , 687.4, 24.1, 389.5,152.8, 734.2, 5.8, 297.4, 97.6, 605.2, 13. , 331.6,118.3, 652.9, - 20.2, 365.8,139. , 700.6, 27.4, 400. ,159.7, 748.3, 6.4, 305.2,101.8, 616.6, 14.5, 340.3,123.4, 665.2, 22.6, 375.4,145. , 713.8, - 30.7, 410.5,166.6, 762.4, 7. , 313. ,106. , 628. , 16. , 349. ,128.5, 677.5, 25. , 385. ,151. , 727. , 34. , 421. ,173.5, 776.5}); - - x.linspace(1.); - y.linspace(0.1, 0.1); - - nd4j::ops::matmul op; - auto results = op.evaluate({&x, &y}, {}, {1, 1, 1}); - auto z = results->at(0); - - ASSERT_EQ(Status::OK(), results->status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - delete results; -} - -////////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests9, matmul_test17) { - - auto x = NDArrayFactory::create('f', {4, 3}); - auto y = NDArrayFactory::create('c', {4}); - auto exp = NDArrayFactory::create('f',{3}, {7., 8., 9.}); - - x.linspace(1.); - y.linspace(0.1, 0.1); - - nd4j::ops::matmul op; - auto results = op.evaluate({&x, &y}, {}, {1, 0}); - auto z = results->at(0); - - ASSERT_EQ(Status::OK(), results->status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - delete results; -} - -////////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests9, matmul_test18) { - - auto x = NDArrayFactory::create('f', {3}); - auto y = NDArrayFactory::create('c', {4, 3}); - auto exp = NDArrayFactory::create('f',{4}, {1.4, 3.2, 5., 6.8}); - - x.linspace(1.); - y.linspace(0.1, 0.1); - - nd4j::ops::matmul op; - auto results = op.evaluate({&x, &y}, {}, {0, 1}); - auto z = results->at(0); - - ASSERT_EQ(Status::OK(), results->status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - delete results; -} - -////////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests9, matmul_test19) { - - auto x = NDArrayFactory::create('f', {1, 1}); - auto y = NDArrayFactory::create('c', {1, 1}); - auto exp = NDArrayFactory::create('f',{1, 1}, {0.2}); - - x.linspace(2.); - y.linspace(0.1, 0.1); - - nd4j::ops::matmul op; - auto results = op.evaluate({&x, &y}, {}, {}); - auto z = results->at(0); - - ASSERT_EQ(Status::OK(), results->status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - delete results; -} - - -////////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests9, matmul_test20) { - - auto x = NDArrayFactory::create('f', {1, 1}); - auto y = NDArrayFactory::create('c', {1, 1}); - auto exp = NDArrayFactory::create('f',{1, 1}, {0.2}); - - x.linspace(2.); - y.linspace(0.1, 0.1); - - nd4j::ops::matmul op; - auto results = op.evaluate({&x, &y}, {}, {1,1,1}); - auto z = results->at(0); - - ASSERT_EQ(Status::OK(), results->status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - delete results; -} - - -////////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests9, matmul_test21) { - - auto x = NDArrayFactory::create('f', {1}); - auto y = NDArrayFactory::create('c', {1, 1}); - auto exp = NDArrayFactory::create('f',{1}, {0.2}); - - x.linspace(2.); - y.linspace(0.1, 0.1); - - nd4j::ops::matmul op; - auto results = op.evaluate({&x, &y}, {}, {}); - auto z = results->at(0); - - ASSERT_EQ(Status::OK(), results->status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - delete results; -} - - -////////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests9, matmul_test22) { - - auto x = NDArrayFactory::create('f', {1,1}); - auto y = NDArrayFactory::create('c', {1}); - auto exp = NDArrayFactory::create('f',{1}, {0.2}); - - x.linspace(2.); - y.linspace(0.1, 0.1); - - nd4j::ops::matmul op; - auto results = op.evaluate({&x, &y}, {}, {1}); - auto z = results->at(0); - - ASSERT_EQ(Status::OK(), results->status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - delete results; -} - - -////////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests9, matmul_test23) { - - auto x = NDArrayFactory::create('f', {4}); - auto y = NDArrayFactory::create('c', {4}); - auto exp = NDArrayFactory::create(3.); - - x.linspace(1.); - y.linspace(0.1, 0.1); - - nd4j::ops::matmul op; - auto results = op.evaluate({&x, &y}, {}, {1, 1}); - auto z = results->at(0); - - ASSERT_EQ(Status::OK(), results->status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - delete results; -} - - -////////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests9, matmul_test24) { - - auto x = NDArrayFactory::create('f', {1}, {2.}); - auto y = NDArrayFactory::create('c', {1}, {3.}); - auto exp = NDArrayFactory::create(6.); - - nd4j::ops::matmul op; - auto results = op.evaluate({&x, &y}, {}, {1, 1}); - auto z = results->at(0); - - ASSERT_EQ(Status::OK(), results->status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - delete results; -} - TEST_F(DeclarableOpsTests9, test_range_int_1) { auto x0 = NDArrayFactory::create(0); auto x1 = NDArrayFactory::create(2); diff --git a/libnd4j/tests_cpu/layers_tests/MklDnnTests.cpp b/libnd4j/tests_cpu/layers_tests/MklDnnTests.cpp index d83e85f67..b01c9f98a 100644 --- a/libnd4j/tests_cpu/layers_tests/MklDnnTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/MklDnnTests.cpp @@ -64,8 +64,11 @@ TEST_F(MklDnnTests, helpers_includer) { nd4j::ops::platforms::PLATFORM_maxpool3dnew_bp_ENGINE_CPU maxpool3d_bp; nd4j::ops::platforms::PLATFORM_lrn_ENGINE_CPU lrn; + nd4j::ops::platforms::PLATFORM_batchnorm_ENGINE_CPU batchnorm; - printer({&conv2d, &conv2d_bp, &conv3d, &conv3d_bp, &avgpool2d, &avgpool2d_bp, &maxpool2d, &maxpool2d_bp, &avgpool3d, &avgpool3d_bp, &maxpool3d, &maxpool3d_bp, &lrn, &batchnorm}); + nd4j::ops::platforms::PLATFORM_matmul_ENGINE_CPU matmul; + + printer({&conv2d, &conv2d_bp, &conv3d, &conv3d_bp, &avgpool2d, &avgpool2d_bp, &maxpool2d, &maxpool2d_bp, &avgpool3d, &avgpool3d_bp, &maxpool3d, &maxpool3d_bp, &lrn, &batchnorm, &matmul}); #endif } \ No newline at end of file