2019-09-11 20:50:28 +02:00
|
|
|
/*******************************************************************************
|
|
|
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
|
|
*
|
|
|
|
* This program and the accompanying materials are made available under the
|
|
|
|
* terms of the Apache License, Version 2.0 which is available at
|
|
|
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
|
|
*
|
|
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
|
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
|
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
|
|
* License for the specific language governing permissions and limitations
|
|
|
|
* under the License.
|
|
|
|
*
|
|
|
|
* SPDX-License-Identifier: Apache-2.0
|
|
|
|
******************************************************************************/
|
|
|
|
|
2020-03-04 17:36:42 +01:00
|
|
|
//
|
|
|
|
// @author saudet
|
|
|
|
// @author Yurii Shyrma (iuriish@yahoo.com)
|
|
|
|
//
|
2019-09-11 20:50:28 +02:00
|
|
|
|
|
|
|
#ifndef DEV_TESTS_MKLDNNUTILS_H
|
|
|
|
#define DEV_TESTS_MKLDNNUTILS_H
|
|
|
|
|
2020-03-04 17:36:42 +01:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
#include <legacy/NativeOps.h>
|
|
|
|
#include <array/NDArray.h>
|
2019-11-20 11:23:08 +01:00
|
|
|
#include <dnnl.hpp>
|
2020-03-02 10:49:41 +01:00
|
|
|
#include <helpers/MKLDNNStream.h>
|
2019-09-11 20:50:28 +02:00
|
|
|
#include <graph/Context.h>
|
|
|
|
#include <ops/declarable/PlatformHelper.h>
|
2020-03-02 10:49:41 +01:00
|
|
|
#include <system/platform_boilerplate.h>
|
2019-09-11 20:50:28 +02:00
|
|
|
|
2020-03-09 06:22:49 +01:00
|
|
|
using namespace samediff;
|
2020-01-20 19:32:46 +01:00
|
|
|
|
2019-09-11 20:50:28 +02:00
|
|
|
|
2020-03-04 17:36:42 +01:00
|
|
|
namespace sd {
|
2019-09-11 20:50:28 +02:00
|
|
|
namespace ops {
|
|
|
|
namespace platforms {
|
|
|
|
/**
|
|
|
|
* Here we actually declare our platform helpers
|
|
|
|
*/
|
2020-01-20 19:32:46 +01:00
|
|
|
DECLARE_PLATFORM(conv2d, ENGINE_CPU);
|
2019-09-11 20:50:28 +02:00
|
|
|
|
2020-01-20 19:32:46 +01:00
|
|
|
DECLARE_PLATFORM(conv2d_bp, ENGINE_CPU);
|
2019-09-11 20:50:28 +02:00
|
|
|
|
2020-01-20 19:32:46 +01:00
|
|
|
DECLARE_PLATFORM(avgpool2d, ENGINE_CPU);
|
2019-09-11 20:50:28 +02:00
|
|
|
|
2020-01-20 19:32:46 +01:00
|
|
|
DECLARE_PLATFORM(avgpool2d_bp, ENGINE_CPU);
|
2019-09-11 20:50:28 +02:00
|
|
|
|
2020-01-20 19:32:46 +01:00
|
|
|
DECLARE_PLATFORM(maxpool2d, ENGINE_CPU);
|
2019-09-11 20:50:28 +02:00
|
|
|
|
2020-01-20 19:32:46 +01:00
|
|
|
DECLARE_PLATFORM(maxpool2d_bp, ENGINE_CPU);
|
2019-09-11 20:50:28 +02:00
|
|
|
|
2020-01-20 19:32:46 +01:00
|
|
|
DECLARE_PLATFORM(conv3dnew, ENGINE_CPU);
|
2019-09-11 20:50:28 +02:00
|
|
|
|
2020-01-20 19:32:46 +01:00
|
|
|
DECLARE_PLATFORM(conv3dnew_bp, ENGINE_CPU);
|
2019-09-11 20:50:28 +02:00
|
|
|
|
2020-01-20 19:32:46 +01:00
|
|
|
DECLARE_PLATFORM(maxpool3dnew, ENGINE_CPU);
|
2019-09-11 20:50:28 +02:00
|
|
|
|
2020-01-20 19:32:46 +01:00
|
|
|
DECLARE_PLATFORM(maxpool3dnew_bp, ENGINE_CPU);
|
2019-09-11 20:50:28 +02:00
|
|
|
|
2020-01-20 19:32:46 +01:00
|
|
|
DECLARE_PLATFORM(avgpool3dnew, ENGINE_CPU);
|
2019-09-11 20:50:28 +02:00
|
|
|
|
2020-01-20 19:32:46 +01:00
|
|
|
DECLARE_PLATFORM(avgpool3dnew_bp, ENGINE_CPU);
|
2019-09-11 20:50:28 +02:00
|
|
|
|
2020-01-20 19:32:46 +01:00
|
|
|
DECLARE_PLATFORM(lrn, ENGINE_CPU);
|
2019-09-11 20:50:28 +02:00
|
|
|
|
2020-01-20 19:32:46 +01:00
|
|
|
DECLARE_PLATFORM(batchnorm, ENGINE_CPU);
|
2019-10-26 13:14:21 +02:00
|
|
|
|
2020-01-20 19:32:46 +01:00
|
|
|
DECLARE_PLATFORM(batchnorm_bp, ENGINE_CPU);
|
2019-10-17 19:44:52 +02:00
|
|
|
|
2020-01-20 19:32:46 +01:00
|
|
|
DECLARE_PLATFORM(lstmLayer, ENGINE_CPU);
|
2019-11-03 11:37:19 +01:00
|
|
|
|
2020-01-20 19:32:46 +01:00
|
|
|
DECLARE_PLATFORM(deconv2d, ENGINE_CPU);
|
2019-11-03 11:37:19 +01:00
|
|
|
|
2020-01-20 19:32:46 +01:00
|
|
|
DECLARE_PLATFORM(deconv2d_tf, ENGINE_CPU);
|
2019-11-03 11:37:19 +01:00
|
|
|
|
2020-01-20 19:32:46 +01:00
|
|
|
DECLARE_PLATFORM(deconv3d, ENGINE_CPU);
|
2019-11-03 11:37:19 +01:00
|
|
|
|
2020-01-20 19:32:46 +01:00
|
|
|
DECLARE_PLATFORM(deconv2d_bp, ENGINE_CPU);
|
2019-11-03 11:37:19 +01:00
|
|
|
|
2020-01-20 19:32:46 +01:00
|
|
|
DECLARE_PLATFORM(deconv3d_bp, ENGINE_CPU);
|
2020-01-11 05:36:40 +01:00
|
|
|
|
2020-01-20 19:32:46 +01:00
|
|
|
DECLARE_PLATFORM(depthwise_conv2d, ENGINE_CPU);
|
2020-02-06 19:12:54 +01:00
|
|
|
|
2020-01-20 19:32:46 +01:00
|
|
|
DECLARE_PLATFORM(depthwise_conv2d_bp, ENGINE_CPU);
|
2020-02-18 06:58:01 +01:00
|
|
|
|
|
|
|
DECLARE_PLATFORM(matmul, ENGINE_CPU);
|
2020-03-04 17:36:42 +01:00
|
|
|
|
|
|
|
DECLARE_PLATFORM(softmax, ENGINE_CPU);
|
|
|
|
|
2020-03-12 16:25:29 +01:00
|
|
|
DECLARE_PLATFORM(softmax_bp, ENGINE_CPU);
|
|
|
|
|
Tanh mkldnn implementation (#296)
* libnd4j first step of softmax mkldnn implementation
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j raw implementation of mkldnn softmax
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j merge master and added softmax to MklDnnTests
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j some corrections for softmax mkldnn
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j merge branch, fixed problem with negative axis, fixed dnnl::memory::format_tag selection, test cases added
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j minor corrections to avoid risk connected with negative axis usage
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j fixed windows builds, added switcher to use mkldnn sofmax version only for 3D, 4D, 5D, 6D arrays
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j fixed dataType selection per request
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j fix for mac and windows builds
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j builds fix
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j first spet of elementwize tanh implementation on mkldnn
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j fixed typo in error message for softmax MKLDNN, test case added, implementation of tanh on MKLDNN, need supported DataType testing
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j several fixes for tanh and temporary performance test added
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j fixed mkldnn platform loader for tanh
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j MklDnn tanh removed unsupported data types, removed performance test case, added more appropriate equivalence test case, code clean up
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j fixed problem with empty input case for MklDnn tanh and softmax
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
2020-03-06 15:11:22 +01:00
|
|
|
DECLARE_PLATFORM(tanh, ENGINE_CPU);
|
|
|
|
|
2020-03-13 17:01:00 +01:00
|
|
|
DECLARE_PLATFORM(tanh_bp, ENGINE_CPU);
|
|
|
|
|
xw_plus_b mkldnn implementation (#247)
* libnd4j first step of mkldnn for xw_plus_b and test of aurora crash in imageHelper
* libnd4j sync folders with master
* libnd4j merge master, raw implementation of xw_plus_b on mkldnn, clean up, need testing and adding checks for corresponded input shapes
* libnd4j corrections and checks added to xw_plus_b mkl
* libnd4j corrected dataType description based on mkl operation description, need more investigation
* libnd4j fixe xw_blus_b mkl implementation, need testing
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j two unit tests added
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j fixed check input dimensions bug
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libndj4 one more test added to cover different order handling
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j added optional int arg support to define weights format, if arg == 1, mkldnn (do not need transpose in mkldnn implementation), else mmul weights format, corrected check points, added unit test
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j merge master
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j some improvements to avoid NDArray transpose in xw_plus_b operation
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j fixed issues connected with weights rank, also added support of one case based on tf (for mkldnn, cpu, cuda), test case added
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j added proper handling of empty inputs (all implementations)
* libnd4j fixed compilation error
* libnd4j several more corrections after conflict solve and fixed typos
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j removed unsupported data types
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j merge master and fixed issues
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j added propagation implementation for xw_plus_b, fixed issue connected with mkl weights data format, avoided data copy in transpose mode, test cases added, manually tested with gradCheck
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j one minor fix of double operation declaration
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j code clean up
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j minor tests fixes
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j fixed build problem, integrate helpers changes
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
Co-authored-by: raver119 <raver119@gmail.com>
2020-03-31 12:03:10 +02:00
|
|
|
DECLARE_PLATFORM(xw_plus_b, ENGINE_CPU);
|
|
|
|
|
|
|
|
DECLARE_PLATFORM(xw_plus_b_bp, ENGINE_CPU);
|
|
|
|
|
2020-05-12 06:47:09 +02:00
|
|
|
DECLARE_PLATFORM(concat, ENGINE_CPU);
|
|
|
|
|
2019-09-11 20:50:28 +02:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
namespace mkldnnUtils {
|
|
|
|
|
2020-03-04 17:36:42 +01:00
|
|
|
void poolingMKLDNN(const NDArray* input, NDArray* output, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int isNCHW, const dnnl::algorithm mode);
|
2020-02-06 19:12:54 +01:00
|
|
|
|
2020-03-04 17:36:42 +01:00
|
|
|
void poolingBpMKLDNN(const NDArray* input, const NDArray* gradO, NDArray* gradI, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int isNCHW, const dnnl::algorithm mode);
|
2020-02-06 19:12:54 +01:00
|
|
|
|
|
|
|
void getMKLDNNMemoryDescLrn(const NDArray* src, const NDArray* diff_src, const NDArray* dst,
|
2020-03-04 17:36:42 +01:00
|
|
|
dnnl::memory::desc* lrn_src_md, dnnl::memory::desc* lrn_diff_src_md, dnnl::memory::desc* lrn_dst_md,
|
|
|
|
dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, int axis);
|
2020-02-06 19:12:54 +01:00
|
|
|
|
2020-03-04 17:36:42 +01:00
|
|
|
dnnl::engine& getEngine(void* ptr);
|
2020-02-06 19:12:54 +01:00
|
|
|
|
2020-03-12 16:25:29 +01:00
|
|
|
/**
|
|
|
|
* This function creates memory dimentions
|
|
|
|
* @param const pointer to array
|
|
|
|
* @param const array rank
|
|
|
|
* @param reference to memory dimentions
|
|
|
|
*/
|
|
|
|
void getDims(const NDArray* array, const int rank, dnnl::memory::dims& mklDims);
|
|
|
|
/**
|
2020-05-12 06:47:09 +02:00
|
|
|
* This function evaluate memory format tag based on array shapeInfo
|
|
|
|
* @param const array
|
2020-03-12 16:25:29 +01:00
|
|
|
* @return memory format
|
|
|
|
*/
|
2020-05-12 06:47:09 +02:00
|
|
|
dnnl::memory::format_tag getFormat(const NDArray& arr);
|
|
|
|
|
|
|
|
void setBlockStrides(const NDArray& array, dnnl::memory::desc& mklMd, const std::vector<int>& permut = {});
|
2020-03-12 16:25:29 +01:00
|
|
|
//////////////////////////////////////////////////////////////////////
|
|
|
|
/**
|
|
|
|
* This function load and reorder user memory to mkl
|
|
|
|
* @param const pointer to dataset
|
|
|
|
* @param reference to mkl engine
|
|
|
|
* @param reference to mkl stream
|
|
|
|
* @param reference to args container for dnnl
|
|
|
|
* @param reference to user memory description
|
|
|
|
* @param primitive memory descriptor
|
|
|
|
* @param dnnl arg activation enumerator
|
|
|
|
*/
|
2020-05-12 06:47:09 +02:00
|
|
|
dnnl::memory loadDataToMklStream(const NDArray& array, const dnnl::engine& engine, const dnnl::stream& stream, const dnnl::memory::desc& user_md, const dnnl::memory::desc& primitive_md,
|
2020-03-20 10:11:27 +01:00
|
|
|
dnnl::memory& arg);
|
2020-03-12 16:25:29 +01:00
|
|
|
|
2019-09-11 20:50:28 +02:00
|
|
|
/**
|
|
|
|
* Utility methods for MKLDNN
|
|
|
|
*/
|
2020-03-04 17:36:42 +01:00
|
|
|
/* void getMKLDNNMemoryDescConv2d(
|
|
|
|
int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, const int paddingMode, bool isNCHW,
|
|
|
|
int bS, int iC, int iH, int iW, int oC, int oH, int oW, const NDArray* src, const NDArray* diff_src,
|
|
|
|
const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst,
|
|
|
|
dnnl::memory::desc* conv_src_md, dnnl::memory::desc* conv_diff_src_md, dnnl::memory::desc* conv_weights_md,
|
|
|
|
dnnl::memory::desc* conv_diff_weights_md, dnnl::memory::desc* conv_bias_md, dnnl::memory::desc* conv_dst_md,
|
|
|
|
dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_weights_md,
|
|
|
|
dnnl::memory::desc* user_diff_weights_md, dnnl::memory::desc* user_bias_md, dnnl::memory::desc* user_dst_md,
|
|
|
|
dnnl::memory::dims& conv_strides, dnnl::memory::dims& conv_padding, dnnl::memory::dims& conv_padding_r, dnnl::memory::dims& conv_dilation);
|
|
|
|
|
|
|
|
void getMKLDNNMemoryDescConv3d(
|
|
|
|
int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, int dD, int dH, int dW, bool isSameMode, bool isNCDHW,
|
|
|
|
int bS, int iC, int iD, int iH, int iW, int oC, int oD, int oH, int oW, const NDArray* src, const NDArray* diff_src,
|
|
|
|
const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst,
|
|
|
|
dnnl::memory::desc* conv_src_md, dnnl::memory::desc* conv_diff_src_md, dnnl::memory::desc* conv_weights_md,
|
|
|
|
dnnl::memory::desc* conv_diff_weights_md, dnnl::memory::desc* conv_bias_md, dnnl::memory::desc* conv_dst_md,
|
|
|
|
dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_weights_md,
|
|
|
|
dnnl::memory::desc* user_diff_weights_md, dnnl::memory::desc* user_bias_md, dnnl::memory::desc* user_dst_md,
|
|
|
|
dnnl::memory::dims& conv_strides, dnnl::memory::dims& conv_padding, dnnl::memory::dims& conv_padding_r, dnnl::memory::dims& conv_dilation);
|
|
|
|
|
|
|
|
void getMKLDNNMemoryDescPool2d(
|
|
|
|
int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, int poolingMode, int extraParam0, bool isNCHW,
|
|
|
|
int bS, int iC, int iH, int iW, int oC, int oH, int oW,
|
|
|
|
const NDArray* src, const NDArray* diff_src, const NDArray* dst, dnnl::algorithm& algorithm,
|
|
|
|
dnnl::memory::desc* pool_src_md, dnnl::memory::desc* pool_diff_src_md, dnnl::memory::desc* pool_dst_md,
|
|
|
|
dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md,
|
|
|
|
dnnl::memory::dims& pool_strides, dnnl::memory::dims& pool_kernel, dnnl::memory::dims& pool_padding, dnnl::memory::dims& pool_padding_r);
|
|
|
|
|
|
|
|
void getMKLDNNMemoryDescPool3d(
|
|
|
|
int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, int dD, int dH, int dW, int poolingMode, int extraParam0, bool isNCDHW,
|
|
|
|
int bS, int iC, int iD, int iH, int iW, int oC, int oD, int oH, int oW,
|
|
|
|
const NDArray* src, const NDArray* diff_src, const NDArray* dst, dnnl::algorithm& algorithm,
|
|
|
|
dnnl::memory::desc* pool_src_md, dnnl::memory::desc* pool_diff_src_md, dnnl::memory::desc* pool_dst_md,
|
|
|
|
dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md,
|
|
|
|
dnnl::memory::dims& pool_strides, dnnl::memory::dims& pool_kernel, dnnl::memory::dims& pool_padding, dnnl::memory::dims& pool_padding_r);
|
|
|
|
|
|
|
|
void getMKLDNNMemoryDescBatchNorm(const NDArray* src, const NDArray* diff_src, const NDArray* dst,
|
|
|
|
dnnl::memory::desc* batchnorm_src_md, dnnl::memory::desc* batchnorm_diff_src_md, dnnl::memory::desc* batchnorm_dst_md,
|
|
|
|
dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, int axis);
|
|
|
|
*/
|
2019-09-11 20:50:28 +02:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#endif //DEV_TESTS_MKLDNNUTILS_H
|