diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/split.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/split.cpp index 2176ae711..0e36f0913 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/split.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/split.cpp @@ -22,6 +22,7 @@ #if NOT_EXCLUDED(OP_split) #include +#include #include namespace nd4j { @@ -65,30 +66,13 @@ namespace ops { REQUIRE_TRUE(input->sizeAt(axis) % num_splits == 0, 0, "Split: num_splits has wrong value, remainder of division should be 0, but it's %i", input->sizeAt(axis) % num_splits); - int pos = 0; - int split = input->sizeAt(axis) / num_splits; - std::vector indices(2 * input->rankOf()); - + std::vector outArrs(num_splits); for (int e = 0; e < num_splits; e++) { - - auto out = OUTPUT_VARIABLE(e); - - for (int d = 0; d < input->rankOf(); d++) { - if (d == axis) { - indices[2*d] = pos; - indices[2*d + 1] = pos + split; - } - else - indices[2*d] = indices[2*d + 1] = 0; - } - - auto sub = (*input)(indices, true); - - out->assign(sub); - - pos += split; + outArrs[e] = OUTPUT_VARIABLE(e); } + helpers::split(block.launchContext(), *input, outArrs, axis); + return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/split.cpp b/libnd4j/include/ops/declarable/helpers/cpu/split.cpp new file mode 100644 index 000000000..bdae61f16 --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cpu/split.cpp @@ -0,0 +1,43 @@ +/******************************************************************************* + * Copyright (c) 2019-2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + + // + // @author Oleh Semeniv (oleg.semeniv@gmail.com) + // + + +#include +#include + +namespace nd4j { +namespace ops { +namespace helpers { + +////////////////////////////////////////////////////////////////////////// +template +static void split_(const NDArray& input, const std::vector& outArrs, const int axis) { + nd4j::SpecialMethods::splitCpuGeneric(input, outArrs, axis); +} + +void split(nd4j::LaunchContext* context, const NDArray& input, std::vector& outArrs, const int axis) { + BUILD_SINGLE_SELECTOR(input.dataType(), split_, (input, outArrs, axis), LIBND4J_TYPES); +} + +BUILD_SINGLE_TEMPLATE(template void split_, (const NDArray& input, const std::vector& outArrs, const int axis), LIBND4J_TYPES); + +} +} +} \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/split.cu b/libnd4j/include/ops/declarable/helpers/cuda/split.cu new file mode 100644 index 000000000..fa6b46539 --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cuda/split.cu @@ -0,0 +1,187 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Yurii Shyrma (iuriish@yahoo.com) +// + + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace nd4j { +namespace ops { +namespace helpers { + + +/////////////////////////////////////////////////////////////////// +template +__global__ static void splitCuda(const void* vx, const Nd4jLong* xShapeInfo, void* pVz, const Nd4jLong* zTadShapeInfo, const int axis) { + + const T* x = reinterpret_cast(vx); + + __shared__ Nd4jLong xLen, totalThreads; + __shared__ int xRank, zDim; + + if (threadIdx.x == 0) { + xLen = shape::length(xShapeInfo); + xRank = shape::rank(xShapeInfo); + zDim = shape::shapeOf(zTadShapeInfo)[axis]; // same for all input arrays + totalThreads = gridDim.z * blockDim.z; + } + __syncthreads(); + + const auto tid = blockIdx.z * blockDim.z + threadIdx.z; + + Nd4jLong coords[MAX_RANK]; + + for (uint64_t i = tid; i < xLen; i += totalThreads) { + + shape::index2coords(i, xShapeInfo, coords); + + const auto xOffset = shape::getOffset(xShapeInfo, coords); + + auto *z = reinterpret_cast(reinterpret_cast(pVz)[coords[axis] / zDim]); + + coords[axis] %= zDim; + + const auto zOffset = shape::getOffset(zTadShapeInfo, coords); + + z[zOffset] = x[xOffset]; + } +} + +/////////////////////////////////////////////////////////////////// +template +__host__ static void splitCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, + const void* vx, const Nd4jLong* xShapeInfo, void* pVz, const Nd4jLong* zTadShapeInfo, const int axis) { + + splitCuda<<>>(vx, xShapeInfo, pVz, zTadShapeInfo, axis); +} +BUILD_SINGLE_TEMPLATE(template void splitCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, void* pVz, const Nd4jLong* zTadShapeInfo, const int axis), LIBND4J_TYPES); + +////////////////////////////////////////////////////////////////////////// +void split(nd4j::LaunchContext* context, const NDArray& input, std::vector& outArrs, const int axis) { + + const int numOfSubArrs = outArrs.size(); + const auto sizeofT = input.sizeOfT(); + + for(int i = 0; i < numOfSubArrs; ++i) + outArrs[i]->syncToDevice(); + input.syncToDevice(); + + bool luckCase1 = ((axis == 0 && input.ordering() == 'c') || (axis == input.rankOf() - 1 && input.ordering() == 'f')) && input.ews() == 1; + + if(luckCase1) { + for (uint i = 0; i < numOfSubArrs; ++i) { + luckCase1 &= outArrs[i]->ordering() == input.ordering() && outArrs[i]->ews() == 1; + if(!luckCase1) + break; + } + } + + if(luckCase1) { // for example {1,10} + {2,10} + {3,10} = {6, 10} order c; or {10,1} + {10,2} + {10,3} = {10, 6} order f + + void* x = static_cast(input.getSpecialBuffer()); + + for (uint i = 0; i < numOfSubArrs; ++i) { + const auto memAmountToCopy = outArrs[i]->lengthOf() * sizeofT; + cudaMemcpyAsync(static_cast(outArrs[i]->getSpecialBuffer()), x, memAmountToCopy, cudaMemcpyDeviceToDevice, *context->getCudaStream()); + x = static_cast(x) + memAmountToCopy; + } + + if(cudaStreamSynchronize(*context->getCudaStream()) != 0) + throw std::runtime_error("split cuda: luckCase1 failed!"); + + for(int i = 0; i < numOfSubArrs; ++i) + outArrs[i]->tickWriteDevice(); + input.tickReadDevice(); + + return; + } + + const bool isXcontin = input.strideAt(axis) == 1; + bool areOutputsContin = true; + bool allSameOrder = true; + + if(isXcontin) { + for (uint i = 0; i < outArrs.size(); ++i) { + areOutputsContin &= outArrs[i]->strideAt(axis) == 1; + allSameOrder &= input.ordering() == outArrs[i]->ordering(); + if(!areOutputsContin || !allSameOrder) + break; + } + } + + const bool luckCase2 = isXcontin && areOutputsContin && allSameOrder; + + if(luckCase2) { // for example {2,1,3} + {2,5,3} + {2,10,3} = {2,16,3}, here axis 1 shoud have stride = 1 for all inputs arrays and input array + + const auto xDim = input.sizeAt(axis); + const auto zDim = outArrs[0]->sizeAt(axis); // same for all outArrs + + for (uint i = 0; i < input.lengthOf() / xDim; ++i) { + + const auto iShift = i * sizeofT; + void* x = static_cast(input.getSpecialBuffer()) + xDim * iShift; + + for (uint j = 0; j < numOfSubArrs; ++j) { + void* z = static_cast(outArrs[j]->getSpecialBuffer()) + zDim * iShift; + const auto memSizeToCopy = zDim * sizeofT; + cudaMemcpyAsync(z, x, memSizeToCopy, cudaMemcpyDeviceToDevice, *context->getCudaStream()); + x = static_cast(x) + memSizeToCopy; + } + } + + if(cudaStreamSynchronize(*context->getCudaStream()) != 0) + throw std::runtime_error("split cuda: luckCase2 failed!"); + } + else { // general (slower) case + + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = (input.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + + // prepare arrays of pointers on buffers and shapes + std::vector hOutBuffers(numOfSubArrs); + + for(int i = 0; i < numOfSubArrs; ++i) + hOutBuffers[i] = outArrs[i]->getSpecialBuffer(); + + PointersManager manager(context, "helpers::split"); + + void* dOutBuffers = manager.replicatePointer(hOutBuffers.data(), hOutBuffers.size() * sizeof(void*)); + + BUILD_SINGLE_SELECTOR(input.dataType(), splitCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), dOutBuffers, outArrs[0]->specialShapeInfo(), axis), LIBND4J_TYPES); + + manager.synchronize(); + } + + for(int i = 0; i < numOfSubArrs; ++i) + outArrs[i]->tickWriteDevice(); + input.tickReadDevice(); +} + +} +} +} \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/transforms.h b/libnd4j/include/ops/declarable/helpers/transforms.h index 82ff4a73e..8ceb9c6f8 100644 --- a/libnd4j/include/ops/declarable/helpers/transforms.h +++ b/libnd4j/include/ops/declarable/helpers/transforms.h @@ -73,6 +73,8 @@ namespace helpers { void concat(nd4j::LaunchContext * context, const std::vector& inArrs, NDArray& output, const int axis); void tileBP(nd4j::LaunchContext * context, const NDArray& gradO /*input*/, NDArray& gradI /*output*/, const std::vector reps); + + void split(nd4j::LaunchContext* context, const NDArray& input, std::vector& outArrs, const int axis); } } } diff --git a/libnd4j/include/ops/impl/specials_single.hpp b/libnd4j/include/ops/impl/specials_single.hpp index ad63ee490..779cb5c2a 100644 --- a/libnd4j/include/ops/impl/specials_single.hpp +++ b/libnd4j/include/ops/impl/specials_single.hpp @@ -218,6 +218,99 @@ void SpecialMethods::concatCpuGeneric(int dimension, int numArrays, Nd4jPoint } +template +void SpecialMethods::splitCpuGeneric(const NDArray& input, const std::vector& outArrs, const int axis) { + + int numSplits = outArrs.size(); + + const auto sizeofT = input.sizeOfT(); + + T* xBuff = input.bufferAsT(); + + bool luckCase1 = ((axis == 0 && input.ordering() == 'c') || (axis == input.rankOf() - 1 && input.ordering() == 'f')) && input.ews() == 1; + + if (luckCase1) { + for (uint i = 0; i < numSplits; ++i) { + luckCase1 &= outArrs[i]->ordering() == input.ordering() && outArrs[i]->ews() == 1; + if (!luckCase1) + break; + } + } + + if (luckCase1) { + + T* x = const_cast(xBuff); + for (uint i = 0; i < numSplits; ++i) { + const auto memAmountToCopy = outArrs[i]->lengthOf(); + memcpy(outArrs[i]->bufferAsT(), x, memAmountToCopy * sizeofT); + x += memAmountToCopy; + } + return; + } + + const bool isXcontin = input.strideAt(axis) == 1 && input.ordering() == 'c'; + bool areOutsContin = true; + bool allSameOrder = true; + + if (isXcontin) { + for (uint i = 0; i < numSplits; ++i) { + areOutsContin &= outArrs[i]->strideAt(axis) == 1; + allSameOrder &= outArrs[i]->ordering() == input.ordering(); + if (!areOutsContin || !allSameOrder) + break; + } + } + + const bool luckCase2 = isXcontin && areOutsContin && allSameOrder; + + if (luckCase2) { + + const uint xDim = input.sizeAt(axis); + + for (uint i = 0; i < input.lengthOf() / xDim; ++i) { + + T* x = xBuff + xDim * i; + + for (uint j = 0; j < numSplits; ++j) { + const auto zDim = outArrs[j]->sizeAt(axis); + T* z = outArrs[j]->bufferAsT() + zDim * i; + memcpy(z, x, zDim * sizeofT); + z += zDim; + x += zDim; + } + } + + return; + } + + uint zDim = outArrs[0]->sizeAt(axis); + // general case + + auto func = PRAGMA_THREADS_FOR{ + + Nd4jLong coords[MAX_RANK]; + for (auto i = start; i < stop; i += increment) { + + shape::index2coords(i, input.getShapeInfo(), coords); + const auto xOffset = shape::getOffset(input.getShapeInfo(), coords); + + uint outArrIdx = 0; + + while (coords[axis] >= zDim) { + coords[axis] -= zDim; + ++outArrIdx; + } + + T* z = outArrs[outArrIdx]->bufferAsT(); + const auto zOffset = shape::getOffset(outArrs[outArrIdx]->getShapeInfo(), coords); + z[zOffset] = xBuff[xOffset]; + } + }; + + samediff::Threads::parallel_for(func, 0, input.lengthOf()); +} + + /** * This kernel accumulates X arrays, and stores result into Z * diff --git a/libnd4j/include/ops/specials.h b/libnd4j/include/ops/specials.h index d8030db0b..fea31cf6f 100644 --- a/libnd4j/include/ops/specials.h +++ b/libnd4j/include/ops/specials.h @@ -67,6 +67,8 @@ namespace nd4j { static void decodeBitmapGeneric(void *dx, Nd4jLong N, void *dz, Nd4jLong *zShapeInfo); static Nd4jLong encodeBitmapGeneric(void *dx, Nd4jLong *zShapeInfo, Nd4jLong N, int *dz, float threshold); + + static void splitCpuGeneric(const NDArray& input, const std::vector& outArrs, const int axis); }; template