From 106524663b2cb0d5e6055a573120bb51012c2e20 Mon Sep 17 00:00:00 2001 From: raver119 Date: Mon, 2 Sep 2019 15:24:51 +0300 Subject: [PATCH 01/19] fix double consumption of rng on cpu Signed-off-by: raver119 --- libnd4j/include/loops/cpu/random.cpp | 8 -------- 1 file changed, 8 deletions(-) diff --git a/libnd4j/include/loops/cpu/random.cpp b/libnd4j/include/loops/cpu/random.cpp index 889e48181..30bab1327 100644 --- a/libnd4j/include/loops/cpu/random.cpp +++ b/libnd4j/include/loops/cpu/random.cpp @@ -162,9 +162,6 @@ namespace functions { } } } - - // update rng state - rng->rewindH(length); }; @@ -223,8 +220,6 @@ namespace functions { } } } - // update rng state - rng->rewindH(length); } @@ -256,9 +251,6 @@ namespace functions { z[offset] = OpClass::op(i+threadOffset, length, rng, extraArguments); } } - - // update rng state - rng->rewindH(length); } template From cb4c9377b19710ad6e086b56759b93778e8185af Mon Sep 17 00:00:00 2001 From: Yurii Shyrma Date: Mon, 2 Sep 2019 16:25:58 +0300 Subject: [PATCH 02/19] Shyrma docs (#222) * - documenting and profiling matrix_set_diag cuda kernel Signed-off-by: Yurii * - correct formula of pnorm pooling in cuda 2d/3d kernels - remove helper matrix_diag which duplicates work of helper matrix_set_diag Signed-off-by: Yurii --- .../generic/parity_ops/matrixSetDiag.cpp | 7 +- .../generic/parity_ops/matrix_diag.cpp | 64 +++++----- .../ops/declarable/headers/parity_ops.h | 16 ++- .../declarable/helpers/cpu/convolutions.cpp | 2 +- .../declarable/helpers/cpu/matrixSetDiag.cpp | 57 +++++---- .../declarable/helpers/cpu/matrix_diag.cpp | 65 ----------- .../declarable/helpers/cuda/convolutions.cu | 23 ++-- .../declarable/helpers/cuda/matrixSetDiag.cu | 110 +++++++++++------- .../declarable/helpers/cuda/matrix_diag.cu | 95 --------------- .../ops/declarable/helpers/matrixSetDiag.h | 3 +- .../ops/declarable/helpers/matrix_diag.h | 34 ------ .../layers_tests/DeclarableOpsTests3.cpp | 28 ++--- .../tests_cpu/layers_tests/SortCudaTests.cu | 6 +- 13 files changed, 190 insertions(+), 320 deletions(-) delete mode 100644 libnd4j/include/ops/declarable/helpers/cpu/matrix_diag.cpp delete mode 100644 libnd4j/include/ops/declarable/helpers/cuda/matrix_diag.cu delete mode 100644 libnd4j/include/ops/declarable/helpers/matrix_diag.h diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/matrixSetDiag.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/matrixSetDiag.cpp index f63469817..3a52057a5 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/matrixSetDiag.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/matrixSetDiag.cpp @@ -15,7 +15,7 @@ ******************************************************************************/ // -// @author Yurii Shyrma (iuriish@yahoo.com), created on 07.12.2017 +// @author Yurii Shyrma (iuriish@yahoo.com) // #include @@ -38,10 +38,9 @@ CONFIGURABLE_OP_IMPL(matrix_set_diag, 2, 1, false, 0, 0) { for(int i = 0; i < diagonal->rankOf() - 1; ++i) REQUIRE_TRUE(diagonal->sizeAt(i) == input->sizeAt(i), 0, "MATRIX_SET_DIAG op: the shapes of diagonal and input arrays must be equal till last diagonal dimension but one, however got diagonal=%s and input=%s instead !", ShapeUtils::shapeAsString(diagonal).c_str(), ShapeUtils::shapeAsString(input).c_str()); - REQUIRE_TRUE(diagonal->sizeAt(-1) == (int)nd4j::math::nd4j_min(input->sizeAt(-1), input->sizeAt(-2)), - 0, "MATRIX_SET_DIAG op: the value of last dimension of diagonal array must be equal to min(input_last_shape=%i, input_last_but_one_shape=%i), but got %i instead !", input->sizeAt(-1), input->sizeAt(-2), diagonal->sizeAt(-1)); + REQUIRE_TRUE(diagonal->sizeAt(-1) == (int)nd4j::math::nd4j_min(input->sizeAt(-1), input->sizeAt(-2)), 0, "MATRIX_SET_DIAG op: the value of last dimension of diagonal array must be equal to min(input_last_shape=%i, input_last_but_one_shape=%i), but got %i instead !", input->sizeAt(-1), input->sizeAt(-2), diagonal->sizeAt(-1)); - helpers::matrixSetDiag(block.launchContext(), input, diagonal, output); + helpers::matrixSetDiag(block.launchContext(), *input, *diagonal, *output, false); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/matrix_diag.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/matrix_diag.cpp index 8fa5bfa41..c430fd4d2 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/matrix_diag.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/matrix_diag.cpp @@ -15,49 +15,53 @@ ******************************************************************************/ // -// Created to use with batched tensor by GS 3/21/2018 +// @author GS 3/21/2018 +// @author Yurii Shyrma (iuriish@yahoo.com) // #include -#include - +#include namespace nd4j { - namespace ops { - CUSTOM_OP_IMPL(matrix_diag, 1, 1, false, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); +namespace ops { - REQUIRE_TRUE(!input->isScalar(), 0, "CUSTOM_OP matrix_diag: input array must be at list a vector, but scalar was given!"); +CUSTOM_OP_IMPL(matrix_diag, 1, 1, false, 0, 0) { - output->nullify(); - return helpers::matrixDiag(block.launchContext(), input, output); - } + auto diagonal = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - DECLARE_SHAPE_FN(matrix_diag) { - Nd4jLong* outShapeInfo = nullptr; - auto in = inputShape->at(0); - int inRank = shape::rank(in); + REQUIRE_TRUE(!diagonal->isScalar(), 0, "CUSTOM_OP matrix_diag: input diagonal array must be at list a vector, but scalar was given!"); - int outRank = inRank + 1; - auto lastDimension = shape::sizeAt(in, -1); + helpers::matrixSetDiag(block.launchContext(), *output, *diagonal, *output, true); - ALLOCATE(outShapeInfo, block.getWorkspace(), shape::shapeInfoLength(outRank), Nd4jLong); - outShapeInfo[0] = outRank; - for(int i = 0; i < inRank; ++i) - outShapeInfo[i + 1] = shape::sizeAt(in, i); - outShapeInfo[outRank] = lastDimension; + return Status::OK(); +} - ShapeUtils::updateStridesAndType(outShapeInfo, in, shape::order(in)); +DECLARE_SHAPE_FN(matrix_diag) { - return SHAPELIST(CONSTANT(outShapeInfo)); - } + Nd4jLong* outShapeInfo = nullptr; + auto in = inputShape->at(0); + int inRank = shape::rank(in); - DECLARE_TYPES(matrix_diag) { - getOpDescriptor() - ->setAllowedInputTypes(nd4j::DataType::ANY) - ->setSameMode(true); - } + int outRank = inRank + 1; + auto lastDimension = shape::sizeAt(in, -1); + + ALLOCATE(outShapeInfo, block.getWorkspace(), shape::shapeInfoLength(outRank), Nd4jLong); + outShapeInfo[0] = outRank; + for(int i = 0; i < inRank; ++i) + outShapeInfo[i + 1] = shape::sizeAt(in, i); + outShapeInfo[outRank] = lastDimension; + + ShapeUtils::updateStridesAndType(outShapeInfo, in, shape::order(in)); + + return SHAPELIST(CONSTANT(outShapeInfo)); +} + +DECLARE_TYPES(matrix_diag) { + getOpDescriptor() + ->setAllowedInputTypes(nd4j::DataType::ANY) + ->setSameMode(true); +} } } diff --git a/libnd4j/include/ops/declarable/headers/parity_ops.h b/libnd4j/include/ops/declarable/headers/parity_ops.h index f9278fb36..c86f28499 100644 --- a/libnd4j/include/ops/declarable/headers/parity_ops.h +++ b/libnd4j/include/ops/declarable/headers/parity_ops.h @@ -76,8 +76,20 @@ namespace nd4j { #endif /** - * Returns a batched matrix tensor with new batched diagonal values. - */ + * Inserts elements provided by diagonal array into the main diagonal of innermost matrices of input array + * + * Input arrays: + * input: input array, considered as batch of matrices + * diagonal: array containing elements to be inserted into input array, + * following rank condition should be satisfied: diagonal_rank = input_rank - 1, + * the shapes of diagonal and input arrays must be equal except last dimension of input array, + * for example if input_shape = [A,B,C,D] then diagonal_shape = [A,B,C], + * also last dimension of diagonal array should be equal to smaller of last and last but one input dimensions + * that is: diagonal_shape[-1] = min(input_shape[-1], input_shape[-2]) + * + * Output array: + * has the same shape as input, corresponding diagonal elements are substituted + */ #if NOT_EXCLUDED(OP_matrix_set_diag) DECLARE_CONFIGURABLE_OP(matrix_set_diag, 2, 1, false, 0, 0); #endif diff --git a/libnd4j/include/ops/declarable/helpers/cpu/convolutions.cpp b/libnd4j/include/ops/declarable/helpers/cpu/convolutions.cpp index dd5516461..3d04bc129 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/convolutions.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/convolutions.cpp @@ -2411,7 +2411,7 @@ void ConvolutionUtils::getMKLDNNMemoryDescConv3d( for (Nd4jLong kd = dstart; kd < dend; kd += iStep2) for (Nd4jLong kh = hstart; kh < hend; kh += iStep3) for (Nd4jLong kw = wstart; kw < wend; kw += iStep4) - pgI[kd + kh + kw] += valO * nd4j::math::nd4j_pow(nd4j::math::nd4j_abs(pIn[kd + kh + kw]), extraParam0 - (T)1.f); + pgI[kd + kh + kw] += valO * nd4j::math::nd4j_pow(nd4j::math::nd4j_abs(pIn[kd + kh + kw]), extraParam0 - (T)1.f) * nd4j::math::nd4j_sgn(pIn[kd + kh + kw]); } else { diff --git a/libnd4j/include/ops/declarable/helpers/cpu/matrixSetDiag.cpp b/libnd4j/include/ops/declarable/helpers/cpu/matrixSetDiag.cpp index 7180a88b3..e974755ac 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/matrixSetDiag.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/matrixSetDiag.cpp @@ -15,7 +15,7 @@ ******************************************************************************/ // -// Created by Yurii Shyrma on 07.12.2017. +// @author Yurii Shyrma (iuriish@yahoo.com) // #include "ResultSet.h" @@ -27,31 +27,48 @@ namespace helpers { ////////////////////////////////////////////////////////////////////////// -// Returns a batched matrix tensor with new batched diagonal values. -// for detailed explanations please take a look on web page: https://www.tensorflow.org/api_docs/python/tf/matrix_set_diag -template -static void _matrixSetDiag(const NDArray* input, const NDArray* diagonal, NDArray* output) { +template +void matrixSetDiag_(const NDArray& input, const NDArray& diagonal, NDArray& output, const bool zeroPad) { - *output = *input; + // input and output are the same array (x == z) when zeroPad = true + // xRank = zRank, xRank = yRank + 1 + // xLen = zLen - const int lastDimSize = input->sizeAt(-1); - const int last2DimSize = input->sizeAt(-1) * input->sizeAt(-2); - const int lastSmallDim = diagonal->sizeAt(-1); - const int batchSize = input->lengthOf()/last2DimSize; + const T* x = input.bufferAsT(); + const T* y = diagonal.bufferAsT(); + T* z = output.bufferAsT(); - for(int i = 0; i < batchSize; ++i ) - for(int j = 0; j < lastSmallDim; ++j) { - output->p(i*last2DimSize + j*(lastDimSize + 1), diagonal->e(i*lastSmallDim + j)); - } - + const Nd4jLong* xShapeInfo = input.getShapeInfo(); + const Nd4jLong* yShapeInfo = diagonal.getShapeInfo(); + const Nd4jLong* zShapeInfo = output.getShapeInfo(); + const bool areSameOffsets = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); // shapes are definitely the same, but strides might not + + const int xRank = input.rankOf(); + const auto xLen = input.lengthOf(); + + std::vector coords(xRank); // we use the same coordinates storage both for input and output since their ranks are the same + + PRAGMA_OMP_PARALLEL_FOR_ARGS(firstprivate(coords)) + for (Nd4jLong i = 0; i < xLen; ++i) { + + shape::index2coords(xRank, xShapeInfo + 1, i, xLen, coords.data()); + + const auto xOffset = shape::getOffset(0, xShapeInfo + 1, xShapeInfo + xRank + 1, coords.data(), xRank); + const auto zOffset = areSameOffsets ? xOffset : shape::getOffset(0, zShapeInfo + 1, zShapeInfo + xRank + 1, coords.data(), xRank); + + // condition to be on diagonal of innermost matrix + if(coords[xRank - 2] == coords[xRank - 1]) + z[zOffset] = y[shape::getOffset(0, yShapeInfo + 1, yShapeInfo + xRank, coords.data(), xRank - 1)]; + else + z[zOffset] = zeroPad ? static_cast(0) : x[xOffset]; + } } - void matrixSetDiag(nd4j::LaunchContext * context, const NDArray* input, const NDArray* diagonal, NDArray* output) { - BUILD_SINGLE_SELECTOR(input->dataType(), _matrixSetDiag, (input, diagonal, output), LIBND4J_TYPES); - } - - BUILD_SINGLE_TEMPLATE(template void _matrixSetDiag, (const NDArray* input, const NDArray* diagonal, NDArray* output), LIBND4J_TYPES); +////////////////////////////////////////////////////////////////////////// +void matrixSetDiag(nd4j::LaunchContext* context, const NDArray& input, const NDArray& diagonal, NDArray& output, const bool zeroPad) { + BUILD_SINGLE_SELECTOR(input.dataType(), matrixSetDiag_, (input, diagonal, output, zeroPad), LIBND4J_TYPES); +} } } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/matrix_diag.cpp b/libnd4j/include/ops/declarable/helpers/cpu/matrix_diag.cpp deleted file mode 100644 index 3f9883b54..000000000 --- a/libnd4j/include/ops/declarable/helpers/cpu/matrix_diag.cpp +++ /dev/null @@ -1,65 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// Created by GS on 3/21/2018. -// - -#include "ResultSet.h" -#include -#include - -namespace nd4j { -namespace ops { -namespace helpers { - - -////////////////////////////////////////////////////////////////////////// -// Returns a batched matrix tensor with new batched diagonal values. -// for detailed explanations please take a look on web page: https://www.tensorflow.org/api_docs/python/tf/matrix_set_diag -template -static int _matrixDiag(const NDArray* input, NDArray* output) { - - auto listOut = output->allTensorsAlongDimension({output->rankOf() - 2, output->rankOf() - 1}); - auto listDiag = input->allTensorsAlongDimension({input->rankOf() - 1}); - - if (listOut->size() != listDiag->size()) { - nd4j_printf("matrix_diag: Input matrix has wrong shape.", ""); - return ND4J_STATUS_VALIDATION; - } - int lastDimension = input->sizeAt(-1); - // TODO: tune this properlys - int lO = listOut->size(); - PRAGMA_OMP_PARALLEL_FOR_IF(lO > Environment::getInstance()->tadThreshold()) - for(int i = 0; i < lO; ++i) - for (int e = 0; e < lastDimension; e++) - listOut->at(i)->p(e, e, listDiag->at(i)->e(e)); - - delete listOut; - delete listDiag; - - return Status::OK(); -} - - int matrixDiag(nd4j::LaunchContext * context, const NDArray* input, NDArray* output) { - BUILD_SINGLE_SELECTOR(input->dataType(), return _matrixDiag, (input, output), LIBND4J_TYPES); - } - - BUILD_SINGLE_TEMPLATE(template int _matrixDiag, (const NDArray* input, NDArray* output), LIBND4J_TYPES); - -} -} -} \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/convolutions.cu b/libnd4j/include/ops/declarable/helpers/cuda/convolutions.cu index 87e7c4f08..c08551318 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/convolutions.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/convolutions.cu @@ -957,9 +957,13 @@ __global__ static void pooling2dBPCuda(const void* vx, const Nd4jLong* xShapeInf val *= nd4j::math::nd4j_pow(sum, ((T)1.f - extraParam0) / extraParam0); - for (coords[2] = hstart; coords[2] < hend; coords[2] += dH) - for (coords[3] = wstart; coords[3] < wend; coords[3] += dW) - nd4j::math::atomics::nd4j_atomicAdd(&z[shape::getOffset(0, zShapeInfo + 1, zShapeInfo + rank + 1, coords, rank)], val * nd4j::math::nd4j_pow(nd4j::math::nd4j_abs(x[shape::getOffset(0, xShapeInfo + 1, xShapeInfo + rank + 1, coords, rank)]), extraParam0 - 1.f)); + for (coords[2] = hstart; coords[2] < hend; coords[2] += dH) { + for (coords[3] = wstart; coords[3] < wend; coords[3] += dW) { + const auto xOffset = shape::getOffset(0, xShapeInfo + 1, xShapeInfo + rank + 1, coords, rank); + const auto zOffset = shape::getOffset(0, zShapeInfo + 1, zShapeInfo + rank + 1, coords, rank); + nd4j::math::atomics::nd4j_atomicAdd(&z[zOffset], val * nd4j::math::nd4j_pow(nd4j::math::nd4j_abs(x[xOffset]), extraParam0 - 1.f) * nd4j::math::nd4j_sgn(x[xOffset])); + } + } } break; } @@ -1123,10 +1127,15 @@ __global__ static void pooling3dBPCuda(const void* vx, const Nd4jLong* xShapeInf val *= nd4j::math::nd4j_pow(sum, ((T)1.f - extraParam0) / extraParam0); - for (coords[2] = dstart; coords[2] < dend; coords[2] += dD) - for (coords[3] = hstart; coords[3] < hend; coords[3] += dH) - for (coords[4] = wstart; coords[4] < wend; coords[4] += dW) - nd4j::math::atomics::nd4j_atomicAdd(&z[shape::getOffset(0, zShapeInfo + 1, zShapeInfo + rank + 1, coords, rank)], val * nd4j::math::nd4j_pow(nd4j::math::nd4j_abs(x[shape::getOffset(0, xShapeInfo + 1, xShapeInfo + rank + 1, coords, rank)]), extraParam0 - 1.f)); + for (coords[2] = dstart; coords[2] < dend; coords[2] += dD) { + for (coords[3] = hstart; coords[3] < hend; coords[3] += dH) { + for (coords[4] = wstart; coords[4] < wend; coords[4] += dW) { + const auto xOffset = shape::getOffset(0, xShapeInfo + 1, xShapeInfo + rank + 1, coords, rank); + const auto zOffset = shape::getOffset(0, zShapeInfo + 1, zShapeInfo + rank + 1, coords, rank); + nd4j::math::atomics::nd4j_atomicAdd(&z[zOffset], val * nd4j::math::nd4j_pow(nd4j::math::nd4j_abs(x[xOffset]), extraParam0 - 1.f) * nd4j::math::nd4j_sgn(x[xOffset])); + } + } + } } break; } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/matrixSetDiag.cu b/libnd4j/include/ops/declarable/helpers/cuda/matrixSetDiag.cu index 95eb5f439..01baaffb4 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/matrixSetDiag.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/matrixSetDiag.cu @@ -15,63 +15,87 @@ ******************************************************************************/ // -// Created by Yurii Shyrma on 07.12.2017. +// @author Yurii Shyrma (iuriish@yahoo.com) // #include "ResultSet.h" #include +#include -namespace nd4j { -namespace ops { +namespace nd4j { +namespace ops { namespace helpers { +/////////////////////////////////////////////////////////////////// +template +__global__ static void matrixSetDiagCuda(const void* vx, const Nd4jLong* xShapeInfo, const void* vy, const Nd4jLong* yShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const bool zeroPad) { - template - static __global__ void matrixSetDiagKernel(void* outputBuffer, Nd4jLong* outputShape, void const* diagonalBuffer, Nd4jLong* diagonalShape, Nd4jLong lastDimSize, Nd4jLong last2DimSize, Nd4jLong lastSmallDim, Nd4jLong batchSize) { - __shared__ T* z; - __shared__ T const* x; - __shared__ Nd4jLong outLength, diagonalLen; - if (threadIdx.x == 0) { - z = reinterpret_cast(outputBuffer); - x = reinterpret_cast(diagonalBuffer); - outLength = shape::length(outputShape); - diagonalLen = shape::length(diagonalShape); - } - __syncthreads(); + // x - input, shape [A,B,C] + // y - diagonal, shape [A,B] + // z - output, shape [A,B,C] + // input and output are the same array (x == z) when zeroPad = true - for(int i = blockIdx.x; i < batchSize; i+= gridDim.x ) - for(int j = threadIdx.x; j < lastSmallDim; j += blockDim.x) { -// z[i * last2DimSize + j * (lastDimSize + 1)] = x[i * lastSmallDim + j]; - z[shape::getIndexOffset(i * last2DimSize + j * (lastDimSize + 1), outputShape, outLength)] = x[shape::getIndexOffset(i * lastSmallDim + j, diagonalShape, diagonalLen)]; - } - } - ////////////////////////////////////////////////////////////////////////// - // Returns a batched matrix tensor with new batched diagonal values. - // for detailed explanations please take a look on web page: https://www.tensorflow.org/api_docs/python/tf/matrix_set_diag - template - static void _matrixSetDiag(nd4j::LaunchContext * context, const NDArray* input, const NDArray* diagonal, NDArray* output) { - *output = *input; + const auto x = reinterpret_cast(vx); + const auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); - const int lastDimSize = input->sizeAt(-1); - const int last2DimSize = input->sizeAt(-1) * input->sizeAt(-2); - const int lastSmallDim = diagonal->sizeAt(-1); - const int batchSize = input->lengthOf()/last2DimSize; - auto stream = context->getCudaStream(); - dim3 launchDims(256, 512, 8192); - matrixSetDiagKernel<<>>(output->specialBuffer(), output->specialShapeInfo(), diagonal->getSpecialBuffer(), diagonal->getSpecialShapeInfo(), lastDimSize, last2DimSize, lastSmallDim, batchSize); -//// #pragma omp parallel for if(batchSize > Environment::getInstance()->elementwiseThreshold()) schedule(static) -// for(int i = 0; i < batchSize; ++i ) -// for(int j = 0; j < lastSmallDim; ++j) { -// output->p(i*last2DimSize + j*(lastDimSize + 1), diagonal->e(i*lastSmallDim + j)); -// } + __shared__ int xRank; // xRank = zRank, xRank = yRank + 1 + __shared__ Nd4jLong xLen, *sharedMem; // xLen = zLen + __shared__ bool areSameOffsets; + if (threadIdx.x == 0) { + + extern __shared__ unsigned char shmem[]; + sharedMem = reinterpret_cast(shmem); + + areSameOffsets = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); // shapes are definitely the same, but strides might not + + xRank = shape::rank(xShapeInfo); + xLen = shape::length(xShapeInfo); } - void matrixSetDiag(nd4j::LaunchContext * context, const NDArray* input, const NDArray* diagonal, NDArray* output) { - BUILD_SINGLE_SELECTOR(input->dataType(), _matrixSetDiag, (context, input, diagonal, output), LIBND4J_TYPES); - } + __syncthreads(); - BUILD_SINGLE_TEMPLATE(template void _matrixSetDiag, (nd4j::LaunchContext * context, const NDArray* input, const NDArray* diagonal, NDArray* output), LIBND4J_TYPES); + auto coords = sharedMem + threadIdx.x * xRank; // we provide (xRank * sizeof(Nd4jLong) * threadIdx.x) amount of shared memory per each thread + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + + for (Nd4jLong i = tid; i < xLen; i += gridDim.x * blockDim.x) { + + shape::index2coords(xRank, xShapeInfo + 1, i, xLen, coords); + + const auto xOffset = shape::getOffset(0, xShapeInfo + 1, xShapeInfo + xRank + 1, coords, xRank); + const auto zOffset = areSameOffsets ? xOffset : shape::getOffset(0, zShapeInfo + 1, zShapeInfo + xRank + 1, coords, xRank); + + // condition to be on diagonal of innermost matrix + if(coords[xRank - 2] == coords[xRank - 1]) + z[zOffset] = y[shape::getOffset(0, yShapeInfo + 1, yShapeInfo + xRank, coords, xRank - 1)]; + else + z[zOffset] = zeroPad ? static_cast(0) : x[xOffset]; + } +} + +/////////////////////////////////////////////////////////////////// +template +static void matrixSetDiagCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, const void* vy, const Nd4jLong* yShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const bool zeroPad) { + + matrixSetDiagCuda<<>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, zeroPad); +} + +/////////////////////////////////////////////////////////////////// +void matrixSetDiag(nd4j::LaunchContext* context, const NDArray& input, const NDArray& diagonal, NDArray& output, const bool zeroPad) { + + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = (input.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + const int sharedMem = threadsPerBlock * sizeof(Nd4jLong) * input.rankOf() + 128; + + PointersManager manager(context, "matrixSetDiag"); + + NDArray::prepareSpecialUse({&output}, {&input, &diagonal}); + BUILD_SINGLE_SELECTOR(input.dataType(), matrixSetDiagCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), diagonal.getSpecialBuffer(), diagonal.getSpecialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), zeroPad), LIBND4J_TYPES); + NDArray::registerSpecialUse({&output}, {&input, &diagonal}); + + manager.synchronize(); +} } } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/matrix_diag.cu b/libnd4j/include/ops/declarable/helpers/cuda/matrix_diag.cu deleted file mode 100644 index 78304510d..000000000 --- a/libnd4j/include/ops/declarable/helpers/cuda/matrix_diag.cu +++ /dev/null @@ -1,95 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// Created by GS on 3/21/2018. -// - -#include "ResultSet.h" -#include -#include -#include -#include -#include -#include -#include - -namespace nd4j { -namespace ops { -namespace helpers { - - - template - static __global__ void matrixDiagKernel(void const* inputBuffer, void* outputBuffer, Nd4jLong numTads, Nd4jLong inputLength, - Nd4jLong* tadOnlyInputShapeInfo, Nd4jLong *tadInputOffsets, - Nd4jLong* tadOnlyOutputShapeInfo, Nd4jLong *tadOutputOffsets) { - int totalThreads = blockDim.x; - for (Nd4jLong i = blockIdx.x; i < numTads; i += gridDim.x) { - auto yOffset = tadInputOffsets[i]; - auto xOffset = tadOutputOffsets[i]; - for (Nd4jLong j = threadIdx.x; j < inputLength; j += totalThreads) { - Nd4jLong coords[2] = {j, j}; - Nd4jLong tadOffset = shape::getOffset(0, shape::shapeOf(tadOnlyOutputShapeInfo), shape::stride(tadOnlyOutputShapeInfo), coords, 2); - //shape::getIndexOffset(j, tadOnlyOutputShapeInfo, inputLength) - *(reinterpret_cast(outputBuffer) + xOffset + tadOffset) = *(reinterpret_cast(inputBuffer) + yOffset + shape::getIndexOffset(j, tadOnlyInputShapeInfo, inputLength)); - } - } - } - ////////////////////////////////////////////////////////////////////////// - // Returns a batched matrix tensor with new batched diagonal values. - // for detailed explanations please take a look on web page: https://www.tensorflow.org/api_docs/python/tf/matrix_set_diag - - template - static int _matrixDiag(nd4j::LaunchContext * context, const NDArray* input, NDArray* output) { - cudaStream_t* stream = context->getCudaStream(); - //auto listOut = output->allTensorsAlongDimension({output->rankOf() - 2, output->rankOf() - 1}); - //auto listDiag = input->allTensorsAlongDimension({input->rankOf() - 1}); - - //auto repeatDelta = shape::prodLong(newShape.data(), rank) / this->lengthOf(); - std::vector dimsToExclude = ShapeUtils::evalDimsToExclude(input->rankOf(), {input->rankOf() - 1}); - const Nd4jLong numTads = ShapeUtils::getNumOfSubArrs(input->getShapeInfo(), dimsToExclude); //this->tensorsAlongDimension({dimension}); - //printf("Repeat delta %lld, numTads %lld\n", repeatDelta, numTads); - //tadOnlyInputShapeInfo, tadInputOffsets, tadOnlyOutputShapeInfo, tadOutputOffsets; - std::vector inputDims({input->rankOf() - 1}); - std::vector outputDims({output->rankOf() - 2, output->rankOf() - 1}); - - auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), inputDims); - auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), outputDims); - - if (!input->isActualOnDeviceSide()) - input->syncToDevice(); - - if (!output->isActualOnDeviceSide()) - output->syncToDevice(); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - - dim3 launchDims(256, 512, 8192); - matrixDiagKernel<<>>(input->getSpecialBuffer(), output->getSpecialBuffer(), numTads, input->sizeAt(-1), packX.specialShapeInfo(), packX.specialOffsets(), packZ.specialShapeInfo(), packZ.specialOffsets()); - - return Status::OK(); - } - - int matrixDiag(nd4j::LaunchContext * context, const NDArray* input, NDArray* output) { - BUILD_SINGLE_SELECTOR(input->dataType(), return _matrixDiag, (context, input, output), LIBND4J_TYPES); - } - - BUILD_SINGLE_TEMPLATE(template int _matrixDiag, (nd4j::LaunchContext * context, const NDArray* input, NDArray* output), LIBND4J_TYPES); - -} -} -} \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/matrixSetDiag.h b/libnd4j/include/ops/declarable/helpers/matrixSetDiag.h index ea5a1a4ad..fb7d57d18 100644 --- a/libnd4j/include/ops/declarable/helpers/matrixSetDiag.h +++ b/libnd4j/include/ops/declarable/helpers/matrixSetDiag.h @@ -28,8 +28,7 @@ namespace nd4j { namespace ops { namespace helpers { - void matrixSetDiag(nd4j::LaunchContext * context, const NDArray* input, const NDArray* diagonal, NDArray* output); - + void matrixSetDiag(nd4j::LaunchContext* context, const NDArray& input, const NDArray& diagonal, NDArray& output, const bool zeroPad); } } diff --git a/libnd4j/include/ops/declarable/helpers/matrix_diag.h b/libnd4j/include/ops/declarable/helpers/matrix_diag.h deleted file mode 100644 index 0cbbcef16..000000000 --- a/libnd4j/include/ops/declarable/helpers/matrix_diag.h +++ /dev/null @@ -1,34 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// @author GS -// -#ifndef __MATRIX_DIAG_HELPERS__ -#define __MATRIX_DIAG_HELPERS__ -#include -#include - -namespace nd4j { -namespace ops { -namespace helpers { - - int matrixDiag(nd4j::LaunchContext * context, NDArray const* input, NDArray* output); - -} -} -} -#endif diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests3.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests3.cpp index 1ec9650f9..7d166f831 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests3.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests3.cpp @@ -117,9 +117,9 @@ TEST_F(DeclarableOpsTests3, Test_Unique_1) { auto v = result->at(0); auto i = result->at(1); - v->printIndexedBuffer("Values"); - i->printIndexedBuffer("Indices"); - i->printShapeInfo("Indices shape"); + // v->printIndexedBuffer("Values"); + // i->printIndexedBuffer("Indices"); + // i->printShapeInfo("Indices shape"); ASSERT_TRUE(expV.isSameShape(v)); ASSERT_TRUE(expV.equalsTo(v)); @@ -145,12 +145,12 @@ TEST_F(DeclarableOpsTests3, Test_Unique_2) { auto i = result->at(1); auto c = result->at(2); - v->printShapeInfo(); - v->printIndexedBuffer("Values"); - i->printShapeInfo(); - i->printIndexedBuffer("Indices"); - c->printShapeInfo(); - c->printIndexedBuffer("Counts"); + // v->printShapeInfo(); + // v->printIndexedBuffer("Values"); + // i->printShapeInfo(); + // i->printIndexedBuffer("Indices"); + // c->printShapeInfo(); + // c->printIndexedBuffer("Counts"); ASSERT_TRUE(expV.isSameShape(v)); ASSERT_TRUE(expV.equalsTo(v)); @@ -200,11 +200,11 @@ TEST_F(DeclarableOpsTests3, Test_Norm_1) { auto result1 = op.execute({&x}, {1.}, {1}); ASSERT_EQ(result1->status(), ND4J_STATUS_OK); auto z1 = result1->at(0); - z1->printIndexedBuffer("Z1"); + // z1->printIndexedBuffer("Z1"); auto exp1 = x.reduceAlongDims(reduce::Norm2, dims, false, false); - exp1.printIndexedBuffer("EXP1"); - z1->printShapeInfo("Z1 shape"); - exp1.printShapeInfo("EXP1 shape"); + // exp1.printIndexedBuffer("EXP1"); + // z1->printShapeInfo("Z1 shape"); + // exp1.printShapeInfo("EXP1 shape"); ASSERT_TRUE(exp1.isSameShape(z1)); ASSERT_TRUE(exp1.equalsTo(z1)); @@ -714,7 +714,7 @@ TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_7) { auto exp = MmulHelper::mmul(&x, &y); - exp->printShapeInfo("exp shape"); + // exp->printShapeInfo("exp shape"); nd4j::ops::batched_gemm op; auto result = op.execute({&a, &b, &x, &x, &x, &y, &y, &y}, {}, {112, 112, 2, 3, 5, 5, 3, 2, 3}); diff --git a/libnd4j/tests_cpu/layers_tests/SortCudaTests.cu b/libnd4j/tests_cpu/layers_tests/SortCudaTests.cu index 49c1f7a95..6913722be 100644 --- a/libnd4j/tests_cpu/layers_tests/SortCudaTests.cu +++ b/libnd4j/tests_cpu/layers_tests/SortCudaTests.cu @@ -79,7 +79,7 @@ TEST_F(SortCudaTests, test_linear_sort_by_val_2) { sortByValue(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), true); k.tickWriteDevice(); v.tickWriteDevice(); - k.printIndexedBuffer("KEYS"); + // k.printIndexedBuffer("KEYS"); ASSERT_EQ(ek, k); ASSERT_EQ(ev, v); } @@ -98,8 +98,8 @@ TEST_F(SortCudaTests, test_tad_sort_by_key_1) { k.tickWriteDevice(); v.tickWriteDevice(); - k.printIndexedBuffer("k"); - v.printIndexedBuffer("v"); + // k.printIndexedBuffer("k"); + // v.printIndexedBuffer("v"); ASSERT_EQ(ek, k); ASSERT_EQ(ev, v); From 18828f97252669ff5f10cd41e8c331a47e77183c Mon Sep 17 00:00:00 2001 From: raver119 Date: Mon, 2 Sep 2019 16:52:10 +0300 Subject: [PATCH 03/19] cublasHandle sharing + lock Signed-off-by: raver119 --- .../jita/handler/impl/CudaZeroHandler.java | 19 ++++++++++++++++--- .../java/org/nd4j/nativeblas/Nd4jCuda.java | 3 +++ .../java/org/nd4j/nativeblas/Nd4jCpu.java | 3 +++ 3 files changed, 22 insertions(+), 3 deletions(-) diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/impl/CudaZeroHandler.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/impl/CudaZeroHandler.java index fdd40f8cb..23301f4be 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/impl/CudaZeroHandler.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/impl/CudaZeroHandler.java @@ -16,6 +16,7 @@ package org.nd4j.jita.handler.impl; +import org.nd4j.nativeblas.OpaqueLaunchContext; import org.nd4j.shade.guava.collect.HashBasedTable; import org.nd4j.shade.guava.collect.Table; import lombok.Getter; @@ -105,6 +106,8 @@ public class CudaZeroHandler implements MemoryHandler { private final AllocationStatus INITIAL_LOCATION; + private final List cublasHandles = new ArrayList<>(); + private final AffinityManager affinityManager = Nd4j.getAffinityManager(); /* @@ -162,6 +165,7 @@ public class CudaZeroHandler implements MemoryHandler { int numDevices = NativeOpsHolder.getInstance().getDeviceNativeOps().getAvailableDevices(); for (int i = 0; i < numDevices; i++) { deviceAllocations.add(new ConcurrentHashMap()); + cublasHandles.add(null); } if (NativeOpsHolder.getInstance().getDeviceNativeOps().getDeviceMajor(0) < 3) { @@ -1176,6 +1180,17 @@ public class CudaZeroHandler implements MemoryHandler { return getCudaContext(); } + + + protected synchronized cublasHandle_t getCudaCublasHandle(OpaqueLaunchContext lc) { + val deviceId = Nd4j.getAffinityManager().getDeviceForCurrentThread(); + if (cublasHandles.get(deviceId) == null) + cublasHandles.remove(deviceId); + cublasHandles.add(deviceId, new cublasHandle_t(nativeOps.lcBlasHandle(lc))); + + return cublasHandles.get(deviceId); + } + /** * This method returns CudaContext for current thread. If context doesn't exist - it gets created first. * @return @@ -1183,8 +1198,6 @@ public class CudaZeroHandler implements MemoryHandler { public CudaContext getCudaContext() { val lc = nativeOps.defaultLaunchContext(); - // TODO: maybe make ThreadLocal cache for context? - return CudaContext.builder() .bufferScalar(nativeOps.lcScalarPointer(lc)) .bufferReduction(nativeOps.lcReductionPointer(lc)) @@ -1192,7 +1205,7 @@ public class CudaZeroHandler implements MemoryHandler { .bufferSpecial(nativeOps.lcScalarPointer(lc)) .oldStream(new cudaStream_t(nativeOps.lcExecutionStream(lc))) .specialStream(new cudaStream_t(nativeOps.lcCopyStream(lc))) - .cublasHandle(new cublasHandle_t(nativeOps.lcBlasHandle(lc))) + .cublasHandle(getCudaCublasHandle(lc)) .solverHandle(new cusolverDnHandle_t(nativeOps.lcSolverHandle(lc))) .build(); } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java index 15f6c52ef..f3080f05a 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java @@ -3830,6 +3830,9 @@ public native @Cast("Nd4jPointer") Pointer lcSolverHandle(OpaqueLaunchContext lc * @param writeList * @param readList */ + // TODO: it would be nice to have NDArray::registerSpecialUse signature that accepts something else beyond initializer_list + + // TODO: it would be nice to have NDArray::registerSpecialUse signature that accepts something else beyond initializer_list /** diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java index 6983e20f0..8e150f618 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java @@ -3830,6 +3830,9 @@ public native @Cast("Nd4jPointer") Pointer lcSolverHandle(OpaqueLaunchContext lc * @param writeList * @param readList */ + // TODO: it would be nice to have NDArray::registerSpecialUse signature that accepts something else beyond initializer_list + + // TODO: it would be nice to have NDArray::registerSpecialUse signature that accepts something else beyond initializer_list /** From 2129d5bcace50cdd766936d3f4eb37323564ad7d Mon Sep 17 00:00:00 2001 From: raver119 Date: Mon, 2 Sep 2019 16:52:28 +0300 Subject: [PATCH 04/19] cublasHandle sharing + lock Signed-off-by: raver119 --- .../main/java/org/nd4j/jita/handler/impl/CudaZeroHandler.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/impl/CudaZeroHandler.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/impl/CudaZeroHandler.java index 23301f4be..106ac9c3a 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/impl/CudaZeroHandler.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/impl/CudaZeroHandler.java @@ -1184,9 +1184,10 @@ public class CudaZeroHandler implements MemoryHandler { protected synchronized cublasHandle_t getCudaCublasHandle(OpaqueLaunchContext lc) { val deviceId = Nd4j.getAffinityManager().getDeviceForCurrentThread(); - if (cublasHandles.get(deviceId) == null) + if (cublasHandles.get(deviceId) == null) { cublasHandles.remove(deviceId); cublasHandles.add(deviceId, new cublasHandle_t(nativeOps.lcBlasHandle(lc))); + } return cublasHandles.get(deviceId); } From 90b62c457917e20480ef3e3ec0d21c6819239650 Mon Sep 17 00:00:00 2001 From: Alexander Stoyakin Date: Mon, 2 Sep 2019 17:17:55 +0300 Subject: [PATCH 05/19] Documentation from serialization/deserialization in NLP (#221) * refactoring Signed-off-by: Alexander Stoyakin * Javadocs Signed-off-by: Alexander Stoyakin * Javadoc fixed Signed-off-by: Alexander Stoyakin * Cleanup Signed-off-by: Alexander Stoyakin --- .../loader/WordVectorSerializer.java | 415 +++++++++++++----- 1 file changed, 303 insertions(+), 112 deletions(-) diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/loader/WordVectorSerializer.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/loader/WordVectorSerializer.java index cce6a740a..210ab7686 100755 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/loader/WordVectorSerializer.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/loader/WordVectorSerializer.java @@ -24,7 +24,6 @@ import org.apache.commons.io.FileUtils; import org.apache.commons.io.IOUtils; import org.apache.commons.io.LineIterator; import org.apache.commons.io.output.CloseShieldOutputStream; -import org.deeplearning4j.exception.DL4JException; import org.deeplearning4j.exception.DL4JInvalidInputException; import org.deeplearning4j.models.embeddings.WeightLookupTable; import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable; @@ -52,7 +51,6 @@ import org.deeplearning4j.text.sentenceiterator.BasicLineIterator; import org.deeplearning4j.text.tokenization.tokenizer.TokenPreProcess; import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory; import org.deeplearning4j.util.DL4JFileUtils; -import org.nd4j.base.Preconditions; import org.nd4j.compression.impl.NoOp; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.exception.ND4JIllegalStateException; @@ -68,8 +66,6 @@ import org.nd4j.util.OneTimeLogger; import java.io.*; import java.nio.charset.StandardCharsets; -import java.nio.file.Files; -import java.nio.file.Paths; import java.util.ArrayList; import java.util.List; import java.util.concurrent.atomic.AtomicInteger; @@ -78,6 +74,80 @@ import java.util.zip.*; /** * This is utility class, providing various methods for WordVectors serialization * + * List of available serialization methods (please keep this list consistent with source code): + * + *
    + *
  • Serializers for Word2Vec:
  • + * {@link #writeWordVectors(WeightLookupTable, File)} + * {@link #writeWordVectors(WeightLookupTable, OutputStream)} + * {@link #writeWord2VecModel(Word2Vec, File)} + * {@link #writeWord2VecModel(Word2Vec, String)} + * {@link #writeWord2VecModel(Word2Vec, OutputStream)} + * + *
  • Deserializers for Word2Vec:
  • + * {@link #readWord2VecModel(File)} + * {@link #readWord2VecModel(String)} + * {@link #readWord2VecModel(File, boolean)} + * {@link #readWord2VecModel(String, boolean)} + * {@link #readAsBinaryNoLineBreaks(File)} + * {@link #readAsBinary(File)} + * {@link #readAsCsv(File)} + * {@link #readBinaryModel(File, boolean, boolean)} + * {@link #readWord2VecFromText(File, File, File, File, VectorsConfiguration)} + * {@link #readWord2Vec(String, boolean)} + * {@link #readWord2Vec(File, boolean)} + * {@link #readWord2Vec(InputStream, boolean)} + * + *
  • Serializers for ParaVec:
  • + * {@link #writeParagraphVectors(ParagraphVectors, File)} + * {@link #writeParagraphVectors(ParagraphVectors, String)} + * {@link #writeParagraphVectors(ParagraphVectors, OutputStream)} + * + *
  • Deserializers for ParaVec:
  • + * {@link #readParagraphVectors(File)} + * {@link #readParagraphVectors(String)} + * {@link #readParagraphVectors(InputStream)} + * + *
  • Serializers for GloVe:
  • + * {@link #writeWordVectors(Glove, File)} + * {@link #writeWordVectors(Glove, String)} + * {@link #writeWordVectors(Glove, OutputStream)} + * + *
  • Adapters
  • + * {@link #fromTableAndVocab(WeightLookupTable, VocabCache)} + * {@link #fromPair(Pair)} + * {@link #loadTxt(File)} + * + *
  • Serializers to tSNE format
  • + * {@link #writeTsneFormat(Glove, INDArray, File)} + * {@link #writeTsneFormat(Word2Vec, INDArray, File)} + * + *
  • FastText serializer:
  • + * {@link #writeWordVectors(FastText, File)} + * + *
  • FastText deserializer:
  • + * {@link #readWordVectors(File)} + * + *
  • SequenceVectors serializers:
  • + * {@link #writeSequenceVectors(SequenceVectors, OutputStream)} + * {@link #writeSequenceVectors(SequenceVectors, SequenceElementFactory, File)} + * {@link #writeSequenceVectors(SequenceVectors, SequenceElementFactory, String)} + * {@link #writeSequenceVectors(SequenceVectors, SequenceElementFactory, OutputStream)} + * {@link #writeLookupTable(WeightLookupTable, File)} + * {@link #writeVocabCache(VocabCache, File)} + * {@link #writeVocabCache(VocabCache, OutputStream)} + * + *
  • SequenceVectors deserializers:
  • + * {@link #readSequenceVectors(File, boolean)} + * {@link #readSequenceVectors(String, boolean)} + * {@link #readSequenceVectors(SequenceElementFactory, File)} + * {@link #readSequenceVectors(InputStream, boolean)} + * {@link #readSequenceVectors(SequenceElementFactory, InputStream)} + * {@link #readLookupTable(File)} + * {@link #readLookupTable(InputStream)} + * + *
+ * * @author Adam Gibson * @author raver119 * @author alexander@skymind.io @@ -97,7 +167,7 @@ public class WordVectorSerializer { * @throws IOException * @throws NumberFormatException */ - private static Word2Vec readTextModel(File modelFile) throws IOException, NumberFormatException { + /*private static Word2Vec readTextModel(File modelFile) throws IOException, NumberFormatException { InMemoryLookupTable lookupTable; VocabCache cache; INDArray syn0; @@ -142,7 +212,7 @@ public class WordVectorSerializer { ret.setLookupTable(lookupTable); } return ret; - } + }*/ /** * Read a binary word2vec file. @@ -173,8 +243,8 @@ public class WordVectorSerializer { try (BufferedInputStream bis = new BufferedInputStream(GzipUtils.isCompressedFilename(modelFile.getName()) ? new GZIPInputStream(new FileInputStream(modelFile)) : new FileInputStream(modelFile)); DataInputStream dis = new DataInputStream(bis)) { - words = Integer.parseInt(readString(dis)); - size = Integer.parseInt(readString(dis)); + words = Integer.parseInt(ReadHelper.readString(dis)); + size = Integer.parseInt(ReadHelper.readString(dis)); syn0 = Nd4j.create(words, size); cache = new AbstractCache<>(); @@ -188,11 +258,11 @@ public class WordVectorSerializer { float[] vector = new float[size]; for (int i = 0; i < words; i++) { - word = readString(dis); + word = ReadHelper.readString(dis); log.trace("Loading " + word + " with word " + i); for (int j = 0; j < size; j++) { - vector[j] = readFloat(dis); + vector[j] = ReadHelper.readFloat(dis); } if (cache.containsWord(word)) @@ -236,64 +306,6 @@ public class WordVectorSerializer { } - /** - * Read a float from a data input stream Credit to: - * https://github.com/NLPchina/Word2VEC_java/blob/master/src/com/ansj/vec/Word2VEC.java - * - * @param is - * @return - * @throws IOException - */ - public static float readFloat(InputStream is) throws IOException { - byte[] bytes = new byte[4]; - is.read(bytes); - return getFloat(bytes); - } - - /** - * Read a string from a data input stream Credit to: - * https://github.com/NLPchina/Word2VEC_java/blob/master/src/com/ansj/vec/Word2VEC.java - * - * @param b - * @return - * @throws IOException - */ - public static float getFloat(byte[] b) { - int accum = 0; - accum = accum | (b[0] & 0xff) << 0; - accum = accum | (b[1] & 0xff) << 8; - accum = accum | (b[2] & 0xff) << 16; - accum = accum | (b[3] & 0xff) << 24; - return Float.intBitsToFloat(accum); - } - - /** - * Read a string from a data input stream Credit to: - * https://github.com/NLPchina/Word2VEC_java/blob/master/src/com/ansj/vec/Word2VEC.java - * - * @param dis - * @return - * @throws IOException - */ - public static String readString(DataInputStream dis) throws IOException { - byte[] bytes = new byte[MAX_SIZE]; - byte b = dis.readByte(); - int i = -1; - StringBuilder sb = new StringBuilder(); - while (b != 32 && b != 10) { - i++; - bytes[i] = b; - b = dis.readByte(); - if (i == 49) { - sb.append(new String(bytes, "UTF-8")); - i = -1; - bytes = new byte[MAX_SIZE]; - } - } - sb.append(new String(bytes, 0, i + 1, "UTF-8")); - return sb.toString(); - } - /** * This method writes word vectors to the given path. * Please note: this method doesn't load whole vocab/lookupTable into memory, so it's able to process large vocabularies served over network. @@ -355,7 +367,7 @@ public class WordVectorSerializer { val builder = new StringBuilder(); val l = element.getLabel(); - builder.append(encodeB64(l)).append(" "); + builder.append(ReadHelper.encodeB64(l)).append(" "); val vec = lookupTable.vector(element.getLabel()); for (int i = 0; i < vec.length(); i++) { builder.append(vec.getDouble(i)); @@ -518,7 +530,7 @@ public class WordVectorSerializer { try (PrintWriter writer = new PrintWriter(new FileWriter(tempFileCodes))) { for (int i = 0; i < vectors.getVocab().numWords(); i++) { VocabWord word = vectors.getVocab().elementAtIndex(i); - StringBuilder builder = new StringBuilder(encodeB64(word.getLabel())).append(" "); + StringBuilder builder = new StringBuilder(ReadHelper.encodeB64(word.getLabel())).append(" "); for (int code : word.getCodes()) { builder.append(code).append(" "); } @@ -536,7 +548,7 @@ public class WordVectorSerializer { try (PrintWriter writer = new PrintWriter(new FileWriter(tempFileHuffman))) { for (int i = 0; i < vectors.getVocab().numWords(); i++) { VocabWord word = vectors.getVocab().elementAtIndex(i); - StringBuilder builder = new StringBuilder(encodeB64(word.getLabel())).append(" "); + StringBuilder builder = new StringBuilder(ReadHelper.encodeB64(word.getLabel())).append(" "); for (int point : word.getPoints()) { builder.append(point).append(" "); } @@ -554,7 +566,7 @@ public class WordVectorSerializer { try (PrintWriter writer = new PrintWriter(new FileWriter(tempFileFreqs))) { for (int i = 0; i < vectors.getVocab().numWords(); i++) { VocabWord word = vectors.getVocab().elementAtIndex(i); - StringBuilder builder = new StringBuilder(encodeB64(word.getLabel())).append(" ") + StringBuilder builder = new StringBuilder(ReadHelper.encodeB64(word.getLabel())).append(" ") .append(word.getElementFrequency()).append(" ") .append(vectors.getVocab().docAppearedIn(word.getLabel())); @@ -638,7 +650,7 @@ public class WordVectorSerializer { try (PrintWriter writer = new PrintWriter(new FileWriter(tempFileCodes))) { for (int i = 0; i < vectors.getVocab().numWords(); i++) { VocabWord word = vectors.getVocab().elementAtIndex(i); - StringBuilder builder = new StringBuilder(encodeB64(word.getLabel())).append(" "); + StringBuilder builder = new StringBuilder(ReadHelper.encodeB64(word.getLabel())).append(" "); for (int code : word.getCodes()) { builder.append(code).append(" "); } @@ -656,7 +668,7 @@ public class WordVectorSerializer { try (PrintWriter writer = new PrintWriter(new FileWriter(tempFileHuffman))) { for (int i = 0; i < vectors.getVocab().numWords(); i++) { VocabWord word = vectors.getVocab().elementAtIndex(i); - StringBuilder builder = new StringBuilder(encodeB64(word.getLabel())).append(" "); + StringBuilder builder = new StringBuilder(ReadHelper.encodeB64(word.getLabel())).append(" "); for (int point : word.getPoints()) { builder.append(point).append(" "); } @@ -677,7 +689,7 @@ public class WordVectorSerializer { StringBuilder builder = new StringBuilder(); for (VocabWord word : vectors.getVocab().tokens()) { if (word.isLabel()) - builder.append(encodeB64(word.getLabel())).append("\n"); + builder.append(ReadHelper.encodeB64(word.getLabel())).append("\n"); } IOUtils.write(builder.toString().trim(), zipfile, StandardCharsets.UTF_8); @@ -688,7 +700,7 @@ public class WordVectorSerializer { try (PrintWriter writer = new PrintWriter(new FileWriter(tempFileFreqs))) { for (int i = 0; i < vectors.getVocab().numWords(); i++) { VocabWord word = vectors.getVocab().elementAtIndex(i); - builder = new StringBuilder(encodeB64(word.getLabel())).append(" ").append(word.getElementFrequency()) + builder = new StringBuilder(ReadHelper.encodeB64(word.getLabel())).append(" ").append(word.getElementFrequency()) .append(" ").append(vectors.getVocab().docAppearedIn(word.getLabel())); writer.println(builder.toString().trim()); @@ -744,7 +756,7 @@ public class WordVectorSerializer { try (BufferedReader reader = new BufferedReader(new InputStreamReader(stream, StandardCharsets.UTF_8))) { String line; while ((line = reader.readLine()) != null) { - VocabWord word = vectors.getVocab().tokenFor(decodeB64(line.trim())); + VocabWord word = vectors.getVocab().tokenFor(ReadHelper.decodeB64(line.trim())); if (word != null) { word.markAsLabel(true); } @@ -836,7 +848,7 @@ public class WordVectorSerializer { String line; while ((line = reader.readLine()) != null) { String[] split = line.split(" "); - VocabWord word = w2v.getVocab().tokenFor(decodeB64(split[0])); + VocabWord word = w2v.getVocab().tokenFor(ReadHelper.decodeB64(split[0])); word.setElementFrequency((long) Double.parseDouble(split[1])); word.setSequencesCount((long) Double.parseDouble(split[2])); } @@ -946,7 +958,7 @@ public class WordVectorSerializer { reader = new BufferedReader(new FileReader(h_points)); while ((line = reader.readLine()) != null) { String[] split = line.split(" "); - VocabWord word = vocab.wordFor(decodeB64(split[0])); + VocabWord word = vocab.wordFor(ReadHelper.decodeB64(split[0])); List points = new ArrayList<>(); for (int i = 1; i < split.length; i++) { points.add(Integer.parseInt(split[i])); @@ -960,7 +972,7 @@ public class WordVectorSerializer { reader = new BufferedReader(new FileReader(h_codes)); while ((line = reader.readLine()) != null) { String[] split = line.split(" "); - VocabWord word = vocab.wordFor(decodeB64(split[0])); + VocabWord word = vocab.wordFor(ReadHelper.decodeB64(split[0])); List codes = new ArrayList<>(); for (int i = 1; i < split.length; i++) { codes.add(Byte.parseByte(split[i])); @@ -1704,7 +1716,7 @@ public class WordVectorSerializer { if (line.isEmpty()) line = iter.nextLine(); String[] split = line.split(" "); - String word = decodeB64(split[0]); //split[0].replaceAll(whitespaceReplacement, " "); + String word = ReadHelper.decodeB64(split[0]); //split[0].replaceAll(whitespaceReplacement, " "); VocabWord word1 = new VocabWord(1.0, word); word1.setIndex(cache.numWords()); @@ -1994,7 +2006,13 @@ public class WordVectorSerializer { private static final String SYN1_ENTRY = "syn1.bin"; private static final String SYN1_NEG_ENTRY = "syn1neg.bin"; - + /** + * This method saves specified SequenceVectors model to target OutputStream + * + * @param vectors SequenceVectors model + * @param stream Target output stream + * @param + */ public static void writeSequenceVectors(@NonNull SequenceVectors vectors, @NonNull OutputStream stream) throws IOException { @@ -2040,7 +2058,13 @@ public class WordVectorSerializer { } } - + /** + * This method loads SequenceVectors from specified file path + * + * @param path String + * @param readExtendedTables boolean + * @param + */ public static SequenceVectors readSequenceVectors(@NonNull String path, boolean readExtendedTables) throws IOException { @@ -2050,6 +2074,14 @@ public class WordVectorSerializer { return vectors; } + /** + * This method loads SequenceVectors from specified file path + * + * @param file File + * @param readExtendedTables boolean + * @param + */ + public static SequenceVectors readSequenceVectors(@NonNull File file, boolean readExtendedTables) throws IOException { @@ -2058,6 +2090,13 @@ public class WordVectorSerializer { return vectors; } + /** + * This method loads SequenceVectors from specified input stream + * + * @param stream InputStream + * @param readExtendedTables boolean + * @param + */ public static SequenceVectors readSequenceVectors(@NonNull InputStream stream, boolean readExtendedTables) throws IOException { @@ -2381,6 +2420,12 @@ public class WordVectorSerializer { } } + /** + * This method loads Word2Vec model from binary file + * + * @param file File + * @return Word2Vec + */ public static Word2Vec readAsBinary(@NonNull File file) { boolean originalPeriodic = Nd4j.getMemoryManager().isPeriodicGcActive(); int originalFreq = Nd4j.getMemoryManager().getOccasionalGcFrequency(); @@ -2403,6 +2448,12 @@ public class WordVectorSerializer { } } + /** + * This method loads Word2Vec model from csv file + * + * @param file File + * @return Word2Vec + */ public static Word2Vec readAsCsv(@NonNull File file) { Word2Vec vec; @@ -2491,7 +2542,7 @@ public class WordVectorSerializer { String line; while ((line = reader.readLine()) != null) { String[] split = line.split(" "); - VocabWord word = new VocabWord(Double.valueOf(split[1]), decodeB64(split[0])); + VocabWord word = new VocabWord(Double.valueOf(split[1]), ReadHelper.decodeB64(split[0])); word.setIndex(cnt.getAndIncrement()); word.incrementSequencesCount(Long.valueOf(split[2])); @@ -2669,7 +2720,7 @@ public class WordVectorSerializer { * * In return you get StaticWord2Vec model, which might be used as lookup table only in multi-gpu environment. * - * @param file File should point to previously saved w2v model + * @param inputStream InputStream should point to previously saved w2v model * @return */ public static WordVectors loadStaticModel(InputStream inputStream) throws IOException { @@ -2685,6 +2736,17 @@ public class WordVectorSerializer { } // TODO: this method needs better name :) + /** + * This method restores previously saved w2v model. File can be in one of the following formats: + * 1) Binary model, either compressed or not. Like well-known Google Model + * 2) Popular CSV word2vec text format + * 3) DL4j compressed format + * + * In return you get StaticWord2Vec model, which might be used as lookup table only in multi-gpu environment. + * + * @param file File + * @return + */ public static WordVectors loadStaticModel(@NonNull File file) { if (!file.exists() || file.isDirectory()) throw new RuntimeException( @@ -2843,8 +2905,8 @@ public class WordVectorSerializer { throw new RuntimeException(e); } try { - numWords = Integer.parseInt(readString(stream)); - vectorLength = Integer.parseInt(readString(stream)); + numWords = Integer.parseInt(ReadHelper.readString(stream)); + vectorLength = Integer.parseInt(ReadHelper.readString(stream)); } catch (IOException e) { throw new RuntimeException(e); } @@ -2858,13 +2920,13 @@ public class WordVectorSerializer { @Override public Pair next() { try { - String word = readString(stream); + String word = ReadHelper.readString(stream); VocabWord element = new VocabWord(1.0, word); element.setIndex(idxCounter.getAndIncrement()); float[] vector = new float[vectorLength]; for (int i = 0; i < vectorLength; i++) { - vector[i] = readFloat(stream); + vector[i] = ReadHelper.readFloat(stream); } return Pair.makePair(element, vector); @@ -2913,7 +2975,7 @@ public class WordVectorSerializer { String[] split = nextLine.split(" "); - VocabWord word = new VocabWord(1.0, decodeB64(split[0])); + VocabWord word = new VocabWord(1.0, ReadHelper.decodeB64(split[0])); word.setIndex(idxCounter.getAndIncrement()); float[] vector = new float[split.length - 1]; @@ -2937,26 +2999,12 @@ public class WordVectorSerializer { } } - public static String encodeB64(String word) { - try { - return "B64:" + Base64.encodeBase64String(word.getBytes("UTF-8")).replaceAll("(\r|\n)", ""); - } catch (Exception e) { - throw new RuntimeException(e); - } - } - - public static String decodeB64(String word) { - if (word.startsWith("B64:")) { - String arp = word.replaceFirst("B64:", ""); - try { - return new String(Base64.decodeBase64(arp), "UTF-8"); - } catch (Exception e) { - throw new RuntimeException(e); - } - } else - return word; - } - + /** + * This method saves Word2Vec model to output stream + * + * @param word2Vec Word2Vec + * @param stream OutputStream + */ public static void writeWord2Vec(@NonNull Word2Vec word2Vec, @NonNull OutputStream stream) throws IOException { @@ -2968,6 +3016,13 @@ public class WordVectorSerializer { writeSequenceVectors(vectors, stream); } + /** + * This method restores Word2Vec model from file + * + * @param path String + * @param readExtendedTables booleab + * @return Word2Vec + */ public static Word2Vec readWord2Vec(@NonNull String path, boolean readExtendedTables) throws IOException { @@ -2976,6 +3031,12 @@ public class WordVectorSerializer { return word2Vec; } + /** + * This method saves table of weights to file + * + * @param weightLookupTable WeightLookupTable + * @param file File + */ public static void writeLookupTable(WeightLookupTable weightLookupTable, @NonNull File file) throws IOException { try (BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(file), @@ -3038,7 +3099,7 @@ public class WordVectorSerializer { headerRead = true; weightLookupTable = new InMemoryLookupTable.Builder().cache(vocabCache).vectorLength(layerSize).build(); } else { - String label = decodeB64(tokens[0]); + String label = ReadHelper.decodeB64(tokens[0]); int freq = Integer.parseInt(tokens[1]); int rows = Integer.parseInt(tokens[2]); int cols = Integer.parseInt(tokens[3]); @@ -3071,6 +3132,13 @@ public class WordVectorSerializer { return weightLookupTable; } + /** + * This method loads Word2Vec model from file + * + * @param file File + * @param readExtendedTables boolean + * @return Word2Vec + */ public static Word2Vec readWord2Vec(@NonNull File file, boolean readExtendedTables) throws IOException { @@ -3078,6 +3146,13 @@ public class WordVectorSerializer { return word2Vec; } + /** + * This method loads Word2Vec model from input stream + * + * @param stream InputStream + * @param readExtendedTable boolean + * @return Word2Vec + */ public static Word2Vec readWord2Vec(@NonNull InputStream stream, boolean readExtendedTable) throws IOException { SequenceVectors vectors = readSequenceVectors(stream, readExtendedTable); @@ -3087,7 +3162,13 @@ public class WordVectorSerializer { word2Vec.setModelUtils(vectors.getModelUtils()); return word2Vec; } - + + /** + * This method loads FastText model to file + * + * @param vectors FastText + * @param path File + */ public static void writeWordVectors(@NonNull FastText vectors, @NonNull File path) throws IOException { ObjectOutputStream outputStream = null; try { @@ -3106,6 +3187,11 @@ public class WordVectorSerializer { } } + /** + * This method unloads FastText model from file + * + * @param path File + */ public static FastText readWordVectors(File path) { FastText result = null; try { @@ -3124,6 +3210,13 @@ public class WordVectorSerializer { return result; } + /** + * This method prints memory usage to log + * + * @param numWords + * @param vectorLength + * @param numTables + */ public static void printOutProjectedMemoryUse(long numWords, int vectorLength, int numTables) { double memSize = numWords * vectorLength * Nd4j.sizeOfDataType() * numTables; @@ -3144,4 +3237,102 @@ public class WordVectorSerializer { OneTimeLogger.info(log, "Projected memory use for model: [{} {}]", String.format("%.2f", value), sfx); } + + /** + * Helper static methods to read data from input stream. + */ + private static class ReadHelper { + /** + * Read a float from a data input stream Credit to: + * https://github.com/NLPchina/Word2VEC_java/blob/master/src/com/ansj/vec/Word2VEC.java + * + * @param is + * @return + * @throws IOException + */ + private static float readFloat(InputStream is) throws IOException { + byte[] bytes = new byte[4]; + is.read(bytes); + return getFloat(bytes); + } + + /** + * Read a string from a data input stream Credit to: + * https://github.com/NLPchina/Word2VEC_java/blob/master/src/com/ansj/vec/Word2VEC.java + * + * @param b + * @return + * @throws IOException + */ + private static float getFloat(byte[] b) { + int accum = 0; + accum = accum | (b[0] & 0xff) << 0; + accum = accum | (b[1] & 0xff) << 8; + accum = accum | (b[2] & 0xff) << 16; + accum = accum | (b[3] & 0xff) << 24; + return Float.intBitsToFloat(accum); + } + + /** + * Read a string from a data input stream Credit to: + * https://github.com/NLPchina/Word2VEC_java/blob/master/src/com/ansj/vec/Word2VEC.java + * + * @param dis + * @return + * @throws IOException + */ + private static String readString(DataInputStream dis) throws IOException { + byte[] bytes = new byte[MAX_SIZE]; + byte b = dis.readByte(); + int i = -1; + StringBuilder sb = new StringBuilder(); + while (b != 32 && b != 10) { + i++; + bytes[i] = b; + b = dis.readByte(); + if (i == 49) { + sb.append(new String(bytes, "UTF-8")); + i = -1; + bytes = new byte[MAX_SIZE]; + } + } + sb.append(new String(bytes, 0, i + 1, "UTF-8")); + return sb.toString(); + } + + private static final String B64 = "B64:"; + + /** + * Encode input string + * + * @param word String + * @return String + */ + private static String encodeB64(String word) { + try { + return B64 + Base64.encodeBase64String(word.getBytes("UTF-8")).replaceAll("(\r|\n)", ""); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + /** + * Encode input string + * + * @param word String + * @return String + */ + + private static String decodeB64(String word) { + if (word.startsWith(B64)) { + String arp = word.replaceFirst(B64, ""); + try { + return new String(Base64.decodeBase64(arp), "UTF-8"); + } catch (Exception e) { + throw new RuntimeException(e); + } + } else + return word; + } + } } From d3253aff3f1778654fea589b3a96b133ce134137 Mon Sep 17 00:00:00 2001 From: raver119 Date: Mon, 2 Sep 2019 20:01:13 +0300 Subject: [PATCH 06/19] dedicated lock for getCudaCublasHandle Signed-off-by: raver119 --- .../jita/handler/impl/CudaZeroHandler.java | 21 ++++++++++++------- .../java/org/nd4j/nativeblas/Nd4jCpu.java | 16 ++++++++++++-- 2 files changed, 28 insertions(+), 9 deletions(-) diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/impl/CudaZeroHandler.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/impl/CudaZeroHandler.java index 106ac9c3a..9b8c1012c 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/impl/CudaZeroHandler.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/impl/CudaZeroHandler.java @@ -1180,16 +1180,23 @@ public class CudaZeroHandler implements MemoryHandler { return getCudaContext(); } + // + private final ReentrantReadWriteLock lock = new ReentrantReadWriteLock(); - - protected synchronized cublasHandle_t getCudaCublasHandle(OpaqueLaunchContext lc) { + protected cublasHandle_t getCudaCublasHandle(OpaqueLaunchContext lc) { val deviceId = Nd4j.getAffinityManager().getDeviceForCurrentThread(); - if (cublasHandles.get(deviceId) == null) { - cublasHandles.remove(deviceId); - cublasHandles.add(deviceId, new cublasHandle_t(nativeOps.lcBlasHandle(lc))); - } + try { + lock.writeLock().lock(); - return cublasHandles.get(deviceId); + if (cublasHandles.get(deviceId) == null) { + cublasHandles.remove(deviceId); + cublasHandles.add(deviceId, new cublasHandle_t(nativeOps.lcBlasHandle(lc))); + } + + return cublasHandles.get(deviceId); + } finally { + lock.writeLock().unlock(); + } } /** diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java index 8e150f618..9554a94e9 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java @@ -16985,8 +16985,20 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); // #endif /** - * Returns a batched matrix tensor with new batched diagonal values. - */ + * Inserts elements provided by diagonal array into the main diagonal of innermost matrices of input array + * + * Input arrays: + * input: input array, considered as batch of matrices + * diagonal: array containing elements to be inserted into input array, + * following rank condition should be satisfied: diagonal_rank = input_rank - 1, + * the shapes of diagonal and input arrays must be equal except last dimension of input array, + * for example if input_shape = [A,B,C,D] then diagonal_shape = [A,B,C], + * also last dimension of diagonal array should be equal to smaller of last and last but one input dimensions + * that is: diagonal_shape[-1] = min(input_shape[-1], input_shape[-2]) + * + * Output array: + * has the same shape as input, corresponding diagonal elements are substituted + */ // #if NOT_EXCLUDED(OP_matrix_set_diag) @Namespace("nd4j::ops") public static class matrix_set_diag extends DeclarableOp { static { Loader.load(); } From ba269a26ab44916da7a7e6d4fdbd382c5fe06952 Mon Sep 17 00:00:00 2001 From: Alex Black Date: Tue, 3 Sep 2019 10:48:59 +1000 Subject: [PATCH 07/19] Small fixes (#223) Signed-off-by: AlexDBlack --- .../org/deeplearning4j/spark/TestKryo.java | 4 +- .../multilayer/TestSparkDl4jMultiLayer.java | 2 +- ...TestSparkMultiLayerParameterAveraging.java | 42 +++++++------------ 3 files changed, 19 insertions(+), 29 deletions(-) diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/TestKryo.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/TestKryo.java index e6688a215..8c5188b70 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/TestKryo.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/TestKryo.java @@ -17,7 +17,6 @@ package org.deeplearning4j.spark; import org.apache.spark.serializer.SerializerInstance; -import org.deeplearning4j.eval.*; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; @@ -28,6 +27,9 @@ import org.deeplearning4j.nn.conf.graph.rnn.LastTimeStepVertex; import org.deeplearning4j.nn.conf.layers.*; import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor; import org.junit.Test; +import org.nd4j.evaluation.IEvaluation; +import org.nd4j.evaluation.classification.*; +import org.nd4j.evaluation.regression.RegressionEvaluation; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.learning.config.Adam; diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/multilayer/TestSparkDl4jMultiLayer.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/multilayer/TestSparkDl4jMultiLayer.java index f8fe1f4f0..ecf9b937b 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/multilayer/TestSparkDl4jMultiLayer.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/multilayer/TestSparkDl4jMultiLayer.java @@ -19,7 +19,6 @@ package org.deeplearning4j.spark.impl.multilayer; import lombok.extern.slf4j.Slf4j; import org.apache.spark.api.java.JavaRDD; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; -import org.deeplearning4j.eval.Evaluation; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; @@ -30,6 +29,7 @@ import org.deeplearning4j.spark.BaseSparkTest; import org.deeplearning4j.spark.api.TrainingMaster; import org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster; import org.junit.Test; +import org.nd4j.evaluation.classification.Evaluation; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/paramavg/TestSparkMultiLayerParameterAveraging.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/paramavg/TestSparkMultiLayerParameterAveraging.java index ed56af9ee..abfd39060 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/paramavg/TestSparkMultiLayerParameterAveraging.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/paramavg/TestSparkMultiLayerParameterAveraging.java @@ -29,15 +29,13 @@ import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.mllib.util.MLUtils; import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; -import org.deeplearning4j.eval.Evaluation; -import org.deeplearning4j.eval.ROC; -import org.deeplearning4j.eval.ROCMultiClass; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.BaseLayer; +import org.deeplearning4j.nn.conf.layers.BatchNormalization; import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.conf.layers.variational.GaussianReconstructionDistribution; @@ -56,6 +54,9 @@ import org.junit.Ignore; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; +import org.nd4j.evaluation.classification.Evaluation; +import org.nd4j.evaluation.classification.ROC; +import org.nd4j.evaluation.classification.ROCMultiClass; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; @@ -63,6 +64,7 @@ import org.nd4j.linalg.dataset.MultiDataSet; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.io.ClassPathResource; +import org.nd4j.linalg.learning.config.Adam; import org.nd4j.linalg.learning.config.IUpdater; import org.nd4j.linalg.learning.config.Nesterovs; import org.nd4j.linalg.learning.config.RmsProp; @@ -70,7 +72,6 @@ import org.nd4j.linalg.lossfunctions.LossFunctions; import scala.Tuple2; import java.io.File; -import java.nio.file.Files; import java.nio.file.Path; import java.util.*; @@ -121,11 +122,6 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { new ParameterAveragingTrainingMaster(true, numExecutors(), 1, 5, 1, 0)); MultiLayerNetwork network2 = master.fitLabeledPoint(data); - Evaluation evaluation = new Evaluation(); - evaluation.eval(d.getLabels(), network2.output(d.getFeatures())); - System.out.println(evaluation.stats()); - - } @@ -137,20 +133,15 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { .getAbsolutePath()) .toJavaRDD().map(new TestFn()); - DataSet d = new IrisDataSetIterator(150, 150).next(); MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(123) - .optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT) - .miniBatch(true).maxNumLineSearchIterations(10) - .list().layer(0, - new DenseLayer.Builder().nIn(4).nOut(100) - .weightInit(WeightInit.XAVIER) - .activation(Activation.RELU) - .build()) - .layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder( - LossFunctions.LossFunction.MCXENT).nIn(100).nOut(3) - .activation(Activation.SOFTMAX) - .weightInit(WeightInit.XAVIER).build()) + .updater(new Adam(1e-6)) + .weightInit(WeightInit.XAVIER) + .list() + .layer(new BatchNormalization.Builder().nIn(4).nOut(4).build()) + .layer(new DenseLayer.Builder().nIn(4).nOut(32).activation(Activation.RELU).build()) + .layer(new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(32).nOut(3) + .activation(Activation.SOFTMAX).build()) .build(); @@ -161,10 +152,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { SparkDl4jMultiLayer master = new SparkDl4jMultiLayer(sc, getBasicConf(), new ParameterAveragingTrainingMaster(true, numExecutors(), 1, 5, 1, 0)); - MultiLayerNetwork network2 = master.fitLabeledPoint(data); - Evaluation evaluation = new Evaluation(); - evaluation.eval(d.getLabels(), network2.output(d.getFeatures())); - System.out.println(evaluation.stats()); + master.fitLabeledPoint(data); } @Test(timeout = 120000L) @@ -465,8 +453,8 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { tempDirF.deleteOnExit(); int dataSetObjSize = 1; - int batchSizePerExecutor = 25; - int numSplits = 10; + int batchSizePerExecutor = 16; + int numSplits = 5; int averagingFrequency = 3; int totalExamples = numExecutors() * batchSizePerExecutor * numSplits * averagingFrequency; DataSetIterator iter = new MnistDataSetIterator(dataSetObjSize, totalExamples, false); From 364a6e1a2a2341d2b818d0619f45621416f37bb7 Mon Sep 17 00:00:00 2001 From: Alex Black Date: Tue, 3 Sep 2019 13:35:02 +1000 Subject: [PATCH 08/19] ELU DL4J fixes (#224) Signed-off-by: AlexDBlack --- .../DifferentialFunctionFactory.java | 4 ++-- .../activations/impl/ActivationELU.java | 16 +++------------ .../ops/impl/transforms/gradient/EluBp.java | 3 ++- .../api/ops/impl/transforms/strict/ELU.java | 20 +++++++++++++------ 4 files changed, 21 insertions(+), 22 deletions(-) diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java index 49e760961..ac017beef 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java @@ -1562,8 +1562,8 @@ public class DifferentialFunctionFactory { } - public SDVariable eluBp(SDVariable in, SDVariable epsilon) { - return new EluBp(sameDiff(), in, epsilon).outputVariable(); + public SDVariable eluBp(SDVariable in, SDVariable epsilon, double alpha) { + return new EluBp(sameDiff(), in, epsilon, alpha).outputVariable(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationELU.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationELU.java index b7ac3887c..b714b1f06 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationELU.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationELU.java @@ -18,14 +18,12 @@ package org.nd4j.linalg.activations.impl; import lombok.EqualsAndHashCode; import lombok.Getter; -import org.nd4j.linalg.api.ops.impl.transforms.gradient.EluBp; -import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.activations.BaseActivationFunction; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.EluBp; import org.nd4j.linalg.api.ops.impl.transforms.strict.ELU; import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.indexing.BooleanIndexing; -import org.nd4j.linalg.indexing.conditions.Conditions; +import org.nd4j.linalg.primitives.Pair; /** * f(x) = alpha * (exp(x) - 1.0); x < 0 @@ -55,15 +53,7 @@ public class ActivationELU extends BaseActivationFunction { */ @Override public INDArray getActivation(INDArray in, boolean training) { - // no support in ELU native to override alpha - if (this.alpha != 1.00) { - INDArray alphaMultiple = Nd4j.getExecutioner().exec(new ELU(in.dup()))[0]; - alphaMultiple.muli(alpha); - BooleanIndexing.replaceWhere(in, alphaMultiple, Conditions.lessThan(0)); - } else { - Nd4j.getExecutioner().execAndReturn(new ELU(in)); - } - return in; + return Nd4j.exec(new ELU(in, in, alpha))[0]; } /* diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/EluBp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/EluBp.java index f4624a6ee..0e2a4c6b9 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/EluBp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/EluBp.java @@ -33,8 +33,9 @@ public class EluBp extends DynamicCustomOp { public EluBp(){ } - public EluBp(SameDiff sd, SDVariable input, SDVariable gradient){ + public EluBp(SameDiff sd, SDVariable input, SDVariable gradient, double alpha){ super(sd, new SDVariable[]{input, gradient}); + addTArgument(alpha); } public EluBp(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ELU.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ELU.java index a144e868b..6923639fd 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ELU.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ELU.java @@ -23,13 +23,9 @@ import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; -import org.tensorflow.framework.AttrValue; -import org.tensorflow.framework.GraphDef; -import org.tensorflow.framework.NodeDef; import java.util.Collections; import java.util.List; -import java.util.Map; /** * ELU: Exponential Linear Unit (alpha=1.0)
@@ -41,19 +37,31 @@ import java.util.Map; * @author Alex Black */ public class ELU extends DynamicCustomOp { + public static final double DEFAULT_ALPHA = 1.0; + + protected double alpha; + public ELU(SameDiff sameDiff, SDVariable i_v) { super(sameDiff, new SDVariable[]{i_v}); + this.alpha = DEFAULT_ALPHA; + addTArgument(alpha); } public ELU() { } public ELU(INDArray x, INDArray z) { + this(x, z, DEFAULT_ALPHA); + } + + public ELU(INDArray x, INDArray z, double alpha) { super(null, wrapOrNull(x), wrapOrNull(z)); + this.alpha = alpha; + addTArgument(alpha); } public ELU(INDArray x) { - this(x, null); + this(x, null, DEFAULT_ALPHA); } @Override @@ -75,7 +83,7 @@ public class ELU extends DynamicCustomOp { public List doDiff(List i_v) { //ELU: e^x-1 if x<0, x otherwise //dL/dIn = dL/Out * dOut/dIn - return Collections.singletonList(f().eluBp(arg(), i_v.get(0))); + return Collections.singletonList(f().eluBp(arg(), i_v.get(0), alpha)); } @Override From c64b340975ae4171dae0b1f9741f2fc62a13449c Mon Sep 17 00:00:00 2001 From: Robert Altena Date: Tue, 3 Sep 2019 13:06:42 +0900 Subject: [PATCH 09/19] javadoc (#225) Signed-off-by: Robert Altena --- .../nd4j/linalg/api/ndarray/BaseNDArray.java | 242 ------------------ .../linalg/api/ndarray/BaseSparseNDArray.java | 2 - .../org/nd4j/linalg/api/ndarray/INDArray.java | 20 +- 3 files changed, 15 insertions(+), 249 deletions(-) diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java index 0d0af0788..ac642872c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java @@ -1195,7 +1195,6 @@ public abstract class BaseNDArray implements INDArray, Iterable { } } - return this; } @@ -3089,12 +3088,6 @@ public abstract class BaseNDArray implements INDArray, Iterable { return mmuli(other, result); } - /** - * in place (element wise) division of two matrices - * - * @param other the second ndarray to divide - * @return the result of the divide - */ @Override public INDArray div(INDArray other) { if (Shape.areShapesBroadcastable(this.shape(), other.shape())) { @@ -3104,25 +3097,12 @@ public abstract class BaseNDArray implements INDArray, Iterable { } } - /** - * copy (element wise) division of two matrices - * - * @param other the second ndarray to divide - * @param result the result ndarray - * @return the result of the divide - */ @Override public INDArray div(INDArray other, INDArray result) { validateNumericalArray("div", true); return divi(other, result); } - /** - * copy (element wise) multiplication of two matrices - * - * @param other the second ndarray to multiply - * @return the result of the addition - */ @Override public INDArray mul(INDArray other) { validateNumericalArray("mul", false); @@ -3134,24 +3114,11 @@ public abstract class BaseNDArray implements INDArray, Iterable { } } - /** - * copy (element wise) multiplication of two matrices - * - * @param other the second ndarray to multiply - * @param result the result ndarray - * @return the result of the multiplication - */ @Override public INDArray mul(INDArray other, INDArray result) { return muli(other, result); } - /** - * copy subtraction of two matrices - * - * @param other the second ndarray to subtract - * @return the result of the addition - */ @Override public INDArray sub(INDArray other) { validateNumericalArray("sub", false); @@ -3162,24 +3129,11 @@ public abstract class BaseNDArray implements INDArray, Iterable { } } - /** - * copy subtraction of two matrices - * - * @param other the second ndarray to subtract - * @param result the result ndarray - * @return the result of the subtraction - */ @Override public INDArray sub(INDArray other, INDArray result) { return subi(other, result); } - /** - * copy addition of two matrices - * - * @param other the second ndarray to add - * @return the result of the addition - */ @Override public INDArray add(INDArray other) { validateNumericalArray("add", false); @@ -3190,65 +3144,29 @@ public abstract class BaseNDArray implements INDArray, Iterable { } } - /** - * copy addition of two matrices - * - * @param other the second ndarray to add - * @param result the result ndarray - * @return the result of the addition - */ @Override public INDArray add(INDArray other, INDArray result) { validateNumericalArray("add", false); return addi(other, result); } - - /** - * Perform an copy matrix multiplication - * - * @param other the other matrix to perform matrix multiply with - * @param transpose the transpose status of each ndarray - * @return the result of the matrix multiplication - */ @Override public INDArray mmuli(INDArray other, MMulTranspose transpose) { validateNumericalArray("mmuli", false); return dup().mmuli(other, this,transpose); } - /** - * Perform an copy matrix multiplication - * - * @param other the other matrix to perform matrix multiply with - * @return the result of the matrix multiplication - */ @Override public INDArray mmuli(INDArray other) { validateNumericalArray("mmuli", false); return dup().mmuli(other, this); } - - /** - * Perform an in place matrix multiplication - * - * @param other the other matrix to perform matrix multiply with - * @param result the result ndarray - * @return the result of the matrix multiplication - */ @Override public INDArray mmuli(INDArray other, INDArray result, MMulTranspose transpose) { return transpose.exec(this, other, result); } - /** - * Perform an copy matrix multiplication - * - * @param other the other matrix to perform matrix multiply with - * @param result the result ndarray - * @return the result of the matrix multiplication - */ @Override public INDArray mmuli(INDArray other, INDArray result) { validateNumericalArray("mmuli", false); @@ -3347,24 +3265,11 @@ public abstract class BaseNDArray implements INDArray, Iterable { return Nd4j.create(shape, stride); } - /** - * in place (element wise) division of two matrices - * - * @param other the second ndarray to divide - * @return the result of the divide - */ @Override public INDArray divi(INDArray other) { return divi(other, this); } - /** - * in place (element wise) division of two matrices - * - * @param other the second ndarray to divide - * @param result the result ndarray - * @return the result of the divide - */ @Override public INDArray divi(INDArray other, INDArray result) { validateNumericalArray("divi", false); @@ -3373,24 +3278,11 @@ public abstract class BaseNDArray implements INDArray, Iterable { return result; } - /** - * in place (element wise) multiplication of two matrices - * - * @param other the second ndarray to multiply - * @return the result of the multiplication - */ @Override public INDArray muli(INDArray other) { return muli(other, this); } - /** - * in place (element wise) multiplication of two matrices - * - * @param other the second ndarray to multiply - * @param result the result ndarray - * @return the result of the multiplication - */ @Override public INDArray muli(INDArray other, INDArray result) { validateNumericalArray("muli", false); @@ -3399,12 +3291,6 @@ public abstract class BaseNDArray implements INDArray, Iterable { return result; } - /** - * in place subtraction of two matrices - * - * @param other the second ndarray to subtract - * @return the result of the addition - */ @Override public INDArray subi(INDArray other) { return subi(other, this); @@ -3425,24 +3311,11 @@ public abstract class BaseNDArray implements INDArray, Iterable { return result; } - /** - * in place addition of two matrices - * - * @param other the second ndarray to add - * @return the result of the addition - */ @Override public INDArray addi(INDArray other) { return addi(other, this); } - /** - * in place addition of two matrices - * - * @param other the second ndarray to add - * @param result the result ndarray - * @return the result of the addition - */ @Override public INDArray addi(INDArray other, INDArray result) { validateNumericalArray("addi", false); @@ -3451,25 +3324,12 @@ public abstract class BaseNDArray implements INDArray, Iterable { return result; } - /** - * Returns the normmax along the specified dimension - * - * @param dimension the dimension to getScalar the norm1 along - * @param keepDims whether to keep reduced dimensions as dimensions of size 1 - * @return the norm1 along the specified dimension - */ @Override public INDArray normmax(boolean keepDims, int... dimension) { validateNumericalArray("normmax", false); return Nd4j.getExecutioner().exec(new NormMax(this, keepDims, dimension)); } - /** - * Returns the normmax along the specified dimension - * - * @param dimension the dimension to getScalar the norm1 along - * @return the norm1 along the specified dimension - */ @Override public INDArray normmax(int... dimension) { return normmax(false, dimension); @@ -4071,49 +3931,23 @@ public abstract class BaseNDArray implements INDArray, Iterable { return reshape(Nd4j.order(), shape); } - /** - * Returns the product along a given dimension - * - * @param dimension the dimension to getScalar the product along - * @param keepDims whether to keep reduced dimensions as dimensions of size 1 - * @return the product along the specified dimension - */ @Override public INDArray prod(boolean keepDims, int... dimension) { validateNumericalArray("prod", false); return Nd4j.getExecutioner().exec(new Prod(this, keepDims, dimension)); } - /** - * Returns the product along a given dimension - * - * @param dimension the dimension to getScalar the product along - * @return the product along the specified dimension - */ @Override public INDArray prod(int... dimension) { return prod(false, dimension); } - /** - * Returns the overall mean of this ndarray - * - * @param dimension the dimension to getScalar the mean along - * @param keepDims whether to keep reduced dimensions as dimensions of size 1 - * @return the mean along the specified dimension of this ndarray - */ @Override public INDArray mean(boolean keepDims, int... dimension) { validateNumericalArray("mean", false); return Nd4j.getExecutioner().exec(new Mean(this, keepDims, dimension)); } - /** - * Returns the overall mean of this ndarray - * - * @param dimension the dimension to getScalar the mean along - * @return the mean along the specified dimension of this ndarray - */ @Override public INDArray mean(int... dimension) { return mean(false, dimension); @@ -4136,50 +3970,24 @@ public abstract class BaseNDArray implements INDArray, Iterable { return mean(result, false, dimension); } - /** - * Returns the overall variance of this ndarray - * - * @param dimension the dimension to getScalar the mean along - * @return the mean along the specified dimension of this ndarray - */ @Override public INDArray var(int... dimension) { validateNumericalArray("var", false); return Nd4j.getExecutioner().exec(new Variance(this, dimension)); } - /** - * Returns the overall variance of this ndarray - * - * @param biasCorrected boolean on whether to apply corrected bias - * @param dimension the dimension to getScalar the mean along - * @return the mean along the specified dimension of this ndarray - */ @Override public INDArray var(boolean biasCorrected, int... dimension) { validateNumericalArray("var", false); return Nd4j.getExecutioner().exec(new Variance(this, biasCorrected, dimension)); } - /** - * Returns the overall max of this ndarray - * - * @param dimension the dimension to getScalar the mean along - * @param keepDims whether to keep reduced dimensions as dimensions of size 1 - * @return the mean along the specified dimension of this ndarray - */ @Override public INDArray max(boolean keepDims, int... dimension) { validateNumericalArray("max", false); return Nd4j.getExecutioner().exec(new Max(this, keepDims, dimension)); } - /** - * Returns the overall max of this ndarray - * - * @param dimension the dimension to getScalar the mean along - * @return the mean along the specified dimension of this ndarray - */ @Override public INDArray max(int... dimension) { return max(false, dimension); @@ -4191,25 +3999,12 @@ public abstract class BaseNDArray implements INDArray, Iterable { return Nd4j.getExecutioner().exec(new AMax(this, dimension)); } - /** - * Returns the overall min of this ndarray - * - * @param dimension the dimension to getScalar the mean along - * @param keepDims whether to keep reduced dimensions as dimensions of size 1 - * @return the mean along the specified dimension of this ndarray - */ @Override public INDArray min(boolean keepDims, int... dimension) { validateNumericalArray("min", false); return Nd4j.getExecutioner().exec(new Min(this, keepDims, dimension)); } - /** - * Returns the overall min of this ndarray - * - * @param dimension the dimension to getScalar the mean along - * @return the mean along the specified dimension of this ndarray - */ @Override public INDArray min(int... dimension) { return min(false, dimension); @@ -4290,39 +4085,17 @@ public abstract class BaseNDArray implements INDArray, Iterable { return sum(result, false, dimension); } - - /** - * Returns the norm1 along the specified dimension - * - * @param dimension the dimension to getScalar the norm1 along - * @return the norm1 along the specified dimension - */ @Override public INDArray norm1(int... dimension) { return norm1(false, dimension); } - - /** - * Returns the norm1 along the specified dimension - * - * @param dimension the dimension to getScalar the norm1 along - * @param keepDims whether to keep reduced dimensions as dimensions of size 1 - * @return the norm1 along the specified dimension - */ @Override public INDArray norm1(boolean keepDims, int... dimension) { validateNumericalArray("norm1", false); return Nd4j.getExecutioner().exec(new Norm1(this, keepDims, dimension)); } - - /** - * Standard deviation of an ndarray along a dimension - * - * @param dimension the dimension to getScalar the std along - * @return the standard deviation along a particular dimension - */ @Override public INDArray std(int... dimension) { return std(true, dimension); @@ -4345,32 +4118,17 @@ public abstract class BaseNDArray implements INDArray, Iterable { return Nd4j.getExecutioner().exec(new StandardDeviation(this, biasCorrected)).getDouble(0); } - /** - * Returns the norm2 along the specified dimension - * - * @param dimension the dimension to getScalar the norm2 along - * @param keepDims whether to keep reduced dimensions as dimensions of size 1 - * @return the norm2 along the specified dimension - */ @Override public INDArray norm2(boolean keepDims, int... dimension) { validateNumericalArray("norm2", false); return Nd4j.getExecutioner().exec(new Norm2(this, keepDims, dimension)); } - /** - * Returns the norm2 along the specified dimension - * - * @param dimension the dimension to getScalar the norm2 along - * @return the norm2 along the specified dimension - */ @Override public INDArray norm2(int... dimension) { return norm2(false, dimension); } - - /** * Number of columns (shape[1]), throws an exception when * called when not 2d diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseSparseNDArray.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseSparseNDArray.java index 6a112b868..1e0772494 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseSparseNDArray.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseSparseNDArray.java @@ -1232,8 +1232,6 @@ public abstract class BaseSparseNDArray implements ISparseNDArray { return null; } - - @Override public INDArray normmax(boolean keepDims, int... dimension) { return null; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/INDArray.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/INDArray.java index b842797f9..47e259b94 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/INDArray.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/INDArray.java @@ -1404,7 +1404,13 @@ public interface INDArray extends Serializable, AutoCloseable { */ INDArray add(INDArray other, INDArray result); - + /** + * Perform an copy matrix multiplication + * + * @param other the other matrix to perform matrix multiply with + * @param transpose the transpose status of each ndarray + * @return the result of the matrix multiplication + */ INDArray mmuli(INDArray other, MMulTranspose transpose); /** @@ -1415,7 +1421,13 @@ public interface INDArray extends Serializable, AutoCloseable { */ INDArray mmuli(INDArray other); - + /** + * Perform an in place matrix multiplication + * + * @param other the other matrix to perform matrix multiply with + * @param result the result ndarray + * @return the result of the matrix multiplication + */ INDArray mmuli(INDArray other, INDArray result, MMulTranspose transpose); /** @@ -1497,7 +1509,6 @@ public interface INDArray extends Serializable, AutoCloseable { */ INDArray addi(INDArray other, INDArray result); - /** * Returns the max norm (aka infinity norm, equal to the maximum absolute value) along the specified dimension(s) * @@ -1506,7 +1517,6 @@ public interface INDArray extends Serializable, AutoCloseable { */ INDArray normmax(int... dimension); - /** * Returns the max norm (aka infinity norm, equal to the maximum absolute value) along the specified dimension(s) * @@ -1585,7 +1595,7 @@ public interface INDArray extends Serializable, AutoCloseable { /** * Calculate the standard deviation for the entire array * - * @return + * @return standard deviation */ Number stdNumber(); From f076a8b285e8523b68af9c47fce6434d4890ff40 Mon Sep 17 00:00:00 2001 From: Alex Black Date: Tue, 3 Sep 2019 14:17:53 +1000 Subject: [PATCH 10/19] Small test compilation fix (#226) Signed-off-by: AlexDBlack --- .../models/WordVectorSerializerTest.java | 12 ++++++------ .../embeddings/loader/WordVectorSerializer.java | 6 +++--- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/models/WordVectorSerializerTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/models/WordVectorSerializerTest.java index f4dd1a6c5..69eae7307 100755 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/models/WordVectorSerializerTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/models/WordVectorSerializerTest.java @@ -833,14 +833,14 @@ public class WordVectorSerializerTest extends BaseDL4JTest { public void testB64_1() throws Exception { String wordA = "night"; String wordB = "night day"; - String encA = WordVectorSerializer.encodeB64(wordA); - String encB = WordVectorSerializer.encodeB64(wordB); + String encA = WordVectorSerializer.ReadHelper.encodeB64(wordA); + String encB = WordVectorSerializer.ReadHelper.encodeB64(wordB); - assertEquals(wordA, WordVectorSerializer.decodeB64(encA)); - assertEquals(wordB, WordVectorSerializer.decodeB64(encB)); + assertEquals(wordA, WordVectorSerializer.ReadHelper.decodeB64(encA)); + assertEquals(wordB, WordVectorSerializer.ReadHelper.decodeB64(encB)); - assertEquals(wordA, WordVectorSerializer.decodeB64(wordA)); - assertEquals(wordB, WordVectorSerializer.decodeB64(wordB)); + assertEquals(wordA, WordVectorSerializer.ReadHelper.decodeB64(wordA)); + assertEquals(wordB, WordVectorSerializer.ReadHelper.decodeB64(wordB)); } diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/loader/WordVectorSerializer.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/loader/WordVectorSerializer.java index 210ab7686..80ce0bf34 100755 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/loader/WordVectorSerializer.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/loader/WordVectorSerializer.java @@ -3241,7 +3241,7 @@ public class WordVectorSerializer { /** * Helper static methods to read data from input stream. */ - private static class ReadHelper { + public static class ReadHelper { /** * Read a float from a data input stream Credit to: * https://github.com/NLPchina/Word2VEC_java/blob/master/src/com/ansj/vec/Word2VEC.java @@ -3308,7 +3308,7 @@ public class WordVectorSerializer { * @param word String * @return String */ - private static String encodeB64(String word) { + public static String encodeB64(String word) { try { return B64 + Base64.encodeBase64String(word.getBytes("UTF-8")).replaceAll("(\r|\n)", ""); } catch (Exception e) { @@ -3323,7 +3323,7 @@ public class WordVectorSerializer { * @return String */ - private static String decodeB64(String word) { + public static String decodeB64(String word) { if (word.startsWith(B64)) { String arp = word.replaceFirst(B64, ""); try { From 5be43e725395137755b696998d526d3b153650fa Mon Sep 17 00:00:00 2001 From: Alex Black Date: Tue, 3 Sep 2019 18:54:19 +1000 Subject: [PATCH 11/19] #8182 remove spark version suffix (#227) Signed-off-by: AlexDBlack --- .../datavec-spark-inference-client/pom.xml | 2 +- .../datavec-spark-inference-server/pom.xml | 2 +- datavec/datavec-spark/pom.xml | 2 +- .../deeplearning4j-scaleout/deeplearning4j-aws/pom.xml | 2 +- .../spark/dl4j-spark-nlp-java8/pom.xml | 2 +- .../deeplearning4j-scaleout/spark/dl4j-spark-nlp/pom.xml | 2 +- .../spark/dl4j-spark-parameterserver/pom.xml | 2 +- .../deeplearning4j-scaleout/spark/dl4j-spark/pom.xml | 2 +- deeplearning4j/deeplearning4j-scaleout/spark/pom.xml | 4 ++-- 9 files changed, 10 insertions(+), 10 deletions(-) diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/pom.xml b/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/pom.xml index db110703b..076c22ab9 100644 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/pom.xml +++ b/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/pom.xml @@ -38,7 +38,7 @@ org.datavec datavec-spark-inference-server_2.11 - 1.0.0_spark_2-SNAPSHOT + 1.0.0-SNAPSHOT test diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/pom.xml b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/pom.xml index 605b13b70..8bef216a7 100644 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/pom.xml +++ b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/pom.xml @@ -25,7 +25,7 @@ datavec-spark-inference-server_2.11 jar - 1.0.0_spark_2-SNAPSHOT + 1.0.0-SNAPSHOT datavec-spark-inference-server diff --git a/datavec/datavec-spark/pom.xml b/datavec/datavec-spark/pom.xml index 05c505cac..f7143c6ea 100644 --- a/datavec/datavec-spark/pom.xml +++ b/datavec/datavec-spark/pom.xml @@ -24,7 +24,7 @@ 4.0.0 - 1.0.0_spark_2-SNAPSHOT + 1.0.0-SNAPSHOT datavec-spark_2.11 diff --git a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/pom.xml b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/pom.xml index 0b6b05c26..7c9967ef8 100644 --- a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/pom.xml +++ b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/pom.xml @@ -24,7 +24,7 @@ deeplearning4j-aws_2.11 DeepLearning4j-AWS - 1.0.0_spark_2-SNAPSHOT + 1.0.0-SNAPSHOT diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/pom.xml b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/pom.xml index 3fded3e4a..8a19b3b68 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/pom.xml +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/pom.xml @@ -18,7 +18,7 @@ spark_2.11 org.deeplearning4j - 1.0.0_spark_2-SNAPSHOT + 1.0.0-SNAPSHOT 4.0.0 diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/pom.xml b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/pom.xml index a5aff014e..16c4ac298 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/pom.xml +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/pom.xml @@ -18,7 +18,7 @@ spark_2.11 org.deeplearning4j - 1.0.0_spark_2-SNAPSHOT + 1.0.0-SNAPSHOT 4.0.0 dl4j-spark-nlp_2.11 diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/pom.xml b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/pom.xml index d8f425286..9192bb877 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/pom.xml +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/pom.xml @@ -19,7 +19,7 @@ spark_2.11 org.deeplearning4j - 1.0.0_spark_2-SNAPSHOT + 1.0.0-SNAPSHOT 4.0.0 diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/pom.xml b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/pom.xml index 8b31872c5..d84947f1e 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/pom.xml +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/pom.xml @@ -18,7 +18,7 @@ spark_2.11 org.deeplearning4j - 1.0.0_spark_2-SNAPSHOT + 1.0.0-SNAPSHOT 4.0.0 dl4j-spark_2.11 diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/pom.xml b/deeplearning4j/deeplearning4j-scaleout/spark/pom.xml index f753fefae..bd7226b0e 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/pom.xml +++ b/deeplearning4j/deeplearning4j-scaleout/spark/pom.xml @@ -22,7 +22,7 @@ 4.0.0 spark_2.11 - 1.0.0_spark_2-SNAPSHOT + 1.0.0-SNAPSHOT pom Spark parent @@ -36,7 +36,7 @@ UTF-8 UTF-8 - 1.0.0_spark_2-SNAPSHOT + 1.0.0-SNAPSHOT 2.1.0 From dddc8a11435dd2c0f2c16dfef7e775ab2261d0b7 Mon Sep 17 00:00:00 2001 From: raver119 Date: Tue, 3 Sep 2019 22:00:38 +0300 Subject: [PATCH 12/19] [WIP] Thread safety (#229) * sync after cublas*gemm Signed-off-by: raver119 * mutex for CublasHelper Signed-off-by: raver119 * don't store cublasHandle in LaunchContext, it's per-device anyway Signed-off-by: raver119 * some printout Signed-off-by: raver119 * check for field instead Signed-off-by: raver119 * pew-pew Signed-off-by: raver119 * don't release ContextBuffers until device changed Signed-off-by: raver119 * small tweak Signed-off-by: raver119 * some logging in sgemm Signed-off-by: raver119 * stream sync Signed-off-by: raver119 * some more logging Signed-off-by: raver119 * some more error checks Signed-off-by: raver119 * one fancy test Signed-off-by: raver119 * one fancy test Signed-off-by: raver119 * minor AffinityManager fix Signed-off-by: raver119 * cudaEvent error logging improvement Signed-off-by: raver119 * ConstantHelper thread safety Signed-off-by: raver119 * - minor corrections in ConstantTadHelper Signed-off-by: Yurii * ConstantShapeHelper thread safety Signed-off-by: raver119 * ConstantTadHelper.cu updated Signed-off-by: raver119 * logging off Signed-off-by: raver119 * logging off Signed-off-by: raver119 --- .../java/org/deeplearning4j/RandomTests.java | 63 +++++++++++++++++++ libnd4j/include/array/ConstantHolder.h | 4 ++ libnd4j/include/array/impl/ConstantHolder.cpp | 4 ++ .../include/execution/cuda/AffinityManager.cu | 21 ++++--- .../include/execution/cuda/ContextBuffers.cu | 21 +++++-- .../include/execution/cuda/LaunchContext.cu | 13 ++-- libnd4j/include/helpers/ConstantHelper.h | 3 +- libnd4j/include/helpers/ConstantShapeHelper.h | 8 +-- libnd4j/include/helpers/ConstantTadHelper.h | 10 +-- .../include/helpers/cpu/ConstantHelper.cpp | 27 ++++++-- .../helpers/cpu/ConstantShapeHelper.cpp | 12 ++-- .../include/helpers/cpu/ConstantTadHelper.cpp | 12 ++-- libnd4j/include/helpers/cublasHelper.h | 2 + .../include/helpers/cuda/ConstantHelper.cu | 31 ++++++--- .../helpers/cuda/ConstantShapeHelper.cu | 12 ++-- .../include/helpers/cuda/ConstantTadHelper.cu | 14 ++--- .../include/helpers/cuda_off/cublasHelper.cu | 9 ++- .../allocator/pointers/cuda/cudaEvent_t.java | 11 ++-- .../linalg/jcublas/blas/JcublasLevel3.java | 16 ++++- .../ops/executioner/CudaExecutioner.java | 12 ++++ 20 files changed, 227 insertions(+), 78 deletions(-) create mode 100644 deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/RandomTests.java diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/RandomTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/RandomTests.java new file mode 100644 index 000000000..8f727fdf9 --- /dev/null +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/RandomTests.java @@ -0,0 +1,63 @@ +package org.deeplearning4j; + +import org.deeplearning4j.datasets.iterator.EarlyTerminationDataSetIterator; +import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; +import org.deeplearning4j.nn.api.OptimizationAlgorithm; +import org.deeplearning4j.nn.conf.MultiLayerConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; +import org.junit.Ignore; +import org.junit.Test; +import org.nd4j.linalg.activations.Activation; +import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.learning.config.RmsProp; +import org.nd4j.linalg.lossfunctions.LossFunctions; + +import java.util.concurrent.CountDownLatch; + +@Ignore +public class RandomTests { + + @Test + public void testReproduce() throws Exception { + + final MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new RmsProp()) + .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).list() + .layer(0, new org.deeplearning4j.nn.conf.layers.DenseLayer.Builder().nIn(28 * 28).nOut(10) + .activation(Activation.TANH).build()) + .layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder( + LossFunctions.LossFunction.MCXENT).nIn(10).nOut(10) + .activation(Activation.SOFTMAX).build()) + .build(); + + for (int e = 0; e < 3; e++) { + + int nThreads = 10; + final CountDownLatch l = new CountDownLatch(nThreads); + for (int i = 0; i < nThreads; i++) { + final int j = i; + Thread t = new Thread(new Runnable() { + @Override + public void run() { + try { + MultiLayerNetwork net = new MultiLayerNetwork(conf.clone()); + net.init(); + DataSetIterator iter = new EarlyTerminationDataSetIterator(new MnistDataSetIterator(10, false, 12345), 100); + net.fit(iter); + } catch (Throwable t) { + System.out.println("Thread failed: " + j); + t.printStackTrace(); + } finally { + l.countDown(); + } + } + }); + t.start(); + } + + l.await(); + System.out.println("DONE " + e + "\n"); + } + } +} diff --git a/libnd4j/include/array/ConstantHolder.h b/libnd4j/include/array/ConstantHolder.h index 89be279e4..137d26f29 100644 --- a/libnd4j/include/array/ConstantHolder.h +++ b/libnd4j/include/array/ConstantHolder.h @@ -24,11 +24,13 @@ #include #include #include +#include namespace nd4j { class ConstantHolder { private: int _deviceId = 0; + std::mutex _mutex; std::map _buffers; public: @@ -53,6 +55,8 @@ namespace nd4j { template ConstantDataBuffer* getConstantDataBuffer(); + + std::mutex* mutex(); }; } diff --git a/libnd4j/include/array/impl/ConstantHolder.cpp b/libnd4j/include/array/impl/ConstantHolder.cpp index 92cc9df23..5913d57a9 100644 --- a/libnd4j/include/array/impl/ConstantHolder.cpp +++ b/libnd4j/include/array/impl/ConstantHolder.cpp @@ -16,6 +16,10 @@ namespace nd4j { return _buffers.count(dataType) > 0; } + std::mutex* ConstantHolder::mutex() { + return &_mutex; + } + template bool ConstantHolder::hasBuffer() { return hasBuffer(DataTypeUtils::fromT()); diff --git a/libnd4j/include/execution/cuda/AffinityManager.cu b/libnd4j/include/execution/cuda/AffinityManager.cu index 1f028b011..d28c0d6d0 100644 --- a/libnd4j/include/execution/cuda/AffinityManager.cu +++ b/libnd4j/include/execution/cuda/AffinityManager.cu @@ -47,7 +47,7 @@ namespace nd4j { _currentMutex.unlock(); - setCurrentDevice(globalThreadToDevice); + setCurrentNativeDevice(globalThreadToDevice); } // if we already know affinity - just return it @@ -92,6 +92,8 @@ namespace nd4j { void AffinityManager::setCurrentNativeDevice(int deviceId) { auto res = cudaSetDevice(deviceId); + if (res != 0) + throw cuda_exception::build("setCurrentDevice failed", res); } void AffinityManager::setCurrentDevice(int deviceId) { @@ -104,17 +106,22 @@ namespace nd4j { res = cudaStreamSynchronize(*LaunchContext::defaultContext()->getCudaSpecialStream()); if (res != 0) throw cuda_exception::build("setCurrentDevice -> specialSync failed", res); + + if (deviceId != previousDeviceId) { + // discard existing stuff + nd4j_printf("AffinityManager::setCurrentDevice() was invoked, releasing buffers\n", ""); + LaunchContext::releaseBuffers(); + } } - auto res = cudaSetDevice(deviceId); - if (res != 0) - throw cuda_exception::build("cudaSetDevice failed", res); + if (deviceId != previousDeviceId) { + auto res = cudaSetDevice(deviceId); + if (res != 0) + throw cuda_exception::build("cudaSetDevice failed", res); + } // update thread-device affinity globalThreadToDevice = deviceId; - - // discard existing stuff - LaunchContext::releaseBuffers(); } std::atomic AffinityManager::_lastDevice;// = std::atomic(initialV); diff --git a/libnd4j/include/execution/cuda/ContextBuffers.cu b/libnd4j/include/execution/cuda/ContextBuffers.cu index 895bb6623..435858462 100644 --- a/libnd4j/include/execution/cuda/ContextBuffers.cu +++ b/libnd4j/include/execution/cuda/ContextBuffers.cu @@ -107,7 +107,6 @@ namespace nd4j { ////// _allocated = false; - _initialized = false; _deviceId = -1; this->_specialStream = nullptr; @@ -116,6 +115,8 @@ namespace nd4j { this->_reductionPointer = nullptr; this->_scalarPointer = nullptr; } + + _initialized = false; } ContextBuffers::~ContextBuffers() { @@ -163,21 +164,21 @@ namespace nd4j { } void* ContextBuffers::reductionBuffer() { - if (_reductionPointer == nullptr) + if (!_initialized) initialize(); return _reductionPointer; } void* ContextBuffers::scalarBuffer() { - if (_scalarPointer == nullptr) + if (!_initialized) initialize(); return _scalarPointer; } void* ContextBuffers::allocationBuffer() { - if (_allocationPointer == nullptr) + if (!_initialized) initialize(); return _allocationPointer; @@ -204,15 +205,23 @@ namespace nd4j { } void* ContextBuffers::execStream() { - if (_execStream == nullptr) + if (!_initialized) { + //nd4j_printf("execStream not initialized\n", ""); initialize(); + } else { + //nd4j_printf("execStream is initialized\n", ""); + } return _execStream; } void* ContextBuffers::specialStream() { - if (_specialStream == nullptr) + if (!_initialized) { + //nd4j_printf("specialStream not initialized\n", ""); initialize(); + } else { + //nd4j_printf("specialStream is initialized\n", ""); + } return _specialStream; } diff --git a/libnd4j/include/execution/cuda/LaunchContext.cu b/libnd4j/include/execution/cuda/LaunchContext.cu index 9d9f2c506..7d1691982 100644 --- a/libnd4j/include/execution/cuda/LaunchContext.cu +++ b/libnd4j/include/execution/cuda/LaunchContext.cu @@ -57,10 +57,6 @@ LaunchContext::LaunchContext() { _deviceID = 0; _isAllocated = true; - - _cublasHandle = CublasHelper::getInstance()->handle(); - - _cusolverHandle = CublasHelper::getInstance()->solver(); } LaunchContext::LaunchContext(Nd4jPointer cudaStream, Nd4jPointer reductionPointer, Nd4jPointer scalarPointer, Nd4jPointer allocationPointer) { @@ -89,13 +85,13 @@ LaunchContext::LaunchContext() { _contexts.resize(numDevices); for (int e = 0; e < numDevices; e++) { - AffinityManager::setCurrentDevice(e); + AffinityManager::setCurrentNativeDevice(e); LaunchContext::_contexts[e] = std::make_shared(); } // don't forget to restore device back again - AffinityManager::setCurrentDevice(deviceId); + AffinityManager::setCurrentNativeDevice(deviceId); } _mutex.unlock(); @@ -117,11 +113,11 @@ LaunchContext::LaunchContext() { }; void* LaunchContext::getCublasHandle() const { - return _cublasHandle; + return CublasHelper::getInstance()->handle(); }; void* LaunchContext::getCusolverHandle() const { - return _cusolverHandle; + return CublasHelper::getInstance()->solver(); }; cudaStream_t* LaunchContext::getCudaStream() const { @@ -162,6 +158,7 @@ LaunchContext::LaunchContext() { }; void LaunchContext::releaseBuffers() { + nd4j_printf("LaunchContext::releaseBuffers() was invoked\n", ""); contextBuffers.release(); } diff --git a/libnd4j/include/helpers/ConstantHelper.h b/libnd4j/include/helpers/ConstantHelper.h index a7f7d0c00..6aad7c387 100644 --- a/libnd4j/include/helpers/ConstantHelper.h +++ b/libnd4j/include/helpers/ConstantHelper.h @@ -38,12 +38,13 @@ namespace nd4j { static ConstantHelper* _INSTANCE; ConstantHelper(); - std::vector> _cache; + std::vector> _cache; // tracking of per-device constant memory buffers (CUDA only atm) std::vector _devicePointers; std::vector _deviceOffsets; std::mutex _mutex; + std::mutex _mutexHolder; std::vector _counters; public: diff --git a/libnd4j/include/helpers/ConstantShapeHelper.h b/libnd4j/include/helpers/ConstantShapeHelper.h index fe0e52ce5..585db0198 100644 --- a/libnd4j/include/helpers/ConstantShapeHelper.h +++ b/libnd4j/include/helpers/ConstantShapeHelper.h @@ -48,10 +48,10 @@ namespace nd4j { static ConstantShapeHelper* getInstance(); - ConstantDataBuffer& bufferForShapeInfo(nd4j::DataType dataType, char order, const std::vector &shape); - ConstantDataBuffer& bufferForShapeInfo(const ShapeDescriptor &descriptor); - ConstantDataBuffer& bufferForShapeInfo(const Nd4jLong *shapeInfo); - ConstantDataBuffer& bufferForShapeInfo(const nd4j::DataType dataType, const char order, const int rank, const Nd4jLong* shape); + ConstantDataBuffer bufferForShapeInfo(nd4j::DataType dataType, char order, const std::vector &shape); + ConstantDataBuffer bufferForShapeInfo(const ShapeDescriptor &descriptor); + ConstantDataBuffer bufferForShapeInfo(const Nd4jLong *shapeInfo); + ConstantDataBuffer bufferForShapeInfo(const nd4j::DataType dataType, const char order, const int rank, const Nd4jLong* shape); Nd4jLong* emptyShapeInfo(const nd4j::DataType dataType); diff --git a/libnd4j/include/helpers/ConstantTadHelper.h b/libnd4j/include/helpers/ConstantTadHelper.h index d2790998b..79ee7dcd4 100644 --- a/libnd4j/include/helpers/ConstantTadHelper.h +++ b/libnd4j/include/helpers/ConstantTadHelper.h @@ -54,11 +54,11 @@ namespace nd4j { * @param keepUnitiesInShape * @return */ - TadPack& tadForDimensions(const Nd4jLong *originalShape, const std::vector &dimensions, const bool keepUnitiesInShape = false); - TadPack& tadForDimensions(const Nd4jLong *originalShape, int* dimensions, int dimLength, const bool keepUnitiesInShape = false); - TadPack& tadForDimensions(const Nd4jLong *originalShape, int dimensions, const bool keepUnitiesInShape = false); - TadPack& tadForDimensions(ShapeDescriptor &descriptor, std::vector &dimensions, const bool keepUnitiesInShape = false); - TadPack& tadForDimensions(TadDescriptor &descriptor); + TadPack tadForDimensions(const Nd4jLong *originalShape, const std::vector &dimensions, const bool keepUnitiesInShape = false); + TadPack tadForDimensions(const Nd4jLong *originalShape, int* dimensions, int dimLength, const bool keepUnitiesInShape = false); + TadPack tadForDimensions(const Nd4jLong *originalShape, int dimensions, const bool keepUnitiesInShape = false); + TadPack tadForDimensions(ShapeDescriptor &descriptor, std::vector &dimensions, const bool keepUnitiesInShape = false); + TadPack tadForDimensions(TadDescriptor &descriptor); /** * This method returns number of cached TAD shapes/offsets on specific device diff --git a/libnd4j/include/helpers/cpu/ConstantHelper.cpp b/libnd4j/include/helpers/cpu/ConstantHelper.cpp index 43a4f97c1..b2549e93f 100644 --- a/libnd4j/include/helpers/cpu/ConstantHelper.cpp +++ b/libnd4j/include/helpers/cpu/ConstantHelper.cpp @@ -33,7 +33,8 @@ namespace nd4j { _cache.resize(numDevices); _counters.resize(numDevices); for (int e = 0; e < numDevices; e++) { - std::map map; + std::map map; + _cache[e] = map; _counters[e] = 0L; } @@ -70,15 +71,26 @@ namespace nd4j { ConstantDataBuffer* ConstantHelper::constantBuffer(const ConstantDescriptor &descriptor, nd4j::DataType dataType) { const auto deviceId = getCurrentDevice(); + // we're locking away cache modification + _mutexHolder.lock(); + if (_cache[deviceId].count(descriptor) == 0) { - ConstantHolder holder; - _cache[deviceId][descriptor] = holder; + _cache[deviceId][descriptor] = new ConstantHolder(); } - ConstantHolder* holder = &_cache[deviceId][descriptor]; + auto holder = _cache[deviceId][descriptor]; + + // releasing cache lock + _mutexHolder.unlock(); + + + ConstantDataBuffer* result; + + // access to this holder instance is synchronous + holder->mutex()->lock(); if (holder->hasBuffer(dataType)) - return holder->getConstantDataBuffer(dataType); + result = holder->getConstantDataBuffer(dataType); else { auto size = descriptor.length() * DataTypeUtils::sizeOf(dataType); auto cbuff = new int8_t[size]; @@ -94,8 +106,11 @@ namespace nd4j { ConstantDataBuffer dataBuffer(cbuff, nullptr, descriptor.length(), DataTypeUtils::sizeOf(dataType)); holder->addBuffer(dataBuffer, dataType); - return holder->getConstantDataBuffer(dataType); + result = holder->getConstantDataBuffer(dataType); } + holder->mutex()->unlock(); + + return result; } Nd4jLong ConstantHelper::getCachedAmount(int deviceId) { diff --git a/libnd4j/include/helpers/cpu/ConstantShapeHelper.cpp b/libnd4j/include/helpers/cpu/ConstantShapeHelper.cpp index bdb77ccaa..531b68004 100644 --- a/libnd4j/include/helpers/cpu/ConstantShapeHelper.cpp +++ b/libnd4j/include/helpers/cpu/ConstantShapeHelper.cpp @@ -41,18 +41,18 @@ namespace nd4j { return _INSTANCE; } - ConstantDataBuffer& ConstantShapeHelper::bufferForShapeInfo(nd4j::DataType dataType, char order, const std::vector &shape) { + ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo(nd4j::DataType dataType, char order, const std::vector &shape) { ShapeDescriptor descriptor(dataType, order, shape); return bufferForShapeInfo(descriptor); } - ConstantDataBuffer& ConstantShapeHelper::bufferForShapeInfo(const nd4j::DataType dataType, const char order, const int rank, const Nd4jLong* shape) { + ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo(const nd4j::DataType dataType, const char order, const int rank, const Nd4jLong* shape) { ShapeDescriptor descriptor(dataType, order, shape, rank); return bufferForShapeInfo(descriptor); } - ConstantDataBuffer& ConstantShapeHelper::bufferForShapeInfo(const ShapeDescriptor &descriptor) { + ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo(const ShapeDescriptor &descriptor) { int deviceId = 0; _mutex.lock(); @@ -62,19 +62,19 @@ namespace nd4j { ConstantDataBuffer buffer(hPtr, nullptr, shape::shapeInfoLength(hPtr)*sizeof(Nd4jLong), DataType::INT64); ShapeDescriptor descriptor1(descriptor); _cache[deviceId][descriptor1] = buffer; - ConstantDataBuffer &r = _cache[deviceId][descriptor1]; + auto r = _cache[deviceId][descriptor1]; _mutex.unlock(); return r; } else { - ConstantDataBuffer &r = _cache[deviceId].at(descriptor); + auto r = _cache[deviceId].at(descriptor); _mutex.unlock(); return r; } } - ConstantDataBuffer& ConstantShapeHelper::bufferForShapeInfo(const Nd4jLong *shapeInfo) { + ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo(const Nd4jLong *shapeInfo) { ShapeDescriptor descriptor(shapeInfo); return bufferForShapeInfo(descriptor); } diff --git a/libnd4j/include/helpers/cpu/ConstantTadHelper.cpp b/libnd4j/include/helpers/cpu/ConstantTadHelper.cpp index 5100ca3ff..822b5ad0d 100644 --- a/libnd4j/include/helpers/cpu/ConstantTadHelper.cpp +++ b/libnd4j/include/helpers/cpu/ConstantTadHelper.cpp @@ -38,25 +38,25 @@ namespace nd4j { return _INSTANCE; } - TadPack& ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, int dimension, const bool keepUnitiesInShape) { + TadPack ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, int dimension, const bool keepUnitiesInShape) { return tadForDimensions(originalShape, &dimension, 1, keepUnitiesInShape); } - TadPack& ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, const std::vector &dimensions, const bool keepUnitiesInShape) { + TadPack ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, const std::vector &dimensions, const bool keepUnitiesInShape) { return tadForDimensions(originalShape, const_cast(dimensions.data()), dimensions.size(), keepUnitiesInShape); } - TadPack& ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, int* dimensions, int dimLength, const bool keepUnitiesInShape) { + TadPack ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, int* dimensions, int dimLength, const bool keepUnitiesInShape) { TadDescriptor tadDescriptor(originalShape, dimensions, dimLength, keepUnitiesInShape); return tadForDimensions(tadDescriptor); } - TadPack& ConstantTadHelper::tadForDimensions(ShapeDescriptor &descriptor, std::vector &dimensions, const bool keepUnitiesInShape) { + TadPack ConstantTadHelper::tadForDimensions(ShapeDescriptor &descriptor, std::vector &dimensions, const bool keepUnitiesInShape) { TadDescriptor tadDescriptor(descriptor, dimensions, keepUnitiesInShape); return tadForDimensions(tadDescriptor); } - TadPack& ConstantTadHelper::tadForDimensions(TadDescriptor &descriptor) { + TadPack ConstantTadHelper::tadForDimensions(TadDescriptor &descriptor) { const int deviceId = 0; _mutex.lock(); @@ -105,7 +105,7 @@ namespace nd4j { return r; } else { - TadPack &r = _cache[deviceId][descriptor]; + TadPack r = _cache[deviceId][descriptor]; _mutex.unlock(); return r; diff --git a/libnd4j/include/helpers/cublasHelper.h b/libnd4j/include/helpers/cublasHelper.h index d4f92881e..94cd2446b 100644 --- a/libnd4j/include/helpers/cublasHelper.h +++ b/libnd4j/include/helpers/cublasHelper.h @@ -24,11 +24,13 @@ #include #include #include +#include namespace nd4j { class CublasHelper { private: static CublasHelper *_INSTANCE; + static std::mutex _mutex; std::vector _cache; std::vector _solvers; diff --git a/libnd4j/include/helpers/cuda/ConstantHelper.cu b/libnd4j/include/helpers/cuda/ConstantHelper.cu index 0c7f2cbc1..0d7bdf64c 100644 --- a/libnd4j/include/helpers/cuda/ConstantHelper.cu +++ b/libnd4j/include/helpers/cuda/ConstantHelper.cu @@ -68,7 +68,7 @@ namespace nd4j { throw cuda_exception::build("cudaSetDevice failed", res); auto constant = getConstantSpace(); - std::map devCache; + std::map devCache; _devicePointers[e] = constant; _deviceOffsets[e] = 0; @@ -136,15 +136,24 @@ namespace nd4j { ConstantDataBuffer* ConstantHelper::constantBuffer(const ConstantDescriptor &descriptor, nd4j::DataType dataType) { const auto deviceId = getCurrentDevice(); - if (_cache[deviceId].count(descriptor) == 0) { - ConstantHolder holder; - _cache[deviceId][descriptor] = holder; - } + // all cache modifications are synchronous + _mutexHolder.lock(); - ConstantHolder* holder = &_cache[deviceId][descriptor]; + if (_cache[deviceId].count(descriptor) == 0) { + _cache[deviceId][descriptor] = new ConstantHolder(); + } + auto holder = _cache[deviceId][descriptor]; + + // release cache lock + _mutexHolder.unlock(); + + ConstantDataBuffer* result; + + // access to this holder instance is synchronous + holder->mutex()->lock(); if (holder->hasBuffer(dataType)) { - return holder->getConstantDataBuffer(dataType); + result = holder->getConstantDataBuffer(dataType); } else { auto numBytes = descriptor.length() * DataTypeUtils::sizeOf(dataType); auto cbuff = new int8_t[numBytes]; @@ -160,10 +169,14 @@ namespace nd4j { auto dbuff = replicatePointer(cbuff, descriptor.length() * DataTypeUtils::sizeOf(dataType)); ConstantDataBuffer dataBuffer(cbuff, dbuff, descriptor.length(), DataTypeUtils::sizeOf(dataType)); - holder->addBuffer(dataBuffer, dataType); - return holder->getConstantDataBuffer(dataType); + holder->addBuffer(dataBuffer, dataType); + result = holder->getConstantDataBuffer(dataType); } + // release holder lock + holder->mutex()->unlock(); + + return result; } Nd4jLong ConstantHelper::getCachedAmount(int deviceId) { diff --git a/libnd4j/include/helpers/cuda/ConstantShapeHelper.cu b/libnd4j/include/helpers/cuda/ConstantShapeHelper.cu index a1217e0e3..4004b9895 100644 --- a/libnd4j/include/helpers/cuda/ConstantShapeHelper.cu +++ b/libnd4j/include/helpers/cuda/ConstantShapeHelper.cu @@ -44,17 +44,17 @@ namespace nd4j { return _INSTANCE; } - ConstantDataBuffer& ConstantShapeHelper::bufferForShapeInfo(nd4j::DataType dataType, char order, const std::vector &shape) { + ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo(nd4j::DataType dataType, char order, const std::vector &shape) { ShapeDescriptor descriptor(dataType, order, shape); return bufferForShapeInfo(descriptor); } - ConstantDataBuffer& ConstantShapeHelper::bufferForShapeInfo(const nd4j::DataType dataType, const char order, const int rank, const Nd4jLong* shape) { + ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo(const nd4j::DataType dataType, const char order, const int rank, const Nd4jLong* shape) { ShapeDescriptor descriptor(dataType, order, shape, rank); return bufferForShapeInfo(descriptor); } - ConstantDataBuffer& ConstantShapeHelper::bufferForShapeInfo(const ShapeDescriptor &descriptor) { + ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo(const ShapeDescriptor &descriptor) { int deviceId = AffinityManager::currentDeviceId(); _mutex.lock(); @@ -65,19 +65,19 @@ namespace nd4j { ConstantDataBuffer buffer(hPtr, dPtr, shape::shapeInfoLength(hPtr) * sizeof(Nd4jLong), DataType::INT64); ShapeDescriptor descriptor1(descriptor); _cache[deviceId][descriptor1] = buffer; - ConstantDataBuffer &r = _cache[deviceId][descriptor1]; + auto r = _cache[deviceId][descriptor1]; _mutex.unlock(); return r; } else { - ConstantDataBuffer &r = _cache[deviceId].at(descriptor); + ConstantDataBuffer r = _cache[deviceId].at(descriptor); _mutex.unlock(); return r; } } - ConstantDataBuffer& ConstantShapeHelper::bufferForShapeInfo(const Nd4jLong *shapeInfo) { + ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo(const Nd4jLong *shapeInfo) { ShapeDescriptor descriptor(shapeInfo); return bufferForShapeInfo(descriptor); } diff --git a/libnd4j/include/helpers/cuda/ConstantTadHelper.cu b/libnd4j/include/helpers/cuda/ConstantTadHelper.cu index da66975c3..8ea4067f3 100644 --- a/libnd4j/include/helpers/cuda/ConstantTadHelper.cu +++ b/libnd4j/include/helpers/cuda/ConstantTadHelper.cu @@ -43,25 +43,25 @@ namespace nd4j { return _INSTANCE; } - TadPack& ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, int dimension, const bool keepUnitiesInShape) { + TadPack ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, int dimension, const bool keepUnitiesInShape) { return tadForDimensions(originalShape, &dimension, 1, keepUnitiesInShape); } - TadPack& ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, const std::vector &dimensions, const bool keepUnitiesInShape) { + TadPack ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, const std::vector &dimensions, const bool keepUnitiesInShape) { return tadForDimensions(originalShape, const_cast(dimensions.data()), dimensions.size(), keepUnitiesInShape); } - TadPack& ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, int* dimensions, int dimLength, const bool keepUnitiesInShape) { + TadPack ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, int* dimensions, int dimLength, const bool keepUnitiesInShape) { TadDescriptor tadDescriptor(originalShape, dimensions, dimLength, keepUnitiesInShape); return tadForDimensions(tadDescriptor); } - TadPack& ConstantTadHelper::tadForDimensions(ShapeDescriptor &descriptor, std::vector &dimensions, const bool keepUnitiesInShape) { + TadPack ConstantTadHelper::tadForDimensions(ShapeDescriptor &descriptor, std::vector &dimensions, const bool keepUnitiesInShape) { TadDescriptor tadDescriptor(descriptor, dimensions, keepUnitiesInShape); return tadForDimensions(tadDescriptor); } - TadPack& ConstantTadHelper::tadForDimensions(TadDescriptor &descriptor) { + TadPack ConstantTadHelper::tadForDimensions(TadDescriptor &descriptor) { const int deviceId = AffinityManager::currentDeviceId(); _mutex.lock(); @@ -96,14 +96,14 @@ namespace nd4j { TadPack t(shapesBuffer, offsetsBuffer, numOfSubArrs); _cache[deviceId][descriptor] = t; - TadPack &r = _cache[deviceId][descriptor]; + TadPack r = _cache[deviceId][descriptor]; _mutex.unlock(); delete[] shapeInfo; return r; } else { - TadPack &r = _cache[deviceId][descriptor]; + TadPack r = _cache[deviceId][descriptor]; _mutex.unlock(); return r; diff --git a/libnd4j/include/helpers/cuda_off/cublasHelper.cu b/libnd4j/include/helpers/cuda_off/cublasHelper.cu index 6f2cf2084..d9784eaa2 100644 --- a/libnd4j/include/helpers/cuda_off/cublasHelper.cu +++ b/libnd4j/include/helpers/cuda_off/cublasHelper.cu @@ -27,6 +27,7 @@ #include namespace nd4j { + std::mutex CublasHelper::_mutex; static void* handle_() { auto _handle = new cublasHandle_t(); @@ -56,22 +57,24 @@ namespace nd4j { } CublasHelper::CublasHelper() { + //nd4j_printf("Initializing cuBLAS\n",""); auto numDevices = AffinityManager::numberOfDevices(); auto currentDevice = AffinityManager::currentDeviceId(); _cache.resize(numDevices); _solvers.resize(numDevices); for (int e = 0; e < numDevices; e++) { - AffinityManager::setCurrentDevice(e); + AffinityManager::setCurrentNativeDevice(e); _cache[e] = handle_(); _solvers[e] = solver_(); } // don't forget to restore back original device - AffinityManager::setCurrentDevice(currentDevice); + AffinityManager::setCurrentNativeDevice(currentDevice); } CublasHelper::~CublasHelper() { + nd4j_printf("Releasing cuBLAS\n",""); auto numDevices = AffinityManager::numberOfDevices(); for (int e = 0; e < numDevices; e++) @@ -79,8 +82,10 @@ namespace nd4j { } CublasHelper* CublasHelper::getInstance() { + _mutex.lock(); if (!_INSTANCE) _INSTANCE = new nd4j::CublasHelper(); + _mutex.unlock(); return _INSTANCE; } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/pointers/cuda/cudaEvent_t.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/pointers/cuda/cudaEvent_t.java index 52b7d7332..de1920f0a 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/pointers/cuda/cudaEvent_t.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/pointers/cuda/cudaEvent_t.java @@ -18,6 +18,7 @@ package org.nd4j.jita.allocator.pointers.cuda; import lombok.Getter; import lombok.Setter; +import lombok.val; import org.bytedeco.javacpp.Pointer; import org.nd4j.jita.allocator.pointers.CudaPointer; import org.nd4j.linalg.exception.ND4JException; @@ -69,8 +70,9 @@ public class cudaEvent_t extends CudaPointer { if (res == 0) throw new ND4JException("CUDA exception happened. Terminating. Last op: [" + Nd4j.getExecutioner().getLastOp() +"]"); - if (NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorCode() != 0) - throw new RuntimeException(NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorMessage()); + val code = NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorCode(); + if (code != 0) + throw new RuntimeException(NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorMessage() + "; Error code: " + code); } } @@ -78,8 +80,9 @@ public class cudaEvent_t extends CudaPointer { if (!isDestroyed()) { int res = NativeOpsHolder.getInstance().getDeviceNativeOps().registerEvent(this, stream); - if (NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorCode() != 0) - throw new RuntimeException(NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorMessage()); + val code = NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorCode(); + if (code != 0) + throw new RuntimeException(NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorMessage() + "; Error code: " + code); } } } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/blas/JcublasLevel3.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/blas/JcublasLevel3.java index 7f8f9bb51..b06211545 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/blas/JcublasLevel3.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/blas/JcublasLevel3.java @@ -17,6 +17,7 @@ package org.nd4j.linalg.jcublas.blas; +import lombok.extern.slf4j.Slf4j; import lombok.val; import org.bytedeco.javacpp.DoublePointer; import org.bytedeco.javacpp.FloatPointer; @@ -52,6 +53,7 @@ import static org.nd4j.linalg.jcublas.blas.CudaBlas.*; * * @author Adam Gibson */ +@Slf4j public class JcublasLevel3 extends BaseLevel3 { private Allocator allocator = AtomicAllocator.getInstance(); private Nd4jBlas nd4jBlas = (Nd4jBlas) Nd4j.factory().blas(); @@ -78,7 +80,7 @@ public class JcublasLevel3 extends BaseLevel3 { int arch = CudaEnvironment.getInstance().getCurrentDeviceArchitecture(); - if ((CUDA_VERSION >= 8000 && (arch == 53 || arch == 60 || arch == 70)) || (CUDA_VERSION >= 8000 && CUDA_VERSION < 9020)) { + if ((CUDA_VERSION >= 8000 && (arch == 53 || arch == 60 || arch >= 70)) || (CUDA_VERSION >= 8000 && CUDA_VERSION < 9020)) { // on these selected archs we run with cublasHgemm __half alphaHalf = new __half(); __half betaHalf = new __half(); @@ -96,7 +98,11 @@ public class JcublasLevel3 extends BaseLevel3 { new FloatPointer(alpha), (ShortPointer) cAPointer.getDevicePointer(), 2, lda, (ShortPointer) cBPointer.getDevicePointer(), 2, ldb, new FloatPointer(beta), (ShortPointer) cCPointer.getDevicePointer(), 2, ldc); + + } + + ctx.getOldStream().synchronize(); } allocator.registerAction(ctx, C, A, B); @@ -114,18 +120,24 @@ public class JcublasLevel3 extends BaseLevel3 { val ctx = allocator.getFlowController().prepareAction(C, A, B); + //log.info("Synchronizing CUDA stream"); + ctx.getOldStream().synchronize(); + val cAPointer = new CublasPointer(A, ctx); val cBPointer = new CublasPointer(B, ctx); val cCPointer = new CublasPointer(C, ctx); val handle = ctx.getCublasHandle(); synchronized (handle) { + //log.info("Handle: {}; Stream: {}", handle.address(), ctx.getCublasStream().address()); cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream())); cublasSgemm_v2(new cublasContext(handle), convertTranspose(TransA), convertTranspose(TransB), M, N, K, new FloatPointer(alpha), (FloatPointer) cAPointer.getDevicePointer(), lda, (FloatPointer) cBPointer.getDevicePointer(), ldb, new FloatPointer(beta), (FloatPointer) cCPointer.getDevicePointer(), ldc); + + ctx.getOldStream().synchronize(); } allocator.registerAction(ctx, C, A, B); @@ -244,6 +256,8 @@ public class JcublasLevel3 extends BaseLevel3 { new DoublePointer(alpha), (DoublePointer) cAPointer.getDevicePointer(), lda, (DoublePointer) cBPointer.getDevicePointer(), ldb, new DoublePointer(beta), (DoublePointer) cCPointer.getDevicePointer(), ldc); + + ctx.getOldStream().synchronize(); } allocator.registerAction(ctx, C, A, B); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java index c5b02a82f..43bbfbdca 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java @@ -2548,6 +2548,9 @@ public class CudaExecutioner extends DefaultOpExecutioner { @Override public DataBuffer createShapeInfo(long[] shape, long[] stride, long elementWiseStride, char order, DataType dtype, boolean empty) { + if (nativeOps.lastErrorCode() != 0) + throw new RuntimeException(nativeOps.lastErrorMessage()); + OpaqueConstantDataBuffer dbf = nativeOps.shapeBuffer(shape.length, new LongPointer(shape), new LongPointer(stride), dtype.toInt(), order, elementWiseStride, empty); if (nativeOps.lastErrorCode() != 0) @@ -2562,6 +2565,9 @@ public class CudaExecutioner extends DefaultOpExecutioner { @Override public TadPack tadShapeInfoAndOffsets(INDArray array, int[] dimension) { + if (nativeOps.lastErrorCode() != 0) + throw new RuntimeException(nativeOps.lastErrorMessage()); + OpaqueTadPack pack = nativeOps.tadOnlyShapeInfo((LongPointer) array.shapeInfoDataBuffer().addressPointer(), new IntPointer(dimension), dimension.length); if (nativeOps.lastErrorCode() != 0) @@ -2577,6 +2583,9 @@ public class CudaExecutioner extends DefaultOpExecutioner { @Override public DataBuffer createConstantBuffer(long[] values, DataType desiredType) { + if (nativeOps.lastErrorCode() != 0) + throw new RuntimeException(nativeOps.lastErrorMessage()); + OpaqueConstantDataBuffer dbf = nativeOps.constantBufferLong(desiredType.toInt(), new LongPointer(values), values.length); if (nativeOps.lastErrorCode() != 0) @@ -2590,6 +2599,9 @@ public class CudaExecutioner extends DefaultOpExecutioner { @Override public DataBuffer createConstantBuffer(double[] values, DataType desiredType) { + if (nativeOps.lastErrorCode() != 0) + throw new RuntimeException(nativeOps.lastErrorMessage()); + OpaqueConstantDataBuffer dbf = nativeOps.constantBufferDouble(desiredType.toInt(), new DoublePointer(values), values.length); if (nativeOps.lastErrorCode() != 0) From 25b01f7850e47cd314063f9e9e34c58eea80b17e Mon Sep 17 00:00:00 2001 From: Robert Altena Date: Wed, 4 Sep 2019 12:29:02 +0900 Subject: [PATCH 13/19] javadoc and remove deprecated methods. (#231) Signed-off-by: Robert Altena --- .../nd4j/linalg/api/ndarray/BaseNDArray.java | 82 ------------------- .../linalg/api/ndarray/BaseSparseNDArray.java | 40 --------- .../api/ndarray/BaseSparseNDArrayCOO.java | 17 ---- .../api/ndarray/BaseSparseNDArrayCSR.java | 5 -- .../org/nd4j/linalg/api/ndarray/INDArray.java | 52 ++++-------- .../nd4j/linalg/indexing/NDArrayIndex.java | 51 ------------ 6 files changed, 14 insertions(+), 233 deletions(-) diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java index ac642872c..771b74615 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java @@ -1149,16 +1149,6 @@ public abstract class BaseNDArray implements INDArray, Iterable { setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape(), stride(), elementWiseStride(), order, this.dataType(), isEmpty())); } - @Override - public void setShape(long[] shape) { - setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape, stride(), elementWiseStride(), ordering(), this.dataType(), isEmpty())); - } - - @Override - public void setStride(long[] stride) { - setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape(), stride, elementWiseStride(), ordering(), this.dataType(), isEmpty())); - } - @Override public void setShapeAndStride(int[] shape, int[] stride) { setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(ArrayUtil.toLongArray(shape), ArrayUtil.toLongArray(stride), 0, ordering(), this.dataType(), false)); @@ -1283,29 +1273,16 @@ public abstract class BaseNDArray implements INDArray, Iterable { return scalar.getDouble(0); } - /** - * Returns entropy value for this INDArray - * @return - */ @Override public Number entropyNumber() { return entropy(Integer.MAX_VALUE).getDouble(0); } - /** - * Returns non-normalized Shannon entropy value for this INDArray - * @return - */ @Override public Number shannonEntropyNumber() { return shannonEntropy(Integer.MAX_VALUE).getDouble(0); } - - /** - * Returns log entropy value for this INDArray - * @return - */ @Override public Number logEntropyNumber() { return logEntropy(Integer.MAX_VALUE).getDouble(0); @@ -2297,37 +2274,6 @@ public abstract class BaseNDArray implements INDArray, Iterable { return size(0); } - @Override - public INDArray subArray(long[] offsets, int[] shape, int[] stride) { - Nd4j.getCompressor().autoDecompress(this); - int n = shape.length; - - // FIXME: shapeInfo should be used here - if (shape.length < 1) - return create(Nd4j.createBufferDetached(shape)); - if (offsets.length != n) - throw new IllegalArgumentException("Invalid offset " + Arrays.toString(offsets)); - if (stride.length != n) - throw new IllegalArgumentException("Invalid stride " + Arrays.toString(stride)); - - if (Shape.contentEquals(shape, shapeOf())) { - if (ArrayUtil.isZero(offsets)) { - return this; - } else { - throw new IllegalArgumentException("Invalid subArray offsets"); - } - } - - long[] dotProductOffsets = offsets; - int[] dotProductStride = stride; - - long offset = Shape.offset(jvmShapeInfo.javaShapeInformation) + NDArrayIndex.offset(dotProductStride, dotProductOffsets); - if (offset >= data().length()) - offset = ArrayUtil.sumLong(offsets); - - return create(data, Arrays.copyOf(shape, shape.length), stride, offset, ordering()); - } - protected INDArray create(DataBuffer buffer) { return Nd4j.create(buffer); } @@ -4016,58 +3962,30 @@ public abstract class BaseNDArray implements INDArray, Iterable { return Nd4j.getExecutioner().exec(new AMin(this, dimension)); } - /** - * Returns the sum along the specified dimension(s) of this ndarray - * - * @param dimension the dimension to getScalar the sum along - * @return the sum along the specified dimension of this ndarray - */ @Override public INDArray sum(int... dimension) { validateNumericalArray("sum", true); return Nd4j.getExecutioner().exec(new Sum(this, dimension)); } - /** - * Returns the sum along the last dimension of this ndarray - * - * @param dimension the dimension to getScalar the sum along - * @return the sum along the specified dimension of this ndarray - */ @Override public INDArray sum(boolean keepDim, int... dimension) { validateNumericalArray("sum", true); return Nd4j.getExecutioner().exec(new Sum(this, null, keepDim, dimension)); } - - /** - * Returns entropy along dimension - * @param dimension - * @return - */ @Override public INDArray entropy(int... dimension) { validateNumericalArray("entropy", false); return Nd4j.getExecutioner().exec(new Entropy(this, dimension)); } - /** - * Returns non-normalized Shannon entropy along dimension - * @param dimension - * @return - */ @Override public INDArray shannonEntropy(int... dimension) { validateNumericalArray("shannonEntropy", false); return Nd4j.getExecutioner().exec(new ShannonEntropy(this, dimension)); } - /** - * Returns log entropy along dimension - * @param dimension - * @return - */ @Override public INDArray logEntropy(int... dimension) { validateNumericalArray("logEntropy", false); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseSparseNDArray.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseSparseNDArray.java index 1e0772494..11a005f91 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseSparseNDArray.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseSparseNDArray.java @@ -468,16 +468,6 @@ public abstract class BaseSparseNDArray implements ISparseNDArray { throw new UnsupportedOperationException(); } - @Override - public void setStride(long... stride) { - throw new UnsupportedOperationException(); - } - - @Override - public void setShape(long... shape) { - throw new UnsupportedOperationException(); - } - @Override public INDArray putScalar(long row, long col, double value) { return null; @@ -1284,17 +1274,10 @@ public abstract class BaseSparseNDArray implements ISparseNDArray { @Override public void setShapeAndStride(int[] shape, int[] stride) { - } @Override public void setOrder(char order) { - - } - - @Override - public INDArray subArray(long[] offsets, int[] shape, int[] stride) { - return null; } @Override @@ -1842,49 +1825,26 @@ public abstract class BaseSparseNDArray implements ISparseNDArray { return null; } - /** - * Returns entropy value for this INDArray - * @return - */ @Override public Number entropyNumber() { return entropy(Integer.MAX_VALUE).getDouble(0); } - /** - * Returns non-normalized Shannon entropy value for this INDArray - * @return - */ @Override public Number shannonEntropyNumber() { return shannonEntropy(Integer.MAX_VALUE).getDouble(0); } - - /** - * Returns log entropy value for this INDArray - * @return - */ @Override public Number logEntropyNumber() { return logEntropy(Integer.MAX_VALUE).getDouble(0); } - /** - * Returns entropy along dimension - * @param dimension - * @return - */ @Override public INDArray entropy(int... dimension) { return Nd4j.getExecutioner().exec(new Entropy(this, dimension)); } - /** - * Returns non-normalized Shannon entropy along dimension - * @param dimension - * @return - */ @Override public INDArray shannonEntropy(int... dimension) { return Nd4j.getExecutioner().exec(new ShannonEntropy(this, dimension)); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseSparseNDArrayCOO.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseSparseNDArrayCOO.java index 116a4b4f7..85a7ec5ce 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseSparseNDArrayCOO.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseSparseNDArrayCOO.java @@ -1016,13 +1016,6 @@ public class BaseSparseNDArrayCOO extends BaseSparseNDArray { return extendedFlags; } - @Override - public INDArray subArray(long[] offsets, int[] shape, int[] stride) { - throw new UnsupportedOperationException(); - } - - - /** * Returns the underlying indices of the element of the given index * such as there really are in the original ndarray @@ -1138,16 +1131,6 @@ public class BaseSparseNDArrayCOO extends BaseSparseNDArray { return null; } - @Override - public void setStride(long... stride) { - - } - - @Override - public void setShape(long... shape) { - - } - /** * This method returns true if this INDArray is special case: no-value INDArray * diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseSparseNDArrayCSR.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseSparseNDArrayCSR.java index 92e59486c..cf2f9fe3f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseSparseNDArrayCSR.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseSparseNDArrayCSR.java @@ -213,11 +213,6 @@ public abstract class BaseSparseNDArrayCSR extends BaseSparseNDArray { return shapeInformation; } - @Override - public INDArray subArray(long[] offsets, int[] shape, int[] stride) { - throw new UnsupportedOperationException(); - } - @Override public boolean equals(Object o) { //TODO use op diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/INDArray.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/INDArray.java index 47e259b94..9288b6d51 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/INDArray.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/INDArray.java @@ -1854,63 +1854,47 @@ public interface INDArray extends Serializable, AutoCloseable { /** * Returns entropy value for this INDArray - * @return + * @return entropy value */ Number entropyNumber(); /** * Returns non-normalized Shannon entropy value for this INDArray - * @return + * @return non-normalized Shannon entropy */ Number shannonEntropyNumber(); /** * Returns log entropy value for this INDArray - * @return + * @return log entropy value */ Number logEntropyNumber(); /** * Returns entropy value for this INDArray along specified dimension(s) - * @return + * @param dimension specified dimension(s) + * @return entropy value */ INDArray entropy(int... dimension); /** - * Returns entropy value for this INDArray along specified dimension(s) - * @return + * Returns Shannon entropy value for this INDArray along specified dimension(s) + * @param dimension specified dimension(s) + * @return Shannon entropy */ INDArray shannonEntropy(int... dimension); /** - * Returns entropy value for this INDArray along specified dimension(s) - * @return + * Returns log entropy value for this INDArray along specified dimension(s) + * @param dimension specified dimension(s) + * @return log entropy value */ INDArray logEntropy(int... dimension); - - /** - * stride setter - * @param stride - * @deprecated, use {@link #reshape(int...) } - */ - @Deprecated - void setStride(long... stride); - - /** - * Shape setter - * @param shape - * @deprecated, use {@link #reshape(int...) } - */ - - - @Deprecated - void setShape(long... shape); - /** * Shape and stride setter - * @param shape - * @param stride + * @param shape new value for shape + * @param stride new value for stride */ void setShapeAndStride(int[] shape, int[] stride); @@ -1919,15 +1903,7 @@ public interface INDArray extends Serializable, AutoCloseable { * @param order the ordering to set */ void setOrder(char order); - - /** - * @param offsets - * @param shape - * @param stride - * @return - */ - INDArray subArray(long[] offsets, int[] shape, int[] stride); - + /** * Returns the elements at the specified indices * diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/NDArrayIndex.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/NDArrayIndex.java index c21993548..40aa692eb 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/NDArrayIndex.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/NDArrayIndex.java @@ -96,57 +96,6 @@ public abstract class NDArrayIndex implements INDArrayIndex { return offset(arr.stride(), Indices.offsets(arr.shape(), indices)); } - /** - * Set the shape and stride for - * new axes based dimensions - * @param arr the array to update - * the shape/strides for - * @param indexes the indexes to update based on - */ - public static void updateForNewAxes(INDArray arr, INDArrayIndex... indexes) { - int numNewAxes = NDArrayIndex.numNewAxis(indexes); - if (numNewAxes >= 1 && (indexes[0].length() > 1 || indexes[0] instanceof NDArrayIndexAll)) { - List newShape = new ArrayList<>(); - List newStrides = new ArrayList<>(); - int currDimension = 0; - for (int i = 0; i < indexes.length; i++) { - if (indexes[i] instanceof NewAxis) { - newShape.add(1L); - newStrides.add(0L); - } else { - newShape.add(arr.size(currDimension)); - newStrides.add(arr.size(currDimension)); - currDimension++; - } - } - - while (currDimension < arr.rank()) { - newShape.add((long) currDimension); - newStrides.add((long) currDimension); - currDimension++; - } - - long[] newShapeArr = Longs.toArray(newShape); - long[] newStrideArr = Longs.toArray(newStrides); - - // FIXME: this is wrong, it breaks shapeInfo immutability - arr.setShape(newShapeArr); - arr.setStride(newStrideArr); - - - } else { - if (numNewAxes > 0) { - long[] newShape = Longs.concat(ArrayUtil.toLongArray(ArrayUtil.nTimes(numNewAxes, 1)), arr.shape()); - long[] newStrides = Longs.concat(new long[numNewAxes], arr.stride()); - arr.setShape(newShape); - arr.setStride(newStrides); - } - } - - } - - - /** * Compute the offset given an array of offsets. * The offset is computed(for both fortran an d c ordering) as: From 6cc887bee94d1313547f14ed83b784e58ebe9fba Mon Sep 17 00:00:00 2001 From: Alex Black Date: Wed, 4 Sep 2019 16:36:11 +1000 Subject: [PATCH 14/19] Rename flatbuffers DataType to DType (#228) * Rename flatbuffers DataType enum to DType Signed-off-by: Alex Black * Rename flatbuffers DataType enum to DType Signed-off-by: Alex Black * Updates for flatbuffers datatype enum renaming Signed-off-by: Alex Black --- libnd4j/blas/cpu/GraphExecutioner.cpp | 2 +- libnd4j/include/array/DataTypeUtils.h | 2 +- libnd4j/include/array/impl/DataTypeUtils.cpp | 2 +- .../include/graph/generated/array_generated.h | 102 +++++++++--------- .../graph/generated/array_generated.js | 10 +- .../nd4j/graph/{DataType.cs => DType.cs} | 2 +- .../nd4j/graph/{DataType.java => DType.java} | 4 +- .../nd4j/graph/{DataType.py => DType.py} | 2 +- .../graph/generated/nd4j/graph/FlatArray.cs | 6 +- .../graph/generated/nd4j/graph/FlatNode.cs | 8 +- .../generated/nd4j/graph/FlatVariable.cs | 6 +- .../include/graph/generated/node_generated.js | 6 +- .../graph/generated/variable_generated.h | 10 +- .../graph/generated/variable_generated.js | 8 +- libnd4j/include/graph/impl/FlatUtils.cpp | 2 +- libnd4j/include/graph/impl/Variable.cpp | 6 +- libnd4j/include/graph/scheme/array.fbs | 4 +- libnd4j/include/graph/scheme/node.fbs | 2 +- .../include/graph/scheme/uigraphstatic.fbs | 2 +- libnd4j/include/graph/scheme/variable.fbs | 2 +- .../layers_tests/FlatBuffersTests.cpp | 6 +- .../tests_cpu/layers_tests/VariableTests.cpp | 14 +-- .../samediff/serde/FlatBuffersMapper.java | 58 +++++----- .../nd4j/graph/{DataType.java => DType.java} | 4 +- 24 files changed, 135 insertions(+), 135 deletions(-) rename libnd4j/include/graph/generated/nd4j/graph/{DataType.cs => DType.cs} (93%) rename libnd4j/include/graph/generated/nd4j/graph/{DataType.java => DType.java} (95%) rename libnd4j/include/graph/generated/nd4j/graph/{DataType.py => DType.py} (93%) rename nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/graph/{DataType.java => DType.java} (95%) diff --git a/libnd4j/blas/cpu/GraphExecutioner.cpp b/libnd4j/blas/cpu/GraphExecutioner.cpp index 6f97bc024..ef45a3e0c 100644 --- a/libnd4j/blas/cpu/GraphExecutioner.cpp +++ b/libnd4j/blas/cpu/GraphExecutioner.cpp @@ -583,7 +583,7 @@ Nd4jStatus GraphExecutioner::execute(Graph *graph, VariableSpace* variableSpace) auto fName = builder.CreateString(*(var->getName())); auto id = CreateIntPair(builder, var->id(), var->index()); - auto fv = CreateFlatVariable(builder, id, fName, static_cast(array->dataType()), 0, fArray); + auto fv = CreateFlatVariable(builder, id, fName, static_cast(array->dataType()), 0, fArray); variables_vector.push_back(fv); arrays++; diff --git a/libnd4j/include/array/DataTypeUtils.h b/libnd4j/include/array/DataTypeUtils.h index 8346442eb..2a52ba6f5 100644 --- a/libnd4j/include/array/DataTypeUtils.h +++ b/libnd4j/include/array/DataTypeUtils.h @@ -38,7 +38,7 @@ namespace nd4j { public: static int asInt(DataType type); static DataType fromInt(int dtype); - static DataType fromFlatDataType(nd4j::graph::DataType dtype); + static DataType fromFlatDataType(nd4j::graph::DType dtype); FORCEINLINE static std::string asString(DataType dataType); template diff --git a/libnd4j/include/array/impl/DataTypeUtils.cpp b/libnd4j/include/array/impl/DataTypeUtils.cpp index f0b261039..cdf688b25 100644 --- a/libnd4j/include/array/impl/DataTypeUtils.cpp +++ b/libnd4j/include/array/impl/DataTypeUtils.cpp @@ -27,7 +27,7 @@ namespace nd4j { return (DataType) val; } - DataType DataTypeUtils::fromFlatDataType(nd4j::graph::DataType dtype) { + DataType DataTypeUtils::fromFlatDataType(nd4j::graph::DType dtype) { return (DataType) dtype; } diff --git a/libnd4j/include/graph/generated/array_generated.h b/libnd4j/include/graph/generated/array_generated.h index 5848c0ac4..b581240ad 100644 --- a/libnd4j/include/graph/generated/array_generated.h +++ b/libnd4j/include/graph/generated/array_generated.h @@ -40,56 +40,56 @@ inline const char *EnumNameByteOrder(ByteOrder e) { return EnumNamesByteOrder()[index]; } -enum DataType { - DataType_INHERIT = 0, - DataType_BOOL = 1, - DataType_FLOAT8 = 2, - DataType_HALF = 3, - DataType_HALF2 = 4, - DataType_FLOAT = 5, - DataType_DOUBLE = 6, - DataType_INT8 = 7, - DataType_INT16 = 8, - DataType_INT32 = 9, - DataType_INT64 = 10, - DataType_UINT8 = 11, - DataType_UINT16 = 12, - DataType_UINT32 = 13, - DataType_UINT64 = 14, - DataType_QINT8 = 15, - DataType_QINT16 = 16, - DataType_BFLOAT16 = 17, - DataType_UTF8 = 50, - DataType_MIN = DataType_INHERIT, - DataType_MAX = DataType_UTF8 +enum DType { + DType_INHERIT = 0, + DType_BOOL = 1, + DType_FLOAT8 = 2, + DType_HALF = 3, + DType_HALF2 = 4, + DType_FLOAT = 5, + DType_DOUBLE = 6, + DType_INT8 = 7, + DType_INT16 = 8, + DType_INT32 = 9, + DType_INT64 = 10, + DType_UINT8 = 11, + DType_UINT16 = 12, + DType_UINT32 = 13, + DType_UINT64 = 14, + DType_QINT8 = 15, + DType_QINT16 = 16, + DType_BFLOAT16 = 17, + DType_UTF8 = 50, + DType_MIN = DType_INHERIT, + DType_MAX = DType_UTF8 }; -inline const DataType (&EnumValuesDataType())[19] { - static const DataType values[] = { - DataType_INHERIT, - DataType_BOOL, - DataType_FLOAT8, - DataType_HALF, - DataType_HALF2, - DataType_FLOAT, - DataType_DOUBLE, - DataType_INT8, - DataType_INT16, - DataType_INT32, - DataType_INT64, - DataType_UINT8, - DataType_UINT16, - DataType_UINT32, - DataType_UINT64, - DataType_QINT8, - DataType_QINT16, - DataType_BFLOAT16, - DataType_UTF8 +inline const DType (&EnumValuesDType())[19] { + static const DType values[] = { + DType_INHERIT, + DType_BOOL, + DType_FLOAT8, + DType_HALF, + DType_HALF2, + DType_FLOAT, + DType_DOUBLE, + DType_INT8, + DType_INT16, + DType_INT32, + DType_INT64, + DType_UINT8, + DType_UINT16, + DType_UINT32, + DType_UINT64, + DType_QINT8, + DType_QINT16, + DType_BFLOAT16, + DType_UTF8 }; return values; } -inline const char * const *EnumNamesDataType() { +inline const char * const *EnumNamesDType() { static const char * const names[] = { "INHERIT", "BOOL", @@ -147,9 +147,9 @@ inline const char * const *EnumNamesDataType() { return names; } -inline const char *EnumNameDataType(DataType e) { +inline const char *EnumNameDType(DType e) { const size_t index = static_cast(e); - return EnumNamesDataType()[index]; + return EnumNamesDType()[index]; } struct FlatArray FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { @@ -165,8 +165,8 @@ struct FlatArray FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { const flatbuffers::Vector *buffer() const { return GetPointer *>(VT_BUFFER); } - DataType dtype() const { - return static_cast(GetField(VT_DTYPE, 0)); + DType dtype() const { + return static_cast(GetField(VT_DTYPE, 0)); } ByteOrder byteOrder() const { return static_cast(GetField(VT_BYTEORDER, 0)); @@ -192,7 +192,7 @@ struct FlatArrayBuilder { void add_buffer(flatbuffers::Offset> buffer) { fbb_.AddOffset(FlatArray::VT_BUFFER, buffer); } - void add_dtype(DataType dtype) { + void add_dtype(DType dtype) { fbb_.AddElement(FlatArray::VT_DTYPE, static_cast(dtype), 0); } void add_byteOrder(ByteOrder byteOrder) { @@ -214,7 +214,7 @@ inline flatbuffers::Offset CreateFlatArray( flatbuffers::FlatBufferBuilder &_fbb, flatbuffers::Offset> shape = 0, flatbuffers::Offset> buffer = 0, - DataType dtype = DataType_INHERIT, + DType dtype = DType_INHERIT, ByteOrder byteOrder = ByteOrder_LE) { FlatArrayBuilder builder_(_fbb); builder_.add_buffer(buffer); @@ -228,7 +228,7 @@ inline flatbuffers::Offset CreateFlatArrayDirect( flatbuffers::FlatBufferBuilder &_fbb, const std::vector *shape = nullptr, const std::vector *buffer = nullptr, - DataType dtype = DataType_INHERIT, + DType dtype = DType_INHERIT, ByteOrder byteOrder = ByteOrder_LE) { return nd4j::graph::CreateFlatArray( _fbb, diff --git a/libnd4j/include/graph/generated/array_generated.js b/libnd4j/include/graph/generated/array_generated.js index 8a2b644e6..b98410a9e 100644 --- a/libnd4j/include/graph/generated/array_generated.js +++ b/libnd4j/include/graph/generated/array_generated.js @@ -23,7 +23,7 @@ nd4j.graph.ByteOrder = { /** * @enum */ -nd4j.graph.DataType = { +nd4j.graph.DType = { INHERIT: 0, BOOL: 1, FLOAT8: 2, @@ -123,11 +123,11 @@ nd4j.graph.FlatArray.prototype.bufferArray = function() { }; /** - * @returns {nd4j.graph.DataType} + * @returns {nd4j.graph.DType} */ nd4j.graph.FlatArray.prototype.dtype = function() { var offset = this.bb.__offset(this.bb_pos, 8); - return offset ? /** @type {nd4j.graph.DataType} */ (this.bb.readInt8(this.bb_pos + offset)) : nd4j.graph.DataType.INHERIT; + return offset ? /** @type {nd4j.graph.DType} */ (this.bb.readInt8(this.bb_pos + offset)) : nd4j.graph.DType.INHERIT; }; /** @@ -205,10 +205,10 @@ nd4j.graph.FlatArray.startBufferVector = function(builder, numElems) { /** * @param {flatbuffers.Builder} builder - * @param {nd4j.graph.DataType} dtype + * @param {nd4j.graph.DType} dtype */ nd4j.graph.FlatArray.addDtype = function(builder, dtype) { - builder.addFieldInt8(2, dtype, nd4j.graph.DataType.INHERIT); + builder.addFieldInt8(2, dtype, nd4j.graph.DType.INHERIT); }; /** diff --git a/libnd4j/include/graph/generated/nd4j/graph/DataType.cs b/libnd4j/include/graph/generated/nd4j/graph/DType.cs similarity index 93% rename from libnd4j/include/graph/generated/nd4j/graph/DataType.cs rename to libnd4j/include/graph/generated/nd4j/graph/DType.cs index 9cd9518c9..00e399b50 100644 --- a/libnd4j/include/graph/generated/nd4j/graph/DataType.cs +++ b/libnd4j/include/graph/generated/nd4j/graph/DType.cs @@ -5,7 +5,7 @@ namespace nd4j.graph { -public enum DataType : sbyte +public enum DType : sbyte { INHERIT = 0, BOOL = 1, diff --git a/libnd4j/include/graph/generated/nd4j/graph/DataType.java b/libnd4j/include/graph/generated/nd4j/graph/DType.java similarity index 95% rename from libnd4j/include/graph/generated/nd4j/graph/DataType.java rename to libnd4j/include/graph/generated/nd4j/graph/DType.java index 369c1b6ae..20d3d475b 100644 --- a/libnd4j/include/graph/generated/nd4j/graph/DataType.java +++ b/libnd4j/include/graph/generated/nd4j/graph/DType.java @@ -2,8 +2,8 @@ package nd4j.graph; -public final class DataType { - private DataType() { } +public final class DType { + private DType() { } public static final byte INHERIT = 0; public static final byte BOOL = 1; public static final byte FLOAT8 = 2; diff --git a/libnd4j/include/graph/generated/nd4j/graph/DataType.py b/libnd4j/include/graph/generated/nd4j/graph/DType.py similarity index 93% rename from libnd4j/include/graph/generated/nd4j/graph/DataType.py rename to libnd4j/include/graph/generated/nd4j/graph/DType.py index e07aace5d..24cadf44e 100644 --- a/libnd4j/include/graph/generated/nd4j/graph/DataType.py +++ b/libnd4j/include/graph/generated/nd4j/graph/DType.py @@ -2,7 +2,7 @@ # namespace: graph -class DataType(object): +class DType(object): INHERIT = 0 BOOL = 1 FLOAT8 = 2 diff --git a/libnd4j/include/graph/generated/nd4j/graph/FlatArray.cs b/libnd4j/include/graph/generated/nd4j/graph/FlatArray.cs index a19325fb7..60d836aeb 100644 --- a/libnd4j/include/graph/generated/nd4j/graph/FlatArray.cs +++ b/libnd4j/include/graph/generated/nd4j/graph/FlatArray.cs @@ -33,13 +33,13 @@ public struct FlatArray : IFlatbufferObject public ArraySegment? GetBufferBytes() { return __p.__vector_as_arraysegment(6); } #endif public sbyte[] GetBufferArray() { return __p.__vector_as_array(6); } - public DataType Dtype { get { int o = __p.__offset(8); return o != 0 ? (DataType)__p.bb.GetSbyte(o + __p.bb_pos) : DataType.INHERIT; } } + public DType Dtype { get { int o = __p.__offset(8); return o != 0 ? (DType)__p.bb.GetSbyte(o + __p.bb_pos) : DType.INHERIT; } } public ByteOrder ByteOrder { get { int o = __p.__offset(10); return o != 0 ? (ByteOrder)__p.bb.GetSbyte(o + __p.bb_pos) : ByteOrder.LE; } } public static Offset CreateFlatArray(FlatBufferBuilder builder, VectorOffset shapeOffset = default(VectorOffset), VectorOffset bufferOffset = default(VectorOffset), - DataType dtype = DataType.INHERIT, + DType dtype = DType.INHERIT, ByteOrder byteOrder = ByteOrder.LE) { builder.StartObject(4); FlatArray.AddBuffer(builder, bufferOffset); @@ -58,7 +58,7 @@ public struct FlatArray : IFlatbufferObject public static VectorOffset CreateBufferVector(FlatBufferBuilder builder, sbyte[] data) { builder.StartVector(1, data.Length, 1); for (int i = data.Length - 1; i >= 0; i--) builder.AddSbyte(data[i]); return builder.EndVector(); } public static VectorOffset CreateBufferVectorBlock(FlatBufferBuilder builder, sbyte[] data) { builder.StartVector(1, data.Length, 1); builder.Add(data); return builder.EndVector(); } public static void StartBufferVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(1, numElems, 1); } - public static void AddDtype(FlatBufferBuilder builder, DataType dtype) { builder.AddSbyte(2, (sbyte)dtype, 0); } + public static void AddDtype(FlatBufferBuilder builder, DType dtype) { builder.AddSbyte(2, (sbyte)dtype, 0); } public static void AddByteOrder(FlatBufferBuilder builder, ByteOrder byteOrder) { builder.AddSbyte(3, (sbyte)byteOrder, 0); } public static Offset EndFlatArray(FlatBufferBuilder builder) { int o = builder.EndObject(); diff --git a/libnd4j/include/graph/generated/nd4j/graph/FlatNode.cs b/libnd4j/include/graph/generated/nd4j/graph/FlatNode.cs index c1068811d..0810d2e6e 100644 --- a/libnd4j/include/graph/generated/nd4j/graph/FlatNode.cs +++ b/libnd4j/include/graph/generated/nd4j/graph/FlatNode.cs @@ -97,14 +97,14 @@ public struct FlatNode : IFlatbufferObject public ArraySegment? GetOpNameBytes() { return __p.__vector_as_arraysegment(36); } #endif public byte[] GetOpNameArray() { return __p.__vector_as_array(36); } - public DataType OutputTypes(int j) { int o = __p.__offset(38); return o != 0 ? (DataType)__p.bb.GetSbyte(__p.__vector(o) + j * 1) : (DataType)0; } + public DType OutputTypes(int j) { int o = __p.__offset(38); return o != 0 ? (DType)__p.bb.GetSbyte(__p.__vector(o) + j * 1) : (DType)0; } public int OutputTypesLength { get { int o = __p.__offset(38); return o != 0 ? __p.__vector_len(o) : 0; } } #if ENABLE_SPAN_T public Span GetOutputTypesBytes() { return __p.__vector_as_span(38); } #else public ArraySegment? GetOutputTypesBytes() { return __p.__vector_as_arraysegment(38); } #endif - public DataType[] GetOutputTypesArray() { return __p.__vector_as_array(38); } + public DType[] GetOutputTypesArray() { return __p.__vector_as_array(38); } public FlatArray? Scalar { get { int o = __p.__offset(40); return o != 0 ? (FlatArray?)(new FlatArray()).__assign(__p.__indirect(o + __p.bb_pos), __p.bb) : null; } } public static Offset CreateFlatNode(FlatBufferBuilder builder, @@ -196,8 +196,8 @@ public struct FlatNode : IFlatbufferObject public static void StartOutputNamesVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(4, numElems, 4); } public static void AddOpName(FlatBufferBuilder builder, StringOffset opNameOffset) { builder.AddOffset(16, opNameOffset.Value, 0); } public static void AddOutputTypes(FlatBufferBuilder builder, VectorOffset outputTypesOffset) { builder.AddOffset(17, outputTypesOffset.Value, 0); } - public static VectorOffset CreateOutputTypesVector(FlatBufferBuilder builder, DataType[] data) { builder.StartVector(1, data.Length, 1); for (int i = data.Length - 1; i >= 0; i--) builder.AddSbyte((sbyte)data[i]); return builder.EndVector(); } - public static VectorOffset CreateOutputTypesVectorBlock(FlatBufferBuilder builder, DataType[] data) { builder.StartVector(1, data.Length, 1); builder.Add(data); return builder.EndVector(); } + public static VectorOffset CreateOutputTypesVector(FlatBufferBuilder builder, DType[] data) { builder.StartVector(1, data.Length, 1); for (int i = data.Length - 1; i >= 0; i--) builder.AddSbyte((sbyte)data[i]); return builder.EndVector(); } + public static VectorOffset CreateOutputTypesVectorBlock(FlatBufferBuilder builder, DType[] data) { builder.StartVector(1, data.Length, 1); builder.Add(data); return builder.EndVector(); } public static void StartOutputTypesVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(1, numElems, 1); } public static void AddScalar(FlatBufferBuilder builder, Offset scalarOffset) { builder.AddOffset(18, scalarOffset.Value, 0); } public static Offset EndFlatNode(FlatBufferBuilder builder) { diff --git a/libnd4j/include/graph/generated/nd4j/graph/FlatVariable.cs b/libnd4j/include/graph/generated/nd4j/graph/FlatVariable.cs index d5f8014f2..9764668a0 100644 --- a/libnd4j/include/graph/generated/nd4j/graph/FlatVariable.cs +++ b/libnd4j/include/graph/generated/nd4j/graph/FlatVariable.cs @@ -25,7 +25,7 @@ public struct FlatVariable : IFlatbufferObject public ArraySegment? GetNameBytes() { return __p.__vector_as_arraysegment(6); } #endif public byte[] GetNameArray() { return __p.__vector_as_array(6); } - public DataType Dtype { get { int o = __p.__offset(8); return o != 0 ? (DataType)__p.bb.GetSbyte(o + __p.bb_pos) : DataType.INHERIT; } } + public DType Dtype { get { int o = __p.__offset(8); return o != 0 ? (DType)__p.bb.GetSbyte(o + __p.bb_pos) : DType.INHERIT; } } public long Shape(int j) { int o = __p.__offset(10); return o != 0 ? __p.bb.GetLong(__p.__vector(o) + j * 8) : (long)0; } public int ShapeLength { get { int o = __p.__offset(10); return o != 0 ? __p.__vector_len(o) : 0; } } #if ENABLE_SPAN_T @@ -41,7 +41,7 @@ public struct FlatVariable : IFlatbufferObject public static Offset CreateFlatVariable(FlatBufferBuilder builder, Offset idOffset = default(Offset), StringOffset nameOffset = default(StringOffset), - DataType dtype = DataType.INHERIT, + DType dtype = DType.INHERIT, VectorOffset shapeOffset = default(VectorOffset), Offset ndarrayOffset = default(Offset), int device = 0, @@ -60,7 +60,7 @@ public struct FlatVariable : IFlatbufferObject public static void StartFlatVariable(FlatBufferBuilder builder) { builder.StartObject(7); } public static void AddId(FlatBufferBuilder builder, Offset idOffset) { builder.AddOffset(0, idOffset.Value, 0); } public static void AddName(FlatBufferBuilder builder, StringOffset nameOffset) { builder.AddOffset(1, nameOffset.Value, 0); } - public static void AddDtype(FlatBufferBuilder builder, DataType dtype) { builder.AddSbyte(2, (sbyte)dtype, 0); } + public static void AddDtype(FlatBufferBuilder builder, DType dtype) { builder.AddSbyte(2, (sbyte)dtype, 0); } public static void AddShape(FlatBufferBuilder builder, VectorOffset shapeOffset) { builder.AddOffset(3, shapeOffset.Value, 0); } public static VectorOffset CreateShapeVector(FlatBufferBuilder builder, long[] data) { builder.StartVector(8, data.Length, 8); for (int i = data.Length - 1; i >= 0; i--) builder.AddLong(data[i]); return builder.EndVector(); } public static VectorOffset CreateShapeVectorBlock(FlatBufferBuilder builder, long[] data) { builder.StartVector(8, data.Length, 8); builder.Add(data); return builder.EndVector(); } diff --git a/libnd4j/include/graph/generated/node_generated.js b/libnd4j/include/graph/generated/node_generated.js index a7b2e264f..bd2274dad 100644 --- a/libnd4j/include/graph/generated/node_generated.js +++ b/libnd4j/include/graph/generated/node_generated.js @@ -312,11 +312,11 @@ nd4j.graph.FlatNode.prototype.opName = function(optionalEncoding) { /** * @param {number} index - * @returns {nd4j.graph.DataType} + * @returns {nd4j.graph.DType} */ nd4j.graph.FlatNode.prototype.outputTypes = function(index) { var offset = this.bb.__offset(this.bb_pos, 38); - return offset ? /** @type {nd4j.graph.DataType} */ (this.bb.readInt8(this.bb.__vector(this.bb_pos + offset) + index)) : /** @type {nd4j.graph.DataType} */ (0); + return offset ? /** @type {nd4j.graph.DType} */ (this.bb.readInt8(this.bb.__vector(this.bb_pos + offset) + index)) : /** @type {nd4j.graph.DType} */ (0); }; /** @@ -686,7 +686,7 @@ nd4j.graph.FlatNode.addOutputTypes = function(builder, outputTypesOffset) { /** * @param {flatbuffers.Builder} builder - * @param {Array.} data + * @param {Array.} data * @returns {flatbuffers.Offset} */ nd4j.graph.FlatNode.createOutputTypesVector = function(builder, data) { diff --git a/libnd4j/include/graph/generated/variable_generated.h b/libnd4j/include/graph/generated/variable_generated.h index e441c17dc..ca1a705a0 100644 --- a/libnd4j/include/graph/generated/variable_generated.h +++ b/libnd4j/include/graph/generated/variable_generated.h @@ -65,8 +65,8 @@ struct FlatVariable FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { const flatbuffers::String *name() const { return GetPointer(VT_NAME); } - DataType dtype() const { - return static_cast(GetField(VT_DTYPE, 0)); + DType dtype() const { + return static_cast(GetField(VT_DTYPE, 0)); } const flatbuffers::Vector *shape() const { return GetPointer *>(VT_SHAPE); @@ -106,7 +106,7 @@ struct FlatVariableBuilder { void add_name(flatbuffers::Offset name) { fbb_.AddOffset(FlatVariable::VT_NAME, name); } - void add_dtype(DataType dtype) { + void add_dtype(DType dtype) { fbb_.AddElement(FlatVariable::VT_DTYPE, static_cast(dtype), 0); } void add_shape(flatbuffers::Offset> shape) { @@ -137,7 +137,7 @@ inline flatbuffers::Offset CreateFlatVariable( flatbuffers::FlatBufferBuilder &_fbb, flatbuffers::Offset id = 0, flatbuffers::Offset name = 0, - DataType dtype = DataType_INHERIT, + DType dtype = DType_INHERIT, flatbuffers::Offset> shape = 0, flatbuffers::Offset ndarray = 0, int32_t device = 0, @@ -157,7 +157,7 @@ inline flatbuffers::Offset CreateFlatVariableDirect( flatbuffers::FlatBufferBuilder &_fbb, flatbuffers::Offset id = 0, const char *name = nullptr, - DataType dtype = DataType_INHERIT, + DType dtype = DType_INHERIT, const std::vector *shape = nullptr, flatbuffers::Offset ndarray = 0, int32_t device = 0, diff --git a/libnd4j/include/graph/generated/variable_generated.js b/libnd4j/include/graph/generated/variable_generated.js index 3f128e4fc..9012af2de 100644 --- a/libnd4j/include/graph/generated/variable_generated.js +++ b/libnd4j/include/graph/generated/variable_generated.js @@ -76,11 +76,11 @@ nd4j.graph.FlatVariable.prototype.name = function(optionalEncoding) { }; /** - * @returns {nd4j.graph.DataType} + * @returns {nd4j.graph.DType} */ nd4j.graph.FlatVariable.prototype.dtype = function() { var offset = this.bb.__offset(this.bb_pos, 8); - return offset ? /** @type {nd4j.graph.DataType} */ (this.bb.readInt8(this.bb_pos + offset)) : nd4j.graph.DataType.INHERIT; + return offset ? /** @type {nd4j.graph.DType} */ (this.bb.readInt8(this.bb_pos + offset)) : nd4j.graph.DType.INHERIT; }; /** @@ -150,10 +150,10 @@ nd4j.graph.FlatVariable.addName = function(builder, nameOffset) { /** * @param {flatbuffers.Builder} builder - * @param {nd4j.graph.DataType} dtype + * @param {nd4j.graph.DType} dtype */ nd4j.graph.FlatVariable.addDtype = function(builder, dtype) { - builder.addFieldInt8(2, dtype, nd4j.graph.DataType.INHERIT); + builder.addFieldInt8(2, dtype, nd4j.graph.DType.INHERIT); }; /** diff --git a/libnd4j/include/graph/impl/FlatUtils.cpp b/libnd4j/include/graph/impl/FlatUtils.cpp index bc8ff7e33..ec76cb4d2 100644 --- a/libnd4j/include/graph/impl/FlatUtils.cpp +++ b/libnd4j/include/graph/impl/FlatUtils.cpp @@ -111,7 +111,7 @@ namespace nd4j { auto bo = static_cast(BitwiseUtils::asByteOrder()); - return CreateFlatArray(builder, fShape, fBuffer, static_cast(array.dataType()), bo); + return CreateFlatArray(builder, fShape, fBuffer, static_cast(array.dataType()), bo); } } } \ No newline at end of file diff --git a/libnd4j/include/graph/impl/Variable.cpp b/libnd4j/include/graph/impl/Variable.cpp index 6dd881f11..e54112783 100644 --- a/libnd4j/include/graph/impl/Variable.cpp +++ b/libnd4j/include/graph/impl/Variable.cpp @@ -219,7 +219,7 @@ namespace nd4j { throw std::runtime_error("CONSTANT variable must have NDArray bundled"); auto ar = flatVariable->ndarray(); - if (ar->dtype() == DataType_UTF8) { + if (ar->dtype() == DType_UTF8) { _ndarray = nd4j::graph::FlatUtils::fromFlatArray(ar); } else { _ndarray = nd4j::graph::FlatUtils::fromFlatArray(ar); @@ -320,7 +320,7 @@ namespace nd4j { auto fBuffer = builder.CreateVector(array->asByteVector()); // packing array - auto fArray = CreateFlatArray(builder, fShape, fBuffer, (nd4j::graph::DataType) array->dataType()); + auto fArray = CreateFlatArray(builder, fShape, fBuffer, (nd4j::graph::DType) array->dataType()); // packing id/index of this var auto fVid = CreateIntPair(builder, this->_id, this->_index); @@ -331,7 +331,7 @@ namespace nd4j { stringId = builder.CreateString(this->_name); // returning array - return CreateFlatVariable(builder, fVid, stringId, static_cast(array->dataType()), 0, fArray); + return CreateFlatVariable(builder, fVid, stringId, static_cast(array->dataType()), 0, fArray); } else { throw std::runtime_error("Variable::asFlatVariable isn't possible for NDArrayList"); } diff --git a/libnd4j/include/graph/scheme/array.fbs b/libnd4j/include/graph/scheme/array.fbs index f415ffb08..91e338500 100644 --- a/libnd4j/include/graph/scheme/array.fbs +++ b/libnd4j/include/graph/scheme/array.fbs @@ -23,7 +23,7 @@ enum ByteOrder:byte { } // DataType for arrays/buffers -enum DataType:byte { +enum DType:byte { INHERIT, BOOL, FLOAT8, @@ -49,7 +49,7 @@ enum DataType:byte { table FlatArray { shape:[long]; // shape in Nd4j format buffer:[byte]; // byte buffer with data - dtype:DataType; // data type of actual data within buffer + dtype:DType; // data type of actual data within buffer byteOrder:ByteOrder; // byte order of buffer } diff --git a/libnd4j/include/graph/scheme/node.fbs b/libnd4j/include/graph/scheme/node.fbs index 6117e7125..930702f6d 100644 --- a/libnd4j/include/graph/scheme/node.fbs +++ b/libnd4j/include/graph/scheme/node.fbs @@ -48,7 +48,7 @@ table FlatNode { opName:string; //Used to help resolving the class. In a few cases, multiple classes/opNames are mapped to same hash, and might have different config/properties/differentiability // output data types (optional) - outputTypes:[DataType]; + outputTypes:[DType]; //Scalar value - used for scalar ops. Should be single value only. scalar:FlatArray; diff --git a/libnd4j/include/graph/scheme/uigraphstatic.fbs b/libnd4j/include/graph/scheme/uigraphstatic.fbs index cce0da4ad..814c28fa5 100644 --- a/libnd4j/include/graph/scheme/uigraphstatic.fbs +++ b/libnd4j/include/graph/scheme/uigraphstatic.fbs @@ -51,7 +51,7 @@ table UIVariable { id:IntPair; //Existing IntPair class name:string; type:VarType; //Use existing VarType: VARIABLE, CONSTANT, ARRAY, PLACEHOLDER - datatype:DataType; + datatype:DType; shape:[long]; controlDeps:[string]; //Input control dependencies: variable x -> this outputOfOp:string; //Null for placeholders/constants. For array type SDVariables, the name of the op it's an output of diff --git a/libnd4j/include/graph/scheme/variable.fbs b/libnd4j/include/graph/scheme/variable.fbs index 43f343c7c..31eafafa7 100644 --- a/libnd4j/include/graph/scheme/variable.fbs +++ b/libnd4j/include/graph/scheme/variable.fbs @@ -30,7 +30,7 @@ enum VarType:byte { table FlatVariable { id:IntPair; // ID of the Variable, in format of IntPair.first is node Id, IntPair.second is output index of the node name:string; // symbolic ID of the Variable (if defined) - dtype:DataType; + dtype:DType; shape:[long]; // shape is absolutely optional. either shape or ndarray might be set ndarray:FlatArray; diff --git a/libnd4j/tests_cpu/layers_tests/FlatBuffersTests.cpp b/libnd4j/tests_cpu/layers_tests/FlatBuffersTests.cpp index cf9f2914e..49dd0657d 100644 --- a/libnd4j/tests_cpu/layers_tests/FlatBuffersTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/FlatBuffersTests.cpp @@ -94,10 +94,10 @@ TEST_F(FlatBuffersTest, FlatGraphTest1) { auto fShape = builder.CreateVector(array->getShapeInfoAsFlatVector()); auto fBuffer = builder.CreateVector(array->asByteVector()); - auto fArray = CreateFlatArray(builder, fShape, fBuffer, nd4j::graph::DataType::DataType_FLOAT); + auto fArray = CreateFlatArray(builder, fShape, fBuffer, nd4j::graph::DType::DType_FLOAT); auto fVid = CreateIntPair(builder, -1); - auto fVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DataType::DataType_FLOAT, 0, fArray); + auto fVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DType::DType_FLOAT, 0, fArray); std::vector outputs1, outputs2, inputs1, inputs2; outputs1.push_back(2); @@ -265,7 +265,7 @@ TEST_F(FlatBuffersTest, ExplicitOutputTest1) { auto name1 = builder.CreateString("wow1"); - auto node1 = CreateFlatNode(builder, 1, name1, OpType_TRANSFORM, 0, in1, 0, nd4j::graph::DataType::FLOAT); + auto node1 = CreateFlatNode(builder, 1, name1, OpType_TRANSFORM, 0, in1, 0, nd4j::graph::DType::FLOAT); std::vector> variables_vector; variables_vector.push_back(fXVar); diff --git a/libnd4j/tests_cpu/layers_tests/VariableTests.cpp b/libnd4j/tests_cpu/layers_tests/VariableTests.cpp index e31347b0e..fcdd1db3c 100644 --- a/libnd4j/tests_cpu/layers_tests/VariableTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/VariableTests.cpp @@ -73,9 +73,9 @@ TEST_F(VariableTests, Test_FlatVariableDataType_1) { auto fBuffer = builder.CreateVector(vec); auto fVid = CreateIntPair(builder, 1, 12); - auto fArray = CreateFlatArray(builder, fShape, fBuffer, nd4j::graph::DataType::DataType_FLOAT); + auto fArray = CreateFlatArray(builder, fShape, fBuffer, nd4j::graph::DType::DType_FLOAT); - auto flatVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DataType::DataType_FLOAT, 0, fArray); + auto flatVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DType::DType_FLOAT, 0, fArray); builder.Finish(flatVar); @@ -107,9 +107,9 @@ TEST_F(VariableTests, Test_FlatVariableDataType_2) { auto fBuffer = builder.CreateVector(vec); auto fVid = CreateIntPair(builder, 1, 12); - auto fArray = CreateFlatArray(builder, fShape, fBuffer, nd4j::graph::DataType::DataType_DOUBLE); + auto fArray = CreateFlatArray(builder, fShape, fBuffer, nd4j::graph::DType::DType_DOUBLE); - auto flatVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DataType::DataType_DOUBLE, 0, fArray); + auto flatVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DType::DType_DOUBLE, 0, fArray); builder.Finish(flatVar); @@ -144,9 +144,9 @@ TEST_F(VariableTests, Test_FlatVariableDataType_3) { auto fBuffer = builder.CreateVector(vec); auto fVid = CreateIntPair(builder, 1, 12); - auto fArray = CreateFlatArray(builder, fShape, fBuffer, nd4j::graph::DataType::DataType_DOUBLE); + auto fArray = CreateFlatArray(builder, fShape, fBuffer, nd4j::graph::DType::DType_DOUBLE); - auto flatVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DataType::DataType_DOUBLE, 0, fArray); + auto flatVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DType::DType_DOUBLE, 0, fArray); builder.Finish(flatVar); @@ -180,7 +180,7 @@ TEST_F(VariableTests, Test_FlatVariableDataType_4) { auto fShape = builder.CreateVector(original.getShapeAsFlatVector()); auto fVid = CreateIntPair(builder, 37, 12); - auto flatVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DataType::DataType_FLOAT, fShape, 0, 0, VarType_PLACEHOLDER); + auto flatVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DType::DType_FLOAT, fShape, 0, 0, VarType_PLACEHOLDER); builder.Finish(flatVar); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/FlatBuffersMapper.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/FlatBuffersMapper.java index 6faf29bfc..cce38cf24 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/FlatBuffersMapper.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/FlatBuffersMapper.java @@ -31,7 +31,7 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.VariableType; import org.nd4j.autodiff.samediff.internal.Variable; import org.nd4j.base.Preconditions; -import org.nd4j.graph.DataType; +import org.nd4j.graph.DType; import org.nd4j.graph.FlatArray; import org.nd4j.graph.FlatNode; import org.nd4j.graph.FlatProperties; @@ -66,33 +66,33 @@ public class FlatBuffersMapper { public static byte getDataTypeAsByte(@NonNull org.nd4j.linalg.api.buffer.DataType type) { switch (type) { case FLOAT: - return DataType.FLOAT; + return DType.FLOAT; case DOUBLE: - return DataType.DOUBLE; + return DType.DOUBLE; case HALF: - return DataType.HALF; + return DType.HALF; case INT: - return DataType.INT32; + return DType.INT32; case LONG: - return DataType.INT64; + return DType.INT64; case BOOL: - return DataType.BOOL; + return DType.BOOL; case SHORT: - return DataType.INT16; + return DType.INT16; case BYTE: - return DataType.INT8; + return DType.INT8; case UBYTE: - return DataType.UINT8; + return DType.UINT8; case UTF8: - return DataType.UTF8; + return DType.UTF8; case UINT16: - return DataType.UINT16; + return DType.UINT16; case UINT32: - return DataType.UINT32; + return DType.UINT32; case UINT64: - return DataType.UINT64; + return DType.UINT64; case BFLOAT16: - return DataType.BFLOAT16; + return DType.BFLOAT16; default: throw new ND4JIllegalStateException("Unknown or unsupported DataType used: [" + type + "]"); } @@ -102,33 +102,33 @@ public class FlatBuffersMapper { * This method converts enums for DataType */ public static org.nd4j.linalg.api.buffer.DataType getDataTypeFromByte(byte val) { - if (val == DataType.FLOAT) { + if (val == DType.FLOAT) { return org.nd4j.linalg.api.buffer.DataType.FLOAT; - } else if (val == DataType.DOUBLE) { + } else if (val == DType.DOUBLE) { return org.nd4j.linalg.api.buffer.DataType.DOUBLE; - } else if (val == DataType.HALF) { + } else if (val == DType.HALF) { return org.nd4j.linalg.api.buffer.DataType.HALF; - } else if (val == DataType.INT32) { + } else if (val == DType.INT32) { return org.nd4j.linalg.api.buffer.DataType.INT; - } else if (val == DataType.INT64) { + } else if (val == DType.INT64) { return org.nd4j.linalg.api.buffer.DataType.LONG; - } else if (val == DataType.INT8) { + } else if (val == DType.INT8) { return org.nd4j.linalg.api.buffer.DataType.BYTE; - } else if (val == DataType.BOOL) { + } else if (val == DType.BOOL) { return org.nd4j.linalg.api.buffer.DataType.BOOL; - } else if (val == DataType.UINT8) { + } else if (val == DType.UINT8) { return org.nd4j.linalg.api.buffer.DataType.UBYTE; - } else if (val == DataType.INT16) { + } else if (val == DType.INT16) { return org.nd4j.linalg.api.buffer.DataType.SHORT; - } else if (val == DataType.UTF8) { + } else if (val == DType.UTF8) { return org.nd4j.linalg.api.buffer.DataType.UTF8; - } else if (val == DataType.UINT16) { + } else if (val == DType.UINT16) { return org.nd4j.linalg.api.buffer.DataType.UINT16; - } else if (val == DataType.UINT32) { + } else if (val == DType.UINT32) { return org.nd4j.linalg.api.buffer.DataType.UINT32; - } else if (val == DataType.UINT64) { + } else if (val == DType.UINT64) { return org.nd4j.linalg.api.buffer.DataType.UINT64; - } else if (val == DataType.BFLOAT16){ + } else if (val == DType.BFLOAT16){ return org.nd4j.linalg.api.buffer.DataType.BFLOAT16; } else { throw new RuntimeException("Unknown datatype: " + val); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/graph/DataType.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/graph/DType.java similarity index 95% rename from nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/graph/DataType.java rename to nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/graph/DType.java index 17a0752f0..2617ce8f6 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/graph/DataType.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/graph/DType.java @@ -2,8 +2,8 @@ package org.nd4j.graph; -public final class DataType { - private DataType() { } +public final class DType { + private DType() { } public static final byte INHERIT = 0; public static final byte BOOL = 1; public static final byte FLOAT8 = 2; From e9454b888298646378c6ac552921371553f26ec5 Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Wed, 4 Sep 2019 00:44:01 -0700 Subject: [PATCH 15/19] SDCNN cleanup pass (#230) * SDCNN cleanup Signed-off-by: Ryan Nett * NonNull annotations Signed-off-by: Ryan Nett * better javadoc, NonNull fix for sconv Signed-off-by: Ryan Nett * update builders to fix names Signed-off-by: Ryan Nett * fixes Signed-off-by: Ryan Nett * even more fixes Signed-off-by: Ryan Nett * fix for null bias Signed-off-by: Ryan Nett --- .../DifferentialFunctionFactory.java | 48 +- .../org/nd4j/autodiff/samediff/ops/SDCNN.java | 498 ++++++++---------- .../impl/layers/convolution/AvgPooling2D.java | 21 +- .../ops/impl/layers/convolution/Conv1D.java | 24 +- .../ops/impl/layers/convolution/Conv2D.java | 31 +- .../layers/convolution/Conv2DDerivative.java | 4 +- .../ops/impl/layers/convolution/Conv3D.java | 38 +- .../layers/convolution/Conv3DDerivative.java | 4 +- .../ops/impl/layers/convolution/DeConv2D.java | 28 +- .../convolution/DeConv2DDerivative.java | 4 +- .../impl/layers/convolution/DeConv2DTF.java | 24 +- .../ops/impl/layers/convolution/DeConv3D.java | 14 +- .../layers/convolution/DepthwiseConv2D.java | 22 +- .../LocalResponseNormalization.java | 20 +- .../LocalResponseNormalizationDerivative.java | 4 +- .../impl/layers/convolution/MaxPooling2D.java | 15 +- .../impl/layers/convolution/Pooling2D.java | 35 +- .../convolution/Pooling2DDerivative.java | 9 +- .../ops/impl/layers/convolution/SConv2D.java | 15 +- .../layers/convolution/SConv2DDerivative.java | 4 +- .../nd4j/linalg/convolution/Convolution.java | 32 +- .../opvalidation/LayerOpValidation.java | 23 +- 22 files changed, 466 insertions(+), 451 deletions(-) diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java index ac017beef..3086b0f1b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java @@ -469,7 +469,7 @@ public class DifferentialFunctionFactory { * @return */ public SDVariable localResponseNormalization(SDVariable input, LocalResponseNormalizationConfig lrnConfig) { - LocalResponseNormalization lrn = LocalResponseNormalization.builder() + LocalResponseNormalization lrn = LocalResponseNormalization.sameDiffBuilder() .inputFunctions(new SDVariable[]{input}) .sameDiff(sameDiff()) .config(lrnConfig) @@ -487,7 +487,7 @@ public class DifferentialFunctionFactory { * @return */ public SDVariable conv1d(SDVariable input, SDVariable weights, Conv1DConfig conv1DConfig) { - Conv1D conv1D = Conv1D.builder() + Conv1D conv1D = Conv1D.sameDiffBuilder() .inputFunctions(new SDVariable[]{input, weights}) .sameDiff(sameDiff()) .config(conv1DConfig) @@ -496,6 +496,34 @@ public class DifferentialFunctionFactory { return conv1D.outputVariable(); } + /** + * Conv1d operation. + * + * @param input the inputs to conv1d + * @param weights conv1d weights + * @param bias conv1d bias + * @param conv1DConfig the configuration + * @return + */ + public SDVariable conv1d(SDVariable input, SDVariable weights, SDVariable bias, Conv1DConfig conv1DConfig) { + + SDVariable[] args; + + if(bias == null){ + args = new SDVariable[]{input, weights}; + } else { + args = new SDVariable[]{input, weights, bias}; + } + + Conv1D conv1D = Conv1D.sameDiffBuilder() + .inputFunctions(args) + .sameDiff(sameDiff()) + .config(conv1DConfig) + .build(); + + return conv1D.outputVariable(); + } + /** * Conv2d operation. * @@ -504,7 +532,7 @@ public class DifferentialFunctionFactory { * @return */ public SDVariable conv2d(SDVariable[] inputs, Conv2DConfig conv2DConfig) { - Conv2D conv2D = Conv2D.builder() + Conv2D conv2D = Conv2D.sameDiffBuilder() .inputFunctions(inputs) .sameDiff(sameDiff()) .config(conv2DConfig) @@ -530,7 +558,7 @@ public class DifferentialFunctionFactory { * @return */ public SDVariable avgPooling2d(SDVariable input, Pooling2DConfig pooling2DConfig) { - AvgPooling2D avgPooling2D = AvgPooling2D.builder() + AvgPooling2D avgPooling2D = AvgPooling2D.sameDiffBuilder() .input(input) .sameDiff(sameDiff()) .config(pooling2DConfig) @@ -547,7 +575,7 @@ public class DifferentialFunctionFactory { * @return */ public SDVariable maxPooling2d(SDVariable input, Pooling2DConfig pooling2DConfig) { - MaxPooling2D maxPooling2D = MaxPooling2D.builder() + MaxPooling2D maxPooling2D = MaxPooling2D.sameDiffBuilder() .input(input) .sameDiff(sameDiff()) .config(pooling2DConfig) @@ -590,7 +618,7 @@ public class DifferentialFunctionFactory { * @return */ public SDVariable sconv2d(SDVariable[] inputs, Conv2DConfig conv2DConfig) { - SConv2D sconv2D = SConv2D.sBuilder() + SConv2D sconv2D = SConv2D.sameDiffSBuilder() .inputFunctions(inputs) .sameDiff(sameDiff()) .conv2DConfig(conv2DConfig) @@ -609,7 +637,7 @@ public class DifferentialFunctionFactory { * @return */ public SDVariable depthWiseConv2d(SDVariable[] inputs, Conv2DConfig depthConv2DConfig) { - SConv2D depthWiseConv2D = SConv2D.sBuilder() + SConv2D depthWiseConv2D = SConv2D.sameDiffSBuilder() .inputFunctions(inputs) .sameDiff(sameDiff()) .conv2DConfig(depthConv2DConfig) @@ -627,7 +655,7 @@ public class DifferentialFunctionFactory { * @return */ public SDVariable deconv2d(SDVariable[] inputs, DeConv2DConfig deconv2DConfig) { - DeConv2D deconv2D = DeConv2D.builder() + DeConv2D deconv2D = DeConv2D.sameDiffBuilder() .inputs(inputs) .sameDiff(sameDiff()) .config(deconv2DConfig) @@ -654,9 +682,9 @@ public class DifferentialFunctionFactory { * @return */ public SDVariable conv3d(SDVariable[] inputs, Conv3DConfig conv3DConfig) { - Conv3D conv3D = Conv3D.builder() + Conv3D conv3D = Conv3D.sameDiffBuilder() .inputFunctions(inputs) - .conv3DConfig(conv3DConfig) + .config(conv3DConfig) .sameDiff(sameDiff()) .build(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDCNN.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDCNN.java index fab50a937..7b56ca266 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDCNN.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDCNN.java @@ -16,6 +16,7 @@ package org.nd4j.autodiff.samediff.ops; +import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ops.impl.layers.convolution.config.*; @@ -38,14 +39,9 @@ public class SDCNN extends SDOps { } /** - * 2D Convolution layer operation - average pooling 2d - * - * @param input the input to average pooling 2d operation - 4d CNN (image) activations in NCHW format - * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) - * @param pooling2DConfig the configuration for - * @return Result after applying average pooling on the input + * See {@link #avgPooling2d(String, SDVariable, Pooling2DConfig)}. */ - public SDVariable avgPooling2d(SDVariable input, Pooling2DConfig pooling2DConfig) { + public SDVariable avgPooling2d(@NonNull SDVariable input, @NonNull Pooling2DConfig pooling2DConfig) { return avgPooling2d(null, input, pooling2DConfig); } @@ -58,22 +54,16 @@ public class SDCNN extends SDOps { * @param pooling2DConfig the configuration * @return Result after applying average pooling on the input */ - public SDVariable avgPooling2d(String name, SDVariable input, Pooling2DConfig pooling2DConfig) { + public SDVariable avgPooling2d(String name, @NonNull SDVariable input, @NonNull Pooling2DConfig pooling2DConfig) { validateFloatingPoint("avgPooling2d", input); SDVariable ret = f().avgPooling2d(input, pooling2DConfig); return updateVariableNameAndReference(ret, name); } /** - * 3D convolution layer operation - average pooling 3d - * - * @param input the input to average pooling 3d operation - 5d activations in NCDHW format - * (shape [minibatch, channels, depth, height, width]) or NDHWC format - * (shape [minibatch, depth, height, width, channels]) - * @param pooling3DConfig the configuration - * @return Result after applying average pooling on the input + * See {@link #avgPooling3d(String, SDVariable, Pooling3DConfig)}. */ - public SDVariable avgPooling3d(SDVariable input, Pooling3DConfig pooling3DConfig) { + public SDVariable avgPooling3d(@NonNull SDVariable input, @NonNull Pooling3DConfig pooling3DConfig) { return avgPooling3d(null, input, pooling3DConfig); } @@ -87,7 +77,7 @@ public class SDCNN extends SDOps { * @param pooling3DConfig the configuration * @return Result after applying average pooling on the input */ - public SDVariable avgPooling3d(String name, SDVariable input, Pooling3DConfig pooling3DConfig) { + public SDVariable avgPooling3d(String name, @NonNull SDVariable input, @NonNull Pooling3DConfig pooling3DConfig) { validateFloatingPoint("avgPooling3d", input); SDVariable ret = f().avgPooling3d(input, pooling3DConfig); return updateVariableNameAndReference(ret, name); @@ -96,7 +86,7 @@ public class SDCNN extends SDOps { /** * @see #batchToSpace(String, SDVariable, int[], int[][]) */ - public SDVariable batchToSpace(SDVariable x, int[] blocks, int[][] crops) { + public SDVariable batchToSpace(@NonNull SDVariable x, @NonNull int[] blocks, @NonNull int[][] crops) { return batchToSpace(null, x, blocks, crops); } @@ -111,7 +101,7 @@ public class SDCNN extends SDOps { * @return Output variable * @see #spaceToBatch(String, SDVariable, int[], int[][]) */ - public SDVariable batchToSpace(String name, SDVariable x, int[] blocks, int[][] crops) { + public SDVariable batchToSpace(String name, @NonNull SDVariable x, @NonNull int[] blocks, @NonNull int[][] crops) { validateNumerical("batchToSpace", x); SDVariable ret = f().batchToSpace(x, blocks, crops); return updateVariableNameAndReference(ret, name); @@ -119,14 +109,9 @@ public class SDCNN extends SDOps { /** - * col2im operation for use in 2D convolution operations. Outputs a 4d array with shape - * [minibatch, inputChannels, height, width] - * - * @param in Input - rank 6 input with shape [minibatch, inputChannels, kernelHeight, kernelWidth, outputHeight, outputWidth] - * @param config Convolution configuration for the col2im operation - * @return Col2Im output variable + * See {@link #col2Im(String, SDVariable, Conv2DConfig)}. */ - public SDVariable col2Im(SDVariable in, Conv2DConfig config) { + public SDVariable col2Im(@NonNull SDVariable in, @NonNull Conv2DConfig config) { return col2Im(null, in, config); } @@ -139,33 +124,22 @@ public class SDCNN extends SDOps { * @param config Convolution configuration for the col2im operation * @return Col2Im output variable */ - public SDVariable col2Im(String name, SDVariable in, Conv2DConfig config) { + public SDVariable col2Im(String name, @NonNull SDVariable in, @NonNull Conv2DConfig config) { SDVariable ret = f().col2Im(in, config); return updateVariableNameAndReference(ret, name); } /** - * 1D Convolution layer operation - Conv1d - * - * @param input the input array/activations for the conv1d op - * @param weights weights for conv1d op - rank 3 array with values [kernelSize, inputChannels, outputChannels] - * @param conv1DConfig the configuration - * @return + * See {@link #conv1d(String, SDVariable, SDVariable, SDVariable, Conv1DConfig)}, no bias. */ - public SDVariable conv1d(SDVariable input, SDVariable weights, Conv1DConfig conv1DConfig) { - return conv1d(null, input, weights, conv1DConfig); + public SDVariable conv1d(@NonNull SDVariable input, @NonNull SDVariable weights, @NonNull Conv1DConfig conv1DConfig) { + return conv1d((String) null, input, weights, conv1DConfig); } /** - * Conv1d operation. - * - * @param name name of the operation in SameDiff - * @param input the inputs to conv1d - * @param weights weights for conv1d op - rank 3 array with values [kernelSize, inputChannels, outputChannels] - * @param conv1DConfig the configuration - * @return + * See {@link #conv1d(String, SDVariable, SDVariable, SDVariable, Conv1DConfig)}, no bias. */ - public SDVariable conv1d(String name, SDVariable input, SDVariable weights, Conv1DConfig conv1DConfig) { + public SDVariable conv1d(String name, @NonNull SDVariable input, @NonNull SDVariable weights, @NonNull Conv1DConfig conv1DConfig) { validateFloatingPoint("conv1d", input); validateFloatingPoint("conv1d", weights); SDVariable ret = f().conv1d(input, weights, conv1DConfig); @@ -173,21 +147,55 @@ public class SDCNN extends SDOps { } /** - * 2D Convolution operation (without bias) - * - * @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format - * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) - * @param weights Weights for the convolution operation. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, outputChannels] - * @param config Conv2DConfig configuration - * @return result of conv2d op + * See {@link #conv1d(String, SDVariable, SDVariable, SDVariable, Conv1DConfig)}. */ - public SDVariable conv2d(SDVariable layerInput, SDVariable weights, Conv2DConfig config) { + public SDVariable conv1d(@NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, @NonNull Conv1DConfig conv1DConfig) { + return conv1d(null, input, weights, bias, conv1DConfig); + } + + /** + * Conv1d operation. + * + * @param name name of the operation in SameDiff + * @param input the inputs to conv1d + * @param weights weights for conv1d op - rank 3 array with shape [kernelSize, inputChannels, outputChannels] + * @param bias bias for conv1d op - rank 1 array with shape [outputChannels]. May be null. + * @param conv1DConfig the configuration + * @return + */ + public SDVariable conv1d(String name, @NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, @NonNull Conv1DConfig conv1DConfig) { + validateFloatingPoint("conv1d", input); + validateFloatingPoint("conv1d", weights); + validateFloatingPoint("conv1d", bias); + SDVariable ret = f().conv1d(input, weights, bias, conv1DConfig); + return updateVariableNameAndReference(ret, name); + } + + /** + * See {@link #conv2d(String, SDVariable, SDVariable, SDVariable, Conv2DConfig)}, no bias. + */ + public SDVariable conv2d(@NonNull SDVariable layerInput, @NonNull SDVariable weights, @NonNull Conv2DConfig config) { return conv2d(layerInput, weights, null, config); } + /** + * See {@link #conv2d(String, SDVariable, SDVariable, SDVariable, Conv2DConfig)}, no bias. + */ + public SDVariable conv2d(String name, @NonNull SDVariable layerInput, @NonNull SDVariable weights, @NonNull Conv2DConfig config) { + return conv2d(name, layerInput, weights, null, config); + } + + /** + * See {@link #conv2d(String, SDVariable, SDVariable, SDVariable, Conv2DConfig)}. + */ + public SDVariable conv2d(@NonNull SDVariable layerInput, @NonNull SDVariable weights, SDVariable bias, @NonNull Conv2DConfig config) { + return conv2d(null, layerInput, weights, bias, config); + } + /** * 2D Convolution operation with optional bias * + * @param name name of the operation in SameDiff * @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) * @param weights Weights for the convolution operation. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, outputChannels] @@ -195,7 +203,7 @@ public class SDCNN extends SDOps { * @param config Conv2DConfig configuration * @return result of conv2d op */ - public SDVariable conv2d(SDVariable layerInput, SDVariable weights, SDVariable bias, Conv2DConfig config) { + public SDVariable conv2d(String name, @NonNull SDVariable layerInput, @NonNull SDVariable weights, SDVariable bias, @NonNull Conv2DConfig config) { validateFloatingPoint("conv2d", "input", layerInput); validateFloatingPoint("conv2d", "weights", weights); validateFloatingPoint("conv2d", "bias", bias); @@ -204,18 +212,13 @@ public class SDCNN extends SDOps { arr[1] = weights; if (bias != null) arr[2] = bias; - return conv2d(arr, config); + return conv2d(name, arr, config); } /** - * 2D Convolution operation with optional bias - * - * @param inputs an array with either 2 elements (layerInput, weights) or 3 elements (layerInput, weights, bias) as - * described in {@link #conv2d(SDVariable, SDVariable, SDVariable, Conv2DConfig)} - * @param config Conv2DConfig configuration - * @return result of convolution 2d operation + * See {@link #conv2d(String, SDVariable[], Conv2DConfig)}. */ - public SDVariable conv2d(SDVariable[] inputs, Conv2DConfig config) { + public SDVariable conv2d(@NonNull SDVariable[] inputs, @NonNull Conv2DConfig config) { return conv2d(null, inputs, config); } @@ -228,7 +231,7 @@ public class SDCNN extends SDOps { * @param config Conv2DConfig configuration * @return result of convolution 2d operation */ - public SDVariable conv2d(String name, SDVariable[] inputs, Conv2DConfig config) { + public SDVariable conv2d(String name, @NonNull SDVariable[] inputs, @NonNull Conv2DConfig config) { for(SDVariable v : inputs) validateNumerical("conv2d", v); SDVariable ret = f().conv2d(inputs, config); @@ -236,19 +239,26 @@ public class SDCNN extends SDOps { } /** - * Convolution 3D operation without bias - * - * @param input the input to average pooling 3d operation - 5d activations in NCDHW format - * (shape [minibatch, channels, depth, height, width]) or NDHWC format - * (shape [minibatch, depth, height, width, channels]) - * @param weights Weights for conv3d. Rank 5 with shape [kernelDepth, kernelHeight, kernelWidth, inputChannels, outputChannels]. - * @param conv3DConfig the configuration - * @return Conv3d output variable + * See {@link #conv3d(String, SDVariable, SDVariable, SDVariable, Conv3DConfig)}, no bias. */ - public SDVariable conv3d(SDVariable input, SDVariable weights, Conv3DConfig conv3DConfig) { + public SDVariable conv3d(@NonNull SDVariable input, @NonNull SDVariable weights, @NonNull Conv3DConfig conv3DConfig) { return conv3d(null, input, weights, null, conv3DConfig); } + /** + * See {@link #conv3d(String, SDVariable, SDVariable, SDVariable, Conv3DConfig)}, no bias. + */ + public SDVariable conv3d(String name, @NonNull SDVariable input, @NonNull SDVariable weights, @NonNull Conv3DConfig conv3DConfig) { + return conv3d(name, input, weights, null, conv3DConfig); + } + + /** + * See {@link #conv3d(String, SDVariable, SDVariable, SDVariable, Conv3DConfig)}. + */ + public SDVariable conv3d(@NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, @NonNull Conv3DConfig conv3DConfig) { + return conv3d(null, input, weights, bias, conv3DConfig); + } + /** * Convolution 3D operation with optional bias * @@ -261,7 +271,7 @@ public class SDCNN extends SDOps { * @param conv3DConfig the configuration * @return Conv3d output variable */ - public SDVariable conv3d(String name, SDVariable input, SDVariable weights, SDVariable bias, Conv3DConfig conv3DConfig) { + public SDVariable conv3d(String name, @NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, @NonNull Conv3DConfig conv3DConfig) { validateFloatingPoint("conv3d", "input", input); validateFloatingPoint("conv3d", "weights", weights); validateFloatingPoint("conv3d", "bias", bias); @@ -276,51 +286,30 @@ public class SDCNN extends SDOps { } /** - * Convolution 3D operation with optional bias - * - * @param input the input to average pooling 3d operation - 5d activations in NCDHW format - * (shape [minibatch, channels, depth, height, width]) or NDHWC format - * (shape [minibatch, depth, height, width, channels]) - * @param weights Weights for conv3d. Rank 5 with shape [kernelDepth, kernelHeight, kernelWidth, inputChannels, outputChannels]. - * @param bias Optional 1D bias array with shape [outputChannels]. May be null. - * @param conv3DConfig the configuration - * @return Conv3d output variable + * See {@link #deconv2d(String, SDVariable, SDVariable, SDVariable, DeConv2DConfig)}, no bias. */ - public SDVariable conv3d(SDVariable input, SDVariable weights, SDVariable bias, Conv3DConfig conv3DConfig) { - return conv3d(null, input, weights, bias, conv3DConfig); - } - - /** - * Convolution 3D operation without bias - * - * @param name Name of the output variable - * @param input the input to average pooling 3d operation - 5d activations in NCDHW format - * (shape [minibatch, channels, depth, height, width]) or NDHWC format - * (shape [minibatch, depth, height, width, channels]) - * @param weights Weights for conv3d. Rank 5 with shape [kernelDepth, kernelHeight, kernelWidth, inputChannels, outputChannels]. - * @param conv3DConfig the configuration - * @return Conv3d output variable - */ - public SDVariable conv3d(String name, SDVariable input, SDVariable weights, Conv3DConfig conv3DConfig) { - return conv3d(name, input, weights, null, conv3DConfig); - } - - /** - * 2D deconvolution operation without bias - * - * @param layerInput the input to deconvolution 2d operation - 4d CNN (image) activations in NCHW format - * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) - * @param weights Weights for the 2d deconvolution operation. 4 dimensions with format [inputChannels, outputChannels, kernelHeight, kernelWidth]. - * @param deconv2DConfig DeConv2DConfig configuration - * @return result of deconv2d op - */ - public SDVariable deconv2d(SDVariable layerInput, SDVariable weights, DeConv2DConfig deconv2DConfig) { + public SDVariable deconv2d(@NonNull SDVariable layerInput, @NonNull SDVariable weights, @NonNull DeConv2DConfig deconv2DConfig) { return deconv2d(layerInput, weights, null, deconv2DConfig); } + /** + * See {@link #deconv2d(String, SDVariable, SDVariable, SDVariable, DeConv2DConfig)}, no bias. + */ + public SDVariable deconv2d(String name, @NonNull SDVariable layerInput, @NonNull SDVariable weights, @NonNull DeConv2DConfig deconv2DConfig) { + return deconv2d(name, layerInput, weights, null, deconv2DConfig); + } + + /** + * See {@link #deconv2d(String, SDVariable, SDVariable, SDVariable, DeConv2DConfig)}. + */ + public SDVariable deconv2d(@NonNull SDVariable layerInput, @NonNull SDVariable weights, SDVariable bias, @NonNull DeConv2DConfig deconv2DConfig) { + return deconv2d(null, layerInput, weights, bias, deconv2DConfig); + } + /** * 2D deconvolution operation with optional bias * + * @param name name of the operation in SameDiff * @param layerInput the input to deconvolution 2d operation - 4d CNN (image) activations in NCHW format * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) * @param weights Weights for the 2d deconvolution operation. 4 dimensions with format [inputChannels, outputChannels, kernelHeight, kernelWidth]. @@ -328,7 +317,7 @@ public class SDCNN extends SDOps { * @param deconv2DConfig DeConv2DConfig configuration * @return result of deconv2d op */ - public SDVariable deconv2d(SDVariable layerInput, SDVariable weights, SDVariable bias, DeConv2DConfig deconv2DConfig) { + public SDVariable deconv2d(String name, @NonNull SDVariable layerInput, @NonNull SDVariable weights, SDVariable bias, @NonNull DeConv2DConfig deconv2DConfig) { validateFloatingPoint("deconv2d", "input", layerInput); validateFloatingPoint("deconv2d", "weights", weights); validateFloatingPoint("deconv2d", "bias", bias); @@ -337,18 +326,13 @@ public class SDCNN extends SDOps { arr[1] = weights; if (bias != null) arr[2] = bias; - return deconv2d(arr, deconv2DConfig); + return deconv2d(name, arr, deconv2DConfig); } /** - * 2D deconvolution operation with or without optional bias - * - * @param inputs Inputs to the deconvolution 2d operation - input array of length 2 (layerInput, weights) - * or length 3 (layerInput, weights, bias) as described in {@link #deconv2d(SDVariable[], DeConv2DConfig)} - * @param deconv2DConfig the configuration - * @return result of deconv2d op + * See {@link #deconv2d(String, SDVariable[], DeConv2DConfig)}. */ - public SDVariable deconv2d(SDVariable[] inputs, DeConv2DConfig deconv2DConfig) { + public SDVariable deconv2d(@NonNull SDVariable[] inputs, @NonNull DeConv2DConfig deconv2DConfig) { return deconv2d(null, inputs, deconv2DConfig); } @@ -361,13 +345,34 @@ public class SDCNN extends SDOps { * @param deconv2DConfig the configuration * @return result of deconv2d op */ - public SDVariable deconv2d(String name, SDVariable[] inputs, DeConv2DConfig deconv2DConfig) { + public SDVariable deconv2d(String name, @NonNull SDVariable[] inputs, @NonNull DeConv2DConfig deconv2DConfig) { for(SDVariable v : inputs) validateNumerical("deconv2d", v); SDVariable ret = f().deconv2d(inputs, deconv2DConfig); return updateVariableNameAndReference(ret, name); } + /** + * See {@link #deconv3d(String, SDVariable, SDVariable, SDVariable, DeConv3DConfig)}, no bias. + */ + public SDVariable deconv3d(@NonNull SDVariable input, @NonNull SDVariable weights, @NonNull DeConv3DConfig config) { + return deconv3d(input, weights, null, config); + } + + /** + * See {@link #deconv3d(String, SDVariable, SDVariable, SDVariable, DeConv3DConfig)}, no bias. + */ + public SDVariable deconv3d(String name, @NonNull SDVariable input, @NonNull SDVariable weights, @NonNull DeConv3DConfig config) { + return deconv3d(name, input, weights, null, config); + } + + /** + * See {@link #deconv3d(String, SDVariable, SDVariable, SDVariable, DeConv3DConfig)}. + */ + public SDVariable deconv3d(@NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, @NonNull DeConv3DConfig config) { + return deconv3d(null, input, weights, bias, config); + } + /** * 3D CNN deconvolution operation with or without optional bias * @@ -377,7 +382,7 @@ public class SDCNN extends SDOps { * @param bias Bias array - optional, may be null. If non-null, must have shape [outputChannels] * @param config Configuration */ - public SDVariable deconv3d(String name, SDVariable input, SDVariable weights, SDVariable bias, DeConv3DConfig config) { + public SDVariable deconv3d(String name, @NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, @NonNull DeConv3DConfig config) { validateFloatingPoint("conv3d", input); validateFloatingPoint("conv3d", weights); validateFloatingPoint("conv3d", bias); @@ -386,41 +391,9 @@ public class SDCNN extends SDOps { } /** - * 3D CNN deconvolution operation with or without optional bias - * - * @param input Input array - shape [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - * @param weights Weights array - shape [kD, kH, kW, oC, iC] - * @param bias Bias array - optional, may be null. If non-null, must have shape [outputChannels] - * @param config Configuration + * See {@link #depthToSpace(String, SDVariable, int, String)}. */ - public SDVariable deconv3d(SDVariable input, SDVariable weights, SDVariable bias, DeConv3DConfig config) { - return deconv3d(null, input, weights, bias, config); - } - - /** - * 3D CNN deconvolution operation with no bias - * - * @param input Input array - shape [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - * @param weights Weights array - shape [kD, kH, kW, oC, iC] - * @param config Configuration - */ - public SDVariable deconv3d(SDVariable input, SDVariable weights, DeConv3DConfig config) { - return deconv3d(input, weights, null, config); - } - - /** - * Convolution 2d layer batch to space operation on 4d input.
- * Reduces input channels dimension by rearranging data into a larger spatial dimensions
- * Example: if input has shape [mb, 8, 2, 2] and block size is 2, then output size is [mb, 8/(2*2), 2*2, 2*2] - * = [mb, 2, 4, 4] - * - * @param x the input to depth to space pooling 2d operation - 4d activations in NCHW format - * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) - * @param blockSize Block size, in the height/width dimension - * @param dataFormat Data format: "NCHW" or "NHWC" - * @return Output variable - */ - public SDVariable depthToSpace(SDVariable x, int blockSize, String dataFormat) { + public SDVariable depthToSpace(@NonNull SDVariable x, @NonNull int blockSize, @NonNull String dataFormat) { return depthToSpace(null, x, blockSize, dataFormat); } @@ -438,27 +411,36 @@ public class SDCNN extends SDOps { * @return Output variable * @see #depthToSpace(String, SDVariable, int, String) */ - public SDVariable depthToSpace(String name, SDVariable x, int blockSize, String dataFormat) { + public SDVariable depthToSpace(String name, @NonNull SDVariable x, @NonNull int blockSize, @NonNull String dataFormat) { SDVariable ret = f().depthToSpace(x, blockSize, dataFormat); return updateVariableNameAndReference(ret, name); } /** - * Depth-wise 2D convolution operation without bias - * - * @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format - * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) - * @param depthWeights Depth-wise conv2d weights. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, depthMultiplier] - * @param config Conv2DConfig configuration - * @return result of conv2d op + * See {@link #depthWiseConv2d(String, SDVariable, SDVariable, SDVariable, Conv2DConfig)}, no bias. */ - public SDVariable depthWiseConv2d(SDVariable layerInput, SDVariable depthWeights, Conv2DConfig config) { + public SDVariable depthWiseConv2d(@NonNull SDVariable layerInput, @NonNull SDVariable depthWeights, @NonNull Conv2DConfig config) { return depthWiseConv2d(layerInput, depthWeights, null, config); } + /** + * See {@link #depthWiseConv2d(String, SDVariable, SDVariable, SDVariable, Conv2DConfig)}, no bias. + */ + public SDVariable depthWiseConv2d(String name, @NonNull SDVariable layerInput, @NonNull SDVariable depthWeights, @NonNull Conv2DConfig config) { + return depthWiseConv2d(name, layerInput, depthWeights, null, config); + } + + /** + * See {@link #depthWiseConv2d(String, SDVariable, SDVariable, SDVariable, Conv2DConfig)}. + */ + public SDVariable depthWiseConv2d(@NonNull SDVariable layerInput, @NonNull SDVariable depthWeights, SDVariable bias, @NonNull Conv2DConfig config) { + return depthWiseConv2d(null, layerInput, depthWeights, bias, config); + } + /** * Depth-wise 2D convolution operation with optional bias * + * @param name name of the operation in SameDiff * @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) * @param depthWeights Depth-wise conv2d weights. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, depthMultiplier] @@ -466,7 +448,7 @@ public class SDCNN extends SDOps { * @param config Conv2DConfig configuration * @return result of depthwise conv2d op */ - public SDVariable depthWiseConv2d(SDVariable layerInput, SDVariable depthWeights, SDVariable bias, Conv2DConfig config) { + public SDVariable depthWiseConv2d(String name, @NonNull SDVariable layerInput, @NonNull SDVariable depthWeights, SDVariable bias, @NonNull Conv2DConfig config) { validateFloatingPoint("depthwiseConv2d", "input", layerInput); validateFloatingPoint("depthwiseConv2d", "depth weights", depthWeights); validateFloatingPoint("depthwiseConv2d", "bias", bias); @@ -475,19 +457,13 @@ public class SDCNN extends SDOps { arr[1] = depthWeights; if (bias != null) arr[2] = bias; - return depthWiseConv2d(arr, config); + return depthWiseConv2d(name, arr, config); } /** - * Depth-wise convolution 2D operation. - * - * @param inputs the inputs to depth-wise conv2d. An array with either 2 elements (layerInput, depthWeights) - * or 3 elements (layerInput, depthWeights, bias) as described in - * {@link #depthWiseConv2d(SDVariable, SDVariable, SDVariable, Conv2DConfig)} - * @param depthConv2DConfig the configuration - * @return result of depthwise conv2d op + * See {@link #depthWiseConv2d(String, SDVariable[], Conv2DConfig)}. */ - public SDVariable depthWiseConv2d(SDVariable[] inputs, Conv2DConfig depthConv2DConfig) { + public SDVariable depthWiseConv2d(@NonNull SDVariable[] inputs, @NonNull Conv2DConfig depthConv2DConfig) { return depthWiseConv2d(null, inputs, depthConv2DConfig); } @@ -501,7 +477,7 @@ public class SDCNN extends SDOps { * @param depthConv2DConfig the configuration * @return result of depthwise conv2d op */ - public SDVariable depthWiseConv2d(String name, SDVariable[] inputs, Conv2DConfig depthConv2DConfig) { + public SDVariable depthWiseConv2d(String name, @NonNull SDVariable[] inputs, @NonNull Conv2DConfig depthConv2DConfig) { for(SDVariable v : inputs) validateFloatingPoint("depthWiseConv2d", v); SDVariable ret = f().depthWiseConv2d(inputs, depthConv2DConfig); @@ -509,17 +485,10 @@ public class SDCNN extends SDOps { } /** - * TODO doc string - * - * @param df - * @param weights - * @param strides - * @param rates - * @param isSameMode - * @return + * See {@link #dilation2D(String, SDVariable, SDVariable, int[], int[], boolean)}. */ - public SDVariable dilation2D(SDVariable df, SDVariable weights, int[] strides, - int[] rates, boolean isSameMode) { + public SDVariable dilation2D(@NonNull SDVariable df, @NonNull SDVariable weights, @NonNull int[] strides, + @NonNull int[] rates, @NonNull boolean isSameMode) { return dilation2D(null, df, weights, strides, rates, isSameMode); } @@ -534,8 +503,8 @@ public class SDCNN extends SDOps { * @param isSameMode * @return */ - public SDVariable dilation2D(String name, SDVariable df, SDVariable weights, int[] strides, - int[] rates, boolean isSameMode) { + public SDVariable dilation2D(String name, @NonNull SDVariable df, @NonNull SDVariable weights, @NonNull int[] strides, + @NonNull int[] rates, @NonNull boolean isSameMode) { SDVariable ret = f().dilation2D(df, weights, strides, rates, isSameMode); return updateVariableNameAndReference(ret, name); } @@ -555,21 +524,16 @@ public class SDCNN extends SDOps { * @param sameMode If true: use same mode padding. If false * @return */ - public SDVariable extractImagePatches(String name, SDVariable input, int kH, int kW, int sH, int sW, int rH, int rW, boolean sameMode) { + public SDVariable extractImagePatches(String name, @NonNull SDVariable input, int kH, int kW, int sH, int sW, int rH, int rW, boolean sameMode) { SDVariable ret = f().extractImagePatches(input, kH, kW, sH, sW, rH, rW, sameMode); return updateVariableNameAndReference(ret, name); } /** - * im2col operation for use in 2D convolution operations. Outputs a 6d array with shape - * [minibatch, inputChannels, kernelHeight, kernelWidth, outputHeight, outputWidth] - * - * @param in Input - rank 4 input with shape [minibatch, inputChannels, height, width] - * @param config Convolution configuration for the im2col operation - * @return Im2Col output variable + * See {@link #im2Col(String, SDVariable, Conv2DConfig)}. */ - public SDVariable im2Col(SDVariable in, Conv2DConfig config) { + public SDVariable im2Col(@NonNull SDVariable in, @NonNull Conv2DConfig config) { return im2Col(null, in, config); } @@ -582,20 +546,16 @@ public class SDCNN extends SDOps { * @param config Convolution configuration for the im2col operation * @return Im2Col output variable */ - public SDVariable im2Col(String name, SDVariable in, Conv2DConfig config) { + public SDVariable im2Col(String name, @NonNull SDVariable in, @NonNull Conv2DConfig config) { SDVariable ret = f().im2Col(in, config); return updateVariableNameAndReference(ret, name); } /** - * 2D convolution layer operation - local response normalization - * - * @param inputs the inputs to lrn - * @param lrnConfig the configuration - * @return + * See {@link #localResponseNormalization(String, SDVariable, LocalResponseNormalizationConfig)}. */ - public SDVariable localResponseNormalization(SDVariable inputs, LocalResponseNormalizationConfig lrnConfig) { + public SDVariable localResponseNormalization(@NonNull SDVariable inputs, @NonNull LocalResponseNormalizationConfig lrnConfig) { return localResponseNormalization(null, inputs, lrnConfig); } @@ -607,8 +567,8 @@ public class SDCNN extends SDOps { * @param lrnConfig the configuration * @return */ - public SDVariable localResponseNormalization(String name, SDVariable input, - LocalResponseNormalizationConfig lrnConfig) { + public SDVariable localResponseNormalization(String name, @NonNull SDVariable input, + @NonNull LocalResponseNormalizationConfig lrnConfig) { validateFloatingPoint("local response normalization", input); SDVariable ret = f().localResponseNormalization(input, lrnConfig); return updateVariableNameAndReference(ret, name); @@ -616,14 +576,9 @@ public class SDCNN extends SDOps { /** - * 2D Convolution layer operation - max pooling 2d - * - * @param input the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format - * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) - * @param pooling2DConfig the configuration - * @return Result after applying max pooling on the input + * See {@link #maxPooling2d(String, SDVariable, Pooling2DConfig)}. */ - public SDVariable maxPooling2d(SDVariable input, Pooling2DConfig pooling2DConfig) { + public SDVariable maxPooling2d(@NonNull SDVariable input, @NonNull Pooling2DConfig pooling2DConfig) { return maxPooling2d(null, input, pooling2DConfig); } @@ -636,22 +591,16 @@ public class SDCNN extends SDOps { * @param pooling2DConfig the configuration * @return Result after applying max pooling on the input */ - public SDVariable maxPooling2d(String name, SDVariable input, Pooling2DConfig pooling2DConfig) { + public SDVariable maxPooling2d(String name, @NonNull SDVariable input, @NonNull Pooling2DConfig pooling2DConfig) { validateNumerical("maxPooling2d", input); SDVariable ret = f().maxPooling2d(input, pooling2DConfig); return updateVariableNameAndReference(ret, name); } /** - * 3D convolution layer operation - max pooling 3d operation. - * - * @param input the input to average pooling 3d operation - 5d activations in NCDHW format - * (shape [minibatch, channels, depth, height, width]) or NDHWC format - * (shape [minibatch, depth, height, width, channels]) - * @param pooling3DConfig the configuration - * @return Result after applying max pooling on the input + * See {@link #maxPooling3d(String, SDVariable, Pooling3DConfig)}. */ - public SDVariable maxPooling3d(SDVariable input, Pooling3DConfig pooling3DConfig) { + public SDVariable maxPooling3d(@NonNull SDVariable input, @NonNull Pooling3DConfig pooling3DConfig) { return maxPooling3d(null, input, pooling3DConfig); } @@ -665,7 +614,7 @@ public class SDCNN extends SDOps { * @param pooling3DConfig the configuration * @return Result after applying max pooling on the input */ - public SDVariable maxPooling3d(String name, SDVariable input, Pooling3DConfig pooling3DConfig) { + public SDVariable maxPooling3d(String name, @NonNull SDVariable input, @NonNull Pooling3DConfig pooling3DConfig) { validateNumerical("maxPooling3d", input); SDVariable ret = f().maxPooling3d(input, pooling3DConfig); return updateVariableNameAndReference(ret, name); @@ -673,21 +622,30 @@ public class SDCNN extends SDOps { /** - * Separable 2D convolution operation without bias - * - * @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format - * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) - * @param depthWeights Separable conv2d depth weights. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, depthMultiplier] - * @param pointWeights Point weights, rank 4 with format [1, 1, inputChannels*depthMultiplier, outputChannels] - * May be null - * @param config Conv2DConfig configuration - * @return result of separable convolution 2d operation + * See {@link #separableConv2d(String, SDVariable, SDVariable, SDVariable, SDVariable, Conv2DConfig)}, no bias. */ - public SDVariable separableConv2d(SDVariable layerInput, SDVariable depthWeights, SDVariable pointWeights, - Conv2DConfig config) { + public SDVariable separableConv2d(SDVariable layerInput, @NonNull SDVariable depthWeights, SDVariable pointWeights, + @NonNull Conv2DConfig config) { return separableConv2d(layerInput, depthWeights, pointWeights, null, config); } + + /** + * See {@link #separableConv2d(String, SDVariable, SDVariable, SDVariable, SDVariable, Conv2DConfig)}, no bias. + */ + public SDVariable separableConv2d(String name, @NonNull SDVariable layerInput, @NonNull SDVariable depthWeights, SDVariable pointWeights, + @NonNull Conv2DConfig config) { + return separableConv2d(layerInput, depthWeights, pointWeights, null, config); + } + + /** + * See {@link #separableConv2d(String, SDVariable, SDVariable, SDVariable, SDVariable, Conv2DConfig)}. + */ + public SDVariable separableConv2d(@NonNull SDVariable layerInput, @NonNull SDVariable depthWeights, SDVariable pointWeights, + SDVariable bias, @NonNull Conv2DConfig config) { + return separableConv2d(null, layerInput, depthWeights, pointWeights, bias, config); + } + /** * Separable 2D convolution operation with optional bias * @@ -700,8 +658,8 @@ public class SDCNN extends SDOps { * @param config Conv2DConfig configuration * @return result of separable convolution 2d operation */ - public SDVariable separableConv2d(SDVariable layerInput, SDVariable depthWeights, SDVariable pointWeights, - SDVariable bias, Conv2DConfig config) { + public SDVariable separableConv2d(String name, @NonNull SDVariable layerInput, @NonNull SDVariable depthWeights, SDVariable pointWeights, + SDVariable bias, @NonNull Conv2DConfig config) { validateFloatingPoint("separableConv2d", "input", layerInput); validateFloatingPoint("separableConv2d", "depthWeights", depthWeights); validateFloatingPoint("separableConv2d", "pointWeights", pointWeights); @@ -712,18 +670,13 @@ public class SDCNN extends SDOps { arr[2] = pointWeights; if (bias != null) arr[3] = bias; - return sconv2d(arr, config); + return sconv2d(name, arr, config); } /** - * Separable 2D convolution operation with/without optional bias - * - * @param inputs the inputs to separable conv2 operation. Should be length 3 (layerInput, depthWeights, pointWeights) - * or length 4 (layerInput, depthWeights, pointWeights, bias) as described in {@link #separableConv2d(SDVariable, SDVariable, SDVariable, SDVariable, Conv2DConfig)} - * @param conv2DConfig the configuration - * @return result of separable convolution 2d operation + * See {@link #sconv2d(String, SDVariable[], Conv2DConfig)}. */ - public SDVariable sconv2d(SDVariable[] inputs, Conv2DConfig conv2DConfig) { + public SDVariable sconv2d(@NonNull SDVariable[] inputs, @NonNull Conv2DConfig conv2DConfig) { return sconv2d(null, inputs, conv2DConfig); } @@ -736,7 +689,7 @@ public class SDCNN extends SDOps { * @param conv2DConfig the configuration * @return result of separable convolution 2d operation */ - public SDVariable sconv2d(String name, SDVariable[] inputs, Conv2DConfig conv2DConfig) { + public SDVariable sconv2d(String name, @NonNull SDVariable[] inputs, @NonNull Conv2DConfig conv2DConfig) { for(SDVariable v : inputs) validateFloatingPoint("sconv2d", v); SDVariable ret = f().sconv2d(inputs, conv2DConfig); @@ -747,7 +700,7 @@ public class SDCNN extends SDOps { /** * @see #spaceToBatch(String, SDVariable, int[], int[][]) */ - public SDVariable spaceToBatch(SDVariable x, int[] blocks, int[][] padding) { + public SDVariable spaceToBatch(@NonNull SDVariable x, @NonNull int[] blocks, @NonNull int[][] padding) { return spaceToBatch(null, x, blocks, padding); } @@ -762,7 +715,7 @@ public class SDCNN extends SDOps { * @return Output variable * @see #batchToSpace(String, SDVariable, int[], int[][]) */ - public SDVariable spaceToBatch(String name, SDVariable x, int[] blocks, int[][] padding) { + public SDVariable spaceToBatch(String name, @NonNull SDVariable x, @NonNull int[] blocks, @NonNull int[][] padding) { SDVariable ret = f().spaceToBatch(x, blocks, padding); return updateVariableNameAndReference(ret, name); } @@ -770,7 +723,7 @@ public class SDCNN extends SDOps { /** * @see #spaceToDepth(String, SDVariable, int, String) */ - public SDVariable spaceToDepth(SDVariable x, int blockSize, String dataFormat) { + public SDVariable spaceToDepth(@NonNull SDVariable x, int blockSize, @NonNull String dataFormat) { return spaceToDepth(null, x, blockSize, dataFormat); } @@ -788,23 +741,39 @@ public class SDCNN extends SDOps { * @return Output variable * @see #depthToSpace(String, SDVariable, int, String) */ - public SDVariable spaceToDepth(String name, SDVariable x, int blockSize, String dataFormat) { + public SDVariable spaceToDepth(String name, @NonNull SDVariable x, int blockSize, @NonNull String dataFormat) { SDVariable ret = f().spaceToDepth(x, blockSize, dataFormat); return updateVariableNameAndReference(ret, name); } /** - * 2D Convolution layer operation - Upsampling 2d with same scale for both dimensions. NCHW input format. + * See {@link #upsampling2d(String, SDVariable, boolean, int, int)}, + * scale is used for both height and width dimensions. * - * @param input Input - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width]) - * @param scale Scale to upsample in both H and W dimensions - * @return Upsampled input + * @param scale The scale for both height and width dimensions. */ - public SDVariable upsampling2d(SDVariable input, int scale) { + public SDVariable upsampling2d(@NonNull SDVariable input, int scale) { return upsampling2d(null, input, true, scale, scale); } + /** + * See {@link #upsampling2d(String, SDVariable, boolean, int, int)}, + * scale is used for both height and width dimensions. + * + * @param scale The scale for both height and width dimensions. + */ + public SDVariable upsampling2d(String name, @NonNull SDVariable input, int scale) { + return upsampling2d(name, input, true, scale, scale); + } + + /** + * See {@link #upsampling2d(String, SDVariable, boolean, int, int)}. + */ + public SDVariable upsampling2d(@NonNull SDVariable input, boolean nchw, int scaleH, int scaleW) { + return upsampling2d(null, input, nchw, scaleH, scaleW); + } + /** * 2D Convolution layer operation - Upsampling 2d * @@ -814,33 +783,8 @@ public class SDCNN extends SDOps { * @param scaleW Scale to upsample in width dimension * @return Upsampled input */ - public SDVariable upsampling2d(String name, SDVariable input, boolean nchw, int scaleH, int scaleW) { + public SDVariable upsampling2d(String name, @NonNull SDVariable input, boolean nchw, int scaleH, int scaleW) { SDVariable ret = f().upsampling2d(input, nchw, scaleH, scaleW); return updateVariableNameAndReference(ret, name); } - - /** - * 2D Convolution layer operation - Upsampling 2d with same scale for both dimensions. NCHW input format. - * - * @param input Input - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width]) - * @param scale Scale to upsample in both H and W dimensions - * @return Upsampled input - */ - public SDVariable upsampling2d(String name, SDVariable input, int scale) { - return upsampling2d(name, input, true, scale, scale); - } - - /** - * 2D Convolution layer operation - Upsampling 2d - * - * @param input Input - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width]) - * or NHWC format (shape [minibatch, height, width, channels]) - * @param nchw If true: input is in NCHW (minibatch, channels, height, width) format. False: NHWC format - * @param scaleH Scale to upsample in height dimension - * @param scaleW Scale to upsample in width dimension - * @return Upsampled input - */ - public SDVariable upsampling2d(SDVariable input, boolean nchw, int scaleH, int scaleW) { - return upsampling2d(null, input, nchw, scaleH, scaleW); - } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/AvgPooling2D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/AvgPooling2D.java index ac13c6224..2f295cc6a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/AvgPooling2D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/AvgPooling2D.java @@ -19,6 +19,7 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution; import lombok.Builder; import lombok.Getter; import lombok.NoArgsConstructor; +import lombok.NonNull; import lombok.extern.slf4j.Slf4j; import lombok.val; import onnx.Onnx; @@ -53,19 +54,19 @@ public class AvgPooling2D extends DynamicCustomOp { } - @Builder(builderMethodName = "builder") - public AvgPooling2D(SameDiff sameDiff, SDVariable input, INDArray arrayInput, INDArray arrayOutput, Pooling2DConfig config) { - super(null, sameDiff, new SDVariable[]{input}, false); - if (arrayInput != null) { - addInputArgument(arrayInput); - } - if (arrayOutput != null) { - addOutputArgument(arrayOutput); - } + @Builder(builderMethodName = "sameDiffBuilder") + public AvgPooling2D(SameDiff sameDiff, SDVariable input, Pooling2DConfig config) { + super(sameDiff, new SDVariable[]{input}); config.setType(Pooling2D.Pooling2DType.AVG); + this.config = config; + addArgs(); + } + + public AvgPooling2D(@NonNull INDArray input, INDArray output, @NonNull Pooling2DConfig config){ + super(new INDArray[]{input}, wrapOrNull(output)); + config.setType(Pooling2D.Pooling2DType.AVG); - this.sameDiff = sameDiff; this.config = config; addArgs(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv1D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv1D.java index 5ae2ac144..2fc814fb3 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv1D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv1D.java @@ -19,6 +19,7 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution; import lombok.Builder; import lombok.Getter; import lombok.NoArgsConstructor; +import lombok.NonNull; import lombok.extern.slf4j.Slf4j; import lombok.val; import onnx.Onnx; @@ -39,6 +40,7 @@ import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv1DConfig; +import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig; import org.nd4j.linalg.util.ArrayUtil; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; @@ -59,18 +61,28 @@ public class Conv1D extends DynamicCustomOp { protected Conv1DConfig config; private static final String INVALID_CONFIGURATION = "Invalid Conv1D configuration : s = %s p = %s "; - @Builder(builderMethodName = "builder") + @Builder(builderMethodName = "sameDiffBuilder") public Conv1D(SameDiff sameDiff, SDVariable[] inputFunctions, - INDArray[] inputArrays, INDArray[] outputs, Conv1DConfig config) { - super(null, inputArrays, outputs); - this.sameDiff = sameDiff; + super(sameDiff, inputFunctions); + initConfig(config); + } + + public Conv1D(INDArray[] inputs, INDArray[] outputs, Conv1DConfig config){ + super(inputs, outputs); + + initConfig(config); + } + + public Conv1D(@NonNull INDArray input, @NonNull INDArray weights, INDArray bias, INDArray output, @NonNull Conv1DConfig config){ + this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config); + } + + private void initConfig(Conv1DConfig config){ this.config = config; Preconditions.checkState(config.getS() >= 1 && config.getP() >= 0, INVALID_CONFIGURATION, config.getS(), config.getP()); addArgs(); - sameDiff.putOpForId(this.getOwnName(), this); - sameDiff.addArgsFor(inputFunctions, this); } protected void addArgs() { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv2D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv2D.java index 04db5874c..5e077e3fc 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv2D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv2D.java @@ -19,6 +19,7 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution; import lombok.Builder; import lombok.Getter; import lombok.NoArgsConstructor; +import lombok.NonNull; import lombok.extern.slf4j.Slf4j; import lombok.val; import onnx.Onnx; @@ -56,23 +57,32 @@ public class Conv2D extends DynamicCustomOp { protected Conv2DConfig config; private static final String INVALID_CONFIGURATION = "Invalid Conv2D configuration : sW = %s pH = %s dW = %s "; - @Builder(builderMethodName = "builder") + @Builder(builderMethodName = "sameDiffBuilder") public Conv2D(SameDiff sameDiff, SDVariable[] inputFunctions, - INDArray[] inputArrays, INDArray[] outputs, Conv2DConfig config) { - super(null, inputArrays, outputs); - this.sameDiff = sameDiff; + super(sameDiff, inputFunctions); + + initConfig(config); + } + + public Conv2D(INDArray[] inputs, INDArray[] outputs, Conv2DConfig config){ + super(inputs, outputs); + + initConfig(config); + } + + public Conv2D(@NonNull INDArray input, @NonNull INDArray weights, INDArray bias, INDArray output, @NonNull Conv2DConfig config){ + this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config); + } + + protected void initConfig(Conv2DConfig config){ this.config = config; Preconditions.checkState(config.getSW() >= 1 && config.getPH() >= 0 && config.getDW() >= 1, - INVALID_CONFIGURATION, - config.getSH(), config.getPH(), config.getDW()); + INVALID_CONFIGURATION, + config.getSH(), config.getPH(), config.getDW()); addArgs(); - if(sameDiff != null) { - sameDiff.putOpForId(this.getOwnName(), this); //Normally called in DynamicCustomOp constructor, via setInstanceId - but sameDiff field is null at that point - sameDiff.addArgsFor(inputFunctions, this); - } } protected void addArgs() { @@ -252,7 +262,6 @@ public class Conv2D extends DynamicCustomOp { Conv2DDerivative conv2DDerivative = Conv2DDerivative.derivativeBuilder() .sameDiff(sameDiff) .config(config) - .outputs(outputArguments()) .inputFunctions(inputs.toArray(new SDVariable[inputs.size()])) .build(); List ret = Arrays.asList(conv2DDerivative.outputVariables()); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv2DDerivative.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv2DDerivative.java index 8ccbd84eb..cd5ab6556 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv2DDerivative.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv2DDerivative.java @@ -37,8 +37,8 @@ import java.util.List; public class Conv2DDerivative extends Conv2D { @Builder(builderMethodName = "derivativeBuilder") - public Conv2DDerivative(SameDiff sameDiff, SDVariable[] inputFunctions, INDArray[] inputArrays, INDArray[] outputs, Conv2DConfig config) { - super(sameDiff, inputFunctions, inputArrays, outputs, config); + public Conv2DDerivative(SameDiff sameDiff, SDVariable[] inputFunctions, Conv2DConfig config) { + super(sameDiff, inputFunctions, config); } public Conv2DDerivative() {} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv3D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv3D.java index 810974103..8c4e40e8a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv3D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv3D.java @@ -18,6 +18,7 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution; import lombok.Builder; import lombok.Getter; +import lombok.NonNull; import lombok.extern.slf4j.Slf4j; import lombok.val; import org.nd4j.autodiff.functions.DifferentialFunction; @@ -33,6 +34,7 @@ import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig; import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv3DConfig; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; @@ -55,25 +57,27 @@ public class Conv3D extends DynamicCustomOp { public Conv3D() { } - @Builder(builderMethodName = "builder") - public Conv3D(SameDiff sameDiff, SDVariable[] inputFunctions, INDArray[] inputs, INDArray[] outputs, - Conv3DConfig conv3DConfig) { - super(null, sameDiff, inputFunctions, false); - setSameDiff(sameDiff); + @Builder(builderMethodName = "sameDiffBuilder") + public Conv3D(SameDiff sameDiff, SDVariable[] inputFunctions, Conv3DConfig config) { + super(sameDiff, inputFunctions); + initConfig(config); + } - if (inputs != null) - addInputArgument(inputs); - if (outputs != null) - addOutputArgument(outputs); - this.config = conv3DConfig; + public Conv3D(INDArray[] inputs, INDArray[] outputs, Conv3DConfig config){ + super(inputs, outputs); + initConfig(config); + } + + public Conv3D(@NonNull INDArray input, @NonNull INDArray weights, INDArray bias, INDArray output, @NonNull Conv3DConfig config){ + this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config); + } + + private void initConfig(Conv3DConfig config){ + this.config = config; Preconditions.checkState(config.getSW() >= 1 && config.getPH() >= 0 && config.getDW() >= 1, - INVALID_CONFIGURATION, - config.getSW(), config.getPH(), config.getDW()); + INVALID_CONFIGURATION, + config.getSW(), config.getPH(), config.getDW()); addArgs(); - - - //for (val arg: iArgs()) - // System.out.println(getIArgument(arg)); } @@ -259,8 +263,6 @@ public class Conv3D extends DynamicCustomOp { inputs.add(f1.get(0)); Conv3DDerivative conv3DDerivative = Conv3DDerivative.derivativeBuilder() .conv3DConfig(config) - .inputFunctions(args()) - .outputs(outputArguments()) .inputFunctions(inputs.toArray(new SDVariable[inputs.size()])) .sameDiff(sameDiff) .build(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv3DDerivative.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv3DDerivative.java index ee34fca90..ea6312094 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv3DDerivative.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv3DDerivative.java @@ -39,8 +39,8 @@ public class Conv3DDerivative extends Conv3D { public Conv3DDerivative() {} @Builder(builderMethodName = "derivativeBuilder") - public Conv3DDerivative(SameDiff sameDiff, SDVariable[] inputFunctions, INDArray[] inputs, INDArray[] outputs, Conv3DConfig conv3DConfig) { - super(sameDiff, inputFunctions, inputs, outputs, conv3DConfig); + public Conv3DDerivative(SameDiff sameDiff, SDVariable[] inputFunctions, Conv3DConfig conv3DConfig) { + super(sameDiff, inputFunctions, conv3DConfig); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv2D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv2D.java index 65c0fccc3..c69292dd9 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv2D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv2D.java @@ -19,6 +19,7 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution; import lombok.Builder; import lombok.Getter; import lombok.NoArgsConstructor; +import lombok.NonNull; import lombok.extern.slf4j.Slf4j; import lombok.val; import onnx.Onnx; @@ -31,6 +32,7 @@ import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig; +import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv3DConfig; import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv2DConfig; import org.nd4j.linalg.util.ArrayUtil; import org.tensorflow.framework.AttrValue; @@ -51,25 +53,25 @@ public class DeConv2D extends DynamicCustomOp { protected DeConv2DConfig config; - @Builder(builderMethodName = "builder") + @Builder(builderMethodName = "sameDiffBuilder") public DeConv2D(SameDiff sameDiff, SDVariable[] inputs, - INDArray[] inputArrays, INDArray[] outputs, DeConv2DConfig config) { - super(null, inputArrays, outputs); - this.sameDiff = sameDiff; + super(sameDiff, inputs); this.config = config; - if (inputArrays != null) { - addInputArgument(inputArrays); - } - if (outputs != null) { - addOutputArgument(outputs); - } - addArgs(); - sameDiff.putOpForId(this.getOwnName(), this); - sameDiff.addArgsFor(inputs, this); + } + + public DeConv2D(INDArray[] inputs, INDArray[] outputs, DeConv2DConfig config){ + super(inputs, outputs); + + this.config = config; + addArgs(); + } + + public DeConv2D(@NonNull INDArray input, @NonNull INDArray weights, INDArray bias, INDArray output, @NonNull DeConv2DConfig config){ + this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv2DDerivative.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv2DDerivative.java index 174d95ed7..04dc1dd2d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv2DDerivative.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv2DDerivative.java @@ -40,8 +40,8 @@ public class DeConv2DDerivative extends DeConv2D { public DeConv2DDerivative() {} @Builder(builderMethodName = "derivativeBuilder") - public DeConv2DDerivative(SameDiff sameDiff, SDVariable[] inputs, INDArray[] inputArrays, INDArray[] outputs, DeConv2DConfig config) { - super(sameDiff, inputs, inputArrays, outputs, config); + public DeConv2DDerivative(SameDiff sameDiff, SDVariable[] inputs, DeConv2DConfig config) { + super(sameDiff, inputs, config); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv2DTF.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv2DTF.java index 085f48365..bc4f996b1 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv2DTF.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv2DTF.java @@ -53,25 +53,21 @@ public class DeConv2DTF extends DynamicCustomOp { protected DeConv2DConfig config; - @Builder(builderMethodName = "builder") + @Builder(builderMethodName = "sameDiffBuilder") public DeConv2DTF(SameDiff sameDiff, SDVariable[] inputs, - INDArray[] inputArrays, INDArray[] outputs, DeConv2DConfig config) { - super(null, inputArrays, outputs); - this.sameDiff = sameDiff; + super(sameDiff, inputs); + + this.config = config; + addArgs(); + } + + public DeConv2DTF(INDArray[] inputs, INDArray[] outputs, DeConv2DConfig config){ + super(inputs, outputs); + this.config = config; - - if (inputArrays != null) { - addInputArgument(inputArrays); - } - if (outputs != null) { - addOutputArgument(outputs); - } - addArgs(); - sameDiff.putOpForId(this.getOwnName(), this); - sameDiff.addArgsFor(inputs, this); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv3D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv3D.java index 20b28da5a..077f6a64b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv3D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv3D.java @@ -28,6 +28,7 @@ import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv2DConfig; import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv3DConfig; import org.nd4j.linalg.util.ArrayUtil; import org.tensorflow.framework.AttrValue; @@ -53,12 +54,23 @@ public class DeConv3D extends DynamicCustomOp { protected DeConv3DConfig config; - public DeConv3D(SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, DeConv3DConfig config) { + public DeConv3D(SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, @NonNull DeConv3DConfig config) { super(sameDiff, toArr(input, weights, bias)); this.config = config; addArgs(); } + public DeConv3D(INDArray[] inputs, INDArray[] outputs, DeConv3DConfig config){ + super(inputs, outputs); + + this.config = config; + addArgs(); + } + + public DeConv3D(@NonNull INDArray input, @NonNull INDArray weights, INDArray bias, INDArray output, @NonNull DeConv3DConfig config){ + this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config); + } + private static SDVariable[] toArr(SDVariable input, SDVariable weights, SDVariable bias){ if(bias != null){ return new SDVariable[]{input, weights, bias}; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DepthwiseConv2D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DepthwiseConv2D.java index 92a39f188..ec2bb1d3f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DepthwiseConv2D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DepthwiseConv2D.java @@ -18,6 +18,7 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution; import lombok.Builder; import lombok.Getter; +import lombok.NonNull; import lombok.extern.slf4j.Slf4j; import lombok.val; import onnx.Onnx; @@ -35,6 +36,7 @@ import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig; +import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv3DConfig; import org.nd4j.linalg.util.ArrayUtil; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; @@ -53,17 +55,25 @@ public class DepthwiseConv2D extends DynamicCustomOp { protected Conv2DConfig config; - @Builder(builderMethodName = "builder") + @Builder(builderMethodName = "sameDiffBuilder") public DepthwiseConv2D(SameDiff sameDiff, SDVariable[] inputFunctions, - INDArray[] inputArrays, INDArray[] outputs, Conv2DConfig config) { - super(null, inputArrays, outputs); - this.sameDiff = sameDiff; + super(sameDiff, inputFunctions); + this.config = config; addArgs(); - sameDiff.putOpForId(this.getOwnName(), this); //Normally called in DynamicCustomOp constructor, via setInstanceId - but sameDiff field is null at that point - sameDiff.addArgsFor(inputFunctions, this); + } + + public DepthwiseConv2D(INDArray[] inputs, INDArray[] outputs, Conv2DConfig config){ + super(inputs, outputs); + + this.config = config; + addArgs(); + } + + public DepthwiseConv2D(@NonNull INDArray input, @NonNull INDArray weights, INDArray bias, INDArray output, @NonNull Conv2DConfig config){ + this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config); } public DepthwiseConv2D() { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/LocalResponseNormalization.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/LocalResponseNormalization.java index 421598d13..8dfb7131a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/LocalResponseNormalization.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/LocalResponseNormalization.java @@ -19,6 +19,7 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution; import lombok.Builder; import lombok.Getter; import lombok.NoArgsConstructor; +import lombok.NonNull; import lombok.extern.slf4j.Slf4j; import lombok.val; import onnx.Onnx; @@ -48,18 +49,19 @@ public class LocalResponseNormalization extends DynamicCustomOp { protected LocalResponseNormalizationConfig config; - @Builder(builderMethodName = "builder") - public LocalResponseNormalization(SameDiff sameDiff, SDVariable[] inputFunctions, - INDArray[] inputs, INDArray[] outputs,boolean inPlace, + @Builder(builderMethodName = "sameDiffBuilder") + public LocalResponseNormalization(SameDiff sameDiff, SDVariable[] inputFunctions, boolean inPlace, LocalResponseNormalizationConfig config) { super(null,sameDiff, inputFunctions, inPlace); + + this.config = config; + addArgs(); + } + + public LocalResponseNormalization(@NonNull INDArray input, INDArray output, @NonNull LocalResponseNormalizationConfig config){ + super(new INDArray[]{input}, wrapOrNull(output)); + this.config = config; - if(inputs != null) { - addInputArgument(inputs); - } - if(outputs!= null) { - addOutputArgument(outputs); - } addArgs(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/LocalResponseNormalizationDerivative.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/LocalResponseNormalizationDerivative.java index 2159f87fa..c2e6aad15 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/LocalResponseNormalizationDerivative.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/LocalResponseNormalizationDerivative.java @@ -33,8 +33,8 @@ import java.util.List; @Slf4j public class LocalResponseNormalizationDerivative extends LocalResponseNormalization { @Builder(builderMethodName = "derivativeBuilder") - public LocalResponseNormalizationDerivative(SameDiff sameDiff, SDVariable[] inputFunctions, INDArray[] inputs, INDArray[] outputs, boolean inPlace, LocalResponseNormalizationConfig config) { - super(sameDiff, inputFunctions, inputs, outputs, inPlace, config); + public LocalResponseNormalizationDerivative(SameDiff sameDiff, SDVariable[] inputFunctions, boolean inPlace, LocalResponseNormalizationConfig config) { + super(sameDiff, inputFunctions, inPlace, config); } public LocalResponseNormalizationDerivative() {} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/MaxPooling2D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/MaxPooling2D.java index b321334a5..09e928d2f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/MaxPooling2D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/MaxPooling2D.java @@ -51,27 +51,18 @@ public class MaxPooling2D extends DynamicCustomOp { public MaxPooling2D() { } - @Builder(builderMethodName = "builder") + @Builder(builderMethodName = "sameDiffBuilder") @SuppressWarnings("Used in lombok") - public MaxPooling2D(SameDiff sameDiff, SDVariable input, INDArray arrayInput, INDArray arrayOutput, Pooling2DConfig config) { + public MaxPooling2D(SameDiff sameDiff, SDVariable input, Pooling2DConfig config) { super(null, sameDiff, new SDVariable[]{input}, false); - if (arrayInput != null) { - addInputArgument(arrayInput); - } - if (arrayOutput != null) { - addOutputArgument(arrayOutput); - } config.setType(Pooling2D.Pooling2DType.MAX); - this.config = config; - this.sameDiff = sameDiff; - addArgs(); } public MaxPooling2D(INDArray input, INDArray output, @NonNull Pooling2DConfig config){ - super(null, new INDArray[]{input}, output == null ? null : new INDArray[]{output}); + super(null, new INDArray[]{input}, wrapOrNull(output)); config.setType(Pooling2D.Pooling2DType.MAX); this.config = config; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Pooling2D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Pooling2D.java index c45d106e7..ab2984969 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Pooling2D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Pooling2D.java @@ -16,8 +16,14 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; import lombok.Builder; import lombok.Getter; +import lombok.NonNull; import lombok.extern.slf4j.Slf4j; import lombok.val; import onnx.Onnx; @@ -33,9 +39,6 @@ import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.NodeDef; -import java.lang.reflect.Field; -import java.util.*; - /** * Pooling2D operation @@ -70,21 +73,27 @@ public class Pooling2D extends DynamicCustomOp { public Pooling2D() {} - @Builder(builderMethodName = "builder") + @Builder(builderMethodName = "sameDiffBuilder") @SuppressWarnings("Used in lombok") - public Pooling2D(SameDiff sameDiff, SDVariable[] inputs,INDArray[] arrayInputs, INDArray[] arrayOutputs,Pooling2DConfig config) { - super(null,sameDiff, inputs, false); - if(arrayInputs != null) { - addInputArgument(arrayInputs); - } + public Pooling2D(SameDiff sameDiff, SDVariable[] inputs, + Pooling2DConfig config) { + super(null, sameDiff, inputs, false); - if(arrayOutputs != null) { - addOutputArgument(arrayOutputs); - } + this.config = config; + addArgs(); + } - this.config = config; + public Pooling2D(@NonNull INDArray[] inputs, INDArray[] outputs, @NonNull Pooling2DConfig config){ + super(inputs, outputs); + this.config = config; + addArgs(); + } + public Pooling2D(@NonNull INDArray input, INDArray output, @NonNull Pooling2DConfig config){ + super(new INDArray[]{input}, wrapOrNull(output)); + + this.config = config; addArgs(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Pooling2DDerivative.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Pooling2DDerivative.java index 6fdb40215..aa58603e1 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Pooling2DDerivative.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Pooling2DDerivative.java @@ -17,6 +17,7 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution; import lombok.Builder; +import lombok.NonNull; import lombok.extern.slf4j.Slf4j; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; @@ -36,8 +37,12 @@ import java.util.List; @Slf4j public class Pooling2DDerivative extends Pooling2D { @Builder(builderMethodName = "derivativeBuilder") - public Pooling2DDerivative(SameDiff sameDiff, SDVariable[] inputs, INDArray[] arrayInputs, INDArray[] arrayOutputs, Pooling2DConfig config) { - super(sameDiff, inputs, arrayInputs, arrayOutputs, config); + public Pooling2DDerivative(SameDiff sameDiff, SDVariable[] inputs, Pooling2DConfig config) { + super(sameDiff, inputs, config); + } + + public Pooling2DDerivative(@NonNull INDArray input, @NonNull INDArray grad, INDArray output, Pooling2DConfig config){ + super(new INDArray[]{input, grad}, wrapOrNull(output), config); } public Pooling2DDerivative() {} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/SConv2D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/SConv2D.java index 745caccba..d4ef84e88 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/SConv2D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/SConv2D.java @@ -17,6 +17,7 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution; import lombok.Builder; +import lombok.NonNull; import lombok.extern.slf4j.Slf4j; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; @@ -39,9 +40,17 @@ import java.util.List; @Slf4j public class SConv2D extends Conv2D { - @Builder(builderMethodName = "sBuilder") - public SConv2D(SameDiff sameDiff, SDVariable[] inputFunctions, INDArray[] inputArrays, INDArray[] outputs, Conv2DConfig conv2DConfig) { - super(sameDiff, inputFunctions, inputArrays, outputs, conv2DConfig); + @Builder(builderMethodName = "sameDiffSBuilder") + public SConv2D(SameDiff sameDiff, SDVariable[] inputFunctions, Conv2DConfig conv2DConfig) { + super(sameDiff, inputFunctions, conv2DConfig); + } + + public SConv2D(INDArray[] inputs, INDArray[] outputs, Conv2DConfig config){ + super(inputs, outputs, config); + } + + public SConv2D(@NonNull INDArray input, @NonNull INDArray depthWeights, INDArray pointWeights, INDArray bias, INDArray output, @NonNull Conv2DConfig config){ + this(wrapFilterNull(input, depthWeights, pointWeights, bias), wrapOrNull(output), config); } public SConv2D() {} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/SConv2DDerivative.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/SConv2DDerivative.java index e25dae144..a30a58d95 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/SConv2DDerivative.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/SConv2DDerivative.java @@ -38,8 +38,8 @@ import java.util.List; public class SConv2DDerivative extends SConv2D { @Builder(builderMethodName = "sDerviativeBuilder") - public SConv2DDerivative(SameDiff sameDiff, SDVariable[] inputFunctions, INDArray[] inputArrays, INDArray[] outputs, Conv2DConfig conv2DConfig) { - super(sameDiff, inputFunctions, inputArrays, outputs, conv2DConfig); + public SConv2DDerivative(SameDiff sameDiff, SDVariable[] inputFunctions, Conv2DConfig conv2DConfig) { + super(sameDiff, inputFunctions, conv2DConfig); } public SConv2DDerivative() {} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/convolution/Convolution.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/convolution/Convolution.java index cab411916..b31e6e036 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/convolution/Convolution.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/convolution/Convolution.java @@ -235,24 +235,20 @@ public class Convolution { public static INDArray pooling2D(INDArray img, int kh, int kw, int sy, int sx, int ph, int pw, int dh, int dw, boolean isSameMode, Pooling2D.Pooling2DType type, Pooling2D.Divisor divisor, double extra, int virtualHeight, int virtualWidth, INDArray out) { - Pooling2D pooling = Pooling2D.builder() - .arrayInputs(new INDArray[]{img}) - .arrayOutputs(new INDArray[]{out}) - .config(Pooling2DConfig.builder() - .dH(dh) - .dW(dw) - .extra(extra) - .kH(kh) - .kW(kw) - .pH(ph) - .pW(pw) - .isSameMode(isSameMode) - .sH(sy) - .sW(sx) - .type(type) - .divisor(divisor) - .build()) - .build(); + Pooling2D pooling = new Pooling2D(img, out, Pooling2DConfig.builder() + .dH(dh) + .dW(dw) + .extra(extra) + .kH(kh) + .kW(kw) + .pH(ph) + .pW(pw) + .isSameMode(isSameMode) + .sH(sy) + .sW(sx) + .type(type) + .divisor(divisor) + .build()); Nd4j.getExecutioner().execAndReturn(pooling); return out; } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java index 057f610bd..539901a41 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java @@ -389,10 +389,7 @@ public class LayerOpValidation extends BaseOpValidation { .build(); INDArray input = Nd4j.create(inSize); - AvgPooling2D avgPooling2D = AvgPooling2D.builder() - .arrayInput(input) - .config(conf) - .build(); + AvgPooling2D avgPooling2D = new AvgPooling2D(input, null, conf); val outSizes = Nd4j.getExecutioner().calculateOutputShape(avgPooling2D); @@ -410,10 +407,7 @@ public class LayerOpValidation extends BaseOpValidation { //Test backprop: - Pooling2DDerivative avg2dDeriv = Pooling2DDerivative.derivativeBuilder() - .arrayInputs(new INDArray[]{input, grad}) - .config(conf) - .build(); + Pooling2DDerivative avg2dDeriv = new Pooling2DDerivative(input, grad, null, conf); val outSizesBP = Nd4j.getExecutioner().calculateOutputShape(avg2dDeriv); assertEquals(1, outSizesBP.size()); @@ -435,10 +429,7 @@ public class LayerOpValidation extends BaseOpValidation { .build(); INDArray input = Nd4j.create(inSize); - AvgPooling2D avgPooling2D = AvgPooling2D.builder() - .arrayInput(input) - .config(conf) - .build(); + AvgPooling2D avgPooling2D = new AvgPooling2D(input, null, conf); val outSizes = Nd4j.getExecutioner().calculateOutputShape(avgPooling2D); assertEquals(1, outSizes.size()); @@ -454,11 +445,7 @@ public class LayerOpValidation extends BaseOpValidation { INDArray grad = Nd4j.create(exp); //Test backprop: - Pooling2DDerivative avg2dDeriv = Pooling2DDerivative.derivativeBuilder() - .arrayInputs(new INDArray[]{input, grad}) //Original input, and output gradient (eps - same shape as output) - .arrayOutputs(new INDArray[]{Nd4j.create(inSize)}) //Output for BP: same shape as original input - .config(conf) - .build(); + Pooling2DDerivative avg2dDeriv = new Pooling2DDerivative(input, grad, Nd4j.create(inSize), conf); val outSizesBP = Nd4j.getExecutioner().calculateOutputShape(avg2dDeriv); assertEquals(1, outSizesBP.size()); @@ -749,7 +736,7 @@ public class LayerOpValidation extends BaseOpValidation { .isSameMode(false) .build(); - SDVariable out = sd.cnn().conv2d(vars, c); + SDVariable out = sd.cnn().conv2d("conv", vars, c); out = sd.nn().tanh("out", out); INDArray outArr = sd.execAndEndResult(); From d41018751b9dc58225c041504bc2fab7ec112ffc Mon Sep 17 00:00:00 2001 From: Alex Black Date: Wed, 4 Sep 2019 19:11:17 +1000 Subject: [PATCH 16/19] Small fix (#233) Signed-off-by: Alex Black --- .../nn/layers/mkldnn/MKLDNNSubsamplingHelper.java | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/mkldnn/MKLDNNSubsamplingHelper.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/mkldnn/MKLDNNSubsamplingHelper.java index 3edbf0b28..abb84e965 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/mkldnn/MKLDNNSubsamplingHelper.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/mkldnn/MKLDNNSubsamplingHelper.java @@ -87,11 +87,7 @@ public class MKLDNNSubsamplingHelper implements SubsamplingHelper { break; } - Pooling2DDerivative d = Pooling2DDerivative.derivativeBuilder() - .config(conf) - .arrayInputs(new INDArray[]{input, epsilon}) - .arrayOutputs(new INDArray[]{gradAtInput}) - .build(); + Pooling2DDerivative d = new Pooling2DDerivative(input, epsilon, gradAtInput, conf); Nd4j.exec(d); return new Pair(new DefaultGradient(), gradAtInput); From a90c7dd9956a2d49bdbc190fabdb5800b0c6b47b Mon Sep 17 00:00:00 2001 From: raver119 Date: Wed, 4 Sep 2019 14:41:08 +0300 Subject: [PATCH 17/19] [WIP] Last set of changes (#234) * mmul op instead of cublasSgemm Signed-off-by: raver119 * transB Signed-off-by: raver119 * jcpp handles Signed-off-by: raver119 * bitwise and/or/xor Signed-off-by: raver119 * bitwise and/or/xor mapping Signed-off-by: raver119 * cuda/cublas version check Signed-off-by: raver119 * add expected version Signed-off-by: raver119 * cuda/cublas version check in java Signed-off-by: raver119 * one more error check Signed-off-by: raver119 * build fix Signed-off-by: raver119 * build fix Signed-off-by: raver119 * build fix Signed-off-by: raver119 * one more fix Signed-off-by: raver119 * skip CUDA version check for now Signed-off-by: raver119 * better wording Signed-off-by: raver119 * few more tweaks Signed-off-by: raver119 * few more tweaks Signed-off-by: raver119 --- libnd4j/blas/BlasVersionHelper.h | 40 +++++++++ libnd4j/blas/CMakeLists.txt | 8 +- libnd4j/blas/Environment.cpp | 9 ++- libnd4j/blas/Environment.h | 7 ++ libnd4j/blas/NativeOps.h | 2 +- libnd4j/blas/cpu/NativeOps.cpp | 4 + libnd4j/blas/cuda/BlasVersionHelper.cu | 29 +++++++ libnd4j/blas/cuda/NativeOps.cu | 12 +++ .../generic/bitwise/bitwise_and.cpp | 50 ++++++++++++ .../declarable/generic/bitwise/bitwise_or.cpp | 50 ++++++++++++ .../generic/bitwise/bitwise_xor.cpp | 50 ++++++++++++ .../include/ops/declarable/headers/bitwise.h | 33 ++++++++ .../converters/ImportClassMapping.java | 3 + .../impl/transforms/custom/BitwiseAnd.java | 78 ++++++++++++++++++ .../ops/impl/transforms/custom/BitwiseOr.java | 78 ++++++++++++++++++ .../impl/transforms/custom/BitwiseXor.java | 78 ++++++++++++++++++ .../deallocation/DeallocatorService.java | 2 +- .../java/org/nd4j/nativeblas/NativeOps.java | 2 + .../org/nd4j/nativeblas/NativeOpsHolder.java | 2 +- .../java/org/nd4j/nativeblas/Nd4jBlas.java | 3 +- .../nd4j/linalg/jcublas/JCublasBackend.java | 1 + .../linalg/jcublas/JCublasNDArrayFactory.java | 16 ++++ .../linalg/jcublas/blas/JcublasLevel3.java | 15 ++-- .../java/org/nd4j/nativeblas/Nd4jCuda.java | 9 ++- .../java/org/nd4j/nativeblas/Nd4jCpu.java | 81 ++++++++++++++++++- 25 files changed, 646 insertions(+), 16 deletions(-) create mode 100644 libnd4j/blas/BlasVersionHelper.h create mode 100644 libnd4j/blas/cuda/BlasVersionHelper.cu create mode 100644 libnd4j/include/ops/declarable/generic/bitwise/bitwise_and.cpp create mode 100644 libnd4j/include/ops/declarable/generic/bitwise/bitwise_or.cpp create mode 100644 libnd4j/include/ops/declarable/generic/bitwise/bitwise_xor.cpp create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/BitwiseAnd.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/BitwiseOr.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/BitwiseXor.java diff --git a/libnd4j/blas/BlasVersionHelper.h b/libnd4j/blas/BlasVersionHelper.h new file mode 100644 index 000000000..93e8d75e3 --- /dev/null +++ b/libnd4j/blas/BlasVersionHelper.h @@ -0,0 +1,40 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#ifndef SAMEDIFF_BLASVERSIONHELPER_H +#define SAMEDIFF_BLASVERSIONHELPER_H + +#include +#include +#include + +namespace nd4j { + class ND4J_EXPORT BlasVersionHelper { + public: + int _blasMajorVersion = 0; + int _blasMinorVersion = 0; + int _blasPatchVersion = 0; + + BlasVersionHelper(); + ~BlasVersionHelper() = default; + }; +} + +#endif //DEV_TESTS_BLASVERSIONHELPER_H diff --git a/libnd4j/blas/CMakeLists.txt b/libnd4j/blas/CMakeLists.txt index add8960a3..a93ac0d26 100755 --- a/libnd4j/blas/CMakeLists.txt +++ b/libnd4j/blas/CMakeLists.txt @@ -253,20 +253,20 @@ if(CUDA_BLAS) file(GLOB_RECURSE LOOPS_SOURCES_CUDA false ../include/loops/*.cu) if (NOT BUILD_TESTS) - CUDA_ADD_LIBRARY(${LIBND4J_NAME} SHARED cuda/NativeOps.cu cuda/NativeOpExecutioner.cu ${LOOPS_SOURCES_CUDA} + CUDA_ADD_LIBRARY(${LIBND4J_NAME} SHARED cuda/NativeOps.cu cuda/NativeOpExecutioner.cu cuda/BlasVersionHelper.cu Environment.cpp ${LOOPS_SOURCES_CUDA} ${CUSTOMOPS_HELPERS_SOURCES} ${HELPERS_SOURCES} ${EXEC_SOURCES} ../include/cnpy/cnpy.cpp ../include/nd4jmemset.h ../include/nd4jmalloc.h cpu/GraphExecutioner.cpp cuda/NDArray.cu cpu/NDArrayFactory.cpp - Environment.cpp Environment.h ${LOOPS_SOURCES} ${ARRAY_SOURCES} ${TYPES_SOURCES} + Environment.h ${LOOPS_SOURCES} ${ARRAY_SOURCES} ${TYPES_SOURCES} ${MEMORY_SOURCES} ${GRAPH_SOURCES} ${CUSTOMOPS_SOURCES} ${INDEXING_SOURCES} ${EXCEPTIONS_SOURCES} ${OPS_SOURCES} ${PERF_SOURCES}) else() set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DBUILD_TESTS=true") - CUDA_ADD_LIBRARY(${LIBND4J_NAME} STATIC cuda/NativeOps.cu cuda/NativeOpExecutioner.cu ${LOOPS_SOURCES_CUDA} + CUDA_ADD_LIBRARY(${LIBND4J_NAME} STATIC cuda/NativeOps.cu cuda/NativeOpExecutioner.cu cuda/BlasVersionHelper.cu Environment.cpp ${LOOPS_SOURCES_CUDA} ${CUSTOMOPS_HELPERS_SOURCES} ${HELPERS_SOURCES} ${EXEC_SOURCES} ../include/cnpy/cnpy.cpp ../include/nd4jmemset.h ../include/nd4jmalloc.h cpu/GraphExecutioner.cpp cuda/NDArray.cu cpu/NDArrayFactory.cpp - Environment.cpp Environment.h ${LOOPS_SOURCES} ${ARRAY_SOURCES} ${TYPES_SOURCES} + Environment.h ${LOOPS_SOURCES} ${ARRAY_SOURCES} ${TYPES_SOURCES} ${MEMORY_SOURCES} ${GRAPH_SOURCES} ${CUSTOMOPS_SOURCES} ${INDEXING_SOURCES} ${EXCEPTIONS_SOURCES} ${OPS_SOURCES} ${PERF_SOURCES}) endif() diff --git a/libnd4j/blas/Environment.cpp b/libnd4j/blas/Environment.cpp index 1c3dd2d9e..b4b2db4ae 100644 --- a/libnd4j/blas/Environment.cpp +++ b/libnd4j/blas/Environment.cpp @@ -35,7 +35,7 @@ #include #include - +#include "BlasVersionHelper.h" #endif namespace nd4j { @@ -66,6 +66,13 @@ namespace nd4j { #endif #ifdef __CUDABLAS__ + BlasVersionHelper ver; + _blasMajorVersion = ver._blasMajorVersion; + _blasMinorVersion = ver._blasMinorVersion; + _blasPatchVersion = ver._blasPatchVersion; + printf("ND4J CUDA build version: %i.%i.%i\n", _blasMajorVersion, _blasMinorVersion, _blasPatchVersion); + fflush(stdout); + int devCnt = 0; cudaGetDeviceCount(&devCnt); auto devProperties = new cudaDeviceProp[devCnt]; diff --git a/libnd4j/blas/Environment.h b/libnd4j/blas/Environment.h index 5092b6190..ac4dfa678 100644 --- a/libnd4j/blas/Environment.h +++ b/libnd4j/blas/Environment.h @@ -56,6 +56,13 @@ namespace nd4j{ Environment(); ~Environment(); public: + /** + * These 3 fields are mostly for CUDA/cuBLAS version tracking + */ + int _blasMajorVersion = 0; + int _blasMinorVersion = 0; + int _blasPatchVersion = 0; + static Environment* getInstance(); bool isVerbose(); diff --git a/libnd4j/blas/NativeOps.h b/libnd4j/blas/NativeOps.h index 9bca7bb10..ef46e7752 100755 --- a/libnd4j/blas/NativeOps.h +++ b/libnd4j/blas/NativeOps.h @@ -647,7 +647,7 @@ ND4J_EXPORT void setOmpNumThreads(int threads); ND4J_EXPORT void setOmpMinThreads(int threads); - +ND4J_EXPORT bool isBlasVersionMatches(int major, int minor, int build); /** * diff --git a/libnd4j/blas/cpu/NativeOps.cpp b/libnd4j/blas/cpu/NativeOps.cpp index 86bc04fc4..e016d58fe 100644 --- a/libnd4j/blas/cpu/NativeOps.cpp +++ b/libnd4j/blas/cpu/NativeOps.cpp @@ -728,6 +728,10 @@ void execReduce3Tad(Nd4jPointer *extraPointers, } } +bool isBlasVersionMatches(int major, int minor, int build) { + return true; +} + /** * * @param opNum diff --git a/libnd4j/blas/cuda/BlasVersionHelper.cu b/libnd4j/blas/cuda/BlasVersionHelper.cu new file mode 100644 index 000000000..1f80a0cc0 --- /dev/null +++ b/libnd4j/blas/cuda/BlasVersionHelper.cu @@ -0,0 +1,29 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include "../BlasVersionHelper.h" + +namespace nd4j { + BlasVersionHelper::BlasVersionHelper() { + _blasMajorVersion = __CUDACC_VER_MAJOR__; + _blasMinorVersion = __CUDACC_VER_MINOR__; + _blasPatchVersion = __CUDACC_VER_BUILD__; + } +} \ No newline at end of file diff --git a/libnd4j/blas/cuda/NativeOps.cu b/libnd4j/blas/cuda/NativeOps.cu index a29613b61..ec88de2e5 100755 --- a/libnd4j/blas/cuda/NativeOps.cu +++ b/libnd4j/blas/cuda/NativeOps.cu @@ -3357,6 +3357,18 @@ void deleteTadPack(nd4j::TadPack* ptr) { delete ptr; } +bool isBlasVersionMatches(int major, int minor, int build) { + auto result = major == Environment::getInstance()->_blasMajorVersion && minor == Environment::getInstance()->_blasMinorVersion && build == Environment::getInstance()->_blasPatchVersion; + + if (!result) { + nd4j_printf("CUDA/cuBLAS version mismatch. Expected: %i.%i.%i but got %i.%i.%i instead\n", Environment::getInstance()->_blasMajorVersion, Environment::getInstance()->_blasMinorVersion, Environment::getInstance()->_blasPatchVersion, major, minor, build); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(152); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage("CUDA/cuBLAS version mismatch"); + } + + return result; +} + nd4j::ConstantDataBuffer* constantBufferLong(nd4j::DataType dtype, Nd4jLong *data, int length) { return nd4j::ConstantHelper::getInstance()->constantBuffer(ConstantDescriptor(data, length), dtype); } diff --git a/libnd4j/include/ops/declarable/generic/bitwise/bitwise_and.cpp b/libnd4j/include/ops/declarable/generic/bitwise/bitwise_and.cpp new file mode 100644 index 000000000..52d01429f --- /dev/null +++ b/libnd4j/include/ops/declarable/generic/bitwise/bitwise_and.cpp @@ -0,0 +1,50 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include +#if NOT_EXCLUDED(OP_bitwise_and) + +#include +#include +#include + +namespace nd4j { + namespace ops { + BROADCASTABLE_OP_IMPL(bitwise_and, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); + + BROADCAST_CHECK_EMPTY(x,y,z); + + x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::IntOps::IntAnd, pairwise::IntOps::IntAnd, broadcast::IntOps::IntAnd), y, z, false); + + return Status::OK(); + } + + DECLARE_TYPES(bitwise_and) { + getOpDescriptor() + ->setAllowedInputTypes({ALL_INTS}) + ->setSameMode(true); + } + } +} + +#endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/bitwise/bitwise_or.cpp b/libnd4j/include/ops/declarable/generic/bitwise/bitwise_or.cpp new file mode 100644 index 000000000..b8469d83a --- /dev/null +++ b/libnd4j/include/ops/declarable/generic/bitwise/bitwise_or.cpp @@ -0,0 +1,50 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include +#if NOT_EXCLUDED(OP_bitwise_or) + +#include +#include +#include + +namespace nd4j { + namespace ops { + BROADCASTABLE_OP_IMPL(bitwise_or, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); + + BROADCAST_CHECK_EMPTY(x,y,z); + + x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::IntOps::IntOr, pairwise::IntOps::IntOr, broadcast::IntOps::IntOr), y, z, false); + + return Status::OK(); + } + + DECLARE_TYPES(bitwise_or) { + getOpDescriptor() + ->setAllowedInputTypes({ALL_INTS}) + ->setSameMode(true); + } + } +} + +#endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/bitwise/bitwise_xor.cpp b/libnd4j/include/ops/declarable/generic/bitwise/bitwise_xor.cpp new file mode 100644 index 000000000..f7f3f479a --- /dev/null +++ b/libnd4j/include/ops/declarable/generic/bitwise/bitwise_xor.cpp @@ -0,0 +1,50 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include +#if NOT_EXCLUDED(OP_bitwise_xor) + +#include +#include +#include + +namespace nd4j { + namespace ops { + BROADCASTABLE_OP_IMPL(bitwise_xor, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); + + BROADCAST_CHECK_EMPTY(x,y,z); + + x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::IntOps::IntXor, pairwise::IntOps::IntXor, broadcast::IntOps::IntXor), y, z, false); + + return Status::OK(); + } + + DECLARE_TYPES(bitwise_xor) { + getOpDescriptor() + ->setAllowedInputTypes({ALL_INTS}) + ->setSameMode(true); + } + } +} + +#endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/headers/bitwise.h b/libnd4j/include/ops/declarable/headers/bitwise.h index a6362a73f..cb395b496 100644 --- a/libnd4j/include/ops/declarable/headers/bitwise.h +++ b/libnd4j/include/ops/declarable/headers/bitwise.h @@ -81,6 +81,39 @@ namespace nd4j { DECLARE_BROADCASTABLE_OP(cyclic_rshift_bits, 0, 0); #endif + /** + * This operation applies bitwise AND + * + * PLEASE NOTE: This operation is applicable only to integer data types + * + * @tparam T + */ + #if NOT_EXCLUDED(OP_bitwise_and) + DECLARE_BROADCASTABLE_OP(bitwise_and, 0, 0); + #endif + + /** + * This operation applies bitwise OR + * + * PLEASE NOTE: This operation is applicable only to integer data types + * + * @tparam T + */ + #if NOT_EXCLUDED(OP_bitwise_or) + DECLARE_BROADCASTABLE_OP(bitwise_or, 0, 0); + #endif + + /** + * This operation applies bitwise XOR + * + * PLEASE NOTE: This operation is applicable only to integer data types + * + * @tparam T + */ + #if NOT_EXCLUDED(OP_bitwise_xor) + DECLARE_BROADCASTABLE_OP(bitwise_xor, 0, 0); + #endif + /** * This operation returns hamming distance based on bits * diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java index 5bfba7a48..3a89b7339 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java @@ -353,6 +353,9 @@ public class ImportClassMapping { org.nd4j.linalg.api.ops.impl.transforms.custom.Choose.class, org.nd4j.linalg.api.ops.impl.transforms.custom.CumProd.class, org.nd4j.linalg.api.ops.impl.transforms.custom.CumSum.class, + org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseAnd.class, + org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseXor.class, + org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseOr.class, org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicShiftBits.class, org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicRShiftBits.class, org.nd4j.linalg.api.ops.impl.transforms.custom.Dilation2D.class, diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/BitwiseAnd.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/BitwiseAnd.java new file mode 100644 index 000000000..d81a72c1f --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/BitwiseAnd.java @@ -0,0 +1,78 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg.api.ops.impl.transforms.custom; + +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.imports.NoOpNameFoundException; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp; + +import java.util.Collections; +import java.util.List; + +/** + * Bit-wise AND operation, broadcastable + * + * @author raver119@gmail.com + */ +public class BitwiseAnd extends BaseDynamicTransformOp { + + public BitwiseAnd(SameDiff sameDiff, SDVariable x, SDVariable y) { + super(sameDiff, new SDVariable[] {x, y} ,false); + } + + public BitwiseAnd(INDArray x, INDArray y, INDArray output) { + super(new INDArray[]{x, y}, new INDArray[]{output}); + } + + public BitwiseAnd(INDArray x, INDArray y) { + this(x, y,x.ulike()); + } + + public BitwiseAnd() {} + + @Override + public String opName() { + return "bitwise_and"; + } + + + @Override + public String onnxName() { + throw new NoOpNameFoundException("No onnx op opName found for " + opName()); + } + + @Override + public String tensorflowName() { + return "bitwise_and"; + } + + + @Override + public List doDiff(List i_v) { + throw new UnsupportedOperationException("Not yet implemented: " + opName()); + } + + @Override + public List calculateOutputDataTypes(List dataTypes){ + Preconditions.checkState(dataTypes.get(0).isIntType(), "Input 0 datatype must be a integer type, got %s", dataTypes.get(0)); + return Collections.singletonList(dataTypes.get(0)); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/BitwiseOr.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/BitwiseOr.java new file mode 100644 index 000000000..85dd5c31b --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/BitwiseOr.java @@ -0,0 +1,78 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg.api.ops.impl.transforms.custom; + +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.imports.NoOpNameFoundException; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp; + +import java.util.Collections; +import java.util.List; + +/** + * Bit-wise OR operation, broadcastable + * + * @author raver119@gmail.com + */ +public class BitwiseOr extends BaseDynamicTransformOp { + + public BitwiseOr(SameDiff sameDiff, SDVariable x, SDVariable y) { + super(sameDiff, new SDVariable[] {x, y} ,false); + } + + public BitwiseOr(INDArray x, INDArray y, INDArray output) { + super(new INDArray[]{x, y}, new INDArray[]{output}); + } + + public BitwiseOr(INDArray x, INDArray y) { + this(x, y,x.ulike()); + } + + public BitwiseOr() {} + + @Override + public String opName() { + return "bitwise_or"; + } + + + @Override + public String onnxName() { + throw new NoOpNameFoundException("No onnx op opName found for " + opName()); + } + + @Override + public String tensorflowName() { + return "bitwise_or"; + } + + + @Override + public List doDiff(List i_v) { + throw new UnsupportedOperationException("Not yet implemented: " + opName()); + } + + @Override + public List calculateOutputDataTypes(List dataTypes){ + Preconditions.checkState(dataTypes.get(0).isIntType(), "Input 0 datatype must be a integer type, got %s", dataTypes.get(0)); + return Collections.singletonList(dataTypes.get(0)); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/BitwiseXor.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/BitwiseXor.java new file mode 100644 index 000000000..136ca9b62 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/BitwiseXor.java @@ -0,0 +1,78 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg.api.ops.impl.transforms.custom; + +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.imports.NoOpNameFoundException; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp; + +import java.util.Collections; +import java.util.List; + +/** + * Bit-wise XOR operation, broadcastable + * + * @author raver119@gmail.com + */ +public class BitwiseXor extends BaseDynamicTransformOp { + + public BitwiseXor(SameDiff sameDiff, SDVariable x, SDVariable y) { + super(sameDiff, new SDVariable[] {x, y} ,false); + } + + public BitwiseXor(INDArray x, INDArray y, INDArray output) { + super(new INDArray[]{x, y}, new INDArray[]{output}); + } + + public BitwiseXor(INDArray x, INDArray y) { + this(x, y,x.ulike()); + } + + public BitwiseXor() {} + + @Override + public String opName() { + return "bitwise_xor"; + } + + + @Override + public String onnxName() { + throw new NoOpNameFoundException("No onnx op opName found for " + opName()); + } + + @Override + public String tensorflowName() { + return "bitwise_xor"; + } + + + @Override + public List doDiff(List i_v) { + throw new UnsupportedOperationException("Not yet implemented: " + opName()); + } + + @Override + public List calculateOutputDataTypes(List dataTypes){ + Preconditions.checkState(dataTypes.get(0).isIntType(), "Input 0 datatype must be a integer type, got %s", dataTypes.get(0)); + return Collections.singletonList(dataTypes.get(0)); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/deallocation/DeallocatorService.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/deallocation/DeallocatorService.java index 26d850366..30c68d578 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/deallocation/DeallocatorService.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/deallocation/DeallocatorService.java @@ -54,7 +54,7 @@ public class DeallocatorService { deallocatorThreads = new Thread[numThreads]; queues = new ReferenceQueue[numThreads]; for (int e = 0; e < numThreads; e++) { - log.debug("Starting deallocator thread {}", e + 1); + log.trace("Starting deallocator thread {}", e + 1); queues[e] = new ReferenceQueue<>(); int deviceId = e % numDevices; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOps.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOps.java index 576cea78a..e694587b0 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOps.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOps.java @@ -1151,4 +1151,6 @@ public interface NativeOps { int lastErrorCode(); String lastErrorMessage(); + + boolean isBlasVersionMatches(int major, int minor, int build); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOpsHolder.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOpsHolder.java index 98bdb90fa..ae31ea7b8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOpsHolder.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOpsHolder.java @@ -101,7 +101,7 @@ public class NativeOpsHolder { } //deviceNativeOps.setOmpNumThreads(4); - log.info("Number of threads used for NativeOps: {}", deviceNativeOps.ompGetMaxThreads()); + log.info("Number of threads used for OpenMP: {}", deviceNativeOps.ompGetMaxThreads()); } catch (Exception | Error e) { throw new RuntimeException( "ND4J is probably missing dependencies. For more information, please refer to: http://nd4j.org/getstarted.html", diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/Nd4jBlas.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/Nd4jBlas.java index 23abf1d40..5de827d1a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/Nd4jBlas.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/Nd4jBlas.java @@ -51,7 +51,8 @@ public abstract class Nd4jBlas implements Blas { numThreads = NativeOpsHolder.getCores(Runtime.getRuntime().availableProcessors()); setMaxThreads(numThreads); } - log.info("Number of threads used for BLAS: {}", getMaxThreads()); + + log.info("Number of threads used for OpenMP BLAS: {}", getMaxThreads()); } } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/JCublasBackend.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/JCublasBackend.java index 28910ae6a..a6a5a45e4 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/JCublasBackend.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/JCublasBackend.java @@ -52,6 +52,7 @@ public class JCublasBackend extends Nd4jBackend { throw new RuntimeException("No CUDA devices were found in system"); } Loader.load(org.bytedeco.cuda.global.cublas.class); + return true; } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/JCublasNDArrayFactory.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/JCublasNDArrayFactory.java index 9e9dc34b2..73daa679d 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/JCublasNDArrayFactory.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/JCublasNDArrayFactory.java @@ -108,6 +108,22 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory { if (nativeOps.lastErrorCode() != 0) throw new RuntimeException(nativeOps.lastErrorMessage()); +/* + val major = new int[1]; + val minor = new int[1]; + val build = new int[1]; + org.bytedeco.cuda.global.cublas.cublasGetProperty(0, major); + org.bytedeco.cuda.global.cublas.cublasGetProperty(1, minor); + org.bytedeco.cuda.global.cublas.cublasGetProperty(2, build); + + val pew = new int[100]; + org.bytedeco.cuda.global.cudart.cudaDriverGetVersion(pew); + + nativeOps.isBlasVersionMatches(major[0], minor[0], build[0]); + + if (nativeOps.lastErrorCode() != 0) + throw new RuntimeException(nativeOps.lastErrorMessage()); + */ } @Override diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/blas/JcublasLevel3.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/blas/JcublasLevel3.java index b06211545..a7c2ab245 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/blas/JcublasLevel3.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/blas/JcublasLevel3.java @@ -28,10 +28,12 @@ import org.nd4j.jita.allocator.impl.AtomicAllocator; import org.nd4j.jita.allocator.pointers.cuda.cublasHandle_t; import org.nd4j.jita.conf.CudaEnvironment; import org.nd4j.linalg.api.blas.impl.BaseLevel3; +import org.nd4j.linalg.api.blas.params.MMulTranspose; import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.executioner.OpExecutionerUtil; +import org.nd4j.linalg.api.ops.impl.reduce.Mmul; import org.nd4j.linalg.factory.DataTypeValidation; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.jcublas.CublasPointer; @@ -113,16 +115,18 @@ public class JcublasLevel3 extends BaseLevel3 { @Override protected void sgemm(char Order, char TransA, char TransB, int M, int N, int K, float alpha, INDArray A, int lda, INDArray B, int ldb, float beta, INDArray C, int ldc) { - //A = Shape.toOffsetZero(A); - //B = Shape.toOffsetZero(B); + /* + val ctx = AtomicAllocator.getInstance().getDeviceContext(); + val handle = ctx.getCublasHandle(); + synchronized (handle) { + Nd4j.exec(new Mmul(A, B, C, MMulTranspose.builder().transposeA(false).transposeB(false).build())); + } + */ Nd4j.getExecutioner().push(); val ctx = allocator.getFlowController().prepareAction(C, A, B); - //log.info("Synchronizing CUDA stream"); - ctx.getOldStream().synchronize(); - val cAPointer = new CublasPointer(A, ctx); val cBPointer = new CublasPointer(B, ctx); val cCPointer = new CublasPointer(C, ctx); @@ -141,6 +145,7 @@ public class JcublasLevel3 extends BaseLevel3 { } allocator.registerAction(ctx, C, A, B); + OpExecutionerUtil.checkForAny(C); } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java index f3080f05a..0ddcb6266 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java @@ -557,6 +557,13 @@ public class Nd4jCuda extends org.nd4j.nativeblas.Nd4jCudaHelper { /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ public Environment(Pointer p) { super(p); } + /** + * These 3 fields are mostly for CUDA/cuBLAS version tracking + */ + public native int _blasMajorVersion(); public native Environment _blasMajorVersion(int setter); + public native int _blasMinorVersion(); public native Environment _blasMinorVersion(int setter); + public native int _blasPatchVersion(); public native Environment _blasPatchVersion(int setter); + public static native Environment getInstance(); public native @Cast("bool") boolean isVerbose(); @@ -1874,7 +1881,7 @@ public native void setOmpNumThreads(int threads); public native void setOmpMinThreads(int threads); - +public native @Cast("bool") boolean isBlasVersionMatches(int major, int minor, int build); /** * diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java index 9554a94e9..dabac7001 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java @@ -557,6 +557,13 @@ public class Nd4jCpu extends org.nd4j.nativeblas.Nd4jCpuHelper { /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ public Environment(Pointer p) { super(p); } + /** + * These 3 fields are mostly for CUDA/cuBLAS version tracking + */ + public native int _blasMajorVersion(); public native Environment _blasMajorVersion(int setter); + public native int _blasMinorVersion(); public native Environment _blasMinorVersion(int setter); + public native int _blasPatchVersion(); public native Environment _blasPatchVersion(int setter); + public static native Environment getInstance(); public native @Cast("bool") boolean isVerbose(); @@ -1874,7 +1881,7 @@ public native void setOmpNumThreads(int threads); public native void setOmpMinThreads(int threads); - +public native @Cast("bool") boolean isBlasVersionMatches(int major, int minor, int build); /** * @@ -21929,6 +21936,78 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); } // #endif + /** + * This operation applies bitwise AND + * + * PLEASE NOTE: This operation is applicable only to integer data types + * + * \tparam T + */ +// #if NOT_EXCLUDED(OP_bitwise_and) + @Namespace("nd4j::ops") public static class bitwise_and extends BroadcastableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public bitwise_and(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public bitwise_and(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public bitwise_and position(long position) { + return (bitwise_and)super.position(position); + } + + public bitwise_and() { super((Pointer)null); allocate(); } + private native void allocate(); + } +// #endif + + /** + * This operation applies bitwise OR + * + * PLEASE NOTE: This operation is applicable only to integer data types + * + * \tparam T + */ +// #if NOT_EXCLUDED(OP_bitwise_or) + @Namespace("nd4j::ops") public static class bitwise_or extends BroadcastableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public bitwise_or(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public bitwise_or(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public bitwise_or position(long position) { + return (bitwise_or)super.position(position); + } + + public bitwise_or() { super((Pointer)null); allocate(); } + private native void allocate(); + } +// #endif + + /** + * This operation applies bitwise XOR + * + * PLEASE NOTE: This operation is applicable only to integer data types + * + * \tparam T + */ +// #if NOT_EXCLUDED(OP_bitwise_xor) + @Namespace("nd4j::ops") public static class bitwise_xor extends BroadcastableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public bitwise_xor(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public bitwise_xor(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public bitwise_xor position(long position) { + return (bitwise_xor)super.position(position); + } + + public bitwise_xor() { super((Pointer)null); allocate(); } + private native void allocate(); + } +// #endif + /** * This operation returns hamming distance based on bits * From 548044a1e29325202c7e01412563b14e8344cc21 Mon Sep 17 00:00:00 2001 From: shugeo Date: Wed, 4 Sep 2019 14:57:59 +0300 Subject: [PATCH 18/19] Shugeo doc (#235) * Actualized doc to tnse ops. * Added comments for dynamic_stitch op. * Added comments to dynamic_stitch op implementation. * Modified comment for unstack_list op. * Added doc for space_to_depth and depth_to_space ops. * Added doc for space_to_batch op. * Enlarge test type for adjustSaturation. * Added doc for runner. --- .../generic/parity_ops/dynamic_stitch.cpp | 15 +++-- .../ops/declarable/headers/BarnesHutTsne.h | 9 ++- libnd4j/include/ops/declarable/headers/list.h | 2 +- .../ops/declarable/headers/parity_ops.h | 64 +++++++++++++++++-- .../ops/declarable/helpers/cuda/dynamic.cu | 11 ++-- .../layers_tests/DeclarableOpsTests13.cpp | 7 +- 6 files changed, 85 insertions(+), 23 deletions(-) diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/dynamic_stitch.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/dynamic_stitch.cpp index 70310f643..e6913dc34 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/dynamic_stitch.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/dynamic_stitch.cpp @@ -29,21 +29,26 @@ namespace ops { CUSTOM_OP_IMPL(dynamic_stitch, 2, 1, false, 0, 0) { int numOfData = block.width(); // int k = 0; + // checking input data size REQUIRE_TRUE(numOfData % 2 == 0, 0, "dynamic_stitch: The input params should contains" " both indeces and data lists with same length."); + // split input data list on two equal parts numOfData /= 2; + // form input lists to use with helpers - both indices and float data inputs auto output = OUTPUT_VARIABLE(0); std::vector inputs(numOfData); std::vector indices(numOfData); + for (int e = 0; e < numOfData; e++) { auto data = INPUT_VARIABLE(numOfData + e); auto index = INPUT_VARIABLE(e); + inputs[e] = data; indices[e] = index; } - + // run helper return helpers::dynamicStitchFunctor(block.launchContext(), inputs, indices, output); } @@ -59,17 +64,17 @@ namespace ops { numOfData /= 2; // only index part it's needed to review auto restShape = inputShape->at(numOfData); auto firstShape = inputShape->at(0); + // check up inputs to avoid non-int indices and calculate max value from indices to output shape length for(int i = 0; i < numOfData; i++) { auto input = INPUT_VARIABLE(i); REQUIRE_TRUE(input->isZ(), 0, "dynamic_stitch: Indices should be integer, but %d type given.", (int)input->dataType() ); - // FIXME: we have reduce::Max, cinsider using it instead auto maxV = input->reduceNumber(reduce::Max); if (maxV.e(0) > maxValue) maxValue = maxV.e(0); } - - int outRank = shape::rank(restShape) - shape::rank(firstShape) + 1; + // calculate output rank - difference between indices shape and data shape + int outRank = shape::rank(restShape) - shape::rank(firstShape) + 1; // at least 1D tensor std::vector outShape(outRank); - + // fill up output shape template: the first to max index, and rests - to vals from the first data input outShape[0] = maxValue + 1; for(int i = 1; i < outRank; ++i) outShape[i] = shape::sizeAt(restShape, i); diff --git a/libnd4j/include/ops/declarable/headers/BarnesHutTsne.h b/libnd4j/include/ops/declarable/headers/BarnesHutTsne.h index 89e4c385a..d3a4c042d 100644 --- a/libnd4j/include/ops/declarable/headers/BarnesHutTsne.h +++ b/libnd4j/include/ops/declarable/headers/BarnesHutTsne.h @@ -33,12 +33,13 @@ namespace nd4j { * 0: 1D row-vector (or with shape (1, m)) * 1: 1D integer vector with slice nums * 2: 1D float-point values vector with same shape as above + * 3: 2D float-point matrix with data to search * * Int args: * 0: N - number of slices * * Output: - * 0: 1D vector with edge forces for input and values + * 0: 2D matrix with the same shape and type as the 3th argument */ #if NOT_EXCLUDED(OP_barnes_edge_forces) DECLARE_CUSTOM_OP(barnes_edge_forces, 4, 1, false, 0, 1); @@ -52,9 +53,11 @@ namespace nd4j { * 0: 1D int row-vector * 1: 1D int col-vector * 2: 1D float vector with values - * + * * Output: - * 0: symmetric 2D matrix with given values on given places + * 0: 1D int result row-vector + * 1: 1D int result col-vector + * 2: a float-point tensor with shape 1xN, with values from the last input vector */ #if NOT_EXCLUDED(OP_barnes_symmetrized) DECLARE_CUSTOM_OP(barnes_symmetrized, 3, 3, false, 0, -1); diff --git a/libnd4j/include/ops/declarable/headers/list.h b/libnd4j/include/ops/declarable/headers/list.h index 01c2d225c..756895a1f 100644 --- a/libnd4j/include/ops/declarable/headers/list.h +++ b/libnd4j/include/ops/declarable/headers/list.h @@ -120,7 +120,7 @@ namespace nd4j { #endif /** - * This operation unstacks given NDArray into NDArrayList + * This operation unstacks given NDArray into NDArrayList by the first dimension */ #if NOT_EXCLUDED(OP_unstack_list) DECLARE_LIST_OP(unstack_list, 1, 1, 0, 0); diff --git a/libnd4j/include/ops/declarable/headers/parity_ops.h b/libnd4j/include/ops/declarable/headers/parity_ops.h index c86f28499..bb7f306bd 100644 --- a/libnd4j/include/ops/declarable/headers/parity_ops.h +++ b/libnd4j/include/ops/declarable/headers/parity_ops.h @@ -594,21 +594,46 @@ namespace nd4j { /** + * This operation rearranges data from depth into blocks of spatial data. This is the reverse transformation + * of space_to_depth op. This op output is a copy of the input tensor where values from the depth dimension + * are moved in spatial blocks to the height and width dimensions. Int attr 0 indicates the input + * block size and how the data is moved. + * Input: + * 0 - 4D tensor on given type + * Output: + * 0 - 4D tensor of given type and proper shape * - * - * + * Int arguments: + * 0 - block size + * 1 - output data format: 0 ("NHWC"): shape{ batch, height, width, channels } + * 1 ("NCHW"): shape{ batch, channels, height, width } + * 2 ("NCHW_VECT_C"): int8 shape{ batch, channels / 4, height, width, 4 } + * optional (default 0) */ #if NOT_EXCLUDED(OP_depth_to_space) - DECLARE_CUSTOM_OP(depth_to_space, 1, 1, false, 0, 2); + DECLARE_CUSTOM_OP(depth_to_space, 1, 1, false, 0, -1); #endif /** + * This operation rearranges blocks of spatial data, into depth.This op output is a copy of the input tensor + * where values from the height and width dimensions are moved to the depth dimension. Int attr 0 indicates + * the input block size. * + * Input: + * - 4D tensor of given type + * Output: + * - 4D tensor * + * Int arguments: + * 0 - block size + * 1 - output data format: 0 ("NHWC"): shape{ batch, height, width, channels } + * 1 ("NCHW"): shape{ batch, channels, height, width } + * 2 ("NCHW_VECT_C"): int8 shape{ batch, channels / 4, height, width, 4 } + * optional (default 0) * */ #if NOT_EXCLUDED(OP_space_to_depth) - DECLARE_CUSTOM_OP(space_to_depth, 1, 1, false, 0, 2); + DECLARE_CUSTOM_OP(space_to_depth, 1, 1, false, 0, -1); #endif /** @@ -622,13 +647,42 @@ namespace nd4j { #endif /** + * Zero-pads and then rearranges (permutes) blocks of spatial data into batch. More specifically, this op + * outputs a copy of the input tensor where values from the height and width dimensions are moved to the + * batch dimension. After the zero-padding, both height and width of the input must be divisible by the block + * size. * + * Inputs: + * 0 - input tensor + * 1 - 2D paddings tensor (shape {M, 2}) + * + * Output: + * - result tensor + * + * Int args: + * 0 - block size (M) * */ #if NOT_EXCLUDED(OP_space_to_batch) DECLARE_CUSTOM_OP(space_to_batch, 2, 1, false, 0, 1); #endif + /* + * This operation divides "spatial" dimensions [1, ..., M] of the input into a grid of blocks of shape + * block_shape, and interleaves these blocks with the "batch" dimension (0) such that in the output, + * the spatial dimensions [1, ..., M] correspond to the position within the grid, and the batch dimension + * combines both the position within a spatial block and the original batch position. Prior to division into + * blocks, the spatial dimensions of the input are optionally zero padded according to paddings. + * + * Inputs: + * 0 - input (N-D tensor) + * 1 - block_shape - int 1D tensor with M length + * 2 - paddings - int 2D tensor with shape {M, 2} + * + * Output: + * - N-D tensor with the same type as input 0. + * + * */ #if NOT_EXCLUDED(OP_space_to_batch_nd) DECLARE_CUSTOM_OP(space_to_batch_nd, 3, 1, false, 0, 0); #endif @@ -973,7 +1027,7 @@ namespace nd4j { * return value: * tensor with min values according to indices sets. */ - #if NOT_EXCLUDED(OP_segment_min_bp) + #if NOT_EXCLUDED(OP_segment_min) DECLARE_CUSTOM_OP(segment_min, 2, 1, false, 0, 0); #endif #if NOT_EXCLUDED(OP_segment_min_bp) diff --git a/libnd4j/include/ops/declarable/helpers/cuda/dynamic.cu b/libnd4j/include/ops/declarable/helpers/cuda/dynamic.cu index 7d520478e..75b541b72 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/dynamic.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/dynamic.cu @@ -118,19 +118,19 @@ namespace nd4j { PointersManager pm(context, "dynamicPartition"); - if (sourceDimsLen) { + if (sourceDimsLen) { // non-linear case std::vector sourceDims(sourceDimsLen); for (int i = sourceDimsLen; i > 0; i--) sourceDims[sourceDimsLen - i] = input->rankOf() - i; - + //compute tad array for given dimensions auto packX = ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), sourceDims); std::vector outBuffers(outSize); std::vector tadShapes(outSize); std::vector tadOffsets(outSize); std::vector numTads(outSize); - + // fill up dimensions array for before kernel for (unsigned int i = 0; i < outSize; i++) { outputs[i].first = outputList[i]; std::vector outDims(outputs[i].first->rankOf() - 1); @@ -151,10 +151,10 @@ namespace nd4j { auto dOutBuffers = reinterpret_cast(pm.replicatePointer(outBuffers.data(), outBuffers.size() * sizeof(void *))); auto dOutTadShapes = reinterpret_cast(pm.replicatePointer(tadShapes.data(), tadShapes.size() * sizeof(Nd4jLong *))); auto dOutTadOffsets = reinterpret_cast(pm.replicatePointer(tadOffsets.data(), tadOffsets.size() * sizeof(Nd4jLong *))); - + // run kernel on device dynamicPartitionTadKernel<<<256, 256, 1024, *context->getCudaStream()>>>(input->getSpecialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), shape::length(packX.primaryShapeInfo()), indices->getSpecialBuffer(), indices->getSpecialShapeInfo(), indices->lengthOf(), dOutBuffers, dOutTadShapes, dOutTadOffsets, outSize); - } else { + } else { // linear case auto numThreads = 256; auto shmemSize = numThreads * sizeof(Y) * 2 + 1024; @@ -169,7 +169,6 @@ namespace nd4j { auto dOutBuffers = reinterpret_cast(pm.replicatePointer(outBuffers.data(), outBuffers.size() * sizeof(void *))); auto dOutShapes = reinterpret_cast(pm.replicatePointer(outShapes.data(), outShapes.size() * sizeof(Nd4jLong *))); - dynamicPartitionScalarKernel<<<256, numThreads, shmemSize, *context->getCudaStream()>>>(input->getSpecialBuffer(), input->getSpecialShapeInfo(), indices->getSpecialBuffer(), indices-> getSpecialShapeInfo(), dOutBuffers, dOutShapes, outSize); } diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp index 87ac417be..2ef9e2309 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp @@ -544,8 +544,8 @@ TEST_F(DeclarableOpsTests13, adjustSaturation_1) { //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, adjustSaturation_2) { - NDArray input('c', {2,2,3}, {0,100,56, 17,220,5, 150,97,230, 255,2,13}, nd4j::DataType::FLOAT32); - NDArray exp ('c', {2,2,3}, {0.,100.,56., 12.279087,220.,0., 91.654228,0.,230., 255.,0.,11.087015}, nd4j::DataType::FLOAT32); + NDArray input('c', {2,2,3}, {0,100,56, 17,220,5, 150,97,230, 255,2,13}, nd4j::DataType::DOUBLE); + NDArray exp ('c', {2,2,3}, {0.,100.,56., 12.279087,220.,0., 91.654228,0.,230., 255.,0.,11.087015}, nd4j::DataType::DOUBLE); nd4j::ops::adjust_saturation op; auto results = op.execute({&input}, {10}, {2}); @@ -553,7 +553,8 @@ TEST_F(DeclarableOpsTests13, adjustSaturation_2) { ASSERT_EQ(ND4J_STATUS_OK, results->status()); auto result = results->at(0); - // result->printIndexedBuffer(); +// result->printIndexedBuffer("Result2"); +// exp.printIndexedBuffer("Expect2"); ASSERT_TRUE(exp.isSameShape(result)); ASSERT_TRUE(exp.equalsTo(result)); From 03c52ef9ddc3ef57dd75262fccc835fca3089b84 Mon Sep 17 00:00:00 2001 From: Alex Black Date: Wed, 4 Sep 2019 22:34:31 +1000 Subject: [PATCH 19/19] Add SameDiff.bitwise namespace (#232) * #8196 add SameDiff.bitwise namespace Signed-off-by: AlexDBlack * Add BitsHammingDistance, add remaining bitwise ops to bitwise namespace Signed-off-by: AlexDBlack * fix Signed-off-by: AlexDBlack --- .../DifferentialFunctionFactory.java | 16 ++ .../org/nd4j/autodiff/samediff/SameDiff.java | 12 + .../nd4j/autodiff/samediff/ops/SDBitwise.java | 205 ++++++++++++++++++ .../converters/ImportClassMapping.java | 1 + .../custom/BitsHammingDistance.java | 37 ++++ .../transforms/custom/CyclicRShiftBits.java | 2 +- .../transforms/custom/CyclicShiftBits.java | 2 +- .../impl/transforms/custom/RShiftBits.java | 2 +- .../ops/impl/transforms/custom/ShiftBits.java | 2 +- 9 files changed, 275 insertions(+), 4 deletions(-) create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBitwise.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/BitsHammingDistance.java diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java index 3086b0f1b..621dac941 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java @@ -1288,6 +1288,22 @@ public class DifferentialFunctionFactory { return new CyclicRShiftBits(sameDiff(), ix, shift).outputVariable(); } + public SDVariable bitwiseHammingDist(SDVariable x, SDVariable y) { + return new BitsHammingDistance(sameDiff(), x, y).outputVariable(); + } + + public SDVariable bitwiseAnd(SDVariable x, SDVariable y){ + return new BitwiseAnd(sameDiff(), x, y).outputVariable(); + } + + public SDVariable bitwiseOr(SDVariable x, SDVariable y){ + return new BitwiseOr(sameDiff(), x, y).outputVariable(); + } + + public SDVariable bitwiseXor(SDVariable x, SDVariable y){ + return new BitwiseXor(sameDiff(), x, y).outputVariable(); + } + public SDVariable eq(SDVariable iX, SDVariable i_y) { return new EqualTo(sameDiff(), new SDVariable[]{iX, i_y}, false).outputVariable(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java index e09ceda75..0b5a4c03f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java @@ -188,6 +188,11 @@ public class SameDiff extends SDBaseOps { */ public final SDImage image = new SDImage(this); + /** + * Op creator object for bitwise operations + */ + public final SDBitwise bitwise = new SDBitwise(this); + /** * Op creator object for math operations */ @@ -237,6 +242,13 @@ public class SameDiff extends SDBaseOps { return image; } + /** + * Op creator object for bitwise operations + */ + public SDBitwise bitwise(){ + return bitwise; + } + /** * For import, many times we have variables diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBitwise.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBitwise.java new file mode 100644 index 000000000..0857b2b42 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBitwise.java @@ -0,0 +1,205 @@ +package org.nd4j.autodiff.samediff.ops; + +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; + +import static org.nd4j.autodiff.samediff.ops.SDValidation.validateInteger; + +/** + * + */ +public class SDBitwise extends SDOps { + public SDBitwise(SameDiff sameDiff) { + super(sameDiff); + } + + /** + * See {@link #leftShift(String, SDVariable, SDVariable)} + */ + public SDVariable leftShift(@NonNull SDVariable x, @NonNull SDVariable y){ + return leftShift(null, x, y); + } + + /** + * Bitwise left shift operation. Supports broadcasting. + * + * @param name Name of the output variable. May be null. + * @param x Input to be bit shifted (must be an integer type) + * @param y Amount to shift elements of x array (must be an integer type) + * @return Bitwise shifted input x + */ + public SDVariable leftShift(String name, SDVariable x, SDVariable y){ + validateInteger("bitwise left shift", x); + validateInteger("bitwise left shift", y); + + SDVariable ret = f().shift(x, y); + return updateVariableNameAndReference(ret, name); + } + + /** + * See {@link #rightShift(String, SDVariable, SDVariable)} + */ + public SDVariable rightShift(SDVariable x, SDVariable y){ + return rightShift(null, x, y); + } + + /** + * Bitwise right shift operation. Supports broadcasting. + * + * @param name Name of the output variable. May be null. + * @param x Input to be bit shifted (must be an integer type) + * @param y Amount to shift elements of x array (must be an integer type) + * @return Bitwise shifted input x + */ + public SDVariable rightShift(String name, SDVariable x, SDVariable y){ + validateInteger("bitwise right shift", x); + validateInteger("bitwise right shift", y); + + SDVariable ret = f().rshift(x, y); + return updateVariableNameAndReference(ret, name); + } + + /** + * See {@link #leftShiftCyclic(String, SDVariable, SDVariable)} + */ + public SDVariable leftShiftCyclic(SDVariable x, SDVariable y){ + return leftShiftCyclic(null, x, y); + } + + /** + * Bitwise left cyclical shift operation. Supports broadcasting. + * Unlike {@link #leftShift(String, SDVariable, SDVariable)} the bits will "wrap around": + * {@code leftShiftCyclic(01110000, 2) -> 11000001} + * + * @param name Name of the output variable. May be null. + * @param x Input to be bit shifted (must be an integer type) + * @param y Amount to shift elements of x array (must be an integer type) + * @return Bitwise cyclic shifted input x + */ + public SDVariable leftShiftCyclic(String name, SDVariable x, SDVariable y){ + validateInteger("bitwise left shift (cyclic)", x); + validateInteger("bitwise left shift (cyclic)", y); + + SDVariable ret = f().rotl(x, y); + return updateVariableNameAndReference(ret, name); + } + + /** + * See {@link #rightShiftCyclic(String, SDVariable, SDVariable)} + */ + public SDVariable rightShiftCyclic(SDVariable x, SDVariable y){ + return rightShiftCyclic(null, x, y); + } + + /** + * Bitwise right cyclical shift operation. Supports broadcasting. + * Unlike {@link #rightShift(String, SDVariable, SDVariable)} the bits will "wrap around": + * {@code rightShiftCyclic(00001110, 2) -> 10000011} + * + * @param name Name of the output variable. May be null. + * @param x Input to be bit shifted (must be an integer type) + * @param y Amount to shift elements of x array (must be an integer type) + * @return Bitwise cyclic shifted input x + */ + public SDVariable rightShiftCyclic(String name, SDVariable x, SDVariable y){ + validateInteger("bitwise right shift (cyclic)", x); + validateInteger("bitwise right shift (cyclic)", y); + + SDVariable ret = f().rotr(x, y); + return updateVariableNameAndReference(ret, name); + } + + /** + * See {@link #bitsHammingDistance(String, SDVariable, SDVariable)} + */ + public SDVariable bitsHammingDistance(SDVariable x, SDVariable y){ + return bitsHammingDistance(null, x, y); + } + + /** + * Bitwise Hamming distance reduction over all elements of both input arrays.
+ * For example, if x=01100000 and y=1010000 then the bitwise Hamming distance is 2 (due to differences at positions 0 and 1) + * + * @param name Name of the output variable. May be null. + * @param x First input array. Must be integer type. + * @param y First input array. Must be integer type, same type as x + * @return + */ + public SDVariable bitsHammingDistance(String name, SDVariable x, SDVariable y){ + validateInteger("bitwise hamming distance", x); + validateInteger("bitwise hamming distance", y); + + SDVariable ret = f().bitwiseHammingDist(x, y); + return updateVariableNameAndReference(ret, name); + } + + /** + * See {@link #and(String, SDVariable, SDVariable)} + */ + public SDVariable and(SDVariable x, SDVariable y){ + return and(null, x, y); + } + + /** + * Bitwise AND operation. Supports broadcasting. + * + * @param name Name of the output variable. May be null. + * @param x First input array. Must be integer type. + * @param y First input array. Must be integer type, same type as x + * @return Bitwise AND array + */ + public SDVariable and(String name, SDVariable x, SDVariable y){ + validateInteger("bitwise AND", x); + validateInteger("bitwise AND", y); + + SDVariable ret = f().bitwiseAnd(x, y); + return updateVariableNameAndReference(ret, name); + } + + /** + * See {@link #or(String, SDVariable, SDVariable)} + */ + public SDVariable or(SDVariable x, SDVariable y){ + return or(null, x, y); + } + + /** + * Bitwise OR operation. Supports broadcasting. + * + * @param name Name of the output variable. May be null. + * @param x First input array. Must be integer type. + * @param y First input array. Must be integer type, same type as x + * @return Bitwise OR array + */ + public SDVariable or(String name, SDVariable x, SDVariable y){ + validateInteger("bitwise OR", x); + validateInteger("bitwise OR", y); + + SDVariable ret = f().bitwiseOr(x, y); + return updateVariableNameAndReference(ret, name); + } + + /** + * See {@link #xor(String, SDVariable, SDVariable)} + */ + public SDVariable xor(SDVariable x, SDVariable y){ + return xor(null, x, y); + } + + /** + * Bitwise XOR operation (exclusive OR). Supports broadcasting. + * + * @param name Name of the output variable. May be null. + * @param x First input array. Must be integer type. + * @param y First input array. Must be integer type, same type as x + * @return Bitwise XOR array + */ + public SDVariable xor(String name, SDVariable x, SDVariable y){ + validateInteger("bitwise XOR", x); + validateInteger("bitwise XOR", y); + + SDVariable ret = f().bitwiseXor(x, y); + return updateVariableNameAndReference(ret, name); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java index 3a89b7339..19b534a97 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java @@ -353,6 +353,7 @@ public class ImportClassMapping { org.nd4j.linalg.api.ops.impl.transforms.custom.Choose.class, org.nd4j.linalg.api.ops.impl.transforms.custom.CumProd.class, org.nd4j.linalg.api.ops.impl.transforms.custom.CumSum.class, + org.nd4j.linalg.api.ops.impl.transforms.custom.BitsHammingDistance.class, org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseAnd.class, org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseXor.class, org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseOr.class, diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/BitsHammingDistance.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/BitsHammingDistance.java new file mode 100644 index 000000000..1fa749830 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/BitsHammingDistance.java @@ -0,0 +1,37 @@ +package org.nd4j.linalg.api.ops.impl.transforms.custom; + +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +import java.util.Collections; +import java.util.List; + +public class BitsHammingDistance extends DynamicCustomOp { + + public BitsHammingDistance(){ } + + public BitsHammingDistance(@NonNull SameDiff sd, @NonNull SDVariable x, @NonNull SDVariable y){ + super(sd, new SDVariable[]{x, y}); + } + + public BitsHammingDistance(@NonNull INDArray x, @NonNull INDArray y){ + super(new INDArray[]{x, y}, null); + } + + @Override + public String opName() { + return "bits_hamming_distance"; + } + + @Override + public List calculateOutputDataTypes(List dataTypes){ + Preconditions.checkState(dataTypes != null && dataTypes.size() == 2, "Expected 2 input datatypes, got %s", dataTypes); + Preconditions.checkState(dataTypes.get(0).isIntType() && dataTypes.get(1).isIntType(), "Input datatypes must be integer type, got %s", dataTypes); + return Collections.singletonList(DataType.LONG); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CyclicRShiftBits.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CyclicRShiftBits.java index 3a9173654..a8b4ebbb0 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CyclicRShiftBits.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CyclicRShiftBits.java @@ -61,7 +61,7 @@ public class CyclicRShiftBits extends BaseDynamicTransformOp { @Override public String tensorflowName() { - throw new NoOpNameFoundException("No onnx op opName found for " + opName()); + throw new NoOpNameFoundException("No TensorFlow op opName found for " + opName()); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CyclicShiftBits.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CyclicShiftBits.java index 20b6f6955..ea7ae1715 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CyclicShiftBits.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CyclicShiftBits.java @@ -61,7 +61,7 @@ public class CyclicShiftBits extends BaseDynamicTransformOp { @Override public String tensorflowName() { - throw new NoOpNameFoundException("No onnx op opName found for " + opName()); + throw new NoOpNameFoundException("No TensorFlow op opName found for " + opName()); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/RShiftBits.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/RShiftBits.java index 4435615f5..3cc03d12b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/RShiftBits.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/RShiftBits.java @@ -61,7 +61,7 @@ public class RShiftBits extends BaseDynamicTransformOp { @Override public String tensorflowName() { - throw new NoOpNameFoundException("No onnx op opName found for " + opName()); + throw new NoOpNameFoundException("No TensorFlow op opName found for " + opName()); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ShiftBits.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ShiftBits.java index 5501324f2..a9eebb14e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ShiftBits.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ShiftBits.java @@ -61,7 +61,7 @@ public class ShiftBits extends BaseDynamicTransformOp { @Override public String tensorflowName() { - throw new NoOpNameFoundException("No onnx op opName found for " + opName()); + throw new NoOpNameFoundException("No TensorFlow op opName found for " + opName()); }