cavis/libnd4j/include/ops/declarable/generic/blas/matmul.cpp

208 lines
8.4 KiB
C++
Raw Normal View History

2019-06-06 14:21:15 +02:00
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com, created on 07.10.2017.
// @author GS <sgazeos@gmail.com>, modified
// @author Yurii Shyrma (iuriish@yahoo.com), fully rewritten
//
#include <system/op_boilerplate.h>
2019-06-06 14:21:15 +02:00
#if NOT_EXCLUDED(OP_matmul)
#include <ops/declarable/CustomOperations.h>
#include <helpers/MmulHelper.h>
2019-06-06 14:21:15 +02:00
namespace sd {
2019-06-06 14:21:15 +02:00
namespace ops {
//////////////////////////////////////////////////////////////////////
CUSTOM_OP_IMPL(matmul, 2, 1, false, 0, -2) {
auto x = INPUT_VARIABLE(0);
auto y = INPUT_VARIABLE(1);
auto z = OUTPUT_VARIABLE(0);
MatMul for gemm/gemv calls (#365) * libnd4j added optional alpha and beta support to matmul Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j typos fixes Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j add optional alpha and beta to matmul_bp Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j one more typo fix Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j added optional alpha and beta to mkl implementation Signed-off-by: Oleg <oleg.semeniv@gmail.com> * MatMul alpha/beta on java side Signed-off-by: raver119 <raver119@gmail.com> * alpha/beta fix in libnd4j Signed-off-by: raver119 <raver119@gmail.com> * alpha/beta fix in matmul_bp Signed-off-by: raver119 <raver119@gmail.com> * restored view validation Signed-off-by: raver119 <raver119@gmail.com> * gemv/gemm now use MatMul op Signed-off-by: raver119 <raver119@gmail.com> * few tests fixed Signed-off-by: raver119 <raver119@gmail.com> * additional INDArray.mmul signature Signed-off-by: raver119 <raver119@gmail.com> * make C order default for INDArray.mmul, unless both A/B have F order Signed-off-by: raver119 <raver119@gmail.com> * Nd4j.gemm validation fix Signed-off-by: raver119 <raver119@gmail.com> * disable mkldnn matmul for xxf with beta != 0 case Signed-off-by: raver119 <raver119@gmail.com> * SimpleRnn workspace fix + timeouts Signed-off-by: Alex Black <blacka101@gmail.com> * two more tests + minor fix in matmul platform check Signed-off-by: raver119 <raver119@gmail.com> * Flaky test fixes Signed-off-by: Alex Black <blacka101@gmail.com> * propagate testresources profile Signed-off-by: raver119 <raver119@gmail.com> * Resources fix + flaky test fix Signed-off-by: Alex Black <blacka101@gmail.com> Co-authored-by: Oleg <oleg.semeniv@gmail.com> Co-authored-by: Alex Black <blacka101@gmail.com>
2020-04-10 16:57:02 +02:00
int iSize = (int) block.getIArguments()->size();
int transX = iSize > 0 ? INT_ARG(0) : 0;
int transY = iSize > 1 ? INT_ARG(1) : 0;
const int transZ = iSize > 2 ? INT_ARG(2) : 0;
MatMul for gemm/gemv calls (#365) * libnd4j added optional alpha and beta support to matmul Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j typos fixes Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j add optional alpha and beta to matmul_bp Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j one more typo fix Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j added optional alpha and beta to mkl implementation Signed-off-by: Oleg <oleg.semeniv@gmail.com> * MatMul alpha/beta on java side Signed-off-by: raver119 <raver119@gmail.com> * alpha/beta fix in libnd4j Signed-off-by: raver119 <raver119@gmail.com> * alpha/beta fix in matmul_bp Signed-off-by: raver119 <raver119@gmail.com> * restored view validation Signed-off-by: raver119 <raver119@gmail.com> * gemv/gemm now use MatMul op Signed-off-by: raver119 <raver119@gmail.com> * few tests fixed Signed-off-by: raver119 <raver119@gmail.com> * additional INDArray.mmul signature Signed-off-by: raver119 <raver119@gmail.com> * make C order default for INDArray.mmul, unless both A/B have F order Signed-off-by: raver119 <raver119@gmail.com> * Nd4j.gemm validation fix Signed-off-by: raver119 <raver119@gmail.com> * disable mkldnn matmul for xxf with beta != 0 case Signed-off-by: raver119 <raver119@gmail.com> * SimpleRnn workspace fix + timeouts Signed-off-by: Alex Black <blacka101@gmail.com> * two more tests + minor fix in matmul platform check Signed-off-by: raver119 <raver119@gmail.com> * Flaky test fixes Signed-off-by: Alex Black <blacka101@gmail.com> * propagate testresources profile Signed-off-by: raver119 <raver119@gmail.com> * Resources fix + flaky test fix Signed-off-by: Alex Black <blacka101@gmail.com> Co-authored-by: Oleg <oleg.semeniv@gmail.com> Co-authored-by: Alex Black <blacka101@gmail.com>
2020-04-10 16:57:02 +02:00
// optional use alpha nad beta
iSize = (int)block.getTArguments()->size();
double alpha = iSize > 0 ? T_ARG(0) : 1.0;
double beta = iSize > 1 ? T_ARG(1) : 0.0;
const int xRank = x->rankOf();
const int yRank = y->rankOf();
const int zRank = z->rankOf();
if (transZ) {
x = INPUT_VARIABLE(1);
y = INPUT_VARIABLE(0);
bool temp = transX;
transX = !transY;
transY = !temp;
}
const int xLastDim = transX ? -2 : -1;
const int yLastDim = transY ? -2 : -1;
const int xLastButOneDim = transX ? -1 : -2;
const int yLastButOneDim = transY ? -1 : -2;
// ******* input validation ******* //
REQUIRE_TRUE(xRank > 0 && yRank > 0, 0, "MATMUL OP: input arrays must have rank bigger than 0 (should not be scalars), but got instead: x rank = %i, y rank = %i !", xRank, yRank);
if (xRank == 1 && yRank == 1) { // dot case, output is scalar (or vector with length = 1)
REQUIRE_TRUE(x->lengthOf() == y->lengthOf(), 0, "MATMUL OP: since input arrays are vectors they must have the same length, but got x length = %i, y length = %i !", x->lengthOf(), y->lengthOf());
} else if (xRank == 1 && yRank == 2) { // vector x matrix, i.e. [4] x [4,5] = [5], output is vector
REQUIRE_TRUE(x->lengthOf() == y->sizeAt(yLastButOneDim), 0, "MATMUL OP: input arrays have inconsistent shapes for vector-matrix product: x %s, y %s !", ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str());
} else if (xRank == 2 && yRank == 1) { // matrix x vector , i.e. [4,5] x [5] = [4], output is vector
REQUIRE_TRUE(x->sizeAt(xLastDim) == y->lengthOf(), 0, "MATMUL OP: input arrays have inconsistent shapes for matrix-vector product: x %s, y %s !", ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str());
} else {
REQUIRE_TRUE(xRank == yRank && yRank == zRank, 0, "MATMUL OP: input and output arrays must have the same rank, but got instead: x rank = %i, y rank = %i, z rank = %i !", xRank, yRank, zRank);
REQUIRE_TRUE(x->sizeAt(xLastDim) == y->sizeAt(yLastButOneDim) && x->sizeAt(xLastButOneDim) == z->sizeAt(-2) && y->sizeAt(yLastDim) == z->sizeAt(-1), 0, "MATMUL OP: input/output arrays have inconsistent shapes for matrix product: x %s, y %s, z %s !", ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str(), ShapeUtils::shapeAsString(z).c_str());
if (xRank > 2) // outer dims must be the same
for (int i = 0; i < xRank - 2; ++i)
REQUIRE_TRUE(x->sizeAt(i) == y->sizeAt(i) && y->sizeAt(i) == z->sizeAt(i), 0, "MATMUL OP: input/output arrays have inconsistent shapes for matrix product: x %s, y %s, z %s !", ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str(), ShapeUtils::shapeAsString(z).c_str());
}
// ******* end of input validation ******* //
MatMul for gemm/gemv calls (#365) * libnd4j added optional alpha and beta support to matmul Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j typos fixes Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j add optional alpha and beta to matmul_bp Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j one more typo fix Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j added optional alpha and beta to mkl implementation Signed-off-by: Oleg <oleg.semeniv@gmail.com> * MatMul alpha/beta on java side Signed-off-by: raver119 <raver119@gmail.com> * alpha/beta fix in libnd4j Signed-off-by: raver119 <raver119@gmail.com> * alpha/beta fix in matmul_bp Signed-off-by: raver119 <raver119@gmail.com> * restored view validation Signed-off-by: raver119 <raver119@gmail.com> * gemv/gemm now use MatMul op Signed-off-by: raver119 <raver119@gmail.com> * few tests fixed Signed-off-by: raver119 <raver119@gmail.com> * additional INDArray.mmul signature Signed-off-by: raver119 <raver119@gmail.com> * make C order default for INDArray.mmul, unless both A/B have F order Signed-off-by: raver119 <raver119@gmail.com> * Nd4j.gemm validation fix Signed-off-by: raver119 <raver119@gmail.com> * disable mkldnn matmul for xxf with beta != 0 case Signed-off-by: raver119 <raver119@gmail.com> * SimpleRnn workspace fix + timeouts Signed-off-by: Alex Black <blacka101@gmail.com> * two more tests + minor fix in matmul platform check Signed-off-by: raver119 <raver119@gmail.com> * Flaky test fixes Signed-off-by: Alex Black <blacka101@gmail.com> * propagate testresources profile Signed-off-by: raver119 <raver119@gmail.com> * Resources fix + flaky test fix Signed-off-by: Alex Black <blacka101@gmail.com> Co-authored-by: Oleg <oleg.semeniv@gmail.com> Co-authored-by: Alex Black <blacka101@gmail.com>
2020-04-10 16:57:02 +02:00
MmulHelper::matmul(x, y, z, transX, transY, alpha, beta);
return Status::OK();
}
DECLARE_SYN(mMul, matmul);
DECLARE_SYN(mmul, matmul);
DECLARE_SYN(gemm, matmul);
DECLARE_SYN(gemv, matmul);
DECLARE_SYN(dot, matmul);
//////////////////////////////////////////////////////////////////////
DECLARE_SHAPE_FN(matmul) {
auto xShapeInfo = inputShape->at(0);
auto yShapeInfo = inputShape->at(1);
const int iSize = (int) block.getIArguments()->size();
int transX = iSize > 0 ? INT_ARG(0) : 0;
int transY = iSize > 1 ? INT_ARG(1) : 0;
const int transZ = iSize > 2 ? INT_ARG(2) : 0;
REQUIRE_TRUE(xShapeInfo[0] > 0 && yShapeInfo[0] > 0, 0,
"MATMUL OP: input arrays must have rank bigger than 0 (should not be scalars), but got instead: x rank = %i, y rank = %i !",
xShapeInfo[0], yShapeInfo[0]);
if (transZ) {
xShapeInfo = inputShape->at(1);
yShapeInfo = inputShape->at(0);
bool temp = transX;
transX = !transY;
transY = !temp;
}
auto zShapeOnly = ShapeUtils::evalShapeForMatmul(xShapeInfo, yShapeInfo, transX, transY);
auto dtypeX = ArrayOptions::dataType(xShapeInfo);
auto dtypeY = ArrayOptions::dataType(yShapeInfo);
auto xOrder = shape::order(xShapeInfo);
auto yOrder = shape::order(yShapeInfo);
auto zOrder = xOrder == 'c' && yOrder == 'c' ? 'c' : 'f';
// we just pick the higher data type out of X and Y
auto dtypeZ = dtypeX > dtypeY ? dtypeX : dtypeY;
auto newShape = ConstantShapeHelper::getInstance().createShapeInfo(dtypeZ, zOrder, zShapeOnly);
return SHAPELIST(newShape);
}
//////////////////////////////////////////////////////////////////////
DECLARE_TYPES(matmul) {
getOpDescriptor()
compression ops (#436) * Added declarations for decode/encode_bitmap ops. Signed-off-by: shugeo <sgazeos@gmail.com> * Added implementation for bitmap encoding/decoding ops. Signed-off-by: shugeo <sgazeos@gmail.com> * Added helpers for encode/decode bitmap ops. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored encodingBitmap helper. Signed-off-by: shugeo <sgazeos@gmail.com> * threshold encode/decode skeleton * helper skeleton * minor import fix * encoder shape fn & op impl * thresholdEncode cpu impl Signed-off-by: raver119@gmail.com <raver119@gmail.com> * thresholdDecode cpu impl Signed-off-by: raver119@gmail.com <raver119@gmail.com> * Only cosmetical changes. Signed-off-by: shugeo <sgazeos@gmail.com> * placeholder Signed-off-by: raver119@gmail.com <raver119@gmail.com> * Added cuda implementation for bitmap decode helper. Signed-off-by: shugeo <sgazeos@gmail.com> * cuda thresholdEstimate Signed-off-by: raver119@gmail.com <raver119@gmail.com> * cuda thresholdDecode Signed-off-by: raver119@gmail.com <raver119@gmail.com> * next step Signed-off-by: raver119@gmail.com <raver119@gmail.com> * - nano cmakelist update (get rid of Clion section) - fixed forgotten throw in AtomicTests Signed-off-by: raver119@gmail.com <raver119@gmail.com> * thesholdEncode cuda impl Signed-off-by: raver119@gmail.com <raver119@gmail.com> * Added tests for bitmap encoding/decoding ops. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed tests for encode/decode bitmaps. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored decode/encode helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed crashes with bitmap decode/encode helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * bitmap encode/decode CPU Signed-off-by: raver119@gmail.com <raver119@gmail.com> * bitmap encode/decode CUDA Signed-off-by: raver119@gmail.com <raver119@gmail.com> * C API removed for threshold/bitmap encode Signed-off-by: raver119@gmail.com <raver119@gmail.com> * EncodeBitmap/DecodeBitmap Java side Signed-off-by: raver119@gmail.com <raver119@gmail.com> * EncodeThreshold/DecodeThreshold Java side Signed-off-by: raver119@gmail.com <raver119@gmail.com> * EncodeThreshold/DecodeThreshold Java side Signed-off-by: raver119@gmail.com <raver119@gmail.com> * few more tests for threshold encoding Signed-off-by: raver119@gmail.com <raver119@gmail.com> * minor test tweak Signed-off-by: raver119@gmail.com <raver119@gmail.com> * two special tests Signed-off-by: raver119@gmail.com <raver119@gmail.com> * encodeBitmap CPU fix Signed-off-by: raver119@gmail.com <raver119@gmail.com> * parallel_long/parallel_double proper spans fix Signed-off-by: raver119@gmail.com <raver119@gmail.com> * encodeThreshold CUDA fix Signed-off-by: raver119@gmail.com <raver119@gmail.com> * nano fix Signed-off-by: raver119@gmail.com <raver119@gmail.com> * grid tweaks Signed-off-by: raver119@gmail.com <raver119@gmail.com> * RTX adaptation for thresholdEncode Signed-off-by: raver119 <raver119@gmail.com> * don't allow threshold encoding for length < 2 Signed-off-by: raver119@gmail.com <raver119@gmail.com> * get rid of NDArrayCompressor in EncodingHandler Signed-off-by: raver119@gmail.com <raver119@gmail.com> * one more minor update of EncodingHandler Signed-off-by: raver119@gmail.com <raver119@gmail.com> * one more minor tweak of EncodingHandler Signed-off-by: raver119@gmail.com <raver119@gmail.com> * - matmul allows integer data types use - EncodingHandler boundary default value - few tests for integer matmul Signed-off-by: raver119@gmail.com <raver119@gmail.com> * minor fix of CUDA bitmap encode Signed-off-by: raver119@gmail.com <raver119@gmail.com> * boundary changed to integer everywhere Signed-off-by: raver119@gmail.com <raver119@gmail.com> * boundary changed to integer everywhere Signed-off-by: raver119@gmail.com <raver119@gmail.com> * re-enable CUDA deallocator Signed-off-by: raver119@gmail.com <raver119@gmail.com> * threshold encoder fix for systems without omp Signed-off-by: raver119@gmail.com <raver119@gmail.com> * - encode_threshold now requires non-negative boundary - minor tweak in EncodingHandler Signed-off-by: raver119@gmail.com <raver119@gmail.com> * restore parallelism in decode_bitmap Signed-off-by: raver119@gmail.com <raver119@gmail.com> * fall back to omp for encode_bitmap cpu Signed-off-by: raver119@gmail.com <raver119@gmail.com> * single time casts Signed-off-by: raver119@gmail.com <raver119@gmail.com> * - additional test for encode_threshold - sync buffers to device before calling for shape function Signed-off-by: raver119@gmail.com <raver119@gmail.com> Co-authored-by: shugeo <sgazeos@gmail.com>
2020-05-08 19:59:39 +02:00
->setAllowedInputTypes(0, {ALL_FLOATS, ALL_INTS})
->setAllowedInputTypes(1, {ALL_FLOATS, ALL_INTS})
->setAllowedOutputTypes(0, {ALL_FLOATS, ALL_INTS});
}
//////////////////////////////////////////////////////////////////////
CUSTOM_OP_IMPL(matmul_bp, 3, 2, false, 0, -2) {
auto x = INPUT_VARIABLE(0);
auto y = INPUT_VARIABLE(1);
auto eps = INPUT_VARIABLE(2);
auto dldx = OUTPUT_VARIABLE(0);
auto dldy = OUTPUT_VARIABLE(1);
MatMul for gemm/gemv calls (#365) * libnd4j added optional alpha and beta support to matmul Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j typos fixes Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j add optional alpha and beta to matmul_bp Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j one more typo fix Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j added optional alpha and beta to mkl implementation Signed-off-by: Oleg <oleg.semeniv@gmail.com> * MatMul alpha/beta on java side Signed-off-by: raver119 <raver119@gmail.com> * alpha/beta fix in libnd4j Signed-off-by: raver119 <raver119@gmail.com> * alpha/beta fix in matmul_bp Signed-off-by: raver119 <raver119@gmail.com> * restored view validation Signed-off-by: raver119 <raver119@gmail.com> * gemv/gemm now use MatMul op Signed-off-by: raver119 <raver119@gmail.com> * few tests fixed Signed-off-by: raver119 <raver119@gmail.com> * additional INDArray.mmul signature Signed-off-by: raver119 <raver119@gmail.com> * make C order default for INDArray.mmul, unless both A/B have F order Signed-off-by: raver119 <raver119@gmail.com> * Nd4j.gemm validation fix Signed-off-by: raver119 <raver119@gmail.com> * disable mkldnn matmul for xxf with beta != 0 case Signed-off-by: raver119 <raver119@gmail.com> * SimpleRnn workspace fix + timeouts Signed-off-by: Alex Black <blacka101@gmail.com> * two more tests + minor fix in matmul platform check Signed-off-by: raver119 <raver119@gmail.com> * Flaky test fixes Signed-off-by: Alex Black <blacka101@gmail.com> * propagate testresources profile Signed-off-by: raver119 <raver119@gmail.com> * Resources fix + flaky test fix Signed-off-by: Alex Black <blacka101@gmail.com> Co-authored-by: Oleg <oleg.semeniv@gmail.com> Co-authored-by: Alex Black <blacka101@gmail.com>
2020-04-10 16:57:02 +02:00
int iSize = (int) block.getIArguments()->size();
int transX = iSize > 0 ? INT_ARG(0) : 0;
int transY = iSize > 1 ? INT_ARG(1) : 0;
const int transZ = iSize > 2 ? INT_ARG(2) : 0;
2019-06-06 14:21:15 +02:00
MatMul for gemm/gemv calls (#365) * libnd4j added optional alpha and beta support to matmul Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j typos fixes Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j add optional alpha and beta to matmul_bp Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j one more typo fix Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j added optional alpha and beta to mkl implementation Signed-off-by: Oleg <oleg.semeniv@gmail.com> * MatMul alpha/beta on java side Signed-off-by: raver119 <raver119@gmail.com> * alpha/beta fix in libnd4j Signed-off-by: raver119 <raver119@gmail.com> * alpha/beta fix in matmul_bp Signed-off-by: raver119 <raver119@gmail.com> * restored view validation Signed-off-by: raver119 <raver119@gmail.com> * gemv/gemm now use MatMul op Signed-off-by: raver119 <raver119@gmail.com> * few tests fixed Signed-off-by: raver119 <raver119@gmail.com> * additional INDArray.mmul signature Signed-off-by: raver119 <raver119@gmail.com> * make C order default for INDArray.mmul, unless both A/B have F order Signed-off-by: raver119 <raver119@gmail.com> * Nd4j.gemm validation fix Signed-off-by: raver119 <raver119@gmail.com> * disable mkldnn matmul for xxf with beta != 0 case Signed-off-by: raver119 <raver119@gmail.com> * SimpleRnn workspace fix + timeouts Signed-off-by: Alex Black <blacka101@gmail.com> * two more tests + minor fix in matmul platform check Signed-off-by: raver119 <raver119@gmail.com> * Flaky test fixes Signed-off-by: Alex Black <blacka101@gmail.com> * propagate testresources profile Signed-off-by: raver119 <raver119@gmail.com> * Resources fix + flaky test fix Signed-off-by: Alex Black <blacka101@gmail.com> Co-authored-by: Oleg <oleg.semeniv@gmail.com> Co-authored-by: Alex Black <blacka101@gmail.com>
2020-04-10 16:57:02 +02:00
// optional use alpha nad beta
iSize = (int)block.getTArguments()->size();
double alpha = iSize > 0 ? T_ARG(0) : 1.0;
double beta = iSize > 1 ? T_ARG(1) : 0.0;
2019-06-06 14:21:15 +02:00
/*
In: x=[a,b], y=[b,c]
tX tY tZ x y z dz dLdx dLdy
F F F [a,b] [b,c] [a,c] [a,c] [a,c]*[b,c]T = [a,b] x*yT [a,b]T*[a,c] = [b,c] xT*y
T F F [b,a] [b,c] [a,c] [a,c] ([a,c]*[b,c]T)T = [b,a] (x*yT)T [b,a]*[a,c] = [b,c] x*y
F T F [a,b] [c,b] [a,c] [a,c] ([a,c]*[c,b]) = [a,b] x*y [a,b]T*[a,c] = [b,c] ->T xT*y
T T F [b,a] [c,b] [a,c] [a,c] ([a,c]*[c,b])T = [b,a] (x*y)T [b,a]*[a,c] = [b,c] ->T x*y
F F T [a,b] [b,c] [c,a] [c,a]
*/
sd::ops::matmul op;
MatMul for gemm/gemv calls (#365) * libnd4j added optional alpha and beta support to matmul Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j typos fixes Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j add optional alpha and beta to matmul_bp Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j one more typo fix Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j added optional alpha and beta to mkl implementation Signed-off-by: Oleg <oleg.semeniv@gmail.com> * MatMul alpha/beta on java side Signed-off-by: raver119 <raver119@gmail.com> * alpha/beta fix in libnd4j Signed-off-by: raver119 <raver119@gmail.com> * alpha/beta fix in matmul_bp Signed-off-by: raver119 <raver119@gmail.com> * restored view validation Signed-off-by: raver119 <raver119@gmail.com> * gemv/gemm now use MatMul op Signed-off-by: raver119 <raver119@gmail.com> * few tests fixed Signed-off-by: raver119 <raver119@gmail.com> * additional INDArray.mmul signature Signed-off-by: raver119 <raver119@gmail.com> * make C order default for INDArray.mmul, unless both A/B have F order Signed-off-by: raver119 <raver119@gmail.com> * Nd4j.gemm validation fix Signed-off-by: raver119 <raver119@gmail.com> * disable mkldnn matmul for xxf with beta != 0 case Signed-off-by: raver119 <raver119@gmail.com> * SimpleRnn workspace fix + timeouts Signed-off-by: Alex Black <blacka101@gmail.com> * two more tests + minor fix in matmul platform check Signed-off-by: raver119 <raver119@gmail.com> * Flaky test fixes Signed-off-by: Alex Black <blacka101@gmail.com> * propagate testresources profile Signed-off-by: raver119 <raver119@gmail.com> * Resources fix + flaky test fix Signed-off-by: Alex Black <blacka101@gmail.com> Co-authored-by: Oleg <oleg.semeniv@gmail.com> Co-authored-by: Alex Black <blacka101@gmail.com>
2020-04-10 16:57:02 +02:00
op.execute({eps, y}, {dldx}, {alpha, beta}, {transZ, !transY, transX}, {});
op.execute({x, eps}, {dldy}, {alpha, beta}, {!transX, transZ, transY}, {});
2019-06-06 14:21:15 +02:00
return Status::OK();
}
2019-06-06 14:21:15 +02:00
//////////////////////////////////////////////////////////////////////
DECLARE_SHAPE_FN(matmul_bp) {
Nd4jLong *xShapeInfo;
Nd4jLong *yShapeInfo;
2019-06-06 14:21:15 +02:00
COPY_SHAPE(inputShape->at(0), xShapeInfo);
COPY_SHAPE(inputShape->at(1), yShapeInfo);
2019-06-06 14:21:15 +02:00
return SHAPELIST(CONSTANT(xShapeInfo), CONSTANT(yShapeInfo));
}
2019-06-06 14:21:15 +02:00
//////////////////////////////////////////////////////////////////////
DECLARE_TYPES(matmul_bp) {
getOpDescriptor()
->setAllowedInputTypes(0, {ALL_FLOATS})
->setAllowedInputTypes(1, {ALL_FLOATS})
->setAllowedInputTypes(2, {ALL_FLOATS})
->setAllowedOutputTypes(0, {ALL_FLOATS})
->setAllowedOutputTypes(1, {ALL_FLOATS});
}
2019-06-06 14:21:15 +02:00
}
2019-06-06 14:21:15 +02:00
}
#endif