From 9f719488b93c17d71f977579e2ba9b36f4f6a575 Mon Sep 17 00:00:00 2001 From: raver119 Date: Tue, 28 Jan 2020 10:55:06 +0300 Subject: [PATCH 01/17] CUDA sync tweaks (#194) * ThreadLocal cache for CudaContext Signed-off-by: raver119 * temp commit Signed-off-by: raver119 * remove unwanted synchronization Signed-off-by: raver119 --- .../jita/allocator/impl/AtomicAllocator.java | 10 --- .../jita/handler/impl/CudaZeroHandler.java | 31 +++++--- .../ops/executioner/CudaExecutioner.java | 5 +- .../org/nd4j/linalg/dataset/DataSetTest.java | 74 +++++++++---------- 4 files changed, 60 insertions(+), 60 deletions(-) diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/impl/AtomicAllocator.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/impl/AtomicAllocator.java index 03e5df160..aaccf9a34 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/impl/AtomicAllocator.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/impl/AtomicAllocator.java @@ -353,18 +353,8 @@ public class AtomicAllocator implements Allocator { */ @Override public void synchronizeHostData(DataBuffer buffer) { - // we don't want non-committed ops left behind - Nd4j.getExecutioner().commit(); - - val oPtr = NativeOpsHolder.getInstance().getDeviceNativeOps().dbPrimaryBuffer(((BaseCudaDataBuffer) buffer).getOpaqueDataBuffer()); - // we actually need synchronization only in device-dependant environment. no-op otherwise. managed by native code NativeOpsHolder.getInstance().getDeviceNativeOps().dbSyncToPrimary(((BaseCudaDataBuffer) buffer).getOpaqueDataBuffer()); - - val cPtr = NativeOpsHolder.getInstance().getDeviceNativeOps().dbPrimaryBuffer(((BaseCudaDataBuffer) buffer).getOpaqueDataBuffer()); - - //assert oPtr.address() == cPtr.address(); - //assert buffer.address() == oPtr.address(); } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/impl/CudaZeroHandler.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/impl/CudaZeroHandler.java index 6f0944c5a..a8f3a0a3b 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/impl/CudaZeroHandler.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/impl/CudaZeroHandler.java @@ -102,6 +102,8 @@ public class CudaZeroHandler implements MemoryHandler { private final AffinityManager affinityManager = Nd4j.getAffinityManager(); + private final transient ThreadLocal tlContext = new ThreadLocal<>(); + /* table for Thread, Device, Object allocations of device memory. Objects should be used to grab Allocation point from allocationsMap */ @@ -1018,18 +1020,25 @@ public class CudaZeroHandler implements MemoryHandler { * @return */ public CudaContext getCudaContext() { - val lc = nativeOps.defaultLaunchContext(); + var ctx = tlContext.get(); + if (ctx == null) { + val lc = nativeOps.defaultLaunchContext(); - return CudaContext.builder() - .bufferScalar(nativeOps.lcScalarPointer(lc)) - .bufferReduction(nativeOps.lcReductionPointer(lc)) - .bufferAllocation(nativeOps.lcAllocationPointer(lc)) - .bufferSpecial(nativeOps.lcScalarPointer(lc)) - .oldStream(new cudaStream_t(nativeOps.lcExecutionStream(lc))) - .specialStream(new cudaStream_t(nativeOps.lcCopyStream(lc))) - .cublasHandle(getCudaCublasHandle(lc)) - .solverHandle(new cusolverDnHandle_t(nativeOps.lcSolverHandle(lc))) - .build(); + ctx = CudaContext.builder() + .bufferScalar(nativeOps.lcScalarPointer(lc)) + .bufferReduction(nativeOps.lcReductionPointer(lc)) + .bufferAllocation(nativeOps.lcAllocationPointer(lc)) + .bufferSpecial(nativeOps.lcScalarPointer(lc)) + .oldStream(new cudaStream_t(nativeOps.lcExecutionStream(lc))) + .specialStream(new cudaStream_t(nativeOps.lcCopyStream(lc))) + .cublasHandle(getCudaCublasHandle(lc)) + .solverHandle(new cusolverDnHandle_t(nativeOps.lcSolverHandle(lc))) + .build(); + + tlContext.set(ctx); + return ctx; + } else + return ctx; } /** diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java index 1615e4843..04b86dc02 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java @@ -1610,8 +1610,9 @@ public class CudaExecutioner extends DefaultOpExecutioner { @Override public void commit() { - AtomicAllocator.getInstance().getDeviceContext().syncOldStream(); - AtomicAllocator.getInstance().getDeviceContext().syncSpecialStream(); + val ctx = AtomicAllocator.getInstance().getDeviceContext(); + ctx.syncOldStream(); + ctx.syncSpecialStream(); } @Override diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/DataSetTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/DataSetTest.java index a62dc631e..816003009 100755 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/DataSetTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/DataSetTest.java @@ -738,52 +738,52 @@ public class DataSetTest extends BaseNd4jTest { @Test public void testShuffleNd() { - int numDims = 7; - int nLabels = 3; - Random r = new Random(); + int numDims = 7; + int nLabels = 3; + Random r = new Random(); - int[] shape = new int[numDims]; - int entries = 1; - for (int i = 0; i < numDims; i++) { - //randomly generating shapes bigger than 1 - shape[i] = r.nextInt(4) + 2; - entries *= shape[i]; - } - int labels = shape[0] * nLabels; + int[] shape = new int[numDims]; + int entries = 1; + for (int i = 0; i < numDims; i++) { + //randomly generating shapes bigger than 1 + shape[i] = r.nextInt(4) + 2; + entries *= shape[i]; + } + int labels = shape[0] * nLabels; - INDArray ds_data = Nd4j.linspace(1, entries, entries, DataType.INT).reshape(shape); - INDArray ds_labels = Nd4j.linspace(1, labels, labels, DataType.INT).reshape(shape[0], nLabels); + INDArray ds_data = Nd4j.linspace(1, entries, entries, DataType.INT).reshape(shape); + INDArray ds_labels = Nd4j.linspace(1, labels, labels, DataType.INT).reshape(shape[0], nLabels); - DataSet ds = new DataSet(ds_data, ds_labels); - ds.shuffle(); + DataSet ds = new DataSet(ds_data, ds_labels); + ds.shuffle(); - //Checking Nd dataset which is the data - for (int dim = 1; dim < numDims; dim++) { - //get tensor along dimension - the order in every dimension but zero should be preserved - for (int tensorNum = 0; tensorNum < ds_data.tensorsAlongDimension(dim); tensorNum++) { - //the difference between consecutive elements should be equal to the stride - for (int i = 0, j = 1; j < shape[dim]; i++, j++) { - int f_element = ds.getFeatures().tensorAlongDimension(tensorNum, dim).getInt(i); - int f_next_element = ds.getFeatures().tensorAlongDimension(tensorNum, dim).getInt(j); - int f_element_diff = f_next_element - f_element; - assertEquals(f_element_diff, ds_data.stride(dim)); + //Checking Nd dataset which is the data + for (int dim = 1; dim < numDims; dim++) { + //get tensor along dimension - the order in every dimension but zero should be preserved + for (int tensorNum = 0; tensorNum < ds_data.tensorsAlongDimension(dim); tensorNum++) { + //the difference between consecutive elements should be equal to the stride + for (int i = 0, j = 1; j < shape[dim]; i++, j++) { + int f_element = ds.getFeatures().tensorAlongDimension(tensorNum, dim).getInt(i); + int f_next_element = ds.getFeatures().tensorAlongDimension(tensorNum, dim).getInt(j); + int f_element_diff = f_next_element - f_element; + assertEquals(f_element_diff, ds_data.stride(dim)); + } } } - } - //Checking 2d, features - int dim = 1; - //get tensor along dimension - the order in every dimension but zero should be preserved - for (int tensorNum = 0; tensorNum < ds_labels.tensorsAlongDimension(dim); tensorNum++) { - //the difference between consecutive elements should be equal to the stride - for (int i = 0, j = 1; j < nLabels; i++, j++) { - int l_element = ds.getLabels().tensorAlongDimension(tensorNum, dim).getInt(i); - int l_next_element = ds.getLabels().tensorAlongDimension(tensorNum, dim).getInt(j); - int l_element_diff = l_next_element - l_element; - assertEquals(l_element_diff, ds_labels.stride(dim)); + //Checking 2d, features + int dim = 1; + //get tensor along dimension - the order in every dimension but zero should be preserved + for (int tensorNum = 0; tensorNum < ds_labels.tensorsAlongDimension(dim); tensorNum++) { + //the difference between consecutive elements should be equal to the stride + for (int i = 0, j = 1; j < nLabels; i++, j++) { + int l_element = ds.getLabels().tensorAlongDimension(tensorNum, dim).getInt(i); + int l_next_element = ds.getLabels().tensorAlongDimension(tensorNum, dim).getInt(j); + int l_element_diff = l_next_element - l_element; + assertEquals(l_element_diff, ds_labels.stride(dim)); + } } - } } @Test From 2f08af316632406db3c6d23bfb5cbd7ae6848885 Mon Sep 17 00:00:00 2001 From: Fariz Rahman Date: Tue, 28 Jan 2020 12:30:39 +0400 Subject: [PATCH 02/17] Update GpuGraphRunnerTest.java (#195) --- .../org/nd4j/tensorflow/conversion/GpuGraphRunnerTest.java | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/nd4j/nd4j-backends/nd4j-tests-tensorflow/src/test/gpujava/org/nd4j/tensorflow/conversion/GpuGraphRunnerTest.java b/nd4j/nd4j-backends/nd4j-tests-tensorflow/src/test/gpujava/org/nd4j/tensorflow/conversion/GpuGraphRunnerTest.java index b13ca465f..28cd5b7b2 100644 --- a/nd4j/nd4j-backends/nd4j-tests-tensorflow/src/test/gpujava/org/nd4j/tensorflow/conversion/GpuGraphRunnerTest.java +++ b/nd4j/nd4j-backends/nd4j-tests-tensorflow/src/test/gpujava/org/nd4j/tensorflow/conversion/GpuGraphRunnerTest.java @@ -42,8 +42,7 @@ public class GpuGraphRunnerTest extends BaseND4JTest { @Test public void testGraphRunner() throws Exception { - byte[] content = IOUtils.toByteArray(new FileInputStream(new File("C:\\Users\\fariz\\code\\dl4j-test-resources\\src\\main\\resources\\tf_graphs\\nd4j_convert\\simple_graph\\frozen_model.pb"))); - //byte[] content = IOUtils.toByteArray(new ClassPathResource("/tf_graphs/nd4j_convert/simple_graph/frozen_model.pb").getInputStream()); + byte[] content = IOUtils.toByteArray(new ClassPathResource("/tf_graphs/nd4j_convert/simple_graph/frozen_model.pb").getInputStream()); List inputNames = Arrays.asList("input_0","input_1"); ConfigProto configProto = ConfigProto.newBuilder() From 7a7ee4b0217bc031d672d056ec52b36ebc6a4247 Mon Sep 17 00:00:00 2001 From: Yurii Shyrma Date: Tue, 28 Jan 2020 17:23:07 +0200 Subject: [PATCH 03/17] Shyrma cudnn (#192) * - implementation of cudnn batchnorm_bp op Signed-off-by: Yurii * - testing and fixing bugs in batchnorm_bp based on cudnn api Signed-off-by: Yurii * - move pooling mkl code and delete some unnecessary files Signed-off-by: Yurii * - implementation and testing cudnn pooling2d ops (avg/max, ff/bp) Signed-off-by: Yurii * - implementation and testing cudnn pooling 3d (ff/bp) ops Signed-off-by: Yurii * - provide ff step in case of cudnn maxpool3d_bp op Signed-off-by: Yurii * - remove half type from set of supported types in mkl dpethwise conv op Signed-off-by: Yurii * - bring back cudaStreamSynchronize in batchnorm and pooling cudnn ops Signed-off-by: Yurii Co-authored-by: raver119 --- .../ops/declarable/generic/nn/batchnorm.cpp | 3 +- .../generic/nn/pooling/avgpool2d.cpp | 14 +- .../generic/nn/pooling/avgpool3d.cpp | 8 +- .../generic/nn/pooling/maxpool2d.cpp | 1 + .../declarable/platform/cudnn/avgpool2d.cu | 138 ++++ .../declarable/platform/cudnn/avgpool3d.cu | 144 ++++ .../declarable/platform/cudnn/batchnorm.cu | 319 +++++++- .../declarable/platform/cudnn/cudnnUtils.cu | 412 ++++++++++ .../declarable/platform/cudnn/cudnnUtils.h | 145 ++-- .../declarable/platform/cudnn/maxpool2d.cu | 132 +++ .../declarable/platform/cudnn/maxpool3d.cu | 140 ++++ .../platform/mkldnn/avgpooling2d.cpp | 314 +++++--- .../platform/mkldnn/avgpooling2d_bp.cpp | 149 ---- .../ops/declarable/platform/mkldnn/conv2d.cpp | 305 ++++--- .../ops/declarable/platform/mkldnn/conv3d.cpp | 418 +++++----- .../declarable/platform/mkldnn/deconv2d.cpp | 4 +- .../platform/mkldnn/depthwiseConv2d.cpp | 2 +- .../platform/mkldnn/maxpooling2d.cpp | 347 +++++--- .../platform/mkldnn/maxpooling2d_bp.cpp | 174 ---- .../platform/mkldnn/maxpooling3d.cpp | 375 ++++++--- .../platform/mkldnn/maxpooling_3d_bp.cpp | 181 ----- .../platform/mkldnn/mkldnnUtils.cpp | 753 +++++++++--------- .../layers_tests/ConvolutionTests1.cpp | 30 - .../layers_tests/ConvolutionTests2.cpp | 14 +- libnd4j/tests_cpu/layers_tests/CuDnnTests.cu | 20 + .../layers_tests/DeclarableOpsTests13.cpp | 523 ++++++++++++ .../layers_tests/DeclarableOpsTests15.cpp | 75 +- .../layers_tests/DeclarableOpsTests3.cpp | 16 - .../layers_tests/DeclarableOpsTests4.cpp | 234 ++++-- .../layers_tests/DeclarableOpsTests8.cpp | 36 - .../layers_tests/DeclarableOpsTests9.cpp | 337 -------- 31 files changed, 3521 insertions(+), 2242 deletions(-) create mode 100644 libnd4j/include/ops/declarable/platform/cudnn/avgpool2d.cu create mode 100644 libnd4j/include/ops/declarable/platform/cudnn/avgpool3d.cu create mode 100644 libnd4j/include/ops/declarable/platform/cudnn/cudnnUtils.cu create mode 100644 libnd4j/include/ops/declarable/platform/cudnn/maxpool2d.cu create mode 100644 libnd4j/include/ops/declarable/platform/cudnn/maxpool3d.cu delete mode 100644 libnd4j/include/ops/declarable/platform/mkldnn/avgpooling2d_bp.cpp delete mode 100644 libnd4j/include/ops/declarable/platform/mkldnn/maxpooling2d_bp.cpp delete mode 100644 libnd4j/include/ops/declarable/platform/mkldnn/maxpooling_3d_bp.cpp diff --git a/libnd4j/include/ops/declarable/generic/nn/batchnorm.cpp b/libnd4j/include/ops/declarable/generic/nn/batchnorm.cpp index a8cd17131..3cf088ae9 100644 --- a/libnd4j/include/ops/declarable/generic/nn/batchnorm.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/batchnorm.cpp @@ -197,8 +197,7 @@ CUSTOM_OP_IMPL(batchnorm_bp, 4, 3, false, 1, 2) { // ***** calculations ***** // // notations: - // f = g * (gamma * ((x - m) / (v + eps)^0.5) + beta) -> means dLdO * ff_output - // g = dLdO + // f = g * (gamma * ((x - m) / (v + eps)^0.5) + beta) -> means dLdO * ff_output, g = dLdO // stdInv = 1 / (v + eps)^0.5 // N - batch size (product of spatial dimensions) diff --git a/libnd4j/include/ops/declarable/generic/nn/pooling/avgpool2d.cpp b/libnd4j/include/ops/declarable/generic/nn/pooling/avgpool2d.cpp index 4e3314897..873ac545a 100644 --- a/libnd4j/include/ops/declarable/generic/nn/pooling/avgpool2d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/pooling/avgpool2d.cpp @@ -31,31 +31,28 @@ namespace ops { CUSTOM_OP_IMPL(avgpool2d, 1, 1, false, 0, 10) { auto input = INPUT_VARIABLE(0); - - REQUIRE_TRUE(input->rankOf() == 4, 0, "Input should have rank of 4, but got %i instead", input->rankOf()); + auto output = OUTPUT_VARIABLE(0); // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; - auto argI = *(block.getIArguments()); - auto output = OUTPUT_VARIABLE(0); const auto kH = INT_ARG(0); const auto kW = INT_ARG(1); const auto sH = INT_ARG(2); const auto sW = INT_ARG(3); - int pH = INT_ARG(4); - int pW = INT_ARG(5); + auto pH = INT_ARG(4); + auto pW = INT_ARG(5); const auto dH = INT_ARG(6); const auto dW = INT_ARG(7); const auto isSameMode = static_cast(INT_ARG(8)); const auto extraParam0 = INT_ARG(9); + const int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC + REQUIRE_TRUE(input->rankOf() == 4, 0, "AVGPOOL2D op: input should have rank of 4, but got %i instead", input->rankOf()); REQUIRE_TRUE(dH != 0 && dW != 0, 0, "AVGPOOL2D op: dilation must not be zero, but got instead {%i, %i}", dH, dW); int oH = 0; int oW = 0; - int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC - const int iH = static_cast(isNCHW ? input->sizeAt(2) : input->sizeAt(1)); const int iW = static_cast(isNCHW ? input->sizeAt(3) : input->sizeAt(2)); @@ -207,7 +204,6 @@ CUSTOM_OP_IMPL(avgpool2d_bp, 2, 1, false, 0, 10) { } return Status::OK(); - } DECLARE_SHAPE_FN(avgpool2d_bp) { diff --git a/libnd4j/include/ops/declarable/generic/nn/pooling/avgpool3d.cpp b/libnd4j/include/ops/declarable/generic/nn/pooling/avgpool3d.cpp index 3f118e002..b72a1f6f7 100644 --- a/libnd4j/include/ops/declarable/generic/nn/pooling/avgpool3d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/pooling/avgpool3d.cpp @@ -51,14 +51,14 @@ CUSTOM_OP_IMPL(avgpool3dnew, 1, 1, false, 0, 14) { int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 0-NCDHW, 1-NDHWC REQUIRE_TRUE(input->rankOf() == 5, 0, "AVGPOOL3DNEW OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf()); - REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "AVGPOOL3DNEW op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW); + REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "AVGPOOL3DNEW OP: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW); int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); std::string expectedOutputShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2})); - REQUIRE_TRUE(expectedOutputShape == ShapeUtils::shapeAsString(output), 0, "AVGPOOL3D op: wrong shape of output array, expected is %s, but got %s instead !", expectedOutputShape.c_str(), ShapeUtils::shapeAsString(output).c_str()); + REQUIRE_TRUE(expectedOutputShape == ShapeUtils::shapeAsString(output), 0, "AVGPOOL3DNEW OP: wrong shape of output array, expected is %s, but got %s instead !", expectedOutputShape.c_str(), ShapeUtils::shapeAsString(output).c_str()); if(!isNCDHW) { input = new NDArray(input->permute({0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW] @@ -176,8 +176,8 @@ CUSTOM_OP_IMPL(avgpool3dnew_bp, 2, 1, false, 0, 14) { std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2})); std::string expectedGradIShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iD,iH,iW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2})); - REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, "AVGPOOL3D_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str()); - REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0, "AVGPOOL3D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str()); + REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, "AVGPOOL3DNEW_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str()); + REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0, "AVGPOOL3DNEW_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str()); if(!isNCDHW) { input = new NDArray(input->permute({0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW] diff --git a/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool2d.cpp b/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool2d.cpp index eb535a098..13ba252e7 100644 --- a/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool2d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool2d.cpp @@ -32,6 +32,7 @@ namespace ops { ////////////////////////////////////////////////////////////////////////// // maxpool2d corresponds to poolingMode=0 CUSTOM_OP_IMPL(maxpool2d, 1, 1, false, 0, 9) { + auto input = INPUT_VARIABLE(0); REQUIRE_TRUE(input->rankOf() == 4, 0, "MAXPOOL2D OP: input array should have rank of 4, but got %i instead", input->rankOf()); diff --git a/libnd4j/include/ops/declarable/platform/cudnn/avgpool2d.cu b/libnd4j/include/ops/declarable/platform/cudnn/avgpool2d.cu new file mode 100644 index 000000000..8ff0bafb1 --- /dev/null +++ b/libnd4j/include/ops/declarable/platform/cudnn/avgpool2d.cu @@ -0,0 +1,138 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Yurii Shyrma (iuriish@yahoo.com) +// + + +#include "cudnnUtils.h" +#include + +namespace nd4j { +namespace ops { +namespace platforms { + + +////////////////////////////////////////////////////////////////////////// +PLATFORM_IMPL(avgpool2d, ENGINE_CUDA) { + + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; + const auto kH = INT_ARG(0); + const auto kW = INT_ARG(1); + const auto sH = INT_ARG(2); + const auto sW = INT_ARG(3); + auto pH = INT_ARG(4); + auto pW = INT_ARG(5); + const auto dH = INT_ARG(6); + const auto dW = INT_ARG(7); + const auto paddingMode = static_cast(INT_ARG(8)); + const auto extraParam0 = INT_ARG(9); + const int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC + + REQUIRE_TRUE(input->rankOf() == 4, 0, "AVGPOOL2D CUDNN op: input should have rank of 4, but got %i instead", input->rankOf()); + REQUIRE_TRUE(dH != 0 && dW != 0, 0, "AVGPOOL2D CUDNN op: dilation must not be zero, but got instead {%i, %i}", dH, dW); + + int oH = 0; + int oW = 0; + + const int iH = static_cast(isNCHW ? input->sizeAt(2) : input->sizeAt(1)); + const int iW = static_cast(isNCHW ? input->sizeAt(3) : input->sizeAt(2)); + + ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, paddingMode); + + if (paddingMode) + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); + + const cudnnPoolingMode_t mode = (extraParam0 == 0) ? CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING : CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING; + + pooling2dCUDNN(block.launchContext(), input, output, kH, kW, sH, sW, pH, pW, dH, dW, isNCHW, mode); + + return Status::OK(); +} + +////////////////////////////////////////////////////////////////////////// +PLATFORM_CHECK(avgpool2d, ENGINE_CUDA) { + + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + const auto goodType = input->dataType() == DataType::DOUBLE || input->dataType() == DataType::FLOAT32 || input->dataType() == DataType::HALF || input->dataType() == DataType::INT32; + + return goodType && input->dataType() == output->dataType(); +} + +////////////////////////////////////////////////////////////////////////// +PLATFORM_IMPL(avgpool2d_bp, ENGINE_CUDA) { + + auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + auto gradO = INPUT_VARIABLE(1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next + auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon + + const auto kH = INT_ARG(0); // filter(kernel) height + const auto kW = INT_ARG(1); // filter(kernel) width + const auto sH = INT_ARG(2); // strides height + const auto sW = INT_ARG(3); // strides width + auto pH = INT_ARG(4); // paddings height + auto pW = INT_ARG(5); // paddings width + const auto dH = INT_ARG(6); // dilations height + const auto dW = INT_ARG(7); // dilations width + const auto paddingMode = INT_ARG(8); // 0-VALID, 1-SAME + const auto extraParam0 = INT_ARG(9); + const auto isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC + + REQUIRE_TRUE(input->rankOf() == 4, 0, "AVGPOOL2D_BP CUDNN op: input should have rank of 4, but got %i instead", input->rankOf()); + REQUIRE_TRUE(dH != 0 && dW != 0, 0, "AVGPOOL2D_BP CUDNN op: dilation must not be zero, but got instead {%i, %i}", dH, dW); + + int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); + + std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oH,oW, 0,indIOioC,indIiH,indIiH+1}); + std::vector expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iH,iW, 0,indIOioC,indIiH,indIiH+1}); + REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "AVGPOOL2D_BP CUDNN op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); + REQUIRE_TRUE(gradI->isSameShape(expectedGradIShape), 0, "AVGPOOL2D_BP CUDNN op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradIShape).c_str(), ShapeUtils::shapeAsString(gradI).c_str()); + + if(paddingMode) // SAME + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); + + const cudnnPoolingMode_t mode = (extraParam0 == 0) ? CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING : CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING; + + pooling2dBpCUDNN(block.launchContext(), input, gradO, gradI, kH, kW, sH, sW, pH, pW, dH, dW, isNCHW, mode); + + return Status::OK(); +} + +PLATFORM_CHECK(avgpool2d_bp, ENGINE_CUDA) { + + auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + auto gradO = INPUT_VARIABLE(1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next + auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon + + const auto goodType = input->dataType() == DataType::DOUBLE || input->dataType() == DataType::FLOAT32 || input->dataType() == DataType::HALF || input->dataType() == DataType::INT32; + + return goodType && (input->dataType() == gradO->dataType()) + && (input->dataType() == gradI->dataType()) + && shape::haveSameShapeAndStrides(input->getShapeInfo(), gradI->getShapeInfo()); +} + + +} +} +} diff --git a/libnd4j/include/ops/declarable/platform/cudnn/avgpool3d.cu b/libnd4j/include/ops/declarable/platform/cudnn/avgpool3d.cu new file mode 100644 index 000000000..878f306b3 --- /dev/null +++ b/libnd4j/include/ops/declarable/platform/cudnn/avgpool3d.cu @@ -0,0 +1,144 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Yurii Shyrma (iuriish@yahoo.com) +// + + +#include "cudnnUtils.h" +#include + +namespace nd4j { +namespace ops { +namespace platforms { + + +////////////////////////////////////////////////////////////////////////// +PLATFORM_IMPL(avgpool3dnew, ENGINE_CUDA) { + + auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) + auto output = OUTPUT_VARIABLE(0); // [bS, oD, oH, oW, iC] (NDHWC) or [bS, iC, oD, oH, oW] (NCDHW) + + int kD = INT_ARG(0); // filter(kernel) depth + int kH = INT_ARG(1); // filter(kernel) height + int kW = INT_ARG(2); // filter(kernel) width + int sD = INT_ARG(3); // strides depth + int sH = INT_ARG(4); // strides height + int sW = INT_ARG(5); // strides width + int pD = INT_ARG(6); // paddings depth + int pH = INT_ARG(7); // paddings height + int pW = INT_ARG(8); // paddings width + int dD = INT_ARG(9); // dilations depth + int dH = INT_ARG(10); // dilations height + int dW = INT_ARG(11); // dilations width + int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID + int extraParam0 = INT_ARG(13); + int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 0-NCDHW, 1-NDHWC + + REQUIRE_TRUE(input->rankOf() == 5, 0, "AVGPOOL3DNEW CUDNN OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf()); + REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "AVGPOOL3DNEW CUDNN OP: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW); + + int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; + int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); + + std::vector expectedOutputShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); + REQUIRE_TRUE(output->isSameShape(expectedOutputShape), 0, "AVGPOOL3DNEW CUDNN OP: wrong shape of output array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedOutputShape).c_str(), ShapeUtils::shapeAsString(output).c_str()); + + if(paddingMode) // SAME + ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW); + + const cudnnPoolingMode_t mode = (extraParam0 == 0) ? CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING : CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING; + + pooling3dCUDNN(block.launchContext(), input, output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isNCDHW, mode); + + return Status::OK(); +} + +////////////////////////////////////////////////////////////////////////// +PLATFORM_CHECK(avgpool3dnew, ENGINE_CUDA) { + + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + const auto goodType = input->dataType() == DataType::DOUBLE || input->dataType() == DataType::FLOAT32 || input->dataType() == DataType::HALF || input->dataType() == DataType::INT32; + + return goodType && input->dataType() == output->dataType(); +} + +////////////////////////////////////////////////////////////////////////// +PLATFORM_IMPL(avgpool3dnew_bp, ENGINE_CUDA) { + + auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) + auto gradO = INPUT_VARIABLE(1); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next + auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon + + const int kD = INT_ARG(0); // filter(kernel) depth + const int kH = INT_ARG(1); // filter(kernel) height + const int kW = INT_ARG(2); // filter(kernel) width + const int sD = INT_ARG(3); // strides depth + const int sH = INT_ARG(4); // strides height + const int sW = INT_ARG(5); // strides width + int pD = INT_ARG(6); // paddings depth + int pH = INT_ARG(7); // paddings height + int pW = INT_ARG(8); // paddings width + const int dD = INT_ARG(9); // dilations depth + const int dH = INT_ARG(10); // dilations height + const int dW = INT_ARG(11); // dilations width + const int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID + const int extraParam0 = INT_ARG(13); // define what divisor to use while averaging + const int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 0-NCDHW, 1-NDHWC + + REQUIRE_TRUE(input->rankOf() == 5, 0, "AVGPOOL3DNEW_BP CUDNN OP: input should have rank of 5, but got %i instead", input->rankOf()); + REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "AVGPOOL3DNEW_BP CUDNN OP: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW); + + int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; + int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); + + std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); + std::vector expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iD,iH,iW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); + REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "AVGPOOL3DNEW_BP CUDNN: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); + REQUIRE_TRUE(gradI->isSameShape(expectedGradIShape), 0, "AVGPOOL3DNEW_BP CUDNN: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradIShape).c_str(), ShapeUtils::shapeAsString(gradI).c_str()); + + if(isSameMode) // SAME + ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW); + + const cudnnPoolingMode_t mode = (extraParam0 == 0) ? CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING : CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING; + + pooling3dBpCUDNN(block.launchContext(), input, gradO, gradI, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isNCDHW, mode); + + return Status::OK(); +} + +PLATFORM_CHECK(avgpool3dnew_bp, ENGINE_CUDA) { + + auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) + auto gradO = INPUT_VARIABLE(1); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next + auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon + + const auto goodType = input->dataType() == DataType::DOUBLE || input->dataType() == DataType::FLOAT32 || input->dataType() == DataType::HALF || input->dataType() == DataType::INT32; + + return goodType && (input->dataType() == gradO->dataType()) + && (input->dataType() == gradI->dataType()) + && shape::haveSameShapeAndStrides(input->getShapeInfo(), gradI->getShapeInfo()); +} + + +} +} +} diff --git a/libnd4j/include/ops/declarable/platform/cudnn/batchnorm.cu b/libnd4j/include/ops/declarable/platform/cudnn/batchnorm.cu index 3bd1357bf..1177d1a3c 100644 --- a/libnd4j/include/ops/declarable/platform/cudnn/batchnorm.cu +++ b/libnd4j/include/ops/declarable/platform/cudnn/batchnorm.cu @@ -97,9 +97,6 @@ static void batchnormCUDNN(const LaunchContext* context, err = cudnnSetTensorNdDescriptor(params, dataType, xRank, paramsShape.data(), paramsStrides.data()); if (err != 0) throw nd4j::cuda_exception::build("batchnormCUDNN: cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for mean/variance/gamma/beta failed", err); - - if (err != 0) throw nd4j::cuda_exception::build("batchnormCUDNN: cudnnSetConvolutionNdDescriptor failed", err); - // provide scaling parameters const float alpha32(1), beta32(0); const double alpha64(1), beta64(0); @@ -114,20 +111,127 @@ static void batchnormCUDNN(const LaunchContext* context, x, input->getSpecialBuffer(), z, output->getSpecialBuffer(), params, - gamma ? gamma->getSpecialBuffer(): nullptr, - beta ? beta->getSpecialBuffer() : nullptr, + gamma->getSpecialBuffer(), beta->getSpecialBuffer(), mean->getSpecialBuffer(), variance->getSpecialBuffer(), epsilon); if (err != 0) throw nd4j::cuda_exception::build("batchnormCUDNN: cudnnBatchNormalizationForwardInference failed", err); - // cudaErr = cudaStreamSynchronize(*context->getCudaStream()); - // if (cudaErr != 0) - // throw cuda_exception::build("batchnormCUDNN: cudaStreamSynchronize failed !", cudaErr); - + auto cudaErr = cudaStreamSynchronize(*context->getCudaStream()); + if (cudaErr != 0) + throw cuda_exception::build("batchnormCUDNN: cudaStreamSynchronize failed !", cudaErr); NDArray::registerSpecialUse({output}, {input, mean, variance, gamma, beta}); } +////////////////////////////////////////////////////////////////////////// +static void batchnormBpCUDNN(const LaunchContext* context, + const NDArray* input, const NDArray* mean, const NDArray* variance, const NDArray* gamma, const NDArray* gradO, + NDArray* gradI, NDArray* gradG, NDArray* gradB, + const double epsilon, const bool isSpatialMode) { + + // input, gradO, gradI -> 4D:nchw, 5D:ncdhw + // mean, variance, gamma, beta, gradM, gradV, gradG, gradB -> 1xCx1x1 for 4D and 1xCx1x1x1 for 5D for BATCHNORM_MODE_SPATIAL mode + // -> 1xCxHxW for 4D and 1xCxDxHxW for 5D for BATCHNORM_MODE_PER_ACTIVATION mode + + const cudnnDataType_t dataType = cudnnDataType(input->dataType()); + + const int xRank = input->rankOf(); + + auto handle = reinterpret_cast(context->getCuDnnHandle()); + cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream()); + if (err != 0) throw nd4j::cuda_exception::build("batchnormBpCUDNN: can't set stream for cuDNN", err); + + const std::vector xShape = input->getShapeAsVectorInt(); // input and output have same shapes + + std::vector paramsShape, paramsStrides; // mean, variance, gamma and beta have same shapes + if(isSpatialMode) { // 1xCx1x1 + const int iC = mean->lengthOf(); + const int stride0 = mean->strideAt(0); + paramsShape = xRank == 4 ? std::vector({1, iC, 1, 1}) : std::vector({1, iC, 1, 1, 1}); + paramsStrides = xRank == 4 ? std::vector({iC*stride0, stride0, 1, 1}) : std::vector({iC*stride0, stride0, 1, 1, 1}); + } + else { + paramsShape = mean->getShapeAsVectorInt(); + paramsStrides = xRank == 4 ? std::vector({(int)mean->strideAt(0), (int)mean->strideAt(1), (int)mean->strideAt(2), (int)mean->strideAt(3)}) : std::vector({(int)mean->strideAt(0), (int)mean->strideAt(1), (int)mean->strideAt(2), (int)mean->strideAt(3), (int)mean->strideAt(4)}); + } + + std::vector xStrides = {(int)input->strideAt(0), (int)input->strideAt(1), (int)input->strideAt(2), (int)input->strideAt(3)}; + std::vector dxStrides = {(int)gradI->strideAt(0), (int)gradI->strideAt(1), (int)gradI->strideAt(2), (int)gradI->strideAt(3)}; + std::vector dzStrides = {(int)gradO->strideAt(0), (int)gradO->strideAt(1), (int)gradO->strideAt(2), (int)gradO->strideAt(3)}; + + if(xRank > 4) { // 5D + xStrides.push_back((int)input->strideAt(4)); + dxStrides.push_back((int)gradI->strideAt(4)); + dzStrides.push_back((int)gradO->strideAt(4)); + } + + cudnnTensorFormat_t format = CUDNN_TENSOR_NCHW; + + // input descriptor + cudnnTensorDescriptor_t x; + cudnnCreateTensorDescriptor(&x); + if(input->ews() == 1) + err = cudnnSetTensorNdDescriptorEx(x, format, dataType, xRank, xShape.data()); + else + err = cudnnSetTensorNdDescriptor(x, dataType, xRank, xShape.data(), xStrides.data()); + if (err != 0) throw nd4j::cuda_exception::build("batchnormBpCUDNN: cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for input failed", err); + + // gradO descriptor + cudnnTensorDescriptor_t dz; + cudnnCreateTensorDescriptor(&dz); + if(gradO->ews() == 1) + err = cudnnSetTensorNdDescriptorEx(dz, format, dataType, xRank, xShape.data()); + else + err = cudnnSetTensorNdDescriptor(dz, dataType, xRank, xShape.data(), dzStrides.data()); + if (err != 0) throw nd4j::cuda_exception::build("batchnormBpCUDNN: cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for gradO failed", err); + + // gradI descriptor + cudnnTensorDescriptor_t dx; + cudnnCreateTensorDescriptor(&dx); + if(input->ews() == 1) + err = cudnnSetTensorNdDescriptorEx(dx, format, dataType, xRank, xShape.data()); + else + err = cudnnSetTensorNdDescriptor(dx, dataType, xRank, xShape.data(), dxStrides.data()); + if (err != 0) throw nd4j::cuda_exception::build("batchnormBpCUDNN: cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for gradI failed", err); + + // mean, variance, gamma, gradG and gradB descriptor, the same descriptor for all of them + cudnnTensorDescriptor_t params; + cudnnCreateTensorDescriptor(¶ms); + if(mean->ews() == 1) + err = cudnnSetTensorNdDescriptorEx(params, format, dataType, xRank, paramsShape.data()); + else + err = cudnnSetTensorNdDescriptor(params, dataType, xRank, paramsShape.data(), paramsStrides.data()); + if (err != 0) throw nd4j::cuda_exception::build("batchnormBpCUDNN: cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for mean/variance/gamma/gradG/gradB failed", err); + + // provide scaling parameters + const float alpha32(1), beta32(0); + double alpha64(1), beta64(0); + const void* ptrAlpha = input->sizeOfT() <= 4 ? reinterpret_cast(&alpha32) : reinterpret_cast(&alpha64); + const void* ptrBeta = input->sizeOfT() <= 4 ? reinterpret_cast(&beta32) : reinterpret_cast(&beta64); + + NDArray::prepareSpecialUse({gradI, gradG, gradB}, {input, mean, variance, gamma, gradO}); + + // calculations + // TODO: we can use cache here + err = cudnnBatchNormalizationBackward(*handle, isSpatialMode ? CUDNN_BATCHNORM_SPATIAL : CUDNN_BATCHNORM_PER_ACTIVATION, + ptrAlpha, ptrBeta, ptrAlpha, ptrBeta, + x, input->getSpecialBuffer(), + dz, gradO->getSpecialBuffer(), + dx, gradI->getSpecialBuffer(), + params, + gamma->getSpecialBuffer(), gradG->getSpecialBuffer(), gradB->getSpecialBuffer(), + epsilon, + nullptr/*mean->getSpecialBuffer()*/, nullptr/*variance->getSpecialBuffer()*/); + + if (err != 0) throw nd4j::cuda_exception::build("batchnormBpCUDNN: cudnnBatchNormalizationBackward failed", err); + + auto cudaErr = cudaStreamSynchronize(*context->getCudaStream()); + if (cudaErr != 0) + throw cuda_exception::build("batchnormBpCUDNN: cudaStreamSynchronize failed !", cudaErr); + + NDArray::registerSpecialUse({gradI, gradG, gradB}, {input, mean, variance, gamma, gradO}); +} + ////////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(batchnorm, ENGINE_CUDA) { @@ -189,11 +293,21 @@ PLATFORM_IMPL(batchnorm, ENGINE_CUDA) { const bool needPermut = axes.size() == 1 && mean->lengthOf() == input->sizeAt(-1); if(needPermut) { // if NHWC - std::vector perm = {0, 3, 1, 2}; // NHWC -> NCHW + std::vector perm = inRank == 4 ? std::vector({0, 3, 1, 2}) : std::vector({0, 4, 1, 2, 3}); // NHWC -> NCHW input = new NDArray(input->permute(perm)); output = new NDArray(output->permute(perm)); } + // cudnn requires gamma and beta to be non-nullptr + if(!applyScale) { + gamma = new NDArray(mean); + *gamma = 1; + } + if(!applyOffset) { + beta = new NDArray(mean); + *beta = 0; + } + // calculations batchnormCUDNN(block.launchContext(), input, mean, variance, gamma, beta, output, epsilon, axes.size() == 1); @@ -202,6 +316,12 @@ PLATFORM_IMPL(batchnorm, ENGINE_CUDA) { delete output; } + if(!applyScale) + delete gamma; + + if(!applyOffset) + delete beta; + return Status::OK(); } @@ -220,9 +340,6 @@ PLATFORM_CHECK(batchnorm, ENGINE_CUDA) { const int numOfIntArgs = block.getIArguments()->size(); const int xRank = input->rankOf(); - // disable cudnn batchnorm so far - return false; - // *********************************** // if(xRank != 4 && xRank != 5) return false; @@ -269,6 +386,182 @@ PLATFORM_CHECK(batchnorm, ENGINE_CUDA) { return true; } +////////////////////////////////////////////////////////////////////////// +PLATFORM_IMPL(batchnorm_bp, ENGINE_CUDA) { + + NDArray* input = INPUT_VARIABLE(0); + NDArray* mean = INPUT_VARIABLE(1); + NDArray* variance = INPUT_VARIABLE(2); + NDArray* gamma = nullptr; + NDArray* beta = nullptr; + NDArray* gradO = INPUT_VARIABLE(block.width() - 1); // next epsilon + + NDArray* gradI = OUTPUT_VARIABLE(0); + NDArray* gradM = OUTPUT_VARIABLE(1); + NDArray* gradV = OUTPUT_VARIABLE(2); + NDArray* gradG = nullptr; + NDArray* gradB = nullptr; + + const bool applyScale = (bool)INT_ARG(0); + const bool applyOffset = (bool)INT_ARG(1); + const float epsilon = T_ARG(0); + + if(applyScale) { + gamma = INPUT_VARIABLE(3); + gradG = OUTPUT_VARIABLE(3); + } + if(applyOffset) { + beta = INPUT_VARIABLE(3 + (int)applyScale); + gradB = OUTPUT_VARIABLE(3 + (int)applyScale); + } + + const int numOfIntArgs = block.getIArguments()->size(); + const int inRank = input->rankOf(); + + // get axes args to normalize input array over + std::vector axes; + if(numOfIntArgs > 2) + for(int i = 2; i < numOfIntArgs; ++i) + axes.push_back(INT_ARG(i)); + else + axes.push_back(inRank-1); // default dimension to reduce along is last dimension + + const int numOfAxes = axes.size(); + REQUIRE_TRUE(numOfAxes <= inRank, 0, "BATCHNORM_BP CUDNN op: too big number of input axes to normalize over, expected number should be less or equal to rank of input array, but got %i and %i correspondingly !", numOfAxes, inRank); + + // evaluate expected shape for mean, variance and gamma. These 3 arrays should have identical shapes + // for example if input shape is {2,3,4,5,6} and axes = {1,3}, then expected shape would be {1,3,1,5,1}, and if axes = {3}, then expected shape would be {5} + std::vector expShape; + if(numOfAxes == 1) + expShape.push_back(input->sizeAt(axes[0])); + else { // get, for example, something like {1, inputDim1, 1, inputDim3, 1} if axes = {1, 3} + expShape = std::vector(inRank, 1); + for(uint i = 0; i < numOfAxes; ++i) + expShape[axes[i]] = input->sizeAt(axes[i]); + } + + REQUIRE_TRUE(mean->isSameShape(expShape), 0, "BATCHNORM_BP CUDNN op: wrong shape of mean array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expShape).c_str(), ShapeUtils::shapeAsString(mean).c_str()); + REQUIRE_TRUE(variance->isSameShape(expShape), 0, "BATCHNORM_BP CUDNN op: wrong shape of variance array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expShape).c_str(), ShapeUtils::shapeAsString(variance).c_str()); + if(gamma) + REQUIRE_TRUE(gamma->isSameShape(expShape), 0, "BATCHNORM_BP CUDNN op: wrong shape of gamma array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expShape).c_str(), ShapeUtils::shapeAsString(gamma).c_str()); + if(beta) + REQUIRE_TRUE(beta->isSameShape(expShape), 0, "BATCHNORM_BP CUDNN op: wrong shape of beta array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expShape).c_str(), ShapeUtils::shapeAsString(beta).c_str()); + + REQUIRE_TRUE(input->isSameShape(gradO), 0, "BATCHNORM_BP CUDNN op: wrong shape of output gradients array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(input).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); + + // types of all input arrays should be the same (except gradO) + for(int i = 1; i < block.width() - 2; ++i) + REQUIRE_TRUE(INPUT_VARIABLE(0)->dataType() == INPUT_VARIABLE(i)->dataType(), 0, "BATCHNORM_BP CUDNN op: types of arrays (input, mean, variance, gamma, beta) should be the same !"); + + // cudnn supports NCHW format only + const bool needPermut = axes.size() == 1 && mean->lengthOf() != input->sizeAt(1); + + if(needPermut) { // if NHWC + std::vector perm = inRank == 4 ? std::vector({0, 3, 1, 2}) : std::vector({0, 4, 1, 2, 3}); // NHWC -> NCHW + input = new NDArray(input->permute(perm)); + gradO = new NDArray(gradO->permute(perm)); + gradI = new NDArray(gradI->permute(perm)); + } + + // cudnn requires gamma, gradG, gradB to be non-nullptr + if(!applyScale) { + gamma = new NDArray(mean); + gradG = new NDArray(mean); + *gamma = 1; + } + if(!applyOffset) + gradB = new NDArray(mean); + + // calculations + batchnormBpCUDNN(block.launchContext(), input, mean, variance, gamma, gradO, gradI, gradG, gradB, epsilon, axes.size() == 1); + + *gradM = 0; // put zeros so far + *gradV = 0; // put zeros so far + + if(needPermut) { + delete input; + delete gradO; + delete gradI; + } + + if(!applyScale) { + delete gamma; + delete gradG; + } + + if(!applyOffset) + delete gradB; + + return Status::OK(); + +} + +PLATFORM_CHECK(batchnorm_bp, ENGINE_CUDA) { + + NDArray* input = INPUT_VARIABLE(0); + NDArray* mean = INPUT_VARIABLE(1); + NDArray* variance = INPUT_VARIABLE(2); + NDArray* gamma = nullptr; + NDArray* beta = nullptr; + NDArray* gradO = INPUT_VARIABLE(block.width() - 1); // next epsilon + + NDArray* gradI = OUTPUT_VARIABLE(0); + NDArray* gradM = OUTPUT_VARIABLE(1); + NDArray* gradV = OUTPUT_VARIABLE(2); + NDArray* gradG = nullptr; + NDArray* gradB = nullptr; + + const int numOfIntArgs = block.getIArguments()->size(); + const int xRank = input->rankOf(); + + // *********************************** // + if(xRank != 4 && xRank != 5) + return false; + + // *********************************** // + const bool badType = input->dataType() != DataType::DOUBLE && input->dataType() != DataType::FLOAT32 && input->dataType() != DataType::HALF; + if(badType) + return false; + + // *********************************** // + // get axes args to normalize input array over + std::vector axes; + if(numOfIntArgs > 2) + for(int i = 2; i < numOfIntArgs; ++i) + axes.push_back(INT_ARG(i)); + else + axes.push_back(xRank-1); // default dimension to reduce along is last dimension + + if(axes.size() != 1 && axes.size() != 3 && axes.size() != 4) + return false; + + // *********************************** // + bool allParamsHaveSameShapeAndStrides = shape::haveSameShapeAndStrides(mean->getShapeInfo(), variance->getShapeInfo()); + if(gamma) + allParamsHaveSameShapeAndStrides &= shape::haveSameShapeAndStrides(mean->getShapeInfo(), gamma->getShapeInfo()); + if(gradG) + allParamsHaveSameShapeAndStrides &= shape::haveSameShapeAndStrides(mean->getShapeInfo(), gradG->getShapeInfo()); + if(gradB) + allParamsHaveSameShapeAndStrides &= shape::haveSameShapeAndStrides(mean->getShapeInfo(), gradB->getShapeInfo()); + + if(!allParamsHaveSameShapeAndStrides) + return false; + + // *********************************** // + bool isFormatGood = false; + if(axes.size() == 1) + isFormatGood = mean->lengthOf() == input->sizeAt(1) || mean->lengthOf() == input->sizeAt(-1); // mean [C] + else { + auto inputShapeModif = input->getShapeAsVector(); // [dim0,dim1,dim2,dim3] 4D or [dim0,dim1,dim2,dim3,dim4] + inputShapeModif[0] = 1; + isFormatGood = mean->isSameShape(inputShapeModif); // mean [1,dim1,dim2,dim3] 4D or [1,dim1,dim2,dim3,dim4] + } + if(!isFormatGood) + return false; + + return true; +} + } } diff --git a/libnd4j/include/ops/declarable/platform/cudnn/cudnnUtils.cu b/libnd4j/include/ops/declarable/platform/cudnn/cudnnUtils.cu new file mode 100644 index 000000000..fa7b1ecfa --- /dev/null +++ b/libnd4j/include/ops/declarable/platform/cudnn/cudnnUtils.cu @@ -0,0 +1,412 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Yurii Shyrma (iuriish@yahoo.com) +// + + +#include "cudnnUtils.h" +#include + +namespace nd4j { +namespace ops { +namespace platforms { + +////////////////////////////////////////////////////////////////////////// +void checkConv2dCUDNNPadAsymmetric(NDArray* &input, NDArray* &gradI, + const int iH, const int iW, + const int oH, const int oW, + const int kH, const int kW, + const int sH, const int sW, + const int pH, const int pW, + const int dH, const int dW, + const bool isNCHW) { + + const auto pHsum = ((oH - 1) * sH + ((kH - 1) * dH + 1) - iH); + const auto pWsum = ((oW - 1) * sW + ((kW - 1) * dW + 1) - iW); + + const bool isPHasymm = pH != (pHsum - pH); + const bool isPWasymm = pW != (pWsum - pW); + + if(!isPHasymm && !isPWasymm) + return; + + std::vector newShape = input->getShapeAsVector(); + + const int iHposition = isNCHW ? 2 : 1; + + if(isPHasymm) + newShape[iHposition] += 1; + if(isPWasymm) + newShape[iHposition + 1] += 1; + + NDArray* newInput = new NDArray(input->ordering(), newShape, input->dataType(), input->getContext()); + + if(isNCHW) + (*newInput)({0,0, 0,0, 0,input->sizeAt(2), 0,input->sizeAt(3)}).assign(input); + else + (*newInput)({0,0, 0,input->sizeAt(1), 0,input->sizeAt(2), 0,0}).assign(input); + + input = newInput; + + if(gradI != nullptr) + gradI = new NDArray(gradI->ordering(), newShape, gradI->dataType(), gradI->getContext()); +} + + +////////////////////////////////////////////////////////////////////////// +void checkConv3dCUDNNPadAsymmetric(NDArray* &input, NDArray* &gradI, + const int iD, const int iH, const int iW, + const int oD, const int oH, const int oW, + const int kD, const int kH, const int kW, + const int sD, const int sH, const int sW, + const int pD, const int pH, const int pW, + const int dD, const int dH, const int dW, + const bool isNCDHW) { + + const auto pDsum = ((oD - 1) * sD + ((kD - 1) * dD + 1) - iD); + const auto pHsum = ((oH - 1) * sH + ((kH - 1) * dH + 1) - iH); + const auto pWsum = ((oW - 1) * sW + ((kW - 1) * dW + 1) - iW); + + const bool isPDasymm = pD != (pDsum - pD); + const bool isPHasymm = pH != (pHsum - pH); + const bool isPWasymm = pW != (pWsum - pW); + + if(!isPDasymm && !isPHasymm && !isPWasymm) + return; + + std::vector newShape = input->getShapeAsVector(); + + const int iDposition = isNCDHW ? 2 : 1; + + if(isPDasymm) + newShape[iDposition] += 1; + if(isPHasymm) + newShape[iDposition + 1] += 1; + if(isPWasymm) + newShape[iDposition + 2] += 1; + + NDArray* newInput = new NDArray(input->ordering(), newShape, input->dataType(), input->getContext()); + + if(isNCDHW) + (*newInput)({0,0, 0,0, 0,input->sizeAt(2), 0,input->sizeAt(3), 0,input->sizeAt(4)}).assign(input); + else + (*newInput)({0,0, 0,input->sizeAt(1), 0,input->sizeAt(2), 0,input->sizeAt(3), 0,0}).assign(input); + + input = newInput; + + if(gradI != nullptr) + gradI = new NDArray(gradI->ordering(), newShape, gradI->dataType(), gradI->getContext()); +} + +////////////////////////////////////////////////////////////////////////// +void pooling2dCUDNN(const LaunchContext* context, + const NDArray* input, NDArray* output, + const int kH, const int kW, + const int sH, const int sW, + const int pH, const int pW, + const int dH, const int dW, + const bool isNCHW, const cudnnPoolingMode_t mode) { + + int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); + + auto handle = reinterpret_cast(context->getCuDnnHandle()); + cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream()); + if (err != 0) throw nd4j::cuda_exception::build("pooling2dCUDNN: can't set stream for cuDNN", err); + + cudnnTensorFormat_t format = isNCHW ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC; + + // input descriptor + cudnnTensorDescriptor_t x; + cudnnCreateTensorDescriptor(&x); + if(input->ews() == 1) + err = cudnnSetTensor4dDescriptor(x, format, cudnnDataType(input->dataType()), bS, iC, iH, iW); + else + err = cudnnSetTensor4dDescriptorEx(x, cudnnDataType(input->dataType()), bS, iC, iH, iW, input->strideAt(0), input->strideAt(indIOioC), input->strideAt(indIiH), input->strideAt(indIiH + 1)); + if (err != 0) throw nd4j::cuda_exception::build("pooling2dCUDNN: cudnnSetTensor4dDescriptor/cudnnSetTensor4dDescriptorEx for input failed", err); + + // output descriptor + cudnnTensorDescriptor_t z; + cudnnCreateTensorDescriptor(&z); + if(output->ews() == 1) + err = cudnnSetTensor4dDescriptor(z, format, cudnnDataType(output->dataType()), bS, oC, oH, oW); + else + err = cudnnSetTensor4dDescriptorEx(z, cudnnDataType(output->dataType()), bS, oC, oH, oW, output->strideAt(0), output->strideAt(indIOioC), output->strideAt(indOoH), output->strideAt(indOoH + 1)); + if (err != 0) throw nd4j::cuda_exception::build("pooling2dCUDNN: cudnnSetTensor4dDescriptor/cudnnSetTensor4dDescriptorEx for output failed", err); + + // description of pooling + cudnnPoolingDescriptor_t pooling; + cudnnCreatePoolingDescriptor(&pooling); + err = cudnnSetPooling2dDescriptor(pooling, mode, CUDNN_PROPAGATE_NAN, kH, kW, pH, pW, sH, sW); + if (err != 0) throw nd4j::cuda_exception::build("pooling2dCUDNN: cudnnSetPooling2dDescriptor failed", err); + + // provide scaling parameters + const float alpha32(1), beta32(0); + const double alpha64(1), beta64(0); + const void* alpha = output->sizeOfT() <= 4 ? reinterpret_cast(&alpha32) : reinterpret_cast(&alpha64); + const void* beta = output->sizeOfT() <= 4 ? reinterpret_cast(&beta32) : reinterpret_cast(&beta64); + + NDArray::prepareSpecialUse({output}, {input}); + + // run calculation + err = cudnnPoolingForward(*handle, pooling, alpha, x, input->getSpecialBuffer(), beta, z, output->specialBuffer()); + if (err != 0) throw nd4j::cuda_exception::build("pooling2dCUDNN: cudnnPoolingForward failed", err); + + auto cudaErr = cudaStreamSynchronize(*context->getCudaStream()); + if (cudaErr != 0) + throw cuda_exception::build("pooling2dCUDNN: cudaStreamSynchronize failed !", cudaErr); + + NDArray::registerSpecialUse({output}, {input}); +} + +////////////////////////////////////////////////////////////////////////// +void pooling2dBpCUDNN(const LaunchContext* context, + const NDArray* input, const NDArray* gradO, + NDArray* gradI, + const int kH, const int kW, + const int sH, const int sW, + const int pH, const int pW, + const int dH, const int dW, + const bool isNCHW, const cudnnPoolingMode_t mode) { + + int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); + + auto handle = reinterpret_cast(context->getCuDnnHandle()); + cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream()); + if (err != 0) throw nd4j::cuda_exception::build("pooling2dBpCUDNN: can't set stream for cuDNN", err); + + cudnnTensorFormat_t format = isNCHW ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC; + + // input and gradI descriptor + cudnnTensorDescriptor_t x; + cudnnCreateTensorDescriptor(&x); + if(input->ews() == 1) + err = cudnnSetTensor4dDescriptor(x, format, cudnnDataType(input->dataType()), bS, iC, iH, iW); + else + err = cudnnSetTensor4dDescriptorEx(x, cudnnDataType(input->dataType()), bS, iC, iH, iW, input->strideAt(0), input->strideAt(indIOioC), input->strideAt(indIiH), input->strideAt(indIiH + 1)); + if (err != 0) throw nd4j::cuda_exception::build("pooling2dBpCUDNN: cudnnSetTensor4dDescriptor/cudnnSetTensor4dDescriptorEx for input/gradI failed", err); + + // gradO descriptor + cudnnTensorDescriptor_t dz; + cudnnCreateTensorDescriptor(&dz); + if(gradO->ews() == 1) + err = cudnnSetTensor4dDescriptor(dz, format, cudnnDataType(gradO->dataType()), bS, oC, oH, oW); + else + err = cudnnSetTensor4dDescriptorEx(dz, cudnnDataType(gradO->dataType()), bS, oC, oH, oW, gradO->strideAt(0), gradO->strideAt(indIOioC), gradO->strideAt(indOoH), gradO->strideAt(indOoH + 1)); + if (err != 0) throw nd4j::cuda_exception::build("pooling2dBpCUDNN: cudnnSetTensor4dDescriptor/cudnnSetTensor4dDescriptorEx for gradO failed", err); + + // description of pooling + cudnnPoolingDescriptor_t pooling; + cudnnCreatePoolingDescriptor(&pooling); + err = cudnnSetPooling2dDescriptor(pooling, mode, CUDNN_PROPAGATE_NAN, kH, kW, pH, pW, sH, sW); + if (err != 0) throw nd4j::cuda_exception::build("pooling2dBpCUDNN: cudnnSetPooling2dDescriptor failed", err); + + // provide scaling parameters + const float alpha32(1), beta32(0); + const double alpha64(1), beta64(0); + const void* alpha = gradO->sizeOfT() <= 4 ? reinterpret_cast(&alpha32) : reinterpret_cast(&alpha64); + const void* beta = gradO->sizeOfT() <= 4 ? reinterpret_cast(&beta32) : reinterpret_cast(&beta64); + + NDArray::prepareSpecialUse({gradI}, {input, gradO}); + + // run calculation for gradI + err = cudnnPoolingBackward(*handle, pooling, alpha, dz, gradO->getSpecialBuffer(), dz, gradO->getSpecialBuffer(), x, input->getSpecialBuffer(), beta, x, gradI->getSpecialBuffer()); + if (err != 0) throw nd4j::cuda_exception::build("pooling2dBpCUDNN: cudnnPoolingBackward failed", err); + + auto cudaErr = cudaStreamSynchronize(*context->getCudaStream()); + if (cudaErr != 0) + throw cuda_exception::build("pooling2dBpCUDNN: cudaStreamSynchronize failed !", cudaErr); + + NDArray::registerSpecialUse({gradI}, {input, gradO}); +} + +////////////////////////////////////////////////////////////////////////// +void pooling3dCUDNN(const LaunchContext* context, + const NDArray* input, NDArray* output, + const int kD, const int kH, const int kW, + const int sD, const int sH, const int sW, + const int pD, const int pH, const int pW, + const int dD, const int dH, const int dW, + const bool isNCDHW, const cudnnPoolingMode_t mode) { + + auto handle = reinterpret_cast(context->getCuDnnHandle()); + cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream()); + if (err != 0) throw nd4j::cuda_exception::build("pooling3dCUDNN: can't set stream for cuDNN", err); +printf("fffffffffff\n"); + const int numDims = 5; + + int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; + int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); + + const int pSizes[] = {pD, pH, pW}; + const int sSizes[] = {sD, sH, sW}; + const int kSizes[] = {kD, kH, kW}; + + const int xShape[] = {bS, iC, iD, iH, iW}; + const int zShape[] = {bS, oC, oD, oH, oW}; + + const int xStrides[] = {(int)input->strideAt(0), (int)input->strideAt(1), (int)input->strideAt(2), (int)input->strideAt(3), (int)input->strideAt(4)}; + const int zStrides[] = {(int)output->strideAt(0), (int)output->strideAt(1), (int)output->strideAt(2), (int)output->strideAt(3), (int)output->strideAt(4)}; + + cudnnTensorFormat_t format = isNCDHW ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC; + + // input descriptor + cudnnTensorDescriptor_t x; + cudnnCreateTensorDescriptor(&x); + if(input->ews() == 1) + err = cudnnSetTensorNdDescriptorEx(x, format, cudnnDataType(input->dataType()), numDims, xShape); + else + err = cudnnSetTensorNdDescriptor(x, cudnnDataType(input->dataType()), numDims, xShape, xStrides); + if (err != 0) throw nd4j::cuda_exception::build("pooling3dCUDNN: cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for input failed", err); + + // output descriptor + cudnnTensorDescriptor_t z; + cudnnCreateTensorDescriptor(&z); + if(output->ews() == 1) + err = cudnnSetTensorNdDescriptorEx(z, format, cudnnDataType(output->dataType()), numDims, zShape); + else + err = cudnnSetTensorNdDescriptor(z, cudnnDataType(output->dataType()), numDims, zShape, zStrides); + if (err != 0) throw nd4j::cuda_exception::build("pooling3dCUDNN: cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for output failed", err); + + // description of pooling + cudnnPoolingDescriptor_t pooling; + cudnnCreatePoolingDescriptor(&pooling); + err = cudnnSetPoolingNdDescriptor(pooling, mode, CUDNN_PROPAGATE_NAN, numDims - 2, kSizes, pSizes, sSizes); + if (err != 0) throw nd4j::cuda_exception::build("pooling3dCUDNN: cudnnSetPoolingNdDescriptor failed", err); + + // provide scaling parameters + const float alpha32(1), beta32(0); + const double alpha64(1), beta64(0); + const void* alpha = output->sizeOfT() <= 4 ? reinterpret_cast(&alpha32) : reinterpret_cast(&alpha64); + const void* beta = output->sizeOfT() <= 4 ? reinterpret_cast(&beta32) : reinterpret_cast(&beta64); + + NDArray::prepareSpecialUse({output}, {input}); + + // run calculation + err = cudnnPoolingForward(*handle, pooling, alpha, x, input->getSpecialBuffer(), beta, z, output->specialBuffer()); + if (err != 0) throw nd4j::cuda_exception::build("pooling3dCUDNN: cudnnPoolingForward failed", err); + + auto cudaErr = cudaStreamSynchronize(*context->getCudaStream()); + if (cudaErr != 0) + throw cuda_exception::build("pooling3dCUDNN: cudaStreamSynchronize failed !", cudaErr); + + NDArray::registerSpecialUse({output}, {input}); +} + +////////////////////////////////////////////////////////////////////////// +void pooling3dBpCUDNN(const LaunchContext* context, + const NDArray* input, const NDArray* gradO, + NDArray* gradI, + const int kD, const int kH, const int kW, + const int sD, const int sH, const int sW, + const int pD, const int pH, const int pW, + const int dD, const int dH, const int dW, + const bool isNCDHW, const cudnnPoolingMode_t mode) { + + auto handle = reinterpret_cast(context->getCuDnnHandle()); + cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream()); + if (err != 0) throw nd4j::cuda_exception::build("pooling3dBpCUDNN: can't set stream for cuDNN", err); + + const int numDims = 5; + + int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; + int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); + + const int pSizes[] = {pD, pH, pW}; + const int sSizes[] = {sD, sH, sW}; + const int kSizes[] = {kD, kH, kW}; + + const int xShape[] = {bS, iC, iD, iH, iW}; + const int dzShape[] = {bS, oC, oD, oH, oW}; + + const int xStrides[] = {(int)input->strideAt(0), (int)input->strideAt(1), (int)input->strideAt(2), (int)input->strideAt(3), (int)input->strideAt(4)}; + const int dzStrides[] = {(int)gradO->strideAt(0), (int)gradO->strideAt(1), (int)gradO->strideAt(2), (int)gradO->strideAt(3), (int)gradO->strideAt(4)}; + + cudnnTensorFormat_t format = isNCDHW ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC; + + // input and gradI descriptor + cudnnTensorDescriptor_t x; + cudnnCreateTensorDescriptor(&x); + if(input->ews() == 1) + err = cudnnSetTensorNdDescriptorEx(x, format, cudnnDataType(input->dataType()), numDims, xShape); + else + err = cudnnSetTensorNdDescriptor(x, cudnnDataType(input->dataType()), numDims, xShape, xStrides); + if (err != 0) throw nd4j::cuda_exception::build("pooling3dBpCUDNN: cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for input/gradI failed", err); + + // gradO descriptor + cudnnTensorDescriptor_t dz; + cudnnCreateTensorDescriptor(&dz); + if(gradO->ews() == 1) + err = cudnnSetTensorNdDescriptorEx(dz, format, cudnnDataType(gradO->dataType()), numDims, dzShape); + else + err = cudnnSetTensorNdDescriptor(dz, cudnnDataType(gradO->dataType()), numDims, dzShape, dzStrides); + if (err != 0) throw nd4j::cuda_exception::build("pooling3dBpCUDNN: cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for gradO failed", err); + + // description of pooling + cudnnPoolingDescriptor_t pooling; + cudnnCreatePoolingDescriptor(&pooling); + err = cudnnSetPoolingNdDescriptor(pooling, mode, CUDNN_PROPAGATE_NAN, numDims - 2, kSizes, pSizes, sSizes); + if (err != 0) throw nd4j::cuda_exception::build("pooling3dBpCUDNN: cudnnSetPoolingNdDescriptor failed", err); + + // provide scaling parameters + const float alpha32(1), beta32(0); + const double alpha64(1), beta64(0); + const void* alpha = gradO->sizeOfT() <= 4 ? reinterpret_cast(&alpha32) : reinterpret_cast(&alpha64); + const void* beta = gradO->sizeOfT() <= 4 ? reinterpret_cast(&beta32) : reinterpret_cast(&beta64); + + // cudnn maxpool2d_bp api requires ff output as one of input arguments + if(mode == CUDNN_POOLING_MAX) { + + NDArray temp(gradO); + + NDArray::prepareSpecialUse({gradI}, {input, gradO, &temp}); + + // run ff calculation + err = cudnnPoolingForward(*handle, pooling, alpha, x, input->getSpecialBuffer(), beta, dz, temp.specialBuffer()); + if (err != 0) throw nd4j::cuda_exception::build("pooling3dCUDNN: cudnnPoolingForward failed", err); + + // run bp calculation for gradI + err = cudnnPoolingBackward(*handle, pooling, alpha, dz, temp.getSpecialBuffer(), dz, gradO->getSpecialBuffer(), x, input->getSpecialBuffer(), beta, x, gradI->getSpecialBuffer()); + if (err != 0) throw nd4j::cuda_exception::build("pooling2dBpCUDNN: cudnnPoolingBackward failed", err); + + NDArray::registerSpecialUse({gradI}, {input, gradO, &temp}); + } + else { + + NDArray::prepareSpecialUse({gradI}, {input, gradO}); + + // run bp calculation for gradI + err = cudnnPoolingBackward(*handle, pooling, alpha, dz, gradO->getSpecialBuffer(), dz, gradO->getSpecialBuffer(), x, input->getSpecialBuffer(), beta, x, gradI->getSpecialBuffer()); + if (err != 0) throw nd4j::cuda_exception::build("pooling2dBpCUDNN: cudnnPoolingBackward failed", err); + + NDArray::registerSpecialUse({gradI}, {input, gradO}); + } + + auto cudaErr = cudaStreamSynchronize(*context->getCudaStream()); + if (cudaErr != 0) + throw cuda_exception::build("pooling3dBpCUDNN: cudaStreamSynchronize failed !", cudaErr); +} + +} +} +} diff --git a/libnd4j/include/ops/declarable/platform/cudnn/cudnnUtils.h b/libnd4j/include/ops/declarable/platform/cudnn/cudnnUtils.h index bdff86e24..5c46fb7b0 100644 --- a/libnd4j/include/ops/declarable/platform/cudnn/cudnnUtils.h +++ b/libnd4j/include/ops/declarable/platform/cudnn/cudnnUtils.h @@ -30,8 +30,8 @@ #include -namespace nd4j { -namespace ops { +namespace nd4j { +namespace ops { namespace platforms { DECLARE_PLATFORM(conv2d, ENGINE_CUDA); @@ -46,6 +46,18 @@ namespace platforms { DECLARE_PLATFORM(batchnorm, ENGINE_CUDA); DECLARE_PLATFORM(batchnorm_bp, ENGINE_CUDA); + DECLARE_PLATFORM(avgpool2d, ENGINE_CUDA); + DECLARE_PLATFORM(avgpool2d_bp, ENGINE_CUDA); + + DECLARE_PLATFORM(maxpool2d, ENGINE_CUDA); + DECLARE_PLATFORM(maxpool2d_bp, ENGINE_CUDA); + + DECLARE_PLATFORM(avgpool3dnew, ENGINE_CUDA); + DECLARE_PLATFORM(avgpool3dnew_bp, ENGINE_CUDA); + + DECLARE_PLATFORM(maxpool3dnew, ENGINE_CUDA); + DECLARE_PLATFORM(maxpool3dnew_bp, ENGINE_CUDA); + ////////////////////////////////////////////////////////////////////////// FORCEINLINE cudnnDataType_t cudnnDataType(nd4j::DataType dataType) { switch (dataType) { @@ -65,91 +77,62 @@ FORCEINLINE cudnnDataType_t cudnnDataType(nd4j::DataType dataType) { } ////////////////////////////////////////////////////////////////////////// -FORCEINLINE void checkConv2dCUDNNPadAsymmetric(NDArray* &input, NDArray* &gradI, - const int iH, const int iW, - const int oH, const int oW, - const int kH, const int kW, - const int sH, const int sW, - const int pH, const int pW, - const int dH, const int dW, - const bool isNCHW) { - - const auto pHsum = ((oH - 1) * sH + ((kH - 1) * dH + 1) - iH); - const auto pWsum = ((oW - 1) * sW + ((kW - 1) * dW + 1) - iW); - - const bool isPHasymm = pH != (pHsum - pH); - const bool isPWasymm = pW != (pWsum - pW); - - if(!isPHasymm && !isPWasymm) - return; - - std::vector newShape = input->getShapeAsVector(); - - const int iHposition = isNCHW ? 2 : 1; - - if(isPHasymm) - newShape[iHposition] += 1; - if(isPWasymm) - newShape[iHposition + 1] += 1; - - NDArray* newInput = new NDArray(input->ordering(), newShape, input->dataType(), input->getContext()); - - if(isNCHW) - (*newInput)({0,0, 0,0, 0,input->sizeAt(2), 0,input->sizeAt(3)}).assign(input); - else - (*newInput)({0,0, 0,input->sizeAt(1), 0,input->sizeAt(2), 0,0}).assign(input); - - input = newInput; - - if(gradI != nullptr) - gradI = new NDArray(gradI->ordering(), newShape, gradI->dataType(), gradI->getContext()); -} - +void checkConv2dCUDNNPadAsymmetric(NDArray* &input, NDArray* &gradI, + const int iH, const int iW, + const int oH, const int oW, + const int kH, const int kW, + const int sH, const int sW, + const int pH, const int pW, + const int dH, const int dW, + const bool isNCHW); ////////////////////////////////////////////////////////////////////////// -FORCEINLINE void checkConv3dCUDNNPadAsymmetric(NDArray* &input, NDArray* &gradI, - const int iD, const int iH, const int iW, - const int oD, const int oH, const int oW, - const int kD, const int kH, const int kW, - const int sD, const int sH, const int sW, - const int pD, const int pH, const int pW, - const int dD, const int dH, const int dW, - const bool isNCDHW) { +void checkConv3dCUDNNPadAsymmetric(NDArray* &input, NDArray* &gradI, + const int iD, const int iH, const int iW, + const int oD, const int oH, const int oW, + const int kD, const int kH, const int kW, + const int sD, const int sH, const int sW, + const int pD, const int pH, const int pW, + const int dD, const int dH, const int dW, + const bool isNCDHW); - const auto pDsum = ((oD - 1) * sD + ((kD - 1) * dD + 1) - iD); - const auto pHsum = ((oH - 1) * sH + ((kH - 1) * dH + 1) - iH); - const auto pWsum = ((oW - 1) * sW + ((kW - 1) * dW + 1) - iW); +////////////////////////////////////////////////////////////////////////// +void pooling2dCUDNN(const LaunchContext* context, + const NDArray* input, NDArray* output, + const int kH, const int kW, + const int sH, const int sW, + const int pH, const int pW, + const int dH, const int dW, + const bool isNCHW, const cudnnPoolingMode_t mode); - const bool isPDasymm = pD != (pDsum - pD); - const bool isPHasymm = pH != (pHsum - pH); - const bool isPWasymm = pW != (pWsum - pW); +////////////////////////////////////////////////////////////////////////// +void pooling2dBpCUDNN(const LaunchContext* context, + const NDArray* input, const NDArray* gradO, + NDArray* gradI, + const int kH, const int kW, + const int sH, const int sW, + const int pH, const int pW, + const int dH, const int dW, + const bool isNCHW, const cudnnPoolingMode_t mode); - if(!isPDasymm && !isPHasymm && !isPWasymm) - return; +////////////////////////////////////////////////////////////////////////// +void pooling3dCUDNN(const LaunchContext* context, + const NDArray* input, NDArray* output, + const int kD, const int kH, const int kW, + const int sD, const int sH, const int sW, + const int pD, const int pH, const int pW, + const int dD, const int dH, const int dW, + const bool isNCDHW, const cudnnPoolingMode_t mode); - std::vector newShape = input->getShapeAsVector(); - - const int iDposition = isNCDHW ? 2 : 1; - - if(isPDasymm) - newShape[iDposition] += 1; - if(isPHasymm) - newShape[iDposition + 1] += 1; - if(isPWasymm) - newShape[iDposition + 2] += 1; - - NDArray* newInput = new NDArray(input->ordering(), newShape, input->dataType(), input->getContext()); - - if(isNCDHW) - (*newInput)({0,0, 0,0, 0,input->sizeAt(2), 0,input->sizeAt(3), 0,input->sizeAt(4)}).assign(input); - else - (*newInput)({0,0, 0,input->sizeAt(1), 0,input->sizeAt(2), 0,input->sizeAt(3), 0,0}).assign(input); - - input = newInput; - - if(gradI != nullptr) - gradI = new NDArray(gradI->ordering(), newShape, gradI->dataType(), gradI->getContext()); -} +////////////////////////////////////////////////////////////////////////// +void pooling3dBpCUDNN(const LaunchContext* context, + const NDArray* input, const NDArray* gradO, + NDArray* gradI, + const int kD, const int kH, const int kW, + const int sD, const int sH, const int sW, + const int pD, const int pH, const int pW, + const int dD, const int dH, const int dW, + const bool isNCDHW, const cudnnPoolingMode_t mode); } } diff --git a/libnd4j/include/ops/declarable/platform/cudnn/maxpool2d.cu b/libnd4j/include/ops/declarable/platform/cudnn/maxpool2d.cu new file mode 100644 index 000000000..6d5affe79 --- /dev/null +++ b/libnd4j/include/ops/declarable/platform/cudnn/maxpool2d.cu @@ -0,0 +1,132 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Yurii Shyrma (iuriish@yahoo.com) +// + + +#include "cudnnUtils.h" +#include + +namespace nd4j { +namespace ops { +namespace platforms { + + +////////////////////////////////////////////////////////////////////////// +PLATFORM_IMPL(maxpool2d, ENGINE_CUDA) { + + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - paddingModee; + const auto kH = INT_ARG(0); + const auto kW = INT_ARG(1); + const auto sH = INT_ARG(2); + const auto sW = INT_ARG(3); + auto pH = INT_ARG(4); + auto pW = INT_ARG(5); + const auto dH = INT_ARG(6); + const auto dW = INT_ARG(7); + const auto paddingMode = static_cast(INT_ARG(8)); + const int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC + + REQUIRE_TRUE(input->rankOf() == 4, 0, "MAXPOOL2D CUDNN op: input should have rank of 4, but got %i instead", input->rankOf()); + REQUIRE_TRUE(dH != 0 && dW != 0, 0, "MAXPOOL2D CUDNN op: dilation must not be zero, but got instead {%i, %i}", dH, dW); + + int oH = 0; + int oW = 0; + + const int iH = static_cast(isNCHW ? input->sizeAt(2) : input->sizeAt(1)); + const int iW = static_cast(isNCHW ? input->sizeAt(3) : input->sizeAt(2)); + + ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, paddingMode); + + if (paddingMode) + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); + + pooling2dCUDNN(block.launchContext(), input, output, kH, kW, sH, sW, pH, pW, dH, dW, isNCHW, CUDNN_POOLING_MAX); + + return Status::OK(); +} + +////////////////////////////////////////////////////////////////////////// +PLATFORM_CHECK(maxpool2d, ENGINE_CUDA) { + + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + const auto goodType = input->dataType() == DataType::DOUBLE || input->dataType() == DataType::FLOAT32 || input->dataType() == DataType::HALF || input->dataType() == DataType::INT32; + + return goodType && input->dataType() == output->dataType(); +} + +////////////////////////////////////////////////////////////////////////// +PLATFORM_IMPL(maxpool2d_bp, ENGINE_CUDA) { + + auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + auto gradO = INPUT_VARIABLE(1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next + auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon + + const auto kH = INT_ARG(0); // filter(kernel) height + const auto kW = INT_ARG(1); // filter(kernel) width + const auto sH = INT_ARG(2); // strides height + const auto sW = INT_ARG(3); // strides width + auto pH = INT_ARG(4); // paddings height + auto pW = INT_ARG(5); // paddings width + const auto dH = INT_ARG(6); // dilations height + const auto dW = INT_ARG(7); // dilations width + const auto paddingMode = INT_ARG(8); // 0-VALID, 1-SAME + const auto isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC + + REQUIRE_TRUE(input->rankOf() == 4, 0, "MAXPOOL2D_BP CUDNN op: input should have rank of 4, but got %i instead", input->rankOf()); + REQUIRE_TRUE(dH != 0 && dW != 0, 0, "MAXPOOL2D_BP CUDNN op: dilation must not be zero, but got instead {%i, %i}", dH, dW); + + int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); + + std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oH,oW, 0,indIOioC,indIiH,indIiH+1}); + std::vector expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iH,iW, 0,indIOioC,indIiH,indIiH+1}); + REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "MAXPOOL2D_BP CUDNN op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); + REQUIRE_TRUE(gradI->isSameShape(expectedGradIShape), 0, "MAXPOOL2D_BP CUDNN op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradIShape).c_str(), ShapeUtils::shapeAsString(gradI).c_str()); + + if(paddingMode) // SAME + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); + + pooling2dBpCUDNN(block.launchContext(), input, gradO, gradI, kH, kW, sH, sW, pH, pW, dH, dW, isNCHW, CUDNN_POOLING_MAX); + + return Status::OK(); +} + +PLATFORM_CHECK(maxpool2d_bp, ENGINE_CUDA) { + + auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + auto gradO = INPUT_VARIABLE(1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next + auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon + + const auto goodType = input->dataType() == DataType::DOUBLE || input->dataType() == DataType::FLOAT32 || input->dataType() == DataType::HALF || input->dataType() == DataType::INT32; + + return goodType && (input->dataType() == gradO->dataType()) + && (input->dataType() == gradI->dataType()) + && shape::haveSameShapeAndStrides(input->getShapeInfo(), gradI->getShapeInfo()); +} + + +} +} +} diff --git a/libnd4j/include/ops/declarable/platform/cudnn/maxpool3d.cu b/libnd4j/include/ops/declarable/platform/cudnn/maxpool3d.cu new file mode 100644 index 000000000..fc2e38577 --- /dev/null +++ b/libnd4j/include/ops/declarable/platform/cudnn/maxpool3d.cu @@ -0,0 +1,140 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Yurii Shyrma (iuriish@yahoo.com) +// + + +#include "cudnnUtils.h" +#include + +namespace nd4j { +namespace ops { +namespace platforms { + + +////////////////////////////////////////////////////////////////////////// +PLATFORM_IMPL(maxpool3dnew, ENGINE_CUDA) { + + auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) + auto output = OUTPUT_VARIABLE(0); // [bS, oD, oH, oW, iC] (NDHWC) or [bS, iC, oD, oH, oW] (NCDHW) + + int kD = INT_ARG(0); // filter(kernel) depth + int kH = INT_ARG(1); // filter(kernel) height + int kW = INT_ARG(2); // filter(kernel) width + int sD = INT_ARG(3); // strides depth + int sH = INT_ARG(4); // strides height + int sW = INT_ARG(5); // strides width + int pD = INT_ARG(6); // paddings depth + int pH = INT_ARG(7); // paddings height + int pW = INT_ARG(8); // paddings width + int dD = INT_ARG(9); // dilations depth + int dH = INT_ARG(10); // dilations height + int dW = INT_ARG(11); // dilations width + int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID + // int extraParam0 = INT_ARG(13); + int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 0-NCDHW, 1-NDHWC + + REQUIRE_TRUE(input->rankOf() == 5, 0, "MAXPOOL3DNEW CUDNN OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf()); + REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "MAXPOOL3DNEW CUDNN OP: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW); + + int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; + int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); + + std::vector expectedOutputShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); + REQUIRE_TRUE(output->isSameShape(expectedOutputShape), 0, "MAXPOOL3DNEW CUDNN OP: wrong shape of output array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedOutputShape).c_str(), ShapeUtils::shapeAsString(output).c_str()); + + if(paddingMode) // SAME + ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW); + + pooling3dCUDNN(block.launchContext(), input, output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isNCDHW, CUDNN_POOLING_MAX); + + return Status::OK(); +} + +////////////////////////////////////////////////////////////////////////// +PLATFORM_CHECK(maxpool3dnew, ENGINE_CUDA) { + + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + const auto goodType = input->dataType() == DataType::DOUBLE || input->dataType() == DataType::FLOAT32 || input->dataType() == DataType::HALF || input->dataType() == DataType::INT32; + + return goodType && input->dataType() == output->dataType(); +} + +////////////////////////////////////////////////////////////////////////// +PLATFORM_IMPL(maxpool3dnew_bp, ENGINE_CUDA) { + + auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) + auto gradO = INPUT_VARIABLE(1); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next + auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon + + const int kD = INT_ARG(0); // filter(kernel) depth + const int kH = INT_ARG(1); // filter(kernel) height + const int kW = INT_ARG(2); // filter(kernel) width + const int sD = INT_ARG(3); // strides depth + const int sH = INT_ARG(4); // strides height + const int sW = INT_ARG(5); // strides width + int pD = INT_ARG(6); // paddings depth + int pH = INT_ARG(7); // paddings height + int pW = INT_ARG(8); // paddings width + const int dD = INT_ARG(9); // dilations depth + const int dH = INT_ARG(10); // dilations height + const int dW = INT_ARG(11); // dilations width + const int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID + // const int extraParam0 = INT_ARG(13); // define what divisor to use while averaging + const int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 0-NCDHW, 1-NDHWC + + REQUIRE_TRUE(input->rankOf() == 5, 0, "MAXPOOL3DNEW_BP CUDNN OP: input should have rank of 5, but got %i instead", input->rankOf()); + REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "MAXPOOL3DNEW_BP CUDNN OP: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW); + + int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; + int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); + + std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); + std::vector expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iD,iH,iW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); + REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "MAXPOOL3DNEW_BP CUDNN: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); + REQUIRE_TRUE(gradI->isSameShape(expectedGradIShape), 0, "MAXPOOL3DNEW_BP CUDNN: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradIShape).c_str(), ShapeUtils::shapeAsString(gradI).c_str()); + + if(isSameMode) // SAME + ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW); + + pooling3dBpCUDNN(block.launchContext(), input, gradO, gradI, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isNCDHW, CUDNN_POOLING_MAX); + + return Status::OK(); +} + +PLATFORM_CHECK(maxpool3dnew_bp, ENGINE_CUDA) { + + auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) + auto gradO = INPUT_VARIABLE(1); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next + auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon + + const auto goodType = input->dataType() == DataType::DOUBLE || input->dataType() == DataType::FLOAT32 || input->dataType() == DataType::HALF || input->dataType() == DataType::INT32; + + return goodType && (input->dataType() == gradO->dataType()) + && (input->dataType() == gradI->dataType()) + && shape::haveSameShapeAndStrides(input->getShapeInfo(), gradI->getShapeInfo()); +} + + +} +} +} diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling2d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling2d.cpp index bf614bfab..1c1e9d6a4 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling2d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling2d.cpp @@ -30,111 +30,231 @@ using namespace dnnl; using namespace samediff; -namespace nd4j { - namespace ops { - namespace platforms { - PLATFORM_IMPL(avgpool2d, ENGINE_CPU) { - auto input = INPUT_VARIABLE(0); +namespace nd4j { +namespace ops { +namespace platforms { - REQUIRE_TRUE(input->rankOf() == 4, 0, "Input should have rank of 4, but got %i instead", - input->rankOf()); +////////////////////////////////////////////////////////////////////////// +PLATFORM_IMPL(avgpool2d, ENGINE_CPU) { + auto input = INPUT_VARIABLE(0); - // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; - auto argI = *(block.getIArguments()); - auto output = OUTPUT_VARIABLE(0); + REQUIRE_TRUE(input->rankOf() == 4, 0, "Input should have rank of 4, but got %i instead", + input->rankOf()); - const auto kH = INT_ARG(0); - const auto kW = INT_ARG(1); - const auto sH = INT_ARG(2); - const auto sW = INT_ARG(3); - int pH = INT_ARG(4); - int pW = INT_ARG(5); - const auto dH = INT_ARG(6); - const auto dW = INT_ARG(7); - const auto isSameMode = static_cast(INT_ARG(8)); - const auto extraParam0 = INT_ARG(9); + // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; + auto argI = *(block.getIArguments()); + auto output = OUTPUT_VARIABLE(0); - REQUIRE_TRUE(dH != 0 && dW != 0, 0, "AVGPOOL2D op: dilation must not be zero, but got instead {%i, %i}", - dH, dW); + const auto kH = INT_ARG(0); + const auto kW = INT_ARG(1); + const auto sH = INT_ARG(2); + const auto sW = INT_ARG(3); + int pH = INT_ARG(4); + int pW = INT_ARG(5); + const auto dH = INT_ARG(6); + const auto dW = INT_ARG(7); + const auto isSameMode = static_cast(INT_ARG(8)); + const auto extraParam0 = INT_ARG(9); - int oH = 0; - int oW = 0; + REQUIRE_TRUE(dH != 0 && dW != 0, 0, "AVGPOOL2D op: dilation must not be zero, but got instead {%i, %i}", + dH, dW); - int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC + int oH = 0; + int oW = 0; - const int iH = static_cast(isNCHW ? input->sizeAt(2) : input->sizeAt(1)); - const int iW = static_cast(isNCHW ? input->sizeAt(3) : input->sizeAt(2)); + int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC - if (!isNCHW) { - input = new NDArray( - input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] - output = new NDArray( - output->permute({0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW] - } + const int iH = static_cast(isNCHW ? input->sizeAt(2) : input->sizeAt(1)); + const int iW = static_cast(isNCHW ? input->sizeAt(3) : input->sizeAt(2)); - ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode); - - if (isSameMode) - ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); - - const int bS = input->sizeAt(0); - const int iC = input->sizeAt(1); - const int oC = output->sizeAt(1); - - auto poolingMode = PoolingType::AVG_POOL; - - dnnl_memory_desc_t empty; - dnnl::memory::desc pool_src_md(empty), pool_dst_md(empty); - dnnl::memory::desc user_src_md(empty), user_dst_md(empty); - dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r; - dnnl::algorithm algorithm; - mkldnnUtils::getMKLDNNMemoryDescPool2d(kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0, - true, - bS, iC, iH, iW, oC, oH, oW, input, nullptr, output, - algorithm, - &pool_src_md, nullptr, &pool_dst_md, &user_src_md, nullptr, - &user_dst_md, - pool_strides, pool_kernel, pool_padding, pool_padding_r); - auto pool_desc = pooling_forward::desc(prop_kind::forward_inference, algorithm, pool_src_md, - pool_dst_md, - pool_strides, pool_kernel, pool_padding, pool_padding_r); - auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine); - auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer()); - auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer()); - auto pool_src_memory = user_src_memory; - dnnl::stream stream(engine); - if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) { - pool_src_memory = dnnl::memory(pool_prim_desc.src_desc(), engine); - reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory); - } - auto pool_dst_memory = user_dst_memory; - if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) { - pool_dst_memory = dnnl::memory(pool_prim_desc.dst_desc(), engine); - } - pooling_forward(pool_prim_desc).execute(stream, {{DNNL_ARG_SRC, pool_src_memory}, - {DNNL_ARG_DST, pool_dst_memory}}); - if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) { - reorder(pool_dst_memory, user_dst_memory).execute(stream, pool_dst_memory, user_dst_memory); - } - stream.wait(); - - //streams[0].submitAndWait(); - - if (!isNCHW) { - delete input; - delete output; - } - - return Status::OK(); - } - - PLATFORM_CHECK(avgpool2d, ENGINE_CPU) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output}); - } - } + if (!isNCHW) { + input = new NDArray( + input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] + output = new NDArray( + output->permute({0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW] } + + ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode); + + if (isSameMode) + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); + + const int bS = input->sizeAt(0); + const int iC = input->sizeAt(1); + const int oC = output->sizeAt(1); + + auto poolingMode = PoolingType::AVG_POOL; + + dnnl_memory_desc_t empty; + dnnl::memory::desc pool_src_md(empty), pool_dst_md(empty); + dnnl::memory::desc user_src_md(empty), user_dst_md(empty); + dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r; + dnnl::algorithm algorithm; + mkldnnUtils::getMKLDNNMemoryDescPool2d(kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0, + true, + bS, iC, iH, iW, oC, oH, oW, input, nullptr, output, + algorithm, + &pool_src_md, nullptr, &pool_dst_md, &user_src_md, nullptr, + &user_dst_md, + pool_strides, pool_kernel, pool_padding, pool_padding_r); + auto pool_desc = pooling_forward::desc(prop_kind::forward_inference, algorithm, pool_src_md, + pool_dst_md, + pool_strides, pool_kernel, pool_padding, pool_padding_r); + auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); + auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine); + auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer()); + auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer()); + auto pool_src_memory = user_src_memory; + dnnl::stream stream(engine); + if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) { + pool_src_memory = dnnl::memory(pool_prim_desc.src_desc(), engine); + reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory); + } + auto pool_dst_memory = user_dst_memory; + if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) { + pool_dst_memory = dnnl::memory(pool_prim_desc.dst_desc(), engine); + } + pooling_forward(pool_prim_desc).execute(stream, {{DNNL_ARG_SRC, pool_src_memory}, + {DNNL_ARG_DST, pool_dst_memory}}); + if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) { + reorder(pool_dst_memory, user_dst_memory).execute(stream, pool_dst_memory, user_dst_memory); + } + stream.wait(); + + //streams[0].submitAndWait(); + + if (!isNCHW) { + delete input; + delete output; + } + + return Status::OK(); +} + +////////////////////////////////////////////////////////////////////////// +PLATFORM_CHECK(avgpool2d, ENGINE_CPU) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output}); +} + +////////////////////////////////////////////////////////////////////////// +PLATFORM_IMPL(avgpool2d_bp, ENGINE_CPU) { + auto input = INPUT_VARIABLE( + 0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + auto gradO = INPUT_VARIABLE( + 1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next + auto gradI = OUTPUT_VARIABLE( + 0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon + + int kH = INT_ARG(0); // filter(kernel) height + int kW = INT_ARG(1); // filter(kernel) width + int sH = INT_ARG(2); // strides height + int sW = INT_ARG(3); // strides width + int pH = INT_ARG(4); // paddings height + int pW = INT_ARG(5); // paddings width + int dH = INT_ARG(6); // dilations height + int dW = INT_ARG(7); // dilations width + int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME + int extraParam0 = INT_ARG(9); + int isNCHW = + block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC + + REQUIRE_TRUE(input->rankOf() == 4, 0, + "AVGPOOL2D_BP op: input should have rank of 4, but got %i instead", input->rankOf()); + REQUIRE_TRUE(dH != 0 && dW != 0, 0, + "AVGPOOL2D_BP op: dilation must not be zero, but got instead {%i, %i}", dH, dW); + + int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, + indIiH, indWiC, indWoC, indWkH, indOoH); + + std::string expectedGradOShape = ShapeUtils::shapeAsString( + ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, oH, oW, 0, indIOioC, indIiH, indIiH + 1})); + std::string expectedGradIShape = ShapeUtils::shapeAsString( + ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, iH, iW, 0, indIOioC, indIiH, indIiH + 1})); + REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, + "AVGPOOL2D_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", + expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str()); + REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0, + "AVGPOOL2D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", + expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str()); + + + if (!isNCHW) { + input = new NDArray(input->permute( + {0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] + gradI = new NDArray(gradI->permute( + {0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] + gradO = new NDArray(gradO->permute( + {0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW] + } + + if (isSameMode) // SAME + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); + + auto poolingMode = PoolingType::AVG_POOL; + + dnnl_memory_desc_t empty; + dnnl::memory::desc pool_src_md(empty), pool_diff_src_md(empty), pool_dst_md(empty); + dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_dst_md(empty); + dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r; + dnnl::algorithm algorithm; + mkldnnUtils::getMKLDNNMemoryDescPool2d(kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0, + true, + bS, iC, iH, iW, oC, oH, oW, input, gradI, gradO, algorithm, + &pool_src_md, &pool_diff_src_md, &pool_dst_md, &user_src_md, + &user_diff_src_md, &user_dst_md, + pool_strides, pool_kernel, pool_padding, pool_padding_r); + auto pool_desc = pooling_forward::desc(prop_kind::forward, algorithm, + input->buffer() != nullptr ? pool_src_md : pool_diff_src_md, + pool_dst_md, pool_strides, pool_kernel, pool_padding, + pool_padding_r); + auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); + auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine); + auto poolB_desc = pooling_backward::desc(algorithm, pool_diff_src_md, pool_dst_md, pool_strides, + pool_kernel, pool_padding, pool_padding_r); + auto poolB_prim_desc = pooling_backward::primitive_desc(poolB_desc, engine, pool_prim_desc); + auto userB_src_memory = dnnl::memory(user_src_md, engine, gradI->buffer()); + auto userB_dst_memory = dnnl::memory(user_dst_md, engine, gradO->buffer()); + auto poolB_src_memory = userB_src_memory; + dnnl::stream stream(engine); + if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) { + poolB_src_memory = dnnl::memory(poolB_prim_desc.diff_src_desc(), engine); + } + auto poolB_dst_memory = userB_dst_memory; + if (poolB_prim_desc.diff_dst_desc() != userB_dst_memory.get_desc()) { + poolB_dst_memory = dnnl::memory(poolB_prim_desc.diff_dst_desc(), engine); + reorder(userB_dst_memory, poolB_dst_memory).execute(stream, userB_dst_memory, poolB_dst_memory); + } + pooling_backward(poolB_prim_desc).execute(stream, {{DNNL_ARG_DIFF_DST, poolB_dst_memory}, + {DNNL_ARG_DIFF_SRC, poolB_src_memory}}); + if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) { + reorder(poolB_src_memory, userB_src_memory).execute(stream, poolB_src_memory, userB_src_memory); + } + stream.wait(); + + if (!isNCHW) { + delete input; + delete gradI; + delete gradO; + } + + + return Status::OK(); +} + +////////////////////////////////////////////////////////////////////////// +PLATFORM_CHECK(avgpool2d_bp, ENGINE_CPU) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output}); +} + + +} +} } \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling2d_bp.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling2d_bp.cpp deleted file mode 100644 index af1fd04fd..000000000 --- a/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling2d_bp.cpp +++ /dev/null @@ -1,149 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// @author saudet -// @author raver119@gmail.com -// - -#include -#include -#include - -#include -#include "mkldnnUtils.h" -#include - -using namespace dnnl; - -namespace nd4j { - namespace ops { - namespace platforms { - PLATFORM_IMPL(avgpool2d_bp, ENGINE_CPU) { - auto input = INPUT_VARIABLE( - 0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - auto gradO = INPUT_VARIABLE( - 1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next - auto gradI = OUTPUT_VARIABLE( - 0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon - - int kH = INT_ARG(0); // filter(kernel) height - int kW = INT_ARG(1); // filter(kernel) width - int sH = INT_ARG(2); // strides height - int sW = INT_ARG(3); // strides width - int pH = INT_ARG(4); // paddings height - int pW = INT_ARG(5); // paddings width - int dH = INT_ARG(6); // dilations height - int dW = INT_ARG(7); // dilations width - int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME - int extraParam0 = INT_ARG(9); - int isNCHW = - block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC - - REQUIRE_TRUE(input->rankOf() == 4, 0, - "AVGPOOL2D_BP op: input should have rank of 4, but got %i instead", input->rankOf()); - REQUIRE_TRUE(dH != 0 && dW != 0, 0, - "AVGPOOL2D_BP op: dilation must not be zero, but got instead {%i, %i}", dH, dW); - - int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; - int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, - indIiH, indWiC, indWoC, indWkH, indOoH); - - std::string expectedGradOShape = ShapeUtils::shapeAsString( - ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, oH, oW, 0, indIOioC, indIiH, indIiH + 1})); - std::string expectedGradIShape = ShapeUtils::shapeAsString( - ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, iH, iW, 0, indIOioC, indIiH, indIiH + 1})); - REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, - "AVGPOOL2D_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", - expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str()); - REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0, - "AVGPOOL2D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", - expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str()); - - - if (!isNCHW) { - input = new NDArray(input->permute( - {0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] - gradI = new NDArray(gradI->permute( - {0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] - gradO = new NDArray(gradO->permute( - {0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW] - } - - if (isSameMode) // SAME - ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); - - auto poolingMode = PoolingType::AVG_POOL; - - dnnl_memory_desc_t empty; - dnnl::memory::desc pool_src_md(empty), pool_diff_src_md(empty), pool_dst_md(empty); - dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_dst_md(empty); - dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r; - dnnl::algorithm algorithm; - mkldnnUtils::getMKLDNNMemoryDescPool2d(kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0, - true, - bS, iC, iH, iW, oC, oH, oW, input, gradI, gradO, algorithm, - &pool_src_md, &pool_diff_src_md, &pool_dst_md, &user_src_md, - &user_diff_src_md, &user_dst_md, - pool_strides, pool_kernel, pool_padding, pool_padding_r); - auto pool_desc = pooling_forward::desc(prop_kind::forward, algorithm, - input->buffer() != nullptr ? pool_src_md : pool_diff_src_md, - pool_dst_md, pool_strides, pool_kernel, pool_padding, - pool_padding_r); - auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine); - auto poolB_desc = pooling_backward::desc(algorithm, pool_diff_src_md, pool_dst_md, pool_strides, - pool_kernel, pool_padding, pool_padding_r); - auto poolB_prim_desc = pooling_backward::primitive_desc(poolB_desc, engine, pool_prim_desc); - auto userB_src_memory = dnnl::memory(user_src_md, engine, gradI->buffer()); - auto userB_dst_memory = dnnl::memory(user_dst_md, engine, gradO->buffer()); - auto poolB_src_memory = userB_src_memory; - dnnl::stream stream(engine); - if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) { - poolB_src_memory = dnnl::memory(poolB_prim_desc.diff_src_desc(), engine); - } - auto poolB_dst_memory = userB_dst_memory; - if (poolB_prim_desc.diff_dst_desc() != userB_dst_memory.get_desc()) { - poolB_dst_memory = dnnl::memory(poolB_prim_desc.diff_dst_desc(), engine); - reorder(userB_dst_memory, poolB_dst_memory).execute(stream, userB_dst_memory, poolB_dst_memory); - } - pooling_backward(poolB_prim_desc).execute(stream, {{DNNL_ARG_DIFF_DST, poolB_dst_memory}, - {DNNL_ARG_DIFF_SRC, poolB_src_memory}}); - if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) { - reorder(poolB_src_memory, userB_src_memory).execute(stream, poolB_src_memory, userB_src_memory); - } - stream.wait(); - - if (!isNCHW) { - delete input; - delete gradI; - delete gradO; - } - - - return Status::OK(); - } - - PLATFORM_CHECK(avgpool2d_bp, ENGINE_CPU) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output}); - } - } - } -} \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/conv2d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/conv2d.cpp index ba1711032..559edf2cd 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/conv2d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/conv2d.cpp @@ -34,24 +34,23 @@ namespace ops { namespace platforms { ////////////////////////////////////////////////////////////////////// -static void conv2d_mkldnn(nd4j::graph::Context &block, const NDArray *input, const NDArray *weights, +static void conv2dMKLDNN(nd4j::graph::Context &block, const NDArray *input, const NDArray *weights, const NDArray *bias, NDArray *output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) { int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, - indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); dnnl_memory_desc_t empty; - dnnl::memory::desc conv_src_md(empty), conv_weights_md(empty), conv_bias_md(empty), conv_dst_md( - empty); - dnnl::memory::desc user_src_md(empty), user_weights_md(empty), user_bias_md(empty), user_dst_md( - empty); + dnnl::memory::desc conv_src_md(empty), conv_weights_md(empty), conv_bias_md(empty), conv_dst_md(empty); + dnnl::memory::desc user_src_md(empty), user_weights_md(empty), user_bias_md(empty), user_dst_md(empty); + dnnl::memory::dims conv_strides, conv_padding, conv_padding_r, conv_dilation; + mkldnnUtils::getMKLDNNMemoryDescConv2d(kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, bS, iC, iH, iW, oC, oH, oW, input, nullptr, weights, nullptr, bias, output, @@ -61,13 +60,12 @@ static void conv2d_mkldnn(nd4j::graph::Context &block, const NDArray *input, con &user_bias_md, &user_dst_md, conv_strides, conv_padding, conv_padding_r, conv_dilation); - auto conv_desc = bias != nullptr - ? convolution_forward::desc(prop_kind::forward, + auto conv_desc = bias != nullptr ? convolution_forward::desc(prop_kind::forward, algorithm::convolution_auto, conv_src_md, conv_weights_md, conv_bias_md, conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r) - : convolution_forward::desc(prop_kind::forward, + : convolution_forward::desc(prop_kind::forward, algorithm::convolution_auto, conv_src_md, conv_weights_md, conv_dst_md, conv_strides, conv_dilation, conv_padding, @@ -112,6 +110,135 @@ static void conv2d_mkldnn(nd4j::graph::Context &block, const NDArray *input, con stream.wait(); } +////////////////////////////////////////////////////////////////////// +static void conv2dBpMKLDNN(nd4j::graph::Context &block, + const NDArray *input, const NDArray *weights, const NDArray *bias, const NDArray *gradO, + NDArray *gradI, NDArray *gradW, NDArray *gradB, + const int kH, const int kW, const int sH,const int sW, int pH, int pW, const int dH, const int dW, + const int paddingMode, const int isNCHW) { + + int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); + + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); + + dnnl_memory_desc_t empty; + dnnl::memory::desc conv_src_md(empty), conv_diff_src_md(empty), conv_weights_md(empty), conv_diff_weights_md(empty), conv_bias_md(empty), conv_dst_md(empty); + dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_weights_md(empty), user_diff_weights_md(empty), user_bias_md(empty), user_dst_md(empty); + + dnnl::memory::dims conv_strides, conv_padding, conv_padding_r, conv_dilation; + + mkldnnUtils::getMKLDNNMemoryDescConv2d(kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, + bS, iC, iH, iW, oC, oH, oW, input, gradI, weights, gradW, + gradB, gradO, + &conv_src_md, &conv_diff_src_md, &conv_weights_md, + &conv_diff_weights_md, &conv_bias_md, &conv_dst_md, + &user_src_md, &user_diff_src_md, &user_weights_md, + &user_diff_weights_md, &user_bias_md, &user_dst_md, + conv_strides, conv_padding, conv_padding_r, conv_dilation); + auto conv_desc = gradB != nullptr + ? convolution_forward::desc(prop_kind::forward, algorithm::convolution_auto, conv_src_md, conv_weights_md, conv_bias_md, conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r) + : convolution_forward::desc(prop_kind::forward, algorithm::convolution_auto, conv_src_md, conv_weights_md, conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r); + + auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc, mkldnnUtils::getEngine( LaunchContext::defaultContext()->engine())); + + auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); + dnnl::stream stream(engine); + + if (gradW != nullptr) { + auto convW_desc = gradB != nullptr ? convolution_backward_weights::desc(algorithm::convolution_auto, conv_src_md, conv_diff_weights_md, conv_bias_md, conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r) + : convolution_backward_weights::desc(algorithm::convolution_auto, conv_src_md, conv_diff_weights_md, conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r); + + + auto convW_prim_desc = convolution_backward_weights::primitive_desc(convW_desc, engine, conv_prim_desc); + + auto userW_src_memory = dnnl::memory(user_src_md, engine, const_cast(input)->buffer()); + auto userW_weights_memory = dnnl::memory(user_diff_weights_md, engine, gradW->buffer()); + auto userW_dst_memory = dnnl::memory(user_dst_md, engine,const_cast(gradO)->buffer()); + + auto convW_src_memory = userW_src_memory; + + if (convW_prim_desc.src_desc() != userW_src_memory.get_desc()) { + convW_src_memory = dnnl::memory(convW_prim_desc.src_desc(), engine); + reorder(userW_src_memory, convW_src_memory).execute(stream, userW_src_memory,convW_src_memory); + } + + auto convW_weights_memory = userW_weights_memory; + if (convW_prim_desc.diff_weights_desc() != userW_weights_memory.get_desc()) { + convW_weights_memory = dnnl::memory(convW_prim_desc.diff_weights_desc(), engine); + } + + auto convW_dst_memory = userW_dst_memory; + if (convW_prim_desc.diff_dst_desc() != userW_dst_memory.get_desc()) { + convW_dst_memory = dnnl::memory(convW_prim_desc.diff_dst_desc(), engine); + reorder(userW_dst_memory, convW_dst_memory).execute(stream, userW_dst_memory, convW_dst_memory); + } + + if (gradB != nullptr) { + auto convW_bias_memory = dnnl::memory(convW_prim_desc.diff_bias_desc(), engine, gradB->buffer()); + + convolution_backward_weights(convW_prim_desc).execute(stream, + {{DNNL_ARG_SRC, convW_src_memory}, + {DNNL_ARG_DIFF_DST, convW_dst_memory}, + {DNNL_ARG_DIFF_WEIGHTS, convW_weights_memory}, + {DNNL_ARG_DIFF_BIAS, convW_bias_memory}}); + } + else { + convolution_backward_weights(convW_prim_desc).execute(stream, + {{DNNL_ARG_SRC, convW_src_memory}, + {DNNL_ARG_DIFF_DST, convW_dst_memory}, + {DNNL_ARG_DIFF_WEIGHTS, convW_weights_memory}}); + } + + if (convW_prim_desc.diff_weights_desc() != userW_weights_memory.get_desc()) { + reorder(convW_weights_memory, userW_weights_memory).execute(stream, convW_weights_memory, + userW_weights_memory); + } + + stream.wait(); + } + + if (gradI != nullptr) { + + auto convI_desc = convolution_backward_data::desc(algorithm::convolution_auto, conv_diff_src_md, conv_weights_md, conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r); + + + auto convI_prim_desc = convolution_backward_data::primitive_desc(convI_desc, engine, conv_prim_desc); + auto userI_src_memory = dnnl::memory(user_diff_src_md, engine, gradI->buffer()); + auto userI_weights_memory = dnnl::memory(user_weights_md, engine,const_cast(weights)->buffer()); + auto userI_dst_memory = dnnl::memory(user_dst_md, engine, const_cast(gradO)->buffer()); + + auto convI_src_memory = userI_src_memory; + if (convI_prim_desc.diff_src_desc() != userI_src_memory.get_desc()) { + convI_src_memory = dnnl::memory(convI_prim_desc.diff_src_desc(), engine); + } + + auto convI_weights_memory = userI_weights_memory; + if (convI_prim_desc.weights_desc() != userI_weights_memory.get_desc()) { + convI_weights_memory = dnnl::memory(convI_prim_desc.weights_desc(), engine); + reorder(userI_weights_memory, convI_weights_memory).execute(stream, userI_weights_memory, convI_weights_memory); + } + + auto convI_dst_memory = userI_dst_memory; + if (convI_prim_desc.diff_dst_desc() != userI_dst_memory.get_desc()) { + convI_dst_memory = dnnl::memory(convI_prim_desc.diff_dst_desc(), engine); + reorder(userI_dst_memory, convI_dst_memory).execute(stream, userI_dst_memory, convI_dst_memory); + } + + convolution_backward_data(convI_prim_desc).execute(stream, + {{DNNL_ARG_DIFF_DST, convI_dst_memory}, + {DNNL_ARG_WEIGHTS, convI_weights_memory}, + {DNNL_ARG_DIFF_SRC, convI_src_memory}}); + + if (convI_prim_desc.diff_src_desc() != userI_src_memory.get_desc()) { + reorder(convI_src_memory, userI_src_memory).execute(stream, convI_src_memory, userI_src_memory); + } + + stream.wait(); + } +} + ////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(conv2d, ENGINE_CPU) { auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) @@ -132,7 +259,7 @@ PLATFORM_IMPL(conv2d, ENGINE_CPU) { int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(weights->sizeAt(0)); // filter(kernel) height int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(weights->sizeAt(1)); // filter(kernel) width - conv2d_mkldnn(block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW); + conv2dMKLDNN(block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW); return Status::OK(); } @@ -152,6 +279,7 @@ PLATFORM_CHECK(conv2d, ENGINE_CPU) { ////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(conv2d_bp, ENGINE_CPU) { + auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] @@ -172,158 +300,11 @@ PLATFORM_IMPL(conv2d_bp, ENGINE_CPU) { int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC - REQUIRE_TRUE(input->rankOf() == 4, 0, - "CUSTOM CONV2D_BP OP: rank of input array must be equal to 4, but got %i instead !", - input->rankOf()); - REQUIRE_TRUE(weights->rankOf() == 4, 0, - "CUSTOM CONV2D_BP OP: rank of weights array must be equal to 4, but got %i instead !", - weights->rankOf()); - REQUIRE_TRUE(gradO->rankOf() == 4, 0, - "CUSTOM CONV2D_BP OP: rank of output's gradients (next epsilon) array must be equal to 4, but got %i instead !", - gradO->rankOf()); + REQUIRE_TRUE(input->rankOf() == 4, 0,"CUSTOM CONV2D_BP OP: rank of input array must be equal to 4, but got %i instead !",input->rankOf()); + REQUIRE_TRUE(weights->rankOf() == 4, 0,"CUSTOM CONV2D_BP OP: rank of weights array must be equal to 4, but got %i instead !",weights->rankOf()); + REQUIRE_TRUE(gradO->rankOf() == 4, 0,"CUSTOM CONV2D_BP OP: rank of output's gradients (next epsilon) array must be equal to 4, but got %i instead !",gradO->rankOf()); - int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; - int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, - indIiH, indWiC, indWoC, indWkH, indOoH); - - ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); - - dnnl_memory_desc_t empty; - dnnl::memory::desc conv_src_md(empty), conv_diff_src_md(empty), conv_weights_md(empty), - conv_diff_weights_md(empty), conv_bias_md(empty), conv_dst_md(empty); - dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_weights_md(empty), - user_diff_weights_md(empty), user_bias_md(empty), user_dst_md(empty); - dnnl::memory::dims conv_strides, conv_padding, conv_padding_r, conv_dilation; - mkldnnUtils::getMKLDNNMemoryDescConv2d(kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, - bS, iC, iH, iW, oC, oH, oW, input, gradI, weights, gradW, - gradB, gradO, - &conv_src_md, &conv_diff_src_md, &conv_weights_md, - &conv_diff_weights_md, &conv_bias_md, &conv_dst_md, - &user_src_md, &user_diff_src_md, &user_weights_md, - &user_diff_weights_md, &user_bias_md, &user_dst_md, - conv_strides, conv_padding, conv_padding_r, conv_dilation); - auto conv_desc = gradB != nullptr - ? convolution_forward::desc(prop_kind::forward, - algorithm::convolution_auto, conv_src_md, - conv_weights_md, conv_bias_md, - conv_dst_md, conv_strides, conv_dilation, conv_padding, - conv_padding_r) - : convolution_forward::desc(prop_kind::forward, - algorithm::convolution_auto, conv_src_md, - conv_weights_md, - conv_dst_md, conv_strides, conv_dilation, conv_padding, - conv_padding_r); - auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc, mkldnnUtils::getEngine( - LaunchContext::defaultContext()->engine())); - if (gradW != nullptr) { - auto convW_desc = gradB != nullptr - ? convolution_backward_weights::desc( - algorithm::convolution_auto, conv_src_md, conv_diff_weights_md, conv_bias_md, - conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r) - : convolution_backward_weights::desc( - algorithm::convolution_auto, conv_src_md, conv_diff_weights_md, - conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r); - - auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - dnnl::stream stream(engine); - auto convW_prim_desc = convolution_backward_weights::primitive_desc(convW_desc, engine, - conv_prim_desc); - auto userW_src_memory = dnnl::memory(user_src_md, engine, - const_cast(input)->buffer()); - auto userW_weights_memory = dnnl::memory(user_diff_weights_md, engine, gradW->buffer()); - auto userW_dst_memory = dnnl::memory(user_dst_md, engine, - const_cast(gradO)->buffer()); - - auto convW_src_memory = userW_src_memory; - if (convW_prim_desc.src_desc() != userW_src_memory.get_desc()) { - convW_src_memory = dnnl::memory(convW_prim_desc.src_desc(), engine); - reorder(userW_src_memory, convW_src_memory).execute(stream, userW_src_memory, - convW_src_memory); - } - - auto convW_weights_memory = userW_weights_memory; - if (convW_prim_desc.diff_weights_desc() != userW_weights_memory.get_desc()) { - convW_weights_memory = dnnl::memory(convW_prim_desc.diff_weights_desc(), engine); - } - - auto convW_dst_memory = userW_dst_memory; - if (convW_prim_desc.diff_dst_desc() != userW_dst_memory.get_desc()) { - convW_dst_memory = dnnl::memory(convW_prim_desc.diff_dst_desc(), engine); - reorder(userW_dst_memory, convW_dst_memory).execute(stream, userW_dst_memory, - convW_dst_memory); - } - - if (gradB != nullptr) { - auto convW_bias_memory = dnnl::memory(convW_prim_desc.diff_bias_desc(), engine, - gradB->buffer()); - convolution_backward_weights(convW_prim_desc).execute(stream, - {{DNNL_ARG_SRC, convW_src_memory}, - {DNNL_ARG_DIFF_DST, convW_dst_memory}, - {DNNL_ARG_DIFF_WEIGHTS, convW_weights_memory}, - {DNNL_ARG_DIFF_BIAS, convW_bias_memory}}); - } else { - convolution_backward_weights(convW_prim_desc).execute(stream, - {{DNNL_ARG_SRC, convW_src_memory}, - {DNNL_ARG_DIFF_DST, convW_dst_memory}, - {DNNL_ARG_DIFF_WEIGHTS, convW_weights_memory}}); - } - - if (convW_prim_desc.diff_weights_desc() != userW_weights_memory.get_desc()) { - reorder(convW_weights_memory, userW_weights_memory).execute(stream, convW_weights_memory, - userW_weights_memory); - } - - stream.wait(); - } - - if (gradI != nullptr) { - auto convI_desc = - convolution_backward_data::desc(algorithm::convolution_auto, conv_diff_src_md, - conv_weights_md, conv_dst_md, conv_strides, conv_dilation, - conv_padding, conv_padding_r); - - auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - dnnl::stream stream(engine); - auto convI_prim_desc = convolution_backward_data::primitive_desc(convI_desc, engine, - conv_prim_desc); - auto userI_src_memory = dnnl::memory(user_diff_src_md, engine, gradI->buffer()); - auto userI_weights_memory = dnnl::memory(user_weights_md, engine, - const_cast(weights)->buffer()); - auto userI_dst_memory = dnnl::memory(user_dst_md, engine, - const_cast(gradO)->buffer()); - - auto convI_src_memory = userI_src_memory; - if (convI_prim_desc.diff_src_desc() != userI_src_memory.get_desc()) { - convI_src_memory = dnnl::memory(convI_prim_desc.diff_src_desc(), engine); - } - - auto convI_weights_memory = userI_weights_memory; - if (convI_prim_desc.weights_desc() != userI_weights_memory.get_desc()) { - convI_weights_memory = dnnl::memory(convI_prim_desc.weights_desc(), engine); - reorder(userI_weights_memory, convI_weights_memory).execute(stream, userI_weights_memory, - convI_weights_memory); - } - - auto convI_dst_memory = userI_dst_memory; - if (convI_prim_desc.diff_dst_desc() != userI_dst_memory.get_desc()) { - convI_dst_memory = dnnl::memory(convI_prim_desc.diff_dst_desc(), engine); - reorder(userI_dst_memory, convI_dst_memory).execute(stream, userI_dst_memory, - convI_dst_memory); - } - - convolution_backward_data(convI_prim_desc).execute(stream, - {{DNNL_ARG_DIFF_DST, convI_dst_memory}, - {DNNL_ARG_WEIGHTS, convI_weights_memory}, - {DNNL_ARG_DIFF_SRC, convI_src_memory}}); - - if (convI_prim_desc.diff_src_desc() != userI_src_memory.get_desc()) { - reorder(convI_src_memory, userI_src_memory).execute(stream, convI_src_memory, - userI_src_memory); - } - - stream.wait(); - }; + conv2dBpMKLDNN(block, input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/conv3d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/conv3d.cpp index 0a79df793..747d84c36 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/conv3d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/conv3d.cpp @@ -34,62 +34,23 @@ namespace ops { namespace platforms { ////////////////////////////////////////////////////////////////////// -PLATFORM_IMPL(conv3dnew, ENGINE_CPU) { - auto input = INPUT_VARIABLE( - 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always - auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] - auto output = OUTPUT_VARIABLE( - 0); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW) - - REQUIRE_TRUE(input->rankOf() == 5, 0, - "CUSTOM CONV3D OP: rank of input array must be equal to 5, but got %i instead !", - input->rankOf()); - REQUIRE_TRUE(weights->rankOf() == 5, 0, - "CUSTOM CONV3D OP: rank of weights array must be equal to 5, but got %i instead !", - weights->rankOf()); - - int kD = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(weights->sizeAt(0));// filter(kernel) depth - int kH = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(weights->sizeAt(1));// filter(kernel) height - int kW = INT_ARG(2) > 0 ? INT_ARG(2) : static_cast(weights->sizeAt(2));// filter(kernel) width - int sD = INT_ARG(3); // strides depth - int sH = INT_ARG(4); // strides height - int sW = INT_ARG(5); // strides width - int pD = INT_ARG(6); // paddings depth - int pH = INT_ARG(7); // paddings height - int pW = INT_ARG(8); // paddings width - int dD = INT_ARG(9); // dilations depth - int dH = INT_ARG(10); // dilations height - int dW = INT_ARG(11); // dilations width - int isSameMode = INT_ARG(12); // 0-SAME, 1-VALID - int isNCDHW = - block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW +static void conv3dMKLDNN(nd4j::graph::Context &block, + const NDArray *input, const NDArray *weights, const NDArray *bias, + NDArray *output, + const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, int pD, int pH, int pW, const int dD, const int dH, const int dW, + const int paddingMode, const int isNCDHW) { int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, - indIOioC, indIOioD, indWiC, indWoC, indWkD); - - std::string expectedWeightsShape = ShapeUtils::shapeAsString({kD, kH, kW, iC, oC}); - REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weights), 0, - "CUSTOM CONV3D OP: wrong shape of weights array, expected is %s, but got %s instead !", - expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weights).c_str()); - if (bias) - REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, - "CUSTOM CONV3D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", - oC, bias->rankOf(), bias->lengthOf()); - - if (isSameMode) // SAME - ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW); - + ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); dnnl_memory_desc_t empty; - dnnl::memory::desc conv_src_md(empty), conv_weights_md(empty), conv_bias_md(empty), conv_dst_md( - empty); - dnnl::memory::desc user_src_md(empty), user_weights_md(empty), user_bias_md(empty), user_dst_md( - empty); + dnnl::memory::desc conv_src_md(empty), conv_weights_md(empty), conv_bias_md(empty), conv_dst_md( empty); + dnnl::memory::desc user_src_md(empty), user_weights_md(empty), user_bias_md(empty), user_dst_md( empty); + dnnl::memory::dims conv_strides, conv_padding, conv_padding_r, conv_dilation; - mkldnnUtils::getMKLDNNMemoryDescConv3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isSameMode, + + mkldnnUtils::getMKLDNNMemoryDescConv3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, paddingMode, isNCDHW, bS, iC, iD, iH, iW, oC, oD, oH, oW, input, nullptr, weights, nullptr, bias, output, @@ -98,151 +59,73 @@ PLATFORM_IMPL(conv3dnew, ENGINE_CPU) { &user_src_md, nullptr, &user_weights_md, nullptr, &user_bias_md, &user_dst_md, conv_strides, conv_padding, conv_padding_r, conv_dilation); - auto conv_desc = bias != nullptr - ? convolution_forward::desc(prop_kind::forward, - algorithm::convolution_auto, conv_src_md, - conv_weights_md, conv_bias_md, - conv_dst_md, conv_strides, conv_dilation, conv_padding, - conv_padding_r) - : convolution_forward::desc(prop_kind::forward, - algorithm::convolution_auto, conv_src_md, - conv_weights_md, - conv_dst_md, conv_strides, conv_dilation, conv_padding, - conv_padding_r); + auto conv_desc = bias != nullptr ? convolution_forward::desc(prop_kind::forward, algorithm::convolution_auto, conv_src_md, conv_weights_md, conv_bias_md, conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r) + : convolution_forward::desc(prop_kind::forward, algorithm::convolution_auto, conv_src_md, conv_weights_md, conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r); + auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); dnnl::stream stream(engine); + auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc, engine); auto user_src_memory = dnnl::memory(user_src_md, engine, const_cast(input)->buffer()); - auto user_weights_memory = dnnl::memory(user_weights_md, engine, - const_cast(weights)->buffer()); + auto user_weights_memory = dnnl::memory(user_weights_md, engine, const_cast(weights)->buffer()); auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer()); + auto conv_src_memory = user_src_memory; if (conv_prim_desc.src_desc() != user_src_memory.get_desc()) { conv_src_memory = dnnl::memory(conv_prim_desc.src_desc(), engine); reorder(user_src_memory, conv_src_memory).execute(stream, user_src_memory, conv_src_memory); } + auto conv_weights_memory = user_weights_memory; if (conv_prim_desc.weights_desc() != user_weights_memory.get_desc()) { conv_weights_memory = dnnl::memory(conv_prim_desc.weights_desc(), engine); - reorder(user_weights_memory, conv_weights_memory).execute(stream, user_weights_memory, - conv_weights_memory); + reorder(user_weights_memory, conv_weights_memory).execute(stream, user_weights_memory, conv_weights_memory); } + auto conv_dst_memory = user_dst_memory; if (conv_prim_desc.dst_desc() != user_dst_memory.get_desc()) { conv_dst_memory = dnnl::memory(conv_prim_desc.dst_desc(), engine); } + if (bias != nullptr) { - auto conv_bias_memory = dnnl::memory(conv_prim_desc.bias_desc(), engine, bias->buffer()); + auto conv_bias_memory = dnnl::memory(conv_prim_desc.bias_desc(), engine, bias->getBuffer()); convolution_forward(conv_prim_desc).execute(stream, {{DNNL_ARG_SRC, conv_src_memory}, {DNNL_ARG_WEIGHTS, conv_weights_memory}, {DNNL_ARG_BIAS, conv_bias_memory}, {DNNL_ARG_DST, conv_dst_memory}}); - } else { + } + else { convolution_forward(conv_prim_desc).execute(stream, {{DNNL_ARG_SRC, conv_src_memory}, {DNNL_ARG_WEIGHTS, conv_weights_memory}, {DNNL_ARG_DST, conv_dst_memory}}); } - if (conv_prim_desc.dst_desc() != user_dst_memory.get_desc()) { + + if (conv_prim_desc.dst_desc() != user_dst_memory.get_desc()) reorder(conv_dst_memory, user_dst_memory).execute(stream, conv_dst_memory, user_dst_memory); - } + stream.wait(); - - return Status::OK(); -} - -PLATFORM_CHECK(conv3dnew, ENGINE_CPU) { - // we don't want to use mkldnn if cpu doesn't support avx/avx2 - if (::optimalLevel() < 2) - return false; - - auto input = INPUT_VARIABLE( - 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always - auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] - auto output = OUTPUT_VARIABLE( - 0); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW) - - return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, weights, bias, output}); } ////////////////////////////////////////////////////////////////////// -PLATFORM_IMPL(conv3dnew_bp, ENGINE_CPU) { - auto input = INPUT_VARIABLE( - 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - auto weights = INPUT_VARIABLE( - 1); // [kD, kH, kW, iC, oC] always - auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] - auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE( - 2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next - - auto gradI = OUTPUT_VARIABLE( - 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon - auto gradW = OUTPUT_VARIABLE( - 1); // [kD, kH, kW, iC, oC] always - auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] - - REQUIRE_TRUE(input->rankOf() == 5, 0, - "CUSTOM CONV3D_BP OP: rank of input array must be equal to 5, but got %i instead !", - input->rankOf()); - REQUIRE_TRUE(weights->rankOf() == 5, 0, - "CUSTOM CONV3D_BP OP: rank of weights array must be equal to 5, but got %i instead !", - weights->rankOf()); - REQUIRE_TRUE(gradO->rankOf() == 5, 0, - "CUSTOM CONV3D_BP OP: rank of output gradients (next epsilon) array must be equal to 5, but got %i instead !", - gradO->rankOf()); - - int kD = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(weights->sizeAt(0));// filter(kernel) depth - int kH = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(weights->sizeAt(1));// filter(kernel) height - int kW = INT_ARG(2) > 0 ? INT_ARG(2) : static_cast(weights->sizeAt(2));// filter(kernel) width - int sD = INT_ARG(3); // strides depth - int sH = INT_ARG(4); // strides height - int sW = INT_ARG(5); // strides width - int pD = INT_ARG(6); // paddings depth - int pH = INT_ARG(7); // paddings height - int pW = INT_ARG(8); // paddings width - int dD = INT_ARG(9); // dilations depth - int dH = INT_ARG(10); // dilations height - int dW = INT_ARG(11); // dilations width - int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID - int isNDHWC = - block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW +static void conv3dBpMKLDNN(nd4j::graph::Context &block, + const NDArray *input, const NDArray *weights, const NDArray *bias, const NDArray *gradO, + NDArray *gradI, NDArray *gradW, NDArray *gradB, + const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, int pD, int pH, int pW, const int dD, const int dH, const int dW, + const int paddingMode, const int isNCDHW) { int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNDHWC, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, - indIOioC, indIOioD, indWiC, indWoC, indWkD); - - if(isSameMode) // SAME - ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW); - - int trueoD, trueoH, trueoW; // true output depth/height/width - ConvolutionUtils::calcOutSizePool3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, - dW, iD, iH, iW, isSameMode); - - std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx( - {bS, oC, trueoD, trueoH, trueoW, 0, indIOioC, indIOioD, indIOioD + 1, indIOioD + 2})); - std::string expectedWeightsShape = ShapeUtils::shapeAsString({kD, kH, kW, iC, oC}); - REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, - "CUSTOM CONV3D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", - expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str()); - REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weights), 0, - "CUSTOM CONV3D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", - expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weights).c_str()); - if (bias) - REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, - "CUSTOM CONV3D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", - oC, bias->rankOf(), bias->lengthOf()); - + ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); dnnl_memory_desc_t empty; - dnnl::memory::desc conv_src_md(empty), conv_diff_src_md(empty), conv_weights_md(empty), - conv_diff_weights_md(empty), conv_bias_md(empty), conv_dst_md(empty); - dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_weights_md(empty), - user_diff_weights_md(empty), user_bias_md(empty), user_dst_md(empty); + dnnl::memory::desc conv_src_md(empty), conv_diff_src_md(empty), conv_weights_md(empty), conv_diff_weights_md(empty), conv_bias_md(empty), conv_dst_md(empty); + dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_weights_md(empty), user_diff_weights_md(empty), user_bias_md(empty), user_dst_md(empty); + dnnl::memory::dims conv_strides, conv_padding, conv_padding_r, conv_dilation; - mkldnnUtils::getMKLDNNMemoryDescConv3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isSameMode, - isNDHWC, + + mkldnnUtils::getMKLDNNMemoryDescConv3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, paddingMode, + isNCDHW, bS, iC, iD, iH, iW, oC, oD, oH, oW, input, gradI, weights, gradW, gradB, gradO, &conv_src_md, &conv_diff_src_md, &conv_weights_md, @@ -250,43 +133,30 @@ PLATFORM_IMPL(conv3dnew_bp, ENGINE_CPU) { &user_src_md, &user_diff_src_md, &user_weights_md, &user_diff_weights_md, &user_bias_md, &user_dst_md, conv_strides, conv_padding, conv_padding_r, conv_dilation); - auto conv_desc = gradB != nullptr - ? convolution_forward::desc(prop_kind::forward, - algorithm::convolution_auto, conv_src_md, - conv_weights_md, conv_bias_md, - conv_dst_md, conv_strides, conv_dilation, conv_padding, - conv_padding_r) - : convolution_forward::desc(prop_kind::forward, - algorithm::convolution_auto, conv_src_md, - conv_weights_md, - conv_dst_md, conv_strides, conv_dilation, conv_padding, - conv_padding_r); - auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc, mkldnnUtils::getEngine( - LaunchContext::defaultContext()->engine())); - if (gradW != nullptr) { - auto convW_desc = gradB != nullptr - ? convolution_backward_weights::desc( - algorithm::convolution_auto, conv_src_md, conv_diff_weights_md, conv_bias_md, - conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r) - : convolution_backward_weights::desc( - algorithm::convolution_auto, conv_src_md, conv_diff_weights_md, - conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r); - auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - dnnl::stream stream(engine); - auto convW_prim_desc = convolution_backward_weights::primitive_desc(convW_desc, engine, - conv_prim_desc); - auto userW_src_memory = dnnl::memory(user_src_md, engine, - const_cast(input)->buffer()); + auto conv_desc = gradB != nullptr ? convolution_forward::desc(prop_kind::forward, algorithm::convolution_auto, conv_src_md, conv_weights_md, conv_bias_md, conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r) + : convolution_forward::desc(prop_kind::forward, algorithm::convolution_auto, conv_src_md, conv_weights_md, conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r); + + auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc, mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine())); + + auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); + dnnl::stream stream(engine); + + if (gradW != nullptr) { + + auto convW_desc = gradB != nullptr ? convolution_backward_weights::desc(algorithm::convolution_auto, conv_src_md, conv_diff_weights_md, conv_bias_md, conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r) + : convolution_backward_weights::desc(algorithm::convolution_auto, conv_src_md, conv_diff_weights_md, conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); + + auto convW_prim_desc = convolution_backward_weights::primitive_desc(convW_desc, engine, conv_prim_desc); + + auto userW_src_memory = dnnl::memory(user_src_md, engine, const_cast(input)->buffer()); auto userW_weights_memory = dnnl::memory(user_diff_weights_md, engine, gradW->buffer()); - auto userW_dst_memory = dnnl::memory(user_dst_md, engine, - const_cast(gradO)->buffer()); + auto userW_dst_memory = dnnl::memory(user_dst_md, engine, const_cast(gradO)->buffer()); auto convW_src_memory = userW_src_memory; if (convW_prim_desc.src_desc() != userW_src_memory.get_desc()) { convW_src_memory = dnnl::memory(convW_prim_desc.src_desc(), engine); - reorder(userW_src_memory, convW_src_memory).execute(stream, userW_src_memory, - convW_src_memory); + reorder(userW_src_memory, convW_src_memory).execute(stream, userW_src_memory, convW_src_memory); } auto convW_weights_memory = userW_weights_memory; @@ -297,65 +167,53 @@ PLATFORM_IMPL(conv3dnew_bp, ENGINE_CPU) { auto convW_dst_memory = userW_dst_memory; if (convW_prim_desc.diff_dst_desc() != userW_dst_memory.get_desc()) { convW_dst_memory = dnnl::memory(convW_prim_desc.diff_dst_desc(), engine); - reorder(userW_dst_memory, convW_dst_memory).execute(stream, userW_dst_memory, - convW_dst_memory); + reorder(userW_dst_memory, convW_dst_memory).execute(stream, userW_dst_memory, convW_dst_memory); } if (gradB != nullptr) { - auto convW_bias_memory = dnnl::memory(convW_prim_desc.diff_bias_desc(), engine, - gradB->buffer()); + + auto convW_bias_memory = dnnl::memory(convW_prim_desc.diff_bias_desc(), engine, gradB->buffer()); + convolution_backward_weights(convW_prim_desc).execute(stream, {{DNNL_ARG_SRC, convW_src_memory}, {DNNL_ARG_DIFF_DST, convW_dst_memory}, {DNNL_ARG_DIFF_WEIGHTS, convW_weights_memory}, {DNNL_ARG_DIFF_BIAS, convW_bias_memory}}); - } else { + } + else { convolution_backward_weights(convW_prim_desc).execute(stream, {{DNNL_ARG_SRC, convW_src_memory}, {DNNL_ARG_DIFF_DST, convW_dst_memory}, {DNNL_ARG_DIFF_WEIGHTS, convW_weights_memory}}); } - if (convW_prim_desc.diff_weights_desc() != userW_weights_memory.get_desc()) { - reorder(convW_weights_memory, userW_weights_memory).execute(stream, convW_weights_memory, - userW_weights_memory); - } + if (convW_prim_desc.diff_weights_desc() != userW_weights_memory.get_desc()) + reorder(convW_weights_memory, userW_weights_memory).execute(stream, convW_weights_memory, userW_weights_memory); stream.wait(); } if (gradI != nullptr) { - auto convI_desc = convolution_backward_data::desc(algorithm::convolution_auto, - conv_diff_src_md, conv_weights_md, - conv_dst_md, conv_strides, conv_dilation, conv_padding, - conv_padding_r); + auto convI_desc = convolution_backward_data::desc(algorithm::convolution_auto, conv_diff_src_md, conv_weights_md, conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r); - auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - dnnl::stream stream(engine); - auto convI_prim_desc = convolution_backward_data::primitive_desc(convI_desc, engine, - conv_prim_desc); + auto convI_prim_desc = convolution_backward_data::primitive_desc(convI_desc, engine, conv_prim_desc); auto userI_src_memory = dnnl::memory(user_diff_src_md, engine, gradI->buffer()); - auto userI_weights_memory = dnnl::memory(user_weights_md, engine, - const_cast(weights)->buffer()); - auto userI_dst_memory = dnnl::memory(user_dst_md, engine, - const_cast(gradO)->buffer()); + auto userI_weights_memory = dnnl::memory(user_weights_md, engine, const_cast(weights)->buffer()); + auto userI_dst_memory = dnnl::memory(user_dst_md, engine, const_cast(gradO)->buffer()); auto convI_src_memory = userI_src_memory; - if (convI_prim_desc.diff_src_desc() != userI_src_memory.get_desc()) { + if (convI_prim_desc.diff_src_desc() != userI_src_memory.get_desc()) convI_src_memory = dnnl::memory(convI_prim_desc.diff_src_desc(), engine); - } auto convI_weights_memory = userI_weights_memory; if (convI_prim_desc.weights_desc() != userI_weights_memory.get_desc()) { convI_weights_memory = dnnl::memory(convI_prim_desc.weights_desc(), engine); - reorder(userI_weights_memory, convI_weights_memory).execute(stream, userI_weights_memory, - convI_weights_memory); + reorder(userI_weights_memory, convI_weights_memory).execute(stream, userI_weights_memory, convI_weights_memory); } auto convI_dst_memory = userI_dst_memory; if (convI_prim_desc.diff_dst_desc() != userI_dst_memory.get_desc()) { convI_dst_memory = dnnl::memory(convI_prim_desc.diff_dst_desc(), engine); - reorder(userI_dst_memory, convI_dst_memory).execute(stream, userI_dst_memory, - convI_dst_memory); + reorder(userI_dst_memory, convI_dst_memory).execute(stream, userI_dst_memory, convI_dst_memory); } convolution_backward_data(convI_prim_desc).execute(stream, @@ -363,30 +221,128 @@ PLATFORM_IMPL(conv3dnew_bp, ENGINE_CPU) { {DNNL_ARG_WEIGHTS, convI_weights_memory}, {DNNL_ARG_DIFF_SRC, convI_src_memory}}); - if (convI_prim_desc.diff_src_desc() != userI_src_memory.get_desc()) { - reorder(convI_src_memory, userI_src_memory).execute(stream, convI_src_memory, - userI_src_memory); - } - - stream.wait(); + if (convI_prim_desc.diff_src_desc() != userI_src_memory.get_desc()) + reorder(convI_src_memory, userI_src_memory).execute(stream, convI_src_memory, userI_src_memory); } +} + +////////////////////////////////////////////////////////////////////// +PLATFORM_IMPL(conv3dnew, ENGINE_CPU) { + + auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) + auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always + auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] + auto output = OUTPUT_VARIABLE(0); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW) + + REQUIRE_TRUE(input->rankOf() == 5, 0, "CUSTOM CONV3D MKLDNN OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf()); + REQUIRE_TRUE(weights->rankOf() == 5, 0, "CUSTOM CONV3D MKLDNN OP: rank of weights array must be equal to 5, but got %i instead !", weights->rankOf()); + + int kD = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(weights->sizeAt(0));// filter(kernel) depth + int kH = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(weights->sizeAt(1));// filter(kernel) height + int kW = INT_ARG(2) > 0 ? INT_ARG(2) : static_cast(weights->sizeAt(2));// filter(kernel) width + int sD = INT_ARG(3); // strides depth + int sH = INT_ARG(4); // strides height + int sW = INT_ARG(5); // strides width + int pD = INT_ARG(6); // paddings depth + int pH = INT_ARG(7); // paddings height + int pW = INT_ARG(8); // paddings width + int dD = INT_ARG(9); // dilations depth + int dH = INT_ARG(10); // dilations height + int dW = INT_ARG(11); // dilations width + int paddingMode = INT_ARG(12); // 0-SAME, 1-VALID + int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW + + int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; + int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); + + std::string expectedWeightsShape = ShapeUtils::shapeAsString({kD, kH, kW, iC, oC}); + REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weights), 0, "CUSTOM CONV3D MKLDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weights).c_str()); + if (bias) + REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV3D MKLDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); + + if (paddingMode) // SAME + ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW); + + conv3dMKLDNN(block, input, weights, bias, output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, paddingMode, isNCDHW); + + return Status::OK(); +} + +PLATFORM_CHECK(conv3dnew, ENGINE_CPU) { + // we don't want to use mkldnn if cpu doesn't support avx/avx2 + if (::optimalLevel() < 2) + return false; + + auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) + auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always + auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] + auto output = OUTPUT_VARIABLE(0); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW) + + return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, weights, bias, output}); +} + +////////////////////////////////////////////////////////////////////// +PLATFORM_IMPL(conv3dnew_bp, ENGINE_CPU) { + auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) + auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always + auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] + auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next + + auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon + auto gradW = OUTPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always + auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] + + REQUIRE_TRUE(input->rankOf() == 5, 0, "CUSTOM CONV3D_BP MKLDNN OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf()); + REQUIRE_TRUE(weights->rankOf() == 5, 0, "CUSTOM CONV3D_BP MKLDNN OP: rank of weights array must be equal to 5, but got %i instead !", weights->rankOf()); + REQUIRE_TRUE(gradO->rankOf() == 5, 0, "CUSTOM CONV3D_BP MKLDNN OP: rank of output gradients (next epsilon) array must be equal to 5, but got %i instead !", gradO->rankOf()); + + int kD = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(weights->sizeAt(0));// filter(kernel) depth + int kH = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(weights->sizeAt(1));// filter(kernel) height + int kW = INT_ARG(2) > 0 ? INT_ARG(2) : static_cast(weights->sizeAt(2));// filter(kernel) width + int sD = INT_ARG(3); // strides depth + int sH = INT_ARG(4); // strides height + int sW = INT_ARG(5); // strides width + int pD = INT_ARG(6); // paddings depth + int pH = INT_ARG(7); // paddings height + int pW = INT_ARG(8); // paddings width + int dD = INT_ARG(9); // dilations depth + int dH = INT_ARG(10); // dilations height + int dW = INT_ARG(11); // dilations width + int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID + int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW + + int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; + int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); + + if(paddingMode) // SAME + ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW); + + int trueoD, trueoH, trueoW; // true output depth/height/width + ConvolutionUtils::calcOutSizePool3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, paddingMode); + + std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx( {bS, oC, trueoD, trueoH, trueoW, 0, indIOioC, indIOioD, indIOioD + 1, indIOioD + 2})); + std::string expectedWeightsShape = ShapeUtils::shapeAsString({kD, kH, kW, iC, oC}); + REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, "CUSTOM CONV3D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str()); + REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weights), 0, "CUSTOM CONV3D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weights).c_str()); + if (bias) + REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV3D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); + + conv3dBpMKLDNN(block, input, weights, bias, gradO, gradI, gradW, gradB, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, paddingMode, isNCDHW); return Status::OK(); } PLATFORM_CHECK(conv3dnew_bp, ENGINE_CPU) { - auto input = INPUT_VARIABLE( - 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - auto weights = INPUT_VARIABLE( - 1); // [kD, kH, kW, iC, oC] always - auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] - auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE( - 2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next - auto gradI = OUTPUT_VARIABLE( - 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon - auto gradW = OUTPUT_VARIABLE( - 1); // [kD, kH, kW, iC, oC] always + auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) + auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always + auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] + auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next + + auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon + auto gradW = OUTPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] return block.isUseMKLDNN() && diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d.cpp index 6db569eec..d95052c5a 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d.cpp @@ -177,7 +177,7 @@ static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const N } ////////////////////////////////////////////////////////////////////////// -static void deconv2dBackPropMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, +static void deconv2dBpMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int paddingMode) { @@ -492,7 +492,7 @@ PLATFORM_IMPL(deconv2d_bp, ENGINE_CPU) { gradO = new NDArray(gradO->permute({0,3,1,2})); // [bS, oH, oW, oC] -> [bS, oC, oH, oW] } - deconv2dBackPropMKLDNN(input, weights, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode); + deconv2dBpMKLDNN(input, weights, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode); delete weights; delete gradW; diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/depthwiseConv2d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/depthwiseConv2d.cpp index f3b745d09..fc7a1e9e3 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/depthwiseConv2d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/depthwiseConv2d.cpp @@ -421,7 +421,7 @@ PLATFORM_CHECK(depthwise_conv2d, ENGINE_CPU) { return block.isUseMKLDNN() && mC == 1 && ( (xType==DataType::FLOAT32 && wType==DataType::FLOAT32 && bType==DataType::FLOAT32 && zType==DataType::FLOAT32) || - (xType==DataType::HALF && wType==DataType::HALF && bType==DataType::HALF && zType==DataType::HALF) || + (xType==DataType::BFLOAT16 && wType==DataType::BFLOAT16 && bType==DataType::BFLOAT16 && zType==DataType::BFLOAT16) || ((xType==DataType::UINT8 || xType==DataType::INT8) && wType==DataType::INT8 && (zType==DataType::UINT8 || zType==DataType::INT8 || zType==DataType::INT32 || zType==DataType::FLOAT32) && bType == zType) ); } diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling2d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling2d.cpp index 975cf7fe1..69aee8fad 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling2d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling2d.cpp @@ -29,117 +29,258 @@ using namespace dnnl; -namespace nd4j { - namespace ops { - namespace platforms { - PLATFORM_IMPL(maxpool2d, ENGINE_CPU) { - auto input = INPUT_VARIABLE(0); +namespace nd4j { +namespace ops { +namespace platforms { - REQUIRE_TRUE(input->rankOf() == 4, 0, "Input should have rank of 4, but got %i instead", - input->rankOf()); +////////////////////////////////////////////////////////////////////////// +PLATFORM_IMPL(maxpool2d, ENGINE_CPU) { + auto input = INPUT_VARIABLE(0); - // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; - auto argI = *(block.getIArguments()); - auto output = OUTPUT_VARIABLE(0); + REQUIRE_TRUE(input->rankOf() == 4, 0, "Input should have rank of 4, but got %i instead", + input->rankOf()); - const auto kH = INT_ARG(0); - const auto kW = INT_ARG(1); - const auto sH = INT_ARG(2); - const auto sW = INT_ARG(3); - int pH = INT_ARG(4); - int pW = INT_ARG(5); - const auto dH = INT_ARG(6); - const auto dW = INT_ARG(7); - const auto isSameMode = static_cast(INT_ARG(8)); + // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; + auto argI = *(block.getIArguments()); + auto output = OUTPUT_VARIABLE(0); - REQUIRE_TRUE(dH != 0 && dW != 0, 0, "AVGPOOL2D op: dilation must not be zero, but got instead {%i, %i}", - dH, dW); + const auto kH = INT_ARG(0); + const auto kW = INT_ARG(1); + const auto sH = INT_ARG(2); + const auto sW = INT_ARG(3); + int pH = INT_ARG(4); + int pW = INT_ARG(5); + const auto dH = INT_ARG(6); + const auto dW = INT_ARG(7); + const auto isSameMode = static_cast(INT_ARG(8)); - int oH = 0; - int oW = 0; + REQUIRE_TRUE(dH != 0 && dW != 0, 0, "AVGPOOL2D op: dilation must not be zero, but got instead {%i, %i}", + dH, dW); - int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC + int oH = 0; + int oW = 0; - const int iH = static_cast(isNCHW ? input->sizeAt(2) : input->sizeAt(1)); - const int iW = static_cast(isNCHW ? input->sizeAt(3) : input->sizeAt(2)); + int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC - if (!isNCHW) { - input = new NDArray( - input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] - output = new NDArray( - output->permute({0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW] - } + const int iH = static_cast(isNCHW ? input->sizeAt(2) : input->sizeAt(1)); + const int iW = static_cast(isNCHW ? input->sizeAt(3) : input->sizeAt(2)); - ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode); - - if (isSameMode) - ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); - - const int bS = input->sizeAt(0); - const int iC = input->sizeAt(1); - const int oC = output->sizeAt(1); - - auto poolingMode = PoolingType::MAX_POOL; - int extraParam0 = 1; - - dnnl_memory_desc_t empty; - dnnl::memory::desc pool_src_md(empty), pool_dst_md(empty); - dnnl::memory::desc user_src_md(empty), user_dst_md(empty); - dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r; - dnnl::algorithm algorithm; - - mkldnnUtils::getMKLDNNMemoryDescPool2d(kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0, - true, - bS, iC, iH, iW, oC, oH, oW, input, nullptr, output, - algorithm, - &pool_src_md, nullptr, &pool_dst_md, &user_src_md, nullptr, - &user_dst_md, - pool_strides, pool_kernel, pool_padding, pool_padding_r); - - auto pool_desc = pooling_forward::desc(prop_kind::forward_inference, algorithm, pool_src_md, - pool_dst_md, - pool_strides, pool_kernel, pool_padding, pool_padding_r); - - auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine); - auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer()); - auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer()); - - auto pool_src_memory = user_src_memory; - dnnl::stream stream(engine); - if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) { - pool_src_memory = dnnl::memory(pool_prim_desc.src_desc(), engine); - reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory); - } - - auto pool_dst_memory = user_dst_memory; - if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) { - pool_dst_memory = dnnl::memory(pool_prim_desc.dst_desc(), engine); - } - - pooling_forward(pool_prim_desc).execute(stream, {{DNNL_ARG_SRC, pool_src_memory}, - {DNNL_ARG_DST, pool_dst_memory}}); - - if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) { - reorder(pool_dst_memory, user_dst_memory).execute(stream, pool_dst_memory, user_dst_memory); - } - - stream.wait(); - - if (!isNCHW) { - delete input; - delete output; - } - - return Status::OK(); - } - - PLATFORM_CHECK(maxpool2d, ENGINE_CPU) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output}); - } - } + if (!isNCHW) { + input = new NDArray( + input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] + output = new NDArray( + output->permute({0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW] } + + ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode); + + if (isSameMode) + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); + + const int bS = input->sizeAt(0); + const int iC = input->sizeAt(1); + const int oC = output->sizeAt(1); + + auto poolingMode = PoolingType::MAX_POOL; + int extraParam0 = 1; + + dnnl_memory_desc_t empty; + dnnl::memory::desc pool_src_md(empty), pool_dst_md(empty); + dnnl::memory::desc user_src_md(empty), user_dst_md(empty); + dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r; + dnnl::algorithm algorithm; + + mkldnnUtils::getMKLDNNMemoryDescPool2d(kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0, + true, + bS, iC, iH, iW, oC, oH, oW, input, nullptr, output, + algorithm, + &pool_src_md, nullptr, &pool_dst_md, &user_src_md, nullptr, + &user_dst_md, + pool_strides, pool_kernel, pool_padding, pool_padding_r); + + auto pool_desc = pooling_forward::desc(prop_kind::forward_inference, algorithm, pool_src_md, + pool_dst_md, + pool_strides, pool_kernel, pool_padding, pool_padding_r); + + auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); + auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine); + auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer()); + auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer()); + + auto pool_src_memory = user_src_memory; + dnnl::stream stream(engine); + if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) { + pool_src_memory = dnnl::memory(pool_prim_desc.src_desc(), engine); + reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory); + } + + auto pool_dst_memory = user_dst_memory; + if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) { + pool_dst_memory = dnnl::memory(pool_prim_desc.dst_desc(), engine); + } + + pooling_forward(pool_prim_desc).execute(stream, {{DNNL_ARG_SRC, pool_src_memory}, + {DNNL_ARG_DST, pool_dst_memory}}); + + if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) { + reorder(pool_dst_memory, user_dst_memory).execute(stream, pool_dst_memory, user_dst_memory); + } + + stream.wait(); + + if (!isNCHW) { + delete input; + delete output; + } + + return Status::OK(); +} + +////////////////////////////////////////////////////////////////////////// +PLATFORM_CHECK(maxpool2d, ENGINE_CPU) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output}); +} + +////////////////////////////////////////////////////////////////////////// +PLATFORM_IMPL(maxpool2d_bp, ENGINE_CPU) { + + auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + auto gradO = INPUT_VARIABLE(1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next + auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon + + int kH = INT_ARG(0); // filter(kernel) height + int kW = INT_ARG(1); // filter(kernel) width + int sH = INT_ARG(2); // strides height + int sW = INT_ARG(3); // strides width + int pH = INT_ARG(4); // paddings height + int pW = INT_ARG(5); // paddings width + int dH = INT_ARG(6); // dilations height + int dW = INT_ARG(7); // dilations width + int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME + int extraParam0 = INT_ARG(9); + int isNCHW = + block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC + + REQUIRE_TRUE(input->rankOf() == 4, 0, + "AVGPOOL2D_BP op: input should have rank of 4, but got %i instead", input->rankOf()); + REQUIRE_TRUE(dH != 0 && dW != 0, 0, + "AVGPOOL2D_BP op: dilation must not be zero, but got instead {%i, %i}", dH, dW); + + int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, + indIiH, indWiC, indWoC, indWkH, indOoH); + + std::string expectedGradOShape = ShapeUtils::shapeAsString( + ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, oH, oW, 0, indIOioC, indIiH, indIiH + 1})); + std::string expectedGradIShape = ShapeUtils::shapeAsString( + ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, iH, iW, 0, indIOioC, indIiH, indIiH + 1})); + REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, + "AVGPOOL2D_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", + expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str()); + REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0, + "AVGPOOL2D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", + expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str()); + + + if (!isNCHW) { + input = new NDArray(input->permute( + {0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] + gradI = new NDArray(gradI->permute( + {0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] + gradO = new NDArray(gradO->permute( + {0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW] + } + + if (isSameMode) // SAME + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); + + auto poolingMode = PoolingType::MAX_POOL; + + dnnl_memory_desc_t empty; + dnnl::memory::desc pool_src_md(empty), pool_diff_src_md(empty), pool_dst_md(empty); + dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_dst_md(empty); + dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r; + dnnl::algorithm algorithm; + + mkldnnUtils::getMKLDNNMemoryDescPool2d(kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0, + true, + bS, iC, iH, iW, oC, oH, oW, input, gradI, gradO, algorithm, + &pool_src_md, &pool_diff_src_md, &pool_dst_md, &user_src_md, + &user_diff_src_md, &user_dst_md, + pool_strides, pool_kernel, pool_padding, pool_padding_r); + + // input is sometimes null, so we can't rely on pool_src_md being valid + auto pool_desc = pooling_forward::desc(prop_kind::forward, algorithm, + input->buffer() != nullptr ? pool_src_md : pool_diff_src_md, + pool_dst_md, pool_strides, pool_kernel, pool_padding, + pool_padding_r); + + auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); + dnnl::stream stream(engine); + auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine); + + auto poolB_desc = pooling_backward::desc(algorithm, pool_diff_src_md, pool_dst_md, + pool_strides, pool_kernel, pool_padding, pool_padding_r); + + auto poolB_prim_desc = pooling_backward::primitive_desc(poolB_desc, engine, pool_prim_desc); + auto userB_src_memory = dnnl::memory(user_src_md, engine, gradI->buffer()); + auto userB_dst_memory = dnnl::memory(user_dst_md, engine, gradO->buffer()); + + auto poolB_src_memory = userB_src_memory; + if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) { + poolB_src_memory = dnnl::memory(poolB_prim_desc.diff_src_desc(), engine); + } + + auto poolB_dst_memory = userB_dst_memory; + if (poolB_prim_desc.diff_dst_desc() != userB_dst_memory.get_desc()) { + poolB_dst_memory = dnnl::memory(poolB_prim_desc.diff_dst_desc(), engine); + reorder(userB_dst_memory, poolB_dst_memory).execute(stream, userB_dst_memory, poolB_dst_memory); + } + + auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer()); + auto pool_src_memory = user_src_memory; + if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) { + pool_src_memory = dnnl::memory(pool_prim_desc.src_desc(), engine); + reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory); + } + + auto pool_dst_memory = dnnl::memory(pool_prim_desc.dst_desc(), engine); + auto pool_workspace_memory = dnnl::memory(pool_prim_desc.workspace_desc(), engine); + + pooling_forward(pool_prim_desc).execute(stream, {{DNNL_ARG_SRC, pool_src_memory}, + {DNNL_ARG_DST, pool_dst_memory}, + {DNNL_ARG_WORKSPACE, pool_workspace_memory}}); + // probably wrong, fix that + pooling_backward(poolB_prim_desc).execute(stream, {{DNNL_ARG_DIFF_DST, poolB_dst_memory}, + {DNNL_ARG_WORKSPACE, pool_workspace_memory}, + {DNNL_ARG_DIFF_SRC, poolB_src_memory}}); + + if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) { + reorder(poolB_src_memory, userB_src_memory).execute(stream, poolB_src_memory, userB_src_memory); + } + + stream.wait(); + + if (!isNCHW) { + delete input; + delete gradI; + delete gradO; + } + + return Status::OK(); +} + +PLATFORM_CHECK(maxpool2d_bp, ENGINE_CPU) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output}); +} + +} +} } \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling2d_bp.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling2d_bp.cpp deleted file mode 100644 index 686bdc7fb..000000000 --- a/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling2d_bp.cpp +++ /dev/null @@ -1,174 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// @author saudet -// @author raver119@gmail.com -// - -#include -#include -#include - -#include -#include "mkldnnUtils.h" -#include - -using namespace dnnl; - -namespace nd4j { - namespace ops { - namespace platforms { - PLATFORM_IMPL(maxpool2d_bp, ENGINE_CPU) { - auto input = INPUT_VARIABLE( - 0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - auto gradO = INPUT_VARIABLE( - 1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next - auto gradI = OUTPUT_VARIABLE( - 0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon - - int kH = INT_ARG(0); // filter(kernel) height - int kW = INT_ARG(1); // filter(kernel) width - int sH = INT_ARG(2); // strides height - int sW = INT_ARG(3); // strides width - int pH = INT_ARG(4); // paddings height - int pW = INT_ARG(5); // paddings width - int dH = INT_ARG(6); // dilations height - int dW = INT_ARG(7); // dilations width - int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME - int extraParam0 = INT_ARG(9); - int isNCHW = - block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC - - REQUIRE_TRUE(input->rankOf() == 4, 0, - "AVGPOOL2D_BP op: input should have rank of 4, but got %i instead", input->rankOf()); - REQUIRE_TRUE(dH != 0 && dW != 0, 0, - "AVGPOOL2D_BP op: dilation must not be zero, but got instead {%i, %i}", dH, dW); - - int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; - int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, - indIiH, indWiC, indWoC, indWkH, indOoH); - - std::string expectedGradOShape = ShapeUtils::shapeAsString( - ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, oH, oW, 0, indIOioC, indIiH, indIiH + 1})); - std::string expectedGradIShape = ShapeUtils::shapeAsString( - ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, iH, iW, 0, indIOioC, indIiH, indIiH + 1})); - REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, - "AVGPOOL2D_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", - expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str()); - REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0, - "AVGPOOL2D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", - expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str()); - - - if (!isNCHW) { - input = new NDArray(input->permute( - {0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] - gradI = new NDArray(gradI->permute( - {0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] - gradO = new NDArray(gradO->permute( - {0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW] - } - - if (isSameMode) // SAME - ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); - - auto poolingMode = PoolingType::MAX_POOL; - - dnnl_memory_desc_t empty; - dnnl::memory::desc pool_src_md(empty), pool_diff_src_md(empty), pool_dst_md(empty); - dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_dst_md(empty); - dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r; - dnnl::algorithm algorithm; - - mkldnnUtils::getMKLDNNMemoryDescPool2d(kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0, - true, - bS, iC, iH, iW, oC, oH, oW, input, gradI, gradO, algorithm, - &pool_src_md, &pool_diff_src_md, &pool_dst_md, &user_src_md, - &user_diff_src_md, &user_dst_md, - pool_strides, pool_kernel, pool_padding, pool_padding_r); - - // input is sometimes null, so we can't rely on pool_src_md being valid - auto pool_desc = pooling_forward::desc(prop_kind::forward, algorithm, - input->buffer() != nullptr ? pool_src_md : pool_diff_src_md, - pool_dst_md, pool_strides, pool_kernel, pool_padding, - pool_padding_r); - - auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - dnnl::stream stream(engine); - auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine); - - auto poolB_desc = pooling_backward::desc(algorithm, pool_diff_src_md, pool_dst_md, - pool_strides, pool_kernel, pool_padding, pool_padding_r); - - auto poolB_prim_desc = pooling_backward::primitive_desc(poolB_desc, engine, pool_prim_desc); - auto userB_src_memory = dnnl::memory(user_src_md, engine, gradI->buffer()); - auto userB_dst_memory = dnnl::memory(user_dst_md, engine, gradO->buffer()); - - auto poolB_src_memory = userB_src_memory; - if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) { - poolB_src_memory = dnnl::memory(poolB_prim_desc.diff_src_desc(), engine); - } - - auto poolB_dst_memory = userB_dst_memory; - if (poolB_prim_desc.diff_dst_desc() != userB_dst_memory.get_desc()) { - poolB_dst_memory = dnnl::memory(poolB_prim_desc.diff_dst_desc(), engine); - reorder(userB_dst_memory, poolB_dst_memory).execute(stream, userB_dst_memory, poolB_dst_memory); - } - - auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer()); - auto pool_src_memory = user_src_memory; - if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) { - pool_src_memory = dnnl::memory(pool_prim_desc.src_desc(), engine); - reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory); - } - - auto pool_dst_memory = dnnl::memory(pool_prim_desc.dst_desc(), engine); - auto pool_workspace_memory = dnnl::memory(pool_prim_desc.workspace_desc(), engine); - - pooling_forward(pool_prim_desc).execute(stream, {{DNNL_ARG_SRC, pool_src_memory}, - {DNNL_ARG_DST, pool_dst_memory}, - {DNNL_ARG_WORKSPACE, pool_workspace_memory}}); - // probably wrong, fix that - pooling_backward(poolB_prim_desc).execute(stream, {{DNNL_ARG_DIFF_DST, poolB_dst_memory}, - {DNNL_ARG_WORKSPACE, pool_workspace_memory}, - {DNNL_ARG_DIFF_SRC, poolB_src_memory}}); - - if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) { - reorder(poolB_src_memory, userB_src_memory).execute(stream, poolB_src_memory, userB_src_memory); - } - - stream.wait(); - - if (!isNCHW) { - delete input; - delete gradI; - delete gradO; - } - - return Status::OK(); - } - - PLATFORM_CHECK(maxpool2d_bp, ENGINE_CPU) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output}); - } - } - } -} \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling3d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling3d.cpp index 604bdcb6b..a37422c55 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling3d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling3d.cpp @@ -28,124 +28,273 @@ using namespace dnnl; -namespace nd4j { - namespace ops { - namespace platforms { - PLATFORM_IMPL(maxpool3dnew, ENGINE_CPU) { - auto input = INPUT_VARIABLE( - 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - auto output = OUTPUT_VARIABLE( - 0); // [bS, oD, oH, oW, iC] (NDHWC) or [bS, iC, oD, oH, oW] (NCDHW) +namespace nd4j { +namespace ops { +namespace platforms { - int kD = INT_ARG(0); // filter(kernel) depth - int kH = INT_ARG(1); // filter(kernel) height - int kW = INT_ARG(2); // filter(kernel) width - int sD = INT_ARG(3); // strides depth - int sH = INT_ARG(4); // strides height - int sW = INT_ARG(5); // strides width - int pD = INT_ARG(6); // paddings depth - int pH = INT_ARG(7); // paddings height - int pW = INT_ARG(8); // paddings width - int dD = INT_ARG(9); // dilations depth - int dH = INT_ARG(10); // dilations height - int dW = INT_ARG(11); // dilations width - int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID - // int extraParam0 = INT_ARG(13); // unnecessary for max case, required only for avg and pnorm cases - int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW +////////////////////////////////////////////////////////////////////////// +PLATFORM_IMPL(maxpool3dnew, ENGINE_CPU) { + auto input = INPUT_VARIABLE( + 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) + auto output = OUTPUT_VARIABLE( + 0); // [bS, oD, oH, oW, iC] (NDHWC) or [bS, iC, oD, oH, oW] (NCDHW) - REQUIRE_TRUE(input->rankOf() == 5, 0, - "MAXPOOL3DNEW OP: rank of input array must be equal to 5, but got %i instead !", - input->rankOf()); - REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, - "MAXPOOL3DNEW op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW); + int kD = INT_ARG(0); // filter(kernel) depth + int kH = INT_ARG(1); // filter(kernel) height + int kW = INT_ARG(2); // filter(kernel) width + int sD = INT_ARG(3); // strides depth + int sH = INT_ARG(4); // strides height + int sW = INT_ARG(5); // strides width + int pD = INT_ARG(6); // paddings depth + int pH = INT_ARG(7); // paddings height + int pW = INT_ARG(8); // paddings width + int dD = INT_ARG(9); // dilations depth + int dH = INT_ARG(10); // dilations height + int dW = INT_ARG(11); // dilations width + int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID + // int extraParam0 = INT_ARG(13); // unnecessary for max case, required only for avg and pnorm cases + int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW - int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; - int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, - indIOioC, indIOioD, indWiC, indWoC, indWkD); + REQUIRE_TRUE(input->rankOf() == 5, 0, + "MAXPOOL3DNEW OP: rank of input array must be equal to 5, but got %i instead !", + input->rankOf()); + REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, + "MAXPOOL3DNEW op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW); - std::string expectedOutputShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx( - {bS, iC, oD, oH, oW, 0, indIOioC, indIOioD, indIOioD + 1, indIOioD + 2})); - REQUIRE_TRUE(expectedOutputShape == ShapeUtils::shapeAsString(output), 0, - "MAXPOOL3D op: wrong shape of output array, expected is %s, but got %s instead !", - expectedOutputShape.c_str(), ShapeUtils::shapeAsString(output).c_str()); - // REQUIRE_TRUE(iD >= kD && iH >= kH && iW >= kW, 0, "MAXPOOL3D OP: the input depth/height/width must be greater or equal to kernel(filter) depth/height/width, but got [%i, %i, %i] and [%i, %i, %i] correspondingly !", iD,iH,iW, kD,kH,kW); - // REQUIRE_TRUE(kD/2 >= pD && kH/2 >= pH && kW/2 >= pW, 0, "MAXPOOL3D OP: pad depth/height/width must not be greater than half of kernel depth/height/width, but got [%i, %i, %i] and [%i, %i, %i] correspondingly !", pD,pH,pW, kD,kH,kW); + int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; + int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, + indIOioC, indIOioD, indWiC, indWoC, indWkD); - if (!isNCDHW) { - input = new NDArray( - input->permute({0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW] - output = new NDArray( - output->permute({0, 4, 1, 2, 3})); // [bS, oD, oH, oW, iC] -> [bS, iC, oD, oH, oW] - } + std::string expectedOutputShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx( + {bS, iC, oD, oH, oW, 0, indIOioC, indIOioD, indIOioD + 1, indIOioD + 2})); + REQUIRE_TRUE(expectedOutputShape == ShapeUtils::shapeAsString(output), 0, + "MAXPOOL3D op: wrong shape of output array, expected is %s, but got %s instead !", + expectedOutputShape.c_str(), ShapeUtils::shapeAsString(output).c_str()); + // REQUIRE_TRUE(iD >= kD && iH >= kH && iW >= kW, 0, "MAXPOOL3D OP: the input depth/height/width must be greater or equal to kernel(filter) depth/height/width, but got [%i, %i, %i] and [%i, %i, %i] correspondingly !", iD,iH,iW, kD,kH,kW); + // REQUIRE_TRUE(kD/2 >= pD && kH/2 >= pH && kW/2 >= pW, 0, "MAXPOOL3D OP: pad depth/height/width must not be greater than half of kernel depth/height/width, but got [%i, %i, %i] and [%i, %i, %i] correspondingly !", pD,pH,pW, kD,kH,kW); - if (isSameMode) // SAME - ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, - dW); - - - auto poolingMode = PoolingType::MAX_POOL; - auto extraParam0 = 1; - - dnnl_memory_desc_t empty; - dnnl::memory::desc pool_src_md(empty), pool_dst_md(empty); - dnnl::memory::desc user_src_md(empty), user_dst_md(empty); - dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r; - dnnl::algorithm algorithm; - - mkldnnUtils::getMKLDNNMemoryDescPool3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode, - extraParam0, true, - bS, iC, iD, iH, iW, oC, oD, oH, oW, input, nullptr, output, - algorithm, - &pool_src_md, nullptr, &pool_dst_md, &user_src_md, nullptr, - &user_dst_md, - pool_strides, pool_kernel, pool_padding, pool_padding_r); - - auto pool_desc = pooling_forward::desc(prop_kind::forward_inference, algorithm, pool_src_md, - pool_dst_md, pool_strides, pool_kernel, pool_padding, - pool_padding_r); - - auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - dnnl::stream stream(engine); - auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine); - auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer()); - auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer()); - - auto pool_src_memory = user_src_memory; - if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) { - pool_src_memory = dnnl::memory(pool_prim_desc.src_desc(), engine); - reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory); - } - - auto pool_dst_memory = user_dst_memory; - if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) { - pool_dst_memory = dnnl::memory(pool_prim_desc.dst_desc(), engine); - } - - pooling_forward(pool_prim_desc).execute(stream, {{DNNL_ARG_SRC, pool_src_memory}, - {DNNL_ARG_DST, pool_dst_memory}}); - - if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) { - reorder(pool_dst_memory, user_dst_memory).execute(stream, pool_dst_memory, user_dst_memory); - } - - stream.wait(); - - - if (!isNCDHW) { - delete input; - delete output; - } - - return Status::OK(); - } - - PLATFORM_CHECK(maxpool3dnew, ENGINE_CPU) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output}); - } - } + if (!isNCDHW) { + input = new NDArray( + input->permute({0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW] + output = new NDArray( + output->permute({0, 4, 1, 2, 3})); // [bS, oD, oH, oW, iC] -> [bS, iC, oD, oH, oW] } + + if (isSameMode) // SAME + ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, + dW); + + + auto poolingMode = PoolingType::MAX_POOL; + auto extraParam0 = 1; + + dnnl_memory_desc_t empty; + dnnl::memory::desc pool_src_md(empty), pool_dst_md(empty); + dnnl::memory::desc user_src_md(empty), user_dst_md(empty); + dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r; + dnnl::algorithm algorithm; + + mkldnnUtils::getMKLDNNMemoryDescPool3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode, + extraParam0, true, + bS, iC, iD, iH, iW, oC, oD, oH, oW, input, nullptr, output, + algorithm, + &pool_src_md, nullptr, &pool_dst_md, &user_src_md, nullptr, + &user_dst_md, + pool_strides, pool_kernel, pool_padding, pool_padding_r); + + auto pool_desc = pooling_forward::desc(prop_kind::forward_inference, algorithm, pool_src_md, + pool_dst_md, pool_strides, pool_kernel, pool_padding, + pool_padding_r); + + auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); + dnnl::stream stream(engine); + auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine); + auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer()); + auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer()); + + auto pool_src_memory = user_src_memory; + if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) { + pool_src_memory = dnnl::memory(pool_prim_desc.src_desc(), engine); + reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory); + } + + auto pool_dst_memory = user_dst_memory; + if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) { + pool_dst_memory = dnnl::memory(pool_prim_desc.dst_desc(), engine); + } + + pooling_forward(pool_prim_desc).execute(stream, {{DNNL_ARG_SRC, pool_src_memory}, + {DNNL_ARG_DST, pool_dst_memory}}); + + if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) { + reorder(pool_dst_memory, user_dst_memory).execute(stream, pool_dst_memory, user_dst_memory); + } + + stream.wait(); + + + if (!isNCDHW) { + delete input; + delete output; + } + + return Status::OK(); +} + +////////////////////////////////////////////////////////////////////////// +PLATFORM_CHECK(maxpool3dnew, ENGINE_CPU) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output}); +} + +////////////////////////////////////////////////////////////////////////// +PLATFORM_IMPL(maxpool3dnew_bp, ENGINE_CPU) { + auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) + auto gradO = INPUT_VARIABLE(1); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next + auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon + + const int kD = INT_ARG(0); // filter(kernel) depth + const int kH = INT_ARG(1); // filter(kernel) height + const int kW = INT_ARG(2); // filter(kernel) width + const int sD = INT_ARG(3); // strides depth + const int sH = INT_ARG(4); // strides height + const int sW = INT_ARG(5); // strides width + int pD = INT_ARG(6); // paddings depth + int pH = INT_ARG(7); // paddings height + int pW = INT_ARG(8); // paddings width + const int dD = INT_ARG(9); // dilations depth + const int dH = INT_ARG(10); // dilations height + const int dW = INT_ARG(11); // dilations width + const int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID + // int extraParam0 = INT_ARG(13); // unnecessary for max case, required only for avg and pnorm cases + int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW + + REQUIRE_TRUE(input->rankOf() == 5, 0, + "MAXPOOL3D_BP op: input should have rank of 5, but got %i instead", input->rankOf()); + REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, + "MAXPOOL3DNEW op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW); + + int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; + int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, + indIOioC, indIOioD, indWiC, indWoC, indWkD); + + std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx( + {bS, iC, oD, oH, oW, 0, indIOioC, indIOioD, indIOioD + 1, indIOioD + 2})); + std::string expectedGradIShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx( + {bS, iC, iD, iH, iW, 0, indIOioC, indIOioD, indIOioD + 1, indIOioD + 2})); + REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, + "MAXPOOL3D_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", + expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str()); + REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0, + "MAXPOOL3D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", + expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str()); + + if (!isNCDHW) { + input = new NDArray(input->permute( + {0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW] + gradI = new NDArray(gradI->permute( + {0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW] + gradO = new NDArray(gradO->permute( + {0, 4, 1, 2, 3})); // [bS, oD, oH, oW, iC] -> [bS, iC, oD, oH, oW] + } + + if (isSameMode) // SAME + ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, + dW); + + + auto poolingMode = PoolingType::MAX_POOL; + auto extraParam0 = 1; + + dnnl_memory_desc_t empty; + dnnl::memory::desc pool_src_md(empty), pool_diff_src_md(empty), pool_dst_md(empty); + dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_dst_md(empty); + dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r; + dnnl::algorithm algorithm; + + mkldnnUtils::getMKLDNNMemoryDescPool3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode, + extraParam0, true, + bS, iC, iD, iH, iW, oC, oD, oH, oW, input, gradI, gradO, + algorithm, + &pool_src_md, &pool_diff_src_md, &pool_dst_md, &user_src_md, + &user_diff_src_md, &user_dst_md, + pool_strides, pool_kernel, pool_padding, pool_padding_r); + + // input is sometimes null, so we can't rely on pool_src_md being valid + if (input->buffer() == nullptr) { + pool_src_md = pool_diff_src_md; + user_src_md = user_diff_src_md; + } + auto pool_desc = pooling_forward::desc(prop_kind::forward, algorithm, pool_src_md, pool_dst_md, pool_strides, pool_kernel, pool_padding, pool_padding_r); + + auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); + dnnl::stream stream(engine); + auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine); + + auto poolB_desc = pooling_backward::desc(algorithm, pool_diff_src_md, pool_dst_md, pool_strides, pool_kernel, pool_padding, pool_padding_r); + + auto poolB_prim_desc = pooling_backward::primitive_desc(poolB_desc, engine, pool_prim_desc); + auto userB_src_memory = dnnl::memory(user_diff_src_md, engine, gradI->buffer()); + auto userB_dst_memory = dnnl::memory(user_dst_md, engine, gradO->buffer()); + + auto poolB_src_memory = userB_src_memory; + if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) { + poolB_src_memory = dnnl::memory(poolB_prim_desc.diff_src_desc(), engine); + } + + auto poolB_dst_memory = userB_dst_memory; + if (poolB_prim_desc.diff_dst_desc() != userB_dst_memory.get_desc()) { + poolB_dst_memory = dnnl::memory(poolB_prim_desc.diff_dst_desc(), engine); + reorder(userB_dst_memory, poolB_dst_memory).execute(stream, userB_dst_memory, poolB_dst_memory); + } + + + auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer()); + + auto pool_src_memory = user_src_memory; + if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) { + pool_src_memory = dnnl::memory(pool_prim_desc.src_desc(), engine); + reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory); + } + + auto pool_dst_memory = dnnl::memory(pool_prim_desc.dst_desc(), engine); + auto pool_workspace_memory = dnnl::memory(pool_prim_desc.workspace_desc(), engine); + + pooling_forward(pool_prim_desc).execute(stream, {{DNNL_ARG_SRC, pool_src_memory}, + {DNNL_ARG_DST, pool_dst_memory}, + {DNNL_ARG_WORKSPACE, pool_workspace_memory}}); + pooling_backward(poolB_prim_desc).execute(stream, {{DNNL_ARG_DIFF_DST, poolB_dst_memory}, + {DNNL_ARG_WORKSPACE, pool_workspace_memory}, + {DNNL_ARG_DIFF_SRC, poolB_src_memory}}); + + + if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) { + reorder(poolB_src_memory, userB_src_memory).execute(stream, poolB_src_memory, userB_src_memory); + } + + stream.wait(); + + if (!isNCDHW) { + delete input; + delete gradI; + delete gradO; + } + + return Status::OK(); +} + +////////////////////////////////////////////////////////////////////////// +PLATFORM_CHECK(maxpool3dnew_bp, ENGINE_CPU) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output}); +} + +} +} } \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling_3d_bp.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling_3d_bp.cpp deleted file mode 100644 index b684df1bb..000000000 --- a/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling_3d_bp.cpp +++ /dev/null @@ -1,181 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// @author raver119@gmail.com -// - -#include -#include -#include - -#include -#include "mkldnnUtils.h" -#include - -using namespace dnnl; - -namespace nd4j { - namespace ops { - namespace platforms { - PLATFORM_IMPL(maxpool3dnew_bp, ENGINE_CPU) { - auto input = INPUT_VARIABLE( - 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - auto gradO = INPUT_VARIABLE( - 1); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next - auto gradI = OUTPUT_VARIABLE( - 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon - - const int kD = INT_ARG(0); // filter(kernel) depth - const int kH = INT_ARG(1); // filter(kernel) height - const int kW = INT_ARG(2); // filter(kernel) width - const int sD = INT_ARG(3); // strides depth - const int sH = INT_ARG(4); // strides height - const int sW = INT_ARG(5); // strides width - int pD = INT_ARG(6); // paddings depth - int pH = INT_ARG(7); // paddings height - int pW = INT_ARG(8); // paddings width - const int dD = INT_ARG(9); // dilations depth - const int dH = INT_ARG(10); // dilations height - const int dW = INT_ARG(11); // dilations width - const int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID - // int extraParam0 = INT_ARG(13); // unnecessary for max case, required only for avg and pnorm cases - int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW - - REQUIRE_TRUE(input->rankOf() == 5, 0, - "MAXPOOL3D_BP op: input should have rank of 5, but got %i instead", input->rankOf()); - REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, - "MAXPOOL3DNEW op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW); - - int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; - int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, - indIOioC, indIOioD, indWiC, indWoC, indWkD); - - std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx( - {bS, iC, oD, oH, oW, 0, indIOioC, indIOioD, indIOioD + 1, indIOioD + 2})); - std::string expectedGradIShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx( - {bS, iC, iD, iH, iW, 0, indIOioC, indIOioD, indIOioD + 1, indIOioD + 2})); - REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, - "MAXPOOL3D_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", - expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str()); - REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0, - "MAXPOOL3D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", - expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str()); - - if (!isNCDHW) { - input = new NDArray(input->permute( - {0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW] - gradI = new NDArray(gradI->permute( - {0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW] - gradO = new NDArray(gradO->permute( - {0, 4, 1, 2, 3})); // [bS, oD, oH, oW, iC] -> [bS, iC, oD, oH, oW] - } - - if (isSameMode) // SAME - ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, - dW); - - - auto poolingMode = PoolingType::MAX_POOL; - auto extraParam0 = 1; - - dnnl_memory_desc_t empty; - dnnl::memory::desc pool_src_md(empty), pool_diff_src_md(empty), pool_dst_md(empty); - dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_dst_md(empty); - dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r; - dnnl::algorithm algorithm; - - mkldnnUtils::getMKLDNNMemoryDescPool3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode, - extraParam0, true, - bS, iC, iD, iH, iW, oC, oD, oH, oW, input, gradI, gradO, - algorithm, - &pool_src_md, &pool_diff_src_md, &pool_dst_md, &user_src_md, - &user_diff_src_md, &user_dst_md, - pool_strides, pool_kernel, pool_padding, pool_padding_r); - - // input is sometimes null, so we can't rely on pool_src_md being valid - if (input->buffer() == nullptr) { - pool_src_md = pool_diff_src_md; - user_src_md = user_diff_src_md; - } - auto pool_desc = pooling_forward::desc(prop_kind::forward, algorithm, pool_src_md, pool_dst_md, pool_strides, pool_kernel, pool_padding, pool_padding_r); - - auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - dnnl::stream stream(engine); - auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine); - - auto poolB_desc = pooling_backward::desc(algorithm, pool_diff_src_md, pool_dst_md, pool_strides, pool_kernel, pool_padding, pool_padding_r); - - auto poolB_prim_desc = pooling_backward::primitive_desc(poolB_desc, engine, pool_prim_desc); - auto userB_src_memory = dnnl::memory(user_diff_src_md, engine, gradI->buffer()); - auto userB_dst_memory = dnnl::memory(user_dst_md, engine, gradO->buffer()); - - auto poolB_src_memory = userB_src_memory; - if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) { - poolB_src_memory = dnnl::memory(poolB_prim_desc.diff_src_desc(), engine); - } - - auto poolB_dst_memory = userB_dst_memory; - if (poolB_prim_desc.diff_dst_desc() != userB_dst_memory.get_desc()) { - poolB_dst_memory = dnnl::memory(poolB_prim_desc.diff_dst_desc(), engine); - reorder(userB_dst_memory, poolB_dst_memory).execute(stream, userB_dst_memory, poolB_dst_memory); - } - - - auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer()); - - auto pool_src_memory = user_src_memory; - if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) { - pool_src_memory = dnnl::memory(pool_prim_desc.src_desc(), engine); - reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory); - } - - auto pool_dst_memory = dnnl::memory(pool_prim_desc.dst_desc(), engine); - auto pool_workspace_memory = dnnl::memory(pool_prim_desc.workspace_desc(), engine); - - pooling_forward(pool_prim_desc).execute(stream, {{DNNL_ARG_SRC, pool_src_memory}, - {DNNL_ARG_DST, pool_dst_memory}, - {DNNL_ARG_WORKSPACE, pool_workspace_memory}}); - pooling_backward(poolB_prim_desc).execute(stream, {{DNNL_ARG_DIFF_DST, poolB_dst_memory}, - {DNNL_ARG_WORKSPACE, pool_workspace_memory}, - {DNNL_ARG_DIFF_SRC, poolB_src_memory}}); - - - if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) { - reorder(poolB_src_memory, userB_src_memory).execute(stream, poolB_src_memory, userB_src_memory); - } - - stream.wait(); - - if (!isNCDHW) { - delete input; - delete gradI; - delete gradO; - } - - return Status::OK(); - } - - PLATFORM_CHECK(maxpool3dnew_bp, ENGINE_CPU) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output}); - } - } - } -} \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.cpp index 96bbffcf8..0b81de76d 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.cpp @@ -23,383 +23,388 @@ using namespace dnnl; -namespace nd4j { - namespace mkldnnUtils { - void getMKLDNNMemoryDescPool2d( - int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, int poolingMode, int extraParam0, bool isNCHW, - int bS, int iC, int iH, int iW, int oC, int oH, int oW, - const NDArray* src, const NDArray* diff_src, const NDArray* dst, dnnl::algorithm& algorithm, - dnnl::memory::desc* pool_src_md, dnnl::memory::desc* pool_diff_src_md, dnnl::memory::desc* pool_dst_md, - dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, - dnnl::memory::dims& pool_strides, dnnl::memory::dims& pool_kernel, dnnl::memory::dims& pool_padding, dnnl::memory::dims& pool_padding_r) { - dnnl::memory::dims pool_src_tz = { bS, iC, iH, iW }; - dnnl::memory::dims pool_dst_tz = { bS, oC, oH, oW }; +namespace nd4j { +namespace mkldnnUtils { - pool_strides = { sH, sW }; - pool_kernel = { kH, kW }; - pool_padding = { pH, pW }; - pool_padding_r = { (oH - 1) * sH - iH + kH - pH, - (oW - 1) * sW - iW + kW - pW }; +////////////////////////////////////////////////////////////////////////// +void getMKLDNNMemoryDescPool2d( + int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, int poolingMode, int extraParam0, bool isNCHW, + int bS, int iC, int iH, int iW, int oC, int oH, int oW, + const NDArray* src, const NDArray* diff_src, const NDArray* dst, dnnl::algorithm& algorithm, + dnnl::memory::desc* pool_src_md, dnnl::memory::desc* pool_diff_src_md, dnnl::memory::desc* pool_dst_md, + dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, + dnnl::memory::dims& pool_strides, dnnl::memory::dims& pool_kernel, dnnl::memory::dims& pool_padding, dnnl::memory::dims& pool_padding_r) { + dnnl::memory::dims pool_src_tz = { bS, iC, iH, iW }; + dnnl::memory::dims pool_dst_tz = { bS, oC, oH, oW }; - algorithm = poolingMode == 0 ? algorithm::pooling_max - : extraParam0 == 0 ? algorithm::pooling_avg_exclude_padding - : algorithm::pooling_avg_include_padding; - auto type = dnnl::memory::data_type::f32; - auto format = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; - auto supposed_to_be_any_format = dnnl::memory::format_tag::nChw8c; // doesn't work with "any" + pool_strides = { sH, sW }; + pool_kernel = { kH, kW }; + pool_padding = { pH, pW }; + pool_padding_r = { (oH - 1) * sH - iH + kH - pH, + (oW - 1) * sW - iW + kW - pW }; - if (src != nullptr && src->getBuffer() != nullptr && pool_src_md != nullptr) { - *pool_src_md = dnnl::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format); - *user_src_md = dnnl::memory::desc({ pool_src_tz }, type, format); - user_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCHW ? nchw : nhwc" - user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[isNCHW ? 0 : 0]; - user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[isNCHW ? 1 : 3]; - user_src_md->data.format_desc.blocking.strides[2] = src->stridesOf()[isNCHW ? 2 : 1]; - user_src_md->data.format_desc.blocking.strides[3] = src->stridesOf()[isNCHW ? 3 : 2]; - } + algorithm = poolingMode == 0 ? algorithm::pooling_max + : extraParam0 == 0 ? algorithm::pooling_avg_exclude_padding + : algorithm::pooling_avg_include_padding; + auto type = dnnl::memory::data_type::f32; + auto format = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; + auto supposed_to_be_any_format = dnnl::memory::format_tag::nChw8c; // doesn't work with "any" - if (diff_src != nullptr && diff_src->getBuffer() != nullptr && pool_diff_src_md != nullptr) { - *pool_diff_src_md = dnnl::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format); - *user_diff_src_md = dnnl::memory::desc({ pool_src_tz }, type, format); - user_diff_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCHW ? nchw : nhwc" - user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[isNCHW ? 0 : 0]; - user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[isNCHW ? 1 : 3]; - user_diff_src_md->data.format_desc.blocking.strides[2] = diff_src->stridesOf()[isNCHW ? 2 : 1]; - user_diff_src_md->data.format_desc.blocking.strides[3] = diff_src->stridesOf()[isNCHW ? 3 : 2]; - } - - if (dst != nullptr && dst->getBuffer() != nullptr && pool_dst_md != nullptr) { - *pool_dst_md = dnnl::memory::desc({ pool_dst_tz }, type, supposed_to_be_any_format); - *user_dst_md = dnnl::memory::desc({ pool_dst_tz }, type, format); - user_dst_md->data.format_kind = dnnl_blocked; // overrides "format = isNCHW ? nchw : nhwc" - user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[isNCHW ? 0 : 0]; - user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[isNCHW ? 1 : 3]; - user_dst_md->data.format_desc.blocking.strides[2] = dst->stridesOf()[isNCHW ? 2 : 1]; - user_dst_md->data.format_desc.blocking.strides[3] = dst->stridesOf()[isNCHW ? 3 : 2]; - } - }; - - - void getMKLDNNMemoryDescPool3d( - int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, int dD, int dH, int dW, int poolingMode, int extraParam0, bool isNCDHW, - int bS, int iC, int iD, int iH, int iW, int oC, int oD, int oH, int oW, - const NDArray* src, const NDArray* diff_src, const NDArray* dst, dnnl::algorithm& algorithm, - dnnl::memory::desc* pool_src_md, dnnl::memory::desc* pool_diff_src_md, dnnl::memory::desc* pool_dst_md, - dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, - dnnl::memory::dims& pool_strides, dnnl::memory::dims& pool_kernel, dnnl::memory::dims& pool_padding, dnnl::memory::dims& pool_padding_r) { - dnnl::memory::dims pool_src_tz = { bS, iC, iD, iH, iW }; - dnnl::memory::dims pool_dst_tz = { bS, oC, oD, oH, oW }; - - pool_strides = { sD, sH, sW }; - pool_kernel = { kD, kH, kW }; - pool_padding = { pD, pH, pW }; - pool_padding_r = { (oD - 1) * sD - iD + kD - pD, - (oH - 1) * sH - iH + kH - pH, - (oW - 1) * sW - iW + kW - pW }; - - algorithm = poolingMode == 0 ? algorithm::pooling_max - : extraParam0 == 0 ? algorithm::pooling_avg_exclude_padding - : algorithm::pooling_avg_include_padding; - auto type = dnnl::memory::data_type::f32; - auto format = isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc; - auto supposed_to_be_any_format = dnnl::memory::format_tag::nCdhw8c; // doesn't work with "any" - - if (src != nullptr && src->getBuffer() != nullptr && pool_src_md != nullptr) { - *pool_src_md = dnnl::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format); - *user_src_md = dnnl::memory::desc({ pool_src_tz }, type, format); - user_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc" - user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[isNCDHW ? 0 : 0]; - user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[isNCDHW ? 1 : 4]; - user_src_md->data.format_desc.blocking.strides[2] = src->stridesOf()[isNCDHW ? 2 : 1]; - user_src_md->data.format_desc.blocking.strides[3] = src->stridesOf()[isNCDHW ? 3 : 2]; - user_src_md->data.format_desc.blocking.strides[4] = src->stridesOf()[isNCDHW ? 4 : 3]; - } - - if (diff_src != nullptr && diff_src->getBuffer() != nullptr && pool_diff_src_md != nullptr) { - *pool_diff_src_md = dnnl::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format); - *user_diff_src_md = dnnl::memory::desc({ pool_src_tz }, type, format); - user_diff_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc" - user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[isNCDHW ? 0 : 0]; - user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[isNCDHW ? 1 : 4]; - user_diff_src_md->data.format_desc.blocking.strides[2] = diff_src->stridesOf()[isNCDHW ? 2 : 1]; - user_diff_src_md->data.format_desc.blocking.strides[3] = diff_src->stridesOf()[isNCDHW ? 3 : 2]; - user_diff_src_md->data.format_desc.blocking.strides[4] = diff_src->stridesOf()[isNCDHW ? 4 : 3]; - } - - if (dst != nullptr && dst->getBuffer() != nullptr && pool_dst_md != nullptr) { - *pool_dst_md = dnnl::memory::desc({ pool_dst_tz }, type, supposed_to_be_any_format); - *user_dst_md = dnnl::memory::desc({ pool_dst_tz }, type, format); - user_dst_md->data.format_kind = dnnl_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc" - user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[isNCDHW ? 0 : 0]; - user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[isNCDHW ? 1 : 4]; - user_dst_md->data.format_desc.blocking.strides[2] = dst->stridesOf()[isNCDHW ? 2 : 1]; - user_dst_md->data.format_desc.blocking.strides[3] = dst->stridesOf()[isNCDHW ? 3 : 2]; - user_dst_md->data.format_desc.blocking.strides[4] = dst->stridesOf()[isNCDHW ? 4 : 3]; - } - }; - - - - void getMKLDNNMemoryDescConv2d( - int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, const int paddingMode, bool isNCHW, - int bS, int iC, int iH, int iW, int oC, int oH, int oW, const NDArray* src, const NDArray* diff_src, - const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst, - dnnl::memory::desc* conv_src_md, dnnl::memory::desc* conv_diff_src_md, dnnl::memory::desc* conv_weights_md, - dnnl::memory::desc* conv_diff_weights_md, dnnl::memory::desc* conv_bias_md, dnnl::memory::desc* conv_dst_md, - dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_weights_md, - dnnl::memory::desc* user_diff_weights_md, dnnl::memory::desc* user_bias_md, dnnl::memory::desc* user_dst_md, - dnnl::memory::dims& conv_strides, dnnl::memory::dims& conv_padding, dnnl::memory::dims& conv_padding_r, dnnl::memory::dims& conv_dilation) { - dnnl::memory::dims conv_src_tz = { bS, iC, iH, iW }; - dnnl::memory::dims conv_weights_tz = { oC, iC, kH, kW }; - dnnl::memory::dims conv_bias_tz = { oC }; - dnnl::memory::dims conv_dst_tz = { bS, oC, oH, oW }; - - const int pWSame = (paddingMode == 2 && dW > 1) ? ((oW - 1) * sW + (kW - 1) * dW + 1 - iW) / 2 : pW; // dH == 1 for causal mode in conv1d - - conv_strides = { sH, sW }; - conv_padding = { pH, pW }; - conv_padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pWSame }; - conv_dilation = { dH-1, dW-1}; - - auto type = dnnl::memory::data_type::f32; - auto format = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; - auto formatw = dnnl::memory::format_tag::hwio; - - if (src != nullptr && conv_src_md != nullptr) { - *conv_src_md = dnnl::memory::desc({ conv_src_tz }, type, dnnl::memory::format_tag::any); - *user_src_md = dnnl::memory::desc({ conv_src_tz }, type, format); - user_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCHW ? nchw : nhwc" - user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[isNCHW ? 0 : 0]; - user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[isNCHW ? 1 : 3]; - user_src_md->data.format_desc.blocking.strides[2] = src->stridesOf()[isNCHW ? 2 : 1]; - user_src_md->data.format_desc.blocking.strides[3] = src->stridesOf()[isNCHW ? 3 : 2]; - } - - if (diff_src != nullptr && conv_diff_src_md != nullptr) { - *conv_diff_src_md = dnnl::memory::desc({ conv_src_tz }, type, dnnl::memory::format_tag::any); - *user_diff_src_md = dnnl::memory::desc({ conv_src_tz }, type, format); - user_diff_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCHW ? nchw : nhwc" - user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[isNCHW ? 0 : 0]; - user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[isNCHW ? 1 : 3]; - user_diff_src_md->data.format_desc.blocking.strides[2] = diff_src->stridesOf()[isNCHW ? 2 : 1]; - user_diff_src_md->data.format_desc.blocking.strides[3] = diff_src->stridesOf()[isNCHW ? 3 : 2]; - } - - if (weights != nullptr && conv_weights_md != nullptr) { - *conv_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, dnnl::memory::format_tag::any); - *user_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, formatw); - user_weights_md->data.format_kind = dnnl_blocked; // overrides "formatw = hwio" - user_weights_md->data.format_desc.blocking.strides[0] = weights->stridesOf()[3]; - user_weights_md->data.format_desc.blocking.strides[1] = weights->stridesOf()[2]; - user_weights_md->data.format_desc.blocking.strides[2] = weights->stridesOf()[0]; - user_weights_md->data.format_desc.blocking.strides[3] = weights->stridesOf()[1]; - } - - if (diff_weights != nullptr && conv_diff_weights_md != nullptr) { - *conv_diff_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, dnnl::memory::format_tag::any); - *user_diff_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, formatw); - user_diff_weights_md->data.format_kind = dnnl_blocked; // overrides "formatw = hwio" - user_diff_weights_md->data.format_desc.blocking.strides[0] = diff_weights->stridesOf()[3]; - user_diff_weights_md->data.format_desc.blocking.strides[1] = diff_weights->stridesOf()[2]; - user_diff_weights_md->data.format_desc.blocking.strides[2] = diff_weights->stridesOf()[0]; - user_diff_weights_md->data.format_desc.blocking.strides[3] = diff_weights->stridesOf()[1]; - } - - if (bias != nullptr && conv_bias_md != nullptr) { - *conv_bias_md = dnnl::memory::desc({ conv_bias_tz }, type, dnnl::memory::format_tag::any); - *user_bias_md = dnnl::memory::desc({ conv_bias_tz }, type, dnnl::memory::format_tag::x); - } - - if (dst != nullptr && conv_dst_md != nullptr) { - *conv_dst_md = dnnl::memory::desc({ conv_dst_tz }, type, dnnl::memory::format_tag::any); - *user_dst_md = dnnl::memory::desc({ conv_dst_tz }, type, format); - user_dst_md->data.format_kind = dnnl_blocked; // overrides "format = isNCHW ? nchw : nhwc" - user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[isNCHW ? 0 : 0]; - user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[isNCHW ? 1 : 3]; - user_dst_md->data.format_desc.blocking.strides[2] = dst->stridesOf()[isNCHW ? 2 : 1]; - user_dst_md->data.format_desc.blocking.strides[3] = dst->stridesOf()[isNCHW ? 3 : 2]; - } - } - - void getMKLDNNMemoryDescConv3d( - int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, int dD, int dH, int dW, bool paddingMode, bool isNCDHW, - int bS, int iC, int iD, int iH, int iW, int oC, int oD, int oH, int oW, const NDArray* src, const NDArray* diff_src, - const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst, - dnnl::memory::desc* conv_src_md, dnnl::memory::desc* conv_diff_src_md, dnnl::memory::desc* conv_weights_md, - dnnl::memory::desc* conv_diff_weights_md, dnnl::memory::desc* conv_bias_md, dnnl::memory::desc* conv_dst_md, - dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_weights_md, - dnnl::memory::desc* user_diff_weights_md, dnnl::memory::desc* user_bias_md, dnnl::memory::desc* user_dst_md, - dnnl::memory::dims& conv_strides, dnnl::memory::dims& conv_padding, dnnl::memory::dims& conv_padding_r, dnnl::memory::dims& conv_dilation) { - dnnl::memory::dims conv_src_tz = { bS, iC, iD, iH, iW }; - dnnl::memory::dims conv_weights_tz = { oC, iC, kD, kH, kW }; - dnnl::memory::dims conv_bias_tz = { oC }; - dnnl::memory::dims conv_dst_tz = { bS, oC, oD, oH, oW }; - - conv_strides = { sD, sH, sW }; - conv_padding = { pD, pH, pW }; - conv_padding_r = { (oD - 1) * sD - iD + kD - pD, (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pW }; - conv_dilation = { dD-1, dH-1, dW-1}; - - auto type = dnnl::memory::data_type::f32; - auto format = isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc; - auto formatw = dnnl::memory::format_tag::dhwio; - - if (src != nullptr && conv_src_md != nullptr) { - *conv_src_md = dnnl::memory::desc({ conv_src_tz }, type, dnnl::memory::format_tag::any); - *user_src_md = dnnl::memory::desc({ conv_src_tz }, type, format); - user_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc" - user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[isNCDHW ? 0 : 0]; - user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[isNCDHW ? 1 : 4]; - user_src_md->data.format_desc.blocking.strides[2] = src->stridesOf()[isNCDHW ? 2 : 1]; - user_src_md->data.format_desc.blocking.strides[3] = src->stridesOf()[isNCDHW ? 3 : 2]; - user_src_md->data.format_desc.blocking.strides[4] = src->stridesOf()[isNCDHW ? 4 : 3]; - } - - if (diff_src != nullptr && conv_diff_src_md != nullptr) { - *conv_diff_src_md = dnnl::memory::desc({ conv_src_tz }, type, dnnl::memory::format_tag::any); - *user_diff_src_md = dnnl::memory::desc({ conv_src_tz }, type, format); - user_diff_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc" - user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[isNCDHW ? 0 : 0]; - user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[isNCDHW ? 1 : 4]; - user_diff_src_md->data.format_desc.blocking.strides[2] = diff_src->stridesOf()[isNCDHW ? 2 : 1]; - user_diff_src_md->data.format_desc.blocking.strides[3] = diff_src->stridesOf()[isNCDHW ? 3 : 2]; - user_diff_src_md->data.format_desc.blocking.strides[4] = diff_src->stridesOf()[isNCDHW ? 4 : 3]; - } - - if (weights != nullptr && conv_weights_md != nullptr) { - *conv_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, dnnl::memory::format_tag::any); - *user_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, formatw); - user_weights_md->data.format_kind = dnnl_blocked; // overrides "formatw = dhwio" - user_weights_md->data.format_desc.blocking.strides[0] = weights->stridesOf()[4]; - user_weights_md->data.format_desc.blocking.strides[1] = weights->stridesOf()[3]; - user_weights_md->data.format_desc.blocking.strides[2] = weights->stridesOf()[0]; - user_weights_md->data.format_desc.blocking.strides[3] = weights->stridesOf()[1]; - user_weights_md->data.format_desc.blocking.strides[4] = weights->stridesOf()[2]; - } - - if (diff_weights != nullptr && conv_diff_weights_md != nullptr) { - *conv_diff_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, dnnl::memory::format_tag::any); - *user_diff_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, formatw); - user_diff_weights_md->data.format_kind = dnnl_blocked; // overrides "formatw = dhwio" - user_diff_weights_md->data.format_desc.blocking.strides[0] = diff_weights->stridesOf()[4]; - user_diff_weights_md->data.format_desc.blocking.strides[1] = diff_weights->stridesOf()[3]; - user_diff_weights_md->data.format_desc.blocking.strides[2] = diff_weights->stridesOf()[0]; - user_diff_weights_md->data.format_desc.blocking.strides[3] = diff_weights->stridesOf()[1]; - user_diff_weights_md->data.format_desc.blocking.strides[4] = diff_weights->stridesOf()[2]; - } - - if (bias != nullptr && conv_bias_md != nullptr) { - *conv_bias_md = dnnl::memory::desc({ conv_bias_tz }, type, dnnl::memory::format_tag::any); - *user_bias_md = dnnl::memory::desc({ conv_bias_tz }, type, dnnl::memory::format_tag::x); - } - - if (dst != nullptr && conv_dst_md != nullptr) { - *conv_dst_md = dnnl::memory::desc({ conv_dst_tz }, type, dnnl::memory::format_tag::any); - *user_dst_md = dnnl::memory::desc({ conv_dst_tz }, type, format); - user_dst_md->data.format_kind = dnnl_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc" - user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[isNCDHW ? 0 : 0]; - user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[isNCDHW ? 1 : 4]; - user_dst_md->data.format_desc.blocking.strides[2] = dst->stridesOf()[isNCDHW ? 2 : 1]; - user_dst_md->data.format_desc.blocking.strides[3] = dst->stridesOf()[isNCDHW ? 3 : 2]; - user_dst_md->data.format_desc.blocking.strides[4] = dst->stridesOf()[isNCDHW ? 4 : 3]; - } - }; - - - // void getMKLDNNMemoryDescBatchNorm(const NDArray* src, const NDArray* diff_src, const NDArray* dst, - // dnnl::memory::desc* batchnorm_src_md, dnnl::memory::desc* batchnorm_diff_src_md, dnnl::memory::desc* batchnorm_dst_md, - // dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, int axis) { - // const Nd4jLong* shape = src->getShapeInfo(); - // Nd4jLong rank = shape[0]; - // Nd4jLong dim1 = axis; // MKL-DNN supports only 1 axis, which has to be the "channel" one - // Nd4jLong dim2 = axis >= 2 ? 1 : 2; - // Nd4jLong dim3 = axis >= 3 ? 2 : 3; - // dnnl::memory::dims batchnorm_src_tz = { (int)shape[1], (int)shape[dim1 + 1], rank > 2 ? (int)shape[dim2 + 1] : 1, rank > 3 ? (int)shape[dim3 + 1] : 1}; - - // auto type = dnnl::memory::data_type::f32; - // auto format = dnnl::memory::format_tag::nchw; - // auto supposed_to_be_any_format = dnnl::memory::format_tag::nChw8c; // doesn't work with "any" - - // if (src != nullptr && src->getBuffer() != nullptr && batchnorm_src_md != nullptr) { - // *batchnorm_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format); - // *user_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, format); - // user_src_md->data.format_kind = dnnl_blocked; // overrides format - // user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[0]; - // user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[dim1]; - // user_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? src->stridesOf()[dim2] : 1; - // user_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? src->stridesOf()[dim3] : 1; - // } - - // if (diff_src != nullptr && diff_src->getBuffer() != nullptr && batchnorm_diff_src_md != nullptr) { - // *batchnorm_diff_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format); - // *user_diff_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, format); - // user_diff_src_md->data.format_kind = dnnl_blocked; // overrides format - // user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[0]; - // user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[dim1]; - // user_diff_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? diff_src->stridesOf()[dim2] : 1; - // user_diff_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? diff_src->stridesOf()[dim3] : 1; - // } - - // if (dst != nullptr && dst->getBuffer() != nullptr && batchnorm_dst_md != nullptr) { - // *batchnorm_dst_md = dnnl::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format); - // *user_dst_md = dnnl::memory::desc({ batchnorm_src_tz }, type, format); - // user_dst_md->data.format_kind = dnnl_blocked; // overrides format - // user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[0]; - // user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[dim1]; - // user_dst_md->data.format_desc.blocking.strides[2] = rank > 2 ? dst->stridesOf()[dim2] : 1; - // user_dst_md->data.format_desc.blocking.strides[3] = rank > 3 ? dst->stridesOf()[dim3] : 1; - // } - // }; - - - void getMKLDNNMemoryDescLrn(const NDArray* src, const NDArray* diff_src, const NDArray* dst, - dnnl::memory::desc* lrn_src_md, dnnl::memory::desc* lrn_diff_src_md, dnnl::memory::desc* lrn_dst_md, - dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, int axis) { - const Nd4jLong* shape = src->getShapeInfo(); - long rank = shape[0]; - long dim1 = axis; // MKL-DNN supports only 1 axis, which has to be the "channel" one - long dim2 = axis >= 2 ? 1 : 2; - long dim3 = axis >= 3 ? 2 : 3; - dnnl::memory::dims lrn_src_tz = { (int)shape[1], (int)shape[dim1 + 1], rank > 2 ? (int)shape[dim2 + 1] : 1, rank > 3 ? (int)shape[dim3 + 1] : 1}; - - auto type = dnnl::memory::data_type::f32; - auto format = axis == 1 ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; - auto supposed_to_be_any_format = format; // doesn't work with "any" - - if (src != nullptr && src->getBuffer() != nullptr && lrn_src_md != nullptr) { - *lrn_src_md = dnnl::memory::desc({ lrn_src_tz }, type, supposed_to_be_any_format); - *user_src_md = dnnl::memory::desc({ lrn_src_tz }, type, format); - user_src_md->data.format_kind = dnnl_blocked; - user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[0]; - user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[dim1]; - user_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? src->stridesOf()[dim2] : 1; - user_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? src->stridesOf()[dim3] : 1; - } - - if (diff_src != nullptr && diff_src->getBuffer() != nullptr && lrn_diff_src_md != nullptr) { - *lrn_diff_src_md = dnnl::memory::desc({ lrn_src_tz }, type, supposed_to_be_any_format); - *user_diff_src_md = dnnl::memory::desc({ lrn_src_tz }, type, format); - user_diff_src_md->data.format_kind = dnnl_blocked; - user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[0]; - user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[dim1]; - user_diff_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? diff_src->stridesOf()[dim2] : 1; - user_diff_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? diff_src->stridesOf()[dim3] : 1; - } - - if (dst != nullptr && dst->getBuffer() != nullptr && lrn_dst_md != nullptr) { - *lrn_dst_md = dnnl::memory::desc({ lrn_src_tz }, type, supposed_to_be_any_format); - *user_dst_md = dnnl::memory::desc({ lrn_src_tz }, type, format); - user_dst_md->data.format_kind = dnnl_blocked; - user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[0]; - user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[dim1]; - user_dst_md->data.format_desc.blocking.strides[2] = rank > 2 ? dst->stridesOf()[dim2] : 1; - user_dst_md->data.format_desc.blocking.strides[3] = rank > 3 ? dst->stridesOf()[dim3] : 1; - } - } - - dnnl::engine& getEngine(void *ptr) { - auto eng = reinterpret_cast(ptr); - return *eng; - } + if (src != nullptr && src->getBuffer() != nullptr && pool_src_md != nullptr) { + *pool_src_md = dnnl::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format); + *user_src_md = dnnl::memory::desc({ pool_src_tz }, type, format); + user_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCHW ? nchw : nhwc" + user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[isNCHW ? 0 : 0]; + user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[isNCHW ? 1 : 3]; + user_src_md->data.format_desc.blocking.strides[2] = src->stridesOf()[isNCHW ? 2 : 1]; + user_src_md->data.format_desc.blocking.strides[3] = src->stridesOf()[isNCHW ? 3 : 2]; } + + if (diff_src != nullptr && diff_src->getBuffer() != nullptr && pool_diff_src_md != nullptr) { + *pool_diff_src_md = dnnl::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format); + *user_diff_src_md = dnnl::memory::desc({ pool_src_tz }, type, format); + user_diff_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCHW ? nchw : nhwc" + user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[isNCHW ? 0 : 0]; + user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[isNCHW ? 1 : 3]; + user_diff_src_md->data.format_desc.blocking.strides[2] = diff_src->stridesOf()[isNCHW ? 2 : 1]; + user_diff_src_md->data.format_desc.blocking.strides[3] = diff_src->stridesOf()[isNCHW ? 3 : 2]; + } + + if (dst != nullptr && dst->getBuffer() != nullptr && pool_dst_md != nullptr) { + *pool_dst_md = dnnl::memory::desc({ pool_dst_tz }, type, supposed_to_be_any_format); + *user_dst_md = dnnl::memory::desc({ pool_dst_tz }, type, format); + user_dst_md->data.format_kind = dnnl_blocked; // overrides "format = isNCHW ? nchw : nhwc" + user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[isNCHW ? 0 : 0]; + user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[isNCHW ? 1 : 3]; + user_dst_md->data.format_desc.blocking.strides[2] = dst->stridesOf()[isNCHW ? 2 : 1]; + user_dst_md->data.format_desc.blocking.strides[3] = dst->stridesOf()[isNCHW ? 3 : 2]; + } +}; + +////////////////////////////////////////////////////////////////////////// +void getMKLDNNMemoryDescPool3d( + int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, int dD, int dH, int dW, int poolingMode, int extraParam0, bool isNCDHW, + int bS, int iC, int iD, int iH, int iW, int oC, int oD, int oH, int oW, + const NDArray* src, const NDArray* diff_src, const NDArray* dst, dnnl::algorithm& algorithm, + dnnl::memory::desc* pool_src_md, dnnl::memory::desc* pool_diff_src_md, dnnl::memory::desc* pool_dst_md, + dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, + dnnl::memory::dims& pool_strides, dnnl::memory::dims& pool_kernel, dnnl::memory::dims& pool_padding, dnnl::memory::dims& pool_padding_r) { + dnnl::memory::dims pool_src_tz = { bS, iC, iD, iH, iW }; + dnnl::memory::dims pool_dst_tz = { bS, oC, oD, oH, oW }; + + pool_strides = { sD, sH, sW }; + pool_kernel = { kD, kH, kW }; + pool_padding = { pD, pH, pW }; + pool_padding_r = { (oD - 1) * sD - iD + kD - pD, + (oH - 1) * sH - iH + kH - pH, + (oW - 1) * sW - iW + kW - pW }; + + algorithm = poolingMode == 0 ? algorithm::pooling_max + : extraParam0 == 0 ? algorithm::pooling_avg_exclude_padding + : algorithm::pooling_avg_include_padding; + auto type = dnnl::memory::data_type::f32; + auto format = isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc; + auto supposed_to_be_any_format = dnnl::memory::format_tag::nCdhw8c; // doesn't work with "any" + + if (src != nullptr && src->getBuffer() != nullptr && pool_src_md != nullptr) { + *pool_src_md = dnnl::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format); + *user_src_md = dnnl::memory::desc({ pool_src_tz }, type, format); + user_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc" + user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[isNCDHW ? 0 : 0]; + user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[isNCDHW ? 1 : 4]; + user_src_md->data.format_desc.blocking.strides[2] = src->stridesOf()[isNCDHW ? 2 : 1]; + user_src_md->data.format_desc.blocking.strides[3] = src->stridesOf()[isNCDHW ? 3 : 2]; + user_src_md->data.format_desc.blocking.strides[4] = src->stridesOf()[isNCDHW ? 4 : 3]; + } + + if (diff_src != nullptr && diff_src->getBuffer() != nullptr && pool_diff_src_md != nullptr) { + *pool_diff_src_md = dnnl::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format); + *user_diff_src_md = dnnl::memory::desc({ pool_src_tz }, type, format); + user_diff_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc" + user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[isNCDHW ? 0 : 0]; + user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[isNCDHW ? 1 : 4]; + user_diff_src_md->data.format_desc.blocking.strides[2] = diff_src->stridesOf()[isNCDHW ? 2 : 1]; + user_diff_src_md->data.format_desc.blocking.strides[3] = diff_src->stridesOf()[isNCDHW ? 3 : 2]; + user_diff_src_md->data.format_desc.blocking.strides[4] = diff_src->stridesOf()[isNCDHW ? 4 : 3]; + } + + if (dst != nullptr && dst->getBuffer() != nullptr && pool_dst_md != nullptr) { + *pool_dst_md = dnnl::memory::desc({ pool_dst_tz }, type, supposed_to_be_any_format); + *user_dst_md = dnnl::memory::desc({ pool_dst_tz }, type, format); + user_dst_md->data.format_kind = dnnl_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc" + user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[isNCDHW ? 0 : 0]; + user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[isNCDHW ? 1 : 4]; + user_dst_md->data.format_desc.blocking.strides[2] = dst->stridesOf()[isNCDHW ? 2 : 1]; + user_dst_md->data.format_desc.blocking.strides[3] = dst->stridesOf()[isNCDHW ? 3 : 2]; + user_dst_md->data.format_desc.blocking.strides[4] = dst->stridesOf()[isNCDHW ? 4 : 3]; + } +}; + +////////////////////////////////////////////////////////////////////////// +void getMKLDNNMemoryDescConv2d( + int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, const int paddingMode, bool isNCHW, + int bS, int iC, int iH, int iW, int oC, int oH, int oW, const NDArray* src, const NDArray* diff_src, + const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst, + dnnl::memory::desc* conv_src_md, dnnl::memory::desc* conv_diff_src_md, dnnl::memory::desc* conv_weights_md, + dnnl::memory::desc* conv_diff_weights_md, dnnl::memory::desc* conv_bias_md, dnnl::memory::desc* conv_dst_md, + dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_weights_md, + dnnl::memory::desc* user_diff_weights_md, dnnl::memory::desc* user_bias_md, dnnl::memory::desc* user_dst_md, + dnnl::memory::dims& conv_strides, dnnl::memory::dims& conv_padding, dnnl::memory::dims& conv_padding_r, dnnl::memory::dims& conv_dilation) { + dnnl::memory::dims conv_src_tz = { bS, iC, iH, iW }; + dnnl::memory::dims conv_weights_tz = { oC, iC, kH, kW }; + dnnl::memory::dims conv_bias_tz = { oC }; + dnnl::memory::dims conv_dst_tz = { bS, oC, oH, oW }; + + const int pWSame = (paddingMode == 2 && dW > 1) ? ((oW - 1) * sW + (kW - 1) * dW + 1 - iW) / 2 : pW; // dH == 1 for causal mode in conv1d + + conv_strides = { sH, sW }; + conv_padding = { pH, pW }; + conv_padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pWSame }; + conv_dilation = { dH-1, dW-1}; + + auto type = dnnl::memory::data_type::f32; + auto format = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; + auto formatw = dnnl::memory::format_tag::hwio; + + if (src != nullptr && conv_src_md != nullptr) { + *conv_src_md = dnnl::memory::desc({ conv_src_tz }, type, dnnl::memory::format_tag::any); + *user_src_md = dnnl::memory::desc({ conv_src_tz }, type, format); + user_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCHW ? nchw : nhwc" + user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[isNCHW ? 0 : 0]; + user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[isNCHW ? 1 : 3]; + user_src_md->data.format_desc.blocking.strides[2] = src->stridesOf()[isNCHW ? 2 : 1]; + user_src_md->data.format_desc.blocking.strides[3] = src->stridesOf()[isNCHW ? 3 : 2]; + } + + if (diff_src != nullptr && conv_diff_src_md != nullptr) { + *conv_diff_src_md = dnnl::memory::desc({ conv_src_tz }, type, dnnl::memory::format_tag::any); + *user_diff_src_md = dnnl::memory::desc({ conv_src_tz }, type, format); + user_diff_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCHW ? nchw : nhwc" + user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[isNCHW ? 0 : 0]; + user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[isNCHW ? 1 : 3]; + user_diff_src_md->data.format_desc.blocking.strides[2] = diff_src->stridesOf()[isNCHW ? 2 : 1]; + user_diff_src_md->data.format_desc.blocking.strides[3] = diff_src->stridesOf()[isNCHW ? 3 : 2]; + } + + if (weights != nullptr && conv_weights_md != nullptr) { + *conv_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, dnnl::memory::format_tag::any); + *user_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, formatw); + user_weights_md->data.format_kind = dnnl_blocked; // overrides "formatw = hwio" + user_weights_md->data.format_desc.blocking.strides[0] = weights->stridesOf()[3]; + user_weights_md->data.format_desc.blocking.strides[1] = weights->stridesOf()[2]; + user_weights_md->data.format_desc.blocking.strides[2] = weights->stridesOf()[0]; + user_weights_md->data.format_desc.blocking.strides[3] = weights->stridesOf()[1]; + } + + if (diff_weights != nullptr && conv_diff_weights_md != nullptr) { + *conv_diff_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, dnnl::memory::format_tag::any); + *user_diff_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, formatw); + user_diff_weights_md->data.format_kind = dnnl_blocked; // overrides "formatw = hwio" + user_diff_weights_md->data.format_desc.blocking.strides[0] = diff_weights->stridesOf()[3]; + user_diff_weights_md->data.format_desc.blocking.strides[1] = diff_weights->stridesOf()[2]; + user_diff_weights_md->data.format_desc.blocking.strides[2] = diff_weights->stridesOf()[0]; + user_diff_weights_md->data.format_desc.blocking.strides[3] = diff_weights->stridesOf()[1]; + } + + if (bias != nullptr && conv_bias_md != nullptr) { + *conv_bias_md = dnnl::memory::desc({ conv_bias_tz }, type, dnnl::memory::format_tag::any); + *user_bias_md = dnnl::memory::desc({ conv_bias_tz }, type, dnnl::memory::format_tag::x); + } + + if (dst != nullptr && conv_dst_md != nullptr) { + *conv_dst_md = dnnl::memory::desc({ conv_dst_tz }, type, dnnl::memory::format_tag::any); + *user_dst_md = dnnl::memory::desc({ conv_dst_tz }, type, format); + user_dst_md->data.format_kind = dnnl_blocked; // overrides "format = isNCHW ? nchw : nhwc" + user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[isNCHW ? 0 : 0]; + user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[isNCHW ? 1 : 3]; + user_dst_md->data.format_desc.blocking.strides[2] = dst->stridesOf()[isNCHW ? 2 : 1]; + user_dst_md->data.format_desc.blocking.strides[3] = dst->stridesOf()[isNCHW ? 3 : 2]; + } +} + +////////////////////////////////////////////////////////////////////////// +void getMKLDNNMemoryDescConv3d( + int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, int dD, int dH, int dW, bool paddingMode, bool isNCDHW, + int bS, int iC, int iD, int iH, int iW, int oC, int oD, int oH, int oW, const NDArray* src, const NDArray* diff_src, + const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst, + dnnl::memory::desc* conv_src_md, dnnl::memory::desc* conv_diff_src_md, dnnl::memory::desc* conv_weights_md, + dnnl::memory::desc* conv_diff_weights_md, dnnl::memory::desc* conv_bias_md, dnnl::memory::desc* conv_dst_md, + dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_weights_md, + dnnl::memory::desc* user_diff_weights_md, dnnl::memory::desc* user_bias_md, dnnl::memory::desc* user_dst_md, + dnnl::memory::dims& conv_strides, dnnl::memory::dims& conv_padding, dnnl::memory::dims& conv_padding_r, dnnl::memory::dims& conv_dilation) { + dnnl::memory::dims conv_src_tz = { bS, iC, iD, iH, iW }; + dnnl::memory::dims conv_weights_tz = { oC, iC, kD, kH, kW }; + dnnl::memory::dims conv_bias_tz = { oC }; + dnnl::memory::dims conv_dst_tz = { bS, oC, oD, oH, oW }; + + conv_strides = { sD, sH, sW }; + conv_padding = { pD, pH, pW }; + conv_padding_r = { (oD - 1) * sD - iD + kD - pD, (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pW }; + conv_dilation = { dD-1, dH-1, dW-1}; + + auto type = dnnl::memory::data_type::f32; + auto format = isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc; + auto formatw = dnnl::memory::format_tag::dhwio; + + if (src != nullptr && conv_src_md != nullptr) { + *conv_src_md = dnnl::memory::desc({ conv_src_tz }, type, dnnl::memory::format_tag::any); + *user_src_md = dnnl::memory::desc({ conv_src_tz }, type, format); + user_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc" + user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[isNCDHW ? 0 : 0]; + user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[isNCDHW ? 1 : 4]; + user_src_md->data.format_desc.blocking.strides[2] = src->stridesOf()[isNCDHW ? 2 : 1]; + user_src_md->data.format_desc.blocking.strides[3] = src->stridesOf()[isNCDHW ? 3 : 2]; + user_src_md->data.format_desc.blocking.strides[4] = src->stridesOf()[isNCDHW ? 4 : 3]; + } + + if (diff_src != nullptr && conv_diff_src_md != nullptr) { + *conv_diff_src_md = dnnl::memory::desc({ conv_src_tz }, type, dnnl::memory::format_tag::any); + *user_diff_src_md = dnnl::memory::desc({ conv_src_tz }, type, format); + user_diff_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc" + user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[isNCDHW ? 0 : 0]; + user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[isNCDHW ? 1 : 4]; + user_diff_src_md->data.format_desc.blocking.strides[2] = diff_src->stridesOf()[isNCDHW ? 2 : 1]; + user_diff_src_md->data.format_desc.blocking.strides[3] = diff_src->stridesOf()[isNCDHW ? 3 : 2]; + user_diff_src_md->data.format_desc.blocking.strides[4] = diff_src->stridesOf()[isNCDHW ? 4 : 3]; + } + + if (weights != nullptr && conv_weights_md != nullptr) { + *conv_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, dnnl::memory::format_tag::any); + *user_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, formatw); + user_weights_md->data.format_kind = dnnl_blocked; // overrides "formatw = dhwio" + user_weights_md->data.format_desc.blocking.strides[0] = weights->stridesOf()[4]; + user_weights_md->data.format_desc.blocking.strides[1] = weights->stridesOf()[3]; + user_weights_md->data.format_desc.blocking.strides[2] = weights->stridesOf()[0]; + user_weights_md->data.format_desc.blocking.strides[3] = weights->stridesOf()[1]; + user_weights_md->data.format_desc.blocking.strides[4] = weights->stridesOf()[2]; + } + + if (diff_weights != nullptr && conv_diff_weights_md != nullptr) { + *conv_diff_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, dnnl::memory::format_tag::any); + *user_diff_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, formatw); + user_diff_weights_md->data.format_kind = dnnl_blocked; // overrides "formatw = dhwio" + user_diff_weights_md->data.format_desc.blocking.strides[0] = diff_weights->stridesOf()[4]; + user_diff_weights_md->data.format_desc.blocking.strides[1] = diff_weights->stridesOf()[3]; + user_diff_weights_md->data.format_desc.blocking.strides[2] = diff_weights->stridesOf()[0]; + user_diff_weights_md->data.format_desc.blocking.strides[3] = diff_weights->stridesOf()[1]; + user_diff_weights_md->data.format_desc.blocking.strides[4] = diff_weights->stridesOf()[2]; + } + + if (bias != nullptr && conv_bias_md != nullptr) { + *conv_bias_md = dnnl::memory::desc({ conv_bias_tz }, type, dnnl::memory::format_tag::any); + *user_bias_md = dnnl::memory::desc({ conv_bias_tz }, type, dnnl::memory::format_tag::x); + } + + if (dst != nullptr && conv_dst_md != nullptr) { + *conv_dst_md = dnnl::memory::desc({ conv_dst_tz }, type, dnnl::memory::format_tag::any); + *user_dst_md = dnnl::memory::desc({ conv_dst_tz }, type, format); + user_dst_md->data.format_kind = dnnl_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc" + user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[isNCDHW ? 0 : 0]; + user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[isNCDHW ? 1 : 4]; + user_dst_md->data.format_desc.blocking.strides[2] = dst->stridesOf()[isNCDHW ? 2 : 1]; + user_dst_md->data.format_desc.blocking.strides[3] = dst->stridesOf()[isNCDHW ? 3 : 2]; + user_dst_md->data.format_desc.blocking.strides[4] = dst->stridesOf()[isNCDHW ? 4 : 3]; + } +}; + + +// void getMKLDNNMemoryDescBatchNorm(const NDArray* src, const NDArray* diff_src, const NDArray* dst, +// dnnl::memory::desc* batchnorm_src_md, dnnl::memory::desc* batchnorm_diff_src_md, dnnl::memory::desc* batchnorm_dst_md, +// dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, int axis) { +// const Nd4jLong* shape = src->getShapeInfo(); +// Nd4jLong rank = shape[0]; +// Nd4jLong dim1 = axis; // MKL-DNN supports only 1 axis, which has to be the "channel" one +// Nd4jLong dim2 = axis >= 2 ? 1 : 2; +// Nd4jLong dim3 = axis >= 3 ? 2 : 3; +// dnnl::memory::dims batchnorm_src_tz = { (int)shape[1], (int)shape[dim1 + 1], rank > 2 ? (int)shape[dim2 + 1] : 1, rank > 3 ? (int)shape[dim3 + 1] : 1}; + +// auto type = dnnl::memory::data_type::f32; +// auto format = dnnl::memory::format_tag::nchw; +// auto supposed_to_be_any_format = dnnl::memory::format_tag::nChw8c; // doesn't work with "any" + +// if (src != nullptr && src->getBuffer() != nullptr && batchnorm_src_md != nullptr) { +// *batchnorm_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format); +// *user_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, format); +// user_src_md->data.format_kind = dnnl_blocked; // overrides format +// user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[0]; +// user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[dim1]; +// user_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? src->stridesOf()[dim2] : 1; +// user_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? src->stridesOf()[dim3] : 1; +// } + +// if (diff_src != nullptr && diff_src->getBuffer() != nullptr && batchnorm_diff_src_md != nullptr) { +// *batchnorm_diff_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format); +// *user_diff_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, format); +// user_diff_src_md->data.format_kind = dnnl_blocked; // overrides format +// user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[0]; +// user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[dim1]; +// user_diff_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? diff_src->stridesOf()[dim2] : 1; +// user_diff_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? diff_src->stridesOf()[dim3] : 1; +// } + +// if (dst != nullptr && dst->getBuffer() != nullptr && batchnorm_dst_md != nullptr) { +// *batchnorm_dst_md = dnnl::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format); +// *user_dst_md = dnnl::memory::desc({ batchnorm_src_tz }, type, format); +// user_dst_md->data.format_kind = dnnl_blocked; // overrides format +// user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[0]; +// user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[dim1]; +// user_dst_md->data.format_desc.blocking.strides[2] = rank > 2 ? dst->stridesOf()[dim2] : 1; +// user_dst_md->data.format_desc.blocking.strides[3] = rank > 3 ? dst->stridesOf()[dim3] : 1; +// } +// }; + +////////////////////////////////////////////////////////////////////////// +void getMKLDNNMemoryDescLrn(const NDArray* src, const NDArray* diff_src, const NDArray* dst, + dnnl::memory::desc* lrn_src_md, dnnl::memory::desc* lrn_diff_src_md, dnnl::memory::desc* lrn_dst_md, + dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, int axis) { + const Nd4jLong* shape = src->getShapeInfo(); + long rank = shape[0]; + long dim1 = axis; // MKL-DNN supports only 1 axis, which has to be the "channel" one + long dim2 = axis >= 2 ? 1 : 2; + long dim3 = axis >= 3 ? 2 : 3; + dnnl::memory::dims lrn_src_tz = { (int)shape[1], (int)shape[dim1 + 1], rank > 2 ? (int)shape[dim2 + 1] : 1, rank > 3 ? (int)shape[dim3 + 1] : 1}; + + auto type = dnnl::memory::data_type::f32; + auto format = axis == 1 ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; + auto supposed_to_be_any_format = format; // doesn't work with "any" + + if (src != nullptr && src->getBuffer() != nullptr && lrn_src_md != nullptr) { + *lrn_src_md = dnnl::memory::desc({ lrn_src_tz }, type, supposed_to_be_any_format); + *user_src_md = dnnl::memory::desc({ lrn_src_tz }, type, format); + user_src_md->data.format_kind = dnnl_blocked; + user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[0]; + user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[dim1]; + user_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? src->stridesOf()[dim2] : 1; + user_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? src->stridesOf()[dim3] : 1; + } + + if (diff_src != nullptr && diff_src->getBuffer() != nullptr && lrn_diff_src_md != nullptr) { + *lrn_diff_src_md = dnnl::memory::desc({ lrn_src_tz }, type, supposed_to_be_any_format); + *user_diff_src_md = dnnl::memory::desc({ lrn_src_tz }, type, format); + user_diff_src_md->data.format_kind = dnnl_blocked; + user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[0]; + user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[dim1]; + user_diff_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? diff_src->stridesOf()[dim2] : 1; + user_diff_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? diff_src->stridesOf()[dim3] : 1; + } + + if (dst != nullptr && dst->getBuffer() != nullptr && lrn_dst_md != nullptr) { + *lrn_dst_md = dnnl::memory::desc({ lrn_src_tz }, type, supposed_to_be_any_format); + *user_dst_md = dnnl::memory::desc({ lrn_src_tz }, type, format); + user_dst_md->data.format_kind = dnnl_blocked; + user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[0]; + user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[dim1]; + user_dst_md->data.format_desc.blocking.strides[2] = rank > 2 ? dst->stridesOf()[dim2] : 1; + user_dst_md->data.format_desc.blocking.strides[3] = rank > 3 ? dst->stridesOf()[dim3] : 1; + } +} + +////////////////////////////////////////////////////////////////////////// +dnnl::engine& getEngine(void *ptr) { + auto eng = reinterpret_cast(ptr); + return *eng; +} + + +} } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/ConvolutionTests1.cpp b/libnd4j/tests_cpu/layers_tests/ConvolutionTests1.cpp index 9ed9f0ee6..9aafe869e 100644 --- a/libnd4j/tests_cpu/layers_tests/ConvolutionTests1.cpp +++ b/libnd4j/tests_cpu/layers_tests/ConvolutionTests1.cpp @@ -318,36 +318,6 @@ TEST_F(ConvolutionTests1, conv2d_8) { delete results; } -TYPED_TEST(TypedConvolutionTests1, TestAvgFF_TF) { - - auto input = NDArrayFactory::create('c', {4, 10, 10, 3}, {9.37125111f, 2.20166993f, 2.91434479f, 5.43639755f, -2.10573769f, 4.08528662f, 5.86908436f, -4.46203756f, 2.21057916f, 5.35849190f, 0.01394637f, 4.40566349f, 7.07982206f, -0.09633455f, 2.42429352f, 3.97301817f, -1.89553940f, 1.99690318f, 6.33141708f, 0.55401880f, 1.70707977f, - 5.55204201f, -0.03513752f, 1.60011971f, 2.62700319f, -2.74582434f, 3.06697464f, 1.06277943f, -1.16075921f, -0.78095782f, 9.72352791f, -1.22686064f, 1.99644792f, 7.35571337f, 1.40607321f, 0.11390255f, 9.53334427f, 2.28303599f, -1.66728830f, 6.16678810f, -0.04532295f, -1.97708666f, 9.74906158f, 1.46223176f, -1.46734393f, 4.30761862f, - -1.23790228f, 1.24823606f, 6.13938427f, -3.83689475f, -1.19625473f, 7.91535568f, 6.05868721f, -3.22946382f, 8.81633949f, -0.19967777f, 0.66053957f, 2.30919123f, 0.74543846f, -0.39347672f, 11.11058044f, 0.53720862f, 1.52645731f, 5.70012379f, -1.15213466f, 1.16451406f, 7.00526333f, 1.57362783f, -2.44384766f, 5.54213285f, -1.98828590f, - -0.70483637f, 7.88281822f, -3.59875536f, 0.80745387f, 13.41578484f, -1.55507684f, -0.65855008f, 9.32583523f, -0.14544789f, 0.73436141f, 3.61176538f, -1.71268058f, -2.58490300f, 9.09280205f, -3.27405524f, -2.04569697f, 4.44761324f, -0.62955856f, -2.61917663f, 8.04890442f, 0.54579324f, 0.85929775f, 9.82259560f, -1.93825579f, 0.77703512f, - 4.67090321f, -4.79267597f, -2.38906908f, 9.31265545f, 0.96026313f, -1.14109385f, 11.54231834f, -0.01417295f, -0.39500344f, 8.49191666f, 0.55300158f, 2.79490185f, 6.92466164f, 1.72254205f, 2.82222271f, 8.83112717f, 2.95033407f, 2.18054962f, 6.73509789f, -2.22272944f, 0.51127720f, -1.04563558f, 2.15747333f, -2.30959272f, 9.55441570f, - 1.50396204f, 1.77370787f, 7.38146257f, -1.79076433f, 3.20961165f, 7.18864202f, 2.91217351f, 0.43018937f, 7.11078024f, -1.17386127f, -0.16817921f, 6.12327290f, -2.82205725f, 3.30696845f, 13.51291752f, -1.30856836f, -2.38332748f, 11.09487438f, -1.47190213f, -0.53050828f, 4.38285351f, -5.07309771f, 1.50714362f, 5.72274446f, -2.85825086f, - -0.89673209f, 3.73791552f, -0.67708802f, -4.13149452f, -0.00671843f, -0.26566532f, 0.32961160f, 7.14501762f, -1.41608179f, -4.96590328f, 12.26205540f, -0.65158135f, -0.88641000f, 6.95777559f, -0.79058206f, -0.10260171f, 7.87169170f, 1.35921454f, 1.11759663f, 5.46187401f, -2.57214499f, 2.48484039f, 4.04043484f, -2.07137156f, -1.42709637f, - 9.25487137f, -0.12605135f, -2.66949964f, 2.89412403f, 0.74451172f, -2.96250391f, 3.99258423f, 0.27084303f, 0.32213116f, 5.42332172f, -0.44414216f, 1.70881832f, 6.69346905f, 0.53058422f, -4.73146200f, 4.22051668f, 2.24834967f, 0.66996074f, 4.30173683f, 0.11849818f, -4.07520294f, 8.27318478f, -2.54398274f, -2.86705542f, 10.11775303f, - -0.99382895f, 0.65881538f, 7.93556786f, -1.27934420f, -1.69343162f, 9.68042564f, -1.02609646f, -1.18189347f, 5.75370646f, -1.67888868f, -4.48871994f, 4.79537392f, -0.79212248f, -0.19855022f, 6.15060997f, -0.01081491f, 3.64454579f, 10.82562447f, 1.58859253f, -2.65847278f, 8.60093212f, -1.59196103f, 0.07635692f, 11.76175690f, -1.17453325f, - 0.10122013f, 6.86458445f, -2.18891335f, -2.74004745f, 8.07066154f, 0.71818852f, -2.03035975f, 6.31053686f, 0.51509416f, 1.39789927f, 9.43515587f, 2.04256630f, 0.13985133f, 4.65010691f, 2.40911126f, -0.36255789f, -3.06867862f, -0.45225358f, -1.56778407f, 6.05917358f, -1.09891272f, 1.77184200f, 6.46248102f, 0.96042323f, -0.24346280f, - 4.63436460f, -4.69907761f, 1.25187206f, 11.46173859f, -2.21917558f, 1.28007793f, 6.92173195f, 2.11268163f, -3.47389889f, 5.08722782f, -3.03950930f, -4.17154264f, 11.30568314f, 0.80361372f, 2.53214502f, 7.18707085f, -4.49114513f, 2.85449266f, 10.14906883f, -0.31974933f, -0.84472644f, -0.52459574f, 0.12921631f, -1.81390119f, 2.76170087f, 1.03982210f, 2.91744232f, -0.29048753f, 5.87453508f, -1.53684759f, 1.85800636f, -0.91404629f, 1.28954852f, 5.11354685f, -2.47475505f, -1.33179152f, 2.58552408f, 1.37316465f, -3.32339454f, 1.54122913f, 3.24953628f, -0.29758382f, 2.82391763f, -1.51142192f, -1.22699404f, 6.75745535f, 0.65452754f, -3.29385471f, 2.06008053f, 2.53172946f, -4.23532820f, -1.53909743f, -0.07010663f, -1.42173731f, 7.29031610f, -0.18448229f, 4.59496164f, 6.73027277f, 0.73441899f, 0.14426160f, 4.14915276f, -2.97010231f, 6.05851364f, 4.95218086f, -2.39145470f, 2.40494704f, 2.10288811f, 0.53503096f, 1.44511235f, 6.66344261f, -3.05803776f, 7.21418667f, 3.30303526f, -0.24163735f, 3.47409391f, 3.64520788f, 2.15189481f, -3.11243272f, 3.62310791f, 0.37379482f, 0.40865007f, -0.83132005f, -4.78246069f, 2.07030797f, 6.51765442f, 3.16178989f, 5.06180477f, 3.78434467f, -0.96689719f, 0.35965276f, 5.89967585f, 1.40294051f, 1.11952639f, 10.59778214f, 0.26739889f, -1.61297631f, 6.24801159f, -0.93914318f, -0.57812452f, 9.92604542f, -0.73025000f, -3.38530874f, 2.45646000f, -2.47949195f, 0.51638460f, 10.65636063f, 1.97816694f, -3.00407791f, 2.66914415f, -0.81951088f, -0.23316640f, 2.40737987f, -2.70007610f, 1.51531935f, 4.08860207f, -0.27552786f, -1.31721711f, 7.11568260f, -3.33498216f, -4.02545023f, 7.22675610f, -0.81690705f, -2.52689576f, 1.04016697f, -0.79291463f, -0.34875512f, 10.00498390f, -4.24167728f, 1.46162593f, 11.82569408f, -1.70359993f, -0.30161047f, 16.44085884f, -0.82253462f, -0.09435523f, 6.13080597f, -0.20259480f, 0.68308711f, 6.15663004f, -6.61776876f, 0.33295766f, 2.55449438f, -0.17819691f, -1.14892209f, 5.56776142f, 1.99279118f, 1.33035934f, 4.45823956f, 3.34916544f, -2.59905386f, 6.16164446f, -2.03881931f, -2.45273542f, 12.46793365f, -2.22743297f, 2.83738565f, 8.48628139f, -1.39347959f, -1.30867767f, 11.08041477f, -4.00363779f, 2.09183025f, 11.30395889f, -2.20504737f, 1.37426853f, 8.98735619f, 1.04676604f, -0.72757077f, 8.28050232f, -6.70741081f, -0.65798020f, 5.68592072f, -0.60760021f, 0.35854483f, 6.26852131f, 1.94100165f, 1.32112014f, 0.80987954f, -1.74617672f, -0.25434083f, 7.16045523f, 1.58884013f, -2.64847064f, 13.14820385f, 1.21393633f, -2.47258949f, 9.41650105f, -0.79384226f, 2.48954105f, 10.95629311f, 0.47723705f, 4.02126694f, 8.02593136f, -2.20726371f, -1.18794477f, 1.50836647f, 0.93118095f, -1.73513174f, 8.85493565f, -2.99670315f, -0.79055870f, 2.39473820f, 2.05046916f, -2.38055134f, 11.82299423f, 0.15609655f, 0.68744308f, 5.66401434f, -0.69281673f, 2.09855556f, 7.74626589f, -0.34283102f, 1.00542057f, 9.95838642f, 0.80161905f, 2.33455157f, 9.80057335f, -0.93561798f, 2.56991577f, 8.29711342f, 0.94213426f, 0.44209945f, 11.70259857f, 0.92710167f, 2.60957146f, 0.24971688f, -0.86529571f, 3.78628922f, 6.80884457f, -0.68178189f, 2.21103406f, 3.18895817f, 0.60283208f, -2.92716241f, 6.72060776f, -1.06625068f, 2.56543374f, 9.97404480f, 3.58080721f, -0.94936347f, 10.16736984f, -1.38464379f, 1.18191063f, 6.66179037f, -3.56115270f, 0.32329530f, 10.90870762f, 2.20638227f, 0.19653285f, 7.34650040f, -3.63859272f, -1.03027737f, 5.98829985f, -3.66606474f, -3.89746714f, 8.63469028f, 1.22569811f, 1.63240814f, 3.74385309f, 0.58243257f, -0.56981975f, 3.69260955f, 1.00979900f, -1.44030499f, 8.57058144f, -1.10648811f, 1.20474911f, 5.43133020f, -2.14822555f, -0.07928789f, 11.25825310f, 0.19645604f, -5.49546146f, 10.41917038f, -0.68178523f, -2.99639869f, 6.50054455f, 0.46488351f, -5.42328453f, 9.09500027f, -2.82107449f, 0.05601966f, 15.34610748f, -0.06820253f, 3.86699796f, 10.73316956f, -3.04795432f, -0.14702171f, 5.64813185f, 1.44028485f, -2.47596145f, 0.07280898f, -3.03187990f, -1.35183525f, 9.35835648f, 2.72966957f, 1.88199532f, 10.36187744f, -0.22834805f, -3.26738238f, 6.92025137f, -2.34061313f, 4.77379704f, 5.28559113f, -2.96323752f, -1.76186585f, 5.94436455f, 0.38647744f, -5.73869514f, 6.76849556f, 1.40892124f, -1.19068217f, 5.37919092f, -6.65328646f, 3.62782669f, 12.34744644f, 2.44762444f, -4.19242620f, 6.14906216f, 0.08121119f, 0.61355996f, 2.69666457f, -1.88962626f, -0.55314136f, 1.84937525f, 1.56048691f, 1.17460012f, 3.75674725f, 1.06198275f, -5.74625874f, 5.41645575f, -1.28946674f, -1.51689398f, 4.32400894f, -0.05222082f, -4.83948946f, 1.80747867f, 1.63144708f, -2.73887825f, 1.63975775f, -2.02163982f, -0.16210437f, 2.93518686f, 1.14427686f, -2.83246303f, 4.79283667f, 2.69697428f, -3.12678456f, -1.19225168f, -2.37022972f, -3.09429741f, 1.94225383f, -1.13747168f, -2.55048585f, 5.40242243f, 1.12777328f, 3.43713188f, 3.62658787f, -2.16878843f, 0.30164462f, 2.97407579f, -0.07275413f, -1.31149673f, 4.70066261f, -2.01323795f, 4.85255766f, 4.59128904f, 1.68084168f, 1.60336494f, 6.58138466f, -1.04759812f, 2.69906545f, 3.55769277f, -0.74327278f, 2.65819693f, 5.39528131f, 2.11248922f, -1.06446671f, 5.24546766f, -2.43146014f, 4.58907509f, 0.06521678f, -2.24503994f, 2.45722699f, 6.94863081f, 0.35258654f, 2.83396196f, 9.92525196f, -1.12225175f, -0.34365177f, 7.19116688f, -4.39813757f, 0.46517885f, 13.22028065f, -2.57483673f, -6.37226963f, 7.58046293f, -2.74600363f, 0.42231262f, 8.04881668f, 0.17289802f, -0.53447008f, 16.55157471f, -5.63614368f, 0.39288223f, 3.37079263f, 1.26484549f, -0.12820500f, 8.46440125f, -4.39304399f, 2.97676420f, 0.65650189f, 0.83158541f, -1.11556435f, 6.32885838f, -0.36087769f, 2.80724382f, 9.90292645f, 1.15936041f, 0.20947981f, 6.91249275f, -2.67404819f, 2.93782163f, 6.65656614f, -2.30828357f, 2.98214006f, 6.80611229f, -4.93821478f, -7.66555262f, 7.59763002f, -0.54159302f, 3.87403512f, 12.42607784f, 2.59284401f, -0.23375344f, 8.95293331f, -0.71807784f, 0.61873478f, 8.66713524f, 1.24289191f, -2.37835455f, 2.08071637f, -0.88315344f, -3.41891551f, 6.85245323f, 1.73007369f, 1.02169311f, 7.69170332f, -2.85411978f, 2.69790673f, 8.12906551f, -1.19351399f, -2.26442742f, 12.26104450f, -0.75579089f, -1.73274946f, 10.68729019f, 2.20655656f, -0.90522075f, 12.42165184f, -1.67929137f, 2.44851565f, 9.31565762f, -0.06645700f, 1.52762020f, 6.18427515f, -1.68882596f, 3.70261097f, 3.02252960f, -3.44125366f, -1.31575799f, 2.84617424f, -0.96849400f, -4.52356243f, 9.95027161f, 0.19966406f, -0.78874779f, 8.18595028f, -4.08300209f, 1.75126517f, 0.96418417f, -4.04913044f, -0.95200396f, 12.03637886f, -0.03041124f, 0.41642749f, 8.88267422f, -3.24985337f, -2.24919462f, 7.32566118f, 0.16964148f, -2.74123430f, 7.05264473f, -3.30191112f, 0.17163286f, 4.81851053f, -1.64463484f, -0.85933101f, 7.29276276f, 2.34066939f, -2.14860010f, 3.46148157f, -0.01782012f, 1.51504040f, 4.79304934f, 1.85281146f, -1.70663762f, 6.93470192f, -4.15440845f, -1.25983095f, 10.52491760f, 0.42930329f, -1.85146868f, 11.70042324f, -0.41704914f, 3.83796859f, 9.21148491f, -2.79719448f, 0.79470479f, 6.26926661f, -5.85230207f, 3.95105338f, 7.84790897f, -1.38680744f, -1.78099084f, 11.95235348f, -2.99841452f, -1.34507811f, 6.15714645f, -1.07552516f, -2.81228638f, 1.66234732f, -4.55166149f, -1.92601109f, 8.64634514f, -0.48158705f, 3.31595659f, 7.67371941f, 2.56964207f, 0.12107098f, 4.56467867f, -0.93541539f, 1.39432955f, 11.99714088f, 1.05353570f, -2.13099813f, 3.67617917f, 3.45895386f, 1.37365830f, 8.74344158f, -4.17585802f, 1.43908918f, 6.28764772f, 3.97346330f, -0.69144285f, 9.07983303f, -0.41635889f, -0.14965028f, 8.85469818f, 1.11306190f, 2.59440994f, 5.38982344f, -1.07948279f, 1.37252975f, 10.26984596f, -0.09318046f, 2.73104119f, 12.45902252f, -1.55446684f, -2.76124811f, 12.19395065f, -0.51846564f, 1.02764034f, 11.42673588f, -0.95940983f, -0.04781032f, 8.78379822f, -4.88957930f, 0.32534006f, 11.97696400f, -3.35108662f, 1.95104563f, 4.46915388f, -2.32061648f, 3.45230985f, 8.29983711f, 2.81034684f, -2.35529327f, 6.07801294f, -0.98105043f, -0.05359888f, 2.52291036f, -0.01986909f, -2.35321999f, 10.51954269f, 2.11145401f, 3.53506470f, 7.29093266f, 0.03721160f, -1.13496494f, 7.43886709f, -5.84201956f, 2.50796294f, 12.14647675f, 2.77490377f, -2.18896222f, 6.05641937f, 5.32617044f, 1.04221284f, 10.79106712f, -2.95749092f, -2.75414610f, 11.30037117f, -3.40654182f, -2.24673963f, 7.49126101f, 0.70811015f, -6.18003702f, 13.83951187f, -1.01204085f, 1.36298490f, -1.04451632f, 2.42435336f, -0.02346706f, -0.85528886f, 1.04731262f, 0.22192979f, 4.15708160f, 0.34933877f, 0.04814529f, 2.24107265f, 0.49676740f, -1.47752666f, 0.45040059f, -0.70471478f, -1.19759345f, 0.21711677f, 0.88461423f, -2.76830935f, 5.52066898f, 1.97664857f, -1.75381601f, 3.45877838f, 1.52617192f, -1.61350942f, 0.85337949f, 1.97610760f, -3.40310287f, 3.40319014f, -3.38691044f, -0.71319139f, 1.65463758f, -0.60680127f, -1.80700517f, 8.02592373f, 2.59627104f, 2.65895891f, 5.93043184f, -4.48425817f, 3.92670918f, 4.19496679f, -2.28286791f, 6.41634607f, 5.72330523f, 1.16269672f, -0.28753027f, 2.46342492f, 0.36693189f, 0.26712441f, 6.37652683f, -2.50139046f, 2.43923736f, 5.56310415f, 0.98065847f, 1.04267502f, 4.16403675f, -0.04966142f, 4.40897894f, 3.72905660f, -3.46129870f, 3.59962773f, 1.34830284f, -1.76661730f, 0.47943926f, 5.29946661f, -1.12711561f, 1.26970029f, 15.17655945f, -1.50971997f, 5.81345224f, 8.48562050f, -4.36049604f, 2.48144460f, 8.23780441f, -3.46030426f, -0.84656560f, 5.94946814f, 1.12747943f, -2.65683913f, 8.69085693f, 1.31309867f, -2.79958344f, 8.76840591f, -1.56444156f, 1.62710834f, 2.41177034f, -0.72804940f, 5.70619011f, 4.67169666f, -0.86167198f, -1.83803177f, 2.96346045f, 2.82692933f, -2.81557131f, 7.11113358f, -1.90071094f, 2.54244423f, 11.19284058f, -0.06298946f, -1.71517313f, 12.98388577f, 0.84510714f, 3.00816894f, 2.57200313f, 0.03899818f, -1.49330592f, 9.60099125f, -3.59513044f, -1.30045319f, 7.09241819f, -0.65233821f, -2.33627677f, 8.81366920f, 0.84154201f, 1.03312039f, 9.85289097f, 0.19351870f, 1.78496623f, 7.34631205f, -2.16530800f, -0.65016162f, 2.46842360f, 0.24016285f, -1.24308395f, 4.78175163f, -0.97682536f, 2.20942235f, 6.68382788f, 3.76786447f, -1.44454038f, 6.26453733f, -3.23575711f, -2.30137897f, 9.53092670f, -5.55222607f, 3.25999236f, 9.37559509f, 1.86339056f, -0.23551451f, 10.23400211f, 3.93031883f, -0.52629089f, 7.85724449f, -2.91549587f, 4.46612740f, 5.66530371f, -2.70820427f, 4.81359577f, 10.31247330f, 1.92230141f, 2.53931546f, 0.74986327f, 1.70303428f, 0.48063779f, 5.31099129f, -0.78976244f, 3.75864220f, 4.23051405f, 2.34042454f, -7.98193836f, 9.83987141f, -1.46722627f, 3.54497814f, 10.36455154f, -4.51249075f, 0.77715248f, 7.78694630f, -4.59989023f, -2.49585629f, 9.90296268f, 1.38535416f, 1.17441154f, 10.10452843f, -0.98628229f, 0.60194463f, 9.12639141f, -3.90754628f, 2.88526392f, 7.24123430f, -0.15283313f, -0.75728363f, -1.15116858f, -2.53791571f, 0.77229571f, 6.44114161f, 0.02646767f, 4.95463037f, 7.21066380f, 1.79384065f, 0.73250306f, 8.04447937f, 0.32576546f, -0.79447043f, 10.12717724f, 2.33392906f, 1.30716443f, 12.36073112f, -0.36694977f, -1.20438910f, 7.03105593f, 0.59557682f, 0.69267452f, 10.18113136f, 2.49944925f, -0.42229167f, 8.83143330f, -1.18805945f, -2.87509322f, 4.53596449f, 4.09732771f, -3.39088297f, -1.02536607f, 0.82119560f, -3.47302604f, 9.29991817f, 0.21001509f, 4.97036457f, 9.50018406f, 1.04420102f, 1.96560478f, 10.74769592f, -6.22709799f, 3.11690164f, 5.06759691f, -1.23724771f, -3.05831861f, 8.12925529f, -1.93435478f, -1.10151744f, 9.32263088f, -0.04249470f, -5.98547363f, 10.49398136f, 0.26400441f, -0.78915191f, 13.28219604f, 2.99276900f, 0.74853164f, 2.49364305f, -3.43529654f, 4.05278301f, 2.13498688f, -2.35444307f, -0.79900265f, 4.66968822f, -0.31095147f, 3.60674143f, 12.37222099f, -0.07855003f, -3.30292702f, 12.15215874f, 0.60886210f, 2.87075138f, 7.75271845f, 0.38044083f, 3.34402204f, 6.40583277f, -0.87888050f, 0.67438459f, 6.91080809f, 1.98332930f, -0.08303714f, 8.08630371f, -0.16772588f, -2.74058914f, 7.17253590f, -2.69122696f, 1.48173678f, 8.99470139f, -1.43302310f, -0.88651133f, 2.66944790f, -0.29186964f, 2.00838661f, 5.09587479f, -0.76676071f, -2.88322186f, 8.31110573f, -0.14550979f, -1.37726915f, 10.28355122f, -1.60575438f, -0.04118848f, 9.97510815f, 0.14440438f, -3.24632120f, 9.00034523f, 4.14319563f, -1.31023729f, 7.16950464f, -0.70428526f, 2.01559544f, 7.26155043f, 2.40816474f, 2.09847403f, 7.31264496f, -0.75401551f, 2.13392544f, 7.03648758f, 1.04036045f, -1.15636516f, 1.09634531f, -0.06340861f, -0.58107805f, -0.65623116f, 1.18972754f, -0.80717683f, 1.40118241f, -0.61932516f, -3.60596156f, 1.59904599f, -2.23774099f, -1.13721037f, 3.89620137f, -0.09115922f, -7.51356888f, 2.36975193f, -1.42520905f, -2.34173775f, 3.33830214f, -2.74016523f, -3.04115510f, 6.00119495f, -1.36084354f, -2.45065260f, 4.56992292f, -3.02825928f, -3.74182844f, 5.11069250f, -0.91531068f, -2.31385994f, 1.83399653f, 3.39370203f, -3.60886002f}); - auto exp = NDArrayFactory::create('c', {4, 4, 4, 3}, {7.97172260f, 0.06878620f, 2.27749538f, 7.29276514f, -0.14074677f, 0.65480286f, 5.70313978f, -0.06546132f, 0.35443667f, 3.70382833f, -0.84020567f, 0.63826996f, 8.60301399f, -0.38236514f, 1.55177069f, 7.37542057f, -0.99374938f, -0.29971302f, 8.84352493f, -0.67121059f, 0.43132120f, 4.78175592f, -1.25070143f, -1.91523600f, 6.03855371f, -0.00292124f, -1.11214364f, 7.90158176f, -0.57949901f, -0.96735370f, 7.81192017f, -0.53255427f, -0.48009714f, 3.16953635f, 0.08353355f, -1.54299748f, 3.74821687f, 1.69396687f, 0.72724354f, 5.42915201f, -1.13686812f, -0.71793109f, 5.78376389f, -0.72239977f, -0.60055625f, 2.53636408f, 0.56777251f, -2.07892323f, 6.08064651f, 0.68620735f, 2.54017019f, 5.65828180f, -0.68255502f, 1.47283304f, 6.10842514f, -0.39655915f, 0.28380761f, 1.96707797f, -1.98206317f, 0.94027776f, 4.71811438f, 0.32104525f, -0.92409706f, 8.34588146f, -1.05581069f, -0.55217457f, 9.58440876f, -0.96549922f, 0.45820439f, 5.65453672f, -2.50953507f, -0.71441835f, 8.03059578f, -0.21281289f, 0.92125505f, 9.26900673f, -0.35963219f, -0.70039093f, 8.59924412f, -1.22358346f, 0.81318003f, 3.85920119f, -0.01305223f, -1.09234154f, 6.33158875f, 1.28094780f, -1.48926139f, 4.94969177f, -0.77126902f, -1.97033751f, 5.64381838f, -0.16285487f, -1.31277227f, 2.39893222f, -1.32902908f, -1.39609122f, 6.47572327f, -0.45267010f, 1.55727172f, 6.70965624f, -1.68735468f, -0.05672536f, 7.25092363f, -0.64613032f, 0.67050058f, 3.60789680f, -2.05948973f, 2.22687531f, 8.15202713f, -0.70148355f, 1.28314006f, 8.14842319f, -1.88807654f, -1.04808438f, 8.45500565f, -0.76425624f, 0.94542569f, 4.56179953f, -0.28786001f, -2.04502511f, 8.46278095f, -0.31019822f, 0.07339200f, 9.34214592f, -0.61948007f, 0.52481830f, 8.32515621f, -1.52418160f, 0.49678251f, 5.11082315f, -1.09908783f, -0.52969611f, 5.27806664f, 0.88632923f, 0.66754371f, 4.75839233f, 0.48928693f, -0.68036932f, 6.56925392f, -0.02949905f, -2.99189186f, 4.46320581f, -0.64534980f, -0.29516968f, 8.60809517f, -1.13120568f, 3.41720533f, 5.84243155f, -1.24109328f, 0.89566326f, 5.99578333f, -0.42496428f, 2.07076764f, 3.17812920f, -0.81566459f, -0.14363396f, 6.55184317f, 0.39633346f, -0.43852386f, 8.70214558f, -2.24613595f, 0.30708700f, 8.73882294f, -0.53545928f, 1.54409575f, 4.49452257f, -0.16509305f, 0.19028664f, 8.24897003f, 0.44750381f, 2.15448594f, 8.97640514f, -0.77728152f, 0.57272542f, 9.03467560f, 0.47173575f, -1.10807717f, 3.30056310f, -0.43268481f, -0.41470885f, 3.53798294f, -0.08546703f, -2.16840744f, 6.18733406f, -0.17871059f, -2.59837723f, 5.94218683f, -1.02990067f, -0.49760687f, 3.76938033f, 0.86383581f, -1.91504073f}); - - nd4j::ops::avgpool2d op; - auto result = op.execute({&input}, {}, {3,3, 3,3, 0,0, 1,1,1, 0,1}, {}); - ASSERT_EQ(Status::OK(), result->status()); - - auto z = result->at(0); - - // z->printIndexedBuffer("z"); - // exp.printIndexedBuffer("e"); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - delete result; -} - ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests1, sconv2d_1) { float _expB[] = {10025.0f, 10350.0f, 10675.0f, 11000.0f, 11325.0f, 11650.0f, 13275.0f, 13600.0f, 13925.0f, 14250.0f, 14575.0f, 14900.0f, 16525.0f, 16850.0f, 17175.0f, 17500.0f, 17825.0f, 18150.0f, 19775.0f, 20100.0f, 20425.0f, 20750.0f, 21075.0f, 21400.0f, 23025.0f, 23350.0f, 23675.0f, 24000.0f, 24325.0f, 24650.0f, 26275.0f, 26600.0f, 26925.0f, 27250.0f, 27575.0f, 27900.0f, 38775.0f, 40350.0f, 41925.0f, 43500.0f, 45075.0f, 46650.0f, 54525.0f, 56100.0f, 57675.0f, 59250.0f, 60825.0f, 62400.0f, 70275.0f, 71850.0f, 73425.0f, 75000.0f, 76575.0f, 78150.0f, 86025.0f, 87600.0f, 89175.0f, 90750.0f, 92325.0f, 93900.0f, 101775.0f, 103350.0f, 104925.0f, 106500.0f, 108075.0f, 109650.0f, 117525.0f, 119100.0f, 120675.0f, 122250.0f, 123825.0f, 125400.0f, 67525.0f, 70350.0f, 73175.0f, 76000.0f, 78825.0f, 81650.0f, 95775.0f, 98600.0f, 101425.0f, 104250.0f, 107075.0f, 109900.0f, 124025.0f, 126850.0f, 129675.0f, 132500.0f, 135325.0f, 138150.0f, 152275.0f, 155100.0f, 157925.0f, 160750.0f, 163575.0f, 166400.0f, 180525.0f, 183350.0f, 186175.0f, 189000.0f, 191825.0f, 194650.0f, 208775.0f, 211600.0f, 214425.0f, 217250.0f, 220075.0f, 222900.0f, 119400.0f, 120350.0f, 121300.0f, 122250.0f, 123200.0f, 124150.0f, 128900.0f, 129850.0f, 130800.0f, 131750.0f, 132700.0f, 133650.0f, 138400.0f, 139350.0f, 140300.0f, 141250.0f, 142200.0f, 143150.0f, 147900.0f, 148850.0f, 149800.0f, 150750.0f, 151700.0f, 152650.0f, 157400.0f, 158350.0f, 159300.0f, 160250.0f, 161200.0f, 162150.0f, 166900.0f, 167850.0f, 168800.0f, 169750.0f, 170700.0f, 171650.0f, 273150.0f, 275350.0f, 277550.0f, 279750.0f, 281950.0f, 284150.0f, 295150.0f, 297350.0f, 299550.0f, 301750.0f, 303950.0f, 306150.0f, 317150.0f, 319350.0f, 321550.0f, 323750.0f, 325950.0f, 328150.0f, 339150.0f, 341350.0f, 343550.0f, 345750.0f, 347950.0f, 350150.0f, 361150.0f, 363350.0f, 365550.0f, 367750.0f, 369950.0f, 372150.0f, 383150.0f, 385350.0f, 387550.0f, 389750.0f, 391950.0f, 394150.0f, 426900.0f, 430350.0f, 433800.0f, 437250.0f, 440700.0f, 444150.0f, 461400.0f, 464850.0f, 468300.0f, 471750.0f, 475200.0f, 478650.0f, 495900.0f, 499350.0f, 502800.0f, 506250.0f, 509700.0f, 513150.0f, 530400.0f, 533850.0f, 537300.0f, 540750.0f, 544200.0f, 547650.0f, 564900.0f, 568350.0f, 571800.0f, 575250.0f, 578700.0f, 582150.0f, 599400.0f, 602850.0f, 606300.0f, 609750.0f, 613200.0f, 616650.0f, 75025.0f, 75350.0f, 75675.0f, 76000.0f, 76325.0f, 76650.0f, 78275.0f, 78600.0f, 78925.0f, 79250.0f, 79575.0f, 79900.0f, 81525.0f, 81850.0f, 82175.0f, 82500.0f, 82825.0f, 83150.0f, 84775.0f, 85100.0f, 85425.0f, 85750.0f, 86075.0f, 86400.0f, 88025.0f, 88350.0f, 88675.0f, 89000.0f, 89325.0f, 89650.0f, 91275.0f, 91600.0f, 91925.0f, 92250.0f, 92575.0f, 92900.0f, 353775.0f, 355350.0f, 356925.0f, 358500.0f, 360075.0f, 361650.0f, 369525.0f, 371100.0f, 372675.0f, 374250.0f, 375825.0f, 377400.0f, 385275.0f, 386850.0f, 388425.0f, 390000.0f, 391575.0f, 393150.0f, 401025.0f, 402600.0f, 404175.0f, 405750.0f, 407325.0f, 408900.0f, 416775.0f, 418350.0f, 419925.0f, 421500.0f, 423075.0f, 424650.0f, 432525.0f, 434100.0f, 435675.0f, 437250.0f, 438825.0f, 440400.0f, 632525.0f, 635350.0f, 638175.0f, 641000.0f, 643825.0f, 646650.0f, 660775.0f, 663600.0f, 666425.0f, 669250.0f, 672075.0f, 674900.0f, 689025.0f, 691850.0f, 694675.0f, 697500.0f, 700325.0f, 703150.0f, 717275.0f, 720100.0f, 722925.0f, 725750.0f, 728575.0f, 731400.0f, 745525.0f, 748350.0f, 751175.0f, 754000.0f, 756825.0f, 759650.0f, 773775.0f, 776600.0f, 779425.0f, 782250.0f, 785075.0f, 787900.0f, 309400.0f, 310350.0f, 311300.0f, 312250.0f, 313200.0f, 314150.0f, 318900.0f, 319850.0f, 320800.0f, 321750.0f, 322700.0f, 323650.0f, 328400.0f, 329350.0f, 330300.0f, 331250.0f, 332200.0f, 333150.0f, 337900.0f, 338850.0f, 339800.0f, 340750.0f, 341700.0f, 342650.0f, 347400.0f, 348350.0f, 349300.0f, 350250.0f, 351200.0f, 352150.0f, 356900.0f, 357850.0f, 358800.0f, 359750.0f, 360700.0f, 361650.0f, 713150.0f, 715350.0f, 717550.0f, 719750.0f, 721950.0f, 724150.0f, 735150.0f, 737350.0f, 739550.0f, 741750.0f, 743950.0f, 746150.0f, 757150.0f, 759350.0f, 761550.0f, 763750.0f, 765950.0f, 768150.0f, 779150.0f, 781350.0f, 783550.0f, 785750.0f, 787950.0f, 790150.0f, 801150.0f, 803350.0f, 805550.0f, 807750.0f, 809950.0f, 812150.0f, 823150.0f, 825350.0f, 827550.0f, 829750.0f, 831950.0f, 834150.0f, 1116900.0f, 1120350.0f, 1123800.0f, 1127250.0f, 1130700.0f, 1134150.0f, 1151400.0f, 1154850.0f, 1158300.0f, 1161750.0f, 1165200.0f, 1168650.0f, 1185900.0f, 1189350.0f, 1192800.0f, 1196250.0f, 1199700.0f, 1203150.0f, 1220400.0f, 1223850.0f, 1227300.0f, 1230750.0f, 1234200.0f, 1237650.0f, 1254900.0f, 1258350.0f, 1261800.0f, 1265250.0f, 1268700.0f, 1272150.0f, 1289400.0f, 1292850.0f, 1296300.0f, 1299750.0f, 1303200.0f, 1306650.0f,}; diff --git a/libnd4j/tests_cpu/layers_tests/ConvolutionTests2.cpp b/libnd4j/tests_cpu/layers_tests/ConvolutionTests2.cpp index a16d9cfbd..989d316de 100644 --- a/libnd4j/tests_cpu/layers_tests/ConvolutionTests2.cpp +++ b/libnd4j/tests_cpu/layers_tests/ConvolutionTests2.cpp @@ -970,7 +970,6 @@ TYPED_TEST(TypedConvolutionTests2, maxpool2d_6) { x.linspace(1); - nd4j::ops::maxpool2d op; auto result = op.execute({&x}, {}, {2, 2, 2, 2, 0, 0, 1, 1, 1, 1, 1}); @@ -991,7 +990,6 @@ TYPED_TEST(TypedConvolutionTests2, maxpool2d_7) { x.linspace(1); - nd4j::ops::maxpool2d op; auto result = op.execute({&x}, {}, {2, 2, 2, 2, 0, 0, 1, 1, 0, 1, 1}); @@ -1012,7 +1010,6 @@ TYPED_TEST(TypedConvolutionTests2, maxpool2d_8) { x.linspace(1); - nd4j::ops::maxpool2d op; auto result = op.execute({&x}, {}, {2, 2, 2, 2, 0, 0, 1, 1, 0, 1, 0}); @@ -1467,11 +1464,12 @@ TYPED_TEST(TypedConvolutionTests2, maxpool3d_bp_test1) { auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); auto gradO = NDArrayFactory::create('c', {bS, iC, oD, oH, oW}); auto expected = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.1f, 0.2f, 0.f, 0.3f, 0.4f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.5f, 0.6f, 0.f, 0.7f, 0.8f, - 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.9f, 1.f, 0.f, 1.1f, 1.2f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 1.3f, 1.4f, 0.f, 1.5f, 1.6f, - 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 1.7f, 1.8f, 0.f, 1.9f, 2.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 2.1f, 2.2f, 0.f, 2.3f, 2.4f, - 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 2.5f, 2.6f, 0.f, 2.7f, 2.8f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 2.9f, 3.f, 0.f, 3.1f, 3.2f, - 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 3.3f, 3.4f, 0.f, 3.5f, 3.6f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 3.7f, 3.8f, 0.f, 3.9f, 4.f, - 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 4.1f, 4.2f, 0.f, 4.3f, 4.4f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 4.5f, 4.6f, 0.f, 4.7f, 4.8f}); + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.9f, 1.f, 0.f, 1.1f, 1.2f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 1.3f, 1.4f, 0.f, 1.5f, 1.6f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 1.7f, 1.8f, 0.f, 1.9f, 2.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 2.1f, 2.2f, 0.f, 2.3f, 2.4f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 2.5f, 2.6f, 0.f, 2.7f, 2.8f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 2.9f, 3.f, 0.f, 3.1f, 3.2f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 3.3f, 3.4f, 0.f, 3.5f, 3.6f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 3.7f, 3.8f, 0.f, 3.9f, 4.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 4.1f, 4.2f, 0.f, 4.3f, 4.4f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 4.5f, 4.6f, 0.f, 4.7f, 4.8f}); + input.linspace(1.); gradO.linspace(0.1, 0.1); diff --git a/libnd4j/tests_cpu/layers_tests/CuDnnTests.cu b/libnd4j/tests_cpu/layers_tests/CuDnnTests.cu index 8809ad894..02e1040aa 100644 --- a/libnd4j/tests_cpu/layers_tests/CuDnnTests.cu +++ b/libnd4j/tests_cpu/layers_tests/CuDnnTests.cu @@ -57,6 +57,17 @@ TEST_F(CuDnnTests, helpers_includer) { nd4j::ops::platforms::PLATFORM_depthwise_conv2d_ENGINE_CUDA depthwise_conv2d; nd4j::ops::platforms::PLATFORM_depthwise_conv2d_bp_ENGINE_CUDA depthwise_conv2d_bp; nd4j::ops::platforms::PLATFORM_batchnorm_ENGINE_CUDA batchnorm; + nd4j::ops::platforms::PLATFORM_batchnorm_bp_ENGINE_CUDA batchnorm_bp; + nd4j::ops::platforms::PLATFORM_avgpool2d_ENGINE_CUDA avgpool2d; + nd4j::ops::platforms::PLATFORM_avgpool2d_bp_ENGINE_CUDA avgpool2d_bp; + nd4j::ops::platforms::PLATFORM_maxpool2d_ENGINE_CUDA maxpool2d; + nd4j::ops::platforms::PLATFORM_maxpool2d_bp_ENGINE_CUDA maxpool2d_bp; + nd4j::ops::platforms::PLATFORM_avgpool3dnew_ENGINE_CUDA avgpool3dnew; + nd4j::ops::platforms::PLATFORM_avgpool3dnew_bp_ENGINE_CUDA avgpool3dnew_bp; + nd4j::ops::platforms::PLATFORM_maxpool3dnew_ENGINE_CUDA maxpool3dnew; + nd4j::ops::platforms::PLATFORM_maxpool3dnew_bp_ENGINE_CUDA maxpool3dnew_bp; + + printer({&conv2d}); printer({&conv2d_bp}); @@ -65,6 +76,15 @@ TEST_F(CuDnnTests, helpers_includer) { printer({&depthwise_conv2d}); printer({&depthwise_conv2d_bp}); printer({&batchnorm}); + printer({&batchnorm_bp}); + printer({&avgpool2d}); + printer({&avgpool2d_bp}); + printer({&maxpool2d}); + printer({&maxpool2d_bp}); + printer({&avgpool3dnew}); + printer({&avgpool3dnew_bp}); + printer({&maxpool3dnew}); + printer({&maxpool3dnew_bp}); #endif } diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp index ee569a07c..18f58c2a1 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp @@ -25,6 +25,7 @@ #include #include #include +#include using namespace nd4j; @@ -2247,3 +2248,525 @@ TEST_F(DeclarableOpsTests13, batchnorm_test9) { delete results; } + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, batchnorm_bp_test1) { + + NDArray input ('c', {2,3,4}, nd4j::DataType::FLOAT32); + NDArray mean ('c', {4}, {1.1, 1.2, 1.3, 1.4}, nd4j::DataType::FLOAT32); + NDArray variance('c', {4}, nd4j::DataType::FLOAT32); + NDArray gamma ('c', {4}, nd4j::DataType::FLOAT32); + NDArray beta ('c', {4}, nd4j::DataType::FLOAT32); + NDArray gradO ('c', {2,3,4}, nd4j::DataType::FLOAT32); + + NDArray expdLdI('c', {2,3,4}, {-0.000056, -0.000056, -0.000056, -0.000056, -0.000034, -0.000034, -0.000034, -0.000034, -0.000011, -0.000011, -0.000011, -0.000011, 0.000011, 0.000011, 0.000011, 0.000011, 0.000034, 0.000034, 0.000034, 0.000034, 0.000056, 0.000056, 0.000056, 0.000056}, nd4j::DataType::FLOAT32); + NDArray expdLdG('c', {4}, {6.148104, 6.148104, 6.148105, 6.148105}, nd4j::DataType::FLOAT32); + NDArray expdLdB('c', {4}, {3.6, 4.5, 5.4, 6.3}, nd4j::DataType::FLOAT32); + + input.linspace(0.1, 0.1); + variance.assign(0.46666667); + gamma.assign(1.2); + beta.assign(1.); // has no effect on gradient calculations + gradO.linspace(-0.9, 0.15); + + nd4j::ops::batchnorm_bp op; + + auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto dLdI = results->at(0); + auto dLdG = results->at(3); + auto dLdB = results->at(4); + + ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI)); + ASSERT_TRUE(expdLdI.equalsTo(dLdI)); + + ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG)); + ASSERT_TRUE(expdLdG.equalsTo(dLdG)); + + ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB)); + ASSERT_TRUE(expdLdB.equalsTo(dLdB)); + + delete results; +} + + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, batchnorm_bp_test2) { + + NDArray input ('c', {2,3,4}, nd4j::DataType::FLOAT32); + NDArray mean ('c', {3}, {1.05, 1.1, 1.15}, nd4j::DataType::FLOAT32); + NDArray variance('c', {3}, {0.5, 0.6, 0.7}, nd4j::DataType::FLOAT32); + NDArray gamma ('c', {3}, {1.2, 1.3, 1.4}, nd4j::DataType::FLOAT32); + NDArray beta ('c', {3}, nd4j::DataType::FLOAT32); + NDArray gradO ('c', {2,3,4}, nd4j::DataType::FLOAT32); + + NDArray expdLdI('c', {2,3,4}, {-0.601415, -0.521226, -0.441037, -0.360849, -0.456306, -0.395465, -0.334624, -0.273784, 0.396631, 0.343747, + 0.290863, 0.237978, 0.360849, 0.441037, 0.521226, 0.601415, 0.273784, 0.334625, 0.395465, 0.456306, -0.237978, + -0.290863, -0.343746, -0.396631}, nd4j::DataType::FLOAT32); + NDArray expdLdG('c', {3}, {5.81236 , 7.048771, 12.155388}, nd4j::DataType::FLOAT32); + NDArray expdLdB('c', {3}, {1.8, 6.6, 11.4}, nd4j::DataType::FLOAT32); + + input.linspace(0.1, 0.1); + // beta.assign(1.); // has no effect on gradient calculations + gradO.linspace(-0.9, 0.15); + + nd4j::ops::batchnorm_bp op; + + auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,1}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto dLdI = results->at(0); + auto dLdG = results->at(3); + auto dLdB = results->at(4); + + ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI)); + ASSERT_TRUE(expdLdI.equalsTo(dLdI)); + + ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG)); + ASSERT_TRUE(expdLdG.equalsTo(dLdG)); + + ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB)); + ASSERT_TRUE(expdLdB.equalsTo(dLdB)); + + delete results; +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, batchnorm_bp_test3) { + + NDArray input ('c', {2,3,4}, nd4j::DataType::FLOAT32); + NDArray mean ('c', {2,1,4}, {1.05, 1.1, 1.15, 1.2, 1.25, 1.3, 1.35, 1.4}, nd4j::DataType::FLOAT32); + NDArray variance('c', {2,1,4}, {0.5, 0.6, 0.7, 0.8, 0.9, 1., 1.1, 1.2}, nd4j::DataType::FLOAT32); + NDArray gamma ('c', {2,1,4}, {1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9}, nd4j::DataType::FLOAT32); + NDArray beta ('c', {2,1,4}, nd4j::DataType::FLOAT32); + NDArray gradO ('c', {2,3,4}, nd4j::DataType::FLOAT32); + + NDArray expdLdI('c', {2,3,4}, {-0.577002, -0.744041, -0.850999, -0.922373, -0.000000, -0.000000, -0.000000, -0.000000, 0.577002, + 0.744041, 0.850999, 0.922373, -0.386037, -0.350205, -0.312047, -0.271737, -0.000000, -0.000000, + -0.000000, -0.000000, 0.386037, 0.350205, 0.312047, 0.271736}, nd4j::DataType::FLOAT32); + NDArray expdLdG('c', {2,1,4}, {1.378844, 0.910144, 0.573706, 0.335408, 2.640487, 2.954985, 3.289431, 3.64234 }, nd4j::DataType::FLOAT32); + NDArray expdLdB('c', {2,1,4}, {-0.9 , -0.45, 0. , 0.45, 4.5 , 4.95, 5.4 , 5.85}, nd4j::DataType::FLOAT32); + + input.linspace(0.1, 0.1); + // beta.assign(1.); // has no effect on gradient calculations + gradO.linspace(-0.9, 0.15); + + nd4j::ops::batchnorm_bp op; + + auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,0,2}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto dLdI = results->at(0); + auto dLdG = results->at(3); + auto dLdB = results->at(4); + + ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI)); + ASSERT_TRUE(expdLdI.equalsTo(dLdI)); + + ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG)); + ASSERT_TRUE(expdLdG.equalsTo(dLdG)); + + ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB)); + ASSERT_TRUE(expdLdB.equalsTo(dLdB)); + + delete results; +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, batchnorm_bp_test4) { + + NDArray input ('c', {2,4}, nd4j::DataType::FLOAT32); + NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32); + NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, nd4j::DataType::FLOAT32); + NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32); + NDArray beta ('c', {4}, nd4j::DataType::FLOAT32); + NDArray gradO ('c', {2,4}, nd4j::DataType::FLOAT32); + + NDArray expdLdI('c', {2,4}, {0.162923, -0.289673, 0.354174, -0.386151, -0.162923, 0.289673, -0.354174, 0.386151}, nd4j::DataType::FLOAT32); + NDArray expdLdG('c', {4}, {1.442483, 0.950200, 0.569207, 0.314641}, nd4j::DataType::FLOAT32); + NDArray expdLdB('c', {4}, {-1.2, -0.9, -0.6, -0.3}, nd4j::DataType::FLOAT32); + + input.linspace(0.1, 0.1); + gradO.linspace(-0.9, 0.15); + + nd4j::ops::batchnorm_bp op; + + auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto dLdI = results->at(0); + auto dLdG = results->at(3); + auto dLdB = results->at(4); + + ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI)); + ASSERT_TRUE(expdLdI.equalsTo(dLdI)); + + ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG)); + ASSERT_TRUE(expdLdG.equalsTo(dLdG)); + + ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB)); + ASSERT_TRUE(expdLdB.equalsTo(dLdB)); + + delete results; +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, batchnorm_bp_test5) { + +#if defined(HAVE_CUDNN) +return; +#endif + NDArray input ('c', {2,4,2,2}, nd4j::DataType::FLOAT32); + NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32); + NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, nd4j::DataType::FLOAT32); + NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32); + NDArray beta ('c', {4}, nd4j::DataType::FLOAT32); + NDArray gradO ('c', {2,4,2,2}, nd4j::DataType::FLOAT32); + + NDArray expdLdI('c', {2,4,2,2}, {-0.737512, -0.659880, -0.582247, -0.504614, 0.561404, 0.502309, 0.443214, 0.384118, -1.168243, + -1.045270, -0.922297, -0.799324, 1.899026, 1.699128, 1.499231, 1.299333, 0.504614, 0.582247, 0.659880, 0.737512, -0.384118, + -0.443214, -0.502308, -0.561404, 0.799324, 0.922297, 1.045270, 1.168243, -1.299334, -1.499231, -1.699129, -1.899026}, nd4j::DataType::FLOAT32); + NDArray expdLdG('c', {4}, {11.073181, 12.585667, 17.708657, 24.313186}, nd4j::DataType::FLOAT32); + NDArray expdLdB('c', {4}, {4.2, 9. , 13.8, 18.6}, nd4j::DataType::FLOAT32); + + input.linspace(0.1, 0.1); + gradO.linspace(-0.9, 0.15); + + nd4j::ops::batchnorm_bp op; + + auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,1}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto dLdI = results->at(0); + auto dLdG = results->at(3); + auto dLdB = results->at(4); + + ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI)); + ASSERT_TRUE(expdLdI.equalsTo(dLdI)); + + ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG)); + ASSERT_TRUE(expdLdG.equalsTo(dLdG)); + + ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB)); + ASSERT_TRUE(expdLdB.equalsTo(dLdB)); + + delete results; +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, batchnorm_bp_test6) { + +#if defined(HAVE_CUDNN) +return; +#endif + + NDArray input ('c', {2,2,2,4}, nd4j::DataType::FLOAT32); + NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32); + NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, nd4j::DataType::FLOAT32); + NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32); + NDArray beta ('c', {4}, nd4j::DataType::FLOAT32); + NDArray gradO ('c', {2,2,2,4}, nd4j::DataType::FLOAT32); + + NDArray expdLdI('c', {2,2,2,4}, {-4.989124, 2.540357, -1.515022, 0.791769, -3.563660, 1.814540, -1.082159, 0.565549, -2.138196, 1.088724, -0.649295, + 0.339329, -0.712732, 0.362908, -0.216432, 0.113110, 0.712732, -0.362908, 0.216432, -0.113110, 2.138195, -1.088724, 0.649295, + -0.339330, 3.563660,-1.814540, 1.082159, -0.565549, 4.989125, -2.540356, 1.515022, -0.791770}, nd4j::DataType::FLOAT32); + NDArray expdLdG('c', {4}, {20.364472, 17.856588, 16.949714, 15.903684}, nd4j::DataType::FLOAT32); + NDArray expdLdB('c', {4}, {9.6, 10.8, 12. , 13.2}, nd4j::DataType::FLOAT32); + + input.linspace(0.1, 0.1); + gradO.linspace(-0.9, 0.15); + + nd4j::ops::batchnorm_bp op; + + auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,3}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto dLdI = results->at(0); + auto dLdG = results->at(3); + auto dLdB = results->at(4); + + ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI)); + ASSERT_TRUE(expdLdI.equalsTo(dLdI)); + + ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG)); + ASSERT_TRUE(expdLdG.equalsTo(dLdG)); + + ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB)); + ASSERT_TRUE(expdLdB.equalsTo(dLdB)); + + delete results; +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, batchnorm_bp_test7) { + +#if defined(HAVE_CUDNN) +return; +#endif + + NDArray input ('c', {2,2,2,2,4}, nd4j::DataType::FLOAT32); + NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32); + NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, nd4j::DataType::FLOAT32); + NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32); + NDArray beta ('c', {4}, nd4j::DataType::FLOAT32); + NDArray gradO ('c', {2,2,2,2,4}, nd4j::DataType::FLOAT32); + + NDArray expdLdI('c', {2,2,2,2,4}, {-119.435059, 78.159744, -58.732986, 46.630123, -103.510391, 67.738441, -50.901920, 40.412773, -87.585716, 57.317142, + -43.070854, 34.195419, -71.661041, 46.895844, -35.239792, 27.978071, -55.736359, 36.474548, -27.408726, 21.760721, -39.811687, 26.053242, -19.577662, + 15.543370, -23.887009, 15.631950, -11.746595, 9.326023, -7.962326, 5.210644, -3.915531, 3.108671, 7.962341, -5.210655, 3.915535, -3.108677, 23.887032, + -15.631958, 11.746601, -9.326031, 39.811691, -26.053246, 19.577671, -15.543377, 55.736382, -36.474548, 27.408726, -21.760731, 71.661064, -46.895851, 35.239788, + -27.978077, 87.585732, -57.317154, 43.070866, -34.195431, 103.510384, -67.738464, 50.901920, -40.412777, 119.435097, -78.159744, 58.732998, -46.630131}, nd4j::DataType::FLOAT32); + NDArray expdLdG('c', {4}, {282.38734 , 244.542027, 224.140995, 207.548793}, nd4j::DataType::FLOAT32); + NDArray expdLdB('c', {4}, {57.6, 60. , 62.4, 64.8}, nd4j::DataType::FLOAT32); + + input.linspace(0.1, 0.1); + gradO.linspace(-0.9, 0.15); + + + nd4j::ops::batchnorm_bp op; + + auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,4}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto dLdI = results->at(0); + auto dLdG = results->at(3); + auto dLdB = results->at(4); + + // dLdI->printBuffer(); + + ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI)); + ASSERT_TRUE(expdLdI.equalsTo(dLdI)); + + ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG)); + ASSERT_TRUE(expdLdG.equalsTo(dLdG)); + + ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB)); + ASSERT_TRUE(expdLdB.equalsTo(dLdB)); + + delete results; +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, batchnorm_bp_test8) { + +#if defined(HAVE_CUDNN) +return; +#endif + + NDArray input ('c', {2,4,2,2,2}, nd4j::DataType::FLOAT32); + NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32); + NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, nd4j::DataType::FLOAT32); + NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32); + NDArray beta ('c', {4}, nd4j::DataType::FLOAT32); + NDArray gradO ('c', {2,4,2,2,2}, nd4j::DataType::FLOAT32); + + NDArray expdLdI('c', {2,4,2,2,2}, {-34.373802, -32.611046, -30.848286, -29.085529, -27.322769, -25.560009, -23.797251, -22.034491, 36.146996, 34.293301, + 32.439610, 30.585917, 28.732227, 26.878534, 25.024841, 23.171150, -42.876553, -40.677757, -38.478958, -36.280159, -34.081367, -31.882565, -29.683767, + -27.484968, 50.674446, 48.075760, 45.477066, 42.878380, 40.279686, 37.681000, 35.082310, 32.483616, 22.034489, 23.797249, 25.560009, 27.322765, 29.085526, + 30.848286, 32.611046, 34.373802, -23.171146, -25.024837, -26.878536, -28.732231, -30.585918, -32.439613, -34.293297, -36.146996, 27.484982, 29.683773, + 31.882572, 34.081364, 36.280178, 38.478970, 40.677776, 42.876560, -32.483627, -35.082329, -37.681023, -40.279701, -42.878403, -45.477081, -48.075775, -50.674484}, nd4j::DataType::FLOAT32); + NDArray expdLdG('c', {4}, {134.490365, 179.785003, 248.933114, 330.087248}, nd4j::DataType::FLOAT32); + NDArray expdLdB('c', {4}, {32.4, 51.6, 70.8, 90.}, nd4j::DataType::FLOAT32); + + input.linspace(0.1, 0.1); + gradO.linspace(-0.9, 0.15); + + nd4j::ops::batchnorm_bp op; + + auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,1}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto dLdI = results->at(0); + auto dLdG = results->at(3); + auto dLdB = results->at(4); + + // dLdI->printBuffer(); + + ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI)); + ASSERT_TRUE(expdLdI.equalsTo(dLdI)); + + ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG)); + ASSERT_TRUE(expdLdG.equalsTo(dLdG)); + + ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB)); + ASSERT_TRUE(expdLdB.equalsTo(dLdB)); + + delete results; +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, batchnorm_bp_test9) { + + NDArray input ('c', {2,4,2,2}, nd4j::DataType::FLOAT32); + NDArray mean ('c', {4}, nd4j::DataType::FLOAT32); + NDArray variance('c', {4}, nd4j::DataType::FLOAT32); + NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32); + NDArray beta ('c', {4}, nd4j::DataType::FLOAT32); + NDArray gradO ('c', {2,4,2,2}, nd4j::DataType::FLOAT32); + + NDArray expdLdI('c', {2,4,2,2}, {0.032378, 0.028967, 0.025558, 0.022147, -0.035056, -0.031364, -0.027669, -0.024006, 0.037742, 0.033766, 0.029791, 0.025818, + -0.040429, -0.036172, -0.031913, -0.027656, -0.022155, -0.025564, -0.028974, -0.032359, 0.023982, 0.027677, 0.031373, 0.035063, + -0.025822, -0.029794, -0.033770, -0.037747, 0.027653, 0.031913, 0.036168, 0.040426}, nd4j::DataType::FLOAT32); + NDArray expdLdG('c', {4}, {9.685875, 9.685880, 9.685887, 9.685891}, nd4j::DataType::FLOAT32); + NDArray expdLdB('c', {4}, {4.2, 9. , 13.8, 18.6}, nd4j::DataType::FLOAT32); + + input.linspace(1,0.01); + gradO.linspace(-0.9, 0.15); + + // calculate mean and variance of input + PointersManager manager(input.getContext(), "DeclarableOpsTests13.batchnorm_bp_test9"); + std::vector dimensions = {0,2,3}; + int* dims = reinterpret_cast(manager.replicatePointer(dimensions.data(), dimensions.size() * sizeof(int))); + input.reduceAlongDimension(nd4j::reduce::Mean, mean, dimensions); + NDArray::prepareSpecialUse({&variance}, {&input}); + auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input.getShapeInfo(), dimensions); + NativeOpExecutioner::execSummaryStats(input.getContext(), 0,input.getBuffer(), input.getShapeInfo(),input.getSpecialBuffer(), input.getSpecialShapeInfo(),nullptr,variance.getBuffer(), variance.getShapeInfo(),variance.getSpecialBuffer(), variance.getSpecialShapeInfo(), dims, dimensions.size(),packX.platformShapeInfo(), packX.platformOffsets(),false); + manager.synchronize(); + NDArray::registerSpecialUse({&variance}, {&input}); + + nd4j::ops::batchnorm_bp op; + + auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,1}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto dLdI = results->at(0); + auto dLdG = results->at(3); + auto dLdB = results->at(4); + + ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI)); + ASSERT_TRUE(expdLdI.equalsTo(dLdI, 1e-4)); + + ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG)); + ASSERT_TRUE(expdLdG.equalsTo(dLdG)); + + ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB)); + ASSERT_TRUE(expdLdB.equalsTo(dLdB)); + + delete results; +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, batchnorm_bp_test10) { + + NDArray input ('c', {2,2,2,4}, nd4j::DataType::FLOAT32); + NDArray mean ('c', {4}, nd4j::DataType::FLOAT32); + NDArray variance('c', {4}, nd4j::DataType::FLOAT32); + NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32); + NDArray beta ('c', {4}, nd4j::DataType::FLOAT32); + NDArray gradO ('c', {2,2,2,4}, nd4j::DataType::FLOAT32); + + NDArray expdLdI('c', {2,2,2,4}, {0.032634, -0.035423, 0.038110, -0.040864, 0.023302, -0.025294, 0.027213, -0.029205, 0.013996, -0.015192, 0.016343, + -0.017519, 0.004664, -0.005062, 0.005445, -0.005833, -0.004668, 0.005067, -0.005452, 0.005824, -0.013974, 0.015171, + -0.016325, 0.017508, -0.023309, 0.025301, -0.027221, 0.029197, -0.032639, 0.035428, -0.038118, 0.040878}, nd4j::DataType::FLOAT32); + NDArray expdLdG('c', {4}, {10.991656, 10.991631, 10.991643, 10.991632}, nd4j::DataType::FLOAT32); + NDArray expdLdB('c', {4}, {9.6, 10.8, 12., 13.2}, nd4j::DataType::FLOAT32); + + input.linspace(1,0.01); + gradO.linspace(-0.9, 0.15); + + // calculate mean and variance of input + PointersManager manager(input.getContext(), "DeclarableOpsTests13.batchnorm_bp_test9"); + std::vector dimensions = {0,1,2}; + int* dims = reinterpret_cast(manager.replicatePointer(dimensions.data(), dimensions.size() * sizeof(int))); + input.reduceAlongDimension(nd4j::reduce::Mean, mean, dimensions); + NDArray::prepareSpecialUse({&variance}, {&input}); + auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input.getShapeInfo(), dimensions); + NativeOpExecutioner::execSummaryStats(input.getContext(), 0,input.getBuffer(), input.getShapeInfo(),input.getSpecialBuffer(), input.getSpecialShapeInfo(),nullptr,variance.getBuffer(), variance.getShapeInfo(),variance.getSpecialBuffer(), variance.getSpecialShapeInfo(), dims, dimensions.size(),packX.platformShapeInfo(), packX.platformOffsets(),false); + manager.synchronize(); + NDArray::registerSpecialUse({&variance}, {&input}); + + nd4j::ops::batchnorm_bp op; + + auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,3}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto dLdI = results->at(0); + auto dLdG = results->at(3); + auto dLdB = results->at(4); + + ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI)); + ASSERT_TRUE(expdLdI.equalsTo(dLdI, 1e-4)); + + ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG)); + ASSERT_TRUE(expdLdG.equalsTo(dLdG)); + + ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB)); + ASSERT_TRUE(expdLdB.equalsTo(dLdB)); + + delete results; +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, batchnorm_bp_test11) { + + NDArray input ('c', {2,3,4,5}, nd4j::DataType::FLOAT32); + NDArray mean ('c', {1,3,4,5}, nd4j::DataType::FLOAT32); + NDArray variance('c', {1,3,4,5}, nd4j::DataType::FLOAT32); + NDArray gamma ('c', {1,3,4,5}, nd4j::DataType::FLOAT32); + NDArray beta ('c', {1,3,4,5}, nd4j::DataType::FLOAT32); + NDArray gradO ('c', {2,3,4,5}, nd4j::DataType::FLOAT32); + + NDArray expdLdI('c', {2,3,4,5}, {0.004981, 0.004818, 0.004652, 0.004483, 0.004319, 0.004153, 0.003985, 0.003832, 0.003661, 0.003505, 0.003340, 0.003171, 0.003001, 0.002837, + 0.002670, 0.002505, 0.002337, 0.002167, 0.002003, 0.001835, 0.001666, 0.001499, 0.001327, 0.001162, 0.000996, 0.000830, 0.000664, 0.000498, + 0.000332, 0.000166, -0.0, -0.000166, -0.000333, -0.000500, -0.000668, -0.000835, -0.001003, -0.001168, -0.001337, -0.001502, -0.001670, + -0.001838, -0.002003, -0.002172, -0.002330, -0.002499, -0.002669, -0.002832, -0.003002, -0.003162, -0.003332, -0.003495, -0.003665, -0.003821, + -0.004001, -0.004163, -0.004324, -0.004516, -0.004678, -0.004851, -0.004981, -0.004818, -0.004652, -0.004483, -0.004319, -0.004151, -0.003985, + -0.003836, -0.003661, -0.003505, -0.003338, -0.003171, -0.003004, -0.002837, -0.002670, -0.002503, -0.002337, -0.002170, -0.002003, -0.001835, + -0.001664, -0.001499, -0.001328, -0.001162, -0.000996, -0.000829, -0.000664, -0.000498, -0.000332, -0.000166, 0.0, 0.000166, 0.000334, + 0.000500, 0.000668, 0.000834, 0.001003, 0.001170, 0.001337, 0.001502, 0.001669, 0.001838, 0.002005, 0.002172, 0.002330, 0.002496, 0.002669, + 0.002836, 0.003002, 0.003162, 0.003328, 0.003495, 0.003670, 0.003828, 0.003992, 0.004158, 0.004324, 0.004522, 0.004689, 0.004843}, nd4j::DataType::FLOAT32); + NDArray expdLdG('c', {1,3,4,5}, {8.999503, 8.999502, 8.999502, 8.999503, 8.999502, 8.999503, 8.999503, 8.999499, 8.999501, 8.999498, 8.999498, 8.999498, 8.999498, 8.999498, 8.999498, + 8.999498, 8.999498, 8.999498, 8.999498, 8.999499, 8.999501, 8.999500, 8.999503, 8.999503, 8.999503, 8.999504, 8.999503, 8.999503, 8.999504, 8.999503, + 8.999504, 8.999504, 8.999499, 8.999500, 8.999497, 8.999498, 8.999496, 8.999496, 8.999496, 8.999498, 8.999498, 8.999496, 8.999496, 8.999496, 8.999501, + 8.999501, 8.999499, 8.999499, 8.999499, 8.999501, 8.999501, 8.999501, 8.999499, 8.999500, 8.999501, 8.999501, 8.999501, 8.999495, 8.999495, 8.999497}, nd4j::DataType::FLOAT32); + NDArray expdLdB('c', {1,3,4,5}, {7.2, 7.5, 7.8, 8.1, 8.4, 8.7, 9.0, 9.3, 9.6, 9.9, 10.2, 10.5, 10.8, 11.1, 11.4, 11.7, 12.0, 12.3, 12.6, 12.9, 13.2, 13.5, 13.8, 14.1, 14.4, 14.7, 15.0, + 15.3, 15.6, 15.9, 16.2, 16.5, 16.8, 17.1, 17.4, 17.7, 18.0, 18.3, 18.6, 18.9, 19.2, 19.5, 19.8, 20.1, 20.4, 20.7, 21.0, 21.3, 21.6, 21.9, 22.2, 22.5, + 22.8, 23.1, 23.4, 23.7, 24.0, 24.3, 24.6, 24.9}, nd4j::DataType::FLOAT32); + + input.linspace(1,0.01); + gradO.linspace(-0.9, 0.15); + gamma.linspace(-3, 0.1); + + // calculate mean and variance of input + PointersManager manager(input.getContext(), "DeclarableOpsTests13.batchnorm_bp_test9"); + std::vector dimensions = {0}; + int* dims = reinterpret_cast(manager.replicatePointer(dimensions.data(), dimensions.size() * sizeof(int))); + input.reduceAlongDimension(nd4j::reduce::Mean, mean, dimensions, true); + NDArray::prepareSpecialUse({&variance}, {&input}); + auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input.getShapeInfo(), dimensions); + NativeOpExecutioner::execSummaryStats(input.getContext(), 0,input.getBuffer(), input.getShapeInfo(),input.getSpecialBuffer(), input.getSpecialShapeInfo(),nullptr,variance.getBuffer(), variance.getShapeInfo(),variance.getSpecialBuffer(), variance.getSpecialShapeInfo(), dims, dimensions.size(),packX.platformShapeInfo(), packX.platformOffsets(),false); + manager.synchronize(); + NDArray::registerSpecialUse({&variance}, {&input}); + + nd4j::ops::batchnorm_bp op; + + auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1, 1,2,3}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto dLdI = results->at(0); + auto dLdG = results->at(3); + auto dLdB = results->at(4); + + ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI)); + ASSERT_TRUE(expdLdI.equalsTo(dLdI, 1e-4)); + + ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG)); + ASSERT_TRUE(expdLdG.equalsTo(dLdG)); + + ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB)); + ASSERT_TRUE(expdLdB.equalsTo(dLdB)); + + delete results; +} \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp index 75db5989c..84dd5d732 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp @@ -72,71 +72,6 @@ TEST_F(DeclarableOpsTests15, Test_Half_assign_1) { ASSERT_EQ(10, x.sumNumber().e(0)); } -TEST_F(DeclarableOpsTests15, test_avgpooling_edge_1) { - int inOutH = 5;// 35; - int inOutW = 5;// 35; - int inOutC = 10;// 192; - - auto x = NDArrayFactory::create('c', {1, inOutH, inOutW, inOutC}); - x.linspace(1.0); - - nd4j::ops::avgpool2d op; - auto result = op.execute({&x}, {}, {3,3, 1,1, 0,0, 1,1, 1, 0, 1}); - ASSERT_EQ(Status::OK(), result->status()); - - auto z = result->at(0); - - int totalPadHeight = (inOutH - 1) * 1 + 3 - inOutH; - int padTop = totalPadHeight / 2; - int padBottom = totalPadHeight - totalPadHeight / 2; - - int k = 3; - - auto m = NDArrayFactory::create('c', {1, inOutH, inOutW, inOutC}); - auto c = NDArrayFactory::create('c', {1, inOutH, inOutW, inOutC}); - - for (int h = 0; h < inOutH; h++) { - for (int w = 0; w < inOutW; w++) { - int hFrom = h - padTop; - int wFrom = w - padBottom; - - int hTo = hFrom + k; - int wTo = wFrom + k; - - hFrom = nd4j::math::nd4j_max(0, hFrom); - wFrom = nd4j::math::nd4j_max(0, wFrom); - - hTo = nd4j::math::nd4j_min(inOutH, hTo); - wTo = nd4j::math::nd4j_min(inOutW, wTo); - - int idxOut[4]; - int idxIn[4]; - for (int ch = 0; ch < inOutC; ch++) { - idxOut[1] = h; - idxOut[2] = w; - idxOut[3] = ch; - idxIn[3] = ch; - - for (int kh = hFrom; kh < hTo; kh++) { - for (int kw = wFrom; kw < wTo; kw++) { - idxIn[1] = kh; - idxIn[2] = kw; - - auto inVal = x.e(0, kh, kw, ch); - m.p(0, h, w, ch, inVal + m.e(0, h, w, ch)); - c.p(0, h, w, ch, 1 + c.e(0, h, w, ch)); - } - } - } - } - } - m /= c; - - ASSERT_EQ(m, *z); - - delete result; -} - TEST_F(DeclarableOpsTests15, Test_standarize_1) { auto x = NDArrayFactory::create('c', {5}, {1.f, 1.f, 1.f, 1.f, 1.f}); auto e = NDArrayFactory::create('c', {5}, {0.f, 0.f, 0.f, 0.f, 0.f}); @@ -1097,7 +1032,7 @@ TEST_F(DeclarableOpsTests15, test_rgb_to_yuv_2) { ASSERT_EQ(Status::OK(), result->status()); ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); - + delete result; } @@ -1106,7 +1041,7 @@ TEST_F(DeclarableOpsTests15, test_rgb_to_yuv_3) { // rank 2 NDArray rgbs('c', { 3, 4 }, { -9.4, 9.9, 9.7, 9.0, 1.14, 1.01, 1.11, 9.6, 1.05, 10.0, 1.03, 10.22 }, nd4j::DataType::FLOAT32); NDArray expected('c', { 3, 4 }, { -2.021720, 4.692970, 3.669290, 9.491281, 1.511627, 2.611648, -1.298824, 0.358612, -6.472839, 4.568039, 5.290639, -0.430992 }, nd4j::DataType::FLOAT32); - + nd4j::ops::rgb_to_yuv op; auto result = op.execute({ &rgbs }, {}, { 0 }); auto output = result->at(0); @@ -1170,7 +1105,7 @@ TEST_F(DeclarableOpsTests15, test_rgb_to_yuv_7) { // rank 3 NDArray rgbs('f', { 2, 2, 3 }, { 1.7750e+01f,-7.1062e+01f, -1.0019e+02f, -2.3406e+01f,5.2094e+01f,9.5438e+01f, -6.7461e+00f,3.8562e+01f, 6.5078e+00f, 3.3562e+01f,-5.8844e+01f,2.2750e+01f }, nd4j::DataType::FLOAT32); NDArray expected('f', { 2,2,3 }, { 36.628319,38.600643, -40.624989,18.231001, -14.822637,-2.479566, -8.965780, 2.223851, -16.561626,- 96.205162,-52.255379, -36.527435 }, nd4j::DataType::FLOAT32); - + nd4j::ops::rgb_to_yuv op; auto result = op.execute({ &rgbs }, {}, {}); auto output = result->at(0); @@ -1210,7 +1145,7 @@ TEST_F(DeclarableOpsTests15, test_yuv_to_rgb_2) { ASSERT_EQ(Status::OK(), result->status()); ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); - + delete result; } @@ -1484,7 +1419,7 @@ TEST_F(DeclarableOpsTests15, Pow_BP_Test7) { auto Y = NDArrayFactory::create(2.f); NDArray x('c', { 2, 2, 2 }, nd4j::DataType::FLOAT32); NDArray dLdzC('c', { 2, 2, 2 }, nd4j::DataType::FLOAT32); - + dLdzC.linspace(0.1, 0.1); x = 4.f; diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests3.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests3.cpp index dacfac127..2ef86710a 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests3.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests3.cpp @@ -883,22 +883,6 @@ TEST_F(DeclarableOpsTests3, Test_Manual_Gemm_6) { delete result; } -TEST_F(DeclarableOpsTests3, Test_AvgPool_1) { - auto x= NDArrayFactory::create('c', {2, 10, 10, 3}); - x.linspace(1); - - nd4j::ops::avgpool2d op; - // kY kX sY sX pY pX dY dX M P - auto result = op.execute({&x}, {}, {3, 3, 3, 3, 0, 0, 1, 1, 1, 0, 1}); - // 0 1 2 3 4 5 6 7 8 9 10 - auto z = result->at(0); - - // z->printShapeInfo("z shape"); - // z->printIndexedBuffer("z buffr"); - - delete result; -} - TEST_F(DeclarableOpsTests3, Test_ReverseDivide_1) { auto x= NDArrayFactory::create('c', {1, 3}, {2, 2, 2}); auto y= NDArrayFactory::create('c', {1, 3}, {4, 6, 8}); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests4.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests4.cpp index 1155c72de..9460a053f 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests4.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests4.cpp @@ -56,7 +56,8 @@ public: typedef ::testing::Types TestingTypes; TYPED_TEST_CASE(TypedDeclarableOpsTests4, TestingTypes); -TYPED_TEST(TypedDeclarableOpsTests4, Test_Pooling_Parity_1) { +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedDeclarableOpsTests4, avgpool2d_1) { auto x = NDArrayFactory::create('c', {2, 4, 4, 2}); auto exp = NDArrayFactory::create('c', {2, 2, 2, 2}, {6.f, 7.f, 10.f, 11.f, 22.f, 23.f, 26.f, 27.f, 38.f, 39.f, 42.f, 43.f, 54.f, 55.f, 58.f, 59.f}); @@ -75,8 +76,8 @@ TYPED_TEST(TypedDeclarableOpsTests4, Test_Pooling_Parity_1) { delete result; } - -TYPED_TEST(TypedDeclarableOpsTests4, Test_Pooling_Parity_2) { +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedDeclarableOpsTests4, avgpool2d_2) { auto x = NDArrayFactory::create('c', {2, 4, 4, 2}); auto exp = NDArrayFactory::create('c', {2, 2, 2, 2}, {6.f, 7.f, 10.f, 11.f, 22.f, 23.f, 26.f, 27.f, 38.f, 39.f, 42.f, 43.f, 54.f, 55.f, 58.f, 59.f}); @@ -96,7 +97,8 @@ TYPED_TEST(TypedDeclarableOpsTests4, Test_Pooling_Parity_2) { delete result; } -TYPED_TEST(TypedDeclarableOpsTests4, Test_Pooling_Parity_5) { +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedDeclarableOpsTests4, avgpool2d_3) { auto x = NDArrayFactory::create('c', {2, 5, 5, 2}); auto exp = NDArrayFactory::create('c', {2, 3, 3, 2}, {7.f, 8.f, 11.f, 12.f, 14.f, 15.f, 27.f, 28.f, 31.f, 32.f, 34.f, 35.f, 42.f, 43.f, 46.f, 47.f, 49.f, 50.f, 57.f, 58.f, 61.f, 62.f, 64.f, 65.f, 77.f, 78.f, 81.f, 82.f, 84.f, 85.f, 92.f, 93.f, 96.f, 97.f, 99.f, 100.f,}); @@ -116,7 +118,8 @@ TYPED_TEST(TypedDeclarableOpsTests4, Test_Pooling_Parity_5) { delete result; } -TYPED_TEST(TypedDeclarableOpsTests4, Test_Pooling_Parity_6) { +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedDeclarableOpsTests4, avgpool2d_4) { auto x = NDArrayFactory::create('c', {2, 5, 5, 2}); auto exp = NDArrayFactory::create('c', {2, 2, 2, 2}, {7.f, 8.f, 11.f, 12.f, 27.f, 28.f, 31.f, 32.f, 57.f, 58.f, 61.f, 62.f, 77.f, 78.f, 81.f, 82.f}); @@ -135,8 +138,8 @@ TYPED_TEST(TypedDeclarableOpsTests4, Test_Pooling_Parity_6) { delete result; } - -TYPED_TEST(TypedDeclarableOpsTests4, Test_Pooling_Parity_8) { +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedDeclarableOpsTests4, avgpool2d_5) { auto x = NDArrayFactory::create('c', {2, 2, 5, 5}); auto exp = NDArrayFactory::create('c', {2, 2, 3, 3}, {1.f, 2.5f, 4.5f, 8.5f, 10.f, 12.f, 18.5f, 20.f, 22.f, 26.f, 27.5f, 29.5f, 33.5f, 35.f, 37.f, 43.5f, 45.f, 47.f, 51.f, 52.5f, 54.5f, 58.5f, 60.f, 62.f, 68.5f, 70.f, 72.f, 76.f, 77.5f, 79.5f, 83.5f, 85.f, 87.f, 93.5f, 95.f, 97.f}); @@ -156,8 +159,8 @@ TYPED_TEST(TypedDeclarableOpsTests4, Test_Pooling_Parity_8) { delete result; } - -TYPED_TEST(TypedDeclarableOpsTests4, Test_Pooling_Parity_9) { +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedDeclarableOpsTests4, avgpool2d_6) { auto x = NDArrayFactory::create('c', {2, 2, 5, 5}); auto exp = NDArrayFactory::create('c', {2, 2, 3, 3}, {0.25f, 1.25f, 2.25f, 4.25f, 10.f, 12.f, 9.25f, 20.f, 22.f, 6.5f, 13.75f, 14.75, 16.75f, 35.f, 37.f, 21.75f, 45.f, 47.f, 12.75f, 26.25f, 27.25f, 29.25f, 60.f, 62.f, 34.25f, 70.f, 72.f, 19.f, 38.75f, 39.75f, 41.75f, 85.f, 87.f, 46.75f, 95.f, 97.f}); @@ -177,8 +180,8 @@ TYPED_TEST(TypedDeclarableOpsTests4, Test_Pooling_Parity_9) { delete result; } - -TYPED_TEST(TypedDeclarableOpsTests4, Test_Pooling_Parity_10) { +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedDeclarableOpsTests4, avgpool2d_7) { auto x = NDArrayFactory::create('c', {2, 2, 5, 5}); auto exp = NDArrayFactory::create('c', {2, 2, 3, 3}, {4.f, 6.f, 7.5f, 14.f, 16.f, 17.5f, 21.5f, 23.5f, 25.f, 29.f, 31.f, 32.5f, 39.f, 41.f, 42.5f, 46.5f, 48.5f, 50.f, 54.f, 56.f, 57.5f, 64.f, 66.f, 67.5f, 71.5f, 73.5f, 75.f, 79.f, 81.f, 82.5f, 89.f, 91.f, 92.5f, 96.5f, 98.5f, 100.f}); @@ -198,8 +201,8 @@ TYPED_TEST(TypedDeclarableOpsTests4, Test_Pooling_Parity_10) { delete result; } - -TYPED_TEST(TypedDeclarableOpsTests4, Test_Pooling_Parity_11) { +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedDeclarableOpsTests4, avgpool2d_8) { auto x = NDArrayFactory::create('c', {1, 1, 3, 3}); auto exp = NDArrayFactory::create('c', {1, 1, 2, 2}, {3.f, 4.f, 6.f, 7.f}); @@ -219,7 +222,8 @@ TYPED_TEST(TypedDeclarableOpsTests4, Test_Pooling_Parity_11) { delete result; } -TYPED_TEST(TypedDeclarableOpsTests4, Test_Pooling_Parity_12) { +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedDeclarableOpsTests4, avgpool2d_9) { auto x = NDArrayFactory::create('c', {1, 1, 3, 3}); auto exp = NDArrayFactory::create('c', {1, 1, 3, 3}, {3.f, 4.f, 4.5f, 6.f, 7.f, 7.5f, 7.5f, 8.5f, 9.f}); @@ -242,7 +246,139 @@ TYPED_TEST(TypedDeclarableOpsTests4, Test_Pooling_Parity_12) { delete result; } +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedDeclarableOpsTests4, avgpool2d_10) { + auto input = NDArrayFactory::create('c', {4, 10, 10, 3}, {9.37125111f, 2.20166993f, 2.91434479f, 5.43639755f, -2.10573769f, 4.08528662f, 5.86908436f, -4.46203756f, 2.21057916f, 5.35849190f, 0.01394637f, 4.40566349f, 7.07982206f, -0.09633455f, 2.42429352f, 3.97301817f, -1.89553940f, 1.99690318f, 6.33141708f, 0.55401880f, 1.70707977f, + 5.55204201f, -0.03513752f, 1.60011971f, 2.62700319f, -2.74582434f, 3.06697464f, 1.06277943f, -1.16075921f, -0.78095782f, 9.72352791f, -1.22686064f, 1.99644792f, 7.35571337f, 1.40607321f, 0.11390255f, 9.53334427f, 2.28303599f, -1.66728830f, 6.16678810f, -0.04532295f, -1.97708666f, 9.74906158f, 1.46223176f, -1.46734393f, 4.30761862f, + -1.23790228f, 1.24823606f, 6.13938427f, -3.83689475f, -1.19625473f, 7.91535568f, 6.05868721f, -3.22946382f, 8.81633949f, -0.19967777f, 0.66053957f, 2.30919123f, 0.74543846f, -0.39347672f, 11.11058044f, 0.53720862f, 1.52645731f, 5.70012379f, -1.15213466f, 1.16451406f, 7.00526333f, 1.57362783f, -2.44384766f, 5.54213285f, -1.98828590f, + -0.70483637f, 7.88281822f, -3.59875536f, 0.80745387f, 13.41578484f, -1.55507684f, -0.65855008f, 9.32583523f, -0.14544789f, 0.73436141f, 3.61176538f, -1.71268058f, -2.58490300f, 9.09280205f, -3.27405524f, -2.04569697f, 4.44761324f, -0.62955856f, -2.61917663f, 8.04890442f, 0.54579324f, 0.85929775f, 9.82259560f, -1.93825579f, 0.77703512f, + 4.67090321f, -4.79267597f, -2.38906908f, 9.31265545f, 0.96026313f, -1.14109385f, 11.54231834f, -0.01417295f, -0.39500344f, 8.49191666f, 0.55300158f, 2.79490185f, 6.92466164f, 1.72254205f, 2.82222271f, 8.83112717f, 2.95033407f, 2.18054962f, 6.73509789f, -2.22272944f, 0.51127720f, -1.04563558f, 2.15747333f, -2.30959272f, 9.55441570f, + 1.50396204f, 1.77370787f, 7.38146257f, -1.79076433f, 3.20961165f, 7.18864202f, 2.91217351f, 0.43018937f, 7.11078024f, -1.17386127f, -0.16817921f, 6.12327290f, -2.82205725f, 3.30696845f, 13.51291752f, -1.30856836f, -2.38332748f, 11.09487438f, -1.47190213f, -0.53050828f, 4.38285351f, -5.07309771f, 1.50714362f, 5.72274446f, -2.85825086f, + -0.89673209f, 3.73791552f, -0.67708802f, -4.13149452f, -0.00671843f, -0.26566532f, 0.32961160f, 7.14501762f, -1.41608179f, -4.96590328f, 12.26205540f, -0.65158135f, -0.88641000f, 6.95777559f, -0.79058206f, -0.10260171f, 7.87169170f, 1.35921454f, 1.11759663f, 5.46187401f, -2.57214499f, 2.48484039f, 4.04043484f, -2.07137156f, -1.42709637f, + 9.25487137f, -0.12605135f, -2.66949964f, 2.89412403f, 0.74451172f, -2.96250391f, 3.99258423f, 0.27084303f, 0.32213116f, 5.42332172f, -0.44414216f, 1.70881832f, 6.69346905f, 0.53058422f, -4.73146200f, 4.22051668f, 2.24834967f, 0.66996074f, 4.30173683f, 0.11849818f, -4.07520294f, 8.27318478f, -2.54398274f, -2.86705542f, 10.11775303f, + -0.99382895f, 0.65881538f, 7.93556786f, -1.27934420f, -1.69343162f, 9.68042564f, -1.02609646f, -1.18189347f, 5.75370646f, -1.67888868f, -4.48871994f, 4.79537392f, -0.79212248f, -0.19855022f, 6.15060997f, -0.01081491f, 3.64454579f, 10.82562447f, 1.58859253f, -2.65847278f, 8.60093212f, -1.59196103f, 0.07635692f, 11.76175690f, -1.17453325f, + 0.10122013f, 6.86458445f, -2.18891335f, -2.74004745f, 8.07066154f, 0.71818852f, -2.03035975f, 6.31053686f, 0.51509416f, 1.39789927f, 9.43515587f, 2.04256630f, 0.13985133f, 4.65010691f, 2.40911126f, -0.36255789f, -3.06867862f, -0.45225358f, -1.56778407f, 6.05917358f, -1.09891272f, 1.77184200f, 6.46248102f, 0.96042323f, -0.24346280f, + 4.63436460f, -4.69907761f, 1.25187206f, 11.46173859f, -2.21917558f, 1.28007793f, 6.92173195f, 2.11268163f, -3.47389889f, 5.08722782f, -3.03950930f, -4.17154264f, 11.30568314f, 0.80361372f, 2.53214502f, 7.18707085f, -4.49114513f, 2.85449266f, 10.14906883f, -0.31974933f, -0.84472644f, -0.52459574f, 0.12921631f, -1.81390119f, 2.76170087f, 1.03982210f, 2.91744232f, -0.29048753f, 5.87453508f, -1.53684759f, 1.85800636f, -0.91404629f, 1.28954852f, 5.11354685f, -2.47475505f, -1.33179152f, 2.58552408f, 1.37316465f, -3.32339454f, 1.54122913f, 3.24953628f, -0.29758382f, 2.82391763f, -1.51142192f, -1.22699404f, 6.75745535f, 0.65452754f, -3.29385471f, 2.06008053f, 2.53172946f, -4.23532820f, -1.53909743f, -0.07010663f, -1.42173731f, 7.29031610f, -0.18448229f, 4.59496164f, 6.73027277f, 0.73441899f, 0.14426160f, 4.14915276f, -2.97010231f, 6.05851364f, 4.95218086f, -2.39145470f, 2.40494704f, 2.10288811f, 0.53503096f, 1.44511235f, 6.66344261f, -3.05803776f, 7.21418667f, 3.30303526f, -0.24163735f, 3.47409391f, 3.64520788f, 2.15189481f, -3.11243272f, 3.62310791f, 0.37379482f, 0.40865007f, -0.83132005f, -4.78246069f, 2.07030797f, 6.51765442f, 3.16178989f, 5.06180477f, 3.78434467f, -0.96689719f, 0.35965276f, 5.89967585f, 1.40294051f, 1.11952639f, 10.59778214f, 0.26739889f, -1.61297631f, 6.24801159f, -0.93914318f, -0.57812452f, 9.92604542f, -0.73025000f, -3.38530874f, 2.45646000f, -2.47949195f, 0.51638460f, 10.65636063f, 1.97816694f, -3.00407791f, 2.66914415f, -0.81951088f, -0.23316640f, 2.40737987f, -2.70007610f, 1.51531935f, 4.08860207f, -0.27552786f, -1.31721711f, 7.11568260f, -3.33498216f, -4.02545023f, 7.22675610f, -0.81690705f, -2.52689576f, 1.04016697f, -0.79291463f, -0.34875512f, 10.00498390f, -4.24167728f, 1.46162593f, 11.82569408f, -1.70359993f, -0.30161047f, 16.44085884f, -0.82253462f, -0.09435523f, 6.13080597f, -0.20259480f, 0.68308711f, 6.15663004f, -6.61776876f, 0.33295766f, 2.55449438f, -0.17819691f, -1.14892209f, 5.56776142f, 1.99279118f, 1.33035934f, 4.45823956f, 3.34916544f, -2.59905386f, 6.16164446f, -2.03881931f, -2.45273542f, 12.46793365f, -2.22743297f, 2.83738565f, 8.48628139f, -1.39347959f, -1.30867767f, 11.08041477f, -4.00363779f, 2.09183025f, 11.30395889f, -2.20504737f, 1.37426853f, 8.98735619f, 1.04676604f, -0.72757077f, 8.28050232f, -6.70741081f, -0.65798020f, 5.68592072f, -0.60760021f, 0.35854483f, 6.26852131f, 1.94100165f, 1.32112014f, 0.80987954f, -1.74617672f, -0.25434083f, 7.16045523f, 1.58884013f, -2.64847064f, 13.14820385f, 1.21393633f, -2.47258949f, 9.41650105f, -0.79384226f, 2.48954105f, 10.95629311f, 0.47723705f, 4.02126694f, 8.02593136f, -2.20726371f, -1.18794477f, 1.50836647f, 0.93118095f, -1.73513174f, 8.85493565f, -2.99670315f, -0.79055870f, 2.39473820f, 2.05046916f, -2.38055134f, 11.82299423f, 0.15609655f, 0.68744308f, 5.66401434f, -0.69281673f, 2.09855556f, 7.74626589f, -0.34283102f, 1.00542057f, 9.95838642f, 0.80161905f, 2.33455157f, 9.80057335f, -0.93561798f, 2.56991577f, 8.29711342f, 0.94213426f, 0.44209945f, 11.70259857f, 0.92710167f, 2.60957146f, 0.24971688f, -0.86529571f, 3.78628922f, 6.80884457f, -0.68178189f, 2.21103406f, 3.18895817f, 0.60283208f, -2.92716241f, 6.72060776f, -1.06625068f, 2.56543374f, 9.97404480f, 3.58080721f, -0.94936347f, 10.16736984f, -1.38464379f, 1.18191063f, 6.66179037f, -3.56115270f, 0.32329530f, 10.90870762f, 2.20638227f, 0.19653285f, 7.34650040f, -3.63859272f, -1.03027737f, 5.98829985f, -3.66606474f, -3.89746714f, 8.63469028f, 1.22569811f, 1.63240814f, 3.74385309f, 0.58243257f, -0.56981975f, 3.69260955f, 1.00979900f, -1.44030499f, 8.57058144f, -1.10648811f, 1.20474911f, 5.43133020f, -2.14822555f, -0.07928789f, 11.25825310f, 0.19645604f, -5.49546146f, 10.41917038f, -0.68178523f, -2.99639869f, 6.50054455f, 0.46488351f, -5.42328453f, 9.09500027f, -2.82107449f, 0.05601966f, 15.34610748f, -0.06820253f, 3.86699796f, 10.73316956f, -3.04795432f, -0.14702171f, 5.64813185f, 1.44028485f, -2.47596145f, 0.07280898f, -3.03187990f, -1.35183525f, 9.35835648f, 2.72966957f, 1.88199532f, 10.36187744f, -0.22834805f, -3.26738238f, 6.92025137f, -2.34061313f, 4.77379704f, 5.28559113f, -2.96323752f, -1.76186585f, 5.94436455f, 0.38647744f, -5.73869514f, 6.76849556f, 1.40892124f, -1.19068217f, 5.37919092f, -6.65328646f, 3.62782669f, 12.34744644f, 2.44762444f, -4.19242620f, 6.14906216f, 0.08121119f, 0.61355996f, 2.69666457f, -1.88962626f, -0.55314136f, 1.84937525f, 1.56048691f, 1.17460012f, 3.75674725f, 1.06198275f, -5.74625874f, 5.41645575f, -1.28946674f, -1.51689398f, 4.32400894f, -0.05222082f, -4.83948946f, 1.80747867f, 1.63144708f, -2.73887825f, 1.63975775f, -2.02163982f, -0.16210437f, 2.93518686f, 1.14427686f, -2.83246303f, 4.79283667f, 2.69697428f, -3.12678456f, -1.19225168f, -2.37022972f, -3.09429741f, 1.94225383f, -1.13747168f, -2.55048585f, 5.40242243f, 1.12777328f, 3.43713188f, 3.62658787f, -2.16878843f, 0.30164462f, 2.97407579f, -0.07275413f, -1.31149673f, 4.70066261f, -2.01323795f, 4.85255766f, 4.59128904f, 1.68084168f, 1.60336494f, 6.58138466f, -1.04759812f, 2.69906545f, 3.55769277f, -0.74327278f, 2.65819693f, 5.39528131f, 2.11248922f, -1.06446671f, 5.24546766f, -2.43146014f, 4.58907509f, 0.06521678f, -2.24503994f, 2.45722699f, 6.94863081f, 0.35258654f, 2.83396196f, 9.92525196f, -1.12225175f, -0.34365177f, 7.19116688f, -4.39813757f, 0.46517885f, 13.22028065f, -2.57483673f, -6.37226963f, 7.58046293f, -2.74600363f, 0.42231262f, 8.04881668f, 0.17289802f, -0.53447008f, 16.55157471f, -5.63614368f, 0.39288223f, 3.37079263f, 1.26484549f, -0.12820500f, 8.46440125f, -4.39304399f, 2.97676420f, 0.65650189f, 0.83158541f, -1.11556435f, 6.32885838f, -0.36087769f, 2.80724382f, 9.90292645f, 1.15936041f, 0.20947981f, 6.91249275f, -2.67404819f, 2.93782163f, 6.65656614f, -2.30828357f, 2.98214006f, 6.80611229f, -4.93821478f, -7.66555262f, 7.59763002f, -0.54159302f, 3.87403512f, 12.42607784f, 2.59284401f, -0.23375344f, 8.95293331f, -0.71807784f, 0.61873478f, 8.66713524f, 1.24289191f, -2.37835455f, 2.08071637f, -0.88315344f, -3.41891551f, 6.85245323f, 1.73007369f, 1.02169311f, 7.69170332f, -2.85411978f, 2.69790673f, 8.12906551f, -1.19351399f, -2.26442742f, 12.26104450f, -0.75579089f, -1.73274946f, 10.68729019f, 2.20655656f, -0.90522075f, 12.42165184f, -1.67929137f, 2.44851565f, 9.31565762f, -0.06645700f, 1.52762020f, 6.18427515f, -1.68882596f, 3.70261097f, 3.02252960f, -3.44125366f, -1.31575799f, 2.84617424f, -0.96849400f, -4.52356243f, 9.95027161f, 0.19966406f, -0.78874779f, 8.18595028f, -4.08300209f, 1.75126517f, 0.96418417f, -4.04913044f, -0.95200396f, 12.03637886f, -0.03041124f, 0.41642749f, 8.88267422f, -3.24985337f, -2.24919462f, 7.32566118f, 0.16964148f, -2.74123430f, 7.05264473f, -3.30191112f, 0.17163286f, 4.81851053f, -1.64463484f, -0.85933101f, 7.29276276f, 2.34066939f, -2.14860010f, 3.46148157f, -0.01782012f, 1.51504040f, 4.79304934f, 1.85281146f, -1.70663762f, 6.93470192f, -4.15440845f, -1.25983095f, 10.52491760f, 0.42930329f, -1.85146868f, 11.70042324f, -0.41704914f, 3.83796859f, 9.21148491f, -2.79719448f, 0.79470479f, 6.26926661f, -5.85230207f, 3.95105338f, 7.84790897f, -1.38680744f, -1.78099084f, 11.95235348f, -2.99841452f, -1.34507811f, 6.15714645f, -1.07552516f, -2.81228638f, 1.66234732f, -4.55166149f, -1.92601109f, 8.64634514f, -0.48158705f, 3.31595659f, 7.67371941f, 2.56964207f, 0.12107098f, 4.56467867f, -0.93541539f, 1.39432955f, 11.99714088f, 1.05353570f, -2.13099813f, 3.67617917f, 3.45895386f, 1.37365830f, 8.74344158f, -4.17585802f, 1.43908918f, 6.28764772f, 3.97346330f, -0.69144285f, 9.07983303f, -0.41635889f, -0.14965028f, 8.85469818f, 1.11306190f, 2.59440994f, 5.38982344f, -1.07948279f, 1.37252975f, 10.26984596f, -0.09318046f, 2.73104119f, 12.45902252f, -1.55446684f, -2.76124811f, 12.19395065f, -0.51846564f, 1.02764034f, 11.42673588f, -0.95940983f, -0.04781032f, 8.78379822f, -4.88957930f, 0.32534006f, 11.97696400f, -3.35108662f, 1.95104563f, 4.46915388f, -2.32061648f, 3.45230985f, 8.29983711f, 2.81034684f, -2.35529327f, 6.07801294f, -0.98105043f, -0.05359888f, 2.52291036f, -0.01986909f, -2.35321999f, 10.51954269f, 2.11145401f, 3.53506470f, 7.29093266f, 0.03721160f, -1.13496494f, 7.43886709f, -5.84201956f, 2.50796294f, 12.14647675f, 2.77490377f, -2.18896222f, 6.05641937f, 5.32617044f, 1.04221284f, 10.79106712f, -2.95749092f, -2.75414610f, 11.30037117f, -3.40654182f, -2.24673963f, 7.49126101f, 0.70811015f, -6.18003702f, 13.83951187f, -1.01204085f, 1.36298490f, -1.04451632f, 2.42435336f, -0.02346706f, -0.85528886f, 1.04731262f, 0.22192979f, 4.15708160f, 0.34933877f, 0.04814529f, 2.24107265f, 0.49676740f, -1.47752666f, 0.45040059f, -0.70471478f, -1.19759345f, 0.21711677f, 0.88461423f, -2.76830935f, 5.52066898f, 1.97664857f, -1.75381601f, 3.45877838f, 1.52617192f, -1.61350942f, 0.85337949f, 1.97610760f, -3.40310287f, 3.40319014f, -3.38691044f, -0.71319139f, 1.65463758f, -0.60680127f, -1.80700517f, 8.02592373f, 2.59627104f, 2.65895891f, 5.93043184f, -4.48425817f, 3.92670918f, 4.19496679f, -2.28286791f, 6.41634607f, 5.72330523f, 1.16269672f, -0.28753027f, 2.46342492f, 0.36693189f, 0.26712441f, 6.37652683f, -2.50139046f, 2.43923736f, 5.56310415f, 0.98065847f, 1.04267502f, 4.16403675f, -0.04966142f, 4.40897894f, 3.72905660f, -3.46129870f, 3.59962773f, 1.34830284f, -1.76661730f, 0.47943926f, 5.29946661f, -1.12711561f, 1.26970029f, 15.17655945f, -1.50971997f, 5.81345224f, 8.48562050f, -4.36049604f, 2.48144460f, 8.23780441f, -3.46030426f, -0.84656560f, 5.94946814f, 1.12747943f, -2.65683913f, 8.69085693f, 1.31309867f, -2.79958344f, 8.76840591f, -1.56444156f, 1.62710834f, 2.41177034f, -0.72804940f, 5.70619011f, 4.67169666f, -0.86167198f, -1.83803177f, 2.96346045f, 2.82692933f, -2.81557131f, 7.11113358f, -1.90071094f, 2.54244423f, 11.19284058f, -0.06298946f, -1.71517313f, 12.98388577f, 0.84510714f, 3.00816894f, 2.57200313f, 0.03899818f, -1.49330592f, 9.60099125f, -3.59513044f, -1.30045319f, 7.09241819f, -0.65233821f, -2.33627677f, 8.81366920f, 0.84154201f, 1.03312039f, 9.85289097f, 0.19351870f, 1.78496623f, 7.34631205f, -2.16530800f, -0.65016162f, 2.46842360f, 0.24016285f, -1.24308395f, 4.78175163f, -0.97682536f, 2.20942235f, 6.68382788f, 3.76786447f, -1.44454038f, 6.26453733f, -3.23575711f, -2.30137897f, 9.53092670f, -5.55222607f, 3.25999236f, 9.37559509f, 1.86339056f, -0.23551451f, 10.23400211f, 3.93031883f, -0.52629089f, 7.85724449f, -2.91549587f, 4.46612740f, 5.66530371f, -2.70820427f, 4.81359577f, 10.31247330f, 1.92230141f, 2.53931546f, 0.74986327f, 1.70303428f, 0.48063779f, 5.31099129f, -0.78976244f, 3.75864220f, 4.23051405f, 2.34042454f, -7.98193836f, 9.83987141f, -1.46722627f, 3.54497814f, 10.36455154f, -4.51249075f, 0.77715248f, 7.78694630f, -4.59989023f, -2.49585629f, 9.90296268f, 1.38535416f, 1.17441154f, 10.10452843f, -0.98628229f, 0.60194463f, 9.12639141f, -3.90754628f, 2.88526392f, 7.24123430f, -0.15283313f, -0.75728363f, -1.15116858f, -2.53791571f, 0.77229571f, 6.44114161f, 0.02646767f, 4.95463037f, 7.21066380f, 1.79384065f, 0.73250306f, 8.04447937f, 0.32576546f, -0.79447043f, 10.12717724f, 2.33392906f, 1.30716443f, 12.36073112f, -0.36694977f, -1.20438910f, 7.03105593f, 0.59557682f, 0.69267452f, 10.18113136f, 2.49944925f, -0.42229167f, 8.83143330f, -1.18805945f, -2.87509322f, 4.53596449f, 4.09732771f, -3.39088297f, -1.02536607f, 0.82119560f, -3.47302604f, 9.29991817f, 0.21001509f, 4.97036457f, 9.50018406f, 1.04420102f, 1.96560478f, 10.74769592f, -6.22709799f, 3.11690164f, 5.06759691f, -1.23724771f, -3.05831861f, 8.12925529f, -1.93435478f, -1.10151744f, 9.32263088f, -0.04249470f, -5.98547363f, 10.49398136f, 0.26400441f, -0.78915191f, 13.28219604f, 2.99276900f, 0.74853164f, 2.49364305f, -3.43529654f, 4.05278301f, 2.13498688f, -2.35444307f, -0.79900265f, 4.66968822f, -0.31095147f, 3.60674143f, 12.37222099f, -0.07855003f, -3.30292702f, 12.15215874f, 0.60886210f, 2.87075138f, 7.75271845f, 0.38044083f, 3.34402204f, 6.40583277f, -0.87888050f, 0.67438459f, 6.91080809f, 1.98332930f, -0.08303714f, 8.08630371f, -0.16772588f, -2.74058914f, 7.17253590f, -2.69122696f, 1.48173678f, 8.99470139f, -1.43302310f, -0.88651133f, 2.66944790f, -0.29186964f, 2.00838661f, 5.09587479f, -0.76676071f, -2.88322186f, 8.31110573f, -0.14550979f, -1.37726915f, 10.28355122f, -1.60575438f, -0.04118848f, 9.97510815f, 0.14440438f, -3.24632120f, 9.00034523f, 4.14319563f, -1.31023729f, 7.16950464f, -0.70428526f, 2.01559544f, 7.26155043f, 2.40816474f, 2.09847403f, 7.31264496f, -0.75401551f, 2.13392544f, 7.03648758f, 1.04036045f, -1.15636516f, 1.09634531f, -0.06340861f, -0.58107805f, -0.65623116f, 1.18972754f, -0.80717683f, 1.40118241f, -0.61932516f, -3.60596156f, 1.59904599f, -2.23774099f, -1.13721037f, 3.89620137f, -0.09115922f, -7.51356888f, 2.36975193f, -1.42520905f, -2.34173775f, 3.33830214f, -2.74016523f, -3.04115510f, 6.00119495f, -1.36084354f, -2.45065260f, 4.56992292f, -3.02825928f, -3.74182844f, 5.11069250f, -0.91531068f, -2.31385994f, 1.83399653f, 3.39370203f, -3.60886002f}); + auto exp = NDArrayFactory::create('c', {4, 4, 4, 3}, {7.97172260f, 0.06878620f, 2.27749538f, 7.29276514f, -0.14074677f, 0.65480286f, 5.70313978f, -0.06546132f, 0.35443667f, 3.70382833f, -0.84020567f, 0.63826996f, 8.60301399f, -0.38236514f, 1.55177069f, 7.37542057f, -0.99374938f, -0.29971302f, 8.84352493f, -0.67121059f, 0.43132120f, 4.78175592f, -1.25070143f, -1.91523600f, 6.03855371f, -0.00292124f, -1.11214364f, 7.90158176f, -0.57949901f, -0.96735370f, 7.81192017f, -0.53255427f, -0.48009714f, 3.16953635f, 0.08353355f, -1.54299748f, 3.74821687f, 1.69396687f, 0.72724354f, 5.42915201f, -1.13686812f, -0.71793109f, 5.78376389f, -0.72239977f, -0.60055625f, 2.53636408f, 0.56777251f, -2.07892323f, 6.08064651f, 0.68620735f, 2.54017019f, 5.65828180f, -0.68255502f, 1.47283304f, 6.10842514f, -0.39655915f, 0.28380761f, 1.96707797f, -1.98206317f, 0.94027776f, 4.71811438f, 0.32104525f, -0.92409706f, 8.34588146f, -1.05581069f, -0.55217457f, 9.58440876f, -0.96549922f, 0.45820439f, 5.65453672f, -2.50953507f, -0.71441835f, 8.03059578f, -0.21281289f, 0.92125505f, 9.26900673f, -0.35963219f, -0.70039093f, 8.59924412f, -1.22358346f, 0.81318003f, 3.85920119f, -0.01305223f, -1.09234154f, 6.33158875f, 1.28094780f, -1.48926139f, 4.94969177f, -0.77126902f, -1.97033751f, 5.64381838f, -0.16285487f, -1.31277227f, 2.39893222f, -1.32902908f, -1.39609122f, 6.47572327f, -0.45267010f, 1.55727172f, 6.70965624f, -1.68735468f, -0.05672536f, 7.25092363f, -0.64613032f, 0.67050058f, 3.60789680f, -2.05948973f, 2.22687531f, 8.15202713f, -0.70148355f, 1.28314006f, 8.14842319f, -1.88807654f, -1.04808438f, 8.45500565f, -0.76425624f, 0.94542569f, 4.56179953f, -0.28786001f, -2.04502511f, 8.46278095f, -0.31019822f, 0.07339200f, 9.34214592f, -0.61948007f, 0.52481830f, 8.32515621f, -1.52418160f, 0.49678251f, 5.11082315f, -1.09908783f, -0.52969611f, 5.27806664f, 0.88632923f, 0.66754371f, 4.75839233f, 0.48928693f, -0.68036932f, 6.56925392f, -0.02949905f, -2.99189186f, 4.46320581f, -0.64534980f, -0.29516968f, 8.60809517f, -1.13120568f, 3.41720533f, 5.84243155f, -1.24109328f, 0.89566326f, 5.99578333f, -0.42496428f, 2.07076764f, 3.17812920f, -0.81566459f, -0.14363396f, 6.55184317f, 0.39633346f, -0.43852386f, 8.70214558f, -2.24613595f, 0.30708700f, 8.73882294f, -0.53545928f, 1.54409575f, 4.49452257f, -0.16509305f, 0.19028664f, 8.24897003f, 0.44750381f, 2.15448594f, 8.97640514f, -0.77728152f, 0.57272542f, 9.03467560f, 0.47173575f, -1.10807717f, 3.30056310f, -0.43268481f, -0.41470885f, 3.53798294f, -0.08546703f, -2.16840744f, 6.18733406f, -0.17871059f, -2.59837723f, 5.94218683f, -1.02990067f, -0.49760687f, 3.76938033f, 0.86383581f, -1.91504073f}); + + nd4j::ops::avgpool2d op; + auto result = op.execute({&input}, {}, {3,3, 3,3, 0,0, 1,1,1, 0,1}, {}); + ASSERT_EQ(Status::OK(), result->status()); + + auto z = result->at(0); + + // z->printIndexedBuffer("z"); + // exp.printIndexedBuffer("e"); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + delete result; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests4, avgpool2d_11) { + int inOutH = 5;// 35; + int inOutW = 5;// 35; + int inOutC = 10;// 192; + + auto x = NDArrayFactory::create('c', {1, inOutH, inOutW, inOutC}); + x.linspace(1.0); + + nd4j::ops::avgpool2d op; + auto result = op.execute({&x}, {}, {3,3, 1,1, 0,0, 1,1, 1, 0, 1}); + ASSERT_EQ(Status::OK(), result->status()); + + auto z = result->at(0); + + int totalPadHeight = (inOutH - 1) * 1 + 3 - inOutH; + int padTop = totalPadHeight / 2; + int padBottom = totalPadHeight - totalPadHeight / 2; + + int k = 3; + + auto m = NDArrayFactory::create('c', {1, inOutH, inOutW, inOutC}); + auto c = NDArrayFactory::create('c', {1, inOutH, inOutW, inOutC}); + + for (int h = 0; h < inOutH; h++) { + for (int w = 0; w < inOutW; w++) { + int hFrom = h - padTop; + int wFrom = w - padBottom; + + int hTo = hFrom + k; + int wTo = wFrom + k; + + hFrom = nd4j::math::nd4j_max(0, hFrom); + wFrom = nd4j::math::nd4j_max(0, wFrom); + + hTo = nd4j::math::nd4j_min(inOutH, hTo); + wTo = nd4j::math::nd4j_min(inOutW, wTo); + + int idxOut[4]; + int idxIn[4]; + for (int ch = 0; ch < inOutC; ch++) { + idxOut[1] = h; + idxOut[2] = w; + idxOut[3] = ch; + idxIn[3] = ch; + + for (int kh = hFrom; kh < hTo; kh++) { + for (int kw = wFrom; kw < wTo; kw++) { + idxIn[1] = kh; + idxIn[2] = kw; + + auto inVal = x.e(0, kh, kw, ch); + m.p(0, h, w, ch, inVal + m.e(0, h, w, ch)); + c.p(0, h, w, ch, 1 + c.e(0, h, w, ch)); + } + } + } + } + } + m /= c; + + ASSERT_EQ(m, *z); + + delete result; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests4, avgpool2d_12) { + + int bS=4, iH=10,iW=10, iC=3, kH=3,kW=3, sH=3,sW=3, pH=0,pW=0, dH=1,dW=1; + int oH=4, oW=4; + int paddingMode = 1; // 1-SAME, 0-VALID + int dataFormat = 1; // 1-NHWC, 0-NDHW + + auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); + auto expected = NDArrayFactory::create('c', {bS, oH, oW, iC}, { 17.5, 18.5, 19.5, 25. , 26. , 27. , 34. , 35. , 36. , 41.5, 42.5, 43.5, 92.5, 93.5, 94.5, 100. , 101. , 102. , 109. , 110. , 111. , 116.5, 117.5, 118.5, + 182.5, 183.5, 184.5, 190. , 191. , 192. , 199. , 200. , 201. , 206.5, 207.5, 208.5, 257.5, 258.5, 259.5, 265. , 266. , 267. , 274. , 275. , 276. , 281.5, 282.5, 283.5, + 317.5, 318.5, 319.5, 325. , 326. , 327. , 334. , 335. , 336. , 341.5, 342.5, 343.5, 392.5, 393.5, 394.5, 400. , 401. , 402. , 409. , 410. , 411. , 416.5, 417.5, 418.5, + 482.5, 483.5, 484.5, 490. , 491. , 492. , 499. , 500. , 501. , 506.5, 507.5, 508.5, 557.5, 558.5, 559.5, 565. , 566. , 567. , 574. , 575. , 576. , 581.5, 582.5, 583.5, + 617.5, 618.5, 619.5, 625. , 626. , 627. , 634. , 635. , 636. , 641.5, 642.5, 643.5, 692.5, 693.5, 694.5, 700. , 701. , 702. , 709. , 710. , 711. , 716.5, 717.5, 718.5, + 782.5, 783.5, 784.5, 790. , 791. , 792. , 799. , 800. , 801. , 806.5, 807.5, 808.5, 857.5, 858.5, 859.5, 865. , 866. , 867. , 874. , 875. , 876. , 881.5, 882.5, 883.5, + 917.5, 918.5, 919.5, 925. , 926. , 927. , 934. , 935. , 936. , 941.5, 942.5, 943.5, 992.5, 993.5, 994.5,1000. , 1001. , 1002. ,1009. , 1010. , 1011. ,1016.5, 1017.5, 1018.5, + 1082.5, 1083.5, 1084.5,1090. , 1091. , 1092. ,1099. , 1100. , 1101. ,1106.5, 1107.5, 1108.5,1157.5, 1158.5, 1159.5,1165. , 1166. , 1167. ,1174. , 1175. , 1176. ,1181.5, 1182.5, 1183.5}); + input.linspace(1.); + input.syncToDevice(); + + nd4j::ops::avgpool2d op; + auto results = op.execute({&input}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, 0, dataFormat}); + auto output = results->at(0); + + ASSERT_EQ(Status::OK(), results->status()); + + //output->printIndexedBuffer("output"); + //expected.printIndexedBuffer("expected"); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + delete results; +} + +////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, biasadd_1) { auto x = NDArrayFactory::create('c', {2, 3, 3, 2}); auto bias = NDArrayFactory::create('c', {2}, {1, 2}); @@ -1652,13 +1788,13 @@ TYPED_TEST(TypedDeclarableOpsTests4, LrnTest_3) { ); auto exp = NDArrayFactory::create('c', {2, 2, 2, 4}, { - 0.9824562f, 0.f, 0.03822664f, 0.9824562f, - 0.67488194f, 0.f, 0.18924236f, 0.96960944f, - 0.99330735f, 0.f, 0.f, 0.37139067f, - 0.86567914f, 0.18702209f, 0.05610663f, 0.9520745f, - 0.6154575f, 0.34942827f, 0.45425674f, 0.6154575f, - 0.905509f, 0.f, 0.2824086f, 0.8361251f, - 0.57063663f, 0.41959068f, 0.629386f, 0.3504383f, + 0.9824562f, 0.f, 0.03822664f, 0.9824562f, + 0.67488194f, 0.f, 0.18924236f, 0.96960944f, + 0.99330735f, 0.f, 0.f, 0.37139067f, + 0.86567914f, 0.18702209f, 0.05610663f, 0.9520745f, + 0.6154575f, 0.34942827f, 0.45425674f, 0.6154575f, + 0.905509f, 0.f, 0.2824086f, 0.8361251f, + 0.57063663f, 0.41959068f, 0.629386f, 0.3504383f, 0.9520745f, 0.21039814f, 0.06311944f, 0.3268602f } ); @@ -1680,24 +1816,24 @@ TYPED_TEST(TypedDeclarableOpsTests4, LrnTest_4) { auto x = NDArrayFactory::create('c', {2, 2, 2, 4}, { - 5.5f, 0.f, 0.3f, 5.5f, - 1.5f, 0.f, 1.3f, 6.5f, - 8.6f, 0.f, 0.f, 0.4f, - 2.5f, 1.f, 0.3f, 4.5f, - 1.5f, 1.f, 1.3f, 1.5f, - 3.5f, 0.f, 1.3f, 2.5f, - 2.6f, 2.f, 3.f, 1.4f, + 5.5f, 0.f, 0.3f, 5.5f, + 1.5f, 0.f, 1.3f, 6.5f, + 8.6f, 0.f, 0.f, 0.4f, + 2.5f, 1.f, 0.3f, 4.5f, + 1.5f, 1.f, 1.3f, 1.5f, + 3.5f, 0.f, 1.3f, 2.5f, + 2.6f, 2.f, 3.f, 1.4f, 4.5f, 1.f, 0.3f, 0.5f} ); auto exp = NDArrayFactory::create('c', {2, 2, 2, 4}, { - 0.70082176f, 0.f, 0.03822664f, 0.70082176f, - 0.21835658f, 0.f, 0.18924236f, 0.9462118f, - 0.9922489f, 0.f, 0.f, 0.04615111f, - 0.46755522f, 0.18702209f, 0.05610663f, 0.8415994f, - 0.5241424f, 0.34942827f, 0.45425674f, 0.5241424f, - 0.76033086f, 0.f, 0.2824086f, 0.54309344f, - 0.54546785f, 0.41959068f, 0.629386f, 0.29371348f, + 0.70082176f, 0.f, 0.03822664f, 0.70082176f, + 0.21835658f, 0.f, 0.18924236f, 0.9462118f, + 0.9922489f, 0.f, 0.f, 0.04615111f, + 0.46755522f, 0.18702209f, 0.05610663f, 0.8415994f, + 0.5241424f, 0.34942827f, 0.45425674f, 0.5241424f, + 0.76033086f, 0.f, 0.2824086f, 0.54309344f, + 0.54546785f, 0.41959068f, 0.629386f, 0.29371348f, 0.94679165f, 0.21039814f, 0.06311944f, 0.10519907f} ); @@ -1719,28 +1855,28 @@ TYPED_TEST(TypedDeclarableOpsTests4, LrnTest_5) { auto x = NDArrayFactory::create('c', {2, 2, 2, 4}, { - 5.5f, 0.f, 0.3f, 5.5f, - 1.5f, 0.f, 1.3f, 6.5f, - 8.6f, 0.f, 0.f, 0.4f, - 2.5f, 1.f, 0.3f, 4.5f, - 1.5f, 1.f, 1.3f, 1.5f, - 3.5f, 0.f, 1.3f, 2.5f, - 2.6f, 2.f, 3.f, 1.4f, + 5.5f, 0.f, 0.3f, 5.5f, + 1.5f, 0.f, 1.3f, 6.5f, + 8.6f, 0.f, 0.f, 0.4f, + 2.5f, 1.f, 0.3f, 4.5f, + 1.5f, 1.f, 1.3f, 1.5f, + 3.5f, 0.f, 1.3f, 2.5f, + 2.6f, 2.f, 3.f, 1.4f, 4.5f, 1.f, 0.3f, 0.5f} ); auto eps = NDArrayFactory::create('c', {2, 2, 2, 4}, { - 0.70082176f, 0.f, 0.03822664f, 0.70082176f, - 0.21835658f, 0.f, 0.18924236f, 0.9462118f, + 0.70082176f, 0.f, 0.03822664f, 0.70082176f, + 0.21835658f, 0.f, 0.18924236f, 0.9462118f, - 0.9922489f, 0.f, 0.f, 0.04615111f, - 0.46755522f, 0.18702209f, 0.05610663f, 0.8415994f, + 0.9922489f, 0.f, 0.f, 0.04615111f, + 0.46755522f, 0.18702209f, 0.05610663f, 0.8415994f, - 0.5241424f, 0.34942827f, 0.45425674f, 0.5241424f, - 0.76033086f, 0.f, 0.2824086f, 0.54309344f, + 0.5241424f, 0.34942827f, 0.45425674f, 0.5241424f, + 0.76033086f, 0.f, 0.2824086f, 0.54309344f, - 0.54546785f, 0.41959068f, 0.629386f, 0.29371348f, + 0.54546785f, 0.41959068f, 0.629386f, 0.29371348f, 0.94679165f, 0.21039814f, 0.06311944f, 0.10519907f} ); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests8.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests8.cpp index ef495142d..3fb90b480 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests8.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests8.cpp @@ -2459,42 +2459,6 @@ TEST_F(DeclarableOpsTests8, reduceStDevBP_test4) { } -////////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests8, avgpool2d_test13) { - - int bS=4, iH=10,iW=10, iC=3, kH=3,kW=3, sH=3,sW=3, pH=0,pW=0, dH=1,dW=1; - int oH=4, oW=4; - int paddingMode = 1; // 1-SAME, 0-VALID - int dataFormat = 1; // 1-NHWC, 0-NDHW - - auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); - auto expected = NDArrayFactory::create('c', {bS, oH, oW, iC}, { 17.5, 18.5, 19.5, 25. , 26. , 27. , 34. , 35. , 36. , 41.5, 42.5, 43.5, 92.5, 93.5, 94.5, 100. , 101. , 102. , 109. , 110. , 111. , 116.5, 117.5, 118.5, - 182.5, 183.5, 184.5, 190. , 191. , 192. , 199. , 200. , 201. , 206.5, 207.5, 208.5, 257.5, 258.5, 259.5, 265. , 266. , 267. , 274. , 275. , 276. , 281.5, 282.5, 283.5, - 317.5, 318.5, 319.5, 325. , 326. , 327. , 334. , 335. , 336. , 341.5, 342.5, 343.5, 392.5, 393.5, 394.5, 400. , 401. , 402. , 409. , 410. , 411. , 416.5, 417.5, 418.5, - 482.5, 483.5, 484.5, 490. , 491. , 492. , 499. , 500. , 501. , 506.5, 507.5, 508.5, 557.5, 558.5, 559.5, 565. , 566. , 567. , 574. , 575. , 576. , 581.5, 582.5, 583.5, - 617.5, 618.5, 619.5, 625. , 626. , 627. , 634. , 635. , 636. , 641.5, 642.5, 643.5, 692.5, 693.5, 694.5, 700. , 701. , 702. , 709. , 710. , 711. , 716.5, 717.5, 718.5, - 782.5, 783.5, 784.5, 790. , 791. , 792. , 799. , 800. , 801. , 806.5, 807.5, 808.5, 857.5, 858.5, 859.5, 865. , 866. , 867. , 874. , 875. , 876. , 881.5, 882.5, 883.5, - 917.5, 918.5, 919.5, 925. , 926. , 927. , 934. , 935. , 936. , 941.5, 942.5, 943.5, 992.5, 993.5, 994.5,1000. , 1001. , 1002. ,1009. , 1010. , 1011. ,1016.5, 1017.5, 1018.5, - 1082.5, 1083.5, 1084.5,1090. , 1091. , 1092. ,1099. , 1100. , 1101. ,1106.5, 1107.5, 1108.5,1157.5, 1158.5, 1159.5,1165. , 1166. , 1167. ,1174. , 1175. , 1176. ,1181.5, 1182.5, 1183.5}); - input.linspace(1.); - input.syncToDevice(); - - nd4j::ops::avgpool2d op; - auto results = op.execute({&input}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, 0, dataFormat}); - auto output = results->at(0); - - ASSERT_EQ(Status::OK(), results->status()); - - //output->printIndexedBuffer("output"); - //expected.printIndexedBuffer("expected"); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - - delete results; -} - - /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, softmax_cross_entropy_loss_with_logits_test1) { diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp index 6df52fb54..caceaa1cd 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp @@ -2894,344 +2894,7 @@ TEST_F(DeclarableOpsTests9, Floormod_BP_Test_4) { delete result; } -//////////////////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests9, batchnorm_bp_test1) { - NDArray input ('c', {2,3,4}, nd4j::DataType::FLOAT32); - NDArray mean ('c', {4}, {1.1, 1.2, 1.3, 1.4}, nd4j::DataType::FLOAT32); - NDArray variance('c', {4}, nd4j::DataType::FLOAT32); - NDArray gamma ('c', {4}, nd4j::DataType::FLOAT32); - NDArray beta ('c', {4}, nd4j::DataType::FLOAT32); - NDArray gradO ('c', {2,3,4}, nd4j::DataType::FLOAT32); - - NDArray expdLdI('c', {2,3,4}, {-0.000056, -0.000056, -0.000056, -0.000056, -0.000034, -0.000034, -0.000034, -0.000034, -0.000011, -0.000011, -0.000011, -0.000011, 0.000011, 0.000011, 0.000011, 0.000011, 0.000034, 0.000034, 0.000034, 0.000034, 0.000056, 0.000056, 0.000056, 0.000056}, nd4j::DataType::FLOAT32); - NDArray expdLdG('c', {4}, {6.148104, 6.148104, 6.148105, 6.148105}, nd4j::DataType::FLOAT32); - NDArray expdLdB('c', {4}, {3.6, 4.5, 5.4, 6.3}, nd4j::DataType::FLOAT32); - - input.linspace(0.1, 0.1); - variance.assign(0.46666667); - gamma.assign(1.2); - beta.assign(1.); // has no effect on gradient calculations - gradO.linspace(-0.9, 0.15); - - nd4j::ops::batchnorm_bp op; - - auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1}); - - ASSERT_EQ(ND4J_STATUS_OK, results->status()); - - auto dLdI = results->at(0); - auto dLdG = results->at(3); - auto dLdB = results->at(4); - - ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI)); - ASSERT_TRUE(expdLdI.equalsTo(dLdI)); - - ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG)); - ASSERT_TRUE(expdLdG.equalsTo(dLdG)); - - ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB)); - ASSERT_TRUE(expdLdB.equalsTo(dLdB)); - - delete results; -} - - -//////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests9, batchnorm_bp_test2) { - - NDArray input ('c', {2,3,4}, nd4j::DataType::FLOAT32); - NDArray mean ('c', {3}, {1.05, 1.1, 1.15}, nd4j::DataType::FLOAT32); - NDArray variance('c', {3}, {0.5, 0.6, 0.7}, nd4j::DataType::FLOAT32); - NDArray gamma ('c', {3}, {1.2, 1.3, 1.4}, nd4j::DataType::FLOAT32); - NDArray beta ('c', {3}, nd4j::DataType::FLOAT32); - NDArray gradO ('c', {2,3,4}, nd4j::DataType::FLOAT32); - - NDArray expdLdI('c', {2,3,4}, {-0.601415, -0.521226, -0.441037, -0.360849, -0.456306, -0.395465, -0.334624, -0.273784, 0.396631, 0.343747, - 0.290863, 0.237978, 0.360849, 0.441037, 0.521226, 0.601415, 0.273784, 0.334625, 0.395465, 0.456306, -0.237978, - -0.290863, -0.343746, -0.396631}, nd4j::DataType::FLOAT32); - NDArray expdLdG('c', {3}, {5.81236 , 7.048771, 12.155388}, nd4j::DataType::FLOAT32); - NDArray expdLdB('c', {3}, {1.8, 6.6, 11.4}, nd4j::DataType::FLOAT32); - - input.linspace(0.1, 0.1); - // beta.assign(1.); // has no effect on gradient calculations - gradO.linspace(-0.9, 0.15); - - nd4j::ops::batchnorm_bp op; - - auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,1}); - - ASSERT_EQ(ND4J_STATUS_OK, results->status()); - - auto dLdI = results->at(0); - auto dLdG = results->at(3); - auto dLdB = results->at(4); - - ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI)); - ASSERT_TRUE(expdLdI.equalsTo(dLdI)); - - ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG)); - ASSERT_TRUE(expdLdG.equalsTo(dLdG)); - - ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB)); - ASSERT_TRUE(expdLdB.equalsTo(dLdB)); - - delete results; -} - -//////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests9, batchnorm_bp_test3) { - - NDArray input ('c', {2,3,4}, nd4j::DataType::FLOAT32); - NDArray mean ('c', {2,1,4}, {1.05, 1.1, 1.15, 1.2, 1.25, 1.3, 1.35, 1.4}, nd4j::DataType::FLOAT32); - NDArray variance('c', {2,1,4}, {0.5, 0.6, 0.7, 0.8, 0.9, 1., 1.1, 1.2}, nd4j::DataType::FLOAT32); - NDArray gamma ('c', {2,1,4}, {1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9}, nd4j::DataType::FLOAT32); - NDArray beta ('c', {2,1,4}, nd4j::DataType::FLOAT32); - NDArray gradO ('c', {2,3,4}, nd4j::DataType::FLOAT32); - - NDArray expdLdI('c', {2,3,4}, {-0.577002, -0.744041, -0.850999, -0.922373, -0.000000, -0.000000, -0.000000, -0.000000, 0.577002, - 0.744041, 0.850999, 0.922373, -0.386037, -0.350205, -0.312047, -0.271737, -0.000000, -0.000000, - -0.000000, -0.000000, 0.386037, 0.350205, 0.312047, 0.271736}, nd4j::DataType::FLOAT32); - NDArray expdLdG('c', {2,1,4}, {1.378844, 0.910144, 0.573706, 0.335408, 2.640487, 2.954985, 3.289431, 3.64234 }, nd4j::DataType::FLOAT32); - NDArray expdLdB('c', {2,1,4}, {-0.9 , -0.45, 0. , 0.45, 4.5 , 4.95, 5.4 , 5.85}, nd4j::DataType::FLOAT32); - - input.linspace(0.1, 0.1); - // beta.assign(1.); // has no effect on gradient calculations - gradO.linspace(-0.9, 0.15); - - nd4j::ops::batchnorm_bp op; - - auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,0,2}); - - ASSERT_EQ(ND4J_STATUS_OK, results->status()); - - auto dLdI = results->at(0); - auto dLdG = results->at(3); - auto dLdB = results->at(4); - - ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI)); - ASSERT_TRUE(expdLdI.equalsTo(dLdI)); - - ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG)); - ASSERT_TRUE(expdLdG.equalsTo(dLdG)); - - ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB)); - ASSERT_TRUE(expdLdB.equalsTo(dLdB)); - - delete results; -} - -//////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests9, batchnorm_bp_test4) { - - NDArray input ('c', {2,4}, nd4j::DataType::FLOAT32); - NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32); - NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, nd4j::DataType::FLOAT32); - NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32); - NDArray beta ('c', {4}, nd4j::DataType::FLOAT32); - NDArray gradO ('c', {2,4}, nd4j::DataType::FLOAT32); - - NDArray expdLdI('c', {2,4}, {0.162923, -0.289673, 0.354174, -0.386151, -0.162923, 0.289673, -0.354174, 0.386151}, nd4j::DataType::FLOAT32); - NDArray expdLdG('c', {4}, {1.442483, 0.950200, 0.569207, 0.314641}, nd4j::DataType::FLOAT32); - NDArray expdLdB('c', {4}, {-1.2, -0.9, -0.6, -0.3}, nd4j::DataType::FLOAT32); - - input.linspace(0.1, 0.1); - gradO.linspace(-0.9, 0.15); - - nd4j::ops::batchnorm_bp op; - - auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1}); - - ASSERT_EQ(ND4J_STATUS_OK, results->status()); - - auto dLdI = results->at(0); - auto dLdG = results->at(3); - auto dLdB = results->at(4); - - ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI)); - ASSERT_TRUE(expdLdI.equalsTo(dLdI)); - - ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG)); - ASSERT_TRUE(expdLdG.equalsTo(dLdG)); - - ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB)); - ASSERT_TRUE(expdLdB.equalsTo(dLdB)); - - delete results; -} - -//////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests9, batchnorm_bp_test5) { - - NDArray input ('c', {2,4,2,2}, nd4j::DataType::FLOAT32); - NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32); - NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, nd4j::DataType::FLOAT32); - NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32); - NDArray beta ('c', {4}, nd4j::DataType::FLOAT32); - NDArray gradO ('c', {2,4,2,2}, nd4j::DataType::FLOAT32); - - NDArray expdLdI('c', {2,4,2,2}, {-0.737512, -0.659880, -0.582247, -0.504614, 0.561404, 0.502309, 0.443214, 0.384118, -1.168243, - -1.045270, -0.922297, -0.799324, 1.899026, 1.699128, 1.499231, 1.299333, 0.504614, 0.582247, 0.659880, 0.737512, -0.384118, - -0.443214, -0.502308, -0.561404, 0.799324, 0.922297, 1.045270, 1.168243, -1.299334, -1.499231, -1.699129, -1.899026}, nd4j::DataType::FLOAT32); - NDArray expdLdG('c', {4}, {11.073181, 12.585667, 17.708657, 24.313186}, nd4j::DataType::FLOAT32); - NDArray expdLdB('c', {4}, {4.2, 9. , 13.8, 18.6}, nd4j::DataType::FLOAT32); - - input.linspace(0.1, 0.1); - gradO.linspace(-0.9, 0.15); - - nd4j::ops::batchnorm_bp op; - - auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,1}); - - ASSERT_EQ(ND4J_STATUS_OK, results->status()); - - auto dLdI = results->at(0); - auto dLdG = results->at(3); - auto dLdB = results->at(4); - - ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI)); - ASSERT_TRUE(expdLdI.equalsTo(dLdI)); - - ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG)); - ASSERT_TRUE(expdLdG.equalsTo(dLdG)); - - ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB)); - ASSERT_TRUE(expdLdB.equalsTo(dLdB)); - - delete results; -} - -//////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests9, batchnorm_bp_test6) { - - NDArray input ('c', {2,2,2,4}, nd4j::DataType::FLOAT32); - NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32); - NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, nd4j::DataType::FLOAT32); - NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32); - NDArray beta ('c', {4}, nd4j::DataType::FLOAT32); - NDArray gradO ('c', {2,2,2,4}, nd4j::DataType::FLOAT32); - - NDArray expdLdI('c', {2,2,2,4}, {-4.989124, 2.540357, -1.515022, 0.791769, -3.563660, 1.814540, -1.082159, 0.565549, -2.138196, 1.088724, -0.649295, - 0.339329, -0.712732, 0.362908, -0.216432, 0.113110, 0.712732, -0.362908, 0.216432, -0.113110, 2.138195, -1.088724, 0.649295, - -0.339330, 3.563660,-1.814540, 1.082159, -0.565549, 4.989125, -2.540356, 1.515022, -0.791770}, nd4j::DataType::FLOAT32); - NDArray expdLdG('c', {4}, {20.364472, 17.856588, 16.949714, 15.903684}, nd4j::DataType::FLOAT32); - NDArray expdLdB('c', {4}, {9.6, 10.8, 12. , 13.2}, nd4j::DataType::FLOAT32); - - input.linspace(0.1, 0.1); - gradO.linspace(-0.9, 0.15); - - nd4j::ops::batchnorm_bp op; - - auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,3}); - - ASSERT_EQ(ND4J_STATUS_OK, results->status()); - - auto dLdI = results->at(0); - auto dLdG = results->at(3); - auto dLdB = results->at(4); - - ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI)); - ASSERT_TRUE(expdLdI.equalsTo(dLdI)); - - ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG)); - ASSERT_TRUE(expdLdG.equalsTo(dLdG)); - - ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB)); - ASSERT_TRUE(expdLdB.equalsTo(dLdB)); - - delete results; -} - -//////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests9, batchnorm_bp_test7) { - - NDArray input ('c', {2,2,2,2,4}, nd4j::DataType::FLOAT32); - NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32); - NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, nd4j::DataType::FLOAT32); - NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32); - NDArray beta ('c', {4}, nd4j::DataType::FLOAT32); - NDArray gradO ('c', {2,2,2,2,4}, nd4j::DataType::FLOAT32); - - NDArray expdLdI('c', {2,2,2,2,4}, {-119.435059, 78.159744, -58.732986, 46.630123, -103.510391, 67.738441, -50.901920, 40.412773, -87.585716, 57.317142, - -43.070854, 34.195419, -71.661041, 46.895844, -35.239792, 27.978071, -55.736359, 36.474548, -27.408726, 21.760721, -39.811687, 26.053242, -19.577662, - 15.543370, -23.887009, 15.631950, -11.746595, 9.326023, -7.962326, 5.210644, -3.915531, 3.108671, 7.962341, -5.210655, 3.915535, -3.108677, 23.887032, - -15.631958, 11.746601, -9.326031, 39.811691, -26.053246, 19.577671, -15.543377, 55.736382, -36.474548, 27.408726, -21.760731, 71.661064, -46.895851, 35.239788, - -27.978077, 87.585732, -57.317154, 43.070866, -34.195431, 103.510384, -67.738464, 50.901920, -40.412777, 119.435097, -78.159744, 58.732998, -46.630131}, nd4j::DataType::FLOAT32); - NDArray expdLdG('c', {4}, {282.38734 , 244.542027, 224.140995, 207.548793}, nd4j::DataType::FLOAT32); - NDArray expdLdB('c', {4}, {57.6, 60. , 62.4, 64.8}, nd4j::DataType::FLOAT32); - - input.linspace(0.1, 0.1); - gradO.linspace(-0.9, 0.15); - - - nd4j::ops::batchnorm_bp op; - - auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,4}); - - ASSERT_EQ(ND4J_STATUS_OK, results->status()); - - auto dLdI = results->at(0); - auto dLdG = results->at(3); - auto dLdB = results->at(4); - - // dLdI->printBuffer(); - - ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI)); - ASSERT_TRUE(expdLdI.equalsTo(dLdI)); - - ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG)); - ASSERT_TRUE(expdLdG.equalsTo(dLdG)); - - ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB)); - ASSERT_TRUE(expdLdB.equalsTo(dLdB)); - - delete results; -} - -//////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests9, batchnorm_bp_test8) { - - NDArray input ('c', {2,4,2,2,2}, nd4j::DataType::FLOAT32); - NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32); - NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, nd4j::DataType::FLOAT32); - NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32); - NDArray beta ('c', {4}, nd4j::DataType::FLOAT32); - NDArray gradO ('c', {2,4,2,2,2}, nd4j::DataType::FLOAT32); - - NDArray expdLdI('c', {2,4,2,2,2}, {-34.373802, -32.611046, -30.848286, -29.085529, -27.322769, -25.560009, -23.797251, -22.034491, 36.146996, 34.293301, - 32.439610, 30.585917, 28.732227, 26.878534, 25.024841, 23.171150, -42.876553, -40.677757, -38.478958, -36.280159, -34.081367, -31.882565, -29.683767, - -27.484968, 50.674446, 48.075760, 45.477066, 42.878380, 40.279686, 37.681000, 35.082310, 32.483616, 22.034489, 23.797249, 25.560009, 27.322765, 29.085526, - 30.848286, 32.611046, 34.373802, -23.171146, -25.024837, -26.878536, -28.732231, -30.585918, -32.439613, -34.293297, -36.146996, 27.484982, 29.683773, - 31.882572, 34.081364, 36.280178, 38.478970, 40.677776, 42.876560, -32.483627, -35.082329, -37.681023, -40.279701, -42.878403, -45.477081, -48.075775, -50.674484}, nd4j::DataType::FLOAT32); - NDArray expdLdG('c', {4}, {134.490365, 179.785003, 248.933114, 330.087248}, nd4j::DataType::FLOAT32); - NDArray expdLdB('c', {4}, {32.4, 51.6, 70.8, 90.}, nd4j::DataType::FLOAT32); - - input.linspace(0.1, 0.1); - gradO.linspace(-0.9, 0.15); - - nd4j::ops::batchnorm_bp op; - - auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,1}); - - ASSERT_EQ(ND4J_STATUS_OK, results->status()); - - auto dLdI = results->at(0); - auto dLdG = results->at(3); - auto dLdB = results->at(4); - - // dLdI->printBuffer(); - - ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI)); - ASSERT_TRUE(expdLdI.equalsTo(dLdI)); - - ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG)); - ASSERT_TRUE(expdLdG.equalsTo(dLdG)); - - ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB)); - ASSERT_TRUE(expdLdB.equalsTo(dLdB)); - - delete results; -} /* //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, gru_cell_bp_test1) { From f25056363b7bdb801eb208506b08c219ca5f9755 Mon Sep 17 00:00:00 2001 From: Abdelrauf Date: Tue, 28 Jan 2020 20:00:12 +0400 Subject: [PATCH 04/17] auto-vectorization check for gcc (#172) * Autovectorization tool: - sync output for gnu make - Reduced html output - links for line numbers - AutoVectorization.md Signed-off-by: AbdelRauf * Detailed report with `-fsave-optimization-record` option Signed-off-by: AbdelRauf * Readme Signed-off-by: AbdelRauf Co-authored-by: raver119 --- libnd4j/CMakeLists.txt | 2 +- libnd4j/README.md | 3 + .../auto_vectorization/AutoVectorization.md | 49 ++ libnd4j/auto_vectorization/auto_vect.py | 546 ++++++++++++++++++ libnd4j/auto_vectorization/bigGzipJson.pyx | 354 ++++++++++++ libnd4j/auto_vectorization/cython_setup.py | 3 + libnd4j/blas/CMakeLists.txt | 26 + libnd4j/buildnativeoperations.sh | 21 +- 8 files changed, 1001 insertions(+), 3 deletions(-) create mode 100644 libnd4j/auto_vectorization/AutoVectorization.md create mode 100644 libnd4j/auto_vectorization/auto_vect.py create mode 100644 libnd4j/auto_vectorization/bigGzipJson.pyx create mode 100644 libnd4j/auto_vectorization/cython_setup.py diff --git a/libnd4j/CMakeLists.txt b/libnd4j/CMakeLists.txt index c82b0b217..cf9d4ff88 100755 --- a/libnd4j/CMakeLists.txt +++ b/libnd4j/CMakeLists.txt @@ -5,7 +5,7 @@ option(NATIVE "Optimize for build machine (might not work on others)" OFF) set(CMAKE_MODULE_PATH "${CMAKE_SOURCE_DIR}/cmake" ${CMAKE_MODULE_PATH}) #ensure we create lib files set(CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS OFF) - +option(CHECK_VECTORIZATION "checks for vectorization" OFF) option(BUILD_TESTS "Build tests" OFF) option(FLATBUFFERS_BUILD_FLATC "Enable the build of the flatbuffers compiler" OFF) set(FLATBUFFERS_BUILD_FLATC "OFF" CACHE STRING "Hack to disable flatc build" FORCE) diff --git a/libnd4j/README.md b/libnd4j/README.md index 9cea1b597..ec17c6227 100644 --- a/libnd4j/README.md +++ b/libnd4j/README.md @@ -17,8 +17,11 @@ There's few additional arguments for `buildnativeoperations.sh` script you could -b release OR -b debug // enables/desables debug builds. release is considered by default -j XX // this argument defines how many threads will be used to binaries on your box. i.e. -j 8 -cc XX// CUDA-only argument, builds only binaries for target GPU architecture. use this for fast builds + --check-vectorization auto-vectorization report for developers. (Currently, only GCC is supported) ``` +[More about AutoVectorization report](auto_vectorization/AutoVectorization.md) + You can find the compute capability for your card [on the NVIDIA website here](https://developer.nvidia.com/cuda-gpus). For example, a GTX 1080 has compute capability 6.1, for which you would use ```-cc 61``` (note no decimal point). diff --git a/libnd4j/auto_vectorization/AutoVectorization.md b/libnd4j/auto_vectorization/AutoVectorization.md new file mode 100644 index 000000000..61b98febe --- /dev/null +++ b/libnd4j/auto_vectorization/AutoVectorization.md @@ -0,0 +1,49 @@ +# Auto-vectorization Report + +This report tool is used to get a human-friendly compiler output of the auto-vectorization process. It is intended for developers to help them to investigate the obstacles that compiler faced during auto-vectorization. + +## Usage +```--check-vectorization``` option should be added to the **release** build to be able to get the auto-vectorization report +```./buildnativeoperations.sh -a native -j 28 --check-vectorization``` +it will output ```vecmiss.html``` inside blasbuild/cpu folder. + +## Report Format +Each filename contains info about optimization attempts for the source code lines. +Each line number is also expandable (⇲) and contains distinct failure notes. +It is possible to click on the line number to see source code + +| file name | total successful attempts | total failed attempts | ⇲ | +|---|---|---|--| +| line number | successful attempts | failed attempts | ⇲ | +|- failure reasons | +| line number | successful attempts | failed attempts |⇲ | + +##### Requirements +- GCC (Currently, only GCC is supported) +- python3 + +### Detailed report with `-fsave-optimization-record` option: +If you want to get more detailed information (for now it reports the functions of failures) you should use new version of the toolchain (GCC > 9). As the new version of GCC compilers have `-fsave-optimization-record` option. +`buildnativeoperations.sh` using CMake will detect it and switch to the more detailed version. +Please, note that this option is still experimental and so the compiler can fail to output some json.gz file with error. +On that case try to exclude those files from the build. +And also the internal structure of the `-fsave-optimization-record` json.gz can be changed in future. + +It outputs two files **vecmiss_fsave.html** and **vecmiss_fsave.html.js**. So to see report details you need to enable javascript on browser if it was disabled. + +##### Requirements for the Detailed report +- GCC version > 9 +- python3 +- Cython (python3) +- json (python3) +- gzip (python3) +- c++filt + +Internally, we are using Cython to speed up json.gz file processing (bigGzipJson.pyx). Because json.gz files can take big memory in raw when loaded in whole. + +If you want to use bigGzipJson outside `buildnativeoperations.sh` and CMake then you should compile it manually using this command in auto_vectorization folder: +`python3 cython_setup.py build_ext --inplace` + +json.gz files could be processed outside of `buildnativeoperations.sh`. +You need to call `python3 auto_vect.py --fsave` inside base source folder and where json.gz files exist. + diff --git a/libnd4j/auto_vectorization/auto_vect.py b/libnd4j/auto_vectorization/auto_vect.py new file mode 100644 index 000000000..f98dc7422 --- /dev/null +++ b/libnd4j/auto_vectorization/auto_vect.py @@ -0,0 +1,546 @@ +''' +@author : Abdelrauf rauf@konduit.ai +''' +import re +import sys +import os +import subprocess +import fnmatch +import json +import gzip +try: + from bigGzipJson import json_gzip_extract_objects +except ImportError: + pass +from pathlib import Path +from multiprocessing import Pool, Manager ,cpu_count +import traceback +import html + +mtch = re.compile(r"[^/]*([^:]+)\:(\d+)\:(\d+)\:(.*)") +replace_msg = re.compile(r"(\d+)?\.?(\d+)?_?\d+\.?(\d+)?") +progress_msg = re.compile(r"\s{0,4}\[\s{0,2}\d+\%\]") +file_dir_strip = str(Path(os.getcwd())) +pp_index = file_dir_strip.rfind("libnd4j") +if pp_index>=0: + file_dir_strip =file_dir_strip[:pp_index+len("libnd4j")] +BASE_URL = "https://github.com/eclipse/deeplearning4j/tree/master/libnd4j/" +if BASE_URL.endswith("/")==False: + BASE_URL = BASE_URL + "/" +#print(file_dir_strip) +class info: + def __repr__(self): + return str(self.__dict__) + +FSAVE_IGNORE_EXTERNALS = True + +def get_cxx_filt_result(strx): + if len(strx)<1: + return "" + res = subprocess.Popen(["c++filt","-i", strx], stdout=subprocess.PIPE).communicate()[0] + res =res.decode('utf-8') + #replace some long names to reduce size + res = res.replace("unsigned long long", "uLL") + res = res.replace("unsigned long int","uL") + res = res.replace("unsigned long", "uL") + res = res.replace("unsigned int", "ui") + res = res.replace("unsigned char", "uchar") + res = res.replace("unsigned short", "ushort") + res = res.replace("long long", "LL") + res = res.replace(", ",",") + return res.strip() + + +def internal_glob(dir, match): + listx = [] + for root, dirnames, filenames in os.walk(dir): + for filename in fnmatch.filter(filenames, match): + listx.append(os.path.join(root, filename)) + return listx + +def get_obj_json_gz(filename): + with gzip.GzipFile(filename, 'r') as f: + return json.loads(f.read().decode('utf-8'))[-1] + + + +def get_msg(msg): + msg = msg.lower().strip() + if "note: not vectorized:" in msg: + msg = replace_msg.sub("_numb",msg.replace("note: not vectorized:","")) + return( 0, 1, msg.strip()) + elif "loop vectorized" in msg: + return (1, 0, None) + # elif msg.startswith("missed")==False: + # msg = replace_msg.sub("_numb",msg) + # return( 0, 0, msg.strip()) + return None + + + + +class File_Info: + ''' + Holds information about vectorized and miss vectorized lines for one file + ''' + + def __init__(self): + self.infos = {} + self.total_opted =0 + self.total_missed = 0 + self.external = False + + + def add_line(self, line_pos): + if line_pos not in self.infos: + v = info() + v.optimized = 0 + v.missed = 0 + v.miss_details = set() + self.infos[line_pos] = v + return v + else: + return self.infos[line_pos] + + + def add_line_fsave(self, line_pos): + if line_pos not in self.infos: + v = info() + v.optimized = 0 + v.missed = 0 + v.miss_details2 = dict() + self.infos[line_pos] = v + return v + else: + return self.infos[line_pos] + + + + def add_fsave(self, line_pos,success, msg, function ,inline_fns=''): + v = self.add_line_fsave(line_pos) + if success and "loop vectorized" in msg: + v.optimized +=1 + self.total_opted +=1 + elif success==False and "not vectorized:" in msg: + #reduce this msg + msg = msg.replace("not vectorized:","") + v.missed +=1 + self.total_missed +=1 + msg = sys.intern(msg) + if msg in v.miss_details2: + ls = v.miss_details2.get(msg) + ls.add(function) + else: + ls =set() + v.miss_details2[msg]=ls + ls.add(function) + return self + + def add(self, line_pos, msg_x): + v = self.add_line(line_pos) + if msg_x is not None: + v.optimized += msg_x[0] + v.missed += msg_x[1] + self.total_opted += msg_x[0] + self.total_missed += msg_x[1] + if msg_x[2] is not None: + v.miss_details.add(msg_x[2]) + return self + + + def __repr__(self): + return str(self.__dict__) + + + + +def process_gzip_json_mp(args): + process_gzip_json_new(*args) + +def process_gzip_json_new(json_gz_fname,list_Queue): + gz_name = Path(json_gz_fname).stem + #print("::--open and process {0}".format(gz_name)) + queue_count = len(list_Queue) + #print(queue_count) + q = list_Queue[0] + old_fname = '' + total_c = 0 + for x in json_gzip_extract_objects(json_gz_fname,'message','vectorized'): + external_source = True + if len(x['message'])>0 and 'location' in x: + line = int(x['location']['line']) + file_name = x['location']['file'].strip() + if file_dir_strip in file_name: + file_name = file_name.replace(file_dir_strip,'./') + external_source = False + msg = x['message'][0] + success = x['kind'] == 'success' + func = '' if 'function' not in x else x['function'] + + if file_name!=old_fname: + #send our info to the right consumer + queue_ind = hash(file_name) % queue_count + #print("quen index {0}".format(queue_ind)) + q =list_Queue[queue_ind] + old_fname = file_name + total_c +=1 + #print("pp {0} {1}".format(q,(file_name,line,success, msg, func,external_source ))) + if FSAVE_IGNORE_EXTERNALS==True and external_source == True: + continue + q.put((file_name,line,success, msg, func,external_source )) + print("::finished {0:60s} :{1:8d}".format(gz_name,total_c)) + +def consume_processed_mp(args): + return consume_processed_new(*args) + + + +def consume_processed_new(list_Queue , c_index): + + info_ = dict() + func_list = dict() + last_func_index = 0 + q = list_Queue[c_index] + print("::consumer {0}".format(c_index)) + total_c = 0 + r_c = 0 + while True: + #print("try to get new from {0}".format(index)) + obj = q.get() + #print("cc {0} {1}".format(q,obj)) + if obj==None: + break #we received the end + file_name,line,success, msg, func, external_source = obj + try: + #get function index + func_index = -1 + if func in func_list: + func_index = func_list[func] + else: + func_list[func] = last_func_index + func_index = last_func_index + last_func_index +=1 + + if file_name in info_: + info_[file_name].add_fsave(line, success, msg, func_index) + else: + info_[file_name] = File_Info().add_fsave(line, success, msg, func_index) + info_[file_name].external = external_source + total_c +=1 + if total_c - r_c >10000: + r_c = total_c + print("::consumer {0:2d} :{1:10d}".format(c_index,total_c)) + except Exception as e: + print(traceback.format_exc()) + break + + print("::consumer {0:2d} :{1:10d}".format(c_index,total_c)) + #write to temp file + wr_fname= "vecmiss_fsave{0}.html".format(str(c_index) if len(list_Queue)>1 else '') + print("generate report for consumer {0} {1}".format(c_index,len(info_))) + try: + uniq_ind = str(c_index)+'_' if len(list_Queue)>1 else '' + generate_report(wr_fname,info_ ,only_body = False, unique_id_prefix = uniq_ind,fsave_format = True, function_list= func_list) + print(" consumer {0} saved output into {1}".format(c_index,wr_fname)) + except Exception as e: + print(traceback.format_exc()) + + + +def obtain_info_from(input_): + info_ = dict() + for line in input_: + x = mtch.match(line) + external_source = True + if x: + file_name =x.group(1).strip() + if file_dir_strip in file_name: + file_name = file_name.replace(file_dir_strip,'') + external_source = False + line_number = int(x.group(2)) + msg = x.group(4).lower() + msg = msg.replace(file_dir_strip,'./') + msg_x = get_msg(msg) + if msg_x is None: + continue + if file_name in info_: + #ignore col_number + info_[file_name].add(line_number,msg_x) + else: + #print("{0} {1}".format(file_name,external_source)) + info_[file_name] = File_Info().add(line_number,msg_x) + info_[file_name].external = external_source + elif progress_msg.match(line): + #actually we redirect only, stderr so this should not happen + print("__"+line.strip()) + elif "error" in line or "Error" in line: + print("****"+line.strip()) + return info_ + + + +def custom_style(fsave): + st = '''''' + +def header(fsave=False): + strx ='\n\n\n\nAuto-Vectorization\n' + strx +=''.format(BASE_URL) + strx +=custom_style(fsave) + strx +='\n\n\n' + return strx + +def footer(): + return '\n' + + +def get_compressed_indices(set_a): + a_len = len(set_a) + if a_len<=1: + if a_len<1: + return '' + return str(set_a)[1:-1] + #we sorted and only saved difference + # 1,14,15,19 --> 1,13,1,4 10bytes=>8bytes + list_sorted = sorted(list(set_a)) + last = list_sorted[0] + str_x = str(list_sorted[0]) + for i in range(1,a_len): + str_x += ','+str(list_sorted[i]-last) + last = list_sorted[i] + return str_x + + + + + +def get_content(k, v, unique_id_prefix = '', fsave_format=False): + inner_str='' + content = '' + inc_id = 0 + for fk,fv in sorted(v.infos.items()): + if fsave_format==True: + inner_str+='
{1}
{2}
    '.format( + fk,fv.optimized,fv.missed,unique_id_prefix,inc_id) + else: + inner_str+='
    {2}
    {3}
      '.format( + k,fk,fv.optimized,fv.missed,unique_id_prefix,inc_id) + inc_id+=1 + if fsave_format==True: + # + for dt,df in fv.miss_details2.items(): + #inner_str +='
    • {1}
    • '.format(str(df).replace(", ",",")[1:-1],dt) + inner_str +='
    • {1}
    • '.format(get_compressed_indices(df),dt) + else: + for dt in fv.miss_details: + inner_str+="
    • "+str(dt)+ "
    • " + inner_str+="
    \n" + + content += '
    /g, ">") + .replace(/"/g, """) + .replace(/'/g, "'"); + } + for (i = 0; i < tags.length; i++) { + tags[i].addEventListener("click", function () { + var source = event.target || event.srcElement; + funcs = source.dataset.fns.split(",") + strx = '' + //we saved differences,not real indices + last_ind = 0; + for (j = 0; j < funcs.length; j++) { + ind = last_ind + parseInt(funcs[j]); + strx += "

    " + escapeHtml(func_list[ind]) + "

    "; + last_ind = ind; + } + if (strx.length > 0) { + content.innerHTML = strx; + modal.className = 'modal open'; + } + + }); + } + + };''' + +def additional_tags(fsave): + if fsave==False: + return '' + # + return ''' + + ''' + +def generate_report(output_name,info_ ,only_body = False, unique_id_prefix='',fsave_format = False , function_list = None): + ''' + Generate Auto-Vectorization Report in html format + ''' + + temp_str ='' + if fsave_format ==True: + # we gonna dump function_list as key list sorted by value + #and use it as jscript array + sorted_funcs_by_index = sorted(function_list.items(), key=lambda x: x[1]) + del function_list + with open(output_name+ ".js","w") as f: + #temp_str =jscript_head() +'{ "fmaps":[' + temp_str = jscript_head() + "\n var func_list = [" + for k,v in sorted_funcs_by_index: + #json.dumps using for escape + #print(str(v)+str(k)) + temp_str+=json.dumps(get_cxx_filt_result(k))+"," + #reduce write calls + if len(temp_str)>8192*2: + f.write(temp_str) + temp_str= '' + if len(temp_str)>0: + f.write(temp_str) + f.write('"-"];'+jscipt_end()) + + + temp_str = '' + with open(output_name,"w") as f: + if only_body==False: + f.write(header(fsave_format)) + f.write(additional_tags(fsave_format)) + nm=0 + for k,v in sorted(info_.items()): # sorted(info_.items(), key=lambda x: x[1].total_opted, reverse=True): + temp_str += get_content(k,v,unique_id_prefix+str(nm),fsave_format) + #reduce io write calls + if len(temp_str)>8192: + f.write(temp_str) + temp_str ='' + nm+=1 + if len(temp_str)>0: + f.write(temp_str) + if only_body==False: + f.write(footer()) + + +def fsave_report_launch(json_gz_list): + + cpus = cpu_count() + if cpus>32: + cpus = 24 + + c_count = 1 # 2 i sufficient # if cpus<=1 else min(4,cpus) + p_count = 3 if cpus<=1 else max(8, cpus - c_count) + + m = Manager() + #consumer Queues + list_Queue = [m.Queue() for index in range(0,c_count)] + with Pool(processes=c_count) as consumers: + #start consumers + cs = consumers.map_async(consume_processed_mp,[(list_Queue, index,) for index in range(0,c_count)]) + with Pool(processes=p_count) as processors: + processors.map(process_gzip_json_mp, [(fname, list_Queue,) for fname in json_gz_list]) + + #send ends to inform our consumers + #send ends + for q in list_Queue: + q.put(None) + + #wait for consumers + cs.wait() + + + + +def main(): + if "--fsave" in sys.argv: + json_gz_list = internal_glob(".","*.json.gz") + fsave_report_launch(json_gz_list) + return + + file_info = obtain_info_from(sys.stdin) + if len(file_info)>0: + #print(file_info) + print("---generating vectorization html report--") + generate_report("vecmiss.html", file_info) + else: + # lets check if we got fsave files + json_gz_list = internal_glob(".","*.json.gz") + fsave_report_launch(json_gz_list) + + + + +if __name__ == '__main__': + main() diff --git a/libnd4j/auto_vectorization/bigGzipJson.pyx b/libnd4j/auto_vectorization/bigGzipJson.pyx new file mode 100644 index 000000000..277bd16ec --- /dev/null +++ b/libnd4j/auto_vectorization/bigGzipJson.pyx @@ -0,0 +1,354 @@ +''' +@author : Abdelrauf rauf@konduit.ai +Simple object xtractor form very big json files +''' + +import sys +from cpython.mem cimport PyMem_Malloc, PyMem_Realloc, PyMem_Free + + +cdef char JSON_1 = b':' +cdef char JSON_2 = b',' +cdef char JSON_3 = b'{' +cdef char JSON_4 = b'}' +cdef char JSON_5 = b'[' +cdef char JSON_6 = b']' +cdef char QUOTE = b'"' +cdef char ESCAPE = b"\\" +cdef char SPACE = b' ' +cdef char TAB = b't' +cdef char CR = b'\r' +cdef char NL = b'\n' +cdef char B = b'\b' +cdef char EMPTY = b'\0' + + +cdef struct Span: + int b + int e + +cdef inline Span read_unquoted(char *text, int start,int end): + cdef Span sp + cdef int j = start + while j < end: + #if text[j].isspace(): + if text[j] == SPACE or text[j] == NL or text[j] == TAB or text[j] == CR or text[j] == B: + j += 1 + continue + if text[j] != QUOTE and text[j] != JSON_1 and text[j] != JSON_2 and text[j] != JSON_3 and text[j] != JSON_4 and text[j] != JSON_5 and text[j] != JSON_6: + start = j + j += 1 + while j < end: + # read till JSON or white space + if text[j] == SPACE or text[j] == NL or text[j] == TAB or text[j] == CR or text[j] == B: + sp.b = start + sp.e = j + return sp + elif text[j] == JSON_1 or text[j] == JSON_2 or text[j] == JSON_3 or text[j] == JSON_4 or text[j] == JSON_5 or text[j] == JSON_6: + sp.b = start + sp.e = j + return sp + j += 1 + if j == end-1: + sp.b = start + sp.e = end + return sp + break + sp.b = j + sp.e = j + return sp + + +cdef inline Span read_seq_token(char *text,int start,int end): + #read quoted + #skip white_space + cdef Span sp + cdef int j = start + cdef char last_char + cdef char char_x + while j < end: + if text[j] == SPACE or text[j] == NL or text[j] == TAB or text[j] == CR or text[j] == B: + j += 1 + continue + if text[j] == QUOTE: + last_char = EMPTY + #read till another quote + start = j + j += 1 + while j < end: + char_x = text[j] + if char_x == QUOTE and last_char != ESCAPE: + # finished reading + sp.b =start + sp.e = j+1 + return sp + last_char = char_x + j += 1 + if j == end-1: + sp.b = start + sp.e = end + return sp + else: + break + return read_unquoted(text, j, end) + + +def tokenizer_spans(utext): + ''' + we will just return tokenize spans + ''' + token_spans = [] + last_char = b'' + end_i = len(utext) + cdef char *text = utext + i = 0 + cdef Span sp + while i < end_i: + sp = read_seq_token(text, i, end_i) + i = sp.e + if sp.e > sp.b: + token_spans.append((sp.b, sp.e)) + if i < end_i: + #if text[i] in JSON: + if text[i] == JSON_3 or text[i] == JSON_4 or text[i] == JSON_5 or text[i] == JSON_6 or text[i] == JSON_1 or text[i] == JSON_2: + token_spans.append((i, i+1)) + i += 1 + return token_spans + + + + + +cdef class JsonObjXtractor: + ''' + JsonObjXtractor that utilize cython better + ''' + + cdef Span* token_spans + cdef size_t size + + def __cinit__(self, size_t count=4096): + self.token_spans = PyMem_Malloc(count * sizeof(Span)) + self.size = count + if not self.token_spans: + raise MemoryError() + + + def __tokenizer_spans(self,utext, length): + ''' + we will just return token spans length + ''' + + last_char = b'' + end_i = length + cdef char *text = utext + cdef int i = 0 + cdef size_t j = 0 + cdef Span sp + while i < end_i: + sp = read_seq_token(text, i, end_i) + i = sp.e + if sp.e > sp.b: + self.token_spans[j] = sp + j+=1 + if j>self.size: + #we need to reallocate + self.__resize(self.size+self.size//2) + if i < end_i: + #if text[i] in JSON: + if text[i] == JSON_3 or text[i] == JSON_4 or text[i] == JSON_5 or text[i] == JSON_6 or text[i] == JSON_1 or text[i] == JSON_2: + sp.b=i + sp.e=i+1 + self.token_spans[j] = sp + j+=1 + if j>self.size: + #we need to reallocate + self.__resize(self.size+self.size//2) + i += 1 + return j + + + + def try_extract_parent_obj(self, json_bytes, property_name, next_contains_value=b'', debug=False): + ''' + try_extract_parent_obj(json_text, property_name, next_contains_value='', debug=False): + make sure that passed variables encoded to bytes with encode('utf-8') + next_contains_value either direct content or followed by '[' + tries to extract the parent object for given named object + if the left brace of the parent object is outside of the current buffer + it will be ignored + if the right brace is outside of the buffer it will be left to be handled by caller + ''' + + look_for_the_left = True + parent_left = [] + parent_right = [] + parent_objects = [] + len_next = len(next_contains_value) + cdef int ind = 0 + cdef int end + cdef int last_start = 0 + property_name = b'"'+property_name+b'"' + cdef int lenx = self.__tokenizer_spans(json_bytes,len(json_bytes)) + cdef char x + cdef int i = -1 + cdef Span sp + while i < lenx-1: + i += 1 + ind = self.token_spans[i].b + x = json_bytes[ind] + #print("-----{0} -- {1} -- {2} ".format(x,parent_left,parent_right)) + if look_for_the_left == False: + if x == JSON_3: + parent_right.append(ind) + elif x == JSON_4: + if len(parent_right) == 0: + #we found parent closing brace + look_for_the_left = True + parent_objects.append((parent_left[-1], ind+1)) + last_start = ind+1 + #print("=============found {0}".format(parent_objects)) + parent_left = [] + parent_right = [] + else: + parent_right.pop() + continue + #search obj + if look_for_the_left: + if x == JSON_3: + parent_left.append(ind) + last_start = ind + elif x == JSON_4: + if len(parent_left) >= 1: + #ignore + parent_left.pop() + + if x == JSON_1: # ':' + #check to see if propertyname + old_property = EMPTY + if i > 1: + sp = self.token_spans[i-1] + old_property = json_bytes[sp.b:sp.e] + if old_property == property_name: + #we found + if len(parent_left) < 1: + #left brace is outside of the buffer + #we have to ignore it + #try to increase buffer + if debug: + print('''left brace of the parent is outside of the buffer and parent is big. + it will be ignored + try to choose disambiguous property names if you are looking for small objects''', file=sys.stderr) + last_start = ind+1 + parent_left = [] + parent_right = [] + continue + else: + #print("++++++ look for the right brace") + if len_next>0 and i+1 < lenx: + i += 1 + ind = self.token_spans[i].b + end = self.token_spans[i].e + m = json_bytes[ind] + + if m == JSON_5: + #print ("----{0} {1}".format(m,JSON_5)) + if i+1 < lenx: + i += 1 + ind = self.token_spans[i].b + end = self.token_spans[i].e + #print ("----{0} == {1}".format(next_contains_value,json_bytes[ind:end])) + if len_next <= end-ind and next_contains_value in json_bytes[ind:end]: + look_for_the_left = False + continue + elif len_next <= end-ind and next_contains_value in json_bytes[ind:end]: + look_for_the_left = False + continue + + #ignore as it does not have that value + parent_left = [] + parent_right = [] + last_start = ind + 1 + else: + look_for_the_left = False + + # lets return last succesful opened brace as the last + # or left brace failure case, safe closed brace + if len(parent_left)>0: + return (parent_objects, parent_left[-1]) + + return (parent_objects, last_start) + + + + def __resize(self, size_t new_count): + cdef Span* mem = PyMem_Realloc(self.token_spans, new_count * sizeof(Span)) + if not mem: + raise MemoryError() + self.token_spans = mem + self.size = new_count + + def __dealloc__(self): + PyMem_Free(self.token_spans) + + + +import json +import gzip +import sys +DEBUG_LOG = False + +def json_gzip_extract_objects(filename, property_name, next_contains_value=''): + strx = b'' + started= False + b_next_contains_value = next_contains_value.encode('utf-8') + b_property_name = property_name.encode('utf-8') + #print(b_property_name) + objXt = JsonObjXtractor() + with gzip.open(filename, 'rb') as f: + if DEBUG_LOG: + print("opened {0}".format(filename), file=sys.stderr) + #instead of reading it as line, I will read it as binary bytes + is_End = False + #total = 0 + while is_End==False: + buffer = f.read(8192*2) + + lenx= len(buffer) + #total +=lenx + if lenx<1: + is_End = True + else: + strx = strx + buffer + + objects , last_index = objXt.try_extract_parent_obj(strx,b_property_name,b_next_contains_value) + + # if b_property_name in strx and b_next_contains_value in strx: + # print(strx) + # print(objects) + # print(last_index) + # print("===================================================") + + for start,end in objects: + yield json.loads(strx[start:end]) #.decode('utf-8')) + + + #remove processed + if last_index< len(strx): + strx = strx[last_index:] + + else: + strx = b'' + #print('----+++') + + if(len(strx)>16384*3): + #buffer to big + #try to avoid big parents + if DEBUG_LOG: + print("parent object is too big. please, look for better property name", file=sys.stderr) + + break + + + + diff --git a/libnd4j/auto_vectorization/cython_setup.py b/libnd4j/auto_vectorization/cython_setup.py new file mode 100644 index 000000000..9dc6ef0c1 --- /dev/null +++ b/libnd4j/auto_vectorization/cython_setup.py @@ -0,0 +1,3 @@ +from distutils.core import setup +from Cython.Build import cythonize +setup(ext_modules=cythonize("bigGzipJson.pyx", language_level="3")) diff --git a/libnd4j/blas/CMakeLists.txt b/libnd4j/blas/CMakeLists.txt index c1c5de399..a54ad52b4 100755 --- a/libnd4j/blas/CMakeLists.txt +++ b/libnd4j/blas/CMakeLists.txt @@ -282,6 +282,32 @@ elseif(CPU_BLAS) set_source_files_properties(../include/helpers/impl/OpTracker.cpp PROPERTIES COMPILE_FLAGS "-march=x86-64 -mtune=generic") endif() + if(CHECK_VECTORIZATION) + set(VECT_FILES cpu/NativeOps.cpp ${OPS_SOURCES} ${HELPERS_SOURCES} ${CUSTOMOPS_GENERIC_SOURCES} ${LOOPS_SOURCES}) + if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") + + if (CMAKE_COMPILER_IS_GNUCC AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 9.0) + set(CHECK_VECT_FLAGS "-ftree-vectorize -fsave-optimization-record") + #to process fsave-optimization-record we will need our cython version code + message("Build Auto vectorization helpers") + execute_process(COMMAND "python3" "${CMAKE_CURRENT_SOURCE_DIR}/../auto_vectorization/cython_setup.py" "build_ext" "--inplace" WORKING_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}/../auto_vectorization/" RESULT_VARIABLE ret) + message("build='${ret}'") + + #remove fail cases that gcc fails produce sometimes + file(GLOB_RECURSE FAILURE_CASES false ../include/loops/cpu/compilation_units/reduce3*.cpp) + #message("*****${FAILURE_CASES}") + foreach(FL_ITEM ${FAILURE_CASES}) + message("Removing failure cases ${FL_ITEM}") + list(REMOVE_ITEM VECT_FILES ${FL_ITEM}) + endforeach() + else() + set(CHECK_VECT_FLAGS "-ftree-vectorize -fopt-info-vec-optimized-missed") + endif() + message("CHECK VECTORIZATION ${CHECK_VECT_FLAGS}") + set_source_files_properties( ${VECT_FILES} PROPERTIES COMPILE_FLAGS "${CHECK_VECT_FLAGS}" ) + endif() + endif() + message("CPU BLAS") add_definitions(-D__CPUBLAS__=true) add_library(nd4jobj OBJECT cpu/NativeOps.cpp cpu/GraphExecutioner.cpp diff --git a/libnd4j/buildnativeoperations.sh b/libnd4j/buildnativeoperations.sh index a8b45e918..c07756d8c 100755 --- a/libnd4j/buildnativeoperations.sh +++ b/libnd4j/buildnativeoperations.sh @@ -55,6 +55,7 @@ TESTS="false" VERBOSE="false" VERBOSE_ARG="VERBOSE=1" HELPER= +CHECK_VECTORIZATION="OFF" NAME= while [[ $# > 0 ]] do @@ -114,6 +115,9 @@ case $key in NAME="$value" shift # past argument ;; + --check-vectorization) + CHECK_VECTORIZATION="ON" + ;; -j) MAKEJ="$value" shift # past argument @@ -528,14 +532,27 @@ echo MINIFIER = "${MINIFIER_ARG}" echo TESTS = "${TESTS_ARG}" echo NAME = "${NAME_ARG}" echo OPENBLAS_PATH = "$OPENBLAS_PATH" +echo CHECK_VECTORIZATION = "$CHECK_VECTORIZATION" echo HELPERS = "$HELPERS" mkbuilddir pwd -eval $CMAKE_COMMAND "$BLAS_ARG" "$ARCH_ARG" "$NAME_ARG" $HELPERS "$SHARED_LIBS_ARG" "$MINIFIER_ARG" "$OPERATIONS_ARG" "$BUILD_TYPE" "$PACKAGING_ARG" "$EXPERIMENTAL_ARG" "$TESTS_ARG" "$CUDA_COMPUTE" -DOPENBLAS_PATH="$OPENBLAS_PATH" -DDEV=FALSE -DCMAKE_NEED_RESPONSE=YES -DMKL_MULTI_THREADED=TRUE ../.. +eval $CMAKE_COMMAND "$BLAS_ARG" "$ARCH_ARG" "$NAME_ARG" -DCHECK_VECTORIZATION="${CHECK_VECTORIZATION}" $HELPERS "$SHARED_LIBS_ARG" "$MINIFIER_ARG" "$OPERATIONS_ARG" "$BUILD_TYPE" "$PACKAGING_ARG" "$EXPERIMENTAL_ARG" "$TESTS_ARG" "$CUDA_COMPUTE" -DOPENBLAS_PATH="$OPENBLAS_PATH" -DDEV=FALSE -DCMAKE_NEED_RESPONSE=YES -DMKL_MULTI_THREADED=TRUE ../.. + if [ "$PARALLEL" == "true" ]; then MAKE_ARGUMENTS="$MAKE_ARGUMENTS -j $MAKEJ" fi if [ "$VERBOSE" == "true" ]; then MAKE_ARGUMENTS="$MAKE_ARGUMENTS $VERBOSE_ARG" fi -eval $MAKE_COMMAND $MAKE_ARGUMENTS && cd ../../.. + +if [ "$CHECK_VECTORIZATION" == "ON" ]; then + +if [ "$MAKE_COMMAND" == "make" ]; then + MAKE_ARGUMENTS="$MAKE_ARGUMENTS --output-sync=target" +fi +exec 3>&1 +eval $MAKE_COMMAND $MAKE_ARGUMENTS 2>&1 >&3 3>&- | python3 ../../auto_vectorization/auto_vect.py && cd ../../.. +exec 3>&- +else +eval $MAKE_COMMAND $MAKE_ARGUMENTS && cd ../../.. +fi From 5039fb22b7dced128093ed18e01cc185df067e13 Mon Sep 17 00:00:00 2001 From: Alex Black Date: Wed, 29 Jan 2020 21:16:56 +1100 Subject: [PATCH 05/17] Fix datatype issue with GpuGraphRunnerTest (#198) Signed-off-by: AlexDBlack --- .../nd4j/tensorflow/conversion/GpuGraphRunnerTest.java | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/nd4j/nd4j-backends/nd4j-tests-tensorflow/src/test/gpujava/org/nd4j/tensorflow/conversion/GpuGraphRunnerTest.java b/nd4j/nd4j-backends/nd4j-tests-tensorflow/src/test/gpujava/org/nd4j/tensorflow/conversion/GpuGraphRunnerTest.java index 28cd5b7b2..a035592df 100644 --- a/nd4j/nd4j-backends/nd4j-tests-tensorflow/src/test/gpujava/org/nd4j/tensorflow/conversion/GpuGraphRunnerTest.java +++ b/nd4j/nd4j-backends/nd4j-tests-tensorflow/src/test/gpujava/org/nd4j/tensorflow/conversion/GpuGraphRunnerTest.java @@ -18,6 +18,7 @@ package org.nd4j.tensorflow.conversion; import org.nd4j.BaseND4JTest; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.shade.protobuf.util.JsonFormat; import org.apache.commons.io.IOUtils; import org.junit.Test; @@ -40,6 +41,11 @@ import static org.junit.Assert.assertNotNull; public class GpuGraphRunnerTest extends BaseND4JTest { + @Override + public long getTimeoutMilliseconds() { + return 180000L; + } + @Test public void testGraphRunner() throws Exception { byte[] content = IOUtils.toByteArray(new ClassPathResource("/tf_graphs/nd4j_convert/simple_graph/frozen_model.pb").getInputStream()); @@ -68,8 +74,8 @@ public class GpuGraphRunnerTest extends BaseND4JTest { assertEquals(2,graphRunner.getInputOrder().size()); assertEquals(1,graphRunner.getOutputOrder().size()); - INDArray input1 = Nd4j.linspace(1,4,4).reshape(4); - INDArray input2 = Nd4j.linspace(1,4,4).reshape(4); + INDArray input1 = Nd4j.linspace(1,4,4).reshape(4).castTo(DataType.FLOAT); + INDArray input2 = Nd4j.linspace(1,4,4).reshape(4).castTo(DataType.FLOAT); Map inputs = new LinkedHashMap<>(); inputs.put("input_0",input1); From ba961c7601effbce04a39f7db62f2abdf210868c Mon Sep 17 00:00:00 2001 From: raver119 Date: Thu, 30 Jan 2020 10:07:24 +0300 Subject: [PATCH 06/17] DataTypes & FlatBuffers (#197) * flatbuffers version upgrade Signed-off-by: raver119 * flatbuffers version upgrade java side Signed-off-by: raver119 * flatbuffers dependency version upgrade java side Signed-off-by: raver119 * MKLDNN version upgrade Signed-off-by: raver119 * DArgs first pass Signed-off-by: raver119 * signatures first pass Signed-off-by: raver119 * signatures second pass Signed-off-by: raver119 * signatures third pass Signed-off-by: raver119 * signatures third pass Signed-off-by: raver119 * signatures fourth pass Signed-off-by: raver119 * signatures fifth pass Signed-off-by: raver119 * flatbuffers UI version upgrade java side Signed-off-by: raver119 * flatbuffers ui update Signed-off-by: raver119 * flatbuffers downgrade Signed-off-by: raver119 * flatbuffers downgrade java side Signed-off-by: raver119 --- libnd4j/CMakeLists.txt.in | 2 +- libnd4j/CMakeLists.txt.mkldnn.in | 2 +- libnd4j/blas/NativeOps.h | 1 + libnd4j/blas/cpu/NativeOps.cpp | 11 +- libnd4j/blas/cuda/NativeOps.cu | 10 +- libnd4j/include/array/DataTypeUtils.h | 4 + libnd4j/include/array/impl/NDArrayList.cpp | 2 +- libnd4j/include/graph/Context.h | 2 + libnd4j/include/graph/ContextPrototype.h | 5 + libnd4j/include/graph/Node.h | 11 +- .../include/graph/generated/array_generated.h | 12 +- .../graph/generated/array_generated.js | 4 +- .../graph/generated/nd4j/graph/DType.cs | 2 + .../graph/generated/nd4j/graph/DType.java | 4 +- .../graph/generated/nd4j/graph/DType.py | 2 + .../graph/generated/nd4j/graph/UIVariable.cs | 6 +- .../graph/generated/uigraphstatic_generated.h | 10 +- .../generated/uigraphstatic_generated.js | 8 +- libnd4j/include/graph/impl/Context.cpp | 12 + .../include/graph/impl/ContextPrototype.cpp | 8 + libnd4j/include/graph/impl/Node.cpp | 18 + .../include/helpers/impl/AttentionHelper.cpp | 4 +- libnd4j/include/op_boilerplate.h | 2 + libnd4j/include/ops/declarable/BooleanOp.h | 5 +- libnd4j/include/ops/declarable/DeclarableOp.h | 21 +- .../declarable/generic/activations/crelu.cpp | 4 +- .../declarable/generic/boolean/where_np.cpp | 2 +- .../generic/broadcastable/floormod.cpp | 2 +- .../generic/nn/dot_product_attention.cpp | 6 +- .../nn/multi_head_dot_product_attention.cpp | 2 +- .../ops/declarable/generic/nn/relu_layer.cpp | 2 +- .../parity_ops/compare_and_bitpack.cpp | 2 +- .../generic/parity_ops/dynamic_parititon.cpp | 4 +- .../generic/parity_ops/embedding_lookup.cpp | 4 +- .../declarable/generic/parity_ops/onehot.cpp | 4 +- .../recurrent/dynamicBidirectionalRNN.cpp | 8 +- .../generic/shape/tile_to_shape.cpp | 2 +- .../declarable/helpers/impl/multiUnique.cpp | 2 +- .../include/ops/declarable/impl/BooleanOp.cpp | 11 +- .../ops/declarable/impl/DeclarableOp.cpp | 171 +++-- .../tests_cpu/layers_tests/AttentionTests.cpp | 12 +- .../tests_cpu/layers_tests/BackpropTests.cpp | 2 +- .../layers_tests/BooleanOpsTests.cpp | 22 +- .../layers_tests/BroadcastableOpsTests.cpp | 40 +- .../layers_tests/ConvolutionTests1.cpp | 142 ++-- .../layers_tests/ConvolutionTests2.cpp | 132 ++-- .../layers_tests/DataTypesValidationTests.cpp | 4 +- .../layers_tests/DeclarableOpsTests1.cpp | 193 +++--- .../layers_tests/DeclarableOpsTests10.cpp | 234 +++---- .../layers_tests/DeclarableOpsTests11.cpp | 210 +++--- .../layers_tests/DeclarableOpsTests12.cpp | 202 +++--- .../layers_tests/DeclarableOpsTests13.cpp | 186 +++--- .../layers_tests/DeclarableOpsTests14.cpp | 52 +- .../layers_tests/DeclarableOpsTests15.cpp | 128 ++-- .../layers_tests/DeclarableOpsTests16.cpp | 8 +- .../layers_tests/DeclarableOpsTests17.cpp | 6 +- .../layers_tests/DeclarableOpsTests2.cpp | 348 +++++----- .../layers_tests/DeclarableOpsTests3.cpp | 246 +++---- .../layers_tests/DeclarableOpsTests4.cpp | 201 +++--- .../layers_tests/DeclarableOpsTests5.cpp | 268 ++++---- .../layers_tests/DeclarableOpsTests6.cpp | 204 +++--- .../layers_tests/DeclarableOpsTests7.cpp | 612 +++++++++--------- .../layers_tests/DeclarableOpsTests8.cpp | 378 +++++------ .../layers_tests/DeclarableOpsTests9.cpp | 216 +++---- .../layers_tests/DeclarableOpsTestsCuda1.cu | 2 +- libnd4j/tests_cpu/layers_tests/EmptyTests.cpp | 24 +- .../tests_cpu/layers_tests/IndexingTests.cpp | 36 +- .../tests_cpu/layers_tests/LegacyOpsTests.cpp | 26 +- .../layers_tests/MultiDataTypeTests.cpp | 2 +- libnd4j/tests_cpu/layers_tests/NlpTests.cpp | 22 +- .../tests_cpu/layers_tests/ParityOpsTests.cpp | 160 ++--- libnd4j/tests_cpu/layers_tests/RNGTests.cpp | 66 +- .../tests_cpu/layers_tests/ScalarTests.cpp | 22 +- libnd4j/tests_cpu/layers_tests/ShapeTests.cpp | 2 +- .../tests_cpu/layers_tests/SingleDimTests.cpp | 16 +- .../src/main/java/org/nd4j/graph/DType.java | 4 +- .../main/java/org/nd4j/graph/FlatNode.java | 44 +- .../java/org/nd4j/graph/FlatVariable.java | 21 +- .../src/main/java/org/nd4j/graph/UIEvent.java | 18 +- .../src/main/java/org/nd4j/graph/UIOp.java | 12 +- .../main/java/org/nd4j/graph/UIVariable.java | 26 +- .../java/org/nd4j/nativeblas/Nd4jCuda.java | 86 ++- .../java/org/nd4j/nativeblas/Nd4jCpu.java | 88 ++- .../graph/grpc/GraphInferenceServerGrpc.java | 184 ++++-- 84 files changed, 2804 insertions(+), 2497 deletions(-) diff --git a/libnd4j/CMakeLists.txt.in b/libnd4j/CMakeLists.txt.in index 8e8741c86..f351bf1b7 100644 --- a/libnd4j/CMakeLists.txt.in +++ b/libnd4j/CMakeLists.txt.in @@ -5,7 +5,7 @@ project(flatbuffers-download NONE) include(ExternalProject) ExternalProject_Add(flatbuffers GIT_REPOSITORY https://github.com/google/flatbuffers.git - GIT_TAG v1.10.0 + GIT_TAG v1.11.0 SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/flatbuffers-src" BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/flatbuffers-build" CONFIGURE_COMMAND "" diff --git a/libnd4j/CMakeLists.txt.mkldnn.in b/libnd4j/CMakeLists.txt.mkldnn.in index 3069d9efe..3de36dfde 100644 --- a/libnd4j/CMakeLists.txt.mkldnn.in +++ b/libnd4j/CMakeLists.txt.mkldnn.in @@ -5,7 +5,7 @@ project(mkldnn-download NONE) include(ExternalProject) ExternalProject_Add(mkldnn GIT_REPOSITORY https://github.com/intel/mkl-dnn.git - GIT_TAG v1.1.2 + GIT_TAG v1.1.3 SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/mkldnn-src" BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/mkldnn-build" CONFIGURE_COMMAND "" diff --git a/libnd4j/blas/NativeOps.h b/libnd4j/blas/NativeOps.h index 862ffa42f..141ecb6ec 100755 --- a/libnd4j/blas/NativeOps.h +++ b/libnd4j/blas/NativeOps.h @@ -1607,6 +1607,7 @@ ND4J_EXPORT void setGraphContextInputArray(OpaqueContext* ptr, int index, void * ND4J_EXPORT void setGraphContextOutputArray(OpaqueContext* ptr, int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo); ND4J_EXPORT void setGraphContextInputBuffer(OpaqueContext* ptr, int index, OpaqueDataBuffer *buffer, void *shapeInfo, void *specialShapeInfo); ND4J_EXPORT void setGraphContextOutputBuffer(OpaqueContext* ptr, int index, OpaqueDataBuffer *buffer, void *shapeInfo, void *specialShapeInfo); +ND4J_EXPORT void setGraphContextDArguments(OpaqueContext* ptr, int *arguments, int numberOfArguments); ND4J_EXPORT void setGraphContextTArguments(OpaqueContext* ptr, double *arguments, int numberOfArguments); ND4J_EXPORT void setGraphContextIArguments(OpaqueContext* ptr, Nd4jLong *arguments, int numberOfArguments); ND4J_EXPORT void setGraphContextBArguments(OpaqueContext* ptr, bool *arguments, int numberOfArguments); diff --git a/libnd4j/blas/cpu/NativeOps.cpp b/libnd4j/blas/cpu/NativeOps.cpp index 2e203584d..3ba971aa5 100644 --- a/libnd4j/blas/cpu/NativeOps.cpp +++ b/libnd4j/blas/cpu/NativeOps.cpp @@ -2130,7 +2130,7 @@ Nd4jStatus realExec(nd4j::ops::DeclarableOp* op, Nd4jPointer* extraPointers, Nd4 biArgs[e] = bArgs[e]; // hypothetically at this point we have everything filled - auto hZ = op->execute(inputs, outputs, ttArgs, iiArgs, biArgs, isInplace); + auto hZ = op->execute(inputs, outputs, ttArgs, iiArgs, biArgs, std::vector(), isInplace); //auto hZ = op->execute(inputs, ttArgs, iiArgs, isInplace); @@ -2788,6 +2788,15 @@ void setGraphContextIArguments(nd4j::graph::Context* ptr, Nd4jLong *arguments, i void setGraphContextBArguments(nd4j::graph::Context* ptr, bool *arguments, int numberOfArguments) { ptr->setBArguments(arguments, numberOfArguments); } + +void setGraphContextDArguments(OpaqueContext* ptr, int *arguments, int numberOfArguments) { + std::vector dtypes(numberOfArguments); + for (int e = 0; e < numberOfArguments; e++) + dtypes[e] = (nd4j::DataType) arguments[e]; + + ptr->setDArguments(dtypes); +} + void deleteGraphContext(nd4j::graph::Context* ptr) { delete ptr; } diff --git a/libnd4j/blas/cuda/NativeOps.cu b/libnd4j/blas/cuda/NativeOps.cu index b7995cb75..45de82b32 100755 --- a/libnd4j/blas/cuda/NativeOps.cu +++ b/libnd4j/blas/cuda/NativeOps.cu @@ -2831,7 +2831,7 @@ static FORCEINLINE Nd4jStatus realExec(nd4j::ops::DeclarableOp* op, Nd4jPointer* // hypothetically at this point we have everything filled - auto dZ = op->execute(inputs, outputs, ttArgs, iiArgs, bbArgs, isInplace); + auto dZ = op->execute(inputs, outputs, ttArgs, iiArgs, bbArgs, std::vector(), isInplace); //auto dZ = op->execute(inputs, ttArgs, iiArgs, isInplace); @@ -3596,6 +3596,14 @@ void setGraphContextBArguments(nd4j::graph::Context* ptr, bool *arguments, int n ptr->setBArguments(arguments, numberOfArguments); } +void setGraphContextDArguments(OpaqueContext* ptr, int *arguments, int numberOfArguments) { + std::vector dtypes(numberOfArguments); + for (int e = 0; e < numberOfArguments; e++) + dtypes[e] = (nd4j::DataType) arguments[e]; + + ptr->setDArguments(dtypes); +} + void deleteGraphContext(nd4j::graph::Context* ptr) { delete ptr; } diff --git a/libnd4j/include/array/DataTypeUtils.h b/libnd4j/include/array/DataTypeUtils.h index 7561e96cc..5d17c28b0 100644 --- a/libnd4j/include/array/DataTypeUtils.h +++ b/libnd4j/include/array/DataTypeUtils.h @@ -95,6 +95,10 @@ namespace nd4j { template // struct scalarTypesForNDarray { static bool const value = std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value; }; struct scalarTypesForNDarray { static bool const value = std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value; }; + + template + struct scalarTypesForExecution { static bool const value = std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value; }; + }; diff --git a/libnd4j/include/array/impl/NDArrayList.cpp b/libnd4j/include/array/impl/NDArrayList.cpp index cb1461226..81ac9ac2d 100644 --- a/libnd4j/include/array/impl/NDArrayList.cpp +++ b/libnd4j/include/array/impl/NDArrayList.cpp @@ -158,7 +158,7 @@ namespace nd4j { iargs.push_back(_axis); - auto result = op.execute(inputs, {}, {}, {}); + auto result = op.evaluate(inputs); auto array = new NDArray(result->at(0)->dup()); diff --git a/libnd4j/include/graph/Context.h b/libnd4j/include/graph/Context.h index 5e0f094e1..96b7e1c79 100644 --- a/libnd4j/include/graph/Context.h +++ b/libnd4j/include/graph/Context.h @@ -197,10 +197,12 @@ namespace nd4j { void setTArguments(double *arguments, int numberOfArguments); void setIArguments(Nd4jLong *arguments, int numberOfArguments); void setBArguments(bool *arguments, int numberOfArguments); + void setDArguments(nd4j::DataType *arguments, int numberOfArguments); void setTArguments(const std::vector &tArgs); void setIArguments(const std::vector &tArgs); void setBArguments(const std::vector &tArgs); + void setDArguments(const std::vector &dArgs); void setCudaContext(Nd4jPointer cudaStream, Nd4jPointer reductionPointer, Nd4jPointer allocationPointer); diff --git a/libnd4j/include/graph/ContextPrototype.h b/libnd4j/include/graph/ContextPrototype.h index bf5d389e4..fac664598 100644 --- a/libnd4j/include/graph/ContextPrototype.h +++ b/libnd4j/include/graph/ContextPrototype.h @@ -47,6 +47,9 @@ namespace nd4j { std::vector _iArgs; std::vector _bArgs; std::vector _axis; + std::vector _dArgs; + + // TODO: remove this field nd4j::DataType _dataType = nd4j::DataType::FLOAT32; bool _isInplace; @@ -93,6 +96,7 @@ namespace nd4j { std::vector* getTArguments(); std::vector* getIArguments(); std::vector* getBArguments(); + std::vector* getDArguments(); std::vector* getAxis(); samediff::Engine engine(); @@ -100,6 +104,7 @@ namespace nd4j { size_t numT(); size_t numI(); size_t numB(); + size_t numD(); std::pair* input(int idx); diff --git a/libnd4j/include/graph/Node.h b/libnd4j/include/graph/Node.h index b57998e38..f07bfac18 100644 --- a/libnd4j/include/graph/Node.h +++ b/libnd4j/include/graph/Node.h @@ -38,7 +38,9 @@ namespace nd4j { class ND4J_EXPORT Node { protected: + // TODO: this field must be removed nd4j::DataType _dataType; + OpType _opType; ContextPrototype* _protoContext = nullptr; Nd4jLong _opNum; @@ -61,6 +63,7 @@ namespace nd4j { // optional scalar. used in scalar ops and in summary stats + // TODO: this field must be removed NDArray _scalar; bool _hasExternalOutputs; @@ -87,15 +90,15 @@ namespace nd4j { int _scope_id = 0; std::string _scope_name; + // TODO: these 3 fields should be removed int _rewindNode = -1; std::pair _rewindLayer = {-1, -1}; - Nd4jLong _frameId = -1; public: - Node(nd4j::ops::DeclarableOp *customOp, int id = 0, std::initializer_list input = {}, std::initializer_list output = {}, std::initializer_list dimensions = {}, float scalar = 0.0f, std::initializer_list tArgs = {}, std::initializer_list iArgs = {}); - Node(OpType opType = OpType_TRANSFORM_SAME, int opNum = 0, int id = 0, std::initializer_list input = {}, std::initializer_list output = {}, std::initializer_list dimensions = {}, float scalar = 0.0f, std::initializer_list tArgs = {}, std::initializer_list iArgs = {}); - Node(const nd4j::graph::FlatNode *node); + explicit Node(nd4j::ops::DeclarableOp *customOp, int id = 0, std::initializer_list input = {}, std::initializer_list output = {}, std::initializer_list dimensions = {}, float scalar = 0.0f, std::initializer_list tArgs = {}, std::initializer_list iArgs = {}); + explicit Node(OpType opType = OpType_TRANSFORM_SAME, int opNum = 0, int id = 0, std::initializer_list input = {}, std::initializer_list output = {}, std::initializer_list dimensions = {}, float scalar = 0.0f, std::initializer_list tArgs = {}, std::initializer_list iArgs = {}); + explicit Node(const nd4j::graph::FlatNode *node); ~Node(); bool equals(Node *other); diff --git a/libnd4j/include/graph/generated/array_generated.h b/libnd4j/include/graph/generated/array_generated.h index b581240ad..e3b3bbe60 100644 --- a/libnd4j/include/graph/generated/array_generated.h +++ b/libnd4j/include/graph/generated/array_generated.h @@ -60,11 +60,13 @@ enum DType { DType_QINT16 = 16, DType_BFLOAT16 = 17, DType_UTF8 = 50, + DType_UTF16 = 51, + DType_UTF32 = 52, DType_MIN = DType_INHERIT, - DType_MAX = DType_UTF8 + DType_MAX = DType_UTF32 }; -inline const DType (&EnumValuesDType())[19] { +inline const DType (&EnumValuesDType())[21] { static const DType values[] = { DType_INHERIT, DType_BOOL, @@ -84,7 +86,9 @@ inline const DType (&EnumValuesDType())[19] { DType_QINT8, DType_QINT16, DType_BFLOAT16, - DType_UTF8 + DType_UTF8, + DType_UTF16, + DType_UTF32 }; return values; } @@ -142,6 +146,8 @@ inline const char * const *EnumNamesDType() { "", "", "UTF8", + "UTF16", + "UTF32", nullptr }; return names; diff --git a/libnd4j/include/graph/generated/array_generated.js b/libnd4j/include/graph/generated/array_generated.js index b98410a9e..adf6ce13b 100644 --- a/libnd4j/include/graph/generated/array_generated.js +++ b/libnd4j/include/graph/generated/array_generated.js @@ -42,7 +42,9 @@ nd4j.graph.DType = { QINT8: 15, QINT16: 16, BFLOAT16: 17, - UTF8: 50 + UTF8: 50, + UTF16: 51, + UTF32: 52 }; /** diff --git a/libnd4j/include/graph/generated/nd4j/graph/DType.cs b/libnd4j/include/graph/generated/nd4j/graph/DType.cs index 00e399b50..9062dc881 100644 --- a/libnd4j/include/graph/generated/nd4j/graph/DType.cs +++ b/libnd4j/include/graph/generated/nd4j/graph/DType.cs @@ -26,6 +26,8 @@ public enum DType : sbyte QINT16 = 16, BFLOAT16 = 17, UTF8 = 50, + UTF16 = 51, + UTF32 = 52, }; diff --git a/libnd4j/include/graph/generated/nd4j/graph/DType.java b/libnd4j/include/graph/generated/nd4j/graph/DType.java index 20d3d475b..c1b394ca7 100644 --- a/libnd4j/include/graph/generated/nd4j/graph/DType.java +++ b/libnd4j/include/graph/generated/nd4j/graph/DType.java @@ -23,8 +23,10 @@ public final class DType { public static final byte QINT16 = 16; public static final byte BFLOAT16 = 17; public static final byte UTF8 = 50; + public static final byte UTF16 = 51; + public static final byte UTF32 = 52; - public static final String[] names = { "INHERIT", "BOOL", "FLOAT8", "HALF", "HALF2", "FLOAT", "DOUBLE", "INT8", "INT16", "INT32", "INT64", "UINT8", "UINT16", "UINT32", "UINT64", "QINT8", "QINT16", "BFLOAT16", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "UTF8", }; + public static final String[] names = { "INHERIT", "BOOL", "FLOAT8", "HALF", "HALF2", "FLOAT", "DOUBLE", "INT8", "INT16", "INT32", "INT64", "UINT8", "UINT16", "UINT32", "UINT64", "QINT8", "QINT16", "BFLOAT16", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "UTF8", "UTF16", "UTF32", }; public static String name(int e) { return names[e]; } } diff --git a/libnd4j/include/graph/generated/nd4j/graph/DType.py b/libnd4j/include/graph/generated/nd4j/graph/DType.py index 24cadf44e..393ec7c4a 100644 --- a/libnd4j/include/graph/generated/nd4j/graph/DType.py +++ b/libnd4j/include/graph/generated/nd4j/graph/DType.py @@ -22,4 +22,6 @@ class DType(object): QINT16 = 16 BFLOAT16 = 17 UTF8 = 50 + UTF16 = 51 + UTF32 = 52 diff --git a/libnd4j/include/graph/generated/nd4j/graph/UIVariable.cs b/libnd4j/include/graph/generated/nd4j/graph/UIVariable.cs index 572f3e229..4b646b3bb 100644 --- a/libnd4j/include/graph/generated/nd4j/graph/UIVariable.cs +++ b/libnd4j/include/graph/generated/nd4j/graph/UIVariable.cs @@ -26,7 +26,7 @@ public struct UIVariable : IFlatbufferObject #endif public byte[] GetNameArray() { return __p.__vector_as_array(6); } public VarType Type { get { int o = __p.__offset(8); return o != 0 ? (VarType)__p.bb.GetSbyte(o + __p.bb_pos) : VarType.VARIABLE; } } - public DataType Datatype { get { int o = __p.__offset(10); return o != 0 ? (DataType)__p.bb.GetSbyte(o + __p.bb_pos) : DataType.INHERIT; } } + public DType Datatype { get { int o = __p.__offset(10); return o != 0 ? (DType)__p.bb.GetSbyte(o + __p.bb_pos) : DType.INHERIT; } } public long Shape(int j) { int o = __p.__offset(12); return o != 0 ? __p.bb.GetLong(__p.__vector(o) + j * 8) : (long)0; } public int ShapeLength { get { int o = __p.__offset(12); return o != 0 ? __p.__vector_len(o) : 0; } } #if ENABLE_SPAN_T @@ -70,7 +70,7 @@ public struct UIVariable : IFlatbufferObject Offset idOffset = default(Offset), StringOffset nameOffset = default(StringOffset), VarType type = VarType.VARIABLE, - DataType datatype = DataType.INHERIT, + DType datatype = DType.INHERIT, VectorOffset shapeOffset = default(VectorOffset), VectorOffset controlDepsOffset = default(VectorOffset), StringOffset outputOfOpOffset = default(StringOffset), @@ -101,7 +101,7 @@ public struct UIVariable : IFlatbufferObject public static void AddId(FlatBufferBuilder builder, Offset idOffset) { builder.AddOffset(0, idOffset.Value, 0); } public static void AddName(FlatBufferBuilder builder, StringOffset nameOffset) { builder.AddOffset(1, nameOffset.Value, 0); } public static void AddType(FlatBufferBuilder builder, VarType type) { builder.AddSbyte(2, (sbyte)type, 0); } - public static void AddDatatype(FlatBufferBuilder builder, DataType datatype) { builder.AddSbyte(3, (sbyte)datatype, 0); } + public static void AddDatatype(FlatBufferBuilder builder, DType datatype) { builder.AddSbyte(3, (sbyte)datatype, 0); } public static void AddShape(FlatBufferBuilder builder, VectorOffset shapeOffset) { builder.AddOffset(4, shapeOffset.Value, 0); } public static VectorOffset CreateShapeVector(FlatBufferBuilder builder, long[] data) { builder.StartVector(8, data.Length, 8); for (int i = data.Length - 1; i >= 0; i--) builder.AddLong(data[i]); return builder.EndVector(); } public static VectorOffset CreateShapeVectorBlock(FlatBufferBuilder builder, long[] data) { builder.StartVector(8, data.Length, 8); builder.Add(data); return builder.EndVector(); } diff --git a/libnd4j/include/graph/generated/uigraphstatic_generated.h b/libnd4j/include/graph/generated/uigraphstatic_generated.h index c980e32ec..8536a18ba 100644 --- a/libnd4j/include/graph/generated/uigraphstatic_generated.h +++ b/libnd4j/include/graph/generated/uigraphstatic_generated.h @@ -266,8 +266,8 @@ struct UIVariable FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VarType type() const { return static_cast(GetField(VT_TYPE, 0)); } - DataType datatype() const { - return static_cast(GetField(VT_DATATYPE, 0)); + DType datatype() const { + return static_cast(GetField(VT_DATATYPE, 0)); } const flatbuffers::Vector *shape() const { return GetPointer *>(VT_SHAPE); @@ -342,7 +342,7 @@ struct UIVariableBuilder { void add_type(VarType type) { fbb_.AddElement(UIVariable::VT_TYPE, static_cast(type), 0); } - void add_datatype(DataType datatype) { + void add_datatype(DType datatype) { fbb_.AddElement(UIVariable::VT_DATATYPE, static_cast(datatype), 0); } void add_shape(flatbuffers::Offset> shape) { @@ -389,7 +389,7 @@ inline flatbuffers::Offset CreateUIVariable( flatbuffers::Offset id = 0, flatbuffers::Offset name = 0, VarType type = VarType_VARIABLE, - DataType datatype = DataType_INHERIT, + DType datatype = DType_INHERIT, flatbuffers::Offset> shape = 0, flatbuffers::Offset>> controlDeps = 0, flatbuffers::Offset outputOfOp = 0, @@ -421,7 +421,7 @@ inline flatbuffers::Offset CreateUIVariableDirect( flatbuffers::Offset id = 0, const char *name = nullptr, VarType type = VarType_VARIABLE, - DataType datatype = DataType_INHERIT, + DType datatype = DType_INHERIT, const std::vector *shape = nullptr, const std::vector> *controlDeps = nullptr, const char *outputOfOp = nullptr, diff --git a/libnd4j/include/graph/generated/uigraphstatic_generated.js b/libnd4j/include/graph/generated/uigraphstatic_generated.js index c05088d1a..c6ec80aa3 100644 --- a/libnd4j/include/graph/generated/uigraphstatic_generated.js +++ b/libnd4j/include/graph/generated/uigraphstatic_generated.js @@ -503,11 +503,11 @@ nd4j.graph.UIVariable.prototype.type = function() { }; /** - * @returns {nd4j.graph.DataType} + * @returns {nd4j.graph.DType} */ nd4j.graph.UIVariable.prototype.datatype = function() { var offset = this.bb.__offset(this.bb_pos, 10); - return offset ? /** @type {nd4j.graph.DataType} */ (this.bb.readInt8(this.bb_pos + offset)) : nd4j.graph.DataType.INHERIT; + return offset ? /** @type {nd4j.graph.DType} */ (this.bb.readInt8(this.bb_pos + offset)) : nd4j.graph.DType.INHERIT; }; /** @@ -668,10 +668,10 @@ nd4j.graph.UIVariable.addType = function(builder, type) { /** * @param {flatbuffers.Builder} builder - * @param {nd4j.graph.DataType} datatype + * @param {nd4j.graph.DType} datatype */ nd4j.graph.UIVariable.addDatatype = function(builder, datatype) { - builder.addFieldInt8(3, datatype, nd4j.graph.DataType.INHERIT); + builder.addFieldInt8(3, datatype, nd4j.graph.DType.INHERIT); }; /** diff --git a/libnd4j/include/graph/impl/Context.cpp b/libnd4j/include/graph/impl/Context.cpp index 5efd13a20..4c7a19133 100644 --- a/libnd4j/include/graph/impl/Context.cpp +++ b/libnd4j/include/graph/impl/Context.cpp @@ -551,6 +551,18 @@ namespace nd4j { bool Context::isInference() { return _execMode == samediff::ExecutionMode::MODE_INFERENCE; } + + void Context::setDArguments(nd4j::DataType *arguments, int numberOfArguments) { + _dArgs.clear(); + for (int e = 0; e < numberOfArguments; e++) + _dArgs.emplace_back(arguments[e]); + } + + void Context::setDArguments(const std::vector &dArgs) { + _dArgs.clear(); + for (auto d:dArgs) + _dArgs.emplace_back(d); + } } } diff --git a/libnd4j/include/graph/impl/ContextPrototype.cpp b/libnd4j/include/graph/impl/ContextPrototype.cpp index 0ddde97f4..e8432aea0 100644 --- a/libnd4j/include/graph/impl/ContextPrototype.cpp +++ b/libnd4j/include/graph/impl/ContextPrototype.cpp @@ -173,5 +173,13 @@ namespace nd4j { return clone; } + + std::vector *ContextPrototype::getDArguments() { + return &_dArgs; + } + + size_t ContextPrototype::numD() { + return _dArgs.size(); + } } } \ No newline at end of file diff --git a/libnd4j/include/graph/impl/Node.cpp b/libnd4j/include/graph/impl/Node.cpp index 9d2224d2f..47c31cdf7 100644 --- a/libnd4j/include/graph/impl/Node.cpp +++ b/libnd4j/include/graph/impl/Node.cpp @@ -587,6 +587,12 @@ namespace nd4j { block->getIArguments()->emplace_back(node->extraInteger()->Get(e)); } + if (node->outputTypes() != nullptr && node->outputTypes()->size() > 0) { + for (int e = 0; e < (int) node->outputTypes()->size(); e++) { + block->getDArguments()->emplace_back((nd4j::DataType) node->outputTypes()->Get(e)); + } + } + this->setContextPrototype(block); this->setCustomOp(Node::buildOpByType(_opType, (int) node->input()->size(), (int) block->getIArguments()->size(), (int) block->getTArguments()->size(), (int) _opNum, &_scalar)); block->setOpDescriptor(this->getCustomOp()->getOpDescriptor()); @@ -618,6 +624,12 @@ namespace nd4j { block->getIArguments()->emplace_back(node->extraInteger()->Get(e)); } + if (node->outputTypes() != nullptr && node->outputTypes()->size() > 0) { + for (int e = 0; e < (int) node->outputTypes()->size(); e++) { + block->getDArguments()->emplace_back((nd4j::DataType) node->outputTypes()->Get(e)); + } + } + this->setContextPrototype(block); this->setCustomOp(Node::buildOpByType(_opType, (int) node->inputPaired()->size(), (int) block->getIArguments()->size(), (int) block->getTArguments()->size(), (int) _opNum, &_scalar)); @@ -652,6 +664,12 @@ namespace nd4j { block->getBArguments()->push_back(node->extraBools()->Get(e)); } + if (node->outputTypes() != nullptr && node->outputTypes()->size() > 0) { + for (int e = 0; e < (int) node->outputTypes()->size(); e++) { + block->getDArguments()->emplace_back((nd4j::DataType) node->outputTypes()->Get(e)); + } + } + for (auto v: _dimensions) block->getAxis()->emplace_back(v); diff --git a/libnd4j/include/helpers/impl/AttentionHelper.cpp b/libnd4j/include/helpers/impl/AttentionHelper.cpp index 3cfee1c08..731c9e56a 100644 --- a/libnd4j/include/helpers/impl/AttentionHelper.cpp +++ b/libnd4j/include/helpers/impl/AttentionHelper.cpp @@ -40,7 +40,7 @@ namespace nd4j { NDArray projected('c', {numHeads * projectionMatrix->sizeAt(1), (miniBatchSize * seqLength)}, input->dataType(), context); //[nHeads*hS, batch*timeSteps] nd4j::ops::matmul mmul; - mmul.execute({&projectionPrep, &inputPrep}, {&projected}, {}, {}, {}); + mmul.execute({&projectionPrep, &inputPrep}, {&projected}); projected.reshapei({numHeads, projectedSize, miniBatchSize, seqLength}); projected.permutei({2, 0, 1, 3}); //[minibatch, numHeads, projectedSize, seqLength] @@ -66,7 +66,7 @@ namespace nd4j { nd4j::ops::matmul_bp mmulBp; NDArray dLdProjectionPrep(projectionPrep.shapeInfo(), false, context); NDArray dLdInputPrep(inputPrep.shapeInfo(), false, context); - mmulBp.execute({&projectionPrep, &inputPrep, &epsReshaped}, {&dLdProjectionPrep, &dLdInputPrep}, {}, {}, {}); + mmulBp.execute({&projectionPrep, &inputPrep, &epsReshaped}, std::vector{&dLdProjectionPrep, &dLdInputPrep}, {}, {}, {}); dLdProjectionPrep.reshapei({numHeads, projectionMatrix->sizeAt(1), projectionMatrix->sizeAt(2)}); dLdProjectionMatrix->assign(dLdProjectionPrep); diff --git a/libnd4j/include/op_boilerplate.h b/libnd4j/include/op_boilerplate.h index 97f33569b..8487f0264 100644 --- a/libnd4j/include/op_boilerplate.h +++ b/libnd4j/include/op_boilerplate.h @@ -1516,7 +1516,9 @@ #define INPUT_LIST(INDEX) reinterpret_cast(block.getVariable(INDEX)->getNDArrayList()) +#define D_ARG(INDEX) block.getDArguments()->at(INDEX) #define INT_ARG(INDEX) block.getIArguments()->at(INDEX) +#define I_ARG(INDEX) INT_ARG(INDEX) #define T_ARG(INDEX) block.getTArguments()->at(INDEX) #define B_ARG(INDEX) block.getBArguments()->at(INDEX) diff --git a/libnd4j/include/ops/declarable/BooleanOp.h b/libnd4j/include/ops/declarable/BooleanOp.h index b741c61c4..c13555407 100644 --- a/libnd4j/include/ops/declarable/BooleanOp.h +++ b/libnd4j/include/ops/declarable/BooleanOp.h @@ -36,9 +36,8 @@ namespace nd4j { public: BooleanOp(const char *name, int numInputs, bool scalar); - bool evaluate(std::initializer_list args); - bool evaluate(std::vector& args); - bool evaluate(nd4j::graph::Context& block); + bool verify(const std::vector& args); + bool verify(nd4j::graph::Context& block); Nd4jStatus execute(Context* block) override; diff --git a/libnd4j/include/ops/declarable/DeclarableOp.h b/libnd4j/include/ops/declarable/DeclarableOp.h index ea1f20d34..ff8fe9e83 100644 --- a/libnd4j/include/ops/declarable/DeclarableOp.h +++ b/libnd4j/include/ops/declarable/DeclarableOp.h @@ -169,13 +169,22 @@ namespace nd4j { */ virtual Nd4jStatus execute(Context* block); - nd4j::ResultSet* execute(std::initializer_list inputs, std::initializer_list tArgs, std::initializer_list iArgs, std::initializer_list bArgs, bool isInplace = false, nd4j::DataType type = nd4j::DataType::FLOAT32); - Nd4jStatus execute(std::initializer_list inputs, std::initializer_list outputs , std::initializer_list tArgs, std::initializer_list iArgs, std::initializer_list bArgs, bool isInplace = false, nd4j::DataType type = nd4j::DataType::FLOAT32); - Nd4jStatus execute(nd4j::graph::RandomGenerator& rng, std::initializer_list inputs, std::initializer_list outputs , std::initializer_list tArgs, std::initializer_list iArgs, std::initializer_list bArgs, bool isInplace = false, nd4j::DataType type = nd4j::DataType::FLOAT32); + Nd4jStatus execute(const std::vector &inputs, const std::vector &outputs); - nd4j::ResultSet* execute(const std::vector& inputs, const std::vector& tArgs, const std::vector& iArgs, const std::vector& bArgs = std::vector(), bool isInplace = false, nd4j::DataType type = nd4j::DataType::FLOAT32); - Nd4jStatus execute(std::vector& inputs, std::vector& outputs , std::vector& tArgs, std::vector& iArgs, std::vector& bArgs, bool isInplace = false, nd4j::DataType type = nd4j::DataType::FLOAT32); - Nd4jStatus execute(nd4j::graph::RandomGenerator& rng, std::vector& inputs, std::vector& outputs, std::vector& tArgs, std::vector& iArgs, std::vector& bArgs, bool isInplace = false, nd4j::DataType type = nd4j::DataType::FLOAT32); + template + Nd4jStatus execute(const std::vector &inputs, const std::vector &outputs, std::initializer_list tArgs); + + Nd4jStatus execute(const std::vector &inputs, const std::vector &outputs, const std::vector &tArgs, const std::vector &iArgs, const std::vector &bArgs = std::vector(), const std::vector &dArgs = std::vector(), bool isInplace = false); + + + nd4j::ResultSet* evaluate(const std::vector &inputs); + + template + nd4j::ResultSet* evaluate(const std::vector &inputs, std::initializer_list args); + + nd4j::ResultSet* evaluate(const std::vector &inputs, const std::vector &tArgs, const std::vector &iArgs, const std::vector &bArgs = std::vector(), const std::vector &dArgs = std::vector(), bool isInplace = false); + + Nd4jStatus execute(nd4j::graph::RandomGenerator& rng, const std::vector& inputs, const std::vector& outputs, const std::vector& tArgs, const std::vector& iArgs, const std::vector& bArgs, const std::vector &dArgs = std::vector(), bool isInplace = false, nd4j::DataType type = nd4j::DataType::FLOAT32); nd4j::ResultSet* execute(const nd4j::OpArgsHolder& holder, bool isInplace = false); diff --git a/libnd4j/include/ops/declarable/generic/activations/crelu.cpp b/libnd4j/include/ops/declarable/generic/activations/crelu.cpp index 8ce3cbf75..a0ba6aa11 100644 --- a/libnd4j/include/ops/declarable/generic/activations/crelu.cpp +++ b/libnd4j/include/ops/declarable/generic/activations/crelu.cpp @@ -73,7 +73,7 @@ namespace nd4j { // at first step we build fwd activation nd4j::ops::crelu op; - auto tmpResult = op.execute({input}, {}, {}, {}); + auto tmpResult = op.evaluate({input}); if (tmpResult->status() != ND4J_STATUS_OK) return tmpResult->status(); @@ -84,7 +84,7 @@ namespace nd4j { helpers::reluDerivative(block.launchContext(), actv, epsilonNext); // now we split updated array into 2 chunks along last dimension nd4j::ops::concat_bp opc; - auto dec = opc.execute({input, input, actv}, {}, {-1}, {}); + auto dec = opc.evaluate({input, input, actv}, {-1}); if (dec->status() != ND4J_STATUS_OK) return dec->status(); diff --git a/libnd4j/include/ops/declarable/generic/boolean/where_np.cpp b/libnd4j/include/ops/declarable/generic/boolean/where_np.cpp index c06ef07d1..aa6169bae 100644 --- a/libnd4j/include/ops/declarable/generic/boolean/where_np.cpp +++ b/libnd4j/include/ops/declarable/generic/boolean/where_np.cpp @@ -103,7 +103,7 @@ namespace nd4j { // if (output->isEmpty()) Nd4jLong width = condition->rankOf(); nd4j::ops::Where op; - std::unique_ptr res(op.execute({condition}, {}, {}, {})); + std::unique_ptr res(op.evaluate({condition})); REQUIRE_OK(res->status()); NDArray* whereTrue = res->at(0); if (whereTrue->isEmpty()) diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/floormod.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/floormod.cpp index d442d89e7..062d3cfab 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/floormod.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/floormod.cpp @@ -66,7 +66,7 @@ namespace nd4j { auto gradY = OUTPUT_VARIABLE(1); gradX->assign(epsNext); nd4j::ops::floormod op; - std::unique_ptr tmpResult(op.execute({x, y}, {}, {}, {})); + std::unique_ptr tmpResult(op.evaluate({x, y})); if (gradY->rankOf() == gradX->rankOf()) epsNext->applyPairwiseTransform(pairwise::Multiply, *tmpResult->at(0), *gradY); diff --git a/libnd4j/include/ops/declarable/generic/nn/dot_product_attention.cpp b/libnd4j/include/ops/declarable/generic/nn/dot_product_attention.cpp index 4b97d58cd..726083deb 100644 --- a/libnd4j/include/ops/declarable/generic/nn/dot_product_attention.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/dot_product_attention.cpp @@ -91,7 +91,7 @@ namespace ops { } nd4j::ops::softmax softmax; - softmax.execute({weights}, {weights}, {}, {-2}, {}, true); + softmax.execute({weights}, std::vector{weights}, {}, {-2}, {}, {}, true); mmul.execute({values, weights}, {output}, {}, {}, {}); @@ -189,7 +189,7 @@ namespace ops { nd4j::ops::matmul_bp mmul_bp; NDArray dLdw(weights.getShapeInfo(), block.workspace()); - mmul_bp.execute({values, &weights, eps}, {dLdv, &dLdw}, {}, {}, {}); + mmul_bp.execute({values, &weights, eps}, std::vector{dLdv, &dLdw}, {}, {}, {}); NDArray dLds(preSoftmax.shapeInfo(), block.workspace()); nd4j::ops::softmax_bp softmax_bp; @@ -198,7 +198,7 @@ namespace ops { if(normalization) dLds /= factor; - mmul_bp.execute({keys, queries, &dLds}, {dLdk, dLdq}, {}, {1}, {}); + mmul_bp.execute({keys, queries, &dLds}, std::vector{dLdk, dLdq}, {}, {1}, {}); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/nn/multi_head_dot_product_attention.cpp b/libnd4j/include/ops/declarable/generic/nn/multi_head_dot_product_attention.cpp index 2123317b5..cff8545b2 100644 --- a/libnd4j/include/ops/declarable/generic/nn/multi_head_dot_product_attention.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/multi_head_dot_product_attention.cpp @@ -239,7 +239,7 @@ namespace ops { auto epsPostReshape = epsPerm.reshape(eps->ordering(), {miniBatchSize * queryCount, outSize}); nd4j::ops::matmul_bp matmulBp; NDArray dLdPreWo(attnResults.shapeInfo(), false, block.launchContext()); - matmulBp.execute({&attnResults, Wo, &epsPostReshape}, {&dLdPreWo, dLdWo}, {}, {}, {}); + matmulBp.execute({&attnResults, Wo, &epsPostReshape}, std::vector{&dLdPreWo, dLdWo}, {}, {}, {}); // dLdAttn dLdPreWo.reshapei({miniBatchSize, queryCount, numHeads, projectedValues.sizeAt(2)}); diff --git a/libnd4j/include/ops/declarable/generic/nn/relu_layer.cpp b/libnd4j/include/ops/declarable/generic/nn/relu_layer.cpp index 4e62abc60..cfc080117 100644 --- a/libnd4j/include/ops/declarable/generic/nn/relu_layer.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/relu_layer.cpp @@ -40,7 +40,7 @@ namespace nd4j { //nd4j_printf("Matrix x(%ix%i), Matrix w(%ix%i), b(1x%i)\n", x->sizeAt(0), x->sizeAt(1), w->sizeAt(0), w->sizeAt(1), b->lengthOf()); nd4j::ops::xw_plus_b op; - std::unique_ptr result(op.execute({x, w, b}, {}, {}, {})); + std::unique_ptr result(op.evaluate({x, w, b})); REQUIRE_TRUE(Status::OK() == result->status(), 0, "relu_layer: xw_plus_b op failed on input data."); auto scalar = block.numT() > 0 ? block.getTArguments()->at(0) : 0.0; diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/compare_and_bitpack.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/compare_and_bitpack.cpp index b9fe7fef9..1a30e0c91 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/compare_and_bitpack.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/compare_and_bitpack.cpp @@ -34,7 +34,7 @@ namespace nd4j { auto tZ = BroadcastHelper::broadcastApply(BROADCAST_BOOL(GreaterThan), x, y, &z0); bitcast res; - auto status = res.execute({tZ}, {z}, {}, {DataType::UINT8}, {}, false); + auto status = res.execute({tZ}, {z}, {}, {DataType::UINT8}, {}, {}, false); if (tZ != &z0) { delete tZ; } diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/dynamic_parititon.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/dynamic_parititon.cpp index 20670142a..49c9ed5e8 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/dynamic_parititon.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/dynamic_parititon.cpp @@ -112,7 +112,7 @@ namespace ops { NDArray originalIndices(*indices); //->ordering(), indices->shapeInfo(), indices->dataType()); originalIndices.linspace(0); ops::dynamic_partition op; - auto res = op.execute({&originalIndices, indices}, {}, {numPartition}); + auto res = op.evaluate({&originalIndices, indices}, {numPartition}); REQUIRE_TRUE(res->status() == ND4J_STATUS_OK, 0, "dynamic_partition_bp: Error with dynamic partitioning."); ops::dynamic_stitch stichOp; std::vector partitions(numPartition * 2); @@ -121,7 +121,7 @@ namespace ops { partitions[i + numPartition] = gradOutList[i]; } - auto result = stichOp.execute(partitions, {}, {numPartition}, {}, false); + auto result = stichOp.evaluate(partitions, {numPartition}); REQUIRE_TRUE(result->status() == ND4J_STATUS_OK, 0, "dynamic_partition_bp: Error with dynamic partitioning."); result->at(0)->reshapei(outputList[0]->getShapeAsVector()); outputList[1]->assign(indices); diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/embedding_lookup.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/embedding_lookup.cpp index 822b4b91b..9df3d52b8 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/embedding_lookup.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/embedding_lookup.cpp @@ -66,7 +66,7 @@ CUSTOM_OP_IMPL(embedding_lookup, 2, 1, false, 0, 1) { nd4j::ops::gather op; - std::unique_ptr result(op.execute({input, indeces}, {}, {0}, {})); + std::unique_ptr result(op.evaluate({input, indeces}, {0})); REQUIRE_TRUE(result->status() == Status::OK(), 0, "embedding_lookup: cannot retrieve results from gather op."); REQUIRE_TRUE(result->at(0)->isSameShape(output), 0, "embedding_lookup: wrong shape of return from gather op."); output->assign(result->at(0)); @@ -94,7 +94,7 @@ DECLARE_SHAPE_FN(embedding_lookup) { for (int e = 1; e < outRank; e++) shapeInfo[e] = shape::sizeAt(inShapeInfo, e); - auto outShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo(block.dataType(), shape::order(inShapeInfo), shapeInfo); + auto outShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inShapeInfo), shape::order(inShapeInfo), shapeInfo); return SHAPELIST(outShapeInfo); } diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/onehot.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/onehot.cpp index 3e17d097d..49d91275f 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/onehot.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/onehot.cpp @@ -74,6 +74,8 @@ namespace nd4j { DECLARE_SHAPE_FN(onehot) { auto inShape = inputShape->at(0); + nd4j::DataType dtype = block.numD() > 0 ? D_ARG(0) : nd4j::DataType::FLOAT32; + int depth = -1; Nd4jLong axis = -1; @@ -99,7 +101,7 @@ namespace nd4j { shape.push_back(shape::shapeOf(inShape)[e]); shape.insert(shape.begin() + axis, depth); - newShape = ConstantShapeHelper::getInstance()->createShapeInfo(block.dataType(), 'c', rank + 1, shape.data()); + newShape = ConstantShapeHelper::getInstance()->createShapeInfo(dtype, 'c', rank + 1, shape.data()); return SHAPELIST(newShape); } diff --git a/libnd4j/include/ops/declarable/generic/recurrent/dynamicBidirectionalRNN.cpp b/libnd4j/include/ops/declarable/generic/recurrent/dynamicBidirectionalRNN.cpp index eb1a01861..7f536a9ea 100644 --- a/libnd4j/include/ops/declarable/generic/recurrent/dynamicBidirectionalRNN.cpp +++ b/libnd4j/include/ops/declarable/generic/recurrent/dynamicBidirectionalRNN.cpp @@ -84,7 +84,7 @@ CUSTOM_OP_IMPL(dynamic_bidirectional_rnn, 7, 4, false, 0, 0) { // forward steps nd4j::ops::dynamic_rnn dynamicRnn; - auto resultsFW = dynamicRnn.execute({x, WxFW, WhFW, bFW, h0FW, maxTimeStep}, {}, {timeMajor}, {}, false, x->dataType()); + auto resultsFW = dynamicRnn.evaluate({x, WxFW, WhFW, bFW, h0FW, maxTimeStep}, {timeMajor}); hFW->assign(resultsFW->at(0)); // [time x bS x numUnitsFW] or [bS x time x numUnitsFW] hFWFinal->assign(resultsFW->at(1)); @@ -97,17 +97,17 @@ CUSTOM_OP_IMPL(dynamic_bidirectional_rnn, 7, 4, false, 0, 0) { // reverse x nd4j::ops::reverse_sequence reverse; - auto resultsIn = timeMajor ? reverse.execute({x, seqLen}, {}, {0, 1}, {}, false, x->dataType()) : reverse.execute({x, seqLen}, {}, {1, 0}, {}, false, x->dataType()); + auto resultsIn = timeMajor ? reverse.evaluate({x, seqLen}, {0, 1}) : reverse.evaluate({x, seqLen}, {1, 0}); REQUIRE_TRUE (resultsIn->status() == ND4J_STATUS_OK, 0, "dynamic_bidirectional_rnn: there is a problem with reverse on the sequence."); auto revInput = resultsIn->at(0); // backward steps - auto resultsBW = dynamicRnn.execute({revInput, WxBW, WhBW, bBW, h0BW, maxTimeStep}, {}, {timeMajor}, {}); + auto resultsBW = dynamicRnn.evaluate({revInput, WxBW, WhBW, bBW, h0BW, maxTimeStep}, {timeMajor}); auto hBWtemp = resultsBW->at(0); // [time x bS x numUnitsBW] or [ bS x time xnumUnitsBW] hBWFinal->assign(resultsBW->at(1)); // reverse hBWtemp - auto resultsOut = timeMajor ? reverse.execute({hBWtemp, seqLen}, {}, {0, 1}, {}) : reverse.execute({hBWtemp, seqLen}, {}, {1, 0}, {}); + auto resultsOut = timeMajor ? reverse.evaluate({hBWtemp, seqLen}, {0, 1}) : reverse.evaluate({hBWtemp, seqLen}, {1, 0}); hBW->assign(resultsOut->at(0)); delete resultsOut; diff --git a/libnd4j/include/ops/declarable/generic/shape/tile_to_shape.cpp b/libnd4j/include/ops/declarable/generic/shape/tile_to_shape.cpp index 878c4c0a3..cc88fb46c 100644 --- a/libnd4j/include/ops/declarable/generic/shape/tile_to_shape.cpp +++ b/libnd4j/include/ops/declarable/generic/shape/tile_to_shape.cpp @@ -48,7 +48,7 @@ namespace ops { auto conv = ArrayUtils::toLongVector(*block.getIArguments()); - auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(block.dataType(), shape::order(in), conv); + auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(in), shape::order(in), conv); return SHAPELIST(newShape); } diff --git a/libnd4j/include/ops/declarable/helpers/impl/multiUnique.cpp b/libnd4j/include/ops/declarable/helpers/impl/multiUnique.cpp index a7b521601..cfa2fe53a 100644 --- a/libnd4j/include/ops/declarable/helpers/impl/multiUnique.cpp +++ b/libnd4j/include/ops/declarable/helpers/impl/multiUnique.cpp @@ -51,7 +51,7 @@ namespace helpers { throw std::runtime_error("multiUnique: cannot execute concat op properly."); nd4j::ops::unique opUnique; - auto uResult = opUnique.execute({&arrayFull}, {}, {}, {}); + auto uResult = opUnique.evaluate({&arrayFull}); if (Status::OK() != uResult->status()) throw std::runtime_error("multiUnique: cannot execute unique op properly."); diff --git a/libnd4j/include/ops/declarable/impl/BooleanOp.cpp b/libnd4j/include/ops/declarable/impl/BooleanOp.cpp index 436cddda3..ea4838f01 100644 --- a/libnd4j/include/ops/declarable/impl/BooleanOp.cpp +++ b/libnd4j/include/ops/declarable/impl/BooleanOp.cpp @@ -36,7 +36,7 @@ namespace nd4j { return SHAPELIST(ConstantShapeHelper::getInstance()->scalarShapeInfo(DataType::BOOL)); } - bool BooleanOp::evaluate(nd4j::graph::Context &block) { + bool BooleanOp::verify(nd4j::graph::Context &block) { // check if scalar or not // validation? @@ -58,11 +58,6 @@ namespace nd4j { } } - bool BooleanOp::evaluate(std::initializer_list args) { - std::vector vec(args); - return this->evaluate(vec); - } - bool BooleanOp::prepareOutputs(Context& ctx) { auto variableSpace = ctx.getVariableSpace(); @@ -120,7 +115,7 @@ namespace nd4j { return ND4J_STATUS_KERNEL_FAILURE; } - bool BooleanOp::evaluate(std::vector &args) { + bool BooleanOp::verify(const std::vector &args) { VariableSpace variableSpace; int cnt = -1; @@ -135,7 +130,7 @@ namespace nd4j { Context block(1, &variableSpace, false); block.fillInputs(in); - return this->evaluate(block); + return this->verify(block); } } } diff --git a/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp b/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp index 8d5cb90d4..6f26c1095 100644 --- a/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp +++ b/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp @@ -15,7 +15,7 @@ ******************************************************************************/ // -// Created by raver119 on 07.10.2017. +// @author raver119@gmail.com // #include @@ -27,6 +27,7 @@ #include #include #include +#include namespace nd4j { namespace ops { @@ -164,6 +165,9 @@ namespace nd4j { // we build list of input shapes if (ctx.isFastPath()) { for (const auto p:ctx.fastpath_in()) { + if (p == nullptr) + continue; + inSha.push_back(p->getShapeInfo()); } } else { @@ -357,6 +361,9 @@ namespace nd4j { std::vector inputTypes(block.width()); if (block.isFastPath()) { for (auto array: block.fastpath_in()) { + if (array == nullptr) + continue; + inputTypes[inT++] = array->dataType(); if (!_descriptor->checkInputMatch(cnt, array->dataType())) { auto ctype = DataTypeUtils::asString(array->dataType()); @@ -394,6 +401,9 @@ namespace nd4j { if (block.isFastPath()) { int index = 0; for (auto array: block.fastpath_out()) { + if (array == nullptr) + continue; + auto cType = array->dataType(); if (_descriptor->isSameMode()) { @@ -762,39 +772,7 @@ namespace nd4j { return ND4J_STATUS_OK; } - nd4j::ResultSet* nd4j::ops::DeclarableOp::execute(std::initializer_list inputs, std::initializer_list tArgs, std::initializer_list iArgs, std::initializer_list bArgs, bool isInplace, nd4j::DataType type) { - std::vector ins(inputs); - std::vector tas(tArgs); - std::vector ias(iArgs); - std::vector bas(bArgs); - return this->execute(ins, tas, ias, bas, isInplace, type); - } - - Nd4jStatus nd4j::ops::DeclarableOp::execute(std::initializer_list inputs, std::initializer_list outputs , std::initializer_list tArgs, std::initializer_list iArgs, std::initializer_list bArgs, bool isInplace, nd4j::DataType type) { - std::vector ins(inputs); - std::vector ous(outputs); - std::vector tas(tArgs); - std::vector ias(iArgs); - std::vector bas(bArgs); - return this->execute(ins, ous, tas, ias, bas, isInplace, type); - } - - Nd4jStatus nd4j::ops::DeclarableOp::execute(nd4j::graph::RandomGenerator& rng, std::initializer_list inputs, std::initializer_list outputs , std::initializer_list tArgs, std::initializer_list iArgs, std::initializer_list bArgs, bool isInplace, nd4j::DataType type) { - std::vector ins(inputs); - std::vector ous(outputs); - std::vector tas(tArgs); - std::vector ias(iArgs); - std::vector bas(bArgs); - return this->execute(rng, ins, ous, tas, ias, bas, isInplace, type); - } - - Nd4jStatus nd4j::ops::DeclarableOp::execute(std::vector& inputs, std::vector& outputs, std::vector& tArgs, std::vector& iArgs, std::vector& bArgs, bool isInplace, nd4j::DataType type) { - // TODO: nullptr here might be replaced - nd4j::graph::RandomGenerator rng(0, 0); - return execute(rng, inputs, outputs, tArgs, iArgs, bArgs, isInplace, type); - } - - Nd4jStatus nd4j::ops::DeclarableOp::execute(nd4j::graph::RandomGenerator& rng, std::vector& inputs, std::vector& outputs, std::vector& tArgs, std::vector& iArgs, std::vector& bArgs, bool isInplace, nd4j::DataType type) { + Nd4jStatus nd4j::ops::DeclarableOp::execute(nd4j::graph::RandomGenerator& rng, const std::vector& inputs, const std::vector& outputs, const std::vector& tArgs, const std::vector& iArgs, const std::vector& bArgs, const std::vector& dArgs, bool isInplace, nd4j::DataType type) { VariableSpace variableSpace; FlowPath fp; variableSpace.setFlowPath(&fp); @@ -838,12 +816,124 @@ namespace nd4j { for (int e = 0; e < bArgs.size(); e++) block.getBArguments()->push_back(static_cast(bArgs.at(e))); + for (int e = 0; e < dArgs.size(); e++) + block.getDArguments()->push_back(dArgs.at(e)); + Nd4jStatus result = this->execute(&block); return result; } - nd4j::ResultSet* nd4j::ops::DeclarableOp::execute(const std::vector& inputs, const std::vector& tArgs, const std::vector& iArgs, const std::vector& bArgs, bool isInplace, nd4j::DataType type) { + Nd4jStatus DeclarableOp::execute(const std::vector &inputs, const std::vector &outputs) { + return execute(inputs, outputs, std::vector(), std::vector(), std::vector(), std::vector()); + } + + template <> + Nd4jStatus DeclarableOp::execute(const std::vector &inputs, const std::vector &outputs, std::initializer_list tArgs) { + std::vector realArgs(tArgs); + return execute(inputs, outputs, realArgs, std::vector(), std::vector(), std::vector());; + } + + template <> + Nd4jStatus DeclarableOp::execute(const std::vector &inputs, const std::vector &outputs, std::initializer_list tArgs) { + std::vector realArgs; + for (auto v:tArgs) + realArgs.emplace_back(v); + + return execute(inputs, outputs, realArgs, std::vector(), std::vector(), std::vector());; + } + + template <> + Nd4jStatus DeclarableOp::execute(const std::vector &inputs, const std::vector &outputs, std::initializer_list iArgs) { + std::vector realArgs(iArgs); + return execute(inputs, outputs, std::vector(), realArgs, std::vector(), std::vector());; + } + + template <> + Nd4jStatus DeclarableOp::execute(const std::vector &inputs, const std::vector &outputs, std::initializer_list iArgs) { + std::vector realArgs; + for (auto v:iArgs) + realArgs.emplace_back(v); + + return execute(inputs, outputs, std::vector(), realArgs, std::vector(), std::vector());; + } + + template <> + Nd4jStatus DeclarableOp::execute(const std::vector &inputs, const std::vector &outputs, std::initializer_list bArgs) { + std::vector realArgs(bArgs); + return execute(inputs, outputs, std::vector(), std::vector(), realArgs, std::vector());; + } + + Nd4jStatus DeclarableOp::execute(const std::vector &inputs, const std::vector &outputs, const std::vector &tArgs, const std::vector &iArgs, const std::vector &bArgs, const std::vector &dArgs, bool isInplace) { + Context ctx(1); + + for (int e = 0; e < inputs.size(); e++) { + if (inputs[e] == nullptr) + break; + + ctx.setInputArray(e, inputs[e]); + } + + for (int e = 0; e < outputs.size(); e++) { + if (outputs[e] == nullptr) + break; + + ctx.setOutputArray(e, outputs[e]); + } + + + if (isInplace) + ctx.markInplace(isInplace); + + ctx.setIArguments(iArgs); + ctx.setTArguments(tArgs); + ctx.setBArguments(bArgs); + ctx.setDArguments(dArgs); + + return execute(&ctx); + } + + nd4j::ResultSet *DeclarableOp::evaluate(const std::vector &inputs) { + return evaluate(inputs, std::vector(), std::vector(), std::vector(), std::vector()); + } + + template <> + nd4j::ResultSet *DeclarableOp::evaluate(const std::vector &inputs, std::initializer_list iArgs) { + std::vector realArgs; + for (auto v:iArgs) + realArgs.emplace_back(v); + + return evaluate(inputs, std::vector(), realArgs, std::vector(), std::vector());; + } + + template <> + nd4j::ResultSet *DeclarableOp::evaluate(const std::vector &inputs, std::initializer_list iArgs) { + std::vector realArgs(iArgs); + return evaluate(inputs, std::vector(), realArgs, std::vector(), std::vector());; + } + + template <> + nd4j::ResultSet *DeclarableOp::evaluate(const std::vector &inputs, std::initializer_list tArgs) { + std::vector realArgs; + for (auto v:tArgs) + realArgs.emplace_back(v); + + return evaluate(inputs, realArgs, std::vector(), std::vector(), std::vector());; + } + + template <> + nd4j::ResultSet *DeclarableOp::evaluate(const std::vector &inputs, std::initializer_list tArgs) { + std::vector realArgs(tArgs); + return evaluate(inputs, realArgs, std::vector(), std::vector(), std::vector());; + } + + template <> + nd4j::ResultSet *DeclarableOp::evaluate(const std::vector &inputs, std::initializer_list bArgs) { + std::vector realArgs(bArgs); + return evaluate(inputs, std::vector(), std::vector(), realArgs, std::vector());; + } + + nd4j::ResultSet *DeclarableOp::evaluate(const std::vector &inputs, const std::vector &tArgs, const std::vector &iArgs, const std::vector &bArgs, const std::vector &dArgs, bool isInplace) { VariableSpace variableSpace; //ResultSet arrayList; FlowPath fp; @@ -862,21 +952,23 @@ namespace nd4j { } Context block(1, &variableSpace, false); - block.setDataType(0, type); + block.setDataType(0, nd4j::DataType::FLOAT32); block.fillInputs(in); block.markInplace(isInplace); - // block.setRNG(ProviderRNG::getInstance().getRNG()); + // block.setRNG(ProviderRNG::getInstance().getRNG()); for (int e = 0; e < tArgs.size(); e++) block.getTArguments()->emplace_back(tArgs.at(e)); - for (int e = 0; e < iArgs.size(); e++) block.getIArguments()->emplace_back(iArgs.at(e)); for (int e = 0; e < bArgs.size(); e++) block.getBArguments()->push_back(bArgs.at(e)); + for (int e = 0; e < dArgs.size(); e++) + block.getDArguments()->push_back(dArgs.at(e)); + Nd4jStatus status = this->execute(&block); auto arrayList = new ResultSet(); if (isInplace) @@ -907,7 +999,8 @@ namespace nd4j { } nd4j::ResultSet* nd4j::ops::DeclarableOp::execute(const nd4j::OpArgsHolder& holder, bool isInplace) { - return execute(holder.getInArrs(), holder.getTArgs(), holder.getIArgs(), holder.getBArgs(), isInplace, nd4j::DataType::DOUBLE); + // FIXME: add DArgs to OpArgsHolder + return evaluate(holder.getInArrs(), holder.getTArgs(), holder.getIArgs(), holder.getBArgs(), std::vector(), isInplace); } Nd4jStatus nd4j::ops::DeclarableOp::validateInputDimensionsMatch(Context& block) { diff --git a/libnd4j/tests_cpu/layers_tests/AttentionTests.cpp b/libnd4j/tests_cpu/layers_tests/AttentionTests.cpp index 102d9d313..aa9d941ea 100644 --- a/libnd4j/tests_cpu/layers_tests/AttentionTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/AttentionTests.cpp @@ -43,7 +43,7 @@ TEST_F(AttentionTests, basic_dot_product_attention) { auto queries = NDArrayFactory::create('c', {10, 4, 1}); nd4j::ops::dot_product_attention op; - auto result = op.execute({&queries, &keys, &values}, {}, {1, 0}, {}); + auto result = op.evaluate({&queries, &keys, &values}, {1, 0}); ASSERT_EQ(Status::OK(), result->status()); delete result; @@ -71,7 +71,7 @@ TEST_F(AttentionTests, basic_dot_product_attention_with_weights) { auto queries = NDArrayFactory::create('c', {10, 4, 1}); nd4j::ops::dot_product_attention op; - auto result = op.execute({&queries, &keys, &values}, {}, {1, 1}, {}); + auto result = op.evaluate({&queries, &keys, &values}, {1, 1}); ASSERT_EQ(Status::OK(), result->status()); delete result; @@ -85,7 +85,7 @@ TEST_F(AttentionTests, basic_dot_product_attention_with_mask) { mask.assign(1.); nd4j::ops::dot_product_attention op; - auto result = op.execute({&queries, &keys, &values, &mask}, {}, {1, 0}, {}); + auto result = op.evaluate({&queries, &keys, &values, &mask}, {1, 0}); ASSERT_EQ(Status::OK(), result->status()); delete result; @@ -117,7 +117,7 @@ TEST_F(AttentionTests, multi_head_input_dot_product_attention_with_mask) { mask.assign(1.); nd4j::ops::dot_product_attention op; - auto result = op.execute({&queries, &keys, &values, &mask}, {}, {1, 0}, {}); + auto result = op.evaluate({&queries, &keys, &values, &mask}, {1, 0}); ASSERT_EQ(Status::OK(), result->status()); delete result; @@ -153,7 +153,7 @@ TEST_F(AttentionTests, basic_multi_head_dot_product_attention) { auto Wo = NDArrayFactory::create('c', {2* 3, 4}); nd4j::ops::multi_head_dot_product_attention op; - auto result = op.execute({&queries, &keys, &values, &Wk, &Wv, &Wq, &Wo}, {}, {1, 0}, {}); + auto result = op.evaluate({&queries, &keys, &values, &Wk, &Wv, &Wq, &Wo}, {1, 0}); ASSERT_EQ(Status::OK(), result->status()); delete result; @@ -197,7 +197,7 @@ TEST_F(AttentionTests, basic_multi_head_dot_product_attention_with_mask) { nd4j::ops::multi_head_dot_product_attention op; - auto result = op.execute({&queries, &keys, &values, &Wk, &Wv, &Wq, &Wo, &mask}, {}, {1, 0}, {}); + auto result = op.evaluate({&queries, &keys, &values, &Wk, &Wv, &Wq, &Wo, &mask}, {1, 0}); ASSERT_EQ(Status::OK(), result->status()); delete result; diff --git a/libnd4j/tests_cpu/layers_tests/BackpropTests.cpp b/libnd4j/tests_cpu/layers_tests/BackpropTests.cpp index c0b777f29..88bd9b286 100644 --- a/libnd4j/tests_cpu/layers_tests/BackpropTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/BackpropTests.cpp @@ -37,7 +37,7 @@ TEST_F(BackpropTests, Test_Add_1) { NDArray e('c', {2, 3, 4}, nd4j::DataType::FLOAT32); nd4j::ops::add_bp op; - auto result = op.execute({&x, &y, &e}, {}, {}, {}); + auto result = op.evaluate({&x, &y, &e}); ASSERT_EQ(Status::OK(), result->status()); diff --git a/libnd4j/tests_cpu/layers_tests/BooleanOpsTests.cpp b/libnd4j/tests_cpu/layers_tests/BooleanOpsTests.cpp index 3cf9eeb04..38aada40f 100644 --- a/libnd4j/tests_cpu/layers_tests/BooleanOpsTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/BooleanOpsTests.cpp @@ -38,7 +38,7 @@ TEST_F(BooleanOpsTests, LtTest_1) { nd4j::ops::lt_scalar op; - ASSERT_TRUE(op.evaluate({x, y})); + ASSERT_TRUE(op.verify({x, y})); delete x; delete y; @@ -51,7 +51,7 @@ TEST_F(BooleanOpsTests, LtTest_2) { nd4j::ops::lt_scalar op; - ASSERT_FALSE(op.evaluate({x, y})); + ASSERT_FALSE(op.verify({x, y})); delete x; delete y; @@ -62,7 +62,7 @@ TEST_F(BooleanOpsTests, Is_non_decreasing_1) { nd4j::ops::is_non_decreasing op; - ASSERT_TRUE(op.evaluate({&x})); + ASSERT_TRUE(op.verify({&x})); } @@ -71,7 +71,7 @@ TEST_F(BooleanOpsTests, Is_non_decreasing_2) { nd4j::ops::is_non_decreasing op; - ASSERT_FALSE(op.evaluate({&x})); + ASSERT_FALSE(op.verify({&x})); } @@ -80,7 +80,7 @@ TEST_F(BooleanOpsTests, Is_strictly_increasing_1) { nd4j::ops::is_strictly_increasing op; - ASSERT_TRUE(op.evaluate({&x})); + ASSERT_TRUE(op.verify({&x})); } @@ -89,7 +89,7 @@ TEST_F(BooleanOpsTests, Is_strictly_increasing_2) { nd4j::ops::is_strictly_increasing op; - ASSERT_FALSE(op.evaluate({&x})); + ASSERT_FALSE(op.verify({&x})); } @@ -98,7 +98,7 @@ TEST_F(BooleanOpsTests, Is_strictly_increasing_3) { nd4j::ops::is_strictly_increasing op; - ASSERT_FALSE(op.evaluate({&x})); + ASSERT_FALSE(op.verify({&x})); } TEST_F(BooleanOpsTests, Is_strictly_increasing_5) { @@ -107,7 +107,7 @@ TEST_F(BooleanOpsTests, Is_strictly_increasing_5) { nd4j::ops::is_strictly_increasing op; - ASSERT_TRUE(op.evaluate({&x})); + ASSERT_TRUE(op.verify({&x})); } TEST_F(BooleanOpsTests, Is_strictly_increasing_6) { @@ -118,7 +118,7 @@ TEST_F(BooleanOpsTests, Is_strictly_increasing_6) { nd4j::ops::is_strictly_increasing op; - ASSERT_FALSE(op.evaluate({&x})); + ASSERT_FALSE(op.verify({&x})); } TEST_F(BooleanOpsTests, Is_numeric_tensor_1) { @@ -126,7 +126,7 @@ TEST_F(BooleanOpsTests, Is_numeric_tensor_1) { nd4j::ops::is_numeric_tensor op; - ASSERT_TRUE(op.evaluate({&x})); + ASSERT_TRUE(op.verify({&x})); } TEST_F(BooleanOpsTests, test_where_1) { @@ -136,7 +136,7 @@ TEST_F(BooleanOpsTests, test_where_1) { nd4j::ops::choose op; - auto result = op.execute({&x, &y}, {}, {3}); + auto result = op.evaluate({&x, &y}, {3}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); diff --git a/libnd4j/tests_cpu/layers_tests/BroadcastableOpsTests.cpp b/libnd4j/tests_cpu/layers_tests/BroadcastableOpsTests.cpp index 655683687..0b1daa3af 100644 --- a/libnd4j/tests_cpu/layers_tests/BroadcastableOpsTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/BroadcastableOpsTests.cpp @@ -46,7 +46,7 @@ TEST_F(BroadcastableOpsTests, Test_Add_1) { exp.applyBroadcast(broadcast::Add, {1}, y, exp); nd4j::ops::add op; - auto result = op.execute({&x, &y}, {}, {}, {}); + auto result = op.evaluate({&x, &y}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -73,7 +73,7 @@ TEST_F(BroadcastableOpsTests, Test_Multiply_1) { exp.applyBroadcast(broadcast::Multiply, {1}, y, exp); nd4j::ops::multiply op; - auto result = op.execute({&x, &y}, {}, {}, {}); + auto result = op.evaluate({&x, &y}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -98,7 +98,7 @@ TEST_F(BroadcastableOpsTests, Test_SquaredSubtract_1) { nd4j::ops::squaredsubtract op; - auto result = op.execute({&x, &y}, {}, {}, {}); + auto result = op.evaluate({&x, &y}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -117,7 +117,7 @@ TEST_F(BroadcastableOpsTests, Test_ScalarBroadcast_1) { auto exp = NDArrayFactory::create('c', {1,3}, {1, 0, -1}); nd4j::ops::subtract op; - auto result = op.execute({&x, &y}, {}, {}, {}); + auto result = op.evaluate({&x, &y}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -136,7 +136,7 @@ TEST_F(BroadcastableOpsTests, Test_ScalarBroadcast_2) { auto exp = NDArrayFactory::create('c', {1,3}, {1, 2, 3}); nd4j::ops::add op; - auto result = op.execute({&x, &y}, {}, {}, {}); + auto result = op.evaluate({&x, &y}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -155,7 +155,7 @@ TEST_F(BroadcastableOpsTests, Test_Maximum_1) { auto exp = NDArrayFactory::create('c', {2, 3}, {2, 2, 2, 2, 3, 2}); nd4j::ops::maximum op; - auto result = op.execute({&x, &row}, {}, {}, {}); + auto result = op.evaluate({&x, &row}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -173,7 +173,7 @@ TEST_F(BroadcastableOpsTests, Test_Minimum_1) { auto exp = NDArrayFactory::create('c', {2, 3}, {1, 2, 1, 1, 1, 1}); nd4j::ops::minimum op; - auto result = op.execute({&x, &col}, {}, {}, {}); + auto result = op.evaluate({&x, &col}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -281,7 +281,7 @@ TEST_F(BroadcastableOpsTests, Test_Scalar_Add_1) { auto exp = NDArrayFactory::create('c', {2, 2}, {3, 4, 5, 6}); nd4j::ops::add op; - auto result = op.execute({&x, &y}, {}, {}, {}); + auto result = op.evaluate({&x, &y}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -331,7 +331,7 @@ TEST_F(BroadcastableOpsTests, Test_Subtract_2) { auto e = NDArrayFactory::create('c', {2}, {1.0f, 0.0f}); nd4j::ops::subtract op; - auto result = op.execute({&x, &y}, {}, {}, {}); + auto result = op.evaluate({&x, &y}); auto z = result->at(0); ASSERT_TRUE(e.equalsTo(z)); @@ -509,7 +509,7 @@ TEST_F(BroadcastableOpsTests, Test_Multiply_7) { auto e = NDArrayFactory::create('c', {1}, {8.f}); nd4j::ops::multiply op; - auto result = op.execute({&x, &y}, {}, {}, {}); + auto result = op.evaluate({&x, &y}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -525,7 +525,7 @@ TEST_F(BroadcastableOpsTests, Test_Multiply_8) { auto e = NDArrayFactory::create('c', {1, 1}, {8.f}); nd4j::ops::multiply op; - auto result = op.execute({&x, &y}, {}, {}, {}); + auto result = op.evaluate({&x, &y}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -544,7 +544,7 @@ TEST_F(BroadcastableOpsTests, broadcast_add_1) { NDArray exp('c', {1,4}, {2,3,4,5}, nd4j::DataType::DOUBLE); nd4j::ops::add op; - auto status = op.execute({&x, &y}, {&z}, {}, {}, {}); + auto status = op.execute({&x, &y}, {&z}); ASSERT_EQ(ND4J_STATUS_OK, status); ASSERT_TRUE(z.equalsTo(exp)); @@ -559,7 +559,7 @@ TEST_F(BroadcastableOpsTests, broadcast_equals_1) { NDArray exp('c', {3,4}, {0,0,0,0, 1,1,1,1, 1,1,1,1}, nd4j::DataType::BOOL); nd4j::ops::equals op; - auto status = op.execute({&x, &y}, {&z}, {}, {}, {}); + auto status = op.execute({&x, &y}, {&z}); // z.printIndexedBuffer(); ASSERT_EQ(ND4J_STATUS_OK, status); @@ -603,7 +603,7 @@ TEST_F(BroadcastableOpsTests, broadcast_empty_3) { NDArray e = NDArrayFactory::create('c', {1, 0, 2});; nd4j::ops::maximum op; - auto result = op.execute({&x, &y}, {}, {}); + auto result = op.evaluate({&x, &y}); ASSERT_EQ(Status::OK(), result->status()); @@ -622,7 +622,7 @@ TEST_F(BroadcastableOpsTests, broadcast_empty_4) { NDArray e = NDArrayFactory::create('c', {1, 0, 2});; nd4j::ops::maximum op; - auto result = op.execute({&x, &y}, {}, {}); + auto result = op.evaluate({&x, &y}); ASSERT_EQ(Status::OK(), result->status()); @@ -641,7 +641,7 @@ TEST_F(BroadcastableOpsTests, broadcast_empty_5) { NDArray e = NDArrayFactory::create('c', {1, 0, 2});; nd4j::ops::realdiv op; - auto result = op.execute({&x, &y}, {}, {}); + auto result = op.evaluate({&x, &y}); ASSERT_EQ(Status::OK(), result->status()); @@ -660,7 +660,7 @@ TEST_F(BroadcastableOpsTests, broadcast_empty_6) { NDArray e = NDArrayFactory::create('c', {1, 0, 2});; nd4j::ops::realdiv op; - auto result = op.execute({&x, &y}, {}, {}); + auto result = op.evaluate({&x, &y}); ASSERT_EQ(Status::OK(), result->status()); @@ -679,7 +679,7 @@ TEST_F(BroadcastableOpsTests, broadcast_empty_7) { NDArray e = NDArrayFactory::create('c', {1, 0, 2, 0});; nd4j::ops::realdiv op; - auto result = op.execute({&x, &y}, {}, {}); + auto result = op.evaluate({&x, &y}); ASSERT_EQ(Status::OK(), result->status()); @@ -715,7 +715,7 @@ TEST_F(BroadcastableOpsTests, broadcast_bool_empty_2) { nd4j::ops::greater op; - auto result = op.execute({&x, &y}, {}, {}, {}); + auto result = op.evaluate({&x, &y}); auto z = result->at(0); @@ -741,7 +741,7 @@ TEST_F(BroadcastableOpsTests, broadcast_bool_1) { nd4j::ops::greater op; - auto status = op.execute({&x, &y}, {&z}, {}, {}, {}); + auto status = op.execute({&x, &y}, {&z}); ASSERT_EQ(ND4J_STATUS_OK, status); diff --git a/libnd4j/tests_cpu/layers_tests/ConvolutionTests1.cpp b/libnd4j/tests_cpu/layers_tests/ConvolutionTests1.cpp index 9aafe869e..a6b99f976 100644 --- a/libnd4j/tests_cpu/layers_tests/ConvolutionTests1.cpp +++ b/libnd4j/tests_cpu/layers_tests/ConvolutionTests1.cpp @@ -140,7 +140,7 @@ TYPED_TEST(TypedConvolutionTests1, conv2d_2) { input.linspace(1); nd4j::ops::conv2d op; - auto result = op.execute({&input, &weights}, {}, {1, 1, 1, 1, 0, 0, 1, 1, 0, 0}); + auto result = op.evaluate({&input, &weights}, {}, {1, 1, 1, 1, 0, 0, 1, 1, 0, 0}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -172,7 +172,7 @@ TYPED_TEST(TypedConvolutionTests1, conv2d_3) { weights.linspace(0.1, 0.1); nd4j::ops::conv2d op; - auto results = op.execute({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); + auto results = op.evaluate({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -201,7 +201,7 @@ TYPED_TEST(TypedConvolutionTests1, conv2d_4) { weights.linspace(0.1, 0.1); nd4j::ops::conv2d op; - auto results = op.execute({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); + auto results = op.evaluate({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -231,7 +231,7 @@ TYPED_TEST(TypedConvolutionTests1, conv2d_5) { weights.permutei({2,3,1,0}); nd4j::ops::conv2d op; - auto results = op.execute({&input, &weights, &bias}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); + auto results = op.evaluate({&input, &weights, &bias}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); auto output = results->at(0); // output->printIndexedBuffer(); @@ -250,7 +250,7 @@ TYPED_TEST(TypedConvolutionTests1, conv2d_6) { auto weights = NDArrayFactory::create('c', {1, 2, 12, 2}); nd4j::ops::conv2d op; - auto result = op.execute({&input, &weights}, {}, {-1,-1, 1,1, 0,0, 1,1, 1,1}); + auto result = op.evaluate({&input, &weights}, {}, {-1,-1, 1,1, 0,0, 1,1, 1,1}); ASSERT_EQ(Status::OK(), result->status()); delete result; @@ -271,7 +271,7 @@ TYPED_TEST(TypedConvolutionTests1, conv2d_7) { weights = 3.; nd4j::ops::conv2d op; - auto results = op.execute({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); + auto results = op.evaluate({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); auto* output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -305,7 +305,7 @@ TEST_F(ConvolutionTests1, conv2d_8) { 1.764169, 2.584944, 2.521004, 1.744296, 1.707578, 2.237938, 2.325231, 0.984485, 1.766936, 1.590640, 1.347524, 1.404648, 1.422042, 1.709862, 1.155412}); nd4j::ops::conv2d op; - auto results = op.execute({&input, &weights, &bias}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); + auto results = op.evaluate({&input, &weights, &bias}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); auto output = results->at(0); // output->printBuffer(); @@ -419,7 +419,7 @@ TYPED_TEST(TypedConvolutionTests1, sconv2d_2) { nd4j::ops::sconv2d op; - auto resultFF = op.execute({&input, &weightsD, &weightsP}, {}, {5, 5, 1, 1, 0, 0, 1, 1, 0, 0}, {}); + auto resultFF = op.evaluate({&input, &weightsD, &weightsP}, {5, 5, 1, 1, 0, 0, 1, 1, 0, 0}); auto z = resultFF->at(0); //z->printShapeInfo("FF shape"); @@ -452,8 +452,8 @@ TYPED_TEST(TypedConvolutionTests1, sconv2d_3) { auto expOutput = NDArrayFactory::create('c', {3, 2, 8, 8}); nd4j::ops::sconv2d op; - Nd4jStatus status = op.execute({&input, &weightsD, &weightsP, &bias}, {&output}, {}, {1, 1, 1, 1, 0, 0, 1, 1, 0}, {}); - auto result = op.execute({&input, &weightsD, &weightsP, &bias}, {}, {1, 1, 1, 1, 0, 0, 1, 1, 0}, {}); + Nd4jStatus status = op.execute({&input, &weightsD, &weightsP, &bias}, {&output}, {1, 1, 1, 1, 0, 0, 1, 1, 0}); + auto result = op.evaluate({&input, &weightsD, &weightsP, &bias}, {1, 1, 1, 1, 0, 0, 1, 1, 0}); auto z = result->at(0); @@ -493,7 +493,7 @@ TEST_F(ConvolutionTests1, sconv2d_4) { 0.962232, 0.980667, 1.623775, 1.417320, 1.845710, 1.237095, 1.762792, 1.352515}); nd4j::ops::sconv2d op; - auto results = op.execute({&input, &weightsD, &weightsP, &biases}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); + auto results = op.evaluate({&input, &weightsD, &weightsP, &biases}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); auto* output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -531,7 +531,7 @@ TYPED_TEST(TypedConvolutionTests1, conv2D_BP_Bias_1) { nd4j::ops::conv2d_bp op; - auto results = op.execute({&input, &weights, &bias, &epsilonNext}, {}, {3, 3, 1, 1, 0, 0, 1, 1, 1}, {}); + auto results = op.evaluate({&input, &weights, &bias, &epsilonNext}, {}, {3, 3, 1, 1, 0, 0, 1, 1, 1}, {}); ASSERT_TRUE(results->size() == 3); @@ -581,7 +581,7 @@ TYPED_TEST(TypedConvolutionTests1, conv2D_BP_NoBias_1) { nd4j::ops::conv2d_bp op; - auto results = op.execute({&input, &weights, &epsilonNext}, {}, {3, 3, 1, 1, 0, 0, 1, 1, 1}, {}); + auto results = op.evaluate({&input, &weights, &epsilonNext}, {}, {3, 3, 1, 1, 0, 0, 1, 1, 1}, {}); ASSERT_TRUE(results->size() == 2); @@ -664,7 +664,7 @@ TYPED_TEST(TypedConvolutionTests1, sconv2d_conv2d_1) { input.linspace(1); nd4j::ops::sconv2d op; - auto resultFF = op.execute({&input, &weightsD}, {}, {5, 5, 1, 1, 0, 0, 1, 1, 0}, {}); + auto resultFF = op.evaluate({&input, &weightsD}, {}, {5, 5, 1, 1, 0, 0, 1, 1, 0}, {}); auto z = resultFF->at(0); @@ -674,7 +674,7 @@ TYPED_TEST(TypedConvolutionTests1, sconv2d_conv2d_1) { nd4j::ops::conv2d op2d; // weightsP.printShapeInfo(); - auto result2D = op2d.execute({z, &weightsP}, {}, {1, 1, 1, 1, 0, 0, 1, 1, 0, 0}, {}); + auto result2D = op2d.evaluate({z, &weightsP}, {}, {1, 1, 1, 1, 0, 0, 1, 1, 0, 0}, {}); auto z2d = result2D->at(0); // z2d->printBuffer(); @@ -717,7 +717,7 @@ TEST_F(ConvolutionTests1, TestDeconv_bp_1) { nd4j::ops::deconv2d_bp op; - auto results = op.execute({&input, &weights, &bias, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); + auto results = op.evaluate({&input, &weights, &bias, &gradO}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -771,7 +771,7 @@ TEST_F(ConvolutionTests1, TestDeconv_bp_2) { nd4j::ops::deconv2d_bp op; - auto result = op.execute({&input, &weights, &bias, &epsilon}, {}, {2, 2, 1, 1, 0, 0, 2, 2, 0}); + auto result = op.evaluate({&input, &weights, &bias, &epsilon}, {}, {2, 2, 1, 1, 0, 0, 2, 2, 0}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -791,7 +791,7 @@ TYPED_TEST(TypedConvolutionTests1, Test_Conv1D_ff_1) { bias.linspace(1); nd4j::ops::conv1d op; - auto result_FF = op.execute({&input, &weights, &bias}, {}, {2, 1, 0, 1, 0, 0}); + auto result_FF = op.evaluate({&input, &weights, &bias}, {}, {2, 1, 0, 1, 0, 0}); ASSERT_EQ(ND4J_STATUS_OK, result_FF->status()); @@ -805,7 +805,7 @@ TYPED_TEST(TypedConvolutionTests1, Test_Conv1D_ff_1) { auto epsilonNxt = new NDArray(z->dup()); epsilonNxt->linspace(1); - auto result_BP = op_bp.execute({&input, &weights, &bias, epsilonNxt}, {}, {2, 1, 0, 1, 0, 0}); + auto result_BP = op_bp.evaluate({&input, &weights, &bias, epsilonNxt}, {}, {2, 1, 0, 1, 0, 0}); ASSERT_EQ(ND4J_STATUS_OK, result_BP->status()); auto eps = result_BP->at(0); @@ -833,7 +833,7 @@ TYPED_TEST(TypedConvolutionTests1, Test_Conv1D_ff_2) { input.linspace(1); nd4j::ops::conv1d op; - auto result = op.execute({&input, &weights}, {}, {2, 1, 0, 1, 1,0}); + auto result = op.evaluate({&input, &weights}, {}, {2, 1, 0, 1, 1,0}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -860,7 +860,7 @@ TEST_F(ConvolutionTests1, conv1d_causal_1) { weights.linspace(0.1, 0.1); nd4j::ops::conv1d op; - auto results = op.execute({&input, &weights, &bias}, {}, {kW, sW, pW, dW, paddingMode, dataFormat}); + auto results = op.evaluate({&input, &weights, &bias}, {kW, sW, pW, dW, paddingMode, dataFormat}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -892,7 +892,7 @@ TEST_F(ConvolutionTests1, conv1d_causal_2) { weights.linspace(0.1, 0.1); nd4j::ops::conv1d op; - auto results = op.execute({&input, &weights, &bias}, {}, {kW, sW, pW, dW, paddingMode, dataFormat}); + auto results = op.evaluate({&input, &weights, &bias}, {kW, sW, pW, dW, paddingMode, dataFormat}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -923,7 +923,7 @@ TEST_F(ConvolutionTests1, conv1d_causal_3) { weights.linspace(0.1, 0.1); nd4j::ops::conv1d op; - auto results = op.execute({&input, &weights, &bias}, {}, {kW, sW, pW, dW, paddingMode, dataFormat}); + auto results = op.evaluate({&input, &weights, &bias}, {kW, sW, pW, dW, paddingMode, dataFormat}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -954,7 +954,7 @@ TEST_F(ConvolutionTests1, conv1d_causal_4) { weights.linspace(0.1, 0.1); nd4j::ops::conv1d op; - auto results = op.execute({&input, &weights, &bias}, {}, {kW, sW, pW, dW, paddingMode, dataFormat}); + auto results = op.evaluate({&input, &weights, &bias}, {kW, sW, pW, dW, paddingMode, dataFormat}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -985,7 +985,7 @@ TEST_F(ConvolutionTests1, conv1d_causal_5) { weights.linspace(0.1, 0.1); nd4j::ops::conv1d op; - auto results = op.execute({&input, &weights, &bias}, {}, {kW, sW, pW, dW, paddingMode, dataFormat}); + auto results = op.evaluate({&input, &weights, &bias}, {kW, sW, pW, dW, paddingMode, dataFormat}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -1016,7 +1016,7 @@ TEST_F(ConvolutionTests1, conv1d_causal_6) { weights.linspace(0.1, 0.1); nd4j::ops::conv1d op; - auto results = op.execute({&input, &weights, &bias}, {}, {kW, sW, pW, dW, paddingMode, dataFormat}); + auto results = op.evaluate({&input, &weights, &bias}, {kW, sW, pW, dW, paddingMode, dataFormat}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -1048,7 +1048,7 @@ TEST_F(ConvolutionTests1, conv1d_causal_7) { weights.linspace(0.1, 0.1); nd4j::ops::conv1d op; - auto results = op.execute({&input, &weights}, {}, {kW, sW, pW, dW, paddingMode, dataFormat}); + auto results = op.evaluate({&input, &weights}, {kW, sW, pW, dW, paddingMode, dataFormat}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -1081,7 +1081,7 @@ TEST_F(ConvolutionTests1, conv1d_causal_8) { weights.linspace(0.1, 0.1); nd4j::ops::conv1d op; - auto results = op.execute({&input, &weights}, {}, {kW, sW, pW, dW, paddingMode, dataFormat}); + auto results = op.evaluate({&input, &weights}, {kW, sW, pW, dW, paddingMode, dataFormat}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -1129,7 +1129,7 @@ TEST_F(ConvolutionTests1, Test_Dilation2D_1) { weights.linspace(1); nd4j::ops::dilation2d op; - auto result = op.execute({&input, &weights}, {}, {1, 1,2,2,1, 1,2,2,1}); + auto result = op.evaluate({&input, &weights}, {1, 1,2,2,1, 1,2,2,1}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -1149,7 +1149,7 @@ TEST_F(ConvolutionTests1, Test_Dilation2D_2) { weights.linspace(1); nd4j::ops::dilation2d op; - auto result = op.execute({&input, &weights}, {}, {0, 1,2,2,1, 1,2,2,1}); + auto result = op.evaluate({&input, &weights}, {0, 1,2,2,1, 1,2,2,1}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -1188,7 +1188,7 @@ TYPED_TEST(TypedConvolutionTests1, conv2d_bp_test1) { gradO.linspace(0.01, 0.01); nd4j::ops::conv2d_bp op; - auto results = op.execute({&input, &weights, &bias, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); + auto results = op.evaluate({&input, &weights, &bias, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); auto gradI = results->at(0); auto gradW = results->at(1); @@ -1231,7 +1231,7 @@ TYPED_TEST(TypedConvolutionTests1, conv2d_bp_test2) { gradO.linspace(0.01, 0.01); nd4j::ops::conv2d_bp op; - auto results = op.execute({&input, &weights, &bias, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); + auto results = op.evaluate({&input, &weights, &bias, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); auto gradI = results->at(0); auto gradW = results->at(1); @@ -1276,7 +1276,7 @@ TYPED_TEST(TypedConvolutionTests1, conv2d_bp_test3) { expGradW.permutei({2,3,1,0}); nd4j::ops::conv2d_bp op; - auto results = op.execute({&input, &weights, &bias, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); + auto results = op.evaluate({&input, &weights, &bias, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); auto gradI = results->at(0); auto gradW = results->at(1); auto gradB = results->at(2); @@ -1358,7 +1358,7 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_bp_test1) { gradO.linspace(0.01, 0.01); nd4j::ops::conv3dnew_bp op; - auto results = op.execute({&input, &weights, &bias, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}); + auto results = op.evaluate({&input, &weights, &bias, &gradO}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}); auto gradI = results->at(0); auto gradW = results->at(1); @@ -1406,7 +1406,7 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_bp_test2) { gradO.linspace(0.01, 0.01); nd4j::ops::conv3dnew_bp op; - auto results = op.execute({&input, &weights, &bias, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}); + auto results = op.evaluate({&input, &weights, &bias, &gradO}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}); auto gradI = results->at(0); auto gradW = results->at(1); @@ -1459,7 +1459,7 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_bp_test3) { expGradW.permutei({2, 3, 4, 1, 0}); nd4j::ops::conv3dnew_bp op; - auto results = op.execute({&input, &weights, &bias, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}); + auto results = op.evaluate({&input, &weights, &bias, &gradO}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}); auto* gradI = results->at(0); auto* gradW = results->at(1); auto* gradB = results->at(2); @@ -1502,7 +1502,7 @@ TEST_F(ConvolutionTests1, depthwise_conv2d_bp_test1) { gradO.linspace(0.01, 0.01); nd4j::ops::depthwise_conv2d_bp op; - auto results = op.execute({&input, &weights, &bias, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); + auto results = op.evaluate({&input, &weights, &bias, &gradO}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); auto* gradI = results->at(0); auto* gradW = results->at(1); @@ -1540,7 +1540,7 @@ TEST_F(ConvolutionTests1, depthwise_conv2d_bp_test2) { gradO.linspace(0.01, 0.01); nd4j::ops::depthwise_conv2d_bp op; - auto results = op.execute({&input, &weights, &bias, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); + auto results = op.evaluate({&input, &weights, &bias, &gradO}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); auto* gradI = results->at(0); auto* gradW = results->at(1); @@ -1568,7 +1568,7 @@ TEST_F(ConvolutionTests1, depthwise_conv2d_bp_test3) { auto gradB = b.like(); nd4j:ops::depthwise_conv2d_bp op; - auto status = op.execute({&in, &w, &b, &grad}, {&gradI, &gradW, &gradB}, {}, {2, 2, 1, 1, 0, 0, 1, 1, 1, 0}, {}); + auto status = op.execute({&in, &w, &b, &grad}, {&gradI, &gradW, &gradB}, {2, 2, 1, 1, 0, 0, 1, 1, 1, 0}); ASSERT_EQ(Status::OK(), status); } @@ -1607,7 +1607,7 @@ TEST_F(ConvolutionTests1, depthwise_conv2d_bp_test4) { NDArray expGradB('c', {oC}, {-2960., -2970., -2980., -2990., -3000., -3010., -3020., -3030.}, nd4j::DataType::FLOAT32); nd4j::ops::depthwise_conv2d_bp op; - ResultSet* results = op.execute({&input, &weights, &bias, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); + ResultSet* results = op.evaluate({&input, &weights, &bias, &gradO}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); NDArray* gradI = results->at(0); NDArray* gradW = results->at(1); NDArray* gradB = results->at(2); @@ -1662,7 +1662,7 @@ TEST_F(ConvolutionTests1, depthwise_conv2d_bp_test5) { NDArray expGradB('c', {oC}, {505., -495., -1495., -2495., -3495., -4494.999512, -5495., -6495.}, nd4j::DataType::FLOAT32); nd4j::ops::depthwise_conv2d_bp op; - ResultSet* results = op.execute({&input, &weights, &bias, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); + ResultSet* results = op.evaluate({&input, &weights, &bias, &gradO}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); NDArray* gradI = results->at(0); NDArray* gradW = results->at(1); NDArray* gradB = results->at(2); @@ -1706,7 +1706,7 @@ TEST_F(ConvolutionTests1, depthwise_conv2d_bp_test6) { gradO.linspace(0.01, 0.01); nd4j::ops::depthwise_conv2d_bp op; - auto results = op.execute({&input, &weights, &bias, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); + auto results = op.evaluate({&input, &weights, &bias, &gradO}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); auto* gradI = results->at(0); auto* gradW = results->at(1); @@ -1742,7 +1742,7 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_test1) { weights = 1.; nd4j::ops::conv3dnew op; - auto results = op.execute({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}); + auto results = op.evaluate({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}); auto* output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -1774,7 +1774,7 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_test2) { weights.linspace(0.1, 0.1); nd4j::ops::conv3dnew op; - auto results = op.execute({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}); + auto results = op.evaluate({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}); auto* output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -1801,7 +1801,7 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_test3) { weights.linspace(0.1, 0.1); nd4j::ops::conv3dnew op; - auto results = op.execute({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}); + auto results = op.evaluate({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}); auto* output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -1827,7 +1827,7 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_test4) { expected = 48.; nd4j::ops::conv3dnew op; - auto results = op.execute({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}); + auto results = op.evaluate({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}); auto* output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -1855,7 +1855,7 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_test5) { bias = 1.; nd4j::ops::conv3dnew op; - auto results = op.execute({&input, &weights, &bias}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}); + auto results = op.evaluate({&input, &weights, &bias}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}); auto* output = results->at(0); // output->printIndexedBuffer(); @@ -1884,7 +1884,7 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_test6) { weights = 0.5; nd4j::ops::conv3dnew op; - auto results = op.execute({&input, &weights, &bias}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}); + auto results = op.evaluate({&input, &weights, &bias}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}); auto* output = results->at(0); // output->printIndexedBuffer(); @@ -1915,7 +1915,7 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_test7) { weights.permutei({2, 3, 4, 1, 0}); nd4j::ops::conv3dnew op; - auto results = op.execute({&input, &weights, &bias}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}); + auto results = op.evaluate({&input, &weights, &bias}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}); auto* output = results->at(0); // output->printIndexedBuffer(); @@ -1944,7 +1944,7 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_test8) { weights.permutei({2, 3, 4, 1, 0}); nd4j::ops::conv3dnew op; - auto results = op.execute({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}); + auto results = op.evaluate({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}); auto* output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -1961,7 +1961,7 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_test9) { auto e = NDArrayFactory::create('c', {4, 1, 7, 10, 4}); nd4j::ops::conv3dnew op; - auto result = op.execute({&x, &y}, {}, {2,5,5, 5,4,3, 0,0,0, 1,1,1, 1,1}); + auto result = op.evaluate({&x, &y}, {}, {2,5,5, 5,4,3, 0,0,0, 1,1,1, 1,1}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -1977,7 +1977,7 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_test10) { auto exp = NDArrayFactory::create('c', {4, 1, 7, 10, 4}); nd4j::ops::conv3dnew op; - auto result = op.execute({&x, &w}, {}, {2,5,5, 5,4,3, 0,0,0, 1,1,1, 1,1}); + auto result = op.evaluate({&x, &w}, {}, {2,5,5, 5,4,3, 0,0,0, 1,1,1, 1,1}); ASSERT_EQ(Status::OK(), result->status()); ShapeList shapeList({x.shapeInfo(), w.shapeInfo()}); @@ -2039,7 +2039,7 @@ TYPED_TEST(TypedConvolutionTests1, pointwise_conv2d_test1) { bias = 1.; nd4j::ops::pointwise_conv2d op; - auto results = op.execute({&input, &weights, &bias}, {}, {dataFormat}); + auto results = op.evaluate({&input, &weights, &bias}, {}, {dataFormat}); auto* output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -2063,7 +2063,7 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_test11) { weights = 1.; nd4j::ops::conv3dnew op; - auto results = op.execute({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}); + auto results = op.evaluate({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}); auto* output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -2087,7 +2087,7 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_test12) { weights = 1.; nd4j::ops::conv3dnew op; - auto results = op.execute({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}); + auto results = op.evaluate({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}); auto* output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -2205,7 +2205,7 @@ TEST_F(ConvolutionTests1, upsampling2d_test1) { 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f}); nd4j::ops::upsampling2d op; - auto results = op.execute({&input}, {}, {factorH, factorW, isNCHW}); + auto results = op.evaluate({&input}, {factorH, factorW, isNCHW}); auto* output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -2233,7 +2233,7 @@ TEST_F(ConvolutionTests1, upsampling2d_test2) { 33.f, 33.f, 33.f, 34.f, 34.f, 34.f, 33.f, 33.f, 33.f, 34.f, 34.f, 34.f, 35.f, 35.f, 35.f, 36.f, 36.f, 36.f, 35.f, 35.f, 35.f, 36.f, 36.f, 36.f}); nd4j::ops::upsampling2d op; - auto results = op.execute({&input}, {}, {factorH, factorW, isNCHW}); + auto results = op.evaluate({&input}, {factorH, factorW, isNCHW}); auto* output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -2271,7 +2271,7 @@ TEST_F(ConvolutionTests1, upsampling3d_test1) { 67.f, 68.f, 69.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f, 67.f, 68.f, 69.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f}); nd4j::ops::upsampling3d op; - auto results = op.execute({&input}, {}, {factorD, factorH, factorW, isNCDHW}); + auto results = op.evaluate({&input}, {factorD, factorH, factorW, isNCDHW}); auto* output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -2305,7 +2305,7 @@ TEST_F(ConvolutionTests1, upsampling3d_test2) { 65.f, 65.f, 66.f, 66.f, 65.f, 65.f, 66.f, 66.f, 65.f, 65.f, 66.f, 66.f, 67.f, 67.f, 68.f, 68.f, 67.f, 67.f, 68.f, 68.f, 67.f, 67.f, 68.f, 68.f, 69.f, 69.f, 70.f, 70.f, 69.f, 69.f, 70.f, 70.f, 69.f, 69.f, 70.f, 70.f, 71.f, 71.f, 72.f, 72.f, 71.f, 71.f, 72.f, 72.f, 71.f, 71.f, 72.f, 72.f, 69.f, 69.f, 70.f, 70.f, 69.f, 69.f, 70.f, 70.f, 69.f, 69.f, 70.f, 70.f, 71.f, 71.f, 72.f, 72.f, 71.f, 71.f, 72.f, 72.f, 71.f, 71.f, 72.f, 72.f}); nd4j::ops::upsampling3d op; - auto results = op.execute({&input}, {}, {factorD, factorH, factorW, isNCDHW}); + auto results = op.evaluate({&input}, {factorD, factorH, factorW, isNCDHW}); auto* output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -2332,7 +2332,7 @@ TEST_F(ConvolutionTests1, upsampling3d_bp_test1) { expGradI = 8.; nd4j::ops::upsampling3d_bp op; - auto results = op.execute({&input, &gradO}, {}, {isNCDHW}); + auto results = op.evaluate({&input, &gradO}, {isNCDHW}); auto* gradI = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -2359,7 +2359,7 @@ TYPED_TEST(TypedConvolutionTests1, conv2D_input_BP_test1) { nd4j::ops::conv2d_input_bp op; - auto results = op.execute({&inputShape, &weights, &epsilonNext}, {}, {3, 3, 1, 1, 0, 0, 1, 1, 1}); + auto results = op.evaluate({&inputShape, &weights, &epsilonNext}, {}, {3, 3, 1, 1, 0, 0, 1, 1, 1}); ASSERT_TRUE(results->size() == 1); @@ -2424,7 +2424,7 @@ TEST_F(ConvolutionTests1, upsampling3d_bp_test3) { 4.225355, 4.377341, 4.4398847, 4.710785, 4.4199953, 3.928307, 4.8769503}, nd4j::DataType::FLOAT32); nd4j::ops::upsampling3d_bp op; - auto results = op.execute({&input, &gradO}, {}, {isNCDHW}); + auto results = op.evaluate({&input, &gradO}, {isNCDHW}); auto* gradI = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -2457,7 +2457,7 @@ TEST_F(ConvolutionTests1, deconv2d_test1) { weights.linspace(0.1, 0.1); nd4j::ops::deconv2d op; - auto results = op.execute({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); + auto results = op.evaluate({&input, &weights}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); ASSERT_EQ(Status::OK(), results->status()); auto output = results->at(0); @@ -2490,7 +2490,7 @@ TEST_F(ConvolutionTests1, deconv2d_test2) { weights.linspace(0.1, 0.1); nd4j::ops::deconv2d op; - auto results = op.execute({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); + auto results = op.evaluate({&input, &weights}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -2522,7 +2522,7 @@ TEST_F(ConvolutionTests1, deconv2d_test3) { bias = 0.2; nd4j::ops::deconv2d op; - auto results = op.execute({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); + auto results = op.evaluate({&input, &weights}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); ASSERT_EQ(Status::OK(), results->status()); auto output = results->at(0); @@ -2557,7 +2557,7 @@ TEST_F(ConvolutionTests1, deconv2d_test4) { weights.permutei({2,3,1,0}); nd4j::ops::deconv2d op; - auto result = op.execute({&input, &weights}, {}, {5, 5, 1, 1, 0, 0, 1, 1, 0, 0}); + auto result = op.evaluate({&input, &weights}, {5, 5, 1, 1, 0, 0, 1, 1, 0, 0}); auto z = result->at(0); // z->printShapeInfo(); @@ -2584,7 +2584,7 @@ TEST_F(ConvolutionTests1, deconv2d_test5) { weights.permutei({2,3,1,0}); nd4j::ops::deconv2d op; - auto result = op.execute({&input, &weights}, {&z}, {}, {5, 5, 1, 1, 0, 0, 1, 1, 0, 0},{}); + auto result = op.execute({&input, &weights}, {&z}, {5, 5, 1, 1, 0, 0, 1, 1, 0, 0}); ASSERT_EQ(ND4J_STATUS_OK, result); @@ -2615,7 +2615,7 @@ TYPED_TEST(TypedConvolutionTests1, deconv2d_test6) { input.linspace(1); nd4j::ops::deconv2d op; - auto results = op.execute({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); + auto results = op.evaluate({&input, &weights}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); ASSERT_EQ(Status::OK(), results->status()); @@ -2640,7 +2640,7 @@ TEST_F(ConvolutionTests1, deconv2d_test7) { nd4j::ops::deconv2d op; - auto result = op.execute({&input, &weights, &bias}, {}, {1, 1, 1, 1, 0, 0, 1, 1, 1, 0}); + auto result = op.evaluate({&input, &weights, &bias}, {1, 1, 1, 1, 0, 0, 1, 1, 1, 0}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -2683,7 +2683,7 @@ TEST_F(ConvolutionTests1, deconv2d_test8) { 1.471922, 1.484062, 1.212039, 1.144419, 1.266123}); nd4j::ops::deconv2d op; - auto results = op.execute({&input, &weights, &bias}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); + auto results = op.evaluate({&input, &weights, &bias}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); ASSERT_EQ(Status::OK(), results->status()); @@ -2718,7 +2718,7 @@ TYPED_TEST(TypedConvolutionTests1, deconv2d_tf_test1) { weights.linspace(0.1, 0.1); nd4j::ops::deconv2d_tf op; - auto results = op.execute({&outShape, &weights, &input}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); + auto results = op.evaluate({&outShape, &weights, &input}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); diff --git a/libnd4j/tests_cpu/layers_tests/ConvolutionTests2.cpp b/libnd4j/tests_cpu/layers_tests/ConvolutionTests2.cpp index 989d316de..376234019 100644 --- a/libnd4j/tests_cpu/layers_tests/ConvolutionTests2.cpp +++ b/libnd4j/tests_cpu/layers_tests/ConvolutionTests2.cpp @@ -79,7 +79,7 @@ TEST_F(ConvolutionTests2, im2col_1) { image.linspace(1, 1); nd4j::ops::im2col op; - auto results = op.execute({&image}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode}); + auto results = op.evaluate({&image}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode}); auto column = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -122,7 +122,7 @@ TYPED_TEST(TypedConvolutionTests2, deconv2d_tf_test2) { weights.linspace(0.1, 0.1); nd4j::ops::deconv2d_tf op; - auto results = op.execute({&outShape, &weights, &input}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); + auto results = op.evaluate({&outShape, &weights, &input}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -140,7 +140,7 @@ TYPED_TEST(TypedConvolutionTests2, Test_DeConv2D_TF_1) { auto exp = NDArrayFactory::create('c', {12, 5, 5, 32}); nd4j::ops::deconv2d_tf op; - auto result = op.execute({&input0, &input1, &input2}, {}, {2, 2, 1, 1, 0, 0, 1, 1, 0, 1}); + auto result = op.evaluate({&input0, &input1, &input2}, {}, {2, 2, 1, 1, 0, 0, 1, 1, 0, 1}); ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(exp, *result->at(0)); @@ -170,7 +170,7 @@ TYPED_TEST(TypedConvolutionTests2, Test_DeConv2D_TF_2) { auto exp = NDArrayFactory::create('c', {3, 8, 8, 16}, {5.98743296f, -2.83037376f, -0.87943113f, 1.41339970f, 1.32433391f, -1.20299149f, -0.02893090f, 2.05326009f, 1.19417048f, 5.58212376f, 3.28139353f, 1.19237995f, -1.09431255f, -2.55264497f, 3.11014652f, 6.81296825f, -2.09029293f, -4.32068443f, -0.52808392f, -1.97968531f, -0.18673831f, 0.84605980f, 4.55825520f, 2.71503139f, 0.15210046f, 0.85310984f, -3.82062817f, 2.76470995f, 3.69004202f, -1.45017099f, -2.59361267f, -1.35094655f, 7.24145126f, -5.25432396f, 0.19920218f, -4.30596399f, 1.35318923f, -3.88142037f, 3.67493343f, 2.25931478f, 2.87630725f, 1.66349852f, 6.21347952f, 0.94105923f, -1.61742055f, -2.35699606f, 0.12850338f, 1.79141688f, -2.09535933f, -6.35418081f, -0.06303531f, -4.38615131f, 0.48237842f, 0.26528549f, 3.38231516f, 3.76315165f, -0.40254810f, -0.23716694f, -6.13381910f, -0.41950428f, -0.89680839f, -1.46491277f, -1.98541689f, -0.99357355f, 5.58237648f, -2.38937521f, -0.00872564f, -2.37138414f, 4.91117287f, -4.51916361f, 0.97943687f, 2.91052818f, -2.50362611f, 1.70252812f, 5.04137802f, 3.57108784f, -1.87532270f, -3.66677809f, -2.38861251f, 5.55765152f, -7.27571774f, -1.68887305f, -0.72266489f, -4.42809057f, -0.92118186f, 1.02381468f, 4.44284725f, 5.17150497f, -0.42438728f, 2.02693963f, -1.36484981f, -1.47912180f, 0.26649538f, -0.02091765f, -2.86906910f, -3.03046989f, 1.35122132f, -3.21707630f, 2.21112418f, 0.24121630f, 3.96940088f, -7.66105747f, 2.76352382f, -0.99061489f, -2.16720009f, -1.63170409f, 1.12701774f, -1.02415371f, -0.90435314f, -1.51372027f, -0.76884907f, 0.39066136f, -0.89562428f, -2.03204703f, 1.28074932f, -2.14551091f, -2.36843777f, 0.46580017f, 0.75451565f, -0.00336730f, -1.06597757f, 3.27195978f, -0.41307712f, -0.10376054f, -1.34102952f, -2.22901654f, 2.31929803f, 1.40851438f, -2.23774385f, 0.20417206f, -1.12153268f, -0.13188094f, -3.96649432f, 2.10269976f, 0.49845099f, 6.18937683f, -0.51783508f, -0.48048639f, -1.92970264f, 3.16670656f, 1.13355756f, -0.07890664f, 1.31536257f, -0.43924797f, -0.04562932f, -0.87974954f, 0.75411212f, -2.39745235f, -3.97132111f, 0.37202546f, -2.40399146f, -1.50796390f, -3.08302689f, 0.23075986f, -0.94316757f, 1.34948587f, 0.58591264f, 2.18529797f, 7.97652435f, 2.32798409f, -4.09404373f, 0.89634895f, 0.77697754f, -0.65091681f, -7.05506849f, 5.86194515f, 2.51394033f, 4.69959354f, 0.20835471f, 3.18049693f, -1.29682434f, 3.70832396f, -0.48123091f, -1.67904007f, -1.35418940f, 1.58435583f, -1.13851106f, -1.19225955f, 0.59713769f, -5.80462933f, -7.45143986f, -1.08658695f, 1.03244078f, -1.75307107f, -7.07100582f, 3.85825157f, 1.62127817f, 2.32572675f, 0.56171900f, -0.80591971f, 3.98835945f, 0.15742642f, -2.97832179f, 0.13821673f, -0.72556758f, -0.84936106f, -7.28444147f, 3.94134307f, 0.80779338f, 7.47784615f, 8.23335075f, 4.80595016f, -4.89574575f, 4.03362942f, -6.67522192f, -4.55204487f, 2.12511182f, -2.70781207f, -1.57226098f, -3.08408356f, -0.30812448f, -5.32870674f, -5.13238287f, 0.49605465f, -0.55042171f, 0.46324944f, -3.83545256f, -0.12562510f, -0.20978995f, -0.13068712f, -1.92144060f, -1.68787408f, 5.45581436f, -0.79583496f, -2.38866687f, -3.90546346f, -0.47028148f, -0.14319679f, -3.37016582f, 2.00905991f, -1.21345615f, 1.81376505f, 7.73004007f, 0.74310112f, -4.64536428f, 3.78111577f, -9.05182457f, -0.10674095f, 1.53476238f, 0.63345337f, -0.40907967f, -1.44729769f, -1.87145400f, -2.46623540f, 1.07472968f, 0.77390999f, -3.93438888f, 4.49174690f, -0.96686655f, 1.92278123f, 0.30049133f, -0.02388665f, -1.99777114f, -3.23885751f, 5.87784004f, 2.13776040f, 3.56758308f, -3.37774134f, -3.67526293f, 1.63700044f, -1.69959962f, -0.99112594f, 6.03103638f, 1.67399430f, -1.28699589f, 7.16759014f, 12.63490295f, 3.62937450f, -4.75982571f, 2.17861104f, -2.03065681f, 4.30207729f, -0.46797156f, -2.96022511f, -6.02702332f, 3.09229851f, -1.39771092f, -0.03471333f, 3.22175527f, 5.63565636f, 1.78195477f, -0.63545251f, -3.99497652f, 1.46043062f, 4.60050488f, -2.96651959f, -2.03159475f, -1.52386189f, -0.15129802f, -3.90390921f, -0.63852370f, 0.79210538f, 2.35288715f, -5.55609035f, 5.36427498f, -0.60248077f, -0.26181316f, 5.04884720f, 8.53192806f, 5.05080223f, -6.56371737f, 1.52260923f, -7.13623667f, 6.49414349f, 2.33445597f, -4.11490965f, -6.44347477f, -0.47079402f, -0.63467920f, 2.60399365f, 1.05958164f, 3.66901422f, -1.05657935f, 1.88611507f, -6.37475634f, 2.01480770f, 3.36020517f, -5.11001921f, -0.46132171f, 2.16525555f, 4.21938848f, -2.08346295f, 2.86168146f, 1.26987600f, 6.76066971f, -7.84916353f, 4.11700916f, 0.47985530f, -4.60113716f, 7.42062473f, 6.37472820f, 4.37820530f, -7.12197018f, 0.01357239f, -7.90392113f, 8.32131577f, -0.87593079f, -0.16994858f, -5.86345863f, -0.20697471f, -1.37845206f, 1.63819647f, 1.59720242f, -0.74357712f, -1.88725603f, -1.98357940f, -8.57950306f, -4.10104513f, 3.57231879f, -2.89855957f, -0.11263305f, 2.78033924f, 1.53078973f, -2.93089223f, 0.73189604f, 3.20563078f, 3.92601013f, -5.21916151f, 0.89163935f, -0.42978728f, -6.70888853f, 4.56477976f, 1.20105875f, 3.83393812f, -6.27205181f, 4.05993128f, -7.35513067f, 1.60660768f, -1.21052051f, 1.58191252f, -1.37899971f, -1.20117283f, 2.93301678f, 1.06302834f, 1.38993621f, -1.66884089f, -3.34452581f, 1.04498529f, -4.10412455f, -4.03310585f, 1.61513603f, -1.09388447f, 2.11451387f, -0.94192362f, -0.23287666f, 5.88265705f, -0.83010495f, -2.15317154f, -0.60276151f, -1.49265075f, 3.93397975f, 5.45194483f, 1.45161700f, -2.57401872f, -5.59288931f, 4.29170895f, 1.87151814f, 0.08362055f, -0.28767288f, 1.17675185f, 0.85266006f, 1.30549634f, -5.60830832f, 0.19398519f, -0.83982587f, 1.75940764f, -5.46077394f, 1.64495635f, 0.17102760f, -0.54459631f, -2.21975255f, -0.37443402f, -2.08474159f, 1.85959935f, 11.19680309f, -0.18611598f, -2.59765387f, 3.06330776f, -1.52183700f, -4.88415241f, -0.75097847f, 2.58201051f, 7.40885210f, 3.58994508f, 1.62457407f, 3.12514591f, -4.36833286f, 1.39830995f, 3.61003447f, -0.63837433f, -3.62661815f, 3.78898096f, 2.92802262f, 5.87374496f, -4.38554621f, -2.53411579f, -2.87311554f, -1.31391978f, -4.26736879f, 3.45099425f, 1.58769250f, 1.73341393f, -1.08842182f, 2.27120280f, -1.78938174f, -2.29940319f, 7.07046986f, 0.51426595f, -6.22928905f, 5.28968811f, 2.31827855f, -4.20915890f, -1.27249205f, 5.92120600f, 3.19458675f, 7.09252501f, 3.96577907f, 6.41484213f, -4.66009521f, 10.00181389f, 0.51108456f, -4.62243366f, -5.18351841f, 2.12961674f, 5.10694027f, 7.29412317f, 0.15912467f, -3.38902974f, -4.01918602f, -2.17383957f, 0.13118666f, 0.27872476f, -0.92317247f, 3.51440644f, 1.84171486f, 1.03378081f, 1.30569839f, -2.09583759f, 9.03952980f, -0.55187917f, -2.04549074f, 1.08294606f, -2.65263700f, -2.93977118f, 1.88909876f, 0.96043622f, 1.76579499f, 3.14314699f, 5.86394691f, 7.36944389f, -7.04524136f, 6.68673229f, -5.52591467f, -2.19745898f, -4.32036924f, 0.52971321f, 2.26268244f, 6.91575766f, -0.94590527f, -3.98923349f, -0.12266219f, 0.24294075f, -1.07783222f, 1.87989080f, -3.57109427f, 1.61553633f, 0.42486978f, 0.75852054f, -6.19481468f, -3.80570698f, 2.39946675f, -1.93851781f, -5.42234039f, -6.34092760f, -2.52374983f, -1.85044456f, 3.92693520f, 0.40042299f, 4.69742584f, 5.40483189f, -1.02398944f, 8.89605045f, 0.64680403f, 0.89943957f, 0.76993859f, -1.88244629f, 1.90714884f, 3.10836840f, -0.17064989f, 0.84892416f, -6.94988108f, 1.92141032f, -1.36458397f, 6.39284658f, 0.45201308f, 2.58823442f, 6.33375788f, -4.76916075f, -8.45738983f, -0.48962492f, 2.40652561f, 4.56602001f, -3.34420681f, 1.86862195f, -7.01420689f, -6.94657421f, -2.47419310f, -4.61693668f, -0.18822384f, -0.36949772f, 2.01374269f, 4.11018658f, -5.11564064f, 8.04294395f, 2.88567662f, -2.87645102f, -1.23238611f, -5.91409397f, -0.62205851f, 1.38689423f, -0.01120412f, 5.25955677f, -1.98474956f, -3.72012186f, 3.00445986f, 4.99141550f, 2.97457719f, 2.70827627f, 6.04544449f, -0.20756161f, -10.87035751f, 0.80454814f, 0.33568168f, -2.48132324f, -2.84452009f, 2.63126230f, -3.99351716f, -7.39294338f, 3.62798953f, -8.65815926f, 2.65992808f, -6.98126554f, 3.09881067f, 0.67735767f, -1.15946686f, 5.63180256f, -0.17694545f, -8.59651184f, 3.75297594f, -2.35913754f, -0.20330384f, 5.49958467f, 1.00861740f, 1.42849684f, 0.00062013f, -0.11073381f, 2.15207863f, 4.07368469f, 1.14344299f, -1.27953362f, 6.64699316f, -0.73672432f, -8.55606937f, -0.19439441f, -4.14319754f, -4.69964647f, -5.86446047f, 2.87106085f, -3.42714882f, -5.00668287f, 6.22464132f, -7.72335291f, 4.05667686f, -5.72637177f, 6.35073948f, -1.29593158f, 0.00813985f, 3.63368607f, -1.05764008f, -7.88486052f, 3.73919106f, 1.41835213f, -1.04935634f, 0.65119827f, 0.03547254f, 1.88996327f, 1.58701086f, -0.56215239f, -0.80187100f, 4.55604362f, -0.67249978f, 1.41084409f, 7.86281586f, -2.38301182f, -8.50535774f, -3.82098866f, -2.40856767f, -5.33439016f, -3.34747362f, 2.69389009f, -1.64118791f, 4.52447939f, 0.04468334f, -1.48768258f, -0.69848812f, -0.71123981f, 3.66259432f, 6.10314512f, 1.37305343f, -0.62758982f, -2.99383426f, 4.20510864f, 1.48497128f, -0.08954811f, 2.43872309f, -0.59880185f, 0.37431365f, 2.45458341f, -3.28401661f, -1.94629693f, -1.93975246f, -0.26385683f, -0.45814323f, -0.18108580f, -3.74811840f, -0.29739976f, -2.24116230f, -0.28150487f, -2.24421668f, 3.46930790f, 8.35415077f, 0.05562943f, -2.81079793f, 1.10388446f, -2.82245207f, -2.98102283f, -1.08132946f, 1.19089699f, 8.00183105f, 6.35385323f, 3.72591257f, 4.59467506f, -5.74890900f, 4.42238331f, -3.36533451f, 0.18350232f, 3.05606651f, 1.18788099f, 2.87450886f, 0.27472210f, -2.80111074f, -0.66314960f, -1.96376896f, 0.75167024f, -4.72056293f, 1.10629988f, -5.00775242f, 1.48246133f, -3.91681528f, -1.86573625f, -6.17714882f, -0.67820001f, 5.69730282f, 1.04399037f, -4.93794823f, 3.09619617f, 2.18692017f, -5.54232264f, -3.10046840f, -0.68972743f, 2.81824327f, 3.04334164f, 6.13203907f, 4.14081764f, 1.02573645f, 5.71970081f, -6.01574707f, -2.07346702f, 0.99554527f, 1.69641590f, 0.66776669f, -0.80132431f, -2.03513098f, -3.42513680f, -0.06704485f, -1.87195873f, -5.42428589f, -0.20748445f, -1.52408111f, 0.97084987f, -0.48799962f, -0.45379883f, -0.26652339f, -1.20720732f, 3.94169855f, -3.18480229f, -1.87440264f, -1.18028760f, 0.52011997f, -2.13437462f, -4.52583313f, 1.69722807f, -0.89371562f, 3.37972403f, 6.38838720f, 6.98663378f, -4.05421400f, 6.89512825f, -5.09085655f, -2.16257906f, -3.33272719f, -3.01246452f, 0.37613097f, 1.80455804f, -0.36456174f, -5.32273912f, -1.29978943f, -0.53685790f, -2.12896323f, 2.55506587f, -2.57999182f, 3.40891910f, 1.36033249f, 0.83864629f, -2.88629293f, -7.36048365f, 5.61314154f, 1.32668555f, -2.58041072f, -3.71943092f, 1.60647738f, -2.74816346f, 2.47269106f, 0.85507953f, 8.39183426f, 3.42624784f, -0.01519036f, 5.68412066f, 2.51771593f, 1.03045523f, -2.08733034f, -2.44337177f, 0.81668580f, 1.30275154f, 2.99679208f, -2.91957355f, -1.71337795f, 3.34979844f, 1.51825011f, 5.20375061f, 2.27888370f, 1.38787699f, 4.23474550f, -4.05878592f, -4.85074377f, -0.22794735f, 4.64402294f, 1.24391258f, -2.04935098f, 1.26285601f, -7.51862240f, 0.62138438f, -1.95792389f, -0.96587181f, 0.85141110f, 0.79354531f, 7.93766356f, 6.07677746f, 2.05947518f, 6.55480623f, 1.44032848f, -0.70615625f, -0.07896036f, -5.08359432f, -0.01047915f, -1.89632201f, 2.57555676f, 3.83779287f, 0.42850614f, 1.80754125f, -0.06942326f, 6.35997963f, 6.06101418f, -0.97032297f, 5.71477222f, -6.06671238f, -3.46607208f, -4.98306370f, 2.84659123f, -2.11025190f, -0.04609144f, 5.26831341f, -9.56940651f, -3.67193556f, -1.71143103f, -1.35221267f, -4.26226807f, -6.89146233f, 8.21761799f, 5.69823503f, 2.28137946f, 1.88911343f, -1.44562483f, -1.60295713f, -0.52568185f, -3.31892347f, -2.81997776f, 0.35287106f, 2.98202395f, -1.39432132f, -2.70001364f, -4.14169264f, 3.50194883f, 4.12610435f, 5.52755260f, 2.65859175f, 3.61353087f, -0.83027136f, -5.10652542f, -4.48625374f, 2.06585884f, -2.76383352f, -0.64300913f, 8.19686604f, 0.96106279f, 2.45952058f, 2.47275925f, -1.03288829f, -0.64897656f, -3.77937531f, 4.27940083f, 2.58320260f, -0.57665241f, 1.87247813f, -3.81604433f, -0.24543774f, -1.62118483f, -0.73075479f, -0.48533297f, 2.05016756f, 0.45561486f, 0.03316188f, 0.77791005f, -1.56283605f, 2.36616826f, 5.58082104f, -1.30925488f, -1.06329608f, 2.17189479f, -3.43008828f, -4.71520567f, -2.56184673f, 0.17508316f, -3.25817418f, -0.41749167f, 0.18119079f, -0.73181152f, 3.99792433f, -3.08002281f, -0.99143314f, -1.83520067f, 1.18565679f, 2.98040128f, 5.67814350f, 2.35128760f, 1.41600966f, 4.02718067f, -0.08193968f, 0.64636409f, 1.35931289f, 2.37125754f, 1.75978124f, 3.90977740f, 1.50662971f, -2.84089065f, 1.29824126f, -3.38730979f, -1.61005294f, 0.58292413f, -0.03019404f, -1.57986510f, -0.56102908f, -3.03128719f, 0.51644313f, -2.01147819f, 0.98400700f, 3.00028515f, 0.74579155f, -3.37098312f, 0.93339360f, -1.29018497f, -2.14695001f, 1.30411184f, 0.71501279f, 7.47793055f, 4.06516457f, 3.50772929f, 3.52762985f, 0.55643129f, 0.32272506f, -4.30955982f, 2.49414706f, 2.07820845f, -0.34377906f, 4.39805031f, 2.77561307f, -3.91292810f, 2.43981409f, 0.18861845f, -2.76658440f, -4.97148752f, 3.25273705f, -0.08929539f, 0.19818619f, -5.83767605f, -0.97381884f, -5.68745661f, -5.42433214f, 3.98769903f, -0.40394354f, -1.83387578f, -0.80109525f, 1.47454357f, -3.14899540f, 0.80130816f, -2.26348829f, 4.06121159f, 6.13077354f, 5.31226397f, 2.94966197f, -3.65217376f, -1.08136678f, -7.14119816f, -0.85269439f, -0.70365787f, -0.81598872f, 3.62807679f, 3.08123684f, -7.82739496f, 4.07951784f, -0.14204243f, -0.66969109f, -5.07225513f, 2.88492823f, 0.47202343f, 0.72683257f, -6.84280777f, 0.41807127f, -5.09785986f, -3.74514675f, 2.03936672f, -1.06096244f, -1.52409148f, -0.97046643f, 2.27491093f, -1.55597985f, -1.29215479f, -0.79737484f, -0.01979581f, 7.65407991f, 5.54527044f, 4.04147148f, -2.64274883f, -1.89246953f, -3.89547634f, -1.06029689f, -2.85982800f, -1.41247237f, 1.55836034f, 3.38194537f, -2.97655582f, 0.87510300f, 1.26282072f, -1.77029657f, -3.57144690f, -4.19456863f, 0.53179169f, -1.42221975f, -3.09144497f, -0.84294832f, -5.02758694f, -2.68011904f, 0.89156240f, -0.34783912f, 4.64484835f, -2.34453487f, -1.28573155f, 0.09990287f, 0.01828218f, -1.79960847f, -1.06579173f, 1.08763921f, 0.43687880f, 3.24747229f, 3.83097172f, 1.07253766f, -1.33810723f, 0.76530832f, 1.58660865f, 5.60743904f, -3.54124737f, -0.89264417f, -3.83942485f, -1.03707337f, -1.61659896f, 1.65349591f, 1.72698796f, 4.96013832f, 0.78927267f, -0.35563886f, -3.48121166f, 3.79677629f, 2.59023166f, 2.74940348f, -2.17589283f, -5.91757107f, 2.43766379f, -4.15906048f, -1.74731481f, -2.49113035f, -0.57349741f, -4.04455185f, -1.46939647f, 2.21418452f, 0.09153593f, 2.23016739f, 7.91880608f, 4.04464149f, 0.07706618f, -2.41892862f, -2.19280314f, 7.61760712f, -5.89153862f, 0.33551922f, -1.70855618f, -0.30561331f, -0.14341974f, -2.48878574f, 1.31269515f, 3.45388412f, -0.02453184f, -0.12132037f, -4.27916241f, 1.25179088f, 4.09455204f, -1.83801770f, -1.86743176f, -4.02864933f, 3.44515228f, -4.39244986f, -0.56988084f, -1.69426417f, 2.18254852f, -4.78135824f, 1.73193693f, -2.27968478f, -1.49523509f, 2.51696730f, 4.03677559f, -2.03679037f, 1.32167840f, -2.22570705f, -2.74843621f, 6.29655170f, -3.67230225f, -1.86765468f, -0.14842367f, -1.21552539f, -0.92038238f, -0.51692355f, 1.08433771f, -0.01929832f, 0.15660909f, 2.31432915f, -3.86507082f, -0.69797570f, 0.13505173f, -1.50951028f, -0.69980979f, -1.51297045f, 3.63725281f, 0.13388813f, 2.73131752f, -0.96528149f, 4.92000961f, -5.92699385f, 1.69444644f, -1.17121375f, -2.33710480f, 1.35302818f, 1.39608085f, 1.68293881f, 0.94960749f, 1.89011908f, -4.08865070f, 0.13722643f, -1.62849212f, -0.19044125f, 1.37906075f, -3.92504406f, -1.45033538f, -0.42085981f, 3.38237071f, -3.06508875f, -1.39420545f, 1.13067436f, 0.92206454f, 0.49917889f, -2.74508023f, -2.19221997f, 1.77914095f, 0.10854459f, -2.62178278f, 2.35042715f, -0.15322030f, -0.67014873f, -1.75627899f, 2.64074945f, 2.76339936f, 2.67275214f, -0.62736398f, 0.58251178f, -4.64895678f, 5.50419283f, 2.53566456f, -2.44196153f, -0.07845879f, -2.80389643f, -0.64810950f, -0.05813205f, 1.67155504f, -2.69673729f, -1.72486305f, -0.53888649f, 1.86805439f, -1.37128329f, -5.37923479f, -2.08133769f, 0.58187997f, -1.39498150f, 0.21874082f, 4.33726025f, 6.29673958f, 0.72312093f, -3.32683516f, 1.73482585f, -0.00766110f, -2.63785434f, -0.13511759f, 4.07195950f, 0.94139838f, 3.15717316f, 1.53720927f, 1.87664819f, -2.33655119f, 6.18176556f, -2.73912525f, -2.45279956f, 2.20392370f, -0.56854641f, 0.98915887f, -2.64472580f, 2.40633702f, -4.93327999f, -1.28942823f, 0.98247659f, 1.31774998f, 0.07669818f, -5.91169453f, -0.43135011f, 1.27404964f, -0.59787154f, -0.22716975f, 0.74409103f, 10.27316475f, -2.29192710f, -2.19403267f, 3.78925133f, 3.19553399f, -4.42490482f, -0.80781460f, 2.16568565f, -2.54165983f, 2.54885101f, 4.18779039f, 1.73079813f, -1.48891807f, 11.60153770f, -0.98686743f, -2.88813901f, 2.32898521f, -0.36101711f, 2.34522438f, 0.29057693f, 1.39800644f, -4.31848240f, -3.21217132f, 0.11740226f, -1.21613467f, 0.57248503f, -4.44853830f, 1.54665899f, 3.14459944f, 1.76809108f, 0.26693153f, 0.86913753f, 9.47121620f, -2.07677889f, 2.08578467f, 1.30181742f, 1.58683562f, -3.52757788f, -1.32763624f, 0.79821301f, -2.19358301f, 1.17707348f, 6.01983643f, 4.11209440f, -2.04209709f, 7.00413418f, -1.84904683f, -1.32542288f, -0.01298118f, 0.70377320f, 0.27815005f, 2.07879829f, -0.71606725f, -4.94399881f, -2.11898828f, -0.39051518f, -2.21034360f, 3.05337906f, -1.56889665f, 1.97065282f, 2.61320901f, -0.34063196f, -0.57001418f, -2.13183641f, 3.48879004f, -0.12067288f, 0.48568326f, -1.81424558f, 2.28868723f, 1.44802380f, 1.25918829f, -1.76415455f, 5.35742331f, 3.50682044f, 4.71371317f, 5.89110756f, 8.51241302f, 4.07391453f, -0.05887252f, -0.18202400f, 2.27119660f, 6.78274727f, -2.87470293f, -5.14336634f, 0.76443815f, 2.04625130f, -0.43199503f, -1.01353514f, 2.42951298f, 2.35641170f, 0.32345510f, -4.04195738f, -4.77967072f, 0.26564783f, 6.11455107f, -2.53868008f, -3.11839914f, -1.04203856f, 5.17195654f, -4.15338612f, -3.84149241f, 0.48130888f, 3.09706950f, -4.18423653f, 5.26233864f, 3.55831861f, 3.75122595f, 8.14969349f, 6.80038738f, 4.68907356f, -1.40135396f, -3.19287133f, -3.15895939f, 8.77363205f, -4.48793411f, -3.80537176f, -2.40145254f, -2.74341679f, -2.02862644f, 5.33402443f, 9.25365734f, 2.50246119f, 0.32847846f, -1.50564361f, -4.26163197f, -1.40994716f, 2.50708485f, 0.44500345f, -0.62516934f, 4.09846306f, 5.29355669f, -4.02224922f, 0.73442125f, 0.46648952f, 0.67028689f, -6.30715466f, 6.56297970f, 3.80854273f, -5.19078207f, 4.98839283f, 7.59161472f, 0.46010983f, -2.10227895f, 0.29324162f, -2.67019558f, 4.57838106f, -3.02338457f, -3.08647728f, -2.00112700f, -3.81710315f, -0.08346784f, 1.69288683f, 5.68807268f, 3.29351830f, 0.54618967f, 1.83540761f, -5.38810253f, 0.51326782f, 4.40081882f, -4.03805828f, 0.49482727f, -1.36024392f, 2.91845679f, -2.00959015f, 2.47489738f, -1.43354976f, 1.92024410f, -6.55897284f, 1.79488957f, -0.89570928f, -6.13094234f, -0.45504010f, 2.35239482f, 1.29039919f, -4.78849840f, -1.52545333f, -6.50420475f, 2.99257326f, -0.55620033f, 0.26807702f, -2.52090979f, -4.59419632f, 0.57965040f, 2.19423151f, 2.04760551f, -0.57048106f, -2.20812702f, -0.04777686f, 1.38053393f, -2.71448946f, -1.06219673f, -3.62008905f, 1.85719645f, 1.28355026f, -2.76315832f, 1.65295160f, -4.01645803f, -3.10454416f, -0.65713316f, 1.22384977f, -0.70416176f, 4.45064926f, 1.31602776f, 2.06907344f, 2.48872757f, 4.25775290f, 3.50504255f, -0.68262041f, 1.29799378f, -1.01969171f, 2.98593879f, 0.12607655f, 0.37219539f, -0.84196299f, -3.80019331f, -1.82315290f, -0.38489276f, -1.45200360f, -4.00882292f, 0.61042011f, -0.16738498f, 1.33787775f, -2.26938057f, 1.03656030f, 8.89089870f, -1.60370600f, -5.38691807f, 5.72182989f, 2.72854710f, -6.18535757f, -3.13408709f, 2.79175353f, 5.18425512f, 9.46434212f, 2.40110517f, 1.11330092f, -3.57366538f, 4.80967665f, 0.40691876f, -3.65484858f, 0.92398167f, 2.53852940f, 3.17747331f, 2.14199781f, -1.69107199f, -1.91864693f, -3.18452644f, -2.42408276f, -2.14332366f, -1.35526609f, -4.50732136f, 0.58234072f, -1.81547785f, 0.57311213f, 1.10584176f, -0.97226644f, 11.73174381f, -2.00559855f, -1.81175601f, 2.33131361f, 0.49264961f, -0.42245382f, -1.37528467f, 1.55768061f, 0.21152198f, 13.08896351f, 10.33674145f, 5.77929306f, -6.19886398f, 5.67007637f, -6.61288071f, -2.58029866f, -4.05192375f, 1.77221894f, 0.29821560f, 5.23508501f, -5.09560966f, -0.97536200f, -5.17957878f, 1.02876794f, -4.52072096f, 2.22126532f, -4.81708670f, 0.44538212f, -2.30738068f, 3.15900373f, -4.99227905f, 0.82632786f, 9.65415478f, -0.63819492f, -3.25479436f, -0.13276935f, 0.21337092f, -2.22116399f, -3.04922724f, 0.65568435f, -0.10706246f, 4.58047390f, 7.80782652f, 5.49080181f, -3.97114491f, 6.43327618f, -6.54772758f, -2.10962629f, -0.79831678f, -0.08316499f, 2.48658133f, 4.14070511f, -0.59806836f, -4.58636141f, -0.31166920f, 0.31757897f, -3.92562199f, 0.65357721f, 0.55871534f, 1.71843934f, 1.62395024f, 0.00695819f, -4.56716251f, -3.76420808f, 4.24979544f, -0.86128616f, 0.23126510f, -6.32968998f, 1.83346081f, 3.81335950f, 2.98407745f, -1.80454743f, 6.61764765f, -1.39372075f, -0.86780751f, 7.24317265f, 2.24205112f, 1.05702817f, 0.55431479f, -1.54557061f, 3.36389136f, 4.70898724f, 1.11327887f, -3.78462076f, -3.63381767f, 2.86510396f, 0.74203897f, 0.81488025f, 3.54250598f, 3.24824381f, 3.19000244f, -0.58995843f, -7.05670738f, 3.18306041f, 3.95191574f, 0.81820154f, -1.91068232f, -2.05426741f, -1.05589008f, -3.18377590f, -1.86278260f, -8.80374908f, 0.93416154f, -4.60517359f, 8.38999462f, 5.26356745f, -8.89992714f, 8.95298958f, 4.22590351f, 1.00351548f, -6.90151119f, -8.07641125f, -4.82450199f, 8.02293015f, 4.11661243f, 0.95457208f, -7.07843113f, -4.30524826f, 5.02697992f, 5.21011686f, 0.80132771f, 3.23420191f, 3.82452774f, -2.13171721f, -7.88879967f, 1.31062031f, 1.90848613f, -3.51572514f, -3.75684500f, 3.62577081f, -5.76075602f, -2.79389215f, 0.32598805f, -4.28981733f, 4.21048594f, -3.84532523f, 3.19815183f, -0.40756655f, -2.19974327f, 6.25655174f, 3.42396951f, -1.88986623f, -1.92803884f, -2.97344875f, -0.09756154f, 5.24342251f, -0.72513700f, 1.06113195f, -1.30720282f, 4.69107103f, 0.58984971f, 2.33985567f, 1.46385121f, 3.16576266f, 6.77769995f, -5.92685127f, -12.61141014f, -2.83663774f, 4.90253258f, -6.32688522f, -3.00096869f, 2.38634992f, -7.21459866f, -5.89208746f, 2.84085894f, -1.21792030f, 6.70161343f, -4.00450230f, 5.29881001f, -1.45574808f, 0.77542424f, 1.38336325f, -0.21572059f, -3.38088870f, 2.33249640f, 0.68824625f, -3.68440270f, 0.33481622f, -0.39239681f, 0.14560902f, 1.61039007f, -3.11967754f, 2.49372435f, 2.68783092f, -1.17559779f, 0.95257235f, 4.35451412f, -0.56818569f, -7.32110357f, -7.58534050f, -2.10573673f, -3.34446383f, -0.32183546f, -0.78525496f, -1.76974547f, 5.19060802f, -2.11319876f, -3.41755080f, -0.36864156f, 1.32680905f, 0.45004874f, 6.17223930f, -1.60707474f, 0.46096295f, -3.88852644f, 1.84729624f, -0.03412050f, 0.99224162f, -2.05553341f, 3.47793245f, -0.06305170f, 0.51314175f, -2.91650558f, -1.78121483f, -2.85465693f, 0.24649808f, -2.70376635f, 0.42334458f, -1.13862336f, -0.98409218f, -0.96593523f, 2.22128963f, 0.53402066f, 3.33979344f, 8.57430458f, 2.34217858f, -2.40062976f, 5.81624222f, 1.13290989f, -5.06850052f, -4.72865725f, 1.82859278f, 6.78569555f, 8.56885242f, 2.76462936f, 0.33891773f, -2.81092787f, 0.79498398f, -2.27208567f, 1.55182552f, 2.17166376f, 6.12517643f, 3.56859684f, 0.27685475f, -1.38408327f, -1.03533340f, -3.46618199f, 0.79240030f, -3.89390516f, -0.55852515f, -1.16367757f, -0.07008934f, -2.20105195f, 3.81210446f, -0.66834474f, 0.43603873f, 10.92334938f, 2.48571420f, -6.34997845f, 4.23135757f, 0.45045292f, -4.13489866f, -3.92324209f, 1.88537407f, 2.57159734f, 9.90973091f, 4.37453461f, 7.34546280f, -2.51120615f, 11.12575245f, -3.23452854f, -2.49947500f, 1.39819741f, -3.78950691f, 2.40617585f, 5.10036278f, -3.55743456f, -6.42888737f, -2.51929998f, -1.90880990f, -1.81618094f, 1.60946512f, -4.09737110f, 1.96408439f, -1.90115595f, 2.44444203f, -2.31254292f, -4.01332951f, 8.65541840f, -0.58626485f, -4.02226830f, 0.43893200f, -3.78272748f, -5.46277428f, 0.01306701f, 0.61185312f, 0.24469066f, 1.30214953f, 5.87789631f, 8.75197792f, -5.31634712f, 3.43556309f, -5.90755081f, 0.54375106f, -2.48162293f, -3.51843548f, 2.55853295f, 5.06387186f, -2.09662485f, -3.00377345f, -3.21781397f, -0.14537808f, -4.65453672f, 1.92747557f, 0.41553855f, 4.09379959f, 0.83387995f, 1.50868511f, -6.54959488f, -8.38881016f, 5.50689125f, -2.88616610f, -1.21597648f, -0.23817590f, 1.50816703f, -2.26873541f, 2.29862142f, -1.61143053f, 5.97371244f, 4.71440220f, -0.20635787f, 8.85926723f, 0.56064367f, -1.04103339f, -4.47060108f, -2.63824081f, 3.06782055f, -2.07702565f, 3.38269401f, -1.59988797f, -3.80122590f, 2.35341501f, 2.69095278f, 3.87612104f, 1.89984226f, 0.95496917f, 3.14841127f, -5.84543085f, -7.24945450f, -2.65708590f, 2.87417006f, 0.97556210f, -3.75203967f, 1.55287778f, -7.43401051f, -1.29005826f, -3.40252638f, -4.01049423f, 2.82721639f, -1.21479535f, 8.54563904f, 7.39749908f, -0.61361837f, 7.60177565f, 1.65812778f, -0.83008504f, -3.60961151f, -7.69062138f, -1.26275063f, -4.17071676f, 5.28448200f, 4.04685593f, -1.18231702f, 1.15276611f, 1.58620787f, 6.75060844f, 3.29332161f, -0.67640316f, 5.78984785f, -3.14913464f, -6.41867924f, -2.58316016f, -2.04366302f, 2.01089478f, -3.81723452f, 3.63843751f, -5.13238430f, -3.79432917f, 4.86581373f, -1.06922054f, 3.95978498f, -0.78166616f, 8.35650539f, 5.35834265f, 0.35594034f, 9.41657066f, -0.84108615f, -6.54425859f, -3.44328952f, -6.55536795f, -0.08963367f, -1.53906262f, 0.17658240f, -0.13108420f, -0.44371247f, -0.78411150f, 2.64754868f, 9.66306782f, 1.70506203f, -0.31588936f, 4.31715870f, -6.16665173f, -10.43371868f, -3.72962189f, 4.35245228f, -1.75867891f, -4.20046234f, 8.62637043f, 1.45946813f, -3.30153608f, 0.85179043f, -2.66643381f, 3.01863337f, -2.52916121f, 8.35405540f, -0.37298933f, -0.89473486f, 6.88681793f, -4.46370125f, -7.50776386f, 3.80255938f, -3.55003357f, 1.43528831f, -2.20383263f, 2.34999895f, 2.03803205f, 1.94830751f, -1.85976326f, 0.97718471f, 5.53710842f, -0.80560827f, 0.23925614f, 5.98795223f, -2.03578377f, -7.77835321f, -2.79955530f, -1.88185954f, -2.49112058f, -0.76095992f, 2.71161270f, -0.55918610f, 0.83789903f, -1.42063200f, -0.61528748f, -4.18273115f, 1.76384258f, 4.21265936f, 5.50964785f, -0.93324339f, 3.83215356f, 1.52210593f, -0.91594946f, 1.31148386f, 3.20160103f, 1.24493563f, -0.72693497f, 1.84716725f, 3.09897518f, -1.34605026f, -1.17511916f, -1.05526352f, -1.08590937f, -1.41319299f, -3.75052118f, -2.67095542f, -0.76179552f, -3.32081509f, -1.04692316f, -1.30194843f, -1.98795474f, 5.01223469f, 0.21895903f, -1.85535169f, 3.12362719f, 0.16198632f, -3.86784005f, -2.03062248f, -0.15415624f, 8.22020721f, 4.83055592f, 4.50315666f, 4.19443417f, 0.42727345f, -4.67786789f, -5.18739986f, 2.53988838f, 3.19683266f, 1.80313504f, 1.94664574f, 0.59795094f, -4.21626759f, 0.50492239f, -0.41232634f, -0.99224532f, -3.94929314f, 1.74060190f, -0.92474866f, -1.00664830f, -6.17397356f, -1.33146775f, -3.78111315f, -4.91876888f, 2.50303864f, -0.34890354f, -1.25013232f, 0.38168997f, -1.84135628f, -4.46107960f, -4.05920792f, -2.61709857f, 0.71046209f, 9.80566883f, 6.34086990f, 2.73394704f, -2.03342366f, -2.21424174f, -5.56514263f, -4.74755144f, -2.20672894f, 0.09010231f, 1.70423889f, 3.19200158f, -6.99027634f, 1.14216340f, 0.05824995f, -0.76996505f, -6.51575899f, -0.41109252f, 0.78229940f, 1.36170781f, -5.65170193f, 1.12221193f, -4.60430050f, -4.40174437f, 4.01805925f, 0.10774946f, -2.77991009f, -0.18023163f, 0.02151692f, -1.77023101f, -1.86639869f, -0.69443607f, 4.92290831f, 6.83520412f, 4.27372265f, 6.54272366f, -7.59249687f, -1.40776849f, -3.52368808f, 1.01398587f, -3.58802676f, -0.35658866f, 1.14716864f, 3.75847244f, -2.30159235f, -0.72130895f, -0.24564353f, -1.77531350f, -3.08677864f, -0.73486501f, -1.20357263f, 0.60789430f, -3.46990204f, -0.20668676f, -5.46096087f, -5.22016764f, 0.98259866f, 1.81012678f, 3.92534304f, -2.94997001f, 1.65154219f, 2.27040243f, 0.99095678f, 0.09144652f, -0.99103236f, -1.11210847f, 0.78181303f, 2.38706732f, 2.96695375f, -0.17279971f, 0.31143007f, 1.35465562f, 2.03586054f, 6.19515753f, -3.14652419f, -2.89027119f, -3.26665854f, -1.93043876f, -0.46601450f, 1.07655203f, 1.74946189f, 4.02148342f, 0.69275337f, 0.50094581f, -4.07613230f, 2.98369169f, 4.24537849f, 0.49480581f, -2.02408123f, -2.02068973f, 6.54505825f, -5.19377470f, -0.12596917f, -0.70204186f, -0.98308045f, -3.19708824f, 1.63609934f, 1.35475993f, 0.16313422f, 4.13918924f, 7.69187021f, 3.72601676f, -1.97790039f, -1.16739464f, -3.31835508f, 8.14553452f, -1.78718984f, 1.21505618f, -3.84255409f, -3.21992350f, 0.07376552f, -0.81223297f, 3.57002878f, 1.48521733f, -0.45995998f, 0.30551746f, -3.33944130f, 1.39538884f, 1.84758544f, -0.21494150f, -2.27316713f, -4.37771225f, 6.48841667f, -5.00251961f, -0.45162797f, -5.01056004f, 0.70199943f, -4.60057783f, -2.22394514f, 0.07777429f, -1.49820781f, 3.47308421f, 6.13231564f, 1.18605387f, -4.78924608f, -3.49548388f, -2.73382568f, 6.24617863f, -2.74291611f, -1.03833354f, -2.20752788f, -2.33219409f, 1.48633552f, 1.65796840f, 4.95045471f, 2.58479190f, -0.90922785f, 0.71312457f, -4.44465590f, 1.37020862f, 2.37683725f, 0.18805164f, -3.28422308f, -1.64939332f, 3.64181972f, -3.75277281f, 3.67203593f, -0.11204052f, 2.24140930f, -3.90657187f, 2.56883717f, -1.44016707f, -2.83842611f, -0.29104578f, 2.17757058f, -0.71431804f, 1.36911654f, 0.85083604f, -1.60110259f, -1.97247636f, -1.61163378f, -0.81236130f, -0.38993555f, -3.03631902f, -0.38213277f, 0.06394482f, 3.19348621f, 0.36771113f, 1.36763072f, 2.49159527f, -0.39599860f, -2.69996762f, -0.97561121f, -2.97563028f, -0.49662948f, -0.17564940f, -2.79042959f, 0.72395414f, 2.07260203f, -0.99439794f, -2.20248008f, -0.07389921f, 0.65536159f, 4.73054695f, -0.63917702f, 0.58788192f, -3.60156059f, 6.59609890f, 3.88419437f, -3.38469863f, -3.56237841f, -2.03295064f, 0.07279694f, 3.71804547f, 0.79928309f, -2.13411403f, -1.13909864f, -0.34193408f, -1.00338125f, -1.44231665f, -5.39835978f, -0.45086145f, 1.16064668f, 2.58335257f, 2.10072684f, 4.64244223f, 7.10090065f, 1.01974952f, -4.44687223f, 2.99792576f, 1.10303724f, -1.22736573f, -3.91514421f, 3.07458854f, 2.18765211f, 3.34481716f, 2.46166849f, 2.99648619f, -0.94046807f, 5.55028200f, 0.92199719f, -0.83934361f, -0.72042274f, 0.84869325f, 1.46914721f, 0.85937387f, 4.77306223f, -4.06436539f, -2.59847593f, 2.44828081f, 0.50484699f, -2.71092367f, -6.39010477f, 0.91778028f, 3.25469685f, 1.30310678f, 1.35258150f, 3.56171441f, 7.82435083f, -2.51527429f, -4.24328852f, 2.36876059f, 1.94595242f, -2.59290171f, -6.62389565f, 3.32567835f, 2.13659120f, 4.09299326f, 3.48293996f, 2.64965177f, -3.19157362f, 13.37204266f, -0.50297594f, -4.57448196f, 3.95582604f, -0.69038916f, 0.10098404f, 1.18737555f, 3.65761185f, -5.69623756f, -2.03357077f, 1.02868807f, -1.38448596f, -0.05690211f, -8.48874187f, 0.56755424f, 1.45485961f, 0.66273880f, 0.06495565f, 1.79539490f, 8.46864319f, -1.22696662f, -1.87585378f, -0.99768794f, 2.72801924f, -0.66980243f, -2.31924677f, 0.33271110f, 0.11666083f, 1.86980045f, 5.95332909f, 7.38583708f, -2.80956483f, 6.79227638f, -6.78070831f, 1.21884382f, -1.40695429f, 0.90236962f, -1.13695288f, 0.50760663f, 1.00955284f, -5.39029121f, 0.24987072f, 2.24283314f, -4.02145576f, 2.18057394f, -3.35627747f, 1.26061773f, 1.30342579f, 0.11311233f, -1.11199212f, -4.06509686f, 5.82649660f, -1.24059582f, 5.51652861f, -1.90937877f, 1.10658336f, -0.47065550f, -2.39167786f, -1.95931304f, 4.12717247f, 1.15396059f, 1.26015663f, 7.97836876f, 7.33633423f, 2.27785325f, -2.83802366f, -2.74850106f, 0.86126029f, 6.18781090f, -1.43707538f, -6.97134876f, -3.25486469f, -1.95214593f, 0.91066706f, 0.89637989f, 1.06481194f, 6.25791073f, 0.81779671f, -1.08384395f, -3.21191931f, 2.04216075f, 4.76030350f, -2.37217665f, -1.42571259f, -6.35876131f, 4.62536526f, -5.40060568f, -3.14868999f, -1.00587153f, 1.80662942f, -7.03201485f, 6.08373499f, 0.99862772f, 2.21717811f, 4.06814623f, 6.02428913f, 5.33422756f, -0.87013257f, -2.22477579f, -2.51505303f, 5.82925224f, -0.82854009f, -4.30698347f, -1.75007713f, 2.08352375f, -2.25235629f, 1.17517352f, 5.77717733f, 2.27472878f, 2.72778273f, -1.95411634f, -4.52602863f, 1.13983536f, 1.16340065f, -2.02740526f, -3.11290503f, -1.94906235f, 1.54855204f, -4.52984142f, 1.97465122f, -1.79415476f, 4.03510094f, -8.45349979f, 10.87430096f, 2.19863629f, -5.39083815f, 5.86213875f, 6.25744534f, 6.52600002f, -4.72149038f, -1.75254321f, -5.51459169f, 7.03155518f, -2.01889277f, -4.58441257f, -3.61226106f, 0.42395937f, -0.93263882f, 2.28703761f, 2.80611467f, 2.59498215f, 0.65989012f, -1.51268566f, -4.49465561f, -4.70453882f, 5.44696808f, -4.37603617f, 0.46670085f, 2.82488608f, 2.18854523f, -2.04817152f, 1.19557285f, 1.53618634f, 4.44758606f, -7.31593513f, 7.43966007f, -3.55480957f, -5.29834652f, 2.14622784f, 1.65194583f, 2.71262598f, -4.86145496f, 0.79726243f, -8.88541985f, 1.19627261f, 0.79660845f, -1.98016644f, 1.03741014f, -3.93128228f, 1.05535269f, 2.01378822f, -0.46086323f, -0.77754641f, -1.43942690f, 0.49809402f, -2.27861357f, -3.29815221f, 0.38201320f, -3.98481083f, 4.88261318f, -0.44555628f, -2.57224536f, 2.35001850f, -2.65835261f, -2.43422794f, -2.97889376f, 1.07349825f, 1.88157082f, 4.74075413f, 0.60376728f, -0.48894715f, -1.15800071f, 4.68110943f, -0.86976886f, 1.49192941f, 0.62665290f, 0.20652676f, 0.53916287f, -1.45706177f, 0.66133004f, 1.34405875f, -4.27689552f, -0.20838106f, -5.14266443f, -1.29718637f, -1.74506426f, -0.86022055f, -3.57553625f, 0.46880072f, -1.25287139f, 3.28596354f, 11.33191013f, 1.23942876f, -3.87616491f, 7.57880497f, -0.22940339f, -5.68512678f, -1.94969654f, 5.85449600f, 3.75705457f, 4.24395847f, 1.60086083f, 2.62553668f, -0.93964291f, 5.84753895f, -0.79931092f, 0.48274064f, 2.07170033f, 3.02243996f, 2.63509989f, -0.76043403f, -1.64048159f, -6.17683458f, -3.09974527f, -2.12773156f, -0.89379883f, 2.82242465f, -1.99981332f, -0.08763933f, 0.01921120f, -1.94142103f, 2.48067307f, 0.41083777f, 8.24922180f, -1.84516132f, -1.39224625f, 5.03956223f, 0.49562740f, -5.28296328f, -0.20005548f, 3.13672113f, 0.51187158f, 7.11563921f, 6.43059587f, 3.48430967f, -5.37095928f, 8.03863049f, -5.53923941f, -2.16421175f, -3.77641368f, 3.29633045f, 5.04030085f, 2.25945377f, -3.04169011f, -2.16198015f, -2.49559617f, -0.26252726f, -6.99201345f, 2.87374353f, -0.12568980f, 0.23314142f, -1.32087135f, 4.39030552f, -0.24638844f, -4.37242651f, 14.09276772f, 1.23987353f, -1.72249663f, 0.31124914f, -2.13725138f, -3.74915648f, -1.87147236f, 0.47318631f, 1.13337576f, 3.00416899f, 8.82548523f, 4.80538750f, -5.28486395f, 5.51870108f, -5.15801477f, 0.95712411f, -1.50416136f, 2.34657240f, 4.20726633f, 5.56757259f, -3.30645251f, -3.39945269f, -2.68488026f, -2.53525281f, -3.15145874f, 2.74529529f, -0.96283442f, 2.87778258f, 0.22186530f, 1.24905694f, -7.07941198f, -5.45916176f, 3.46988297f, 0.92430985f, -0.98330998f, -2.23672342f, -3.03262734f, 0.73941302f, 0.98004431f, 0.83219361f, 7.17411804f, 4.27849865f, 0.14765590f, 8.61269569f, 9.04497051f, 1.53991723f, -2.08305025f, -4.34939337f, 0.63786775f, 2.60098696f, 0.02432060f, -1.48516297f, -4.06825686f, 5.12420368f, -0.75312757f, 1.96927559f, 4.91575956f, 3.41533065f, 3.62557888f, -4.35002136f, -5.91343403f, 0.45026422f, 4.93286371f, 3.45830250f, -4.39032364f, -0.51697755f, -7.41543341f, -3.06703568f, 1.01196158f, 2.47106576f, 5.54014874f, -4.65312243f, 8.61000633f, 8.25905323f, -1.41497111f, 8.69221878f, 0.40090930f, 1.11325574f, -1.67089832f, -4.01080132f, 1.07925677f, 2.68086481f, -0.73093414f, -1.35081220f, -7.85765076f, -5.98989439f, -0.04651213f, 4.63693142f, 2.07757711f, -0.22652936f, 3.45525455f, -0.69198442f, -10.39761639f, -2.02106953f, 4.77755499f, -2.67665577f, -1.72481167f, 4.49634743f, -2.55717134f, -4.55044937f, 0.46377492f, -3.08933020f, 3.86891365f, -2.79104614f, 8.36974335f, 0.86471701f, -5.39342690f, 12.54906940f, -0.41536295f, -5.29502535f, -3.94430566f, -5.67391300f, -4.65079165f, 2.22505951f, -0.30000746f, 2.27855444f, -4.81604433f, -1.73440599f, 4.68784523f, 5.00208044f, 0.18863934f, -1.74989462f, 3.17923450f, -1.59773099f, -12.59962940f, -1.54495025f, -0.00576371f, 1.79913878f, -2.43449807f, 1.49516344f, -3.90507102f, 1.68647158f, 4.50177765f, -5.32286358f, 3.47539330f, -2.90529680f, 1.61576962f, 0.83679676f, -5.55615807f, 3.78939056f, -4.46644831f, -5.95550919f, 0.37808037f, 0.51334500f, 1.74658906f, -0.82085419f, -0.65387219f, 3.67790437f, 0.03758264f, -2.42622781f, 1.83335185f, 4.73835945f, -0.83536482f, -0.03993917f, 3.78230667f, -4.81265640f, -8.26869011f, -1.30363441f, -2.09106350f, -3.96769738f, -1.89037073f, 0.38682747f, 0.05434489f, 5.72213697f, 0.55685395f, -3.47729349f, -1.11535001f, 2.09416127f, 5.08877802f, 5.72183466f, 1.29632664f, 0.16822398f, -2.43180108f, 3.49967623f, 2.15753818f, -0.26548505f, 3.24446392f, -0.00599277f, 1.08215356f, -0.23225522f, -2.40723038f, 0.18496060f, -3.70608735f, -0.19918591f, -1.64028871f, 0.80792952f, -0.85334057f, -2.52314138f, -3.12099195f, 0.17949918f, -0.82650864f, 2.32224989f, 9.56476116f, -0.20134282f, -0.48428559f, 2.86784410f, 0.07289505f, -3.92880869f, -2.11887884f, 0.59164631f, 6.31267452f, 7.49149418f, 2.88749456f, 2.40504885f, -3.57608175f, -1.48019314f, -0.69410253f, 0.90275228f, -0.34111357f, 2.19190216f, 3.39090061f, 3.39631820f, -5.19105434f, 2.67546582f, -2.56549048f, -0.59797800f, -4.21802664f, 0.63918972f, -0.69969130f, 0.47496963f, -4.30976725f, 0.16531238f, -3.59595251f, -0.76877379f, 11.79971790f, -0.93276632f, -1.48630571f, 8.04754066f, 2.09168458f, -3.77018499f, -4.19337654f, 0.26171905f, 1.99359691f, 8.96759701f, 8.39609814f, 6.19231987f, -5.36037970f, 4.69818354f, -4.22453928f, -4.61665344f, -2.52073431f, 1.34026706f, 2.80182385f, 2.56681514f, -4.04676390f, -3.01466990f, -4.10480118f, 0.38737059f, -0.37146521f, -2.26529670f, -1.72867084f, 0.93472683f, -2.47562981f, 0.89871657f, -1.67618203f, -0.28950238f, 5.30124855f, -0.14731219f, -0.81319761f, -1.11265934f, 0.11356127f, -2.52802444f, -1.93826056f, 1.06187987f, 1.48062325f, 4.28070498f, 5.69893932f, 9.26904392f, -4.23773003f, 5.78582096f, -6.18445301f, -2.85200453f, -5.30461454f, -4.16009140f, -0.07239690f, 4.11531162f, -1.12266588f, -1.50265646f, 0.47661865f, -1.90043914f, -6.48978710f, 1.71005368f, 0.18256521f, -0.88272136f, -0.51324779f, -0.78045660f, -5.21036625f, -4.11805344f, 3.99454761f, -1.04999924f, -6.99629354f, -5.02737141f, 0.94748145f, -2.35882139f, 4.13982439f, -1.41835535f, 7.56763077f, 3.97024012f, -4.08156776f, 6.90305424f, 0.53571963f, -2.22625160f, -2.09144926f, -4.98530245f, -0.15102190f, 0.59995949f, 3.28562784f, 0.77991986f, -3.08389306f, 3.34046674f, 0.41394949f, 5.10031366f, 2.99692893f, 0.17706826f, 2.85998058f, -6.68330860f, -6.72653008f, -0.04071128f, 3.71085787f, 3.17834806f, -4.88019037f, 6.74075413f, -7.41782188f, -5.22026348f, -1.94595623f, -3.61318684f, 1.85610664f, 1.08613706f, 6.41580677f, 1.46376514f, -4.11524010f, 9.59146214f, -2.92772651f, -1.70753336f, -1.51594138f, -4.88185692f, 1.47331417f, -2.23893595f, 4.98459148f, 1.29359996f, -2.29221845f, -0.99594390f, 3.05759239f, 6.86030054f, 2.40487719f, 3.28339863f, 7.72739315f, -3.60563445f, -9.73502827f, -1.51672328f, -0.08473521f, -2.43673515f, -3.26616001f, 3.63767886f, -11.25394535f, -5.17597103f, -1.27523947f, -7.82669783f, 0.67929745f, -4.50530529f, 5.49323797f, 6.78993320f, -2.28033876f, 4.61412525f, 2.55109429f, -12.38607693f, -0.63024014f, -3.45992327f, -0.84092742f, -0.03252453f, 4.58635283f, 5.28213978f, -1.28417206f, -1.71185923f, -0.26850975f, 8.28257561f, 4.47432184f, 2.72818279f, 8.42217731f, -4.22216320f, -8.95128918f, -1.57179546f, 1.34253705f, -5.47035217f, -5.50866985f, 4.64156532f, -6.11207914f, -5.46734476f, 3.54298997f, -2.79237103f, -0.70766860f, -3.62739944f, 3.22660995f, -2.02262759f, 0.11224222f, 2.63832402f, -0.91955596f, -4.65958309f, -0.29729855f, -1.78957534f, -0.40749407f, 0.51688713f, 0.83725226f, 0.30945438f, 1.20769620f, -1.75219965f, 2.59689760f, 5.01501608f, -1.59034789f, 0.58155286f, 3.75831509f, -5.26110506f, -8.65382767f, -6.19066620f, -0.61932850f, -2.71863723f, -0.87443137f, 3.40582991f, -1.27868056f, 3.51236677f, -2.07806540f, -0.85076392f, -1.14599180f, 1.16361260f, 1.86411846f, 5.86179352f, 0.69029891f, -0.06060839f, 1.54649436f, -0.60351688f, 1.51970077f, 0.04187265f, 1.64540339f, 2.75502157f, 2.46308279f, 1.69071770f, -3.23827076f, 0.92096543f, -3.09458661f, -1.23823690f, 0.24035048f, -0.74456501f, -1.85476089f, -0.32914662f, -2.10325241f, 1.19795251f, -2.05372071f, 1.02114081f, 2.56286955f, 0.42165697f, -1.65826249f, 4.00724554f, -2.18727994f, -1.05848944f, -0.52338278f, -0.28714985f, 8.08780861f, 5.04444599f, 3.51866961f, 3.37445784f, -1.96067202f, -1.21509445f, -3.96595931f, -0.80801201f, 0.76944816f, 1.80147493f, 4.14419460f, -0.12201095f, -2.77788162f, 1.13284469f, -2.05441403f, -0.61129224f, -2.69690657f, 1.91634214f, -2.17146754f, -0.22308528f, -6.02561045f, 0.49161875f, -6.74280357f, -4.62689781f, 2.47910833f, 1.86534905f, -3.24152899f, -1.39898300f, 0.29427958f, -2.16338181f, 0.90073711f, 1.75551236f, 4.42651892f, 8.34437466f, 5.50070190f, 5.68162251f, 1.65345454f, -2.72315669f, -5.43411493f, -0.29380533f, 1.07508349f, -1.73533511f, 2.56912184f, 3.62010550f, -6.30422783f, 1.74158525f, -1.22070909f, -0.80982518f, -4.14757967f, 4.29217434f, 0.70600843f, -2.09282112f, -5.09018898f, -0.11623126f, -5.99775553f, -4.66743088f, 1.61512172f, -1.30276895f, -3.17103505f, -0.26310229f, -1.00843918f, -0.77664804f, -2.05240250f, 0.04728425f, 1.15720487f, 4.01001406f, 7.24615860f, 2.55452180f, -5.76347876f, 0.34683830f, -6.05540276f, -4.70677900f, -0.93182588f, -4.37759733f, 2.93209839f, 1.63947964f, -2.43563962f, 1.35213876f, 0.00670356f, -0.02742785f, -2.16460943f, 1.39449501f, 0.23929763f, 2.37476778f, -4.17733765f, -0.81475425f, -6.15027046f, -5.74441719f, 3.53978682f, 0.66798484f}); nd4j::ops::deconv2d_tf op; - auto result = op.execute({&input0, &input1, &input2}, {}, {7,7, 2,2, 0,0, 1,1, 1,1}); + auto result = op.evaluate({&input0, &input1, &input2}, {}, {7,7, 2,2, 0,0, 1,1, 1,1}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -189,7 +189,7 @@ TEST_F(ConvolutionTests2, Test_Dilation2D_Again_1) { nd4j::ops::dilation2d op; - auto result = op.execute({&x, &w}, {}, {1, 1,5,7,1, 1,2,3,1}); + auto result = op.evaluate({&x, &w}, {}, {1, 1,5,7,1, 1,2,3,1}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -205,7 +205,7 @@ TEST_F(ConvolutionTests2, Test_Dilation2D_Again_2) { auto w = NDArrayFactory::create('c', {11, 7, 4}); nd4j::ops::dilation2d op; - auto result = op.execute({&x, &w}, {}, {0, 1,2,3,1, 1,3,2,1}); + auto result = op.evaluate({&x, &w}, {}, {0, 1,2,3,1, 1,3,2,1}); ASSERT_EQ(Status::OK(), result->status()); delete result; @@ -246,7 +246,7 @@ TYPED_TEST(TypedConvolutionTests2, sconv2d_bp_1) { epsilonNext.applyScalar(scalar::Divide, 100.0, epsilonNext); nd4j::ops::sconv2d_bp op; - auto resultBP = op.execute({&input, &epsilonNext, &weightsD, &weightsP },{}, {5, 5, 1, 1, 0, 0, 1, 1, 0}, {}); + auto resultBP = op.evaluate({&input, &epsilonNext, &weightsD, &weightsP },{}, {5, 5, 1, 1, 0, 0, 1, 1, 0}, {}); ASSERT_EQ(3, resultBP->size()); @@ -342,7 +342,7 @@ TYPED_TEST(TypedConvolutionTests2, sconv2d_bp_3) { auto epsilon = NDArrayFactory::create('c', {3, 3, 16, 16}); nd4j::ops::sconv2d_bp op; - auto result = op.execute({&input, &epsilonNext, &weightsD, &weightsP}, {}, {2, 2, 1, 1, 0, 0, 2, 2, 0}); + auto result = op.evaluate({&input, &epsilonNext, &weightsD, &weightsP}, {}, {2, 2, 1, 1, 0, 0, 2, 2, 0}); auto eps = result->at(0); auto gWD = result->at(1); @@ -378,7 +378,7 @@ TYPED_TEST(TypedConvolutionTests2, sconv2d_bp_4) { gradO.linspace(0.01, 0.01); nd4j::ops::sconv2d_bp op; - auto results = op.execute({&input, &gradO, &weightsDepth, &bias}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); + auto results = op.evaluate({&input, &gradO, &weightsDepth, &bias}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); auto* gradI = results->at(0); auto* gradWD = results->at(1); @@ -463,7 +463,7 @@ TEST_F(ConvolutionTests2, deconv3d_test1) { weights.linspace(0.1, 0.1); nd4j::ops::deconv3d op; - auto results = op.execute({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}, {}); + auto results = op.evaluate({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}, {}); auto output = results->at(0); // output->printBuffer(); @@ -497,7 +497,7 @@ TEST_F(ConvolutionTests2, deconv3d_test2) { weights.linspace(0.1, 0.1); nd4j::ops::deconv3d op; - auto results = op.execute({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}, {}); + auto results = op.evaluate({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}, {}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -530,7 +530,7 @@ TEST_F(ConvolutionTests2, deconv3d_test3) { weights.permutei({2, 3, 4, 1, 0}); nd4j::ops::deconv3d op; - auto results = op.execute({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}, {}); + auto results = op.evaluate({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}, {}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -557,7 +557,7 @@ TEST_F(ConvolutionTests2, deconv3d_test4) { weights.permutei({2, 3, 4, 1, 0}); nd4j::ops::deconv3d op; - auto results = op.execute({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}, {}); + auto results = op.evaluate({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}, {}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -600,7 +600,7 @@ TEST_F(ConvolutionTests2, deconv3d_test5) { bias = 0.2; nd4j::ops::deconv3d op; - auto results = op.execute({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}); + auto results = op.evaluate({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}); ASSERT_EQ(Status::OK(), results->status()); auto output = results->at(0); @@ -633,7 +633,7 @@ TEST_F(ConvolutionTests2, deconv3d_bp_test1) { gradO.linspace(0.5); nd4j::ops::deconv3d_bp op; - auto results = op.execute({&input, &weights, &bias, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}, {}); + auto results = op.evaluate({&input, &weights, &bias, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}, {}); auto gradI = results->at(0); auto gradW = results->at(1); @@ -673,7 +673,7 @@ TEST_F(ConvolutionTests2, deconv3d_bp_test2) { gradO.linspace(0.5); nd4j::ops::deconv3d_bp op; - auto results = op.execute({&input, &weights, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}, {}); + auto results = op.evaluate({&input, &weights, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}, {}); auto gradI = results->at(0); auto gradW = results->at(1); @@ -708,7 +708,7 @@ TEST_F(ConvolutionTests2, deconv3d_bp_test3) { gradO.linspace(0.5); nd4j::ops::deconv3d_bp op; - auto results = op.execute({&input, &weights, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}, {}); + auto results = op.evaluate({&input, &weights, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}, {}); auto gradI = results->at(0); auto gradW = results->at(1); @@ -743,7 +743,7 @@ TEST_F(ConvolutionTests2, deconv3d_bp_test4) { gradO.linspace(0.5); nd4j::ops::deconv3d_bp op; - auto results = op.execute({&input, &weights, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}, {}); + auto results = op.evaluate({&input, &weights, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}, {}); auto gradI = results->at(0); auto gradW = results->at(1); @@ -971,7 +971,7 @@ TYPED_TEST(TypedConvolutionTests2, maxpool2d_6) { x.linspace(1); nd4j::ops::maxpool2d op; - auto result = op.execute({&x}, {}, {2, 2, 2, 2, 0, 0, 1, 1, 1, 1, 1}); + auto result = op.evaluate({&x}, {}, {2, 2, 2, 2, 0, 0, 1, 1, 1, 1, 1}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -991,7 +991,7 @@ TYPED_TEST(TypedConvolutionTests2, maxpool2d_7) { x.linspace(1); nd4j::ops::maxpool2d op; - auto result = op.execute({&x}, {}, {2, 2, 2, 2, 0, 0, 1, 1, 0, 1, 1}); + auto result = op.evaluate({&x}, {}, {2, 2, 2, 2, 0, 0, 1, 1, 0, 1, 1}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -1011,7 +1011,7 @@ TYPED_TEST(TypedConvolutionTests2, maxpool2d_8) { x.linspace(1); nd4j::ops::maxpool2d op; - auto result = op.execute({&x}, {}, {2, 2, 2, 2, 0, 0, 1, 1, 0, 1, 0}); + auto result = op.evaluate({&x}, {}, {2, 2, 2, 2, 0, 0, 1, 1, 0, 1, 0}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -1041,7 +1041,7 @@ TYPED_TEST(TypedConvolutionTests2, maxpool2d_9) { auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); nd4j::ops::maxpool2d op; - auto results = op.execute({&input}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, 1, 0}); + auto results = op.evaluate({&input}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, 1, 0}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -1068,7 +1068,7 @@ TYPED_TEST(TypedConvolutionTests2, maxpool2d_10) { 0.85722464f, 0.85722464f, 0.85019743f}); nd4j::ops::maxpool2d op; - auto results = op.execute({&input}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode}); + auto results = op.evaluate({&input}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode}); auto* output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -1088,7 +1088,7 @@ TEST_F(ConvolutionTests2, maxpool2d_11) { input.linspace(1.); nd4j::ops::maxpool2d op; - auto results = op.execute({&input}, {}, {2,2, 1,1, 1,1, 2,2, 1,0,0}); + auto results = op.evaluate({&input}, {}, {2,2, 1,1, 1,1, 2,2, 1,0,0}); ASSERT_EQ(Status::OK(), results->status()); @@ -1110,7 +1110,7 @@ TYPED_TEST(TypedConvolutionTests2, avgpool3d_test1) { input.linspace(1.); nd4j::ops::avgpool3dnew op; - auto results = op.execute({&input}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 1, dataFormat}); + auto results = op.evaluate({&input}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 1, dataFormat}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -1138,7 +1138,7 @@ TYPED_TEST(TypedConvolutionTests2, avgpool3d_test2) { input.linspace(1.); nd4j::ops::avgpool3dnew op; - auto results = op.execute({&input}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 0, dataFormat}); + auto results = op.evaluate({&input}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 0, dataFormat}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -1163,7 +1163,7 @@ TYPED_TEST(TypedConvolutionTests2, avgpool3d_test3) { input.linspace(1.); nd4j::ops::avgpool3dnew op; - auto results = op.execute({&input}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 1, dataFormat}); + auto results = op.evaluate({&input}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 1, dataFormat}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -1203,7 +1203,7 @@ TYPED_TEST(TypedConvolutionTests2, avgpool3d_test4) { input.linspace(1.); nd4j::ops::avgpool3dnew op; - auto results = op.execute({&input}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 1, dataFormat}); + auto results = op.evaluate({&input}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 1, dataFormat}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -1227,7 +1227,7 @@ TYPED_TEST(TypedConvolutionTests2, maxpool3d_test1) { input.linspace(1.); nd4j::ops::maxpool3dnew op; - auto results = op.execute({&input}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 1, dataFormat}); + auto results = op.evaluate({&input}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 1, dataFormat}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -1255,7 +1255,7 @@ TYPED_TEST(TypedConvolutionTests2, maxpool3d_test2) { input.linspace(1.); nd4j::ops::maxpool3dnew op; - auto results = op.execute({&input}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 1, dataFormat}); + auto results = op.evaluate({&input}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 1, dataFormat}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -1279,7 +1279,7 @@ TYPED_TEST(TypedConvolutionTests2, maxpool3d_test3) { input.linspace(1.); nd4j::ops::maxpool3dnew op; - auto results = op.execute({&input}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 1, dataFormat}); + auto results = op.evaluate({&input}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 1, dataFormat}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -1309,7 +1309,7 @@ TYPED_TEST(TypedConvolutionTests2, maxpool3d_test4) { input.linspace(1.); nd4j::ops::maxpool3dnew op; - auto results = op.execute({&input}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 1, dataFormat}); + auto results = op.evaluate({&input}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 1, dataFormat}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -1342,7 +1342,7 @@ TYPED_TEST(TypedConvolutionTests2, avgpool3d_bp_test1) { gradO = 2.; nd4j::ops::avgpool3dnew_bp op; - auto results = op.execute({&input, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 1, dataFormat}); + auto results = op.evaluate({&input, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 1, dataFormat}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -1376,7 +1376,7 @@ TYPED_TEST(TypedConvolutionTests2, avgpool3d_bp_test2) { gradO = 2.; nd4j::ops::avgpool3dnew_bp op; - auto results = op.execute({&input, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 1, dataFormat}); + auto results = op.evaluate({&input, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 1, dataFormat}); auto output = results->at(0); // output->printBuffer(); @@ -1411,7 +1411,7 @@ TYPED_TEST(TypedConvolutionTests2, avgpool3d_bp_test3) { gradO = 2.; nd4j::ops::avgpool3dnew_bp op; - auto results = op.execute({&input, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 0, dataFormat}); + auto results = op.evaluate({&input, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 0, dataFormat}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -1443,7 +1443,7 @@ TYPED_TEST(TypedConvolutionTests2, avgpool3d_bp_test4) { gradO = 2.; nd4j::ops::avgpool3dnew_bp op; - auto results = op.execute({&input, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 0, dataFormat}); + auto results = op.evaluate({&input, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 0, dataFormat}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -1474,7 +1474,7 @@ TYPED_TEST(TypedConvolutionTests2, maxpool3d_bp_test1) { gradO.linspace(0.1, 0.1); nd4j::ops::maxpool3dnew_bp op; - auto results = op.execute({&input, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 1, dataFormat}); + auto results = op.evaluate({&input, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 1, dataFormat}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -1507,7 +1507,7 @@ TYPED_TEST(TypedConvolutionTests2, maxpool3d_bp_test2) { gradO.linspace(0.1, 0.1); nd4j::ops::maxpool3dnew_bp op; - auto results = op.execute({&input, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 1, dataFormat}); + auto results = op.evaluate({&input, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 1, dataFormat}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -1539,7 +1539,7 @@ TYPED_TEST(TypedConvolutionTests2, maxpool3d_bp_test3) { gradO.linspace(0.1, 0.1); nd4j::ops::maxpool3dnew_bp op; - auto results = op.execute({&input, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 1, dataFormat}); + auto results = op.evaluate({&input, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 1, dataFormat}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -1571,7 +1571,7 @@ TYPED_TEST(TypedConvolutionTests2, maxpool3d_bp_test4) { gradO.linspace(0.1, 0.1); nd4j::ops::maxpool3dnew_bp op; - auto results = op.execute({&input, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 1, dataFormat}); + auto results = op.evaluate({&input, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 1, dataFormat}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -1630,7 +1630,7 @@ TEST_F(ConvolutionTests2, maxpool2d_bp_2) { std::initializer_list argI = {kH,kW, sH,sW, pH,pW, dW,dH, 0, 0, 0}; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; nd4j::ops::maxpool2d_bp op; - auto results = op.execute({&input, &epsilon}, {}, argI); + auto results = op.evaluate({&input, &epsilon}, {}, argI); auto output = results->at(0); ASSERT_TRUE(expected.isSameShape(output)); @@ -1656,7 +1656,7 @@ TYPED_TEST(TypedConvolutionTests2, maxpool2d_bp_3) { gradO.linspace(0.1, 0.1); nd4j::ops::maxpool2d_bp op; - auto results = op.execute({&input, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, 1, dataFormat}); + auto results = op.evaluate({&input, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, 1, dataFormat}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -1683,7 +1683,7 @@ TYPED_TEST(TypedConvolutionTests2, maxpool2d_bp_4) { gradO.linspace(0.1, 0.1); nd4j::ops::maxpool2d_bp op; - auto results = op.execute({&input, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, 1, dataFormat}); + auto results = op.evaluate({&input, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, 1, dataFormat}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -1710,7 +1710,7 @@ TYPED_TEST(TypedConvolutionTests2, maxpool2d_bp_5) { gradO.linspace(0.1, 0.1); nd4j::ops::maxpool2d_bp op; - auto results = op.execute({&input, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, 1, dataFormat}); + auto results = op.evaluate({&input, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, 1, dataFormat}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -1737,7 +1737,7 @@ TYPED_TEST(TypedConvolutionTests2, maxpool2d_bp_6) { gradO.linspace(0.1, 0.1); nd4j::ops::maxpool2d_bp op; - auto results = op.execute({&input, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, 1, dataFormat}); + auto results = op.evaluate({&input, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, 1, dataFormat}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -1762,7 +1762,7 @@ TEST_F(ConvolutionTests2, maxpool2d_bp_7) { gradO.linspace(0.1, 0.1); nd4j::ops::maxpool2d_bp op; - auto results = op.execute({&input, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, 1, dataFormat}); + auto results = op.evaluate({&input, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, 1, dataFormat}); // auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -1820,7 +1820,7 @@ TYPED_TEST(TypedConvolutionTests2, avgpool2d_bp_2) { std::initializer_list argI = {kH,kW, sH,sW, pH,pW, dW,dH, 1, 1, 0}; nd4j::ops::avgpool2d_bp op; - auto results = op.execute({&input, &epsilon}, {}, argI); + auto results = op.evaluate({&input, &epsilon}, {}, argI); auto output = results->at(0); ASSERT_TRUE(expected.isSameShape(output)); @@ -1849,7 +1849,7 @@ TYPED_TEST(TypedConvolutionTests2, avgpool2d_bp_3) { gradO.linspace(0.1, 0.1); nd4j::ops::avgpool2d_bp op; - auto results = op.execute({&input, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, 1, dataFormat}); + auto results = op.evaluate({&input, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, 1, dataFormat}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -1880,7 +1880,7 @@ TYPED_TEST(TypedConvolutionTests2, avgpool2d_bp_4) { gradO.linspace(0.1, 0.1); nd4j::ops::avgpool2d_bp op; - auto results = op.execute({&input, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, 1, dataFormat}); + auto results = op.evaluate({&input, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, 1, dataFormat}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -1909,7 +1909,7 @@ TYPED_TEST(TypedConvolutionTests2, avgpool2d_bp_5) { gradO.linspace(0.1, 0.1); nd4j::ops::avgpool2d_bp op; - auto results = op.execute({&input, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, 0, dataFormat}); + auto results = op.evaluate({&input, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, 0, dataFormat}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -1938,7 +1938,7 @@ TYPED_TEST(TypedConvolutionTests2, avgpool2d_bp_6) { gradO.linspace(0.1, 0.1); nd4j::ops::avgpool2d_bp op; - auto results = op.execute({&input, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, 1, dataFormat}); + auto results = op.evaluate({&input, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, 1, dataFormat}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -2002,7 +2002,7 @@ TYPED_TEST(TypedConvolutionTests2, pnormpool2d_bp_2) { gradO.linspace(0.1, 0.1); nd4j::ops::pnormpool2d_bp op; - auto results = op.execute({&input, &gradO}, {eps}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, pnorm, dataFormat}); + auto results = op.evaluate({&input, &gradO}, {eps}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, pnorm, dataFormat}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -2035,7 +2035,7 @@ TYPED_TEST(TypedConvolutionTests2, pnormpool2d_bp_3) { gradO.linspace(0.1, 0.1); nd4j::ops::pnormpool2d_bp op; - auto results = op.execute({&input, &gradO}, {eps}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, pnorm, dataFormat}); + auto results = op.evaluate({&input, &gradO}, {eps}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, pnorm, dataFormat}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -2061,7 +2061,7 @@ TEST_F(ConvolutionTests2, upsampling2d_bp_1) { expGradI = 4.; nd4j::ops::upsampling2d_bp op; - auto results = op.execute({&input, &gradO}, {}, {isNCHW}); + auto results = op.evaluate({&input, &gradO}, {}, {isNCHW}); auto* gradI = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -2086,7 +2086,7 @@ TEST_F(ConvolutionTests2, upsampling2d_bp_2) { expGradI = 4.; nd4j::ops::upsampling2d_bp op; - auto results = op.execute({&input, &gradO}, {}, {isNCHW}); + auto results = op.evaluate({&input, &gradO}, {}, {isNCHW}); auto* gradI = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -2117,7 +2117,7 @@ TEST_F(ConvolutionTests2, upsampling2d_bp_3) { 1.090545, 1.9094483, 1.3611296, 2.1195147, 2.0659215, 1.0423062, 2.3405795, 1.9105877, 1.2203633}, nd4j::DataType::FLOAT32); nd4j::ops::upsampling2d_bp op; - auto results = op.execute({&input, &gradO}, {}, {isNCHW}); + auto results = op.evaluate({&input, &gradO}, {}, {isNCHW}); auto* gradI = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -2148,7 +2148,7 @@ TYPED_TEST(TypedConvolutionTests2, depthwise_conv2d_1) { weights.linspace(0.1, 0.1); nd4j::ops::depthwise_conv2d op; - auto results = op.execute({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); + auto results = op.evaluate({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); auto* output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -2178,7 +2178,7 @@ TEST_F(ConvolutionTests2, depthwise_conv2d_2) { weights.linspace(0.1, 0.1); nd4j::ops::depthwise_conv2d op; - auto results = op.execute({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); + auto results = op.evaluate({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); auto* output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -2210,7 +2210,7 @@ TEST_F(ConvolutionTests2, depthwise_conv2d_3) { weights.permutei({2,3,1,0}); nd4j::ops::depthwise_conv2d op; - auto results = op.execute({&input, &weights, &biases}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); + auto results = op.evaluate({&input, &weights, &biases}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); auto* output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -2267,7 +2267,7 @@ TEST_F(ConvolutionTests2, depthwise_conv2d_5) { weights = 0.5; nd4j::ops::depthwise_conv2d op; - auto results = op.execute({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); + auto results = op.evaluate({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -2295,7 +2295,7 @@ TEST_F(ConvolutionTests2, depthwise_conv2d_6) { weights = 1.; nd4j::ops::depthwise_conv2d op; - ResultSet* results = op.execute({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); + auto results = op.evaluate({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); NDArray* output = results->at(0); // output.printIndexedBuffer(); @@ -2330,7 +2330,7 @@ TEST_F(ConvolutionTests2, depthwise_conv2d_7) { nd4j::ops::depthwise_conv2d op; - auto results = op.execute({&input, &weights, &biases}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); + auto results = op.evaluate({&input, &weights, &biases}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); auto* output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -2368,8 +2368,8 @@ TEST_F(ConvolutionTests2, depthwise_conv2d_8) { weights.linspace(-2, 0.1); nd4j::ops::depthwise_conv2d op; - ResultSet* results = op.execute({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); - NDArray* output = results->at(0); + auto results = op.evaluate({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); + auto output = results->at(0); // output->printBuffer(); ASSERT_EQ(Status::OK(), results->status()); @@ -2408,8 +2408,8 @@ TEST_F(ConvolutionTests2, depthwise_conv2d_9) { weights.linspace(-2, 0.1); nd4j::ops::depthwise_conv2d op; - ResultSet* results = op.execute({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); - NDArray* output = results->at(0); + auto results = op.evaluate({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); + auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); diff --git a/libnd4j/tests_cpu/layers_tests/DataTypesValidationTests.cpp b/libnd4j/tests_cpu/layers_tests/DataTypesValidationTests.cpp index e87dfa125..9ee23c36d 100644 --- a/libnd4j/tests_cpu/layers_tests/DataTypesValidationTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/DataTypesValidationTests.cpp @@ -46,7 +46,7 @@ TEST_F(DataTypesValidationTests, Basic_Test_1) { input.linspace(1); nd4j::ops::conv2d op; - auto result = op.execute({&input, &weights}, {}, {1, 1, 1, 1, 0, 0, 1, 1, 0, 0}, {}); + auto result = op.evaluate({&input, &weights}, {1, 1, 1, 1, 0, 0, 1, 1, 0, 0}); ASSERT_EQ(ND4J_STATUS_VALIDATION, result->status()); @@ -62,7 +62,7 @@ TEST_F(DataTypesValidationTests, Basic_Test_2) { input.linspace(1); nd4j::ops::conv2d op; - auto result = op.execute({&input, &weights}, {}, {1, 1, 1, 1, 0, 0, 1, 1, 0, 0}, {}); + auto result = op.evaluate({&input, &weights}, {1, 1, 1, 1, 0, 0, 1, 1, 0, 0}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp index 1e43081c1..15524a901 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp @@ -161,7 +161,7 @@ TEST_F(DeclarableOpsTests1, ApplyGradientDescent_1) { auto exp = NDArrayFactory::create('c', {3,4}); exp.linspace(0.9, 0.9); nd4j::ops::apply_sgd op; - auto result = op.execute({&x, &y}, {1.}, {}, {}, false, nd4j::DataType::DOUBLE); + auto result = op.evaluate({&x, &y}, {1.}, {}); ASSERT_EQ(result->status(), ND4J_STATUS_OK); auto z = result->at(0); @@ -175,7 +175,7 @@ TEST_F(DeclarableOpsTests1, AssignBroadcastTest_1) { auto y = NDArrayFactory::create('c', {1,4}, {0.1,0.2,0.3,0.4}); auto exp = NDArrayFactory::create('c', {3,4}, {0.1, 0.2, 0.3, 0.4, 0.1, 0.2, 0.3, 0.4, 0.1, 0.2, 0.3, 0.4}); nd4j::ops::assign op; - auto result = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::DOUBLE); + auto result = op.evaluate({&x, &y}); ASSERT_EQ(result->status(), ND4J_STATUS_OK); auto z = result->at(0); @@ -191,7 +191,7 @@ TEST_F(DeclarableOpsTests1, AssignBroadcastTest_2) { auto exp1 = NDArrayFactory::create('c', {3,4}); // zero auto exp2 = NDArrayFactory::create('c', {1,4}, {3, 6, 9, 12}); nd4j::ops::assign_bp op; - auto result = op.execute({&x, &y, &eps}, {}, {}, {}, false, nd4j::DataType::DOUBLE); + auto result = op.evaluate({&x, &y, &eps}); ASSERT_EQ(result->status(), ND4J_STATUS_OK); auto z1 = result->at(0); auto z2 = result->at(1); @@ -208,7 +208,7 @@ TEST_F(DeclarableOpsTests1, AXpY_Test_1) { auto exp = NDArrayFactory::create('c', {3,4}); exp.linspace(3, 3); nd4j::ops::axpy op; - auto result = op.execute({&x, &y}, {2.}, {}, {}, false, nd4j::DataType::DOUBLE); + auto result = op.evaluate({&x, &y}, {2.}); ASSERT_EQ(result->status(), ND4J_STATUS_OK); auto z = result->at(0); @@ -249,7 +249,7 @@ TEST_F(DeclarableOpsTests1, TestTensorMmul1) { NDArray exp('c', {2, 2}, {650.0, 1586.0, 1586.0, 4250.0}, nd4j::DataType::FLOAT32); nd4j::ops::tensormmul op; - auto results = op.execute({&x, &y}, {}, {2,1,2,2,1,2}); + auto results = op.evaluate({&x, &y}, {}, {2,1,2,2,1,2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -269,7 +269,7 @@ TEST_F(DeclarableOpsTests1, TestTensorDot2) { NDArray exp('c', {2, 2}, {2300.0, 2444.0, 2444.0, 2600.0}, nd4j::DataType::FLOAT32); nd4j::ops::tensormmul op; - auto results = op.execute({&x, &y}, {}, {2,1,2,2,1,2}); + auto results = op.evaluate({&x, &y}, {}, {2,1,2,2,1,2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -289,7 +289,7 @@ TEST_F(DeclarableOpsTests1, TestTensorDot3) { NDArray exp('f', {2, 2}, {1090.0, 2818.0, 1168.0, 3040.0}, nd4j::DataType::FLOAT32); nd4j::ops::tensormmul op; - auto results = op.execute({&x, &y}, {}, {2,1,2,2,1,2}); + auto results = op.evaluate({&x, &y}, {}, {2,1,2,2,1,2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -309,7 +309,7 @@ TEST_F(DeclarableOpsTests1, TestTensorDot4) { NDArray exp('f', {2, 2}, {1090.0, 1168.0, 2818.0, 3040.0}, nd4j::DataType::FLOAT32); nd4j::ops::tensormmul op; - auto results = op.execute({&x, &y}, {}, {2,1,2,2,1,2}); + auto results = op.evaluate({&x, &y}, {}, {2,1,2,2,1,2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -506,7 +506,7 @@ TEST_F(DeclarableOpsTests1, SubtractTest_2) { nd4j::ops::subtract subOp; - auto res = subOp.execute({&x, &y}, {}, {}); + auto res = subOp.evaluate({&x, &y}); ASSERT_TRUE(res->status() == ND4J_STATUS_OK); @@ -767,7 +767,7 @@ TEST_F(DeclarableOpsTests1, ReverseSubtractTest_1) { nd4j::ops::reversesubtract subOp; - auto res = subOp.execute({&x, &y}, {}, {}); + auto res = subOp.evaluate({&x, &y}); ASSERT_TRUE(res->status() == ND4J_STATUS_OK); ASSERT_TRUE(res->at(0)->equalsTo(&exp)); @@ -792,7 +792,7 @@ TEST_F(DeclarableOpsTests1, ReverseSubtractTest_2) { nd4j::ops::reversesubtract subOp; - auto res = subOp.execute({&x, &y}, {}, {}); + auto res = subOp.evaluate({&x, &y}); ASSERT_TRUE(res->status() == ND4J_STATUS_OK); ASSERT_TRUE(res->at(0)->equalsTo(&exp)); @@ -815,7 +815,7 @@ TEST_F(DeclarableOpsTests1, ReverseSubtractTest_3) { ASSERT_TRUE(z.equalsTo(&exp)); nd4j::ops::reversesubtract subOp; - auto res = subOp.execute({&x, &y}, {}, {}); + auto res = subOp.evaluate({&x, &y}); ASSERT_TRUE(res->status() == ND4J_STATUS_OK); ASSERT_TRUE(res->at(0)->equalsTo(&exp)); @@ -841,7 +841,7 @@ TEST_F(DeclarableOpsTests1, ReverseModTest_1) { nd4j::ops::reversemod subOp; - auto res = subOp.execute({&x, &y}, {}, {}); + auto res = subOp.evaluate({&x, &y}); ASSERT_TRUE(res->status() == ND4J_STATUS_OK); ASSERT_TRUE(res->at(0)->equalsTo(&exp)); @@ -868,7 +868,7 @@ TEST_F(DeclarableOpsTests1, ReverseModTest_2) { nd4j::ops::reversemod subOp; - auto res = subOp.execute({&x, &y}, {}, {}); + auto res = subOp.evaluate({&x, &y}); ASSERT_TRUE(res->status() == ND4J_STATUS_OK); ASSERT_TRUE(res->at(0)->equalsTo(&exp)); @@ -1157,7 +1157,7 @@ TEST_F(DeclarableOpsTests1, BroadcastDivideTest_1) { nd4j::ops::divide div; - auto res = div.execute({&x, &y}, {}, {}); + auto res = div.evaluate({&x, &y}); ASSERT_EQ(res->status(), ND4J_STATUS_OK); ASSERT_TRUE(res->at(0)->equalsTo(exp)); @@ -1176,7 +1176,7 @@ TEST_F(DeclarableOpsTests1, BroadcastDivideTest_2) { exp.assign(3); nd4j::ops::divide_no_nan div; - auto res = div.execute({&x, &y}, {}, {}); + auto res = div.evaluate({&x, &y}); ASSERT_EQ(res->status(), ND4J_STATUS_OK); ASSERT_TRUE(res->at(0)->equalsTo(exp)); @@ -1192,7 +1192,7 @@ TEST_F(DeclarableOpsTests1, BroadcastDivideTest_3) { auto exp = NDArrayFactory::create({2, 2, 0, 2, 2}); nd4j::ops::divide_no_nan div; - auto res = div.execute({&x, &y}, {}, {}); + auto res = div.evaluate({&x, &y}); ASSERT_EQ(res->status(), ND4J_STATUS_OK); ASSERT_TRUE(res->at(0)->equalsTo(exp)); @@ -1212,7 +1212,7 @@ TEST_F(DeclarableOpsTests1, BroadcastReverseDivideTest_1) { nd4j::ops::reversedivide div; - auto res = div.execute({&x, &y}, {}, {}); + auto res = div.evaluate({&x, &y}); ASSERT_EQ(res->status(), ND4J_STATUS_OK); @@ -1469,7 +1469,7 @@ TEST_F(DeclarableOpsTests1, Test_Cast_1) { yExp.linspace(1); nd4j::ops::cast op; - auto result = op.execute({&x}, {}, {3}); + auto result = op.evaluate({&x}, {}, {3}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -1673,7 +1673,7 @@ TEST_F(DeclarableOpsTests1, Reshape3) { auto x = NDArrayFactory::create('c', {3, 4, 5}); nd4j::ops::reshape op; - auto result = op.execute({&x}, {}, {-99, 3, 4, 5}); + auto result = op.evaluate({&x}, {}, {-99, 3, 4, 5}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -1688,7 +1688,7 @@ TEST_F(DeclarableOpsTests1, Reshape4) { auto x = NDArrayFactory::create('c', {3, 4, 5}); nd4j::ops::reshape op; - auto result = op.execute({&x}, {}, {3, 4, 5}); + auto result = op.evaluate({&x}, {}, {3, 4, 5}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -1703,7 +1703,7 @@ TEST_F(DeclarableOpsTests1, Reshape5) { auto x = NDArrayFactory::create('c', {3, 4, 5}); nd4j::ops::reshape op; - auto result = op.execute({&x}, {}, {5, 4, 3}); + auto result = op.evaluate({&x}, {}, {5, 4, 3}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -1715,7 +1715,7 @@ TEST_F(DeclarableOpsTests1, Reshape6){ auto exp = NDArrayFactory::create('c', {4, 15}); nd4j::ops::reshape op; - auto result = op.execute({&x}, {}, {4, -1}); + auto result = op.evaluate({&x}, {}, {4, -1}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -1732,7 +1732,7 @@ TEST_F(DeclarableOpsTests1, Reshape7){ auto exp = NDArrayFactory::create('c', {60}); nd4j::ops::reshape op; - auto result = op.execute({&x}, {}, {-1}); + auto result = op.evaluate({&x}, {}, {-1}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -2217,7 +2217,7 @@ TEST_F(DeclarableOpsTests1, IsMax1) { exp.p(2, 2, true); nd4j::ops::ismax ismaxOp; - auto result = ismaxOp.execute({&x}, {}, {1}); + auto result = ismaxOp.evaluate({&x}, {}, {1}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -2239,7 +2239,7 @@ TEST_F(DeclarableOpsTests1, IsMax2) { exp.p(2, 2, true); nd4j::ops::ismax ismaxOp; - auto result = ismaxOp.execute({&x}, {}, {0, 1}); + auto result = ismaxOp.evaluate({&x}, {}, {0, 1}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -2261,7 +2261,7 @@ TEST_F(DeclarableOpsTests1, IsMax3) { //exp.p(2, 2, true); nd4j::ops::ismax ismaxOp; - auto result = ismaxOp.execute({&x}, {}, {0}); + auto result = ismaxOp.evaluate({&x}, {}, {0}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -2279,7 +2279,7 @@ TEST_F(DeclarableOpsTests1, IsMax4) { auto e = NDArrayFactory::create('c', {6}, {false, false, false, true, false, false}); nd4j::ops::ismax op; - auto result = op.execute({&x}, {&z}, {}, {}, {}); + auto result = op.execute({&x}, {&z}); ASSERT_EQ(Status::OK(), result); ASSERT_EQ(e, z); @@ -2343,7 +2343,7 @@ TEST_F(DeclarableOpsTests1, sru_test1) { mask.assign(1.); nd4j::ops::sru op; - auto results = op.execute({&input, &weights, &bias, &init, &mask}, {}, {}); + auto results = op.evaluate({&input, &weights, &bias, &init, &mask}); ASSERT_TRUE(results->size() == 2); auto output = results->at(0); @@ -2390,7 +2390,7 @@ TEST_F(DeclarableOpsTests1, sru_bp) { inGradH.assign(0.5); nd4j::ops::sru_bp bp; - auto resultsBP = bp.execute({&input, &weights, &bias, &init, &state, &inGradCt, &inGradH, &mask}, {}, {}); + auto resultsBP = bp.evaluate({&input, &weights, &bias, &init, &state, &inGradCt, &inGradH, &mask}, {}, {}); ASSERT_TRUE(resultsBP->size() == 4); auto gradX = resultsBP->at(0); @@ -2429,7 +2429,7 @@ TEST_F(DeclarableOpsTests1, sru_bi_1) { mask.assign(1.); nd4j::ops::sru_bi op; - auto results = op.execute({&input, &weights, &bias, &init, &mask}, {}, {}); + auto results = op.evaluate({&input, &weights, &bias, &init, &mask}, {}, {}); ASSERT_TRUE(results->size() == 2); auto output = results->at(0); @@ -2480,7 +2480,7 @@ TEST_F(DeclarableOpsTests1, sru_bi_bp_1) { inGradH.assign(0.5); nd4j::ops::sru_bi_bp bp; - auto resultsBP = bp.execute({&input, &weights, &bias, &init, &state, &inGradCt, &inGradH, &mask}, {}, {}); + auto resultsBP = bp.evaluate({&input, &weights, &bias, &init, &state, &inGradCt, &inGradH, &mask}, {}, {}); ASSERT_TRUE(resultsBP->size() == 4); auto gradX = resultsBP->at(0); @@ -2504,7 +2504,7 @@ TEST_F(DeclarableOpsTests1, ArgMax1) { nd4j::ops::argmax op; - auto result = op.execute({&x}, {}, {1}); + auto result = op.evaluate({&x}, {}, {1}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -2525,7 +2525,7 @@ TEST_F(DeclarableOpsTests1, ArgMax2) { nd4j::ops::argmax op; - auto result = op.execute({&x}, {}, {0}); + auto result = op.evaluate({&x}, {}, {0}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -2547,7 +2547,7 @@ TEST_F(DeclarableOpsTests1, ArgMax3) { nd4j::ops::argmax op; - auto result = op.execute({&x, &dim}, {}, {}); + auto result = op.evaluate({&x, &dim}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -2568,7 +2568,7 @@ TEST_F(DeclarableOpsTests1, ArgMax4) { nd4j::ops::argmax op; - auto result = op.execute({&x, &dim}, {}, {}); + auto result = op.evaluate({&x, &dim}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -2590,7 +2590,7 @@ TEST_F(DeclarableOpsTests1, ArgMax5) { nd4j::ops::argmax op; - auto result = op.execute({&x, &dim}, {}, {}); + auto result = op.evaluate({&x, &dim}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -2610,12 +2610,12 @@ TEST_F(DeclarableOpsTests1, ArgMax6) { nd4j::ops::argmax op; - auto expected = op.execute({&x}, {}, {2}); + auto expected = op.evaluate({&x}, {}, {2}); ASSERT_EQ(Status::OK(), expected->status()); auto exp = expected->at(0); - auto result = op.execute({&x, &dim}, {}, {}); + auto result = op.evaluate({&x, &dim}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -2636,7 +2636,7 @@ TEST_F(DeclarableOpsTests1, ArgMin1) { nd4j::ops::argmin op; - auto result = op.execute({&x}, {}, {1}); + auto result = op.evaluate({&x}, {}, {1}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -2659,7 +2659,7 @@ TEST_F(DeclarableOpsTests1, SquareTests1) { nd4j::ops::square op; - auto result = op.execute({&x}, {}, {}); + auto result = op.evaluate({&x}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -2677,7 +2677,7 @@ TEST_F(DeclarableOpsTests1, OneHotTests_1) { nd4j::ops::onehot op; - auto result = op.execute({&indices}, {1.0f, 0.0f}, {-1, 3}); + auto result = op.evaluate({&indices}, {1.0f, 0.0f}, {-1, 3}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -2695,7 +2695,7 @@ TEST_F(DeclarableOpsTests1, OneHotTests_2) { auto exp = NDArrayFactory::create('c', {2, 2, 3}, {1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f, 1.f, 0.f, 0.f, 0.f, 0.f}); nd4j::ops::onehot op; - auto result = op.execute({&indices}, {1.0f, 0.0f}, {-1, 3}); + auto result = op.evaluate({&indices}, {1.0f, 0.0f}, {-1, 3}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -2715,7 +2715,7 @@ TEST_F(DeclarableOpsTests1, OneHotTests_3) { nd4j::ops::onehot op; - auto result = op.execute({&indices}, {1.0f, 0.0f}, {-1, 3}); + auto result = op.evaluate({&indices}, {1.0f, 0.0f}, {-1, 3}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -2736,7 +2736,7 @@ TEST_F(DeclarableOpsTests1, OneHotTests_4) { nd4j::ops::onehot op; - auto result = op.execute({&indices, &depth}, {1.0f, 0.0f}, {}); + auto result = op.evaluate({&indices, &depth}, {1.0f, 0.0f}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -2757,7 +2757,7 @@ TEST_F(DeclarableOpsTests1, OneHotTests_5) { nd4j::ops::onehot op; - auto result = op.execute({&indices, &depth, &on, &off}, {}, {}); + auto result = op.evaluate({&indices, &depth, &on, &off}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -2769,11 +2769,24 @@ TEST_F(DeclarableOpsTests1, OneHotTests_5) { } TEST_F(DeclarableOpsTests1, OneHotTests_6) { - auto indices = NDArrayFactory::create('c', {3}, {0., 1., 2.}); - auto e = NDArrayFactory::create('c', {3, 3}, {1., 0., 0., 0., 1., 0., 0., 0., 1.}); + auto indices = NDArrayFactory::create('c', {3}, {0.f, 1.f, 2.f}); + auto e = NDArrayFactory::create('c', {3, 3}, {1.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 1.f}); nd4j::ops::onehot op; - auto result = op.execute({&indices}, {1.0, 0.0}, {0, 3}); + auto result = op.evaluate({&indices}, {1.0, 0.0}, {0, 3}); + auto z = result->at(0); + + ASSERT_EQ(e, *z); + + delete result; +} + +TEST_F(DeclarableOpsTests1, OneHotTests_7) { + auto indices = NDArrayFactory::create('c', {3}, {0, 1, 2}); + auto e = NDArrayFactory::create('c', {3, 3}, {1., 0., 0., 0., 1., 0., 0., 0., 1.}); + + nd4j::ops::onehot op; + auto result = op.evaluate({&indices}, {1.0, 0.0}, {0, 3}, {}, {nd4j::DataType::HALF}, false); auto z = result->at(0); ASSERT_EQ(e, *z); @@ -2788,7 +2801,7 @@ TEST_F(DeclarableOpsTests1, FillAs_1) { float scalar = 119.f; nd4j::ops::fill_as op; - auto result = op.execute({&x}, {scalar}, {}); + auto result = op.evaluate({&x}, {scalar}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -2824,7 +2837,7 @@ TEST_F(DeclarableOpsTests1, Stack_1) { NDArray expected(expBuff, expShape); nd4j::ops::stack op; - auto results = op.execute({&input1, &input2}, {}, {0}); + auto results = op.evaluate({&input1, &input2}, {}, {0}); auto output = results->at(0); ASSERT_TRUE(expected.isSameShapeStrict(*output)); @@ -2852,7 +2865,7 @@ TEST_F(DeclarableOpsTests1, Stack_2) { NDArray expected(expBuff, expShape); nd4j::ops::stack op; - auto results = op.execute({&input1, &input2}, {}, {1}); + auto results = op.evaluate({&input1, &input2}, {}, {1}); auto output = results->at(0); ASSERT_TRUE(expected.isSameShapeStrict(*output)); @@ -2880,7 +2893,7 @@ TEST_F(DeclarableOpsTests1, Stack_3) { NDArray expected(expBuff, expShape); nd4j::ops::stack op; - auto results = op.execute({&input1, &input2}, {}, {0}); + auto results = op.evaluate({&input1, &input2}, {}, {0}); auto output = results->at(0); ASSERT_TRUE(expected.isSameShapeStrict(*output)); @@ -2907,7 +2920,7 @@ TEST_F(DeclarableOpsTests1, Stack_4) { NDArray expected(expBuff, expShape); nd4j::ops::stack op; - auto results = op.execute({&input1, &input2}, {}, {1}); + auto results = op.evaluate({&input1, &input2}, {}, {1}); auto output = results->at(0); ASSERT_TRUE(expected.isSameShapeStrict(*output)); @@ -2934,7 +2947,7 @@ TEST_F(DeclarableOpsTests1, Stack_5) { NDArray expected(expBuff, expShape); nd4j::ops::stack op; - auto results = op.execute({&input1, &input2}, {}, {0}); + auto results = op.evaluate({&input1, &input2}, {}, {0}); auto output = results->at(0); ASSERT_TRUE(expected.isSameShapeStrict(*output)); @@ -2961,7 +2974,7 @@ TEST_F(DeclarableOpsTests1, Stack_6) { NDArray expected(expBuff, expShape); nd4j::ops::stack op; - auto results = op.execute({&input1, &input2}, {}, {1}); + auto results = op.evaluate({&input1, &input2}, {}, {1}); auto output = results->at(0); ASSERT_TRUE(expected.isSameShapeStrict(*output)); @@ -2985,7 +2998,7 @@ TEST_F(DeclarableOpsTests1, Stack_7) { NDArray expected(expBuff, expShape); nd4j::ops::stack op; - auto results = op.execute({&input1, &input1, &input1}, {}, {0}); + auto results = op.evaluate({&input1, &input1, &input1}, {}, {0}); auto output = results->at(0); ASSERT_TRUE(expected.isSameShapeStrict(*output)); @@ -3008,7 +3021,7 @@ TEST_F(DeclarableOpsTests1, Stack_8) { NDArray expected(expBuff, expShape); nd4j::ops::stack op; - auto results = op.execute({&input1, &input1, &input1}, {}, {0}); + auto results = op.evaluate({&input1, &input1, &input1}, {}, {0}); auto output = results->at(0); ASSERT_TRUE(expected.isSameShapeStrict(*output)); @@ -3031,7 +3044,7 @@ TEST_F(DeclarableOpsTests1, Stack_9) { NDArray expected(expBuff, expShape); nd4j::ops::stack op; - auto results = op.execute({&input1, &input1, &input1}, {}, {1}); + auto results = op.evaluate({&input1, &input1, &input1}, {}, {1}); auto output = results->at(0); ASSERT_TRUE(expected.isSameShapeStrict(*output)); @@ -3054,7 +3067,7 @@ TEST_F(DeclarableOpsTests1, Stack_10) { NDArray expected(expBuff, expShape); nd4j::ops::stack op; - auto results = op.execute({&input1, &input1, &input1}, {}, {1}); + auto results = op.evaluate({&input1, &input1, &input1}, {}, {1}); auto output = results->at(0); //expected.printShapeInfo("exp"); @@ -3079,7 +3092,7 @@ TEST_F(DeclarableOpsTests1, Stack_11) { NDArray expected(expBuff, expShape); nd4j::ops::stack op; - auto results = op.execute({&input1, &input1, &input1}, {}, {}); + auto results = op.evaluate({&input1, &input1, &input1}, {}, {}); auto output = results->at(0); ASSERT_TRUE(expected.isSameShapeStrict(*output)); @@ -3095,7 +3108,7 @@ TEST_F(DeclarableOpsTests1, Test_Range_Integer_1) { nd4j::ops::range op; - auto result = op.execute({}, {}, {1, 5, 1}); + auto result = op.evaluate({}, {}, {1, 5, 1}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(1, result->size()); @@ -3122,7 +3135,7 @@ TEST_F(DeclarableOpsTests1, Test_Range_Integer_2) { nd4j::ops::range op; - auto result = op.execute({&start, &stop, &step}, {}, {}); + auto result = op.evaluate({&start, &stop, &step}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(1, result->size()); @@ -3142,7 +3155,7 @@ TEST_F(DeclarableOpsTests1, Test_Range_Integer_3) { nd4j::ops::range op; - auto result = op.execute({}, {1.f, 5.f, 1.f}, {}); + auto result = op.evaluate({}, {1.f, 5.f, 1.f}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(1, result->size()); @@ -3161,7 +3174,7 @@ TEST_F(DeclarableOpsTests1, softmax_test1) { auto expOutput = NDArrayFactory::create('c', {3, 3}, {1.14195199e-01, 8.43794734e-01, 4.20100661e-02, 2.68454951e-01, 1.80883523e-03, 7.29736214e-01, 9.02116571e-05, 2.68917160e-01, 7.30992629e-01}); nd4j::ops::softmax op; - auto results = op.execute({&input}, {}, {}, {}, false, nd4j::DataType::DOUBLE); + auto results = op.evaluate({&input}, {}, {}, {}); auto z = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -3177,7 +3190,7 @@ TEST_F(DeclarableOpsTests1, softmax_test2) { auto expOutput = NDArrayFactory::create('c', {3, 3, 3}, {4.73142e-02,4.73847e-02,6.69062e-03, 9.50330e-01,8.67881e-04,9.92976e-01, 2.35563e-03,9.51747e-01,3.33106e-04, 4.74259e-02,2.26032e-06,4.74259e-02, 2.91395e-07,9.99998e-01,3.94360e-08, 9.52574e-01,1.12535e-07,9.52574e-01, 7.58256e-10,4.74259e-02,1.22325e-11, 1.00000e+00,1.32293e-11,1.19203e-01, 3.77513e-11,9.52574e-01,8.80797e-01}); nd4j::ops::softmax op; - auto results = op.execute({&input}, {}, {1}, {}, false, nd4j::DataType::DOUBLE); + auto results = op.evaluate({&input}, {}, {1}, {}); auto z = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -3193,7 +3206,7 @@ TEST_F(DeclarableOpsTests1, softmax_test3) { auto expOutput = NDArrayFactory::create('c', {3, 3, 3}, {2.47262e-03,1.23395e-04,3.35350e-04, 1.23395e-04,4.53979e-05,1.23395e-04, 6.14417e-06,1.23395e-04,5.56530e-09, 9.97527e-01,1.12521e-07,9.99665e-01, 1.52281e-08,9.99955e-01,2.06090e-09, 9.99994e-01,2.78912e-10,6.69285e-03, 3.05146e-07,9.99876e-01,4.13855e-08, 9.99877e-01,5.60254e-09,9.99877e-01, 7.58251e-10,9.99877e-01,9.93307e-01}); nd4j::ops::softmax op; - auto results = op.execute({&input}, {}, {0}, {}, false, nd4j::DataType::DOUBLE); + auto results = op.evaluate({&input}, {}, {0}, {}); auto z = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -3209,7 +3222,7 @@ TEST_F(DeclarableOpsTests1, softmax_test4) { auto expOutput = NDArrayFactory::create('c', {1, 5}, {0.01198,0.08855,0.00441,0.24072,0.65434}); nd4j::ops::softmax op; - auto results = op.execute({&input}, {}, {1}, {}, false, nd4j::DataType::DOUBLE); + auto results = op.evaluate({&input}, {}, {1}, {}); auto z = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -3225,7 +3238,7 @@ TEST_F(DeclarableOpsTests1, softmax_test5) { auto expOutput = NDArrayFactory::create('c', {1, 5}, {1,1,1,1,1}); nd4j::ops::softmax op; - auto results = op.execute({&input}, {}, {0}, {}, false, nd4j::DataType::DOUBLE); + auto results = op.evaluate({&input}, {}, {0}); auto z = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -3241,7 +3254,7 @@ TEST_F(DeclarableOpsTests1, softmax_test6) { auto expOutput = NDArrayFactory::create('c', {5, 1}, {0.01198,0.08855,0.00441,0.24072,0.65434}); nd4j::ops::softmax op; - auto results = op.execute({&input}, {}, {0}, {}, false, nd4j::DataType::DOUBLE); + auto results = op.evaluate({&input}, {}, {0}, {}); auto z = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -3257,7 +3270,7 @@ TEST_F(DeclarableOpsTests1, softmax_test7) { auto expOutput = NDArrayFactory::create('c', {5, 1}, {1,1,1,1,1}); nd4j::ops::softmax op; - auto results = op.execute({&input}, {}, {1}, {}, false, nd4j::DataType::DOUBLE); + auto results = op.evaluate({&input}, {}, {1}, {}); auto z = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -3273,7 +3286,7 @@ TEST_F(DeclarableOpsTests1, softmax_test8) { auto expOutput = NDArrayFactory::create('c', {5}, {0.01198,0.08855,0.00441,0.24072,0.65434}); nd4j::ops::softmax op; - auto results = op.execute({&input}, {}, {}, {}, false, nd4j::DataType::DOUBLE); + auto results = op.evaluate({&input}, {}, {}, {}); auto z = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -3294,7 +3307,7 @@ TEST_F(DeclarableOpsTests1, Test_Stack_Edge_1) { nd4j::ops::stack op; - auto result = op.execute({&input}, {}, {0}); + auto result = op.evaluate({&input}, {}, {0}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -3316,7 +3329,7 @@ TEST_F(DeclarableOpsTests1, Test_Stack_Edge_2) { nd4j::ops::stack op; - auto result = op.execute({&input}, {}, {0}); + auto result = op.evaluate({&input}, {}, {0}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -3338,7 +3351,7 @@ TEST_F(DeclarableOpsTests1, Test_Stack_Edge_3) { nd4j::ops::stack op; - auto result = op.execute({&input}, {}, {1}); + auto result = op.evaluate({&input}, {}, {1}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -3364,7 +3377,7 @@ TEST_F(DeclarableOpsTests1, Reverse_1 ) { NDArray output(shapeInfo); nd4j::ops::reverse op; - auto results = op.execute({&input}, {}, {0,1,2}); + auto results = op.evaluate({&input}, {}, {0,1,2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -3389,7 +3402,7 @@ TEST_F(DeclarableOpsTests1, Reverse_2 ) { NDArray output(shapeInfo); nd4j::ops::reverse op; - auto results = op.execute({&input}, {}, {}, {}, true); + auto results = op.evaluate({&input}, {}, {}, {}, {}, true); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -3414,7 +3427,7 @@ TEST_F(DeclarableOpsTests1, Reverse_3 ) { NDArray output(shapeInfo); nd4j::ops::reverse op; - auto results = op.execute({&input}, {}, {1,2}); + auto results = op.evaluate({&input}, {}, {1,2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -3440,7 +3453,7 @@ TEST_F(DeclarableOpsTests1, Reverse_4 ) { NDArray output(shapeInfo); nd4j::ops::reverse op; - auto results = op.execute({&input}, {}, {0,2}); + auto results = op.evaluate({&input}, {}, {0,2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -3466,7 +3479,7 @@ TEST_F(DeclarableOpsTests1, Reverse_5 ) { NDArray output(shapeInfo); nd4j::ops::reverse op; - auto results = op.execute({&input}, {}, {0,1}); + auto results = op.evaluate({&input}, {}, {0,1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -3491,7 +3504,7 @@ TEST_F(DeclarableOpsTests1, Reverse_6 ) { NDArray output(shapeInfo); nd4j::ops::reverse op; - auto results = op.execute({&input}, {}, {2}, {}, true); + auto results = op.evaluate({&input}, {}, {2}, {}, {}, true); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -3518,7 +3531,7 @@ TEST_F(DeclarableOpsTests1, Reverse_7 ) { NDArray output(shapeInfo); nd4j::ops::reverse op; - auto results = op.execute({&input}, {}, {1}); + auto results = op.evaluate({&input}, {}, {1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -3547,7 +3560,7 @@ TEST_F(DeclarableOpsTests1, Reverse_8 ) { NDArray output(shapeInfo); nd4j::ops::reverse op; - auto results = op.execute({&input}, {}, {2,1}); + auto results = op.evaluate({&input}, {}, {2,1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -3573,7 +3586,7 @@ TEST_F(DeclarableOpsTests1, Reverse_9 ) { NDArray output(shapeInfo); nd4j::ops::reverse op; - auto results = op.execute({&input}, {}, {0}); + auto results = op.evaluate({&input}, {}, {0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -3591,7 +3604,7 @@ TEST_F(DeclarableOpsTests1, Reverse_10 ) { auto e = NDArrayFactory::create('c', {4, 3}, {0.09966054, 0.1592365, 1.5375735, -1.0355669, 1.144433, 0.677872,0.85020787, -0.67863184, 0.48456487, -1.1660044, 0.20998026, 0.13950661}); nd4j::ops::reverse op; - auto result = op.execute({&x, &i}, {}, {}, {}, false, nd4j::DataType::DOUBLE); + auto result = op.evaluate({&x, &i}, {}, {}, {}); auto z = result->at(0); @@ -3612,7 +3625,7 @@ TEST_F(DeclarableOpsTests1, Reverse_11 ) { input.linspace(1); nd4j::ops::reverse op; - auto results = op.execute({&input}, {}, {0, 1, 2}); + auto results = op.evaluate({&input}, {}, {0, 1, 2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -3633,7 +3646,7 @@ TEST_F(DeclarableOpsTests1, Reverse_12 ) { //input.linspace(1); nd4j::ops::reverse op; - auto results = op.execute({&input}, {}, {0}); + auto results = op.evaluate({&input}, {}, {0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -3655,7 +3668,7 @@ TEST_F(DeclarableOpsTests1, Reverse_13 ) { //input.linspace(1); nd4j::ops::reverse op; - auto results = op.execute({&input}, {}, {-1}); + auto results = op.evaluate({&input}, {}, {-1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -3676,7 +3689,7 @@ TEST_F(DeclarableOpsTests1, Reverse_14 ) { //input.linspace(1); nd4j::ops::reverse op; - auto results = op.execute({&input}, {}, {}, {}, false, nd4j::DataType::DOUBLE); + auto results = op.evaluate({&input}, {}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -3694,7 +3707,7 @@ TEST_F(DeclarableOpsTests1, Test_Expose_1) { nd4j::ops::expose op; - auto result = op.execute({&input0, &input1}, {}, {}); + auto result = op.evaluate({&input0, &input1}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp index 689969543..4875ce8c5 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp @@ -60,7 +60,7 @@ TEST_F(DeclarableOpsTests10, Test_ArgMax_1) { nd4j::ops::argmax op; - auto result = op.execute({&x}, {}, {}, {}); + auto result = op.evaluate({&x}); ASSERT_EQ(Status::OK(), result->status()); @@ -79,7 +79,7 @@ TEST_F(DeclarableOpsTests10, Test_ArgMax_2) { x.linspace(1.0); nd4j::ops::argmax op; - auto result = op.execute({&x, &y}, {}, {}, {}); + auto result = op.evaluate({&x, &y}); ASSERT_EQ(Status::OK(), result->status()); auto z = *result->at(0); @@ -98,7 +98,7 @@ TEST_F(DeclarableOpsTests10, Test_And_1) { auto e = NDArrayFactory::create('c', {4}, {0, 0, 0, 1}); nd4j::ops::boolean_and op; - auto result = op.execute({&x, &y}, {}, {}, {}); + auto result = op.evaluate({&x, &y}); ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(e, *result->at(0)); @@ -112,7 +112,7 @@ TEST_F(DeclarableOpsTests10, Test_Or_1) { auto e = NDArrayFactory::create('c', {4}, {1, 1, 0, 1}); nd4j::ops::boolean_or op; - auto result = op.execute({&x, &y}, {}, {}, {}); + auto result = op.evaluate({&x, &y}); ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(e, *result->at(0)); @@ -127,7 +127,7 @@ TEST_F(DeclarableOpsTests10, Test_Not_1) { auto e = NDArrayFactory::create('c', {4}, {false, false, true, false}); nd4j::ops::boolean_not op; - auto result = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::BOOL); + auto result = op.evaluate({&x, &y}); ASSERT_EQ(Status::OK(), result->status()); auto res = result->at(0); @@ -141,7 +141,7 @@ TEST_F(DeclarableOpsTests10, Test_Size_at_1) { auto e = NDArrayFactory::create(20); nd4j::ops::size_at op; - auto result = op.execute({&x}, {}, {1}); + auto result = op.evaluate({&x}, {1}); ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(e, *result->at(0)); @@ -161,7 +161,7 @@ TEST_F(DeclarableOpsTests10, MirrorPad_SGO_Test_1) { nd4j::ops::mirror_pad op; - auto res = op.execute({&in, &pad}, {10.0}, {0}, {}, false, nd4j::DataType::DOUBLE); + auto res = op.evaluate({&in, &pad}, {10.0}, {0}); ASSERT_EQ(res->status(), ND4J_STATUS_OK); ASSERT_TRUE(exp.equalsTo(res->at(0))); @@ -175,7 +175,7 @@ TEST_F(DeclarableOpsTests10, Unique_SGO_Test_1) { auto exp = NDArrayFactory::create({3., 4., 1., 0., 2.}); nd4j::ops::unique op; - auto res = op.execute({&input}, {}, {}); + auto res = op.evaluate({&input}, {}, {}); ASSERT_EQ(res->status(), ND4J_STATUS_OK); auto res1 = res->at(0); auto res2 = res->at(1); @@ -192,7 +192,7 @@ TEST_F(DeclarableOpsTests10, Where_SGO_Test_1) { auto exp = NDArrayFactory::create('c', {6, 2}, {0LL, 0LL, 1LL, 0LL, 1LL, 1LL, 2LL, 0LL, 2LL, 1LL, 2LL, 2LL}); nd4j::ops::Where op; - auto res = op.execute({&input}, {}, {}); + auto res = op.evaluate({&input}, {}, {}); ASSERT_TRUE(res->status() == ND4J_STATUS_OK); auto resA = res->at(0); @@ -209,7 +209,7 @@ TEST_F(DeclarableOpsTests10, Where_SGO_Test_02) { auto exp = NDArrayFactory::create('c', {5, 3}, {0LL, 0LL, 0LL, 0LL, 1LL, 1LL, 1LL, 0LL, 0LL, 1LL, 0LL, 1LL, 1LL, 1LL, 0LL}); nd4j::ops::Where op; - auto res = op.execute({&input}, {}, {}); + auto res = op.evaluate({&input}, {}, {}); ASSERT_TRUE(res->status() == ND4J_STATUS_OK); auto resA = res->at(0); @@ -227,7 +227,7 @@ TEST_F(DeclarableOpsTests10, WhereNP_SGO_Test_1) { auto exp2 = NDArrayFactory::create({0, 1, 0, 0, 1}); auto exp3 = NDArrayFactory::create({0, 1, 0, 1, 0}); nd4j::ops::where_np op; - auto res = op.execute({&cond3d}, {}, {}); + auto res = op.evaluate({&cond3d}, {}, {}); ASSERT_TRUE(res->size() == 3); ASSERT_EQ(res->status(), ND4J_STATUS_OK); auto res1 = res->at(0); @@ -251,7 +251,7 @@ TEST_F(DeclarableOpsTests10, WhereNP_SGO_Test_2) { auto exp1 = NDArrayFactory::create({0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2}); auto exp2 = NDArrayFactory::create({0, 1, 4, 0, 1, 2, 3, 4, 1, 2, 3, 4}); nd4j::ops::where_np op; - auto res = op.execute({&cond2d}, {}, {}); + auto res = op.evaluate({&cond2d}, {}, {}); ASSERT_TRUE(res->size() == 2); ASSERT_TRUE(res->status() == ND4J_STATUS_OK); ASSERT_TRUE(exp1.equalsTo(res->at(0))); @@ -267,7 +267,7 @@ TEST_F(DeclarableOpsTests10, Where_SGO_Test_2) { auto exp = NDArrayFactory::create('c', {4,1}, {0, 2, 3, 4}); nd4j::ops::Where op; - auto res = op.execute({&input}, {}, {}, {}, false, nd4j::DataType::INT64); + auto res = op.evaluate({&input}); ASSERT_TRUE(res->status() == ND4J_STATUS_OK); auto resA = res->at(0); // resA->printIndexedBuffer("Result A"); @@ -285,7 +285,7 @@ TEST_F(DeclarableOpsTests10, Where_SGO_Test_3) { auto exp = NDArrayFactory::create('c', {4, 2}, {0, 0, 2, 0, 3, 0, 4, 0}); nd4j::ops::Where op; - auto res = op.execute({&input}, {}, {}); + auto res = op.evaluate({&input}, {}, {}); ASSERT_TRUE(res->status() == ND4J_STATUS_OK); auto resA = res->at(0); //resA->printIndexedBuffer("Result A"); @@ -303,7 +303,7 @@ TEST_F(DeclarableOpsTests10, Where_SGO_Test_4) { auto exp = NDArrayFactory::create('c', {4, 2}, {0, 0, 2, 0, 3, 0, 4, 0}); nd4j::ops::Where op; - auto res = op.execute({&input}, {}, {}); + auto res = op.evaluate({&input}, {}, {}); ASSERT_TRUE(res->status() == ND4J_STATUS_OK); auto resA = res->at(0); ASSERT_TRUE(resA->isEmpty()); @@ -322,7 +322,7 @@ TEST_F(DeclarableOpsTests10, Where_SGO_Test_5) { auto exp = NDArrayFactory::create('c', {3, 1}, {0, 3, 4}); nd4j::ops::Where op; - auto res = op.execute({&input}, {}, {}); + auto res = op.evaluate({&input}, {}, {}); ASSERT_TRUE(res->status() == ND4J_STATUS_OK); auto resA = res->at(0); //ASSERT_TRUE(resA->isEmpty()); @@ -340,7 +340,7 @@ TEST_F(DeclarableOpsTests10, WhereNP_SGO_Test_4) { auto exp = NDArrayFactory::create('c', {4, 2}, {0, 0, 2, 0, 3, 0, 4, 0}); nd4j::ops::where_np op; - auto res = op.execute({&input}, {}, {}); + auto res = op.evaluate({&input}, {}, {}); ASSERT_TRUE(res->status() == ND4J_STATUS_OK); auto resA = res->at(0); ASSERT_TRUE(resA->isEmpty()); @@ -361,7 +361,7 @@ TEST_F(DeclarableOpsTests10, CosineDistance_SGO_Test_1) { auto exp = NDArrayFactory::create(0.6); nd4j::ops::cosine_distance_loss op; - auto res = op.execute({&predictions, &weights, &labels}, {}, {3, 1}); + auto res = op.evaluate({&predictions, &weights, &labels}, {}, {3, 1}); ASSERT_TRUE(res->status() == ND4J_STATUS_OK); auto resA = res->at(0); @@ -379,7 +379,7 @@ TEST_F(DeclarableOpsTests10, CosineDistance_SGO_Test_2) { auto exp = NDArrayFactory::create(0.6); nd4j::ops::cosine_distance_loss op; - auto res = op.execute({&predictions, &weights, &labels}, {}, {2, 1}); + auto res = op.evaluate({&predictions, &weights, &labels}, {}, {2, 1}); ASSERT_TRUE(res->status() == ND4J_STATUS_OK); auto resA = res->at(0); @@ -402,7 +402,7 @@ TEST_F(DeclarableOpsTests10, TestMarixBandPart_Test_1) { exp.p(1, 2, 0, 0.); nd4j::ops::matrix_band_part op; - auto results = op.execute({&x}, {}, {1, 1}, {}, false, nd4j::DataType::DOUBLE); + auto results = op.evaluate({&x}, {}, {1, 1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); //results->at(0)->printIndexedBuffer("MBP Test1"); @@ -422,7 +422,7 @@ TEST_F(DeclarableOpsTests10, atan2_test1) { 0.33172, 0.69614, 0.81846, 0.87776, 0.91253, 0.93533, 0.95141, 0.96336, 0.97259, 0.97993, 0.98591, 1.01266,}); nd4j::ops::tf_atan2 op; - auto result = op.execute({&y, &x}, {}, {}); + auto result = op.evaluate({&y, &x}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -442,7 +442,7 @@ TEST_F(DeclarableOpsTests10, atan2_test2) { 3.11208, 2.99987, 2.83399, 2.57869, 2.207 , 1.77611, 1.41664, 1.17298, 1.01458, 0.90829, 0.8336 , 0.77879}); nd4j::ops::tf_atan2 op; - auto result = op.execute({&y, &x}, {}, {}); + auto result = op.evaluate({&y, &x}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); // z->printIndexedBuffer(); @@ -465,7 +465,7 @@ TEST_F(DeclarableOpsTests10, atan2_test3) { -1.54128, -1.42907, -1.2632 , -1.00789,-0.63621, -0.20531, 0.15416, 0.39782, 0.55622, 0.6625 , 0.7372 , 0.79201}); nd4j::ops::tf_atan2 op; - auto result = op.execute({&x, &y}, {}, {}, {}); + auto result = op.evaluate({&x, &y}, {}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -485,7 +485,7 @@ TEST_F(DeclarableOpsTests10, atan2_test4) { 3.05688, 3.03942, 3.01293, 2.9681 , 2.18167, 1.87635, 1.50156, 1.14451, 1.13674, 0.97626, 0.84423, 0.7372 }); nd4j::ops::tf_atan2 op; - auto result = op.execute({&x, &y}, {}, {}, {}); + auto result = op.evaluate({&x, &y}, {}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -505,7 +505,7 @@ TEST_F(DeclarableOpsTests10, atan2_test5) { -1.48608, -1.46862, -1.44214, -1.3973 ,-0.61088, -0.30556, 0.06924, 0.42629, 0.43405, 0.59453, 0.72657, 0.8336 }); nd4j::ops::tf_atan2 op; - auto result = op.execute({&y, &x}, {}, {}, {}); + auto result = op.evaluate({&y, &x}, {}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -524,7 +524,7 @@ TEST_F(DeclarableOpsTests10, atan2_test6) { auto exp = NDArrayFactory::create('c', {1,3,4}, {-2.25712, -1.68608, -1.44214, -0.54006,-2.77695, -2.16855, 0.34972, 0.24585, 2.71267, 1.74453, 1.45312, 0.8336 }); nd4j::ops::tf_atan2 op; - auto result = op.execute({&y, &x}, {}, {}, {}); + auto result = op.evaluate({&y, &x}, {}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -546,7 +546,7 @@ TEST_F(DeclarableOpsTests10, IGamma_Test1) { 0.0000039433403, 0.000086064574, 0.000436067, 0.0012273735}); nd4j::ops::igamma op; - auto result = op.execute({&y, &x}, {}, {}, {}); + auto result = op.evaluate({&y, &x}, {}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); // z->printBuffer("OUtput"); @@ -568,7 +568,7 @@ TEST_F(DeclarableOpsTests10, IGamma_Test2) { 0.999996, 0.999914, 0.999564, 0.998773}); nd4j::ops::igammac op; - auto result = op.execute({&y, &x}, {}, {}, {}); + auto result = op.evaluate({&y, &x}, {}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); // z->printBuffer("OUtput"); @@ -591,7 +591,7 @@ TEST_F(DeclarableOpsTests10, LGamma_Test1) { }); nd4j::ops::lgamma op; - auto result = op.execute({&x}, {}, {}, {}); + auto result = op.evaluate({&x}, {}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); // z->printBuffer("OUtput"); @@ -610,7 +610,7 @@ TEST_F(DeclarableOpsTests10, range_test10) { auto exp = NDArrayFactory::create('c', {5}, {0.,1.,2.,3.,4.}); nd4j::ops::range op; - auto result = op.execute({&limit}, {}, {}, {}); + auto result = op.evaluate({&limit}, {}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -632,7 +632,7 @@ TEST_F(DeclarableOpsTests10, range_test11) { auto exp = NDArrayFactory::create('c', {5}, {0.5,1.5,2.5,3.5,4.5}); nd4j::ops::range op; - auto result = op.execute({&start, &limit}, {}, {}, {}); + auto result = op.evaluate({&start, &limit}, {}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -650,7 +650,7 @@ TEST_F(DeclarableOpsTests10, range_test12) { auto exp = NDArrayFactory::create('c', {9}, {0.5f, 1.f , 1.5f, 2.f , 2.5f, 3.f , 3.5f, 4.f , 4.5f}); nd4j::ops::range op; - auto result = op.execute({}, {0.5, 5, 0.5}, {}, {}); + auto result = op.evaluate({}, {0.5, 5, 0.5}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -671,7 +671,7 @@ TEST_F(DeclarableOpsTests10, top_k_permuted_test1) { nd4j::ops::top_k op; - auto result = op.execute({&x}, {}, {4}, {false}); + auto result = op.evaluate({&x}, {}, {4}, {false}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -681,7 +681,7 @@ TEST_F(DeclarableOpsTests10, top_k_permuted_test1) { ASSERT_TRUE(expUnsorted.isSameShape(z)); ASSERT_TRUE(expUnsorted.equalsTo(z)); - auto result2 = op.execute({&x}, {}, {5}, {true}); + auto result2 = op.evaluate({&x}, {}, {5}, {true}); ASSERT_EQ(ND4J_STATUS_OK, result2->status()); @@ -704,7 +704,7 @@ TEST_F(DeclarableOpsTests10, top_k_permuted_test2) { nd4j::ops::top_k op; - auto result = op.execute({&x}, {}, {5}, {false}); + auto result = op.evaluate({&x}, {}, {5}, {false}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -714,7 +714,7 @@ TEST_F(DeclarableOpsTests10, top_k_permuted_test2) { ASSERT_TRUE(expUnsorted.isSameShape(z)); ASSERT_TRUE(expUnsorted.equalsTo(z)); - auto result2 = op.execute({&x}, {}, {5}, {true}); + auto result2 = op.evaluate({&x}, {}, {5}, {true}); ASSERT_EQ(ND4J_STATUS_OK, result2->status()); @@ -738,7 +738,7 @@ TEST_F(DeclarableOpsTests10, sparse_softmax_cross_entropy_loss_with_logits_test1 logits.linspace(0.1, 0.1); nd4j::ops::sparse_softmax_cross_entropy_loss_with_logits op; - auto results = op.execute({&labels, &logits}, {}, {}, {}, false, nd4j::DataType::DOUBLE); + auto results = op.evaluate({&labels, &logits}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -760,7 +760,7 @@ TEST_F(DeclarableOpsTests10, sparse_softmax_cross_entropy_loss_with_logits_test2 logits.linspace(0.1, 0.1); nd4j::ops::sparse_softmax_cross_entropy_loss_with_logits op; - auto results = op.execute({&labels, &logits}, {}, {}, {}, false, nd4j::DataType::DOUBLE); + auto results = op.evaluate({&labels, &logits}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -782,7 +782,7 @@ TEST_F(DeclarableOpsTests10, sparse_softmax_cross_entropy_loss_with_logits_test3 logits.linspace(0.1, 0.1); nd4j::ops::sparse_softmax_cross_entropy_loss_with_logits op; - auto results = op.execute({&labels, &logits}, {}, {}, {}, false, nd4j::DataType::DOUBLE); + auto results = op.evaluate({&labels, &logits}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -804,7 +804,7 @@ TEST_F(DeclarableOpsTests10, sparse_softmax_cross_entropy_loss_with_logits_test4 logits.linspace(0.1, 0.1); nd4j::ops::sparse_softmax_cross_entropy_loss_with_logits op; - auto results = op.execute({&labels, &logits}, {}, {}, {}, false, nd4j::DataType::DOUBLE); + auto results = op.evaluate({&labels, &logits}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -825,7 +825,7 @@ TEST_F(DeclarableOpsTests10, split_test4) { auto exp2 = NDArrayFactory::create('c', {5}, {6.f,7.f,8.f,9.f,10.f}); nd4j::ops::split op; - auto results = op.execute({&input, &axis}, {}, {2}, {}); + auto results = op.evaluate({&input, &axis}, {}, {2}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -849,7 +849,7 @@ TEST_F(DeclarableOpsTests10, split_test5) { auto exp2 = NDArrayFactory::create('c', {3,4}, {5.f,6.f,7.f,8.f, 13.f,14.f,15.f,16.f, 21.f,22.f,23.f,24.f}); nd4j::ops::split op; - auto results = op.execute({&input}, {}, {2,-1},{}); + auto results = op.evaluate({&input}, {}, {2,-1},{}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -872,7 +872,7 @@ TEST_F(DeclarableOpsTests10, histogram_fixed_width_test1) { auto exp = NDArrayFactory::create('c', {5}, {2, 1, 1, 0, 2}); nd4j::ops::histogram_fixed_width op; - auto results = op.execute({&input, &range}, {}, {5}, {}); + auto results = op.evaluate({&input, &range}, {}, {5}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -892,7 +892,7 @@ TEST_F(DeclarableOpsTests10, histogram_fixed_width_test2) { auto exp = NDArrayFactory::create('c', {5}, {5, 2, 5, 3, 9}); nd4j::ops::histogram_fixed_width op; - auto results = op.execute({&input, &range}, {}, {5}, {}); + auto results = op.evaluate({&input, &range}, {}, {5}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -912,7 +912,7 @@ TEST_F(DeclarableOpsTests10, histogram_fixed_width_test3) { auto exp = NDArrayFactory::create('c', {5}, {5, 2, 5, 4, 8}); nd4j::ops::histogram_fixed_width op; - auto results = op.execute({&input, &range}, {}, {5}, {}); + auto results = op.evaluate({&input, &range}, {}, {5}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -937,7 +937,7 @@ TEST_F(DeclarableOpsTests10, histogram_fixed_width_test4) { auto exp = NDArrayFactory::create('c', {5}, {22, 17, 24, 19, 18}); nd4j::ops::histogram_fixed_width op; - auto results = op.execute({&input, &range}, {}, {5}, {}); + auto results = op.evaluate({&input, &range}, {}, {5}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -963,7 +963,7 @@ TEST_F(DeclarableOpsTests10, histogram_fixed_width_test5) { auto exp = NDArrayFactory::create('c', {5}, {23, 15, 24, 17, 21}); nd4j::ops::histogram_fixed_width op; - auto results = op.execute({&input, &range}, {}, {5}, {}); + auto results = op.evaluate({&input, &range}, {}, {5}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -986,7 +986,7 @@ TEST_F(DeclarableOpsTests10, histogram_fixed_width_test6) { auto exp = NDArrayFactory::create('c', {5}, {3, 1, 2, 0, 1}); nd4j::ops::histogram_fixed_width op; - auto results = op.execute({&input, &range, &bins}, {}, {}, {}); + auto results = op.evaluate({&input, &range, &bins}, {}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1010,7 +1010,7 @@ TEST_F(DeclarableOpsTests10, NTH_Element_Test_1) { //input.linspace(1.f); nd4j::ops::nth_element op; - auto results = op.execute({&input, &n}, {}, {}); + auto results = op.evaluate({&input, &n}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1032,7 +1032,7 @@ TEST_F(DeclarableOpsTests10, NTH_Element_Test_2) { // input.linspace(1.f); nd4j::ops::nth_element op; - auto results = op.execute({&input, &n}, {}, {}); + auto results = op.evaluate({&input, &n}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1054,7 +1054,7 @@ TEST_F(DeclarableOpsTests10, NTH_Element_Test_3) { //input.linspace(1.f); nd4j::ops::nth_element op; - auto results = op.execute({&input, &n}, {}, {1}); // with reverse = true + auto results = op.evaluate({&input, &n}, {}, {1}); // with reverse = true ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1076,7 +1076,7 @@ TEST_F(DeclarableOpsTests10, NTH_Element_Test_4) { //input.linspace(1.f); nd4j::ops::nth_element op; - auto results = op.execute({&input, &n}, {}, {}); + auto results = op.evaluate({&input, &n}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1097,7 +1097,7 @@ TEST_F(DeclarableOpsTests10, NTH_Element_Test_04) { input.linspace(1.f); nd4j::ops::nth_element op; - auto results = op.execute({&input, &n}, {}, {}); + auto results = op.evaluate({&input, &n}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1118,7 +1118,7 @@ TEST_F(DeclarableOpsTests10, NTH_Element_Test_5) { // input.linspace(1.f); nd4j::ops::nth_element op; - auto results = op.execute({&input, &n}, {}, {1}); + auto results = op.evaluate({&input, &n}, {}, {1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1140,7 +1140,7 @@ TEST_F(DeclarableOpsTests10, NTH_Element_Test_6) { // input.linspace(1.f); nd4j::ops::nth_element op; - auto results = op.execute({&input, &n}, {}, {0}); + auto results = op.evaluate({&input, &n}, {}, {0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1160,7 +1160,7 @@ TEST_F(DeclarableOpsTests10, NTH_Element_Test_06) { // input.linspace(1.f); nd4j::ops::nth_element op; - auto results = op.execute({&input, &n}, {}, {1}); + auto results = op.evaluate({&input, &n}, {}, {1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1186,7 +1186,7 @@ TEST_F(DeclarableOpsTests10, NTH_Element_Test_7) { //input.linspace(1.f); nd4j::ops::nth_element op; - auto results = op.execute({&input, &n}, {}, {0}); + auto results = op.evaluate({&input, &n}, {}, {0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1213,7 +1213,7 @@ TEST_F(DeclarableOpsTests10, NTH_Element_Test_8) { //input.linspace(1.f); nd4j::ops::nth_element op; - auto results = op.execute({&input, &n}, {}, {1}); + auto results = op.evaluate({&input, &n}, {}, {1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1235,7 +1235,7 @@ TEST_F(DeclarableOpsTests10, broadcast_to_test1) { input.linspace(1.f); nd4j::ops::broadcast_to op; - auto results = op.execute({&input, &shape}, {}, {}, {}); + auto results = op.evaluate({&input, &shape}, {}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1257,7 +1257,7 @@ TEST_F(DeclarableOpsTests10, broadcast_to_test2) { input.linspace(1.f); nd4j::ops::broadcast_to op; - auto results = op.execute({&input, &shape}, {}, {}, {}); + auto results = op.evaluate({&input, &shape}, {}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1279,7 +1279,7 @@ TEST_F(DeclarableOpsTests10, broadcast_to_test3) { input.linspace(1.f); nd4j::ops::broadcast_to op; - auto results = op.execute({&input, &shape}, {}, {}, {}); + auto results = op.evaluate({&input, &shape}, {}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1299,7 +1299,7 @@ TEST_F(DeclarableOpsTests10, broadcast_to_test4) { auto exp = NDArrayFactory::create('c', {3,3}, {10.f, 10.f, 10.f,10.f, 10.f, 10.f, 10.f, 10.f, 10.f}); nd4j::ops::broadcast_to op; - auto results = op.execute({&input, &shape}, {}, {}, {}); + auto results = op.evaluate({&input, &shape}, {}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1319,7 +1319,7 @@ TEST_F(DeclarableOpsTests10, broadcast_to_test5) { auto exp = NDArrayFactory::create('c', {3}, {10.f, 10.f, 10.f}); nd4j::ops::broadcast_to op; - auto results = op.execute({&input, &shape}, {}, {}, {}); + auto results = op.evaluate({&input, &shape}, {}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1339,7 +1339,7 @@ TEST_F(DeclarableOpsTests10, broadcast_to_test6) { auto exp = NDArrayFactory::create('c', {1}, {10.f}); nd4j::ops::broadcast_to op; - auto results = op.execute({&input, &shape}, {}, {}, {}); + auto results = op.evaluate({&input, &shape}, {}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1359,7 +1359,7 @@ TEST_F(DeclarableOpsTests10, broadcast_to_test7) { auto exp = NDArrayFactory::create('c', {1}, {10.}); nd4j::ops::broadcast_to op; - auto results = op.execute({&input, &shape}, {}, {}, {}); + auto results = op.evaluate({&input, &shape}, {}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1381,7 +1381,7 @@ TEST_F(DeclarableOpsTests10, broadcast_to_test8) { input.linspace(1.f); nd4j::ops::broadcast_to op; - auto results = op.execute({&input, &shape}, {}, {}, {}); + auto results = op.evaluate({&input, &shape}, {}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1403,7 +1403,7 @@ TEST_F(DeclarableOpsTests10, broadcast_to_test9) { input.linspace(1.f); nd4j::ops::broadcast_to op; - auto results = op.execute({&input, &shape}, {}, {}, {}); + auto results = op.evaluate({&input, &shape}, {}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1425,7 +1425,7 @@ TEST_F(DeclarableOpsTests10, broadcast_to_test10) { input.linspace(1.f); nd4j::ops::broadcast_to op; - auto results = op.execute({&input, &shape}, {}, {}, {}); + auto results = op.evaluate({&input, &shape}, {}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1481,7 +1481,7 @@ TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test1) { input.linspace(1); nd4j::ops::resize_bilinear op; - auto results = op.execute({&input}, {}, {10, 10}); + auto results = op.evaluate({&input}, {}, {10, 10}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1503,7 +1503,7 @@ TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test_11) { auto size = NDArrayFactory::create({65,65}); auto ex = NDArrayFactory::create('c', {1,65,65,256}); nd4j::ops::resize_bilinear op; - auto results = op.execute({&input, &size}, {}, {}, {false}); + auto results = op.evaluate({&input, &size}, {}, {}, {false}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1522,7 +1522,7 @@ TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test_12) { auto size = NDArrayFactory::create({65,65}); auto ex = NDArrayFactory::create('c', {1,65,65,256}); nd4j::ops::resize_bilinear op; - auto results = op.execute({&input, &size}, {}, {}, {true}); + auto results = op.evaluate({&input, &size}, {}, {}, {true}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1566,7 +1566,7 @@ TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test1_1) { input.linspace(1); nd4j::ops::resize_bilinear op; - auto results = op.execute({&input}, {}, {4, 5}, {false, true}); + auto results = op.evaluate({&input}, {}, {4, 5}, {false, true}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1613,7 +1613,7 @@ TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test1_2) { input.linspace(1); nd4j::ops::resize_bilinear op; - auto results = op.execute({&input}, {}, {4, 5}, {false, true}); + auto results = op.evaluate({&input}, {}, {4, 5}, {false, true}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1669,7 +1669,7 @@ TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test01) { input.linspace(1); nd4j::ops::resize_bilinear op; - auto results = op.execute({&input}, {}, {10, 10}); + auto results = op.evaluate({&input}, {}, {10, 10}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1824,7 +1824,7 @@ TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test02) { //input.linspace(1); nd4j::ops::resize_bilinear op; - auto results = op.execute({&input}, {}, {9, 9}); + auto results = op.evaluate({&input}, {}, {9, 9}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1883,7 +1883,7 @@ TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test2) { input.linspace(1); nd4j::ops::resize_bilinear op; - auto results = op.execute({&input, &size}, {}, {}); + auto results = op.evaluate({&input, &size}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2013,7 +2013,7 @@ TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test3) { input.linspace(1); nd4j::ops::resize_bilinear op; - auto results = op.execute({&input}, {}, {10, 10}, {true}); + auto results = op.evaluate({&input}, {}, {10, 10}, {true}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2142,7 +2142,7 @@ TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test4) { input.linspace(1); nd4j::ops::resize_bilinear op; - auto results = op.execute({&input, &size}, {}, {}, {true}); + auto results = op.evaluate({&input, &size}, {}, {}, {true}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2166,7 +2166,7 @@ TEST_F(DeclarableOpsTests10, LinSpace_Test1) { 8., 8.5, 9., 9.5, 10., 10.5, 11., 11.5, 12.}); nd4j::ops::lin_space op; - auto result = op.execute({&start, &finish, &num}, {}, {}); + auto result = op.evaluate({&start, &finish, &num}, {}, {}); ASSERT_EQ(result->status(), ND4J_STATUS_OK); auto res = result->at(0); @@ -2208,7 +2208,7 @@ TEST_F(DeclarableOpsTests10, ImageResizeNeighbor_Test1) { input.linspace(1); nd4j::ops::resize_nearest_neighbor op; - auto results = op.execute({&input}, {}, {4, 5}, {false, false}); + auto results = op.evaluate({&input}, {}, {4, 5}, {false, false}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2256,7 +2256,7 @@ TEST_F(DeclarableOpsTests10, ImageResizeNeighbor_Test1_1) { input.linspace(1); nd4j::ops::resize_nearest_neighbor op; - auto results = op.execute({&input}, {}, {4, 5}); + auto results = op.evaluate({&input}, {}, {4, 5}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2304,7 +2304,7 @@ TEST_F(DeclarableOpsTests10, ImageResizeNeighbor_Test1_1_1) { input.linspace(1); nd4j::ops::resize_nearest_neighbor op; - auto results = op.execute({&input}, {}, {4,5}, {false, true}); + auto results = op.evaluate({&input}, {}, {4,5}, {false, true}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2351,7 +2351,7 @@ TEST_F(DeclarableOpsTests10, ImageResizeNeighbor_Test01) { input.linspace(1); nd4j::ops::resize_nearest_neighbor op; - auto results = op.execute({&input}, {}, {4, 5}); + auto results = op.evaluate({&input}, {}, {4, 5}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2373,7 +2373,7 @@ TEST_F(DeclarableOpsTests10, ReduceLogSumExpTest_1) { NDArray expected = NDArrayFactory::create(2.5206409f); nd4j::ops::reduce_logsumexp op; - auto results = op.execute({&input}, {}, {}); + auto results = op.evaluate({&input}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2394,7 +2394,7 @@ TEST_F(DeclarableOpsTests10, ReduceLogSumExpTest_2) { NDArray expected = NDArrayFactory::create({1.0986123f, 1.8619947f, 1.0986123f}); nd4j::ops::reduce_logsumexp op; - auto results = op.execute({&input}, {}, {0}); + auto results = op.evaluate({&input}, {}, {0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2414,7 +2414,7 @@ TEST_F(DeclarableOpsTests10, ReduceLogSumExpTest_3) { NDArray expected = NDArrayFactory::create('c', {1,3}, {1.0986123f, 1.8619947f, 1.0986123f}); nd4j::ops::reduce_logsumexp op; - auto results = op.execute({&input}, {1.f}, {0}); + auto results = op.evaluate({&input}, {1.f}, {0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2435,7 +2435,7 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_1) { boxes.linspace(1.f); nd4j::ops::non_max_suppression op; - auto results = op.execute({&boxes, &scores}, {}, {3}); + auto results = op.evaluate({&boxes, &scores}, {}, {3}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2457,7 +2457,7 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_2) { NDArray expected = NDArrayFactory::create('c', {3}, {3,0,5}); nd4j::ops::non_max_suppression op; - auto results = op.execute({&boxes, &scales}, {0.5}, {3}); + auto results = op.evaluate({&boxes, &scales}, {0.5}, {3}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2479,7 +2479,7 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_3) { NDArray expected = NDArrayFactory::create('c', {1}, {1}); nd4j::ops::non_max_suppression op; - auto results = op.execute({&boxes, &scales}, {0.5, 0.5}, {2}); + auto results = op.evaluate({&boxes, &scales}, {0.5, 0.5}, {2}); ASSERT_EQ(Status::OK(), results->status()); @@ -2502,7 +2502,7 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_4) { NDArray threshold = NDArrayFactory::create(0.5f); NDArray scoreThreshold = NDArrayFactory::create(0.5); nd4j::ops::non_max_suppression op; - auto results = op.execute({&boxes, &scales, &maxSize, &threshold, &scoreThreshold}, {}, {}); + auto results = op.evaluate({&boxes, &scales, &maxSize, &threshold, &scoreThreshold}, {}, {}); ASSERT_EQ(Status::OK(), results->status()); @@ -2524,7 +2524,7 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_5) { NDArray threshold = NDArrayFactory::create(0.5f); NDArray scoreThreshold = NDArrayFactory::create(-DataTypeUtils::infOrMax()); nd4j::ops::non_max_suppression op; - auto results = op.execute({&boxes, &scales, &maxSize, &threshold, &scoreThreshold}, {}, {}); + auto results = op.evaluate({&boxes, &scales, &maxSize, &threshold, &scoreThreshold}, {}, {}); ASSERT_EQ(Status::OK(), results->status()); @@ -2547,7 +2547,7 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_6) { NDArray threshold = NDArrayFactory::create(0.5f); NDArray scoreThreshold = NDArrayFactory::create(-DataTypeUtils::infOrMax()); nd4j::ops::non_max_suppression_v3 op; - auto results = op.execute({&boxes, &scales, &maxSize, &threshold, &scoreThreshold}, {}, {}); + auto results = op.evaluate({&boxes, &scales, &maxSize, &threshold, &scoreThreshold}, {}, {}); ASSERT_EQ(Status::OK(), results->status()); @@ -2571,7 +2571,7 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_06) { NDArray threshold = NDArrayFactory::create(0.5f); NDArray scoreThreshold = NDArrayFactory::create(-DataTypeUtils::infOrMax()); nd4j::ops::non_max_suppression_v3 op; - auto results = op.execute({&boxes, &scales, &maxSize, &threshold, &scoreThreshold}, {}, {}); + auto results = op.evaluate({&boxes, &scales, &maxSize, &threshold, &scoreThreshold}, {}, {}); ASSERT_EQ(Status::OK(), results->status()); @@ -2594,7 +2594,7 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_7) { NDArray threshold = NDArrayFactory::create(0.5f); NDArray scoreThreshold = NDArrayFactory::create(0.5f); nd4j::ops::non_max_suppression_v3 op; - auto results = op.execute({&boxes, &scales, &maxSize, &threshold, &scoreThreshold}, {}, {}); + auto results = op.evaluate({&boxes, &scales, &maxSize, &threshold, &scoreThreshold}, {}, {}); ASSERT_EQ(Status::OK(), results->status()); @@ -2619,7 +2619,7 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressingOverlap_1) { NDArray expected = NDArrayFactory::create('c', {1,}, {3}); nd4j::ops::non_max_suppression_overlaps op; - auto results = op.execute({&boxes, &scores, &max_num}, {0.5, 0.}, {}); + auto results = op.evaluate({&boxes, &scores, &max_num}, {0.5, 0.}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2644,7 +2644,7 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressingOverlap_2) { NDArray expected = NDArrayFactory::create('c', {3,}, {1,1,1}); nd4j::ops::non_max_suppression_overlaps op; - auto results = op.execute({&boxes, &scores, &max_num}, {0.5, 0.}, {}); + auto results = op.evaluate({&boxes, &scores, &max_num}, {0.5, 0.}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2669,7 +2669,7 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressingOverlap_3) { NDArray expected = NDArrayFactory::create('c', {5,}, {1,1,1,1,1}); nd4j::ops::non_max_suppression_overlaps op; - auto results = op.execute({&boxes, &scores, &max_num}, {0.5, 0.}, {}); + auto results = op.evaluate({&boxes, &scores, &max_num}, {0.5, 0.}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2693,7 +2693,7 @@ TEST_F(DeclarableOpsTests10, Image_CropAndResize_1) { NDArray expected = NDArrayFactory::create('c', {1,1,1,1}, {2.5f}); nd4j::ops::crop_and_resize op; - auto results = op.execute({&images, &boxes, &boxI, &cropSize}, {}, {}); + auto results = op.evaluate({&images, &boxes, &boxI, &cropSize}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2718,7 +2718,7 @@ TEST_F(DeclarableOpsTests10, Image_CropAndResize_2) { NDArray expected = NDArrayFactory::create('c', {1,1,1,1}, {4.f}); nd4j::ops::crop_and_resize op; - auto results = op.execute({&images, &boxes, &boxI, &cropSize}, {}, {1}); + auto results = op.evaluate({&images, &boxes, &boxI, &cropSize}, {}, {1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2742,7 +2742,7 @@ TEST_F(DeclarableOpsTests10, Image_CropAndResize_3) { NDArray expected('c', {1,3,3,1}, {1.f, 1.5f, 2., 2.f, 2.5f, 3.f, 3.f, 3.5f, 4.f}, nd4j::DataType::FLOAT32); nd4j::ops::crop_and_resize op; - auto results = op.execute({&images, &boxes, &boxI, &cropSize}, {}, {0}); + auto results = op.evaluate({&images, &boxes, &boxI, &cropSize}, {}, {0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2766,7 +2766,7 @@ TEST_F(DeclarableOpsTests10, Image_CropAndResize_4) { NDArray expected('c', {1,3,3,1}, {1.f, 2.f, 2.f, 3.f, 4, 4.f, 3.f, 4.f, 4.f}, nd4j::DataType::FLOAT32); nd4j::ops::crop_and_resize op; - auto results = op.execute({&images, &boxes, &boxI, &cropSize}, {}, {1}); + auto results = op.evaluate({&images, &boxes, &boxI, &cropSize}, {}, {1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2790,7 +2790,7 @@ TEST_F(DeclarableOpsTests10, Image_CropAndResize_5) { NDArray expected('c', {1, 10, 10,3}, nd4j::DataType::FLOAT32); nd4j::ops::crop_and_resize op; - auto results = op.execute({&images, &boxes, &boxI, &cropSize}, {}, {1}); + auto results = op.evaluate({&images, &boxes, &boxI, &cropSize}, {}, {1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2826,7 +2826,7 @@ TEST_F(DeclarableOpsTests10, Image_DrawBoundingBoxes_1) { }); images.linspace(1.); nd4j::ops::draw_bounding_boxes op; - auto results = op.execute({&images, &boxes, &colors}, {}, {}); + auto results = op.evaluate({&images, &boxes, &colors}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2859,7 +2859,7 @@ TEST_F(DeclarableOpsTests10, Image_DrawBoundingBoxes_2) { 73.1f , 74.1f, 75.1f, 76.1f, 77.1f , 78.1f, 79.1f , 80.1f , 81.1f }); images.linspace(1.1); nd4j::ops::draw_bounding_boxes op; - auto results = op.execute({&images, &boxes, &colors}, {}, {}); + auto results = op.evaluate({&images, &boxes, &colors}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2912,7 +2912,7 @@ TEST_F(DeclarableOpsTests10, Image_DrawBoundingBoxes_3) { 0.8428f, 0.9441f, 0.9441f, 0.9441f, 0.3491f, 0.5793f, 0.573f , 0.1822f, 0.642f , 0.9143f}); nd4j::ops::draw_bounding_boxes op; - auto results = op.execute({&images, &boxes, &colors}, {}, {}); + auto results = op.evaluate({&images, &boxes, &colors}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); auto result = results->at(0); @@ -2937,7 +2937,7 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_1) { NDArray max('c', {}, {0.1f}, nd4j::DataType::FLOAT32); nd4j::ops::fake_quant_with_min_max_vars op; - auto results = op.execute({&x, &min, &max}, {}, {}); + auto results = op.evaluate({&x, &min, &max}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2958,7 +2958,7 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_2) { NDArray max = NDArrayFactory::create(0.1); nd4j::ops::fake_quant_with_min_max_vars op; - auto results = op.execute({&x, &min, &max}, {}, {}); + auto results = op.evaluate({&x, &min, &max}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2979,7 +2979,7 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_3) { NDArray max = NDArrayFactory::create('c', {1}, {0.1}); nd4j::ops::fake_quant_with_min_max_vars_per_channel op; - auto results = op.execute({&x, &min, &max}, {}, {}); + auto results = op.evaluate({&x, &min, &max}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -3003,7 +3003,7 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_03) { NDArray max = NDArrayFactory::create({0.9441f, 0.5957f, 0.8669f, 0.3502f, 0.5100f}); nd4j::ops::fake_quant_with_min_max_vars_per_channel op; - auto results = op.execute({&x, &min, &max}, {}, {}); + auto results = op.evaluate({&x, &min, &max}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -3026,7 +3026,7 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_03_1) { NDArray max = NDArrayFactory::create({0.9441f, 0.5957f, 0.8669f, 0.3502f, 0.5100f}); nd4j::ops::fake_quant_with_min_max_vars_per_channel op; - auto results = op.execute({&x, &min, &max}, {}, {8}, {true}); + auto results = op.evaluate({&x, &min, &max}, {}, {8}, {true}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -3050,7 +3050,7 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_03_2) { NDArray max = NDArrayFactory::create({0.9441f, 0.5957f, 0.8669f, 0.3502f, 0.5100f}); nd4j::ops::fake_quant_with_min_max_vars_per_channel op; - auto results = op.execute({&x, &min, &max}, {}, {6}, {true}); + auto results = op.evaluate({&x, &min, &max}, {}, {6}, {true}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -3073,7 +3073,7 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_03_3) { NDArray max = NDArrayFactory::create({0.9441f, 0.5957f, 0.8669f, 0.3502f, 0.5100f}); nd4j::ops::fake_quant_with_min_max_vars_per_channel op; - auto results = op.execute({&x, &min, &max}, {}, {6}, {false}); + auto results = op.evaluate({&x, &min, &max}, {}, {6}, {false}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -3108,7 +3108,7 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_4) { NDArray max = NDArrayFactory::create({65.f, 70.f, 90.f}); x.linspace(1.); nd4j::ops::fake_quant_with_min_max_vars_per_channel op; - auto results = op.execute({&x, &min, &max}, {}, {}); + auto results = op.evaluate({&x, &min, &max}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -3161,7 +3161,7 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_5) { NDArray max = NDArrayFactory::create({20.f, 21.f, 22.f, 23.f}); x.linspace(-60.); nd4j::ops::fake_quant_with_min_max_vars_per_channel op; - auto results = op.execute({&x, &min, &max}, {}, {}); + auto results = op.evaluate({&x, &min, &max}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -3195,7 +3195,7 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_6) { NDArray max = NDArrayFactory::create('c', {5}, {0.9441f, 0.5957f, 0.8669f, 0.3502f, 0.5100f}); // x.linspace(-60.); nd4j::ops::fake_quant_with_min_max_vars_per_channel op; - auto results = op.execute({&x, &min, &max}, {}, {}); + auto results = op.evaluate({&x, &min, &max}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -3241,7 +3241,7 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_7) { NDArray max = NDArrayFactory::create('c', {1}, {1.f}); x.linspace(0., 0.01); nd4j::ops::fake_quant_with_min_max_vars op; - auto results = op.execute({&x, &min, &max}, {}, {}); + auto results = op.evaluate({&x, &min, &max}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -3266,7 +3266,7 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_8) { NDArray max = NDArrayFactory::create('c', {1}, {1.f}); x.linspace(0., 0.1); nd4j::ops::fake_quant_with_min_max_vars op; - auto results = op.execute({&x, &min, &max}, {}, {}); + auto results = op.evaluate({&x, &min, &max}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests11.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests11.cpp index 9b1dfc068..27f742316 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests11.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests11.cpp @@ -43,7 +43,7 @@ TEST_F(DeclarableOpsTests11, test_listdiff_1) { auto y = NDArrayFactory::create('c',{2}, {3, 1}); nd4j::ops::listdiff op; - auto result = op.execute({&x, &y}, {}, {}); + auto result = op.evaluate({&x, &y}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); delete result; @@ -68,7 +68,7 @@ TEST_F(DeclarableOpsTests11, log_loss_grad_test1) { weights.assign(0.5); nd4j::ops::log_loss_grad op; - auto results = op.execute({&predictions, &weights, &labels}, {1e-7}, {0}, {}); + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {0}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -100,7 +100,7 @@ TEST_F(DeclarableOpsTests11, log_loss_grad_test2) { weights.assign(0.5); nd4j::ops::log_loss_grad op; - auto results = op.execute({&predictions, &weights, &labels}, {1e-7}, {0}); + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -130,7 +130,7 @@ TEST_F(DeclarableOpsTests11, log_loss_grad_test3) { weights.assign(0.5); nd4j::ops::log_loss_grad op; - auto results = op.execute({&predictions, &weights, &labels}, {1e-7}, {1}); + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -162,7 +162,7 @@ TEST_F(DeclarableOpsTests11, log_loss_grad_test4) { weights.assign(0.5); nd4j::ops::log_loss_grad op; - auto results = op.execute({&predictions, &weights, &labels}, {1e-7}, {1}); + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -195,7 +195,7 @@ TEST_F(DeclarableOpsTests11, log_loss_grad_test5) { weights.assign(0.5); nd4j::ops::log_loss_grad op; - auto results = op.execute({&predictions, &weights, &labels}, {1e-7}, {2}); + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -227,7 +227,7 @@ TEST_F(DeclarableOpsTests11, log_loss_grad_test6) { weights.assign(0.5); nd4j::ops::log_loss_grad op; - auto results = op.execute({&predictions, &weights, &labels}, {1e-7}, {2}); + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -253,7 +253,7 @@ TEST_F(DeclarableOpsTests11, log_loss_grad_test7) { weights.assign(0.5); nd4j::ops::log_loss_grad op; - auto results = op.execute({&predictions, &weights, &labels}, {1e-7}, {2}); + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -288,7 +288,7 @@ TEST_F(DeclarableOpsTests11, log_loss_grad_test8) { weights.p(3, 0.); nd4j::ops::log_loss_grad op; - auto results = op.execute({&predictions, &weights, &labels}, {1e-7}, {2}); + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -325,7 +325,7 @@ TEST_F(DeclarableOpsTests11, log_loss_grad_test9) { weights.assign(0.5); nd4j::ops::log_loss_grad op; - auto results = op.execute({&predictions, &weights, &labels}, {1e-7}, {3}); + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {3}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -357,7 +357,7 @@ TEST_F(DeclarableOpsTests11, log_loss_grad_test10) { weights.assign(0.5); nd4j::ops::log_loss_grad op; - auto results = op.execute({&predictions, &weights, &labels}, {1e-7}, {3}); + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {3}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -383,7 +383,7 @@ TEST_F(DeclarableOpsTests11, log_loss_grad_test11) { weights.assign(0.5); nd4j::ops::log_loss_grad op; - auto results = op.execute({&predictions, &weights, &labels}, {1e-7}, {3}); + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {3}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -420,7 +420,7 @@ TEST_F(DeclarableOpsTests11, log_loss_grad_test12) { nd4j::ops::log_loss_grad op; - auto results = op.execute({&predictions, &weights, &labels}, {1e-7}, {3}); + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {3}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -459,7 +459,7 @@ TEST_F(DeclarableOpsTests11, log_loss_grad_test13) { weights.t(2) = 0.; nd4j::ops::log_loss_grad op; - auto results = op.execute({&predictions, &weights, &labels}, {1e-7}, {3}); + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {3}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -642,7 +642,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test1) { auto size = NDArrayFactory::create({30, 30}); nd4j::ops::resize_bicubic op; - auto results = op.execute({&input, &size}, {}, {}); + auto results = op.evaluate({&input, &size}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); NDArray* result = results->at(0); @@ -716,7 +716,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test2) { input.linspace(1); auto size = NDArrayFactory::create({10, 8}); nd4j::ops::resize_bicubic op; - auto results = op.execute({&input, &size}, {}, {}); + auto results = op.evaluate({&input, &size}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -753,7 +753,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test3) { input.linspace(1); auto size = NDArrayFactory::create({6, 6}); nd4j::ops::resize_bicubic op; - auto results = op.execute({&input, &size}, {}, {}); + auto results = op.evaluate({&input, &size}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -790,7 +790,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test4) { input.linspace(1); auto size = NDArrayFactory::create({6, 8}); nd4j::ops::resize_bicubic op; - auto results = op.execute({&input, &size}, {}, {}); + auto results = op.evaluate({&input, &size}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -833,7 +833,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test5) { input.linspace(1); auto size = NDArrayFactory::create({8, 8}); nd4j::ops::resize_bicubic op; - auto results = op.execute({&input, &size}, {}, {}); + auto results = op.evaluate({&input, &size}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -963,7 +963,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test6) { auto size = NDArrayFactory::create({30, 30}); nd4j::ops::resize_bicubic op; - auto results = op.execute({&input, &size}, {}, {}); + auto results = op.evaluate({&input, &size}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); NDArray* result = results->at(0); @@ -1021,7 +1021,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test7) { }); auto size = NDArrayFactory::create({9, 9}); nd4j::ops::resize_bicubic op; - auto results = op.execute({&input, &size}, {}, {}); + auto results = op.evaluate({&input, &size}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1074,7 +1074,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test8) { auto size = NDArrayFactory::create({9, 9}); nd4j::ops::resize_bicubic op; - auto results = op.execute({&input, &size}, {}, {}, {true, false}); + auto results = op.evaluate({&input, &size}, {}, {}, {true, false}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1135,7 +1135,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeArea_Test1) { input.linspace(1); auto size = NDArrayFactory::create({6, 6}); nd4j::ops::resize_area op; - auto results = op.execute({&input, &size}, {}, {}); + auto results = op.evaluate({&input, &size}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1162,7 +1162,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeArea_Test2) { input.linspace(1); auto size = NDArrayFactory::create({6, 6}); nd4j::ops::resize_area op; - auto results = op.execute({&input, &size}, {}, {}); + auto results = op.evaluate({&input, &size}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1190,7 +1190,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeArea_Test3) { input.linspace(1); auto size = NDArrayFactory::create({6, 6}); nd4j::ops::resize_area op; - auto results = op.execute({&input, &size}, {}, {}); + auto results = op.evaluate({&input, &size}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1228,7 +1228,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeArea_Test4) { //input.linspace(1); auto size = NDArrayFactory::create({6, 6}); nd4j::ops::resize_area op; - auto results = op.execute({&input, &size}, {}, {}); + auto results = op.evaluate({&input, &size}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1266,7 +1266,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeArea_Test5) { //input.linspace(1); auto size = NDArrayFactory::create({6, 6}); nd4j::ops::resize_area op; - auto results = op.execute({&input, &size}, {}, {}); + auto results = op.evaluate({&input, &size}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1304,7 +1304,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeArea_Test6) { //input.linspace(1); auto size = NDArrayFactory::create({6, 6}); nd4j::ops::resize_area op; - auto results = op.execute({&input, &size}, {}, {}, {true}); + auto results = op.evaluate({&input, &size}, {}, {}, {true}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1342,7 +1342,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeArea_Test7) { //input.linspace(1); // auto size = NDArrayFactory::create({6, 6}); nd4j::ops::resize_area op; - auto results = op.execute({&input}, {}, {6, 6}, {true}); + auto results = op.evaluate({&input}, {}, {6, 6}, {true}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1372,7 +1372,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeArea_Test8) { //input.linspace(1); // auto size = NDArrayFactory::create({6, 6}); nd4j::ops::resize_area op; - auto results = op.execute({&input}, {}, {6, 6}, {true}); + auto results = op.evaluate({&input}, {}, {6, 6}, {true}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1399,7 +1399,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeArea_Test9) { //input.linspace(1); auto size = NDArrayFactory::create({10, 10}); nd4j::ops::resize_area op; - auto results = op.execute({&input, &size}, {}, {}); + auto results = op.evaluate({&input, &size}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1426,7 +1426,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeArea_Test10) { //input.linspace(1); //auto size = NDArrayFactory::create({10, 10}); nd4j::ops::resize_area op; - auto results = op.execute({&input}, {}, {10, 10}); + auto results = op.evaluate({&input}, {}, {10, 10}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1453,7 +1453,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeArea_Test11) { //input.linspace(1); //auto size = NDArrayFactory::create({10, 10}); nd4j::ops::resize_area op; - auto results = op.execute({&input}, {}, {6, 9}); + auto results = op.evaluate({&input}, {}, {6, 9}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1480,7 +1480,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeArea_Test12) { //input.linspace(1); //auto size = NDArrayFactory::create({10, 10}); nd4j::ops::resize_area op; - auto results = op.execute({&input}, {}, {10, 15}); + auto results = op.evaluate({&input}, {}, {10, 15}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1507,7 +1507,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeArea_Test13) { //input.linspace(1); //auto size = NDArrayFactory::create({10, 10}); nd4j::ops::resize_area op; - auto results = op.execute({&input}, {}, {9, 9}); + auto results = op.evaluate({&input}, {}, {9, 9}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1558,7 +1558,7 @@ TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test1) { weights.assign(0.5); nd4j::ops::mean_sqerr_loss_grad op; - auto results = op.execute({&predictions, &weights, &labels}, {}, {0}); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1590,7 +1590,7 @@ TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test2) { weights.assign(0.5); nd4j::ops::mean_sqerr_loss_grad op; - auto results = op.execute({&predictions, &weights, &labels}, {}, {0}); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1618,7 +1618,7 @@ TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test3) { weights.assign(0.5); nd4j::ops::mean_sqerr_loss_grad op; - auto results = op.execute({&predictions, &weights, &labels}, {}, {1}); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1650,7 +1650,7 @@ TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test4) { weights.assign(0.5); nd4j::ops::mean_sqerr_loss_grad op; - auto results = op.execute({&predictions, &weights, &labels}, {}, {1}); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1679,7 +1679,7 @@ TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test5) { weights.assign(0.5); nd4j::ops::mean_sqerr_loss_grad op; - auto results = op.execute({&predictions, &weights, &labels}, {}, {2}); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1711,7 +1711,7 @@ TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test6) { weights.assign(0.5); nd4j::ops::mean_sqerr_loss_grad op; - auto results = op.execute({&predictions, &weights, &labels}, {}, {2}); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1737,7 +1737,7 @@ TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test7) { weights.assign(0.5); nd4j::ops::mean_sqerr_loss_grad op; - auto results = op.execute({&predictions, &weights, &labels}, {}, {2}); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1770,7 +1770,7 @@ TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test8) { weights.p(3, 0.); nd4j::ops::mean_sqerr_loss_grad op; - auto results = op.execute({&predictions, &weights, &labels}, {}, {2}); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1805,7 +1805,7 @@ TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test9) { weights.assign(0.5); nd4j::ops::mean_sqerr_loss_grad op; - auto results = op.execute({&predictions, &weights, &labels}, {}, {3}); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1837,7 +1837,7 @@ TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test10) { weights.assign(0.5); nd4j::ops::mean_sqerr_loss_grad op; - auto results = op.execute({&predictions, &weights, &labels}, {}, {3}); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1863,7 +1863,7 @@ TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test11) { weights.assign(0.5); nd4j::ops::mean_sqerr_loss_grad op; - auto results = op.execute({&predictions, &weights, &labels}, {}, {3}); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1896,7 +1896,7 @@ TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test12) { weights.t(3) = 0.; nd4j::ops::mean_sqerr_loss_grad op; - auto results = op.execute({&predictions, &weights, &labels}, {}, {3}); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1933,7 +1933,7 @@ TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test13) { weights.t(2) = 0.; nd4j::ops::mean_sqerr_loss_grad op; - auto results = op.execute({&predictions, &weights, &labels}, {}, {3}); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1956,7 +1956,7 @@ TEST_F(DeclarableOpsTests11, SquaredSubtractTest_Test1) { auto y = NDArrayFactory::create('c',{4}, {3, 2, 1, 0}); auto exp = NDArrayFactory::create('c', {4}, {9, 1,1, 9}); nd4j::ops::squaredsubtract op; - auto result = op.execute({&x, &y}, {}, {}); + auto result = op.evaluate({&x, &y}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); ASSERT_TRUE(exp.equalsTo(result->at(0))); @@ -1968,7 +1968,7 @@ TEST_F(DeclarableOpsTests11, SquaredSubtractTest_Test2) { auto y = NDArrayFactory::create('c',{4}, {3, 2, 1, 0}); auto exp = NDArrayFactory::create('c', {2, 4}, {9, 1,1, 9, 9, 1, 1, 9}); nd4j::ops::squaredsubtract op; - auto result = op.execute({&x, &y}, {}, {}); + auto result = op.evaluate({&x, &y}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); ASSERT_TRUE(exp.equalsTo(result->at(0))); delete result; @@ -1980,7 +1980,7 @@ TEST_F(DeclarableOpsTests11, SquaredSubtractTest_Test3) { auto exp = NDArrayFactory::create('c', {2, 4}, {-6, -4, 6, 24, -30, -12, 14, 48}); auto eps = NDArrayFactory::create('c', {2, 4}, {1,2,3,4,5,6,7,8}); nd4j::ops::squaredsubtract_bp op; - auto result = op.execute({&x, &y, &eps}, {}, {}); + auto result = op.evaluate({&x, &y, &eps}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); ASSERT_TRUE(exp.equalsTo(result->at(0))); delete result; @@ -2003,7 +2003,7 @@ TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test1) { weights.assign(0.5); nd4j::ops::absolute_difference_loss_grad op; - auto results = op.execute({&predictions, &weights, &labels}, {}, {0}); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2035,7 +2035,7 @@ TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test2) { weights.assign(0.5); nd4j::ops::absolute_difference_loss_grad op; - auto results = op.execute({&predictions, &weights, &labels}, {}, {0}); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2063,7 +2063,7 @@ TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test3) { weights.assign(0.5); nd4j::ops::absolute_difference_loss_grad op; - auto results = op.execute({&predictions, &weights, &labels}, {}, {1}); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2095,7 +2095,7 @@ TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test4) { weights.assign(0.5); nd4j::ops::absolute_difference_loss_grad op; - auto results = op.execute({&predictions, &weights, &labels}, {}, {1}); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2124,7 +2124,7 @@ TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test5) { weights.assign(0.5); nd4j::ops::absolute_difference_loss_grad op; - auto results = op.execute({&predictions, &weights, &labels}, {}, {2}); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2156,7 +2156,7 @@ TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test6) { weights.assign(0.5); nd4j::ops::absolute_difference_loss_grad op; - auto results = op.execute({&predictions, &weights, &labels}, {}, {2}); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2182,7 +2182,7 @@ TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test7) { weights.assign(0.5); nd4j::ops::absolute_difference_loss_grad op; - auto results = op.execute({&predictions, &weights, &labels}, {}, {2}); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2215,7 +2215,7 @@ TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test8) { weights.p(3, 0.); nd4j::ops::absolute_difference_loss_grad op; - auto results = op.execute({&predictions, &weights, &labels}, {}, {2}); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2250,7 +2250,7 @@ TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test9) { weights.assign(0.5); nd4j::ops::absolute_difference_loss_grad op; - auto results = op.execute({&predictions, &weights, &labels}, {}, {3}); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2282,7 +2282,7 @@ TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test10) { weights.assign(0.5); nd4j::ops::absolute_difference_loss_grad op; - auto results = op.execute({&predictions, &weights, &labels}, {}, {3}); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2308,7 +2308,7 @@ TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test11) { weights.assign(0.5); nd4j::ops::absolute_difference_loss_grad op; - auto results = op.execute({&predictions, &weights, &labels}, {}, {3}); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2341,7 +2341,7 @@ TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test12) { weights.t(3) = 0.; nd4j::ops::absolute_difference_loss_grad op; - auto results = op.execute({&predictions, &weights, &labels}, {}, {3}); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2378,7 +2378,7 @@ TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test13) { weights.t(2) = 0.; nd4j::ops::absolute_difference_loss_grad op; - auto results = op.execute({&predictions, &weights, &labels}, {}, {3}); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2407,7 +2407,7 @@ TEST_F(DeclarableOpsTests11, BFloat16_Test_1) { y.linspace(1); exp.linspace(2,2); nd4j::ops::add op; - auto results = op.execute({&x, &y}, {}, {}); + auto results = op.evaluate({&x, &y}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2428,7 +2428,7 @@ TEST_F(DeclarableOpsTests11, BFloat16_Test_2) { y.linspace(1); exp.linspace(2,2); nd4j::ops::add op; - auto results = op.execute({&x, &y}, {}, {}); + auto results = op.evaluate({&x, &y}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2449,7 +2449,7 @@ TEST_F(DeclarableOpsTests11, BFloat16_Test_3) { y.linspace(1); exp.linspace(2,2); nd4j::ops::add op; - auto results = op.execute({&x, &y}, {}, {}); + auto results = op.evaluate({&x, &y}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2478,7 +2478,7 @@ TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test1) { weights.assign(0.5); nd4j::ops::sigm_cross_entropy_loss_grad op; - auto results = op.execute({&logits, &weights, &labels}, {0.}, {0}); + auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2514,7 +2514,7 @@ TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test2) { weights.assign(0.5); nd4j::ops::sigm_cross_entropy_loss_grad op; - auto results = op.execute({&logits, &weights, &labels}, {0.3}, {0}); + auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2550,7 +2550,7 @@ TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test3) { weights.assign(0.5); nd4j::ops::sigm_cross_entropy_loss_grad op; - auto results = op.execute({&logits, &weights, &labels}, {0.3}, {1}); + auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2582,7 +2582,7 @@ TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test4) { weights.assign(0.5); nd4j::ops::sigm_cross_entropy_loss_grad op; - auto results = op.execute({&logits, &weights, &labels}, {0.3}, {1}); + auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2613,7 +2613,7 @@ TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test5) { weights.assign(0.5); nd4j::ops::sigm_cross_entropy_loss_grad op; - auto results = op.execute({&logits, &weights, &labels}, {0.3}, {2}); + auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2645,7 +2645,7 @@ TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test6) { weights.assign(0.5); nd4j::ops::sigm_cross_entropy_loss_grad op; - auto results = op.execute({&logits, &weights, &labels}, {0.3}, {2}); + auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2671,7 +2671,7 @@ TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test7) { weights.assign(0.5); nd4j::ops::sigm_cross_entropy_loss_grad op; - auto results = op.execute({&logits, &weights, &labels}, {0.3}, {2}); + auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2705,7 +2705,7 @@ TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test8) { weights.p(3, 0.); nd4j::ops::sigm_cross_entropy_loss_grad op; - auto results = op.execute({&logits, &weights, &labels}, {0.3}, {2}); + auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2741,7 +2741,7 @@ TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test9) { weights.assign(0.5); nd4j::ops::sigm_cross_entropy_loss_grad op; - auto results = op.execute({&logits, &weights, &labels}, {0.3}, {3}); + auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {3}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2773,7 +2773,7 @@ TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test10) { weights.assign(0.5); nd4j::ops::sigm_cross_entropy_loss_grad op; - auto results = op.execute({&logits, &weights, &labels}, {0.3}, {3}); + auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {3}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2799,7 +2799,7 @@ TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test11) { weights.assign(0.5); nd4j::ops::sigm_cross_entropy_loss_grad op; - auto results = op.execute({&logits, &weights, &labels}, {0.3}, {3}); + auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {3}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2834,7 +2834,7 @@ TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test12) { nd4j::ops::sigm_cross_entropy_loss_grad op; - auto results = op.execute({&logits, &weights, &labels}, {0.3}, {3}); + auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {3}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2872,7 +2872,7 @@ TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test13) { weights.t(2) = 0.; nd4j::ops::sigm_cross_entropy_loss_grad op; - auto results = op.execute({&logits, &weights, &labels}, {0.3}, {3}); + auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {3}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2901,7 +2901,7 @@ TEST_F(DeclarableOpsTests11, BFloat16_Test_4) { y.linspace(1); exp.linspace(2,2); nd4j::ops::add op; - auto results = op.execute({&x, &y}, {}, {}); + auto results = op.evaluate({&x, &y}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2922,7 +2922,7 @@ TEST_F(DeclarableOpsTests11, BFloat16_Test_5) { y.linspace(1); exp.linspace(1); nd4j::ops::subtract op; - auto results = op.execute({&x, &y}, {}, {}); + auto results = op.evaluate({&x, &y}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2943,7 +2943,7 @@ TEST_F(DeclarableOpsTests11, BFloat16_Test_6) { y.linspace(1); exp.linspace(1); nd4j::ops::subtract op; - auto results = op.execute({&x, &y}, {}, {}); + auto results = op.evaluate({&x, &y}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2968,7 +2968,7 @@ TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test1) { nd4j::ops::softmax_cross_entropy_loss_grad op; - auto results = op.execute({&logits, &weights, &labels}, {0.}, {0}); + auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2999,7 +2999,7 @@ TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test2) { nd4j::ops::softmax_cross_entropy_loss_grad op; - auto results = op.execute({&logits, &weights, &labels}, {0.}, {1}); + auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -3030,7 +3030,7 @@ TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test3) { nd4j::ops::softmax_cross_entropy_loss_grad op; - auto results = op.execute({&logits, &weights, &labels}, {0.}, {1}); + auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -3061,7 +3061,7 @@ TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test4) { nd4j::ops::softmax_cross_entropy_loss_grad op; - auto results = op.execute({&logits, &weights, &labels}, {0.}, {2}); + auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -3092,7 +3092,7 @@ TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test5) { nd4j::ops::softmax_cross_entropy_loss_grad op; - auto results = op.execute({&logits, &weights, &labels}, {0.}, {3}); + auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {3}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -3123,7 +3123,7 @@ TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test6) { nd4j::ops::softmax_cross_entropy_loss_grad op; - auto results = op.execute({&logits, &weights, &labels}, {0.3}, {2}); + auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -3154,7 +3154,7 @@ TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test7) { nd4j::ops::softmax_cross_entropy_loss_grad op; - auto results = op.execute({&logits, &weights, &labels}, {0.}, {3}); + auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {3}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -3196,7 +3196,7 @@ TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test8) { nd4j::ops::softmax_cross_entropy_loss_grad op; - auto results = op.execute({&logits, &weights, &labels}, {0.}, {2}); + auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -3237,7 +3237,7 @@ TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test1) { nd4j::ops::softmax_cross_entropy_loss_with_logits_grad op; - auto results = op.execute({&logits, &labels}, {}, {}); + auto results = op.evaluate({&logits, &labels}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -3261,7 +3261,7 @@ TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test2) { nd4j::ops::softmax_cross_entropy_loss_with_logits_grad op; - auto results = op.execute({&logits, &labels}, {}, {1}); + auto results = op.evaluate({&logits, &labels}, {}, {1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -3284,7 +3284,7 @@ TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test3) { nd4j::ops::softmax_cross_entropy_loss_with_logits_grad op; - auto results = op.execute({&logits, &labels}, {}, {0}); + auto results = op.evaluate({&logits, &labels}, {}, {0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -3306,7 +3306,7 @@ TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test4) { nd4j::ops::softmax_cross_entropy_loss_with_logits_grad op; - auto results = op.execute({&logits, &labels}, {}, {1}); + auto results = op.evaluate({&logits, &labels}, {}, {1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -3328,7 +3328,7 @@ TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test5) { nd4j::ops::softmax_cross_entropy_loss_with_logits_grad op; - auto results = op.execute({&logits, &labels}, {}, {0}); + auto results = op.evaluate({&logits, &labels}, {}, {0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -3350,7 +3350,7 @@ TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test6) { nd4j::ops::softmax_cross_entropy_loss_with_logits_grad op; - auto results = op.execute({&logits, &labels}, {}, {0}); + auto results = op.evaluate({&logits, &labels}, {}, {0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -3372,7 +3372,7 @@ TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test7) { nd4j::ops::softmax_cross_entropy_loss_with_logits_grad op; - auto results = op.execute({&logits, &labels}, {}, {0}); + auto results = op.evaluate({&logits, &labels}, {}, {0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -3394,7 +3394,7 @@ TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test8) { nd4j::ops::softmax_cross_entropy_loss_with_logits_grad op; - auto results = op.execute({&logits, &labels}, {}, {0}); + auto results = op.evaluate({&logits, &labels}, {}, {0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -3421,7 +3421,7 @@ TEST_F(DeclarableOpsTests11, Multiply_BP_Test1) { dLdpExp.assign(1.0); nd4j::ops::multiply_bp op; - auto results = op.execute({&x, &y, &dLdp}, {}, {}); + auto results = op.evaluate({&x, &y, &dLdp}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -3444,7 +3444,7 @@ TEST_F(DeclarableOpsTests11, sparseSoftmaxCrossEntropyWithLogits_grad_test1) { nd4j::ops::sparse_softmax_cross_entropy_loss_with_logits_grad op; - auto results = op.execute({&labels, &logits}, {}, {}); + auto results = op.evaluate({&labels, &logits}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -3468,7 +3468,7 @@ TEST_F(DeclarableOpsTests11, sparseSoftmaxCrossEntropyWithLogits_grad_test2) { nd4j::ops::sparse_softmax_cross_entropy_loss_with_logits_grad op; - auto results = op.execute({&labels, &logits}, {}, {}); + auto results = op.evaluate({&labels, &logits}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -3490,7 +3490,7 @@ TEST_F(DeclarableOpsTests11, sparseSoftmaxCrossEntropyWithLogits_grad_test3) { nd4j::ops::sparse_softmax_cross_entropy_loss_with_logits_grad op; - auto results = op.execute({&labels, &logits}, {}, {}); + auto results = op.evaluate({&labels, &logits}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -3514,7 +3514,7 @@ TEST_F(DeclarableOpsTests11, sparseSoftmaxCrossEntropyWithLogits_grad_test4) { nd4j::ops::sparse_softmax_cross_entropy_loss_with_logits_grad op; - auto results = op.execute({&labels, &logits}, {}, {}); + auto results = op.evaluate({&labels, &logits}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -3536,7 +3536,7 @@ TEST_F(DeclarableOpsTests11, sparseSoftmaxCrossEntropyWithLogits_grad_test5) { nd4j::ops::sparse_softmax_cross_entropy_loss_with_logits_grad op; - auto results = op.execute({&labels, &logits}, {}, {}); + auto results = op.evaluate({&labels, &logits}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp index c0ce9f1ab..80a9d67a4 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp @@ -44,7 +44,7 @@ TEST_F(DeclarableOpsTests12, test_any_validation_1) { auto y = NDArrayFactory::create('c', {2}, {1, 0}); nd4j::ops::transpose op; - auto result = op.execute({&x, &y}, {}, {}); + auto result = op.evaluate({&x, &y}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -69,7 +69,7 @@ TEST_F(DeclarableOpsTests12, cosine_distance_loss_grad_test1) { nd4j::ops::cosine_distance_loss_grad op; - auto results = op.execute({&predictions, &weights, &labels}, {}, {0, -1}); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0, -1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -101,7 +101,7 @@ TEST_F(DeclarableOpsTests12, cosine_distance_loss_grad_test2) { nd4j::ops::cosine_distance_loss_grad op; - auto results = op.execute({&predictions, &weights, &labels}, {}, {0, 0}); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0, 0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -135,7 +135,7 @@ TEST_F(DeclarableOpsTests12, cosine_distance_loss_grad_test3) { nd4j::ops::cosine_distance_loss_grad op; - auto results = op.execute({&predictions, &weights, &labels}, {}, {0, 0}); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0, 0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -169,7 +169,7 @@ TEST_F(DeclarableOpsTests12, cosine_distance_loss_grad_test4) { nd4j::ops::cosine_distance_loss_grad op; - auto results = op.execute({&predictions, &weights, &labels}, {}, {1, 1}); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1, 1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -204,7 +204,7 @@ TEST_F(DeclarableOpsTests12, cosine_distance_loss_grad_test5) { nd4j::ops::cosine_distance_loss_grad op; - auto results = op.execute({&predictions, &weights, &labels}, {}, {2, 0}); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2, 0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -238,7 +238,7 @@ TEST_F(DeclarableOpsTests12, cosine_distance_loss_grad_test6) { nd4j::ops::cosine_distance_loss_grad op; - auto results = op.execute({&predictions, &weights, &labels}, {}, {3, 1}); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3, 1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -274,7 +274,7 @@ TEST_F(DeclarableOpsTests12, cosine_distance_loss_grad_test7) { nd4j::ops::cosine_distance_loss_grad op; - auto results = op.execute({&predictions, &weights, &labels}, {}, {2, 0}); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2, 0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -310,7 +310,7 @@ TEST_F(DeclarableOpsTests12, cosine_distance_loss_grad_test8) { nd4j::ops::cosine_distance_loss_grad op; - auto results = op.execute({&predictions, &weights, &labels}, {}, {3, 1}); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3, 1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -346,7 +346,7 @@ TEST_F(DeclarableOpsTests12, cosine_distance_loss_grad_test9) { nd4j::ops::cosine_distance_loss_grad op; - auto results = op.execute({&predictions, &weights, &labels}, {}, {0, 2}); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0, 2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -422,7 +422,7 @@ TEST_F(DeclarableOpsTests12, TestDivideBP_2) { eps.linspace(1.); nd4j::ops::divide_bp op; - Nd4jStatus status = op.execute({&x, &y, &eps}, {&output1, &output2}, {}, {}, {}); + Nd4jStatus status = op.execute({&x, &y, &eps}, std::vector{&output1, &output2}, {}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, status); ASSERT_TRUE(output1.equalsTo(exp1)); @@ -443,7 +443,7 @@ TEST_F(DeclarableOpsTests12, TestReverseDivideBP_1) { eps.linspace(1.); nd4j::ops::reversedivide_bp op; - Nd4jStatus status = op.execute({&y, &x, &eps}, {&output2, &output1}, {}, {}, {}); + Nd4jStatus status = op.execute({&y, &x, &eps}, std::vector{&output2, &output1}, {}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, status); //ASSERT_TRUE(output.e(0) == 47.); @@ -467,7 +467,7 @@ TEST_F(DeclarableOpsTests12, TestReverseDivideBP_2) { exp1.assign(1.); exp2.assign(-2.); nd4j::ops::reversedivide_bp op; - Nd4jStatus status = op.execute({&y, &x, &eps}, {&output2, &output1}, {}, {}, {}); + Nd4jStatus status = op.execute({&y, &x, &eps}, std::vector{&output2, &output1}, {}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, status); ASSERT_TRUE(output1.equalsTo(exp1)); @@ -539,7 +539,7 @@ TEST_F(DeclarableOpsTests12, TestMaximumBP_1) { //exp1.assign(1.); //exp2.assign(-2.); nd4j::ops::maximum_bp op; - Nd4jStatus status = op.execute({&x, &y, &eps}, {&output1, &output2}, {}, {}, {}); + Nd4jStatus status = op.execute({&x, &y, &eps}, std::vector{&output1, &output2}, {}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, status); ASSERT_TRUE(output1.equalsTo(exp1)); @@ -564,7 +564,7 @@ TEST_F(DeclarableOpsTests12, TestMinimumBP_1) { //exp1.assign(1.); //exp2.assign(-2.); nd4j::ops::minimum_bp op; - Nd4jStatus status = op.execute({&x, &y, &eps}, {&output2, &output1}, {}, {}, {}); + Nd4jStatus status = op.execute({&x, &y, &eps}, std::vector{&output2, &output1}, {}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, status); ASSERT_TRUE(output1.equalsTo(exp1)); @@ -716,7 +716,7 @@ TEST_F(DeclarableOpsTests12, tensormmul_6) { NDArray exp('c', {2,2}, {2,4,6,8}, nd4j::DataType::FLOAT32); nd4j::ops::tensormmul op; - auto results = op.execute({&x, &y}, {}, {1,0, 1,1}); + auto results = op.evaluate({&x, &y}, {}, {1,0, 1,1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -743,7 +743,7 @@ TEST_F(DeclarableOpsTests12, reduceMeanBp_4) { exp = 0.333333; nd4j::ops::reduce_mean_bp op; - auto result = op.execute({&x, &gradO}, {}, {0}); + auto result = op.evaluate({&x, &gradO}, {}, {0}); auto output = result->at(0); // output->printShapeInfo(); @@ -765,7 +765,7 @@ TEST_F(DeclarableOpsTests12, reduceMeanBp_5) { exp = 0.2; nd4j::ops::reduce_mean_bp op; - auto result = op.execute({&x, &gradO}, {}, {1}); + auto result = op.evaluate({&x, &gradO}, {}, {1}); auto output = result->at(0); // output->printShapeInfo(); @@ -783,7 +783,7 @@ TEST_F(DeclarableOpsTests12, reduceSqnormBp_1) { NDArray gradO('c', {8,6,1}, nd4j::DataType::DOUBLE); nd4j::ops::reduce_sqnorm_bp op; - auto result = op.execute({&x, &gradO}, {1}, {2}); + auto result = op.evaluate({&x, &gradO}, {1}, {2}); ASSERT_EQ(Status::OK(), result->status()); delete result; @@ -937,7 +937,7 @@ TEST_F(DeclarableOpsTests12, lrn_bp_1) { nd4j::ops::lrn_bp op; - auto results = op.execute({&input, &gradO}, {1., 1., 1}, {5}); + auto results = op.evaluate({&input, &gradO}, {1., 1., 1}, {5}); auto gradI = results->at(0); ASSERT_EQ(*gradI, exp); @@ -968,7 +968,7 @@ TEST_F(DeclarableOpsTests12, lrn_bp_2) { nd4j::ops::lrn_bp op; - auto results = op.execute({&input, &gradO}, {1., 1., 1}, {2}); + auto results = op.evaluate({&input, &gradO}, {1., 1., 1}, {2}); auto gradI = results->at(0); ASSERT_EQ(*gradI, exp); @@ -999,7 +999,7 @@ TEST_F(DeclarableOpsTests12, lrn_bp_3) { nd4j::ops::lrn_bp op; - auto results = op.execute({&input, &gradO}, {1., 1., 1}, {7}); + auto results = op.evaluate({&input, &gradO}, {1., 1., 1}, {7}); auto gradI = results->at(0); ASSERT_EQ(*gradI, exp); @@ -1030,7 +1030,7 @@ TEST_F(DeclarableOpsTests12, lrn_bp_4) { nd4j::ops::lrn_bp op; - auto results = op.execute({&input, &gradO}, {1., 1., 1}, {12}); + auto results = op.evaluate({&input, &gradO}, {1., 1., 1}, {12}); auto gradI = results->at(0); ASSERT_EQ(*gradI, exp); @@ -1053,7 +1053,7 @@ TEST_F(DeclarableOpsTests12, lrn_bp_5) { nd4j::ops::lrn_bp op; - auto results = op.execute({&input, &gradO}, {1., 1., 0.5}, {2}); + auto results = op.evaluate({&input, &gradO}, {1., 1., 0.5}, {2}); auto gradI = results->at(0); ASSERT_EQ(*gradI, exp); @@ -1072,7 +1072,7 @@ TEST_F(DeclarableOpsTests12, lrn_bp_6) { nd4j::ops::lrn_bp op; - auto results = op.execute({&input, &gradO}, {1., 2., 0.5}, {10}); + auto results = op.evaluate({&input, &gradO}, {1., 2., 0.5}, {10}); auto gradI = results->at(0); ASSERT_EQ(*gradI, exp); @@ -1126,7 +1126,7 @@ TEST_F(DeclarableOpsTests12, lrn_bp_9) { nd4j::ops::lrn_bp op; - auto results = op.execute({&input, &gradO}, {1., 2., 0.5}, {3}); + auto results = op.evaluate({&input, &gradO}, {1., 2., 0.5}, {3}); auto gradI = results->at(0); // for (int i = 0; i < exp.lengthOf(); ++i) @@ -1146,7 +1146,7 @@ TEST_F(DeclarableOpsTests12, lrn_bp_10) { nd4j::ops::lrn_bp op; - auto results = op.execute({&input, &gradO}, {1., 2., 0.5}, {1}); + auto results = op.evaluate({&input, &gradO}, {1., 2., 0.5}, {1}); auto gradI = results->at(0); ASSERT_EQ(*gradI, exp); @@ -1167,7 +1167,7 @@ TEST_F(DeclarableOpsTests12, lrn_1) { nd4j::ops::lrn op; - auto results = op.execute({&input}, {1., 2., 0.5}, {2}); + auto results = op.evaluate({&input}, {1., 2., 0.5}, {2}); auto output = results->at(0); ASSERT_EQ(*output, exp); @@ -1183,7 +1183,7 @@ TEST_F(DeclarableOpsTests12, lrn_2) { nd4j::ops::lrn op; - auto results = op.execute({&input}, {0.1, 2., 0.5}, {5}); + auto results = op.evaluate({&input}, {0.1, 2., 0.5}, {5}); auto output = results->at(0); ASSERT_EQ(*output, exp); @@ -1198,7 +1198,7 @@ TEST_F(DeclarableOpsTests12, lrn_3) { nd4j::ops::lrn op; - auto results = op.execute({&input}, {0.1, 2., 0.5}, {5}); + auto results = op.evaluate({&input}, {0.1, 2., 0.5}, {5}); auto output = results->at(0); ASSERT_EQ(*output, exp); @@ -1213,7 +1213,7 @@ TEST_F(DeclarableOpsTests12, lrn_4) { nd4j::ops::lrn op; - auto results = op.execute({&input}, {0.1, 2., 0.5}, {0}); + auto results = op.evaluate({&input}, {0.1, 2., 0.5}, {0}); auto output = results->at(0); ASSERT_EQ(*output, exp); @@ -1228,7 +1228,7 @@ TEST_F(DeclarableOpsTests12, lrn_5) { nd4j::ops::lrn op; - auto results = op.execute({&input}, {0.1, 2., 0.5}, {0}); + auto results = op.evaluate({&input}, {0.1, 2., 0.5}, {0}); auto output = results->at(0); ASSERT_EQ(*output, exp); @@ -1268,7 +1268,7 @@ TEST_F(DeclarableOpsTests12, inTopK_2) { nd4j::ops::in_top_k op; - auto res = op.execute({&input, &idx}, {}, {1}, {}, false, nd4j::DataType::BOOL); + auto res = op.evaluate({&input, &idx}, {}, {1}); ASSERT_EQ(res->status(), ND4J_STATUS_OK); //res->at(0)->printIndexedBuffer("IN_TOP_K output"); @@ -1283,7 +1283,7 @@ TEST_F(DeclarableOpsTests12, inTopK_3) { auto expV = NDArrayFactory::create('c', {2}, {true, false}); nd4j::ops::in_top_k op; - auto result = op.execute({&x, &y}, {}, {2}); + auto result = op.evaluate({&x, &y}, {}, {2}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(1, result->size()); @@ -1303,7 +1303,7 @@ TEST_F(DeclarableOpsTests12, inTopK_4) { auto expV = NDArrayFactory::create('c', {6}, {true, false, true, false, false, true}); nd4j::ops::in_top_k op; - auto result = op.execute({&x, &y}, {}, {2}); + auto result = op.evaluate({&x, &y}, {}, {2}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(1, result->size()); @@ -1324,7 +1324,7 @@ TEST_F(DeclarableOpsTests12, inTopK_5) { auto expV = NDArrayFactory::create('f', {6}, {true, false, false, false, false, false }); nd4j::ops::in_top_k op; - auto result = op.execute({&x, &y}, {}, {2}); + auto result = op.evaluate({&x, &y}, {}, {2}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(1, result->size()); @@ -1345,7 +1345,7 @@ TEST_F(DeclarableOpsTests12, cube_1) { nd4j::ops::cube op; - auto result = op.execute({&x}, {}, {}); + auto result = op.evaluate({&x}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -1368,7 +1368,7 @@ TEST_F(DeclarableOpsTests12, cube_bp_1) { nd4j::ops::cube_bp op; - auto result = op.execute({&x, &gradO}, {}, {}); + auto result = op.evaluate({&x, &gradO}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -1391,7 +1391,7 @@ TEST_F(DeclarableOpsTests12, pad_tests1) { NDArray expected('c', {4,7}, {0,0,0,0,0,0,0, 0,0,1,2,3,0,0, 0,0,4,5,6,0,0, 0,0,0,0,0,0,0}, nd4j::DataType::FLOAT32); nd4j::ops::pad op; - auto results = op.execute({&input, &paddings}, {}, {0}); + auto results = op.evaluate({&input, &paddings}, {}, {0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1418,7 +1418,7 @@ TEST_F(DeclarableOpsTests12, pad_tests2) { auto expected = NDArrayFactory::create(expBuff, 'c', {4,7}); nd4j::ops::pad op; - auto results = op.execute({&input, &paddings}, {}, {1}); + auto results = op.evaluate({&input, &paddings}, {}, {1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1445,7 +1445,7 @@ TEST_F(DeclarableOpsTests12, pad_tests3) { auto expected = NDArrayFactory::create(expBuff, 'c', {4,7}); nd4j::ops::pad op; - auto results = op.execute({&input, &paddings}, {}, {2}); + auto results = op.evaluate({&input, &paddings}, {}, {2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1476,7 +1476,7 @@ TEST_F(DeclarableOpsTests12, pad_tests4) { auto expected = NDArrayFactory::create(expBuff, 'c', {4,7,7}); nd4j::ops::pad op; - auto results = op.execute({&input, &paddings}, {}, {0}); + auto results = op.evaluate({&input, &paddings}, {}, {0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1510,7 +1510,7 @@ TEST_F(DeclarableOpsTests12, pad_tests5) { auto expected = NDArrayFactory::create(expBuff, 'c', {4,7,7}); nd4j::ops::pad op; - auto results = op.execute({&input, &paddings}, {}, {1}); + auto results = op.evaluate({&input, &paddings}, {}, {1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1537,7 +1537,7 @@ TEST_F(DeclarableOpsTests12, pad_tests6) { auto expected = NDArrayFactory::create(expBuff, 'c', {4,7,7}); nd4j::ops::pad op; - auto results = op.execute({&input, &paddings}, {}, {2}); + auto results = op.evaluate({&input, &paddings}, {}, {2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1563,7 +1563,7 @@ TEST_F(DeclarableOpsTests12, pad_tests7) auto expected = NDArrayFactory::create(expBuff, 'c', {4, 4, 4, 4}); nd4j::ops::pad op; - auto results = op.execute({&input, &paddings}, {}, {0}); + auto results = op.evaluate({&input, &paddings}, {}, {0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1589,7 +1589,7 @@ TEST_F(DeclarableOpsTests12, pad_tests8) auto expected = NDArrayFactory::create(expBuff, 'c', {4, 4, 4, 4}); nd4j::ops::pad op; - auto results = op.execute({&input, &paddings}, {}, {1}); + auto results = op.evaluate({&input, &paddings}, {}, {1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1615,7 +1615,7 @@ TEST_F(DeclarableOpsTests12, pad_tests9) auto expected = NDArrayFactory::create(expBuff, 'c', {4, 4, 4, 4}); nd4j::ops::pad op; - auto results = op.execute({&input, &paddings}, {}, {2}); + auto results = op.evaluate({&input, &paddings}, {}, {2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1638,7 +1638,7 @@ TEST_F(DeclarableOpsTests12, pad_tests10) { input = 1.f; //input.assign(1.); nd4j::ops::pad op; - auto results = op.execute({&input, &paddings}, {}, {0}); + auto results = op.evaluate({&input, &paddings}, {}, {0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1660,7 +1660,7 @@ TEST_F(DeclarableOpsTests12, pad_tests11) { input.linspace(1.f); nd4j::ops::pad op; - auto results = op.execute({&input, &paddings}, {}, {1}); + auto results = op.evaluate({&input, &paddings}, {}, {1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1688,7 +1688,7 @@ TEST_F(DeclarableOpsTests12, pad_tests12) { input.linspace(1.f); nd4j::ops::pad op; - auto results = op.execute({&input, &paddings}, {}, {2}); + auto results = op.evaluate({&input, &paddings}, {}, {2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1710,7 +1710,7 @@ TEST_F(DeclarableOpsTests12, pad_tests13) { input.linspace(1.f); nd4j::ops::pad op; - auto results = op.execute({&input, &paddings}, {}, {1}); + auto results = op.evaluate({&input, &paddings}, {}, {1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1732,7 +1732,7 @@ TEST_F(DeclarableOpsTests12, pad_tests14) { input.linspace(1.f); nd4j::ops::pad op; - auto results = op.execute({&input, &paddings}, {}, {2}); + auto results = op.evaluate({&input, &paddings}, {}, {2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1753,7 +1753,7 @@ TEST_F(DeclarableOpsTests12, pad_tests15) { input.linspace(1.f); nd4j::ops::pad op; - auto results = op.execute({&input, &paddings}, {}, {2}); + auto results = op.evaluate({&input, &paddings}, {}, {2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1774,7 +1774,7 @@ TEST_F(DeclarableOpsTests12, pad_tests16) { input.linspace(1.f); nd4j::ops::pad op; - auto results = op.execute({&input, &paddings}, {}, {1}); + auto results = op.evaluate({&input, &paddings}, {}, {1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1795,7 +1795,7 @@ TEST_F(DeclarableOpsTests12, pad_tests17) { input.linspace(1.f); nd4j::ops::pad op; - auto results = op.execute({&input, &paddings}, {}, {2}); + auto results = op.evaluate({&input, &paddings}, {}, {2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1816,7 +1816,7 @@ TEST_F(DeclarableOpsTests12, pad_tests18) { input.linspace(1.f); nd4j::ops::pad op; - auto results = op.execute({&input, &paddings}, {}, {1}); + auto results = op.evaluate({&input, &paddings}, {}, {1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1837,7 +1837,7 @@ TEST_F(DeclarableOpsTests12, pad_tests19) { input.linspace(1.f); nd4j::ops::pad op; - auto results = op.execute({&input, &paddings}, {}, {1}); + auto results = op.evaluate({&input, &paddings}, {}, {1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1858,7 +1858,7 @@ TEST_F(DeclarableOpsTests12, pad_tests20) { input.linspace(1.f); nd4j::ops::pad op; - auto results = op.execute({&input, &paddings}, {}, {1}); + auto results = op.evaluate({&input, &paddings}, {}, {1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1880,7 +1880,7 @@ TEST_F(DeclarableOpsTests12, pad_tests21) { input.linspace(1.f); nd4j::ops::pad op; - auto results = op.execute({&input, &paddings}, {}, {2}); + auto results = op.evaluate({&input, &paddings}, {}, {2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1903,7 +1903,7 @@ TEST_F(DeclarableOpsTests12, pad_tests22) { input.linspace(1.f); nd4j::ops::pad op; - auto results = op.execute({&input, &paddings}, {}, {0}); + auto results = op.evaluate({&input, &paddings}, {}, {0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1926,7 +1926,7 @@ TEST_F(DeclarableOpsTests12, pad_tests23) { input.linspace(1.f); nd4j::ops::pad op; - auto results = op.execute({&input, &paddings}, {}, {0}); + auto results = op.evaluate({&input, &paddings}, {}, {0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1950,7 +1950,7 @@ TEST_F(DeclarableOpsTests12, pad_tests24) { input.linspace(1.f); nd4j::ops::pad op; - auto results = op.execute({&input, &paddings}, {}, {0}); + auto results = op.evaluate({&input, &paddings}, {}, {0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1972,7 +1972,7 @@ TEST_F(DeclarableOpsTests12, pad_tests25) { input.linspace(1.f); nd4j::ops::pad op; - auto results = op.execute({&input, &paddings}, {}, {2}); + auto results = op.evaluate({&input, &paddings}, {}, {2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1994,7 +1994,7 @@ TEST_F(DeclarableOpsTests12, pad_tests26) { input.linspace(1.f); nd4j::ops::pad op; - auto results = op.execute({&input, &paddings}, {}, {0}); + auto results = op.evaluate({&input, &paddings}, {}, {0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2054,7 +2054,7 @@ TEST_F(DeclarableOpsTests12, pad_tests29) { nd4j::ops::pad op; - auto res = op.execute({&in, &pad}, {10.0}, {0}); + auto res = op.evaluate({&in, &pad}, {10.0}, {0}); ASSERT_EQ(res->status(), ND4J_STATUS_OK); ASSERT_TRUE(exp.equalsTo(res->at(0))); delete res; @@ -2071,7 +2071,7 @@ TEST_F(DeclarableOpsTests12, pad_tests30) { nd4j::ops::pad op; - auto res = op.execute({&in, &pad}, {10.0}, {2}); + auto res = op.evaluate({&in, &pad}, {10.0}, {2}); ASSERT_EQ(res->status(), ND4J_STATUS_OK); ASSERT_TRUE(exp.equalsTo(res->at(0))); delete res; @@ -2089,7 +2089,7 @@ TEST_F(DeclarableOpsTests12, pad_tests31) { nd4j::ops::pad op; - auto res = op.execute({&in, &pad}, {10.0}, {1}); + auto res = op.evaluate({&in, &pad}, {10.0}, {1}); ASSERT_EQ(res->status(), ND4J_STATUS_OK); ASSERT_TRUE(exp.equalsTo(res->at(0))); delete res; @@ -2105,7 +2105,7 @@ TEST_F(DeclarableOpsTests12, pad_tests32) { nd4j::ops::pad op; - auto res = op.execute({&in, &pad}, {10.0}, {2}); + auto res = op.evaluate({&in, &pad}, {10.0}, {2}); ASSERT_EQ(res->status(), ND4J_STATUS_OK); ASSERT_TRUE(exp.equalsTo(res->at(0))); delete res; @@ -2128,7 +2128,7 @@ TEST_F(DeclarableOpsTests12, pad_tests33) { 11,10,9,9,10,11,12,12,11,10., 11,10,9,9,10,11,12,12,11,10., 7,6,5,5,6,7,8,8,7,6., 3,2,1,1,2,3,4,4,3,2.}); nd4j::ops::pad op; - auto res = op.execute({&in, &pad}, {10.0}, {2}); + auto res = op.evaluate({&in, &pad}, {10.0}, {2}); ASSERT_EQ(res->status(), ND4J_STATUS_OK); ASSERT_TRUE(exp.equalsTo(res->at(0))); delete res; @@ -2163,7 +2163,7 @@ TEST_F(DeclarableOpsTests12, Pad_1) { auto expected = NDArrayFactory::create(expBuff, 'c', {4,7}); nd4j::ops::pad op; - auto results = op.execute({&input, &paddings}, {}, {0}); + auto results = op.evaluate({&input, &paddings}, {}, {0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2190,7 +2190,7 @@ TEST_F(DeclarableOpsTests12, Pad_2) { auto expected = NDArrayFactory::create(expBuff, 'c', {4,7}); nd4j::ops::pad op; - auto results = op.execute({&input, &paddings}, {}, {1}); + auto results = op.evaluate({&input, &paddings}, {}, {1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2217,7 +2217,7 @@ TEST_F(DeclarableOpsTests12, Pad_3) { auto expected = NDArrayFactory::create(expBuff, 'c', {4,7}); nd4j::ops::pad op; - auto results = op.execute({&input, &paddings}, {}, {2}); + auto results = op.evaluate({&input, &paddings}, {}, {2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2244,7 +2244,7 @@ TEST_F(DeclarableOpsTests12, Pad_4) { auto expected = NDArrayFactory::create(expBuff, 'c', {4,7,7}); nd4j::ops::pad op; - auto results = op.execute({&input, &paddings}, {}, {0}); + auto results = op.evaluate({&input, &paddings}, {}, {0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2271,7 +2271,7 @@ TEST_F(DeclarableOpsTests12, Pad_5) { auto expected = NDArrayFactory::create(expBuff, 'c', {4,7,7}); nd4j::ops::pad op; - auto results = op.execute({&input, &paddings}, {}, {1}); + auto results = op.evaluate({&input, &paddings}, {}, {1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2298,7 +2298,7 @@ TEST_F(DeclarableOpsTests12, Pad_6) { auto expected = NDArrayFactory::create(expBuff, 'c', {4,7,7}); nd4j::ops::pad op; - auto results = op.execute({&input, &paddings}, {}, {2}); + auto results = op.evaluate({&input, &paddings}, {}, {2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2324,7 +2324,7 @@ TEST_F(DeclarableOpsTests12, Pad_7) auto expected = NDArrayFactory::create(expBuff, 'c', {4, 4, 4, 4}); nd4j::ops::pad op; - auto results = op.execute({&input, &paddings}, {}, {0}); + auto results = op.evaluate({&input, &paddings}, {}, {0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2350,7 +2350,7 @@ TEST_F(DeclarableOpsTests12, Pad_8) auto expected = NDArrayFactory::create(expBuff, 'c', {4, 4, 4, 4}); nd4j::ops::pad op; - auto results = op.execute({&input, &paddings}, {}, {1}); + auto results = op.evaluate({&input, &paddings}, {}, {1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2376,7 +2376,7 @@ TEST_F(DeclarableOpsTests12, Pad_9) auto expected = NDArrayFactory::create(expBuff, 'c', {4, 4, 4, 4}); nd4j::ops::pad op; - auto results = op.execute({&input, &paddings}, {}, {2}); + auto results = op.evaluate({&input, &paddings}, {}, {2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2395,7 +2395,7 @@ TEST_F(DeclarableOpsTests12, Test_Expose_1) { nd4j::ops::expose op; - auto result = op.execute({&input0, &input1}, {}, {}); + auto result = op.evaluate({&input0, &input1}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -2420,7 +2420,7 @@ TEST_F(DeclarableOpsTests12, Pad_SGO_Test_1) { nd4j::ops::pad op; - auto res = op.execute({&in, &pad}, {10.0}, {0}); + auto res = op.evaluate({&in, &pad}, {10.0}, {0}); ASSERT_EQ(res->status(), ND4J_STATUS_OK); // res->at(0)->printIndexedBuffer("PAD_SGO"); // exp.printIndexedBuffer("PAD_EXP"); @@ -2436,7 +2436,7 @@ TEST_F(DeclarableOpsTests12, LU_Test_1) { auto pExp = NDArrayFactory::create('c', {3}, {0, 1, 2}); nd4j::ops::lu op; - auto res = op.execute({&in}, {}, {}); + auto res = op.evaluate({&in}); ASSERT_EQ(res->status(), ND4J_STATUS_OK); auto z = res->at(0); auto p = res->at(1); @@ -2457,7 +2457,7 @@ TEST_F(DeclarableOpsTests12, LU_Test_2) { auto expP = NDArrayFactory::create({2, 0, 1}); nd4j::ops::lu op; - auto res = op.execute({&in}, {}, {}); + auto res = op.evaluate({&in}); ASSERT_EQ(res->status(), ND4J_STATUS_OK); auto z = res->at(0); auto p = res->at(1); @@ -2480,7 +2480,7 @@ TEST_F(DeclarableOpsTests12, LU_Test_3) { auto expP = NDArrayFactory::create({2, 1, 0}); nd4j::ops::lu op; - auto res = op.execute({&in}, {}, {}); + auto res = op.evaluate({&in}); ASSERT_EQ(res->status(), ND4J_STATUS_OK); auto z = res->at(0); auto p = res->at(1); @@ -2522,7 +2522,7 @@ TEST_F(DeclarableOpsTests12, LU_Test_4) { auto expP = NDArrayFactory::create({1, 2, 7, 3, 6, 8, 5, 4, 0, 9}); nd4j::ops::lu op; - auto res = op.execute({&in}, {}, {}); + auto res = op.evaluate({&in}); ASSERT_EQ(res->status(), ND4J_STATUS_OK); auto z = res->at(0); auto p = res->at(1); @@ -2592,7 +2592,7 @@ TEST_F(DeclarableOpsTests12, LU_Test_5) { }); nd4j::ops::lu op; - auto res = op.execute({&in}, {}, {}); + auto res = op.evaluate({&in}); ASSERT_EQ(res->status(), ND4J_STATUS_OK); auto z = res->at(0); auto p = res->at(1); @@ -2613,7 +2613,7 @@ TEST_F(DeclarableOpsTests12, LU_Test_1_2) { nd4j::ops::lu op; - auto res = op.execute({&in}, {}, {}); + auto res = op.evaluate({&in}); ASSERT_EQ(res->status(), ND4J_STATUS_OK); auto z = res->at(0); auto p = res->at(1); @@ -2641,7 +2641,7 @@ TEST_F(DeclarableOpsTests12, LU_Test_3_2) { auto expP = NDArrayFactory::create('c', {2,3}, {2, 1, 0, 2, 1, 0}); nd4j::ops::lu op; - auto res = op.execute({&in}, {}, {}); + auto res = op.evaluate({&in}); ASSERT_EQ(res->status(), ND4J_STATUS_OK); auto z = res->at(0); auto p = res->at(1); @@ -2669,7 +2669,7 @@ TEST_F(DeclarableOpsTests12, LU_Test_3_3) { auto expP = NDArrayFactory::create('c', {2,3}, {2, 1, 0, 0, 2, 1}); nd4j::ops::lu op; - auto res = op.execute({&in}, {}, {}); + auto res = op.evaluate({&in}); ASSERT_EQ(res->status(), ND4J_STATUS_OK); auto z = res->at(0); auto p = res->at(1); @@ -2697,7 +2697,7 @@ TEST_F(DeclarableOpsTests12, LU_Test_4_1) { auto expP = NDArrayFactory::create('c', {2,2}, {0, 1, 0, 1}); nd4j::ops::lu op; - auto res = op.execute({&in}, {}, {}); + auto res = op.evaluate({&in}); ASSERT_EQ(res->status(), ND4J_STATUS_OK); auto z = res->at(0); auto p = res->at(1); @@ -2725,7 +2725,7 @@ TEST_F(DeclarableOpsTests12, LU_Test_4_2) { auto expP = NDArrayFactory::create('c', {2,2}, {0, 1, 0, 1}); nd4j::ops::lu op; - auto res = op.execute({&in}, {}, {nd4j::DataType::INT64}); + auto res = op.evaluate({&in}, {}, {nd4j::DataType::INT64}); ASSERT_EQ(res->status(), ND4J_STATUS_OK); auto z = res->at(0); auto p = res->at(1); @@ -2750,7 +2750,7 @@ TEST_F(DeclarableOpsTests12, QR_Test_1) { auto expR = NDArrayFactory::create('c', {5,3}, { -14.177447, -20.666622, 13.401566, 0., -175.04254, 70.080315, 0., 0., 35.201546, 0., 0., 0., 0., 0., 0. }); nd4j::ops::qr op; - auto res = op.execute({&in}, {}, {}, {true}); + auto res = op.evaluate({&in}, {}, {}, {true}); ASSERT_EQ(res->status(), ND4J_STATUS_OK); auto q = res->at(0); @@ -2762,7 +2762,7 @@ TEST_F(DeclarableOpsTests12, QR_Test_1) { // q->printShapeInfo("Q shape"); // r->printShapeInfo("R shape"); nd4j::ops::matmul opMul; - auto res2 = opMul.execute({q, r}, {}, {}); //MmulHelper::matmul(q, r, &in, false, false); + auto res2 = opMul.evaluate({q, r}); //MmulHelper::matmul(q, r, &in, false, false); auto exp = res2->at(0);//->printIndexedBuffer("Result as result"); ASSERT_TRUE(exp->isSameShape(in)); // ASSERT_TRUE(q->isSameShape(expQ)); @@ -2797,7 +2797,7 @@ TEST_F(DeclarableOpsTests12, QR_Test_1_1) { -14.177447, -20.666622, 13.401566, 0., -175.04254, 70.080315, 0., 0., 35.201546, 0., 0., 0., 0., 0., 0. }); nd4j::ops::qr op; - auto res = op.execute({&in}, {}, {}, {true}); + auto res = op.evaluate({&in}, {}, {}, {true}); ASSERT_EQ(res->status(), ND4J_STATUS_OK); auto q = res->at(0); @@ -2809,7 +2809,7 @@ TEST_F(DeclarableOpsTests12, QR_Test_1_1) { // q->printShapeInfo("Q shape"); // r->printShapeInfo("R shape"); nd4j::ops::matmul opMul; - auto res2 = opMul.execute({q, r}, {}, {}); //MmulHelper::matmul(q, r, &in, false, false); + auto res2 = opMul.evaluate({q, r}); //MmulHelper::matmul(q, r, &in, false, false); auto exp = res2->at(0);//->printIndexedBuffer("Result as result"); ASSERT_TRUE(exp->isSameShape(in)); // ASSERT_TRUE(q->isSameShape(expQ)); @@ -2836,7 +2836,7 @@ TEST_F(DeclarableOpsTests12, QR_Test_2) { }); nd4j::ops::qr op; - auto res = op.execute({&in}, {}, {}, {false}); + auto res = op.evaluate({&in}, {}, {}, {false}); ASSERT_EQ(res->status(), ND4J_STATUS_OK); auto q = res->at(0); @@ -2847,7 +2847,7 @@ TEST_F(DeclarableOpsTests12, QR_Test_2) { // r->printIndexedBuffer("Upper triangular 5x3"); nd4j::ops::matmul opMul; - auto res2 = opMul.execute({q, r}, {}, {}); //MmulHelper::matmul(q, r, &in, false, false); + auto res2 = opMul.evaluate({q, r}); //MmulHelper::matmul(q, r, &in, false, false); auto exp = res2->at(0);//->printIndexedBuffer("Result as result"); ASSERT_TRUE(exp->isSameShape(in)); ASSERT_TRUE(exp->equalsTo(in)); @@ -2874,7 +2874,7 @@ TEST_F(DeclarableOpsTests12, TriangularSolve_Test_1) { nd4j::ops::triangular_solve op; - auto res = op.execute({&a, &b}, {}, {}); + auto res = op.evaluate({&a, &b}); ASSERT_EQ(res->status(), ND4J_STATUS_OK); auto z = res->at(0); @@ -2903,7 +2903,7 @@ TEST_F(DeclarableOpsTests12, TriangularSolve_Test_2) { nd4j::ops::triangular_solve op; - auto res = op.execute({&a, &b}, {}, {}); + auto res = op.evaluate({&a, &b}); ASSERT_EQ(res->status(), ND4J_STATUS_OK); auto z = res->at(0); @@ -2940,7 +2940,7 @@ TEST_F(DeclarableOpsTests12, TriangularSolve_Test_3) { nd4j::ops::triangular_solve op; - auto res = op.execute({&a, &b}, {}, {}); + auto res = op.evaluate({&a, &b}); ASSERT_EQ(res->status(), ND4J_STATUS_OK); auto z = res->at(0); @@ -2969,7 +2969,7 @@ TEST_F(DeclarableOpsTests12, TriangularSolve_Test_4) { nd4j::ops::triangular_solve op; - auto res = op.execute({&a, &b}, {}, {}, {false}); + auto res = op.evaluate({&a, &b}, {}, {}, {false}); ASSERT_EQ(res->status(), ND4J_STATUS_OK); auto z = res->at(0); @@ -2999,7 +2999,7 @@ TEST_F(DeclarableOpsTests12, TriangularSolve_Test_5) { nd4j::ops::triangular_solve op; - auto res = op.execute({&a, &b}, {}, {}, {false, true}); + auto res = op.evaluate({&a, &b}, {}, {}, {false, true}); ASSERT_EQ(res->status(), ND4J_STATUS_OK); auto z = res->at(0); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp index 18f58c2a1..a445666df 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp @@ -58,7 +58,7 @@ TEST_F(DeclarableOpsTests13, test_pow_1) { auto e = NDArrayFactory::create('c', {2, 2}, {8.f, 8.f, 8.f, 8.f}); nd4j::ops::Pow op; - auto result = op.execute({&x, &y}, {}, {}); + auto result = op.evaluate({&x, &y}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -73,7 +73,7 @@ TEST_F(DeclarableOpsTests13, test_empty_range_1) { auto limit = NDArrayFactory::create(0); nd4j::ops::range op; - auto result = op.execute({&start, &limit}, {}, {}); + auto result = op.evaluate({&start, &limit}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -85,7 +85,7 @@ TEST_F(DeclarableOpsTests13, test_empty_range_1) { TEST_F(DeclarableOpsTests13, test_empty_range_2) { nd4j::ops::range op; - auto result = op.execute({}, {1.0, 1.0}, {}); + auto result = op.evaluate({}, {1.0, 1.0}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -97,7 +97,7 @@ TEST_F(DeclarableOpsTests13, test_empty_range_2) { TEST_F(DeclarableOpsTests13, test_empty_range_3) { nd4j::ops::range op; - auto result = op.execute({}, {}, {1, 1}); + auto result = op.evaluate({}, {1, 1}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -143,7 +143,7 @@ TEST_F(DeclarableOpsTests13, test_listdiff_1) { auto oi = NDArrayFactory::create('c', {2}); nd4j::ops::listdiff op; - auto result = op.execute({&x, &y}, {&od, &oi}, {}, {}, {}); + auto result = op.execute({&x, &y}, std::vector{&od, &oi}, {}, {}, {}); ASSERT_EQ(Status::OK(), result); } @@ -152,7 +152,7 @@ TEST_F(DeclarableOpsTests13, test_greater_1) { auto y = NDArrayFactory::create('c', {1, 4}); nd4j::ops::greater op; - auto result = op.execute({&x, &y}, {}, {}); + auto result = op.evaluate({&x, &y}); ASSERT_EQ(Status::OK(), result->status()); delete result; @@ -165,7 +165,7 @@ TEST_F(DeclarableOpsTests13, test_eval_reduction_shape_1) { auto exp = NDArrayFactory::create('c', {2}, {1, 2}); nd4j::ops::evaluate_reduction_shape op; - auto result = op.execute({&x, &y}, {}, {}, {true}); + auto result = op.evaluate({&x, &y}, {true}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -218,7 +218,7 @@ TEST_F(DeclarableOpsTests13, BarnesHutTsne_GainsTest_1) { auto eps = NDArrayFactory::create('c', {2,3}, {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6}); auto exp = NDArrayFactory::create('c', {2,3}, {1.2,2.2,3.2,4.2,5.2,6.2}); nd4j::ops::barnes_gains op; - auto result = op.execute({&x, &y, &eps}, {}, {}); + auto result = op.evaluate({&x, &y, &eps}); ASSERT_EQ(result->status(), Status::OK()); //result->at(0)->printBuffer("Gains out"); ASSERT_TRUE(exp.equalsTo(result->at(0))); @@ -232,7 +232,7 @@ TEST_F(DeclarableOpsTests13, BarnesHutTsne_GainsTest_2) { auto eps = NDArrayFactory::create('c', {2,3}, {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6}); auto exp = NDArrayFactory::create('c', {2,3}, {1.2, 0.01, 3.2, 0.01, 5.2, 0.01}); nd4j::ops::barnes_gains op; - auto result = op.execute({&x, &y, &eps}, {}, {}); + auto result = op.evaluate({&x, &y, &eps}, {}, {}); ASSERT_EQ(result->status(), Status::OK()); //result->at(0)->printBuffer("Gains out"); ASSERT_TRUE(exp.equalsTo(result->at(0))); @@ -247,7 +247,7 @@ TEST_F(DeclarableOpsTests13, BarnesHutTsne_GainsTest_3) { auto eps = NDArrayFactory::create('c', {2,3}, {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6}); auto exp = NDArrayFactory::create('c', {2,3}, {0.01, 2.2, 0.01, 4.2, 0.01, 6.2}); nd4j::ops::barnes_gains op; - auto result = op.execute({&x, &y, &eps}, {}, {}); + auto result = op.evaluate({&x, &y, &eps}, {}, {}); ASSERT_EQ(result->status(), Status::OK()); //result->at(0)->printBuffer("Gains out"); ASSERT_TRUE(exp.equalsTo(result->at(0))); @@ -269,7 +269,7 @@ TEST_F(DeclarableOpsTests13, BarnesHutTsne_EdgeForceTest_1) { // auto eps = NDArrayFactory::create('c', {2,3}, {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6}); // auto exp = NDArrayFactory::create('c', {2,3}, {1, 2, 1, 2, 2, 2}); nd4j::ops::barnes_edge_forces op; - auto result = op.execute({&rows, &cols, &vals, &data}, {}, {1}); + auto result = op.evaluate({&rows, &cols, &vals, &data}, {}, {1}); ASSERT_EQ(result->status(), Status::OK()); @@ -293,7 +293,7 @@ TEST_F(DeclarableOpsTests13, BarnesHutTsne_EdgeForceTest_2) { // auto eps = NDArrayFactory::create('c', {2,3}, {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6}); // auto exp = NDArrayFactory::create('c', {2,3}, {1, 2, 1, 2, 2, 2}); nd4j::ops::barnes_edge_forces op; - auto result = op.execute({&rows, &cols, &vals, &data}, {}, {2}); + auto result = op.evaluate({&rows, &cols, &vals, &data}, {}, {2}); ASSERT_EQ(result->status(), Status::OK()); @@ -317,7 +317,7 @@ TEST_F(DeclarableOpsTests13, BarnesHutTsne_EdgeForceTest_3) { // auto eps = NDArrayFactory::create('c', {2,3}, {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6}); // auto exp = NDArrayFactory::create('c', {2,3}, {1, 2, 1, 2, 2, 2}); nd4j::ops::barnes_edge_forces op; - auto result = op.execute({&rows, &cols, &vals, &data}, {}, {11}); + auto result = op.evaluate({&rows, &cols, &vals, &data}, {}, {11}); //nd4j_printf("rows %lld, cols %lld, vals %lld, res full %lld\n", rows.lengthOf(), cols.lengthOf(), vals.lengthOf(), exp1.lengthOf()); ASSERT_EQ(result->status(), Status::OK()); @@ -340,7 +340,7 @@ TEST_F(DeclarableOpsTests13, BarnesHutTsne_symmetrized_1) { // auto eps = NDArrayFactory::create('c', {2,3}, {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6}); // auto exp = NDArrayFactory::create('c', {2,3}, {1, 2, 1, 2, 2, 2}); nd4j::ops::barnes_symmetrized op; - auto result = op.execute({&rows, &cols, &vals}, {}, {1}); + auto result = op.evaluate({&rows, &cols, &vals}, {}, {1}); ASSERT_EQ(result->status(), Status::OK()); //result->at(2)->printBuffer("Symmetrized1"); ASSERT_TRUE(exp.equalsTo(result->at(2))); @@ -359,7 +359,7 @@ TEST_F(DeclarableOpsTests13, BarnesHutTsne_symmetrized_2) { // auto eps = NDArrayFactory::create('c', {2,3}, {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6}); // auto exp = NDArrayFactory::create('c', {2,3}, {1, 2, 1, 2, 2, 2}); nd4j::ops::barnes_symmetrized op; - auto result = op.execute({&rows, &cols, &vals}, {}, {3}); + auto result = op.evaluate({&rows, &cols, &vals}, {}, {3}); ASSERT_EQ(result->status(), Status::OK()); //result->at(2)->printBuffer("Symmetrized2"); // ASSERT_TRUE(exp[i]->equalsTo(result->at(i))); @@ -378,7 +378,7 @@ TEST_F(DeclarableOpsTests13, BarnesHutTsne_symmetrized_3) { // auto eps = NDArrayFactory::create('c', {2,3}, {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6}); // auto exp = NDArrayFactory::create('c', {2,3}, {1, 2, 1, 2, 2, 2}); nd4j::ops::barnes_symmetrized op; - auto result = op.execute({&rows, &cols, &vals}, {}, {11}); + auto result = op.evaluate({&rows, &cols, &vals}, {}, {11}); ASSERT_EQ(result->status(), Status::OK()); //result->at(2)->printBuffer("Symmetrized3"); //exp.printBuffer("EXPect symm3"); @@ -402,7 +402,7 @@ TEST_F(DeclarableOpsTests13, BarnesHutTsne_symmetrized_4) { // auto eps = NDArrayFactory::create('c', {2,3}, {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6}); // auto exp = NDArrayFactory::create('c', {2,3}, {1, 2, 1, 2, 2, 2}); nd4j::ops::barnes_symmetrized op; - auto result = op.execute({&rows, &cols, &vals}, {}, {11}); + auto result = op.evaluate({&rows, &cols, &vals}, {}, {11}); ASSERT_EQ(result->status(), Status::OK()); auto res = result->at(2); // res->printBuffer("Symmetrized4"); @@ -428,7 +428,7 @@ TEST_F(DeclarableOpsTests13, CellContains_test_1) { // auto eps = NDArrayFactory::create('c', {2,3}, {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6}); // auto exp = NDArrayFactory::create('c', {2,3}, {1, 2, 1, 2, 2, 2}); nd4j::ops::cell_contains op; - auto result = op.execute({&corners, &width, &point}, {}, {5}); + auto result = op.evaluate({&corners, &width, &point}, {}, {5}); ASSERT_EQ(result->status(), Status::OK()); ASSERT_TRUE(result->at(0)->e(0)); //result->at(2)->printBuffer("Symmetrized3"); @@ -446,7 +446,7 @@ TEST_F(DeclarableOpsTests13, adjustHue_1) { NDArray exp ('c', {2,2,3}, {100,0,44, 208,5,220, 177,230,97, 2,255,244}, nd4j::DataType::FLOAT32); nd4j::ops::adjust_hue op; - std::unique_ptr results (op.execute({&input, &factor}, {}, {2})); + std::unique_ptr results (op.evaluate({&input, &factor}, {}, {2})); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -467,7 +467,7 @@ TEST_F(DeclarableOpsTests13, adjustHue_2) { nd4j::ops::adjust_hue op; - std::unique_ptr results(op.execute({&input}, {0.9}, {2})); + std::unique_ptr results(op.evaluate({&input}, {0.9}, {2})); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -487,7 +487,7 @@ TEST_F(DeclarableOpsTests13, adjustHue_3) { NDArray exp ('c', {2,2,3}, {0.,84.,100., 5.,220.,122.0001, 229.8,97.,230., 255.,142.8002,2.}, nd4j::DataType::FLOAT32); nd4j::ops::adjust_hue op; - std::unique_ptr results(op.execute({&input}, {-0.9}, {2})); + std::unique_ptr results(op.evaluate({&input}, {-0.9}, {2})); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -506,7 +506,7 @@ TEST_F(DeclarableOpsTests13, adjustHue_4) { NDArray exp ('c', {2,3,2}, {100,208, 0,5, 44,220, 177,2, 230,255, 97,244}, nd4j::DataType::FLOAT32); nd4j::ops::adjust_hue op; - std::unique_ptr results(op.execute({&input}, {0.5}, {1})); + std::unique_ptr results(op.evaluate({&input}, {0.5}, {1})); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -525,7 +525,7 @@ TEST_F(DeclarableOpsTests13, adjustHue_5) { NDArray exp ('c', {3,2,2}, {100,208, 177,2, 0,5, 230,255, 44,220, 97,244}, nd4j::DataType::FLOAT32); nd4j::ops::adjust_hue op; - std::unique_ptr results(op.execute({&input}, {0.5}, {0})); + std::unique_ptr results(op.evaluate({&input}, {0.5}, {0})); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -545,7 +545,7 @@ TEST_F(DeclarableOpsTests13, adjustSaturation_1) { NDArray exp ('c', {2,2,3}, {50,100,78, 118.5,220,112.5, 190,163.5,230, 255,128.5,134}, nd4j::DataType::FLOAT32); nd4j::ops::adjust_saturation op; - auto results = op.execute({&input, &factor}, {}, {2}); + auto results = op.evaluate({&input, &factor}, {}, {2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -564,7 +564,7 @@ TEST_F(DeclarableOpsTests13, adjustSaturation_2) { NDArray exp ('c', {2,2,3}, {0.,100.,56., 12.279087,220.,0., 91.654228,0.,230., 255.,0.,11.087015}, nd4j::DataType::DOUBLE); nd4j::ops::adjust_saturation op; - auto results = op.execute({&input}, {10}, {2}); + auto results = op.evaluate({&input}, {10}, {2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -585,7 +585,7 @@ TEST_F(DeclarableOpsTests13, adjustSaturation_3) { NDArray exp ('c', {2,2,3}, {100.,100.,100., 220.,220.,220., 230.,230.,230., 255., 255., 255.}, nd4j::DataType::FLOAT32); nd4j::ops::adjust_saturation op; - auto results = op.execute({&input}, {-10}, {2}); + auto results = op.evaluate({&input}, {-10}, {2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -605,7 +605,7 @@ TEST_F(DeclarableOpsTests13, adjustSaturation_4) { NDArray exp ('c', {2,3,2}, {50,118.5, 100,220, 78,112.5, 190,255, 163.5,128.5, 230,134}, nd4j::DataType::FLOAT32); nd4j::ops::adjust_saturation op; - auto results = op.execute({&input}, {0.5}, {1}); + auto results = op.evaluate({&input}, {0.5}, {1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -625,7 +625,7 @@ TEST_F(DeclarableOpsTests13, adjustSaturation_5) { NDArray exp ('c', {3,2,2}, {50,118.5, 190,255, 100,220, 163.5,128.5, 78,112.5, 230,134}, nd4j::DataType::FLOAT32); nd4j::ops::adjust_saturation op; - auto results = op.execute({&input}, {0.5}, {0}); + auto results = op.evaluate({&input}, {0.5}, {0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -646,7 +646,7 @@ TEST_F(DeclarableOpsTests13, shift_bits_1) { e.assign(512); nd4j::ops::shift_bits op; - auto result = op.execute({&x, &y}, {}, {}); + auto result = op.evaluate({&x, &y}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -664,7 +664,7 @@ TEST_F(DeclarableOpsTests13, rshift_bits_1) { e.assign(32); nd4j::ops::rshift_bits op; - auto result = op.execute({&x, &y}, {}, {}); + auto result = op.evaluate({&x, &y}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -682,7 +682,7 @@ TEST_F(DeclarableOpsTests13, cyclic_shift_bits_1) { e.assign(512); nd4j::ops::cyclic_shift_bits op; - auto result = op.execute({&x, &y}, {}, {}); + auto result = op.evaluate({&x, &y}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -700,7 +700,7 @@ TEST_F(DeclarableOpsTests13, cyclic_rshift_bits_1) { e.assign(32); nd4j::ops::cyclic_rshift_bits op; - auto result = op.execute({&x, &y}, {}, {}); + auto result = op.evaluate({&x, &y}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -719,7 +719,7 @@ TEST_F(DeclarableOpsTests13, shift_bits_2) { e.assign(512); nd4j::ops::shift_bits op; - auto result = op.execute({&x, &y}, {}, {}); + auto result = op.evaluate({&x, &y}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -738,7 +738,7 @@ TEST_F(DeclarableOpsTests13, rshift_bits_2) { e.assign(32); nd4j::ops::rshift_bits op; - auto result = op.execute({&x, &y}, {}, {}); + auto result = op.evaluate({&x, &y}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -757,7 +757,7 @@ TEST_F(DeclarableOpsTests13, cyclic_shift_bits_2) { e.assign(512); nd4j::ops::cyclic_shift_bits op; - auto result = op.execute({&x, &y}, {}, {}); + auto result = op.evaluate({&x, &y}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -776,7 +776,7 @@ TEST_F(DeclarableOpsTests13, cyclic_rshift_bits_2) { e.assign(32); nd4j::ops::cyclic_rshift_bits op; - auto result = op.execute({&x, &y}, {}, {}); + auto result = op.evaluate({&x, &y}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -794,7 +794,7 @@ TEST_F(DeclarableOpsTests13, shift_bits_3) { e.assign(512); nd4j::ops::shift_bits op; - auto result = op.execute({&x, &y}, {}, {}); + auto result = op.evaluate({&x, &y}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -817,7 +817,7 @@ TEST_F(DeclarableOpsTests13, space_to_batch_nd_1) { exp.linspace(1); nd4j::ops::space_to_batch_nd op; - auto result = op.execute({&x, &blockShape, &paddings}, {}, {}); + auto result = op.evaluate({&x, &blockShape, &paddings}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -844,7 +844,7 @@ TEST_F(DeclarableOpsTests13, space_to_batch_nd_2) { x.linspace(1); nd4j::ops::space_to_batch_nd op; - auto result = op.execute({&x, &blockShape, &paddings}, {}, {}); + auto result = op.evaluate({&x, &blockShape, &paddings}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -875,7 +875,7 @@ TEST_F(DeclarableOpsTests13, space_to_batch_nd_3) { x.linspace(1); nd4j::ops::space_to_batch_nd op; - auto result = op.execute({&x, &blockShape, &paddings}, {}, {}); + auto result = op.evaluate({&x, &blockShape, &paddings}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -901,7 +901,7 @@ TEST_F(DeclarableOpsTests13, batch_to_space_nd_1) { exp.linspace(1); nd4j::ops::batch_to_space_nd op; - auto result = op.execute({&x, &blockShape, &crop}, {}, {}); + auto result = op.evaluate({&x, &blockShape, &crop}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -924,7 +924,7 @@ TEST_F(DeclarableOpsTests13, batch_to_space_nd_2) { x.linspace(1); nd4j::ops::batch_to_space_nd op; - auto result = op.execute({&x, &blockShape, &crop}, {}, {}); + auto result = op.evaluate({&x, &blockShape, &crop}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -948,7 +948,7 @@ TEST_F(DeclarableOpsTests13, batch_to_space_nd_3) { x.linspace(1); nd4j::ops::batch_to_space_nd op; - auto result = op.execute({&x, &blockShape, &crop}, {}, {}); + auto result = op.evaluate({&x, &blockShape, &crop}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -974,7 +974,7 @@ TEST_F(DeclarableOpsTests13, mergemax_1) { nd4j::ops::mergemax op; - auto result = op.execute({&x1, &x2, &x3}, {}, {}); + auto result = op.evaluate({&x1, &x2, &x3}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -1040,9 +1040,9 @@ TEST_F(DeclarableOpsTests13, lstmLayer_1) { hI = 1.; cI = 2.; - std::initializer_list tArgs = {cellClip}; - std::initializer_list iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; - std::initializer_list bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; + std::vector tArgs = {cellClip}; + std::vector iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; + std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; auto expH = NDArrayFactory::create('c', {sL, bS, nOut}, {0.57574f, 0.57574f, 0.57574f, 0.58006f, 0.58006f, 0.58006f, 0.58434f, 0.58434f, 0.58434f, 0.55114f, 0.55114f, 0.55114f, 0.55732f, 0.55732f, 0.55732f, 0.56338f, 0.56338f, 0.56338f, @@ -1053,7 +1053,7 @@ TEST_F(DeclarableOpsTests13, lstmLayer_1) { auto expClast = NDArrayFactory::create('c', {bS, nOut}, {1.1589154f, 1.1589154f, 1.1589154f, 1.1892855f, 1.1892855f, 1.1892855f, 1.219861f, 1.219861f, 1.219861f}); nd4j::ops::lstmLayer op; - auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs); + auto results = op.evaluate({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1110,9 +1110,9 @@ TEST_F(DeclarableOpsTests13, lstmLayer_2) { hI = 1.; cI = 2.; - std::initializer_list tArgs = {cellClip}; - std::initializer_list iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; - std::initializer_list bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; + std::vector tArgs = {cellClip}; + std::vector iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; + std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; auto expH = NDArrayFactory::create('c', {bS, sL, nOut}, {0.575735f, 0.575735f, 0.575735f, 0.541562f, 0.541562f, 0.541562f, 0.514003f, 0.514003f, 0.514003f, 0.495597f, 0.495597f, 0.495597f, 0.485999f, 0.485999f, 0.485999f, 0.596965f, 0.596965f, 0.596965f, 0.571978f, 0.571978f, 0.571978f, 0.552888f, 0.552888f, 0.552888f, 0.540606f, 0.540606f, 0.540606f, 0.534764f, 0.534764f, 0.534764f, @@ -1121,7 +1121,7 @@ TEST_F(DeclarableOpsTests13, lstmLayer_2) { auto expClast = NDArrayFactory::create('c', {bS, nOut}, {0.996965f, 0.996965f, 0.996965f, 1.146756f, 1.146756f, 1.146756f, 1.301922f, 1.301922f, 1.301922f}); nd4j::ops::lstmLayer op; - auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs); + auto results = op.evaluate({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1178,9 +1178,9 @@ TEST_F(DeclarableOpsTests13, lstmLayer_3) { hI = 1.; cI = 2.; - std::initializer_list tArgs = {cellClip}; - std::initializer_list iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; - std::initializer_list bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; + std::vector tArgs = {cellClip}; + std::vector iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; + std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; NDArray expH('c', {sL, bS, nOut}, {0.493883f, 0.493883f, 0.493883f, 0.510990f, 0.510990f, 0.510990f, 0.534701f, 0.534701f, 0.534701f, 0.549139f, 0.549139f, 0.549139f, 0.571900f, 0.571900f, 0.571900f, 0.583561f, 0.583561f, 0.583561f, 0.605106f, 0.605106f, @@ -1190,7 +1190,7 @@ TEST_F(DeclarableOpsTests13, lstmLayer_3) { NDArray expCL('c', {bS, nOut}, {1.061274f, 1.061274f, 1.061274f, 1.115888f, 1.115888f, 1.115888f}, nd4j::DataType::FLOAT32); nd4j::ops::lstmLayer op; - auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs); + auto results = op.evaluate({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1256,9 +1256,9 @@ TEST_F(DeclarableOpsTests13, lstmLayer_4) { cI({0,1, 0,0, 0,0}) = 2; cI({1,2, 0,0, 0,0}) = -2; - std::initializer_list tArgs = {cellClip}; - std::initializer_list iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; - std::initializer_list bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; + std::vector tArgs = {cellClip}; + std::vector iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; + std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; NDArray expH('c', {sL, bS, 2 * nOut}, { 0.577661f, 0.577661f, 0.577661f, -0.107642f, -0.107642f, -0.107642f, 0.585289f, 0.585289f, 0.585289f, @@ -1275,7 +1275,7 @@ TEST_F(DeclarableOpsTests13, lstmLayer_4) { -0.295768f, -0.295768f, -0.298453f, -0.298453f, -0.298453f}, nd4j::DataType::FLOAT32); nd4j::ops::lstmLayer op; - auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs); + auto results = op.evaluate({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1340,9 +1340,9 @@ TEST_F(DeclarableOpsTests13, lstmLayer_5) { cI({0,1, 0,0, 0,0}) = 2; cI({1,2, 0,0, 0,0}) = -2; - std::initializer_list tArgs = {cellClip}; - std::initializer_list iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; - std::initializer_list bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; + std::vector tArgs = {cellClip}; + std::vector iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; + std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; NDArray expH('c', {bS, sL, 2*nOut}, { 0.577661f, 0.577661f, 0.577661f, -0.107659f, -0.107659f, -0.107659f, 0.548099f, 0.548099f, 0.548099f, -0.113406f, -0.113406f, -0.113406f, @@ -1357,7 +1357,7 @@ TEST_F(DeclarableOpsTests13, lstmLayer_5) { -0.295811f, -0.295811f, -0.295811f, -0.305394f, -0.305394f, -0.305394f}, nd4j::DataType::FLOAT32); nd4j::ops::lstmLayer op; - auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs); + auto results = op.evaluate({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1426,9 +1426,9 @@ TEST_F(DeclarableOpsTests13, lstmLayer_6) { cI({0,1, 0,0, 0,0}) = 2; cI({1,2, 0,0, 0,0}) = -2; - std::initializer_list tArgs = {cellClip}; - std::initializer_list iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; - std::initializer_list bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; + std::vector tArgs = {cellClip}; + std::vector iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; + std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; NDArray expH('c', {sL, bS, nOut}, { 0.470019f, 0.470019f, 0.470019f, 0.478352f, 0.478352f, 0.478352f, 0.444871f, 0.444871f, 0.444871f, 0.457060f, @@ -1443,7 +1443,7 @@ TEST_F(DeclarableOpsTests13, lstmLayer_6) { nd4j::DataType::FLOAT32); nd4j::ops::lstmLayer op; - auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs); + auto results = op.evaluate({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1519,7 +1519,7 @@ TEST_F(DeclarableOpsTests13, lstmLayer_7) { NDArray expCL('c', {bS, nOut}, {1.147089, 1.147089, 1.147089,1.197228, 1.197228, 1.197228}, nd4j::DataType::FLOAT32); nd4j::ops::lstmLayer op; - auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); + auto results = op.evaluate({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1597,7 +1597,7 @@ TEST_F(DeclarableOpsTests13, lstmLayer_8) { NDArray expCL('c', {bS, nOut}, {0.879804f, 0.879804f, 0.879804f, 0.914666f, 0.914666f, 0.914666f}, nd4j::DataType::FLOAT32); nd4j::ops::lstmLayer op; - auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); + auto results = op.evaluate({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1684,7 +1684,7 @@ TEST_F(DeclarableOpsTests13, lstmLayer_9) { -0.292174f, -0.292174f, -0.292174f}, nd4j::DataType::FLOAT32); nd4j::ops::lstmLayer op; - auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); + auto results = op.evaluate({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1769,7 +1769,7 @@ TEST_F(DeclarableOpsTests13, lstmLayer_10) { NDArray expCL('c', {bS, nOut}, {0.f, 0.f, 0.f, 1.534275f, 1.534275f, 1.534275f, 1.40183f, 1.40183f, 1.40183f, 1.449675f, 1.449675f, 1.449675f, 1.767702f, 1.767702f, 1.767702f}, nd4j::DataType::FLOAT32); nd4j::ops::lstmLayer op; - auto results = op.execute({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); + auto results = op.evaluate({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1849,7 +1849,7 @@ TEST_F(DeclarableOpsTests13, lstmLayer_11) { NDArray expCL('c', {bS, nOut}, {0.f, 0.f, 0.f, 2.092814f, 2.092814f, 2.092814f, 2.08832f, 2.08832f, 2.08832f, 2.009851f, 2.009851f, 2.009851f, 1.646034f, 1.646034f, 1.646034f}, nd4j::DataType::FLOAT32); nd4j::ops::lstmLayer op; - auto results = op.execute({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); + auto results = op.evaluate({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1940,7 +1940,7 @@ TEST_F(DeclarableOpsTests13, lstmLayer_12) { 0.f, 0.f, 0.f, -0.86636f, -0.86636f, -0.86636f, -0.470245f, -0.470245f, -0.470245f, -0.341856f, -0.341856f, -0.341856f, -0.294986f, -0.294986f, -0.294986f}, nd4j::DataType::FLOAT32); nd4j::ops::lstmLayer op; - auto results = op.execute({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); + auto results = op.evaluate({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1977,7 +1977,7 @@ TEST_F(DeclarableOpsTests13, batchnorm_test1) { nd4j::ops::batchnorm op; - auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1}); + auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2010,7 +2010,7 @@ TYPED_TEST(TypedDeclarableOpsTests13, batchnorm_test2) { nd4j::ops::batchnorm op; - auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1}); + auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2039,7 +2039,7 @@ TYPED_TEST(TypedDeclarableOpsTests13, batchnorm_test3) { nd4j::ops::batchnorm op; - auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1,1}); + auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1,1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2067,7 +2067,7 @@ TYPED_TEST(TypedDeclarableOpsTests13, batchnorm_test4) { nd4j::ops::batchnorm op; - auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1,0,2}); + auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1,0,2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2095,7 +2095,7 @@ TEST_F(DeclarableOpsTests13, batchnorm_test5) { nd4j::ops::batchnorm op; - auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1, 1, 1}); + auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1, 1, 1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2124,7 +2124,7 @@ TEST_F(DeclarableOpsTests13, batchnorm_test6) { nd4j::ops::batchnorm op; - auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1,3}); + auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1,3}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2193,7 +2193,7 @@ TEST_F(DeclarableOpsTests13, batchnorm_test8) { nd4j::ops::batchnorm op; - auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1, 1,2,3}); + auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1, 1,2,3}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2236,7 +2236,7 @@ TEST_F(DeclarableOpsTests13, batchnorm_test9) { nd4j::ops::batchnorm op; - auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1, 1,2,3,4}); + auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1, 1,2,3,4}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2271,7 +2271,7 @@ TEST_F(DeclarableOpsTests13, batchnorm_bp_test1) { nd4j::ops::batchnorm_bp op; - auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1}); + auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2314,7 +2314,7 @@ TEST_F(DeclarableOpsTests13, batchnorm_bp_test2) { nd4j::ops::batchnorm_bp op; - auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,1}); + auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2356,7 +2356,7 @@ TEST_F(DeclarableOpsTests13, batchnorm_bp_test3) { nd4j::ops::batchnorm_bp op; - auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,0,2}); + auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,0,2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2395,7 +2395,7 @@ TEST_F(DeclarableOpsTests13, batchnorm_bp_test4) { nd4j::ops::batchnorm_bp op; - auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1}); + auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2439,7 +2439,7 @@ return; nd4j::ops::batchnorm_bp op; - auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,1}); + auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2484,7 +2484,7 @@ return; nd4j::ops::batchnorm_bp op; - auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,3}); + auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,3}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2532,7 +2532,7 @@ return; nd4j::ops::batchnorm_bp op; - auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,4}); + auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,4}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2581,7 +2581,7 @@ return; nd4j::ops::batchnorm_bp op; - auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,1}); + auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2635,7 +2635,7 @@ TEST_F(DeclarableOpsTests13, batchnorm_bp_test9) { nd4j::ops::batchnorm_bp op; - auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,1}); + auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2687,7 +2687,7 @@ TEST_F(DeclarableOpsTests13, batchnorm_bp_test10) { nd4j::ops::batchnorm_bp op; - auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,3}); + auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,3}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2751,7 +2751,7 @@ TEST_F(DeclarableOpsTests13, batchnorm_bp_test11) { nd4j::ops::batchnorm_bp op; - auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1, 1,2,3}); + auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1, 1,2,3}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests14.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests14.cpp index 574da8993..1815e5336 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests14.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests14.cpp @@ -44,7 +44,7 @@ TEST_F(DeclarableOpsTests14, Test_Validation_Edge_1) { exp.assign(4.0f); nd4j::ops::fill op; - auto result = op.execute({&x}, {4.0f},{}, {}); + auto result = op.evaluate({&x}, {4.0f}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -62,7 +62,7 @@ TEST_F(DeclarableOpsTests14, Test_Reshape_CF_1) { r.streamline('f'); nd4j::ops::reshape op; - auto result = op.execute({&x}, {}, {3, 2}, {}); + auto result = op.evaluate({&x}, {3, 2}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -96,7 +96,7 @@ TEST_F(DeclarableOpsTests14, Multiply_test) { e.assign(1.0); nd4j::ops::multiply op; - auto result = op.execute({&x, &y}, {}, {}); + auto result = op.evaluate({&x, &y}); auto f = result->at(0); NDArray r = *f; @@ -113,7 +113,7 @@ TEST_F(DeclarableOpsTests14, Test_EvalReductionShape_1) { auto e = NDArrayFactory::create('c', {2}, {5, 4}); nd4j::ops::evaluate_reduction_shape op; - auto result = op.execute({&x, &y}, {}, {}, {false, false}); + auto result = op.evaluate({&x, &y}, {}, {}, {false, false}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -128,7 +128,7 @@ TEST_F(DeclarableOpsTests14, Test_EvalReductionShape_2) { auto e = NDArrayFactory::create('c', {3}, {5, 1, 4}); nd4j::ops::evaluate_reduction_shape op; - auto result = op.execute({&x, &y}, {}, {}, {true, false}); + auto result = op.evaluate({&x, &y}, {}, {}, {true, false}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -183,7 +183,7 @@ TEST_F(DeclarableOpsTests14, Test_scalar_broadcast_1) { nd4j::ops::add op; - auto result = op.execute({&x, &y}, {}, {}); + auto result = op.evaluate({&x, &y}); ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(e, *result->at(0)); @@ -200,7 +200,7 @@ TEST_F(DeclarableOpsTests14, Test_scalar_broadcast_2) { nd4j::ops::subtract op; - auto result = op.execute({&x, &y}, {}, {}); + auto result = op.evaluate({&x, &y}); ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(e, *result->at(0)); @@ -213,7 +213,7 @@ TEST_F(DeclarableOpsTests14, test_empty_fill_1) { auto y = NDArrayFactory::create(1); nd4j::ops::fill op; - auto result = op.execute({&x, &y}, {}, {}); + auto result = op.evaluate({&x, &y}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -250,13 +250,13 @@ TEST_F(DeclarableOpsTests14, test_empty_stack_1) { auto e = NDArrayFactory::create('c', {1, 0}); nd4j::ops::stack op; - auto result = op.execute({&x}, {}, {0}); + auto result = op.evaluate({&x}, {}, {0}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); ASSERT_EQ(e, *z); nd4j::ops::reduce_min sumOp; - auto res2 = sumOp.execute({&e}, {1.}, {1}); + auto res2 = sumOp.evaluate({&e}, {1.}, {1}); ASSERT_EQ(res2->status(), Status::OK()); auto out = res2->at(0); @@ -270,7 +270,7 @@ TEST_F(DeclarableOpsTests14, test_empty_stack_2) { auto e = NDArrayFactory::create('c', {0}); nd4j::ops::stack op; - auto result = op.execute({&x}, {}, {0}); + auto result = op.evaluate({&x}, {}, {0}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -284,7 +284,7 @@ TEST_F(DeclarableOpsTests14, test_empty_stack_3) { auto e = NDArrayFactory::create('c', {2, 0}); nd4j::ops::stack op; - auto result = op.execute({&x, &x}, {}, {0}); + auto result = op.evaluate({&x, &x}, {}, {0}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -298,7 +298,7 @@ TEST_F(DeclarableOpsTests14, test_empty_stack_4) { auto e = NDArrayFactory::create('c', {2, 0}); nd4j::ops::stack op; - auto result = op.execute({&x, &x}, {}, {0}); + auto result = op.evaluate({&x, &x}, {}, {0}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -311,7 +311,7 @@ TEST_F(DeclarableOpsTests14, test_empty_reduce_min_1) { auto e = NDArrayFactory::create('c', {1, 0}); nd4j::ops::reduce_min sumOp; - auto res2 = sumOp.execute({&e}, {1.}, {1}); + auto res2 = sumOp.evaluate({&e}, {1.}, {1}); ASSERT_EQ(res2->status(), Status::OK()); auto out = res2->at(0); @@ -323,7 +323,7 @@ TEST_F(DeclarableOpsTests14, test_empty_reduce_max_1) { auto e = NDArrayFactory::create('c', {1, 0}); nd4j::ops::reduce_max sumOp; - auto res2 = sumOp.execute({&e}, {1.}, {1}); + auto res2 = sumOp.evaluate({&e}, {1.}, {1}); ASSERT_EQ(res2->status(), Status::OK()); auto out = res2->at(0); @@ -335,7 +335,7 @@ TEST_F(DeclarableOpsTests14, test_empty_reduce_sum_1) { auto e = NDArrayFactory::create('c', {1, 0}); nd4j::ops::reduce_sum sumOp; - auto res2 = sumOp.execute({&e}, {1.}, {1}); + auto res2 = sumOp.evaluate({&e}, {1.}, {1}); ASSERT_EQ(res2->status(), Status::OK()); auto out = res2->at(0); ASSERT_EQ(out->e(0), 0.f); @@ -346,7 +346,7 @@ TEST_F(DeclarableOpsTests14, test_empty_reduce_mean_1) { auto e = NDArrayFactory::create('c', {1, 0}); nd4j::ops::reduce_mean sumOp; - auto res2 = sumOp.execute({&e}, {1.}, {1}); + auto res2 = sumOp.evaluate({&e}, {1.}, {1}); ASSERT_EQ(res2->status(), Status::OK()); auto out = res2->at(0); // out->printShapeInfo("ReduceMean empty shape with keep dims"); @@ -366,7 +366,7 @@ TEST_F(DeclarableOpsTests14, Test_StridedSliceZeros_1) { matrix.linspace(1); nd4j::ops::strided_slice op; - auto result = op.execute({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 0}); + auto result = op.evaluate({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 0}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -387,7 +387,7 @@ TEST_F(DeclarableOpsTests14, Test_StridedSliceZeros_2) { matrix.linspace(1); nd4j::ops::strided_slice op; - auto result = op.execute({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 1}); + auto result = op.evaluate({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 1}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -405,7 +405,7 @@ TEST_F(DeclarableOpsTests14, test_empty_argmax_1) { nd4j::ops::argmax op; //nd4j::ops::reduce_max op; - auto result = op.execute({&x, &y}, {}, {}); + auto result = op.evaluate({&x, &y}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -432,7 +432,7 @@ TEST_F(DeclarableOpsTests14, test_empty_tanh_5) { auto x = NDArrayFactory::create('c', {32, 0}); nd4j::ops::tanh op; - auto result = op.execute({&x}, {}, {}); + auto result = op.evaluate({&x}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -450,7 +450,7 @@ TEST_F(DeclarableOpsTests14, repeat_1) { NDArray e('c', {4, 3}, {1, 2, 3, 1, 2, 3, 4, 5, 6, 4, 5, 6}); nd4j::ops::repeat op; - auto result = op.execute({&x}, {}, {2, 0}); + auto result = op.evaluate({&x}, {}, {2, 0}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -468,7 +468,7 @@ TEST_F(DeclarableOpsTests14, repeat_2) { NDArray e('c', {2, 6}, {1, 1, 2, 2, 3, 3,4, 4, 5, 5, 6, 6}); nd4j::ops::repeat op; - auto result = op.execute({&x}, {}, {2, 1}); + auto result = op.evaluate({&x}, {}, {2, 1}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -486,7 +486,7 @@ TEST_F(DeclarableOpsTests14, repeat_3) { NDArray e('c', {2, 6}, {1, 2, 2, 3, 3, 3,4, 5, 5, 6, 6, 6}); nd4j::ops::repeat op; - auto result = op.execute({&x}, {}, {1,2,3, 1}); + auto result = op.evaluate({&x}, {}, {1,2,3, 1}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -504,7 +504,7 @@ TEST_F(DeclarableOpsTests14, repeat_4) { NDArray e('c', {7, 3}, {1, 2, 3, 1, 2, 3, 1, 2, 3, 4, 5, 6, 4, 5, 6, 4, 5, 6, 4, 5, 6}); nd4j::ops::repeat op; - auto result = op.execute({&x}, {}, {3,4, 0}); + auto result = op.evaluate({&x}, {}, {3,4, 0}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -522,7 +522,7 @@ TEST_F(DeclarableOpsTests14, repeat_5) { NDArray e('c', {2, 4, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 17, 18, 19, 20, 21, 22, 23, 24}); nd4j::ops::repeat op; - auto result = op.execute({&x}, {}, {1,2,1, 1}); + auto result = op.evaluate({&x}, {}, {1,2,1, 1}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp index 84dd5d732..cd7f84610 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp @@ -49,7 +49,7 @@ TEST_F(DeclarableOpsTests15, Test_NormalizeMoments_1) { auto z1 = NDArrayFactory::create('c', {10}); nd4j::ops::normalize_moments op; - auto result = op.execute({&w, &x, &y}, {&z0, &z1}, {1e-4}, {}, {}); + auto result = op.execute({&w, &x, &y}, std::vector{&z0, &z1}, {1e-4}, {}, {}); ASSERT_EQ(Status::OK(), result); } @@ -87,7 +87,7 @@ TEST_F(DeclarableOpsTests15, Test_standarize_bp_1) { auto eps = NDArrayFactory::create('c', {5}, {0.f, 0.f, 0.f, 0.f, 0.f}); nd4j::ops::standardize_bp op; - auto result = op.execute({&x, &eps}, {}, {0}, {}); + auto result = op.evaluate({&x, &eps}, {0}); ASSERT_EQ(Status::OK(), result->status()); delete result; } @@ -103,7 +103,7 @@ TEST_F(DeclarableOpsTests15, Test_AdjustContrast_1) { x.linspace(1.); nd4j::ops::adjust_contrast op; - auto result = op.execute({&x, &factor}, {}, {}, {}); + auto result = op.evaluate({&x, &factor}, {}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); auto out = result->at(0); @@ -121,7 +121,7 @@ TEST_F(DeclarableOpsTests15, Test_AdjustContrast_2) { }); x.linspace(1.); nd4j::ops::adjust_contrast op; - auto result = op.execute({&x}, {2.}, {}, {}); + auto result = op.evaluate({&x}, {2.}); ASSERT_EQ(Status::OK(), result->status()); auto out = result->at(0); // out->printIndexedBuffer("Adjusted Constrast"); @@ -139,7 +139,7 @@ TEST_F(DeclarableOpsTests15, Test_AdjustContrast_3) { }); x.linspace(1.); nd4j::ops::adjust_contrast_v2 op; - auto result = op.execute({&x}, {2.}, {}, {}); + auto result = op.evaluate({&x}, {2.}); ASSERT_EQ(Status::OK(), result->status()); auto out = result->at(0); // out->printIndexedBuffer("Adjusted Constrast"); @@ -157,7 +157,7 @@ TEST_F(DeclarableOpsTests15, Test_AdjustContrast_4) { }); x.linspace(1.); nd4j::ops::adjust_contrast_v2 op; - auto result = op.execute({&x}, {2.}, {}, {}); + auto result = op.evaluate({&x}, {2.}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); auto out = result->at(0); // out->printIndexedBuffer("Adjusted Constrast"); @@ -172,7 +172,7 @@ TEST_F(DeclarableOpsTests15, Test_AdjustContrast_5) { }); x.linspace(1.); nd4j::ops::adjust_contrast_v2 op; - auto result = op.execute({&x}, {2.}, {}, {}); + auto result = op.evaluate({&x}, {2.}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); auto out = result->at(0); // out->printIndexedBuffer("Adjusted Constrast"); @@ -302,7 +302,7 @@ TEST_F(DeclarableOpsTests15, Test_AdjustContrast_6) { }); nd4j::ops::adjust_contrast op; - auto result = op.execute({&x}, {2.}, {}, {}); + auto result = op.evaluate({&x}, {2.}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); auto out = result->at(0); // out->printBuffer("Adjusted Constrast6"); @@ -407,7 +407,7 @@ TEST_F(DeclarableOpsTests15, Test_AdjustContrast_7) { }); // x.linspace(1.); nd4j::ops::adjust_contrast_v2 op; - auto result = op.execute({&x}, {2.}, {}, {}); + auto result = op.evaluate({&x}, {2.}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); auto out = result->at(0); // out->printBuffer("Adjusted Constrast7"); @@ -423,7 +423,7 @@ TEST_F(DeclarableOpsTests15, Test_BitCast_1) { auto e = NDArrayFactory::create('c', {2, 2}, {2., 512., 8192., 131072.032 }); x.linspace(1.); nd4j::ops::bitcast op; - auto result = op.execute({&x}, {}, {nd4j::DataType::DOUBLE}, {}); + auto result = op.evaluate({&x}, {(int) nd4j::DataType::DOUBLE}); ASSERT_EQ(Status::OK(), result->status()); auto out = result->at(0); // out->printIndexedBuffer("Casted result"); @@ -437,7 +437,7 @@ TEST_F(DeclarableOpsTests15, Test_BitCast_2) { 0.f, 2.312f, 0.f, 2.375f, 0.f, 2.438f, 0.f, 2.5f}); x.linspace(1.); nd4j::ops::bitcast op; - auto result = op.execute({&x}, {}, {nd4j::DataType::HALF}, {}); + auto result = op.evaluate({&x}, {(int) nd4j::DataType::HALF}); ASSERT_EQ(Status::OK(), result->status()); auto out = result->at(0); ASSERT_TRUE(e.equalsTo(out)); @@ -450,7 +450,7 @@ TEST_F(DeclarableOpsTests15, Test_BitCast_3) { x.linspace(1.); nd4j::ops::bitcast op; try { - auto result = op.execute({&x}, {}, {nd4j::DataType::INT64}, {}); + auto result = op.evaluate({&x}, {(int) nd4j::DataType::INT64}); ASSERT_NE(Status::OK(), result->status()); delete result; } catch (std::exception& e) { @@ -478,7 +478,7 @@ TEST_F(DeclarableOpsTests15, Test_BitCast_4_1) { x.linspace(1.); nd4j::ops::bitcast op; - auto result = op.execute({&x}, {}, {nd4j::DataType::INT64}, {}); + auto result = op.evaluate({&x}, {}, {nd4j::DataType::INT64}, {}); ASSERT_EQ(Status::OK(), result->status()); // e.printIndexedBuffer("Double to int64"); auto res = result->at(0); @@ -497,7 +497,7 @@ TEST_F(DeclarableOpsTests15, Test_BitCast_5) { auto e = NDArrayFactory::create('c', {4}, {4260467851820808160LL, 3900173902914993008LL, 3566895990128523424LL, 3314989625590692528LL}); nd4j::ops::bitcast op; - auto result = op.execute({&x}, {}, {nd4j::DataType::INT64}, {}); + auto result = op.evaluate({&x}, {}, {nd4j::DataType::INT64}, {}); ASSERT_EQ(Status::OK(), result->status()); auto res = result->at(0); // res->printIndexedBuffer("BITCAST5"); @@ -515,7 +515,7 @@ TEST_F(DeclarableOpsTests15, Test_BitCast_6) { auto e = NDArrayFactory::create('c', {4}, {4899988963420290048LL, 5188224837230806272LL, 5332342774136064128LL, 5476460161268730496LL}); nd4j::ops::bitcast op; - auto result = op.execute({&x}, {}, {nd4j::DataType::INT64}, {}); + auto result = op.evaluate({&x}, {}, {nd4j::DataType::INT64}, {}); ASSERT_EQ(Status::OK(), result->status()); auto res = result->at(0); // res->printIndexedBuffer("BITCAST6"); @@ -532,7 +532,7 @@ TEST_F(DeclarableOpsTests15, Test_BitCast_7) { auto e = NDArrayFactory::create('c', {4}, { 4928700072476425318LL, 5202580391758873882LL, 5346698272827918477LL, 5483778673873668736LL}); nd4j::ops::bitcast op; - auto result = op.execute({&x}, {}, {nd4j::DataType::INT64}, {}); + auto result = op.evaluate({&x}, {}, {nd4j::DataType::INT64}, {}); ASSERT_EQ(Status::OK(), result->status()); auto res = result->at(0); // res->printIndexedBuffer("BITCAST7"); @@ -549,7 +549,7 @@ TEST_F(DeclarableOpsTests15, test_matmul_bp_1) { auto gB = NDArrayFactory::create('c', {1, 4}); nd4j::ops::matmul_bp op; - auto status = op.execute({&a, &b, &gI}, {&gA, &gB}, {}, {1, 0, 0}, {}); + auto status = op.execute({&a, &b, &gI}, std::vector{&gA, &gB}, {}, {1, 0, 0}, {}); ASSERT_EQ(Status::OK(), status); } @@ -573,7 +573,7 @@ TEST_F(DeclarableOpsTests15, test_check_numeric_1) { auto y = NDArrayFactory::string("shouldn't ever trigger"); nd4j::ops::check_numerics op; - auto result = op.execute({&x, &y}, {}, {}); + auto result = op.evaluate({&x, &y}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -617,7 +617,7 @@ TEST_F(DeclarableOpsTests15, Test_layer_norm_1) { auto b = NDArrayFactory::create('c', {5}, {1.f, 2.f, 3.f, 4.f, 5.f}); nd4j::ops::layer_norm op; - auto result = op.execute({&x, &g, &b}, {}, {0}, {false}); + auto result = op.evaluate({&x, &g, &b}, {}, {0}, {false}); ASSERT_EQ(Status::OK(), result->status()); delete result; } @@ -629,7 +629,7 @@ TEST_F(DeclarableOpsTests15, Test_layer_norm_bp_1) { auto eps = NDArrayFactory::create('c', {1, 5}, {0.f, 0.f, 0.f, 0.f, 0.f}); nd4j::ops::layer_norm_bp op; - auto result = op.execute({&x, &g, &b, &eps}, {}, {0}, {false}); + auto result = op.evaluate({&x, &g, &b, &eps}, {}, {0}, {false}); ASSERT_EQ(Status::OK(), result->status()); delete result; } @@ -662,9 +662,9 @@ TEST_F(DeclarableOpsTests15, test_hashCode_1) { y.linspace(2.); nd4j::ops::hashcode op; - auto resultA0 = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::INT64); - auto resultA1 = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::INT64); - auto resultB0 = op.execute({&y}, {}, {}, {}, false, nd4j::DataType::INT64); + auto resultA0 = op.evaluate({&x}); + auto resultA1 = op.evaluate({&x}); + auto resultB0 = op.evaluate({&y}); // resultA0->at(0)->printIndexedBuffer("A0"); // resultA1->at(0)->printIndexedBuffer("A1"); // resultB0->at(0)->printIndexedBuffer("B0"); @@ -684,9 +684,9 @@ TEST_F(DeclarableOpsTests15, test_hashCode_2) { y.linspace(2.); nd4j::ops::hashcode op; - auto resultA0 = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::INT64); - auto resultA1 = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::INT64); - auto resultB0 = op.execute({&y}, {}, {}, {}, false, nd4j::DataType::INT64); + auto resultA0 = op.evaluate({&x}); + auto resultA1 = op.evaluate({&x}); + auto resultB0 = op.evaluate({&y}); // resultA0->at(0)->printIndexedBuffer("A0"); // resultA1->at(0)->printIndexedBuffer("A1"); @@ -705,7 +705,7 @@ TEST_F(DeclarableOpsTests15, test_reshape_to_scalar_1) { auto e = NDArrayFactory::create('c', {1, 1}, {119.f}); nd4j::ops::reshape op; - auto result = op.execute({&array}, {}, {1, 1}); + auto result = op.evaluate({&array}, {}, {1, 1}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -742,7 +742,7 @@ TEST_F(DeclarableOpsTests15, test_rank_2) { auto e = NDArrayFactory::create('c', {}, {2}); nd4j::ops::rank op; - auto result = op.execute({&array}, {}, {}); + auto result = op.evaluate({&array}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -764,7 +764,7 @@ TEST_F(DeclarableOpsTests15, test_lstmBlock_1) { auto x8 = NDArrayFactory::create('c', {12}); nd4j::ops::lstmBlock op; - auto result = op.execute({&x0, &x1, &x2, &x3, &x4, &x5, &x6, &x7, &x8}, {2.0, 0.3}, {0, 0}); + auto result = op.evaluate({&x0, &x1, &x2, &x3, &x4, &x5, &x6, &x7, &x8}, {2.0, 0.3}, {0, 0}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -790,7 +790,7 @@ TEST_F(DeclarableOpsTests15, test_lstmBlock_2) { auto x8 = NDArrayFactory::create('f', {4 * nIn}); nd4j::ops::lstmBlock op; - auto result = op.execute({&x0, &x1, &x2, &x3, &x4, &x5, &x6, &x7, &x8}, {1.0, 0.0}, {0, 1}); + auto result = op.evaluate({&x0, &x1, &x2, &x3, &x4, &x5, &x6, &x7, &x8}, {1.0, 0.0}, {0, 1}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -860,7 +860,7 @@ TEST_F(DeclarableOpsTests15, test_rgb_to_grs_1) { NDArray rgbs('c', { 3 }, { 10, 50, 200 }, nd4j::DataType::INT32); NDArray expected('c', { 1 }, { 55 }, nd4j::DataType::INT32); nd4j::ops::rgb_to_grs op; - auto result = op.execute({&rgbs}, {}, {}); + auto result = op.evaluate({&rgbs}, {}, {}); auto output = result->at(0); ASSERT_EQ(Status::OK(), result->status()); @@ -876,7 +876,7 @@ TEST_F(DeclarableOpsTests15, test_rgb_to_grs_2) { auto rgbs = NDArrayFactory::create('f', { 3 }, { 1, 120, -25 }); auto expected = NDArrayFactory::create('f', { 1 }, { 67 }); nd4j::ops::rgb_to_grs op; - auto result = op.execute({ &rgbs }, {}, {}); + auto result = op.evaluate({ &rgbs }, {}, {}); auto output = result->at(0); ASSERT_EQ(Status::OK(), result->status()); @@ -892,7 +892,7 @@ TEST_F(DeclarableOpsTests15, test_rgb_to_grs_3) { NDArray rgbs('c', { 4, 3 }, { -94, 99, 97, 90, 114, 101, 111, 96, 105, 100, 103, 102 }, nd4j::DataType::INT32); NDArray expected('c', { 4, 1 }, { 41, 105, 101, 101 }, nd4j::DataType::INT32); nd4j::ops::rgb_to_grs op; - auto result = op.execute({ &rgbs }, {}, {}); + auto result = op.evaluate({ &rgbs }, {}, {}); auto output = result->at(0); ASSERT_EQ(Status::OK(), result->status()); @@ -910,7 +910,7 @@ TEST_F(DeclarableOpsTests15, test_rgb_to_grs_4) { rgbs.permutei({1,0}); NDArray expected('c', { 2, 1 }, { 138, 58 }, nd4j::DataType::INT32); nd4j::ops::rgb_to_grs op; - auto result = op.execute({ &rgbs }, {}, {}); + auto result = op.evaluate({ &rgbs }, {}, {}); auto output = result->at(0); ASSERT_EQ(Status::OK(), result->status()); @@ -926,7 +926,7 @@ TEST_F(DeclarableOpsTests15, test_rgb_to_grs_5) { NDArray rgbs('c', { 3, 4 }, { -94, 99, 97, 90, 114, 101, 111, 96, 105, 100, 103, 102 }, nd4j::DataType::INT32); NDArray expected('c', { 1, 4 }, { 50, 100, 105, 94 }, nd4j::DataType::INT32); nd4j::ops::rgb_to_grs op; - auto result = op.execute({ &rgbs }, {}, {0}); + auto result = op.evaluate({ &rgbs }, {}, {0}); auto output = result->at(0); ASSERT_EQ(Status::OK(), result->status()); @@ -943,7 +943,7 @@ TEST_F(DeclarableOpsTests15, test_rgb_to_grs_6) { auto expected = NDArrayFactory::create('c', { 5,4,1 }, {-47.82958221f, 34.46305847f, 21.36137581f, -21.91625023f,2.49686432f, -43.59792709f, 9.64180183f, 23.04854202f,40.7946167f, 44.98754883f, -25.19047546f, 20.64586449f,-4.97033119f, 30.0226841f, 30.30688286f, 15.61459541f,43.36166f, 18.22480774f, 13.74833488f, 21.59387016f}); nd4j::ops::rgb_to_grs op; - auto result = op.execute({ &rgbs }, {}, {}); + auto result = op.evaluate({ &rgbs }, {}, {}); auto output = result->at(0); ASSERT_EQ(Status::OK(), result->status()); @@ -960,7 +960,7 @@ TEST_F(DeclarableOpsTests15, test_rgb_to_grs_7) { auto expected = NDArrayFactory::create('c', { 5,1,4 }, { 36.626545f, 38.607746f, -40.614971f, 18.233341f, -51.545094f,2.234142f, 20.913160f, 8.783220f, 15.955761f, 55.273506f, 36.838833f, -29.751089f, 8.148357f, 13.676106f, 1.097548f, 68.766457f, 38.690712f, 27.176361f, -14.156269f, 7.157052f }); nd4j::ops::rgb_to_grs op; - auto result = op.execute({ &rgbs }, {}, {1}); + auto result = op.evaluate({ &rgbs }, {}, {1}); auto output = result->at(0); ASSERT_EQ(Status::OK(), result->status()); @@ -976,7 +976,7 @@ TEST_F(DeclarableOpsTests15, test_rgb_to_grs_8) { auto rgbs = NDArrayFactory::create('c', { 3,5,4 }, {1.7750e+01f, -7.1062e+01f, -1.0019e+02f,-2.3406e+01f, 5.2094e+01f, 9.5438e+01f, -6.7461e+00f, 3.8562e+01f, 6.5078e+00f,3.3562e+01f, -5.8844e+01f, 2.2750e+01f, -1.0477e+01f, 7.7344e+00f, 9.5469e+00f,2.1391e+01f, -8.5312e+01f, 7.5830e-01f,2.3125e+01f, 1.8145e+00f, 1.4602e+01f,-4.5859e+00f, 3.9344e+01f, 1.1617e+01f,-8.6562e+01f, 1.0038e+02f, 6.7938e+01f,5.9961e+00f, 6.7812e+01f, 2.9734e+01f,2.9609e+01f, -6.1438e+01f, 1.7750e+01f,6.8562e+01f, -7.4414e+00f, 3.9656e+01f,1.1641e+01f, -2.7516e+01f, 6.7562e+01f,7.8438e+01f, 5.4883e+00f, 2.9438e+01f,-3.1344e+01f, 6.5125e+01f, 1.2695e+01f,4.0531e+01f, -6.1211e+00f, 6.2219e+01f,4.6812e+01f, 5.2250e+01f, -1.1414e+01f,1.5404e-02f, 2.9938e+01f, 5.6719e+00f,-2.0125e+01f, 2.1531e+01f, 6.2500e+01f,7.2188e+01f, 9.3750e+00f, -4.8125e+01f}); try { nd4j::ops::rgb_to_grs op; - auto result = op.execute({ &rgbs }, {}, {}); + auto result = op.evaluate({ &rgbs }, {}, {}); ASSERT_EQ(Status::THROW(), result->status()); delete result; } catch (std::exception& e) { @@ -991,7 +991,7 @@ TEST_F(DeclarableOpsTests15, test_rgb_to_grs_9) { auto expected = NDArrayFactory::create('f', { 2,2,1 }, { 36.626545f, 38.607746f, -40.614971f, 18.233341f }); nd4j::ops::rgb_to_grs op; - auto result = op.execute({ &rgbs }, {}, {}); + auto result = op.evaluate({ &rgbs }, {}, {}); auto output = result->at(0); ASSERT_EQ(Status::OK(), result->status()); @@ -1007,7 +1007,7 @@ TEST_F(DeclarableOpsTests15, test_rgb_to_yuv_1) { NDArray rgbs('f', { 3 }, { 10, 50, 200 }, nd4j::DataType::FLOAT32); NDArray expected('f', { 3 }, { 55.14 , 71.2872001, -39.6005542 }, nd4j::DataType::FLOAT32); nd4j::ops::rgb_to_yuv op; - auto result = op.execute({ &rgbs }, {}, {}); + auto result = op.evaluate({ &rgbs }, {}, {}); auto output = result->at(0); ASSERT_EQ(Status::OK(), result->status()); @@ -1026,7 +1026,7 @@ TEST_F(DeclarableOpsTests15, test_rgb_to_yuv_2) { NDArray expected('c', { 2, 3 }, { 138.691, -12.150713, -109.38929, 58.385, 70.18241, 35.63085 }, nd4j::DataType::FLOAT32); nd4j::ops::rgb_to_yuv op; - auto result = op.execute({ &rgbs }, {}, {}); + auto result = op.evaluate({ &rgbs }, {}, {}); auto output = result->at(0); ASSERT_EQ(Status::OK(), result->status()); @@ -1043,7 +1043,7 @@ TEST_F(DeclarableOpsTests15, test_rgb_to_yuv_3) { NDArray expected('c', { 3, 4 }, { -2.021720, 4.692970, 3.669290, 9.491281, 1.511627, 2.611648, -1.298824, 0.358612, -6.472839, 4.568039, 5.290639, -0.430992 }, nd4j::DataType::FLOAT32); nd4j::ops::rgb_to_yuv op; - auto result = op.execute({ &rgbs }, {}, { 0 }); + auto result = op.evaluate({ &rgbs }, {}, { 0 }); auto output = result->at(0); ASSERT_EQ(Status::OK(), result->status()); ASSERT_TRUE(expected.isSameShape(output)); @@ -1059,7 +1059,7 @@ TEST_F(DeclarableOpsTests15, test_rgb_to_yuv_4) { NDArray expected('c', { 5,4,3 }, { 14.5042902, -4.43686799, 2.847406, 92.079556, -25.36761168, 2.94630572, -1.515069, -4.87137291, -50.29369639, 32.128515, -5.21515376, -9.41983935,-20.5835293, 24.61614501, -44.28390394, 37.1647167, -21.30142676, -38.52221293, -29.26009994, 14.40679768, 45.62757638, -11.550021, 36.44083018, -64.71012983,-10.435098, - 10.28950082, - 78.74044941, 22.1427147, 19.72198103, 14.40435988, 10.699559, 9.46744852, - 18.5778351 , -7.6957283, 39.31166179, 7.41657542, 7.245035, 28.48336771, - 26.88963173, 47.0880442, - 0.13584441, - 35.60035823, 43.2050762, - 18.47048906, - 31.11782117, 47.642019, - 18.83162118, - 21.50836396,-33.788558, 22.87507047, 75.34330791, 33.445396, 9.25395257, 0.10229474, -3.8078287, -8.02985955, 11.71587638, 41.0993915, -43.90830496, -34.46396749 }, nd4j::DataType::FLOAT32); nd4j::ops::rgb_to_yuv op; - auto result = op.execute({ &rgbs }, {}, {}); + auto result = op.evaluate({ &rgbs }, {}, {}); auto output = result->at(0); ASSERT_EQ(Status::OK(), result->status()); @@ -1076,7 +1076,7 @@ TEST_F(DeclarableOpsTests15, test_rgb_to_yuv_5) { NDArray expected('c', { 5,3,4 }, { 36.628319, 38.600643,-40.624989, 18.231001, - 14.822637, - 2.479566, - 8.965780, 2.223851, -16.561626,-96.205162,-52.255379,-36.527435,-51.546139,2.234915, 20.914114, 8.785358, 32.552223, -3.356598, 9.069552, 1.393482,36.029255, 4.824605,- 9.972263,11.058715, 15.947105, 55.283543, 36.845627, -29.750486,0.887228, 6.534475, -21.794132,34.155693, -89.929497,39.562351, 27.276817,31.359871, 8.149521, 13.673355, 1.104303, 68.774300, 2.236881, 13.216944, - 3.555702,- 3.225931,3.063015, - 36.134724,58.302204, 8.477802, 38.695396,27.181587, - 14.157411,7.157054, 11.714512, 22.148155, 11.580557, - 27.204905,7.120562, 21.992094, 2.406748, - 6.265247, }, nd4j::DataType::FLOAT32); nd4j::ops::rgb_to_yuv op; - auto result = op.execute({ &rgbs }, {}, { 1 }); + auto result = op.evaluate({ &rgbs }, {}, { 1 }); auto output = result->at(0); ASSERT_EQ(Status::OK(), result->status()); @@ -1091,7 +1091,7 @@ TEST_F(DeclarableOpsTests15, test_rgb_to_yuv_6) { NDArray rgbs('c', { 3,5,4 }, { 1.7750e+01f, -7.1062e+01f, -1.0019e+02f,-2.3406e+01f, 5.2094e+01f, 9.5438e+01f, -6.7461e+00f, 3.8562e+01f, 6.5078e+00f,3.3562e+01f, -5.8844e+01f, 2.2750e+01f, -1.0477e+01f, 7.7344e+00f, 9.5469e+00f,2.1391e+01f, -8.5312e+01f, 7.5830e-01f,2.3125e+01f, 1.8145e+00f, 1.4602e+01f,-4.5859e+00f, 3.9344e+01f, 1.1617e+01f,-8.6562e+01f, 1.0038e+02f, 6.7938e+01f,5.9961e+00f, 6.7812e+01f, 2.9734e+01f,2.9609e+01f, -6.1438e+01f, 1.7750e+01f,6.8562e+01f, -7.4414e+00f, 3.9656e+01f,1.1641e+01f, -2.7516e+01f, 6.7562e+01f,7.8438e+01f, 5.4883e+00f, 2.9438e+01f,-3.1344e+01f, 6.5125e+01f, 1.2695e+01f,4.0531e+01f, -6.1211e+00f, 6.2219e+01f,4.6812e+01f, 5.2250e+01f, -1.1414e+01f,1.5404e-02f, 2.9938e+01f, 5.6719e+00f,-2.0125e+01f, 2.1531e+01f, 6.2500e+01f,7.2188e+01f, 9.3750e+00f, -4.8125e+01f }, nd4j::DataType::FLOAT32); try { nd4j::ops::rgb_to_yuv op; - auto result = op.execute({ &rgbs }, {}, {}); + auto result = op.evaluate({ &rgbs }, {}, {}); ASSERT_EQ(Status::THROW(), result->status()); delete result; } @@ -1107,7 +1107,7 @@ TEST_F(DeclarableOpsTests15, test_rgb_to_yuv_7) { NDArray expected('f', { 2,2,3 }, { 36.628319,38.600643, -40.624989,18.231001, -14.822637,-2.479566, -8.965780, 2.223851, -16.561626,- 96.205162,-52.255379, -36.527435 }, nd4j::DataType::FLOAT32); nd4j::ops::rgb_to_yuv op; - auto result = op.execute({ &rgbs }, {}, {}); + auto result = op.evaluate({ &rgbs }, {}, {}); auto output = result->at(0); ASSERT_EQ(Status::OK(), result->status()); @@ -1123,7 +1123,7 @@ TEST_F(DeclarableOpsTests15, test_yuv_to_rgb_1) { NDArray yuv('c', { 3 }, { 55.14 , 71.2872001, -39.6005542 }, nd4j::DataType::FLOAT32); NDArray expected('c', { 3 }, { 10, 50, 200 }, nd4j::DataType::FLOAT32); nd4j::ops::yuv_to_rgb op; - auto result = op.execute({ &yuv }, {}, {}); + auto result = op.evaluate({ &yuv }, {}, {}); auto output = result->at(0); ASSERT_EQ(Status::OK(), result->status()); @@ -1139,7 +1139,7 @@ TEST_F(DeclarableOpsTests15, test_yuv_to_rgb_2) { NDArray yuv('f', { 3 }, { 55.14, 71.2872001, -39.6005542 }, nd4j::DataType::FLOAT32); NDArray expected('f', { 3 }, { 10, 50, 200 }, nd4j::DataType::FLOAT32); nd4j::ops::yuv_to_rgb op; - auto result = op.execute({ &yuv }, {}, {}); + auto result = op.evaluate({ &yuv }, {}, {}); auto output = result->at(0); ASSERT_EQ(Status::OK(), result->status()); @@ -1156,7 +1156,7 @@ TEST_F(DeclarableOpsTests15, test_yuv_to_rgb_3) { NDArray yuv('c', { 3, 4 }, { -2.021720, 4.692970, 3.669290, 9.491281, 1.511627, 2.611648, -1.298824, 0.358612, -6.472839, 4.568039, 5.290639, -0.430992 }, nd4j::DataType::FLOAT32); nd4j::ops::yuv_to_rgb op; - auto result = op.execute({ &yuv }, {}, { 0 }); + auto result = op.evaluate({ &yuv }, {}, { 0 }); auto output = result->at(0); ASSERT_EQ(Status::OK(), result->status()); ASSERT_TRUE(expected.isSameShape(output)); @@ -1172,7 +1172,7 @@ TEST_F(DeclarableOpsTests15, test_yuv_to_rgb_4) { NDArray yuv('c', { 5,4,3 }, { 14.5042902, -4.43686799, 2.847406, 92.079556, -25.36761168, 2.94630572, -1.515069, -4.87137291, -50.29369639, 32.128515, -5.21515376, -9.41983935,-20.5835293, 24.61614501, -44.28390394, 37.1647167, -21.30142676, -38.52221293, -29.26009994, 14.40679768, 45.62757638, -11.550021, 36.44083018, -64.71012983,-10.435098, -10.28950082, -78.74044941, 22.1427147, 19.72198103, 14.40435988, 10.699559, 9.46744852, -18.5778351 , -7.6957283, 39.31166179, 7.41657542, 7.245035, 28.48336771, -26.88963173, 47.0880442, -0.13584441, -35.60035823, 43.2050762, -18.47048906, -31.11782117, 47.642019, -18.83162118, -21.50836396,-33.788558, 22.87507047, 75.34330791, 33.445396, 9.25395257, 0.10229474, -3.8078287, -8.02985955, 11.71587638, 41.0993915, -43.90830496, -34.46396749 }, nd4j::DataType::FLOAT32); nd4j::ops::yuv_to_rgb op; - auto result = op.execute({ &yuv }, {}, {}); + auto result = op.evaluate({ &yuv }, {}, {}); auto output = result->at(0); ASSERT_EQ(Status::OK(), result->status()); @@ -1189,7 +1189,7 @@ TEST_F(DeclarableOpsTests15, test_yuv_to_rgb_5) { NDArray yuv('c', { 5,3,4 }, { 36.628319, 38.600643,-40.624989, 18.231001, -14.822637, -2.479566, -8.965780, 2.223851, -16.561626,-96.205162,-52.255379,-36.527435,-51.546139,2.234915, 20.914114, 8.785358, 32.552223, -3.356598, 9.069552, 1.393482,36.029255, 4.824605,-9.972263,11.058715, 15.947105, 55.283543, 36.845627, -29.750486,0.887228, 6.534475, -21.794132,34.155693, -89.929497,39.562351, 27.276817,31.359871, 8.149521, 13.673355, 1.104303, 68.774300, 2.236881, 13.216944, -3.555702,-3.225931,3.063015, -36.134724,58.302204, 8.477802, 38.695396,27.181587, -14.157411,7.157054, 11.714512, 22.148155, 11.580557, -27.204905,7.120562, 21.992094, 2.406748, -6.265247, }, nd4j::DataType::FLOAT32); nd4j::ops::yuv_to_rgb op; - auto result = op.execute({ &yuv }, {}, { 1 }); + auto result = op.evaluate({ &yuv }, {}, { 1 }); auto output = result->at(0); ASSERT_EQ(Status::OK(), result->status()); @@ -1204,7 +1204,7 @@ TEST_F(DeclarableOpsTests15, test_yuv_to_rgb_6) { NDArray yuv('c', { 3,5,4 }, { 1.7750e+01f, -7.1062e+01f, -1.0019e+02f,-2.3406e+01f, 5.2094e+01f, 9.5438e+01f, -6.7461e+00f, 3.8562e+01f, 6.5078e+00f,3.3562e+01f, -5.8844e+01f, 2.2750e+01f, -1.0477e+01f, 7.7344e+00f, 9.5469e+00f,2.1391e+01f, -8.5312e+01f, 7.5830e-01f,2.3125e+01f, 1.8145e+00f, 1.4602e+01f,-4.5859e+00f, 3.9344e+01f, 1.1617e+01f,-8.6562e+01f, 1.0038e+02f, 6.7938e+01f,5.9961e+00f, 6.7812e+01f, 2.9734e+01f,2.9609e+01f, -6.1438e+01f, 1.7750e+01f,6.8562e+01f, -7.4414e+00f, 3.9656e+01f,1.1641e+01f, -2.7516e+01f, 6.7562e+01f,7.8438e+01f, 5.4883e+00f, 2.9438e+01f,-3.1344e+01f, 6.5125e+01f, 1.2695e+01f,4.0531e+01f, -6.1211e+00f, 6.2219e+01f,4.6812e+01f, 5.2250e+01f, -1.1414e+01f,1.5404e-02f, 2.9938e+01f, 5.6719e+00f,-2.0125e+01f, 2.1531e+01f, 6.2500e+01f,7.2188e+01f, 9.3750e+00f, -4.8125e+01f }, nd4j::DataType::FLOAT32); try { nd4j::ops::yuv_to_rgb op; - auto result = op.execute({ &yuv }, {}, {}); + auto result = op.evaluate({ &yuv }, {}, {}); ASSERT_EQ(Status::THROW(), result->status()); delete result; } @@ -1220,7 +1220,7 @@ TEST_F(DeclarableOpsTests15, test_yuv_to_rgb_7) { NDArray yuv('f', { 2,2,3 }, { 36.628319, 38.600643, -40.624989, 18.231001, -14.822637, -2.479566, -8.965780, 2.223851, -16.561626, -96.205162, -52.255379, -36.527435 }, nd4j::DataType::FLOAT32); nd4j::ops::yuv_to_rgb op; - auto result = op.execute({ &yuv }, {}, {}); + auto result = op.evaluate({ &yuv }, {}, {}); auto output = result->at(0); ASSERT_EQ(Status::OK(), result->status()); @@ -1246,7 +1246,7 @@ TEST_F(DeclarableOpsTests15, Pow_BP_Test1) { dLdz.assign(1.0); nd4j::ops::Pow_bp op; - auto results = op.execute({ &x, &y, &dLdz }, {}, {}); + auto results = op.evaluate({ &x, &y, &dLdz }, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1275,7 +1275,7 @@ TEST_F(DeclarableOpsTests15, Pow_BP_Test2) { dLdz.linspace(0.1, 0.1); nd4j::ops::Pow_bp op; - auto results = op.execute({ &x, &y, &dLdz }, {}, {}); + auto results = op.evaluate({ &x, &y, &dLdz }, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); auto* dLdx = results->at(0); @@ -1305,7 +1305,7 @@ TEST_F(DeclarableOpsTests15, Pow_BP_Test3) { dLdz.linspace(0.1, 0.1); nd4j::ops::Pow_bp op; - auto resultsY = op.execute({ &xY, &yY, &dLdz }, {}, {}); + auto resultsY = op.evaluate({ &xY, &yY, &dLdz }, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, resultsY->status()); @@ -1337,7 +1337,7 @@ TEST_F(DeclarableOpsTests15, Pow_BP_Test4) { xX.assign(2.0); yX.assign(4.0); - auto resultsX = op.execute({ &xX, &yX, &dLdz }, {}, {}); + auto resultsX = op.evaluate({ &xX, &yX, &dLdz }, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, resultsX->status()); @@ -1369,7 +1369,7 @@ TEST_F(DeclarableOpsTests15, Pow_BP_Test5) { dLdyExp.assign(pow(3, 4) * log(3)); nd4j::ops::Pow_bp op; - auto results = op.execute({ &xConst, &yConst, &dLdz }, {}, {}); + auto results = op.evaluate({ &xConst, &yConst, &dLdz }, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); auto* dLdx = results->at(0); @@ -1399,7 +1399,7 @@ TEST_F(DeclarableOpsTests15, Pow_BP_Test6) { NDArray dLdyExpXC('c', { 2, 2, 2 }, { 1.10904, 2.21807, 3.32711, 4.43614, 5.54518, 6.65421, 7.76325, 8.87228 }, nd4j::DataType::FLOAT32); nd4j::ops::Pow_bp op; - auto resultsXC = op.execute({ &xConst, &y, &dLdzC }, {}, {}); + auto resultsXC = op.evaluate({ &xConst, &y, &dLdzC }, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, resultsXC->status()); auto* dLdxXC = resultsXC->at(0); @@ -1428,7 +1428,7 @@ TEST_F(DeclarableOpsTests15, Pow_BP_Test7) { auto dLdyExpYs = NDArrayFactory::create(79.85056f); nd4j::ops::Pow_bp op; - auto resultsYs = op.execute({ &x, &Y, &dLdzC }, {}, {}); + auto resultsYs = op.evaluate({ &x, &Y, &dLdzC }, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, resultsYs->status()); auto* dLdxY = resultsYs->at(0); @@ -1454,7 +1454,7 @@ TEST_F(DeclarableOpsTests15, Pow_BP_Test8) { NDArray dLdyExp = NDArrayFactory::create(pow(4.f, 2.f) * log(4.f) * 0.1f); nd4j::ops::Pow_bp op; - auto results = op.execute({ &X, &Y, &dLdz }, {}, {}); + auto results = op.evaluate({ &X, &Y, &dLdz }, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1484,7 +1484,7 @@ TEST_F(DeclarableOpsTests15, Pow_BP_Test9) { y.assign(2.0); dLdz.linspace(0.1, 0.1); - auto results = op.execute({ &x, &y, &dLdz }, {}, {}); + auto results = op.evaluate({ &x, &y, &dLdz }, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); auto* dLdx = results->at(0); @@ -1513,7 +1513,7 @@ TEST_F(DeclarableOpsTests15, Pow_BP_Test10) { yB.assign(2.0); nd4j::ops::Pow_bp op; - auto resultsB = op.execute({ &xB, &yB, &dLdzB }, {}, {}); + auto resultsB = op.evaluate({ &xB, &yB, &dLdzB }, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, resultsB->status()); @@ -1540,7 +1540,7 @@ TEST_F(DeclarableOpsTests15, Pow_BP_Test11) { NDArray dLdzB('c', { 3,2,3 }, { .1,.2,.3, .1,.2,.3, .1,.4,.1, .2,.1,.1, .3,.1,.5, .1, .7, .1 }, nd4j::DataType::FLOAT32); nd4j::ops::Pow_bp op; - auto resultsB = op.execute({ &xB, &yB, &dLdzB }, {}, {}); + auto resultsB = op.evaluate({ &xB, &yB, &dLdzB }, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, resultsB->status()); auto* dLdxB = resultsB->at(0); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests16.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests16.cpp index f05b8f488..cff57b62d 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests16.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests16.cpp @@ -46,7 +46,7 @@ TEST_F(DeclarableOpsTests16, scatter_upd_1) { auto e = NDArrayFactory::create('c', { 3 }, { 3.f, 1.f, 1.f }); nd4j::ops::scatter_upd op; - auto result = op.execute({ &x, &y, &w }, {}, {}); + auto result = op.evaluate({ &x, &y, &w }); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -66,7 +66,7 @@ TEST_F(DeclarableOpsTests16, scatter_upd_2) { x.linspace(1); nd4j::ops::scatter_upd op; - auto result = op.execute({ &x, &indices, &updates }, {}, {}); + auto result = op.evaluate({ &x, &indices, &updates }); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -135,7 +135,7 @@ TEST_F(DeclarableOpsTests16, test_hamming_distance_1) { auto e = NDArrayFactory::create(18); nd4j::ops::bits_hamming_distance op; - auto result = op.execute({ &x, &y }, {}, {}); + auto result = op.evaluate({ &x, &y }); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -166,7 +166,7 @@ TEST_F(DeclarableOpsTests16, test_empty_cast_1) { auto e = NDArrayFactory::create('c', { 1, 0, 2 }); nd4j::ops::cast op; - auto result = op.execute({ &x }, {}, { 10 }); + auto result = op.evaluate({&x}, {10}); ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(e, *result->at(0)); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests17.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests17.cpp index 543043ebd..3f200854d 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests17.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests17.cpp @@ -48,7 +48,7 @@ TEST_F(DeclarableOpsTests17, test_sparse_to_dense_1) { nd4j::ops::compat_sparse_to_dense op; - auto result = op.execute({&ranges, &shape, &values, &def}, {}, {}); + auto result = op.evaluate({&ranges, &shape, &values, &def}); ASSERT_EQ(Status::OK(), result->status()); delete result; @@ -63,7 +63,7 @@ TEST_F(DeclarableOpsTests17, test_sparse_to_dense_2) { nd4j::ops::compat_sparse_to_dense op; - auto result = op.execute({&ranges, &shape, &values, &def}, {}, {}); + auto result = op.evaluate({&ranges, &shape, &values, &def}); ASSERT_EQ(Status::OK(), result->status()); delete result; @@ -77,7 +77,7 @@ TEST_F(DeclarableOpsTests17, test_compat_string_split_1) { auto exp1 = NDArrayFactory::string('c', {3}, {"first", "string", "second"}); nd4j::ops::compat_string_split op; - auto result = op.execute({&x, &delimiter}, {}, {}); + auto result = op.evaluate({&x, &delimiter}); ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(2, result->size()); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests2.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests2.cpp index e4d0db62c..118463e3e 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests2.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests2.cpp @@ -41,7 +41,7 @@ TEST_F(DeclarableOpsTests2, gather_1) { nd4j::ops::gather op; - auto result = op.execute({&input, &indices}, {}, {1}); + auto result = op.evaluate({&input, &indices}, {1}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -62,7 +62,7 @@ TEST_F(DeclarableOpsTests2, gather_2) { nd4j::ops::gather op; - auto result = op.execute({&input}, {}, {1, 0,1, 2,2, 1,2}, {true}); + auto result = op.evaluate({&input}, {}, {1, 0,1, 2,2, 1,2}, {true}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -84,7 +84,7 @@ TEST_F(DeclarableOpsTests2, gather_3) { nd4j::ops::gather op; - auto result = op.execute({&input, &indices}, {}, {1}); + auto result = op.evaluate({&input, &indices}, {}, {1}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -104,7 +104,7 @@ TEST_F(DeclarableOpsTests2, gather_4) { nd4j::ops::gather op; - auto result = op.execute({&input}, {}, {1, 2}); + auto result = op.evaluate({&input}, {}, {1, 2}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -125,7 +125,7 @@ TEST_F(DeclarableOpsTests2, gather_5) { nd4j::ops::gather op; - auto result = op.execute({&input, &indices}, {}, {1}, {true}); + auto result = op.evaluate({&input, &indices}, {}, {1}, {true}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -147,7 +147,7 @@ TEST_F(DeclarableOpsTests2, gather_6) { nd4j::ops::gather op; - auto result = op.execute({&input, &indices}, {}, {0}); + auto result = op.evaluate({&input, &indices}, {}, {0}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -169,7 +169,7 @@ TEST_F(DeclarableOpsTests2, gather_7) { nd4j::ops::gather op; - auto result = op.execute({&input, &indices}, {}, {2}); + auto result = op.evaluate({&input, &indices}, {}, {2}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -191,7 +191,7 @@ TEST_F(DeclarableOpsTests2, gather_8) { nd4j::ops::gather op; - auto result = op.execute({&input, &indices}, {}, {0}); + auto result = op.evaluate({&input, &indices}, {}, {0}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto* output = result->at(0); // output->printShapeInfo(); @@ -209,7 +209,7 @@ TEST_F(DeclarableOpsTests2, gather_9) { NDArray indices('c', {2}, {1, 0}, nd4j::DataType::INT32); nd4j::ops::gather op; - auto result = op.execute({&x, &indices}, {}, {-2}); + auto result = op.evaluate({&x, &indices}, {}, {-2}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -223,7 +223,7 @@ TEST_F(DeclarableOpsTests2, gather_10) { NDArray e('c', {2, 2}, {3, 4, 1, 2}); nd4j::ops::gather op; - auto result = op.execute({&x}, {}, {0, 1, 0}); + auto result = op.evaluate({&x}, {}, {0, 1, 0}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -242,7 +242,7 @@ TEST_F(DeclarableOpsTests2, gather_11) { NDArray e('c', {2, 2}, {3, 4, 1, 2}); nd4j::ops::gather op; - auto result = op.execute({&x, &indices}, {}, {0}); + auto result = op.evaluate({&x, &indices}, {}, {0}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -261,7 +261,7 @@ TEST_F(DeclarableOpsTests2, gather_12) { NDArray exp('c', {2}, {2.f, 4.f}); nd4j::ops::gather op; - auto result = op.execute({&input, &indices}, {}, {}); + auto result = op.evaluate({&input, &indices}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -294,7 +294,7 @@ TEST_F(DeclarableOpsTests2, gather_13) { nd4j::ops::gather op; - auto result = op.execute({&input, &indices}, {}, {2}, {true}); + auto result = op.evaluate({&input, &indices}, {}, {2}, {true}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -337,7 +337,7 @@ TEST_F(DeclarableOpsTests2, BroadcastGradientArgs_1) { nd4j::ops::broadcastgradientargs op; - auto result = op.execute({&input, &indices}, {}, {}); + auto result = op.evaluate({&input, &indices}, {}, {}); ASSERT_EQ(ND4J_STATUS_KERNEL_FAILURE, result->status()); @@ -375,7 +375,7 @@ TEST_F(DeclarableOpsTests2, NLP_Cbow_Test_1) { auto inferenceVector = NDArrayFactory::empty(); nd4j::ops::cbow op; - auto result = op.execute({&target, &ngStarter, &context, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &numWords, &locked, &inferenceVector}, {}, {}, {true}, true); + auto result = op.evaluate({&target, &ngStarter, &context, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &numWords, &locked, &inferenceVector}, {}, {}, {true}, {}, true); ASSERT_EQ(Status::OK(), result->status()); auto row_s0_0 = syn0({0,1, 0,0}, true); @@ -407,7 +407,7 @@ TEST_F(DeclarableOpsTests2, YetAnotherMatmulTest_1) { nd4j::ops::matmul op; - auto result = op.execute({&A, &B}, {}, {}); + auto result = op.evaluate({&A, &B}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -424,7 +424,7 @@ TEST_F(DeclarableOpsTests2, Test_Squeeze_1) { auto exp = x.reshape('c', {2, 3, 4}); nd4j::ops::squeeze op; - auto result = op.execute({&x}, {}, {}); + auto result = op.evaluate({&x}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -443,7 +443,7 @@ TEST_F(DeclarableOpsTests2, Test_Squeeze_2) { auto exp = new NDArray(x.dup()); nd4j::ops::squeeze op; - auto result = op.execute({&x}, {}, {}); + auto result = op.evaluate({&x}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -463,7 +463,7 @@ TEST_F(DeclarableOpsTests2, Test_FloorMod_1) { nd4j::ops::floormod op; - auto result = op.execute({&x, &y}, {}, {}); + auto result = op.evaluate({&x, &y}, {}, {}); auto z = result->at(0); @@ -481,7 +481,7 @@ TEST_F(DeclarableOpsTests2, Test_FloorDiv_1) { nd4j::ops::floordiv op; - auto result = op.execute({&x, &y}, {}, {}); + auto result = op.evaluate({&x, &y}, {}, {}); auto z = result->at(0); // z->printShapeInfo("FloorDiv1 shape"); @@ -503,7 +503,7 @@ TEST_F(DeclarableOpsTests2, Test_FloorDiv_2) { nd4j::ops::floordiv_bp op; - auto result = op.execute({&x, &y, &eps}, {}, {}); + auto result = op.evaluate({&x, &y, &eps}, {}, {}); ASSERT_EQ(result->status(), Status::OK()); auto z1 = result->at(0); auto z2 = result->at(1); @@ -523,7 +523,7 @@ TEST_F(DeclarableOpsTests2, Test_CRelu_1) { nd4j::ops::crelu op; - auto result = op.execute({&x}, {}, {}); + auto result = op.evaluate({&x}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -541,7 +541,7 @@ TEST_F(DeclarableOpsTests2, Test_CRelu_BP_2) { auto exp = NDArrayFactory::create('c', {2, 2}, {1.f, 2.f, -2.f, 4.f}); nd4j::ops::crelu_bp op; - auto result = op.execute({&x, &eps}, {}, {}); + auto result = op.evaluate({&x, &eps}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(1, result->size()); @@ -561,7 +561,7 @@ TEST_F(DeclarableOpsTests2, Test_Concat_BP_1) { auto expEY = NDArrayFactory::create('c', {2, 2}, {0.f, 1.f, 0.f, 1.f}); nd4j::ops::concat_bp op; - auto result = op.execute({&x, &y, &eps}, {}, {-1}); + auto result = op.evaluate({&x, &y, &eps}, {}, {-1}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(2, result->size()); @@ -586,7 +586,7 @@ TEST_F(DeclarableOpsTests2, TestTensorDot5) { auto expected = NDArrayFactory::create('c', {2,4,2,4}, {44,110,160, 66,132, 38, 88,154, 68,170,224,102,204, 82,136,238, 92,230,288,138,276,126,184,322, 116,290,352,174,348,170,232,406, 76,190,160,114,228,182,152,266, 100,250,224,150,300,226,200,350, 124,310,288,186,372,270,248,434, 148,370,352,222,444,314,296,518}); nd4j::ops::tensormmul op; - auto results = op.execute({&x, &y}, {}, {1,1,1,2}); + auto results = op.evaluate({&x, &y}, {}, {1,1,1,2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -608,7 +608,7 @@ TEST_F(DeclarableOpsTests2, TestTensorDot6) { auto expected = NDArrayFactory::create('c', {2,4,2,4}, {22, 66,110,154, 44, 88,132,176, 34,102,170,238, 68,136,204,272, 46,138,230,322, 92,184,276,368, 58,174,290,406,116,232,348,464, 38,114,190,266, 76,152,228,304, 50,150,250,350,100,200,300,400, 62,186,310,434,124,248,372,496, 74,222,370,518,148,296,444,592}); nd4j::ops::tensormmul op; - auto results = op.execute({&x, &y}, {}, {1,1,1,2}); + auto results = op.evaluate({&x, &y}, {}, {1,1,1,2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -629,7 +629,7 @@ TEST_F(DeclarableOpsTests2, TestTensorDot7) { auto expected = NDArrayFactory::create('c', {2,4,2,4}, {76,166,112,106,196, 62,136,226, 60,174,208, 98,212,230,136,250, 76,214,336,122,260,174,168,306, 124,286,240,178,340,150,232,394, 100,226,176,142,268,106,184,310, 84,234,272,134,284,274,184,334, 100,274,400,158,332,218,216,390, 148,346,304,214,412,194,280,478}); nd4j::ops::tensormmul op; - auto results = op.execute({&x, &y}, {}, {1,1,1,2}); + auto results = op.evaluate({&x, &y}, {}, {1,1,1,2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -650,7 +650,7 @@ TEST_F(DeclarableOpsTests2, TestTensorDot8) { auto expected = NDArrayFactory::create('c', {2,4,2,4}, {30, 90,150,210, 60,120,180,240, 38,114,190,266, 76,152,228,304, 46,138,230,322, 92,184,276,368, 54,162,270,378,108,216,324,432, 42,126,210,294, 84,168,252,336, 50,150,250,350,100,200,300,400, 58,174,290,406,116,232,348,464, 66,198,330,462,132,264,396,528}); nd4j::ops::tensormmul op; - auto results = op.execute({&x, &y}, {}, {1,1,1,2}); + auto results = op.evaluate({&x, &y}, {}, {1,1,1,2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -679,7 +679,7 @@ TEST_F(DeclarableOpsTests2, TestTensorDot9) { auto expected = NDArrayFactory::create('c', {3,4,4,3}, {14, 14, 14, 30, 30, 30, 46, 46, 46, 62, 62, 62, 86, 86, 86,198,198,198,310,310,310,422,422,422, 62, 62, 62,142,142,142,222,222,222,302,302,302, 38, 38, 38, 86, 86, 86,134,134,134,182,182,182, 38, 38, 38, 86, 86, 86,134,134,134,182,182,182, 14, 14, 14, 30, 30, 30, 46, 46, 46, 62, 62, 62, 86, 86, 86,198,198,198,310,310,310,422,422,422, 62, 62, 62,142,142,142,222,222,222,302,302,302, 62, 62, 62,142,142,142,222,222,222,302,302,302, 38, 38, 38, 86, 86, 86,134,134,134,182,182,182, 14, 14, 14, 30, 30, 30, 46, 46, 46, 62, 62, 62, 86, 86, 86,198,198,198,310,310,310,422,422,422}); nd4j::ops::tensormmul op; - auto results = op.execute({&x, &y}, {}, {1,0,1,0}); + auto results = op.evaluate({&x, &y}, {}, {1,0,1,0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -700,7 +700,7 @@ TEST_F(DeclarableOpsTests2, TestTensorDot10) { auto expected = NDArrayFactory::create('c', {4,4}, {114,258,402,546, 138,314,490,666, 162,370,578,786, 186,426,666,906}); nd4j::ops::tensormmul op; - auto results = op.execute({&x, &y}, {}, {2,0,1, 2,0,2}); + auto results = op.evaluate({&x, &y}, {}, {2,0,1, 2,0,2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -722,7 +722,7 @@ TEST_F(DeclarableOpsTests2, TestTensorDot11) { auto expected = NDArrayFactory::create('c', {4,4}, {98,218,338,458, 134,302,470,638, 170,386,602,818, 206,470,734,998}); nd4j::ops::tensormmul op; - auto results = op.execute({&x, &y}, {}, {2,0,1, 2,0,2}); + auto results = op.evaluate({&x, &y}, {}, {2,0,1, 2,0,2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -743,7 +743,7 @@ TEST_F(DeclarableOpsTests2, TestTensorDot12) { auto expected = NDArrayFactory::create('c', {4,4}, {272,292,312,332, 368,396,424,452, 464,500,536,572, 560,604,648,692}); nd4j::ops::tensormmul op; - auto results = op.execute({&x, &y}, {}, {2,0,1, 2,0,2}); + auto results = op.evaluate({&x, &y}, {}, {2,0,1, 2,0,2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -764,7 +764,7 @@ TEST_F(DeclarableOpsTests2, TestTensorDot13) { auto expected = NDArrayFactory::create('c', {3,3}, {640,560,640, 576,624,576, 640,560,640}); nd4j::ops::tensormmul op; - auto results = op.execute({&x, &y}, {}, {2,0,2, 2,1,0}); + auto results = op.evaluate({&x, &y}, {}, {2,0,2, 2,1,0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -785,7 +785,7 @@ TEST_F(DeclarableOpsTests2, TestTensorDot14) { auto expected = NDArrayFactory::create('c', {3,3}, {648,600,520, 648,536,648, 520,600,648}); nd4j::ops::tensormmul op; - auto results = op.execute({&x, &y}, {}, {2,0,2, 2,1,0}); + auto results = op.evaluate({&x, &y}, {}, {2,0,2, 2,1,0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -806,7 +806,7 @@ TEST_F(DeclarableOpsTests2, TestTensorDot15) { auto expected = NDArrayFactory::create('c', {3,3}, {624,624,624, 656,656,656, 624,624,624}); nd4j::ops::tensormmul op; - auto results = op.execute({&x, &y}, {}, {2,0,2, 2,1,0}); + auto results = op.evaluate({&x, &y}, {}, {2,0,2, 2,1,0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -833,7 +833,7 @@ TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_1) { expected.assign(0.5f); nd4j::ops::absolute_difference_loss op; - auto results = op.execute({&predictions, &weights, &labels}, {}, {0}); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -861,7 +861,7 @@ TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_2) { expected.assign(0.5f); nd4j::ops::absolute_difference_loss op; - auto results = op.execute({&predictions, &weights, &labels}, {}, {0}); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -889,7 +889,7 @@ TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_3) { expected.assign(0.5f); nd4j::ops::absolute_difference_loss op; - auto results = op.execute({&predictions, &weights, &labels}, {}, {0}); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -916,7 +916,7 @@ TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_4) { expected.assign(0.5f); nd4j::ops::absolute_difference_loss op; - auto results = op.execute({&predictions, &weights, &labels}, {}, {0}); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -943,7 +943,7 @@ TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_5) { expected.assign(0.5f); nd4j::ops::absolute_difference_loss op; - auto results = op.execute({&predictions, &weights, &labels}, {}, {0}); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -970,7 +970,7 @@ TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_6) { expected.assign(0.f); nd4j::ops::absolute_difference_loss op; - auto results = op.execute({&predictions, &weights, &labels}, {}, {0}); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -995,7 +995,7 @@ TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_7) { weights.assign(0.5f); nd4j::ops::absolute_difference_loss op; - auto results = op.execute({&predictions, &weights, &labels}, {}, {1}); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1020,7 +1020,7 @@ TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_8) { weights.assign(0.f); nd4j::ops::absolute_difference_loss op; - auto results = op.execute({&predictions, &weights, &labels}, {}, {1}); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1045,7 +1045,7 @@ TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_9) { weights.assign(0.5f); nd4j::ops::absolute_difference_loss op; - auto results = op.execute({&predictions, &weights, &labels}, {}, {1}); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1070,7 +1070,7 @@ TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_10) { weights.assign(0.5f); nd4j::ops::absolute_difference_loss op; - auto results = op.execute({&predictions, &weights, &labels}, {}, {1}); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1095,7 +1095,7 @@ TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_11) { weights.assign(0.5f); nd4j::ops::absolute_difference_loss op; - auto results = op.execute({&predictions, &weights, &labels}, {}, {2}); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1120,7 +1120,7 @@ TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_12) { weights.assign(0.f); nd4j::ops::absolute_difference_loss op; - auto results = op.execute({&predictions, &weights, &labels}, {}, {2}); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1145,7 +1145,7 @@ TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_13) { weights.assign(0.5f); nd4j::ops::absolute_difference_loss op; - auto results = op.execute({&predictions, &weights, &labels}, {}, {2}); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1172,7 +1172,7 @@ TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_14) { weights.p(2, 0.f); nd4j::ops::absolute_difference_loss op; - auto results = op.execute({&predictions, &weights, &labels}, {}, {2}); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1197,7 +1197,7 @@ TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_15) { weights.assign(0.5f); nd4j::ops::absolute_difference_loss op; - auto results = op.execute({&predictions, &weights, &labels}, {}, {2}); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1226,7 +1226,7 @@ TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_16) { predictions.p(3, 0.f); nd4j::ops::absolute_difference_loss op; - auto results = op.execute({&predictions, &weights, &labels}, {}, {2}); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1259,7 +1259,7 @@ TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_17) { labels.p(3, 0.f); nd4j::ops::absolute_difference_loss op; - auto results = op.execute({&predictions, &weights, &labels}, {}, {2}); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1292,7 +1292,7 @@ TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_18) { labels.p(3, 0.f); nd4j::ops::absolute_difference_loss op; - auto results = op.execute({&predictions, &weights, &labels}, {}, {2}); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1318,7 +1318,7 @@ TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_19) { weights.assign(0.5); nd4j::ops::absolute_difference_loss op; - auto results = op.execute({&predictions, &weights, &labels}, {}, {3}); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1343,7 +1343,7 @@ TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_20) { weights.assign(0.5); nd4j::ops::absolute_difference_loss op; - auto results = op.execute({&predictions, &weights, &labels}, {}, {3}); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1368,7 +1368,7 @@ TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_21) { weights.assign(0.5); nd4j::ops::absolute_difference_loss op; - auto results = op.execute({&predictions, &weights, &labels}, {}, {3}); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1393,7 +1393,7 @@ TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_22) { weights.assign(0.); nd4j::ops::absolute_difference_loss op; - auto results = op.execute({&predictions, &weights, &labels}, {}, {3}); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1430,7 +1430,7 @@ TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_23) { weights.p(40+3, 0.); nd4j::ops::absolute_difference_loss op; - auto results = op.execute({&predictions, &weights, &labels}, {}, {3}); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1456,7 +1456,7 @@ TEST_F(DeclarableOpsTests2, cosine_distance_loss_test1) { weights.assign(0.5); nd4j::ops::cosine_distance_loss op; - auto results = op.execute({&predictions, &weights, &labels}, {}, {0,0}); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0,0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1482,7 +1482,7 @@ TEST_F(DeclarableOpsTests2, cosine_distance_loss_test2) { predictions.assign(0.5); nd4j::ops::cosine_distance_loss op; - auto results = op.execute({&predictions, &weights, &labels}, {}, {0,1}); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0,1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1509,7 +1509,7 @@ TEST_F(DeclarableOpsTests2, cosine_distance_loss_test3) { predictions.assign(0.5); nd4j::ops::cosine_distance_loss op; - auto results = op.execute({&predictions, &weights, &labels}, {}, {0,2}); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0,2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1534,7 +1534,7 @@ TEST_F(DeclarableOpsTests2, cosine_distance_loss_test4) { predictions.assign(0.5); nd4j::ops::cosine_distance_loss op; - auto results = op.execute({&predictions, &weights, &labels}, {}, {0,2}); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0,2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1559,7 +1559,7 @@ TEST_F(DeclarableOpsTests2, cosine_distance_loss_test5) { predictions.assign(0.5); nd4j::ops::cosine_distance_loss op; - auto results = op.execute({&predictions, &weights, &labels}, {}, {1,1}); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1,1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1584,7 +1584,7 @@ TEST_F(DeclarableOpsTests2, cosine_distance_loss_test6) { predictions.assign(0.5); nd4j::ops::cosine_distance_loss op; - auto results = op.execute({&predictions, &weights, &labels}, {}, {1,1}); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1,1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1609,7 +1609,7 @@ TEST_F(DeclarableOpsTests2, cosine_distance_loss_test7) { predictions.assign(0.5); nd4j::ops::cosine_distance_loss op; - auto results = op.execute({&predictions, &weights, &labels}, {}, {1,0}); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1,0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1634,7 +1634,7 @@ TEST_F(DeclarableOpsTests2, cosine_distance_loss_test8) { predictions.assign(0.5f); nd4j::ops::cosine_distance_loss op; - auto results = op.execute({&predictions, &weights, &labels}, {}, {2,2}); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2,2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1659,7 +1659,7 @@ TEST_F(DeclarableOpsTests2, cosine_distance_loss_test9) { predictions.assign(0.5f); nd4j::ops::cosine_distance_loss op; - auto results = op.execute({&predictions, &weights, &labels}, {}, {2,2}); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2,2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1686,7 +1686,7 @@ TEST_F(DeclarableOpsTests2, cosine_distance_loss_test10) { weights.p(1, 0.f); nd4j::ops::cosine_distance_loss op; - auto results = op.execute({&predictions, &weights, &labels}, {}, {2,2}); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2,2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1711,7 +1711,7 @@ TEST_F(DeclarableOpsTests2, hinge_loss_test1) { weights.assign(0.5); nd4j::ops::hinge_loss op; - auto results = op.execute({&logits, &weights, &labels}, {}, {0}); + auto results = op.evaluate({&logits, &weights, &labels}, {}, {0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1736,7 +1736,7 @@ TEST_F(DeclarableOpsTests2, hinge_loss_test2) { weights.assign(0.5); nd4j::ops::hinge_loss op; - auto results = op.execute({&logits, &weights, &labels}, {}, {0}); + auto results = op.evaluate({&logits, &weights, &labels}, {}, {0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1761,7 +1761,7 @@ TEST_F(DeclarableOpsTests2, hinge_loss_test3) { weights.assign(0.5); nd4j::ops::hinge_loss op; - auto results = op.execute({&logits, &weights, &labels}, {}, {0}); + auto results = op.evaluate({&logits, &weights, &labels}, {}, {0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1786,7 +1786,7 @@ TEST_F(DeclarableOpsTests2, hinge_loss_test4) { nd4j::ops::hinge_loss op; - auto results = op.execute({&logits, &weights, &labels}, {}, {1}); + auto results = op.evaluate({&logits, &weights, &labels}, {}, {1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1810,7 +1810,7 @@ TEST_F(DeclarableOpsTests2, hinge_loss_test5) { nd4j::ops::hinge_loss op; - auto results = op.execute({&logits, &weights, &labels}, {}, {1}); + auto results = op.evaluate({&logits, &weights, &labels}, {}, {1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1834,7 +1834,7 @@ TEST_F(DeclarableOpsTests2, hinge_loss_test6) { nd4j::ops::hinge_loss op; - auto results = op.execute({&logits, &weights, &labels}, {}, {1}); + auto results = op.evaluate({&logits, &weights, &labels}, {}, {1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1858,7 +1858,7 @@ TEST_F(DeclarableOpsTests2, hinge_loss_test7) { nd4j::ops::hinge_loss op; - auto results = op.execute({&logits, &weights, &labels}, {}, {2}); + auto results = op.evaluate({&logits, &weights, &labels}, {}, {2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1882,7 +1882,7 @@ TEST_F(DeclarableOpsTests2, hinge_loss_test8) { nd4j::ops::hinge_loss op; - auto results = op.execute({&logits, &weights, &labels}, {}, {2}); + auto results = op.evaluate({&logits, &weights, &labels}, {}, {2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1906,7 +1906,7 @@ TEST_F(DeclarableOpsTests2, hinge_loss_test9) { nd4j::ops::hinge_loss op; - auto results = op.execute({&logits, &weights, &labels}, {}, {2}); + auto results = op.evaluate({&logits, &weights, &labels}, {}, {2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1930,7 +1930,7 @@ TEST_F(DeclarableOpsTests2, hinge_loss_test10) { nd4j::ops::hinge_loss op; - auto results = op.execute({&logits, &weights, &labels}, {}, {3}); + auto results = op.evaluate({&logits, &weights, &labels}, {}, {3}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1954,7 +1954,7 @@ TEST_F(DeclarableOpsTests2, hinge_loss_test11) { nd4j::ops::hinge_loss op; - auto results = op.execute({&logits, &weights, &labels}, {}, {3}); + auto results = op.evaluate({&logits, &weights, &labels}, {}, {3}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1982,7 +1982,7 @@ TEST_F(DeclarableOpsTests2, hinge_loss_test12) { nd4j::ops::hinge_loss op; - auto results = op.execute({&logits, &weights, &labels}, {}, {3}); + auto results = op.evaluate({&logits, &weights, &labels}, {}, {3}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2006,7 +2006,7 @@ TEST_F(DeclarableOpsTests2, hinge_loss_test13) { nd4j::ops::hinge_loss op; - auto results = op.execute({&logits, &weights, &labels}, {}, {3}); + auto results = op.evaluate({&logits, &weights, &labels}, {}, {3}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2031,7 +2031,7 @@ TEST_F(DeclarableOpsTests2, huber_loss_test1) { weights.assign(0.5); nd4j::ops::huber_loss op; - auto results = op.execute({&predictions, &weights, &labels}, {0.1}, {0}); + auto results = op.evaluate({&predictions, &weights, &labels}, {0.1}, {0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2056,7 +2056,7 @@ TEST_F(DeclarableOpsTests2, huber_loss_test2) { weights.assign(0.5); nd4j::ops::huber_loss op; - auto results = op.execute({&predictions, &weights, &labels}, {0.1}, {0}); + auto results = op.evaluate({&predictions, &weights, &labels}, {0.1}, {0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2081,7 +2081,7 @@ TEST_F(DeclarableOpsTests2, huber_loss_test3) { weights.assign(0.5); nd4j::ops::huber_loss op; - auto results = op.execute({&predictions, &weights, &labels}, {0.1}, {0}); + auto results = op.evaluate({&predictions, &weights, &labels}, {0.1}, {0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2105,7 +2105,7 @@ TEST_F(DeclarableOpsTests2, huber_loss_test4) { weights.assign(0.5); nd4j::ops::huber_loss op; - auto results = op.execute({&predictions, &weights, &labels}, {0.1}, {1}); + auto results = op.evaluate({&predictions, &weights, &labels}, {0.1}, {1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2129,7 +2129,7 @@ TEST_F(DeclarableOpsTests2, huber_loss_test5) { weights.assign(0.5); nd4j::ops::huber_loss op; - auto results = op.execute({&predictions, &weights, &labels}, {0.1}, {1}); + auto results = op.evaluate({&predictions, &weights, &labels}, {0.1}, {1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2153,7 +2153,7 @@ TEST_F(DeclarableOpsTests2, huber_loss_test6) { weights.assign(0.5); nd4j::ops::huber_loss op; - auto results = op.execute({&predictions, &weights, &labels}, {0.1}, {2}); + auto results = op.evaluate({&predictions, &weights, &labels}, {0.1}, {2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2177,7 +2177,7 @@ TEST_F(DeclarableOpsTests2, huber_loss_test7) { weights.assign(0.5); nd4j::ops::huber_loss op; - auto results = op.execute({&predictions, &weights, &labels}, {0.1}, {2}); + auto results = op.evaluate({&predictions, &weights, &labels}, {0.1}, {2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2205,7 +2205,7 @@ TEST_F(DeclarableOpsTests2, huber_loss_test8) { weights.p(3, 0.); nd4j::ops::huber_loss op; - auto results = op.execute({&predictions, &weights, &labels}, {0.1}, {2}); + auto results = op.evaluate({&predictions, &weights, &labels}, {0.1}, {2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2229,7 +2229,7 @@ TEST_F(DeclarableOpsTests2, huber_loss_test9) { weights.assign(0.5); nd4j::ops::huber_loss op; - auto results = op.execute({&predictions, &weights, &labels}, {0.1}, {3}); + auto results = op.evaluate({&predictions, &weights, &labels}, {0.1}, {3}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2253,7 +2253,7 @@ TEST_F(DeclarableOpsTests2, huber_loss_test10) { weights.assign(0.5); nd4j::ops::huber_loss op; - auto results = op.execute({&predictions, &weights, &labels}, {0.1}, {3}); + auto results = op.evaluate({&predictions, &weights, &labels}, {0.1}, {3}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2281,7 +2281,7 @@ TEST_F(DeclarableOpsTests2, huber_loss_test11) { weights.p(3, 0.); nd4j::ops::huber_loss op; - auto results = op.execute({&predictions, &weights, &labels}, {0.1}, {3}); + auto results = op.evaluate({&predictions, &weights, &labels}, {0.1}, {3}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2306,7 +2306,7 @@ TEST_F(DeclarableOpsTests2, log_loss_test1) { weights.assign(0.5); nd4j::ops::log_loss op; - auto results = op.execute({&predictions, &weights, &labels}, {1e-7}, {0}); + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2331,7 +2331,7 @@ TEST_F(DeclarableOpsTests2, log_loss_test2) { weights.assign(0.5); nd4j::ops::log_loss op; - auto results = op.execute({&predictions, &weights, &labels}, {1e-7}, {0}); + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2356,7 +2356,7 @@ TEST_F(DeclarableOpsTests2, log_loss_test3) { weights.assign(0.5); nd4j::ops::log_loss op; - auto results = op.execute({&predictions, &weights, &labels}, {1e-7}, {0}); + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2380,7 +2380,7 @@ TEST_F(DeclarableOpsTests2, log_loss_test4) { weights.assign(0.5); nd4j::ops::log_loss op; - auto results = op.execute({&predictions, &weights, &labels}, {1e-7}, {1}); + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2404,7 +2404,7 @@ TEST_F(DeclarableOpsTests2, log_loss_test5) { weights.assign(0.5); nd4j::ops::log_loss op; - auto results = op.execute({&predictions, &weights, &labels}, {1e-7}, {1}); + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2428,7 +2428,7 @@ TEST_F(DeclarableOpsTests2, log_loss_test6) { weights.assign(0.5); nd4j::ops::log_loss op; - auto results = op.execute({&predictions, &weights, &labels}, {1e-7}, {1}); + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2452,7 +2452,7 @@ TEST_F(DeclarableOpsTests2, log_loss_test7) { weights.assign(0.5); nd4j::ops::log_loss op; - auto results = op.execute({&predictions, &weights, &labels}, {1e-7}, {2}); + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2476,7 +2476,7 @@ TEST_F(DeclarableOpsTests2, log_loss_test8) { weights.assign(0.5); nd4j::ops::log_loss op; - auto results = op.execute({&predictions, &weights, &labels}, {1e-7}, {2}); + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2500,7 +2500,7 @@ TEST_F(DeclarableOpsTests2, log_loss_test9) { weights.assign(0.5); nd4j::ops::log_loss op; - auto results = op.execute({&predictions, &weights, &labels}, {1e-7}, {2}); + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2528,7 +2528,7 @@ TEST_F(DeclarableOpsTests2, log_loss_test10) { weights.p(3, 0.); nd4j::ops::log_loss op; - auto results = op.execute({&predictions, &weights, &labels}, {1e-7}, {2}); + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2552,7 +2552,7 @@ TEST_F(DeclarableOpsTests2, log_loss_test11) { weights.assign(0.5); nd4j::ops::log_loss op; - auto results = op.execute({&predictions, &weights, &labels}, {1e-7}, {3}); + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {3}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2576,7 +2576,7 @@ TEST_F(DeclarableOpsTests2, log_loss_test12) { weights.assign(0.5); nd4j::ops::log_loss op; - auto results = op.execute({&predictions, &weights, &labels}, {1e-7}, {3}); + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {3}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2604,7 +2604,7 @@ TEST_F(DeclarableOpsTests2, log_loss_test13) { weights.p(3, 0.); nd4j::ops::log_loss op; - auto results = op.execute({&predictions, &weights, &labels}, {1e-7}, {3}); + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {3}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2624,7 +2624,7 @@ TEST_F(DeclarableOpsTests2, mean_pairwssqerr_loss_test1) { auto expected = NDArrayFactory::create('c', {1,1}, {1.}); nd4j::ops::mean_pairwssqerr_loss op; - auto results = op.execute({&predictions, &weights, &labels}, {}, {0}); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2644,7 +2644,7 @@ TEST_F(DeclarableOpsTests2, mean_pairwssqerr_loss_test2) { auto expected = NDArrayFactory::create('c', {10,1}, {1.9665822560405073, 3.806679563402927, 6.185624212589066, 20.237895345263905, 16.739700814450472, 13.655430201400929, 6.473256392322658, 3.9337379694106325, 22.509455553531062, 1.4741234749089487}); nd4j::ops::mean_pairwssqerr_loss op; - auto results = op.execute({&predictions, &weights, &labels}, {}, {0}); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2664,7 +2664,7 @@ TEST_F(DeclarableOpsTests2, mean_pairwssqerr_loss_test3) { auto expected = NDArrayFactory::create('c', {10,1}, {0.0, 0.0, 21.748459867092496, 6.090581568657439, 7.51315897553838, 5.999534225166869, 22.58050883748054, 6.8600435676788605, 107.5976928688877, 191.56864939172544}); nd4j::ops::mean_pairwssqerr_loss op; - auto results = op.execute({&predictions, &weights, &labels}, {}, {0}); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2683,7 +2683,7 @@ TEST_F(DeclarableOpsTests2, mean_pairwssqerr_loss_test4) { auto weights = NDArrayFactory::create('c', {1,1}, {1}); nd4j::ops::mean_pairwssqerr_loss op; - auto results = op.execute({&predictions, &weights, &labels}, {}, {1}); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2702,7 +2702,7 @@ TEST_F(DeclarableOpsTests2, mean_pairwssqerr_loss_test5) { auto weights = NDArrayFactory::create('c', {1,1}, {1}); nd4j::ops::mean_pairwssqerr_loss op; - auto results = op.execute({&predictions, &weights, &labels}, {}, {2}); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2721,7 +2721,7 @@ TEST_F(DeclarableOpsTests2, mean_pairwssqerr_loss_test6) { auto weights = NDArrayFactory::create('c', {1,1}, {1}); nd4j::ops::mean_pairwssqerr_loss op; - auto results = op.execute({&predictions, &weights, &labels}, {}, {3}); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2740,7 +2740,7 @@ TEST_F(DeclarableOpsTests2, mean_pairwssqerr_loss_test7) { auto weights = NDArrayFactory::create('c', {10,1}, {0.0, 0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0}); nd4j::ops::mean_pairwssqerr_loss op; - auto results = op.execute({&predictions, &weights, &labels}, {}, {1}); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2759,7 +2759,7 @@ TEST_F(DeclarableOpsTests2, mean_pairwssqerr_loss_test8) { auto weights = NDArrayFactory::create('c', {10,1}, {0.0, 0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0}); nd4j::ops::mean_pairwssqerr_loss op; - auto results = op.execute({&predictions, &weights, &labels}, {}, {2}); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2778,7 +2778,7 @@ TEST_F(DeclarableOpsTests2, mean_pairwssqerr_loss_test9) { auto weights = NDArrayFactory::create('c', {10,1}, {0.0, 0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0}); nd4j::ops::mean_pairwssqerr_loss op; - auto results = op.execute({&predictions, &weights, &labels}, {}, {3}); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2802,7 +2802,7 @@ TEST_F(DeclarableOpsTests2, mean_sqerr_loss_test1) { weights.assign(0.5); nd4j::ops::mean_sqerr_loss op; - auto results = op.execute({&predictions, &weights, &labels}, {}, {0}); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2827,7 +2827,7 @@ TEST_F(DeclarableOpsTests2, mean_sqerr_loss_test2) { weights.assign(0.5); nd4j::ops::mean_sqerr_loss op; - auto results = op.execute({&predictions, &weights, &labels}, {}, {0}); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2852,7 +2852,7 @@ TEST_F(DeclarableOpsTests2, mean_sqerr_loss_test3) { weights.assign(0.5); nd4j::ops::mean_sqerr_loss op; - auto results = op.execute({&predictions, &weights, &labels}, {}, {0}); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2881,7 +2881,7 @@ TEST_F(DeclarableOpsTests2, mean_sqerr_loss_test4) { weights.p(3, 0.); nd4j::ops::mean_sqerr_loss op; - auto results = op.execute({&predictions, &weights, &labels}, {}, {0}); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2905,7 +2905,7 @@ TEST_F(DeclarableOpsTests2, mean_sqerr_loss_test5) { weights.assign(0.5); nd4j::ops::mean_sqerr_loss op; - auto results = op.execute({&predictions, &weights, &labels}, {}, {1}); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2929,7 +2929,7 @@ TEST_F(DeclarableOpsTests2, mean_sqerr_loss_test6) { weights.assign(0.5); nd4j::ops::mean_sqerr_loss op; - auto results = op.execute({&predictions, &weights, &labels}, {}, {1}); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2953,7 +2953,7 @@ TEST_F(DeclarableOpsTests2, mean_sqerr_loss_test7) { weights.assign(0.5); nd4j::ops::mean_sqerr_loss op; - auto results = op.execute({&predictions, &weights, &labels}, {}, {1}); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2981,7 +2981,7 @@ TEST_F(DeclarableOpsTests2, mean_sqerr_loss_test8) { weights.p(3, 0.); nd4j::ops::mean_sqerr_loss op; - auto results = op.execute({&predictions, &weights, &labels}, {}, {1}); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -3005,7 +3005,7 @@ TEST_F(DeclarableOpsTests2, mean_sqerr_loss_test9) { weights.assign(0.5); nd4j::ops::mean_sqerr_loss op; - auto results = op.execute({&predictions, &weights, &labels}, {}, {2}); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -3029,7 +3029,7 @@ TEST_F(DeclarableOpsTests2, mean_sqerr_loss_test10) { weights.assign(0.5); nd4j::ops::mean_sqerr_loss op; - auto results = op.execute({&predictions, &weights, &labels}, {}, {2}); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -3053,7 +3053,7 @@ TEST_F(DeclarableOpsTests2, mean_sqerr_loss_test11) { weights.assign(0.5); nd4j::ops::mean_sqerr_loss op; - auto results = op.execute({&predictions, &weights, &labels}, {}, {2}); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -3080,7 +3080,7 @@ TEST_F(DeclarableOpsTests2, mean_sqerr_loss_test12) { weights.p(2, 0.); nd4j::ops::mean_sqerr_loss op; - auto results = op.execute({&predictions, &weights, &labels}, {}, {2}); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -3104,7 +3104,7 @@ TEST_F(DeclarableOpsTests2, mean_sqerr_loss_test13) { weights.assign(0.5); nd4j::ops::mean_sqerr_loss op; - auto results = op.execute({&predictions, &weights, &labels}, {}, {3}); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -3128,7 +3128,7 @@ TEST_F(DeclarableOpsTests2, mean_sqerr_loss_test14) { weights.assign(0.5); nd4j::ops::mean_sqerr_loss op; - auto results = op.execute({&predictions, &weights, &labels}, {}, {3}); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -3152,7 +3152,7 @@ TEST_F(DeclarableOpsTests2, mean_sqerr_loss_test15) { weights.assign(0.5); nd4j::ops::mean_sqerr_loss op; - auto results = op.execute({&predictions, &weights, &labels}, {}, {3}); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -3179,7 +3179,7 @@ TEST_F(DeclarableOpsTests2, mean_sqerr_loss_test16) { weights.p(2, 0.); nd4j::ops::mean_sqerr_loss op; - auto results = op.execute({&predictions, &weights, &labels}, {}, {3}); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -3203,7 +3203,7 @@ TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test1) { weights.assign(0.5); nd4j::ops::sigm_cross_entropy_loss op; - auto results = op.execute({&logits, &weights, &labels}, {0.}, {0}); + auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -3227,7 +3227,7 @@ TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test2) { weights.assign(0.5); nd4j::ops::sigm_cross_entropy_loss op; - auto results = op.execute({&logits, &weights, &labels}, {0.}, {0}); + auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -3251,7 +3251,7 @@ TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test3) { weights.assign(0.5); nd4j::ops::sigm_cross_entropy_loss op; - auto results = op.execute({&logits, &weights, &labels}, {0.}, {0}); + auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -3275,7 +3275,7 @@ TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test4) { weights.assign(0.5); nd4j::ops::sigm_cross_entropy_loss op; - auto results = op.execute({&logits, &weights, &labels}, {5.}, {0}); + auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -3298,7 +3298,7 @@ TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test5) { weights.assign(0.5); nd4j::ops::sigm_cross_entropy_loss op; - auto results = op.execute({&logits, &weights, &labels}, {0.}, {1}); + auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -3321,7 +3321,7 @@ TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test6) { weights.assign(0.5); nd4j::ops::sigm_cross_entropy_loss op; - auto results = op.execute({&logits, &weights, &labels}, {0.}, {1}); + auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -3344,7 +3344,7 @@ TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test7) { weights.assign(0.5); nd4j::ops::sigm_cross_entropy_loss op; - auto results = op.execute({&logits, &weights, &labels}, {0.}, {1}); + auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -3367,7 +3367,7 @@ TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test8) { weights.assign(0.5); nd4j::ops::sigm_cross_entropy_loss op; - auto results = op.execute({&logits, &weights, &labels}, {5.}, {1}); + auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -3394,7 +3394,7 @@ TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test9) { weights.p(2, 0.); nd4j::ops::sigm_cross_entropy_loss op; - auto results = op.execute({&logits, &weights, &labels}, {5.}, {1}); + auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -3417,7 +3417,7 @@ TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test10) { weights.assign(0.5); nd4j::ops::sigm_cross_entropy_loss op; - auto results = op.execute({&logits, &weights, &labels}, {0.}, {2}); + auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -3440,7 +3440,7 @@ TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test11) { weights.assign(0.5); nd4j::ops::sigm_cross_entropy_loss op; - auto results = op.execute({&logits, &weights, &labels}, {0.}, {2}); + auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -3463,7 +3463,7 @@ TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test12) { weights.assign(0.5); nd4j::ops::sigm_cross_entropy_loss op; - auto results = op.execute({&logits, &weights, &labels}, {5.}, {2}); + auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -3489,7 +3489,7 @@ TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test13) { weights.p(2, 0.); nd4j::ops::sigm_cross_entropy_loss op; - auto results = op.execute({&logits, &weights, &labels}, {5.}, {2}); + auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -3512,7 +3512,7 @@ TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test14) { weights.assign(0.5); nd4j::ops::sigm_cross_entropy_loss op; - auto results = op.execute({&logits, &weights, &labels}, {0.}, {3}); + auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {3}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -3535,7 +3535,7 @@ TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test15) { weights.assign(0.5); nd4j::ops::sigm_cross_entropy_loss op; - auto results = op.execute({&logits, &weights, &labels}, {0.}, {3}); + auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {3}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -3558,7 +3558,7 @@ TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test16) { weights.assign(0.5); nd4j::ops::sigm_cross_entropy_loss op; - auto results = op.execute({&logits, &weights, &labels}, {5.}, {3}); + auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {3}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -3584,7 +3584,7 @@ TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test17) { weights.p(2, 0.); nd4j::ops::sigm_cross_entropy_loss op; - auto results = op.execute({&logits, &weights, &labels}, {5.}, {3}); + auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {3}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -3608,7 +3608,7 @@ TEST_F(DeclarableOpsTests2, softmax_cross_entropy_loss_test1) { weights.assign(0.5); nd4j::ops::softmax_cross_entropy_loss op; - auto results = op.execute({&logits, &weights, &labels}, {0.}, {0}); + auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -3631,7 +3631,7 @@ TEST_F(DeclarableOpsTests2, softmax_cross_entropy_loss_test2) { weights.assign(0.5); nd4j::ops::softmax_cross_entropy_loss op; - auto results = op.execute({&logits, &weights, &labels}, {5.}, {0}, {}); + auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {0}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -3655,7 +3655,7 @@ TEST_F(DeclarableOpsTests2, softmax_cross_entropy_loss_test3) { weights.assign(0.5); nd4j::ops::softmax_cross_entropy_loss op; - auto results = op.execute({&logits, &weights, &labels}, {5.}, {0}); + auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -3679,7 +3679,7 @@ TEST_F(DeclarableOpsTests2, softmax_cross_entropy_loss_test4) { weights.assign(0.5); nd4j::ops::softmax_cross_entropy_loss op; - auto results = op.execute({&logits, &weights, &labels}, {5.}, {0}); + auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -3703,7 +3703,7 @@ TEST_F(DeclarableOpsTests2, softmax_cross_entropy_loss_test5) { weights.assign(0.5); nd4j::ops::softmax_cross_entropy_loss op; - auto results = op.execute({&logits, &weights, &labels}, {5.}, {0}); + auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -3726,7 +3726,7 @@ TEST_F(DeclarableOpsTests2, softmax_cross_entropy_loss_test6) { weights.assign(0.5); nd4j::ops::softmax_cross_entropy_loss op; - auto results = op.execute({&logits, &weights, &labels}, {0.}, {1}); + auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -3749,7 +3749,7 @@ TEST_F(DeclarableOpsTests2, softmax_cross_entropy_loss_test7) { weights.assign(0.5); nd4j::ops::softmax_cross_entropy_loss op; - auto results = op.execute({&logits, &weights, &labels}, {5.}, {1}); + auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -3772,7 +3772,7 @@ TEST_F(DeclarableOpsTests2, softmax_cross_entropy_loss_test8) { weights.assign(0.5); nd4j::ops::softmax_cross_entropy_loss op; - auto results = op.execute({&logits, &weights, &labels}, {5.}, {1}); + auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -3795,7 +3795,7 @@ TEST_F(DeclarableOpsTests2, softmax_cross_entropy_loss_test9) { weights.assign(0.5); nd4j::ops::softmax_cross_entropy_loss op; - auto results = op.execute({&logits, &weights, &labels}, {5.}, {1}); + auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -3818,7 +3818,7 @@ TEST_F(DeclarableOpsTests2, softmax_cross_entropy_loss_test10) { weights.assign(0.5); nd4j::ops::softmax_cross_entropy_loss op; - auto results = op.execute({&logits, &weights, &labels}, {5.}, {2}); + auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -3841,7 +3841,7 @@ TEST_F(DeclarableOpsTests2, softmax_cross_entropy_loss_test11) { weights.assign(0.5); nd4j::ops::softmax_cross_entropy_loss op; - auto results = op.execute({&logits, &weights, &labels}, {5.}, {3}, {}, false, nd4j::DataType::DOUBLE); + auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {3}, {}, {}, false); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -3864,7 +3864,7 @@ TEST_F(DeclarableOpsTests2, softmax_cross_entropy_loss_test12) { weights.assign(0.5); nd4j::ops::softmax_cross_entropy_loss op; - auto results = op.execute({&logits, &weights, &labels}, {5.}, {3}, {}, false, nd4j::DataType::DOUBLE); + auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {3}, {}, {}, false); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -3888,7 +3888,7 @@ TEST_F(DeclarableOpsTests2, softmax_cross_entropy_loss_test13) { weights.assign(0.5); nd4j::ops::softmax_cross_entropy_loss op; - auto results = op.execute({&logits, &weights, &labels}, {0.}, {0}); + auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -3914,7 +3914,7 @@ TEST_F(DeclarableOpsTests2, softmax_cross_entropy_loss_test14) { weights.assign(0.5); nd4j::ops::softmax_cross_entropy_loss op; - auto results = op.execute({&logits, &weights, &labels}, {5.}, {0}); + auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -3938,7 +3938,7 @@ TEST_F(DeclarableOpsTests2, softmax_cross_entropy_loss_test15) { weights.assign(0.5); nd4j::ops::softmax_cross_entropy_loss op; - auto results = op.execute({&logits, &weights, &labels}, {5.}, {0}); + auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -3980,7 +3980,7 @@ TEST_F(DeclarableOpsTests2, lstmCell_test1) { auto expCt = NDArrayFactory::create('c', {batchSize, numUnits},{3.99987108,3.99987108,3.99987108,3.99987108,3.99987108,3.99987108,3.99987108,3.99987108}); nd4j::ops::lstmCell op; - auto results = op.execute({&xt, &ht_1, &ct_1, &Wx, &Wh, &Wc, &Wp, &b}, {0., 0., 1.}, {0, 0}); + auto results = op.evaluate({&xt, &ht_1, &ct_1, &Wx, &Wh, &Wc, &Wp, &b}, {0., 0., 1.}, {0, 0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -4025,7 +4025,7 @@ TEST_F(DeclarableOpsTests2, lstmCell_test2) { auto expCt = NDArrayFactory::create('c', {batchSize, numUnits},{1.93001527,1.93001527,1.93001527,1.93001527, 1.93001527,1.93001527,1.93001527,1.93001527}); nd4j::ops::lstmCell op; - auto results = op.execute({&xt, &ht_1, &ct_1, &Wx, &Wh, &Wc, &Wp, &b}, {0., 0., -10.5}, {0, 0}); + auto results = op.evaluate({&xt, &ht_1, &ct_1, &Wx, &Wh, &Wc, &Wp, &b}, {0., 0., -10.5}, {0, 0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -4070,7 +4070,7 @@ TEST_F(DeclarableOpsTests2, lstmCell_test3) { auto expCt = NDArrayFactory::create('c', {batchSize, numUnits},{0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4}); nd4j::ops::lstmCell op; - auto results = op.execute({&xt, &ht_1, &ct_1, &Wx, &Wh, &Wc, &Wp, &b}, {0.4, 0., 1.5}, {0, 0}); + auto results = op.evaluate({&xt, &ht_1, &ct_1, &Wx, &Wh, &Wc, &Wp, &b}, {0.4, 0., 1.5}, {0, 0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -4115,7 +4115,7 @@ TEST_F(DeclarableOpsTests2, lstmCell_test4) { auto expCt = NDArrayFactory::create('c', {batchSize, numUnits},{0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4}); nd4j::ops::lstmCell op; - auto results = op.execute({&xt, &ht_1, &ct_1, &Wx, &Wh, &Wc, &Wp, &b}, {0.4, 0.3, 1.5}, {0, 0}); + auto results = op.evaluate({&xt, &ht_1, &ct_1, &Wx, &Wh, &Wc, &Wp, &b}, {0.4, 0.3, 1.5}, {0, 0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -4160,7 +4160,7 @@ TEST_F(DeclarableOpsTests2, lstmCell_test5) { auto expCt = NDArrayFactory::create('c', {batchSize, numUnits},{0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4}); nd4j::ops::lstmCell op; - auto results = op.execute({&xt, &ht_1, &ct_1, &Wx, &Wh, &Wc, &Wp, &b}, {0.4, 0.3, 1.5}, {0, 1}); + auto results = op.evaluate({&xt, &ht_1, &ct_1, &Wx, &Wh, &Wc, &Wp, &b}, {0.4, 0.3, 1.5}, {0, 1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -4205,7 +4205,7 @@ TEST_F(DeclarableOpsTests2, lstmCell_test6) { auto expCt = NDArrayFactory::create('c', {batchSize, numUnits},{3.99972188,3.99972188,3.99972188,3.99972188,3.99972188,3.99972188,3.99972188,3.99972188}); nd4j::ops::lstmCell op; - auto results = op.execute({&xt, &ht_1, &ct_1, &Wx, &Wh, &Wc, &Wp, &b}, {0., 0., 1.5}, {0, 1}); + auto results = op.evaluate({&xt, &ht_1, &ct_1, &Wx, &Wh, &Wc, &Wp, &b}, {0., 0., 1.5}, {0, 1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -4250,7 +4250,7 @@ TEST_F(DeclarableOpsTests2, lstmCell_test7) { auto expCt = NDArrayFactory::create('c', {batchSize, numUnits},{0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4}); nd4j::ops::lstmCell op; - auto results = op.execute({&xt, &ht_1, &ct_1, &Wx, &Wh, &Wc, &Wp, &b}, {0.4, 0., 1.5}, {0, 1}); + auto results = op.evaluate({&xt, &ht_1, &ct_1, &Wx, &Wh, &Wc, &Wp, &b}, {0.4, 0., 1.5}, {0, 1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -4296,7 +4296,7 @@ TEST_F(DeclarableOpsTests2, lstmCell_test8) { auto expCt = NDArrayFactory::create('c', {batchSize, numUnits},{3.99996277,3.99996277,3.99996277,3.99996277,3.99996277,3.99996277,3.99996277,3.99996277}); nd4j::ops::lstmCell op; - auto results = op.execute({&xt, &ht_1, &ct_1, &Wx, &Wh, &Wc, &Wp, &b}, {0., 0., 10.5}, {1, 0}); + auto results = op.evaluate({&xt, &ht_1, &ct_1, &Wx, &Wh, &Wc, &Wp, &b}, {0., 0., 10.5}, {1, 0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -4341,7 +4341,7 @@ TEST_F(DeclarableOpsTests2, lstmCell_test9) { auto expCt = NDArrayFactory::create('c', {batchSize, numUnits},{3.,3.,3.,3.,3.,3.,3.,3.}); nd4j::ops::lstmCell op; - auto results = op.execute({&xt, &ht_1, &ct_1, &Wx, &Wh, &Wc, &Wp, &b}, {3., 0., 10.5}, {1, 0}); + auto results = op.evaluate({&xt, &ht_1, &ct_1, &Wx, &Wh, &Wc, &Wp, &b}, {3., 0., 10.5}, {1, 0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -4386,7 +4386,7 @@ TEST_F(DeclarableOpsTests2, lstmCell_test10) { auto expCt = NDArrayFactory::create('c', {batchSize, numUnits},{3.99996277, 3.99996277, 3.99996277, 3.99996277,3.99996277, 3.99996277, 3.99996277, 3.99996277}); nd4j::ops::lstmCell op; - auto results = op.execute({&xt, &ht_1, &ct_1, &Wx, &Wh, &Wc, &Wp, &b}, {0., 0., 10.5}, {1, 1}); + auto results = op.evaluate({&xt, &ht_1, &ct_1, &Wx, &Wh, &Wc, &Wp, &b}, {0., 0., 10.5}, {1, 1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -4431,7 +4431,7 @@ TEST_F(DeclarableOpsTests2, lstmCell_test11) { auto expCt = NDArrayFactory::create('c', {batchSize, numUnits},{3.,3.,3.,3.,3.,3.,3.,3.}); nd4j::ops::lstmCell op; - auto results = op.execute({&xt, &ht_1, &ct_1, &Wx, &Wh, &Wc, &Wp, &b}, {3., 0., 10.5}, {1, 1}); + auto results = op.evaluate({&xt, &ht_1, &ct_1, &Wx, &Wh, &Wc, &Wp, &b}, {3., 0., 10.5}, {1, 1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -4476,7 +4476,7 @@ TEST_F(DeclarableOpsTests2, lstmCell_test12) { auto expCt = NDArrayFactory::create('c', {batchSize, numUnits},{3.,3.,3.,3.,3.,3.,3.,3.}); nd4j::ops::lstmCell op; - auto results = op.execute({&xt, &ht_1, &ct_1, &Wx, &Wh, &Wc, &Wp, &b}, {3., 1.,-5.}, {1, 1}); + auto results = op.evaluate({&xt, &ht_1, &ct_1, &Wx, &Wh, &Wc, &Wp, &b}, {3., 1.,-5.}, {1, 1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests3.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests3.cpp index 2ef86710a..e39589270 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests3.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests3.cpp @@ -43,7 +43,7 @@ TEST_F(DeclarableOpsTests3, Test_Tile_1) { auto exp = x.tile(reps); nd4j::ops::tile op; - auto result = op.execute({&x, &rep_vector}, {}, {}); + auto result = op.evaluate({&x, &rep_vector}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -61,7 +61,7 @@ TEST_F(DeclarableOpsTests3, Test_Tile_2) { auto exp = x.tile(reps); nd4j::ops::tile op; - auto result = op.execute({&x}, {}, {2, 2}); + auto result = op.evaluate({&x}, {}, {2, 2}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -77,7 +77,7 @@ TEST_F(DeclarableOpsTests3, Test_Permute_1) { auto exp= NDArrayFactory::create('c', {2, 4, 3}); nd4j::ops::permute op; - auto result = op.execute({&x, &permute}, {}, {}); + auto result = op.evaluate({&x, &permute}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -92,7 +92,7 @@ TEST_F(DeclarableOpsTests3, Test_Permute_2) { auto exp= NDArrayFactory::create('c', {4, 3, 2}); nd4j::ops::permute op; - auto result = op.execute({&x}, {}, {}); + auto result = op.evaluate({&x}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -110,7 +110,7 @@ TEST_F(DeclarableOpsTests3, Test_Unique_1) { // auto expI= NDArrayFactory::create('c', {3}, {0, 1, 4}); nd4j::ops::unique op; - auto result = op.execute({&x}, {}, {}); + auto result = op.evaluate({&x}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(2, result->size()); @@ -136,7 +136,7 @@ TEST_F(DeclarableOpsTests3, Test_Unique_2) { auto expC= NDArrayFactory::create('c', {3}, {2, 2, 1}); nd4j::ops::unique_with_counts op; - auto result = op.execute({&x}, {}, {}); + auto result = op.evaluate({&x}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(3, result->size()); @@ -169,7 +169,7 @@ TEST_F(DeclarableOpsTests3, Test_Rint_1) { auto exp= NDArrayFactory::create('c', {1, 7}, {-2.f, -2.f, -0.f, 0.f, 2.f, 2.f, 2.f}); nd4j::ops::rint op; - auto result = op.execute({&x}, {}, {}); + auto result = op.evaluate({&x}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -188,7 +188,7 @@ TEST_F(DeclarableOpsTests3, Test_Norm_1) { std::vector dims({1}); nd4j::ops::norm op; - auto result0 = op.execute({&x}, {0.}, {}); + auto result0 = op.evaluate({&x}, {0.}, {}); auto z0 = result0->at(0); auto exp0 = x.reduceAlongDimension(reduce::NormFrobenius, empty, false, false); @@ -197,7 +197,7 @@ TEST_F(DeclarableOpsTests3, Test_Norm_1) { delete result0; - auto result1 = op.execute({&x}, {1.}, {1}); + auto result1 = op.evaluate({&x}, {1.}, {1}); ASSERT_EQ(result1->status(), ND4J_STATUS_OK); auto z1 = result1->at(0); // z1->printIndexedBuffer("Z1"); @@ -210,7 +210,7 @@ TEST_F(DeclarableOpsTests3, Test_Norm_1) { delete result1; - auto result4 = op.execute({&x}, {4.}, {1}); + auto result4 = op.evaluate({&x}, {4.}, {1}); auto z4 = result4->at(0); auto exp4= x.reduceAlongDimension(reduce::NormMax, dims, false, false); @@ -230,7 +230,7 @@ TEST_F(DeclarableOpsTests3, Test_Norm_2) { std::vector dims({1}); nd4j::ops::norm op; - auto result0 = op.execute({&x}, {0}, {}); + auto result0 = op.evaluate({&x}, {0}, {}); auto z0 = result0->at(0); auto exp0 = x.reduceAlongDimension(reduce::NormFrobenius, empty, false, false); @@ -239,7 +239,7 @@ TEST_F(DeclarableOpsTests3, Test_Norm_2) { delete result0; - auto result1 = op.execute({&x, &axis}, {1}, {}); + auto result1 = op.evaluate({&x, &axis}, {1}, {}); auto z1 = result1->at(0); auto exp1 = x.reduceAlongDimension(reduce::Norm2, dims, false, false); @@ -248,7 +248,7 @@ TEST_F(DeclarableOpsTests3, Test_Norm_2) { delete result1; - auto result4 = op.execute({&x, &axis}, {4}, {}); + auto result4 = op.evaluate({&x, &axis}, {4}, {}); auto z4 = result4->at(0); auto exp4= x.reduceAlongDimension(reduce::NormMax, dims, false, false); @@ -264,7 +264,7 @@ TEST_F(DeclarableOpsTests3, Test_ClipByAvgNorm_1) { auto exp = NDArrayFactory::create('c', {2, 3}, {-2.88, 0.0, 0.0, 3.84, 0.0, 0.0}); nd4j::ops::clipbyavgnorm op; - auto result = op.execute({&x}, {0.8}, {}, {}, false, nd4j::DataType::DOUBLE); + auto result = op.evaluate({&x}, {0.8}, {}); auto z = result->at(0); @@ -279,7 +279,7 @@ TEST_F(DeclarableOpsTests3, Test_ClipByAvgNorm_2) { auto exp= NDArrayFactory::create('c', {2, 3}, {-3.f, 0.0f, 0.0f, 4.f, 0.0f, 0.0f}); nd4j::ops::clipbyavgnorm op; - auto result = op.execute({&x}, {0.9}, {}); + auto result = op.evaluate({&x}, {0.9}, {}); auto z = result->at(0); @@ -295,7 +295,7 @@ TEST_F(DeclarableOpsTests3, Test_ClipByNorm_1) { auto exp= NDArrayFactory::create('c', {2, 3}, {-2.4, 0.0, 0.0, 3.2, 0.0, 0.0}); nd4j::ops::clipbynorm op; - auto result = op.execute({&x}, {4.0}, {}); + auto result = op.evaluate({&x}, {4.0}, {}); auto z = result->at(0); @@ -310,7 +310,7 @@ TEST_F(DeclarableOpsTests3, Test_ClipByNorm_2) { auto exp= NDArrayFactory::create('c', {2, 3}, {-3.0f, 0.0f, 0.0f, 4.0f, 0.0f, 0.0f}); nd4j::ops::clipbynorm op; - auto result = op.execute({&x}, {6.0}, {}); + auto result = op.evaluate({&x}, {6.0}, {}); auto z = result->at(0); @@ -340,7 +340,7 @@ TEST_F(DeclarableOpsTests3, Test_ClipByNorm_3) { xNorm1 = x.reduceAlongDimension(reduce::Norm2, {1}, true); nd4j::ops::clipbynorm op; - auto result = op.execute({&x}, {1.0}, {1}, {}, false, nd4j::DataType::DOUBLE); + auto result = op.evaluate({&x}, {1.0}, {1}); auto z = result->at(0); auto zNorm1 = z->reduceAlongDimension(reduce::Norm2, {1}, true); @@ -360,7 +360,7 @@ TEST_F(DeclarableOpsTests3, Test_ListDiff_1) { auto exp1= NDArrayFactory::create('c', {3}, {1, 3, 5}); nd4j::ops::listdiff op; - auto result = op.execute({&x, &y}, {}, {}); + auto result = op.evaluate({&x, &y}); ASSERT_EQ(Status::OK(), result->status()); @@ -386,7 +386,7 @@ TEST_F(DeclarableOpsTests3, Test_Range_1) { auto exp= NDArrayFactory::create('c', {17}, { 0.3f, -0.03f, -0.36f, -0.69f, -1.02f, -1.35f, -1.68f, -2.01f, -2.34f, -2.67f, -3.f, -3.33f, -3.66f, -3.99f, -4.32f, -4.65f, -4.98f}); nd4j::ops::range op; - auto result = op.execute({&start, &stop, &step}, {}, {}); + auto result = op.evaluate({&start, &stop, &step}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -406,7 +406,7 @@ TEST_F(DeclarableOpsTests3, Test_Range_2) { auto exp= NDArrayFactory::create('c', {2}, {2.f, 1.f}); nd4j::ops::range op; - auto result = op.execute({&start, &stop, &step}, {}, {}); + auto result = op.evaluate({&start, &stop, &step}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -425,7 +425,7 @@ TEST_F(DeclarableOpsTests3, Test_Range_3) { auto exp= NDArrayFactory::create('c', {2}, {0.f, 1.f}); nd4j::ops::range op; - auto result = op.execute({&start, &stop, &step}, {}, {}); + auto result = op.evaluate({&start, &stop, &step}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -442,7 +442,7 @@ TEST_F(DeclarableOpsTests3, Test_Range_4) { auto exp= NDArrayFactory::create('c', {13}, {-10.f, -8.334f, -6.668f, -5.002f, -3.336f, -1.67f, -0.004f, 1.662f, 3.328f, 4.994f, 6.66f, 8.326f, 9.992f}); nd4j::ops::range op; - auto result = op.execute({}, {-10., 10., 1.666}, {}); + auto result = op.evaluate({}, {-10., 10., 1.666}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -459,7 +459,7 @@ TEST_F(DeclarableOpsTests3, Test_Range_5) { auto exp= NDArrayFactory::create('c', {2}, {2.f, 1.f}); nd4j::ops::range op; - auto result = op.execute({}, {2, 0, -1}, {}); + auto result = op.evaluate({}, {2, 0, -1}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -475,7 +475,7 @@ TEST_F(DeclarableOpsTests3, Test_Range_6) { auto exp= NDArrayFactory::create('c', {2}, {0.f, 1.f}); nd4j::ops::range op; - auto result = op.execute({}, {0, 2, 1}, {}); + auto result = op.evaluate({}, {0, 2, 1}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -491,7 +491,7 @@ TEST_F(DeclarableOpsTests3, Test_Range_7) { auto exp= NDArrayFactory::create('c', {10}, {10.f, 8.334f, 6.668f, 5.002f, 3.336f, 1.67f, 0.004f, -1.662f, -3.328f, -4.994f}); nd4j::ops::range op; - auto result = op.execute({}, {10,-5,-1.666}, {}); + auto result = op.evaluate({}, {10,-5,-1.666}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -509,7 +509,7 @@ TEST_F(DeclarableOpsTests3, Test_Range_8) { auto exp= NDArrayFactory::create('c', {2}, {2, 1}); nd4j::ops::range op; - auto result = op.execute({}, {}, {2, 0, -1}); + auto result = op.evaluate({}, {}, {2, 0, -1}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -525,7 +525,7 @@ TEST_F(DeclarableOpsTests3, Test_Range_9) { auto exp= NDArrayFactory::create('c', {2}, {0, 1}); nd4j::ops::range op; - auto result = op.execute({}, {}, {0, 2, 1}); + auto result = op.evaluate({}, {}, {0, 2, 1}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -546,7 +546,7 @@ TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_1) { auto exp = MmulHelper::mmul(&x, &y); nd4j::ops::batched_gemm op; - auto result = op.execute({&a, &b, &x, &x, &x, &y, &y, &y}, {}, {111, 111, 3, 3, 3, 3, 3, 3, 3}); + auto result = op.evaluate({&a, &b, &x, &x, &x, &y, &y, &y}, {}, {111, 111, 3, 3, 3, 3, 3, 3, 3}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(3, result->size()); @@ -574,7 +574,7 @@ TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_2) { auto exp = MmulHelper::mmul(&x, &y); nd4j::ops::batched_gemm op; - auto result = op.execute({&a, &b, &x, &x, &x, &y, &y, &y}, {}, {112, 112, 3, 3, 3, 3, 3, 3, 3}); + auto result = op.evaluate({&a, &b, &x, &x, &x, &y, &y, &y}, {}, {112, 112, 3, 3, 3, 3, 3, 3, 3}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(3, result->size()); @@ -602,7 +602,7 @@ TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_3) { auto exp = MmulHelper::mmul(&x, &y); nd4j::ops::batched_gemm op; - auto result = op.execute({&a, &b, &x, &x, &x, &y, &y, &y}, {}, {112, 111, 3, 3, 3, 3, 3, 3, 3}); + auto result = op.evaluate({&a, &b, &x, &x, &x, &y, &y, &y}, {}, {112, 111, 3, 3, 3, 3, 3, 3, 3}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(3, result->size()); @@ -630,7 +630,7 @@ TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_4) { auto exp = MmulHelper::mmul(&x, &y); nd4j::ops::batched_gemm op; - auto result = op.execute({&a, &b, &x, &x, &x, &y, &y, &y}, {}, {111, 111, 5, 4, 3, 5, 3, 5, 3}); + auto result = op.evaluate({&a, &b, &x, &x, &x, &y, &y, &y}, {}, {111, 111, 5, 4, 3, 5, 3, 5, 3}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(3, result->size()); @@ -658,7 +658,7 @@ TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_5) { auto exp = MmulHelper::mmul(&x, &y); nd4j::ops::batched_gemm op; - auto result = op.execute({&a, &b, &x, &x, &x, &y, &y, &y}, {}, {112, 112, 5, 4, 3, 3, 4, 5, 3}); + auto result = op.evaluate({&a, &b, &x, &x, &x, &y, &y, &y}, {}, {112, 112, 5, 4, 3, 3, 4, 5, 3}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(3, result->size()); @@ -687,7 +687,7 @@ TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_6) { auto exp = MmulHelper::mmul(&x, &y); nd4j::ops::batched_gemm op; - auto result = op.execute({&a, &b, &x, &x, &x, &y, &y, &y}, {}, {111, 111, 2, 3, 5, 2, 5, 2, 3}); + auto result = op.evaluate({&a, &b, &x, &x, &x, &y, &y, &y}, {}, {111, 111, 2, 3, 5, 2, 5, 2, 3}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(3, result->size()); @@ -717,7 +717,7 @@ TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_7) { // exp->printShapeInfo("exp shape"); nd4j::ops::batched_gemm op; - auto result = op.execute({&a, &b, &x, &x, &x, &y, &y, &y}, {}, {112, 112, 2, 3, 5, 5, 3, 2, 3}); + auto result = op.evaluate({&a, &b, &x, &x, &x, &y, &y, &y}, {}, {112, 112, 2, 3, 5, 5, 3, 2, 3}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(3, result->size()); @@ -744,7 +744,7 @@ TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_Validation_1) { nd4j::ops::batched_gemm op; try { - auto result = op.execute({&a, &b, &x, &x, &x, &y, &y, &y}, {}, {112, 112, 2, 3, 5, 5, 3, 2, 3}); + auto result = op.evaluate({&a, &b, &x, &x, &x, &y, &y, &y}, {}, {112, 112, 2, 3, 5, 5, 3, 2, 3}); delete result; ASSERT_TRUE(false); } catch (std::invalid_argument &e) { @@ -775,7 +775,7 @@ TEST_F(DeclarableOpsTests3, Test_Manual_Gemm_1) { auto exp= NDArrayFactory::create('f', {4, 4}, {38.0, 44.0, 50.0, 56.0, 83.0, 98.0, 113.0, 128.0, 128.0, 152.0, 176.0, 200.0, 173.0, 206.0, 239.0, 272.0}); nd4j::ops::matmul op; - auto result = op.execute({&x, &y}, {}, {1, 1}); + auto result = op.evaluate({&x, &y}, {}, {1, 1}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -794,7 +794,7 @@ TEST_F(DeclarableOpsTests3, Test_Manual_Gemm_2) { auto exp= NDArrayFactory::create('f', {3, 3}, {70.0, 158.0, 246.0, 80.0, 184.0, 288.0, 90.0, 210.0, 330.0}); nd4j::ops::matmul op; - auto result = op.execute({&x, &y}, {}, {0, 0}); + auto result = op.evaluate({&x, &y}, {}, {0, 0}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -813,7 +813,7 @@ TEST_F(DeclarableOpsTests3, Test_Manual_Gemm_3) { auto exp= NDArrayFactory::create('f', {3, 4}, {1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 3.0, 6.0, 9.0, 4.0, 8.0, 12.0}); nd4j::ops::matmul op; - auto result = op.execute({&x, &y}, {}, {1, 0}); + auto result = op.evaluate({&x, &y}, {}, {1, 0}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -832,7 +832,7 @@ TEST_F(DeclarableOpsTests3, Test_Manual_Gemm_4) { auto exp= NDArrayFactory::create('f', {3, 4}, {1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 3.0, 6.0, 9.0, 4.0, 8.0, 12.0}); nd4j::ops::matmul op; - auto result = op.execute({&x, &y}, {}, {0, 1}); + auto result = op.evaluate({&x, &y}, {}, {0, 1}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -851,7 +851,7 @@ TEST_F(DeclarableOpsTests3, Test_Manual_Gemm_5) { auto exp= NDArrayFactory::create('f', {3, 4}, {1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 3.0, 6.0, 9.0, 4.0, 8.0, 12.0}); nd4j::ops::matmul op; - auto result = op.execute({&x, &y}, {}, {}); + auto result = op.evaluate({&x, &y}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -870,7 +870,7 @@ TEST_F(DeclarableOpsTests3, Test_Manual_Gemm_6) { auto exp= NDArrayFactory::create('f', {4, 4}, {1,2, 3, 4,2,4, 6, 8,3,6, 9,12,4,8,12,16}); nd4j::ops::matmul op; - auto result = op.execute({&x, &y}, {}, {}); + auto result = op.evaluate({&x, &y}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -889,7 +889,7 @@ TEST_F(DeclarableOpsTests3, Test_ReverseDivide_1) { auto exp= NDArrayFactory::create('c', {1, 3}, {2, 3, 4}); nd4j::ops::reversedivide op; - auto result = op.execute({&x, &y}, {}, {}); + auto result = op.evaluate({&x, &y}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -920,7 +920,7 @@ TEST_F(DeclarableOpsTests3, sruCell_test1) { auto expCt= NDArrayFactory::create('c', {batchSize, inSize}, {2.01958286f, 2.01958286f, 2.01958286f, 2.01958286f, 2.01958286f, 2.01958286f, 2.01958286f, 2.01958286f, 2.01958286f, 2.01958286f}); nd4j::ops::sruCell op; - auto results = op.execute({&xt, &ct_1, &w, &b}, {}, {}); + auto results = op.evaluate({&xt, &ct_1, &w, &b}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -956,7 +956,7 @@ TEST_F(DeclarableOpsTests3, sruCell_test2) { auto expCt= NDArrayFactory::create('c', {batchSize, inSize}, {2.09121276f, 2.09121276f, 2.09121276f, 2.09121276f, 2.09121276f, 2.09121276f, 2.09121276f, 2.09121276f, 2.09121276f, 2.09121276f}); nd4j::ops::sruCell op; - auto results = op.execute({&xt, &ct_1, &w, &b}, {}, {}); + auto results = op.evaluate({&xt, &ct_1, &w, &b}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -991,7 +991,7 @@ TEST_F(DeclarableOpsTests3, sruCell_test3) { auto expCt= NDArrayFactory::create('c', {batchSize, inSize}, {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}); nd4j::ops::sruCell op; - auto results = op.execute({&xt, &ct_1, &w, &b}, {}, {}); + auto results = op.evaluate({&xt, &ct_1, &w, &b}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1030,7 +1030,7 @@ TEST_F(DeclarableOpsTests3, gruCell_test1) { auto expHt = NDArrayFactory::create('c', {batchSize, numUnits}, {1.99993872f, 1.99993872f, 1.99993872f, 1.99993872f, 1.99993872f, 1.99993872f, 1.99993872f, 1.99993872f}); nd4j::ops::gruCell op; - auto results = op.execute({&xt, &ht_1, &Wru, &Wc, &bru, &bc}, {}, {}); + auto results = op.evaluate({&xt, &ht_1, &Wru, &Wc, &bru, &bc}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1066,7 +1066,7 @@ TEST_F(DeclarableOpsTests3, gruCell_test2) { auto expHt= NDArrayFactory::create('c', {batchSize, numUnits}, {0.00669224f, 0.00669224f, 0.00669224f, 0.00669224f, 0.00669224f, 0.00669224f, 0.00669224f, 0.00669224f}); nd4j::ops::gruCell op; - auto results = op.execute({&xt, &ht_1, &Wru, &Wc, &bru, &bc}, {}, {}); + auto results = op.evaluate({&xt, &ht_1, &Wru, &Wc, &bru, &bc}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1102,7 +1102,7 @@ TEST_F(DeclarableOpsTests3, gruCell_test3) { auto expHt= NDArrayFactory::create('c', {batchSize, numUnits}, {0.1149149f, 0.1149149f, 0.1149149f, 0.1149149f, 0.1149149f, 0.1149149f, 0.1149149f, 0.1149149f}); nd4j::ops::gruCell op; - auto results = op.execute({&xt, &ht_1, &Wru, &Wc, &bru, &bc}, {}, {}); + auto results = op.evaluate({&xt, &ht_1, &Wru, &Wc, &bru, &bc}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1121,7 +1121,7 @@ TEST_F(DeclarableOpsTests3, invertPermutation_test1) { auto expected= NDArrayFactory::create('c', {1, 8}, {7, 6, 1, 5, 3, 0, 4, 2}); nd4j::ops::invert_permutation op; - auto results = op.execute({&input}, {}, {}); + auto results = op.evaluate({&input}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1140,7 +1140,7 @@ TEST_F(DeclarableOpsTests3, invertPermutation_test2) { auto expected= NDArrayFactory::create('c', {1, 8}, {7, 6, 1, 5, 3, 0, 4, 2}); nd4j::ops::invert_permutation op; - auto results = op.execute({&input}, {}, {}); + auto results = op.evaluate({&input}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1159,7 +1159,7 @@ TEST_F(DeclarableOpsTests3, invertPermutation_test3) { auto expected= NDArrayFactory::create('c', {1, 8}, {2, 0, 1, 5, 3, 6, 4, 7}); nd4j::ops::invert_permutation op; - auto results = op.execute({&input}, {}, {}); + auto results = op.evaluate({&input}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1180,7 +1180,7 @@ TEST_F(DeclarableOpsTests3, diag_test1) { auto expected= NDArrayFactory::create('c', {3,2,3,2}, {1,0,0,0,0,0, 0,2,0,0,0,0, 0,0,3,0,0,0, 0,0,0,4,0,0, 0,0,0,0,5,0, 0,0,0,0,0,6}); nd4j::ops::diag op; - auto results = op.execute({&input}, {}, {}); + auto results = op.evaluate({&input}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1201,7 +1201,7 @@ TEST_F(DeclarableOpsTests3, diag_test2) { auto expected= NDArrayFactory::create('c', {2,3,2,3}, {1,0,0,0,0,0, 0,2,0,0,0,0, 0,0,3,0,0,0, 0,0,0,4,0,0, 0,0,0,0,5,0, 0,0,0,0,0,6}); nd4j::ops::diag op; - auto results = op.execute({&input}, {}, {}); + auto results = op.evaluate({&input}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1222,7 +1222,7 @@ TEST_F(DeclarableOpsTests3, diag_test_vector) { auto expected= NDArrayFactory::create('c', {4,4}, {1,0,0,0, 0,2,0,0, 0,0,3,0,0,0,0,4}); nd4j::ops::diag op; - auto results = op.execute({input}, {}, {}); + auto results = op.evaluate({input}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1246,7 +1246,7 @@ TEST_F(DeclarableOpsTests3, diag_test_col_vector) { auto expected= NDArrayFactory::create('c', {4,4}, {1,0,0,0, 0,2,0,0, 0,0,3,0,0,0,0,4}); nd4j::ops::diag op; - auto results = op.execute({input}, {}, {}); + auto results = op.evaluate({input}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1267,7 +1267,7 @@ TEST_F(DeclarableOpsTests3, diag_test3) { auto expected= NDArrayFactory::create('c', {3,3}, {1,0,0, 0,2,0, 0,0,3}); nd4j::ops::diag op; - auto results = op.execute({&input}, {}, {}); + auto results = op.evaluate({&input}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1288,7 +1288,7 @@ TEST_F(DeclarableOpsTests3, diag_test4) { auto expected= NDArrayFactory::create('c', {3,3}, {1,0,0, 0,2,0, 0,0,3}); nd4j::ops::diag op; - auto results = op.execute({&input}, {}, {}); + auto results = op.evaluate({&input}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1309,7 +1309,7 @@ TEST_F(DeclarableOpsTests3, diag_test5) { auto expected= NDArrayFactory::create('c', {1,1}, {2}); nd4j::ops::diag op; - auto results = op.execute({&input}, {}, {}); + auto results = op.evaluate({&input}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1330,7 +1330,7 @@ TEST_F(DeclarableOpsTests3, diag_test6) { auto expected= NDArrayFactory::create('c', {2,2,2,2,2,2}, {1,0,0,0, 0,0,0,0, 0,2,0,0, 0,0,0,0, 0,0,3,0, 0,0,0,0, 0,0,0,4, 0,0,0,0, 0,0,0,0, 5,0,0,0, 0,0,0,0, 0,6,0,0, 0,0,0,0, 0,0,7,0, 0,0,0,0, 0,0,0,8}); nd4j::ops::diag op; - auto results = op.execute({&input}, {}, {}); + auto results = op.evaluate({&input}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1353,7 +1353,7 @@ TEST_F(DeclarableOpsTests3, matrixSetDiag_test1) { auto expected= NDArrayFactory::create('c', {4,3,2}, {1,0,0,1,0,0, 1,0,0,1,0,0, 1,0,0,1,0,0, 1,0,0,1,0,0}); nd4j::ops::matrix_set_diag op; - auto results = op.execute({&input, &diagonal}, {}, {}); + auto results = op.evaluate({&input, &diagonal}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1376,7 +1376,7 @@ TEST_F(DeclarableOpsTests3, matrixSetDiag_test2) { auto expected= NDArrayFactory::create('c', {1,1,2}, {1.f, 0.f}); nd4j::ops::matrix_set_diag op; - auto results = op.execute({&input, &diagonal}, {}, {}); + auto results = op.evaluate({&input, &diagonal}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1399,7 +1399,7 @@ TEST_F(DeclarableOpsTests3, matrixSetDiag_test3) { auto expected= NDArrayFactory::create('c', {2,1,4}, {1,0,0,0,1,0,0,0}); nd4j::ops::matrix_set_diag op; - auto results = op.execute({&input, &diagonal}, {}, {}); + auto results = op.evaluate({&input, &diagonal}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1422,7 +1422,7 @@ TEST_F(DeclarableOpsTests3, matrixSetDiag_test4) { auto expected= NDArrayFactory::create('c', {2,1,4,1}, {1,0,0,0,1,0,0,0}); nd4j::ops::matrix_set_diag op; - auto results = op.execute({&input, &diagonal}, {}, {}); + auto results = op.evaluate({&input, &diagonal}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1443,7 +1443,7 @@ TEST_F(DeclarableOpsTests3, diagPart_test1) { auto expected= NDArrayFactory::create('c', {2}, {1,4}); nd4j::ops::diag_part op; - auto results = op.execute({&input}, {}, {}); + auto results = op.evaluate({&input}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1465,7 +1465,7 @@ TEST_F(DeclarableOpsTests3, diagPart_test2) { auto expected= NDArrayFactory::create('c', {2,2}, {1,6,11,16}); nd4j::ops::diag_part op; - auto results = op.execute({&input}, {}, {}); + auto results = op.evaluate({&input}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1486,7 +1486,7 @@ TEST_F(DeclarableOpsTests3, diagPart_test3) { auto expected= NDArrayFactory::create('c', {2,2,2}, {1,10,19,28,37,46,55,64}); nd4j::ops::diag_part op; - auto results = op.execute({&input}, {}, {}); + auto results = op.evaluate({&input}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1512,7 +1512,7 @@ TEST_F(DeclarableOpsTests3, betainc_test1) { auto expected = NDArrayFactory::create('c', {3,3}, {0.40638509f, 0.33668978f, 0.28271242f, 0.23973916f, 0.20483276f, 0.17604725f, 0.15203027f, 0.13180567f, 0.114647f}); nd4j::ops::betainc op; - auto results = op.execute({&a, &b, &x}, {}, {}); + auto results = op.evaluate({&a, &b, &x}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1538,7 +1538,7 @@ TEST_F(DeclarableOpsTests3, betainc_test2) { auto expected= NDArrayFactory::create('c', {3,3}, {0.40638509f, 0.33668978f, 0.28271242f, 0.23973916f, 0.20483276f, 0.17604725f, 0.15203027f, 0.13180567f, 0.114647f}); nd4j::ops::betainc op; - auto results = op.execute({&a, &b, &x}, {}, {}); + auto results = op.evaluate({&a, &b, &x}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1564,7 +1564,7 @@ TEST_F(DeclarableOpsTests3, betainc_test3) { auto expected= NDArrayFactory::create('c', {3,3}, {0.40638509f, 0.33668978f, 0.28271242f, 0.23973916f, 0.20483276f, 0.17604725f, 0.15203027f, 0.13180567f, 0.114647f}); nd4j::ops::betainc op; - auto results = op.execute({&a, &b, &x}, {}, {}); + auto results = op.evaluate({&a, &b, &x}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1590,7 +1590,7 @@ TEST_F(DeclarableOpsTests3, betainc_test4) { auto expected= NDArrayFactory::create('c', {3,3}, {1.00000000e-01f, 2.80000000e-02f, 8.56000000e-03f, 2.72800000e-03f, 8.90920000e-04f, 2.95706080e-04f, 9.92854864e-05f, 3.36248880e-05f, 1.14644360e-05f}); nd4j::ops::betainc op; - auto results = op.execute({&a, &b, &x}, {}, {}); + auto results = op.evaluate({&a, &b, &x}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1616,7 +1616,7 @@ TEST_F(DeclarableOpsTests3, betainc_test5) { auto expected= NDArrayFactory::create('c', {3,3}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}); nd4j::ops::betainc op; - auto results = op.execute({&a, &b, &x}, {}, {}); + auto results = op.evaluate({&a, &b, &x}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1642,7 +1642,7 @@ TEST_F(DeclarableOpsTests3, betainc_test6) { auto expected= NDArrayFactory::create('c', {3,3}, {3.92988233e-06f, 1.35306497e-06f, 4.67576826e-07f, 1.62083416e-07f, 5.63356971e-08f, 1.96261318e-08f, 6.85120307e-09f, 2.39594668e-09f, 8.39227685e-10f}); nd4j::ops::betainc op; - auto results = op.execute({&a, &b, &x}, {}, {}); + auto results = op.evaluate({&a, &b, &x}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1668,7 +1668,7 @@ TEST_F(DeclarableOpsTests3, betainc_test7) { auto expected= NDArrayFactory::create('c', {3,3}, {0.99999607f, 0.99999865f, 0.99999953f, 0.99999984f, 0.99999994f, 0.99999998f, 0.99999999f, 1.f, 1.f}); nd4j::ops::betainc op; - auto results = op.execute({&a, &b, &x}, {}, {}); + auto results = op.evaluate({&a, &b, &x}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1694,7 +1694,7 @@ TEST_F(DeclarableOpsTests3, betainc_test8) { auto expected= NDArrayFactory::create('c', {3,3}, {1.f, 1.f, 1.f,1.f,1.f,1.f,1.f,1.f,1.f}); nd4j::ops::betainc op; - auto results = op.execute({&a, &b, &x}, {}, {}); + auto results = op.evaluate({&a, &b, &x}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1720,7 +1720,7 @@ TEST_F(DeclarableOpsTests3, betainc_test9) { auto expected= NDArrayFactory::create('c', {3,3}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}); nd4j::ops::betainc op; - auto results = op.execute({&a, &b, &x}, {}, {}); + auto results = op.evaluate({&a, &b, &x}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1746,7 +1746,7 @@ TEST_F(DeclarableOpsTests3, betainc_test10) { auto expected= NDArrayFactory::create('c', {3,3}, {0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f}); nd4j::ops::betainc op; - auto results = op.execute({&a, &b, &x}, {}, {}); + auto results = op.evaluate({&a, &b, &x}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1767,7 +1767,7 @@ TEST_F(DeclarableOpsTests3, betainc_test11) { NDArray expected('c', {4}, {0.912156, 0.634460, 0.898314, 0.624538}, nd4j::DataType::FLOAT32); nd4j::ops::betainc op; - auto results = op.execute({&a, &b, &x}, {}, {}); + auto results = op.evaluate({&a, &b, &x}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1789,7 +1789,7 @@ TEST_F(DeclarableOpsTests3, betainc_test12) { NDArray expected('c', {4}, {0.9999995 , 0.8594694 , 0.999988 , 0.49124345}, nd4j::DataType::FLOAT32); nd4j::ops::betainc op; - auto results = op.execute({&a, &b, &x}, {}, {}); + auto results = op.evaluate({&a, &b, &x}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1813,7 +1813,7 @@ TEST_F(DeclarableOpsTests3, zeta_test1) { auto expected= NDArrayFactory::create('c', {3,3}, {1.64493407f, 0.64493407f, 0.39493407f, 0.28382296f, 0.22132296f, 0.18132296f, 0.15354518f, 0.13313701f, 0.11751201f}); nd4j::ops::zeta op; - auto results = op.execute({&x, &q}, {}, {}); + auto results = op.evaluate({&x, &q}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1837,7 +1837,7 @@ TEST_F(DeclarableOpsTests3, zeta_test2) { auto expected= NDArrayFactory::create('c', {3,3}, {0.10516634f, 0.09516634f, 0.08690187f, 0.07995743f, 0.07404027f, 0.06893823f, 0.06449378f, 0.06058753f, 0.05712733f}); nd4j::ops::zeta op; - auto results = op.execute({&x, &q}, {}, {}); + auto results = op.evaluate({&x, &q}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1861,7 +1861,7 @@ TEST_F(DeclarableOpsTests3, zeta_test3) { auto expected= NDArrayFactory::create('c', {3,3}, {0.01005017f, 0.00995017f, 0.00985214f, 0.00975602f, 0.00966176f, 0.0095693f, 0.0094786f, 0.0093896f, 0.00930226f}); nd4j::ops::zeta op; - auto results = op.execute({&x, &q}, {}, {}); + auto results = op.evaluate({&x, &q}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1886,7 +1886,7 @@ TEST_F(DeclarableOpsTests3, zeta_test4) { auto expected= NDArrayFactory::create('c', {3,3}, {0.01005017f, 0.00995017f, 0.00985214f, 0.00975602f, 0.00966176f, 0.0095693f, 0.0094786f, 0.0093896f, 0.00930226f}); nd4j::ops::zeta op; - auto results = op.execute({&x, &q}, {}, {}); + auto results = op.evaluate({&x, &q}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1910,7 +1910,7 @@ TEST_F(DeclarableOpsTests3, zeta_test5) { auto expected= NDArrayFactory::create('c', {3,3}, {10.58444846f, 9.58444846f, 9.11793197f, 8.81927915f, 8.60164151f, 8.43137352f, 8.29204706f, 8.17445116f, 8.07291961f}); nd4j::ops::zeta op; - auto results = op.execute({&x, &q}, {}, {}); + auto results = op.evaluate({&x, &q}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1934,7 +1934,7 @@ TEST_F(DeclarableOpsTests3, zeta_test6) { auto expected= NDArrayFactory::create('c', {3,3}, {100.57794334f, 99.57794334f, 99.08139709f, 98.75170576f, 98.50514758f, 98.30834069f, 98.1446337f, 98.00452955f, 97.88210202f}); nd4j::ops::zeta op; - auto results = op.execute({&x, &q}, {}, {}); + auto results = op.evaluate({&x, &q}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1958,7 +1958,7 @@ TEST_F(DeclarableOpsTests3, zeta_test7) { auto expected= NDArrayFactory::create('c', {3,3}, {1.00099458e+00f, 9.94575128e-04f, 1.80126278e-05f, 1.07754001e-06f, 1.23865693e-07f, 2.14656932e-08f, 4.92752156e-09f, 1.38738839e-09f, 4.56065812e-10f}); nd4j::ops::zeta op; - auto results = op.execute({&x, &q}, {}, {}); + auto results = op.evaluate({&x, &q}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1982,7 +1982,7 @@ TEST_F(DeclarableOpsTests3, zeta_test8) { auto expected= NDArrayFactory::create('c', {3,4}, {23.014574, 12.184081, 8.275731, 6.1532226, 4.776538, 3.7945523, 3.0541048, 2.4765317, 2.0163891, 205.27448, 21.090889, 19.477398}); nd4j::ops::zeta op; - auto results = op.execute({&x, &q}, {}, {}); + auto results = op.evaluate({&x, &q}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2054,7 +2054,7 @@ TEST_F(DeclarableOpsTests3, Test_SplitV_Validation_1) { auto z1 = NDArrayFactory::create('c', {3, 7}); nd4j::ops::split_v op; - auto status = op.execute({&x, &indices, &axis}, {&z0, &z1}, {}, {}, {}); + auto status = op.execute({&x, &indices, &axis}, std::vector{&z0, &z1}, {}, {}, {}); ASSERT_EQ(Status::OK(), status); } @@ -2070,7 +2070,7 @@ TEST_F(DeclarableOpsTests3, polygamma_test1) { auto expected= NDArrayFactory::create('c', {3,3}, {4.934802, -16.828796, 97.409088, -771.474243, 7691.113770, -92203.460938, 1290440.250000, -20644900.000000, 3.71595e+08}); nd4j::ops::polygamma op; - auto results = op.execute({&n, &x}, {}, {}); + auto results = op.evaluate({&n, &x}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2097,7 +2097,7 @@ TEST_F(DeclarableOpsTests3, polygamma_test2) { //ASSERT_FALSE(true); nd4j::ops::polygamma op; - auto results = op.execute({&n, &x}, {}, {}); + auto results = op.evaluate({&n, &x}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2120,7 +2120,7 @@ TEST_F(DeclarableOpsTests3, polygamma_test3) { auto expected= NDArrayFactory::create('c', {3,3}, {1.05166336e-01,-9.04983497e-03, 1.31009323e-03,-2.44459433e-04, 5.31593880e-05,-1.28049888e-05, 3.31755364e-06,-9.07408791e-07, 2.58758130e-07}); nd4j::ops::polygamma op; - auto results = op.execute({&n, &x}, {}, {}); + auto results = op.evaluate({&n, &x}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2141,7 +2141,7 @@ TEST_F(DeclarableOpsTests3, polygamma_test4) { 1.596005e+03, -4.876665e+03,4.510025e+04, -1.730340e+08, 6.110257e+05, -1.907087e+07}, nd4j::DataType::DOUBLE); nd4j::ops::polygamma op; - auto results = op.execute({&n, &x}, {}, {}); + auto results = op.evaluate({&n, &x}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2161,7 +2161,7 @@ TEST_F(DeclarableOpsTests3, digamma_1) { std::numeric_limits::infinity(),-5.28904,-0.577216, 0.03649, 0.544293, 1.549434,2.917892, 3.020524, 3.077401}, nd4j::DataType::DOUBLE); nd4j::ops::digamma op; - auto results = op.execute({&x}, {}, {}); + auto results = op.evaluate({&x}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2182,7 +2182,7 @@ TEST_F(DeclarableOpsTests3, svd_test1) { auto expV= NDArrayFactory::create('c', {6,6}, {-0.24577,-0.24512, 0.00401,-0.04585,-0.62058, 0.70162, 0.27937, 0.75961, 0.43885,-0.06857,-0.3839 , 0.01669,-0.35944,-0.09629, 0.44593, 0.78602,-0.09103,-0.19125, 0.53973, 0.07613,-0.10721, 0.49559, 0.35687, 0.56431,-0.6226 , 0.39742, 0.12785,-0.15716, 0.52372, 0.37297, 0.23113,-0.43578, 0.76204,-0.32414, 0.23996, 0.11543}); nd4j::ops::svd op; - auto results = op.execute({&x}, {}, {1, 1, 16}); + auto results = op.evaluate({&x}, {}, {1, 1, 16}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2219,7 +2219,7 @@ TEST_F(DeclarableOpsTests3, svd_test2) { auto expV= NDArrayFactory::create('c', {6,6}, {0.2508 ,-0.2265 , 0.01689, 0.04486, 0.53132, 0.77537,-0.32281, 0.74559, 0.41845, -0.13821, 0.37642, 0.06315, 0.33139,-0.05528, 0.47186, 0.73171, 0.18905, -0.3055 ,-0.57263, 0.06276,-0.09542, 0.59396, -0.36152, 0.419 , 0.59193, 0.4361 , 0.13557, -0.03632, -0.5755 , 0.32944,-0.21165,-0.44227, 0.75794, -0.29895, -0.27993, 0.13187}); nd4j::ops::svd op; - auto results = op.execute({&x}, {}, {1, 1, 16}); + auto results = op.evaluate({&x}, {}, {1, 1, 16}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2256,7 +2256,7 @@ TEST_F(DeclarableOpsTests3, svd_test3) { auto expV= NDArrayFactory::create('c', {6,6}, {0.2508 ,-0.2265 , 0.01689, 0.04486, 0.53132, 0.77537,-0.32281, 0.74559, 0.41845, -0.13821, 0.37642, 0.06315, 0.33139,-0.05528, 0.47186, 0.73171, 0.18905, -0.3055 ,-0.57263, 0.06276,-0.09542, 0.59396, -0.36152, 0.419 , 0.59193, 0.4361 , 0.13557, -0.03632, -0.5755 , 0.32944,-0.21165,-0.44227, 0.75794, -0.29895, -0.27993, 0.13187}); nd4j::ops::svd op; - auto results = op.execute({&x}, {}, {0, 1, 16}); + auto results = op.evaluate({&x}, {}, {0, 1, 16}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2293,7 +2293,7 @@ TEST_F(DeclarableOpsTests3, svd_test4) { auto expV= NDArrayFactory::create('c', {7,7}, {-0.35914, 0.68966, -0.30077, -0.15238, -0.48179, 0.14716, -0.16709, 0.21989, -0.34343, 0.11086, -0.78381, -0.37902, 0.24224, -0.06862, 0.32179, 0.12812, -0.25812, 0.0691 , -0.12891, 0.26979, 0.84807,-0.50833, 0.13793, 0.06658, -0.53001, 0.52572, -0.16194, 0.36692, 0.48118, 0.15876, -0.65132, -0.24602, 0.3963 , -0.16651, -0.27155,-0.31605, -0.46947, -0.50195, 0.0378 , -0.34937, -0.53062, 0.15069, 0.35957, 0.35408, 0.38732, -0.12154, -0.22827, -0.7151 , 0.13065}); nd4j::ops::svd op; - auto results = op.execute({&x}, {}, {1, 1, 16}); + auto results = op.evaluate({&x}, {}, {1, 1, 16}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2330,7 +2330,7 @@ TEST_F(DeclarableOpsTests3, svd_test5) { auto expV= NDArrayFactory::create('c', {7,6}, {-0.35914, 0.68966, -0.30077, -0.15238, -0.48179, 0.14716, 0.21989, -0.34343, 0.11086, -0.78381, -0.37902, 0.24224, 0.32179, 0.12812, -0.25812, 0.0691 , -0.12891, 0.26979,-0.50833, 0.13793, 0.06658, -0.53001, 0.52572, -0.16194, 0.48118, 0.15876, -0.65132, -0.24602, 0.3963 , -0.16651,-0.31605, -0.46947, -0.50195, 0.0378 , -0.34937, -0.53062, 0.35957, 0.35408, 0.38732, -0.12154, -0.22827, -0.7151}); nd4j::ops::svd op; - auto results = op.execute({&x}, {}, {0, 1, 16}); + auto results = op.evaluate({&x}, {}, {0, 1, 16}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2385,7 +2385,7 @@ TEST_F(DeclarableOpsTests3, svd_test6) { -0.51827, -0.31837, -0.16732, 0.71378, -0.30425,-0.39314, 0.15266, 0.63693, -0.30945, -0.5663 ,-0.51981, 0.03325, 0.37603, 0.05147, 0.76462,-0.01282, 0.92491, -0.08042, 0.36977, -0.03428}); nd4j::ops::svd op; - auto results = op.execute({&x}, {}, {1, 1, 16}); + auto results = op.evaluate({&x}, {}, {1, 1, 16}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2423,7 +2423,7 @@ TEST_F(DeclarableOpsTests3, svd_test7) { 39.34498, 32.54861, 17.52492, 7.03003, 2.2399,44.72126, 32.3164 , 16.60139, 6.88783, 0.78122}); nd4j::ops::svd op; - auto results = op.execute({&x}, {}, {0, 0, 16}); + auto results = op.evaluate({&x}, {}, {0, 0, 16}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2622,7 +2622,7 @@ TEST_F(DeclarableOpsTests3, svd_test9) { 1.31650000e-01, 7.57150000e-01, -4.89030000e-01, 3.47710000e-01,-4.39400000e-02, 2.17750000e-01,-6.57270000e-01, 2.91000000e-01, 4.17280000e-01, 2.52880000e-01,-4.63400000e-01, -1.74620000e-01}); nd4j::ops::svd op; - auto results = op.execute({&x}, {}, {1, 1, 16}); + auto results = op.evaluate({&x}, {}, {1, 1, 16}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2681,7 +2681,7 @@ TEST_F(DeclarableOpsTests3, svd_test10) { -4.39400000e-02,-6.57270000e-01, 2.91000000e-01, 4.17280000e-01, 2.52880000e-01,-4.63400000e-01}); nd4j::ops::svd op; - auto results = op.execute({&x}, {}, {0, 1, 16}); + auto results = op.evaluate({&x}, {}, {0, 1, 16}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2726,7 +2726,7 @@ TEST_F(DeclarableOpsTests3, svd_test11) { -0.43596, 0.83108, -0.34531}); nd4j::ops::svd op; - auto results = op.execute({&x}, {}, {0, 1, 16}); + auto results = op.evaluate({&x}, {}, {0, 1, 16}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2761,7 +2761,7 @@ TEST_F(DeclarableOpsTests3, svd_test12) { NDArray expS('c', {3}, {3.024703, 1.459483, 1.026371}); nd4j::ops::svd op; - auto results = op.execute({&x}, {}, {1, 0, 16}); + auto results = op.evaluate({&x}, {}, {1, 0, 16}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2780,7 +2780,7 @@ TEST_F(DeclarableOpsTests3, elu_test1) { auto exp = NDArrayFactory::create('c', {3,3}, {.1, .2, .3, 0.5*-0.32968, 0.5*-0.393469, 0.5*-0.451188, .7, .8, .9}); nd4j::ops::elu op; - auto results = op.execute({&x}, {0.5}, {}); + auto results = op.evaluate({&x}, {0.5}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2799,7 +2799,7 @@ TEST_F(DeclarableOpsTests3, elu_bp_test1) { auto exp = NDArrayFactory::create('c', {3, 3}, {2, 2, 2, 0.5*1.34064, 0.5*1.213061, 0.5*1.097623, 2, 2, 2}); nd4j::ops::elu_bp op; - auto results = op.execute({ &x, &eps }, {0.5}, {}); + auto results = op.evaluate({ &x, &eps }, {0.5}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2816,7 +2816,7 @@ TEST_F(DeclarableOpsTests3, lrelu_test1) { auto exp = NDArrayFactory::create('c', {3,3}, {1, 2, 3, -0.8, -1., -1.2, 7, 8, 9}); nd4j::ops::lrelu op; - auto results = op.execute({&x}, {0.2}, {}); + auto results = op.evaluate({&x}, {0.2}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2833,7 +2833,7 @@ TEST_F(DeclarableOpsTests3, lrelu_bp_test1) { auto exp = NDArrayFactory::create('c', {3,3}, {2, 2, 2, 0.4, 0.4, 0.4, 2, 2, 2}); nd4j::ops::lrelu_bp op; - auto results = op.execute({&x, &eps}, {0.2}, {}); + auto results = op.evaluate({&x, &eps}, {0.2}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2850,7 +2850,7 @@ TEST_F(DeclarableOpsTests3, selu_test1) { auto exp = NDArrayFactory::create('c', {3,3}, {1.050701, 2.101402, 3.152103, -1.725899, -1.746253, -1.753742, 7.354907, 8.405608, 9.456309}); nd4j::ops::selu op; - auto results = op.execute({&x}, {}, {}); + auto results = op.evaluate({&x}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2868,7 +2868,7 @@ TEST_F(DeclarableOpsTests3, selu_test2) { auto exp = NDArrayFactory::create('c', {3,3}, {2.101401, 2.101402, 2.101402, 0.064401, 0.023692, 0.008716, 2.101402, 2.101402, 2.101402}); nd4j::ops::selu_bp op; - auto results = op.execute({&x, &eps}, {0.2}, {}); + auto results = op.evaluate({&x, &eps}, {0.2}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2888,7 +2888,7 @@ TEST_F(DeclarableOpsTests3, EQScalarTests_1) { auto scalar = NDArrayFactory::create(1.0f); nd4j::ops::eq_scalar op; - auto res = op.evaluate({&x, &scalar}); + auto res = op.verify({&x, &scalar}); ASSERT_TRUE(res); } @@ -2900,7 +2900,7 @@ TEST_F(DeclarableOpsTests3, EQScalarTests_2) { auto scalar = NDArrayFactory::create(1.0f); nd4j::ops::eq_scalar op; - auto res = op.evaluate({&x, &scalar}); + auto res = op.verify({&x, &scalar}); ASSERT_FALSE(res); } @@ -2911,7 +2911,7 @@ TEST_F(DeclarableOpsTests3, GTScalarTests_1) { auto scalar = NDArrayFactory::create(1.0f); nd4j::ops::gt_scalar op; - auto res = op.evaluate({&x, &scalar}); + auto res = op.verify({&x, &scalar}); ASSERT_FALSE(res); } @@ -2922,7 +2922,7 @@ TEST_F(DeclarableOpsTests3, GTScalarTests_2) { auto scalar = NDArrayFactory::create(1.0f); nd4j::ops::gt_scalar op; - auto res = op.evaluate({&x, &scalar}); + auto res = op.verify({&x, &scalar}); ASSERT_TRUE(res); } @@ -2933,7 +2933,7 @@ TEST_F(DeclarableOpsTests3, GTEScalarTests_1) { auto scalar = NDArrayFactory::create(1.0f); nd4j::ops::gte_scalar op; - auto res = op.evaluate({&x, &scalar}); + auto res = op.verify({&x, &scalar}); ASSERT_TRUE(res); } @@ -2944,7 +2944,7 @@ TEST_F(DeclarableOpsTests3, GTEScalarTests_2) { auto scalar = NDArrayFactory::create(1.0f); nd4j::ops::gte_scalar op; - auto res = op.evaluate({&x, &scalar}); + auto res = op.verify({&x, &scalar}); ASSERT_TRUE(res); } @@ -2955,7 +2955,7 @@ TEST_F(DeclarableOpsTests3, GTEScalarTests_3) { auto scalar = NDArrayFactory::create(2.0f); nd4j::ops::gte_scalar op; - auto res = op.evaluate({&x, &scalar}); + auto res = op.verify({&x, &scalar}); ASSERT_FALSE(res); } @@ -2966,7 +2966,7 @@ TEST_F(DeclarableOpsTests3, LTEScalarTests_1) { auto scalar = NDArrayFactory::create(1.0f); nd4j::ops::lte_scalar op; - auto res = op.evaluate({&x, &scalar}); + auto res = op.verify({&x, &scalar}); ASSERT_TRUE(res); } @@ -2977,7 +2977,7 @@ TEST_F(DeclarableOpsTests3, LTEScalarTests_2) { auto scalar = NDArrayFactory::create(1.0f); nd4j::ops::lte_scalar op; - auto res = op.evaluate({&x, &scalar}); + auto res = op.verify({&x, &scalar}); ASSERT_FALSE(res); } @@ -2988,7 +2988,7 @@ TEST_F(DeclarableOpsTests3, LTEScalarTests_3) { auto scalar = NDArrayFactory::create(2.0f); nd4j::ops::lte_scalar op; - auto res = op.evaluate({&x, &scalar}); + auto res = op.verify({&x, &scalar}); ASSERT_TRUE(res); } @@ -2999,7 +2999,7 @@ TEST_F(DeclarableOpsTests3, NEQScalarTests_1) { auto scalar = NDArrayFactory::create(1.0f); nd4j::ops::neq_scalar op; - auto res = op.evaluate({&x, &scalar}); + auto res = op.verify({&x, &scalar}); ASSERT_FALSE(res); } @@ -3011,7 +3011,7 @@ TEST_F(DeclarableOpsTests3, NEQScalarTests_2) { auto scalar = NDArrayFactory::create(1.0f); nd4j::ops::neq_scalar op; - auto res = op.evaluate({&x, &scalar}); + auto res = op.verify({&x, &scalar}); ASSERT_TRUE(res); } @@ -3022,7 +3022,7 @@ TEST_F(DeclarableOpsTests3, NOOPTests_1) { auto scalar = NDArrayFactory::create(1.0f); nd4j::ops::noop op; - auto res = op.execute({&x, &scalar}, {}, {}); + auto res = op.evaluate({&x, &scalar}, {}, {}); ASSERT_TRUE(res->status() == nd4j::Status::OK()); delete res; } diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests4.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests4.cpp index 9460a053f..1e085d46c 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests4.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests4.cpp @@ -64,7 +64,7 @@ TYPED_TEST(TypedDeclarableOpsTests4, avgpool2d_1) { x.linspace(1); nd4j::ops::avgpool2d op; - auto result = op.execute({&x}, {}, {2, 2, 2, 2, 0, 0, 1, 1, 1, 1, 1}); + auto result = op.evaluate({&x}, {}, {2, 2, 2, 2, 0, 0, 1, 1, 1, 1, 1}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -85,7 +85,7 @@ TYPED_TEST(TypedDeclarableOpsTests4, avgpool2d_2) { nd4j::ops::avgpool2d op; - auto result = op.execute({&x}, {}, {2, 2, 2, 2, 0, 0, 1, 1, 0, 1, 1}); + auto result = op.evaluate({&x}, {}, {2, 2, 2, 2, 0, 0, 1, 1, 0, 1, 1}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -106,7 +106,7 @@ TYPED_TEST(TypedDeclarableOpsTests4, avgpool2d_3) { nd4j::ops::avgpool2d op; - auto result = op.execute({&x}, {}, {2, 2, 2, 2, 0, 0, 1, 1, 1, 0, 1}); + auto result = op.evaluate({&x}, {}, {2, 2, 2, 2, 0, 0, 1, 1, 1, 0, 1}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -126,7 +126,7 @@ TYPED_TEST(TypedDeclarableOpsTests4, avgpool2d_4) { x.linspace(1); nd4j::ops::avgpool2d op; - auto result = op.execute({&x}, {}, {2, 2, 2, 2, 0, 0, 1, 1, 0, 1, 1}); + auto result = op.evaluate({&x}, {}, {2, 2, 2, 2, 0, 0, 1, 1, 0, 1, 1}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -147,7 +147,7 @@ TYPED_TEST(TypedDeclarableOpsTests4, avgpool2d_5) { nd4j::ops::avgpool2d op; - auto result = op.execute({&x}, {}, {2, 2, 2, 2, 1, 1, 1, 1, 0, 0, 0}); + auto result = op.evaluate({&x}, {}, {2, 2, 2, 2, 1, 1, 1, 1, 0, 0, 0}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -168,7 +168,7 @@ TYPED_TEST(TypedDeclarableOpsTests4, avgpool2d_6) { nd4j::ops::avgpool2d op; - auto result = op.execute({&x}, {}, {2, 2, 2, 2, 1, 1, 1, 1, 0, 1, 0}); + auto result = op.evaluate({&x}, {}, {2, 2, 2, 2, 1, 1, 1, 1, 0, 1, 0}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -189,7 +189,7 @@ TYPED_TEST(TypedDeclarableOpsTests4, avgpool2d_7) { nd4j::ops::avgpool2d op; - auto result = op.execute({&x}, {}, {2, 2, 2, 2, 0, 0, 1, 1, 1, 0, 0}); + auto result = op.evaluate({&x}, {}, {2, 2, 2, 2, 0, 0, 1, 1, 1, 0, 0}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -208,9 +208,8 @@ TYPED_TEST(TypedDeclarableOpsTests4, avgpool2d_8) { x.linspace(1); - nd4j::ops::avgpool2d op; - auto result = op.execute({&x}, {}, {2, 2, 1, 1, 0, 0, 1, 1, 0, 0, 0}); + auto result = op.evaluate({&x}, {}, {2, 2, 1, 1, 0, 0, 1, 1, 0, 0, 0}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -231,7 +230,7 @@ TYPED_TEST(TypedDeclarableOpsTests4, avgpool2d_9) { nd4j::ops::avgpool2d op; - auto result = op.execute({&x}, {}, {2, 2, 1, 1, 0, 0, 1, 1, 1, 0, 0}); + auto result = op.evaluate({&x}, {}, {2, 2, 1, 1, 0, 0, 1, 1, 1, 0, 0}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -263,7 +262,7 @@ TYPED_TEST(TypedDeclarableOpsTests4, avgpool2d_10) { auto exp = NDArrayFactory::create('c', {4, 4, 4, 3}, {7.97172260f, 0.06878620f, 2.27749538f, 7.29276514f, -0.14074677f, 0.65480286f, 5.70313978f, -0.06546132f, 0.35443667f, 3.70382833f, -0.84020567f, 0.63826996f, 8.60301399f, -0.38236514f, 1.55177069f, 7.37542057f, -0.99374938f, -0.29971302f, 8.84352493f, -0.67121059f, 0.43132120f, 4.78175592f, -1.25070143f, -1.91523600f, 6.03855371f, -0.00292124f, -1.11214364f, 7.90158176f, -0.57949901f, -0.96735370f, 7.81192017f, -0.53255427f, -0.48009714f, 3.16953635f, 0.08353355f, -1.54299748f, 3.74821687f, 1.69396687f, 0.72724354f, 5.42915201f, -1.13686812f, -0.71793109f, 5.78376389f, -0.72239977f, -0.60055625f, 2.53636408f, 0.56777251f, -2.07892323f, 6.08064651f, 0.68620735f, 2.54017019f, 5.65828180f, -0.68255502f, 1.47283304f, 6.10842514f, -0.39655915f, 0.28380761f, 1.96707797f, -1.98206317f, 0.94027776f, 4.71811438f, 0.32104525f, -0.92409706f, 8.34588146f, -1.05581069f, -0.55217457f, 9.58440876f, -0.96549922f, 0.45820439f, 5.65453672f, -2.50953507f, -0.71441835f, 8.03059578f, -0.21281289f, 0.92125505f, 9.26900673f, -0.35963219f, -0.70039093f, 8.59924412f, -1.22358346f, 0.81318003f, 3.85920119f, -0.01305223f, -1.09234154f, 6.33158875f, 1.28094780f, -1.48926139f, 4.94969177f, -0.77126902f, -1.97033751f, 5.64381838f, -0.16285487f, -1.31277227f, 2.39893222f, -1.32902908f, -1.39609122f, 6.47572327f, -0.45267010f, 1.55727172f, 6.70965624f, -1.68735468f, -0.05672536f, 7.25092363f, -0.64613032f, 0.67050058f, 3.60789680f, -2.05948973f, 2.22687531f, 8.15202713f, -0.70148355f, 1.28314006f, 8.14842319f, -1.88807654f, -1.04808438f, 8.45500565f, -0.76425624f, 0.94542569f, 4.56179953f, -0.28786001f, -2.04502511f, 8.46278095f, -0.31019822f, 0.07339200f, 9.34214592f, -0.61948007f, 0.52481830f, 8.32515621f, -1.52418160f, 0.49678251f, 5.11082315f, -1.09908783f, -0.52969611f, 5.27806664f, 0.88632923f, 0.66754371f, 4.75839233f, 0.48928693f, -0.68036932f, 6.56925392f, -0.02949905f, -2.99189186f, 4.46320581f, -0.64534980f, -0.29516968f, 8.60809517f, -1.13120568f, 3.41720533f, 5.84243155f, -1.24109328f, 0.89566326f, 5.99578333f, -0.42496428f, 2.07076764f, 3.17812920f, -0.81566459f, -0.14363396f, 6.55184317f, 0.39633346f, -0.43852386f, 8.70214558f, -2.24613595f, 0.30708700f, 8.73882294f, -0.53545928f, 1.54409575f, 4.49452257f, -0.16509305f, 0.19028664f, 8.24897003f, 0.44750381f, 2.15448594f, 8.97640514f, -0.77728152f, 0.57272542f, 9.03467560f, 0.47173575f, -1.10807717f, 3.30056310f, -0.43268481f, -0.41470885f, 3.53798294f, -0.08546703f, -2.16840744f, 6.18733406f, -0.17871059f, -2.59837723f, 5.94218683f, -1.02990067f, -0.49760687f, 3.76938033f, 0.86383581f, -1.91504073f}); nd4j::ops::avgpool2d op; - auto result = op.execute({&input}, {}, {3,3, 3,3, 0,0, 1,1,1, 0,1}, {}); + auto result = op.evaluate({&input}, {3,3, 3,3, 0,0, 1,1,1, 0,1}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -287,7 +286,7 @@ TEST_F(DeclarableOpsTests4, avgpool2d_11) { x.linspace(1.0); nd4j::ops::avgpool2d op; - auto result = op.execute({&x}, {}, {3,3, 1,1, 0,0, 1,1, 1, 0, 1}); + auto result = op.evaluate({&x}, {3,3, 1,1, 0,0, 1,1, 1, 0, 1}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -364,7 +363,7 @@ TEST_F(DeclarableOpsTests4, avgpool2d_12) { input.syncToDevice(); nd4j::ops::avgpool2d op; - auto results = op.execute({&input}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, 0, dataFormat}); + auto results = op.evaluate({&input}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, 0, dataFormat}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -385,7 +384,7 @@ TEST_F(DeclarableOpsTests4, biasadd_1) { auto exp = NDArrayFactory::create('c', {2, 3, 3, 2}, {1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f}); nd4j::ops::biasadd op; - auto result = op.execute({&x, &bias}, {}, {}, {}, false, nd4j::DataType::DOUBLE); + auto result = op.evaluate({&x, &bias}, {}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -403,7 +402,7 @@ TEST_F(DeclarableOpsTests4, biasadd_2) { auto exp = NDArrayFactory::create('c', {2, 2, 3, 3}, {1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2}); nd4j::ops::biasadd op; - auto result = op.execute({&x, &bias}, {}, {}, {true}, false, nd4j::DataType::DOUBLE); + auto result = op.evaluate({&x, &bias}, {}, {}, {true}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -421,7 +420,7 @@ TEST_F(DeclarableOpsTests4, biasadd_3) { auto exp = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 1, 2, 3}); nd4j::ops::biasadd op; - auto result = op.execute({&x, &row}, {}, {}, {true}, false, nd4j::DataType::DOUBLE); + auto result = op.evaluate({&x, &row}, {}, {}, {true}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -444,7 +443,7 @@ TEST_F(DeclarableOpsTests4, biasadd_bp_1) { gradO.linspace(0.1, 0.1); nd4j::ops::biasadd_bp op; - auto result = op.execute({&x, &bias, &gradO}, {}, {}, {false}); // NHWC + auto result = op.evaluate({&x, &bias, &gradO}, {}, {}, {false}); // NHWC ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -472,7 +471,7 @@ TEST_F(DeclarableOpsTests4, biasadd_bp_2) { gradO.linspace(0.1, 0.1); nd4j::ops::biasadd_bp op; - auto result = op.execute({&x, &bias, &gradO}, {}, {}, {true}); // NCHW + auto result = op.evaluate({&x, &bias, &gradO}, {}, {}, {true}); // NCHW ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -511,7 +510,7 @@ TEST_F(DeclarableOpsTests4, Test_Fill_1) { exp.assign(2.0f); nd4j::ops::fill op; - auto result = op.execute({&x, &v}, {}, {}, {}, false, nd4j::DataType::DOUBLE); + auto result = op.evaluate({&x, &v}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -532,7 +531,7 @@ TEST_F(DeclarableOpsTests4, Test_FirasSparce_1) { x.p(60, 1); x.p(61, 0); nd4j::ops::firas_sparse op; - auto result = op.execute({&x}, {}, {0, 1}); + auto result = op.evaluate({&x}, {0, 1}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -552,7 +551,7 @@ TEST_F(DeclarableOpsTests4, Test_FlattenTests_1) { x.linspace(1); exp.linspace(1); nd4j::ops::flatten op; - auto result = op.execute({&x}, {}, {'c'}); + auto result = op.evaluate({&x}, {}, {'c'}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -573,7 +572,7 @@ TEST_F(DeclarableOpsTests4, Test_FlattenTests_2) { y.linspace(82); exp.linspace(1); nd4j::ops::flatten op; - auto result = op.execute({&x, &y}, {}, {'c'}); + auto result = op.evaluate({&x, &y}, {}, {'c'}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -593,7 +592,7 @@ TEST_F(DeclarableOpsTests4, Test_FlattenTests_3) { y.assign(x); nd4j::ops::flatten op; - auto result = op.execute({&x, &y}, {}, {'c'}); + auto result = op.evaluate({&x, &y}, {}, {'c'}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -611,7 +610,7 @@ TEST_F(DeclarableOpsTests4, Test_FlattenTests_4) { y.assign(x); nd4j::ops::flatten op; - auto result = op.execute({&x, &y}, {}, {'f'}); + auto result = op.evaluate({&x, &y}, {}, {'f'}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -627,7 +626,7 @@ TEST_F(DeclarableOpsTests4, Test_FloorTests_1) { exp.linspace(1); nd4j::ops::Floor op; - auto result = op.execute({&x}, {}, {}); + auto result = op.evaluate({&x}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -647,7 +646,7 @@ TEST_F(DeclarableOpsTests4, Test_Reshape_Again) { exp.linspace(1); nd4j::ops::reshape op; - auto result = op.execute({&x}, {}, {-99, 4, 3}); + auto result = op.evaluate({&x}, {-99, 4, 3}); auto z = result->at(0); @@ -666,7 +665,7 @@ TEST_F(DeclarableOpsTests4, Test_Gemv_Transpose_1) { y.linspace(1); nd4j::ops::matmul op; - auto result = op.execute({&x, &y}, {}, {1, 0}); + auto result = op.evaluate({&x, &y}, {}, {1, 0}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -695,7 +694,7 @@ TEST_F(DeclarableOpsTests4, Test_Split_1) { nd4j::ops::split_v op; - auto result = op.execute({&x, &sizes}, {}, {1}); + auto result = op.evaluate({&x, &sizes}, {}, {1}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -738,7 +737,7 @@ TEST_F(DeclarableOpsTests4, Test_Split_2) { nd4j::ops::split op; - auto result = op.execute({&axis, &x}, {}, {4}); + auto result = op.evaluate({&axis, &x}, {}, {4}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z0 = result->at(0); @@ -777,7 +776,7 @@ TEST_F(DeclarableOpsTests4, Test_Split_3) { sub2.assign(2.0f); nd4j::ops::split op; - auto result = op.execute({&axis, &x}, {}, {3}); + auto result = op.evaluate({&axis, &x}, {}, {3}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z0 = result->at(0); @@ -802,7 +801,7 @@ TEST_F(DeclarableOpsTests4, Test_Stack_4) { auto exp = NDArrayFactory::create('c', {3, 2, 3, 5}); nd4j::ops::stack op; - auto result = op.execute({&t, &u, &v}, {}, {-4}); + auto result = op.evaluate({&t, &u, &v}, {}, {-4}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -818,7 +817,7 @@ TEST_F(DeclarableOpsTests4, Test_Squeeze_args_1) { auto exp = NDArrayFactory::create('c', {2, 1, 2}, {1, 2, 3, 4}); nd4j::ops::squeeze op; - auto result = op.execute({&x}, {}, {1, 3}); + auto result = op.evaluate({&x}, {}, {1, 3}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -835,7 +834,7 @@ TEST_F(DeclarableOpsTests4, Test_Squeeze_args_2) { auto exp = NDArrayFactory::create('c', {2, 1, 2}, {1, 2, 3, 4}); nd4j::ops::squeeze op; - auto result = op.execute({&x, &y}, {}, {}); + auto result = op.evaluate({&x, &y}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -852,7 +851,7 @@ TEST_F(DeclarableOpsTests4, Test_Squeeze_args_3) { auto exp = NDArrayFactory::create('c', {2, 1, 2}, {1, 2, 3, 4}); nd4j::ops::squeeze op; - auto result = op.execute({&x}, {}, {-2, -3}); + auto result = op.evaluate({&x}, {}, {-2, -3}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -867,7 +866,7 @@ TEST_F(DeclarableOpsTests4, Test_1D_1) { auto x = NDArrayFactory::create('c', {2, 3}); nd4j::ops::unstack op; - auto result = op.execute({&x}, {}, {1}); + auto result = op.evaluate({&x}, {}, {1}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -884,7 +883,7 @@ TEST_F(DeclarableOpsTests4, Test_SpaceToDepth_1) { auto exp = NDArrayFactory::create('c', {1, 1, 1, 12}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); nd4j::ops::space_to_depth op; - auto result = op.execute({&x}, {}, {2, 1}); + auto result = op.evaluate({&x}, {}, {2, 1}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -900,7 +899,7 @@ TEST_F(DeclarableOpsTests4, Test_SpaceToDepth_2) { auto exp = NDArrayFactory::create('c', {1, 12, 1, 1}, {1, 5, 9, 2, 6, 10, 3, 7, 11, 4, 8, 12}); nd4j::ops::space_to_depth op; - auto result = op.execute({&x}, {}, {2, 0}); + auto result = op.evaluate({&x}, {}, {2, 0}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -917,7 +916,7 @@ TEST_F(DeclarableOpsTests4, Test_DepthToSpace_1) { auto exp = NDArrayFactory::create('c', {1, 2, 2, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); nd4j::ops::depth_to_space op; - auto result = op.execute({&x}, {}, {2, 1}); + auto result = op.evaluate({&x}, {}, {2, 1}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -934,7 +933,7 @@ TEST_F(DeclarableOpsTests4, Test_DepthToSpace_2) { auto exp = NDArrayFactory::create('c', {1, 3, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); nd4j::ops::depth_to_space op; - auto result = op.execute({&x}, {}, {2, 0}); + auto result = op.evaluate({&x}, {}, {2, 0}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -950,7 +949,7 @@ TEST_F(DeclarableOpsTests4, Test_DepthToSpace_3) { auto exp = NDArrayFactory::create('c', {4, 16, 64, 1}); nd4j::ops::depth_to_space op; - auto result = op.execute({&x}, {}, {4, 1}); + auto result = op.evaluate({&x}, {}, {4, 1}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -967,7 +966,7 @@ TEST_F(DeclarableOpsTests4, Test_Cross_1) { auto exp = NDArrayFactory::create('c', {3}, {-5, 10, -5}); nd4j::ops::cross op; - auto result = op.execute({&a, &b}, {}, {}, {}, false, nd4j::DataType::DOUBLE); + auto result = op.evaluate({&a, &b}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -985,7 +984,7 @@ TEST_F(DeclarableOpsTests4, Test_Cross_2) { auto exp = NDArrayFactory::create('c', {2, 3}, {-5, 10, -5, -5, 10, -5}); nd4j::ops::cross op; - auto result = op.execute({&a, &b}, {}, {}, {}, false, nd4j::DataType::DOUBLE); + auto result = op.evaluate({&a, &b}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -1003,7 +1002,7 @@ TEST_F(DeclarableOpsTests4, Test_Cross_3) { auto exp = NDArrayFactory::create('c', {3, 3}, { -1, 2, -1, -11, 22, -11, -11, 40, -27}); nd4j::ops::cross op; - auto result = op.execute({&a, &b}, {}, {}, {}, false, nd4j::DataType::DOUBLE); + auto result = op.evaluate({&a, &b}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -1020,7 +1019,7 @@ TEST_F(DeclarableOpsTests4, Test_Matmul_YATS_1) { auto exp = NDArrayFactory::create('c', {3}, {30, 70, 110}); nd4j::ops::matmul op; - auto result = op.execute({&a, &b}, {}, {}); + auto result = op.evaluate({&a, &b}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -1037,7 +1036,7 @@ TEST_F(DeclarableOpsTests4, Test_Matmul_YATS_2) { auto exp = NDArrayFactory::create('c', {3}, {70, 80, 90}); nd4j::ops::matmul op; - auto result = op.execute({&a, &b}, {}, {}); + auto result = op.evaluate({&a, &b}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -1054,7 +1053,7 @@ TEST_F(DeclarableOpsTests4, Test_Matmul_YATS_3) { auto exp = NDArrayFactory::create('c', {1, 3}, {70, 80, 90}); nd4j::ops::matmul op; - auto result = op.execute({&a, &b}, {}, {}); + auto result = op.evaluate({&a, &b}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -1071,7 +1070,7 @@ TEST_F(DeclarableOpsTests4, Test_Add_119) { auto exp = NDArrayFactory::create('c', {1, 4}, {2, 4, 6, 8}); nd4j::ops::add op; - auto result = op.execute({&a, &b}, {}, {}); + auto result = op.evaluate({&a, &b}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -1091,7 +1090,7 @@ TEST_F(DeclarableOpsTests4, Test_Reshape_Negative_1) { auto exp = NDArrayFactory::create('c', {4, 2}, {1, 2, 3, 4, 5, 6, 7, 8}); nd4j::ops::reshape op; - auto result = op.execute({&x, &shape}, {}, {}); + auto result = op.evaluate({&x, &shape}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -1109,7 +1108,7 @@ TEST_F(DeclarableOpsTests4, Test_TileToShape_1) { x.linspace(1.f); nd4j::ops::tile_to_shape op; - auto result = op.execute({&x},{}, {2, 4, 3}, {}, false, nd4j::DataType::DOUBLE); + auto result = op.evaluate({&x},{}, {2, 4, 3}); ASSERT_EQ(Status::OK(), result->status()); @@ -1128,7 +1127,7 @@ TEST_F(DeclarableOpsTests4, Test_StridedSlice_Alex_1) { exp.linspace(1); nd4j::ops::strided_slice op; - auto result = op.execute({&x}, {}, {0,0,0,1,0, -999,0,0,0, -999,3,4,5, -999,1,1,1}); + auto result = op.evaluate({&x}, {}, {0,0,0,1,0, -999,0,0,0, -999,3,4,5, -999,1,1,1}); ASSERT_EQ(Status::OK(), result->status()); @@ -1150,7 +1149,7 @@ TEST_F(DeclarableOpsTests4, Test_StridedSlice_Alex_2) { exp.linspace(1); nd4j::ops::strided_slice op; - auto result = op.execute({&x, &begin, &end, &stride}, {}, {0,0,0,1,0}); + auto result = op.evaluate({&x, &begin, &end, &stride}, {}, {0,0,0,1,0}); ASSERT_EQ(Status::OK(), result->status()); @@ -1175,7 +1174,7 @@ TEST_F(DeclarableOpsTests4, Test_StridedSlice_Alex_3) { //exp.linspace(1); nd4j::ops::strided_slice op; - auto result = op.execute({&x, &begin, &end, &stride}, {}, {1,0,0,0,0}); + auto result = op.evaluate({&x, &begin, &end, &stride}, {}, {1,0,0,0,0}); ASSERT_EQ(Status::OK(), result->status()); @@ -1194,7 +1193,7 @@ TEST_F(DeclarableOpsTests4, Test_StridedSlice_Alex_4) { //exp.linspace(1); nd4j::ops::strided_slice op; - auto result = op.execute({&x, &begin, &end, &stride}, {}, {1,0,1,0,2}); + auto result = op.evaluate({&x, &begin, &end, &stride}, {}, {1,0,1,0,2}); ASSERT_EQ(Status::OK(), result->status()); @@ -1218,7 +1217,7 @@ TEST_F(DeclarableOpsTests4, parallel_stack_test1) { expected.linspace(1); nd4j::ops::parallel_stack op; - auto results = op.execute({&x1, &x2, &x3}, {}, {}); + auto results = op.evaluate({&x1, &x2, &x3}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -1238,7 +1237,7 @@ TEST_F(DeclarableOpsTests4, parallel_stack_test2) { auto expected = NDArrayFactory::create('c', {3,1,2}, {1,2,3,4,5,6}); nd4j::ops::parallel_stack op; - auto results = op.execute({&x1, &x2, &x3}, {}, {}); + auto results = op.evaluate({&x1, &x2, &x3}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -1259,7 +1258,7 @@ TEST_F(DeclarableOpsTests4, parallel_stack_test3) { auto expected = NDArrayFactory::create('c', {3,2,1}, {1,2,3,4,5,6}); nd4j::ops::parallel_stack op; - auto results = op.execute({&x1, &x2, &x3}, {}, {}); + auto results = op.evaluate({&x1, &x2, &x3}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -1279,7 +1278,7 @@ TEST_F(DeclarableOpsTests4, parallel_stack_test4) { auto expected = NDArrayFactory::create('c', {3,2}, {1,2,3,4,5,6}); nd4j::ops::parallel_stack op; - auto results = op.execute({&x1, &x2, &x3}, {}, {}); + auto results = op.evaluate({&x1, &x2, &x3}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -1299,7 +1298,7 @@ TEST_F(DeclarableOpsTests4, parallel_stack_test5) { auto expected = NDArrayFactory::create('c', {3,1}, {1,3,5}); nd4j::ops::parallel_stack op; - auto results = op.execute({&x1, &x2, &x3}, {}, {}); + auto results = op.evaluate({&x1, &x2, &x3}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -1319,7 +1318,7 @@ TEST_F(DeclarableOpsTests4, parallel_stack_test6) { auto expected = NDArrayFactory::create('c', {3}, {1,3,5}); nd4j::ops::parallel_stack op; - auto results = op.execute({&x1, &x2, &x3}, {}, {}); + auto results = op.evaluate({&x1, &x2, &x3}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -1336,7 +1335,7 @@ TEST_F(DeclarableOpsTests4, parallel_stack_test7) { auto expected = NDArrayFactory::create('c', {1}, {1.}); nd4j::ops::parallel_stack op; - auto results = op.execute({&x1}, {}, {}); + auto results = op.evaluate({&x1}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -1357,7 +1356,7 @@ TEST_F(DeclarableOpsTests4, meshgrid_test1) { auto exp2 = NDArrayFactory::create('c', {2,3,4}, {100, 200, 300, 400, 100, 200, 300, 400, 100, 200, 300, 400, 100, 200, 300, 400, 100, 200, 300, 400, 100, 200, 300, 400}); nd4j::ops::meshgrid op; - auto results = op.execute({&in0, &in1, &in2}, {}, {0}); + auto results = op.evaluate({&in0, &in1, &in2}, {}, {0}); auto out0 = results->at(0); auto out1 = results->at(1); auto out2 = results->at(2); @@ -1386,7 +1385,7 @@ TEST_F(DeclarableOpsTests4, meshgrid_test2) { auto exp2 = NDArrayFactory::create('c', {3,2,4}, {100, 200, 300, 400, 100, 200, 300, 400, 100, 200, 300, 400, 100, 200, 300, 400, 100, 200, 300, 400, 100, 200, 300, 400}); nd4j::ops::meshgrid op; - auto results = op.execute({&in0, &in1, &in2}, {}, {}); + auto results = op.evaluate({&in0, &in1, &in2}); auto out0 = results->at(0); auto out1 = results->at(1); auto out2 = results->at(2); @@ -1413,7 +1412,7 @@ TEST_F(DeclarableOpsTests4, meshgrid_test3) { auto exp2 = NDArrayFactory::create('c', {3,2,4}, {100, 200, 300, 400, 100, 200, 300, 400, 100, 200, 300, 400, 100, 200, 300, 400, 100, 200, 300, 400, 100, 200, 300, 400}); nd4j::ops::meshgrid op; - auto results = op.execute({&in0, &in1, &in2}, {}, {}); + auto results = op.evaluate({&in0, &in1, &in2}); auto out0 = results->at(0); auto out1 = results->at(1); auto out2 = results->at(2); @@ -1440,7 +1439,7 @@ TEST_F(DeclarableOpsTests4, meshgrid_test4) { auto exp2 = NDArrayFactory::create('c', {2,3,4}, {100, 200, 300, 400, 100, 200, 300, 400, 100, 200, 300, 400, 100, 200, 300, 400, 100, 200, 300, 400, 100, 200, 300, 400}); nd4j::ops::meshgrid op; - auto results = op.execute({&in0, &in1, &in2}, {}, {0}); + auto results = op.evaluate({&in0, &in1, &in2}, {}, {0}); auto out0 = results->at(0); auto out1 = results->at(1); auto out2 = results->at(2); @@ -1467,7 +1466,7 @@ TEST_F(DeclarableOpsTests4, meshgrid_test5) { auto exp2 = NDArrayFactory::create('c', {1,1,1}, {3}); nd4j::ops::meshgrid op; - auto results = op.execute({&in0, &in1, &in2}, {}, {0}); + auto results = op.evaluate({&in0, &in1, &in2}, {}, {0}); auto out0 = results->at(0); auto out1 = results->at(1); auto out2 = results->at(2); @@ -1494,7 +1493,7 @@ TEST_F(DeclarableOpsTests4, meshgrid_test6) { auto exp2 = NDArrayFactory::create('c', {4,1,1}, {6,6,6,6}); nd4j::ops::meshgrid op; - auto results = op.execute({&in0, &in1, &in2}, {}, {0}); + auto results = op.evaluate({&in0, &in1, &in2}, {}, {0}); auto out0 = results->at(0); auto out1 = results->at(1); auto out2 = results->at(2); @@ -1521,7 +1520,7 @@ TEST_F(DeclarableOpsTests4, meshgrid_test7) { auto exp2 = NDArrayFactory::create('c', {1,4,1}, {6,6,6,6}); nd4j::ops::meshgrid op; - auto results = op.execute({&in0, &in1, &in2}, {}, {1}); + auto results = op.evaluate({&in0, &in1, &in2}, {}, {1}); auto out0 = results->at(0); auto out1 = results->at(1); auto out2 = results->at(2); @@ -1544,7 +1543,7 @@ TEST_F(DeclarableOpsTests4, meshgrid_test8) { auto exp0 = NDArrayFactory::create('c', {1}, {5}); nd4j::ops::meshgrid op; - auto results = op.execute({&in0}, {}, {0}); + auto results = op.evaluate({&in0}, {}, {0}); auto out0 = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -1561,7 +1560,7 @@ TEST_F(DeclarableOpsTests4, meshgrid_test9) { auto exp0 = NDArrayFactory::create('c', {1}, {5}); nd4j::ops::meshgrid op; - auto results = op.execute({&in0}, {}, {1}); + auto results = op.evaluate({&in0}, {}, {1}); auto out0 = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -1588,7 +1587,7 @@ TEST_F(DeclarableOpsTests4, WeightedCrossEntropyWithLogits_1) { //Result {-159.50006, -191.1, -16.009075, -210., -24.001238, -15.03887} nd4j::ops::weighted_cross_entropy_with_logits op; - auto results = op.execute({&targets, &input, &weight}, {}, {}, {}, false, nd4j::DataType::DOUBLE); + auto results = op.evaluate({&targets, &input, &weight}); auto output = results->at(0); // output->printIndexedBuffer(); @@ -1610,7 +1609,7 @@ TEST_F(DeclarableOpsTests4, WeightedCrossEntropyWithLogits_2) { auto expected = NDArrayFactory::create('c', {2, 3}, {-159.5001f, -191.1f, -15.98185f, -210.f, -24.001238f, -14.951412f}); nd4j::ops::weighted_cross_entropy_with_logits op; - auto results = op.execute({&targets, &input, &weights}, {}, {}, {}, false, nd4j::DataType::DOUBLE); + auto results = op.evaluate({&targets, &input, &weights}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -1657,7 +1656,7 @@ TEST_F(DeclarableOpsTests4, lstm_test1) { auto expClast = NDArrayFactory::create('c', {1, batchSize, numProj}, {1.1589154,1.1589154,1.1589154,1.1892855,1.1892855,1.1892855,1.219861 ,1.219861 ,1.219861}); nd4j::ops::lstm op; - auto results = op.execute({&x, &h0, &c0, &Wx, &Wh, &Wc, &Wp, &b}, {0., 0., 0.}, {0, 0}); + auto results = op.evaluate({&x, &h0, &c0, &Wx, &Wh, &Wc, &Wp, &b}, {0., 0., 0.}, {0, 0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1681,7 +1680,7 @@ TEST_F(DeclarableOpsTests4, relu6_test1) { auto expected = NDArrayFactory::create('c', {2,4}, {0., 6., 0., 0.,2., 6., 6., 6.}); nd4j::ops::relu6 op; - auto results = op.execute({&input}, {0.}, {}, {}, false, nd4j::DataType::DOUBLE); + auto results = op.evaluate({&input}, {0.}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1703,7 +1702,7 @@ TEST_F(DeclarableOpsTests4, relu6_bp_test1) { auto expected = NDArrayFactory::create('c', {2,4}, {0., 0., 0., 0., 5., 0., 0., 8.}); nd4j::ops::relu6_bp op; - auto results = op.execute({&input, &gradO}, {0.}, {}, {}, false, nd4j::DataType::DOUBLE); + auto results = op.evaluate({&input, &gradO}, {0.}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1732,7 +1731,7 @@ TYPED_TEST(TypedDeclarableOpsTests4, LrnTest_1) { ); nd4j::ops::lrn op; - auto results = op.execute({&x}, {1.0, 1.0, 0.5}, {5}, {}, false, nd4j::DataType::DOUBLE); + auto results = op.evaluate({&x}, {1.0, 1.0, 0.5}, {5}); auto out = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -1760,7 +1759,7 @@ TYPED_TEST(TypedDeclarableOpsTests4, LrnTest_2) { 0.7581754f, 0.58321184f, 0.86747235f, 0.4048204f}); nd4j::ops::lrn op; - auto results = op.execute({&x}, {1.0, 1.0, 0.5}, {2}, {}, false, nd4j::DataType::DOUBLE); + auto results = op.evaluate({&x}, {1.0, 1.0, 0.5}, {2}); auto out = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -1799,7 +1798,7 @@ TYPED_TEST(TypedDeclarableOpsTests4, LrnTest_3) { ); nd4j::ops::lrn op; - auto results = op.execute({&x}, {1.0, 1.0, 0.5}, {2}, {}, false, nd4j::DataType::DOUBLE); + auto results = op.evaluate({&x}, {1.0, 1.0, 0.5}, {2}); auto out = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -1838,7 +1837,7 @@ TYPED_TEST(TypedDeclarableOpsTests4, LrnTest_4) { ); nd4j::ops::lrn op; - auto results = op.execute({&x}, {1.0, 1.0, 0.5}, {5}, {}, false, nd4j::DataType::DOUBLE); + auto results = op.evaluate({&x}, {1.0, 1.0, 0.5}, {5}); auto out = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -1883,7 +1882,7 @@ TYPED_TEST(TypedDeclarableOpsTests4, LrnTest_5) { auto exp = NDArrayFactory::create('c', {2, 2, 2, 4}); nd4j::ops::lrn_bp op; - auto results = op.execute({&x, &eps}, {1.0, 1.0, 0.5}, {5}, {}, false, typeid(TypeParam) == typeid(float) ? nd4j::DataType::FLOAT32 : nd4j::DataType::DOUBLE); + auto results = op.evaluate({&x, &eps}, {1.0, 1.0, 0.5}, {5}, {}, {}, false); auto out = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -1905,7 +1904,7 @@ TEST_F(DeclarableOpsTests4, tri_test1) { auto expected = NDArrayFactory::create('c', {rows, cols}, {1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 1.f, 0.f, 0.f, 0.f, 1.f, 1.f, 1.f, 0.f, 0.f}); nd4j::ops::tri op; - auto results = op.execute({}, {}, {rows, cols}); + auto results = op.evaluate({}, {}, {rows, cols}); auto output = results->at(0); // output->printIndexedBuffer(); @@ -1928,7 +1927,7 @@ TEST_F(DeclarableOpsTests4, tri_test2) { auto expected = NDArrayFactory::create('c', {rows, cols}, {1.f, 1.f, 1.f, 0.f, 0.f, 1.f, 1.f, 1.f, 1.f, 0.f, 1.f, 1.f, 1.f, 1.f, 1.f}); nd4j::ops::tri op; - auto results = op.execute({}, {}, {rows, cols, diag}); + auto results = op.evaluate({}, {}, {rows, cols, diag}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -1949,7 +1948,7 @@ TEST_F(DeclarableOpsTests4, tri_test3) { auto expected = NDArrayFactory::create('c', {rows, cols}, {0.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 1.f, 0.f, 0.f, 0.f}); nd4j::ops::tri op; - auto results = op.execute({}, {}, {rows, cols, diag}); + auto results = op.evaluate({}, {}, {rows, cols, diag}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -1970,7 +1969,7 @@ TEST_F(DeclarableOpsTests4, tri_test4) { auto expected = NDArrayFactory::create('c', {rows, cols}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 0.f}); nd4j::ops::tri op; - auto results = op.execute({}, {}, {rows, cols, diag}); + auto results = op.evaluate({}, {}, {rows, cols, diag}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -1989,7 +1988,7 @@ TEST_F(DeclarableOpsTests4, tri_test5) { auto expected = NDArrayFactory::create('c', {rows, rows}, {1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 1.f, 0.f, 0.f, 0.f, 1.f, 1.f, 1.f, 0.f, 0.f, 1.f, 1.f, 1.f, 1.f, 0.f, 1.f, 1.f, 1.f, 1.f, 1.f}); nd4j::ops::tri op; - auto results = op.execute({}, {}, {rows}); + auto results = op.evaluate({}, {}, {rows}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -2010,7 +2009,7 @@ TEST_F(DeclarableOpsTests4, tri_test6) { auto expected = NDArrayFactory::create('c', {rows, cols}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}); nd4j::ops::tri op; - auto results = op.execute({}, {}, {rows, cols, diag}); + auto results = op.evaluate({}, {}, {rows, cols, diag}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -2031,7 +2030,7 @@ TEST_F(DeclarableOpsTests4, tri_test7) { auto expected = NDArrayFactory::create('c', {rows, cols}, {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}); nd4j::ops::tri op; - auto results = op.execute({}, {}, {rows, cols, diag}); + auto results = op.evaluate({}, {}, {rows, cols, diag}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -2049,7 +2048,7 @@ TEST_F(DeclarableOpsTests4, triu_test1) { auto expected = NDArrayFactory::create('c', {4, 3}, {1, 2, 3, 0, 5, 6, 0, 0, 9, 0, 0, 0}); nd4j::ops::triu op; - auto results = op.execute({&input}, {}, {}); + auto results = op.evaluate({&input}, {}, {}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -2068,7 +2067,7 @@ TEST_F(DeclarableOpsTests4, triu_test2) { auto expected = NDArrayFactory::create('c', {4, 3}, {1, 2, 3,4, 5, 6,0, 8, 9,0, 0, 12}); nd4j::ops::triu op; - auto results = op.execute({&input}, {}, {-1}); + auto results = op.evaluate({&input}, {}, {-1}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -2086,7 +2085,7 @@ TEST_F(DeclarableOpsTests4, triu_test3) { auto expected = NDArrayFactory::create('c', {2, 3, 2}, {1, 2,3, 4,0, 6,7, 8,9,10,0,12}); nd4j::ops::triu op; - auto results = op.execute({&input}, {}, {-1}); + auto results = op.evaluate({&input}, {}, {-1}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -2104,7 +2103,7 @@ TEST_F(DeclarableOpsTests4, triu_test4) { auto expected = NDArrayFactory::create('c', {2, 3, 2}, {1, 2,0, 4,0, 0,7, 8,0, 10,0, 0}); nd4j::ops::triu op; - auto results = op.execute({&input}, {}, {}); + auto results = op.evaluate({&input}, {}, {}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -2122,7 +2121,7 @@ TEST_F(DeclarableOpsTests4, triu_test5) { auto expected = NDArrayFactory::create('c', {2, 3, 2}, {0, 2,0, 0,0, 0,0, 8,0, 0,0, 0}); nd4j::ops::triu op; - auto results = op.execute({&input}, {}, {1}); + auto results = op.evaluate({&input}, {}, {1}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -2140,7 +2139,7 @@ TEST_F(DeclarableOpsTests4, triu_test6) { auto expected = NDArrayFactory::create('c', {2, 3, 2}, {0, 0,0, 0,0, 0,0, 0,0, 0,0, 0}); nd4j::ops::triu op; - auto results = op.execute({&input}, {}, {10}); + auto results = op.evaluate({&input}, {}, {10}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -2158,7 +2157,7 @@ TEST_F(DeclarableOpsTests4, triu_test7) { auto expected = NDArrayFactory::create('c', {2, 3, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); nd4j::ops::triu op; - auto results = op.execute({&input}, {}, {-10}); + auto results = op.evaluate({&input}, {}, {-10}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -2176,7 +2175,7 @@ TEST_F(DeclarableOpsTests4, triu_test8) { auto expected = NDArrayFactory::create('c', {6, 6}, {1, 2, 3, 4, 5, 6,0, 2, 3, 4, 5, 6,0, 0, 3, 4, 5, 6,0, 0, 0, 4, 5, 6,0, 0, 0, 0, 5, 6,0, 0, 0, 0, 0, 6}); nd4j::ops::triu op; - auto results = op.execute({&input}, {}, {}); + auto results = op.evaluate({&input}, {}, {}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -2194,7 +2193,7 @@ TEST_F(DeclarableOpsTests4, triu_test9) { auto expected = NDArrayFactory::create('c', {6, 6}, {1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 0, 2, 3, 4, 5, 6, 0, 0, 3, 4, 5, 6}); nd4j::ops::triu op; - auto results = op.execute({&input}, {}, {-3}); + auto results = op.evaluate({&input}, {}, {-3}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -2212,7 +2211,7 @@ TEST_F(DeclarableOpsTests4, triu_test10) { auto expected = NDArrayFactory::create('c', {6, 6}, {0, 0, 0, 4, 5, 6, 0, 0, 0, 0, 5, 6, 0, 0, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}); nd4j::ops::triu op; - auto results = op.execute({&input}, {}, {3}); + auto results = op.evaluate({&input}, {}, {3}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -2230,7 +2229,7 @@ TEST_F(DeclarableOpsTests4, triu_test11) { auto expected = NDArrayFactory::create('c', {6, 6}, {1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6}); nd4j::ops::triu op; - auto results = op.execute({&input}, {}, {-58}); + auto results = op.evaluate({&input}, {}, {-58}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -2252,7 +2251,7 @@ TEST_F(DeclarableOpsTests4, triu_bp_test1) { auto expected = NDArrayFactory::create('c', {2, 3, 2}, {0.,0.5,0.,0. ,0.,0. ,0.,0.5,0.,0. ,0.,0.}); nd4j::ops::triu_bp op; - auto results = op.execute({&input, &gradO}, {}, {1}); + auto results = op.evaluate({&input, &gradO}, {}, {1}); auto gradI = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -2273,7 +2272,7 @@ TEST_F(DeclarableOpsTests4, triu_bp_test2) { auto expected = NDArrayFactory::create('c', {2, 3, 2}, {0.5,0.5,0. ,0.5,0. ,0. ,0.5,0.5,0. ,0.5,0. ,0.}); nd4j::ops::triu_bp op; - auto results = op.execute({&input, &gradO}, {}, {}); + auto results = op.evaluate({&input, &gradO}, {}, {}); auto gradI = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -2294,7 +2293,7 @@ TEST_F(DeclarableOpsTests4, triu_bp_test3) { auto expected = NDArrayFactory::create('c', {6,6}, {0.5, 0.5, 0.5, 0.5, 0.5, 0.5,0.5, 0.5, 0.5, 0.5, 0.5, 0.5,0.5, 0.5, 0.5, 0.5, 0.5, 0.5,0. , 0.5, 0.5, 0.5, 0.5, 0.5,0. , 0. , 0.5, 0.5, 0.5, 0.5,0. , 0. , 0. , 0.5, 0.5, 0.5}); nd4j::ops::triu_bp op; - auto results = op.execute({&input, &gradO}, {}, {-2}); + auto results = op.evaluate({&input, &gradO}, {}, {-2}); auto gradI = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -2315,7 +2314,7 @@ TEST_F(DeclarableOpsTests4, triu_bp_test4) { auto expected = NDArrayFactory::create('c', {2,3}, {0., 0., 0., 0., 0., 0.}); nd4j::ops::triu_bp op; - auto results = op.execute({&input, &gradO}, {}, {10}); + auto results = op.evaluate({&input, &gradO}, {}, {10}); auto gradI = results->at(0); ASSERT_EQ(Status::OK(), results->status()); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests5.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests5.cpp index 6d85feec1..2a5697ce8 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests5.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests5.cpp @@ -46,7 +46,7 @@ TEST_F(DeclarableOpsTests5, Test_PermuteEquality_1) { x.reshapei('c', {3, 4, 5}); nd4j::ops::permute op; - auto result = op.execute({&x}, {}, {0, 2, 1}); + auto result = op.evaluate({&x}, {}, {0, 2, 1}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -67,7 +67,7 @@ TEST_F(DeclarableOpsTests5, Test_PermuteEquality_0) { // x.printBuffer("{0, 1, 2} data"); nd4j::ops::permute op; - auto result = op.execute({&x}, {}, {0, 1, 2}); + auto result = op.evaluate({&x}, {}, {0, 1, 2}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -89,7 +89,7 @@ TEST_F(DeclarableOpsTests5, Test_PermuteEquality_2) { // x.printBuffer("{1, 0, 2} data"); nd4j::ops::permute op; - auto result = op.execute({&x}, {}, {1, 0, 2}); + auto result = op.evaluate({&x}, {}, {1, 0, 2}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -110,7 +110,7 @@ TEST_F(DeclarableOpsTests5, Test_PermuteEquality_3) { // x.printBuffer("{1, 2, 0} data"); nd4j::ops::permute op; - auto result = op.execute({&x}, {}, {1, 2, 0}); + auto result = op.evaluate({&x}, {}, {1, 2, 0}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -131,7 +131,7 @@ TEST_F(DeclarableOpsTests5, Test_PermuteEquality_4) { // x.printBuffer("{2, 0, 1} data"); nd4j::ops::permute op; - auto result = op.execute({&x}, {}, {2, 0, 1}); + auto result = op.evaluate({&x}, {}, {2, 0, 1}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -152,7 +152,7 @@ TEST_F(DeclarableOpsTests5, Test_PermuteEquality_5) { // x.printBuffer("{2, 1, 0} data"); nd4j::ops::permute op; - auto result = op.execute({&x}, {}, {2, 1, 0}); + auto result = op.evaluate({&x}, {}, {2, 1, 0}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -172,7 +172,7 @@ TEST_F(DeclarableOpsTests5, Test_TTS_bp_1) { eps.linspace(1.f); nd4j::ops::tile_to_shape_bp op; - auto result = op.execute({&x, &eps}, {}, {2, 4, 3}); + auto result = op.evaluate({&x, &eps}, {}, {2, 4, 3}); ASSERT_EQ(Status::OK(), result->status()); @@ -194,14 +194,14 @@ TEST_F(DeclarableOpsTests5, Test_Rdiv_bp_1) { nd4j::ops::reversedivide op_ff; - auto result_ff = op_ff.execute({&x, &y}, {}, {}); + auto result_ff = op_ff.evaluate({&x, &y}, {}, {}); ASSERT_EQ(Status::OK(), result_ff->status()); auto z_ff = result_ff->at(0); ASSERT_TRUE(eps.isSameShape(z_ff)); nd4j::ops::reversedivide_bp op_bp; - auto result_bp = op_bp.execute({&x, &y, &eps}, {}, {}); + auto result_bp = op_bp.evaluate({&x, &y, &eps}, {}, {}); ASSERT_EQ(Status::OK(), result_bp->status()); auto z_bp = result_bp->at(0); @@ -217,7 +217,7 @@ TEST_F(DeclarableOpsTests5, Test_Boolean_diff_1) { auto y = NDArrayFactory::create(2.0f); nd4j::ops::less op; - auto result = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::BOOL); + auto result = op.evaluate({&x, &y}); ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(result->at(0)->t(0), true); delete result; @@ -228,12 +228,12 @@ TEST_F(DeclarableOpsTests5, Test_SetSeed_1) { auto y = NDArrayFactory::create(5); nd4j::ops::set_seed op; - auto result = op.execute({&x, &y}, {}, {120, 5}, {}, false, nd4j::DataType::INT32); + auto result = op.evaluate({&x, &y}, {}, {120, 5}); ASSERT_EQ(Status::OK(), result->status()); // result->at(0)->printIndexedBuffer("RES SEED"); nd4j::ops::get_seed getOp; - auto getRes = getOp.execute({}, {}, {}); + auto getRes = getOp.evaluate({}); ASSERT_EQ(Status::OK(), getRes->status()); // getRes->at(0)->printIndexedBuffer("Output RES GET SEED"); // ASSERT_EQ(result->at(0)->t(0), true); @@ -248,7 +248,7 @@ TEST_F(DeclarableOpsTests5, scatterMul_test1) { auto exp = NDArrayFactory::create('c', {2, 2}, {10.f, 2.f, 3.f, 4.f}); nd4j::ops::scatter_mul op; - auto result = op.execute({&matrix, &idc, &updates}, {}, {}, {}); + auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -266,7 +266,7 @@ TEST_F(DeclarableOpsTests5, scatterDiv_test1) { auto exp = NDArrayFactory::create('c', {2, 2}, {0.10f, 2.f, 3.f, 4.f}); nd4j::ops::scatter_div op; - auto result = op.execute({&matrix, &idc, &updates}, {}, {}, {}); + auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -284,7 +284,7 @@ TEST_F(DeclarableOpsTests5, scatterSub_test1) { auto exp = NDArrayFactory::create('c', {2, 2}, {-9.f, 1.f, 3.f, 4.f}); nd4j::ops::scatter_sub op; - auto result = op.execute({&matrix, &idc, &updates}, {}, {}, {}); + auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -300,7 +300,7 @@ TEST_F(DeclarableOpsTests5, hardsigmoid_test1) { auto exp = NDArrayFactory::create('c', {2, 2}, {0.7f, 0.9f, 1.f, 1.f}); nd4j::ops::hardsigmoid op; - auto result = op.execute({&matrix}, {}, {}, {}); + auto result = op.evaluate({&matrix}, {}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -316,7 +316,7 @@ TEST_F(DeclarableOpsTests5, hardsigmoid_test2) { auto exp = NDArrayFactory::create('c', {2, 2}, {0.2f, 0.4f, 0.f, 0.f}); nd4j::ops::hardsigmoid_bp op; - auto result = op.execute({&matrix, &eps}, {}, {}, {}); + auto result = op.evaluate({&matrix, &eps}, {}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -331,7 +331,7 @@ TEST_F(DeclarableOpsTests5, hardtanh_test1) { auto exp = NDArrayFactory::create('c', {3, 3}, {-1, -1, -1, -1, 0, 1, 1, 1, 1}); nd4j::ops::hardtanh op; - auto result = op.execute({&matrix}, {}, {}, {}); + auto result = op.evaluate({&matrix}, {}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -347,7 +347,7 @@ TEST_F(DeclarableOpsTests5, hardtanh_test2) { auto exp = NDArrayFactory::create('c', {3, 3}, {0, 0, 0, 4, 5, 6, 0, 0, 0}); nd4j::ops::hardtanh_bp op; - auto result = op.execute({&matrix, &eps}, {}, {}, {}); + auto result = op.evaluate({&matrix, &eps}, {}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -363,7 +363,7 @@ TEST_F(DeclarableOpsTests5, histogram_test1) { auto exp = NDArrayFactory::create('c', {3}, {3, 3, 3}); nd4j::ops::histogram op; - auto result = op.execute({&matrix}, {}, {3}, {}); + auto result = op.evaluate({&matrix}, {}, {3}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -378,7 +378,7 @@ TEST_F(DeclarableOpsTests5, histogram_test2) { auto exp = NDArrayFactory::create('c', {4}, {2, 0, 0, 1}); nd4j::ops::histogram op; - auto result = op.execute({&matrix}, {}, {4}, {}); + auto result = op.evaluate({&matrix}, {}, {4}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -393,7 +393,7 @@ TEST_F(DeclarableOpsTests5, Identity_test1) { // auto exp = NDArrayFactory::create('c', {3, 3}, {3, 3, 3}); nd4j::ops::identity op; - auto result = op.execute({&matrix}, {}, {}, {}); + auto result = op.evaluate({&matrix}, {}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -408,7 +408,7 @@ TEST_F(DeclarableOpsTests5, Identity_test2) { auto eps = NDArrayFactory::create('c', {3, 3}, {1,2,3,4,5,6,7,8,9}); // auto exp = NDArrayFactory::create('c', {3,3}); nd4j::ops::identity_bp op; - auto result = op.execute({&matrix, &eps}, {}, {}, {}); + auto result = op.evaluate({&matrix, &eps}, {}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -424,7 +424,7 @@ TEST_F(DeclarableOpsTests5, Log1p_test1) { // auto exp = NDArrayFactory::create('c', {3,3}); nd4j::ops::Log1p op; y.applyTransform(nd4j::transform::Log, y); - auto result = op.execute({&matrix}, {}, {}, {}); + auto result = op.evaluate({&matrix}, {}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -440,7 +440,7 @@ TEST_F(DeclarableOpsTests5, Test_SpaceToBatch_1) { auto paddings = NDArrayFactory::create('c', {2, 2}, {0, 0, 0, 0}); nd4j::ops::space_to_batch op; - auto result = op.execute({&x, &paddings}, {}, {2}); + auto result = op.evaluate({&x, &paddings}, {}, {2}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -458,7 +458,7 @@ TEST_F(DeclarableOpsTests5, Test_SpaceToBatch_2) { auto paddings = NDArrayFactory::create('c', {2, 2}, {0, 0, 0, 0}); nd4j::ops::space_to_batch op; - auto result = op.execute({&x, &paddings}, {}, {2}); + auto result = op.evaluate({&x, &paddings}, {}, {2}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -477,7 +477,7 @@ TEST_F(DeclarableOpsTests5, Test_SpaceToBatch_3) { auto exp = NDArrayFactory::create('c', {8, 1, 3, 1}, {0, 1, 3, 0, 9, 11,0, 2, 4, 0, 10, 12,0, 5, 7, 0, 13, 15,0, 6, 8, 0, 14, 16}); nd4j::ops::space_to_batch op; - auto result = op.execute({&x, &paddings}, {}, {2}); + auto result = op.evaluate({&x, &paddings}, {}, {2}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -506,7 +506,7 @@ TEST_F(DeclarableOpsTests5, Test_SpaceToBatch_4) { 276, 0,0, 0,0, 0,0, 0,0, 0,0, 0,0}, nd4j::DataType::FLOAT32); nd4j::ops::space_to_batch op; - auto result = op.execute({&x, &paddings}, {}, {blockSize}); + auto result = op.evaluate({&x, &paddings}, {}, {blockSize}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -524,7 +524,7 @@ TEST_F(DeclarableOpsTests5, Test_BatchToSpace_1) { auto crops = NDArrayFactory::create('c', {2, 2}, {0, 0, 0, 0}); nd4j::ops::batch_to_space op; - auto result = op.execute({&x, &crops}, {}, {2}); + auto result = op.evaluate({&x, &crops}, {}, {2}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -542,7 +542,7 @@ TEST_F(DeclarableOpsTests5, Test_BatchToSpace_2) { auto crops = NDArrayFactory::create('c', {2, 2}, {0, 0, 0, 0}); nd4j::ops::batch_to_space op; - auto result = op.execute({&x, &crops}, {}, {2}); + auto result = op.evaluate({&x, &crops}, {}, {2}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -563,7 +563,7 @@ TEST_F(DeclarableOpsTests5, Test_BatchToSpace_3) { auto crops = NDArrayFactory::create('c', {2, 2}, {0, 0, 2, 0}); nd4j::ops::batch_to_space op; - auto result = op.execute({&x, &crops}, {}, {2}); + auto result = op.evaluate({&x, &crops}, {}, {2}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -585,7 +585,7 @@ TEST_F(DeclarableOpsTests5, Test_BatchToSpace_4) { NDArray exp('c', {3, 3*blockSize - 1 - 2, 4*blockSize - 2 - 3, 2}, {147, 148, 219, 220, 149, 150, 11, 12, 83, 84, 13, 14, 155, 156, 227, 228, 157, 158, 171, 172, 243, 244, 173, 174, 35, 36, 107, 108, 37, 38, 179, 180, 251, 252, 181, 182, 195, 196, 267, 268, 197, 198, 59, 60, 131, 132, 61, 62, 203, 204, 275, 276, 205, 206}, nd4j::DataType::FLOAT32); nd4j::ops::batch_to_space op; - auto result = op.execute({&x, &crops}, {}, {blockSize}); + auto result = op.evaluate({&x, &crops}, {}, {blockSize}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -602,7 +602,7 @@ TEST_F(DeclarableOpsTests5, eye_test1) { auto expected = NDArrayFactory::create('c', {3, 3}, {1.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 1.f}); nd4j::ops::eye op; - auto results = op.execute({}, {}, {-99, 3}); + auto results = op.evaluate({}, {}, {-99, 3}); auto output = results->at(0); // output->printIndexedBuffer(); @@ -619,7 +619,7 @@ TEST_F(DeclarableOpsTests5, eye_test2) { auto expected = NDArrayFactory::create('c', {3, 4}, {1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f}); nd4j::ops::eye op; - auto results = op.execute({}, {}, {-99, 3, 4}); + auto results = op.evaluate({}, {}, {-99, 3, 4}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -635,7 +635,7 @@ TEST_F(DeclarableOpsTests5, eye_test3) { auto expected = NDArrayFactory::create('c', {2, 3, 4}, {1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0}); nd4j::ops::eye op; - auto results = op.execute({}, {9 /*int*/}, {-99, 3, 4, 2}); + auto results = op.evaluate({}, {9 /*int*/}, {-99, 3, 4, 2}); auto output = results->at(0); // output->printIndexedBuffer("Output eye"); @@ -652,7 +652,7 @@ TEST_F(DeclarableOpsTests5, eye_test4) { auto expected = NDArrayFactory::create('c', {2, 2, 3, 4}, {1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0.}); nd4j::ops::eye op; - auto results = op.execute({}, {6/*double*/}, {-99, 3, 4, 2, 2}); + auto results = op.evaluate({}, {6/*double*/}, {-99, 3, 4, 2, 2}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -666,7 +666,7 @@ TEST_F(DeclarableOpsTests5, eye_test4) { TEST_F(DeclarableOpsTests5, eye_test5) { nd4j::ops::eye op; - auto result = op.execute({},{},{3, 2}); + auto result = op.evaluate({},{},{3, 2}); auto z = result->at(0); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -684,7 +684,7 @@ TEST_F(DeclarableOpsTests5, gatherNd_test1) { auto expected = NDArrayFactory::create('c', {2,2,3,2}, {19, 20, 21, 22, 23, 24, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 13, 14, 15, 16, 17, 18}); nd4j::ops::gather_nd op; - auto results = op.execute({&input, &indices}, {}, {}); + auto results = op.evaluate({&input, &indices}, {}, {}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -704,7 +704,7 @@ TEST_F(DeclarableOpsTests5, gatherNd_test2) { auto expected = NDArrayFactory::create('c', {2,2,2}, {23, 24, 11, 12, 3, 4, 3, 4}); nd4j::ops::gather_nd op; - auto results = op.execute({&input, &indices}, {}, {}, {true}); + auto results = op.evaluate({&input, &indices}, {}, {}, {true}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -723,7 +723,7 @@ TEST_F(DeclarableOpsTests5, gatherNd_test3) { auto expected = NDArrayFactory::create(24.); nd4j::ops::gather_nd op; - auto results = op.execute({&input, &indices}, {}, {}); + auto results = op.evaluate({&input, &indices}, {}, {}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -742,7 +742,7 @@ TEST_F(DeclarableOpsTests5, gatherNd_test4) { auto expected = NDArrayFactory::create('c',{2}, {24., 6}); nd4j::ops::gather_nd op; - auto results = op.execute({&input, &indices}, {}, {}); + auto results = op.evaluate({&input, &indices}, {}, {}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -760,7 +760,7 @@ TEST_F(DeclarableOpsTests5, gatherNd_test5) { auto expected = NDArrayFactory::create('c',{5}, {4.,3,1,2,2}); nd4j::ops::gather_nd op; - auto results = op.execute({&input, &indices}, {}, {}); + auto results = op.evaluate({&input, &indices}, {}, {}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -779,7 +779,7 @@ TEST_F(DeclarableOpsTests5, gatherNd_test6) { auto expected = NDArrayFactory::create(3.); nd4j::ops::gather_nd op; - auto results = op.execute({&input, &indices}, {}, {}); + auto results = op.evaluate({&input, &indices}, {}, {}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -798,7 +798,7 @@ TEST_F(DeclarableOpsTests5, gatherNd_test7) { auto expected = NDArrayFactory::create('c', {3,3}, {3,5,5,8,5,10,2,2,14}); nd4j::ops::gather_nd op; - auto results = op.execute({&input, &indices}, {}, {}, {true}); + auto results = op.evaluate({&input, &indices}, {}, {}, {true}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -815,7 +815,7 @@ TEST_F(DeclarableOpsTests5, gatherNd_test8) { auto e = NDArrayFactory::create('c', {2}, {1., 4.}); nd4j::ops::gather_nd op; - auto result = op.execute({&x, &y}, {}, {}); + auto result = op.evaluate({&x, &y}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -832,7 +832,7 @@ TEST_F(DeclarableOpsTests5, gatherNd_test9) { x.linspace(1); nd4j::ops::gather_nd op; - auto result = op.execute({&x, &indices}, {}, {}); + auto result = op.evaluate({&x, &indices}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -880,7 +880,7 @@ TEST_F(DeclarableOpsTests5, reverse_sequense_test1) { auto exp = NDArrayFactory::create('c', {3, 4, 5}, {4, 3, 2, 1, 5, 9, 8, 7, 6, 10, 14, 13, 12, 11, 15, 19, 18, 17, 16, 20, 24, 23, 22, 21, 25, 29, 28, 27, 26, 30, 34, 33, 32, 31, 35, 39, 38, 37, 36, 40, 44, 43, 42, 41, 45, 49, 48, 47, 46, 50, 54, 53, 52, 51, 55, 59, 58, 57, 56, 60}); nd4j::ops::reverse_sequence op; - auto results = op.execute({&input, &seqLengths}, {}, {2, 1}); + auto results = op.evaluate({&input, &seqLengths}, {}, {2, 1}); ASSERT_EQ(Status::OK(), results->status()); auto output = results->at(0); @@ -900,7 +900,7 @@ TEST_F(DeclarableOpsTests5, reverse_sequense_test2) { auto exp = NDArrayFactory::create('c', {3, 4, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 12, 11, 13, 14, 15, 18, 17, 16, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 32, 31, 33, 34, 35, 38, 37, 36, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 52, 51, 53, 54, 55, 58, 57, 56, 59, 60}); nd4j::ops::reverse_sequence op; - auto results = op.execute({&input, &seqLengths}, {}, {2, 1}); + auto results = op.evaluate({&input, &seqLengths}, {}, {2, 1}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -919,7 +919,7 @@ TEST_F(DeclarableOpsTests5, reverse_sequense_test3) { auto exp = NDArrayFactory::create('c', {3, 4, 5}, {2, 1, 3, 4, 5, 7, 6, 8, 9, 10, 12, 11, 13, 14, 15, 17, 16, 18, 19, 20, 23, 22, 21, 24, 25, 28, 27, 26, 29, 30, 33, 32, 31, 34, 35, 38, 37, 36, 39, 40, 44, 43, 42, 41, 45, 49, 48, 47, 46, 50, 54, 53, 52, 51, 55, 59, 58, 57, 56, 60}); nd4j::ops::reverse_sequence op; - auto results = op.execute({&input, &seqLengths}, {}, {2, 0}); + auto results = op.evaluate({&input, &seqLengths}, {}, {2, 0}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -938,7 +938,7 @@ TEST_F(DeclarableOpsTests5, reverse_sequense_test4) { auto exp = NDArrayFactory::create('c', {3, 4, 5}, {1, 22, 3, 24, 45, 6, 27, 8, 29, 50, 11, 32, 13, 34, 55, 16, 37, 18, 39, 60, 21, 2, 23, 4, 25, 26, 7, 28, 9, 30, 31, 12, 33, 14, 35, 36, 17, 38, 19, 40, 41, 42, 43, 44, 5, 46, 47, 48, 49, 10, 51, 52, 53, 54, 15, 56, 57, 58, 59, 20}); nd4j::ops::reverse_sequence op; - auto results = op.execute({&input, &seqLengths}, {}, {0, 2}); + auto results = op.evaluate({&input, &seqLengths}, {}, {0, 2}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -957,7 +957,7 @@ TEST_F(DeclarableOpsTests5, reverse_sequense_test5) { auto exp = NDArrayFactory::create('c', {3, 4, 5}, {1, 7, 18, 9, 15, 6, 2, 13, 4, 10, 11, 12, 8, 14, 5, 16, 17, 3, 19, 20, 21, 27, 38, 29, 35, 26, 22, 33, 24, 30, 31, 32, 28, 34, 25, 36, 37, 23, 39, 40, 41, 47, 58, 49, 55, 46, 42, 53, 44, 50, 51, 52, 48, 54, 45, 56, 57, 43, 59, 60}); nd4j::ops::reverse_sequence op; - auto results = op.execute({&input, &seqLengths}, {}, {1, 2}); + auto results = op.evaluate({&input, &seqLengths}, {}, {1, 2}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -976,7 +976,7 @@ TEST_F(DeclarableOpsTests5, reverse_sequense_test6) { auto exp = NDArrayFactory::create('c', {3, 4, 5}, {1, 2, 3, 4, 5, 26, 27, 28, 29, 30, 51, 52, 53, 54, 55, 36, 37, 38, 39, 40, 21, 22, 23, 24, 25, 6, 7, 8, 9, 10, 31, 32, 33, 34, 35, 16, 17, 18, 19, 20, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 11, 12, 13, 14, 15, 56, 57, 58, 59, 60}); nd4j::ops::reverse_sequence op; - auto results = op.execute({&input, &seqLengths}, {}, {0, 1}); + auto results = op.evaluate({&input, &seqLengths}, {}, {0, 1}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -996,7 +996,7 @@ TEST_F(DeclarableOpsTests5, reverse_sequense_test7) { auto exp = NDArrayFactory::create('c', {1, 5}, {3, 2, 1, 4, 5}); nd4j::ops::reverse_sequence op; - auto results = op.execute({&input, &seqLengths}, {}, {1, 0}); + auto results = op.evaluate({&input, &seqLengths}, {}, {1, 0}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -1016,7 +1016,7 @@ TEST_F(DeclarableOpsTests5, reverse_sequense_test8) { auto exp = NDArrayFactory::create('c', {1, 5}, {1, 2, 3, 4, 5}); nd4j::ops::reverse_sequence op; - auto results = op.execute({&input, &seqLengths}, {}, {0, 1}); + auto results = op.evaluate({&input, &seqLengths}, {}, {0, 1}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -1036,7 +1036,7 @@ TEST_F(DeclarableOpsTests5, reverse_sequense_test9) { auto exp = NDArrayFactory::create('c', {5, 1}, {1, 2, 3, 4, 5}); nd4j::ops::reverse_sequence op; - auto results = op.execute({&input, &seqLengths}, {}, {1, 0}); + auto results = op.evaluate({&input, &seqLengths}, {}, {1, 0}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -1056,7 +1056,7 @@ TEST_F(DeclarableOpsTests5, reverse_sequense_test10) { auto exp = NDArrayFactory::create('c', {5, 1}, {3, 2, 1, 4, 5}); nd4j::ops::reverse_sequence op; - auto results = op.execute({&input, &seqLengths}, {}, {0, 1}); + auto results = op.evaluate({&input, &seqLengths}, {}, {0, 1}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -1076,7 +1076,7 @@ TEST_F(DeclarableOpsTests5, reverse_sequense_test11) { auto exp = NDArrayFactory::create('c', {1, 1, 5, 1}, {1, 2, 3, 4, 5}); nd4j::ops::reverse_sequence op; - auto results = op.execute({&input, &seqLengths}, {}, {1, 2}); + auto results = op.evaluate({&input, &seqLengths}, {}, {1, 2}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -1096,7 +1096,7 @@ TEST_F(DeclarableOpsTests5, reverse_sequense_test12) { auto exp = NDArrayFactory::create('c', {1, 1, 5, 1}, {3, 2, 1, 4, 5}); nd4j::ops::reverse_sequence op; - auto results = op.execute({&input, &seqLengths}, {}, {2, 0}); + auto results = op.evaluate({&input, &seqLengths}, {}, {2, 0}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -1116,7 +1116,7 @@ TEST_F(DeclarableOpsTests5, reverse_sequense_test13) { auto exp = NDArrayFactory::create('c', {1, 1, 5, 1}, {1, 2, 3, 4, 5}); nd4j::ops::reverse_sequence op; - auto results = op.execute({&input, &seqLengths}, {}, {3, 0}); + auto results = op.evaluate({&input, &seqLengths}, {}, {3, 0}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -1133,7 +1133,7 @@ TEST_F(DeclarableOpsTests5, reverse_sequense_test14) { auto e = NDArrayFactory::create('c', {8, 8, 3, 2}, {0.54193264, 0.05176904, 0.82555761, 0.71106697, 0.04416722, 0.07653656, 0.06478678, 0.68985848, 0.55216783, 0.55382648, 0.34652863, 0.17261296, 0.61523204, 0.64298760, 0.26848351, 0.75015615, 0.28683049, 0.70937606, 0.38700677, 0.68832738, 0.37292716, 0.94616004, 0.77735792, 0.60803430, 0.54555996, 0.23407607, 0.11372584, 0.49965927, 0.15210842, 0.53268608, 0.67197708, 0.80659380, 0.98274191, 0.63566073, 0.21592326, 0.54902743, 0.09753360, 0.76124972, 0.24693797, 0.13813169, 0.33144656, 0.08299957, 0.01034390, 0.99430482, 0.59944390, 0.17973880, 0.36437840, 0.86383673, 0.93630291, 0.67277404, 0.93899264, 0.52422773, 0.44892176, 0.03127759, 0.45025550, 0.97136977, 0.13565978, 0.71567448, 0.92094825, 0.93536442, 0.85910449, 0.18252879, 0.72830945, 0.96736828, 0.89831575, 0.83437150, 0.59050780, 0.36145925, 0.16483070, 0.44021176, 0.76018652, 0.44227383, 0.13052339, 0.18204235, 0.99743733, 0.26885190, 0.87726522, 0.16396056, 0.94943412, 0.40016700, 0.65267938, 0.71073267, 0.40094733, 0.91182634, 0.05391789, 0.49520416, 0.24963864, 0.34847086, 0.74088617, 0.36115701, 0.63074210, 0.97423085, 0.42216846, 0.06326975, 0.07858702, 0.20586622, 0.34755773, 0.63166554, 0.18849320, 0.34828456, 0.98477707, 0.75163124, 0.33309570, 0.67563176, 0.98343578, 0.95919930, 0.66994391, 0.89296165, 0.28752144, 0.38146961, 0.83518735, 0.08207577, 0.82083487, 0.81665728, 0.83306004, 0.14203056, 0.01497920, 0.85727447, 0.71194544, 0.85654019, 0.86160433, 0.79580411, 0.47710411, 0.09318029, 0.31369071, 0.64122249, 0.58399725, 0.26706597, 0.05655339, 0.91025211, 0.30330468, 0.33142930, 0.05668627, 0.02936449, 0.12613087, 0.09960114, 0.16218074, 0.15088139, 0.31239040, 0.55980062, 0.34804391, 0.34941538, 0.61370555, 0.07022964, 0.27274571, 0.83306066, 0.75830824, 0.25963478, 0.87137718, 0.24418835, 0.59371493, 0.74479056, 0.84699625, 0.51210368, 0.12489571, 0.23371067, 0.18361641, 0.48636240, 0.06052657, 0.04241913, 0.66710351, 0.07007925, 0.60553664, 0.07536713, 0.55971796, 0.38764845, 0.20737843, 0.37989120, 0.59757058, 0.31189846, 0.25215345, 0.52546591, 0.55744218, 0.59485650, 0.05032742, 0.52076188, 0.47762345, 0.89829370, 0.34417708, 0.84705151, 0.08203183, 0.10632956, 0.78431292, 0.86441722, 0.36487598, 0.09833603, 0.85863594, 0.11010505, 0.11659283, 0.42500288, 0.02747301, 0.12359903, 0.19736489, 0.44461885, 0.33341706, 0.22519571, 0.31528710, 0.14802902, 0.01753431, 0.41160932, 0.47245979, 0.08268172, 0.21580773, 0.75770279, 0.64171939, 0.52643769, 0.19261234, 0.98032835, 0.15401656, 0.85274458, 0.66408502, 0.23212704, 0.74630026, 0.05713613, 0.49025892, 0.48418810, 0.59541513, 0.09243053, 0.93919152, 0.95357019, 0.52377729, 0.65963871, 0.47934951, 0.49919534, 0.34369898, 0.78211256, 0.13908708, 0.95754117, 0.84107746, 0.09126213, 0.42979124, 0.10295325, 0.34631257, 0.69448345, 0.41720536, 0.15282440, 0.74329854, 0.45775009, 0.12786280, 0.39830299, 0.20386769, 0.59703523, 0.94077086, 0.42255597, 0.80453309, 0.79757204, 0.28653229, 0.60175909, 0.55859623, 0.34318230, 0.63002770, 0.36533324, 0.89689906, 0.73236186, 0.61491989, 0.83787947, 0.67939463, 0.72016694, 0.77499849, 0.72428343, 0.34571059, 0.23143007, 0.20099338, 0.85583142, 0.73174191, 0.54284092, 0.20264181, 0.53037061, 0.30493131, 0.82279766, 0.58542432, 0.72632070, 0.18394258, 0.00608118, 0.23808232, 0.17007573, 0.75245459, 0.84990616, 0.38827634, 0.33809538, 0.01080317, 0.27250145, 0.81769542, 0.15323253, 0.71668395, 0.99427044, 0.11355576, 0.50511923, 0.22952055, 0.78271870, 0.12833592, 0.88639055, 0.76398188, 0.49533508, 0.47939640, 0.73564612, 0.41465671, 0.10995635, 0.20271728, 0.00521771, 0.67265260, 0.11917707, 0.76574855, 0.43842117, 0.28530411, 0.79648090, 0.79433656, 0.12074559, 0.02325163, 0.10117917, 0.83559239, 0.67213900, 0.25247084, 0.47968157, 0.88649124, 0.33588961, 0.92338319, 0.18808573, 0.60248266, 0.36610154, 0.99123140, 0.10519719, 0.18754650, 0.43232584, 0.85447872, 0.15937568, 0.92947480, 0.62705964, 0.85960084, 0.13435660, 0.81845809, 0.60715133, 0.83030708, 0.83071910, 0.38883408, 0.92033237, 0.59820890, 0.75527947, 0.67683355, 0.21847023, 0.29395619, 0.50477953, 0.98977921, 0.96225332, 0.90143562, 0.19559914, 0.08978307, 0.09687492, 0.07381865, 0.22801110, 0.26669388, 0.99691302, 0.12113623, 0.34373057, 0.46066239, 0.48806761, 0.50688779, 0.00654483, 0.32076493, 0.42367646, 0.07112842, 0.54090558, 0.68230725, 0.49713828, 0.41958965, 0.68013847, 0.47691765, 0.63269259, 0.94304095, 0.54587271, 0.72447569, 0.28913523, 0.75766936, 0.52965692, 0.96854824, 0.15589071, 0.84128672, 0.16337522, 0.05771034, 0.21556356, 0.12094140, 0.29721207, 0.00811008, 0.66184926}); nd4j::ops::reverse_sequence op; - auto results = op.execute({&input, &lengths}, {}, {1, 0}); + auto results = op.evaluate({&input, &lengths}, {}, {1, 0}); ASSERT_EQ(Status::OK(), results->status()); auto z = results->at(0); @@ -1150,7 +1150,7 @@ TEST_F(DeclarableOpsTests5, Test_TopK_0) { auto expI = NDArrayFactory::create('c', {2, 1}, {4, 3}); nd4j::ops::top_k op; - auto result = op.execute({&x}, {}, {1, 0}); // without sorting + auto result = op.evaluate({&x}, {}, {1, 0}); // without sorting ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(2, result->size()); @@ -1177,7 +1177,7 @@ TEST_F(DeclarableOpsTests5, Test_TopK_0) { ASSERT_TRUE(expI.equalsTo(i)); // repeat res again for (int cases = 0; cases < 100; ++cases) { - op.execute({&x}, {v, i}, {}, {1, 0}, {}); // without sorting + op.execute({&x}, std::vector{v, i}, {}, {1, 0}, {}); // without sorting } delete result; } @@ -1189,7 +1189,7 @@ TEST_F(DeclarableOpsTests5, Test_TopK_1) { auto expI = NDArrayFactory::create('c', {2, 1}, {1, 0}); nd4j::ops::top_k op; - auto result = op.execute({&x}, {}, {1, 0}); // without sorting + auto result = op.evaluate({&x}, {}, {1, 0}); // without sorting ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(2, result->size()); @@ -1216,7 +1216,7 @@ TEST_F(DeclarableOpsTests5, Test_TopK_1) { ASSERT_TRUE(expI.equalsTo(i)); // repeat res again for (int cases = 0; cases < 100; ++cases) { - op.execute({&x}, {v, i}, {}, {1, 0}, {}); // without sorting + op.execute({&x}, std::vector{v, i}, {}, {1, 0}, {}); // without sorting } delete result; } @@ -1242,7 +1242,7 @@ TEST_F(DeclarableOpsTests5, Test_TopK_2) { auto expI = NDArrayFactory::create('c', {2, 3, 1 }, {2, 1, 0, 1, 2, 0}); nd4j::ops::top_k op; - auto result = op.execute({&x}, {}, {1, 1}); + auto result = op.evaluate({&x}, {}, {1, 1}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(2, result->size()); @@ -1292,7 +1292,7 @@ TEST_F(DeclarableOpsTests5, Test_TopK_3) { auto expI = NDArrayFactory::create('c', {2, 3, 2 }, {2, 0, 1, 3, 0, 3, 1, 3, 2, 1, 0, 2}); nd4j::ops::top_k op; - auto result = op.execute({&x}, {}, {2, 1}); + auto result = op.evaluate({&x}, {}, {2, 1}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(2, result->size()); @@ -1342,7 +1342,7 @@ TEST_F(DeclarableOpsTests5, Test_TopK_3_unsorted) { auto expI = NDArrayFactory::create('c', {2, 3, 2 }, {0, 2, 1, 3, 0, 3, 1, 3, 1, 2, 0, 2}); nd4j::ops::top_k op; - auto result = op.execute({&x}, {}, {2}, {false}); + auto result = op.evaluate({&x}, {}, {2}, {false}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(2, result->size()); @@ -1366,7 +1366,7 @@ TEST_F(DeclarableOpsTests5, Test_TopK_4) { auto expI = NDArrayFactory::create('c', {2, 2}, {1, 2, 0, 2}); nd4j::ops::top_k op; - auto result = op.execute({&x}, {}, {2, 1}); + auto result = op.evaluate({&x}, {}, {2, 1}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(2, result->size()); @@ -1390,7 +1390,7 @@ TEST_F(DeclarableOpsTests5, Test_TopK_5) { auto expI = NDArrayFactory::create('f', {2, 2}, {2, 1, 1, 2}); nd4j::ops::top_k op; - auto result = op.execute({&x}, {}, {2, 1}); + auto result = op.evaluate({&x}, {}, {2, 1}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(2, result->size()); @@ -1428,7 +1428,7 @@ TEST_F(DeclarableOpsTests5, Test_Moments_1) { float inf = 1.e-5f; nd4j::ops::moments op; - auto result = op.execute({&x, &y}, {}, {}); + auto result = op.evaluate({&x, &y}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(2, result->size()); @@ -1459,7 +1459,7 @@ TEST_F(DeclarableOpsTests5, Test_Moments_2) { NDArray expD('c', {4}, {28.472221, 12.888889, 23.951387, 11.555554}); nd4j::ops::moments op; - auto result = op.execute({&x}, {}, {0, 1}); + auto result = op.evaluate({&x}, {}, {0, 1}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(2, result->size()); @@ -1493,7 +1493,7 @@ TEST_F(DeclarableOpsTests5, Test_Moments_3) { 6.25f, 9.f, 0.0625f, 16.f}); nd4j::ops::moments op; - auto result = op.execute({&x}, {}, {0}); + auto result = op.evaluate({&x}, {}, {0}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(2, result->size()); @@ -1520,7 +1520,7 @@ TEST_F(DeclarableOpsTests5, Test_Moments_4) { auto expD = NDArrayFactory::create('c', {3, 4}, { 6.25f, 9.f, 27.5625f, 1.f, 6.25f, 4.f, 27.5625f, 1.f, 6.25f, 9.f, 0.0625f, 16.f}); nd4j::ops::moments op; - auto result = op.execute({&x}, {}, {0}); + auto result = op.evaluate({&x}, {}, {0}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(2, result->size()); @@ -1551,7 +1551,7 @@ TEST_F(DeclarableOpsTests5, trace_test1) { auto exp = NDArrayFactory::create('c', {3}, {40, 120, 200}); NDArray matrix('c', {3, 3}, {1., 2., 3., 4., 5., 6., 7., 8., 9.}); nd4j::ops::trace op; - auto results = op.execute({&input}, {}, {}); + auto results = op.evaluate({&input}, {}, {}); auto output = results->at(0); double traceM = matrix.getTrace(); // nd4j_printf("Trace for matrix is %f\n", traceM); @@ -1572,7 +1572,7 @@ TEST_F(DeclarableOpsTests5, trace_test2) { auto exp = NDArrayFactory::create(40.); nd4j::ops::trace op; - auto results = op.execute({&input}, {}, {}); + auto results = op.evaluate({&input}, {}, {}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -1590,7 +1590,7 @@ TEST_F(DeclarableOpsTests5, trace_test3) { auto exp = NDArrayFactory::create(1.); nd4j::ops::trace op; - auto results = op.execute({&input}, {}, {}); + auto results = op.evaluate({&input}, {}, {}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -1608,7 +1608,7 @@ TEST_F(DeclarableOpsTests5, trace_test4) { auto exp = NDArrayFactory::create(1.); nd4j::ops::trace op; - auto results = op.execute({&input}, {}, {}); + auto results = op.evaluate({&input}, {}, {}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -1626,7 +1626,7 @@ TEST_F(DeclarableOpsTests5, trace_test5) { auto exp = NDArrayFactory::create('c', {3, 4}, {75, 225, 375, 525, 675, 825, 975, 1125, 1275, 1425, 1575, 1725}); nd4j::ops::trace op; - auto results = op.execute({&input}, {}, {}, {}, false, nd4j::DataType::DOUBLE); + auto results = op.evaluate({&input}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -1643,7 +1643,7 @@ TEST_F(DeclarableOpsTests5, random_shuffle_test1) { input.linspace(1); nd4j::ops::random_shuffle op; - auto results = op.execute({&input}, {}, {}, {}, false, nd4j::DataType::DOUBLE); + auto results = op.evaluate({&input}); auto output = results->at(0); bool haveZeros = false; @@ -1666,7 +1666,7 @@ TEST_F(DeclarableOpsTests5, random_shuffle_test2) { input.linspace(1); nd4j::ops::random_shuffle op; - auto results = op.execute({&input}, {}, {}, {}, false, nd4j::DataType::DOUBLE); + auto results = op.evaluate({&input}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -1683,7 +1683,7 @@ TEST_F(DeclarableOpsTests5, random_shuffle_test3) { input.linspace(1); nd4j::ops::random_shuffle op; - auto results = op.execute({&input}, {}, {}, {}, false, nd4j::DataType::DOUBLE); + auto results = op.evaluate({&input}); auto output = results->at(0); bool haveZeros = false; @@ -1705,7 +1705,7 @@ TEST_F(DeclarableOpsTests5, random_shuffle_test04) { nd4j::ops::random_shuffle op; //NDArray* output; - auto results = op.execute({&input}, {}, {}, {}, true, nd4j::DataType::DOUBLE); + auto results = op.evaluate({&input}, {}, {}, {}, {}, true); ASSERT_EQ(Status::OK(), results->status()); auto output = &input; //results->at(0); bool haveZeros = false; @@ -1727,7 +1727,7 @@ TEST_F(DeclarableOpsTests5, random_shuffle_test4) { nd4j::ops::random_shuffle op; //NDArray* output; - auto results = op.execute({&input}, {}, {}, {}, false, nd4j::DataType::DOUBLE); + auto results = op.evaluate({&input}); ASSERT_EQ(Status::OK(), results->status()); auto output = results->at(0); bool haveZeros = false; @@ -1749,7 +1749,7 @@ TEST_F(DeclarableOpsTests5, random_shuffle_test5) { input.linspace(1); nd4j::ops::random_shuffle op; - auto results = op.execute({&input}, {}, {}, {}, false, nd4j::DataType::DOUBLE); + auto results = op.evaluate({&input}); auto output = results->at(0); bool haveZeros = false; @@ -1772,7 +1772,7 @@ TEST_F(DeclarableOpsTests5, random_shuffle_test6) { input.linspace(1); nd4j::ops::random_shuffle op; - auto results = op.execute({&input}, {}, {}, {}, false, nd4j::DataType::DOUBLE); + auto results = op.evaluate({&input}); auto output = results->at(0); bool haveZeros = false; @@ -1796,7 +1796,7 @@ TEST_F(DeclarableOpsTests5, random_shuffle_test7) { auto exp = NDArrayFactory::create('c', {1,4}, {1, 2, 3, 4}); nd4j::ops::random_shuffle op; - auto results = op.execute({&input}, {}, {}, {}, false, nd4j::DataType::DOUBLE); + auto results = op.evaluate({&input}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -1826,11 +1826,11 @@ TEST_F(DeclarableOpsTests5, EmbeddingLookup_1) { // y.printIndexedBuffer("y buffer"); nd4j::ops::embedding_lookup op; - auto result = op.execute({&x, &y}, {}, {0}, {}, false, nd4j::DataType::DOUBLE); + auto result = op.evaluate({&x, &y}, {}, {0}); auto output = result->at(0); // x.printShapeInfo("Input"); - // output->printShapeInfo("Output"); - // exp.printShapeInfo("Expected"); + output->printShapeInfo("Output"); + exp.printShapeInfo("Expected"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_TRUE(exp.isSameShape(output)); //output->printIndexedBuffer("Output"); @@ -1862,7 +1862,7 @@ TEST_F(DeclarableOpsTests5, EmbeddingLookup_2) { // y.printIndexedBuffer("y buffer"); nd4j::ops::embedding_lookup op; - auto result = op.execute({&x, &y}, {}, {0}, {}, false, nd4j::DataType::DOUBLE); + auto result = op.evaluate({&x, &y}, {}, {0}); auto output = result->at(0); // x.printShapeInfo("Input"); // output->printShapeInfo("Output"); @@ -1903,7 +1903,7 @@ TEST_F(DeclarableOpsTests5, EmbeddingLookup_3) { // res = tf.nn.embedding_lookup((p1, p2, p3, p4, p5, p6, p7), ids, 'mod') nd4j::ops::embedding_lookup op; - auto result = op.execute({&p1, &p2, &p3, &p4, &p5, &p6, &p7, &p8, &y}, {}, {1}, {}, false, nd4j::DataType::DOUBLE); + auto result = op.evaluate({&p1, &p2, &p3, &p4, &p5, &p6, &p7, &p8, &y}, {}, {1}); auto output = result->at(0); // x.printShapeInfo("Input"); // output->printIndexedBuffer("Output"); @@ -1946,7 +1946,7 @@ TEST_F(DeclarableOpsTests5, DynamicPartition_01) { NDArrayFactory::create('c', {1}, {1})}); nd4j::ops::dynamic_partition op; - auto result = op.execute({&x, &y}, {}, {numPartition}); + auto result = op.evaluate({&x, &y}, {}, {numPartition}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(result->size(), numPartition); // result has the same size as given param 4 @@ -1985,7 +1985,7 @@ TEST_F(DeclarableOpsTests5, DynamicPartition_1) { NDArrayFactory::create('c', {10}, {13, 23, 14, 24, 15, 25, 16, 26, 17, 27})}); nd4j::ops::dynamic_partition op; - auto result = op.execute({&x, &y}, {}, {numPartition}); + auto result = op.evaluate({&x, &y}, {}, {numPartition}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(result->size(), numPartition); // result has the same size as given param 4 @@ -2015,7 +2015,7 @@ TEST_F(DeclarableOpsTests5, DynamicPartition_2) { nd4j::ops::dynamic_partition op; int numPartition = 4; - auto result = op.execute({&x, &y}, {}, {numPartition}); + auto result = op.evaluate({&x, &y}, {}, {numPartition}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(result->size(), numPartition); // result has the same size as given param 4 @@ -2043,7 +2043,7 @@ TEST_F(DeclarableOpsTests5, DynamicPartition_3) { nd4j::ops::dynamic_partition op; int numPartition = 4; - auto result = op.execute({&x, &y}, {}, {numPartition}); + auto result = op.evaluate({&x, &y}, {}, {numPartition}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(result->size(), numPartition); // result has the same size as given param 4 @@ -2078,7 +2078,7 @@ TEST_F(DeclarableOpsTests5, DynamicStitch_empty_1) { auto d2 = NDArrayFactory::create('c', {2, 5}, {0.94414854,0.5956861,0.8668989,0.3502196,0.5100082,0.061725974,0.6621324,0.034165382,0.32576954,0.51917326}); nd4j::ops::dynamic_stitch op; - auto result = op.execute({&i0, &i1, &i2, &d0, &d1, &d2}, {}, {}); + auto result = op.evaluate({&i0, &i1, &i2, &d0, &d1, &d2}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); delete result; @@ -2094,7 +2094,7 @@ TEST_F(DeclarableOpsTests5, DynamicStitch_empty_2) { auto d2 = NDArrayFactory::create('c', {2, 5}, {0.94414854,0.5956861,0.8668989,0.3502196,0.5100082,0.061725974,0.6621324,0.034165382,0.32576954,0.51917326}); nd4j::ops::dynamic_stitch op; - auto result = op.execute({&i0, &i1, &i2, &d0, &d1, &d2}, {}, {}); + auto result = op.evaluate({&i0, &i1, &i2, &d0, &d1, &d2}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); delete result; @@ -2113,7 +2113,7 @@ TEST_F(DeclarableOpsTests5, DynamicStitch_1) { auto exp = NDArrayFactory::create({7.4f, 0.1f, -1.f, 5.2f, -1.f, 4.3f}); nd4j::ops::dynamic_stitch op; - auto result = op.execute({&x1, &x2, &y1, &y2}, {}, {}); + auto result = op.evaluate({&x1, &x2, &y1, &y2}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -2138,7 +2138,7 @@ TEST_F(DeclarableOpsTests5, DynamicStitch_2) { auto exp = NDArrayFactory::create({5.2f, -1.f, 4.3f, -1.f, 7.4f, 0.1f}); nd4j::ops::dynamic_stitch op; - auto result = op.execute({&x1, &x2, &y1, &y2}, {}, {}); + auto result = op.evaluate({&x1, &x2, &y1, &y2}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -2170,7 +2170,7 @@ TEST_F(DeclarableOpsTests5, fusedBatchNorm_test1) { nd4j::ops::fused_batch_norm op; - auto results = op.execute({&x, &scale, &offset}, {}, {0,1}); + auto results = op.evaluate({&x, &scale, &offset}, {}, {0,1}); auto y = results->at(0); auto batchMean = results->at(1); auto batchVar = results->at(2); @@ -2199,7 +2199,7 @@ TEST_F(DeclarableOpsTests5, fusedBatchNorm_test2) { auto expBatchVar = NDArrayFactory::create('c', {4}, {208.00001526, 208.00001526, 208.00001526, 208.00001526}); nd4j::ops::fused_batch_norm op; - auto results = op.execute({&x, &scale, &offset}, {0.05}, {0,1}); + auto results = op.evaluate({&x, &scale, &offset}, {0.05}, {0,1}); auto y = results->at(0); auto batchMean = results->at(1); auto batchVar = results->at(2); @@ -2228,7 +2228,7 @@ TEST_F(DeclarableOpsTests5, fusedBatchNorm_test3) { auto expBatchVar = NDArrayFactory::create('c', {4}, {208.00001526, 208.00001526, 208.00001526, 208.00001526}); nd4j::ops::fused_batch_norm op; - auto results = op.execute({&x, &scale, &offset}, {}, {1,1}); + auto results = op.evaluate({&x, &scale, &offset}, {}, {1,1}); auto y = results->at(0); auto batchMean = results->at(1); auto batchVar = results->at(2); @@ -2263,7 +2263,7 @@ TEST_F(DeclarableOpsTests5, fusedBatchNorm_test4) { nd4j::ops::fused_batch_norm op; - auto results = op.execute({&x, &scale, &offset}, {}, {0,1}); + auto results = op.evaluate({&x, &scale, &offset}, {}, {0,1}); auto y = results->at(0); auto batchMean = results->at(1); auto batchVar = results->at(2); @@ -2298,7 +2298,7 @@ TEST_F(DeclarableOpsTests5, fusedBatchNorm_test5) { nd4j::ops::fused_batch_norm op; - auto results = op.execute({&x, &scale, &offset}, {0.05}, {0,1}); + auto results = op.evaluate({&x, &scale, &offset}, {0.05}, {0,1}); auto y = results->at(0); auto batchMean = results->at(1); auto batchVar = results->at(2); @@ -2319,7 +2319,7 @@ TEST_F(DeclarableOpsTests5, confusion_matrix_test1) { auto expected = NDArrayFactory::create('c', {5, 5}, {0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}); nd4j::ops::confusion_matrix op; - auto results = op.execute({&labels, &predictions}, {}, {}); + auto results = op.evaluate({&labels, &predictions}, {}, {}); ASSERT_EQ(Status::OK(), results->status()); auto output = results->at(0); @@ -2338,7 +2338,7 @@ TEST_F(DeclarableOpsTests5, confusion_matrix_test2) { auto expected = NDArrayFactory::create('c', {3, 3}, {0, 0, 0, 1, 0, 0, 0, 0, 1}); nd4j::ops::confusion_matrix op; - auto results = op.execute({&labels, &predictions}, {}, {3}); + auto results = op.evaluate({&labels, &predictions}, {}, {3}); ASSERT_EQ(Status::OK(), results->status()); auto output = results->at(0); @@ -2358,7 +2358,7 @@ TEST_F(DeclarableOpsTests5, confusion_matrix_test3) { auto expected = NDArrayFactory::create('c', {3, 3}, {0, 0, 0, 100, 0, 0, 0, 0, 200}); nd4j::ops::confusion_matrix op; - auto results = op.execute({&labels, &predictions, &weights}, {}, {3}); + auto results = op.evaluate({&labels, &predictions, &weights}, {}, {3}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -2377,7 +2377,7 @@ TEST_F(DeclarableOpsTests5, confusion_matrix_test4) { auto expected = NDArrayFactory::create('c', {3, 3}, {0, 0, 0, 100, 0, 0, 0, 0, 200}); nd4j::ops::confusion_matrix op; - auto results = op.execute({&labels, &predictions, &weights}, {}, {3, nd4j::DataType::DOUBLE}); + auto results = op.evaluate({&labels, &predictions, &weights}, {}, {3, nd4j::DataType::DOUBLE}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -2396,7 +2396,7 @@ TEST_F(DeclarableOpsTests5, ZeroFraction_1) { 19, 0, 21, 22, 23, 24}); nd4j::ops::zero_fraction op; - auto res = op.execute({&x}, {}, {}); + auto res = op.evaluate({&x}, {}, {}); ASSERT_EQ(Status::OK(), res->status()); ASSERT_TRUE(res->at(0)->isScalar()); @@ -2411,7 +2411,7 @@ TEST_F(DeclarableOpsTests5, ZeroFraction_2) { auto x = NDArrayFactory::create('c', {2, 2, 2}, {5.5, 0., 0.3, 5.5, 8.6, 0., 0., 0.4}); nd4j::ops::zero_fraction op; - auto res = op.execute({&x}, {}, {}); + auto res = op.evaluate({&x}, {}, {}); ASSERT_EQ(Status::OK(), res->status()); ASSERT_TRUE(res->at(0)->isScalar()); @@ -2426,7 +2426,7 @@ TEST_F(DeclarableOpsTests5, ZeroFraction_3) { auto x = NDArrayFactory::create('f', {2, 2, 2}, {5.5, 0., 0.3, 5.5, 8.6, 0., 0., 0.4}); nd4j::ops::zero_fraction op; - auto res = op.execute({&x}, {}, {}); + auto res = op.evaluate({&x}, {}, {}); ASSERT_EQ(Status::OK(), res->status()); ASSERT_TRUE(res->at(0)->isScalar()); @@ -2445,7 +2445,7 @@ TEST_F(DeclarableOpsTests5, XWPlusB_1) { auto exp = NDArrayFactory::create('c', {2,2}, {173.f, 264.f, 310.f, 279.f}); nd4j::ops::xw_plus_b op; - auto result = op.execute({&x, &y, &b}, {}, {}); + auto result = op.evaluate({&x, &y, &b}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -2463,7 +2463,7 @@ TEST_F(DeclarableOpsTests5, StopGradient_1) { auto x = NDArrayFactory::create('c', {2,3}, { 1.f, 11.f, 3.f, 14.f, 5.f, 6.f}); nd4j::ops::stop_gradient op; - auto result = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::DOUBLE); + auto result = op.evaluate({&x}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -2486,7 +2486,7 @@ TEST_F(DeclarableOpsTests5, StopGradient_2) { auto x = NDArrayFactory::create('f', {2,3}, { 1.f, 11.f, 3.f, 14.f, 5.f, 6.f}); nd4j::ops::stop_gradient op; - auto result = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::DOUBLE); + auto result = op.evaluate({&x}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -2510,7 +2510,7 @@ TEST_F(DeclarableOpsTests5, log_softmax_test1) { auto expOutput = NDArrayFactory::create('c', {3, 3, 3}, {-2.16985e+00,-1.69846e-01,-3.16985e+00, -1.31507e+00,-6.31507e+00,-3.15072e-01, -8.00046e+00,-4.58767e-04,-9.00046e+00, -1.31327e+00,-1.23133e+01,-3.13266e-01, -1.40000e+01,-1.13743e-06,-1.50000e+01, -1.31326e+00,-1.83133e+01,-3.13262e-01, -2.00000e+01,-2.81941e-09,-2.10000e+01, -1.31326e+00,-2.43133e+01,-3.13262e-01, -2.73133e+01,-1.31326e+00,-3.13262e-01}); nd4j::ops::log_softmax op; - auto results = op.execute({&input}, {}, {}, {}, false, nd4j::DataType::DOUBLE); + auto results = op.evaluate({&input}); auto z = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -2527,7 +2527,7 @@ TEST_F(DeclarableOpsTests5, log_softmax_test2) { auto expOutput = NDArrayFactory::create('c', {3, 3, 3}, {-3.05095e+00,-3.04946e+00,-5.00705e+00, -5.09458e-02,-7.04946e+00,-7.04851e-03, -6.05095e+00,-4.94556e-02,-8.00705e+00, -3.04859e+00,-1.30000e+01,-3.04859e+00, -1.50486e+01,-2.37286e-06,-1.70486e+01, -4.85876e-02,-1.60000e+01,-4.85874e-02, -2.10000e+01,-3.04859e+00,-2.51269e+01, -7.96007e-10,-2.50486e+01,-2.12693e+00, -2.40000e+01,-4.85874e-02,-1.26928e-01}); nd4j::ops::log_softmax op; - auto results = op.execute({&input}, {}, {1}, {}, false, nd4j::DataType::DOUBLE); + auto results = op.evaluate({&input}, {}, {1}); auto z = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -2544,7 +2544,7 @@ TEST_F(DeclarableOpsTests5, log_softmax_test3) { auto expOutput = NDArrayFactory::create('c', {3, 3, 3}, {-2.16985e+00,-1.69846e-01,-3.16985e+00, -1.31507e+00,-6.31507e+00,-3.15072e-01, -8.00046e+00,-4.58767e-04,-9.00046e+00, -1.31327e+00,-1.23133e+01,-3.13266e-01, -1.40000e+01,-1.13743e-06,-1.50000e+01, -1.31326e+00,-1.83133e+01,-3.13262e-01, -2.00000e+01,-2.81941e-09,-2.10000e+01, -1.31326e+00,-2.43133e+01,-3.13262e-01, -2.73133e+01,-1.31326e+00,-3.13262e-01}); nd4j::ops::log_softmax op; - auto results = op.execute({&input}, {}, {2}, {}, false, nd4j::DataType::DOUBLE); + auto results = op.evaluate({&input}, {}, {2}); auto z = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -2562,7 +2562,7 @@ TEST_F(DeclarableOpsTests5, log_softmax_test5) { auto expOutput = NDArrayFactory::create('c', {3, 3}, {-2.16985, -0.16985, -3.16985, -1.31507, -6.31507, -0.31507, -9.31335, -1.31335, -0.31335}); nd4j::ops::log_softmax op; - auto results = op.execute({&input}, {}, {}, {}, false, nd4j::DataType::DOUBLE); + auto results = op.evaluate({&input}); auto z = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -2579,7 +2579,7 @@ TEST_F(DeclarableOpsTests5, log_softmax_test6) { auto expOutput = NDArrayFactory::create('c', {3, 3}, {-3.05095,-3.04946,-7.12773, -0.05095,-7.04946,-2.12773, -6.05095,-0.04946,-0.12773}); nd4j::ops::log_softmax op; - auto results = op.execute({&input}, {}, {0}, {}, false, nd4j::DataType::DOUBLE); + auto results = op.evaluate({&input}, {}, {0}); auto z = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -2596,7 +2596,7 @@ TEST_F(DeclarableOpsTests5, log_softmax_test7) { auto expOutput = NDArrayFactory::create('c', {1, 5}, {-4.42414, -2.42414, -5.42414, -1.42414, -0.42414}); nd4j::ops::log_softmax op; - auto results = op.execute({&input}, {}, {}, {}, false, nd4j::DataType::DOUBLE); + auto results = op.evaluate({&input}); auto z = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -2613,7 +2613,7 @@ TEST_F(DeclarableOpsTests5, log_softmax_test8) { auto expOutput = NDArrayFactory::create('c', {1, 5}, {0, 0, 0, 0, 0}); nd4j::ops::log_softmax op; - auto results = op.execute({&input}, {}, {0}, {}, false, nd4j::DataType::DOUBLE); + auto results = op.evaluate({&input}, {}, {0}); auto z = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -2630,7 +2630,7 @@ TEST_F(DeclarableOpsTests5, log_softmax_test9) { auto expOutput = NDArrayFactory::create('c', {5, 1}, {0, 0, 0, 0, 0}); nd4j::ops::log_softmax op; - auto results = op.execute({&input}, {}, {}, {}, false, nd4j::DataType::DOUBLE); + auto results = op.evaluate({&input}); auto z = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -2647,7 +2647,7 @@ TEST_F(DeclarableOpsTests5, log_softmax_test10) { auto expOutput = NDArrayFactory::create('c', {5, 1}, {-4.42414, -2.42414, -5.42414, -1.42414, -0.42414}); nd4j::ops::log_softmax op; - auto results = op.execute({&input}, {}, {0}, {}, false, nd4j::DataType::DOUBLE); + auto results = op.evaluate({&input}, {}, {0}); auto z = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -2664,7 +2664,7 @@ TEST_F(DeclarableOpsTests5, log_softmax_test11) { auto expOutput = NDArrayFactory::create('c', {5}, {-4.42414, -2.42414, -5.42414, -1.42414, -0.42414}); nd4j::ops::log_softmax op; - auto results = op.execute({&input}, {}, {}, {}, false, nd4j::DataType::DOUBLE); + auto results = op.evaluate({&input}); auto z = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -2683,7 +2683,7 @@ TEST_F(DeclarableOpsTests5, log_softmax_test12) { for (int i = 0; i < 10; ++i) { nd4j::ops::log_softmax op; - auto results = op.execute({&input}, {}, {}, {}, false, nd4j::DataType::DOUBLE); + auto results = op.evaluate({&input}); auto z = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -2702,7 +2702,7 @@ TEST_F(DeclarableOpsTests5, log_softmax_bp_test1) { auto exp = NDArrayFactory::create('c', {2, 2}, {-0.07311,0.02689, -0.07311,0.02689}); nd4j::ops::log_softmax_bp op; - auto results = op.execute({&input, &epsilon}, {}, {}, {}, false, nd4j::DataType::DOUBLE); + auto results = op.evaluate({&input, &epsilon}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -2720,7 +2720,7 @@ TEST_F(DeclarableOpsTests5, log_softmax_bp_test2) { auto exp = NDArrayFactory::create('c', {2, 2}, {-0.17616, -0.17616, 0.02384, 0.02384}); nd4j::ops::log_softmax_bp op; - auto results = op.execute({&input, &epsilon}, {}, {0}, {}, false, nd4j::DataType::DOUBLE); + auto results = op.evaluate({&input, &epsilon}, {}, {0}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -2749,7 +2749,7 @@ TEST_F(DeclarableOpsTests5, L2_Loss_1) { double exp(9.605); nd4j::ops::l2_loss op; - auto results = op.execute({&input}, {}, {}); + auto results = op.evaluate({&input}, {}, {}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -2765,7 +2765,7 @@ TEST_F(DeclarableOpsTests5, L2_Loss_2) { auto e = NDArrayFactory::create(0.303254); nd4j::ops::l2_loss op; - auto results = op.execute({&x}, {}, {}); + auto results = op.evaluate({&x}, {}, {}); ASSERT_EQ(Status::OK(), results->status()); auto z = results->at(0); @@ -2797,7 +2797,7 @@ TEST_F(DeclarableOpsTests5, LogPoissonLoss_1) { auto exp = NDArrayFactory::create('c', {2, 2, 2}, {1.3678794, 5.389056, 2.981689, 1.6465969, 1.7182817, 5.389056, 5.389056, 1.7182817}); nd4j::ops::log_poisson_loss op; - auto results = op.execute({&input, &weights, &targets}, {}, {0}, {}, false, nd4j::DataType::DOUBLE); + auto results = op.evaluate({&input, &weights, &targets}, {}, {0}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -2818,7 +2818,7 @@ TEST_F(DeclarableOpsTests5, LogPoissonLoss_2) { auto exp = NDArrayFactory::create('c', {2, 2, 2}, {3.0196857, 4.0408626, 2.1334953, 3.6984034, 1.3700882, 4.0408626, 4.0408626, 1.3700882}); nd4j::ops::log_poisson_loss op; - auto results = op.execute({&input, &weights, &targets}, {}, {0, 1}, {}, false, nd4j::DataType::DOUBLE); + auto results = op.evaluate({&input, &weights, &targets}, {}, {0, 1}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -2864,7 +2864,7 @@ TEST_F(DeclarableOpsTests5, NormalizeMoments_1) { -51., -10.75, -33.8125, -3.75}); nd4j::ops::normalize_moments op; - auto results = op.execute({&counts, &means, &deviance}, {0.0}, {}); + auto results = op.evaluate({&counts, &means, &deviance}, {0.0}, {}); ASSERT_EQ(Status::OK(), results->status()); ASSERT_EQ(results->size(), 2); @@ -2915,7 +2915,7 @@ TEST_F(DeclarableOpsTests5, NormalizeMoments_2) { 0.38888884, 1.0208334, 0.6927084, 1.076389}); nd4j::ops::normalize_moments op; - auto results = op.execute({&counts, &means, &deviance}, {0.0}, {}); + auto results = op.evaluate({&counts, &means, &deviance}, {0.0}, {}); ASSERT_EQ(Status::OK(), results->status()); ASSERT_EQ(results->size(), 2); @@ -2966,7 +2966,7 @@ TEST_F(DeclarableOpsTests5, NormalizeMoments_3) { 0.38888884, 1.0208334, 0.6927084, 1.076389}); nd4j::ops::normalize_moments op; - auto results = op.execute({&counts, &means, &deviance}, {shift}, {}); + auto results = op.evaluate({&counts, &means, &deviance}, {shift}, {}); ASSERT_EQ(Status::OK(), results->status()); ASSERT_EQ(results->size(), 2); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests6.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests6.cpp index c52191b8a..5be2eeebd 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests6.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests6.cpp @@ -50,7 +50,7 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_1) { matrix.linspace(1); nd4j::ops::strided_slice op; - auto result = op.execute({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 1}); + auto result = op.evaluate({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 1}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -71,7 +71,7 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_2) { matrix.linspace(1); nd4j::ops::strided_slice op; - auto result = op.execute({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 1}); + auto result = op.evaluate({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 1}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -92,7 +92,7 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_3) { //matrix.linspace(1); nd4j::ops::strided_slice op; - auto result = op.execute({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 1}); + auto result = op.evaluate({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 1}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -114,7 +114,7 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_4) { //matrix.linspace(1); nd4j::ops::strided_slice op; - auto result = op.execute({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 1}); + auto result = op.evaluate({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 1}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -133,7 +133,7 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_04) { auto s = NDArrayFactory::create_('c', {1}, {1}); nd4j::ops::ones_as opOnes; //auto exp = NDArrayFactory::create('c', {2}, {1.0f, 2.0f}); - auto onesRes = opOnes.execute({&matrix}, {}, {}); + auto onesRes = opOnes.evaluate({&matrix}); //matrix.linspace(1); ASSERT_EQ(onesRes->status(), Status::OK()); @@ -181,7 +181,7 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_5) { //matrix.linspace(1); nd4j::ops::strided_slice op; - auto result = op.execute({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 1}); + auto result = op.evaluate({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 1}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -201,7 +201,7 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_6) { //matrix.linspace(1); nd4j::ops::strided_slice op; - auto result = op.execute({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 2}); + auto result = op.evaluate({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 2}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -222,7 +222,7 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_7) { //matrix.linspace(1); nd4j::ops::strided_slice op; - auto result = op.execute({&matrix, &b, &e, &s}, {}, {1, 0, 0, 0, 0}); + auto result = op.evaluate({&matrix, &b, &e, &s}, {}, {1, 0, 0, 0, 0}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -244,7 +244,7 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_BP_1) { grad.linspace(1); nd4j::ops::strided_slice_bp op; - auto result = op.execute({&matrix, &grad}, {}, {1, 0, 1, 0, 2, 0, 0, 0, 1, 1, 1}); + auto result = op.evaluate({&matrix, &grad}, {}, {1, 0, 1, 0, 2, 0, 0, 0, 1, 1, 1}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -266,7 +266,7 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_BP_2) { //grad.linspace(1); nd4j::ops::strided_slice_bp op; - auto result = op.execute({&matrix, &grad}, {}, {1, 0, 1, 0, 2, 0, 0, 0, 1, 1, 1}); + auto result = op.evaluate({&matrix, &grad}, {}, {1, 0, 1, 0, 2, 0, 0, 0, 1, 1, 1}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -288,7 +288,7 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_BP_3) { grad.linspace(1); nd4j::ops::strided_slice_bp op; - auto result = op.execute({&matrix, &grad}, {}, {1, 0, 1, 0, 0, 0, 0, 0, 256, 1, 1}); + auto result = op.evaluate({&matrix, &grad}, {}, {1, 0, 1, 0, 0, 0, 0, 0, 256, 1, 1}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -302,7 +302,7 @@ TEST_F(DeclarableOpsTests6, Test_Simple_Scalar_1) { auto exp = NDArrayFactory::create('c', {1, 1}, {4.0f}); nd4j::ops::test_scalar op; - auto result = op.execute({&x}, {}, {}); + auto result = op.evaluate({&x}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); @@ -321,7 +321,7 @@ TEST_F(DeclarableOpsTests6, Test_Order_1) { exp.linspace(1); nd4j::ops::order op; - auto result = op.execute({&x}, {}, {0}); + auto result = op.evaluate({&x}, {}, {0}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -336,7 +336,7 @@ TEST_F(DeclarableOpsTests6, cumSum_1) { auto exp = NDArrayFactory::create('c', {1, 4}, {1.f, 3.f, 6.f, 10.f}); nd4j::ops::cumsum op; - auto result = op.execute({&x}, {}, {0, 0}); + auto result = op.evaluate({&x}, {}, {0, 0}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -352,7 +352,7 @@ TEST_F(DeclarableOpsTests6, cumSum_2) { auto exp= NDArrayFactory::create('c', {2, 4}, {1.f, 3.f, 6.f, 10.f, 1.f, 3.f, 6.f, 10.f}); nd4j::ops::cumsum op; - auto result = op.execute({&x}, {}, {0, 0, 1}); + auto result = op.evaluate({&x}, {}, {0, 0, 1}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -369,7 +369,7 @@ TEST_F(DeclarableOpsTests6, cumSum_3) { auto exp= NDArrayFactory::create('c', {2, 4}, {1.f, 2.f, 3.f, 4.f, 2.f, 4.f, 6.f, 8.f}); nd4j::ops::cumsum op; - auto result = op.execute({&x}, {}, {0, 0, 0}); + auto result = op.evaluate({&x}, {}, {0, 0, 0}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -385,7 +385,7 @@ TEST_F(DeclarableOpsTests6, cumSum_4) { auto exp = NDArrayFactory::create('c', {3, 3}, {12., 15., 18., 11., 13., 15., 7., 8., 9.}); nd4j::ops::cumsum op; - auto result = op.execute({&x}, {}, {0, 1, 0}, {}, false, nd4j::DataType::DOUBLE); + auto result = op.evaluate({&x}, {}, {0, 1, 0}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -401,7 +401,7 @@ TEST_F(DeclarableOpsTests6, cumSum_5) { auto exp = NDArrayFactory::create('c', {3, 3}, {6.f, 5.f, 3.f, 15.f, 11.f, 6.f, 24.f, 17.f, 9.f,}); nd4j::ops::cumsum op; - auto result = op.execute({&x}, {}, {0, 1, 1}, {}, false, nd4j::DataType::DOUBLE); + auto result = op.evaluate({&x}, {}, {0, 1, 1}, {}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -416,7 +416,7 @@ TEST_F(DeclarableOpsTests6, cumSum_6) { auto exp = NDArrayFactory::create('c', {3, 3}, {11.f, 13.f, 15.f, 7.f, 8.f, 9.f, 0.f, 0.f, 0.f}); nd4j::ops::cumsum op; - auto result = op.execute({&x}, {}, {1, 1, 0}, {}, false, nd4j::DataType::DOUBLE); + auto result = op.evaluate({&x}, {}, {1, 1, 0}, {}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -431,7 +431,7 @@ TEST_F(DeclarableOpsTests6, cumSum_7) { auto exp = NDArrayFactory::create('c', {3, 3}, {5.f, 3.f, 0.f, 11.f, 6.f, 0.f, 17.f, 9.f, 0.f}); nd4j::ops::cumsum op; - auto result = op.execute({&x}, {}, {1, 1, 1}, {}, false, nd4j::DataType::DOUBLE); + auto result = op.evaluate({&x}, {}, {1, 1, 1}, {}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -447,7 +447,7 @@ TEST_F(DeclarableOpsTests6, cumSum_8) { auto exp = NDArrayFactory::create('c', {3, 3}, {5.f, 3.f, 0.f, 11.f, 6.f, 0.f, 17.f, 9.f, 0.f}); nd4j::ops::cumsum op; - auto result = op.execute({&x, &axis}, {}, {1, 1}, {}, false, nd4j::DataType::DOUBLE); + auto result = op.evaluate({&x, &axis}, {}, {1, 1}, {}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -475,7 +475,7 @@ TEST_F(DeclarableOpsTests6, cumSum_9) { exclusive = 0; reverse = 0; nd4j::ops::cumsum op; - auto result = op.execute({&inputC, &axis}, {}, {exclusive, reverse}, {}, false, nd4j::DataType::DOUBLE); + auto result = op.evaluate({&inputC, &axis}, {}, {exclusive, reverse}, {}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); ASSERT_TRUE(expFF.equalsTo(z)); @@ -484,7 +484,7 @@ TEST_F(DeclarableOpsTests6, cumSum_9) { //************************************// exclusive = 1; reverse = 0; - result = op.execute({&inputC, &axis}, {}, {exclusive, reverse}, {}, false, nd4j::DataType::DOUBLE); + result = op.evaluate({&inputC, &axis}, {}, {exclusive, reverse}); ASSERT_EQ(Status::OK(), result->status()); z = result->at(0); ASSERT_TRUE(expTF.equalsTo(z)); @@ -493,7 +493,7 @@ TEST_F(DeclarableOpsTests6, cumSum_9) { //************************************// exclusive = 0; reverse = 1; - result = op.execute({&inputC, &axis}, {}, {exclusive, reverse}, {}, false, nd4j::DataType::DOUBLE); + result = op.evaluate({&inputC, &axis}, {}, {exclusive, reverse}); ASSERT_EQ(Status::OK(), result->status()); z = result->at(0); ASSERT_TRUE(expFT.equalsTo(z)); @@ -502,7 +502,7 @@ TEST_F(DeclarableOpsTests6, cumSum_9) { //************************************// exclusive = 1; reverse = 1; - result = op.execute({&inputC, &axis}, {}, {exclusive, reverse}, {}, false, nd4j::DataType::DOUBLE); + result = op.evaluate({&inputC, &axis}, {}, {exclusive, reverse}); ASSERT_EQ(Status::OK(), result->status()); z = result->at(0); ASSERT_TRUE(expTT.equalsTo(z)); @@ -516,7 +516,7 @@ TEST_F(DeclarableOpsTests6, cumSum_10) { auto y = NDArrayFactory::create(-3); nd4j::ops::cumsum op; - auto result = op.execute({&x, &y}, {}, {1, 1}, {}); + auto result = op.evaluate({&x, &y}, {}, {1, 1}); ASSERT_EQ(Status::OK(), result->status()); delete result; @@ -531,7 +531,7 @@ TEST_F(DeclarableOpsTests6, cumSum_11) { x.linspace(1); nd4j::ops::cumsum op; - auto result = op.execute({&x}, {}, {0, 1, 1}, {}, false, nd4j::DataType::DOUBLE); + auto result = op.evaluate({&x}, {}, {0, 1, 1}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -550,7 +550,7 @@ TEST_F(DeclarableOpsTests6, cumSum_12) { x.linspace(1); nd4j::ops::cumsum op; - auto result = op.execute({&x}, {}, {0, 0, 1}, {}, false, nd4j::DataType::DOUBLE); + auto result = op.evaluate({&x}, {}, {0, 0, 1}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -569,7 +569,7 @@ TEST_F(DeclarableOpsTests6, cumSum_13) { x.linspace(1); nd4j::ops::cumsum op; - auto result = op.execute({&x}, {}, {1, 1, 1}, {}, false, nd4j::DataType::DOUBLE); + auto result = op.evaluate({&x}, {}, {1, 1, 1}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -588,7 +588,7 @@ TEST_F(DeclarableOpsTests6, cumSum_14) { x.linspace(1); nd4j::ops::cumsum op; - auto result = op.execute({&x}, {}, {1, 1, 0}, {}, false, nd4j::DataType::DOUBLE); + auto result = op.evaluate({&x}, {}, {1, 1, 0}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -607,7 +607,7 @@ TEST_F(DeclarableOpsTests6, cumSum_15) { x.linspace(1); nd4j::ops::cumsum op; - auto result = op.execute({&x}, {}, {0, 1, 2}, {}, false, nd4j::DataType::DOUBLE); + auto result = op.evaluate({&x}, {}, {0, 1, 2}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -623,7 +623,7 @@ TEST_F(DeclarableOpsTests6, cumSum_16) { NDArray x('f', {3, 4}, nd4j::DataType::FLOAT32); nd4j::ops::cumsum op; - auto result = op.execute({&x}, {}, {0, 0, 1}); + auto result = op.evaluate({&x}, {}, {0, 0, 1}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -659,7 +659,7 @@ TEST_F(DeclarableOpsTests6, cumSum_17) { } nd4j::ops::cumsum op; - auto result = op.execute({&x}, {}, {0, 0, 1}); + auto result = op.evaluate({&x}, {}, {0, 0, 1}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -692,7 +692,7 @@ TEST_F(DeclarableOpsTests6, cumSum_18) { } nd4j::ops::cumsum op; - auto result = op.execute({&x}, {}, {1, 0, 1}); + auto result = op.evaluate({&x}, {}, {1, 0, 1}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -725,7 +725,7 @@ TEST_F(DeclarableOpsTests6, cumSum_19) { } nd4j::ops::cumsum op; - auto result = op.execute({&x}, {}, {0, 1, 1}); + auto result = op.evaluate({&x}, {}, {0, 1, 1}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -759,7 +759,7 @@ TEST_F(DeclarableOpsTests6, cumSum_20) { } nd4j::ops::cumsum op; - auto result = op.execute({&x}, {}, {1, 1, 1}); + auto result = op.evaluate({&x}, {}, {1, 1, 1}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -778,7 +778,7 @@ TEST_F(DeclarableOpsTests6, TestMergeMaxIndex_1) { auto exp = NDArrayFactory::create('c', {2, 2, 2}, {1, 2, 1, 2, 1, 2, 1, 2}); nd4j::ops::mergemaxindex op; - auto ress = op.execute({&x, &y, &z}, {}, {}, {}); + auto ress = op.evaluate({&x, &y, &z}, {}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, ress->status()); // ress->at(0)->printIndexedBuffer("MergeMaxIndex Result is "); @@ -797,7 +797,7 @@ TEST_F(DeclarableOpsTests6, TestMergeMaxIndex_2) { auto exp = NDArrayFactory::create('c', {2, 2, 2}, {1, 2, 1, 2, 1, 2, 1, 2}); nd4j::ops::mergemaxindex op; - auto ress = op.execute({&x, &y, &z}, {}, {nd4j::DataType::INT64}, {}); + auto ress = op.evaluate({&x, &y, &z}, {}, {nd4j::DataType::INT64}); ASSERT_EQ(ND4J_STATUS_OK, ress->status()); // ress->at(0)->printIndexedBuffer("MergeMaxIndex2 Result is "); @@ -814,7 +814,7 @@ TEST_F(DeclarableOpsTests6, TestDropout_1) { auto shape = NDArrayFactory::create({2, 2}); nd4j::ops::dropout op; - auto ress = op.execute({&x, &shape}, {0.2f}, {113}, {}, false, nd4j::DataType::DOUBLE); + auto ress = op.evaluate({&x, &shape}, {0.2f}, {113}); ASSERT_EQ(ND4J_STATUS_OK, ress->status()); //ress->at(0)->printIndexedBuffer("Result is "); @@ -830,7 +830,7 @@ TEST_F(DeclarableOpsTests6, TestMod_1) { auto exp = NDArrayFactory::create('c', {2, 2, 2}, {1, 0, 3, 0, 5, 0, 7, 0}); nd4j::ops::mod op; - auto ress = op.execute({&x, &y}, {}, {}, {}); + auto ress = op.evaluate({&x, &y}); ASSERT_EQ(ND4J_STATUS_OK, ress->status()); // ress->at(0)->printIndexedBuffer("MOD Result is "); @@ -848,7 +848,7 @@ TEST_F(DeclarableOpsTests6, TestMod_BP_1) { auto exp = NDArrayFactory::create('c', {2, 2, 2}); nd4j::ops::mod_bp op; - auto ress = op.execute({&x, &y, &eps}, {}, {}, {}); + auto ress = op.evaluate({&x, &y, &eps}); ASSERT_EQ(ND4J_STATUS_OK, ress->status()); // ress->at(0)->printIndexedBuffer("MOD_BP Result is "); @@ -867,7 +867,7 @@ TEST_F(DeclarableOpsTests6, TestRank_1) { auto exp = NDArrayFactory::create(3); nd4j::ops::rank op; - auto ress = op.execute({&x}, {}, {}, {}); + auto ress = op.evaluate({&x}); ASSERT_EQ(ND4J_STATUS_OK, ress->status()); @@ -881,7 +881,7 @@ TEST_F(DeclarableOpsTests6, TestDropout_2) { nd4j::ops::dropout op; - auto ress = op.execute({&x}, {0.4f}, {113}, {}, false, nd4j::DataType::DOUBLE); + auto ress = op.evaluate({&x}, {0.4f}, {113}); ASSERT_EQ(ND4J_STATUS_OK, ress->status()); @@ -896,7 +896,7 @@ TEST_F(DeclarableOpsTests6, TestDropout_3) { nd4j::ops::dropout op; - auto ress = op.execute({&x, &shape}, {0.4f}, {113}, {}, false, nd4j::DataType::DOUBLE); + auto ress = op.evaluate({&x, &shape}, {0.4f}, {113}); ASSERT_EQ(ND4J_STATUS_OK, ress->status()); @@ -913,7 +913,7 @@ TEST_F(DeclarableOpsTests6, MaxPoolWithArgmax_1) { nd4j::ops::max_pool_with_argmax op; - auto ress = op.execute({&x}, {}, {1,1,1,1,1,1,1,1,1}); + auto ress = op.evaluate({&x}, {}, {1,1,1,1,1,1,1,1,1}); ASSERT_EQ(ND4J_STATUS_OK, ress->status()); ASSERT_TRUE(expI.isSameShape(ress->at(0))); @@ -942,7 +942,7 @@ TEST_F(DeclarableOpsTests6, SufficientStatistics_1) { nd4j::ops::sufficient_statistics op; - auto ress = op.execute({&x, &axis}, {}, {}, {}, false, nd4j::DataType::DOUBLE); + auto ress = op.evaluate({&x, &axis}); ASSERT_EQ(ND4J_STATUS_OK, ress->status()); ASSERT_EQ(ress->at(0)->e(0), count); @@ -974,7 +974,7 @@ TEST_F(DeclarableOpsTests6, SufficientStatistics_2) { nd4j::ops::sufficient_statistics op; - auto ress = op.execute({&x, &axis}, {}, {}, {}, false, nd4j::DataType::DOUBLE); + auto ress = op.evaluate({&x, &axis}); ASSERT_EQ(ND4J_STATUS_OK, ress->status()); ASSERT_EQ(ress->at(0)->e(0), count); @@ -996,7 +996,7 @@ TEST_F(DeclarableOpsTests6, BinCount_1) { nd4j::ops::bincount op; - auto res = op.execute({&x}, {}, {}); + auto res = op.evaluate({&x}); ASSERT_EQ(ND4J_STATUS_OK, res->status()); ASSERT_TRUE(exp.equalsTo(res->at(0))); @@ -1021,7 +1021,7 @@ TEST_F(DeclarableOpsTests6, BinCount_2) { nd4j::ops::bincount op; - auto res = op.execute({&x, &weights}, {}, {}); + auto res = op.evaluate({&x, &weights}); ASSERT_EQ(ND4J_STATUS_OK, res->status()); ASSERT_TRUE(exp.equalsTo(res->at(0))); @@ -1046,7 +1046,7 @@ TEST_F(DeclarableOpsTests6, BinCount_3) { nd4j::ops::bincount op; - auto res = op.execute({&x, &weights}, {}, {0, 2}); + auto res = op.evaluate({&x, &weights}, {}, {0, 2}); ASSERT_EQ(ND4J_STATUS_OK, res->status()); ASSERT_TRUE(exp.equalsTo(res->at(0))); @@ -1071,7 +1071,7 @@ TEST_F(DeclarableOpsTests6, BinCount_4) { nd4j::ops::bincount op; - auto res = op.execute({&x, &weights}, {}, {4, 4}); + auto res = op.evaluate({&x, &weights}, {}, {4, 4}); ASSERT_EQ(ND4J_STATUS_OK, res->status()); ASSERT_TRUE(exp.equalsTo(res->at(0))); @@ -1097,7 +1097,7 @@ TEST_F(DeclarableOpsTests6, BinCount_5) { nd4j::ops::bincount op; - auto res = op.execute({&x, &weights, &minV, &maxV}, {}, {}); + auto res = op.evaluate({&x, &weights, &minV, &maxV}); ASSERT_EQ(ND4J_STATUS_OK, res->status()); // res->at(0)->printBuffer("BC out"); ASSERT_TRUE(exp.equalsTo(res->at(0))); @@ -1116,7 +1116,7 @@ TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_1) { nd4j::ops::broadcast_dynamic_shape op; - auto res = op.execute({&x, &y}, {}, {}); + auto res = op.evaluate({&x, &y}); ASSERT_EQ(ND4J_STATUS_OK, res->status()); ASSERT_TRUE(exp.equalsTo(res->at(0))); @@ -1135,7 +1135,7 @@ TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_2) { nd4j::ops::broadcast_dynamic_shape op; - auto res = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::INT64); + auto res = op.evaluate({&x, &y}); ASSERT_EQ(ND4J_STATUS_OK, res->status()); ASSERT_TRUE(exp.equalsTo(res->at(0))); @@ -1153,7 +1153,7 @@ TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_3) { nd4j::ops::broadcast_dynamic_shape op; - auto res = op.execute({&x, &y}, {}, {}, {}); + auto res = op.evaluate({&x, &y}, {}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, res->status()); ASSERT_TRUE(exp.equalsTo(res->at(0))); @@ -1172,7 +1172,7 @@ TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_SGO_4) { nd4j::ops::broadcast_dynamic_shape op; - auto res = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::INT64); + auto res = op.evaluate({&x, &y}); ASSERT_EQ(ND4J_STATUS_OK, res->status()); //res->at(0)->printBuffer("Shape SGO 4"); @@ -1191,7 +1191,7 @@ TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_SGO_6) { auto exp = NDArrayFactory::create({2, 2, 4}); nd4j::ops::broadcast_dynamic_shape op; - auto res = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::INT64); + auto res = op.evaluate({&x, &y}); ASSERT_EQ(ND4J_STATUS_OK, res->status()); ASSERT_TRUE(exp.equalsTo(res->at(0))); @@ -1209,7 +1209,7 @@ TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_SGO_7) { auto exp = NDArrayFactory::create({2, 4, 3}); nd4j::ops::broadcast_dynamic_shape op; - auto res = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::INT64); + auto res = op.evaluate({&x, &y}); ASSERT_EQ(ND4J_STATUS_OK, res->status()); ASSERT_TRUE(exp.equalsTo(res->at(0))); @@ -1274,7 +1274,7 @@ TEST_F(DeclarableOpsTests6, ClipByGlobalNorm_1) { // auto expNorm(8.660254); nd4j::ops::clip_by_global_norm op; - auto result = op.execute({&x}, {0.8}, {}); + auto result = op.evaluate({&x}, {0.8}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -1316,7 +1316,7 @@ TEST_F(DeclarableOpsTests6, ClipByGlobalNorm_2) { ); nd4j::ops::clip_by_global_norm op; - auto result = op.execute({&x, &a}, {1.8}, {}); + auto result = op.evaluate({&x, &a}, {1.8}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -1346,7 +1346,7 @@ TEST_F(DeclarableOpsTests6, ClipByGlobalNorm_3) { ); nd4j::ops::clip_by_global_norm op; - auto result = op.execute({&x, &a}, {0.8}, {}); + auto result = op.evaluate({&x, &a}, {0.8}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -1372,7 +1372,7 @@ TEST_F(DeclarableOpsTests6, MatrixDeterminant_1) { auto exp = NDArrayFactory::create({36.0, -48.0}); nd4j::ops::matrix_determinant op; - auto result = op.execute({&x}, {}, {}); + auto result = op.evaluate({&x}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -1393,7 +1393,7 @@ TEST_F(DeclarableOpsTests6, MatrixDeterminant_2) { auto exp = NDArrayFactory::create({-2.0, -2.0}); nd4j::ops::matrix_determinant op; - auto result = op.execute({&x}, {}, {}); + auto result = op.evaluate({&x}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -1414,7 +1414,7 @@ TEST_F(DeclarableOpsTests6, MatrixDeterminant_3) { NDArray exp('c', {1}, {-54.0}); nd4j::ops::matrix_determinant op; - auto result = op.execute({&x}, {}, {}); + auto result = op.evaluate({&x}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -1435,7 +1435,7 @@ TEST_F(DeclarableOpsTests6, MatrixDeterminant_4) { auto exp = NDArrayFactory::create('c', {1}, {189.0}); nd4j::ops::matrix_determinant op; - auto result = op.execute({&x}, {}, {}); + auto result = op.evaluate({&x}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -1459,7 +1459,7 @@ TEST_F(DeclarableOpsTests6, MatrixDeterminant_5) { x.p(12, 12.0); nd4j::ops::matrix_determinant op; - auto result = op.execute({&x}, {}, {}); + auto result = op.evaluate({&x}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -1483,7 +1483,7 @@ TEST_F(DeclarableOpsTests6, MatrixDeterminant_6) { x.p(12, 12.0); nd4j::ops::matrix_determinant op; - auto result = op.execute({&x}, {}, {}); + auto result = op.evaluate({&x}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -1505,7 +1505,7 @@ TEST_F(DeclarableOpsTests6, LogMatrixDeterminant_1) { auto exp = NDArrayFactory::create({3.58351893845611, 3.871201010907891}); nd4j::ops::log_matrix_determinant op; - auto result = op.execute({&x}, {}, {}); + auto result = op.evaluate({&x}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -1524,7 +1524,7 @@ TEST_F(DeclarableOpsTests6, LogDet_1) { auto exp = NDArrayFactory::create({ 3.5835189, 4.159008}); nd4j::ops::logdet op; - auto result = op.execute({&x}, {}, {}); + auto result = op.evaluate({&x}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -1542,7 +1542,7 @@ TEST_F(DeclarableOpsTests6, LogDet_2) { auto exp = NDArrayFactory::create('c', {1}, { 3.5835189}); nd4j::ops::logdet op; - auto result = op.execute({&x}, {}, {}); + auto result = op.evaluate({&x}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -1561,7 +1561,7 @@ TEST_F(DeclarableOpsTests6, LogDet_3) { auto exp = NDArrayFactory::create( 3.5835189); nd4j::ops::logdet op; - auto result = op.execute({&x}, {}, {}); + auto result = op.evaluate({&x}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -1605,7 +1605,7 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_1) { }); nd4j::ops::matrix_inverse op; - auto result = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::FLOAT32); + auto result = op.evaluate({&x}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -1624,7 +1624,7 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_010) { auto exp = NDArrayFactory::create('c', {1, 5, 5}, {1.0f, 0.0f, 0.0f, 0.0f, 0.f, -2.0f, 1.0f, 0.f, 0.f, 0.f, -26.0f, -2.0f, 1.f, 0.f, 0.f, 54.0f, 1.0f, -2.0f, 1.f, 0.f, -27.0f, 0.0f, 1.0f, -2.0f, 1.f}); nd4j::ops::matrix_inverse op; - auto result = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::FLOAT32); + auto result = op.evaluate({&x}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -1643,7 +1643,7 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_01) { auto exp = NDArrayFactory::create('c', {1, 5, 5}, {0.5f, -2.0f, -13.0f, 54.0f, -6.75f, 0.0f, 1.0f, -1.0f, 1.0f, 0.0f, 0.f, 0.f, 0.5f, -2.0f, 0.25f, 0.f, 0.f, 0.f, 1.0f, -0.5f, 0.f, 0.f, 0.f, 0.f, 0.25f }); nd4j::ops::matrix_inverse op; - auto result = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::FLOAT32); + auto result = op.evaluate({&x}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -1662,7 +1662,7 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_02) { auto exp = NDArrayFactory::create('c', {1, 5, 5}, {1.0f, 0.0f, 0.0f, 0.0f, 0.f, -2.0f, 1.0f, 0.f, 0.f, 0.f, -26.0f, -2.0f, 1.f, 0.f, 0.f, 54.0f, 1.0f, -2.0f, 1.f, 0.f, -27.0f, 0.0f, 1.0f, -2.0f, 1.f }); nd4j::ops::matrix_inverse op; - auto result = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::FLOAT32); + auto result = op.evaluate({&x}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -1707,7 +1707,7 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_2) { }); nd4j::ops::matrix_inverse op; - auto result = op.execute({&x}, {}, {}); + auto result = op.evaluate({&x}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -1740,7 +1740,7 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_03) { }); nd4j::ops::matrix_inverse op; - auto result = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::FLOAT32); + auto result = op.evaluate({&x}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -1774,7 +1774,7 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_3) { }); nd4j::ops::matrix_inverse op; - auto result = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::FLOAT32); + auto result = op.evaluate({&x}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -1808,7 +1808,7 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_4) { }); nd4j::ops::matrix_inverse op; - auto result = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::FLOAT32); + auto result = op.evaluate({&x}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -1842,7 +1842,7 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_04) { }); nd4j::ops::matrix_inverse op; - auto result = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::FLOAT32); + auto result = op.evaluate({&x}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -1870,7 +1870,7 @@ TEST_F(DeclarableOpsTests6, ReluLayer_1) { 26.2, 31.65, 60.7}); nd4j::ops::relu_layer op; - auto result = op.execute({&x, &w, &b}, {}, {}, {}, false, nd4j::DataType::DOUBLE); + auto result = op.evaluate({&x, &w, &b}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -1923,7 +1923,7 @@ TEST_F(DeclarableOpsTests6, static_rnn_test1) { auto expHFinal = NDArrayFactory::create('c', {bS, numUnits}, {0.97732812, 0.97732812, 0.97732812, 0.97732812, 0.93751527, 0.93751527, 0.93751527, 0.93751527}); nd4j::ops::static_rnn op; - auto results = op.execute({&x, &Wx, &Wh, &b, &h0, &maxTimeStep}, {}, {}); + auto results = op.evaluate({&x, &Wx, &Wh, &b, &h0, &maxTimeStep}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1966,7 +1966,7 @@ TEST_F(DeclarableOpsTests6, static_rnn_test2) { auto expHFinal = NDArrayFactory::create('c', {bS, numUnits}, {0.98000654, 0.98000654, 0.98000654, 0.98000654,0.98112648, 0.98112648, 0.98112648, 0.98112648}); nd4j::ops::static_rnn op; - auto results = op.execute({&x, &Wx, &Wh, &b, &h0}, {}, {}); + auto results = op.evaluate({&x, &Wx, &Wh, &b, &h0}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2009,7 +2009,7 @@ TEST_F(DeclarableOpsTests6, static_rnn_test3) { auto expHFinal = NDArrayFactory::create('c', {bS, numUnits}, {0.97732812, 0.97732812, 0.97732812, 0.97732812, 0.2 , 0.2 , 0.2 , 0.2}); nd4j::ops::static_rnn op; - auto results = op.execute({&x, &Wx, &Wh, &b, &h0, &maxTimeStep}, {}, {}); + auto results = op.evaluate({&x, &Wx, &Wh, &b, &h0, &maxTimeStep}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2051,7 +2051,7 @@ TEST_F(DeclarableOpsTests6, static_rnn_test4) { auto expHFinal = NDArrayFactory::create('c', {bS, numUnits}, {0.97688859, 0.97688859, 0.97688859, 0.97688859, 0.88400882, 0.88400882, 0.88400882, 0.88400882}); nd4j::ops::static_rnn op; - auto results = op.execute({&x, &Wx, &Wh, &b, &maxTimeStep}, {}, {}); + auto results = op.evaluate({&x, &Wx, &Wh, &b, &maxTimeStep}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2093,7 +2093,7 @@ TEST_F(DeclarableOpsTests6, static_rnn_test5) { auto expHFinal = NDArrayFactory::create('c', {bS, numUnits}, {0.97997868, 0.97997868, 0.97997868, 0.97997868, 0.98110653, 0.98110653, 0.98110653, 0.98110653}); nd4j::ops::static_rnn op; - auto results = op.execute({&x, &Wx, &Wh, &b}, {}, {}); + auto results = op.evaluate({&x, &Wx, &Wh, &b}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2144,7 +2144,7 @@ TEST_F(DeclarableOpsTests6, static_bidir_rnn_test1) { auto expHBWfinal = NDArrayFactory::create('c', {bS, numUnitsBW}, {0.86708881, 0.86708881, 0.86708881, 0.78347842, 0.78347842, 0.78347842, 0.55529176, 0.55529176, 0.55529176, 0.25, 0.25, 0.25}); nd4j::ops::static_bidirectional_rnn op; - auto results = op.execute({&x, &WxFW,&WhFW,&bFW, &WxFW,&WhFW,&bFW, &h0FW, &h0BW, &maxTimeStep}, {}, {}); + auto results = op.evaluate({&x, &WxFW,&WhFW,&bFW, &WxFW,&WhFW,&bFW, &h0FW, &h0BW, &maxTimeStep}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2197,7 +2197,7 @@ TEST_F(DeclarableOpsTests6, static_bidir_rnn_test2) { auto expHBWfinal = NDArrayFactory::create('c', {bS, numUnitsBW}, {0.86518273, 0.86518273, 0.86518273, 0.66617761, 0.66617761, 0.66617761, 0.31492203, 0.31492203, 0.31492203, 0. , 0. , 0.}); nd4j::ops::static_bidirectional_rnn op; - auto results = op.execute({&x, &WxFW,&WhFW,&bFW, &WxFW,&WhFW,&bFW, &maxTimeStep}, {}, {}); + auto results = op.evaluate({&x, &WxFW,&WhFW,&bFW, &WxFW,&WhFW,&bFW, &maxTimeStep}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2250,7 +2250,7 @@ TEST_F(DeclarableOpsTests6, static_bidir_rnn_test3) { auto expHBWfinal = NDArrayFactory::create('c', {bS, numUnitsBW}, {0.86841012, 0.86841012, 0.86841012, 0.88207531, 0.88207531, 0.88207531, 0.8941667 , 0.8941667 , 0.8941667 , 0.90489713, 0.90489713, 0.90489713}); nd4j::ops::static_bidirectional_rnn op; - auto results = op.execute({&x, &WxFW,&WhFW,&bFW, &WxFW,&WhFW,&bFW}, {}, {}); + auto results = op.evaluate({&x, &WxFW,&WhFW,&bFW, &WxFW,&WhFW,&bFW}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2296,7 +2296,7 @@ TEST_F(DeclarableOpsTests6, dynamic_rnn_test1) { auto expHFinal = NDArrayFactory::create('c', {bS, numUnits}, {0.97732812, 0.97732812, 0.97732812, 0.97732812, 0.93751527, 0.93751527, 0.93751527, 0.93751527}); nd4j::ops::dynamic_rnn op; - auto results = op.execute({&x, &Wx, &Wh, &b, &h0, &maxTimeStep}, {}, {1}); + auto results = op.evaluate({&x, &Wx, &Wh, &b, &h0, &maxTimeStep}, {}, {1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2341,7 +2341,7 @@ TEST_F(DeclarableOpsTests6, dynamic_rnn_test2) { auto expHFinal = NDArrayFactory::create('c', {bS, numUnits}, {0.97309129, 0.97309129, 0.97309129, 0.97309129, 0.98120782, 0.98120782, 0.98120782, 0.98120782}); nd4j::ops::dynamic_rnn op; - auto results = op.execute({&x, &Wx, &Wh, &b, &h0, &maxTimeStep}, {}, {}); + auto results = op.evaluate({&x, &Wx, &Wh, &b, &h0, &maxTimeStep}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2383,7 +2383,7 @@ TEST_F(DeclarableOpsTests6, dynamic_rnn_test3) { auto expHFinal = NDArrayFactory::create('c', {bS, numUnits}, {0.97491207, 0.97491207, 0.97491207, 0.97491207, 0.98120782, 0.98120782, 0.98120782, 0.98120782}); nd4j::ops::dynamic_rnn op; - auto results = op.execute({&x, &Wx, &Wh, &b, &h0}, {}, {}); + auto results = op.evaluate({&x, &Wx, &Wh, &b, &h0}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2424,7 +2424,7 @@ TEST_F(DeclarableOpsTests6, dynamic_rnn_test4) { auto expHFinal = NDArrayFactory::create('c', {bS, numUnits}, {0.9724738 , 0.9724738 , 0.9724738 , 0.9724738 ,0.57368608, 0.57368608, 0.57368608, 0.57368608}); nd4j::ops::dynamic_rnn op; - auto results = op.execute({&x, &Wx, &Wh, &b, &maxTimeStep}, {}, {}); + auto results = op.evaluate({&x, &Wx, &Wh, &b, &maxTimeStep}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2465,7 +2465,7 @@ TEST_F(DeclarableOpsTests6, dynamic_rnn_test5) { auto expHFinal = NDArrayFactory::create('c', {bS, numUnits}, {0.97486307, 0.97486307, 0.97486307, 0.97486307,0.98119833, 0.98119833, 0.98119833, 0.98119833}); nd4j::ops::dynamic_rnn op; - auto results = op.execute({&x, &Wx, &Wh, &b}, {}, {}); + auto results = op.evaluate({&x, &Wx, &Wh, &b}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2521,7 +2521,7 @@ TEST_F(DeclarableOpsTests6, dynamic_bidir_rnn_test1) { auto expHBWfinal = NDArrayFactory::create('c', {bS, numUnitsBW}, {0.86708881, 0.86708881, 0.86708881, 0.78347842, 0.78347842, 0.78347842, 0.55529176, 0.55529176, 0.55529176, 0.25 , 0.25 , 0.25}); nd4j::ops::dynamic_bidirectional_rnn op; - auto results = op.execute({&x, &WxFW,&WhFW,&bFW, &WxFW,&WhFW,&bFW, &h0FW, &h0BW, &maxTimeStep}, {}, {1}, {}, false, nd4j::DataType::DOUBLE); + auto results = op.evaluate({&x, &WxFW,&WhFW,&bFW, &WxFW,&WhFW,&bFW, &h0FW, &h0BW, &maxTimeStep}, {}, {1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2581,7 +2581,7 @@ TEST_F(DeclarableOpsTests6, dynamic_bidir_rnn_test2) { auto expHBWfinal = NDArrayFactory::create('c', {bS, numUnitsBW}, {0.84345207, 0.84345207, 0.84345207, 0.85615841, 0.85615841, 0.85615841, 0.76576202, 0.76576202, 0.76576202, 0.25 , 0.25 , 0.25}); nd4j::ops::dynamic_bidirectional_rnn op; - auto results = op.execute({&x, &WxFW,&WhFW,&bFW, &WxFW,&WhFW,&bFW, &h0FW, &h0BW, &maxTimeStep}, {}, {}); + auto results = op.evaluate({&x, &WxFW,&WhFW,&bFW, &WxFW,&WhFW,&bFW, &h0FW, &h0BW, &maxTimeStep}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2637,7 +2637,7 @@ TEST_F(DeclarableOpsTests6, dynamic_bidir_rnn_test3) { auto expHBWfinal = NDArrayFactory::create('c', {bS, numUnitsBW}, {0.82273707, 0.82273707, 0.82273707, 0.77843476, 0.77843476, 0.77843476, 0.61067683, 0.61067683, 0.61067683, 0. , 0. , 0.}); nd4j::ops::dynamic_bidirectional_rnn op; - auto results = op.execute({&x, &WxFW,&WhFW,&bFW, &WxFW,&WhFW,&bFW, &maxTimeStep}, {}, {}); + auto results = op.evaluate({&x, &WxFW,&WhFW,&bFW, &WxFW,&WhFW,&bFW, &maxTimeStep}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2696,7 +2696,7 @@ TEST_F(DeclarableOpsTests6, dynamic_bidir_rnn_test4) { auto expHBWfinal = NDArrayFactory::create('c', {bS, numUnitsBW}, {0.85301722, 0.85301722, 0.85301722, 0.91888753, 0.91888753, 0.91888753, 0.95254269, 0.95254269, 0.95254269, 0.97154357, 0.97154357, 0.97154357}); nd4j::ops::dynamic_bidirectional_rnn op; - auto results = op.execute({&x, &WxFW,&WhFW,&bFW, &WxFW,&WhFW,&bFW, &h0FW, &h0BW}, {}, {}); + auto results = op.evaluate({&x, &WxFW,&WhFW,&bFW, &WxFW,&WhFW,&bFW, &h0FW, &h0BW}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2749,7 +2749,7 @@ TEST_F(DeclarableOpsTests6, dynamic_bidir_rnn_test5) { auto expHBWfinal = NDArrayFactory::create('c', {bS, numUnitsBW}, {0.84882345, 0.84882345, 0.84882345, 0.91865453, 0.91865453, 0.91865453, 0.95252666, 0.95252666, 0.95252666, 0.97154234, 0.97154234, 0.97154234}); nd4j::ops::dynamic_bidirectional_rnn op; - auto results = op.execute({&x, &WxFW,&WhFW,&bFW, &WxFW,&WhFW,&bFW}, {}, {}); + auto results = op.evaluate({&x, &WxFW,&WhFW,&bFW, &WxFW,&WhFW,&bFW}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2776,7 +2776,7 @@ TEST_F(DeclarableOpsTests6, Test_Diag_119_1) { auto e = NDArrayFactory::create('c', {3, 3}, {0.15f, 0.0f, 0.0f, 0.0f, 0.25f, 0.0f, 0.0f, 0.0f, 0.35f}); nd4j::ops::diag op; - auto result = op.execute({&x}, {}, {}); + auto result = op.evaluate({&x}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(e, *result->at(0)); @@ -2789,7 +2789,7 @@ TEST_F(DeclarableOpsTests6, Test_Diag_119_2) { auto e = NDArrayFactory::create('c', {1, 1}, {0.15f}); nd4j::ops::diag op; - auto result = op.execute({&x}, {}, {}); + auto result = op.evaluate({&x}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(e, *result->at(0)); @@ -2802,7 +2802,7 @@ TEST_F(DeclarableOpsTests6, Test_Diag_119_3) { auto e = NDArrayFactory::create('c', {1, 1}, {0.15f}); nd4j::ops::diag op; - auto result = op.execute({&x}, {}, {}); + auto result = op.evaluate({&x}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(e, *result->at(0)); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests7.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests7.cpp index ffb847dbd..39761ecb3 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests7.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests7.cpp @@ -61,7 +61,7 @@ TEST_F(DeclarableOpsTests7, Test_CHOOSE_SCALAR_LARGE) { auto x = NDArrayFactory::create(inputData,'c',{1,149}); nd4j::ops::choose op; //greater than test - auto result = op.execute({&x}, {0.0},{3}); + auto result = op.evaluate({&x}, {0.0},{3}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(1); @@ -84,7 +84,7 @@ TEST_F(DeclarableOpsTests7, Test_CHOOSE_SCALAR_ZERO) { auto x = NDArrayFactory::create('c',{1,4},data); nd4j::ops::choose op; //greater than test - auto result = op.execute({&x}, {0.0},{3}); + auto result = op.evaluate({&x}, {0.0},{3}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(1); @@ -109,7 +109,7 @@ TEST_F(DeclarableOpsTests7, Test_CHOOSE_SCALAR) { auto scalar = NDArrayFactory::create('c',{1,1},{0.0}); nd4j::ops::choose op; //greater than test - auto result = op.execute({&x,&scalar}, {1.0},{3}); + auto result = op.evaluate({&x,&scalar}, {1.0},{3}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -133,7 +133,7 @@ TEST_F(DeclarableOpsTests7, Test_CHOOSE_SCALAR_LEFT) { auto scalar = NDArrayFactory::create('c',{1,1},{0.0}); nd4j::ops::choose op; //greater than test - auto result = op.execute({&scalar,&x}, {1.0},{3}); + auto result = op.evaluate({&scalar,&x}, {1.0},{3}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -156,7 +156,7 @@ TEST_F(DeclarableOpsTests7, Test_CHOOSE_ONLY_SCALAR) { auto x = NDArrayFactory::create('c',{1,4},data); nd4j::ops::choose op; //greater than test - auto result = op.execute({&x}, {1.0},{3}); + auto result = op.evaluate({&x}, {1.0},{3}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -179,7 +179,7 @@ TEST_F(DeclarableOpsTests7, Test_CHOOSE_ONLY_SCALAR_GTE) { auto x = NDArrayFactory::create('c',{1,4},data); nd4j::ops::choose op; //greater than test - auto result = op.execute({&x}, {1.0},{5}); + auto result = op.evaluate({&x}, {1.0},{5}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -223,7 +223,7 @@ TEST_F(DeclarableOpsTests7, TEST_WHERE) { //greater than test // Nd4jStatus execute(std::initializer_list*> inputs, std::initializer_list*> outputs , std::initializer_list tArgs, std::initializer_list iArgs, bool isInplace = false); - auto result = op.execute({&maskArr,&x,&putArr},{&resultArr}, {},{3}, {}, false); + auto result = op.execute({&maskArr,&x,&putArr},{&resultArr}, {},{3}, {}, {}, false); ASSERT_EQ(Status::OK(), result); for(int i = 0; i < 4; i++) ASSERT_EQ(assertion[i],resultArr.e(i)); @@ -294,7 +294,7 @@ TEST_F(DeclarableOpsTests7, TEST_WHERE_SCALAR) { //greater than test // Nd4jStatus execute(std::initializer_list*> inputs, std::initializer_list*> outputs , std::initializer_list tArgs, std::initializer_list iArgs, bool isInplace = false); - auto result = op.execute({&maskArr,&x,&putArr},{&resultArr}, {},{3}, {}, false); + auto result = op.execute({&maskArr,&x,&putArr},{&resultArr}, {},{3}, {}, {}, false); // ASSERT_EQ(Status::OK(), result->status()); for(int i = 0; i < 4; i++) ASSERT_EQ(assertion[i],resultArr.e(i)); @@ -314,7 +314,7 @@ TEST_F(DeclarableOpsTests7, TestMatrixDiagPart_1) { nd4j::ops::matrix_diag_part op; - auto result = op.execute({&x}, {}, {}); + auto result = op.evaluate({&x}, {}, {}); ASSERT_EQ(result->status(), Status::OK()); ASSERT_TRUE(z.equalsTo(result->at(0))); @@ -330,7 +330,7 @@ TEST_F(DeclarableOpsTests7, TestMatrixDiagPart_2) { nd4j::ops::matrix_diag_part op; - auto result = op.execute({&x}, {}, {}); + auto result = op.evaluate({&x}, {}, {}); ASSERT_EQ(result->status(), Status::OK()); ASSERT_TRUE(z.equalsTo(result->at(0))); @@ -346,7 +346,7 @@ TEST_F(DeclarableOpsTests7, TestMatrixDiag_1) { nd4j::ops::matrix_diag op; - auto result = op.execute({&x}, {}, {}); + auto result = op.evaluate({&x}, {}, {}); ASSERT_EQ(result->status(), Status::OK()); ASSERT_TRUE(z.equalsTo(result->at(0))); @@ -361,7 +361,7 @@ TEST_F(DeclarableOpsTests7, TestMatrixDiag_2) { nd4j::ops::matrix_diag op; - auto result = op.execute({&x}, {}, {}); + auto result = op.evaluate({&x}, {}, {}); ASSERT_EQ(result->status(), Status::OK()); ASSERT_TRUE(z.equalsTo(result->at(0))); @@ -375,7 +375,7 @@ TEST_F(DeclarableOpsTests7, TestRandomCrop_1) { auto shape = NDArrayFactory::create({1, 2, 3}); nd4j::ops::random_crop op; - auto result = op.execute({&x, &shape}, {}, {}); + auto result = op.evaluate({&x, &shape}, {}, {}); ASSERT_EQ(result->status(), Status::OK()); //result->at(0)->printIndexedBuffer("Output"); // ASSERT_TRUE(z.equalsTo(result->at(0))); @@ -389,7 +389,7 @@ TEST_F(DeclarableOpsTests7, TestRandomCrop_2) { auto shape = NDArrayFactory::create({2, 2, 2}); nd4j::ops::random_crop op; - auto result = op.execute({&x, &shape}, {}, {}); + auto result = op.evaluate({&x, &shape}, {}, {}); ASSERT_EQ(result->status(), Status::OK()); //result->at(0)->printIndexedBuffer("Output"); // ASSERT_TRUE(z.equalsTo(result->at(0))); @@ -426,7 +426,7 @@ TEST_F(DeclarableOpsTests7, Test_Dynamic_Stitch_119) { 53.f, 54.f, 55.f, 56.f,57.f, 58.f, 59.f, 60.f,21.f, 22.f, 23.f, 24.f,25.f, 26.f, 27.f, 28.f,29.f, 30.f, 31.f, 32.f,33.f, 34.f, 35.f, 36.f,37.f, 38.f, 39.f, 40.f}); nd4j::ops::dynamic_stitch op; - auto result = op.execute({&indices0, &indices1, &indices2, &data0, &data1, &data2}, {}, {}); + auto result = op.evaluate({&indices0, &indices1, &indices2, &data0, &data1, &data2}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); // result->at(0)->printIndexedBuffer("Output"); // exp.printIndexedBuffer("Expect"); @@ -464,7 +464,7 @@ TEST_F(DeclarableOpsTests7, Test_Dynamic_Stitch_Prof_1) { 53.f, 54.f, 55.f, 56.f,57.f, 58.f, 59.f, 60.f,21.f, 22.f, 23.f, 24.f,25.f, 26.f, 27.f, 28.f,29.f, 30.f, 31.f, 32.f,33.f, 34.f, 35.f, 36.f,37.f, 38.f, 39.f, 40.f}); nd4j::ops::dynamic_stitch op; - auto result = op.execute({&indices0, &indices1, &indices2, &data0, &data1, &data2}, {}, {}); + auto result = op.evaluate({&indices0, &indices1, &indices2, &data0, &data1, &data2}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); // result->at(0)->printIndexedBuffer("Output"); // exp.printIndexedBuffer("Expect"); @@ -567,7 +567,7 @@ TEST_F(DeclarableOpsTests7, Test_Dynamic_Stitch_119_1) { data1.linspace(21); data2.linspace(141); nd4j::ops::dynamic_stitch op; - auto result = op.execute({&indices0, &indices1, &indices2, &data0, &data1, &data2}, {}, {}); + auto result = op.evaluate({&indices0, &indices1, &indices2, &data0, &data1, &data2}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -658,7 +658,7 @@ TEST_F(DeclarableOpsTests7, Test_Dynamic_Stitch_119_2) { data1.linspace(41); data2.linspace(161); nd4j::ops::dynamic_stitch op; - auto result = op.execute({&indices0, &indices1, &indices2, &data0, &data1, &data2}, {}, {}); + auto result = op.evaluate({&indices0, &indices1, &indices2, &data0, &data1, &data2}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -676,7 +676,7 @@ TEST_F(DeclarableOpsTests7, Test_Dynamic_Partition_119) { x.assign(1.f); e.assign(1.f); nd4j::ops::dynamic_partition op; - auto result = op.execute({&x, &y}, {}, {4}); + auto result = op.evaluate({&x, &y}, {}, {4}); ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(4, result->size()); auto z = result->at(0); @@ -695,7 +695,7 @@ TEST_F(DeclarableOpsTests7, Test_Dynamic_Partition_119_1) { // x.assign(1.f); // e.assign(1.f); nd4j::ops::dynamic_partition op; - auto result = op.execute({&x, &y}, {}, {3}); + auto result = op.evaluate({&x, &y}, {}, {3}); ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(3, result->size()); auto z = result->at(0); @@ -738,7 +738,7 @@ TEST_F(DeclarableOpsTests7, Test_Dynamic_Partition_119_2) { x.linspace(1.f); //.assign(1.f); nd4j::ops::dynamic_partition op; - auto result = op.execute({&x, &y}, {}, {4}); + auto result = op.evaluate({&x, &y}, {}, {4}); ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(4, result->size()); for (size_t i = 0; i < result->size(); i++) { @@ -768,7 +768,7 @@ TEST_F(DeclarableOpsTests7, Test_SequenceMask_1) { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0,1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 }); nd4j::ops::sequence_mask op; - auto result = op.execute({&input}, {}, {}); + auto result = op.evaluate({&input}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -790,7 +790,7 @@ TEST_F(DeclarableOpsTests7, Test_SequenceMask_2) { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}); nd4j::ops::sequence_mask op; - auto result = op.execute({&input}, {}, {}); + auto result = op.evaluate({&input}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -810,7 +810,7 @@ TEST_F(DeclarableOpsTests7, TestSegmentMax_1) { nd4j::ops::segment_max op; - auto result = op.execute({&x, &idx}, {}, {}); + auto result = op.evaluate({&x, &idx}, {}, {}); ASSERT_EQ(result->status(), Status::OK()); // result->at(0)->printBuffer("MaX1"); // exp.printBuffer("ExP1"); @@ -826,7 +826,7 @@ TEST_F(DeclarableOpsTests7, TestSegmentMax_01) { nd4j::ops::segment_max op; - auto result = op.execute({&x, &idx}, {}, {}); + auto result = op.evaluate({&x, &idx}, {}, {}); ASSERT_EQ(result->status(), Status::OK()); // result->at(0)->printBuffer("MaX01"); // exp.printBuffer("ExP01"); @@ -842,7 +842,7 @@ TEST_F(DeclarableOpsTests7, TestSegmentMaxBP_1) { auto eps = NDArrayFactory::create('c', {5}); nd4j::ops::segment_max_bp op; eps.linspace(1); - auto result = op.execute({&x, &idx, &eps}, {}, {}); + auto result = op.evaluate({&x, &idx, &eps}, {}, {}); ASSERT_EQ(result->status(), Status::OK()); // result->at(0)->printIndexedBuffer("OutputMaxBP"); ASSERT_TRUE(exp.equalsTo(result->at(0))); @@ -865,7 +865,7 @@ TEST_F(DeclarableOpsTests7, TestSegmentMax_2) { //{ 2.1, 2.5, 4., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.} nd4j::ops::segment_max op; - auto result = op.execute({&x, &idx}, {}, {}); + auto result = op.evaluate({&x, &idx}, {}, {}); ASSERT_EQ(result->status(), Status::OK()); ASSERT_EQ(result->size(), 1); auto out = result->at(0); @@ -889,7 +889,7 @@ TEST_F(DeclarableOpsTests7, TestSegmentMaxBP_2) { nd4j::ops::segment_max_bp op; - auto result = op.execute({&x, &idx, &eps}, {}, {}); + auto result = op.evaluate({&x, &idx, &eps}, {}, {}); ASSERT_EQ(result->status(), Status::OK()); ASSERT_EQ(result->size(), 2); //exp.printIndexedBuffer("BP Max Expect"); @@ -917,7 +917,7 @@ TEST_F(DeclarableOpsTests7, TestSegmentMax_3) { nd4j::ops::segment_max op; - auto result = op.execute({&x, &idx}, {}, {}); + auto result = op.evaluate({&x, &idx}, {}, {}); ASSERT_EQ(result->status(), Status::OK()); // result->at(0)->printIndexedBuffer("Output3Max"); // result->at(0)->printShapeInfo("Out Shape 3 Max"); @@ -945,7 +945,7 @@ TEST_F(DeclarableOpsTests7, TestSegmentMax_4) { nd4j::ops::segment_max op; - auto result = op.execute({&x, &idx}, {}, {}); + auto result = op.evaluate({&x, &idx}, {}, {}); ASSERT_EQ(result->status(), Status::OK()); //result->at(0)->printIndexedBuffer("Output"); //result->at(0)->printShapeInfo("Out Shape"); @@ -965,7 +965,7 @@ TEST_F(DeclarableOpsTests7, TestUnsortedSegmentMax_1) { nd4j::ops::unsorted_segment_max op; - auto result = op.execute({&x, &idx}, {}, {5}); + auto result = op.evaluate({&x, &idx}, {}, {5}); ASSERT_EQ(result->status(), Status::OK()); ASSERT_TRUE(exp.equalsTo(result->at(0))); @@ -980,7 +980,7 @@ TEST_F(DeclarableOpsTests7, TestUnsortedSegmentMaxBP_1) { auto eps = NDArrayFactory::create('c', {5}); nd4j::ops::segment_max_bp op; eps.linspace(1); - auto result = op.execute({&x, &idx, &eps}, {}, {5}); + auto result = op.evaluate({&x, &idx, &eps}, {}, {5}); ASSERT_EQ(result->status(), Status::OK()); ASSERT_TRUE(exp.equalsTo(result->at(0))); @@ -995,7 +995,7 @@ TEST_F(DeclarableOpsTests7, TestUnsortedSegmentMaxBP_2) { auto eps = NDArrayFactory::create('c', {5}); nd4j::ops::segment_max_bp op; eps.linspace(1); - auto result = op.execute({&x, &idx, &eps}, {}, {5}); + auto result = op.evaluate({&x, &idx, &eps}, {}, {5}); ASSERT_EQ(result->status(), Status::OK()); //result->at(0)->printIndexedBuffer("Output"); //exp.printIndexedBuffer("Expect"); @@ -1012,7 +1012,7 @@ TEST_F(DeclarableOpsTests7, TestUnsortedSegmentMax_2) { nd4j::ops::unsorted_segment_max op; - auto result = op.execute({&x, &idx}, {}, {5}); + auto result = op.evaluate({&x, &idx}, {}, {5}); ASSERT_EQ(result->status(), Status::OK()); // result->at(0)->printIndexedBuffer("OutputUnsortedMax"); ASSERT_TRUE(exp.equalsTo(result->at(0))); @@ -1030,7 +1030,7 @@ TEST_F(DeclarableOpsTests7, TestUnsortedSegmentMax_3) { nd4j::ops::unsorted_segment_max op; - auto result = op.execute({&x, &idx}, {}, {3}); + auto result = op.evaluate({&x, &idx}, {}, {3}); ASSERT_EQ(result->status(), Status::OK()); ASSERT_EQ(result->size(), 1); //exp.printIndexedBuffer("Expect"); @@ -1053,7 +1053,7 @@ TEST_F(DeclarableOpsTests7, TestUnsortedSegmentMax_4) { nd4j::ops::unsorted_segment_max op; - auto result = op.execute({&x, &idx}, {}, {3}); + auto result = op.evaluate({&x, &idx}, {}, {3}); ASSERT_EQ(result->status(), Status::OK()); ASSERT_EQ(result->size(), 1); //exp.printIndexedBuffer("Expect"); @@ -1070,7 +1070,7 @@ TEST_F(DeclarableOpsTests7, TestSegmentMin_1) { nd4j::ops::segment_min op; - auto result = op.execute({&x, &idx}, {}, {}); + auto result = op.evaluate({&x, &idx}, {}, {}); ASSERT_EQ(result->status(), Status::OK()); auto out = result->at(0); @@ -1087,7 +1087,7 @@ TEST_F(DeclarableOpsTests7, TestSegmentMin_01) { nd4j::ops::segment_min op; - auto result = op.execute({&x, &idx}, {}, {}); + auto result = op.evaluate({&x, &idx}, {}, {}); ASSERT_EQ(result->status(), Status::OK()); auto out = result->at(0); @@ -1103,7 +1103,7 @@ TEST_F(DeclarableOpsTests7, TestSegmentMin_02) { nd4j::ops::segment_min op; - auto result = op.execute({&x, &idx}, {}, {}); + auto result = op.evaluate({&x, &idx}, {}, {}); ASSERT_EQ(result->status(), Status::OK()); auto out = result->at(0); @@ -1121,7 +1121,7 @@ TEST_F(DeclarableOpsTests7, TestSegmentMinBP_1) { eps.linspace(1); nd4j::ops::segment_min_bp op; - auto result = op.execute({&x, &idx, &eps}, {}, {}); + auto result = op.evaluate({&x, &idx, &eps}, {}, {}); ASSERT_EQ(result->status(), Status::OK()); ASSERT_TRUE(exp.equalsTo(result->at(0))); @@ -1138,7 +1138,7 @@ TEST_F(DeclarableOpsTests7, TestUnsortedSegmentMinBP_1) { eps.linspace(1); nd4j::ops::unsorted_segment_min_bp op; - auto result = op.execute({&x, &idx, &eps}, {}, {5}); + auto result = op.evaluate({&x, &idx, &eps}, {}, {5}); ASSERT_EQ(result->status(), Status::OK()); //result->at(0)->printIndexedBuffer("Output1"); //exp.printIndexedBuffer("Expecte"); @@ -1157,7 +1157,7 @@ TEST_F(DeclarableOpsTests7, TestUnsortedSegmentMinBP_2) { eps.linspace(1); nd4j::ops::unsorted_segment_min_bp op; - auto result = op.execute({&x, &idx, &eps}, {}, {5}); + auto result = op.evaluate({&x, &idx, &eps}, {}, {5}); ASSERT_EQ(result->status(), Status::OK()); //result->at(0)->printIndexedBuffer("Output1"); //exp.printIndexedBuffer("Expecte"); @@ -1177,7 +1177,7 @@ TEST_F(DeclarableOpsTests7, TestSegmentMin_2) { nd4j::ops::segment_min op; - auto result = op.execute({&x, &idx}, {}, {}); + auto result = op.evaluate({&x, &idx}, {}, {}); ASSERT_EQ(result->status(), Status::OK()); ASSERT_EQ(result->size(), 1); // exp.printIndexedBuffer("Expect"); @@ -1198,7 +1198,7 @@ TEST_F(DeclarableOpsTests7, TestSegmentMinBP_2) { nd4j::ops::segment_min_bp op; - auto result = op.execute({&x, &idx, &eps}, {}, {}); + auto result = op.evaluate({&x, &idx, &eps}, {}, {}); ASSERT_EQ(result->status(), Status::OK()); ASSERT_EQ(result->size(), 2); // exp.printIndexedBuffer("Expect"); @@ -1223,7 +1223,7 @@ TEST_F(DeclarableOpsTests7, TestSegmentMin_3) { nd4j::ops::segment_min op; - auto result = op.execute({&x, &idx}, {}, {}); + auto result = op.evaluate({&x, &idx}, {}, {}); ASSERT_EQ(result->status(), Status::OK()); // result->at(0)->printIndexedBuffer("Output"); // result->at(0)->printShapeInfo("Out Shape"); @@ -1253,7 +1253,7 @@ TEST_F(DeclarableOpsTests7, TestSegmentMin_4) { nd4j::ops::segment_min op; - auto result = op.execute({&x, &idx}, {}, {}); + auto result = op.evaluate({&x, &idx}, {}, {}); ASSERT_EQ(result->status(), Status::OK()); //result->at(0)->printIndexedBuffer("Output"); //result->at(0)->printShapeInfo("Out Shape"); @@ -1273,7 +1273,7 @@ TEST_F(DeclarableOpsTests7, TestUnsortedSegmentMin_1) { nd4j::ops::unsorted_segment_min op; - auto result = op.execute({&x, &idx}, {}, {5}); + auto result = op.evaluate({&x, &idx}, {}, {5}); ASSERT_EQ(result->status(), Status::OK()); ASSERT_TRUE(exp.equalsTo(result->at(0))); @@ -1287,7 +1287,7 @@ TEST_F(DeclarableOpsTests7, TestUnsortedSegmentMin_01) { nd4j::ops::unsorted_segment_min op; - auto result = op.execute({&x, &idx}, {}, {5}); + auto result = op.evaluate({&x, &idx}, {}, {5}); ASSERT_EQ(result->status(), Status::OK()); ASSERT_TRUE(exp.equalsTo(result->at(0))); @@ -1304,7 +1304,7 @@ TEST_F(DeclarableOpsTests7, TestUnsortedSegmentMin_2) { nd4j::ops::unsorted_segment_min op; - auto result = op.execute({&x, &idx}, {}, {3}); + auto result = op.evaluate({&x, &idx}, {}, {3}); ASSERT_EQ(result->status(), Status::OK()); ASSERT_EQ(result->size(), 1); // exp.printIndexedBuffer("Expect"); @@ -1328,7 +1328,7 @@ TEST_F(DeclarableOpsTests7, TestUnsortedSegmentMin_3) { nd4j::ops::unsorted_segment_min op; - auto result = op.execute({&x, &idx}, {}, {3}); + auto result = op.evaluate({&x, &idx}, {}, {3}); ASSERT_EQ(result->status(), Status::OK()); // result->at(0)->printIndexedBuffer("Output"); // result->at(0)->printShapeInfo("Out Shape"); @@ -1369,7 +1369,7 @@ TEST_F(DeclarableOpsTests7, TestUnsortedSegmentMin_4) { nd4j::ops::unsorted_segment_min op; - auto result = op.execute({&x, &idx}, {}, {8}); + auto result = op.evaluate({&x, &idx}, {}, {8}); ASSERT_EQ(result->status(), Status::OK()); //result->at(0)->printIndexedBuffer("Output"); //result->at(0)->printShapeInfo("Out Shape"); @@ -1389,7 +1389,7 @@ TEST_F(DeclarableOpsTests7, TestSegmentMean_1) { nd4j::ops::segment_mean op; - auto result = op.execute({&x, &idx}, {}, {}); + auto result = op.evaluate({&x, &idx}, {}, {}); ASSERT_EQ(result->status(), Status::OK()); ASSERT_TRUE(exp.equalsTo(result->at(0))); @@ -1403,7 +1403,7 @@ TEST_F(DeclarableOpsTests7, TestSegmentMean_2) { nd4j::ops::segment_mean op; - auto result = op.execute({&x, &idx}, {}, {}); + auto result = op.evaluate({&x, &idx}, {}, {}); ASSERT_EQ(result->status(), Status::OK()); ASSERT_EQ(result->size(), 1); // exp.printIndexedBuffer("Expect"); @@ -1421,7 +1421,7 @@ TEST_F(DeclarableOpsTests7, TestSegmentMean_02) { nd4j::ops::segment_mean op; - auto result = op.execute({&x, &idx}, {}, {}); + auto result = op.evaluate({&x, &idx}, {}, {}); ASSERT_EQ(result->status(), Status::OK()); ASSERT_EQ(result->size(), 1); ASSERT_TRUE(exp.equalsTo(result->at(0))); @@ -1436,7 +1436,7 @@ TEST_F(DeclarableOpsTests7, TestSegmentMean_021) { nd4j::ops::segment_mean op; x.linspace(1.); - auto result = op.execute({&x, &idx}, {}, {}); + auto result = op.evaluate({&x, &idx}, {}, {}); ASSERT_EQ(result->status(), Status::OK()); ASSERT_EQ(result->size(), 1); ASSERT_TRUE(exp.equalsTo(result->at(0))); @@ -1452,7 +1452,7 @@ TEST_F(DeclarableOpsTests7, TestSegmentMean_022) { nd4j::ops::segment_mean op; x.linspace(1.); - auto result = op.execute({&x, &idx}, {&z}, {}, {}, {}, false, nd4j::DataType::FLOAT32); + auto result = op.execute({&x, &idx}, {&z}); ASSERT_EQ(result, Status::OK()); ASSERT_TRUE(exp.equalsTo(z)); @@ -1470,7 +1470,7 @@ TEST_F(DeclarableOpsTests7, TestSegmentMeanBP_2) { nd4j::ops::segment_mean_bp op; - auto result = op.execute({&x, &idx, &eps}, {}, {}); + auto result = op.evaluate({&x, &idx, &eps}, {}, {}); ASSERT_EQ(result->status(), Status::OK()); ASSERT_EQ(result->size(), 2); ASSERT_TRUE(exp.equalsTo(result->at(0))); @@ -1495,7 +1495,7 @@ TEST_F(DeclarableOpsTests7, TestSegmentMean_3) { nd4j::ops::segment_mean op; - auto result = op.execute({&x, &idx}, {}, {}); + auto result = op.evaluate({&x, &idx}, {}, {}); ASSERT_EQ(result->status(), Status::OK()); // result->at(0)->printIndexedBuffer("Output"); // result->at(0)->printShapeInfo("Out Shape"); @@ -1526,7 +1526,7 @@ TEST_F(DeclarableOpsTests7, TestSegmentMean_4) { nd4j::ops::segment_mean op; - auto result = op.execute({&x, &idx}, {}, {}); + auto result = op.evaluate({&x, &idx}, {}, {}); ASSERT_EQ(result->status(), Status::OK()); //result->at(0)->printIndexedBuffer("Output"); //result->at(0)->printShapeInfo("Out Shape"); @@ -1546,7 +1546,7 @@ TEST_F(DeclarableOpsTests7, TestUnsortedSegmentMean_1) { nd4j::ops::unsorted_segment_mean op; - auto result = op.execute({&x, &idx}, {}, {5}); + auto result = op.evaluate({&x, &idx}, {}, {5}); ASSERT_EQ(result->status(), Status::OK()); ASSERT_TRUE(exp.equalsTo(result->at(0))); @@ -1562,7 +1562,7 @@ TEST_F(DeclarableOpsTests7, TestSegmentMeanBP_1) { 5./6., 5./6., 5./6., 5./6., 5./6., 5./6.}); nd4j::ops::segment_mean_bp op; - auto result = op.execute({&x, &idx, &eps}, {}, {}); + auto result = op.evaluate({&x, &idx, &eps}, {}, {}); ASSERT_EQ(result->status(), Status::OK()); ASSERT_TRUE(exp.equalsTo(result->at(0))); @@ -1578,7 +1578,7 @@ TEST_F(DeclarableOpsTests7, TestUnsortedSegmentMeanBP_1) { 5./6., 5./6., 5./6., 5./6., 5./6., 5./6.}); nd4j::ops::unsorted_segment_mean_bp op; - auto result = op.execute({&x, &idx, &eps}, {}, {5}); + auto result = op.evaluate({&x, &idx, &eps}, {}, {5}); ASSERT_EQ(result->status(), Status::OK()); ASSERT_TRUE(exp.equalsTo(result->at(0))); @@ -1594,7 +1594,7 @@ TEST_F(DeclarableOpsTests7, TestUnsortedSegmentMeanBP_2) { 5./6., 5./6., 5./6., 5./6., 5./6., 5./6.}); nd4j::ops::unsorted_segment_mean_bp op; - auto result = op.execute({&x, &idx, &eps}, {}, {5}); + auto result = op.evaluate({&x, &idx, &eps}, {}, {5}); ASSERT_EQ(result->status(), Status::OK()); ASSERT_TRUE(exp.equalsTo(result->at(0))); @@ -1609,7 +1609,7 @@ TEST_F(DeclarableOpsTests7, TestUnsortedSegmentMean_2) { nd4j::ops::unsorted_segment_mean op; - auto result = op.execute({&x, &idx}, {}, {3}); + auto result = op.evaluate({&x, &idx}, {}, {3}); ASSERT_EQ(result->status(), Status::OK()); ASSERT_EQ(result->size(), 1); // exp.printIndexedBuffer("Expect"); @@ -1637,7 +1637,7 @@ TEST_F(DeclarableOpsTests7, TestUnsortedSegmentMean_3) { nd4j::ops::unsorted_segment_mean op; - auto result = op.execute({&x, &idx}, {}, {3}); + auto result = op.evaluate({&x, &idx}, {}, {3}); ASSERT_EQ(result->status(), Status::OK()); // result->at(0)->printIndexedBuffer("Output"); // result->at(0)->printShapeInfo("Out Shape"); @@ -1668,7 +1668,7 @@ TEST_F(DeclarableOpsTests7, TestUnsortedSegmentMean_4) { nd4j::ops::unsorted_segment_mean op; - auto result = op.execute({&x, &idx}, {}, {8}); + auto result = op.evaluate({&x, &idx}, {}, {8}); ASSERT_EQ(result->status(), Status::OK()); //result->at(0)->printIndexedBuffer("Output"); //result->at(0)->printShapeInfo("Out Shape"); @@ -1688,7 +1688,7 @@ TEST_F(DeclarableOpsTests7, TestUnsortedSegmentSqrtN_1) { nd4j::ops::unsorted_segment_sqrt_n op; - auto result = op.execute({&x, &idx}, {}, {5}); + auto result = op.evaluate({&x, &idx}, {}, {5}); ASSERT_EQ(result->status(), Status::OK()); ASSERT_TRUE(exp.equalsTo(result->at(0))); @@ -1704,7 +1704,7 @@ TEST_F(DeclarableOpsTests7, TestUnsortedSegmentSqrtN_BP_1) { auto exp = NDArrayFactory::create({3., 0.707107, 0.707107, 1., 1., 1., 1., 2.309401, 2.309401, 2.309401, 2.041241, 2.041241, 2.041241, 2.041241, 2.041241, 2.041241}); nd4j::ops::unsorted_segment_sqrt_n_bp op; - auto result = op.execute({&x, &idx, &eps}, {}, {5}); + auto result = op.evaluate({&x, &idx, &eps}, {}, {5}); ASSERT_EQ(result->status(), Status::OK()); // result->at(0)->printIndexedBuffer("Hello Out:"); ASSERT_TRUE(exp.equalsTo(result->at(0))); @@ -1723,7 +1723,7 @@ TEST_F(DeclarableOpsTests7, TestUnsortedSegmentSqrtN_2) { nd4j::ops::unsorted_segment_sqrt_n op; - auto result = op.execute({&x, &idx}, {}, {3}); + auto result = op.evaluate({&x, &idx}, {}, {3}); ASSERT_EQ(result->status(), Status::OK()); ASSERT_EQ(result->size(), 1); // exp.printIndexedBuffer("Expect"); @@ -1751,7 +1751,7 @@ TEST_F(DeclarableOpsTests7, TestUnsortedSegmentSqrtN_3) { nd4j::ops::unsorted_segment_sqrt_n op; - auto result = op.execute({&x, &idx}, {}, {3}); + auto result = op.evaluate({&x, &idx}, {}, {3}); ASSERT_EQ(result->status(), Status::OK()); // result->at(0)->printIndexedBuffer("Output"); // result->at(0)->printShapeInfo("Out Shape"); @@ -1782,7 +1782,7 @@ TEST_F(DeclarableOpsTests7, TestUnsortedSegmentSqrtN_4) { nd4j::ops::unsorted_segment_sqrt_n op; - auto result = op.execute({&x, &idx}, {}, {8}); + auto result = op.evaluate({&x, &idx}, {}, {8}); ASSERT_EQ(result->status(), Status::OK()); //result->at(0)->printIndexedBuffer("Output"); //result->at(0)->printShapeInfo("Out Shape"); @@ -1802,7 +1802,7 @@ TEST_F(DeclarableOpsTests7, TestUnsortedSegmentSqrtN_5) { auto exp = NDArrayFactory::create({7.5055537, 2., 4.9497476, 2.828427}); nd4j::ops::unsorted_segment_sqrt_n op; - auto result = op.execute({&x, &idx}, {}, {4}); + auto result = op.evaluate({&x, &idx}, {}, {4}); ASSERT_EQ(result->status(), Status::OK()); // result->at(0)->printIndexedBuffer("Output"); // exp.printIndexedBuffer("Expect"); @@ -1819,7 +1819,7 @@ TEST_F(DeclarableOpsTests7, TestSegmentSum_1) { nd4j::ops::segment_sum op; - auto result = op.execute({&x, &idx}, {}, {}); + auto result = op.evaluate({&x, &idx}, {}, {}); ASSERT_EQ(result->status(), Status::OK()); ASSERT_TRUE(exp.equalsTo(result->at(0))); @@ -1835,7 +1835,7 @@ TEST_F(DeclarableOpsTests7, TestSegmentSumBP_1) { auto exp = NDArrayFactory::create({ 1., 1., 2., 2., 2., 2., 3., 4., 4., 4., 5., 5., 5., 5., 5., 5.}); nd4j::ops::segment_sum_bp op; - auto result = op.execute({&x, &idx, &eps}, {}, {}); + auto result = op.evaluate({&x, &idx, &eps}, {}, {}); ASSERT_EQ(result->status(), Status::OK()); ASSERT_TRUE(exp.equalsTo(result->at(0))); @@ -1849,7 +1849,7 @@ TEST_F(DeclarableOpsTests7, TestUnsortedSegmentSumBP_1) { auto exp = NDArrayFactory::create({ 1., 1., 2., 2., 2., 2., 3., 4., 4., 4., 5., 5., 5., 5., 5., 5.}); nd4j::ops::unsorted_segment_sum_bp op; - auto result = op.execute({&x, &idx, &eps}, {}, {5}); + auto result = op.evaluate({&x, &idx, &eps}, {}, {5}); ASSERT_EQ(result->status(), Status::OK()); ASSERT_TRUE(exp.equalsTo(result->at(0))); @@ -1863,7 +1863,7 @@ TEST_F(DeclarableOpsTests7, TestUnsortedSegmentSumBP_2) { auto exp = NDArrayFactory::create({ 3., 1., 1., 2., 2., 2., 2., 4., 4., 4., 5., 5., 5., 5., 5., 5.}); nd4j::ops::unsorted_segment_sum_bp op; - auto result = op.execute({&x, &idx, &eps}, {}, {5}); + auto result = op.evaluate({&x, &idx, &eps}, {}, {5}); ASSERT_EQ(result->status(), Status::OK()); ASSERT_TRUE(exp.equalsTo(result->at(0))); @@ -1878,7 +1878,7 @@ TEST_F(DeclarableOpsTests7, TestSegmentSum_2) { nd4j::ops::segment_sum op; - auto result = op.execute({&x, &idx}, {}, {}); + auto result = op.evaluate({&x, &idx}, {}, {}); ASSERT_EQ(result->status(), Status::OK()); ASSERT_EQ(result->size(), 1); // exp.printIndexedBuffer("Expect"); @@ -1898,7 +1898,7 @@ TEST_F(DeclarableOpsTests7, TestSegmentSumBP_2) { nd4j::ops::segment_sum_bp op; - auto result = op.execute({&x, &idx, &eps}, {}, {}); + auto result = op.evaluate({&x, &idx, &eps}, {}, {}); ASSERT_EQ(result->status(), Status::OK()); ASSERT_EQ(result->size(), 2); // exp.printIndexedBuffer("Expect"); @@ -1925,7 +1925,7 @@ TEST_F(DeclarableOpsTests7, TestSegmentSum_3) { nd4j::ops::segment_sum op; - auto result = op.execute({&x, &idx}, {}, {}); + auto result = op.evaluate({&x, &idx}, {}, {}); ASSERT_EQ(result->status(), Status::OK()); // result->at(0)->printIndexedBuffer("Output"); // result->at(0)->printShapeInfo("Out Shape"); @@ -1956,7 +1956,7 @@ TEST_F(DeclarableOpsTests7, TestSegmentSum_4) { nd4j::ops::segment_sum op; - auto result = op.execute({&x, &idx}, {}, {}); + auto result = op.evaluate({&x, &idx}, {}, {}); ASSERT_EQ(result->status(), Status::OK()); //result->at(0)->printIndexedBuffer("Output"); //result->at(0)->printShapeInfo("Out Shape"); @@ -1976,7 +1976,7 @@ TEST_F(DeclarableOpsTests7, TestUnsortedSegmentSum_1) { nd4j::ops::unsorted_segment_sum op; - auto result = op.execute({&x, &idx}, {}, {5}); + auto result = op.evaluate({&x, &idx}, {}, {5}); ASSERT_EQ(result->status(), Status::OK()); ASSERT_TRUE(exp.equalsTo(result->at(0))); @@ -1991,7 +1991,7 @@ TEST_F(DeclarableOpsTests7, TestUnsortedSegmentSum_2) { nd4j::ops::unsorted_segment_sum op; - auto result = op.execute({&x, &idx}, {}, {3}); + auto result = op.evaluate({&x, &idx}, {}, {3}); ASSERT_EQ(result->status(), Status::OK()); ASSERT_EQ(result->size(), 1); ASSERT_TRUE(exp.equalsTo(result->at(0))); @@ -2016,7 +2016,7 @@ TEST_F(DeclarableOpsTests7, TestUnsortedSegmentSum_3) { nd4j::ops::unsorted_segment_sum op; - auto result = op.execute({&x, &idx}, {}, {3}); + auto result = op.evaluate({&x, &idx}, {}, {3}); ASSERT_EQ(result->status(), Status::OK()); // result->at(0)->printIndexedBuffer("Output"); // result->at(0)->printShapeInfo("Out Shape"); @@ -2046,7 +2046,7 @@ TEST_F(DeclarableOpsTests7, TestUnsortedSegmentSum_4) { nd4j::ops::unsorted_segment_sum op; - auto result = op.execute({&x, &idx}, {}, {8}); + auto result = op.evaluate({&x, &idx}, {}, {8}); ASSERT_EQ(result->status(), Status::OK()); // result->at(0)->printIndexedBuffer("Output"); // result->at(0)->printShapeInfo("Out Shape"); @@ -2066,7 +2066,7 @@ TEST_F(DeclarableOpsTests7, TestSegmentProd_1) { nd4j::ops::segment_prod op; - auto result = op.execute({&x, &idx}, {}, {}); + auto result = op.evaluate({&x, &idx}, {}, {}); ASSERT_EQ(result->status(), Status::OK()); ASSERT_TRUE(exp.equalsTo(result->at(0))); @@ -2081,7 +2081,7 @@ TEST_F(DeclarableOpsTests7, TestSegmentProdBP_1) { auto exp = NDArrayFactory::create({2.5, 1.8, 90.72, 40.32, 172.8, 151.2, 3., 17.64, 75.6, 75.6, 13.86, 97.02, 3.234, 2.31, 4.41, 9.702}); nd4j::ops::segment_prod_bp op; - auto result = op.execute({&x, &idx, &eps}, {}, {}); + auto result = op.evaluate({&x, &idx, &eps}, {}, {}); ASSERT_EQ(result->status(), Status::OK()); // result->at(0)->printIndexedBuffer("ProdBP Output"); // exp.printIndexedBuffer("ProdBP Expect"); @@ -2099,7 +2099,7 @@ TEST_F(DeclarableOpsTests7, TestUnsortedSegmentProdBP_1) { auto exp = NDArrayFactory::create({2.5, 1.8, 90.72, 40.32, 172.8, 151.2, 3., 17.64, 75.6, 75.6, 13.86, 97.02, 3.234, 2.31, 4.41, 9.702}); nd4j::ops::segment_prod_bp op; - auto result = op.execute({&x, &idx, &eps}, {}, {}); + auto result = op.evaluate({&x, &idx, &eps}, {}, {}); ASSERT_EQ(result->status(), Status::OK()); //result->at(0)->printIndexedBuffer("ProdBP Output"); //exp.printIndexedBuffer("ProdBP Expect"); @@ -2118,7 +2118,7 @@ TEST_F(DeclarableOpsTests7, TestUnsortedSegmentProdBP_2) { auto n = NDArrayFactory::create(5LL); nd4j::ops::unsorted_segment_prod_bp op; - auto result = op.execute({&x, &idx, &eps, &n}, {}, {5}); + auto result = op.evaluate({&x, &idx, &eps, &n}, {}, {5}); ASSERT_EQ(result->status(), Status::OK()); //result->at(0)->printIndexedBuffer("Unsorted ProdBP Output"); //exp.printIndexedBuffer("Unsorted ProdBP Expect"); @@ -2139,7 +2139,7 @@ TEST_F(DeclarableOpsTests7, TestSegmentProd_2) { nd4j::ops::segment_prod op; - auto result = op.execute({&x, &idx}, {}, {}); + auto result = op.evaluate({&x, &idx}, {}, {}); ASSERT_EQ(result->status(), Status::OK()); ASSERT_EQ(result->size(), 1); // exp.printIndexedBuffer("Expect"); @@ -2163,7 +2163,7 @@ TEST_F(DeclarableOpsTests7, TestSegmentProdBP_2) { eps.linspace(1); nd4j::ops::segment_prod_bp op; - auto result = op.execute({&x, &idx, &eps}, {}, {}); + auto result = op.evaluate({&x, &idx, &eps}, {}, {}); ASSERT_EQ(result->status(), Status::OK()); ASSERT_EQ(result->size(), 2); // exp.printIndexedBuffer("Expect"); @@ -2190,7 +2190,7 @@ TEST_F(DeclarableOpsTests7, TestSegmentProd_3) { nd4j::ops::segment_prod op; - auto result = op.execute({&x, &idx}, {}, {}); + auto result = op.evaluate({&x, &idx}, {}, {}); ASSERT_EQ(result->status(), Status::OK()); // result->at(0)->printIndexedBuffer("Output"); // result->at(0)->printShapeInfo("Out Shape"); @@ -2212,7 +2212,7 @@ TEST_F(DeclarableOpsTests7, TestSegmentProd_04) { nd4j::ops::segment_prod op; - auto result = op.execute({&x, &idx}, {}, {}); + auto result = op.evaluate({&x, &idx}, {}, {}); ASSERT_EQ(result->status(), Status::OK()); ASSERT_TRUE(exp.equalsTo(result->at(0))); @@ -2229,7 +2229,7 @@ TEST_F(DeclarableOpsTests7, TestSegmentProd_05) { nd4j::ops::segment_prod op; - auto result = op.execute({&x, &idx}, {}, {}); + auto result = op.evaluate({&x, &idx}, {}, {}); ASSERT_EQ(result->status(), Status::OK()); auto res = result->at(0); // res->printIndexedBuffer("Segment prod 05"); @@ -2248,7 +2248,7 @@ TEST_F(DeclarableOpsTests7, TestSegmentProd_05_1) { nd4j::ops::segment_prod op; - auto result = op.execute({&x, &idx}, {}, {}); + auto result = op.evaluate({&x, &idx}, {}, {}); ASSERT_EQ(result->status(), Status::OK()); auto res = result->at(0); // res->printIndexedBuffer("Segment prod 05_1"); @@ -2267,7 +2267,7 @@ TEST_F(DeclarableOpsTests7, TestSegmentProd_06) { auto exp = NDArrayFactory::create({ 2, 3, 120, 56}); nd4j::ops::segment_prod op; - auto result = op.execute({&x, &idx}, {}, {}); + auto result = op.evaluate({&x, &idx}, {}, {}); ASSERT_EQ(result->status(), Status::OK()); ASSERT_TRUE(exp.equalsTo(result->at(0))); @@ -2284,7 +2284,7 @@ TEST_F(DeclarableOpsTests7, TestSegmentProd_07) { auto exp = NDArrayFactory::create({ 2, 3, 120, 56}); nd4j::ops::segment_prod op; - auto result = op.execute({&x, &idx}, {}, {}); + auto result = op.evaluate({&x, &idx}, {}, {}); ASSERT_EQ(result->status(), Status::OK()); ASSERT_TRUE(exp.equalsTo(result->at(0))); @@ -2301,7 +2301,7 @@ TEST_F(DeclarableOpsTests7, TestSegmentProd_08) { auto exp = NDArrayFactory::create({ 2, 1,360, 5040}); nd4j::ops::segment_prod op; - auto result = op.execute({&x, &idx}, {}, {}); + auto result = op.evaluate({&x, &idx}, {}, {}); ASSERT_EQ(result->status(), Status::OK()); ASSERT_TRUE(exp.equalsTo(result->at(0))); @@ -2316,7 +2316,7 @@ TEST_F(DeclarableOpsTests7, TestUnsortedSegmentProd_1) { nd4j::ops::unsorted_segment_prod op; - auto result = op.execute({&x, &idx}, {}, {5}); + auto result = op.evaluate({&x, &idx}, {}, {5}); ASSERT_EQ(result->status(), Status::OK()); ASSERT_TRUE(exp.equalsTo(result->at(0))); @@ -2331,7 +2331,7 @@ TEST_F(DeclarableOpsTests7, TestUnsortedSegmentProd_11) { nd4j::ops::unsorted_segment_prod op; - auto result = op.execute({&x, &idx}, {}, {5}); + auto result = op.evaluate({&x, &idx}, {}, {5}); ASSERT_EQ(result->status(), Status::OK()); ASSERT_TRUE(exp.equalsTo(result->at(0))); @@ -2348,7 +2348,7 @@ TEST_F(DeclarableOpsTests7, TestUnsortedSegmentProd_2) { nd4j::ops::unsorted_segment_prod op; - auto result = op.execute({&x, &idx}, {}, {3}); + auto result = op.evaluate({&x, &idx}, {}, {3}); ASSERT_EQ(result->status(), Status::OK()); ASSERT_EQ(result->size(), 1); // exp.printIndexedBuffer("Expect"); @@ -2370,7 +2370,7 @@ TEST_F(DeclarableOpsTests7, TestUnsortedSegmentProd_12) { nd4j::ops::unsorted_segment_prod op; - auto result = op.execute({&x, &idx}, {}, {3}); + auto result = op.evaluate({&x, &idx}, {}, {3}); ASSERT_EQ(result->status(), Status::OK()); ASSERT_EQ(result->size(), 1); // exp.printIndexedBuffer("Expect"); @@ -2389,7 +2389,7 @@ TEST_F(DeclarableOpsTests7, TestUnsortedSegmentProd_08) { auto exp = NDArrayFactory::create({ 2, 1,360, 5040}); nd4j::ops::unsorted_segment_prod op; - auto result = op.execute({&x, &idx}, {}, {4}); + auto result = op.evaluate({&x, &idx}, {}, {4}); ASSERT_EQ(result->status(), Status::OK()); ASSERT_TRUE(exp.equalsTo(result->at(0))); @@ -2413,7 +2413,7 @@ TEST_F(DeclarableOpsTests7, TestUnsortedSegmentProd_3) { nd4j::ops::unsorted_segment_prod op; - auto result = op.execute({&x, &idx}, {}, {3}); + auto result = op.evaluate({&x, &idx}, {}, {3}); ASSERT_EQ(result->status(), Status::OK()); // result->at(0)->printIndexedBuffer("Output"); // result->at(0)->printShapeInfo("Out Shape"); @@ -2444,7 +2444,7 @@ TEST_F(DeclarableOpsTests7, TestUnsortedSegmentProd_4) { nd4j::ops::unsorted_segment_prod op; - auto result = op.execute({&x, &idx}, {}, {3}); + auto result = op.evaluate({&x, &idx}, {}, {3}); ASSERT_EQ(result->status(), Status::OK()); //result->at(0)->printIndexedBuffer("Output"); // result->at(0)->printShapeInfo("Out Shape"); @@ -2479,7 +2479,7 @@ TEST_F(DeclarableOpsTests7, TestUnsortedSegmentProd_5) { nd4j::ops::unsorted_segment_prod op; - auto result = op.execute({&x, &idx}, {}, {4}); + auto result = op.evaluate({&x, &idx}, {}, {4}); ASSERT_EQ(result->status(), Status::OK()); //result->at(0)->printIndexedBuffer("Output"); // result->at(0)->printShapeInfo("Out Shape"); @@ -2510,7 +2510,7 @@ TEST_F(DeclarableOpsTests7, TestUnsortedSegmentProdBP_4) { nd4j::ops::unsorted_segment_prod_bp op; - auto result = op.execute({&x, &idx, &gradO}, {}, {4}); + auto result = op.evaluate({&x, &idx, &gradO}, {}, {4}); ASSERT_EQ(result->status(), Status::OK()); //result->at(0)->printIndexedBuffer("Output"); //result->at(0)->printShapeInfo("Out Shape"); @@ -2548,7 +2548,7 @@ TEST_F(DeclarableOpsTests7, TestExtractImagePatches_1) { nd4j::ops::extract_image_patches op; - auto result = op.execute({&x}, {}, {1,1,1,1,1,1,0}); + auto result = op.evaluate({&x}, {}, {1,1,1,1,1,1,0}); ASSERT_EQ(result->status(), Status::OK()); // result->at(0)->printIndexedBuffer("Output"); //result->at(0)->printShapeInfo("Out Shape"); @@ -2585,7 +2585,7 @@ auto exp = NDArrayFactory::create('c', {3, 1, 1, 12}, { // ---------------------------------------------------------------- nd4j::ops::extract_image_patches op; - auto result = op.execute({&x}, {}, {2,2, 3,3, 1,1,0}); + auto result = op.evaluate({&x}, {}, {2,2, 3,3, 1,1,0}); ASSERT_EQ(result->status(), Status::OK()); ASSERT_TRUE(exp.isSameShape(result->at(0))); @@ -2619,7 +2619,7 @@ auto exp = NDArrayFactory::create('c', {3, 1, 2, 6}, { // ---------------------------------------------------------------- nd4j::ops::extract_image_patches op; - auto result = op.execute({&x}, {}, {2,1,3,2,2,2,0}); + auto result = op.evaluate({&x}, {}, {2,1,3,2,2,2,0}); ASSERT_EQ(result->status(), Status::OK()); ASSERT_TRUE(exp.isSameShape(result->at(0))); @@ -2658,7 +2658,7 @@ auto exp = NDArrayFactory::create('c', {3, 3, 4, 3}, { // ---------------------------------------------------------------- nd4j::ops::extract_image_patches op; - auto result = op.execute({&x}, {}, {1,1,1,1,1,1,0}); + auto result = op.evaluate({&x}, {}, {1,1,1,1,1,1,0}); ASSERT_EQ(result->status(), Status::OK()); // result->at(0)->printIndexedBuffer("Output"); //result->at(0)->printShapeInfo("Out Shape"); @@ -2697,7 +2697,7 @@ auto exp = NDArrayFactory::create('c', {3, 1, 1, 18}, { // ---------------------------------------------------------------- nd4j::ops::extract_image_patches op; - auto result = op.execute({&x}, {}, {3,2,3,2,1,2,0}); + auto result = op.evaluate({&x}, {}, {3,2,3,2,1,2,0}); ASSERT_EQ(result->status(), Status::OK()); // result->at(0)->printIndexedBuffer("Output"); //result->at(0)->printShapeInfo("Out Shape"); @@ -2726,7 +2726,7 @@ auto exp = NDArrayFactory::create('c', {2, 1, 4, 4}, { // ---------------------------------------------------------------- nd4j::ops::extract_image_patches op; - auto result = op.execute({&x}, {}, {2,1, 1,1, 1,1,0}); + auto result = op.evaluate({&x}, {}, {2,1, 1,1, 1,1,0}); ASSERT_EQ(result->status(), Status::OK()); ASSERT_TRUE(exp.isSameShape(result->at(0))); ASSERT_TRUE(exp.equalsTo(result->at(0))); @@ -2748,7 +2748,7 @@ TEST_F(DeclarableOpsTests7, TestExtractImagePatches_SGO_7) { // ---------------------------------------------------------------- nd4j::ops::extract_image_patches op; - auto result = op.execute({&x}, {}, {2,2, 1,1, 1,1, 1}); // equiv TF ksizes=[1,2,2,1], strides=[1,1,1,1], rates=[1,1,1,1], padding="SAME" + auto result = op.evaluate({&x}, {}, {2,2, 1,1, 1,1, 1}); // equiv TF ksizes=[1,2,2,1], strides=[1,1,1,1], rates=[1,1,1,1], padding="SAME" ASSERT_EQ(result->status(), Status::OK()); auto output = result->at(0); // output->printBuffer("Output"); @@ -2779,7 +2779,7 @@ TEST_F(DeclarableOpsTests7, TestExtractImagePatches_SGO_8) { // ---------------------------------------------------------------- nd4j::ops::extract_image_patches op; - auto result = op.execute({&x}, {}, {2,2, 1,1, 1,1, 1}); // equiv TF ksizes=[1,2,2,1], strides=[1,1,1,1], rates=[1,1,1,1], padding="SAME" + auto result = op.evaluate({&x}, {}, {2,2, 1,1, 1,1, 1}); // equiv TF ksizes=[1,2,2,1], strides=[1,1,1,1], rates=[1,1,1,1], padding="SAME" ASSERT_EQ(result->status(), Status::OK()); auto output = result->at(0); // output->printBuffer("Output"); @@ -2843,7 +2843,7 @@ TEST_F(DeclarableOpsTests7, TestExtractImagePatches_SGO_9) { // ---------------------------------------------------------------- nd4j::ops::extract_image_patches op; - auto result = op.execute({&x}, {}, {3,3, 1,1, 1,1, 1}); // equiv TF ksizes=[1,2,2,1], strides=[1,1,1,1], rates=[1,1,1,1], padding="SAME" + auto result = op.evaluate({&x}, {}, {3,3, 1,1, 1,1, 1}); // equiv TF ksizes=[1,2,2,1], strides=[1,1,1,1], rates=[1,1,1,1], padding="SAME" ASSERT_EQ(result->status(), Status::OK()); auto output = result->at(0); // output->printBuffer("OutputSame"); @@ -2884,7 +2884,7 @@ TEST_F(DeclarableOpsTests7, TestExtractImagePatches_SGO_9_1) { // ---------------------------------------------------------------- nd4j::ops::extract_image_patches op; - auto result = op.execute({&x}, {}, {2,2, 1,1, 1,1, 1}); // equiv TF ksizes=[1,2,2,1], strides=[1,1,1,1], rates=[1,1,1,1], padding="SAME" + auto result = op.evaluate({&x}, {}, {2,2, 1,1, 1,1, 1}); // equiv TF ksizes=[1,2,2,1], strides=[1,1,1,1], rates=[1,1,1,1], padding="SAME" ASSERT_EQ(result->status(), Status::OK()); auto output = result->at(0); // output->printBuffer("OutputSame"); @@ -2931,7 +2931,7 @@ TEST_F(DeclarableOpsTests7, TestExtractImagePatches_SGO_10) { nd4j::ops::extract_image_patches op; //x.printIndexedBuffer("Images"); //x.printBuffer("Images linear"); - auto result = op.execute({&x}, {}, {3,3, 1,1, 1,1, 0}); // equiv TF ksizes=[1,2,2,1], strides=[1,1,1,1], rates=[1,1,1,1], padding="VALID" + auto result = op.evaluate({&x}, {}, {3,3, 1,1, 1,1, 0}); // equiv TF ksizes=[1,2,2,1], strides=[1,1,1,1], rates=[1,1,1,1], padding="VALID" ASSERT_EQ(result->status(), Status::OK()); auto output = result->at(0); // output->printBuffer("OutputValid"); @@ -2960,7 +2960,7 @@ TEST_F(DeclarableOpsTests7, TestExtractImagePatches_SGO_010) { nd4j::ops::extract_image_patches op; //x.printIndexedBuffer("Images"); //x.printBuffer("Images linear"); - auto result = op.execute({&x}, {}, {2,2, 1,1, 1,1, 0}); // equiv TF ksizes=[1,2,2,1], strides=[1,1,1,1], rates=[1,1,1,1], padding="VALID" + auto result = op.evaluate({&x}, {}, {2,2, 1,1, 1,1, 0}); // equiv TF ksizes=[1,2,2,1], strides=[1,1,1,1], rates=[1,1,1,1], padding="VALID" ASSERT_EQ(result->status(), Status::OK()); auto output = result->at(0); // output->printBuffer("OutputValid"); @@ -2990,7 +2990,7 @@ TEST_F(DeclarableOpsTests7, TestExtractImagePatches_SGO_010_1) { nd4j::ops::extract_image_patches op; //x.printIndexedBuffer("Images"); //x.printBuffer("Images linear"); - auto result = op.execute({&x}, {}, {2,2, 1,1, 1,1, 1}); // equiv TF ksizes=[1,2,2,1], strides=[1,1,1,1], rates=[1,1,1,1], padding="VALID" + auto result = op.evaluate({&x}, {}, {2,2, 1,1, 1,1, 1}); // equiv TF ksizes=[1,2,2,1], strides=[1,1,1,1], rates=[1,1,1,1], padding="VALID" ASSERT_EQ(result->status(), Status::OK()); auto output = result->at(0); // output->printBuffer("OutputSame"); @@ -3022,7 +3022,7 @@ TEST_F(DeclarableOpsTests7, TestExtractImagePatches_SGO_011) { nd4j::ops::extract_image_patches op; //x.printIndexedBuffer("Images"); //x.printBuffer("Images linear"); - auto result = op.execute({&x}, {}, {2,2, 1,1, 2,2, 0}); // equiv TF ksizes=[1,2,2,1], strides=[1,1,1,1], rates=[1,1,1,1], padding="VALID" + auto result = op.evaluate({&x}, {}, {2,2, 1,1, 2,2, 0}); // equiv TF ksizes=[1,2,2,1], strides=[1,1,1,1], rates=[1,1,1,1], padding="VALID" ASSERT_EQ(result->status(), Status::OK()); auto output = result->at(0); // output->printBuffer("OutputValid"); @@ -3058,7 +3058,7 @@ TEST_F(DeclarableOpsTests7, TestExtractImagePatches_SGO_11) { // ---------------------------------------------------------------- nd4j::ops::extract_image_patches op; - auto result = op.execute({&x}, {}, {2,2, 2,2, 1,1, 1}); // equiv TF ksizes=[1,2,2,1], strides=[1,1,1,1], rates=[1,1,1,1], padding="SAME" + auto result = op.evaluate({&x}, {}, {2,2, 2,2, 1,1, 1}); // equiv TF ksizes=[1,2,2,1], strides=[1,1,1,1], rates=[1,1,1,1], padding="SAME" ASSERT_EQ(result->status(), Status::OK()); auto output = result->at(0); // output->printBuffer("Output"); @@ -3115,7 +3115,7 @@ TEST_F(DeclarableOpsTests7, TestExtractImagePatches_SGO_12) { // ---------------------------------------------------------------- nd4j::ops::extract_image_patches op; - auto result = op.execute({&x}, {}, {2,2, 1,1, 2,2, 1}); // equiv TF ksizes=[1,2,2,1], strides=[1,1,1,1], rates=[1,2,2,1], padding="SAME" + auto result = op.evaluate({&x}, {}, {2,2, 1,1, 2,2, 1}); // equiv TF ksizes=[1,2,2,1], strides=[1,1,1,1], rates=[1,2,2,1], padding="SAME" ASSERT_EQ(result->status(), Status::OK()); auto output = result->at(0); //output->printShapeInfo("Output shape"); @@ -3145,7 +3145,7 @@ TEST_F(DeclarableOpsTests7, TestExtractImagePatches_SGO_13) { // ---------------------------------------------------------------- nd4j::ops::extract_image_patches op; - auto result = op.execute({&x}, {}, {2,2, 1,1, 1,1, 1}); // equiv TF ksizes=[1,2,2,1], strides=[1,1,1,1], rates=[1,1,1,1], padding="SAME" + auto result = op.evaluate({&x}, {}, {2,2, 1,1, 1,1, 1}); // equiv TF ksizes=[1,2,2,1], strides=[1,1,1,1], rates=[1,1,1,1], padding="SAME" ASSERT_EQ(result->status(), Status::OK()); auto output = result->at(0); @@ -3170,7 +3170,7 @@ auto exp = NDArrayFactory::create('c', {2, 2, 4, 2}, { // ---------------------------------------------------------------- nd4j::ops::roll op; - auto result = op.execute({&x}, {}, {6}, {}, false, nd4j::DataType::DOUBLE); + auto result = op.evaluate({&x}, {}, {6}); ASSERT_EQ(result->status(), Status::OK()); ASSERT_TRUE(exp.equalsTo(result->at(0))); @@ -3192,7 +3192,7 @@ auto exp = NDArrayFactory::create('c', {2, 2, 4, 2}, { // ---------------------------------------------------------------- nd4j::ops::roll op; - auto result = op.execute({&x}, {}, {-8}, {}, false, nd4j::DataType::DOUBLE); + auto result = op.evaluate({&x}, {}, {-8}); ASSERT_EQ(result->status(), Status::OK()); ASSERT_TRUE(exp.equalsTo(result->at(0))); @@ -3213,7 +3213,7 @@ auto exp = NDArrayFactory::create('c', {2, 2, 4, 2}, { // ---------------------------------------------------------------- nd4j::ops::roll op; - auto result = op.execute({&x}, {}, {-40}, {}, false, nd4j::DataType::DOUBLE); + auto result = op.evaluate({&x}, {}, {-40}); ASSERT_EQ(result->status(), Status::OK()); ASSERT_TRUE(exp.equalsTo(result->at(0))); @@ -3236,7 +3236,7 @@ auto exp = NDArrayFactory::create('c', {2, 2, 4, 2}, { // ---------------------------------------------------------------- nd4j::ops::roll op; - auto result = op.execute({&x}, {}, {38}, {}, false, nd4j::DataType::DOUBLE); + auto result = op.evaluate({&x}, {}, {38}); ASSERT_EQ(result->status(), Status::OK()); //result->at(0)->printIndexedBuffer("Output 4"); //exp.printIndexedBuffer("Expect 4"); @@ -3261,7 +3261,7 @@ auto exp = NDArrayFactory::create('c', {2, 2, 4, 2}, { // ---------------------------------------------------------------- nd4j::ops::roll op; NDArray* y = nullptr; - auto result = op.execute({&x}, {y}, {}, {38}, {}, true, nd4j::DataType::DOUBLE); + auto result = op.execute({&x}, {y}, {}, {38}, {}, {}, true); ASSERT_EQ(result, Status::OK()); //x.printIndexedBuffer("Output 4 inplace"); //exp.printIndexedBuffer("Expect 4 inplace"); @@ -3284,7 +3284,7 @@ auto exp = NDArrayFactory::create('c', {3, 4}, { // ---------------------------------------------------------------- nd4j::ops::roll op; - auto result = op.execute({&x}, {}, {2, 1}, {}, false, nd4j::DataType::DOUBLE); + auto result = op.evaluate({&x}, {}, {2, 1}); ASSERT_EQ(result->status(), Status::OK()); //result->at(0)->printIndexedBuffer("Output"); @@ -3307,7 +3307,7 @@ auto exp = NDArrayFactory::create('c', {2, 3, 2}, { // ---------------------------------------------------------------- nd4j::ops::roll op; - auto result = op.execute({&x}, {}, {1, 2}, {}, false, nd4j::DataType::DOUBLE); + auto result = op.evaluate({&x}, {}, {1, 2}); ASSERT_EQ(result->status(), Status::OK()); //result->at(0)->printIndexedBuffer("Output"); @@ -3330,7 +3330,7 @@ auto exp = NDArrayFactory::create('c', {2, 3, 2}, { // ---------------------------------------------------------------- nd4j::ops::roll op; - auto result = op.execute({&x}, {}, {1, 2, 1, 0}, {}, false, nd4j::DataType::DOUBLE); + auto result = op.evaluate({&x}, {}, {1, 2, 1, 0}); ASSERT_EQ(result->status(), Status::OK()); //result->at(0)->printIndexedBuffer("Output"); @@ -3353,7 +3353,7 @@ auto exp = NDArrayFactory::create('c', {2, 3, 2}, { // ---------------------------------------------------------------- nd4j::ops::roll op; NDArray* y = nullptr; - auto result = op.execute({&x}, {y}, {}, {1, 2, 1, 0}, {}, true, nd4j::DataType::DOUBLE); + auto result = op.execute({&x}, {y}, {}, {1, 2, 1, 0}, {}, {}, true); ASSERT_EQ(result, Status::OK()); //x.printIndexedBuffer("Output"); @@ -3376,7 +3376,7 @@ auto exp = NDArrayFactory::create('c', {2, 3, 3}, { // ---------------------------------------------------------------- nd4j::ops::roll op; NDArray* y = nullptr; - auto result = op.execute({&x}, {y}, {}, {1, 1}, {}, true, nd4j::DataType::DOUBLE); + auto result = op.execute({&x}, {y}, {}, {1, 1}, {}, {}, true); ASSERT_EQ(result, Status::OK()); ASSERT_TRUE(exp.equalsTo(&x)); @@ -3395,7 +3395,7 @@ TEST_F(DeclarableOpsTests7, TestRoll_10) { }); // ---------------------------------------------------------------- nd4j::ops::roll op; - auto result = op.execute({&x}, {}, {3, 1}, {}, false, nd4j::DataType::DOUBLE); + auto result = op.evaluate({&x}, {}, {3, 1}); ASSERT_EQ(result->status(), Status::OK()); auto out = result->at(0); @@ -3420,7 +3420,7 @@ TEST_F(DeclarableOpsTests7, TestRoll_11) { // ---------------------------------------------------------------- nd4j::ops::roll op; NDArray* y = nullptr; - auto result = op.execute({&x, &shift, &axis}, {}, {}, {}, false, nd4j::DataType::DOUBLE); + auto result = op.evaluate({&x, &shift, &axis}); ASSERT_EQ(result->status(), Status::OK()); auto out = result->at(0); @@ -3446,7 +3446,7 @@ TEST_F(DeclarableOpsTests7, TestRoll_12) { // ---------------------------------------------------------------- nd4j::ops::roll op; NDArray* y = nullptr; - auto result = op.execute({&x, &shift, &axis}, {}, {}, {}, false, nd4j::DataType::DOUBLE); + auto result = op.evaluate({&x, &shift, &axis}); ASSERT_EQ(result->status(), Status::OK()); auto out = result->at(0); @@ -3469,7 +3469,7 @@ TEST_F(DeclarableOpsTests7, TestRoll_13) { // ---------------------------------------------------------------- nd4j::ops::roll op; NDArray* y = nullptr; - auto result = op.execute({&x}, {}, {3,2}, {}, false, nd4j::DataType::DOUBLE); + auto result = op.evaluate({&x}, {}, {3,2}); ASSERT_EQ(result->status(), Status::OK()); auto out = result->at(0); @@ -3492,7 +3492,7 @@ TEST_F(DeclarableOpsTests7, TestRoll_14) { // ---------------------------------------------------------------- nd4j::ops::roll op; - auto result = op.execute({&x, &shift, &axis}, {}, {}, {}, false, nd4j::DataType::DOUBLE); + auto result = op.evaluate({&x, &shift, &axis}); ASSERT_EQ(result->status(), Status::OK()); auto out = result->at(0); // out->printIndexedBuffer("Output"); @@ -3513,7 +3513,7 @@ TEST_F(DeclarableOpsTests7, TestRoll_15) { // ---------------------------------------------------------------- nd4j::ops::roll op; - auto result = op.execute({&x, &shift, &axis}, {}, {}, {}, false, nd4j::DataType::FLOAT32); + auto result = op.evaluate({&x, &shift, &axis}); ASSERT_EQ(result->status(), Status::OK()); auto out = result->at(0); // out->printIndexedBuffer("Output 15"); @@ -3538,7 +3538,7 @@ TEST_F(DeclarableOpsTests7, percentile_test1) { nd4j::ops::percentile op; - auto result = op.execute({&input}, {50.}, {}); + auto result = op.evaluate({&input}, {50.}, {}); auto output = result->at(0); ASSERT_TRUE(expected.isSameShape(output)); @@ -3562,7 +3562,7 @@ TEST_F(DeclarableOpsTests7, percentile_test2) { nd4j::ops::percentile op; //q, interpolation, keepDims - auto result = op.execute({&input}, {10, 2, 1}, {}); + auto result = op.evaluate({&input}, {10, 2, 1}, {}); auto output = result->at(0); ASSERT_TRUE(expected.isSameShape(output)); @@ -3586,7 +3586,7 @@ TEST_F(DeclarableOpsTests7, percentile_test3) { nd4j::ops::percentile op; //q, interpolation, keepDims - auto result = op.execute({&input}, {10, 0, 1}, {}); + auto result = op.evaluate({&input}, {10, 0, 1}, {}); auto output = result->at(0); ASSERT_TRUE(expected.isSameShape(output)); @@ -3610,7 +3610,7 @@ TEST_F(DeclarableOpsTests7, percentile_test4) { nd4j::ops::percentile op; //q, interpolation, keepDims - auto result = op.execute({&input}, {10, 1, 1}, {}); + auto result = op.evaluate({&input}, {10, 1, 1}, {}); auto output = result->at(0); ASSERT_TRUE(expected.isSameShape(output)); @@ -3634,7 +3634,7 @@ TEST_F(DeclarableOpsTests7, percentile_test5) { nd4j::ops::percentile op; //q, interpolation, keepDims - auto result = op.execute({&input}, {10, 0, 1}, {0,1}); + auto result = op.evaluate({&input}, {10, 0, 1}, {0,1}); auto output = result->at(0); ASSERT_TRUE(expected.isSameShape(output)); @@ -3658,7 +3658,7 @@ TEST_F(DeclarableOpsTests7, percentile_test6) { nd4j::ops::percentile op; //q, interpolation, keepDims - auto result = op.execute({&input}, {10, 1, 1}, {0,1}); + auto result = op.evaluate({&input}, {10, 1, 1}, {0,1}); auto output = result->at(0); ASSERT_TRUE(expected.isSameShape(output)); @@ -3682,7 +3682,7 @@ TEST_F(DeclarableOpsTests7, percentile_test7) { nd4j::ops::percentile op; //q, interpolation, keepDims - auto result = op.execute({&input}, {10, 2, 1}, {0,1}); + auto result = op.evaluate({&input}, {10, 2, 1}, {0,1}); auto output = result->at(0); ASSERT_TRUE(expected.isSameShape(output)); @@ -3706,7 +3706,7 @@ TEST_F(DeclarableOpsTests7, percentile_test8) { nd4j::ops::percentile op; //q, interpolation, keepDims - auto result = op.execute({&input}, {10, 2, 0}, {0,1}); + auto result = op.evaluate({&input}, {10, 2, 0}, {0,1}); auto output = result->at(0); ASSERT_TRUE(expected.isSameShape(output)); @@ -3730,7 +3730,7 @@ TEST_F(DeclarableOpsTests7, percentile_test9) { nd4j::ops::percentile op; //q, interpolation, keepDims - auto result = op.execute({&input}, {10, 2, 0}, {0}); + auto result = op.evaluate({&input}, {10, 2, 0}, {0}); auto output = result->at(0); ASSERT_TRUE(expected.isSameShape(output)); @@ -3754,7 +3754,7 @@ TEST_F(DeclarableOpsTests7, percentile_test10) { nd4j::ops::percentile op; //q, interpolation, keepDims - auto result = op.execute({&input}, {10, 2, 1}, {0}); + auto result = op.evaluate({&input}, {10, 2, 1}, {0}); auto output = result->at(0); ASSERT_TRUE(expected.isSameShape(output)); @@ -3774,7 +3774,7 @@ TEST_F(DeclarableOpsTests7, percentile_test11) { nd4j::ops::percentile op; //q, interpolation, keepDims - auto result = op.execute({&input}, {10, 2, 1}, {0}); + auto result = op.evaluate({&input}, {10, 2, 1}, {0}); auto output = result->at(0); ASSERT_TRUE(expected.isSameShape(output)); @@ -3794,7 +3794,7 @@ TEST_F(DeclarableOpsTests7, percentile_test12) { nd4j::ops::percentile op; //q, interpolation, keepDims - auto result = op.execute({&input}, {10, 2, 0}, {}); + auto result = op.evaluate({&input}, {10, 2, 0}, {}); auto output = result->at(0); ASSERT_TRUE(expected.isSameShape(output)); @@ -3810,7 +3810,7 @@ TEST_F(DeclarableOpsTests7, transpose_test3) { auto exp = NDArrayFactory::create('c', {3, 5}, {1.f, 4.f, 7.f, 10.f, 13.f, 2.f, 5.f, 8.f, 11.f, 14.f, 3.f, 6.f, 9.f, 12.f, 15.f}); nd4j::ops::transpose op; - auto result = op.execute({&input}, {}, {}); + auto result = op.evaluate({&input}, {}, {}); auto output = result->at(0); ASSERT_TRUE(exp.isSameShape(output)); @@ -3826,7 +3826,7 @@ TEST_F(DeclarableOpsTests7, rationaltanh_test1) { NDArray exp = NDArrayFactory::create({0.000000, 0.998222, 1.516093, 1.658054, 1.695077, 1.706884, 1.711427, 1.713446}); nd4j::ops::rationaltanh op; - auto result = op.execute({&input}, {}, {}); + auto result = op.evaluate({&input}, {}, {}); auto output = result->at(0); // output->printIndexedBuffer("Output rationaltanh"); ASSERT_TRUE(exp.isSameShape(output)); @@ -3842,7 +3842,7 @@ TEST_F(DeclarableOpsTests7, rationaltanh_test2) { NDArray exp = NDArrayFactory::create('c', {2,2,2}, {0.000000, 0.998222, 1.516093, 1.658054, 1.695077, 1.706884, 1.711427, 1.713446}); nd4j::ops::rationaltanh op; - auto result = op.execute({&input}, {}, {}); + auto result = op.evaluate({&input}, {}, {}); auto output = result->at(0); // output->printIndexedBuffer("Output rationaltanh"); ASSERT_TRUE(exp.isSameShape(output)); @@ -3859,7 +3859,7 @@ TEST_F(DeclarableOpsTests7, rationaltanh_test3) { NDArray exp = NDArrayFactory::create('c', {2,2,2}, {1.143933, 1.605747, 0.795557, 0.261710, 0.095832, 0.041218, 0.020221, 0.010971}); nd4j::ops::rationaltanh_bp op; - auto result = op.execute({&input, &eps}, {}, {}); + auto result = op.evaluate({&input, &eps}, {}, {}); auto output = result->at(0); // output->printBuffer("Output rationaltanh BP"); ASSERT_TRUE(exp.isSameShape(output)); @@ -3875,7 +3875,7 @@ TEST_F(DeclarableOpsTests7, rectifiedtanh_test1) { NDArray exp = NDArrayFactory::create('c', {2,2,2}, {0.000000, 0.761594, 0.964028, 0.995055, 0.999329, 0.999909, 0.999988, 0.999998}); nd4j::ops::rectifiedtanh op; - auto result = op.execute({&input}, {}, {}); + auto result = op.evaluate({&input}, {}, {}); auto output = result->at(0); // output->printIndexedBuffer("Output rectifiedtanh"); ASSERT_TRUE(exp.isSameShape(output)); @@ -3892,7 +3892,7 @@ TEST_F(DeclarableOpsTests7, rectifiedtanh_test2) { NDArray exp = NDArrayFactory::create('c', {2,2,2}, {0.000000, 0.839949, 0.211952, 0.039464, 0.006705, 0.001089, 0.000172, 0.000027}); nd4j::ops::rectifiedtanh_bp op; - auto result = op.execute({&input, &eps}, {}, {}); + auto result = op.evaluate({&input, &eps}, {}, {}); auto output = result->at(0); // output->printBuffer("Output rectifiedtanh BP"); ASSERT_TRUE(exp.isSameShape(output)); @@ -3908,7 +3908,7 @@ TEST_F(DeclarableOpsTests7, RealDiv_1) { NDArray e = NDArrayFactory::create('c', {1, 2, 2}, {2.f, 1.f, 4.f, 2.f}); nd4j::ops::realdiv op; - auto result = op.execute({&x, &y}, {}, {}); + auto result = op.evaluate({&x, &y}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); @@ -3930,7 +3930,7 @@ TEST_F(DeclarableOpsTests7, RealDiv_BP_1) { NDArray eps = NDArrayFactory::create('c', {1, 2, 2}, {1.f, 2.f, 3.f, 4.f}); nd4j::ops::realdiv_bp op; - auto result = op.execute({&x, &y, &eps}, {}, {}); + auto result = op.evaluate({&x, &y, &eps}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); @@ -3955,7 +3955,7 @@ TEST_F(DeclarableOpsTests7, ShapesOf_1) { NDArray e = NDArrayFactory::create({1, 2, 1}); nd4j::ops::shapes_of op; - auto result = op.execute({&x}, {}, {}); + auto result = op.evaluate({&x}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); @@ -3976,7 +3976,7 @@ TEST_F(DeclarableOpsTests7, ShapesOf_2) { NDArray e1 = NDArrayFactory::create({1, 2}); nd4j::ops::shapes_of op; - auto result = op.execute({&x, &y}, {}, {}); + auto result = op.evaluate({&x, &y}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); @@ -3998,7 +3998,7 @@ TEST_F(DeclarableOpsTests7, Size_1) { NDArray e = NDArrayFactory::create(2); nd4j::ops::size op; - auto result = op.execute({&x}, {}, {}); + auto result = op.evaluate({&x}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); @@ -4017,7 +4017,7 @@ TEST_F(DeclarableOpsTests7, Size_2) { NDArray e = NDArrayFactory::create(10); nd4j::ops::size op; - auto result = op.execute({&y}, {}, {}); + auto result = op.evaluate({&y}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); @@ -4035,7 +4035,7 @@ TEST_F(DeclarableOpsTests7, Softplus_1) { NDArray e = NDArrayFactory::create('c', {5, 2}, {1.3132616, 2.126928, 3.0485873, 4.01815, 5.0067153, 7.0009117, 9.000123, 10.000046, 10.000046, 11.000016}); nd4j::ops::softplus op; - auto result = op.execute({&x}, {}, {}); + auto result = op.evaluate({&x}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); @@ -4075,7 +4075,7 @@ TEST_F(DeclarableOpsTests7, Softsign_1) { NDArray e = NDArrayFactory::create('c', {5, 2}, {0.5, 0.6666667, 0.75, 0.8, 0.8333333, 0.875, 0.9, 0.90909094, 0.90909094, 0.9166667}); nd4j::ops::softsign op; - auto result = op.execute({&x}, {}, {}); + auto result = op.evaluate({&x}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); @@ -4110,7 +4110,7 @@ TEST_F(DeclarableOpsTests7, fill_test2) { auto exp = NDArrayFactory::create('c', {2, 2},{42.f, 42.f, 42.f, 42.f}); nd4j::ops::fill op; - auto result = op.execute({&x, &v}, {}, {}); + auto result = op.evaluate({&x, &v}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -4130,7 +4130,7 @@ TEST_F(DeclarableOpsTests7, fill_test3) { auto exp = NDArrayFactory::create('c', {2, 2}, {42.f, 42.f, 42.f, 42.f}); nd4j::ops::fill op; - auto result = op.execute({&x, &v}, {}, {}); + auto result = op.evaluate({&x, &v}, {}, {}); auto output = result->at(0); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -4148,7 +4148,7 @@ TEST_F(DeclarableOpsTests7, ToggleBits_test1) { auto exp = NDArrayFactory::create('c', {2}, {-3, -3}); nd4j::ops::toggle_bits op; - auto result = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::INT32); + auto result = op.evaluate({&x}); auto output = result->at(0); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -4168,7 +4168,7 @@ TEST_F(DeclarableOpsTests7, ToggleBits_test2) { auto exp1 = NDArrayFactory::create('c', {2}, {-2, -2}); nd4j::ops::toggle_bits op; - auto result = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::INT32); + auto result = op.evaluate({&x, &y}); auto output = result->at(0); auto z = result->at(1); @@ -4189,7 +4189,7 @@ TEST_F(DeclarableOpsTests7, Truncatediv_test1) { NDArray exp = NDArrayFactory::create('c', {5, 2}, {0.5, 1., 1.5, 2., 2.5, 3.5, 4.5, 5., 5., 5.5}); nd4j::ops::truncatediv op; - auto result = op.execute({&x, &y}, {}, {}); + auto result = op.evaluate({&x, &y}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto output = result->at(0); // output->printIndexedBuffer("Toggled"); @@ -4205,7 +4205,7 @@ TEST_F(DeclarableOpsTests7, Truncatediv_test2) { NDArray exp = NDArrayFactory::create('c', {5, 2}, {0.5, 1., 1.5, 2., 2.5, 3.5, 4.5, 5., 5., 5.5}); nd4j::ops::truncatediv op; - auto result = op.execute({&x, &y}, {}, {}); + auto result = op.evaluate({&x, &y}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto output = result->at(0); // output->printIndexedBuffer("Toggled"); @@ -4224,8 +4224,8 @@ TEST_F(DeclarableOpsTests7, TypesConversion_test1) { nd4j::ops::to_int32 op32; nd4j::ops::to_int64 op64; - auto result32 = op32.execute({&x}, {}, {}); - auto result64 = op64.execute({&x}, {}, {}); + auto result32 = op32.evaluate({&x}, {}, {}); + auto result64 = op64.evaluate({&x}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result32->status()); ASSERT_EQ(ND4J_STATUS_OK, result64->status()); @@ -4250,8 +4250,8 @@ TEST_F(DeclarableOpsTests7, TypesConversion_test2) { nd4j::ops::to_float32 op32; nd4j::ops::to_float16 op16; - auto result32 = op32.execute({&x}, {}, {}); - auto result16 = op16.execute({&x}, {}, {}); + auto result32 = op32.evaluate({&x}, {}, {}); + auto result16 = op16.evaluate({&x}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result32->status()); ASSERT_EQ(ND4J_STATUS_OK, result16->status()); @@ -4276,8 +4276,8 @@ TEST_F(DeclarableOpsTests7, TypesConversion_test3) { nd4j::ops::to_uint32 op32; nd4j::ops::to_uint64 op64; - auto result32 = op32.execute({&x}, {}, {}); - auto result64 = op64.execute({&x}, {}, {}); + auto result32 = op32.evaluate({&x}, {}, {}); + auto result64 = op64.evaluate({&x}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result32->status()); ASSERT_EQ(ND4J_STATUS_OK, result64->status()); @@ -4302,8 +4302,8 @@ TEST_F(DeclarableOpsTests7, TypesConversion_test4) { nd4j::ops::to_float32 op32; nd4j::ops::to_double op64; - auto result32 = op32.execute({&x}, {}, {}); - auto result64 = op64.execute({&x}, {}, {}); + auto result32 = op32.evaluate({&x}, {}, {}); + auto result64 = op64.evaluate({&x}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result32->status()); ASSERT_EQ(ND4J_STATUS_OK, result64->status()); @@ -4326,7 +4326,7 @@ TEST_F(DeclarableOpsTests7, mirrorPad_test1) { auto exp = NDArrayFactory::create('c', {4, 7}, {2, 1, 1, 2, 3, 3, 2, 2, 1, 1, 2, 3, 3, 2, 5, 4, 4, 5, 6, 6, 5, 5, 4, 4, 5, 6, 6, 5}); nd4j::ops::mirror_pad op; - auto result = op.execute({&input, &paddings}, {}, {1}); + auto result = op.evaluate({&input, &paddings}, {}, {1}); auto output = result->at(0); ASSERT_TRUE(exp.isSameShape(output)); @@ -4344,7 +4344,7 @@ TEST_F(DeclarableOpsTests7, mirrorPad_test2) { auto exp = NDArrayFactory::create('c', {4, 7}, {6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1, 6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1}); nd4j::ops::mirror_pad op; - auto result = op.execute({&input, &paddings}, {}, {0}); + auto result = op.evaluate({&input, &paddings}, {}, {0}); auto output = result->at(0); ASSERT_TRUE(exp.isSameShape(output)); @@ -4362,7 +4362,7 @@ TEST_F(DeclarableOpsTests7, mirrorPad_test3) { auto exp = NDArrayFactory::create('c', {7}, {2, 1, 1, 2, 3, 3, 2}); nd4j::ops::mirror_pad op; - auto result = op.execute({&input, &paddings}, {}, {1}); + auto result = op.evaluate({&input, &paddings}, {}, {1}); auto output = result->at(0); ASSERT_TRUE(exp.isSameShape(output)); @@ -4380,7 +4380,7 @@ TEST_F(DeclarableOpsTests7, mirrorPad_test4) { auto exp = NDArrayFactory::create('c', {8}, {2, 1, 1, 2, 3, 3, 2, 1}); nd4j::ops::mirror_pad op; - auto result = op.execute({&input, &paddings}, {}, {1}); + auto result = op.evaluate({&input, &paddings}, {}, {1}); auto output = result->at(0); ASSERT_TRUE(exp.isSameShape(output)); @@ -4398,7 +4398,7 @@ TEST_F(DeclarableOpsTests7, mirrorPad_test5) { auto exp = NDArrayFactory::create('c', {7}, {3, 2, 1, 2, 3, 2, 1}); nd4j::ops::mirror_pad op; - auto result = op.execute({&input, &paddings}, {}, {0}); + auto result = op.evaluate({&input, &paddings}, {}, {0}); auto output = result->at(0); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); @@ -4415,7 +4415,7 @@ TEST_F(DeclarableOpsTests7, mirrorPad_test6) { auto exp = NDArrayFactory::create('c', {3}, {1,1,1}); nd4j::ops::mirror_pad op; - auto result = op.execute({&input, &paddings}, {}, {1}); + auto result = op.evaluate({&input, &paddings}, {}, {1}); auto output = result->at(0); ASSERT_TRUE(exp.isSameShape(output)); @@ -4433,7 +4433,7 @@ TEST_F(DeclarableOpsTests7, mirrorPad_test7) { auto exp = NDArrayFactory::create('c', {3}, {1,1,1}); nd4j::ops::mirror_pad op; - auto result = op.execute({&input, &paddings}, {}, {1}); + auto result = op.evaluate({&input, &paddings}, {}, {1}); auto output = result->at(0); ASSERT_TRUE(exp.isSameShape(output)); @@ -4451,7 +4451,7 @@ TEST_F(DeclarableOpsTests7, mirrorPad_test8) { auto exp = NDArrayFactory::create('c', {3,9}, {3, 2, 1, 1, 2, 3, 3, 2, 1, 3, 2, 1, 1, 2, 3, 3, 2, 1, 3, 2, 1, 1, 2, 3, 3, 2, 1}); nd4j::ops::mirror_pad op; - auto result = op.execute({&input, &paddings}, {}, {1}); + auto result = op.evaluate({&input, &paddings}, {}, {1}); ASSERT_EQ(result->status(), Status::OK()); auto output = result->at(0); @@ -4470,7 +4470,7 @@ TEST_F(DeclarableOpsTests7, mirrorPad_test9) { auto exp = NDArrayFactory::create('c', {6, 9}, {6, 5, 4, 4, 5, 6, 6, 5, 4, 3, 2, 1, 1, 2, 3, 3, 2, 1, 3, 2, 1, 1, 2, 3, 3, 2, 1, 6, 5, 4, 4, 5, 6, 6, 5, 4, 6, 5, 4, 4, 5, 6, 6, 5, 4, 3, 2, 1, 1, 2, 3, 3, 2, 1}); nd4j::ops::mirror_pad op; - auto result = op.execute({&input, &paddings}, {}, {1}); + auto result = op.evaluate({&input, &paddings}, {}, {1}); auto output = result->at(0); ASSERT_TRUE(exp.isSameShape(output)); @@ -4488,7 +4488,7 @@ TEST_F(DeclarableOpsTests7, mirrorPad_test10) { auto exp = NDArrayFactory::create('c', {1,3}, {1., 2., 3.}); nd4j::ops::mirror_pad op; - auto result = op.execute({&input, &paddings}, {}, {1}); + auto result = op.evaluate({&input, &paddings}, {}, {1}); auto output = result->at(0); ASSERT_TRUE(exp.isSameShape(output)); @@ -4506,7 +4506,7 @@ TEST_F(DeclarableOpsTests7, mirrorPad_test11) { auto exp = NDArrayFactory::create('c', {1,3}, {1., 2., 3.}); nd4j::ops::mirror_pad op; - auto result = op.execute({&input, &paddings}, {}, {0}); + auto result = op.evaluate({&input, &paddings}, {}, {0}); auto output = result->at(0); ASSERT_TRUE(exp.isSameShape(output)); @@ -4524,7 +4524,7 @@ TEST_F(DeclarableOpsTests7, mirrorPad_test12) { auto exp = NDArrayFactory::create('c', {3}, {1., 2., 3.}); nd4j::ops::mirror_pad op; - auto result = op.execute({&input, &paddings}, {}, {0}); + auto result = op.evaluate({&input, &paddings}, {}, {0}); auto output = result->at(0); ASSERT_TRUE(exp.isSameShape(output)); @@ -4542,7 +4542,7 @@ TEST_F(DeclarableOpsTests7, mirrorPad_test13) { auto exp = NDArrayFactory::create('c', {2, 3}, {1., 2., 3., 4., 5., 6.}); nd4j::ops::mirror_pad op; - auto result = op.execute({&input, &paddings}, {}, {0}); + auto result = op.evaluate({&input, &paddings}, {}, {0}); auto output = result->at(0); ASSERT_TRUE(exp.isSameShape(output)); @@ -4560,7 +4560,7 @@ TEST_F(DeclarableOpsTests7, mirrorPad_test14) { auto exp = NDArrayFactory::create('c', {3, 4}, {4, 5, 6, 5, 1, 2, 3, 2, 4, 5, 6, 5}); nd4j::ops::mirror_pad op; - auto result = op.execute({&input, &paddings}, {}, {0}); + auto result = op.evaluate({&input, &paddings}, {}, {0}); auto output = result->at(0); ASSERT_TRUE(exp.isSameShape(output)); @@ -4578,7 +4578,7 @@ TEST_F(DeclarableOpsTests7, mirrorPad_test15) { auto exp = NDArrayFactory::create('c', {4, 3}, {1, 2, 3, 1, 2, 3, 4, 5, 6, 4, 5, 6}); nd4j::ops::mirror_pad op; - auto result = op.execute({&input, &paddings}, {}, {1}); + auto result = op.evaluate({&input, &paddings}, {}, {1}); auto output = result->at(0); ASSERT_TRUE(exp.isSameShape(output)); @@ -4601,7 +4601,7 @@ TEST_F(DeclarableOpsTests7, mirrorPad_test16) { input.linspace(1.); nd4j::ops::mirror_pad op; - auto result = op.execute({&input, &paddings}, {}, {0}); + auto result = op.evaluate({&input, &paddings}, {}, {0}); ASSERT_EQ(result->status(), Status::OK()); auto output = result->at(0); //output->printBuffer("VVV"); @@ -4621,7 +4621,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_Sum_1) { //************************************// nd4j::ops::reduce_sum op; - auto result = op.execute({&input}, {}, {}); + auto result = op.evaluate({&input}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -4638,7 +4638,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_Sum_2) { //************************************// nd4j::ops::reduce_sum op; - auto result = op.execute({&input}, {}, {1}); + auto result = op.evaluate({&input}, {}, {1}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -4655,7 +4655,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_Prod_1) { //************************************// nd4j::ops::reduce_prod op; - auto result = op.execute({&input}, {}, {}); + auto result = op.evaluate({&input}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -4672,7 +4672,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_Prod_2) { //************************************// nd4j::ops::reduce_prod op; - auto result = op.execute({&input}, {}, {1}); + auto result = op.evaluate({&input}, {}, {1}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -4689,7 +4689,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_Sum_01) { x.linspace(1); nd4j::ops::reduce_sum op; - auto result = op.execute({&x}, {}, {0,1}); + auto result = op.evaluate({&x}, {}, {0,1}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -4708,7 +4708,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_Sum_02) { x.linspace(1); nd4j::ops::reduce_sum op; - auto result = op.execute({&x}, {1.}, {0, 1}); + auto result = op.evaluate({&x}, {1.}, {0, 1}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); @@ -4728,7 +4728,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_Sum_3) { x.linspace(1); nd4j::ops::reduce_sum op; - auto result = op.execute({&x}, {}, {0, 2}); + auto result = op.evaluate({&x}, {}, {0, 2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); @@ -4748,7 +4748,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_Sum_4) { x.linspace(1); nd4j::ops::reduce_sum op; - auto result = op.execute({&x}, {1.}, {0, 2}); + auto result = op.evaluate({&x}, {1.}, {0, 2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); @@ -4768,7 +4768,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_Sum_5) { x.linspace(1); nd4j::ops::reduce_sum op; - auto result = op.execute({&x}, {}, {}); + auto result = op.evaluate({&x}, {}, {}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); @@ -4788,7 +4788,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_Sum_6) { x.linspace(1); nd4j::ops::reduce_sum op; - auto result = op.execute({&x}, {}, {0,1,2}); + auto result = op.evaluate({&x}, {}, {0,1,2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); @@ -4808,7 +4808,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_Sum_7) { x.linspace(1); // x.printIndexedBuffer("Input with shape (2, 3, 4) is"); nd4j::ops::reduce_sum op; - auto result = op.execute({&x}, {1.}, {0,1,2}); + auto result = op.evaluate({&x}, {1.}, {0,1,2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); @@ -4828,7 +4828,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_Prod_01) { x.linspace(1); nd4j::ops::reduce_prod op; - auto result = op.execute({&x}, {}, {0,1}); + auto result = op.evaluate({&x}, {}, {0,1}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -4847,7 +4847,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_Prod_02) { x.linspace(1); nd4j::ops::reduce_prod op; - auto result = op.execute({&x}, {1.}, {0, 1}); + auto result = op.evaluate({&x}, {1.}, {0, 1}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); @@ -4867,7 +4867,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_Prod_3) { x.linspace(1); nd4j::ops::reduce_prod op; - auto result = op.execute({&x}, {}, {0, 2}); + auto result = op.evaluate({&x}, {}, {0, 2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); @@ -4887,7 +4887,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_Prod_4) { x.linspace(1); nd4j::ops::reduce_prod op; - auto result = op.execute({&x}, {1.}, {0, 2}); + auto result = op.evaluate({&x}, {1.}, {0, 2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); @@ -4907,7 +4907,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_Prod_5) { x.linspace(1); nd4j::ops::reduce_prod op; - auto result = op.execute({&x}, {}, {}); + auto result = op.evaluate({&x}, {}, {}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); @@ -4927,7 +4927,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_Prod_6) { x.linspace(1); nd4j::ops::reduce_prod op; - auto result = op.execute({&x}, {}, {0,1,2}); + auto result = op.evaluate({&x}, {}, {0,1,2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); @@ -4947,7 +4947,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_Prod_7) { x.linspace(1); // x.printIndexedBuffer("Input with shape (2, 3, 4) is"); nd4j::ops::reduce_prod op; - auto result = op.execute({&x}, {1.}, {0,1,2}); + auto result = op.evaluate({&x}, {1.}, {0,1,2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); @@ -4965,7 +4965,7 @@ TEST_F(DeclarableOpsTests7, Test_Matmul_Once_Again) { auto exp = NDArrayFactory::create('c', {1, 1}, {8.0f}); nd4j::ops::matmul op; - auto result = op.execute({&x, &y}, {}, {}); + auto result = op.evaluate({&x, &y}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(exp, *result->at(0)); @@ -4978,7 +4978,7 @@ TYPED_TEST(TypedDeclarableOpsTests7, Test_Pnorm_Once_Again) { auto exp = NDArrayFactory::create('c', {1, 1, 5, 5}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f, 25.0f}); nd4j::ops::pnormpool2d op; - auto result = op.execute({&input}, {}, {1,1, 1,1, 0,0, 1,1,1, 3, 0}); + auto result = op.evaluate({&input}, {}, {1,1, 1,1, 0,0, 1,1,1, 3, 0}); ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(exp, *result->at(0)); @@ -4994,7 +4994,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_Min_1) { x.linspace(1); nd4j::ops::reduce_min op; - auto result = op.execute({&x}, {}, {0, 1}); + auto result = op.evaluate({&x}, {}, {0, 1}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -5013,7 +5013,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_Min_2) { x.linspace(1); nd4j::ops::reduce_min op; - auto result = op.execute({&x}, {1.}, {0, 1}); + auto result = op.evaluate({&x}, {1.}, {0, 1}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); @@ -5033,7 +5033,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_Min_3) { x.linspace(1); nd4j::ops::reduce_min op; - auto result = op.execute({&x}, {}, {0, 2}); + auto result = op.evaluate({&x}, {}, {0, 2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); @@ -5053,7 +5053,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_Min_4) { x.linspace(1); nd4j::ops::reduce_min op; - auto result = op.execute({&x}, {1.}, {0, 2}); + auto result = op.evaluate({&x}, {1.}, {0, 2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); @@ -5073,7 +5073,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_Min_5) { x.linspace(1); nd4j::ops::reduce_min op; - auto result = op.execute({&x}, {}, {}); + auto result = op.evaluate({&x}, {}, {}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); @@ -5093,7 +5093,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_Min_6) { x.linspace(1); nd4j::ops::reduce_min op; - auto result = op.execute({&x}, {}, {0,1,2}); + auto result = op.evaluate({&x}, {}, {0,1,2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); @@ -5113,7 +5113,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_Min_7) { x.linspace(1); // x.printIndexedBuffer("Input with shape (2, 3, 4) is"); nd4j::ops::reduce_min op; - auto result = op.execute({&x}, {1.}, {0,1,2}); + auto result = op.evaluate({&x}, {1.}, {0,1,2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); @@ -5133,7 +5133,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_Max_1) { x.linspace(1); nd4j::ops::reduce_max op; - auto result = op.execute({&x}, {}, {0,1}); + auto result = op.evaluate({&x}, {}, {0,1}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); // output->printShapeInfo("Output shape"); @@ -5153,7 +5153,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_Max_2) { x.linspace(1); nd4j::ops::reduce_max op; - auto result = op.execute({&x}, {1.}, {0, 1}); + auto result = op.evaluate({&x}, {1.}, {0, 1}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); @@ -5173,7 +5173,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_Max_3) { x.linspace(1); nd4j::ops::reduce_max op; - auto result = op.execute({&x}, {}, {0, 2}); + auto result = op.evaluate({&x}, {}, {0, 2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); @@ -5193,7 +5193,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_Max_4) { x.linspace(1); nd4j::ops::reduce_max op; - auto result = op.execute({&x}, {1.}, {0, 2}); + auto result = op.evaluate({&x}, {1.}, {0, 2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); @@ -5213,7 +5213,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_Max_5) { x.linspace(1); nd4j::ops::reduce_max op; - auto result = op.execute({&x}, {}, {}); + auto result = op.evaluate({&x}, {}, {}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); @@ -5233,7 +5233,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_Max_6) { x.linspace(1); nd4j::ops::reduce_max op; - auto result = op.execute({&x}, {}, {0,1,2}); + auto result = op.evaluate({&x}, {}, {0,1,2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); @@ -5253,7 +5253,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_Max_7) { x.linspace(1); // x.printIndexedBuffer("Input with shape (2, 3, 4) is"); nd4j::ops::reduce_max op; - auto result = op.execute({&x}, {1.}, {0,1,2}); + auto result = op.evaluate({&x}, {1.}, {0,1,2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); @@ -5273,7 +5273,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_Norm1_1) { x.linspace(1); nd4j::ops::reduce_norm1 op; - auto result = op.execute({&x}, {}, {0,1}); + auto result = op.evaluate({&x}, {}, {0,1}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -5292,7 +5292,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_Norm1_2) { x.linspace(1); nd4j::ops::reduce_norm1 op; - auto result = op.execute({&x}, {1.}, {0, 1}); + auto result = op.evaluate({&x}, {1.}, {0, 1}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); @@ -5312,7 +5312,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_Norm1_3) { x.linspace(1); nd4j::ops::reduce_norm1 op; - auto result = op.execute({&x}, {}, {0, 2}); + auto result = op.evaluate({&x}, {}, {0, 2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); @@ -5332,7 +5332,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_Norm1_4) { x.linspace(1); nd4j::ops::reduce_norm1 op; - auto result = op.execute({&x}, {1.}, {0, 2}); + auto result = op.evaluate({&x}, {1.}, {0, 2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); @@ -5352,7 +5352,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_Norm1_5) { x.linspace(1); nd4j::ops::reduce_norm1 op; - auto result = op.execute({&x}, {}, {}); + auto result = op.evaluate({&x}, {}, {}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); @@ -5372,7 +5372,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_Norm1_6) { x.linspace(1); nd4j::ops::reduce_norm1 op; - auto result = op.execute({&x}, {}, {0,1,2}); + auto result = op.evaluate({&x}, {}, {0,1,2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); @@ -5392,7 +5392,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_Norm1_7) { x.linspace(1); // x.printIndexedBuffer("Input with shape (2, 3, 4) is"); nd4j::ops::reduce_norm1 op; - auto result = op.execute({&x}, {1.}, {0,1,2}); + auto result = op.evaluate({&x}, {1.}, {0,1,2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); @@ -5411,7 +5411,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_Norm2_1) { x.linspace(1); nd4j::ops::reduce_norm2 op; - auto result = op.execute({&x}, {}, {0,1}); + auto result = op.evaluate({&x}, {}, {0,1}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -5430,7 +5430,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_Norm2_2) { x.linspace(1); nd4j::ops::reduce_norm2 op; - auto result = op.execute({&x}, {1.}, {0, 1}); + auto result = op.evaluate({&x}, {1.}, {0, 1}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); @@ -5450,7 +5450,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_Norm2_3) { x.linspace(1); nd4j::ops::reduce_norm2 op; - auto result = op.execute({&x}, {}, {0, 2}); + auto result = op.evaluate({&x}, {}, {0, 2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); @@ -5470,7 +5470,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_Norm2_4) { x.linspace(1); nd4j::ops::reduce_norm2 op; - auto result = op.execute({&x}, {1.}, {0, 2}); + auto result = op.evaluate({&x}, {1.}, {0, 2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); @@ -5490,7 +5490,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_Norm2_5) { x.linspace(1); nd4j::ops::reduce_norm2 op; - auto result = op.execute({&x}, {}, {}); + auto result = op.evaluate({&x}, {}, {}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); @@ -5510,7 +5510,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_Norm2_6) { x.linspace(1); nd4j::ops::reduce_norm2 op; - auto result = op.execute({&x}, {}, {0,1,2}); + auto result = op.evaluate({&x}, {}, {0,1,2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); @@ -5530,7 +5530,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_Norm2_7) { x.linspace(1); // x.printIndexedBuffer("Input with shape (2, 3, 4) is"); nd4j::ops::reduce_norm2 op; - auto result = op.execute({&x}, {1.}, {0,1,2}); + auto result = op.evaluate({&x}, {1.}, {0,1,2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); @@ -5550,7 +5550,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_NormMax_1) { x.linspace(1); nd4j::ops::reduce_norm_max op; - auto result = op.execute({&x}, {}, {0,1}); + auto result = op.evaluate({&x}, {}, {0,1}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -5569,7 +5569,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_NormMax_2) { x.linspace(1); nd4j::ops::reduce_norm_max op; - auto result = op.execute({&x}, {1.f}, {0,1}); + auto result = op.evaluate({&x}, {1.f}, {0,1}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -5588,7 +5588,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_NormMax_3) { x.linspace(1); nd4j::ops::reduce_norm_max op; - auto result = op.execute({&x}, {}, {0,2}); + auto result = op.evaluate({&x}, {}, {0,2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -5607,7 +5607,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_NormMax_4) { x.linspace(1); nd4j::ops::reduce_norm_max op; - auto result = op.execute({&x}, {1.f}, {0,2}); + auto result = op.evaluate({&x}, {1.f}, {0,2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -5626,7 +5626,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_NormMax_5) { x.linspace(1); nd4j::ops::reduce_norm_max op; - auto result = op.execute({&x}, {}, {}); + auto result = op.evaluate({&x}, {}, {}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); @@ -5646,7 +5646,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_NormMax_6) { x.linspace(1); nd4j::ops::reduce_norm_max op; - auto result = op.execute({&x}, {}, {0, 1, 2}); + auto result = op.evaluate({&x}, {}, {0, 1, 2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); @@ -5666,7 +5666,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_NormMax_7) { x.linspace(1); nd4j::ops::reduce_norm_max op; - auto result = op.execute({&x}, {1.f}, {}); + auto result = op.evaluate({&x}, {1.f}, {}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); @@ -5686,7 +5686,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_SquaredNorm_1) { x.linspace(1); nd4j::ops::reduce_sqnorm op; - auto result = op.execute({&x}, {}, {0,1}); + auto result = op.evaluate({&x}, {}, {0,1}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -5705,7 +5705,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_SquaredNorm_2) { x.linspace(1); nd4j::ops::reduce_sqnorm op; - auto result = op.execute({&x}, {1.f}, {0,1}); + auto result = op.evaluate({&x}, {1.f}, {0,1}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -5724,7 +5724,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_SquaredNorm_3) { x.linspace(1); nd4j::ops::reduce_sqnorm op; - auto result = op.execute({&x}, {}, {0,2}); + auto result = op.evaluate({&x}, {}, {0,2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -5743,7 +5743,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_SquaredNorm_4) { x.linspace(1); nd4j::ops::reduce_sqnorm op; - auto result = op.execute({&x}, {1.f}, {0,2}); + auto result = op.evaluate({&x}, {1.f}, {0,2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -5762,7 +5762,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_SquaredNorm_5) { x.linspace(1); nd4j::ops::reduce_sqnorm op; - auto result = op.execute({&x}, {}, {}); + auto result = op.evaluate({&x}, {}, {}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); @@ -5782,7 +5782,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_SquaredNorm_6) { x.linspace(1); nd4j::ops::reduce_sqnorm op; - auto result = op.execute({&x}, {}, {0, 1, 2}); + auto result = op.evaluate({&x}, {}, {0, 1, 2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); @@ -5802,7 +5802,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_SquaredNorm_7) { x.linspace(1); nd4j::ops::reduce_sqnorm op; - auto result = op.execute({&x}, {1.f}, {}); + auto result = op.evaluate({&x}, {1.f}, {}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); @@ -5823,7 +5823,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_Sum_BP_1) { //************************************// nd4j::ops::reduce_sum_bp op; - auto result = op.execute({&input, &eps}, {}, {}); + auto result = op.evaluate({&input, &eps}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -5844,7 +5844,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_Sum_BP_2) { //************************************// nd4j::ops::reduce_sum_bp op; - auto result = op.execute({&input, &eps}, {1.f}, {}); + auto result = op.evaluate({&input, &eps}, {1.f}, {}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -5865,7 +5865,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_Sum_BP_3) { //************************************// nd4j::ops::reduce_sum_bp op; - auto result = op.execute({&input, &eps}, {}, {0}); + auto result = op.evaluate({&input, &eps}, {}, {0}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -5886,7 +5886,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_Sum_BP_4) { //************************************// nd4j::ops::reduce_sum_bp op; - auto result = op.execute({&input, &eps}, {1.f}, {0}); + auto result = op.evaluate({&input, &eps}, {1.f}, {0}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -5911,7 +5911,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_Prod_BP_1) { 131539399526781282156544.f, 122143728775382565912576.f, 114000815325130245799936.f}); nd4j::ops::reduce_prod_bp op; - auto result = op.execute({&input, &eps}, {}, {}); + auto result = op.evaluate({&input, &eps}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -5933,8 +5933,8 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_Prod_BP_2) { nd4j::ops::reduce_prod_bp op; nd4j::ops::reduce_prod op_exp; - auto res = op_exp.execute({&input}, {}, {}); - auto result = op.execute({&input, &eps}, {}, {}); + auto res = op_exp.evaluate({&input}); + auto result = op.evaluate({&input, &eps}, {}, {}); exp.assign(res->at(0)->e(0)); exp /= input; exp *= eps.e(0); @@ -5958,7 +5958,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_Prod_BP_3) { nd4j::ops::reduce_prod_bp op; //nd4j::ops::reduce_prod op_exp; - auto result = op.execute({&input, &eps}, {1.f}, {0}); + auto result = op.evaluate({&input, &eps}, {1.f}, {0}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -5979,7 +5979,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_Prod_BP_03) { auto axis = NDArrayFactory::create('c', {1}, {ax}); nd4j::ops::reduce_prod_bp op; //nd4j::ops::reduce_prod op_exp; - auto result = op.execute({&input, &eps, &axis}, {}, {}, {true}); + auto result = op.evaluate({&input, &eps, &axis}, {}, {}, {true}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -6001,7 +6001,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_Prod_BP_4) { nd4j::ops::reduce_prod_bp op; nd4j::ops::reduce_prod op_exp; // auto res = op_exp.execute({&input}, {}, {}); - auto result = op.execute({&input, &eps}, {0.f}, {0}); + auto result = op.evaluate({&input, &eps}, {0.f}, {0}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -6024,7 +6024,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_Prod_BP_5) { nd4j::ops::reduce_prod_bp op; nd4j::ops::reduce_prod op_exp; // auto res = op_exp.execute({&input}, {}, {}); - auto result = op.execute({&input, &eps}, {0.f}, {1}); + auto result = op.evaluate({&input, &eps}, {0.f}, {1}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -6050,7 +6050,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_Min_BP_1) { // x.printIndexedBuffer("Input is"); // exp.printIndexedBuffer("Expected "); nd4j::ops::reduce_min_bp op; - auto result = op.execute({&x, &eps}, {}, {0, 1}); + auto result = op.evaluate({&x, &eps}, {}, {0, 1}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -6075,7 +6075,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_Min_BP_2) { // x.printIndexedBuffer("Input is"); // exp.printIndexedBuffer("Expected "); nd4j::ops::reduce_min_bp op; - auto result = op.execute({&x, &eps}, {1.f}, {0, 1}); + auto result = op.evaluate({&x, &eps}, {1.f}, {0, 1}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -6100,7 +6100,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_Min_BP_02) { // x.printIndexedBuffer("Input is"); // exp.printIndexedBuffer("Expected "); nd4j::ops::reduce_min_bp op; - auto result = op.execute({&x, &eps, &axes}, {}, {}, {true}); + auto result = op.evaluate({&x, &eps, &axes}, {}, {}, {true}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -6123,7 +6123,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_Min_BP_3) { //x.printIndexedBuffer("Input is"); // exp.printIndexedBuffer("Expected "); nd4j::ops::reduce_min_bp op; - auto result = op.execute({&x, &eps}, {1.f}, {}); + auto result = op.evaluate({&x, &eps}, {1.f}, {}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -6146,7 +6146,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_Min_BP_4) { // x.printIndexedBuffer("Input is"); // exp.printIndexedBuffer("Expected "); nd4j::ops::reduce_min_bp op; - auto result = op.execute({&x, &eps}, {}, {}); + auto result = op.evaluate({&x, &eps}, {}, {}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -6176,7 +6176,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_Min_BP_5) { // x.printIndexedBuffer("Input is"); // exp.printIndexedBuffer("Expected "); nd4j::ops::reduce_min_bp op; - auto result = op.execute({&x, &eps}, {}, {0}); + auto result = op.evaluate({&x, &eps}, {}, {0}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -6206,7 +6206,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_Min_BP_6) { // x.printIndexedBuffer("Input is"); // exp.printIndexedBuffer("Expected "); nd4j::ops::reduce_min_bp op; - auto result = op.execute({&x, &eps}, {1.f}, {0}); + auto result = op.evaluate({&x, &eps}, {1.f}, {0}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -6231,7 +6231,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_Max_BP_1) { // x.printIndexedBuffer("Input is"); // exp.printIndexedBuffer("Expected "); nd4j::ops::reduce_max_bp op; - auto result = op.execute({&x, &eps}, {}, {0, 1}); + auto result = op.evaluate({&x, &eps}, {}, {0, 1}); auto output = result->at(0); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -6255,7 +6255,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_Max_BP_2) { // x.printIndexedBuffer("Input is"); // exp.printIndexedBuffer("Expected "); nd4j::ops::reduce_max_bp op; - auto result = op.execute({&x, &eps}, {1.f}, {0, 1}); + auto result = op.evaluate({&x, &eps}, {1.f}, {0, 1}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -6281,7 +6281,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_Max_BP_02) { // x.printIndexedBuffer("Input is"); // exp.printIndexedBuffer("Expected "); nd4j::ops::reduce_max_bp op; - auto result = op.execute({&x, &eps, &axes}, {}, {}, {true}); + auto result = op.evaluate({&x, &eps, &axes}, {}, {}, {true}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -6310,7 +6310,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_Max_BP_3) { // x.printIndexedBuffer("Input is"); // exp.printIndexedBuffer("Expected "); nd4j::ops::reduce_max_bp op; - auto result = op.execute({&x, &eps}, {}, {0}); + auto result = op.evaluate({&x, &eps}, {}, {0}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -6340,7 +6340,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_Max_BP_4) { // x.printIndexedBuffer("Input is"); // exp.printIndexedBuffer("Expected "); nd4j::ops::reduce_max_bp op; - auto result = op.execute({&x, &eps}, {1.f}, {0}); + auto result = op.evaluate({&x, &eps}, {1.f}, {0}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -6364,7 +6364,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_Norm1_BP_1) { exp.p(12, -exp.e(12)); exp.p(20, -exp.e(20)); nd4j::ops::reduce_norm1_bp op; - auto result = op.execute({&x, &eps}, {}, {}); + auto result = op.evaluate({&x, &eps}, {}, {}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -6383,7 +6383,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_Norm1_BP_2) { x.linspace(1); auto exp = NDArrayFactory::create('c', {2, 3, 4}, {1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f,1.f, 2.f, 3.f, 4.f,1.f, 2.f, 3.f, 4.f}); nd4j::ops::reduce_norm1_bp op; - auto result = op.execute({&x, &eps}, {}, {0,1}); + auto result = op.evaluate({&x, &eps}, {}, {0,1}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto output = result->at(0); // output->printIndexedBuffer("Result is"); @@ -6403,7 +6403,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_Norm1_BP_02) { auto exp = NDArrayFactory::create('c', {2, 3, 4}, {1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f,1.f, 2.f, 3.f, 4.f,1.f, 2.f, 3.f, 4.f}); auto axes = NDArrayFactory::create({0,1}); nd4j::ops::reduce_norm1_bp op; - auto result = op.execute({&x, &eps, &axes}, {}, {}, {false}); + auto result = op.evaluate({&x, &eps, &axes}, {}, {}, {false}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto output = result->at(0); ASSERT_TRUE(exp.isSameShape(output)); @@ -6420,7 +6420,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_Norm1_BP_3) { x.linspace(1); auto exp = NDArrayFactory::create('c', {2, 3, 4}, {1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f,1.f, 2.f, 3.f, 4.f,1.f, 2.f, 3.f, 4.f}); nd4j::ops::reduce_norm1_bp op; - auto result = op.execute({&x, &eps}, {1.f}, {0,1}); + auto result = op.evaluate({&x, &eps}, {1.f}, {0,1}); auto output = result->at(0); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -6438,7 +6438,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_Norm2_BP_1) { x.linspace(1); nd4j::ops::reduce_norm2_bp op; - auto result = op.execute({&x, &eps}, {}, {0,1}); + auto result = op.evaluate({&x, &eps}, {}, {0,1}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -6457,7 +6457,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_Norm2_BP_2) { x.linspace(1); nd4j::ops::reduce_norm2_bp op; - auto result = op.execute({&x, &eps}, {1.f}, {0,1}); + auto result = op.evaluate({&x, &eps}, {1.f}, {0,1}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -6477,7 +6477,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_Norm2_BP_02) { x.linspace(1); nd4j::ops::reduce_norm2_bp op; - auto result = op.execute({&x, &eps, &axes}, {}, {}, {true}); + auto result = op.evaluate({&x, &eps, &axes}, {}, {}, {true}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -6496,7 +6496,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_Norm2_BP_3) { x.linspace(1); nd4j::ops::reduce_norm2_bp op; - auto result = op.execute({&x, &eps}, {}, {0, 2}); + auto result = op.evaluate({&x, &eps}, {}, {0, 2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); @@ -6516,7 +6516,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_Norm2_BP_4) { x.linspace(1); nd4j::ops::reduce_norm2_bp op; - auto result = op.execute({&x, &eps}, {1.f}, {0, 2}); + auto result = op.evaluate({&x, &eps}, {1.f}, {0, 2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); @@ -6542,7 +6542,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_SquaredNorm_BP_1) { x.linspace(1); nd4j::ops::reduce_sqnorm_bp op; - auto result = op.execute({&x, &eps}, {}, {0,1}); + auto result = op.evaluate({&x, &eps}, {}, {0,1}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto output = result->at(0); @@ -6568,7 +6568,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_SquaredNorm_BP_01) { x.linspace(1); nd4j::ops::reduce_sqnorm_bp op; - auto result = op.execute({&x, &eps, &axes}, {}, {}, {false}); + auto result = op.evaluate({&x, &eps, &axes}, {}, {}, {false}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto output = result->at(0); @@ -6592,7 +6592,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_NormMax_BP_1) { exp.p(23, 4.f); nd4j::ops::reduce_norm_max_bp op; - auto result = op.execute({&x, &eps}, {}, {0,1}); + auto result = op.evaluate({&x, &eps}, {}, {0,1}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -6616,7 +6616,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_NormMax_BP_2) { exp.p(23, 4.f); nd4j::ops::reduce_norm_max_bp op; - auto result = op.execute({&x, &eps}, {1.f}, {0,1}); + auto result = op.evaluate({&x, &eps}, {1.f}, {0,1}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -6641,7 +6641,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_NormMax_BP_02) { exp.p(23, 4.f); nd4j::ops::reduce_norm_max_bp op; - auto result = op.execute({&x, &eps, &axes}, {}, {}, {true}); + auto result = op.evaluate({&x, &eps, &axes}, {}, {}, {true}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -6664,7 +6664,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_NormMax_BP_3) { exp.p(23, 3.f); nd4j::ops::reduce_norm_max_bp op; - auto result = op.execute({&x, &eps}, {}, {0,2}); + auto result = op.evaluate({&x, &eps}, {}, {0,2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -6686,7 +6686,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_NormMax_BP_4) { exp.p(19, 2.f); exp.p(23, 3.f); nd4j::ops::reduce_norm_max_bp op; - auto result = op.execute({&x, &eps}, {1.f}, {0,2}); + auto result = op.evaluate({&x, &eps}, {1.f}, {0,2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -6706,7 +6706,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_NormMax_BP_5) { x.linspace(1); exp.p(23, 1.f); nd4j::ops::reduce_norm_max_bp op; - auto result = op.execute({&x, &eps}, {}, {}); + auto result = op.evaluate({&x, &eps}, {}, {}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); @@ -6728,7 +6728,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_NormMax_BP_6) { exp.p(23, 1.f); nd4j::ops::reduce_norm_max_bp op; - auto result = op.execute({&x, &eps}, {}, {0, 1, 2}); + auto result = op.evaluate({&x, &eps}, {}, {0, 1, 2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); @@ -6749,7 +6749,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_NormMax_BP_7) { x.linspace(1); exp.p(23, 1.f); nd4j::ops::reduce_norm_max_bp op; - auto result = op.execute({&x, &eps}, {1.f}, {}); + auto result = op.evaluate({&x, &eps}, {1.f}, {}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); @@ -6774,7 +6774,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_Dot_BP_1) { nd4j::ops::reduce_dot_bp op; - auto result = op.execute({&x, &y, &eps}, {}, {}); + auto result = op.evaluate({&x, &y, &eps}, {}, {}); auto output = result->at(0); auto outputX = result->at(1); //tput->printIndexedBuffer("Result is"); @@ -6805,7 +6805,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_Dot_BP_2) { eps.linspace(1); y.assign(2.f); nd4j::ops::reduce_dot_bp op; - auto result = op.execute({&x, &y, &eps}, {}, {1}); + auto result = op.evaluate({&x, &y, &eps}, {}, {1}); ASSERT_EQ(result->status(), ND4J_STATUS_OK); ASSERT_EQ(result->size(), 2); auto outputX = result->at(0); @@ -6837,7 +6837,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_Dot_BP_02) { eps.linspace(1); y.assign(2.f); nd4j::ops::reduce_dot_bp op; - auto result = op.execute({&x, &y, &eps, &axis}, {}, {}, {false}); + auto result = op.evaluate({&x, &y, &eps, &axis}, {}, {}, {false}); ASSERT_EQ(result->status(), ND4J_STATUS_OK); ASSERT_EQ(result->size(), 2); auto outputX = result->at(0); @@ -6864,7 +6864,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_Dot_BP_3) { y.assign(2.f); nd4j::ops::reduce_dot_bp op; - auto result = op.execute({&x,&y, &eps}, {}, {1}); + auto result = op.evaluate({&x,&y, &eps}, {}, {1}); auto outputX = result->at(0); auto outputY = result->at(1); @@ -6886,7 +6886,7 @@ TEST_F(DeclarableOpsTests7, cumsum_bp_1) { eps.assign(1.f); nd4j::ops::cumsum_bp op; - auto result = op.execute({&x, &eps}, {}, {0,0}); + auto result = op.evaluate({&x, &eps}, {}, {0,0}); auto output = result->at(0); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -6908,7 +6908,7 @@ TEST_F(DeclarableOpsTests7, cumsum_bp_2) { nd4j::ops::cumsum_bp op; - auto result = op.execute({&x, &eps}, {}, {1,0}); + auto result = op.evaluate({&x, &eps}, {}, {1,0}); auto output = result->at(0); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -6930,7 +6930,7 @@ TEST_F(DeclarableOpsTests7, cumsum_bp_3) { eps.assign(1.f); nd4j::ops::cumsum_bp op; - auto result = op.execute({&x, &eps}, {}, {1,1}); + auto result = op.evaluate({&x, &eps}, {}, {1,1}); auto output = result->at(0); ASSERT_EQ(ND4J_STATUS_OK, result->status()); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests8.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests8.cpp index 3fb90b480..05c21a8f0 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests8.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests8.cpp @@ -57,7 +57,7 @@ TEST_F(DeclarableOpsTests8, reduceVariance_test1) { auto exp = NDArrayFactory::create('c', {4}, {602.2222f, 727.13885f, 993.5555f, 755.8889f}); nd4j::ops::reduce_variance op; - auto result = op.execute({&x}, {}, {0,1}); + auto result = op.evaluate({&x}, {}, {0,1}); auto output = result->at(0); ASSERT_EQ(Status::OK(), result->status()); @@ -75,7 +75,7 @@ TEST_F(DeclarableOpsTests8, reduceVariance_test2) { auto exp = NDArrayFactory::create('c', {1,1,4}, {602.2222f, 727.13885f, 993.5555f, 755.8889f}); nd4j::ops::reduce_variance op; - auto result = op.execute({&x}, {1.}, {0,1}); + auto result = op.evaluate({&x}, {1.}, {0,1}); auto output = result->at(0); ASSERT_EQ(Status::OK(), result->status()); @@ -93,7 +93,7 @@ TEST_F(DeclarableOpsTests8, reduceVariance_test3) { auto exp = NDArrayFactory::create('c', {3}, {900.9375f, 969.8594f, 424.1875f}); nd4j::ops::reduce_variance op; - auto result = op.execute({&x}, {}, {0,2}); + auto result = op.evaluate({&x}, {}, {0,2}); auto output = result->at(0); ASSERT_EQ(Status::OK(), result->status()); @@ -111,7 +111,7 @@ TEST_F(DeclarableOpsTests8, reduceVariance_test4) { auto exp = NDArrayFactory::create('c', {1,3,1}, {900.9375f, 969.8594f, 424.1875f}); nd4j::ops::reduce_variance op; - auto result = op.execute({&x}, {1.}, {0,2}); + auto result = op.evaluate({&x}, {1.}, {0,2}); auto output = result->at(0); ASSERT_EQ(Status::OK(), result->status()); @@ -129,7 +129,7 @@ TEST_F(DeclarableOpsTests8, reduceVariance_test5) { auto exp = NDArrayFactory::create(788.6927f); nd4j::ops::reduce_variance op; - auto result = op.execute({&x}, {}, {}); + auto result = op.evaluate({&x}, {}, {}); auto output = result->at(0); ASSERT_EQ(Status::OK(), result->status()); @@ -147,7 +147,7 @@ TEST_F(DeclarableOpsTests8, reduceVariance_test6) { auto exp = NDArrayFactory::create(788.6927f); nd4j::ops::reduce_variance op; - auto result = op.execute({&x}, {}, {0,1,2}); + auto result = op.evaluate({&x}, {}, {0,1,2}); auto output = result->at(0); ASSERT_EQ(Status::OK(), result->status()); @@ -165,7 +165,7 @@ TEST_F(DeclarableOpsTests8, reduceVariance_test7) { auto exp = NDArrayFactory::create('c', {1,1,1}, {788.6927f}); nd4j::ops::reduce_variance op; - auto result = op.execute({&x}, {1.}, {0,1,2}); + auto result = op.evaluate({&x}, {1.}, {0,1,2}); auto output = result->at(0); ASSERT_EQ(Status::OK(), result->status()); @@ -183,7 +183,7 @@ TEST_F(DeclarableOpsTests8, reduceVariance_test8) { auto exp = NDArrayFactory::create('c', {1,1,1}, {788.6927f}); auto axes = NDArrayFactory::create({0, 1, 2}); nd4j::ops::reduce_variance op; - auto result = op.execute({&x, &axes}, {}, {}, {true}); + auto result = op.evaluate({&x, &axes}, {}, {}, {true}); auto output = result->at(0); ASSERT_EQ(Status::OK(), result->status()); @@ -201,7 +201,7 @@ TEST_F(DeclarableOpsTests8, reduceStDev_test1) { auto exp = NDArrayFactory::create('c', {4}, {24.54022f, 26.96551f, 31.52072f, 27.49343f}); nd4j::ops::reduce_stdev op; - auto result = op.execute({&x}, {}, {0,1}); + auto result = op.evaluate({&x}, {}, {0,1}); auto output = result->at(0); ASSERT_EQ(Status::OK(), result->status()); @@ -219,7 +219,7 @@ TEST_F(DeclarableOpsTests8, reduceStDev_test2) { auto exp = NDArrayFactory::create('c', {1,1,4}, {24.54022f, 26.96551f, 31.52072f, 27.49343f}); nd4j::ops::reduce_stdev op; - auto result = op.execute({&x}, {1.}, {0,1}); + auto result = op.evaluate({&x}, {1.}, {0,1}); auto output = result->at(0); ASSERT_EQ(Status::OK(), result->status()); @@ -237,7 +237,7 @@ TEST_F(DeclarableOpsTests8, reduceStDev_test3) { auto exp = NDArrayFactory::create('c', {3}, {30.01562f, 31.14257f, 20.59581f}); nd4j::ops::reduce_stdev op; - auto result = op.execute({&x}, {}, {0,2}); + auto result = op.evaluate({&x}, {}, {0,2}); auto output = result->at(0); ASSERT_EQ(Status::OK(), result->status()); @@ -255,7 +255,7 @@ TEST_F(DeclarableOpsTests8, reduceStDev_test4) { auto exp = NDArrayFactory::create('c', {1,3,1}, {30.01562f, 31.14257f, 20.59581f}); nd4j::ops::reduce_stdev op; - auto result = op.execute({&x}, {1.}, {0,2}); + auto result = op.evaluate({&x}, {1.}, {0,2}); auto output = result->at(0); ASSERT_EQ(Status::OK(), result->status()); @@ -273,7 +273,7 @@ TEST_F(DeclarableOpsTests8, reduceStDev_test5) { auto exp = NDArrayFactory::create(28.08367f); nd4j::ops::reduce_stdev op; - auto result = op.execute({&x}, {}, {}); + auto result = op.evaluate({&x}, {}, {}); auto output = result->at(0); ASSERT_EQ(Status::OK(), result->status()); @@ -291,7 +291,7 @@ TEST_F(DeclarableOpsTests8, reduceStDev_test6) { auto exp = NDArrayFactory::create(28.08367f); nd4j::ops::reduce_stdev op; - auto result = op.execute({&x}, {}, {0,1,2}); + auto result = op.evaluate({&x}, {}, {0,1,2}); auto output = result->at(0); ASSERT_EQ(Status::OK(), result->status()); @@ -309,7 +309,7 @@ TEST_F(DeclarableOpsTests8, reduceStDev_test7) { auto exp = NDArrayFactory::create('c', {1,1,1}, {28.08367f}); nd4j::ops::reduce_stdev op; - auto result = op.execute({&x}, {1.f}, {0,1,2}); + auto result = op.evaluate({&x}, {1.f}, {0,1,2}); auto output = result->at(0); ASSERT_EQ(Status::OK(), result->status()); @@ -327,7 +327,7 @@ TEST_F(DeclarableOpsTests8, reduceStDev_test8) { auto exp = NDArrayFactory::create('c', {4}, {26.88246f, 29.53924f, 34.52921f, 30.11755f}); nd4j::ops::reduce_stdev op; - auto result = op.execute({&x}, {0.f,1.f}, {0,1}); + auto result = op.evaluate({&x}, {0.f,1.f}, {0,1}); auto output = result->at(0); ASSERT_EQ(Status::OK(), result->status()); @@ -345,7 +345,7 @@ TEST_F(DeclarableOpsTests8, reduceStDev_test08) { auto exp = NDArrayFactory::create('c', {4}, {26.88246f, 29.53924f, 34.52921f, 30.11755f}); auto axes = NDArrayFactory::create({0,1}); nd4j::ops::reduce_stdev op; - auto result = op.execute({&x, &axes}, {}, {}, {false, true}); + auto result = op.evaluate({&x, &axes}, {}, {}, {false, true}); auto output = result->at(0); ASSERT_EQ(Status::OK(), result->status()); @@ -369,28 +369,28 @@ TEST_F(DeclarableOpsTests8, reduceVarianceBP_test1) { nd4j::ops::reduce_variance_bp op; - auto result = op.execute({&x, &gradO2}, {0,1}, {}); + auto result = op.evaluate({&x, &gradO2}, {0,1}, {}); ASSERT_EQ(Status::OK(), result->status()); auto output = result->at(0); ASSERT_TRUE(exp12.isSameShape(output)); ASSERT_TRUE(exp12.equalsTo(output)); delete result; - result = op.execute({&x, &gradO1}, {1,1}, {}); + result = op.evaluate({&x, &gradO1}, {1,1}, {}); ASSERT_EQ(Status::OK(), result->status()); output = result->at(0); ASSERT_TRUE(exp12.isSameShape(output)); ASSERT_TRUE(exp12.equalsTo(output)); delete result; - result = op.execute({&x, &gradO2}, {0,0}, {}); + result = op.evaluate({&x, &gradO2}, {0,0}, {}); ASSERT_EQ(Status::OK(), result->status()); output = result->at(0); ASSERT_TRUE(exp34.isSameShape(output)); ASSERT_TRUE(exp34.equalsTo(output)); delete result; - result = op.execute({&x, &gradO1}, {1,0}, {}); + result = op.evaluate({&x, &gradO1}, {1,0}, {}); ASSERT_EQ(Status::OK(), result->status()); output = result->at(0); ASSERT_TRUE(exp34.isSameShape(output)); @@ -412,28 +412,28 @@ TEST_F(DeclarableOpsTests8, reduceVarianceBP_test2) { nd4j::ops::reduce_variance_bp op; - auto result = op.execute({&x, &gradO2}, {0,0}, {0}); + auto result = op.evaluate({&x, &gradO2}, {0,0}, {0}); ASSERT_EQ(Status::OK(), result->status()); auto output = result->at(0); ASSERT_TRUE(exp12.isSameShape(output)); ASSERT_TRUE(exp12.equalsTo(output)); delete result; - result = op.execute({&x, &gradO1}, {1,0}, {0}); + result = op.evaluate({&x, &gradO1}, {1,0}, {0}); ASSERT_EQ(Status::OK(), result->status()); output = result->at(0); ASSERT_TRUE(exp12.isSameShape(output)); ASSERT_TRUE(exp12.equalsTo(output)); delete result; - result = op.execute({&x, &gradO2}, {0,1}, {0}); + result = op.evaluate({&x, &gradO2}, {0,1}, {0}); ASSERT_EQ(Status::OK(), result->status()); output = result->at(0); ASSERT_TRUE(exp34.isSameShape(output)); ASSERT_TRUE(exp34.equalsTo(output)); delete result; - result = op.execute({&x, &gradO1}, {1,1}, {0}); + result = op.evaluate({&x, &gradO1}, {1,1}, {0}); ASSERT_EQ(Status::OK(), result->status()); output = result->at(0); ASSERT_TRUE(exp34.isSameShape(output)); @@ -455,28 +455,28 @@ TEST_F(DeclarableOpsTests8, reduceVarianceBP_test02) { nd4j::ops::reduce_variance_bp op; - auto result = op.execute({&x, &gradO2, &axes}, {}, {}, {false, false}); + auto result = op.evaluate({&x, &gradO2, &axes}, {}, {}, {false, false}); ASSERT_EQ(Status::OK(), result->status()); auto output = result->at(0); ASSERT_TRUE(exp12.isSameShape(output)); ASSERT_TRUE(exp12.equalsTo(output)); delete result; - result = op.execute({&x, &gradO1, &axes}, {}, {}, {true, false}); + result = op.evaluate({&x, &gradO1, &axes}, {}, {}, {true, false}); ASSERT_EQ(Status::OK(), result->status()); output = result->at(0); ASSERT_TRUE(exp12.isSameShape(output)); ASSERT_TRUE(exp12.equalsTo(output)); delete result; - result = op.execute({&x, &gradO2, &axes}, {}, {}, {false, true}); + result = op.evaluate({&x, &gradO2, &axes}, {}, {}, {false, true}); ASSERT_EQ(Status::OK(), result->status()); output = result->at(0); ASSERT_TRUE(exp34.isSameShape(output)); ASSERT_TRUE(exp34.equalsTo(output)); delete result; - result = op.execute({&x, &gradO1, &axes}, {}, {}, {true, true}); + result = op.evaluate({&x, &gradO1, &axes}, {}, {}, {true, true}); ASSERT_EQ(Status::OK(), result->status()); output = result->at(0); ASSERT_TRUE(exp34.isSameShape(output)); @@ -502,28 +502,28 @@ TEST_F(DeclarableOpsTests8, reduceVarianceBP_test3) { nd4j::ops::reduce_variance_bp op; - auto result = op.execute({&x, &gradO2}, {0, 0}, {1}); + auto result = op.evaluate({&x, &gradO2}, {0, 0}, {1}); ASSERT_EQ(Status::OK(), result->status()); auto output = result->at(0); ASSERT_TRUE(exp12.isSameShape(output)); ASSERT_TRUE(exp12.equalsTo(output)); delete result; - result = op.execute({&x, &gradO1}, {1, 0}, {1}); + result = op.evaluate({&x, &gradO1}, {1, 0}, {1}); ASSERT_EQ(Status::OK(), result->status()); output = result->at(0); ASSERT_TRUE(exp12.isSameShape(output)); ASSERT_TRUE(exp12.equalsTo(output)); delete result; - result = op.execute({&x, &gradO2}, {0, 1}, {1}); + result = op.evaluate({&x, &gradO2}, {0, 1}, {1}); ASSERT_EQ(Status::OK(), result->status()); output = result->at(0); ASSERT_TRUE(exp34.isSameShape(output)); ASSERT_TRUE(exp34.equalsTo(output)); delete result; - result = op.execute({&x, &gradO1}, {1, 1}, {1}); + result = op.evaluate({&x, &gradO1}, {1, 1}, {1}); ASSERT_EQ(Status::OK(), result->status()); output = result->at(0); ASSERT_TRUE(exp34.isSameShape(output)); @@ -544,7 +544,7 @@ TEST_F(DeclarableOpsTests8, reduceStDevBP_test1) { nd4j::ops::reduce_stdev_bp op; - auto result = op.execute({&x, &gradO2}, {0,1}, {}); + auto result = op.evaluate({&x, &gradO2}, {0,1}, {}); ASSERT_EQ(Status::OK(), result->status()); auto output = result->at(0); // output->printIndexedBuffer(); @@ -552,21 +552,21 @@ TEST_F(DeclarableOpsTests8, reduceStDevBP_test1) { ASSERT_TRUE(exp12.equalsTo(output)); delete result; - result = op.execute({&x, &gradO1}, {1,1}, {}); + result = op.evaluate({&x, &gradO1}, {1,1}, {}); ASSERT_EQ(Status::OK(), result->status()); output = result->at(0); ASSERT_TRUE(exp12.isSameShape(output)); ASSERT_TRUE(exp12.equalsTo(output)); delete result; - result = op.execute({&x, &gradO2}, {0,0}, {}); + result = op.evaluate({&x, &gradO2}, {0,0}, {}); ASSERT_EQ(Status::OK(), result->status()); output = result->at(0); ASSERT_TRUE(exp34.isSameShape(output)); ASSERT_TRUE(exp34.equalsTo(output)); delete result; - result = op.execute({&x, &gradO1}, {1,0}, {}); + result = op.evaluate({&x, &gradO1}, {1,0}, {}); ASSERT_EQ(Status::OK(), result->status()); output = result->at(0); ASSERT_TRUE(exp34.isSameShape(output)); @@ -587,28 +587,28 @@ TEST_F(DeclarableOpsTests8, reduceStDevBP_test2) { nd4j::ops::reduce_stdev_bp op; - auto result = op.execute({&x, &gradO2}, {0,0}, {0}); + auto result = op.evaluate({&x, &gradO2}, {0,0}, {0}); ASSERT_EQ(Status::OK(), result->status()); auto output = result->at(0); ASSERT_TRUE(exp12.isSameShape(output)); ASSERT_TRUE(exp12.equalsTo(output)); delete result; - result = op.execute({&x, &gradO1}, {1,0}, {0}); + result = op.evaluate({&x, &gradO1}, {1,0}, {0}); ASSERT_EQ(Status::OK(), result->status()); output = result->at(0); ASSERT_TRUE(exp12.isSameShape(output)); ASSERT_TRUE(exp12.equalsTo(output)); delete result; - result = op.execute({&x, &gradO2}, {0,1}, {0}); + result = op.evaluate({&x, &gradO2}, {0,1}, {0}); ASSERT_EQ(Status::OK(), result->status()); output = result->at(0); ASSERT_TRUE(exp34.isSameShape(output)); ASSERT_TRUE(exp34.equalsTo(output)); delete result; - result = op.execute({&x, &gradO1}, {1,1}, {0}); + result = op.evaluate({&x, &gradO1}, {1,1}, {0}); ASSERT_EQ(Status::OK(), result->status()); output = result->at(0); ASSERT_TRUE(exp34.isSameShape(output)); @@ -630,28 +630,28 @@ TEST_F(DeclarableOpsTests8, reduceStDevBP_test02) { nd4j::ops::reduce_stdev_bp op; - auto result = op.execute({&x, &gradO2, &axis}, {}, {}, {false, false}); + auto result = op.evaluate({&x, &gradO2, &axis}, {}, {}, {false, false}); ASSERT_EQ(Status::OK(), result->status()); auto output = result->at(0); ASSERT_TRUE(exp12.isSameShape(output)); ASSERT_TRUE(exp12.equalsTo(output)); delete result; - result = op.execute({&x, &gradO1, &axis}, {}, {}, {true, false}); + result = op.evaluate({&x, &gradO1, &axis}, {}, {}, {true, false}); ASSERT_EQ(Status::OK(), result->status()); output = result->at(0); ASSERT_TRUE(exp12.isSameShape(output)); ASSERT_TRUE(exp12.equalsTo(output)); delete result; - result = op.execute({&x, &gradO2, &axis}, {}, {}, {false, true}); + result = op.evaluate({&x, &gradO2, &axis}, {}, {}, {false, true}); ASSERT_EQ(Status::OK(), result->status()); output = result->at(0); ASSERT_TRUE(exp34.isSameShape(output)); ASSERT_TRUE(exp34.equalsTo(output)); delete result; - result = op.execute({&x, &gradO1, &axis}, {}, {}, {true, true}); + result = op.evaluate({&x, &gradO1, &axis}, {}, {}, {true, true}); ASSERT_EQ(Status::OK(), result->status()); output = result->at(0); ASSERT_TRUE(exp34.isSameShape(output)); @@ -672,28 +672,28 @@ TEST_F(DeclarableOpsTests8, reduceStDevBP_test3) { nd4j::ops::reduce_stdev_bp op; - auto result = op.execute({&x, &gradO2}, {0,0}, {1}); + auto result = op.evaluate({&x, &gradO2}, {0,0}, {1}); ASSERT_EQ(Status::OK(), result->status()); auto output = result->at(0); ASSERT_TRUE(exp12.isSameShape(output)); ASSERT_TRUE(exp12.equalsTo(output)); delete result; - result = op.execute({&x, &gradO1}, {1,0}, {1}); + result = op.evaluate({&x, &gradO1}, {1,0}, {1}); ASSERT_EQ(Status::OK(), result->status()); output = result->at(0); ASSERT_TRUE(exp12.isSameShape(output)); ASSERT_TRUE(exp12.equalsTo(output)); delete result; - result = op.execute({&x, &gradO2}, {0,1}, {1}); + result = op.evaluate({&x, &gradO2}, {0,1}, {1}); ASSERT_EQ(Status::OK(), result->status()); output = result->at(0); ASSERT_TRUE(exp34.isSameShape(output)); ASSERT_TRUE(exp34.equalsTo(output)); delete result; - result = op.execute({&x, &gradO1}, {1,1}, {1}); + result = op.evaluate({&x, &gradO1}, {1,1}, {1}); ASSERT_EQ(Status::OK(), result->status()); output = result->at(0); ASSERT_TRUE(exp34.isSameShape(output)); @@ -711,7 +711,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Sum_1) { //************************************// nd4j::ops::reduce_sum op; - auto result = op.execute({&input}, {}, {}); + auto result = op.evaluate({&input}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -728,7 +728,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Sum_2) { //************************************// nd4j::ops::reduce_sum op; - auto result = op.execute({&input}, {}, {1}); + auto result = op.evaluate({&input}, {}, {1}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -746,7 +746,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Sum_03) { //************************************// nd4j::ops::reduce_sum op; - auto result = op.execute({&input, &axis}, {}, {}, {false}); + auto result = op.evaluate({&input, &axis}, {}, {}, {false}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -763,7 +763,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Prod_1) { //************************************// nd4j::ops::reduce_prod op; - auto result = op.execute({&input}, {}, {}); + auto result = op.evaluate({&input}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -780,7 +780,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Prod_2) { //************************************// nd4j::ops::reduce_prod op; - auto result = op.execute({&input}, {}, {1}); + auto result = op.evaluate({&input}, {}, {1}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -797,7 +797,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Sum_01) { x.linspace(1); nd4j::ops::reduce_sum op; - auto result = op.execute({&x}, {}, {0,1}); + auto result = op.evaluate({&x}, {}, {0,1}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(Status::OK(), result->status()); @@ -816,7 +816,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Sum_02) { x.linspace(1); nd4j::ops::reduce_sum op; - auto result = op.execute({&x}, {1.}, {0, 1}); + auto result = op.evaluate({&x}, {1.}, {0, 1}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); @@ -836,7 +836,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Sum_3) { x.linspace(1); nd4j::ops::reduce_sum op; - auto result = op.execute({&x}, {}, {0, 2}); + auto result = op.evaluate({&x}, {}, {0, 2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); @@ -856,7 +856,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Sum_4) { x.linspace(1); nd4j::ops::reduce_sum op; - auto result = op.execute({&x}, {1.}, {0, 2}); + auto result = op.evaluate({&x}, {1.}, {0, 2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); @@ -876,7 +876,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Sum_5) { x.linspace(1); nd4j::ops::reduce_sum op; - auto result = op.execute({&x}, {}, {}); + auto result = op.evaluate({&x}, {}, {}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); @@ -896,7 +896,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Sum_6) { x.linspace(1); nd4j::ops::reduce_sum op; - auto result = op.execute({&x}, {}, {0,1,2}); + auto result = op.evaluate({&x}, {}, {0,1,2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); @@ -916,7 +916,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Sum_7) { x.linspace(1); // x.printIndexedBuffer("Input with shape (2, 3, 4) is"); nd4j::ops::reduce_sum op; - auto result = op.execute({&x}, {1.}, {0,1,2}); + auto result = op.evaluate({&x}, {1.}, {0,1,2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); @@ -936,7 +936,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Prod_01) { x.linspace(1); nd4j::ops::reduce_prod op; - auto result = op.execute({&x}, {}, {0,1}); + auto result = op.evaluate({&x}, {}, {0,1}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(Status::OK(), result->status()); @@ -955,7 +955,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Prod_02) { x.linspace(1); nd4j::ops::reduce_prod op; - auto result = op.execute({&x}, {1.}, {0, 1}); + auto result = op.evaluate({&x}, {1.}, {0, 1}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); @@ -975,7 +975,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Prod_3) { x.linspace(1); nd4j::ops::reduce_prod op; - auto result = op.execute({&x}, {}, {0, 2}); + auto result = op.evaluate({&x}, {}, {0, 2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); @@ -995,7 +995,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Prod_4) { x.linspace(1); nd4j::ops::reduce_prod op; - auto result = op.execute({&x}, {1.}, {0, 2}); + auto result = op.evaluate({&x}, {1.}, {0, 2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); @@ -1016,7 +1016,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Prod_04) { x.linspace(1); nd4j::ops::reduce_prod op; - auto result = op.execute({&x, &axes}, {}, {}, {true}); + auto result = op.evaluate({&x, &axes}, {}, {}, {true}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); @@ -1036,7 +1036,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Prod_5) { x.linspace(1); nd4j::ops::reduce_prod op; - auto result = op.execute({&x}, {}, {}); + auto result = op.evaluate({&x}, {}, {}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); @@ -1056,7 +1056,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Prod_6) { x.linspace(1); nd4j::ops::reduce_prod op; - auto result = op.execute({&x}, {}, {0,1,2}); + auto result = op.evaluate({&x}, {}, {0,1,2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); @@ -1076,7 +1076,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Prod_7) { x.linspace(1); // x.printIndexedBuffer("Input with shape (2, 3, 4) is"); nd4j::ops::reduce_prod op; - auto result = op.execute({&x}, {1.}, {0,1,2}); + auto result = op.evaluate({&x}, {1.}, {0,1,2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); @@ -1096,7 +1096,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Min_1) { x.linspace(1); nd4j::ops::reduce_min op; - auto result = op.execute({&x}, {}, {0, 1}); + auto result = op.evaluate({&x}, {}, {0, 1}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(Status::OK(), result->status()); @@ -1115,7 +1115,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Min_2) { x.linspace(1); nd4j::ops::reduce_min op; - auto result = op.execute({&x}, {1.}, {0, 1}); + auto result = op.evaluate({&x}, {1.}, {0, 1}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); @@ -1135,7 +1135,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Min_3) { x.linspace(1); nd4j::ops::reduce_min op; - auto result = op.execute({&x}, {}, {0, 2}); + auto result = op.evaluate({&x}, {}, {0, 2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); @@ -1155,7 +1155,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Min_4) { x.linspace(1); nd4j::ops::reduce_min op; - auto result = op.execute({&x}, {1.}, {0, 2}); + auto result = op.evaluate({&x}, {1.}, {0, 2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); @@ -1176,7 +1176,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Min_04) { x.linspace(1); nd4j::ops::reduce_min op; - auto result = op.execute({&x, &axes}, {}, {}, {true}); + auto result = op.evaluate({&x, &axes}, {}, {}, {true}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); @@ -1196,7 +1196,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Min_5) { x.linspace(1); nd4j::ops::reduce_min op; - auto result = op.execute({&x}, {}, {}); + auto result = op.evaluate({&x}, {}, {}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); @@ -1216,7 +1216,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Min_6) { x.linspace(1); nd4j::ops::reduce_min op; - auto result = op.execute({&x}, {}, {0,1,2}); + auto result = op.evaluate({&x}, {}, {0,1,2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); @@ -1236,7 +1236,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Min_7) { x.linspace(1); // x.printIndexedBuffer("Input with shape (2, 3, 4) is"); nd4j::ops::reduce_min op; - auto result = op.execute({&x}, {1.}, {0,1,2}); + auto result = op.evaluate({&x}, {1.}, {0,1,2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); @@ -1256,7 +1256,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Max_1) { x.linspace(1); nd4j::ops::reduce_max op; - auto result = op.execute({&x}, {}, {0,1}); + auto result = op.evaluate({&x}, {}, {0,1}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); // output->printShapeInfo("Output shape"); @@ -1276,7 +1276,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Max_2) { x.linspace(1); nd4j::ops::reduce_max op; - auto result = op.execute({&x}, {1.}, {0, 1}); + auto result = op.evaluate({&x}, {1.}, {0, 1}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); @@ -1296,7 +1296,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Max_3) { x.linspace(1); nd4j::ops::reduce_max op; - auto result = op.execute({&x}, {}, {0, 2}); + auto result = op.evaluate({&x}, {}, {0, 2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); @@ -1316,7 +1316,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Max_4) { x.linspace(1); nd4j::ops::reduce_max op; - auto result = op.execute({&x}, {1.}, {0, 2}); + auto result = op.evaluate({&x}, {1.}, {0, 2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); @@ -1337,7 +1337,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Max_04) { x.linspace(1); nd4j::ops::reduce_max op; - auto result = op.execute({&x, &axes}, {}, {}, {true}); + auto result = op.evaluate({&x, &axes}, {}, {}, {true}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); @@ -1357,7 +1357,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Max_5) { x.linspace(1); nd4j::ops::reduce_max op; - auto result = op.execute({&x}, {}, {}); + auto result = op.evaluate({&x}, {}, {}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); @@ -1377,7 +1377,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Max_6) { x.linspace(1); nd4j::ops::reduce_max op; - auto result = op.execute({&x}, {}, {0,1,2}); + auto result = op.evaluate({&x}, {}, {0,1,2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); @@ -1397,7 +1397,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Max_7) { x.linspace(1); // x.printIndexedBuffer("Input with shape (2, 3, 4) is"); nd4j::ops::reduce_max op; - auto result = op.execute({&x}, {1.}, {0,1,2}); + auto result = op.evaluate({&x}, {1.}, {0,1,2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); @@ -1416,7 +1416,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Norm1_1) { x.linspace(1); nd4j::ops::reduce_norm1 op; - auto result = op.execute({&x}, {}, {0,1}); + auto result = op.evaluate({&x}, {}, {0,1}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(Status::OK(), result->status()); @@ -1435,7 +1435,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Norm1_2) { x.linspace(1); nd4j::ops::reduce_norm1 op; - auto result = op.execute({&x}, {1.}, {0, 1}); + auto result = op.evaluate({&x}, {1.}, {0, 1}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); @@ -1455,7 +1455,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Norm1_3) { x.linspace(1); nd4j::ops::reduce_norm1 op; - auto result = op.execute({&x}, {}, {0, 2}); + auto result = op.evaluate({&x}, {}, {0, 2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); @@ -1475,7 +1475,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Norm1_4) { x.linspace(1); nd4j::ops::reduce_norm1 op; - auto result = op.execute({&x}, {1.}, {0, 2}); + auto result = op.evaluate({&x}, {1.}, {0, 2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); @@ -1496,7 +1496,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Norm1_04) { x.linspace(1); nd4j::ops::reduce_norm1 op; - auto result = op.execute({&x, &axes}, {}, {}, {true}); + auto result = op.evaluate({&x, &axes}, {}, {}, {true}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); @@ -1516,7 +1516,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Norm1_5) { x.linspace(1); nd4j::ops::reduce_norm1 op; - auto result = op.execute({&x}, {}, {}); + auto result = op.evaluate({&x}, {}, {}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); @@ -1536,7 +1536,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Norm1_6) { x.linspace(1); nd4j::ops::reduce_norm1 op; - auto result = op.execute({&x}, {}, {0,1,2}); + auto result = op.evaluate({&x}, {}, {0,1,2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); @@ -1556,7 +1556,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Norm1_7) { x.linspace(1); // x.printIndexedBuffer("Input with shape (2, 3, 4) is"); nd4j::ops::reduce_norm1 op; - auto result = op.execute({&x}, {1.}, {0,1,2}); + auto result = op.evaluate({&x}, {1.}, {0,1,2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); @@ -1575,7 +1575,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Norm2_1) { x.linspace(1); nd4j::ops::reduce_norm2 op; - auto result = op.execute({&x}, {}, {0,1}); + auto result = op.evaluate({&x}, {}, {0,1}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(Status::OK(), result->status()); @@ -1594,7 +1594,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Norm2_2) { x.linspace(1); nd4j::ops::reduce_norm2 op; - auto result = op.execute({&x}, {1.}, {0, 1}); + auto result = op.evaluate({&x}, {1.}, {0, 1}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); @@ -1614,7 +1614,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Norm2_3) { x.linspace(1); nd4j::ops::reduce_norm2 op; - auto result = op.execute({&x}, {}, {0, 2}); + auto result = op.evaluate({&x}, {}, {0, 2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); @@ -1634,7 +1634,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Norm2_4) { x.linspace(1); nd4j::ops::reduce_norm2 op; - auto result = op.execute({&x}, {1.}, {0, 2}); + auto result = op.evaluate({&x}, {1.}, {0, 2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); @@ -1655,7 +1655,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Norm2_04) { x.linspace(1); nd4j::ops::reduce_norm2 op; - auto result = op.execute({&x, &axes}, {}, {}, {true}); + auto result = op.evaluate({&x, &axes}, {}, {}, {true}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); @@ -1675,7 +1675,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Norm2_5) { x.linspace(1); nd4j::ops::reduce_norm2 op; - auto result = op.execute({&x}, {}, {}); + auto result = op.evaluate({&x}, {}, {}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); @@ -1695,7 +1695,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Norm2_6) { x.linspace(1); nd4j::ops::reduce_norm2 op; - auto result = op.execute({&x}, {}, {0,1,2}); + auto result = op.evaluate({&x}, {}, {0,1,2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); @@ -1715,7 +1715,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Norm2_7) { x.linspace(1); // x.printIndexedBuffer("Input with shape (2, 3, 4) is"); nd4j::ops::reduce_norm2 op; - auto result = op.execute({&x}, {1.}, {0,1,2}); + auto result = op.evaluate({&x}, {1.}, {0,1,2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); @@ -1735,7 +1735,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_NormMax_1) { x.linspace(1); nd4j::ops::reduce_norm_max op; - auto result = op.execute({&x}, {}, {0,1}); + auto result = op.evaluate({&x}, {}, {0,1}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(Status::OK(), result->status()); @@ -1754,7 +1754,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_NormMax_2) { x.linspace(1); nd4j::ops::reduce_norm_max op; - auto result = op.execute({&x}, {1.f}, {0,1}); + auto result = op.evaluate({&x}, {1.f}, {0,1}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(Status::OK(), result->status()); @@ -1773,7 +1773,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_NormMax_3) { x.linspace(1); nd4j::ops::reduce_norm_max op; - auto result = op.execute({&x}, {}, {0,2}); + auto result = op.evaluate({&x}, {}, {0,2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(Status::OK(), result->status()); @@ -1792,7 +1792,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_NormMax_4) { x.linspace(1); nd4j::ops::reduce_norm_max op; - auto result = op.execute({&x}, {1.f}, {0,2}); + auto result = op.evaluate({&x}, {1.f}, {0,2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(Status::OK(), result->status()); @@ -1812,7 +1812,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_NormMax_04) { x.linspace(1); nd4j::ops::reduce_norm_max op; - auto result = op.execute({&x, &axes}, {}, {}, {true}); + auto result = op.evaluate({&x, &axes}, {}, {}, {true}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(Status::OK(), result->status()); @@ -1831,7 +1831,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_NormMax_5) { x.linspace(1); nd4j::ops::reduce_norm_max op; - auto result = op.execute({&x}, {}, {}); + auto result = op.evaluate({&x}, {}, {}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); @@ -1851,7 +1851,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_NormMax_6) { x.linspace(1); nd4j::ops::reduce_norm_max op; - auto result = op.execute({&x}, {}, {0, 1, 2}); + auto result = op.evaluate({&x}, {}, {0, 1, 2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); @@ -1871,7 +1871,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_NormMax_7) { x.linspace(1); nd4j::ops::reduce_norm_max op; - auto result = op.execute({&x}, {1.f}, {}); + auto result = op.evaluate({&x}, {1.f}, {}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); @@ -1891,7 +1891,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_SquaredNorm_1) { x.linspace(1); nd4j::ops::reduce_sqnorm op; - auto result = op.execute({&x}, {}, {0,1}); + auto result = op.evaluate({&x}, {}, {0,1}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(Status::OK(), result->status()); @@ -1910,7 +1910,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_SquaredNorm_2) { x.linspace(1); nd4j::ops::reduce_sqnorm op; - auto result = op.execute({&x}, {1.f}, {0,1}); + auto result = op.evaluate({&x}, {1.f}, {0,1}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(Status::OK(), result->status()); @@ -1929,7 +1929,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_SquaredNorm_3) { x.linspace(1); nd4j::ops::reduce_sqnorm op; - auto result = op.execute({&x}, {}, {0,2}); + auto result = op.evaluate({&x}, {}, {0,2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(Status::OK(), result->status()); @@ -1948,7 +1948,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_SquaredNorm_4) { x.linspace(1); nd4j::ops::reduce_sqnorm op; - auto result = op.execute({&x}, {1.f}, {0,2}); + auto result = op.evaluate({&x}, {1.f}, {0,2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(Status::OK(), result->status()); @@ -1968,7 +1968,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_SquaredNorm_04) { x.linspace(1); nd4j::ops::reduce_sqnorm op; - auto result = op.execute({&x, &axes}, {}, {}, {true}); + auto result = op.evaluate({&x, &axes}, {}, {}, {true}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(Status::OK(), result->status()); @@ -1987,7 +1987,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_SquaredNorm_5) { x.linspace(1); nd4j::ops::reduce_sqnorm op; - auto result = op.execute({&x}, {}, {}); + auto result = op.evaluate({&x}, {}, {}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); @@ -2007,7 +2007,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_SquaredNorm_6) { x.linspace(1); nd4j::ops::reduce_sqnorm op; - auto result = op.execute({&x}, {}, {0, 1, 2}); + auto result = op.evaluate({&x}, {}, {0, 1, 2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); @@ -2027,7 +2027,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_SquaredNorm_7) { x.linspace(1); nd4j::ops::reduce_sqnorm op; - auto result = op.execute({&x}, {1.f}, {}); + auto result = op.evaluate({&x}, {1.f}, {}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); @@ -2048,7 +2048,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Sum_BP_1) { //************************************// nd4j::ops::reduce_sum_bp op; - auto result = op.execute({&input, &eps}, {}, {}); + auto result = op.evaluate({&input, &eps}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -2069,7 +2069,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Sum_BP_2) { //************************************// nd4j::ops::reduce_sum_bp op; - auto result = op.execute({&input, &eps}, {1.f}, {}); + auto result = op.evaluate({&input, &eps}, {1.f}, {}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -2090,7 +2090,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Sum_BP_3) { //************************************// nd4j::ops::reduce_sum_bp op; - auto result = op.execute({&input, &eps}, {}, {0}); + auto result = op.evaluate({&input, &eps}, {}, {0}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -2111,7 +2111,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Sum_BP_4) { //************************************// nd4j::ops::reduce_sum_bp op; - auto result = op.execute({&input, &eps}, {1.f}, {0}); + auto result = op.evaluate({&input, &eps}, {1.f}, {0}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -2134,7 +2134,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Sum_BP_04) { //************************************// nd4j::ops::reduce_sum_bp op; - auto result = op.execute({&input, &eps, &axis}, {}, {}, {true}); + auto result = op.evaluate({&input, &eps, &axis}, {}, {}, {true}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -2159,7 +2159,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Prod_BP_1) { 131539399526781282156544.f, 122143728775382565912576.f, 114000815325130245799936.f}); nd4j::ops::reduce_prod_bp op; - auto result = op.execute({&input, &eps}, {}, {}); + auto result = op.evaluate({&input, &eps}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -2178,7 +2178,7 @@ TEST_F(DeclarableOpsTests8, reduceMean_test1) { nd4j::ops::reduce_mean op; - auto result = op.execute({&x}, {}, {0,1}); + auto result = op.evaluate({&x}, {}, {0,1}); auto output = result->at(0); ASSERT_EQ(Status::OK(), result->status()); @@ -2198,7 +2198,7 @@ TEST_F(DeclarableOpsTests8, reduceMean_test2) { nd4j::ops::reduce_mean op; - auto result = op.execute({&x}, {1.}, {0,1}); + auto result = op.evaluate({&x}, {1.}, {0,1}); auto output = result->at(0); ASSERT_EQ(Status::OK(), result->status()); @@ -2218,7 +2218,7 @@ TEST_F(DeclarableOpsTests8, reduceMean_test3) { nd4j::ops::reduce_mean op; - auto result = op.execute({&x}, {}, {0,2}); + auto result = op.evaluate({&x}, {}, {0,2}); auto output = result->at(0); ASSERT_EQ(Status::OK(), result->status()); @@ -2238,7 +2238,7 @@ TEST_F(DeclarableOpsTests8, reduceMean_test4) { nd4j::ops::reduce_mean op; - auto result = op.execute({&x}, {1.f}, {0,2}); + auto result = op.evaluate({&x}, {1.f}, {0,2}); auto output = result->at(0); ASSERT_EQ(Status::OK(), result->status()); @@ -2258,7 +2258,7 @@ TEST_F(DeclarableOpsTests8, reduceMean_test5) { nd4j::ops::reduce_mean op; - auto result = op.execute({&x}, {}, {}); + auto result = op.evaluate({&x}, {}, {}); auto output = result->at(0); ASSERT_EQ(Status::OK(), result->status()); @@ -2277,7 +2277,7 @@ TEST_F(DeclarableOpsTests8, reduceMean_test6) { x.linspace(1); nd4j::ops::reduce_mean op; - auto result = op.execute({&x}, {}, {0,1,2}); + auto result = op.evaluate({&x}, {}, {0,1,2}); auto output = result->at(0); ASSERT_EQ(Status::OK(), result->status()); @@ -2296,7 +2296,7 @@ TEST_F(DeclarableOpsTests8, reduceMean_test7) { x.linspace(1); nd4j::ops::reduce_mean op; - auto result = op.execute({&x}, {1.}, {0,1,2}); + auto result = op.evaluate({&x}, {1.}, {0,1,2}); auto output = result->at(0); ASSERT_EQ(Status::OK(), result->status()); @@ -2316,7 +2316,7 @@ TEST_F(DeclarableOpsTests8, reduceMean_test8) { x.linspace(1); nd4j::ops::reduce_mean op; - auto result = op.execute({&x, &axes}, {}, {}, {true}); + auto result = op.evaluate({&x, &axes}, {}, {}, {true}); auto output = result->at(0); ASSERT_EQ(Status::OK(), result->status()); @@ -2339,7 +2339,7 @@ TEST_F(DeclarableOpsTests8, reduceMeanBP_test1) { nd4j::ops::reduce_mean_bp op; - auto result = op.execute({&x, &gradO1}, {0}, {}); + auto result = op.evaluate({&x, &gradO1}, {0}, {}); ASSERT_EQ(Status::OK(), result->status()); auto output = result->at(0); @@ -2349,7 +2349,7 @@ TEST_F(DeclarableOpsTests8, reduceMeanBP_test1) { ASSERT_TRUE(exp.equalsTo(output)); delete result; - result = op.execute({&x, &gradO2}, {1}, {}); + result = op.evaluate({&x, &gradO2}, {1}, {}); ASSERT_EQ(Status::OK(), result->status()); output = result->at(0); ASSERT_TRUE(exp.isSameShape(output)); @@ -2370,14 +2370,14 @@ TEST_F(DeclarableOpsTests8, reduceMeanBP_test2) { nd4j::ops::reduce_mean_bp op; - auto result = op.execute({&x, &gradO1}, {0}, {0}); + auto result = op.evaluate({&x, &gradO1}, {0}, {0}); ASSERT_EQ(Status::OK(), result->status()); auto output = result->at(0); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; - result = op.execute({&x, &gradO2}, {1}, {0}); + result = op.evaluate({&x, &gradO2}, {1}, {0}); ASSERT_EQ(Status::OK(), result->status()); output = result->at(0); ASSERT_TRUE(exp.isSameShape(output)); @@ -2398,14 +2398,14 @@ TEST_F(DeclarableOpsTests8, reduceMeanBP_test02) { nd4j::ops::reduce_mean_bp op; - auto result = op.execute({&x, &gradO1, &axis}, {}, {}, {false}); + auto result = op.evaluate({&x, &gradO1, &axis}, {}, {}, {false}); ASSERT_EQ(Status::OK(), result->status()); auto output = result->at(0); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; - result = op.execute({&x, &gradO2, &axis}, {}, {}, {true}); + result = op.evaluate({&x, &gradO2, &axis}, {}, {}, {true}); ASSERT_EQ(Status::OK(), result->status()); output = result->at(0); ASSERT_TRUE(exp.isSameShape(output)); @@ -2425,14 +2425,14 @@ TEST_F(DeclarableOpsTests8, reduceMeanBP_test3) { nd4j::ops::reduce_mean_bp op; - auto result = op.execute({&x, &gradO1}, {0}, {1}); + auto result = op.evaluate({&x, &gradO1}, {0}, {1}); ASSERT_EQ(Status::OK(), result->status()); auto output = result->at(0); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; - result = op.execute({&x, &gradO2}, {1}, {1}); + result = op.evaluate({&x, &gradO2}, {1}, {1}); ASSERT_EQ(Status::OK(), result->status()); output = result->at(0); ASSERT_TRUE(exp.isSameShape(output)); @@ -2449,7 +2449,7 @@ TEST_F(DeclarableOpsTests8, reduceStDevBP_test4) { nd4j::ops::reduce_stdev_bp op; - auto result = op.execute({&x, &gradO}, {0,1}, {}); + auto result = op.evaluate({&x, &gradO}, {0,1}, {}); ASSERT_EQ(Status::OK(), result->status()); auto output = result->at(0); @@ -2469,7 +2469,7 @@ TEST_F(DeclarableOpsTests8, softmax_cross_entropy_loss_with_logits_test1) { logits.linspace(0.1, 0.1); nd4j::ops::softmax_cross_entropy_loss_with_logits op; - auto results = op.execute({&logits, &labels}, {}, {}); + auto results = op.evaluate({&logits, &labels}, {}, {}); ASSERT_EQ(Status::OK(), results->status()); @@ -2491,7 +2491,7 @@ TEST_F(DeclarableOpsTests8, softmax_cross_entropy_loss_with_logits_test2) { logits.linspace(0.1, 0.1); nd4j::ops::softmax_cross_entropy_loss_with_logits op; - auto results = op.execute({&logits, &labels}, {}, {0}); + auto results = op.evaluate({&logits, &labels}, {}, {0}); ASSERT_EQ(Status::OK(), results->status()); @@ -2513,7 +2513,7 @@ TEST_F(DeclarableOpsTests8, softmax_cross_entropy_loss_with_logits_test3) { logits.linspace(0.1, 0.1); nd4j::ops::softmax_cross_entropy_loss_with_logits op; - auto results = op.execute({&logits, &labels}, {}, {1}); + auto results = op.evaluate({&logits, &labels}, {}, {1}); ASSERT_EQ(Status::OK(), results->status()); @@ -2535,7 +2535,7 @@ TEST_F(DeclarableOpsTests8, softmax_cross_entropy_loss_with_logits_test4) { logits.linspace(0.1, 0.1); nd4j::ops::softmax_cross_entropy_loss_with_logits op; - auto results = op.execute({&logits, &labels}, {}, {}); + auto results = op.evaluate({&logits, &labels}, {}, {}); ASSERT_EQ(Status::OK(), results->status()); @@ -2557,7 +2557,7 @@ TEST_F(DeclarableOpsTests8, softmax_cross_entropy_loss_with_logits_test5) { logits.linspace(0.1, 0.1); nd4j::ops::softmax_cross_entropy_loss_with_logits op; - auto results = op.execute({&logits, &labels}, {}, {0}); + auto results = op.evaluate({&logits, &labels}, {}, {0}); ASSERT_EQ(Status::OK(), results->status()); @@ -2579,7 +2579,7 @@ TEST_F(DeclarableOpsTests8, softmax_cross_entropy_loss_with_logits_test6) { logits.linspace(0.1, 0.1); nd4j::ops::softmax_cross_entropy_loss_with_logits op; - auto results = op.execute({&logits, &labels}, {}, {0}); + auto results = op.evaluate({&logits, &labels}, {}, {0}); ASSERT_EQ(Status::OK(), results->status()); @@ -2601,7 +2601,7 @@ TEST_F(DeclarableOpsTests8, softmax_cross_entropy_loss_with_logits_test7) { logits.linspace(0.1, 0.1); nd4j::ops::softmax_cross_entropy_loss_with_logits op; - auto results = op.execute({&logits, &labels}, {}, {1}); + auto results = op.evaluate({&logits, &labels}, {}, {1}); ASSERT_EQ(Status::OK(), results->status()); @@ -2623,7 +2623,7 @@ TEST_F(DeclarableOpsTests8, softmax_cross_entropy_loss_with_logits_test8) { logits.linspace(0.1, 0.1); nd4j::ops::softmax_cross_entropy_loss_with_logits op; - auto results = op.execute({&logits, &labels}, {}, {}); + auto results = op.evaluate({&logits, &labels}, {}, {}); ASSERT_EQ(Status::OK(), results->status()); @@ -2643,7 +2643,7 @@ TEST_F(DeclarableOpsTests8, softmax_cross_entropy_loss_with_logits_test9) { auto expected = NDArrayFactory::create(0.); nd4j::ops::softmax_cross_entropy_loss_with_logits op; - auto results = op.execute({&logits, &labels}, {}, {}); + auto results = op.evaluate({&logits, &labels}, {}, {}); ASSERT_EQ(Status::OK(), results->status()); @@ -2665,7 +2665,7 @@ TEST_F(DeclarableOpsTests8, softmax_cross_entropy_loss_with_logits_test10) { logits.linspace(0.1, 0.1); nd4j::ops::softmax_cross_entropy_loss_with_logits op; - auto results = op.execute({&logits, &labels}, {}, {0}); + auto results = op.evaluate({&logits, &labels}, {}, {0}); ASSERT_EQ(Status::OK(), results->status()); @@ -2684,7 +2684,7 @@ TEST_F(DeclarableOpsTests8, clipbynorm_test4) { auto exp = NDArrayFactory::create('c', {3, 5}, {0.405392, 0.319980, 0.091113, 0.001079, 0.354444, 0.225846, 0.426676, 0.237501, 0.138259, 0.150149, 0.268965, 0.010723, 0.049078, 0.304615, 0.317105}); nd4j::ops::clipbynorm op; - auto result = op.execute({&x}, {1.f}, {}, {}, false, nd4j::DataType::DOUBLE); + auto result = op.evaluate({&x}, {1.f}, {}); auto output = result->at(0); ASSERT_TRUE(exp.isSameShape(output)); @@ -2704,7 +2704,7 @@ TEST_F(DeclarableOpsTests8, clipbynorm_test5) { x.linspace(1); nd4j::ops::clipbynorm op; - auto result = op.execute({&x}, {15.f}, {0}, {}, false, nd4j::DataType::DOUBLE); + auto result = op.evaluate({&x}, {15.f}, {0}); auto output = result->at(0); ASSERT_TRUE(exp.isSameShape(output)); @@ -2722,7 +2722,7 @@ TEST_F(DeclarableOpsTests8, clipbynorm_test6) { x.linspace(1); nd4j::ops::clipbynorm op; - auto result = op.execute({&x}, {15.f}, {1}, {}, false, nd4j::DataType::DOUBLE); + auto result = op.evaluate({&x}, {15.f}, {1}); auto output = result->at(0); ASSERT_TRUE(exp.isSameShape(output)); @@ -2740,7 +2740,7 @@ TEST_F(DeclarableOpsTests8, clipbynorm_test7) { x.linspace(1); nd4j::ops::clipbynorm op; - auto result = op.execute({&x}, {15.f}, {0,1}, {}, false, nd4j::DataType::DOUBLE); + auto result = op.evaluate({&x}, {15.f}, {0,1}); auto output = result->at(0); ASSERT_TRUE(exp.isSameShape(output)); @@ -2758,7 +2758,7 @@ TEST_F(DeclarableOpsTests8, clipbynorm_test8) { x.linspace(1); nd4j::ops::clipbynorm op; - auto result = op.execute({&x}, {15.}, {}, {}, false, nd4j::DataType::DOUBLE); + auto result = op.evaluate({&x}, {15.}, {}); auto output = result->at(0); ASSERT_TRUE(exp.isSameShape(output)); @@ -2774,7 +2774,7 @@ TEST_F(DeclarableOpsTests8, clipbynorm_test9) { auto exp = NDArrayFactory::create('c', {2}, {2.4, 3.2}); nd4j::ops::clipbynorm op; - auto result = op.execute({&x}, {4.}, {}, {}, false, nd4j::DataType::DOUBLE); + auto result = op.evaluate({&x}, {4.}, {}); auto output = result->at(0); ASSERT_TRUE(exp.isSameShape(output)); @@ -2790,7 +2790,7 @@ TEST_F(DeclarableOpsTests8, clipbynorm_test10) { auto exp = NDArrayFactory::create(5.); nd4j::ops::clipbynorm op; - auto result = op.execute({&x}, {5.}, {}, {}, false, nd4j::DataType::DOUBLE); + auto result = op.evaluate({&x}, {5.}, {}); auto output = result->at(0); ASSERT_TRUE(exp.isSameShape(output)); @@ -2809,7 +2809,7 @@ TEST_F(DeclarableOpsTests8, clipbynorm_test11) { x.linspace(1); nd4j::ops::clipbynorm op; - auto result = op.execute({&x}, {35.}, {0, 2}, {}, false, nd4j::DataType::DOUBLE); + auto result = op.evaluate({&x}, {35.}, {0, 2}); auto output = result->at(0); ASSERT_TRUE(exp.isSameShape(output)); @@ -2824,7 +2824,7 @@ TEST_F(DeclarableOpsTests8, clipbynorm_test_tf_119_1) { auto e = NDArrayFactory::create('c', {3, 3}, {0.03198684, 0.06397368, 0.09596053, 0.12794736, 0.15993419, 0.19192106, 0.22390789, 0.25589472, 0.28788155}); nd4j::ops::clipbynorm op; - auto result = op.execute({&x}, {0.54}, {}, {}, false, nd4j::DataType::DOUBLE); + auto result = op.evaluate({&x}, {0.54}, {}); ASSERT_EQ(e, *result->at(0)); @@ -2841,14 +2841,14 @@ TEST_F(DeclarableOpsTests8, reduceMeanBP_test4) { nd4j::ops::reduce_mean_bp op; - auto result = op.execute({&x, &gradO1}, {0}, {0}); + auto result = op.evaluate({&x, &gradO1}, {0}, {0}); ASSERT_EQ(Status::OK(), result->status()); auto output = result->at(0); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; - result = op.execute({&x, &gradO2}, {1}, {0}); + result = op.evaluate({&x, &gradO2}, {1}, {0}); ASSERT_EQ(Status::OK(), result->status()); output = result->at(0); ASSERT_TRUE(exp.isSameShape(output)); @@ -2867,14 +2867,14 @@ TEST_F(DeclarableOpsTests8, reduceMeanBP_test5) { nd4j::ops::reduce_mean_bp op; - auto result = op.execute({&x, &gradO1}, {0}, {1}); + auto result = op.evaluate({&x, &gradO1}, {0}, {1}); ASSERT_EQ(Status::OK(), result->status()); auto output = result->at(0); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; - result = op.execute({&x, &gradO2}, {1}, {1}); + result = op.evaluate({&x, &gradO2}, {1}, {1}); ASSERT_EQ(Status::OK(), result->status()); output = result->at(0); ASSERT_TRUE(exp.isSameShape(output)); @@ -2893,14 +2893,14 @@ TEST_F(DeclarableOpsTests8, reduceStDevBP_test5) { nd4j::ops::reduce_stdev_bp op; - auto result = op.execute({&x, &gradO1}, {0}, {0}); + auto result = op.evaluate({&x, &gradO1}, {0}, {0}); ASSERT_EQ(Status::OK(), result->status()); auto output = result->at(0); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; - result = op.execute({&x, &gradO2}, {1}, {0}); + result = op.evaluate({&x, &gradO2}, {1}, {0}); ASSERT_EQ(Status::OK(), result->status()); output = result->at(0); ASSERT_TRUE(exp.isSameShape(output)); @@ -2934,7 +2934,7 @@ TEST_F(DeclarableOpsTests8, zeros_as_test2) { nd4j::ops::zeros_as op; - auto result = op.execute({&x}, {}, {}); + auto result = op.evaluate({&x}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); auto y = result->at(0); @@ -2952,7 +2952,7 @@ TEST_F(DeclarableOpsTests8, ones_as_test1) { nd4j::ops::ones_as op; - Nd4jStatus status = op.execute({&x}, {&y}, {}, {}, {}, false, nd4j::DataType::DOUBLE); + Nd4jStatus status = op.execute({&x}, {&y}); ASSERT_EQ(Status::OK(), status); ASSERT_TRUE(y.isSameShape(exp)); @@ -2969,7 +2969,7 @@ TEST_F(DeclarableOpsTests8, ones_as_test2) { nd4j::ops::ones_as op; - auto results = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::DOUBLE); + auto results = op.evaluate({&x}); ASSERT_EQ(Status::OK(), results->status()); auto y = results->at(0); ASSERT_TRUE(y->isSameShape(exp)); @@ -2998,10 +2998,10 @@ TEST_F(DeclarableOpsTests8, NormalizeMoments_SGO_1) { // ssSquared->printBuffer("Sum squared"); // squared.printBuffer("Squared"); nd4j::ops::normalize_moments op; - auto results = op.execute({&counts, &means, &ssSquared}, {0.0}, {0}); + auto results = op.evaluate({&counts, &means, &ssSquared}, {0.0}, {0}); means /= counts; // nd4j::ops::normalize_moments op; -// auto results = op.execute({&counts, means, deviance}, {0.0}, {}); +// auto results = op.evaluate({&counts, means, deviance}, {0.0}, {}); ASSERT_EQ(Status::OK(), results->status()); ASSERT_EQ(results->size(), 2); @@ -3035,7 +3035,7 @@ TEST_F(DeclarableOpsTests8, Test_Moments_1) { x.linspace(1); nd4j::ops::moments op; - auto result = op.execute({&x}, {}, {0, 1}); + auto result = op.evaluate({&x}, {}, {0, 1}); ASSERT_EQ(Status::OK(), result->status()); @@ -3066,7 +3066,7 @@ TEST_F(DeclarableOpsTests8, Test_Moments_2) { x.linspace(1); nd4j::ops::moments op; - auto result = op.execute({&x}, {1.}, {0, 1}); + auto result = op.evaluate({&x}, {1.}, {0, 1}); ASSERT_EQ(Status::OK(), result->status()); auto outputMeans = result->at(0); @@ -3095,7 +3095,7 @@ TEST_F(DeclarableOpsTests8, Test_Moments_3) { x.linspace(1); nd4j::ops::moments op; - auto result = op.execute({&x}, {}, {0, 2}); + auto result = op.evaluate({&x}, {}, {0, 2}); ASSERT_EQ(Status::OK(), result->status()); auto outputMeans = result->at(0); @@ -3124,7 +3124,7 @@ TEST_F(DeclarableOpsTests8, Test_Moments_4) { x.linspace(1); nd4j::ops::moments op; - auto result = op.execute({&x}, {1.}, {0, 2}); + auto result = op.evaluate({&x}, {1.}, {0, 2}); ASSERT_EQ(Status::OK(), result->status()); auto outputMeans = result->at(0); @@ -3153,7 +3153,7 @@ TEST_F(DeclarableOpsTests8, Test_Moments_6) { x.linspace(1); nd4j::ops::moments op; - auto result = op.execute({&x}, {}, {0,1,2}); + auto result = op.evaluate({&x}, {}, {0,1,2}); ASSERT_EQ(Status::OK(), result->status()); auto outputMeans = result->at(0); @@ -3182,7 +3182,7 @@ TEST_F(DeclarableOpsTests8, Test_Moments_7) { x.linspace(1); // x.printIndexedBuffer("Input with shape (2, 3, 4) is"); nd4j::ops::moments op; - auto result = op.execute({&x}, {1.}, {0,1,2}); + auto result = op.evaluate({&x}, {1.}, {0,1,2}); ASSERT_EQ(Status::OK(), result->status()); auto outputMeans = result->at(0); @@ -3211,7 +3211,7 @@ TYPED_TEST(TypedDeclarableOpsTests8, LrnTest_01) { ); nd4j::ops::lrn op; - auto results = op.execute({&x}, {1.0, 1.0, 0.5}, {2}, {}, false, nd4j::DataType::DOUBLE); + auto results = op.evaluate({&x}, {1.0, 1.0, 0.5}, {2}); auto out = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -3233,7 +3233,7 @@ TYPED_TEST(TypedDeclarableOpsTests8, LrnTest_02) { ); nd4j::ops::lrn op; - auto results = op.execute({&x}, {1.0, 1.0, 0.5}, {2}, {}, false, nd4j::DataType::DOUBLE); + auto results = op.evaluate({&x}, {1.0, 1.0, 0.5}, {2}); auto out = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -3252,7 +3252,7 @@ TYPED_TEST(TypedDeclarableOpsTests8, LrnTest_03) { auto exp = NDArrayFactory::create('c', {1, 1, 1, 10}, {0.10425719f, 0.16843036f, 0.2095291f, 0.23652494f, 0.25449327f, 0.3053919f, 0.35675305f, 0.4098524f, 0.46662825f, 0.52999896f}); nd4j::ops::lrn op; - auto results = op.execute({&x}, {1.0, 1.0, 0.5}, {5}, {}, false, nd4j::DataType::DOUBLE); + auto results = op.evaluate({&x}, {1.0, 1.0, 0.5}, {5}); auto out = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -3281,7 +3281,7 @@ TYPED_TEST(TypedDeclarableOpsTests8, LrnTest_1) { ); nd4j::ops::lrn op; - auto results = op.execute({&x}, {1.0, 1.0, 0.5}, {2}, {}, false, nd4j::DataType::DOUBLE); + auto results = op.evaluate({&x}, {1.0, 1.0, 0.5}, {2}); auto out = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -3358,7 +3358,7 @@ TYPED_TEST(TypedDeclarableOpsTests8, LrnTest_2) { ); // nd4j::ops::lrn op; - auto results = op.execute({&x}, {1.0, 1.0, 0.5}, {2}, {}, false, nd4j::DataType::DOUBLE); + auto results = op.evaluate({&x}, {1.0, 1.0, 0.5}, {2}); auto out = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -3435,7 +3435,7 @@ TYPED_TEST(TypedDeclarableOpsTests8, LrnTest_3) { ); // nd4j::ops::lrn op; - auto results = op.execute({&x}, {1.0, 1.0, 0.5}, {2}, {}, false, nd4j::DataType::FLOAT32); + auto results = op.evaluate({&x}, {1.0, 1.0, 0.5}, {2}); auto out = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -3455,7 +3455,7 @@ TYPED_TEST(TypedDeclarableOpsTests8, LrnTest_4) { x.linspace(1); nd4j::ops::lrn op; - auto results = op.execute({&x}, {1.0, 1.0, 0.5}, {2}, {}, false, nd4j::DataType::DOUBLE); + auto results = op.evaluate({&x}, {1.0, 1.0, 0.5}, {2}); auto out = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -3478,12 +3478,12 @@ TYPED_TEST(TypedDeclarableOpsTests8, LrnTest_4_119) { nd4j::ops::lrn op; - op.execute({&x}, {&z}, {1.0, 1.0, 0.5}, {2}, {}, false, nd4j::DataType::DOUBLE); + op.execute({&x}, {&z}, {1.0, 1.0, 0.5}, {2}); auto timeStart = std::chrono::system_clock::now(); for (int e = 0; e < iterations; e++) - op.execute({&x}, {&z}, {1.0, 1.0, 0.5}, {2}, {}, false, nd4j::DataType::DOUBLE); + op.execute({&x}, {&z}, {1.0, 1.0, 0.5}, {2}); auto timeEnd = std::chrono::system_clock::now(); auto spanTime = std::chrono::duration_cast ((timeEnd - timeStart) / iterations).count(); @@ -3501,7 +3501,7 @@ TYPED_TEST(TypedDeclarableOpsTests8, LrnTest_5) { x.linspace(1); nd4j::ops::lrn op; - auto results = op.execute({&x}, {1.0, 1.0, 0.5}, {2}, {}, false, nd4j::DataType::DOUBLE); + auto results = op.evaluate({&x}, {1.0, 1.0, 0.5}, {2}); auto out = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -3524,7 +3524,7 @@ TYPED_TEST(TypedDeclarableOpsTests8, LrnTest_BP_01) { // ); /// nd4j::ops::lrn_bp op; - auto results = op.execute({&x, &eps}, {1.0, 1.0, 0.5}, {5}, {}, false, nd4j::DataType::DOUBLE); + auto results = op.evaluate({&x, &eps}, {1.0, 1.0, 0.5}, {5}); auto out = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -3580,7 +3580,7 @@ auto exp = NDArrayFactory::create('c', {3,3,5,5}, { ); /// nd4j::ops::lrn_bp op; - auto results = op.execute({&x, &eps}, {1.0, 1.0, 0.5}, {2}, {}, false, typeid(TypeParam) == typeid(float) ? nd4j::DataType::FLOAT32 : nd4j::DataType::DOUBLE); + auto results = op.evaluate({&x, &eps}, {1.0, 1.0, 0.5}, {2}, {}, {}, false); auto out = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -3660,7 +3660,7 @@ TYPED_TEST(TypedDeclarableOpsTests8, LrnTest_BP_2) { ); nd4j::ops::lrn_bp op; - auto results = op.execute({&x, &eps}, {1.0, 1.0, 0.5}, {2}, {}, false, typeid(TypeParam) == typeid(float) ? nd4j::DataType::FLOAT32 : nd4j::DataType::DOUBLE); + auto results = op.evaluate({&x, &eps}, {1.0, 1.0, 0.5}, {2}, {}, {}, false); auto out = results->at(0); ASSERT_EQ(Status::OK(), results->status()); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp index caceaa1cd..2c4655b31 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp @@ -51,7 +51,7 @@ TEST_F(DeclarableOpsTests9, reduceStDevBP_test3) { nd4j::ops::reduce_stdev_bp op; - auto result = op.execute({&x, &gradO2}, {0,0}, {1}); + auto result = op.evaluate({&x, &gradO2}, {0,0}, {1}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto output = result->at(0); // output->printIndexedBuffer(); @@ -59,7 +59,7 @@ TEST_F(DeclarableOpsTests9, reduceStDevBP_test3) { ASSERT_TRUE(exp.equalsTo(output)); delete result; - result = op.execute({&x, &gradO1}, {1,0}, {1}); + result = op.evaluate({&x, &gradO1}, {1,0}, {1}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); output = result->at(0); ASSERT_TRUE(exp.isSameShape(output)); @@ -80,7 +80,7 @@ TEST_F(DeclarableOpsTests9, reduceStDevBP_test03) { nd4j::ops::reduce_stdev_bp op; - auto result = op.execute({&x, &gradO2, &axis}, {}, {}, {false, false}); + auto result = op.evaluate({&x, &gradO2, &axis}, {}, {}, {false, false}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto output = result->at(0); // output->printIndexedBuffer(); @@ -88,7 +88,7 @@ TEST_F(DeclarableOpsTests9, reduceStDevBP_test03) { ASSERT_TRUE(exp.equalsTo(output)); delete result; - result = op.execute({&x, &gradO1}, {1,0}, {1}); + result = op.evaluate({&x, &gradO1}, {1,0}, {1}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); output = result->at(0); ASSERT_TRUE(exp.isSameShape(output)); @@ -248,7 +248,7 @@ TEST_F(DeclarableOpsTests9, concat_test1) { nd4j::ops::concat op; - auto result = op.execute({&x0, &x1, &x2}, {}, {1}); + auto result = op.evaluate({&x0, &x1, &x2}, {}, {1}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto output = result->at(0); @@ -272,7 +272,7 @@ TEST_F(DeclarableOpsTests9, concat_test2) { nd4j::ops::concat op; - auto result = op.execute({&x0, &x1, &x2}, {}, {1}); + auto result = op.evaluate({&x0, &x1, &x2}, {}, {1}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto output = result->at(0); @@ -296,7 +296,7 @@ TEST_F(DeclarableOpsTests9, concat_test3) { nd4j::ops::concat op; - auto result = op.execute({&x0, &x1, &x2}, {}, {0}); + auto result = op.evaluate({&x0, &x1, &x2}, {}, {0}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto output = result->at(0); @@ -316,7 +316,7 @@ TEST_F(DeclarableOpsTests9, concat_test4) { nd4j::ops::concat op; - auto result = op.execute({&x0, &x1, &x2}, {}, {1}); + auto result = op.evaluate({&x0, &x1, &x2}, {}, {1}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto output = result->at(0); @@ -336,7 +336,7 @@ TEST_F(DeclarableOpsTests9, concat_test5) { nd4j::ops::concat op; - auto result = op.execute({&x0, &x1, &x2}, {}, {0}); + auto result = op.evaluate({&x0, &x1, &x2}, {}, {0}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto output = result->at(0); @@ -356,7 +356,7 @@ TEST_F(DeclarableOpsTests9, concat_test6) { nd4j::ops::concat op; - auto result = op.execute({&x0, &x1, &x2}, {}, {0}); + auto result = op.evaluate({&x0, &x1, &x2}, {}, {0}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto output = result->at(0); @@ -376,7 +376,7 @@ TEST_F(DeclarableOpsTests9, concat_test7) { nd4j::ops::concat op; - auto result = op.execute({&x0, &x1, &x2}, {}, {0}); + auto result = op.evaluate({&x0, &x1, &x2}, {}, {0}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto output = result->at(0); @@ -394,7 +394,7 @@ TEST_F(DeclarableOpsTests9, concat_test8) { nd4j::ops::concat op; - auto result = op.execute({&x0}, {}, {0}); + auto result = op.evaluate({&x0}, {}, {0}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto output = result->at(0); @@ -412,7 +412,7 @@ TEST_F(DeclarableOpsTests9, concat_test9) { nd4j::ops::concat op; - auto result = op.execute({&x0}, {}, {0}); + auto result = op.evaluate({&x0}, {}, {0}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto output = result->at(0); @@ -437,7 +437,7 @@ TEST_F(DeclarableOpsTests9, concat_test10) { nd4j::ops::concat op; - auto result = op.execute({&x0, &x1, &x2}, {}, {1}); + auto result = op.evaluate({&x0, &x1, &x2}, {}, {1}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto output = result->at(0); @@ -462,7 +462,7 @@ TEST_F(DeclarableOpsTests9, concat_test11) { nd4j::ops::concat op; - auto result = op.execute({&x0, &x1, &x2}, {}, {1}); + auto result = op.evaluate({&x0, &x1, &x2}, {}, {1}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto output = result->at(0); @@ -487,7 +487,7 @@ TEST_F(DeclarableOpsTests9, concat_test12) { nd4j::ops::concat op; - auto result = op.execute({&x0, &x1, &x2}, {}, {1}); + auto result = op.evaluate({&x0, &x1, &x2}, {}, {1}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto output = result->at(0); @@ -512,7 +512,7 @@ TEST_F(DeclarableOpsTests9, concat_test13) { nd4j::ops::concat op; - auto result = op.execute({&x0, &x1, &x2}, {}, {1}); + auto result = op.evaluate({&x0, &x1, &x2}, {}, {1}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto output = result->at(0); @@ -532,7 +532,7 @@ TEST_F(DeclarableOpsTests9, concat_test14) { x1 = 2.; nd4j::ops::concat op; - auto result = op.execute({&x0, &x1}, {}, {0}, {}); + auto result = op.evaluate({&x0, &x1}, {}, {0}, {}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -555,7 +555,7 @@ TEST_F(DeclarableOpsTests9, concat_test15) { auto exp = NDArrayFactory::create('c', {3}, {1, 0, 3}); nd4j::ops::concat op; - auto result = op.execute({&x, &y}, {}, {0}); + auto result = op.evaluate({&x, &y}, {}, {0}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -574,7 +574,7 @@ TEST_F(DeclarableOpsTests9, concat_test16) { auto exp = NDArrayFactory::create('c', {0,2,3}); nd4j::ops::concat op; - auto result = op.execute({&x, &y}, {}, {0}); + auto result = op.evaluate({&x, &y}, {}, {0}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -594,7 +594,7 @@ TEST_F(DeclarableOpsTests9, concat_test17) { x1 = 2.; nd4j::ops::concat op; - auto result = op.execute({&x0, &x1}, {}, {0}, {}); + auto result = op.evaluate({&x0, &x1}, {}, {0}, {}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -675,7 +675,7 @@ TEST_F(DeclarableOpsTests9, concat_test20) { x3.assign(4.0); nd4j::ops::concat op; - auto result = op.execute({&x0, &x1, &x2, &x3}, {}, {0}, {}); + auto result = op.evaluate({&x0, &x1, &x2, &x3}, {}, {0}, {}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -763,7 +763,7 @@ TEST_F(DeclarableOpsTests9, concat_test25) { nd4j::ops::concat op; - auto result = op.execute({&x0, &x1, &axis}, {}, {}, {true}); + auto result = op.evaluate({&x0, &x1, &axis}, {}, {}, {true}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto output = result->at(0); @@ -784,7 +784,7 @@ TEST_F(DeclarableOpsTests9, tile_bp_test1) { gradO.linspace(0.01, 0.01); nd4j::ops::tile_bp op; - auto results = op.execute({&input, &gradO}, {}, {2, 3}); + auto results = op.evaluate({&input, &gradO}, {}, {2, 3}); auto gradI = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -804,7 +804,7 @@ TEST_F(DeclarableOpsTests9, tile_bp_test2) { gradO.linspace(0.01, 0.01); nd4j::ops::tile_bp op; - auto results = op.execute({&input, &gradO}, {}, {1, 3}); + auto results = op.evaluate({&input, &gradO}, {}, {1, 3}); auto gradI = results->at(0); ASSERT_EQ(Status::OK(), results->status()); ASSERT_TRUE(gradIExp.isSameShape(gradI)); @@ -823,7 +823,7 @@ TEST_F(DeclarableOpsTests9, tile_bp_test3) { gradO.linspace(0.01, 0.01); nd4j::ops::tile_bp op; - auto results = op.execute({&input, &gradO}, {}, {1, 1}); + auto results = op.evaluate({&input, &gradO}, {}, {1, 1}); auto gradI = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -843,7 +843,7 @@ TEST_F(DeclarableOpsTests9, tile_bp_test4) { gradO.linspace(0.01, 0.01); nd4j::ops::tile_bp op; - auto results = op.execute({&input, &gradO}, {}, {2}); + auto results = op.evaluate({&input, &gradO}, {}, {2}); auto gradI = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -863,7 +863,7 @@ TEST_F(DeclarableOpsTests9, tile_bp_test5) { gradO.linspace(0.01, 0.01); nd4j::ops::tile_bp op; - auto results = op.execute({&input, &gradO}, {}, {1}); + auto results = op.evaluate({&input, &gradO}, {}, {1}); auto gradI = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -883,7 +883,7 @@ TEST_F(DeclarableOpsTests9, tile_bp_test6) { gradO.linspace(0.01, 0.01); nd4j::ops::tile_bp op; - auto results = op.execute({&input, &gradO}, {}, {1, 3, 2}); + auto results = op.evaluate({&input, &gradO}, {}, {1, 3, 2}); auto gradI = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -904,7 +904,7 @@ TEST_F(DeclarableOpsTests9, tile_bp_test7) { gradO.linspace(0.01, 0.01); nd4j::ops::tile_bp op; - auto results = op.execute({&input, &reps, &gradO}, {}, {}); + auto results = op.evaluate({&input, &reps, &gradO}, {}, {}); auto gradI = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -922,7 +922,7 @@ TEST_F(DeclarableOpsTests9, tile_test1) { auto expOut = NDArrayFactory::create('c', {2, 6,}, {1.,2.,3.,4.,5.,6., 1.,2.,3.,4.,5.,6.}); nd4j::ops::tile op; - auto results = op.execute({&input, &reps}, {}, {}); + auto results = op.evaluate({&input, &reps}, {}, {}); auto out = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -944,7 +944,7 @@ TEST_F(DeclarableOpsTests9, matmul_test1) { y.linspace(0.5, 0.5); nd4j::ops::matmul op; - auto results = op.execute({&x, &y}, {}, {}); + auto results = op.evaluate({&x, &y}, {}, {}); auto z = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -966,7 +966,7 @@ TEST_F(DeclarableOpsTests9, matmul_test2) { y.linspace(0.5, 0.5); nd4j::ops::matmul op; - auto results = op.execute({&x, &y}, {}, {}); + auto results = op.evaluate({&x, &y}, {}, {}); auto z = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -987,7 +987,7 @@ TEST_F(DeclarableOpsTests9, matmul_test3) { y.linspace(0.5, 0.5); nd4j::ops::matmul op; - auto results = op.execute({&x, &y}, {}, {}); + auto results = op.evaluate({&x, &y}, {}, {}); auto z = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -1009,7 +1009,7 @@ TEST_F(DeclarableOpsTests9, matmul_test4) { y.linspace(0.5, 0.5); nd4j::ops::matmul op; - auto results = op.execute({&x, &y}, {}, {}); + auto results = op.evaluate({&x, &y}, {}, {}); auto z = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -1031,7 +1031,7 @@ TEST_F(DeclarableOpsTests9, matmul_test5) { y.linspace(0.5, 0.5); nd4j::ops::matmul op; - auto results = op.execute({&x, &y}, {}, {1}); + auto results = op.evaluate({&x, &y}, {}, {1}); auto z = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -1052,7 +1052,7 @@ TEST_F(DeclarableOpsTests9, matmul_test6) { y.linspace(0.5, 0.5); nd4j::ops::matmul op; - auto results = op.execute({&x, &y}, {}, {1, 1}); + auto results = op.evaluate({&x, &y}, {}, {1, 1}); auto z = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -1075,7 +1075,7 @@ TEST_F(DeclarableOpsTests9, matmul_test7) { y.linspace(0.1, 0.1); nd4j::ops::matmul op; - auto results = op.execute({&x, &y}, {}, {0, 1}); + auto results = op.evaluate({&x, &y}, {}, {0, 1}); auto z = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -1100,7 +1100,7 @@ TEST_F(DeclarableOpsTests9, matmul_test8) { y.linspace(0.1, 0.1); nd4j::ops::matmul op; - auto results = op.execute({&x, &y}, {}, {0, 1}); + auto results = op.evaluate({&x, &y}, {}, {0, 1}); auto z = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -1125,7 +1125,7 @@ TEST_F(DeclarableOpsTests9, matmul_test9) { y.linspace(0.1, 0.1); nd4j::ops::matmul op; - auto results = op.execute({&x, &y}, {}, {1, 1}); + auto results = op.evaluate({&x, &y}, {}, {1, 1}); auto z = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -1142,7 +1142,7 @@ TEST_F(DeclarableOpsTests9, TestDropout_BP_1) { NDArray shape('c', {2}, {2, 2}); nd4j::ops::dropout_bp op; - auto ress = op.execute({&x, &errs, &shape}, {0.2f}, {113}); + auto ress = op.evaluate({&x, &errs, &shape}, {0.2f}, {113}); ASSERT_EQ(ND4J_STATUS_OK, ress->status()); //ress->at(0)->printIndexedBuffer("Result is "); @@ -1159,7 +1159,7 @@ TEST_F(DeclarableOpsTests9, TestDropout_1) { //NDArray shape({2.f, 2.f}); nd4j::ops::dropout op; x.linspace(1); - auto ress = op.execute({&x}, {0.2f}, {113}); + auto ress = op.evaluate({&x}, {0.2f}, {113}); ASSERT_EQ(ND4J_STATUS_OK, ress->status()); NDArray* res = ress->at(0); //->printIndexedBuffer("Result is "); @@ -1167,7 +1167,7 @@ TEST_F(DeclarableOpsTests9, TestDropout_1) { //res->printIndexedBuffer("Result for Dropout_1"); auto countZero = res->reduceNumber(reduce::CountZero); ASSERT_NEAR(countZero.e(0), 80, 5); - auto ress2 = op.execute({&x}, {0.2f}, {113}); + auto ress2 = op.evaluate({&x}, {0.2f}, {113}); ASSERT_EQ(ND4J_STATUS_OK, ress2->status()); NDArray* res2 = ress2->at(0); @@ -1214,7 +1214,7 @@ TEST_F(DeclarableOpsTests9, Test_DropoutInverted_01) { */ nd4j::ops::dropout op; - auto ress = op.execute({&x1}, {0.5f}, {119}); + auto ress = op.evaluate({&x1}, {0.5f}, {119}); ASSERT_EQ(ND4J_STATUS_OK, ress->status()); //ress->at(0)->printIndexedBuffer("01Dropout result is "); @@ -1225,11 +1225,11 @@ TEST_F(DeclarableOpsTests9, Test_DropoutInverted_01) { //NDArray exp('c', {10,10}, {4.f, 0.f, 12.f, 0.f, 20.f, 24.f, 0.f, 32.f, 0.f, 0.f, 0.f, 0.f, 52.f, 56.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 84.f, 88.f, 0.f, 0.f, 0.f, 0.f, 108.f, 0.f, 0.f, 120.f, 0.f, 0.f, 132.f, 0.f, 0.f, 0.f, 0.f, 0.f, 156.f, 0.f, 164.f, 168.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 200.f, 204.f, 0.f, 0.f, 0.f, 220.f, 0.f, 0.f, 232.f, 236.f, 240.f, 0.f, 248.f, 0.f, 0.f, 260.f, 0.f, 0.f, 0.f, 276.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 316.f, 0.f, 324.f, 0.f, 0.f, 336.f, 0.f, 0.f, 0.f, 0.f, 356.f, 0.f, 0.f, 368.f, 0.f, 0.f, 0.f, 384.f, 388.f, 0.f, 0.f, 400.f}); //02Dropout result is [4.000000, 0.000000, 12.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 36.000000, 0.000000, 0.000000, 0.000000, 0.000000, 56.000000, 60.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 88.000000, 0.000000, 96.000000, 0.000000, 0.000000, 108.000000, 0.000000, 0.000000, 120.000000, 0.000000, 128.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 156.000000, 0.000000, 164.000000, 0.000000, 0.000000, 0.000000, 0.000000, 184.000000, 0.000000, 0.000000, 0.000000, 200.000000, 0.000000, 0.000000, 0.000000, 216.000000, 0.000000, 0.000000, 0.000000, 232.000000, 0.000000, 240.000000, 0.000000, 248.000000, 0.000000, 0.000000, 260.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 308.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 348.000000, 0.000000, 356.000000, 0.000000, 0.000000, 0.000000, 0.000000, 376.000000, 0.000000, 384.000000, 0.000000, 0.000000, 0.000000, 400.000000] - auto ressX = op2.execute({&x1, &x1}, {0.5f}, {119}); // , false, nd4j::DataType::FLOAT32); // skipped due given by default + auto ressX = op2.evaluate({&x1, &x1}, {0.5f}, {119}); // , false, nd4j::DataType::FLOAT32); // skipped due given by default //x0.printIndexedBuffer("X0"); //x1.printIndexedBuffer("X1"); ASSERT_EQ(ND4J_STATUS_OK, ressX->status()); - auto ressY = op2.execute({&x1, &x0}, {0.5f}, {119}); + auto ressY = op2.evaluate({&x1, &x0}, {0.5f}, {119}); ASSERT_EQ(ND4J_STATUS_OK, ressY->status()); //ressY->at(0)->printIndexedBuffer("BP"); //ress->at(0)->printIndexedBuffer("FF"); @@ -1264,17 +1264,17 @@ TEST_F(DeclarableOpsTests9, Test_Dropout_BP_2) { nd4j::ops::dropout op; - auto ress = op.execute({&x}, {0.5f}, {119}); + auto ress = op.evaluate({&x}, {0.5f}, {119}); ASSERT_EQ(ND4J_STATUS_OK, ress->status()); // ress->at(0)->printIndexedBuffer("01Dropout result is "); nd4j::ops::dropout_bp op2; - auto ressX = op2.execute({&x, &x}, {0.5f}, {119}); + auto ressX = op2.evaluate({&x, &x}, {0.5f}, {119}); ASSERT_EQ(ND4J_STATUS_OK, ressX->status()); - auto ressY = op2.execute({&x, &x}, {0.5f}, {119}); + auto ressY = op2.evaluate({&x, &x}, {0.5f}, {119}); ASSERT_EQ(ND4J_STATUS_OK, ressY->status()); //ress->at(0)->printIndexedBuffer("FF Dropout result is "); @@ -1307,12 +1307,12 @@ TEST_F(DeclarableOpsTests9, Test_AlphaDropout_BP_1) { nd4j::ops::alpha_dropout_bp op; - auto ress = op.execute({&x, &eps}, {0.5f, 0.5f, 1.5f, 1.6f}, {119}); + auto ress = op.evaluate({&x, &eps}, {0.5f, 0.5f, 1.5f, 1.6f}, {119}); ASSERT_EQ(ND4J_STATUS_OK, ress->status()); NDArray* res = ress->at(0); - auto ress2 = op.execute({&x, &eps}, {0.5f, 0.5f, 1.5f, 1.6f}, {119}); + auto ress2 = op.evaluate({&x, &eps}, {0.5f, 0.5f, 1.5f, 1.6f}, {119}); ASSERT_EQ(ND4J_STATUS_OK, ress2->status()); NDArray* res2 = ress2->at(0); @@ -1336,7 +1336,7 @@ TEST_F(DeclarableOpsTests9, matmul_test10) { y.linspace(0.5, 0.5); nd4j::ops::matmul op; - auto results = op.execute({&x, &y}, {}, {1, 1}); + auto results = op.evaluate({&x, &y}, {}, {1, 1}); auto z = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -1356,7 +1356,7 @@ TEST_F(DeclarableOpsTests9, matmul_test11) { x.linspace(1.); y.linspace(0.5, 0.5); nd4j::ops::matmul op; - auto results = op.execute({&x, &y}, {}, {1, 1}); + auto results = op.evaluate({&x, &y}, {}, {1, 1}); ASSERT_EQ(Status::OK(), results->status()); auto z = results->at(0); @@ -1377,7 +1377,7 @@ TEST_F(DeclarableOpsTests9, matmul_test12) { y.linspace(0.5, 0.5); nd4j::ops::matmul op; - auto results = op.execute({&x, &y}, {}, {1, 1}); + auto results = op.evaluate({&x, &y}, {}, {1, 1}); ASSERT_EQ(Status::OK(), results->status()); auto z = results->at(0); @@ -1398,7 +1398,7 @@ TEST_F(DeclarableOpsTests9, matmul_test13) { y.linspace(0.5, 0.5); nd4j::ops::matmul op; - auto results = op.execute({&x, &y}, {}, {0, 0, 1}); + auto results = op.evaluate({&x, &y}, {}, {0, 0, 1}); auto z = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -1419,7 +1419,7 @@ TEST_F(DeclarableOpsTests9, matmul_test14) { y.linspace(0.5, 0.5); nd4j::ops::matmul op; - auto results = op.execute({&x, &y}, {}, {1, 0, 1}); + auto results = op.evaluate({&x, &y}, {}, {1, 0, 1}); auto z = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -1440,7 +1440,7 @@ TEST_F(DeclarableOpsTests9, matmul_test15) { y.linspace(0.5, 0.5); nd4j::ops::matmul op; - auto results = op.execute({&x, &y}, {}, {1, 0, 1}); + auto results = op.evaluate({&x, &y}, {}, {1, 0, 1}); auto z = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -1464,7 +1464,7 @@ TEST_F(DeclarableOpsTests9, matmul_test16) { y.linspace(0.1, 0.1); nd4j::ops::matmul op; - auto results = op.execute({&x, &y}, {}, {1, 1, 1}); + auto results = op.evaluate({&x, &y}, {}, {1, 1, 1}); auto z = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -1485,7 +1485,7 @@ TEST_F(DeclarableOpsTests9, matmul_test17) { y.linspace(0.1, 0.1); nd4j::ops::matmul op; - auto results = op.execute({&x, &y}, {}, {1, 0}); + auto results = op.evaluate({&x, &y}, {}, {1, 0}); auto z = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -1506,7 +1506,7 @@ TEST_F(DeclarableOpsTests9, matmul_test18) { y.linspace(0.1, 0.1); nd4j::ops::matmul op; - auto results = op.execute({&x, &y}, {}, {0, 1}); + auto results = op.evaluate({&x, &y}, {}, {0, 1}); auto z = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -1527,7 +1527,7 @@ TEST_F(DeclarableOpsTests9, matmul_test19) { y.linspace(0.1, 0.1); nd4j::ops::matmul op; - auto results = op.execute({&x, &y}, {}, {}); + auto results = op.evaluate({&x, &y}, {}, {}); auto z = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -1549,7 +1549,7 @@ TEST_F(DeclarableOpsTests9, matmul_test20) { y.linspace(0.1, 0.1); nd4j::ops::matmul op; - auto results = op.execute({&x, &y}, {}, {1,1,1}); + auto results = op.evaluate({&x, &y}, {}, {1,1,1}); auto z = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -1571,7 +1571,7 @@ TEST_F(DeclarableOpsTests9, matmul_test21) { y.linspace(0.1, 0.1); nd4j::ops::matmul op; - auto results = op.execute({&x, &y}, {}, {}); + auto results = op.evaluate({&x, &y}, {}, {}); auto z = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -1593,7 +1593,7 @@ TEST_F(DeclarableOpsTests9, matmul_test22) { y.linspace(0.1, 0.1); nd4j::ops::matmul op; - auto results = op.execute({&x, &y}, {}, {1}); + auto results = op.evaluate({&x, &y}, {}, {1}); auto z = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -1615,7 +1615,7 @@ TEST_F(DeclarableOpsTests9, matmul_test23) { y.linspace(0.1, 0.1); nd4j::ops::matmul op; - auto results = op.execute({&x, &y}, {}, {1, 1}); + auto results = op.evaluate({&x, &y}, {}, {1, 1}); auto z = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -1634,7 +1634,7 @@ TEST_F(DeclarableOpsTests9, matmul_test24) { auto exp = NDArrayFactory::create(6.); nd4j::ops::matmul op; - auto results = op.execute({&x, &y}, {}, {1, 1}); + auto results = op.evaluate({&x, &y}, {}, {1, 1}); auto z = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -1650,7 +1650,7 @@ TEST_F(DeclarableOpsTests9, test_range_int_1) { auto x2 = NDArrayFactory::create(1); nd4j::ops::range op; - auto result = op.execute({&x0, &x1, &x2}, {}, {}); + auto result = op.evaluate({&x0, &x1, &x2}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -1664,7 +1664,7 @@ TEST_F(DeclarableOpsTests9, test_range_empty_1) { auto x2 = NDArrayFactory::create(1); nd4j::ops::range op; - auto result = op.execute({&x0, &x1, &x2}, {}, {}); + auto result = op.evaluate({&x0, &x1, &x2}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -1703,7 +1703,7 @@ TEST_F(DeclarableOpsTests9, test_unstack_1) { x.linspace(1.0); nd4j::ops::unstack op; - auto result = op.execute({&x}, {}, {0}); + auto result = op.evaluate({&x}, {}, {0}); ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(5, result->size()); @@ -1721,7 +1721,7 @@ TEST_F(DeclarableOpsTests9, test_unstack_SGO_1) { auto z5 = NDArrayFactory::create(5); std::vector z({&z1, &z2, &z3, &z4, &z5}); nd4j::ops::unstack op; - auto result = op.execute({&x}, {}, {0}); + auto result = op.evaluate({&x}, {}, {0}); ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(5, result->size()); for (size_t i = 0; i < result->size(); i++) { @@ -1758,7 +1758,7 @@ TEST_F(DeclarableOpsTests9, clipbynorm_test12) { } nd4j::ops::clipbynorm op; - auto result = op.execute({&y}, {clip}, {axis}, {}, false, nd4j::DataType::DOUBLE); + auto result = op.evaluate({&y}, {clip}, {axis}); auto outFF = result->at(0); ASSERT_TRUE(expect.isSameShape(outFF)); @@ -1852,7 +1852,7 @@ TEST_F(DeclarableOpsTests9, cumprod_1) { exclusive = 0; reverse = 0; nd4j::ops::cumprod op; - auto result = op.execute({&inputC, &axis}, {}, {exclusive, reverse}, {}, false, nd4j::DataType::DOUBLE); + auto result = op.evaluate({&inputC, &axis}, {}, {exclusive, reverse}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); ASSERT_TRUE(expFF.equalsTo(z)); @@ -1861,7 +1861,7 @@ TEST_F(DeclarableOpsTests9, cumprod_1) { //************************************// exclusive = 1; reverse = 0; - result = op.execute({&inputC, &axis}, {}, {exclusive, reverse}, {}, false, nd4j::DataType::DOUBLE); + result = op.evaluate({&inputC, &axis}, {}, {exclusive, reverse}); ASSERT_EQ(Status::OK(), result->status()); z = result->at(0); ASSERT_TRUE(expTF.equalsTo(z)); @@ -1870,7 +1870,7 @@ TEST_F(DeclarableOpsTests9, cumprod_1) { //************************************// exclusive = 0; reverse = 1; - result = op.execute({&inputC, &axis}, {}, {exclusive, reverse}, {}, false, nd4j::DataType::DOUBLE); + result = op.evaluate({&inputC, &axis}, {}, {exclusive, reverse}); ASSERT_EQ(Status::OK(), result->status()); z = result->at(0); ASSERT_TRUE(expFT.equalsTo(z)); @@ -1879,7 +1879,7 @@ TEST_F(DeclarableOpsTests9, cumprod_1) { //************************************// exclusive = 1; reverse = 1; - result = op.execute({&inputC, &axis}, {}, {exclusive, reverse}, {}, false, nd4j::DataType::DOUBLE); + result = op.evaluate({&inputC, &axis}, {}, {exclusive, reverse}); ASSERT_EQ(Status::OK(), result->status()); z = result->at(0); ASSERT_TRUE(expTT.equalsTo(z)); @@ -1910,7 +1910,7 @@ TEST_F(DeclarableOpsTests9, cumprod_2) { } nd4j::ops::cumprod op; - auto result = op.execute({&x}, {}, {0, 0, 1}); + auto result = op.evaluate({&x}, {}, {0, 0, 1}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -2111,7 +2111,7 @@ TEST_F(DeclarableOpsTests9, prelu_test1) { nd4j::ops::prelu op; - auto result = op.execute({&x, &alpha}, {}, {}, {}, false, nd4j::DataType::DOUBLE); + auto result = op.evaluate({&x, &alpha}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto output = result->at(0); @@ -2129,7 +2129,7 @@ TEST_F(DeclarableOpsTests9, prelu_test2) { auto exp = NDArrayFactory::create('c', {2, 3, 4}, {7.2f, 6.6f, 6.f, 5.4f, -16.f, -14.f, -12.f, -10.f, -16.f, -12.f, -8.f, -4.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); nd4j::ops::prelu op; - auto result = op.execute({&x, &alpha}, {}, {0}, {}, false, nd4j::DataType::DOUBLE); + auto result = op.evaluate({&x, &alpha}, {}, {0}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto output = result->at(0); @@ -2147,7 +2147,7 @@ TEST_F(DeclarableOpsTests9, prelu_test3) { auto exp = NDArrayFactory::create('c', {2, 3, 4}, {7.2f, 6.6f, 6.f, 5.4f, -16.f, -14.f, -12.f, -10.f, -16.f, -12.f, -8.f, -4.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); nd4j::ops::prelu op; - auto result = op.execute({&x, &alpha}, {}, {0}, {}, false, nd4j::DataType::DOUBLE); + auto result = op.evaluate({&x, &alpha}, {}, {0}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto output = result->at(0); @@ -2165,7 +2165,7 @@ TEST_F(DeclarableOpsTests9, prelu_test4) { auto exp = NDArrayFactory::create('c', {2, 3, 4}, {7.2f, 6.6f, 6.f, 5.4f, -16.f, -14.f, -12.f, -10.f, -16.f, -12.f, -8.f, -4.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); nd4j::ops::prelu op; - auto result = op.execute({&x, &alpha}, {}, {0}, {}, false, nd4j::DataType::DOUBLE); + auto result = op.evaluate({&x, &alpha}, {}, {0}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto output = result->at(0); @@ -2183,7 +2183,7 @@ TEST_F(DeclarableOpsTests9, prelu_test5) { auto exp = NDArrayFactory::create('c', {2, 3, 4}, {7.2f, -22.f, -40.f, 9.f, 4.8f, -14.f, -24.f, 5.f, 2.4f, -6.f, -8.f, 1.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); nd4j::ops::prelu op; - auto result = op.execute({&x, &alpha}, {}, {1}, {}, false, nd4j::DataType::DOUBLE); + auto result = op.evaluate({&x, &alpha}, {}, {1}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto output = result->at(0); @@ -2201,7 +2201,7 @@ TEST_F(DeclarableOpsTests9, prelu_test6) { auto exp = NDArrayFactory::create('c', {2, 3, 4}, {24.f, 22.f, 20.f, 18.f, 16.f, 14.f, 12.f, 10.f, 8.f, 6.f, 4.f, 2.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); nd4j::ops::prelu op; - auto result = op.execute({&x, &alpha}, {}, {1,0}, {}, false, nd4j::DataType::DOUBLE); + auto result = op.evaluate({&x, &alpha}, {}, {1,0}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto output = result->at(0); @@ -2220,7 +2220,7 @@ TEST_F(DeclarableOpsTests9, prelu_test7) { auto exp = NDArrayFactory::create('c', {2, 3, 4}, {24.f, 22.f, 20.f, 18.f, 16.f, 14.f, 12.f, 10.f, 8.f, 6.f, 4.f, 2.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); nd4j::ops::prelu op; - auto result = op.execute({&x, &alpha}, {}, {1,0}, {}, false, nd4j::DataType::DOUBLE); + auto result = op.evaluate({&x, &alpha}, {}, {1,0}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto output = result->at(0); @@ -2238,7 +2238,7 @@ TEST_F(DeclarableOpsTests9, prelu_test8) { auto exp = NDArrayFactory::create('c', {2, 3, 4}, {24.f, 22.f, 20.f, 18.f, 16.f, 14.f, 12.f, 10.f, 8.f, 6.f, 4.f, 2.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); nd4j::ops::prelu op; - auto result = op.execute({&x, &alpha}, {}, {1,0,1,0,1,0}, {}, false, nd4j::DataType::DOUBLE); + auto result = op.evaluate({&x, &alpha}, {}, {1,0,1,0,1,0}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto output = result->at(0); @@ -2256,7 +2256,7 @@ TEST_F(DeclarableOpsTests9, prelu_test9) { auto exp = NDArrayFactory::create('c', {2, 4}, {8.f, 6.f, 4.f, 2.f,0.f, 1.f, 2.f, 3.f}); nd4j::ops::prelu op; - auto result = op.execute({&x, &alpha}, {}, {0}, {}, false, nd4j::DataType::DOUBLE); + auto result = op.evaluate({&x, &alpha}, {}, {0}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto output = result->at(0); @@ -2274,7 +2274,7 @@ TEST_F(DeclarableOpsTests9, prelu_test10) { auto exp = NDArrayFactory::create('c', {2, 4}, {8.f, 6.f, 4.f, 2.f,0.f, 1.f, 2.f, 3.f}); nd4j::ops::prelu op; - auto result = op.execute({&x, &alpha}, {}, {1}, {}, false, nd4j::DataType::DOUBLE); + auto result = op.evaluate({&x, &alpha}, {}, {1}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto output = result->at(0); @@ -2299,7 +2299,7 @@ TEST_F(DeclarableOpsTests9, prelu_test11) { 62.f, 63.f, 64.f, 65.f, 66.f, 67.f, 68.f, 69.f}); nd4j::ops::prelu op; - auto result = op.execute({&x, &alpha}, {}, {1,3}, {}, false, nd4j::DataType::DOUBLE); + auto result = op.evaluate({&x, &alpha}, {}, {1,3}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto output = result->at(0); @@ -2323,7 +2323,7 @@ TEST_F(DeclarableOpsTests9, prelu_test12) { 53.f, 54.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 67.f, 68.f, 69.f}); nd4j::ops::prelu op; - auto result = op.execute({&x, &alpha}, {}, {-1, 2}, {}, false, nd4j::DataType::DOUBLE); + auto result = op.evaluate({&x, &alpha}, {}, {-1, 2}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto output = result->at(0); @@ -2347,7 +2347,7 @@ TEST_F(DeclarableOpsTests9, prelu_test13) { 53.f, 54.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 67.f, 68.f, 69.f}); nd4j::ops::prelu op; - auto result = op.execute({&x, &alpha}, {}, {-1, 2}, {}, false, nd4j::DataType::DOUBLE); + auto result = op.evaluate({&x, &alpha}, {}, {-1, 2}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto output = result->at(0); @@ -2372,7 +2372,7 @@ TEST_F(DeclarableOpsTests9, prelu_test14) { 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 67.f, 68.f, 69.f}); nd4j::ops::prelu op; - auto result = op.execute({&x, &alpha}, {}, {-2}, {}, false, nd4j::DataType::DOUBLE); + auto result = op.evaluate({&x, &alpha}, {}, {-2}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto output = result->at(0); @@ -2391,7 +2391,7 @@ TEST_F(DeclarableOpsTests9, thresholdedrelu_test1) { nd4j::ops::thresholdedrelu op; - auto result = op.execute({&x}, {theta}, {}, {}, false, nd4j::DataType::DOUBLE); + auto result = op.evaluate({&x}, {theta}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto output = result->at(0); @@ -2411,7 +2411,7 @@ TEST_F(DeclarableOpsTests9, compare_and_bitpack_test1) { nd4j::ops::compare_and_bitpack op; - auto result = op.execute({&x, &threshold}, {}, {}, {}); + auto result = op.evaluate({&x, &threshold}, {}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto output = result->at(0); // output->printIndexedBuffer("Packed to uint8"); @@ -2429,7 +2429,7 @@ TEST_F(DeclarableOpsTests9, thresholdedrelu_test2) { nd4j::ops::thresholdedrelu op; - auto result = op.execute({&x}, {theta}, {}, {}, false, nd4j::DataType::DOUBLE); + auto result = op.evaluate({&x}, {theta}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto output = result->at(0); @@ -2544,7 +2544,7 @@ TEST_F(DeclarableOpsTests9, multiply_test1) { y.linspace(0.1f, 0.1f); nd4j::ops::multiply op; - auto result = op.execute({&x, &y}, {}, {}); + auto result = op.evaluate({&x, &y}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -2564,7 +2564,7 @@ TEST_F(DeclarableOpsTests9, multiply_test2) { // y.linspace(0.1f, 0.1f); nd4j::ops::multiply op; - auto result = op.execute({&y, &x}, {}, {}); + auto result = op.evaluate({&y, &x}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -2584,7 +2584,7 @@ TEST_F(DeclarableOpsTests9, multiply_test3) { y.linspace(0.1f, 0.1f); nd4j::ops::multiply op; - auto result = op.execute({&x, &y}, {}, {}); + auto result = op.evaluate({&x, &y}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -2603,7 +2603,7 @@ TEST_F(DeclarableOpsTests9, multiply_test4) { x.linspace(1.f); nd4j::ops::multiply op; - auto result = op.execute({&x, &y}, {}, {}); + auto result = op.evaluate({&x, &y}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -2621,7 +2621,7 @@ TEST_F(DeclarableOpsTests9, multiply_test5) { auto exp = NDArrayFactory::create(0.1f); nd4j::ops::multiply op; - auto result = op.execute({&x, &y}, {}, {}); + auto result = op.evaluate({&x, &y}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -2643,8 +2643,8 @@ TEST_F(DeclarableOpsTests9, multiply_bp_test1) { nd4j::ops::multiply opFF; nd4j::ops::multiply_bp opBP; - auto resFF = opFF.execute({&x, &y}, {}, {}); - auto resBP = opBP.execute({&x, &y, &dLdz}, {}, {}); + auto resFF = opFF.evaluate({&x, &y}, {}, {}); + auto resBP = opBP.evaluate({&x, &y, &dLdz}, {}, {}); // resFF->at(0)->printIndexedBuffer("Multiply 1x1"); // resBP->at(0)->printIndexedBuffer("Multiply BP 1x1 x"); // resBP->at(1)->printIndexedBuffer("Multyply BP 1x1 y");*/ @@ -2800,7 +2800,7 @@ TEST_F(DeclarableOpsTests9, Floormod_BP_Test_2) { // resFF->at(0)->printIndexedBuffer("FF floormod"); // delete resFF; nd4j::ops::floormod_bp opBP; - auto resBP = opBP.execute({&x, &y, &dLdz}, {}, {}); + auto resBP = opBP.evaluate({&x, &y, &dLdz}, {}, {}); ASSERT_TRUE(resBP->status() == ND4J_STATUS_OK); // resBP->at(0)->printIndexedBuffer("BP floormod /dx"); @@ -2832,10 +2832,10 @@ TEST_F(DeclarableOpsTests9, Dynamic_Partition_BP_1) { dLdzZ.assign(3); nd4j::ops::dynamic_partition op1; - auto res1 = op1.execute({&x, &y}, {}, {3}); + auto res1 = op1.evaluate({&x, &y}, {}, {3}); nd4j::ops::dynamic_partition_bp op2; - auto res2 = op2.execute({&x, &y, &dLdzX, &dLdzY, &dLdzZ}, {}, {3}); + auto res2 = op2.evaluate({&x, &y, &dLdzX, &dLdzY, &dLdzZ}, {}, {3}); ASSERT_TRUE(res2->status() == ND4J_STATUS_OK); ASSERT_TRUE(res2->size() == 2); // printf("How many: %ul\n", res2->size()); @@ -2879,7 +2879,7 @@ TEST_F(DeclarableOpsTests9, Floormod_BP_Test_4) { eps.assign(1.f); nd4j::ops::floormod_bp op; - auto result = op.execute({&x, &y, &eps}, {}, {}); + auto result = op.evaluate({&x, &y, &eps}, {}, {}); ASSERT_TRUE(result->size() == 2); auto gradX = result->at(0); @@ -2924,7 +2924,7 @@ TEST_F(DeclarableOpsTests9, gru_cell_bp_test1) { const OpArgsHolder argsHolderFF({&x, &hi, &W, &Wc, &b, &bc}, {}, {}); nd4j::ops::gruCell op; - auto results = op.execute(argsHolderFF); + auto results = op.evaluate(argsHolderFF); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2964,7 +2964,7 @@ TEST_F(DeclarableOpsTests9, Cholesky_Test_1) { nd4j::ops::cholesky op; - auto result = op.execute({&x}, {}, {}); + auto result = op.evaluate({&x}, {}, {}); ASSERT_EQ(result->status(), ND4J_STATUS_OK); auto res = result->at(0); // res->printIndexedBuffer("Output for Cholesky1"); @@ -2980,7 +2980,7 @@ TEST_F(DeclarableOpsTests9, Cholesky_Test_2) { nd4j::ops::cholesky op; - auto result = op.execute({&x}, {}, {}); + auto result = op.evaluate({&x}, {}, {}); ASSERT_EQ(result->status(), ND4J_STATUS_OK); auto res = result->at(0); // res->printIndexedBuffer("Output for Cholesky 2"); @@ -2996,7 +2996,7 @@ TEST_F(DeclarableOpsTests9, Cholesky_Test_3) { nd4j::ops::cholesky op; - auto result = op.execute({&x}, {}, {}); + auto result = op.evaluate({&x}, {}, {}); ASSERT_EQ(result->status(), ND4J_STATUS_OK); auto res = result->at(0); // res->printIndexedBuffer("Output for Cholesky 3"); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTestsCuda1.cu b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTestsCuda1.cu index f88cddde5..b7907ce1d 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTestsCuda1.cu +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTestsCuda1.cu @@ -50,7 +50,7 @@ TEST_F(DeclarableOpsTestsCuda1, Test_CHOOSE_SCALAR_LARGE) { nd4j::ops::choose op; //greater than test - auto result = op.execute({&x}, {0.0},{3}); + auto result = op.evaluate({&x}, {0.0},{3}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(1); diff --git a/libnd4j/tests_cpu/layers_tests/EmptyTests.cpp b/libnd4j/tests_cpu/layers_tests/EmptyTests.cpp index 12069c67e..3717c488b 100644 --- a/libnd4j/tests_cpu/layers_tests/EmptyTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/EmptyTests.cpp @@ -66,7 +66,7 @@ TEST_F(EmptyTests, Test_Concat_1) { ASSERT_TRUE(empty->isEmpty()); nd4j::ops::concat op; - auto result = op.execute({empty, vector}, {}, {0}); + auto result = op.evaluate({empty, vector}, {}, {0}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -91,7 +91,7 @@ TEST_F(EmptyTests, Test_Concat_2) { ASSERT_TRUE(empty->isEmpty()); nd4j::ops::concat op; - auto result = op.execute({empty, scalar1, scalar2}, {}, {0}); + auto result = op.evaluate({empty, scalar1, scalar2}, {}, {0}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -116,7 +116,7 @@ TEST_F(EmptyTests, Test_Concat_3) { ASSERT_TRUE(empty.isEmpty()); nd4j::ops::concat op; - auto result = op.execute({&empty, &scalar1, &scalar2}, {}, {0}); + auto result = op.evaluate({&empty, &scalar1, &scalar2}, {}, {0}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -135,7 +135,7 @@ TEST_F(EmptyTests, Test_Concat_4) { ASSERT_TRUE(empty.isEmpty()); nd4j::ops::concat op; - auto result = op.execute({&scalar1, &empty, &scalar2}, {}, {0}); + auto result = op.evaluate({&scalar1, &empty, &scalar2}, {}, {0}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -151,7 +151,7 @@ TEST_F(EmptyTests, Test_Reshape_1) { auto empty = NDArrayFactory::empty_(); nd4j::ops::reshape op; - auto result = op.execute({&vector, empty}, {}, {}); + auto result = op.evaluate({&vector, empty}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); @@ -167,7 +167,7 @@ TEST_F(EmptyTests, Test_Reshape_2) { auto empty = NDArrayFactory::empty_(); nd4j::ops::reshape op; - auto result = op.execute({&vector, empty}, {}, {}, {}, true); + auto result = op.evaluate({&vector, empty}, {}, {}, {}, {}, true); ASSERT_EQ(Status::OK(), result->status()); @@ -184,7 +184,7 @@ TEST_F(EmptyTests, Test_Reshape_3) { auto e = NDArrayFactory::create('c', {10, 0}); nd4j::ops::reshape op; - auto result = op.execute({&x, &y}, {}, {}); + auto result = op.evaluate({&x, &y}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -213,7 +213,7 @@ TEST_F(EmptyTests, test_empty_scatter_1) { x.linspace(1.0f); nd4j::ops::scatter_upd op; - auto result = op.execute({&x, &indices, &updates}, {}, {}, {true}); + auto result = op.evaluate({&x, &indices, &updates}, {}, {}, {true}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -311,12 +311,12 @@ TEST_F(EmptyTests, test_empty_reshape_1) { auto e1 = NDArrayFactory::create('c', {0, 1}); nd4j::ops::reshape op; - auto result0 = op.execute({&x0, &shape0}, {}, {}); + auto result0 = op.evaluate({&x0, &shape0}, {}, {}); ASSERT_EQ(Status::OK(), result0->status()); auto z0 = result0->at(0); ASSERT_EQ(e0, *z0); - auto result1 = op.execute({&x1, &shape1}, {}, {}); + auto result1 = op.evaluate({&x1, &shape1}, {}, {}); ASSERT_EQ(Status::OK(), result1->status()); auto z1 = result1->at(0); ASSERT_EQ(e1, *z1); @@ -332,7 +332,7 @@ TEST_F(EmptyTests, test_empty_matmul_1) { auto e = NDArrayFactory::create('c', {0, 0}); nd4j::ops::matmul op; - auto result = op.execute({&x, &y}, {}, {}); + auto result = op.evaluate({&x, &y}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -347,7 +347,7 @@ TEST_F(EmptyTests, test_empty_matmul_2) { auto e = NDArrayFactory::create('c', {1, 0, 0}); nd4j::ops::matmul op; - auto result = op.execute({&x, &y}, {}, {}); + auto result = op.evaluate({&x, &y}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); diff --git a/libnd4j/tests_cpu/layers_tests/IndexingTests.cpp b/libnd4j/tests_cpu/layers_tests/IndexingTests.cpp index 790279f74..d26fbd122 100644 --- a/libnd4j/tests_cpu/layers_tests/IndexingTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/IndexingTests.cpp @@ -46,7 +46,7 @@ TEST_F(IndexingTests, StridedSlice_1) { nd4j::ops::strided_slice op; - auto result = op.execute({&x, &begin, &end, &strides}, {}, {0,0,0,0,0}); //, 2,2,0, 3,3,3, 1,1,1}); + auto result = op.evaluate({&x, &begin, &end, &strides}, {}, {0,0,0,0,0}); //, 2,2,0, 3,3,3, 1,1,1}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -65,7 +65,7 @@ TEST_F(IndexingTests, StridedSlice_2) { nd4j::ops::strided_slice op; - auto result = op.execute({&x}, {}, {0,0,0,0,0, 3,2,0, 5,5,3, 1,1,1}); + auto result = op.evaluate({&x}, {}, {0,0,0,0,0, 3,2,0, 5,5,3, 1,1,1}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -85,7 +85,7 @@ TEST_F(IndexingTests, StridedSlice_3) { nd4j::ops::strided_slice op; - auto result = op.execute({&x}, {}, {0,0,0,0,0, 3,2,0, 5,5,3, 1,1,2}); + auto result = op.evaluate({&x}, {}, {0,0,0,0,0, 3,2,0, 5,5,3, 1,1,2}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -108,7 +108,7 @@ TEST_F(IndexingTests, SimpleSlice_1) { nd4j::ops::slice op; - auto result = op.execute({&input}, {}, {1,0,0, 1,1,3}); + auto result = op.evaluate({&input}, {}, {1,0,0, 1,1,3}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -134,7 +134,7 @@ TEST_F(IndexingTests, SimpleSlice_2) { nd4j::ops::slice op; - auto result = op.execute({&input}, {}, {1,0,0, 1,2,3}); + auto result = op.evaluate({&input}, {}, {1,0,0, 1,2,3}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -159,7 +159,7 @@ TEST_F(IndexingTests, SimpleSlice_3) { nd4j::ops::slice op; - auto result = op.execute({&input}, {}, {1,0,0, 2,1,3}); + auto result = op.evaluate({&input}, {}, {1,0,0, 2,1,3}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -179,7 +179,7 @@ TEST_F(IndexingTests, SimpleSlice_4) { nd4j::ops::slice op; - auto result = op.execute({&input, &start, &stop}, {}, {}); + auto result = op.evaluate({&input, &start, &stop}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -202,7 +202,7 @@ TEST_F(IndexingTests, MaskedSlice_0) { exp.assign(2.0f); nd4j::ops::strided_slice op; - auto result = op.execute({&matrix}, {}, {0,0,0,0,0, 1, 2, 1}); + auto result = op.evaluate({&matrix}, {}, {0,0,0,0,0, 1, 2, 1}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -228,7 +228,7 @@ TEST_F(IndexingTests, MaskedSlice_00) { nd4j::ops::strided_slice op; - auto result = op.execute({&matrix}, {}, {0,0,0,0,0, 1, 1, 2, 3, 1, 1}); + auto result = op.evaluate({&matrix}, {}, {0,0,0,0,0, 1, 1, 2, 3, 1, 1}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -252,7 +252,7 @@ TEST_F(IndexingTests, MaskedSlice_1) { exp.assign(2.0f); nd4j::ops::strided_slice op; - auto result = op.execute({&matrix}, {}, {0,0,0,0,1, 1, 2, 1}); + auto result = op.evaluate({&matrix}, {}, {0,0,0,0,1, 1, 2, 1}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -273,7 +273,7 @@ TEST_F(IndexingTests, MaskedSlice_2) { // output = tf.strided_slice(a, [1, 0, 0], [3, 3, 3], shrink_axis_mask=5) nd4j::ops::strided_slice op; - auto result = op.execute({&matrix}, {}, {0,0,0,0,1, 1, 0, 0, 3, 3, 3, 1, 1, 1}); + auto result = op.evaluate({&matrix}, {}, {0,0,0,0,1, 1, 0, 0, 3, 3, 3, 1, 1, 1}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -293,7 +293,7 @@ TEST_F(IndexingTests, MaskedSlice_3) { // output = tf.strided_slice(a, [1, 0, 0], [3, 3, 3], shrink_axis_mask=5) nd4j::ops::strided_slice op; - auto result = op.execute({&matrix}, {}, {0,0,0,0,2, 1, 0, 0, 3, 3, 3, 1, 1, 1}); + auto result = op.evaluate({&matrix}, {}, {0,0,0,0,2, 1, 0, 0, 3, 3, 3, 1, 1, 1}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -313,7 +313,7 @@ TEST_F(IndexingTests, MaskedSlice_4) { // output = tf.strided_slice(a, [1, 0, 0], [3, 3, 3], shrink_axis_mask=5) nd4j::ops::strided_slice op; - auto result = op.execute({&matrix}, {}, {0,0,0,0, 3, 1, 0, 0, 3, 3, 3, 1, 1, 1}); + auto result = op.evaluate({&matrix}, {}, {0,0,0,0, 3, 1, 0, 0, 3, 3, 3, 1, 1, 1}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -336,7 +336,7 @@ TEST_F(IndexingTests, Live_Slice_1) { // output = tf.strided_slice(a, [1, 0, 0], [3, 3, 3], shrink_axis_mask=5) nd4j::ops::strided_slice op; - auto result = op.execute({&matrix, &begin, &end, &stride}, {}, {0,0,0,0,3}); + auto result = op.evaluate({&matrix, &begin, &end, &stride}, {}, {0,0,0,0,3}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -359,7 +359,7 @@ TEST_F(IndexingTests, Test_StridedSlice_1) { auto exp = NDArrayFactory::create({5.0f, 2}); nd4j::ops::strided_slice op; - auto result = op.execute({&x, &a, &b, &c}, {}, {0, 0, 0, 0, 1}); + auto result = op.evaluate({&x, &a, &b, &c}, {}, {0, 0, 0, 0, 1}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -379,7 +379,7 @@ TEST_F(IndexingTests, Test_StridedSlice_2) { auto exp = NDArrayFactory::create('c', {1}, {5.0}); nd4j::ops::strided_slice op; - auto result = op.execute({&x, &a, &b, &c}, {}, {0, 0, 0, 0, 1}); + auto result = op.evaluate({&x, &a, &b, &c}, {}, {0, 0, 0, 0, 1}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -402,7 +402,7 @@ TEST_F(IndexingTests, Test_StridedSlice_3) { auto exp = NDArrayFactory::create('c', {1}, {6.0}); nd4j::ops::strided_slice op; - auto result = op.execute({&x, &a, &b, &c}, {}, {0, 0, 0, 0, 1}); + auto result = op.evaluate({&x, &a, &b, &c}, {}, {0, 0, 0, 0, 1}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -423,7 +423,7 @@ TEST_F(IndexingTests, Test_StridedSlice_4) { auto exp = NDArrayFactory::create({5.0f, 2}); nd4j::ops::strided_slice op; - auto result = op.execute({&x, &a, &b, &c}, {}, {0, 0, 0, 0, 1}); + auto result = op.evaluate({&x, &a, &b, &c}, {}, {0, 0, 0, 0, 1}); // auto result = op.execute({&x, &a, &b, &c}, {}, {0, 0, 0, 0, 1, 0, 1, 1}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); diff --git a/libnd4j/tests_cpu/layers_tests/LegacyOpsTests.cpp b/libnd4j/tests_cpu/layers_tests/LegacyOpsTests.cpp index f0b7628ee..35f46d739 100644 --- a/libnd4j/tests_cpu/layers_tests/LegacyOpsTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/LegacyOpsTests.cpp @@ -62,7 +62,7 @@ TEST_F(LegacyOpsTests, TransformTests_2) { exp.assign(-1.0); nd4j::ops::LegacyTransformSameOp op(transform::Neg); // Neg - auto result = op.execute({&x}, {}, {}); + auto result = op.evaluate({&x}, {}, {}); ASSERT_EQ(1, result->size()); @@ -119,7 +119,7 @@ TEST_F(LegacyOpsTests, PWT_Tests_2) { exp.assign(6.0); nd4j::ops::LegacyPairwiseTransformOp op(pairwise::Multiply); // Multiply - auto result = op.execute({&x, &y}, {}, {}); + auto result = op.evaluate({&x, &y}, {}, {}); auto z = result->at(0); @@ -152,7 +152,7 @@ TEST_F(LegacyOpsTests, Scalar_Test_2) { auto y = NDArrayFactory::create(5.0f); nd4j::ops::LegacyScalarOp op(scalar::Add, y); - auto result = op.execute({&x}, {}, {}); + auto result = op.evaluate({&x}, {}, {}); auto z = result->at(0); ASSERT_TRUE(exp.equalsTo(z)); @@ -167,7 +167,7 @@ TEST_F(LegacyOpsTests, ReduceTests_1) { int opNum = reduce::Sum; nd4j::ops::LegacyReduceSameOp op(opNum); - auto result = op.execute({&x}, {}, {}); + auto result = op.evaluate({&x}, {}, {}); ASSERT_EQ(1, result->size()); @@ -186,7 +186,7 @@ TEST_F(LegacyOpsTests, ReduceTests_2) { nd4j::ops::LegacyReduceSameOp op(reduce::Sum); auto axis = NDArrayFactory::create('c', {1}, {1}); - auto result = op.execute({&x, &axis}, {}, {}); + auto result = op.evaluate({&x, &axis}, {}, {}); ASSERT_EQ(1, result->size()); @@ -208,7 +208,7 @@ TEST_F(LegacyOpsTests, ReduceTests_3) { nd4j::ops::LegacyReduceSameOp op(reduce::Sum); - auto result = op.execute({&x, &indices}, {}, {}); + auto result = op.evaluate({&x, &indices}, {}, {}); auto z = result->at(0); auto exp = x.reduceAlongDimension(reduce::Sum,{1}); @@ -228,7 +228,7 @@ TEST_F(LegacyOpsTests, ReduceTests_4) { nd4j::ops::LegacyReduceSameOp op(reduce::Sum); - auto result = op.execute({&x, &indices}, {}, {}, {true}); + auto result = op.evaluate({&x, &indices}, {}, {}, {true}); auto z = result->at(0); auto exp = x.reduceAlongDimension(reduce::Sum, {1}, true); // indices.printShapeInfo("Indices shape"); @@ -247,7 +247,7 @@ TEST_F(LegacyOpsTests, ReduceTests_5) { int opNum = reduce::Mean; nd4j::ops::LegacyReduceFloatOp op(opNum); - ResultSet* result = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::FLOAT32); + auto result = op.evaluate({&x}); ASSERT_EQ(1, result->size()); @@ -266,7 +266,7 @@ TEST_F(LegacyOpsTests, ReduceTests_6) { auto axis = NDArrayFactory::create('c', {1}, {1}); nd4j::ops::LegacyReduceFloatOp op(reduce::Mean); - auto result = op.execute({&x, &axis}, {}, {}); + auto result = op.evaluate({&x, &axis}, {}, {}); ASSERT_EQ(1, result->size()); @@ -288,7 +288,7 @@ TEST_F(LegacyOpsTests, ReduceTests_7) { nd4j::ops::LegacyReduceFloatOp op(reduce::Mean); - auto result = op.execute({&x, &indices}, {}, {}); + auto result = op.evaluate({&x, &indices}, {}, {}); auto z = result->at(0); auto exp = x.reduceAlongDimension(reduce::Mean,{1}); @@ -308,7 +308,7 @@ TEST_F(LegacyOpsTests, ReduceTests_8) { nd4j::ops::LegacyReduceFloatOp op(reduce::Mean); - auto result = op.execute({&x, &indices}, {}, {}, {true}); + auto result = op.evaluate({&x, &indices}, {}, {}, {true}); auto z = result->at(0); auto exp = x.reduceAlongDimension(reduce::Mean, {1}, true); @@ -329,7 +329,7 @@ TEST_F(LegacyOpsTests, IndexReduceTests_1) { nd4j::ops::LegacyIndexReduceOp op(indexreduce::IndexMax); - auto result = op.execute({&x}, {}, {}); + auto result = op.evaluate({&x}, {}, {}); ASSERT_EQ(1, result->size()); @@ -349,7 +349,7 @@ TEST_F(LegacyOpsTests, IndexReduceTests_2) { auto exp = NDArrayFactory::create({4,4,4,4,4}); nd4j::ops::LegacyIndexReduceOp op(indexreduce::IndexMax); - auto result = op.execute({&x, &indices}, {}, {}); + auto result = op.evaluate({&x, &indices}, {}, {}); ASSERT_EQ(1, result->size()); diff --git a/libnd4j/tests_cpu/layers_tests/MultiDataTypeTests.cpp b/libnd4j/tests_cpu/layers_tests/MultiDataTypeTests.cpp index 9afc34267..aae4493ab 100644 --- a/libnd4j/tests_cpu/layers_tests/MultiDataTypeTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/MultiDataTypeTests.cpp @@ -133,7 +133,7 @@ TEST_F(MultiDataTypeTests, Basic_Test_7) { auto e = NDArrayFactory::create('c', {2, 3}, {0.f, 2.f, 4.f, 6.f, 8.f, 10.f}); nd4j::ops::add op; - auto result = op.execute({&x, &y},{}, {}); + auto result = op.evaluate({&x, &y}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); diff --git a/libnd4j/tests_cpu/layers_tests/NlpTests.cpp b/libnd4j/tests_cpu/layers_tests/NlpTests.cpp index 9c6f4a981..83eeee48b 100644 --- a/libnd4j/tests_cpu/layers_tests/NlpTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/NlpTests.cpp @@ -65,7 +65,7 @@ TEST_F(NlpTests, basic_sg_hs_test_1) { auto inferenceVector = NDArrayFactory::empty(); nd4j::ops::skipgram op; - auto result = op.execute({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {}, {false}, true); + auto result = op.evaluate({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {}, {false}, {}, true); ASSERT_EQ(Status::OK(), result->status()); auto row0 = syn0({0,1, 0,0}, true); @@ -106,7 +106,7 @@ TEST_F(NlpTests, basic_sg_hs_test_2) { auto inferenceVector = NDArrayFactory::empty(); nd4j::ops::skipgram op; - auto result = op.execute({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {}, {false}, true); + auto result = op.evaluate({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {}, {false}, {}, true); ASSERT_EQ(Status::OK(), result->status()); auto row0 = syn0({0,1, 0,0}, true); @@ -157,8 +157,8 @@ TEST_F(NlpTests, basic_sg_hs_test_3) { auto inferenceVector = NDArrayFactory::empty(); nd4j::ops::skipgram op; - auto result0 = op.execute({&target, &ngStarter, &indices0, &codes00, &syn00, &syn10, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {}, {false}, true); - auto result1 = op.execute({&target, &ngStarter, &indices1, &codes01, &syn01, &syn11, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {}, {false}, true); + auto result0 = op.evaluate({&target, &ngStarter, &indices0, &codes00, &syn00, &syn10, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {}, {false}, {}, true); + auto result1 = op.evaluate({&target, &ngStarter, &indices1, &codes01, &syn01, &syn11, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {}, {false}, {}, true); ASSERT_EQ(Status::OK(), result0->status()); auto row00 = syn00({0,1, 0,0}, true); @@ -191,7 +191,7 @@ TEST_F(NlpTests, basic_sg_hs_ns_test_1) { auto inferenceVector = NDArrayFactory::empty(); nd4j::ops::skipgram op; - auto result = op.execute({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {3}, {false}, true); + auto result = op.evaluate({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {3}, {false}, {}, true); ASSERT_EQ(Status::OK(), result->status()); delete result; @@ -226,7 +226,7 @@ TEST_F(NlpTests, basic_sg_ns_test_1) { auto inferenceVector = NDArrayFactory::empty(); nd4j::ops::skipgram op; - auto result = op.execute({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {1, 1}, {false}, true); + auto result = op.evaluate({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {1, 1}, {false}, {}, true); ASSERT_EQ(Status::OK(), result->status()); auto row0 = syn0({1,2, 0,0}, true); @@ -268,7 +268,7 @@ TEST_F(NlpTests, basic_cb_hs_test_1) { auto inferenceVector = NDArrayFactory::empty(); nd4j::ops::cbow op; - auto result = op.execute({&target, &ngStarter, &context, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &numWords, &locked, &inferenceVector}, {}, {}, {true}, true); + auto result = op.evaluate({&target, &ngStarter, &context, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &numWords, &locked, &inferenceVector}, {}, {}, {true}, {}, true); ASSERT_EQ(Status::OK(), result->status()); auto row_s0_0 = syn0({0,1, 0,0}, true); @@ -322,7 +322,7 @@ TEST_F(NlpTests, basic_cb_ns_test_1) { auto inferenceVector = NDArrayFactory::empty(); nd4j::ops::cbow op; - auto result = op.execute({&target, &ngStarter, &context, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &numWords, &locked, &inferenceVector}, {}, {1, 2, 0}, {true}, true); + auto result = op.evaluate({&target, &ngStarter, &context, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &numWords, &locked, &inferenceVector}, {}, {1, 2, 0}, {true}, {}, true); ASSERT_EQ(Status::OK(), result->status()); auto row_s0_0 = syn0({0,1, 0,0}, true); @@ -371,7 +371,7 @@ TEST_F(NlpTests, test_sg_hs_batch_1) { expTable.assign(0.5); nd4j::ops::skipgram op; - auto result = op.execute({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {}, {false, true}, true); + auto result = op.evaluate({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {}, {false, true}, {}, true); ASSERT_EQ(Status::OK(), result->status()); auto row0 = syn0({0,1, 0,0}, true); @@ -415,7 +415,7 @@ TEST_F(NlpTests, test_sg_ns_batch_1) { negTable.linspace(0.0); nd4j::ops::skipgram op; - auto result = op.execute({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {4, 5}, {false, true}, true); + auto result = op.evaluate({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {4, 5}, {false, true}, {}, true); ASSERT_EQ(Status::OK(), result->status()); auto row0 = syn0({0,0, 0,0}, true); @@ -452,7 +452,7 @@ TEST_F(NlpTests, test_cbow_hs_batch_1) { auto inferenceVector = NDArrayFactory::empty(); nd4j::ops::cbow op; - auto result = op.execute({&target, &ngStarter, &context, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &numWords, &locked, &inferenceVector}, {}, {}, {true}, true); + auto result = op.evaluate({&target, &ngStarter, &context, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &numWords, &locked, &inferenceVector}, {}, {}, {true}, {}, true); ASSERT_EQ(Status::OK(), result->status()); auto exp0 = NDArrayFactory::create('c', {1, 10}); diff --git a/libnd4j/tests_cpu/layers_tests/ParityOpsTests.cpp b/libnd4j/tests_cpu/layers_tests/ParityOpsTests.cpp index b62cbceea..987817136 100644 --- a/libnd4j/tests_cpu/layers_tests/ParityOpsTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/ParityOpsTests.cpp @@ -41,7 +41,7 @@ TEST_F(ParityOpsTests, TestZeroAs1) { nd4j::ops::zeros_as op; - auto result = op.execute({&x}, {}, {}); + auto result = op.evaluate({&x}, {}, {}); auto z = result->at(0); @@ -60,7 +60,7 @@ TEST_F(ParityOpsTests, TestMaximum1) { nd4j::ops::maximum op; - auto result = op.execute({&x, &y}, {}, {}); + auto result = op.evaluate({&x, &y}, {}, {}); auto z = result->at(0); @@ -80,7 +80,7 @@ TEST_F(ParityOpsTests, TestMinimum1) { nd4j::ops::minimum op; - auto result = op.execute({&x, &y}, {}, {}); + auto result = op.evaluate({&x, &y}, {}, {}); auto z = result->at(0); @@ -99,7 +99,7 @@ TEST_F(ParityOpsTests, TestTear1) { nd4j::ops::tear op; - auto result = op.execute({&input}, {}, {1}); + auto result = op.evaluate({&input}, {}, {1}); ASSERT_EQ(10, result->size()); @@ -119,7 +119,7 @@ TEST_F(ParityOpsTests, TestUnstack1) { nd4j::ops::unstack op; - auto result = op.execute({&input}, {}, {0}); + auto result = op.evaluate({&input}, {}, {0}); ASSERT_EQ(10, result->size()); @@ -141,7 +141,7 @@ TEST_F(ParityOpsTests, TestUnstack2) { nd4j::ops::unstack op; - auto result = op.execute({&input}, {}, {2}); + auto result = op.evaluate({&input}, {}, {2}); ASSERT_EQ(6, result->size()); @@ -158,7 +158,7 @@ TEST_F(ParityOpsTests, TestUnstack3) { nd4j::ops::unstack op; - auto result = op.execute({&input}, {}, {2}); + auto result = op.evaluate({&input}, {}, {2}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -177,7 +177,7 @@ TEST_F(ParityOpsTests, TestUnstack4) { nd4j::ops::unstack op; - auto result = op.execute({&input}, {}, {1}); + auto result = op.evaluate({&input}, {}, {1}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -195,7 +195,7 @@ TEST_F(ParityOpsTests, TestUnstack5) { nd4j::ops::unstack op; - auto result = op.execute({&input}, {}, {0}); + auto result = op.evaluate({&input}, {}, {0}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -213,7 +213,7 @@ TEST_F(ParityOpsTests, TestUnstack6) { nd4j::ops::unstack op; - auto result = op.execute({&input}, {}, {0}); + auto result = op.evaluate({&input}, {}, {0}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -231,7 +231,7 @@ TEST_F(ParityOpsTests, TestUnstack7) { nd4j::ops::unstack op; - auto result = op.execute({&input}, {}, {1}); + auto result = op.evaluate({&input}, {}, {1}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -249,7 +249,7 @@ TEST_F(ParityOpsTests, TestUnstack8) { nd4j::ops::unstack op; - auto result = op.execute({&input}, {}, {0}); + auto result = op.evaluate({&input}, {}, {0}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -267,7 +267,7 @@ TEST_F(ParityOpsTests, TestUnstack9) { nd4j::ops::unstack op; - auto result = op.execute({&input}, {}, {1}); + auto result = op.evaluate({&input}, {}, {1}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -286,7 +286,7 @@ TEST_F(ParityOpsTests, TestUnstack10) { nd4j::ops::unstack op; - auto result = op.execute({&input}, {}, {0}); + auto result = op.evaluate({&input}, {}, {0}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_TRUE(exp.isSameShape(result->at(0))); @@ -304,7 +304,7 @@ TEST_F(ParityOpsTests, TestUnstack11) { nd4j::ops::unstack op; - auto result = op.execute({&input}, {}, {2}); + auto result = op.evaluate({&input}, {}, {2}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_TRUE(exp.isSameShape(result->at(0))); @@ -320,7 +320,7 @@ TEST_F(ParityOpsTests, TestUnstack12) { nd4j::ops::unstack op; - auto result = op.execute({&input}, {}, {1}); + auto result = op.evaluate({&input}, {}, {1}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_TRUE(result->size() == 0); @@ -334,7 +334,7 @@ TEST_F(ParityOpsTests, ExpandDimsTest1) { auto reshaped = input.reshape('c', {5, 1, 5}); nd4j::ops::expand_dims op; - auto result = op.execute({&input}, {}, {1}); + auto result = op.evaluate({&input}, {}, {1}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -353,7 +353,7 @@ TEST_F(ParityOpsTests, ExpandDimsTest2) { auto reshaped = input.reshape('c', {1, 3, 4}); nd4j::ops::expand_dims op; - auto result = op.execute({&input}, {}, {0}); + auto result = op.evaluate({&input}, {}, {0}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -372,7 +372,7 @@ TEST_F(ParityOpsTests, ExpandDimsTest3) { auto reshaped = input.reshape('c', {3, 1, 4}); nd4j::ops::expand_dims op; - auto result = op.execute({&input}, {}, {-2}); + auto result = op.evaluate({&input}, {}, {-2}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -390,7 +390,7 @@ TEST_F(ParityOpsTests, ExpandDimsTest4) { auto reshaped = input.reshape('c', {1, 3, 4}); nd4j::ops::expand_dims op; - auto result = op.execute({&input}, {}, {-3}); + auto result = op.evaluate({&input}, {}, {-3}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -408,7 +408,7 @@ TEST_F(ParityOpsTests, Test_Shape_1) { auto exp = NDArrayFactory::create('c', {4}, {3, 4, 5, 6}); nd4j::ops::shape_of op; - auto result = op.execute({&x}, {}, {}); + auto result = op.evaluate({&x}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -426,7 +426,7 @@ TEST_F(ParityOpsTests, Test_Equals_1) { auto exp = NDArrayFactory::create('c', {1, 5}, {1, 0, 1, 0, 1}); nd4j::ops::equals op; - auto result = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::BOOL); + auto result = op.evaluate({&x, &y}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -444,7 +444,7 @@ TEST_F(ParityOpsTests, Test_NotEquals_1) { auto exp = NDArrayFactory::create('c', {1, 5}, {0, 1, 0, 1, 0}); nd4j::ops::not_equals op; - auto result = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::BOOL); + auto result = op.evaluate({&x, &y}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -461,7 +461,7 @@ TEST_F(ParityOpsTests, Test_Less_1) { auto exp = NDArrayFactory::create('c', {1, 5}, {1, 1, 0, 0, 0}); nd4j::ops::less op; - auto result = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::BOOL); + auto result = op.evaluate({&x, &y}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -478,7 +478,7 @@ TEST_F(ParityOpsTests, Test_LessEquals_1) { auto exp = NDArrayFactory::create('c', {1, 5}, {1, 1, 1, 0, 0}); nd4j::ops::less_equal op; - auto result = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::BOOL); + auto result = op.evaluate({&x, &y}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -495,7 +495,7 @@ TEST_F(ParityOpsTests, Test_GreaterEquals_1) { auto exp = NDArrayFactory::create('c', {1, 5}, {0, 0, 1, 1, 1}); nd4j::ops::greater_equal op; - auto result = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::BOOL); + auto result = op.evaluate({&x, &y}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -512,7 +512,7 @@ TEST_F(ParityOpsTests, Test_GreaterEquals_2) { auto exp = NDArrayFactory::create('c', {1, 5}, {0, 0, 1, 1, 1}); nd4j::ops::greater_equal op; - auto result = op.execute({&x, &y}, {}, {}, {}, false); + auto result = op.evaluate({&x, &y}, {}, {}, {}, {}, false); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -529,7 +529,7 @@ TEST_F(ParityOpsTests, Test_Greater_1) { auto exp = NDArrayFactory::create('c', {1, 5}, {0, 0, 0, 1, 1}); nd4j::ops::greater op; - auto result = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::BOOL); + auto result = op.evaluate({&x, &y}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -547,7 +547,7 @@ TEST_F(ParityOpsTests, Test_Where_1) { auto exp = NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 6, 5, 4, 7, 8, 9}); nd4j::ops::Where op; - auto result = op.execute({&mask, &x, &y}, {}, {}); + auto result = op.evaluate({&mask, &x, &y}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -567,7 +567,7 @@ TEST_F(ParityOpsTests, Test_Where_2) { auto exp = NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 6, 5, 4, 3, 2, 1}); nd4j::ops::Where op; - auto result = op.execute({&mask, &x, &y}, {}, {}); + auto result = op.evaluate({&mask, &x, &y}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -584,7 +584,7 @@ TEST_F(ParityOpsTests, Test_Where_3) { auto exp = NDArrayFactory::create('c', {5, 3}, {0, 0, 1, 0, 0, 2, 0, 1, 1, 1, 0, 0, 1, 1, 2}); nd4j::ops::Where op; - auto result = op.execute({&mask}, {}, {}); + auto result = op.evaluate({&mask}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -604,7 +604,7 @@ TEST_F(ParityOpsTests, Test_Select_1) { auto exp = NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 6, 5, 4, 3, 2, 1}); nd4j::ops::select op; - auto result = op.execute({&mask, &x, &y}, {}, {}); + auto result = op.evaluate({&mask, &x, &y}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -622,7 +622,7 @@ TEST_F(ParityOpsTests, Test_Select_2) { auto exp = NDArrayFactory::create('c', {2, 2}, {1, 8, 3, 6}); nd4j::ops::select op; - auto result = op.execute({&mask, &x, &y}, {}, {}); + auto result = op.evaluate({&mask, &x, &y}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -641,7 +641,7 @@ TEST_F(ParityOpsTests, Test_Select_3) { auto exp = NDArrayFactory::create('c', {1, 1}, {2}); nd4j::ops::select op; - auto result = op.execute({&mask, &x, &y}, {}, {}); + auto result = op.evaluate({&mask, &x, &y}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -660,7 +660,7 @@ TEST_F(ParityOpsTests, Test_Reshape_TF_1) { nd4j::ops::reshape op; - auto result = op.execute({&x, &shape}, {}, {}); + auto result = op.evaluate({&x, &shape}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -677,7 +677,7 @@ TEST_F(ParityOpsTests, Test_Bias_Add_1) { auto bias = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); nd4j::ops::biasadd op; - auto result = op.execute({&x, &bias}, {}, {}); + auto result = op.evaluate({&x, &bias}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -697,7 +697,7 @@ TEST_F(ParityOpsTests, Test_Scatter_Add_1) { auto exp = NDArrayFactory::create('c', {2, 2}, {2, 3, 3, 4}); nd4j::ops::scatter_add op; - auto result = op.execute({&matrix, &idc, &updates}, {}, {}); + auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -715,7 +715,7 @@ TEST_F(ParityOpsTests, Test_Scatter_Add_2) { auto exp = NDArrayFactory::create('c', {1, 4}, {2, 3, 4, 5}); nd4j::ops::scatter_add op; - auto result = op.execute({&vec, &idc, &updates}, {}, {}); + auto result = op.evaluate({&vec, &idc, &updates}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -732,7 +732,7 @@ TEST_F(ParityOpsTests, Test_Scatter_Add_3) { auto exp = NDArrayFactory::create('c', {2, 2, 2}, {2, 3, 4, 5, 5, 6, 7, 8}); nd4j::ops::scatter_add op; - auto result = op.execute({&matrix, &idc, &updates}, {}, {}); + auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -749,7 +749,7 @@ TEST_F(ParityOpsTests, Test_Scatter_Add_4) { auto exp = NDArrayFactory::create('c', {2, 2, 2}, {3, 4, 5, 6, 5, 6, 7, 8}); nd4j::ops::scatter_add op; - auto result = op.execute({&matrix, &idc, &updates}, {}, {}, {true, true}); + auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}, {true, true}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -766,7 +766,7 @@ TEST_F(ParityOpsTests, Test_Scatter_Add_5) { auto exp = NDArrayFactory::create('c', {2, 2, 3}, {9., 11., 13.,15., 17., 19., 9., 11., 13.,15., 17., 19.}); nd4j::ops::scatter_add op; - auto result = op.execute({&matrix, &idc, &updates}, {}, {}, {true}); + auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}, {true}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -784,7 +784,7 @@ TEST_F(ParityOpsTests, Test_Scatter_Add_6) { auto exp = NDArrayFactory::create('c', {2, 2, 2}, {7, 9, 11, 13, 7, 9, 11, 13}); nd4j::ops::scatter_add op; - auto result = op.execute({&matrix, &idc, &updates}, {}, {}, {true, true}); + auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}, {true, true}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -801,7 +801,7 @@ TEST_F(ParityOpsTests, Test_Scatter_Add_7) { auto exp = NDArrayFactory::create('c', {10, 3}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f,11.f,12.f, 13.f,14.f,15.f, 26.f,37.f,48.f, 19.f,20.f,21.f, 22.f,23.f,24.f, 25.f,26.f,27.f, 28.f,29.f,30.f}); nd4j::ops::scatter_add op; - auto result = op.execute({&matrix, &idc, &updates}, {}, {}); + auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -850,7 +850,7 @@ TEST_F(ParityOpsTests, scatterMax_test1) { auto exp = NDArrayFactory::create('c', {2, 2}, {10, 2, 3, 4}); nd4j::ops::scatter_max op; - auto result = op.execute({&matrix, &idc, &updates}, {}, {}, {true}); + auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}, {true}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -867,7 +867,7 @@ TEST_F(ParityOpsTests, scatterMax_test2) { auto exp = NDArrayFactory::create('c', {1, 4}, {10, 2, 30, 4}); nd4j::ops::scatter_max op; - auto result = op.execute({&vec, &idc, &updates}, {}, {}, {true}); + auto result = op.evaluate({&vec, &idc, &updates}, {}, {}, {true}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -884,7 +884,7 @@ TEST_F(ParityOpsTests, scatterMax_test3) { auto exp = NDArrayFactory::create('c', {2, 2, 2}, {10, 2, 30, 4, 5, 6, 7, 8}); nd4j::ops::scatter_max op; - auto result = op.execute({&matrix, &idc, &updates}, {}, {}, {true}); + auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}, {true}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -901,7 +901,7 @@ TEST_F(ParityOpsTests, scatterMax_test4) { auto exp = NDArrayFactory::create('c', {2, 2, 2}, {1, 10, 10, 10, 5, 6, 7, 8}); nd4j::ops::scatter_max op; - auto result = op.execute({&matrix, &idc, &updates}, {}, {true}, {true}); + auto result = op.evaluate({&matrix, &idc, &updates}, {}, {true}, {true}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -918,7 +918,7 @@ TEST_F(ParityOpsTests, scatterMax_test5) { auto exp = NDArrayFactory::create('c', {2, 2, 3}, {10, 2, 10, 2, 10, 2, 2, 10, 2, 10, 2, 10}); nd4j::ops::scatter_max op; - auto result = op.execute({&matrix, &idc, &updates}, {}, {}, {true}); + auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}, {true}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -935,7 +935,7 @@ TEST_F(ParityOpsTests, scatterMax_test6) { auto exp = NDArrayFactory::create('c', {2, 2, 2}, {2, 1, 2, 1, 1, 2, 1, 2}); nd4j::ops::scatter_max op; - auto result = op.execute({&matrix, &idc, &updates}, {}, {}, {true}); + auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}, {true}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -953,7 +953,7 @@ TEST_F(ParityOpsTests, scatterMin_test1) { auto exp = NDArrayFactory::create('c', {2, 2}, {-1, 1, 3, 4}); nd4j::ops::scatter_min op; - auto result = op.execute({&matrix, &idc, &updates}, {}, {}, {true}); + auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}, {true}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -970,7 +970,7 @@ TEST_F(ParityOpsTests, scatterMin_test2) { auto exp = NDArrayFactory::create('c', {1, 4}, {1, 1, 3, 1}); nd4j::ops::scatter_min op; - auto result = op.execute({&vec, &idc, &updates}, {}, {}, {true}); + auto result = op.evaluate({&vec, &idc, &updates}, {}, {}, {true}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -987,7 +987,7 @@ TEST_F(ParityOpsTests, scatterMin_test3) { auto exp = NDArrayFactory::create('c', {2, 2, 2}, {1, 1, 3, 2, 5, 6, 7, 8}); nd4j::ops::scatter_min op; - auto result = op.execute({&matrix, &idc, &updates}, {}, {}, {true}); + auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}, {true}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -1004,7 +1004,7 @@ TEST_F(ParityOpsTests, scatterMin_test4) { auto exp = NDArrayFactory::create('c', {2, 2, 2}, {1, 1, 1, 1, 5, 6, 7, 8}); nd4j::ops::scatter_min op; - auto result = op.execute({&matrix, &idc, &updates}, {}, {}, {true}); + auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}, {true}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -1036,7 +1036,7 @@ TEST_F(ParityOpsTests, scatterND_test1) { auto exp = NDArrayFactory::create('c', {3, 4}, {50.f, 60.f, 70.f, 80.f, 10.f, 20.f, 30.f, 40.f, 0.f, 0.f, 0.f, 0.f}); nd4j::ops::scatter_nd op; - auto result = op.execute({&indices, &updates, &shape}, {}, {false, true}); + auto result = op.evaluate({&indices, &updates, &shape}, {}, {false, true}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -1058,7 +1058,7 @@ TEST_F(ParityOpsTests, scatterND_test2) { updates.linspace(1.f); nd4j::ops::scatter_nd op; - auto result = op.execute({&indices, &updates, &shape}, {}, {}); + auto result = op.evaluate({&indices, &updates, &shape}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -1083,7 +1083,7 @@ TEST_F(ParityOpsTests, scatterND_test3) { updates.linspace(1.f); nd4j::ops::scatter_nd op; - auto result = op.execute({&indices, &updates, &shape}, {}, {false, true}); + auto result = op.evaluate({&indices, &updates, &shape}, {}, {false, true}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -1103,7 +1103,7 @@ TEST_F(ParityOpsTests, scatterND_test4) { auto exp = NDArrayFactory::create('c', {8}, {0.f, 11.f, 0.f, 10.f, 9.f, 0.f, 0.f, 12.f}); nd4j::ops::scatter_nd op; - auto result = op.execute({&indices, &updates, &shape}, {}, {}); + auto result = op.evaluate({&indices, &updates, &shape}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -1123,7 +1123,7 @@ TEST_F(ParityOpsTests, scatterND_test5) { auto exp = NDArrayFactory::create('c', {8}, {0.f, 10.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}); nd4j::ops::scatter_nd op; - auto result = op.execute({&indices, &updates, &shape}, {}, {}, {true}); + auto result = op.evaluate({&indices, &updates, &shape}, {}, {}, {true}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -1150,7 +1150,7 @@ TEST_F(ParityOpsTests, scatterND_test6) { updates.linspace(1); nd4j::ops::scatter_nd op; - auto result = op.execute({&indices, &updates, &shape}, {}, {}); + auto result = op.evaluate({&indices, &updates, &shape}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -1177,7 +1177,7 @@ TEST_F(ParityOpsTests, scatterND_test7) { updates.linspace(1); nd4j::ops::scatter_nd op; - auto result = op.execute({&indices, &updates, &shape}, {}, {}, {true, true}); + auto result = op.evaluate({&indices, &updates, &shape}, {}, {}, {true, true}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -1198,7 +1198,7 @@ TEST_F(ParityOpsTests, scatterND_test8) { auto exp = NDArrayFactory::create('c', {6,4}, {1, 0, 0, 0, 0, 2, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}); nd4j::ops::scatter_nd op; - auto result = op.execute({&indices, &updates, &shape}, {}, {true}); + auto result = op.evaluate({&indices, &updates, &shape}, {}, {true}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -1233,7 +1233,7 @@ TEST_F(ParityOpsTests, scatterND_add_test1) { auto exp = NDArrayFactory::create('c', {8}, {1.f, 13.f, 3.f, 14.f, 14.f, 6.f, 7.f, 20.f}); nd4j::ops::scatter_nd_add op; - auto result = op.execute({&input, &indices, &updates}, {}, {}); + auto result = op.evaluate({&input, &indices, &updates}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -1256,7 +1256,7 @@ TEST_F(ParityOpsTests, scatterND_add_test2) { updates.linspace(1.f); nd4j::ops::scatter_nd_add op; - auto result = op.execute({&input, &indices, &updates}, {}, {}); + auto result = op.evaluate({&input, &indices, &updates}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -1280,7 +1280,7 @@ TEST_F(ParityOpsTests, scatterND_add_test3) { updates.linspace(1.f); nd4j::ops::scatter_nd_add op; - auto result = op.execute({&input, &indices, &updates}, {}, {}); + auto result = op.evaluate({&input, &indices, &updates}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -1307,7 +1307,7 @@ TEST_F(ParityOpsTests, scatterND_add_test4) { updates.linspace(1.f); nd4j::ops::scatter_nd_add op; - auto result = op.execute({&input, &indices, &updates}, {}, {}); + auto result = op.evaluate({&input, &indices, &updates}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -1343,7 +1343,7 @@ TEST_F(ParityOpsTests, scatterND_add_test5) { updates.linspace(1.f); nd4j::ops::scatter_nd_add op; - auto result = op.execute({&input, &indices, &updates}, {}, {}); + auto result = op.evaluate({&input, &indices, &updates}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -1376,7 +1376,7 @@ TEST_F(ParityOpsTests, scatterND_sub_test1) { auto exp = NDArrayFactory::create('c', {8}, {1.f, -9.f, 3.f, -6.f, -4.f, 6.f, 7.f, -4.f}); nd4j::ops::scatter_nd_sub op; - auto result = op.execute({&input, &indices, &updates}, {}, {}); + auto result = op.evaluate({&input, &indices, &updates}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -1399,7 +1399,7 @@ TEST_F(ParityOpsTests, scatterND_sub_test2) { updates.linspace(1.f); nd4j::ops::scatter_nd_sub op; - auto result = op.execute({&input, &indices, &updates}, {}, {}); + auto result = op.evaluate({&input, &indices, &updates}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -1424,7 +1424,7 @@ TEST_F(ParityOpsTests, scatterND_sub_test3) { updates.linspace(1.f); nd4j::ops::scatter_nd_sub op; - auto result = op.execute({&input, &indices, &updates}, {}, {}); + auto result = op.evaluate({&input, &indices, &updates}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -1451,7 +1451,7 @@ TEST_F(ParityOpsTests, scatterND_sub_test4) { updates.linspace(1.f); nd4j::ops::scatter_nd_sub op; - auto result = op.execute({&input, &indices, &updates}, {}, {}); + auto result = op.evaluate({&input, &indices, &updates}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -1487,7 +1487,7 @@ TEST_F(ParityOpsTests, scatterND_sub_test5) { updates.linspace(1.f); nd4j::ops::scatter_nd_sub op; - auto result = op.execute({&input, &indices, &updates}, {}, {}); + auto result = op.evaluate({&input, &indices, &updates}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -1508,7 +1508,7 @@ TEST_F(ParityOpsTests, scatterND_update_test1) { auto exp = NDArrayFactory::create('c', {8}, {1.f, 11.f, 3.f, 10.f, 9.f, 6.f, 7.f, 12.f}); nd4j::ops::scatter_nd_update op; - auto result = op.execute({&input, &indices, &updates}, {}, {}); + auto result = op.evaluate({&input, &indices, &updates}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -1531,7 +1531,7 @@ TEST_F(ParityOpsTests, scatterND_update_test2) { updates.linspace(1.f); nd4j::ops::scatter_nd_update op; - auto result = op.execute({&input, &indices, &updates}, {}, {}); + auto result = op.evaluate({&input, &indices, &updates}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -1555,7 +1555,7 @@ TEST_F(ParityOpsTests, scatterND_update_test3) { updates.linspace(1.f); nd4j::ops::scatter_nd_update op; - auto result = op.execute({&input, &indices, &updates}, {}, {}); + auto result = op.evaluate({&input, &indices, &updates}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -1583,7 +1583,7 @@ TEST_F(ParityOpsTests, scatterND_update_test4) { updates.linspace(1.f); nd4j::ops::scatter_nd_update op; - auto result = op.execute({&input, &indices, &updates}, {}, {}); + auto result = op.evaluate({&input, &indices, &updates}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -1619,7 +1619,7 @@ TEST_F(ParityOpsTests, scatterND_update_test5) { updates.linspace(1.f); nd4j::ops::scatter_nd_update op; - auto result = op.execute({&input, &indices, &updates}, {}, {}); + auto result = op.evaluate({&input, &indices, &updates}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -1652,7 +1652,7 @@ TEST_F(ParityOpsTests, scatter_update_1) { NDArray exp('c', {2,2}, {30,40,10,20}, nd4j::DataType::INT32); nd4j::ops::scatter_update op; - auto results = op.execute({&x, &updates}, {}, {6, 1,1, 2,1,0}); + auto results = op.evaluate({&x, &updates}, {}, {6, 1,1, 2,1,0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); // x.printBuffer(); @@ -1672,7 +1672,7 @@ TEST_F(ParityOpsTests, scatter_update_2) { NDArray exp('c', {2,2}, {20,10,40,30}, nd4j::DataType::INT32); nd4j::ops::scatter_update op; - auto results = op.execute({&x, &updates}, {}, {6, 1,0, 2,1,0}); + auto results = op.evaluate({&x, &updates}, {}, {6, 1,0, 2,1,0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1691,7 +1691,7 @@ TEST_F(ParityOpsTests, scatter_update_3) { NDArray exp('c', {2,2,2}, {50,60,70,80,10,20,30,40}, nd4j::DataType::INT32); nd4j::ops::scatter_update op; - auto results = op.execute({&x, &updates}, {}, {6, 2,1,2, 2,1,0}); + auto results = op.evaluate({&x, &updates}, {}, {6, 2,1,2, 2,1,0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1710,7 +1710,7 @@ TEST_F(ParityOpsTests, scatter_update_4) { NDArray exp('c', {2,2,2}, {20,2,3,10,60,6,7,50}, nd4j::DataType::INT32); nd4j::ops::scatter_update op; - auto results = op.execute({&x, &updates}, {}, {6, 1,0, 2,3,0}); + auto results = op.evaluate({&x, &updates}, {}, {6, 1,0, 2,3,0}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); diff --git a/libnd4j/tests_cpu/layers_tests/RNGTests.cpp b/libnd4j/tests_cpu/layers_tests/RNGTests.cpp index 0d5572ec6..e6cf01521 100644 --- a/libnd4j/tests_cpu/layers_tests/RNGTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/RNGTests.cpp @@ -257,7 +257,7 @@ TEST_F(RNGTests, Test_Gaussian_21) { ASSERT_FALSE(x0.equalsTo(nexp1)); ASSERT_FALSE(x0.equalsTo(nexp2)); nd4j::ops::moments op; - auto result = op.execute({&x0}, {}, {}); + auto result = op.evaluate({&x0}, {}, {}); //x0.printIndexedBuffer("X0 Normal"); //x1.printIndexedBuffer("X1 Normal"); ASSERT_TRUE(result->status() == Status::OK()); @@ -289,7 +289,7 @@ TEST_F(RNGTests, Test_Gaussian_22) { ASSERT_FALSE(x0.equalsTo(nexp1)); ASSERT_FALSE(x0.equalsTo(nexp2)); nd4j::ops::moments op; - auto result = op.execute({&x0}, {}, {}); + auto result = op.evaluate({&x0}, {}, {}); //x0.printIndexedBuffer("X0 Normal"); //x1.printIndexedBuffer("X1 Normal"); ASSERT_TRUE(result->status() == Status::OK()); @@ -412,14 +412,14 @@ TEST_F(RNGTests, Test_Truncated_21) { ASSERT_NEAR(mean.e(0), 1.f, 0.002); ASSERT_NEAR(deviation.e(0), 2.f, 0.5); nd4j::ops::moments op; - auto result = op.execute({&x0}, {}, {}, {}, false, nd4j::DataType::FLOAT32); + auto result = op.evaluate({&x0}, {}, {}, {}, {}, false); // result->at(0)->printBuffer("MEAN"); // result->at(1)->printBuffer("VARIANCE"); delete result; nd4j::ops::reduce_min minOp; nd4j::ops::reduce_max maxOp; - auto minRes = minOp.execute({&x1}, {}, {}, {}); - auto maxRes = maxOp.execute({&x0}, {}, {}, {}); + auto minRes = minOp.evaluate({&x1}, {}, {}, {}); + auto maxRes = maxOp.evaluate({&x0}, {}, {}, {}); // minRes->at(0)->printBuffer("MIN for Truncated"); // maxRes->at(0)->printBuffer("MAX for Truncated"); @@ -459,14 +459,14 @@ TEST_F(RNGTests, Test_Truncated_22) { ASSERT_NEAR(mean.e(0), 2.f, 0.01); ASSERT_NEAR(deviation.e(0), 4.f, 0.52); nd4j::ops::moments op; - auto result = op.execute({&x0}, {}, {}, {}, false, nd4j::DataType::FLOAT32); + auto result = op.evaluate({&x0}, {}, {}, {}, {}, false); // result->at(0)->printBuffer("MEAN"); // result->at(1)->printBuffer("VARIANCE"); delete result; nd4j::ops::reduce_min minOp; nd4j::ops::reduce_max maxOp; - auto minRes = minOp.execute({&x1}, {}, {}, {}); - auto maxRes = maxOp.execute({&x0}, {}, {}, {}); + auto minRes = minOp.evaluate({&x1}, {}, {}, {}); + auto maxRes = maxOp.evaluate({&x0}, {}, {}, {}); // minRes->at(0)->printBuffer("MIN for Truncated2"); // maxRes->at(0)->printBuffer("MAX for Truncated2"); @@ -506,14 +506,14 @@ TEST_F(RNGTests, Test_Truncated_23) { ASSERT_NEAR(mean.e(0), 0.f, 0.01); ASSERT_NEAR(deviation.e(0), 1.f, 0.5); nd4j::ops::moments op; - auto result = op.execute({&x0}, {}, {}, {}, false, nd4j::DataType::FLOAT32); + auto result = op.evaluate({&x0}); // result->at(0)->printBuffer("MEAN"); // result->at(1)->printBuffer("VARIANCE"); delete result; nd4j::ops::reduce_min minOp; nd4j::ops::reduce_max maxOp; - auto minRes = minOp.execute({&x1}, {}, {}, {}); - auto maxRes = maxOp.execute({&x0}, {}, {}, {}); + auto minRes = minOp.evaluate({&x1}, {}, {}, {}); + auto maxRes = maxOp.evaluate({&x0}, {}, {}, {}); // minRes->at(0)->printBuffer("MIN for Truncated3"); // maxRes->at(0)->printBuffer("MAX for Truncated3"); @@ -686,7 +686,7 @@ TEST_F(RNGTests, Test_GaussianDistribution_1) { nd4j::ops::random_normal op; - auto result = op.execute({&x}, {0.0, 1.0f}, {}); + auto result = op.evaluate({&x}, {0.0, 1.0f}, {}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -707,7 +707,7 @@ TEST_F(RNGTests, Test_BernoulliDistribution_1) { nd4j::ops::random_bernoulli op; - auto result = op.execute({&x}, {0.5f}, {}); + auto result = op.evaluate({&x}, {0.5f}, {}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -728,7 +728,7 @@ TEST_F(RNGTests, Test_ExponentialDistribution_1) { nd4j::ops::random_exponential op; - auto result = op.execute({&x}, {0.25f}, {0}); + auto result = op.evaluate({&x}, {0.25f}, {0}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -752,7 +752,7 @@ TEST_F(RNGTests, Test_ExponentialDistribution_2) { nd4j::ops::random_exponential op; - auto result = op.execute({&x, &y}, {0.25f}, {0}); + auto result = op.evaluate({&x, &y}, {0.25f}, {0}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -776,7 +776,7 @@ TEST_F(RNGTests, Test_PoissonDistribution_1) { nd4j::ops::random_poisson op; - auto result = op.execute({&x, &la}, {}, {}); + auto result = op.evaluate({&x, &la}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -796,7 +796,7 @@ TEST_F(RNGTests, Test_GammaDistribution_1) { nd4j::ops::random_gamma op; - auto result = op.execute({&x, &al}, {}, {}); + auto result = op.evaluate({&x, &al}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -817,7 +817,7 @@ TEST_F(RNGTests, Test_GammaDistribution_2) { be.assign(1.0); nd4j::ops::random_gamma op; - auto result = op.execute({&x, &al, &be}, {}, {}); + auto result = op.evaluate({&x, &al, &be}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -838,7 +838,7 @@ TEST_F(RNGTests, Test_GammaDistribution_3) { be.assign(2.0); nd4j::ops::random_gamma op; - auto result = op.execute({&x, &al, &be}, {}, {}); + auto result = op.evaluate({&x, &al, &be}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -857,7 +857,7 @@ TEST_F(RNGTests, Test_UniformDistribution_04) { nd4j::ops::randomuniform op; - auto result = op.execute({&x, &al, &be}, {}, {DataType::INT32}); + auto result = op.evaluate({&x, &al, &be}, {}, {DataType::INT32}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -878,7 +878,7 @@ namespace nd4j { auto min = NDArrayFactory::create(0.0); auto max = NDArrayFactory::create(1.0); nd4j::ops::randomuniform op; - op.execute(*rng, {&arrayI, &min, &max}, {arrayR}, {}, {DataType::DOUBLE}, {}, false); + op.execute(*rng, {&arrayI, &min, &max}, {arrayR}, {}, {DataType::DOUBLE}, {}, {}, false); list.emplace_back(arrayR); } @@ -1013,14 +1013,14 @@ TEST_F(RNGTests, test_multinomial_1) { nd4j::ops::random_multinomial op; RandomGenerator rng(1234, 1234); - ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 0, INT64}, {}, false) ); + ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 0, INT64}, {}, {}, false) ); ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); NDArray probsZ('c', { 1, 3 }, { 0.3, 0.3, 0.3 }, nd4j::DataType::FLOAT32); NDArray expectedZ('c', { 3, 3 }, { 0, 0, 0, 0, 0, 0, 0, 0, 0 }, nd4j::DataType::INT64); - auto result = op.execute({ &probsZ, &samples }, { }, { 1, INT64 }); + auto result = op.evaluate({ &probsZ, &samples }, { }, { 1, INT64 }); auto outputZ = result->at(0); ASSERT_EQ(Status::OK(), result->status()); @@ -1038,7 +1038,7 @@ TEST_F(RNGTests, test_multinomial_2) { nd4j::ops::random_multinomial op; RandomGenerator rng(1234, 1234); - ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 0, INT64 }, {}, false)); + ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 0, INT64 }, {}, {}, false)); ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); @@ -1047,7 +1047,7 @@ TEST_F(RNGTests, test_multinomial_2) { NDArray output2('c', { 20, 3 }, nd4j::DataType::INT64); rng.setStates(1234, 1234); - ASSERT_EQ(Status::OK(), op.execute(rng, { &probs2, &samples }, { &output2 }, {}, { 1, INT64 }, {}, false)); + ASSERT_EQ(Status::OK(), op.execute(rng, { &probs2, &samples }, { &output2 }, {}, { 1, INT64 }, {}, {}, false)); ASSERT_TRUE(expected2.isSameShape(output2)); ASSERT_TRUE(expected2.equalsTo(output2)); } @@ -1061,10 +1061,10 @@ TEST_F(RNGTests, test_multinomial_3) { RandomGenerator rng(1234, 1234); nd4j::ops::random_multinomial op; - ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &expected }, {}, { 0, INT64 }, {}, false)); + ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &expected }, {}, { 0, INT64 }, {}, {}, false)); rng.setStates(1234, 1234); - ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 0, INT64 }, {}, false)); + ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 0, INT64 }, {}, {}, false)); ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); } @@ -1078,10 +1078,10 @@ TEST_F(RNGTests, test_multinomial_4) { RandomGenerator rng(1234, 1234); nd4j::ops::random_multinomial op; - ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &expected }, {}, { 1, INT64 }, {}, false)); + ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &expected }, {}, { 1, INT64 }, {}, {}, false)); rng.setStates(1234, 1234); - ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 1, INT64 }, {}, false)); + ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 1, INT64 }, {}, {}, false)); ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); } @@ -1101,7 +1101,7 @@ TEST_F(RNGTests, test_multinomial_5) { NDArray output('c', { Samples, batchValue }, nd4j::DataType::INT64); RandomGenerator rng(1234, 1234); - ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 1 }, {}, false)); + ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 1 }, {}, {}, false)); auto deviation = output.varianceNumber(variance::SummaryStatsStandardDeviation, false); auto mean = output.meanNumber(); @@ -1115,7 +1115,7 @@ TEST_F(RNGTests, test_multinomial_5) { ASSERT_TRUE(value >= 0 && value < ClassValue); } - auto resultR = op.execute({ &probs, &samples }, { }, { 1 }); + auto resultR = op.evaluate({ &probs, &samples }, { }, { 1 }); auto outputR = resultR->at(0); ASSERT_EQ(Status::OK(), resultR->status()); @@ -1148,7 +1148,7 @@ TEST_F(RNGTests, test_multinomial_6) { // without seed NDArray probsR('c', { batchValue, ClassValue }, { 1., 1.5, 2., 2.5, 3. }, nd4j::DataType::FLOAT32); - auto resultR = op.execute({ &probsR, &samples }, { }, { 0 }); + auto resultR = op.evaluate({ &probsR, &samples }, { }, { 0 }); auto outputR = resultR->at(0); ASSERT_EQ(Status::OK(), resultR->status()); @@ -1180,7 +1180,7 @@ TEST_F(RNGTests, test_multinomial_6) { NDArray probs('c', { batchValue, ClassValue }, { 1., 1.5, 2., 2.5, 3. }, nd4j::DataType::FLOAT32); NDArray output('c', { batchValue, Samples }, nd4j::DataType::INT64); - ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 0, INT64 }, {}, false)); + ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 0, INT64 }, {}, {}, false)); NDArray counts('c', { ClassValue }, { 0, 0, 0, 0, 0 }, nd4j::DataType::DOUBLE); diff --git a/libnd4j/tests_cpu/layers_tests/ScalarTests.cpp b/libnd4j/tests_cpu/layers_tests/ScalarTests.cpp index 310ad95ad..881a33c2e 100644 --- a/libnd4j/tests_cpu/layers_tests/ScalarTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/ScalarTests.cpp @@ -94,7 +94,7 @@ TEST_F(ScalarTests, Test_Concat_1) { auto exp = NDArrayFactory::create('c', {3}, {1, 2, 3}); nd4j::ops::concat op; - auto result = op.execute({&t, &u, &v}, {}, {0}); + auto result = op.evaluate({&t, &u, &v}, {}, {0}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -114,7 +114,7 @@ TEST_F(ScalarTests, Test_Concat_2) { auto exp = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); nd4j::ops::concat op; - auto result = op.execute({&t, &u, &v}, {}, {0}); + auto result = op.evaluate({&t, &u, &v}, {}, {0}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -135,7 +135,7 @@ TEST_F(ScalarTests, Test_Concat_3) { auto exp = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); nd4j::ops::concat op; - auto result = op.execute({&t, &u, &v}, {}, {0}); + auto result = op.evaluate({&t, &u, &v}, {}, {0}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -154,7 +154,7 @@ TEST_F(ScalarTests, Test_ExpandDims_1) { auto exp = NDArrayFactory::create('c', {1}, {2.0f}); nd4j::ops::expand_dims op; - auto result = op.execute({&x}, {}, {0}); + auto result = op.evaluate({&x}, {}, {0}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -171,7 +171,7 @@ TEST_F(ScalarTests, Test_Squeeze_1) { auto exp = NDArrayFactory::create(2.0f); nd4j::ops::squeeze op; - auto result = op.execute({&x}, {}, {}); + auto result = op.evaluate({&x}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -188,7 +188,7 @@ TEST_F(ScalarTests, Test_Reshape_1) { auto exp = NDArrayFactory::create('c', {1, 1, 1}, {2.0f}); nd4j::ops::reshape op; - auto result = op.execute({&x}, {}, {-99, 1, 1, 1}); + auto result = op.evaluate({&x}, {}, {-99, 1, 1, 1}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -205,7 +205,7 @@ TEST_F(ScalarTests, Test_Permute_1) { auto exp = NDArrayFactory::create(3.0f); nd4j::ops::permute op; - auto result = op.execute({&x}, {}, {0}); + auto result = op.evaluate({&x}, {}, {0}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -224,7 +224,7 @@ TEST_F(ScalarTests, Test_Stack_1) { auto exp = NDArrayFactory::create('c', {3}, {1, 2, 3}); nd4j::ops::stack op; - auto result = op.execute({&t, &u, &v}, {}, {0}); + auto result = op.evaluate({&t, &u, &v}, {}, {0}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -243,7 +243,7 @@ TEST_F(ScalarTests, Test_Stack_2) { auto exp = NDArrayFactory::create('c', {4, 1, 1}, {1, 2, 3, 4}); nd4j::ops::stack op; - auto result = op.execute({&t, &u, &v, &w}, {}, {0}); + auto result = op.evaluate({&t, &u, &v, &w}, {}, {0}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -265,7 +265,7 @@ TEST_F(ScalarTests, Test_Concat_Scalar_1) { auto exp = NDArrayFactory::create('c', {4, 1}, {1, 2, 3, 4}); nd4j::ops::concat op; - auto result = op.execute({&t, &u, &v, &w}, {}, {0}); + auto result = op.evaluate({&t, &u, &v, &w}, {}, {0}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -285,7 +285,7 @@ TEST_F(ScalarTests, Test_Concat_Scalar_2) { auto exp = NDArrayFactory::create('c', {1, 4}, {1, 2, 3, 4}); nd4j::ops::concat op; - auto result = op.execute({&t, &u, &v, &w}, {}, {1}); + auto result = op.evaluate({&t, &u, &v, &w}, {}, {1}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); diff --git a/libnd4j/tests_cpu/layers_tests/ShapeTests.cpp b/libnd4j/tests_cpu/layers_tests/ShapeTests.cpp index 071c33fab..003474fab 100644 --- a/libnd4j/tests_cpu/layers_tests/ShapeTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/ShapeTests.cpp @@ -307,7 +307,7 @@ TEST_F(ShapeTests, Tests_Transpose_119_2) { auto exp = x.transpose(); nd4j::ops::transpose op; - auto result = op.execute({&x},{}, {}); + auto result = op.evaluate({&x}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); diff --git a/libnd4j/tests_cpu/layers_tests/SingleDimTests.cpp b/libnd4j/tests_cpu/layers_tests/SingleDimTests.cpp index ab94eeb6f..c2d8bce04 100644 --- a/libnd4j/tests_cpu/layers_tests/SingleDimTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/SingleDimTests.cpp @@ -68,7 +68,7 @@ TEST_F(SingleDimTests, Test_Concat_1) { auto exp = NDArrayFactory::create('c', {6}, {1, 2, 3, 4, 5, 6}); nd4j::ops::concat op; - auto result = op.execute({&x, &y}, {}, {0}); + auto result = op.evaluate({&x, &y}, {}, {0}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -102,7 +102,7 @@ TEST_F(SingleDimTests, Test_ExpandDims_1) { auto exp = NDArrayFactory::create('c', {1, 3}, {1, 2, 3}); nd4j::ops::expand_dims op; - auto result = op.execute({&x}, {}, {0}); + auto result = op.evaluate({&x}, {}, {0}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -120,7 +120,7 @@ TEST_F(SingleDimTests, Test_ExpandDims_2) { auto exp = NDArrayFactory::create('c', {3, 1}, {1, 2, 3}); nd4j::ops::expand_dims op; - auto result = op.execute({&x}, {}, {1}); + auto result = op.evaluate({&x}, {}, {1}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -140,7 +140,7 @@ TEST_F(SingleDimTests, Test_Squeeze_1) { auto exp = NDArrayFactory::create(3.0f); nd4j::ops::squeeze op; - auto result = op.execute({&x}, {}, {}); + auto result = op.evaluate({&x}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -157,7 +157,7 @@ TEST_F(SingleDimTests, Test_Squeeze_2) { auto exp = NDArrayFactory::create('c', {3}, {1, 2, 3}); nd4j::ops::squeeze op; - auto result = op.execute({&x}, {}, {}); + auto result = op.evaluate({&x}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -173,7 +173,7 @@ TEST_F(SingleDimTests, Test_Reshape_1) { auto exp = NDArrayFactory::create('c', {3}, {1, 2, 3}); nd4j::ops::reshape op; - auto result = op.execute({&x}, {}, {-99, 3}); + auto result = op.evaluate({&x}, {}, {-99, 3}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -189,7 +189,7 @@ TEST_F(SingleDimTests, Test_Reshape_2) { auto exp = NDArrayFactory::create('c', {1, 3}, {1, 2, 3}); nd4j::ops::reshape op; - auto result = op.execute({&x}, {}, {-99, 1, 3}); + auto result = op.evaluate({&x}, {}, {-99, 1, 3}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -206,7 +206,7 @@ TEST_F(SingleDimTests, Test_Permute_1) { auto exp = NDArrayFactory::create('c', {3}, {1, 2, 3}); nd4j::ops::permute op; - auto result = op.execute({&x}, {}, {0}); + auto result = op.evaluate({&x}, {}, {0}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/graph/DType.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/graph/DType.java index 2617ce8f6..8a316515a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/graph/DType.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/graph/DType.java @@ -23,8 +23,10 @@ public final class DType { public static final byte QINT16 = 16; public static final byte BFLOAT16 = 17; public static final byte UTF8 = 50; + public static final byte UTF16 = 51; + public static final byte UTF32 = 52; - public static final String[] names = { "INHERIT", "BOOL", "FLOAT8", "HALF", "HALF2", "FLOAT", "DOUBLE", "INT8", "INT16", "INT32", "INT64", "UINT8", "UINT16", "UINT32", "UINT64", "QINT8", "QINT16", "BFLOAT16", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "UTF8", }; + public static final String[] names = { "INHERIT", "BOOL", "FLOAT8", "HALF", "HALF2", "FLOAT", "DOUBLE", "INT8", "INT16", "INT32", "INT64", "UINT8", "UINT16", "UINT32", "UINT64", "QINT8", "QINT16", "BFLOAT16", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "UTF8", "UTF16", "UTF32", }; public static String name(int e) { return names[e]; } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/graph/FlatNode.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/graph/FlatNode.java index ef116b97b..ca411435d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/graph/FlatNode.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/graph/FlatNode.java @@ -75,28 +75,28 @@ public final class FlatNode extends Table { public int controlDepForLength() { int o = __offset(46); return o != 0 ? __vector_len(o) : 0; } public static int createFlatNode(FlatBufferBuilder builder, - int id, - int nameOffset, - byte opType, - long opNum, - int propertiesOffset, - int inputOffset, - int inputPairedOffset, - int outputOffset, - int extraParamsOffset, - int extraIntegerOffset, - int extraBoolsOffset, - int dimensionsOffset, - int device, - int scope_id, - int scope_nameOffset, - int outputNamesOffset, - int opNameOffset, - int outputTypesOffset, - int scalarOffset, - int controlDepsOffset, - int varControlDepsOffset, - int controlDepForOffset) { + int id, + int nameOffset, + byte opType, + long opNum, + int propertiesOffset, + int inputOffset, + int inputPairedOffset, + int outputOffset, + int extraParamsOffset, + int extraIntegerOffset, + int extraBoolsOffset, + int dimensionsOffset, + int device, + int scope_id, + int scope_nameOffset, + int outputNamesOffset, + int opNameOffset, + int outputTypesOffset, + int scalarOffset, + int controlDepsOffset, + int varControlDepsOffset, + int controlDepForOffset) { builder.startObject(22); FlatNode.addOpNum(builder, opNum); FlatNode.addControlDepFor(builder, controlDepForOffset); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/graph/FlatVariable.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/graph/FlatVariable.java index 4845f7320..3c8f151b2 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/graph/FlatVariable.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/graph/FlatVariable.java @@ -37,16 +37,16 @@ public final class FlatVariable extends Table { public int controlDepsForVarLength() { int o = __offset(22); return o != 0 ? __vector_len(o) : 0; } public static int createFlatVariable(FlatBufferBuilder builder, - int idOffset, - int nameOffset, - byte dtype, - int shapeOffset, - int ndarrayOffset, - int device, - byte variabletype, - int controlDepsOffset, - int controlDepForOpOffset, - int controlDepsForVarOffset) { + int idOffset, + int nameOffset, + byte dtype, + int shapeOffset, + int ndarrayOffset, + int device, + byte variabletype, + int controlDepsOffset, + int controlDepForOpOffset, + int controlDepsForVarOffset) { builder.startObject(10); FlatVariable.addControlDepsForVar(builder, controlDepsForVarOffset); FlatVariable.addControlDepForOp(builder, controlDepForOpOffset); @@ -88,4 +88,3 @@ public final class FlatVariable extends Table { public static void finishSizePrefixedFlatVariableBuffer(FlatBufferBuilder builder, int offset) { builder.finishSizePrefixed(offset); } } - diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/graph/UIEvent.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/graph/UIEvent.java index f22d3c6f8..1ea69d48a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/graph/UIEvent.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/graph/UIEvent.java @@ -27,15 +27,15 @@ public final class UIEvent extends Table { public int plugin() { int o = __offset(20); return o != 0 ? bb.getShort(o + bb_pos) & 0xFFFF : 0; } public static int createUIEvent(FlatBufferBuilder builder, - byte eventType, - byte eventSubType, - int nameIdx, - long timestamp, - int iteration, - int epoch, - short variableId, - int frameIterOffset, - int plugin) { + byte eventType, + byte eventSubType, + int nameIdx, + long timestamp, + int iteration, + int epoch, + short variableId, + int frameIterOffset, + int plugin) { builder.startObject(9); UIEvent.addTimestamp(builder, timestamp); UIEvent.addFrameIter(builder, frameIterOffset); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/graph/UIOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/graph/UIOp.java index e86fa3f3e..874aa1a8a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/graph/UIOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/graph/UIOp.java @@ -32,12 +32,12 @@ public final class UIOp extends Table { public ByteBuffer uiLabelExtraInByteBuffer(ByteBuffer _bb) { return __vector_in_bytebuffer(_bb, 14, 1); } public static int createUIOp(FlatBufferBuilder builder, - int nameOffset, - int opNameOffset, - int inputsOffset, - int outputsOffset, - int controlDepsOffset, - int uiLabelExtraOffset) { + int nameOffset, + int opNameOffset, + int inputsOffset, + int outputsOffset, + int controlDepsOffset, + int uiLabelExtraOffset) { builder.startObject(6); UIOp.addUiLabelExtra(builder, uiLabelExtraOffset); UIOp.addControlDeps(builder, controlDepsOffset); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/graph/UIVariable.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/graph/UIVariable.java index 8a474ff0e..b56acf8a4 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/graph/UIVariable.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/graph/UIVariable.java @@ -47,19 +47,19 @@ public final class UIVariable extends Table { public FlatArray constantValue(FlatArray obj) { int o = __offset(28); return o != 0 ? obj.__assign(__indirect(o + bb_pos), bb) : null; } public static int createUIVariable(FlatBufferBuilder builder, - int idOffset, - int nameOffset, - byte type, - byte datatype, - int shapeOffset, - int controlDepsOffset, - int outputOfOpOffset, - int inputsForOpOffset, - int controlDepsForOpOffset, - int controlDepsForVarOffset, - int gradientVariableOffset, - int uiLabelExtraOffset, - int constantValueOffset) { + int idOffset, + int nameOffset, + byte type, + byte datatype, + int shapeOffset, + int controlDepsOffset, + int outputOfOpOffset, + int inputsForOpOffset, + int controlDepsForOpOffset, + int controlDepsForVarOffset, + int gradientVariableOffset, + int uiLabelExtraOffset, + int constantValueOffset) { builder.startObject(13); UIVariable.addConstantValue(builder, constantValueOffset); UIVariable.addUiLabelExtra(builder, uiLabelExtraOffset); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java index 2618b2d22..d1840ab63 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java @@ -3096,6 +3096,9 @@ public native void setGraphContextInputArray(OpaqueContext ptr, int index, Point public native void setGraphContextOutputArray(OpaqueContext ptr, int index, Pointer buffer, Pointer shapeInfo, Pointer specialBuffer, Pointer specialShapeInfo); public native void setGraphContextInputBuffer(OpaqueContext ptr, int index, OpaqueDataBuffer buffer, Pointer shapeInfo, Pointer specialShapeInfo); public native void setGraphContextOutputBuffer(OpaqueContext ptr, int index, OpaqueDataBuffer buffer, Pointer shapeInfo, Pointer specialShapeInfo); +public native void setGraphContextDArguments(OpaqueContext ptr, IntPointer arguments, int numberOfArguments); +public native void setGraphContextDArguments(OpaqueContext ptr, IntBuffer arguments, int numberOfArguments); +public native void setGraphContextDArguments(OpaqueContext ptr, int[] arguments, int numberOfArguments); public native void setGraphContextTArguments(OpaqueContext ptr, DoublePointer arguments, int numberOfArguments); public native void setGraphContextTArguments(OpaqueContext ptr, DoubleBuffer arguments, int numberOfArguments); public native void setGraphContextTArguments(OpaqueContext ptr, double[] arguments, int numberOfArguments); @@ -6435,6 +6438,9 @@ public native @Cast("bool") boolean isOptimalRequirementsMet(); public native void setIArguments(@Cast("Nd4jLong*") long[] arguments, int numberOfArguments); public native void setBArguments(@Cast("bool*") BooleanPointer arguments, int numberOfArguments); public native void setBArguments(@Cast("bool*") boolean[] arguments, int numberOfArguments); + public native void setDArguments(@Cast("nd4j::DataType*") IntPointer arguments, int numberOfArguments); + public native void setDArguments(@Cast("nd4j::DataType*") IntBuffer arguments, int numberOfArguments); + public native void setDArguments(@Cast("nd4j::DataType*") int[] arguments, int numberOfArguments); public native void setTArguments(@StdVector DoublePointer tArgs); public native void setTArguments(@StdVector DoubleBuffer tArgs); @@ -6444,6 +6450,9 @@ public native @Cast("bool") boolean isOptimalRequirementsMet(); public native void setIArguments(@Cast("Nd4jLong*") @StdVector long[] tArgs); public native void setBArguments(@Cast("bool*") @StdVector BooleanPointer tArgs); public native void setBArguments(@Cast("bool*") @StdVector boolean[] tArgs); + public native void setDArguments(@Cast("nd4j::DataType*") @StdVector IntPointer dArgs); + public native void setDArguments(@Cast("nd4j::DataType*") @StdVector IntBuffer dArgs); + public native void setDArguments(@Cast("nd4j::DataType*") @StdVector int[] dArgs); public native void setCudaContext(@Cast("Nd4jPointer") Pointer cudaStream, @Cast("Nd4jPointer") Pointer reductionPointer, @Cast("Nd4jPointer") Pointer allocationPointer); @@ -6547,6 +6556,7 @@ public native @Cast("bool") boolean isOptimalRequirementsMet(); public native @StdVector DoublePointer getTArguments(); public native @StdVector IntPointer getIArguments(); public native @Cast("bool*") @StdVector BooleanPointer getBArguments(); + public native @Cast("nd4j::DataType*") @StdVector IntPointer getDArguments(); public native @StdVector IntPointer getAxis(); public native @Cast("samediff::Engine") int engine(); @@ -6554,6 +6564,7 @@ public native @Cast("bool") boolean isOptimalRequirementsMet(); public native @Cast("size_t") long numT(); public native @Cast("size_t") long numI(); public native @Cast("size_t") long numB(); + public native @Cast("size_t") long numD(); public native IntIntPair input(int idx); @@ -9418,39 +9429,43 @@ public static final int PREALLOC_SIZE = 33554432; */ public native @Cast("Nd4jStatus") int execute(Context block); - public native ResultSet execute(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/); - public native ResultSet execute(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs); - public native ResultSet execute(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/); - public native ResultSet execute(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs); - public native ResultSet execute(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/); - public native ResultSet execute(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs); - public native ResultSet execute(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/); - public native ResultSet execute(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/); - public native ResultSet execute(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/); - public native @Cast("Nd4jStatus") int execute(@ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/); - public native @Cast("Nd4jStatus") int execute(@ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs); - public native @Cast("Nd4jStatus") int execute(@ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector boolean[] bArgs, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/); - public native @Cast("Nd4jStatus") int execute(@ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector boolean[] bArgs); - public native @Cast("Nd4jStatus") int execute(@ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/); - public native @Cast("Nd4jStatus") int execute(@ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs); - public native @Cast("Nd4jStatus") int execute(@ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector boolean[] bArgs, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/); - public native @Cast("Nd4jStatus") int execute(@ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector boolean[] bArgs); - public native @Cast("Nd4jStatus") int execute(@ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/); - public native @Cast("Nd4jStatus") int execute(@ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs); - public native @Cast("Nd4jStatus") int execute(@ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/); - public native @Cast("Nd4jStatus") int execute(@ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs); - public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/); - public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs); - public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector boolean[] bArgs, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/); - public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector boolean[] bArgs); - public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/); - public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs); - public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector boolean[] bArgs, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/); - public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector boolean[] bArgs); - public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/); - public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs); - public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/); - public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs); + public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs); + + public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector()*/, @Cast("nd4j::DataType*") @StdVector IntPointer dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/); + public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs); + public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector()*/, @Cast("nd4j::DataType*") @StdVector IntBuffer dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/); + public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs); + public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector()*/, @Cast("nd4j::DataType*") @StdVector int[] dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/); + public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs); + public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector()*/, @Cast("nd4j::DataType*") @StdVector IntPointer dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/); + public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector()*/, @Cast("nd4j::DataType*") @StdVector IntBuffer dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/); + public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector()*/, @Cast("nd4j::DataType*") @StdVector int[] dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/); + + + public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs); + + public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector()*/, @Cast("nd4j::DataType*") @StdVector IntPointer dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/); + public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs); + public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector()*/, @Cast("nd4j::DataType*") @StdVector IntBuffer dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/); + public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs); + public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector()*/, @Cast("nd4j::DataType*") @StdVector int[] dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/); + public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs); + public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector()*/, @Cast("nd4j::DataType*") @StdVector IntPointer dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/); + public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector()*/, @Cast("nd4j::DataType*") @StdVector IntBuffer dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/); + public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector()*/, @Cast("nd4j::DataType*") @StdVector int[] dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/); + + public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs, @Cast("nd4j::DataType*") @StdVector IntPointer dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/); + public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs); + public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector boolean[] bArgs, @Cast("nd4j::DataType*") @StdVector IntBuffer dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/); + public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector boolean[] bArgs); + public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs, @Cast("nd4j::DataType*") @StdVector int[] dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/); + public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs); + public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector boolean[] bArgs, @Cast("nd4j::DataType*") @StdVector IntPointer dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/); + public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector boolean[] bArgs); + public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs, @Cast("nd4j::DataType*") @StdVector IntBuffer dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/); + public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs); + public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs, @Cast("nd4j::DataType*") @StdVector int[] dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/); + public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs); public native ResultSet execute(@Const @ByRef OpArgsHolder holder, @Cast("bool") boolean isInplace/*=false*/); public native ResultSet execute(@Const @ByRef OpArgsHolder holder); @@ -9649,8 +9664,9 @@ public static final int PREALLOC_SIZE = 33554432; /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ public BooleanOp(Pointer p) { super(p); } - public native @Cast("bool") boolean evaluate(@ByRef NDArrayVector args); - public native @Cast("bool") boolean evaluate(@ByRef Context block); + + public native @Cast("bool") boolean verify(@Const @ByRef NDArrayVector args); + public native @Cast("bool") boolean verify(@ByRef Context block); public native @Cast("Nd4jStatus") int execute(Context block); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java index c8a3b693e..40a2f5236 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java @@ -3099,6 +3099,9 @@ public native void setGraphContextInputArray(OpaqueContext ptr, int index, Point public native void setGraphContextOutputArray(OpaqueContext ptr, int index, Pointer buffer, Pointer shapeInfo, Pointer specialBuffer, Pointer specialShapeInfo); public native void setGraphContextInputBuffer(OpaqueContext ptr, int index, OpaqueDataBuffer buffer, Pointer shapeInfo, Pointer specialShapeInfo); public native void setGraphContextOutputBuffer(OpaqueContext ptr, int index, OpaqueDataBuffer buffer, Pointer shapeInfo, Pointer specialShapeInfo); +public native void setGraphContextDArguments(OpaqueContext ptr, IntPointer arguments, int numberOfArguments); +public native void setGraphContextDArguments(OpaqueContext ptr, IntBuffer arguments, int numberOfArguments); +public native void setGraphContextDArguments(OpaqueContext ptr, int[] arguments, int numberOfArguments); public native void setGraphContextTArguments(OpaqueContext ptr, DoublePointer arguments, int numberOfArguments); public native void setGraphContextTArguments(OpaqueContext ptr, DoubleBuffer arguments, int numberOfArguments); public native void setGraphContextTArguments(OpaqueContext ptr, double[] arguments, int numberOfArguments); @@ -6438,6 +6441,9 @@ public native @Cast("bool") boolean isOptimalRequirementsMet(); public native void setIArguments(@Cast("Nd4jLong*") long[] arguments, int numberOfArguments); public native void setBArguments(@Cast("bool*") BooleanPointer arguments, int numberOfArguments); public native void setBArguments(@Cast("bool*") boolean[] arguments, int numberOfArguments); + public native void setDArguments(@Cast("nd4j::DataType*") IntPointer arguments, int numberOfArguments); + public native void setDArguments(@Cast("nd4j::DataType*") IntBuffer arguments, int numberOfArguments); + public native void setDArguments(@Cast("nd4j::DataType*") int[] arguments, int numberOfArguments); public native void setTArguments(@StdVector DoublePointer tArgs); public native void setTArguments(@StdVector DoubleBuffer tArgs); @@ -6447,6 +6453,9 @@ public native @Cast("bool") boolean isOptimalRequirementsMet(); public native void setIArguments(@Cast("Nd4jLong*") @StdVector long[] tArgs); public native void setBArguments(@Cast("bool*") @StdVector BooleanPointer tArgs); public native void setBArguments(@Cast("bool*") @StdVector boolean[] tArgs); + public native void setDArguments(@Cast("nd4j::DataType*") @StdVector IntPointer dArgs); + public native void setDArguments(@Cast("nd4j::DataType*") @StdVector IntBuffer dArgs); + public native void setDArguments(@Cast("nd4j::DataType*") @StdVector int[] dArgs); public native void setCudaContext(@Cast("Nd4jPointer") Pointer cudaStream, @Cast("Nd4jPointer") Pointer reductionPointer, @Cast("Nd4jPointer") Pointer allocationPointer); @@ -6550,6 +6559,7 @@ public native @Cast("bool") boolean isOptimalRequirementsMet(); public native @StdVector DoublePointer getTArguments(); public native @StdVector IntPointer getIArguments(); public native @Cast("bool*") @StdVector BooleanPointer getBArguments(); + public native @Cast("nd4j::DataType*") @StdVector IntPointer getDArguments(); public native @StdVector IntPointer getAxis(); public native @Cast("samediff::Engine") int engine(); @@ -6557,6 +6567,7 @@ public native @Cast("bool") boolean isOptimalRequirementsMet(); public native @Cast("size_t") long numT(); public native @Cast("size_t") long numI(); public native @Cast("size_t") long numB(); + public native @Cast("size_t") long numD(); public native IntIntPair input(int idx); @@ -11130,7 +11141,9 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); // #define INPUT_LIST(INDEX) reinterpret_cast(block.getVariable(INDEX)->getNDArrayList()) +// #define D_ARG(INDEX) block.getDArguments()->at(INDEX) // #define INT_ARG(INDEX) block.getIArguments()->at(INDEX) +// #define I_ARG(INDEX) INT_ARG(INDEX) // #define T_ARG(INDEX) block.getTArguments()->at(INDEX) // #define B_ARG(INDEX) block.getBArguments()->at(INDEX) @@ -11629,39 +11642,43 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); */ public native @Cast("Nd4jStatus") int execute(Context block); - public native ResultSet execute(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/); - public native ResultSet execute(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs); - public native ResultSet execute(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/); - public native ResultSet execute(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs); - public native ResultSet execute(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/); - public native ResultSet execute(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs); - public native ResultSet execute(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/); - public native ResultSet execute(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/); - public native ResultSet execute(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/); - public native @Cast("Nd4jStatus") int execute(@ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/); - public native @Cast("Nd4jStatus") int execute(@ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs); - public native @Cast("Nd4jStatus") int execute(@ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector boolean[] bArgs, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/); - public native @Cast("Nd4jStatus") int execute(@ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector boolean[] bArgs); - public native @Cast("Nd4jStatus") int execute(@ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/); - public native @Cast("Nd4jStatus") int execute(@ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs); - public native @Cast("Nd4jStatus") int execute(@ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector boolean[] bArgs, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/); - public native @Cast("Nd4jStatus") int execute(@ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector boolean[] bArgs); - public native @Cast("Nd4jStatus") int execute(@ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/); - public native @Cast("Nd4jStatus") int execute(@ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs); - public native @Cast("Nd4jStatus") int execute(@ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/); - public native @Cast("Nd4jStatus") int execute(@ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs); - public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/); - public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs); - public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector boolean[] bArgs, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/); - public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector boolean[] bArgs); - public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/); - public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs); - public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector boolean[] bArgs, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/); - public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector boolean[] bArgs); - public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/); - public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs); - public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/); - public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @ByRef NDArrayVector inputs, @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs); + public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs); + + public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector()*/, @Cast("nd4j::DataType*") @StdVector IntPointer dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/); + public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs); + public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector()*/, @Cast("nd4j::DataType*") @StdVector IntBuffer dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/); + public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs); + public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector()*/, @Cast("nd4j::DataType*") @StdVector int[] dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/); + public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs); + public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector()*/, @Cast("nd4j::DataType*") @StdVector IntPointer dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/); + public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector()*/, @Cast("nd4j::DataType*") @StdVector IntBuffer dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/); + public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector()*/, @Cast("nd4j::DataType*") @StdVector int[] dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/); + + + public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs); + + public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector()*/, @Cast("nd4j::DataType*") @StdVector IntPointer dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/); + public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs); + public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector()*/, @Cast("nd4j::DataType*") @StdVector IntBuffer dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/); + public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs); + public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector()*/, @Cast("nd4j::DataType*") @StdVector int[] dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/); + public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs); + public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector()*/, @Cast("nd4j::DataType*") @StdVector IntPointer dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/); + public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector()*/, @Cast("nd4j::DataType*") @StdVector IntBuffer dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/); + public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector()*/, @Cast("nd4j::DataType*") @StdVector int[] dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/); + + public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs, @Cast("nd4j::DataType*") @StdVector IntPointer dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/); + public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs); + public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector boolean[] bArgs, @Cast("nd4j::DataType*") @StdVector IntBuffer dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/); + public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector boolean[] bArgs); + public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs, @Cast("nd4j::DataType*") @StdVector int[] dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/); + public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs); + public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector boolean[] bArgs, @Cast("nd4j::DataType*") @StdVector IntPointer dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/); + public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector boolean[] bArgs); + public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs, @Cast("nd4j::DataType*") @StdVector IntBuffer dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/); + public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs); + public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs, @Cast("nd4j::DataType*") @StdVector int[] dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("nd4j::DataType") int type/*=nd4j::DataType::FLOAT32*/); + public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs); public native ResultSet execute(@Const @ByRef OpArgsHolder holder, @Cast("bool") boolean isInplace/*=false*/); public native ResultSet execute(@Const @ByRef OpArgsHolder holder); @@ -11860,8 +11877,9 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ public BooleanOp(Pointer p) { super(p); } - public native @Cast("bool") boolean evaluate(@ByRef NDArrayVector args); - public native @Cast("bool") boolean evaluate(@ByRef Context block); + + public native @Cast("bool") boolean verify(@Const @ByRef NDArrayVector args); + public native @Cast("bool") boolean verify(@ByRef Context block); public native @Cast("Nd4jStatus") int execute(Context block); diff --git a/nd4j/nd4j-remote/nd4j-grpc-client/src/main/java/org/nd4j/graph/grpc/GraphInferenceServerGrpc.java b/nd4j/nd4j-remote/nd4j-grpc-client/src/main/java/org/nd4j/graph/grpc/GraphInferenceServerGrpc.java index 61ddd7c49..20cffe496 100644 --- a/nd4j/nd4j-remote/nd4j-grpc-client/src/main/java/org/nd4j/graph/grpc/GraphInferenceServerGrpc.java +++ b/nd4j/nd4j-remote/nd4j-grpc-client/src/main/java/org/nd4j/graph/grpc/GraphInferenceServerGrpc.java @@ -1,4 +1,4 @@ -//Generated by flatc compiler (version 1.9.0) +//Generated by flatc compiler (version 1.10.0) //If you make any local changes, they will be lost //source: graph.fbs @@ -31,17 +31,17 @@ public final class GraphInferenceServerGrpc { private GraphInferenceServerGrpc() {} - public static final String SERVICE_NAME = "nd4j.graph.GraphInferenceServer"; + public static final String SERVICE_NAME = "org.nd4j.graph.GraphInferenceServer"; // Static method descriptors that strictly reflect the proto. @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/1901") @Deprecated // Use {@link #getRegisterGraphMethod()} instead. public static final io.grpc.MethodDescriptor METHOD_REGISTER_GRAPH = getRegisterGraphMethod(); - + private static volatile io.grpc.MethodDescriptor getRegisterGraphMethod; - + private static volatile FlatbuffersUtils.FBExtactor extractorOfFlatGraph; private static FlatbuffersUtils.FBExtactor getExtractorOfFlatGraph() { if (extractorOfFlatGraph != null) return extractorOfFlatGraph; @@ -55,7 +55,7 @@ public final class GraphInferenceServerGrpc { return extractorOfFlatGraph; } } - + private static volatile FlatbuffersUtils.FBExtactor extractorOfFlatResponse; private static FlatbuffersUtils.FBExtactor getExtractorOfFlatResponse() { if (extractorOfFlatResponse != null) return extractorOfFlatResponse; @@ -69,7 +69,7 @@ public final class GraphInferenceServerGrpc { return extractorOfFlatResponse; } } - + @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/1901") public static io.grpc.MethodDescriptor getRegisterGraphMethod() { @@ -77,11 +77,11 @@ public final class GraphInferenceServerGrpc { if ((getRegisterGraphMethod = GraphInferenceServerGrpc.getRegisterGraphMethod) == null) { synchronized (GraphInferenceServerGrpc.class) { if ((getRegisterGraphMethod = GraphInferenceServerGrpc.getRegisterGraphMethod) == null) { - GraphInferenceServerGrpc.getRegisterGraphMethod = getRegisterGraphMethod = + GraphInferenceServerGrpc.getRegisterGraphMethod = getRegisterGraphMethod = io.grpc.MethodDescriptor.newBuilder() .setType(io.grpc.MethodDescriptor.MethodType.UNARY) .setFullMethodName(generateFullMethodName( - "nd4j.graph.GraphInferenceServer", "RegisterGraph")) + "org.nd4j.graph.GraphInferenceServer", "RegisterGraph")) .setSampledToLocalTracing(true) .setRequestMarshaller(FlatbuffersUtils.marshaller( org.nd4j.graph.FlatGraph.class, getExtractorOfFlatGraph())) @@ -94,15 +94,15 @@ public final class GraphInferenceServerGrpc { } return getRegisterGraphMethod; } - + @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/1901") - @Deprecated // Use {@link #getForgetGraphMethod()} instead. + @Deprecated // Use {@link #getForgetGraphMethod()} instead. public static final io.grpc.MethodDescriptor METHOD_FORGET_GRAPH = getForgetGraphMethod(); - + private static volatile io.grpc.MethodDescriptor getForgetGraphMethod; - + private static volatile FlatbuffersUtils.FBExtactor extractorOfFlatDropRequest; private static FlatbuffersUtils.FBExtactor getExtractorOfFlatDropRequest() { if (extractorOfFlatDropRequest != null) return extractorOfFlatDropRequest; @@ -116,7 +116,7 @@ public final class GraphInferenceServerGrpc { return extractorOfFlatDropRequest; } } - + @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/1901") public static io.grpc.MethodDescriptor getForgetGraphMethod() { @@ -124,11 +124,11 @@ public final class GraphInferenceServerGrpc { if ((getForgetGraphMethod = GraphInferenceServerGrpc.getForgetGraphMethod) == null) { synchronized (GraphInferenceServerGrpc.class) { if ((getForgetGraphMethod = GraphInferenceServerGrpc.getForgetGraphMethod) == null) { - GraphInferenceServerGrpc.getForgetGraphMethod = getForgetGraphMethod = + GraphInferenceServerGrpc.getForgetGraphMethod = getForgetGraphMethod = io.grpc.MethodDescriptor.newBuilder() .setType(io.grpc.MethodDescriptor.MethodType.UNARY) .setFullMethodName(generateFullMethodName( - "nd4j.graph.GraphInferenceServer", "ForgetGraph")) + "org.nd4j.graph.GraphInferenceServer", "ForgetGraph")) .setSampledToLocalTracing(true) .setRequestMarshaller(FlatbuffersUtils.marshaller( org.nd4j.graph.FlatDropRequest.class, getExtractorOfFlatDropRequest())) @@ -141,15 +141,48 @@ public final class GraphInferenceServerGrpc { } return getForgetGraphMethod; } - + @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/1901") - @Deprecated // Use {@link #getInferenceRequestMethod()} instead. + @Deprecated // Use {@link #getReplaceGraphMethod()} instead. + public static final io.grpc.MethodDescriptor METHOD_REPLACE_GRAPH = getReplaceGraphMethod(); + + private static volatile io.grpc.MethodDescriptor getReplaceGraphMethod; + + @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/1901") + public static io.grpc.MethodDescriptor getReplaceGraphMethod() { + io.grpc.MethodDescriptor getReplaceGraphMethod; + if ((getReplaceGraphMethod = GraphInferenceServerGrpc.getReplaceGraphMethod) == null) { + synchronized (GraphInferenceServerGrpc.class) { + if ((getReplaceGraphMethod = GraphInferenceServerGrpc.getReplaceGraphMethod) == null) { + GraphInferenceServerGrpc.getReplaceGraphMethod = getReplaceGraphMethod = + io.grpc.MethodDescriptor.newBuilder() + .setType(io.grpc.MethodDescriptor.MethodType.UNARY) + .setFullMethodName(generateFullMethodName( + "org.nd4j.graph.GraphInferenceServer", "ReplaceGraph")) + .setSampledToLocalTracing(true) + .setRequestMarshaller(FlatbuffersUtils.marshaller( + org.nd4j.graph.FlatGraph.class, getExtractorOfFlatGraph())) + .setResponseMarshaller(FlatbuffersUtils.marshaller( + org.nd4j.graph.FlatResponse.class, getExtractorOfFlatResponse())) + .setSchemaDescriptor(null) + .build(); + } + } + } + return getReplaceGraphMethod; + } + + @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/1901") + @Deprecated // Use {@link #getInferenceRequestMethod()} instead. public static final io.grpc.MethodDescriptor METHOD_INFERENCE_REQUEST = getInferenceRequestMethod(); - + private static volatile io.grpc.MethodDescriptor getInferenceRequestMethod; - + private static volatile FlatbuffersUtils.FBExtactor extractorOfFlatInferenceRequest; private static FlatbuffersUtils.FBExtactor getExtractorOfFlatInferenceRequest() { if (extractorOfFlatInferenceRequest != null) return extractorOfFlatInferenceRequest; @@ -163,7 +196,7 @@ public final class GraphInferenceServerGrpc { return extractorOfFlatInferenceRequest; } } - + private static volatile FlatbuffersUtils.FBExtactor extractorOfFlatResult; private static FlatbuffersUtils.FBExtactor getExtractorOfFlatResult() { if (extractorOfFlatResult != null) return extractorOfFlatResult; @@ -177,7 +210,7 @@ public final class GraphInferenceServerGrpc { return extractorOfFlatResult; } } - + @io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/1901") public static io.grpc.MethodDescriptor getInferenceRequestMethod() { @@ -185,11 +218,11 @@ public final class GraphInferenceServerGrpc { if ((getInferenceRequestMethod = GraphInferenceServerGrpc.getInferenceRequestMethod) == null) { synchronized (GraphInferenceServerGrpc.class) { if ((getInferenceRequestMethod = GraphInferenceServerGrpc.getInferenceRequestMethod) == null) { - GraphInferenceServerGrpc.getInferenceRequestMethod = getInferenceRequestMethod = + GraphInferenceServerGrpc.getInferenceRequestMethod = getInferenceRequestMethod = io.grpc.MethodDescriptor.newBuilder() .setType(io.grpc.MethodDescriptor.MethodType.UNARY) .setFullMethodName(generateFullMethodName( - "nd4j.graph.GraphInferenceServer", "InferenceRequest")) + "org.nd4j.graph.GraphInferenceServer", "InferenceRequest")) .setSampledToLocalTracing(true) .setRequestMarshaller(FlatbuffersUtils.marshaller( org.nd4j.graph.FlatInferenceRequest.class, getExtractorOfFlatInferenceRequest())) @@ -202,14 +235,14 @@ public final class GraphInferenceServerGrpc { } return getInferenceRequestMethod; } - + /** * Creates a new async stub that supports all call types for the service */ public static GraphInferenceServerStub newStub(io.grpc.Channel channel) { return new GraphInferenceServerStub(channel); } - + /** * Creates a new blocking-style stub that supports unary and streaming output calls on the service */ @@ -217,7 +250,7 @@ public final class GraphInferenceServerGrpc { io.grpc.Channel channel) { return new GraphInferenceServerBlockingStub(channel); } - + /** * Creates a new ListenableFuture-style stub that supports unary calls on the service */ @@ -225,32 +258,39 @@ public final class GraphInferenceServerGrpc { io.grpc.Channel channel) { return new GraphInferenceServerFutureStub(channel); } - + /** */ public static abstract class GraphInferenceServerImplBase implements io.grpc.BindableService { - + /** */ public void registerGraph(org.nd4j.graph.FlatGraph request, io.grpc.stub.StreamObserver responseObserver) { asyncUnimplementedUnaryCall(getRegisterGraphMethod(), responseObserver); } - + /** */ public void forgetGraph(org.nd4j.graph.FlatDropRequest request, io.grpc.stub.StreamObserver responseObserver) { asyncUnimplementedUnaryCall(getForgetGraphMethod(), responseObserver); } - + + /** + */ + public void replaceGraph(org.nd4j.graph.FlatGraph request, + io.grpc.stub.StreamObserver responseObserver) { + asyncUnimplementedUnaryCall(getReplaceGraphMethod(), responseObserver); + } + /** */ public void inferenceRequest(org.nd4j.graph.FlatInferenceRequest request, io.grpc.stub.StreamObserver responseObserver) { asyncUnimplementedUnaryCall(getInferenceRequestMethod(), responseObserver); } - + @Override public final io.grpc.ServerServiceDefinition bindService() { return io.grpc.ServerServiceDefinition.builder(getServiceDescriptor()) .addMethod( @@ -267,6 +307,13 @@ public final class GraphInferenceServerGrpc { org.nd4j.graph.FlatDropRequest, org.nd4j.graph.FlatResponse>( this, METHODID_FORGET_GRAPH))) + .addMethod( + getReplaceGraphMethod(), + asyncUnaryCall( + new MethodHandlers< + org.nd4j.graph.FlatGraph, + org.nd4j.graph.FlatResponse>( + this, METHODID_REPLACE_GRAPH))) .addMethod( getInferenceRequestMethod(), asyncUnaryCall( @@ -277,25 +324,25 @@ public final class GraphInferenceServerGrpc { .build(); } } - + /** */ public static final class GraphInferenceServerStub extends io.grpc.stub.AbstractStub { private GraphInferenceServerStub(io.grpc.Channel channel) { super(channel); } - + private GraphInferenceServerStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { super(channel, callOptions); } - + @Override protected GraphInferenceServerStub build(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { return new GraphInferenceServerStub(channel, callOptions); } - + /** */ public void registerGraph(org.nd4j.graph.FlatGraph request, @@ -303,7 +350,7 @@ public final class GraphInferenceServerGrpc { asyncUnaryCall( getChannel().newCall(getRegisterGraphMethod(), getCallOptions()), request, responseObserver); } - + /** */ public void forgetGraph(org.nd4j.graph.FlatDropRequest request, @@ -311,7 +358,15 @@ public final class GraphInferenceServerGrpc { asyncUnaryCall( getChannel().newCall(getForgetGraphMethod(), getCallOptions()), request, responseObserver); } - + + /** + */ + public void replaceGraph(org.nd4j.graph.FlatGraph request, + io.grpc.stub.StreamObserver responseObserver) { + asyncUnaryCall( + getChannel().newCall(getReplaceGraphMethod(), getCallOptions()), request, responseObserver); + } + /** */ public void inferenceRequest(org.nd4j.graph.FlatInferenceRequest request, @@ -320,39 +375,46 @@ public final class GraphInferenceServerGrpc { getChannel().newCall(getInferenceRequestMethod(), getCallOptions()), request, responseObserver); } } - + /** */ public static final class GraphInferenceServerBlockingStub extends io.grpc.stub.AbstractStub { private GraphInferenceServerBlockingStub(io.grpc.Channel channel) { super(channel); } - + private GraphInferenceServerBlockingStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { super(channel, callOptions); } - + @Override protected GraphInferenceServerBlockingStub build(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { return new GraphInferenceServerBlockingStub(channel, callOptions); } - + /** */ public org.nd4j.graph.FlatResponse registerGraph(org.nd4j.graph.FlatGraph request) { return blockingUnaryCall( getChannel(), getRegisterGraphMethod(), getCallOptions(), request); } - + /** */ public org.nd4j.graph.FlatResponse forgetGraph(org.nd4j.graph.FlatDropRequest request) { return blockingUnaryCall( getChannel(), getForgetGraphMethod(), getCallOptions(), request); } - + + /** + */ + public org.nd4j.graph.FlatResponse replaceGraph(org.nd4j.graph.FlatGraph request) { + return blockingUnaryCall( + getChannel(), getReplaceGraphMethod(), getCallOptions(), request); + } + /** */ public org.nd4j.graph.FlatResult inferenceRequest(org.nd4j.graph.FlatInferenceRequest request) { @@ -360,25 +422,25 @@ public final class GraphInferenceServerGrpc { getChannel(), getInferenceRequestMethod(), getCallOptions(), request); } } - + /** */ public static final class GraphInferenceServerFutureStub extends io.grpc.stub.AbstractStub { private GraphInferenceServerFutureStub(io.grpc.Channel channel) { super(channel); } - + private GraphInferenceServerFutureStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { super(channel, callOptions); } - + @Override protected GraphInferenceServerFutureStub build(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { return new GraphInferenceServerFutureStub(channel, callOptions); } - + /** */ public com.google.common.util.concurrent.ListenableFuture registerGraph( @@ -386,7 +448,7 @@ public final class GraphInferenceServerGrpc { return futureUnaryCall( getChannel().newCall(getRegisterGraphMethod(), getCallOptions()), request); } - + /** */ public com.google.common.util.concurrent.ListenableFuture forgetGraph( @@ -394,7 +456,15 @@ public final class GraphInferenceServerGrpc { return futureUnaryCall( getChannel().newCall(getForgetGraphMethod(), getCallOptions()), request); } - + + /** + */ + public com.google.common.util.concurrent.ListenableFuture replaceGraph( + org.nd4j.graph.FlatGraph request) { + return futureUnaryCall( + getChannel().newCall(getReplaceGraphMethod(), getCallOptions()), request); + } + /** */ public com.google.common.util.concurrent.ListenableFuture inferenceRequest( @@ -403,11 +473,12 @@ public final class GraphInferenceServerGrpc { getChannel().newCall(getInferenceRequestMethod(), getCallOptions()), request); } } - + private static final int METHODID_REGISTER_GRAPH = 0; private static final int METHODID_FORGET_GRAPH = 1; - private static final int METHODID_INFERENCE_REQUEST = 2; - + private static final int METHODID_REPLACE_GRAPH = 2; + private static final int METHODID_INFERENCE_REQUEST = 3; + private static final class MethodHandlers implements io.grpc.stub.ServerCalls.UnaryMethod, io.grpc.stub.ServerCalls.ServerStreamingMethod, @@ -415,12 +486,12 @@ public final class GraphInferenceServerGrpc { io.grpc.stub.ServerCalls.BidiStreamingMethod { private final GraphInferenceServerImplBase serviceImpl; private final int methodId; - + MethodHandlers(GraphInferenceServerImplBase serviceImpl, int methodId) { this.serviceImpl = serviceImpl; this.methodId = methodId; } - + @Override @SuppressWarnings("unchecked") public void invoke(Req request, io.grpc.stub.StreamObserver responseObserver) { @@ -433,6 +504,10 @@ public final class GraphInferenceServerGrpc { serviceImpl.forgetGraph((org.nd4j.graph.FlatDropRequest) request, (io.grpc.stub.StreamObserver) responseObserver); break; + case METHODID_REPLACE_GRAPH: + serviceImpl.replaceGraph((org.nd4j.graph.FlatGraph) request, + (io.grpc.stub.StreamObserver) responseObserver); + break; case METHODID_INFERENCE_REQUEST: serviceImpl.inferenceRequest((org.nd4j.graph.FlatInferenceRequest) request, (io.grpc.stub.StreamObserver) responseObserver); @@ -465,6 +540,7 @@ public final class GraphInferenceServerGrpc { .setSchemaDescriptor(null) .addMethod(getRegisterGraphMethod()) .addMethod(getForgetGraphMethod()) + .addMethod(getReplaceGraphMethod()) .addMethod(getInferenceRequestMethod()) .build(); } From 5d98cfcf4765fe5554a06c173934c16f6ab7defd Mon Sep 17 00:00:00 2001 From: raver119 Date: Thu, 30 Jan 2020 18:46:12 +0300 Subject: [PATCH 07/17] Configurable DataType for ops (#201) * initial commit Signed-off-by: raver119 * - one more test for OneHot with dtype - one more signature in Nd4j Signed-off-by: raver119 * ones_as/zeros_as now accept dtype Signed-off-by: raver119 * one more test Signed-off-by: raver119 * - more updates for configurable data types - ones_as/zeros_as java side + tests Signed-off-by: raver119 * few c++ tests fixed Signed-off-by: raver119 * few more changes around DArgs Signed-off-by: raver119 --- libnd4j/blas/NativeOps.h | 2 +- libnd4j/blas/cpu/NativeOps.cpp | 9 ++- libnd4j/blas/cuda/NativeOps.cu | 9 ++- .../graph/generated/nd4j/graph/FlatNode.cs | 20 ++++++- .../graph/generated/nd4j/graph/FlatNode.java | 15 ++++- .../graph/generated/nd4j/graph/FlatNode.py | 26 ++++++++- .../include/graph/generated/node_generated.h | 21 +++++-- .../include/graph/generated/node_generated.js | 56 ++++++++++++++++++- libnd4j/include/graph/impl/Node.cpp | 18 +++--- libnd4j/include/graph/scheme/node.fbs | 4 +- .../declarable/generic/parity_ops/ones_as.cpp | 14 ++++- .../generic/parity_ops/zeros_as.cpp | 13 ++++- .../ops/declarable/headers/parity_ops.h | 4 +- .../layers_tests/DeclarableOpsTests16.cpp | 2 +- .../layers_tests/DeclarableOpsTests8.cpp | 18 ++++++ .../layers_tests/JavaInteropTests.cpp | 4 +- .../tests_cpu/layers_tests/NativeOpsTests.cpp | 2 +- .../org/nd4j/autodiff/samediff/SameDiff.java | 2 +- .../samediff/serde/FlatBuffersMapper.java | 22 +++++++- .../main/java/org/nd4j/graph/FlatNode.java | 15 ++++- .../DifferentialFunctionClassHolder.java | 1 + .../nd4j/linalg/api/ops/BaseOpContext.java | 22 ++++++-- .../org/nd4j/linalg/api/ops/CustomOp.java | 11 +++- .../nd4j/linalg/api/ops/DynamicCustomOp.java | 30 ++++++++++ .../org/nd4j/linalg/api/ops/OpContext.java | 9 ++- .../linalg/api/ops/custom/ScatterUpdate.java | 16 ++++++ .../linalg/api/ops/impl/shape/OneHot.java | 3 + .../linalg/api/ops/impl/shape/OnesLike.java | 20 +++++++ .../java/org/nd4j/linalg/factory/Nd4j.java | 10 ++++ .../java/org/nd4j/nativeblas/NativeOps.java | 3 +- .../ops/executioner/CudaExecutioner.java | 10 +++- .../ops/executioner/CudaOpContext.java | 18 ++++-- .../java/org/nd4j/nativeblas/Nd4jCuda.java | 12 ++-- .../cpu/nativecpu/ops/CpuOpContext.java | 19 +++++-- .../nativecpu/ops/NativeOpExecutioner.java | 11 +++- .../java/org/nd4j/nativeblas/Nd4jCpu.java | 32 +++++------ .../opvalidation/MiscOpValidation.java | 22 +++++++- .../nd4j/linalg/custom/CustomOpsTests.java | 10 ++++ 38 files changed, 448 insertions(+), 87 deletions(-) diff --git a/libnd4j/blas/NativeOps.h b/libnd4j/blas/NativeOps.h index 141ecb6ec..01b656861 100755 --- a/libnd4j/blas/NativeOps.h +++ b/libnd4j/blas/NativeOps.h @@ -1518,7 +1518,7 @@ ND4J_EXPORT int execCustomOp2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPoi typedef nd4j::ShapeList OpaqueShapeList; ND4J_EXPORT OpaqueShapeList* calculateOutputShapes(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs); -ND4J_EXPORT OpaqueShapeList* calculateOutputShapes2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs); +ND4J_EXPORT OpaqueShapeList* calculateOutputShapes2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs, int *dArgs, int numDArgs); ND4J_EXPORT Nd4jLong getShapeListSize(OpaqueShapeList* list); ND4J_EXPORT Nd4jLong* getShape(OpaqueShapeList* list, Nd4jLong i); diff --git a/libnd4j/blas/cpu/NativeOps.cpp b/libnd4j/blas/cpu/NativeOps.cpp index 3ba971aa5..0410c833b 100644 --- a/libnd4j/blas/cpu/NativeOps.cpp +++ b/libnd4j/blas/cpu/NativeOps.cpp @@ -1974,7 +1974,7 @@ void deleteShapeList(Nd4jPointer shapeList) { delete list; } -nd4j::ShapeList* _calculateOutputShapes(Nd4jPointer* extraPointers, nd4j::ops::DeclarableOp* op, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs) { +nd4j::ShapeList* _calculateOutputShapes(Nd4jPointer* extraPointers, nd4j::ops::DeclarableOp* op, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs, int *dArgs, int numDArgs) { nd4j::graph::VariableSpace varSpace; Context block(2, &varSpace); nd4j::ShapeList inShapes; @@ -1988,6 +1988,9 @@ nd4j::ShapeList* _calculateOutputShapes(Nd4jPointer* extraPointers, nd4j::ops::D for (int e = 0; e < numBArgs; e++) block.getBArguments()->push_back(bArgs[e]); + for (int e = 0; e < numDArgs; e++) + block.getDArguments()->push_back((nd4j::DataType) dArgs[e]); + for (int e = 0; e < numInputShapes; e++) { auto shape_ = reinterpret_cast(inputShapes[e]); @@ -2015,11 +2018,11 @@ nd4j::ShapeList* _calculateOutputShapes(Nd4jPointer* extraPointers, nd4j::ops::D return shapeList; } -nd4j::ShapeList* calculateOutputShapes2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs) { +nd4j::ShapeList* calculateOutputShapes2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs, int *dArgs, int numDArgs) { try { auto op = nd4j::ops::OpRegistrator::getInstance()->getOperation(hash); - return _calculateOutputShapes(extraPointers, op, inputBuffers, inputShapes, numInputShapes, tArgs, numTArgs, iArgs, numIArgs, bArgs, numBArgs); + return _calculateOutputShapes(extraPointers, op, inputBuffers, inputShapes, numInputShapes, tArgs, numTArgs, iArgs, numIArgs, bArgs, numBArgs, dArgs, numDArgs); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); diff --git a/libnd4j/blas/cuda/NativeOps.cu b/libnd4j/blas/cuda/NativeOps.cu index 45de82b32..d65dcaed5 100755 --- a/libnd4j/blas/cuda/NativeOps.cu +++ b/libnd4j/blas/cuda/NativeOps.cu @@ -2684,7 +2684,7 @@ const char* getAllCustomOps() { } -nd4j::ShapeList* _calculateOutputShapes(Nd4jPointer* extraPointers, nd4j::ops::DeclarableOp* op, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs) { +nd4j::ShapeList* _calculateOutputShapes(Nd4jPointer* extraPointers, nd4j::ops::DeclarableOp* op, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs, int *dArgs, int numDArgs) { nd4j::graph::VariableSpace varSpace; Context block(2, &varSpace); nd4j::ShapeList inShapes; @@ -2698,6 +2698,9 @@ nd4j::ShapeList* _calculateOutputShapes(Nd4jPointer* extraPointers, nd4j::ops::D for (int e = 0; e < numBArgs; e++) block.getBArguments()->push_back(bArgs[e]); + for (int e = 0; e < numDArgs; e++) + block.getDArguments()->push_back((nd4j::DataType) dArgs[e]); + for (int e = 0; e < numInputShapes; e++) { auto shape_ = reinterpret_cast(inputShapes[e]); @@ -2722,12 +2725,12 @@ nd4j::ShapeList* _calculateOutputShapes(Nd4jPointer* extraPointers, nd4j::ops::D return shapeList; } -nd4j::ShapeList* calculateOutputShapes2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs) { +nd4j::ShapeList* calculateOutputShapes2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs, int *dArgs, int numDArgs) { try { auto op = nd4j::ops::OpRegistrator::getInstance()->getOperation(hash); return _calculateOutputShapes(extraPointers, op, inputBuffers, inputShapes, numInputShapes, tArgs, numTArgs, - iArgs, numIArgs, bArgs, numBArgs); + iArgs, numIArgs, bArgs, numBArgs, dArgs, numDArgs); } catch (std::exception &e) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); diff --git a/libnd4j/include/graph/generated/nd4j/graph/FlatNode.cs b/libnd4j/include/graph/generated/nd4j/graph/FlatNode.cs index 7fa9722db..c94e0fcc4 100644 --- a/libnd4j/include/graph/generated/nd4j/graph/FlatNode.cs +++ b/libnd4j/include/graph/generated/nd4j/graph/FlatNode.cs @@ -112,6 +112,14 @@ public struct FlatNode : IFlatbufferObject public int VarControlDepsLength { get { int o = __p.__offset(44); return o != 0 ? __p.__vector_len(o) : 0; } } public string ControlDepFor(int j) { int o = __p.__offset(46); return o != 0 ? __p.__string(__p.__vector(o) + j * 4) : null; } public int ControlDepForLength { get { int o = __p.__offset(46); return o != 0 ? __p.__vector_len(o) : 0; } } + public DType ExtraTypes(int j) { int o = __p.__offset(48); return o != 0 ? (DType)__p.bb.GetSbyte(__p.__vector(o) + j * 1) : (DType)0; } + public int ExtraTypesLength { get { int o = __p.__offset(48); return o != 0 ? __p.__vector_len(o) : 0; } } +#if ENABLE_SPAN_T + public Span GetExtraTypesBytes() { return __p.__vector_as_span(48); } +#else + public ArraySegment? GetExtraTypesBytes() { return __p.__vector_as_arraysegment(48); } +#endif + public DType[] GetExtraTypesArray() { return __p.__vector_as_array(48); } public static Offset CreateFlatNode(FlatBufferBuilder builder, int id = 0, @@ -135,9 +143,11 @@ public struct FlatNode : IFlatbufferObject Offset scalarOffset = default(Offset), VectorOffset controlDepsOffset = default(VectorOffset), VectorOffset varControlDepsOffset = default(VectorOffset), - VectorOffset controlDepForOffset = default(VectorOffset)) { - builder.StartObject(22); + VectorOffset controlDepForOffset = default(VectorOffset), + VectorOffset extraTypesOffset = default(VectorOffset)) { + builder.StartObject(23); FlatNode.AddOpNum(builder, opNum); + FlatNode.AddExtraTypes(builder, extraTypesOffset); FlatNode.AddControlDepFor(builder, controlDepForOffset); FlatNode.AddVarControlDeps(builder, varControlDepsOffset); FlatNode.AddControlDeps(builder, controlDepsOffset); @@ -162,7 +172,7 @@ public struct FlatNode : IFlatbufferObject return FlatNode.EndFlatNode(builder); } - public static void StartFlatNode(FlatBufferBuilder builder) { builder.StartObject(22); } + public static void StartFlatNode(FlatBufferBuilder builder) { builder.StartObject(23); } public static void AddId(FlatBufferBuilder builder, int id) { builder.AddInt(0, id, 0); } public static void AddName(FlatBufferBuilder builder, StringOffset nameOffset) { builder.AddOffset(1, nameOffset.Value, 0); } public static void AddOpType(FlatBufferBuilder builder, OpType opType) { builder.AddSbyte(2, (sbyte)opType, 0); } @@ -224,6 +234,10 @@ public struct FlatNode : IFlatbufferObject public static VectorOffset CreateControlDepForVector(FlatBufferBuilder builder, StringOffset[] data) { builder.StartVector(4, data.Length, 4); for (int i = data.Length - 1; i >= 0; i--) builder.AddOffset(data[i].Value); return builder.EndVector(); } public static VectorOffset CreateControlDepForVectorBlock(FlatBufferBuilder builder, StringOffset[] data) { builder.StartVector(4, data.Length, 4); builder.Add(data); return builder.EndVector(); } public static void StartControlDepForVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(4, numElems, 4); } + public static void AddExtraTypes(FlatBufferBuilder builder, VectorOffset extraTypesOffset) { builder.AddOffset(22, extraTypesOffset.Value, 0); } + public static VectorOffset CreateExtraTypesVector(FlatBufferBuilder builder, DType[] data) { builder.StartVector(1, data.Length, 1); for (int i = data.Length - 1; i >= 0; i--) builder.AddSbyte((sbyte)data[i]); return builder.EndVector(); } + public static VectorOffset CreateExtraTypesVectorBlock(FlatBufferBuilder builder, DType[] data) { builder.StartVector(1, data.Length, 1); builder.Add(data); return builder.EndVector(); } + public static void StartExtraTypesVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(1, numElems, 1); } public static Offset EndFlatNode(FlatBufferBuilder builder) { int o = builder.EndObject(); return new Offset(o); diff --git a/libnd4j/include/graph/generated/nd4j/graph/FlatNode.java b/libnd4j/include/graph/generated/nd4j/graph/FlatNode.java index 8a72cc00a..2fe0a0ee9 100644 --- a/libnd4j/include/graph/generated/nd4j/graph/FlatNode.java +++ b/libnd4j/include/graph/generated/nd4j/graph/FlatNode.java @@ -72,6 +72,10 @@ public final class FlatNode extends Table { public int varControlDepsLength() { int o = __offset(44); return o != 0 ? __vector_len(o) : 0; } public String controlDepFor(int j) { int o = __offset(46); return o != 0 ? __string(__vector(o) + j * 4) : null; } public int controlDepForLength() { int o = __offset(46); return o != 0 ? __vector_len(o) : 0; } + public byte extraTypes(int j) { int o = __offset(48); return o != 0 ? bb.get(__vector(o) + j * 1) : 0; } + public int extraTypesLength() { int o = __offset(48); return o != 0 ? __vector_len(o) : 0; } + public ByteBuffer extraTypesAsByteBuffer() { return __vector_as_bytebuffer(48, 1); } + public ByteBuffer extraTypesInByteBuffer(ByteBuffer _bb) { return __vector_in_bytebuffer(_bb, 48, 1); } public static int createFlatNode(FlatBufferBuilder builder, int id, @@ -95,9 +99,11 @@ public final class FlatNode extends Table { int scalarOffset, int controlDepsOffset, int varControlDepsOffset, - int controlDepForOffset) { - builder.startObject(22); + int controlDepForOffset, + int extraTypesOffset) { + builder.startObject(23); FlatNode.addOpNum(builder, opNum); + FlatNode.addExtraTypes(builder, extraTypesOffset); FlatNode.addControlDepFor(builder, controlDepForOffset); FlatNode.addVarControlDeps(builder, varControlDepsOffset); FlatNode.addControlDeps(builder, controlDepsOffset); @@ -122,7 +128,7 @@ public final class FlatNode extends Table { return FlatNode.endFlatNode(builder); } - public static void startFlatNode(FlatBufferBuilder builder) { builder.startObject(22); } + public static void startFlatNode(FlatBufferBuilder builder) { builder.startObject(23); } public static void addId(FlatBufferBuilder builder, int id) { builder.addInt(0, id, 0); } public static void addName(FlatBufferBuilder builder, int nameOffset) { builder.addOffset(1, nameOffset, 0); } public static void addOpType(FlatBufferBuilder builder, byte opType) { builder.addByte(2, opType, 0); } @@ -171,6 +177,9 @@ public final class FlatNode extends Table { public static void addControlDepFor(FlatBufferBuilder builder, int controlDepForOffset) { builder.addOffset(21, controlDepForOffset, 0); } public static int createControlDepForVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); } public static void startControlDepForVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); } + public static void addExtraTypes(FlatBufferBuilder builder, int extraTypesOffset) { builder.addOffset(22, extraTypesOffset, 0); } + public static int createExtraTypesVector(FlatBufferBuilder builder, byte[] data) { builder.startVector(1, data.length, 1); for (int i = data.length - 1; i >= 0; i--) builder.addByte(data[i]); return builder.endVector(); } + public static void startExtraTypesVector(FlatBufferBuilder builder, int numElems) { builder.startVector(1, numElems, 1); } public static int endFlatNode(FlatBufferBuilder builder) { int o = builder.endObject(); return o; diff --git a/libnd4j/include/graph/generated/nd4j/graph/FlatNode.py b/libnd4j/include/graph/generated/nd4j/graph/FlatNode.py index 889eca62f..d5104efb6 100644 --- a/libnd4j/include/graph/generated/nd4j/graph/FlatNode.py +++ b/libnd4j/include/graph/generated/nd4j/graph/FlatNode.py @@ -339,7 +339,29 @@ class FlatNode(object): return self._tab.VectorLen(o) return 0 -def FlatNodeStart(builder): builder.StartObject(22) + # FlatNode + def ExtraTypes(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(48)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get(flatbuffers.number_types.Int8Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 1)) + return 0 + + # FlatNode + def ExtraTypesAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(48)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int8Flags, o) + return 0 + + # FlatNode + def ExtraTypesLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(48)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + +def FlatNodeStart(builder): builder.StartObject(23) def FlatNodeAddId(builder, id): builder.PrependInt32Slot(0, id, 0) def FlatNodeAddName(builder, name): builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(name), 0) def FlatNodeAddOpType(builder, opType): builder.PrependInt8Slot(2, opType, 0) @@ -375,4 +397,6 @@ def FlatNodeAddVarControlDeps(builder, varControlDeps): builder.PrependUOffsetTR def FlatNodeStartVarControlDepsVector(builder, numElems): return builder.StartVector(4, numElems, 4) def FlatNodeAddControlDepFor(builder, controlDepFor): builder.PrependUOffsetTRelativeSlot(21, flatbuffers.number_types.UOffsetTFlags.py_type(controlDepFor), 0) def FlatNodeStartControlDepForVector(builder, numElems): return builder.StartVector(4, numElems, 4) +def FlatNodeAddExtraTypes(builder, extraTypes): builder.PrependUOffsetTRelativeSlot(22, flatbuffers.number_types.UOffsetTFlags.py_type(extraTypes), 0) +def FlatNodeStartExtraTypesVector(builder, numElems): return builder.StartVector(1, numElems, 1) def FlatNodeEnd(builder): return builder.EndObject() diff --git a/libnd4j/include/graph/generated/node_generated.h b/libnd4j/include/graph/generated/node_generated.h index 6ca85f7b0..92f4ab126 100644 --- a/libnd4j/include/graph/generated/node_generated.h +++ b/libnd4j/include/graph/generated/node_generated.h @@ -38,7 +38,8 @@ struct FlatNode FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VT_SCALAR = 40, VT_CONTROLDEPS = 42, VT_VARCONTROLDEPS = 44, - VT_CONTROLDEPFOR = 46 + VT_CONTROLDEPFOR = 46, + VT_EXTRATYPES = 48 }; int32_t id() const { return GetField(VT_ID, 0); @@ -106,6 +107,9 @@ struct FlatNode FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { const flatbuffers::Vector> *controlDepFor() const { return GetPointer> *>(VT_CONTROLDEPFOR); } + const flatbuffers::Vector *extraTypes() const { + return GetPointer *>(VT_EXTRATYPES); + } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyField(verifier, VT_ID) && @@ -153,6 +157,8 @@ struct FlatNode FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VerifyOffset(verifier, VT_CONTROLDEPFOR) && verifier.VerifyVector(controlDepFor()) && verifier.VerifyVectorOfStrings(controlDepFor()) && + VerifyOffset(verifier, VT_EXTRATYPES) && + verifier.VerifyVector(extraTypes()) && verifier.EndTable(); } }; @@ -226,6 +232,9 @@ struct FlatNodeBuilder { void add_controlDepFor(flatbuffers::Offset>> controlDepFor) { fbb_.AddOffset(FlatNode::VT_CONTROLDEPFOR, controlDepFor); } + void add_extraTypes(flatbuffers::Offset> extraTypes) { + fbb_.AddOffset(FlatNode::VT_EXTRATYPES, extraTypes); + } explicit FlatNodeBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -261,9 +270,11 @@ inline flatbuffers::Offset CreateFlatNode( flatbuffers::Offset scalar = 0, flatbuffers::Offset>> controlDeps = 0, flatbuffers::Offset>> varControlDeps = 0, - flatbuffers::Offset>> controlDepFor = 0) { + flatbuffers::Offset>> controlDepFor = 0, + flatbuffers::Offset> extraTypes = 0) { FlatNodeBuilder builder_(_fbb); builder_.add_opNum(opNum); + builder_.add_extraTypes(extraTypes); builder_.add_controlDepFor(controlDepFor); builder_.add_varControlDeps(varControlDeps); builder_.add_controlDeps(controlDeps); @@ -311,7 +322,8 @@ inline flatbuffers::Offset CreateFlatNodeDirect( flatbuffers::Offset scalar = 0, const std::vector> *controlDeps = nullptr, const std::vector> *varControlDeps = nullptr, - const std::vector> *controlDepFor = nullptr) { + const std::vector> *controlDepFor = nullptr, + const std::vector *extraTypes = nullptr) { return nd4j::graph::CreateFlatNode( _fbb, id, @@ -335,7 +347,8 @@ inline flatbuffers::Offset CreateFlatNodeDirect( scalar, controlDeps ? _fbb.CreateVector>(*controlDeps) : 0, varControlDeps ? _fbb.CreateVector>(*varControlDeps) : 0, - controlDepFor ? _fbb.CreateVector>(*controlDepFor) : 0); + controlDepFor ? _fbb.CreateVector>(*controlDepFor) : 0, + extraTypes ? _fbb.CreateVector(*extraTypes) : 0); } inline const nd4j::graph::FlatNode *GetFlatNode(const void *buf) { diff --git a/libnd4j/include/graph/generated/node_generated.js b/libnd4j/include/graph/generated/node_generated.js index dd83c4356..3f831a582 100644 --- a/libnd4j/include/graph/generated/node_generated.js +++ b/libnd4j/include/graph/generated/node_generated.js @@ -398,11 +398,36 @@ nd4j.graph.FlatNode.prototype.controlDepForLength = function() { return offset ? this.bb.__vector_len(this.bb_pos + offset) : 0; }; +/** + * @param {number} index + * @returns {nd4j.graph.DType} + */ +nd4j.graph.FlatNode.prototype.extraTypes = function(index) { + var offset = this.bb.__offset(this.bb_pos, 48); + return offset ? /** @type {nd4j.graph.DType} */ (this.bb.readInt8(this.bb.__vector(this.bb_pos + offset) + index)) : /** @type {nd4j.graph.DType} */ (0); +}; + +/** + * @returns {number} + */ +nd4j.graph.FlatNode.prototype.extraTypesLength = function() { + var offset = this.bb.__offset(this.bb_pos, 48); + return offset ? this.bb.__vector_len(this.bb_pos + offset) : 0; +}; + +/** + * @returns {Int8Array} + */ +nd4j.graph.FlatNode.prototype.extraTypesArray = function() { + var offset = this.bb.__offset(this.bb_pos, 48); + return offset ? new Int8Array(this.bb.bytes().buffer, this.bb.bytes().byteOffset + this.bb.__vector(this.bb_pos + offset), this.bb.__vector_len(this.bb_pos + offset)) : null; +}; + /** * @param {flatbuffers.Builder} builder */ nd4j.graph.FlatNode.startFlatNode = function(builder) { - builder.startObject(22); + builder.startObject(23); }; /** @@ -854,6 +879,35 @@ nd4j.graph.FlatNode.startControlDepForVector = function(builder, numElems) { builder.startVector(4, numElems, 4); }; +/** + * @param {flatbuffers.Builder} builder + * @param {flatbuffers.Offset} extraTypesOffset + */ +nd4j.graph.FlatNode.addExtraTypes = function(builder, extraTypesOffset) { + builder.addFieldOffset(22, extraTypesOffset, 0); +}; + +/** + * @param {flatbuffers.Builder} builder + * @param {Array.} data + * @returns {flatbuffers.Offset} + */ +nd4j.graph.FlatNode.createExtraTypesVector = function(builder, data) { + builder.startVector(1, data.length, 1); + for (var i = data.length - 1; i >= 0; i--) { + builder.addInt8(data[i]); + } + return builder.endVector(); +}; + +/** + * @param {flatbuffers.Builder} builder + * @param {number} numElems + */ +nd4j.graph.FlatNode.startExtraTypesVector = function(builder, numElems) { + builder.startVector(1, numElems, 1); +}; + /** * @param {flatbuffers.Builder} builder * @returns {flatbuffers.Offset} diff --git a/libnd4j/include/graph/impl/Node.cpp b/libnd4j/include/graph/impl/Node.cpp index 47c31cdf7..4c79ccb3e 100644 --- a/libnd4j/include/graph/impl/Node.cpp +++ b/libnd4j/include/graph/impl/Node.cpp @@ -587,9 +587,9 @@ namespace nd4j { block->getIArguments()->emplace_back(node->extraInteger()->Get(e)); } - if (node->outputTypes() != nullptr && node->outputTypes()->size() > 0) { - for (int e = 0; e < (int) node->outputTypes()->size(); e++) { - block->getDArguments()->emplace_back((nd4j::DataType) node->outputTypes()->Get(e)); + if (node->extraTypes() != nullptr && node->extraTypes()->size() > 0) { + for (int e = 0; e < (int) node->extraTypes()->size(); e++) { + block->getDArguments()->emplace_back((nd4j::DataType) node->extraTypes()->Get(e)); } } @@ -624,9 +624,9 @@ namespace nd4j { block->getIArguments()->emplace_back(node->extraInteger()->Get(e)); } - if (node->outputTypes() != nullptr && node->outputTypes()->size() > 0) { - for (int e = 0; e < (int) node->outputTypes()->size(); e++) { - block->getDArguments()->emplace_back((nd4j::DataType) node->outputTypes()->Get(e)); + if (node->extraTypes() != nullptr && node->extraTypes()->size() > 0) { + for (int e = 0; e < (int) node->extraTypes()->size(); e++) { + block->getDArguments()->emplace_back((nd4j::DataType) node->extraTypes()->Get(e)); } } @@ -664,9 +664,9 @@ namespace nd4j { block->getBArguments()->push_back(node->extraBools()->Get(e)); } - if (node->outputTypes() != nullptr && node->outputTypes()->size() > 0) { - for (int e = 0; e < (int) node->outputTypes()->size(); e++) { - block->getDArguments()->emplace_back((nd4j::DataType) node->outputTypes()->Get(e)); + if (node->extraTypes() != nullptr && node->extraTypes()->size() > 0) { + for (int e = 0; e < (int) node->extraTypes()->size(); e++) { + block->getDArguments()->emplace_back((nd4j::DataType) node->extraTypes()->Get(e)); } } diff --git a/libnd4j/include/graph/scheme/node.fbs b/libnd4j/include/graph/scheme/node.fbs index 92975e216..8e63186f5 100644 --- a/libnd4j/include/graph/scheme/node.fbs +++ b/libnd4j/include/graph/scheme/node.fbs @@ -57,7 +57,9 @@ table FlatNode { controlDeps:[string]; varControlDeps:[string]; controlDepFor:[string]; - + + // DArgs + extraTypes:[DType]; } root_type FlatNode; \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/ones_as.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/ones_as.cpp index a20c0110b..702aa6711 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/ones_as.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/ones_as.cpp @@ -25,7 +25,7 @@ namespace nd4j { namespace ops { - OP_IMPL(ones_as, 1, 1, false) { + CUSTOM_OP_IMPL(ones_as, 1, 1, false, 0, 0) { auto output = OUTPUT_VARIABLE(0); output->assign(1); @@ -33,11 +33,21 @@ namespace nd4j { return Status::OK(); } + DECLARE_SHAPE_FN(ones_as) { + auto in = inputShape->at(0); + auto dtype = block.numD() ? D_ARG(0) : ArrayOptions::dataType(in); + auto shape = nd4j::ConstantShapeHelper::getInstance()->createShapeInfo(dtype, in); + + nd4j_printf("numD: %i; dtype: %s\n", block.numD(), DataTypeUtils::asString(dtype).c_str()); + + return SHAPELIST(shape); + } + DECLARE_TYPES(ones_as) { getOpDescriptor() ->setAllowedInputTypes(nd4j::DataType::ANY) ->setAllowedOutputTypes(nd4j::DataType::ANY) - ->setSameMode(true); + ->setSameMode(false); } } } diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/zeros_as.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/zeros_as.cpp index 6b461043a..56b4264d0 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/zeros_as.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/zeros_as.cpp @@ -25,7 +25,7 @@ namespace nd4j { namespace ops { - OP_IMPL(zeros_as, 1, 1, false) { + CUSTOM_OP_IMPL(zeros_as, 1, 1, false, 0, 0) { auto out = OUTPUT_VARIABLE(0); out->assign(0); // output is filled by zero by default @@ -35,11 +35,20 @@ namespace nd4j { DECLARE_SYN(zeroslike, zeros_as); DECLARE_SYN(zeros_like, zeros_as); + + DECLARE_SHAPE_FN(zeros_as) { + auto in = inputShape->at(0); + auto dtype = block.numD() ? D_ARG(0) : ArrayOptions::dataType(in); + auto shape = nd4j::ConstantShapeHelper::getInstance()->createShapeInfo(dtype, in); + + return SHAPELIST(shape); + } + DECLARE_TYPES(zeros_as) { getOpDescriptor() ->setAllowedInputTypes(nd4j::DataType::ANY) ->setAllowedOutputTypes(nd4j::DataType::ANY) - ->setSameMode(true); + ->setSameMode(false); } } } diff --git a/libnd4j/include/ops/declarable/headers/parity_ops.h b/libnd4j/include/ops/declarable/headers/parity_ops.h index 791027baa..c5d0ff207 100644 --- a/libnd4j/include/ops/declarable/headers/parity_ops.h +++ b/libnd4j/include/ops/declarable/headers/parity_ops.h @@ -487,7 +487,7 @@ namespace nd4j { * */ #if NOT_EXCLUDED(OP_zeros_as) - DECLARE_OP(zeros_as, 1, 1, false); + DECLARE_CUSTOM_OP(zeros_as, 1, 1, false, 0, 0); #endif /** @@ -497,7 +497,7 @@ namespace nd4j { * */ #if NOT_EXCLUDED(OP_ones_as) - DECLARE_OP(ones_as, 1, 1, false); + DECLARE_CUSTOM_OP(ones_as, 1, 1, false, 0, 0); #endif /** diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests16.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests16.cpp index cff57b62d..a85772cec 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests16.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests16.cpp @@ -191,7 +191,7 @@ TEST_F(DeclarableOpsTests16, test_range_2) { double tArgs[] = { -1.0, 1.0, 0.01 }; - auto shapes = ::calculateOutputShapes2(nullptr, op.getOpHash(), nullptr, nullptr, 0, tArgs, 3, nullptr, 0, nullptr, 0); + auto shapes = ::calculateOutputShapes2(nullptr, op.getOpHash(), nullptr, nullptr, 0, tArgs, 3, nullptr, 0, nullptr, 0, nullptr, 0); shape::printShapeInfoLinear("Result", shapes->at(0)); ASSERT_TRUE(shape::shapeEquals(z.shapeInfo(), shapes->at(0))); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests8.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests8.cpp index 05c21a8f0..002a31d6e 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests8.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests8.cpp @@ -2978,6 +2978,24 @@ TEST_F(DeclarableOpsTests8, ones_as_test2) { delete results; } +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, ones_as_test3) { + + auto x = NDArrayFactory::create(10.); + //auto y = NDArrayFactory::create(100.); + auto exp = NDArrayFactory::create(1.); + + nd4j::ops::ones_as op; + + auto results = op.evaluate({&x}, {}, {}, {}, {nd4j::DataType::INT32}); + ASSERT_EQ(Status::OK(), results->status()); + auto y = results->at(0); + ASSERT_TRUE(y->isSameShape(exp)); + ASSERT_TRUE(y->equalsTo(exp)); + + delete results; +} + //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, NormalizeMoments_SGO_1) { diff --git a/libnd4j/tests_cpu/layers_tests/JavaInteropTests.cpp b/libnd4j/tests_cpu/layers_tests/JavaInteropTests.cpp index f058d9112..ee828a6e2 100644 --- a/libnd4j/tests_cpu/layers_tests/JavaInteropTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/JavaInteropTests.cpp @@ -112,7 +112,7 @@ TEST_F(JavaInteropTests, TestShapeExposure3) { Nd4jLong iArgs[] = {1}; auto hash = op.getOpHash(); - auto shapeList = calculateOutputShapes2(nullptr, hash, inputBuffers, inputShapes, 2, nullptr, 0, iArgs, 1, nullptr, 0); + auto shapeList = calculateOutputShapes2(nullptr, hash, inputBuffers, inputShapes, 2, nullptr, 0, iArgs, 1, nullptr, 0, nullptr, 0); ASSERT_EQ(3, shapeList->size()); @@ -1065,7 +1065,7 @@ TEST_F(JavaInteropTests, Test_Boolean_Broadcastables_1) { NDArray::prepareSpecialUse({}, {&arrayX, &arrayY}); nd4j::ops::greater_equal op; - auto shapeList = calculateOutputShapes2(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, nullptr, 0, nullptr, 0, nullptr, 0); + auto shapeList = calculateOutputShapes2(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, nullptr, 0, nullptr, 0, nullptr, 0, nullptr, 0); NDArray::registerSpecialUse({}, {&arrayX, &arrayY}); delete shapeList; } diff --git a/libnd4j/tests_cpu/layers_tests/NativeOpsTests.cpp b/libnd4j/tests_cpu/layers_tests/NativeOpsTests.cpp index 42eb50be0..0306fb555 100644 --- a/libnd4j/tests_cpu/layers_tests/NativeOpsTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/NativeOpsTests.cpp @@ -1579,7 +1579,7 @@ TEST_F(NativeOpsTests, CalculateOutputShapeTests_2) { #endif auto shapeList = ::calculateOutputShapes2(nullptr, op.getOpHash(), dataPtrs, shapePtrs, 2, const_cast(tArgs.data()), tArgs.size(), - const_cast(iArgs.data()), iArgs.size(), nullptr, bArgsF.size()); + const_cast(iArgs.data()), iArgs.size(), nullptr, bArgsF.size(), nullptr, 0); // Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs ASSERT_EQ(1, shapeList->size()); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java index f3e0510cb..de421b297 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java @@ -4704,7 +4704,7 @@ public class SameDiff extends SDBaseOps { 0, 0, -1, - 0, 0, 0, 0, 0, 0, 0, 0, 0); + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0); return flatNode; } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/FlatBuffersMapper.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/FlatBuffersMapper.java index a88a9c84f..d87a890ff 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/FlatBuffersMapper.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/FlatBuffersMapper.java @@ -17,6 +17,7 @@ package org.nd4j.autodiff.samediff.serde; import org.nd4j.autodiff.samediff.internal.SameDiffOp; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.shade.guava.primitives.Ints; import com.google.flatbuffers.FlatBufferBuilder; import java.nio.ByteOrder; @@ -361,6 +362,11 @@ public class FlatBuffersMapper { for (int i = 0; i < extraBools.length; i++) { extraBools[i] = fn.extraBools(i); } + DataType[] extraDTypes = new DataType[fn.extraTypesLength()]; + for (int i = 0; i < extraDTypes.length; i++) { + extraDTypes[i] = DataType.fromInt(fn.extraTypes(i)); + } + int[] dimensions = new int[fn.dimensionsLength()]; for (int i = 0; i < dimensions.length; i++) { dimensions[i] = fn.dimensions(i); @@ -401,6 +407,7 @@ public class FlatBuffersMapper { ((CustomOp) op).addIArgument(extraInteger); ((CustomOp) op).addTArgument(extraParams); ((CustomOp) op).addBArgument(extraBools); + ((CustomOp) op).addDArgument(extraDTypes); op.setPropertiesForFunction(props); return op; @@ -714,11 +721,20 @@ public class FlatBuffersMapper { } boolean[] boolArgs = null; + byte[] dtypeArgs = null; long[] extraBits = null; if (node.opType() == Op.Type.CUSTOM) { - DynamicCustomOp dynamicCustomOp = (DynamicCustomOp) node; + val dynamicCustomOp = (DynamicCustomOp) node; extraBits = dynamicCustomOp.iArgs(); boolArgs = dynamicCustomOp.bArgs(); + + if (dynamicCustomOp.numDArguments() > 0) { + dtypeArgs = new byte[dynamicCustomOp.numDArguments()]; + val d = dynamicCustomOp.dArgs(); + for (int e = 0; e < dtypeArgs.length; e++) { + dtypeArgs[e] = (byte) d[e].toInt(); + } + } } else if (node instanceof Enter) { // in case of Enter node we'll be storing unique frame reference val frameName = ((Enter) node).getFrameName(); @@ -817,6 +833,7 @@ public class FlatBuffersMapper { int extraz = FlatNode.createExtraParamsVector(bufferBuilder, extras); int integerArgs = FlatNode.createExtraIntegerVector(bufferBuilder, extraBits); int bArgs = FlatNode.createExtraBoolsVector(bufferBuilder, boolArgs != null ? boolArgs : new boolean[0]); + int dArgs = FlatNode.createOutputTypesVector(bufferBuilder, dtypeArgs != null ? dtypeArgs : new byte[0]); int dimensions = FlatNode.createDimensionsVector(bufferBuilder, dims); int fname = bufferBuilder.createString(node.getOwnName()); int scopeName = bufferBuilder.createString(""); @@ -896,7 +913,8 @@ public class FlatBuffersMapper { scalar, opCds, varCds, - cdsFor + cdsFor, + dArgs ); return flatNode; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/graph/FlatNode.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/graph/FlatNode.java index ca411435d..fdda379b7 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/graph/FlatNode.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/graph/FlatNode.java @@ -73,6 +73,10 @@ public final class FlatNode extends Table { public int varControlDepsLength() { int o = __offset(44); return o != 0 ? __vector_len(o) : 0; } public String controlDepFor(int j) { int o = __offset(46); return o != 0 ? __string(__vector(o) + j * 4) : null; } public int controlDepForLength() { int o = __offset(46); return o != 0 ? __vector_len(o) : 0; } + public byte extraTypes(int j) { int o = __offset(48); return o != 0 ? bb.get(__vector(o) + j * 1) : 0; } + public int extraTypesLength() { int o = __offset(48); return o != 0 ? __vector_len(o) : 0; } + public ByteBuffer extraTypesAsByteBuffer() { return __vector_as_bytebuffer(48, 1); } + public ByteBuffer extraTypesInByteBuffer(ByteBuffer _bb) { return __vector_in_bytebuffer(_bb, 48, 1); } public static int createFlatNode(FlatBufferBuilder builder, int id, @@ -96,9 +100,11 @@ public final class FlatNode extends Table { int scalarOffset, int controlDepsOffset, int varControlDepsOffset, - int controlDepForOffset) { - builder.startObject(22); + int controlDepForOffset, + int extraTypesOffset) { + builder.startObject(23); FlatNode.addOpNum(builder, opNum); + FlatNode.addExtraTypes(builder, extraTypesOffset); FlatNode.addControlDepFor(builder, controlDepForOffset); FlatNode.addVarControlDeps(builder, varControlDepsOffset); FlatNode.addControlDeps(builder, controlDepsOffset); @@ -123,7 +129,7 @@ public final class FlatNode extends Table { return FlatNode.endFlatNode(builder); } - public static void startFlatNode(FlatBufferBuilder builder) { builder.startObject(22); } + public static void startFlatNode(FlatBufferBuilder builder) { builder.startObject(23); } public static void addId(FlatBufferBuilder builder, int id) { builder.addInt(0, id, 0); } public static void addName(FlatBufferBuilder builder, int nameOffset) { builder.addOffset(1, nameOffset, 0); } public static void addOpType(FlatBufferBuilder builder, byte opType) { builder.addByte(2, opType, 0); } @@ -172,6 +178,9 @@ public final class FlatNode extends Table { public static void addControlDepFor(FlatBufferBuilder builder, int controlDepForOffset) { builder.addOffset(21, controlDepForOffset, 0); } public static int createControlDepForVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); } public static void startControlDepForVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); } + public static void addExtraTypes(FlatBufferBuilder builder, int extraTypesOffset) { builder.addOffset(22, extraTypesOffset, 0); } + public static int createExtraTypesVector(FlatBufferBuilder builder, byte[] data) { builder.startVector(1, data.length, 1); for (int i = data.length - 1; i >= 0; i--) builder.addByte(data[i]); return builder.endVector(); } + public static void startExtraTypesVector(FlatBufferBuilder builder, int numElems) { builder.startVector(1, numElems, 1); } public static int endFlatNode(FlatBufferBuilder builder) { int o = builder.endObject(); return o; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/DifferentialFunctionClassHolder.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/DifferentialFunctionClassHolder.java index 05ac2495c..cf82510c0 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/DifferentialFunctionClassHolder.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/DifferentialFunctionClassHolder.java @@ -61,6 +61,7 @@ public class DifferentialFunctionClassHolder { add("tArguments"); add("iArguments"); add("bArguments"); + add("dArguments"); add("hash"); add("opName"); add("sameDiff"); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseOpContext.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseOpContext.java index 050868b36..4a56e2a88 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseOpContext.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseOpContext.java @@ -20,6 +20,7 @@ import lombok.Getter; import lombok.NonNull; import lombok.Setter; import lombok.val; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import java.util.*; @@ -33,9 +34,10 @@ public abstract class BaseOpContext implements OpContext { protected Map fastpath_in = new HashMap<>(); protected Map fastpath_out = new HashMap<>(); - protected List fastpath_d = new ArrayList<>(); + protected List fastpath_t = new ArrayList<>(); protected List fastpath_b = new ArrayList<>(); protected List fastpath_i = new ArrayList<>(); + protected List fastpath_d = new ArrayList<>(); @Setter() @Getter @@ -55,14 +57,14 @@ public abstract class BaseOpContext implements OpContext { @Override public void setTArguments(double... arguments) { - fastpath_d.clear(); + fastpath_t.clear(); for (val v:arguments) - fastpath_d.add(v); + fastpath_t.add(v); } @Override public List getTArguments(){ - return fastpath_d; + return fastpath_t; } @Override @@ -77,6 +79,18 @@ public abstract class BaseOpContext implements OpContext { return fastpath_b; } + @Override + public void setDArguments(DataType... arguments) { + fastpath_d.clear(); + for (val v:arguments) + fastpath_d.add(v); + } + + @Override + public List getDArguments() { + return fastpath_d; + } + @Override public void setInputArray(int index, @NonNull INDArray array) { fastpath_in.put(index, array); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/CustomOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/CustomOp.java index cfa4f9b75..befdfb605 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/CustomOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/CustomOp.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.shape.LongShapeDescriptor; @@ -57,12 +58,18 @@ public interface CustomOp { boolean[] bArgs(); + DataType[] dArgs(); + + void addTArgument(double... arg); + void addIArgument(int... arg); void addIArgument(long... arg); void addBArgument(boolean... arg); + void addDArgument(DataType... arg); + void removeIArgument(Integer arg); Boolean getBArgument(int index); @@ -71,8 +78,6 @@ public interface CustomOp { int numIArguments(); - void addTArgument(double... arg); - void removeTArgument(Double arg); Double getTArgument(int index); @@ -81,6 +86,8 @@ public interface CustomOp { int numBArguments(); + int numDArguments(); + void addInputArgument(INDArray... arg); void removeInputArgument(INDArray arg); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/DynamicCustomOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/DynamicCustomOp.java index e46dfab4b..f4116ba3e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/DynamicCustomOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/DynamicCustomOp.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.shade.guava.collect.Lists; import org.nd4j.shade.guava.primitives.Doubles; import org.nd4j.shade.guava.primitives.Longs; @@ -62,6 +63,9 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp { @Builder.Default protected List bArguments = new ArrayList<>(); + @Builder.Default + protected List dArguments = new ArrayList<>(); + @Builder.Default protected List axis = new ArrayList<>(); @@ -77,6 +81,7 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp { iArguments = new ArrayList<>(); tArguments = new ArrayList<>(); bArguments = new ArrayList<>(); + dArguments = new ArrayList<>(); } public DynamicCustomOp(SameDiff sameDiff, SDVariable arg) { @@ -93,6 +98,7 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp { iArguments = new ArrayList<>(); tArguments = new ArrayList<>(); bArguments = new ArrayList<>(); + dArguments = new ArrayList<>(); } public DynamicCustomOp(String opName, INDArray input, INDArray output, List tArguments, int[] iArguments) { @@ -132,6 +138,7 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp { this.iArguments.add((Long) a.longValue()); } bArguments = new ArrayList<>(); + dArguments = new ArrayList<>(); } /** @@ -173,6 +180,7 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp { iArguments = new ArrayList<>(); tArguments = new ArrayList<>(); bArguments = new ArrayList<>(); + dArguments = new ArrayList<>(); this.inplaceCall = inPlace; } @@ -185,6 +193,7 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp { iArguments = new ArrayList<>(); tArguments = new ArrayList<>(); bArguments = new ArrayList<>(); + dArguments = new ArrayList<>(); } @@ -260,6 +269,11 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp { return hash; } + @Override + public int numDArguments() { + return dArguments.size(); + } + @Override public List outputArguments() { return outputArguments; @@ -280,6 +294,11 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp { return Doubles.toArray(tArguments); } + @Override + public DataType[] dArgs() { + return dArguments.toArray(new DataType[dArguments.size()]); + } + @Override public void addIArgument(int... arg) { for (long a: arg) @@ -323,6 +342,15 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp { addTArgument(Doubles.asList(arg).toArray(new Double[arg.length])); } + @Override + public void addDArgument(DataType... arg) { + if (dArguments == null) + dArguments = new ArrayList<>(); + + if (arg != null) + dArguments.addAll(Arrays.asList(arg)); + } + private void addTArgument(Double... arg) { tArguments.addAll(Arrays.asList(arg)); } @@ -650,6 +678,7 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp { private List outputArguments = new ArrayList<>(); private List tArguments = new ArrayList<>(); private List iArguments = new ArrayList<>(); + private List dArguments = new ArrayList<>(); private List bArguments = new ArrayList<>(); protected DynamicCustomOpsBuilder(String opName, long hash, int numInputs, int numOutputs, boolean inplaceAllowed, int numTArguments, int numIArguments) { @@ -870,6 +899,7 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp { result.iArguments = iArguments; result.tArguments = tArguments; result.bArguments = bArguments; + result.dArguments = dArguments; result.inplaceCall = inplaceCall; result.hash = opHash; result.outputShapes = outputShapes; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/OpContext.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/OpContext.java index 3deefe7c0..4063746b3 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/OpContext.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/OpContext.java @@ -17,6 +17,7 @@ package org.nd4j.linalg.api.ops; import org.bytedeco.javacpp.Pointer; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.primitives.Pair; @@ -43,9 +44,15 @@ public interface OpContext extends AutoCloseable { * @param arguments */ void setTArguments(double... arguments); - List getTArguments(); + /** + * This method sets data type arguments required for operation + * @param arguments + */ + void setDArguments(DataType... arguments); + List getDArguments(); + /** * This method sets boolean arguments required for operation * @param arguments diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/ScatterUpdate.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/ScatterUpdate.java index 83020cb57..313b7ccb4 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/ScatterUpdate.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/ScatterUpdate.java @@ -18,6 +18,7 @@ package org.nd4j.linalg.api.ops.custom; import lombok.NonNull; import lombok.val; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.CustomOp; import org.nd4j.linalg.api.ops.CustomOpDescriptor; @@ -246,6 +247,21 @@ public class ScatterUpdate implements CustomOp { } + @Override + public DataType[] dArgs() { + return new DataType[0]; + } + + @Override + public void addDArgument(DataType... arg) { + + } + + @Override + public int numDArguments() { + return 0; + } + @Override public void clearArrays() { op.clearArrays(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/OneHot.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/OneHot.java index d442e4623..beb9d09b9 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/OneHot.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/OneHot.java @@ -83,6 +83,9 @@ public class OneHot extends DynamicCustomOp { addIArgument(depth); addTArgument(on); addTArgument(off); + + if (outputType != null) + addDArgument(outputType); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/OnesLike.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/OnesLike.java index a8f49bdf2..4b4b3e578 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/OnesLike.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/OnesLike.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.shape; +import lombok.NonNull; import lombok.extern.slf4j.Slf4j; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; @@ -23,6 +24,7 @@ import org.nd4j.base.Preconditions; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; @@ -53,6 +55,22 @@ public class OnesLike extends DynamicCustomOp { public OnesLike(String name, SameDiff sameDiff, SDVariable input, DataType dataType) { super(name, sameDiff, new SDVariable[]{input}, false); this.outputType = dataType; + addArgs(); + } + + public OnesLike(@NonNull INDArray input, DataType dataType) { + this.addInputArgument(input); + this.outputType = dataType; + addArgs(); + } + + public OnesLike(@NonNull INDArray input) { + this(input, input.dataType()); + } + + public void addArgs() { + if (outputType != null) + addDArgument(outputType); } @@ -78,6 +96,8 @@ public class OnesLike extends DynamicCustomOp { if(attributesForNode.containsKey("T")) { outputType = TFGraphMapper.convertType(attributesForNode.get("T").getType()); } + + addArgs(); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java index f0bf4bc5e..d32aff5b1 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java @@ -3438,6 +3438,16 @@ public class Nd4j { return create(ArrayUtil.flatten(data), data.length, data[0].length, data[0][0].length); } + /** + * Create 2D double array based on java 2d double array. and ordering + * + * @param data the data to use + * @return the created ndarray. + */ + public static INDArray create(int[][] data) { + return createFromArray(data); + } + /** * create 3D int array based on 3D java int array. * @param data java 3D i array. diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOps.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOps.java index 95c97068e..d284974eb 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOps.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOps.java @@ -1056,7 +1056,7 @@ public interface NativeOps { OpaqueShapeList calculateOutputShapes(PointerPointer extraPointers, long hash, PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong *") LongPointer iArgs, int numIArgs); - OpaqueShapeList calculateOutputShapes2(PointerPointer extraPointers, long hash, PointerPointer inputBunffers, PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong *") LongPointer iArgs, int numIArgs, @Cast("bool *") BooleanPointer bArgs, int numBArgs); + OpaqueShapeList calculateOutputShapes2(PointerPointer extraPointers, long hash, PointerPointer inputBunffers, PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong *") LongPointer iArgs, int numIArgs, @Cast("bool *") BooleanPointer bArgs, int numBArgs, @Cast("int *") IntPointer dArgs, int numDArgs); long getShapeListSize(OpaqueShapeList list); LongPointer getShape(OpaqueShapeList list, long i); @@ -1156,6 +1156,7 @@ public interface NativeOps { void setGraphContextOutputBuffer(OpaqueContext ptr, int index, OpaqueDataBuffer databuffer, Pointer shapeInfo, Pointer specialShapeInfo); void setGraphContextTArguments(OpaqueContext ptr, DoublePointer arguments, int numberOfArguments); void setGraphContextIArguments(OpaqueContext ptr, LongPointer arguments, int numberOfArguments); + void setGraphContextDArguments(OpaqueContext ptr, IntPointer arguments, int numberOfArguments); void setGraphContextBArguments(OpaqueContext ptr, BooleanPointer arguments, int numberOfArguments); void ctxAllowHelpers(OpaqueContext ptr, boolean reallyAllow); void ctxSetExecutionMode(OpaqueContext ptr, int execMode); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java index 04b86dc02..f18bd1459 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java @@ -1928,6 +1928,8 @@ public class CudaExecutioner extends DefaultOpExecutioner { val bArgs = op.bArgs().length > 0 ? new BooleanPointer(op.bArgs().length) : null; + val dArgs = op.numDArguments() > 0 ? new IntPointer(op.numDArguments()) : null; + cnt = 0; for (val b: op.bArgs()) bArgs.put(cnt++, b); @@ -1936,7 +1938,12 @@ public class CudaExecutioner extends DefaultOpExecutioner { for (val t: op.tArgs()) tArgs.put(cnt++, t); - OpaqueShapeList ptrptr = nativeOps.calculateOutputShapes2(null, hash, inputBuffers, inputShapes, op.inputArguments().size(), tArgs, op.tArgs().length, iArgs, op.iArgs().length, bArgs, op.numBArguments()); + cnt = 0; + val dArgs1 = op.dArgs(); + for (val d: dArgs1) + dArgs.put(cnt++, d.toInt()); + + OpaqueShapeList ptrptr = nativeOps.calculateOutputShapes2(null, hash, inputBuffers, inputShapes, op.inputArguments().size(), tArgs, op.tArgs().length, iArgs, op.iArgs().length, bArgs, op.numBArguments(), dArgs, op.numDArguments()); if (nativeOps.lastErrorCode() != 0) throw new RuntimeException(nativeOps.lastErrorMessage()); @@ -2003,6 +2010,7 @@ public class CudaExecutioner extends DefaultOpExecutioner { context.setBArguments(op.bArgs()); context.setIArguments(op.iArgs()); context.setTArguments(op.tArgs()); + context.setDArguments(op.dArgs()); val result = exec(op, context); val states = context.getRngStates(); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaOpContext.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaOpContext.java index 487f38232..01127e891 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaOpContext.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaOpContext.java @@ -18,12 +18,10 @@ package org.nd4j.linalg.jcublas.ops.executioner; import lombok.NonNull; import lombok.val; -import org.bytedeco.javacpp.BooleanPointer; -import org.bytedeco.javacpp.DoublePointer; -import org.bytedeco.javacpp.LongPointer; -import org.bytedeco.javacpp.Pointer; +import org.bytedeco.javacpp.*; import org.nd4j.jita.allocator.impl.AtomicAllocator; import org.nd4j.jita.allocator.pointers.cuda.cudaStream_t; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.concurrency.AffinityManager; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseOpContext; @@ -76,6 +74,18 @@ public class CudaOpContext extends BaseOpContext implements OpContext { } } + @Override + public void setDArguments(DataType... arguments) { + if (arguments.length > 0) { + super.setDArguments(arguments); + val args = new int[arguments.length]; + for (int e = 0; e < arguments.length; e++) + args[e] = arguments[e].toInt(); + + nativeOps.setGraphContextDArguments(context, new IntPointer(args), arguments.length); + }; + } + @Override public void setRngStates(long rootState, long nodeState) { nativeOps.setRandomGeneratorStates(nativeOps.getGraphContextRandomGenerator(context), rootState, nodeState); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java index d1840ab63..8d0029bc3 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java @@ -2984,12 +2984,12 @@ public native int execCustomOp2(@Cast("Nd4jPointer*") PointerPointer extraPointe public native OpaqueShapeList calculateOutputShapes(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs); public native OpaqueShapeList calculateOutputShapes(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs); public native OpaqueShapeList calculateOutputShapes(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs); -public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs); -public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs); -public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs); -public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs); -public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs); -public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs); +public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs, IntPointer dArgs, int numDArgs); +public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs, IntBuffer dArgs, int numDArgs); +public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs, int[] dArgs, int numDArgs); +public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs, IntPointer dArgs, int numDArgs); +public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs, IntBuffer dArgs, int numDArgs); +public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs, int[] dArgs, int numDArgs); public native @Cast("Nd4jLong") long getShapeListSize(OpaqueShapeList list); public native @Cast("Nd4jLong*") LongPointer getShape(OpaqueShapeList list, @Cast("Nd4jLong") long i); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/CpuOpContext.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/CpuOpContext.java index 1863d6c1c..461646311 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/CpuOpContext.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/CpuOpContext.java @@ -17,10 +17,9 @@ package org.nd4j.linalg.cpu.nativecpu.ops; import lombok.NonNull; -import org.bytedeco.javacpp.BooleanPointer; -import org.bytedeco.javacpp.DoublePointer; -import org.bytedeco.javacpp.LongPointer; -import org.bytedeco.javacpp.Pointer; +import lombok.val; +import org.bytedeco.javacpp.*; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseOpContext; import org.nd4j.linalg.api.ops.ExecutionMode; @@ -73,6 +72,18 @@ public class CpuOpContext extends BaseOpContext implements OpContext { }; } + @Override + public void setDArguments(DataType... arguments) { + if (arguments.length > 0) { + super.setDArguments(arguments); + val args = new int[arguments.length]; + for (int e = 0; e < arguments.length; e++) + args[e] = arguments[e].toInt(); + + nativeOps.setGraphContextDArguments(context, new IntPointer(args), arguments.length); + }; + } + @Override public void setRngStates(long rootState, long nodeState) { nativeOps.setRandomGeneratorStates(nativeOps.getGraphContextRandomGenerator(context), rootState, nodeState); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java index ebeab58f4..cc3d17b5f 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java @@ -1636,6 +1636,7 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { context.setBArguments(op.bArgs()); context.setIArguments(op.iArgs()); context.setTArguments(op.tArgs()); + context.setDArguments(op.dArgs()); val result = exec(op, context); val states = context.getRngStates(); @@ -1712,6 +1713,8 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { val bArgs = op.numBArguments() > 0 ? new BooleanPointer(op.numBArguments()) : null; + val dArgs = op.numDArguments() > 0 ? new IntPointer(op.numDArguments()) : null; + cnt = 0; val bArgs1 = op.bArgs(); for (val b: bArgs1) @@ -1722,11 +1725,17 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { for (val t: tArgs1) tArgs.put(cnt++, t); + cnt = 0; + val dArgs1 = op.dArgs(); + for (val d: dArgs1) + dArgs.put(cnt++, d.toInt()); + + OpaqueShapeList ptrptr; try { ptrptr = loop.calculateOutputShapes2(null, hash, inputBuffers, inputShapes, op.numInputArguments(), tArgs, - op.numTArguments(), iArgs, op.numIArguments(), bArgs, op.numBArguments()); + op.numTArguments(), iArgs, op.numIArguments(), bArgs, op.numBArguments(), dArgs, op.numDArguments()); if (loop.lastErrorCode() != 0) throw new RuntimeException(loop.lastErrorMessage()); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java index 40a2f5236..93fbb71d7 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java @@ -2987,12 +2987,12 @@ public native int execCustomOp2(@Cast("Nd4jPointer*") PointerPointer extraPointe public native OpaqueShapeList calculateOutputShapes(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs); public native OpaqueShapeList calculateOutputShapes(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs); public native OpaqueShapeList calculateOutputShapes(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs); -public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs); -public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs); -public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs); -public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs); -public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs); -public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs); +public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs, IntPointer dArgs, int numDArgs); +public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs, IntBuffer dArgs, int numDArgs); +public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs, int[] dArgs, int numDArgs); +public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs, IntPointer dArgs, int numDArgs); +public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs, IntBuffer dArgs, int numDArgs); +public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs, int[] dArgs, int numDArgs); public native @Cast("Nd4jLong") long getShapeListSize(OpaqueShapeList list); public native @Cast("Nd4jLong*") LongPointer getShape(OpaqueShapeList list, @Cast("Nd4jLong") long i); @@ -17951,7 +17951,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); * */ // #if NOT_EXCLUDED(OP_zeros_as) - @Namespace("nd4j::ops") public static class zeros_as extends DeclarableOp { + @Namespace("nd4j::ops") public static class zeros_as extends DeclarableCustomOp { static { Loader.load(); } /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ public zeros_as(Pointer p) { super(p); } @@ -17962,10 +17962,10 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); return (zeros_as)super.position(position); } - public zeros_as() { super((Pointer)null); allocate(); } - private native void allocate(); - public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); - } + public zeros_as() { super((Pointer)null); allocate(); } + private native void allocate(); + public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); + } // #endif /** @@ -17975,7 +17975,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); * */ // #if NOT_EXCLUDED(OP_ones_as) - @Namespace("nd4j::ops") public static class ones_as extends DeclarableOp { + @Namespace("nd4j::ops") public static class ones_as extends DeclarableCustomOp { static { Loader.load(); } /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ public ones_as(Pointer p) { super(p); } @@ -17986,10 +17986,10 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); return (ones_as)super.position(position); } - public ones_as() { super((Pointer)null); allocate(); } - private native void allocate(); - public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); - } + public ones_as() { super((Pointer)null); allocate(); } + private native void allocate(); + public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); + } // #endif /** diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/MiscOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/MiscOpValidation.java index e02e4b91d..59932d670 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/MiscOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/MiscOpValidation.java @@ -1169,6 +1169,26 @@ public class MiscOpValidation extends BaseOpValidation { assertNull(err); } + @Test + public void testOneHot4() { + + INDArray indicesArr = Nd4j.createFromArray(0, 2, -1, 1); + + SameDiff sd = SameDiff.create(); + SDVariable indices = sd.constant("indices", indicesArr); + int depth = 3; + int axis = -1; + SDVariable oneHot = sd.oneHot("oneHot", indices, depth, axis, 5.0, 0.0, DataType.INT32); + + INDArray exp = Nd4j.create(new int[][]{{5, 0, 0}, {0,0,5}, {0,0,0}, {0, 5, 0}}); + + String err = OpValidation.validate(new TestCase(sd) + .expected(oneHot, exp) + .gradientCheck(false)); + + assertNull(err); + } + @Test public void testOneHot3() { //https://github.com/deeplearning4j/deeplearning4j/issues/6872 @@ -1204,8 +1224,6 @@ public class MiscOpValidation extends BaseOpValidation { assertNull(err); } - - @Test public void testLinspace(){ SameDiff sd = SameDiff.create(); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java index e7160a1d8..e9d2979c6 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java @@ -36,6 +36,7 @@ import org.nd4j.linalg.api.ops.impl.image.ResizeArea; import org.nd4j.linalg.api.ops.impl.image.ResizeBilinear; import org.nd4j.linalg.api.ops.impl.reduce.MmulBp; import org.nd4j.linalg.api.ops.impl.shape.Create; +import org.nd4j.linalg.api.ops.impl.shape.OnesLike; import org.nd4j.linalg.api.ops.impl.transforms.any.IsMax; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.ModOp; @@ -1673,4 +1674,13 @@ public class CustomOpsTests extends BaseNd4jTest { assertEquals(expected, ret[0]); } + + @Test + public void testOnesLike_1() { + val x = Nd4j.create(DataType.FLOAT, 3, 4, 5); + val e = Nd4j.ones(DataType.INT32, 3, 4, 5); + + val z = Nd4j.exec(new OnesLike(x, DataType.INT32))[0]; + assertEquals(e, z); + } } From d39ca6d488695903ba47c565b92abf98b454e5e6 Mon Sep 17 00:00:00 2001 From: Abdelrauf Date: Fri, 31 Jan 2020 09:57:31 +0400 Subject: [PATCH 08/17] ElementWiseStride==1 cases for legacy random ops (#202) Signed-off-by: AbdelRauf --- libnd4j/include/loops/cpu/random.cpp | 99 ++++++++++++++++++++-------- 1 file changed, 70 insertions(+), 29 deletions(-) diff --git a/libnd4j/include/loops/cpu/random.cpp b/libnd4j/include/loops/cpu/random.cpp index d4c808719..6fccc6376 100644 --- a/libnd4j/include/loops/cpu/random.cpp +++ b/libnd4j/include/loops/cpu/random.cpp @@ -29,6 +29,7 @@ using namespace randomOps; namespace functions { namespace random { + template template void RandomFunction::execTransform(Nd4jPointer state, @@ -56,18 +57,32 @@ namespace functions { if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo) && shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { - uint xShapeInfoCast[MAX_RANK]; - const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - auto func = PRAGMA_THREADS_FOR { - PRAGMA_OMP_SIMD - for (auto i = start; i < stop; i += increment) { - auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); - z[offset] = OpClass::op(x[offset], y[offset], i, length, rng, extraArguments); - } - }; + if(shape::elementWiseStride(zShapeInfo) == 1 && shape::elementWiseStride(xShapeInfo) == 1 && shape::elementWiseStride(yShapeInfo) == 1 && + shape::order(xShapeInfo) == shape::order(zShapeInfo) && shape::order(zShapeInfo) == shape::order(yShapeInfo) ){ + + auto func = PRAGMA_THREADS_FOR { + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) { + z[i] = OpClass::op(x[i], y[i], i, length, rng, extraArguments); + } + }; + samediff::Threads::parallel_for(func, 0, length, 1); + } + else{ + uint xShapeInfoCast[MAX_RANK]; + const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - samediff::Threads::parallel_for(func, 0, length, 1); + auto func = PRAGMA_THREADS_FOR { + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i += increment) { + auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); + z[offset] = OpClass::op(x[offset], y[offset], i, length, rng, extraArguments); + } + }; + + samediff::Threads::parallel_for(func, 0, length, 1); + } } else if (shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo)) { @@ -169,15 +184,27 @@ namespace functions { if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { - auto func = PRAGMA_THREADS_FOR { - PRAGMA_OMP_SIMD - for (uint64_t i = start; i < stop; i += increment) { - auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); - z[offset] = OpClass::op(x[offset], i, length, rng, extraArguments); - } - }; + if(shape::elementWiseStride(zShapeInfo) == 1 && shape::elementWiseStride(xShapeInfo) == 1 && shape::order(xShapeInfo) == shape::order(zShapeInfo)){ + + auto func = PRAGMA_THREADS_FOR { + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) { + z[i] = OpClass::op(x[i], i, length, rng, extraArguments); + } + }; + samediff::Threads::parallel_for(func, 0, length, 1); + } + else{ + auto func = PRAGMA_THREADS_FOR { + PRAGMA_OMP_SIMD + for (uint64_t i = start; i < stop; i += increment) { + auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); + z[offset] = OpClass::op(x[offset], i, length, rng, extraArguments); + } + }; - samediff::Threads::parallel_for(func, 0, length, 1); + samediff::Threads::parallel_for(func, 0, length, 1); + } } else { @@ -208,20 +235,34 @@ namespace functions { auto length = shape::length(zShapeInfo); nd4j::graph::RandomGenerator* rng = reinterpret_cast(state); - nd4j::OmpLaunchHelper info(length); - uint zShapeInfoCast[MAX_RANK]; - const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); + if(shape::elementWiseStride(zShapeInfo) == 1){ - auto func = PRAGMA_THREADS_FOR { - PRAGMA_OMP_SIMD - for (uint64_t i = start; i < stop; i += increment) { - auto offset = shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ); - z[offset] = OpClass::op(i, length, rng, extraArguments); - } - }; + auto func = PRAGMA_THREADS_FOR { + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) { + z[i] = OpClass::op( i, length, rng, extraArguments); + } + }; - samediff::Threads::parallel_for(func, 0, length, 1); + samediff::Threads::parallel_for(func, 0, length, 1); + } + else{ + nd4j::OmpLaunchHelper info(length); + + uint zShapeInfoCast[MAX_RANK]; + const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); + + auto func = PRAGMA_THREADS_FOR { + PRAGMA_OMP_SIMD + for (uint64_t i = start; i < stop; i += increment) { + auto offset = shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ); + z[offset] = OpClass::op(i, length, rng, extraArguments); + } + }; + + samediff::Threads::parallel_for(func, 0, length, 1); + } } template From 1ab86d1306b2b30fdbd549bda43edf600689e461 Mon Sep 17 00:00:00 2001 From: raver119 Date: Fri, 31 Jan 2020 10:45:41 +0300 Subject: [PATCH 09/17] Range op data type (#204) * - range op now accepts dargs - dargs now can be in signature Signed-off-by: raver119 * range dtype java side Signed-off-by: raver119 * linspace fix Signed-off-by: raver119 * lin_space fix for scalar outputs Signed-off-by: raver119 --- .../generic/parity_ops/lin_space.cpp | 5 +++ .../declarable/generic/parity_ops/range.cpp | 31 ++++++++++------ .../ops/declarable/impl/DeclarableOp.cpp | 36 ++++++++++--------- .../layers_tests/DeclarableOpsTests3.cpp | 20 +++++++++++ .../linalg/api/ops/random/impl/Range.java | 4 +++ .../java/org/nd4j/linalg/factory/Nd4j.java | 13 +++++++ .../nd4j/linalg/custom/CustomOpsTests.java | 8 +++++ 7 files changed, 90 insertions(+), 27 deletions(-) diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/lin_space.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/lin_space.cpp index f932a1274..8d30185b1 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/lin_space.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/lin_space.cpp @@ -31,6 +31,11 @@ namespace ops { auto start = INPUT_VARIABLE(0); auto finish = INPUT_VARIABLE(1); auto numOfElements = INPUT_VARIABLE(2); + + if (numOfElements->e(0) == 1) { + output->assign(start); + return Status::OK(); + } output->linspace(start->e(0), (finish->e(0) - start->e(0)) / (numOfElements->e(0) - 1.)); return Status::OK(); diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/range.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/range.cpp index 04b5b48d6..7faf82b08 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/range.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/range.cpp @@ -130,7 +130,7 @@ DECLARE_SHAPE_FN(range) { const int numIArgs = block.getIArguments()->size(); Nd4jLong steps = 0; - nd4j::DataType dataType = nd4j::DataType::INHERIT; + nd4j::DataType dataType = block.numD() ? D_ARG(0) : nd4j::DataType::INHERIT; if (numInArrs > 0) { auto isR = INPUT_VARIABLE(0)->isR(); @@ -159,7 +159,9 @@ DECLARE_SHAPE_FN(range) { REQUIRE_TRUE(delta != 0, 0, "CUSTOM RANGE OP: delta should not be equal to zero !"); steps = static_cast((limit - start) / delta); - dataType = INPUT_VARIABLE(0)->dataType(); + + if (!block.numD()) + dataType = INPUT_VARIABLE(0)->dataType(); if(math::nd4j_abs(start + steps * delta) < math::nd4j_abs(limit)) ++steps; @@ -187,7 +189,9 @@ DECLARE_SHAPE_FN(range) { REQUIRE_TRUE(delta != 0, 0, "CUSTOM RANGE OP: delta should not be equal to zero !"); steps = static_cast((limit - start) / delta); - dataType = INPUT_VARIABLE(0)->dataType(); + + if (!block.numD()) + dataType = INPUT_VARIABLE(0)->dataType(); if(math::nd4j_abs(start + steps * delta) < math::nd4j_abs(limit)) ++steps; @@ -214,10 +218,12 @@ DECLARE_SHAPE_FN(range) { REQUIRE_TRUE(delta != 0, 0, "CUSTOM RANGE OP: delta should not be equal to zero !"); - if (limit > DataTypeUtils::max()) - dataType = nd4j::DataType::INT64; - else - dataType = nd4j::DataType::INT32; + if (!block.numD()) { + if (limit > DataTypeUtils::max()) + dataType = nd4j::DataType::INT64; + else + dataType = nd4j::DataType::INT32; + } steps = (limit - start) / delta; @@ -248,10 +254,13 @@ DECLARE_SHAPE_FN(range) { REQUIRE_TRUE(delta != 0, 0, "CUSTOM RANGE OP: delta should not be equal to zero !"); steps = static_cast((limit - start) / delta); - if (Environment::getInstance()->precisionBoostAllowed()) - dataType = nd4j::DataType::DOUBLE; - else - dataType = Environment::getInstance()->defaultFloatDataType(); + + if (!block.numD()) { + if (Environment::getInstance()->precisionBoostAllowed()) + dataType = nd4j::DataType::DOUBLE; + else + dataType = Environment::getInstance()->defaultFloatDataType(); + } if(math::nd4j_abs(start + steps * delta) < math::nd4j_abs(limit)) ++steps; diff --git a/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp b/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp index 6f26c1095..7c4138d36 100644 --- a/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp +++ b/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp @@ -830,8 +830,12 @@ namespace nd4j { template <> Nd4jStatus DeclarableOp::execute(const std::vector &inputs, const std::vector &outputs, std::initializer_list tArgs) { - std::vector realArgs(tArgs); - return execute(inputs, outputs, realArgs, std::vector(), std::vector(), std::vector());; + return execute(inputs, outputs, tArgs, std::vector(), std::vector(), std::vector()); + } + + template <> + Nd4jStatus DeclarableOp::execute(const std::vector &inputs, const std::vector &outputs, std::initializer_list dArgs) { + return execute(inputs, outputs, std::vector(), std::vector(), std::vector(), dArgs); } template <> @@ -840,13 +844,12 @@ namespace nd4j { for (auto v:tArgs) realArgs.emplace_back(v); - return execute(inputs, outputs, realArgs, std::vector(), std::vector(), std::vector());; + return execute(inputs, outputs, realArgs, std::vector(), std::vector(), std::vector()); } template <> Nd4jStatus DeclarableOp::execute(const std::vector &inputs, const std::vector &outputs, std::initializer_list iArgs) { - std::vector realArgs(iArgs); - return execute(inputs, outputs, std::vector(), realArgs, std::vector(), std::vector());; + return execute(inputs, outputs, std::vector(), iArgs, std::vector(), std::vector()); } template <> @@ -855,13 +858,12 @@ namespace nd4j { for (auto v:iArgs) realArgs.emplace_back(v); - return execute(inputs, outputs, std::vector(), realArgs, std::vector(), std::vector());; + return execute(inputs, outputs, std::vector(), realArgs, std::vector(), std::vector()); } template <> Nd4jStatus DeclarableOp::execute(const std::vector &inputs, const std::vector &outputs, std::initializer_list bArgs) { - std::vector realArgs(bArgs); - return execute(inputs, outputs, std::vector(), std::vector(), realArgs, std::vector());; + return execute(inputs, outputs, std::vector(), std::vector(), bArgs, std::vector()); } Nd4jStatus DeclarableOp::execute(const std::vector &inputs, const std::vector &outputs, const std::vector &tArgs, const std::vector &iArgs, const std::vector &bArgs, const std::vector &dArgs, bool isInplace) { @@ -903,13 +905,12 @@ namespace nd4j { for (auto v:iArgs) realArgs.emplace_back(v); - return evaluate(inputs, std::vector(), realArgs, std::vector(), std::vector());; + return evaluate(inputs, std::vector(), realArgs, std::vector(), std::vector()); } template <> nd4j::ResultSet *DeclarableOp::evaluate(const std::vector &inputs, std::initializer_list iArgs) { - std::vector realArgs(iArgs); - return evaluate(inputs, std::vector(), realArgs, std::vector(), std::vector());; + return evaluate(inputs, std::vector(), iArgs, std::vector(), std::vector()); } template <> @@ -918,19 +919,22 @@ namespace nd4j { for (auto v:tArgs) realArgs.emplace_back(v); - return evaluate(inputs, realArgs, std::vector(), std::vector(), std::vector());; + return evaluate(inputs, realArgs, std::vector(), std::vector(), std::vector()); } template <> nd4j::ResultSet *DeclarableOp::evaluate(const std::vector &inputs, std::initializer_list tArgs) { - std::vector realArgs(tArgs); - return evaluate(inputs, realArgs, std::vector(), std::vector(), std::vector());; + return evaluate(inputs, tArgs, std::vector(), std::vector(), std::vector()); } template <> nd4j::ResultSet *DeclarableOp::evaluate(const std::vector &inputs, std::initializer_list bArgs) { - std::vector realArgs(bArgs); - return evaluate(inputs, std::vector(), std::vector(), realArgs, std::vector());; + return evaluate(inputs, std::vector(), std::vector(), bArgs, std::vector()); + } + + template <> + nd4j::ResultSet *DeclarableOp::evaluate(const std::vector &inputs, std::initializer_list bArgs) { + return evaluate(inputs, std::vector(), std::vector(), std::vector(), bArgs); } nd4j::ResultSet *DeclarableOp::evaluate(const std::vector &inputs, const std::vector &tArgs, const std::vector &iArgs, const std::vector &bArgs, const std::vector &dArgs, bool isInplace) { diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests3.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests3.cpp index e39589270..04816b2b2 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests3.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests3.cpp @@ -438,6 +438,26 @@ TEST_F(DeclarableOpsTests3, Test_Range_3) { } +TEST_F(DeclarableOpsTests3, Test_Range_10) { + auto start= NDArrayFactory::create('c', {1, 1}, {0.f}); + auto stop= NDArrayFactory::create('c', {1, 1}, {2.f}); + auto step= NDArrayFactory::create('c', {1, 1}, {1.f}); + auto exp= NDArrayFactory::create('c', {2}, {0.f, 1.f}); + + nd4j::ops::range op; + auto result = op.evaluate({&start, &stop, &step}, {nd4j::DataType::DOUBLE}); + + ASSERT_EQ(ND4J_STATUS_OK, result->status()); + + auto z = result->at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + delete result; +} + + TEST_F(DeclarableOpsTests3, Test_Range_4) { auto exp= NDArrayFactory::create('c', {13}, {-10.f, -8.334f, -6.668f, -5.002f, -3.336f, -1.67f, -0.004f, 1.662f, 3.328f, 4.994f, 6.66f, 8.326f, 9.992f}); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/Range.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/Range.java index d211fbe25..ade01281c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/Range.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/Range.java @@ -51,6 +51,7 @@ public class Range extends DynamicCustomOp { public Range(SameDiff sd, double from, double to, double step, DataType dataType){ super(null, sd, new SDVariable[0]); addTArgument(from, to, step); + addDArgument(dataType); this.from = from; this.to = to; this.delta = step; @@ -63,11 +64,13 @@ public class Range extends DynamicCustomOp { this.to = to; this.delta = step; this.dataType = dataType; + addDArgument(dataType); } public Range(SameDiff sd, SDVariable from, SDVariable to, SDVariable step, DataType dataType){ super(null, sd, new SDVariable[]{from, to, step}); this.dataType = dataType; + addDArgument(dataType); } @@ -99,6 +102,7 @@ public class Range extends DynamicCustomOp { if(attributesForNode.containsKey("Tidx")){ dataType = TFGraphMapper.convertType(attributesForNode.get("Tidx").getType()); } + addDArgument(dataType); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java index d32aff5b1..8e638f373 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java @@ -1944,6 +1944,9 @@ public class Nd4j { if(lower == upper && num == 1) { return Nd4j.scalar(dtype, lower); } + if (num == 1) { + return Nd4j.scalar(dtype, lower); + } if (dtype.isIntType()) { return linspaceWithCustomOp(lower, upper, (int)num, dtype); } else if (dtype.isFPType()) { @@ -1964,6 +1967,9 @@ public class Nd4j { */ public static INDArray linspace(@NonNull DataType dataType, double lower, double step, long num) { Preconditions.checkState(dataType.isFPType()); + if (num == 1) + return Nd4j.scalar(dataType, lower); + return Nd4j.getExecutioner().exec(new Linspace(lower, num, step, dataType)); } @@ -1977,10 +1983,15 @@ public class Nd4j { */ public static INDArray linspace( double lower, double upper, long num, @NonNull DataType dataType) { Preconditions.checkState(dataType.isFPType()); + if (num == 1) + return Nd4j.scalar(dataType, lower); + return Nd4j.getExecutioner().exec(new Linspace(lower, upper, num, dataType)); } private static INDArray linspaceWithCustomOp(long lower, long upper, int num, DataType dataType) { + if (num == 1) + return Nd4j.scalar(dataType, lower); INDArray result = Nd4j.createUninitialized(dataType, new long[] {num}, Nd4j.order()); @@ -1994,6 +2005,8 @@ public class Nd4j { } private static INDArray linspaceWithCustomOpByRange(long lower, long upper, long num, long step, DataType dataType) { + if (num == 1) + return Nd4j.scalar(dataType, lower); INDArray result = Nd4j.createUninitialized(dataType, new long[] {num}, Nd4j.order()); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java index e9d2979c6..49ff345e7 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java @@ -1683,4 +1683,12 @@ public class CustomOpsTests extends BaseNd4jTest { val z = Nd4j.exec(new OnesLike(x, DataType.INT32))[0]; assertEquals(e, z); } + + @Test + public void testLinSpaceEdge_1() { + val x = Nd4j.linspace(1,10,1, DataType.FLOAT); + val e = Nd4j.scalar(1.0f); + + assertEquals(e, x); + } } From 00cd61f32dfdf97d498ce4e42b7ce20b52237e0d Mon Sep 17 00:00:00 2001 From: raver119 Date: Fri, 31 Jan 2020 15:57:55 +0300 Subject: [PATCH 10/17] roll back flatbuffers version Signed-off-by: raver119 --- libnd4j/CMakeLists.txt.in | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libnd4j/CMakeLists.txt.in b/libnd4j/CMakeLists.txt.in index f351bf1b7..8e8741c86 100644 --- a/libnd4j/CMakeLists.txt.in +++ b/libnd4j/CMakeLists.txt.in @@ -5,7 +5,7 @@ project(flatbuffers-download NONE) include(ExternalProject) ExternalProject_Add(flatbuffers GIT_REPOSITORY https://github.com/google/flatbuffers.git - GIT_TAG v1.11.0 + GIT_TAG v1.10.0 SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/flatbuffers-src" BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/flatbuffers-build" CONFIGURE_COMMAND "" From d52e67209e3de6af2d67867836eba9456a901dfc Mon Sep 17 00:00:00 2001 From: Oleh Date: Fri, 31 Jan 2020 15:30:49 +0200 Subject: [PATCH 11/17] Oleh convert (#200) * StringUtils for utf convertor raw implementation of all possible combinations, need to be add counter of bytes per symbol for any type and add api to call convertors and store data Signed-off-by: Oleg * StringUtils for utf convertor more corrections to support convertors Signed-off-by: Oleg * StringUtils for utf convertor some corrections and bug fixes, need review to discuss how to add multi-threading Signed-off-by: Oleg * StringUtils for utf convertor #8613 some corrections to move to multi-threading, add one test need discussion data inputs/outputs array presentation, need discussion the way of multi-threading * StringUtils for utf convertor #8613 tests added some corrections to optimize build Signed-off-by: Oleg * StringUtils for utf convertor #8613 some corrections and code clean up Signed-off-by: Oleg * StringUtils for utf convertor #8613 code clean up and optimize usage, need update ndarray factory before replace std usage Signed-off-by: Oleg * StringUtils for utf convertor #8613 some staff to integrate converters into NDArrayFactory, update tests and add some functionality Signed-off-by: Oleg * StringUtils for utf convertor #8613 minor corrections and bug fix before discussion * StringUtils for utf convertor #8613 some fixes and tets * StringUtils for utf convertor #8613 some more staff to support different unicode Signed-off-by: Oleg * StringUtils for utf convertor #8613 fix linking bug * StringUtils for utf convertor #8613 corrected several tests as defaults for string ndarray changed * StringUtils for utf convertor #8613 replace some incorrect implementation, revert some test changes, need sync before testing * StringUtils for utf convertor #8613 fixed several thing that were badly implemented yesterday, need optimization, testing (before testing have to be add support of u32 and u16 buffer visualization) * StringUtils for utf convertor #8613 fixed to support u16 and u32, and convertor in ndarray, fix buffer print, etc Signed-off-by: Oleg * StringUtils for utf convertor #8613 merge master and sync with server Signed-off-by: Oleg * StringUtils for utf convertor #8613 some correction for string cast, need print check only asci support Signed-off-by: Oleg * StringUtils for utf convertor #8613 merge master, remove copies and add cast, need test, refactoring according review and clean up * StringUtils for utf convertor #8613 fixed cast and copy issues Signed-off-by: Oleg * StringUtils for utf convertor #8613 fixed cuda and update tests * StringUtils for utf convertor #8613 integration into NdArray, fix several tests for build pass, refactoring, etc * - avoid ambiguity of NDArray ctrs overloading in some tests Signed-off-by: Yurii * StringUtils for utf convertor #8613 NDArray string constructors added, updated NDArrayFactory, refactoring unicode and tests, etc Signed-off-by: Oleg * StringUtils for utf convertor #8613 fixed cuda build and test, refactoring and void* added to some functions Signed-off-by: Oleg * StringUtils for utf convertor #8613 void* integration, removed copy operation, refactoring, added tests for NDArray string constructors, etc Signed-off-by: Oleg * StringUtils for utf convertor #8613 several more fixes, improvements and updates Signed-off-by: Oleg * StringUtils for utf convertor #8613 master merge, code clean up and optimization before review Signed-off-by: Oleg * StringUtils for utf convertor #8613 minor fixes string element size define Signed-off-by: Oleg * StringUtils for utf convertor #8613 revert last changes as mistake Signed-off-by: Oleg * StringUtils for utf convertor #8613 fixed NDArray constructor build problem, remove order from string factory, fixed order use for factory via project, added catch of incorrect sync in cast of arrays to data types, fixed e method for strings, etc Signed-off-by: Oleg * StringUtils for utf convertor #8613 added javacpp hack, added multi-threading, minor corrections in license agreement Signed-off-by: Oleg * StringUtils for utf convertor #8613 windows builds fix, as "sting" is not treated as utf8 Signed-off-by: Oleg Co-authored-by: Yurii Shyrma --- libnd4j/blas/NDArray.h | 56 +- libnd4j/blas/NDArray.hpp | 905 +++++++++++++++++- libnd4j/blas/NDArrayFactory.h | 77 +- libnd4j/blas/cpu/NDArrayFactory.cpp | 300 +++--- libnd4j/include/array/DataTypeUtils.h | 12 +- libnd4j/include/graph/impl/FlatUtils.cpp | 9 +- libnd4j/include/helpers/ShapeUtils.h | 5 +- libnd4j/include/helpers/StringUtils.h | 52 + libnd4j/include/helpers/impl/ShapeUtils.cpp | 9 - libnd4j/include/helpers/impl/StringUtils.cpp | 94 +- libnd4j/include/helpers/impl/unicode.cpp | 456 +++++++++ libnd4j/include/helpers/unicode.h | 189 ++++ .../declarable/generic/broadcastable/pow.cpp | 2 +- .../generic/compat/compat_string_split.cpp | 2 +- libnd4j/include/types/types.h | 5 + .../layers_tests/BroadcastableOpsTests.cpp | 2 +- .../layers_tests/ConvolutionTests2.cpp | 2 +- .../layers_tests/CudaBasicsTests1.cu | 74 +- .../layers_tests/DeclarableOpsTests10.cpp | 10 +- .../layers_tests/DeclarableOpsTests11.cpp | 52 +- .../layers_tests/DeclarableOpsTests12.cpp | 36 +- .../layers_tests/DeclarableOpsTests13.cpp | 16 +- .../layers_tests/DeclarableOpsTests15.cpp | 4 +- .../layers_tests/DeclarableOpsTests17.cpp | 8 +- .../layers_tests/DeclarableOpsTests2.cpp | 8 +- .../layers_tests/DeclarableOpsTests5.cpp | 6 +- .../layers_tests/DeclarableOpsTests6.cpp | 4 +- .../tests_cpu/layers_tests/FlatUtilsTests.cpp | 2 +- .../layers_tests/JavaInteropTests.cpp | 8 +- .../layers_tests/MultiDataTypeTests.cpp | 156 +-- .../layers_tests/NDArrayConstructorsTests.cu | 2 +- .../layers_tests/NDArrayCudaBasicsTests.cu | 40 +- .../tests_cpu/layers_tests/ParityOpsTests.cpp | 24 +- libnd4j/tests_cpu/layers_tests/RNGTests.cpp | 53 +- .../tests_cpu/layers_tests/StringTests.cpp | 776 ++++++++++++++- 35 files changed, 2992 insertions(+), 464 deletions(-) create mode 100644 libnd4j/include/helpers/impl/unicode.cpp create mode 100644 libnd4j/include/helpers/unicode.h diff --git a/libnd4j/blas/NDArray.h b/libnd4j/blas/NDArray.h index 671f72a57..fe5f90bc3 100644 --- a/libnd4j/blas/NDArray.h +++ b/libnd4j/blas/NDArray.h @@ -195,6 +195,56 @@ namespace nd4j { NDArray(std::shared_ptr buffer, const char order, const std::vector &shape, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext()); + /** + * This contructors create scalar array containing string utf8 + * + */ + NDArray(const char* str, nd4j::DataType dtype = nd4j::DataType::UTF8, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext()) + : NDArray(std::string(str), dtype, context) { + } + NDArray(const std::string& string, nd4j::DataType dtype = nd4j::DataType::UTF8, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext()); + + /** + * This contructors create scalar array containing string utf16 + * + */ + NDArray(const char16_t* u16string, nd4j::DataType dtype = nd4j::DataType::UTF16, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext()) + : NDArray(std::u16string(u16string), dtype, context) { + } + + NDArray(const std::u16string& u16string, nd4j::DataType dtype = nd4j::DataType::UTF16, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext()); + + /** + * This contructors create scalar array containing string utf32 + * + */ + NDArray(const char32_t* u32string, nd4j::DataType dtype = nd4j::DataType::UTF32, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext()) + : NDArray(std::u32string(u32string), dtype, context) { + } + + NDArray(const std::u32string& u32string, nd4j::DataType dtype = nd4j::DataType::UTF32, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext()); + + /** + * This contructors create array from vector of utf8 strings + * + */ + NDArray(const std::vector& shape, const std::vector& strings, nd4j::DataType dtype = nd4j::DataType::UTF8, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext()); + NDArray(const std::vector& shape, const std::vector& string, nd4j::DataType dtype = nd4j::DataType::UTF8, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext()); + + /** + * This contructors create array from vector of utf16 strings + * + */ + NDArray(const std::vector& shape, const std::vector& strings, nd4j::DataType dtype = nd4j::DataType::UTF16, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext()); + NDArray(const std::vector& shape, const std::vector& string, nd4j::DataType dtype = nd4j::DataType::UTF16, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext()); + + /** + * This contructors create array from vector of utf32 strings + * + */ + NDArray(const std::vector& shape, const std::vector& strings, nd4j::DataType dtype = nd4j::DataType::UTF32, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext()); + NDArray(const std::vector& shape, const std::vector& string, nd4j::DataType dtype = nd4j::DataType::UTF32, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext()); + #endif /** @@ -250,7 +300,6 @@ namespace nd4j { */ NDArray(void *buffer, const char order, const std::vector &shape, nd4j::DataType dtype, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext(), const bool isBuffAlloc = false); - /** * This method returns new array with the same shape & data type * @return @@ -1148,6 +1197,9 @@ namespace nd4j { template NDArray asT() const; + template + NDArray asS() const; + NDArray asT(DataType dtype) const; @@ -1441,7 +1493,7 @@ namespace nd4j { * @return */ bool isS() const; - + template std::vector asVectorT(); diff --git a/libnd4j/blas/NDArray.hpp b/libnd4j/blas/NDArray.hpp index 42b29cf78..79137ac3a 100644 --- a/libnd4j/blas/NDArray.hpp +++ b/libnd4j/blas/NDArray.hpp @@ -34,11 +34,15 @@ template <> ND4J_EXPORT utf8string NDArray::e(const Nd4jLong i) const; template <> ND4J_EXPORT std::string NDArray::e(const Nd4jLong i) const; +template <> +ND4J_EXPORT std::u16string NDArray::e(const Nd4jLong i) const; +template <> +ND4J_EXPORT std::u32string NDArray::e(const Nd4jLong i) const; //////////////////////////////////////////////////////////////////////// // copy constructor NDArray::NDArray(const NDArray& other) { - + _context = other._context; _offset = 0; @@ -293,6 +297,560 @@ NDArray::NDArray(std::shared_ptr buffer, const char order, const std _isView = _length * DataTypeUtils::sizeOf(_dataType) < buffer->getLenInBytes(); } +///////////////////////////////////////////////////////////////////////// +// u16 string constructors +NDArray::NDArray(const std::u16string& u16string, nd4j::DataType dtype, nd4j::LaunchContext* context) { + + if (!DataTypeUtils::isS(dtype)) { + throw std::invalid_argument("NDArray::NDArray: invalid DataType, only string dataTypes have to be used"); + } + + if (!unicode::isStringValidU16(u16string.data(), u16string.data() + u16string.size())) { + throw std::invalid_argument("NDArray::NDArray: invalid character in input string"); + } + + // one word that is why used 1 + Nd4jLong headerLength = ShapeUtils::stringBufferHeaderRequirements(1); + + Nd4jLong dataLength = [&] { + if (dtype == DataType::UTF16) { + return static_cast(u16string.size() * sizeof(uint16_t)); + } + if (dtype == DataType::UTF32) { + return unicode::offsetUtf16StringInUtf32(u16string.data(), u16string.size()); + } + return unicode::offsetUtf16StringInUtf8(u16string.data(), u16string.size()); + }(); + + Nd4jLong offsets[2] = { 0 , dataLength }; + + _buffer = std::make_shared(headerLength + dataLength, dtype, context->getWorkspace(), true); + + _context = context; + _isAttached = getContext()->getWorkspace() != nullptr; + _offset = 0; + + setShapeInfo(ShapeDescriptor::scalarDescriptor(dtype)); + + memcpy(bufferAsT(), &offsets[0], 2 * sizeof(Nd4jLong)); + + auto data = reinterpret_cast(bufferAsT() + headerLength); + if (dtype == DataType::UTF8) { + unicode::utf16to8(u16string.data(), data, u16string.size()); + } + else if (dtype == DataType::UTF16) { + memcpy(data, u16string.data(), dataLength); + } + else { + unicode::utf16to32(u16string.data(), data, u16string.size()); + } + + tickWriteHost(); + syncToDevice(); +} + +///////////////////////////////////////////////////////////////////////// +// u32 string constructors +NDArray::NDArray(const std::u32string& u32string, nd4j::DataType dtype, nd4j::LaunchContext* context) { + + if (!DataTypeUtils::isS(dtype)) { + throw std::invalid_argument("NDArray::NDArray: invalid DataType, only string dataTypes have to be used"); + } + + if (!unicode::isStringValidU32(u32string.data(), u32string.data() + u32string.size())) { + throw std::invalid_argument("NDArray::NDArray: invalid character in input string"); + } + // one word that is why used 1 + Nd4jLong headerLength = ShapeUtils::stringBufferHeaderRequirements(1); + + Nd4jLong dataLength = [&] { + if (dtype == DataType::UTF16) { + return unicode::offsetUtf32StringInUtf16(u32string.data(), u32string.size()); + } + if (dtype == DataType::UTF32) { + return static_cast(sizeof(uint32_t) * u32string.size()); + } + return unicode::offsetUtf32StringInUtf8(u32string.data(), u32string.size()); + }(); + + Nd4jLong offsets[2] = { 0 , dataLength }; + + _buffer = std::make_shared(headerLength + dataLength, dtype, context->getWorkspace(), true); + + _context = context; + _isAttached = getContext()->getWorkspace() != nullptr; + _offset = 0; + + setShapeInfo(ShapeDescriptor::scalarDescriptor(dtype)); + + memcpy(bufferAsT(), &offsets[0], 2 * sizeof(Nd4jLong)); + + auto data = reinterpret_cast(bufferAsT() + headerLength); + if (dtype == DataType::UTF8) { + unicode::utf32to8(u32string.data(), data, u32string.size()); + } + else if (dtype == DataType::UTF16) { + unicode::utf32to16(u32string.data(), data, u32string.size()); + } + else { + memcpy(data, u32string.data(), u32string.size() * sizeof(uint32_t)); + } + + tickWriteHost(); + syncToDevice(); +} + +///////////////////////////////////////////////////////////////////////// +// u8 string constructors +///////////////////////////////////////////////////////////////////////// +NDArray::NDArray(const std::string& str, nd4j::DataType dtype, nd4j::LaunchContext* context) { + + if (!DataTypeUtils::isS(dtype)) { + throw std::invalid_argument("NDArray::NDArray: invalid DataType, only string dataTypes have to be used"); + } + + if (!unicode::isStringValidU8(str.data(), str.data() + str.size())) { + throw std::invalid_argument("NDArray::NDArray: invalid character in input string"); + } + + // one word that is why used 1 + auto headerLength = ShapeUtils::stringBufferHeaderRequirements(1); + + Nd4jLong dataLength = [&] { + if (dtype == DataType::UTF16) { + return unicode::offsetUtf8StringInUtf16(str.data(), str.size()); + } + if (dtype == DataType::UTF32) { + return unicode::offsetUtf8StringInUtf32(str.data(), str.size()); + } + return static_cast(str.size()); + }(); + + Nd4jLong offsets[2] = { 0 , dataLength }; + + _buffer = std::make_shared(headerLength + dataLength, dtype, context->getWorkspace(), true); + + _context = context; + _isAttached = getContext()->getWorkspace() != nullptr; + _offset = 0; + + setShapeInfo(ShapeDescriptor::scalarDescriptor(dtype)); + + memcpy(bufferAsT(), &offsets[0], 2 * sizeof(Nd4jLong)); + + auto data = reinterpret_cast(bufferAsT() + headerLength); + + if (dtype == DataType::UTF8) { + memcpy(data, str.data(), str.size()); + } + else if (dtype == DataType::UTF16) { + unicode::utf8to16(str.data(), data, str.size()); + } + else { + unicode::utf8to32(str.data(), data, str.size()); + } + + tickWriteHost(); + syncToDevice(); +} +///////////////////////////////////////////////////////////////////////// +// constructors for vector of strings +NDArray::NDArray(const std::vector& shape, const std::vector& string, const nd4j::DataType dataType, nd4j::LaunchContext* context) { + + if (!DataTypeUtils::isS(dataType)) + throw std::invalid_argument("NDArray::NDArray: invalid DataType, only string dataTypes have to be used"); + + if (shape::prodLong(shape.data(), shape.size()) != string.size()) + throw std::invalid_argument("NDArray::NDArray: Number of strings should match length of array"); + + for (const auto& str : string) { + if (!unicode::isStringValidU8(str, str + std::char_traits::length(str)) ) { + throw std::invalid_argument("NDArray::NDArray: invalid character in input string"); + } + } + + Nd4jLong headerLength = ShapeUtils::stringBufferHeaderRequirements(string.size()); + + std::vector offsets(string.size() + 1); + Nd4jLong dataLength = 0; + for (int e = 0; e < string.size(); e++) { + offsets[e] = dataLength; + dataLength += [&] { + if (dataType == DataType::UTF16) + return unicode::offsetUtf8StringInUtf16(string[e], std::char_traits::length(string[e])); + if (dataType == DataType::UTF32) + return unicode::offsetUtf8StringInUtf32(string[e], std::char_traits::length(string[e])); + return static_cast(std::char_traits::length(string[e])); + }(); + } + offsets[string.size()] = dataLength; + + _buffer = std::make_shared(headerLength + dataLength, dataType, context->getWorkspace(), true); + + _context = context; + _offset = 0; + + setShapeInfo(ShapeDescriptor(dataType, 'c', shape)); + + _isView = false; + + setAttached(context->getWorkspace() != nullptr); + + memcpy(bufferAsT(), offsets.data(), offsets.size() * sizeof(Nd4jLong)); + + auto data = reinterpret_cast(bufferAsT() + headerLength); + + auto func = PRAGMA_THREADS_FOR{ + for (auto e = start; e < stop; e += increment) { + auto cdata = data + offsets[e]; + if (dataType == DataType::UTF16) { + unicode::utf8to16(string[e], cdata, std::char_traits::length(string[e])); + } + else if (dataType == DataType::UTF32) { + unicode::utf8to32(string[e], cdata, std::char_traits::length(string[e])); + } + else { + memcpy(cdata, string[e], std::char_traits::length(string[e])); + } + } + }; + + samediff::Threads::parallel_for(func, 0, lengthOf(), 1); + + tickWriteHost(); + syncToDevice(); +} +///////////////////////////////////////////////////////////////////////// +NDArray::NDArray(const std::vector& shape, const std::vector& string, const nd4j::DataType dataType, nd4j::LaunchContext* context) { + + if (!DataTypeUtils::isS(dataType)) + throw std::invalid_argument("NDArray::NDArray: invalid DataType, only string dataTypes have to be used"); + + if (shape::prodLong(shape.data(), shape.size()) != string.size()) + throw std::invalid_argument("NDArray::NDArray: Number of strings should match length of array"); + + for (const auto& str : string) { + if (!unicode::isStringValidU8(str.data(), str.data() + str.size())) { + throw std::invalid_argument("NDArray::NDArray: invalid character in input string"); + } + } + + Nd4jLong headerLength = ShapeUtils::stringBufferHeaderRequirements(string.size()); + + std::vector offsets(string.size() + 1); + Nd4jLong dataLength = 0; + for (int e = 0; e < string.size(); e++) { + offsets[e] = dataLength; + dataLength += [&] { + if (dataType == DataType::UTF16) + return unicode::offsetUtf8StringInUtf16(string[e].data(), string[e].size()); + if (dataType == DataType::UTF32) + return unicode::offsetUtf8StringInUtf32(string[e].data(), string[e].size()); + return static_cast(string[e].size()); + }(); + } + + offsets[string.size()] = dataLength; + + _buffer = std::make_shared(headerLength + dataLength, dataType, context->getWorkspace(), true); + + _context = context; + _offset = 0; + + setShapeInfo(ShapeDescriptor(dataType, 'c', shape)); + + _isView = false; + + setAttached(context->getWorkspace() != nullptr); + + memcpy(bufferAsT(), offsets.data(), offsets.size() * sizeof(Nd4jLong)); + + auto data = reinterpret_cast(bufferAsT() + headerLength); + + auto func = PRAGMA_THREADS_FOR{ + for (auto e = start; e < stop; e += increment) { + auto cdata = data + offsets[e]; + if (dataType == DataType::UTF16) { + unicode::utf8to16(string[e].data(), cdata, string[e].size()); + } + else if (dataType == DataType::UTF32) { + unicode::utf8to32(string[e].data(), cdata, string[e].size()); + } + else { + memcpy(cdata, string[e].data(), string[e].size()); + } + } + }; + + samediff::Threads::parallel_for(func, 0, lengthOf(), 1); + + + tickWriteHost(); + syncToDevice(); +} +///////////////////////////////////////////////////////////////////////// +NDArray::NDArray(const std::vector& shape, const std::vector& string, nd4j::DataType dtype, nd4j::LaunchContext* context) { + + if (!DataTypeUtils::isS(dtype)) + throw std::invalid_argument("NDArray::NDArray: invalid DataType, only string dataTypes have to be used"); + + if (shape::prodLong(shape.data(), shape.size()) != string.size()) + throw std::invalid_argument("NDArray::NDArray: Number of strings should match length of array"); + + for (const auto& str : string) { + if (!unicode::isStringValidU16(str.data(), str.data() + str.size())) { + throw std::invalid_argument("NDArray::NDArray: invalid character in input string"); + } + } + + Nd4jLong headerLength = ShapeUtils::stringBufferHeaderRequirements(string.size()); + + std::vector offsets(string.size() + 1); + Nd4jLong dataLength = 0; + for (int e = 0; e < string.size(); e++) { + offsets[e] = dataLength; + dataLength += [&] { + if (dtype == DataType::UTF16) + return static_cast(sizeof(uint16_t) * string[e].size()); + if (dtype == DataType::UTF32) + return unicode::offsetUtf16StringInUtf32(string[e].data(), string[e].size()); + return unicode::offsetUtf16StringInUtf8(string[e].data(), string[e].size()); + }(); + } + offsets[string.size()] = dataLength; + + _buffer = std::make_shared(headerLength + dataLength, dtype, context->getWorkspace(), true); + + _context = context; + _offset = 0; + + setShapeInfo(ShapeDescriptor(dtype, 'c', shape)); + + _isView = false; + + setAttached(context->getWorkspace() != nullptr); + + memcpy(bufferAsT(), offsets.data(), offsets.size() * sizeof(Nd4jLong)); + + auto data = reinterpret_cast(bufferAsT() + headerLength); + + auto func = PRAGMA_THREADS_FOR{ + for (auto e = start; e < stop; e += increment) { + auto cdata = data + offsets[e]; + if (dtype == DataType::UTF16) { + memcpy(cdata, string[e].data(), string[e].size() * sizeof(uint16_t)); + } + else if (dtype == DataType::UTF32) { + unicode::utf16to32(string[e].data(), cdata, string[e].size()); + } + else { + unicode::utf16to8(string[e].data(), cdata, string[e].size()); + } + } + }; + samediff::Threads::parallel_for(func, 0, lengthOf(), 1); + + tickWriteHost(); + syncToDevice(); +} +///////////////////////////////////////////////////////////////////////// +NDArray::NDArray(const std::vector& shape, const std::vector& string, nd4j::DataType dtype, nd4j::LaunchContext* context) { + + if (!DataTypeUtils::isS(dtype)) + throw std::invalid_argument("NDArray::NDArray: invalid DataType, only string dataTypes have to be used"); + + if (shape::prodLong(shape.data(), shape.size()) != string.size()) + throw std::invalid_argument("NDArray::NDArray: Number of strings should match length of array"); + + for (const auto& str : string) { + if (!unicode::isStringValidU16(str, str + std::char_traits::length(str))) { + throw std::invalid_argument("NDArray::NDArray: invalid character in input string"); + } + } + + Nd4jLong headerLength = ShapeUtils::stringBufferHeaderRequirements(string.size()); + + std::vector offsets(string.size() + 1); + Nd4jLong dataLength = 0; + for (int e = 0; e < string.size(); e++) { + offsets[e] = dataLength; + dataLength += [&] { + if (dtype == DataType::UTF16) + return static_cast(sizeof(uint16_t) * std::char_traits::length(string[e])); + if (dtype == DataType::UTF32) + return unicode::offsetUtf16StringInUtf32(string[e], std::char_traits::length(string[e])); + return unicode::offsetUtf16StringInUtf8(string[e], std::char_traits::length(string[e])); + }(); + } + offsets[string.size()] = dataLength; + + _buffer = std::make_shared(headerLength + dataLength, dtype, context->getWorkspace(), true); + + _context = context; + _offset = 0; + + setShapeInfo(ShapeDescriptor(dtype, 'c', shape)); + + _isView = false; + + setAttached(context->getWorkspace() != nullptr); + + memcpy(bufferAsT(), offsets.data(), offsets.size() * sizeof(Nd4jLong)); + + auto data = reinterpret_cast(bufferAsT() + headerLength); + + + auto func = PRAGMA_THREADS_FOR{ + for (auto e = start; e < stop; e += increment) { + auto cdata = data + offsets[e]; + if (dtype == DataType::UTF16) { + memcpy(cdata, string[e], std::char_traits::length(string[e]) * sizeof(uint16_t)); + } + else if (dtype == DataType::UTF32) { + unicode::utf16to32(string[e], cdata, std::char_traits::length(string[e])); + } + else { + unicode::utf16to8(string[e], cdata, std::char_traits::length(string[e])); + } + } + }; + samediff::Threads::parallel_for(func, 0, lengthOf(), 1); + + tickWriteHost(); + syncToDevice(); +} +///////////////////////////////////////////////////////////////////////// +NDArray::NDArray(const std::vector& shape, const std::vector& string, nd4j::DataType dtype, nd4j::LaunchContext* context) { + + if (!DataTypeUtils::isS(dtype)) + throw std::invalid_argument("NDArray::NDArray: invalid DataType, only string dataTypes have to be used"); + + if (shape::prodLong(shape.data(), shape.size()) != string.size()) + throw std::invalid_argument("NDArray::NDArray: Number of strings should match length of array"); + + for (auto str : string) { + if (!unicode::isStringValidU32(str.data(), str.data() + str.size())) { + throw std::invalid_argument("NDArray::NDArray: invalid character in input string"); + } + } + + Nd4jLong headerLength = ShapeUtils::stringBufferHeaderRequirements(string.size()); + + std::vector offsets(string.size() + 1); + + Nd4jLong dataLength = 0; + for (int e = 0; e < string.size(); e++) { + offsets[e] = dataLength; + dataLength += [&] { + if (dtype == DataType::UTF16) + return unicode::offsetUtf32StringInUtf16(string[e].data(), string[e].size()); + if (dtype == DataType::UTF32) + return static_cast(sizeof(uint32_t) * string[e].size()); + return unicode::offsetUtf32StringInUtf16(string[e].data(), string[e].size()); + }(); + } + offsets[string.size()] = dataLength; + + _buffer = std::make_shared(headerLength + dataLength, dtype, context->getWorkspace(), true); + + _context = context; + _offset = 0; + + setShapeInfo(ShapeDescriptor(dtype, 'c', shape)); + + _isView = false; + + setAttached(context->getWorkspace() != nullptr); + + memcpy(bufferAsT(), offsets.data(), offsets.size() * sizeof(Nd4jLong)); + + auto data = reinterpret_cast(bufferAsT() + headerLength); + + auto func = PRAGMA_THREADS_FOR{ + for (auto e = start; e < stop; e += increment) { + auto cdata = data + offsets[e]; + if (dtype == DataType::UTF16) { + unicode::utf32to16(string[e].data(), cdata, string[e].size()); + } + else if (dtype == DataType::UTF32) { + memcpy(cdata, string[e].data(), string[e].size() * sizeof(uint32_t)); + } + else { + unicode::utf32to8(string[e].data(), cdata, string[e].size()); + } + } + }; + samediff::Threads::parallel_for(func, 0, lengthOf(), 1); + + tickWriteHost(); + syncToDevice(); +} +///////////////////////////////////////////////////////////////////////// +NDArray::NDArray(const std::vector& shape, const std::vector& string, nd4j::DataType dtype, nd4j::LaunchContext* context) { + + if (!DataTypeUtils::isS(dtype)) + throw std::invalid_argument("NDArray::NDArray: invalid DataType used"); + + if (shape::prodLong(shape.data(), shape.size()) != string.size()) + throw std::invalid_argument("NDArray::NDArray: Number of strings should match length of array"); + + for (const auto& str : string) { + if (!unicode::isStringValidU32(str, str + std::char_traits::length(str))) { + throw std::invalid_argument("NDArray::NDArray: invalid character in input string"); + } + } + + Nd4jLong headerLength = ShapeUtils::stringBufferHeaderRequirements(string.size()); + + std::vector offsets(string.size() + 1); + + Nd4jLong dataLength = 0; + for (int e = 0; e < string.size(); e++) { + offsets[e] = dataLength; + dataLength += [&] { + if (dtype == DataType::UTF16) + return unicode::offsetUtf32StringInUtf16(string[e], std::char_traits::length(string[e])); + if (dtype == DataType::UTF32) + return static_cast(sizeof(uint32_t) * std::char_traits::length(string[e])); + return unicode::offsetUtf32StringInUtf16(string[e], std::char_traits::length(string[e])); + }(); + } + offsets[string.size()] = dataLength; + + _buffer = std::make_shared(headerLength + dataLength, dtype, context->getWorkspace(), true); + + _context = context; + _offset = 0; + + setShapeInfo(ShapeDescriptor(dtype, 'c', shape)); + + _isView = _length * DataTypeUtils::sizeOf(_dataType) < _buffer->getLenInBytes(); + + setAttached(context->getWorkspace() != nullptr); + + memcpy(bufferAsT(), offsets.data(), offsets.size() * sizeof(Nd4jLong)); + + auto data = reinterpret_cast(bufferAsT() + headerLength); + + auto func = PRAGMA_THREADS_FOR{ + for (auto e = start; e < stop; e += increment) { + auto cdata = data + offsets[e]; + if (dtype == DataType::UTF16) { + unicode::utf32to16(string[e], cdata, std::char_traits::length(string[e])); + } + else if (dtype == DataType::UTF32) { + memcpy(cdata, string[e], std::char_traits::length(string[e]) * sizeof(uint32_t)); + } + else { + unicode::utf32to8(string[e], cdata, std::char_traits::length(string[e])); + } + } + }; + samediff::Threads::parallel_for(func, 0, lengthOf(), 1); + + tickWriteHost(); + syncToDevice(); +} //////////////////////////////////////////////////////////////////////// // assignment operator @@ -329,7 +887,9 @@ bool NDArray::isC() const { ////////////////////////////////////////////////////////////////////////// bool NDArray::isS() const { - return dataType() == DataType::UTF8; + return (dataType() == DataType::UTF8 || + dataType() == DataType::UTF16 || + dataType() == DataType::UTF32); } ////////////////////////////////////////////////////////////////////////// @@ -407,7 +967,7 @@ std::string NDArray::asString(Nd4jLong limit) { os << toStringValue(this->e(e)); else if (this->isB()) os << toStringValue(this->e(e)); - else if (this->isS()) + else if (this->isS()) // todo add utf16 and utf32 os << this->e(e); if (e < limit - 1) os << ", "; @@ -477,8 +1037,6 @@ std::vector NDArray::getShapeInfoAsVector() { //////////////////////////////////////////////////////////////////////// std::vector NDArray::asByteVector() { - - if (isS()) { // string data type requires special treatment syncToHost(); @@ -1066,8 +1624,17 @@ void NDArray::printBuffer(const char* msg, Nd4jLong limit, const bool sync) cons if (e < limit - 1) printf(", "); } - } + } else if (this->isS()) { + // todo do we need this print offsets + /* + for (Nd4jLong e = 0; e < limit; e++) { + printf("\"%lld\"", this->getOffset(e)); + if (e < limit - 1) + printf(", "); + } + printf("]\n["); + */ for (Nd4jLong e = 0; e < limit; e++) { printf("\"%s\"", this->e(e).c_str()); if (e < limit - 1) @@ -1123,8 +1690,9 @@ static void printFormatted(NDArray const* arr, int depth, int limit) { printf("%lld, ", arr->e(i)); else if (arr->isB()) printf("%s, ", arr->e(i)?"true":"false"); - else if (arr->isS()) + else if (arr->isS()) { printf("\"%s\", ", arr->e(i).c_str()); + } } printf("]\n"); } @@ -1149,8 +1717,9 @@ static void printFormatted(NDArray const* arr, int depth, int limit) { printf("%lld", arr->e(row, col)); else if (arr->isB()) printf("%s", arr->e(row, col)?"true":"false"); - else if (arr->isS()) + else if (arr->isS()) { printf("\"%s\"", arr->e(row * cols + col).c_str()); + } } if (row < rows - 1) printf("]\n"); @@ -1204,6 +1773,8 @@ void NDArray::printIndexedBuffer(const char* msg, Nd4jLong limit) const { printf("%s\n", this->e(0)?"true":"false"); } else if (this->isS()) { + // todo do we need this + // printf("\"%lld\"\n", this->getOffset(e)); printf("\"%s\"\n", this->e(0).c_str()); } } @@ -1708,9 +2279,8 @@ NDArray NDArray::subarray(const Intervals& idx) const { template NDArray NDArray::asT() const{ - auto result = isScalar() ? NDArray('c', {}, {0.}, DataTypeUtils::fromT(), this->getContext()) : NDArray(ordering(), getShapeAsVector(), DataTypeUtils::fromT(), this->getContext()); - auto l = this->lengthOf(); - + auto result = isScalar() ? NDArray('c', {}, std::vector{0.}, DataTypeUtils::fromT(), this->getContext()) : NDArray(ordering(), getShapeAsVector(), DataTypeUtils::fromT(), this->getContext()); + NDArray::prepareSpecialUse({&result}, {this}); NativeOpExecutioner::execTransformAny(getContext(), transform::AnyOps::Assign, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), result.getBuffer(), result.getShapeInfo(), result.getSpecialBuffer(), result.getSpecialShapeInfo(), nullptr, nullptr, nullptr); NDArray::registerSpecialUse({&result}, {this}); @@ -1719,20 +2289,145 @@ NDArray NDArray::asT() const{ } BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT NDArray NDArray::asT, () const, LIBND4J_TYPES); +////////////////////////////////////////////////////////////////////////// +template +NDArray NDArray::asS() const { + + if (!isS()) + throw std::runtime_error("NDArray::asS: you can use this method only for String array!"); + + auto dtype = DataTypeUtils::fromT(); + + if (!(DataTypeUtils::isS(dtype))) + throw std::invalid_argument("NDArray::asS: invalid DataType used"); + + if (dtype == dataType()) { + + Nd4jLong offsetsLength = ShapeUtils::stringBufferHeaderRequirements(lengthOf()); + const auto nInputoffsets = bufferAsT(); + std::shared_ptr pBuffer = std::make_shared(offsetsLength + nInputoffsets[lengthOf()], dtype, getContext()->getWorkspace(), true); + + NDArray res(pBuffer, ShapeDescriptor(dtype, ordering(), getShapeAsVector()), getContext()); + res.setAttached(getContext()->getWorkspace() != nullptr); + + preparePrimaryUse({ &res }, { this }); + memcpy(res.bufferAsT(), nInputoffsets, offsetsLength); + auto data = res.bufferAsT() + offsetsLength; + const auto inData = bufferAsT() + offsetsLength; + memcpy(data, inData, nInputoffsets[lengthOf()]); + + registerPrimaryUse({ &res }, { this }); + return res; + } + + Nd4jLong offsetsLength = ShapeUtils::stringBufferHeaderRequirements(lengthOf()); + + std::vector offsets(lengthOf() + 1); + + const auto nInputoffsets = bufferAsT(); + + Nd4jLong start = 0, stop = 0; + Nd4jLong dataLength = 0; + + auto data = bufferAsT() + offsetsLength; + for (int e = 0; e < lengthOf(); e++) { + offsets[e] = dataLength; + start = nInputoffsets[e]; + stop = nInputoffsets[e + 1]; + if (dataType() == DataType::UTF8) { + dataLength += (dtype == DataType::UTF16) ? unicode::offsetUtf8StringInUtf16(data + start, stop) + : unicode::offsetUtf8StringInUtf32(data + start, stop); + } + else if (dataType() == DataType::UTF16) { + dataLength += (dtype == DataType::UTF32) ? unicode::offsetUtf16StringInUtf32(data + start, (stop / sizeof(char16_t)) ) + : unicode::offsetUtf16StringInUtf8(data + start, (stop / sizeof(char16_t))); + } + else { + dataLength += (dtype == DataType::UTF16) ? unicode::offsetUtf32StringInUtf16(data + start, (stop / sizeof(char32_t))) + : unicode::offsetUtf32StringInUtf8(data + start, (stop / sizeof(char32_t))); + } + } + offsets[lengthOf()] = dataLength; + + std::shared_ptr pBuffer = std::make_shared(offsetsLength + dataLength, dtype, getContext()->getWorkspace(), true); + + NDArray res(pBuffer, ShapeDescriptor(dtype, ordering(), getShapeAsVector()), getContext()); + res.setAttached(getContext()->getWorkspace() != nullptr); + + preparePrimaryUse({ &res }, { this }); + + memcpy(res.bufferAsT(), offsets.data(), offsets.size() * sizeof(Nd4jLong)); + + auto outData = res.bufferAsT() + offsetsLength; + const auto inData = bufferAsT() + offsetsLength; + + auto func = PRAGMA_THREADS_FOR{ + for (int e = start; e < stop; e += increment) { + auto cdata = outData + offsets[e]; + auto end = nInputoffsets[e + 1]; + auto idata = inData + nInputoffsets[e]; + if (dtype == DataType::UTF16) { + if (dataType() == DataType::UTF8) { + unicode::utf8to16(idata, outData, end); + } + else { + unicode::utf32to16(idata, outData, (end / sizeof(char32_t))); + } + } + else if (dtype == DataType::UTF32) { + if (dataType() == DataType::UTF8) { + unicode::utf8to32(idata, cdata, end); + } + else { + unicode::utf16to32(idata, outData, (end / sizeof(char16_t))); + } + } + else { + if (dataType() == DataType::UTF16) { + unicode::utf16to8(idata, outData, (end / sizeof(char16_t))); + } + else { + unicode::utf32to8(idata, outData, (end / sizeof(char32_t))); + } + } + } + }; + + samediff::Threads::parallel_for(func, 0, lengthOf(), 1); + + registerPrimaryUse({ &res }, { this }); + + return res; +} +BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT NDArray NDArray::asS, () const, LIBND4J_STRINGTYPES); + //////////////////////////////////////////////////////////////////////// NDArray NDArray::asT(DataType dtype) const { - if (isS()) - throw std::runtime_error("NDArray::asT: you can't use this method on String array!"); + + if (isS() && !DataTypeUtils::isS(dtype)) + throw std::runtime_error("NDArray::asT: you can't use this method on String array with not string DataType!"); - BUILD_SINGLE_SELECTOR(dtype, return asT, (), LIBND4J_TYPES); + if (!isS() && DataTypeUtils::isS(dtype)) + throw std::runtime_error("NDArray::asT: you can't use this method on not String array with string DataType!"); + + if (isS()){ + BUILD_SINGLE_SELECTOR(dtype, return asS, (), LIBND4J_STRINGTYPES); + } else { + BUILD_SINGLE_SELECTOR(dtype, return asT, (), LIBND4J_TYPES); + } return NDArray(); } //////////////////////////////////////////////////////////////////////// NDArray NDArray::cast(DataType dtype) const { - if (isS()) - throw std::runtime_error("NDArray::cast: you can't use this method on String array!"); + + if (isS() && !DataTypeUtils::isS(dtype)) + throw std::runtime_error("NDArray::cast: you can't use this method on String array with not string DataType!"); + + if (!isS() && DataTypeUtils::isS(dtype)) + throw std::runtime_error("NDArray::cast: you can't use this method on not String array with string DataType!"); + return this->asT(dtype); } @@ -2765,14 +3460,44 @@ NDArray NDArray::dup(const char newOrder) const { char order = newOrder == 'a' ? ordering() : newOrder; // for now string arrays require special treatment - if (dataType() == DataType::UTF8) { + if (isS()) { + if (dataType() == DataType::UTF8) { + std::vector strings(lengthOf()); + + auto func = PRAGMA_THREADS_FOR{ + for (auto i = start; i < stop; i += increment) { + strings[i] = std::move(this->e(i)); + } + }; - std::vector strings(lengthOf()); - for (int e = 0; e < lengthOf(); e++) - strings[e] = this->e(e); + samediff::Threads::parallel_for(func, 0, lengthOf(), 1); - auto result = NDArrayFactory::string(order, getShapeAsVector(), strings, getContext()); - return result; + return NDArray(getShapeAsVector(), strings, dataType(), getContext()); + } + if (dataType() == DataType::UTF16) { + std::vector strings(lengthOf()); + + auto func = PRAGMA_THREADS_FOR{ + for (auto i = start; i < stop; i += increment) { + strings[i] = std::move(this->e(i)); + } + }; + + samediff::Threads::parallel_for(func, 0, lengthOf(), 1); + + return NDArray(getShapeAsVector(), strings, dataType(), getContext()); + } + + std::vector strings(lengthOf()); + auto func = PRAGMA_THREADS_FOR{ + for (auto i = start; i < stop; i += increment) { + strings[i] = std::move(this->e(i)); + } + }; + + samediff::Threads::parallel_for(func, 0, lengthOf(), 1); + + return NDArray(getShapeAsVector(), strings, dataType(), getContext()); } NDArray result(order, isScalar() ? std::vector({0}) : getShapeAsVector(), dataType(), getContext()); @@ -2796,12 +3521,33 @@ bool NDArray::equalsTo(const NDArray *other, double eps) const { if (isS()) { // string is special case, we'll compare them one by one, considering both arrays are guaranteed to have the same length - for (int e = 0; e < this->lengthOf(); e++) { - auto s1 = this->e(e); - auto s2 = other->e(e); + + if (dataType() == DataType::UTF8) { + for (int e = 0; e < this->lengthOf(); e++) { + auto s1 = this->e(e); + auto s2 = other->e(e); - if (s1 != s2) - return false; + if (s1 != s2) + return false; + } + } + else if (dataType() == DataType::UTF16) { + for (int e = 0; e < this->lengthOf(); e++) { + auto s1 = this->e(e); + auto s2 = other->e(e); + + if (s1 != s2) + return false; + } + } + else { + for (int e = 0; e < this->lengthOf(); e++) { + auto s1 = this->e(e); + auto s2 = other->e(e); + + if (s1 != s2) + return false; + } } return true; @@ -2836,15 +3582,31 @@ std::string NDArray::e(const Nd4jLong i) const { if (!isS()) throw std::runtime_error("Can't get std::string out of non-string array"); + if (i == lengthOf()) + throw std::runtime_error("Can't get std::string for index out of range"); + + + if (this->dataType() == DataType::UTF16) { + auto u16 = this->e(i); + std::string s; + StringUtils::u16StringToU8String(u16, s); + return s; + } + + if (this->dataType() == DataType::UTF32) { + auto u32 = this->e(i); + std::string s; + StringUtils::u32StringToU8String(u32, s); + return s; + } + NDArray::preparePrimaryUse({}, {this}); - // getting "virtual" offset. it's not real though,since it doesn't take lengths into account - auto offset = getOffset(i); - auto offsets = reinterpret_cast(getBuffer()); + auto offsets = bufferAsT(); auto offsetsLength = ShapeUtils::stringBufferHeaderRequirements(lengthOf()); - auto start = offsets[offset]; - auto end = offsets[offset + 1]; - auto data = static_cast(getBuffer()) + offsetsLength + start; + auto start = offsets[i]; + auto end = offsets[i + 1]; + auto data = bufferAsT() + offsetsLength + start; std::string r(reinterpret_cast(data), (end - start)); @@ -2853,6 +3615,83 @@ std::string NDArray::e(const Nd4jLong i) const { return r; } +template <> +std::u16string NDArray::e(const Nd4jLong i) const { + + if (!isS()) + throw std::runtime_error("Can't get std::u16string out of non-string array"); + + if(i == lengthOf()) + throw std::runtime_error("Can't get std::u16string for index out of range"); + + if (this->dataType() == DataType::UTF8) { + auto u = this->e(i); + std::u16string s; + StringUtils::u8StringToU16String(u, s); + return s; + } + + if (this->dataType() == DataType::UTF32) { + auto u32 = this->e(i); + std::u16string s; + StringUtils::u32StringToU16String(u32, s); + return s; + } + + NDArray::preparePrimaryUse({}, { this }); + + auto offsets = bufferAsT(); + Nd4jLong offsetsLength = ShapeUtils::stringBufferHeaderRequirements(lengthOf()); + Nd4jLong start = offsets[i]; + Nd4jLong end = offsets[i + 1]; + auto data = bufferAsT() + offsetsLength + start; + + std::u16string r(reinterpret_cast(data), (end - start) / sizeof(char16_t)); + + registerPrimaryUse({}, { this }); + + return r; +} + +template <> +std::u32string NDArray::e(const Nd4jLong i) const { + + if (!isS()) + throw std::runtime_error("Can't get std::u32string out of non-string array"); + + if (i == lengthOf()) + throw std::runtime_error("Can't get std::u32string for index out of range"); + + if (this->dataType() == DataType::UTF8) { + auto u = this->e(i); + std::u32string s; + StringUtils::u8StringToU32String(u, s); + return s; + } + + if (this->dataType() == DataType::UTF16) { + auto u16 = this->e(i); + std::u32string s; + StringUtils::u16StringToU32String(u16, s); + return s; + } + + NDArray::preparePrimaryUse({}, { this }); + + auto offsets = bufferAsT(); + Nd4jLong offsetsLength = ShapeUtils::stringBufferHeaderRequirements(lengthOf()); + Nd4jLong start = offsets[i]; + Nd4jLong end = offsets[i + 1]; + + auto data = bufferAsT() + offsetsLength + start; + + std::u32string r(reinterpret_cast(data), (end - start) / sizeof(char32_t)); + + registerPrimaryUse({}, { this }); + + return r; +} + ////////////////////////////////////////////////////////////////////////// template <> utf8string NDArray::e(const Nd4jLong i) const { diff --git a/libnd4j/blas/NDArrayFactory.h b/libnd4j/blas/NDArrayFactory.h index cdd8d9f9f..5e979f1d8 100644 --- a/libnd4j/blas/NDArrayFactory.h +++ b/libnd4j/blas/NDArrayFactory.h @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019-2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -16,6 +17,7 @@ // // Created by raver119 on 2018-09-16. +// @author Oleg Semeniv // #ifndef DEV_TESTS_NDARRAYFACTORY_H @@ -106,25 +108,72 @@ namespace nd4j { template static NDArray create(char order, const std::vector &shape, const std::initializer_list& data, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext()); - static NDArray string(const char *string, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext()); + /** + * This factory create array from utf8 string + * @return NDArray default dataType UTF8 + */ + static NDArray string(const char *string, nd4j::DataType dtype = nd4j::DataType::UTF8, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext()); + static NDArray* string_(const char *string, nd4j::DataType dtype = nd4j::DataType::UTF8, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext()); + static NDArray* string_(const std::string &string, nd4j::DataType dtype = nd4j::DataType::UTF8, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext()); + static NDArray string(const std::string& string, nd4j::DataType dtype = nd4j::DataType::UTF8, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext()); - static NDArray* string_(const char *string, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext()); + /** + * This factory create array from utf16 string + * @return NDArray default dataType UTF16 + */ + static NDArray string(const char16_t* u16string, nd4j::DataType dtype = nd4j::DataType::UTF16, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext()); + static NDArray* string_(const char16_t* u16string, nd4j::DataType dtype = nd4j::DataType::UTF16, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext()); + static NDArray* string_(const std::u16string& u16string, nd4j::DataType dtype = nd4j::DataType::UTF16, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext()); + static NDArray string(const std::u16string& u16string, nd4j::DataType dtype = nd4j::DataType::UTF16, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext()); + + /** + * This factory create array from utf32 string + * @return NDArray default dataType UTF32 + */ + static NDArray string(const char32_t* u32string, nd4j::DataType dtype = nd4j::DataType::UTF32, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext()); + static NDArray* string_(const char32_t* u32string, nd4j::DataType dtype = nd4j::DataType::UTF32, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext()); + static NDArray* string_(const std::u32string& u32string, nd4j::DataType dtype = nd4j::DataType::UTF32, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext()); + static NDArray string(const std::u32string& u32string, nd4j::DataType dtype = nd4j::DataType::UTF32, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext()); - static NDArray string(const std::string &string, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext()); + /** + * This factory create array from vector of utf8 strings + * @return NDArray default dataType UTF8 + */ + static NDArray string( const std::vector &shape, const std::initializer_list &strings, nd4j::DataType dtype = nd4j::DataType::UTF8, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext()); + static NDArray string( const std::vector &shape, const std::initializer_list &string, nd4j::DataType dtype = nd4j::DataType::UTF8, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext()); + static NDArray string( const std::vector &shape, const std::vector &strings, nd4j::DataType dtype = nd4j::DataType::UTF8, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext()); + static NDArray string( const std::vector &shape, const std::vector &string, nd4j::DataType dtype = nd4j::DataType::UTF8, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext()); + static NDArray* string_( const std::vector &shape, const std::initializer_list &strings, nd4j::DataType dtype = nd4j::DataType::UTF8, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext()); + static NDArray* string_( const std::vector &shape, const std::initializer_list &string, nd4j::DataType dtype = nd4j::DataType::UTF8, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext()); + static NDArray* string_( const std::vector &shape, const std::vector &strings, nd4j::DataType dtype = nd4j::DataType::UTF8, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext()); + static NDArray* string_( const std::vector &shape, const std::vector &string, nd4j::DataType dtype = nd4j::DataType::UTF8, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext()); - static NDArray* string_(const std::string &string, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext()); + /** + * This factory create array from vector of utf16 strings + * @return NDArray default dataType UTF16 + */ + static NDArray string( const std::vector& shape, const std::initializer_list& strings, nd4j::DataType dtype = nd4j::DataType::UTF16, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext()); + static NDArray string( const std::vector& shape, const std::initializer_list& string, nd4j::DataType dtype = nd4j::DataType::UTF16, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext()); + static NDArray string( const std::vector& shape, const std::vector& strings, nd4j::DataType dtype = nd4j::DataType::UTF16, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext()); + static NDArray string( const std::vector& shape, const std::vector& string, nd4j::DataType dtype = nd4j::DataType::UTF16, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext()); + static NDArray* string_( const std::vector& shape, const std::initializer_list& strings, nd4j::DataType dtype = nd4j::DataType::UTF16, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext()); + static NDArray* string_( const std::vector& shape, const std::initializer_list& string, nd4j::DataType dtype = nd4j::DataType::UTF16, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext()); + static NDArray* string_( const std::vector& shape, const std::vector& strings, nd4j::DataType dtype = nd4j::DataType::UTF16, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext()); + static NDArray* string_( const std::vector& shape, const std::vector& string, nd4j::DataType dtype = nd4j::DataType::UTF16, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext()); - static NDArray string(char order, const std::vector &shape, const std::initializer_list &strings, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext()); - static NDArray string(char order, const std::vector &shape, const std::initializer_list &string, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext()); + /** + * This factory create array from vector of utf32 strings + * @return NDArray default dataType UTF32 + */ + static NDArray string( const std::vector& shape, const std::initializer_list& strings, nd4j::DataType dtype = nd4j::DataType::UTF32, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext()); + static NDArray string( const std::vector& shape, const std::initializer_list& string, nd4j::DataType dtype = nd4j::DataType::UTF32, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext()); + static NDArray string( const std::vector& shape, const std::vector& strings, nd4j::DataType dtype = nd4j::DataType::UTF32, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext()); + static NDArray string( const std::vector& shape, const std::vector& string, nd4j::DataType dtype = nd4j::DataType::UTF32, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext()); + static NDArray* string_( const std::vector& shape, const std::initializer_list& strings, nd4j::DataType dtype = nd4j::DataType::UTF32, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext()); + static NDArray* string_( const std::vector& shape, const std::initializer_list& string, nd4j::DataType dtype = nd4j::DataType::UTF32, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext()); + static NDArray* string_( const std::vector& shape, const std::vector& strings, nd4j::DataType dtype = nd4j::DataType::UTF32, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext()); + static NDArray* string_( const std::vector& shape, const std::vector& string, nd4j::DataType dtype = nd4j::DataType::UTF32, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext()); - static NDArray string(char order, const std::vector &shape, const std::vector &strings, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext()); - static NDArray string(char order, const std::vector &shape, const std::vector &string, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext()); - - static NDArray* string_(char order, const std::vector &shape, const std::initializer_list &strings, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext()); - static NDArray* string_(char order, const std::vector &shape, const std::initializer_list &string, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext()); - - static NDArray* string_(char order, const std::vector &shape, const std::vector &strings, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext()); - static NDArray* string_(char order, const std::vector &shape, const std::vector &string, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext()); static ResultSet createSetOfArrs(const Nd4jLong numOfArrs, const void* buffer, const Nd4jLong* shapeInfo, const Nd4jLong* offsets, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext()); diff --git a/libnd4j/blas/cpu/NDArrayFactory.cpp b/libnd4j/blas/cpu/NDArrayFactory.cpp index 54cc6bba8..738dccdbe 100644 --- a/libnd4j/blas/cpu/NDArrayFactory.cpp +++ b/libnd4j/blas/cpu/NDArrayFactory.cpp @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019-2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -16,6 +17,7 @@ // // Created by GS on 2018-12-20. +// @author Oleg Semeniv // #include @@ -25,6 +27,9 @@ #include #include + +#include + namespace nd4j { //////////////////////////////////////////////////////////////////////// @@ -85,45 +90,6 @@ namespace nd4j { template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::vector& data, nd4j::LaunchContext * context); template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::vector& data, nd4j::LaunchContext * context); - NDArray NDArrayFactory::string(const char *str, nd4j::LaunchContext * context) { - std::string s(str); - return string(s, context); - } - - NDArray* NDArrayFactory::string_(const char *str, nd4j::LaunchContext * context) { - return string_(std::string(str), context); - } - - NDArray NDArrayFactory::string(const std::string &str, nd4j::LaunchContext * context) { - - auto headerLength = ShapeUtils::stringBufferHeaderRequirements(1); - - std::shared_ptr pBuffer = std::make_shared(headerLength + str.length(), DataType::UTF8, context->getWorkspace(), true); - - NDArray res(pBuffer, ShapeDescriptor::scalarDescriptor(DataType::UTF8), context); - - int8_t* buffer = reinterpret_cast(res.getBuffer()); - - auto offsets = reinterpret_cast(buffer); - offsets[0] = 0; - offsets[1] = str.length(); - - auto data = buffer + headerLength; - - memcpy(data, str.c_str(), str.length()); - - res.tickWriteHost(); - res.syncToDevice(); - - return res; - } - - NDArray* NDArrayFactory::string_(const std::string &str, nd4j::LaunchContext * context) { - auto res = new NDArray(); - *res = NDArrayFactory::string(str, context); - return res; - } - //////////////////////////////////////////////////////////////////////// template NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, nd4j::LaunchContext * context) { @@ -551,91 +517,175 @@ template ND4J_EXPORT NDArray NDArrayFactory::create(uint8_t * buffer, const char template ND4J_EXPORT NDArray NDArrayFactory::create(int8_t* buffer, const char order, const std::initializer_list& shape, nd4j::LaunchContext * context); template ND4J_EXPORT NDArray NDArrayFactory::create(int16_t* buffer, const char order, const std::initializer_list& shape, nd4j::LaunchContext * context); - - NDArray NDArrayFactory::string(char order, const std::vector &shape, const std::initializer_list &strings, nd4j::LaunchContext * context) { - std::vector vec(strings); - return NDArrayFactory::string(order, shape, vec, context); - } - - NDArray NDArrayFactory::string(char order, const std::vector &shape, const std::vector &strings, nd4j::LaunchContext * context) { - std::vector vec(strings.size()); - int cnt = 0; - for (auto s:strings) - vec[cnt++] = std::string(s); - - return NDArrayFactory::string(order, shape, vec, context); - } - - - NDArray NDArrayFactory::string(char order, const std::vector &shape, const std::initializer_list &string, nd4j::LaunchContext * context) { - std::vector vec(string); - return NDArrayFactory::string(order, shape, vec, context); - } - - NDArray* NDArrayFactory::string_(char order, const std::vector &shape, const std::initializer_list &strings, nd4j::LaunchContext * context) { - std::vector vec(strings); - return NDArrayFactory::string_(order, shape, vec, context); - } - - NDArray* NDArrayFactory::string_(char order, const std::vector &shape, const std::vector &strings, nd4j::LaunchContext * context) { - std::vector vec(strings.size()); - int cnt = 0; - for (auto s:strings) - vec[cnt++] = std::string(s); - - return NDArrayFactory::string_(order, shape, vec, context); - } - - - NDArray* NDArrayFactory::string_(char order, const std::vector &shape, const std::initializer_list &string, nd4j::LaunchContext * context) { - std::vector vec(string); - return NDArrayFactory::string_(order, shape, vec, context); - } - - NDArray NDArrayFactory::string(char order, const std::vector &shape, const std::vector &string, nd4j::LaunchContext * context) { - - if (context == nullptr) - context = nd4j::LaunchContext ::defaultContext(); - - auto headerLength = ShapeUtils::stringBufferHeaderRequirements(string.size()); - - std::vector offsets(string.size() + 1); - Nd4jLong dataLength = 0; - for (int e = 0; e < string.size(); e++) { - offsets[e] = dataLength; - dataLength += string[e].length(); - } - offsets[string.size()] = dataLength; - - std::shared_ptr pBuffer = std::make_shared(headerLength + dataLength, DataType::UTF8, context->getWorkspace(), true); - - NDArray res(pBuffer, ShapeDescriptor(DataType::UTF8, order, shape), context); - res.setAttached(context->getWorkspace() != nullptr); - - if (res.lengthOf() != string.size()) - throw std::invalid_argument("Number of strings should match length of array"); - - memcpy(res.buffer(), offsets.data(), offsets.size() * sizeof(Nd4jLong)); - - auto data = static_cast(res.buffer()) + headerLength; - int resLen = res.lengthOf(); - for (int e = 0; e < resLen; e++) { - auto length = offsets[e+1] - offsets[e]; - auto cdata = data + offsets[e]; - memcpy(cdata, string[e].c_str(), string[e].length()); - } - - res.tickWriteHost(); - res.syncToDevice(); - - return res; - } - - NDArray* NDArrayFactory::string_(char order, const std::vector &shape, const std::vector &string, nd4j::LaunchContext * context) { - auto res = new NDArray(); - *res = NDArrayFactory::string(order, shape, string, context); - return res; - } - + ///////////////////////////////////////////////////////////////////////////////////// + NDArray NDArrayFactory::string(const char16_t* u16string, nd4j::DataType dtype, nd4j::LaunchContext* context) { + return NDArray(u16string, dtype, context); + } + ///////////////////////////////////////////////////////////////////////// + NDArray* NDArrayFactory::string_(const char16_t* u16string, nd4j::DataType dtype, nd4j::LaunchContext* context) { + return string_(std::u16string(u16string), dtype, context); + } + ///////////////////////////////////////////////////////////////////////// + NDArray* NDArrayFactory::string_(const std::u16string& u16string, nd4j::DataType dtype, nd4j::LaunchContext* context) { + auto res = new NDArray(); + *res = NDArray(u16string, dtype, context); + return res; + } + ///////////////////////////////////////////////////////////////////////// + NDArray NDArrayFactory::string(const std::u16string& u16string, nd4j::DataType dtype, nd4j::LaunchContext* context) { + return NDArray(u16string, dtype, context); + } + ///////////////////////////////////////////////////////////////////////// + NDArray NDArrayFactory::string(const char32_t* u32string, nd4j::DataType dtype, nd4j::LaunchContext* context) { + return NDArray(u32string, dtype, context); + } + ///////////////////////////////////////////////////////////////////////// + NDArray* NDArrayFactory::string_(const char32_t* u32string, nd4j::DataType dtype, nd4j::LaunchContext* context) { + return string_(std::u32string(u32string), dtype, context); + } + ///////////////////////////////////////////////////////////////////////// + NDArray* NDArrayFactory::string_(const std::u32string& u32string, nd4j::DataType dtype, nd4j::LaunchContext* context) { + auto res = new NDArray(); + *res = NDArray(u32string, dtype, context); + return res; + } + ///////////////////////////////////////////////////////////////////////// + NDArray NDArrayFactory::string(const std::u32string& u32string, nd4j::DataType dtype, nd4j::LaunchContext* context) { + return NDArray(u32string, dtype, context); + } + ///////////////////////////////////////////////////////////////////////// + NDArray NDArrayFactory::string(const char* str, nd4j::DataType dtype, nd4j::LaunchContext* context) { + return NDArray(str, dtype, context); + } + ///////////////////////////////////////////////////////////////////////// + NDArray* NDArrayFactory::string_(const char* str, nd4j::DataType dtype, nd4j::LaunchContext* context) { + return string_(std::string(str), dtype, context); + } + ///////////////////////////////////////////////////////////////////////// + NDArray* NDArrayFactory::string_(const std::string& str, nd4j::DataType dtype, nd4j::LaunchContext* context) { + auto res = new NDArray(); + *res = NDArray(str, dtype, context); + return res; + } + ///////////////////////////////////////////////////////////////////////// + NDArray NDArrayFactory::string(const std::string& str, nd4j::DataType dtype, nd4j::LaunchContext* context) { + return NDArray(str, dtype, context); + } + ///////////////////////////////////////////////////////////////////////// + NDArray NDArrayFactory::string(const std::vector &shape, const std::initializer_list &strings, nd4j::DataType dataType, nd4j::LaunchContext * context) { + return NDArray(shape, std::vector(strings), dataType, context); + } + ///////////////////////////////////////////////////////////////////////// + NDArray NDArrayFactory::string( const std::vector &shape, const std::vector &strings, nd4j::DataType dataType, nd4j::LaunchContext * context) { + return NDArray( shape, strings, dataType, context); + } + ///////////////////////////////////////////////////////////////////////// + NDArray NDArrayFactory::string( const std::vector &shape, const std::initializer_list &string, nd4j::DataType dataType, nd4j::LaunchContext * context) { + return NDArray( shape, std::vector(string), dataType, context); + } + ///////////////////////////////////////////////////////////////////////// + NDArray* NDArrayFactory::string_( const std::vector &shape, const std::initializer_list &strings, nd4j::DataType dataType, nd4j::LaunchContext * context) { + return NDArrayFactory::string_( shape, std::vector(strings), dataType, context); + } + ///////////////////////////////////////////////////////////////////////// + NDArray* NDArrayFactory::string_( const std::vector &shape, const std::vector &strings, nd4j::DataType dataType, nd4j::LaunchContext * context) { + std::vector vec(strings.size()); + int cnt = 0; + for (auto s:strings) + vec[cnt++] = std::string(s); + + return NDArrayFactory::string_( shape, vec, dataType, context); + } + ///////////////////////////////////////////////////////////////////////// + NDArray* NDArrayFactory::string_( const std::vector &shape, const std::initializer_list &string, nd4j::DataType dataType, nd4j::LaunchContext * context) { + return NDArrayFactory::string_( shape, std::vector(string), dataType, context); + } + ///////////////////////////////////////////////////////////////////////// + NDArray NDArrayFactory::string( const std::vector &shape, const std::vector &string, nd4j::DataType dataType, nd4j::LaunchContext * context) { + return NDArray(shape, string, dataType, context); + } + ///////////////////////////////////////////////////////////////////////// + NDArray* NDArrayFactory::string_(const std::vector &shape, const std::vector &string, nd4j::DataType dataType, nd4j::LaunchContext * context) { + auto res = new NDArray(); + *res = NDArray( shape, string, dataType, context); + return res; + } + ///////////////////////////////////////////////////////////////////////// + NDArray NDArrayFactory::string(const std::vector& shape, const std::initializer_list& strings, nd4j::DataType dataType, nd4j::LaunchContext* context) { + return NDArray( shape, std::vector(strings), dataType, context); + } + ///////////////////////////////////////////////////////////////////////// + NDArray NDArrayFactory::string( const std::vector& shape, const std::vector& strings, nd4j::DataType dataType, nd4j::LaunchContext* context) { + return NDArray( shape, strings, dataType, context); + } + ///////////////////////////////////////////////////////////////////////// + NDArray NDArrayFactory::string( const std::vector& shape, const std::initializer_list& string, nd4j::DataType dataType, nd4j::LaunchContext* context) { + return NDArray( shape, std::vector(string), dataType, context); + } + ///////////////////////////////////////////////////////////////////////// + NDArray* NDArrayFactory::string_( const std::vector& shape, const std::initializer_list& strings, nd4j::DataType dataType, nd4j::LaunchContext* context) { + return NDArrayFactory::string_( shape, std::vector(strings), dataType, context); + } + ///////////////////////////////////////////////////////////////////////// + NDArray* NDArrayFactory::string_( const std::vector& shape, const std::vector& strings, nd4j::DataType dataType, nd4j::LaunchContext* context) { + std::vector vec(strings.size()); + int cnt = 0; + for (auto s : strings) + vec[cnt++] = std::u16string(s); + + return NDArrayFactory::string_( shape, vec, dataType, context); + } + ///////////////////////////////////////////////////////////////////////// + NDArray* NDArrayFactory::string_( const std::vector& shape, const std::initializer_list& string, nd4j::DataType dataType, nd4j::LaunchContext* context) { + return NDArrayFactory::string_( shape, std::vector(string), dataType, context); + } + ///////////////////////////////////////////////////////////////////////// + NDArray* NDArrayFactory::string_( const std::vector& shape, const std::vector& string, nd4j::DataType dataType, nd4j::LaunchContext* context) { + auto res = new NDArray(); + *res = NDArray( shape, string, dataType, context); + return res; + } + ///////////////////////////////////////////////////////////////////////// + NDArray NDArrayFactory::string( const std::vector& shape, const std::vector& string, nd4j::DataType dtype, nd4j::LaunchContext* context) { + return NDArray( shape, string, dtype, context); + } + ///////////////////////////////////////////////////////////////////////// + NDArray NDArrayFactory::string( const std::vector& shape, const std::initializer_list& strings, nd4j::DataType dataType, nd4j::LaunchContext* context) { + return NDArray( shape, std::vector(strings), dataType, context); + } + ///////////////////////////////////////////////////////////////////////// + NDArray NDArrayFactory::string( const std::vector& shape, const std::vector& strings, nd4j::DataType dataType, nd4j::LaunchContext* context) { + return NDArray( shape, strings, dataType, context); + } + ///////////////////////////////////////////////////////////////////////// + NDArray NDArrayFactory::string( const std::vector& shape, const std::initializer_list& string, nd4j::DataType dataType, nd4j::LaunchContext* context) { + return NDArray(shape, std::vector(string), dataType, context); + } + ///////////////////////////////////////////////////////////////////////// + NDArray* NDArrayFactory::string_( const std::vector& shape, const std::initializer_list& strings, nd4j::DataType dataType, nd4j::LaunchContext* context) { + return NDArrayFactory::string_( shape, std::vector(strings), dataType, context); + } + ///////////////////////////////////////////////////////////////////////// + NDArray* NDArrayFactory::string_( const std::vector& shape, const std::vector& strings, nd4j::DataType dataType, nd4j::LaunchContext* context) { + std::vector vec(strings.size()); + int cnt = 0; + for (auto s : strings) + vec[cnt++] = std::u32string(s); + return NDArrayFactory::string_( shape, vec, dataType, context); + } + ///////////////////////////////////////////////////////////////////////// + NDArray* NDArrayFactory::string_( const std::vector& shape, const std::initializer_list& string, nd4j::DataType dataType, nd4j::LaunchContext* context) { + return NDArrayFactory::string_( shape, std::vector(string), dataType, context); + } + ///////////////////////////////////////////////////////////////////////// + NDArray* NDArrayFactory::string_( const std::vector& shape, const std::vector& string, nd4j::DataType dataType, nd4j::LaunchContext* context) { + auto res = new NDArray(); + *res = NDArray( shape, string, dataType, context); + return res; + } + ///////////////////////////////////////////////////////////////////////// + NDArray NDArrayFactory::string(const std::vector& shape, const std::vector& string, nd4j::DataType dtype, nd4j::LaunchContext* context) { + return NDArray( shape, string, dtype, context); + } } diff --git a/libnd4j/include/array/DataTypeUtils.h b/libnd4j/include/array/DataTypeUtils.h index 5d17c28b0..c307ecd4e 100644 --- a/libnd4j/include/array/DataTypeUtils.h +++ b/libnd4j/include/array/DataTypeUtils.h @@ -122,7 +122,7 @@ namespace nd4j { } FORCEINLINE bool DataTypeUtils::isS(nd4j::DataType dataType) { - return dataType == nd4j::DataType::UTF8; + return dataType == nd4j::DataType::UTF8 || dataType == nd4j::DataType::UTF16 || dataType == nd4j::DataType::UTF32; } FORCEINLINE bool DataTypeUtils::isZ(nd4j::DataType dataType) { @@ -370,6 +370,10 @@ FORCEINLINE std::string DataTypeUtils::asString(DataType dataType) { return std::string("UINT64"); case UTF8: return std::string("UTF8"); + case UTF16: + return std::string("UTF16"); + case UTF32: + return std::string("UTF32"); default: throw std::runtime_error("Unknown data type used"); } @@ -431,6 +435,8 @@ FORCEINLINE _CUDA_HD T DataTypeUtils::eps() { case nd4j::DataType::UINT16: return (size_t) 2; case nd4j::DataType::UTF8: + case nd4j::DataType::UTF16: + case nd4j::DataType::UTF32: case nd4j::DataType::INT32: case nd4j::DataType::UINT32: case nd4j::DataType::HALF2: @@ -455,6 +461,10 @@ FORCEINLINE _CUDA_HD T DataTypeUtils::eps() { return nd4j::DataType::BOOL; } else if (std::is_same::value) { return nd4j::DataType::UTF8; + } else if (std::is_same::value) { + return nd4j::DataType::UTF16; + } else if (std::is_same::value) { + return nd4j::DataType::UTF32; } else if (std::is_same::value) { return nd4j::DataType::FLOAT32; } else if (std::is_same::value) { diff --git a/libnd4j/include/graph/impl/FlatUtils.cpp b/libnd4j/include/graph/impl/FlatUtils.cpp index ec76cb4d2..f582220da 100644 --- a/libnd4j/include/graph/impl/FlatUtils.cpp +++ b/libnd4j/include/graph/impl/FlatUtils.cpp @@ -49,12 +49,11 @@ namespace nd4j { delete[] newShape; return NDArrayFactory::empty_(dtype, nullptr); } - + // TODO fix UTF16 and UTF32 if (dtype == UTF8) { bool isBe = BitwiseUtils::isBE(); bool canKeep = (isBe && flatArray->byteOrder() == nd4j::graph::ByteOrder_BE) || (!isBe && flatArray->byteOrder() == nd4j::graph::ByteOrder_LE); - auto order = shape::order(newShape); - + std::vector substrings(length); std::vector shapeVector(rank); for (int e = 0; e < rank; e++) @@ -88,8 +87,8 @@ namespace nd4j { delete[] offsets; delete[] newShape; - - return NDArrayFactory::string_(order, shapeVector, substrings); + // string order always 'c' + return NDArrayFactory::string_(shapeVector, substrings); } diff --git a/libnd4j/include/helpers/ShapeUtils.h b/libnd4j/include/helpers/ShapeUtils.h index 5f76c11b5..c99a0b0de 100644 --- a/libnd4j/include/helpers/ShapeUtils.h +++ b/libnd4j/include/helpers/ShapeUtils.h @@ -171,7 +171,10 @@ namespace nd4j { * @param numStrings * @return */ - static Nd4jLong stringBufferHeaderRequirements(Nd4jLong numStrings); + static FORCEINLINE Nd4jLong stringBufferHeaderRequirements(Nd4jLong numStrings) { + // we store +1 offset + return (numStrings + 1) * sizeof(Nd4jLong); + } /* * check whether arr1/arr2 is sub-array of arr2/arr1, diff --git a/libnd4j/include/helpers/StringUtils.h b/libnd4j/include/helpers/StringUtils.h index 2a562de4b..7a0e2a960 100644 --- a/libnd4j/include/helpers/StringUtils.h +++ b/libnd4j/include/helpers/StringUtils.h @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019-2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -16,6 +17,7 @@ // // Created by raver119 on 20/04/18. +// @author Oleg Semeniv // #ifndef LIBND4J_STRINGUTILS_H @@ -27,6 +29,7 @@ #include #include #include +#include namespace nd4j { class ND4J_EXPORT StringUtils { @@ -85,6 +88,55 @@ namespace nd4j { * @return */ static std::vector split(const std::string &haystack, const std::string &delimiter); + + + /** + * This method convert u8 string to u16 + * @param const reference to input string + * @param reference to output u16string + * @return boolean status + */ + static bool u8StringToU16String(const std::string& u8, std::u16string& u16); + + /** + * This method convert u8 string to u32 + * @param const reference to input string + * @param reference to output u32string + * @return boolean status + */ + static bool u8StringToU32String(const std::string& u8, std::u32string& u32); + + /** + * This method convert u16 string to u32 + * @param const reference to input u16string + * @param reference to output u32string + * @return boolean status + */ + static bool u16StringToU32String(const std::u16string& u16, std::u32string& u32); + + /** + * This method convert u16 string to u8 string + * @param const reference to input u16string + * @param reference to output string + * @return boolean status + */ + static bool u16StringToU8String(const std::u16string& u16, std::string& u8); + + /** + * This method convert u32 string to u16 string + * @param const reference to input u32string + * @param reference to output u16string + * @return boolean status + */ + static bool u32StringToU16String(const std::u32string& u32, std::u16string& u16); + + /** + * This method convert u32 string to u8 string + * @param const reference to input u32string + * @param reference to output string + * @return boolean status + */ + static bool u32StringToU8String(const std::u32string& u32, std::string& u8); }; } diff --git a/libnd4j/include/helpers/impl/ShapeUtils.cpp b/libnd4j/include/helpers/impl/ShapeUtils.cpp index b040eb73c..165ed5ffd 100644 --- a/libnd4j/include/helpers/impl/ShapeUtils.cpp +++ b/libnd4j/include/helpers/impl/ShapeUtils.cpp @@ -1019,15 +1019,6 @@ std::vector ShapeUtils::tadAxesForSimpleBroadcast(const NDArray& max, const return numOfMinTads == 1 ? maxTadDims : std::vector(); } - -Nd4jLong ShapeUtils::stringBufferHeaderRequirements(Nd4jLong numStrings) { - // we store +1 offset - auto base = numStrings + 1; - - // since we return number of bytes... - return base * sizeof(Nd4jLong); -} - //////////////////////////////////////////////////////////////////////////////// /* bool ShapeUtils::isSubArrayCase(const NDArray& arr1, const NDArray& arr2, std::vector& sameDims) { diff --git a/libnd4j/include/helpers/impl/StringUtils.cpp b/libnd4j/include/helpers/impl/StringUtils.cpp index faace2c63..045dcea73 100644 --- a/libnd4j/include/helpers/impl/StringUtils.cpp +++ b/libnd4j/include/helpers/impl/StringUtils.cpp @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019-2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -16,6 +17,7 @@ // // Created by raver119 on 20/04/18. +// @author Oleg Semeniv // #include @@ -49,13 +51,8 @@ namespace nd4j { if (!array.isS()) throw nd4j::datatype_exception::build("StringUtils::byteLength expects one of String types;", array.dataType()); - uint64_t result = 0; - - // our buffer stores offsets, and the last value is basically number of bytes used auto buffer = array.bufferAsT(); - result = buffer[array.lengthOf()]; - - return result; + return buffer[array.lengthOf()]; } std::vector StringUtils::split(const std::string &haystack, const std::string &delimiter) { @@ -73,4 +70,89 @@ namespace nd4j { return output; } + + bool StringUtils::u8StringToU16String(const std::string& u8, std::u16string& u16) { + + if (u8.empty()) + return false; + + u16.resize(unicode::offsetUtf8StringInUtf16(u8.data(), u8.size()) / sizeof(char16_t)); + if (u8.size() == u16.size()) + u16.assign(u8.begin(), u8.end()); + else + return unicode::utf8to16(u8.data(), &u16[0], u8.size()); + + return true; + } + + bool StringUtils::u8StringToU32String(const std::string& u8, std::u32string& u32) { + + if (u8.empty()) + return false; + + u32.resize( unicode::offsetUtf8StringInUtf32(u8.data(), u8.size()) / sizeof(char32_t) ); + if (u8.size() == u32.size()) + u32.assign(u8.begin(), u8.end()); + else + return unicode::utf8to32(u8.data(), &u32[0], u8.size()); + + return true; + } + + bool StringUtils::u16StringToU32String(const std::u16string& u16, std::u32string& u32) { + + if (u16.empty()) + return false; + + u32.resize(unicode::offsetUtf16StringInUtf32(u16.data(), u16.size()) / sizeof(char32_t)); + if (u16.size() == u32.size()) + u32.assign(u16.begin(), u16.end()); + else + return unicode::utf16to32(u16.data(), &u32[0], u16.size()); + + return true; + } + + bool StringUtils::u16StringToU8String(const std::u16string& u16, std::string& u8) { + + if (u16.empty()) + return false; + + u8.resize(unicode::offsetUtf16StringInUtf8(u16.data(), u16.size())); + if (u16.size() == u8.size()) + u8.assign(u16.begin(), u16.end()); + else + return unicode::utf16to8(u16.data(), &u8[0], u16.size()); + + return true; + } + + bool StringUtils::u32StringToU16String(const std::u32string& u32, std::u16string& u16) { + + if (u32.empty()) + return false; + + u16.resize(unicode::offsetUtf32StringInUtf16(u32.data(), u32.size()) / sizeof(char16_t)); + if (u32.size() == u16.size()) + u16.assign(u32.begin(), u32.end()); + else + return unicode::utf32to16(u32.data(), &u16[0], u32.size()); + + return true; + } + + bool StringUtils::u32StringToU8String(const std::u32string& u32, std::string& u8) { + + if (u32.empty()) + return false; + + u8.resize(unicode::offsetUtf32StringInUtf8(u32.data(), u32.size())); + if (u32.size() == u8.size()) + u8.assign(u32.begin(), u32.end()); + else + return unicode::utf32to8(u32.data(), &u8[0], u32.size()); + + return true; + } + } diff --git a/libnd4j/include/helpers/impl/unicode.cpp b/libnd4j/include/helpers/impl/unicode.cpp new file mode 100644 index 000000000..2e49faf3e --- /dev/null +++ b/libnd4j/include/helpers/impl/unicode.cpp @@ -0,0 +1,456 @@ +/******************************************************************************* + * Copyright (c) 2015-2020 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Oleg Semeniv +// + +#include + +namespace nd4j { +namespace unicode { + + constexpr uint32_t ONEBYTEBOUND = 0x00000080; + constexpr uint32_t TWOBYTEBOUND = 0x00000800; + constexpr uint32_t THREEBYTEBOUND = 0x00010000; + constexpr uint16_t HIGHBYTEMIN = 0xd800u; + constexpr uint16_t HIGHBYTEMAX = 0xdbffu; + constexpr uint16_t TRAILBYTEMIN = 0xdc00u; + constexpr uint16_t TRAILBYTEMAX = 0xdfffu; + constexpr uint16_t HIGHBYTEOFFSET = HIGHBYTEMIN - (0x10000 >> 10); + constexpr uint32_t BYTEOFFSET = 0x10000u - (HIGHBYTEMIN << 10) - TRAILBYTEMIN; + // Maximum valid value for a Unicode code point + constexpr uint32_t CODEPOINTMAX = 0x0010ffffu; + + template + FORCEINLINE uint8_t castToU8(const T cp) { + return static_cast(0xff & cp); + } + + template + FORCEINLINE uint16_t castToU16(const T cp) { + return static_cast(0xffff & cp); + } + + template + FORCEINLINE uint32_t castToU32(const T cp) { + return static_cast(0xffffff & cp); + } + + template + FORCEINLINE bool isTrail(const T cp) { + return ((castToU8(cp) >> 6) == 0x2); + } + + template + FORCEINLINE bool isHighSurrogate(const T cp) { + return (cp & 0xfffffc00) == 0xd800; + } + + template + bool isLowSurrogate(const T cp) { + return (cp & 0xfffffc00) == 0xdc00; + } + + template + FORCEINLINE bool isLeadSurrogate(const T cp) { + return (cp >= HIGHBYTEMIN && cp <= HIGHBYTEMAX); + } + + template + FORCEINLINE bool isTrailSurrogate(const T cp) { + return (cp >= TRAILBYTEMIN && cp <= TRAILBYTEMAX); + } + + template + FORCEINLINE bool isSurrogateU8(const T cp) { + return (cp >= HIGHBYTEMIN && cp <= TRAILBYTEMAX); + } + + template + FORCEINLINE bool isSurrogateU16(const T cp) { + return ((cp - 0xd800u) < 2048u); + } + + template + FORCEINLINE bool isSymbolU8Valid(const T cp) { + return (cp <= CODEPOINTMAX && !isSurrogateU8(cp)); + } + + template + FORCEINLINE bool isSymbolValid(const T cp) { + return (cp <= CODEPOINTMAX); + } + + template + FORCEINLINE uint32_t surrogateU32(const T& high, const T& low) { + return (high << 10) + low - 0x35fdc00; + } + + template + Nd4jLong symbolLength(const T* it) { + uint8_t lead = castToU8(*it); + if (lead < 0x80) + return 1; + else if ((lead >> 5) == 0x6) + return 2; + else if ((lead >> 4) == 0xe) + return 3; + else if ((lead >> 3) == 0x1e) + return 4; + else + return 0; + } + + template + Nd4jLong symbolLength32(const T* it) { + auto lead = castToU32(*it); + if (lead < ONEBYTEBOUND) + return 1; + else if (lead < TWOBYTEBOUND) + return 2; + else if (lead < THREEBYTEBOUND) + return 3; + else if (lead <= CODEPOINTMAX) + return 4; + else + return 0; + } + + template + Nd4jLong symbolLength16(const T* it) { + + uint32_t lead = castToU16(*it); + if (!isLeadSurrogate(lead)) { + if (lead < ONEBYTEBOUND) + return 1; + else if (lead < TWOBYTEBOUND) + return 2; + else if (lead < THREEBYTEBOUND) + return 3; + else + return 0; + } + else { + return 4; + } + } + + Nd4jLong offsetUtf8StringInUtf32(const void* start, const void* end) { + + Nd4jLong count = 0; + for (auto it = static_cast(start); it != end; it++) { + auto length = symbolLength(it); + it += (length > 0) ? (length - 1) : 0; + count += 1; + } + return static_cast(count * sizeof(char32_t)); + } + + Nd4jLong offsetUtf16StringInUtf32(const void* start, const void* end) { + + Nd4jLong count = 0; + for (auto it = static_cast(start); it != end;) { + auto length = symbolLength16(it); + it += (4 == length) ? 2 : 1; + count += 1; + } + return static_cast(count*sizeof(char32_t)); + } + + Nd4jLong offsetUtf8StringInUtf16(const void* start, const void* end) { + + Nd4jLong count = 0; + for (auto it = static_cast(start); it != end; it++) { + auto length = symbolLength(it); + auto step = ((length > 0) ? (length - 1) : 0); + it += step; + count += (4 == length) ? 2 : 1; + } + return static_cast(count*sizeof(char16_t)); + } + + Nd4jLong offsetUtf16StringInUtf8(const void* start, const void* end) { + + Nd4jLong count = 0; + for (auto it = static_cast(start); it != end;) { + auto length = symbolLength16(it); + it += (4 == length) ? 2 : 1; + count += length; + } + return static_cast(count); + } + + Nd4jLong offsetUtf32StringInUtf16(const void* start, const void* end) { + + Nd4jLong count = 0; + for (auto it = static_cast(start); it != end; it++) { + auto length = symbolLength32(it); + count += (4 == length) ? 2 : 1;; + } + return static_cast(count*sizeof(char16_t)); + } + + Nd4jLong offsetUtf32StringInUtf8(const void* start, const void* end) { + + Nd4jLong count = 0; + for (auto it = static_cast(start); it != end; it++) { + count += symbolLength32(it); + } + return count; + } + + bool isStringValidU8(const void* start, const void* stop) { + for (auto it = static_cast(start); it != stop; it++) { + if (!isSymbolU8Valid( castToU8(*it) )) { + return false; + } + } + return true; + } + + bool isStringValidU16(const void* start, const void* stop) { + for (auto it = static_cast(start); it != stop; it++) { + if (!isSymbolValid( castToU32(*it) )) { + return false; + } + } + return true; + } + + bool isStringValidU32(const void* start, const void* stop) { + for (auto it = static_cast(start); it != stop; it++) { + if (!isSymbolValid( castToU32(*it) )) { + return false; + } + } + return true; + } + + void* utf16to8Ptr(const void* start, const void* end, void* res) { + + auto result = static_cast(res); + // result have to be pre-allocated + for (auto it = static_cast(start); it != end;) { + uint32_t cp = castToU16(*it++); + if (!isLeadSurrogate(cp)) { + if (cp < 0x80) { // for one byte + *(result++) = static_cast(cp); + } + else if (cp < 0x800) { // for two bytes + *(result++) = static_cast((cp >> 6) | 0xc0); + *(result++) = static_cast((cp & 0x3f) | 0x80); + } + else{ // for three bytes + *(result++) = static_cast((cp >> 12) | 0xe0); + *(result++) = static_cast(((cp >> 6) & 0x3f) | 0x80); + *(result++) = static_cast((cp & 0x3f) | 0x80); + } + } + else { + if (it != end) { + uint32_t trail_surrogate = castToU16(*it++); + if (isTrailSurrogate(trail_surrogate)) + cp = (cp << 10) + trail_surrogate + BYTEOFFSET; + } + // for four bytes + *(result++) = static_cast((cp >> 18) | 0xf0); + *(result++) = static_cast(((cp >> 12) & 0x3f) | 0x80); + *(result++) = static_cast(((cp >> 6) & 0x3f) | 0x80); + *(result++) = static_cast((cp & 0x3f) | 0x80); + } + } + return result; + } + + void* utf8to16Ptr(const void* start, const void* end, void* res) { + + auto result = static_cast(res); + // result have to be pre-allocated + for (auto it = static_cast(start); it != end;) { + + auto nLength = symbolLength(it); + uint32_t cp = castToU8(*it++); + if (4 != nLength) { + if (2 == nLength) { + cp = ((cp << 6) & 0x7ff) + ((*it++) & 0x3f); + } + else if (3 == nLength) { + cp = ((cp << 12) & 0xffff) + ((castToU8(*it++) << 6) & 0xfff); + cp += (*it++) & 0x3f; + } + *(result++) = static_cast(cp); + } + else { + cp = ((cp << 18) & 0x1fffff) + ((castToU8(*it++) << 12) & 0x3ffff); + cp += (castToU8(*it++) << 6) & 0xfff; + cp += (*it++) & 0x3f; + //make a surrogate pair + *(result++) = static_cast((cp >> 10) + HIGHBYTEOFFSET); + *(result++) = static_cast((cp & 0x3ff) + TRAILBYTEMIN); + } + } + return result; + } + + void* utf32to8Ptr( const void* start, const void* end, void* result) { + + auto res = static_cast(result); + // result have to be pre-allocated + for (auto it = static_cast(start); it != end; it++) { + + if (*it < 0x80) // for one byte + *(res++) = static_cast(*it); + else if (*it < 0x800) { // for two bytes + *(res++) = static_cast((*it >> 6) | 0xc0); + *(res++) = static_cast((*it & 0x3f) | 0x80); + } + else if (*it < 0x10000) { // for three bytes + *(res++) = static_cast((*it >> 12) | 0xe0); + *(res++) = static_cast(((*it >> 6) & 0x3f) | 0x80); + *(res++) = static_cast((*it & 0x3f) | 0x80); + } + else { // for four bytes + *(res++) = static_cast((*it >> 18) | 0xf0); + *(res++) = static_cast(((*it >> 12) & 0x3f) | 0x80); + *(res++) = static_cast(((*it >> 6) & 0x3f) | 0x80); + *(res++) = static_cast((*it & 0x3f) | 0x80); + } + } + return result; + } + + void* utf8to32Ptr(const void* start, const void* end, void* res) { + + auto result = static_cast(res); + // result have to be pre-allocated + for (auto it = static_cast(start); it != end;) { + + auto nLength = symbolLength(it); + uint32_t cp = castToU8(*it++); + if (2 == nLength) { + cp = ((cp << 6) & 0x7ff) + ((*it++) & 0x3f); + } + else if (3 == nLength) { + cp = ((cp << 12) & 0xffff) + ((castToU8(*it++) << 6) & 0xfff); + cp += (*it++) & 0x3f; + } + else if (4 == nLength) { + cp = ((cp << 18) & 0x1fffff) + ((castToU8(*it++) << 12) & 0x3ffff); + cp += (castToU8(*it++) << 6) & 0xfff; + cp += (*it++) & 0x3f; + } + (*result++) = cp; + } + return result; + } + + void* utf16to32Ptr(const void* start, const void* end, void* res) { + + auto result = static_cast(res); + // result have to be pre-allocated + for (auto it = static_cast(start); it != end; it++) { + + uint32_t cpHigh = castToU32(*it); + if (!isSurrogateU16(cpHigh)) { + *result++ = cpHigh; + } + else { + it++; + uint32_t cpLow = castToU32(*it); + if (isHighSurrogate(cpHigh) && it != end && isLowSurrogate(cpLow)) { + *result++ = surrogateU32(cpHigh, cpLow); + } + } + } + return result; + } + + void* utf32to16Ptr(const void* start, const void* end, void* res) { + + auto result = static_cast(res); + // result have to be pre-allocate + for (auto it = static_cast(start); it != end; it++) { + + uint32_t cpHigh = castToU32(*it); + // todo check do we need this as we have pre-validation, if yes find out how to check u16 + if (cpHigh < 0 || cpHigh > 0x10FFFF || (cpHigh >= 0xD800 && cpHigh <= 0xDFFF)) { + // Invalid code point. Replace with sentinel, per Unicode standard: + *result++ = u'\uFFFD'; + } + else if (cpHigh < 0x10000UL) { // In the BMP. + *result++ = static_cast(cpHigh); + } + else { + *result++ = static_cast(((cpHigh - 0x10000UL) / 0x400U) + 0xD800U); + *result++ = static_cast(((cpHigh - 0x10000UL) % 0x400U) + 0xDC00U); + } + } + return result; + } + + Nd4jLong offsetUtf8StringInUtf32(const void* input, uint32_t nInputSize) { + return offsetUtf8StringInUtf32(input, static_cast(input) + nInputSize); + } + + Nd4jLong offsetUtf16StringInUtf32(const void* input, uint32_t nInputSize) { + return offsetUtf16StringInUtf32(input, static_cast(input) + nInputSize); + } + + Nd4jLong offsetUtf8StringInUtf16(const void* input, uint32_t nInputSize) { + return offsetUtf8StringInUtf16(input, static_cast(input) + nInputSize); + } + + Nd4jLong offsetUtf16StringInUtf8(const void* input, uint32_t nInputSize) { + return offsetUtf16StringInUtf8(input, static_cast(input) + nInputSize); + } + + Nd4jLong offsetUtf32StringInUtf8(const void* input, uint32_t nInputSize) { + return offsetUtf32StringInUtf8(input, static_cast(input) + nInputSize); + } + + Nd4jLong offsetUtf32StringInUtf16(const void* input, const uint32_t nInputSize) { + return offsetUtf32StringInUtf16(input, static_cast(input) + nInputSize); + } + + bool utf8to16(const void* input, void* output, uint32_t nInputSize) { + return utf8to16Ptr(input, static_cast(input) + nInputSize, output); + } + + bool utf8to32(const void* input, void* output, uint32_t nInputSize) { + return utf8to32Ptr(input, static_cast(input) + nInputSize, output); + } + + bool utf16to32(const void* input, void* output, uint32_t nInputSize) { + return utf16to32Ptr(input, static_cast(input) + nInputSize, output); + } + + bool utf16to8(const void* input, void* output, uint32_t nInputSize) { + return utf16to8Ptr(input, static_cast(input) + nInputSize, output); + } + + bool utf32to16(const void* input, void* output, uint32_t nInputSize) { + return utf32to16Ptr(input, static_cast(input) + nInputSize, output); + } + + bool utf32to8(const void* input, void* output, const Nd4jLong nInputSize) { + return utf32to8Ptr(input, static_cast(input) + nInputSize, output); + } + + } + +} + diff --git a/libnd4j/include/helpers/unicode.h b/libnd4j/include/helpers/unicode.h new file mode 100644 index 000000000..239b71201 --- /dev/null +++ b/libnd4j/include/helpers/unicode.h @@ -0,0 +1,189 @@ +/******************************************************************************* + * Copyright (c) 2019-2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Oleg Semeniv +// + +#ifndef LIBND4J_UNICODE_H +#define LIBND4J_UNICODE_H + +#include + +namespace nd4j { +namespace unicode { + + /** + * This method calculate u16 offset based on utf8 + * @param const pointer to the utf8 string start point + * @param size of the string + * @return offset of utf16 + */ + Nd4jLong offsetUtf8StringInUtf16(const void* start, const void* end); + + /** + * This method calculate u8 offset based on utf16 + * @param const pointer to the utf16 string start point + * @param size of the string + * @return offset of utf8 + */ + Nd4jLong offsetUtf16StringInUtf8(const void* start, const void* end); + + /** + * This method calculate u32 offset based on utf16 + * @param const pointer to the utf16 string start point + * @param size of the string + * @return offset of utf32 + */ + Nd4jLong offsetUtf32StringInUtf16(const void* start, const void* end); + + /** + * This method calculate u32 offset based on utf8 + * @param const pointer to the utf16 string start point + * @param size of the string + * @return offset of utf8 + */ + Nd4jLong offsetUtf32StringInUtf8(const void* start, const void* end); + + /* + * This function check is valid charecter in u8 string + */ + bool isStringValidU8(const void* start, const void* stop); + + /* + * This function check is valid charecter in u16 string + */ + bool isStringValidU16(const void* start, const void* stop); + + /* + * This function check is valid u32 charecter in string + */ + bool isStringValidU32(const void* start, const void* stop); + + /** + * This method count offset for utf8 string in utf32 + * @param const pointer to the utf8 string start point + * @param size of the string + * @return offset + */ + Nd4jLong offsetUtf8StringInUtf32(const void* input, uint32_t nInputSize); + + /** + * This method count offset for utf8 string in utf32 + * @param const pointer to the utf8 string start point + * @param const end pointer to the utf8 string + * @return offset + */ + Nd4jLong offsetUtf8StringInUtf32(const void* input, const void* stop); + + /** + * This method count offset for utf32 based on utf16 string + * @param const pointer to the utf16 string start point + * @param size of the string + * @return offset + */ + Nd4jLong offsetUtf16StringInUtf32(const void* input, uint32_t nInputSize); + + /** + * This method calculate offset of u16 based on utf8 + * @param const pointer to the utf8 string start point + * @param size of the string + * @return offset of utf16 + */ + Nd4jLong offsetUtf8StringInUtf16(const void* input, uint32_t nInputSize); + + /** + * This method calculate offset of u8 based on utf16 + * @param const pointer to the utf16 string start point + * @param size of the string + * @return offset of utf8 + */ + Nd4jLong offsetUtf16StringInUtf8(const void* input, uint32_t nInputSize); + + /** + * This method calculate offset of u32 based on utf8 + * @param const pointer to the utf16 string start point + * @param size of the string + * @return offset of utf32 + */ + Nd4jLong offsetUtf32StringInUtf8(const void* input, uint32_t nInputSize); + + /** + * This method calculate offset of u32 based on utf16 + * @param const pointer to the utf16 string start point + * @param size of the string + * @return offset of utf32 + */ + Nd4jLong offsetUtf32StringInUtf16(const void* input, const uint32_t nInputSize); + + /** + * This method convert utf8 string to utf16 string + * @param const pointer to the utf8 string start point + * @param reference to start point to utf16 + * @param size of input utf8 string + * @return status of convertion + */ + bool utf8to16(const void* input, void* output, uint32_t nInputSize); + + /** + * This method convert utf8 string to utf32 string + * @param const pointer to the utf8 string start point + * @param reference to start point to utf32 + * @param size of input utf8 string + * @return status of convertion + */ + bool utf8to32(const void* input, void* output, uint32_t nInputSize); + + /** + * This method convert utf16 string to utf32 string + * @param const pointer to the utf16 string start point + * @param reference to start point to utf32 + * @param size of input utf16 string + * @return status of convertion + */ + bool utf16to32(const void* input, void* output, uint32_t nInputSize); + + /** + * This method convert utf16 string to utf8 string + * @param const pointer to the utf16 string start point + * @param reference to start point to utf8 + * @param size of input utf16 string + * @return status of convertion + */ + bool utf16to8(const void* input, void* output, uint32_t nInputSize); + + /** + * This method convert utf32 string to utf16 string + * @param const pointer to the utf32 string start point + * @param reference to start point to utf16 + * @param size of input utf32 string + * @return status of convertion + */ + bool utf32to16(const void* input, void* output, uint32_t nInputSize); + + /** + * This method convert utf32 string to utf8 string + * @param const pointer to the utf32 string start point + * @param reference to start point to utf8 + * @param size of input utf32 string + * @return status of convertion + */ + bool utf32to8(const void* input, void* output, const Nd4jLong nInputSize); +} +} + + +#endif //LIBND4J_UNICODE_H diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/pow.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/pow.cpp index 56f77737d..1fe3e359b 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/pow.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/pow.cpp @@ -118,7 +118,7 @@ namespace ops { DECLARE_TYPES(Pow_bp) { getOpDescriptor() ->setAllowedInputTypes({ ALL_FLOATS, ALL_INTS }) - ->setAllowedOutputTypes({ ALL_FLOATS }); // TODO maybe wourth to add ALL_INTS + ->setAllowedOutputTypes({ ALL_FLOATS }); } } diff --git a/libnd4j/include/ops/declarable/generic/compat/compat_string_split.cpp b/libnd4j/include/ops/declarable/generic/compat/compat_string_split.cpp index 9d7b57ee4..ac88a4a60 100644 --- a/libnd4j/include/ops/declarable/generic/compat/compat_string_split.cpp +++ b/libnd4j/include/ops/declarable/generic/compat/compat_string_split.cpp @@ -81,7 +81,7 @@ namespace nd4j { } // now once we have all strings in single vector time to fill - auto tmp = NDArrayFactory::string('c', {(Nd4jLong) strings.size()}, strings); + auto tmp = NDArrayFactory::string({(Nd4jLong) strings.size()}, strings); auto blen = StringUtils::byteLength(tmp) + ShapeUtils::stringBufferHeaderRequirements(strings.size()); // for CUDA mostly diff --git a/libnd4j/include/types/types.h b/libnd4j/include/types/types.h index 9c8dcb273..92fada8d3 100644 --- a/libnd4j/include/types/types.h +++ b/libnd4j/include/types/types.h @@ -33,6 +33,11 @@ #include +#define LIBND4J_STRINGTYPES \ + (nd4j::DataType::UTF8, std::string),\ + (nd4j::DataType::UTF16, std::u16string), \ + (nd4j::DataType::UTF32, std::u32string) + #define LIBND4J_TYPES \ (nd4j::DataType::BFLOAT16, bfloat16),\ (nd4j::DataType::HALF, float16), \ diff --git a/libnd4j/tests_cpu/layers_tests/BroadcastableOpsTests.cpp b/libnd4j/tests_cpu/layers_tests/BroadcastableOpsTests.cpp index 0b1daa3af..1f6000f06 100644 --- a/libnd4j/tests_cpu/layers_tests/BroadcastableOpsTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/BroadcastableOpsTests.cpp @@ -599,7 +599,7 @@ TEST_F(BroadcastableOpsTests, broadcast_empty_2) { TEST_F(BroadcastableOpsTests, broadcast_empty_3) { NDArray x = NDArrayFactory::create('c', {1, 0, 2}); - NDArray y('c', {}, {0.1}, nd4j::DataType::FLOAT32); + NDArray y('c', {}, std::vector{0.1}, nd4j::DataType::FLOAT32); NDArray e = NDArrayFactory::create('c', {1, 0, 2});; nd4j::ops::maximum op; diff --git a/libnd4j/tests_cpu/layers_tests/ConvolutionTests2.cpp b/libnd4j/tests_cpu/layers_tests/ConvolutionTests2.cpp index 376234019..0c63f527e 100644 --- a/libnd4j/tests_cpu/layers_tests/ConvolutionTests2.cpp +++ b/libnd4j/tests_cpu/layers_tests/ConvolutionTests2.cpp @@ -626,7 +626,7 @@ TEST_F(ConvolutionTests2, deconv3d_bp_test1) { NDArray expGradI('c', {bS, oD, oH, oW, oC}, {62., 67.6, 68.4, 74.8, 81.2, 89.2, 87.6, 96.4, 119.6, 132.4, 126., 139.6, 138.8, 154., 145.2, 161.2}, nd4j::DataType::FLOAT32); NDArray expGradW('c', {kD, kH, kW, iC, oC}, {28., 28., 32., 32., 40., 40., 44., 44., 64, 64., 68., 68., 76., 76., 80., 80.}, nd4j::DataType::FLOAT32); - NDArray expGradB('c', {iC}, {364.5}, nd4j::DataType::FLOAT32); + NDArray expGradB('c', {iC}, std::vector{364.5}, nd4j::DataType::FLOAT32); input = 0.5; weights.linspace(0.1, 0.1); diff --git a/libnd4j/tests_cpu/layers_tests/CudaBasicsTests1.cu b/libnd4j/tests_cpu/layers_tests/CudaBasicsTests1.cu index 593d47bb5..0a22272d8 100644 --- a/libnd4j/tests_cpu/layers_tests/CudaBasicsTests1.cu +++ b/libnd4j/tests_cpu/layers_tests/CudaBasicsTests1.cu @@ -132,11 +132,11 @@ TEST_F(CudaBasicsTests1, execIndexReduceScalar_1) { NDArray x2('c', {2,2}, {0.5, 1.5, -4.5, 3.5}, nd4j::DataType::BFLOAT16); NDArray x3('c', {2,2}, {0, -1, 0, 1}, nd4j::DataType::BOOL); - NDArray scalar('c', {}, {0}, nd4j::DataType::INT64); + NDArray scalar('c', {}, std::vector{0}, nd4j::DataType::INT64); - NDArray exp1('c', {}, {3}, nd4j::DataType::INT64); - NDArray exp2('c', {}, {2}, nd4j::DataType::INT64); - NDArray exp3('c', {}, {1}, nd4j::DataType::INT64); + NDArray exp1('c', {}, std::vector{3}, nd4j::DataType::INT64); + NDArray exp2('c', {}, std::vector{2}, nd4j::DataType::INT64); + NDArray exp3('c', {}, std::vector{1}, nd4j::DataType::INT64); void *dX1, *dX2, *dX3, *dZ; Nd4jLong *dX1ShapeInfo, *dX2ShapeInfo, *dX3ShapeInfo, *dZShapeInfo; @@ -262,11 +262,11 @@ TEST_F(CudaBasicsTests1, execReduce3Scalar_1) { NDArray x3('c', {2,2}, {1.5,1.5,1.5,1.5}, nd4j::DataType::DOUBLE); NDArray x4('c', {2,2}, {1,2,3,4}, nd4j::DataType::DOUBLE); - NDArray exp1('c', {}, {-30.f}, nd4j::DataType::FLOAT32); - NDArray exp2('c', {}, {15.}, nd4j::DataType::DOUBLE); + NDArray exp1('c', {}, std::vector{-30.f}, nd4j::DataType::FLOAT32); + NDArray exp2('c', {}, std::vector{15.}, nd4j::DataType::DOUBLE); - NDArray scalar1('c', {}, {100.f}, nd4j::DataType::FLOAT32); - NDArray scalar2('c', {}, {100.}, nd4j::DataType::DOUBLE); + NDArray scalar1('c', {}, std::vector{100.f}, nd4j::DataType::FLOAT32); + NDArray scalar2('c', {}, std::vector{100.}, nd4j::DataType::DOUBLE); void *dX1, *dX2, *dX3, *dX4, *dZ1, *dZ2; Nd4jLong *dX1ShapeInfo, *dX3ShapeInfo, *dZ1ShapeInfo, *dZ2ShapeInfo; @@ -363,8 +363,8 @@ TEST_F(CudaBasicsTests1, execReduce3_1) { NDArray x('c', {2,2}, {1,2,3,4}, nd4j::DataType::INT32); NDArray y('c', {2,2}, {-1,-2,-3,-4}, nd4j::DataType::INT32); - NDArray exp('c', {}, {-30.f}, nd4j::DataType::FLOAT32); - NDArray z('c', {}, {100.f}, nd4j::DataType::FLOAT32); + NDArray exp('c', {}, std::vector{-30.f}, nd4j::DataType::FLOAT32); + NDArray z('c', {}, std::vector{100.f}, nd4j::DataType::FLOAT32); std::vector dimensions = {0, 1}; @@ -415,8 +415,8 @@ TEST_F(CudaBasicsTests1, execReduce3_2) { NDArray x('c', {2,2}, {1.5,1.5,1.5,1.5}, nd4j::DataType::DOUBLE); NDArray y('c', {2,2}, {1,2,3,4}, nd4j::DataType::DOUBLE); - NDArray exp('c', {}, {15.}, nd4j::DataType::DOUBLE); - NDArray z('c', {}, {100.}, nd4j::DataType::DOUBLE); + NDArray exp('c', {}, std::vector{15.}, nd4j::DataType::DOUBLE); + NDArray z('c', {}, std::vector{100.}, nd4j::DataType::DOUBLE); std::vector dimensions = {0, 1}; @@ -975,7 +975,7 @@ TEST_F(CudaBasicsTests1, execScalar_1) { NDArray x('c', {2,3}, {0,1,2,3,4,5}, nd4j::DataType::INT64); NDArray exp('c',{2,3}, {0,0,1,1,2,2}, nd4j::DataType::INT64); - NDArray scalar('c',{}, {2.f}, nd4j::DataType::FLOAT32); + NDArray scalar('c',{}, std::vector{2.f}, nd4j::DataType::FLOAT32); NDArray z('c', {2,3}, {100,100,100,100,100,100}, nd4j::DataType::INT64); // create cuda stream and LaunchContext @@ -1010,7 +1010,7 @@ TEST_F(CudaBasicsTests1, execScalar_2) { NDArray x('c', {2,3}, {-1,-2,-3,-4,-5,-6}, nd4j::DataType::INT64); NDArray exp('c',{2,3}, {10,10,10,10,10,10}, nd4j::DataType::FLOAT32); - NDArray scalar('c',{}, {10.f}, nd4j::DataType::FLOAT32); + NDArray scalar('c',{}, std::vector{10.f}, nd4j::DataType::FLOAT32); NDArray z('c', {2,3}, {100,100,100,100,100,100}, nd4j::DataType::FLOAT32); // create cuda stream and LaunchContext @@ -1103,7 +1103,7 @@ TEST_F(CudaBasicsTests1, execScalar_3) { TEST_F(CudaBasicsTests1, execScalarBool_1) { NDArray x('c', {2,3}, {-1,-2,0,1,2,3}, nd4j::DataType::BFLOAT16); - NDArray scalar('c',{}, {0}, nd4j::DataType::BFLOAT16); + NDArray scalar('c',{}, std::vector{0}, nd4j::DataType::BFLOAT16); NDArray exp('c',{2,3}, {0,0,0,1,1,1}, nd4j::DataType::BOOL); NDArray z('c', {2,3}, {100,100,100,100,100,100,}, nd4j::DataType::BOOL); @@ -2245,8 +2245,8 @@ TEST_F(CudaBasicsTests1, execReduceLong_2) { TEST_F(CudaBasicsTests1, execReduceFloatScalar_1) { NDArray x('c', {2,3,4}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18}, nd4j::DataType::INT32); - NDArray z('c', {}, {100}, nd4j::DataType::FLOAT32); - NDArray exp('c', {}, {6.5}, nd4j::DataType::FLOAT32); + NDArray z('c', {}, std::vector{100}, nd4j::DataType::FLOAT32); + NDArray exp('c', {}, std::vector{6.5}, nd4j::DataType::FLOAT32); x.permutei({2,1,0}); // create cuda stream and LaunchContext @@ -2282,8 +2282,8 @@ TEST_F(CudaBasicsTests1, execReduceFloatScalar_1) { TEST_F(CudaBasicsTests1, execReduceFloatScalar_2) { NDArray x('c', {2,3,4}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18}, nd4j::DataType::INT32); - NDArray z('c', {}, {100}, nd4j::DataType::DOUBLE); - NDArray exp('c', {}, {6.5}, nd4j::DataType::DOUBLE); + NDArray z('c', {}, std::vector{100}, nd4j::DataType::DOUBLE); + NDArray exp('c', {}, std::vector{6.5}, nd4j::DataType::DOUBLE); // create cuda stream and LaunchContext cudaError_t cudaResult; @@ -2318,8 +2318,8 @@ TEST_F(CudaBasicsTests1, execReduceFloatScalar_2) { TEST_F(CudaBasicsTests1, execReduceSameScalar_1) { NDArray x('c', {2,3,4}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18}, nd4j::DataType::INT32); - NDArray z('c', {}, {100}, nd4j::DataType::INT32); - NDArray exp('c', {}, {156}, nd4j::DataType::INT32); + NDArray z('c', {}, std::vector{100}, nd4j::DataType::INT32); + NDArray exp('c', {}, std::vector{156}, nd4j::DataType::INT32); x.permutei({2,1,0}); // create cuda stream and LaunchContext @@ -2355,8 +2355,8 @@ TEST_F(CudaBasicsTests1, execReduceSameScalar_1) { TEST_F(CudaBasicsTests1, execReduceSameScalar_2) { NDArray x('c', {2,3,4}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18}, nd4j::DataType::DOUBLE); - NDArray z('c', {}, {100}, nd4j::DataType::DOUBLE); - NDArray exp('c', {}, {156}, nd4j::DataType::DOUBLE); + NDArray z('c', {}, std::vector{100}, nd4j::DataType::DOUBLE); + NDArray exp('c', {}, std::vector{156}, nd4j::DataType::DOUBLE); // create cuda stream and LaunchContext cudaError_t cudaResult; @@ -2391,8 +2391,8 @@ TEST_F(CudaBasicsTests1, execReduceSameScalar_2) { TEST_F(CudaBasicsTests1, execReduceBoolScalar_1) { NDArray x('c', {2,3,4}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6,-7,-8,-9,-10,-11,-12,-13,-14,-15,-16,-17,-18}, nd4j::DataType::INT32); - NDArray z('c', {}, {100}, nd4j::DataType::BOOL); - NDArray exp('c', {}, {1}, nd4j::DataType::BOOL); + NDArray z('c', {}, std::vector{100}, nd4j::DataType::BOOL); + NDArray exp('c', {}, std::vector{1}, nd4j::DataType::BOOL); x.permutei({2,1,0}); x.syncShape(); @@ -2429,8 +2429,8 @@ TEST_F(CudaBasicsTests1, execReduceBoolScalar_1) { TEST_F(CudaBasicsTests1, execReduceBoolScalar_2) { NDArray x('c', {2,3,4}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6,-7,-8,-9,-10,-11,-12,-13,-14,-15,-16,-17,-18}, nd4j::DataType::DOUBLE); - NDArray z('c', {}, {100}, nd4j::DataType::BOOL); - NDArray exp('c', {}, {1}, nd4j::DataType::BOOL); + NDArray z('c', {}, std::vector{100}, nd4j::DataType::BOOL); + NDArray exp('c', {}, std::vector{1}, nd4j::DataType::BOOL); // create cuda stream and LaunchContext cudaError_t cudaResult; @@ -2465,8 +2465,8 @@ TEST_F(CudaBasicsTests1, execReduceBoolScalar_2) { TEST_F(CudaBasicsTests1, execReduceLongScalar_1) { NDArray x('c', {2,3,4}, {-5,0,-3,0,-1,0,1,2,3,4,5,6,7,0,9,10,11,0,13,14,0,16,0,18}, nd4j::DataType::INT32); - NDArray z('c', {}, {100}, nd4j::DataType::INT64); - NDArray exp('c', {}, {17}, nd4j::DataType::INT64); + NDArray z('c', {}, std::vector{100}, nd4j::DataType::INT64); + NDArray exp('c', {}, std::vector{17}, nd4j::DataType::INT64); x.permutei({2,1,0}); x.syncShape(); @@ -2503,8 +2503,8 @@ TEST_F(CudaBasicsTests1, execReduceLongScalar_1) { TEST_F(CudaBasicsTests1, execReduceLongScalar_2) { NDArray x('c', {2,3,4}, {-5,0,-3,0,-1,0,1,2,3,4,5,6,7,0,9,10,11,0,13,14,0,16,0,18}, nd4j::DataType::DOUBLE); - NDArray z('c', {}, {100}, nd4j::DataType::INT64); - NDArray exp('c', {}, {17}, nd4j::DataType::INT64); + NDArray z('c', {}, std::vector{100}, nd4j::DataType::INT64); + NDArray exp('c', {}, std::vector{17}, nd4j::DataType::INT64); // create cuda stream and LaunchContext cudaError_t cudaResult; @@ -2685,8 +2685,8 @@ TEST_F(CudaBasicsTests1, execReduce3TAD_4) { NDArray x('c', {2,2,3}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6}, nd4j::DataType::DOUBLE); NDArray y('c', {2,2,3}, {10,20,30,40,50,60,70,80,90,100,110,120}, nd4j::DataType::DOUBLE); - NDArray exp('c', {}, {1820}, nd4j::DataType::FLOAT32); - NDArray z('c', {}, {100}, nd4j::DataType::FLOAT32); + NDArray exp('c', {}, std::vector{1820}, nd4j::DataType::FLOAT32); + NDArray z('c', {}, std::vector{100}, nd4j::DataType::FLOAT32); std::vector dimensions = {0,1,2}; @@ -2739,8 +2739,8 @@ TEST_F(CudaBasicsTests1, execReduce3TAD_4) { TEST_F(CudaBasicsTests1, execSummaryStats_1) { NDArray x('c', {2,2,3}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6}, nd4j::DataType::INT64); - NDArray exp('c', {}, {3.605551}, nd4j::DataType::FLOAT32); - NDArray z('c', {}, {100}, nd4j::DataType::FLOAT32); + NDArray exp('c', {}, std::vector{3.605551}, nd4j::DataType::FLOAT32); + NDArray z('c', {}, std::vector{100}, nd4j::DataType::FLOAT32); // create cuda stream and LaunchContext cudaError_t cudaResult; @@ -2881,8 +2881,8 @@ TEST_F(CudaBasicsTests1, execSummaryStats_3) { TEST_F(CudaBasicsTests1, execSummaryStatsScalar_1) { NDArray x('c', {2,2,3}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6}, nd4j::DataType::INT64); - NDArray exp('c', {}, {3.605551}, nd4j::DataType::FLOAT32); - NDArray z('c', {}, {100}, nd4j::DataType::FLOAT32); + NDArray exp('c', {}, std::vector{3.605551}, nd4j::DataType::FLOAT32); + NDArray z('c', {}, std::vector{100}, nd4j::DataType::FLOAT32); // create cuda stream and LaunchContext cudaError_t cudaResult; diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp index 4875ce8c5..a0722f9d0 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp @@ -775,7 +775,7 @@ TEST_F(DeclarableOpsTests10, sparse_softmax_cross_entropy_loss_with_logits_test2 /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, sparse_softmax_cross_entropy_loss_with_logits_test3) { - NDArray labels('c', {1}, {0}, nd4j::DataType::INT32); + NDArray labels('c', {1}, std::vector{0}, nd4j::DataType::INT32); auto logits = NDArrayFactory::create('c', {1,3}); auto expected = NDArrayFactory::create('c', {1}, {1.20194}); @@ -2735,7 +2735,7 @@ TEST_F(DeclarableOpsTests10, Image_CropAndResize_3) { NDArray images ('c', {1,2,2,1}, {1,2,3,4}, nd4j::DataType::FLOAT32); NDArray boxes('c', {1,4}, {0,0,1,1}, nd4j::DataType::FLOAT32); - NDArray boxI('c', {1}, {0}, nd4j::DataType::INT64); + NDArray boxI('c', {1}, std::vector{0}, nd4j::DataType::INT64); NDArray cropSize = NDArrayFactory::create({3, 3}); //NDArray ('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f}); @@ -2759,7 +2759,7 @@ TEST_F(DeclarableOpsTests10, Image_CropAndResize_4) { NDArray images('c', {1,2,2,1}, {1, 2, 3, 4}, nd4j::DataType::FLOAT32); NDArray boxes('c', {1,4}, {0,0,1,1}, nd4j::DataType::FLOAT32); - NDArray boxI('c', {1}, {0}, nd4j::DataType::INT32); + NDArray boxI('c', {1}, std::vector({0.}), nd4j::DataType::INT32); NDArray cropSize = NDArrayFactory::create({3, 3}); //NDArray ('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f}); @@ -2933,8 +2933,8 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_1) { NDArray x('c', {2,3}, {-63.80f, -63.75f, -63.70f, -63.5f, 0.0f, 0.1f}, nd4j::DataType::FLOAT32); NDArray exp('c', {2,3}, {-63.75f, -63.75f, -63.75f, -63.5f, 0.f, 0.f}, nd4j::DataType::FLOAT32); - NDArray min('c', {}, {-63.65f}, nd4j::DataType::FLOAT32); - NDArray max('c', {}, {0.1f}, nd4j::DataType::FLOAT32); + NDArray min('c', {}, std::vector{-63.65f}, nd4j::DataType::FLOAT32); + NDArray max('c', {}, std::vector{0.1f}, nd4j::DataType::FLOAT32); nd4j::ops::fake_quant_with_min_max_vars op; auto results = op.evaluate({&x, &min, &max}, {}, {}); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests11.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests11.cpp index 27f742316..71ebdc7e6 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests11.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests11.cpp @@ -121,7 +121,7 @@ TEST_F(DeclarableOpsTests11, log_loss_grad_test3) { NDArray dLdpExp('c', {2,3,4}, {-12.49997,-13.04346, -13.63635, -14.28571,-14.99999,-15.78947, -16.66666, -17.64705,-18.75 ,-20. , -21.42857, -23.07692, -24.99999,-27.27272, -29.99999, -33.33332,-37.49999,-42.85713, -49.99998, -59.99998,-74.99995,-99.99992,-149.99986,-299.99911}); - NDArray dLdwExp('c', {}, {-227.77286}); + NDArray dLdwExp('c', {}, std::vector{-227.77286}); NDArray dLdlExp('c', {2,3,4}, {1.58903, 1.22117, 0.99621, 0.82911, 0.69315, 0.57634, 0.47223, 0.37689, 0.28768, 0.20273, 0.12058, 0.04002, -0.04002,-0.12058,-0.20273,-0.28768,-0.37689,-0.47223,-0.57634,-0.69315,-0.82911,-0.99621,-1.22117,-1.58903}); @@ -246,7 +246,7 @@ TEST_F(DeclarableOpsTests11, log_loss_grad_test7) { NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE); NDArray weights(nd4j::DataType::DOUBLE); - NDArray dLdwExp('c', {}, {0.}); + NDArray dLdwExp('c', {}, std::vector{0.}); predictions.linspace(0.04, 0.04); labels.linspace(1); @@ -350,7 +350,7 @@ TEST_F(DeclarableOpsTests11, log_loss_grad_test10) { NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE); NDArray weights('c', {1,1}, nd4j::DataType::DOUBLE); - NDArray dLdwExp('c', {1,1}, {-9.49054}); + NDArray dLdwExp('c', {1,1}, std::vector{-9.49054}); predictions.linspace(0.04, 0.04); labels.linspace(1); @@ -1611,7 +1611,7 @@ TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test3) { NDArray dLdpExp('c', {2,3,4}, {-0.96, -1.92, -2.88, -3.84, -4.8 , -5.76, -6.72, -7.68, -8.64, -9.6 ,-10.56,-11.52, -12.48,-13.44,-14.4 ,-15.36,-16.32,-17.28,-18.24,-19.2 ,-20.16,-21.12,-22.08,-23.04}); - NDArray dLdwExp('c', {}, {4515.84}); + NDArray dLdwExp('c', {}, std::vector{4515.84}); predictions.linspace(0.04, 0.04); labels.linspace(1); @@ -1730,7 +1730,7 @@ TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test7) { NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE); NDArray weights(nd4j::DataType::DOUBLE); - NDArray dLdwExp('c', {}, {0.}); + NDArray dLdwExp('c', {}, std::vector{0.}); predictions.linspace(0.04, 0.04); labels.linspace(1); @@ -1830,7 +1830,7 @@ TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test10) { NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE); NDArray weights('c', {1,1}, nd4j::DataType::DOUBLE); - NDArray dLdwExp('c', {1,1}, {188.16}); + NDArray dLdwExp('c', {1,1}, std::vector{188.16}); predictions.linspace(0.04, 0.04); labels.linspace(1); @@ -2056,7 +2056,7 @@ TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test3) { NDArray dLdpExp('c', {2,3,4}, {-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5, -0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5}); - NDArray dLdwExp('c', {}, {288.}); + NDArray dLdwExp('c', {}, std::vector{288.}); predictions.linspace(0.04, 0.04); labels.linspace(1); @@ -2175,7 +2175,7 @@ TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test7) { NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE); NDArray weights(nd4j::DataType::DOUBLE); - NDArray dLdwExp('c', {}, {0.}); + NDArray dLdwExp('c', {}, std::vector{0.}); predictions.linspace(0.04, 0.04); labels.linspace(1); @@ -2275,7 +2275,7 @@ TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test10) { NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE); NDArray weights('c', {1,1}, nd4j::DataType::DOUBLE); - NDArray dLdwExp('c', {1,1}, {12.}); + NDArray dLdwExp('c', {1,1}, std::vector{12.}); predictions.linspace(0.04, 0.04); labels.linspace(1); @@ -2541,7 +2541,7 @@ TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test3) { NDArray dLdpExp('c', {2,3,4}, {-0.18499,-0.53 ,-0.875 ,-1.22 ,-1.56501,-1.91002,-2.25504,-2.60008,-2.94514,-3.29023,-3.63534,-3.98048, -4.32566,-4.67087,-5.01613,-5.36143,-5.70677,-6.05217,-6.39762,-6.74313,-7.0887 ,-7.43432,-7.78001,-8.12577}); - NDArray dLdwExp('c', {}, {-91.52109}); + NDArray dLdwExp('c', {}, std::vector{-91.52109}); NDArray dLdlExp('c', {2,3,4}, {0.028, 0.014, -0., -0.014,-0.028, -0.042, -0.056, -0.07 ,-0.084, -0.098, -0.112, -0.126, -0.14 , -0.154, -0.168, -0.182,-0.196, -0.21 , -0.224, -0.238,-0.252, -0.266, -0.28 , -0.294}); @@ -2664,7 +2664,7 @@ TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test7) { NDArray logits('c', {2,3,4}, nd4j::DataType::DOUBLE); NDArray weights(nd4j::DataType::DOUBLE); - NDArray dLdwExp('c', {}, {0.}); + NDArray dLdwExp('c', {}, std::vector{0.}); logits.linspace(-0.08, 0.04); labels.linspace(1); @@ -2766,7 +2766,7 @@ TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test10) { NDArray logits('c', {2,3,4}, nd4j::DataType::DOUBLE); NDArray weights('c', {1,1}, nd4j::DataType::DOUBLE); - NDArray dLdwExp('c', {1,1}, {-3.81338}); + NDArray dLdwExp('c', {1,1}, std::vector{-3.81338}); logits.linspace(-0.08, 0.04); labels.linspace(1); @@ -2992,7 +2992,7 @@ TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test2) { NDArray weights('c', {1}, nd4j::DataType::DOUBLE); NDArray dLdpExp('c', {4}, {0.125, 0.125, -0.375, 0.125}); - NDArray dLdwExp('c', {1}, {1.38629}); + NDArray dLdwExp('c', {1}, std::vector{1.38629}); logits = 2.; weights.assign(0.5); @@ -3020,10 +3020,10 @@ TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test3) { NDArray labels('c', {4}, {0,0,1,0}, nd4j::DataType::INT32); NDArray logits('c', {4}, nd4j::DataType::DOUBLE); - NDArray weights('c', {}, {0}, nd4j::DataType::DOUBLE); + NDArray weights('c', {}, std::vector{0}, nd4j::DataType::DOUBLE); NDArray dLdpExp('c', {4}, {0.125, 0.125, -0.375, 0.125}); - NDArray dLdwExp('c', {}, {1.38629}); + NDArray dLdwExp('c', {}, std::vector{1.38629}); logits = 2.; weights.assign(0.5); @@ -3051,10 +3051,10 @@ TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test4) { NDArray labels('c', {4}, {0,0,1,0}, nd4j::DataType::INT32); NDArray logits('c', {4}, nd4j::DataType::DOUBLE); - NDArray weights('c', {}, {0}, nd4j::DataType::DOUBLE); + NDArray weights('c', {}, std::vector{0}, nd4j::DataType::DOUBLE); NDArray dLdpExp('c', {4}, {0.23521, 0.2448 , -0.7452 , 0.26519}); - NDArray dLdwExp('c', {}, {0.}); + NDArray dLdwExp('c', {}, std::vector{0.}); logits.linspace(-0.08, 0.04); weights = 0.5; @@ -3085,7 +3085,7 @@ TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test5) { NDArray weights('c', {1}, nd4j::DataType::DOUBLE); NDArray dLdpExp('c', {4}, {0.1176, 0.1224, -0.3726, 0.1326}); - NDArray dLdwExp('c', {1}, {1.36729}); + NDArray dLdwExp('c', {1}, std::vector{1.36729}); logits.linspace(-0.08, 0.04); weights = 0.5; @@ -3321,7 +3321,7 @@ TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test4) { ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test5) { - NDArray labels('c', {2,1}, {1,0}); + NDArray labels('c', {2,1}, std::vector{1,0}); NDArray logits('c', {2,1}, {-0.04, 0.04}); NDArray dLdpExp('c', {2,1}, {-0.51999, 0.51999}); @@ -3343,10 +3343,10 @@ TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test5) { ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test6) { - NDArray labels('c', {1,2}, {1,1}); + NDArray labels('c', {1,2}, {1,1.}); NDArray logits('c', {1,2}, {-0.04, 0.04}); - NDArray dLdpExp('c', {1,2}, {0, 0}); + NDArray dLdpExp('c', {1,2}, {0, 0.}); nd4j::ops::softmax_cross_entropy_loss_with_logits_grad op; @@ -3387,10 +3387,10 @@ TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test7) { ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test8) { - NDArray labels('c', {1}, {1}); - NDArray logits('c', {1}, {0.04}); + NDArray labels('c', {1}, std::vector{1}); + NDArray logits('c', {1}, std::vector{0.04}); - NDArray dLdpExp('c', {1}, {0}); + NDArray dLdpExp('c', {1}, std::vector{0}); nd4j::ops::softmax_cross_entropy_loss_with_logits_grad op; @@ -3483,7 +3483,7 @@ TEST_F(DeclarableOpsTests11, sparseSoftmaxCrossEntropyWithLogits_grad_test2) { ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, sparseSoftmaxCrossEntropyWithLogits_grad_test3) { - NDArray labels('c', {}, {1}, nd4j::DataType::INT64); + NDArray labels('c', {}, std::vector{1}, nd4j::DataType::INT64); NDArray logits('c', {2}, {-0.2, 0.3}); NDArray dLdpExp('c', {2}, {0.37754, -0.37754}); @@ -3529,7 +3529,7 @@ TEST_F(DeclarableOpsTests11, sparseSoftmaxCrossEntropyWithLogits_grad_test4) { ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, sparseSoftmaxCrossEntropyWithLogits_grad_test5) { - NDArray labels('c', {1,1}, {0}, nd4j::DataType::INT64); + NDArray labels('c', {1,1}, std::vector({0}), nd4j::DataType::INT64); NDArray logits('c', {1,1,2}, {-0.3,0.2}); NDArray dLdpExp('c', {1,1,2}, {-0.62246, 0.62246}); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp index 80a9d67a4..142a3dbd4 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp @@ -127,7 +127,7 @@ TEST_F(DeclarableOpsTests12, cosine_distance_loss_grad_test3) { NDArray weights('c', {1}, nd4j::DataType::DOUBLE); NDArray dLdpExp('c', {4}, {0.05, -0.15, -1., 0.7}); - NDArray dLdwExp('c', {1}, {1.3}); + NDArray dLdwExp('c', {1}, std::vector{1.3}); NDArray dLdlExp('c', {4}, {0.2, 0.1, -0. , -0.1}); predictions.linspace(-0.4, 0.2); @@ -158,10 +158,10 @@ TEST_F(DeclarableOpsTests12, cosine_distance_loss_grad_test4) { NDArray labels('c', {1,4}, {-0.1, 0.3, 2, -1.4}); NDArray predictions('c', {1,4}, nd4j::DataType::DOUBLE); - NDArray weights('c', {}, {0.}, nd4j::DataType::DOUBLE); + NDArray weights('c', {}, std::vector{0.}, nd4j::DataType::DOUBLE); NDArray dLdpExp('c', {1,4}, {0.05, -0.15, -1., 0.7}); - NDArray dLdwExp('c', {}, {1.3}); + NDArray dLdwExp('c', {}, std::vector{1.3}); NDArray dLdlExp('c', {1,4}, {0.2, 0.1, -0. , -0.1}); predictions.linspace(-0.4, 0.2); @@ -196,7 +196,7 @@ TEST_F(DeclarableOpsTests12, cosine_distance_loss_grad_test5) { NDArray weights('c', {1,1}, nd4j::DataType::DOUBLE); NDArray dLdpExp('c', {4}, {0.1, -0.3, -2. , 1.4}); - NDArray dLdwExp('c', {1,1}, {0.}); + NDArray dLdwExp('c', {1,1}, std::vector{0.}); NDArray dLdlExp('c', {4}, {0.4, 0.2, -0. , -0.2}); predictions.linspace(-0.4, 0.2); @@ -369,10 +369,10 @@ TEST_F(DeclarableOpsTests12, cosine_distance_loss_grad_test9) { TEST_F(DeclarableOpsTests12, hinge_loss_14) { NDArray logits('c', {3,4}, nd4j::DataType::DOUBLE); - NDArray weights('c', {}, {1.}); + NDArray weights('c', {}, std::vector{1.}); NDArray labels('c', {3,4}, {0,1,1,0,1,0,1,0,1,0,1,0}); - NDArray output('c', {}, {0.}, nd4j::DataType::DOUBLE); + NDArray output('c', {}, std::vector{0.}, nd4j::DataType::DOUBLE); logits.linspace(1.); weights.assign(1.); @@ -576,7 +576,7 @@ TEST_F(DeclarableOpsTests12, TestMinimumBP_1) { TEST_F(DeclarableOpsTests12, reverse_test15) { NDArray x('c', {5}, {1,2,3,4,5}, nd4j::DataType::DOUBLE); - NDArray axis('c', {}, {0}, nd4j::DataType::INT32); + NDArray axis('c', {}, std::vector{0}, nd4j::DataType::INT32); NDArray z('c', {5}, nd4j::DataType::DOUBLE); NDArray exp('c', {5}, {5,4,3,2,1}, nd4j::DataType::DOUBLE); @@ -711,7 +711,7 @@ TEST_F(DeclarableOpsTests12, multiUnique_2) { //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, tensormmul_6) { - NDArray x('c', {1}, {2}, nd4j::DataType::FLOAT32); + NDArray x('c', {1}, std::vector{2}, nd4j::DataType::FLOAT32); NDArray y('c', {2,1,2}, {1,2,3,4}, nd4j::DataType::FLOAT32); NDArray exp('c', {2,2}, {2,4,6,8}, nd4j::DataType::FLOAT32); @@ -1140,9 +1140,9 @@ TEST_F(DeclarableOpsTests12, lrn_bp_9) { ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, lrn_bp_10) { - NDArray input('c', {1,1,1,1}, {1}); - NDArray gradO('c', {1,1,1,1}, {1}); - NDArray exp('c', {1,1,1,1}, {0.19245008}); + NDArray input('c', {1,1,1,1}, std::vector{1}); + NDArray gradO('c', {1,1,1,1}, std::vector{1}); + NDArray exp('c', {1,1,1,1}, std::vector{0.19245008}); nd4j::ops::lrn_bp op; @@ -1193,8 +1193,8 @@ TEST_F(DeclarableOpsTests12, lrn_2) { ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, lrn_3) { - NDArray input('c', {1,1,1,1}, {1.}); - NDArray exp('c', {1,1,1,1}, {0.69006556}); + NDArray input('c', {1,1,1,1}, std::vector{1.}); + NDArray exp('c', {1,1,1,1}, std::vector{0.69006556}); nd4j::ops::lrn op; @@ -1208,8 +1208,8 @@ TEST_F(DeclarableOpsTests12, lrn_3) { ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, lrn_4) { - NDArray input('c', {1,1,1,1}, {1.}); - NDArray exp('c', {1,1,1,1}, {0.69006556}); + NDArray input('c', {1,1,1,1}, std::vector{1.}); + NDArray exp('c', {1,1,1,1}, std::vector{0.69006556}); nd4j::ops::lrn op; @@ -1239,10 +1239,10 @@ TEST_F(DeclarableOpsTests12, lrn_5) { TEST_F(DeclarableOpsTests12, inTopK_1) { NDArray x('c', {4, 5}, {11.0, 14.0, 6.0, 9.0, 3.5, 7.0, 21.0, 3.0, 15.0, 6.0, 9.0, 3.5, 7.0, 11.0, 13.0, 5.0, 16.0, 9.0, 13.5, 7.0}); - NDArray y('c', {4}, {0, 0, 0, 0}, nd4j::DataType::INT64); - NDArray z('c', {4}, {1, 1, 1, 1}, nd4j::DataType::BOOL); + NDArray y('c', {4}, {0., 0, 0, 0}, nd4j::DataType::INT64); + NDArray z('c', {4}, {1., 1, 1, 1}, nd4j::DataType::BOOL); - NDArray expV('c', {4}, {1, 0, 0, 0}, nd4j::DataType::BOOL); + NDArray expV('c', {4}, {1., 0, 0, 0}, nd4j::DataType::BOOL); nd4j::ops::in_top_k op; Nd4jStatus status = op.execute({&x, &y, }, {&z}, {}, {2}, {}); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp index a445666df..e964d397d 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp @@ -809,7 +809,7 @@ TEST_F(DeclarableOpsTests13, space_to_batch_nd_1) { NDArray x('c', {1, 2, 2, 2, 3}, nd4j::DataType::FLOAT32); NDArray blockShape('c', {3}, {2, 2, 2} , nd4j::DataType::INT32); // three spatial dimensions - NDArray paddings('c', {3, 2}, {0, 0, 0, 0, 0, 0} , nd4j::DataType::INT32); + NDArray paddings('c', {3, 2}, std::vector{0, 0, 0, 0, 0, 0} , nd4j::DataType::INT32); NDArray exp('c', {8, 1, 1, 1, 3}, nd4j::DataType::FLOAT32); @@ -892,8 +892,8 @@ TEST_F(DeclarableOpsTests13, batch_to_space_nd_1) { NDArray x('c', {8, 1, 1, 1, 3}, nd4j::DataType::FLOAT32); - NDArray blockShape('c', {3}, {2, 2, 2} , nd4j::DataType::INT32); // three spatial dimensions - NDArray crop('c', {3, 2}, {0, 0, 0, 0, 0, 0} , nd4j::DataType::INT32); + NDArray blockShape('c', {3}, {2., 2, 2} , nd4j::DataType::INT32); // three spatial dimensions + NDArray crop('c', {3, 2}, {0., 0, 0, 0, 0, 0} , nd4j::DataType::INT32); NDArray exp('c', {1, 2, 2, 2, 3}, nd4j::DataType::FLOAT32); @@ -990,7 +990,7 @@ TEST_F(DeclarableOpsTests13, mergemax_1) { TEST_F(DeclarableOpsTests13, mergemax_2) { NDArray x1('c', {1, 3}, {0., 1, 2}, nd4j::DataType::FLOAT32); - NDArray x2('c', {1, 1}, {1.}, nd4j::DataType::FLOAT32); + NDArray x2('c', {1, 1}, std::vector{1.}, nd4j::DataType::FLOAT32); NDArray out('c', {1, 3}, {-1., -1, -1}, nd4j::DataType::FLOAT32); nd4j::ops::mergemax op; @@ -2143,10 +2143,10 @@ TEST_F(DeclarableOpsTests13, batchnorm_test7) { NDArray input2('c', {3,15,15,3}, nd4j::DataType::FLOAT32); input2.permutei({0,3,1,2}); - NDArray mean ('c', {3}, {0, 0, 0}, nd4j::DataType::FLOAT32); - NDArray variance('c', {3}, {1, 1, 1}, nd4j::DataType::FLOAT32); - NDArray gamma ('c', {3}, {1, 1, 1}, nd4j::DataType::FLOAT32); - NDArray beta ('c', {3}, {0, 0, 0}, nd4j::DataType::FLOAT32); + NDArray mean ('c', {3}, {0., 0, 0}, nd4j::DataType::FLOAT32); + NDArray variance('c', {3}, {1., 1, 1}, nd4j::DataType::FLOAT32); + NDArray gamma ('c', {3}, {1., 1, 1}, nd4j::DataType::FLOAT32); + NDArray beta ('c', {3}, {0., 0, 0}, nd4j::DataType::FLOAT32); NDArray out1('c', {3,3,15,15}, nd4j::DataType::FLOAT32); NDArray out2('c', {3,3,15,15}, nd4j::DataType::FLOAT32); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp index cd7f84610..d154039f3 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp @@ -858,7 +858,7 @@ TEST_F(DeclarableOpsTests15, test_empty_decreasing_1) { TEST_F(DeclarableOpsTests15, test_rgb_to_grs_1) { // rank 1 NDArray rgbs('c', { 3 }, { 10, 50, 200 }, nd4j::DataType::INT32); - NDArray expected('c', { 1 }, { 55 }, nd4j::DataType::INT32); + NDArray expected('c', { 1 }, std::vector{ 55 }, nd4j::DataType::INT32); nd4j::ops::rgb_to_grs op; auto result = op.evaluate({&rgbs}, {}, {}); auto output = result->at(0); @@ -1395,7 +1395,7 @@ TEST_F(DeclarableOpsTests15, Pow_BP_Test6) { y.assign(4.0); dLdzC.linspace(0.1, 0.1); - NDArray dLdxExpXC('c', { 1 }, { 115.2 }, nd4j::DataType::FLOAT32); + NDArray dLdxExpXC('c', { 1 }, std::vector{ 115.2 }, nd4j::DataType::FLOAT32); NDArray dLdyExpXC('c', { 2, 2, 2 }, { 1.10904, 2.21807, 3.32711, 4.43614, 5.54518, 6.65421, 7.76325, 8.87228 }, nd4j::DataType::FLOAT32); nd4j::ops::Pow_bp op; diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests17.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests17.cpp index 3f200854d..497475262 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests17.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests17.cpp @@ -55,11 +55,11 @@ TEST_F(DeclarableOpsTests17, test_sparse_to_dense_1) { } TEST_F(DeclarableOpsTests17, test_sparse_to_dense_2) { - auto values = NDArrayFactory::string('c', {3}, {"alpha", "beta", "gamma"}); + auto values = NDArrayFactory::string({3}, {"alpha", "beta", "gamma"}); auto shape = NDArrayFactory::create({3, 3}); auto ranges = NDArrayFactory::create({0,0, 1,1, 2,2}); auto def = NDArrayFactory::string("d"); - auto exp = NDArrayFactory::string('c', {3, 3}, {"alpha","d","d", "d","beta","d", "d","d","gamma"}); + auto exp = NDArrayFactory::string( {3, 3}, {"alpha","d","d", "d","beta","d", "d","d","gamma"}); nd4j::ops::compat_sparse_to_dense op; @@ -70,11 +70,11 @@ TEST_F(DeclarableOpsTests17, test_sparse_to_dense_2) { } TEST_F(DeclarableOpsTests17, test_compat_string_split_1) { - auto x = NDArrayFactory::string('c', {2}, {"first string", "second"}); + auto x = NDArrayFactory::string( {2}, {"first string", "second"}); auto delimiter = NDArrayFactory::string(" "); auto exp0 = NDArrayFactory::create({0,0, 0,1, 1,0}); - auto exp1 = NDArrayFactory::string('c', {3}, {"first", "string", "second"}); + auto exp1 = NDArrayFactory::string( {3}, {"first", "string", "second"}); nd4j::ops::compat_string_split op; auto result = op.evaluate({&x, &delimiter}); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests2.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests2.cpp index 118463e3e..fa129b1af 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests2.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests2.cpp @@ -79,7 +79,7 @@ TEST_F(DeclarableOpsTests2, gather_2) { TEST_F(DeclarableOpsTests2, gather_3) { NDArray input ('c', {2,3,4}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24}); - NDArray indices ('c', {1,1}, {2}, nd4j::DataType::INT32); + NDArray indices ('c', {1,1}, std::vector{2}, nd4j::DataType::INT32); NDArray expected('c', {2,1,1,4}, {9,10,11,12,21,22,23,24}); nd4j::ops::gather op; @@ -186,7 +186,7 @@ TEST_F(DeclarableOpsTests2, gather_7) { TEST_F(DeclarableOpsTests2, gather_8) { NDArray input('c', {3,5}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, nd4j::DataType::FLOAT32); - NDArray indices('c', {1}, {2}, nd4j::DataType::INT32); + NDArray indices('c', {1}, std::vector{2}, nd4j::DataType::INT32); NDArray expected('c', {1,5}, {11, 12, 13, 14, 15.}, nd4j::DataType::FLOAT32); nd4j::ops::gather op; @@ -206,7 +206,7 @@ TEST_F(DeclarableOpsTests2, gather_8) { //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, gather_9) { NDArray x('c', {2, 4, 3, 2}, nd4j::DataType::FLOAT32); - NDArray indices('c', {2}, {1, 0}, nd4j::DataType::INT32); + NDArray indices('c', {2}, std::vector{1, 0}, nd4j::DataType::INT32); nd4j::ops::gather op; auto result = op.evaluate({&x, &indices}, {}, {-2}); @@ -238,7 +238,7 @@ TEST_F(DeclarableOpsTests2, gather_10) { TEST_F(DeclarableOpsTests2, gather_11) { NDArray x('c', {2, 2}, {1, 2, 3, 4}); - NDArray indices('c', {2}, {1, 0}, nd4j::DataType::INT64); + NDArray indices('c', {2}, std::vector{1, 0}, nd4j::DataType::INT64); NDArray e('c', {2, 2}, {3, 4, 1, 2}); nd4j::ops::gather op; diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests5.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests5.cpp index 2a5697ce8..62868f67f 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests5.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests5.cpp @@ -243,7 +243,7 @@ TEST_F(DeclarableOpsTests5, Test_SetSeed_1) { //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, scatterMul_test1) { auto matrix = NDArrayFactory::create('c', {2, 2}, {1.f, 2.f, 3.f, 4.f}); - NDArray idc('c', {1}, {0LL}, nd4j::DataType::INT64); + NDArray idc('c', {1}, std::vector({0LL}), nd4j::DataType::INT64); auto updates = NDArrayFactory::create('c', {1, 2}, {10.f, 1.f}); auto exp = NDArrayFactory::create('c', {2, 2}, {10.f, 2.f, 3.f, 4.f}); @@ -261,7 +261,7 @@ TEST_F(DeclarableOpsTests5, scatterMul_test1) { //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, scatterDiv_test1) { auto matrix = NDArrayFactory::create('c', {2, 2}, {1.f, 2.f, 3.f, 4.f}); - NDArray idc('c', {1}, {0LL}, nd4j::DataType::INT64); + NDArray idc('c', {1}, std::vector({0LL}), nd4j::DataType::INT64); auto updates = NDArrayFactory::create('c', {1, 2}, {10.f, 1.f}); auto exp = NDArrayFactory::create('c', {2, 2}, {0.10f, 2.f, 3.f, 4.f}); @@ -279,7 +279,7 @@ TEST_F(DeclarableOpsTests5, scatterDiv_test1) { //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, scatterSub_test1) { auto matrix = NDArrayFactory::create('c', {2, 2}, {1.f, 2.f, 3.f, 4.f}); - NDArray idc('c', {1}, {0LL}, nd4j::DataType::INT64); + NDArray idc('c', {1}, std::vector({0LL}), nd4j::DataType::INT64); auto updates = NDArrayFactory::create('c', {1, 2}, {10.f, 1.f}); auto exp = NDArrayFactory::create('c', {2, 2}, {-9.f, 1.f, 3.f, 4.f}); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests6.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests6.cpp index 5be2eeebd..5a919d132 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests6.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests6.cpp @@ -1411,7 +1411,7 @@ TEST_F(DeclarableOpsTests6, MatrixDeterminant_2) { TEST_F(DeclarableOpsTests6, MatrixDeterminant_3) { auto x = NDArrayFactory::create('c', {1, 3, 3}, {3.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 3.0}); - NDArray exp('c', {1}, {-54.0}); + NDArray exp('c', {1}, std::vector{-54.0}); nd4j::ops::matrix_determinant op; auto result = op.evaluate({&x}, {}, {}); @@ -1453,7 +1453,7 @@ TEST_F(DeclarableOpsTests6, MatrixDeterminant_4) { TEST_F(DeclarableOpsTests6, MatrixDeterminant_5) { auto x = NDArrayFactory::create('c', {1, 4, 4}); - NDArray exp('c', {1}, {-16.0}); + NDArray exp('c', {1}, std::vector{-16.0}); x.linspace(1); x.p(5, 4.0); x.p(12, 12.0); diff --git a/libnd4j/tests_cpu/layers_tests/FlatUtilsTests.cpp b/libnd4j/tests_cpu/layers_tests/FlatUtilsTests.cpp index bf428b833..31aa997c6 100644 --- a/libnd4j/tests_cpu/layers_tests/FlatUtilsTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/FlatUtilsTests.cpp @@ -83,7 +83,7 @@ TEST_F(FlatUtilsTests, flat_bool_serde_1) { } TEST_F(FlatUtilsTests, flat_string_serde_1) { - auto array = NDArrayFactory::string('c', {3}, {"alpha", "beta", "gamma"}); + auto array = NDArrayFactory::string( {3}, {"alpha", "beta", "gamma"}); flatbuffers::FlatBufferBuilder builder(1024); auto flatArray = FlatUtils::toFlatArray(builder, array); diff --git a/libnd4j/tests_cpu/layers_tests/JavaInteropTests.cpp b/libnd4j/tests_cpu/layers_tests/JavaInteropTests.cpp index ee828a6e2..2f3f93d56 100644 --- a/libnd4j/tests_cpu/layers_tests/JavaInteropTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/JavaInteropTests.cpp @@ -1277,14 +1277,14 @@ TEST_F(JavaInteropTests, test_size_dtype_1) { } TEST_F(JavaInteropTests, test_expandable_array_op_1) { - auto x = NDArrayFactory::string('c', {2}, {"first string", "second"}); - auto d = NDArrayFactory::string(" "); + auto x = NDArrayFactory::string( {2}, {"first string", "second"}); + auto d = NDArrayFactory::string(" ", nd4j::DataType::UTF8); auto z0 = NDArrayFactory::create('c', {6}); - auto z1 = NDArrayFactory::string('c', {3}, {"", "", ""}); + auto z1 = NDArrayFactory::string( {3}, {"", "", ""}); auto exp0 = NDArrayFactory::create({0,0, 0,1, 1,0}); - auto exp1 = NDArrayFactory::string('c', {3}, {"first", "string", "second"}); + auto exp1 = NDArrayFactory::string( {3}, {"first", "string", "second"}); InteropDataBuffer iz0(z0.dataBuffer()); InteropDataBuffer iz1(z1.dataBuffer()); diff --git a/libnd4j/tests_cpu/layers_tests/MultiDataTypeTests.cpp b/libnd4j/tests_cpu/layers_tests/MultiDataTypeTests.cpp index aae4493ab..127b3c7d3 100644 --- a/libnd4j/tests_cpu/layers_tests/MultiDataTypeTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/MultiDataTypeTests.cpp @@ -204,7 +204,7 @@ TEST_F(MultiDataTypeTests, ndarray_repeat_test1) { //////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_bufferAsT_test1) { NDArray x('f', {2}, {1.5, 3.5}, nd4j::DataType::FLOAT32); - NDArray y('c', {}, {1.5}, nd4j::DataType::FLOAT32); + NDArray y('c', {}, std::vector{1.5}, nd4j::DataType::FLOAT32); const int* buffX = x.bufferAsT(); const int* buffY = y.bufferAsT(); @@ -217,8 +217,8 @@ TEST_F(MultiDataTypeTests, ndarray_assign_test1) { NDArray x('c', {2,2}, {0, 1, 2, 3}, nd4j::DataType::UINT8); NDArray exp('c', {2,2}, {10, 10, 20, 20}, nd4j::DataType::UINT8); - NDArray scalar1('c', {}, {10.5}, nd4j::DataType::FLOAT32); - NDArray scalar2('c', {}, {20.8}, nd4j::DataType::DOUBLE); + NDArray scalar1('c', {}, std::vector{10.5}, nd4j::DataType::FLOAT32); + NDArray scalar2('c', {}, std::vector{20.8}, nd4j::DataType::DOUBLE); x(0,{0}).assign(scalar1); x(1,{0}).assign(scalar2); @@ -233,9 +233,9 @@ TEST_F(MultiDataTypeTests, ndarray_assign_test1) { //////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_reduceAlongDimension_test1) { NDArray x('f', {2,2}, {0, 1.5, 2.5, 3.5}, nd4j::DataType::HALF); - NDArray exp1('c', {}, {3}, nd4j::DataType::INT64); - NDArray exp2('c', {1,1}, {1}, nd4j::DataType::INT64); - NDArray exp3('c', {2}, {1,2}, nd4j::DataType::INT64); + NDArray exp1('c', {}, std::vector{3}, nd4j::DataType::INT64); + NDArray exp2('c', {1,1}, std::vector{1}, nd4j::DataType::INT64); + NDArray exp3('c', {2}, std::vector{1,2}, nd4j::DataType::INT64); auto scalar1 = x.reduceAlongDimension(nd4j::reduce::CountNonZero, {}/*whole range*/); ASSERT_EQ(scalar1, exp1); @@ -250,7 +250,7 @@ TEST_F(MultiDataTypeTests, ndarray_reduceAlongDimension_test1) { //////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_reduceAlongDimension_test2) { NDArray x('c', {2, 2}, {0, 1, 2, 3}, nd4j::DataType::INT32); - NDArray exp1('c', {}, {1.5}, nd4j::DataType::FLOAT32); + NDArray exp1('c', {}, std::vector{1.5}, nd4j::DataType::FLOAT32); NDArray exp2('c', {2}, {0.5,2.5}, nd4j::DataType::FLOAT32); auto scalar1 = x.reduceAlongDimension(nd4j::reduce::Mean, {}/*whole range*/); @@ -265,7 +265,7 @@ TEST_F(MultiDataTypeTests, ndarray_reduceAlongDimension_test2) { //////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_reduceAlongDimension_test3) { NDArray x('c', {2, 2}, {0.5, 1.5, 2.5, 3.5}, nd4j::DataType::HALF); - NDArray exp1('c', {}, {8.}, nd4j::DataType::HALF); + NDArray exp1('c', {}, std::vector{8.}, nd4j::DataType::HALF); NDArray exp2('c', {2}, {2.,6.}, nd4j::DataType::HALF); auto scalar1 = x.reduceAlongDimension(nd4j::reduce::Sum, {}/*whole range*/); @@ -278,8 +278,8 @@ TEST_F(MultiDataTypeTests, ndarray_reduceAlongDimension_test3) { //////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_reduceAlongDimension_test4) { NDArray x('c', {2, 2}, {10.5, 1.5, -2.5, -3.5}, nd4j::DataType::HALF); - NDArray exp1('c', {}, {1}, nd4j::DataType::BOOL); - NDArray exp2('c', {2}, {1,0}, nd4j::DataType::BOOL); + NDArray exp1('c', {}, std::vector{1}, nd4j::DataType::BOOL); + NDArray exp2('c', {2}, std::vector{1, 0}, nd4j::DataType::BOOL); auto scalar1 = x.reduceAlongDimension(nd4j::reduce::IsPositive, {}/*whole range*/); ASSERT_EQ(scalar1, exp1); @@ -291,8 +291,8 @@ TEST_F(MultiDataTypeTests, ndarray_reduceAlongDimension_test4) { //////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_varianceNumber_test1) { NDArray x('f', {2, 2}, {0, 1, 2, 3}, nd4j::DataType::INT64); - NDArray exp1('c', {}, {1.666666667}, nd4j::DataType::FLOAT32); - NDArray exp2('c', {}, {1.118033989}, nd4j::DataType::FLOAT32); + NDArray exp1('c', {}, std::vector{1.666666667}, nd4j::DataType::FLOAT32); + NDArray exp2('c', {}, std::vector{1.118033989}, nd4j::DataType::FLOAT32); auto scalar1 = x.varianceNumber(variance::SummaryStatsVariance); ASSERT_EQ(scalar1, exp1); @@ -475,8 +475,8 @@ TEST_F(MultiDataTypeTests, ndarray_operatorPlusEqual_test1) { if (!Environment::getInstance()->isExperimentalBuild()) return; - NDArray scalar1('c', {0}, {4}, nd4j::DataType::INT32); - NDArray scalar2('c', {0}, {1.5}, nd4j::DataType::HALF); + NDArray scalar1('c', {0}, std::vector{4}, nd4j::DataType::INT32); + NDArray scalar2('c', {0}, std::vector{1.5}, nd4j::DataType::HALF); NDArray x1('c', {2,3}, {1.5, 2.5, 3.5, 4.5, 5.5, 6.5}, nd4j::DataType::FLOAT32); NDArray x2('c', {3,2}, {10, 20, 30, 40, 50, 60}, nd4j::DataType::INT64); @@ -485,8 +485,8 @@ TEST_F(MultiDataTypeTests, ndarray_operatorPlusEqual_test1) { NDArray x5('c', {2,2}, {0, 1, 2, 3}, nd4j::DataType::HALF); NDArray x6('c', {2}, {0.4, 0.5}, nd4j::DataType::FLOAT32); - NDArray exp1('c', {0}, {5}, nd4j::DataType::INT32); - NDArray exp2('c', {0}, {6.5}, nd4j::DataType::HALF); + NDArray exp1('c', {0}, std::vector{5}, nd4j::DataType::INT32); + NDArray exp2('c', {0}, std::vector{6.5}, nd4j::DataType::HALF); NDArray exp3('c', {3,2}, {11, 22, 33, 44, 55, 66}, nd4j::DataType::INT64); NDArray exp4('c', {2,3}, {12.5, 24.5, 36.5, 48.5, 60.5, 72.5}, nd4j::DataType::FLOAT32); NDArray exp5('c', {2,2}, {0.4, 1.5, 2.4, 3.5}, nd4j::DataType::HALF); @@ -553,8 +553,8 @@ TEST_F(MultiDataTypeTests, ndarray_operatorMinusEqual_test1) { if (!Environment::getInstance()->isExperimentalBuild()) return; - NDArray scalar1('c', {0}, {4}, nd4j::DataType::INT32); - NDArray scalar2('c', {0}, {1.5}, nd4j::DataType::HALF); + NDArray scalar1('c', {0}, std::vector{4}, nd4j::DataType::INT32); + NDArray scalar2('c', {0}, std::vector{1.5}, nd4j::DataType::HALF); NDArray x1('c', {2,3}, {1.5, 2.5, 3.5, 4.5, 5.5, 6.5}, nd4j::DataType::FLOAT32); NDArray x2('c', {3,2}, {10, 20, 30, 40, 50, 60}, nd4j::DataType::INT64); @@ -563,8 +563,8 @@ TEST_F(MultiDataTypeTests, ndarray_operatorMinusEqual_test1) { NDArray x5('c', {2,2}, {0, 1, 2, 3}, nd4j::DataType::HALF); NDArray x6('c', {2}, {0.4, 0.5}, nd4j::DataType::FLOAT32); - NDArray exp1('c', {0}, {2}, nd4j::DataType::INT32); - NDArray exp2('c', {0}, {-0.5}, nd4j::DataType::HALF); + NDArray exp1('c', {0}, std::vector{2}, nd4j::DataType::INT32); + NDArray exp2('c', {0}, std::vector{-0.5}, nd4j::DataType::HALF); NDArray exp3('c', {3,2}, {8, 17, 26, 35, 44, 53}, nd4j::DataType::INT64); NDArray exp4('c', {2,3}, {-6.5, -14.5, -22.5, -30.5, -38.5, -46.5}, nd4j::DataType::FLOAT32); NDArray exp5('c', {2,2}, {0.4, -0.5, -1.6, -2.5}, nd4j::DataType::HALF); @@ -631,8 +631,8 @@ TEST_F(MultiDataTypeTests, ndarray_operatorMultiplyEqual_test1) { if (!Environment::getInstance()->isExperimentalBuild()) return; - NDArray scalar1('c', {0}, {3}, nd4j::DataType::INT32); - NDArray scalar2('c', {0}, {2.5}, nd4j::DataType::HALF); + NDArray scalar1('c', {0}, std::vector{3}, nd4j::DataType::INT32); + NDArray scalar2('c', {0}, std::vector{2.5}, nd4j::DataType::HALF); NDArray x1('c', {2,3}, {1.5, 2.5, 3.5, 4.5, 5.5, 6.5}, nd4j::DataType::FLOAT32); NDArray x2('c', {3,2}, {1, 2, 3, 4, 5, 6}, nd4j::DataType::INT64); @@ -641,8 +641,8 @@ TEST_F(MultiDataTypeTests, ndarray_operatorMultiplyEqual_test1) { NDArray x5('c', {2,2}, {0, 1, 2, 3}, nd4j::DataType::HALF); NDArray x6('c', {2}, {0.4, 0.5}, nd4j::DataType::FLOAT32); - NDArray exp1('c', {0}, {7}, nd4j::DataType::INT32); - NDArray exp2('c', {0}, {17.5}, nd4j::DataType::HALF); + NDArray exp1('c', {0}, std::vector{7}, nd4j::DataType::INT32); + NDArray exp2('c', {0}, std::vector{17.5}, nd4j::DataType::HALF); NDArray exp3('c', {3,2}, {1, 5, 10, 18, 27, 39}, nd4j::DataType::INT64); NDArray exp4('c', {2,3}, {1.5, 12.5, 35, 81, 148.5, 253.5}, nd4j::DataType::FLOAT32); NDArray exp5('c', {2,2}, {0., 0.5, 0.8, 1.5}, nd4j::DataType::HALF); @@ -709,8 +709,8 @@ TEST_F(MultiDataTypeTests, ndarray_operatorDivideEqual_test1) { if (!Environment::getInstance()->isExperimentalBuild()) return; - NDArray scalar1('c', {0}, {3}, nd4j::DataType::INT32); - NDArray scalar2('c', {0}, {2.5}, nd4j::DataType::HALF); + NDArray scalar1('c', {0}, std::vector{3}, nd4j::DataType::INT32); + NDArray scalar2('c', {0}, std::vector{2.5}, nd4j::DataType::HALF); NDArray x1('c', {2,3}, {1.5, 2.5, 3.5, 4.5, 5.5, 6.5}, nd4j::DataType::FLOAT32); NDArray x2('c', {3,2}, {10, 20, 30, 40, 50, 60}, nd4j::DataType::INT64); @@ -719,8 +719,8 @@ TEST_F(MultiDataTypeTests, ndarray_operatorDivideEqual_test1) { NDArray x5('c', {2,2}, {1, 2, 3, 4}, nd4j::DataType::HALF); NDArray x6('c', {2}, {0.4, 0.5}, nd4j::DataType::FLOAT32); - NDArray exp1('c', {0}, {1}, nd4j::DataType::INT32); - NDArray exp2('c', {0}, {2.5}, nd4j::DataType::HALF); + NDArray exp1('c', {0}, std::vector{1}, nd4j::DataType::INT32); + NDArray exp2('c', {0}, std::vector{2.5}, nd4j::DataType::HALF); NDArray exp3('c', {3,2}, {6, 8, 8, 8, 9, 9}, nd4j::DataType::INT64); NDArray exp4('c', {2,3}, {0.25, 0.3125, 0.4375, 0.5625, 0.611111111, 0.722222222}, nd4j::DataType::FLOAT32); NDArray exp5('c', {2,2}, {0.4, 0.25, 0.1333333, 0.125}, nd4j::DataType::HALF); @@ -792,10 +792,10 @@ TEST_F(MultiDataTypeTests, ndarray_reduceNumberFloat_test1) { NDArray x3('c', {2,2}, {0.5, 1.5, 2.5, 3.5}, nd4j::DataType::DOUBLE); NDArray x4('c', {2,2}, {0, 1, 0, 1}, nd4j::DataType::BOOL); - NDArray exp1('c', {0}, {1.5}, nd4j::DataType::FLOAT32); - NDArray exp2('c', {0}, {2}, nd4j::DataType::HALF); - NDArray exp3('c', {0}, {2}, nd4j::DataType::DOUBLE); - NDArray exp4('c', {0}, {0.25},nd4j::DataType::FLOAT32); + NDArray exp1('c', {0}, std::vector{1.5}, nd4j::DataType::FLOAT32); + NDArray exp2('c', {0}, std::vector{2}, nd4j::DataType::HALF); + NDArray exp3('c', {0}, std::vector{2}, nd4j::DataType::DOUBLE); + NDArray exp4('c', {0}, std::vector{0.25},nd4j::DataType::FLOAT32); NDArray scalar = x1.reduceNumber(reduce::Mean); @@ -829,10 +829,10 @@ TEST_F(MultiDataTypeTests, ndarray_reduceNumberSame_test1) { NDArray x3('c', {2,2}, {0.5, 1.5, 2.5, 3.5}, nd4j::DataType::DOUBLE); NDArray x4('c', {2,2}, {0, 1, 0, 1}, nd4j::DataType::BOOL); - NDArray exp1('c', {0}, {6}, nd4j::DataType::INT64); - NDArray exp2('c', {0}, {8}, nd4j::DataType::HALF); - NDArray exp3('c', {0}, {8}, nd4j::DataType::DOUBLE); - NDArray exp4('c', {0}, {1}, nd4j::DataType::BOOL); + NDArray exp1('c', {0}, std::vector{6}, nd4j::DataType::INT64); + NDArray exp2('c', {0}, std::vector{8}, nd4j::DataType::HALF); + NDArray exp3('c', {0}, std::vector{8}, nd4j::DataType::DOUBLE); + NDArray exp4('c', {0}, std::vector{1}, nd4j::DataType::BOOL); NDArray scalar = x1.reduceNumber(reduce::Sum); @@ -866,7 +866,7 @@ TEST_F(MultiDataTypeTests, ndarray_reduceNumberBool_test1) { NDArray x3('c', {2,2}, {0.5, 1.5, 2.5, 3.5}, nd4j::DataType::DOUBLE); NDArray x4('c', {2,2}, {-2, -1, 0, 1}, nd4j::DataType::BOOL); - NDArray exp1('c', {0}, {1}, nd4j::DataType::BOOL); + NDArray exp1('c', {0}, std::vector{1}, nd4j::DataType::BOOL); NDArray scalar = x1.reduceNumber(reduce::IsFinite); ASSERT_EQ(scalar, exp1); @@ -899,10 +899,10 @@ TEST_F(MultiDataTypeTests, ndarray_reduceNumberLong_test1) { NDArray x3('c', {2,2}, {0.5, -1.5, 0, 3.5}, nd4j::DataType::DOUBLE); NDArray x4('c', {2,2}, {0, 1, 0, 1}, nd4j::DataType::BOOL); - NDArray exp1('c', {0}, {3}, nd4j::DataType::INT64); - NDArray exp2('c', {0}, {4}, nd4j::DataType::INT64); - NDArray exp3('c', {0}, {3}, nd4j::DataType::INT64); - NDArray exp4('c', {0}, {2}, nd4j::DataType::INT64); + NDArray exp1('c', {0}, std::vector{3}, nd4j::DataType::INT64); + NDArray exp2('c', {0}, std::vector{4}, nd4j::DataType::INT64); + NDArray exp3('c', {0}, std::vector{3}, nd4j::DataType::INT64); + NDArray exp4('c', {0}, std::vector{2}, nd4j::DataType::INT64); NDArray scalar = x1.reduceNumber(reduce::CountNonZero); ASSERT_EQ(scalar, exp1); @@ -934,9 +934,9 @@ TEST_F(MultiDataTypeTests, ndarray_indexReduceNumber_test1) { NDArray x2('c', {2,2}, {0.5, 1.5, -4.5, 3.5}, nd4j::DataType::HALF); NDArray x3('c', {2,2}, {0, -1, 0, 1}, nd4j::DataType::BOOL); - NDArray exp1('c', {0}, {3}, nd4j::DataType::INT64); - NDArray exp2('c', {0}, {2}, nd4j::DataType::INT64); - NDArray exp3('c', {0}, {1}, nd4j::DataType::INT64); + NDArray exp1('c', {0}, std::vector{3}, nd4j::DataType::INT64); + NDArray exp2('c', {0}, std::vector{2}, nd4j::DataType::INT64); + NDArray exp3('c', {0}, std::vector{1}, nd4j::DataType::INT64); NDArray scalar = x1.indexReduceNumber(nd4j::indexreduce::IndexAbsoluteMax); ASSERT_EQ(scalar, exp1); @@ -1238,15 +1238,15 @@ TEST_F(MultiDataTypeTests, ndarray_applyTrueBroadcast_test1) { NDArray x7('c', {2}, {1, 2}, nd4j::DataType::INT64); NDArray x8('c', {2,2}, nd4j::DataType::BOOL); - NDArray x13('c', {0}, {3}, nd4j::DataType::INT64); - NDArray x14('c', {0}, {1.5}, nd4j::DataType::DOUBLE); + NDArray x13('c', {0}, std::vector{3}, nd4j::DataType::INT64); + NDArray x14('c', {0}, std::vector{1.5}, nd4j::DataType::DOUBLE); NDArray x15(nd4j::DataType::DOUBLE); NDArray x16('c', {2,2}, nd4j::DataType::DOUBLE); NDArray exp1('c', {2,2}, {11, 22, 31, 42}, nd4j::DataType::HALF); NDArray exp2('c', {2,2}, {11, 22, 31, 42}, nd4j::DataType::INT32); NDArray exp3('c', {2,2}, {1, 1, 1, 1}, nd4j::DataType::BOOL); - NDArray exp4('c', {0}, {4.5}, nd4j::DataType::DOUBLE); + NDArray exp4('c', {0}, std::vector{4.5}, nd4j::DataType::DOUBLE); NDArray exp5('c', {2,2}, {11.5, 21.5, 31.5, 41.5}, nd4j::DataType::DOUBLE); x1.applyTrueBroadcast(nd4j::BroadcastOpsTuple::Add(), x2, x3); @@ -1289,13 +1289,13 @@ TEST_F(MultiDataTypeTests, ndarray_applyTrueBroadcast_test2) { NDArray x1('c', {2,2}, {10, 20, 30, 40}, nd4j::DataType::HALF); NDArray x2('c', {2}, {10, 40}, nd4j::DataType::HALF); NDArray x3('c', {2,2}, nd4j::DataType::BOOL); - NDArray x4('c', {0}, {10}, nd4j::DataType::HALF); - NDArray x5('c', {0}, {20}, nd4j::DataType::HALF); + NDArray x4('c', {0}, std::vector{10}, nd4j::DataType::HALF); + NDArray x5('c', {0}, std::vector{20}, nd4j::DataType::HALF); NDArray x6(nd4j::DataType::BOOL); NDArray exp1('c', {2,2}, {1, 0, 0, 1}, nd4j::DataType::BOOL); NDArray exp2('c', {2,2}, {1, 0, 0, 0}, nd4j::DataType::BOOL); - NDArray exp3('c', {0}, {0}, nd4j::DataType::BOOL); + NDArray exp3('c', {0}, std::vector{0}, nd4j::DataType::BOOL); x1.applyTrueBroadcast(BroadcastBoolOpsTuple(nd4j::scalar::EqualTo, nd4j::pairwise::EqualTo, nd4j::broadcast::EqualTo), x2, x3); ASSERT_EQ(x3, exp1); @@ -1459,16 +1459,16 @@ TEST_F(MultiDataTypeTests, ndarray_applyIndexedLambda_test1) { ////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_applyPairwiseLambda_test1) { - NDArray x1('c', {2,2}, {0, 1, 2, 3}, nd4j::DataType::DOUBLE); - NDArray x2('c', {2,2}, {0, 1, 2, 3}, nd4j::DataType::INT64); - NDArray x3('c', {2,2}, {0, 1.5, 2.5, 3.5}, nd4j::DataType::FLOAT32); + NDArray x1('c', {2,2}, {0., 1, 2, 3}, nd4j::DataType::DOUBLE); + NDArray x2('c', {2,2}, {0., 1, 2, 3}, nd4j::DataType::INT64); + NDArray x3('c', {2,2}, {0., 1.5, 2.5, 3.5}, nd4j::DataType::FLOAT32); NDArray x4('c', {2,2}, nd4j::DataType::DOUBLE); NDArray x5('c', {2,2}, {0, 1.5, 2.5, 3.5}, nd4j::DataType::FLOAT32); NDArray x6('c', {2,2}, {0.1, -1, -1, 0.1}, nd4j::DataType::BOOL); NDArray x7('c', {2,2}, nd4j::DataType::BOOL); NDArray other1('c', {2,2}, {0.1, 0.1, 0.1, 0.1}, nd4j::DataType::FLOAT32); NDArray other2('c', {2,2}, {0.1, 0.1, 0.1, 0.1}, nd4j::DataType::DOUBLE); - NDArray other3('c', {2,2}, {0, -1, -2, -3}, nd4j::DataType::INT64); + NDArray other3('c', {2,2}, {0., -1, -2, -3}, nd4j::DataType::INT64); NDArray other4('c', {2,2}, {1, 0, 0.1, 0}, nd4j::DataType::BOOL); auto func1 = [](float elem1, float elem2) { return elem1 + elem2; }; @@ -1478,10 +1478,10 @@ TEST_F(MultiDataTypeTests, ndarray_applyPairwiseLambda_test1) { auto func5 = [](float elem1, int elem2) { return elem1 - elem2; }; NDArray exp1('c', {2,2}, {0.1, 1.1, 2.1, 3.1}, nd4j::DataType::DOUBLE); - NDArray exp2('c', {2,2}, {0, 0, 0, 0}, nd4j::DataType::INT64); + NDArray exp2('c', {2,2}, {0., 0, 0, 0}, nd4j::DataType::INT64); NDArray exp3('c', {2,2}, {0.1, 1.1, 2.1, 3.1}, nd4j::DataType::FLOAT32); NDArray exp4('c', {2,2}, {0.1, 1.6, 2.6, 3.6}, nd4j::DataType::FLOAT32); - NDArray exp5('c', {2,2}, {0, 1, 0, 1}, nd4j::DataType::BOOL); + NDArray exp5('c', {2,2}, {0., 1, 0, 1}, nd4j::DataType::BOOL); x1.applyPairwiseLambda(other2, func1, x4); ASSERT_EQ(x4, exp1); @@ -1505,16 +1505,16 @@ TEST_F(MultiDataTypeTests, ndarray_applyPairwiseLambda_test1) { ////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_applyIndexedPairwiseLambda_test1) { - NDArray x1('c', {2,2}, {0, 1, 2, 3}, nd4j::DataType::DOUBLE); - NDArray x2('c', {2,2}, {0, 1, 2, 3}, nd4j::DataType::INT64); - NDArray x3('c', {2,2}, {0, 1.5, 2.5, 3.5}, nd4j::DataType::FLOAT32); + NDArray x1('c', {2,2}, {0., 1, 2, 3}, nd4j::DataType::DOUBLE); + NDArray x2('c', {2,2}, {0., 1, 2, 3}, nd4j::DataType::INT64); + NDArray x3('c', {2,2}, {0., 1.5, 2.5, 3.5}, nd4j::DataType::FLOAT32); NDArray x4('c', {2,2}, nd4j::DataType::DOUBLE); NDArray x5('c', {2,2}, {0, 1.5, 2.5, 3.5}, nd4j::DataType::FLOAT32); NDArray x6('c', {2,2}, {0.1, -1, -1, 0.1}, nd4j::DataType::BOOL); NDArray x7('c', {2,2}, nd4j::DataType::BOOL); NDArray other1('c', {2,2}, {0.1, 0.1, 0.1, 0.1}, nd4j::DataType::FLOAT32); NDArray other2('c', {2,2}, {0.1, 0.1, 0.1, 0.1}, nd4j::DataType::DOUBLE); - NDArray other3('c', {2,2}, {0, -1, -2, -3}, nd4j::DataType::INT64); + NDArray other3('c', {2,2}, {0., -1, -2, -3}, nd4j::DataType::INT64); NDArray other4('c', {2,2}, {1, 0, 0.1, 0}, nd4j::DataType::BOOL); auto func1 = [](Nd4jLong idx, float elem1, float elem2) { return elem1 + elem2 + idx; }; @@ -1524,10 +1524,10 @@ TEST_F(MultiDataTypeTests, ndarray_applyIndexedPairwiseLambda_test1) { auto func5 = [](Nd4jLong idx, float elem1, int elem2) { return elem1 - elem2 + idx; }; NDArray exp1('c', {2,2}, {0.1, 2.1, 4.1, 6.1}, nd4j::DataType::DOUBLE); - NDArray exp2('c', {2,2}, {0, 1, 2, 3}, nd4j::DataType::INT64); + NDArray exp2('c', {2,2}, {0., 1, 2, 3}, nd4j::DataType::INT64); NDArray exp3('c', {2,2}, {0.1, 2.1, 4.1, 6.1}, nd4j::DataType::FLOAT32); NDArray exp4('c', {2,2}, {0.1, 2.6, 4.6, 6.6}, nd4j::DataType::FLOAT32); - NDArray exp5('c', {2,2}, {0, 1, 1, 1}, nd4j::DataType::BOOL); + NDArray exp5('c', {2,2}, {0., 1, 1, 1}, nd4j::DataType::BOOL); x1.applyIndexedPairwiseLambda(other2, func1, x4); ASSERT_EQ(x4, exp1); @@ -1551,25 +1551,25 @@ TEST_F(MultiDataTypeTests, ndarray_applyIndexedPairwiseLambda_test1) { ////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_applyTriplewiseLambda_test1) { - NDArray x1('c', {2,2}, {0, 1, 2, 3}, nd4j::DataType::DOUBLE); - NDArray x2('c', {2,2}, {0, -1, -2, -3}, nd4j::DataType::DOUBLE); + NDArray x1('c', {2,2}, {0., 1, 2, 3}, nd4j::DataType::DOUBLE); + NDArray x2('c', {2,2}, {0., -1, -2, -3}, nd4j::DataType::DOUBLE); NDArray x3('c', {2,2}, {0, -1.5, -2.5, -3.5}, nd4j::DataType::DOUBLE); NDArray x4('c', {2,2}, nd4j::DataType::DOUBLE); - NDArray x5('c', {2,2}, {0, 1, 2, 3}, nd4j::DataType::INT32); - NDArray x6('c', {2,2}, {0, -1, -2, -3}, nd4j::DataType::INT32); - NDArray x7('c', {2,2}, {0, 10, 20, 30}, nd4j::DataType::INT32); + NDArray x5('c', {2,2}, {0., 1, 2, 3}, nd4j::DataType::INT32); + NDArray x6('c', {2,2}, {0., -1, -2, -3}, nd4j::DataType::INT32); + NDArray x7('c', {2,2}, {0., 10, 20, 30}, nd4j::DataType::INT32); - NDArray x8('c', {2,2}, {0, 1, 0, 1}, nd4j::DataType::BOOL); - NDArray x9('c', {2,2}, {1, 1, 0, 1}, nd4j::DataType::BOOL); - NDArray x10('c', {2,2}, {0, 0, 0, 0}, nd4j::DataType::BOOL); + NDArray x8('c', {2,2}, {0., 1, 0, 1}, nd4j::DataType::BOOL); + NDArray x9('c', {2,2}, {1., 1, 0, 1}, nd4j::DataType::BOOL); + NDArray x10('c', {2,2}, {0., 0, 0, 0}, nd4j::DataType::BOOL); auto func1 = [](double elem1, float elem2, int elem3) { return elem1 + elem2 + elem3; }; auto func2 = [](float elem1, float elem2, float elem3) { return elem1 + elem2 + elem3; }; auto func3 = [](int elem1, int elem2, int elem3) { return elem1 + elem2 + elem3; }; auto func4 = [](bool elem1, bool elem2, bool elem3) { return elem1 + elem2 + elem3; }; - NDArray exp('c', {2,2}, {1, 1, 0, 1}, nd4j::DataType::BOOL); + NDArray exp('c', {2,2}, {1., 1, 0, 1}, nd4j::DataType::BOOL); x1.applyTriplewiseLambda(x2, x3, func1, x4); ASSERT_EQ(x4, x2); @@ -1590,7 +1590,7 @@ TEST_F(MultiDataTypeTests, ndarray_applyTriplewiseLambda_test1) { TEST_F(MultiDataTypeTests, ndarray_applyIndexReduce_test1) { NDArray x1('c', {2,3}, {0, 1, 2, 3, 4, 5}, nd4j::DataType::DOUBLE); - NDArray exp1('c', {}, {5}, nd4j::DataType::INT64); + NDArray exp1('c', {}, std::vector{5}, nd4j::DataType::INT64); NDArray exp2('c', {2}, {2,2}, nd4j::DataType::INT64); NDArray exp3('c', {3}, {1,1,1}, nd4j::DataType::INT64); @@ -1608,10 +1608,10 @@ TEST_F(MultiDataTypeTests, ndarray_applyIndexReduce_test1) { TEST_F(MultiDataTypeTests, ndarray_applyIndexReduce_test2) { NDArray x1('c', {2,3}, {0, 1, 2, 3, 4, 5}, nd4j::DataType::DOUBLE); - NDArray scalar('c', {}, {5}, nd4j::DataType::INT64); + NDArray scalar('c', {}, std::vector{5}, nd4j::DataType::INT64); NDArray vec1('c', {2}, {2,2}, nd4j::DataType::INT64); NDArray vec2('c', {3}, {1,1,1}, nd4j::DataType::INT64); - NDArray exp1('c', {}, {5}, nd4j::DataType::INT64); + NDArray exp1('c', {}, std::vector{5}, nd4j::DataType::INT64); NDArray exp2('c', {2}, {2,2}, nd4j::DataType::INT64); NDArray exp3('c', {3}, {1,1,1}, nd4j::DataType::INT64); @@ -1632,8 +1632,8 @@ TEST_F(MultiDataTypeTests, applyReduce3_test1) { NDArray x2('c', {2,2}, {-1,-2,-3,-4}, nd4j::DataType::INT32); NDArray x3('c', {2,2}, {1.5,1.5,1.5,1.5}, nd4j::DataType::DOUBLE); NDArray x4('c', {2,2}, {1,2,3,4}, nd4j::DataType::DOUBLE); - NDArray exp1('c', {}, {-30}, nd4j::DataType::FLOAT32); - NDArray exp2('c', {}, {15}, nd4j::DataType::DOUBLE); + NDArray exp1('c', {}, std::vector{-30}, nd4j::DataType::FLOAT32); + NDArray exp2('c', {}, std::vector{15}, nd4j::DataType::DOUBLE); auto result = x1.applyReduce3(reduce3::Dot, x2); ASSERT_EQ(result, exp1); @@ -1654,8 +1654,8 @@ TEST_F(MultiDataTypeTests, applyReduce3_test2) { NDArray x7('c', {2,3}, {1.5,1.5,1.5,1.5,1.5,1.5}, nd4j::DataType::DOUBLE); NDArray x8('c', {2,3}, {1,2,3,4,5,6}, nd4j::DataType::DOUBLE); - NDArray exp1('c', {}, {-30}, nd4j::DataType::FLOAT32); - NDArray exp2('c', {}, {15}, nd4j::DataType::DOUBLE); + NDArray exp1('c', {}, std::vector{-30}, nd4j::DataType::FLOAT32); + NDArray exp2('c', {}, std::vector{15}, nd4j::DataType::DOUBLE); NDArray exp3('c', {3}, {-18,-20,-18}, nd4j::DataType::FLOAT32); NDArray exp4('c', {2}, {-28,-28}, nd4j::DataType::FLOAT32); NDArray exp5('c', {3}, {7.5,10.5,13.5}, nd4j::DataType::DOUBLE); diff --git a/libnd4j/tests_cpu/layers_tests/NDArrayConstructorsTests.cu b/libnd4j/tests_cpu/layers_tests/NDArrayConstructorsTests.cu index 33384d4d8..0c0c102ac 100644 --- a/libnd4j/tests_cpu/layers_tests/NDArrayConstructorsTests.cu +++ b/libnd4j/tests_cpu/layers_tests/NDArrayConstructorsTests.cu @@ -184,7 +184,7 @@ TEST_F(NDArrayConstructorsTests, test_linspace_1) { TEST_F(NDArrayConstructorsTests, test_constructor_10) { NDArray scalar1(nd4j::DataType::DOUBLE); // scalar1 = 0 - NDArray scalar2('c', {}, {0}); + NDArray scalar2('c', {}, std::vector{0}); ASSERT_TRUE(scalar1.isActualOnDeviceSide()); ASSERT_TRUE(!scalar1.isActualOnHostSide()); diff --git a/libnd4j/tests_cpu/layers_tests/NDArrayCudaBasicsTests.cu b/libnd4j/tests_cpu/layers_tests/NDArrayCudaBasicsTests.cu index c6c0a1bd8..46f962dda 100644 --- a/libnd4j/tests_cpu/layers_tests/NDArrayCudaBasicsTests.cu +++ b/libnd4j/tests_cpu/layers_tests/NDArrayCudaBasicsTests.cu @@ -1226,8 +1226,8 @@ TEST_F(NDArrayCudaBasicsTests, applyReduce3_3) { NDArray x3('c', {3,2}, {1.5,1.5,1.5,1.5,1.5,1.5}, nd4j::DataType::DOUBLE); NDArray x4('c', {3,2}, {1,2,3,4,5,6}, nd4j::DataType::DOUBLE); - NDArray exp1('c', {}, {-204}, nd4j::DataType::FLOAT32); - NDArray exp2('c', {}, {31.5}, nd4j::DataType::DOUBLE); + NDArray exp1('c', {}, std::vector{-204}, nd4j::DataType::FLOAT32); + NDArray exp2('c', {}, std::vector{31.5}, nd4j::DataType::DOUBLE); auto z = x1.applyReduce3(reduce3::Dot, x2); @@ -1260,7 +1260,7 @@ TEST_F(NDArrayCudaBasicsTests, applyAllReduce3_1) { NDArray exp2('c', {6,4}, {-36.f, -44.f, -52.f, -60.f,-42.f, -52.f, -62.f, -72.f, 2.f, 0.f, -2.f, -4.f, 6.f, 4.f, 2.f, 0.f, 10.f, 8.f, 6.f, 4.f, 14.f, 12.f, 10.f, 8.f}, nd4j::DataType::FLOAT32); - NDArray exp3('c', {1,1}, {31.5}, nd4j::DataType::DOUBLE); + NDArray exp3('c', {1,1}, std::vector{31.5}, nd4j::DataType::DOUBLE); NDArray exp4('c', {3,3}, {4.5, 10.5, 16.5,4.5, 10.5, 16.5,4.5, 10.5, 16.5}, nd4j::DataType::DOUBLE); auto z = x1.applyAllReduce3(reduce3::Dot, x2, {0,2}); @@ -1292,15 +1292,15 @@ TEST_F(NDArrayCudaBasicsTests, applyIndexReduce_test1) { NDArray x('c', {2,3}, {0, 10, 1, 2, 2.5,-4}, nd4j::DataType::DOUBLE); - NDArray scalar('c', {}, {100}, nd4j::DataType::INT64); + NDArray scalar('c', {}, std::vector{100}, nd4j::DataType::INT64); NDArray vec1('c', {2}, {100,100}, nd4j::DataType::INT64); NDArray vec2('c', {3}, {100,100,100}, nd4j::DataType::INT64); - NDArray exp1('c', {}, {1}, nd4j::DataType::INT64); + NDArray exp1('c', {}, std::vector{1}, nd4j::DataType::INT64); NDArray exp2('c', {2}, {1,1}, nd4j::DataType::INT64); NDArray exp3('c', {3}, {1,0,0}, nd4j::DataType::INT64); - NDArray exp4('c', {}, {2}, nd4j::DataType::INT64); + NDArray exp4('c', {}, std::vector{2}, nd4j::DataType::INT64); NDArray exp5('c', {2}, {1,1}, nd4j::DataType::INT64); NDArray exp6('c', {3}, {1,0,0}, nd4j::DataType::INT64); @@ -1331,11 +1331,11 @@ TEST_F(NDArrayCudaBasicsTests, applyIndexReduce_test2) { NDArray x('c', {2,3}, {0, 10, 1, 2, 2.5,-4}, nd4j::DataType::DOUBLE); - NDArray exp1('c', {}, {1}, nd4j::DataType::INT64); + NDArray exp1('c', {}, std::vector{1}, nd4j::DataType::INT64); NDArray exp2('c', {2}, {1,1}, nd4j::DataType::INT64); NDArray exp3('c', {3}, {1,0,0}, nd4j::DataType::INT64); - NDArray exp4('c', {}, {2}, nd4j::DataType::INT64); + NDArray exp4('c', {}, std::vector{2}, nd4j::DataType::INT64); NDArray exp5('c', {2}, {1,1}, nd4j::DataType::INT64); NDArray exp6('c', {3}, {1,0,0}, nd4j::DataType::INT64); @@ -1365,13 +1365,13 @@ TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_float_test1) { NDArray x('c', {2,3,2}, {1,2,3,4,5,6,7,8,-1,-2,-3,-4,}, nd4j::DataType::INT32); - NDArray z1('c', {}, {100}, nd4j::DataType::DOUBLE); + NDArray z1('c', {}, std::vector{100}, nd4j::DataType::DOUBLE); NDArray z2('c', {2,2}, {100,100,100,100}, nd4j::DataType::FLOAT32); NDArray z3('c', {3}, {100,100,100}, nd4j::DataType::DOUBLE); NDArray z4('c', {3,2}, {100,100,100,100,100,100}, nd4j::DataType::FLOAT32); NDArray z5('c', {2}, {100,100}, nd4j::DataType::FLOAT32); - NDArray exp1('c', {}, {2.166667}, nd4j::DataType::DOUBLE); + NDArray exp1('c', {}, std::vector{2.166667}, nd4j::DataType::DOUBLE); NDArray exp2('c', {2,2}, {3.f,4.f,1.f,0.666667f}, nd4j::DataType::FLOAT32); NDArray exp3('c', {3}, {4.5,1,1}, nd4j::DataType::DOUBLE); NDArray exp4('c', {3,2}, {4,5,1,1,1,1}, nd4j::DataType::FLOAT32); @@ -1403,7 +1403,7 @@ TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_float_test2) { NDArray x('c', {2,3,2}, {1,2,3,4,5,6,7,8,-1,-2,-3,-4,}, nd4j::DataType::DOUBLE); - NDArray exp1('c', {}, {2.166667}, nd4j::DataType::DOUBLE); + NDArray exp1('c', {}, std::vector{2.166667}, nd4j::DataType::DOUBLE); NDArray exp2('c', {2,2}, {3,4,1,0.666667}, nd4j::DataType::DOUBLE); NDArray exp3('c', {3}, {4.5,1,1}, nd4j::DataType::DOUBLE); NDArray exp4('c', {3,2}, {4,5,1,1,1,1}, nd4j::DataType::DOUBLE); @@ -1477,13 +1477,13 @@ TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_same_test1) { NDArray x('c', {2,3,2}, {1.5f,2.f,3.f,4.f,5.f,6.f,7.5f,8.f,-1.f,-2.f,-3.5f,-4.f}, nd4j::DataType::FLOAT32); - NDArray z1('c', {}, {100}, nd4j::DataType::FLOAT32); + NDArray z1('c', {}, std::vector{100}, nd4j::DataType::FLOAT32); NDArray z2('c', {2,2}, {100,100,100,100}, nd4j::DataType::FLOAT32); NDArray z3('c', {3}, {100,100,100}, nd4j::DataType::FLOAT32); NDArray z4('c', {3,2}, {100,100,100,100,100,100}, nd4j::DataType::FLOAT32); NDArray z5('c', {2}, {100,100}, nd4j::DataType::FLOAT32); - NDArray exp1('c', {}, {26.5f}, nd4j::DataType::FLOAT32); + NDArray exp1('c', {}, std::vector{26.5f}, nd4j::DataType::FLOAT32); NDArray exp2('c', {2,2}, {9.5f,12.f,3.f,2.f}, nd4j::DataType::FLOAT32); NDArray exp3('c', {3}, {19.f,4.f,3.5f}, nd4j::DataType::FLOAT32); NDArray exp4('c', {3,2}, {9.f,10.f,2.f,2.f,1.5f,2.f}, nd4j::DataType::FLOAT32); @@ -1515,7 +1515,7 @@ TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_same_test2) { NDArray x('c', {2,3,2}, {1.5,2,3,4,5,6,7.5,8,-1,-2,-3.5,-4,}, nd4j::DataType::INT64); - NDArray exp1('c', {}, {26}, nd4j::DataType::INT64); + NDArray exp1('c', {}, std::vector{26}, nd4j::DataType::INT64); NDArray exp2('c', {2,2}, {9,12,3,2}, nd4j::DataType::INT64); NDArray exp3('c', {3}, {18,4,4}, nd4j::DataType::INT64); NDArray exp4('c', {3,2}, {8,10,2,2,2,2}, nd4j::DataType::INT64); @@ -1547,13 +1547,13 @@ TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_bool_test1) { NDArray x('c', {2,3,2}, {0.5,2,3,-4,5,6,-7.5,8,-1,-0.5,-3.5,4}, nd4j::DataType::DOUBLE); - NDArray z1('c', {}, {true}, nd4j::DataType::BOOL); + NDArray z1('c', {}, std::vector{true}, nd4j::DataType::BOOL); NDArray z2('c', {2,2}, {true,true,true,true}, nd4j::DataType::BOOL); NDArray z3('c', {3}, {true,true,true}, nd4j::DataType::BOOL); NDArray z4('c', {3,2}, {true,true,true,true,true,true}, nd4j::DataType::BOOL); NDArray z5('c', {2}, {true,true}, nd4j::DataType::BOOL); - NDArray exp1('c', {}, {true}, nd4j::DataType::BOOL); + NDArray exp1('c', {}, std::vector{true}, nd4j::DataType::BOOL); NDArray exp2('c', {2,2}, {true,true,false,true}, nd4j::DataType::BOOL); NDArray exp3('c', {3}, {true,true,true}, nd4j::DataType::BOOL); NDArray exp4('c', {3,2}, {true,true,true,false,true,true}, nd4j::DataType::BOOL); @@ -1585,7 +1585,7 @@ TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_bool_test2) { NDArray x('c', {2,3,2}, {0.5,2,3,-4,5,6,-7.5,8,-1,-0.5,-3.5,4}, nd4j::DataType::INT32); - NDArray exp1('c', {}, {1}, nd4j::DataType::BOOL); + NDArray exp1('c', {}, std::vector{1}, nd4j::DataType::BOOL); NDArray exp2('c', {2,2}, {1,1,0,1}, nd4j::DataType::BOOL); NDArray exp3('c', {3}, {1,1,1}, nd4j::DataType::BOOL); NDArray exp4('c', {3,2}, {0,1,1,0,1,1}, nd4j::DataType::BOOL); @@ -1617,13 +1617,13 @@ TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_long_test1) { NDArray x('c', {2,3,2}, {0.5f,2.f,3.f,-0.f,5.f,6.f,-7.5f,0.f,-1.f,-0.5f,-3.5f,4.f}, nd4j::DataType::FLOAT32); - NDArray z1('c', {}, {100}, nd4j::DataType::INT64); + NDArray z1('c', {}, std::vector{100}, nd4j::DataType::INT64); NDArray z2('c', {2,2}, {100,100,100,100}, nd4j::DataType::INT64); NDArray z3('c', {3}, {100,100,100}, nd4j::DataType::INT64); NDArray z4('c', {3,2}, {100,100,100,100,100,100}, nd4j::DataType::INT64); NDArray z5('c', {2}, {100,100}, nd4j::DataType::INT64); - NDArray exp1('c', {}, {2}, nd4j::DataType::INT64); + NDArray exp1('c', {}, std::vector{2}, nd4j::DataType::INT64); NDArray exp2('c', {2,2}, {0,1,0,1}, nd4j::DataType::INT64); NDArray exp3('c', {3}, {1,1,0}, nd4j::DataType::INT64); NDArray exp4('c', {3,2}, {0,1,0,1,0,0}, nd4j::DataType::INT64); @@ -1655,7 +1655,7 @@ TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_long_test2) { NDArray x('c', {2,3,2}, {0.5,2,3,-0,5,6,-7.5,0,-1,-0.5,-3.5,4}, nd4j::DataType::INT32); - NDArray exp1('c', {}, {4}, nd4j::DataType::INT64); + NDArray exp1('c', {}, std::vector{4}, nd4j::DataType::INT64); NDArray exp2('c', {2,2}, {1,1,0,2}, nd4j::DataType::INT64); NDArray exp3('c', {3}, {2,2,0}, nd4j::DataType::INT64); NDArray exp4('c', {3,2}, {1,1,0,2,0,0}, nd4j::DataType::INT64); diff --git a/libnd4j/tests_cpu/layers_tests/ParityOpsTests.cpp b/libnd4j/tests_cpu/layers_tests/ParityOpsTests.cpp index 987817136..a53d71a65 100644 --- a/libnd4j/tests_cpu/layers_tests/ParityOpsTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/ParityOpsTests.cpp @@ -692,7 +692,7 @@ TEST_F(ParityOpsTests, Test_Bias_Add_1) { TEST_F(ParityOpsTests, Test_Scatter_Add_1) { auto matrix = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); - NDArray idc('c', {1}, {0}, nd4j::DataType::INT64); + NDArray idc('c', {1}, std::vector({0}), nd4j::DataType::INT64); auto updates = NDArrayFactory::create('c', {1, 2}, {1, 1}); auto exp = NDArrayFactory::create('c', {2, 2}, {2, 3, 3, 4}); @@ -710,7 +710,7 @@ TEST_F(ParityOpsTests, Test_Scatter_Add_1) { TEST_F(ParityOpsTests, Test_Scatter_Add_2) { auto vec = NDArrayFactory::create('c', {4}, {1, 2, 3, 4}); - NDArray idc('c', {1, 4}, {0, 1, 2, 3}, nd4j::DataType::INT64); + NDArray idc('c', {1, 4}, {0., 1, 2, 3}, nd4j::DataType::INT64); auto updates = NDArrayFactory::create('c', {1, 4}, {1, 1, 1, 1}); auto exp = NDArrayFactory::create('c', {1, 4}, {2, 3, 4, 5}); @@ -727,7 +727,7 @@ TEST_F(ParityOpsTests, Test_Scatter_Add_2) { TEST_F(ParityOpsTests, Test_Scatter_Add_3) { auto matrix = NDArrayFactory::create('c', {2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8}); - NDArray idc('c', {1}, {0}, nd4j::DataType::INT64); + NDArray idc('c', {1}, std::vector({0}), nd4j::DataType::INT64); auto updates = NDArrayFactory::create('c', {1, 2, 2}, {1, 1, 1, 1}); auto exp = NDArrayFactory::create('c', {2, 2, 2}, {2, 3, 4, 5, 5, 6, 7, 8}); @@ -744,7 +744,7 @@ TEST_F(ParityOpsTests, Test_Scatter_Add_3) { TEST_F(ParityOpsTests, Test_Scatter_Add_4) { auto matrix = NDArrayFactory::create('c', {2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8}); - NDArray idc('c', {1, 2}, {0, 0}, nd4j::DataType::INT64); + NDArray idc('c', {1, 2}, std::vector{0, 0}, nd4j::DataType::INT64); auto updates = NDArrayFactory::create('c', {1, 2, 2, 2}, {1, 1, 1, 1, 1, 1, 1, 1}); auto exp = NDArrayFactory::create('c', {2, 2, 2}, {3, 4, 5, 6, 5, 6, 7, 8}); @@ -761,7 +761,7 @@ TEST_F(ParityOpsTests, Test_Scatter_Add_4) { TEST_F(ParityOpsTests, Test_Scatter_Add_5) { auto matrix = NDArrayFactory::create('c', {2, 2, 3}, {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}); - NDArray idc('c', {2, 2}, {1, 1, 0, 0}, nd4j::DataType::INT64); + NDArray idc('c', {2, 2}, {1., 1, 0, 0}, nd4j::DataType::INT64); auto updates = NDArrayFactory::create('c', {2, 2, 2, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); auto exp = NDArrayFactory::create('c', {2, 2, 3}, {9., 11., 13.,15., 17., 19., 9., 11., 13.,15., 17., 19.}); @@ -796,7 +796,7 @@ TEST_F(ParityOpsTests, Test_Scatter_Add_6) { TEST_F(ParityOpsTests, Test_Scatter_Add_7) { auto matrix = NDArrayFactory::create('c', {10, 3}, {1.f,2.f,3.f,4.f,5.f,6.f,7.f,8.f,9.f,10.f,11.f,12.f,13.f,14.f,15.f,16.f,17.f,18.f,19.f,20.f,21.f,22.f,23.f,24.f,25.f,26.f,27.f,28.f,29.f,30.f}); - NDArray idc('c', {}, {5}, nd4j::DataType::INT64); + NDArray idc('c', {}, std::vector{5}, nd4j::DataType::INT64); auto updates = NDArrayFactory::create('c', {3}, {10.f, 20.f, 30.f}); auto exp = NDArrayFactory::create('c', {10, 3}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f,11.f,12.f, 13.f,14.f,15.f, 26.f,37.f,48.f, 19.f,20.f,21.f, 22.f,23.f,24.f, 25.f,26.f,27.f, 28.f,29.f,30.f}); @@ -845,7 +845,7 @@ TEST_F(ParityOpsTests, Test_Scatter_Add_9) { //////////////////////////////////////////////////////////////////// TEST_F(ParityOpsTests, scatterMax_test1) { auto matrix = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); - NDArray idc('c', {1}, {0.}, nd4j::DataType::INT64); + NDArray idc('c', {1}, std::vector{0.}, nd4j::DataType::INT64); auto updates = NDArrayFactory::create('c', {1, 2}, {10, 1}); auto exp = NDArrayFactory::create('c', {2, 2}, {10, 2, 3, 4}); @@ -879,7 +879,7 @@ TEST_F(ParityOpsTests, scatterMax_test2) { TEST_F(ParityOpsTests, scatterMax_test3) { auto matrix = NDArrayFactory::create('c', {2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8}); - NDArray idc('c', {1}, {0}, nd4j::DataType::INT64); + NDArray idc('c', {1}, std::vector({0}), nd4j::DataType::INT64); auto updates = NDArrayFactory::create('c', {1, 2, 2}, {10, 1, 30, 1}); auto exp = NDArrayFactory::create('c', {2, 2, 2}, {10, 2, 30, 4, 5, 6, 7, 8}); @@ -896,7 +896,7 @@ TEST_F(ParityOpsTests, scatterMax_test3) { TEST_F(ParityOpsTests, scatterMax_test4) { auto matrix = NDArrayFactory::create('c', {2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8}); - NDArray idc('c', {1,2}, {0,0}, nd4j::DataType::INT32); + NDArray idc('c', {1,2}, std::vector{0.,0}, nd4j::DataType::INT32); auto updates = NDArrayFactory::create('c', {1, 2, 2, 2}, {1,10,1,10, 1,1,10,1.}); auto exp = NDArrayFactory::create('c', {2, 2, 2}, {1, 10, 10, 10, 5, 6, 7, 8}); @@ -948,7 +948,7 @@ TEST_F(ParityOpsTests, scatterMax_test6) { TEST_F(ParityOpsTests, scatterMin_test1) { auto matrix = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); - NDArray idc('c', {1}, {0}, nd4j::DataType::INT32); + NDArray idc('c', {1}, std::vector({0}), nd4j::DataType::INT32); auto updates = NDArrayFactory::create('c', {1, 2}, {-1, 1}); auto exp = NDArrayFactory::create('c', {2, 2}, {-1, 1, 3, 4}); @@ -982,7 +982,7 @@ TEST_F(ParityOpsTests, scatterMin_test2) { TEST_F(ParityOpsTests, scatterMin_test3) { auto matrix = NDArrayFactory::create('c', {2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8}); - NDArray idc('c', {1}, {0}, nd4j::DataType::INT32); + NDArray idc('c', {1}, std::vector({0}), nd4j::DataType::INT32); auto updates = NDArrayFactory::create('c', {1, 2, 2}, {10, 1, 30, 2}); auto exp = NDArrayFactory::create('c', {2, 2, 2}, {1, 1, 3, 2, 5, 6, 7, 8}); @@ -999,7 +999,7 @@ TEST_F(ParityOpsTests, scatterMin_test3) { TEST_F(ParityOpsTests, scatterMin_test4) { auto matrix = NDArrayFactory::create('c', {2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8}); - NDArray idc('c', {1,2}, {0,0}, nd4j::DataType::INT32); + NDArray idc('c', {1,2}, std::vector{0.,0}, nd4j::DataType::INT32); auto updates = NDArrayFactory::create('c', {1, 2, 2, 2}, {1,10,1,10, 1,1,10,1.}); auto exp = NDArrayFactory::create('c', {2, 2, 2}, {1, 1, 1, 1, 5, 6, 7, 8}); diff --git a/libnd4j/tests_cpu/layers_tests/RNGTests.cpp b/libnd4j/tests_cpu/layers_tests/RNGTests.cpp index e6cf01521..b40b74939 100644 --- a/libnd4j/tests_cpu/layers_tests/RNGTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/RNGTests.cpp @@ -1005,24 +1005,24 @@ TEST_F(RNGTests, test_uniform_119) { } TEST_F(RNGTests, test_multinomial_1) { - + NDArray probs('f', { 3, 3 }, { 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3 }, nd4j::DataType::FLOAT32); - NDArray expected('f', { 3, 3 }, { 0, 1, 2, 2, 0, 0, 1, 2, 1 }, nd4j::DataType::INT64); + NDArray expected('f', { 3, 3 }, { 0., 1, 2, 2, 0, 0, 1, 2, 1 }, nd4j::DataType::INT64); NDArray output('f', { 3, 3 }, nd4j::DataType::INT64); - NDArray samples('f', { 1 }, { 3 }, nd4j::DataType::INT32); - + NDArray samples('f', { 1 }, std::vector({3}), nd4j::DataType::INT32); + nd4j::ops::random_multinomial op; RandomGenerator rng(1234, 1234); ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 0, INT64}, {}, {}, false) ); ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); - + NDArray probsZ('c', { 1, 3 }, { 0.3, 0.3, 0.3 }, nd4j::DataType::FLOAT32); - NDArray expectedZ('c', { 3, 3 }, { 0, 0, 0, 0, 0, 0, 0, 0, 0 }, nd4j::DataType::INT64); + NDArray expectedZ('c', { 3, 3 }, { 0., 0, 0, 0, 0, 0, 0, 0, 0 }, nd4j::DataType::INT64); auto result = op.evaluate({ &probsZ, &samples }, { }, { 1, INT64 }); auto outputZ = result->at(0); - + ASSERT_EQ(Status::OK(), result->status()); ASSERT_TRUE(expectedZ.isSameShape(outputZ)); ASSERT_TRUE(expectedZ.equalsTo(outputZ)); @@ -1031,7 +1031,7 @@ TEST_F(RNGTests, test_multinomial_1) { TEST_F(RNGTests, test_multinomial_2) { - NDArray samples('c', { 1 }, { 20 }, nd4j::DataType::INT32); + NDArray samples('c', { 1 }, std::vector{ 20 }, nd4j::DataType::INT32); NDArray probs('c', { 3, 5 }, { 0.2, 0.3, 0.5, 0.3, 0.5, 0.2, 0.5, 0.2, 0.3, 0.35, 0.25, 0.3, 0.25, 0.25, 0.5 }, nd4j::DataType::FLOAT32); NDArray expected('c', { 3, 20 }, { 0, 2, 0, 2, 0, 4, 2, 0, 1, 2, 0, 2, 3, 0, 0, 2, 4, 4, 1, 0, 2, 3, 2, 3, 0, 1, 3, 1, 1, 1, 2, 4, 3, 3, 1, 4, 4, 2, 0, 0, 3, 3, 3, 0, 0, 2, 2, 3, 3, 0, 0, 2, 3, 4, 2, 2, 3, 2, 1, 2 }, nd4j::DataType::INT64); NDArray output('c', { 3, 20 }, nd4j::DataType::INT64); @@ -1041,11 +1041,11 @@ TEST_F(RNGTests, test_multinomial_2) { ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 0, INT64 }, {}, {}, false)); ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); - + NDArray probs2('c', { 5, 3 }, { 0.2, 0.3, 0.5, 0.3, 0.5, 0.2, 0.5, 0.2, 0.3, 0.35, 0.25, 0.3, 0.25, 0.25, 0.5 }, nd4j::DataType::FLOAT32); NDArray expected2('c', { 20, 3 }, { 0, 2, 3, 2, 3, 3, 0, 2, 3, 2, 3, 0, 0, 0, 0, 4, 1, 2, 2, 3, 2, 3, 1, 3, 1, 1, 3, 2, 1, 0, 0, 2, 0, 2, 4, 2, 3, 3, 3, 0, 3, 4, 0, 1, 2, 2, 0, 2, 4, 4, 0, 4, 2, 2, 1, 0, 1, 0, 0, 2 }, nd4j::DataType::INT64); NDArray output2('c', { 20, 3 }, nd4j::DataType::INT64); - + rng.setStates(1234, 1234); ASSERT_EQ(Status::OK(), op.execute(rng, { &probs2, &samples }, { &output2 }, {}, { 1, INT64 }, {}, {}, false)); ASSERT_TRUE(expected2.isSameShape(output2)); @@ -1053,16 +1053,17 @@ TEST_F(RNGTests, test_multinomial_2) { } TEST_F(RNGTests, test_multinomial_3) { - + NDArray probs('c', { 4, 3 }, { 0.3, 0.3, 0.4, 0.3, 0.4, 0.3, 0.3, 0.3, 0.4, 0.4, 0.3, 0.3 }, nd4j::DataType::FLOAT32); NDArray expected('c', { 4, 5 }, nd4j::DataType::INT64); NDArray output('c', { 4, 5 }, nd4j::DataType::INT64); - NDArray samples('c', { 1 }, { 5 }, nd4j::DataType::INT32); + NDArray samples('c', { 1 }, std::vector{ 5 }, nd4j::DataType::INT32); RandomGenerator rng(1234, 1234); nd4j::ops::random_multinomial op; + ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &expected }, {}, { 0, INT64 }, {}, {}, false)); - + rng.setStates(1234, 1234); ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 0, INT64 }, {}, {}, false)); ASSERT_TRUE(expected.isSameShape(output)); @@ -1074,7 +1075,7 @@ TEST_F(RNGTests, test_multinomial_4) { NDArray probs('c', { 3, 4 }, { 0.3, 0.3, 0.4, 0.3, 0.4, 0.3, 0.3, 0.3, 0.4, 0.4, 0.3, 0.3 }, nd4j::DataType::FLOAT32); NDArray expected('c', { 5, 4 }, nd4j::DataType::INT64); NDArray output('c', { 5, 4 }, nd4j::DataType::INT64); - NDArray samples('c', { 1 }, { 5 }, nd4j::DataType::INT32); + NDArray samples('c', { 1 }, std::vector{ 5 }, nd4j::DataType::INT32); RandomGenerator rng(1234, 1234); nd4j::ops::random_multinomial op; @@ -1092,15 +1093,15 @@ TEST_F(RNGTests, test_multinomial_5) { int ClassValue = 2; int Samples = 100000; - NDArray samples('c', { 1 }, { 1.*Samples }, nd4j::DataType::INT32); - + NDArray samples('c', { 1 }, std::vector{ 1.*Samples }, nd4j::DataType::INT32); + NDArray probs('c', { ClassValue, batchValue }, { 1.0, 1.0 }, nd4j::DataType::FLOAT32); - + nd4j::ops::random_multinomial op; NDArray output('c', { Samples, batchValue }, nd4j::DataType::INT64); RandomGenerator rng(1234, 1234); - + ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 1 }, {}, {}, false)); auto deviation = output.varianceNumber(variance::SummaryStatsStandardDeviation, false); @@ -1109,7 +1110,7 @@ TEST_F(RNGTests, test_multinomial_5) { // theoretical values for binomial ASSERT_NEAR(0.5, deviation.e(0), 4e-3); // 1000000 3e-3); ASSERT_NEAR(0.5, mean.e(0), 4e-3); // 1000000 3e-3); - + for (int i = 0; i < output.lengthOf(); i++) { auto value = output.e(i); ASSERT_TRUE(value >= 0 && value < ClassValue); @@ -1139,8 +1140,8 @@ TEST_F(RNGTests, test_multinomial_6) { int batchValue = 1; int ClassValue = 5; int Samples = 100000; - - NDArray samples('c', { 1 }, { 1. * Samples }, nd4j::DataType::INT32); + + NDArray samples('c', { 1 }, std::vector{ 1. * Samples }, nd4j::DataType::INT32); nd4j::ops::random_multinomial op; NDArray probExpect('c', { ClassValue }, { 0.058, 0.096, 0.1576, 0.2598, 0.4287 }, nd4j::DataType::DOUBLE); @@ -1152,8 +1153,8 @@ TEST_F(RNGTests, test_multinomial_6) { auto outputR = resultR->at(0); ASSERT_EQ(Status::OK(), resultR->status()); - NDArray countsR('c', { ClassValue }, { 0, 0, 0, 0, 0 }, nd4j::DataType::DOUBLE); - + NDArray countsR('c', { ClassValue }, { 0., 0, 0, 0, 0 }, nd4j::DataType::DOUBLE); + for (int i = 0; i < outputR->lengthOf(); i++) { auto value = outputR->e(i); ASSERT_TRUE(value >= 0 && value < ClassValue); @@ -1179,11 +1180,11 @@ TEST_F(RNGTests, test_multinomial_6) { RandomGenerator rng(1234, 1234); NDArray probs('c', { batchValue, ClassValue }, { 1., 1.5, 2., 2.5, 3. }, nd4j::DataType::FLOAT32); NDArray output('c', { batchValue, Samples }, nd4j::DataType::INT64); - + ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 0, INT64 }, {}, {}, false)); - NDArray counts('c', { ClassValue }, { 0, 0, 0, 0, 0 }, nd4j::DataType::DOUBLE); - + NDArray counts('c', { ClassValue }, { 0., 0, 0, 0, 0 }, nd4j::DataType::DOUBLE); + for (int i = 0; i < output.lengthOf(); i++) { auto value = output.e(i); ASSERT_TRUE(value >= 0 && value < ClassValue); diff --git a/libnd4j/tests_cpu/layers_tests/StringTests.cpp b/libnd4j/tests_cpu/layers_tests/StringTests.cpp index ec7821f21..8b9d92f2f 100644 --- a/libnd4j/tests_cpu/layers_tests/StringTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/StringTests.cpp @@ -1,5 +1,6 @@ -/******************************************************************************* +/******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019-2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -16,6 +17,7 @@ // // @author raver119@gmail.com +// @author Oleg Semeniv // @@ -30,7 +32,7 @@ class StringTests : public testing::Test { public: }; - +///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_Test_1) { std::string f("alpha"); auto array = NDArrayFactory::string(f); @@ -43,7 +45,7 @@ TEST_F(StringTests, Basic_Test_1) { ASSERT_EQ(f, z); } - +///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_Test_2) { std::string f("alpha"); auto array = NDArrayFactory::string(f.c_str()); @@ -56,23 +58,213 @@ TEST_F(StringTests, Basic_Test_2) { ASSERT_EQ(f, z); } - +///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_Test_3) { - auto array = NDArrayFactory::string('c', {3, 2}, {"alpha", "beta", "gamma", "phi", "theta", "omega"}); + + auto array = NDArrayFactory::string({3, 2}, {"alpha", "beta", "gamma", "phi", "theta", "omega"}); + + ASSERT_EQ(6, array.lengthOf()); + ASSERT_EQ(2, array.rankOf()); + + array.printIndexedBuffer("String array"); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Basic_Test_4) { + + NDArray array( { 3, 2 }, std::vector{ U"alpha", U"beta", U"gamma€한", U"pÿqwe", U"ß水𝄋", U"omega" }); + ASSERT_EQ(6, array.lengthOf()); + ASSERT_EQ(2, array.rankOf()); + + array.printIndexedBuffer("String array"); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Basic_Test_5) { + + NDArray array( { 3, 2 }, std::vector{ u"alpha", u"beta", u"gamma€한", u"pÿqwe", u"ß水𝄋", u"omega" }); + ASSERT_EQ(6, array.lengthOf()); + ASSERT_EQ(2, array.rankOf()); + + array.printIndexedBuffer("String array"); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Basic_Test_6) { + + NDArray array( { 3, 2 }, std::vector{ "alpha", "beta", "gamma€한", "pÿqwe", "ß水𝄋", "omega" }); + ASSERT_EQ(6, array.lengthOf()); + ASSERT_EQ(2, array.rankOf()); + + array.printIndexedBuffer("String array"); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Basic_Test_7) { + + NDArray array( { 3, 2 }, std::vector{ U"alpha", U"beta", U"gamma€한", U"pÿqwe", U"ß水𝄋", U"omega" }); + ASSERT_EQ(6, array.lengthOf()); + ASSERT_EQ(2, array.rankOf()); + + array.printIndexedBuffer("String array"); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Basic_Test_8) { + + NDArray array( { 3, 2 }, std::vector{ u"alpha", u"beta", u"gamma€한", u"pÿqwe", u"ß水𝄋", u"omega" }); + ASSERT_EQ(6, array.lengthOf()); + ASSERT_EQ(2, array.rankOf()); + + array.printIndexedBuffer("String array"); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Basic_Test_9) { + + NDArray array( { 3, 2 }, std::vector{ "alpha", "beta", "gamma€한", "pÿqwe", "ß水𝄋", "omega" }); + ASSERT_EQ(6, array.lengthOf()); + ASSERT_EQ(2, array.rankOf()); + + array.printIndexedBuffer("String array"); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Basic_Test_10) { + + NDArray array(std::u32string(U"gamma€한")); + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); + array.printIndexedBuffer("String array"); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Basic_Test_11) { + + NDArray array(U"gamma€한"); + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); + + array.printIndexedBuffer("String array"); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Basic_Test_12) { + + NDArray array(std::u16string(u"gamma€한")); + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); + array.printIndexedBuffer("String array"); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Basic_Test_13) { + + NDArray array(u"gamma€한"); + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); + + array.printIndexedBuffer("String array"); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Basic_Test_14) { + + NDArray array(std::string("gamma€한")); + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); + array.printIndexedBuffer("String array"); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Basic_Test_15) { + + NDArray array("gamma€한"); + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); + + array.printIndexedBuffer("String array"); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Basic_Test_16) { + + auto array = NDArrayFactory::string( { 3, 2 }, std::vector{ "alpha", "beta", "gamma", "phi", "theta", "omega" }); ASSERT_EQ(6, array.lengthOf()); ASSERT_EQ(2, array.rankOf()); array.printIndexedBuffer("String array"); } +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Basic_Test_17) { + auto array = NDArrayFactory::string({ 3, 2 }, std::vector{ "alpha", "beta", "gamma", "phi", "theta", "omega" }); + ASSERT_EQ(6, array.lengthOf()); + ASSERT_EQ(2, array.rankOf()); + + array.printIndexedBuffer("String array"); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Basic_Test_18) { + + auto array = NDArrayFactory::string({ 3, 2 }, std::vector{ u"alpha", u"beta", u"gamma", u"phi", u"theta", u"omega" }); + + ASSERT_EQ(6, array.lengthOf()); + ASSERT_EQ(2, array.rankOf()); + + array.printIndexedBuffer("String array"); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Basic_Test_19) { + + auto array = NDArrayFactory::string( { 3, 2 }, std::vector{ u"alpha", u"beta", u"gamma", u"phi", u"theta", u"omega" }); + + ASSERT_EQ(6, array.lengthOf()); + ASSERT_EQ(2, array.rankOf()); + + array.printIndexedBuffer("String array"); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Basic_Test_20) { + + auto array = NDArrayFactory::string( { 3, 2 }, std::vector{ U"alpha", U"beta", U"gamma", U"phi", U"theta", U"omega" }); + + ASSERT_EQ(6, array.lengthOf()); + ASSERT_EQ(2, array.rankOf()); + + array.printIndexedBuffer("String array"); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Basic_Test_21) { + + auto array = NDArrayFactory::string( { 3, 2 }, std::vector{ U"alpha", U"òèçùà12345¤z", U"ß水𝄋ÿ€한𐍈®кею90ощъ]ї", U"phi", U"theta", U"omega" }); + + ASSERT_EQ(6, array.lengthOf()); + ASSERT_EQ(2, array.rankOf()); + + array.printIndexedBuffer("String array"); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Basic_Test_22) { + std::u16string f(u"ß水𝄋ÿ€한𐍈®кею90ощъ]ї"); + auto array = NDArrayFactory::string(f.c_str()); + ASSERT_EQ(nd4j::DataType::UTF16, array.dataType()); + + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); + + auto z = array.e(0); + + ASSERT_EQ(f, z); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Basic_Test_23) { + std::u32string f(U"ß水𝄋ÿ€한𐍈®кею90ощъ]ї"); + auto array = NDArrayFactory::string(f.c_str()); + ASSERT_EQ(nd4j::DataType::UTF32, array.dataType()); + + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); + + auto z = array.e(0); + + ASSERT_EQ(f, z); +} +///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Export_Test_1) { - auto array = NDArrayFactory::string('c', {3}, {"alpha", "beta", "gamma"}); - + auto array = NDArrayFactory::string( {3}, {"alpha", "beta", "gamma"}); auto vector = array.asByteVector(); } - +///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_dup_1) { std::string f("alpha"); auto array = NDArrayFactory::string(f); @@ -91,20 +283,20 @@ TEST_F(StringTests, Basic_dup_1) { delete dup; } - +///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, byte_length_test_1) { std::string f("alpha"); auto array = NDArrayFactory::string(f); ASSERT_EQ(f.length(), StringUtils::byteLength(array)); } - +///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, byte_length_test_2) { - auto array = NDArrayFactory::string('c', {2}, {"alpha", "beta"}); + auto array = NDArrayFactory::string( {2}, {"alpha", "beta"}); ASSERT_EQ(9, StringUtils::byteLength(array)); } - +///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, test_split_1) { auto split = StringUtils::split("alpha beta gamma", " "); @@ -112,4 +304,562 @@ TEST_F(StringTests, test_split_1) { ASSERT_EQ(std::string("alpha"), split[0]); ASSERT_EQ(std::string("beta"), split[1]); ASSERT_EQ(std::string("gamma"), split[2]); -} \ No newline at end of file +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, test_unicode_utf8_utf16) { + + std::string utf8 = u8"\nòèçùà12345¤zß水𝄋ÿ€한𐍈®кею90ощъ]їїщkk1q\n\t\rop~"; + std::u16string utf16Exp = u"\nòèçùà12345¤zß水𝄋ÿ€한𐍈®кею90ощъ]їїщkk1q\n\t\rop~"; + + std::u16string utf16Res; + ASSERT_TRUE(StringUtils::u8StringToU16String(utf8, utf16Res)); + + ASSERT_EQ(utf16Res.size(), utf16Exp.size()); + for (auto i = 0; i < utf16Exp.size(); i++) { + ASSERT_EQ(utf16Exp[i], utf16Res[i]); + } +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, test_unicode_utf8_utf32) { + + std::string utf8 = u8"\nòèçùà12345¤zß水𝄋ÿ€한𐍈®кею90ощъ]їїщkk1q\n\t\rop~"; + std::u32string utf32Exp = U"\nòèçùà12345¤zß水𝄋ÿ€한𐍈®кею90ощъ]їїщkk1q\n\t\rop~"; + + std::u32string utf32Res; + ASSERT_TRUE(StringUtils::u8StringToU32String(utf8, utf32Res)); + + ASSERT_EQ(utf32Res.size(), utf32Exp.size()); + for (auto i = 0; i < utf32Exp.size(); i++) { + ASSERT_EQ(utf32Exp[i], utf32Res[i]); + } +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, test_unicode_utf16_utf8) { + + std::string utf8Exp = u8"\nòèçùà12345¤zß水𝄋ÿ€한𐍈®кею90ощъ]їїщkk1q\n\t\rop~"; + std::u16string utf16 = u"\nòèçùà12345¤zß水𝄋ÿ€한𐍈®кею90ощъ]їїщkk1q\n\t\rop~"; + + std::string utf8Res; + ASSERT_TRUE(StringUtils::u16StringToU8String(utf16, utf8Res)); + + ASSERT_EQ(utf8Res.size(), utf8Exp.size()); + for (auto i = 0; i < utf8Exp.size(); i++) { + ASSERT_EQ(utf8Exp[i], utf8Res[i]); + } +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, test_unicode_utf32_utf8) { + + std::string utf8Exp = u8"\nòèçùà12345¤zß水𝄋ÿ€한𐍈®кею 90ощъ]їїщkk1q\n\t\rop~"; + std::u32string utf32 = U"\nòèçùà12345¤zß水𝄋ÿ€한𐍈®кею 90ощъ]їїщkk1q\n\t\rop~"; + + std::string utf8Res; + ASSERT_TRUE(StringUtils::u32StringToU8String(utf32, utf8Res)); + + ASSERT_EQ(utf8Res.size(), utf8Exp.size()); + for (auto i = 0; i < utf8Exp.size(); i++) { + ASSERT_EQ(utf8Exp[i], utf8Res[i]); + } +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, test_unicode_utf16_utf32) { + + std::u32string utf32Exp = U"\nòèçùà12345¤zß水𝄋ÿ€한𐍈®кею90ощъ]їїщkk1q\n\t\rop~"; + std::u16string utf16 = u"\nòèçùà12345¤zß水𝄋ÿ€한𐍈®кею90ощъ]їїщkk1q\n\t\rop~"; + + std::u32string utf32Res; + ASSERT_TRUE(StringUtils::u16StringToU32String(utf16, utf32Res)); + + ASSERT_EQ(utf32Res.size(), utf32Exp.size()); + for (auto i = 0; i < utf32Exp.size(); i++) { + ASSERT_EQ(utf32Exp[i], utf32Res[i]); + } +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, test_unicode_utf32_utf16) { + + std::u16string utf16Exp = u"\nòèçùà12345¤zß水𝄋ÿ€한𐍈®кею90ощъ]їїщkk1q\n\t\rop~"; + std::u32string utf32 = U"\nòèçùà12345¤zß水𝄋ÿ€한𐍈®кею90ощъ]їїщkk1q\n\t\rop~"; + + std::u16string utf16Res; + ASSERT_TRUE(StringUtils::u32StringToU16String(utf32, utf16Res)); + + ASSERT_EQ(utf16Res.size(), utf16Exp.size()); + for (auto i = 0; i < utf16Exp.size(); i++) { + ASSERT_EQ(utf16Exp[i], utf16Res[i]); + } +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, byte_length_test_Default) { + + std::string f("alpha"); + auto array = NDArrayFactory::string(f); + + ASSERT_EQ(f.length(), StringUtils::byteLength(array)); + + std::u16string f16(u"alpha"); + auto array16 = NDArrayFactory::string(f16); + + ASSERT_EQ(sizeof(char16_t)*f16.length(), StringUtils::byteLength(array16)); + + std::u32string f32(U"alpha"); + auto array32 = NDArrayFactory::string(f32); + + ASSERT_EQ(sizeof(char32_t) * f32.length(), StringUtils::byteLength(array32)); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, byte_length_test_UTF16) { + std::string f(u8"alpha"); + auto array = NDArrayFactory::string(f, nd4j::DataType::UTF16); + + ASSERT_EQ(sizeof(char16_t) * f.length(), StringUtils::byteLength(array)); + + std::u16string f16(u"alpha"); + auto array16 = NDArrayFactory::string(f16, nd4j::DataType::UTF16); + + ASSERT_EQ(sizeof(char16_t) * f16.length(), StringUtils::byteLength(array16)); + + std::u32string f32(U"alpha"); + auto array32 = NDArrayFactory::string(f32, nd4j::DataType::UTF16); + + ASSERT_EQ(sizeof(char16_t) * f32.length(), StringUtils::byteLength(array32)); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Basic_Test_UTF16toU8) { + + std::u16string f16(u"alpha水𝄋ÿ€한𐍈®кею"); + auto array = NDArrayFactory::string(f16, nd4j::DataType::UTF8); + ASSERT_EQ(nd4j::DataType::UTF8, array.dataType()); + + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); + + auto z = array.e(0); + + std::string f(u8"alpha水𝄋ÿ€한𐍈®кею"); + ASSERT_EQ(f, z); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Basic_Test_UTF32toU8) { + std::u32string f32(U"alpha水𝄋ÿ€한𐍈®кею"); + auto array = NDArrayFactory::string(f32.c_str(), nd4j::DataType::UTF8); + ASSERT_EQ(nd4j::DataType::UTF8, array.dataType()); + + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); + + auto z = array.e(0); + std::string f(u8"alpha水𝄋ÿ€한𐍈®кею"); + ASSERT_EQ(f, z); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Basic_Test_UTF16toU16) { + + std::u16string f16(u"€alpha水𝄋ÿ€한𐍈®кею"); + auto array = NDArrayFactory::string(f16, nd4j::DataType::UTF16); + ASSERT_EQ(nd4j::DataType::UTF16, array.dataType()); + + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); + auto z = array.e(0); + + ASSERT_EQ(z, f16); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Basic_Test_UTF32toU16) { + + std::u32string f32(U"€alpha水𝄋ÿ€한𐍈®кею"); + auto array = NDArrayFactory::string(f32, nd4j::DataType::UTF16); + ASSERT_EQ(nd4j::DataType::UTF16, array.dataType()); + + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); + auto z = array.e(0); + std::u16string f16(u"€alpha水𝄋ÿ€한𐍈®кею"); + ASSERT_EQ(z, f16); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Basic_Test_UTF16toU32) { + + std::u16string f16(u"€alpha水𝄋ÿ€한𐍈®кею"); + auto array = NDArrayFactory::string(f16, nd4j::DataType::UTF32); + ASSERT_EQ(nd4j::DataType::UTF32, array.dataType()); + + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); + + auto z = array.e(0); + std::u32string fres(U"€alpha水𝄋ÿ€한𐍈®кею"); + ASSERT_EQ(z, fres); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Basic_Test_UTF32toU32) { + + std::u32string f32(U"€alpha水𝄋ÿ€한𐍈®кею"); + auto array = NDArrayFactory::string(f32); + ASSERT_EQ(nd4j::DataType::UTF32, array.dataType()); + + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); + auto z = array.e(0); + ASSERT_EQ(f32, z); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Basic_Test_UTF8toU32) { + + std::string f(u8"€alpha水𝄋ÿ€한𐍈®кею"); + auto array = NDArrayFactory::string(f, nd4j::DataType::UTF32); + ASSERT_EQ(nd4j::DataType::UTF32, array.dataType()); + + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); + std::u32string f32(U"€alpha水𝄋ÿ€한𐍈®кею"); + auto z = array.e(0); + ASSERT_EQ(f32, z); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Basic_Test_StringVecU8toUTF16) { + auto array = NDArrayFactory::string({ 3, 2 }, { "alpha€", "beta", "gamma水", "phi", "theta", "omega水" }, nd4j::DataType::UTF16); + + ASSERT_EQ(6, array.lengthOf()); + ASSERT_EQ(2, array.rankOf()); + + array.printIndexedBuffer("String array"); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Basic_Test_StringVecU8toUTF32) { + auto array = NDArrayFactory::string( { 3, 2 }, { "alpha€", "beta水", "gamma", "phi", "theta", "omega" }, nd4j::DataType::UTF32); + + ASSERT_EQ(6, array.lengthOf()); + ASSERT_EQ(2, array.rankOf()); + + array.printIndexedBuffer("String array"); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Export_Test_U8toUTF16) { + auto array = NDArrayFactory::string({ 3 }, { "alpha", "beta", "gamma" }, nd4j::DataType::UTF16); + + auto vector = array.asByteVector(); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Export_Test_U8toUTF32) { + auto array = NDArrayFactory::string({ 3 }, { "alpha", "beta", "gamma" }, nd4j::DataType::UTF32); + + auto vector = array.asByteVector(); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Basic_Test_StringVecU16toUTF16) { + auto array = NDArrayFactory::string({ 3, 2 }, { u"alpha水", u"beta", u"gamma", u"phi", u"theta水", u"omega" }, nd4j::DataType::UTF16); + + ASSERT_EQ(6, array.lengthOf()); + ASSERT_EQ(2, array.rankOf()); + + array.printIndexedBuffer("String array"); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Basic_Test_StringVecU16toUTF32) { + auto array = NDArrayFactory::string( { 3, 2 }, { u"alpha水", u"beta", u"gamma水", u"phi", u"theta", u"omega" }, nd4j::DataType::UTF32); + + ASSERT_EQ(6, array.lengthOf()); + ASSERT_EQ(2, array.rankOf()); + + array.printIndexedBuffer("String array"); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Basic_Test_StringVecU16toUTF8) { + auto array = NDArrayFactory::string( { 3, 2 }, { u"alpha€", u"beta水", u"gamma", u"phi水", u"theta", u"omega" }, nd4j::DataType::UTF8); + + ASSERT_EQ(6, array.lengthOf()); + ASSERT_EQ(2, array.rankOf()); + + array.printIndexedBuffer("String array"); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Export_Test_U16toUTF8) { + auto array = NDArrayFactory::string( { 3 }, { u"alpha", u"beta", u"gamma" }, nd4j::DataType::UTF8); + + auto vector = array.asByteVector(); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Export_Test_U16toUTF16) { + auto array = NDArrayFactory::string( { 3 }, { u"alpha", u"beta", u"gamma" }, nd4j::DataType::UTF16); + + auto vector = array.asByteVector(); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Export_Test_U16toUTF32) { + auto array = NDArrayFactory::string( { 3 }, { u"alpha水", u"beta", u"gamma水" }, nd4j::DataType::UTF32); + + auto vector = array.asByteVector(); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Basic_Test_StringVecU32toUTF32) { + auto array = NDArrayFactory::string( { 3, 2 }, { U"alpha€", U"beta水", U"gamma", U"phi", U"theta", U"omega水" }, nd4j::DataType::UTF32); + + ASSERT_EQ(6, array.lengthOf()); + ASSERT_EQ(2, array.rankOf()); + + array.printIndexedBuffer("String array"); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Basic_Test_StringVecU32toUTF16) { + auto array = NDArrayFactory::string({ 3, 2 }, { U"alpha水", U"水beta", U"gamma", U"phi水", U"theta", U"omega" }, nd4j::DataType::UTF16); + + ASSERT_EQ(6, array.lengthOf()); + ASSERT_EQ(2, array.rankOf()); + + array.printIndexedBuffer("String array"); + + printf("Array elements size: \n"); + for (int e = 0; e < array.lengthOf(); e++) { + printf("Element %d size: %d\n", e, static_cast(array.e(e).size())); + } +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Basic_Test_StringVecU32toUTF8) { + auto array = NDArrayFactory::string( { 3, 2 }, { U"alpha水", U"beta", U"gamma水", U"phi", U"theta", U"omega" }, nd4j::DataType::UTF8); + + ASSERT_EQ(6, array.lengthOf()); + ASSERT_EQ(2, array.rankOf()); + + array.printIndexedBuffer("String array"); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Export_Test_U32toUTF32) { + auto array = NDArrayFactory::string( { 3 }, { U"alpha", U"beta", U"gamma" }, nd4j::DataType::UTF32); + + auto vector = array.asByteVector(); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Export_Test_U32toUTF16) { + auto array = NDArrayFactory::string( { 3 }, { U"alpha", U"beta水", U"gamma水" }, nd4j::DataType::UTF16); + + auto vector = array.asByteVector(); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Export_Test_U32toUTF8) { + auto array = NDArrayFactory::string( { 3 }, { U"alpha", U"beta", U"gamma水" }, nd4j::DataType::UTF8); + + auto vector = array.asByteVector(); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Basic_dup_UTF16) { + std::u16string f(u"€alpha水𝄋ÿ€한𐍈®кею"); + auto array = NDArrayFactory::string(f); + ASSERT_EQ(nd4j::DataType::UTF16, array.dataType()); + + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); + + auto dup = new NDArray(array.dup()); + + auto z0 = array.e(0); + auto z1 = dup->e(0); + + ASSERT_EQ(f, z0); + ASSERT_EQ(f, z1); + + delete dup; +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Basic_dup_UTF32) { + std::u32string f(U"€alpha水𝄋ÿ€한𐍈®кею"); + auto array = NDArrayFactory::string(f); + ASSERT_EQ(nd4j::DataType::UTF32, array.dataType()); + + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); + + auto dup = new NDArray(array.dup()); + + auto z0 = array.e(0); + auto z1 = dup->e(0); + + ASSERT_EQ(f, z0); + ASSERT_EQ(f, z1); + + delete dup; +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Basic_cast_UTF32toUTF8) { + + std::u32string u32(U"€alpha水𝄋ÿ€한𐍈®кею"); + + std::string u8(u8"€alpha水𝄋ÿ€한𐍈®кею"); + + auto array = NDArrayFactory::string(u32); + ASSERT_EQ(nd4j::DataType::UTF32, array.dataType()); + + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); + + auto aCast = array.cast(nd4j::DataType::UTF8); + + auto z0 = array.e(0); + auto z1 = aCast.e(0); + + ASSERT_EQ(u32, z0); + ASSERT_EQ(u8, z1); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Basic_cast_UTF32toUTF16) { + + std::u32string u32(U"€alpha水𝄋ÿ€한𐍈®кею"); + + std::u16string u16(u"€alpha水𝄋ÿ€한𐍈®кею"); + + auto array = NDArrayFactory::string(u32); + ASSERT_EQ(nd4j::DataType::UTF32, array.dataType()); + + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); + + auto aCast = array.cast(nd4j::DataType::UTF16); + + auto z0 = array.e(0); + auto z1 = aCast.e(0); + + ASSERT_EQ(u32, z0); + ASSERT_EQ(u16, z1); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Basic_cast_UTF32toUTF32) { + + std::u32string u32(U"€alpha水𝄋ÿ€한𐍈®кею"); + + auto array = NDArrayFactory::string(u32); + ASSERT_EQ(nd4j::DataType::UTF32, array.dataType()); + + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); + + auto aCast = array.cast(nd4j::DataType::UTF32); + + auto z0 = array.e(0); + auto z1 = aCast.e(0); + + ASSERT_EQ(u32, z0); + ASSERT_EQ(u32, z1); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Basic_cast_UTF16toUTF16) { + + std::u16string u16(u"€alpha水𝄋ÿ€한𐍈®кею"); + + auto array = NDArrayFactory::string(u16); + ASSERT_EQ(nd4j::DataType::UTF16, array.dataType()); + + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); + + auto aCast = array.cast(nd4j::DataType::UTF16); + + auto z0 = array.e(0); + auto z1 = aCast.e(0); + + ASSERT_EQ(u16, z0); + ASSERT_EQ(u16, z1); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Basic_cast_UTF16toUTF32) { + + std::u32string u32(U"€alpha水𝄋ÿ€한𐍈®кею"); + + std::u16string u16(u"€alpha水𝄋ÿ€한𐍈®кею"); + + auto array = NDArrayFactory::string(u16); + ASSERT_EQ(nd4j::DataType::UTF16, array.dataType()); + + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); + + auto aCast = array.cast(nd4j::DataType::UTF32); + + auto z0 = array.e(0); + auto z1 = aCast.e(0); + + ASSERT_EQ(u32, z1); + ASSERT_EQ(u16, z0); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Basic_cast_UTF16toUTF8) { + + std::string u8(u8"€alpha水𝄋ÿ€한𐍈®кею"); + + std::u16string u16(u"€alpha水𝄋ÿ€한𐍈®кею"); + + auto array = NDArrayFactory::string(u16); + ASSERT_EQ(nd4j::DataType::UTF16, array.dataType()); + + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); + + auto aCast = array.cast(nd4j::DataType::UTF8); + + auto z0 = array.e(0); + auto z1 = aCast.e(0); + + ASSERT_EQ(u8, z1); + ASSERT_EQ(u16, z0); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Basic_cast_UTF8toUTF8) { + + std::string u8("€alpha水𝄋ÿ€한𐍈®кею"); + + auto array = NDArrayFactory::string(u8); + ASSERT_EQ(nd4j::DataType::UTF8, array.dataType()); + + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); + + auto aCast = array.cast(nd4j::DataType::UTF8); + + auto z0 = array.e(0); + auto z1 = aCast.e(0); + + ASSERT_EQ(u8, z1); + ASSERT_EQ(u8, z0); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Basic_cast_UTF8toUTF16) { + + std::string u8(u8"€alpha水𝄋ÿ€한𐍈®кею"); + + std::u16string u16(u"€alpha水𝄋ÿ€한𐍈®кею"); + + auto array = NDArrayFactory::string(u8); + ASSERT_EQ(nd4j::DataType::UTF8, array.dataType()); + + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); + + auto aCast = array.cast(nd4j::DataType::UTF16); + + auto z0 = array.e(0); + auto z1 = aCast.e(0); + + ASSERT_EQ(u8, z0); + ASSERT_EQ(u16, z1); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Basic_cast_UTF8toUTF32) { + + std::string u8(u8"€alpha水𝄋ÿ€한𐍈®кею"); + + std::u32string u32(U"€alpha水𝄋ÿ€한𐍈®кею"); + + auto array = NDArrayFactory::string(u8); + ASSERT_EQ(nd4j::DataType::UTF8, array.dataType()); + + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); + + auto aCast = array.cast(nd4j::DataType::UTF32); + + auto z0 = array.e(0); + auto z1 = aCast.e(0); + + ASSERT_EQ(u8, z0); + ASSERT_EQ(u32, z1); +} From 0756e3fe7073686b513c19b4b6aa7fa8ed69c410 Mon Sep 17 00:00:00 2001 From: Alex Black Date: Sat, 1 Feb 2020 18:19:36 +1100 Subject: [PATCH 12/17] Small fixes. (#206) * Logging format tweaks for file logging Signed-off-by: AlexDBlack * Min abs error tweak for Util layer gradient checks Signed-off-by: AlexDBlack * #8648 Fix SameDiff NPE instead of error for missing placeholders Signed-off-by: AlexDBlack * Test runtime reduction Signed-off-by: AlexDBlack --- .../CapsnetGradientCheckTest.java | 6 ----- .../OutputLayerGradientChecks.java | 4 ---- .../gradientcheck/RnnGradientChecks.java | 4 ---- .../UtilLayerGradientChecks.java | 17 +++++-------- .../gradientcheck/YoloGradientCheckTests.java | 5 ---- .../src/test/resources/logback-test.xml | 4 ++-- .../src/test/resources/logback.xml | 4 ++-- .../src/test/resources/logback.xml | 4 ++-- .../src/test/resources/logback.xml | 4 ++-- .../src/test/resources/logback.xml | 4 ++-- .../src/test/resources/logback.xml | 4 ++-- .../src/test/resources/logback.xml | 4 ++-- .../src/test/resources/logback.xml | 4 ++-- .../dl4j-spark/src/test/resources/logback.xml | 4 ++-- .../src/test/resources/logback.xml | 4 ++-- .../src/main/resources/logback.xml | 4 ++-- .../src/test/resources/logback.xml | 4 ++-- .../src/test/resources/logback.xml | 4 ++-- .../src/test/resources/logback-test.xml | 4 ++-- .../samediff/internal/AbstractSession.java | 7 +++++- .../src/test/resources/logback.xml | 4 ++-- .../nd4j/autodiff/samediff/SameDiffTests.java | 24 +++++++++++++++++++ .../test/java/org/nd4j/linalg/Nd4jTestsC.java | 4 ++-- .../nd4j-tests/src/test/resources/logback.xml | 4 ++-- .../src/test/resources/logback.xml | 4 ++-- .../src/test/resources/logback.xml | 4 ++-- .../src/test/resources/logback.xml | 4 ++-- .../src/test/resources/logback.xml | 4 ++-- .../src/test/resources/logback.xml | 4 ++-- .../nd4j-aeron/src/test/resources/logback.xml | 4 ++-- 30 files changed, 82 insertions(+), 77 deletions(-) diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CapsnetGradientCheckTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CapsnetGradientCheckTest.java index e69766677..ac0224b7a 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CapsnetGradientCheckTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CapsnetGradientCheckTest.java @@ -43,12 +43,6 @@ import java.util.Random; public class CapsnetGradientCheckTest extends BaseDL4JTest { - private static final boolean PRINT_RESULTS = true; - private static final boolean RETURN_ON_FIRST_FAILURE = false; - private static final double DEFAULT_EPS = 1e-6; - private static final double DEFAULT_MAX_REL_ERROR = 1e-3; - private static final double DEFAULT_MIN_ABS_ERROR = 1e-8; - @Test public void testCapsNet() { diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/OutputLayerGradientChecks.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/OutputLayerGradientChecks.java index 67fc4c11c..6796b2790 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/OutputLayerGradientChecks.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/OutputLayerGradientChecks.java @@ -43,10 +43,6 @@ import static org.junit.Assert.assertTrue; public class OutputLayerGradientChecks extends BaseDL4JTest { private static final boolean PRINT_RESULTS = true; - private static final boolean RETURN_ON_FIRST_FAILURE = false; - private static final double DEFAULT_EPS = 1e-6; - private static final double DEFAULT_MAX_REL_ERROR = 1e-3; - private static final double DEFAULT_MIN_ABS_ERROR = 1e-8; static { Nd4j.setDataType(DataType.DOUBLE); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/RnnGradientChecks.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/RnnGradientChecks.java index 2980cad7c..afb4fd67c 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/RnnGradientChecks.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/RnnGradientChecks.java @@ -47,10 +47,6 @@ import static org.junit.Assert.assertTrue; public class RnnGradientChecks extends BaseDL4JTest { private static final boolean PRINT_RESULTS = true; - private static final boolean RETURN_ON_FIRST_FAILURE = false; - private static final double DEFAULT_EPS = 1e-6; - private static final double DEFAULT_MAX_REL_ERROR = 1e-3; - private static final double DEFAULT_MIN_ABS_ERROR = 1e-8; static { Nd4j.setDataType(DataType.DOUBLE); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/UtilLayerGradientChecks.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/UtilLayerGradientChecks.java index 2d889a6a1..6482984cb 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/UtilLayerGradientChecks.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/UtilLayerGradientChecks.java @@ -48,12 +48,6 @@ import static org.junit.Assert.assertTrue; public class UtilLayerGradientChecks extends BaseDL4JTest { - private static final boolean PRINT_RESULTS = true; - private static final boolean RETURN_ON_FIRST_FAILURE = false; - private static final double DEFAULT_EPS = 1e-6; - private static final double DEFAULT_MAX_REL_ERROR = 1e-3; - private static final double DEFAULT_MIN_ABS_ERROR = 1e-6; - static { Nd4j.setDataType(DataType.DOUBLE); } @@ -182,9 +176,9 @@ public class UtilLayerGradientChecks extends BaseDL4JTest { MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(input) - .minAbsoluteError(1e-7) - .labels(label).inputMask(inMask)); + boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net) + .minAbsoluteError(1e-6) + .input(input).labels(label).inputMask(inMask)); assertTrue(gradOK); TestUtils.testModelSerialization(net); @@ -233,8 +227,9 @@ public class UtilLayerGradientChecks extends BaseDL4JTest { //Test ComputationGraph equivalent: ComputationGraph g = net.toComputationGraph(); - boolean gradOKCG = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(g).inputs(new INDArray[]{in}) - .labels(new INDArray[]{labels}).excludeParams(excludeParams)); + boolean gradOKCG = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(g) + .minAbsoluteError(1e-6) + .inputs(new INDArray[]{in}).labels(new INDArray[]{labels}).excludeParams(excludeParams)); assertTrue(gradOKCG); TestUtils.testModelSerialization(g); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/YoloGradientCheckTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/YoloGradientCheckTests.java index 147150aa8..a47716740 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/YoloGradientCheckTests.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/YoloGradientCheckTests.java @@ -56,11 +56,6 @@ import static org.junit.Assert.assertTrue; * @author Alex Black */ public class YoloGradientCheckTests extends BaseDL4JTest { - private static final boolean PRINT_RESULTS = true; - private static final boolean RETURN_ON_FIRST_FAILURE = false; - private static final double DEFAULT_EPS = 1e-6; - private static final double DEFAULT_MAX_REL_ERROR = 1e-3; - private static final double DEFAULT_MIN_ABS_ERROR = 1e-8; static { Nd4j.setDataType(DataType.DOUBLE); diff --git a/deeplearning4j/deeplearning4j-core/src/test/resources/logback-test.xml b/deeplearning4j/deeplearning4j-core/src/test/resources/logback-test.xml index c6f89b60a..69246755b 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/resources/logback-test.xml +++ b/deeplearning4j/deeplearning4j-core/src/test/resources/logback-test.xml @@ -19,8 +19,8 @@ logs/application.log - %date - [%level] - from %logger in %thread - %n%message%n%xException%n + %logger{15} - %message%n%xException{5} + diff --git a/deeplearning4j/deeplearning4j-cuda/src/test/resources/logback.xml b/deeplearning4j/deeplearning4j-cuda/src/test/resources/logback.xml index c6f89b60a..69246755b 100644 --- a/deeplearning4j/deeplearning4j-cuda/src/test/resources/logback.xml +++ b/deeplearning4j/deeplearning4j-cuda/src/test/resources/logback.xml @@ -19,8 +19,8 @@ logs/application.log - %date - [%level] - from %logger in %thread - %n%message%n%xException%n + %logger{15} - %message%n%xException{5} + diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/test/resources/logback.xml b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/test/resources/logback.xml index 7953c2712..7d49481af 100644 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/test/resources/logback.xml +++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/test/resources/logback.xml @@ -19,8 +19,8 @@ logs/application.log - %date - [%level] - from %logger in %thread - %n%message%n%xException%n + %logger{15} - %message%n%xException{5} + diff --git a/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/resources/logback.xml b/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/resources/logback.xml index 59b35644e..cbcbed5d6 100644 --- a/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/resources/logback.xml +++ b/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/resources/logback.xml @@ -21,8 +21,8 @@ logs/application.log - %date - [%level] - from %logger in %thread - %n%message%n%xException%n + %logger{15} - %message%n%xException{5} + diff --git a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper-parameter-server/src/test/resources/logback.xml b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper-parameter-server/src/test/resources/logback.xml index f1ffbc8ac..f6b823056 100644 --- a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper-parameter-server/src/test/resources/logback.xml +++ b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper-parameter-server/src/test/resources/logback.xml @@ -21,8 +21,8 @@ logs/application.log - %date - [%level] - from %logger in %thread - %n%message%n%xException%n + %logger{15} - %message%n%xException{5} + diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/test/resources/logback.xml b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/test/resources/logback.xml index 9dec22fae..4d94f2516 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/test/resources/logback.xml +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/test/resources/logback.xml @@ -21,8 +21,8 @@ logs/application.log - %date - [%level] - from %logger in %thread - %n%message%n%xException%n + %logger{15} - %message%n%xException{5} + diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/resources/logback.xml b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/resources/logback.xml index 9dec22fae..4d94f2516 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/resources/logback.xml +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/resources/logback.xml @@ -21,8 +21,8 @@ logs/application.log - %date - [%level] - from %logger in %thread - %n%message%n%xException%n + %logger{15} - %message%n%xException{5} + diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/resources/logback.xml b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/resources/logback.xml index 9dec22fae..4d94f2516 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/resources/logback.xml +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/resources/logback.xml @@ -21,8 +21,8 @@ logs/application.log - %date - [%level] - from %logger in %thread - %n%message%n%xException%n + %logger{15} - %message%n%xException{5} + diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/resources/logback.xml b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/resources/logback.xml index 9dec22fae..4d94f2516 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/resources/logback.xml +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/resources/logback.xml @@ -21,8 +21,8 @@ logs/application.log - %date - [%level] - from %logger in %thread - %n%message%n%xException%n + %logger{15} - %message%n%xException{5} + diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-model/src/test/resources/logback.xml b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-model/src/test/resources/logback.xml index 2c204cafa..9baf66a0d 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-model/src/test/resources/logback.xml +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-model/src/test/resources/logback.xml @@ -21,8 +21,8 @@ logs/application.log - %date - [%level] - from %logger in %thread - %n%message%n%xException%n + %logger{15} - %message%n%xException{5} + diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-standalone/src/main/resources/logback.xml b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-standalone/src/main/resources/logback.xml index 1753f88dc..2283bdc50 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-standalone/src/main/resources/logback.xml +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-standalone/src/main/resources/logback.xml @@ -21,8 +21,8 @@ logs/application.log - %date - [%level] - from %logger in %thread - %n%message%n%xException%n + %logger{15} - %message%n%xException{5} + diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui/src/test/resources/logback.xml b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui/src/test/resources/logback.xml index 2c204cafa..9baf66a0d 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui/src/test/resources/logback.xml +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui/src/test/resources/logback.xml @@ -21,8 +21,8 @@ logs/application.log - %date - [%level] - from %logger in %thread - %n%message%n%xException%n + %logger{15} - %message%n%xException{5} + diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/test/resources/logback.xml b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/test/resources/logback.xml index 2c204cafa..9baf66a0d 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/test/resources/logback.xml +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/test/resources/logback.xml @@ -21,8 +21,8 @@ logs/application.log - %date - [%level] - from %logger in %thread - %n%message%n%xException%n + %logger{15} - %message%n%xException{5} + diff --git a/deeplearning4j/dl4j-integration-tests/src/test/resources/logback-test.xml b/deeplearning4j/dl4j-integration-tests/src/test/resources/logback-test.xml index c6f89b60a..69246755b 100644 --- a/deeplearning4j/dl4j-integration-tests/src/test/resources/logback-test.xml +++ b/deeplearning4j/dl4j-integration-tests/src/test/resources/logback-test.xml @@ -19,8 +19,8 @@ logs/application.log - %date - [%level] - from %logger in %thread - %n%message%n%xException%n + %logger{15} - %message%n%xException{5} + diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/AbstractSession.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/AbstractSession.java index c95f26b1f..d89fe05a5 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/AbstractSession.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/AbstractSession.java @@ -25,6 +25,7 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.VariableType; import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.controlflow.compat.*; import org.nd4j.linalg.dataset.api.MultiDataSet; import org.nd4j.linalg.function.Predicate; @@ -295,9 +296,13 @@ public abstract class AbstractSession { } } else if (es.getType() == ExecType.PLACEHOLDER) { VarId vid = new VarId(es.getName(), OUTER_FRAME, 0, null); - nodeOutputs.put(vid, placeholderValues.get(es.getName())); + T phVal = placeholderValues == null ? null : placeholderValues.get(es.getName()); + + nodeOutputs.put(vid, phVal); outFrameIter = new FrameIter(OUTER_FRAME, 0, null); if (allRequired.contains(es.getName())) { + Preconditions.checkState(placeholderValues != null && placeholderValues.containsKey(es.getName()), + "No array was provided for the placeholder variable \"%s\" that is required for execution", es.getName()); //User requested placeholder value as one of the outputs out.put(es.getName(), placeholderValues.get(es.getName())); } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/test/resources/logback.xml b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/test/resources/logback.xml index 63b2ca84e..3f88b92bb 100755 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/test/resources/logback.xml +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/test/resources/logback.xml @@ -21,8 +21,8 @@ logs/application.log - %date - [%level] - from %logger in %thread - %n%message%n%xException%n + %logger{15} - %message%n%xException{5} + diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java index 774af9657..89e9a8abd 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java @@ -38,6 +38,7 @@ import org.junit.Ignore; import org.junit.Test; import org.junit.rules.TemporaryFolder; import org.nd4j.OpValidationSuite; +import org.nd4j.autodiff.loss.LossReduce; import org.nd4j.autodiff.samediff.api.OutAndGrad; import org.nd4j.autodiff.validation.OpValidation; import org.nd4j.autodiff.validation.TestCase; @@ -3528,4 +3529,27 @@ public class SameDiffTests extends BaseNd4jTest { String s = out.toString(); } } + + @Test + public void testMissingPlaceholderError(){ + + SameDiff sd = SameDiff.create(); + + int nOut = 4; + int minibatch = 10; + SDVariable predictions = sd.var("in", DataType.DOUBLE, minibatch, nOut); + SDVariable labels = sd.placeHolder("labels", DataType.DOUBLE, -1, nOut); + + LossReduce reduction = LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT; + + SDVariable loss = sd.loss().absoluteDifference("loss", labels, predictions, null, reduction); + + try { + loss.eval(); + fail("Exception should have been thrown"); + } catch (IllegalStateException e){ + String msg = e.getMessage(); + assertTrue(msg, msg.contains("\"labels\"") && msg.contains("No array was provided")); + } + } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java index 23fd775b8..91cec2a5c 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java @@ -6935,9 +6935,9 @@ public class Nd4jTestsC extends BaseNd4jTest { val arrayY = Nd4j.create(128, 128, 'f'); val arrayZ = Nd4j.create(128, 128, 'f'); - int iterations = 10000; + int iterations = 100; // warmup - for (int e = 0; e < 1000; e++) + for (int e = 0; e < 10; e++) arrayX.addi(arrayY); for (int e = 0; e < iterations; e++) { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/resources/logback.xml b/nd4j/nd4j-backends/nd4j-tests/src/test/resources/logback.xml index 9d7518dfb..4a795a852 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/resources/logback.xml +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/resources/logback.xml @@ -21,8 +21,8 @@ logs/application.log - %date - [%level] - from %logger in %thread - %n%message%n%xException%n + %logger{15} - %message%n%xException{5} + diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-client/src/test/resources/logback.xml b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-client/src/test/resources/logback.xml index 2fe9c103d..ca35ddc13 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-client/src/test/resources/logback.xml +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-client/src/test/resources/logback.xml @@ -21,8 +21,8 @@ logs/application.log - %date - [%level] - from %logger in %thread - %n%message%n%xException%n + %logger{15} - %message%n%xException{5} + diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/resources/logback.xml b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/resources/logback.xml index b64707465..c483a01c6 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/resources/logback.xml +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/resources/logback.xml @@ -21,8 +21,8 @@ logs/application.log - %date - [%level] - from %logger in %thread - %n%message%n%xException%n + %logger{15} - %message%n%xException{5} + diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-status/src/test/resources/logback.xml b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-status/src/test/resources/logback.xml index 2fe9c103d..ca35ddc13 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-status/src/test/resources/logback.xml +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-status/src/test/resources/logback.xml @@ -21,8 +21,8 @@ logs/application.log - %date - [%level] - from %logger in %thread - %n%message%n%xException%n + %logger{15} - %message%n%xException{5} + diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server/src/test/resources/logback.xml b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server/src/test/resources/logback.xml index 2fe9c103d..ca35ddc13 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server/src/test/resources/logback.xml +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server/src/test/resources/logback.xml @@ -21,8 +21,8 @@ logs/application.log - %date - [%level] - from %logger in %thread - %n%message%n%xException%n + %logger{15} - %message%n%xException{5} + diff --git a/nd4j/nd4j-remote/nd4j-json-server/src/test/resources/logback.xml b/nd4j/nd4j-remote/nd4j-json-server/src/test/resources/logback.xml index 59b35644e..cbcbed5d6 100644 --- a/nd4j/nd4j-remote/nd4j-json-server/src/test/resources/logback.xml +++ b/nd4j/nd4j-remote/nd4j-json-server/src/test/resources/logback.xml @@ -21,8 +21,8 @@ logs/application.log - %date - [%level] - from %logger in %thread - %n%message%n%xException%n + %logger{15} - %message%n%xException{5} + diff --git a/nd4j/nd4j-serde/nd4j-aeron/src/test/resources/logback.xml b/nd4j/nd4j-serde/nd4j-aeron/src/test/resources/logback.xml index 30cd40807..56edb3183 100644 --- a/nd4j/nd4j-serde/nd4j-aeron/src/test/resources/logback.xml +++ b/nd4j/nd4j-serde/nd4j-aeron/src/test/resources/logback.xml @@ -21,8 +21,8 @@ logs/application.log - %date - [%level] - from %logger in %thread - %n%message%n%xException%n + %logger{15} - %message%n%xException{5} + From 81efa5c3b6e8bfebd9a787fdd47ae8445b49756a Mon Sep 17 00:00:00 2001 From: raver119 Date: Sun, 2 Feb 2020 19:17:26 +0300 Subject: [PATCH 13/17] [WIP] one small fix (#207) * one small fix Signed-off-by: raver119 * assert added Signed-off-by: raver119 --- .../linalg/jcublas/buffer/BaseCudaDataBuffer.java | 12 ++++++++++++ .../src/main/java/org/nd4j/nativeblas/Nd4jCuda.java | 1 - .../src/main/java/org/nd4j/nativeblas/Nd4jCpu.java | 1 - .../test/java/org/nd4j/linalg/shape/EmptyTests.java | 6 ++++++ 4 files changed, 18 insertions(+), 2 deletions(-) diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBuffer.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBuffer.java index 07614a0ad..02b857f7f 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBuffer.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBuffer.java @@ -1050,21 +1050,33 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda @Override public void setData(int[] data) { + if (data.length == 0) + return; + set(data, data.length, 0, 0); } @Override public void setData(long[] data) { + if (data.length == 0) + return; + set(data, data.length, 0, 0); } @Override public void setData(float[] data) { + if (data.length == 0) + return; + set(data, data.length, 0, 0); } @Override public void setData(double[] data) { + if (data.length == 0) + return; + set(data, data.length, 0, 0); } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java index 8d0029bc3..f85ae9cf1 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java @@ -3804,7 +3804,6 @@ public native @Cast("bool") boolean isOptimalRequirementsMet(); public NDArray(Pointer buffer, byte order, @Cast("Nd4jLong*") @StdVector long[] shape, @Cast("nd4j::DataType") int dtype) { super((Pointer)null); allocate(buffer, order, shape, dtype); } private native void allocate(Pointer buffer, byte order, @Cast("Nd4jLong*") @StdVector long[] shape, @Cast("nd4j::DataType") int dtype); - /** * This method returns new array with the same shape & data type * @return diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java index 93fbb71d7..5522141be 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java @@ -3807,7 +3807,6 @@ public native @Cast("bool") boolean isOptimalRequirementsMet(); public NDArray(Pointer buffer, byte order, @Cast("Nd4jLong*") @StdVector long[] shape, @Cast("nd4j::DataType") int dtype) { super((Pointer)null); allocate(buffer, order, shape, dtype); } private native void allocate(Pointer buffer, byte order, @Cast("Nd4jLong*") @StdVector long[] shape, @Cast("nd4j::DataType") int dtype); - /** * This method returns new array with the same shape & data type * @return diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/EmptyTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/EmptyTests.java index 3bef69c19..aa81097d1 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/EmptyTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/EmptyTests.java @@ -320,6 +320,12 @@ public class EmptyTests extends BaseNd4jTest { Nd4j.exec(op); } + @Test + public void testEmptyConstructor_1() { + val x = Nd4j.create(new double[0]); + assertTrue(x.isEmpty()); + } + @Override public char ordering() { return 'c'; From 9bb5798caca11753a1957c3f94929a09a709f618 Mon Sep 17 00:00:00 2001 From: raver119 Date: Sun, 2 Feb 2020 23:14:00 +0300 Subject: [PATCH 14/17] Null arrays fix (#208) * don't skip null arrays Signed-off-by: raver119 * one test tweak Signed-off-by: raver119 --- libnd4j/include/ops/declarable/DeclarableOp.h | 4 +-- .../declarable/generic/nn/convo/conv1d.cpp | 2 +- .../ops/declarable/impl/DeclarableOp.cpp | 21 ++++--------- .../layers_tests/DeclarableOpsTests19.cpp | 31 ++++++++++++++++++- .../java/org/nd4j/linalg/rng/RandomTests.java | 4 ++- 5 files changed, 42 insertions(+), 20 deletions(-) diff --git a/libnd4j/include/ops/declarable/DeclarableOp.h b/libnd4j/include/ops/declarable/DeclarableOp.h index ff8fe9e83..78f5fcaa4 100644 --- a/libnd4j/include/ops/declarable/DeclarableOp.h +++ b/libnd4j/include/ops/declarable/DeclarableOp.h @@ -171,7 +171,7 @@ namespace nd4j { Nd4jStatus execute(const std::vector &inputs, const std::vector &outputs); - template + template ::value>> Nd4jStatus execute(const std::vector &inputs, const std::vector &outputs, std::initializer_list tArgs); Nd4jStatus execute(const std::vector &inputs, const std::vector &outputs, const std::vector &tArgs, const std::vector &iArgs, const std::vector &bArgs = std::vector(), const std::vector &dArgs = std::vector(), bool isInplace = false); @@ -179,7 +179,7 @@ namespace nd4j { nd4j::ResultSet* evaluate(const std::vector &inputs); - template + template ::value>> nd4j::ResultSet* evaluate(const std::vector &inputs, std::initializer_list args); nd4j::ResultSet* evaluate(const std::vector &inputs, const std::vector &tArgs, const std::vector &iArgs, const std::vector &bArgs = std::vector(), const std::vector &dArgs = std::vector(), bool isInplace = false); diff --git a/libnd4j/include/ops/declarable/generic/nn/convo/conv1d.cpp b/libnd4j/include/ops/declarable/generic/nn/convo/conv1d.cpp index 2800e7185..9cd3285f3 100644 --- a/libnd4j/include/ops/declarable/generic/nn/convo/conv1d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/convo/conv1d.cpp @@ -222,7 +222,7 @@ CUSTOM_OP_IMPL(conv1d_bp, 3, 2, false, 0, 5) { auto gradWReshaped = gradW ->reshape(gradW->ordering(), {1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC] nd4j::ops::conv2d_bp conv2dBP; - const Nd4jStatus status = conv2dBP.execute({&inputReshaped, &weightsReshaped, bias, &gradOReshaped}, {&gradIReshaped, &gradWReshaped, gradB}, {}, {1,kW, 1,sW, 0,pW, 1,dW, paddingMode, !isNCW}, {}); + auto status = conv2dBP.execute({&inputReshaped, &weightsReshaped, bias, &gradOReshaped}, {&gradIReshaped, &gradWReshaped, gradB}, {}, {1,kW, 1,sW, 0,pW, 1,dW, paddingMode, !isNCW}, {}); if (status != ND4J_STATUS_OK) return status; diff --git a/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp b/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp index 7c4138d36..46d10b51c 100644 --- a/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp +++ b/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp @@ -165,10 +165,7 @@ namespace nd4j { // we build list of input shapes if (ctx.isFastPath()) { for (const auto p:ctx.fastpath_in()) { - if (p == nullptr) - continue; - - inSha.push_back(p->getShapeInfo()); + inSha.push_back(p == nullptr ? nullptr : p->getShapeInfo()); } } else { for (auto p: *ctx.inputs()) { @@ -184,6 +181,11 @@ namespace nd4j { } } + // if we override shape function, we'll return size of fastPath + if (ctx.isFastPath() && ctx.shapeFunctionOverride()) { + return (int) ctx.fastpath_out().size(); + } + // optionally saving input time if (Environment::getInstance()->isProfiling() && node != nullptr) { inputEnd = std::chrono::system_clock::now(); @@ -193,11 +195,6 @@ namespace nd4j { shapeStart = std::chrono::system_clock::now(); } - // if we override shape function, we'll return size of fastPath - if (ctx.isFastPath() && ctx.shapeFunctionOverride()) { - return (int) ctx.fastpath_out().size(); - } - auto outSha = this->calculateOutputShape(&inSha, ctx); results = outSha->size(); @@ -870,16 +867,10 @@ namespace nd4j { Context ctx(1); for (int e = 0; e < inputs.size(); e++) { - if (inputs[e] == nullptr) - break; - ctx.setInputArray(e, inputs[e]); } for (int e = 0; e < outputs.size(); e++) { - if (outputs[e] == nullptr) - break; - ctx.setOutputArray(e, outputs[e]); } diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests19.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests19.cpp index 871bfe186..9883a9d79 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests19.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests19.cpp @@ -37,4 +37,33 @@ public: printf("\n"); fflush(stdout); } -}; \ No newline at end of file +}; + +TEST_F(DeclarableOpsTests19, test_conv1d_bp_1) { + /* + DynamicCustomOp op = DynamicCustomOp.builder("conv1d_bp") + .addInputs( + Nd4j.create(DataType.FLOAT, 2,2,12), + Nd4j.create(DataType.FLOAT, 3,2,3), + Nd4j.create(DataType.FLOAT, 2,3,6) + ) + .addOutputs( + Nd4j.create(DataType.FLOAT, 2,2,12), + Nd4j.create(DataType.FLOAT, 3,2,3)) + .addIntegerArguments(3,2,0,1,2,0) + .build(); + + Nd4j.exec(op); + */ + + auto t = NDArrayFactory::create('c', {2, 2, 12}); + auto u = NDArrayFactory::create('c', {3, 2, 3}); + auto v = NDArrayFactory::create('c', {2, 3, 6}); + + nd4j::ops::conv1d_bp op; + auto result = op.evaluate({&t, &u, &v}, {3, 2, 0, 1, 2,0}); + ASSERT_EQ(Status::OK(), result->status()); + + + delete result; +} \ No newline at end of file diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/rng/RandomTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/rng/RandomTests.java index 8a06bd7e9..b2de46e1b 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/rng/RandomTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/rng/RandomTests.java @@ -810,9 +810,11 @@ public class RandomTests extends BaseNd4jTest { threads[x].start(); } - for (int x = 0; x < threads.length; x++) { + // we want all threads finished before comparing arrays + for (int x = 0; x < threads.length; x++) threads[x].join(); + for (int x = 0; x < threads.length; x++) { assertNotEquals(null, list.get(x)); if (x > 0) { From ddf70ac450cb4a108fa7a0889d710ac743487d7a Mon Sep 17 00:00:00 2001 From: Alex Black Date: Mon, 3 Feb 2020 22:18:01 +1100 Subject: [PATCH 15/17] Avoid double printing of start/stop test in a few cases (#210) Signed-off-by: AlexDBlack --- .../src/test/java/org/nd4j/linalg/BaseNd4jTest.java | 4 +--- .../nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java | 2 -- .../src/test/java/org/nd4j/linalg/Nd4jTestsComparisonC.java | 2 -- .../test/java/org/nd4j/linalg/Nd4jTestsComparisonFortran.java | 2 -- 4 files changed, 1 insertion(+), 9 deletions(-) diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/BaseNd4jTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/BaseNd4jTest.java index e5a30b65b..0c2cb7a95 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/BaseNd4jTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/BaseNd4jTest.java @@ -81,10 +81,8 @@ public abstract class BaseNd4jTest extends BaseND4JTest { return ret; } - @Override @Before - public void beforeTest(){ - super.beforeTest(); + public void beforeTest2(){ Nd4j.factory().setOrder(ordering()); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java index 91cec2a5c..d96c0ed31 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java @@ -138,7 +138,6 @@ public class Nd4jTestsC extends BaseNd4jTest { @Before public void before() throws Exception { - super.beforeTest(); Nd4j.setDataType(DataType.DOUBLE); Nd4j.getRandom().setSeed(123); Nd4j.getExecutioner().enableDebugMode(false); @@ -147,7 +146,6 @@ public class Nd4jTestsC extends BaseNd4jTest { @After public void after() throws Exception { - super.afterTest(); Nd4j.setDataType(initialType); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsComparisonC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsComparisonC.java index 946da0fb3..e8ce1202c 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsComparisonC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsComparisonC.java @@ -57,13 +57,11 @@ public class Nd4jTestsComparisonC extends BaseNd4jTest { @Before public void before() throws Exception { - super.beforeTest(); DataTypeUtil.setDTypeForContext(DataType.DOUBLE); } @After public void after() throws Exception { - super.afterTest(); DataTypeUtil.setDTypeForContext(initialType); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsComparisonFortran.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsComparisonFortran.java index ef32c1f99..df929bbae 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsComparisonFortran.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsComparisonFortran.java @@ -58,7 +58,6 @@ public class Nd4jTestsComparisonFortran extends BaseNd4jTest { @Before public void before() throws Exception { - super.beforeTest(); DataTypeUtil.setDTypeForContext(DataType.DOUBLE); Nd4j.getRandom().setSeed(SEED); @@ -66,7 +65,6 @@ public class Nd4jTestsComparisonFortran extends BaseNd4jTest { @After public void after() throws Exception { - super.afterTest(); DataTypeUtil.setDTypeForContext(initialType); } From 57d5eb473b43f5b5b228e3fdc7cce2c3a442f388 Mon Sep 17 00:00:00 2001 From: Alex Black Date: Tue, 4 Feb 2020 15:38:06 +1100 Subject: [PATCH 16/17] Fixes for global pooling + masking with different mask datatypes (#212) * Fixes for global pooling + masking with different mask datatypes Signed-off-by: AlexDBlack * Global pooling backprop dtype fixes Signed-off-by: AlexDBlack --- .../pooling/GlobalPoolingMaskingTests.java | 56 +++++++++++++++++-- .../nn/layers/recurrent/RnnOutputLayer.java | 2 +- .../util/MaskedReductionUtil.java | 25 +++------ 3 files changed, 61 insertions(+), 22 deletions(-) diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/pooling/GlobalPoolingMaskingTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/pooling/GlobalPoolingMaskingTests.java index 957d22a08..aebf23673 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/pooling/GlobalPoolingMaskingTests.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/pooling/GlobalPoolingMaskingTests.java @@ -21,14 +21,13 @@ import org.deeplearning4j.nn.conf.ConvolutionMode; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.distribution.NormalDistribution; -import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; -import org.deeplearning4j.nn.conf.layers.GravesLSTM; -import org.deeplearning4j.nn.conf.layers.OutputLayer; -import org.deeplearning4j.nn.conf.layers.PoolingType; +import org.deeplearning4j.nn.conf.layers.*; +import org.deeplearning4j.nn.conf.layers.GlobalPoolingLayer; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; import org.junit.Test; import org.nd4j.linalg.activations.Activation; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastMulOp; import org.nd4j.linalg.factory.Nd4j; @@ -416,4 +415,53 @@ public class GlobalPoolingMaskingTests extends BaseDL4JTest { } } } + + @Test + public void testMaskLayerDataTypes(){ + + for(DataType dt : new DataType[]{DataType.FLOAT16, DataType.BFLOAT16, DataType.FLOAT, DataType.DOUBLE, + DataType.INT8, DataType.INT16, DataType.INT32, DataType.INT64, + DataType.UINT8, DataType.UINT16, DataType.UINT32, DataType.UINT64}){ + INDArray mask = Nd4j.rand(DataType.FLOAT, 2, 10).addi(0.3).castTo(dt); + + for(DataType networkDtype : new DataType[]{DataType.FLOAT16, DataType.BFLOAT16, DataType.FLOAT, DataType.DOUBLE}){ + + INDArray in = Nd4j.rand(networkDtype, 2, 5, 10); + INDArray label1 = Nd4j.rand(networkDtype, 2, 5); + INDArray label2 = Nd4j.rand(networkDtype, 2, 5, 10); + + for(PoolingType pt : PoolingType.values()) { + //System.out.println("Net: " + networkDtype + ", mask: " + dt + ", pt=" + pt); + + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + .list() + .layer(new GlobalPoolingLayer(pt)) + .layer(new OutputLayer.Builder().nIn(5).nOut(5).activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).build()) + .build(); + + MultiLayerNetwork net = new MultiLayerNetwork(conf); + net.init(); + + net.output(in, false, mask, null); + net.output(in, false, mask, null); + + + MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder() + + .list() + .layer(new RnnOutputLayer.Builder().nIn(5).nOut(5).activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).build()) + .build(); + + MultiLayerNetwork net2 = new MultiLayerNetwork(conf2); + net2.init(); + + net2.output(in, false, mask, mask); + net2.output(in, false, mask, mask); + + net.fit(in, label1, mask, null); + net2.fit(in, label2, mask, mask); + } + } + } + } } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/RnnOutputLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/RnnOutputLayer.java index 27cc0f9df..93f77ad67 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/RnnOutputLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/RnnOutputLayer.java @@ -131,7 +131,7 @@ public class RnnOutputLayer extends BaseOutputLayer Date: Tue, 4 Feb 2020 07:59:11 +0200 Subject: [PATCH 17/17] Shugeo solve linear (#191) * linear equations systems solve op. Initial commit. Signed-off-by: shugeo * Fixed compiling issues. Signed-off-by: shugeo * Linear equations systems solve. The next stage commit. Signed-off-by: shugeo * Added test for linear equations systems solve operation. Signed-off-by: shugeo * Added additional test and fixed lower matrix retrievance. * Implementation for solve of the systems of linear equations." Signed-off-by: shugeo * Refactored permutation generation. Signed-off-by: shugeo * Added restore for permutations batched with cuda helper for solve op. Signed-off-by: shugeo * Finished cuda implementation for solve op helpers. Signed-off-by: shugeo * Refactored cpu helpers for solve op. Signed-off-by: shugeo * Fix gtest output on Windows * Fixed issue with permutation matrix for cuda implementation. Signed-off-by: shugeo * Fixed issue with permutation matrix for cpu implementation. Signed-off-by: shugeo * Eliminated waste comments. Signed-off-by: shugeo * LinearSolve added * Mapping added * Javadoc added * Refactored implementation of triangular_solve helpers and tests for solve matrix equations generally. Signed-off-by: shugeo * Added a test for solve op. Signed-off-by: shugeo * Solve test added * Fix for TF import Co-authored-by: Serhii Shepel <9946053+sshepel@users.noreply.github.com> Co-authored-by: raver119 Co-authored-by: Alexander Stoyakin --- .../declarable/generic/parity_ops/solve.cpp | 75 ++++++++ .../ops/declarable/headers/parity_ops.h | 18 ++ .../ops/declarable/helpers/cpu/lup.cpp | 75 ++++++-- .../ops/declarable/helpers/cpu/solve.cpp | 100 +++++++++++ .../helpers/cpu/triangular_solve.cpp | 23 ++- .../ops/declarable/helpers/cuda/solve.cu | 140 +++++++++++++++ .../helpers/cuda/triangular_solve.cu | 67 ++++---- .../include/ops/declarable/helpers/solve.h | 34 ++++ .../layers_tests/DeclarableOpsTests11.cpp | 160 ++++++++++++++++++ .../layers_tests/DeclarableOpsTests12.cpp | 30 ++++ libnd4j/tests_cpu/run_tests.sh | 6 +- .../converters/ImportClassMapping.java | 3 +- .../org/nd4j/linalg/api/buffer/DataType.java | 22 +++ .../linalg/api/ops/custom/LinearSolve.java | 77 +++++++++ .../nd4j/linalg/custom/CustomOpsTests.java | 46 +++++ 15 files changed, 818 insertions(+), 58 deletions(-) create mode 100644 libnd4j/include/ops/declarable/generic/parity_ops/solve.cpp create mode 100644 libnd4j/include/ops/declarable/helpers/cpu/solve.cpp create mode 100644 libnd4j/include/ops/declarable/helpers/cuda/solve.cu create mode 100644 libnd4j/include/ops/declarable/helpers/solve.h create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/LinearSolve.java diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/solve.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/solve.cpp new file mode 100644 index 000000000..5790ae960 --- /dev/null +++ b/libnd4j/include/ops/declarable/generic/parity_ops/solve.cpp @@ -0,0 +1,75 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit, K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// Created by GS at 01/22/2020 +// + +#include +#if NOT_EXCLUDED(OP_solve) + +#include +#include +namespace nd4j { + namespace ops { + CUSTOM_OP_IMPL(solve, 2, 1, false, 0, 0) { + auto a = INPUT_VARIABLE(0); + auto b = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); + + bool useAdjoint = false; + + if (block.numB() > 0) { + useAdjoint = B_ARG(0); + } + + REQUIRE_TRUE(a->rankOf() >=2, 0, "solve: The rank of input left tensor should not be less than 2, but %i is given", a->rankOf()); + REQUIRE_TRUE(b->rankOf() >=2, 0, "solve: The rank of input right tensor should not be less than 2, but %i is given", b->rankOf()); + + REQUIRE_TRUE(a->sizeAt(-1) == a->sizeAt(-2), 0, "solve: The last two dimmensions should be equal, but %i and %i are given", a->sizeAt(-1), a->sizeAt(-2)); + REQUIRE_TRUE(a->sizeAt(-1) == b->sizeAt(-2), 0, "solve: The last dimmension of left part should be equal to prelast of right part, but %i and %i are given", a->sizeAt(-1), b->sizeAt(-2)); + auto input = a; + if (useAdjoint) { + auto adjointA = a->ulike(); + helpers::adjointMatrix(block.launchContext(), a, &adjointA); + input = new NDArray(adjointA); //.detach(); + }; + + auto res = helpers::solveFunctor(block.launchContext(), input, b, useAdjoint, z); + if (input != a) + delete input; + + return Status::OK(); + } + + DECLARE_SHAPE_FN(solve) { + auto in0 = inputShape->at(1); + auto in1 = inputShape->at(1); + auto luShape = ShapeBuilders::copyShapeInfoAndType(in1, in0, true, block.workspace()); + + return SHAPELIST(CONSTANT(luShape)); + } + + DECLARE_TYPES(solve) { + getOpDescriptor() + ->setAllowedInputTypes({ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_FLOATS}) + ->setSameMode(false); + } + } +} + +#endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/headers/parity_ops.h b/libnd4j/include/ops/declarable/headers/parity_ops.h index c5d0ff207..1b5bda29a 100644 --- a/libnd4j/include/ops/declarable/headers/parity_ops.h +++ b/libnd4j/include/ops/declarable/headers/parity_ops.h @@ -1076,6 +1076,24 @@ namespace nd4j { DECLARE_CUSTOM_OP(triangular_solve, 2, 1, true, 0, 0); #endif + /** + * solve op. - solve systems of linear equations - general method. + * + * input params: + * 0 - the tensor with dimension (x * y * z * ::: * M * M) - left parts of equations + * 1 - the tensor with dimension (x * y * z * ::: * M * K) - right parts of equations + * + * boolean args: + * 0 - adjoint - default is false (optional) - indicate input matrix or its adjoint (hermitian addition) should be used + * + * return value: + * tensor with dimension (x * y * z * ::: * M * K) with solutions + * + */ + #if NOT_EXCLUDED(OP_solve) + DECLARE_CUSTOM_OP(solve, 2, 1, true, 0, 0); + #endif + /** * lu op. - make LUP decomposition of given batch of 2D square matricies * diff --git a/libnd4j/include/ops/declarable/helpers/cpu/lup.cpp b/libnd4j/include/ops/declarable/helpers/cpu/lup.cpp index 9c7cb1bfe..2856e73b9 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/lup.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/lup.cpp @@ -237,25 +237,65 @@ namespace helpers { samediff::Threads::parallel_tad(loop, currentRow + 1, rowNum, 1); } + template + static void doolitleLU(LaunchContext* context, NDArray* compound, Nd4jLong rowNum) { + auto input = compound->dup(); + compound->nullify(); + + // Decomposing matrix into Upper and Lower + // triangular matrix + for (auto i = 0; i < rowNum; i++) { + + // Upper Triangular + for (auto k = i; k < rowNum; k++) { + + // Summation of L(i, j) * U(j, k) + int sum = 0; + for (int j = 0; j < i; j++) + sum += compound->t(i,j) * compound->t(j,k); + + // Evaluating U(i, k) + compound->t(i, k) = input.t(i, k) - sum; + } + + // Lower Triangular + for (int k = i + 1; k < rowNum; k++) { + // Summation of L(k, j) * U(j, i) + int sum = 0; + for (int j = 0; j < i; j++) + sum += compound->t(k,j) * compound->t(j, i); + + // Evaluating L(k, i) + compound->t(k, i) = (input.t(k, i) - sum) / compound->t(i,i); + } + } + } + template static void luNN_(LaunchContext *context, NDArray* compound, NDArray* permutation, Nd4jLong rowNum) { //const int rowNum = compound->rows(); // const int columnNum = output->columns(); - permutation->linspace(0); - auto permutationBuf = permutation->bufferAsT(); //dataBuffer()->primaryAsT(); - auto compoundBuf = compound->bufferAsT(); - auto compoundShape = compound->shapeInfo(); - auto permutationShape = permutation->shapeInfo(); - for (auto i = 0; i < rowNum - 1; i++) { - auto pivotIndex = argmaxCol(i, compoundBuf, compoundShape); - if (pivotIndex < 0) { - throw std::runtime_error("helpers::luNN_: input matrix is singular."); - } - math::nd4j_swap(permutationBuf[shape::getIndexOffset(i, permutationShape)], permutationBuf[shape::getIndexOffset(pivotIndex, permutationShape)]); - swapRows(compoundBuf, compoundShape, i, pivotIndex); + if (permutation) { // LUP algorithm + permutation->linspace(0); + auto permutationBuf = permutation->bufferAsT(); //dataBuffer()->primaryAsT(); + auto compoundBuf = compound->bufferAsT(); + auto compoundShape = compound->shapeInfo(); + auto permutationShape = permutation->shapeInfo(); + for (auto i = 0; i < rowNum - 1; i++) { + auto pivotIndex = argmaxCol(i, compoundBuf, compoundShape); + if (pivotIndex < 0) { + throw std::runtime_error("helpers::luNN_: input matrix is singular."); + } + math::nd4j_swap(permutationBuf[shape::getIndexOffset(i, permutationShape)], + permutationBuf[shape::getIndexOffset(pivotIndex, permutationShape)]); + swapRows(compoundBuf, compoundShape, i, pivotIndex); - processColumns(i, rowNum, compoundBuf, compoundShape); + processColumns(i, rowNum, compoundBuf, compoundShape); + } + } + else { // Doolitle algorithm with LU decomposition + doolitleLU(context, compound, rowNum); } } @@ -265,17 +305,20 @@ namespace helpers { output->assign(input); // fill up output tensor with zeros ResultSet outputs = output->allTensorsAlongDimension({-2, -1}); - ResultSet permutations = permutationVectors->allTensorsAlongDimension({-1}); + ResultSet permutations; + if (permutationVectors) + permutations = permutationVectors->allTensorsAlongDimension({-1}); + auto loop = PRAGMA_THREADS_FOR { for (auto i = start; i < stop; i += increment) { - luNN_(context, outputs.at(i), permutations.at(i), n); + luNN_(context, outputs.at(i), permutationVectors?permutations.at(i):nullptr, n); } }; samediff::Threads::parallel_for(loop, 0, outputs.size(), 1); } void lu(LaunchContext *context, NDArray* input, NDArray* output, NDArray* permutation) { - BUILD_DOUBLE_SELECTOR(input->dataType(), permutation->dataType(), lu_, (context, input, output, permutation), FLOAT_TYPES, INDEXING_TYPES); + BUILD_DOUBLE_SELECTOR(input->dataType(), permutation?permutation->dataType():DataType::INT32, lu_, (context, input, output, permutation), FLOAT_TYPES, INDEXING_TYPES); } // BUILD_DOUBLE_TEMPLATE(template NDArray lu_, (LaunchContext *context, NDArray* input, NDArray* output, NDArray* permutation), FLOAT_TYPES, INDEXING_TYPES); diff --git a/libnd4j/include/ops/declarable/helpers/cpu/solve.cpp b/libnd4j/include/ops/declarable/helpers/cpu/solve.cpp new file mode 100644 index 000000000..8583d9cba --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cpu/solve.cpp @@ -0,0 +1,100 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit, K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author GS +// +#include +#include +#include +#include +#include + +#include "../triangular_solve.h" +#include "../lup.h" +#include "../solve.h" + +namespace nd4j { +namespace ops { +namespace helpers { + +// --------------------------------------------------------------------------------------------------------------------------------------- // + template + static void adjointMatrix_(nd4j::LaunchContext* context, NDArray const* input, NDArray* output) { + auto inputPart = input->allTensorsAlongDimension({-2, -1}); + auto outputPart = output->allTensorsAlongDimension({-2, -1}); + output->assign(input); + auto batchLoop = PRAGMA_THREADS_FOR { + for (auto batch = start; batch < stop; batch += increment) { + for (auto r = 0; r < input->rows(); r++) { + for (auto c = 0; c < r; c++) { + math::nd4j_swap(outputPart[batch]->t(r, c) , outputPart[batch]->t(c, r)); + } + } + } + }; + samediff::Threads::parallel_tad(batchLoop, 0, inputPart.size(), 1); + } + +// --------------------------------------------------------------------------------------------------------------------------------------- // + template + static int solveFunctor_(nd4j::LaunchContext * context, NDArray* leftInput, NDArray* rightInput, bool const adjoint, NDArray* output) { + + // stage 1: LU decomposition batched + auto leftOutput = leftInput->ulike(); + auto permuShape = rightInput->getShapeAsVector(); permuShape.pop_back(); + auto permutations = NDArrayFactory::create('c', permuShape, context); + helpers::lu(context, leftInput, &leftOutput, &permutations); + auto P = leftInput->ulike(); //permutations batched matrix + P.nullify(); // to fill up matricies with zeros + auto PPart = P.allTensorsAlongDimension({-2,-1}); + auto permutationsPart = permutations.allTensorsAlongDimension({-1}); + + for (auto batch = 0; batch < permutationsPart.size(); ++batch) { + for (auto row = 0; row < PPart[batch]->rows(); ++row) { + PPart[batch]->t(row, permutationsPart[batch]->t(row)) = T(1.f); + } + } + + auto leftLower = leftOutput.dup(); + auto rightOutput = rightInput->ulike(); + auto rightPermuted = rightOutput.ulike(); + MmulHelper::matmul(&P, rightInput, &rightPermuted, 0, 0); + ResultSet leftLowerPart = leftLower.allTensorsAlongDimension({-2, -1}); + for (auto i = 0; i < leftLowerPart.size(); i++) { + for (auto r = 0; r < leftLowerPart[i]->rows(); r++) + leftLowerPart[i]->t(r,r) = (T)1.f; + } + // stage 2: triangularSolveFunctor for Lower with given b + helpers::triangularSolveFunctor(context, &leftLower, &rightPermuted, true, false, &rightOutput); + // stage 3: triangularSolveFunctor for Upper with output of previous stage + helpers::triangularSolveFunctor(context, &leftOutput, &rightOutput, false, false, output); + + return Status::OK(); + } + +// --------------------------------------------------------------------------------------------------------------------------------------- // + int solveFunctor(nd4j::LaunchContext * context, NDArray* leftInput, NDArray* rightInput, bool const adjoint, NDArray* output) { + BUILD_SINGLE_SELECTOR(leftInput->dataType(), return solveFunctor_, (context, leftInput, rightInput, adjoint, output), FLOAT_TYPES); + } +// --------------------------------------------------------------------------------------------------------------------------------------- // + void adjointMatrix(nd4j::LaunchContext* context, NDArray const* input, NDArray* output) { + BUILD_SINGLE_SELECTOR(input->dataType(), adjointMatrix_, (context, input, output), FLOAT_TYPES); + } +// --------------------------------------------------------------------------------------------------------------------------------------- // +} +} +} diff --git a/libnd4j/include/ops/declarable/helpers/cpu/triangular_solve.cpp b/libnd4j/include/ops/declarable/helpers/cpu/triangular_solve.cpp index ab409a0c6..e904d219c 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/triangular_solve.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/triangular_solve.cpp @@ -41,13 +41,16 @@ namespace helpers { template static void lowerTriangularSolve(nd4j::LaunchContext * context, NDArray* leftInput, NDArray* rightInput, bool adjoint, NDArray* output) { auto rows = leftInput->rows(); + auto cols = rightInput->columns(); //output->t(0,0) = rightInput->t(0,0) / leftInput->t(0,0); for (auto r = 0; r < rows; r++) { - auto sum = rightInput->t(r, 0); - for (auto c = 0; c < r; c++) { - sum -= leftInput->t(r,c) * output->t(c, 0); + for (auto j = 0; j < cols; j++) { + auto sum = rightInput->t(r, j); + for (auto c = 0; c < r; c++) { + sum -= leftInput->t(r, c) * output->t(c, j); + } + output->t(r, j) = sum / leftInput->t(r, r); } - output->t(r, 0) = sum / leftInput->t(r, r); } } @@ -68,13 +71,15 @@ namespace helpers { template static void upperTriangularSolve(nd4j::LaunchContext * context, NDArray* leftInput, NDArray* rightInput, bool adjoint, NDArray* output) { auto rows = leftInput->rows(); - + auto cols = rightInput->columns(); for (auto r = rows; r > 0; r--) { - auto sum = rightInput->t(r - 1, 0); - for (auto c = r; c < rows; c++) { - sum -= leftInput->t(r - 1, c) * output->t(c, 0); + for (auto j = 0; j < cols; j++) { + auto sum = rightInput->t(r - 1, j); + for (auto c = r; c < rows; c++) { + sum -= leftInput->t(r - 1, c) * output->t(c, j); + } + output->t(r - 1, j) = sum / leftInput->t(r - 1, r - 1); } - output->t(r - 1, 0) = sum / leftInput->t(r - 1, r - 1); } } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/solve.cu b/libnd4j/include/ops/declarable/helpers/cuda/solve.cu new file mode 100644 index 000000000..6437b80bd --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cuda/solve.cu @@ -0,0 +1,140 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit, K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author GS +// + +#include +#include +#include +#include + +#include +#include +#include "../triangular_solve.h" +#include "../lup.h" +#include "../solve.h" + +namespace nd4j { + namespace ops { + namespace helpers { + + template + static __global__ void oneOnDiagonalKernel(T* ioBuf, Nd4jLong* ioShape, Nd4jLong* tadShape, Nd4jLong* tadOffsets, Nd4jLong batchNum, Nd4jLong rowNum) { + for (auto i = blockIdx.x; i < batchNum; i += gridDim.x) { + auto matrixPart = ioBuf + tadOffsets[i]; + for (auto j = threadIdx.x; j < rowNum; j += blockDim.x) { + Nd4jLong pos[] = {j, j}; + auto offset = shape::getOffset(tadShape, pos); + + matrixPart[offset] = T(1.f); + } + } + } + + template + static __global__ void restorePermutationsKernel(T* PBuf, Nd4jLong* PShapeInfo, int const* permutationsBuf, + Nd4jLong* PTadShapeInfo, Nd4jLong* PTadSOffsets, Nd4jLong* permutationsTadShapeInfo, Nd4jLong* permutationsTadOffsets, Nd4jLong batchNum, Nd4jLong rowNum) { + for (auto batch = blockIdx.x; batch < batchNum; batch += gridDim.x) { + auto permutations = permutationsBuf + permutationsTadOffsets[batch]; + auto P = PBuf + PTadSOffsets[batch]; + + for (auto row = threadIdx.x; row < rowNum; row += blockDim.x) { + //auto posX[] = {row}; + Nd4jLong posZ[] = {row, permutations[row]}; + auto zOffset = shape::getOffset(PTadShapeInfo, posZ); + P[zOffset] = T(1.f); + } + } + } + + template + static int solveFunctor_(nd4j::LaunchContext * context, NDArray* leftInput, NDArray* rightInput, + bool adjoint, NDArray* output) { + NDArray::prepareSpecialUse({output}, {leftInput, rightInput}); + // stage 1: LU decomposition batched + auto leftOutput = leftInput->ulike(); + auto permuShape = rightInput->getShapeAsVector(); permuShape.pop_back(); + auto permutations = NDArrayFactory::create('c', permuShape, context); + helpers::lu(context, leftInput, &leftOutput, &permutations); + auto leftLower = leftOutput.dup(); + auto rightOutput = rightInput->ulike(); + auto leftLowerTad = ConstantTadHelper::getInstance()->tadForDimensions(leftLower.getShapeInfo(), {-2, -1}); + auto stream = context->getCudaStream(); + oneOnDiagonalKernel<<<128, 256, 256, *stream>>>(leftLower.dataBuffer()->specialAsT(), leftLower.specialShapeInfo(), leftLowerTad.specialShapeInfo(), leftLowerTad.specialOffsets(), leftLowerTad.numberOfTads(), leftLower.sizeAt(-1)); + auto P = leftOutput.ulike(); P.nullify(); + auto PTad = ConstantTadHelper::getInstance()->tadForDimensions(P.getShapeInfo(), {-2, -1}); + auto permutationsTad = ConstantTadHelper::getInstance()->tadForDimensions(permutations.getShapeInfo(), {-1}); + restorePermutationsKernel<<<128, 256, 256, *stream>>>(P.dataBuffer()->specialAsT(), P.specialShapeInfo(), permutations.dataBuffer()->specialAsT(), + PTad.specialShapeInfo(), PTad.specialOffsets(), permutationsTad.specialShapeInfo(), permutationsTad.specialOffsets(), permutationsTad.numberOfTads(), permutations.sizeAt(-1)); + P.tickWriteDevice(); + auto rightPart = rightInput->ulike(); + MmulHelper::matmul(&P, rightInput, &rightPart, 0, 0); + + // stage 2: triangularSolveFunctor for Lower with given b + helpers::triangularSolveFunctor(context, &leftLower, &rightPart, true, false, &rightOutput); + // stage 3: triangularSolveFunctor for Upper with output of previous stage + helpers::triangularSolveFunctor(context, &leftOutput, &rightOutput, false, false, output); + NDArray::registerSpecialUse({output}, {leftInput, rightInput}); + + return Status::OK(); + } + + int solveFunctor(nd4j::LaunchContext * context, NDArray* leftInput, NDArray* rightInput, bool adjoint, NDArray* output) { + BUILD_SINGLE_SELECTOR(leftInput->dataType(), return solveFunctor_, (context, leftInput, rightInput, adjoint, output), FLOAT_TYPES); + } + + template + static __global__ void adjointKernel(T* output, Nd4jLong batchSize, Nd4jLong rows, Nd4jLong columns, Nd4jLong* outputTads, + Nd4jLong* outputOffsets) { + + for (auto b = blockIdx.x; b < batchSize; b += gridDim.x) { + auto outputPart = output + outputOffsets[b]; + for (auto r = threadIdx.x; r < rows; r += blockDim.x) { + for (auto c = threadIdx.y; c < r; c += blockDim.y) { + Nd4jLong zPos[] = {r, c}; + Nd4jLong xPos[] = {c, r}; + auto zIndex = shape::getOffset(outputTads, zPos); + auto xIndex = shape::getOffset(outputTads, xPos); + math::nd4j_swap(outputPart[zIndex], outputPart[xIndex]); + } + } + } + + } + + template + static void adjointMatrix_(nd4j::LaunchContext* context, NDArray const* input, NDArray* output) { + NDArray::prepareSpecialUse({output}, {input}); + auto inputTads = ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), {-2, -1}); + auto outputTads = ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), {-2, -1}); + auto stream = context->getCudaStream(); + auto outputBuf = reinterpret_cast(output->specialBuffer()); + auto rows = input->sizeAt(-2); + auto columns = input->sizeAt(-1); + output->assign(input); + adjointKernel<<<128, 256, 256, *stream>>>(outputBuf, outputTads.numberOfTads(), rows, columns, outputTads.specialShapeInfo(), outputTads.specialOffsets()); + NDArray::registerSpecialUse({output}, {input}); + } + + void adjointMatrix(nd4j::LaunchContext* context, NDArray const* input, NDArray* output) { + BUILD_SINGLE_SELECTOR(input->dataType(), adjointMatrix_, (context, input, output), FLOAT_TYPES); + } + + } + } +} diff --git a/libnd4j/include/ops/declarable/helpers/cuda/triangular_solve.cu b/libnd4j/include/ops/declarable/helpers/cuda/triangular_solve.cu index 8846be45c..6f5fe6b8c 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/triangular_solve.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/triangular_solve.cu @@ -44,24 +44,26 @@ namespace nd4j { static __device__ void lowerTriangularSolve(T const* leftInput, Nd4jLong const* leftInputShape, T const* rightInput, Nd4jLong const* rightInputShape, bool const adjoint, T* output, Nd4jLong* outputShape, - Nd4jLong rows) { + Nd4jLong rows, Nd4jLong cols) { for (auto r = 0; r < rows; r++) { - Nd4jLong posY[] = {r, 0}; - Nd4jLong posX[] = {r, r}; - auto xIndex = shape::getOffset(leftInputShape, posX, 0); - auto yIndex = shape::getOffset(rightInputShape, posY, 0); - auto zIndex = shape::getOffset(outputShape, posY, 0); + for (auto j = 0; j < cols; j++) { + Nd4jLong posY[] = {r, j}; + Nd4jLong posX[] = {r, r}; + auto xIndex = shape::getOffset(leftInputShape, posX, 0); + auto yIndex = shape::getOffset(rightInputShape, posY, 0); + auto zIndex = shape::getOffset(outputShape, posY, 0); - auto sum = rightInput[yIndex]; - for (auto c = 0; c < r; c++) { - Nd4jLong posZ[] = {c, 0}; - Nd4jLong pos[] = {r, c}; - auto xcIndex = shape::getOffset(leftInputShape, pos, 0); - auto zcIndex = shape::getOffset(outputShape, posZ, 0); - sum -= leftInput[xcIndex] * output[zcIndex]; + auto sum = rightInput[yIndex]; + for (auto c = 0; c < r; c++) { + Nd4jLong posZ[] = {c, j}; + Nd4jLong pos[] = {r, c}; + auto xcIndex = shape::getOffset(leftInputShape, pos, 0); + auto zcIndex = shape::getOffset(outputShape, posZ, 0); + sum -= leftInput[xcIndex] * output[zcIndex]; + } + output[zIndex] = sum / leftInput[xIndex]; } - output[zIndex] = sum / leftInput[xIndex]; } } @@ -82,23 +84,25 @@ namespace nd4j { template static __device__ void upperTriangularSolve(T const* leftInput, Nd4jLong const* leftInputShape, T const* rightInput, Nd4jLong const* rightInputShape, bool const adjoint, T* output, - Nd4jLong* outputShape, Nd4jLong rows) { + Nd4jLong* outputShape, Nd4jLong rows, Nd4jLong cols) { for (auto r = rows; r > 0; r--) { - Nd4jLong posY[] = {r - 1, 0}; - Nd4jLong posX[] = {r - 1, r - 1}; - auto xIndex = shape::getOffset(leftInputShape, posX, 0); - auto yIndex = shape::getOffset(rightInputShape, posY, 0); - auto zIndex = shape::getOffset(outputShape, posY, 0); - auto sum = rightInput[yIndex]; - for (auto c = r; c < rows; c++) { - Nd4jLong posZ[] = {c, 0}; - Nd4jLong pos[] = {r - 1, c}; - auto zcIndex = shape::getOffset(outputShape, posZ, 0); - auto xcIndex = shape::getOffset(leftInputShape, pos, 0); - sum -= leftInput[xcIndex] * output[zcIndex]; + for (auto j = 0; j < cols; j++) { + Nd4jLong posY[] = {r - 1, j}; + Nd4jLong posX[] = {r - 1, r - 1}; + auto xIndex = shape::getOffset(leftInputShape, posX, 0); + auto yIndex = shape::getOffset(rightInputShape, posY, 0); + auto zIndex = shape::getOffset(outputShape, posY, 0); + auto sum = rightInput[yIndex]; + for (auto c = r; c < rows; c++) { + Nd4jLong posZ[] = {c, j}; + Nd4jLong pos[] = {r - 1, c}; + auto zcIndex = shape::getOffset(outputShape, posZ, 0); + auto xcIndex = shape::getOffset(leftInputShape, pos, 0); + sum -= leftInput[xcIndex] * output[zcIndex]; + } + output[zIndex] = sum / leftInput[xIndex]; } - output[zIndex] = sum / leftInput[xIndex]; } } @@ -109,8 +113,11 @@ namespace nd4j { Nd4jLong* tadRightOffset, Nd4jLong* tadOutputShape, Nd4jLong* tadOutputOffset, Nd4jLong batchNum) { __shared__ Nd4jLong rows; + __shared__ Nd4jLong cols; + if (threadIdx.x == 0) { rows = shape::sizeAt(leftPartShape, -2); + cols = shape::sizeAt(rightPartShape, -1); } __syncthreads(); @@ -123,9 +130,9 @@ namespace nd4j { auto pRightPart = rightInput + tadRightOffset[i]; auto pOutputPart = output + tadOutputOffset[i]; if (lower) { - lowerTriangularSolve(pLeftPart, tadLeftShape, pRightPart, tadRightShape, adjoint, pOutputPart, tadOutputShape, rows); + lowerTriangularSolve(pLeftPart, tadLeftShape, pRightPart, tadRightShape, adjoint, pOutputPart, tadOutputShape, rows, cols); } else { - upperTriangularSolve(pLeftPart, tadLeftShape, pRightPart, tadRightShape, adjoint, pOutputPart, tadOutputShape, rows); + upperTriangularSolve(pLeftPart, tadLeftShape, pRightPart, tadRightShape, adjoint, pOutputPart, tadOutputShape, rows, cols); } } } diff --git a/libnd4j/include/ops/declarable/helpers/solve.h b/libnd4j/include/ops/declarable/helpers/solve.h new file mode 100644 index 000000000..d097fa217 --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/solve.h @@ -0,0 +1,34 @@ +/******************************************************************************* + * Copyright (c) Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author GS +// +#ifndef __SOLVE__H_HELPERS__ +#define __SOLVE__H_HELPERS__ +#include +#include + +namespace nd4j { +namespace ops { +namespace helpers { + + int solveFunctor(nd4j::LaunchContext* context, NDArray* leftInput, NDArray* rightInput, bool adjoint, NDArray* output); + void adjointMatrix(nd4j::LaunchContext* context, NDArray const* input, NDArray* output); +} +} +} +#endif diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests11.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests11.cpp index 71ebdc7e6..de4bdc31b 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests11.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests11.cpp @@ -1541,6 +1541,166 @@ TEST_F(DeclarableOpsTests11, summaryStatsData_test1) { delete []arr; } +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, Solve_Test_1) { + + auto a = NDArrayFactory::create('c', {3, 3}, { + 2.f, -1.f, -2.f, -4.f, 6.f, 3.f, -4.f, -2.f, 8.f + }); + + auto b = NDArrayFactory::create('c', {3, 1}, { + 2.f, 4.f, 3.f + }); + + auto exp = NDArrayFactory::create('c', {3, 1}, { + 7.625f, 3.25f, 5.f + }); + + nd4j::ops::solve op; + + auto res = op.evaluate({&a, &b}); + ASSERT_EQ(res->status(), ND4J_STATUS_OK); + auto z = res->at(0); + +// z->printIndexedBuffer("Solve of 3x3"); + + ASSERT_TRUE(exp.equalsTo(z)); + delete res; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, Solve_Test_2) { + + auto a = NDArrayFactory::create('c', {4, 4}, { + 1.f, 1.f, 1.f, 1.f, + 0.f, 1.f, 1.f, 0.f, + 0.f, 0.f, 2.f, 1.f, + 0.f, 0.f, 0.f, 3.f, + }); + + auto b = NDArrayFactory::create('c', {4, 1}, { + 2.f, 4.f, 2.f, 4.f + }); + + auto exp = NDArrayFactory::create('c', {4, 1}, { + -3.3333333f, 3.6666666f, 0.333333f, 1.3333333f + }); + + nd4j::ops::solve op; + + auto res = op.evaluate({&a, &b}); + ASSERT_EQ(res->status(), ND4J_STATUS_OK); + auto z = res->at(0); + +// z->printIndexedBuffer("Solve 4x4"); + + ASSERT_TRUE(exp.equalsTo(z)); + delete res; +} +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, Solve_Test_3) { + + auto a = NDArrayFactory::create('c', {2, 4, 4}, { + 1.f, 1.f, 1.f, 1.f, + 0.f, 1.f, 1.f, 0.f, + 0.f, 0.f, 2.f, 1.f, + 0.f, 0.f, 0.f, 3.f, + + 3.f, 0.f, 0.f, 0.f, + 2.f, 1.f, 0.f, 0.f, + 1.f, 0.f, 1.f, 0.f, + 1.f, 1.f, 1.f, 1.f + + }); + + auto b = NDArrayFactory::create('c', {2, 4, 1}, { + 2.f, 4.f, 2.f, 4.f, + 4.f, 2.f, 4.f, 2.f + }); + + auto exp = NDArrayFactory::create('c', {2, 4, 1}, { + -3.3333333f, 3.6666666f, 0.333333f, 1.3333333f, + 1.333333f, -0.6666667f, 2.6666667f, -1.3333333f + }); + + nd4j::ops::solve op; + + auto res = op.evaluate({&a, &b}); + ASSERT_EQ(res->status(), ND4J_STATUS_OK); + auto z = res->at(0); + +// z->printIndexedBuffer("Solve 4x4"); + + ASSERT_TRUE(exp.equalsTo(z)); + delete res; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, Solve_Test_4) { + + auto a = NDArrayFactory::create('c', {2, 2, 2}, { + 0.7788f, 0.8012f, 0.7244f, 0.2309f, + 0.7271f, 0.1804f, 0.5056f, 0.8925f + }); + + auto b = NDArrayFactory::create('c', {2, 2, 2}, { + 0.7717f, 0.9281f, 0.9846f, 0.4838f, + 0.6433f, 0.6041f, 0.6501f, 0.7612f + }); + + auto exp = NDArrayFactory::create('c', {2, 2, 2}, { +// 1.524494767f, 0.432706356f,-0.518630624f, 0.737760842f, +// 0.819143713f, 0.720401764f, 0.264349997f, 0.444699198f + 1.5245394f, 0.4326952f, -0.51873577f, 0.7377896f, + 0.81915987f, 0.72049433f, 0.2643504f, 0.44472617f + }); + + nd4j::ops::solve op; + + auto res = op.evaluate({&a, &b}); + ASSERT_EQ(res->status(), ND4J_STATUS_OK); + auto z = res->at(0); + +// z->printBuffer("4 Solve 4x4"); +// exp.printBuffer("4 Expec 4x4"); + + ASSERT_TRUE(exp.equalsTo(z)); + delete res; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, Solve_Test_5) { + + auto a = NDArrayFactory::create('c', {3, 3}, { + 0.7788f, 0.8012f, 0.7244f, + 0.2309f, 0.7271f, 0.1804f, + 0.5056f, 0.8925f, 0.5461f + }); + + auto b = NDArrayFactory::create('c', {3, 3}, { + 0.7717f, 0.9281f, 0.9846f, + 0.4838f, 0.6433f, 0.6041f, + 0.6501f, 0.7612f, 0.7605f + }); + + auto exp = NDArrayFactory::create('c', {3, 3}, { + 1.5504692f, 1.8953944f, 2.2765768f, + 0.03399149f, 0.2883001f, 0.5377323f, + -0.8774802f, -1.2155888f, -1.8049058f + }); + + nd4j::ops::solve op; + + auto res = op.evaluate({&a, &b}, {true}); + ASSERT_EQ(res->status(), ND4J_STATUS_OK); + auto z = res->at(0); + + z->printBuffer("4 Solve 4x4"); + exp.printBuffer("4 Expec 4x4"); + + ASSERT_TRUE(exp.equalsTo(z)); + delete res; +} /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test1) { diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp index 142a3dbd4..6025216f9 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp @@ -3008,3 +3008,33 @@ TEST_F(DeclarableOpsTests12, TriangularSolve_Test_5) { ASSERT_TRUE(exp.equalsTo(z)); delete res; } + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, TriangularSolve_Test_6) { + + auto a = NDArrayFactory::create('c', {4, 4}, { + 5.f, 1.f, -3.f, 3.f, + 0.f, 1.f, 1.f, -1.f, + 0.f, 0.f, 2.f, -9.f, + 0.f, 0.f, 0.f, 4.f + }); + + auto b = NDArrayFactory::create('c', {4, 2}, { + 5.f, 1.f, 2.f, 1.f, 0.f, 1.f, -3.f, 1.f + }); + + auto exp = NDArrayFactory::create('c', {4, 2}, { + 1.f,0.2f, 1.f,0.8f, 1.f,0.4f, 1.f,1.2f + }); + + nd4j::ops::triangular_solve op; + + auto res = op.evaluate({&a, &b}, {}, {}, {false, true}); + ASSERT_EQ(res->status(), ND4J_STATUS_OK); + auto z = res->at(0); + + z->printIndexedBuffer("TriangularSolve with adjoint"); + + ASSERT_TRUE(exp.equalsTo(z)); + delete res; +} \ No newline at end of file diff --git a/libnd4j/tests_cpu/run_tests.sh b/libnd4j/tests_cpu/run_tests.sh index 9b1271df6..8f412dee5 100755 --- a/libnd4j/tests_cpu/run_tests.sh +++ b/libnd4j/tests_cpu/run_tests.sh @@ -39,7 +39,7 @@ do done CHIP="${CHIP:-cpu}" -export GTEST_OUTPUT="xml:../target/surefire-reports/TEST-${CHIP}-results.xml" +export GTEST_OUTPUT="xml:surefire-reports/TEST-${CHIP}-results.xml" # On Mac, make sure it can find libraries for GCC export DYLD_LIBRARY_PATH=/usr/local/lib/gcc/8/:/usr/local/lib/gcc/7/:/usr/local/lib/gcc/6/:/usr/local/lib/gcc/5/ @@ -48,9 +48,11 @@ export DYLD_LIBRARY_PATH=/usr/local/lib/gcc/8/:/usr/local/lib/gcc/7/:/usr/local/ if [ -n "$BUILD_PATH" ]; then if which cygpath; then BUILD_PATH=$(cygpath -p $BUILD_PATH) - export GTEST_OUTPUT="xml:'..\target\surefire-reports\TEST-${CHIP}-results.xml'" fi export PATH="$PATH:$BUILD_PATH" fi ../blasbuild/${CHIP}/tests_cpu/layers_tests/runtests + +# Workaround to fix posix path conversion problem on Windows (http://mingw.org/wiki/Posix_path_conversion) +[ -f "${GTEST_OUTPUT#*:}" ] && cp -a surefire-reports/ ../target && rm -rf surefire-reports/ diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java index f804d8c95..3ed96fe9c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java @@ -623,7 +623,8 @@ public class ImportClassMapping { org.nd4j.linalg.api.ops.custom.Igammac.class, org.nd4j.linalg.api.ops.custom.Digamma.class, org.nd4j.linalg.api.ops.custom.Lu.class, - org.nd4j.linalg.api.ops.custom.TriangularSolve.class + org.nd4j.linalg.api.ops.custom.TriangularSolve.class, + org.nd4j.linalg.api.ops.custom.LinearSolve.class ); static { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/buffer/DataType.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/buffer/DataType.java index 7555bce21..94cfdca43 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/buffer/DataType.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/buffer/DataType.java @@ -16,26 +16,48 @@ package org.nd4j.linalg.api.buffer; +/** + * Enum lists supported data types. + * + */ public enum DataType { DOUBLE, FLOAT, + /** + * @deprecated Replaced by {@link DataType#FLOAT16}, use that instead + */ @Deprecated HALF, + /** + * @deprecated Replaced by {@link DataType#INT64}, use that instead + */ @Deprecated LONG, + /** + * @deprecated Replaced by {@link DataType#INT32}, use that instead + */ @Deprecated INT, + /** + * @deprecated Replaced by {@link DataType#INT16}, use that instead + */ @Deprecated SHORT, + /** + * @deprecated Replaced by {@link DataType#UINT8}, use that instead + */ @Deprecated UBYTE, + /** + * @deprecated Replaced by {@link DataType#INT8}, use that instead + */ @Deprecated BYTE, diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/LinearSolve.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/LinearSolve.java new file mode 100644 index 000000000..7c835006e --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/LinearSolve.java @@ -0,0 +1,77 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.custom; + +import lombok.NoArgsConstructor; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.imports.graphmapper.tf.TFGraphMapper; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.tensorflow.framework.AttrValue; +import org.tensorflow.framework.GraphDef; +import org.tensorflow.framework.NodeDef; + +import java.util.Collections; +import java.util.List; +import java.util.Map; + +@NoArgsConstructor +public class LinearSolve extends DynamicCustomOp { + + public LinearSolve(INDArray a, INDArray b, boolean adjoint) { + addInputArgument(a, b); + addBArgument(adjoint); + } + + public LinearSolve(INDArray a, INDArray b) { + this(a,b,false); + } + + public LinearSolve(SameDiff sameDiff, SDVariable a, SDVariable b, SDVariable adjoint) { + super(sameDiff, new SDVariable[] {a, b, adjoint}); + } + + public LinearSolve(SameDiff sameDiff, SDVariable a, SDVariable b, boolean adjoint) { + super(sameDiff, new SDVariable[] {a, b}); + addBArgument(adjoint); + } + + @Override + public String opName() { + return "solve"; + } + + @Override + public String tensorflowName() { + return "MatrixSolve"; + } + + @Override + public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { + boolean adjoint = attributesForNode.containsKey("adjoint") ? attributesForNode.get("adjoint").getB() : false; + addBArgument(adjoint); + } + + @Override + public List calculateOutputDataTypes(List dataTypes) { + int n = args().length; + Preconditions.checkState(dataTypes != null && dataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), dataTypes); + return Collections.singletonList(dataTypes.get(0)); + } +} diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java index 49ff345e7..b8d795460 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java @@ -1691,4 +1691,50 @@ public class CustomOpsTests extends BaseNd4jTest { assertEquals(e, x); } + + @Test + public void testLinearSolve() { + INDArray a = Nd4j.createFromArray(new float[]{ + 2.f, -1.f, -2.f, -4.f, 6.f, 3.f, -4.f, -2.f, 8.f + }).reshape(3, 3); + + INDArray b = Nd4j.createFromArray(new float[]{ + 2.f, 4.f, 3.f + }).reshape(3, 1); + + INDArray expected = Nd4j.createFromArray(new float[]{ + 7.625f, 3.25f, 5.f + }).reshape(3, 1); + + val op = new LinearSolve(a, b); + INDArray[] ret = Nd4j.exec(op); + + assertEquals(expected, ret[0]); + } + + @Test + public void testLinearSolveAdjust() { + INDArray a = Nd4j.createFromArray(new float[]{ + 0.7788f, 0.8012f, 0.7244f, + 0.2309f, 0.7271f, 0.1804f, + 0.5056f, 0.8925f, 0.5461f + }).reshape(3, 3); + + INDArray b = Nd4j.createFromArray(new float[]{ + 0.7717f, 0.9281f, 0.9846f, + 0.4838f, 0.6433f, 0.6041f, + 0.6501f, 0.7612f, 0.7605f + }).reshape(3, 3); + + INDArray expected = Nd4j.createFromArray(new float[]{ + 1.5504692f, 1.8953944f, 2.2765768f, + 0.03399149f, 0.2883001f , 0.5377323f, + -0.8774802f, -1.2155888f, -1.8049058f + }).reshape(3, 3); + + val op = new LinearSolve(a, b, true); + INDArray[] ret = Nd4j.exec(op); + + assertEquals(expected, ret[0]); + } }