/* ****************************************************************************** * * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * * See the NOTICE file distributed with this work for additional * information regarding copyright ownership. * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations * under the License. * * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ // // @author Yurii Shyrma (iuriish@yahoo.com) // #include #include #include #include #include "mkldnnUtils.h" #include namespace sd { namespace ops { namespace platforms { ////////////////////////////////////////////////////////////////////////// static void concatMKLDNN(const std::vector& inArrs, NDArray& output, const int axis) { // data type dnnl::memory::data_type type; if(output.dataType() == DataType::FLOAT32) type = dnnl::memory::data_type::f32; else if(output.dataType() == DataType::HALF) type = dnnl::memory::data_type::f16; else if(output.dataType() == DataType::BFLOAT16) type = dnnl::memory::data_type::bf16; else if(output.dataType() == DataType::UINT8) type = dnnl::memory::data_type::u8; else type = dnnl::memory::data_type::s8; std::vector x_user_md(inArrs.size()), x_mkl_md(inArrs.size()); // inputs for (int i = 0; i < inArrs.size(); ++i) { dnnl::memory::dims dims = inArrs[i]->getShapeAsFlatVector(); x_user_md[i] = x_mkl_md[i] = dnnl::memory::desc(dims, type, mkldnnUtils::getFormat(*inArrs[i])); mkldnnUtils::setBlockStrides(*inArrs[i], x_user_md[i]); } // output dnnl::memory::dims dims = output.getShapeAsFlatVector(); dnnl::memory::desc z_mkl_md = dnnl::memory::desc(dims, type, dnnl::memory::format_tag::any); dnnl::memory::desc z_user_md = dnnl::memory::desc(dims, type, mkldnnUtils::getFormat(output)); mkldnnUtils::setBlockStrides(output, z_user_md); std::unordered_map args; auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); dnnl::concat::primitive_desc op_prim_desc(axis, x_mkl_md, engine); dnnl::stream stream(engine); // inputs for (int i = 0; i < inArrs.size(); ++i) mkldnnUtils::loadDataToMklStream(*inArrs[i], engine, stream, x_user_md[i], op_prim_desc.src_desc(i), args[DNNL_ARG_MULTIPLE_SRC + i]); // outputs auto z_user_mem = mkldnnUtils::loadDataToMklStream(output, engine, stream, z_user_md, op_prim_desc.dst_desc(), args[DNNL_ARG_DST]); // primitive execution dnnl::concat(op_prim_desc).execute(stream, args); // reorder output if necessary if (op_prim_desc.dst_desc() != z_user_mem.get_desc()) dnnl::reorder(args[DNNL_ARG_DST], z_user_mem).execute(stream, args[DNNL_ARG_DST], z_user_mem); stream.wait(); } ////////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(concat, ENGINE_CPU) { REQUIRE_TRUE(block.width() > 0, 0, "CONCAT MKLDNN op: No input arrays were provided"); const bool isAxisInLastArr = block.getBArguments()->size() == 0 ? false : B_ARG(0); const int numOfInArrs = isAxisInLastArr ? block.width() - 1 : block.width(); // first of all take into account possible presence of empty arrays // also if scalar is present -> copy its value to vector with length=1 std::vector nonEmptyArrs; std::vector arrsToDelete; int index = 0; bool allOfSameType = true; auto rankOfFirstArr = block.width() > 0 ? INPUT_VARIABLE(0)->rankOf() : 0; auto typeOfFirstArr = block.width() > 0 ? INPUT_VARIABLE(0)->dataType() : block.dataType(); for(int i = 0; i < numOfInArrs; ++i) { auto input = INPUT_VARIABLE(i); auto currentRank = input->rankOf(); if(!input->isEmpty()) { allOfSameType &= (typeOfFirstArr == input->dataType()); if(input->rankOf() == 0) { auto vec = new NDArray('c', {1}, input->dataType(), block.launchContext()); vec->assign(input); nonEmptyArrs.push_back(vec); arrsToDelete.push_back(index); } else{ nonEmptyArrs.push_back(input); } ++index; } } const int numOfNonEmptyArrs = nonEmptyArrs.size(); if(numOfNonEmptyArrs == 0){ //All inputs are empty arrays -> return empty, mainly for TF import compatibility (no op) REQUIRE_TRUE(OUTPUT_VARIABLE(0)->isEmpty(), 0, "CONCAT MKLDNN op: If all input variables are empty, output must be empty"); return Status::OK(); } const int rank = nonEmptyArrs[0]->rankOf(); // look up to first non-empty array int axis = isAxisInLastArr ? INPUT_VARIABLE(block.width() - 1)->e(0) : INT_ARG(0); if(axis < 0){ axis += rank; } // ******** input validation ******** // REQUIRE_TRUE(allOfSameType, 0, "CONCAT MKLDNN op: all of input arrays must have same type !"); REQUIRE_TRUE(nonEmptyArrs[0]->dataType() == OUTPUT_VARIABLE(0)->dataType(), 0, "CONCAT MKLDNN op: output array should have the same type as inputs arrays !"); REQUIRE_TRUE(0 <= axis && (axis < rank || (axis == 0 && rank == 0)), 0, "CONCAT MKLDNN op: input axis must be in range [0, %i], but got %i instead!", rank-1, axis); for(int i = 1; i < numOfNonEmptyArrs; ++i) REQUIRE_TRUE(nonEmptyArrs[i]->rankOf() == rank, 0, "CONCAT MKLDNN op: all input arrays must have the same rank !"); for(int i = 1; i < numOfNonEmptyArrs; ++i) { for(int dim = 0; dim < rank; ++dim) if(dim != axis) REQUIRE_TRUE(nonEmptyArrs[i]->sizeAt(dim) == nonEmptyArrs[0]->sizeAt(dim), 0, "CONCAT MKLDNN op: all input arrays must have the same dimensions (except those on input axis) !"); } // ******** end of input validation ******** // auto output = OUTPUT_VARIABLE(0); if(numOfNonEmptyArrs == 1) output->assign(nonEmptyArrs[0]); else concatMKLDNN(nonEmptyArrs, *output, axis); // delete dynamically allocated vectors with length=1 for(int index : arrsToDelete) delete nonEmptyArrs[index]; return Status::OK(); } ////////////////////////////////////////////////////////////////////////// PLATFORM_CHECK(concat, ENGINE_CPU) { auto z = OUTPUT_VARIABLE(0); const auto zType = z->dataType(); const bool isAxisInLastArr = block.getBArguments()->size() == 0 ? false : B_ARG(0); const int numOfInArrs = isAxisInLastArr ? block.width() - 1 : block.width(); return z->rankOf() < 7 && numOfInArrs <= 3072 && (zType==DataType::FLOAT32 || zType==DataType::HALF || zType==DataType::BFLOAT16 || zType==DataType::UINT8 || zType==DataType::INT8); } } } }