cavis/libnd4j/include/ops/declarable/platform/mkldnn/lrn.cpp

93 lines
3.9 KiB
C++

/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author saudet
// @author raver119@gmail.com
//
#include <ops/declarable/PlatformHelper.h>
#include <ops/declarable/OpRegistrator.h>
#include <platform_boilerplate.h>
#include <helpers/MKLDNNStream.h>
#include "mkldnnUtils.h"
#include <ops/declarable/helpers/convolutions.h>
using namespace dnnl;
namespace nd4j {
namespace ops {
namespace platforms {
PLATFORM_IMPL(lrn, ENGINE_CPU) {
auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0);
REQUIRE_TRUE(input->rankOf() == 4, 0, "lrn: Input rank of 4 expected, but got %i instead",
input->rankOf());
double alpha = T_ARG(1);
double beta = T_ARG(2);
double bias = T_ARG(0);
int depth = INT_ARG(0);
dnnl_memory_desc_t empty;
dnnl::memory::desc lrn_src_md(empty), lrn_dst_md(empty), user_src_md(empty), user_dst_md(empty);
mkldnnUtils::getMKLDNNMemoryDescLrn(input, nullptr, output, &lrn_src_md, nullptr, &lrn_dst_md,
&user_src_md, nullptr, &user_dst_md, input->rankOf() - 1);
auto lrn_desc = lrn_forward::desc(prop_kind::forward_inference, algorithm::lrn_across_channels,
lrn_src_md, (2 * depth + 1), alpha * (2 * depth + 1), beta, bias);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
dnnl::stream stream(engine);
auto lrn_prim_desc = lrn_forward::primitive_desc(lrn_desc, engine);
auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer());
auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer());
auto lrn_src_memory = user_src_memory;
if (lrn_prim_desc.src_desc() != user_src_memory.get_desc()) {
lrn_src_memory = dnnl::memory(lrn_prim_desc.src_desc(), engine);
reorder(user_src_memory, lrn_src_memory).execute(stream, user_src_memory, lrn_src_memory);
}
auto lrn_dst_memory = user_dst_memory;
if (lrn_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
lrn_dst_memory = dnnl::memory(lrn_prim_desc.dst_desc(), engine);
}
lrn_forward(lrn_prim_desc).execute(stream, {{DNNL_ARG_SRC, lrn_src_memory},
{DNNL_ARG_DST, lrn_dst_memory}});
if (lrn_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
reorder(lrn_dst_memory, user_dst_memory).execute(stream, lrn_dst_memory, user_dst_memory);
}
stream.wait();
return Status::OK();
};
PLATFORM_CHECK(lrn, ENGINE_CPU) {
auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0);
return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output});
}
}
}
}