raver119 7783012f39
cuDNN integration (#150)
* initial commit

Signed-off-by: raver119 <raver119@gmail.com>

* one file

Signed-off-by: raver119 <raver119@gmail.com>

* few more includes

Signed-off-by: raver119 <raver119@gmail.com>

* m?

Signed-off-by: raver119 <raver119@gmail.com>

* const

Signed-off-by: raver119 <raver119@gmail.com>

* cudnn linkage in tests

Signed-off-by: raver119 <raver119@gmail.com>

* culibos

Signed-off-by: raver119 <raver119@gmail.com>

* static reminder

Signed-off-by: raver119 <raver119@gmail.com>

* platform engine tag

Signed-off-by: raver119 <raver119@gmail.com>

* HAVE_CUDNN moved to config.h.in

Signed-off-by: raver119 <raver119@gmail.com>

* include

Signed-off-by: raver119 <raver119@gmail.com>

* include

Signed-off-by: raver119 <raver119@gmail.com>

* skip cudnn handle creation if there's not cudnn

Signed-off-by: raver119 <raver119@gmail.com>

* meh

Signed-off-by: raver119 <raver119@gmail.com>

* target device in context

Signed-off-by: raver119 <raver119@gmail.com>

* platform engines

Signed-off-by: raver119 <raver119@gmail.com>

* platform engines

Signed-off-by: raver119 <raver119@gmail.com>

* allow multiple -h args

Signed-off-by: raver119 <raver119@gmail.com>

* allow multiple -h args

Signed-off-by: raver119 <raver119@gmail.com>

* move mkldnn out of CPU block

Signed-off-by: raver119 <raver119@gmail.com>

* link to mkldnn on cuda

Signed-off-by: raver119 <raver119@gmail.com>

* less prints

Signed-off-by: raver119 <raver119@gmail.com>

* minor tweaks

Signed-off-by: raver119 <raver119@gmail.com>

* next step

Signed-off-by: raver119 <raver119@gmail.com>

* conv2d NCHW draft

Signed-off-by: raver119 <raver119@gmail.com>

* conv2d biasAdd

Signed-off-by: raver119 <raver119@gmail.com>

* test for MKL/CUDNN combined use

Signed-off-by: raver119 <raver119@gmail.com>

* - provide additional code for conv2d ff based on cudnn api, not tested yet

Signed-off-by: Yurii <iuriish@yahoo.com>

* - further work on conv2d helper based on using cudnn api

Signed-off-by: Yurii <iuriish@yahoo.com>

* - fixing several cuda bugs which appeared after cudnn lib had been started to use

Signed-off-by: Yurii <iuriish@yahoo.com>

* - implementation of conv2d backprop op based on cudnn api

Signed-off-by: Yurii <iuriish@yahoo.com>

* - implementaion of conv3d and conv3d_bp ops based on cudnn api

Signed-off-by: Yurii <iuriish@yahoo.com>

* - bugs fixing in conv3d/conv3d_bp ops (cudnn in use)

Signed-off-by: Yurii <iuriish@yahoo.com>

* - implementation of depthwiseConv2d (ff/bp) op based on cudnn api

Signed-off-by: Yurii <iuriish@yahoo.com>

* - implementation of batchnorm ff op based on cudnn api

Signed-off-by: Yurii <iuriish@yahoo.com>

* - disable cudnn batchnorm temporary

Signed-off-by: Yurii <iuriish@yahoo.com>

* - add minor change in cmake

Signed-off-by: Yurii <iuriish@yahoo.com>

* engine for depthwise mkldnn

Signed-off-by: raver119 <raver119@gmail.com>

* couple of includes

Signed-off-by: raver119 <raver119@gmail.com>

* - provide permutation to cudnn batchnorm ff when format is NHWC

Signed-off-by: Yurii <iuriish@yahoo.com>

* lgamma fix

Signed-off-by: raver119 <raver119@gmail.com>

* - eliminate memory leak in two tests

Signed-off-by: Yurii <iuriish@yahoo.com>

Co-authored-by: Yurii Shyrma <iuriish@yahoo.com>
2020-01-20 21:32:46 +03:00

93 lines
3.9 KiB
C++

/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author saudet
// @author raver119@gmail.com
//
#include <ops/declarable/PlatformHelper.h>
#include <ops/declarable/OpRegistrator.h>
#include <platform_boilerplate.h>
#include <helpers/MKLDNNStream.h>
#include "mkldnnUtils.h"
#include <ops/declarable/helpers/convolutions.h>
using namespace dnnl;
namespace nd4j {
namespace ops {
namespace platforms {
PLATFORM_IMPL(lrn, ENGINE_CPU) {
auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0);
REQUIRE_TRUE(input->rankOf() == 4, 0, "lrn: Input rank of 4 expected, but got %i instead",
input->rankOf());
double alpha = T_ARG(1);
double beta = T_ARG(2);
double bias = T_ARG(0);
int depth = INT_ARG(0);
dnnl_memory_desc_t empty;
dnnl::memory::desc lrn_src_md(empty), lrn_dst_md(empty), user_src_md(empty), user_dst_md(empty);
mkldnnUtils::getMKLDNNMemoryDescLrn(input, nullptr, output, &lrn_src_md, nullptr, &lrn_dst_md,
&user_src_md, nullptr, &user_dst_md, input->rankOf() - 1);
auto lrn_desc = lrn_forward::desc(prop_kind::forward_inference, algorithm::lrn_across_channels,
lrn_src_md, (2 * depth + 1), alpha * (2 * depth + 1), beta, bias);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
dnnl::stream stream(engine);
auto lrn_prim_desc = lrn_forward::primitive_desc(lrn_desc, engine);
auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer());
auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer());
auto lrn_src_memory = user_src_memory;
if (lrn_prim_desc.src_desc() != user_src_memory.get_desc()) {
lrn_src_memory = dnnl::memory(lrn_prim_desc.src_desc(), engine);
reorder(user_src_memory, lrn_src_memory).execute(stream, user_src_memory, lrn_src_memory);
}
auto lrn_dst_memory = user_dst_memory;
if (lrn_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
lrn_dst_memory = dnnl::memory(lrn_prim_desc.dst_desc(), engine);
}
lrn_forward(lrn_prim_desc).execute(stream, {{DNNL_ARG_SRC, lrn_src_memory},
{DNNL_ARG_DST, lrn_dst_memory}});
if (lrn_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
reorder(lrn_dst_memory, user_dst_memory).execute(stream, lrn_dst_memory, user_dst_memory);
}
stream.wait();
return Status::OK();
};
PLATFORM_CHECK(lrn, ENGINE_CPU) {
auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0);
return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output});
}
}
}
}