Softmax operation implementation for mkldnn (#286)
* libnd4j first step of softmax mkldnn implementation Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j raw implementation of mkldnn softmax Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j merge master and added softmax to MklDnnTests Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j some corrections for softmax mkldnn Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j merge branch, fixed problem with negative axis, fixed dnnl::memory::format_tag selection, test cases added Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j minor corrections to avoid risk connected with negative axis usage Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j fixed windows builds, added switcher to use mkldnn sofmax version only for 3D, 4D, 5D, 6D arrays Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j fixed dataType selection per request Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j fix for mac and windows builds Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j builds fix Signed-off-by: Oleg <oleg.semeniv@gmail.com>master
parent
1c89512ec0
commit
4d81af9fe9
|
@ -14,14 +14,15 @@
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
//
|
//
|
||||||
// @author saudet
|
// @author saudet
|
||||||
// @author Yurii Shyrma (iuriish@yahoo.com)
|
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||||
//
|
//
|
||||||
|
|
||||||
#ifndef DEV_TESTS_MKLDNNUTILS_H
|
#ifndef DEV_TESTS_MKLDNNUTILS_H
|
||||||
#define DEV_TESTS_MKLDNNUTILS_H
|
#define DEV_TESTS_MKLDNNUTILS_H
|
||||||
|
|
||||||
|
|
||||||
#include <legacy/NativeOps.h>
|
#include <legacy/NativeOps.h>
|
||||||
#include <array/NDArray.h>
|
#include <array/NDArray.h>
|
||||||
#include <dnnl.hpp>
|
#include <dnnl.hpp>
|
||||||
|
@ -33,7 +34,7 @@
|
||||||
using namespace samediff;
|
using namespace samediff;
|
||||||
|
|
||||||
|
|
||||||
namespace sd{
|
namespace sd {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
namespace platforms {
|
namespace platforms {
|
||||||
/**
|
/**
|
||||||
|
@ -86,64 +87,67 @@ namespace sd{
|
||||||
DECLARE_PLATFORM(depthwise_conv2d_bp, ENGINE_CPU);
|
DECLARE_PLATFORM(depthwise_conv2d_bp, ENGINE_CPU);
|
||||||
|
|
||||||
DECLARE_PLATFORM(matmul, ENGINE_CPU);
|
DECLARE_PLATFORM(matmul, ENGINE_CPU);
|
||||||
|
|
||||||
|
DECLARE_PLATFORM(softmax, ENGINE_CPU);
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace mkldnnUtils {
|
namespace mkldnnUtils {
|
||||||
|
|
||||||
void poolingMKLDNN(const NDArray *input, NDArray *output, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int isNCHW, const dnnl::algorithm mode);
|
void poolingMKLDNN(const NDArray* input, NDArray* output, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int isNCHW, const dnnl::algorithm mode);
|
||||||
|
|
||||||
void poolingBpMKLDNN(const NDArray *input, const NDArray *gradO, NDArray *gradI, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int isNCHW, const dnnl::algorithm mode);
|
void poolingBpMKLDNN(const NDArray* input, const NDArray* gradO, NDArray* gradI, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int isNCHW, const dnnl::algorithm mode);
|
||||||
|
|
||||||
void getMKLDNNMemoryDescLrn(const NDArray* src, const NDArray* diff_src, const NDArray* dst,
|
void getMKLDNNMemoryDescLrn(const NDArray* src, const NDArray* diff_src, const NDArray* dst,
|
||||||
dnnl::memory::desc* lrn_src_md, dnnl::memory::desc* lrn_diff_src_md, dnnl::memory::desc* lrn_dst_md,
|
dnnl::memory::desc* lrn_src_md, dnnl::memory::desc* lrn_diff_src_md, dnnl::memory::desc* lrn_dst_md,
|
||||||
dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, int axis);
|
dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, int axis);
|
||||||
|
|
||||||
dnnl::engine& getEngine(void *ptr);
|
dnnl::engine& getEngine(void* ptr);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Utility methods for MKLDNN
|
* Utility methods for MKLDNN
|
||||||
*/
|
*/
|
||||||
/* void getMKLDNNMemoryDescConv2d(
|
/* void getMKLDNNMemoryDescConv2d(
|
||||||
int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, const int paddingMode, bool isNCHW,
|
int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, const int paddingMode, bool isNCHW,
|
||||||
int bS, int iC, int iH, int iW, int oC, int oH, int oW, const NDArray* src, const NDArray* diff_src,
|
int bS, int iC, int iH, int iW, int oC, int oH, int oW, const NDArray* src, const NDArray* diff_src,
|
||||||
const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst,
|
const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst,
|
||||||
dnnl::memory::desc* conv_src_md, dnnl::memory::desc* conv_diff_src_md, dnnl::memory::desc* conv_weights_md,
|
dnnl::memory::desc* conv_src_md, dnnl::memory::desc* conv_diff_src_md, dnnl::memory::desc* conv_weights_md,
|
||||||
dnnl::memory::desc* conv_diff_weights_md, dnnl::memory::desc* conv_bias_md, dnnl::memory::desc* conv_dst_md,
|
dnnl::memory::desc* conv_diff_weights_md, dnnl::memory::desc* conv_bias_md, dnnl::memory::desc* conv_dst_md,
|
||||||
dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_weights_md,
|
dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_weights_md,
|
||||||
dnnl::memory::desc* user_diff_weights_md, dnnl::memory::desc* user_bias_md, dnnl::memory::desc* user_dst_md,
|
dnnl::memory::desc* user_diff_weights_md, dnnl::memory::desc* user_bias_md, dnnl::memory::desc* user_dst_md,
|
||||||
dnnl::memory::dims& conv_strides, dnnl::memory::dims& conv_padding, dnnl::memory::dims& conv_padding_r, dnnl::memory::dims& conv_dilation);
|
dnnl::memory::dims& conv_strides, dnnl::memory::dims& conv_padding, dnnl::memory::dims& conv_padding_r, dnnl::memory::dims& conv_dilation);
|
||||||
|
|
||||||
void getMKLDNNMemoryDescConv3d(
|
void getMKLDNNMemoryDescConv3d(
|
||||||
int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, int dD, int dH, int dW, bool isSameMode, bool isNCDHW,
|
int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, int dD, int dH, int dW, bool isSameMode, bool isNCDHW,
|
||||||
int bS, int iC, int iD, int iH, int iW, int oC, int oD, int oH, int oW, const NDArray* src, const NDArray* diff_src,
|
int bS, int iC, int iD, int iH, int iW, int oC, int oD, int oH, int oW, const NDArray* src, const NDArray* diff_src,
|
||||||
const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst,
|
const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst,
|
||||||
dnnl::memory::desc* conv_src_md, dnnl::memory::desc* conv_diff_src_md, dnnl::memory::desc* conv_weights_md,
|
dnnl::memory::desc* conv_src_md, dnnl::memory::desc* conv_diff_src_md, dnnl::memory::desc* conv_weights_md,
|
||||||
dnnl::memory::desc* conv_diff_weights_md, dnnl::memory::desc* conv_bias_md, dnnl::memory::desc* conv_dst_md,
|
dnnl::memory::desc* conv_diff_weights_md, dnnl::memory::desc* conv_bias_md, dnnl::memory::desc* conv_dst_md,
|
||||||
dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_weights_md,
|
dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_weights_md,
|
||||||
dnnl::memory::desc* user_diff_weights_md, dnnl::memory::desc* user_bias_md, dnnl::memory::desc* user_dst_md,
|
dnnl::memory::desc* user_diff_weights_md, dnnl::memory::desc* user_bias_md, dnnl::memory::desc* user_dst_md,
|
||||||
dnnl::memory::dims& conv_strides, dnnl::memory::dims& conv_padding, dnnl::memory::dims& conv_padding_r, dnnl::memory::dims& conv_dilation);
|
dnnl::memory::dims& conv_strides, dnnl::memory::dims& conv_padding, dnnl::memory::dims& conv_padding_r, dnnl::memory::dims& conv_dilation);
|
||||||
|
|
||||||
void getMKLDNNMemoryDescPool2d(
|
void getMKLDNNMemoryDescPool2d(
|
||||||
int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, int poolingMode, int extraParam0, bool isNCHW,
|
int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, int poolingMode, int extraParam0, bool isNCHW,
|
||||||
int bS, int iC, int iH, int iW, int oC, int oH, int oW,
|
int bS, int iC, int iH, int iW, int oC, int oH, int oW,
|
||||||
const NDArray* src, const NDArray* diff_src, const NDArray* dst, dnnl::algorithm& algorithm,
|
const NDArray* src, const NDArray* diff_src, const NDArray* dst, dnnl::algorithm& algorithm,
|
||||||
dnnl::memory::desc* pool_src_md, dnnl::memory::desc* pool_diff_src_md, dnnl::memory::desc* pool_dst_md,
|
dnnl::memory::desc* pool_src_md, dnnl::memory::desc* pool_diff_src_md, dnnl::memory::desc* pool_dst_md,
|
||||||
dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md,
|
dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md,
|
||||||
dnnl::memory::dims& pool_strides, dnnl::memory::dims& pool_kernel, dnnl::memory::dims& pool_padding, dnnl::memory::dims& pool_padding_r);
|
dnnl::memory::dims& pool_strides, dnnl::memory::dims& pool_kernel, dnnl::memory::dims& pool_padding, dnnl::memory::dims& pool_padding_r);
|
||||||
|
|
||||||
void getMKLDNNMemoryDescPool3d(
|
void getMKLDNNMemoryDescPool3d(
|
||||||
int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, int dD, int dH, int dW, int poolingMode, int extraParam0, bool isNCDHW,
|
int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, int dD, int dH, int dW, int poolingMode, int extraParam0, bool isNCDHW,
|
||||||
int bS, int iC, int iD, int iH, int iW, int oC, int oD, int oH, int oW,
|
int bS, int iC, int iD, int iH, int iW, int oC, int oD, int oH, int oW,
|
||||||
const NDArray* src, const NDArray* diff_src, const NDArray* dst, dnnl::algorithm& algorithm,
|
const NDArray* src, const NDArray* diff_src, const NDArray* dst, dnnl::algorithm& algorithm,
|
||||||
dnnl::memory::desc* pool_src_md, dnnl::memory::desc* pool_diff_src_md, dnnl::memory::desc* pool_dst_md,
|
dnnl::memory::desc* pool_src_md, dnnl::memory::desc* pool_diff_src_md, dnnl::memory::desc* pool_dst_md,
|
||||||
dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md,
|
dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md,
|
||||||
dnnl::memory::dims& pool_strides, dnnl::memory::dims& pool_kernel, dnnl::memory::dims& pool_padding, dnnl::memory::dims& pool_padding_r);
|
dnnl::memory::dims& pool_strides, dnnl::memory::dims& pool_kernel, dnnl::memory::dims& pool_padding, dnnl::memory::dims& pool_padding_r);
|
||||||
|
|
||||||
void getMKLDNNMemoryDescBatchNorm(const NDArray* src, const NDArray* diff_src, const NDArray* dst,
|
void getMKLDNNMemoryDescBatchNorm(const NDArray* src, const NDArray* diff_src, const NDArray* dst,
|
||||||
dnnl::memory::desc* batchnorm_src_md, dnnl::memory::desc* batchnorm_diff_src_md, dnnl::memory::desc* batchnorm_dst_md,
|
dnnl::memory::desc* batchnorm_src_md, dnnl::memory::desc* batchnorm_diff_src_md, dnnl::memory::desc* batchnorm_dst_md,
|
||||||
dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, int axis);
|
dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, int axis);
|
||||||
*/
|
*/
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,183 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2019-2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author Oleg Semeniv <oleg.semeniv@gmail.com>
|
||||||
|
//
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <ops/declarable/PlatformHelper.h>
|
||||||
|
#include <ops/declarable/OpRegistrator.h>
|
||||||
|
#include <system/platform_boilerplate.h>
|
||||||
|
#include <helpers/MKLDNNStream.h>
|
||||||
|
#include "mkldnnUtils.h"
|
||||||
|
|
||||||
|
using namespace dnnl;
|
||||||
|
|
||||||
|
namespace sd {
|
||||||
|
namespace ops {
|
||||||
|
namespace platforms {
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
static void softmaxMKLDNN(const NDArray* x, NDArray* z, const int axis) {
|
||||||
|
|
||||||
|
const auto xRank = x->rankOf();
|
||||||
|
const auto zRank = z->rankOf();
|
||||||
|
|
||||||
|
std::vector<int64_t> dimsX(xRank), dimsZ(zRank);
|
||||||
|
for (auto i = 0; i < xRank; i++) {
|
||||||
|
dimsX[i] = x->sizeAt(i);
|
||||||
|
dimsZ[i] = z->sizeAt(i);
|
||||||
|
}
|
||||||
|
|
||||||
|
dnnl::memory::dims xShape = dnnl::memory::dims(dimsX);
|
||||||
|
dnnl::memory::dims zShape = dnnl::memory::dims(dimsZ);
|
||||||
|
|
||||||
|
dnnl::memory::format_tag format = dnnl::memory::format_tag::a; // 1 == xRank
|
||||||
|
if (2 == xRank && 1 == axis) {
|
||||||
|
format = dnnl::memory::format_tag::ab;
|
||||||
|
}
|
||||||
|
else if (2 == xRank && 0 == axis) {
|
||||||
|
format = dnnl::memory::format_tag::ba;
|
||||||
|
}
|
||||||
|
else if (3 == xRank) {
|
||||||
|
format = dnnl::memory::format_tag::abc;
|
||||||
|
}
|
||||||
|
else if (4 == xRank && 3 == axis) {
|
||||||
|
format = dnnl::memory::format_tag::abcd;
|
||||||
|
}
|
||||||
|
else if (4 == xRank && 1 == axis && dimsX[2] * dimsX[3] > 1) {
|
||||||
|
format = dnnl::memory::format_tag::acdb;
|
||||||
|
}
|
||||||
|
else if (4 == xRank) {
|
||||||
|
format = dnnl::memory::format_tag::abcd;
|
||||||
|
}
|
||||||
|
else if (5 == xRank) {
|
||||||
|
format = dnnl::memory::format_tag::abcde;
|
||||||
|
}
|
||||||
|
else if (6 == xRank) {
|
||||||
|
format = dnnl::memory::format_tag::abcdef;
|
||||||
|
}
|
||||||
|
|
||||||
|
dnnl::memory::data_type xType = dnnl::memory::data_type::f32;
|
||||||
|
dnnl::memory::data_type zType = dnnl::memory::data_type::f32;
|
||||||
|
|
||||||
|
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xShape, xType, format);
|
||||||
|
dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, xType, format);
|
||||||
|
|
||||||
|
if (x->ews() != 1 || x->ordering() != 'c') {
|
||||||
|
x_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||||
|
for (auto i = 0; i < xRank; ++i) {
|
||||||
|
x_user_md.data.format_desc.blocking.strides[i] = x->strideAt(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// z
|
||||||
|
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zShape, zType, dnnl::memory::format_tag::any);
|
||||||
|
dnnl::memory::desc z_user_md = dnnl::memory::desc(zShape, zType, format);
|
||||||
|
if (z->ews() != 1 || z->ordering() != 'c') {
|
||||||
|
z_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||||
|
for (auto i = 0; i < xRank; ++i) {
|
||||||
|
z_user_md.data.format_desc.blocking.strides[i] = z->strideAt(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||||
|
|
||||||
|
// Create attributes (to handle alpha and beta if necessary)
|
||||||
|
dnnl::primitive_attr attr; // it is empty since we have usual values for alpha (=1) and beta (=0)
|
||||||
|
|
||||||
|
// operation primitive description
|
||||||
|
// todo check this
|
||||||
|
dnnl::softmax_forward::desc op_desc(dnnl::prop_kind::forward_inference, x_mkl_md, axis);
|
||||||
|
|
||||||
|
dnnl::softmax_forward::primitive_desc op_prim_desc(op_desc, attr, engine);
|
||||||
|
|
||||||
|
// arguments (memory buffers) necessary for calculations
|
||||||
|
std::unordered_map<int, dnnl::memory> args;
|
||||||
|
|
||||||
|
dnnl::stream stream(engine);
|
||||||
|
|
||||||
|
// provide memory buffers and check whether reorder is required
|
||||||
|
|
||||||
|
// input
|
||||||
|
auto x_user_mem = dnnl::memory(x_user_md, engine, x->getBuffer());
|
||||||
|
const bool xReorder = op_prim_desc.src_desc() != x_user_mem.get_desc();
|
||||||
|
auto x_mkl_mem = xReorder ? dnnl::memory(op_prim_desc.src_desc(), engine) : x_user_mem;
|
||||||
|
if (xReorder)
|
||||||
|
dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
|
||||||
|
args[DNNL_ARG_SRC] = x_mkl_mem;
|
||||||
|
|
||||||
|
// z
|
||||||
|
auto z_user_mem = dnnl::memory(z_user_md, engine, z->getBuffer());
|
||||||
|
const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc();
|
||||||
|
auto z_mkl_mem = zReorder ? dnnl::memory(op_prim_desc.dst_desc(), engine) : z_user_mem;
|
||||||
|
args[DNNL_ARG_DST] = z_mkl_mem;
|
||||||
|
|
||||||
|
// run calculations
|
||||||
|
dnnl::softmax_forward(op_prim_desc).execute(stream, args);
|
||||||
|
|
||||||
|
// reorder outputs if necessary
|
||||||
|
if (zReorder)
|
||||||
|
dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem);
|
||||||
|
|
||||||
|
stream.wait();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
PLATFORM_IMPL(softmax, ENGINE_CPU) {
|
||||||
|
|
||||||
|
auto input = INPUT_VARIABLE(0);
|
||||||
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
const int rank = input->rankOf();
|
||||||
|
int dim = block.getIArguments()->size() > 0 ? INT_ARG(0) : rank - 1;
|
||||||
|
|
||||||
|
if (dim < 0) {
|
||||||
|
dim += rank;
|
||||||
|
}
|
||||||
|
|
||||||
|
REQUIRE_TRUE(dim < rank && dim >= 0, 0, "SOFTMAX_MKLDNN OP: the value of input integer parameter (dimension) must be less than input array rank %i, but got dimension = %i instead !", rank, dim);
|
||||||
|
|
||||||
|
REQUIRE_TRUE(rank <= 6, 0, "SOFTMAX_MKLDNN OP: the rank of input must be less or qual 4, but got rank = %i instead !", rank);
|
||||||
|
|
||||||
|
// mkldnnSoftMax
|
||||||
|
softmaxMKLDNN(input, output, dim);
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
PLATFORM_CHECK(softmax, ENGINE_CPU) {
|
||||||
|
|
||||||
|
auto x = INPUT_VARIABLE(0);
|
||||||
|
auto z = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
const DataType xType = x->dataType();
|
||||||
|
const DataType zType = z->dataType();
|
||||||
|
|
||||||
|
const int xRank = x->rankOf();
|
||||||
|
bool bSupportedRanks = (xRank > 2 && xRank < 7);
|
||||||
|
/*
|
||||||
|
Source Destination
|
||||||
|
f32 f32
|
||||||
|
*/
|
||||||
|
return block.isUseMKLDNN() && bSupportedRanks && (xType == DataType::FLOAT32 && zType == DataType::FLOAT32);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
File diff suppressed because it is too large
Load Diff
|
@ -69,6 +69,8 @@ TEST_F(MklDnnTests, helpers_includer) {
|
||||||
|
|
||||||
sd::ops::platforms::PLATFORM_matmul_ENGINE_CPU matmul;
|
sd::ops::platforms::PLATFORM_matmul_ENGINE_CPU matmul;
|
||||||
|
|
||||||
printer({&conv2d, &conv2d_bp, &conv3d, &conv3d_bp, &avgpool2d, &avgpool2d_bp, &maxpool2d, &maxpool2d_bp, &avgpool3d, &avgpool3d_bp, &maxpool3d, &maxpool3d_bp, &lrn, &batchnorm, &matmul});
|
sd::ops::platforms::PLATFORM_softmax_ENGINE_CPU softmax;
|
||||||
|
|
||||||
|
printer({&conv2d, &conv2d_bp, &conv3d, &conv3d_bp, &avgpool2d, &avgpool2d_bp, &maxpool2d, &maxpool2d_bp, &avgpool3d, &avgpool3d_bp, &maxpool3d, &maxpool3d_bp, &lrn, &batchnorm, &matmul, &softmax });
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
Loading…
Reference in New Issue