Softmax operation implementation for mkldnn (#286)

* libnd4j first step of softmax mkldnn implementation

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j raw implementation of mkldnn softmax

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j merge master and added softmax to MklDnnTests

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j some corrections for softmax mkldnn

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j merge branch, fixed problem with negative axis, fixed dnnl::memory::format_tag selection, test cases added

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j minor corrections to avoid risk connected with negative axis usage

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j fixed windows builds, added switcher to use mkldnn sofmax version only for 3D, 4D, 5D, 6D arrays

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j fixed dataType selection per request

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j fix for mac and windows builds

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j builds fix

Signed-off-by: Oleg <oleg.semeniv@gmail.com>
master
Oleh 2020-03-04 18:36:42 +02:00 committed by GitHub
parent 1c89512ec0
commit 4d81af9fe9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 892 additions and 638 deletions

View File

@ -14,14 +14,15 @@
* SPDX-License-Identifier: Apache-2.0 * SPDX-License-Identifier: Apache-2.0
******************************************************************************/ ******************************************************************************/
// //
// @author saudet // @author saudet
// @author Yurii Shyrma (iuriish@yahoo.com) // @author Yurii Shyrma (iuriish@yahoo.com)
// //
#ifndef DEV_TESTS_MKLDNNUTILS_H #ifndef DEV_TESTS_MKLDNNUTILS_H
#define DEV_TESTS_MKLDNNUTILS_H #define DEV_TESTS_MKLDNNUTILS_H
#include <legacy/NativeOps.h> #include <legacy/NativeOps.h>
#include <array/NDArray.h> #include <array/NDArray.h>
#include <dnnl.hpp> #include <dnnl.hpp>
@ -33,7 +34,7 @@
using namespace samediff; using namespace samediff;
namespace sd{ namespace sd {
namespace ops { namespace ops {
namespace platforms { namespace platforms {
/** /**
@ -86,64 +87,67 @@ namespace sd{
DECLARE_PLATFORM(depthwise_conv2d_bp, ENGINE_CPU); DECLARE_PLATFORM(depthwise_conv2d_bp, ENGINE_CPU);
DECLARE_PLATFORM(matmul, ENGINE_CPU); DECLARE_PLATFORM(matmul, ENGINE_CPU);
DECLARE_PLATFORM(softmax, ENGINE_CPU);
} }
} }
namespace mkldnnUtils { namespace mkldnnUtils {
void poolingMKLDNN(const NDArray *input, NDArray *output, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int isNCHW, const dnnl::algorithm mode); void poolingMKLDNN(const NDArray* input, NDArray* output, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int isNCHW, const dnnl::algorithm mode);
void poolingBpMKLDNN(const NDArray *input, const NDArray *gradO, NDArray *gradI, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int isNCHW, const dnnl::algorithm mode); void poolingBpMKLDNN(const NDArray* input, const NDArray* gradO, NDArray* gradI, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int isNCHW, const dnnl::algorithm mode);
void getMKLDNNMemoryDescLrn(const NDArray* src, const NDArray* diff_src, const NDArray* dst, void getMKLDNNMemoryDescLrn(const NDArray* src, const NDArray* diff_src, const NDArray* dst,
dnnl::memory::desc* lrn_src_md, dnnl::memory::desc* lrn_diff_src_md, dnnl::memory::desc* lrn_dst_md, dnnl::memory::desc* lrn_src_md, dnnl::memory::desc* lrn_diff_src_md, dnnl::memory::desc* lrn_dst_md,
dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, int axis); dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, int axis);
dnnl::engine& getEngine(void *ptr); dnnl::engine& getEngine(void* ptr);
/** /**
* Utility methods for MKLDNN * Utility methods for MKLDNN
*/ */
/* void getMKLDNNMemoryDescConv2d( /* void getMKLDNNMemoryDescConv2d(
int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, const int paddingMode, bool isNCHW, int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, const int paddingMode, bool isNCHW,
int bS, int iC, int iH, int iW, int oC, int oH, int oW, const NDArray* src, const NDArray* diff_src, int bS, int iC, int iH, int iW, int oC, int oH, int oW, const NDArray* src, const NDArray* diff_src,
const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst, const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst,
dnnl::memory::desc* conv_src_md, dnnl::memory::desc* conv_diff_src_md, dnnl::memory::desc* conv_weights_md, dnnl::memory::desc* conv_src_md, dnnl::memory::desc* conv_diff_src_md, dnnl::memory::desc* conv_weights_md,
dnnl::memory::desc* conv_diff_weights_md, dnnl::memory::desc* conv_bias_md, dnnl::memory::desc* conv_dst_md, dnnl::memory::desc* conv_diff_weights_md, dnnl::memory::desc* conv_bias_md, dnnl::memory::desc* conv_dst_md,
dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_weights_md, dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_weights_md,
dnnl::memory::desc* user_diff_weights_md, dnnl::memory::desc* user_bias_md, dnnl::memory::desc* user_dst_md, dnnl::memory::desc* user_diff_weights_md, dnnl::memory::desc* user_bias_md, dnnl::memory::desc* user_dst_md,
dnnl::memory::dims& conv_strides, dnnl::memory::dims& conv_padding, dnnl::memory::dims& conv_padding_r, dnnl::memory::dims& conv_dilation); dnnl::memory::dims& conv_strides, dnnl::memory::dims& conv_padding, dnnl::memory::dims& conv_padding_r, dnnl::memory::dims& conv_dilation);
void getMKLDNNMemoryDescConv3d( void getMKLDNNMemoryDescConv3d(
int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, int dD, int dH, int dW, bool isSameMode, bool isNCDHW, int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, int dD, int dH, int dW, bool isSameMode, bool isNCDHW,
int bS, int iC, int iD, int iH, int iW, int oC, int oD, int oH, int oW, const NDArray* src, const NDArray* diff_src, int bS, int iC, int iD, int iH, int iW, int oC, int oD, int oH, int oW, const NDArray* src, const NDArray* diff_src,
const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst, const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst,
dnnl::memory::desc* conv_src_md, dnnl::memory::desc* conv_diff_src_md, dnnl::memory::desc* conv_weights_md, dnnl::memory::desc* conv_src_md, dnnl::memory::desc* conv_diff_src_md, dnnl::memory::desc* conv_weights_md,
dnnl::memory::desc* conv_diff_weights_md, dnnl::memory::desc* conv_bias_md, dnnl::memory::desc* conv_dst_md, dnnl::memory::desc* conv_diff_weights_md, dnnl::memory::desc* conv_bias_md, dnnl::memory::desc* conv_dst_md,
dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_weights_md, dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_weights_md,
dnnl::memory::desc* user_diff_weights_md, dnnl::memory::desc* user_bias_md, dnnl::memory::desc* user_dst_md, dnnl::memory::desc* user_diff_weights_md, dnnl::memory::desc* user_bias_md, dnnl::memory::desc* user_dst_md,
dnnl::memory::dims& conv_strides, dnnl::memory::dims& conv_padding, dnnl::memory::dims& conv_padding_r, dnnl::memory::dims& conv_dilation); dnnl::memory::dims& conv_strides, dnnl::memory::dims& conv_padding, dnnl::memory::dims& conv_padding_r, dnnl::memory::dims& conv_dilation);
void getMKLDNNMemoryDescPool2d( void getMKLDNNMemoryDescPool2d(
int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, int poolingMode, int extraParam0, bool isNCHW, int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, int poolingMode, int extraParam0, bool isNCHW,
int bS, int iC, int iH, int iW, int oC, int oH, int oW, int bS, int iC, int iH, int iW, int oC, int oH, int oW,
const NDArray* src, const NDArray* diff_src, const NDArray* dst, dnnl::algorithm& algorithm, const NDArray* src, const NDArray* diff_src, const NDArray* dst, dnnl::algorithm& algorithm,
dnnl::memory::desc* pool_src_md, dnnl::memory::desc* pool_diff_src_md, dnnl::memory::desc* pool_dst_md, dnnl::memory::desc* pool_src_md, dnnl::memory::desc* pool_diff_src_md, dnnl::memory::desc* pool_dst_md,
dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md,
dnnl::memory::dims& pool_strides, dnnl::memory::dims& pool_kernel, dnnl::memory::dims& pool_padding, dnnl::memory::dims& pool_padding_r); dnnl::memory::dims& pool_strides, dnnl::memory::dims& pool_kernel, dnnl::memory::dims& pool_padding, dnnl::memory::dims& pool_padding_r);
void getMKLDNNMemoryDescPool3d( void getMKLDNNMemoryDescPool3d(
int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, int dD, int dH, int dW, int poolingMode, int extraParam0, bool isNCDHW, int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, int dD, int dH, int dW, int poolingMode, int extraParam0, bool isNCDHW,
int bS, int iC, int iD, int iH, int iW, int oC, int oD, int oH, int oW, int bS, int iC, int iD, int iH, int iW, int oC, int oD, int oH, int oW,
const NDArray* src, const NDArray* diff_src, const NDArray* dst, dnnl::algorithm& algorithm, const NDArray* src, const NDArray* diff_src, const NDArray* dst, dnnl::algorithm& algorithm,
dnnl::memory::desc* pool_src_md, dnnl::memory::desc* pool_diff_src_md, dnnl::memory::desc* pool_dst_md, dnnl::memory::desc* pool_src_md, dnnl::memory::desc* pool_diff_src_md, dnnl::memory::desc* pool_dst_md,
dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md,
dnnl::memory::dims& pool_strides, dnnl::memory::dims& pool_kernel, dnnl::memory::dims& pool_padding, dnnl::memory::dims& pool_padding_r); dnnl::memory::dims& pool_strides, dnnl::memory::dims& pool_kernel, dnnl::memory::dims& pool_padding, dnnl::memory::dims& pool_padding_r);
void getMKLDNNMemoryDescBatchNorm(const NDArray* src, const NDArray* diff_src, const NDArray* dst, void getMKLDNNMemoryDescBatchNorm(const NDArray* src, const NDArray* diff_src, const NDArray* dst,
dnnl::memory::desc* batchnorm_src_md, dnnl::memory::desc* batchnorm_diff_src_md, dnnl::memory::desc* batchnorm_dst_md, dnnl::memory::desc* batchnorm_src_md, dnnl::memory::desc* batchnorm_diff_src_md, dnnl::memory::desc* batchnorm_dst_md,
dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, int axis); dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, int axis);
*/ */
} }
} }

View File

@ -0,0 +1,183 @@
/*******************************************************************************
* Copyright (c) 2019-2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author Oleg Semeniv <oleg.semeniv@gmail.com>
//
//
#include <ops/declarable/PlatformHelper.h>
#include <ops/declarable/OpRegistrator.h>
#include <system/platform_boilerplate.h>
#include <helpers/MKLDNNStream.h>
#include "mkldnnUtils.h"
using namespace dnnl;
namespace sd {
namespace ops {
namespace platforms {
//////////////////////////////////////////////////////////////////////
static void softmaxMKLDNN(const NDArray* x, NDArray* z, const int axis) {
const auto xRank = x->rankOf();
const auto zRank = z->rankOf();
std::vector<int64_t> dimsX(xRank), dimsZ(zRank);
for (auto i = 0; i < xRank; i++) {
dimsX[i] = x->sizeAt(i);
dimsZ[i] = z->sizeAt(i);
}
dnnl::memory::dims xShape = dnnl::memory::dims(dimsX);
dnnl::memory::dims zShape = dnnl::memory::dims(dimsZ);
dnnl::memory::format_tag format = dnnl::memory::format_tag::a; // 1 == xRank
if (2 == xRank && 1 == axis) {
format = dnnl::memory::format_tag::ab;
}
else if (2 == xRank && 0 == axis) {
format = dnnl::memory::format_tag::ba;
}
else if (3 == xRank) {
format = dnnl::memory::format_tag::abc;
}
else if (4 == xRank && 3 == axis) {
format = dnnl::memory::format_tag::abcd;
}
else if (4 == xRank && 1 == axis && dimsX[2] * dimsX[3] > 1) {
format = dnnl::memory::format_tag::acdb;
}
else if (4 == xRank) {
format = dnnl::memory::format_tag::abcd;
}
else if (5 == xRank) {
format = dnnl::memory::format_tag::abcde;
}
else if (6 == xRank) {
format = dnnl::memory::format_tag::abcdef;
}
dnnl::memory::data_type xType = dnnl::memory::data_type::f32;
dnnl::memory::data_type zType = dnnl::memory::data_type::f32;
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xShape, xType, format);
dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, xType, format);
if (x->ews() != 1 || x->ordering() != 'c') {
x_user_md.data.format_kind = dnnl_blocked; // overrides format
for (auto i = 0; i < xRank; ++i) {
x_user_md.data.format_desc.blocking.strides[i] = x->strideAt(i);
}
}
// z
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zShape, zType, dnnl::memory::format_tag::any);
dnnl::memory::desc z_user_md = dnnl::memory::desc(zShape, zType, format);
if (z->ews() != 1 || z->ordering() != 'c') {
z_user_md.data.format_kind = dnnl_blocked; // overrides format
for (auto i = 0; i < xRank; ++i) {
z_user_md.data.format_desc.blocking.strides[i] = z->strideAt(i);
}
}
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
// Create attributes (to handle alpha and beta if necessary)
dnnl::primitive_attr attr; // it is empty since we have usual values for alpha (=1) and beta (=0)
// operation primitive description
// todo check this
dnnl::softmax_forward::desc op_desc(dnnl::prop_kind::forward_inference, x_mkl_md, axis);
dnnl::softmax_forward::primitive_desc op_prim_desc(op_desc, attr, engine);
// arguments (memory buffers) necessary for calculations
std::unordered_map<int, dnnl::memory> args;
dnnl::stream stream(engine);
// provide memory buffers and check whether reorder is required
// input
auto x_user_mem = dnnl::memory(x_user_md, engine, x->getBuffer());
const bool xReorder = op_prim_desc.src_desc() != x_user_mem.get_desc();
auto x_mkl_mem = xReorder ? dnnl::memory(op_prim_desc.src_desc(), engine) : x_user_mem;
if (xReorder)
dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
args[DNNL_ARG_SRC] = x_mkl_mem;
// z
auto z_user_mem = dnnl::memory(z_user_md, engine, z->getBuffer());
const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc();
auto z_mkl_mem = zReorder ? dnnl::memory(op_prim_desc.dst_desc(), engine) : z_user_mem;
args[DNNL_ARG_DST] = z_mkl_mem;
// run calculations
dnnl::softmax_forward(op_prim_desc).execute(stream, args);
// reorder outputs if necessary
if (zReorder)
dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem);
stream.wait();
}
PLATFORM_IMPL(softmax, ENGINE_CPU) {
auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0);
const int rank = input->rankOf();
int dim = block.getIArguments()->size() > 0 ? INT_ARG(0) : rank - 1;
if (dim < 0) {
dim += rank;
}
REQUIRE_TRUE(dim < rank && dim >= 0, 0, "SOFTMAX_MKLDNN OP: the value of input integer parameter (dimension) must be less than input array rank %i, but got dimension = %i instead !", rank, dim);
REQUIRE_TRUE(rank <= 6, 0, "SOFTMAX_MKLDNN OP: the rank of input must be less or qual 4, but got rank = %i instead !", rank);
// mkldnnSoftMax
softmaxMKLDNN(input, output, dim);
return Status::OK();
}
PLATFORM_CHECK(softmax, ENGINE_CPU) {
auto x = INPUT_VARIABLE(0);
auto z = OUTPUT_VARIABLE(0);
const DataType xType = x->dataType();
const DataType zType = z->dataType();
const int xRank = x->rankOf();
bool bSupportedRanks = (xRank > 2 && xRank < 7);
/*
Source Destination
f32 f32
*/
return block.isUseMKLDNN() && bSupportedRanks && (xType == DataType::FLOAT32 && zType == DataType::FLOAT32);
}
}
}
}

File diff suppressed because it is too large Load Diff

View File

@ -69,6 +69,8 @@ TEST_F(MklDnnTests, helpers_includer) {
sd::ops::platforms::PLATFORM_matmul_ENGINE_CPU matmul; sd::ops::platforms::PLATFORM_matmul_ENGINE_CPU matmul;
printer({&conv2d, &conv2d_bp, &conv3d, &conv3d_bp, &avgpool2d, &avgpool2d_bp, &maxpool2d, &maxpool2d_bp, &avgpool3d, &avgpool3d_bp, &maxpool3d, &maxpool3d_bp, &lrn, &batchnorm, &matmul}); sd::ops::platforms::PLATFORM_softmax_ENGINE_CPU softmax;
printer({&conv2d, &conv2d_bp, &conv3d, &conv3d_bp, &avgpool2d, &avgpool2d_bp, &maxpool2d, &maxpool2d_bp, &avgpool3d, &avgpool3d_bp, &maxpool3d, &maxpool3d_bp, &lrn, &batchnorm, &matmul, &softmax });
#endif #endif
} }