From e910ce75ecd75e6e0c7f2b732a716d096fd3b0d4 Mon Sep 17 00:00:00 2001 From: Alex Black Date: Sat, 23 Nov 2019 20:06:12 +1100 Subject: [PATCH 01/30] Various Fixes (#75) * #8431 Cast loss function weights array automatically Signed-off-by: AlexDBlack * Add 'regex verbose mode' printing (ExecDebugListener) for TFGraphTestAllSameDiff' Signed-off-by: AlexDBlack * Class import mapping fix Signed-off-by: AlexDBlack * Reshape fixes Signed-off-by: AlexDBlack * Don't swallow first exception in NativeOpExecutioner.exec(CustomOp) Signed-off-by: AlexDBlack --- .../conf/preprocessor/TestPreProcessors.java | 32 +++++- .../keras/layers/core/KerasFlatten.java | 2 +- .../keras/layers/core/KerasReshape.java | 12 +-- .../preprocessors/ReshapePreprocessor.java | 102 +++++++++--------- .../Keras2ModelConfigurationTest.java | 3 + .../converters/ImportClassMapping.java | 2 + .../linalg/lossfunctions/impl/LossMCXENT.java | 2 +- .../lossfunctions/impl/LossSparseMCXENT.java | 2 +- .../nativecpu/ops/NativeOpExecutioner.java | 7 +- .../java/org/nd4j/autodiff/TestOpMapping.java | 4 +- .../TFGraphs/TFGraphTestAllSameDiff.java | 19 +++- .../lossfunctions/LossFunctionTest.java | 69 +++++++++++- 12 files changed, 189 insertions(+), 67 deletions(-) diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/preprocessor/TestPreProcessors.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/preprocessor/TestPreProcessors.java index e1ce77cd3..21fd3368a 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/preprocessor/TestPreProcessors.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/preprocessor/TestPreProcessors.java @@ -1,5 +1,6 @@ -/******************************************************************************* +/* ****************************************************************************** * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -28,6 +29,7 @@ import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.conf.layers.RnnOutputLayer; import org.deeplearning4j.nn.layers.convolution.ConvolutionLayer; import org.deeplearning4j.nn.layers.feedforward.dense.DenseLayer; +import org.deeplearning4j.nn.modelimport.keras.preprocessors.ReshapePreprocessor; import org.deeplearning4j.nn.weights.WeightInit; import org.junit.Test; import org.nd4j.linalg.activations.Activation; @@ -485,4 +487,32 @@ public class TestPreProcessors extends BaseDL4JTest { assertEquals(15 * 15 * 10, ((FeedForwardLayer) conf.getConf(1).getLayer()).getNIn()); } + + + @Test + public void testPreprocessorVertex(){ + for(boolean withMinibatchDim : new boolean[]{true, false}){ + long[] inShape = withMinibatchDim ? new long[]{-1, 32} : new long[]{32}; + long[] targetShape = withMinibatchDim ? new long[]{-1, 2, 4, 4} : new long[]{2, 4, 4}; + + for( long minibatch : new long[]{1, 3}) { + long[] inArrayShape = new long[]{minibatch, 32}; + long[] targetArrayShape = new long[]{minibatch, 2, 4, 4}; + long length = minibatch * 32; + + INDArray in = Nd4j.linspace(1, length, length).reshape('c', inArrayShape); + + ReshapePreprocessor pp = new ReshapePreprocessor(inShape, targetShape, withMinibatchDim); + + for( int i=0; i<3; i++ ) { + INDArray out = pp.preProcess(in, (int) minibatch, LayerWorkspaceMgr.noWorkspaces()); + INDArray expOut = in.reshape(targetArrayShape); + assertEquals(expOut, out); + + INDArray backprop = pp.backprop(expOut, (int)minibatch, LayerWorkspaceMgr.noWorkspaces()); + assertEquals(in, backprop); + } + } + } + } } diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasFlatten.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasFlatten.java index e0a6628a2..196f9d3d9 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasFlatten.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasFlatten.java @@ -111,7 +111,7 @@ public class KerasFlatten extends KerasLayer { // to RNN type. Otherwise we add this trivial preprocessor (since there's nothing to flatten). InputType.InputTypeFeedForward it = (InputType.InputTypeFeedForward) inputType[0]; val inputShape = new long[]{it.getSize()}; - preprocessor = new ReshapePreprocessor(inputShape, inputShape); + preprocessor = new ReshapePreprocessor(inputShape, inputShape, false); } return preprocessor; } diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasReshape.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasReshape.java index 4035e9298..e5f1375d1 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasReshape.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasReshape.java @@ -111,11 +111,11 @@ public class KerasReshape extends KerasLayer { } else { targetShape = new long[]{targetShape[1], targetShape[0], targetShape[2]}; } - preprocessor = new ReshapePreprocessor(inputShape, targetShape); + preprocessor = new ReshapePreprocessor(inputShape, targetShape, false); } else { // (dimOrder == DimOrder.TENSORFLOW || dimOrder == DimOrder.NONE && kerasMajorVersion == 2) if (inputShape[0] != targetShape[0]) targetShape = new long[]{targetShape[2], targetShape[0], targetShape[1]}; - preprocessor = new ReshapePreprocessor(inputShape, targetShape); + preprocessor = new ReshapePreprocessor(inputShape, targetShape, false); } } else if (inputType[0] instanceof InputType.InputTypeConvolutional3D) { @@ -128,23 +128,23 @@ public class KerasReshape extends KerasLayer { } else { targetShape = new long[] { targetShape[2], targetShape[1], targetShape[0], targetShape[3] }; } - preprocessor = new ReshapePreprocessor(inputShape, targetShape); + preprocessor = new ReshapePreprocessor(inputShape, targetShape, false); } else { if (inputShape[0] != targetShape[0]) targetShape = new long[] { targetShape[3], targetShape[0], targetShape[1], targetShape[2] }; - preprocessor = new ReshapePreprocessor(inputShape, targetShape); + preprocessor = new ReshapePreprocessor(inputShape, targetShape, false); } } else if (inputType[0] instanceof InputType.InputTypeRecurrent) { InputType.InputTypeRecurrent it = (InputType.InputTypeRecurrent) inputType[0]; val inputShape = new long[]{it.getSize(), it.getTimeSeriesLength()}; - preprocessor = new ReshapePreprocessor(inputShape, this.targetShape); + preprocessor = new ReshapePreprocessor(inputShape, this.targetShape, false); } else if (inputType[0] instanceof InputType.InputTypeFeedForward) { InputType.InputTypeFeedForward it = (InputType.InputTypeFeedForward) inputType[0]; val inputShape = new long[]{it.getSize()}; if (targetShape.length == 3) { targetShape = targetShapeForDimOrder(inputShape, targetShape); } - preprocessor = new ReshapePreprocessor(inputShape, this.targetShape); + preprocessor = new ReshapePreprocessor(inputShape, this.targetShape, false); } return preprocessor; } diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/preprocessors/ReshapePreprocessor.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/preprocessors/ReshapePreprocessor.java index e9aef5b90..afc9392a5 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/preprocessors/ReshapePreprocessor.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/preprocessors/ReshapePreprocessor.java @@ -1,5 +1,6 @@ -/******************************************************************************* +/* ****************************************************************************** * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -20,7 +21,6 @@ import lombok.Data; import lombok.EqualsAndHashCode; import lombok.extern.slf4j.Slf4j; import lombok.val; -import org.apache.commons.lang3.ArrayUtils; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException; import org.deeplearning4j.nn.conf.preprocessor.BaseInputPreProcessor; @@ -36,73 +36,72 @@ import java.util.Arrays; import static org.nd4j.linalg.util.ArrayUtil.prodLong; /** - * Generic reshape preprocessor + * Generic reshape preprocessor. + * Note that shapes may be specified with or without the leading minibatch dimension, as long as hasMiniBatchDimension + * is set appropriately in {@link #ReshapePreprocessor(long[], long[], boolean)}
+ * For example, to reshape from [minibatch, 32] to [minibatch, 2, 4, 4] you could use:
+ * hasMiniBatchDimension = true with inputShape = [-1, 32] and targetShape = [-1, 2, 4, 4] OR
+ * hasMiniBatchDimension = false with inputShape = [32] and targetShape = [2, 4, 4] * * @author Max Pumperla */ @Data @Slf4j @EqualsAndHashCode(callSuper = false) -@JsonIgnoreProperties({"hasMiniBatchDimension", "miniBatchSize", "staticTargetShape"}) +@JsonIgnoreProperties({"miniBatchSize", "staticTargetShape"}) public class ReshapePreprocessor extends BaseInputPreProcessor { - private long[] inputShape; - private long[] targetShape; - private boolean hasMiniBatchDimension = false; - private int miniBatchSize; - private long[] staticTargetShape; + private final long[] inputShape; + private final long[] targetShape; + private boolean hasMiniBatchDimension; - public ReshapePreprocessor(@JsonProperty("inputShape") long[] inputShape, @JsonProperty("targetShape") long[] targetShape) { - this.inputShape = inputShape; - this.targetShape = targetShape; + /** + * @deprecated Use constructor {@link #ReshapePreprocessor(long[], long[], boolean)} + */ + @Deprecated + public ReshapePreprocessor(long[] inputShape, long[] targetShape) { + this(inputShape, targetShape, false); } - private static int prod(int[] array) { - int prod = 1; - for (int i : array) { - prod *= i; + /** + * @param inputShape Input shape, with or without leading minibatch dimension, depending on value of hasMiniBatchDimension + * @param targetShape Target shape, with or without leading minibatch dimension, depending on value of hasMiniBatchDimension + * @param hasMiniBatchDimension If true: shapes should be of the form [minibatch, x, y, ...]; if false: shapes should be of form [x, y, ...] + */ + public ReshapePreprocessor(@JsonProperty("inputShape") long[] inputShape, @JsonProperty("targetShape") long[] targetShape, + @JsonProperty("hasMiniBatchDimension") boolean hasMiniBatchDimension) { + this.inputShape = inputShape; + this.targetShape = targetShape; + this.hasMiniBatchDimension = hasMiniBatchDimension; + } + + private long[] getShape(long[] originalShape, long minibatch) { + long[] newShape = (hasMiniBatchDimension ? originalShape : prependMiniBatchSize(originalShape, minibatch)); + if (newShape[0] != minibatch) { + newShape = newShape.clone(); + newShape[0] = minibatch; } - return prod; + return newShape; } private static long[] prependMiniBatchSize(long[] shape, long miniBatchSize) { int shapeLength = shape.length; val miniBatchShape = new long[shapeLength + 1]; - for (int i = 0; i < miniBatchShape.length; i++) { - if (i == 0) - miniBatchShape[i] = miniBatchSize; - else - miniBatchShape[i] = shape[i - 1]; + miniBatchShape[0] = miniBatchSize; + for (int i = 1; i < miniBatchShape.length; i++) { + miniBatchShape[i] = shape[i - 1]; } return miniBatchShape; } @Override public INDArray preProcess(INDArray input, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) { - // the target shape read from a keras config does not have mini-batch size - // included. We prepend it here dynamically. + // the target shape read from a keras config does not have mini-batch size included. We prepend it here dynamically. + long[] targetShape = getShape(this.targetShape, miniBatchSize); + long[] inputShape = getShape(this.inputShape, miniBatchSize); - long[] targetShape; - if (staticTargetShape != null){ - targetShape = prependMiniBatchSize(staticTargetShape, miniBatchSize); - hasMiniBatchDimension = true; - this.miniBatchSize = miniBatchSize; - } - else{ - targetShape = this.targetShape; - } - if (!this.hasMiniBatchDimension) { - targetShape = prependMiniBatchSize(targetShape, miniBatchSize); - inputShape = prependMiniBatchSize(inputShape, miniBatchSize); - this.miniBatchSize = miniBatchSize; - } - if (this.miniBatchSize != miniBatchSize) { - targetShape = prependMiniBatchSize(ArrayUtils.subarray(targetShape, 1, targetShape.length), miniBatchSize); - inputShape = prependMiniBatchSize(ArrayUtils.subarray(inputShape, 1, targetShape.length), miniBatchSize); - this.miniBatchSize = miniBatchSize; - } if (prodLong(input.shape()) == prodLong((targetShape))) { - if(input.ordering() != 'c' || !Shape.hasDefaultStridesForShape(input)){ + if (input.ordering() != 'c' || !Shape.hasDefaultStridesForShape(input)) { input = workspaceMgr.dup(ArrayType.ACTIVATIONS, input, 'c'); } return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, input.reshape(targetShape)); @@ -114,15 +113,18 @@ public class ReshapePreprocessor extends BaseInputPreProcessor { @Override public INDArray backprop(INDArray output, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) { + long[] targetShape = getShape(this.targetShape, miniBatchSize); + long[] inputShape = getShape(this.inputShape, miniBatchSize); + if (!Arrays.equals(targetShape, output.shape())) { throw new IllegalStateException("Unexpected output shape" + Arrays.toString(output.shape()) + " (expected to be " + Arrays.toString(targetShape) + ")"); } if (prodLong(output.shape()) == prodLong((targetShape))) { - if(output.ordering() != 'c' || !Shape.hasDefaultStridesForShape(output)){ + if (output.ordering() != 'c' || !Shape.hasDefaultStridesForShape(output)) { output = workspaceMgr.dup(ArrayType.ACTIVATIONS, output, 'c'); } - return workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, output.reshape(this.inputShape)); + return workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, output.reshape(inputShape)); } else { throw new IllegalStateException("Output shape" + Arrays.toString(output.shape()) + " and input shape" + Arrays.toString(targetShape) + " do not match"); @@ -131,7 +133,7 @@ public class ReshapePreprocessor extends BaseInputPreProcessor { @Override public InputType getOutputType(InputType inputType) throws InvalidInputTypeException { - val shape = hasMiniBatchDimension ? targetShape : prependMiniBatchSize(targetShape, 0); + long[] shape = getShape(this.targetShape, 0); InputType ret; switch (shape.length) { case 2: @@ -141,18 +143,16 @@ public class ReshapePreprocessor extends BaseInputPreProcessor { ret = InputType.recurrent(shape[2], shape[1]); break; case 4: - if (inputShape.length == 1 || inputType.getType() == InputType.Type.RNN){ + if (inputShape.length == 1 || inputType.getType() == InputType.Type.RNN) { ret = InputType.convolutional(shape[1], shape[2], shape[3]); - }else { + } else { ret = InputType.convolutional(shape[2], shape[3], shape[1]); } break; default: throw new UnsupportedOperationException( "Cannot infer input type for reshape array " + Arrays.toString(shape)); - } - this.staticTargetShape = ret.getShape(); return ret; } } \ No newline at end of file diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/Keras2ModelConfigurationTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/Keras2ModelConfigurationTest.java index db03128f7..d776ed63e 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/Keras2ModelConfigurationTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/Keras2ModelConfigurationTest.java @@ -257,12 +257,15 @@ public class Keras2ModelConfigurationTest extends BaseDL4JTest { @Test public void ReshapeEmbeddingConcatTest() throws Exception{ + //TODO AB 2019/11/23 - known issue - see https://github.com/eclipse/deeplearning4j/issues/8373 and https://github.com/eclipse/deeplearning4j/issues/8441 + try(InputStream is = Resources.asStream("/modelimport/keras/configs/keras2/reshape_embedding_concat.json")) { ComputationGraphConfiguration config = new KerasModel().modelBuilder().modelJsonInputStream(is) .enforceTrainingConfig(false).buildModel().getComputationGraphConfiguration(); ComputationGraph model = new ComputationGraph(config); model.init(); +// System.out.println(model.summary()); model.outputSingle(Nd4j.zeros(1, 1), Nd4j.zeros(1, 1), Nd4j.zeros(1, 1)); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java index 7b19406ef..8826858e5 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java @@ -540,6 +540,8 @@ public class ImportClassMapping { org.nd4j.linalg.api.ops.impl.transforms.strict.Log.class, org.nd4j.linalg.api.ops.impl.transforms.strict.Log1p.class, org.nd4j.linalg.api.ops.impl.transforms.strict.LogSigmoid.class, + org.nd4j.linalg.api.ops.impl.transforms.strict.Mish.class, + org.nd4j.linalg.api.ops.impl.transforms.strict.MishDerivative.class, org.nd4j.linalg.api.ops.impl.transforms.strict.PreciseGELU.class, org.nd4j.linalg.api.ops.impl.transforms.strict.PreciseGELUDerivative.class, org.nd4j.linalg.api.ops.impl.transforms.strict.RationalTanh.class, diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossMCXENT.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossMCXENT.java index 22bb27e0e..78c21d95c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossMCXENT.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossMCXENT.java @@ -164,7 +164,7 @@ public class LossMCXENT implements ILossFunction { throw new IllegalStateException("Weights vector (length " + weights.length() + ") does not match output.size(1)=" + output.size(1)); } - INDArray temp = labels.mulRowVector(weights); + INDArray temp = labels.mulRowVector(weights.castTo(labels.dataType())); INDArray col = temp.sum(true,1); grad = output.mulColumnVector(col).sub(temp); } else { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossSparseMCXENT.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossSparseMCXENT.java index 2ea0feb52..f472fae5f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossSparseMCXENT.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossSparseMCXENT.java @@ -117,7 +117,7 @@ public class LossSparseMCXENT extends LossMCXENT { private INDArray toOneHot(INDArray labels, INDArray preOutput){ Preconditions.checkState(labels.size(-1) == 1, "Labels for LossSparseMCXENT should be an array of integers " + - "with last dimension having size 1. Got labels array with shape %ndShape", labels); + "with first dimension equal to minibatch size, and last dimension having size 1. Got labels array with shape %ndShape", labels); INDArray oneHotLabels = preOutput.ulike(); Nd4j.exec(new OneHot(labels.reshape(labels.length()), oneHotLabels, (int)preOutput.size(-1))); return oneHotLabels; diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java index 751f75cea..b6af2e5f2 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java @@ -1662,7 +1662,7 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { * This method executes given CustomOp * * PLEASE NOTE: You're responsible for input/output validation - * @param op + * @param op Operation to execute */ @Override public INDArray[] exec(@NonNull CustomOp op) { @@ -1671,11 +1671,12 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { try { val list = this.calculateOutputShape(op); if (list.isEmpty()) - throw new ND4JIllegalStateException("Op name " + op.opName() + " failed to execute. You can't execute non-inplace CustomOp without outputs being specified"); + throw new ND4JIllegalStateException("Op name " + op.opName() + " failed to calculate output datatypes"); for (LongShapeDescriptor shape : list) op.addOutputArgument(Nd4j.create(shape, false)); - + } catch (ND4JIllegalStateException e){ + throw e; } catch (Exception e) { throw new ND4JIllegalStateException("Op name " + op.opName() + " failed to execute. You can't execute non-inplace CustomOp without outputs being specified"); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/TestOpMapping.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/TestOpMapping.java index c4225ce2a..ab56ae281 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/TestOpMapping.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/TestOpMapping.java @@ -68,7 +68,9 @@ public class TestOpMapping extends BaseNd4jTest { } String opName = df.opName(); - assertTrue("Op is missing - not defined in ImportClassMapping: " + opName, opNameMapping.containsKey(opName)); + assertTrue("Op is missing - not defined in ImportClassMapping: " + opName + + "\nInstructions to fix: Add class to org.nd4j.imports.converters.ImportClassMapping", opNameMapping.containsKey(opName) + ); try{ String[] tfNames = df.tensorflowNames(); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java index a690bc5a8..d30ba87f6 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java @@ -129,6 +129,13 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a "resize_bilinear/int32.*" }; + /* As per TFGraphTestList.printArraysDebugging - this field defines a set of regexes for test cases that should have + all arrays printed during execution. + If a test name matches any regex here, an ExecPrintListener will be added to the listeners, and all output + arrays will be printed during execution + */ + private final List debugModeRegexes = null; //Arrays.asList("resize_nearest_neighbor/.*", "add_n.*"); + @BeforeClass public static void beforeClass() { Nd4j.setDataType(DataType.FLOAT); @@ -194,8 +201,18 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a Double maxRE = (precisionOverride == null ? null : precisionOverride.getFirst()); Double minAbs = (precisionOverride == null ? null : precisionOverride.getSecond()); + boolean verboseDebugMode = false; + if(debugModeRegexes != null){ + for(String regex : debugModeRegexes){ + if(modelName.matches(regex)){ + verboseDebugMode = true; + break; + } + } + } + try { - TFGraphTestAllHelper.checkOnlyOutput(inputs, predictions, modelName, BASE_DIR, MODEL_FILENAME, EXECUTE_WITH, TFGraphTestAllHelper.LOADER, maxRE, minAbs, false); + TFGraphTestAllHelper.checkOnlyOutput(inputs, predictions, modelName, BASE_DIR, MODEL_FILENAME, EXECUTE_WITH, TFGraphTestAllHelper.LOADER, maxRE, minAbs, verboseDebugMode); //TFGraphTestAllHelper.checkIntermediate(inputs, modelName, BASE_DIR, MODEL_FILENAME, EXECUTE_WITH, localTestDir); } catch (Throwable t){ log.error("ERROR Executing test: {} - input keys {}", modelName, (inputs == null ? null : inputs.keySet()), t); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/lossfunctions/LossFunctionTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/lossfunctions/LossFunctionTest.java index a6f8dddf3..9822962c4 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/lossfunctions/LossFunctionTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/lossfunctions/LossFunctionTest.java @@ -20,13 +20,15 @@ import org.junit.Test; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.activations.IActivation; import org.nd4j.linalg.activations.impl.ActivationSigmoid; +import org.nd4j.linalg.activations.impl.ActivationSoftmax; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition; import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.indexing.conditions.Conditions; -import org.nd4j.linalg.lossfunctions.impl.LossBinaryXENT; +import org.nd4j.linalg.lossfunctions.impl.*; import static junit.framework.TestCase.assertFalse; import static junit.framework.TestCase.assertTrue; @@ -70,6 +72,71 @@ public class LossFunctionTest extends BaseNd4jTest { assertEquals(0, match2); } + @Test + public void testWeightedLossFunctionDTypes(){ + + for(DataType activationsDt : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}){ + for(DataType weightsDt : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}){ + for( boolean rank1W : new boolean[]{false, true}) { + + INDArray preOut = Nd4j.rand(activationsDt, 2, 3); + INDArray l = Nd4j.rand(activationsDt, 2, 3); + + INDArray w = Nd4j.createFromArray(1.0f, 2.0f, 3.0f).castTo(weightsDt); + if(!rank1W){ + w = w.reshape(1, 3); + } + + ILossFunction lf = null; + for (int i = 0; i < 10; i++) { + switch (i) { + case 0: + lf = new LossBinaryXENT(w); + break; + case 1: + lf = new LossL1(w); + break; + case 2: + lf = new LossL2(w); + break; + case 3: + lf = new LossMAE(w); + break; + case 4: + lf = new LossMAPE(w); + break; + case 5: + lf = new LossMCXENT(w); + break; + case 6: + lf = new LossMSE(w); + break; + case 7: + lf = new LossMSLE(w); + break; + case 8: + lf = new LossNegativeLogLikelihood(w); + break; + case 9: + lf = new LossSparseMCXENT(w); + l = Nd4j.createFromArray(1,2).reshape(2, 1).castTo(activationsDt); + break; + default: + throw new RuntimeException(); + } + } + + //Check score + lf.computeScore(l, preOut, new ActivationSoftmax(), null, true); + + //Check backward + lf.computeGradient(l, preOut, new ActivationSoftmax(), null); + } + } + } + + } + @Override public char ordering() { From 5b2ee7267310c36a4c72e8e61e9b8322b7798623 Mon Sep 17 00:00:00 2001 From: Alex Black Date: Mon, 25 Nov 2019 16:00:21 +1100 Subject: [PATCH 02/30] DL4J Time Distributed + fixes + Vertx module profiles fix (#78) * Add test profiles to vertx module * Arbiter test tweaks Signed-off-by: AlexDBlack * Add TimeDistributed wrapper layer Signed-off-by: AlexDBlack * Tests for TimeDistributed layer Signed-off-by: AlexDBlack * Small test dependency exclusion for Spark module * Fixes, more thorough tests Signed-off-by: AlexDBlack --- .../TestGraphLocalExecution.java | 4 +- .../util/TestDataFactoryProviderMnist.java | 2 +- .../layers/recurrent/TestTimeDistributed.java | 88 ++++++++++++++ .../layers/recurrent/TimeDistributed.java | 81 +++++++++++++ .../recurrent/TimeDistributedLayer.java | 110 ++++++++++++++++++ .../spark/dl4j-spark/pom.xml | 6 + .../deeplearning4j-vertx/pom.xml | 9 ++ 7 files changed, 297 insertions(+), 3 deletions(-) create mode 100644 deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestTimeDistributed.java create mode 100644 deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/TimeDistributed.java create mode 100644 deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/TimeDistributedLayer.java diff --git a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/computationgraph/TestGraphLocalExecution.java b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/computationgraph/TestGraphLocalExecution.java index 9d9db6261..c64a06040 100644 --- a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/computationgraph/TestGraphLocalExecution.java +++ b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/computationgraph/TestGraphLocalExecution.java @@ -305,7 +305,7 @@ public class TestGraphLocalExecution { @Test public void testLocalExecutionEarlyStopping() throws Exception { EarlyStoppingConfiguration esConf = new EarlyStoppingConfiguration.Builder() - .epochTerminationConditions(new MaxEpochsTerminationCondition(6)) + .epochTerminationConditions(new MaxEpochsTerminationCondition(4)) .scoreCalculator(new ScoreProvider()) .modelSaver(new InMemoryModelSaver()).build(); Map commands = new HashMap<>(); @@ -348,7 +348,7 @@ public class TestGraphLocalExecution { .dataProvider(dataProvider) .scoreFunction(ScoreFunctions.testSetF1()) .modelSaver(new FileModelSaver(modelSavePath)) - .terminationConditions(new MaxTimeCondition(30, TimeUnit.SECONDS), + .terminationConditions(new MaxTimeCondition(45, TimeUnit.SECONDS), new MaxCandidatesCondition(10)) .build(); diff --git a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/util/TestDataFactoryProviderMnist.java b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/util/TestDataFactoryProviderMnist.java index 1e652cdbe..4416dd8cf 100644 --- a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/util/TestDataFactoryProviderMnist.java +++ b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/util/TestDataFactoryProviderMnist.java @@ -32,7 +32,7 @@ public class TestDataFactoryProviderMnist implements DataSetIteratorFactory { private int terminationIter; public TestDataFactoryProviderMnist(){ - this(16, 10); + this(16, 4); } @Override diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestTimeDistributed.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestTimeDistributed.java new file mode 100644 index 000000000..5c456e206 --- /dev/null +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestTimeDistributed.java @@ -0,0 +1,88 @@ +package org.deeplearning4j.nn.layers.recurrent; + +import org.deeplearning4j.BaseDL4JTest; +import org.deeplearning4j.TestUtils; +import org.deeplearning4j.nn.conf.MultiLayerConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.WorkspaceMode; +import org.deeplearning4j.nn.conf.inputs.InputType; +import org.deeplearning4j.nn.conf.layers.DenseLayer; +import org.deeplearning4j.nn.conf.layers.LSTM; +import org.deeplearning4j.nn.conf.layers.RnnOutputLayer; +import org.deeplearning4j.nn.conf.layers.recurrent.TimeDistributed; +import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; +import org.junit.Test; +import org.nd4j.linalg.activations.Activation; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.dataset.DataSet; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.learning.config.Adam; +import org.nd4j.linalg.lossfunctions.LossFunctions; + +import static org.junit.Assert.assertEquals; + +public class TestTimeDistributed extends BaseDL4JTest { + + @Test + public void testTimeDistributed(){ + for(WorkspaceMode wsm : new WorkspaceMode[]{WorkspaceMode.ENABLED, WorkspaceMode.NONE}) { + + MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder() + .trainingWorkspaceMode(wsm) + .inferenceWorkspaceMode(wsm) + .seed(12345) + .updater(new Adam(0.1)) + .list() + .layer(new LSTM.Builder().nIn(3).nOut(3).build()) + .layer(new DenseLayer.Builder().nIn(3).nOut(3).activation(Activation.TANH).build()) + .layer(new RnnOutputLayer.Builder().nIn(3).nOut(3).activation(Activation.SOFTMAX) + .lossFunction(LossFunctions.LossFunction.MCXENT).build()) + .setInputType(InputType.recurrent(3)) + .build(); + + MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder() + .trainingWorkspaceMode(wsm) + .inferenceWorkspaceMode(wsm) + .seed(12345) + .updater(new Adam(0.1)) + .list() + .layer(new LSTM.Builder().nIn(3).nOut(3).build()) + .layer(new TimeDistributed(new DenseLayer.Builder().nIn(3).nOut(3).activation(Activation.TANH).build(), 2)) + .layer(new RnnOutputLayer.Builder().nIn(3).nOut(3).activation(Activation.SOFTMAX) + .lossFunction(LossFunctions.LossFunction.MCXENT).build()) + .setInputType(InputType.recurrent(3)) + .build(); + + MultiLayerNetwork net1 = new MultiLayerNetwork(conf1); + MultiLayerNetwork net2 = new MultiLayerNetwork(conf2); + net1.init(); + net2.init(); + + for( int mb : new int[]{1, 5}) { + for(char inLabelOrder : new char[]{'c', 'f'}) { + INDArray in = Nd4j.rand(DataType.FLOAT, mb, 3, 5).dup(inLabelOrder); + + INDArray out1 = net1.output(in); + INDArray out2 = net2.output(in); + + assertEquals(out1, out2); + + INDArray labels = TestUtils.randomOneHotTimeSeries(mb, 3, 5).dup(inLabelOrder); + + DataSet ds = new DataSet(in, labels); + net1.fit(ds); + net2.fit(ds); + + assertEquals(net1.params(), net2.params()); + + MultiLayerNetwork net3 = TestUtils.testModelSerialization(net2); + out2 = net2.output(in); + INDArray out3 = net3.output(in); + + assertEquals(out2, out3); + } + } + } + } +} diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/TimeDistributed.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/TimeDistributed.java new file mode 100644 index 000000000..bd9685ef9 --- /dev/null +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/TimeDistributed.java @@ -0,0 +1,81 @@ +package org.deeplearning4j.nn.conf.layers.recurrent; + +import lombok.Data; +import lombok.EqualsAndHashCode; +import lombok.NonNull; +import org.deeplearning4j.nn.conf.InputPreProcessor; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.inputs.InputType; +import org.deeplearning4j.nn.conf.layers.Layer; +import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer; +import org.deeplearning4j.nn.layers.recurrent.TimeDistributedLayer; +import org.deeplearning4j.optimize.api.TrainingListener; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.shade.jackson.annotation.JsonProperty; + +import java.util.Collection; + +/** + * TimeDistributed wrapper layer.
+ * Note: only the "Feed forward layer time distributed in an RNN" is currently supported. + * For example, a time distributed dense layer.
+ * Usage: {@code .layer(new TimeDistributed(new DenseLayer.Builder()....build(), timeAxis))}
+ * Note that for DL4J RNNs, time axis is always 2 - i.e., RNN activations have shape [minibatch, size, sequenceLength] + * + * @author Alex Black + */ +@Data +@EqualsAndHashCode(callSuper = true) +public class TimeDistributed extends BaseWrapperLayer { + + private final int timeAxis; + + /** + * @param underlying Underlying (internal) layer - should be a feed forward type such as DenseLayer + * @param timeAxis Time axis, should be 2 for DL4J RNN activations (shape [minibatch, size, sequenceLength]) + */ + public TimeDistributed(@JsonProperty("underlying") @NonNull Layer underlying, @JsonProperty("timeAxis") int timeAxis) { + super(underlying); + this.timeAxis = timeAxis; + } + + + @Override + public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf, Collection trainingListeners, + int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { + NeuralNetConfiguration conf2 = conf.clone(); + conf2.setLayer(((TimeDistributed) conf2.getLayer()).getUnderlying()); + return new TimeDistributedLayer(underlying.instantiate(conf2, trainingListeners, layerIndex, layerParamsView, + initializeParams, networkDataType), timeAxis); + } + + @Override + public InputType getOutputType(int layerIndex, InputType inputType) { + if (inputType.getType() != InputType.Type.RNN) { + throw new IllegalStateException("Only RNN input type is supported as input to TimeDistributed layer (layer #" + layerIndex + ")"); + } + + InputType.InputTypeRecurrent rnn = (InputType.InputTypeRecurrent) inputType; + InputType ff = InputType.feedForward(rnn.getSize()); + InputType ffOut = underlying.getOutputType(layerIndex, ff); + return InputType.recurrent(ffOut.arrayElementsPerExample(), rnn.getTimeSeriesLength()); + } + + @Override + public void setNIn(InputType inputType, boolean override) { + if (inputType.getType() != InputType.Type.RNN) { + throw new IllegalStateException("Only RNN input type is supported as input to TimeDistributed layer"); + } + + InputType.InputTypeRecurrent rnn = (InputType.InputTypeRecurrent) inputType; + InputType ff = InputType.feedForward(rnn.getSize()); + underlying.setNIn(ff, override); + } + + @Override + public InputPreProcessor getPreProcessorForInputType(InputType inputType) { + //No preprocessor - the wrapper layer operates as the preprocessor + return null; + } +} diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/TimeDistributedLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/TimeDistributedLayer.java new file mode 100644 index 000000000..874fb136f --- /dev/null +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/TimeDistributedLayer.java @@ -0,0 +1,110 @@ +package org.deeplearning4j.nn.layers.recurrent; + +import org.deeplearning4j.nn.api.Layer; +import org.deeplearning4j.nn.gradient.Gradient; +import org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer; +import org.deeplearning4j.nn.workspace.ArrayType; +import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.primitives.Pair; +import org.nd4j.linalg.util.ArrayUtil; + +/** + * TimeDistributed wrapper layer.
+ * Note: only the "Feed forward layer time distributed in an RNN" is currently supported. + * For example, a time distributed dense layer.
+ * Usage: {@code .layer(new TimeDistributed(new DenseLayer.Builder()....build(), timeAxis))}
+ * Note that for DL4J RNNs, time axis is always 2 - i.e., RNN activations have shape [minibatch, size, sequenceLength] + * + * @author Alex Black + */ +public class TimeDistributedLayer extends BaseWrapperLayer { + + private final int timeAxis; + + public TimeDistributedLayer(Layer underlying, int timeAxis) { + super(underlying); + this.timeAxis = timeAxis; + } + + + @Override + public Pair backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) { + INDArray reshapedEps = reshape(epsilon); + Pair p = underlying.backpropGradient(reshapedEps, workspaceMgr); + INDArray reverted = revertReshape(p.getSecond(), epsilon.size(0)); + reverted = workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, reverted); + p.setSecond(reverted); + return p; + } + + @Override + public INDArray activate(boolean training, LayerWorkspaceMgr workspaceMgr) { + return activate(input(), training, workspaceMgr); + } + + @Override + public INDArray activate(INDArray input, boolean training, LayerWorkspaceMgr workspaceMgr) { + INDArray reshaped = reshape(input); + INDArray out = underlying.activate(reshaped, training, workspaceMgr); + INDArray ret = revertReshape(out, input.size(0)); + return workspaceMgr.dup(ArrayType.ACTIVATIONS, ret); + } + + protected INDArray reshape(INDArray array){ + //Reshape the time axis to the minibatch axis + //For example, for RNN -> FF (dense time distributed): [mb, size, seqLen] -> [mb x seqLen, size] + int axis = timeAxis; + if(axis < 0) + axis += array.rank(); + + int[] permuteAxis = permuteAxes(array.rank(), axis); + INDArray permute = array.permute(permuteAxis); + + long[] newShape = new long[array.rank()-1]; + newShape[0] = array.size(0) * array.size(axis); + int j=1; + for( int i=1; ideeplearning4j-ui ${deeplearning4j.version} test + + + net.jpountz.lz4 + lz4 + + diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/pom.xml b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/pom.xml index 8a621b40b..5e5ae75c1 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/pom.xml +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/pom.xml @@ -434,4 +434,13 @@ + + + + test-nd4j-native + + + test-nd4j-cuda-10.1 + + \ No newline at end of file From aa44fd6850ae0f15be5719aec4b098b23ffc6dd8 Mon Sep 17 00:00:00 2001 From: raver119 Date: Mon, 25 Nov 2019 08:52:11 +0300 Subject: [PATCH 03/30] one more BitCast test Signed-off-by: raver119 --- .../java/org/nd4j/linalg/custom/CustomOpsTests.java | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java index ca075c872..c01cf1942 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java @@ -1088,6 +1088,16 @@ public class CustomOpsTests extends BaseNd4jTest { assertArrayEquals(new long[]{1,10, 2}, lsd.get(0).getShape()); } + @Test + public void testBitCastShape_3(){ + val x = Nd4j.createFromArray(new int[]{1, 2, 3, 4, 5, 6, 7, 8}).reshape(1, 4, 2); + val e = Nd4j.createFromArray(new long[]{8589934593L, 17179869187L, 25769803781L, 34359738375L}).reshape(1, 4); + val z = Nd4j.exec(new BitCast(x, DataType.LONG.toInt()))[0]; + + assertEquals(e, z); + } + + @Test public void testMatch_1() { INDArray x = Nd4j.ones(DataType.FLOAT, 3,3); From 7f90930e7a5cec6eaed87121c6deaf3209b932f3 Mon Sep 17 00:00:00 2001 From: raver119 Date: Mon, 25 Nov 2019 09:17:35 +0300 Subject: [PATCH 04/30] bring back cuda cc 30 Signed-off-by: raver119 --- libnd4j/blas/CMakeLists.txt | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/libnd4j/blas/CMakeLists.txt b/libnd4j/blas/CMakeLists.txt index c804ce5ec..9674e28cd 100755 --- a/libnd4j/blas/CMakeLists.txt +++ b/libnd4j/blas/CMakeLists.txt @@ -175,9 +175,9 @@ if(CUDA_BLAS) if(CUDA_VERSION VERSION_GREATER "9.2") # cuda 10 if ("${COMPUTE}" STREQUAL "all") if (APPLE) - list(APPEND CUDA_NVCC_FLAGS -DCUDA_10 ${EXPM} -w --cudart=static -O3 --expt-extended-lambda -Xfatbin -compress-all -gencode arch=compute_35,code=sm_35 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_60,code=sm_60) + list(APPEND CUDA_NVCC_FLAGS -DCUDA_10 ${EXPM} -w --cudart=static -O3 --expt-extended-lambda -Xfatbin -compress-all -gencode arch=compute_30,code=sm_30 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_60,code=sm_60) else() - list(APPEND CUDA_NVCC_FLAGS -DCUDA_10 ${EXPM} -w --cudart=static -O3 --expt-extended-lambda -Xfatbin -compress-all -gencode arch=compute_35,code=sm_35 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_70,code=sm_70) + list(APPEND CUDA_NVCC_FLAGS -DCUDA_10 ${EXPM} -w --cudart=static -O3 --expt-extended-lambda -Xfatbin -compress-all -gencode arch=compute_30,code=sm_30 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_70,code=sm_70) endif() else() list(APPEND CUDA_NVCC_FLAGS -DCUDA_10 ${EXPM} -w --cudart=static --expt-extended-lambda -O3 -Xfatbin -compress-all -arch=compute_${COMPUTE} -code=sm_${COMPUTE}) @@ -185,9 +185,9 @@ if(CUDA_BLAS) elseif(CUDA_VERSION VERSION_GREATER "8.0") # cuda 9 if ("${COMPUTE}" STREQUAL "all") if (APPLE) - list(APPEND CUDA_NVCC_FLAGS -DCUDA_9 ${EXPM} -w --cudart=static -O3 --expt-extended-lambda -Xfatbin -compress-all -gencode arch=compute_35,code=sm_35 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_60,code=sm_60) + list(APPEND CUDA_NVCC_FLAGS -DCUDA_9 ${EXPM} -w --cudart=static -O3 --expt-extended-lambda -Xfatbin -compress-all -gencode arch=compute_30,code=sm_30 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_60,code=sm_60) else() - list(APPEND CUDA_NVCC_FLAGS -DCUDA_9 ${EXPM} -w --cudart=static -O3 --expt-extended-lambda -Xfatbin -compress-all -gencode arch=compute_35,code=sm_35 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_70,code=sm_70) + list(APPEND CUDA_NVCC_FLAGS -DCUDA_9 ${EXPM} -w --cudart=static -O3 --expt-extended-lambda -Xfatbin -compress-all -gencode arch=compute_30,code=sm_30 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_70,code=sm_70) endif() else() list(APPEND CUDA_NVCC_FLAGS -DCUDA_9 ${EXPM} -w --cudart=static --expt-extended-lambda -O3 -Xfatbin -compress-all -arch=compute_${COMPUTE} -code=sm_${COMPUTE}) From 0e3fcdc24dc28e018149a23e66b37f39146e5b78 Mon Sep 17 00:00:00 2001 From: Alex Black Date: Mon, 25 Nov 2019 18:46:34 +1100 Subject: [PATCH 05/30] [WIP] DL4J nearestneighbors-sever: Play ->Vertx (#79) * Switch Nearest neighbors server implementation from Play to Vertx Signed-off-by: AlexDBlack * No more scala version suffix for nearest neighbor server Signed-off-by: AlexDBlack * logback.xml fixes Signed-off-by: AlexDBlack * Header tweaks Signed-off-by: AlexDBlack --- .../pom.xml | 53 ++---- .../server/NearestNeighborsServer.java | 160 +++++++++++------- .../server/NearestNeighborTest.java | 11 +- .../src/test/resources/logback.xml | 42 +++++ .../pom.xml | 1 + 5 files changed, 158 insertions(+), 109 deletions(-) create mode 100644 deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/test/resources/logback.xml diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/pom.xml b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/pom.xml index 66c639f56..1fe0d20ff 100644 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/pom.xml +++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/pom.xml @@ -1,5 +1,6 @@ - 2.11.12 - 2.11 - @@ -73,29 +69,17 @@ - com.typesafe.play - play-java_2.11 - ${playframework.version} - - - com.google.code.findbugs - jsr305 - - - org.apache.tomcat - tomcat-servlet-api - - - net.jodah - typetools - - + io.vertx + vertx-core + ${vertx.version} + - net.jodah - typetools - ${jodah.typetools.version} + io.vertx + vertx-web + ${vertx.version} + com.mashape.unirest unirest-java @@ -108,25 +92,16 @@ ${project.version} test - - com.typesafe.play - play-json_2.11 - ${playframework.version} - - - com.typesafe.play - play-server_2.11 - ${playframework.version} - com.beust jcommander ${jcommander.version} + - com.typesafe.play - play-netty-server_2.11 - ${playframework.version} + ch.qos.logback + logback-classic + test diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/main/java/org/deeplearning4j/nearestneighbor/server/NearestNeighborsServer.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/main/java/org/deeplearning4j/nearestneighbor/server/NearestNeighborsServer.java index a79b57b19..6610e75f9 100644 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/main/java/org/deeplearning4j/nearestneighbor/server/NearestNeighborsServer.java +++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/main/java/org/deeplearning4j/nearestneighbor/server/NearestNeighborsServer.java @@ -1,5 +1,6 @@ -/******************************************************************************* +/* ****************************************************************************** * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -19,6 +20,11 @@ package org.deeplearning4j.nearestneighbor.server; import com.beust.jcommander.JCommander; import com.beust.jcommander.Parameter; import com.beust.jcommander.ParameterException; +import io.netty.handler.codec.http.HttpResponseStatus; +import io.vertx.core.AbstractVerticle; +import io.vertx.core.Vertx; +import io.vertx.ext.web.Router; +import io.vertx.ext.web.handler.BodyHandler; import lombok.extern.slf4j.Slf4j; import org.apache.commons.io.FileUtils; import org.deeplearning4j.clustering.sptree.DataPoint; @@ -26,6 +32,7 @@ import org.deeplearning4j.clustering.vptree.VPTree; import org.deeplearning4j.clustering.vptree.VPTreeFillSearch; import org.deeplearning4j.exception.DL4JInvalidInputException; import org.deeplearning4j.nearestneighbor.model.*; +import org.deeplearning4j.nn.conf.serde.JsonMappers; import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.shape.Shape; @@ -33,19 +40,10 @@ import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.serde.base64.Nd4jBase64; import org.nd4j.serde.binary.BinarySerde; -import play.BuiltInComponents; -import play.Mode; -import play.libs.Json; -import play.routing.Router; -import play.routing.RoutingDsl; -import play.server.Server; import java.io.File; import java.util.*; -import static play.mvc.Controller.request; -import static play.mvc.Results.*; - /** * A rest server for using an * {@link VPTree} based on loading an ndarray containing @@ -57,22 +55,33 @@ import static play.mvc.Results.*; * @author Adam Gibson */ @Slf4j -public class NearestNeighborsServer { - @Parameter(names = {"--ndarrayPath"}, arity = 1, required = true) - private String ndarrayPath = null; - @Parameter(names = {"--labelsPath"}, arity = 1, required = false) - private String labelsPath = null; - @Parameter(names = {"--nearestNeighborsPort"}, arity = 1) - private int port = 9000; - @Parameter(names = {"--similarityFunction"}, arity = 1) - private String similarityFunction = "euclidean"; - @Parameter(names = {"--invert"}, arity = 1) - private boolean invert = false; +public class NearestNeighborsServer extends AbstractVerticle { - private Server server; + private static class RunArgs { + @Parameter(names = {"--ndarrayPath"}, arity = 1, required = true) + private String ndarrayPath = null; + @Parameter(names = {"--labelsPath"}, arity = 1, required = false) + private String labelsPath = null; + @Parameter(names = {"--nearestNeighborsPort"}, arity = 1) + private int port = 9000; + @Parameter(names = {"--similarityFunction"}, arity = 1) + private String similarityFunction = "euclidean"; + @Parameter(names = {"--invert"}, arity = 1) + private boolean invert = false; + } - public void runMain(String... args) throws Exception { - JCommander jcmdr = new JCommander(this); + private static RunArgs instanceArgs; + private static NearestNeighborsServer instance; + + public NearestNeighborsServer(){ } + + public static NearestNeighborsServer getInstance(){ + return instance; + } + + public static void runMain(String... args) { + RunArgs r = new RunArgs(); + JCommander jcmdr = new JCommander(r); try { jcmdr.parse(args); @@ -84,7 +93,7 @@ public class NearestNeighborsServer { //User provides invalid input -> print the usage info jcmdr.usage(); - if (ndarrayPath == null) + if (r.ndarrayPath == null) log.error("Json path parameter is missing (null)"); try { Thread.sleep(500); @@ -93,16 +102,20 @@ public class NearestNeighborsServer { System.exit(1); } + instanceArgs = r; try { - runHelper(); + Vertx vertx = Vertx.vertx(); + vertx.deployVerticle(NearestNeighborsServer.class.getName()); } catch (Throwable t){ log.error("Error in NearestNeighboursServer run method",t); } } - protected void runHelper() throws Exception { + @Override + public void start() throws Exception { + instance = this; - String[] pathArr = ndarrayPath.split(","); + String[] pathArr = instanceArgs.ndarrayPath.split(","); //INDArray[] pointsArr = new INDArray[pathArr.length]; // first of all we reading shapes of saved eariler files int rows = 0; @@ -111,7 +124,7 @@ public class NearestNeighborsServer { DataBuffer shape = BinarySerde.readShapeFromDisk(new File(pathArr[i])); log.info("Loading shape {} of {}; Shape: [{} x {}]", i + 1, pathArr.length, Shape.size(shape, 0), - Shape.size(shape, 1)); + Shape.size(shape, 1)); if (Shape.rank(shape) != 2) throw new DL4JInvalidInputException("NearestNeighborsServer assumes 2D chunks"); @@ -122,12 +135,12 @@ public class NearestNeighborsServer { cols = Shape.size(shape, 1); else if (cols != Shape.size(shape, 1)) throw new DL4JInvalidInputException( - "NearestNeighborsServer requires equal 2D chunks. Got columns mismatch."); + "NearestNeighborsServer requires equal 2D chunks. Got columns mismatch."); } final List labels = new ArrayList<>(); - if (labelsPath != null) { - String[] labelsPathArr = labelsPath.split(","); + if (instanceArgs.labelsPath != null) { + String[] labelsPathArr = instanceArgs.labelsPath.split(","); for (int i = 0; i < labelsPathArr.length; i++) { labels.addAll(FileUtils.readLines(new File(labelsPathArr[i]), "utf-8")); } @@ -149,7 +162,7 @@ public class NearestNeighborsServer { System.gc(); } - VPTree tree = new VPTree(points, similarityFunction, invert); + VPTree tree = new VPTree(points, instanceArgs.similarityFunction, instanceArgs.invert); //Set play secret key, if required //http://www.playframework.com/documentation/latest/ApplicationSecret @@ -163,40 +176,57 @@ public class NearestNeighborsServer { System.setProperty("play.crypto.secret", base64); } + Router r = Router.router(vertx); + r.route().handler(BodyHandler.create()); //NOTE: Setting this is required to receive request body content at all + createRoutes(r, labels, tree, points); - server = Server.forRouter(Mode.PROD, port, b -> createRouter(tree, labels, points, b)); + vertx.createHttpServer() + .requestHandler(r) + .listen(instanceArgs.port); } - protected Router createRouter(VPTree tree, List labels, INDArray points, BuiltInComponents builtInComponents){ - RoutingDsl routingDsl = RoutingDsl.fromComponents(builtInComponents); - //return the host information for a given id - routingDsl.POST("/knn").routingTo(request -> { + private void createRoutes(Router r, List labels, VPTree tree, INDArray points){ + + r.post("/knn").handler(rc -> { try { - NearestNeighborRequest record = Json.fromJson(request.body().asJson(), NearestNeighborRequest.class); + String json = rc.getBodyAsJson().encode(); + NearestNeighborRequest record = JsonMappers.getMapper().readValue(json, NearestNeighborRequest.class); + NearestNeighbor nearestNeighbor = NearestNeighbor.builder().points(points).record(record).tree(tree).build(); - if (record == null) - return badRequest(Json.toJson(Collections.singletonMap("status", "invalid json passed."))); + if (record == null) { + rc.response().setStatusCode(HttpResponseStatus.BAD_REQUEST.code()) + .putHeader("content-type", "application/json") + .end(JsonMappers.getMapper().writeValueAsString(Collections.singletonMap("status", "invalid json passed."))); + return; + } - NearestNeighborsResults results = - NearestNeighborsResults.builder().results(nearestNeighbor.search()).build(); - - - return ok(Json.toJson(results)); + NearestNeighborsResults results = NearestNeighborsResults.builder().results(nearestNeighbor.search()).build(); + rc.response().setStatusCode(HttpResponseStatus.BAD_REQUEST.code()) + .putHeader("content-type", "application/json") + .end(JsonMappers.getMapper().writeValueAsString(results)); + return; } catch (Throwable e) { log.error("Error in POST /knn",e); e.printStackTrace(); - return internalServerError(e.getMessage()); + rc.response().setStatusCode(HttpResponseStatus.INTERNAL_SERVER_ERROR.code()) + .end("Error parsing request - " + e.getMessage()); + return; } }); - routingDsl.POST("/knnnew").routingTo(request -> { + r.post("/knnnew").handler(rc -> { try { - Base64NDArrayBody record = Json.fromJson(request.body().asJson(), Base64NDArrayBody.class); - if (record == null) - return badRequest(Json.toJson(Collections.singletonMap("status", "invalid json passed."))); + String json = rc.getBodyAsJson().encode(); + Base64NDArrayBody record = JsonMappers.getMapper().readValue(json, Base64NDArrayBody.class); + if (record == null) { + rc.response().setStatusCode(HttpResponseStatus.BAD_REQUEST.code()) + .putHeader("content-type", "application/json") + .end(JsonMappers.getMapper().writeValueAsString(Collections.singletonMap("status", "invalid json passed."))); + return; + } INDArray arr = Nd4jBase64.fromBase64(record.getNdarray()); List results; @@ -214,9 +244,10 @@ public class NearestNeighborsServer { } if (results.size() != distances.size()) { - return internalServerError( - String.format("results.size == %d != %d == distances.size", - results.size(), distances.size())); + rc.response() + .setStatusCode(HttpResponseStatus.INTERNAL_SERVER_ERROR.code()) + .end(String.format("results.size == %d != %d == distances.size", results.size(), distances.size())); + return; } List nnResult = new ArrayList<>(); @@ -228,30 +259,29 @@ public class NearestNeighborsServer { } NearestNeighborsResults results2 = NearestNeighborsResults.builder().results(nnResult).build(); - return ok(Json.toJson(results2)); - + String j = JsonMappers.getMapper().writeValueAsString(results2); + rc.response() + .putHeader("content-type", "application/json") + .end(j); } catch (Throwable e) { log.error("Error in POST /knnnew",e); e.printStackTrace(); - return internalServerError(e.getMessage()); + rc.response().setStatusCode(HttpResponseStatus.INTERNAL_SERVER_ERROR.code()) + .end("Error parsing request - " + e.getMessage()); + return; } }); - - return routingDsl.build(); } /** * Stop the server */ - public void stop() { - if (server != null) { - log.info("Attempting to stop server"); - server.stop(); - } + public void stop() throws Exception { + super.stop(); } public static void main(String[] args) throws Exception { - new NearestNeighborsServer().runMain(args); + runMain(args); } } diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/test/java/org/deeplearning4j/nearestneighbor/server/NearestNeighborTest.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/test/java/org/deeplearning4j/nearestneighbor/server/NearestNeighborTest.java index 9f8fd7241..b42c407e5 100644 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/test/java/org/deeplearning4j/nearestneighbor/server/NearestNeighborTest.java +++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/test/java/org/deeplearning4j/nearestneighbor/server/NearestNeighborTest.java @@ -1,5 +1,6 @@ -/******************************************************************************* +/* ****************************************************************************** * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -50,7 +51,6 @@ public class NearestNeighborTest extends BaseDL4JTest { public TemporaryFolder testDir = new TemporaryFolder(); @Test - //@Ignore("AB 2019/05/21 - Failing - Issue #7657") public void testNearestNeighbor() { double[][] data = new double[][] {{1, 2, 3, 4}, {1, 2, 3, 5}, {3, 4, 5, 6}}; INDArray arr = Nd4j.create(data); @@ -119,14 +119,15 @@ public class NearestNeighborTest extends BaseDL4JTest { File writeToTmp = testDir.newFile(); writeToTmp.deleteOnExit(); BinarySerde.writeArrayToDisk(rand, writeToTmp); - NearestNeighborsServer server = new NearestNeighborsServer(); - server.runMain("--ndarrayPath", writeToTmp.getAbsolutePath(), "--nearestNeighborsPort", + NearestNeighborsServer.runMain("--ndarrayPath", writeToTmp.getAbsolutePath(), "--nearestNeighborsPort", String.valueOf(localPort)); + Thread.sleep(3000); + NearestNeighborsClient client = new NearestNeighborsClient("http://localhost:" + localPort); NearestNeighborsResults result = client.knnNew(5, rand.getRow(0)); assertEquals(5, result.getResults().size()); - server.stop(); + NearestNeighborsServer.getInstance().stop(); } diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/test/resources/logback.xml b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/test/resources/logback.xml new file mode 100644 index 000000000..7953c2712 --- /dev/null +++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/test/resources/logback.xml @@ -0,0 +1,42 @@ + + + + + + logs/application.log + + %date - [%level] - from %logger in %thread + %n%message%n%xException%n + + + + + + %logger{15} - %message%n%xException{5} + + + + + + + + + + + + + \ No newline at end of file diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/pom.xml b/deeplearning4j/deeplearning4j-nearestneighbors-parent/pom.xml index d820dd6b7..720e7a5f3 100644 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/pom.xml +++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/pom.xml @@ -1,5 +1,6 @@ - 10.1 + 10.2 7.6 1.5.2 @@ -106,7 +106,7 @@ - test-nd4j-cuda-10.1 + test-nd4j-cuda-10.2 diff --git a/deeplearning4j/deeplearning4j-data/deeplearning4j-datasets/pom.xml b/deeplearning4j/deeplearning4j-data/deeplearning4j-datasets/pom.xml index eb6b8fec7..b1adbe93e 100644 --- a/deeplearning4j/deeplearning4j-data/deeplearning4j-datasets/pom.xml +++ b/deeplearning4j/deeplearning4j-data/deeplearning4j-datasets/pom.xml @@ -51,7 +51,7 @@ test-nd4j-native - test-nd4j-cuda-10.1 + test-nd4j-cuda-10.2 diff --git a/deeplearning4j/deeplearning4j-data/deeplearning4j-datavec-iterators/pom.xml b/deeplearning4j/deeplearning4j-data/deeplearning4j-datavec-iterators/pom.xml index 4df187848..c87a94b37 100644 --- a/deeplearning4j/deeplearning4j-data/deeplearning4j-datavec-iterators/pom.xml +++ b/deeplearning4j/deeplearning4j-data/deeplearning4j-datavec-iterators/pom.xml @@ -46,7 +46,7 @@ test-nd4j-native - test-nd4j-cuda-10.1 + test-nd4j-cuda-10.2 diff --git a/deeplearning4j/deeplearning4j-data/deeplearning4j-utility-iterators/pom.xml b/deeplearning4j/deeplearning4j-data/deeplearning4j-utility-iterators/pom.xml index eb3ca627f..462bebc95 100644 --- a/deeplearning4j/deeplearning4j-data/deeplearning4j-utility-iterators/pom.xml +++ b/deeplearning4j/deeplearning4j-data/deeplearning4j-utility-iterators/pom.xml @@ -48,7 +48,7 @@ test-nd4j-native - test-nd4j-cuda-10.1 + test-nd4j-cuda-10.2 diff --git a/deeplearning4j/deeplearning4j-data/pom.xml b/deeplearning4j/deeplearning4j-data/pom.xml index a4ffc7d2d..ca29f35b7 100644 --- a/deeplearning4j/deeplearning4j-data/pom.xml +++ b/deeplearning4j/deeplearning4j-data/pom.xml @@ -38,7 +38,7 @@ test-nd4j-native - test-nd4j-cuda-10.1 + test-nd4j-cuda-10.2 diff --git a/deeplearning4j/deeplearning4j-dataimport-solrj/pom.xml b/deeplearning4j/deeplearning4j-dataimport-solrj/pom.xml index c6227f27f..ecc33b57e 100644 --- a/deeplearning4j/deeplearning4j-dataimport-solrj/pom.xml +++ b/deeplearning4j/deeplearning4j-dataimport-solrj/pom.xml @@ -116,11 +116,11 @@ - test-nd4j-cuda-10.1 + test-nd4j-cuda-10.2 org.nd4j - nd4j-cuda-10.1 + nd4j-cuda-10.2 ${project.version} test diff --git a/deeplearning4j/deeplearning4j-graph/pom.xml b/deeplearning4j/deeplearning4j-graph/pom.xml index 6f3e5f555..9c4b25ac3 100644 --- a/deeplearning4j/deeplearning4j-graph/pom.xml +++ b/deeplearning4j/deeplearning4j-graph/pom.xml @@ -58,7 +58,7 @@ test-nd4j-native - test-nd4j-cuda-10.1 + test-nd4j-cuda-10.2 diff --git a/deeplearning4j/deeplearning4j-manifold/deeplearning4j-tsne/pom.xml b/deeplearning4j/deeplearning4j-manifold/deeplearning4j-tsne/pom.xml index 2353a6bdd..37002b5e1 100644 --- a/deeplearning4j/deeplearning4j-manifold/deeplearning4j-tsne/pom.xml +++ b/deeplearning4j/deeplearning4j-manifold/deeplearning4j-tsne/pom.xml @@ -62,7 +62,7 @@ test-nd4j-native - test-nd4j-cuda-10.1 + test-nd4j-cuda-10.2 diff --git a/deeplearning4j/deeplearning4j-manifold/pom.xml b/deeplearning4j/deeplearning4j-manifold/pom.xml index 5ac781ff3..921ee9653 100644 --- a/deeplearning4j/deeplearning4j-manifold/pom.xml +++ b/deeplearning4j/deeplearning4j-manifold/pom.xml @@ -41,7 +41,7 @@ test-nd4j-native - test-nd4j-cuda-10.1 + test-nd4j-cuda-10.2 diff --git a/deeplearning4j/deeplearning4j-modelexport-solr/pom.xml b/deeplearning4j/deeplearning4j-modelexport-solr/pom.xml index c550547d4..02ce30a40 100644 --- a/deeplearning4j/deeplearning4j-modelexport-solr/pom.xml +++ b/deeplearning4j/deeplearning4j-modelexport-solr/pom.xml @@ -302,11 +302,11 @@ - test-nd4j-cuda-10.1 + test-nd4j-cuda-10.2 org.nd4j - nd4j-cuda-10.1 + nd4j-cuda-10.2 ${project.version} test diff --git a/deeplearning4j/deeplearning4j-modelimport/pom.xml b/deeplearning4j/deeplearning4j-modelimport/pom.xml index dec29266f..223aebdaa 100644 --- a/deeplearning4j/deeplearning4j-modelimport/pom.xml +++ b/deeplearning4j/deeplearning4j-modelimport/pom.xml @@ -115,11 +115,11 @@ - test-nd4j-cuda-10.1 + test-nd4j-cuda-10.2 org.nd4j - nd4j-cuda-10.1 + nd4j-cuda-10.2 ${project.version} test diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/pom.xml b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/pom.xml index 1fe0d20ff..ab28d78c4 100644 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/pom.xml +++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/pom.xml @@ -119,11 +119,11 @@ - test-nd4j-cuda-10.1 + test-nd4j-cuda-10.2 org.nd4j - nd4j-cuda-10.1 + nd4j-cuda-10.2 ${project.version} test diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-client/pom.xml b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-client/pom.xml index d6b64b025..e3ca20366 100644 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-client/pom.xml +++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-client/pom.xml @@ -54,7 +54,7 @@ test-nd4j-native - test-nd4j-cuda-10.1 + test-nd4j-cuda-10.2 diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-model/pom.xml b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-model/pom.xml index 609d48a39..bfd004c41 100644 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-model/pom.xml +++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-model/pom.xml @@ -53,7 +53,7 @@ test-nd4j-native - test-nd4j-cuda-10.1 + test-nd4j-cuda-10.2 diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/pom.xml b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/pom.xml index f95f9268d..87bb7e68e 100644 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/pom.xml +++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/pom.xml @@ -83,11 +83,11 @@ - test-nd4j-cuda-10.1 + test-nd4j-cuda-10.2 org.nd4j - nd4j-cuda-10.1 + nd4j-cuda-10.2 ${project.version} test diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/pom.xml b/deeplearning4j/deeplearning4j-nearestneighbors-parent/pom.xml index 720e7a5f3..23d5d225d 100644 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/pom.xml +++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/pom.xml @@ -44,7 +44,7 @@ test-nd4j-native - test-nd4j-cuda-10.1 + test-nd4j-cuda-10.2 diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-chinese/pom.xml b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-chinese/pom.xml index 35eb2903d..219301c56 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-chinese/pom.xml +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-chinese/pom.xml @@ -66,7 +66,7 @@ test-nd4j-native - test-nd4j-cuda-10.1 + test-nd4j-cuda-10.2 diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/pom.xml b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/pom.xml index 260de2004..beeb07d34 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/pom.xml +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/pom.xml @@ -68,7 +68,7 @@ test-nd4j-native - test-nd4j-cuda-10.1 + test-nd4j-cuda-10.2 diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-korean/pom.xml b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-korean/pom.xml index aa9eb15b0..e11e9044f 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-korean/pom.xml +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-korean/pom.xml @@ -61,7 +61,7 @@ test-nd4j-native - test-nd4j-cuda-10.1 + test-nd4j-cuda-10.2 diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/pom.xml b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/pom.xml index 619d36db6..39eda5e50 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/pom.xml +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/pom.xml @@ -79,7 +79,7 @@ test-nd4j-native - test-nd4j-cuda-10.1 + test-nd4j-cuda-10.2 diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/pom.xml b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/pom.xml index 3f367689c..da4fd9cba 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/pom.xml +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/pom.xml @@ -84,7 +84,7 @@ test-nd4j-native - test-nd4j-cuda-10.1 + test-nd4j-cuda-10.2 diff --git a/deeplearning4j/deeplearning4j-nlp-parent/pom.xml b/deeplearning4j/deeplearning4j-nlp-parent/pom.xml index 838a3bb6e..6c4eea3fd 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/pom.xml +++ b/deeplearning4j/deeplearning4j-nlp-parent/pom.xml @@ -42,7 +42,7 @@ test-nd4j-native - test-nd4j-cuda-10.1 + test-nd4j-cuda-10.2 diff --git a/deeplearning4j/deeplearning4j-nn/pom.xml b/deeplearning4j/deeplearning4j-nn/pom.xml index 2564330b3..c1ff45a61 100644 --- a/deeplearning4j/deeplearning4j-nn/pom.xml +++ b/deeplearning4j/deeplearning4j-nn/pom.xml @@ -124,7 +124,7 @@ test-nd4j-native - test-nd4j-cuda-10.1 + test-nd4j-cuda-10.2 diff --git a/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/pom.xml b/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/pom.xml index a9b7879f0..62f95f736 100644 --- a/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/pom.xml +++ b/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/pom.xml @@ -93,14 +93,14 @@ - test-nd4j-cuda-10.1 + test-nd4j-cuda-10.2 false org.nd4j - nd4j-cuda-10.1 + nd4j-cuda-10.2 ${project.version} test diff --git a/deeplearning4j/deeplearning4j-remote/pom.xml b/deeplearning4j/deeplearning4j-remote/pom.xml index c1244455a..4ef2e06dd 100644 --- a/deeplearning4j/deeplearning4j-remote/pom.xml +++ b/deeplearning4j/deeplearning4j-remote/pom.xml @@ -24,7 +24,7 @@ test-nd4j-native - test-nd4j-cuda-10.1 + test-nd4j-cuda-10.2 diff --git a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/pom.xml b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/pom.xml index 94f66b405..969025e50 100644 --- a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/pom.xml +++ b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/pom.xml @@ -93,7 +93,7 @@ test-nd4j-native - test-nd4j-cuda-10.1 + test-nd4j-cuda-10.2 diff --git a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper-parameter-server/pom.xml b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper-parameter-server/pom.xml index f95b9935b..97515cf5e 100644 --- a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper-parameter-server/pom.xml +++ b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper-parameter-server/pom.xml @@ -102,11 +102,11 @@ - test-nd4j-cuda-10.1 + test-nd4j-cuda-10.2 org.nd4j - nd4j-cuda-10.1 + nd4j-cuda-10.2 ${project.version} test diff --git a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/pom.xml b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/pom.xml index e3b6444bf..08eed7f15 100644 --- a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/pom.xml +++ b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/pom.xml @@ -97,7 +97,7 @@ test-nd4j-native - test-nd4j-cuda-10.1 + test-nd4j-cuda-10.2 diff --git a/deeplearning4j/deeplearning4j-scaleout/pom.xml b/deeplearning4j/deeplearning4j-scaleout/pom.xml index 65ddf17f5..2c192cc10 100644 --- a/deeplearning4j/deeplearning4j-scaleout/pom.xml +++ b/deeplearning4j/deeplearning4j-scaleout/pom.xml @@ -39,7 +39,7 @@ test-nd4j-native - test-nd4j-cuda-10.1 + test-nd4j-cuda-10.2 diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/pom.xml b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/pom.xml index 8a19b3b68..42e799b69 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/pom.xml +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/pom.xml @@ -67,7 +67,7 @@ test-nd4j-native - test-nd4j-cuda-10.1 + test-nd4j-cuda-10.2 diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/pom.xml b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/pom.xml index 8d49f4d19..0a92d19ab 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/pom.xml +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/pom.xml @@ -72,7 +72,7 @@ test-nd4j-native - test-nd4j-cuda-10.1 + test-nd4j-cuda-10.2 diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/pom.xml b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/pom.xml index f22f2f6b8..fc1e96ec0 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/pom.xml +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/pom.xml @@ -73,7 +73,7 @@ test-nd4j-native - test-nd4j-cuda-10.1 + test-nd4j-cuda-10.2 diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/pom.xml b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/pom.xml index e70e61939..7b49dfce3 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/pom.xml +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/pom.xml @@ -98,7 +98,7 @@ test-nd4j-native - test-nd4j-cuda-10.1 + test-nd4j-cuda-10.2 diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/pom.xml b/deeplearning4j/deeplearning4j-scaleout/spark/pom.xml index a24676022..8a4fb02d5 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/pom.xml +++ b/deeplearning4j/deeplearning4j-scaleout/spark/pom.xml @@ -181,7 +181,7 @@ test-nd4j-native - test-nd4j-cuda-10.1 + test-nd4j-cuda-10.2 diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-components/pom.xml b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-components/pom.xml index 88f94a5a1..2fe90e566 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-components/pom.xml +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-components/pom.xml @@ -74,7 +74,7 @@ test-nd4j-native - test-nd4j-cuda-10.1 + test-nd4j-cuda-10.2 diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-model/pom.xml b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-model/pom.xml index dd8447a6a..b6e1f2f67 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-model/pom.xml +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-model/pom.xml @@ -106,7 +106,7 @@ test-nd4j-native - test-nd4j-cuda-10.1 + test-nd4j-cuda-10.2 diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-standalone/pom.xml b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-standalone/pom.xml index 654e34dc1..1b85a1d87 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-standalone/pom.xml +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-standalone/pom.xml @@ -32,7 +32,7 @@ test-nd4j-native - test-nd4j-cuda-10.1 + test-nd4j-cuda-10.2 diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui/pom.xml b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui/pom.xml index d829971b4..282867e7e 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui/pom.xml +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui/pom.xml @@ -72,7 +72,7 @@ test-nd4j-native - test-nd4j-cuda-10.1 + test-nd4j-cuda-10.2 diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/pom.xml b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/pom.xml index 5e5ae75c1..3f37266c9 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/pom.xml +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/pom.xml @@ -440,7 +440,7 @@ test-nd4j-native - test-nd4j-cuda-10.1 + test-nd4j-cuda-10.2 \ No newline at end of file diff --git a/deeplearning4j/deeplearning4j-ui-parent/pom.xml b/deeplearning4j/deeplearning4j-ui-parent/pom.xml index 087261f0c..70c32b984 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/pom.xml +++ b/deeplearning4j/deeplearning4j-ui-parent/pom.xml @@ -49,7 +49,7 @@ test-nd4j-native - test-nd4j-cuda-10.1 + test-nd4j-cuda-10.2 diff --git a/deeplearning4j/deeplearning4j-util/pom.xml b/deeplearning4j/deeplearning4j-util/pom.xml index 36845c54e..b49239a9e 100644 --- a/deeplearning4j/deeplearning4j-util/pom.xml +++ b/deeplearning4j/deeplearning4j-util/pom.xml @@ -52,7 +52,7 @@ test-nd4j-native - test-nd4j-cuda-10.1 + test-nd4j-cuda-10.2 diff --git a/deeplearning4j/deeplearning4j-zoo/pom.xml b/deeplearning4j/deeplearning4j-zoo/pom.xml index aec30b64e..976d7500b 100644 --- a/deeplearning4j/deeplearning4j-zoo/pom.xml +++ b/deeplearning4j/deeplearning4j-zoo/pom.xml @@ -78,7 +78,7 @@ test-nd4j-native - test-nd4j-cuda-10.1 + test-nd4j-cuda-10.2 diff --git a/deeplearning4j/dl4j-integration-tests/pom.xml b/deeplearning4j/dl4j-integration-tests/pom.xml index d38410fcd..27461c923 100644 --- a/deeplearning4j/dl4j-integration-tests/pom.xml +++ b/deeplearning4j/dl4j-integration-tests/pom.xml @@ -113,7 +113,7 @@ test-nd4j-native - test-nd4j-cuda-10.1 + test-nd4j-cuda-10.2 \ No newline at end of file diff --git a/deeplearning4j/dl4j-perf/pom.xml b/deeplearning4j/dl4j-perf/pom.xml index cfd0347be..239eead6b 100644 --- a/deeplearning4j/dl4j-perf/pom.xml +++ b/deeplearning4j/dl4j-perf/pom.xml @@ -122,7 +122,7 @@ test-nd4j-native - test-nd4j-cuda-10.1 + test-nd4j-cuda-10.2 diff --git a/deeplearning4j/pom.xml b/deeplearning4j/pom.xml index a139c4f44..f9b1eecce 100644 --- a/deeplearning4j/pom.xml +++ b/deeplearning4j/pom.xml @@ -226,7 +226,7 @@ ${skipBackendChoice} - test-nd4j-native,test-nd4j-cuda-10.1 + test-nd4j-native,test-nd4j-cuda-10.2 false @@ -501,7 +501,7 @@ - test-nd4j-cuda-10.1 + test-nd4j-cuda-10.2 false @@ -514,7 +514,7 @@ org.nd4j - nd4j-cuda-10.1 + nd4j-cuda-10.2 ${nd4j.version} test diff --git a/docs/deeplearning4j/templates/config-cudnn.md b/docs/deeplearning4j/templates/config-cudnn.md index 5044b3ca0..24f69da87 100644 --- a/docs/deeplearning4j/templates/config-cudnn.md +++ b/docs/deeplearning4j/templates/config-cudnn.md @@ -10,17 +10,8 @@ weight: 3 Deeplearning4j supports CUDA but can be further accelerated with cuDNN. Most 2D CNN layers (such as ConvolutionLayer, SubsamplingLayer, etc), and also LSTM and BatchNormalization layers support CuDNN. -The only thing we need to do to have DL4J load cuDNN is to add a dependency on `deeplearning4j-cuda-9.2`, `deeplearning4j-cuda-10.0`, or `deeplearning4j-cuda-10.1`, for example: +The only thing we need to do to have DL4J load cuDNN is to add a dependency on `deeplearning4j-cuda-10.0`, `deeplearning4j-cuda-10.1`, or `deeplearning4j-cuda-10.2` for example: -```xml - - org.deeplearning4j - deeplearning4j-cuda-9.2 - {{page.version}} - -``` - -or ```xml org.deeplearning4j @@ -38,6 +29,16 @@ or ``` +or +```xml + + org.deeplearning4j + deeplearning4j-cuda-10.2 + {{page.version}} + +``` + + The actual library for cuDNN is not bundled, so be sure to download and install the appropriate package for your platform from NVIDIA: * [NVIDIA cuDNN](https://developer.nvidia.com/cudnn) @@ -48,39 +49,20 @@ Note there are multiple combinations of cuDNN and CUDA supported. At this time t CUDA Version cuDNN Version - 9.27.2 10.07.4 10.17.6 + 10.27.6 - To install, simply extract the library to a directory found in the system path used by native libraries. The easiest way is to place it alongside other libraries from CUDA in the default directory (`/usr/local/cuda/lib64/` on Linux, `/usr/local/cuda/lib/` on Mac OS X, and `C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v9.2\bin\`, `C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v10.0\bin\`, or `C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v10.1\bin\` on Windows). + To install, simply extract the library to a directory found in the system path used by native libraries. The easiest way is to place it alongside other libraries from CUDA in the default directory (`/usr/local/cuda/lib64/` on Linux, `/usr/local/cuda/lib/` on Mac OS X, and `C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v10.0\bin\`, `C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v10.1\bin\`, or `C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v10.2\bin\` on Windows). -Alternatively, in the case of CUDA 10.1, cuDNN comes bundled with the "redist" package of the [JavaCPP Presets for CUDA](https://github.com/bytedeco/javacpp-presets/tree/master/cuda). [After agreeing to the license](https://github.com/bytedeco/javacpp-presets/tree/master/cuda#license-agreements), we can add the following dependencies instead of installing CUDA and cuDNN: +Alternatively, in the case of CUDA 10.2, cuDNN comes bundled with the "redist" package of the [JavaCPP Presets for CUDA](https://github.com/bytedeco/javacpp-presets/tree/master/cuda). [After agreeing to the license](https://github.com/bytedeco/javacpp-presets/tree/master/cuda#license-agreements), we can add the following dependencies instead of installing CUDA and cuDNN: org.bytedeco - cuda - 10.1-7.6-1.5.2 - linux-x86_64-redist - - - org.bytedeco - cuda - 10.1-7.6-1.5.2 - linux-ppc64le-redist - - - org.bytedeco - cuda - 10.1-7.6-1.5.2 - macosx-x86_64-redist - - - org.bytedeco - cuda - 10.1-7.6-1.5.2 - windows-x86_64-redist + cuda-platform-redist + 10.2-7.6-1.5.2 Also note that, by default, Deeplearning4j will use the fastest algorithms available according to cuDNN, but memory usage may be excessive, causing strange launch errors. When this happens, try to reduce memory usage by using the [`NO_WORKSPACE` mode settable via the network configuration](/api/{{page.version}}/org/deeplearning4j/nn/conf/layers/ConvolutionLayer.Builder.html#cudnnAlgoMode-org.deeplearning4j.nn.conf.layers.ConvolutionLayer.AlgoMode-), instead of the default of `ConvolutionLayer.AlgoMode.PREFER_FASTEST`, for example: diff --git a/libnd4j/pom.xml b/libnd4j/pom.xml index 3e766b944..374bc5640 100644 --- a/libnd4j/pom.xml +++ b/libnd4j/pom.xml @@ -64,7 +64,7 @@ - 10.1 + 10.2 7.6 release cpu diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda-platform/pom.xml b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda-platform/pom.xml index c99b496d1..027b49844 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda-platform/pom.xml +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda-platform/pom.xml @@ -22,12 +22,12 @@ 4.0.0 - nd4j-cuda-10.1-platform + nd4j-cuda-10.2-platform nd4j-cuda-platform - 10.1 + 10.2 7.6 1.5.2 nd4j-cuda-${cuda.version} diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/pom.xml b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/pom.xml index df22d05ab..ec0eab208 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/pom.xml +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/pom.xml @@ -22,12 +22,12 @@ 4.0.0 - nd4j-cuda-10.1 + nd4j-cuda-10.2 nd4j-cuda - 10.1 + 10.2 7.6 1.5.2 @@ -95,6 +95,17 @@ nd4j-native-api ${project.version} + + org.bytedeco + cuda + ${cuda.version}-${cudnn.version}-${javacpp-presets.cuda.version} + + + org.bytedeco + cuda + ${cuda.version}-${cudnn.version}-${javacpp-presets.cuda.version} + ${dependency.platform} + ${javacpp.platform} diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java index cdfb45257..45a20bfbc 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java @@ -7627,9 +7627,9 @@ public static final int PREALLOC_SIZE = 33554432; * the given shape info buffer * represents a scalar shape */ - @Namespace("shape") public static native int isScalar(@Cast("Nd4jLong*") LongPointer info); - @Namespace("shape") public static native int isScalar(@Cast("Nd4jLong*") LongBuffer info); - @Namespace("shape") public static native int isScalar(@Cast("Nd4jLong*") long[] info); + @Namespace("shape") public static native int isScalar(@Cast("const Nd4jLong*") LongPointer info); + @Namespace("shape") public static native int isScalar(@Cast("const Nd4jLong*") LongBuffer info); + @Namespace("shape") public static native int isScalar(@Cast("const Nd4jLong*") long[] info); /** * Returns whether @@ -8033,6 +8033,9 @@ public static final int PREALLOC_SIZE = 33554432; @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") LongPointer shape, @Cast("Nd4jLong*") LongPointer coords); @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong*") LongBuffer coords); @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") long[] shape, @Cast("Nd4jLong*") long[] coords); + @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") LongPointer shape, IntPointer coords); + @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") LongBuffer shape, IntBuffer coords); + @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") long[] shape, int[] coords); /** * take into account only dimensions stored in tadDims, tadDims must be sorted in increasing order! */ @@ -9127,6 +9130,8 @@ public static final int PREALLOC_SIZE = 33554432; ////////////////////////////////////////////////////////////////////// +////////////////////////////////////////////////////////////////////// + diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCudaPresets.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCudaPresets.java index 113960951..8c2109f7c 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCudaPresets.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCudaPresets.java @@ -16,6 +16,10 @@ package org.nd4j.nativeblas; +import java.util.List; +import org.bytedeco.javacpp.ClassProperties; +import org.bytedeco.javacpp.LoadEnabled; +import org.bytedeco.javacpp.Loader; import org.bytedeco.javacpp.annotation.Platform; import org.bytedeco.javacpp.annotation.Properties; import org.bytedeco.javacpp.tools.Info; @@ -111,7 +115,42 @@ import org.bytedeco.javacpp.tools.InfoMapper; @Platform(value = "linux-arm64", preloadpath = {"/usr/aarch64-linux-gnu/lib/", "/usr/lib/aarch64-linux-gnu/"}), @Platform(value = "linux-ppc64", preloadpath = {"/usr/powerpc64-linux-gnu/lib/", "/usr/powerpc64le-linux-gnu/lib/", "/usr/lib/powerpc64-linux-gnu/", "/usr/lib/powerpc64le-linux-gnu/"}), @Platform(value = "windows", preload = {"libwinpthread-1", "libgcc_s_seh-1", "libgomp-1", "libstdc++-6", "libnd4jcpu"}) }) -public class Nd4jCudaPresets implements InfoMapper { +public class Nd4jCudaPresets implements LoadEnabled, InfoMapper { + + @Override public void init(ClassProperties properties) { + String platform = properties.getProperty("platform"); + List preloads = properties.get("platform.preload"); + List resources = properties.get("platform.preloadresource"); + + // Only apply this at load time since we don't want to copy the CUDA libraries here + if (!Loader.isLoadLibraries()) { + return; + } + int i = 0; + String[] libs = {"cudart", "cublasLt", "cublas", "cusolver", "cusparse", "cudnn"}; + for (String lib : libs) { + switch (platform) { + case "linux-arm64": + case "linux-ppc64le": + case "linux-x86_64": + case "macosx-x86_64": + lib += lib.equals("cudnn") ? "@.7" : lib.equals("cudart") ? "@.10.2" : "@.10"; + break; + case "windows-x86_64": + lib += lib.equals("cudnn") ? "64_7" : lib.equals("cudart") ? "64_102" : "64_10"; + break; + default: + continue; // no CUDA + } + if (!preloads.contains(lib)) { + preloads.add(i++, lib); + } + } + if (i > 0) { + resources.add("/org/bytedeco/cuda/"); + } + } + @Override public void map(InfoMap infoMap) { infoMap.put(new Info("thread_local", "ND4J_EXPORT", "INLINEDEF", "CUBLASWINAPI", "FORCEINLINE", diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java index 85c1f92f4..ff1619490 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java @@ -7630,9 +7630,9 @@ public static final int PREALLOC_SIZE = 33554432; * the given shape info buffer * represents a scalar shape */ - @Namespace("shape") public static native int isScalar(@Cast("Nd4jLong*") LongPointer info); - @Namespace("shape") public static native int isScalar(@Cast("Nd4jLong*") LongBuffer info); - @Namespace("shape") public static native int isScalar(@Cast("Nd4jLong*") long[] info); + @Namespace("shape") public static native int isScalar(@Cast("const Nd4jLong*") LongPointer info); + @Namespace("shape") public static native int isScalar(@Cast("const Nd4jLong*") LongBuffer info); + @Namespace("shape") public static native int isScalar(@Cast("const Nd4jLong*") long[] info); /** * Returns whether @@ -8036,6 +8036,9 @@ public static final int PREALLOC_SIZE = 33554432; @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") LongPointer shape, @Cast("Nd4jLong*") LongPointer coords); @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong*") LongBuffer coords); @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") long[] shape, @Cast("Nd4jLong*") long[] coords); + @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") LongPointer shape, IntPointer coords); + @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") LongBuffer shape, IntBuffer coords); + @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") long[] shape, int[] coords); /** * take into account only dimensions stored in tadDims, tadDims must be sorted in increasing order! */ @@ -9130,6 +9133,8 @@ public static final int PREALLOC_SIZE = 33554432; ////////////////////////////////////////////////////////////////////// +////////////////////////////////////////////////////////////////////// + diff --git a/nd4j/nd4j-backends/nd4j-tests-tensorflow/pom.xml b/nd4j/nd4j-backends/nd4j-tests-tensorflow/pom.xml index 0ad010efd..bc468c874 100644 --- a/nd4j/nd4j-backends/nd4j-tests-tensorflow/pom.xml +++ b/nd4j/nd4j-backends/nd4j-tests-tensorflow/pom.xml @@ -225,7 +225,7 @@ org.nd4j - nd4j-cuda-10.1 + nd4j-cuda-10.2 ${project.version} diff --git a/nd4j/nd4j-backends/nd4j-tests/pom.xml b/nd4j/nd4j-backends/nd4j-tests/pom.xml index 50fa24bf9..e6861d257 100644 --- a/nd4j/nd4j-backends/nd4j-tests/pom.xml +++ b/nd4j/nd4j-backends/nd4j-tests/pom.xml @@ -208,7 +208,7 @@ org.nd4j - nd4j-cuda-10.1 + nd4j-cuda-10.2 ${project.version} diff --git a/nd4j/nd4j-remote/nd4j-grpc-client/pom.xml b/nd4j/nd4j-remote/nd4j-grpc-client/pom.xml index 9dbdcbf24..93d171535 100644 --- a/nd4j/nd4j-remote/nd4j-grpc-client/pom.xml +++ b/nd4j/nd4j-remote/nd4j-grpc-client/pom.xml @@ -107,7 +107,7 @@ org.nd4j - nd4j-cuda-10.1 + nd4j-cuda-10.2 ${project.version} test diff --git a/nd4j/nd4j-remote/nd4j-json-server/pom.xml b/nd4j/nd4j-remote/nd4j-json-server/pom.xml index 47ef995a9..6c307f71c 100644 --- a/nd4j/nd4j-remote/nd4j-json-server/pom.xml +++ b/nd4j/nd4j-remote/nd4j-json-server/pom.xml @@ -163,7 +163,7 @@ org.nd4j - nd4j-cuda-10.1 + nd4j-cuda-10.2 ${project.version} test diff --git a/nd4j/nd4j-uberjar/pom.xml b/nd4j/nd4j-uberjar/pom.xml index bb61f0f70..c3398dea9 100644 --- a/nd4j/nd4j-uberjar/pom.xml +++ b/nd4j/nd4j-uberjar/pom.xml @@ -284,7 +284,7 @@ org.nd4j - nd4j-cuda-10.1 + nd4j-cuda-10.2 ${project.version} @@ -316,12 +316,12 @@ org.nd4j - nd4j-cuda-10.1 + nd4j-cuda-10.2 ${project.version} org.nd4j - nd4j-cuda-10.1-platform + nd4j-cuda-10.2-platform ${project.version} diff --git a/nd4s/pom.xml b/nd4s/pom.xml index f165cfae9..f10f8cf41 100644 --- a/nd4s/pom.xml +++ b/nd4s/pom.xml @@ -312,11 +312,11 @@ - test-nd4j-cuda-10.1 + test-nd4j-cuda-10.2 org.nd4j - nd4j-cuda-10.1 + nd4j-cuda-10.2 ${project.version} test diff --git a/rl4j/rl4j-core/pom.xml b/rl4j/rl4j-core/pom.xml index 67b050ba1..ebdfec1a6 100644 --- a/rl4j/rl4j-core/pom.xml +++ b/rl4j/rl4j-core/pom.xml @@ -127,7 +127,7 @@ org.nd4j - nd4j-cuda-10.1 + nd4j-cuda-10.2 ${project.version} test diff --git a/scalnet/pom.xml b/scalnet/pom.xml index a6e220280..39408dbf5 100644 --- a/scalnet/pom.xml +++ b/scalnet/pom.xml @@ -293,7 +293,7 @@ org.nd4j - nd4j-cuda-10.1 + nd4j-cuda-10.2 ${project.version} test From d19eeaec5204e78bc53acf2362e1970128cf7239 Mon Sep 17 00:00:00 2001 From: Yurii Shyrma Date: Fri, 29 Nov 2019 13:14:30 +0200 Subject: [PATCH 12/30] Shyrma casual conv1d (#90) * - add causal mode of padding to convolutions Signed-off-by: Yurii * - add additional tests for causal conv1d Signed-off-by: Yurii * - add causal mode for cuda conv kernels Signed-off-by: Yurii * Java side of Conv1D changes Signed-off-by: raver119 * Add Conv1DDerivative op Signed-off-by: Alex Black * Causal Conv1D gradient checks Signed-off-by: Alex Black * Tweaks Signed-off-by: Alex Black * - add causal padding mode to conv2d_bp Signed-off-by: Yurii * More thorough causal conv1d tests Signed-off-by: Alex Black --- .../declarable/generic/nn/convo/conv1d.cpp | 44 ++-- .../declarable/generic/nn/convo/conv3d.cpp | 24 +- .../include/ops/declarable/headers/convo.h | 32 +-- .../ops/declarable/helpers/convolutions.h | 151 ++++++------ .../declarable/helpers/cpu/convolutions.cpp | 60 +++-- .../declarable/helpers/cuda/convolutions.cu | 54 +++-- .../layers_tests/ConvolutionTests1.cpp | 220 +++++++++++++++++- .../tests_cpu/layers_tests/ParityOpsTests.cpp | 2 +- .../ops/impl/layers/convolution/Conv1D.java | 34 +-- .../layers/convolution/Conv1DDerivative.java | 152 ++++++++++++ .../convolution/config/Conv1DConfig.java | 23 +- .../convolution/config/PaddingMode.java | 24 ++ .../org/nd4j/linalg/util/ConvConfigUtil.java | 4 +- .../java/org/nd4j/nativeblas/Nd4jCpu.java | 28 +-- .../opvalidation/LayerOpValidation.java | 62 ++++- 15 files changed, 695 insertions(+), 219 deletions(-) create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv1DDerivative.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/config/PaddingMode.java diff --git a/libnd4j/include/ops/declarable/generic/nn/convo/conv1d.cpp b/libnd4j/include/ops/declarable/generic/nn/convo/conv1d.cpp index 2d82346ff..2800e7185 100644 --- a/libnd4j/include/ops/declarable/generic/nn/convo/conv1d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/convo/conv1d.cpp @@ -31,7 +31,7 @@ namespace ops { -CUSTOM_OP_IMPL(conv1d, 2, 1, false, 0, 4) { +CUSTOM_OP_IMPL(conv1d, 2, 1, false, 0, 5) { auto input = INPUT_VARIABLE(0); // [bS, iW, iC] (NWC) or [bS, iC, iW] (NCW) auto weights = INPUT_VARIABLE(1); // [kW, iC, oC] always @@ -42,8 +42,9 @@ CUSTOM_OP_IMPL(conv1d, 2, 1, false, 0, 4) { int kW = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(weights->sizeAt(0));// filter(kernel) width int sW = INT_ARG(1); // strides width int pW = INT_ARG(2); // paddings width - int isSameMode = INT_ARG(3); // 0-VALID, 1-SAME - int isNCW = block.getIArguments()->size() > 4 ? !INT_ARG(4) : 1; // INT_ARG(4): 0-NCW, 1-NWC + int dW = INT_ARG(3); // dilations width + int paddingMode = INT_ARG(4); // 0-VALID, 1-SAME, 2-CAUSAL + int isNCW = block.getIArguments()->size() > 5 ? !INT_ARG(5) : 1; // INT_ARG(4): 0-NCW, 1-NWC const int rank = 3; REQUIRE_TRUE(input->rankOf() == rank, 0, "CUSTOM CONV1D OP: rank of input array must be equal to %i, but got %i instead !", rank, input->rankOf()); @@ -81,7 +82,12 @@ CUSTOM_OP_IMPL(conv1d, 2, 1, false, 0, 4) { auto outputReshaped = output ->reshape(output->ordering(), reshapeForOutput); auto weightsReshaped = weights->reshape(weights->ordering(), {1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC] - ConvolutionUtils::conv2d(block, &inputReshaped, &weightsReshaped, bias, &outputReshaped, 1,kW, 1,sW, 0,pW, 1,1, isSameMode, isNCW); + nd4j::ops::conv2d conv2d; + const Nd4jStatus status = conv2d.execute({&inputReshaped, &weightsReshaped, bias}, {&outputReshaped}, {}, {1,kW, 1,sW, 0,pW, 1,dW, paddingMode, !isNCW}, {}); + if (status != ND4J_STATUS_OK) + return status; + + // ConvolutionUtils::conv2d(block, &inputReshaped, &weightsReshaped, bias, &outputReshaped, 1,kW, 1,sW, 0,pW, 1,dW, paddingMode, isNCW); return Status::OK(); } @@ -96,8 +102,9 @@ DECLARE_SHAPE_FN(conv1d) { int kW = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(shape::sizeAt(weightsShapeInfo, 0)); // filter(kernel) width int sW = INT_ARG(1); // strides width int pW = INT_ARG(2); // paddings width - int isSameMode = INT_ARG(3); // 0-VALID, 1-SAME - int isNCW = block.getIArguments()->size() > 4 ? !INT_ARG(4) : 1; // INT_ARG(4): 1-NWC, 0-NCW + int dW = INT_ARG(3); // dilations width + int paddingMode = INT_ARG(4); // 0-VALID, 1-SAME + int isNCW = block.getIArguments()->size() > 5 ? !INT_ARG(5) : 1; // INT_ARG(4): 1-NWC, 0-NCW int indIOioC, indIiW, indWoC(2); if(!isNCW) { @@ -122,7 +129,7 @@ DECLARE_SHAPE_FN(conv1d) { REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM CONV1D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo)); int oH, oW; // output height, width - ConvolutionUtils::calcOutSizePool2D(oH,oW, 1,kW, 1,sW, 0,pW, 1,1, 1,iW, isSameMode); + ConvolutionUtils::calcOutSizePool2D(oH,oW, 1,kW, 1,sW, 0,pW, 1,dW, 1,iW, paddingMode); Nd4jLong* outputShapeInfo = nullptr; ALLOCATE(outputShapeInfo, block.getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong); @@ -153,7 +160,7 @@ DECLARE_TYPES(conv1d) { ////////////////////////////////////////////////////////////////////////// -CUSTOM_OP_IMPL(conv1d_bp, 3, 2, false, 0, 4) { +CUSTOM_OP_IMPL(conv1d_bp, 3, 2, false, 0, 5) { auto input = INPUT_VARIABLE(0); // [bS, iW, iC] (NWC) or [bS, iC, iW] (NCW) auto weights = INPUT_VARIABLE(1); // [kW, iC, oC] always @@ -167,8 +174,9 @@ CUSTOM_OP_IMPL(conv1d_bp, 3, 2, false, 0, 4) { int kW = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(weights->sizeAt(0));// filter(kernel) width int sW = INT_ARG(1); // strides width int pW = INT_ARG(2); // paddings width - int isSameMode = INT_ARG(3); // 0-VALID, 1-SAME - int isNCW = block.getIArguments()->size() > 4 ? !INT_ARG(4) : 1; // INT_ARG(4): 1-NWC, 0-NCW + int dW = INT_ARG(3); // dilations width + int paddingMode = INT_ARG(4); // 0-VALID, 1-SAME, 2-CAUSAL + int isNCW = block.getIArguments()->size() > 5 ? !INT_ARG(5) : 1; // INT_ARG(4): 1-NWC, 0-NCW const int rank = 3; REQUIRE_TRUE(input->rankOf() == rank, 0, "CUSTOM CONV1D_BP OP: rank of input array must be equal to %i, but got %i instead !", rank, input->rankOf()); @@ -188,7 +196,7 @@ CUSTOM_OP_IMPL(conv1d_bp, 3, 2, false, 0, 4) { const int oC = weights->sizeAt(indWoC); // output channels int trueoH, trueoW; // true output height, width - ConvolutionUtils::calcOutSizePool2D(trueoH,trueoW, 1,kW, 1,sW, 0,pW, 1,1, 1,iW, isSameMode); + ConvolutionUtils::calcOutSizePool2D(trueoH,trueoW, 1,kW, 1,sW, 0,pW, 1,dW, 1,iW, paddingMode); std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoW, 0,indIOioC,indIiW})); std::string expectedWeightsShape = ShapeUtils::shapeAsString({kW, iC, oC}); @@ -213,7 +221,12 @@ CUSTOM_OP_IMPL(conv1d_bp, 3, 2, false, 0, 4) { auto weightsReshaped = weights->reshape(weights->ordering(),{1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC] auto gradWReshaped = gradW ->reshape(gradW->ordering(), {1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC] - ConvolutionUtils::conv2dBP(block, &inputReshaped, &weightsReshaped, bias, &gradOReshaped, &gradIReshaped, &gradWReshaped, gradB, 1,kW, 1,sW, 0,pW, 1,1, isSameMode, isNCW); + nd4j::ops::conv2d_bp conv2dBP; + const Nd4jStatus status = conv2dBP.execute({&inputReshaped, &weightsReshaped, bias, &gradOReshaped}, {&gradIReshaped, &gradWReshaped, gradB}, {}, {1,kW, 1,sW, 0,pW, 1,dW, paddingMode, !isNCW}, {}); + if (status != ND4J_STATUS_OK) + return status; + + // ConvolutionUtils::conv2dBP(block, &inputReshaped, &weightsReshaped, bias, &gradOReshaped, &gradIReshaped, &gradWReshaped, gradB, 1,kW, 1,sW, 0,pW, 1,dW, paddingMode, isNCW); return Status::OK(); } @@ -234,8 +247,9 @@ DECLARE_SHAPE_FN(conv1d_bp) { int kW = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(shape::sizeAt(weightsShapeInfo, 0));// filter(kernel) width int sW = INT_ARG(1); // strides width int pW = INT_ARG(2); // paddings width - int isSameMode = INT_ARG(3); // 0-VALID, 1-SAME - int isNCW = block.getIArguments()->size() > 4 ? !INT_ARG(4) : 1; // INT_ARG(4): 1-NWC, 0-NCW + int dW = INT_ARG(3); // dilations width + int paddingMode = INT_ARG(4); // 0-VALID, 1-SAME + int isNCW = block.getIArguments()->size() > 5 ? !INT_ARG(5) : 1; // INT_ARG(4): 1-NWC, 0-NCW int indIOioC, indIiW, indWoC(2); if(!isNCW) { @@ -251,7 +265,7 @@ DECLARE_SHAPE_FN(conv1d_bp) { const int oC = weightsShapeInfo[indWoC+1]; // output channels int trueoH, trueoW; // true output height, width - ConvolutionUtils::calcOutSizePool2D(trueoH,trueoW, 1,kW, 1,sW, 0,pW, 1,1, 1,iW, isSameMode); + ConvolutionUtils::calcOutSizePool2D(trueoH,trueoW, 1,kW, 1,sW, 0,pW, 1,dW, 1,iW, paddingMode); std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoW, 0,indIOioC,indIiW})); std::string expectedWeightsShape = ShapeUtils::shapeAsString({kW, iC, oC}); diff --git a/libnd4j/include/ops/declarable/generic/nn/convo/conv3d.cpp b/libnd4j/include/ops/declarable/generic/nn/convo/conv3d.cpp index 38138de0e..98223c5b4 100644 --- a/libnd4j/include/ops/declarable/generic/nn/convo/conv3d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/convo/conv3d.cpp @@ -51,20 +51,20 @@ CUSTOM_OP_IMPL(conv3dnew, 2, 1, false, 0, 13) { int dD = INT_ARG(9); // dilations depth int dH = INT_ARG(10); // dilations height int dW = INT_ARG(11); // dilations width - int isSameMode = INT_ARG(12); // 0-SAME, 1-VALID + int paddingMode = INT_ARG(12); // 0-SAME, 1-VALID int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); + REQUIRE_TRUE(paddingMode < 2, 0, "CUSTOM CONV3D OP: causal padding mode (paddingMode = 2) is not allowed for this operation !"); std::string expectedWeightsShape = ShapeUtils::shapeAsString({kD, kH, kW, iC, oC}); REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weights), 0, "CUSTOM CONV3D OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weights).c_str()); if (bias) REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV3D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); - if(isSameMode) // SAME - ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW); + ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW, paddingMode); nd4j_debug("MKL-DNN is not used for conv3dnew!\n", 0); @@ -116,10 +116,11 @@ DECLARE_SHAPE_FN(conv3dnew) { int dD = INT_ARG(9); // dilations depth int dH = INT_ARG(10); // dilations height int dW = INT_ARG(11); // dilations width - int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID; + int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID; int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW const int rank = 5; + REQUIRE_TRUE(paddingMode < 2, 0, "CUSTOM CONV3D OP: causal padding mode (paddingMode = 2) is not allowed for this operation !"); REQUIRE_TRUE(inputShapeInfo[0] == rank, 0, "CUSTOM CONV3D OP: rank of input array must be equal to %i, but got %i instead !", rank, inputShapeInfo); REQUIRE_TRUE(weightsShapeInfo[0] == rank, 0, "CUSTOM CONV3D OP: rank of weights array must be equal to %i, but got %i instead !", rank, weightsShapeInfo); @@ -144,7 +145,7 @@ DECLARE_SHAPE_FN(conv3dnew) { REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM CONV3D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo)); int oD, oH, oW; // output depth, height, width - ConvolutionUtils::calcOutSizePool3D(oD, oH, oW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, isSameMode); + ConvolutionUtils::calcOutSizePool3D(oD, oH, oW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, paddingMode); Nd4jLong* outputShapeInfo = nullptr; ALLOCATE(outputShapeInfo, block.getWorkspace(), shape::shapeInfoLength(inputShapeInfo), Nd4jLong); @@ -197,7 +198,7 @@ CUSTOM_OP_IMPL(conv3dnew_bp, 3, 2, false, 0, 13) { int dD = INT_ARG(9); // dilations depth int dH = INT_ARG(10); // dilations height int dW = INT_ARG(11); // dilations width - int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID + int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID int isNDHWC = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; @@ -205,8 +206,9 @@ CUSTOM_OP_IMPL(conv3dnew_bp, 3, 2, false, 0, 13) { ConvolutionUtils::getSizesAndIndexesConv3d(isNDHWC, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); int trueoD, trueoH, trueoW; // true output depth/height/width - ConvolutionUtils::calcOutSizePool3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, isSameMode); + ConvolutionUtils::calcOutSizePool3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, paddingMode); + REQUIRE_TRUE(paddingMode < 2, 0, "CUSTOM CONV3D OP: causal padding mode (paddingMode = 2) is not allowed for this operation !"); std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoD,trueoH,trueoW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2})); std::string expectedWeightsShape = ShapeUtils::shapeAsString({kD, kH, kW, iC, oC}); REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, "CUSTOM CONV3D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str()); @@ -214,8 +216,7 @@ CUSTOM_OP_IMPL(conv3dnew_bp, 3, 2, false, 0, 13) { if(bias) REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV3D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); - if(isSameMode) // SAME - ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW); + ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW, paddingMode); nd4j_debug("MKL-DNN is not used for conv3dnew_bp!\n", 0); @@ -285,10 +286,11 @@ DECLARE_SHAPE_FN(conv3dnew_bp) { int dD = INT_ARG(9); // dilations depth int dH = INT_ARG(10); // dilations height int dW = INT_ARG(11); // dilations width - int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID + int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID int isNDHWC = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW const int rank = 5; + REQUIRE_TRUE(paddingMode < 2, 0, "CUSTOM CONV3D OP: causal padding mode (paddingMode = 2) is not allowed for this operation !"); REQUIRE_TRUE(inputShapeInfo[0] == rank, 0, "CUSTOM CONV3D_BP OP: rank of input array must be equal to %i, but got %i instead !", rank, inputShapeInfo); REQUIRE_TRUE(weightsShapeInfo[0] == rank, 0, "CUSTOM CONV3D_BP OP: rank of weights array must be equal to %i, but got %i instead !", rank, weightsShapeInfo); REQUIRE_TRUE(gradOShapeInfo[0] == rank, 0, "CUSTOM CONV3D_BP OP: rank of output gradients (next epsilon) array must be equal to %i, but got %i instead !", rank, gradOShapeInfo); @@ -309,7 +311,7 @@ DECLARE_SHAPE_FN(conv3dnew_bp) { int oC = weightsShapeInfo[indWoC+1]; // output channels int trueoD, trueoH, trueoW; // true output depth/height/width - ConvolutionUtils::calcOutSizePool3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, isSameMode); + ConvolutionUtils::calcOutSizePool3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, paddingMode); std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoD,trueoH,trueoW, 0,indIOioC,indIiD,indIiD+1,indIiD+2})); std::string expectedWeightsShape = ShapeUtils::shapeAsString({kD, kH, kW, iC, oC}); diff --git a/libnd4j/include/ops/declarable/headers/convo.h b/libnd4j/include/ops/declarable/headers/convo.h index bd262a7c1..89824c342 100644 --- a/libnd4j/include/ops/declarable/headers/convo.h +++ b/libnd4j/include/ops/declarable/headers/convo.h @@ -28,28 +28,28 @@ namespace nd4j { /** * 1D temporal convolution implementation - * Expected input: + * Expected input: * x: 3D array * weight: 3D Array * bias: optional vector - * + * * Int args: * 0: kernel * 1: stride * 2: padding */ #if NOT_EXCLUDED(OP_conv1d) - DECLARE_CUSTOM_OP(conv1d, 2, 1, false, 0, 4); - DECLARE_CUSTOM_OP(conv1d_bp, 3, 2, false, 0, 4); + DECLARE_CUSTOM_OP(conv1d, 2, 1, false, 0, 5); + DECLARE_CUSTOM_OP(conv1d_bp, 3, 2, false, 0, 5); #endif /** * 2D convolution implementation - * Expected input: + * Expected input: * x: 4D array * weight: 4D Array * bias: optional vector, length of outputChannels - * + * * IntArgs: * 0: kernel height * 1: kernel width @@ -83,7 +83,7 @@ namespace nd4j { /** * 2D deconvolution implementation - * + * * IntArgs: * 0: kernel height * 1: kernel width @@ -102,7 +102,7 @@ namespace nd4j { /** * 3D deconvolution implementation - * + * * IntArgs: * 0: filter(kernel) depth * 1: filter(kernel) height @@ -190,7 +190,7 @@ namespace nd4j { /** * This op implements im2col algorithm, widely used in convolution neural networks * Input: 4D input expected - * + * * Int args: * 0: kernel height * 1: kernel width @@ -210,7 +210,7 @@ namespace nd4j { /** * This op implements col2im algorithm, widely used in convolution neural networks * Input: 6D input expected (like output of im2col op) - * + * * Int args: * 0: stride height * 1: stride width @@ -227,7 +227,7 @@ namespace nd4j { /** * Expected input: 4D array - * + * * IntArgs: * 0: scale factor for rows (height) * 1: scale factor for columns (width) @@ -240,7 +240,7 @@ namespace nd4j { /** * Expected input: 4D array - * + * * IntArgs: * 0: scale factor for depth * 1: scale factor for rows (height) @@ -249,13 +249,13 @@ namespace nd4j { */ #if NOT_EXCLUDED(OP_upsampling3d) DECLARE_CUSTOM_OP(upsampling3d, 1, 1, false, 0, 3); - DECLARE_CUSTOM_OP(upsampling3d_bp, 2, 1, false, 0, 0); + DECLARE_CUSTOM_OP(upsampling3d_bp, 2, 1, false, 0, 0); #endif /** * This op produces binary matrix wrt to target dimension. * Maximum value within each TAD is replaced with 1, other values are set to true. - * + * * Int args: * 0: axis */ @@ -265,7 +265,7 @@ namespace nd4j { /** * Dilation2D op - * + * * Int args: * 0: isSameMode */ @@ -295,7 +295,7 @@ namespace nd4j { * Output: * 0 - 4D tensor as input * 1 - 4D tensor with max value indexes - * + * * Int params: * 9 int with 2x4 vectors and 1 bool value */ diff --git a/libnd4j/include/ops/declarable/helpers/convolutions.h b/libnd4j/include/ops/declarable/helpers/convolutions.h index 81695c9ac..65544960a 100644 --- a/libnd4j/include/ops/declarable/helpers/convolutions.h +++ b/libnd4j/include/ops/declarable/helpers/convolutions.h @@ -37,79 +37,93 @@ namespace nd4j { class ConvolutionUtils { public: - static inline void calcOutSizePool2D(int& oH, int& oW, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int iH, const int iW, const int isSameMode) { - if(isSameMode > 0) { - oH = (int) math::nd4j_ceil(iH * 1. / sH); - oW = (int) math::nd4j_ceil(iW * 1. / sW); - } - else { + static inline void calcOutSizePool2D(int& oH, int& oW, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int iH, const int iW, const int paddingMode) { + + if(paddingMode == 0) { // valid oH = (iH - (kH + (kH-1)*(dH-1)) + 2*pH)/sH + 1; oW = (iW - (kW + (kW-1)*(dW-1)) + 2*pW)/sW + 1; } + else if (paddingMode == 1) { // same + oH = (int) math::nd4j_ceil(iH * 1. / sH); + oW = (int) math::nd4j_ceil(iW * 1. / sW); + } + else { // causal + oH = (iH - 1) / sH + 1; // 2*pH = (kH-1)*dH + oW = (iW - 1) / sW + 1; + } } - static inline void calcOutSizePool3D(int& oD, int& oH, int& oW, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int iD, const int iH, const int iW, const int isSameMode) { - if(!isSameMode) { // valid + static inline void calcOutSizePool3D(int& oD, int& oH, int& oW, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int iD, const int iH, const int iW, const int paddingMode) { + if(paddingMode == 0) { // valid oD = (iD - (kD + (kD - 1) * (dD - 1)) + 2 * pD) / sD + 1; oH = (iH - (kH + (kH - 1) * (dH - 1)) + 2 * pH) / sH + 1; oW = (iW - (kW + (kW - 1) * (dW - 1)) + 2 * pW) / sW + 1; } - else { // same - + else if(paddingMode == 1) { // same oD = (int) nd4j::math::nd4j_ceil(iD * 1. / sD); oH = (int) nd4j::math::nd4j_ceil(iH * 1. / sH); oW = (int) nd4j::math::nd4j_ceil(iW * 1. / sW); + + } + else { // causal + oD = (iD - 1) / sD + 1; + oH = (iH - 1) / sH + 1; // 2*pH = (kH-1)*dH + oW = (iW - 1) / sW + 1; } } - static inline void calcPadding2D(int& pH, int& pW, int oH, int oW, int iH, int iW, int kH, int kW, int sH, int sW, int dH, int dW) { - int eKH, eKW; - if (dH == 1 && dW == 1) { - eKH = kH; - eKW = kW; - } else { - eKH = (kH - 1) * dH + 1; - eKW = (kW - 1) * dW + 1; - } + static inline void calcPadding2D(int& pH, int& pW, int oH, int oW, int iH, int iW, int kH, int kW, int sH, int sW, int dH, int dW, const int paddingMode = 1 /* default is same mode*/) { - pH = ((oH - 1) * sH + eKH - iH) / 2; //Note that padBottom is 1 bigger than this if bracketed term is not divisible by 2 - pW = ((oW - 1) * sW + eKW - iW) / 2; + if(paddingMode == 0) // valid + return; + + if(paddingMode == 1) { // same + + const int eKH = (kH - 1) * dH + 1; + const int eKW = (kW - 1) * dW + 1; + + pH = ((oH - 1) * sH + eKH - iH) / 2; //Note that padBottom is 1 bigger than this if bracketed term is not divisible by 2 + pW = ((oW - 1) * sW + eKW - iW) / 2; + } + else { // causal + pH = (kH - 1) * dH; + pW = (kW - 1) * dW; + } } - static inline void calcPadding3D(int& pD, int& pH, int& pW, const int oD, const int oH, const int oW, const int iD, const int iH, const int iW, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int dD, const int dH, const int dW) { - int eKD, eKH, eKW; - if (dD == 1 && dH == 1 && dW == 1) { - eKD = kD; - eKH = kH; - eKW = kW; - } else { - eKD = (kD - 1) * dD + 1; - eKH = (kH - 1) * dH + 1; - eKW = (kW - 1) * dW + 1; + static inline void calcPadding3D(int& pD, int& pH, int& pW, const int oD, const int oH, const int oW, const int iD, const int iH, const int iW, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int dD, const int dH, const int dW, const int paddingMode = 1 /* default is same mode*/) { + + if(paddingMode == 0) // valid + return; + + if(paddingMode == 1) { // same + + const int eKD = (kD - 1) * dD + 1; + const int eKH = (kH - 1) * dH + 1; + const int eKW = (kW - 1) * dW + 1; + + pD = ((oD - 1) * sD + eKD - iD) / 2; + pH = ((oH - 1) * sH + eKH - iH) / 2; //Note that padBottom is 1 bigger than this if bracketed term is not divisible by 2 + pW = ((oW - 1) * sW + eKW - iW) / 2; + } + else { // causal + pD = (kD - 1) * dD; + pH = (kH - 1) * dH; + pW = (kW - 1) * dW; } - - pD = ((oD - 1) * sD + eKD - iD) / 2; // Note that padBottom is 1 bigger than this if bracketed term is not divisible by 2 - pH = ((oH - 1) * sH + eKH - iH) / 2; - pW = ((oW - 1) * sW + eKW - iW) / 2; - } // calculation of output height and width in 2D deconvolution procedure - static inline void calcOutSizeDeconv2D(int& oH, int& oW, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int iH, const int iW, const int isSameMode) { - if (isSameMode) { + static inline void calcOutSizeDeconv2D(int& oH, int& oW, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int iH, const int iW, const int paddingMode) { + + if (paddingMode) { oH = sH * iH; oW = sW * iW; } else { - int ekH, ekW; - if (dH == 1 && dW == 1) { - ekH = kH; - ekW = kW; - } else { - ekH = (kH - 1) * dH + 1; - ekW = (kW - 1) * dW + 1; - } + const int ekH = (kH - 1) * dH + 1; + const int ekW = (kW - 1) * dW + 1; oH = sH * (iH - 1) + ekH - 2 * pH; oW = sW * (iW - 1) + ekW - 2 * pW; @@ -117,24 +131,19 @@ namespace nd4j { } // calculation of output height and width in 3D deconvolution procedure - static inline void calcOutSizeDeconv3D(int& oD, int& oH, int& oW, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int iD, const int iH, const int iW, const int isSameMode) { - if (isSameMode) { + static inline void calcOutSizeDeconv3D(int& oD, int& oH, int& oW, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int iD, const int iH, const int iW, const int paddingMode) { + + if (paddingMode) { oD = sD * iD; oH = sH * iH; oW = sW * iW; } else { - int ekD, ekH, ekW; - if (dD == 1 && dH == 1 && dW == 1) { - ekD = kD; - ekH = kH; - ekW = kW; - } - else { - ekD = (kD - 1) * dD + 1; - ekH = (kH - 1) * dH + 1; - ekW = (kW - 1) * dW + 1; - } + + const int ekD = (kD - 1) * dD + 1; + const int ekH = (kH - 1) * dH + 1; + const int ekW = (kW - 1) * dW + 1; + oD = sD * (iD - 1) + ekD - 2 * pD; oH = sH * (iH - 1) + ekH - 2 * pH; oW = sW * (iW - 1) + ekW - 2 * pW; @@ -194,10 +203,10 @@ namespace nd4j { } - // static inline void calcPaddingAndDilationForConv2DMKL(const int iH, const int iW, const int oH, const int oW, const int kH, const int kW, const int sH, const int sW, const int isSameMode, int& pH, int& pW, int& dH, int& dW) { + // static inline void calcPaddingAndDilationForConv2DMKL(const int iH, const int iW, const int oH, const int oW, const int kH, const int kW, const int sH, const int sW, const int paddingMode, int& pH, int& pW, int& dH, int& dW) { // if(kH != 1) { - // if(isSameMode) { + // if(paddingMode) { // pH = (oH - 1) * sH - iH + kH - pH; // dH = dH - 1; // } @@ -205,7 +214,7 @@ namespace nd4j { // dH = (iH + 2*pH - (oH - 1) * sH - kH) / (kH - 1); // } // if(kW != 1) { - // if(isSameMode) { + // if(paddingMode) { // pW = (oW - 1) * sW - iW + kW - pW; // dW = dW - 1; // } @@ -214,10 +223,10 @@ namespace nd4j { // } // } - // static inline void calcPaddingAndDilationForConv3DMKL(const int iD, const int iH, const int iW, const int oD, const int oH, const int oW, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int isSameMode, int& pD, int& pH, int& pW, int& dD, int& dH, int& dW) { + // static inline void calcPaddingAndDilationForConv3DMKL(const int iD, const int iH, const int iW, const int oD, const int oH, const int oW, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int paddingMode, int& pD, int& pH, int& pW, int& dD, int& dH, int& dW) { // if(kD != 1) { - // if(isSameMode) { + // if(paddingMode) { // pD = (oD - 1) * sD - iD + kD - pD; // dD = dD - 1; // } @@ -225,7 +234,7 @@ namespace nd4j { // dD = (iD + 2*pD - (oD - 1) * sD - kD) / (kD - 1); // } // if(kH != 1) { - // if(isSameMode) { + // if(paddingMode) { // pH = (oH - 1) * sH - iH + kH - pH; // dH = dH - 1; // } @@ -233,7 +242,7 @@ namespace nd4j { // dH = (iH + 2*pH - (oH - 1) * sH - kH) / (kH - 1); // } // if(kW != 1) { - // if(isSameMode) { + // if(paddingMode) { // pW = (oW - 1) * sW - iW + kW - pW; // dW = dW - 1; // } @@ -242,19 +251,19 @@ namespace nd4j { // } // } - static void conv2d(nd4j::graph::Context &context, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW); + static void conv2d(nd4j::graph::Context &context, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW); // static void conv2d(nd4j::graph::Context & block, const std::vector& inArrs, NDArray* output, const std::vector& intArgs); // static void conv2dBP(nd4j::graph::Context & block, const std::vector& inArrs, const std::vector& outArrs, const std::vector& intArgs); - static void conv2dBP(nd4j::graph::Context & block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW); + static void conv2dBP(nd4j::graph::Context & block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW); - static void depthwiseConv2d(nd4j::graph::Context & block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW); + static void depthwiseConv2d(nd4j::graph::Context & block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW); - static void depthwiseConv2dBP(nd4j::graph::Context & block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW); + static void depthwiseConv2dBP(nd4j::graph::Context & block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW); - static void sconv2d(nd4j::graph::Context & block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW); + static void sconv2d(nd4j::graph::Context & block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW); static void vol2col(nd4j::graph::Context & block, const NDArray& vol, NDArray& col, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW); diff --git a/libnd4j/include/ops/declarable/helpers/cpu/convolutions.cpp b/libnd4j/include/ops/declarable/helpers/cpu/convolutions.cpp index 0829bcbe6..47938e9fb 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/convolutions.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/convolutions.cpp @@ -258,7 +258,7 @@ namespace nd4j { ////////////////////////////////////////////////////////////////////////// template - static void conv2d_(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) { + static void conv2d_(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) { // input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) // weights [kH, kW, iC, oC] always @@ -273,15 +273,14 @@ namespace nd4j { // pW paddings width // dH dilations height // dW dilations width - // isSameMode 0-VALID, 1-SAME - // isNCHW 1-NCHW, 0-NHWC + // paddingMode 0-VALID, 1-SAME + // isNCHW 1-NCHW, 0-NHWC int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); - if(isSameMode) // SAME - ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); nd4j_debug("MKL-DNN is not used for conv2d!\n", 0); @@ -320,7 +319,7 @@ namespace nd4j { ////////////////////////////////////////////////////////////////////////// template - static void conv2dBP_(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) { + static void conv2dBP_(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) { // input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) // weights [kH, kW, iC, oC] always @@ -339,15 +338,14 @@ namespace nd4j { // pW paddings width // dH dilations height // dW dilations width - // isSameMode 0-VALID, 1-SAME - // isNCHW 0-NHWC, 1-NCHW + // paddingMode 0-VALID, 1-SAME + // isNCHW 0-NHWC, 1-NCHW int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); - if(isSameMode) // SAME - ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); nd4j_debug("MKL-DNN is not used for conv2d_bp!\n", 0); @@ -393,7 +391,7 @@ namespace nd4j { ////////////////////////////////////////////////////////////////////////// template - static void depthwiseConv2d_(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) { + static void depthwiseConv2d_(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) { // input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) // weights [kH, kW, iC, mC] always @@ -408,7 +406,7 @@ namespace nd4j { // pW paddings width // dH dilations height // dW dilations width - // isSameMode 0-VALID, 1-SAME + // paddingMode 0-VALID, 1-SAME // isNCHW 0-NCHW, 1-NHWC int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width @@ -430,7 +428,7 @@ namespace nd4j { modifOutput = {{1,0,3,4,2},{iC, bS*oH*oW, mC}}; // [bS,iC,mC,oH,oW] -> [iC,bS,oH,oW,mC] -> [iC,bS*oH*oW,mC] } - if(isSameMode) // SAME + if(paddingMode == 1) // SAME ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); NDArray columns(input->ordering(), {bS, iC, kH, kW, oH, oW}, input->dataType(), input->getContext()); @@ -449,7 +447,7 @@ namespace nd4j { ////////////////////////////////////////////////////////////////////////// template - static void depthwiseConv2dBP_(const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) { + static void depthwiseConv2dBP_(const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) { // input [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW) // weights [kH, kW, iC, mC] always @@ -467,7 +465,7 @@ namespace nd4j { // pW paddings width // dH dilations height // dW dilations width - // isSameMode 0-VALID, 1-SAME + // paddingMode 0-VALID, 1-SAME // isNCHW 0-NHWC, 1-NCHW int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width @@ -492,7 +490,7 @@ namespace nd4j { modifGradO2 = {{1,0,2,3},{iC, mC, bS*oH*oW}}; // [bS,iC*mC,oH,oW] -> [iC*mC,bS,oH,oW] -> [iC,mC,bS*oH*oW] } - if(isSameMode) // SAME + if(paddingMode == 1) // SAME ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); NDArray columns(input->ordering(), {bS, iC, kH, kW, oH, oW}, input->dataType(), input->getContext()); @@ -526,7 +524,7 @@ namespace nd4j { ////////////////////////////////////////////////////////////////////////// template - static void sconv2d_(nd4j::graph::Context& block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) { + static void sconv2d_(nd4j::graph::Context& block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) { // input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) // weightsDepth [kH, kW, iC, mC] always @@ -542,8 +540,8 @@ namespace nd4j { // pW paddings width // dH dilations height // dW dilations width - // isSameMode 0-VALID, 1-SAME - // isNCHW 1-NCHW, 0-NHWC + // paddingMode 0-VALID, 1-SAME + // isNCHW 1-NCHW, 0-NHWC int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier, output channels, output height/width int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes @@ -555,11 +553,11 @@ namespace nd4j { outputDepth = new NDArray(output->ordering(), !isNCHW ? std::vector({bS, oH, oW, iC*mC}) : std::vector({bS, iC*mC, oH, oW}), input->dataType(), input->getContext()); // ----- perform depthwise convolution (if weightsPoint is absent then oC = iC*mC) ----- // - ConvolutionUtils::depthwiseConv2d(block, input, weightsDepth, weightsPoint ? nullptr : bias, outputDepth, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW); + ConvolutionUtils::depthwiseConv2d(block, input, weightsDepth, weightsPoint ? nullptr : bias, outputDepth, kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, isNCHW); // ----- perform pointwise convolution (oH = iH, oW = iW) ----- // if (weightsPoint) { - ConvolutionUtils::conv2d(block, outputDepth, weightsPoint, bias, output, 1,1, 1,1, 0,0, 1,1, isSameMode, isNCHW); // in this case oH=iH, oW=iW + ConvolutionUtils::conv2d(block, outputDepth, weightsPoint, bias, output, 1,1, 1,1, 0,0, 1,1, paddingMode, isNCHW); // in this case oH=iH, oW=iW delete outputDepth; } } @@ -1774,20 +1772,20 @@ namespace nd4j { - void ConvolutionUtils::conv2d(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) { - BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), conv2d_, (block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), FLOAT_TYPES); + void ConvolutionUtils::conv2d(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) { + BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), conv2d_, (block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW), FLOAT_TYPES); } - void ConvolutionUtils::conv2dBP(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) { - BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), conv2dBP_, (block, input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), FLOAT_TYPES); + void ConvolutionUtils::conv2dBP(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) { + BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), conv2dBP_, (block, input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW), FLOAT_TYPES); } - void ConvolutionUtils::depthwiseConv2d(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) { - BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), depthwiseConv2d_, (block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), FLOAT_TYPES); + void ConvolutionUtils::depthwiseConv2d(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) { + BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), depthwiseConv2d_, (block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW), FLOAT_TYPES); } - void ConvolutionUtils::depthwiseConv2dBP(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) { - BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), depthwiseConv2dBP_, (input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), FLOAT_TYPES); + void ConvolutionUtils::depthwiseConv2dBP(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) { + BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), depthwiseConv2dBP_, (input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW), FLOAT_TYPES); } - void ConvolutionUtils::sconv2d(nd4j::graph::Context& block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) { - BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), sconv2d_, (block, input, weightsDepth, weightsPoint, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), FLOAT_TYPES); + void ConvolutionUtils::sconv2d(nd4j::graph::Context& block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) { + BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), sconv2d_, (block, input, weightsDepth, weightsPoint, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW), FLOAT_TYPES); } void ConvolutionUtils::vol2col(nd4j::graph::Context& block, const NDArray& volume, NDArray& columns, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW) { BUILD_SINGLE_SELECTOR(volume.dataType(), vol2col_, (volume, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW), FLOAT_TYPES); diff --git a/libnd4j/include/ops/declarable/helpers/cuda/convolutions.cu b/libnd4j/include/ops/declarable/helpers/cuda/convolutions.cu index 4887b7266..6b86ce302 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/convolutions.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/convolutions.cu @@ -217,7 +217,7 @@ void ConvolutionUtils::col2vol(nd4j::graph::Context& block, const NDArray& col, ////////////////////////////////////////////////////////////////////////// template -static void conv2d_(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) { +static void conv2d_(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) { // input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) // weights [kH, kW, iC, oC] always @@ -232,15 +232,14 @@ static void conv2d_(nd4j::graph::Context& block, const NDArray* input, const NDA // pW paddings width // dH dilations height // dW dilations width - // isSameMode 0-VALID, 1-SAME + // paddingMode 0-VALID, 1-SAME // isNCHW 1-NCHW, 0-NHWC int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); - if(isSameMode) // SAME - ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); std::vector permutForOutput; @@ -276,13 +275,13 @@ static void conv2d_(nd4j::graph::Context& block, const NDArray* input, const NDA } ////////////////////////////////////////////////////////////////////////// -void ConvolutionUtils::conv2d(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) { - BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), conv2d_, (block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), FLOAT_TYPES); +void ConvolutionUtils::conv2d(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) { + BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), conv2d_, (block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW), FLOAT_TYPES); } ////////////////////////////////////////////////////////////////////////// template -static void depthwiseConv2d_(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) { +static void depthwiseConv2d_(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) { // input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) // weights [kH, kW, iC, mC] always @@ -297,7 +296,7 @@ static void depthwiseConv2d_(nd4j::graph::Context& block, const NDArray* input, // pW paddings width // dH dilations height // dW dilations width - // isSameMode 0-VALID, 1-SAME + // paddingMode 0-VALID, 1-SAME // isNCHW 0-NCHW, 1-NHWC int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width @@ -319,7 +318,7 @@ static void depthwiseConv2d_(nd4j::graph::Context& block, const NDArray* input, modifOutput = {{1,0,3,4,2},{iC, bS*oH*oW, mC}}; // [bS,iC,mC,oH,oW] -> [iC,bS,oH,oW,mC] -> [iC,bS*oH*oW,mC] } - if(isSameMode) // SAME + if(paddingMode == 1) // SAME ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); NDArray columns(input->ordering(), {bS, iC, kH, kW, oH, oW}, input->dataType(), input->getContext()); @@ -337,13 +336,13 @@ static void depthwiseConv2d_(nd4j::graph::Context& block, const NDArray* input, } ////////////////////////////////////////////////////////////////////////// -void ConvolutionUtils::depthwiseConv2d(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) { - BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), depthwiseConv2d_, (block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), FLOAT_TYPES); +void ConvolutionUtils::depthwiseConv2d(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) { + BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), depthwiseConv2d_, (block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW), FLOAT_TYPES); } ////////////////////////////////////////////////////////////////////////// template -static void sconv2d_(nd4j::graph::Context& block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) { +static void sconv2d_(nd4j::graph::Context& block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) { // input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) // weightsDepth [kH, kW, iC, mC] always @@ -359,7 +358,7 @@ static void sconv2d_(nd4j::graph::Context& block, const NDArray* input, const ND // pW paddings width // dH dilations height // dW dilations width - // isSameMode 0-VALID, 1-SAME + // paddingMode 0-VALID, 1-SAME // isNCHW 1-NCHW, 0-NHWC int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier, output channels, output height/width @@ -372,18 +371,18 @@ static void sconv2d_(nd4j::graph::Context& block, const NDArray* input, const ND outputDepth = new NDArray(output->ordering(), !isNCHW ? std::vector({bS, oH, oW, iC*mC}) : std::vector({bS, iC*mC, oH, oW}), input->dataType(), input->getContext()); // ----- perform depthwise convolution (if weightsPoint is absent then oC = iC*mC) ----- // - ConvolutionUtils::depthwiseConv2d(block, input, weightsDepth, weightsPoint ? nullptr : bias, outputDepth, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW); + ConvolutionUtils::depthwiseConv2d(block, input, weightsDepth, weightsPoint ? nullptr : bias, outputDepth, kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, isNCHW); // ----- perform pointwise convolution (oH = iH, oW = iW) ----- // if (weightsPoint) { - ConvolutionUtils::conv2d(block, outputDepth, weightsPoint, bias, output, 1,1, 1,1, 0,0, 1,1, isSameMode, isNCHW); // in this case oH=iH, oW=iW + ConvolutionUtils::conv2d(block, outputDepth, weightsPoint, bias, output, 1,1, 1,1, 0,0, 1,1, paddingMode, isNCHW); // in this case oH=iH, oW=iW delete outputDepth; } } ////////////////////////////////////////////////////////////////////////// -void ConvolutionUtils::sconv2d(nd4j::graph::Context& block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) { - BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), sconv2d_, (block, input, weightsDepth, weightsPoint, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), FLOAT_TYPES); +void ConvolutionUtils::sconv2d(nd4j::graph::Context& block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) { + BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), sconv2d_, (block, input, weightsDepth, weightsPoint, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW), FLOAT_TYPES); } ////////////////////////////////////////////////////////////////////////// @@ -1177,7 +1176,7 @@ void ConvolutionUtils::pooling3dBP(nd4j::graph::Context& block, const NDArray& i ////////////////////////////////////////////////////////////////////////// template -static void conv2dBP_(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) { +static void conv2dBP_(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) { // input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) // weights [kH, kW, iC, oC] always @@ -1196,15 +1195,14 @@ static void conv2dBP_(nd4j::graph::Context& block, const NDArray* input, const N // pW paddings width // dH dilations height // dW dilations width - // isSameMode 0-VALID, 1-SAME + // paddingMode 0-VALID, 1-SAME // isNCHW 0-NHWC, 1-NCHW int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); - if(isSameMode) // SAME - ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); std::vector gradOaxesForDot; @@ -1247,13 +1245,13 @@ static void conv2dBP_(nd4j::graph::Context& block, const NDArray* input, const N } ////////////////////////////////////////////////////////////////////////// -void ConvolutionUtils::conv2dBP(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) { - BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), conv2dBP_, (block, input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), FLOAT_TYPES); +void ConvolutionUtils::conv2dBP(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) { + BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), conv2dBP_, (block, input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW), FLOAT_TYPES); } ////////////////////////////////////////////////////////////////////////// template -static void depthwiseConv2dBP_(const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) { +static void depthwiseConv2dBP_(const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) { // input [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW) // weights [kH, kW, iC, mC] always @@ -1271,7 +1269,7 @@ static void depthwiseConv2dBP_(const NDArray* input, const NDArray* weights, con // pW paddings width // dH dilations height // dW dilations width - // isSameMode 0-VALID, 1-SAME + // paddingMode 0-VALID, 1-SAME // isNCHW 0-NHWC, 1-NCHW int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width @@ -1296,7 +1294,7 @@ static void depthwiseConv2dBP_(const NDArray* input, const NDArray* weights, con modifGradO2 = {{1,0,2,3},{iC, mC, bS*oH*oW}}; // [bS,iC*mC,oH,oW] -> [iC*mC,bS,oH,oW] -> [iC,mC,bS*oH*oW] } - if(isSameMode) // SAME + if(paddingMode == 1) // SAME ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); NDArray columns(input->ordering(), {bS, iC, kH, kW, oH, oW}, input->dataType(), input->getContext()); @@ -1328,8 +1326,8 @@ static void depthwiseConv2dBP_(const NDArray* input, const NDArray* weights, con } ////////////////////////////////////////////////////////////////////////// -void ConvolutionUtils::depthwiseConv2dBP(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) { - BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), depthwiseConv2dBP_, (input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), FLOAT_TYPES); +void ConvolutionUtils::depthwiseConv2dBP(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) { + BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), depthwiseConv2dBP_, (input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW), FLOAT_TYPES); } diff --git a/libnd4j/tests_cpu/layers_tests/ConvolutionTests1.cpp b/libnd4j/tests_cpu/layers_tests/ConvolutionTests1.cpp index 3b49c735d..bb4fe7b3c 100644 --- a/libnd4j/tests_cpu/layers_tests/ConvolutionTests1.cpp +++ b/libnd4j/tests_cpu/layers_tests/ConvolutionTests1.cpp @@ -31,6 +31,7 @@ #include #include #include +#include #ifdef HAVE_MKLDNN #include @@ -771,7 +772,7 @@ TYPED_TEST(TypedConvolutionTests1, Test_Conv1D_ff_1) { bias.linspace(1); nd4j::ops::conv1d op; - auto result_FF = op.execute({&input, &weights, &bias}, {}, {2, 1, 0, 0}); + auto result_FF = op.execute({&input, &weights, &bias}, {}, {2, 1, 0, 1, 0, 0}); ASSERT_EQ(ND4J_STATUS_OK, result_FF->status()); @@ -785,7 +786,7 @@ TYPED_TEST(TypedConvolutionTests1, Test_Conv1D_ff_1) { auto epsilonNxt = z->dup(); epsilonNxt->linspace(1); - auto result_BP = op_bp.execute({&input, &weights, &bias, epsilonNxt}, {}, {2, 1, 0, 0}); + auto result_BP = op_bp.execute({&input, &weights, &bias, epsilonNxt}, {}, {2, 1, 0, 1, 0, 0}); ASSERT_EQ(ND4J_STATUS_OK, result_BP->status()); auto eps = result_BP->at(0); @@ -813,7 +814,7 @@ TYPED_TEST(TypedConvolutionTests1, Test_Conv1D_ff_2) { input.linspace(1); nd4j::ops::conv1d op; - auto result = op.execute({&input, &weights}, {}, {2, 1, 0, 1}); + auto result = op.execute({&input, &weights}, {}, {2, 1, 0, 1, 1,0}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -822,6 +823,219 @@ TYPED_TEST(TypedConvolutionTests1, Test_Conv1D_ff_2) { delete result; } +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, conv1d_causal_1) { + + int bS=2, iW=3, iC=4,oC=3, kW=2, sW=1, pW=0, dW=1; + int oW = (iW-1)/sW + 1; + int paddingMode = 2; // CAUSAL + int dataFormat = 1; // 1-NHWC, 0-NCHW + + NDArray input('c', {bS, iW, iC}); + NDArray weights('c', {kW, iC, oC}); + NDArray bias('c', {oC}, {-1,-2,-3}); + + NDArray expOutput('c', {bS, oW, oC}, {18. , 18. , 18. , 53. , 55.6, 58.2, 89.8, 95.6, 101.4, 102. , 106.8, 111.6, 163.4, 175.6, 187.8, 200.2, 215.6, 231.}); + + input.linspace(1., 1.); + weights.linspace(0.1, 0.1); + + nd4j::ops::conv1d op; + auto results = op.execute({&input, &weights, &bias}, {}, {kW, sW, pW, dW, paddingMode, dataFormat}); + auto output = results->at(0); + + ASSERT_EQ(Status::OK(), results->status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); + + delete results; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, conv1d_causal_2) { + + int bS=2, iW=16, iC=3,oC=4, kW=2, sW=2, pW=0, dW=1; + int oW = (iW-1)/sW + 1; + int paddingMode = 2; // CAUSAL + int dataFormat = 1; // 1-NHWC, 0-NCHW + + NDArray input('c', {bS, iW, iC}); + NDArray weights('c', {kW, iC, oC}); + NDArray bias('c', {oC}, {-1,-2,-3,-4}); + + NDArray expOutput('c', {bS, oW, oC}, { 10. , 9.6, 9.2, 8.8, 48.9, 51.8, 54.7, 57.6, 88.5, 95. , 101.5, 108. , 128.1, 138.2, 148.3, 158.4, + 167.7, 181.4, 195.1, 208.8, 207.3, 224.6, 241.9, 259.2, 246.9, 267.8, 288.7, 309.6, 286.5, 311. , 335.5, 360. , + 254.8, 268.8, 282.8, 296.8, 365.7, 397.4, 429.1, 460.8, 405.3, 440.6, 475.9, 511.2, 444.9, 483.8, 522.7, 561.6, + 484.5, 527. , 569.5, 612. , 524.1, 570.2, 616.3, 662.4, 563.7, 613.4, 663.1, 712.8, 603.3, 656.6, 709.9, 763.2}); + + input.linspace(1., 1.); + weights.linspace(0.1, 0.1); + + nd4j::ops::conv1d op; + auto results = op.execute({&input, &weights, &bias}, {}, {kW, sW, pW, dW, paddingMode, dataFormat}); + auto output = results->at(0); + + ASSERT_EQ(Status::OK(), results->status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); + + delete results; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, conv1d_causal_3) { + + int bS=2, iW=16, iC=3,oC=4, kW=3, sW=3, pW=0, dW=1; + int oW = (iW-1)/sW + 1; + int paddingMode = 2; // CAUSAL + int dataFormat = 1; // 1-NHWC, 0-NCHW + + NDArray input('c', {bS, iW, iC}); + NDArray weights('c', {kW, iC, oC}); + NDArray bias('c', {oC}, {-1,-2,-3,-4}); + + NDArray expOutput('c', {bS, oW, oC}, {17.2, 16.8, 16.4, 16.,145.4, 151.6, 157.8, 164.,283.1, 297.4, 311.7, 326., 420.8, 443.2, 465.6, 488., + 558.5, 589., 619.5, 650.,696.2001, 734.8, 773.4, 812., 434.8, 448.8, 462.8, 476.8, 879.8, 929.2, 978.6, 1028., + 1017.5, 1075., 1132.5, 1190.,1155.2001, 1220.8, 1286.4, 1352.,1292.8999, 1366.6, 1440.3, 1514., 1430.6001, 1512.4, 1594.2, 1676.}); + + input.linspace(1., 1.); + weights.linspace(0.1, 0.1); + + nd4j::ops::conv1d op; + auto results = op.execute({&input, &weights, &bias}, {}, {kW, sW, pW, dW, paddingMode, dataFormat}); + auto output = results->at(0); + + ASSERT_EQ(Status::OK(), results->status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); + + delete results; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, conv1d_causal_4) { + + int bS=2, iW=8, iC=3,oC=4, kW=3, sW=1, pW=0, dW=3; + int oW = (iW-1)/sW + 1; + int paddingMode = 2; // CAUSAL + int dataFormat = 1; // 1-NHWC, 0-NCHW + + NDArray input('c', {bS, iW, iC}); + NDArray weights('c', {kW, iC, oC}); + NDArray bias('c', {oC}, {-1,-2,-3,-4}); + + NDArray expOutput('c', {bS, oW, oC}, {17.2, 16.8, 16.4, 16. ,43.3, 43.8, 44.3, 44.8,69.4, 70.8, 72.2, 73.6,106.5, 109.4, 112.3, 115.2,147.9, 152.6, 157.3, 162. ,189.3, 195.8, 202.3, + 208.8,234.5, 243.4, 252.3, 261.2,280.4, 292. , 303.6, 315.2, 226. , 232.8, 239.6, 246.4, 252.1, 259.8, 267.5, 275.2,278.2, 286.8, 295.4, 304. ,437.7, + 455. , 472.3, 489.6,479.1, 498.2, 517.3, 536.4,520.5, 541.4, 562.3, 583.2, 601.7, 632.2, 662.7, 693.2, 647.6, 680.8, 714. , 747.2}); + + input.linspace(1., 1.); + weights.linspace(0.1, 0.1); + + nd4j::ops::conv1d op; + auto results = op.execute({&input, &weights, &bias}, {}, {kW, sW, pW, dW, paddingMode, dataFormat}); + auto output = results->at(0); + + ASSERT_EQ(Status::OK(), results->status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); + + delete results; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, conv1d_causal_5) { + + int bS=2, iW=8, iC=3,oC=4, kW=3, sW=1, pW=0, dW=3; + int oW = (iW-1)/sW + 1; + int paddingMode = 2; // CAUSAL + int dataFormat = 0; // 1-NHWC, 0-NCHW + + NDArray input('c', {bS, iC, iW}); + NDArray weights('c', {kW, iC, oC}); + NDArray bias('c', {oC}, {-1,-2,-3,-4}); + + NDArray expOutput('c', {bS, oC, oW}, { 83.7, 92.4, 101.1, 162.1, 175.9, 189.7, 223.4, 238.7,85.4, 94.4, 103.4, 167.4, 181.8, 196.2, 233.2, 249.4,87.1, 96.4, 105.7, 172.7, 187.7, 202.7, 243. , 260.1, + 88.8, 98.4, 108. , 178. , 193.6, 209.2, 252.8, 270.8, 292.5, 301.2, 309.9, 493.3, 507.1, 520.9, 590.6, 605.9, 301.4, 310.4, 319.4, 513. , 527.4, 541.8, 622. , 638.2, + 310.3, 319.6, 328.9, 532.7, 547.7, 562.7, 653.4, 670.5, 319.2, 328.8, 338.4, 552.4, 568. , 583.6, 684.8, 702.8}); + + input.linspace(1., 1.); + weights.linspace(0.1, 0.1); + + nd4j::ops::conv1d op; + auto results = op.execute({&input, &weights, &bias}, {}, {kW, sW, pW, dW, paddingMode, dataFormat}); + auto output = results->at(0); + + ASSERT_EQ(Status::OK(), results->status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); + + delete results; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, conv1d_causal_6) { + + int bS=2, iW=16, iC=3,oC=4, kW=3, sW=3, pW=0, dW=1; + int oW = (iW-1)/sW + 1; + int paddingMode = 2; // CAUSAL + int dataFormat = 0; // 1-NHWC, 0-NCHW + + NDArray input('c', {bS, iC, iW}); + NDArray weights('c', {kW, iC, oC}); + NDArray bias('c', {oC}, {-1,-2,-3,-4}); + + NDArray expOutput('c', {bS, oC, oW}, {159.7,335.3,381.2,427.1,473. ,518.9,163.8,351.4,400. ,448.6,497.2,545.8,167.9,367.5,418.8,470.1,521.4,572.7,172. ,383.6,437.6,491.6,545.6,599.6, + 577.3, 1069.7, 1115.6, 1161.5, 1207.4, 1253.3,595.8, 1129. , 1177.6, 1226.2, 1274.8, 1323.4,614.3, 1188.3, 1239.6, 1290.9, 1342.2, 1393.5, + 632.8, 1247.6, 1301.6, 1355.6, 1409.6, 1463.6}); + + input.linspace(1., 1.); + weights.linspace(0.1, 0.1); + + nd4j::ops::conv1d op; + auto results = op.execute({&input, &weights, &bias}, {}, {kW, sW, pW, dW, paddingMode, dataFormat}); + auto output = results->at(0); + + ASSERT_EQ(Status::OK(), results->status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); + + delete results; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, conv1d_causal_bp_1) { + + int bS=2, iW=3, iC=4,oC=3, kW=2, sW=1, pW=0, dW=1; + int oW = (iW-1)/sW + 1; + int paddingMode = 2; // CAUSAL + int dataFormat = 1; // 1-NHWC, 0-NCHW + + NDArray input('c', {bS, iW, iC}); + NDArray weights('c', {kW, iC, oC}); + NDArray bias('c', {oC}, {-1,-2,-3}); + NDArray gradO('c', {bS, oW, oC}); + + input.linspace(1., 1.); + weights.linspace(0.1, 0.1); + gradO.linspace(-1.5, 0.1); + + const OpArgsHolder argsHolderFF({&input, &weights, &bias}, {}, {kW, sW, pW, dW, paddingMode, dataFormat}); + const OpArgsHolder argsHolderBP({&input, &weights, &bias, &gradO}, {}, {kW, sW, pW, dW, paddingMode, dataFormat}); + + nd4j::ops::conv1d opFF; + nd4j::ops::conv1d_bp opBP; + + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); + + ASSERT_TRUE(isGradCorrect); +} + TEST_F(ConvolutionTests1, Test_Dilation2D_1) { auto input = NDArrayFactory::create('c', {2, 6, 6, 3}); auto weights = NDArrayFactory::create('c', {3, 2, 3}); diff --git a/libnd4j/tests_cpu/layers_tests/ParityOpsTests.cpp b/libnd4j/tests_cpu/layers_tests/ParityOpsTests.cpp index 17f00011c..6d58e6e41 100644 --- a/libnd4j/tests_cpu/layers_tests/ParityOpsTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/ParityOpsTests.cpp @@ -908,7 +908,7 @@ TEST_F(ParityOpsTests, scatterMax_test4) { auto exp = NDArrayFactory::create('c', {2, 2, 2}, {1, 10, 10, 10, 5, 6, 7, 8}); nd4j::ops::scatter_max op; - auto result = op.execute({&matrix, &idc, &updates}, {}, {}); + auto result = op.execute({&matrix, &idc, &updates}, {}, {true}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv1D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv1D.java index 852c865f7..577b1177c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv1D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv1D.java @@ -29,12 +29,11 @@ import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv1DConfig; +import org.nd4j.linalg.api.ops.impl.layers.convolution.config.PaddingMode; import org.nd4j.linalg.util.ArrayUtil; import java.lang.reflect.Field; -import java.util.Collections; -import java.util.List; -import java.util.Map; +import java.util.*; /** @@ -79,7 +78,8 @@ public class Conv1D extends DynamicCustomOp { addIArgument(config.getK(), config.getS(), config.getP(), - ArrayUtil.fromBoolean(config.isSameMode()), + config.getD(), + config.getPaddingMode().ordinal(), ArrayUtil.fromBoolean(config.isNWC())); } @@ -95,10 +95,12 @@ public class Conv1D extends DynamicCustomOp { public Object getValue(Field property) { if (config == null && !iArguments.isEmpty()) { config = Conv1DConfig.builder() - .s(iArguments.get(0)) - .p(iArguments.get(1)) - .isSameMode(iArguments.get(2) == 1) - .dataFormat(iArguments.get(3) == 1 ? Conv1DConfig.NCW : Conv1DConfig.NWC) + .k(iArguments.get(0)) + .s(iArguments.get(1)) + .p(iArguments.get(2)) + .d(iArguments.get(3)) + .paddingMode(PaddingMode.values()[iArguments.get(4).intValue()]) + .dataFormat(iArguments.get(5) == 1 ? Conv1DConfig.NCW : Conv1DConfig.NWC) .build(); } @@ -125,16 +127,20 @@ public class Conv1D extends DynamicCustomOp { return "conv1d"; } - - @Override - public String onnxName() { - throw new NoOpNameFoundException("No ONNX op name found for: " + getClass().getName()); - } - @Override public List calculateOutputDataTypes(List inputDataTypes){ int n = args().length; Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes); return Collections.singletonList(inputDataTypes.get(0)); } + + @Override + public List doDiff(List grads){ + List args = new ArrayList<>(); + Collections.addAll(args, args()); + args.add(grads.get(0)); + + Conv1DDerivative gradFn = new Conv1DDerivative(sameDiff, args.toArray(new SDVariable[0]), config); + return Arrays.asList(gradFn.outputVariables()); + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv1DDerivative.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv1DDerivative.java new file mode 100644 index 000000000..6bbe36c58 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv1DDerivative.java @@ -0,0 +1,152 @@ +/* ****************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg.api.ops.impl.layers.convolution; + +import lombok.Builder; +import lombok.Getter; +import lombok.NoArgsConstructor; +import lombok.NonNull; +import lombok.extern.slf4j.Slf4j; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv1DConfig; +import org.nd4j.linalg.api.ops.impl.layers.convolution.config.PaddingMode; +import org.nd4j.linalg.util.ArrayUtil; + +import java.lang.reflect.Field; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; + + +/** + * Conv1D Backprop operation + * + * @author Alex Black + */ +@Slf4j +@Getter +@NoArgsConstructor +public class Conv1DDerivative extends DynamicCustomOp { + + protected Conv1DConfig config; + private static final String INVALID_CONFIGURATION = "Invalid Conv1D configuration : s = %s p = %s "; + + public Conv1DDerivative(@NonNull SameDiff sameDiff, + @NonNull SDVariable[] inputs, + @NonNull Conv1DConfig config) { + super(sameDiff, inputs); + initConfig(config); + } + + public Conv1DDerivative(@NonNull SameDiff sd, @NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, SDVariable gradOut, @NonNull Conv1DConfig config){ + this(sd, wrapFilterNull(input, weights, bias, gradOut), config); + } + + public Conv1DDerivative(INDArray[] inputs, INDArray[] outputs, Conv1DConfig config){ + super(inputs, outputs); + + initConfig(config); + } + + public Conv1DDerivative(@NonNull INDArray input, @NonNull INDArray weights, INDArray bias, @NonNull INDArray gradOut, INDArray output, @NonNull Conv1DConfig config){ + this(wrapFilterNull(input, weights, bias, gradOut), wrapOrNull(output), config); + } + + private void initConfig(Conv1DConfig config){ + this.config = config; + Preconditions.checkState(config.getS() >= 1 && config.getP() >= 0, INVALID_CONFIGURATION, config.getS(), config.getP()); + addArgs(); + } + + protected void addArgs() { + if (config == null) + config = Conv1DConfig.builder().build(); + + addIArgument(config.getK(), + config.getS(), + config.getP(), + config.getD(), + config.getPaddingMode().ordinal(), + ArrayUtil.fromBoolean(config.isNWC())); + } + + @Override + public long[] iArgs() { + if (iArguments.size() == 0) + addArgs(); + + return super.iArgs(); + } + + @Override + public Object getValue(Field property) { + if (config == null && !iArguments.isEmpty()) { + config = Conv1DConfig.builder() + .k(iArguments.get(0)) + .s(iArguments.get(1)) + .p(iArguments.get(2)) + .d(iArguments.get(3)) + .paddingMode(PaddingMode.values()[iArguments.get(4).intValue()]) + .dataFormat(iArguments.get(5) == 1 ? Conv1DConfig.NCW : Conv1DConfig.NWC) + .build(); + } + + return config.getValue(property); + } + + @Override + public Map propertiesForFunction() { + return config.toProperties(); + } + + @Override + public boolean isConfigProperties() { + return true; + } + + @Override + public String configFieldName() { + return "config"; + } + + @Override + public String opName() { + return "conv1d_bp"; + } + + @Override + public int getNumOutputs(){ + if(args().length == 4){ + return 3; //Includes bias + } else { + return 2; //No bias - only input + weight grads + } + } + + @Override + public List calculateOutputDataTypes(List inputDataTypes){ + int n = args().length; + Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes); + return new ArrayList<>(inputDataTypes.subList(0, inputDataTypes.size()-1)); //All except gradient input variable + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/config/Conv1DConfig.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/config/Conv1DConfig.java index f04e27533..196876cb2 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/config/Conv1DConfig.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/config/Conv1DConfig.java @@ -21,6 +21,7 @@ import java.util.Map; import lombok.Builder; import lombok.Data; import lombok.NoArgsConstructor; +import lombok.NonNull; import org.nd4j.base.Preconditions; import org.nd4j.linalg.util.ConvConfigUtil; @@ -38,15 +39,28 @@ public class Conv1DConfig extends BaseConvolutionConfig { @Builder.Default private long p = 0; // padding @Builder.Default + private long d = 1; // dilation + @Builder.Default private String dataFormat = NCW; - private boolean isSameMode; + private PaddingMode paddingMode; + + public Conv1DConfig(long k, long s, long p, long d, String dataFormat, @NonNull PaddingMode paddingMode) { + this.k = k; + this.s = s; + this.p = p; + this.d = d; + this.dataFormat = dataFormat; + this.paddingMode = paddingMode; + + validate(); + } public Conv1DConfig(long k, long s, long p, String dataFormat, boolean isSameMode) { this.k = k; this.s = s; this.p = p; this.dataFormat = dataFormat; - this.isSameMode = isSameMode; + this.paddingMode = isSameMode ? PaddingMode.SAME : PaddingMode.VALID; validate(); } @@ -71,14 +85,15 @@ public class Conv1DConfig extends BaseConvolutionConfig { ret.put("k", k); ret.put("s", s); ret.put("p", p); - ret.put("isSameMode", isSameMode); + ret.put("d", d); + ret.put("isSameMode", paddingMode); ret.put("dataFormat", dataFormat); return ret; } @Override protected void validate() { - ConvConfigUtil.validate1D(k, s, p); + ConvConfigUtil.validate1D(k, s, p, d); Preconditions.checkArgument(dataFormat != null, "Data format can't be null"); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/config/PaddingMode.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/config/PaddingMode.java new file mode 100644 index 000000000..21c3f3f8e --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/config/PaddingMode.java @@ -0,0 +1,24 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg.api.ops.impl.layers.convolution.config; + + +public enum PaddingMode { + VALID, + SAME, + CAUSAL +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/util/ConvConfigUtil.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/util/ConvConfigUtil.java index 91b854923..d532dfe92 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/util/ConvConfigUtil.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/util/ConvConfigUtil.java @@ -76,11 +76,13 @@ public class ConvConfigUtil { /** * Validate a 1D convolution's Kernel, Stride, and Padding */ - public static void validate1D(long k, long s, long p){ + public static void validate1D(long k, long s, long p, long d){ Preconditions.checkArgument(k != 0, "Kernel can not be 0"); Preconditions.checkArgument(s > 0, "Stride can not be negative or 0, got: %s", s); + Preconditions.checkArgument(d > 0, "Dilation can not be negative or 0, got: %s", s); + Preconditions.checkArgument(p >= 0, "Padding can not be negative, got: %s", p); } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java index ff1619490..e9a36d49f 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java @@ -14592,11 +14592,11 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); /** * 1D temporal convolution implementation - * Expected input: + * Expected input: * x: 3D array * weight: 3D Array * bias: optional vector - * + * * Int args: * 0: kernel * 1: stride @@ -14637,11 +14637,11 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); /** * 2D convolution implementation - * Expected input: + * Expected input: * x: 4D array * weight: 4D Array * bias: optional vector, length of outputChannels - * + * * IntArgs: * 0: kernel height * 1: kernel width @@ -14745,7 +14745,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); /** * 2D deconvolution implementation - * + * * IntArgs: * 0: kernel height * 1: kernel width @@ -14792,7 +14792,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); /** * 3D deconvolution implementation - * + * * IntArgs: * 0: filter(kernel) depth * 1: filter(kernel) height @@ -14992,7 +14992,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); /** * This op implements im2col algorithm, widely used in convolution neural networks * Input: 4D input expected - * + * * Int args: * 0: kernel height * 1: kernel width @@ -15040,7 +15040,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); /** * This op implements col2im algorithm, widely used in convolution neural networks * Input: 6D input expected (like output of im2col op) - * + * * Int args: * 0: stride height * 1: stride width @@ -15071,7 +15071,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); /** * Expected input: 4D array - * + * * IntArgs: * 0: scale factor for rows (height) * 1: scale factor for columns (width) @@ -15112,7 +15112,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); /** * Expected input: 4D array - * + * * IntArgs: * 0: scale factor for depth * 1: scale factor for rows (height) @@ -15149,13 +15149,13 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); public upsampling3d_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); - } + } // #endif /** * This op produces binary matrix wrt to target dimension. * Maximum value within each TAD is replaced with 1, other values are set to true. - * + * * Int args: * 0: axis */ @@ -15179,7 +15179,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); /** * Dilation2D op - * + * * Int args: * 0: isSameMode */ @@ -15307,7 +15307,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); * Output: * 0 - 4D tensor as input * 1 - 4D tensor with max value indexes - * + * * Int params: * 9 int with 2x4 vectors and 1 bool value */ diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java index 92bcc71f9..6f4acd079 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java @@ -39,14 +39,7 @@ import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.impl.layers.convolution.AvgPooling2D; import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2D; import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2DDerivative; -import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv1DConfig; -import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig; -import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv3DConfig; -import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv2DConfig; -import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv3DConfig; -import org.nd4j.linalg.api.ops.impl.layers.convolution.config.LocalResponseNormalizationConfig; -import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling2DConfig; -import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling3DConfig; +import org.nd4j.linalg.api.ops.impl.layers.convolution.config.*; import org.nd4j.linalg.api.ops.impl.transforms.custom.LayerNorm; import org.nd4j.linalg.api.ops.impl.transforms.custom.Standardize; import org.nd4j.linalg.factory.Nd4j; @@ -944,7 +937,7 @@ public class LayerOpValidation extends BaseOpValidation { Conv1DConfig conv1DConfig = Conv1DConfig.builder() .k(k).p(0).s(1) - .isSameMode(false) + .paddingMode(PaddingMode.VALID) .build(); SDVariable out = sd.cnn().conv1d(in, w, conv1DConfig); @@ -960,6 +953,55 @@ public class LayerOpValidation extends BaseOpValidation { assertNull(err); } + @Test + public void testConv1dCausal() { + Nd4j.getRandom().setSeed(12345); + int nIn = 3; + int nOut = 4; + int mb = 2; + + for( int k : new int[]{2, 3}) { + for (int sz : new int[]{3, 4, 5}) { + for (int s : new int[]{1, 2}) { + for (int d : new int[]{1, 2}) { + for (boolean ncw : new boolean[]{true, false}) { + + SameDiff sd = SameDiff.create(); + INDArray wArr = Nd4j.rand(DataType.DOUBLE, k, nIn, nOut); + INDArray inArr = Nd4j.rand(DataType.DOUBLE, (ncw ? new long[]{mb, nIn, sz} : new long[]{mb, sz, nIn})); + INDArray bArr = Nd4j.rand(DataType.DOUBLE, nOut); + + SDVariable in = sd.var("in", inArr); + SDVariable w = sd.var("W", wArr); + SDVariable b = sd.var("b", bArr); + + Conv1DConfig conv1DConfig = Conv1DConfig.builder() + .dataFormat(ncw ? Conv1DConfig.NCW : Conv1DConfig.NWC) + .k(k).p(0).s(s).d(d) + .paddingMode(PaddingMode.CAUSAL) + .build(); + + SDVariable out = sd.cnn().conv1d(in, w, b, conv1DConfig); + SDVariable loss = sd.nn().tanh(out).std(true).rename("loss"); + + sd.setLossVariables("loss"); + + String name = "k=" + k + ", sz=" + sz + ", ncw=" + ncw; + + System.out.println(name); + + TestCase tc = new TestCase(sd).testName(name).gradientCheck(true); + String err = OpValidation + .validate(tc); + assertNull(err); + } + } + } + } + } + } + + @Test public void testConv1dForward(){ int nIn = 2; @@ -1254,7 +1296,7 @@ public class LayerOpValidation extends BaseOpValidation { Conv1DConfig conv1DConfig = Conv1DConfig.builder() .k(k).p(-1).s(0) - .isSameMode(false) + .paddingMode(PaddingMode.VALID) .build(); SDVariable out = sd.cnn().conv1d(in, w, conv1DConfig); From dc66a52bc7c6fcda7d37cd9fa1b9f85abaa673db Mon Sep 17 00:00:00 2001 From: shugeo Date: Fri, 29 Nov 2019 15:05:08 +0200 Subject: [PATCH 13/30] [WIP] Shugeo release fixes4 (#91) * Fixed fake_quant_with_min_max_vars op. * Refactored bitcast op. * bad linspace removed Signed-off-by: raver119 * Corrected tests for bitcast op. * Eliminated debug prints. * one fix Signed-off-by: raver119 * one fix Signed-off-by: raver119 * Added a pair of comments. --- libnd4j/include/array/DataBuffer.h | 2 + libnd4j/include/array/cpu/DataBuffer.cpp | 12 +++- libnd4j/include/array/cuda/DataBuffer.cu | 13 +++++ .../declarable/generic/datatypes/bitcast.cpp | 7 ++- .../layers_tests/DeclarableOpsTests15.cpp | 56 ++++++++++++++++++- 5 files changed, 83 insertions(+), 7 deletions(-) diff --git a/libnd4j/include/array/DataBuffer.h b/libnd4j/include/array/DataBuffer.h index 37d575b13..034f16a25 100644 --- a/libnd4j/include/array/DataBuffer.h +++ b/libnd4j/include/array/DataBuffer.h @@ -116,6 +116,8 @@ class ND4J_EXPORT DataBuffer { void setToZeroBuffers(const bool both = false); void copyBufferFrom(const DataBuffer& other, size_t sizeToCopyinBytes = 0, const Nd4jLong offsetThis = 0, const Nd4jLong offsetOther = 0); + + static void memcpy(const DataBuffer &dst, const DataBuffer &src); }; diff --git a/libnd4j/include/array/cpu/DataBuffer.cpp b/libnd4j/include/array/cpu/DataBuffer.cpp index 5d27bf9a1..d13ca0def 100644 --- a/libnd4j/include/array/cpu/DataBuffer.cpp +++ b/libnd4j/include/array/cpu/DataBuffer.cpp @@ -33,7 +33,6 @@ void DataBuffer::setCountersToZero() { void DataBuffer::copyCounters(const DataBuffer& other) { } - //////////////////////////////////////////////////////////////////////// void DataBuffer::allocateBuffers(const bool allocBoth) { // always allocate primary buffer only (cpu case) @@ -49,7 +48,7 @@ void DataBuffer::copyBufferFrom(const DataBuffer& other, size_t sizeToCopyinByte return; if(other._primaryBuffer != nullptr) - memcpy(static_cast(_primaryBuffer) + offsetThis * DataTypeUtils::sizeOfElement(_dataType), static_cast(other._primaryBuffer) + offsetOther * DataTypeUtils::sizeOfElement(other._dataType), sizeToCopyinBytes); + std::memcpy(static_cast(_primaryBuffer) + offsetThis * DataTypeUtils::sizeOfElement(_dataType), static_cast(other._primaryBuffer) + offsetOther * DataTypeUtils::sizeOfElement(other._dataType), sizeToCopyinBytes); } //////////////////////////////////////////////////////////////////////// @@ -61,7 +60,7 @@ void DataBuffer::copyBufferFromHost(const void* hostBuffer, size_t sizeToCopyinB return; if(hostBuffer != nullptr) - memcpy(static_cast(_primaryBuffer) + offsetThis * DataTypeUtils::sizeOfElement(_dataType), static_cast(hostBuffer) + offsetHostBuffer * DataTypeUtils::sizeOfElement(_dataType), sizeToCopyinBytes); + std::memcpy(static_cast(_primaryBuffer) + offsetThis * DataTypeUtils::sizeOfElement(_dataType), static_cast(hostBuffer) + offsetHostBuffer * DataTypeUtils::sizeOfElement(_dataType), sizeToCopyinBytes); } @@ -100,6 +99,13 @@ void DataBuffer::allocateSpecial() { void DataBuffer::migrate() { } +/////////////////////////////////////////////////////////////////////// +void DataBuffer::memcpy(const DataBuffer &dst, const DataBuffer &src) { + if (src._lenInBytes < dst._lenInBytes) + throw std::runtime_error("DataBuffer::memcpy: Source data buffer is smaller than destination"); + + std::memcpy(dst._primaryBuffer, src._primaryBuffer, dst._lenInBytes); +} //////////////////////////////////////////////////////////////////////// void DataBuffer::writePrimary() const { } diff --git a/libnd4j/include/array/cuda/DataBuffer.cu b/libnd4j/include/array/cuda/DataBuffer.cu index e71ed4b49..5cb227e69 100644 --- a/libnd4j/include/array/cuda/DataBuffer.cu +++ b/libnd4j/include/array/cuda/DataBuffer.cu @@ -97,6 +97,19 @@ void DataBuffer::copyCounters(const DataBuffer& other) { _readPrimary.store(other._writeSpecial); _readSpecial.store(other._writePrimary); } +//////////////////////////////////////////////////////////////////////// +void DataBuffer::memcpy(const DataBuffer &dst, const DataBuffer &src) { + if (src._lenInBytes < dst._lenInBytes) + throw std::runtime_error("DataBuffer::memcpy: Source data buffer is smaller than destination"); + + if (src.isSpecialActual()) { + cudaMemcpy(dst._specialBuffer, src._specialBuffer, dst.getLenInBytes(), cudaMemcpyDeviceToDevice); + } else if (src.isPrimaryActual()) { + cudaMemcpy(dst._specialBuffer, src._primaryBuffer, dst.getLenInBytes(), cudaMemcpyHostToDevice); + } + + dst.writeSpecial(); +} //////////////////////////////////////////////////////////////////////// void DataBuffer::copyBufferFrom(const DataBuffer& other, size_t sizeToCopyinBytes, const Nd4jLong offsetThis, const Nd4jLong offsetOther) { // copies only to special buffer diff --git a/libnd4j/include/ops/declarable/generic/datatypes/bitcast.cpp b/libnd4j/include/ops/declarable/generic/datatypes/bitcast.cpp index 533b4e2f9..24f96f7a7 100644 --- a/libnd4j/include/ops/declarable/generic/datatypes/bitcast.cpp +++ b/libnd4j/include/ops/declarable/generic/datatypes/bitcast.cpp @@ -45,9 +45,10 @@ namespace nd4j { REQUIRE_TRUE(output->isEmpty(), 0, "BITCAST: If input is empty, output array must also be empty."); return Status::OK(); } - // buffers for both input and output should be equals - DataBuffer buf(input->buffer(), input->specialBuffer(), input->lengthOf() * input->sizeOfT(), input->dataType()); - *(output->dataBuffer()) = buf; + + // just memcpy data +// output->dataBuffer()->copyBufferFrom(*input->dataBuffer()); // as variant + DataBuffer::memcpy(*output->dataBuffer(), *input->dataBuffer()); // this is modern approach return Status::OK(); } diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp index 50f8de9f0..fdaa7b549 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp @@ -282,6 +282,60 @@ TEST_F(DeclarableOpsTests15, Test_BitCast_4) { } + +TEST_F(DeclarableOpsTests15, Test_BitCast_5) { + auto x = NDArrayFactory::create('c', {4, 4}, { + 0.4922f, 0.2969f, 0.6172f, 0.8906f, + 0.9297f, 0.0859f, 0.2344f, 0.3828f, + 0.5781f, 0.7969f, 0.0391f, 0.1719f, + 0.8359f, 0.9297f, 0.3438f, 0.0938f}); + + auto e = NDArrayFactory::create('c', {4}, {4260467851820808160LL, 3900173902914993008LL, 3566895990128523424LL, + 3314989625590692528LL}); + nd4j::ops::bitcast op; + auto result = op.execute({&x}, {}, {nd4j::DataType::INT64}, {}); + ASSERT_EQ(Status::OK(), result->status()); + auto res = result->at(0); +// res->printIndexedBuffer("BITCAST5"); + ASSERT_TRUE(e.equalsTo(res)); + delete result; +} + +TEST_F(DeclarableOpsTests15, Test_BitCast_6) { + auto x = NDArrayFactory::create('c', {4, 4}, { + 1.f, 2.f, 3.f, 4.f, + 5.f, 6.f, 7.f, 8.f, + 9.f, 10.f, 11.f, 12.f, + 13.f, 14.f, 15.f, 16.f}); + + auto e = NDArrayFactory::create('c', {4}, {4899988963420290048LL, 5188224837230806272LL, 5332342774136064128LL, + 5476460161268730496LL}); + nd4j::ops::bitcast op; + auto result = op.execute({&x}, {}, {nd4j::DataType::INT64}, {}); + ASSERT_EQ(Status::OK(), result->status()); + auto res = result->at(0); +// res->printIndexedBuffer("BITCAST6"); + ASSERT_TRUE(e.equalsTo(res)); + delete result; +} +TEST_F(DeclarableOpsTests15, Test_BitCast_7) { + auto x = NDArrayFactory::create('c', {4, 4}, { + 1.1f, 2.2f, 3.3f, 4.4f, + 5.1f, 6.2f, 7.3f, 8.4f, + 9.1f, 10.2f, 11.3f, 12.4f, + 13.f, 14.2f, 15.3f, 16.4f}); + + auto e = NDArrayFactory::create('c', {4}, { + 4928700072476425318LL, 5202580391758873882LL, 5346698272827918477LL, 5483778673873668736LL}); + nd4j::ops::bitcast op; + auto result = op.execute({&x}, {}, {nd4j::DataType::INT64}, {}); + ASSERT_EQ(Status::OK(), result->status()); + auto res = result->at(0); +// res->printIndexedBuffer("BITCAST7"); + ASSERT_TRUE(e.equalsTo(res)); + delete result; +} + TEST_F(DeclarableOpsTests15, Test_depthwise_bp_1) { auto in = NDArrayFactory::create('c', {4, 8, 64, 64}); auto w = NDArrayFactory::create('c', {2, 2, 8, 2}); @@ -609,4 +663,4 @@ TEST_F(DeclarableOpsTests15, test_empty_decreasing_1) { ASSERT_EQ(Status::OK(), status); ASSERT_EQ(true, z.e(0)); -} \ No newline at end of file +} From 4fb9fa7748b60673edfd2c76beff4fb961d914e4 Mon Sep 17 00:00:00 2001 From: Alex Black Date: Sat, 30 Nov 2019 18:39:32 +1100 Subject: [PATCH 14/30] Add ND4J namespaces (#83) * Add NDValidation Signed-off-by: AlexDBlack * Add bitwise namespace Signed-off-by: AlexDBlack * Math namespace op constructor fixes Signed-off-by: AlexDBlack * Constructor fixes Signed-off-by: AlexDBlack * Add Math namespace Signed-off-by: AlexDBlack * Update NDBitwise Signed-off-by: AlexDBlack * Add random namespaces Signed-off-by: AlexDBlack * Update Signed-off-by: AlexDBlack * NN namespace Signed-off-by: AlexDBlack * Small cleanup Signed-off-by: AlexDBlack --- .../api/ops/impl/broadcast/BiasAdd.java | 4 + .../api/ops/impl/indexaccum/FirstIndex.java | 5 + .../linalg/api/ops/impl/indexaccum/IAMax.java | 5 + .../linalg/api/ops/impl/indexaccum/IAMin.java | 5 + .../api/ops/impl/indexaccum/LastIndex.java | 5 + .../impl/layers/convolution/BatchNorm.java | 9 + .../linalg/api/ops/impl/reduce/Moments.java | 7 + .../api/ops/impl/reduce/NormalizeMoments.java | 6 + .../api/ops/impl/reduce/ZeroFraction.java | 7 +- .../linalg/api/ops/impl/scalar/PRelu.java | 4 + .../api/ops/impl/shape/ConfusionMatrix.java | 35 +- .../nd4j/linalg/api/ops/impl/shape/Cross.java | 7 +- .../nd4j/linalg/api/ops/impl/shape/Diag.java | 9 +- .../linalg/api/ops/impl/shape/DiagPart.java | 4 + .../nd4j/linalg/api/ops/impl/shape/Eye.java | 39 +- .../linalg/api/ops/impl/shape/MergeAvg.java | 5 + .../linalg/api/ops/impl/shape/MergeMax.java | 5 + .../api/ops/impl/transforms/ReluLayer.java | 5 + .../ops/impl/transforms/clip/ClipByNorm.java | 6 +- .../ops/impl/transforms/clip/ClipByValue.java | 6 +- .../api/ops/impl/transforms/custom/ATan2.java | 11 +- .../custom/DotProductAttention.java | 14 + .../transforms/custom/IsNonDecreasing.java | 9 +- .../custom/IsStrictlyIncreasing.java | 9 +- .../ops/impl/transforms/custom/LayerNorm.java | 4 + .../impl/transforms/custom/LogSoftMax.java | 5 + .../transforms/custom/MatrixDeterminant.java | 6 + .../impl/transforms/custom/MatrixInverse.java | 6 + .../impl/transforms/custom/MatrixSetDiag.java | 6 + .../custom/MultiHeadDotProductAttention.java | 18 + .../api/ops/impl/transforms/custom/Pow.java | 6 + .../ops/impl/transforms/custom/SoftMax.java | 7 +- .../impl/transforms/custom/Standardize.java | 6 +- .../api/ops/impl/transforms/custom/Trace.java | 6 + .../ops/impl/transforms/custom/XwPlusB.java | 7 + .../gradient/LeakyReLUDerivative.java | 1 + .../gradient/SigmoidDerivative.java | 5 + .../impl/transforms/gradient/SoftmaxBp.java | 4 + .../pairwise/arithmetic/MergeAddOp.java | 5 + .../ops/random/custom/RandomExponential.java | 5 + .../random/impl/BernoulliDistribution.java | 5 + .../ops/random/impl/BinomialDistribution.java | 5 + .../ops/random/impl/GaussianDistribution.java | 5 + .../random/impl/LogNormalDistribution.java | 5 + .../impl/TruncatedNormalDistribution.java | 5 + .../ops/random/impl/UniformDistribution.java | 5 + .../org/nd4j/linalg/factory/NDValidation.java | 236 +++ .../java/org/nd4j/linalg/factory/Nd4j.java | 51 +- .../nd4j/linalg/factory/ops/NDBitwise.java | 211 +++ .../org/nd4j/linalg/factory/ops/NDMath.java | 1324 +++++++++++++++++ .../org/nd4j/linalg/factory/ops/NDNN.java | 522 +++++++ .../org/nd4j/linalg/factory/ops/NDRandom.java | 138 ++ .../nd4j/autodiff/samediff/SameDiffTests.java | 6 +- .../TFGraphs/TFGraphTestAllHelper.java | 1 + .../org/nd4j/linalg/api/TestNamespaces.java | 68 + 55 files changed, 2879 insertions(+), 26 deletions(-) create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/NDValidation.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDBitwise.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDMath.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDNN.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDRandom.java create mode 100644 nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/TestNamespaces.java diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BiasAdd.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BiasAdd.java index 7d5dbf4fc..3487cc216 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BiasAdd.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BiasAdd.java @@ -48,6 +48,10 @@ public class BiasAdd extends DynamicCustomOp { this.nchw = nchw; } + public BiasAdd(@NonNull INDArray input, @NonNull INDArray bias, boolean nchw){ + this(input, bias, null, nchw); + } + public BiasAdd(@NonNull INDArray input, @NonNull INDArray bias, INDArray output, boolean nchw){ super(new INDArray[]{input, bias}, wrapOrNull(output)); bArguments.clear(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/FirstIndex.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/FirstIndex.java index d2046140a..8d660eba8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/FirstIndex.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/FirstIndex.java @@ -54,7 +54,12 @@ public class FirstIndex extends BaseIndexAccumulation { public FirstIndex(INDArray x, @NonNull Condition condition, int... dimension) { + this(x, condition, false, dimension); + } + + public FirstIndex(INDArray x, @NonNull Condition condition, boolean keepDims, int... dimension) { this(x, condition, Nd4j.EPS_THRESHOLD, dimension); + this.keepDims = keepDims; } public FirstIndex(INDArray x, @NonNull Condition condition, double eps, int... dimension) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IAMax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IAMax.java index 30f51c56c..4c8465ef7 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IAMax.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IAMax.java @@ -38,7 +38,12 @@ public class IAMax extends BaseIndexAccumulation { public IAMax() {} public IAMax(INDArray x, int... dimensions) { + this(x, false, dimensions); + } + + public IAMax(INDArray x, boolean keepDims, int... dimensions) { this(x, null, dimensions); + this.keepDims = keepDims; } public IAMax(INDArray x, INDArray z, int... dimensions) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IAMin.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IAMin.java index 5a3e950e1..0a1383a67 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IAMin.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IAMin.java @@ -41,6 +41,11 @@ public class IAMin extends BaseIndexAccumulation { super(x, dimensions); } + public IAMin(INDArray in, boolean keepDims, int... dimnesions){ + super(in, null, dimnesions); + this.keepDims = keepDims; + } + public IAMin(INDArray x, INDArray z, int... dimensions) { super(x, z, dimensions); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/LastIndex.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/LastIndex.java index 792547d7c..b29af5042 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/LastIndex.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/LastIndex.java @@ -58,6 +58,11 @@ public class LastIndex extends BaseIndexAccumulation { this(x, condition, Nd4j.EPS_THRESHOLD, dimensions); } + public LastIndex(INDArray x, @NonNull Condition condition, boolean keepDim, int... dimensions) { + this(x, condition, Nd4j.EPS_THRESHOLD, dimensions); + this.keepDims = keepDim; + } + public LastIndex(INDArray x, @NonNull Condition condition, double eps, int... dimensions) { super(x,null, dimensions); this.condition = condition; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/BatchNorm.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/BatchNorm.java index 20ff5918c..e3716bc24 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/BatchNorm.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/BatchNorm.java @@ -76,6 +76,15 @@ public class BatchNorm extends DynamicCustomOp { addArgs(); } + public BatchNorm(INDArray input, INDArray mean, INDArray variance, INDArray gamma, INDArray beta, double epsilon, int... axis){ + super(wrapFilterNull(input, mean, variance, gamma, beta), null); + this.jaxis = axis; + this.applyBeta = beta != null; + this.applyGamma = gamma != null; + this.epsilon = epsilon; + addArgs(); + } + public void addArgs() { addIArgument(ArrayUtil.fromBoolean(applyGamma)); addIArgument(ArrayUtil.fromBoolean(applyBeta)); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/Moments.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/Moments.java index c7aef3c62..152b93980 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/Moments.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/Moments.java @@ -17,6 +17,7 @@ package org.nd4j.linalg.api.ops.impl.reduce; import lombok.NoArgsConstructor; +import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; @@ -40,6 +41,12 @@ public class Moments extends DynamicCustomOp { private int[] axes; + public Moments(@NonNull INDArray input, int... axes){ + super(new INDArray[]{input}, null); + this.axes = axes; + addArgs(); + } + public Moments(SameDiff sameDiff, SDVariable input) { this(sameDiff, input, null); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/NormalizeMoments.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/NormalizeMoments.java index 945cd505d..be33a458d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/NormalizeMoments.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/NormalizeMoments.java @@ -47,6 +47,12 @@ public class NormalizeMoments extends DynamicCustomOp { addArgs(); } + public NormalizeMoments(INDArray counts, INDArray means, INDArray variances, double shift) { + super(null, new INDArray[]{counts, means, variances}, null); + this.shift = shift; + addArgs(); + } + public NormalizeMoments(INDArray counts, INDArray ssSum, INDArray ssSqSum, INDArray outMean, INDArray outVar) { super(null, new INDArray[]{counts, ssSum, ssSqSum}, new INDArray[]{outMean, outVar}, new ArrayList(), new ArrayList()); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/ZeroFraction.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/ZeroFraction.java index 18009a466..42ecc2f57 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/ZeroFraction.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/ZeroFraction.java @@ -17,11 +17,13 @@ package org.nd4j.linalg.api.ops.impl.reduce; import lombok.NoArgsConstructor; +import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import java.util.Collections; @@ -36,10 +38,13 @@ import java.util.List; public class ZeroFraction extends DynamicCustomOp { public ZeroFraction(SameDiff sameDiff, SDVariable input) { - super(null, sameDiff, new SDVariable[] {input}, false); } + public ZeroFraction(@NonNull INDArray input){ + super(new INDArray[]{input}, null); + } + @Override public String opName() { return "zero_fraction"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/PRelu.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/PRelu.java index 32c07ad96..f9e30be9c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/PRelu.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/PRelu.java @@ -45,6 +45,10 @@ public class PRelu extends DynamicCustomOp { addIArgument(sharedAxes); } + public PRelu(@NonNull INDArray x, @NonNull INDArray alpha, @NonNull int... sharedAxes) { + this(x, null, alpha, sharedAxes); + } + public PRelu(@NonNull INDArray x, INDArray z, @NonNull INDArray alpha, @NonNull int... sharedAxes) { super(new INDArray[]{x, alpha}, new INDArray[]{z}); this.sharedAxes = sharedAxes; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ConfusionMatrix.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ConfusionMatrix.java index 6275ce210..f21a0d291 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ConfusionMatrix.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ConfusionMatrix.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.shape; +import lombok.NonNull; import lombok.val; import org.apache.commons.lang3.NotImplementedException; import org.nd4j.autodiff.samediff.SDVariable; @@ -23,6 +24,7 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.imports.descriptors.properties.PropertyMapping; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; @@ -41,6 +43,35 @@ public class ConfusionMatrix extends DynamicCustomOp { public ConfusionMatrix(){ } + public ConfusionMatrix(@NonNull INDArray labels, @NonNull INDArray predicted, @NonNull DataType dataType){ + super(new INDArray[]{labels, predicted}, null); + this.outputType = dataType; + } + + public ConfusionMatrix(@NonNull INDArray labels, @NonNull INDArray predicted, int numClasses){ + this(labels, predicted, numClasses, DEFAULT_DTYPE); + } + + public ConfusionMatrix(@NonNull INDArray labels, @NonNull INDArray predicted, INDArray weights) { + this(labels, predicted, weights, null); + } + + public ConfusionMatrix(@NonNull INDArray labels, @NonNull INDArray predicted, INDArray weights, Integer numClasses) { + this(labels, predicted, weights, numClasses, DEFAULT_DTYPE); + } + + public ConfusionMatrix(@NonNull INDArray labels, @NonNull INDArray predicted, Integer numClasses, @NonNull DataType dataType) { + this(labels, predicted, null, numClasses, dataType); + } + + public ConfusionMatrix(@NonNull INDArray labels, @NonNull INDArray predicted, INDArray weights, Integer numClasses, @NonNull DataType dataType) { + super(wrapFilterNull(labels, predicted, weights), null); + this.outputType = dataType; + if(numClasses != null) { + addIArgument(numClasses); + } + } + public ConfusionMatrix(SameDiff sameDiff, SDVariable labels, SDVariable pred, DataType dataType){ super(null, sameDiff, new SDVariable[]{labels, pred}); this.outputType = dataType; @@ -57,7 +88,9 @@ public class ConfusionMatrix extends DynamicCustomOp { public ConfusionMatrix(SameDiff sameDiff, SDVariable labels, SDVariable pred, Integer numClasses, SDVariable weights){ super(null, sameDiff, new SDVariable[]{labels, pred, weights}); - addIArgument(numClasses); + if(numClasses != null) { + addIArgument(numClasses); + } } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Cross.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Cross.java index f45f9aa87..3e94cb126 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Cross.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Cross.java @@ -44,13 +44,16 @@ public class Cross extends DynamicCustomOp { public Cross() { } - public Cross(SameDiff sameDiff, SDVariable[] args) { super(null, sameDiff, args, false); } + public Cross(INDArray a, INDArray b){ + this(a,b,null); + } + public Cross(INDArray a, INDArray b, INDArray out){ - super(null, new INDArray[]{a,b}, out == null ? null : new INDArray[]{out}, null, (int[])null); + super(null, new INDArray[]{a,b}, wrapOrNull(out), null, (int[])null); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Diag.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Diag.java index b6d08784b..95947a94b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Diag.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Diag.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.shape; +import lombok.NonNull; import onnx.Onnx; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; @@ -44,8 +45,12 @@ public class Diag extends DynamicCustomOp { public Diag() { } - public Diag(INDArray[] inputs, INDArray[] outputs) { - super(null, inputs, outputs); + public Diag(@NonNull INDArray input) { + this(input, null); + } + + public Diag(@NonNull INDArray input, @NonNull INDArray output){ + super(null, new INDArray[]{input}, wrapOrNull(output)); } public Diag(SameDiff sameDiff, SDVariable[] args, boolean inPlace) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/DiagPart.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/DiagPart.java index 6b1688602..1d2b93d9a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/DiagPart.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/DiagPart.java @@ -51,6 +51,10 @@ public class DiagPart extends DynamicCustomOp { super(null, sameDiff, args, inPlace); } + public DiagPart(INDArray in){ + this(in, null); + } + public DiagPart(INDArray in, INDArray out){ super(null, in, out, null, null); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Eye.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Eye.java index 3472be2de..3a8bb8f15 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Eye.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Eye.java @@ -16,11 +16,14 @@ package org.nd4j.linalg.api.ops.impl.shape; +import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.shape.LongShapeDescriptor; +import org.nd4j.shade.guava.base.Preconditions; import java.util.Collections; import java.util.List; @@ -55,6 +58,23 @@ public class Eye extends DynamicCustomOp { public Eye() { } + public Eye(@NonNull INDArray rows){ + this(rows.getInt(0)); + Preconditions.checkArgument(rows.isScalar(), "Rows INDArray must be a scalar"); + } + + public Eye(@NonNull INDArray rows, @NonNull INDArray columns){ + this(rows.getInt(0), columns.getInt(0)); + Preconditions.checkArgument(rows.isScalar(), "Rows INDArray must be a scalar"); + Preconditions.checkArgument(columns.isScalar(), "Columns INDArray must be a scalar"); + } + + public Eye(int rows){ + this.numRows = rows; + this.numCols = rows; + addArgs(); + } + public Eye(SameDiff sameDiff, SDVariable numRows){ super(null, sameDiff, new SDVariable[] {numRows}, false); } @@ -66,10 +86,7 @@ public class Eye extends DynamicCustomOp { super(null, sameDiff, new SDVariable[] {numRows, numCols, batch_shape}, false); } public Eye(SameDiff sameDiff, int numRows) { - super(null, sameDiff, new SDVariable[] {}, false); - this.numRows = numRows; - this.numCols = numRows; - addArgs(); + this(sameDiff, numRows, numRows); } public Eye(SameDiff sameDiff, int numRows, int numCols) { @@ -77,13 +94,25 @@ public class Eye extends DynamicCustomOp { } public Eye(SameDiff sameDiff, int numRows, int numCols, DataType dataType) { - super(null, sameDiff, new SDVariable[] {}, false); + this(sameDiff, numRows, numCols, dataType, null); + } + + public Eye(int numRows, int numCols, DataType dataType, int[] batchDimension) { this.numRows = numRows; this.numCols = numCols; + this.batchDimension = batchDimension; this.dataType = dataType; addArgs(); } + public Eye(int numRows, int numCols) { + this(numRows, numCols, DEFAULT_DTYPE); + } + + public Eye(int numRows, int numCols, DataType dataType) { + this(numRows, numCols, dataType, null); + } + public Eye(SameDiff sameDiff, int numRows, int numCols, DataType dataType, int[] batchDimension) { super(null, sameDiff, new SDVariable[] {}, false); this.numRows = numRows; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/MergeAvg.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/MergeAvg.java index 448ae1d16..b63052eb5 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/MergeAvg.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/MergeAvg.java @@ -22,6 +22,7 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.factory.Nd4j; import org.tensorflow.framework.AttrValue; @@ -36,6 +37,10 @@ import java.util.Map; @Slf4j public class MergeAvg extends DynamicCustomOp { + public MergeAvg(INDArray... inputs){ + super(inputs, null); + } + public MergeAvg(SameDiff sameDiff, SDVariable... inputs) { super(null, sameDiff, inputs); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/MergeMax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/MergeMax.java index 11578b902..6c342200d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/MergeMax.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/MergeMax.java @@ -23,6 +23,7 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; @@ -40,6 +41,10 @@ public class MergeMax extends DynamicCustomOp { super(null, sameDiff, inputs); } + public MergeMax(INDArray... inputs){ + super(inputs, null); + } + public MergeMax(){ } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/ReluLayer.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/ReluLayer.java index 6a8f2965d..fcd220004 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/ReluLayer.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/ReluLayer.java @@ -17,9 +17,11 @@ package org.nd4j.linalg.api.ops.impl.transforms; import lombok.NoArgsConstructor; +import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.transforms.custom.XwPlusB; import java.util.Collections; @@ -37,7 +39,10 @@ public class ReluLayer extends XwPlusB { public ReluLayer(SameDiff sameDiff, SDVariable input, SDVariable weights, SDVariable bias) { super(sameDiff, input, weights, bias); + } + public ReluLayer(@NonNull INDArray input, @NonNull INDArray weights, @NonNull INDArray bias){ + super(new INDArray[]{input, weights, bias}, null); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/clip/ClipByNorm.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/clip/ClipByNorm.java index 4d6ba3e66..026930e4a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/clip/ClipByNorm.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/clip/ClipByNorm.java @@ -49,8 +49,12 @@ public class ClipByNorm extends DynamicCustomOp { addTArgument(clipValue); } + public ClipByNorm(INDArray in, double clipValue, int... dimensions){ + this(in, null, clipValue, dimensions); + } + public ClipByNorm(INDArray in, INDArray out, double clipValue, int... dimensions){ - super(null, new INDArray[]{in}, (out == null ? null : new INDArray[]{out}), Collections.singletonList(clipValue), dimensions); + super(null, new INDArray[]{in}, wrapOrNull(out), Collections.singletonList(clipValue), dimensions); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/clip/ClipByValue.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/clip/ClipByValue.java index 11d3e9004..3927ba2bc 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/clip/ClipByValue.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/clip/ClipByValue.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.clip; +import lombok.NonNull; import onnx.Onnx; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; @@ -38,11 +39,10 @@ public class ClipByValue extends DynamicCustomOp { private double clipValueMin; private double clipValueMax; - public ClipByValue(INDArray[] inputs, INDArray[] outputs, double clipValueMin, double clipValueMax, boolean inPlace) { - super(null, inputs, outputs); + public ClipByValue(@NonNull INDArray input, double clipValueMin, double clipValueMax) { + super(null, new INDArray[]{input}, null); this.clipValueMin = clipValueMin; this.clipValueMax = clipValueMax; - this.inplaceCall = inPlace; addTArgument(clipValueMin, clipValueMax); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ATan2.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ATan2.java index 8a782acf6..d6230e153 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ATan2.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ATan2.java @@ -41,13 +41,22 @@ public class ATan2 extends BaseDynamicTransformOp { super(sameDiff, new SDVariable[] {y, x} ,false); } + /** + * Note that the order of x and y match {@link java.lang.Math#atan2(double, double)}, + * and are reversed when compared to OldATan2. + * See {@link Transforms#atan2(org.nd4j.linalg.api.ndarray.INDArray, org.nd4j.linalg.api.ndarray.INDArray)} + */ + public ATan2(INDArray x, INDArray y) { + this(x,y,null); + } + /** * Note that the order of x and y match {@link java.lang.Math#atan2(double, double)}, * and are reversed when compared to OldATan2. * See {@link Transforms#atan2(org.nd4j.linalg.api.ndarray.INDArray, org.nd4j.linalg.api.ndarray.INDArray)} */ public ATan2(INDArray x, INDArray y, INDArray z) { - super(new INDArray[]{x, y}, new INDArray[]{ z }); + super(new INDArray[]{x, y}, wrapOrNull(z)); } public ATan2() {} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/DotProductAttention.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/DotProductAttention.java index cf72ea7be..d3a5c9676 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/DotProductAttention.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/DotProductAttention.java @@ -17,10 +17,12 @@ package org.nd4j.linalg.api.ops.impl.transforms.custom; import lombok.NoArgsConstructor; +import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import java.util.Arrays; @@ -49,6 +51,18 @@ public class DotProductAttention extends DynamicCustomOp { addIArgument(withWeights ? 1 : 0); } + public DotProductAttention(@NonNull INDArray queries, @NonNull INDArray keys, @NonNull INDArray values, INDArray mask, boolean scaled){ + this(queries, keys, values, mask, scaled, false); + } + + public DotProductAttention(@NonNull INDArray queries, @NonNull INDArray keys, @NonNull INDArray values, INDArray mask, boolean scaled, boolean withWeights){ + super(wrapFilterNull(queries, keys, values, mask), null); + this.scaled = scaled; + this.withWeights = withWeights; + addIArgument(scaled ? 1 : 0); + addIArgument(withWeights ? 1 : 0); + } + @Override public String opName() { return "dot_product_attention"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/IsNonDecreasing.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/IsNonDecreasing.java index 08a9b0faf..83cad14a5 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/IsNonDecreasing.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/IsNonDecreasing.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.custom; +import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; @@ -40,8 +41,12 @@ public class IsNonDecreasing extends DynamicCustomOp { super(null, sameDiff, args, inPlace); } - public IsNonDecreasing(INDArray[] inputs, INDArray[] outputs) { - super(null, inputs, outputs); + public IsNonDecreasing(@NonNull INDArray input){ + this(input, null); + } + + public IsNonDecreasing(@NonNull INDArray input, INDArray output) { + super(null, new INDArray[]{input}, wrapOrNull(output)); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/IsStrictlyIncreasing.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/IsStrictlyIncreasing.java index 02a527cb8..55b866cad 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/IsStrictlyIncreasing.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/IsStrictlyIncreasing.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.custom; +import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; @@ -38,8 +39,12 @@ public class IsStrictlyIncreasing extends DynamicCustomOp { super(null, sameDiff, args, inPlace); } - public IsStrictlyIncreasing( INDArray[] inputs, INDArray[] outputs) { - super(null, inputs, outputs); + public IsStrictlyIncreasing(@NonNull INDArray input){ + this(input, null); + } + + public IsStrictlyIncreasing(@NonNull INDArray input, INDArray output) { + super(null, new INDArray[]{input}, wrapOrNull(output)); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LayerNorm.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LayerNorm.java index 7c7c34fc5..0c4990bb2 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LayerNorm.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LayerNorm.java @@ -62,6 +62,10 @@ public class LayerNorm extends DynamicCustomOp { setDimensions(dimensions); } + public LayerNorm(@NonNull INDArray input, @NonNull INDArray gain, boolean channelsFirst, int... dimensions) { + this(input, gain, null, channelsFirst, dimensions); + } + public LayerNorm(INDArray input, INDArray gain, INDArray result, boolean channelsFirst, int... dimensions) { this(input, gain, null, result, channelsFirst, dimensions); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LogSoftMax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LogSoftMax.java index 09a4823e0..86c9d9c0a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LogSoftMax.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LogSoftMax.java @@ -52,6 +52,11 @@ public class LogSoftMax extends DynamicCustomOp { this(x, x); } + public LogSoftMax(INDArray x, int dimension) { + this(x, null); + this.dimension = dimension; + } + public LogSoftMax(SameDiff sameDiff, SDVariable i_v, int dimension) { this(sameDiff, i_v); this.dimension = dimension; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MatrixDeterminant.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MatrixDeterminant.java index c079b0fc0..67ba9f343 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MatrixDeterminant.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MatrixDeterminant.java @@ -16,10 +16,12 @@ package org.nd4j.linalg.api.ops.impl.transforms.custom; +import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.factory.Nd4j; @@ -39,6 +41,10 @@ public class MatrixDeterminant extends DynamicCustomOp { // } + public MatrixDeterminant(@NonNull INDArray input){ + super(new INDArray[]{input}, null); + } + public MatrixDeterminant(SameDiff sameDiff, SDVariable in, boolean inPlace) { super(null, sameDiff, new SDVariable[]{in}, inPlace); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MatrixInverse.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MatrixInverse.java index 0bbe7f25d..4ff0f942b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MatrixInverse.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MatrixInverse.java @@ -16,10 +16,12 @@ package org.nd4j.linalg.api.ops.impl.transforms.custom; +import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import java.util.Collections; @@ -36,6 +38,10 @@ public class MatrixInverse extends DynamicCustomOp { // } + public MatrixInverse(@NonNull INDArray input){ + super(new INDArray[]{input}, null); + } + public MatrixInverse(SameDiff sameDiff, SDVariable in, boolean inPlace) { super(null, sameDiff, new SDVariable[]{in}, inPlace); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MatrixSetDiag.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MatrixSetDiag.java index 3d00afe5b..9bbf6c50f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MatrixSetDiag.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MatrixSetDiag.java @@ -16,10 +16,12 @@ package org.nd4j.linalg.api.ops.impl.transforms.custom; +import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import java.util.Arrays; @@ -32,6 +34,10 @@ public class MatrixSetDiag extends DynamicCustomOp { super(null, sameDiff, new SDVariable[]{in, diag}, inPlace); } + public MatrixSetDiag(@NonNull INDArray in, @NonNull INDArray diag){ + super(new INDArray[]{in, diag}, null); + } + public MatrixSetDiag(){ } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MultiHeadDotProductAttention.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MultiHeadDotProductAttention.java index f55b21263..54167bd8b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MultiHeadDotProductAttention.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MultiHeadDotProductAttention.java @@ -17,10 +17,12 @@ package org.nd4j.linalg.api.ops.impl.transforms.custom; import lombok.NoArgsConstructor; +import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import java.util.Arrays; @@ -54,6 +56,22 @@ public class MultiHeadDotProductAttention extends DynamicCustomOp { addIArgument(withWeights ? 1 : 0); } + public MultiHeadDotProductAttention(@NonNull INDArray queries, @NonNull INDArray keys, @NonNull INDArray values, + @NonNull INDArray Wq, @NonNull INDArray Wk, @NonNull INDArray Wv, @NonNull INDArray Wo, + INDArray mask, boolean scaled) { + this(queries, keys, values, Wq, Wk, Wv, Wo, mask, scaled, false); + } + + public MultiHeadDotProductAttention(@NonNull INDArray queries, @NonNull INDArray keys, @NonNull INDArray values, + @NonNull INDArray Wq, @NonNull INDArray Wk, @NonNull INDArray Wv, @NonNull INDArray Wo, + INDArray mask, boolean scaled, boolean withWeights) { + super(wrapFilterNull(queries, keys, values, Wq, Wk, Wv, Wo, mask), null); + this.scaled = scaled; + this.withWeights = withWeights; + addIArgument(scaled ? 1 : 0); + addIArgument(withWeights ? 1 : 0); + } + @Override public String opName() { return "multi_head_dot_product_attention"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Pow.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Pow.java index 07e72b6b3..df41438fe 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Pow.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Pow.java @@ -16,10 +16,12 @@ package org.nd4j.linalg.api.ops.impl.transforms.custom; +import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import java.util.Arrays; @@ -39,6 +41,10 @@ public class Pow extends DynamicCustomOp { public Pow(){ } + public Pow(@NonNull INDArray x, @NonNull INDArray y){ + super(new INDArray[]{x,y}, null); + } + @Override public String opName(){ return "Pow"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/SoftMax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/SoftMax.java index bfa1c27c1..d8db2569c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/SoftMax.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/SoftMax.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.custom; +import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; @@ -69,8 +70,12 @@ public class SoftMax extends BaseDynamicTransformOp { addIArgument(dimension); } + public SoftMax(@NonNull INDArray input, int dimension){ + this(input, null, dimension); + } + public SoftMax(INDArray input, INDArray result, int dimension){ - super(new INDArray[]{input}, new INDArray[]{result}); + super(new INDArray[]{input}, wrapOrNull(result)); this.dimension = dimension; addIArgument(dimension); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Standardize.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Standardize.java index 140aef355..467b36a4e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Standardize.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Standardize.java @@ -34,8 +34,12 @@ public class Standardize extends DynamicCustomOp { setDimensions(dimensions); } + public Standardize(INDArray input, int... dimensions){ + this(input, null, dimensions); + } + public Standardize(INDArray input, INDArray result, int... dimensions){ - super("standardize", new INDArray[]{input}, new INDArray[]{result}); + super("standardize", new INDArray[]{input},wrapOrNull(result)); setDimensions(dimensions); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Trace.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Trace.java index e43918918..9d61de1c8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Trace.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Trace.java @@ -16,10 +16,12 @@ package org.nd4j.linalg.api.ops.impl.transforms.custom; +import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.factory.Nd4j; @@ -37,6 +39,10 @@ public class Trace extends DynamicCustomOp { super(null, sd, new SDVariable[]{in}); } + public Trace(@NonNull INDArray in){ + super(wrapOrNull(in), null); + } + public Trace(){ } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/XwPlusB.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/XwPlusB.java index 0c10159e3..563c4a7f6 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/XwPlusB.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/XwPlusB.java @@ -46,7 +46,14 @@ public class XwPlusB extends DynamicCustomOp { public XwPlusB(SameDiff sameDiff, SDVariable input, SDVariable weights, SDVariable bias) { super(null, sameDiff, new SDVariable[] {input, weights, bias}, false); + } + public XwPlusB(INDArray input, INDArray weights, INDArray bias) { + super(new INDArray[] {input, weights, bias}, null); + } + + public XwPlusB(INDArray[] inputs, INDArray output){ + super(inputs, wrapOrNull(output)); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/LeakyReLUDerivative.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/LeakyReLUDerivative.java index baaa87e1f..202f7e291 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/LeakyReLUDerivative.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/LeakyReLUDerivative.java @@ -17,6 +17,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.gradient; +import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/SigmoidDerivative.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/SigmoidDerivative.java index 9c4d478c7..b47d41462 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/SigmoidDerivative.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/SigmoidDerivative.java @@ -17,6 +17,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.gradient; +import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; @@ -35,6 +36,10 @@ public class SigmoidDerivative extends DynamicCustomOp { super(sameDiff, new SDVariable[]{i_v1, i_v2}); } + public SigmoidDerivative(@NonNull INDArray x, @NonNull INDArray y) { + this(x, y, null); + } + public SigmoidDerivative(INDArray x, INDArray y, INDArray z) { super(null, new INDArray[]{x,y}, new INDArray[]{z}, null, (int[])null); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/SoftmaxBp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/SoftmaxBp.java index dbbdb8dde..37a1d8632 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/SoftmaxBp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/SoftmaxBp.java @@ -42,6 +42,10 @@ public class SoftmaxBp extends DynamicCustomOp { addIArgument(dimension); } + public SoftmaxBp(@NonNull INDArray input, @NonNull INDArray grad, Integer dimension){ + this(input, grad, null, dimension); + } + public SoftmaxBp(@NonNull INDArray input, @NonNull INDArray grad, INDArray output, Integer dimension){ super(new INDArray[]{input, grad}, wrapOrNull(output)); if(dimension != null) diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/MergeAddOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/MergeAddOp.java index c2d15df19..f64bfe902 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/MergeAddOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/MergeAddOp.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic; +import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; @@ -41,6 +42,10 @@ public class MergeAddOp extends BaseDynamicTransformOp { super(sameDiff, args, inPlace); } + public MergeAddOp(@NonNull INDArray... inputs){ + this(inputs, null); + } + public MergeAddOp(INDArray[] inputs, INDArray[] outputs) { super(inputs, outputs); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/custom/RandomExponential.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/custom/RandomExponential.java index e67c362fc..5b9faa005 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/custom/RandomExponential.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/custom/RandomExponential.java @@ -23,6 +23,7 @@ import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.factory.Nd4j; import java.util.Collections; import java.util.List; @@ -47,6 +48,10 @@ public class RandomExponential extends DynamicCustomOp { addTArgument(lambda); } + public RandomExponential(double lambda, DataType datatype, long... shape){ + this(Nd4j.createFromArray(shape), Nd4j.createUninitialized(datatype, shape), lambda); + } + public RandomExponential(INDArray shape,INDArray out, double lambda){ super(null, new INDArray[]{shape}, new INDArray[]{out}, Collections.singletonList(lambda), (List)null); this.lambda = lambda; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/BernoulliDistribution.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/BernoulliDistribution.java index dec04f11f..0ffc8e72e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/BernoulliDistribution.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/BernoulliDistribution.java @@ -25,6 +25,7 @@ import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.random.BaseRandomOp; import org.nd4j.linalg.exception.ND4JIllegalStateException; +import org.nd4j.linalg.factory.Nd4j; import java.util.Collections; import java.util.LinkedHashMap; @@ -49,6 +50,10 @@ public class BernoulliDistribution extends BaseRandomOp { super(); } + public BernoulliDistribution(double p, DataType datatype, long... shape){ + this(Nd4j.createUninitialized(datatype, shape), p); + } + /** * This op fills Z with bernoulli trial results, so 0, or 1, depending by common probability * @param z diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/BinomialDistribution.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/BinomialDistribution.java index 69a5460f2..41bf909cc 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/BinomialDistribution.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/BinomialDistribution.java @@ -24,6 +24,7 @@ import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.random.BaseRandomOp; +import org.nd4j.linalg.factory.Nd4j; import java.util.Collections; import java.util.LinkedHashMap; @@ -46,6 +47,10 @@ public class BinomialDistribution extends BaseRandomOp { this.extraArgs = new Object[] {(double) this.trials, this.probability}; } + public BinomialDistribution(int trials, double probability, DataType dt, long[] shape){ + this(Nd4j.createUninitialized(dt, shape), trials, probability); + } + public BinomialDistribution() { super(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/GaussianDistribution.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/GaussianDistribution.java index 42f6def0f..0bb41b655 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/GaussianDistribution.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/GaussianDistribution.java @@ -24,6 +24,7 @@ import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.random.BaseRandomOp; +import org.nd4j.linalg.factory.Nd4j; import java.util.Collections; import java.util.LinkedHashMap; @@ -50,6 +51,10 @@ public class GaussianDistribution extends BaseRandomOp { super(); } + public GaussianDistribution(double mean, double stddev, DataType datatype, long... shape){ + this(Nd4j.createUninitialized(datatype, shape), mean, stddev); + } + /** * This op fills Z with random values within stddev..mean..stddev boundaries * @param z diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/LogNormalDistribution.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/LogNormalDistribution.java index 4a0b36b32..b42e311a7 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/LogNormalDistribution.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/LogNormalDistribution.java @@ -24,6 +24,7 @@ import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.random.BaseRandomOp; +import org.nd4j.linalg.factory.Nd4j; import java.util.Collections; import java.util.LinkedHashMap; @@ -50,6 +51,10 @@ public class LogNormalDistribution extends BaseRandomOp { this.extraArgs = new Object[] {this.mean, this.stddev}; } + public LogNormalDistribution(double mean, double stddev, DataType datatype, long... shape){ + this(Nd4j.createUninitialized(datatype, shape), mean, stddev); + } + /** * This op fills Z with random values within stddev..mean..stddev boundaries * @param z diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/TruncatedNormalDistribution.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/TruncatedNormalDistribution.java index bd453fe0a..24e52a532 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/TruncatedNormalDistribution.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/TruncatedNormalDistribution.java @@ -24,6 +24,7 @@ import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.random.BaseRandomOp; +import org.nd4j.linalg.factory.Nd4j; import java.util.Collections; import java.util.List; @@ -48,6 +49,10 @@ public class TruncatedNormalDistribution extends BaseRandomOp { this.extraArgs = new Object[] {this.mean, this.stddev}; } + public TruncatedNormalDistribution(double mean, double stddev, DataType datatype, long... shape){ + this(Nd4j.createUninitialized(datatype, shape), mean, stddev); + } + /** * This op fills Z with random values within stddev..mean..stddev boundaries * @param z diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/UniformDistribution.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/UniformDistribution.java index 2b4adfc1a..408af9ce2 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/UniformDistribution.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/UniformDistribution.java @@ -24,6 +24,7 @@ import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.random.BaseRandomOp; +import org.nd4j.linalg.factory.Nd4j; import java.util.Collections; import java.util.List; @@ -46,6 +47,10 @@ public class UniformDistribution extends BaseRandomOp { this.extraArgs = new Object[] {this.from, this.to}; } + public UniformDistribution(double min, double max, DataType datatype, long... shape){ + this(Nd4j.createUninitialized(datatype, shape), min, max); + } + /** * This op fills Z with random values within from...to boundaries * @param z diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/NDValidation.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/NDValidation.java new file mode 100644 index 000000000..f60726c36 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/NDValidation.java @@ -0,0 +1,236 @@ +/* ***************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg.factory; + +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; + +import java.util.Arrays; + +public class NDValidation { + + private NDValidation() { + } + + /** + * Validate that the operation is being applied on a numerical INDArray (not boolean or utf8). + * Some operations (such as sum, norm2, add(Number) etc) don't make sense when applied to boolean/utf8 arrays + * + * @param opName Operation name to print in the exception + * @param v Variable to perform operation on + */ + public static void validateNumerical(String opName, INDArray v) { + if (v == null) + return; + if (v.dataType() == DataType.BOOL || v.dataType() == DataType.UTF8) + throw new IllegalStateException("Cannot apply operation \"" + opName + "\" to array with non-numerical data type " + v.dataType()); + } + + /** + * Validate that the operation is being applied on numerical INDArrays (not boolean or utf8). + * Some operations (such as sum, norm2, add(Number) etc) don't make sense when applied to boolean/utf8 arrays + * + * @param opName Operation name to print in the exception + * @param v Variable to perform operation on + */ + public static void validateNumerical(String opName, INDArray[] v) { + if (v == null) + return; + for (int i = 0; i < v.length; i++) { + if (v[i].dataType() == DataType.BOOL || v[i].dataType() == DataType.UTF8) + throw new IllegalStateException("Cannot apply operation \"" + opName + "\" to input array " + i + " with non-numerical data type " + v[i].dataType()); + } + } + + /** + * Validate that the operation is being applied on a numerical INDArray (not boolean or utf8). + * Some operations (such as sum, norm2, add(Number) etc) don't make sense when applied to boolean/utf8 arrays + * + * @param opName Operation name to print in the exception + * @param v Variable to validate datatype for (input to operation) + */ + public static void validateNumerical(String opName, String inputName, INDArray v) { + if (v == null) + return; + if (v.dataType() == DataType.BOOL || v.dataType() == DataType.UTF8) + throw new IllegalStateException("Input \"" + inputName + "\" for operation \"" + opName + "\" must be an numerical type type;" + + " got array with non-integer data type " + v.dataType()); + } + + /** + * Validate that the operation is being applied on numerical INDArrays (not boolean or utf8). + * Some operations (such as sum, norm2, add(Number) etc) don't make sense when applied to boolean/utf8 arrays + * + * @param opName Operation name to print in the exception + * @param v Variable to perform operation on + */ + public static void validateNumerical(String opName, String inputName, INDArray[] v) { + if (v == null) + return; + for (int i = 0; i < v.length; i++) { + if (v[i].dataType() == DataType.BOOL || v[i].dataType() == DataType.UTF8) + throw new IllegalStateException("Cannot apply operation \"" + opName + "\" to input \"" + inputName + "\" array " + i + " with non-numerical data type " + v[i].dataType()); + } + } + + /** + * Validate that the operation is being applied on numerical INDArrays (not boolean or utf8). + * Some operations (such as sum, norm2, add(Number) etc don't make sense when applied to boolean/utf8 arrays + * + * @param opName Operation name to print in the exception + * @param v1 Variable to validate datatype for (input to operation) + * @param v2 Variable to validate datatype for (input to operation) + */ + public static void validateNumerical(String opName, INDArray v1, INDArray v2) { + if (v1.dataType() == DataType.BOOL || v1.dataType() == DataType.UTF8 || v2.dataType() == DataType.BOOL || v2.dataType() == DataType.UTF8) + throw new IllegalStateException("Cannot perform operation \"" + opName + "\" on arrays if one or both variables" + + " are non-numerical: got " + v1.dataType() + " and " + v2.dataType()); + } + + /** + * Validate that the operation is being applied on an integer type INDArray + * + * @param opName Operation name to print in the exception + * @param v Variable to validate datatype for (input to operation) + */ + public static void validateInteger(String opName, INDArray v) { + if (v == null) + return; + if (!v.dataType().isIntType()) + throw new IllegalStateException("Cannot apply operation \"" + opName + "\" to array with non-integer data type " + v.dataType()); + } + + /** + * Validate that the operation is being applied on an integer type INDArray + * + * @param opName Operation name to print in the exception + * @param inputName Name of the input to the op to validate + * @param v Variable to validate datatype for (input to operation) + */ + public static void validateInteger(String opName, String inputName, INDArray v) { + if (v == null) + return; + if (!v.dataType().isIntType()) + throw new IllegalStateException("Input \"" + inputName + "\" for operation \"" + opName + "\" must be an integer" + + " type; got array with non-integer data type " + v.dataType()); + } + + /** + * Validate that the operation is being applied on an floating point type INDArray + * + * @param opName Operation name to print in the exception + * @param v Variable to validate datatype for (input to operation) + */ + public static void validateFloatingPoint(String opName, INDArray v) { + if (v == null) + return; + if (!v.dataType().isFPType()) + throw new IllegalStateException("Cannot apply operation \"" + opName + "\" to array with non-floating point data type " + v.dataType()); + } + + /** + * Validate that the operation is being applied on a floating point type INDArray + * + * @param opName Operation name to print in the exception + * @param inputName Name of the input to the op to validate + * @param v Variable to validate datatype for (input to operation) + */ + public static void validateFloatingPoint(String opName, String inputName, INDArray v) { + if (v == null) + return; + if (!v.dataType().isFPType()) + throw new IllegalStateException("Input \"" + inputName + "\" for operation \"" + opName + + "\" must be an floating point type; got array with non-floating point data type " + v.dataType()); + } + + /** + * Validate that the operation is being applied on a boolean type INDArray + * + * @param opName Operation name to print in the exception + * @param v Variable to validate datatype for (input to operation) + */ + public static void validateBool(String opName, INDArray v) { + if (v == null) + return; + if (v.dataType() != DataType.BOOL) + throw new IllegalStateException("Cannot apply operation \"" + opName + "\" to array with non-boolean point data type " + v.dataType()); + } + + /** + * Validate that the operation is being applied on a boolean type INDArray + * + * @param opName Operation name to print in the exception + * @param inputName Name of the input to the op to validate + * @param v Variable to validate datatype for (input to operation) + */ + public static void validateBool(String opName, String inputName, INDArray v) { + if (v == null) + return; + if (v.dataType() != DataType.BOOL) + throw new IllegalStateException("Input \"" + inputName + "\" for operation \"" + opName + + "\" must be an boolean variable; got array with non-boolean data type " + v.dataType()); + } + + /** + * Validate that the operation is being applied on boolean INDArrays + * + * @param opName Operation name to print in the exception + * @param v1 Variable to validate datatype for (input to operation) + * @param v2 Variable to validate datatype for (input to operation) + */ + public static void validateBool(String opName, INDArray v1, INDArray v2) { + if (v1.dataType() != DataType.BOOL || v2.dataType() != DataType.BOOL) + throw new IllegalStateException("Cannot perform operation \"" + opName + "\" on array if one or both variables are non-boolean: " + + v1.dataType() + " and " + v2.dataType()); + } + + /** + * Validate that the operation is being applied on array with the exact same datatypes (which may optionally be + * restricted to numerical INDArrays only (not boolean or utf8)) + * + * @param opName Operation name to print in the exception + * @param numericalOnly If true, the variables must all be the same type, and must be numerical (not boolean/utf8) + * @param vars Variable to perform operation on + */ + public static void validateSameType(String opName, boolean numericalOnly, INDArray... vars) { + if (vars.length == 0) + return; + if (vars.length == 1) { + if (numericalOnly) { + validateNumerical(opName, vars[0]); + } + } else { + DataType first = vars[0].dataType(); + if (numericalOnly) + validateNumerical(opName, vars[0]); + for (int i = 1; i < vars.length; i++) { + if (first != vars[i].dataType()) { + DataType[] dtypes = new DataType[vars.length]; + for (int j = 0; j < vars.length; j++) { + dtypes[j] = vars[j].dataType(); + } + throw new IllegalStateException("Cannot perform operation \"" + opName + "\" to arrays with different datatypes:" + + " Got arrays with datatypes " + Arrays.toString(dtypes)); + } + } + } + } + + public static boolean isSameType(INDArray x, INDArray y) { + return x.dataType() == y.dataType(); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java index 5e62dd198..2e2efadda 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java @@ -16,6 +16,10 @@ package org.nd4j.linalg.factory; +import org.nd4j.linalg.factory.ops.NDBitwise; +import org.nd4j.linalg.factory.ops.NDMath; +import org.nd4j.linalg.factory.ops.NDNN; +import org.nd4j.linalg.factory.ops.NDRandom; import org.nd4j.shade.guava.primitives.Ints; import org.nd4j.shade.guava.primitives.Longs; import lombok.NonNull; @@ -114,6 +118,51 @@ import java.util.logging.Logger; */ public class Nd4j { + /** + * Bitwise namespace - operations related to bitwise manipulation of arrays + */ + public static final NDBitwise bitwise = new NDBitwise(); + /** + * Math namespace - general mathematical operations + */ + public static final NDMath math = new NDMath(); + /** + * Random namespace - (pseudo) random number generation methods + */ + public static final NDRandom random = new NDRandom(); + /** + * Neural network namespace - operations related to neural networks + */ + public static final NDNN nn = new NDNN(); + + /** + * Bitwise namespace - operations related to bitwise manipulation of arrays + */ + public static NDBitwise bitwise() { + return bitwise; + } + + /** + * Math namespace - general mathematical operations + */ + public static NDMath math() { + return math; + } + + /** + * Random namespace - (pseudo) random number generation methods + */ + public static NDRandom random() { + return random; + } + + /** + * Neural network namespace - operations related to neural networks + */ + public static NDNN nn() { + return nn; + } + private final static String DATA_BUFFER_OPS = "databufferfactory"; private final static String CONVOLUTION_OPS = "convops"; /**@deprecated Use {@link ND4JSystemProperties#DTYPE}*/ @@ -2638,7 +2687,7 @@ public class Nd4j { INDArray ret; if(x.isVectorOrScalar() || x.isRowVector() || x.isColumnVector()) { ret = Nd4j.create(x.dataType(), x.length(), x.length()); - Nd4j.getExecutioner().execAndReturn(new Diag(new INDArray[]{x},new INDArray[]{ret})); + Nd4j.getExecutioner().execAndReturn(new Diag(x, ret)); } else { ret = Nd4j.createUninitialized(x.dataType(), Math.min(x.size(0), x.size(1))); Nd4j.getExecutioner().execAndReturn(new DiagPart(x,ret)); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDBitwise.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDBitwise.java new file mode 100644 index 000000000..f77d5c823 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDBitwise.java @@ -0,0 +1,211 @@ +/* ****************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +//================== GENERATED CODE - DO NOT MODIFY THIS FILE ================== + +package org.nd4j.linalg.factory.ops; + +import static org.nd4j.linalg.factory.NDValidation.isSameType; + +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.NDValidation; +import org.nd4j.linalg.factory.Nd4j; + +public class NDBitwise { + public NDBitwise() { + } + + /** + * Bitwise AND operation. Supports broadcasting.
+ * + * Inputs must satisfy the following constraints:
+ * Must be same types: isSameType(x, y)
+ * Must have broadcastable shapes: isBroadcastableShapes(x, y)
+ * + * @param x First input array (INT type) + * @param y Second input array (INT type) + * @return output Bitwise AND array (INT type) + */ + public INDArray and(INDArray x, INDArray y) { + NDValidation.validateInteger("and", "x", x); + NDValidation.validateInteger("and", "y", y); + Preconditions.checkArgument(isSameType(x, y), "Must be same types"); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseAnd(x, y))[0]; + } + + /** + * Roll integer bits to the left, i.e. var << 4 | var >> (32 - 4)
+ * + * @param x Input 1 (INT type) + * @param shift Number of bits to shift. (INT type) + * @return output SDVariable with shifted bits (INT type) + */ + public INDArray bitRotl(INDArray x, INDArray shift) { + NDValidation.validateInteger("bitRotl", "x", x); + NDValidation.validateInteger("bitRotl", "shift", shift); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicShiftBits(x, shift))[0]; + } + + /** + * Roll integer bits to the right, i.e. var >> 4 | var << (32 - 4)
+ * + * @param x Input 1 (INT type) + * @param shift Number of bits to shift. (INT type) + * @return output SDVariable with shifted bits (INT type) + */ + public INDArray bitRotr(INDArray x, INDArray shift) { + NDValidation.validateInteger("bitRotr", "x", x); + NDValidation.validateInteger("bitRotr", "shift", shift); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicRShiftBits(x, shift))[0]; + } + + /** + * Shift integer bits to the left, i.e. var << 4
+ * + * @param x Input 1 (INT type) + * @param shift Number of bits to shift. (INT type) + * @return output SDVariable with shifted bits (INT type) + */ + public INDArray bitShift(INDArray x, INDArray shift) { + NDValidation.validateInteger("bitShift", "x", x); + NDValidation.validateInteger("bitShift", "shift", shift); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.ShiftBits(x, shift))[0]; + } + + /** + * Shift integer bits to the right, i.e. var >> 4
+ * + * @param x Input 1 (INT type) + * @param shift Number of bits to shift. (INT type) + * @return output SDVariable with shifted bits (INT type) + */ + public INDArray bitShiftRight(INDArray x, INDArray shift) { + NDValidation.validateInteger("bitShiftRight", "x", x); + NDValidation.validateInteger("bitShiftRight", "shift", shift); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.RShiftBits(x, shift))[0]; + } + + /** + * Bitwise Hamming distance reduction over all elements of both input arrays.
+ * For example, if x=01100000 and y=1010000 then the bitwise Hamming distance is 2 (due to differences at positions 0 and 1)
+ * + * Inputs must satisfy the following constraints:
+ * Must be same types: isSameType(x, y)
+ * + * @param x First input array. (INT type) + * @param y Second input array. (INT type) + * @return output bitwise Hamming distance (INT type) + */ + public INDArray bitsHammingDistance(INDArray x, INDArray y) { + NDValidation.validateInteger("bitsHammingDistance", "x", x); + NDValidation.validateInteger("bitsHammingDistance", "y", y); + Preconditions.checkArgument(isSameType(x, y), "Must be same types"); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.BitsHammingDistance(x, y))[0]; + } + + /** + * Bitwise left shift operation. Supports broadcasting.
+ * + * @param x Input to be bit shifted (INT type) + * @param y Amount to shift elements of x array (INT type) + * @return output Bitwise shifted input x (INT type) + */ + public INDArray leftShift(INDArray x, INDArray y) { + NDValidation.validateInteger("leftShift", "x", x); + NDValidation.validateInteger("leftShift", "y", y); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.ShiftBits(x, y))[0]; + } + + /** + * Bitwise left cyclical shift operation. Supports broadcasting.
+ * Unlike {@link #leftShift(INDArray, INDArray)} the bits will "wrap around":
+ * {@code leftShiftCyclic(01110000, 2) -> 11000001}
+ * + * @param x Input to be bit shifted (INT type) + * @param y Amount to shift elements of x array (INT type) + * @return output Bitwise cyclic shifted input x (INT type) + */ + public INDArray leftShiftCyclic(INDArray x, INDArray y) { + NDValidation.validateInteger("leftShiftCyclic", "x", x); + NDValidation.validateInteger("leftShiftCyclic", "y", y); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicShiftBits(x, y))[0]; + } + + /** + * Bitwise OR operation. Supports broadcasting.
+ * + * Inputs must satisfy the following constraints:
+ * Must be same types: isSameType(x, y)
+ * Must have broadcastable shapes: isBroadcastableShapes(x, y)
+ * + * @param x First input array (INT type) + * @param y First input array (INT type) + * @return output Bitwise OR array (INT type) + */ + public INDArray or(INDArray x, INDArray y) { + NDValidation.validateInteger("or", "x", x); + NDValidation.validateInteger("or", "y", y); + Preconditions.checkArgument(isSameType(x, y), "Must be same types"); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseOr(x, y))[0]; + } + + /** + * Bitwise right shift operation. Supports broadcasting.
+ * + * @param x Input to be bit shifted (INT type) + * @param y Amount to shift elements of x array (INT type) + * @return output Bitwise shifted input x (INT type) + */ + public INDArray rightShift(INDArray x, INDArray y) { + NDValidation.validateInteger("rightShift", "x", x); + NDValidation.validateInteger("rightShift", "y", y); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.RShiftBits(x, y))[0]; + } + + /** + * Bitwise right cyclical shift operation. Supports broadcasting.
+ * Unlike {@link #rightShift(INDArray, INDArray)} the bits will "wrap around":
+ * {@code rightShiftCyclic(00001110, 2) -> 10000011}
+ * + * @param x Input to be bit shifted (INT type) + * @param y Amount to shift elements of x array (INT type) + * @return output Bitwise cyclic shifted input x (INT type) + */ + public INDArray rightShiftCyclic(INDArray x, INDArray y) { + NDValidation.validateInteger("rightShiftCyclic", "x", x); + NDValidation.validateInteger("rightShiftCyclic", "y", y); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicRShiftBits(x, y))[0]; + } + + /** + * Bitwise XOR operation (exclusive OR). Supports broadcasting.
+ * + * Inputs must satisfy the following constraints:
+ * Must be same types: isSameType(x, y)
+ * Must have broadcastable shapes: isBroadcastableShapes(x, y)
+ * + * @param x First input array (INT type) + * @param y First input array (INT type) + * @return output Bitwise XOR array (INT type) + */ + public INDArray xor(INDArray x, INDArray y) { + NDValidation.validateInteger("xor", "x", x); + NDValidation.validateInteger("xor", "y", y); + Preconditions.checkArgument(isSameType(x, y), "Must be same types"); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseXor(x, y))[0]; + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDMath.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDMath.java new file mode 100644 index 000000000..8e194fcd4 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDMath.java @@ -0,0 +1,1324 @@ +/* ****************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +//================== GENERATED CODE - DO NOT MODIFY THIS FILE ================== + +package org.nd4j.linalg.factory.ops; + +import static org.nd4j.linalg.factory.NDValidation.isSameType; + +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.NDValidation; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.indexing.conditions.Condition; + +public class NDMath { + public NDMath() { + } + + /** + * Elementwise absolute value operation: out = abs(x)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray abs(INDArray x) { + NDValidation.validateNumerical("abs", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.same.Abs(x)); + } + + /** + * Elementwise acos (arccosine, inverse cosine) operation: out = arccos(x)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray acos(INDArray x) { + NDValidation.validateNumerical("acos", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.ACos(x)); + } + + /** + * Elementwise acosh (inverse hyperbolic cosine) function: out = acosh(x)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray acosh(INDArray x) { + NDValidation.validateNumerical("acosh", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.ACosh(x)); + } + + /** + * Absolute max array reduction operation, optionally along specified dimensions: out = max(abs(x))
+ * + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public INDArray amax(INDArray in, int... dimensions) { + NDValidation.validateNumerical("amax", "in", in); + Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.same.AMax(in, dimensions)); + } + + /** + * Absolute mean array reduction operation, optionally along specified dimensions: out = mean(abs(x))
+ * + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public INDArray amean(INDArray in, int... dimensions) { + NDValidation.validateNumerical("amean", "in", in); + Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.floating.AMean(in, dimensions)); + } + + /** + * Absolute min array reduction operation, optionally along specified dimensions: out = min(abs(x))
+ * + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public INDArray amin(INDArray in, int... dimensions) { + NDValidation.validateNumerical("amin", "in", in); + Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.same.AMin(in, dimensions)); + } + + /** + * Boolean AND operation: elementwise (x != 0) && (y != 0)
+ * If x and y arrays have equal shape, the output shape is the same as these inputs.
+ * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * Returns an array with values 1 where condition is satisfied, or value 0 otherwise.
+ * + * @param x Input 1 (BOOL type) + * @param y Input 2 (BOOL type) + * @return output INDArray with values 0 and 1 based on where the condition is satisfied (BOOL type) + */ + public INDArray and(INDArray x, INDArray y) { + NDValidation.validateBool("and", "x", x); + NDValidation.validateBool("and", "y", y); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.And(x, y)); + } + + /** + * Elementwise asin (arcsin, inverse sine) operation: out = arcsin(x)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray asin(INDArray x) { + NDValidation.validateNumerical("asin", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.ASin(x)); + } + + /** + * Elementwise asinh (inverse hyperbolic sine) function: out = asinh(x)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray asinh(INDArray x) { + NDValidation.validateNumerical("asinh", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.ASinh(x)); + } + + /** + * Absolute sum array reduction operation, optionally along specified dimensions: out = sum(abs(x))
+ * + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public INDArray asum(INDArray in, int... dimensions) { + NDValidation.validateNumerical("asum", "in", in); + Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.same.ASum(in, dimensions)); + } + + /** + * Elementwise atan (arctangent, inverse tangent) operation: out = arctangent(x)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray atan(INDArray x) { + NDValidation.validateNumerical("atan", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.ATan(x)); + } + + /** + * Elementwise atan (arctangent, inverse tangent) operation: out = atan2(x,y).
+ * Similar to atan(y/x) but sigts of x and y are used to determine the location of the result
+ * + * @param y Input Y variable (NUMERIC type) + * @param x Input X variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray atan2(INDArray y, INDArray x) { + NDValidation.validateNumerical("atan2", "y", y); + NDValidation.validateNumerical("atan2", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.ATan2(y, x))[0]; + } + + /** + * Elementwise atanh (inverse hyperbolic tangent) function: out = atanh(x)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray atanh(INDArray x) { + NDValidation.validateNumerical("atanh", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.ATanh(x)); + } + + /** + * Element-wise ceiling function: out = ceil(x).
+ * Rounds each value up to the nearest integer value (if not already an integer)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray ceil(INDArray x) { + NDValidation.validateNumerical("ceil", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.same.Ceil(x)); + } + + /** + * Clipping by L2 norm, optionally along dimension(s)
+ * if l2Norm(x,dimension) < clipValue, then input is returned unmodifed
+ * Otherwise, out[i] = in[i] * clipValue / l2Norm(in, dimensions) where each value is clipped according
+ * to the corresponding l2Norm along the specified dimensions
+ * + * @param x Input variable (NUMERIC type) + * @param clipValue Clipping value (maximum l2 norm) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output Output variable (NUMERIC type) + */ + public INDArray clipByNorm(INDArray x, double clipValue, int... dimensions) { + NDValidation.validateNumerical("clipByNorm", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByNorm(x, clipValue, dimensions))[0]; + } + + /** + * Element-wise clipping function:
+ * out[i] = in[i] if in[i] >= clipValueMin and in[i] <= clipValueMax
+ * out[i] = clipValueMin if in[i] < clipValueMin
+ * out[i] = clipValueMax if in[i] > clipValueMax
+ * + * @param x Input variable (NUMERIC type) + * @param clipValueMin Minimum value for clipping + * @param clipValueMax Maximum value for clipping + * @return output Output variable (NUMERIC type) + */ + public INDArray clipByValue(INDArray x, double clipValueMin, double clipValueMax) { + NDValidation.validateNumerical("clipByValue", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByValue(x, clipValueMin, clipValueMax))[0]; + } + + /** + * Compute the 2d confusion matrix of size [numClasses, numClasses] from a pair of labels and predictions, both of
+ * which are represented as integer values. This version assumes the number of classes is 1 + max(max(labels), max(pred))
+ * For example, if labels = [0, 1, 1] and predicted = [0, 2, 1] then output is:
+ * [1, 0, 0]
+ * [0, 1, 1]
+ * [0, 0, 0]
+ * + * @param labels Labels - 1D array of integer values representing label values (NUMERIC type) + * @param pred Predictions - 1D array of integer values representing predictions. Same length as labels (NUMERIC type) + * @param dataType Data type + * @return output variable (2D, shape [numClasses, numClasses}) (NUMERIC type) + */ + public INDArray confusionMatrix(INDArray labels, INDArray pred, DataType dataType) { + NDValidation.validateNumerical("confusionMatrix", "labels", labels); + NDValidation.validateNumerical("confusionMatrix", "pred", pred); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.ConfusionMatrix(labels, pred, dataType))[0]; + } + + /** + * Compute the 2d confusion matrix of size [numClasses, numClasses] from a pair of labels and predictions, both of
+ * which are represented as integer values.
+ * For example, if labels = [0, 1, 1], predicted = [0, 2, 1], and numClasses=4 then output is:
+ * [1, 0, 0, 0]
+ * [0, 1, 1, 0]
+ * [0, 0, 0, 0]
+ * [0, 0, 0, 0]
+ * + * @param labels Labels - 1D array of integer values representing label values (NUMERIC type) + * @param pred Predictions - 1D array of integer values representing predictions. Same length as labels (NUMERIC type) + * @param numClasses Number of classes + * @return output variable (2D, shape [numClasses, numClasses}) (NUMERIC type) + */ + public INDArray confusionMatrix(INDArray labels, INDArray pred, int numClasses) { + NDValidation.validateNumerical("confusionMatrix", "labels", labels); + NDValidation.validateNumerical("confusionMatrix", "pred", pred); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.ConfusionMatrix(labels, pred, numClasses))[0]; + } + + /** + * Compute the 2d confusion matrix of size [numClasses, numClasses] from a pair of labels and predictions, both of
+ * which are represented as integer values. This version assumes the number of classes is 1 + max(max(labels), max(pred))
+ * For example, if labels = [0, 1, 1], predicted = [0, 2, 1] and weights = [1, 2, 3]
+ * [1, 0, 0]
+ * [0, 3, 2]
+ * [0, 0, 0]
+ * + * @param labels Labels - 1D array of integer values representing label values (NUMERIC type) + * @param pred Predictions - 1D array of integer values representing predictions. Same length as labels (NUMERIC type) + * @param weights Weights - 1D array of values (may be real/decimal) representing the weight/contribution of each prediction. Must be same length as both labels and predictions arrays (NUMERIC type) + * @return output variable (2D, shape [numClasses, numClasses}) (NUMERIC type) + */ + public INDArray confusionMatrix(INDArray labels, INDArray pred, INDArray weights) { + NDValidation.validateNumerical("confusionMatrix", "labels", labels); + NDValidation.validateNumerical("confusionMatrix", "pred", pred); + NDValidation.validateNumerical("confusionMatrix", "weights", weights); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.ConfusionMatrix(labels, pred, weights))[0]; + } + + /** + * Compute the 2d confusion matrix of size [numClasses, numClasses] from a pair of labels and predictions, both of
+ * which are represented as integer values.
+ * For example, if labels = [0, 1, 1], predicted = [0, 2, 1], numClasses = 4, and weights = [1, 2, 3]
+ * [1, 0, 0, 0]
+ * [0, 3, 2, 0]
+ * [0, 0, 0, 0]
+ * [0, 0, 0, 0]
+ * + * @param labels Labels - 1D array of integer values representing label values (NUMERIC type) + * @param pred Predictions - 1D array of integer values representing predictions. Same length as labels (NUMERIC type) + * @param weights Weights - 1D array of values (may be real/decimal) representing the weight/contribution of each prediction. Must be same length as both labels and predictions arrays (NUMERIC type) + * @param numClasses + * @return output Output variable (2D, shape [numClasses, numClasses}) (NUMERIC type) + */ + public INDArray confusionMatrix(INDArray labels, INDArray pred, INDArray weights, + int numClasses) { + NDValidation.validateNumerical("confusionMatrix", "labels", labels); + NDValidation.validateNumerical("confusionMatrix", "pred", pred); + NDValidation.validateNumerical("confusionMatrix", "weights", weights); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.ConfusionMatrix(labels, pred, weights, numClasses))[0]; + } + + /** + * Elementwise cosine operation: out = cos(x)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray cos(INDArray x) { + NDValidation.validateNumerical("cos", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.Cos(x)); + } + + /** + * Elementwise cosh (hyperbolic cosine) operation: out = cosh(x)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray cosh(INDArray x) { + NDValidation.validateNumerical("cosh", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.Cosh(x)); + } + + /** + * Cosine distance reduction operation. The output contains the cosine distance for each
+ * tensor/subset along the specified dimensions:
+ * out = 1.0 - cosineSimilarity(x,y)
+ * + * @param x Input variable x (NUMERIC type) + * @param y Input variable y (NUMERIC type) + * @param dimensions Dimensions to calculate cosineDistance over (Size: AtLeast(min=1)) + * @return output Output variable (NUMERIC type) + */ + public INDArray cosineDistance(INDArray x, INDArray y, int... dimensions) { + NDValidation.validateNumerical("cosineDistance", "x", x); + NDValidation.validateNumerical("cosineDistance", "y", y); + Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce3.CosineDistance(x, y, dimensions)); + } + + /** + * Cosine similarity pairwise reduction operation. The output contains the cosine similarity for each tensor/subset
+ * along the specified dimensions:
+ * out = (sum_i x[i] * y[i]) / ( sqrt(sum_i x[i]^2) * sqrt(sum_i y[i]^2)
+ * + * @param x Input variable x (NUMERIC type) + * @param y Input variable y (NUMERIC type) + * @param dimensions Dimensions to calculate cosineSimilarity over (Size: AtLeast(min=1)) + * @return output Output variable (NUMERIC type) + */ + public INDArray cosineSimilarity(INDArray x, INDArray y, int... dimensions) { + NDValidation.validateNumerical("cosineSimilarity", "x", x); + NDValidation.validateNumerical("cosineSimilarity", "y", y); + Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce3.CosineSimilarity(x, y, dimensions)); + } + + /** + * Count non zero array reduction operation, optionally along specified dimensions: out = count(x != 0)
+ * + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public INDArray countNonZero(INDArray in, int... dimensions) { + NDValidation.validateNumerical("countNonZero", "in", in); + Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.longer.CountNonZero(in, dimensions)); + } + + /** + * Count zero array reduction operation, optionally along specified dimensions: out = count(x == 0)
+ * + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public INDArray countZero(INDArray in, int... dimensions) { + NDValidation.validateNumerical("countZero", "in", in); + Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.longer.CountZero(in, dimensions)); + } + + /** + * Returns the pair-wise cross product of equal size arrays a and b: a x b = ||a||x||b|| sin(theta).
+ * Can take rank 1 or above inputs (of equal shapes), but note that the last dimension must have dimension 3
+ * + * @param a First input (NUMERIC type) + * @param b Second input (NUMERIC type) + * @return output Element-wise cross product (NUMERIC type) + */ + public INDArray cross(INDArray a, INDArray b) { + NDValidation.validateNumerical("cross", "a", a); + NDValidation.validateNumerical("cross", "b", b); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Cross(a, b))[0]; + } + + /** + * Element-wise cube function: out = x^3
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray cube(INDArray x) { + NDValidation.validateNumerical("cube", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.same.Cube(x)); + } + + /** + * Returns an output variable with diagonal values equal to the specified values; off-diagonal values will be set to 0
+ * For example, if input = [1,2,3], then output is given by:
+ * [ 1, 0, 0]
+ * [ 0, 2, 0]
+ * [ 0, 0, 3]
+ *
+ * Higher input ranks are also supported: if input has shape [a,...,R-1] then output[i,...,k,i,...,k] = input[i,...,k].
+ * i.e., for input rank R, output has rank 2R
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray diag(INDArray x) { + NDValidation.validateNumerical("diag", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Diag(x))[0]; + } + + /** + * Extract the diagonal part from the input array.
+ * If input is
+ * [ 1, 0, 0]
+ * [ 0, 2, 0]
+ * [ 0, 0, 3]
+ * then output is [1, 2, 3].
+ * Supports higher dimensions: in general, out[i,...,k] = in[i,...,k,i,...,k]
+ * + * @param x Input variable (NUMERIC type) + * @return output Diagonal part of the input (NUMERIC type) + */ + public INDArray diagPart(INDArray x) { + NDValidation.validateNumerical("diagPart", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.DiagPart(x))[0]; + } + + /** + * Entropy reduction: -sum(x * log(x))
+ * + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public INDArray entropy(INDArray in, int... dimensions) { + NDValidation.validateNumerical("entropy", "in", in); + Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.floating.Entropy(in, dimensions)); + } + + /** + * Element-wise Gaussian error function - out = erf(in)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray erf(INDArray x) { + NDValidation.validateNumerical("erf", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.Erf(x)); + } + + /** + * Element-wise complementary Gaussian error function - out = erfc(in) = 1 - erf(in)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray erfc(INDArray x) { + NDValidation.validateNumerical("erfc", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.Erfc(x)); + } + + /** + * Euclidean distance (l2 norm, l2 distance) reduction operation. The output contains the Euclidean distance for each
+ * tensor/subset along the specified dimensions:
+ * out = sqrt( sum_i (x[i] - y[i])^2 )
+ * + * @param x Input variable x (NUMERIC type) + * @param y Input variable y (NUMERIC type) + * @param dimensions Dimensions to calculate euclideanDistance over (Size: AtLeast(min=1)) + * @return output Output variable (NUMERIC type) + */ + public INDArray euclideanDistance(INDArray x, INDArray y, int... dimensions) { + NDValidation.validateNumerical("euclideanDistance", "x", x); + NDValidation.validateNumerical("euclideanDistance", "y", y); + Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce3.EuclideanDistance(x, y, dimensions)); + } + + /** + * Elementwise exponent function: out = exp(x) = 2.71828...^x
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray exp(INDArray x) { + NDValidation.validateNumerical("exp", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.Exp(x)); + } + + /** + * Elementwise 1.0 - exponent function: out = 1.0 - exp(x) = 1.0 - 2.71828...^x
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray expm1(INDArray x) { + NDValidation.validateNumerical("expm1", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.Expm1(x)); + } + + /** + * Generate an identity matrix with the specified number of rows and columns.
+ * + * @param rows Number of rows + * @return output Identity matrix (NUMERIC type) + */ + public INDArray eye(int rows) { + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Eye(rows))[0]; + } + + /** + * As per eye(String, int, int, DataType) but with the default datatype, Eye.DEFAULT_DTYPE
+ * + * @param rows Number of rows + * @param cols Number of columns + * @return output (NUMERIC type) + */ + public INDArray eye(int rows, int cols) { + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Eye(rows, cols))[0]; + } + + /** + * Generate an identity matrix with the specified number of rows and columns
+ * Example:
+ *

+ * {@code INDArray eye = eye(3,2)
+ * eye:
+ * [ 1, 0]
+ * [ 0, 1]
+ * [ 0, 0]}
+ *

+ * + * @param rows Number of rows + * @param cols Number of columns + * @param dataType Data type + * @return output Identity matrix (NUMERIC type) + */ + public INDArray eye(int rows, int cols, DataType dataType) { + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Eye(rows, cols, dataType))[0]; + } + + /** + * As per eye(int, int) bit with the number of rows/columns specified as scalar INDArrays
+ * + * @param rows Number of rows (INT type) + * @param cols Number of columns (INT type) + * @return output Identity matrix (NUMERIC type) + */ + public INDArray eye(INDArray rows, INDArray cols) { + NDValidation.validateInteger("eye", "rows", rows); + NDValidation.validateInteger("eye", "cols", cols); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Eye(rows, cols))[0]; + } + + /** + * As per eye(String, int) but with the number of rows specified as a scalar INDArray
+ * + * @param rows Number of rows (INT type) + * @return output SDVaribable identity matrix (NUMERIC type) + */ + public INDArray eye(INDArray rows) { + NDValidation.validateInteger("eye", "rows", rows); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Eye(rows))[0]; + } + + /** + * First index reduction operation.
+ * Returns a variable that contains the index of the first element that matches the specified condition (for each
+ * slice along the specified dimensions)
+ * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param in Input variable (NUMERIC type) + * @param condition Condition to check on input variable + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public INDArray firstIndex(INDArray in, Condition condition, int... dimensions) { + NDValidation.validateNumerical("firstIndex", "in", in); + Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.FirstIndex(in, condition, dimensions)); + } + + /** + * First index reduction operation.
+ * Returns a variable that contains the index of the first element that matches the specified condition (for each
+ * slice along the specified dimensions)
+ * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param in Input variable (NUMERIC type) + * @param condition Condition to check on input variable + * @param keepDims If true: keep the dimensions that are reduced on (as length 1). False: remove the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public INDArray firstIndex(INDArray in, Condition condition, boolean keepDims, + int... dimensions) { + NDValidation.validateNumerical("firstIndex", "in", in); + Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.FirstIndex(in, condition, keepDims, dimensions)); + } + + /** + * Element-wise floor function: out = floor(x).
+ * Rounds each value down to the nearest integer value (if not already an integer)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray floor(INDArray x) { + NDValidation.validateNumerical("floor", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.same.Floor(x)); + } + + /** + * Hamming distance reduction operation. The output contains the cosine distance for each
+ * tensor/subset along the specified dimensions:
+ * out = count( x[i] != y[i] )
+ * + * @param x Input variable x (NUMERIC type) + * @param y Input variable y (NUMERIC type) + * @param dimensions Dimensions to calculate hammingDistance over (Size: AtLeast(min=1)) + * @return output Output variable (NUMERIC type) + */ + public INDArray hammingDistance(INDArray x, INDArray y, int... dimensions) { + NDValidation.validateNumerical("hammingDistance", "x", x); + NDValidation.validateNumerical("hammingDistance", "y", y); + Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce3.HammingDistance(x, y, dimensions)); + } + + /** + * Index of the max absolute value: argmax(abs(in))
+ * see argmax(String, INDArray, boolean, int...)
+ * + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public INDArray iamax(INDArray in, int... dimensions) { + NDValidation.validateNumerical("iamax", "in", in); + Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.IAMax(in, dimensions)); + } + + /** + * Index of the max absolute value: argmax(abs(in))
+ * see argmax(String, INDArray, boolean, int...)
+ * + * @param in Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as length 1). False: remove the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public INDArray iamax(INDArray in, boolean keepDims, int... dimensions) { + NDValidation.validateNumerical("iamax", "in", in); + Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.IAMax(in, keepDims, dimensions)); + } + + /** + * Index of the min absolute value: argmin(abs(in))
+ * see argmin(String, INDArray, boolean, int...)
+ * + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public INDArray iamin(INDArray in, int... dimensions) { + NDValidation.validateNumerical("iamin", "in", in); + Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.IAMin(in, dimensions)); + } + + /** + * Index of the min absolute value: argmin(abs(in))
+ * see argmin(String, INDArray, boolean, int...)
+ * + * @param in Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as length 1). False: remove the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public INDArray iamin(INDArray in, boolean keepDims, int... dimensions) { + NDValidation.validateNumerical("iamin", "in", in); + Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.IAMin(in, keepDims, dimensions)); + } + + /** + * Is finite operation: elementwise isFinite(x)
+ * Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or
+ * value 0 otherwise
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray isFinite(INDArray x) { + NDValidation.validateNumerical("isFinite", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.bool.IsFinite(x)); + } + + /** + * Is infinite operation: elementwise isInfinite(x)
+ * Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or
+ * value 0 otherwise
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray isInfinite(INDArray x) { + NDValidation.validateNumerical("isInfinite", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.bool.IsInf(x)); + } + + /** + * Is maximum operation: elementwise x == max(x)
+ * Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or
+ * value 0 otherwise
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray isMax(INDArray x) { + NDValidation.validateNumerical("isMax", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.any.IsMax(x))[0]; + } + + /** + * Is Not a Number operation: elementwise isNaN(x)
+ * Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or
+ * value 0 otherwise
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray isNaN(INDArray x) { + NDValidation.validateNumerical("isNaN", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.bool.IsNaN(x)); + } + + /** + * Is the array non decreasing?
+ * An array is non-decreasing if for every valid i, x[i] <= x[i+1]. For Rank 2+ arrays, values are compared
+ * in 'c' (row major) order
+ * + * @param x Input variable (NUMERIC type) + * @return output Scalar variable with value 1 if non-decreasing, or 0 otherwise (NUMERIC type) + */ + public INDArray isNonDecreasing(INDArray x) { + NDValidation.validateNumerical("isNonDecreasing", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.IsNonDecreasing(x))[0]; + } + + /** + * Is the array strictly increasing?
+ * An array is strictly increasing if for every valid i, x[i] < x[i+1]. For Rank 2+ arrays, values are compared
+ * in 'c' (row major) order
+ * + * @param x Input variable (NUMERIC type) + * @return output Scalar variable with value 1 if strictly increasing, or 0 otherwise (NUMERIC type) + */ + public INDArray isStrictlyIncreasing(INDArray x) { + NDValidation.validateNumerical("isStrictlyIncreasing", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.IsStrictlyIncreasing(x))[0]; + } + + /** + * Jaccard similarity reduction operation. The output contains the Jaccard distance for each
+ * tensor along the specified dimensions.
+ * + * @param x Input variable x (NUMERIC type) + * @param y Input variable y (NUMERIC type) + * @param dimensions Dimensions to calculate jaccardDistance over (Size: AtLeast(min=1)) + * @return output Output variable (NUMERIC type) + */ + public INDArray jaccardDistance(INDArray x, INDArray y, int... dimensions) { + NDValidation.validateNumerical("jaccardDistance", "x", x); + NDValidation.validateNumerical("jaccardDistance", "y", y); + Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce3.JaccardDistance(x, y, dimensions)); + } + + /** + * Last index reduction operation.
+ * Returns a variable that contains the index of the last element that matches the specified condition (for each
+ * slice along the specified dimensions)
+ * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param in Input variable (NUMERIC type) + * @param condition Condition to check on input variable + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public INDArray lastIndex(INDArray in, Condition condition, int... dimensions) { + NDValidation.validateNumerical("lastIndex", "in", in); + Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.LastIndex(in, condition, dimensions)); + } + + /** + * Last index reduction operation.
+ * Returns a variable that contains the index of the last element that matches the specified condition (for each
+ * slice along the specified dimensions)
+ * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param in Input variable (NUMERIC type) + * @param condition Condition to check on input variable + * @param keepDims If true: keep the dimensions that are reduced on (as length 1). False: remove the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public INDArray lastIndex(INDArray in, Condition condition, boolean keepDims, int... dimensions) { + NDValidation.validateNumerical("lastIndex", "in", in); + Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.LastIndex(in, condition, keepDims, dimensions)); + } + + /** + * Element-wise logarithm function (base e - natural logarithm): out = log(x)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray log(INDArray x) { + NDValidation.validateNumerical("log", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.Log(x)); + } + + /** + * Element-wise logarithm function (with specified base): out = log_{base}(x)
+ * + * @param x Input variable (NUMERIC type) + * @param base Logarithm base (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray log(INDArray x, INDArray base) { + NDValidation.validateNumerical("log", "x", x); + NDValidation.validateNumerical("log", "base", base); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.Log(x, base)); + } + + /** + * Elementwise natural logarithm function: out = log_e (1 + x)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray log1p(INDArray x) { + NDValidation.validateNumerical("log1p", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.Log1p(x)); + } + + /** + * Log entropy reduction: log(-sum(x * log(x)))
+ * + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public INDArray logEntropy(INDArray in, int... dimensions) { + NDValidation.validateNumerical("logEntropy", "in", in); + Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.floating.LogEntropy(in, dimensions)); + } + + /** + * Log-sum-exp reduction (optionally along dimension).
+ * Computes log(sum(exp(x))
+ * + * @param input Input variable (NUMERIC type) + * @param dimensions Optional dimensions to reduce along (Size: AtLeast(min=0)) + * @return output Output variable (NUMERIC type) + */ + public INDArray logSumExp(INDArray input, int... dimensions) { + NDValidation.validateNumerical("logSumExp", "input", input); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.custom.LogSumExp(input, dimensions))[0]; + } + + /** + * Manhattan distance (l1 norm, l1 distance) reduction operation. The output contains the Manhattan distance for each
+ * tensor/subset along the specified dimensions:
+ * out = sum_i abs(x[i]-y[i])
+ * + * @param x Input variable x (NUMERIC type) + * @param y Input variable y (NUMERIC type) + * @param dimensions Dimensions to calculate manhattanDistance over (Size: AtLeast(min=1)) + * @return output Output variable (NUMERIC type) + */ + public INDArray manhattanDistance(INDArray x, INDArray y, int... dimensions) { + NDValidation.validateNumerical("manhattanDistance", "x", x); + NDValidation.validateNumerical("manhattanDistance", "y", y); + Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce3.ManhattanDistance(x, y, dimensions)); + } + + /** + * Matrix determinant op. For 2D input, this returns the standard matrix determinant.
+ * For higher dimensional input with shape [..., m, m] the matrix determinant is returned for each
+ * shape [m,m] sub-matrix.
+ * + * @param in Input (NUMERIC type) + * @return output Matrix determinant variable (NUMERIC type) + */ + public INDArray matrixDeterminant(INDArray in) { + NDValidation.validateNumerical("matrixDeterminant", "in", in); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.MatrixDeterminant(in))[0]; + } + + /** + * Matrix inverse op. For 2D input, this returns the standard matrix inverse.
+ * For higher dimensional input with shape [..., m, m] the matrix inverse is returned for each
+ * shape [m,m] sub-matrix.
+ * + * @param in Input (NUMERIC type) + * @return output Matrix inverse variable (NUMERIC type) + */ + public INDArray matrixInverse(INDArray in) { + NDValidation.validateNumerical("matrixInverse", "in", in); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.MatrixInverse(in))[0]; + } + + /** + * Merge add function: merges an arbitrary number of equal shaped arrays using element-wise addition:
+ * out = sum_i in[i]
+ * + * @param inputs Input variables (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray mergeAdd(INDArray[] inputs) { + NDValidation.validateNumerical("mergeAdd", "inputs", inputs); + Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MergeAddOp(inputs))[0]; + } + + /** + * Merge average function: merges an arbitrary number of equal shaped arrays using element-wise mean operation:
+ * out = mean_i in[i]
+ * + * @param inputs Input variables (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray mergeAvg(INDArray[] inputs) { + NDValidation.validateNumerical("mergeAvg", "inputs", inputs); + Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.MergeAvg(inputs))[0]; + } + + /** + * Merge max function: merges an arbitrary number of equal shaped arrays using element-wise maximum operation:
+ * out = max_i in[i]
+ * + * @param inputs Input variables (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray mergeMax(INDArray[] inputs) { + NDValidation.validateNumerical("mergeMax", "inputs", inputs); + Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.MergeMax(inputs))[0]; + } + + /** + * Calculate the mean and (population) variance for the input variable, for the specified axis
+ * + * @param input Input to calculate moments for (NUMERIC type) + * @param axes Dimensions to perform calculation over (Size: AtLeast(min=0)) + * @return output Mean and variance variables (NUMERIC type) + */ + public INDArray moments(INDArray input, int... axes) { + NDValidation.validateNumerical("moments", "input", input); + Preconditions.checkArgument(axes.length >= 0, "axes has incorrect size/length. Expected: axes.length >= 0, got %s", axes.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.Moments(input, axes))[0]; + } + + /** + * Elementwise negative operation: out = -x
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray neg(INDArray x) { + NDValidation.validateNumerical("neg", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.same.Negative(x)); + } + + /** + * Calculate the mean and variance from the sufficient statistics
+ * + * @param counts Rank 0 (scalar) value with the total number of values used to calculate the sufficient statistics (NUMERIC type) + * @param means Mean-value sufficient statistics: this is the SUM of all data values (NUMERIC type) + * @param variances Variaance sufficient statistics: this is the squared sum of all data values (NUMERIC type) + * @param shift Shift value, possibly 0, used when calculating the sufficient statistics (for numerical stability) + * @return output Output variables: mean and population variance (NUMERIC type) + */ + public INDArray normalizeMoments(INDArray counts, INDArray means, INDArray variances, + double shift) { + NDValidation.validateNumerical("normalizeMoments", "counts", counts); + NDValidation.validateNumerical("normalizeMoments", "means", means); + NDValidation.validateNumerical("normalizeMoments", "variances", variances); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.NormalizeMoments(counts, means, variances, shift))[0]; + } + + /** + * Boolean OR operation: elementwise (x != 0) || (y != 0)
+ * If x and y arrays have equal shape, the output shape is the same as these inputs.
+ * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * Returns an array with values 1 where condition is satisfied, or value 0 otherwise.
+ * + * @param x Input 1 (BOOL type) + * @param y Input 2 (BOOL type) + * @return output INDArray with values 0 and 1 based on where the condition is satisfied (BOOL type) + */ + public INDArray or(INDArray x, INDArray y) { + NDValidation.validateBool("or", "x", x); + NDValidation.validateBool("or", "y", y); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Or(x, y)); + } + + /** + * Element-wise power function: out = x^value
+ * + * @param x Input variable (NUMERIC type) + * @param value Scalar value for op + * @return output Output variable (NUMERIC type) + */ + public INDArray pow(INDArray x, double value) { + NDValidation.validateNumerical("pow", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scalar.Pow(x, value)); + } + + /** + * Element-wise (broadcastable) power function: out = x[i]^y[i]
+ * + * @param x Input variable (NUMERIC type) + * @param y Power (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray pow(INDArray x, INDArray y) { + NDValidation.validateNumerical("pow", "x", x); + NDValidation.validateNumerical("pow", "y", y); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.Pow(x, y))[0]; + } + + /** + * Element-wise reciprocal (inverse) function: out[i] = 1 / in[i]
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray reciprocal(INDArray x) { + NDValidation.validateNumerical("reciprocal", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.same.Reciprocal(x)); + } + + /** + * Element-wise round function: out = round(x).
+ * Rounds (up or down depending on value) to the nearest integer value.
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray round(INDArray x) { + NDValidation.validateNumerical("round", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.same.Round(x)); + } + + /** + * Element-wise reciprocal (inverse) of square root: out = 1.0 / sqrt(x)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray rsqrt(INDArray x) { + NDValidation.validateNumerical("rsqrt", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.floating.RSqrt(x)); + } + + /** + * Set the diagonal value to the specified values
+ * If input is
+ * [ a, b, c]
+ * [ d, e, f]
+ * [ g, h, i]
+ * and diag = [ 1, 2, 3] then output is
+ * [ 1, b, c]
+ * [ d, 2, f]
+ * [ g, h, 3]
+ * + * @param in Input variable (NUMERIC type) + * @param diag Diagonal (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray setDiag(INDArray in, INDArray diag) { + NDValidation.validateNumerical("setDiag", "in", in); + NDValidation.validateNumerical("setDiag", "diag", diag); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.MatrixSetDiag(in, diag))[0]; + } + + /** + * Shannon Entropy reduction: -sum(x * log2(x))
+ * + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public INDArray shannonEntropy(INDArray in, int... dimensions) { + NDValidation.validateNumerical("shannonEntropy", "in", in); + Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.floating.ShannonEntropy(in, dimensions)); + } + + /** + * Element-wise sign (signum) function:
+ * out = -1 if in < 0
+ * out = 0 if in = 0
+ * out = 1 if in > 0
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray sign(INDArray x) { + NDValidation.validateNumerical("sign", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.same.Sign(x)); + } + + /** + * Elementwise sine operation: out = sin(x)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray sin(INDArray x) { + NDValidation.validateNumerical("sin", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.Sin(x)); + } + + /** + * Elementwise sinh (hyperbolic sine) operation: out = sinh(x)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray sinh(INDArray x) { + NDValidation.validateNumerical("sinh", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.Sinh(x)); + } + + /** + * Element-wise square root function: out = sqrt(x)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray sqrt(INDArray x) { + NDValidation.validateNumerical("sqrt", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.floating.Sqrt(x)); + } + + /** + * Element-wise square function: out = x^2
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray square(INDArray x) { + NDValidation.validateNumerical("square", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.same.Square(x)); + } + + /** + * Standardize input variable along given axis
+ *


+ * out = (x - mean) / stdev
+ *


+ * with mean and stdev being calculated along the given dimension.
+ *


+ * For example: given x as a mini batch of the shape [numExamples, exampleLength]:
+ *


    + *
  • use dimension 1 too use the statistics (mean, stdev) for each example

  • + *
  • use dimension 0 if you want to use the statistics for each column across all examples

  • + *
  • use dimensions 0,1 if you want to use the statistics across all columns and examples

  • + *

+ * + * @param x Input variable (NUMERIC type) + * @param dimensions (Size: AtLeast(min=1)) + * @return output Output variable (NUMERIC type) + */ + public INDArray standardize(INDArray x, int... dimensions) { + NDValidation.validateNumerical("standardize", "x", x); + Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.Standardize(x, dimensions))[0]; + } + + /** + * Elementwise step function:
+ * out(x) = 1 if x >= cutoff
+ * out(x) = 0 otherwise
+ * + * @param x Input variable (NUMERIC type) + * @param value Scalar value for op + * @return output Output variable (NUMERIC type) + */ + public INDArray step(INDArray x, double value) { + NDValidation.validateNumerical("step", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scalar.Step(x, value)); + } + + /** + * Elementwise tangent operation: out = tan(x)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray tan(INDArray x) { + NDValidation.validateNumerical("tan", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.Tan(x)); + } + + /** + * Elementwise tanh (hyperbolic tangent) operation: out = tanh(x)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray tanh(INDArray x) { + NDValidation.validateNumerical("tanh", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.Tanh(x)); + } + + /** + * Matrix trace operation
+ * For rank 2 matrices, the output is a scalar vith the trace - i.e., sum of the main diagonal.
+ * For higher rank inputs, output[a,b,c] = trace(in[a,b,c,:,:])
+ * + * @param in Input variable (NUMERIC type) + * @return output Trace (NUMERIC type) + */ + public INDArray trace(INDArray in) { + NDValidation.validateNumerical("trace", "in", in); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.Trace(in))[0]; + } + + /** + * Boolean XOR (exclusive OR) operation: elementwise (x != 0) XOR (y != 0)
+ * If x and y arrays have equal shape, the output shape is the same as these inputs.
+ * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * Returns an array with values 1 where condition is satisfied, or value 0 otherwise.
+ * + * @param x Input 1 (BOOL type) + * @param y Input 2 (BOOL type) + * @return output INDArray with values 0 and 1 based on where the condition is satisfied (BOOL type) + */ + public INDArray xor(INDArray x, INDArray y) { + NDValidation.validateBool("xor", "x", x); + NDValidation.validateBool("xor", "y", y); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Xor(x, y)); + } + + /** + * Full array zero fraction array reduction operation, optionally along specified dimensions: out = (count(x == 0) / length(x))
+ * + * @param input Input variable (NUMERIC type) + * @return output Reduced array of rank 0 (scalar) (NUMERIC type) + */ + public INDArray zeroFraction(INDArray input) { + NDValidation.validateNumerical("zeroFraction", "input", input); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.ZeroFraction(input))[0]; + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDNN.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDNN.java new file mode 100644 index 000000000..815f22e5b --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDNN.java @@ -0,0 +1,522 @@ +/* ****************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +//================== GENERATED CODE - DO NOT MODIFY THIS FILE ================== + +package org.nd4j.linalg.factory.ops; + +import static org.nd4j.linalg.factory.NDValidation.isSameType; + +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.NDValidation; +import org.nd4j.linalg.factory.Nd4j; + +public class NDNN { + public NDNN() { + } + + /** + * Neural network batch normalization operation.
+ * For details, see https://arxiv.org/abs/1502.03167
+ * + * @param input Input variable. (NUMERIC type) + * @param mean Mean value. For 1d axis, this should match input.size(axis) (NUMERIC type) + * @param variance Variance value. For 1d axis, this should match input.size(axis) (NUMERIC type) + * @param gamma Gamma value. For 1d axis, this should match input.size(axis) (NUMERIC type) + * @param beta Beta value. For 1d axis, this should match input.size(axis) (NUMERIC type) + * @param epsilon Epsilon constant for numerical stability (to avoid division by 0) + * @param axis For 2d CNN activations: 1 for NCHW format activations, or 3 for NHWC format activations. + * For 3d CNN activations: 1 for NCDHW format, 4 for NDHWC + * For 1d/RNN activations: 1 for NCW format, 2 for NWC (Size: AtLeast(min=1)) + * @return output variable for batch normalization (NUMERIC type) + */ + public INDArray batchNorm(INDArray input, INDArray mean, INDArray variance, INDArray gamma, + INDArray beta, double epsilon, int... axis) { + NDValidation.validateNumerical("batchNorm", "input", input); + NDValidation.validateNumerical("batchNorm", "mean", mean); + NDValidation.validateNumerical("batchNorm", "variance", variance); + NDValidation.validateNumerical("batchNorm", "gamma", gamma); + NDValidation.validateNumerical("batchNorm", "beta", beta); + Preconditions.checkArgument(axis.length >= 1, "axis has incorrect size/length. Expected: axis.length >= 1, got %s", axis.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.convolution.BatchNorm(input, mean, variance, gamma, beta, epsilon, axis))[0]; + } + + /** + * Bias addition operation: a special case of addition, typically used with CNN 4D activations and a 1D bias vector
+ * + * @param input 4d input variable (NUMERIC type) + * @param bias 1d bias (NUMERIC type) + * @param nchw The format - nchw=true means [minibatch, channels, height, width] format; nchw=false - [minibatch, height, width, channels]. + * Unused for 2d inputs + * @return output Output variable, after applying bias add operation (NUMERIC type) + */ + public INDArray biasAdd(INDArray input, INDArray bias, boolean nchw) { + NDValidation.validateNumerical("biasAdd", "input", input); + NDValidation.validateNumerical("biasAdd", "bias", bias); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.broadcast.BiasAdd(input, bias, nchw))[0]; + } + + /** + * This operation performs dot product attention on the given timeseries input with the given queries
+ * out = sum(similarity(k_i, q) * v_i)
+ *
+ * similarity(k, q) = softmax(k * q) where x * q is the dot product of x and q
+ *
+ * Optionally with normalization step:
+ * similarity(k, q) = softmax(k * q / sqrt(size(q))
+ *
+ * See also "Attention is all you need" (https://arxiv.org/abs/1706.03762, p. 4, eq. 1)
+ *
+ * Note: This supports multiple queries at once, if only one query is available the queries vector still has to
+ * be 3D but can have queryCount = 1
+ *
+ * Note: keys and values usually is the same array. If you want to use it as the same array, simply pass it for
+ * both.
+ *
+ * Note: Queries, keys and values must either be all rank 3 or all rank 4 arrays. Mixing them doesn't work. The
+ * output rank will depend on the input rank.
+ * + * @param queries input 3D array "queries" of shape [batchSize, featureKeys, queryCount] + * or 4D array of shape [batchSize, numHeads, featureKeys, queryCount] (NUMERIC type) + * @param keys input 3D array "keys" of shape [batchSize, featureKeys, timesteps] + * or 4D array of shape [batchSize, numHeads, featureKeys, timesteps] (NUMERIC type) + * @param values input 3D array "values" of shape [batchSize, featureValues, timesteps] + * or 4D array of shape [batchSize, numHeads, featureValues, timesteps] (NUMERIC type) + * @param mask OPTIONAL; array that defines which values should be skipped of shape [batchSize, timesteps] (NUMERIC type) + * @param scaled normalization, false -> do not apply normalization, true -> apply normalization + * @return output Attention result arrays of shape [batchSize, featureValues, queryCount] or [batchSize, numHeads, featureValues, queryCount], + * (optionally) Attention Weights of shape [batchSize, timesteps, queryCount] or [batchSize, numHeads, timesteps, queryCount] (NUMERIC type) + */ + public INDArray dotProductAttention(INDArray queries, INDArray keys, INDArray values, + INDArray mask, boolean scaled) { + NDValidation.validateNumerical("dotProductAttention", "queries", queries); + NDValidation.validateNumerical("dotProductAttention", "keys", keys); + NDValidation.validateNumerical("dotProductAttention", "values", values); + NDValidation.validateNumerical("dotProductAttention", "mask", mask); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.DotProductAttention(queries, keys, values, mask, scaled))[0]; + } + + /** + * Dropout operation
+ * + * @param input Input array (NUMERIC type) + * @param inputRetainProbability Probability of retaining an input (set to 0 with probability 1-p) + * @return output Output (NUMERIC type) + */ + public INDArray dropout(INDArray input, double inputRetainProbability) { + NDValidation.validateNumerical("dropout", "input", input); + return Nd4j.exec(new org.nd4j.linalg.api.ops.random.impl.DropOut(input, inputRetainProbability)); + } + + /** + * Element-wise exponential linear unit (ELU) function:
+ * out = x if x > 0
+ * out = a * (exp(x) - 1) if x <= 0
+ * with constant a = 1.0
+ *


+ * See: https://arxiv.org/abs/1511.07289
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray elu(INDArray x) { + NDValidation.validateNumerical("elu", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.ELU(x))[0]; + } + + /** + * GELU activation function - Gaussian Error Linear Units
+ * For more details, see Gaussian Error Linear Units (GELUs) - https://arxiv.org/abs/1606.08415
+ * This method uses the sigmoid approximation
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray gelu(INDArray x) { + NDValidation.validateNumerical("gelu", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.GELU(x)); + } + + /** + * Element-wise hard sigmoid function:
+ * out[i] = 0 if in[i] <= -2.5
+ * out[1] = 0.2*in[i]+0.5 if -2.5 < in[i] < 2.5
+ * out[i] = 1 if in[i] >= 2.5
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray hardSigmoid(INDArray x) { + NDValidation.validateNumerical("hardSigmoid", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.HardSigmoid(x)); + } + + /** + * Element-wise hard tanh function:
+ * out[i] = -1 if in[i] <= -1
+ * out[1] = in[i] if -1 < in[i] < 1
+ * out[i] = 1 if in[i] >= 1
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray hardTanh(INDArray x) { + NDValidation.validateNumerical("hardTanh", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.HardTanh(x)); + } + + /** + * Derivative (dOut/dIn) of the element-wise hard Tanh function - hardTanh(INDArray)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray hardTanhDerivative(INDArray x) { + NDValidation.validateNumerical("hardTanhDerivative", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhDerivative(x)); + } + + /** + * Apply Layer Normalization
+ *
+ * y = gain * standardize(x) + bias
+ * + * @param input Input variable (NUMERIC type) + * @param gain Gain (NUMERIC type) + * @param bias Bias (NUMERIC type) + * @param channelsFirst For 2D input - unused. True for NCHW (minibatch, channels, height, width), false for NHWC data + * @param dimensions Dimensions to perform layer norm over - dimension=1 for 2d/MLP data, dimension=1,2,3 for CNNs (Size: AtLeast(min=1)) + * @return output Output variable (NUMERIC type) + */ + public INDArray layerNorm(INDArray input, INDArray gain, INDArray bias, boolean channelsFirst, + int... dimensions) { + NDValidation.validateNumerical("layerNorm", "input", input); + NDValidation.validateNumerical("layerNorm", "gain", gain); + NDValidation.validateNumerical("layerNorm", "bias", bias); + Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.LayerNorm(input, gain, bias, channelsFirst, dimensions))[0]; + } + + /** + * Apply Layer Normalization
+ *
+ * y = gain * standardize(x) + bias
+ * + * @param input Input variable (NUMERIC type) + * @param gain Gain (NUMERIC type) + * @param channelsFirst For 2D input - unused. True for NCHW (minibatch, channels, height, width), false for NHWC data + * @param dimensions Dimensions to perform layer norm over - dimension=1 for 2d/MLP data, dimension=1,2,3 for CNNs (Size: AtLeast(min=1)) + * @return output Output variable (NUMERIC type) + */ + public INDArray layerNorm(INDArray input, INDArray gain, boolean channelsFirst, + int... dimensions) { + NDValidation.validateNumerical("layerNorm", "input", input); + NDValidation.validateNumerical("layerNorm", "gain", gain); + Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.LayerNorm(input, gain, channelsFirst, dimensions))[0]; + } + + /** + * Element-wise leaky ReLU function:
+ * out = x if x >= 0.0
+ * out = alpha * x if x < cutoff
+ * Alpha value is most commonly set to 0.01
+ * + * @param x Input variable (NUMERIC type) + * @param alpha Cutoff - commonly 0.01 (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray leakyRelu(INDArray x, INDArray alpha) { + NDValidation.validateNumerical("leakyRelu", "x", x); + NDValidation.validateNumerical("leakyRelu", "alpha", alpha); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scalar.LeakyReLU(x, alpha)); + } + + /** + * Leaky ReLU derivative: dOut/dIn given input.
+ * + * @param x Input variable (NUMERIC type) + * @param alpha Cutoff - commonly 0.01 (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray leakyReluDerivative(INDArray x, INDArray alpha) { + NDValidation.validateNumerical("leakyReluDerivative", "x", x); + NDValidation.validateNumerical("leakyReluDerivative", "alpha", alpha); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUDerivative(x, alpha)); + } + + /** + * Linear layer operation: out = mmul(in,w) + bias
+ * Note that bias array is optional
+ * + * @param input Input data (NUMERIC type) + * @param weights Weights variable, shape [nIn, nOut] (NUMERIC type) + * @param bias Optional bias variable (may be null) (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray linear(INDArray input, INDArray weights, INDArray bias) { + NDValidation.validateNumerical("linear", "input", input); + NDValidation.validateNumerical("linear", "weights", weights); + NDValidation.validateNumerical("linear", "bias", bias); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.XwPlusB(input, weights, bias))[0]; + } + + /** + * Element-wise sigmoid function: out[i] = log(sigmoid(in[i]))
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray logSigmoid(INDArray x) { + NDValidation.validateNumerical("logSigmoid", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.LogSigmoid(x)); + } + + /** + * Log softmax activation
+ * + * @param x (NUMERIC type) + * @return output (NUMERIC type) + */ + public INDArray logSoftmax(INDArray x) { + NDValidation.validateNumerical("logSoftmax", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.LogSoftMax(x))[0]; + } + + /** + * Log softmax activation
+ * + * @param x Input (NUMERIC type) + * @param dimension Dimension along which to apply log softmax + * @return output Output - log(softmax(input)) (NUMERIC type) + */ + public INDArray logSoftmax(INDArray x, int dimension) { + NDValidation.validateNumerical("logSoftmax", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.LogSoftMax(x, dimension))[0]; + } + + /** + * This performs multi-headed dot product attention on the given timeseries input
+ * out = concat(head_1, head_2, ..., head_n) * Wo
+ * head_i = dot_product_attention(Wq_i*q, Wk_i*k, Wv_i*v)
+ *
+ * Optionally with normalization when calculating the attention for each head.
+ *
+ * See also "Attention is all you need" (https://arxiv.org/abs/1706.03762, pp. 4,5, "3.2.2 Multi-Head Attention")
+ *
+ * This makes use of dot_product_attention OP support for rank 4 inputs.
+ * see dotProductAttention(INDArray, INDArray, INDArray, INDArray, boolean, boolean)
+ * + * @param queries input 3D array "queries" of shape [batchSize, featureKeys, queryCount] (NUMERIC type) + * @param keys input 3D array "keys" of shape [batchSize, featureKeys, timesteps] (NUMERIC type) + * @param values input 3D array "values" of shape [batchSize, featureValues, timesteps] (NUMERIC type) + * @param Wq input query projection weights of shape [numHeads, projectedKeys, featureKeys] (NUMERIC type) + * @param Wk input key projection weights of shape [numHeads, projectedKeys, featureKeys] (NUMERIC type) + * @param Wv input value projection weights of shape [numHeads, projectedValues, featureValues] (NUMERIC type) + * @param Wo output projection weights of shape [numHeads * projectedValues, outSize] (NUMERIC type) + * @param mask OPTIONAL; array that defines which values should be skipped of shape [batchSize, timesteps] (NUMERIC type) + * @param scaled normalization, false -> do not apply normalization, true -> apply normalization + * @return output Attention result arrays of shape [batchSize, outSize, queryCount] + * (optionally) Attention Weights of shape [batchSize, numHeads, timesteps, queryCount] (NUMERIC type) + */ + public INDArray multiHeadDotProductAttention(INDArray queries, INDArray keys, INDArray values, + INDArray Wq, INDArray Wk, INDArray Wv, INDArray Wo, INDArray mask, boolean scaled) { + NDValidation.validateNumerical("multiHeadDotProductAttention", "queries", queries); + NDValidation.validateNumerical("multiHeadDotProductAttention", "keys", keys); + NDValidation.validateNumerical("multiHeadDotProductAttention", "values", values); + NDValidation.validateNumerical("multiHeadDotProductAttention", "Wq", Wq); + NDValidation.validateNumerical("multiHeadDotProductAttention", "Wk", Wk); + NDValidation.validateNumerical("multiHeadDotProductAttention", "Wv", Wv); + NDValidation.validateNumerical("multiHeadDotProductAttention", "Wo", Wo); + NDValidation.validateNumerical("multiHeadDotProductAttention", "mask", mask); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.MultiHeadDotProductAttention(queries, keys, values, Wq, Wk, Wv, Wo, mask, scaled))[0]; + } + + /** + * PReLU (Parameterized Rectified Linear Unit) operation. Like LeakyReLU with a learnable alpha:
+ * out[i] = in[i] if in[i] >= 0
+ * out[i] = in[i] * alpha[i] otherwise
+ *
+ * sharedAxes allows you to share learnable parameters along axes.
+ * For example, if the input has shape [batchSize, channels, height, width]
+ * and you want each channel to have its own cutoff, use sharedAxes = [2, 3] and an
+ * alpha with shape [channels].
+ * + * @param input Input data (NUMERIC type) + * @param alpha The cutoff variable. Note that the batch dimension (the 0th, whether it is batch or not) should not be part of alpha. (NUMERIC type) + * @param sharedAxes Which axes to share cutoff parameters along. (Size: AtLeast(min=1)) + * @return output Output (NUMERIC type) + */ + public INDArray prelu(INDArray input, INDArray alpha, int... sharedAxes) { + NDValidation.validateNumerical("prelu", "input", input); + NDValidation.validateNumerical("prelu", "alpha", alpha); + Preconditions.checkArgument(sharedAxes.length >= 1, "sharedAxes has incorrect size/length. Expected: sharedAxes.length >= 1, got %s", sharedAxes.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scalar.PRelu(input, alpha, sharedAxes))[0]; + } + + /** + * Element-wise rectified linear function with specified cutoff:
+ * out[i] = in[i] if in[i] >= cutoff
+ * out[i] = 0 otherwise
+ * + * @param x Input (NUMERIC type) + * @param cutoff Cutoff value for ReLU operation - x > cutoff ? x : 0. Usually 0 + * @return output Output (NUMERIC type) + */ + public INDArray relu(INDArray x, double cutoff) { + NDValidation.validateNumerical("relu", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinear(x, cutoff)); + } + + /** + * Element-wise "rectified linear 6" function with specified cutoff:
+ * out[i] = min(max(in, cutoff), 6)
+ * + * @param x Input (NUMERIC type) + * @param cutoff Cutoff value for ReLU operation. Usually 0 + * @return output Output (NUMERIC type) + */ + public INDArray relu6(INDArray x, double cutoff) { + NDValidation.validateNumerical("relu6", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scalar.Relu6(x, cutoff)); + } + + /** + * ReLU (Rectified Linear Unit) layer operation: out = relu(mmul(in,w) + bias)
+ * Note that bias array is optional
+ * + * @param input Input data (NUMERIC type) + * @param weights Weights variable (NUMERIC type) + * @param bias Optional bias variable (may be null) (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray reluLayer(INDArray input, INDArray weights, INDArray bias) { + NDValidation.validateNumerical("reluLayer", "input", input); + NDValidation.validateNumerical("reluLayer", "weights", weights); + NDValidation.validateNumerical("reluLayer", "bias", bias); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.ReluLayer(input, weights, bias))[0]; + } + + /** + * Element-wise SeLU function - Scaled exponential Lineal Unit: see Self-Normalizing Neural Networks
+ *
+ * out[i] = scale * alpha * (exp(in[i])-1) if in[i]>0, or 0 if in[i] <= 0
+ * Uses default scale and alpha values.
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray selu(INDArray x) { + NDValidation.validateNumerical("selu", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.SELU(x)); + } + + /** + * Element-wise sigmoid function: out[i] = 1.0/(1+exp(-in[i]))
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray sigmoid(INDArray x) { + NDValidation.validateNumerical("sigmoid", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.Sigmoid(x)); + } + + /** + * Element-wise sigmoid function derivative: dL/dIn given input and dL/dOut
+ * + * @param x Input Variable (NUMERIC type) + * @param wrt Gradient at the output - dL/dOut. Must have same shape as the input (NUMERIC type) + * @return output Output (gradient at input of sigmoid) (NUMERIC type) + */ + public INDArray sigmoidDerivative(INDArray x, INDArray wrt) { + NDValidation.validateNumerical("sigmoidDerivative", "x", x); + NDValidation.validateNumerical("sigmoidDerivative", "wrt", wrt); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.gradient.SigmoidDerivative(x, wrt))[0]; + } + + /** + * Softmax activation, along the specified dimension
+ * + * @param x Input (NUMERIC type) + * @param dimension Dimension along which to apply softmax + * @return output Output variable (NUMERIC type) + */ + public INDArray softmax(INDArray x, int dimension) { + NDValidation.validateNumerical("softmax", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax(x, dimension))[0]; + } + + /** + * Softmax derivative function
+ * + * @param x Softmax input (NUMERIC type) + * @param wrt Gradient at output, dL/dx (NUMERIC type) + * @param dimension Softmax dimension + * @return output (NUMERIC type) + */ + public INDArray softmaxDerivative(INDArray x, INDArray wrt, int dimension) { + NDValidation.validateNumerical("softmaxDerivative", "x", x); + NDValidation.validateNumerical("softmaxDerivative", "wrt", wrt); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftmaxBp(x, wrt, dimension))[0]; + } + + /** + * Element-wise softplus function: out = log(exp(x) + 1)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray softplus(INDArray x) { + NDValidation.validateNumerical("softplus", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.SoftPlus(x)); + } + + /** + * Element-wise softsign function: out = x / (abs(x) + 1)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray softsign(INDArray x) { + NDValidation.validateNumerical("softsign", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.SoftSign(x)); + } + + /** + * Element-wise derivative (dOut/dIn) of the softsign function softsign(INDArray)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output (NUMERIC type) + */ + public INDArray softsignDerivative(INDArray x) { + NDValidation.validateNumerical("softsignDerivative", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignDerivative(x)); + } + + /** + * Element-wise "swish" function: out = x * sigmoid(b*x) with b=1.0
+ * See: https://arxiv.org/abs/1710.05941
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray swish(INDArray x) { + NDValidation.validateNumerical("swish", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.Swish(x)); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDRandom.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDRandom.java new file mode 100644 index 000000000..5737ced1f --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDRandom.java @@ -0,0 +1,138 @@ +/* ****************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +//================== GENERATED CODE - DO NOT MODIFY THIS FILE ================== + +package org.nd4j.linalg.factory.ops; + +import static org.nd4j.linalg.factory.NDValidation.isSameType; + +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; + +public class NDRandom { + public NDRandom() { + } + + /** + * Generate a new random INDArray, where values are randomly sampled according to a Bernoulli distribution,
+ * with the specified probability. Array values will have value 1 with probability P and value 0 with probability
+ * 1-P.
+ * + * @param p Probability of value 1 + * @param datatype Data type of the output variable + * @param shape Shape of the new random INDArray, as a 1D array (Size: AtLeast(min=0)) + * @return output Tensor with the given shape where values are randomly sampled according to a %OP_NAME% distribution (NUMERIC type) + */ + public INDArray bernoulli(double p, DataType datatype, long... shape) { + Preconditions.checkArgument(shape.length >= 0, "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution(p, datatype, shape)); + } + + /** + * Generate a new random INDArray, where values are randomly sampled according to a Binomial distribution,
+ * with the specified number of trials and probability.
+ * + * @param nTrials Number of trials parameter for the binomial distribution + * @param p Probability of success for each trial + * @param datatype Data type of the output variable + * @param shape Shape of the new random INDArray, as a 1D array (Size: AtLeast(min=0)) + * @return output Tensor with the given shape where values are randomly sampled according to a %OP_NAME% distribution (NUMERIC type) + */ + public INDArray binomial(int nTrials, double p, DataType datatype, long... shape) { + Preconditions.checkArgument(shape.length >= 0, "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.random.impl.BinomialDistribution(nTrials, p, datatype, shape)); + } + + /** + * Generate a new random INDArray, where values are randomly sampled according to a exponential distribution:
+ * P(x) = lambda * exp(-lambda * x)
+ * + * Inputs must satisfy the following constraints:
+ * Must be positive: lambda > 0
+ * + * @param lambda lambda parameter + * @param datatype Data type of the output variable + * @param shape Shape of the new random INDArray, as a 1D array (Size: AtLeast(min=0)) + */ + public INDArray[] exponential(double lambda, DataType datatype, long... shape) { + Preconditions.checkArgument(shape.length >= 0, "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length); + Preconditions.checkArgument(lambda > 0, "Must be positive"); + return Nd4j.exec(new org.nd4j.linalg.api.ops.random.custom.RandomExponential(lambda, datatype, shape)); + } + + /** + * Generate a new random INDArray, where values are randomly sampled according to a Log Normal distribution,
+ * i.e., {@code log(x) ~ N(mean, stdev)}
+ * + * @param mean Mean value for the random array + * @param stddev Standard deviation for the random array + * @param datatype Data type of the output variable + * @param shape Shape of the new random INDArray, as a 1D array (Size: AtLeast(min=0)) + * @return output Tensor with the given shape where values are randomly sampled according to a %OP_NAME% distribution (NUMERIC type) + */ + public INDArray logNormal(double mean, double stddev, DataType datatype, long... shape) { + Preconditions.checkArgument(shape.length >= 0, "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.random.impl.LogNormalDistribution(mean, stddev, datatype, shape)); + } + + /** + * Generate a new random INDArray, where values are randomly sampled according to a Gaussian (normal) distribution,
+ * N(mean, stdev)
+ * + * @param mean Mean value for the random array + * @param stddev Standard deviation for the random array + * @param datatype Data type of the output variable + * @param shape Shape of the new random INDArray, as a 1D array (Size: AtLeast(min=0)) + * @return output Tensor with the given shape where values are randomly sampled according to a %OP_NAME% distribution (NUMERIC type) + */ + public INDArray normal(double mean, double stddev, DataType datatype, long... shape) { + Preconditions.checkArgument(shape.length >= 0, "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.random.impl.GaussianDistribution(mean, stddev, datatype, shape)); + } + + /** + * Generate a new random INDArray, where values are randomly sampled according to a Gaussian (normal) distribution,
+ * N(mean, stdev). However, any values more than 1 standard deviation from the mean are dropped and re-sampled
+ * + * @param mean Mean value for the random array + * @param stddev Standard deviation for the random array + * @param datatype Data type of the output variable + * @param shape Shape of the new random INDArray, as a 1D array (Size: AtLeast(min=0)) + * @return output Tensor with the given shape where values are randomly sampled according to a %OP_NAME% distribution (NUMERIC type) + */ + public INDArray normalTruncated(double mean, double stddev, DataType datatype, long... shape) { + Preconditions.checkArgument(shape.length >= 0, "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.random.impl.TruncatedNormalDistribution(mean, stddev, datatype, shape)); + } + + /** + * Generate a new random INDArray, where values are randomly sampled according to a uniform distribution,
+ * U(min,max)
+ * + * @param min Minimum value + * @param max Maximum value. + * @param datatype Data type of the output variable + * @param shape Shape of the new random INDArray, as a 1D array (Size: AtLeast(min=0)) + * @return output Tensor with the given shape where values are randomly sampled according to a %OP_NAME% distribution (NUMERIC type) + */ + public INDArray uniform(double min, double max, DataType datatype, long... shape) { + Preconditions.checkArgument(shape.length >= 0, "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.random.impl.UniformDistribution(min, max, datatype, shape)); + } +} diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java index ed3b5a7cb..e10ffcddb 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java @@ -1620,11 +1620,11 @@ public class SameDiffTests extends BaseNd4jTest { switch (i) { case 0: t = sd.math().isNonDecreasing(in1); - Nd4j.exec(new IsNonDecreasing(new INDArray[]{ia}, new INDArray[]{expOut})); + Nd4j.exec(new IsNonDecreasing(ia, expOut)); break; case 1: t = sd.math().isStrictlyIncreasing(in1); - Nd4j.exec(new IsStrictlyIncreasing(new INDArray[]{ia}, new INDArray[]{expOut})); + Nd4j.exec(new IsStrictlyIncreasing(ia, expOut)); break; case 2: t = sd.isNumericTensor(in1); @@ -1650,7 +1650,7 @@ public class SameDiffTests extends BaseNd4jTest { INDArray ia = Nd4j.randn(minibatch, nOut); INDArray expOut = Nd4j.create(DataType.BOOL, ia.shape()); - Nd4j.exec(new IsStrictlyIncreasing(new INDArray[]{ia}, new INDArray[]{expOut})); + Nd4j.exec(new IsStrictlyIncreasing(ia, expOut)); System.out.println(expOut); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllHelper.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllHelper.java index 83145a048..6582d38db 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllHelper.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllHelper.java @@ -31,6 +31,7 @@ import org.nd4j.autodiff.execution.conf.ExecutorConfiguration; import org.nd4j.autodiff.execution.conf.OutputMode; import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.autodiff.listeners.Listener; +import org.nd4j.autodiff.listeners.debugging.ExecDebuggingListener; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.internal.InferenceSession; import org.nd4j.autodiff.samediff.internal.SameDiffOp; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/TestNamespaces.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/TestNamespaces.java new file mode 100644 index 000000000..445f72342 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/TestNamespaces.java @@ -0,0 +1,68 @@ +package org.nd4j.linalg.api; + +import org.junit.Test; +import org.nd4j.linalg.BaseNd4jTest; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.factory.Nd4jBackend; + +public class TestNamespaces extends BaseNd4jTest { + + public TestNamespaces(Nd4jBackend backend) { + super(backend); + } + + @Test + public void testBitwiseSimple(){ + + INDArray x = Nd4j.rand(DataType.FLOAT, 1, 5).muli(100000).castTo(DataType.INT); + INDArray y = Nd4j.rand(DataType.FLOAT, 1, 5).muli(100000).castTo(DataType.INT); + + INDArray and = Nd4j.bitwise.and(x, y); + INDArray or = Nd4j.bitwise.or(x, y); + INDArray xor = Nd4j.bitwise.xor(x, y); + + System.out.println(and); + System.out.println(or); + System.out.println(xor); + + } + + @Test + public void testMathSimple(){ + INDArray x = Nd4j.rand(DataType.FLOAT, 1, 5).muli(2).subi(1); + INDArray abs = Nd4j.math.abs(x); + System.out.println(x); + System.out.println(abs); + + + INDArray c1 = Nd4j.createFromArray(0, 2, 1); + INDArray c2 = Nd4j.createFromArray(1, 2, 1); + + INDArray cm = Nd4j.math.confusionMatrix(c1, c2, 3); + System.out.println(cm); + } + + @Test + public void testRandomSimple(){ + INDArray normal = Nd4j.random.normal(0, 1, DataType.FLOAT, 10); + System.out.println(normal); + INDArray uniform = Nd4j.random.uniform(0, 1, DataType.FLOAT, 10); + System.out.println(uniform); + } + + @Test + public void testNeuralNetworkSimple(){ + INDArray out = Nd4j.nn.elu(Nd4j.random.normal(0, 1, DataType.FLOAT, 10)); + System.out.println(out); + INDArray out2 = Nd4j.nn.softmax(Nd4j.random.normal(0, 1, DataType.FLOAT, 4, 5), 1); + System.out.println(out2); + } + + @Override + public char ordering() { + return 'c'; + } + +} From 35ab4a72ba44079257bec3254175806f5b0a03e9 Mon Sep 17 00:00:00 2001 From: Alex Black Date: Sat, 30 Nov 2019 18:58:37 +1100 Subject: [PATCH 15/30] TF import test resources loading precision fixes (#92) * Fix precision issues when loading from CSV Signed-off-by: AlexDBlack * Small tweak Signed-off-by: AlexDBlack --- .../imports/graphmapper/tf/TFGraphMapper.java | 8 ++ .../TFGraphs/TFGraphTestAllHelper.java | 127 +++++++++++++++--- 2 files changed, 118 insertions(+), 17 deletions(-) diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/tf/TFGraphMapper.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/tf/TFGraphMapper.java index 8605467cc..f54b532e8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/tf/TFGraphMapper.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/tf/TFGraphMapper.java @@ -55,6 +55,14 @@ import java.util.*; @Slf4j public class TFGraphMapper { + /** + * @deprecated Use static methods - {@link #importGraph(File)} etc + */ + @Deprecated + public static TFGraphMapper getInstance(){ + return new TFGraphMapper(); + } + /** * Import a frozen TensorFlow protobuf (.pb) file from the specified file * diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllHelper.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllHelper.java index 6582d38db..eae14b230 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllHelper.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllHelper.java @@ -594,7 +594,7 @@ public class TFGraphTestAllHelper { val key = modelDir + "/" + okey; // parse type directly - val value = ArrayOptionsHelper.dataType(split[1]); + DataType value = ArrayOptionsHelper.dataType(split[1]); // adding key directly //if (dtypes.containsKey(key)) @@ -672,12 +672,35 @@ public class TFGraphTestAllHelper { INDArray varValue; if(filtered.size() == 0){ //Scalar - float[] varContents; - try(InputStream is = new BufferedInputStream(resources.get(i).getSecond().getInputStream())){ - varContents = Nd4j.readNumpy(is, ",").data().asFloat(); + String content = IOUtils.toString(resources.get(i).getSecond().getInputStream(), StandardCharsets.UTF_8); + switch (type){ + case DOUBLE: + case FLOAT: + case HALF: + case BFLOAT16: + varValue = Nd4j.scalar(type, parseDouble(content)); + break; + case LONG: + case INT: + case SHORT: + case UBYTE: + case BYTE: + case UINT16: + case UINT32: + case UINT64: + varValue = Nd4j.scalar(type, parseLong(content)); + break; + case BOOL: + varValue = Nd4j.scalar(parseBoolean(content)); + break; + case UTF8: + varValue = Nd4j.scalar(content); + break; + case COMPRESSED: + case UNKNOWN: + default: + throw new UnsupportedOperationException("Unknown / not implemented datatype: " + type); } - Preconditions.checkState(varContents.length == 1, "Expected length 1 content for scalar shape; got length %s", varContents.length); - varValue = Nd4j.scalar(type, varContents[0]); } else { int[] varShape = new int[filtered.size()]; for( int j=0; j testPrecisionOverride(String testName){ if("conv_4".equalsIgnoreCase(testName)){ From 2be47082c901162ab7e99b0708df13f3901ddd99 Mon Sep 17 00:00:00 2001 From: Alex Black Date: Sat, 30 Nov 2019 20:08:30 +1100 Subject: [PATCH 16/30] #8470 TrainingConfig json fix for Evaluation instances (#93) Signed-off-by: AlexDBlack --- .../autodiff/samediff/TrainingConfig.java | 4 +- .../org/nd4j/evaluation/BaseEvaluation.java | 51 +++---------------- .../org/nd4j/evaluation/curves/BaseCurve.java | 9 ++-- .../nd4j/evaluation/curves/BaseHistogram.java | 9 ++-- .../nd4j/evaluation/serde/ROCSerializer.java | 12 +++-- .../java/org/nd4j/serde/json/JsonMappers.java | 39 ++++++++++++-- .../nd4j/autodiff/samediff/SameDiffTests.java | 19 +++++++ 7 files changed, 81 insertions(+), 62 deletions(-) diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/TrainingConfig.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/TrainingConfig.java index d50daddb8..25aec0028 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/TrainingConfig.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/TrainingConfig.java @@ -1,4 +1,4 @@ -/******************************************************************************* +/* ****************************************************************************** * Copyright (c) 2015-2019 Skymind, Inc. * * This program and the accompanying materials are made available under the @@ -18,7 +18,6 @@ package org.nd4j.autodiff.samediff; import lombok.*; import lombok.extern.slf4j.Slf4j; -import org.nd4j.autodiff.listeners.ListenerEvaluations; import org.nd4j.base.Preconditions; import org.nd4j.evaluation.IEvaluation; import org.nd4j.linalg.learning.config.IUpdater; @@ -64,6 +63,7 @@ public class TrainingConfig { private int iterationCount; private int epochCount; + private Map> trainEvaluations = new HashMap<>(); private Map trainEvaluationLabels = new HashMap<>(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/BaseEvaluation.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/BaseEvaluation.java index fd08e4270..3f4ce04f3 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/BaseEvaluation.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/BaseEvaluation.java @@ -1,5 +1,6 @@ -/******************************************************************************* +/* ****************************************************************************** * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -17,7 +18,6 @@ package org.nd4j.evaluation; import lombok.EqualsAndHashCode; -import lombok.Getter; import lombok.NonNull; import org.nd4j.base.Preconditions; import org.nd4j.evaluation.classification.*; @@ -27,24 +27,13 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastTo; import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.primitives.AtomicBoolean; -import org.nd4j.linalg.primitives.AtomicDouble; import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.primitives.Triple; -import org.nd4j.linalg.primitives.serde.JsonDeserializerAtomicBoolean; -import org.nd4j.linalg.primitives.serde.JsonDeserializerAtomicDouble; -import org.nd4j.linalg.primitives.serde.JsonSerializerAtomicBoolean; -import org.nd4j.linalg.primitives.serde.JsonSerializerAtomicDouble; import org.nd4j.linalg.util.ArrayUtil; -import org.nd4j.shade.jackson.annotation.JsonAutoDetect; +import org.nd4j.serde.json.JsonMappers; import org.nd4j.shade.jackson.core.JsonProcessingException; -import org.nd4j.shade.jackson.databind.DeserializationFeature; -import org.nd4j.shade.jackson.databind.MapperFeature; import org.nd4j.shade.jackson.databind.ObjectMapper; -import org.nd4j.shade.jackson.databind.SerializationFeature; import org.nd4j.shade.jackson.databind.exc.InvalidTypeIdException; -import org.nd4j.shade.jackson.databind.module.SimpleModule; -import org.nd4j.shade.jackson.dataformat.yaml.YAMLFactory; import java.io.IOException; import java.io.Serializable; @@ -60,32 +49,6 @@ import java.util.List; @EqualsAndHashCode public abstract class BaseEvaluation implements IEvaluation { - @Getter - private static ObjectMapper objectMapper = configureMapper(new ObjectMapper()); - @Getter - private static ObjectMapper yamlMapper = configureMapper(new ObjectMapper(new YAMLFactory())); - - private static ObjectMapper configureMapper(ObjectMapper ret) { - ret.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false); - ret.configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false); - ret.configure(MapperFeature.SORT_PROPERTIES_ALPHABETICALLY, false); - ret.enable(SerializationFeature.INDENT_OUTPUT); - SimpleModule atomicModule = new SimpleModule(); - atomicModule.addSerializer(AtomicDouble.class, new JsonSerializerAtomicDouble()); - atomicModule.addSerializer(AtomicBoolean.class, new JsonSerializerAtomicBoolean()); - atomicModule.addDeserializer(AtomicDouble.class, new JsonDeserializerAtomicDouble()); - atomicModule.addDeserializer(AtomicBoolean.class, new JsonDeserializerAtomicBoolean()); - ret.registerModule(atomicModule); - //Serialize fields only, not using getters - ret.setVisibilityChecker(ret.getSerializationConfig().getDefaultVisibilityChecker() - .withFieldVisibility(JsonAutoDetect.Visibility.ANY) - .withGetterVisibility(JsonAutoDetect.Visibility.NONE) - .withSetterVisibility(JsonAutoDetect.Visibility.NONE) - .withCreatorVisibility(JsonAutoDetect.Visibility.ANY) - ); - return ret; - } - /** * @param yaml YAML representation * @param clazz Class @@ -94,7 +57,7 @@ public abstract class BaseEvaluation implements IEvalu */ public static T fromYaml(String yaml, Class clazz) { try { - return yamlMapper.readValue(yaml, clazz); + return JsonMappers.getYamlMapper().readValue(yaml, clazz); } catch (IOException e) { throw new RuntimeException(e); } @@ -108,7 +71,7 @@ public abstract class BaseEvaluation implements IEvalu */ public static T fromJson(String json, Class clazz) { try { - return objectMapper.readValue(json, clazz); + return JsonMappers.getMapper().readValue(json, clazz); } catch (InvalidTypeIdException e) { if (e.getMessage().contains("Could not resolve type id")) { try { @@ -332,7 +295,7 @@ public abstract class BaseEvaluation implements IEvalu @Override public String toJson() { try { - return objectMapper.writeValueAsString(this); + return JsonMappers.getMapper().writeValueAsString(this); } catch (JsonProcessingException e) { throw new RuntimeException(e); } @@ -349,7 +312,7 @@ public abstract class BaseEvaluation implements IEvalu @Override public String toYaml() { try { - return yamlMapper.writeValueAsString(this); + return JsonMappers.getYamlMapper().writeValueAsString(this); } catch (JsonProcessingException e) { throw new RuntimeException(e); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/curves/BaseCurve.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/curves/BaseCurve.java index ee9339da4..2e61e80bd 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/curves/BaseCurve.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/curves/BaseCurve.java @@ -17,6 +17,7 @@ package org.nd4j.evaluation.curves; import org.nd4j.evaluation.BaseEvaluation; +import org.nd4j.serde.json.JsonMappers; import org.nd4j.shade.jackson.annotation.JsonTypeInfo; import org.nd4j.shade.jackson.core.JsonProcessingException; @@ -87,7 +88,7 @@ public abstract class BaseCurve { */ public String toJson() { try { - return BaseEvaluation.getObjectMapper().writeValueAsString(this); + return JsonMappers.getMapper().writeValueAsString(this); } catch (JsonProcessingException e) { throw new RuntimeException(e); } @@ -98,7 +99,7 @@ public abstract class BaseCurve { */ public String toYaml() { try { - return BaseEvaluation.getYamlMapper().writeValueAsString(this); + return JsonMappers.getYamlMapper().writeValueAsString(this); } catch (JsonProcessingException e) { throw new RuntimeException(e); } @@ -113,7 +114,7 @@ public abstract class BaseCurve { */ public static T fromJson(String json, Class curveClass) { try { - return BaseEvaluation.getObjectMapper().readValue(json, curveClass); + return JsonMappers.getMapper().readValue(json, curveClass); } catch (IOException e) { throw new RuntimeException(e); } @@ -128,7 +129,7 @@ public abstract class BaseCurve { */ public static T fromYaml(String yaml, Class curveClass) { try { - return BaseEvaluation.getYamlMapper().readValue(yaml, curveClass); + return JsonMappers.getYamlMapper().readValue(yaml, curveClass); } catch (IOException e) { throw new RuntimeException(e); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/curves/BaseHistogram.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/curves/BaseHistogram.java index a941f2088..1adcc32d0 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/curves/BaseHistogram.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/curves/BaseHistogram.java @@ -17,6 +17,7 @@ package org.nd4j.evaluation.curves; import org.nd4j.evaluation.BaseEvaluation; +import org.nd4j.serde.json.JsonMappers; import org.nd4j.shade.jackson.annotation.JsonTypeInfo; import org.nd4j.shade.jackson.core.JsonProcessingException; @@ -46,7 +47,7 @@ public abstract class BaseHistogram { */ public String toJson() { try { - return BaseEvaluation.getObjectMapper().writeValueAsString(this); + return JsonMappers.getMapper().writeValueAsString(this); } catch (JsonProcessingException e) { throw new RuntimeException(e); } @@ -57,7 +58,7 @@ public abstract class BaseHistogram { */ public String toYaml() { try { - return BaseEvaluation.getYamlMapper().writeValueAsString(this); + return JsonMappers.getYamlMapper().writeValueAsString(this); } catch (JsonProcessingException e) { throw new RuntimeException(e); } @@ -72,7 +73,7 @@ public abstract class BaseHistogram { */ public static T fromJson(String json, Class curveClass) { try { - return BaseEvaluation.getObjectMapper().readValue(json, curveClass); + return JsonMappers.getMapper().readValue(json, curveClass); } catch (IOException e) { throw new RuntimeException(e); } @@ -87,7 +88,7 @@ public abstract class BaseHistogram { */ public static T fromYaml(String yaml, Class curveClass) { try { - return BaseEvaluation.getYamlMapper().readValue(yaml, curveClass); + return JsonMappers.getYamlMapper().readValue(yaml, curveClass); } catch (IOException e) { throw new RuntimeException(e); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/serde/ROCSerializer.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/serde/ROCSerializer.java index 236407527..331585ad4 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/serde/ROCSerializer.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/serde/ROCSerializer.java @@ -36,7 +36,9 @@ public class ROCSerializer extends JsonSerializer { @Override public void serialize(ROC roc, JsonGenerator jsonGenerator, SerializerProvider serializerProvider) throws IOException { - if (roc.isExact()) { + boolean empty = roc.getExampleCount() == 0; + + if (roc.isExact() && !empty) { //For exact ROC implementation: force AUC and AUPRC calculation, so result can be stored in JSON, such //that we have them once deserialized. //Due to potentially huge size, exact mode doesn't store the original predictions in JSON @@ -47,9 +49,11 @@ public class ROCSerializer extends JsonSerializer { jsonGenerator.writeNumberField("countActualPositive", roc.getCountActualPositive()); jsonGenerator.writeNumberField("countActualNegative", roc.getCountActualNegative()); jsonGenerator.writeObjectField("counts", roc.getCounts()); - jsonGenerator.writeNumberField("auc", roc.calculateAUC()); - jsonGenerator.writeNumberField("auprc", roc.calculateAUCPR()); - if (roc.isExact()) { + if(!empty) { + jsonGenerator.writeNumberField("auc", roc.calculateAUC()); + jsonGenerator.writeNumberField("auprc", roc.calculateAUCPR()); + } + if (roc.isExact() && !empty) { //Store ROC and PR curves only for exact mode... they are redundant + can be calculated again for thresholded mode jsonGenerator.writeObjectField("rocCurve", roc.getRocCurve()); jsonGenerator.writeObjectField("prCurve", roc.getPrecisionRecallCurve()); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/serde/json/JsonMappers.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/serde/json/JsonMappers.java index 81bb46e75..4a1344ae3 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/serde/json/JsonMappers.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/serde/json/JsonMappers.java @@ -17,10 +17,19 @@ package org.nd4j.serde.json; import lombok.extern.slf4j.Slf4j; +import org.nd4j.linalg.primitives.AtomicBoolean; +import org.nd4j.linalg.primitives.AtomicDouble; +import org.nd4j.linalg.primitives.serde.JsonDeserializerAtomicBoolean; +import org.nd4j.linalg.primitives.serde.JsonDeserializerAtomicDouble; +import org.nd4j.linalg.primitives.serde.JsonSerializerAtomicBoolean; +import org.nd4j.linalg.primitives.serde.JsonSerializerAtomicDouble; +import org.nd4j.shade.jackson.annotation.JsonAutoDetect; import org.nd4j.shade.jackson.databind.DeserializationFeature; import org.nd4j.shade.jackson.databind.MapperFeature; import org.nd4j.shade.jackson.databind.ObjectMapper; import org.nd4j.shade.jackson.databind.SerializationFeature; +import org.nd4j.shade.jackson.databind.module.SimpleModule; +import org.nd4j.shade.jackson.dataformat.yaml.YAMLFactory; /** * JSON mappers for serializing/deserializing objects @@ -30,19 +39,41 @@ import org.nd4j.shade.jackson.databind.SerializationFeature; @Slf4j public class JsonMappers { - private static ObjectMapper jsonMapper = new ObjectMapper(); + private static ObjectMapper jsonMapper = configureMapper(new ObjectMapper()); + private static ObjectMapper yamlMapper = configureMapper(new ObjectMapper(new YAMLFactory())); /** - * @return The default/primary ObjectMapper for deserializing JSON network configurations in DL4J + * @return The default/primary ObjectMapper for deserializing JSON objects */ public static ObjectMapper getMapper(){ return jsonMapper; } - private static void configureMapper(ObjectMapper ret) { + /** + * @return The default/primary ObjectMapper for deserializing JSON objects + */ + public static ObjectMapper getYamlMapper(){ + return jsonMapper; + } + + private static ObjectMapper configureMapper(ObjectMapper ret) { ret.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false); ret.configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false); - ret.configure(MapperFeature.SORT_PROPERTIES_ALPHABETICALLY, true); + ret.configure(MapperFeature.SORT_PROPERTIES_ALPHABETICALLY, false); ret.enable(SerializationFeature.INDENT_OUTPUT); + SimpleModule atomicModule = new SimpleModule(); + atomicModule.addSerializer(AtomicDouble.class, new JsonSerializerAtomicDouble()); + atomicModule.addSerializer(AtomicBoolean.class, new JsonSerializerAtomicBoolean()); + atomicModule.addDeserializer(AtomicDouble.class, new JsonDeserializerAtomicDouble()); + atomicModule.addDeserializer(AtomicBoolean.class, new JsonDeserializerAtomicBoolean()); + ret.registerModule(atomicModule); + //Serialize fields only, not using getters + ret.setVisibilityChecker(ret.getSerializationConfig().getDefaultVisibilityChecker() + .withFieldVisibility(JsonAutoDetect.Visibility.ANY) + .withGetterVisibility(JsonAutoDetect.Visibility.NONE) + .withSetterVisibility(JsonAutoDetect.Visibility.NONE) + .withCreatorVisibility(JsonAutoDetect.Visibility.ANY) + ); + return ret; } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java index e10ffcddb..db8d7d551 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java @@ -43,6 +43,9 @@ import org.nd4j.autodiff.samediff.api.OutAndGrad; import org.nd4j.autodiff.samediff.impl.DefaultSameDiffConditional; import org.nd4j.autodiff.validation.OpValidation; import org.nd4j.autodiff.validation.TestCase; +import org.nd4j.evaluation.IEvaluation; +import org.nd4j.evaluation.classification.*; +import org.nd4j.evaluation.regression.RegressionEvaluation; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.blas.params.MMulTranspose; @@ -3501,4 +3504,20 @@ public class SameDiffTests extends BaseNd4jTest { Map map = sd.calculateGradients(null,"input", "concat"); assertEquals(map.get("input"), map.get("concat")); } + + @Test + public void testTrainingConfigJson(){ + for(IEvaluation e : new IEvaluation[]{new Evaluation(), new RegressionEvaluation(), new EvaluationBinary(), new ROC(), + new ROCMultiClass(), new ROCBinary(), new EvaluationCalibration()}) { + TrainingConfig config = new TrainingConfig.Builder() + .l2(1e-4) + .updater(new Adam(0.1)) + .dataSetFeatureMapping("out").dataSetLabelMapping("label") + .trainEvaluation("out", 0, e) + .build(); + String json = config.toJson(); + TrainingConfig fromJson = TrainingConfig.fromJson(json); + assertEquals(config, fromJson); + } + } } From 4ada65b3845827d0213d3c1c3909e54eff45e849 Mon Sep 17 00:00:00 2001 From: raver119 Date: Sat, 30 Nov 2019 16:02:07 +0300 Subject: [PATCH 17/30] [WIP] MSVC-related tests fixes (#88) * fix narrowing down cast Signed-off-by: raver119 * trigger jenkins Signed-off-by: raver119 * few more fixes for MSVC and Windows Signed-off-by: raver119 * few more fixes for MSVC and Windows Signed-off-by: raver119 * few more fixes for MSVC and Windows Signed-off-by: raver119 * few more fixes for MSVC and Windows Signed-off-by: raver119 * few more tweaks Signed-off-by: raver119 * few more tweaks Signed-off-by: raver119 * few more tweaks Signed-off-by: raver119 * few more tweaks Signed-off-by: raver119 * few more tweaks Signed-off-by: raver119 * - few more tweaks - tensormmul dtype validation Signed-off-by: raver119 * - few more tweaks - batched gemm dtype validation Signed-off-by: raver119 * - few more tweaks Signed-off-by: raver119 * - few more tweaks Signed-off-by: raver119 * - few more tweaks Signed-off-by: raver119 * - few more tweaks Signed-off-by: raver119 --- .../declarable/generic/blas/batched_gemm.cpp | 16 +- .../declarable/generic/blas/tensormmul.cpp | 11 +- libnd4j/tests_cpu/layers_tests/CMakeLists.txt | 5 +- .../layers_tests/ConvolutionTests1.cpp | 385 ++++++++------- .../layers_tests/ConvolutionTests2.cpp | 385 ++++++++------- .../layers_tests/DeclarableOpsTests10.cpp | 32 +- .../layers_tests/DeclarableOpsTests12.cpp | 136 +++--- .../layers_tests/DeclarableOpsTests13.cpp | 20 +- .../layers_tests/DeclarableOpsTests15.cpp | 34 +- .../layers_tests/DeclarableOpsTests16.cpp | 2 +- .../layers_tests/DeclarableOpsTests2.cpp | 194 ++++---- .../layers_tests/DeclarableOpsTests3.cpp | 452 +++++++++--------- .../layers_tests/DeclarableOpsTests4.cpp | 144 +++--- .../layers_tests/DeclarableOpsTests5.cpp | 52 +- .../layers_tests/DeclarableOpsTests6.cpp | 136 +++--- .../layers_tests/DeclarableOpsTests7.cpp | 56 +-- .../layers_tests/DeclarableOpsTests8.cpp | 96 ++-- .../layers_tests/DeclarableOpsTests9.cpp | 12 +- .../layers_tests/JavaInteropTests.cpp | 24 +- .../tests_cpu/layers_tests/NDArrayTests2.cpp | 162 +++---- .../tests_cpu/layers_tests/NativeOpsTests.cpp | 40 +- .../tests_cpu/layers_tests/PairwiseTests.cpp | 6 +- .../tests_cpu/layers_tests/ReduceTests.cpp | 24 +- 23 files changed, 1209 insertions(+), 1215 deletions(-) diff --git a/libnd4j/include/ops/declarable/generic/blas/batched_gemm.cpp b/libnd4j/include/ops/declarable/generic/blas/batched_gemm.cpp index 6812b287b..67a839e7b 100644 --- a/libnd4j/include/ops/declarable/generic/blas/batched_gemm.cpp +++ b/libnd4j/include/ops/declarable/generic/blas/batched_gemm.cpp @@ -110,19 +110,7 @@ DECLARE_SHAPE_FN(batched_gemm) { auto shapeList = SHAPELIST(); if (!(M > 0 && N > 0 && K > 0 && ldA > 0 && ldB > 0 && ldC > 0 && batchSize > 0)) { - Nd4jLong *newShape; - ALLOCATE(newShape, block.getWorkspace(), shape::shapeInfoLength(2), Nd4jLong); - - newShape[0] = 2; - newShape[1] = 1; - newShape[2] = 1; - newShape[3] = 1; - newShape[4] = 1; - newShape[5] = 0; - newShape[6] = 1; - newShape[7] = 99; - - shapeList->push_back(newShape); + shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inputShape->at(0)), 'c', {1, 1})); return shapeList; } @@ -130,7 +118,7 @@ DECLARE_SHAPE_FN(batched_gemm) { std::vector shape({M, N}); for (int e = 0; e < batchSize; e++) { - auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(block.dataType(), 'f', shape); + auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inputShape->at(0)), 'f', shape); shapeList->push_back(newShape); } diff --git a/libnd4j/include/ops/declarable/generic/blas/tensormmul.cpp b/libnd4j/include/ops/declarable/generic/blas/tensormmul.cpp index 78d78712f..2c362b23d 100644 --- a/libnd4j/include/ops/declarable/generic/blas/tensormmul.cpp +++ b/libnd4j/include/ops/declarable/generic/blas/tensormmul.cpp @@ -31,7 +31,9 @@ namespace nd4j { auto a = INPUT_VARIABLE(0); auto b = INPUT_VARIABLE(1); - auto c = OUTPUT_VARIABLE(0); // + auto c = OUTPUT_VARIABLE(0); // + + REQUIRE_TRUE(a->dataType() == b->dataType(), 0, "tensormmul: A, B and C data types must be the same"); // building axes int axe0_size = INT_ARG(0); @@ -54,7 +56,10 @@ namespace nd4j { DECLARE_SHAPE_FN(tensormmul) { auto aShapeInfo = inputShape->at(0); - auto bShapeInfo = inputShape->at(1); + auto bShapeInfo = inputShape->at(1); + + REQUIRE_TRUE(ArrayOptions::dataType(aShapeInfo) == ArrayOptions::dataType(bShapeInfo), 0, "tensormmul: A and B data types must be the same"); + // building axes int axe0_size = INT_ARG(0); int axe1_size = INT_ARG(axe0_size+1); @@ -70,7 +75,7 @@ namespace nd4j { std::vector shapeAt, shapeBt; auto outShape = nd4j::ShapeUtils::evalShapeForTensorDot(aShapeInfo, bShapeInfo, axes_0, axes_1, permutAt, permutBt, shapeAt, shapeBt); - return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(block.dataType(), 'c', outShape))); + return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(ArrayOptions::dataType(aShapeInfo), 'c', outShape))); } DECLARE_TYPES(tensormmul) { diff --git a/libnd4j/tests_cpu/layers_tests/CMakeLists.txt b/libnd4j/tests_cpu/layers_tests/CMakeLists.txt index 8a58fe3a5..1d5a1df98 100644 --- a/libnd4j/tests_cpu/layers_tests/CMakeLists.txt +++ b/libnd4j/tests_cpu/layers_tests/CMakeLists.txt @@ -45,7 +45,10 @@ endif() if (APPLE) set(CMAKE_CXX_FLAGS " -fPIC -std=c++11 -fmax-errors=2 -D__APPLE_OS__=true") elseif(WIN32) - set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -march=native -mtune=native -O3") + if (CPU_BLAS) + set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -fPIC -march=native -mtune=native -O3") + endif() + if (CPU_BLAS AND LINUX) set(CMAKE_CXX_FLAGS " -fPIC -std=c++11 -fmax-errors=2") endif() diff --git a/libnd4j/tests_cpu/layers_tests/ConvolutionTests1.cpp b/libnd4j/tests_cpu/layers_tests/ConvolutionTests1.cpp index bb4fe7b3c..b3552a00f 100644 --- a/libnd4j/tests_cpu/layers_tests/ConvolutionTests1.cpp +++ b/libnd4j/tests_cpu/layers_tests/ConvolutionTests1.cpp @@ -134,7 +134,7 @@ TYPED_TEST(TypedConvolutionTests1, conv2d_1) { TYPED_TEST(TypedConvolutionTests1, conv2d_2) { auto input = NDArrayFactory::create('c', {1, 1, 1, 4}); auto weights = NDArrayFactory::create('c', {1, 1, 1, 4}); - auto exp = NDArrayFactory::create('c', {1, 4, 1, 4}, {2., 4., 6., 8., 2., 4., 6., 8., 2., 4., 6., 8., 2., 4., 6., 8.}); + auto exp = NDArrayFactory::create('c', {1, 4, 1, 4}, {2.f, 4.f, 6.f, 8.f, 2.f, 4.f, 6.f, 8.f, 2.f, 4.f, 6.f, 8.f, 2.f, 4.f, 6.f, 8.f}); weights.assign(2.0); input.linspace(1); @@ -161,7 +161,7 @@ TYPED_TEST(TypedConvolutionTests1, conv2d_3) { auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); auto weights = NDArrayFactory::create('c', {kH, kW, iC, oC}); - auto bias = NDArrayFactory::create('c', {oC}, {1,2,3}); + auto bias = NDArrayFactory::create('c', {oC}, {1.f, 2.f, 3.f}); auto expOutput = NDArrayFactory::create('c', {bS, oH, oW, oC},{ 152.f, 155.2f, 158.4f, 152.f, 155.2f, 158.4f, 66.4f, 68.f, 69.6f, 170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f, @@ -762,10 +762,10 @@ TYPED_TEST(TypedConvolutionTests1, Test_Conv1D_ff_1) { auto input = NDArrayFactory::create('c', {2, 2, 6}); auto weights = NDArrayFactory::create('c', {2, 2, 3}, {1,5,9,3,7,11,2,6,10,4,8,12}); auto bias = NDArrayFactory::create('c', {3}); - auto expFF = NDArrayFactory::create('c', {2, 3, 5}, {59.0, 69.0, 79.0, 89.0, 99.0, 132.0, 158.0, 184.0, 210.0, 236.0, 205.0, 247.0, 289.0, 331.0, 373.0, 179.0, 189.0, 199.0, 209.0, 219.0, 444.0, 470.0, 496.0, 522.0, 548.0, 709.0, 751.0, 793.0, 835.0, 877.0}); - auto expEps = NDArrayFactory::create('c', {2, 2, 6}, {130.0, 293.0, 326.0, 359.0, 392.0, 220.0, 166.0, 371.0, 416.0, 461.0, 506.0, 280.0, 355.0, 788.0, 821.0, 854.0, 887.0, 490.0, 481.0, 1046.0, 1091.0, 1136.0, 1181.0, 640.0}); - auto expGW = NDArrayFactory::create('c', {3, 2, 2}, {1415.0, 1520.0, 2045.0, 2150.0, 1865.0, 2020.0, 2795.0, 2950.0, 2315.0, 2520.0, 3545.0, 3750.0}); - auto expGB = NDArrayFactory::create('c', {3}, {105.0, 155.0, 205.0}); + auto expFF = NDArrayFactory::create('c', {2, 3, 5}, {59.0f, 69.0f, 79.0f, 89.0f, 99.0f, 132.0f, 158.0f, 184.0f, 210.0f, 236.0f, 205.0f, 247.0f, 289.0f, 331.0f, 373.0f, 179.0f, 189.0f, 199.0f, 209.0f, 219.0f, 444.0f, 470.0f, 496.0f, 522.0f, 548.0f, 709.0f, 751.0f, 793.0f, 835.0f, 877.0f}); + auto expEps = NDArrayFactory::create('c', {2, 2, 6}, {130.0f, 293.0f, 326.0f, 359.0f, 392.0f, 220.0f, 166.0f, 371.0f, 416.0f, 461.0f, 506.0f, 280.0f, 355.0f, 788.0f, 821.0f, 854.0f, 887.0f, 490.0f, 481.0f, 1046.0f, 1091.0f, 1136.0f, 1181.0f, 640.0f}); + auto expGW = NDArrayFactory::create('c', {3, 2, 2}, {1415.0f, 1520.0f, 2045.0f, 2150.0f, 1865.0f, 2020.0f, 2795.0f, 2950.0f, 2315.0f, 2520.0f, 3545.0f, 3750.0f}); + auto expGB = NDArrayFactory::create('c', {3}, {105.0f, 155.0f, 205.0f}); expGW.permutei({2,1,0}); input.linspace(1); @@ -809,7 +809,7 @@ TYPED_TEST(TypedConvolutionTests1, Test_Conv1D_ff_1) { TYPED_TEST(TypedConvolutionTests1, Test_Conv1D_ff_2) { auto input = NDArrayFactory::create('c', {2, 2, 6}); - auto weights = NDArrayFactory::create('c', {2, 2, 3}, {1,5,9,3,7,11,2,6,10,4,8,12}); + auto weights = NDArrayFactory::create('c', {2, 2, 3}, {1.f, 5.f, 9.f, 3.f, 7.f, 11.f, 2.f, 6.f, 10.f, 4.f, 8.f, 12.f}); input.linspace(1); @@ -1164,7 +1164,6 @@ TYPED_TEST(TypedConvolutionTests1, conv2d_bp_test2) { ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests1, conv2d_bp_test3) { - int bS=2, iH=4,iW=3, iC=4,oC=3, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; int oH=2,oW=2; int paddingMode = 0; // 1-SAME, 0-VALID; @@ -1175,16 +1174,16 @@ TYPED_TEST(TypedConvolutionTests1, conv2d_bp_test3) { auto bias = NDArrayFactory::create('c', {oC}, {1,2,3}); auto gradO = NDArrayFactory::create('c', {bS, oC, oH, oW}); - auto expGradI = NDArrayFactory::create('c', {bS, iC, iH, iW},{ 0.567, 1.224,0.66 ,1.314, 2.82 ,1.512,1.386, 2.976,1.596,0.801, 1.71 ,0.912,0.657, 1.422,0.768,1.53 , 3.288,1.764,1.602, 3.444,1.848,0.927, 1.98 ,1.056, - 0.747, 1.62 ,0.876,1.746, 3.756,2.016,1.818, 3.912,2.1 ,1.053, 2.25 ,1.2 ,0.837, 1.818,0.984,1.962, 4.224,2.268,2.034, 4.38 ,2.352,1.179, 2.52 ,1.344, - 1.467, 3.06 ,1.596,3.186, 6.636,3.456,3.402, 7.08 ,3.684,1.845, 3.834,1.992,1.773, 3.69 ,1.92 ,3.834, 7.968,4.14 ,4.05 , 8.412,4.368,2.187, 4.536,2.352, - 2.079, 4.32 ,2.244,4.482, 9.3 ,4.824,4.698, 9.744,5.052,2.529, 5.238,2.712,2.385, 4.95 ,2.568,5.13 ,10.632,5.508,5.346,11.076,5.736,2.871, 5.94 ,3.072}); + auto expGradI = NDArrayFactory::create('c', {bS, iC, iH, iW},{ 0.567f, 1.224f, 0.66f, 1.314f, 2.82f, 1.512f, 1.386f, 2.976f, 1.596f, 0.801f, 1.71f, 0.912f, 0.657f, 1.422f, 0.768f, 1.53f, 3.288f, 1.764f, 1.602f, 3.444f, 1.848f, 0.927f, 1.98f, 1.056f, + 0.747f, 1.62f, 0.876f, 1.746f, 3.756f, 2.016f, 1.818f, 3.912f, 2.1f, 1.053f, 2.25f, 1.2f, 0.837f, 1.818f, 0.984f, 1.962f, 4.224f, 2.268f, 2.034f, 4.38f, 2.352f, 1.179f, 2.52f, 1.344f, + 1.467f, 3.06f, 1.596f, 3.186f, 6.636f, 3.456f, 3.402f, 7.08f, 3.684f, 1.845f, 3.834f, 1.992f, 1.773f, 3.69f, 1.92f, 3.834f, 7.968f, 4.14f, 4.05f, 8.412f, 4.368f, 2.187f, 4.536f, 2.352f, + 2.079f, 4.32f, 2.244f, 4.482f, 9.3f, 4.824f, 4.698f, 9.744f, 5.052f, 2.529f, 5.238f, 2.712f, 2.385f, 4.95f, 2.568f, 5.13f, 10.632f, 5.508f, 5.346f, 11.076f, 5.736f, 2.871f, 5.94f, 3.072f}); - auto expGradW = NDArrayFactory::create('c', {oC, iC, kH, kW},{1.3600e+00,1.3600e+00,1.3600e+00,1.3600e+00,1.3600e+00,1.3600e+00,1.3600e+00,1.3600e+00,1.3600e+00,1.3600e+00,1.3600e+00,1.3600e+00,1.3600e+00,1.3600e+00,1.3600e+00,1.3600e+00,1.3600e+00,1.3600e+00, - 1.3600e+00,1.3600e+00,1.3600e+00,1.3600e+00,1.3600e+00,1.3600e+00,2.0000e+00,2.0000e+00,2.0000e+00,2.0000e+00,2.0000e+00,2.0000e+00,2.0000e+00,2.0000e+00,2.0000e+00,2.0000e+00,2.0000e+00,2.0000e+00, - 2.0000e+00,2.0000e+00,2.0000e+00,2.0000e+00,2.0000e+00,2.0000e+00,2.0000e+00,2.0000e+00,2.0000e+00,2.0000e+00,2.0000e+00,2.0000e+00,2.6400e+00,2.6400e+00,2.6400e+00,2.6400e+00,2.6400e+00,2.6400e+00, - 2.6400e+00,2.6400e+00,2.6400e+00,2.6400e+00,2.6400e+00,2.6400e+00,2.6400e+00,2.6400e+00,2.6400e+00,2.6400e+00,2.6400e+00,2.6400e+00,2.6400e+00,2.6400e+00,2.6400e+00,2.6400e+00,2.6400e+00,2.6400e+00}); - auto expGradB = NDArrayFactory::create('c', {oC},{0.68, 1., 1.32}); + auto expGradW = NDArrayFactory::create('c', {oC, iC, kH, kW},{1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, + 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, + 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, + 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f}); + auto expGradB = NDArrayFactory::create('c', {oC},{0.68f, 1.f, 1.32f}); input = 2.; weights.linspace(0.1, 0.1); @@ -1253,21 +1252,21 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_bp_test1) { auto bias = NDArrayFactory::create('c', {oC}, {1,2,3}); auto gradO = NDArrayFactory::create('c', {bS, oD, oH, oW, oC}); - auto expGradI = NDArrayFactory::create('c', {bS, iD, iH, iW, iC},{0.226, 0.343, 0.46 , 0.577, 1.172, 1.46 , 1.748, 2.036, 1.892, 2.288, 2.684, 3.08 , 1.284, 1.581, 1.878, 2.175, 4.458, 5.133, 5.808, 6.483, 6.186, 7.023, 7.86 , 8.697, 3.39 , 3.93 , 4.47 , 5.01 , 9.642, 10.803, 11.964, 13.125, 11.37 , 12.693, 14.016, 15.339, - 5.266, 5.707, 6.148, 6.589, 12.98 , 13.916, 14.852, 15.788, 14.564, 15.608, 16.652, 17.696, 6.284, 7.166, 8.048, 8.93 , 17.896, 19.768, 21.64 , 23.512, 21.928, 24.016, 26.104, 28.192, 18.12 , 19.686, 21.252, 22.818, 45.852, 49.146, 52.44 , 55.734, 53.196, 56.814, 60.432, 64.05 , - 28.164, 30.216, 32.268, 34.32 , 67.884, 72.15 , 76.416, 80.682, 75.228, 79.818, 84.408, 88.998, 29.324, 30.854, 32.384, 33.914, 67.432, 70.6 , 73.768, 76.936, 73.192, 76.576, 79.96 , 83.344, 27.884, 30.062, 32.24 , 34.418, 66.28 , 70.744, 75.208, 79.672, 70.312, 74.992, 79.672, 84.352, - 58.296, 61.806, 65.316, 68.826,133.98 , 141.162, 148.344, 155.526,141.324, 148.83 , 156.336, 163.842, 68.34 , 72.336, 76.332, 80.328,156.012, 164.166, 172.32 , 180.474,163.356, 171.834, 180.312, 188.79 , 61.292, 64.118, 66.944, 69.77 ,136.552, 142.312, 148.072, 153.832,142.312, 148.288, 154.264, 160.24 , - 9.298, 11.359, 13.42 , 15.481, 27.092, 31.268, 35.444, 39.62 , 27.812, 32.096, 36.38 , 40.664, 26.556, 29.769, 32.982, 36.195, 66.666, 73.173, 79.68 , 86.187, 68.394, 75.063, 81.732, 88.401, 28.662, 32.118, 35.574, 39.03 , 71.85 , 78.843, 85.836, 92.829, 73.578, 80.733, 87.888, 95.043, - 29.89 , 32.275, 34.66 , 37.045, 70.004, 74.828, 79.652, 84.476, 71.588, 76.52 , 81.452, 86.384, 71.084, 75.854, 80.624, 85.394,163.048, 172.696, 182.344, 191.992,167.08 , 176.944, 186.808, 196.672,138.648, 146.046, 153.444, 160.842,310.236, 325.194, 340.152, 355.11 ,317.58 , 332.862, 348.144, 363.426, - 148.692, 156.576, 164.46 , 172.344,332.268, 348.198, 364.128, 380.058,339.612, 355.866, 372.12 , 388.374,125.228, 130.646, 136.064, 141.482,274.792, 285.736, 296.68 , 307.624,280.552, 291.712, 302.872, 314.032, 92.684, 98.75 , 104.816, 110.882,211.432, 223.672, 235.912, 248.152,215.464, 227.92 , 240.376, 252.832, - 178.824, 188.166, 197.508, 206.85 ,398.364, 417.21 , 436.056, 454.902,405.708, 424.878, 444.048, 463.218,188.868, 198.696, 208.524, 218.352,420.396, 440.214, 460.032, 479.85 ,427.74 , 447.882, 468.024, 488.166,157.196, 163.91 , 170.624, 177.338,343.912, 357.448, 370.984, 384.52 ,349.672, 363.424, 377.176, 390.928}); + auto expGradI = NDArrayFactory::create('c', {bS, iD, iH, iW, iC},{0.226f, 0.343f, 0.46f, 0.577f, 1.172f, 1.46f, 1.748f, 2.036f, 1.892f, 2.288f, 2.684f, 3.08f, 1.284f, 1.581f, 1.878f, 2.175f, 4.458f, 5.133f, 5.808f, 6.483f, 6.186f, 7.023f, 7.86f, 8.697f, 3.39f, 3.93f, 4.47f, 5.01f, 9.642f, 10.803f, 11.964f, 13.125f, 11.37f, 12.693f, 14.016f, 15.339f, + 5.266f, 5.707f, 6.148f, 6.589f, 12.98f, 13.916f, 14.852f, 15.788f, 14.564f, 15.608f, 16.652f, 17.696f, 6.284f, 7.166f, 8.048f, 8.93f, 17.896f, 19.768f, 21.64f, 23.512f, 21.928f, 24.016f, 26.104f, 28.192f, 18.12f, 19.686f, 21.252f, 22.818f, 45.852f, 49.146f, 52.44f, 55.734f, 53.196f, 56.814f, 60.432f, 64.05f, + 28.164f, 30.216f, 32.268f, 34.32f, 67.884f, 72.15f, 76.416f, 80.682f, 75.228f, 79.818f, 84.408f, 88.998f, 29.324f, 30.854f, 32.384f, 33.914f, 67.432f, 70.6f, 73.768f, 76.936f, 73.192f, 76.576f, 79.96f, 83.344f, 27.884f, 30.062f, 32.24f, 34.418f, 66.28f, 70.744f, 75.208f, 79.672f, 70.312f, 74.992f, 79.672f, 84.352f, + 58.296f, 61.806f, 65.316f, 68.826f, 133.98f, 141.162f, 148.344f, 155.526f, 141.324f, 148.83f, 156.336f, 163.842f, 68.34f, 72.336f, 76.332f, 80.328f, 156.012f, 164.166f, 172.32f, 180.474f, 163.356f, 171.834f, 180.312f, 188.79f, 61.292f, 64.118f, 66.944f, 69.77f, 136.552f, 142.312f, 148.072f, 153.832f, 142.312f, 148.288f, 154.264f, 160.24f, + 9.298f, 11.359f, 13.42f, 15.481f, 27.092f, 31.268f, 35.444f, 39.62f, 27.812f, 32.096f, 36.38f, 40.664f, 26.556f, 29.769f, 32.982f, 36.195f, 66.666f, 73.173f, 79.68f, 86.187f, 68.394f, 75.063f, 81.732f, 88.401f, 28.662f, 32.118f, 35.574f, 39.03f, 71.85f, 78.843f, 85.836f, 92.829f, 73.578f, 80.733f, 87.888f, 95.043f, + 29.89f, 32.275f, 34.66f, 37.045f, 70.004f, 74.828f, 79.652f, 84.476f, 71.588f, 76.52f, 81.452f, 86.384f, 71.084f, 75.854f, 80.624f, 85.394f, 163.048f, 172.696f, 182.344f, 191.992f, 167.08f, 176.944f, 186.808f, 196.672f, 138.648f, 146.046f, 153.444f, 160.842f, 310.236f, 325.194f, 340.152f, 355.11f, 317.58f, 332.862f, 348.144f, 363.426f, + 148.692f, 156.576f, 164.46f, 172.344f, 332.268f, 348.198f, 364.128f, 380.058f, 339.612f, 355.866f, 372.12f, 388.374f, 125.228f, 130.646f, 136.064f, 141.482f, 274.792f, 285.736f, 296.68f, 307.624f, 280.552f, 291.712f, 302.872f, 314.032f, 92.684f, 98.75f, 104.816f, 110.882f, 211.432f, 223.672f, 235.912f, 248.152f, 215.464f, 227.92f, 240.376f, 252.832f, + 178.824f, 188.166f, 197.508f, 206.85f, 398.364f, 417.21f, 436.056f, 454.902f, 405.708f, 424.878f, 444.048f, 463.218f, 188.868f, 198.696f, 208.524f, 218.352f, 420.396f, 440.214f, 460.032f, 479.85f, 427.74f, 447.882f, 468.024f, 488.166f, 157.196f, 163.91f, 170.624f, 177.338f, 343.912f, 357.448f, 370.984f, 384.52f, 349.672f, 363.424f, 377.176f, 390.928f}); - auto expGradW = NDArrayFactory::create('c', {kD, kH, kW, iC, oC},{120.96, 122.04, 123.12,120.96, 122.04, 123.12,120.96, 122.04, 123.12,120.96, 122.04, 123.12, 79.56, 80.28, 81. , 79.56, 80.28, 81. , 79.56, 80.28, 81. , 79.56, 80.28, 81. , - 154.8 , 156.24, 157.68,154.8 , 156.24, 157.68,154.8 , 156.24, 157.68,154.8 , 156.24, 157.68,101.76, 102.72, 103.68,101.76, 102.72, 103.68,101.76, 102.72, 103.68,101.76, 102.72, 103.68, - 111.24, 112.32, 113.4 ,111.24, 112.32, 113.4 ,111.24, 112.32, 113.4 ,111.24, 112.32, 113.4 , 73.08, 73.8 , 74.52, 73.08, 73.8 , 74.52, 73.08, 73.8 , 74.52, 73.08, 73.8 , 74.52, - 67.68, 68.4 , 69.12, 67.68, 68.4 , 69.12, 67.68, 68.4 , 69.12, 67.68, 68.4 , 69.12, 44.4 , 44.88, 45.36, 44.4 , 44.88, 45.36, 44.4 , 44.88, 45.36, 44.4 , 44.88, 45.36, - 85.92, 86.88, 87.84, 85.92, 86.88, 87.84, 85.92, 86.88, 87.84, 85.92, 86.88, 87.84, 56.32, 56.96, 57.6 , 56.32, 56.96, 57.6 , 56.32, 56.96, 57.6 , 56.32, 56.96, 57.6 , - 61.2 , 61.92, 62.64, 61.2 , 61.92, 62.64, 61.2 , 61.92, 62.64, 61.2 , 61.92, 62.64, 40.08, 40.56, 41.04, 40.08, 40.56, 41.04, 40.08, 40.56, 41.04, 40.08, 40.56, 41.04}); + auto expGradW = NDArrayFactory::create('c', {kD, kH, kW, iC, oC},{120.96f, 122.04f, 123.12f, 120.96f, 122.04f, 123.12f, 120.96f, 122.04f, 123.12f, 120.96f, 122.04f, 123.12f, 79.56f, 80.28f, 81.f, 79.56f, 80.28f, 81.f, 79.56f, 80.28f, 81.f, 79.56f, 80.28f, 81.f, + 154.8f, 156.24f, 157.68f, 154.8f, 156.24f, 157.68f, 154.8f, 156.24f, 157.68f, 154.8f, 156.24f, 157.68f, 101.76f, 102.72f, 103.68f, 101.76f, 102.72f, 103.68f, 101.76f, 102.72f, 103.68f, 101.76f, 102.72f, 103.68f, + 111.24f, 112.32f, 113.4f, 111.24f, 112.32f, 113.4f, 111.24f, 112.32f, 113.4f, 111.24f, 112.32f, 113.4f, 73.08f, 73.8f, 74.52f, 73.08f, 73.8f, 74.52f, 73.08f, 73.8f, 74.52f, 73.08f, 73.8f, 74.52f, + 67.68f, 68.4f, 69.12f, 67.68f, 68.4f, 69.12f, 67.68f, 68.4f, 69.12f, 67.68f, 68.4f, 69.12f, 44.4f, 44.88f, 45.36f, 44.4f, 44.88f, 45.36f, 44.4f, 44.88f, 45.36f, 44.4f, 44.88f, 45.36f, + 85.92f, 86.88f, 87.84f, 85.92f, 86.88f, 87.84f, 85.92f, 86.88f, 87.84f, 85.92f, 86.88f, 87.84f, 56.32f, 56.96f, 57.6f, 56.32f, 56.96f, 57.6f, 56.32f, 56.96f, 57.6f, 56.32f, 56.96f, 57.6f, + 61.2f, 61.92f, 62.64f, 61.2f, 61.92f, 62.64f, 61.2f, 61.92f, 62.64f, 61.2f, 61.92f, 62.64f, 40.08f, 40.56f, 41.04f, 40.08f, 40.56f, 41.04f, 40.08f, 40.56f, 41.04f, 40.08f, 40.56f, 41.04f}); // auto expGradB('c', {oC},{}); input = 2.; @@ -1303,19 +1302,19 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_bp_test2) { auto bias = NDArrayFactory::create('c', {oC}, {1,2,3}); auto gradO = NDArrayFactory::create('c', {bS, oD, oH, oW, oC}); - auto expGradI = NDArrayFactory::create('c', {bS, iD, iH, iW, iC},{ 0.014, 0.032, 0.05 , 0.068, 0.118, 0.181, 0.244, 0.307, 0.212, 0.257, 0.302, 0.347, 0.208, 0.298, 0.388, 0.478, 1.028, 1.262, 1.496, 1.73 , 1.036, 1.18 , 1.324, 1.468, 0.928, 1.018, 1.108, 1.198, 2.9 , 3.134, 3.368, 3.602, 2.188, 2.332, 2.476, 2.62 , - 1.202, 1.274, 1.346, 1.418, 3.142, 3.313, 3.484, 3.655, 2.048, 2.147, 2.246, 2.345, 0.532, 0.676, 0.82 , 0.964, 2.324, 2.666, 3.008, 3.35 , 2.008, 2.206, 2.404, 2.602, 3.584, 3.98 , 4.376, 4.772,10.552,11.452,12.352,13.252, 7.4 , 7.904, 8.408, 8.912, - 6.752, 7.148, 7.544, 7.94 ,17.752,18.652,19.552,20.452,11.432,11.936,12.44 ,12.944, 5.932, 6.184, 6.436, 6.688,14.42 ,14.978,15.536,16.094, 8.704, 9.01 , 9.316, 9.622, 3.11 , 3.236, 3.362, 3.488, 7.39 , 7.669, 7.948, 8.227, 4.388, 4.541, 4.694, 4.847, - 8.56 , 8.866, 9.172, 9.478,19.892,20.558,21.224,21.89 ,11.548,11.908,12.268,12.628,11.008,11.314,11.62 ,11.926,25.22 ,25.886,26.552,27.218,14.428,14.788,15.148,15.508, 7.322, 7.502, 7.682, 7.862,16.462,16.849,17.236,17.623, 9.248, 9.455, 9.662, 9.869, - 0.158, 0.392, 0.626, 0.86 , 1.27 , 1.765, 2.26 , 2.755, 1.22 , 1.481, 1.742, 2.003, 2.224, 2.746, 3.268, 3.79 , 6.788, 7.886, 8.984,10.082, 4.78 , 5.356, 5.932, 6.508, 6.4 , 6.922, 7.444, 7.966,15.572,16.67 ,17.768,18.866, 9.388, 9.964,10.54 ,11.116, - 4.802, 5.09 , 5.378, 5.666,11.206,11.809,12.412,13.015, 6.512, 6.827, 7.142, 7.457, 6.004, 6.58 , 7.156, 7.732,14.996,16.202,17.408,18.614, 9.208, 9.838,10.468,11.098,17.984,19.244,20.504,21.764,42.808,45.436,48.064,50.692,25.256,26.624,27.992,29.36 , - 28.064,29.324,30.584,31.844,63.832,66.46 ,69.088,71.716,36.2 ,37.568,38.936,40.304,18.316,19. ,19.684,20.368,40.916,42.338,43.76 ,45.182,22.816,23.554,24.292,25.03 , 8.438, 8.78 , 9.122, 9.464,18.91 ,19.621,20.332,21.043,10.58 ,10.949,11.318,11.687, - 20.944,21.682,22.42 ,23.158,46.388,47.918,49.448,50.978,25.66 ,26.452,27.244,28.036,26.848,27.586,28.324,29.062,58.628,60.158,61.688,63.218,31.996,32.788,33.58 ,34.372,16.106,16.502,16.898,17.294,34.894,35.713,36.532,37.351,18.896,19.319,19.742,20.165}); + auto expGradI = NDArrayFactory::create('c', {bS, iD, iH, iW, iC},{ 0.014f, 0.032f, 0.05f, 0.068f, 0.118f, 0.181f, 0.244f, 0.307f, 0.212f, 0.257f, 0.302f, 0.347f, 0.208f, 0.298f, 0.388f, 0.478f, 1.028f, 1.262f, 1.496f, 1.73f, 1.036f, 1.18f, 1.324f, 1.468f, 0.928f, 1.018f, 1.108f, 1.198f, 2.9f, 3.134f, 3.368f, 3.602f, 2.188f, 2.332f, 2.476f, 2.62f, + 1.202f, 1.274f, 1.346f, 1.418f, 3.142f, 3.313f, 3.484f, 3.655f, 2.048f, 2.147f, 2.246f, 2.345f, 0.532f, 0.676f, 0.82f, 0.964f, 2.324f, 2.666f, 3.008f, 3.35f, 2.008f, 2.206f, 2.404f, 2.602f, 3.584f, 3.98f, 4.376f, 4.772f, 10.552f, 11.452f, 12.352f, 13.252f, 7.4f, 7.904f, 8.408f, 8.912f, + 6.752f, 7.148f, 7.544f, 7.94f, 17.752f, 18.652f, 19.552f, 20.452f, 11.432f, 11.936f, 12.44f, 12.944f, 5.932f, 6.184f, 6.436f, 6.688f, 14.42f, 14.978f, 15.536f, 16.094f, 8.704f, 9.01f, 9.316f, 9.622f, 3.11f, 3.236f, 3.362f, 3.488f, 7.39f, 7.669f, 7.948f, 8.227f, 4.388f, 4.541f, 4.694f, 4.847f, + 8.56f, 8.866f, 9.172f, 9.478f, 19.892f, 20.558f, 21.224f, 21.89f, 11.548f, 11.908f, 12.268f, 12.628f, 11.008f, 11.314f, 11.62f, 11.926f, 25.22f, 25.886f, 26.552f, 27.218f, 14.428f, 14.788f, 15.148f, 15.508f, 7.322f, 7.502f, 7.682f, 7.862f, 16.462f, 16.849f, 17.236f, 17.623f, 9.248f, 9.455f, 9.662f, 9.869f, + 0.158f, 0.392f, 0.626f, 0.86f, 1.27f, 1.765f, 2.26f, 2.755f, 1.22f, 1.481f, 1.742f, 2.003f, 2.224f, 2.746f, 3.268f, 3.79f, 6.788f, 7.886f, 8.984f, 10.082f, 4.78f, 5.356f, 5.932f, 6.508f, 6.4f, 6.922f, 7.444f, 7.966f, 15.572f, 16.67f, 17.768f, 18.866f, 9.388f, 9.964f, 10.54f, 11.116f, + 4.802f, 5.09f, 5.378f, 5.666f, 11.206f, 11.809f, 12.412f, 13.015f, 6.512f, 6.827f, 7.142f, 7.457f, 6.004f, 6.58f, 7.156f, 7.732f, 14.996f, 16.202f, 17.408f, 18.614f, 9.208f, 9.838f, 10.468f, 11.098f, 17.984f, 19.244f, 20.504f, 21.764f, 42.808f, 45.436f, 48.064f, 50.692f, 25.256f, 26.624f, 27.992f, 29.36f, + 28.064f, 29.324f, 30.584f, 31.844f, 63.832f, 66.46f, 69.088f, 71.716f, 36.2f, 37.568f, 38.936f, 40.304f, 18.316f, 19.f, 19.684f, 20.368f, 40.916f, 42.338f, 43.76f, 45.182f, 22.816f, 23.554f, 24.292f, 25.03f, 8.438f, 8.78f, 9.122f, 9.464f, 18.91f, 19.621f, 20.332f, 21.043f, 10.58f, 10.949f, 11.318f, 11.687f, + 20.944f, 21.682f, 22.42f, 23.158f, 46.388f, 47.918f, 49.448f, 50.978f, 25.66f, 26.452f, 27.244f, 28.036f, 26.848f, 27.586f, 28.324f, 29.062f, 58.628f, 60.158f, 61.688f, 63.218f, 31.996f, 32.788f, 33.58f, 34.372f, 16.106f, 16.502f, 16.898f, 17.294f, 34.894f, 35.713f, 36.532f, 37.351f, 18.896f, 19.319f, 19.742f, 20.165f}); - auto expGradW = NDArrayFactory::create('c', {kD, kH, kW, iC, oC},{7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16, - 7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16, - 7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16, - 7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16}); + auto expGradW = NDArrayFactory::create('c', {kD, kH, kW, iC, oC},{7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, + 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, + 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, + 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f}); // auto expGradB('c', {oC},{}); input = 2.; @@ -1351,23 +1350,23 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_bp_test3) { auto bias = NDArrayFactory::create('c', {oC}, {1,2,3}); auto gradO = NDArrayFactory::create('c', {bS, oC, oD, oH, oW}); - auto expGradI = NDArrayFactory::create('c', {bS, iC, iD, iH, iW},{2.091, 4.356, 2.268, 4.53 , 9.42 , 4.896, 4.65 , 9.672, 5.028, 2.517, 5.226, 2.712, 4.932,10.242, 5.316,10.62 ,22.02 ,11.412,10.908,22.62 ,11.724, 5.868,12.15 , 6.288, 2.913, 6.03 , 3.12 , 6.234,12.888, 6.66 , 6.402,13.236, 6.84 , 3.423, 7.068, 3.648, - 2.415, 5.04 , 2.628, 5.25 ,10.932, 5.688, 5.37 ,11.184, 5.82 , 2.913, 6.054, 3.144, 5.724,11.898, 6.18 ,12.348,25.62 ,13.284,12.636,26.22 ,13.596, 6.804,14.094, 7.296, 3.381, 7.002, 3.624, 7.242,14.976, 7.74 , 7.41 ,15.324, 7.92 , 3.963, 8.184, 4.224, - 2.739, 5.724, 2.988, 5.97 ,12.444, 6.48 , 6.09 ,12.696, 6.612, 3.309, 6.882, 3.576, 6.516,13.554, 7.044,14.076,29.22 ,15.156,14.364,29.82 ,15.468, 7.74 ,16.038, 8.304, 3.849, 7.974, 4.128, 8.25 ,17.064, 8.82 , 8.418,17.412, 9. , 4.503, 9.3 , 4.8 , - 3.063, 6.408, 3.348, 6.69 ,13.956, 7.272, 6.81 ,14.208, 7.404, 3.705, 7.71 , 4.008, 7.308,15.21 , 7.908,15.804,32.82 ,17.028,16.092,33.42 ,17.34 , 8.676,17.982, 9.312, 4.317, 8.946, 4.632, 9.258,19.152, 9.9 , 9.426,19.5 ,10.08 , 5.043,10.416, 5.376, - 5.619,11.484, 5.868,11.73 ,23.964,12.24 ,12.138,24.792,12.66 , 6.333,12.93 , 6.6 ,12.42 ,25.362,12.948,25.884,52.836,26.964,26.748,54.588,27.852,13.932,28.422,14.496, 6.873,14.022, 7.152,14.298,29.16 ,14.868,14.754,30.084,15.336, 7.671,15.636, 7.968, - 6.807,13.896, 7.092,14.178,28.932,14.76 ,14.586,29.76 ,15.18 , 7.593,15.486, 7.896,14.94 ,30.474,15.54 ,31.068,63.348,32.292,31.932,65.1 ,33.18 ,16.596,33.822,17.232, 8.205,16.722, 8.52 ,17.034,34.704,17.676,17.49 ,35.628,18.144, 9.075,18.48 , 9.408, - 7.995,16.308, 8.316,16.626,33.9 ,17.28 ,17.034,34.728,17.7 , 8.853,18.042, 9.192,17.46 ,35.586,18.132,36.252,73.86 ,37.62 ,37.116,75.612,38.508,19.26 ,39.222,19.968, 9.537,19.422, 9.888,19.77 ,40.248,20.484,20.226,41.172,20.952,10.479,21.324,10.848, - 9.183,18.72 , 9.54 ,19.074,38.868,19.8 ,19.482,39.696,20.22 ,10.113,20.598,10.488,19.98 ,40.698,20.724,41.436,84.372,42.948,42.3 ,86.124,43.836,21.924,44.622,22.704,10.869,22.122,11.256,22.506,45.792,23.292,22.962,46.716,23.76 ,11.883,24.168,12.288}); + auto expGradI = NDArrayFactory::create('c', {bS, iC, iD, iH, iW},{2.091f, 4.356f, 2.268f, 4.53f, 9.42f, 4.896f, 4.65f, 9.672f, 5.028f, 2.517f, 5.226f, 2.712f, 4.932f, 10.242f, 5.316f, 10.62f, 22.02f, 11.412f, 10.908f, 22.62f, 11.724f, 5.868f, 12.15f, 6.288f, 2.913f, 6.03f, 3.12f, 6.234f, 12.888f, 6.66f, 6.402f, 13.236f, 6.84f, 3.423f, 7.068f, 3.648f, + 2.415f, 5.04f, 2.628f, 5.25f, 10.932f, 5.688f, 5.37f, 11.184f, 5.82f, 2.913f, 6.054f, 3.144f, 5.724f, 11.898f, 6.18f, 12.348f, 25.62f, 13.284f, 12.636f, 26.22f, 13.596f, 6.804f, 14.094f, 7.296f, 3.381f, 7.002f, 3.624f, 7.242f, 14.976f, 7.74f, 7.41f, 15.324f, 7.92f, 3.963f, 8.184f, 4.224f, + 2.739f, 5.724f, 2.988f, 5.97f, 12.444f, 6.48f, 6.09f, 12.696f, 6.612f, 3.309f, 6.882f, 3.576f, 6.516f, 13.554f, 7.044f, 14.076f, 29.22f, 15.156f, 14.364f, 29.82f, 15.468f, 7.74f, 16.038f, 8.304f, 3.849f, 7.974f, 4.128f, 8.25f, 17.064f, 8.82f, 8.418f, 17.412f, 9.f, 4.503f, 9.3f, 4.8f, + 3.063f, 6.408f, 3.348f, 6.69f, 13.956f, 7.272f, 6.81f, 14.208f, 7.404f, 3.705f, 7.71f, 4.008f, 7.308f, 15.21f, 7.908f, 15.804f, 32.82f, 17.028f, 16.092f, 33.42f, 17.34f, 8.676f, 17.982f, 9.312f, 4.317f, 8.946f, 4.632f, 9.258f, 19.152f, 9.9f, 9.426f, 19.5f, 10.08f, 5.043f, 10.416f, 5.376f, + 5.619f, 11.484f, 5.868f, 11.73f, 23.964f, 12.24f, 12.138f, 24.792f, 12.66f, 6.333f, 12.93f, 6.6f, 12.42f, 25.362f, 12.948f, 25.884f, 52.836f, 26.964f, 26.748f, 54.588f, 27.852f, 13.932f, 28.422f, 14.496f, 6.873f, 14.022f, 7.152f, 14.298f, 29.16f, 14.868f, 14.754f, 30.084f, 15.336f, 7.671f, 15.636f, 7.968f, + 6.807f, 13.896f, 7.092f, 14.178f, 28.932f, 14.76f, 14.586f, 29.76f, 15.18f, 7.593f, 15.486f, 7.896f, 14.94f, 30.474f, 15.54f, 31.068f, 63.348f, 32.292f, 31.932f, 65.1f, 33.18f, 16.596f, 33.822f, 17.232f, 8.205f, 16.722f, 8.52f, 17.034f, 34.704f, 17.676f, 17.49f, 35.628f, 18.144f, 9.075f, 18.48f, 9.408f, + 7.995f, 16.308f, 8.316f, 16.626f, 33.9f, 17.28f, 17.034f, 34.728f, 17.7f, 8.853f, 18.042f, 9.192f, 17.46f, 35.586f, 18.132f, 36.252f, 73.86f, 37.62f, 37.116f, 75.612f, 38.508f, 19.26f, 39.222f, 19.968f, 9.537f, 19.422f, 9.888f, 19.77f, 40.248f, 20.484f, 20.226f, 41.172f, 20.952f, 10.479f, 21.324f, 10.848f, + 9.183f, 18.72f, 9.54f, 19.074f, 38.868f, 19.8f, 19.482f, 39.696f, 20.22f, 10.113f, 20.598f, 10.488f, 19.98f, 40.698f, 20.724f, 41.436f, 84.372f, 42.948f, 42.3f, 86.124f, 43.836f, 21.924f, 44.622f, 22.704f, 10.869f, 22.122f, 11.256f, 22.506f, 45.792f, 23.292f, 22.962f, 46.716f, 23.76f, 11.883f, 24.168f, 12.288f}); - auto expGradW = NDArrayFactory::create('c', {oC, iC, kD, kH, kW},{5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, - 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, - 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, - 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, - 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, - 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4}); + auto expGradW = NDArrayFactory::create('c', {oC, iC, kD, kH, kW},{5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, + 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, + 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, + 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, + 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, + 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f}); - auto expGradB = NDArrayFactory::create('c', {oC},{2.64, 3.92, 5.2 }); + auto expGradB = NDArrayFactory::create('c', {oC},{2.64f, 3.92f, 5.2f}); input = 2.; weights.linspace(0.1, 0.1); @@ -1408,10 +1407,10 @@ TYPED_TEST(TypedConvolutionTests1, depthwise_conv2d_1) { auto weights = NDArrayFactory::create('c', {kH, kW, iC, mC}); - auto expOutput = NDArrayFactory::create('c', {bS, oH, oW, oC},{12. , 12.8, 13.6, 14.4,12. , 12.8, 13.6, 14.4, 5.2, 5.6, 6. , 6.4,13.2, 14.4, 15.6, 16.8,13.2, 14.4, 15.6, 16.8, 5.4, 6. , 6.6, 7.2, - 13.2, 14.4, 15.6, 16.8,13.2, 14.4, 15.6, 16.8, 5.4, 6. , 6.6, 7.2, 5.6, 6.4, 7.2, 8. , 5.6, 6.4, 7.2, 8. , 2. , 2.4, 2.8, 3.2, - 12. , 12.8, 13.6, 14.4,12. , 12.8, 13.6, 14.4, 5.2, 5.6, 6. , 6.4,13.2, 14.4, 15.6, 16.8,13.2, 14.4, 15.6, 16.8, 5.4, 6. , 6.6, 7.2, - 13.2, 14.4, 15.6, 16.8,13.2, 14.4, 15.6, 16.8, 5.4, 6. , 6.6, 7.2, 5.6, 6.4, 7.2, 8. , 5.6, 6.4, 7.2, 8. , 2. , 2.4, 2.8, 3.2}); + auto expOutput = NDArrayFactory::create('c', {bS, oH, oW, oC},{12.f, 12.8f, 13.6f, 14.4f, 12.f, 12.8f, 13.6f, 14.4f, 5.2f, 5.6f, 6.f, 6.4f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 5.4f, 6.f, 6.6f, 7.2f, + 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 5.4f, 6.f, 6.6f, 7.2f, 5.6f, 6.4f, 7.2f, 8.f, 5.6f, 6.4f, 7.2f, 8.f, 2.f, 2.4f, 2.8f, 3.2f, + 12.f, 12.8f, 13.6f, 14.4f, 12.f, 12.8f, 13.6f, 14.4f, 5.2f, 5.6f, 6.f, 6.4f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 5.4f, 6.f, 6.6f, 7.2f, + 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 5.4f, 6.f, 6.6f, 7.2f, 5.6f, 6.4f, 7.2f, 8.f, 5.6f, 6.4f, 7.2f, 8.f, 2.f, 2.4f, 2.8f, 3.2f}); input = 2.; weights.linspace(0.1, 0.1); @@ -1440,8 +1439,8 @@ TEST_F(ConvolutionTests1, depthwise_conv2d_2) { auto weights = NDArrayFactory::create('c', {kH, kW, iC, mC}); - auto expOutput = NDArrayFactory::create('c', {bS, oH, oW, oC},{13.2, 14.4, 15.6, 16.8,13.2, 14.4, 15.6, 16.8,13.2, 14.4, 15.6, 16.8,13.2, 14.4, 15.6, 16.8, - 13.2, 14.4, 15.6, 16.8,13.2, 14.4, 15.6, 16.8,13.2, 14.4, 15.6, 16.8,13.2, 14.4, 15.6, 16.8}); + auto expOutput = NDArrayFactory::create('c', {bS, oH, oW, oC},{13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, + 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f}); input = 2.; weights.linspace(0.1, 0.1); @@ -1698,14 +1697,14 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_test1) { auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); auto weights = NDArrayFactory::create('c', {kD, kH, kW, iC, oC}); - auto expected = NDArrayFactory::create('c', {2, 3, 4, 3, 3}, {64.,64.,64.,64.,64.,64.,32.,32.,32.,96.,96.,96.,96.,96.,96.,48.,48.,48.,96.,96.,96.,96.,96.,96.,48.,48.,48., - 64.,64.,64.,64.,64.,64.,32.,32.,32.,64.,64.,64.,64.,64.,64.,32.,32.,32.,96.,96.,96.,96.,96.,96.,48.,48.,48., - 96.,96.,96.,96.,96.,96.,48.,48.,48.,64.,64.,64.,64.,64.,64.,32.,32.,32.,32.,32.,32.,32.,32.,32.,16.,16.,16., - 48.,48.,48.,48.,48.,48.,24.,24.,24.,48.,48.,48.,48.,48.,48.,24.,24.,24.,32.,32.,32.,32.,32.,32.,16.,16.,16., - 64.,64.,64.,64.,64.,64.,32.,32.,32.,96.,96.,96.,96.,96.,96.,48.,48.,48.,96.,96.,96.,96.,96.,96.,48.,48.,48., - 64.,64.,64.,64.,64.,64.,32.,32.,32.,64.,64.,64.,64.,64.,64.,32.,32.,32.,96.,96.,96.,96.,96.,96.,48.,48.,48., - 96.,96.,96.,96.,96.,96.,48.,48.,48.,64.,64.,64.,64.,64.,64.,32.,32.,32.,32.,32.,32.,32.,32.,32.,16.,16.,16., - 48.,48.,48.,48.,48.,48.,24.,24.,24.,48.,48.,48.,48.,48.,48.,24.,24.,24.,32.,32.,32.,32.,32.,32.,16.,16.,16.}); + auto expected = NDArrayFactory::create('c', {2, 3, 4, 3, 3}, {64.f, 64.f, 64.f, 64.f, 64.f, 64.f, 32.f, 32.f, 32.f, 96.f, 96.f, 96.f, 96.f, 96.f, 96.f, 48.f, 48.f, 48.f, 96.f, 96.f, 96.f, 96.f, 96.f, 96.f, 48.f, 48.f, 48.f, + 64.f, 64.f, 64.f, 64.f, 64.f, 64.f, 32.f, 32.f, 32.f, 64.f, 64.f, 64.f, 64.f, 64.f, 64.f, 32.f, 32.f, 32.f, 96.f, 96.f, 96.f, 96.f, 96.f, 96.f, 48.f, 48.f, 48.f, + 96.f, 96.f, 96.f, 96.f, 96.f, 96.f, 48.f, 48.f, 48.f, 64.f, 64.f, 64.f, 64.f, 64.f, 64.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 16.f, 16.f, 16.f, + 48.f, 48.f, 48.f, 48.f, 48.f, 48.f, 24.f, 24.f, 24.f, 48.f, 48.f, 48.f, 48.f, 48.f, 48.f, 24.f, 24.f, 24.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 16.f, 16.f, 16.f, + 64.f, 64.f, 64.f, 64.f, 64.f, 64.f, 32.f, 32.f, 32.f, 96.f, 96.f, 96.f, 96.f, 96.f, 96.f, 48.f, 48.f, 48.f, 96.f, 96.f, 96.f, 96.f, 96.f, 96.f, 48.f, 48.f, 48.f, + 64.f, 64.f, 64.f, 64.f, 64.f, 64.f, 32.f, 32.f, 32.f, 64.f, 64.f, 64.f, 64.f, 64.f, 64.f, 32.f, 32.f, 32.f, 96.f, 96.f, 96.f, 96.f, 96.f, 96.f, 48.f, 48.f, 48.f, + 96.f, 96.f, 96.f, 96.f, 96.f, 96.f, 48.f, 48.f, 48.f, 64.f, 64.f, 64.f, 64.f, 64.f, 64.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 16.f, 16.f, 16.f, + 48.f, 48.f, 48.f, 48.f, 48.f, 48.f, 24.f, 24.f, 24.f, 48.f, 48.f, 48.f, 48.f, 48.f, 48.f, 24.f, 24.f, 24.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 16.f, 16.f, 16.f}); input = 2.; weights = 1.; @@ -1730,14 +1729,14 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_test2) { auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); auto weights = NDArrayFactory::create('c', {kD, kH, kW, iC, oC}); - auto expected = NDArrayFactory::create('c', {2, 3, 4, 3, 3}, {534.4,540.8,547.2,534.4,540.8,547.2,248. ,251.2,254.4,686.4,696. ,705.6,686.4,696. ,705.6,314.4,319.2,324. ,686.4,696. ,705.6,686.4,696. ,705.6,314.4,319.2,324. , - 380.8,387.2,393.6,380.8,387.2,393.6,171.2,174.4,177.6,534.4,540.8,547.2,534.4,540.8,547.2,248. ,251.2,254.4,686.4,696. ,705.6,686.4,696. ,705.6,314.4,319.2,324. , - 686.4,696. ,705.6,686.4,696. ,705.6,314.4,319.2,324. ,380.8,387.2,393.6,380.8,387.2,393.6,171.2,174.4,177.6,152. ,155.2,158.4,152. ,155.2,158.4, 66.4, 68. , 69.6, - 170.4,175.2,180. ,170.4,175.2,180. , 70.8, 73.2, 75.6,170.4,175.2,180. ,170.4,175.2,180. , 70.8, 73.2, 75.6, 75.2, 78.4, 81.6, 75.2, 78.4, 81.6, 28. , 29.6, 31.2, - 534.4,540.8,547.2,534.4,540.8,547.2,248. ,251.2,254.4,686.4,696. ,705.6,686.4,696. ,705.6,314.4,319.2,324. ,686.4,696. ,705.6,686.4,696. ,705.6,314.4,319.2,324. , - 380.8,387.2,393.6,380.8,387.2,393.6,171.2,174.4,177.6,534.4,540.8,547.2,534.4,540.8,547.2,248. ,251.2,254.4,686.4,696. ,705.6,686.4,696. ,705.6,314.4,319.2,324. , - 686.4,696. ,705.6,686.4,696. ,705.6,314.4,319.2,324. ,380.8,387.2,393.6,380.8,387.2,393.6,171.2,174.4,177.6,152. ,155.2,158.4,152. ,155.2,158.4, 66.4, 68. , 69.6, - 170.4,175.2,180. ,170.4,175.2,180. , 70.8, 73.2, 75.6,170.4,175.2,180. ,170.4,175.2,180. , 70.8, 73.2, 75.6, 75.2, 78.4, 81.6, 75.2, 78.4, 81.6, 28. , 29.6, 31.2}); + auto expected = NDArrayFactory::create('c', {2, 3, 4, 3, 3}, {534.4f, 540.8f, 547.2f, 534.4f, 540.8f, 547.2f, 248.f, 251.2f, 254.4f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, + 380.8f, 387.2f, 393.6f, 380.8f, 387.2f, 393.6f, 171.2f, 174.4f, 177.6f, 534.4f, 540.8f, 547.2f, 534.4f, 540.8f, 547.2f, 248.f, 251.2f, 254.4f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, + 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, 380.8f, 387.2f, 393.6f, 380.8f, 387.2f, 393.6f, 171.2f, 174.4f, 177.6f, 152.f, 155.2f, 158.4f, 152.f, 155.2f, 158.4f, 66.4f, 68.f, 69.6f, + 170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f, 170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f, 75.2f, 78.4f, 81.6f, 75.2f, 78.4f, 81.6f, 28.f, 29.6f, 31.2f, + 534.4f, 540.8f, 547.2f, 534.4f, 540.8f, 547.2f, 248.f, 251.2f, 254.4f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, + 380.8f, 387.2f, 393.6f, 380.8f, 387.2f, 393.6f, 171.2f, 174.4f, 177.6f, 534.4f, 540.8f, 547.2f, 534.4f, 540.8f, 547.2f, 248.f, 251.2f, 254.4f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, + 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, 380.8f, 387.2f, 393.6f, 380.8f, 387.2f, 393.6f, 171.2f, 174.4f, 177.6f, 152.f, 155.2f, 158.4f, 152.f, 155.2f, 158.4f, 66.4f, 68.f, 69.6f, + 170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f, 170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f, 75.2f, 78.4f, 81.6f, 75.2f, 78.4f, 81.6f, 28.f, 29.6f, 31.2f}); input = 2.; weights.linspace(0.1, 0.1); @@ -1761,10 +1760,10 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_test3) { auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); auto weights = NDArrayFactory::create('c', {kD, kH, kW, iC, oC}); - auto expected = NDArrayFactory::create('c', {2, 2, 2, 2, 3}, {686.4,696.,705.6,686.4,696.,705.6,686.4,696.,705.6,686.4,696.,705.6, - 686.4,696.,705.6,686.4,696.,705.6,686.4,696.,705.6,686.4,696.,705.6, - 686.4,696.,705.6,686.4,696.,705.6,686.4,696.,705.6,686.4,696.,705.6, - 686.4,696.,705.6,686.4,696.,705.6,686.4,696.,705.6,686.4,696.,705.6}); + auto expected = NDArrayFactory::create('c', {2, 2, 2, 2, 3}, {686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, + 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, + 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, + 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f}); input = 2.; weights.linspace(0.1, 0.1); @@ -1844,10 +1843,10 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_test6) { auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); auto weights = NDArrayFactory::create('c', {kD, kH, kW, iC, oC}); - auto bias = NDArrayFactory::create('c', {oC},{1,2,3}); - auto expected = NDArrayFactory::create('c', {2, 3, 2, 2, 2},{49., 49.,49., 49., 49., 49.,49., 49., 50., 50.,50., 50., 50., 50.,50., 50., - 51., 51.,51., 51., 51., 51.,51., 51., 49., 49.,49., 49., 49., 49.,49., 49., - 50., 50.,50., 50., 50., 50.,50., 50., 51., 51.,51., 51., 51., 51.,51., 51.}); + auto bias = NDArrayFactory::create('c', {oC},{1.f, 2.f, 3.f}); + auto expected = NDArrayFactory::create('c', {2, 3, 2, 2, 2},{49.f, 49.f, 49.f, 49.f, 49.f, 49.f, 49.f, 49.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, + 51.f, 51.f, 51.f, 51.f, 51.f, 51.f, 51.f, 51.f, 49.f, 49.f, 49.f, 49.f, 49.f, 49.f, 49.f, 49.f, + 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 51.f, 51.f, 51.f, 51.f, 51.f, 51.f, 51.f, 51.f}); input = 2.; weights = 0.5; @@ -1873,11 +1872,11 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_test7) { auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); auto weights = NDArrayFactory::create('c', {oC, iC, kD, kH, kW}); - auto bias = NDArrayFactory::create('c', {oC},{1,2,3}); - auto expected = NDArrayFactory::create('c', {2, 3, 2, 2, 2},{236.2, 236.2, 236.2, 236.2, 236.2, 236.2, 236.2, 236.2, 698. , 698. , 698. , 698. , - 698. , 698. , 698. , 698. ,1159.8,1159.8,1159.8,1159.8,1159.8,1159.8,1159.8,1159.8, - 236.2, 236.2, 236.2, 236.2, 236.2, 236.2, 236.2, 236.2, 698. , 698. , 698. , 698. , - 698. , 698. , 698. , 698. ,1159.8,1159.8,1159.8,1159.8,1159.8,1159.8,1159.8,1159.8}); + auto bias = NDArrayFactory::create('c', {oC},{1.f, 2.f, 3.f}); + auto expected = NDArrayFactory::create('c', {2, 3, 2, 2, 2},{236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 698.f, 698.f, 698.f, 698.f, + 698.f, 698.f, 698.f, 698.f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, + 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 698.f, 698.f, 698.f, 698.f, + 698.f, 698.f, 698.f, 698.f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f}); input = 2.; weights.linspace(0.1, 0.1); weights.permutei({2, 3, 4, 1, 0}); @@ -1904,9 +1903,9 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_test8) { auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); auto weights = NDArrayFactory::create('c', {oC, iC, kD, kH, kW}); - auto expected = NDArrayFactory::create('c', {2, 3, 2, 2, 2},{235.2, 235.2, 235.2, 235.2, 235.2, 235.2, 235.2, 235.2, 696. , 696. , 696. , 696. , 696. , 696. , 696. , 696. , - 1156.8,1156.8,1156.8,1156.8,1156.8,1156.8,1156.8,1156.8, 235.2, 235.2, 235.2, 235.2, 235.2, 235.2, 235.2, 235.2, - 696. , 696. , 696. , 696. , 696. , 696. , 696. , 696. ,1156.8,1156.8,1156.8,1156.8,1156.8,1156.8,1156.8,1156.8}); + auto expected = NDArrayFactory::create('c', {2, 3, 2, 2, 2},{235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 696.f, 696.f, 696.f, 696.f, 696.f, 696.f, 696.f, 696.f, + 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, + 696.f, 696.f, 696.f, 696.f, 696.f, 696.f, 696.f, 696.f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f}); input = 2.; weights.linspace(0.1, 0.1); weights.permutei({2, 3, 4, 1, 0}); @@ -1998,10 +1997,10 @@ TYPED_TEST(TypedConvolutionTests1, pointwise_conv2d_test1) { auto bias = NDArrayFactory::create('c', {oC}); - auto expOutput = NDArrayFactory::create('c', {bS, iH, iW, oC},{ 5.4, 6.2, 7.0, 5.4, 6.2, 7.0, 5.4, 6.2, 7.0, 5.4, 6.2, 7.0, 5.4, 6.2, 7.0, 5.4, 6.2, 7.0, 5.4, 6.2, - 7.0, 5.4, 6.2, 7.0, 5.4, 6.2, 7.0, 5.4, 6.2, 7.0, 5.4, 6.2, 7.0, 5.4, 6.2, 7.0, 5.4, 6.2, 7.0, 5.4, - 6.2, 7.0, 5.4, 6.2, 7.0, 5.4, 6.2, 7.0, 5.4, 6.2, 7.0, 5.4, 6.2, 7.0, 5.4, 6.2, 7.0, 5.4, 6.2, 7.0, - 5.4, 6.2, 7.0, 5.4, 6.2, 7.0, 5.4, 6.2, 7.0, 5.4, 6.2, 7.0}); + auto expOutput = NDArrayFactory::create('c', {bS, iH, iW, oC},{ 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, + 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, + 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, + 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f}); input = 2.; weights.linspace(0.1, 0.1); bias = 1.; @@ -2111,21 +2110,21 @@ TEST_F(ConvolutionTests1, vol2col_test2) { auto columns = NDArrayFactory::create('c', {kD, iC, kH, oW, kW, bS, oD, oH}); columns.permutei({5, 1, 0, 2, 4, 6, 7, 3}); columns = -1.; - auto columnsExpected = NDArrayFactory::create('c', {bS, iC, kD, kH, kW, oD, oH, oW}, {1., 2., 3., 4., 5., 6., 7., 8., 9., -10., 11., 12., 2., 0., 4., 0., 6., 0., 8., 0., 10., 0., 12., 0., 3., 4., 5., 6., 0., 0., 9., 10., 11., 12., 0., 0., 4., 0., 6., 0., 0.,0., 10., 0., 12., 0., 0., 0., 5., 6., 0., 0., 0., 0., 11., 12., 0., 0., 0., 0., 6., 0., 0., 0., 0., 0., 12., 0., 0., 0., 0., 0., 7., 8., -9., 10., 11., 12., 0., 0., 0., 0., 0., 0., 8., 0., 10., 0., 12., 0., 0., 0., 0., 0., 0., 0., 9., 10., 11., 12., 0., 0., 0., 0., 0., 0., 0., 0., 10., 0., 12., 0., 0., 0., 0., 0., 0., 0., 0., 0., 11., 12., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 12., 0., 0., 0., 0., 0., -0., 0., 0., 0., 0., 0., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 14., 0., 16., 0., 18., 0., 20., 0., 22., 0., 24., 0., 15., 16., 17., 18., 0., 0., 21., 22., 23., 24., 0., 0., 16., 0., 18., 0., 0., 0., 22., 0., 24., 0., 0., 0., 17., 18., 0., 0., 0., 0., -23., 24., 0., 0., 0., 0., 18., 0., 0., 0., 0., 0., 24., 0., 0., 0., 0., 0., 19., 20., 21., 22., 23., 24., 0., 0., 0., 0., 0., 0., 20., 0., 22., 0., 24., 0., 0., 0., 0., 0., 0., 0., 21., 22., 23., 24., 0., 0., 0., 0., 0., 0., 0., 0., 22., 0., 24., 0., 0., 0., 0., 0., 0., -0., 0., 0., 23., 24., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 24., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 25., 26., 27., 28., 29., 30., 31., 32., 33., 34., 35., 36., 26., 0., 28., 0., 30., 0., 32., 0., 34., 0., 36., 0., 27., 28., 29., 30., 0., 0., 33., 34., 35., 36., -0., 0., 28., 0., 30., 0., 0., 0., 34., 0., 36., 0., 0., 0., 29., 30., 0., 0., 0., 0., 35., 36., 0., 0., 0., 0., 30., 0., 0., 0., 0., 0., 36., 0., 0., 0., 0., 0., 31., 32., 33., 34., 35., 36., 0., 0., 0., 0., 0., 0., 32., 0., 34., 0., 36., 0., 0., 0., 0., 0., 0., 0., 33., -34., 35., 36., 0., 0., 0., 0., 0., 0., 0., 0., 34., 0., 36., 0., 0., 0., 0., 0., 0., 0., 0., 0., 35., 36., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 36., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 37., 38., 39., 40., 41., 42., 43., 44., 45., 46., 47., 48., 38., 0., 40., -0., 42., 0., 44., 0., 46., 0., 48., 0., 39., 40., 41., 42., 0., 0., 45., 46., 47., 48., 0., 0., 40., 0., 42., 0., 0., 0., 46., 0., 48., 0., 0., 0., 41., 42., 0., 0., 0., 0., 47., 48., 0., 0., 0., 0., 42., 0., 0., 0., 0., 0., 48., 0., 0., 0., 0., 0., 43., 44., 45., 46., 47., -48., 0., 0., 0., 0., 0., 0., 44., 0., 46., 0., 48., 0., 0., 0., 0., 0., 0., 0., 45., 46., 47., 48., 0., 0., 0., 0., 0., 0., 0., 0., 46., 0., 48., 0., 0., 0., 0., 0., 0., 0., 0., 0., 47., 48., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 48., 0., 0., 0., 0., 0., 0., 0., 0., 0., -0., 0., 49., 50., 51., 52., 53., 54., 55., 56., 57., 58., 59., 60., 50., 0., 52., 0., 54., 0., 56., 0., 58., 0., 60., 0., 51., 52., 53., 54., 0., 0., 57., 58., 59., 60., 0., 0., 52., 0., 54., 0., 0., 0., 58., 0., 60., 0., 0., 0., 53., 54., 0., 0., 0., 0., 59., 60., 0., 0., -0., 0., 54., 0., 0., 0., 0., 0., 60., 0., 0., 0., 0., 0., 55., 56., 57., 58., 59., 60., 0., 0., 0., 0., 0., 0., 56., 0., 58., 0., 60., 0., 0., 0., 0., 0., 0., 0., 57., 58., 59., 60., 0., 0., 0., 0., 0., 0., 0., 0., 58., 0., 60., 0., 0., 0., 0., 0., 0., 0., 0., 0., 59., 60., -0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 60., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 61., 62., 63., 64., 65., 66., 67., 68., 69., 70., 71., 72., 62., 0., 64., 0., 66., 0., 68., 0., 70., 0., 72., 0., 63., 64., 65., 66., 0., 0., 69., 70., 71., 72., 0., 0., 64., 0., 66., -0., 0., 0., 70., 0., 72., 0., 0., 0., 65., 66., 0., 0., 0., 0., 71., 72., 0., 0., 0., 0., 66., 0., 0., 0., 0., 0., 72., 0., 0., 0., 0., 0., 67., 68., 69., 70., 71., 72., 0., 0., 0., 0., 0., 0., 68., 0., 70., 0., 72., 0., 0., 0., 0., 0., 0., 0., 69., 70., 71., 72., 0., 0., -0., 0., 0., 0., 0., 0., 70., 0., 72., 0., 0., 0., 0., 0., 0., 0., 0., 0., 71., 72., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 72., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.}); + auto columnsExpected = NDArrayFactory::create('c', {bS, iC, kD, kH, kW, oD, oH, oW}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, +10.f, 11.f, 12.f, 2.f, 0.f, 4.f, 0.f, 6.f, 0.f, 8.f, 0.f, 10.f, 0.f, 12.f, 0.f, 3.f, 4.f, 5.f, 6.f, 0.f, 0.f, 9.f, 10.f, 11.f, 12.f, 0.f, 0.f, 4.f, 0.f, 6.f, 0.f, 0.f, 0.f, 10.f, 0.f, 12.f, 0.f, 0.f, 0.f, 5.f, 6.f, 0.f, 0.f, 0.f, 0.f, 11.f, 12.f, 0.f, 0.f, 0.f, 0.f, 6.f, 0.f, 0.f, 0.f, 0.f, 0.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 7.f, 8.f, +9.f, 10.f, 11.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 8.f, 0.f, 10.f, 0.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 9.f, 10.f, 11.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 10.f, 0.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 11.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, +0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 14.f, 0.f, 16.f, 0.f, 18.f, 0.f, 20.f, 0.f, 22.f, 0.f, 24.f, 0.f, 15.f, 16.f, 17.f, 18.f, 0.f, 0.f, 21.f, 22.f, 23.f, 24.f, 0.f, 0.f, 16.f, 0.f, 18.f, 0.f, 0.f, 0.f, 22.f, 0.f, 24.f, 0.f, 0.f, 0.f, 17.f, 18.f, 0.f, 0.f, 0.f, 0.f, +23.f, 24.f, 0.f, 0.f, 0.f, 0.f, 18.f, 0.f, 0.f, 0.f, 0.f, 0.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 20.f, 0.f, 22.f, 0.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 21.f, 22.f, 23.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 22.f, 0.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, +0.f, 0.f, 0.f, 23.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 26.f, 0.f, 28.f, 0.f, 30.f, 0.f, 32.f, 0.f, 34.f, 0.f, 36.f, 0.f, 27.f, 28.f, 29.f, 30.f, 0.f, 0.f, 33.f, 34.f, 35.f, 36.f, +0.f, 0.f, 28.f, 0.f, 30.f, 0.f, 0.f, 0.f, 34.f, 0.f, 36.f, 0.f, 0.f, 0.f, 29.f, 30.f, 0.f, 0.f, 0.f, 0.f, 35.f, 36.f, 0.f, 0.f, 0.f, 0.f, 30.f, 0.f, 0.f, 0.f, 0.f, 0.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 32.f, 0.f, 34.f, 0.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 33.f, +34.f, 35.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 34.f, 0.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 35.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 38.f, 0.f, 40.f, +0.f, 42.f, 0.f, 44.f, 0.f, 46.f, 0.f, 48.f, 0.f, 39.f, 40.f, 41.f, 42.f, 0.f, 0.f, 45.f, 46.f, 47.f, 48.f, 0.f, 0.f, 40.f, 0.f, 42.f, 0.f, 0.f, 0.f, 46.f, 0.f, 48.f, 0.f, 0.f, 0.f, 41.f, 42.f, 0.f, 0.f, 0.f, 0.f, 47.f, 48.f, 0.f, 0.f, 0.f, 0.f, 42.f, 0.f, 0.f, 0.f, 0.f, 0.f, 48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 43.f, 44.f, 45.f, 46.f, 47.f, +48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 44.f, 0.f, 46.f, 0.f, 48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 45.f, 46.f, 47.f, 48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 46.f, 0.f, 48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 47.f, 48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, +0.f, 0.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 50.f, 0.f, 52.f, 0.f, 54.f, 0.f, 56.f, 0.f, 58.f, 0.f, 60.f, 0.f, 51.f, 52.f, 53.f, 54.f, 0.f, 0.f, 57.f, 58.f, 59.f, 60.f, 0.f, 0.f, 52.f, 0.f, 54.f, 0.f, 0.f, 0.f, 58.f, 0.f, 60.f, 0.f, 0.f, 0.f, 53.f, 54.f, 0.f, 0.f, 0.f, 0.f, 59.f, 60.f, 0.f, 0.f, +0.f, 0.f, 54.f, 0.f, 0.f, 0.f, 0.f, 0.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 56.f, 0.f, 58.f, 0.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 57.f, 58.f, 59.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 58.f, 0.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 59.f, 60.f, +0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 62.f, 0.f, 64.f, 0.f, 66.f, 0.f, 68.f, 0.f, 70.f, 0.f, 72.f, 0.f, 63.f, 64.f, 65.f, 66.f, 0.f, 0.f, 69.f, 70.f, 71.f, 72.f, 0.f, 0.f, 64.f, 0.f, 66.f, +0.f, 0.f, 0.f, 70.f, 0.f, 72.f, 0.f, 0.f, 0.f, 65.f, 66.f, 0.f, 0.f, 0.f, 0.f, 71.f, 72.f, 0.f, 0.f, 0.f, 0.f, 66.f, 0.f, 0.f, 0.f, 0.f, 0.f, 72.f, 0.f, 0.f, 0.f, 0.f, 0.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 68.f, 0.f, 70.f, 0.f, 72.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 69.f, 70.f, 71.f, 72.f, 0.f, 0.f, +0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 70.f, 0.f, 72.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 71.f, 72.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 72.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}); graph::Context context(1); nd4j::ops::ConvolutionUtils::vol2col(context, volume, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW); @@ -2146,7 +2145,7 @@ TEST_F(ConvolutionTests1, col2im_test1) { auto columns = NDArrayFactory::create('c', {bS, iC, kH, kW, oH, oW}); columns.linspace(1); - auto imageExpected = NDArrayFactory::create('c', {bS, iC, iH, iW}, {1., 7., 12., 34., 17., 39., 44., 98., 33., 71., 76., 162., 49., 103., 108., 226.}); + auto imageExpected = NDArrayFactory::create('c', {bS, iC, iH, iW}, {1.f, 7.f, 12.f, 34.f, 17.f, 39.f, 44.f, 98.f, 33.f, 71.f, 76.f, 162.f, 49.f, 103.f, 108.f, 226.f}); LaunchContext ctx; nd4j::ops::helpers::col2im(ctx, columns, image, sH, sW, pH, pW, iH, iW, dH, dW); @@ -2165,12 +2164,12 @@ TEST_F(ConvolutionTests1, upsampling2d_test1) { auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); input.linspace(1); - auto expOutput = NDArrayFactory::create('c', {bS, iH*factorH, iW*factorW, iC}, {1., 2., 3., 1., 2., 3., 1., 2., 3., 4., 5., 6., 4., 5., 6., 4., 5., 6., 1., 2., 3., 1., 2., 3., 1., 2., 3., 4., 5., 6., 4., 5., 6., 4., 5., 6., - 7., 8., 9., 7., 8., 9., 7., 8., 9.,10., 11., 12.,10., 11., 12.,10., 11., 12., 7., 8., 9., 7., 8., 9., 7., 8., 9.,10., 11., 12.,10., 11., 12.,10., 11., 12., - 13., 14., 15.,13., 14., 15.,13., 14., 15.,16., 17., 18.,16., 17., 18.,16., 17., 18.,13., 14., 15.,13., 14., 15.,13., 14., 15.,16., 17., 18.,16., 17., 18.,16., 17., 18., - 19., 20., 21.,19., 20., 21.,19., 20., 21.,22., 23., 24.,22., 23., 24.,22., 23., 24.,19., 20., 21.,19., 20., 21.,19., 20., 21.,22., 23., 24.,22., 23., 24.,22., 23., 24., - 25., 26., 27.,25., 26., 27.,25., 26., 27.,28., 29., 30.,28., 29., 30.,28., 29., 30.,25., 26., 27.,25., 26., 27.,25., 26., 27.,28., 29., 30.,28., 29., 30.,28., 29., 30., - 31., 32., 33.,31., 32., 33.,31., 32., 33.,34., 35., 36.,34., 35., 36.,34., 35., 36.,31., 32., 33.,31., 32., 33.,31., 32., 33.,34., 35., 36.,34., 35., 36.,34., 35., 36.}); + auto expOutput = NDArrayFactory::create('c', {bS, iH*factorH, iW*factorW, iC}, {1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, + 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, + 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, + 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, + 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, + 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f}); nd4j::ops::upsampling2d op; auto results = op.execute({&input}, {}, {factorH, factorW, isNCHW}); @@ -2193,12 +2192,12 @@ TEST_F(ConvolutionTests1, upsampling2d_test2) { auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); input.linspace(1); - auto expOutput = NDArrayFactory::create('c', {bS, iC, iH*factorH, iW*factorW}, {1., 1., 1., 2., 2., 2., 1., 1., 1., 2., 2., 2., 3., 3., 3., 4., 4., 4., 3., 3., 3., 4., 4., 4., - 5., 5., 5., 6., 6., 6., 5., 5., 5., 6., 6., 6., 7., 7., 7., 8., 8., 8., 7., 7., 7., 8., 8., 8., 9., 9., 9., 10., 10., 10., 9., 9., 9., 10., 10., 10.,11., 11., 11., 12., 12., 12.,11., 11., 11., 12., 12., 12., - 13., 13., 13., 14., 14., 14.,13., 13., 13., 14., 14., 14.,15., 15., 15., 16., 16., 16.,15., 15., 15., 16., 16., 16.,17., 17., 17., 18., 18., 18.,17., 17., 17., 18., 18., 18.,19., 19., 19., 20., 20., 20.,19., 19., 19., 20., 20., 20., - 21., 21., 21., 22., 22., 22.,21., 21., 21., 22., 22., 22.,23., 23., 23., 24., 24., 24.,23., 23., 23., 24., 24., 24.,25., 25., 25., 26., 26., 26.,25., 25., 25., 26., 26., 26.,27., 27., 27., 28., 28., 28.,27., 27., 27., 28., 28., 28., - 29., 29., 29., 30., 30., 30.,29., 29., 29., 30., 30., 30.,31., 31., 31., 32., 32., 32.,31., 31., 31., 32., 32., 32., - 33., 33., 33., 34., 34., 34.,33., 33., 33., 34., 34., 34.,35., 35., 35., 36., 36., 36.,35., 35., 35., 36., 36., 36.}); + auto expOutput = NDArrayFactory::create('c', {bS, iC, iH*factorH, iW*factorW}, {1.f, 1.f, 1.f, 2.f, 2.f, 2.f, 1.f, 1.f, 1.f, 2.f, 2.f, 2.f, 3.f, 3.f, 3.f, 4.f, 4.f, 4.f, 3.f, 3.f, 3.f, 4.f, 4.f, 4.f, + 5.f, 5.f, 5.f, 6.f, 6.f, 6.f, 5.f, 5.f, 5.f, 6.f, 6.f, 6.f, 7.f, 7.f, 7.f, 8.f, 8.f, 8.f, 7.f, 7.f, 7.f, 8.f, 8.f, 8.f, 9.f, 9.f, 9.f, 10.f, 10.f, 10.f, 9.f, 9.f, 9.f, 10.f, 10.f, 10.f, 11.f, 11.f, 11.f, 12.f, 12.f, 12.f, 11.f, 11.f, 11.f, 12.f, 12.f, 12.f, + 13.f, 13.f, 13.f, 14.f, 14.f, 14.f, 13.f, 13.f, 13.f, 14.f, 14.f, 14.f, 15.f, 15.f, 15.f, 16.f, 16.f, 16.f, 15.f, 15.f, 15.f, 16.f, 16.f, 16.f, 17.f, 17.f, 17.f, 18.f, 18.f, 18.f, 17.f, 17.f, 17.f, 18.f, 18.f, 18.f, 19.f, 19.f, 19.f, 20.f, 20.f, 20.f, 19.f, 19.f, 19.f, 20.f, 20.f, 20.f, + 21.f, 21.f, 21.f, 22.f, 22.f, 22.f, 21.f, 21.f, 21.f, 22.f, 22.f, 22.f, 23.f, 23.f, 23.f, 24.f, 24.f, 24.f, 23.f, 23.f, 23.f, 24.f, 24.f, 24.f, 25.f, 25.f, 25.f, 26.f, 26.f, 26.f, 25.f, 25.f, 25.f, 26.f, 26.f, 26.f, 27.f, 27.f, 27.f, 28.f, 28.f, 28.f, 27.f, 27.f, 27.f, 28.f, 28.f, 28.f, + 29.f, 29.f, 29.f, 30.f, 30.f, 30.f, 29.f, 29.f, 29.f, 30.f, 30.f, 30.f, 31.f, 31.f, 31.f, 32.f, 32.f, 32.f, 31.f, 31.f, 31.f, 32.f, 32.f, 32.f, + 33.f, 33.f, 33.f, 34.f, 34.f, 34.f, 33.f, 33.f, 33.f, 34.f, 34.f, 34.f, 35.f, 35.f, 35.f, 36.f, 36.f, 36.f, 35.f, 35.f, 35.f, 36.f, 36.f, 36.f}); nd4j::ops::upsampling2d op; auto results = op.execute({&input}, {}, {factorH, factorW, isNCHW}); @@ -2222,21 +2221,21 @@ TEST_F(ConvolutionTests1, upsampling3d_test1) { auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); input.linspace(1); - auto expOutput = NDArrayFactory::create('c', {bS, iD*factorD, iH*factorH, iW*factorW, iC}, {1., 2., 3., 1., 2., 3., 4., 5., 6., 4., 5., 6., 1., 2., 3., 1., 2., 3., 4., 5., 6., 4., 5., 6., 1., 2., 3., 1., 2., 3., 4., 5., 6., 4., 5., 6., 7., 8., 9., 7., 8., 9.,10., 11., 12.,10., 11., 12., 7., 8., 9., 7., 8., 9.,10., 11., 12.,10., 11., 12., - 7., 8., 9., 7., 8., 9.,10., 11., 12.,10., 11., 12., 1., 2., 3., 1., 2., 3., 4., 5., 6., 4., 5., 6., 1., 2., 3., 1., 2., 3., 4., 5., 6., 4., 5., 6., 1., 2., 3., 1., 2., 3., 4., 5., 6., 4., 5., 6., 7., 8., 9., 7., 8., 9.,10., 11., 12.,10., 11., 12., - 7., 8., 9., 7., 8., 9.,10., 11., 12.,10., 11., 12., 7., 8., 9., 7., 8., 9.,10., 11., 12.,10., 11., 12.,13., 14., 15.,13., 14., 15.,16., 17., 18.,16., 17., 18.,13., 14., 15.,13., 14., 15.,16., 17., 18.,16., 17., 18.,13., 14., 15.,13., 14., 15.,16., 17., 18.,16., 17., 18., - 19., 20., 21.,19., 20., 21.,22., 23., 24.,22., 23., 24.,19., 20., 21.,19., 20., 21.,22., 23., 24.,22., 23., 24.,19., 20., 21.,19., 20., 21.,22., 23., 24.,22., 23., 24.,13., 14., 15.,13., 14., 15.,16., 17., 18.,16., 17., 18.,13., 14., 15.,13., 14., 15.,16., 17., 18.,16., 17., 18., - 13., 14., 15.,13., 14., 15.,16., 17., 18.,16., 17., 18.,19., 20., 21.,19., 20., 21.,22., 23., 24.,22., 23., 24.,19., 20., 21.,19., 20., 21.,22., 23., 24.,22., 23., 24.,19., 20., 21.,19., 20., 21.,22., 23., 24.,22., 23., 24.,25., 26., 27.,25., 26., 27.,28., 29., 30.,28., 29., 30., - 25., 26., 27.,25., 26., 27.,28., 29., 30.,28., 29., 30.,25., 26., 27.,25., 26., 27.,28., 29., 30.,28., 29., 30.,31., 32., 33.,31., 32., 33.,34., 35., 36.,34., 35., 36.,31., 32., 33.,31., 32., 33.,34., 35., 36.,34., 35., 36.,31., 32., 33.,31., 32., 33.,34., 35., 36.,34., 35., 36., - 25., 26., 27.,25., 26., 27.,28., 29., 30.,28., 29., 30.,25., 26., 27.,25., 26., 27.,28., 29., 30.,28., 29., 30.,25., 26., 27.,25., 26., 27.,28., 29., 30.,28., 29., 30.,31., 32., 33.,31., 32., 33.,34., 35., 36.,34., 35., 36.,31., 32., 33.,31., 32., 33.,34., 35., 36.,34., 35., 36., - 31., 32., 33.,31., 32., 33.,34., 35., 36.,34., 35., 36.,37., 38., 39.,37., 38., 39.,40., 41., 42.,40., 41., 42.,37., 38., 39.,37., 38., 39.,40., 41., 42.,40., 41., 42.,37., 38., 39.,37., 38., 39.,40., 41., 42.,40., 41., 42.,43., 44., 45.,43., 44., 45.,46., 47., 48.,46., 47., 48., - 43., 44., 45.,43., 44., 45.,46., 47., 48.,46., 47., 48.,43., 44., 45.,43., 44., 45.,46., 47., 48.,46., 47., 48.,37., 38., 39.,37., 38., 39.,40., 41., 42.,40., 41., 42.,37., 38., 39.,37., 38., 39.,40., 41., 42.,40., 41., 42.,37., 38., 39.,37., 38., 39.,40., 41., 42.,40., 41., 42., - 43., 44., 45.,43., 44., 45.,46., 47., 48.,46., 47., 48.,43., 44., 45.,43., 44., 45.,46., 47., 48.,46., 47., 48.,43., 44., 45.,43., 44., 45.,46., 47., 48.,46., 47., 48.,49., 50., 51.,49., 50., 51.,52., 53., 54.,52., 53., 54.,49., 50., 51.,49., 50., 51.,52., 53., 54.,52., 53., 54., - 49., 50., 51.,49., 50., 51.,52., 53., 54.,52., 53., 54.,55., 56., 57.,55., 56., 57.,58., 59., 60.,58., 59., 60.,55., 56., 57.,55., 56., 57.,58., 59., 60.,58., 59., 60.,55., 56., 57.,55., 56., 57.,58., 59., 60.,58., 59., 60.,49., 50., 51.,49., 50., 51.,52., 53., 54.,52., 53., 54., - 49., 50., 51.,49., 50., 51.,52., 53., 54.,52., 53., 54.,49., 50., 51.,49., 50., 51.,52., 53., 54.,52., 53., 54.,55., 56., 57.,55., 56., 57.,58., 59., 60.,58., 59., 60.,55., 56., 57.,55., 56., 57.,58., 59., 60.,58., 59., 60.,55., 56., 57.,55., 56., 57.,58., 59., 60.,58., 59., 60., - 61., 62., 63.,61., 62., 63.,64., 65., 66.,64., 65., 66.,61., 62., 63.,61., 62., 63.,64., 65., 66.,64., 65., 66.,61., 62., 63.,61., 62., 63.,64., 65., 66.,64., 65., 66.,67., 68., 69.,67., 68., 69.,70., 71., 72.,70., 71., 72.,67., 68., 69.,67., 68., 69.,70., 71., 72.,70., 71., 72., - 67., 68., 69.,67., 68., 69.,70., 71., 72.,70., 71., 72.,61., 62., 63.,61., 62., 63.,64., 65., 66.,64., 65., 66.,61., 62., 63.,61., 62., 63.,64., 65., 66.,64., 65., 66.,61., 62., 63.,61., 62., 63.,64., 65., 66.,64., 65., 66.,67., 68., 69.,67., 68., 69.,70., 71., 72.,70., 71., 72., - 67., 68., 69.,67., 68., 69.,70., 71., 72.,70., 71., 72.,67., 68., 69.,67., 68., 69.,70., 71., 72.,70., 71., 72.}); + auto expOutput = NDArrayFactory::create('c', {bS, iD*factorD, iH*factorH, iW*factorW, iC}, {1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, + 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, + 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, + 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, + 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, + 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, + 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, + 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 37.f, 38.f, 39.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 40.f, 41.f, 42.f, 37.f, 38.f, 39.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 40.f, 41.f, 42.f, 37.f, 38.f, 39.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 40.f, 41.f, 42.f, 43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 46.f, 47.f, 48.f, + 43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 46.f, 47.f, 48.f, 43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 46.f, 47.f, 48.f, 37.f, 38.f, 39.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 40.f, 41.f, 42.f, 37.f, 38.f, 39.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 40.f, 41.f, 42.f, 37.f, 38.f, 39.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 40.f, 41.f, 42.f, + 43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 46.f, 47.f, 48.f, 43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 46.f, 47.f, 48.f, 43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 46.f, 47.f, 48.f, 49.f, 50.f, 51.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f, 49.f, 50.f, 51.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f, + 49.f, 50.f, 51.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f, 55.f, 56.f, 57.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 58.f, 59.f, 60.f, 55.f, 56.f, 57.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 58.f, 59.f, 60.f, 55.f, 56.f, 57.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 58.f, 59.f, 60.f, 49.f, 50.f, 51.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f, + 49.f, 50.f, 51.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f, 49.f, 50.f, 51.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f, 55.f, 56.f, 57.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 58.f, 59.f, 60.f, 55.f, 56.f, 57.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 58.f, 59.f, 60.f, 55.f, 56.f, 57.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 58.f, 59.f, 60.f, + 61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 64.f, 65.f, 66.f, 61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 64.f, 65.f, 66.f, 61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 64.f, 65.f, 66.f, 67.f, 68.f, 69.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f, 67.f, 68.f, 69.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f, + 67.f, 68.f, 69.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f, 61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 64.f, 65.f, 66.f, 61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 64.f, 65.f, 66.f, 61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 64.f, 65.f, 66.f, 67.f, 68.f, 69.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f, + 67.f, 68.f, 69.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f, 67.f, 68.f, 69.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f}); nd4j::ops::upsampling3d op; auto results = op.execute({&input}, {}, {factorD, factorH, factorW, isNCDHW}); @@ -2259,18 +2258,18 @@ TEST_F(ConvolutionTests1, upsampling3d_test2) { auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); input.linspace(1); - auto expOutput = NDArrayFactory::create('c', {bS, iC, iD*factorD, iH*factorH, iW*factorW}, { 1., 1., 2., 2., 1., 1., 2., 2., 1., 1., 2., 2., 3., 3., 4., 4., 3., 3., 4., 4., 3., 3., 4., 4., 1., 1., 2., 2., 1., 1., 2., 2., 1., 1., 2., 2., 3., 3., 4., 4., 3., 3., 4., 4., 3., 3., 4., 4., 5., 5., 6., 6., 5., 5., 6., 6., 5., 5., 6., 6., 7., 7., 8., 8., 7., 7., 8., 8., 7., 7., 8., 8., - 5., 5., 6., 6., 5., 5., 6., 6., 5., 5., 6., 6., 7., 7., 8., 8., 7., 7., 8., 8., 7., 7., 8., 8., 9., 9., 10., 10., 9., 9., 10., 10., 9., 9., 10., 10.,11., 11., 12., 12.,11., 11., 12., 12.,11., 11., 12., 12., 9., 9., 10., 10., 9., 9., 10., 10., 9., 9., 10., 10.,11., 11., 12., 12.,11., 11., 12., 12.,11., 11., 12., 12., - 13., 13., 14., 14.,13., 13., 14., 14.,13., 13., 14., 14.,15., 15., 16., 16.,15., 15., 16., 16.,15., 15., 16., 16.,13., 13., 14., 14.,13., 13., 14., 14.,13., 13., 14., 14.,15., 15., 16., 16.,15., 15., 16., 16.,15., 15., 16., 16.,17., 17., 18., 18.,17., 17., 18., 18.,17., 17., 18., 18.,19., 19., 20., 20.,19., 19., 20., 20.,19., 19., 20., 20., - 17., 17., 18., 18.,17., 17., 18., 18.,17., 17., 18., 18.,19., 19., 20., 20.,19., 19., 20., 20.,19., 19., 20., 20.,21., 21., 22., 22.,21., 21., 22., 22.,21., 21., 22., 22.,23., 23., 24., 24.,23., 23., 24., 24.,23., 23., 24., 24.,21., 21., 22., 22.,21., 21., 22., 22.,21., 21., 22., 22.,23., 23., 24., 24.,23., 23., 24., 24.,23., 23., 24., 24., - 25., 25., 26., 26.,25., 25., 26., 26.,25., 25., 26., 26.,27., 27., 28., 28.,27., 27., 28., 28.,27., 27., 28., 28.,25., 25., 26., 26.,25., 25., 26., 26.,25., 25., 26., 26.,27., 27., 28., 28.,27., 27., 28., 28.,27., 27., 28., 28.,29., 29., 30., 30.,29., 29., 30., 30.,29., 29., 30., 30.,31., 31., 32., 32.,31., 31., 32., 32.,31., 31., 32., 32., - 29., 29., 30., 30.,29., 29., 30., 30.,29., 29., 30., 30.,31., 31., 32., 32.,31., 31., 32., 32.,31., 31., 32., 32.,33., 33., 34., 34.,33., 33., 34., 34.,33., 33., 34., 34.,35., 35., 36., 36.,35., 35., 36., 36.,35., 35., 36., 36.,33., 33., 34., 34.,33., 33., 34., 34.,33., 33., 34., 34.,35., 35., 36., 36.,35., 35., 36., 36.,35., 35., 36., 36., - 37., 37., 38., 38.,37., 37., 38., 38.,37., 37., 38., 38.,39., 39., 40., 40.,39., 39., 40., 40.,39., 39., 40., 40.,37., 37., 38., 38.,37., 37., 38., 38.,37., 37., 38., 38.,39., 39., 40., 40.,39., 39., 40., 40.,39., 39., 40., 40.,41., 41., 42., 42.,41., 41., 42., 42.,41., 41., 42., 42.,43., 43., 44., 44.,43., 43., 44., 44.,43., 43., 44., 44., - 41., 41., 42., 42.,41., 41., 42., 42.,41., 41., 42., 42.,43., 43., 44., 44.,43., 43., 44., 44.,43., 43., 44., 44.,45., 45., 46., 46.,45., 45., 46., 46.,45., 45., 46., 46.,47., 47., 48., 48.,47., 47., 48., 48.,47., 47., 48., 48.,45., 45., 46., 46.,45., 45., 46., 46.,45., 45., 46., 46.,47., 47., 48., 48.,47., 47., 48., 48.,47., 47., 48., 48., - 49., 49., 50., 50.,49., 49., 50., 50.,49., 49., 50., 50.,51., 51., 52., 52.,51., 51., 52., 52.,51., 51., 52., 52.,49., 49., 50., 50.,49., 49., 50., 50.,49., 49., 50., 50.,51., 51., 52., 52.,51., 51., 52., 52.,51., 51., 52., 52.,53., 53., 54., 54.,53., 53., 54., 54.,53., 53., 54., 54.,55., 55., 56., 56.,55., 55., 56., 56.,55., 55., 56., 56., - 53., 53., 54., 54.,53., 53., 54., 54.,53., 53., 54., 54.,55., 55., 56., 56.,55., 55., 56., 56.,55., 55., 56., 56.,57., 57., 58., 58.,57., 57., 58., 58.,57., 57., 58., 58.,59., 59., 60., 60.,59., 59., 60., 60.,59., 59., 60., 60.,57., 57., 58., 58.,57., 57., 58., 58.,57., 57., 58., 58.,59., 59., 60., 60.,59., 59., 60., 60.,59., 59., 60., 60., - 61., 61., 62., 62.,61., 61., 62., 62.,61., 61., 62., 62.,63., 63., 64., 64.,63., 63., 64., 64.,63., 63., 64., 64.,61., 61., 62., 62.,61., 61., 62., 62.,61., 61., 62., 62.,63., 63., 64., 64.,63., 63., 64., 64.,63., 63., 64., 64.,65., 65., 66., 66.,65., 65., 66., 66.,65., 65., 66., 66.,67., 67., 68., 68.,67., 67., 68., 68.,67., 67., 68., 68., - 65., 65., 66., 66.,65., 65., 66., 66.,65., 65., 66., 66.,67., 67., 68., 68.,67., 67., 68., 68.,67., 67., 68., 68.,69., 69., 70., 70.,69., 69., 70., 70.,69., 69., 70., 70.,71., 71., 72., 72.,71., 71., 72., 72.,71., 71., 72., 72.,69., 69., 70., 70.,69., 69., 70., 70.,69., 69., 70., 70.,71., 71., 72., 72.,71., 71., 72., 72.,71., 71., 72., 72.}); + auto expOutput = NDArrayFactory::create('c', {bS, iC, iD*factorD, iH*factorH, iW*factorW}, { 1.f, 1.f, 2.f, 2.f, 1.f, 1.f, 2.f, 2.f, 1.f, 1.f, 2.f, 2.f, 3.f, 3.f, 4.f, 4.f, 3.f, 3.f, 4.f, 4.f, 3.f, 3.f, 4.f, 4.f, 1.f, 1.f, 2.f, 2.f, 1.f, 1.f, 2.f, 2.f, 1.f, 1.f, 2.f, 2.f, 3.f, 3.f, 4.f, 4.f, 3.f, 3.f, 4.f, 4.f, 3.f, 3.f, 4.f, 4.f, 5.f, 5.f, 6.f, 6.f, 5.f, 5.f, 6.f, 6.f, 5.f, 5.f, 6.f, 6.f, 7.f, 7.f, 8.f, 8.f, 7.f, 7.f, 8.f, 8.f, 7.f, 7.f, 8.f, 8.f, + 5.f, 5.f, 6.f, 6.f, 5.f, 5.f, 6.f, 6.f, 5.f, 5.f, 6.f, 6.f, 7.f, 7.f, 8.f, 8.f, 7.f, 7.f, 8.f, 8.f, 7.f, 7.f, 8.f, 8.f, 9.f, 9.f, 10.f, 10.f, 9.f, 9.f, 10.f, 10.f, 9.f, 9.f, 10.f, 10.f, 11.f, 11.f, 12.f, 12.f, 11.f, 11.f, 12.f, 12.f, 11.f, 11.f, 12.f, 12.f, 9.f, 9.f, 10.f, 10.f, 9.f, 9.f, 10.f, 10.f, 9.f, 9.f, 10.f, 10.f, 11.f, 11.f, 12.f, 12.f, 11.f, 11.f, 12.f, 12.f, 11.f, 11.f, 12.f, 12.f, + 13.f, 13.f, 14.f, 14.f, 13.f, 13.f, 14.f, 14.f, 13.f, 13.f, 14.f, 14.f, 15.f, 15.f, 16.f, 16.f, 15.f, 15.f, 16.f, 16.f, 15.f, 15.f, 16.f, 16.f, 13.f, 13.f, 14.f, 14.f, 13.f, 13.f, 14.f, 14.f, 13.f, 13.f, 14.f, 14.f, 15.f, 15.f, 16.f, 16.f, 15.f, 15.f, 16.f, 16.f, 15.f, 15.f, 16.f, 16.f, 17.f, 17.f, 18.f, 18.f, 17.f, 17.f, 18.f, 18.f, 17.f, 17.f, 18.f, 18.f, 19.f, 19.f, 20.f, 20.f, 19.f, 19.f, 20.f, 20.f, 19.f, 19.f, 20.f, 20.f, + 17.f, 17.f, 18.f, 18.f, 17.f, 17.f, 18.f, 18.f, 17.f, 17.f, 18.f, 18.f, 19.f, 19.f, 20.f, 20.f, 19.f, 19.f, 20.f, 20.f, 19.f, 19.f, 20.f, 20.f, 21.f, 21.f, 22.f, 22.f, 21.f, 21.f, 22.f, 22.f, 21.f, 21.f, 22.f, 22.f, 23.f, 23.f, 24.f, 24.f, 23.f, 23.f, 24.f, 24.f, 23.f, 23.f, 24.f, 24.f, 21.f, 21.f, 22.f, 22.f, 21.f, 21.f, 22.f, 22.f, 21.f, 21.f, 22.f, 22.f, 23.f, 23.f, 24.f, 24.f, 23.f, 23.f, 24.f, 24.f, 23.f, 23.f, 24.f, 24.f, + 25.f, 25.f, 26.f, 26.f, 25.f, 25.f, 26.f, 26.f, 25.f, 25.f, 26.f, 26.f, 27.f, 27.f, 28.f, 28.f, 27.f, 27.f, 28.f, 28.f, 27.f, 27.f, 28.f, 28.f, 25.f, 25.f, 26.f, 26.f, 25.f, 25.f, 26.f, 26.f, 25.f, 25.f, 26.f, 26.f, 27.f, 27.f, 28.f, 28.f, 27.f, 27.f, 28.f, 28.f, 27.f, 27.f, 28.f, 28.f, 29.f, 29.f, 30.f, 30.f, 29.f, 29.f, 30.f, 30.f, 29.f, 29.f, 30.f, 30.f, 31.f, 31.f, 32.f, 32.f, 31.f, 31.f, 32.f, 32.f, 31.f, 31.f, 32.f, 32.f, + 29.f, 29.f, 30.f, 30.f, 29.f, 29.f, 30.f, 30.f, 29.f, 29.f, 30.f, 30.f, 31.f, 31.f, 32.f, 32.f, 31.f, 31.f, 32.f, 32.f, 31.f, 31.f, 32.f, 32.f, 33.f, 33.f, 34.f, 34.f, 33.f, 33.f, 34.f, 34.f, 33.f, 33.f, 34.f, 34.f, 35.f, 35.f, 36.f, 36.f, 35.f, 35.f, 36.f, 36.f, 35.f, 35.f, 36.f, 36.f, 33.f, 33.f, 34.f, 34.f, 33.f, 33.f, 34.f, 34.f, 33.f, 33.f, 34.f, 34.f, 35.f, 35.f, 36.f, 36.f, 35.f, 35.f, 36.f, 36.f, 35.f, 35.f, 36.f, 36.f, + 37.f, 37.f, 38.f, 38.f, 37.f, 37.f, 38.f, 38.f, 37.f, 37.f, 38.f, 38.f, 39.f, 39.f, 40.f, 40.f, 39.f, 39.f, 40.f, 40.f, 39.f, 39.f, 40.f, 40.f, 37.f, 37.f, 38.f, 38.f, 37.f, 37.f, 38.f, 38.f, 37.f, 37.f, 38.f, 38.f, 39.f, 39.f, 40.f, 40.f, 39.f, 39.f, 40.f, 40.f, 39.f, 39.f, 40.f, 40.f, 41.f, 41.f, 42.f, 42.f, 41.f, 41.f, 42.f, 42.f, 41.f, 41.f, 42.f, 42.f, 43.f, 43.f, 44.f, 44.f, 43.f, 43.f, 44.f, 44.f, 43.f, 43.f, 44.f, 44.f, + 41.f, 41.f, 42.f, 42.f, 41.f, 41.f, 42.f, 42.f, 41.f, 41.f, 42.f, 42.f, 43.f, 43.f, 44.f, 44.f, 43.f, 43.f, 44.f, 44.f, 43.f, 43.f, 44.f, 44.f, 45.f, 45.f, 46.f, 46.f, 45.f, 45.f, 46.f, 46.f, 45.f, 45.f, 46.f, 46.f, 47.f, 47.f, 48.f, 48.f, 47.f, 47.f, 48.f, 48.f, 47.f, 47.f, 48.f, 48.f, 45.f, 45.f, 46.f, 46.f, 45.f, 45.f, 46.f, 46.f, 45.f, 45.f, 46.f, 46.f, 47.f, 47.f, 48.f, 48.f, 47.f, 47.f, 48.f, 48.f, 47.f, 47.f, 48.f, 48.f, + 49.f, 49.f, 50.f, 50.f, 49.f, 49.f, 50.f, 50.f, 49.f, 49.f, 50.f, 50.f, 51.f, 51.f, 52.f, 52.f, 51.f, 51.f, 52.f, 52.f, 51.f, 51.f, 52.f, 52.f, 49.f, 49.f, 50.f, 50.f, 49.f, 49.f, 50.f, 50.f, 49.f, 49.f, 50.f, 50.f, 51.f, 51.f, 52.f, 52.f, 51.f, 51.f, 52.f, 52.f, 51.f, 51.f, 52.f, 52.f, 53.f, 53.f, 54.f, 54.f, 53.f, 53.f, 54.f, 54.f, 53.f, 53.f, 54.f, 54.f, 55.f, 55.f, 56.f, 56.f, 55.f, 55.f, 56.f, 56.f, 55.f, 55.f, 56.f, 56.f, + 53.f, 53.f, 54.f, 54.f, 53.f, 53.f, 54.f, 54.f, 53.f, 53.f, 54.f, 54.f, 55.f, 55.f, 56.f, 56.f, 55.f, 55.f, 56.f, 56.f, 55.f, 55.f, 56.f, 56.f, 57.f, 57.f, 58.f, 58.f, 57.f, 57.f, 58.f, 58.f, 57.f, 57.f, 58.f, 58.f, 59.f, 59.f, 60.f, 60.f, 59.f, 59.f, 60.f, 60.f, 59.f, 59.f, 60.f, 60.f, 57.f, 57.f, 58.f, 58.f, 57.f, 57.f, 58.f, 58.f, 57.f, 57.f, 58.f, 58.f, 59.f, 59.f, 60.f, 60.f, 59.f, 59.f, 60.f, 60.f, 59.f, 59.f, 60.f, 60.f, + 61.f, 61.f, 62.f, 62.f, 61.f, 61.f, 62.f, 62.f, 61.f, 61.f, 62.f, 62.f, 63.f, 63.f, 64.f, 64.f, 63.f, 63.f, 64.f, 64.f, 63.f, 63.f, 64.f, 64.f, 61.f, 61.f, 62.f, 62.f, 61.f, 61.f, 62.f, 62.f, 61.f, 61.f, 62.f, 62.f, 63.f, 63.f, 64.f, 64.f, 63.f, 63.f, 64.f, 64.f, 63.f, 63.f, 64.f, 64.f, 65.f, 65.f, 66.f, 66.f, 65.f, 65.f, 66.f, 66.f, 65.f, 65.f, 66.f, 66.f, 67.f, 67.f, 68.f, 68.f, 67.f, 67.f, 68.f, 68.f, 67.f, 67.f, 68.f, 68.f, + 65.f, 65.f, 66.f, 66.f, 65.f, 65.f, 66.f, 66.f, 65.f, 65.f, 66.f, 66.f, 67.f, 67.f, 68.f, 68.f, 67.f, 67.f, 68.f, 68.f, 67.f, 67.f, 68.f, 68.f, 69.f, 69.f, 70.f, 70.f, 69.f, 69.f, 70.f, 70.f, 69.f, 69.f, 70.f, 70.f, 71.f, 71.f, 72.f, 72.f, 71.f, 71.f, 72.f, 72.f, 71.f, 71.f, 72.f, 72.f, 69.f, 69.f, 70.f, 70.f, 69.f, 69.f, 70.f, 70.f, 69.f, 69.f, 70.f, 70.f, 71.f, 71.f, 72.f, 72.f, 71.f, 71.f, 72.f, 72.f, 71.f, 71.f, 72.f, 72.f}); nd4j::ops::upsampling3d op; auto results = op.execute({&input}, {}, {factorD, factorH, factorW, isNCDHW}); @@ -2413,14 +2412,14 @@ TEST_F(ConvolutionTests1, deconv2d_test1) { auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); auto weights = NDArrayFactory::create('c', {kH, kW, oC, iC}); - auto exp = NDArrayFactory::create('c', {bS, oH, oW, oC}, { 2.75, 7.75, 12.75, 17.75, 22.75, 30.5 , 40.5 , 50.5 , 60.5 , 70.5 , 30.5 , 40.5 , 50.5 , 60.5 , 70.5 , 27.75, 32.75, 37.75, 42.75, 47.75, - 55.5 , 65.5 , 75.5 , 85.5 , 95.5 ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. ,105.5 , 115.5 , 125.5 , 135.5 , 145.5 , - 55.5 , 65.5 , 75.5 , 85.5 , 95.5 ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. ,105.5 , 115.5 , 125.5 , 135.5 , 145.5 , - 52.75, 57.75, 62.75, 67.75, 72.75,130.5 , 140.5 , 150.5 , 160.5 , 170.5 ,130.5 , 140.5 , 150.5 , 160.5 , 170.5 , 77.75, 82.75, 87.75, 92.75, 97.75, - 2.75, 7.75, 12.75, 17.75, 22.75, 30.5 , 40.5 , 50.5 , 60.5 , 70.5 , 30.5 , 40.5 , 50.5 , 60.5 , 70.5 , 27.75, 32.75, 37.75, 42.75, 47.75, - 55.5 , 65.5 , 75.5 , 85.5 , 95.5 ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. ,105.5 , 115.5 , 125.5 , 135.5 , 145.5 , - 55.5 , 65.5 , 75.5 , 85.5 , 95.5 ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. ,105.5 , 115.5 , 125.5 , 135.5 , 145.5 , - 52.75, 57.75, 62.75, 67.75, 72.75,130.5 , 140.5 , 150.5 , 160.5 , 170.5 ,130.5 , 140.5 , 150.5 , 160.5 , 170.5 , 77.75, 82.75, 87.75, 92.75, 97.75}); + auto exp = NDArrayFactory::create('c', {bS, oH, oW, oC}, { 2.75f, 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 27.75f, 32.75f, 37.75f, 42.75f, 47.75f, + 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f, + 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f, + 52.75f, 57.75f, 62.75f, 67.75f, 72.75f, 130.5f, 140.5f, 150.5f, 160.5f, 170.5f, 130.5f, 140.5f, 150.5f, 160.5f, 170.5f, 77.75f, 82.75f, 87.75f, 92.75f, 97.75f, + 2.75f, 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 27.75f, 32.75f, 37.75f, 42.75f, 47.75f, + 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f, + 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f, + 52.75f, 57.75f, 62.75f, 67.75f, 72.75f, 130.5f, 140.5f, 150.5f, 160.5f, 170.5f, 130.5f, 140.5f, 150.5f, 160.5f, 170.5f, 77.75f, 82.75f, 87.75f, 92.75f, 97.75f}); input = 0.5; weights.linspace(0.1, 0.1); @@ -2446,14 +2445,14 @@ TEST_F(ConvolutionTests1, deconv2d_test2) { auto input = NDArrayFactory::create('c', {bS, oH, oW, oC}); auto weights = NDArrayFactory::create('c', {kH, kW, iC, oC}); - auto exp = NDArrayFactory::create('c', {bS, iH, iW, iC}, {2.75, 7.75, 12.75, 17.75, 22.75, 30.5 , 40.5 , 50.5 , 60.5 , 70.5 , 30.5 , 40.5 , 50.5 , 60.5 , 70.5 , 30.5 , 40.5 , 50.5 , 60.5 , 70.5 , - 55.5 , 65.5 , 75.5 , 85.5 , 95.5 ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. , - 55.5 , 65.5 , 75.5 , 85.5 , 95.5 ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. , - 55.5 , 65.5 , 75.5 , 85.5 , 95.5 ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. , - 2.75, 7.75, 12.75, 17.75, 22.75, 30.5 , 40.5 , 50.5 , 60.5 , 70.5 , 30.5 , 40.5 , 50.5 , 60.5 , 70.5 , 30.5 , 40.5 , 50.5 , 60.5 , 70.5 , - 55.5 , 65.5 , 75.5 , 85.5 , 95.5 ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. , - 55.5 , 65.5 , 75.5 , 85.5 , 95.5 ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. , - 55.5 , 65.5 , 75.5 , 85.5 , 95.5 ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. }); + auto exp = NDArrayFactory::create('c', {bS, iH, iW, iC}, {2.75f, 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, + 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, + 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, + 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, + 2.75f, 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, + 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, + 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, + 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f }); input = 0.5; weights.linspace(0.1, 0.1); @@ -2480,10 +2479,10 @@ TEST_F(ConvolutionTests1, deconv2d_test3) { auto weights = NDArrayFactory::create('c', {kH, kW, oC, iC}); auto bias = NDArrayFactory::create('c', {oC}); - auto exp = NDArrayFactory::create('c', {bS, oH, oW, oC}, {-2.9, -6.8, -10.7, -2.6, -6.1, -9.6, -16.9, -23.9, -30.9, -13.1, -16.6, -20.1, -11.6, -14.7, -17.8, -2.0, -4.7, -7.4, -1.7, -4.0, -6.3, -11.5, -16.1, - -20.7, -8.6, -10.9, -13.2, -7.1, -9.0, -10.9, -27.4, -32.8, -38.2, -24.4, -29.0, -33.6, -65.0, -74.2, -83.4, -38.2, -42.8, -47.4, - -32.8, -36.6, -40.4, -18.2, -20.9, -23.6, -15.5, -17.8, -20.1, -39.1, -43.7, -48.3, -22.4, -24.7, -27.0, -18.5, -20.4, -22.3, -10.1, -11.6, -13.1, - -7.4, -8.5, -9.6, -19.3, -21.5, -23.7, -10.7, -11.8, -12.9, -6.8, -7.5, -8.2}); + auto exp = NDArrayFactory::create('c', {bS, oH, oW, oC}, {-2.9f, -6.8f, -10.7f, -2.6f, -6.1f, -9.6f, -16.9f, -23.9f, -30.9f, -13.1f, -16.6f, -20.1f, -11.6f, -14.7f, -17.8f, -2.0f, -4.7f, -7.4f, -1.7f, -4.0f, -6.3f, -11.5f, -16.1f, + -20.7f, -8.6f, -10.9f, -13.2f, -7.1f, -9.0f, -10.9f, -27.4f, -32.8f, -38.2f, -24.4f, -29.0f, -33.6f, -65.0f, -74.2f, -83.4f, -38.2f, -42.8f, -47.4f, + -32.8f, -36.6f, -40.4f, -18.2f, -20.9f, -23.6f, -15.5f, -17.8f, -20.1f, -39.1f, -43.7f, -48.3f, -22.4f, -24.7f, -27.0f, -18.5f, -20.4f, -22.3f, -10.1f, -11.6f, -13.1f, + -7.4f, -8.5f, -9.6f, -19.3f, -21.5f, -23.7f, -10.7f, -11.8f, -12.9f, -6.8f, -7.5f, -8.2f}); input.linspace(-10, 0.5); weights.linspace(0.1, 0.1); @@ -2568,17 +2567,17 @@ TYPED_TEST(TypedConvolutionTests1, deconv2d_test6) { int dataFormat = 0; // 1-NHWC, 0-NCHW auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); - auto weights = NDArrayFactory::create('c', {kH, kW, oC, iC}, {1., 76., 151., 26., 101., 176., 51., 126., 201., 2., 77., 152., 27., 102., 177., 52., 127., 202., 3., 78., 153., 28., 103., 178., 53., 128., 203., - 4., 79., 154., 29., 104., 179., 54., 129., 204., 5., 80., 155., 30., 105., 180., 55., 130., 205., 6., 81., 156., 31., 106., 181., 56., 131., 206., - 7., 82., 157., 32., 107., 182., 57., 132., 207., 8., 83., 158., 33., 108., 183., 58., 133., 208., 9., 84., 159., 34., 109., 184., 59., 134., 209., - 10., 85., 160., 35., 110., 185., 60., 135., 210., 11., 86., 161., 36., 111., 186., 61., 136., 211., 12., 87., 162., 37., 112., 187., 62., 137., 212., - 13., 88., 163., 38., 113., 188., 63., 138., 213., 14., 89., 164., 39., 114., 189., 64., 139., 214., 15., 90., 165., 40., 115., 190., 65., 140., 215., - 16., 91., 166., 41., 116., 191., 66., 141., 216., 17., 92., 167., 42., 117., 192., 67., 142., 217., 18., 93., 168., 43., 118., 193., 68., 143., 218., - 19., 94., 169., 44., 119., 194., 69., 144., 219., 20., 95., 170., 45., 120., 195., 70., 145., 220., 21., 96., 171., 46., 121., 196., 71., 146., 221., - 22., 97., 172., 47., 122., 197., 72., 147., 222., 23., 98., 173., 48., 123., 198., 73., 148., 223., 24., 99., 174., 49., 124., 199., 74., 149., 224., - 25., 100., 175.,50., 125., 200.,75., 150., 225.}); + auto weights = NDArrayFactory::create('c', {kH, kW, oC, iC}, {1.f, 76.f, 151.f, 26.f, 101.f, 176.f, 51.f, 126.f, 201.f, 2.f, 77.f, 152.f, 27.f, 102.f, 177.f, 52.f, 127.f, 202.f, 3.f, 78.f, 153.f, 28.f, 103.f, 178.f, 53.f, 128.f, 203.f, + 4.f, 79.f, 154.f, 29.f, 104.f, 179.f, 54.f, 129.f, 204.f, 5.f, 80.f, 155.f, 30.f, 105.f, 180.f, 55.f, 130.f, 205.f, 6.f, 81.f, 156.f, 31.f, 106.f, 181.f, 56.f, 131.f, 206.f, + 7.f, 82.f, 157.f, 32.f, 107.f, 182.f, 57.f, 132.f, 207.f, 8.f, 83.f, 158.f, 33.f, 108.f, 183.f, 58.f, 133.f, 208.f, 9.f, 84.f, 159.f, 34.f, 109.f, 184.f, 59.f, 134.f, 209.f, + 10.f, 85.f, 160.f, 35.f, 110.f, 185.f, 60.f, 135.f, 210.f, 11.f, 86.f, 161.f, 36.f, 111.f, 186.f, 61.f, 136.f, 211.f, 12.f, 87.f, 162.f, 37.f, 112.f, 187.f, 62.f, 137.f, 212.f, + 13.f, 88.f, 163.f, 38.f, 113.f, 188.f, 63.f, 138.f, 213.f, 14.f, 89.f, 164.f, 39.f, 114.f, 189.f, 64.f, 139.f, 214.f, 15.f, 90.f, 165.f, 40.f, 115.f, 190.f, 65.f, 140.f, 215.f, + 16.f, 91.f, 166.f, 41.f, 116.f, 191.f, 66.f, 141.f, 216.f, 17.f, 92.f, 167.f, 42.f, 117.f, 192.f, 67.f, 142.f, 217.f, 18.f, 93.f, 168.f, 43.f, 118.f, 193.f, 68.f, 143.f, 218.f, + 19.f, 94.f, 169.f, 44.f, 119.f, 194.f, 69.f, 144.f, 219.f, 20.f, 95.f, 170.f, 45.f, 120.f, 195.f, 70.f, 145.f, 220.f, 21.f, 96.f, 171.f, 46.f, 121.f, 196.f, 71.f, 146.f, 221.f, + 22.f, 97.f, 172.f, 47.f, 122.f, 197.f, 72.f, 147.f, 222.f, 23.f, 98.f, 173.f, 48.f, 123.f, 198.f, 73.f, 148.f, 223.f, 24.f, 99.f, 174.f, 49.f, 124.f, 199.f, 74.f, 149.f, 224.f, + 25.f, 100.f, 175.f,50.f, 125.f, 200.f,75.f, 150.f, 225.f}); - auto exp = NDArrayFactory::create('c', {bS, oC, oH, oW}, {6276.0, 12831.0, 19668.0, 26790.0, 27012.0, 20703.0, 14100.0, 7200.0, 13719.0, 28023.0, 42918.0, 58410.0, 58902.0, 45105.0, 30693.0, 15660.0, 22389.0, 45696.0, 69930.0, 95100.0, 95910.0, 73386.0, 49899.0, 25440.0, 32346.0, 65970.0, 100884.0, 137100.0, 138276.0, 105726.0, 71838.0, 36600.0, 33726.0, 68790.0, 105204.0, 142980.0, 144156.0, 110226.0, 74898.0, 38160.0, 27555.0, 56154.0, 85806.0, 116520.0, 117474.0, 89748.0, 60933.0, 31020.0, 19917.0, 40557.0, 61926.0, 84030.0, 84714.0, 64671.0, 43875.0, 22320.0, 10752.0, 21879.0, 33384.0, 45270.0, 45636.0, 34815.0, 23604.0, 12000.0, 7551.0, 15456.0, 23718.0, 32340.0, 32562.0, 24978.0, 17025.0, 8700.0, 16569.0, 33873.0, 51918.0, 70710.0, 71202.0, 54555.0, 37143.0, 18960.0, 27114.0, 55371.0, 84780.0, 115350.0, 116160.0, 88911.0, 60474.0, 30840.0, 39246.0, 80070.0, 122484.0, 166500.0, 167676.0, 128226.0, 87138.0, 44400.0, 40626.0, 82890.0, 126804.0, 172380.0, 173556.0, 132726.0, 90198.0, 45960.0, 33180.0, 67629.0, 103356.0, 140370.0, 141324.0, 107973.0, 73308.0, 37320.0, 23967.0, 48807.0, 74526.0, 101130.0, 101814.0, 77721.0, 52725.0, 26820.0, 12927.0, 26304.0, 40134.0, 54420.0, 54786.0, 41790.0, 28329.0, 14400.0, 8826.0, 18081.0, 27768.0, 37890.0, 38112.0, 29253.0, 19950.0, 10200.0, 19419.0, 39723.0, 60918.0, 83010.0, 83502.0, 64005.0, 43593.0, 22260.0, 31839.0, 65046.0, 99630.0, 135600.0, 136410.0, 104436.0, 71049.0, 36240.0, 46146.0, 94170.0, 144084.0, 195900.0, 197076.0, 150726.0, 102438.0, 52200.0, 47526.0, 96990.0, 148404.0, 201780.0, 202956.0, 155226.0, 105498.0, 53760.0, 38805.0, 79104.0, 120906.0, 164220.0, 165174.0, 126198.0, 85683.0, 43620.0, 28017.0, 57057.0, 87126.0, 118230.0, 118914.0, 90771.0, 61575.0, 31320.0, 15102.0, 30729.0, 46884.0, 63570.0, 63936.0, 48765.0, 33054.0, 16800.0, 17220.0, 34863.0, 52932.0, 71430.0, 72228.0, 54831.0, 36996.0, 18720.0, 36327.0, 73527.0, 111606.0, 150570.0, 152214.0, 115521.0, 77925.0, 39420.0, 57381.0, 116112.0, 176202.0, 237660.0, 240198.0, 182250.0, 122907.0, 62160.0, 80442.0, 162738.0, 246900.0, 332940.0, 336420.0, 255198.0, 172062.0, 87000.0, 84702.0, 171318.0, 259860.0, 350340.0, 353820.0, 268338.0, 180882.0, 91440.0, 66867.0, 135210.0, 205038.0, 276360.0, 279042.0, 211572.0, 142581.0, 72060.0, 46845.0, 94701.0, 143574.0, 193470.0, 195306.0, 148047.0, 99747.0, 50400.0, 24576.0, 49671.0, 75288.0, 101430.0, 102372.0, 77583.0, 52260.0, 26400.0, 22095.0, 44688.0, 67782.0, 91380.0, 92178.0, 69906.0, 47121.0, 23820.0, 46377.0, 93777.0, 142206.0, 191670.0, 193314.0, 146571.0, 98775.0, 49920.0, 72906.0, 147387.0, 223452.0, 301110.0, 303648.0, 230175.0, 155082.0, 78360.0, 101742.0, 205638.0, 311700.0, 419940.0, 423420.0, 320898.0, 216162.0, 109200.0, 106002.0, 214218.0, 324660.0, 437340.0, 440820.0, 334038.0, 224982.0, 113640.0, 83292.0, 168285.0, 254988.0, 343410.0, 346092.0, 262197.0, 176556.0, 89160.0, 58095.0, 117351.0, 177774.0, 239370.0, 241206.0, 182697.0, 122997.0, 62100.0, 30351.0, 61296.0, 92838.0, 124980.0, 125922.0, 95358.0, 64185.0, 32400.0, 26970.0, 54513.0, 82632.0, 111330.0, 112128.0, 84981.0, 57246.0, 28920.0, 56427.0, 114027.0, 172806.0, 232770.0, 234414.0, 177621.0, 119625.0, 60420.0, 88431.0, 178662.0, 270702.0, 364560.0, 367098.0, 278100.0, 187257.0, 94560.0, 123042.0, 248538.0, 376500.0, 506940.0, 510420.0, 386598.0, 260262.0, 131400.0, 127302.0, 257118.0, 389460.0, 524340.0, 527820.0, 399738.0, 269082.0, 135840.0, 99717.0, 201360.0, 304938.0, 410460.0, 413142.0, 312822.0, 210531.0, 106260.0, 69345.0, 140001.0, 211974.0, 285270.0, 287106.0, 217347.0, 146247.0, 73800.0, 36126.0, 72921.0, 110388.0, 148530.0, 149472.0, 113133.0, 76110.0, 38400.0}); + auto exp = NDArrayFactory::create('c', {bS, oC, oH, oW}, {6276.0f, 12831.0f, 19668.0f, 26790.0f, 27012.0f, 20703.0f, 14100.0f, 7200.0f, 13719.0f, 28023.0f, 42918.0f, 58410.0f, 58902.0f, 45105.0f, 30693.0f, 15660.0f, 22389.0f, 45696.0f, 69930.0f, 95100.0f, 95910.0f, 73386.0f, 49899.0f, 25440.0f, 32346.0f, 65970.0f, 100884.0f, 137100.0f, 138276.0f, 105726.0f, 71838.0f, 36600.0f, 33726.0f, 68790.0f, 105204.0f, 142980.0f, 144156.0f, 110226.0f, 74898.0f, 38160.0f, 27555.0f, 56154.0f, 85806.0f, 116520.0f, 117474.0f, 89748.0f, 60933.0f, 31020.0f, 19917.0f, 40557.0f, 61926.0f, 84030.0f, 84714.0f, 64671.0f, 43875.0f, 22320.0f, 10752.0f, 21879.0f, 33384.0f, 45270.0f, 45636.0f, 34815.0f, 23604.0f, 12000.0f, 7551.0f, 15456.0f, 23718.0f, 32340.0f, 32562.0f, 24978.0f, 17025.0f, 8700.0f, 16569.0f, 33873.0f, 51918.0f, 70710.0f, 71202.0f, 54555.0f, 37143.0f, 18960.0f, 27114.0f, 55371.0f, 84780.0f, 115350.0f, 116160.0f, 88911.0f, 60474.0f, 30840.0f, 39246.0f, 80070.0f, 122484.0f, 166500.0f, 167676.0f, 128226.0f, 87138.0f, 44400.0f, 40626.0f, 82890.0f, 126804.0f, 172380.0f, 173556.0f, 132726.0f, 90198.0f, 45960.0f, 33180.0f, 67629.0f, 103356.0f, 140370.0f, 141324.0f, 107973.0f, 73308.0f, 37320.0f, 23967.0f, 48807.0f, 74526.0f, 101130.0f, 101814.0f, 77721.0f, 52725.0f, 26820.0f, 12927.0f, 26304.0f, 40134.0f, 54420.0f, 54786.0f, 41790.0f, 28329.0f, 14400.0f, 8826.0f, 18081.0f, 27768.0f, 37890.0f, 38112.0f, 29253.0f, 19950.0f, 10200.0f, 19419.0f, 39723.0f, 60918.0f, 83010.0f, 83502.0f, 64005.0f, 43593.0f, 22260.0f, 31839.0f, 65046.0f, 99630.0f, 135600.0f, 136410.0f, 104436.0f, 71049.0f, 36240.0f, 46146.0f, 94170.0f, 144084.0f, 195900.0f, 197076.0f, 150726.0f, 102438.0f, 52200.0f, 47526.0f, 96990.0f, 148404.0f, 201780.0f, 202956.0f, 155226.0f, 105498.0f, 53760.0f, 38805.0f, 79104.0f, 120906.0f, 164220.0f, 165174.0f, 126198.0f, 85683.0f, 43620.0f, 28017.0f, 57057.0f, 87126.0f, 118230.0f, 118914.0f, 90771.0f, 61575.0f, 31320.0f, 15102.0f, 30729.0f, 46884.0f, 63570.0f, 63936.0f, 48765.0f, 33054.0f, 16800.0f, 17220.0f, 34863.0f, 52932.0f, 71430.0f, 72228.0f, 54831.0f, 36996.0f, 18720.0f, 36327.0f, 73527.0f, 111606.0f, 150570.0f, 152214.0f, 115521.0f, 77925.0f, 39420.0f, 57381.0f, 116112.0f, 176202.0f, 237660.0f, 240198.0f, 182250.0f, 122907.0f, 62160.0f, 80442.0f, 162738.0f, 246900.0f, 332940.0f, 336420.0f, 255198.0f, 172062.0f, 87000.0f, 84702.0f, 171318.0f, 259860.0f, 350340.0f, 353820.0f, 268338.0f, 180882.0f, 91440.0f, 66867.0f, 135210.0f, 205038.0f, 276360.0f, 279042.0f, 211572.0f, 142581.0f, 72060.0f, 46845.0f, 94701.0f, 143574.0f, 193470.0f, 195306.0f, 148047.0f, 99747.0f, 50400.0f, 24576.0f, 49671.0f, 75288.0f, 101430.0f, 102372.0f, 77583.0f, 52260.0f, 26400.0f, 22095.0f, 44688.0f, 67782.0f, 91380.0f, 92178.0f, 69906.0f, 47121.0f, 23820.0f, 46377.0f, 93777.0f, 142206.0f, 191670.0f, 193314.0f, 146571.0f, 98775.0f, 49920.0f, 72906.0f, 147387.0f, 223452.0f, 301110.0f, 303648.0f, 230175.0f, 155082.0f, 78360.0f, 101742.0f, 205638.0f, 311700.0f, 419940.0f, 423420.0f, 320898.0f, 216162.0f, 109200.0f, 106002.0f, 214218.0f, 324660.0f, 437340.0f, 440820.0f, 334038.0f, 224982.0f, 113640.0f, 83292.0f, 168285.0f, 254988.0f, 343410.0f, 346092.0f, 262197.0f, 176556.0f, 89160.0f, 58095.0f, 117351.0f, 177774.0f, 239370.0f, 241206.0f, 182697.0f, 122997.0f, 62100.0f, 30351.0f, 61296.0f, 92838.0f, 124980.0f, 125922.0f, 95358.0f, 64185.0f, 32400.0f, 26970.0f, 54513.0f, 82632.0f, 111330.0f, 112128.0f, 84981.0f, 57246.0f, 28920.0f, 56427.0f, 114027.0f, 172806.0f, 232770.0f, 234414.0f, 177621.0f, 119625.0f, 60420.0f, 88431.0f, 178662.0f, 270702.0f, 364560.0f, 367098.0f, 278100.0f, 187257.0f, 94560.0f, 123042.0f, 248538.0f, 376500.0f, 506940.0f, 510420.0f, 386598.0f, 260262.0f, 131400.0f, 127302.0f, 257118.0f, 389460.0f, 524340.0f, 527820.0f, 399738.0f, 269082.0f, 135840.0f, 99717.0f, 201360.0f, 304938.0f, 410460.0f, 413142.0f, 312822.0f, 210531.0f, 106260.0f, 69345.0f, 140001.0f, 211974.0f, 285270.0f, 287106.0f, 217347.0f, 146247.0f, 73800.0f, 36126.0f, 72921.0f, 110388.0f, 148530.0f, 149472.0f, 113133.0f, 76110.0f, 38400.0f}); input.linspace(1); @@ -2674,14 +2673,14 @@ TYPED_TEST(TypedConvolutionTests1, deconv2d_tf_test1) { auto input = NDArrayFactory::create('c', {bS, oH, oW, oC}); auto weights = NDArrayFactory::create('c', {kH, kW, iC, oC}); auto outShape = NDArrayFactory::create('c', {4}, {static_cast(bS), static_cast(iH), static_cast(iW), static_cast(iC)}); - auto exp = NDArrayFactory::create('c', {bS, iH, iW, iC}, { 2.75, 7.75, 12.75, 17.75, 22.75, 30.5 , 40.5 , 50.5 , 60.5 , 70.5 , 30.5 , 40.5 , 50.5 , 60.5 , 70.5 , 27.75, 32.75, 37.75, 42.75, 47.75, - 55.5 , 65.5 , 75.5 , 85.5 , 95.5 ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. ,105.5 , 115.5 , 125.5 , 135.5 , 145.5 , - 55.5 , 65.5 , 75.5 , 85.5 , 95.5 ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. ,105.5 , 115.5 , 125.5 , 135.5 , 145.5 , - 52.75, 57.75, 62.75, 67.75, 72.75,130.5 , 140.5 , 150.5 , 160.5 , 170.5 ,130.5 , 140.5 , 150.5 , 160.5 , 170.5 , 77.75, 82.75, 87.75, 92.75, 97.75, - 2.75, 7.75, 12.75, 17.75, 22.75, 30.5 , 40.5 , 50.5 , 60.5 , 70.5 , 30.5 , 40.5 , 50.5 , 60.5 , 70.5 , 27.75, 32.75, 37.75, 42.75, 47.75, - 55.5 , 65.5 , 75.5 , 85.5 , 95.5 ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. ,105.5 , 115.5 , 125.5 , 135.5 , 145.5 , - 55.5 , 65.5 , 75.5 , 85.5 , 95.5 ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. ,105.5 , 115.5 , 125.5 , 135.5 , 145.5 , - 52.75, 57.75, 62.75, 67.75, 72.75,130.5 , 140.5 , 150.5 , 160.5 , 170.5 ,130.5 , 140.5 , 150.5 , 160.5 , 170.5 , 77.75, 82.75, 87.75, 92.75, 97.75}); + auto exp = NDArrayFactory::create('c', {bS, iH, iW, iC}, { 2.75f, 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 27.75f, 32.75f, 37.75f, 42.75f, 47.75f, + 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f, + 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f, + 52.75f, 57.75f, 62.75f, 67.75f, 72.75f, 130.5f, 140.5f, 150.5f, 160.5f, 170.5f, 130.5f, 140.5f, 150.5f, 160.5f, 170.5f, 77.75f, 82.75f, 87.75f, 92.75f, 97.75f, + 2.75f, 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 27.75f, 32.75f, 37.75f, 42.75f, 47.75f, + 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f, + 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f, + 52.75f, 57.75f, 62.75f, 67.75f, 72.75f, 130.5f, 140.5f, 150.5f, 160.5f, 170.5f, 130.5f, 140.5f, 150.5f, 160.5f, 170.5f, 77.75f, 82.75f, 87.75f, 92.75f, 97.75f}); input = 0.5; weights.linspace(0.1, 0.1); diff --git a/libnd4j/tests_cpu/layers_tests/ConvolutionTests2.cpp b/libnd4j/tests_cpu/layers_tests/ConvolutionTests2.cpp index ec67d6dbb..4cbf6b6dd 100644 --- a/libnd4j/tests_cpu/layers_tests/ConvolutionTests2.cpp +++ b/libnd4j/tests_cpu/layers_tests/ConvolutionTests2.cpp @@ -110,14 +110,14 @@ TYPED_TEST(TypedConvolutionTests2, deconv2d_tf_test2) { auto input = NDArrayFactory::create('c', {bS, oH, oW, oC}); auto weights = NDArrayFactory::create('c', {kH, kW, iC, oC}); auto outShape = NDArrayFactory::create('c', {4}, {static_cast(bS), static_cast(iH), static_cast(iW), static_cast(iC)}); - auto exp = NDArrayFactory::create('c', {bS, iH, iW, iC}, {2.75, 7.75, 12.75, 17.75, 22.75, 30.5 , 40.5 , 50.5 , 60.5 , 70.5 , 30.5 , 40.5 , 50.5 , 60.5 , 70.5 , 30.5 , 40.5 , 50.5 , 60.5 , 70.5 , - 55.5 , 65.5 , 75.5 , 85.5 , 95.5 ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. , - 55.5 , 65.5 , 75.5 , 85.5 , 95.5 ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. , - 55.5 , 65.5 , 75.5 , 85.5 , 95.5 ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. , - 2.75, 7.75, 12.75, 17.75, 22.75, 30.5 , 40.5 , 50.5 , 60.5 , 70.5 , 30.5 , 40.5 , 50.5 , 60.5 , 70.5 , 30.5 , 40.5 , 50.5 , 60.5 , 70.5 , - 55.5 , 65.5 , 75.5 , 85.5 , 95.5 ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. , - 55.5 , 65.5 , 75.5 , 85.5 , 95.5 ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. , - 55.5 , 65.5 , 75.5 , 85.5 , 95.5 ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. }); + auto exp = NDArrayFactory::create('c', {bS, iH, iW, iC}, {2.75f, 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, + 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, + 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, + 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, + 2.75f, 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, + 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, + 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, + 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f}); input = 0.5; weights.linspace(0.1, 0.1); @@ -150,7 +150,7 @@ TYPED_TEST(TypedConvolutionTests2, Test_DeConv2D_TF_1) { ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests2, Test_DeConv2D_TF_2) { - auto input0 = NDArrayFactory::create('c', {4}, {3, 8, 8, 16}); + auto input0 = NDArrayFactory::create('c', {4}, {3.f, 8.f, 8.f, 16.f}); auto input1 = NDArrayFactory::create('c', {7, 7, 16, 5}, {1.05293429f, -0.89349967f, 0.31027254f, 1.22991478f, -0.62926656f, 0.56918693f, -1.60992694f, 1.10167944f, -0.80843484f, 0.07521993f, -1.15994942f, 0.76016301f, -0.40056285f, -1.16872537f, -0.91384381f, -0.36700436f, 1.82389200f, -1.18200207f, 0.51612782f, -0.92479187f, -0.09307563f, -0.55122334f, 1.23532486f, -1.11124146f, -0.05812126f, 0.68159896f, 0.69125599f, -0.77127314f, -0.10874277f, 0.86469102f, @@ -569,7 +569,6 @@ TEST_F(ConvolutionTests2, deconv3d_test4) { ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, deconv3d_test5) { - int bS=1, oD=5,oH=5,oW=5, oC=3,iC=2, kD=2,kH=2,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=2,dH=2,dW=2; int iD=3,iH=3,iW=3; int paddingMode = 0; // 1-SAME, 0-VALID; @@ -579,22 +578,22 @@ TEST_F(ConvolutionTests2, deconv3d_test5) { auto weights = NDArrayFactory::create('c', {kD, kH, kW, oC, iC}); auto bias = NDArrayFactory::create('c', {oC}); - auto exp = NDArrayFactory::create('c', {bS, oD, oH, oW, oC}, {-2.9, -6.8, -10.7, -2.6, -6.1, -9.6, -16.9, -23.9, -30.9, -13.1, -16.6, -20.1, -11.6, -14.7, -17.8, -2.0, -4.7, -7.4, -1.7, -4.0, -6.3, -11.5, - -16.1, -20.7, -8.6, -10.9, -13.2, -7.1, -9.0, -10.9, -27.4, -32.8, -38.2, -24.4, -29.0, -33.6, -65.0, -74.2, -83.4, -38.2, -42.8, -47.4, -32.8, - -36.6, -40.4, -18.2, -20.9, -23.6, -15.5, -17.8, -20.1, -39.1, -43.7, -48.3, -22.4, -24.7, -27.0, -18.5, -20.4, -22.3, -10.1, -11.6, -13.1, -7.4, - -8.5, -9.6, -19.3, -21.5, -23.7, -10.7, -11.8, -12.9, -6.8, -7.5, -8.2, -0.2, -0.5, -0.8, 0.1, 0.2, 0.3, -0.7, -0.5, -0.3, 0.4, 0.5, 0.6, 1.9, 2.4, - 2.9, 0.7, 1.6, 2.5, 1.0, 2.3, 3.6, 4.7, 7.3, 9.9, 4.9, 6.2, 7.5, 6.4, 8.1, 9.8, -0.4, 1.4, 3.2, 2.6, 5.2, 7.8, 10.6, 15.8, 21.0, 10.4, 13.0, 15.6, - 15.8, 19.2, 22.6, 6.1, 7.0, 7.9, 8.8, 10.1, 11.4, 20.3, 22.9, 25.5, 12.7, 14.0, 15.3, 16.6, 18.3, 20.0, 14.2, 16.3, 18.4, 16.9, 19.4, 21.9, 40.1, - 45.1, 50.1, 24.4, 26.9, 29.4, 28.3, 31.2, 34.1, -47.2, -47.8, -48.4, -41.8, -41.6, -41.4, -85.4, -85., -84.6, -41.2, -41.0, -40.8, -33.4, -32.4, -31.4, - -31., -29.2, -27.4, -25.6, -23.0, -20.4, -45.8, -40.6, -35.4, -17.8, -15.2, -12.6, -10.0, -6.6, -3.2, -65.6, -62.0, -58.4, -50.0, -44.8, -39.6, -89.2, - -78.8, -68.4, -34.4, -29.2, -24., -14.0, -7.2, -0.4, -20.2, -18.4, -16.6, -10., -7.4, -4.8, -14.6, -9.4, -4.2, -2.2, 0.4, 3.0, 10.4, 13.8, 17.2, 10.4, - 14.6, 18.8, 20.6, 25.6, 30.6, 53.8, 63.8, 73.8, 35.6, 40.6, 45.6, 48.2, 54.0, 59.8, -3.8, -4.1, -4.4, 1.3, 1.4, 1.5, 1.7, 1.9, 2.1, 1.6, 1.7, 1.8, 7.9, - 8.4, 8.9, 11.5, 12.4, 13.3, 16.6, 17.9, 19.2, 35.9, 38.5, 41.1, 20.5, 21.8, 23.1, 26.8, 28.5, 30.2, 21.2, 23.0, 24.8, 33.8, 36.4, 39.0, 73.0, 78.2, - 83.4, 41.6, 44.2, 46.8, 56.6, 60.0, 63.4, 16.9, 17.8, 18.7, 24.4, 25.7, 27., 51.5, 54.1, 56.7, 28.3, 29.6, 30.9, 37.0, 38.7, 40.4, 39.4, 41.5, - 43.6, 46.9, 49.4, 51.9, 100.1, 105.1, 110.1, 54.4, 56.9, 59.4, 63.1, 66.0, 68.9, 42.1, 45.4, 48.7, 47.2, 50.9, 54.6, 104.3, 111.7, - 119.1, 58.3, 62.0, 65.7, 64.6, 68.7, 72.8, 57.4, 61.9, 66.4, 62.5, 67.4, 72.3, 138.5, 148.3, 158.1, 77.2, 82.1, 87.0, 83.5, 88.8, 94.1, - 134.6, 143.6, 152.6, 147.2, 157.0, 166.8, 321.4, 341.0, 360.6, 176.6, 186.4, 196.2, 191.6, 202.2, 212.8, 84.4, 88.9, - 93.4, 91.9, 96.8, 101.7, 197.3, 207.1, 216.9, 106.6, 111.5, 116.4, 115.3, 120.6, 125.9, 106.9, 112.6, 118.3, 114.4, 120.5, 126.6, 245.9, 258.1, 270.3, 132.7, 138.8, 144.9, 141.4, 147.9, 154.4}); + auto exp = NDArrayFactory::create('c', {bS, oD, oH, oW, oC}, {-2.9f, -6.8f, -10.7f, -2.6f, -6.1f, -9.6f, -16.9f, -23.9f, -30.9f, -13.1f, -16.6f, -20.1f, -11.6f, -14.7f, -17.8f, -2.0f, -4.7f, -7.4f, -1.7f, -4.0f, -6.3f, -11.5f, + -16.1f, -20.7f, -8.6f, -10.9f, -13.2f, -7.1f, -9.0f, -10.9f, -27.4f, -32.8f, -38.2f, -24.4f, -29.0f, -33.6f, -65.0f, -74.2f, -83.4f, -38.2f, -42.8f, -47.4f, -32.8f, + -36.6f, -40.4f, -18.2f, -20.9f, -23.6f, -15.5f, -17.8f, -20.1f, -39.1f, -43.7f, -48.3f, -22.4f, -24.7f, -27.0f, -18.5f, -20.4f, -22.3f, -10.1f, -11.6f, -13.1f, -7.4f, + -8.5f, -9.6f, -19.3f, -21.5f, -23.7f, -10.7f, -11.8f, -12.9f, -6.8f, -7.5f, -8.2f, -0.2f, -0.5f, -0.8f, 0.1f, 0.2f, 0.3f, -0.7f, -0.5f, -0.3f, 0.4f, 0.5f, 0.6f, 1.9f, 2.4f, + 2.9f, 0.7f, 1.6f, 2.5f, 1.0f, 2.3f, 3.6f, 4.7f, 7.3f, 9.9f, 4.9f, 6.2f, 7.5f, 6.4f, 8.1f, 9.8f, -0.4f, 1.4f, 3.2f, 2.6f, 5.2f, 7.8f, 10.6f, 15.8f, 21.0f, 10.4f, 13.0f, 15.6f, + 15.8f, 19.2f, 22.6f, 6.1f, 7.0f, 7.9f, 8.8f, 10.1f, 11.4f, 20.3f, 22.9f, 25.5f, 12.7f, 14.0f, 15.3f, 16.6f, 18.3f, 20.0f, 14.2f, 16.3f, 18.4f, 16.9f, 19.4f, 21.9f, 40.1f, + 45.1f, 50.1f, 24.4f, 26.9f, 29.4f, 28.3f, 31.2f, 34.1f, -47.2f, -47.8f, -48.4f, -41.8f, -41.6f, -41.4f, -85.4f, -85.f, -84.6f, -41.2f, -41.0f, -40.8f, -33.4f, -32.4f, -31.4f, + -31.f, -29.2f, -27.4f, -25.6f, -23.0f, -20.4f, -45.8f, -40.6f, -35.4f, -17.8f, -15.2f, -12.6f, -10.0f, -6.6f, -3.2f, -65.6f, -62.0f, -58.4f, -50.0f, -44.8f, -39.6f, -89.2f, + -78.8f, -68.4f, -34.4f, -29.2f, -24.f, -14.0f, -7.2f, -0.4f, -20.2f, -18.4f, -16.6f, -10.f, -7.4f, -4.8f, -14.6f, -9.4f, -4.2f, -2.2f, 0.4f, 3.0f, 10.4f, 13.8f, 17.2f, 10.4f, + 14.6f, 18.8f, 20.6f, 25.6f, 30.6f, 53.8f, 63.8f, 73.8f, 35.6f, 40.6f, 45.6f, 48.2f, 54.0f, 59.8f, -3.8f, -4.1f, -4.4f, 1.3f, 1.4f, 1.5f, 1.7f, 1.9f, 2.1f, 1.6f, 1.7f, 1.8f, 7.9f, + 8.4f, 8.9f, 11.5f, 12.4f, 13.3f, 16.6f, 17.9f, 19.2f, 35.9f, 38.5f, 41.1f, 20.5f, 21.8f, 23.1f, 26.8f, 28.5f, 30.2f, 21.2f, 23.0f, 24.8f, 33.8f, 36.4f, 39.0f, 73.0f, 78.2f, + 83.4f, 41.6f, 44.2f, 46.8f, 56.6f, 60.0f, 63.4f, 16.9f, 17.8f, 18.7f, 24.4f, 25.7f, 27.f, 51.5f, 54.1f, 56.7f, 28.3f, 29.6f, 30.9f, 37.0f, 38.7f, 40.4f, 39.4f, 41.5f, + 43.6f, 46.9f, 49.4f, 51.9f, 100.1f, 105.1f, 110.1f, 54.4f, 56.9f, 59.4f, 63.1f, 66.0f, 68.9f, 42.1f, 45.4f, 48.7f, 47.2f, 50.9f, 54.6f, 104.3f, 111.7f, + 119.1f, 58.3f, 62.0f, 65.7f, 64.6f, 68.7f, 72.8f, 57.4f, 61.9f, 66.4f, 62.5f, 67.4f, 72.3f, 138.5f, 148.3f, 158.1f, 77.2f, 82.1f, 87.0f, 83.5f, 88.8f, 94.1f, + 134.6f, 143.6f, 152.6f, 147.2f, 157.0f, 166.8f, 321.4f, 341.0f, 360.6f, 176.6f, 186.4f, 196.2f, 191.6f, 202.2f, 212.8f, 84.4f, 88.9f, + 93.4f, 91.9f, 96.8f, 101.7f, 197.3f, 207.1f, 216.9f, 106.6f, 111.5f, 116.4f, 115.3f, 120.6f, 125.9f, 106.9f, 112.6f, 118.3f, 114.4f, 120.5f, 126.6f, 245.9f, 258.1f, 270.3f, 132.7f, 138.8f, 144.9f, 141.4f, 147.9f, 154.4f}); input.linspace(-10, 0.5); weights.linspace(0.1, 0.1); @@ -699,7 +698,7 @@ TEST_F(ConvolutionTests2, deconv3d_bp_test3) { int dataFormat = 0; // 1-NDHWC, 0-NCDHW auto input = NDArrayFactory::create('c', {bS, oC, oD, oH, oW}); - auto weights = NDArrayFactory::create('c', {kD, kH, kW, iC, oC}, {0.1,0.9,0.2,0.1,0.3,1.1,0.4,1.2,0.5,1.3,0.6,1.4,0.7,1.5,0.8,1.6}); + auto weights = NDArrayFactory::create('c', {kD, kH, kW, iC, oC}, {0.1f, 0.9f, 0.2f, 0.1f, 0.3f, 1.1f, 0.4f, 1.2f, 0.5f, 1.3f, 0.6f, 1.4f, 0.7f, 1.5f, 0.8f, 1.6f}); auto gradO = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); NDArray expGradI('c', {bS, oD, oH, oW, oC}, {33.8, 37.4, 44.6, 48.2, 66.2, 69.8, 77., 80.6, 77.25, 86.35, 104.55, 113.65, 159.15, 168.25, 186.45, 195.55}, nd4j::DataType::FLOAT32); @@ -734,7 +733,7 @@ TEST_F(ConvolutionTests2, deconv3d_bp_test4) { int dataFormat = 0; // 1-NDHWC, 0-NCDHW auto input = NDArrayFactory::create('c', {bS, oC, oD, oH, oW}); - auto weights = NDArrayFactory::create('c', {kD, kH, kW, iC, oC}, {0.1,0.9,0.2,0.1,0.3,1.1,0.4,1.2,0.5,1.3,0.6,1.4,0.7,1.5,0.8,1.6}); + auto weights = NDArrayFactory::create('c', {kD, kH, kW, iC, oC}, {0.1f, 0.9f, 0.2f, 0.1f, 0.3f, 1.1f, 0.4f, 1.2f, 0.5f, 1.3f, 0.6f, 1.4f, 0.7f, 1.5f, 0.8f, 1.6f}); auto gradO = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); NDArray expGradI('c', {bS, oC, oD, oH, oW}, {0.4, 1.55, 1.05, 2.3, 5.7, 3.2, 1.5, 3.35, 1.75, 3.8, 8.3, 4.3, 9.0, 18.6, 9.2, 4.4, 8.7, 4.1, 1.8, 3.55, 1.65, 3.5, 6.5, 2.8, 1.3, 2.15, 0.75, 0.8, 3.15, 2.25, 4.7, 12.1, 7.2, 3.5, 8.15, 4.55, 7.8, 17.9, 9.9, 19.75, 42.85, 23.6, 9.35, 21.55, 12.9, 5.4, 11.55, 6.05, 8.25, 20.75, 13.2, 0.65, 6.6, 6.75}, nd4j::DataType::FLOAT32); @@ -1062,14 +1061,14 @@ TYPED_TEST(TypedConvolutionTests2, maxpool2d_10) { int paddingMode = 0; // 1-SAME, 0-VALID; int dataFormat = 0; // 1-NHWC, 0-NCHW - auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}, {0.27620894, 0.21801452, 0.062078513, 7.348895E-4, 0.24149609, 0.4948205, 0.93483436, 0.52035654, 0.30292067, 0.3289706, 0.7977864, - 0.03180518, 0.1455722, 0.90352905, 0.9405744, 0.0048329555, 0.44062102, 0.111197524, 0.31742015, 0.1933705, 0.23825112, 0.35076278, 0.7135856, 0.28229436, 0.18310733, - 0.9613717, 0.56823575, 0.78289545, 0.62195826, 0.5244586, 0.5040889, 0.025349546, 0.41400263, 0.28420195, 0.8536445, 0.3044107, 0.7997134, 0.45762005, 0.7653578, - 0.07198584, 0.5304998, 0.7334402, 0.85019743, 0.031957153, 0.37088063, 0.85722464, 0.06376881, 0.39791203}); + auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}, {0.27620894f, 0.21801452f, 0.062078513f, 7.348895E-4f, 0.24149609f, 0.4948205f, 0.93483436f, 0.52035654f, 0.30292067f, 0.3289706f, 0.7977864f, + 0.03180518f, 0.1455722f, 0.90352905f, 0.9405744f, 0.0048329555f, 0.44062102f, 0.111197524f, 0.31742015f, 0.1933705f, 0.23825112f, 0.35076278f, 0.7135856f, 0.28229436f, 0.18310733f, + 0.9613717f, 0.56823575f, 0.78289545f, 0.62195826f, 0.5244586f, 0.5040889f, 0.025349546f, 0.41400263f, 0.28420195f, 0.8536445f, 0.3044107f, 0.7997134f, 0.45762005f, 0.7653578f, + 0.07198584f, 0.5304998f, 0.7334402f, 0.85019743f, 0.031957153f, 0.37088063f, 0.85722464f, 0.06376881f, 0.39791203f}); - auto expOutput = NDArrayFactory::create('c', {bS, iC, oH, oW}, {0.4948205, 0.93483436, 0.93483436, 0.4948205, 0.93483436, 0.93483436, 0.90352905, 0.9405744, 0.9405744, 0.44062102, 0.7135856, - 0.7135856, 0.9613717, 0.9613717, 0.78289545, 0.9613717, 0.9613717, 0.78289545, 0.7997134, 0.8536445, 0.8536445, 0.7997134, 0.85019743, 0.85019743, - 0.85722464, 0.85722464, 0.85019743}); + auto expOutput = NDArrayFactory::create('c', {bS, iC, oH, oW}, {0.4948205f, 0.93483436f, 0.93483436f, 0.4948205f, 0.93483436f, 0.93483436f, 0.90352905f, 0.9405744f, 0.9405744f, 0.44062102f, 0.7135856f, + 0.7135856f, 0.9613717f, 0.9613717f, 0.78289545f, 0.9613717f, 0.9613717f, 0.78289545f, 0.7997134f, 0.8536445f, 0.8536445f, 0.7997134f, 0.85019743f, 0.85019743f, + 0.85722464f, 0.85722464f, 0.85019743f}); nd4j::ops::maxpool2d op; auto results = op.execute({&input}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode}); @@ -1108,9 +1107,9 @@ TYPED_TEST(TypedConvolutionTests2, avgpool3d_test1) { int dataFormat = 0; // 1-NDHWC, 0-NCDHW auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); - auto expected = NDArrayFactory::create('c', {bS, iC, oD, oH, oW}, {10.5, 11.5, 13.5, 14.5, 22.5, 23.5, 25.5, 26.5, 46.5, 47.5, 49.5, 50.5, 58.5, 59.5, 61.5, 62.5, - 82.5, 83.5, 85.5, 86.5, 94.5, 95.5, 97.5, 98.5,118.5,119.5,121.5,122.5,130.5,131.5,133.5,134.5, - 154.5,155.5,157.5,158.5,166.5,167.5,169.5,170.5,190.5,191.5,193.5,194.5,202.5,203.5,205.5,206.5}); + auto expected = NDArrayFactory::create('c', {bS, iC, oD, oH, oW}, {10.5f, 11.5f, 13.5f, 14.5f, 22.5f, 23.5f, 25.5f, 26.5f, 46.5f, 47.5f, 49.5f, 50.5f, 58.5f, 59.5f, 61.5f, 62.5f, + 82.5f, 83.5f, 85.5f, 86.5f, 94.5f, 95.5f, 97.5f, 98.5f,118.5f,119.5f,121.5f,122.5f,130.5f,131.5f,133.5f,134.5f, + 154.5f,155.5f,157.5f,158.5f,166.5f,167.5f,169.5f,170.5f,190.5f,191.5f,193.5f,194.5f,202.5f,203.5f,205.5f,206.5f}); input.linspace(1.); nd4j::ops::avgpool3dnew op; @@ -1133,12 +1132,12 @@ TYPED_TEST(TypedConvolutionTests2, avgpool3d_test2) { int dataFormat = 1; // 1-NDHWC, 0-NCDHW auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); - auto expected = NDArrayFactory::create('c', {bS, oD, oH, oW, iC}, { 25. , 26. , 27. , 28. , 29. , 30. , 29.5, 30.5, 31.5, 29.5, 30.5, 31.5, 32.5, 33.5, 34.5, 34. , 35. , 36. , 38.5, 39.5, 40.5, 41.5, 42.5, 43.5, 43. , 44. , 45. , 43. , 44. , 45. , 46. , 47. , 48. , 47.5, 48.5, 49.5, - 61. , 62. , 63. , 64. , 65. , 66. , 65.5, 66.5, 67.5, 65.5, 66.5, 67.5, 68.5, 69.5, 70.5, 70. , 71. , 72. , 74.5, 75.5, 76.5, 77.5, 78.5, 79.5, 79. , 80. , 81. , 79. , 80. , 81. , 82. , 83. , 84. , 83.5, 84.5, 85.5, - 79. , 80. , 81. , 82. , 83. , 84. , 83.5, 84.5, 85.5, 83.5, 84.5, 85.5, 86.5, 87.5, 88.5, 88. , 89. , 90. , 92.5, 93.5, 94.5, 95.5, 96.5, 97.5, 97. , 98. , 99. , 97. , 98. , 99. ,100. ,101. ,102. ,101.5,102.5,103.5, - 133. ,134. ,135. ,136. ,137. ,138. ,137.5,138.5,139.5,137.5,138.5,139.5,140.5,141.5,142.5,142. ,143. ,144. ,146.5,147.5,148.5,149.5,150.5,151.5,151. ,152. ,153. ,151. ,152. ,153. ,154. ,155. ,156. ,155.5,156.5,157.5, - 169. ,170. ,171. ,172. ,173. ,174. ,173.5,174.5,175.5,173.5,174.5,175.5,176.5,177.5,178.5,178. ,179. ,180. ,182.5,183.5,184.5,185.5,186.5,187.5,187. ,188. ,189. ,187. ,188. ,189. ,190. ,191. ,192. ,191.5,192.5,193.5, - 187. ,188. ,189. ,190. ,191. ,192. ,191.5,192.5,193.5,191.5,192.5,193.5,194.5,195.5,196.5,196. ,197. ,198. ,200.5,201.5,202.5,203.5,204.5,205.5,205. ,206. ,207. ,205. ,206. ,207. ,208. ,209. ,210. ,209.5,210.5,211.5}); + auto expected = NDArrayFactory::create('c', {bS, oD, oH, oW, iC}, { 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 29.5f, 30.5f, 31.5f, 29.5f, 30.5f, 31.5f, 32.5f, 33.5f, 34.5f, 34.f, 35.f, 36.f, 38.5f, 39.5f, 40.5f, 41.5f, 42.5f, 43.5f, 43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 47.5f, 48.5f, 49.5f, + 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 65.5f, 66.5f, 67.5f, 65.5f, 66.5f, 67.5f, 68.5f, 69.5f, 70.5f, 70.f, 71.f, 72.f, 74.5f, 75.5f, 76.5f, 77.5f, 78.5f, 79.5f, 79.f, 80.f, 81.f, 79.f, 80.f, 81.f, 82.f, 83.f, 84.f, 83.5f, 84.5f, 85.5f, + 79.f, 80.f, 81.f, 82.f, 83.f, 84.f, 83.5f, 84.5f, 85.5f, 83.5f, 84.5f, 85.5f, 86.5f, 87.5f, 88.5f, 88.f, 89.f, 90.f, 92.5f, 93.5f, 94.5f, 95.5f, 96.5f, 97.5f, 97.f, 98.f, 99.f, 97.f, 98.f, 99.f, 100.f, 101.f, 102.f, 101.5f, 102.5f, 103.5f, + 133.f, 134.f, 135.f, 136.f, 137.f, 138.f, 137.5f, 138.5f, 139.5f, 137.5f, 138.5f, 139.5f, 140.5f, 141.5f, 142.5f, 142.f, 143.f, 144.f, 146.5f, 147.5f, 148.5f, 149.5f, 150.5f, 151.5f, 151.f, 152.f, 153.f, 151.f, 152.f, 153.f, 154.f, 155.f, 156.f, 155.5f, 156.5f, 157.5f, + 169.f, 170.f, 171.f, 172.f, 173.f, 174.f, 173.5f, 174.5f, 175.5f, 173.5f, 174.5f, 175.5f, 176.5f, 177.5f, 178.5f, 178.f, 179.f, 180.f, 182.5f, 183.5f, 184.5f, 185.5f, 186.5f, 187.5f, 187.f, 188.f, 189.f, 187.f, 188.f, 189.f, 190.f, 191.f, 192.f, 191.5f, 192.5f, 193.5f, + 187.f, 188.f, 189.f, 190.f, 191.f, 192.f, 191.5f, 192.5f, 193.5f, 191.5f, 192.5f, 193.5f, 194.5f, 195.5f, 196.5f, 196.f, 197.f, 198.f, 200.5f, 201.5f, 202.5f, 203.5f, 204.5f, 205.5f, 205.f, 206.f, 207.f, 205.f, 206.f, 207.f, 208.f, 209.f, 210.f, 209.5f, 210.5f, 211.5f}); input.linspace(1.); nd4j::ops::avgpool3dnew op; @@ -1161,9 +1160,9 @@ TYPED_TEST(TypedConvolutionTests2, avgpool3d_test3) { int dataFormat = 1; // 1-NDHWC, 0-NCDHW auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); - auto expected = NDArrayFactory::create('c', {bS, oD, oH, oW, iC}, { 29.5, 30.5, 31.5, 32.5, 33.5, 34.5, 38.5, 39.5, 40.5, 41.5, 42.5, 43.5, 65.5, 66.5, 67.5, 68.5, 69.5, 70.5, - 74.5, 75.5, 76.5, 77.5, 78.5, 79.5,137.5,138.5,139.5,140.5,141.5,142.5,146.5,147.5,148.5,149.5,150.5,151.5, - 173.5,174.5,175.5,176.5,177.5,178.5,182.5,183.5,184.5,185.5,186.5,187.5}); + auto expected = NDArrayFactory::create('c', {bS, oD, oH, oW, iC}, { 29.5f, 30.5f, 31.5f, 32.5f, 33.5f, 34.5f, 38.5f, 39.5f, 40.5f, 41.5f, 42.5f, 43.5f, 65.5f, 66.5f, 67.5f, 68.5f, 69.5f, 70.5f, + 74.5f, 75.5f, 76.5f, 77.5f, 78.5f, 79.5f, 137.5f, 138.5f, 139.5f, 140.5f, 141.5f, 142.5f, 146.5f, 147.5f, 148.5f, 149.5f, 150.5f, 151.5f, + 173.5f, 174.5f, 175.5f, 176.5f, 177.5f, 178.5f, 182.5f, 183.5f, 184.5f, 185.5f, 186.5f, 187.5f}); input.linspace(1.); nd4j::ops::avgpool3dnew op; @@ -1186,24 +1185,24 @@ TYPED_TEST(TypedConvolutionTests2, avgpool3d_test4) { int dataFormat = 0; // 1-NDHWC, 0-NCDHW auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); - auto expected = NDArrayFactory::create('c', {bS, iC, oD, oH, oW},{0.416667, 1.00, 1.333333, 0.75, 1.00, 2.25, 2.75, 1.50, 1.75, 3.75, 4.25, 2.25, 1.416667, 3.00, 3.333333, 1.75, 2.833333, 6.00, 6.666667, 3.50, 5.00, 10.50, 11.50, 6.00, 6.50, - 13.50, 14.50, 7.50, 4.833333, 10.00, 10.666667, 5.50, 6.833333, 14.00, 14.666667, 7.50, 11.00, 22.50, 23.50, 12.00, 12.50, 25.50, 26.50, 13.50, 8.833333, 18.00, 18.666666, 9.50, - 4.416667, 9.00, 9.333333, 4.75, 7.00, 14.25, 14.75, 7.50, 7.75, 15.75, 16.25, 8.25, 5.416667, 11.00, 11.333333, 5.75, 6.416667, 13.00, 13.333333, 6.75, 10.00, 20.25, 20.75, - 10.50, 10.75, 21.75, 22.25, 11.25, 7.416667, 15.00, 15.333333, 7.75, 14.833333, 30.00, 30.666666, 15.50, 23.00, 46.50, 47.50, 24.00, 24.50, 49.50, 50.50, 25.50, 16.833334, - 34.00, 34.666668, 17.50, 18.833334, 38.00, 38.666668, 19.50, 29.00, 58.50, 59.50, 30.00, 30.50, 61.50, 62.50, 31.50, 20.833334, 42.00, 42.666668, 21.50, 10.416667, 21.00, - 21.333334, 10.75, 16.00, 32.25, 32.75, 16.50, 16.75, 33.75, 34.25, 17.25, 11.416667, 23.00, 23.333334, 11.75, 12.416667, 25.00, 25.333334, 12.75, 19.00, 38.25, 38.75, 19.50, - 19.75, 39.75, 40.25, 20.25, 13.416667, 27.00, 27.333334, 13.75, 26.833334, 54.00, 54.666668, 27.50, 41.00, 82.50, 83.50, 42.00, 42.50, 85.50, 86.50, 43.50, 28.833334, 58.00, - 58.666668, 29.50, 30.833334, 62.00, 62.666668, 31.50, 47.00, 94.50, 95.50, 48.00, 48.50, 97.50, 98.50, 49.50, 32.833332, 66.00, 66.666664, 33.50, 16.416666, 33.00, 33.333332, - 16.75, 25.00, 50.25, 50.75, 25.50, 25.75, 51.75, 52.25, 26.25, 17.416666, 35.00, 35.333332, 17.75, 18.416666, 37.00, 37.333332, 18.75, 28.00, 56.25, 56.75, 28.50, 28.75, - 57.75, 58.25, 29.25, 19.416666, 39.00, 39.333332, 19.75, 38.833332, 78.00, 78.666664, 39.50, 59.00, 118.50, 119.50, 60.00, 60.50, 121.50, 122.50, 61.50, 40.833332, 82.00, - 82.666664, 41.50, 42.833332, 86.00, 86.666664, 43.50, 65.00, 130.50, 131.50, 66.00, 66.50, 133.50, 134.50, 67.50, 44.833332, 90.00, 90.666664, 45.50, 22.416666, 45.00, - 45.333332, 22.75, 34.00, 68.25, 68.75, 34.50, 34.75, 69.75, 70.25, 35.25, 23.416666, 47.00, 47.333332, 23.75, 24.416666, 49.00, 49.333332, 24.75, 37.00, 74.25, 74.75, - 37.50, 37.75, 75.75, 76.25, 38.25, 25.416666, 51.00, 51.333332, 25.75, 50.833332, 102.00, 102.666664, 51.50, 77.00, 154.50, 155.50, 78.00, 78.50, 157.50, 158.50, 79.50, - 52.833332, 106.00, 106.666664, 53.50, 54.833332, 110.00, 110.666664, 55.50, 83.00, 166.50, 167.50, 84.00, 84.50, 169.50, 170.50, 85.50, 56.833332, 114.00, 114.666664, - 57.50, 28.416666, 57.00, 57.333332, 28.75, 43.00, 86.25, 86.75, 43.50, 43.75, 87.75, 88.25, 44.25, 29.416666, 59.00, 59.333332, 29.75, 30.416666, 61.00, 61.333332, 30.75, - 46.00, 92.25, 92.75, 46.50, 46.75, 93.75, 94.25, 47.25, 31.416666, 63.00, 63.333332, 31.75, 62.833332, 126.00, 126.666664, 63.50, 95.00, 190.50, 191.50, 96.00, 96.50, - 193.50, 194.50, 97.50, 64.833336, 130.00, 130.666672, 65.50, 66.833336, 134.00, 134.666672, 67.50, 101.00, 202.50, 203.50, 102.00, 102.50, 205.50, 206.50, 103.50, - 68.833336, 138.00, 138.666672, 69.50, 34.416668, 69.00, 69.333336, 34.75, 52.00, 104.25, 104.75, 52.50, 52.75, 105.75, 106.25, 53.25, 35.416668, 71.00, 71.333336, 35.75}); + auto expected = NDArrayFactory::create('c', {bS, iC, oD, oH, oW},{0.416667f, 1.00f, 1.333333f, 0.75f, 1.00f, 2.25f, 2.75f, 1.50f, 1.75f, 3.75f, 4.25f, 2.25f, 1.416667f, 3.00f, 3.333333f, 1.75f, 2.833333f, 6.00f, 6.666667f, 3.50f, 5.00f, 10.50f, 11.50f, 6.00f, 6.50f, + 13.50f, 14.50f, 7.50f, 4.833333f, 10.00f, 10.666667f, 5.50f, 6.833333f, 14.00f, 14.666667f, 7.50f, 11.00f, 22.50f, 23.50f, 12.00f, 12.50f, 25.50f, 26.50f, 13.50f, 8.833333f, 18.00f, 18.666666f, 9.50f, + 4.416667f, 9.00f, 9.333333f, 4.75f, 7.00f, 14.25f, 14.75f, 7.50f, 7.75f, 15.75f, 16.25f, 8.25f, 5.416667f, 11.00f, 11.333333f, 5.75f, 6.416667f, 13.00f, 13.333333f, 6.75f, 10.00f, 20.25f, 20.75f, + 10.50f, 10.75f, 21.75f, 22.25f, 11.25f, 7.416667f, 15.00f, 15.333333f, 7.75f, 14.833333f, 30.00f, 30.666666f, 15.50f, 23.00f, 46.50f, 47.50f, 24.00f, 24.50f, 49.50f, 50.50f, 25.50f, 16.833334f, + 34.00f, 34.666668f, 17.50f, 18.833334f, 38.00f, 38.666668f, 19.50f, 29.00f, 58.50f, 59.50f, 30.00f, 30.50f, 61.50f, 62.50f, 31.50f, 20.833334f, 42.00f, 42.666668f, 21.50f, 10.416667f, 21.00f, + 21.333334f, 10.75f, 16.00f, 32.25f, 32.75f, 16.50f, 16.75f, 33.75f, 34.25f, 17.25f, 11.416667f, 23.00f, 23.333334f, 11.75f, 12.416667f, 25.00f, 25.333334f, 12.75f, 19.00f, 38.25f, 38.75f, 19.50f, + 19.75f, 39.75f, 40.25f, 20.25f, 13.416667f, 27.00f, 27.333334f, 13.75f, 26.833334f, 54.00f, 54.666668f, 27.50f, 41.00f, 82.50f, 83.50f, 42.00f, 42.50f, 85.50f, 86.50f, 43.50f, 28.833334f, 58.00f, + 58.666668f, 29.50f, 30.833334f, 62.00f, 62.666668f, 31.50f, 47.00f, 94.50f, 95.50f, 48.00f, 48.50f, 97.50f, 98.50f, 49.50f, 32.833332f, 66.00f, 66.666664f, 33.50f, 16.416666f, 33.00f, 33.333332f, + 16.75f, 25.00f, 50.25f, 50.75f, 25.50f, 25.75f, 51.75f, 52.25f, 26.25f, 17.416666f, 35.00f, 35.333332f, 17.75f, 18.416666f, 37.00f, 37.333332f, 18.75f, 28.00f, 56.25f, 56.75f, 28.50f, 28.75f, + 57.75f, 58.25f, 29.25f, 19.416666f, 39.00f, 39.333332f, 19.75f, 38.833332f, 78.00f, 78.666664f, 39.50f, 59.00f, 118.50f, 119.50f, 60.00f, 60.50f, 121.50f, 122.50f, 61.50f, 40.833332f, 82.00f, + 82.666664f, 41.50f, 42.833332f, 86.00f, 86.666664f, 43.50f, 65.00f, 130.50f, 131.50f, 66.00f, 66.50f, 133.50f, 134.50f, 67.50f, 44.833332f, 90.00f, 90.666664f, 45.50f, 22.416666f, 45.00f, + 45.333332f, 22.75f, 34.00f, 68.25f, 68.75f, 34.50f, 34.75f, 69.75f, 70.25f, 35.25f, 23.416666f, 47.00f, 47.333332f, 23.75f, 24.416666f, 49.00f, 49.333332f, 24.75f, 37.00f, 74.25f, 74.75f, + 37.50f, 37.75f, 75.75f, 76.25f, 38.25f, 25.416666f, 51.00f, 51.333332f, 25.75f, 50.833332f, 102.00f, 102.666664f, 51.50f, 77.00f, 154.50f, 155.50f, 78.00f, 78.50f, 157.50f, 158.50f, 79.50f, + 52.833332f, 106.00f, 106.666664f, 53.50f, 54.833332f, 110.00f, 110.666664f, 55.50f, 83.00f, 166.50f, 167.50f, 84.00f, 84.50f, 169.50f, 170.50f, 85.50f, 56.833332f, 114.00f, 114.666664f, + 57.50f, 28.416666f, 57.00f, 57.333332f, 28.75f, 43.00f, 86.25f, 86.75f, 43.50f, 43.75f, 87.75f, 88.25f, 44.25f, 29.416666f, 59.00f, 59.333332f, 29.75f, 30.416666f, 61.00f, 61.333332f, 30.75f, + 46.00f, 92.25f, 92.75f, 46.50f, 46.75f, 93.75f, 94.25f, 47.25f, 31.416666f, 63.00f, 63.333332f, 31.75f, 62.833332f, 126.00f, 126.666664f, 63.50f, 95.00f, 190.50f, 191.50f, 96.00f, 96.50f, + 193.50f, 194.50f, 97.50f, 64.833336f, 130.00f, 130.666672f, 65.50f, 66.833336f, 134.00f, 134.666672f, 67.50f, 101.00f, 202.50f, 203.50f, 102.00f, 102.50f, 205.50f, 206.50f, 103.50f, + 68.833336f, 138.00f, 138.666672f, 69.50f, 34.416668f, 69.00f, 69.333336f, 34.75f, 52.00f, 104.25f, 104.75f, 52.50f, 52.75f, 105.75f, 106.25f, 53.25f, 35.416668f, 71.00f, 71.333336f, 35.75f}); input.linspace(1.); nd4j::ops::avgpool3dnew op; @@ -1226,8 +1225,8 @@ TYPED_TEST(TypedConvolutionTests2, maxpool3d_test1) { int dataFormat = 0; // 1-NDHWC, 0-NCDHW auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); - auto expected = NDArrayFactory::create('c', {bS, iC, oD, oH, oW}, {20., 21., 23., 24., 32., 33., 35., 36., 56., 57., 59., 60., 68., 69., 71., 72., 92., 93., 95., 96.,104.,105.,107.,108., - 128.,129.,131.,132.,140.,141.,143.,144.,164.,165.,167.,168.,176.,177.,179.,180.,200.,201.,203.,204.,212.,213.,215.,216.}); + auto expected = NDArrayFactory::create('c', {bS, iC, oD, oH, oW}, {20.f, 21.f, 23.f, 24.f, 32.f, 33.f, 35.f, 36.f, 56.f, 57.f, 59.f, 60.f, 68.f, 69.f, 71.f, 72.f, 92.f, 93.f, 95.f, 96.f, 104.f, 105.f, 107.f, 108.f, + 128.f, 129.f, 131.f, 132.f, 140.f, 141.f, 143.f, 144.f, 164.f, 165.f, 167.f, 168.f, 176.f, 177.f, 179.f, 180.f, 200.f, 201.f, 203.f, 204.f, 212.f, 213.f, 215.f, 216.f}); input.linspace(1.); nd4j::ops::maxpool3dnew op; @@ -1250,12 +1249,12 @@ TYPED_TEST(TypedConvolutionTests2, maxpool3d_test2) { int dataFormat = 1; // 1-NDHWC, 0-NCDHW auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); - auto expected = NDArrayFactory::create('c', {bS, oD, oH, oW, iC}, { 49., 50., 51., 52., 53., 54., 52., 53., 54., 58., 59., 60., 61., 62., 63., 61., 62., 63., 67., 68., 69., 70., 71., 72., 70., 71., 72., 67., 68., 69., 70., 71., 72., 70., 71., 72., - 85., 86., 87., 88., 89., 90., 88., 89., 90., 94., 95., 96., 97., 98., 99., 97., 98., 99.,103., 104., 105.,106., 107., 108.,106., 107., 108.,103., 104., 105.,106., 107., 108.,106., 107., 108., - 85., 86., 87., 88., 89., 90., 88., 89., 90., 94., 95., 96., 97., 98., 99., 97., 98., 99.,103., 104., 105.,106., 107., 108.,106., 107., 108.,103., 104., 105.,106., 107., 108.,106., 107., 108., - 157., 158., 159.,160., 161., 162.,160., 161., 162.,166., 167., 168.,169., 170., 171.,169., 170., 171.,175., 176., 177.,178., 179., 180.,178., 179., 180.,175., 176., 177.,178., 179., 180.,178., 179., 180., - 193., 194., 195.,196., 197., 198.,196., 197., 198.,202., 203., 204.,205., 206., 207.,205., 206., 207.,211., 212., 213.,214., 215., 216.,214., 215., 216.,211., 212., 213.,214., 215., 216.,214., 215., 216., - 193., 194., 195.,196., 197., 198.,196., 197., 198.,202., 203., 204.,205., 206., 207.,205., 206., 207.,211., 212., 213.,214., 215., 216.,214., 215., 216.,211., 212., 213.,214., 215., 216.,214., 215., 216.}); + auto expected = NDArrayFactory::create('c', {bS, oD, oH, oW, iC}, { 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f, 58.f, 59.f, 60.f, 61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f, + 85.f, 86.f, 87.f, 88.f, 89.f, 90.f, 88.f, 89.f, 90.f, 94.f, 95.f, 96.f, 97.f, 98.f, 99.f, 97.f, 98.f, 99.f, 103.f, 104.f, 105.f, 106.f, 107.f, 108.f, 106.f, 107.f, 108.f, 103.f, 104.f, 105.f, 106.f, 107.f, 108.f, 106.f, 107.f, 108.f, + 85.f, 86.f, 87.f, 88.f, 89.f, 90.f, 88.f, 89.f, 90.f, 94.f, 95.f, 96.f, 97.f, 98.f, 99.f, 97.f, 98.f, 99.f, 103.f, 104.f, 105.f, 106.f, 107.f, 108.f, 106.f, 107.f, 108.f, 103.f, 104.f, 105.f, 106.f, 107.f, 108.f, 106.f, 107.f, 108.f, + 157.f, 158.f, 159.f, 160.f, 161.f, 162.f, 160.f, 161.f, 162.f, 166.f, 167.f, 168.f, 169.f, 170.f, 171.f, 169.f, 170.f, 171.f, 175.f, 176.f, 177.f, 178.f, 179.f, 180.f, 178.f, 179.f, 180.f, 175.f, 176.f, 177.f, 178.f, 179.f, 180.f, 178.f, 179.f, 180.f, + 193.f, 194.f, 195.f, 196.f, 197.f, 198.f, 196.f, 197.f, 198.f, 202.f, 203.f, 204.f, 205.f, 206.f, 207.f, 205.f, 206.f, 207.f, 211.f, 212.f, 213.f, 214.f, 215.f, 216.f, 214.f, 215.f, 216.f, 211.f, 212.f, 213.f, 214.f, 215.f, 216.f, 214.f, 215.f, 216.f, + 193.f, 194.f, 195.f, 196.f, 197.f, 198.f, 196.f, 197.f, 198.f, 202.f, 203.f, 204.f, 205.f, 206.f, 207.f, 205.f, 206.f, 207.f, 211.f, 212.f, 213.f, 214.f, 215.f, 216.f, 214.f, 215.f, 216.f, 211.f, 212.f, 213.f, 214.f, 215.f, 216.f, 214.f, 215.f, 216.f}); input.linspace(1.); nd4j::ops::maxpool3dnew op; @@ -1278,8 +1277,8 @@ TYPED_TEST(TypedConvolutionTests2, maxpool3d_test3) { int dataFormat = 1; // 1-NDHWC, 0-NCDHW auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); - auto expected = NDArrayFactory::create('c', {bS, oD, oH, oW, iC}, {58., 59., 60., 61., 62., 63., 67., 68., 69., 70., 71., 72., 94., 95., 96., 97., 98., 99.,103., 104., 105.,106., 107., 108., - 166., 167., 168.,169., 170., 171.,175., 176., 177.,178., 179., 180.,202., 203., 204.,205., 206., 207.,211., 212., 213.,214., 215., 216.}); + auto expected = NDArrayFactory::create('c', {bS, oD, oH, oW, iC}, {58.f, 59.f, 60.f, 61.f, 62.f, 63.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 94.f, 95.f, 96.f, 97.f, 98.f, 99.f, 103.f, 104.f, 105.f, 106.f, 107.f, 108.f, + 166.f, 167.f, 168.f, 169.f, 170.f, 171.f, 175.f, 176.f, 177.f, 178.f, 179.f, 180.f, 202.f, 203.f, 204.f, 205.f, 206.f, 207.f, 211.f, 212.f, 213.f, 214.f, 215.f, 216.f}); input.linspace(1.); nd4j::ops::maxpool3dnew op; @@ -1302,14 +1301,14 @@ TYPED_TEST(TypedConvolutionTests2, maxpool3d_test4) { int dataFormat = 0; // -NDHWC, 0-NCDHW auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); - auto expected = NDArrayFactory::create('c', {bS, iC, oD, oH, oW},{ 4., 5., 6., 6., 7., 8., 9., 9., 10., 11., 12., 12., 10., 11., 12., 12., 16., 17., 18., 18., 19., 20., 21., 21., 22., 23., 24., 24., 22., 23., 24., 24., 28., 29., 30., 30., 31., 32., 33., 33., 34., 35., 36., 36., 34., 35., 36., 36., - 28., 29., 30., 30., 31., 32., 33., 33., 34., 35., 36., 36., 34., 35., 36., 36., 40., 41., 42., 42., 43., 44., 45., 45., 46., 47., 48., 48., 46., 47., 48., 48., 52., 53., 54., 54., 55., 56., 57., 57., 58., 59., 60., 60., 58., 59., 60., 60., - 64., 65., 66., 66., 67., 68., 69., 69., 70., 71., 72., 72., 70., 71., 72., 72., 64., 65., 66., 66., 67., 68., 69., 69., 70., 71., 72., 72., 70., 71., 72., 72., 76., 77., 78., 78., 79., 80., 81., 81., 82., 83., 84., 84., 82., 83., 84., 84., - 88., 89., 90., 90., 91., 92., 93., 93., 94., 95., 96., 96., 94., 95., 96., 96.,100., 101., 102., 102.,103., 104., 105., 105.,106., 107., 108., 108.,106., 107., 108., 108.,100., 101., 102., 102.,103., 104., 105., 105.,106., 107., 108., 108.,106., 107., 108., 108., - 112., 113., 114., 114.,115., 116., 117., 117.,118., 119., 120., 120.,118., 119., 120., 120.,124., 125., 126., 126.,127., 128., 129., 129.,130., 131., 132., 132.,130., 131., 132., 132.,136., 137., 138., 138.,139., 140., 141., 141.,142., 143., 144., 144.,142., 143., 144., 144., - 136., 137., 138., 138.,139., 140., 141., 141.,142., 143., 144., 144.,142., 143., 144., 144.,148., 149., 150., 150.,151., 152., 153., 153.,154., 155., 156., 156.,154., 155., 156., 156.,160., 161., 162., 162.,163., 164., 165., 165.,166., 167., 168., 168.,166., 167., 168., 168., - 172., 173., 174., 174.,175., 176., 177., 177.,178., 179., 180., 180.,178., 179., 180., 180.,172., 173., 174., 174.,175., 176., 177., 177.,178., 179., 180., 180.,178., 179., 180., 180.,184., 185., 186., 186.,187., 188., 189., 189.,190., 191., 192., 192.,190., 191., 192., 192., - 196., 197., 198., 198.,199., 200., 201., 201.,202., 203., 204., 204.,202., 203., 204., 204.,208., 209., 210., 210.,211., 212., 213., 213.,214., 215., 216., 216.,214., 215., 216., 216.,208., 209., 210., 210.,211., 212., 213., 213.,214., 215., 216., 216.,214., 215., 216., 216.}); + auto expected = NDArrayFactory::create('c', {bS, iC, oD, oH, oW},{ 4.f, 5.f, 6.f, 6.f, 7.f, 8.f, 9.f, 9.f, 10.f, 11.f, 12.f, 12.f, 10.f, 11.f, 12.f, 12.f, 16.f, 17.f, 18.f, 18.f, 19.f, 20.f, 21.f, 21.f, 22.f, 23.f, 24.f, 24.f, 22.f, 23.f, 24.f, 24.f, 28.f, 29.f, 30.f, 30.f, 31.f, 32.f, 33.f, 33.f, 34.f, 35.f, 36.f, 36.f, 34.f, 35.f, 36.f, 36.f, + 28.f, 29.f, 30.f, 30.f, 31.f, 32.f, 33.f, 33.f, 34.f, 35.f, 36.f, 36.f, 34.f, 35.f, 36.f, 36.f, 40.f, 41.f, 42.f, 42.f, 43.f, 44.f, 45.f, 45.f, 46.f, 47.f, 48.f, 48.f, 46.f, 47.f, 48.f, 48.f, 52.f, 53.f, 54.f, 54.f, 55.f, 56.f, 57.f, 57.f, 58.f, 59.f, 60.f, 60.f, 58.f, 59.f, 60.f, 60.f, + 64.f, 65.f, 66.f, 66.f, 67.f, 68.f, 69.f, 69.f, 70.f, 71.f, 72.f, 72.f, 70.f, 71.f, 72.f, 72.f, 64.f, 65.f, 66.f, 66.f, 67.f, 68.f, 69.f, 69.f, 70.f, 71.f, 72.f, 72.f, 70.f, 71.f, 72.f, 72.f, 76.f, 77.f, 78.f, 78.f, 79.f, 80.f, 81.f, 81.f, 82.f, 83.f, 84.f, 84.f, 82.f, 83.f, 84.f, 84.f, + 88.f, 89.f, 90.f, 90.f, 91.f, 92.f, 93.f, 93.f, 94.f, 95.f, 96.f, 96.f, 94.f, 95.f, 96.f, 96.f, 100.f, 101.f, 102.f, 102.f, 103.f, 104.f, 105.f, 105.f, 106.f, 107.f, 108.f, 108.f, 106.f, 107.f, 108.f, 108.f, 100.f, 101.f, 102.f, 102.f, 103.f, 104.f, 105.f, 105.f, 106.f, 107.f, 108.f, 108.f, 106.f, 107.f, 108.f, 108.f, + 112.f, 113.f, 114.f, 114.f, 115.f, 116.f, 117.f, 117.f, 118.f, 119.f, 120.f, 120.f, 118.f, 119.f, 120.f, 120.f, 124.f, 125.f, 126.f, 126.f, 127.f, 128.f, 129.f, 129.f, 130.f, 131.f, 132.f, 132.f, 130.f, 131.f, 132.f, 132.f, 136.f, 137.f, 138.f, 138.f, 139.f, 140.f, 141.f, 141.f, 142.f, 143.f, 144.f, 144.f, 142.f, 143.f, 144.f, 144.f, + 136.f, 137.f, 138.f, 138.f, 139.f, 140.f, 141.f, 141.f, 142.f, 143.f, 144.f, 144.f, 142.f, 143.f, 144.f, 144.f, 148.f, 149.f, 150.f, 150.f, 151.f, 152.f, 153.f, 153.f, 154.f, 155.f, 156.f, 156.f, 154.f, 155.f, 156.f, 156.f, 160.f, 161.f, 162.f, 162.f, 163.f, 164.f, 165.f, 165.f, 166.f, 167.f, 168.f, 168.f, 166.f, 167.f, 168.f, 168.f, + 172.f, 173.f, 174.f, 174.f, 175.f, 176.f, 177.f, 177.f, 178.f, 179.f, 180.f, 180.f, 178.f, 179.f, 180.f, 180.f, 172.f, 173.f, 174.f, 174.f, 175.f, 176.f, 177.f, 177.f, 178.f, 179.f, 180.f, 180.f, 178.f, 179.f, 180.f, 180.f, 184.f, 185.f, 186.f, 186.f, 187.f, 188.f, 189.f, 189.f, 190.f, 191.f, 192.f, 192.f, 190.f, 191.f, 192.f, 192.f, + 196.f, 197.f, 198.f, 198.f, 199.f, 200.f, 201.f, 201.f, 202.f, 203.f, 204.f, 204.f, 202.f, 203.f, 204.f, 204.f, 208.f, 209.f, 210.f, 210.f, 211.f, 212.f, 213.f, 213.f, 214.f, 215.f, 216.f, 216.f, 214.f, 215.f, 216.f, 216.f, 208.f, 209.f, 210.f, 210.f, 211.f, 212.f, 213.f, 213.f, 214.f, 215.f, 216.f, 216.f, 214.f, 215.f, 216.f, 216.f}); input.linspace(1.); nd4j::ops::maxpool3dnew op; @@ -1333,15 +1332,15 @@ TYPED_TEST(TypedConvolutionTests2, avgpool3d_bp_test1) { auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); auto gradO = NDArrayFactory::create('c', {bS, iC, oD, oH, oW}); - auto expected = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}, {0.166667, 0.333333, 0.166667,0.333333, 0.666667, 0.333333,0.333333, 0.666667, 0.333333,0.166667, 0.333333, 0.166667,0.333333, 0.666667, 0.333333,0.666667, 1.333333, 0.666667,0.666667, 1.333333, 0.666667,0.333333, 0.666667, 0.333333, - 0.166667, 0.333333, 0.166667,0.333333, 0.666667, 0.333333,0.333333, 0.666667, 0.333333,0.166667, 0.333333, 0.166667,0.166667, 0.333333, 0.166667,0.333333, 0.666667, 0.333333,0.333333, 0.666667, 0.333333,0.166667, 0.333333, 0.166667, - 0.333333, 0.666667, 0.333333,0.666667, 1.333333, 0.666667,0.666667, 1.333333, 0.666667,0.333333, 0.666667, 0.333333,0.166667, 0.333333, 0.166667,0.333333, 0.666667, 0.333333,0.333333, 0.666667, 0.333333,0.166667, 0.333333, 0.166667, - 0.166667, 0.333333, 0.166667,0.333333, 0.666667, 0.333333,0.333333, 0.666667, 0.333333,0.166667, 0.333333, 0.166667,0.333333, 0.666667, 0.333333,0.666667, 1.333333, 0.666667,0.666667, 1.333333, 0.666667,0.333333, 0.666667, 0.333333, - 0.166667, 0.333333, 0.166667,0.333333, 0.666667, 0.333333,0.333333, 0.666667, 0.333333,0.166667, 0.333333, 0.166667,0.166667, 0.333333, 0.166667,0.333333, 0.666667, 0.333333,0.333333, 0.666667, 0.333333,0.166667, 0.333333, 0.166667, - 0.333333, 0.666667, 0.333333,0.666667, 1.333333, 0.666667,0.666667, 1.333333, 0.666667,0.333333, 0.666667, 0.333333,0.166667, 0.333333, 0.166667,0.333333, 0.666667, 0.333333,0.333333, 0.666667, 0.333333,0.166667, 0.333333, 0.166667, - 0.166667, 0.333333, 0.166667,0.333333, 0.666667, 0.333333,0.333333, 0.666667, 0.333333,0.166667, 0.333333, 0.166667,0.333333, 0.666667, 0.333333,0.666667, 1.333333, 0.666667,0.666667, 1.333333, 0.666667,0.333333, 0.666667, 0.333333, - 0.166667, 0.333333, 0.166667,0.333333, 0.666667, 0.333333,0.333333, 0.666667, 0.333333,0.166667, 0.333333, 0.166667,0.166667, 0.333333, 0.166667,0.333333, 0.666667, 0.333333,0.333333, 0.666667, 0.333333,0.166667, 0.333333, 0.166667, - 0.333333, 0.666667, 0.333333,0.666667, 1.333333, 0.666667,0.666667, 1.333333, 0.666667,0.333333, 0.666667, 0.333333,0.166667, 0.333333, 0.166667,0.333333, 0.666667, 0.333333,0.333333, 0.666667, 0.333333,0.166667, 0.333333, 0.166667}); + auto expected = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}, {0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.666667f, 1.333333f, 0.666667f, 0.666667f, 1.333333f, 0.666667f, 0.333333f, 0.666667f, 0.333333f, + 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, + 0.333333f, 0.666667f, 0.333333f, 0.666667f, 1.333333f, 0.666667f, 0.666667f, 1.333333f, 0.666667f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, + 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.666667f, 1.333333f, 0.666667f, 0.666667f, 1.333333f, 0.666667f, 0.333333f, 0.666667f, 0.333333f, + 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, + 0.333333f, 0.666667f, 0.333333f, 0.666667f, 1.333333f, 0.666667f, 0.666667f, 1.333333f, 0.666667f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, + 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.666667f, 1.333333f, 0.666667f, 0.666667f, 1.333333f, 0.666667f, 0.333333f, 0.666667f, 0.333333f, + 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, + 0.333333f, 0.666667f, 0.333333f, 0.666667f, 1.333333f, 0.666667f, 0.666667f, 1.333333f, 0.666667f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f}); input.linspace(1.); gradO = 2.; @@ -1367,15 +1366,15 @@ TYPED_TEST(TypedConvolutionTests2, avgpool3d_bp_test2) { auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); auto gradO = NDArrayFactory::create('c', {bS, iC, oD, oH, oW}); - auto expected = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}, {1.333333, 1.333333, 1.333333,2. , 2. , 2. ,2. , 2. , 2. ,1.333333, 1.333333, 1.333333,1.333333, 1.333333, 1.333333,2. , 2. , 2. ,2. , 2. , 2. ,1.333333, 1.333333, 1.333333, - 1.333333, 1.333333, 1.333333,2. , 2. , 2. ,2. , 2. , 2. ,1.333333, 1.333333, 1.333333,1.333333, 1.333333, 1.333333,2. , 2. , 2. ,2. , 2. , 2. ,1.333333, 1.333333, 1.333333, - 1.333333, 1.333333, 1.333333,2. , 2. , 2. ,2. , 2. , 2. ,1.333333, 1.333333, 1.333333,1.333333, 1.333333, 1.333333,2. , 2. , 2. ,2. , 2. , 2. ,1.333333, 1.333333, 1.333333, - 1.333333, 1.333333, 1.333333,2. , 2. , 2. ,2. , 2. , 2. ,1.333333, 1.333333, 1.333333,1.333333, 1.333333, 1.333333,2. , 2. , 2. ,2. , 2. , 2. ,1.333333, 1.333333, 1.333333, - 1.333333, 1.333333, 1.333333,2. , 2. , 2. ,2. , 2. , 2. ,1.333333, 1.333333, 1.333333,1.333333, 1.333333, 1.333333,2. , 2. , 2. ,2. , 2. , 2. ,1.333333, 1.333333, 1.333333, - 1.333333, 1.333333, 1.333333,2. , 2. , 2. ,2. , 2. , 2. ,1.333333, 1.333333, 1.333333,1.333333, 1.333333, 1.333333,2. , 2. , 2. ,2. , 2. , 2. ,1.333333, 1.333333, 1.333333, - 1.333333, 1.333333, 1.333333,2. , 2. , 2. ,2. , 2. , 2. ,1.333333, 1.333333, 1.333333,1.333333, 1.333333, 1.333333,2. , 2. , 2. ,2. , 2. , 2. ,1.333333, 1.333333, 1.333333, - 1.333333, 1.333333, 1.333333,2. , 2. , 2. ,2. , 2. , 2. ,1.333333, 1.333333, 1.333333,1.333333, 1.333333, 1.333333,2. , 2. , 2. ,2. , 2. , 2. ,1.333333, 1.333333, 1.333333, - 1.333333, 1.333333, 1.333333,2. , 2. , 2. ,2. , 2. , 2. ,1.333333, 1.333333, 1.333333,1.333333, 1.333333, 1.333333,2. , 2. , 2. ,2. , 2. , 2. ,1.333333, 1.333333, 1.333333}); + auto expected = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}, {1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, + 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, + 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, + 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, + 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, + 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, + 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, + 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, + 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f}); input.linspace(1.); gradO = 2.; @@ -1403,14 +1402,14 @@ TYPED_TEST(TypedConvolutionTests2, avgpool3d_bp_test3) { auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); auto gradO = NDArrayFactory::create('c', {bS, oD, oH, oW, iC}); - auto expected = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}, {0.41667, 0.41667, 0.41667,0.83333, 0.83333, 0.83333,1.25, 1.25, 1.25 ,0.58333, 0.58333, 0.58333,1.16667, 1.16667, 1.16667,1.75, 1.75, 1.75 ,0.58333, 0.58333, 0.58333,1.16667, 1.16667, 1.16667,1.75, 1.75, 1.75 , - 0.41667, 0.41667, 0.41667,0.83333, 0.83333, 0.83333,1.25, 1.25, 1.25 ,0.83333, 0.83333, 0.83333,1.66667, 1.66667, 1.66667,2.5 , 2.5 , 2.5 ,1.16667, 1.16667, 1.16667,2.33333, 2.33333, 2.33333,3.5 , 3.5 , 3.5 , - 1.16667, 1.16667, 1.16667,2.33333, 2.33333, 2.33333,3.5 , 3.5 , 3.5 ,0.83333, 0.83333, 0.83333,1.66667, 1.66667, 1.66667,2.5 , 2.5 , 2.5 ,1.25 , 1.25 , 1.25 ,2.5 , 2.5 , 2.5 ,3.75, 3.75, 3.75 , - 1.75 , 1.75 , 1.75 ,3.5 , 3.5 , 3.5 ,5.25, 5.25, 5.25 ,1.75 , 1.75 , 1.75 ,3.5 , 3.5 , 3.5 ,5.25, 5.25, 5.25 ,1.25 , 1.25 , 1.25 ,2.5 , 2.5 , 2.5 ,3.75, 3.75, 3.75 , - 0.41667, 0.41667, 0.41667,0.83333, 0.83333, 0.83333,1.25, 1.25, 1.25 ,0.58333, 0.58333, 0.58333,1.16667, 1.16667, 1.16667,1.75, 1.75, 1.75 ,0.58333, 0.58333, 0.58333,1.16667, 1.16667, 1.16667,1.75, 1.75, 1.75 , - 0.41667, 0.41667, 0.41667,0.83333, 0.83333, 0.83333,1.25, 1.25, 1.25 ,0.83333, 0.83333, 0.83333,1.66667, 1.66667, 1.66667,2.5 , 2.5 , 2.5 ,1.16667, 1.16667, 1.16667,2.33333, 2.33333, 2.33333,3.5 , 3.5 , 3.5 , - 1.16667, 1.16667, 1.16667,2.33333, 2.33333, 2.33333,3.5 , 3.5 , 3.5 ,0.83333, 0.83333, 0.83333,1.66667, 1.66667, 1.66667,2.5 , 2.5 , 2.5 ,1.25 , 1.25 , 1.25 ,2.5 , 2.5 , 2.5 ,3.75, 3.75, 3.75 , - 1.75 , 1.75 , 1.75 ,3.5 , 3.5 , 3.5 ,5.25, 5.25, 5.25 ,1.75 , 1.75 , 1.75 ,3.5 , 3.5 , 3.5 ,5.25, 5.25, 5.25 ,1.25 , 1.25 , 1.25 ,2.5 , 2.5 , 2.5 ,3.75, 3.75, 3.75 }); + auto expected = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}, {0.41667f, 0.41667f, 0.41667f, 0.83333f, 0.83333f, 0.83333f, 1.25f, 1.25f, 1.25f, 0.58333f, 0.58333f, 0.58333f, 1.16667f, 1.16667f, 1.16667f, 1.75f, 1.75f, 1.75f, 0.58333f, 0.58333f, 0.58333f, 1.16667f, 1.16667f, 1.16667f, 1.75f, 1.75f, 1.75f, + 0.41667f, 0.41667f, 0.41667f, 0.83333f, 0.83333f, 0.83333f, 1.25f, 1.25f, 1.25f, 0.83333f, 0.83333f, 0.83333f, 1.66667f, 1.66667f, 1.66667f, 2.5f, 2.5f, 2.5f, 1.16667f, 1.16667f, 1.16667f, 2.33333f, 2.33333f, 2.33333f, 3.5f, 3.5f, 3.5f, + 1.16667f, 1.16667f, 1.16667f, 2.33333f, 2.33333f, 2.33333f, 3.5f, 3.5f, 3.5f, 0.83333f, 0.83333f, 0.83333f, 1.66667f, 1.66667f, 1.66667f, 2.5f, 2.5f, 2.5f, 1.25f, 1.25f, 1.25f, 2.5f, 2.5f, 2.5f, 3.75f, 3.75f, 3.75f, + 1.75f, 1.75f, 1.75f, 3.5f, 3.5f, 3.5f, 5.25f, 5.25f, 5.25f, 1.75f, 1.75f, 1.75f, 3.5f, 3.5f, 3.5f, 5.25f, 5.25f, 5.25f, 1.25f, 1.25f, 1.25f, 2.5f, 2.5f, 2.5f, 3.75f, 3.75f, 3.75f, + 0.41667f, 0.41667f, 0.41667f, 0.83333f, 0.83333f, 0.83333f, 1.25f, 1.25f, 1.25f, 0.58333f, 0.58333f, 0.58333f, 1.16667f, 1.16667f, 1.16667f, 1.75f, 1.75f, 1.75f, 0.58333f, 0.58333f, 0.58333f, 1.16667f, 1.16667f, 1.16667f, 1.75f, 1.75f, 1.75f, + 0.41667f, 0.41667f, 0.41667f, 0.83333f, 0.83333f, 0.83333f, 1.25f, 1.25f, 1.25f, 0.83333f, 0.83333f, 0.83333f, 1.66667f, 1.66667f, 1.66667f, 2.5f, 2.5f, 2.5f, 1.16667f, 1.16667f, 1.16667f, 2.33333f, 2.33333f, 2.33333f, 3.5f, 3.5f, 3.5f, + 1.16667f, 1.16667f, 1.16667f, 2.33333f, 2.33333f, 2.33333f, 3.5f, 3.5f, 3.5f, 0.83333f, 0.83333f, 0.83333f, 1.66667f, 1.66667f, 1.66667f, 2.5f, 2.5f, 2.5f, 1.25f, 1.25f, 1.25f, 2.5f, 2.5f, 2.5f, 3.75f, 3.75f, 3.75f, + 1.75f, 1.75f, 1.75f, 3.5f, 3.5f, 3.5f, 5.25f, 5.25f, 5.25f, 1.75f, 1.75f, 1.75f, 3.5f, 3.5f, 3.5f, 5.25f, 5.25f, 5.25f, 1.25f, 1.25f, 1.25f, 2.5f, 2.5f, 2.5f, 3.75f, 3.75f, 3.75f}); input.linspace(1.); gradO = 2.; @@ -1435,14 +1434,14 @@ TYPED_TEST(TypedConvolutionTests2, avgpool3d_bp_test4) { auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); auto gradO = NDArrayFactory::create('c', {bS, oD, oH, oW, iC}); - auto expected = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}, {0.16667, 0.16667, 0.16667,0.33333, 0.33333, 0.33333,0.5 , 0.5 , 0.5 ,0.33333, 0.33333, 0.33333,0.66667, 0.66667, 0.66667,1. , 1. , 1. ,0.58333, 0.58333, 0.58333,1.16667, 1.16667, 1.16667,1.75, 1.75 , 1.75 , - 0.91667, 0.91667, 0.91667,1.83333, 1.83333, 1.83333,2.75, 2.75 , 2.75 ,0.33333, 0.33333, 0.33333,0.66667, 0.66667, 0.66667,1. , 1. , 1. ,0.66667, 0.66667, 0.66667,1.33333, 1.33333, 1.33333,2. , 2. , 2. , - 1.16667, 1.16667, 1.16667,2.33333, 2.33333, 2.33333,3.5 , 3.5 , 3.5 ,1.83333, 1.83333, 1.83333,3.66667, 3.66667, 3.66667,5.5 , 5.5 , 5.5 ,0.5 , 0.5 , 0.5 ,1. , 1. , 1. ,1.5 , 1.5 , 1.5 , - 1. , 1. , 1. ,2. , 2. , 2. ,3. , 3. , 3. ,1.75 , 1.75 , 1.75 ,3.5 , 3.5 , 3.5 ,5.25, 5.25 , 5.25 ,2.75 , 2.75 , 2.75 ,5.5 , 5.5 , 5.5 ,8.25, 8.25 , 8.25 , - 0.16667, 0.16667, 0.16667,0.33333, 0.33333, 0.33333,0.5 , 0.5 , 0.5 ,0.33333, 0.33333, 0.33333,0.66667, 0.66667, 0.66667,1. , 1. , 1. ,0.58333, 0.58333, 0.58333,1.16667, 1.16667, 1.16667,1.75, 1.75 , 1.75 , - 0.91667, 0.91667, 0.91667,1.83333, 1.83333, 1.83333,2.75, 2.75 , 2.75 ,0.33333, 0.33333, 0.33333,0.66667, 0.66667, 0.66667,1. , 1. , 1. ,0.66667, 0.66667, 0.66667,1.33333, 1.33333, 1.33333,2. , 2. , 2. , - 1.16667, 1.16667, 1.16667,2.33333, 2.33333, 2.33333,3.5 , 3.5 , 3.5 ,1.83333, 1.83333, 1.83333,3.66667, 3.66667, 3.66667,5.5 , 5.5 , 5.5 ,0.5 , 0.5 , 0.5 ,1. , 1. , 1. ,1.5 , 1.5 , 1.5 , - 1. , 1. , 1. ,2. , 2. , 2. ,3. , 3. , 3. ,1.75 , 1.75 , 1.75 ,3.5 , 3.5 , 3.5 ,5.25, 5.25 , 5.25 ,2.75 , 2.75 , 2.75 ,5.5 , 5.5 , 5.5 ,8.25, 8.25 , 8.25 }); + auto expected = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}, {0.16667f, 0.16667f, 0.16667f, 0.33333f, 0.33333f, 0.33333f, 0.5f, 0.5f, 0.5f, 0.33333f, 0.33333f, 0.33333f, 0.66667f, 0.66667f, 0.66667f, 1.f, 1.f, 1.f, 0.58333f, 0.58333f, 0.58333f, 1.16667f, 1.16667f, 1.16667f, 1.75f, 1.75f, 1.75f, + 0.91667f, 0.91667f, 0.91667f, 1.83333f, 1.83333f, 1.83333f, 2.75f, 2.75f, 2.75f, 0.33333f, 0.33333f, 0.33333f, 0.66667f, 0.66667f, 0.66667f, 1.f, 1.f, 1.f, 0.66667f, 0.66667f, 0.66667f, 1.33333f, 1.33333f, 1.33333f, 2.f, 2.f, 2.f, + 1.16667f, 1.16667f, 1.16667f, 2.33333f, 2.33333f, 2.33333f, 3.5f, 3.5f, 3.5f, 1.83333f, 1.83333f, 1.83333f, 3.66667f, 3.66667f, 3.66667f, 5.5f, 5.5f, 5.5f, 0.5f, 0.5f, 0.5f, 1.f, 1.f, 1.f, 1.5f, 1.5f, 1.5f, + 1.f, 1.f, 1.f, 2.f, 2.f, 2.f, 3.f, 3.f, 3.f, 1.75f, 1.75f, 1.75f, 3.5f, 3.5f, 3.5f, 5.25f, 5.25f, 5.25f, 2.75f, 2.75f, 2.75f, 5.5f, 5.5f, 5.5f, 8.25f, 8.25f, 8.25f, + 0.16667f, 0.16667f, 0.16667f, 0.33333f, 0.33333f, 0.33333f, 0.5f, 0.5f, 0.5f, 0.33333f, 0.33333f, 0.33333f, 0.66667f, 0.66667f, 0.66667f, 1.f, 1.f, 1.f, 0.58333f, 0.58333f, 0.58333f, 1.16667f, 1.16667f, 1.16667f, 1.75f, 1.75f, 1.75f, + 0.91667f, 0.91667f, 0.91667f, 1.83333f, 1.83333f, 1.83333f, 2.75f, 2.75f, 2.75f, 0.33333f, 0.33333f, 0.33333f, 0.66667f, 0.66667f, 0.66667f, 1.f, 1.f, 1.f, 0.66667f, 0.66667f, 0.66667f, 1.33333f, 1.33333f, 1.33333f, 2.f, 2.f, 2.f, + 1.16667f, 1.16667f, 1.16667f, 2.33333f, 2.33333f, 2.33333f, 3.5f, 3.5f, 3.5f, 1.83333f, 1.83333f, 1.83333f, 3.66667f, 3.66667f, 3.66667f, 5.5f, 5.5f, 5.5f, 0.5f, 0.5f, 0.5f, 1.f, 1.f, 1.f, 1.5f, 1.5f, 1.5f, + 1.f, 1.f, 1.f, 2.f, 2.f, 2.f, 3.f, 3.f, 3.f, 1.75f, 1.75f, 1.75f, 3.5f, 3.5f, 3.5f, 5.25f, 5.25f, 5.25f, 2.75f, 2.75f, 2.75f, 5.5f, 5.5f, 5.5f, 8.25f, 8.25f, 8.25f}); input.linspace(1.); gradO = 2.; @@ -1467,12 +1466,12 @@ TYPED_TEST(TypedConvolutionTests2, maxpool3d_bp_test1) { auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); auto gradO = NDArrayFactory::create('c', {bS, iC, oD, oH, oW}); - auto expected = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}, {0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0.1, 0.2,0. , 0.3, 0.4,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0.5, 0.6,0. , 0.7, 0.8, - 0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0.9, 1. ,0. , 1.1, 1.2,0. , 0. , 0. ,0. , 0. , 0. ,0. , 1.3, 1.4,0. , 1.5, 1.6, - 0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 1.7, 1.8,0. , 1.9, 2. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 2.1, 2.2,0. , 2.3, 2.4, - 0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 2.5, 2.6,0. , 2.7, 2.8,0. , 0. , 0. ,0. , 0. , 0. ,0. , 2.9, 3. ,0. , 3.1, 3.2, - 0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 3.3, 3.4,0. , 3.5, 3.6,0. , 0. , 0. ,0. , 0. , 0. ,0. , 3.7, 3.8,0. , 3.9, 4. , - 0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 4.1, 4.2,0. , 4.3, 4.4,0. , 0. , 0. ,0. , 0. , 0. ,0. , 4.5, 4.6,0. , 4.7, 4.8}); + auto expected = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.1f, 0.2f, 0.f, 0.3f, 0.4f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.5f, 0.6f, 0.f, 0.7f, 0.8f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.9f, 1.f, 0.f, 1.1f, 1.2f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 1.3f, 1.4f, 0.f, 1.5f, 1.6f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 1.7f, 1.8f, 0.f, 1.9f, 2.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 2.1f, 2.2f, 0.f, 2.3f, 2.4f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 2.5f, 2.6f, 0.f, 2.7f, 2.8f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 2.9f, 3.f, 0.f, 3.1f, 3.2f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 3.3f, 3.4f, 0.f, 3.5f, 3.6f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 3.7f, 3.8f, 0.f, 3.9f, 4.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 4.1f, 4.2f, 0.f, 4.3f, 4.4f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 4.5f, 4.6f, 0.f, 4.7f, 4.8f}); input.linspace(1.); gradO.linspace(0.1, 0.1); @@ -1497,15 +1496,15 @@ TYPED_TEST(TypedConvolutionTests2, maxpool3d_bp_test2) { auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); auto gradO = NDArrayFactory::create('c', {bS, iC, oD, oH, oW}); - auto expected = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}, {0.000e+00, 0.000e+00, 0.000e+00,1.000e-01, 2.000e-01, 7.000e-01,5.000e-01, 6.000e-01, 1.500e+00,2.200e+00, 2.400e+00, 5.400e+00,0.000e+00, 0.000e+00, 0.000e+00,1.700e+00, 1.800e+00, 3.900e+00,2.100e+00, 2.200e+00, 4.700e+00,5.400e+00, 5.600e+00, 1.180e+01, - 0.000e+00, 0.000e+00, 0.000e+00,8.200e+00, 8.400e+00, 1.740e+01,9.000e+00, 9.200e+00, 1.900e+01,2.040e+01, 2.080e+01, 4.280e+01,0.000e+00, 0.000e+00, 0.000e+00,6.500e+00, 6.600e+00, 1.350e+01,6.900e+00, 7.000e+00, 1.430e+01,1.500e+01, 1.520e+01, 3.100e+01, - 0.000e+00, 0.000e+00, 0.000e+00,8.100e+00, 8.200e+00, 1.670e+01,8.500e+00, 8.600e+00, 1.750e+01,1.820e+01, 1.840e+01, 3.740e+01,0.000e+00, 0.000e+00, 0.000e+00,2.100e+01, 2.120e+01, 4.300e+01,2.180e+01, 2.200e+01, 4.460e+01,4.600e+01, 4.640e+01, 9.400e+01, - 0.000e+00, 0.000e+00, 0.000e+00,1.290e+01, 1.300e+01, 2.630e+01,1.330e+01, 1.340e+01, 2.710e+01,2.780e+01, 2.800e+01, 5.660e+01,0.000e+00, 0.000e+00, 0.000e+00,1.450e+01, 1.460e+01, 2.950e+01,1.490e+01, 1.500e+01, 3.030e+01,3.100e+01, 3.120e+01, 6.300e+01, - 0.000e+00, 0.000e+00, 0.000e+00,3.380e+01, 3.400e+01, 6.860e+01,3.460e+01, 3.480e+01, 7.020e+01,7.160e+01, 7.200e+01, 1.452e+02,0.000e+00, 0.000e+00, 0.000e+00,1.930e+01, 1.940e+01, 3.910e+01,1.970e+01, 1.980e+01, 3.990e+01,4.060e+01, 4.080e+01, 8.220e+01, - 0.000e+00, 0.000e+00, 0.000e+00,2.090e+01, 2.100e+01, 4.230e+01,2.130e+01, 2.140e+01, 4.310e+01,4.380e+01, 4.400e+01, 8.860e+01,0.000e+00, 0.000e+00, 0.000e+00,4.660e+01, 4.680e+01, 9.420e+01,4.740e+01, 4.760e+01, 9.580e+01,9.720e+01, 9.760e+01, 1.964e+02, - 0.000e+00, 0.000e+00, 0.000e+00,2.570e+01, 2.580e+01, 5.190e+01,2.610e+01, 2.620e+01, 5.270e+01,5.340e+01, 5.360e+01, 1.078e+02,0.000e+00, 0.000e+00, 0.000e+00,2.730e+01, 2.740e+01, 5.510e+01,2.770e+01, 2.780e+01, 5.590e+01,5.660e+01, 5.680e+01, 1.142e+02, - 0.000e+00, 0.000e+00, 0.000e+00,5.940e+01, 5.960e+01, 1.198e+02,6.020e+01, 6.040e+01, 1.214e+02,1.228e+02, 1.232e+02, 2.476e+02,0.000e+00, 0.000e+00, 0.000e+00,3.210e+01, 3.220e+01, 6.470e+01,3.250e+01, 3.260e+01, 6.550e+01,6.620e+01, 6.640e+01, 1.334e+02, - 0.000e+00, 0.000e+00, 0.000e+00,3.370e+01, 3.380e+01, 6.790e+01,3.410e+01, 3.420e+01, 6.870e+01,6.940e+01, 6.960e+01, 1.398e+02,0.000e+00, 0.000e+00, 0.000e+00,7.220e+01, 7.240e+01, 1.454e+02,7.300e+01, 7.320e+01, 1.470e+02,1.484e+02, 1.488e+02, 2.988e+02}); + auto expected = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}, {0.000e+00f, 0.000e+00f, 0.000e+00f, 1.000e-01f, 2.000e-01f, 7.000e-01f, 5.000e-01f, 6.000e-01f, 1.500e+00f, 2.200e+00f, 2.400e+00f, 5.400e+00f, 0.000e+00f, 0.000e+00f, 0.000e+00f, 1.700e+00f, 1.800e+00f, 3.900e+00f, 2.100e+00f, 2.200e+00f, 4.700e+00f, 5.400e+00f, 5.600e+00f, 1.180e+01f, + 0.000e+00f, 0.000e+00f, 0.000e+00f, 8.200e+00f, 8.400e+00f, 1.740e+01f, 9.000e+00f, 9.200e+00f, 1.900e+01f, 2.040e+01f, 2.080e+01f, 4.280e+01f, 0.000e+00f, 0.000e+00f, 0.000e+00f, 6.500e+00f, 6.600e+00f, 1.350e+01f, 6.900e+00f, 7.000e+00f, 1.430e+01f, 1.500e+01f, 1.520e+01f, 3.100e+01f, + 0.000e+00f, 0.000e+00f, 0.000e+00f, 8.100e+00f, 8.200e+00f, 1.670e+01f, 8.500e+00f, 8.600e+00f, 1.750e+01f, 1.820e+01f, 1.840e+01f, 3.740e+01f, 0.000e+00f, 0.000e+00f, 0.000e+00f, 2.100e+01f, 2.120e+01f, 4.300e+01f, 2.180e+01f, 2.200e+01f, 4.460e+01f, 4.600e+01f, 4.640e+01f, 9.400e+01f, + 0.000e+00f, 0.000e+00f, 0.000e+00f, 1.290e+01f, 1.300e+01f, 2.630e+01f, 1.330e+01f, 1.340e+01f, 2.710e+01f, 2.780e+01f, 2.800e+01f, 5.660e+01f, 0.000e+00f, 0.000e+00f, 0.000e+00f, 1.450e+01f, 1.460e+01f, 2.950e+01f, 1.490e+01f, 1.500e+01f, 3.030e+01f, 3.100e+01f, 3.120e+01f, 6.300e+01f, + 0.000e+00f, 0.000e+00f, 0.000e+00f, 3.380e+01f, 3.400e+01f, 6.860e+01f, 3.460e+01f, 3.480e+01f, 7.020e+01f, 7.160e+01f, 7.200e+01f, 1.452e+02f, 0.000e+00f, 0.000e+00f, 0.000e+00f, 1.930e+01f, 1.940e+01f, 3.910e+01f, 1.970e+01f, 1.980e+01f, 3.990e+01f, 4.060e+01f, 4.080e+01f, 8.220e+01f, + 0.000e+00f, 0.000e+00f, 0.000e+00f, 2.090e+01f, 2.100e+01f, 4.230e+01f, 2.130e+01f, 2.140e+01f, 4.310e+01f, 4.380e+01f, 4.400e+01f, 8.860e+01f, 0.000e+00f, 0.000e+00f, 0.000e+00f, 4.660e+01f, 4.680e+01f, 9.420e+01f, 4.740e+01f, 4.760e+01f, 9.580e+01f, 9.720e+01f, 9.760e+01f, 1.964e+02f, + 0.000e+00f, 0.000e+00f, 0.000e+00f, 2.570e+01f, 2.580e+01f, 5.190e+01f, 2.610e+01f, 2.620e+01f, 5.270e+01f, 5.340e+01f, 5.360e+01f, 1.078e+02f, 0.000e+00f, 0.000e+00f, 0.000e+00f, 2.730e+01f, 2.740e+01f, 5.510e+01f, 2.770e+01f, 2.780e+01f, 5.590e+01f, 5.660e+01f, 5.680e+01f, 1.142e+02f, + 0.000e+00f, 0.000e+00f, 0.000e+00f, 5.940e+01f, 5.960e+01f, 1.198e+02f, 6.020e+01f, 6.040e+01f, 1.214e+02f, 1.228e+02f, 1.232e+02f, 2.476e+02f, 0.000e+00f, 0.000e+00f, 0.000e+00f, 3.210e+01f, 3.220e+01f, 6.470e+01f, 3.250e+01f, 3.260e+01f, 6.550e+01f, 6.620e+01f, 6.640e+01f, 1.334e+02f, + 0.000e+00f, 0.000e+00f, 0.000e+00f, 3.370e+01f, 3.380e+01f, 6.790e+01f, 3.410e+01f, 3.420e+01f, 6.870e+01f, 6.940e+01f, 6.960e+01f, 1.398e+02f, 0.000e+00f, 0.000e+00f, 0.000e+00f, 7.220e+01f, 7.240e+01f, 1.454e+02f, 7.300e+01f, 7.320e+01f, 1.470e+02f, 1.484e+02f, 1.488e+02f, 2.988e+02f}); input.linspace(1.); gradO.linspace(0.1, 0.1); @@ -1530,14 +1529,14 @@ TYPED_TEST(TypedConvolutionTests2, maxpool3d_bp_test3) { auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); auto gradO = NDArrayFactory::create('c', {bS, oD, oH, oW, iC}); - auto expected = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}, { 0., 0., 0., 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , - 0., 0., 0., 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0.1, 0.2 , 0.3, 1.1, 1.3 , 1.5, - 0., 0., 0., 1. , 1.1, 1.2, 2.9, 3.1 , 3.3, 0. , 0. , 0. , 4.7, 4.9 , 5.1, 11.2, 11.6 , 12. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , - 0., 0., 0., 11. , 11.2, 11.4, 23.8, 24.2 , 24.6, 0. , 0. , 0. , 12.8, 13. , 13.2, 27.4, 27.8 , 28.2, 0. , 0. , 0. , 31. , 31.4 , 31.8, 65.6, 66.39999, 67.2, - 0., 0., 0., 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , - 0., 0., 0., 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 10.9, 11. , 11.1, 22.7, 22.9 , 23.1, - 0., 0., 0., 11.8, 11.9, 12. , 24.5, 24.7 , 24.9, 0. , 0. , 0. , 26.3, 26.5 , 26.7, 54.4, 54.8 , 55.2, 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , - 0., 0., 0., 32.6, 32.8, 33. , 67. , 67.4 , 67.8, 0. , 0. , 0. , 34.4, 34.6 , 34.8, 70.6, 71. , 71.4, 0. , 0. , 0. , 74.2, 74.6 , 75. ,152. , 152.8 ,153.6}); + auto expected = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}, { 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.1f, 0.2f, 0.3f, 1.1f, 1.3f, 1.5f, + 0.f, 0.f, 0.f, 1.f, 1.1f, 1.2f, 2.9f, 3.1f, 3.3f, 0.f, 0.f, 0.f, 4.7f, 4.9f, 5.1f, 11.2f, 11.6f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 11.f, 11.2f, 11.4f, 23.8f, 24.2f, 24.6f, 0.f, 0.f, 0.f, 12.8f, 13.f, 13.2f, 27.4f, 27.8f, 28.2f, 0.f, 0.f, 0.f, 31.f, 31.4f, 31.8f, 65.6f, 66.39999f, 67.2f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 10.9f, 11.f, 11.1f, 22.7f, 22.9f, 23.1f, + 0.f, 0.f, 0.f, 11.8f, 11.9f, 12.f, 24.5f, 24.7f, 24.9f, 0.f, 0.f, 0.f, 26.3f, 26.5f, 26.7f, 54.4f, 54.8f, 55.2f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 32.6f, 32.8f, 33.f, 67.f, 67.4f, 67.8f, 0.f, 0.f, 0.f, 34.4f, 34.6f, 34.8f, 70.6f, 71.f, 71.4f, 0.f, 0.f, 0.f, 74.2f, 74.6f, 75.f, 152.f, 152.8f, 153.6f}); input.linspace(1.); gradO.linspace(0.1, 0.1); @@ -1563,13 +1562,13 @@ TYPED_TEST(TypedConvolutionTests2, maxpool3d_bp_test4) { auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); auto gradO = NDArrayFactory::create('c', {bS, oD, oH, oW, iC}); - auto expected = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}, {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.1, 0.2, 0.3, 1.1, 1.3, 1.5, 0, 0, 0, 5.7, 6, 6.3, - 14.1, 14.7, 15.3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 11, 11.2, 11.4, 23.8, 24.2, - 24.6, 0, 0, 0, 43.8, 44.4, 45, 93, 94.2, 95.4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 10.9, 11, 11.1, 22.7, 22.9, 23.1, 0, 0, 0, 38.1, 38.4, 38.7, 78.9, 79.5, 80.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32.6, 32.8, 33, 67, 67.4, 67.8, 0, 0, 0, 108.6, 109.2, 109.8, 222.6, 223.8, 225,}); + auto expected = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.1f, 0.2f, 0.3f, 1.1f, 1.3f, 1.5f, 0.f, 0.f, 0.f, 5.7f, 6.f, 6.3f, + 14.1f, 14.7f, 15.3f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 11.f, 11.2f, 11.4f, 23.8f, 24.2f, + 24.6f, 0.f, 0.f, 0.f, 43.8f, 44.4f, 45.f, 93.f, 94.2f, 95.4f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 10.9f, 11.f, 11.1f, 22.7f, 22.9f, 23.1f, 0.f, 0.f, 0.f, 38.1f, 38.4f, 38.7f, 78.9f, 79.5f, 80.1f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 32.6f, 32.8f, 33.f, 67.f, 67.4f, 67.8f, 0.f, 0.f, 0.f, 108.6f, 109.2f, 109.8f, 222.6f, 223.8f, 225.f}); input.linspace(1.); gradO.linspace(0.1, 0.1); @@ -1652,9 +1651,9 @@ TYPED_TEST(TypedConvolutionTests2, maxpool2d_bp_3) { auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); auto gradO = NDArrayFactory::create('c', {bS, iC, oH, oW}); - auto expected = NDArrayFactory::create('c', {bS, iC, iH, iW}, {0. , 0. , 0. ,0. , 0. , 0. ,0. , 0.1, 0.2,0. , 0.3, 0.4,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0.5, 0.6,0. , 0.7, 0.8, - 0. , 0. , 0. ,0. , 0. , 0. ,0. , 0.9, 1. ,0. , 1.1, 1.2,0. , 0. , 0. ,0. , 0. , 0. ,0. , 1.3, 1.4,0. , 1.5, 1.6, - 0. , 0. , 0. ,0. , 0. , 0. ,0. , 1.7, 1.8,0. , 1.9, 2. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 2.1, 2.2,0. , 2.3, 2.4}); + auto expected = NDArrayFactory::create('c', {bS, iC, iH, iW}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.1f, 0.2f, 0.f, 0.3f, 0.4f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.5f, 0.6f, 0.f, 0.7f, 0.8f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.9f, 1.f, 0.f, 1.1f, 1.2f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 1.3f, 1.4f, 0.f, 1.5f, 1.6f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 1.7f, 1.8f, 0.f, 1.9f, 2.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 2.1f, 2.2f, 0.f, 2.3f, 2.4f}); input.linspace(1.); gradO.linspace(0.1, 0.1); @@ -1679,9 +1678,9 @@ TYPED_TEST(TypedConvolutionTests2, maxpool2d_bp_4) { auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); auto gradO = NDArrayFactory::create('c', {bS, iC, oH, oW}); - auto expected = NDArrayFactory::create('c', {bS, iC, iH, iW}, {0. , 0. , 0. , 0.1, 0.2, 0.7, 0.5, 0.6, 1.5, 2.2, 2.4, 5.4, 0. , 0. , 0. , 1.7, 1.8, 3.9, 2.1, 2.2, 4.7, 5.4, 5.6, 11.8, - 0. , 0. , 0. , 3.3, 3.4, 7.1, 3.7, 3.8, 7.9, 8.6, 8.8, 18.2, 0. , 0. , 0. , 4.9, 5. , 10.3, 5.3, 5.4, 11.1,11.8, 12. , 24.6, - 0. , 0. , 0. , 6.5, 6.6, 13.5, 6.9, 7. , 14.3,15. , 15.2, 31. , 0. , 0. , 0. , 8.1, 8.2, 16.7, 8.5, 8.6, 17.5,18.2, 18.4, 37.4}); + auto expected = NDArrayFactory::create('c', {bS, iC, iH, iW}, {0.f, 0.f, 0.f, 0.1f, 0.2f, 0.7f, 0.5f, 0.6f, 1.5f, 2.2f, 2.4f, 5.4f, 0.f, 0.f, 0.f, 1.7f, 1.8f, 3.9f, 2.1f, 2.2f, 4.7f, 5.4f, 5.6f, 11.8f, + 0.f, 0.f, 0.f, 3.3f, 3.4f, 7.1f, 3.7f, 3.8f, 7.9f, 8.6f, 8.8f, 18.2f, 0.f, 0.f, 0.f, 4.9f, 5.f, 10.3f, 5.3f, 5.4f, 11.1f, 11.8f, 12.f, 24.6f, + 0.f, 0.f, 0.f, 6.5f, 6.6f, 13.5f, 6.9f, 7.f, 14.3f, 15.f, 15.2f, 31.f, 0.f, 0.f, 0.f, 8.1f, 8.2f, 16.7f, 8.5f, 8.6f, 17.5f, 18.2f, 18.4f, 37.4f}); input.linspace(1.); gradO.linspace(0.1, 0.1); @@ -1706,9 +1705,9 @@ TYPED_TEST(TypedConvolutionTests2, maxpool2d_bp_5) { auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); auto gradO = NDArrayFactory::create('c', {bS, oH, oW, iC}); - auto expected = NDArrayFactory::create('c', {bS, iH, iW, iC}, {0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0.1, 0.2, 0.3, 1.1, 1.3, 1.5, 0. , 0. , 0. , 1. , 1.1, 1.2, 2.9, 3.1, 3.3, - 0. , 0. , 0. , 4.7, 4.9, 5.1,11.2,11.6,12. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 3.7, 3.8, 3.9, 8.3, 8.5, 8.7, - 0. , 0. , 0. , 4.6, 4.7, 4.8,10.1,10.3,10.5, 0. , 0. , 0. ,11.9,12.1,12.3,25.6,26. ,26.4}); + auto expected = NDArrayFactory::create('c', {bS, iH, iW, iC}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.1f, 0.2f, 0.3f, 1.1f, 1.3f, 1.5f, 0.f, 0.f, 0.f, 1.f, 1.1f, 1.2f, 2.9f, 3.1f, 3.3f, + 0.f, 0.f, 0.f, 4.7f, 4.9f, 5.1f, 11.2f, 11.6f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 3.7f, 3.8f, 3.9f, 8.3f, 8.5f, 8.7f, + 0.f, 0.f, 0.f, 4.6f, 4.7f, 4.8f, 10.1f, 10.3f, 10.5f, 0.f, 0.f, 0.f, 11.9f, 12.1f, 12.3f, 25.6f, 26.f, 26.4f}); input.linspace(1.); gradO.linspace(0.1, 0.1); @@ -1733,9 +1732,9 @@ TYPED_TEST(TypedConvolutionTests2, maxpool2d_bp_6) { auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); auto gradO = NDArrayFactory::create('c', {bS, oH, oW, iC}); - auto expected = NDArrayFactory::create('c', {bS, iH, iW, iC}, {0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0.1, 0.2, 0.3,0.4, 0.5, 0.6, - 0. , 0. , 0. ,0.7, 0.8, 0.9,1. , 1.1, 1.2,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. , - 0. , 0. , 0. ,1.3, 1.4, 1.5,1.6, 1.7, 1.8,0. , 0. , 0. ,1.9, 2. , 2.1,2.2, 2.3, 2.4}); + auto expected = NDArrayFactory::create('c', {bS, iH, iW, iC}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, + 0.f, 0.f, 0.f, 0.7f, 0.8f, 0.9f, 1.f, 1.1f, 1.2f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 1.3f, 1.4f, 1.5f, 1.6f, 1.7f, 1.8f, 0.f, 0.f, 0.f, 1.9f, 2.f, 2.1f, 2.2f, 2.3f, 2.4f}); input.linspace(1.); gradO.linspace(0.1, 0.1); @@ -1815,8 +1814,8 @@ TYPED_TEST(TypedConvolutionTests2, avgpool2d_bp_2) { // TypeParam expectedBuff[] = {0.875, 2., 2.5,1.375, 2.75 , 6., 7., 3.75, 4.75 ,10., 11., 5.75, 2.875, 6., 6.5, 3.375, 4.875, 10.,10.5, 5.375, 10.75, 22.,23., 11.75, 12.75, 26.,27., 13.75, 6.875, 14.,14.5, 7.375}; auto input = NDArrayFactory::create('c', {bS,iD,iH,iW}); - auto epsilon = NDArrayFactory::create('c', {bS,iD,oH,oW}, {3.5 , 4.5 , 5.5, 7.5 , 8.5 , 9.5, 11.5, 12.5, 13.5, 19.5, 20.5, 21.5, 23.5, 24.5, 25.5, 27.5, 28.5, 29.5}); - auto expected = NDArrayFactory::create('c', {bS,iD,iH,iW}, {0.875, 2., 2.5,1.375, 2.75 , 6., 7., 3.75, 4.75 ,10., 11., 5.75, 2.875, 6., 6.5, 3.375, 4.875, 10.,10.5, 5.375, 10.75, 22.,23., 11.75, 12.75, 26.,27., 13.75, 6.875, 14.,14.5, 7.375}); + auto epsilon = NDArrayFactory::create('c', {bS,iD,oH,oW}, {3.5f, 4.5f, 5.5f, 7.5f, 8.5f, 9.5f, 11.5f, 12.5f, 13.5f, 19.5f, 20.5f, 21.5f, 23.5f, 24.5f, 25.5f, 27.5f, 28.5f, 29.5f}); + auto expected = NDArrayFactory::create('c', {bS,iD,iH,iW}, {0.875f, 2.f, 2.5f, 1.375f, 2.75f, 6.f, 7.f, 3.75f, 4.75f, 10.f, 11.f, 5.75f, 2.875f, 6.f, 6.5f, 3.375f, 4.875f, 10.f, 10.5f, 5.375f, 10.75f, 22.f, 23.f, 11.75f, 12.75f, 26.f, 27.f, 13.75f, 6.875f, 14.f, 14.5f, 7.375f}); input.linspace(1.); @@ -1842,12 +1841,12 @@ TYPED_TEST(TypedConvolutionTests2, avgpool2d_bp_3) { auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); auto gradO = NDArrayFactory::create('c', {bS, iC, oH, oW}); - auto expected = NDArrayFactory::create('c', {bS, iC, iH, iW}, {0.016667,0.05 ,0.033333,0.066667,0.166667,0.1 ,0.066667,0.166667,0.1 ,0.05 ,0.116667,0.066667, - 0.083333,0.183333,0.1 ,0.2 ,0.433333,0.233333,0.2 ,0.433333,0.233333,0.116667,0.25 ,0.133333, - 0.15 ,0.316667,0.166667,0.333333,0.7 ,0.366667,0.333333,0.7 ,0.366667,0.183333,0.383333,0.2 , - 0.216667,0.45 ,0.233333,0.466667,0.966667,0.5 ,0.466667,0.966667,0.5 ,0.25 ,0.516667,0.266667, - 0.283333,0.583333,0.3 ,0.6 ,1.233333,0.633333,0.6 ,1.233333,0.633333,0.316667,0.65 ,0.333333, - 0.35 ,0.716667,0.366667,0.733333,1.5 ,0.766667,0.733333,1.5 ,0.766667,0.383333,0.783333,0.4 }); + auto expected = NDArrayFactory::create('c', {bS, iC, iH, iW}, {0.016667f, 0.05f, 0.033333f, 0.066667f, 0.166667f, 0.1f, 0.066667f, 0.166667f, 0.1f, 0.05f, 0.116667f, 0.066667f, + 0.083333f, 0.183333f, 0.1f, 0.2f, 0.433333f, 0.233333f, 0.2f, 0.433333f, 0.233333f, 0.116667f, 0.25f, 0.133333f, + 0.15f, 0.316667f, 0.166667f, 0.333333f, 0.7f, 0.366667f, 0.333333f, 0.7f, 0.366667f, 0.183333f, 0.383333f, 0.2f, + 0.216667f, 0.45f, 0.233333f, 0.466667f, 0.966667f, 0.5f, 0.466667f, 0.966667f, 0.5f, 0.25f, 0.516667f, 0.266667f, + 0.283333f, 0.583333f, 0.3f, 0.6f, 1.233333f, 0.633333f, 0.6f, 1.233333f, 0.633333f, 0.316667f, 0.65f, 0.333333f, + 0.35f, 0.716667f, 0.366667f, 0.733333f, 1.5f, 0.766667f, 0.733333f, 1.5f, 0.766667f, 0.383333f, 0.783333f, 0.4f }); input.linspace(1.); gradO.linspace(0.1, 0.1); @@ -1873,12 +1872,12 @@ TYPED_TEST(TypedConvolutionTests2, avgpool2d_bp_4) { auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); auto gradO = NDArrayFactory::create('c', {bS, iC, oH, oW}); - auto expected = NDArrayFactory::create('c', {bS, iC, iH, iW}, {0.233333,0.3 ,0.366667,0.55 ,0.65 ,0.75 ,0.95 ,1.05 ,1.15 ,0.766667,0.833333,0.9 , - 1.3 ,1.366667,1.433333,2.15 ,2.25 ,2.35 ,2.55 ,2.65 ,2.75 ,1.833333,1.9 ,1.966667, - 2.366667,2.433333,2.5 ,3.75 ,3.85 ,3.95 ,4.15 ,4.25 ,4.35 ,2.9 ,2.966667,3.033333, - 3.433333,3.5 ,3.566667,5.35 ,5.45 ,5.55 ,5.75 ,5.85 ,5.95 ,3.966667,4.033333,4.1 , - 4.5 ,4.566667,4.633333,6.95 ,7.05 ,7.15 ,7.35 ,7.45 ,7.55 ,5.033333,5.1 ,5.166667, - 5.566667,5.633333,5.7 ,8.549999,8.65 ,8.75 ,8.95 ,9.05 ,9.150001,6.1 ,6.166667,6.233334}); + auto expected = NDArrayFactory::create('c', {bS, iC, iH, iW}, {0.233333f, 0.3f, 0.366667f, 0.55f, 0.65f, 0.75f, 0.95f, 1.05f, 1.15f, 0.766667f, 0.833333f, 0.9f, + 1.3f, 1.366667f, 1.433333f, 2.15f, 2.25f, 2.35f, 2.55f, 2.65f, 2.75f, 1.833333f, 1.9f, 1.966667f, + 2.366667f, 2.433333f, 2.5f, 3.75f, 3.85f, 3.95f, 4.15f, 4.25f, 4.35f, 2.9f, 2.966667f, 3.033333f, + 3.433333f, 3.5f, 3.566667f, 5.35f, 5.45f, 5.55f, 5.75f, 5.85f, 5.95f, 3.966667f, 4.033333f, 4.1f, + 4.5f, 4.566667f, 4.633333f, 6.95f, 7.05f, 7.15f, 7.35f, 7.45f, 7.55f, 5.033333f, 5.1f, 5.166667f, + 5.566667f, 5.633333f, 5.7f, 8.549999f, 8.65f, 8.75f, 8.95f, 9.05f, 9.150001f, 6.1f, 6.166667f, 6.233334f}); input.linspace(1.); gradO.linspace(0.1, 0.1); @@ -1904,10 +1903,10 @@ TYPED_TEST(TypedConvolutionTests2, avgpool2d_bp_5) { auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); auto gradO = NDArrayFactory::create('c', {bS, oH, oW, iC}); - auto expected = NDArrayFactory::create('c', {bS, iH, iW, iC}, {0.19167, 0.23333, 0.275, 0.50833, 0.59167, 0.675, 1.2 , 1.325, 1.45 ,0.50833,0.56667, 0.625, 1.19167,1.30833, 1.425, 2.4 ,2.575, 2.75 , - 1.18333, 1.24167, 1.3 , 2.54167, 2.65833, 2.775, 4.425, 4.6 , 4.775,1.01667,1.05833, 1.1 , 2.15833,2.24167, 2.325, 3.675,3.8 , 3.925, - 1.69167, 1.73333, 1.775, 3.50833, 3.59167, 3.675, 5.7 , 5.825, 5.95 ,2.60833,2.66667, 2.725, 5.39167,5.50833, 5.625, 8.7 ,8.875, 9.05 , - 3.28333, 3.34167, 3.4 , 6.74167, 6.85833, 6.975,10.725,10.9 ,11.075,2.51667,2.55833, 2.6 , 5.15833,5.24167, 5.325, 8.175,8.3 , 8.425}); + auto expected = NDArrayFactory::create('c', {bS, iH, iW, iC}, {0.19167f, 0.23333f, 0.275f, 0.50833f, 0.59167f, 0.675f, 1.2f, 1.325f, 1.45f, 0.50833f, 0.56667f, 0.625f, 1.19167f, 1.30833f, 1.425f, 2.4f, 2.575f, 2.75f, + 1.18333f, 1.24167f, 1.3f, 2.54167f, 2.65833f, 2.775f, 4.425f, 4.6f, 4.775f, 1.01667f, 1.05833f, 1.1f, 2.15833f, 2.24167f, 2.325f, 3.675f, 3.8f, 3.925f, + 1.69167f, 1.73333f, 1.775f, 3.50833f, 3.59167f, 3.675f, 5.7f, 5.825f, 5.95f, 2.60833f, 2.66667f, 2.725f, 5.39167f, 5.50833f, 5.625f, 8.7f, 8.875f, 9.05f, + 3.28333f, 3.34167f, 3.4f, 6.74167f, 6.85833f, 6.975f, 10.725f, 10.9f, 11.075f, 2.51667f, 2.55833f, 2.6f, 5.15833f, 5.24167f, 5.325f, 8.175f, 8.3f, 8.425f}); input.linspace(1.); gradO.linspace(0.1, 0.1); @@ -1933,10 +1932,10 @@ TYPED_TEST(TypedConvolutionTests2, avgpool2d_bp_6) { auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); auto gradO = NDArrayFactory::create('c', {bS, oH, oW, iC}); - auto expected = NDArrayFactory::create('c', {bS, iH, iW, iC}, {0.01667,0.03333,0.05,0.08333,0.11667,0.15,0.06667,0.08333,0.1,0.13333,0.16667,0.2 ,0.36667,0.43333,0.5 ,0.23333,0.26667,0.3, - 0.13333,0.16667,0.2 ,0.36667,0.43333,0.5 ,0.23333,0.26667,0.3,0.11667,0.13333,0.15,0.28333,0.31667,0.35,0.16667,0.18333,0.2, - 0.21667,0.23333,0.25,0.48333,0.51667,0.55,0.26667,0.28333,0.3,0.53333,0.56667,0.6 ,1.16667,1.23333,1.3 ,0.63333,0.66667,0.7, - 0.53333,0.56667,0.6 ,1.16667,1.23333,1.3 ,0.63333,0.66667,0.7,0.31667,0.33333,0.35,0.68333,0.71667,0.75,0.36667,0.38333,0.4}); + auto expected = NDArrayFactory::create('c', {bS, iH, iW, iC}, {0.01667f, 0.03333f, 0.05f, 0.08333f, 0.11667f, 0.15f, 0.06667f, 0.08333f, 0.1f, 0.13333f, 0.16667f, 0.2f, 0.36667f, 0.43333f, 0.5f, 0.23333f, 0.26667f, 0.3f, + 0.13333f, 0.16667f, 0.2f, 0.36667f, 0.43333f, 0.5f, 0.23333f, 0.26667f, 0.3f, 0.11667f, 0.13333f, 0.15f, 0.28333f, 0.31667f, 0.35f, 0.16667f, 0.18333f, 0.2f, + 0.21667f, 0.23333f, 0.25f, 0.48333f, 0.51667f, 0.55f, 0.26667f, 0.28333f, 0.3f, 0.53333f, 0.56667f, 0.6f, 1.16667f, 1.23333f, 1.3f, 0.63333f, 0.66667f, 0.7f, + 0.53333f, 0.56667f, 0.6f, 1.16667f, 1.23333f, 1.3f, 0.63333f, 0.66667f, 0.7f, 0.31667f, 0.33333f, 0.35f, 0.68333f, 0.71667f, 0.75f, 0.36667f, 0.38333f, 0.4f}); input.linspace(1.); gradO.linspace(0.1, 0.1); @@ -1995,12 +1994,12 @@ TYPED_TEST(TypedConvolutionTests2, pnormpool2d_bp_2) { auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); auto gradO = NDArrayFactory::create('c', {bS, iC, oH, oW}); - auto expected = NDArrayFactory::create('c', {bS, iC, iH, iW}, {9.661570e-04,9.671602e-03,1.306569e-02,3.679184e-02,1.297220e-01,1.040181e-01,1.126750e-01,3.320884e-01,2.340406e-01,1.333333e-01,3.352886e-01,2.070211e-01, - 8.991618e-02,2.160601e-01,1.283173e-01,2.744226e-01,6.364498e-01,3.662123e-01,3.869788e-01,8.808994e-01,4.984556e-01,2.613189e-01,5.818475e-01,3.225517e-01, - 2.065654e-01,4.553546e-01,2.501175e-01,5.190718e-01,1.131343e+00,6.148388e-01,6.362602e-01,1.377521e+00,7.439550e-01,3.833026e-01,8.227519e-01,4.407146e-01, - 3.261206e-01,6.969233e-01,3.717564e-01,7.627507e-01,1.620991e+00,8.600952e-01,8.814538e-01,1.866888e+00,9.873542e-01,5.046682e-01,1.064004e+00,5.602558e-01, - 4.464697e-01,9.389536e-01,4.932274e-01,1.005908e+00,2.108550e+00,1.104095e+00,1.125322e+00,2.354009e+00,1.230180e+00,6.258913e-01,1.305581e+00,6.804127e-01, - 5.671396e-01,1.181128e+00,6.145977e-01,1.248783e+00,2.595083e+00,1.347494e+00,1.368600e+00,2.840157e+00,1.472778e+00,7.470673e-01,1.547362e+00,8.008900e-01}); + auto expected = NDArrayFactory::create('c', {bS, iC, iH, iW}, {9.661570e-04f, 9.671602e-03f, 1.306569e-02f, 3.679184e-02f, 1.297220e-01f, 1.040181e-01f, 1.126750e-01f, 3.320884e-01f, 2.340406e-01f, 1.333333e-01f, 3.352886e-01f, 2.070211e-01f, + 8.991618e-02f, 2.160601e-01f, 1.283173e-01f, 2.744226e-01f, 6.364498e-01f, 3.662123e-01f, 3.869788e-01f, 8.808994e-01f, 4.984556e-01f, 2.613189e-01f, 5.818475e-01f, 3.225517e-01f, + 2.065654e-01f, 4.553546e-01f, 2.501175e-01f, 5.190718e-01f, 1.131343e+00f, 6.148388e-01f, 6.362602e-01f, 1.377521e+00f, 7.439550e-01f, 3.833026e-01f, 8.227519e-01f, 4.407146e-01f, + 3.261206e-01f, 6.969233e-01f, 3.717564e-01f, 7.627507e-01f, 1.620991e+00f, 8.600952e-01f, 8.814538e-01f, 1.866888e+00f, 9.873542e-01f, 5.046682e-01f, 1.064004e+00f, 5.602558e-01f, + 4.464697e-01f, 9.389536e-01f, 4.932274e-01f, 1.005908e+00f, 2.108550e+00f, 1.104095e+00f, 1.125322e+00f, 2.354009e+00f, 1.230180e+00f, 6.258913e-01f, 1.305581e+00f, 6.804127e-01f, + 5.671396e-01f, 1.181128e+00f, 6.145977e-01f, 1.248783e+00f, 2.595083e+00f, 1.347494e+00f, 1.368600e+00f, 2.840157e+00f, 1.472778e+00f, 7.470673e-01f, 1.547362e+00f, 8.008900e-01f}); input.linspace(1.); gradO.linspace(0.1, 0.1); @@ -2028,12 +2027,12 @@ TYPED_TEST(TypedConvolutionTests2, pnormpool2d_bp_3) { auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); auto gradO = NDArrayFactory::create('c', {bS, iC, oH, oW}); - auto expected = NDArrayFactory::create('c', {bS, iC, iH, iW}, {0.007931,0.042891,0.040544,0.09369 ,0.276841,0.191675,0.163957,0.442946,0.287512,0.154919,0.373153,0.221172, - 0.15901 ,0.365232,0.207846,0.428282,0.959455,0.534076,0.508585,1.128771,0.623089,0.319794,0.698063,0.379547, - 0.321068,0.692438,0.372316,0.757521,1.620323,0.864566,0.838684,1.787943,0.951023,0.483194,1.023434,0.541058, - 0.483937,1.019414,0.536145,1.085348,2.276996,1.192917,1.166749,2.443606,1.278126,0.646499,1.349361,0.703463, - 0.647021,1.346249,0.699745,1.412654,2.932174,1.520512,1.494153,3.098146,1.604985,0.809791,1.675544,0.866229, - 0.810192,1.673009,0.863237,1.739711,3.58665 ,1.847753,1.82126 ,3.752188,1.931741,0.973081,2.001861,1.029173}); + auto expected = NDArrayFactory::create('c', {bS, iC, iH, iW}, {0.007931f, 0.042891f, 0.040544f, 0.09369f, 0.276841f, 0.191675f, 0.163957f, 0.442946f, 0.287512f, 0.154919f, 0.373153f, 0.221172f, + 0.15901f, 0.365232f, 0.207846f, 0.428282f, 0.959455f, 0.534076f, 0.508585f, 1.128771f, 0.623089f, 0.319794f, 0.698063f, 0.379547f, + 0.321068f, 0.692438f, 0.372316f, 0.757521f, 1.620323f, 0.864566f, 0.838684f, 1.787943f, 0.951023f, 0.483194f, 1.023434f, 0.541058f, + 0.483937f, 1.019414f, 0.536145f, 1.085348f, 2.276996f, 1.192917f, 1.166749f, 2.443606f, 1.278126f, 0.646499f, 1.349361f, 0.703463f, + 0.647021f, 1.346249f, 0.699745f, 1.412654f, 2.932174f, 1.520512f, 1.494153f, 3.098146f, 1.604985f, 0.809791f, 1.675544f, 0.866229f, + 0.810192f, 1.673009f, 0.863237f, 1.739711f, 3.58665f, 1.847753f, 1.82126f, 3.752188f, 1.931741f, 0.973081f, 2.001861f, 1.029173f}); input.linspace(1.); gradO.linspace(0.1, 0.1); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp index 8a889ea44..7bea1e820 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp @@ -2857,8 +2857,8 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_5) { 20.07843f, 21.019608f, 21.960783f, 23.058823f, 20.07843f, 21.019608f, 21.960783f, 23.058823f }); - NDArray min = NDArrayFactory::create({-20., -19., -18., -17}); - NDArray max = NDArrayFactory::create({20., 21., 22., 23}); + NDArray min = NDArrayFactory::create({-20.f, -19.f, -18.f, -17.f}); + NDArray max = NDArrayFactory::create({20.f, 21.f, 22.f, 23.f}); x.linspace(-60.); nd4j::ops::fake_quant_with_min_max_vars_per_channel op; auto results = op.execute({&x, &min, &max}, {}, {}); @@ -3033,8 +3033,8 @@ TYPED_TEST(TypedDeclarableOpsTests10, batchnorm_test2) { auto gamma = NDArrayFactory::create('c', {4}); auto beta = NDArrayFactory::create('c', {4}); - auto expected = NDArrayFactory::create('c', {2,3,4}, {-0.52733537,-0.35763144,-0.18792751,-0.01822358, 0.15148035, 0.32118428, 0.49088821, 0.66059214, 0.83029607, 1. , 1.16970393, 1.33940786, - 1.50911179, 1.67881572, 1.84851965, 2.01822358, 2.18792751, 2.35763144, 2.52733537, 2.6970393 , 2.86674323, 3.03644717, 3.2061511 , 3.37585503}); + auto expected = NDArrayFactory::create('c', {2,3,4}, {-0.52733537f, -0.35763144f, -0.18792751f, -0.01822358f, 0.15148035f, 0.32118428f, 0.49088821f, 0.66059214f, 0.83029607f, 1.f, 1.16970393f, 1.33940786f, + 1.50911179f, 1.67881572f, 1.84851965f, 2.01822358f, 2.18792751f, 2.35763144f, 2.52733537f, 2.6970393f, 2.86674323f, 3.03644717f, 3.2061511f, 3.37585503f}); input.linspace(0.1, 0.1); mean.assign(1.); @@ -3061,13 +3061,13 @@ TYPED_TEST(TypedDeclarableOpsTests10, batchnorm_test2) { TYPED_TEST(TypedDeclarableOpsTests10, batchnorm_test3) { auto input = NDArrayFactory::create('c', {2,3,4}); - auto mean = NDArrayFactory::create('c', {3}, {1.05, 1.1, 1.15}); - auto variance = NDArrayFactory::create('c', {3}, {0.5, 0.6, 0.7}); - auto gamma = NDArrayFactory::create('c', {3}, {1.2, 1.3, 1.4}); - auto beta = NDArrayFactory::create('c', {3}, {0.1, 0.2, 0.3}); + auto mean = NDArrayFactory::create('c', {3}, {1.05f, 1.1f, 1.15f}); + auto variance = NDArrayFactory::create('c', {3}, {0.5f, 0.6f, 0.7f}); + auto gamma = NDArrayFactory::create('c', {3}, {1.2f, 1.3f, 1.4f}); + auto beta = NDArrayFactory::create('c', {3}, {0.1f, 0.2f, 0.3f}); - auto expected = NDArrayFactory::create('c', {2,3,4}, {-1.51218734,-1.34248341,-1.17277948,-1.00307555,-0.80696728,-0.6391394 ,-0.47131152,-0.30348364,-0.11832703, 0.04900378, 0.21633459, 0.38366541, - 0.52425983, 0.69396376, 0.86366769, 1.03337162, 1.20696728, 1.37479516, 1.54262304, 1.71045092, 1.8896427 , 2.05697351, 2.22430432, 2.39163513,}); + auto expected = NDArrayFactory::create('c', {2,3,4}, {-1.51218734f, -1.34248341f, -1.17277948f, -1.00307555f, -0.80696728f, -0.6391394f, -0.47131152f, -0.30348364f, -0.11832703f, 0.04900378f, 0.21633459f, 0.38366541f, + 0.52425983f, 0.69396376f, 0.86366769f, 1.03337162f, 1.20696728f, 1.37479516f, 1.54262304f, 1.71045092f, 1.8896427f, 2.05697351f, 2.22430432f, 2.39163513f}); input.linspace(0.1, 0.1); @@ -3089,13 +3089,13 @@ TYPED_TEST(TypedDeclarableOpsTests10, batchnorm_test3) { TYPED_TEST(TypedDeclarableOpsTests10, batchnorm_test4) { auto input = NDArrayFactory::create('c', {2,3,4}); - auto mean = NDArrayFactory::create('c', {2,1,4}, {1.05, 1.1, 1.15, 1.2, 1.25, 1.3, 1.35, 1.4}); - auto variance = NDArrayFactory::create('c', {2,1,4}, {0.5, 0.6, 0.7, 0.8, 0.9, 1., 1.1, 1.2}); - auto gamma = NDArrayFactory::create('c', {2,1,4}, {1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9}); - auto beta = NDArrayFactory::create('c', {2,1,4}, {0.1, 0.2, 0.3, 0.4, 0.5, 0.66, 0.7, 0.8}); + auto mean = NDArrayFactory::create('c', {2,1,4}, {1.05f, 1.1f, 1.15f, 1.2f, 1.25f, 1.3f, 1.35f, 1.4f}); + auto variance = NDArrayFactory::create('c', {2,1,4}, {0.5f, 0.6f, 0.7f, 0.8f, 0.9f, 1.f, 1.1f, 1.2f}); + auto gamma = NDArrayFactory::create('c', {2,1,4}, {1.2f, 1.3f, 1.4f, 1.5f, 1.6f, 1.7f, 1.8f, 1.9f}); + auto beta = NDArrayFactory::create('c', {2,1,4}, {0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.66f, 0.7f, 0.8f}); - auto expected = NDArrayFactory::create('c', {2,3,4}, {-1.51218734,-1.31045092,-1.12231189,-0.9416324 ,-0.83337162,-0.6391394 ,-0.45298865,-0.2708162 ,-0.1545559 , 0.03217212, 0.21633459, 0.4, - 0.58432694, 0.82999915, 0.95743373, 1.14688951, 1.25894242, 1.50999575, 1.64392367, 1.84066852, 1.93355791, 2.18999235, 2.33041362, 2.53444754}); + auto expected = NDArrayFactory::create('c', {2,3,4}, {-1.51218734f, -1.31045092f, -1.12231189f, -0.9416324f, -0.83337162f, -0.6391394f, -0.45298865f, -0.2708162f, -0.1545559f, 0.03217212f, 0.21633459f, 0.4f, + 0.58432694f, 0.82999915f, 0.95743373f, 1.14688951f, 1.25894242f, 1.50999575f, 1.64392367f, 1.84066852f, 1.93355791f, 2.18999235f, 2.33041362f, 2.53444754f}); input.linspace(0.1, 0.1); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp index 59da5edb4..67ecf5576 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp @@ -1406,9 +1406,9 @@ TEST_F(DeclarableOpsTests12, pad_tests1) { // REFLECT mode 2D TEST_F(DeclarableOpsTests12, pad_tests2) { - float inBuff[] = {1,2,3,4,5,6}; + float inBuff[] = {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}; int padBuff[] = {1,1,2,2}; - float expBuff[] = {6,5,4,5,6,5,4, 3,2,1,2,3,2,1, 6,5,4,5,6,5,4, 3,2,1,2,3,2,1}; + float expBuff[] = {6.f, 5.f, 4.f, 5.f, 6.f, 5.f, 4.f, 3.f, 2.f, 1.f, 2.f, 3.f, 2.f, 1.f, 6.f, 5.f, 4.f, 5.f, 6.f, 5.f, 4.f, 3.f, 2.f, 1.f, 2.f, 3.f, 2.f, 1.f}; auto input = NDArrayFactory::create(inBuff, 'c', {2,3}); auto paddings = NDArrayFactory::create(padBuff, 'c', {2,2}); @@ -1433,9 +1433,9 @@ TEST_F(DeclarableOpsTests12, pad_tests2) { // SYMMETRIC mode 2D TEST_F(DeclarableOpsTests12, pad_tests3) { - float inBuff[] = {1,2,3,4,5,6}; + float inBuff[] = {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}; int padBuff[] = {1,1,2,2}; - float expBuff[] = {2,1,1,2,3,3,2, 2,1,1,2,3,3,2, 5,4,4,5,6,6,5, 5,4,4,5,6,6,5}; + float expBuff[] = {2.f, 1.f, 1.f, 2.f, 3.f, 3.f, 2.f, 2.f,1.f,1.f,2.f,3.f,3.f,2.f, 5.f,4.f,4.f,5.f,6.f,6.f,5.f, 5.f,4.f,4.f,5.f,6.f,6.f,5.f}; auto input = NDArrayFactory::create(inBuff, 'c', {2,3}); auto paddings = NDArrayFactory::create(padBuff, 'c', {2,2}); @@ -1460,13 +1460,13 @@ TEST_F(DeclarableOpsTests12, pad_tests3) { // CONSTANT mode 3D TEST_F(DeclarableOpsTests12, pad_tests4) { - float inBuff[] = {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18}; + float inBuff[] = {1.f,2.f,3.f,4.f,5.f,6.f,7.f,8.f,9.f,10.f,11.f,12.f,13.f,14.f,15.f,16.f,17.f,18.f}; int padBuff[] = {1,1,2,2,2,2}; - float expBuff[] = {0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, - 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 1, 2, 3,0,0,0,0, 4, 5, 6,0,0,0,0, - 7, 8, 9,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0,10,11,12,0, - 0,0,0,13,14,15,0,0,0,0,16,17,18,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, - 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0}; + float expBuff[] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 1.f, 2.f, 3.f, 0.f, 0.f, 0.f, 0.f, 4.f, 5.f, 6.f, 0.f, 0.f, 0.f, 0.f, + 7.f, 8.f, 9.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 10.f, 11.f, 12.f, 0.f, + 0.f, 0.f, 0.f, 13.f, 14.f, 15.f, 0.f, 0.f, 0.f, 0.f, 16.f, 17.f, 18.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}; auto input = NDArrayFactory::create(inBuff, 'c', {2,3,3}); auto paddings = NDArrayFactory::create(padBuff, 'c', {3,2}); @@ -1499,12 +1499,12 @@ TEST_F(DeclarableOpsTests12, pad_tests4) { // REFLECT mode 3D TEST_F(DeclarableOpsTests12, pad_tests5) { - float inBuff[] = {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18}; + double inBuff[] = {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18}; int padBuff[] = {1,1,2,2,2,2}; - float expBuff[] = {18,17,16,17,18,17,16, 15,14,13,14,15,14,13, 12,11,10,11,12,11,10, 15,14,13,14,15,14,13, 18,17,16,17,18,17,16, 15,14,13,14,15,14,13, 12,11,10,11,12,11,10, 9, 8, 7, 8, 9, 8, 7, 6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1, 6, 5, 4, 5, 6, 5, 4, 9, 8, 7, 8, 9, 8, 7, 6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1, 18,17,16,17,18,17,16, 15,14,13,14,15,14,13, 12,11,10,11,12,11,10, 15,14,13,14,15,14,13, 18,17,16,17,18,17,16, 15,14,13,14,15,14,13, 12,11,10,11,12,11,10, 9, 8, 7, 8, 9, 8, 7, 6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1, 6, 5, 4, 5, 6, 5, 4, 9, 8, 7, 8, 9, 8, 7, 6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1}; - auto input = NDArrayFactory::create(inBuff, 'c', {2,3,3}); + double expBuff[] = {18,17,16,17,18,17,16, 15,14,13,14,15,14,13, 12,11,10,11,12,11,10, 15,14,13,14,15,14,13, 18,17,16,17,18,17,16, 15,14,13,14,15,14,13, 12,11,10,11,12,11,10, 9, 8, 7, 8, 9, 8, 7, 6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1, 6, 5, 4, 5, 6, 5, 4, 9, 8, 7, 8, 9, 8, 7, 6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1, 18,17,16,17,18,17,16, 15,14,13,14,15,14,13, 12,11,10,11,12,11,10, 15,14,13,14,15,14,13, 18,17,16,17,18,17,16, 15,14,13,14,15,14,13, 12,11,10,11,12,11,10, 9, 8, 7, 8, 9, 8, 7, 6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1, 6, 5, 4, 5, 6, 5, 4, 9, 8, 7, 8, 9, 8, 7, 6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1}; + auto input = NDArrayFactory::create(inBuff, 'c', {2,3,3}); auto paddings = NDArrayFactory::create(padBuff, 'c', {3,2}); - auto expected = NDArrayFactory::create(expBuff, 'c', {4,7,7}); + auto expected = NDArrayFactory::create(expBuff, 'c', {4,7,7}); nd4j::ops::pad op; auto results = op.execute({&input, &paddings}, {}, {1}); @@ -1525,13 +1525,13 @@ TEST_F(DeclarableOpsTests12, pad_tests5) { // SYMMETRIC mode 3D TEST_F(DeclarableOpsTests12, pad_tests6) { - float inBuff[] = {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18}; + double inBuff[] = {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18}; int padBuff[] = {1,1,2,2,2,2}; - float expBuff[] = {5, 4, 4, 5, 6, 6, 5, 2, 1, 1, 2, 3, 3, 2, 2, 1, 1, 2, 3, 3, 2, 5, 4, 4, 5, 6, 6, 5, 8, 7, 7, 8, 9, 9, 8, 8, 7, 7, 8, 9, 9, 8, 5, 4, 4, 5, 6, 6, 5, 5, 4, 4, 5, 6, 6, 5, 2, 1, 1, 2, 3, 3, 2, 2, 1, 1, 2, 3, 3, 2, 5, 4, 4, 5, 6, 6, 5, 8, 7, 7, 8, 9, 9, 8, 8, 7, 7, 8, 9, 9, 8, 5, 4, 4, 5, 6, 6, 5, 14,13,13,14,15,15,14, 11,10,10,11,12,12,11, 11,10,10,11,12,12,11, 14,13,13,14,15,15,14, 17,16,16,17,18,18,17, 17,16,16,17,18,18,17, 14,13,13,14,15,15,14, 14,13,13,14,15,15,14, 11,10,10,11,12,12,11, 11,10,10,11,12,12,11, 14,13,13,14,15,15,14, 17,16,16,17,18,18,17, 17,16,16,17,18,18,17, 14,13,13,14,15,15,14}; + double expBuff[] = {5, 4, 4, 5, 6, 6, 5, 2, 1, 1, 2, 3, 3, 2, 2, 1, 1, 2, 3, 3, 2, 5, 4, 4, 5, 6, 6, 5, 8, 7, 7, 8, 9, 9, 8, 8, 7, 7, 8, 9, 9, 8, 5, 4, 4, 5, 6, 6, 5, 5, 4, 4, 5, 6, 6, 5, 2, 1, 1, 2, 3, 3, 2, 2, 1, 1, 2, 3, 3, 2, 5, 4, 4, 5, 6, 6, 5, 8, 7, 7, 8, 9, 9, 8, 8, 7, 7, 8, 9, 9, 8, 5, 4, 4, 5, 6, 6, 5, 14,13,13,14,15,15,14, 11,10,10,11,12,12,11, 11,10,10,11,12,12,11, 14,13,13,14,15,15,14, 17,16,16,17,18,18,17, 17,16,16,17,18,18,17, 14,13,13,14,15,15,14, 14,13,13,14,15,15,14, 11,10,10,11,12,12,11, 11,10,10,11,12,12,11, 14,13,13,14,15,15,14, 17,16,16,17,18,18,17, 17,16,16,17,18,18,17, 14,13,13,14,15,15,14}; - auto input = NDArrayFactory::create(inBuff, 'c', {2,3,3}); + auto input = NDArrayFactory::create(inBuff, 'c', {2,3,3}); auto paddings = NDArrayFactory::create(padBuff, 'c', {3,2}); - auto expected = NDArrayFactory::create(expBuff, 'c', {4,7,7}); + auto expected = NDArrayFactory::create(expBuff, 'c', {4,7,7}); nd4j::ops::pad op; auto results = op.execute({&input, &paddings}, {}, {2}); @@ -1552,12 +1552,12 @@ TEST_F(DeclarableOpsTests12, pad_tests6) { TEST_F(DeclarableOpsTests12, pad_tests7) { - float inBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; + double inBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; int padBuff[] = {1, 1, 1, 1, 1, 1, 1, 1}; - float expBuff[] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 3, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 6, 0, 0, 7, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 9, 10, 0, 0, 11, 12, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 13, 14, 0, 0, 15, 16, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; - auto input = NDArrayFactory::create(inBuff, 'c', {2, 2, 2, 2}); + double expBuff[] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 3, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 6, 0, 0, 7, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 9, 10, 0, 0, 11, 12, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 13, 14, 0, 0, 15, 16, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + auto input = NDArrayFactory::create(inBuff, 'c', {2, 2, 2, 2}); auto paddings = NDArrayFactory::create(padBuff, 'c', {4, 2}); - auto expected = NDArrayFactory::create(expBuff, 'c', {4, 4, 4, 4}); + auto expected = NDArrayFactory::create(expBuff, 'c', {4, 4, 4, 4}); nd4j::ops::pad op; auto results = op.execute({&input, &paddings}, {}, {0}); @@ -1578,12 +1578,12 @@ TEST_F(DeclarableOpsTests12, pad_tests7) TEST_F(DeclarableOpsTests12, pad_tests8) { - float inBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; + double inBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; int padBuff[] = {1, 1, 1, 1, 1, 1, 1, 1}; - float expBuff[] = {16, 15, 16, 15, 14, 13, 14, 13, 16, 15, 16, 15, 14, 13, 14, 13, 12, 11, 12, 11, 10, 9, 10, 9, 12, 11, 12, 11, 10, 9, 10, 9, 16, 15, 16, 15, 14, 13, 14, 13, 16, 15, 16, 15, 14, 13, 14, 13, 12, 11, 12, 11, 10, 9, 10, 9, 12, 11, 12, 11, 10, 9, 10, 9, 8, 7, 8, 7, 6, 5, 6, 5, 8, 7, 8, 7, 6, 5, 6, 5, 4, 3, 4, 3, 2, 1, 2, 1, 4, 3, 4, 3, 2, 1, 2, 1, 8, 7, 8, 7, 6, 5, 6, 5, 8, 7, 8, 7, 6, 5, 6, 5, 4, 3, 4, 3, 2, 1, 2, 1, 4, 3, 4, 3, 2, 1, 2, 1, 16, 15, 16, 15, 14, 13, 14, 13, 16, 15, 16, 15, 14, 13, 14, 13, 12, 11, 12, 11, 10, 9, 10, 9, 12, 11, 12, 11, 10, 9, 10, 9, 16, 15, 16, 15, 14, 13, 14, 13, 16, 15, 16, 15, 14, 13, 14, 13, 12, 11, 12, 11, 10, 9, 10, 9, 12, 11, 12, 11, 10, 9, 10, 9, 8, 7, 8, 7, 6, 5, 6, 5, 8, 7, 8, 7, 6, 5, 6, 5, 4, 3, 4, 3, 2, 1, 2, 1, 4, 3, 4, 3, 2, 1, 2, 1, 8, 7, 8, 7, 6, 5, 6, 5, 8, 7, 8, 7, 6, 5, 6, 5, 4, 3, 4, 3, 2, 1, 2, 1, 4, 3, 4, 3, 2, 1, 2, 1}; - auto input = NDArrayFactory::create(inBuff, 'c', {2, 2, 2, 2}); + double expBuff[] = {16, 15, 16, 15, 14, 13, 14, 13, 16, 15, 16, 15, 14, 13, 14, 13, 12, 11, 12, 11, 10, 9, 10, 9, 12, 11, 12, 11, 10, 9, 10, 9, 16, 15, 16, 15, 14, 13, 14, 13, 16, 15, 16, 15, 14, 13, 14, 13, 12, 11, 12, 11, 10, 9, 10, 9, 12, 11, 12, 11, 10, 9, 10, 9, 8, 7, 8, 7, 6, 5, 6, 5, 8, 7, 8, 7, 6, 5, 6, 5, 4, 3, 4, 3, 2, 1, 2, 1, 4, 3, 4, 3, 2, 1, 2, 1, 8, 7, 8, 7, 6, 5, 6, 5, 8, 7, 8, 7, 6, 5, 6, 5, 4, 3, 4, 3, 2, 1, 2, 1, 4, 3, 4, 3, 2, 1, 2, 1, 16, 15, 16, 15, 14, 13, 14, 13, 16, 15, 16, 15, 14, 13, 14, 13, 12, 11, 12, 11, 10, 9, 10, 9, 12, 11, 12, 11, 10, 9, 10, 9, 16, 15, 16, 15, 14, 13, 14, 13, 16, 15, 16, 15, 14, 13, 14, 13, 12, 11, 12, 11, 10, 9, 10, 9, 12, 11, 12, 11, 10, 9, 10, 9, 8, 7, 8, 7, 6, 5, 6, 5, 8, 7, 8, 7, 6, 5, 6, 5, 4, 3, 4, 3, 2, 1, 2, 1, 4, 3, 4, 3, 2, 1, 2, 1, 8, 7, 8, 7, 6, 5, 6, 5, 8, 7, 8, 7, 6, 5, 6, 5, 4, 3, 4, 3, 2, 1, 2, 1, 4, 3, 4, 3, 2, 1, 2, 1}; + auto input = NDArrayFactory::create(inBuff, 'c', {2, 2, 2, 2}); auto paddings = NDArrayFactory::create(padBuff, 'c', {4, 2}); - auto expected = NDArrayFactory::create(expBuff, 'c', {4, 4, 4, 4}); + auto expected = NDArrayFactory::create(expBuff, 'c', {4, 4, 4, 4}); nd4j::ops::pad op; auto results = op.execute({&input, &paddings}, {}, {1}); @@ -1604,12 +1604,12 @@ TEST_F(DeclarableOpsTests12, pad_tests8) TEST_F(DeclarableOpsTests12, pad_tests9) { - float inBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; + double inBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; int padBuff[] = {1, 1, 1, 1, 1, 1, 1, 1}; - float expBuff[] = {1, 1, 2, 2, 1, 1, 2, 2, 3, 3, 4, 4, 3, 3, 4, 4, 1, 1, 2, 2, 1, 1, 2, 2, 3, 3, 4, 4, 3, 3, 4, 4, 5, 5, 6, 6, 5, 5, 6, 6, 7, 7, 8, 8, 7, 7, 8, 8, 5, 5, 6, 6, 5, 5, 6, 6, 7, 7, 8, 8, 7, 7, 8, 8, 1, 1, 2, 2, 1, 1, 2, 2, 3, 3, 4, 4, 3, 3, 4, 4, 1, 1, 2, 2, 1, 1, 2, 2, 3, 3, 4, 4, 3, 3, 4, 4, 5, 5, 6, 6, 5, 5, 6, 6, 7, 7, 8, 8, 7, 7, 8, 8, 5, 5, 6, 6, 5, 5, 6, 6, 7, 7, 8, 8, 7, 7, 8, 8, 9, 9, 10, 10, 9, 9, 10, 10, 11, 11, 12, 12, 11, 11, 12, 12, 9, 9, 10, 10, 9, 9, 10, 10, 11, 11, 12, 12, 11, 11, 12, 12, 13, 13, 14, 14, 13, 13, 14, 14, 15, 15, 16, 16, 15, 15, 16, 16, 13, 13, 14, 14, 13, 13, 14, 14, 15, 15, 16, 16, 15, 15, 16, 16, 9, 9, 10, 10, 9, 9, 10, 10, 11, 11, 12, 12, 11, 11, 12, 12, 9, 9, 10, 10, 9, 9, 10, 10, 11, 11, 12, 12, 11, 11, 12, 12, 13, 13, 14, 14, 13, 13, 14, 14, 15, 15, 16, 16, 15, 15, 16, 16, 13, 13, 14, 14, 13, 13, 14, 14, 15, 15, 16, 16, 15, 15, 16, 16}; - auto input = NDArrayFactory::create(inBuff, 'c', {2, 2, 2, 2}); + double expBuff[] = {1, 1, 2, 2, 1, 1, 2, 2, 3, 3, 4, 4, 3, 3, 4, 4, 1, 1, 2, 2, 1, 1, 2, 2, 3, 3, 4, 4, 3, 3, 4, 4, 5, 5, 6, 6, 5, 5, 6, 6, 7, 7, 8, 8, 7, 7, 8, 8, 5, 5, 6, 6, 5, 5, 6, 6, 7, 7, 8, 8, 7, 7, 8, 8, 1, 1, 2, 2, 1, 1, 2, 2, 3, 3, 4, 4, 3, 3, 4, 4, 1, 1, 2, 2, 1, 1, 2, 2, 3, 3, 4, 4, 3, 3, 4, 4, 5, 5, 6, 6, 5, 5, 6, 6, 7, 7, 8, 8, 7, 7, 8, 8, 5, 5, 6, 6, 5, 5, 6, 6, 7, 7, 8, 8, 7, 7, 8, 8, 9, 9, 10, 10, 9, 9, 10, 10, 11, 11, 12, 12, 11, 11, 12, 12, 9, 9, 10, 10, 9, 9, 10, 10, 11, 11, 12, 12, 11, 11, 12, 12, 13, 13, 14, 14, 13, 13, 14, 14, 15, 15, 16, 16, 15, 15, 16, 16, 13, 13, 14, 14, 13, 13, 14, 14, 15, 15, 16, 16, 15, 15, 16, 16, 9, 9, 10, 10, 9, 9, 10, 10, 11, 11, 12, 12, 11, 11, 12, 12, 9, 9, 10, 10, 9, 9, 10, 10, 11, 11, 12, 12, 11, 11, 12, 12, 13, 13, 14, 14, 13, 13, 14, 14, 15, 15, 16, 16, 15, 15, 16, 16, 13, 13, 14, 14, 13, 13, 14, 14, 15, 15, 16, 16, 15, 15, 16, 16}; + auto input = NDArrayFactory::create(inBuff, 'c', {2, 2, 2, 2}); auto paddings = NDArrayFactory::create(padBuff, 'c', {4, 2}); - auto expected = NDArrayFactory::create(expBuff, 'c', {4, 4, 4, 4}); + auto expected = NDArrayFactory::create(expBuff, 'c', {4, 4, 4, 4}); nd4j::ops::pad op; auto results = op.execute({&input, &paddings}, {}, {2}); @@ -2151,13 +2151,13 @@ TEST_F(DeclarableOpsTests12, pad_tests34) { // CONSTANT mode 2D TEST_F(DeclarableOpsTests12, Pad_1) { - float inBuff[] = {1,2,3,4,5,6}; + double inBuff[] = {1,2,3,4,5,6}; int padBuff[] = {1,1,2,2}; - float expBuff[] = {0,0,0,0,0,0,0, 0,0,1,2,3,0,0, 0,0,4,5,6,0,0, 0,0,0,0,0,0,0}; + double expBuff[] = {0,0,0,0,0,0,0, 0,0,1,2,3,0,0, 0,0,4,5,6,0,0, 0,0,0,0,0,0,0}; - auto input = NDArrayFactory::create(inBuff, 'c', {2,3}); + auto input = NDArrayFactory::create(inBuff, 'c', {2,3}); auto paddings = NDArrayFactory::create(padBuff, 'c', {2,2}); - auto expected = NDArrayFactory::create(expBuff, 'c', {4,7}); + auto expected = NDArrayFactory::create(expBuff, 'c', {4,7}); nd4j::ops::pad op; auto results = op.execute({&input, &paddings}, {}, {0}); @@ -2178,13 +2178,13 @@ TEST_F(DeclarableOpsTests12, Pad_1) { // REFLECT mode 2D TEST_F(DeclarableOpsTests12, Pad_2) { - float inBuff[] = {1,2,3,4,5,6}; + double inBuff[] = {1,2,3,4,5,6}; int padBuff[] = {1,1,2,2}; - float expBuff[] = {6,5,4,5,6,5,4, 3,2,1,2,3,2,1, 6,5,4,5,6,5,4, 3,2,1,2,3,2,1}; + double expBuff[] = {6,5,4,5,6,5,4, 3,2,1,2,3,2,1, 6,5,4,5,6,5,4, 3,2,1,2,3,2,1}; - auto input = NDArrayFactory::create(inBuff, 'c', {2,3}); + auto input = NDArrayFactory::create(inBuff, 'c', {2,3}); auto paddings = NDArrayFactory::create(padBuff, 'c', {2,2}); - auto expected = NDArrayFactory::create(expBuff, 'c', {4,7}); + auto expected = NDArrayFactory::create(expBuff, 'c', {4,7}); nd4j::ops::pad op; auto results = op.execute({&input, &paddings}, {}, {1}); @@ -2205,13 +2205,13 @@ TEST_F(DeclarableOpsTests12, Pad_2) { // SYMMETRIC mode 2D TEST_F(DeclarableOpsTests12, Pad_3) { - float inBuff[] = {1,2,3,4,5,6}; + double inBuff[] = {1,2,3,4,5,6}; int padBuff[] = {1,1,2,2}; - float expBuff[] = {2,1,1,2,3,3,2, 2,1,1,2,3,3,2, 5,4,4,5,6,6,5, 5,4,4,5,6,6,5}; + double expBuff[] = {2,1,1,2,3,3,2, 2,1,1,2,3,3,2, 5,4,4,5,6,6,5, 5,4,4,5,6,6,5}; - auto input = NDArrayFactory::create(inBuff, 'c', {2,3}); + auto input = NDArrayFactory::create(inBuff, 'c', {2,3}); auto paddings = NDArrayFactory::create(padBuff, 'c', {2,2}); - auto expected = NDArrayFactory::create(expBuff, 'c', {4,7}); + auto expected = NDArrayFactory::create(expBuff, 'c', {4,7}); nd4j::ops::pad op; auto results = op.execute({&input, &paddings}, {}, {2}); @@ -2232,13 +2232,13 @@ TEST_F(DeclarableOpsTests12, Pad_3) { // CONSTANT mode 3D TEST_F(DeclarableOpsTests12, Pad_4) { - float inBuff[] = {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18}; + double inBuff[] = {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18}; int padBuff[] = {1,1,2,2,2,2}; - float expBuff[] = {0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 1, 2, 3,0,0,0,0, 4, 5, 6,0,0,0,0, 7, 8, 9,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0,10,11,12,0,0,0,0,13,14,15,0,0,0,0,16,17,18,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0}; + double expBuff[] = {0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 1, 2, 3,0,0,0,0, 4, 5, 6,0,0,0,0, 7, 8, 9,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0,10,11,12,0,0,0,0,13,14,15,0,0,0,0,16,17,18,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0}; - auto input = NDArrayFactory::create(inBuff, 'c', {2,3,3}); + auto input = NDArrayFactory::create(inBuff, 'c', {2,3,3}); auto paddings = NDArrayFactory::create(padBuff, 'c', {3,2}); - auto expected = NDArrayFactory::create(expBuff, 'c', {4,7,7}); + auto expected = NDArrayFactory::create(expBuff, 'c', {4,7,7}); nd4j::ops::pad op; auto results = op.execute({&input, &paddings}, {}, {0}); @@ -2260,12 +2260,12 @@ TEST_F(DeclarableOpsTests12, Pad_4) { // REFLECT mode 3D TEST_F(DeclarableOpsTests12, Pad_5) { - float inBuff[] = {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18}; + double inBuff[] = {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18}; int padBuff[] = {1,1,2,2,2,2}; - float expBuff[] = {18,17,16,17,18,17,16, 15,14,13,14,15,14,13, 12,11,10,11,12,11,10, 15,14,13,14,15,14,13, 18,17,16,17,18,17,16, 15,14,13,14,15,14,13, 12,11,10,11,12,11,10, 9, 8, 7, 8, 9, 8, 7, 6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1, 6, 5, 4, 5, 6, 5, 4, 9, 8, 7, 8, 9, 8, 7, 6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1, 18,17,16,17,18,17,16, 15,14,13,14,15,14,13, 12,11,10,11,12,11,10, 15,14,13,14,15,14,13, 18,17,16,17,18,17,16, 15,14,13,14,15,14,13, 12,11,10,11,12,11,10, 9, 8, 7, 8, 9, 8, 7, 6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1, 6, 5, 4, 5, 6, 5, 4, 9, 8, 7, 8, 9, 8, 7, 6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1}; - auto input = NDArrayFactory::create(inBuff, 'c', {2,3,3}); + double expBuff[] = {18,17,16,17,18,17,16, 15,14,13,14,15,14,13, 12,11,10,11,12,11,10, 15,14,13,14,15,14,13, 18,17,16,17,18,17,16, 15,14,13,14,15,14,13, 12,11,10,11,12,11,10, 9, 8, 7, 8, 9, 8, 7, 6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1, 6, 5, 4, 5, 6, 5, 4, 9, 8, 7, 8, 9, 8, 7, 6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1, 18,17,16,17,18,17,16, 15,14,13,14,15,14,13, 12,11,10,11,12,11,10, 15,14,13,14,15,14,13, 18,17,16,17,18,17,16, 15,14,13,14,15,14,13, 12,11,10,11,12,11,10, 9, 8, 7, 8, 9, 8, 7, 6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1, 6, 5, 4, 5, 6, 5, 4, 9, 8, 7, 8, 9, 8, 7, 6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1}; + auto input = NDArrayFactory::create(inBuff, 'c', {2,3,3}); auto paddings = NDArrayFactory::create(padBuff, 'c', {3,2}); - auto expected = NDArrayFactory::create(expBuff, 'c', {4,7,7}); + auto expected = NDArrayFactory::create(expBuff, 'c', {4,7,7}); nd4j::ops::pad op; auto results = op.execute({&input, &paddings}, {}, {1}); @@ -2286,13 +2286,13 @@ TEST_F(DeclarableOpsTests12, Pad_5) { // SYMMETRIC mode 3D TEST_F(DeclarableOpsTests12, Pad_6) { - float inBuff[] = {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18}; + double inBuff[] = {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18}; int padBuff[] = {1,1,2,2,2,2}; - float expBuff[] = {5, 4, 4, 5, 6, 6, 5, 2, 1, 1, 2, 3, 3, 2, 2, 1, 1, 2, 3, 3, 2, 5, 4, 4, 5, 6, 6, 5, 8, 7, 7, 8, 9, 9, 8, 8, 7, 7, 8, 9, 9, 8, 5, 4, 4, 5, 6, 6, 5, 5, 4, 4, 5, 6, 6, 5, 2, 1, 1, 2, 3, 3, 2, 2, 1, 1, 2, 3, 3, 2, 5, 4, 4, 5, 6, 6, 5, 8, 7, 7, 8, 9, 9, 8, 8, 7, 7, 8, 9, 9, 8, 5, 4, 4, 5, 6, 6, 5, 14,13,13,14,15,15,14, 11,10,10,11,12,12,11, 11,10,10,11,12,12,11, 14,13,13,14,15,15,14, 17,16,16,17,18,18,17, 17,16,16,17,18,18,17, 14,13,13,14,15,15,14, 14,13,13,14,15,15,14, 11,10,10,11,12,12,11, 11,10,10,11,12,12,11, 14,13,13,14,15,15,14, 17,16,16,17,18,18,17, 17,16,16,17,18,18,17, 14,13,13,14,15,15,14}; + double expBuff[] = {5, 4, 4, 5, 6, 6, 5, 2, 1, 1, 2, 3, 3, 2, 2, 1, 1, 2, 3, 3, 2, 5, 4, 4, 5, 6, 6, 5, 8, 7, 7, 8, 9, 9, 8, 8, 7, 7, 8, 9, 9, 8, 5, 4, 4, 5, 6, 6, 5, 5, 4, 4, 5, 6, 6, 5, 2, 1, 1, 2, 3, 3, 2, 2, 1, 1, 2, 3, 3, 2, 5, 4, 4, 5, 6, 6, 5, 8, 7, 7, 8, 9, 9, 8, 8, 7, 7, 8, 9, 9, 8, 5, 4, 4, 5, 6, 6, 5, 14,13,13,14,15,15,14, 11,10,10,11,12,12,11, 11,10,10,11,12,12,11, 14,13,13,14,15,15,14, 17,16,16,17,18,18,17, 17,16,16,17,18,18,17, 14,13,13,14,15,15,14, 14,13,13,14,15,15,14, 11,10,10,11,12,12,11, 11,10,10,11,12,12,11, 14,13,13,14,15,15,14, 17,16,16,17,18,18,17, 17,16,16,17,18,18,17, 14,13,13,14,15,15,14}; - auto input = NDArrayFactory::create(inBuff, 'c', {2,3,3}); + auto input = NDArrayFactory::create(inBuff, 'c', {2,3,3}); auto paddings = NDArrayFactory::create(padBuff, 'c', {3,2}); - auto expected = NDArrayFactory::create(expBuff, 'c', {4,7,7}); + auto expected = NDArrayFactory::create(expBuff, 'c', {4,7,7}); nd4j::ops::pad op; auto results = op.execute({&input, &paddings}, {}, {2}); @@ -2313,12 +2313,12 @@ TEST_F(DeclarableOpsTests12, Pad_6) { TEST_F(DeclarableOpsTests12, Pad_7) { - float inBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; + double inBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; int padBuff[] = {1, 1, 1, 1, 1, 1, 1, 1}; - float expBuff[] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 3, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 6, 0, 0, 7, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 9, 10, 0, 0, 11, 12, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 13, 14, 0, 0, 15, 16, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; - auto input = NDArrayFactory::create(inBuff, 'c', {2, 2, 2, 2}); + double expBuff[] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 3, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 6, 0, 0, 7, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 9, 10, 0, 0, 11, 12, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 13, 14, 0, 0, 15, 16, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + auto input = NDArrayFactory::create(inBuff, 'c', {2, 2, 2, 2}); auto paddings = NDArrayFactory::create(padBuff, 'c', {4, 2}); - auto expected = NDArrayFactory::create(expBuff, 'c', {4, 4, 4, 4}); + auto expected = NDArrayFactory::create(expBuff, 'c', {4, 4, 4, 4}); nd4j::ops::pad op; auto results = op.execute({&input, &paddings}, {}, {0}); @@ -2339,12 +2339,12 @@ TEST_F(DeclarableOpsTests12, Pad_7) TEST_F(DeclarableOpsTests12, Pad_8) { - float inBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; + double inBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; int padBuff[] = {1, 1, 1, 1, 1, 1, 1, 1}; - float expBuff[] = {16, 15, 16, 15, 14, 13, 14, 13, 16, 15, 16, 15, 14, 13, 14, 13, 12, 11, 12, 11, 10, 9, 10, 9, 12, 11, 12, 11, 10, 9, 10, 9, 16, 15, 16, 15, 14, 13, 14, 13, 16, 15, 16, 15, 14, 13, 14, 13, 12, 11, 12, 11, 10, 9, 10, 9, 12, 11, 12, 11, 10, 9, 10, 9, 8, 7, 8, 7, 6, 5, 6, 5, 8, 7, 8, 7, 6, 5, 6, 5, 4, 3, 4, 3, 2, 1, 2, 1, 4, 3, 4, 3, 2, 1, 2, 1, 8, 7, 8, 7, 6, 5, 6, 5, 8, 7, 8, 7, 6, 5, 6, 5, 4, 3, 4, 3, 2, 1, 2, 1, 4, 3, 4, 3, 2, 1, 2, 1, 16, 15, 16, 15, 14, 13, 14, 13, 16, 15, 16, 15, 14, 13, 14, 13, 12, 11, 12, 11, 10, 9, 10, 9, 12, 11, 12, 11, 10, 9, 10, 9, 16, 15, 16, 15, 14, 13, 14, 13, 16, 15, 16, 15, 14, 13, 14, 13, 12, 11, 12, 11, 10, 9, 10, 9, 12, 11, 12, 11, 10, 9, 10, 9, 8, 7, 8, 7, 6, 5, 6, 5, 8, 7, 8, 7, 6, 5, 6, 5, 4, 3, 4, 3, 2, 1, 2, 1, 4, 3, 4, 3, 2, 1, 2, 1, 8, 7, 8, 7, 6, 5, 6, 5, 8, 7, 8, 7, 6, 5, 6, 5, 4, 3, 4, 3, 2, 1, 2, 1, 4, 3, 4, 3, 2, 1, 2, 1}; - auto input = NDArrayFactory::create(inBuff, 'c', {2, 2, 2, 2}); + double expBuff[] = {16, 15, 16, 15, 14, 13, 14, 13, 16, 15, 16, 15, 14, 13, 14, 13, 12, 11, 12, 11, 10, 9, 10, 9, 12, 11, 12, 11, 10, 9, 10, 9, 16, 15, 16, 15, 14, 13, 14, 13, 16, 15, 16, 15, 14, 13, 14, 13, 12, 11, 12, 11, 10, 9, 10, 9, 12, 11, 12, 11, 10, 9, 10, 9, 8, 7, 8, 7, 6, 5, 6, 5, 8, 7, 8, 7, 6, 5, 6, 5, 4, 3, 4, 3, 2, 1, 2, 1, 4, 3, 4, 3, 2, 1, 2, 1, 8, 7, 8, 7, 6, 5, 6, 5, 8, 7, 8, 7, 6, 5, 6, 5, 4, 3, 4, 3, 2, 1, 2, 1, 4, 3, 4, 3, 2, 1, 2, 1, 16, 15, 16, 15, 14, 13, 14, 13, 16, 15, 16, 15, 14, 13, 14, 13, 12, 11, 12, 11, 10, 9, 10, 9, 12, 11, 12, 11, 10, 9, 10, 9, 16, 15, 16, 15, 14, 13, 14, 13, 16, 15, 16, 15, 14, 13, 14, 13, 12, 11, 12, 11, 10, 9, 10, 9, 12, 11, 12, 11, 10, 9, 10, 9, 8, 7, 8, 7, 6, 5, 6, 5, 8, 7, 8, 7, 6, 5, 6, 5, 4, 3, 4, 3, 2, 1, 2, 1, 4, 3, 4, 3, 2, 1, 2, 1, 8, 7, 8, 7, 6, 5, 6, 5, 8, 7, 8, 7, 6, 5, 6, 5, 4, 3, 4, 3, 2, 1, 2, 1, 4, 3, 4, 3, 2, 1, 2, 1}; + auto input = NDArrayFactory::create(inBuff, 'c', {2, 2, 2, 2}); auto paddings = NDArrayFactory::create(padBuff, 'c', {4, 2}); - auto expected = NDArrayFactory::create(expBuff, 'c', {4, 4, 4, 4}); + auto expected = NDArrayFactory::create(expBuff, 'c', {4, 4, 4, 4}); nd4j::ops::pad op; auto results = op.execute({&input, &paddings}, {}, {1}); @@ -2365,12 +2365,12 @@ TEST_F(DeclarableOpsTests12, Pad_8) TEST_F(DeclarableOpsTests12, Pad_9) { - float inBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; + double inBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; int padBuff[] = {1, 1, 1, 1, 1, 1, 1, 1}; - float expBuff[] = {1, 1, 2, 2, 1, 1, 2, 2, 3, 3, 4, 4, 3, 3, 4, 4, 1, 1, 2, 2, 1, 1, 2, 2, 3, 3, 4, 4, 3, 3, 4, 4, 5, 5, 6, 6, 5, 5, 6, 6, 7, 7, 8, 8, 7, 7, 8, 8, 5, 5, 6, 6, 5, 5, 6, 6, 7, 7, 8, 8, 7, 7, 8, 8, 1, 1, 2, 2, 1, 1, 2, 2, 3, 3, 4, 4, 3, 3, 4, 4, 1, 1, 2, 2, 1, 1, 2, 2, 3, 3, 4, 4, 3, 3, 4, 4, 5, 5, 6, 6, 5, 5, 6, 6, 7, 7, 8, 8, 7, 7, 8, 8, 5, 5, 6, 6, 5, 5, 6, 6, 7, 7, 8, 8, 7, 7, 8, 8, 9, 9, 10, 10, 9, 9, 10, 10, 11, 11, 12, 12, 11, 11, 12, 12, 9, 9, 10, 10, 9, 9, 10, 10, 11, 11, 12, 12, 11, 11, 12, 12, 13, 13, 14, 14, 13, 13, 14, 14, 15, 15, 16, 16, 15, 15, 16, 16, 13, 13, 14, 14, 13, 13, 14, 14, 15, 15, 16, 16, 15, 15, 16, 16, 9, 9, 10, 10, 9, 9, 10, 10, 11, 11, 12, 12, 11, 11, 12, 12, 9, 9, 10, 10, 9, 9, 10, 10, 11, 11, 12, 12, 11, 11, 12, 12, 13, 13, 14, 14, 13, 13, 14, 14, 15, 15, 16, 16, 15, 15, 16, 16, 13, 13, 14, 14, 13, 13, 14, 14, 15, 15, 16, 16, 15, 15, 16, 16}; - auto input = NDArrayFactory::create(inBuff, 'c', {2, 2, 2, 2}); + double expBuff[] = {1, 1, 2, 2, 1, 1, 2, 2, 3, 3, 4, 4, 3, 3, 4, 4, 1, 1, 2, 2, 1, 1, 2, 2, 3, 3, 4, 4, 3, 3, 4, 4, 5, 5, 6, 6, 5, 5, 6, 6, 7, 7, 8, 8, 7, 7, 8, 8, 5, 5, 6, 6, 5, 5, 6, 6, 7, 7, 8, 8, 7, 7, 8, 8, 1, 1, 2, 2, 1, 1, 2, 2, 3, 3, 4, 4, 3, 3, 4, 4, 1, 1, 2, 2, 1, 1, 2, 2, 3, 3, 4, 4, 3, 3, 4, 4, 5, 5, 6, 6, 5, 5, 6, 6, 7, 7, 8, 8, 7, 7, 8, 8, 5, 5, 6, 6, 5, 5, 6, 6, 7, 7, 8, 8, 7, 7, 8, 8, 9, 9, 10, 10, 9, 9, 10, 10, 11, 11, 12, 12, 11, 11, 12, 12, 9, 9, 10, 10, 9, 9, 10, 10, 11, 11, 12, 12, 11, 11, 12, 12, 13, 13, 14, 14, 13, 13, 14, 14, 15, 15, 16, 16, 15, 15, 16, 16, 13, 13, 14, 14, 13, 13, 14, 14, 15, 15, 16, 16, 15, 15, 16, 16, 9, 9, 10, 10, 9, 9, 10, 10, 11, 11, 12, 12, 11, 11, 12, 12, 9, 9, 10, 10, 9, 9, 10, 10, 11, 11, 12, 12, 11, 11, 12, 12, 13, 13, 14, 14, 13, 13, 14, 14, 15, 15, 16, 16, 15, 15, 16, 16, 13, 13, 14, 14, 13, 13, 14, 14, 15, 15, 16, 16, 15, 15, 16, 16}; + auto input = NDArrayFactory::create(inBuff, 'c', {2, 2, 2, 2}); auto paddings = NDArrayFactory::create(padBuff, 'c', {4, 2}); - auto expected = NDArrayFactory::create(expBuff, 'c', {4, 4, 4, 4}); + auto expected = NDArrayFactory::create(expBuff, 'c', {4, 4, 4, 4}); nd4j::ops::pad op; auto results = op.execute({&input, &paddings}, {}, {2}); @@ -2387,8 +2387,8 @@ TEST_F(DeclarableOpsTests12, Pad_9) } TEST_F(DeclarableOpsTests12, Test_Expose_1) { - auto input0 = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 6, 5, 4}); - auto input1 = NDArrayFactory::create('c', {2, 3}, {3, 2, 1, 4, 5, 6}); + auto input0 = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 6, 5, 4}); + auto input1 = NDArrayFactory::create('c', {2, 3}, {3, 2, 1, 4, 5, 6}); nd4j::ops::expose op; diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp index 71ee8a04e..209b16a7a 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp @@ -1027,13 +1027,13 @@ TEST_F(DeclarableOpsTests13, lstmLayer_1) { std::initializer_list iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; std::initializer_list bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; - auto expH = NDArrayFactory::create('c', {sL, bS, nOut}, {0.57574,0.57574,0.57574,0.58006,0.58006,0.58006,0.58434,0.58434,0.58434, - 0.55114,0.55114,0.55114,0.55732,0.55732,0.55732,0.56338,0.56338,0.56338, - 0.53763,0.53763,0.53763,0.54534,0.54534,0.54534,0.55287,0.55287,0.55287, - 0.53626,0.53626,0.53626,0.54487,0.54487,0.54487,0.55327,0.55327,0.55327, - 0.54484,0.54484,0.54484,0.55379,0.55379,0.55379,0.5625 ,0.5625 ,0.5625}); + auto expH = NDArrayFactory::create('c', {sL, bS, nOut}, {0.57574f, 0.57574f, 0.57574f, 0.58006f, 0.58006f, 0.58006f, 0.58434f, 0.58434f, 0.58434f, + 0.55114f, 0.55114f, 0.55114f, 0.55732f, 0.55732f, 0.55732f, 0.56338f, 0.56338f, 0.56338f, + 0.53763f, 0.53763f, 0.53763f, 0.54534f, 0.54534f, 0.54534f, 0.55287f, 0.55287f, 0.55287f, + 0.53626f, 0.53626f, 0.53626f, 0.54487f, 0.54487f, 0.54487f, 0.55327f, 0.55327f, 0.55327f, + 0.54484f, 0.54484f, 0.54484f, 0.55379f, 0.55379f, 0.55379f, 0.5625f, 0.5625f, 0.5625f}); - auto expClast = NDArrayFactory::create('c', {bS, nOut}, {1.1589154,1.1589154,1.1589154,1.1892855,1.1892855,1.1892855,1.219861 ,1.219861 ,1.219861}); + auto expClast = NDArrayFactory::create('c', {bS, nOut}, {1.1589154f, 1.1589154f, 1.1589154f, 1.1892855f, 1.1892855f, 1.1892855f, 1.219861f, 1.219861f, 1.219861f}); nd4j::ops::lstmLayer op; auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs); @@ -1097,11 +1097,11 @@ TEST_F(DeclarableOpsTests13, lstmLayer_2) { std::initializer_list iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; std::initializer_list bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; - auto expH = NDArrayFactory::create('c', {bS, sL, nOut}, {0.575735, 0.575735, 0.575735, 0.541562, 0.541562, 0.541562, 0.514003, 0.514003, 0.514003, 0.495597, 0.495597, 0.495597, 0.485999, 0.485999, 0.485999, - 0.596965, 0.596965, 0.596965, 0.571978, 0.571978, 0.571978, 0.552888, 0.552888, 0.552888, 0.540606, 0.540606, 0.540606, 0.534764, 0.534764, 0.534764, - 0.61725 , 0.61725 , 0.61725 , 0.599828, 0.599828, 0.599828, 0.587627, 0.587627, 0.587627, 0.580408, 0.580408, 0.580408, 0.577735, 0.577735, 0.577735}); + auto expH = NDArrayFactory::create('c', {bS, sL, nOut}, {0.575735f, 0.575735f, 0.575735f, 0.541562f, 0.541562f, 0.541562f, 0.514003f, 0.514003f, 0.514003f, 0.495597f, 0.495597f, 0.495597f, 0.485999f, 0.485999f, 0.485999f, + 0.596965f, 0.596965f, 0.596965f, 0.571978f, 0.571978f, 0.571978f, 0.552888f, 0.552888f, 0.552888f, 0.540606f, 0.540606f, 0.540606f, 0.534764f, 0.534764f, 0.534764f, + 0.61725f, 0.61725f, 0.61725f, 0.599828f, 0.599828f, 0.599828f, 0.587627f, 0.587627f, 0.587627f, 0.580408f, 0.580408f, 0.580408f, 0.577735f, 0.577735f, 0.577735f}); - auto expClast = NDArrayFactory::create('c', {bS, nOut}, {0.996965, 0.996965, 0.996965, 1.146756, 1.146756, 1.146756, 1.301922, 1.301922, 1.301922}); + auto expClast = NDArrayFactory::create('c', {bS, nOut}, {0.996965f, 0.996965f, 0.996965f, 1.146756f, 1.146756f, 1.146756f, 1.301922f, 1.301922f, 1.301922f}); nd4j::ops::lstmLayer op; auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp index fdaa7b549..41dc12a14 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp @@ -178,10 +178,10 @@ TEST_F(DeclarableOpsTests15, Test_AdjustContrast_1) { TEST_F(DeclarableOpsTests15, Test_AdjustContrast_2) { auto x = NDArrayFactory::create('c', {1, 4,4,3}); auto e = NDArrayFactory::create('c', {1, 4,4,3}, { - -21.5, -20.5, -19.5, -15.5, -14.5, -13.5, -9.5, -8.5, -7.5, -3.5, -2.5, -1.5, - 2.5, 3.5, 4.5, 8.5, 9.5, 10.5, 14.5, 15.5, 16.5, 20.5, 21.5, 22.5, - 26.5, 27.5, 28.5, 32.5, 33.5, 34.5, 38.5, 39.5, 40.5, 44.5, 45.5, 46.5, - 50.5, 51.5, 52.5, 56.5, 57.5, 58.5, 62.5, 63.5, 64.5, 68.5, 69.5, 70.5 + -21.5f, -20.5f, -19.5f, -15.5f, -14.5f, -13.5f, -9.5f, -8.5f, -7.5f, -3.5f, -2.5f, -1.5f, + 2.5f, 3.5f, 4.5f, 8.5f, 9.5f, 10.5f, 14.5f, 15.5f, 16.5f, 20.5f, 21.5f, 22.5f, + 26.5f, 27.5f, 28.5f, 32.5f, 33.5f, 34.5f, 38.5f, 39.5f, 40.5f, 44.5f, 45.5f, 46.5f, + 50.5f, 51.5f, 52.5f, 56.5f, 57.5f, 58.5f, 62.5f, 63.5f, 64.5f, 68.5f, 69.5f, 70.5f }); x.linspace(1.); nd4j::ops::adjust_contrast op; @@ -196,10 +196,10 @@ TEST_F(DeclarableOpsTests15, Test_AdjustContrast_2) { TEST_F(DeclarableOpsTests15, Test_AdjustContrast_3) { auto x = NDArrayFactory::create('c', {1, 4,4,3}); auto e = NDArrayFactory::create('c', {1, 4,4,3}, { - -21.5, -20.5, -19.5, -15.5, -14.5, -13.5, -9.5, -8.5, -7.5, -3.5, -2.5, -1.5, - 2.5, 3.5, 4.5, 8.5, 9.5, 10.5, 14.5, 15.5, 16.5, 20.5, 21.5, 22.5, - 26.5, 27.5, 28.5, 32.5, 33.5, 34.5, 38.5, 39.5, 40.5, 44.5, 45.5, 46.5, - 50.5, 51.5, 52.5, 56.5, 57.5, 58.5, 62.5, 63.5, 64.5, 68.5, 69.5, 70.5 + -21.5f, -20.5f, -19.5f, -15.5f, -14.5f, -13.5f, -9.5f, -8.5f, -7.5f, -3.5f, -2.5f, -1.5f, + 2.5f, 3.5f, 4.5f, 8.5f, 9.5f, 10.5f, 14.5f, 15.5f, 16.5f, 20.5f, 21.5f, 22.5f, + 26.5f, 27.5f, 28.5f, 32.5f, 33.5f, 34.5f, 38.5f, 39.5f, 40.5f, 44.5f, 45.5f, 46.5f, + 50.5f, 51.5f, 52.5f, 56.5f, 57.5f, 58.5f, 62.5f, 63.5f, 64.5f, 68.5f, 69.5f, 70.5f }); x.linspace(1.); nd4j::ops::adjust_contrast_v2 op; @@ -243,8 +243,8 @@ TEST_F(DeclarableOpsTests15, Test_BitCast_1) { TEST_F(DeclarableOpsTests15, Test_BitCast_2) { auto x = NDArrayFactory::create('c', {2, 4}); - auto e = NDArrayFactory::create('c', {2, 4, 2}, {0, 1.875, 0, 2., 0, 2.125, 0, 2.25, - 0, 2.312, 0, 2.375, 0, 2.438, 0., 2.5}); + auto e = NDArrayFactory::create('c', {2, 4, 2}, {0.f, 1.875f, 0.f, 2.f, 0.f, 2.125f, 0.f, 2.25f, + 0.f, 2.312f, 0.f, 2.375f, 0.f, 2.438f, 0.f, 2.5f}); x.linspace(1.); nd4j::ops::bitcast op; auto result = op.execute({&x}, {}, {nd4j::DataType::HALF}, {}); @@ -423,9 +423,9 @@ TEST_F(DeclarableOpsTests15, test_check_numeric_3) { } TEST_F(DeclarableOpsTests15, Test_layer_norm_1) { - auto x = NDArrayFactory::create('c', {1, 5}, {1., 2., 3., 4., 5.}); - auto g = NDArrayFactory::create('c', {5}, {1., 2., 3., 4., 5.}); - auto b = NDArrayFactory::create('c', {5}, {1., 2., 3., 4., 5.}); + auto x = NDArrayFactory::create('c', {1, 5}, {1.f, 2.f, 3.f, 4.f, 5.f}); + auto g = NDArrayFactory::create('c', {5}, {1.f, 2.f, 3.f, 4.f, 5.f}); + auto b = NDArrayFactory::create('c', {5}, {1.f, 2.f, 3.f, 4.f, 5.f}); nd4j::ops::layer_norm op; auto result = op.execute({&x, &g, &b}, {}, {0}, {false}); @@ -434,10 +434,10 @@ TEST_F(DeclarableOpsTests15, Test_layer_norm_1) { } TEST_F(DeclarableOpsTests15, Test_layer_norm_bp_1) { - auto x = NDArrayFactory::create('c', {1, 5}, {1., 2., 3., 4., 5.}); - auto g = NDArrayFactory::create('c', {5}, {1., 2., 3., 4., 5.}); - auto b = NDArrayFactory::create('c', {5}, {1., 2., 3., 4., 5.}); - auto eps = NDArrayFactory::create('c', {1, 5}, {0., 0., 0., 0., 0.}); + auto x = NDArrayFactory::create('c', {1, 5}, {1.f, 2.f, 3.f, 4.f, 5.f}); + auto g = NDArrayFactory::create('c', {5}, {1.f, 2.f, 3.f, 4.f, 5.f}); + auto b = NDArrayFactory::create('c', {5}, {1.f, 2.f, 3.f, 4.f, 5.f}); + auto eps = NDArrayFactory::create('c', {1, 5}, {0.f, 0.f, 0.f, 0.f, 0.f}); nd4j::ops::layer_norm_bp op; auto result = op.execute({&x, &g, &b, &eps}, {}, {0}, {false}); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests16.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests16.cpp index c2e39cab5..38d88b469 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests16.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests16.cpp @@ -40,7 +40,7 @@ public: }; TEST_F(DeclarableOpsTests16, scatter_upd_1) { - auto x = NDArrayFactory::create('c', {3}, {1, 1, 1}); + auto x = NDArrayFactory::create('c', {3}, {1.f, 1.f, 1.f}); auto y = NDArrayFactory::create(0); auto w = NDArrayFactory::create(3.0f); auto e = NDArrayFactory::create('c', {3}, {3.f, 1.f, 1.f}); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests2.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests2.cpp index 076d14385..9f9c39156 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests2.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests2.cpp @@ -400,7 +400,7 @@ TEST_F(DeclarableOpsTests2, NLP_Cbow_Test_1) { TEST_F(DeclarableOpsTests2, YetAnotherMatmulTest_1) { auto A = NDArrayFactory::create('c', {3, 3}); auto B = NDArrayFactory::create('c', {3, 1}); - auto exp = NDArrayFactory::create('c', {3, 1}, {14.00, 32.00, 50.00}); + auto exp = NDArrayFactory::create('c', {3, 1}, {14.00f, 32.00f, 50.00f}); A.linspace(1); B.linspace(1); @@ -457,9 +457,9 @@ TEST_F(DeclarableOpsTests2, Test_Squeeze_2) { } TEST_F(DeclarableOpsTests2, Test_FloorMod_1) { - auto x = NDArrayFactory::create('c', {1, 3}, {2.0, 6.0, -3.0}); - auto y = NDArrayFactory::create('c', {1, 3}, {-3.0, 2.0, -2.0}); - auto exp = NDArrayFactory::create('c', {1, 3}, {-1., 0., -1.,}); + auto x = NDArrayFactory::create('c', {1, 3}, {2.0f, 6.0f, -3.0f}); + auto y = NDArrayFactory::create('c', {1, 3}, {-3.0f, 2.0f, -2.0f}); + auto exp = NDArrayFactory::create('c', {1, 3}, {-1.f, 0.f, -1.f}); nd4j::ops::floormod op; @@ -475,9 +475,9 @@ TEST_F(DeclarableOpsTests2, Test_FloorMod_1) { } TEST_F(DeclarableOpsTests2, Test_FloorDiv_1) { - auto x = NDArrayFactory::create('c', {1, 3}, {3.0, 6.0, -3.0}); - auto y = NDArrayFactory::create('c', {1, 3}, {-2.0, 2.0, -2.0}); - auto exp = NDArrayFactory::create('c', {1, 3}, {-2., 3., 1.,}); + auto x = NDArrayFactory::create('c', {1, 3}, {3.0f, 6.0f, -3.0f}); + auto y = NDArrayFactory::create('c', {1, 3}, {-2.0f, 2.0f, -2.0f}); + auto exp = NDArrayFactory::create('c', {1, 3}, {-2.f, 3.f, 1.f}); nd4j::ops::floordiv op; @@ -494,9 +494,9 @@ TEST_F(DeclarableOpsTests2, Test_FloorDiv_1) { } TEST_F(DeclarableOpsTests2, Test_FloorDiv_2) { - auto x = NDArrayFactory::create('c', {1, 3}, {3.0, 6.0, -3.0}); - auto y = NDArrayFactory::create('c', {1, 3}, {-2.0, 2.0, -2.0}); - auto eps = NDArrayFactory::create('c', {1, 3}, {1, 2, 3}); + auto x = NDArrayFactory::create('c', {1, 3}, {3.0f, 6.0f, -3.0f}); + auto y = NDArrayFactory::create('c', {1, 3}, {-2.0f, 2.0f, -2.0f}); + auto eps = NDArrayFactory::create('c', {1, 3}, {1.f, 2.f, 3.f}); auto exp1 = NDArrayFactory::create('c', {1, 3}, {0.f, 0.f, 0.f}); auto exp2 = NDArrayFactory::create('c', {1, 3}, {0.f, 0.f, 0.f}); @@ -518,8 +518,8 @@ TEST_F(DeclarableOpsTests2, Test_FloorDiv_2) { } TEST_F(DeclarableOpsTests2, Test_CRelu_1) { - auto x = NDArrayFactory::create('c', {2, 2}, {1.0, 2.0, 3.0, 4.0}); - auto exp = NDArrayFactory::create('c', {2, 4}, {1.0, 2.0, 0, 0, 3.0, 4.0, 0, 0}); + auto x = NDArrayFactory::create('c', {2, 2}, {1.0f, 2.0f, 3.0f, 4.0f}); + auto exp = NDArrayFactory::create('c', {2, 4}, {1.0f, 2.0f, 0.f, 0.f, 3.0f, 4.0f, 0.f, 0.f}); nd4j::ops::crelu op; @@ -536,9 +536,9 @@ TEST_F(DeclarableOpsTests2, Test_CRelu_1) { } TEST_F(DeclarableOpsTests2, Test_CRelu_BP_2) { - auto x = NDArrayFactory::create('c', {2, 2}, {1.0, 2.0, -3.0, 4.0}); - auto eps = NDArrayFactory::create('c', {2, 4}, {1.0, 2.0, 4, 3, 3.0, 4.0, 2, 1}); - auto exp = NDArrayFactory::create('c', {2, 2}, {1, 2, -2, 4}); + auto x = NDArrayFactory::create('c', {2, 2}, {1.0f, 2.0f, -3.0f, 4.0f}); + auto eps = NDArrayFactory::create('c', {2, 4}, {1.0f, 2.0f, 4.f, 3.f, 3.0f, 4.0f, 2.f, 1.f}); + auto exp = NDArrayFactory::create('c', {2, 2}, {1.f, 2.f, -2.f, 4.f}); nd4j::ops::crelu_bp op; auto result = op.execute({&x, &eps}, {}, {}); @@ -556,9 +556,9 @@ TEST_F(DeclarableOpsTests2, Test_CRelu_BP_2) { TEST_F(DeclarableOpsTests2, Test_Concat_BP_1) { auto x = NDArrayFactory::create('c', {2, 2}); auto y = NDArrayFactory::create('c', {2, 2}); - auto eps = NDArrayFactory::create('c', {2, 4}, {1.0, 2.0, 0, 1, 3.0, 4.0, 0, 1}); - auto expEX = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); - auto expEY = NDArrayFactory::create('c', {2, 2}, {0, 1, 0, 1}); + auto eps = NDArrayFactory::create('c', {2, 4}, {1.0f, 2.0f, 0.f, 1.f, 3.0f, 4.0f, 0.f, 1.f}); + auto expEX = NDArrayFactory::create('c', {2, 2}, {1.f, 2.f, 3.f, 4.f}); + auto expEY = NDArrayFactory::create('c', {2, 2}, {0.f, 1.f, 0.f, 1.f}); nd4j::ops::concat_bp op; auto result = op.execute({&x, &y, &eps}, {}, {-1}); @@ -581,9 +581,9 @@ TEST_F(DeclarableOpsTests2, Test_Concat_BP_1) { //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, TestTensorDot5) { - auto x = NDArrayFactory::create('c', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15}); - auto y = NDArrayFactory::create('c', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16}); - auto expected = NDArrayFactory::create('c', {2,4,2,4}, {44,110,160, 66,132, 38, 88,154, 68,170,224,102,204, 82,136,238, 92,230,288,138,276,126,184,322, 116,290,352,174,348,170,232,406, 76,190,160,114,228,182,152,266, 100,250,224,150,300,226,200,350, 124,310,288,186,372,270,248,434, 148,370,352,222,444,314,296,518}); + auto x = NDArrayFactory::create('c', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15}); + auto y = NDArrayFactory::create('c', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16}); + auto expected = NDArrayFactory::create('c', {2,4,2,4}, {44,110,160, 66,132, 38, 88,154, 68,170,224,102,204, 82,136,238, 92,230,288,138,276,126,184,322, 116,290,352,174,348,170,232,406, 76,190,160,114,228,182,152,266, 100,250,224,150,300,226,200,350, 124,310,288,186,372,270,248,434, 148,370,352,222,444,314,296,518}); nd4j::ops::tensormmul op; auto results = op.execute({&x, &y}, {}, {1,1,1,2}); @@ -603,9 +603,9 @@ TEST_F(DeclarableOpsTests2, TestTensorDot5) { //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, TestTensorDot6) { - auto x = NDArrayFactory::create('c', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15}); - auto y = NDArrayFactory::create('f', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16}); - auto expected = NDArrayFactory::create('c', {2,4,2,4}, {22, 66,110,154, 44, 88,132,176, 34,102,170,238, 68,136,204,272, 46,138,230,322, 92,184,276,368, 58,174,290,406,116,232,348,464, 38,114,190,266, 76,152,228,304, 50,150,250,350,100,200,300,400, 62,186,310,434,124,248,372,496, 74,222,370,518,148,296,444,592}); + auto x = NDArrayFactory::create('c', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15}); + auto y = NDArrayFactory::create('f', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16}); + auto expected = NDArrayFactory::create('c', {2,4,2,4}, {22, 66,110,154, 44, 88,132,176, 34,102,170,238, 68,136,204,272, 46,138,230,322, 92,184,276,368, 58,174,290,406,116,232,348,464, 38,114,190,266, 76,152,228,304, 50,150,250,350,100,200,300,400, 62,186,310,434,124,248,372,496, 74,222,370,518,148,296,444,592}); nd4j::ops::tensormmul op; auto results = op.execute({&x, &y}, {}, {1,1,1,2}); @@ -624,9 +624,9 @@ TEST_F(DeclarableOpsTests2, TestTensorDot6) { //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, TestTensorDot7) { - auto x = NDArrayFactory::create('f', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15}); - auto y = NDArrayFactory::create('c', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16}); - auto expected = NDArrayFactory::create('c', {2,4,2,4}, {76,166,112,106,196, 62,136,226, 60,174,208, 98,212,230,136,250, 76,214,336,122,260,174,168,306, 124,286,240,178,340,150,232,394, 100,226,176,142,268,106,184,310, 84,234,272,134,284,274,184,334, 100,274,400,158,332,218,216,390, 148,346,304,214,412,194,280,478}); + auto x = NDArrayFactory::create('f', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15}); + auto y = NDArrayFactory::create('c', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16}); + auto expected = NDArrayFactory::create('c', {2,4,2,4}, {76,166,112,106,196, 62,136,226, 60,174,208, 98,212,230,136,250, 76,214,336,122,260,174,168,306, 124,286,240,178,340,150,232,394, 100,226,176,142,268,106,184,310, 84,234,272,134,284,274,184,334, 100,274,400,158,332,218,216,390, 148,346,304,214,412,194,280,478}); nd4j::ops::tensormmul op; auto results = op.execute({&x, &y}, {}, {1,1,1,2}); @@ -645,9 +645,9 @@ TEST_F(DeclarableOpsTests2, TestTensorDot7) { //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, TestTensorDot8) { - auto x = NDArrayFactory::create('f', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15}); - auto y = NDArrayFactory::create('f', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16}); - auto expected = NDArrayFactory::create('c', {2,4,2,4}, {30, 90,150,210, 60,120,180,240, 38,114,190,266, 76,152,228,304, 46,138,230,322, 92,184,276,368, 54,162,270,378,108,216,324,432, 42,126,210,294, 84,168,252,336, 50,150,250,350,100,200,300,400, 58,174,290,406,116,232,348,464, 66,198,330,462,132,264,396,528}); + auto x = NDArrayFactory::create('f', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15}); + auto y = NDArrayFactory::create('f', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16}); + auto expected = NDArrayFactory::create('c', {2,4,2,4}, {30, 90,150,210, 60,120,180,240, 38,114,190,266, 76,152,228,304, 46,138,230,322, 92,184,276,368, 54,162,270,378,108,216,324,432, 42,126,210,294, 84,168,252,336, 50,150,250,350,100,200,300,400, 58,174,290,406,116,232,348,464, 66,198,330,462,132,264,396,528}); nd4j::ops::tensormmul op; auto results = op.execute({&x, &y}, {}, {1,1,1,2}); @@ -674,9 +674,9 @@ TEST_F(DeclarableOpsTests2, TestTensorDot9) { // z.printShapeInfo(); // z.printIndexedBuffer(); - auto x = NDArrayFactory::create('f', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15}); - auto y = NDArrayFactory::create('f', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16}); - auto expected = NDArrayFactory::create('c', {3,4,4,3}, {14, 14, 14, 30, 30, 30, 46, 46, 46, 62, 62, 62, 86, 86, 86,198,198,198,310,310,310,422,422,422, 62, 62, 62,142,142,142,222,222,222,302,302,302, 38, 38, 38, 86, 86, 86,134,134,134,182,182,182, 38, 38, 38, 86, 86, 86,134,134,134,182,182,182, 14, 14, 14, 30, 30, 30, 46, 46, 46, 62, 62, 62, 86, 86, 86,198,198,198,310,310,310,422,422,422, 62, 62, 62,142,142,142,222,222,222,302,302,302, 62, 62, 62,142,142,142,222,222,222,302,302,302, 38, 38, 38, 86, 86, 86,134,134,134,182,182,182, 14, 14, 14, 30, 30, 30, 46, 46, 46, 62, 62, 62, 86, 86, 86,198,198,198,310,310,310,422,422,422}); + auto x = NDArrayFactory::create('f', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15}); + auto y = NDArrayFactory::create('f', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16}); + auto expected = NDArrayFactory::create('c', {3,4,4,3}, {14, 14, 14, 30, 30, 30, 46, 46, 46, 62, 62, 62, 86, 86, 86,198,198,198,310,310,310,422,422,422, 62, 62, 62,142,142,142,222,222,222,302,302,302, 38, 38, 38, 86, 86, 86,134,134,134,182,182,182, 38, 38, 38, 86, 86, 86,134,134,134,182,182,182, 14, 14, 14, 30, 30, 30, 46, 46, 46, 62, 62, 62, 86, 86, 86,198,198,198,310,310,310,422,422,422, 62, 62, 62,142,142,142,222,222,222,302,302,302, 62, 62, 62,142,142,142,222,222,222,302,302,302, 38, 38, 38, 86, 86, 86,134,134,134,182,182,182, 14, 14, 14, 30, 30, 30, 46, 46, 46, 62, 62, 62, 86, 86, 86,198,198,198,310,310,310,422,422,422}); nd4j::ops::tensormmul op; auto results = op.execute({&x, &y}, {}, {1,0,1,0}); @@ -695,9 +695,9 @@ TEST_F(DeclarableOpsTests2, TestTensorDot9) { //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, TestTensorDot10) { - auto x = NDArrayFactory::create('f', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15}); - auto y = NDArrayFactory::create('f', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16}); - auto expected = NDArrayFactory::create('c', {4,4}, {114,258,402,546, 138,314,490,666, 162,370,578,786, 186,426,666,906}); + auto x = NDArrayFactory::create('f', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15}); + auto y = NDArrayFactory::create('f', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16}); + auto expected = NDArrayFactory::create('c', {4,4}, {114,258,402,546, 138,314,490,666, 162,370,578,786, 186,426,666,906}); nd4j::ops::tensormmul op; auto results = op.execute({&x, &y}, {}, {2,0,1, 2,0,2}); @@ -717,9 +717,9 @@ TEST_F(DeclarableOpsTests2, TestTensorDot10) { //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, TestTensorDot11) { - auto x = NDArrayFactory::create('c', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15}); - auto y = NDArrayFactory::create('f', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16}); - auto expected = NDArrayFactory::create('c', {4,4}, {98,218,338,458, 134,302,470,638, 170,386,602,818, 206,470,734,998}); + auto x = NDArrayFactory::create('c', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15}); + auto y = NDArrayFactory::create('f', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16}); + auto expected = NDArrayFactory::create('c', {4,4}, {98,218,338,458, 134,302,470,638, 170,386,602,818, 206,470,734,998}); nd4j::ops::tensormmul op; auto results = op.execute({&x, &y}, {}, {2,0,1, 2,0,2}); @@ -738,9 +738,9 @@ TEST_F(DeclarableOpsTests2, TestTensorDot11) { //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, TestTensorDot12) { - auto x = NDArrayFactory::create('c', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15}); - auto y = NDArrayFactory::create('c', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16}); - auto expected = NDArrayFactory::create('c', {4,4}, {272,292,312,332, 368,396,424,452, 464,500,536,572, 560,604,648,692}); + auto x = NDArrayFactory::create('c', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15}); + auto y = NDArrayFactory::create('c', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16}); + auto expected = NDArrayFactory::create('c', {4,4}, {272,292,312,332, 368,396,424,452, 464,500,536,572, 560,604,648,692}); nd4j::ops::tensormmul op; auto results = op.execute({&x, &y}, {}, {2,0,1, 2,0,2}); @@ -759,9 +759,9 @@ TEST_F(DeclarableOpsTests2, TestTensorDot12) { //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, TestTensorDot13) { - auto x = NDArrayFactory::create('c', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15}); - auto y = NDArrayFactory::create('c', {4,2,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16}); - auto expected = NDArrayFactory::create('c', {3,3}, {640,560,640, 576,624,576, 640,560,640}); + auto x = NDArrayFactory::create('c', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15}); + auto y = NDArrayFactory::create('c', {4,2,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16}); + auto expected = NDArrayFactory::create('c', {3,3}, {640,560,640, 576,624,576, 640,560,640}); nd4j::ops::tensormmul op; auto results = op.execute({&x, &y}, {}, {2,0,2, 2,1,0}); @@ -780,9 +780,9 @@ TEST_F(DeclarableOpsTests2, TestTensorDot13) { //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, TestTensorDot14) { - auto x = NDArrayFactory::create('f', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15}); - auto y = NDArrayFactory::create('c', {4,2,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16}); - auto expected = NDArrayFactory::create('c', {3,3}, {648,600,520, 648,536,648, 520,600,648}); + auto x = NDArrayFactory::create('f', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15}); + auto y = NDArrayFactory::create('c', {4,2,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16}); + auto expected = NDArrayFactory::create('c', {3,3}, {648,600,520, 648,536,648, 520,600,648}); nd4j::ops::tensormmul op; auto results = op.execute({&x, &y}, {}, {2,0,2, 2,1,0}); @@ -801,9 +801,9 @@ TEST_F(DeclarableOpsTests2, TestTensorDot14) { //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, TestTensorDot15) { - auto x = NDArrayFactory::create('f', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15}); - auto y = NDArrayFactory::create('f', {4,2,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16}); - auto expected = NDArrayFactory::create('c', {3,3}, {624,624,624, 656,656,656, 624,624,624}); + auto x = NDArrayFactory::create('f', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15}); + auto y = NDArrayFactory::create('f', {4,2,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16}); + auto expected = NDArrayFactory::create('c', {3,3}, {624,624,624, 656,656,656, 624,624,624}); nd4j::ops::tensormmul op; auto results = op.execute({&x, &y}, {}, {2,0,2, 2,1,0}); @@ -1449,7 +1449,7 @@ TEST_F(DeclarableOpsTests2, cosine_distance_loss_test1) { auto labels = NDArrayFactory::create('c', {2,3,4}); auto predictions = NDArrayFactory::create('c', {2,3,4}); auto weights = NDArrayFactory::create('c', {1,3,4}); - auto expected = NDArrayFactory::create('c', {1,3,4}, {-91.5,-107.5,-125.5,-145.5, -167.5,-191.5,-217.5,-245.5, -275.5,-307.5,-341.5,-377.5}); + auto expected = NDArrayFactory::create('c', {1,3,4}, {-91.5f, -107.5f, -125.5f, -145.5f, -167.5f, -191.5f, -217.5f, -245.5f, -275.5f, -307.5f, -341.5f, -377.5f}); labels.linspace(1); predictions.linspace(2); @@ -1475,7 +1475,7 @@ TEST_F(DeclarableOpsTests2, cosine_distance_loss_test2) { auto labels = NDArrayFactory::create('c', {2,3,4}); auto predictions = NDArrayFactory::create('c', {2,3,4}); auto weights = NDArrayFactory::create('c', {2,1,4}); - auto expected = NDArrayFactory::create('c', {2,1,4}, {-3.25, -4., -4.75, -5.5,-12.25,-13.,-13.75,-14.5}); + auto expected = NDArrayFactory::create('c', {2,1,4}, {-3.25f, -4.f, -4.75f, -5.5f, -12.25f, -13.f, -13.75f, -14.5f}); labels.linspace(1); weights.assign(0.5); @@ -1502,7 +1502,7 @@ TEST_F(DeclarableOpsTests2, cosine_distance_loss_test3) { auto labels = NDArrayFactory::create('c', {2,3,4}); auto predictions = NDArrayFactory::create('c', {2,3,4}); auto weights = NDArrayFactory::create('c', {2,3,1}); - auto expected = NDArrayFactory::create('c', {2,3,1}, {-2., -6.,-10.,-14.,-18.,-22.}); + auto expected = NDArrayFactory::create('c', {2,3,1}, {-2.f, -6.f,-10.f,-14.f,-18.f,-22.f}); labels.linspace(1); weights.assign(0.5); @@ -1527,7 +1527,7 @@ TEST_F(DeclarableOpsTests2, cosine_distance_loss_test4) { auto labels = NDArrayFactory::create('c', {2,3,4}); auto predictions = NDArrayFactory::create('c', {2,3,4}); auto weights = NDArrayFactory::create('c', {1,1}); - auto expected = NDArrayFactory::create('c', {2,3,1}, {-2., -6.,-10.,-14.,-18.,-22.}); + auto expected = NDArrayFactory::create('c', {2,3,1}, {-2.f, -6.f,-10.f,-14.f,-18.f,-22.f}); labels.linspace(1); weights.assign(0.5); @@ -1702,10 +1702,10 @@ TEST_F(DeclarableOpsTests2, cosine_distance_loss_test10) { /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, hinge_loss_test1) { - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,3,4}); - auto expected = NDArrayFactory::create('c', {2,3,4}, {1., 0. , 0., 2.5,0., 3.5, 0., 4.5,0., 5.5, 0., 6.5, 0., 7.5, 0., 8.5,0., 9.5,10., 0. ,0.,11.5, 0.,12.5}); + auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); + auto logits = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {2,3,4}); + auto expected = NDArrayFactory::create('c', {2,3,4}, {1., 0. , 0., 2.5,0., 3.5, 0., 4.5,0., 5.5, 0., 6.5, 0., 7.5, 0., 8.5,0., 9.5,10., 0. ,0.,11.5, 0.,12.5}); logits.linspace(1); weights.assign(0.5); @@ -1727,10 +1727,10 @@ TEST_F(DeclarableOpsTests2, hinge_loss_test1) { /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, hinge_loss_test2) { - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {1,1}); - auto expected = NDArrayFactory::create('c', {2,3,4}, {1., 0. , 0., 2.5,0., 3.5, 0., 4.5,0., 5.5, 0., 6.5, 0., 7.5, 0., 8.5,0., 9.5,10., 0. ,0.,11.5, 0.,12.5}); + auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); + auto logits = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {1,1}); + auto expected = NDArrayFactory::create('c', {2,3,4}, {1., 0. , 0., 2.5,0., 3.5, 0., 4.5,0., 5.5, 0., 6.5, 0., 7.5, 0., 8.5,0., 9.5,10., 0. ,0.,11.5, 0.,12.5}); logits.linspace(1); weights.assign(0.5); @@ -1752,10 +1752,10 @@ TEST_F(DeclarableOpsTests2, hinge_loss_test2) { /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, hinge_loss_test3) { - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {1,3,1}); - auto expected = NDArrayFactory::create('c', {2,3,4}, {1., 0. , 0., 2.5,0., 3.5, 0., 4.5,0., 5.5, 0., 6.5, 0., 7.5, 0., 8.5,0., 9.5,10., 0. ,0.,11.5, 0.,12.5}); + auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); + auto logits = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {1,3,1}); + auto expected = NDArrayFactory::create('c', {2,3,4}, {1., 0. , 0., 2.5,0., 3.5, 0., 4.5,0., 5.5, 0., 6.5, 0., 7.5, 0., 8.5,0., 9.5,10., 0. ,0.,11.5, 0.,12.5}); logits.linspace(1); weights.assign(0.5); @@ -1777,9 +1777,9 @@ TEST_F(DeclarableOpsTests2, hinge_loss_test3) { /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, hinge_loss_test4) { - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,3,4}); + auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); + auto logits = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {2,3,4}); logits.linspace(1); weights.assign(0.5); @@ -1801,9 +1801,9 @@ TEST_F(DeclarableOpsTests2, hinge_loss_test4) { /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, hinge_loss_test5) { - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {1,1}); + auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); + auto logits = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {1,1}); logits.linspace(1); weights.assign(0.5); @@ -1825,9 +1825,9 @@ TEST_F(DeclarableOpsTests2, hinge_loss_test5) { /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, hinge_loss_test6) { - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,1,1}); + auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); + auto logits = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {2,1,1}); logits.linspace(1); weights.assign(0.5); @@ -1849,9 +1849,9 @@ TEST_F(DeclarableOpsTests2, hinge_loss_test6) { /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, hinge_loss_test7) { - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,3,4}); + auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); + auto logits = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {2,3,4}); logits.linspace(1); weights.assign(0.5); @@ -1873,9 +1873,9 @@ TEST_F(DeclarableOpsTests2, hinge_loss_test7) { /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, hinge_loss_test8) { - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {1,1}); + auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); + auto logits = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {1,1}); logits.linspace(1); weights.assign(0.5); @@ -1897,9 +1897,9 @@ TEST_F(DeclarableOpsTests2, hinge_loss_test8) { /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, hinge_loss_test9) { - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {1,1,4}); + auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); + auto logits = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {1,1,4}); logits.linspace(1); weights.assign(0.5); @@ -1921,9 +1921,9 @@ TEST_F(DeclarableOpsTests2, hinge_loss_test9) { /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, hinge_loss_test10) { - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,3,4}); + auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); + auto logits = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {2,3,4}); logits.linspace(1); weights.assign(0.5); @@ -1945,9 +1945,9 @@ TEST_F(DeclarableOpsTests2, hinge_loss_test10) { /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, hinge_loss_test11) { - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,1,4}); + auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); + auto logits = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {2,1,4}); logits.linspace(1); weights.assign(0.5); @@ -1969,9 +1969,9 @@ TEST_F(DeclarableOpsTests2, hinge_loss_test11) { /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, hinge_loss_test12) { - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,3,4}); + auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); + auto logits = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {2,3,4}); logits.linspace(1); weights.assign(0.5); @@ -1997,9 +1997,9 @@ TEST_F(DeclarableOpsTests2, hinge_loss_test12) { /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, hinge_loss_test13) { - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {1,1}); + auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); + auto logits = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {1,1}); logits.linspace(1); weights.assign(0.); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests3.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests3.cpp index 7d166f831..269c13f51 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests3.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests3.cpp @@ -36,7 +36,7 @@ public: TEST_F(DeclarableOpsTests3, Test_Tile_1) { - auto x= NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 4, 5, 6}); + auto x= NDArrayFactory::create('c', {2, 3}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); auto rep_vector= NDArrayFactory::create('c', {1, 2}, {2, 2}); std::vector reps({2, 2}); @@ -55,7 +55,7 @@ TEST_F(DeclarableOpsTests3, Test_Tile_1) { TEST_F(DeclarableOpsTests3, Test_Tile_2) { - auto x= NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 4, 5, 6}); + auto x= NDArrayFactory::create('c', {2, 3}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); std::vector reps({2, 2}); auto exp = x.tile(reps); @@ -104,8 +104,8 @@ TEST_F(DeclarableOpsTests3, Test_Permute_2) { TEST_F(DeclarableOpsTests3, Test_Unique_1) { - auto x= NDArrayFactory::create('c', {1, 5}, {1, 2, 1, 2, 3}); - auto expV= NDArrayFactory::create('c', {3}, {1, 2, 3}); + auto x= NDArrayFactory::create('c', {1, 5}, {1.f, 2.f, 1.f, 2.f, 3.f}); + auto expV= NDArrayFactory::create('c', {3}, {1.f, 2.f, 3.f}); auto expI= NDArrayFactory::create('c', {5}, {0, 1, 0, 1, 2}); // auto expI= NDArrayFactory::create('c', {3}, {0, 1, 4}); @@ -130,8 +130,8 @@ TEST_F(DeclarableOpsTests3, Test_Unique_1) { } TEST_F(DeclarableOpsTests3, Test_Unique_2) { - auto x= NDArrayFactory::create('c', {1, 5}, {1, 2, 1, 2, 3}); - auto expV= NDArrayFactory::create('c', {3}, {1, 2, 3}); + auto x= NDArrayFactory::create('c', {1, 5}, {1.f, 2.f, 1.f, 2.f, 3.f}); + auto expV= NDArrayFactory::create('c', {3}, {1.f, 2.f, 3.f}); auto expI= NDArrayFactory::create('c', {5}, {0, 1, 0, 1, 2}); auto expC= NDArrayFactory::create('c', {3}, {2, 2, 1}); @@ -165,8 +165,8 @@ TEST_F(DeclarableOpsTests3, Test_Unique_2) { } TEST_F(DeclarableOpsTests3, Test_Rint_1) { - auto x= NDArrayFactory::create('c', {1, 7}, {-1.7, -1.5, -0.2, 0.2, 1.5, 1.7, 2.0}); - auto exp= NDArrayFactory::create('c', {1, 7}, {-2., -2., -0., 0., 2., 2., 2.}); + auto x= NDArrayFactory::create('c', {1, 7}, {-1.7f, -1.5f, -0.2f, 0.2f, 1.5f, 1.7f, 2.0f}); + auto exp= NDArrayFactory::create('c', {1, 7}, {-2.f, -2.f, -0.f, 0.f, 2.f, 2.f, 2.f}); nd4j::ops::rint op; auto result = op.execute({&x}, {}, {}); @@ -275,8 +275,8 @@ TEST_F(DeclarableOpsTests3, Test_ClipByAvgNorm_1) { } TEST_F(DeclarableOpsTests3, Test_ClipByAvgNorm_2) { - auto x= NDArrayFactory::create('c', {2, 3}, {-3.0, 0.0, 0.0, 4.0, 0.0, 0.0}); - auto exp= NDArrayFactory::create('c', {2, 3}, {-3, 0.0, 0.0, 4, 0.0, 0.0}); + auto x= NDArrayFactory::create('c', {2, 3}, {-3.0f, 0.0f, 0.0f, 4.0f, 0.0f, 0.0f}); + auto exp= NDArrayFactory::create('c', {2, 3}, {-3.f, 0.0f, 0.0f, 4.f, 0.0f, 0.0f}); nd4j::ops::clipbyavgnorm op; auto result = op.execute({&x}, {0.9}, {}); @@ -291,8 +291,8 @@ TEST_F(DeclarableOpsTests3, Test_ClipByAvgNorm_2) { TEST_F(DeclarableOpsTests3, Test_ClipByNorm_1) { - auto x= NDArrayFactory::create('c', {2, 3}, {-3.0, 0.0, 0.0, 4.0, 0.0, 0.0}); - auto exp= NDArrayFactory::create('c', {2, 3}, {-2.4, 0.0, 0.0, 3.2, 0.0, 0.0}); + auto x= NDArrayFactory::create('c', {2, 3}, {-3.0, 0.0, 0.0, 4.0, 0.0, 0.0}); + auto exp= NDArrayFactory::create('c', {2, 3}, {-2.4, 0.0, 0.0, 3.2, 0.0, 0.0}); nd4j::ops::clipbynorm op; auto result = op.execute({&x}, {4.0}, {}); @@ -306,8 +306,8 @@ TEST_F(DeclarableOpsTests3, Test_ClipByNorm_1) { } TEST_F(DeclarableOpsTests3, Test_ClipByNorm_2) { - auto x= NDArrayFactory::create('c', {2, 3}, {-3.0, 0.0, 0.0, 4.0, 0.0, 0.0}); - auto exp= NDArrayFactory::create('c', {2, 3}, {-3.0, 0.0, 0.0, 4.0, 0.0, 0.0}); + auto x= NDArrayFactory::create('c', {2, 3}, {-3.0f, 0.0f, 0.0f, 4.0f, 0.0f, 0.0f}); + auto exp= NDArrayFactory::create('c', {2, 3}, {-3.0f, 0.0f, 0.0f, 4.0f, 0.0f, 0.0f}); nd4j::ops::clipbynorm op; auto result = op.execute({&x}, {6.0}, {}); @@ -353,10 +353,10 @@ TEST_F(DeclarableOpsTests3, Test_ClipByNorm_3) { } TEST_F(DeclarableOpsTests3, Test_ListDiff_1) { - auto x= NDArrayFactory::create('c', {6}, {1, 2, 3, 4, 5, 6}); - auto y= NDArrayFactory::create('c', {3}, {1, 3, 5}); + auto x= NDArrayFactory::create('c', {6}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); + auto y= NDArrayFactory::create('c', {3}, {1.f, 3.f, 5.f}); - auto exp0= NDArrayFactory::create('c', {3}, {2, 4, 6}); + auto exp0= NDArrayFactory::create('c', {3}, {2.f, 4.f, 6.f}); auto exp1= NDArrayFactory::create('c', {3}, {1, 3, 5}); nd4j::ops::listdiff op; @@ -380,10 +380,10 @@ TEST_F(DeclarableOpsTests3, Test_ListDiff_1) { } TEST_F(DeclarableOpsTests3, Test_Range_1) { - auto start = NDArrayFactory::create(0.3); - auto stop = NDArrayFactory::create(-5); - auto step = NDArrayFactory::create(-0.33); - auto exp= NDArrayFactory::create('c', {17}, { 0.3 , -0.03, -0.36, -0.69, -1.02, -1.35, -1.68, -2.01, -2.34, -2.67,-3. , -3.33, -3.66, -3.99, -4.32, -4.65, -4.98}); + auto start = NDArrayFactory::create(0.3f); + auto stop = NDArrayFactory::create(-5.f); + auto step = NDArrayFactory::create(-0.33f); + auto exp= NDArrayFactory::create('c', {17}, { 0.3f, -0.03f, -0.36f, -0.69f, -1.02f, -1.35f, -1.68f, -2.01f, -2.34f, -2.67f, -3.f, -3.33f, -3.66f, -3.99f, -4.32f, -4.65f, -4.98f}); nd4j::ops::range op; auto result = op.execute({&start, &stop, &step}, {}, {}); @@ -400,10 +400,10 @@ TEST_F(DeclarableOpsTests3, Test_Range_1) { TEST_F(DeclarableOpsTests3, Test_Range_2) { - auto start= NDArrayFactory::create('c', {1, 1}, {2}); - auto stop= NDArrayFactory::create('c', {1, 1}, {0.}); - auto step= NDArrayFactory::create('c', {1, 1}, {-1}); - auto exp= NDArrayFactory::create('c', {2}, {2, 1}); + auto start= NDArrayFactory::create('c', {1, 1}, {2.f}); + auto stop= NDArrayFactory::create('c', {1, 1}, {0.f}); + auto step= NDArrayFactory::create('c', {1, 1}, {-1.f}); + auto exp= NDArrayFactory::create('c', {2}, {2.f, 1.f}); nd4j::ops::range op; auto result = op.execute({&start, &stop, &step}, {}, {}); @@ -419,10 +419,10 @@ TEST_F(DeclarableOpsTests3, Test_Range_2) { } TEST_F(DeclarableOpsTests3, Test_Range_3) { - auto start= NDArrayFactory::create('c', {1, 1}, {0.}); - auto stop= NDArrayFactory::create('c', {1, 1}, {2}); - auto step= NDArrayFactory::create('c', {1, 1}, {1}); - auto exp= NDArrayFactory::create('c', {2}, {0, 1}); + auto start= NDArrayFactory::create('c', {1, 1}, {0.f}); + auto stop= NDArrayFactory::create('c', {1, 1}, {2.f}); + auto step= NDArrayFactory::create('c', {1, 1}, {1.f}); + auto exp= NDArrayFactory::create('c', {2}, {0.f, 1.f}); nd4j::ops::range op; auto result = op.execute({&start, &stop, &step}, {}, {}); @@ -439,7 +439,7 @@ TEST_F(DeclarableOpsTests3, Test_Range_3) { TEST_F(DeclarableOpsTests3, Test_Range_4) { - auto exp= NDArrayFactory::create('c', {13}, {-10., -8.334, -6.668, -5.002, -3.336, -1.67 , -0.004, 1.662, 3.328, 4.994, 6.66 , 8.326, 9.992}); + auto exp= NDArrayFactory::create('c', {13}, {-10.f, -8.334f, -6.668f, -5.002f, -3.336f, -1.67f, -0.004f, 1.662f, 3.328f, 4.994f, 6.66f, 8.326f, 9.992f}); nd4j::ops::range op; auto result = op.execute({}, {-10., 10., 1.666}, {}); @@ -456,7 +456,7 @@ TEST_F(DeclarableOpsTests3, Test_Range_4) { TEST_F(DeclarableOpsTests3, Test_Range_5) { - auto exp= NDArrayFactory::create('c', {2}, {2, 1}); + auto exp= NDArrayFactory::create('c', {2}, {2.f, 1.f}); nd4j::ops::range op; auto result = op.execute({}, {2, 0, -1}, {}); @@ -472,7 +472,7 @@ TEST_F(DeclarableOpsTests3, Test_Range_5) { } TEST_F(DeclarableOpsTests3, Test_Range_6) { - auto exp= NDArrayFactory::create('c', {2}, {0, 1}); + auto exp= NDArrayFactory::create('c', {2}, {0.f, 1.f}); nd4j::ops::range op; auto result = op.execute({}, {0, 2, 1}, {}); @@ -488,7 +488,7 @@ TEST_F(DeclarableOpsTests3, Test_Range_6) { } TEST_F(DeclarableOpsTests3, Test_Range_7) { - auto exp= NDArrayFactory::create('c', {10}, {10., 8.334, 6.668, 5.002, 3.336, 1.67 , 0.004, -1.662, -3.328, -4.994}); + auto exp= NDArrayFactory::create('c', {10}, {10.f, 8.334f, 6.668f, 5.002f, 3.336f, 1.67f, 0.004f, -1.662f, -3.328f, -4.994f}); nd4j::ops::range op; auto result = op.execute({}, {10,-5,-1.666}, {}); @@ -538,10 +538,10 @@ TEST_F(DeclarableOpsTests3, Test_Range_9) { } TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_1) { - auto a= NDArrayFactory::create('c', {1, 3}, {1, 1, 1}); - auto b= NDArrayFactory::create('c', {1, 3}, {0, 0, 0}); - auto x= NDArrayFactory::create('f', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); - auto y= NDArrayFactory::create('f', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto a= NDArrayFactory::create('c', {1, 3}, {1, 1, 1}); + auto b= NDArrayFactory::create('c', {1, 3}, {0, 0, 0}); + auto x= NDArrayFactory::create('f', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto y= NDArrayFactory::create('f', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); auto exp = MmulHelper::mmul(&x, &y); @@ -566,10 +566,10 @@ TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_1) { } TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_2) { - auto a= NDArrayFactory::create('c', {1, 3}, {1, 1, 1}); - auto b= NDArrayFactory::create('c', {1, 3}, {0, 0, 0}); - auto x= NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); - auto y= NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto a= NDArrayFactory::create('c', {1, 3}, {1, 1, 1}); + auto b= NDArrayFactory::create('c', {1, 3}, {0, 0, 0}); + auto x= NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto y= NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); auto exp = MmulHelper::mmul(&x, &y); @@ -594,10 +594,10 @@ TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_2) { } TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_3) { - auto a= NDArrayFactory::create('c', {1, 3}, {1, 1, 1}); - auto b= NDArrayFactory::create('c', {1, 3}, {0, 0, 0}); - auto x= NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); - auto y= NDArrayFactory::create('f', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto a= NDArrayFactory::create('c', {1, 3}, {1, 1, 1}); + auto b= NDArrayFactory::create('c', {1, 3}, {0, 0, 0}); + auto x= NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto y= NDArrayFactory::create('f', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); auto exp = MmulHelper::mmul(&x, &y); @@ -622,10 +622,10 @@ TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_3) { } TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_4) { - auto a= NDArrayFactory::create('c', {1, 3}, {1, 1, 1}); - auto b= NDArrayFactory::create('c', {1, 3}, {0, 0, 0}); - auto x= NDArrayFactory::create('f', {5, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}); - auto y= NDArrayFactory::create('f', {3, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + auto a= NDArrayFactory::create('c', {1, 3}, {1, 1, 1}); + auto b= NDArrayFactory::create('c', {1, 3}, {0, 0, 0}); + auto x= NDArrayFactory::create('f', {5, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}); + auto y= NDArrayFactory::create('f', {3, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); auto exp = MmulHelper::mmul(&x, &y); @@ -650,10 +650,10 @@ TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_4) { } TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_5) { - auto a= NDArrayFactory::create('c', {1, 3}, {1, 1, 1}); - auto b= NDArrayFactory::create('c', {1, 3}, {0, 0, 0}); - auto x= NDArrayFactory::create('c', {5, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}); - auto y= NDArrayFactory::create('c', {3, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + auto a= NDArrayFactory::create('c', {1, 3}, {1, 1, 1}); + auto b= NDArrayFactory::create('c', {1, 3}, {0, 0, 0}); + auto x= NDArrayFactory::create('c', {5, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}); + auto y= NDArrayFactory::create('c', {3, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); auto exp = MmulHelper::mmul(&x, &y); @@ -679,10 +679,10 @@ TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_5) { TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_6) { - auto a= NDArrayFactory::create('c', {1, 3}, {1, 1, 1}); - auto b= NDArrayFactory::create('c', {1, 3}, {0, 0, 0}); - auto x= NDArrayFactory::create('f', {2, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); - auto y= NDArrayFactory::create('f', {5, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}); + auto a= NDArrayFactory::create('c', {1, 3}, {1, 1, 1}); + auto b= NDArrayFactory::create('c', {1, 3}, {0, 0, 0}); + auto x= NDArrayFactory::create('f', {2, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); + auto y= NDArrayFactory::create('f', {5, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}); auto exp = MmulHelper::mmul(&x, &y); @@ -707,10 +707,10 @@ TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_6) { } TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_7) { - auto a= NDArrayFactory::create('c', {1, 3}, {1, 1, 1}); - auto b= NDArrayFactory::create('c', {1, 3}, {0, 0, 0}); - auto x= NDArrayFactory::create('c', {2, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); - auto y= NDArrayFactory::create('c', {5, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}); + auto a= NDArrayFactory::create('c', {1, 3}, {1, 1, 1}); + auto b= NDArrayFactory::create('c', {1, 3}, {0, 0, 0}); + auto x= NDArrayFactory::create('c', {2, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); + auto y= NDArrayFactory::create('c', {5, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}); auto exp = MmulHelper::mmul(&x, &y); @@ -737,10 +737,10 @@ TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_7) { } TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_Validation_1) { - auto a = NDArrayFactory::create('c', {1, 3}, {1, 1, 1}); - auto b = NDArrayFactory::create('c', {1, 3}, {0, 0, 0}); - auto x = NDArrayFactory::create('c', {2, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); - auto y = NDArrayFactory::create('c', {5, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}); + auto a = NDArrayFactory::create('c', {1, 3}, {1.f, 1.f, 1.f}); + auto b = NDArrayFactory::create('c', {1, 3}, {0.f, 0.f, 0.f}); + auto x = NDArrayFactory::create('c', {2, 5}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f}); + auto y = NDArrayFactory::create('c', {5, 3}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f}); nd4j::ops::batched_gemm op; try { @@ -753,10 +753,10 @@ TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_Validation_1) { } TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_Validation_2) { - auto a = NDArrayFactory::create('c', {1, 3}, {1, 1, 1}); - auto b = NDArrayFactory::create('c', {1, 3}, {0, 0, 0}); - auto x = NDArrayFactory::create('c', {2, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); - auto y = NDArrayFactory::create('c', {5, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}); + auto a = NDArrayFactory::create('c', {1, 3}, {1.f, 1.f, 1.f}); + auto b = NDArrayFactory::create('c', {1, 3}, {0.f, 0.f, 0.f}); + auto x = NDArrayFactory::create('c', {2, 5}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f}); + auto y = NDArrayFactory::create('c', {5, 3}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f}); auto z = NDArrayFactory::create('c', {2, 3}); @@ -770,9 +770,9 @@ TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_Validation_2) { } TEST_F(DeclarableOpsTests3, Test_Manual_Gemm_1) { - auto x= NDArrayFactory::create('c', {3, 4}, {1, 2, 3, 4, 5, 6, 7, 8 , 9, 10, 11, 12}); - auto y= NDArrayFactory::create('c', {4, 3}, {1, 2, 3, 4, 5, 6, 7, 8 , 9, 10, 11, 12}); - auto exp= NDArrayFactory::create('f', {4, 4}, {38.0, 44.0, 50.0, 56.0, 83.0, 98.0, 113.0, 128.0, 128.0, 152.0, 176.0, 200.0, 173.0, 206.0, 239.0, 272.0}); + auto x= NDArrayFactory::create('c', {3, 4}, {1, 2, 3, 4, 5, 6, 7, 8 , 9, 10, 11, 12}); + auto y= NDArrayFactory::create('c', {4, 3}, {1, 2, 3, 4, 5, 6, 7, 8 , 9, 10, 11, 12}); + auto exp= NDArrayFactory::create('f', {4, 4}, {38.0, 44.0, 50.0, 56.0, 83.0, 98.0, 113.0, 128.0, 128.0, 152.0, 176.0, 200.0, 173.0, 206.0, 239.0, 272.0}); nd4j::ops::matmul op; auto result = op.execute({&x, &y}, {}, {1, 1}); @@ -789,9 +789,9 @@ TEST_F(DeclarableOpsTests3, Test_Manual_Gemm_1) { TEST_F(DeclarableOpsTests3, Test_Manual_Gemm_2) { - auto x= NDArrayFactory::create('c', {3, 4}, {1, 2, 3, 4, 5, 6, 7, 8 , 9, 10, 11, 12}); - auto y= NDArrayFactory::create('c', {4, 3}, {1, 2, 3, 4, 5, 6, 7, 8 , 9, 10, 11, 12}); - auto exp= NDArrayFactory::create('f', {3, 3}, {70.0, 158.0, 246.0, 80.0, 184.0, 288.0, 90.0, 210.0, 330.0}); + auto x= NDArrayFactory::create('c', {3, 4}, {1, 2, 3, 4, 5, 6, 7, 8 , 9, 10, 11, 12}); + auto y= NDArrayFactory::create('c', {4, 3}, {1, 2, 3, 4, 5, 6, 7, 8 , 9, 10, 11, 12}); + auto exp= NDArrayFactory::create('f', {3, 3}, {70.0, 158.0, 246.0, 80.0, 184.0, 288.0, 90.0, 210.0, 330.0}); nd4j::ops::matmul op; auto result = op.execute({&x, &y}, {}, {0, 0}); @@ -808,9 +808,9 @@ TEST_F(DeclarableOpsTests3, Test_Manual_Gemm_2) { TEST_F(DeclarableOpsTests3, Test_Manual_Gemm_3) { - auto x= NDArrayFactory::create('c', {1, 3}, {1, 2, 3}); - auto y= NDArrayFactory::create('c', {1, 4}, {1, 2, 3, 4}); - auto exp= NDArrayFactory::create('f', {3, 4}, {1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 3.0, 6.0, 9.0, 4.0, 8.0, 12.0}); + auto x= NDArrayFactory::create('c', {1, 3}, {1, 2, 3}); + auto y= NDArrayFactory::create('c', {1, 4}, {1, 2, 3, 4}); + auto exp= NDArrayFactory::create('f', {3, 4}, {1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 3.0, 6.0, 9.0, 4.0, 8.0, 12.0}); nd4j::ops::matmul op; auto result = op.execute({&x, &y}, {}, {1, 0}); @@ -827,9 +827,9 @@ TEST_F(DeclarableOpsTests3, Test_Manual_Gemm_3) { } TEST_F(DeclarableOpsTests3, Test_Manual_Gemm_4) { - auto x= NDArrayFactory::create('c', {3, 1}, {1, 2, 3}); - auto y= NDArrayFactory::create('c', {4, 1}, {1, 2, 3, 4}); - auto exp= NDArrayFactory::create('f', {3, 4}, {1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 3.0, 6.0, 9.0, 4.0, 8.0, 12.0}); + auto x= NDArrayFactory::create('c', {3, 1}, {1, 2, 3}); + auto y= NDArrayFactory::create('c', {4, 1}, {1, 2, 3, 4}); + auto exp= NDArrayFactory::create('f', {3, 4}, {1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 3.0, 6.0, 9.0, 4.0, 8.0, 12.0}); nd4j::ops::matmul op; auto result = op.execute({&x, &y}, {}, {0, 1}); @@ -846,9 +846,9 @@ TEST_F(DeclarableOpsTests3, Test_Manual_Gemm_4) { } TEST_F(DeclarableOpsTests3, Test_Manual_Gemm_5) { - auto x= NDArrayFactory::create('c', {3, 1}, {1, 2, 3}); - auto y= NDArrayFactory::create('c', {1, 4}, {1, 2, 3, 4}); - auto exp= NDArrayFactory::create('f', {3, 4}, {1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 3.0, 6.0, 9.0, 4.0, 8.0, 12.0}); + auto x= NDArrayFactory::create('c', {3, 1}, {1, 2, 3}); + auto y= NDArrayFactory::create('c', {1, 4}, {1, 2, 3, 4}); + auto exp= NDArrayFactory::create('f', {3, 4}, {1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 3.0, 6.0, 9.0, 4.0, 8.0, 12.0}); nd4j::ops::matmul op; auto result = op.execute({&x, &y}, {}, {}); @@ -865,9 +865,9 @@ TEST_F(DeclarableOpsTests3, Test_Manual_Gemm_5) { } TEST_F(DeclarableOpsTests3, Test_Manual_Gemm_6) { - auto x= NDArrayFactory::create('c', {4, 1}, {1, 2, 3, 4}); - auto y= NDArrayFactory::create('c', {1, 4}, {1, 2, 3, 4}); - auto exp= NDArrayFactory::create('f', {4, 4}, {1,2, 3, 4,2,4, 6, 8,3,6, 9,12,4,8,12,16}); + auto x= NDArrayFactory::create('c', {4, 1}, {1, 2, 3, 4}); + auto y= NDArrayFactory::create('c', {1, 4}, {1, 2, 3, 4}); + auto exp= NDArrayFactory::create('f', {4, 4}, {1,2, 3, 4,2,4, 6, 8,3,6, 9,12,4,8,12,16}); nd4j::ops::matmul op; auto result = op.execute({&x, &y}, {}, {}); @@ -900,9 +900,9 @@ TEST_F(DeclarableOpsTests3, Test_AvgPool_1) { } TEST_F(DeclarableOpsTests3, Test_ReverseDivide_1) { - auto x= NDArrayFactory::create('c', {1, 3}, {2, 2, 2}); - auto y= NDArrayFactory::create('c', {1, 3}, {4, 6, 8}); - auto exp= NDArrayFactory::create('c', {1, 3}, {2, 3, 4}); + auto x= NDArrayFactory::create('c', {1, 3}, {2, 2, 2}); + auto y= NDArrayFactory::create('c', {1, 3}, {4, 6, 8}); + auto exp= NDArrayFactory::create('c', {1, 3}, {2, 3, 4}); nd4j::ops::reversedivide op; auto result = op.execute({&x, &y}, {}, {}); @@ -932,8 +932,8 @@ TEST_F(DeclarableOpsTests3, sruCell_test1) { w.assign(0.5); b.assign(0.7); - auto expHt= NDArrayFactory::create('c', {batchSize, inSize}, {0.96674103,0.96674103,0.96674103,0.96674103,0.96674103,0.96674103,0.96674103,0.96674103,0.96674103,0.96674103}); - auto expCt= NDArrayFactory::create('c', {batchSize, inSize}, {2.01958286,2.01958286,2.01958286,2.01958286,2.01958286, 2.01958286,2.01958286,2.01958286,2.01958286,2.01958286}); + auto expHt= NDArrayFactory::create('c', {batchSize, inSize}, {0.96674103f, 0.96674103f, 0.96674103f, 0.96674103f, 0.96674103f, 0.96674103f, 0.96674103f, 0.96674103f, 0.96674103f, 0.96674103f}); + auto expCt= NDArrayFactory::create('c', {batchSize, inSize}, {2.01958286f, 2.01958286f, 2.01958286f, 2.01958286f, 2.01958286f, 2.01958286f, 2.01958286f, 2.01958286f, 2.01958286f, 2.01958286f}); nd4j::ops::sruCell op; auto results = op.execute({&xt, &ct_1, &w, &b}, {}, {}); @@ -968,8 +968,8 @@ TEST_F(DeclarableOpsTests3, sruCell_test2) { w.assign(0.5); b.assign(-1.); - auto expHt= NDArrayFactory::create('c', {batchSize, inSize}, {0.97542038,0.97542038,0.97542038,0.97542038,0.97542038,0.97542038,0.97542038,0.97542038,0.97542038,0.97542038}); - auto expCt= NDArrayFactory::create('c', {batchSize, inSize}, {2.09121276,2.09121276,2.09121276,2.09121276,2.09121276,2.09121276,2.09121276,2.09121276,2.09121276,2.09121276}); + auto expHt= NDArrayFactory::create('c', {batchSize, inSize}, {0.97542038f, 0.97542038f, 0.97542038f, 0.97542038f, 0.97542038f, 0.97542038f, 0.97542038f, 0.97542038f, 0.97542038f, 0.97542038f}); + auto expCt= NDArrayFactory::create('c', {batchSize, inSize}, {2.09121276f, 2.09121276f, 2.09121276f, 2.09121276f, 2.09121276f, 2.09121276f, 2.09121276f, 2.09121276f, 2.09121276f, 2.09121276f}); nd4j::ops::sruCell op; auto results = op.execute({&xt, &ct_1, &w, &b}, {}, {}); @@ -1003,8 +1003,8 @@ TEST_F(DeclarableOpsTests3, sruCell_test3) { w.assign(0.5); b.assign(-1.); - auto expHt= NDArrayFactory::create('c', {batchSize, inSize}, {0.76159416,0.76159416,0.76159416,0.76159416,0.76159416,0.76159416,0.76159416,0.76159416,0.76159416,0.76159416}); - auto expCt= NDArrayFactory::create('c', {batchSize, inSize}, {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.}); + auto expHt= NDArrayFactory::create('c', {batchSize, inSize}, {0.76159416f, 0.76159416f, 0.76159416f, 0.76159416f, 0.76159416f, 0.76159416f, 0.76159416f, 0.76159416f, 0.76159416f, 0.76159416f}); + auto expCt= NDArrayFactory::create('c', {batchSize, inSize}, {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}); nd4j::ops::sruCell op; auto results = op.execute({&xt, &ct_1, &w, &b}, {}, {}); @@ -1043,7 +1043,7 @@ TEST_F(DeclarableOpsTests3, gruCell_test1) { bru.assign(0.7); bc.assign(0.7); - auto expHt = NDArrayFactory::create('c', {batchSize, numUnits}, {1.99993872,1.99993872,1.99993872,1.99993872,1.99993872,1.99993872,1.99993872,1.99993872}); + auto expHt = NDArrayFactory::create('c', {batchSize, numUnits}, {1.99993872f, 1.99993872f, 1.99993872f, 1.99993872f, 1.99993872f, 1.99993872f, 1.99993872f, 1.99993872f}); nd4j::ops::gruCell op; auto results = op.execute({&xt, &ht_1, &Wru, &Wc, &bru, &bc}, {}, {}); @@ -1079,7 +1079,7 @@ TEST_F(DeclarableOpsTests3, gruCell_test2) { bru.assign(-10); bc.assign(-10); - auto expHt= NDArrayFactory::create('c', {batchSize, numUnits}, {0.00669224,0.00669224,0.00669224,0.00669224,0.00669224,0.00669224,0.00669224,0.00669224}); + auto expHt= NDArrayFactory::create('c', {batchSize, numUnits}, {0.00669224f, 0.00669224f, 0.00669224f, 0.00669224f, 0.00669224f, 0.00669224f, 0.00669224f, 0.00669224f}); nd4j::ops::gruCell op; auto results = op.execute({&xt, &ht_1, &Wru, &Wc, &bru, &bc}, {}, {}); @@ -1115,7 +1115,7 @@ TEST_F(DeclarableOpsTests3, gruCell_test3) { bru.assign(1); bc.assign(1); - auto expHt= NDArrayFactory::create('c', {batchSize, numUnits}, {0.1149149,0.1149149,0.1149149,0.1149149,0.1149149,0.1149149,0.1149149,0.1149149}); + auto expHt= NDArrayFactory::create('c', {batchSize, numUnits}, {0.1149149f, 0.1149149f, 0.1149149f, 0.1149149f, 0.1149149f, 0.1149149f, 0.1149149f, 0.1149149f}); nd4j::ops::gruCell op; auto results = op.execute({&xt, &ht_1, &Wru, &Wc, &bru, &bc}, {}, {}); @@ -1133,8 +1133,8 @@ TEST_F(DeclarableOpsTests3, gruCell_test3) { //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, invertPermutation_test1) { - auto input= NDArrayFactory::create('c', {1, 8}, {5,2,7,4,6,3,1,0}); - auto expected= NDArrayFactory::create('c', {1, 8}, {7, 6, 1, 5, 3, 0, 4, 2}); + auto input= NDArrayFactory::create('c', {1, 8}, {5,2,7,4,6,3,1,0}); + auto expected= NDArrayFactory::create('c', {1, 8}, {7, 6, 1, 5, 3, 0, 4, 2}); nd4j::ops::invert_permutation op; auto results = op.execute({&input}, {}, {}); @@ -1152,8 +1152,8 @@ TEST_F(DeclarableOpsTests3, invertPermutation_test1) { /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, invertPermutation_test2) { - auto input= NDArrayFactory::create('c', {1, 8}, {5,2,7,4,6,3,1,0}); - auto expected= NDArrayFactory::create('c', {1, 8}, {7, 6, 1, 5, 3, 0, 4, 2}); + auto input= NDArrayFactory::create('c', {1, 8}, {5,2,7,4,6,3,1,0}); + auto expected= NDArrayFactory::create('c', {1, 8}, {7, 6, 1, 5, 3, 0, 4, 2}); nd4j::ops::invert_permutation op; auto results = op.execute({&input}, {}, {}); @@ -1171,8 +1171,8 @@ TEST_F(DeclarableOpsTests3, invertPermutation_test2) { /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, invertPermutation_test3) { - auto input= NDArrayFactory::create('c', {1, 8}, {1,2,0,4,6,3,5,7}); - auto expected= NDArrayFactory::create('c', {1, 8}, {2, 0, 1, 5, 3, 6, 4, 7}); + auto input= NDArrayFactory::create('c', {1, 8}, {1,2,0,4,6,3,5,7}); + auto expected= NDArrayFactory::create('c', {1, 8}, {2, 0, 1, 5, 3, 6, 4, 7}); nd4j::ops::invert_permutation op; auto results = op.execute({&input}, {}, {}); @@ -1190,10 +1190,10 @@ TEST_F(DeclarableOpsTests3, invertPermutation_test3) { /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, diag_test1) { - auto input= NDArrayFactory::create('c', {3, 2}); + auto input= NDArrayFactory::create('c', {3, 2}); input.linspace(1); - auto expected= NDArrayFactory::create('c', {3,2,3,2}, {1,0,0,0,0,0, 0,2,0,0,0,0, 0,0,3,0,0,0, 0,0,0,4,0,0, 0,0,0,0,5,0, 0,0,0,0,0,6}); + auto expected= NDArrayFactory::create('c', {3,2,3,2}, {1,0,0,0,0,0, 0,2,0,0,0,0, 0,0,3,0,0,0, 0,0,0,4,0,0, 0,0,0,0,5,0, 0,0,0,0,0,6}); nd4j::ops::diag op; auto results = op.execute({&input}, {}, {}); @@ -1211,10 +1211,10 @@ TEST_F(DeclarableOpsTests3, diag_test1) { /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, diag_test2) { - auto input= NDArrayFactory::create('c', {2, 3}); + auto input= NDArrayFactory::create('c', {2, 3}); input.linspace(1); - auto expected= NDArrayFactory::create('c', {2,3,2,3}, {1,0,0,0,0,0, 0,2,0,0,0,0, 0,0,3,0,0,0, 0,0,0,4,0,0, 0,0,0,0,5,0, 0,0,0,0,0,6}); + auto expected= NDArrayFactory::create('c', {2,3,2,3}, {1,0,0,0,0,0, 0,2,0,0,0,0, 0,0,3,0,0,0, 0,0,0,4,0,0, 0,0,0,0,5,0, 0,0,0,0,0,6}); nd4j::ops::diag op; auto results = op.execute({&input}, {}, {}); @@ -1234,8 +1234,8 @@ TEST_F(DeclarableOpsTests3, diag_test2) { TEST_F(DeclarableOpsTests3, diag_test_vector) { - auto input = NDArrayFactory::linspace(1,4,4); - auto expected= NDArrayFactory::create('c', {4,4}, {1,0,0,0, 0,2,0,0, 0,0,3,0,0,0,0,4}); + auto input = NDArrayFactory::linspace(1,4,4); + auto expected= NDArrayFactory::create('c', {4,4}, {1,0,0,0, 0,2,0,0, 0,0,3,0,0,0,0,4}); nd4j::ops::diag op; auto results = op.execute({input}, {}, {}); @@ -1257,9 +1257,9 @@ TEST_F(DeclarableOpsTests3, diag_test_vector) { TEST_F(DeclarableOpsTests3, diag_test_col_vector) { - auto input = NDArrayFactory::linspace(1,4,4); + auto input = NDArrayFactory::linspace(1,4,4); input->reshapei({4,1}); - auto expected= NDArrayFactory::create('c', {4,4}, {1,0,0,0, 0,2,0,0, 0,0,3,0,0,0,0,4}); + auto expected= NDArrayFactory::create('c', {4,4}, {1,0,0,0, 0,2,0,0, 0,0,3,0,0,0,0,4}); nd4j::ops::diag op; auto results = op.execute({input}, {}, {}); @@ -1277,10 +1277,10 @@ TEST_F(DeclarableOpsTests3, diag_test_col_vector) { /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, diag_test3) { - auto input= NDArrayFactory::create('c', {1, 3}); + auto input= NDArrayFactory::create('c', {1, 3}); input.linspace(1); - auto expected= NDArrayFactory::create('c', {3,3}, {1,0,0, 0,2,0, 0,0,3}); + auto expected= NDArrayFactory::create('c', {3,3}, {1,0,0, 0,2,0, 0,0,3}); nd4j::ops::diag op; auto results = op.execute({&input}, {}, {}); @@ -1298,10 +1298,10 @@ TEST_F(DeclarableOpsTests3, diag_test3) { /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, diag_test4) { - auto input= NDArrayFactory::create('c', {3, 1}); + auto input= NDArrayFactory::create('c', {3, 1}); input.linspace(1); - auto expected= NDArrayFactory::create('c', {3,3}, {1,0,0, 0,2,0, 0,0,3}); + auto expected= NDArrayFactory::create('c', {3,3}, {1,0,0, 0,2,0, 0,0,3}); nd4j::ops::diag op; auto results = op.execute({&input}, {}, {}); @@ -1319,10 +1319,10 @@ TEST_F(DeclarableOpsTests3, diag_test4) { /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, diag_test5) { - auto input= NDArrayFactory::create('c', {1, 1}); + auto input= NDArrayFactory::create('c', {1, 1}); input.linspace(2); - auto expected= NDArrayFactory::create('c', {1,1}, {2}); + auto expected= NDArrayFactory::create('c', {1,1}, {2}); nd4j::ops::diag op; auto results = op.execute({&input}, {}, {}); @@ -1340,10 +1340,10 @@ TEST_F(DeclarableOpsTests3, diag_test5) { /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, diag_test6) { - auto input= NDArrayFactory::create('c', {2,2,2}); + auto input= NDArrayFactory::create('c', {2,2,2}); input.linspace(1); - auto expected= NDArrayFactory::create('c', {2,2,2,2,2,2}, {1,0,0,0, 0,0,0,0, 0,2,0,0, 0,0,0,0, 0,0,3,0, 0,0,0,0, 0,0,0,4, 0,0,0,0, 0,0,0,0, 5,0,0,0, 0,0,0,0, 0,6,0,0, 0,0,0,0, 0,0,7,0, 0,0,0,0, 0,0,0,8}); + auto expected= NDArrayFactory::create('c', {2,2,2,2,2,2}, {1,0,0,0, 0,0,0,0, 0,2,0,0, 0,0,0,0, 0,0,3,0, 0,0,0,0, 0,0,0,4, 0,0,0,0, 0,0,0,0, 5,0,0,0, 0,0,0,0, 0,6,0,0, 0,0,0,0, 0,0,7,0, 0,0,0,0, 0,0,0,8}); nd4j::ops::diag op; auto results = op.execute({&input}, {}, {}); @@ -1361,12 +1361,12 @@ TEST_F(DeclarableOpsTests3, diag_test6) { /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, matrixSetDiag_test1) { - auto input= NDArrayFactory::create('c', {4,3,2}); - auto diagonal= NDArrayFactory::create('c', {4,2}); + auto input= NDArrayFactory::create('c', {4,3,2}); + auto diagonal= NDArrayFactory::create('c', {4,2}); input.assign(0.); diagonal.assign(1.); - auto expected= NDArrayFactory::create('c', {4,3,2}, {1,0,0,1,0,0, 1,0,0,1,0,0, 1,0,0,1,0,0, 1,0,0,1,0,0}); + auto expected= NDArrayFactory::create('c', {4,3,2}, {1,0,0,1,0,0, 1,0,0,1,0,0, 1,0,0,1,0,0, 1,0,0,1,0,0}); nd4j::ops::matrix_set_diag op; auto results = op.execute({&input, &diagonal}, {}, {}); @@ -1389,7 +1389,7 @@ TEST_F(DeclarableOpsTests3, matrixSetDiag_test2) { input.assign(0.); diagonal.assign(1.); - auto expected= NDArrayFactory::create('c', {1,1,2}, {1,0}); + auto expected= NDArrayFactory::create('c', {1,1,2}, {1.f, 0.f}); nd4j::ops::matrix_set_diag op; auto results = op.execute({&input, &diagonal}, {}, {}); @@ -1407,12 +1407,12 @@ TEST_F(DeclarableOpsTests3, matrixSetDiag_test2) { /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, matrixSetDiag_test3) { - auto input= NDArrayFactory::create('c', {2,1,4}); - auto diagonal= NDArrayFactory::create('c', {2,1}); + auto input= NDArrayFactory::create('c', {2,1,4}); + auto diagonal= NDArrayFactory::create('c', {2,1}); input.assign(0.); diagonal.assign(1.); - auto expected= NDArrayFactory::create('c', {2,1,4}, {1,0,0,0,1,0,0,0}); + auto expected= NDArrayFactory::create('c', {2,1,4}, {1,0,0,0,1,0,0,0}); nd4j::ops::matrix_set_diag op; auto results = op.execute({&input, &diagonal}, {}, {}); @@ -1430,12 +1430,12 @@ TEST_F(DeclarableOpsTests3, matrixSetDiag_test3) { /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, matrixSetDiag_test4) { - auto input= NDArrayFactory::create('c', {2,1,4,1}); - auto diagonal= NDArrayFactory::create('c', {2,1,1}); + auto input= NDArrayFactory::create('c', {2,1,4,1}); + auto diagonal= NDArrayFactory::create('c', {2,1,1}); input.assign(0.); diagonal.assign(1.); - auto expected= NDArrayFactory::create('c', {2,1,4,1}, {1,0,0,0,1,0,0,0}); + auto expected= NDArrayFactory::create('c', {2,1,4,1}, {1,0,0,0,1,0,0,0}); nd4j::ops::matrix_set_diag op; auto results = op.execute({&input, &diagonal}, {}, {}); @@ -1453,10 +1453,10 @@ TEST_F(DeclarableOpsTests3, matrixSetDiag_test4) { /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, diagPart_test1) { - auto input= NDArrayFactory::create('c', {2,2}); + auto input= NDArrayFactory::create('c', {2,2}); input.linspace(1); - auto expected= NDArrayFactory::create('c', {2}, {1,4}); + auto expected= NDArrayFactory::create('c', {2}, {1,4}); nd4j::ops::diag_part op; auto results = op.execute({&input}, {}, {}); @@ -1475,10 +1475,10 @@ TEST_F(DeclarableOpsTests3, diagPart_test1) { /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, diagPart_test2) { - auto input= NDArrayFactory::create('c', {2,2,2,2}); + auto input= NDArrayFactory::create('c', {2,2,2,2}); input.linspace(1); - auto expected= NDArrayFactory::create('c', {2,2}, {1,6,11,16}); + auto expected= NDArrayFactory::create('c', {2,2}, {1,6,11,16}); nd4j::ops::diag_part op; auto results = op.execute({&input}, {}, {}); @@ -1496,10 +1496,10 @@ TEST_F(DeclarableOpsTests3, diagPart_test2) { /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, diagPart_test3) { - auto input= NDArrayFactory::create('c', {2,2,2,2,2,2}); + auto input= NDArrayFactory::create('c', {2,2,2,2,2,2}); input.linspace(1); - auto expected= NDArrayFactory::create('c', {2,2,2}, {1,10,19,28,37,46,55,64}); + auto expected= NDArrayFactory::create('c', {2,2,2}, {1,10,19,28,37,46,55,64}); nd4j::ops::diag_part op; auto results = op.execute({&input}, {}, {}); @@ -1525,7 +1525,7 @@ TEST_F(DeclarableOpsTests3, betainc_test1) { b.linspace((float16)0.1, (float16)0.1); x.assign(0.1); - auto expected = NDArrayFactory::create('c', {3,3}, {0.40638509,0.33668978,0.28271242,0.23973916,0.20483276,0.17604725,0.15203027,0.13180567,0.114647}); + auto expected = NDArrayFactory::create('c', {3,3}, {0.40638509f, 0.33668978f, 0.28271242f, 0.23973916f, 0.20483276f, 0.17604725f, 0.15203027f, 0.13180567f, 0.114647f}); nd4j::ops::betainc op; auto results = op.execute({&a, &b, &x}, {}, {}); @@ -1551,7 +1551,7 @@ TEST_F(DeclarableOpsTests3, betainc_test2) { b.linspace(0.1, 0.1); x.assign(0.1); - auto expected= NDArrayFactory::create('c', {3,3}, {0.40638509,0.33668978,0.28271242,0.23973916,0.20483276,0.17604725,0.15203027,0.13180567,0.114647}); + auto expected= NDArrayFactory::create('c', {3,3}, {0.40638509f, 0.33668978f, 0.28271242f, 0.23973916f, 0.20483276f, 0.17604725f, 0.15203027f, 0.13180567f, 0.114647f}); nd4j::ops::betainc op; auto results = op.execute({&a, &b, &x}, {}, {}); @@ -1577,7 +1577,7 @@ TEST_F(DeclarableOpsTests3, betainc_test3) { b.linspace(0.1, 0.1); x.assign(0.1); - auto expected= NDArrayFactory::create('c', {3,3}, {0.40638509,0.33668978,0.28271242,0.23973916,0.20483276,0.17604725,0.15203027,0.13180567,0.114647}); + auto expected= NDArrayFactory::create('c', {3,3}, {0.40638509f, 0.33668978f, 0.28271242f, 0.23973916f, 0.20483276f, 0.17604725f, 0.15203027f, 0.13180567f, 0.114647f}); nd4j::ops::betainc op; auto results = op.execute({&a, &b, &x}, {}, {}); @@ -1603,7 +1603,7 @@ TEST_F(DeclarableOpsTests3, betainc_test4) { b.linspace(1); x.assign(0.1); - auto expected= NDArrayFactory::create('c', {3,3}, {1.00000000e-01,2.80000000e-02,8.56000000e-03,2.72800000e-03,8.90920000e-04,2.95706080e-04,9.92854864e-05,3.36248880e-05,1.14644360e-05}); + auto expected= NDArrayFactory::create('c', {3,3}, {1.00000000e-01f, 2.80000000e-02f, 8.56000000e-03f, 2.72800000e-03f, 8.90920000e-04f, 2.95706080e-04f, 9.92854864e-05f, 3.36248880e-05f, 1.14644360e-05f}); nd4j::ops::betainc op; auto results = op.execute({&a, &b, &x}, {}, {}); @@ -1629,7 +1629,7 @@ TEST_F(DeclarableOpsTests3, betainc_test5) { b.linspace(3200.); x.assign(0.1); - auto expected= NDArrayFactory::create('c', {3,3}, {0.,0.,0.,0.,0.,0.,0.,0.,0.}); + auto expected= NDArrayFactory::create('c', {3,3}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}); nd4j::ops::betainc op; auto results = op.execute({&a, &b, &x}, {}, {}); @@ -1655,7 +1655,7 @@ TEST_F(DeclarableOpsTests3, betainc_test6) { b.linspace(10.); x.assign(0.1); - auto expected= NDArrayFactory::create('c', {3,3}, {3.92988233e-06,1.35306497e-06,4.67576826e-07,1.62083416e-07,5.63356971e-08,1.96261318e-08,6.85120307e-09,2.39594668e-09,8.39227685e-10}); + auto expected= NDArrayFactory::create('c', {3,3}, {3.92988233e-06f, 1.35306497e-06f, 4.67576826e-07f, 1.62083416e-07f, 5.63356971e-08f, 1.96261318e-08f, 6.85120307e-09f, 2.39594668e-09f, 8.39227685e-10f}); nd4j::ops::betainc op; auto results = op.execute({&a, &b, &x}, {}, {}); @@ -1681,7 +1681,7 @@ TEST_F(DeclarableOpsTests3, betainc_test7) { b.linspace(10.); x.assign(0.9); - auto expected= NDArrayFactory::create('c', {3,3}, {0.99999607,0.99999865,0.99999953,0.99999984,0.99999994,0.99999998,0.99999999,1.,1.}); + auto expected= NDArrayFactory::create('c', {3,3}, {0.99999607f, 0.99999865f, 0.99999953f, 0.99999984f, 0.99999994f, 0.99999998f, 0.99999999f, 1.f, 1.f}); nd4j::ops::betainc op; auto results = op.execute({&a, &b, &x}, {}, {}); @@ -1707,7 +1707,7 @@ TEST_F(DeclarableOpsTests3, betainc_test8) { b.linspace(10.); x.assign(1.); - auto expected= NDArrayFactory::create('c', {3,3}, {1.,1.,1.,1.,1.,1.,1.,1.,1.}); + auto expected= NDArrayFactory::create('c', {3,3}, {1.f, 1.f, 1.,1.,1.,1.,1.,1.,1.}); nd4j::ops::betainc op; auto results = op.execute({&a, &b, &x}, {}, {}); @@ -1733,7 +1733,7 @@ TEST_F(DeclarableOpsTests3, betainc_test9) { b.linspace(10.); x.assign(0.); - auto expected= NDArrayFactory::create('c', {3,3}, {0.,0.,0.,0.,0.,0.,0.,0.,0.}); + auto expected= NDArrayFactory::create('c', {3,3}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}); nd4j::ops::betainc op; auto results = op.execute({&a, &b, &x}, {}, {}); @@ -1759,7 +1759,7 @@ TEST_F(DeclarableOpsTests3, betainc_test10) { b.linspace(10.); x.assign(0.5); - auto expected= NDArrayFactory::create('c', {3,3}, {0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5}); + auto expected= NDArrayFactory::create('c', {3,3}, {0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f}); nd4j::ops::betainc op; auto results = op.execute({&a, &b, &x}, {}, {}); @@ -1783,7 +1783,7 @@ TEST_F(DeclarableOpsTests3, zeta_test1) { q.linspace(1.); x.assign(2.); - auto expected= NDArrayFactory::create('c', {3,3}, {1.64493407,0.64493407,0.39493407,0.28382296,0.22132296,0.18132296,0.15354518,0.13313701,0.11751201}); + auto expected= NDArrayFactory::create('c', {3,3}, {1.64493407f, 0.64493407f, 0.39493407f, 0.28382296f, 0.22132296f, 0.18132296f, 0.15354518f, 0.13313701f, 0.11751201f}); nd4j::ops::zeta op; auto results = op.execute({&x, &q}, {}, {}); @@ -1807,7 +1807,7 @@ TEST_F(DeclarableOpsTests3, zeta_test2) { q.linspace(10.); x.assign(2.); - auto expected= NDArrayFactory::create('c', {3,3}, {0.10516634,0.09516634,0.08690187,0.07995743,0.07404027,0.06893823,0.06449378,0.06058753,0.05712733}); + auto expected= NDArrayFactory::create('c', {3,3}, {0.10516634f, 0.09516634f, 0.08690187f, 0.07995743f, 0.07404027f, 0.06893823f, 0.06449378f, 0.06058753f, 0.05712733f}); nd4j::ops::zeta op; auto results = op.execute({&x, &q}, {}, {}); @@ -1831,7 +1831,7 @@ TEST_F(DeclarableOpsTests3, zeta_test3) { q.linspace(100.); x.assign(2.); - auto expected= NDArrayFactory::create('c', {3,3}, {0.01005017,0.00995017,0.00985214,0.00975602,0.00966176,0.0095693 ,0.0094786 ,0.0093896 ,0.00930226}); + auto expected= NDArrayFactory::create('c', {3,3}, {0.01005017f, 0.00995017f, 0.00985214f, 0.00975602f, 0.00966176f, 0.0095693f, 0.0094786f, 0.0093896f, 0.00930226f}); nd4j::ops::zeta op; auto results = op.execute({&x, &q}, {}, {}); @@ -1856,7 +1856,7 @@ TEST_F(DeclarableOpsTests3, zeta_test4) { q.linspace(100.); x.assign(2.); - auto expected= NDArrayFactory::create('c', {3,3}, {0.01005017,0.00995017,0.00985214,0.00975602,0.00966176,0.0095693 ,0.0094786 ,0.0093896 ,0.00930226}); + auto expected= NDArrayFactory::create('c', {3,3}, {0.01005017f, 0.00995017f, 0.00985214f, 0.00975602f, 0.00966176f, 0.0095693f, 0.0094786f, 0.0093896f, 0.00930226f}); nd4j::ops::zeta op; auto results = op.execute({&x, &q}, {}, {}); @@ -1880,7 +1880,7 @@ TEST_F(DeclarableOpsTests3, zeta_test5) { q.linspace(1.); x.assign(1.1); - auto expected= NDArrayFactory::create('c', {3,3}, {10.58444846,9.58444846,9.11793197, 8.81927915,8.60164151,8.43137352, 8.29204706,8.17445116,8.07291961}); + auto expected= NDArrayFactory::create('c', {3,3}, {10.58444846f, 9.58444846f, 9.11793197f, 8.81927915f, 8.60164151f, 8.43137352f, 8.29204706f, 8.17445116f, 8.07291961f}); nd4j::ops::zeta op; auto results = op.execute({&x, &q}, {}, {}); @@ -1904,7 +1904,7 @@ TEST_F(DeclarableOpsTests3, zeta_test6) { q.linspace(1.); x.assign(1.01); - auto expected= NDArrayFactory::create('c', {3,3}, {100.57794334,99.57794334,99.08139709, 98.75170576,98.50514758,98.30834069, 98.1446337 ,98.00452955,97.88210202}); + auto expected= NDArrayFactory::create('c', {3,3}, {100.57794334f, 99.57794334f, 99.08139709f, 98.75170576f, 98.50514758f, 98.30834069f, 98.1446337f, 98.00452955f, 97.88210202f}); nd4j::ops::zeta op; auto results = op.execute({&x, &q}, {}, {}); @@ -1928,7 +1928,7 @@ TEST_F(DeclarableOpsTests3, zeta_test7) { q.linspace(1.); x.assign(10.); - auto expected= NDArrayFactory::create('c', {3,3}, {1.00099458e+00,9.94575128e-04,1.80126278e-05,1.07754001e-06,1.23865693e-07,2.14656932e-08,4.92752156e-09,1.38738839e-09,4.56065812e-10}); + auto expected= NDArrayFactory::create('c', {3,3}, {1.00099458e+00f, 9.94575128e-04f, 1.80126278e-05f, 1.07754001e-06f, 1.23865693e-07f, 2.14656932e-08f, 4.92752156e-09f, 1.38738839e-09f, 4.56065812e-10f}); nd4j::ops::zeta op; auto results = op.execute({&x, &q}, {}, {}); @@ -1946,13 +1946,13 @@ TEST_F(DeclarableOpsTests3, zeta_test7) { /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, zeta_test8) { - auto x= NDArrayFactory::create('c', {3,4}, {1.1,1.2,1.3,1.4,1.5,1.6,1.7,1.8,1.9,1.01,1.11,1.12}); - auto q= NDArrayFactory::create('c', {3,4}, {0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.01, 0.11, 0.12}); + auto x= NDArrayFactory::create('c', {3,4}, {1.1,1.2,1.3,1.4,1.5,1.6,1.7,1.8,1.9,1.01,1.11,1.12}); + auto q= NDArrayFactory::create('c', {3,4}, {0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.01, 0.11, 0.12}); //q.linspace(1.); //x.assign(10.); - auto expected= NDArrayFactory::create('c', {3,4}, {23.014574, 12.184081, 8.275731, 6.1532226, 4.776538, 3.7945523, 3.0541048, 2.4765317, 2.0163891, 205.27448, 21.090889, 19.477398}); + auto expected= NDArrayFactory::create('c', {3,4}, {23.014574, 12.184081, 8.275731, 6.1532226, 4.776538, 3.7945523, 3.0541048, 2.4765317, 2.0163891, 205.27448, 21.090889, 19.477398}); nd4j::ops::zeta op; auto results = op.execute({&x, &q}, {}, {}); @@ -1970,14 +1970,14 @@ TEST_F(DeclarableOpsTests3, zeta_test8) { /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, zeta_test9) { - auto x= NDArrayFactory::create('c', {3,4}, {1.1,1.2,1.3,1.4,1.5,1.6,1.7,1.8,1.9,1.01,1.11,1.12}); - auto q= NDArrayFactory::create('c', {3,4}, {0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.01, 0.11, 0.12}); - auto z= NDArrayFactory::create('c', {3,4}, {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.}); + auto x= NDArrayFactory::create('c', {3,4}, {1.1,1.2,1.3,1.4,1.5,1.6,1.7,1.8,1.9,1.01,1.11,1.12}); + auto q= NDArrayFactory::create('c', {3,4}, {0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.01, 0.11, 0.12}); + auto z= NDArrayFactory::create('c', {3,4}, {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.}); //q.linspace(1.); //x.assign(10.); - auto expected= NDArrayFactory::create('c', {3,4}, {23.014574, 12.184081, 8.275731, 6.1532226, 4.776538, 3.7945523, 3.0541048, 2.4765317, 2.0163891, 205.27448, 21.090889, 19.477398}); + auto expected= NDArrayFactory::create('c', {3,4}, {23.014574, 12.184081, 8.275731, 6.1532226, 4.776538, 3.7945523, 3.0541048, 2.4765317, 2.0163891, 205.27448, 21.090889, 19.477398}); nd4j::ops::zeta op; auto results = op.execute({&x, &q}, {&z}, {}, {}, {}); @@ -1995,14 +1995,14 @@ TEST_F(DeclarableOpsTests3, zeta_test9) { /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, zeta_test10) { - auto x= NDArrayFactory::create('c', {3,4}, {1.1,1.2,1.3,1.4,1.5,1.6,1.7,1.8,1.9,1.01,1.11,1.12}); - auto q= NDArrayFactory::create('c', {3,4}, {0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.01, 0.11, 0.12}); - auto z= NDArrayFactory::create('c', {3,4}); + auto x= NDArrayFactory::create('c', {3,4}, {1.1,1.2,1.3,1.4,1.5,1.6,1.7,1.8,1.9,1.01,1.11,1.12}); + auto q= NDArrayFactory::create('c', {3,4}, {0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.01, 0.11, 0.12}); + auto z= NDArrayFactory::create('c', {3,4}); //q.linspace(1.); //x.assign(10.); - auto expected= NDArrayFactory::create('c', {3,4}, {23.014574, 12.184081, 8.275731, 6.1532226, 4.776538, 3.7945523, 3.0541048, 2.4765317, 2.0163891, 205.27448, 21.090889, 19.477398}); + auto expected= NDArrayFactory::create('c', {3,4}, {23.014574, 12.184081, 8.275731, 6.1532226, 4.776538, 3.7945523, 3.0541048, 2.4765317, 2.0163891, 205.27448, 21.090889, 19.477398}); nd4j::ops::zeta op; auto results = op.execute({&x, &q}, {&z}, {}, {}, {}); @@ -2034,13 +2034,13 @@ TEST_F(DeclarableOpsTests3, Test_SplitV_Validation_1) { /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, polygamma_test1) { - auto n= NDArrayFactory::create('c', {3,3}); - auto x= NDArrayFactory::create('c', {3,3}); + auto n= NDArrayFactory::create('c', {3,3}); + auto x= NDArrayFactory::create('c', {3,3}); // ASSERT_FALSE(true); n.linspace(1.); x.assign(0.5); - auto expected= NDArrayFactory::create('c', {3,3}, {4.934802, -16.828796, 97.409088, -771.474243, 7691.113770, -92203.460938, 1290440.250000, -20644900.000000, 3.71595e+08}); + auto expected= NDArrayFactory::create('c', {3,3}, {4.934802, -16.828796, 97.409088, -771.474243, 7691.113770, -92203.460938, 1290440.250000, -20644900.000000, 3.71595e+08}); nd4j::ops::polygamma op; auto results = op.execute({&n, &x}, {}, {}); @@ -2059,13 +2059,13 @@ TEST_F(DeclarableOpsTests3, polygamma_test1) { /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, polygamma_test2) { - auto n= NDArrayFactory::create('c', {3,3}); - auto x= NDArrayFactory::create('c', {3,3}); + auto n= NDArrayFactory::create('c', {3,3}); + auto x= NDArrayFactory::create('c', {3,3}); n.linspace(10.); x.linspace(0.5); - auto expected= NDArrayFactory::create('c', {3,3}, {-7.43182451e+09, 3.08334759e+05,-3.25669798e+03, 1.55186197e+02,-1.46220433e+01, 2.00905201e+00,-3.48791235e-01, 7.08016273e-02,-1.60476052e-02}); + auto expected= NDArrayFactory::create('c', {3,3}, {-7.43182451e+09, 3.08334759e+05,-3.25669798e+03, 1.55186197e+02,-1.46220433e+01, 2.00905201e+00,-3.48791235e-01, 7.08016273e-02,-1.60476052e-02}); //ASSERT_FALSE(true); @@ -2085,13 +2085,13 @@ TEST_F(DeclarableOpsTests3, polygamma_test2) { /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, polygamma_test3) { - auto n= NDArrayFactory::create('c', {3,3}); - auto x= NDArrayFactory::create('c', {3,3}); + auto n= NDArrayFactory::create('c', {3,3}); + auto x= NDArrayFactory::create('c', {3,3}); n.linspace(1.); x.linspace(10.); - auto expected= NDArrayFactory::create('c', {3,3}, {1.05166336e-01,-9.04983497e-03, 1.31009323e-03,-2.44459433e-04, 5.31593880e-05,-1.28049888e-05, 3.31755364e-06,-9.07408791e-07, 2.58758130e-07}); + auto expected= NDArrayFactory::create('c', {3,3}, {1.05166336e-01,-9.04983497e-03, 1.31009323e-03,-2.44459433e-04, 5.31593880e-05,-1.28049888e-05, 3.31755364e-06,-9.07408791e-07, 2.58758130e-07}); //ASSERT_FALSE(true); @@ -2111,10 +2111,10 @@ TEST_F(DeclarableOpsTests3, polygamma_test3) { /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, svd_test1) { - auto x= NDArrayFactory::create('c', {6,6}, {0. ,-9. ,-6 ,9 ,-10 ,-12 ,2 ,13 ,5 ,-11 ,20 ,-17 ,1 ,-2 ,-11 ,3 ,-8 ,3 ,-14 ,19 ,-20 ,20 ,-17 ,-5 ,6 ,-16 ,0 ,-1 ,-16 ,11 ,7 ,-19 ,2 ,-17 ,17 ,-16}); - auto expS= NDArrayFactory::create('c', {6}, {54.12775, 38.79293, 25.89287, 9.82168, 6.07227, 2.91827}); - auto expU= NDArrayFactory::create('c', {6,6}, {0.14692,-0.11132,-0.69568, 0.59282,-0.14881, 0.32935,-0.38751, 0.60378,-0.04927,-0.01397,-0.69456,-0.01581, 0.19293,-0.12795,-0.18682,-0.69065,-0.20597, 0.62617, 0.66806, 0.4314 ,-0.33849,-0.22166, 0.04099,-0.44967, 0.11121,-0.64065,-0.02138,-0.07378,-0.60568,-0.45216,-0.5765 ,-0.1007 ,-0.60305,-0.34175, 0.29068,-0.3042}); - auto expV= NDArrayFactory::create('c', {6,6}, {-0.24577,-0.24512, 0.00401,-0.04585,-0.62058, 0.70162, 0.27937, 0.75961, 0.43885,-0.06857,-0.3839 , 0.01669,-0.35944,-0.09629, 0.44593, 0.78602,-0.09103,-0.19125, 0.53973, 0.07613,-0.10721, 0.49559, 0.35687, 0.56431,-0.6226 , 0.39742, 0.12785,-0.15716, 0.52372, 0.37297, 0.23113,-0.43578, 0.76204,-0.32414, 0.23996, 0.11543}); + auto x= NDArrayFactory::create('c', {6,6}, {0. ,-9. ,-6 ,9 ,-10 ,-12 ,2 ,13 ,5 ,-11 ,20 ,-17 ,1 ,-2 ,-11 ,3 ,-8 ,3 ,-14 ,19 ,-20 ,20 ,-17 ,-5 ,6 ,-16 ,0 ,-1 ,-16 ,11 ,7 ,-19 ,2 ,-17 ,17 ,-16}); + auto expS= NDArrayFactory::create('c', {6}, {54.12775, 38.79293, 25.89287, 9.82168, 6.07227, 2.91827}); + auto expU= NDArrayFactory::create('c', {6,6}, {0.14692,-0.11132,-0.69568, 0.59282,-0.14881, 0.32935,-0.38751, 0.60378,-0.04927,-0.01397,-0.69456,-0.01581, 0.19293,-0.12795,-0.18682,-0.69065,-0.20597, 0.62617, 0.66806, 0.4314 ,-0.33849,-0.22166, 0.04099,-0.44967, 0.11121,-0.64065,-0.02138,-0.07378,-0.60568,-0.45216,-0.5765 ,-0.1007 ,-0.60305,-0.34175, 0.29068,-0.3042}); + auto expV= NDArrayFactory::create('c', {6,6}, {-0.24577,-0.24512, 0.00401,-0.04585,-0.62058, 0.70162, 0.27937, 0.75961, 0.43885,-0.06857,-0.3839 , 0.01669,-0.35944,-0.09629, 0.44593, 0.78602,-0.09103,-0.19125, 0.53973, 0.07613,-0.10721, 0.49559, 0.35687, 0.56431,-0.6226 , 0.39742, 0.12785,-0.15716, 0.52372, 0.37297, 0.23113,-0.43578, 0.76204,-0.32414, 0.23996, 0.11543}); nd4j::ops::svd op; auto results = op.execute({&x}, {}, {1, 1, 16}); @@ -2137,9 +2137,9 @@ TEST_F(DeclarableOpsTests3, svd_test1) { } else { for(uint i = 0; i < expU.lengthOf(); ++i) - ASSERT_NEAR(nd4j::math::nd4j_abs(expU.t(i)), nd4j::math::nd4j_abs(u->t(i)), 1e-5); + ASSERT_NEAR(nd4j::math::nd4j_abs(expU.e(i)), nd4j::math::nd4j_abs(u->e(i)), 1e-5); for(uint i = 0; i < expV.lengthOf(); ++i) - ASSERT_NEAR(nd4j::math::nd4j_abs(expV.t(i)), nd4j::math::nd4j_abs(v->t(i)), 1e-5); + ASSERT_NEAR(nd4j::math::nd4j_abs(expV.e(i)), nd4j::math::nd4j_abs(v->e(i)), 1e-5); } delete results; @@ -2148,10 +2148,10 @@ TEST_F(DeclarableOpsTests3, svd_test1) { /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, svd_test2) { - auto x = NDArrayFactory::create('c', {7,6}, {0. ,-9. ,-6 ,9 ,-10 ,-12 ,2 ,13 ,5 ,-11 ,20 ,-17 ,1 ,-2 ,-11 ,3 ,-8 ,3 ,-14 ,19 ,-20 ,20 ,-17 ,-5 ,6 ,-16 ,0 ,-1 ,-16 ,11 ,7 ,-19 ,2 ,-17 ,17 ,-16, 4, -9, 1, -15, 7, -2}); - auto expS= NDArrayFactory::create('c', {6}, {56.76573, 39.11776, 26.00713, 11.83606, 6.16578, 3.99672}); - auto expU= NDArrayFactory::create('c', {7,7}, {-0.13417,-0.12443, -0.68854, 0.5196 , 0.21706, 0.03974, 0.41683, 0.347 , 0.62666, -0.04964, -0.01912, 0.66932, 0.1457 , -0.12183,-0.17329,-0.14666, -0.19639, -0.55355, 0.0614 , 0.75729, 0.1619 ,-0.64703, 0.37056, -0.37398, -0.32922, -0.0186 , -0.35656, -0.26134,-0.08027,-0.64405, -0.0127 , -0.06934, 0.59287, -0.14956, -0.44712, 0.55906,-0.06235, -0.58017, -0.12911, -0.359 , -0.00393, -0.44877, 0.30645,-0.11953, -0.09083, -0.54163, 0.14283, -0.50417, 0.56178}); - auto expV= NDArrayFactory::create('c', {6,6}, {0.2508 ,-0.2265 , 0.01689, 0.04486, 0.53132, 0.77537,-0.32281, 0.74559, 0.41845, -0.13821, 0.37642, 0.06315, 0.33139,-0.05528, 0.47186, 0.73171, 0.18905, -0.3055 ,-0.57263, 0.06276,-0.09542, 0.59396, -0.36152, 0.419 , 0.59193, 0.4361 , 0.13557, -0.03632, -0.5755 , 0.32944,-0.21165,-0.44227, 0.75794, -0.29895, -0.27993, 0.13187}); + auto x = NDArrayFactory::create('c', {7,6}, {0. ,-9. ,-6 ,9 ,-10 ,-12 ,2 ,13 ,5 ,-11 ,20 ,-17 ,1 ,-2 ,-11 ,3 ,-8 ,3 ,-14 ,19 ,-20 ,20 ,-17 ,-5 ,6 ,-16 ,0 ,-1 ,-16 ,11 ,7 ,-19 ,2 ,-17 ,17 ,-16, 4, -9, 1, -15, 7, -2}); + auto expS= NDArrayFactory::create('c', {6}, {56.76573, 39.11776, 26.00713, 11.83606, 6.16578, 3.99672}); + auto expU= NDArrayFactory::create('c', {7,7}, {-0.13417,-0.12443, -0.68854, 0.5196 , 0.21706, 0.03974, 0.41683, 0.347 , 0.62666, -0.04964, -0.01912, 0.66932, 0.1457 , -0.12183,-0.17329,-0.14666, -0.19639, -0.55355, 0.0614 , 0.75729, 0.1619 ,-0.64703, 0.37056, -0.37398, -0.32922, -0.0186 , -0.35656, -0.26134,-0.08027,-0.64405, -0.0127 , -0.06934, 0.59287, -0.14956, -0.44712, 0.55906,-0.06235, -0.58017, -0.12911, -0.359 , -0.00393, -0.44877, 0.30645,-0.11953, -0.09083, -0.54163, 0.14283, -0.50417, 0.56178}); + auto expV= NDArrayFactory::create('c', {6,6}, {0.2508 ,-0.2265 , 0.01689, 0.04486, 0.53132, 0.77537,-0.32281, 0.74559, 0.41845, -0.13821, 0.37642, 0.06315, 0.33139,-0.05528, 0.47186, 0.73171, 0.18905, -0.3055 ,-0.57263, 0.06276,-0.09542, 0.59396, -0.36152, 0.419 , 0.59193, 0.4361 , 0.13557, -0.03632, -0.5755 , 0.32944,-0.21165,-0.44227, 0.75794, -0.29895, -0.27993, 0.13187}); nd4j::ops::svd op; auto results = op.execute({&x}, {}, {1, 1, 16}); @@ -2174,9 +2174,9 @@ TEST_F(DeclarableOpsTests3, svd_test2) { } else { for(uint i = 0; i < expU.lengthOf(); ++i) - ASSERT_NEAR(nd4j::math::nd4j_abs(expU.t(i)), nd4j::math::nd4j_abs(u->t(i)), 1e-5); + ASSERT_NEAR(nd4j::math::nd4j_abs(expU.e(i)), nd4j::math::nd4j_abs(u->e(i)), 1e-5); for(uint i = 0; i < expV.lengthOf(); ++i) - ASSERT_NEAR(nd4j::math::nd4j_abs(expV.t(i)), nd4j::math::nd4j_abs(v->t(i)), 1e-5); + ASSERT_NEAR(nd4j::math::nd4j_abs(expV.e(i)), nd4j::math::nd4j_abs(v->e(i)), 1e-5); } delete results; @@ -2185,10 +2185,10 @@ TEST_F(DeclarableOpsTests3, svd_test2) { /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, svd_test3) { - auto x= NDArrayFactory::create('c', {7,6}, {0. ,-9. ,-6 ,9 ,-10 ,-12 ,2 ,13 ,5 ,-11 ,20 ,-17 ,1 ,-2 ,-11 ,3 ,-8 ,3 ,-14 ,19 ,-20 ,20 ,-17 ,-5 ,6 ,-16 ,0 ,-1 ,-16 ,11 ,7 ,-19 ,2 ,-17 ,17 ,-16, 4, -9, 1, -15, 7, -2}); - auto expS= NDArrayFactory::create('c', {6}, {56.76573, 39.11776, 26.00713, 11.83606, 6.16578, 3.99672}); - auto expU= NDArrayFactory::create('c', {7,6}, {-0.13417, -0.12443, -0.68854, 0.5196 , 0.21706, 0.03974, 0.347 , 0.62666, -0.04964, -0.01912, 0.66932, 0.1457 ,-0.17329, -0.14666, -0.19639, -0.55355, 0.0614 , 0.75729,-0.64703, 0.37056, -0.37398, -0.32922, -0.0186 , -0.35656,-0.08027, -0.64405, -0.0127 , -0.06934, 0.59287, -0.14956, 0.55906, -0.06235, -0.58017, -0.12911, -0.359 , -0.00393, 0.30645, -0.11953, -0.09083, -0.54163, 0.14283, -0.50417}); - auto expV= NDArrayFactory::create('c', {6,6}, {0.2508 ,-0.2265 , 0.01689, 0.04486, 0.53132, 0.77537,-0.32281, 0.74559, 0.41845, -0.13821, 0.37642, 0.06315, 0.33139,-0.05528, 0.47186, 0.73171, 0.18905, -0.3055 ,-0.57263, 0.06276,-0.09542, 0.59396, -0.36152, 0.419 , 0.59193, 0.4361 , 0.13557, -0.03632, -0.5755 , 0.32944,-0.21165,-0.44227, 0.75794, -0.29895, -0.27993, 0.13187}); + auto x= NDArrayFactory::create('c', {7,6}, {0. ,-9. ,-6 ,9 ,-10 ,-12 ,2 ,13 ,5 ,-11 ,20 ,-17 ,1 ,-2 ,-11 ,3 ,-8 ,3 ,-14 ,19 ,-20 ,20 ,-17 ,-5 ,6 ,-16 ,0 ,-1 ,-16 ,11 ,7 ,-19 ,2 ,-17 ,17 ,-16, 4, -9, 1, -15, 7, -2}); + auto expS= NDArrayFactory::create('c', {6}, {56.76573, 39.11776, 26.00713, 11.83606, 6.16578, 3.99672}); + auto expU= NDArrayFactory::create('c', {7,6}, {-0.13417, -0.12443, -0.68854, 0.5196 , 0.21706, 0.03974, 0.347 , 0.62666, -0.04964, -0.01912, 0.66932, 0.1457 ,-0.17329, -0.14666, -0.19639, -0.55355, 0.0614 , 0.75729,-0.64703, 0.37056, -0.37398, -0.32922, -0.0186 , -0.35656,-0.08027, -0.64405, -0.0127 , -0.06934, 0.59287, -0.14956, 0.55906, -0.06235, -0.58017, -0.12911, -0.359 , -0.00393, 0.30645, -0.11953, -0.09083, -0.54163, 0.14283, -0.50417}); + auto expV= NDArrayFactory::create('c', {6,6}, {0.2508 ,-0.2265 , 0.01689, 0.04486, 0.53132, 0.77537,-0.32281, 0.74559, 0.41845, -0.13821, 0.37642, 0.06315, 0.33139,-0.05528, 0.47186, 0.73171, 0.18905, -0.3055 ,-0.57263, 0.06276,-0.09542, 0.59396, -0.36152, 0.419 , 0.59193, 0.4361 , 0.13557, -0.03632, -0.5755 , 0.32944,-0.21165,-0.44227, 0.75794, -0.29895, -0.27993, 0.13187}); nd4j::ops::svd op; auto results = op.execute({&x}, {}, {0, 1, 16}); @@ -2211,9 +2211,9 @@ TEST_F(DeclarableOpsTests3, svd_test3) { } else { for(uint i = 0; i < expU.lengthOf(); ++i) - ASSERT_NEAR(nd4j::math::nd4j_abs(expU.t(i)), nd4j::math::nd4j_abs(u->t(i)), 1e-5); + ASSERT_NEAR(nd4j::math::nd4j_abs(expU.e(i)), nd4j::math::nd4j_abs(u->e(i)), 1e-5); for(uint i = 0; i < expV.lengthOf(); ++i) - ASSERT_NEAR(nd4j::math::nd4j_abs(expV.t(i)), nd4j::math::nd4j_abs(v->t(i)), 1e-5); + ASSERT_NEAR(nd4j::math::nd4j_abs(expV.e(i)), nd4j::math::nd4j_abs(v->e(i)), 1e-5); } delete results; @@ -2222,10 +2222,10 @@ TEST_F(DeclarableOpsTests3, svd_test3) { /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, svd_test4) { - auto x= NDArrayFactory::create('c', {6,7}, {0. ,-9. ,-6 ,9 ,-10 ,-12 ,2 ,13 ,5 ,-11 ,20 ,-17 ,1 ,-2 ,-11 ,3 ,-8 ,3 ,-14 ,19 ,-20 ,20 ,-17 ,-5 ,6 ,-16 ,0 ,-1 ,-16 ,11 ,7 ,-19 ,2 ,-17 ,17 ,-16, 4, -9, 1, -15, 7, -2}); - auto expS= NDArrayFactory::create('c', {6}, {53.11053, 39.09542, 28.1987, 17.7468, 11.61684, 5.36217}); - auto expU= NDArrayFactory::create('c', {6,6}, {-0.16541, 0.21276, 0.51284, 0.20472, 0.74797, 0.25102,-0.49879, 0.12076, 0.37629, -0.7211 , -0.24585, 0.12086,-0.36569,-0.70218, -0.08012, 0.21274, -0.07314, 0.56231,-0.44508, 0.4329 , 0.1356 , 0.60909, -0.47398, -0.02164, 0.61238,-0.05674, 0.59489, 0.06588, -0.3874 , 0.33685,-0.13044,-0.50644, 0.46552, 0.13236, -0.00474, -0.70161}); - auto expV= NDArrayFactory::create('c', {7,7}, {-0.35914, 0.68966, -0.30077, -0.15238, -0.48179, 0.14716, -0.16709, 0.21989, -0.34343, 0.11086, -0.78381, -0.37902, 0.24224, -0.06862, 0.32179, 0.12812, -0.25812, 0.0691 , -0.12891, 0.26979, 0.84807,-0.50833, 0.13793, 0.06658, -0.53001, 0.52572, -0.16194, 0.36692, 0.48118, 0.15876, -0.65132, -0.24602, 0.3963 , -0.16651, -0.27155,-0.31605, -0.46947, -0.50195, 0.0378 , -0.34937, -0.53062, 0.15069, 0.35957, 0.35408, 0.38732, -0.12154, -0.22827, -0.7151 , 0.13065}); + auto x= NDArrayFactory::create('c', {6,7}, {0. ,-9. ,-6 ,9 ,-10 ,-12 ,2 ,13 ,5 ,-11 ,20 ,-17 ,1 ,-2 ,-11 ,3 ,-8 ,3 ,-14 ,19 ,-20 ,20 ,-17 ,-5 ,6 ,-16 ,0 ,-1 ,-16 ,11 ,7 ,-19 ,2 ,-17 ,17 ,-16, 4, -9, 1, -15, 7, -2}); + auto expS= NDArrayFactory::create('c', {6}, {53.11053, 39.09542, 28.1987, 17.7468, 11.61684, 5.36217}); + auto expU= NDArrayFactory::create('c', {6,6}, {-0.16541, 0.21276, 0.51284, 0.20472, 0.74797, 0.25102,-0.49879, 0.12076, 0.37629, -0.7211 , -0.24585, 0.12086,-0.36569,-0.70218, -0.08012, 0.21274, -0.07314, 0.56231,-0.44508, 0.4329 , 0.1356 , 0.60909, -0.47398, -0.02164, 0.61238,-0.05674, 0.59489, 0.06588, -0.3874 , 0.33685,-0.13044,-0.50644, 0.46552, 0.13236, -0.00474, -0.70161}); + auto expV= NDArrayFactory::create('c', {7,7}, {-0.35914, 0.68966, -0.30077, -0.15238, -0.48179, 0.14716, -0.16709, 0.21989, -0.34343, 0.11086, -0.78381, -0.37902, 0.24224, -0.06862, 0.32179, 0.12812, -0.25812, 0.0691 , -0.12891, 0.26979, 0.84807,-0.50833, 0.13793, 0.06658, -0.53001, 0.52572, -0.16194, 0.36692, 0.48118, 0.15876, -0.65132, -0.24602, 0.3963 , -0.16651, -0.27155,-0.31605, -0.46947, -0.50195, 0.0378 , -0.34937, -0.53062, 0.15069, 0.35957, 0.35408, 0.38732, -0.12154, -0.22827, -0.7151 , 0.13065}); nd4j::ops::svd op; auto results = op.execute({&x}, {}, {1, 1, 16}); @@ -2248,9 +2248,9 @@ TEST_F(DeclarableOpsTests3, svd_test4) { } else { for(uint i = 0; i < expU.lengthOf(); ++i) - ASSERT_NEAR(nd4j::math::nd4j_abs(expU.t(i)), nd4j::math::nd4j_abs(u->t(i)), 1e-5); + ASSERT_NEAR(nd4j::math::nd4j_abs(expU.e(i)), nd4j::math::nd4j_abs(u->e(i)), 1e-5); for(uint i = 0; i < expV.lengthOf(); ++i) - ASSERT_NEAR(nd4j::math::nd4j_abs(expV.t(i)), nd4j::math::nd4j_abs(v->t(i)), 1e-5); + ASSERT_NEAR(nd4j::math::nd4j_abs(expV.e(i)), nd4j::math::nd4j_abs(v->e(i)), 1e-5); } delete results; @@ -2259,10 +2259,10 @@ TEST_F(DeclarableOpsTests3, svd_test4) { /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, svd_test5) { - auto x= NDArrayFactory::create('c', {6,7}, {0. ,-9. ,-6 ,9 ,-10 ,-12 ,2 ,13 ,5 ,-11 ,20 ,-17 ,1 ,-2 ,-11 ,3 ,-8 ,3 ,-14 ,19 ,-20 ,20 ,-17 ,-5 ,6 ,-16 ,0 ,-1 ,-16 ,11 ,7 ,-19 ,2 ,-17 ,17 ,-16, 4, -9, 1, -15, 7, -2}); - auto expS= NDArrayFactory::create('c', {6}, {53.11053, 39.09542, 28.1987, 17.7468, 11.61684, 5.36217}); - auto expU= NDArrayFactory::create('c', {6,6}, {-0.16541, 0.21276, 0.51284, 0.20472, 0.74797, 0.25102,-0.49879, 0.12076, 0.37629, -0.7211 , -0.24585, 0.12086,-0.36569,-0.70218, -0.08012, 0.21274, -0.07314, 0.56231,-0.44508, 0.4329 , 0.1356 , 0.60909, -0.47398, -0.02164, 0.61238,-0.05674, 0.59489, 0.06588, -0.3874 , 0.33685,-0.13044,-0.50644, 0.46552, 0.13236, -0.00474, -0.70161}); - auto expV= NDArrayFactory::create('c', {7,6}, {-0.35914, 0.68966, -0.30077, -0.15238, -0.48179, 0.14716, 0.21989, -0.34343, 0.11086, -0.78381, -0.37902, 0.24224, 0.32179, 0.12812, -0.25812, 0.0691 , -0.12891, 0.26979,-0.50833, 0.13793, 0.06658, -0.53001, 0.52572, -0.16194, 0.48118, 0.15876, -0.65132, -0.24602, 0.3963 , -0.16651,-0.31605, -0.46947, -0.50195, 0.0378 , -0.34937, -0.53062, 0.35957, 0.35408, 0.38732, -0.12154, -0.22827, -0.7151}); + auto x= NDArrayFactory::create('c', {6,7}, {0. ,-9. ,-6 ,9 ,-10 ,-12 ,2 ,13 ,5 ,-11 ,20 ,-17 ,1 ,-2 ,-11 ,3 ,-8 ,3 ,-14 ,19 ,-20 ,20 ,-17 ,-5 ,6 ,-16 ,0 ,-1 ,-16 ,11 ,7 ,-19 ,2 ,-17 ,17 ,-16, 4, -9, 1, -15, 7, -2}); + auto expS= NDArrayFactory::create('c', {6}, {53.11053, 39.09542, 28.1987, 17.7468, 11.61684, 5.36217}); + auto expU= NDArrayFactory::create('c', {6,6}, {-0.16541, 0.21276, 0.51284, 0.20472, 0.74797, 0.25102,-0.49879, 0.12076, 0.37629, -0.7211 , -0.24585, 0.12086,-0.36569,-0.70218, -0.08012, 0.21274, -0.07314, 0.56231,-0.44508, 0.4329 , 0.1356 , 0.60909, -0.47398, -0.02164, 0.61238,-0.05674, 0.59489, 0.06588, -0.3874 , 0.33685,-0.13044,-0.50644, 0.46552, 0.13236, -0.00474, -0.70161}); + auto expV= NDArrayFactory::create('c', {7,6}, {-0.35914, 0.68966, -0.30077, -0.15238, -0.48179, 0.14716, 0.21989, -0.34343, 0.11086, -0.78381, -0.37902, 0.24224, 0.32179, 0.12812, -0.25812, 0.0691 , -0.12891, 0.26979,-0.50833, 0.13793, 0.06658, -0.53001, 0.52572, -0.16194, 0.48118, 0.15876, -0.65132, -0.24602, 0.3963 , -0.16651,-0.31605, -0.46947, -0.50195, 0.0378 , -0.34937, -0.53062, 0.35957, 0.35408, 0.38732, -0.12154, -0.22827, -0.7151}); nd4j::ops::svd op; auto results = op.execute({&x}, {}, {0, 1, 16}); @@ -2285,9 +2285,9 @@ TEST_F(DeclarableOpsTests3, svd_test5) { } else { for(uint i = 0; i < expU.lengthOf(); ++i) - ASSERT_NEAR(nd4j::math::nd4j_abs(expU.t(i)), nd4j::math::nd4j_abs(u->t(i)), 1e-5); + ASSERT_NEAR(nd4j::math::nd4j_abs(expU.e(i)), nd4j::math::nd4j_abs(u->e(i)), 1e-5); for(uint i = 0; i < expV.lengthOf(); ++i) - ASSERT_NEAR(nd4j::math::nd4j_abs(expV.t(i)), nd4j::math::nd4j_abs(v->t(i)), 1e-5); + ASSERT_NEAR(nd4j::math::nd4j_abs(expV.e(i)), nd4j::math::nd4j_abs(v->e(i)), 1e-5); } delete results; @@ -2296,22 +2296,22 @@ TEST_F(DeclarableOpsTests3, svd_test5) { /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, svd_test6) { - auto x= NDArrayFactory::create('c', {2,2,5,5}, {-7. ,17 ,4 ,-10 ,5 ,1 ,-5 ,-19 ,13 ,-8 ,9 ,13 ,19 ,13 ,-2 + auto x= NDArrayFactory::create('c', {2,2,5,5}, {-7. ,17 ,4 ,-10 ,5 ,1 ,-5 ,-19 ,13 ,-8 ,9 ,13 ,19 ,13 ,-2 ,-8 ,10 ,-9 ,0 ,-20 ,-2 ,14 ,19 ,5 ,-18 ,4 ,-13 ,12 ,-10 ,5 ,-10 ,-10 ,17 ,-5 ,-2 ,10 ,5 ,-4 ,-11 ,15 ,-3 ,15 ,-17 ,-20 ,-10 ,-4 ,12 ,-9 ,16 ,13 ,10 ,-19 ,2 ,-9 ,-10 ,8 ,-2 ,-4 ,3 ,7 ,10 ,-19 ,-11 ,-4 ,-6 ,2 ,-12 ,6 ,-4 ,-14 ,14 ,16 ,7 ,19 ,-17 ,2 ,-14 ,5 ,-1 ,16 ,19 ,-11 ,-14 ,-16 ,-19 ,15 ,-18 ,-12 ,-16 ,16 ,1 ,5 ,7 ,8 ,2 ,13 ,-3 ,6 ,2 ,-5}); - auto expS= NDArrayFactory::create('c', {2,2,5}, {40.95395, 31.46869, 24.79993, 12.33768, 1.80031, + auto expS= NDArrayFactory::create('c', {2,2,5}, {40.95395, 31.46869, 24.79993, 12.33768, 1.80031, 38.18412, 31.52287, 23.52755, 11.79484, 1.90195, 39.34498, 32.54861, 17.52492, 7.03003, 2.2399, 44.72126, 32.3164 , 16.60139, 6.88783, 0.78122}); - auto expU= NDArrayFactory::create('c', {2,2,5,5}, {0.25441, 0.16908, -0.68564, 0.58844, -0.30054, + auto expU= NDArrayFactory::create('c', {2,2,5,5}, {0.25441, 0.16908, -0.68564, 0.58844, -0.30054, -0.32285, -0.58332, 0.3451 , 0.4746 , -0.45953,0.58332, 0.10605, 0.51533, 0.50234, 0.36136,0.12588, -0.73123, -0.37812, -0.00215, 0.55361, 0.68915, -0.2919 , 0.04767, -0.4197 , -0.51132,0.44464, -0.25326, -0.42493, -0.01712, -0.74653,0.516 , -0.16688, 0.1854 , -0.77155, 0.27611, -0.19321, -0.14317, -0.85886, -0.15224, 0.42585,-0.60155, -0.68323, 0.18819, -0.29053, -0.22696,-0.36993, 0.64862, -0.10956, -0.54483, -0.36552, -0.57697, -0.32277, 0.11229, 0.55495, 0.4923 ,-0.02937, 0.01689, -0.63257, 0.57075, -0.52245,-0.56002, -0.2036 , -0.53119, -0.6022 , 0.01017, -0.33605, -0.35257, 0.53215, -0.04936, -0.69075,0.48958, -0.85427, -0.14796, -0.03449, 0.08633,0.15008, 0.60996, 0.31071, -0.67721, 0.22421, 0.67717, -0.59857, 0.04372, -0.2565 , 0.33979,0.68116, 0.49852, -0.13441, 0.51374, -0.07421,-0.20066, 0.04504, 0.42865, 0.44418, 0.75939,0.12113, -0.13826, 0.83651, 0.11988, -0.50209}); - auto expV= NDArrayFactory::create('c', {2,2,5,5}, {0.01858, 0.17863, 0.51259, 0.14048, 0.82781, + auto expV= NDArrayFactory::create('c', {2,2,5,5}, {0.01858, 0.17863, 0.51259, 0.14048, 0.82781, 0.59651, -0.13439, -0.395 , 0.66979, 0.14654,0.73731, 0.47061, 0.19357, -0.41127, -0.16817,0.1047 , -0.29727, 0.73711, 0.38235, -0.45951, -0.29873, 0.80012, -0.02078, 0.4651 , -0.23201,-0.05314, -0.0419 , -0.52146, 0.77792, 0.344 ,-0.66438, 0.05648, 0.03756, -0.31531, 0.67422, 0.74471, 0.01504, -0.03081, -0.24335, 0.62049,0.03172, 0.91947, 0.30828, 0.23713, 0.04796,-0.01311, 0.38652, -0.79415, -0.42423, -0.19945, @@ -2340,9 +2340,9 @@ TEST_F(DeclarableOpsTests3, svd_test6) { } else { for(uint i = 0; i < expU.lengthOf(); ++i) - ASSERT_NEAR(nd4j::math::nd4j_abs(expU.t(i)), nd4j::math::nd4j_abs(u->t(i)), 1e-5); + ASSERT_NEAR(nd4j::math::nd4j_abs(expU.e(i)), nd4j::math::nd4j_abs(u->e(i)), 1e-5); for(uint i = 0; i < expV.lengthOf(); ++i) - ASSERT_NEAR(nd4j::math::nd4j_abs(expV.t(i)), nd4j::math::nd4j_abs(v->t(i)), 1e-5); + ASSERT_NEAR(nd4j::math::nd4j_abs(expV.e(i)), nd4j::math::nd4j_abs(v->e(i)), 1e-5); } delete results; @@ -2351,14 +2351,14 @@ TEST_F(DeclarableOpsTests3, svd_test6) { /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, svd_test7) { - auto x= NDArrayFactory::create('c', {2,2,5,5}, {-7. ,17 ,4 ,-10 ,5 ,1 ,-5 ,-19 ,13 ,-8 ,9 ,13 ,19 ,13 ,-2 + auto x= NDArrayFactory::create('c', {2,2,5,5}, {-7. ,17 ,4 ,-10 ,5 ,1 ,-5 ,-19 ,13 ,-8 ,9 ,13 ,19 ,13 ,-2 ,-8 ,10 ,-9 ,0 ,-20 ,-2 ,14 ,19 ,5 ,-18 ,4 ,-13 ,12 ,-10 ,5 ,-10 ,-10 ,17 ,-5 ,-2 ,10 ,5 ,-4 ,-11 ,15 ,-3 ,15 ,-17 ,-20 ,-10 ,-4 ,12 ,-9 ,16 ,13 ,10 ,-19 ,2 ,-9 ,-10 ,8 ,-2 ,-4 ,3 ,7 ,10 ,-19 ,-11 ,-4 ,-6 ,2 ,-12 ,6 ,-4 ,-14 ,14 ,16 ,7 ,19 ,-17 ,2 ,-14 ,5 ,-1 ,16 ,19 ,-11 ,-14 ,-16 ,-19 ,15 ,-18 ,-12 ,-16 ,16 ,1 ,5 ,7 ,8 ,2 ,13 ,-3 ,6 ,2 ,-5}); - auto expS= NDArrayFactory::create('c', {2,2,5}, {40.95395, 31.46869, 24.79993, 12.33768, 1.80031, + auto expS= NDArrayFactory::create('c', {2,2,5}, {40.95395, 31.46869, 24.79993, 12.33768, 1.80031, 38.18412, 31.52287, 23.52755, 11.79484, 1.90195, 39.34498, 32.54861, 17.52492, 7.03003, 2.2399, 44.72126, 32.3164 , 16.60139, 6.88783, 0.78122}); @@ -2520,9 +2520,9 @@ TEST_F(DeclarableOpsTests3, svd_test7) { // } // else { // for(uint i = 0; i < expU.lengthOf(); ++i) - // ASSERT_NEAR(nd4j::math::nd4j_abs(expU.t(i)), nd4j::math::nd4j_abs(u->t(i)), 1e-5); + // ASSERT_NEAR(nd4j::math::nd4j_abs(expU.e(i)), nd4j::math::nd4j_abs(u->e(i)), 1e-5); // for(uint i = 0; i < expV.lengthOf(); ++i) - // ASSERT_NEAR(nd4j::math::nd4j_abs(expV.t(i)), nd4j::math::nd4j_abs(v->t(i)), 1e-5); + // ASSERT_NEAR(nd4j::math::nd4j_abs(expV.e(i)), nd4j::math::nd4j_abs(v->e(i)), 1e-5); // } // delete results; @@ -2531,18 +2531,18 @@ TEST_F(DeclarableOpsTests3, svd_test7) { /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, svd_test9) { - auto x= NDArrayFactory::create('c', {2,2,5,6}, {17 ,-11 ,20 ,-10 ,19 ,13 ,-18 ,6 ,-2 ,-6 ,-10 ,4 ,-6 ,-4 ,3 ,16 ,12 , + auto x= NDArrayFactory::create('c', {2,2,5,6}, {17 ,-11 ,20 ,-10 ,19 ,13 ,-18 ,6 ,-2 ,-6 ,-10 ,4 ,-6 ,-4 ,3 ,16 ,12 , -15 ,8 ,-8 ,12 ,-1 ,20 ,19 ,-13 ,0 ,20 ,17 ,-8 ,16 ,-19 ,7 ,-16 ,-14 ,-5 ,7 ,7 ,-5 ,12 ,-15 ,7 ,8 , 1 ,-8 ,-17 ,10 ,-11 ,8 ,-10 ,1 ,-6 ,10 ,15 ,19 ,-15 ,8 ,2 ,8 ,12 ,7 ,-5 ,1 ,8 ,4 ,-13 ,2 ,19 ,-2 ,-10 , -8 ,11 ,1 ,20 ,-11 ,4 ,1 ,-17 ,-15 ,0 ,-9 ,-4 ,-1 ,-6 ,-9 ,-13 ,10 ,7 ,-2 ,15 ,-10 ,-1 ,11 ,-20 ,-2 , -1 ,-18 ,12 ,16 ,8 ,-9 ,-20 ,-7 ,-20 ,3 ,-9 ,12 ,8 ,-19 ,-2 ,2 ,1 ,7 ,10 ,-18 ,13 ,6 ,14 ,0 ,19 ,8}); - auto expS= NDArrayFactory::create('c', {2,2,5}, {50.46507, 35.75599, 28.12787, 12.45245, 9.08545, + auto expS= NDArrayFactory::create('c', {2,2,5}, {50.46507, 35.75599, 28.12787, 12.45245, 9.08545, 38.56035, 30.62846, 26.31646, 19.42605, 3.01162, 38.56369, 29.18881, 19.54565, 10.89746, 2.017 , 44.99108, 34.95059, 26.00453, 15.43898, 7.18752}); - auto expU= NDArrayFactory::create('c', {2,2,5,5}, {-0.73644, -0.10751, 0.10081, -0.00325, 0.66025, + auto expU= NDArrayFactory::create('c', {2,2,5,5}, {-0.73644, -0.10751, 0.10081, -0.00325, 0.66025, 0.26329, 0.3079 , 0.38582, 0.77696, 0.28872, 0.03076, 0.03015, -0.9128 , 0.36387, 0.18039, -0.61335, 0.10076, 0.01381, 0.40922, -0.66783, @@ -2563,7 +2563,7 @@ TEST_F(DeclarableOpsTests3, svd_test9) { 0.34555, 0.12488, -0.50703, -0.29269, 0.72267, -0.34713, 0.3847 , -0.7532 , 0.22176, -0.33913}); - auto expV= NDArrayFactory::create('c', {2,2,6,6}, {-4.15640000e-01, -5.30190000e-01, 5.29200000e-02, -7.15710000e-01, + auto expV= NDArrayFactory::create('c', {2,2,6,6}, {-4.15640000e-01, -5.30190000e-01, 5.29200000e-02, -7.15710000e-01, -1.10690000e-01, 1.37280000e-01, 2.86620000e-01, 5.88200000e-02, 1.68760000e-01, -2.55000000e-03, -1.00090000e-01, 9.35890000e-01, @@ -2633,9 +2633,9 @@ TEST_F(DeclarableOpsTests3, svd_test9) { } else { for(uint i = 0; i < expU.lengthOf(); ++i) - ASSERT_NEAR(nd4j::math::nd4j_abs(expU.t(i)), nd4j::math::nd4j_abs(u->t(i)), 1e-5); + ASSERT_NEAR(nd4j::math::nd4j_abs(expU.e(i)), nd4j::math::nd4j_abs(u->e(i)), 1e-5); for(uint i = 0; i < expV.lengthOf(); ++i) - ASSERT_NEAR(nd4j::math::nd4j_abs(expV.t(i)), nd4j::math::nd4j_abs(v->t(i)), 1e-5); + ASSERT_NEAR(nd4j::math::nd4j_abs(expV.e(i)), nd4j::math::nd4j_abs(v->e(i)), 1e-5); } delete results; @@ -2644,18 +2644,18 @@ TEST_F(DeclarableOpsTests3, svd_test9) { /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, svd_test10) { - auto x= NDArrayFactory::create('c', {2,2,5,6}, {17 ,-11 ,20 ,-10 ,19 ,13 ,-18 ,6 ,-2 ,-6 ,-10 ,4 ,-6 ,-4 ,3 ,16 ,12 , + auto x= NDArrayFactory::create('c', {2,2,5,6}, {17 ,-11 ,20 ,-10 ,19 ,13 ,-18 ,6 ,-2 ,-6 ,-10 ,4 ,-6 ,-4 ,3 ,16 ,12 , -15 ,8 ,-8 ,12 ,-1 ,20 ,19 ,-13 ,0 ,20 ,17 ,-8 ,16 ,-19 ,7 ,-16 ,-14 ,-5 ,7 ,7 ,-5 ,12 ,-15 ,7 ,8 , 1 ,-8 ,-17 ,10 ,-11 ,8 ,-10 ,1 ,-6 ,10 ,15 ,19 ,-15 ,8 ,2 ,8 ,12 ,7 ,-5 ,1 ,8 ,4 ,-13 ,2 ,19 ,-2 ,-10 , -8 ,11 ,1 ,20 ,-11 ,4 ,1 ,-17 ,-15 ,0 ,-9 ,-4 ,-1 ,-6 ,-9 ,-13 ,10 ,7 ,-2 ,15 ,-10 ,-1 ,11 ,-20 ,-2 , -1 ,-18 ,12 ,16 ,8 ,-9 ,-20 ,-7 ,-20 ,3 ,-9 ,12 ,8 ,-19 ,-2 ,2 ,1 ,7 ,10 ,-18 ,13 ,6 ,14 ,0 ,19 ,8}); - auto expS= NDArrayFactory::create('c', {2,2,5}, {50.46507, 35.75599, 28.12787, 12.45245, 9.08545, + auto expS= NDArrayFactory::create('c', {2,2,5}, {50.46507, 35.75599, 28.12787, 12.45245, 9.08545, 38.56035, 30.62846, 26.31646, 19.42605, 3.01162, 38.56369, 29.18881, 19.54565, 10.89746, 2.017 , 44.99108, 34.95059, 26.00453, 15.43898, 7.18752}); - auto expU= NDArrayFactory::create('c', {2,2,5,5}, {-0.73644, -0.10751, 0.10081, -0.00325, 0.66025, + auto expU= NDArrayFactory::create('c', {2,2,5,5}, {-0.73644, -0.10751, 0.10081, -0.00325, 0.66025, 0.26329, 0.3079 , 0.38582, 0.77696, 0.28872, 0.03076, 0.03015, -0.9128 , 0.36387, 0.18039, -0.61335, 0.10076, 0.01381, 0.40922, -0.66783, @@ -2676,7 +2676,7 @@ TEST_F(DeclarableOpsTests3, svd_test10) { 0.34555, 0.12488, -0.50703, -0.29269, 0.72267, -0.34713, 0.3847 , -0.7532 , 0.22176, -0.33913}); - auto expV= NDArrayFactory::create('c', {2,2,6,5}, { -4.15640000e-01, -5.30190000e-01, 5.29200000e-02, -7.15710000e-01, + auto expV= NDArrayFactory::create('c', {2,2,6,5}, { -4.15640000e-01, -5.30190000e-01, 5.29200000e-02, -7.15710000e-01, -1.10690000e-01, 2.86620000e-01, 5.88200000e-02, 1.68760000e-01, -2.55000000e-03, -1.00090000e-01, @@ -2746,9 +2746,9 @@ TEST_F(DeclarableOpsTests3, svd_test10) { } else { for(uint i = 0; i < expU.lengthOf(); ++i) - ASSERT_NEAR(nd4j::math::nd4j_abs(expU.t(i)), nd4j::math::nd4j_abs(u->t(i)), 1e-5); + ASSERT_NEAR(nd4j::math::nd4j_abs(expU.e(i)), nd4j::math::nd4j_abs(u->e(i)), 1e-5); for(uint i = 0; i < expV.lengthOf(); ++i) - ASSERT_NEAR(nd4j::math::nd4j_abs(expV.t(i)), nd4j::math::nd4j_abs(v->t(i)), 1e-5); + ASSERT_NEAR(nd4j::math::nd4j_abs(expV.e(i)), nd4j::math::nd4j_abs(v->e(i)), 1e-5); } delete results; diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests4.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests4.cpp index 478a31d4a..1155c72de 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests4.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests4.cpp @@ -201,7 +201,7 @@ TYPED_TEST(TypedDeclarableOpsTests4, Test_Pooling_Parity_10) { TYPED_TEST(TypedDeclarableOpsTests4, Test_Pooling_Parity_11) { auto x = NDArrayFactory::create('c', {1, 1, 3, 3}); - auto exp = NDArrayFactory::create('c', {1, 1, 2, 2}, {3, 4, 6, 7}); + auto exp = NDArrayFactory::create('c', {1, 1, 2, 2}, {3.f, 4.f, 6.f, 7.f}); x.linspace(1); @@ -1582,17 +1582,17 @@ TEST_F(DeclarableOpsTests4, relu6_bp_test1) { //////////////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedDeclarableOpsTests4, LrnTest_1) { - auto x = NDArrayFactory::create('c', {2, 2, 2, 2}, { 5.5, 0., 0.3, 5.5, - 8.6, 0., 0., 0.4, - 1.5, 1., 1.3, 1.5, - 2.6, 2., 3., 1.4} + auto x = NDArrayFactory::create('c', {2, 2, 2, 2}, { 5.5f, 0.f, 0.3f, 5.5f, + 8.6f, 0.f, 0.f, 0.4f, + 1.5f, 1.f, 1.3f, 1.5f, + 2.6f, 2.f, 3.f, 1.4f} ); auto exp = NDArrayFactory::create('c', {2, 2, 2, 2}, { - 0.98386997, 0., 0.05358852, 0.9824562, - 0.99330735, 0., 0., 0.37139067, - 0.72760683, 0.4850712, 0.5848977, 0.67488194, - 0.7581754, 0.58321184, 0.86747235, 0.4048204} + 0.98386997f, 0.f, 0.05358852f, 0.9824562f, + 0.99330735f, 0.f, 0.f, 0.37139067f, + 0.72760683f, 0.4850712f, 0.5848977f, 0.67488194f, + 0.7581754f, 0.58321184f, 0.86747235f, 0.4048204f} ); nd4j::ops::lrn op; @@ -1612,16 +1612,16 @@ TYPED_TEST(TypedDeclarableOpsTests4, LrnTest_1) { //////////////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedDeclarableOpsTests4, LrnTest_2) { - auto x = NDArrayFactory::create('c', {2, 2, 2, 2}, { 5.5, 0., 0.3, 5.5, - 8.6, 0., 0., 0.4, - 1.5, 1., 1.3, 1.5, - 2.6, 2., 3., 1.4}); + auto x = NDArrayFactory::create('c', {2, 2, 2, 2}, { 5.5f, 0.f, 0.3f, 5.5f, + 8.6f, 0.f, 0.f, 0.4f, + 1.5f, 1.f, 1.3f, 1.5f, + 2.6f, 2.f, 3.f, 1.4f}); auto exp = NDArrayFactory::create('c', {2, 2, 2, 2}, { - 0.98386997, 0., 0.05358852, 0.9824562, - 0.99330735, 0., 0., 0.37139067, - 0.72760683, 0.4850712, 0.5848977, 0.67488194, - 0.7581754, 0.58321184, 0.86747235, 0.4048204}); + 0.98386997f, 0.f, 0.05358852f, 0.9824562f, + 0.99330735f, 0.f, 0.f, 0.37139067f, + 0.72760683f, 0.4850712f, 0.5848977f, 0.67488194f, + 0.7581754f, 0.58321184f, 0.86747235f, 0.4048204f}); nd4j::ops::lrn op; auto results = op.execute({&x}, {1.0, 1.0, 0.5}, {2}, {}, false, nd4j::DataType::DOUBLE); @@ -1641,25 +1641,25 @@ TYPED_TEST(TypedDeclarableOpsTests4, LrnTest_3) { auto x = NDArrayFactory::create('c', {2, 2, 2, 4}, { - 5.5, 0., 0.3, 5.5, - 1.5, 0., 1.3, 6.5, - 8.6, 0., 0., 0.4, - 2.5, 1., 0.3, 4.5, - 1.5, 1., 1.3, 1.5, - 3.5, 0., 1.3, 2.5, - 2.6, 2., 3., 1.4, - 4.5, 1., 0.3, 0.5} + 5.5f, 0.f, 0.3f, 5.5f, + 1.5f, 0.f, 1.3f, 6.5f, + 8.6f, 0.f, 0.f, 0.4f, + 2.5f, 1.f, 0.3f, 4.5f, + 1.5f, 1.f, 1.3f, 1.5f, + 3.5f, 0.f, 1.3f, 2.5f, + 2.6f, 2.f, 3.f, 1.4f, + 4.5f, 1.f, 0.3f, 0.5f} ); auto exp = NDArrayFactory::create('c', {2, 2, 2, 4}, { - 0.9824562, 0., 0.03822664, 0.9824562, - 0.67488194, 0., 0.18924236, 0.96960944, - 0.99330735, 0., 0., 0.37139067, - 0.86567914, 0.18702209, 0.05610663, 0.9520745, - 0.6154575, 0.34942827, 0.45425674, 0.6154575, - 0.905509, 0. , 0.2824086, 0.8361251, - 0.57063663, 0.41959068, 0.629386, 0.3504383, - 0.9520745, 0.21039814, 0.06311944, 0.3268602 } + 0.9824562f, 0.f, 0.03822664f, 0.9824562f, + 0.67488194f, 0.f, 0.18924236f, 0.96960944f, + 0.99330735f, 0.f, 0.f, 0.37139067f, + 0.86567914f, 0.18702209f, 0.05610663f, 0.9520745f, + 0.6154575f, 0.34942827f, 0.45425674f, 0.6154575f, + 0.905509f, 0.f, 0.2824086f, 0.8361251f, + 0.57063663f, 0.41959068f, 0.629386f, 0.3504383f, + 0.9520745f, 0.21039814f, 0.06311944f, 0.3268602f } ); nd4j::ops::lrn op; @@ -1680,25 +1680,25 @@ TYPED_TEST(TypedDeclarableOpsTests4, LrnTest_4) { auto x = NDArrayFactory::create('c', {2, 2, 2, 4}, { - 5.5, 0., 0.3, 5.5, - 1.5, 0., 1.3, 6.5, - 8.6, 0., 0., 0.4, - 2.5, 1., 0.3, 4.5, - 1.5, 1., 1.3, 1.5, - 3.5, 0., 1.3, 2.5, - 2.6, 2., 3., 1.4, - 4.5, 1., 0.3, 0.5} + 5.5f, 0.f, 0.3f, 5.5f, + 1.5f, 0.f, 1.3f, 6.5f, + 8.6f, 0.f, 0.f, 0.4f, + 2.5f, 1.f, 0.3f, 4.5f, + 1.5f, 1.f, 1.3f, 1.5f, + 3.5f, 0.f, 1.3f, 2.5f, + 2.6f, 2.f, 3.f, 1.4f, + 4.5f, 1.f, 0.3f, 0.5f} ); auto exp = NDArrayFactory::create('c', {2, 2, 2, 4}, { - 0.70082176, 0., 0.03822664, 0.70082176, - 0.21835658, 0., 0.18924236, 0.9462118, - 0.9922489, 0., 0., 0.04615111, - 0.46755522, 0.18702209, 0.05610663, 0.8415994, - 0.5241424, 0.34942827, 0.45425674, 0.5241424, - 0.76033086, 0., 0.2824086, 0.54309344, - 0.54546785, 0.41959068, 0.629386, 0.29371348, - 0.94679165, 0.21039814, 0.06311944, 0.10519907} + 0.70082176f, 0.f, 0.03822664f, 0.70082176f, + 0.21835658f, 0.f, 0.18924236f, 0.9462118f, + 0.9922489f, 0.f, 0.f, 0.04615111f, + 0.46755522f, 0.18702209f, 0.05610663f, 0.8415994f, + 0.5241424f, 0.34942827f, 0.45425674f, 0.5241424f, + 0.76033086f, 0.f, 0.2824086f, 0.54309344f, + 0.54546785f, 0.41959068f, 0.629386f, 0.29371348f, + 0.94679165f, 0.21039814f, 0.06311944f, 0.10519907f} ); nd4j::ops::lrn op; @@ -1719,29 +1719,29 @@ TYPED_TEST(TypedDeclarableOpsTests4, LrnTest_5) { auto x = NDArrayFactory::create('c', {2, 2, 2, 4}, { - 5.5,0., 0.3, 5.5, - 1.5,0., 1.3, 6.5, - 8.6,0., 0., 0.4, - 2.5,1., 0.3, 4.5, - 1.5,1., 1.3, 1.5, - 3.5,0., 1.3, 2.5, - 2.6,2., 3., 1.4, - 4.5,1., 0.3, 0.5} + 5.5f, 0.f, 0.3f, 5.5f, + 1.5f, 0.f, 1.3f, 6.5f, + 8.6f, 0.f, 0.f, 0.4f, + 2.5f, 1.f, 0.3f, 4.5f, + 1.5f, 1.f, 1.3f, 1.5f, + 3.5f, 0.f, 1.3f, 2.5f, + 2.6f, 2.f, 3.f, 1.4f, + 4.5f, 1.f, 0.3f, 0.5f} ); auto eps = NDArrayFactory::create('c', {2, 2, 2, 4}, { - 0.70082176, 0., 0.03822664, 0.70082176, - 0.21835658, 0., 0.18924236, 0.9462118, + 0.70082176f, 0.f, 0.03822664f, 0.70082176f, + 0.21835658f, 0.f, 0.18924236f, 0.9462118f, - 0.9922489, 0., 0. , 0.04615111, - 0.46755522, 0.18702209, 0.05610663, 0.8415994, + 0.9922489f, 0.f, 0.f, 0.04615111f, + 0.46755522f, 0.18702209f, 0.05610663f, 0.8415994f, - 0.5241424, 0.34942827, 0.45425674, 0.5241424, - 0.76033086, 0., 0.2824086 , 0.54309344, + 0.5241424f, 0.34942827f, 0.45425674f, 0.5241424f, + 0.76033086f, 0.f, 0.2824086f, 0.54309344f, - 0.54546785, 0.41959068, 0.629386 , 0.29371348, - 0.94679165, 0.21039814, 0.06311944, 0.10519907} + 0.54546785f, 0.41959068f, 0.629386f, 0.29371348f, + 0.94679165f, 0.21039814f, 0.06311944f, 0.10519907f} ); auto exp = NDArrayFactory::create('c', {2, 2, 2, 4}); @@ -1766,7 +1766,7 @@ TEST_F(DeclarableOpsTests4, tri_test1) { const int rows = 3; const int cols = 5; - auto expected = NDArrayFactory::create('c', {rows, cols}, {1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0}); + auto expected = NDArrayFactory::create('c', {rows, cols}, {1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 1.f, 0.f, 0.f, 0.f, 1.f, 1.f, 1.f, 0.f, 0.f}); nd4j::ops::tri op; auto results = op.execute({}, {}, {rows, cols}); @@ -1789,7 +1789,7 @@ TEST_F(DeclarableOpsTests4, tri_test2) { const int cols = 5; const int diag = 2; - auto expected = NDArrayFactory::create('c', {rows, cols}, {1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1}); + auto expected = NDArrayFactory::create('c', {rows, cols}, {1.f, 1.f, 1.f, 0.f, 0.f, 1.f, 1.f, 1.f, 1.f, 0.f, 1.f, 1.f, 1.f, 1.f, 1.f}); nd4j::ops::tri op; auto results = op.execute({}, {}, {rows, cols, diag}); @@ -1810,7 +1810,7 @@ TEST_F(DeclarableOpsTests4, tri_test3) { const int cols = 5; const int diag = -1; - auto expected = NDArrayFactory::create('c', {rows, cols}, {0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0}); + auto expected = NDArrayFactory::create('c', {rows, cols}, {0.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 1.f, 0.f, 0.f, 0.f}); nd4j::ops::tri op; auto results = op.execute({}, {}, {rows, cols, diag}); @@ -1831,7 +1831,7 @@ TEST_F(DeclarableOpsTests4, tri_test4) { const int cols = 5; const int diag = -2; - auto expected = NDArrayFactory::create('c', {rows, cols}, {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0}); + auto expected = NDArrayFactory::create('c', {rows, cols}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 0.f}); nd4j::ops::tri op; auto results = op.execute({}, {}, {rows, cols, diag}); @@ -1850,7 +1850,7 @@ TEST_F(DeclarableOpsTests4, tri_test5) { const int rows = 5; - auto expected = NDArrayFactory::create('c', {rows, rows}, {1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1}); + auto expected = NDArrayFactory::create('c', {rows, rows}, {1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 1.f, 0.f, 0.f, 0.f, 1.f, 1.f, 1.f, 0.f, 0.f, 1.f, 1.f, 1.f, 1.f, 0.f, 1.f, 1.f, 1.f, 1.f, 1.f}); nd4j::ops::tri op; auto results = op.execute({}, {}, {rows}); @@ -1871,7 +1871,7 @@ TEST_F(DeclarableOpsTests4, tri_test6) { const int cols = 5; const int diag = -20; - auto expected = NDArrayFactory::create('c', {rows, cols}, {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}); + auto expected = NDArrayFactory::create('c', {rows, cols}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}); nd4j::ops::tri op; auto results = op.execute({}, {}, {rows, cols, diag}); @@ -1892,7 +1892,7 @@ TEST_F(DeclarableOpsTests4, tri_test7) { const int cols = 5; const int diag = 20; - auto expected = NDArrayFactory::create('c', {rows, cols}, {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}); + auto expected = NDArrayFactory::create('c', {rows, cols}, {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}); nd4j::ops::tri op; auto results = op.execute({}, {}, {rows, cols, diag}); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests5.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests5.cpp index 2c15f24bc..c30ad5f89 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests5.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests5.cpp @@ -242,10 +242,10 @@ TEST_F(DeclarableOpsTests5, Test_SetSeed_1) { } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, scatterMul_test1) { - auto matrix = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); + auto matrix = NDArrayFactory::create('c', {2, 2}, {1.f, 2.f, 3.f, 4.f}); NDArray idc('c', {1}, {0LL}, nd4j::DataType::INT64); - auto updates = NDArrayFactory::create('c', {1, 2}, {10, 1}); - auto exp = NDArrayFactory::create('c', {2, 2}, {10, 2, 3, 4}); + auto updates = NDArrayFactory::create('c', {1, 2}, {10.f, 1.f}); + auto exp = NDArrayFactory::create('c', {2, 2}, {10.f, 2.f, 3.f, 4.f}); nd4j::ops::scatter_mul op; auto result = op.execute({&matrix, &idc, &updates}, {}, {}, {}); @@ -260,10 +260,10 @@ TEST_F(DeclarableOpsTests5, scatterMul_test1) { //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, scatterDiv_test1) { - auto matrix = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); + auto matrix = NDArrayFactory::create('c', {2, 2}, {1.f, 2.f, 3.f, 4.f}); NDArray idc('c', {1}, {0LL}, nd4j::DataType::INT64); - auto updates = NDArrayFactory::create('c', {1, 2}, {10, 1}); - auto exp = NDArrayFactory::create('c', {2, 2}, {0.10, 2, 3, 4}); + auto updates = NDArrayFactory::create('c', {1, 2}, {10.f, 1.f}); + auto exp = NDArrayFactory::create('c', {2, 2}, {0.10f, 2.f, 3.f, 4.f}); nd4j::ops::scatter_div op; auto result = op.execute({&matrix, &idc, &updates}, {}, {}, {}); @@ -278,10 +278,10 @@ TEST_F(DeclarableOpsTests5, scatterDiv_test1) { //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, scatterSub_test1) { - auto matrix = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); + auto matrix = NDArrayFactory::create('c', {2, 2}, {1.f, 2.f, 3.f, 4.f}); NDArray idc('c', {1}, {0LL}, nd4j::DataType::INT64); - auto updates = NDArrayFactory::create('c', {1, 2}, {10, 1}); - auto exp = NDArrayFactory::create('c', {2, 2}, {-9, 1, 3, 4}); + auto updates = NDArrayFactory::create('c', {1, 2}, {10.f, 1.f}); + auto exp = NDArrayFactory::create('c', {2, 2}, {-9.f, 1.f, 3.f, 4.f}); nd4j::ops::scatter_sub op; auto result = op.execute({&matrix, &idc, &updates}, {}, {}, {}); @@ -296,8 +296,8 @@ TEST_F(DeclarableOpsTests5, scatterSub_test1) { //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, hardsigmoid_test1) { - auto matrix = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); - auto exp = NDArrayFactory::create('c', {2, 2}, {0.7, 0.9, 1, 1}); + auto matrix = NDArrayFactory::create('c', {2, 2}, {1.f, 2.f, 3.f, 4.f}); + auto exp = NDArrayFactory::create('c', {2, 2}, {0.7f, 0.9f, 1.f, 1.f}); nd4j::ops::hardsigmoid op; auto result = op.execute({&matrix}, {}, {}, {}); @@ -311,9 +311,9 @@ TEST_F(DeclarableOpsTests5, hardsigmoid_test1) { //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, hardsigmoid_test2) { - auto matrix = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); - auto eps = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); - auto exp = NDArrayFactory::create('c', {2, 2}, {0.2, 0.4, 0, 0}); + auto matrix = NDArrayFactory::create('c', {2, 2}, {1.f, 2.f, 3.f, 4.f}); + auto eps = NDArrayFactory::create('c', {2, 2}, {1.f, 2.f, 3.f, 4.f}); + auto exp = NDArrayFactory::create('c', {2, 2}, {0.2f, 0.4f, 0.f, 0.f}); nd4j::ops::hardsigmoid_bp op; auto result = op.execute({&matrix, &eps}, {}, {}, {}); @@ -327,8 +327,8 @@ TEST_F(DeclarableOpsTests5, hardsigmoid_test2) { //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, hardtanh_test1) { - auto matrix = NDArrayFactory::create('c', {3, 3}, {-4, -3, -2, -1, 0, 1, 2, 3, 4}); - auto exp = NDArrayFactory::create('c', {3, 3}, {-1, -1, -1, -1, 0, 1, 1, 1, 1}); + auto matrix = NDArrayFactory::create('c', {3, 3}, {-4, -3, -2, -1, 0, 1, 2, 3, 4}); + auto exp = NDArrayFactory::create('c', {3, 3}, {-1, -1, -1, -1, 0, 1, 1, 1, 1}); nd4j::ops::hardtanh op; auto result = op.execute({&matrix}, {}, {}, {}); @@ -342,9 +342,9 @@ TEST_F(DeclarableOpsTests5, hardtanh_test1) { } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, hardtanh_test2) { - auto matrix = NDArrayFactory::create('c', {3, 3}, {-4, -3, -2, -1, 0, 1, 2, 3, 4}); - auto eps = NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); - auto exp = NDArrayFactory::create('c', {3, 3}, {0, 0, 0, 4, 5, 6, 0, 0, 0}); + auto matrix = NDArrayFactory::create('c', {3, 3}, {-4, -3, -2, -1, 0, 1, 2, 3, 4}); + auto eps = NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto exp = NDArrayFactory::create('c', {3, 3}, {0, 0, 0, 4, 5, 6, 0, 0, 0}); nd4j::ops::hardtanh_bp op; auto result = op.execute({&matrix, &eps}, {}, {}, {}); @@ -389,7 +389,7 @@ TEST_F(DeclarableOpsTests5, histogram_test2) { //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, Identity_test1) { - auto matrix = NDArrayFactory::create('c', {3, 3}, {-4, -3, -2, -1, 0, 1, 2, 3, 4}); + auto matrix = NDArrayFactory::create('c', {3, 3}, {-4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, 4.f}); // auto exp = NDArrayFactory::create('c', {3, 3}, {3, 3, 3}); nd4j::ops::identity op; @@ -404,8 +404,8 @@ TEST_F(DeclarableOpsTests5, Identity_test1) { //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, Identity_test2) { - auto matrix = NDArrayFactory::create('c', {3, 3}, {-4, -3, -2, -1, 0, 1, 2, 3, 4}); - auto eps = NDArrayFactory::create('c', {3, 3}, {1,2,3,4,5,6,7,8,9}); + auto matrix = NDArrayFactory::create('c', {3, 3}, {-4, -3, -2, -1, 0, 1, 2, 3, 4}); + auto eps = NDArrayFactory::create('c', {3, 3}, {1,2,3,4,5,6,7,8,9}); // auto exp = NDArrayFactory::create('c', {3,3}); nd4j::ops::identity_bp op; auto result = op.execute({&matrix, &eps}, {}, {}, {}); @@ -418,8 +418,8 @@ TEST_F(DeclarableOpsTests5, Identity_test2) { } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, Log1p_test1) { - auto matrix = NDArrayFactory::create('c', {3, 3}, {4, 3, 2, 1, 0, 1, 2, 3, 4}); - auto y = NDArrayFactory::create('c', {3,3}, {5,4,3,2,1,2,3,4,5}); + auto matrix = NDArrayFactory::create('c', {3, 3}, {4, 3, 2, 1, 0, 1, 2, 3, 4}); + auto y = NDArrayFactory::create('c', {3,3}, {5,4,3,2,1,2,3,4,5}); // auto eps = NDArrayFactory::create('c', {3, 3}, {1,2,3,4,5,6,7,8,9}); // auto exp = NDArrayFactory::create('c', {3,3}); nd4j::ops::Log1p op; @@ -599,7 +599,7 @@ TEST_F(DeclarableOpsTests5, Test_BatchToSpace_4) { ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, eye_test1) { - auto expected = NDArrayFactory::create('c', {3, 3}, {1, 0, 0, 0, 1, 0, 0, 0, 1}); + auto expected = NDArrayFactory::create('c', {3, 3}, {1.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 1.f}); nd4j::ops::eye op; auto results = op.execute({}, {}, {-99, 3}); @@ -616,7 +616,7 @@ TEST_F(DeclarableOpsTests5, eye_test1) { ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, eye_test2) { - auto expected = NDArrayFactory::create('c', {3, 4}, {1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0}); + auto expected = NDArrayFactory::create('c', {3, 4}, {1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f}); nd4j::ops::eye op; auto results = op.execute({}, {}, {-99, 3, 4}); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests6.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests6.cpp index 3ebcbd016..2fbd42af7 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests6.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests6.cpp @@ -348,8 +348,8 @@ TEST_F(DeclarableOpsTests6, cumSum_1) { } TEST_F(DeclarableOpsTests6, cumSum_2) { - auto x= NDArrayFactory::create('c', {2, 4}, {1, 2, 3, 4, 1, 2, 3, 4}); - auto exp= NDArrayFactory::create('c', {2, 4}, {1, 3, 6, 10, 1, 3, 6, 10}); + auto x= NDArrayFactory::create('c', {2, 4}, {1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f}); + auto exp= NDArrayFactory::create('c', {2, 4}, {1.f, 3.f, 6.f, 10.f, 1.f, 3.f, 6.f, 10.f}); nd4j::ops::cumsum op; auto result = op.execute({&x}, {}, {0, 0, 1}); @@ -365,8 +365,8 @@ TEST_F(DeclarableOpsTests6, cumSum_2) { } TEST_F(DeclarableOpsTests6, cumSum_3) { - auto x= NDArrayFactory::create('c', {2, 4}, {1, 2, 3, 4, 1, 2, 3, 4}); - auto exp= NDArrayFactory::create('c', {2, 4}, {1, 2, 3, 4, 2, 4, 6, 8}); + auto x= NDArrayFactory::create('c', {2, 4}, {1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f}); + auto exp= NDArrayFactory::create('c', {2, 4}, {1.f, 2.f, 3.f, 4.f, 2.f, 4.f, 6.f, 8.f}); nd4j::ops::cumsum op; auto result = op.execute({&x}, {}, {0, 0, 0}); @@ -649,13 +649,13 @@ TEST_F(DeclarableOpsTests6, cumSum_17) { NDArray exp0 = exp(0, {0}); NDArray exp1 = exp(1, {0}); - exp0.p(0, 1.); - exp1.p(0, 1.); + exp0.p(0, 1.); + exp1.p(0, 1.); for (int i = 1; i < 1500; ++i) { const auto prev = exp0.e(i-1); - exp0.p(i, prev + i + 1); - exp1.p(i, prev + i + 1); + exp0.p(i, prev + i + 1); + exp1.p(i, prev + i + 1); } nd4j::ops::cumsum op; @@ -682,13 +682,13 @@ TEST_F(DeclarableOpsTests6, cumSum_18) { NDArray exp0 = exp(0, {0}); NDArray exp1 = exp(1, {0}); - exp0.p(0, 0.); - exp1.p(0, 0.); + exp0.p(0, 0.); + exp1.p(0, 0.); for (int i = 1; i < 1500; ++i) { const auto prev = exp0.e(i-1); - exp0.p(i, prev + i); - exp1.p(i, prev + i); + exp0.p(i, prev + i); + exp1.p(i, prev + i); } nd4j::ops::cumsum op; @@ -715,13 +715,13 @@ TEST_F(DeclarableOpsTests6, cumSum_19) { NDArray exp0 = exp(0, {0}); NDArray exp1 = exp(1, {0}); - exp0.p(1499, 1500.); - exp1.p(1499, 1500.); + exp0.p(1499, 1500.f); + exp1.p(1499, 1500.f); for (int i = 1498; i >= 0; --i) { const auto prev = exp0.e(i + 1); - exp0.p(i, prev + i + 1); - exp1.p(i, prev + i + 1); + exp0.p(i, prev + i + 1); + exp1.p(i, prev + i + 1); } nd4j::ops::cumsum op; @@ -749,13 +749,13 @@ TEST_F(DeclarableOpsTests6, cumSum_20) { NDArray exp0 = exp(0, {0}); NDArray exp1 = exp(1, {0}); - exp0.p(1499, 0.); - exp1.p(1499, 0.); + exp0.p(1499, 0.); + exp1.p(1499, 0.); for (int i = 1498; i >= 0; --i) { const auto prev = exp0.e(i + 1); - exp0.p(i, prev + i + 2); - exp1.p(i, prev + i + 2); + exp0.p(i, prev + i + 2); + exp1.p(i, prev + i + 2); } nd4j::ops::cumsum op; @@ -1576,7 +1576,7 @@ TEST_F(DeclarableOpsTests6, LogDet_3) { //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, MatrixInverse_1) { - auto x = NDArrayFactory::create('c', {2, 5, 5}, { + auto x = NDArrayFactory::create('c', {2, 5, 5}, { 2., 4., 60., 8., 10., 0., 1., 2., 3., 4., 0., 0., 2., 4., 6., @@ -1590,7 +1590,7 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_1) { 5., 4., 3., 2., 1., }); - auto exp = NDArrayFactory::create('c', {2, 5, 5}, { + auto exp = NDArrayFactory::create('c', {2, 5, 5}, { 0.5, -2.0, -13.0, 54.0, -6.75, 0.0, 1.0, -1.0, 1.0, 0.0, 0, 0, 0.5, -2.0, 0.25, @@ -1620,8 +1620,8 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_1) { //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, MatrixInverse_010) { - auto x = NDArrayFactory::create('c', {1, 5, 5}, {1., 0., 0., 0., 0.,2., 1., 0., 0., 0.,30., 2., 1., 0., 0.,4., 3., 2., 1., 0.,5., 4., 3., 2., 1.,}); - auto exp = NDArrayFactory::create('c', {1, 5, 5}, {1.0, 0.0, 0.0, 0.0, 0.,-2.0, 1.0, 0., 0., 0.,-26.0, -2.0, 1, 0, 0.,54.0, 1.0, -2.0, 1, 0.,-27.0, 0.0, 1.0, -2.0, 1.}); + auto x = NDArrayFactory::create('c', {1, 5, 5}, {1., 0., 0., 0., 0.,2., 1., 0., 0., 0.,30., 2., 1., 0., 0.,4., 3., 2., 1., 0.,5., 4., 3., 2., 1.,}); + auto exp = NDArrayFactory::create('c', {1, 5, 5}, {1.0, 0.0, 0.0, 0.0, 0.,-2.0, 1.0, 0., 0., 0.,-26.0, -2.0, 1, 0, 0.,54.0, 1.0, -2.0, 1, 0.,-27.0, 0.0, 1.0, -2.0, 1.}); nd4j::ops::matrix_inverse op; auto result = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::FLOAT32); @@ -1639,9 +1639,9 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_010) { //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, MatrixInverse_01) { - auto x = NDArrayFactory::create('c', {1, 5, 5}, {2., 4., 60., 8., 10., 0., 1., 2., 3., 4., 0., 0., 2., 4., 6., 0., 0., 0., 1., 2., 0., 0., 0., 0., 4. }); + auto x = NDArrayFactory::create('c', {1, 5, 5}, {2., 4., 60., 8., 10., 0., 1., 2., 3., 4., 0., 0., 2., 4., 6., 0., 0., 0., 1., 2., 0., 0., 0., 0., 4. }); - auto exp = NDArrayFactory::create('c', {1, 5, 5}, {0.5, -2.0, -13.0, 54.0, -6.75, 0.0, 1.0, -1.0, 1.0, 0.0, 0, 0, 0.5, -2.0, 0.25, 0, 0, 0, 1.0, -0.5, 0, 0, 0, 0, 0.25 }); + auto exp = NDArrayFactory::create('c', {1, 5, 5}, {0.5, -2.0, -13.0, 54.0, -6.75, 0.0, 1.0, -1.0, 1.0, 0.0, 0, 0, 0.5, -2.0, 0.25, 0, 0, 0, 1.0, -0.5, 0, 0, 0, 0, 0.25 }); nd4j::ops::matrix_inverse op; auto result = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::FLOAT32); @@ -1658,8 +1658,8 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_01) { //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, MatrixInverse_02) { - auto x = NDArrayFactory::create('c', {1, 5, 5}, {1., 0., 0., 0., 0., 2., 1., 0., 0., 0., 30., 2., 1., 0., 0., 4., 3., 2., 1., 0., 5., 4., 3., 2., 1. }); - auto exp = NDArrayFactory::create('c', {1, 5, 5}, {1.0, 0.0, 0.0, 0.0, 0., -2.0, 1.0, 0., 0., 0., -26.0, -2.0, 1, 0, 0., 54.0, 1.0, -2.0, 1, 0., -27.0, 0.0, 1.0, -2.0, 1. }); + auto x = NDArrayFactory::create('c', {1, 5, 5}, {1., 0., 0., 0., 0., 2., 1., 0., 0., 0., 30., 2., 1., 0., 0., 4., 3., 2., 1., 0., 5., 4., 3., 2., 1. }); + auto exp = NDArrayFactory::create('c', {1, 5, 5}, {1.0, 0.0, 0.0, 0.0, 0., -2.0, 1.0, 0., 0., 0., -26.0, -2.0, 1, 0, 0., 54.0, 1.0, -2.0, 1, 0., -27.0, 0.0, 1.0, -2.0, 1. }); nd4j::ops::matrix_inverse op; auto result = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::FLOAT32); @@ -1724,19 +1724,19 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_2) { TEST_F(DeclarableOpsTests6, MatrixInverse_03) { auto x = NDArrayFactory::create('c', {5, 5}, { - 4., 0., 0., 0., 0., - 4., 2., 0., 0., 0., - 30., 2., 1., 0., 0., - 8., 6., 4., 2., 0., - 15., 12., 9., 6., 3., + 4.f, 0.f, 0.f, 0.f, 0.f, + 4.f, 2.f, 0.f, 0.f, 0.f, + 30.f, 2.f, 1.f, 0.f, 0.f, + 8.f, 6.f, 4.f, 2.f, 0.f, + 15.f, 12.f, 9.f, 6.f, 3.f, }); auto exp = NDArrayFactory::create('c', {5, 5}, { - 0.25, 0.0, 0.0, 0.0, 0.0, - -0.50, 0.5, 0.0, 0.0, 0.0, - -6.50, -1.0, 1.0, 0.0, 0.0, - 13.50, 0.5, -2.0, 0.5, 0.0, - -6.75, 0.0, 1.0, -1.0, 0.33333333 + 0.25f, 0.0f, 0.0f, 0.0f, 0.0f, + -0.50f, 0.5f, 0.0f, 0.0f, 0.0f, + -6.50f, -1.0f, 1.0f, 0.0f, 0.0f, + 13.50f, 0.5f, -2.0f, 0.5f, 0.0f, + -6.75f, 0.0f, 1.0f, -1.0f, 0.33333333f }); nd4j::ops::matrix_inverse op; @@ -1758,19 +1758,19 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_03) { TEST_F(DeclarableOpsTests6, MatrixInverse_3) { auto x = NDArrayFactory::create('c', {5, 5}, { - 4., 0., 0., 0., 0., - 4., 2., 0., 0., 0., - 30., 2., 1., 0., 0., - 8., 6., 4., 2., 0., - 15., 12., 9., 6., 3., + 4.f, 0.f, 0.f, 0.f, 0.f, + 4.f, 2.f, 0.f, 0.f, 0.f, + 30.f, 2.f, 1.f, 0.f, 0.f, + 8.f, 6.f, 4.f, 2.f, 0.f, + 15.f, 12.f, 9.f, 6.f, 3.f, }); auto exp = NDArrayFactory::create('c', {5, 5}, { - 0.25, 0.0, 0.0, 0.0, 0.0, - -0.50, 0.5, 0.0, 0.0, 0.0, - -6.50, -1.0, 1.0, 0.0, 0.0, - 13.50, 0.5, -2.0, 0.5, 0.0, - -6.75, 0.0, 1.0, -1.0, 0.33333333 + 0.25f, 0.0f, 0.0f, 0.0f, 0.0f, + -0.50f, 0.5f, 0.0f, 0.0f, 0.0f, + -6.50f, -1.0f, 1.0f, 0.0f, 0.0f, + 13.50f, 0.5f, -2.0f, 0.5f, 0.0f, + -6.75f, 0.0f, 1.0f, -1.0f, 0.33333333f }); nd4j::ops::matrix_inverse op; @@ -1792,19 +1792,19 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_3) { TEST_F(DeclarableOpsTests6, MatrixInverse_4) { auto x = NDArrayFactory::create('c', {5, 5}, { - 1., 2., 30., 4., 5., - 0., 1., 2., 3., 4., - 0., 0., 1., 2., 3., - 0., 0., 0., 1., 2., - 0., 0., 0., 0., 1. + 1.f, 2.f, 30.f, 4.f, 5.f, + 0.f, 1.f, 2.f, 3.f, 4.f, + 0.f, 0.f, 1.f, 2.f, 3.f, + 0.f, 0.f, 0.f, 1.f, 2.f, + 0.f, 0.f, 0.f, 0.f, 1.f }); auto exp = NDArrayFactory::create('c', {5, 5}, { - 1.0, -2.0, -26.0, 54.0, -27.0, - 0.0, 1.0, -2.0, 1.0, 0.0, - 0.0, 0.0, 1.0, -2.0, 1.0, - 0.0, 0.0, 0.0, 1.0, -2.0, - 0.0, 0.0, 0.0, 0.0, 1.0 + 1.0f, -2.0f, -26.0f, 54.0f, -27.0f, + 0.0f, 1.0f, -2.0f, 1.0f, 0.0f, + 0.0f, 0.0f, 1.0f, -2.0f, 1.0f, + 0.0f, 0.0f, 0.0f, 1.0f, -2.0f, + 0.0f, 0.0f, 0.0f, 0.0f, 1.0f }); nd4j::ops::matrix_inverse op; @@ -1826,19 +1826,19 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_4) { TEST_F(DeclarableOpsTests6, MatrixInverse_04) { auto x = NDArrayFactory::create('c', {5, 5}, { - 1., 2., 30., 4., 5., - 0., 1., 2., 3., 4., - 0., 0., 1., 2., 3., - 0., 0., 0., 1., 2., - 0., 0., 0., 0., 1. + 1.f, 2.f, 30.f, 4.f, 5.f, + 0.f, 1.f, 2.f, 3.f, 4.f, + 0.f, 0.f, 1.f, 2.f, 3.f, + 0.f, 0.f, 0.f, 1.f, 2.f, + 0.f, 0.f, 0.f, 0.f, 1.f }); auto exp = NDArrayFactory::create('c', {5, 5}, { - 1.0, -2.0, -26.0, 54.0, -27.0, - 0.0, 1.0, -2.0, 1.0, 0.0, - 0.0, 0.0, 1.0, -2.0, 1.0, - 0.0, 0.0, 0.0, 1.0, -2.0, - 0.0, 0.0, 0.0, 0.0, 1.0 + 1.0f, -2.0f, -26.0f, 54.0f, -27.0f, + 0.0f, 1.0f, -2.0f, 1.0f, 0.0f, + 0.0f, 0.0f, 1.0f, -2.0f, 1.0f, + 0.0f, 0.0f, 0.0f, 1.0f, -2.0f, + 0.0f, 0.0f, 0.0f, 0.0f, 1.0f }); nd4j::ops::matrix_inverse op; diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests7.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests7.cpp index a79633c03..f232411b2 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests7.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests7.cpp @@ -1097,9 +1097,9 @@ TEST_F(DeclarableOpsTests7, TestSegmentMin_01) { } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentMin_02) { - auto x = NDArrayFactory::create({1.8, -2.5,4., -9., 2.1, 2.4,-3.,-9., 2.1, 2.1,0.7, 0.1, 3., -4.2, 2.2, 1.}); + auto x = NDArrayFactory::create({1.8f, -2.5f, 4.f, -9.f, 2.1f, 2.4f, -3.f, -9.f, 2.1f, 2.1f,0.7f, 0.1f, 3.f, -4.2f, 2.2f, 1.f}); auto idx = NDArrayFactory::create({0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); - auto exp = NDArrayFactory::create({-2.5, -9, -3., -9, -4.2}); + auto exp = NDArrayFactory::create({-2.5f, -9.f, -3.f, -9.f, -4.2f}); nd4j::ops::segment_min op; @@ -1432,7 +1432,7 @@ TEST_F(DeclarableOpsTests7, TestSegmentMean_02) { TEST_F(DeclarableOpsTests7, TestSegmentMean_021) { auto x = NDArrayFactory::create('c', {6, 3});//, {1, 2, 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18.}); auto idx = NDArrayFactory::create({0, 0, 1, 1, 2,2}); - auto exp = NDArrayFactory::create('c', {3, 3}, { 2.5, 3.5, 4.5, 8.5, 9.5, 10.5, 14.5, 15.5, 16.5}); + auto exp = NDArrayFactory::create('c', {3, 3}, { 2.5f, 3.5f, 4.5f, 8.5f, 9.5f, 10.5f, 14.5f, 15.5f, 16.5f}); nd4j::ops::segment_mean op; x.linspace(1.); @@ -1448,7 +1448,7 @@ TEST_F(DeclarableOpsTests7, TestSegmentMean_022) { auto x = NDArrayFactory::create('c', {6, 3});//, {1, 2, 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18.}); auto idx = NDArrayFactory::create({0, 0, 1, 1, 2,2}); auto z = NDArrayFactory::create('c', {3, 3}); //, { 2.5, 3.5, 4.5, 8.5, 9.5, 10.5, 14.5, 15.5, 16.5}); - auto exp = NDArrayFactory::create('c', {3, 3}, { 2.5, 3.5, 4.5, 8.5, 9.5, 10.5, 14.5, 15.5, 16.5}); + auto exp = NDArrayFactory::create('c', {3, 3}, { 2.5f, 3.5f, 4.5f, 8.5f, 9.5f, 10.5f, 14.5f, 15.5f, 16.5f}); nd4j::ops::segment_mean op; x.linspace(1.); @@ -3897,9 +3897,9 @@ TEST_F(DeclarableOpsTests7, rectifiedtanh_test2) { TEST_F(DeclarableOpsTests7, RealDiv_1) { - NDArray x = NDArrayFactory::create('c', {1, 2, 1}, {2, 4}); - NDArray y = NDArrayFactory::create('c', {1, 2}, {1,2}); - NDArray e = NDArrayFactory::create('c', {1, 2, 2}, {2, 1, 4, 2}); + NDArray x = NDArrayFactory::create('c', {1, 2, 1}, {2.f, 4.f}); + NDArray y = NDArrayFactory::create('c', {1, 2}, {1.f,2.f}); + NDArray e = NDArrayFactory::create('c', {1, 2, 2}, {2.f, 1.f, 4.f, 2.f}); nd4j::ops::realdiv op; auto result = op.execute({&x, &y}, {}, {}); @@ -3917,11 +3917,11 @@ TEST_F(DeclarableOpsTests7, RealDiv_1) { //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, RealDiv_BP_1) { - NDArray x = NDArrayFactory::create('c', {1, 2, 1}, {2, 4}); - NDArray y = NDArrayFactory::create('c', {1, 2}, {1,2}); - NDArray e0 = NDArrayFactory::create('c', {1, 2, 1}, {2, 5}); - NDArray e1 = NDArrayFactory::create('c', {1, 2}, {-14, -5}); - NDArray eps = NDArrayFactory::create('c', {1, 2, 2}, {1, 2, 3, 4}); + NDArray x = NDArrayFactory::create('c', {1, 2, 1}, {2.f, 4.f}); + NDArray y = NDArrayFactory::create('c', {1, 2}, {1.f, 2.f}); + NDArray e0 = NDArrayFactory::create('c', {1, 2, 1}, {2.f, 5.f}); + NDArray e1 = NDArrayFactory::create('c', {1, 2}, {-14.f, -5.f}); + NDArray eps = NDArrayFactory::create('c', {1, 2, 2}, {1.f, 2.f, 3.f, 4.f}); nd4j::ops::realdiv_bp op; auto result = op.execute({&x, &y, &eps}, {}, {}); @@ -3944,7 +3944,7 @@ TEST_F(DeclarableOpsTests7, RealDiv_BP_1) { //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, ShapesOf_1) { - NDArray x = NDArrayFactory::create('c', {1, 2, 1}, {2, 4}); + NDArray x = NDArrayFactory::create('c', {1, 2, 1}, {2.f, 4.f}); // NDArray y = NDArrayFactory::create('c', {1, 2}, {1,2}); NDArray e = NDArrayFactory::create({1, 2, 1}); @@ -3964,8 +3964,8 @@ TEST_F(DeclarableOpsTests7, ShapesOf_1) { //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, ShapesOf_2) { - NDArray x = NDArrayFactory::create('c', {1, 2, 1}, {2, 4}); - NDArray y = NDArrayFactory::create('c', {1, 2}, {1,2}); + NDArray x = NDArrayFactory::create('c', {1, 2, 1}, {2.f, 4.f}); + NDArray y = NDArrayFactory::create('c', {1, 2}, {1.f, 2.f}); NDArray e0 = NDArrayFactory::create({1, 2, 1}); NDArray e1 = NDArrayFactory::create({1, 2}); @@ -3987,8 +3987,8 @@ TEST_F(DeclarableOpsTests7, ShapesOf_2) { TEST_F(DeclarableOpsTests7, Size_1) { - NDArray x = NDArrayFactory::create('c', {1, 2, 1}, {2, 4}); - NDArray y = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); + NDArray x = NDArrayFactory::create('c', {1, 2, 1}, {2.f, 4.f}); + NDArray y = NDArrayFactory::create('c', {5, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 7.f, 9.f, 10.f, 10.f, 11.f}); NDArray e = NDArrayFactory::create(2); nd4j::ops::size op; @@ -4006,8 +4006,8 @@ TEST_F(DeclarableOpsTests7, Size_1) { TEST_F(DeclarableOpsTests7, Size_2) { - NDArray x = NDArrayFactory::create('c', {1, 2, 1}, {2, 4}); - NDArray y = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); + NDArray x = NDArrayFactory::create('c', {1, 2, 1}, {2, 4}); + NDArray y = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); NDArray e = NDArrayFactory::create(10); nd4j::ops::size op; @@ -4025,8 +4025,8 @@ TEST_F(DeclarableOpsTests7, Size_2) { TEST_F(DeclarableOpsTests7, Softplus_1) { - NDArray x = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); - NDArray e = NDArrayFactory::create('c', {5, 2}, {1.3132616, 2.126928, 3.0485873, 4.01815, 5.0067153, 7.0009117, 9.000123, 10.000046, 10.000046, 11.000016}); + NDArray x = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); + NDArray e = NDArrayFactory::create('c', {5, 2}, {1.3132616, 2.126928, 3.0485873, 4.01815, 5.0067153, 7.0009117, 9.000123, 10.000046, 10.000046, 11.000016}); nd4j::ops::softplus op; auto result = op.execute({&x}, {}, {}); @@ -4065,8 +4065,8 @@ TEST_F(DeclarableOpsTests7, Softplus_BP_1) { TEST_F(DeclarableOpsTests7, Softsign_1) { - NDArray x = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); - NDArray e = NDArrayFactory::create('c', {5, 2}, {0.5, 0.6666667, 0.75, 0.8, 0.8333333, 0.875, 0.9, 0.90909094, 0.90909094, 0.9166667}); + NDArray x = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); + NDArray e = NDArrayFactory::create('c', {5, 2}, {0.5, 0.6666667, 0.75, 0.8, 0.8333333, 0.875, 0.9, 0.90909094, 0.90909094, 0.9166667}); nd4j::ops::softsign op; auto result = op.execute({&x}, {}, {}); @@ -4213,7 +4213,7 @@ TEST_F(DeclarableOpsTests7, TypesConversion_test1) { NDArray x = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); NDArray expI = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); NDArray expL = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); - NDArray expF = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); + NDArray expF = NDArrayFactory::create('c', {5, 2}, {1.f,2.f,3.f,4.f,5.f,7.f,9.f,10.f, 10.f, 11.f}); NDArray expF16 = NDArrayFactory::create('c', {5, 2}, {1.f,2.f,3.f,4.f,5.f,7.f,9.f,10.f, 10.f, 11.f}); nd4j::ops::to_int32 op32; @@ -4239,7 +4239,7 @@ TEST_F(DeclarableOpsTests7, TypesConversion_test1) { //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TypesConversion_test2) { NDArray x = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); - NDArray expF = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); + NDArray expF = NDArrayFactory::create('c', {5, 2}, {1.f,2.f,3.f,4.f,5.f,7.f,9.f,10.f, 10.f, 11.f}); NDArray expH = NDArrayFactory::create('c', {5, 2}, {1.f,2.f,3.f,4.f,5.f,7.f,9.f,10.f, 10.f, 11.f}); nd4j::ops::to_float32 op32; @@ -4291,7 +4291,7 @@ TEST_F(DeclarableOpsTests7, TypesConversion_test3) { //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TypesConversion_test4) { NDArray x = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); - NDArray exp32 = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); + NDArray exp32 = NDArrayFactory::create('c', {5, 2}, {1.f,2.f,3.f,4.f,5.f,7.f,9.f,10.f, 10.f, 11.f}); NDArray exp64 = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); nd4j::ops::to_float32 op32; @@ -4968,8 +4968,8 @@ TEST_F(DeclarableOpsTests7, Test_Matmul_Once_Again) { } TYPED_TEST(TypedDeclarableOpsTests7, Test_Pnorm_Once_Again) { - auto input = NDArrayFactory::create('c', {1, 1, 5, 5}, {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0}); - auto exp = NDArrayFactory::create('c', {1, 1, 5, 5}, {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0}); + auto input = NDArrayFactory::create('c', {1, 1, 5, 5}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f, 25.0f}); + auto exp = NDArrayFactory::create('c', {1, 1, 5, 5}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f, 25.0f}); nd4j::ops::pnormpool2d op; auto result = op.execute({&input}, {}, {1,1, 1,1, 0,0, 1,1,1, 3, 0}); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests8.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests8.cpp index 9f98ab3a1..ebe1f8e18 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests8.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests8.cpp @@ -3614,7 +3614,7 @@ TYPED_TEST(TypedDeclarableOpsTests8, LrnTest_BP_1) { eps.linspace(1); // auto exp = NDArrayFactory::create('c', {3,3,5,5}, { - 0.238337, 0.309664, 0.334077, 0.376534, 0.342926, 0.370734, 0.362017, 0.354182, 0.379140, 0.376275, 0.380027, 0.368347, 0.356401, 0.378316, 0.381315, 0.382465, 0.370592, 0.357055, 0.377670, 0.382950, 0.383445, 0.371718, 0.357332, 0.377217, 0.383677, 0.383933, 0.372391, 0.357475, 0.376891, 0.384062, 0.384212, 0.372837, 0.357557, 0.376646, 0.384290, 0.384385, 0.373153, 0.357610, 0.376457, 0.384436, 0.384500, 0.373389, 0.357645, 0.376306, 0.384536, 0.384581, 0.373572, 0.357670, 0.376184, 0.384606, 0.384639, 0.373718, 0.357688, 0.376082, 0.384658, 0.384683, 0.373837, 0.357702, 0.375996, 0.384698, 0.384717, 0.373935, 0.357712, 0.375923, 0.384728, 0.384743, 0.374019, 0.357721, 0.375860, 0.384752, 0.384764, 0.374090, 0.357727, 0.375804, 0.384771, 0.384781, 0.374152, 0.357733, 0.375756, 0.384787, 0.384795, 0.374205, 0.357737, 0.375713, 0.384800, 0.384807, 0.374253, 0.357741, 0.375674, 0.384811, 0.384817, 0.374295, 0.357744, 0.375640, 0.384820, 0.384825, 0.374333, 0.357747, 0.375609, 0.384828, 0.384832, 0.374366, 0.357749, 0.375581, 0.384835, 0.384839, 0.374397, 0.357751, 0.375555, 0.384841, 0.384844, 0.374425, 0.357753, 0.375531, 0.384846, 0.384849, 0.374450, 0.357754, 0.375510, 0.384850, 0.384853, 0.374473, 0.357756, 0.375490, 0.384854, 0.384856, 0.374494, 0.357757, 0.375471, 0.384858, 0.384860, 0.374514, 0.357758, 0.375454, 0.384861, 0.384863, 0.374532, 0.357759, 0.375438, 0.384864, 0.384865, 0.374549, 0.357760, 0.375423, 0.384866, 0.384868, 0.374565, 0.357760, 0.375410, 0.384868, 0.384870, 0.374579, 0.357761, 0.375397, 0.384870, 0.384872, 0.374593, 0.357762, 0.375384, 0.384872, 0.384873, 0.374606, 0.357762, 0.375373, 0.384874, 0.384875, 0.374618, 0.357763, 0.375362, 0.384875, 0.384876, 0.374629, 0.357763, 0.375352, 0.384877, 0.384878, 0.374640, 0.357764, 0.375342, 0.384878, 0.384879, 0.374650, 0.357764, 0.375333, 0.384879, 0.384880, 0.374660, 0.357764, 0.375325, 0.384880, 0.384881, 0.374669, 0.357765, 0.375316, 0.384881, 0.384882, 0.374677, 0.357765, 0.375309, 0.384882, 0.384883, 0.374685, 0.357765, 0.375301, 0.384883, 0.384884, 0.374693, 0.357765, 0.375294, 0.384884, 0.384884, 0.374700, 0.357766, 0.375287, 0.384885, 0.384885, 0.374707, 0.357766, 0.375281, 0.384885, 0.384886, 0.374714, 0.357766, 0.375275, 0.384886} + 0.238337f, 0.309664f, 0.334077f, 0.376534f, 0.342926f, 0.370734f, 0.362017f, 0.354182f, 0.379140f, 0.376275f, 0.380027f, 0.368347f, 0.356401f, 0.378316f, 0.381315f, 0.382465f, 0.370592f, 0.357055f, 0.377670f, 0.382950f, 0.383445f, 0.371718f, 0.357332f, 0.377217f, 0.383677f, 0.383933f, 0.372391f, 0.357475f, 0.376891f, 0.384062f, 0.384212f, 0.372837f, 0.357557f, 0.376646f, 0.384290f, 0.384385f, 0.373153f, 0.357610f, 0.376457f, 0.384436f, 0.384500f, 0.373389f, 0.357645f, 0.376306f, 0.384536f, 0.384581f, 0.373572f, 0.357670f, 0.376184f, 0.384606f, 0.384639f, 0.373718f, 0.357688f, 0.376082f, 0.384658f, 0.384683f, 0.373837f, 0.357702f, 0.375996f, 0.384698f, 0.384717f, 0.373935f, 0.357712f, 0.375923f, 0.384728f, 0.384743f, 0.374019f, 0.357721f, 0.375860f, 0.384752f, 0.384764f, 0.374090f, 0.357727f, 0.375804f, 0.384771f, 0.384781f, 0.374152f, 0.357733f, 0.375756f, 0.384787f, 0.384795f, 0.374205f, 0.357737f, 0.375713f, 0.384800f, 0.384807f, 0.374253f, 0.357741f, 0.375674f, 0.384811f, 0.384817f, 0.374295f, 0.357744f, 0.375640f, 0.384820f, 0.384825f, 0.374333f, 0.357747f, 0.375609f, 0.384828f, 0.384832f, 0.374366f, 0.357749f, 0.375581f, 0.384835f, 0.384839f, 0.374397f, 0.357751f, 0.375555f, 0.384841f, 0.384844f, 0.374425f, 0.357753f, 0.375531f, 0.384846f, 0.384849f, 0.374450f, 0.357754f, 0.375510f, 0.384850f, 0.384853f, 0.374473f, 0.357756f, 0.375490f, 0.384854f, 0.384856f, 0.374494f, 0.357757f, 0.375471f, 0.384858f, 0.384860f, 0.374514f, 0.357758f, 0.375454f, 0.384861f, 0.384863f, 0.374532f, 0.357759f, 0.375438f, 0.384864f, 0.384865f, 0.374549f, 0.357760f, 0.375423f, 0.384866f, 0.384868f, 0.374565f, 0.357760f, 0.375410f, 0.384868f, 0.384870f, 0.374579f, 0.357761f, 0.375397f, 0.384870f, 0.384872f, 0.374593f, 0.357762f, 0.375384f, 0.384872f, 0.384873f, 0.374606f, 0.357762f, 0.375373f, 0.384874f, 0.384875f, 0.374618f, 0.357763f, 0.375362f, 0.384875f, 0.384876f, 0.374629f, 0.357763f, 0.375352f, 0.384877f, 0.384878f, 0.374640f, 0.357764f, 0.375342f, 0.384878f, 0.384879f, 0.374650f, 0.357764f, 0.375333f, 0.384879f, 0.384880f, 0.374660f, 0.357764f, 0.375325f, 0.384880f, 0.384881f, 0.374669f, 0.357765f, 0.375316f, 0.384881f, 0.384882f, 0.374677f, 0.357765f, 0.375309f, 0.384882f, 0.384883f, 0.374685f, 0.357765f, 0.375301f, 0.384883f, 0.384884f, 0.374693f, 0.357765f, 0.375294f, 0.384884f, 0.384884f, 0.374700f, 0.357766f, 0.375287f, 0.384885f, 0.384885f, 0.374707f, 0.357766f, 0.375281f, 0.384885f, 0.384886f, 0.374714f, 0.357766f, 0.375275f, 0.384886f} ); /// nd4j::ops::lrn_bp op; @@ -3636,65 +3636,65 @@ TYPED_TEST(TypedDeclarableOpsTests8, LrnTest_BP_2) { auto x = NDArrayFactory::create( 'c', {3, 3, 5, 5}); x.linspace(1); - auto eps = NDArrayFactory::create('c', {3, 3, 5, 5}, { 0.2581989 ,0.3592106 , 0.40089184, 0.53935987, 0.70014, - 0.4898979 ,0.46056613, 0.43971977, 0.5240002 , 0.6375767, - 0.5274096 ,0.47771242, 0.4443308 , 0.5163977 , 0.61701745, - 0.5424508 ,0.48452914, 0.44570294, 0.5123918 , 0.6068971, - 0.5505386 ,0.4881662 , 0.4462865 , 0.5099462 , 0.60088515, + auto eps = NDArrayFactory::create('c', {3, 3, 5, 5}, { 0.2581989f, 0.3592106f, 0.40089184f, 0.53935987f, 0.70014f, + 0.4898979f, 0.46056613f, 0.43971977f, 0.5240002f, 0.6375767f, + 0.5274096f, 0.47771242f, 0.4443308f, 0.5163977f, 0.61701745f, + 0.5424508f, 0.48452914f, 0.44570294f, 0.5123918f, 0.6068971f, + 0.5505386f, 0.4881662f, 0.4462865f, 0.5099462f, 0.60088515f, - 0.5555859 , 0.49042296, 0.44658744, 0.5083028 , 0.59690416, - 0.55903524, 0.4919585 , 0.44676256, 0.5071239 , 0.59407425, - 0.5615412 , 0.49307042, 0.44687328, 0.50623745, 0.5919596 , - 0.56344414, 0.49391258, 0.4469477 , 0.5055468 , 0.59031945, - 0.56493837, 0.49457246, 0.4470002 , 0.5049936 , 0.5890103 , + 0.5555859f, 0.49042296f, 0.44658744f, 0.5083028f, 0.59690416f, + 0.55903524f, 0.4919585f, 0.44676256f, 0.5071239f, 0.59407425f, + 0.5615412f, 0.49307042f, 0.44687328f, 0.50623745f, 0.5919596f, + 0.56344414f, 0.49391258f, 0.4469477f, 0.5055468f, 0.59031945f, + 0.56493837f, 0.49457246f, 0.4470002f, 0.5049936f, 0.5890103f, - 0.56614274, 0.49510333, 0.44703856, 0.50454074, 0.5879411 , - 0.567134 , 0.49553978, 0.4470674 , 0.504163 , 0.5870515 , - 0.5679643 , 0.4959048 , 0.44708967, 0.5038433 , 0.5862998 , - 0.56866974, 0.4962146 , 0.44710726, 0.5035692 , 0.58565617, - 0.56927663, 0.49648085, 0.4471213 , 0.5033315 , 0.5850988 , + 0.56614274f, 0.49510333f, 0.44703856f, 0.50454074f, 0.5879411f, + 0.567134f, 0.49553978f, 0.4470674f, 0.504163f, 0.5870515f, + 0.5679643f, 0.4959048f, 0.44708967f, 0.5038433f, 0.5862998f, + 0.56866974f, 0.4962146f, 0.44710726f, 0.5035692f, 0.58565617f, + 0.56927663f, 0.49648085f, 0.4471213f, 0.5033315f, 0.5850988f, - 0.56980413, 0.49671215, 0.44713274, 0.50312346, 0.58461165, - 0.57026696, 0.49691492, 0.4471422 , 0.50293994, 0.58418214, - 0.5706764 , 0.49709415, 0.44715008, 0.5027767 , 0.5838005 , - 0.571041 , 0.4972537 , 0.44715673, 0.50263065, 0.58345926, - 0.57136786, 0.49739665, 0.44716236, 0.5024992 , 0.58315235, + 0.56980413f, 0.49671215f, 0.44713274f, 0.50312346f, 0.58461165f, + 0.57026696f, 0.49691492f, 0.4471422f, 0.50293994f, 0.58418214f, + 0.5706764f, 0.49709415f, 0.44715008f, 0.5027767f, 0.5838005f, + 0.571041f, 0.4972537f, 0.44715673f, 0.50263065f, 0.58345926f, + 0.57136786f, 0.49739665f, 0.44716236f, 0.5024992f, 0.58315235f, - 0.5716625 , 0.49752548, 0.4471672 , 0.5023803, 0.5828747 , - 0.5719295 , 0.49764213, 0.44717142, 0.5022721, 0.5826225 , - 0.57217246, 0.49774826, 0.44717506, 0.5021734, 0.58239233, - 0.5723947 , 0.4978453 , 0.44717824, 0.5020829, 0.58218133, - 0.57259864, 0.49793428, 0.44718108, 0.5019997, 0.5819874 , + 0.5716625f, 0.49752548f, 0.4471672f, 0.5023803f, 0.5828747f, + 0.5719295f, 0.49764213f, 0.44717142f, 0.5022721f, 0.5826225f, + 0.57217246f, 0.49774826f, 0.44717506f, 0.5021734f, 0.58239233f, + 0.5723947f, 0.4978453f, 0.44717824f, 0.5020829f, 0.58218133f, + 0.57259864f, 0.49793428f, 0.44718108f, 0.5019997f, 0.5819874f, - 0.5727864 , 0.49801624, 0.44718358, 0.5019227, 0.5818083 , - 0.57296 , 0.49809194, 0.44718578, 0.5018515, 0.5816426 , - 0.5731208 , 0.49816203, 0.44718775, 0.5017854, 0.58148885, - 0.57327026, 0.49822718, 0.4471895 , 0.5017239, 0.5813457 , - 0.57340944, 0.49828786, 0.44719115, 0.5016664, 0.581212 , + 0.5727864f, 0.49801624f, 0.44718358f, 0.5019227f, 0.5818083f, + 0.57296f, 0.49809194f, 0.44718578f, 0.5018515f, 0.5816426f, + 0.5731208f, 0.49816203f, 0.44718775f, 0.5017854f, 0.58148885f, + 0.57327026f, 0.49822718f, 0.4471895f, 0.5017239f, 0.5813457f, + 0.57340944f, 0.49828786f, 0.44719115f, 0.5016664f, 0.581212f, - 0.57353944, 0.4983446 , 0.44719255, 0.50161266, 0.58108705, - 0.5736612 , 0.49839762, 0.4471939 , 0.50156236, 0.5809699 , - 0.5737754 , 0.4984474 , 0.44719502, 0.501515 , 0.58085984, - 0.5738828 , 0.49849418, 0.4471962 , 0.50147045, 0.5807564 , - 0.5739839 , 0.49853817, 0.44719717, 0.5014284 , 0.5806588 , + 0.57353944f, 0.4983446f, 0.44719255f, 0.50161266f, 0.58108705f, + 0.5736612f, 0.49839762f, 0.4471939f, 0.50156236f, 0.5809699f, + 0.5737754f, 0.4984474f, 0.44719502f, 0.501515f, 0.58085984f, + 0.5738828f, 0.49849418f, 0.4471962f, 0.50147045f, 0.5807564f, + 0.5739839f, 0.49853817f, 0.44719717f, 0.5014284f, 0.5806588f, - 0.5740793 , 0.49857965, 0.4471981 , 0.5013887 , 0.5805666 , - 0.5741694 , 0.49861887, 0.44719887, 0.50135124, 0.58047944, - 0.57425463, 0.49865603, 0.44719967, 0.5013157 , 0.5803969 , - 0.5743354 , 0.4986912 , 0.44720036, 0.5012819 , 0.5803186 , - 0.57441217, 0.49872455, 0.44720104, 0.5012499 , 0.58024424, + 0.5740793f, 0.49857965f, 0.4471981f, 0.5013887f, 0.5805666f, + 0.5741694f, 0.49861887f, 0.44719887f, 0.50135124f, 0.58047944f, + 0.57425463f, 0.49865603f, 0.44719967f, 0.5013157f, 0.5803969f, + 0.5743354f, 0.4986912f, 0.44720036f, 0.5012819f, 0.5803186f, + 0.57441217f, 0.49872455f, 0.44720104f, 0.5012499f, 0.58024424f, - 0.57448506, 0.4987563 , 0.44720164, 0.5012194 , 0.58017343, - 0.57455444, 0.4987865 , 0.4472022 , 0.5011904 , 0.5801061, - 0.57462054, 0.49881527, 0.44720277, 0.5011627 , 0.5800419, - 0.57468355, 0.49884263, 0.44720328, 0.50113624, 0.5799805, - 0.57474375, 0.49886885, 0.44720373, 0.50111103, 0.5799219 }); + 0.57448506f, 0.4987563f, 0.44720164f, 0.5012194f, 0.58017343f, + 0.57455444f, 0.4987865f, 0.4472022f, 0.5011904f, 0.5801061f, + 0.57462054f, 0.49881527f, 0.44720277f, 0.5011627f, 0.5800419f, + 0.57468355f, 0.49884263f, 0.44720328f, 0.50113624f, 0.5799805f, + 0.57474375f, 0.49886885f, 0.44720373f, 0.50111103f, 0.5799219f }); // auto exp = NDArrayFactory::create('c', {3,3,5,5}, { - 0.061538, 0.055617, 0.044643, 0.050772, 0.048019, 0.030270, 0.023819, 0.019468, 0.022074, 0.023990, 0.018221, 0.014664, 0.012182, 0.013954, 0.015685, 0.012967, 0.010563, 0.008841, 0.010185, 0.011621, 0.010052, 0.008248, 0.006934, 0.008015, 0.009222, 0.008204, 0.006764, 0.005702, 0.006606, 0.007642, 0.006929, 0.005732, 0.004841, 0.005618, 0.006523, 0.005996, 0.004973, 0.004205, 0.004887, 0.005689, 0.005284, 0.004391, 0.003717, 0.004324, 0.005044, 0.004723, 0.003931, 0.003331, 0.003877, 0.004531, 0.004270, 0.003558, 0.003017, 0.003514, 0.004112, 0.003896, 0.003250, 0.002757, 0.003213, 0.003764, 0.003582, 0.002991, 0.002539, 0.002959, 0.003470, 0.003315, 0.002770, 0.002352, 0.002743, 0.003219, 0.003085, 0.002580, 0.002191, 0.002556, 0.003002, 0.002885, 0.002414, 0.002051, 0.002393, 0.002812, 0.002709, 0.002268, 0.001927, 0.002250, 0.002645, 0.002553, 0.002138, 0.001818, 0.002122, 0.002496, 0.002415, 0.002023, 0.001720, 0.002009, 0.002363, 0.002290, 0.001920, 0.001632, 0.001906, 0.002244, 0.002178, 0.001826, 0.001553, 0.001814, 0.002136, 0.002076, 0.001741, 0.001481, 0.001731, 0.002038, 0.001984, 0.001664, 0.001416, 0.001654, 0.001949, 0.001899, 0.001593, 0.001356, 0.001584, 0.001867, 0.001821, 0.001528, 0.001301, 0.001520, 0.001792, 0.001750, 0.001469, 0.001250, 0.001461, 0.001722, 0.001683, 0.001413, 0.001203, 0.001406, 0.001658, 0.001622, 0.001362, 0.001159, 0.001355, 0.001599, 0.001565, 0.001314, 0.001119, 0.001308, 0.001543, 0.001512, 0.001270, 0.001081, 0.001264, 0.001491, 0.001462, 0.001228, 0.001046, 0.001223, 0.001443, 0.001415, 0.001189, 0.001013, 0.001184, 0.001397, 0.001372, 0.001153, 0.000982, 0.001148, 0.001355, 0.001331, 0.001118, 0.000952, 0.001114, 0.001315, 0.001292, 0.001086, 0.000925, 0.001082, 0.001277, 0.001255, 0.001055, 0.000899, 0.001051, 0.001241, 0.001221, 0.001026, 0.000874, 0.001023, 0.001208, 0.001188, 0.000999, 0.000851, 0.000996, 0.001176, 0.001157, 0.000973, 0.000829, 0.000970, 0.001145, 0.001128, 0.000949, 0.000808, 0.000945, 0.001117, 0.001100, 0.000925, 0.000788, 0.000922, 0.001089, 0.001073, 0.000903, 0.000769, 0.000900, 0.001063, 0.001048, 0.000882, 0.000751, 0.000879, 0.001038, 0.001024, 0.000861, 0.000734, 0.000859, 0.001015, 0.001001, 0.000842, 0.000717, 0.000840, 0.000992} - // 0.009859, 0.013075, 0.013874, 0.017893, 0.022344, 0.014551, 0.012859, 0.011511, 0.013311, 0.015834, 0.012025, 0.010047, 0.008601, 0.009920, 0.011885, 0.009505, 0.007636, 0.006299, 0.007413, 0.009095, 0.007446, 0.005743, 0.004540, 0.005533, 0.007033, 0.005821, 0.004282, 0.003209, 0.004123, 0.005491, 0.004577, 0.003198, 0.002247, 0.003097, 0.004355, 0.003652, 0.002412, 0.001565, 0.002357, 0.003517, 0.002965, 0.001844, 0.001084, 0.001821, 0.002893, 0.002451, 0.001430, 0.000741, 0.001428, 0.002422, -0.111434, -0.105946, -0.100351, -0.091868, -0.083323, -0.078775, -0.076222, -0.073291, -0.067635, -0.061692, -0.058943, -0.057832, -0.056263, -0.052198, -0.047768, -0.046002, -0.045655, -0.044839, -0.041748, -0.038271, -0.037084, -0.037161, -0.036786, -0.034331, -0.031495, 0.000077, -0.000673, -0.001181, -0.000667, 0.000079, -0.000089, -0.000802, -0.001285, -0.000793, -0.000079, -0.000228, -0.000908, -0.001368, -0.000896, -0.000212, -0.000345, -0.000996, -0.001434, -0.000981, -0.000325, -0.000444, -0.001067, -0.001487, -0.001051, -0.000421, 0.000697, 0.000188, -0.000152, 0.000210, 0.000731, 0.000650, 0.000165, -0.000161, 0.000185, 0.000683, 0.000610, 0.000145, -0.000168, 0.000164, 0.000641, 0.000574, 0.000128, -0.000172, 0.000146, 0.000604, 0.000542, 0.000113, -0.000175, 0.000131, 0.000571, -0.009490, -0.010070, -0.010409, -0.009734, -0.008834, -0.008785, -0.009351, -0.009687, -0.009054, -0.008207, -0.008167, -0.008718, -0.009050, -0.008455, -0.007654, -0.007622, -0.008159, -0.008485, -0.007924, -0.007164, -0.007138, -0.007661, -0.007981, -0.007450, -0.006728, -0.000901, -0.001327, -0.001614, -0.001310, -0.000869, -0.000913, -0.001328, -0.001607, -0.001310, -0.000882, -0.000922, -0.001326, -0.001598, -0.001309, -0.000892, -0.000930, -0.001323, -0.001588, -0.001306, -0.000900, -0.000936, -0.001319, -0.001577, -0.001302, -0.000906, 0.000339, 0.000038, -0.000164, 0.000048, 0.000355, 0.000328, 0.000035, -0.000162, 0.000045, 0.000343, 0.000318, 0.000033, -0.000159, 0.000041, 0.000332, 0.000308, 0.000030, -0.000157, 0.000039, 0.000322, 0.000299, 0.000028, -0.000155, 0.000036, 0.000312, -0.004085, -0.004479, -0.004733, -0.004396, -0.003925, -0.003925, -0.004309, -0.004558, -0.004232, -0.003775, -0.003776, -0.004151, -0.004395, -0.004079, -0.003636, -0.003637, -0.004004, -0.004242, -0.003936, -0.003505, -0.003507, -0.003866, -0.004100, -0.003802, -0.003383} + 0.061538f, 0.055617f, 0.044643f, 0.050772f, 0.048019f, 0.030270f, 0.023819f, 0.019468f, 0.022074f, 0.023990f, 0.018221f, 0.014664f, 0.012182f, 0.013954f, 0.015685f, 0.012967f, 0.010563f, 0.008841f, 0.010185f, 0.011621f, 0.010052f, 0.008248f, 0.006934f, 0.008015f, 0.009222f, 0.008204f, 0.006764f, 0.005702f, 0.006606f, 0.007642f, 0.006929f, 0.005732f, 0.004841f, 0.005618f, 0.006523f, 0.005996f, 0.004973f, 0.004205f, 0.004887f, 0.005689f, 0.005284f, 0.004391f, 0.003717f, 0.004324f, 0.005044f, 0.004723f, 0.003931f, 0.003331f, 0.003877f, 0.004531f, 0.004270f, 0.003558f, 0.003017f, 0.003514f, 0.004112f, 0.003896f, 0.003250f, 0.002757f, 0.003213f, 0.003764f, 0.003582f, 0.002991f, 0.002539f, 0.002959f, 0.003470f, 0.003315f, 0.002770f, 0.002352f, 0.002743f, 0.003219f, 0.003085f, 0.002580f, 0.002191f, 0.002556f, 0.003002f, 0.002885f, 0.002414f, 0.002051f, 0.002393f, 0.002812f, 0.002709f, 0.002268f, 0.001927f, 0.002250f, 0.002645f, 0.002553f, 0.002138f, 0.001818f, 0.002122f, 0.002496f, 0.002415f, 0.002023f, 0.001720f, 0.002009f, 0.002363f, 0.002290f, 0.001920f, 0.001632f, 0.001906f, 0.002244f, 0.002178f, 0.001826f, 0.001553f, 0.001814f, 0.002136f, 0.002076f, 0.001741f, 0.001481f, 0.001731f, 0.002038f, 0.001984f, 0.001664f, 0.001416f, 0.001654f, 0.001949f, 0.001899f, 0.001593f, 0.001356f, 0.001584f, 0.001867f, 0.001821f, 0.001528f, 0.001301f, 0.001520f, 0.001792f, 0.001750f, 0.001469f, 0.001250f, 0.001461f, 0.001722f, 0.001683f, 0.001413f, 0.001203f, 0.001406f, 0.001658f, 0.001622f, 0.001362f, 0.001159f, 0.001355f, 0.001599f, 0.001565f, 0.001314f, 0.001119f, 0.001308f, 0.001543f, 0.001512f, 0.001270f, 0.001081f, 0.001264f, 0.001491f, 0.001462f, 0.001228f, 0.001046f, 0.001223f, 0.001443f, 0.001415f, 0.001189f, 0.001013f, 0.001184f, 0.001397f, 0.001372f, 0.001153f, 0.000982f, 0.001148f, 0.001355f, 0.001331f, 0.001118f, 0.000952f, 0.001114f, 0.001315f, 0.001292f, 0.001086f, 0.000925f, 0.001082f, 0.001277f, 0.001255f, 0.001055f, 0.000899f, 0.001051f, 0.001241f, 0.001221f, 0.001026f, 0.000874f, 0.001023f, 0.001208f, 0.001188f, 0.000999f, 0.000851f, 0.000996f, 0.001176f, 0.001157f, 0.000973f, 0.000829f, 0.000970f, 0.001145f, 0.001128f, 0.000949f, 0.000808f, 0.000945f, 0.001117f, 0.001100f, 0.000925f, 0.000788f, 0.000922f, 0.001089f, 0.001073f, 0.000903f, 0.000769f, 0.000900f, 0.001063f, 0.001048f, 0.000882f, 0.000751f, 0.000879f, 0.001038f, 0.001024f, 0.000861f, 0.000734f, 0.000859f, 0.001015f, 0.001001f, 0.000842f, 0.000717f, 0.000840f, 0.000992f} + // 0.009859f, 0.013075f, 0.013874f, 0.017893f, 0.022344f, 0.014551f, 0.012859f, 0.011511f, 0.013311f, 0.015834f, 0.012025f, 0.010047f, 0.008601f, 0.009920f, 0.011885f, 0.009505f, 0.007636f, 0.006299f, 0.007413f, 0.009095f, 0.007446f, 0.005743f, 0.004540f, 0.005533f, 0.007033f, 0.005821f, 0.004282f, 0.003209f, 0.004123f, 0.005491f, 0.004577f, 0.003198f, 0.002247f, 0.003097f, 0.004355f, 0.003652f, 0.002412f, 0.001565f, 0.002357f, 0.003517f, 0.002965f, 0.001844f, 0.001084f, 0.001821f, 0.002893f, 0.002451f, 0.001430f, 0.000741f, 0.001428f, 0.002422f, -0.111434f, -0.105946f, -0.100351f, -0.091868f, -0.083323f, -0.078775f, -0.076222f, -0.073291f, -0.067635f, -0.061692f, -0.058943f, -0.057832f, -0.056263f, -0.052198f, -0.047768f, -0.046002f, -0.045655f, -0.044839f, -0.041748f, -0.038271f, -0.037084f, -0.037161f, -0.036786f, -0.034331f, -0.031495f, 0.000077f, -0.000673f, -0.001181f, -0.000667f, 0.000079f, -0.000089f, -0.000802f, -0.001285f, -0.000793f, -0.000079f, -0.000228f, -0.000908f, -0.001368f, -0.000896f, -0.000212f, -0.000345f, -0.000996f, -0.001434f, -0.000981f, -0.000325f, -0.000444f, -0.001067f, -0.001487f, -0.001051f, -0.000421f, 0.000697f, 0.000188f, -0.000152f, 0.000210f, 0.000731f, 0.000650f, 0.000165f, -0.000161f, 0.000185f, 0.000683f, 0.000610f, 0.000145f, -0.000168f, 0.000164f, 0.000641f, 0.000574f, 0.000128f, -0.000172f, 0.000146f, 0.000604f, 0.000542f, 0.000113f, -0.000175f, 0.000131f, 0.000571f, -0.009490f, -0.010070f, -0.010409f, -0.009734f, -0.008834f, -0.008785f, -0.009351f, -0.009687f, -0.009054f, -0.008207f, -0.008167f, -0.008718f, -0.009050f, -0.008455f, -0.007654f, -0.007622f, -0.008159f, -0.008485f, -0.007924f, -0.007164f, -0.007138f, -0.007661f, -0.007981f, -0.007450f, -0.006728f, -0.000901f, -0.001327f, -0.001614f, -0.001310f, -0.000869f, -0.000913f, -0.001328f, -0.001607f, -0.001310f, -0.000882f, -0.000922f, -0.001326f, -0.001598f, -0.001309f, -0.000892f, -0.000930f, -0.001323f, -0.001588f, -0.001306f, -0.000900f, -0.000936f, -0.001319f, -0.001577f, -0.001302f, -0.000906f, 0.000339f, 0.000038f, -0.000164f, 0.000048f, 0.000355f, 0.000328f, 0.000035f, -0.000162f, 0.000045f, 0.000343f, 0.000318f, 0.000033f, -0.000159f, 0.000041f, 0.000332f, 0.000308f, 0.000030f, -0.000157f, 0.000039f, 0.000322f, 0.000299f, 0.000028f, -0.000155f, 0.000036f, 0.000312f, -0.004085f, -0.004479f, -0.004733f, -0.004396f, -0.003925f, -0.003925f, -0.004309f, -0.004558f, -0.004232f, -0.003775f, -0.003776f, -0.004151f, -0.004395f, -0.004079f, -0.003636f, -0.003637f, -0.004004f, -0.004242f, -0.003936f, -0.003505f, -0.003507f, -0.003866f, -0.004100f, -0.003802f, -0.003383f} ); nd4j::ops::lrn_bp op; diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp index f88d6e930..dfbfc90a8 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp @@ -1903,13 +1903,13 @@ TEST_F(DeclarableOpsTests9, cumprod_2) { NDArray exp0 = exp(0, {0}); NDArray exp1 = exp(1, {0}); - exp0.p(0, 1.); - exp1.p(0, 1.); + exp0.p(0, 1.f); + exp1.p(0, 1.f); for (int i = 1; i < 1500; ++i) { const auto prev = exp0.e(i-1); - exp0.p(i, prev * x0.e(i)); - exp1.p(i, prev * x1.e(i)); + exp0.p(i, prev * x0.e(i)); + exp1.p(i, prev * x1.e(i)); } nd4j::ops::cumprod op; @@ -3331,8 +3331,8 @@ TEST_F(DeclarableOpsTests9, Cholesky_Test_2) { //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, Cholesky_Test_3) { - NDArray x = NDArrayFactory::create('c', {2, 3, 3}, {4, 12,-16, 12 ,37,-43, -16, -43, 98, 1, 1, 1, 1, 2, 2, 1, 2., 6}); - NDArray exp = NDArrayFactory::create('c', {2, 3, 3}, {2., 0., 0., 6., 1., 0., -8., 5., 3., 1., 0., 0., 1., 1., 0,1., 1., 2.}); + NDArray x = NDArrayFactory::create('c', {2, 3, 3}, {4.f, 12.f, -16.f, 12.f, 37.f, -43.f, -16.f, -43.f, 98.f, 1.f, 1.f, 1.f, 1.f, 2.f, 2.f, 1.f, 2.f, 6.f}); + NDArray exp = NDArrayFactory::create('c', {2, 3, 3}, {2.f, 0.f, 0.f, 6.f, 1.f, 0.f, -8.f, 5.f, 3.f, 1.f, 0.f, 0.f, 1.f, 1.f, 0.f, 1.f, 1.f, 2.f}); nd4j::ops::cholesky op; diff --git a/libnd4j/tests_cpu/layers_tests/JavaInteropTests.cpp b/libnd4j/tests_cpu/layers_tests/JavaInteropTests.cpp index a6ca56fd4..c89a989a9 100644 --- a/libnd4j/tests_cpu/layers_tests/JavaInteropTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/JavaInteropTests.cpp @@ -124,9 +124,9 @@ TEST_F(JavaInteropTests, TestShapeExposure3) { } TEST_F(JavaInteropTests, Test_Squeeze_1) { - auto x = NDArrayFactory::create('c', {1, 6}, {1, 2, 3, 4, 5, 6}); + auto x = NDArrayFactory::create('c', {1, 6}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); auto z = NDArrayFactory::create('c', {6}); - auto e = NDArrayFactory::create('c', {6}, {1, 2, 3, 4, 5, 6}); + auto e = NDArrayFactory::create('c', {6}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); nd4j::ops::squeeze op; @@ -683,8 +683,8 @@ TEST_F(JavaInteropTests, Test_Greater_1) { TEST_F(JavaInteropTests, Test_Greater_2) { - auto x = NDArrayFactory::create('c', {2, 2}, {1, 2, 1, 2}); - auto y = NDArrayFactory::create('c', {2, 2}, {1, 2, 0, 0}); + auto x = NDArrayFactory::create('c', {2, 2}, {1.f, 2.f, 1.f, 2.f}); + auto y = NDArrayFactory::create('c', {2, 2}, {1.f, 2.f, 0.f, 0.f}); auto o = NDArrayFactory::create('c', {2, 2}, {1, 1, 1, 1}); auto exp = NDArrayFactory::create('c', {2, 2}, {0, 0, 1, 1}); @@ -710,7 +710,7 @@ TEST_F(JavaInteropTests, Test_Boolean_Op_1) { nd4j::ops::is_non_decreasing op; - auto x = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); + auto x = NDArrayFactory::create('c', {5}, {1.f, 2.f, 3.f, 4.f, 5.f}); auto o = NDArrayFactory::create(false); auto exp = NDArrayFactory::create(1); @@ -787,10 +787,10 @@ TEST_F(JavaInteropTests, Test_Inplace_Outputs_2) { } TEST_F(JavaInteropTests, Test_Inplace_Outputs_3) { - auto input = NDArrayFactory::create('c', {2, 3, 4}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24}); + auto input = NDArrayFactory::create('c', {2, 3, 4}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24}); auto indices = NDArrayFactory::create('c', {1, 6}, {0,1, 2,2, 1,2}); - auto output = NDArrayFactory::create('f', {2, 1, 6, 4}); - auto e = NDArrayFactory::create('c', {2, 1, 6, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12, 9,10,11,12, 5, 6, 7, 8, 9,10,11,12, 13,14,15,16, 17,18,19,20, 21,22,23,24, 21,22,23,24, 17,18,19,20, 21,22,23,24}); + auto output = NDArrayFactory::create('f', {2, 1, 6, 4}); + auto e = NDArrayFactory::create('c', {2, 1, 6, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12, 9,10,11,12, 5, 6, 7, 8, 9,10,11,12, 13,14,15,16, 17,18,19,20, 21,22,23,24, 21,22,23,24, 17,18,19,20, 21,22,23,24}); nd4j::ops::gather op; @@ -864,9 +864,9 @@ TEST_F(JavaInteropTests, Test_SimpleIf_Output) { TEST_F(JavaInteropTests, Test_AveragePooling_FF_TF_double) { - auto input = NDArrayFactory::create('c', {4, 10, 10, 3}, {9.37125111, 2.20166993, 2.91434479, 5.43639755, -2.10573769, 4.08528662, 5.86908436, -4.46203756, 2.21057916, 5.35849190, 0.01394637, 4.40566349, 7.07982206, -0.09633455, 2.42429352, 3.97301817, -1.89553940, 1.99690318, 6.33141708, 0.55401880, 1.70707977, 5.55204201, -0.03513752, 1.60011971, 2.62700319, -2.74582434, 3.06697464, 1.06277943, -1.16075921, -0.78095782, 9.72352791, -1.22686064, 1.99644792, 7.35571337, 1.40607321, 0.11390255, 9.53334427, 2.28303599, -1.66728830, 6.16678810, -0.04532295, -1.97708666, 9.74906158, 1.46223176, -1.46734393, 4.30761862, -1.23790228, 1.24823606, 6.13938427, -3.83689475, -1.19625473, 7.91535568, 6.05868721, -3.22946382, 8.81633949, -0.19967777, 0.66053957, 2.30919123, 0.74543846, -0.39347672, 11.11058044, 0.53720862, 1.52645731, 5.70012379, -1.15213466, 1.16451406, 7.00526333, 1.57362783, -2.44384766, 5.54213285, -1.98828590, -0.70483637, 7.88281822, -3.59875536, 0.80745387, 13.41578484, -1.55507684, -0.65855008, 9.32583523, -0.14544789, 0.73436141, 3.61176538, -1.71268058, -2.58490300, 9.09280205, -3.27405524, -2.04569697, 4.44761324, -0.62955856, -2.61917663, 8.04890442, 0.54579324, 0.85929775, 9.82259560, -1.93825579, 0.77703512, 4.67090321, -4.79267597, -2.38906908, 9.31265545, 0.96026313, -1.14109385, 11.54231834, -0.01417295, -0.39500344, 8.49191666, 0.55300158, 2.79490185, 6.92466164, 1.72254205, 2.82222271, 8.83112717, 2.95033407, 2.18054962, 6.73509789, -2.22272944, 0.51127720, -1.04563558, 2.15747333, -2.30959272, 9.55441570, 1.50396204, 1.77370787, 7.38146257, -1.79076433, 3.20961165, 7.18864202, 2.91217351, 0.43018937, 7.11078024, -1.17386127, -0.16817921, 6.12327290, -2.82205725, 3.30696845, 13.51291752, -1.30856836, -2.38332748, 11.09487438, -1.47190213, -0.53050828, 4.38285351, -5.07309771, 1.50714362, 5.72274446, -2.85825086, -0.89673209, 3.73791552, -0.67708802, -4.13149452, -0.00671843, -0.26566532, 0.32961160, 7.14501762, -1.41608179, -4.96590328, 12.26205540, -0.65158135, -0.88641000, 6.95777559, -0.79058206, -0.10260171, 7.87169170, 1.35921454, 1.11759663, 5.46187401, -2.57214499, 2.48484039, 4.04043484, -2.07137156, -1.42709637, 9.25487137, -0.12605135, -2.66949964, 2.89412403, 0.74451172, -2.96250391, 3.99258423, 0.27084303, 0.32213116, 5.42332172, -0.44414216, 1.70881832, 6.69346905, 0.53058422, -4.73146200, 4.22051668, 2.24834967, 0.66996074, 4.30173683, 0.11849818, -4.07520294, 8.27318478, -2.54398274, -2.86705542, 10.11775303, -0.99382895, 0.65881538, 7.93556786, -1.27934420, -1.69343162, 9.68042564, -1.02609646, -1.18189347, 5.75370646, -1.67888868, -4.48871994, 4.79537392, -0.79212248, -0.19855022, 6.15060997, -0.01081491, 3.64454579, 10.82562447, 1.58859253, -2.65847278, 8.60093212, -1.59196103, 0.07635692, 11.76175690, -1.17453325, 0.10122013, 6.86458445, -2.18891335, -2.74004745, 8.07066154, 0.71818852, -2.03035975, 6.31053686, 0.51509416, 1.39789927, 9.43515587, 2.04256630, 0.13985133, 4.65010691, 2.40911126, -0.36255789, -3.06867862, -0.45225358, -1.56778407, 6.05917358, -1.09891272, 1.77184200, 6.46248102, 0.96042323, -0.24346280, 4.63436460, -4.69907761, 1.25187206, 11.46173859, -2.21917558, 1.28007793, 6.92173195, 2.11268163, -3.47389889, 5.08722782, -3.03950930, -4.17154264, 11.30568314, 0.80361372, 2.53214502, 7.18707085, -4.49114513, 2.85449266, 10.14906883, -0.31974933, -0.84472644, -0.52459574, 0.12921631, -1.81390119, 2.76170087, 1.03982210, 2.91744232, -0.29048753, 5.87453508, -1.53684759, 1.85800636, -0.91404629, 1.28954852, 5.11354685, -2.47475505, -1.33179152, 2.58552408, 1.37316465, -3.32339454, 1.54122913, 3.24953628, -0.29758382, 2.82391763, -1.51142192, -1.22699404, 6.75745535, 0.65452754, -3.29385471, 2.06008053, 2.53172946, -4.23532820, -1.53909743, -0.07010663, -1.42173731, 7.29031610, -0.18448229, 4.59496164, 6.73027277, 0.73441899, 0.14426160, 4.14915276, -2.97010231, 6.05851364, 4.95218086, -2.39145470, 2.40494704, 2.10288811, 0.53503096, 1.44511235, 6.66344261, -3.05803776, 7.21418667, 3.30303526, -0.24163735, 3.47409391, 3.64520788, 2.15189481, -3.11243272, 3.62310791, 0.37379482, 0.40865007, -0.83132005, -4.78246069, 2.07030797, 6.51765442, 3.16178989, 5.06180477, 3.78434467, -0.96689719, 0.35965276, 5.89967585, 1.40294051, 1.11952639, 10.59778214, 0.26739889, -1.61297631, 6.24801159, -0.93914318, -0.57812452, 9.92604542, -0.73025000, -3.38530874, 2.45646000, -2.47949195, 0.51638460, 10.65636063, 1.97816694, -3.00407791, 2.66914415, -0.81951088, -0.23316640, 2.40737987, -2.70007610, 1.51531935, 4.08860207, -0.27552786, -1.31721711, 7.11568260, -3.33498216, -4.02545023, 7.22675610, -0.81690705, -2.52689576, 1.04016697, -0.79291463, -0.34875512, 10.00498390, -4.24167728, 1.46162593, 11.82569408, -1.70359993, -0.30161047, 16.44085884, -0.82253462, -0.09435523, 6.13080597, -0.20259480, 0.68308711, 6.15663004, -6.61776876, 0.33295766, 2.55449438, -0.17819691, -1.14892209, 5.56776142, 1.99279118, 1.33035934, 4.45823956, 3.34916544, -2.59905386, 6.16164446, -2.03881931, -2.45273542, 12.46793365, -2.22743297, 2.83738565, 8.48628139, -1.39347959, -1.30867767, 11.08041477, -4.00363779, 2.09183025, 11.30395889, -2.20504737, 1.37426853, 8.98735619, 1.04676604, -0.72757077, 8.28050232, -6.70741081, -0.65798020, 5.68592072, -0.60760021, 0.35854483, 6.26852131, 1.94100165, 1.32112014, 0.80987954, -1.74617672, -0.25434083, 7.16045523, 1.58884013, -2.64847064, 13.14820385, 1.21393633, -2.47258949, 9.41650105, -0.79384226, 2.48954105, 10.95629311, 0.47723705, 4.02126694, 8.02593136, -2.20726371, -1.18794477, 1.50836647, 0.93118095, -1.73513174, 8.85493565, -2.99670315, -0.79055870, 2.39473820, 2.05046916, -2.38055134, 11.82299423, 0.15609655, 0.68744308, 5.66401434, -0.69281673, 2.09855556, 7.74626589, -0.34283102, 1.00542057, 9.95838642, 0.80161905, 2.33455157, 9.80057335, -0.93561798, 2.56991577, 8.29711342, 0.94213426, 0.44209945, 11.70259857, 0.92710167, 2.60957146, 0.24971688, -0.86529571, 3.78628922, 6.80884457, -0.68178189, 2.21103406, 3.18895817, 0.60283208, -2.92716241, 6.72060776, -1.06625068, 2.56543374, 9.97404480, 3.58080721, -0.94936347, 10.16736984, -1.38464379, 1.18191063, 6.66179037, -3.56115270, 0.32329530, 10.90870762, 2.20638227, 0.19653285, 7.34650040, -3.63859272, -1.03027737, 5.98829985, -3.66606474, -3.89746714, 8.63469028, 1.22569811, 1.63240814, 3.74385309, 0.58243257, -0.56981975, 3.69260955, 1.00979900, -1.44030499, 8.57058144, -1.10648811, 1.20474911, 5.43133020, -2.14822555, -0.07928789, 11.25825310, 0.19645604, -5.49546146, 10.41917038, -0.68178523, -2.99639869, 6.50054455, 0.46488351, -5.42328453, 9.09500027, -2.82107449, 0.05601966, 15.34610748, -0.06820253, 3.86699796, 10.73316956, -3.04795432, -0.14702171, 5.64813185, 1.44028485, -2.47596145, 0.07280898, -3.03187990, -1.35183525, 9.35835648, 2.72966957, 1.88199532, 10.36187744, -0.22834805, -3.26738238, 6.92025137, -2.34061313, 4.77379704, 5.28559113, -2.96323752, -1.76186585, 5.94436455, 0.38647744, -5.73869514, 6.76849556, 1.40892124, -1.19068217, 5.37919092, -6.65328646, 3.62782669, 12.34744644, 2.44762444, -4.19242620, 6.14906216, 0.08121119, 0.61355996, 2.69666457, -1.88962626, -0.55314136, 1.84937525, 1.56048691, 1.17460012, 3.75674725, 1.06198275, -5.74625874, 5.41645575, -1.28946674, -1.51689398, 4.32400894, -0.05222082, -4.83948946, 1.80747867, 1.63144708, -2.73887825, 1.63975775, -2.02163982, -0.16210437, 2.93518686, 1.14427686, -2.83246303, 4.79283667, 2.69697428, -3.12678456, -1.19225168, -2.37022972, -3.09429741, 1.94225383, -1.13747168, -2.55048585, 5.40242243, 1.12777328, 3.43713188, 3.62658787, -2.16878843, 0.30164462, 2.97407579, -0.07275413, -1.31149673, 4.70066261, -2.01323795, 4.85255766, 4.59128904, 1.68084168, 1.60336494, 6.58138466, -1.04759812, 2.69906545, 3.55769277, -0.74327278, 2.65819693, 5.39528131, 2.11248922, -1.06446671, 5.24546766, -2.43146014, 4.58907509, 0.06521678, -2.24503994, 2.45722699, 6.94863081, 0.35258654, 2.83396196, 9.92525196, -1.12225175, -0.34365177, 7.19116688, -4.39813757, 0.46517885, 13.22028065, -2.57483673, -6.37226963, 7.58046293, -2.74600363, 0.42231262, 8.04881668, 0.17289802, -0.53447008, 16.55157471, -5.63614368, 0.39288223, 3.37079263, 1.26484549, -0.12820500, 8.46440125, -4.39304399, 2.97676420, 0.65650189, 0.83158541, -1.11556435, 6.32885838, -0.36087769, 2.80724382, 9.90292645, 1.15936041, 0.20947981, 6.91249275, -2.67404819, 2.93782163, 6.65656614, -2.30828357, 2.98214006, 6.80611229, -4.93821478, -7.66555262, 7.59763002, -0.54159302, 3.87403512, 12.42607784, 2.59284401, -0.23375344, 8.95293331, -0.71807784, 0.61873478, 8.66713524, 1.24289191, -2.37835455, 2.08071637, -0.88315344, -3.41891551, 6.85245323, 1.73007369, 1.02169311, 7.69170332, -2.85411978, 2.69790673, 8.12906551, -1.19351399, -2.26442742, 12.26104450, -0.75579089, -1.73274946, 10.68729019, 2.20655656, -0.90522075, 12.42165184, -1.67929137, 2.44851565, 9.31565762, -0.06645700, 1.52762020, 6.18427515, -1.68882596, 3.70261097, 3.02252960, -3.44125366, -1.31575799, 2.84617424, -0.96849400, -4.52356243, 9.95027161, 0.19966406, -0.78874779, 8.18595028, -4.08300209, 1.75126517, 0.96418417, -4.04913044, -0.95200396, 12.03637886, -0.03041124, 0.41642749, 8.88267422, -3.24985337, -2.24919462, 7.32566118, 0.16964148, -2.74123430, 7.05264473, -3.30191112, 0.17163286, 4.81851053, -1.64463484, -0.85933101, 7.29276276, 2.34066939, -2.14860010, 3.46148157, -0.01782012, 1.51504040, 4.79304934, 1.85281146, -1.70663762, 6.93470192, -4.15440845, -1.25983095, 10.52491760, 0.42930329, -1.85146868, 11.70042324, -0.41704914, 3.83796859, 9.21148491, -2.79719448, 0.79470479, 6.26926661, -5.85230207, 3.95105338, 7.84790897, -1.38680744, -1.78099084, 11.95235348, -2.99841452, -1.34507811, 6.15714645, -1.07552516, -2.81228638, 1.66234732, -4.55166149, -1.92601109, 8.64634514, -0.48158705, 3.31595659, 7.67371941, 2.56964207, 0.12107098, 4.56467867, -0.93541539, 1.39432955, 11.99714088, 1.05353570, -2.13099813, 3.67617917, 3.45895386, 1.37365830, 8.74344158, -4.17585802, 1.43908918, 6.28764772, 3.97346330, -0.69144285, 9.07983303, -0.41635889, -0.14965028, 8.85469818, 1.11306190, 2.59440994, 5.38982344, -1.07948279, 1.37252975, 10.26984596, -0.09318046, 2.73104119, 12.45902252, -1.55446684, -2.76124811, 12.19395065, -0.51846564, 1.02764034, 11.42673588, -0.95940983, -0.04781032, 8.78379822, -4.88957930, 0.32534006, 11.97696400, -3.35108662, 1.95104563, 4.46915388, -2.32061648, 3.45230985, 8.29983711, 2.81034684, -2.35529327, 6.07801294, -0.98105043, -0.05359888, 2.52291036, -0.01986909, -2.35321999, 10.51954269, 2.11145401, 3.53506470, 7.29093266, 0.03721160, -1.13496494, 7.43886709, -5.84201956, 2.50796294, 12.14647675, 2.77490377, -2.18896222, 6.05641937, 5.32617044, 1.04221284, 10.79106712, -2.95749092, -2.75414610, 11.30037117, -3.40654182, -2.24673963, 7.49126101, 0.70811015, -6.18003702, 13.83951187, -1.01204085, 1.36298490, -1.04451632, 2.42435336, -0.02346706, -0.85528886, 1.04731262, 0.22192979, 4.15708160, 0.34933877, 0.04814529, 2.24107265, 0.49676740, -1.47752666, 0.45040059, -0.70471478, -1.19759345, 0.21711677, 0.88461423, -2.76830935, 5.52066898, 1.97664857, -1.75381601, 3.45877838, 1.52617192, -1.61350942, 0.85337949, 1.97610760, -3.40310287, 3.40319014, -3.38691044, -0.71319139, 1.65463758, -0.60680127, -1.80700517, 8.02592373, 2.59627104, 2.65895891, 5.93043184, -4.48425817, 3.92670918, 4.19496679, -2.28286791, 6.41634607, 5.72330523, 1.16269672, -0.28753027, 2.46342492, 0.36693189, 0.26712441, 6.37652683, -2.50139046, 2.43923736, 5.56310415, 0.98065847, 1.04267502, 4.16403675, -0.04966142, 4.40897894, 3.72905660, -3.46129870, 3.59962773, 1.34830284, -1.76661730, 0.47943926, 5.29946661, -1.12711561, 1.26970029, 15.17655945, -1.50971997, 5.81345224, 8.48562050, -4.36049604, 2.48144460, 8.23780441, -3.46030426, -0.84656560, 5.94946814, 1.12747943, -2.65683913, 8.69085693, 1.31309867, -2.79958344, 8.76840591, -1.56444156, 1.62710834, 2.41177034, -0.72804940, 5.70619011, 4.67169666, -0.86167198, -1.83803177, 2.96346045, 2.82692933, -2.81557131, 7.11113358, -1.90071094, 2.54244423, 11.19284058, -0.06298946, -1.71517313, 12.98388577, 0.84510714, 3.00816894, 2.57200313, 0.03899818, -1.49330592, 9.60099125, -3.59513044, -1.30045319, 7.09241819, -0.65233821, -2.33627677, 8.81366920, 0.84154201, 1.03312039, 9.85289097, 0.19351870, 1.78496623, 7.34631205, -2.16530800, -0.65016162, 2.46842360, 0.24016285, -1.24308395, 4.78175163, -0.97682536, 2.20942235, 6.68382788, 3.76786447, -1.44454038, 6.26453733, -3.23575711, -2.30137897, 9.53092670, -5.55222607, 3.25999236, 9.37559509, 1.86339056, -0.23551451, 10.23400211, 3.93031883, -0.52629089, 7.85724449, -2.91549587, 4.46612740, 5.66530371, -2.70820427, 4.81359577, 10.31247330, 1.92230141, 2.53931546, 0.74986327, 1.70303428, 0.48063779, 5.31099129, -0.78976244, 3.75864220, 4.23051405, 2.34042454, -7.98193836, 9.83987141, -1.46722627, 3.54497814, 10.36455154, -4.51249075, 0.77715248, 7.78694630, -4.59989023, -2.49585629, 9.90296268, 1.38535416, 1.17441154, 10.10452843, -0.98628229, 0.60194463, 9.12639141, -3.90754628, 2.88526392, 7.24123430, -0.15283313, -0.75728363, -1.15116858, -2.53791571, 0.77229571, 6.44114161, 0.02646767, 4.95463037, 7.21066380, 1.79384065, 0.73250306, 8.04447937, 0.32576546, -0.79447043, 10.12717724, 2.33392906, 1.30716443, 12.36073112, -0.36694977, -1.20438910, 7.03105593, 0.59557682, 0.69267452, 10.18113136, 2.49944925, -0.42229167, 8.83143330, -1.18805945, -2.87509322, 4.53596449, 4.09732771, -3.39088297, -1.02536607, 0.82119560, -3.47302604, 9.29991817, 0.21001509, 4.97036457, 9.50018406, 1.04420102, 1.96560478, 10.74769592, -6.22709799, 3.11690164, 5.06759691, -1.23724771, -3.05831861, 8.12925529, -1.93435478, -1.10151744, 9.32263088, -0.04249470, -5.98547363, 10.49398136, 0.26400441, -0.78915191, 13.28219604, 2.99276900, 0.74853164, 2.49364305, -3.43529654, 4.05278301, 2.13498688, -2.35444307, -0.79900265, 4.66968822, -0.31095147, 3.60674143, 12.37222099, -0.07855003, -3.30292702, 12.15215874, 0.60886210, 2.87075138, 7.75271845, 0.38044083, 3.34402204, 6.40583277, -0.87888050, 0.67438459, 6.91080809, 1.98332930, -0.08303714, 8.08630371, -0.16772588, -2.74058914, 7.17253590, -2.69122696, 1.48173678, 8.99470139, -1.43302310, -0.88651133, 2.66944790, -0.29186964, 2.00838661, 5.09587479, -0.76676071, -2.88322186, 8.31110573, -0.14550979, -1.37726915, 10.28355122, -1.60575438, -0.04118848, 9.97510815, 0.14440438, -3.24632120, 9.00034523, 4.14319563, -1.31023729, 7.16950464, -0.70428526, 2.01559544, 7.26155043, 2.40816474, 2.09847403, 7.31264496, -0.75401551, 2.13392544, 7.03648758, 1.04036045, -1.15636516, 1.09634531, -0.06340861, -0.58107805, -0.65623116, 1.18972754, -0.80717683, 1.40118241, -0.61932516, -3.60596156, 1.59904599, -2.23774099, -1.13721037, 3.89620137, -0.09115922, -7.51356888, 2.36975193, -1.42520905, -2.34173775, 3.33830214, -2.74016523, -3.04115510, 6.00119495, -1.36084354, -2.45065260, 4.56992292, -3.02825928,-3.74182844,5.11069250,-0.91531068,-2.31385994,1.83399653,3.39370203,-3.60886002}); - auto z = NDArrayFactory::create('c', {4, 4, 4, 3}); - auto exp = NDArrayFactory::create('c', {4, 4, 4, 3}, {7.97172260, 0.06878620, 2.27749538, 7.29276514, -0.14074677, 0.65480286, 5.70313978, -0.06546132, 0.35443667, 3.70382833, -0.84020567, 0.63826996, 8.60301399, -0.38236514, 1.55177069, 7.37542057, -0.99374938, -0.29971302, 8.84352493, -0.67121059, 0.43132120, 4.78175592, -1.25070143, -1.91523600, 6.03855371, -0.00292124, -1.11214364, 7.90158176, -0.57949901, -0.96735370, 7.81192017, -0.53255427, -0.48009714, 3.16953635, 0.08353355, -1.54299748, 3.74821687, 1.69396687, 0.72724354, 5.42915201, -1.13686812, -0.71793109, 5.78376389, -0.72239977, -0.60055625, 2.53636408, 0.56777251, -2.07892323, 6.08064651, 0.68620735, 2.54017019, 5.65828180, -0.68255502, 1.47283304, 6.10842514, -0.39655915, 0.28380761, 1.96707797, -1.98206317, 0.94027776, 4.71811438, 0.32104525, -0.92409706, 8.34588146, -1.05581069, -0.55217457, 9.58440876, -0.96549922, 0.45820439, 5.65453672, -2.50953507, -0.71441835, 8.03059578, -0.21281289, 0.92125505, 9.26900673, -0.35963219, -0.70039093, 8.59924412, -1.22358346, 0.81318003, 3.85920119, -0.01305223, -1.09234154, 6.33158875, 1.28094780, -1.48926139, 4.94969177, -0.77126902, -1.97033751, 5.64381838, -0.16285487, -1.31277227, 2.39893222, -1.32902908, -1.39609122, 6.47572327, -0.45267010, 1.55727172, 6.70965624, -1.68735468, -0.05672536, 7.25092363, -0.64613032, 0.67050058, 3.60789680, -2.05948973, 2.22687531, 8.15202713, -0.70148355, 1.28314006, 8.14842319, -1.88807654, -1.04808438, 8.45500565, -0.76425624, 0.94542569, 4.56179953, -0.28786001, -2.04502511, 8.46278095, -0.31019822, 0.07339200, 9.34214592, -0.61948007, 0.52481830, 8.32515621, -1.52418160, 0.49678251, 5.11082315, -1.09908783, -0.52969611, 5.27806664, 0.88632923, 0.66754371, 4.75839233, 0.48928693, -0.68036932, 6.56925392, -0.02949905, -2.99189186, 4.46320581, -0.64534980, -0.29516968, 8.60809517, -1.13120568, 3.41720533, 5.84243155, -1.24109328, 0.89566326, 5.99578333, -0.42496428, 2.07076764, 3.17812920, -0.81566459, -0.14363396, 6.55184317, 0.39633346, -0.43852386, 8.70214558, -2.24613595, 0.30708700, 8.73882294, -0.53545928, 1.54409575, 4.49452257, -0.16509305, 0.19028664, 8.24897003, 0.44750381, 2.15448594, 8.97640514, -0.77728152, 0.57272542, 9.03467560, 0.47173575, -1.10807717, 3.30056310, -0.43268481, -0.41470885, 3.53798294, -0.08546703, -2.16840744, 6.18733406, -0.17871059, -2.59837723, 5.94218683, -1.02990067, -0.49760687, 3.76938033, 0.86383581, -1.91504073}); + auto input = NDArrayFactory::create('c', {4, 10, 10, 3}, {9.37125111, 2.20166993, 2.91434479, 5.43639755, -2.10573769, 4.08528662, 5.86908436, -4.46203756, 2.21057916, 5.35849190, 0.01394637, 4.40566349, 7.07982206, -0.09633455, 2.42429352, 3.97301817, -1.89553940, 1.99690318, 6.33141708, 0.55401880, 1.70707977, 5.55204201, -0.03513752, 1.60011971, 2.62700319, -2.74582434, 3.06697464, 1.06277943, -1.16075921, -0.78095782, 9.72352791, -1.22686064, 1.99644792, 7.35571337, 1.40607321, 0.11390255, 9.53334427, 2.28303599, -1.66728830, 6.16678810, -0.04532295, -1.97708666, 9.74906158, 1.46223176, -1.46734393, 4.30761862, -1.23790228, 1.24823606, 6.13938427, -3.83689475, -1.19625473, 7.91535568, 6.05868721, -3.22946382, 8.81633949, -0.19967777, 0.66053957, 2.30919123, 0.74543846, -0.39347672, 11.11058044, 0.53720862, 1.52645731, 5.70012379, -1.15213466, 1.16451406, 7.00526333, 1.57362783, -2.44384766, 5.54213285, -1.98828590, -0.70483637, 7.88281822, -3.59875536, 0.80745387, 13.41578484, -1.55507684, -0.65855008, 9.32583523, -0.14544789, 0.73436141, 3.61176538, -1.71268058, -2.58490300, 9.09280205, -3.27405524, -2.04569697, 4.44761324, -0.62955856, -2.61917663, 8.04890442, 0.54579324, 0.85929775, 9.82259560, -1.93825579, 0.77703512, 4.67090321, -4.79267597, -2.38906908, 9.31265545, 0.96026313, -1.14109385, 11.54231834, -0.01417295, -0.39500344, 8.49191666, 0.55300158, 2.79490185, 6.92466164, 1.72254205, 2.82222271, 8.83112717, 2.95033407, 2.18054962, 6.73509789, -2.22272944, 0.51127720, -1.04563558, 2.15747333, -2.30959272, 9.55441570, 1.50396204, 1.77370787, 7.38146257, -1.79076433, 3.20961165, 7.18864202, 2.91217351, 0.43018937, 7.11078024, -1.17386127, -0.16817921, 6.12327290, -2.82205725, 3.30696845, 13.51291752, -1.30856836, -2.38332748, 11.09487438, -1.47190213, -0.53050828, 4.38285351, -5.07309771, 1.50714362, 5.72274446, -2.85825086, -0.89673209, 3.73791552, -0.67708802, -4.13149452, -0.00671843, -0.26566532, 0.32961160, 7.14501762, -1.41608179, -4.96590328, 12.26205540, -0.65158135, -0.88641000, 6.95777559, -0.79058206, -0.10260171, 7.87169170, 1.35921454, 1.11759663, 5.46187401, -2.57214499, 2.48484039, 4.04043484, -2.07137156, -1.42709637, 9.25487137, -0.12605135, -2.66949964, 2.89412403, 0.74451172, -2.96250391, 3.99258423, 0.27084303, 0.32213116, 5.42332172, -0.44414216, 1.70881832, 6.69346905, 0.53058422, -4.73146200, 4.22051668, 2.24834967, 0.66996074, 4.30173683, 0.11849818, -4.07520294, 8.27318478, -2.54398274, -2.86705542, 10.11775303, -0.99382895, 0.65881538, 7.93556786, -1.27934420, -1.69343162, 9.68042564, -1.02609646, -1.18189347, 5.75370646, -1.67888868, -4.48871994, 4.79537392, -0.79212248, -0.19855022, 6.15060997, -0.01081491, 3.64454579, 10.82562447, 1.58859253, -2.65847278, 8.60093212, -1.59196103, 0.07635692, 11.76175690, -1.17453325, 0.10122013, 6.86458445, -2.18891335, -2.74004745, 8.07066154, 0.71818852, -2.03035975, 6.31053686, 0.51509416, 1.39789927, 9.43515587, 2.04256630, 0.13985133, 4.65010691, 2.40911126, -0.36255789, -3.06867862, -0.45225358, -1.56778407, 6.05917358, -1.09891272, 1.77184200, 6.46248102, 0.96042323, -0.24346280, 4.63436460, -4.69907761, 1.25187206, 11.46173859, -2.21917558, 1.28007793, 6.92173195, 2.11268163, -3.47389889, 5.08722782, -3.03950930, -4.17154264, 11.30568314, 0.80361372, 2.53214502, 7.18707085, -4.49114513, 2.85449266, 10.14906883, -0.31974933, -0.84472644, -0.52459574, 0.12921631, -1.81390119, 2.76170087, 1.03982210, 2.91744232, -0.29048753, 5.87453508, -1.53684759, 1.85800636, -0.91404629, 1.28954852, 5.11354685, -2.47475505, -1.33179152, 2.58552408, 1.37316465, -3.32339454, 1.54122913, 3.24953628, -0.29758382, 2.82391763, -1.51142192, -1.22699404, 6.75745535, 0.65452754, -3.29385471, 2.06008053, 2.53172946, -4.23532820, -1.53909743, -0.07010663, -1.42173731, 7.29031610, -0.18448229, 4.59496164, 6.73027277, 0.73441899, 0.14426160, 4.14915276, -2.97010231, 6.05851364, 4.95218086, -2.39145470, 2.40494704, 2.10288811, 0.53503096, 1.44511235, 6.66344261, -3.05803776, 7.21418667, 3.30303526, -0.24163735, 3.47409391, 3.64520788, 2.15189481, -3.11243272, 3.62310791, 0.37379482, 0.40865007, -0.83132005, -4.78246069, 2.07030797, 6.51765442, 3.16178989, 5.06180477, 3.78434467, -0.96689719, 0.35965276, 5.89967585, 1.40294051, 1.11952639, 10.59778214, 0.26739889, -1.61297631, 6.24801159, -0.93914318, -0.57812452, 9.92604542, -0.73025000, -3.38530874, 2.45646000, -2.47949195, 0.51638460, 10.65636063, 1.97816694, -3.00407791, 2.66914415, -0.81951088, -0.23316640, 2.40737987, -2.70007610, 1.51531935, 4.08860207, -0.27552786, -1.31721711, 7.11568260, -3.33498216, -4.02545023, 7.22675610, -0.81690705, -2.52689576, 1.04016697, -0.79291463, -0.34875512, 10.00498390, -4.24167728, 1.46162593, 11.82569408, -1.70359993, -0.30161047, 16.44085884, -0.82253462, -0.09435523, 6.13080597, -0.20259480, 0.68308711, 6.15663004, -6.61776876, 0.33295766, 2.55449438, -0.17819691, -1.14892209, 5.56776142, 1.99279118, 1.33035934, 4.45823956, 3.34916544, -2.59905386, 6.16164446, -2.03881931, -2.45273542, 12.46793365, -2.22743297, 2.83738565, 8.48628139, -1.39347959, -1.30867767, 11.08041477, -4.00363779, 2.09183025, 11.30395889, -2.20504737, 1.37426853, 8.98735619, 1.04676604, -0.72757077, 8.28050232, -6.70741081, -0.65798020, 5.68592072, -0.60760021, 0.35854483, 6.26852131, 1.94100165, 1.32112014, 0.80987954, -1.74617672, -0.25434083, 7.16045523, 1.58884013, -2.64847064, 13.14820385, 1.21393633, -2.47258949, 9.41650105, -0.79384226, 2.48954105, 10.95629311, 0.47723705, 4.02126694, 8.02593136, -2.20726371, -1.18794477, 1.50836647, 0.93118095, -1.73513174, 8.85493565, -2.99670315, -0.79055870, 2.39473820, 2.05046916, -2.38055134, 11.82299423, 0.15609655, 0.68744308, 5.66401434, -0.69281673, 2.09855556, 7.74626589, -0.34283102, 1.00542057, 9.95838642, 0.80161905, 2.33455157, 9.80057335, -0.93561798, 2.56991577, 8.29711342, 0.94213426, 0.44209945, 11.70259857, 0.92710167, 2.60957146, 0.24971688, -0.86529571, 3.78628922, 6.80884457, -0.68178189, 2.21103406, 3.18895817, 0.60283208, -2.92716241, 6.72060776, -1.06625068, 2.56543374, 9.97404480, 3.58080721, -0.94936347, 10.16736984, -1.38464379, 1.18191063, 6.66179037, -3.56115270, 0.32329530, 10.90870762, 2.20638227, 0.19653285, 7.34650040, -3.63859272, -1.03027737, 5.98829985, -3.66606474, -3.89746714, 8.63469028, 1.22569811, 1.63240814, 3.74385309, 0.58243257, -0.56981975, 3.69260955, 1.00979900, -1.44030499, 8.57058144, -1.10648811, 1.20474911, 5.43133020, -2.14822555, -0.07928789, 11.25825310, 0.19645604, -5.49546146, 10.41917038, -0.68178523, -2.99639869, 6.50054455, 0.46488351, -5.42328453, 9.09500027, -2.82107449, 0.05601966, 15.34610748, -0.06820253, 3.86699796, 10.73316956, -3.04795432, -0.14702171, 5.64813185, 1.44028485, -2.47596145, 0.07280898, -3.03187990, -1.35183525, 9.35835648, 2.72966957, 1.88199532, 10.36187744, -0.22834805, -3.26738238, 6.92025137, -2.34061313, 4.77379704, 5.28559113, -2.96323752, -1.76186585, 5.94436455, 0.38647744, -5.73869514, 6.76849556, 1.40892124, -1.19068217, 5.37919092, -6.65328646, 3.62782669, 12.34744644, 2.44762444, -4.19242620, 6.14906216, 0.08121119, 0.61355996, 2.69666457, -1.88962626, -0.55314136, 1.84937525, 1.56048691, 1.17460012, 3.75674725, 1.06198275, -5.74625874, 5.41645575, -1.28946674, -1.51689398, 4.32400894, -0.05222082, -4.83948946, 1.80747867, 1.63144708, -2.73887825, 1.63975775, -2.02163982, -0.16210437, 2.93518686, 1.14427686, -2.83246303, 4.79283667, 2.69697428, -3.12678456, -1.19225168, -2.37022972, -3.09429741, 1.94225383, -1.13747168, -2.55048585, 5.40242243, 1.12777328, 3.43713188, 3.62658787, -2.16878843, 0.30164462, 2.97407579, -0.07275413, -1.31149673, 4.70066261, -2.01323795, 4.85255766, 4.59128904, 1.68084168, 1.60336494, 6.58138466, -1.04759812, 2.69906545, 3.55769277, -0.74327278, 2.65819693, 5.39528131, 2.11248922, -1.06446671, 5.24546766, -2.43146014, 4.58907509, 0.06521678, -2.24503994, 2.45722699, 6.94863081, 0.35258654, 2.83396196, 9.92525196, -1.12225175, -0.34365177, 7.19116688, -4.39813757, 0.46517885, 13.22028065, -2.57483673, -6.37226963, 7.58046293, -2.74600363, 0.42231262, 8.04881668, 0.17289802, -0.53447008, 16.55157471, -5.63614368, 0.39288223, 3.37079263, 1.26484549, -0.12820500, 8.46440125, -4.39304399, 2.97676420, 0.65650189, 0.83158541, -1.11556435, 6.32885838, -0.36087769, 2.80724382, 9.90292645, 1.15936041, 0.20947981, 6.91249275, -2.67404819, 2.93782163, 6.65656614, -2.30828357, 2.98214006, 6.80611229, -4.93821478, -7.66555262, 7.59763002, -0.54159302, 3.87403512, 12.42607784, 2.59284401, -0.23375344, 8.95293331, -0.71807784, 0.61873478, 8.66713524, 1.24289191, -2.37835455, 2.08071637, -0.88315344, -3.41891551, 6.85245323, 1.73007369, 1.02169311, 7.69170332, -2.85411978, 2.69790673, 8.12906551, -1.19351399, -2.26442742, 12.26104450, -0.75579089, -1.73274946, 10.68729019, 2.20655656, -0.90522075, 12.42165184, -1.67929137, 2.44851565, 9.31565762, -0.06645700, 1.52762020, 6.18427515, -1.68882596, 3.70261097, 3.02252960, -3.44125366, -1.31575799, 2.84617424, -0.96849400, -4.52356243, 9.95027161, 0.19966406, -0.78874779, 8.18595028, -4.08300209, 1.75126517, 0.96418417, -4.04913044, -0.95200396, 12.03637886, -0.03041124, 0.41642749, 8.88267422, -3.24985337, -2.24919462, 7.32566118, 0.16964148, -2.74123430, 7.05264473, -3.30191112, 0.17163286, 4.81851053, -1.64463484, -0.85933101, 7.29276276, 2.34066939, -2.14860010, 3.46148157, -0.01782012, 1.51504040, 4.79304934, 1.85281146, -1.70663762, 6.93470192, -4.15440845, -1.25983095, 10.52491760, 0.42930329, -1.85146868, 11.70042324, -0.41704914, 3.83796859, 9.21148491, -2.79719448, 0.79470479, 6.26926661, -5.85230207, 3.95105338, 7.84790897, -1.38680744, -1.78099084, 11.95235348, -2.99841452, -1.34507811, 6.15714645, -1.07552516, -2.81228638, 1.66234732, -4.55166149, -1.92601109, 8.64634514, -0.48158705, 3.31595659, 7.67371941, 2.56964207, 0.12107098, 4.56467867, -0.93541539, 1.39432955, 11.99714088, 1.05353570, -2.13099813, 3.67617917, 3.45895386, 1.37365830, 8.74344158, -4.17585802, 1.43908918, 6.28764772, 3.97346330, -0.69144285, 9.07983303, -0.41635889, -0.14965028, 8.85469818, 1.11306190, 2.59440994, 5.38982344, -1.07948279, 1.37252975, 10.26984596, -0.09318046, 2.73104119, 12.45902252, -1.55446684, -2.76124811, 12.19395065, -0.51846564, 1.02764034, 11.42673588, -0.95940983, -0.04781032, 8.78379822, -4.88957930, 0.32534006, 11.97696400, -3.35108662, 1.95104563, 4.46915388, -2.32061648, 3.45230985, 8.29983711, 2.81034684, -2.35529327, 6.07801294, -0.98105043, -0.05359888, 2.52291036, -0.01986909, -2.35321999, 10.51954269, 2.11145401, 3.53506470, 7.29093266, 0.03721160, -1.13496494, 7.43886709, -5.84201956, 2.50796294, 12.14647675, 2.77490377, -2.18896222, 6.05641937, 5.32617044, 1.04221284, 10.79106712, -2.95749092, -2.75414610, 11.30037117, -3.40654182, -2.24673963, 7.49126101, 0.70811015, -6.18003702, 13.83951187, -1.01204085, 1.36298490, -1.04451632, 2.42435336, -0.02346706, -0.85528886, 1.04731262, 0.22192979, 4.15708160, 0.34933877, 0.04814529, 2.24107265, 0.49676740, -1.47752666, 0.45040059, -0.70471478, -1.19759345, 0.21711677, 0.88461423, -2.76830935, 5.52066898, 1.97664857, -1.75381601, 3.45877838, 1.52617192, -1.61350942, 0.85337949, 1.97610760, -3.40310287, 3.40319014, -3.38691044, -0.71319139, 1.65463758, -0.60680127, -1.80700517, 8.02592373, 2.59627104, 2.65895891, 5.93043184, -4.48425817, 3.92670918, 4.19496679, -2.28286791, 6.41634607, 5.72330523, 1.16269672, -0.28753027, 2.46342492, 0.36693189, 0.26712441, 6.37652683, -2.50139046, 2.43923736, 5.56310415, 0.98065847, 1.04267502, 4.16403675, -0.04966142, 4.40897894, 3.72905660, -3.46129870, 3.59962773, 1.34830284, -1.76661730, 0.47943926, 5.29946661, -1.12711561, 1.26970029, 15.17655945, -1.50971997, 5.81345224, 8.48562050, -4.36049604, 2.48144460, 8.23780441, -3.46030426, -0.84656560, 5.94946814, 1.12747943, -2.65683913, 8.69085693, 1.31309867, -2.79958344, 8.76840591, -1.56444156, 1.62710834, 2.41177034, -0.72804940, 5.70619011, 4.67169666, -0.86167198, -1.83803177, 2.96346045, 2.82692933, -2.81557131, 7.11113358, -1.90071094, 2.54244423, 11.19284058, -0.06298946, -1.71517313, 12.98388577, 0.84510714, 3.00816894, 2.57200313, 0.03899818, -1.49330592, 9.60099125, -3.59513044, -1.30045319, 7.09241819, -0.65233821, -2.33627677, 8.81366920, 0.84154201, 1.03312039, 9.85289097, 0.19351870, 1.78496623, 7.34631205, -2.16530800, -0.65016162, 2.46842360, 0.24016285, -1.24308395, 4.78175163, -0.97682536, 2.20942235, 6.68382788, 3.76786447, -1.44454038, 6.26453733, -3.23575711, -2.30137897, 9.53092670, -5.55222607, 3.25999236, 9.37559509, 1.86339056, -0.23551451, 10.23400211, 3.93031883, -0.52629089, 7.85724449, -2.91549587, 4.46612740, 5.66530371, -2.70820427, 4.81359577, 10.31247330, 1.92230141, 2.53931546, 0.74986327, 1.70303428, 0.48063779, 5.31099129, -0.78976244, 3.75864220, 4.23051405, 2.34042454, -7.98193836, 9.83987141, -1.46722627, 3.54497814, 10.36455154, -4.51249075, 0.77715248, 7.78694630, -4.59989023, -2.49585629, 9.90296268, 1.38535416, 1.17441154, 10.10452843, -0.98628229, 0.60194463, 9.12639141, -3.90754628, 2.88526392, 7.24123430, -0.15283313, -0.75728363, -1.15116858, -2.53791571, 0.77229571, 6.44114161, 0.02646767, 4.95463037, 7.21066380, 1.79384065, 0.73250306, 8.04447937, 0.32576546, -0.79447043, 10.12717724, 2.33392906, 1.30716443, 12.36073112, -0.36694977, -1.20438910, 7.03105593, 0.59557682, 0.69267452, 10.18113136, 2.49944925, -0.42229167, 8.83143330, -1.18805945, -2.87509322, 4.53596449, 4.09732771, -3.39088297, -1.02536607, 0.82119560, -3.47302604, 9.29991817, 0.21001509, 4.97036457, 9.50018406, 1.04420102, 1.96560478, 10.74769592, -6.22709799, 3.11690164, 5.06759691, -1.23724771, -3.05831861, 8.12925529, -1.93435478, -1.10151744, 9.32263088, -0.04249470, -5.98547363, 10.49398136, 0.26400441, -0.78915191, 13.28219604, 2.99276900, 0.74853164, 2.49364305, -3.43529654, 4.05278301, 2.13498688, -2.35444307, -0.79900265, 4.66968822, -0.31095147, 3.60674143, 12.37222099, -0.07855003, -3.30292702, 12.15215874, 0.60886210, 2.87075138, 7.75271845, 0.38044083, 3.34402204, 6.40583277, -0.87888050, 0.67438459, 6.91080809, 1.98332930, -0.08303714, 8.08630371, -0.16772588, -2.74058914, 7.17253590, -2.69122696, 1.48173678, 8.99470139, -1.43302310, -0.88651133, 2.66944790, -0.29186964, 2.00838661, 5.09587479, -0.76676071, -2.88322186, 8.31110573, -0.14550979, -1.37726915, 10.28355122, -1.60575438, -0.04118848, 9.97510815, 0.14440438, -3.24632120, 9.00034523, 4.14319563, -1.31023729, 7.16950464, -0.70428526, 2.01559544, 7.26155043, 2.40816474, 2.09847403, 7.31264496, -0.75401551, 2.13392544, 7.03648758, 1.04036045, -1.15636516, 1.09634531, -0.06340861, -0.58107805, -0.65623116, 1.18972754, -0.80717683, 1.40118241, -0.61932516, -3.60596156, 1.59904599, -2.23774099, -1.13721037, 3.89620137, -0.09115922, -7.51356888, 2.36975193, -1.42520905, -2.34173775, 3.33830214, -2.74016523, -3.04115510, 6.00119495, -1.36084354, -2.45065260, 4.56992292, -3.02825928,-3.74182844,5.11069250,-0.91531068,-2.31385994,1.83399653,3.39370203,-3.60886002}); + auto z = NDArrayFactory::create('c', {4, 4, 4, 3}); + auto exp = NDArrayFactory::create('c', {4, 4, 4, 3}, {7.97172260, 0.06878620, 2.27749538, 7.29276514, -0.14074677, 0.65480286, 5.70313978, -0.06546132, 0.35443667, 3.70382833, -0.84020567, 0.63826996, 8.60301399, -0.38236514, 1.55177069, 7.37542057, -0.99374938, -0.29971302, 8.84352493, -0.67121059, 0.43132120, 4.78175592, -1.25070143, -1.91523600, 6.03855371, -0.00292124, -1.11214364, 7.90158176, -0.57949901, -0.96735370, 7.81192017, -0.53255427, -0.48009714, 3.16953635, 0.08353355, -1.54299748, 3.74821687, 1.69396687, 0.72724354, 5.42915201, -1.13686812, -0.71793109, 5.78376389, -0.72239977, -0.60055625, 2.53636408, 0.56777251, -2.07892323, 6.08064651, 0.68620735, 2.54017019, 5.65828180, -0.68255502, 1.47283304, 6.10842514, -0.39655915, 0.28380761, 1.96707797, -1.98206317, 0.94027776, 4.71811438, 0.32104525, -0.92409706, 8.34588146, -1.05581069, -0.55217457, 9.58440876, -0.96549922, 0.45820439, 5.65453672, -2.50953507, -0.71441835, 8.03059578, -0.21281289, 0.92125505, 9.26900673, -0.35963219, -0.70039093, 8.59924412, -1.22358346, 0.81318003, 3.85920119, -0.01305223, -1.09234154, 6.33158875, 1.28094780, -1.48926139, 4.94969177, -0.77126902, -1.97033751, 5.64381838, -0.16285487, -1.31277227, 2.39893222, -1.32902908, -1.39609122, 6.47572327, -0.45267010, 1.55727172, 6.70965624, -1.68735468, -0.05672536, 7.25092363, -0.64613032, 0.67050058, 3.60789680, -2.05948973, 2.22687531, 8.15202713, -0.70148355, 1.28314006, 8.14842319, -1.88807654, -1.04808438, 8.45500565, -0.76425624, 0.94542569, 4.56179953, -0.28786001, -2.04502511, 8.46278095, -0.31019822, 0.07339200, 9.34214592, -0.61948007, 0.52481830, 8.32515621, -1.52418160, 0.49678251, 5.11082315, -1.09908783, -0.52969611, 5.27806664, 0.88632923, 0.66754371, 4.75839233, 0.48928693, -0.68036932, 6.56925392, -0.02949905, -2.99189186, 4.46320581, -0.64534980, -0.29516968, 8.60809517, -1.13120568, 3.41720533, 5.84243155, -1.24109328, 0.89566326, 5.99578333, -0.42496428, 2.07076764, 3.17812920, -0.81566459, -0.14363396, 6.55184317, 0.39633346, -0.43852386, 8.70214558, -2.24613595, 0.30708700, 8.73882294, -0.53545928, 1.54409575, 4.49452257, -0.16509305, 0.19028664, 8.24897003, 0.44750381, 2.15448594, 8.97640514, -0.77728152, 0.57272542, 9.03467560, 0.47173575, -1.10807717, 3.30056310, -0.43268481, -0.41470885, 3.53798294, -0.08546703, -2.16840744, 6.18733406, -0.17871059, -2.59837723, 5.94218683, -1.02990067, -0.49760687, 3.76938033, 0.86383581, -1.91504073}); nd4j::ops::avgpool2d op; @@ -1255,7 +1255,7 @@ TEST_F(JavaInteropTests, test_ismax_view) { } TEST_F(JavaInteropTests, test_size_dtype_1) { - auto x = NDArrayFactory::create('c', {3}, {1, 1, 1}); + auto x = NDArrayFactory::create('c', {3}, {1.f, 1.f, 1.f}); auto z = NDArrayFactory::create(0.0f); auto e = NDArrayFactory::create(3.0f); diff --git a/libnd4j/tests_cpu/layers_tests/NDArrayTests2.cpp b/libnd4j/tests_cpu/layers_tests/NDArrayTests2.cpp index a497cd9e6..4f8d38e76 100644 --- a/libnd4j/tests_cpu/layers_tests/NDArrayTests2.cpp +++ b/libnd4j/tests_cpu/layers_tests/NDArrayTests2.cpp @@ -94,7 +94,7 @@ TEST_F(NDArrayTest2, Test_Reshape_Scalar_2) { } TEST_F(NDArrayTest2, Test_IndexReduce_1) { - auto x = NDArrayFactory::create('c', {1, 5}, {1, 2, 3, 4, 5}); + auto x = NDArrayFactory::create('c', {1, 5}, {1, 2, 3, 4, 5}); ExtraArguments extras({3.0, 0.0, 10.0}); int idx = x.indexReduceNumber(indexreduce::FirstIndex, &extras).e(0); @@ -160,7 +160,7 @@ TEST_F(NDArrayTest2, SetIdentity_test_5) { TEST_F(NDArrayTest2, SetIdentity_test_6) { auto x = NDArrayFactory::create('c', {3, 2}); - auto xExp = NDArrayFactory::create('c', {3, 2}, {1,0,0,1,0,0}); + auto xExp = NDArrayFactory::create('c', {3, 2}, {1.f, 0.f, 0.f, 1.f, 0.f, 0.f}); x.setIdentity(); @@ -171,7 +171,7 @@ TEST_F(NDArrayTest2, SetIdentity_test_6) { TEST_F(NDArrayTest2, SetIdentity_test_7) { auto x = NDArrayFactory::create('c', {3, 4}); - auto xExp = NDArrayFactory::create('c', {3, 4}, {1.,0.,0.,0.,0.,1.,0.,0.,0.,0.,1.,0.}); + auto xExp = NDArrayFactory::create('c', {3, 4}, {1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f}); x.setIdentity(); @@ -192,9 +192,9 @@ TEST_F(NDArrayTest2, SetIdentity_test_8) { //////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, Test_AllReduce3_1) { - auto x = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 1, 2, 3}); - auto y = NDArrayFactory::create('c', {2, 3}, {2, 3, 4, 2, 3, 4}); - auto exp = NDArrayFactory::create('c', {2, 2}, {1.73205, 1.73205, 1.73205, 1.73205}); + auto x = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 1, 2, 3}); + auto y = NDArrayFactory::create('c', {2, 3}, {2, 3, 4, 2, 3, 4}); + auto exp = NDArrayFactory::create('c', {2, 2}, {1.73205, 1.73205, 1.73205, 1.73205}); auto z = x.applyAllReduce3(reduce3::EuclideanDistance, &y, {1}, nullptr); @@ -206,9 +206,9 @@ TEST_F(NDArrayTest2, Test_AllReduce3_1) { //////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, Test_AllReduce3_2) { - auto x = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 2, 3, 4 }); - auto y = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 2, 3, 4}); - auto exp = NDArrayFactory::create('c', {2, 2}, {0., 1.73205, 1.73205, 0.}); + auto x = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 2, 3, 4 }); + auto y = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 2, 3, 4}); + auto exp = NDArrayFactory::create('c', {2, 2}, {0., 1.73205, 1.73205, 0.}); auto z = x.applyAllReduce3(reduce3::EuclideanDistance, &y, {1}, nullptr); @@ -221,9 +221,9 @@ TEST_F(NDArrayTest2, Test_AllReduce3_2) { //////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, mmul_test1) { - auto x = NDArrayFactory::create('c', {4, 1}, {1, 2, 3, 4}); - auto y = NDArrayFactory::create('c', {1, 4}, {1, 2, 3, 4}); - auto exp = NDArrayFactory::create('c', {4, 4}, {1,2, 3, 4,2,4, 6, 8,3,6, 9,12,4,8,12,16}); + auto x = NDArrayFactory::create('c', {4, 1}, {1, 2, 3, 4}); + auto y = NDArrayFactory::create('c', {1, 4}, {1, 2, 3, 4}); + auto exp = NDArrayFactory::create('c', {4, 4}, {1,2, 3, 4,2,4, 6, 8,3,6, 9,12,4,8,12,16}); auto result = mmul(x, y); ASSERT_TRUE(exp.isSameShape(&result)); @@ -234,9 +234,9 @@ TEST_F(NDArrayTest2, mmul_test1) { //////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, mmul_test2) { - auto x = NDArrayFactory::create('c', {4, 1}, {1, 2, 3, 4}); - auto y = NDArrayFactory::create('c', {1, 4}, {1, 2, 3, 4}); - auto exp = NDArrayFactory::create('c', {1, 1}, {30}); + auto x = NDArrayFactory::create('c', {4, 1}, {1, 2, 3, 4}); + auto y = NDArrayFactory::create('c', {1, 4}, {1, 2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1, 1}, {30}); auto result = mmul(y ,x); @@ -248,10 +248,10 @@ TEST_F(NDArrayTest2, mmul_test2) { //////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, mmul_test3) { - auto x = NDArrayFactory::create('c', {4, 1}, {1, 2, 3, 4}); - auto exp = NDArrayFactory::create('c', {4, 4}, {1. ,0.2 ,0.3 ,0.4 ,0.2,0.04,0.06,0.08,0.3,0.06,0.09,0.12,0.4,0.08,0.12,0.16}); - auto w = NDArrayFactory::create( x.ordering(), {(int)x.lengthOf(), 1}, x.getContext()); // column-vector - auto wT = NDArrayFactory::create(x.ordering(), {1, (int)x.lengthOf()}, x.getContext()); // row-vector (transposed w) + auto x = NDArrayFactory::create('c', {4, 1}, {1, 2, 3, 4}); + auto exp = NDArrayFactory::create('c', {4, 4}, {1. ,0.2 ,0.3 ,0.4 ,0.2,0.04,0.06,0.08,0.3,0.06,0.09,0.12,0.4,0.08,0.12,0.16}); + auto w = NDArrayFactory::create( x.ordering(), {(int)x.lengthOf(), 1}, x.getContext()); // column-vector + auto wT = NDArrayFactory::create(x.ordering(), {1, (int)x.lengthOf()}, x.getContext()); // row-vector (transposed w) w = x / (float)10.; w.p(0, 1.); @@ -311,9 +311,9 @@ TEST_F(NDArrayTest2, Test_Enforce_1) { } TEST_F(NDArrayTest2, TestVector_1) { - auto x = NDArrayFactory::create('c', {2, 3}); - auto row = NDArrayFactory::create('c', {3}, {1, 2, 3}); - auto exp = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 1, 2, 3}); + auto x = NDArrayFactory::create('c', {2, 3}); + auto row = NDArrayFactory::create('c', {3}, {1, 2, 3}); + auto exp = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 1, 2, 3}); x.addiRowVector(&row); @@ -341,9 +341,9 @@ TEST_F(NDArrayTest2, Operator_Plus_Test_5) ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, Operator_Plus_Test_6) { - auto x = NDArrayFactory::create('c', {3, 3, 3}); - auto y = NDArrayFactory::create('c', {3, 1, 3}); - auto expected = NDArrayFactory::create('c', {3, 3, 3}, {2., 4., 6., 5., 7., 9., 8.,10.,12., 14.,16.,18.,17.,19.,21.,20.,22.,24., 26.,28.,30.,29.,31.,33.,32.,34.,36.}); + auto x = NDArrayFactory::create('c', {3, 3, 3}); + auto y = NDArrayFactory::create('c', {3, 1, 3}); + auto expected = NDArrayFactory::create('c', {3, 3, 3}, {2., 4., 6., 5., 7., 9., 8.,10.,12., 14.,16.,18.,17.,19.,21.,20.,22.,24., 26.,28.,30.,29.,31.,33.,32.,34.,36.}); x.linspace(1); y.linspace(1); @@ -356,8 +356,8 @@ TEST_F(NDArrayTest2, Operator_Plus_Test_6) { ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, tileToShape_test1) { - auto x = NDArrayFactory::create('c', {2, 2}, {1,2,3,4}); - auto exp = NDArrayFactory::create('c', {2, 2, 2}, {1,2,3,4,1,2,3,4}); + auto x = NDArrayFactory::create('c', {2, 2}, {1,2,3,4}); + auto exp = NDArrayFactory::create('c', {2, 2, 2}, {1,2,3,4,1,2,3,4}); x.tileToShape({2,2,2}); @@ -368,8 +368,8 @@ TEST_F(NDArrayTest2, tileToShape_test1) { ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, tileToShape_test2) { - auto x = NDArrayFactory::create('c', {2, 1, 2}, {1,2,3,4}); - auto exp = NDArrayFactory::create('c', {2, 3, 2}, {1,2,1,2,1,2,3,4,3,4,3,4}); + auto x = NDArrayFactory::create('c', {2, 1, 2}, {1,2,3,4}); + auto exp = NDArrayFactory::create('c', {2, 3, 2}, {1,2,1,2,1,2,3,4,3,4,3,4}); x.tileToShape({2,3,2}); @@ -380,9 +380,9 @@ TEST_F(NDArrayTest2, tileToShape_test2) { ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, tileToShape_test3) { - auto x = NDArrayFactory::create('c', {2, 2}, {1,2,3,4}); - auto result = NDArrayFactory::create('c', {2, 2, 2}); - auto exp = NDArrayFactory::create('c', {2, 2, 2}, {1,2,3,4,1,2,3,4}); + auto x = NDArrayFactory::create('c', {2, 2}, {1,2,3,4}); + auto result = NDArrayFactory::create('c', {2, 2, 2}); + auto exp = NDArrayFactory::create('c', {2, 2, 2}, {1,2,3,4,1,2,3,4}); x.tileToShape({2,2,2}, &result); // result.printIndexedBuffer(); @@ -394,9 +394,9 @@ TEST_F(NDArrayTest2, tileToShape_test3) { ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, tileToShape_test4) { - auto x = NDArrayFactory::create('c', {2, 1, 2}, {1,2,3,4}); - auto result = NDArrayFactory::create('c', {2, 3, 2}); - auto exp = NDArrayFactory::create('c', {2, 3, 2}, {1,2,1,2,1,2,3,4,3,4,3,4}); + auto x = NDArrayFactory::create('c', {2, 1, 2}, {1,2,3,4}); + auto result = NDArrayFactory::create('c', {2, 3, 2}); + auto exp = NDArrayFactory::create('c', {2, 3, 2}, {1,2,1,2,1,2,3,4,3,4,3,4}); x.tileToShape({2,3,2}, &result); @@ -407,50 +407,50 @@ TEST_F(NDArrayTest2, tileToShape_test4) { #ifndef __CUDABLAS__ TEST_F(NDArrayTest2, Test_TriplewiseLambda_1) { - auto t = NDArrayFactory::create('c', {3, 3}, {1, 1, 1, 1, 1, 1, 1, 1, 1}); - auto u = NDArrayFactory::create('c', {3, 3}, {2, 2, 2, 2, 2, 2, 2, 2, 2}); - auto v = NDArrayFactory::create('c', {3, 3}, {3, 3, 3, 3, 3, 3, 3, 3, 3}); - auto exp = NDArrayFactory::create('c', {3, 3}, {7, 7, 7, 7, 7, 7, 7, 7, 7}); + auto t = NDArrayFactory::create('c', {3, 3}, {1, 1, 1, 1, 1, 1, 1, 1, 1}); + auto u = NDArrayFactory::create('c', {3, 3}, {2, 2, 2, 2, 2, 2, 2, 2, 2}); + auto v = NDArrayFactory::create('c', {3, 3}, {3, 3, 3, 3, 3, 3, 3, 3, 3}); + auto exp = NDArrayFactory::create('c', {3, 3}, {7, 7, 7, 7, 7, 7, 7, 7, 7}); float extra = 1.0f; - auto la = LAMBDA_FFF(_t, _u, _v, extra) { + auto la = LAMBDA_DDD(_t, _u, _v, extra) { return _t + _u + _v + extra; }; - t.applyTriplewiseLambda(&u, &v, la); + t.applyTriplewiseLambda(&u, &v, la); ASSERT_TRUE(t.equalsTo(&exp)); } TEST_F(NDArrayTest2, Test_TriplewiseLambda_2) { - auto t = NDArrayFactory::create('c', {3, 3}, {1, 1, 1, 1, 1, 1, 1, 1, 1}); - auto u = NDArrayFactory::create('f', {3, 3}, {2, 2, 2, 2, 2, 2, 2, 2, 2}); - auto v = NDArrayFactory::create('c', {3, 3}, {3, 3, 3, 3, 3, 3, 3, 3, 3}); - auto exp = NDArrayFactory::create('c', {3, 3}, {7, 7, 7, 7, 7, 7, 7, 7, 7}); + auto t = NDArrayFactory::create('c', {3, 3}, {1, 1, 1, 1, 1, 1, 1, 1, 1}); + auto u = NDArrayFactory::create('f', {3, 3}, {2, 2, 2, 2, 2, 2, 2, 2, 2}); + auto v = NDArrayFactory::create('c', {3, 3}, {3, 3, 3, 3, 3, 3, 3, 3, 3}); + auto exp = NDArrayFactory::create('c', {3, 3}, {7, 7, 7, 7, 7, 7, 7, 7, 7}); float extra = 1.0f; - auto la = LAMBDA_FFF(_t, _u, _v, extra) { + auto la = LAMBDA_DDD(_t, _u, _v, extra) { return _t + _u + _v + extra; }; - t.applyTriplewiseLambda(&u, &v, la); + t.applyTriplewiseLambda(&u, &v, la); ASSERT_TRUE(t.equalsTo(&exp)); } //////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, Test_Indexed_Lambda) { - auto x = NDArrayFactory::create('c', {2, 2}); - auto exp = NDArrayFactory::create('c', {2, 2}, {0, 1, 2, 3}); + auto x = NDArrayFactory::create('c', {2, 2}); + auto exp = NDArrayFactory::create('c', {2, 2}, {0, 1, 2, 3}); - auto lambda = ILAMBDA_F(_x) { + auto lambda = ILAMBDA_D(_x) { return (float) _idx; }; - x.applyIndexedLambda(lambda); + x.applyIndexedLambda(lambda); ASSERT_TRUE(exp.equalsTo(&x)); } @@ -458,8 +458,8 @@ TEST_F(NDArrayTest2, Test_Indexed_Lambda) { #endif TEST_F(NDArrayTest2, Test_PermuteEquality_1) { - auto x = NDArrayFactory::create('c', {1, 60}); - auto exp = NDArrayFactory::create('c', {3, 5, 4}, {1.0, 6.0, 11.0, 16.0, 2.0, 7.0, 12.0, 17.0, 3.0, 8.0, 13.0, 18.0, 4.0, 9.0, 14.0, 19.0, 5.0, 10.0, 15.0, 20.0, 21.0, 26.0, 31.0, 36.0, 22.0, 27.0, 32.0, 37.0, 23.0, 28.0, 33.0, 38.0, 24.0, 29.0, 34.0, 39.0, 25.0, 30.0, 35.0, 40.0, 41.0, 46.0, 51.0, 56.0, 42.0, 47.0, 52.0, 57.0, 43.0, 48.0, 53.0, 58.0, 44.0, 49.0, 54.0, 59.0, 45.0, 50.0, 55.0, 60.0}); + auto x = NDArrayFactory::create('c', {1, 60}); + auto exp = NDArrayFactory::create('c', {3, 5, 4}, {1.0, 6.0, 11.0, 16.0, 2.0, 7.0, 12.0, 17.0, 3.0, 8.0, 13.0, 18.0, 4.0, 9.0, 14.0, 19.0, 5.0, 10.0, 15.0, 20.0, 21.0, 26.0, 31.0, 36.0, 22.0, 27.0, 32.0, 37.0, 23.0, 28.0, 33.0, 38.0, 24.0, 29.0, 34.0, 39.0, 25.0, 30.0, 35.0, 40.0, 41.0, 46.0, 51.0, 56.0, 42.0, 47.0, 52.0, 57.0, 43.0, 48.0, 53.0, 58.0, 44.0, 49.0, 54.0, 59.0, 45.0, 50.0, 55.0, 60.0}); x.linspace(1); x.reshapei('c', {3, 4, 5}); @@ -474,9 +474,9 @@ TEST_F(NDArrayTest2, Test_PermuteEquality_1) { } TEST_F(NDArrayTest2, Test_PermuteEquality_0) { - auto x = NDArrayFactory::create('c', {1, 60}); + auto x = NDArrayFactory::create('c', {1, 60}); x.linspace(1); - auto exp = NDArrayFactory::create('c', {3, 4, 5}, {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0, 48.0, 49.0, 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0, 57.0, 58.0, 59.0, 60.0}); + auto exp = NDArrayFactory::create('c', {3, 4, 5}, {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0, 48.0, 49.0, 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0, 57.0, 58.0, 59.0, 60.0}); x.reshapei('c', {3, 4, 5}); x.permutei({0, 1, 2}); @@ -491,9 +491,9 @@ TEST_F(NDArrayTest2, Test_PermuteEquality_0) { TEST_F(NDArrayTest2, Test_PermuteEquality_2) { - auto x = NDArrayFactory::create('c', {1, 60}); + auto x = NDArrayFactory::create('c', {1, 60}); x.linspace(1); - auto exp = NDArrayFactory::create('c', {4, 3, 5}, {1.0, 2.0, 3.0, 4.0, 5.0, 21.0, 22.0, 23.0, 24.0, 25.0, 41.0, 42.0, 43.0, 44.0, 45.0, 6.0, 7.0, 8.0, 9.0, 10.0, 26.0, 27.0, 28.0, 29.0, 30.0, 46.0, 47.0, 48.0, 49.0, 50.0, 11.0, 12.0, 13.0, 14.0, 15.0, 31.0, 32.0, 33.0, 34.0, 35.0, 51.0, 52.0, 53.0, 54.0, 55.0, 16.0, 17.0, 18.0, 19.0, 20.0, 36.0, 37.0, 38.0, 39.0, 40.0, 56.0, 57.0, 58.0, 59.0, 60.0}); + auto exp = NDArrayFactory::create('c', {4, 3, 5}, {1.0, 2.0, 3.0, 4.0, 5.0, 21.0, 22.0, 23.0, 24.0, 25.0, 41.0, 42.0, 43.0, 44.0, 45.0, 6.0, 7.0, 8.0, 9.0, 10.0, 26.0, 27.0, 28.0, 29.0, 30.0, 46.0, 47.0, 48.0, 49.0, 50.0, 11.0, 12.0, 13.0, 14.0, 15.0, 31.0, 32.0, 33.0, 34.0, 35.0, 51.0, 52.0, 53.0, 54.0, 55.0, 16.0, 17.0, 18.0, 19.0, 20.0, 36.0, 37.0, 38.0, 39.0, 40.0, 56.0, 57.0, 58.0, 59.0, 60.0}); x.reshapei('c', {3, 4, 5}); x.permutei({1, 0, 2}); @@ -507,9 +507,9 @@ TEST_F(NDArrayTest2, Test_PermuteEquality_2) { } TEST_F(NDArrayTest2, Test_PermuteEquality_3) { - auto x = NDArrayFactory::create('c', {1, 60}); + auto x = NDArrayFactory::create('c', {1, 60}); x.linspace(1); - auto exp = NDArrayFactory::create('c', {4, 5, 3}, {1.0, 21.0, 41.0, 2.0, 22.0, 42.0, 3.0, 23.0, 43.0, 4.0, 24.0, 44.0, 5.0, 25.0, 45.0, 6.0, 26.0, 46.0, 7.0, 27.0, 47.0, 8.0, 28.0, 48.0, 9.0, 29.0, 49.0, 10.0, 30.0, 50.0, 11.0, 31.0, 51.0, 12.0, 32.0, 52.0, 13.0, 33.0, 53.0, 14.0, 34.0, 54.0, 15.0, 35.0, 55.0, 16.0, 36.0, 56.0, 17.0, 37.0, 57.0, 18.0, 38.0, 58.0, 19.0, 39.0, 59.0, 20.0, 40.0, 60.0}); + auto exp = NDArrayFactory::create('c', {4, 5, 3}, {1.0, 21.0, 41.0, 2.0, 22.0, 42.0, 3.0, 23.0, 43.0, 4.0, 24.0, 44.0, 5.0, 25.0, 45.0, 6.0, 26.0, 46.0, 7.0, 27.0, 47.0, 8.0, 28.0, 48.0, 9.0, 29.0, 49.0, 10.0, 30.0, 50.0, 11.0, 31.0, 51.0, 12.0, 32.0, 52.0, 13.0, 33.0, 53.0, 14.0, 34.0, 54.0, 15.0, 35.0, 55.0, 16.0, 36.0, 56.0, 17.0, 37.0, 57.0, 18.0, 38.0, 58.0, 19.0, 39.0, 59.0, 20.0, 40.0, 60.0}); x.reshapei('c', {3, 4, 5}); x.permutei({1, 2, 0}); @@ -523,9 +523,9 @@ TEST_F(NDArrayTest2, Test_PermuteEquality_3) { } TEST_F(NDArrayTest2, Test_PermuteEquality_4) { - auto x = NDArrayFactory::create('c', {1, 60}); + auto x = NDArrayFactory::create('c', {1, 60}); x.linspace(1); - auto exp = NDArrayFactory::create('c', {5, 3, 4}, {1.0, 6.0, 11.0, 16.0, 21.0, 26.0, 31.0, 36.0, 41.0, 46.0, 51.0, 56.0, 2.0, 7.0, 12.0, 17.0, 22.0, 27.0, 32.0, 37.0, 42.0, 47.0, 52.0, 57.0, 3.0, 8.0, 13.0, 18.0, 23.0, 28.0, 33.0, 38.0, 43.0, 48.0, 53.0, 58.0, 4.0, 9.0, 14.0, 19.0, 24.0, 29.0, 34.0, 39.0, 44.0, 49.0, 54.0, 59.0, 5.0, 10.0, 15.0, 20.0, 25.0, 30.0, 35.0, 40.0, 45.0, 50.0, 55.0, 60.0}); + auto exp = NDArrayFactory::create('c', {5, 3, 4}, {1.0, 6.0, 11.0, 16.0, 21.0, 26.0, 31.0, 36.0, 41.0, 46.0, 51.0, 56.0, 2.0, 7.0, 12.0, 17.0, 22.0, 27.0, 32.0, 37.0, 42.0, 47.0, 52.0, 57.0, 3.0, 8.0, 13.0, 18.0, 23.0, 28.0, 33.0, 38.0, 43.0, 48.0, 53.0, 58.0, 4.0, 9.0, 14.0, 19.0, 24.0, 29.0, 34.0, 39.0, 44.0, 49.0, 54.0, 59.0, 5.0, 10.0, 15.0, 20.0, 25.0, 30.0, 35.0, 40.0, 45.0, 50.0, 55.0, 60.0}); x.reshapei('c', {3, 4, 5}); x.permutei({2, 0, 1}); @@ -539,9 +539,9 @@ TEST_F(NDArrayTest2, Test_PermuteEquality_4) { } TEST_F(NDArrayTest2, Test_PermuteEquality_5) { - auto x = NDArrayFactory::create('c', {1, 60}); + auto x = NDArrayFactory::create('c', {1, 60}); x.linspace(1); - auto exp = NDArrayFactory::create('c', {5, 4, 3}, + auto exp = NDArrayFactory::create('c', {5, 4, 3}, {1.0, 21.0, 41.0, 6.0, 26.0, 46.0, 11.0, 31.0, 51.0, 16.0, 36.0, 56.0, 2.0, 22.0, 42.0, 7.0, 27.0, 47.0, 12.0, 32.0, 52.0, 17.0, 37.0, 57.0, 3.0, 23.0, 43.0, 8.0, 28.0, 48.0, 13.0, 33.0, 53.0, 18.0, 38.0, 58.0, 4.0, 24.0, 44.0, 9.0, 29.0, 49.0, 14.0, 34.0, 54.0, 19.0, 39.0, 59.0, @@ -562,10 +562,10 @@ TEST_F(NDArrayTest2, Test_PermuteEquality_5) { //////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, fillAsTriangular_test1) { - auto x = NDArrayFactory::create('c', {4, 4}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16}); - auto exp = NDArrayFactory::create('c', {4, 4}, {1,0,0,0,5,6,0,0,9,10,11,0 ,13,14,15,16}); + auto x = NDArrayFactory::create('c', {4, 4}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16}); + auto exp = NDArrayFactory::create('c', {4, 4}, {1,0,0,0,5,6,0,0,9,10,11,0 ,13,14,15,16}); - x.fillAsTriangular(0., 0, 0, 'u'); + x.fillAsTriangular(0., 0, 0, 'u'); ASSERT_TRUE(exp.isSameShape(&x)); ASSERT_TRUE(exp.equalsTo(&x)); @@ -575,10 +575,10 @@ TEST_F(NDArrayTest2, fillAsTriangular_test1) { //////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, fillAsTriangular_test2) { - auto x = NDArrayFactory::create('c', {4, 4}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16}); - auto exp = NDArrayFactory::create('c', {4, 4}, {0,0,0,0,5,0,0,0,9,10,0 ,0 ,13,14,15,0}); + auto x = NDArrayFactory::create('c', {4, 4}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16}); + auto exp = NDArrayFactory::create('c', {4, 4}, {0,0,0,0,5,0,0,0,9,10,0 ,0 ,13,14,15,0}); - x.fillAsTriangular(0., 0, -1, 'u'); + x.fillAsTriangular(0., 0, -1, 'u'); ASSERT_TRUE(exp.isSameShape(&x)); ASSERT_TRUE(exp.equalsTo(&x)); @@ -588,10 +588,10 @@ TEST_F(NDArrayTest2, fillAsTriangular_test2) { //////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, fillAsTriangular_test3) { - auto x = NDArrayFactory::create('c', {4, 4}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16}); - auto exp = NDArrayFactory::create('c', {4, 4}, {1,2,3,4,0,6,7,8,0,0 ,11,12,0 ,0 , 0,16}); + auto x = NDArrayFactory::create('c', {4, 4}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16}); + auto exp = NDArrayFactory::create('c', {4, 4}, {1,2,3,4,0,6,7,8,0,0 ,11,12,0 ,0 , 0,16}); - x.fillAsTriangular(0., 0, 0, 'l'); + x.fillAsTriangular(0., 0, 0, 'l'); ASSERT_TRUE(exp.isSameShape(&x)); ASSERT_TRUE(exp.equalsTo(&x)); @@ -601,10 +601,10 @@ TEST_F(NDArrayTest2, fillAsTriangular_test3) { //////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, fillAsTriangular_test4) { - auto x = NDArrayFactory::create('c', {4, 4}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16}); - auto exp = NDArrayFactory::create('c', {4, 4}, {0,2,3,4,0,0,7,8,0,0 , 0,12, 0, 0, 0, 0}); + auto x = NDArrayFactory::create('c', {4, 4}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16}); + auto exp = NDArrayFactory::create('c', {4, 4}, {0,2,3,4,0,0,7,8,0,0 , 0,12, 0, 0, 0, 0}); - x.fillAsTriangular(0., 1, 0, 'l'); + x.fillAsTriangular(0., 1, 0, 'l'); ASSERT_TRUE(exp.isSameShape(&x)); ASSERT_TRUE(exp.equalsTo(&x)); @@ -612,11 +612,11 @@ TEST_F(NDArrayTest2, fillAsTriangular_test4) { //////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, Test_DType_Conversion_1) { - auto x = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 4, 5, 6}); + auto x = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 4, 5, 6}); auto xd = x.template asT(); - auto xf = xd->template asT(); + auto xf = xd->template asT(); ASSERT_TRUE(x.isSameShape(xf)); ASSERT_TRUE(x.equalsTo(xf)); @@ -766,8 +766,8 @@ TEST_F(NDArrayTest2, Test_Linspace_5) { //////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, allTensorsAlongDimension_test1) { - auto x = NDArrayFactory::create('c', {4}, {1, 2, 3, 4}); - auto exp = NDArrayFactory::create('c', {4}, {1, 2, 3, 4}); + auto x = NDArrayFactory::create('c', {4}, {1, 2, 3, 4}); + auto exp = NDArrayFactory::create('c', {4}, {1, 2, 3, 4}); auto set = x.allTensorsAlongDimension({0}); // set->at(0)->printShapeInfo(); @@ -836,8 +836,8 @@ TEST_F(NDArrayTest2, scalar_set_test2) { } TEST_F(NDArrayTest2, big_dup_test) { - // auto arr = NDArrayFactory::linspace(1.0f, 10000000.0f, 100000000); - auto arr = NDArrayFactory::linspace(1.0f, 1000.0f, 10000); + // auto arr = NDArrayFactory::linspace(1.0f, 10000000.0f, 100000000); + auto arr = NDArrayFactory::linspace(1.0f, 1000.0f, 10000); auto dup = arr->dup('c'); ASSERT_EQ(*arr, *dup); diff --git a/libnd4j/tests_cpu/layers_tests/NativeOpsTests.cpp b/libnd4j/tests_cpu/layers_tests/NativeOpsTests.cpp index 1846fc397..e426eeb1f 100644 --- a/libnd4j/tests_cpu/layers_tests/NativeOpsTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/NativeOpsTests.cpp @@ -682,7 +682,7 @@ TEST_F(NativeOpsTests, ScalarTest_1) { TEST_F(NativeOpsTests, ScalarTest_2) { auto x = NDArrayFactory::create('c', {5, 5}); - auto y = NDArrayFactory::create(10.); + auto y = NDArrayFactory::create(10.f); auto exp = NDArrayFactory::create('c', {5,5}); auto z = NDArrayFactory::create('c', {5,5}); @@ -714,9 +714,9 @@ TEST_F(NativeOpsTests, ScalarTest_2) { } TEST_F(NativeOpsTests, SummaryStatsScalarTest_1) { - auto x = NDArrayFactory::create('c', {5, 5}, {0.1, 0.2, 0.3, -0.3, -0.5, 0.5, 0.7, 0.9, 0.8, 0.1, 0.11, 0.12, 0.5, -0.8, -0.9, 0.4, 0.1, 0.2, 0.3, -0.3, -0.5, 0.2, 0.3, -0.3, -0.5}); - auto exp = NDArrayFactory::create(0.9); - auto z = NDArrayFactory::create(0.21587136); + auto x = NDArrayFactory::create('c', {5, 5}, {0.1f, 0.2f, 0.3f, -0.3f, -0.5f, 0.5f, 0.7f, 0.9f, 0.8f, 0.1f, 0.11f, 0.12f, 0.5f, -0.8f, -0.9f, 0.4f, 0.1f, 0.2f, 0.3f, -0.3f, -0.5f, 0.2f, 0.3f, -0.3f, -0.5f}); + auto exp = NDArrayFactory::create(0.9f); + auto z = NDArrayFactory::create(0.21587136f); Nd4jPointer extra[6]; #ifdef __CUDABLAS__ @@ -739,9 +739,9 @@ TEST_F(NativeOpsTests, SummaryStatsScalarTest_1) { } TEST_F(NativeOpsTests, SummaryStatsScalarTest_2) { - auto x = NDArrayFactory::create('c', {5, 5}, {0.1, 0.2, 0.3, -0.3, -0.5, 0.5, 0.7, 0.9, 0.8, 0.1, 0.11, 0.12, 0.5, -0.8, -0.9, 0.4, 0.1, 0.2, 0.3, -0.3, -0.5, 0.2, 0.3, -0.3, -0.5}); - auto exp = NDArrayFactory::create(0.9); - auto z = NDArrayFactory::create(0.21587136); + auto x = NDArrayFactory::create('c', {5, 5}, {0.1, 0.2, 0.3, -0.3, -0.5, 0.5, 0.7, 0.9, 0.8, 0.1, 0.11, 0.12, 0.5, -0.8, -0.9, 0.4, 0.1, 0.2, 0.3, -0.3, -0.5, 0.2, 0.3, -0.3, -0.5}); + auto exp = NDArrayFactory::create(0.9); + auto z = NDArrayFactory::create(0.21587136); Nd4jPointer extra[6]; #ifdef __CUDABLAS__ @@ -764,9 +764,9 @@ TEST_F(NativeOpsTests, SummaryStatsScalarTest_2) { } TEST_F(NativeOpsTests, SummaryStatsScalarTest_3) { - auto x = NDArrayFactory::create('c', {5, 5}, {0.1, 0.2, 0.3, -0.3, -0.5, 0.5, 0.7, 0.9, 0.8, 0.1, 0.11, 0.12, 0.5, -0.8, -0.9, 0.4, 0.1, 0.2, 0.3, -0.3, -0.5, 0.2, 0.3, -0.3, -0.5}); - auto exp = NDArrayFactory::create(0.9); - auto z = NDArrayFactory::create(0.21587136); + auto x = NDArrayFactory::create('c', {5, 5}, {0.1, 0.2, 0.3, -0.3, -0.5, 0.5, 0.7, 0.9, 0.8, 0.1, 0.11, 0.12, 0.5, -0.8, -0.9, 0.4, 0.1, 0.2, 0.3, -0.3, -0.5, 0.2, 0.3, -0.3, -0.5}); + auto exp = NDArrayFactory::create(0.9); + auto z = NDArrayFactory::create(0.21587136); Nd4jPointer extra[6]; #ifdef __CUDABLAS__ @@ -794,9 +794,9 @@ TEST_F(NativeOpsTests, SummaryStatsScalarTest_3) { } TEST_F(NativeOpsTests, TransformTest_1) { - auto x = NDArrayFactory::create('c', {5, 5}, {1, 4, 9, 16, 25, 36, 49, 64, 81, 100, 121, 144, 169, 196, 225, 256, 289, 324, 361, 400, 441, 484, 529, 576, 625}); - auto exp = NDArrayFactory::create('c', {5, 5}); - auto z = NDArrayFactory::create('c', {5,5}); + auto x = NDArrayFactory::create('c', {5, 5}, {1, 4, 9, 16, 25, 36, 49, 64, 81, 100, 121, 144, 169, 196, 225, 256, 289, 324, 361, 400, 441, 484, 529, 576, 625}); + auto exp = NDArrayFactory::create('c', {5, 5}); + auto z = NDArrayFactory::create('c', {5,5}); Nd4jPointer extra[6]; #ifdef __CUDABLAS__ @@ -821,7 +821,7 @@ TEST_F(NativeOpsTests, TransformTest_1) { } TEST_F(NativeOpsTests, TransformTest_2) { - auto x = NDArrayFactory::create('c', {5, 5}, {1, 4, 9, 16, 25, 36, 49, 64, 81, 100, 121, 144, 169, 196, 225, 256, 289, 324, 361, 400, 441, 484, 529, 576, 625}); + auto x = NDArrayFactory::create('c', {5, 5}, {1.f, 4.f, 9.f, 16.f, 25.f, 36.f, 49.f, 64.f, 81.f, 100.f, 121.f, 144.f, 169.f, 196.f, 225.f, 256.f, 289.f, 324.f, 361.f, 400.f, 441.f, 484.f, 529.f, 576.f, 625.f}); auto exp = NDArrayFactory::create('c', {5, 5}); auto z = NDArrayFactory::create('c', {5,5}); @@ -878,10 +878,10 @@ TEST_F(NativeOpsTests, TransformTest_3) { } TEST_F(NativeOpsTests, TransformTest_4) { - auto x = NDArrayFactory::create('c', {5, 5}, {0, 1, 2, 3, 2, 1, 0, 1.57, 1.57, 1.57, 3.141592, 3.141592, + auto x = NDArrayFactory::create('c', {5, 5}, {0, 1, 2, 3, 2, 1, 0, 1.57, 1.57, 1.57, 3.141592, 3.141592, 3.141592, 0, 0, 0, 0, 1, 1, 2, 2, 2, 1, 0, 0}); - auto exp = NDArrayFactory::create('c', {5, 5}); - auto z = NDArrayFactory::create('c', {5,5}, {1., 0.540302, -0.416147, -0.989992, -0.416147, 0.540302, 1.0, + auto exp = NDArrayFactory::create('c', {5, 5}); + auto z = NDArrayFactory::create('c', {5,5}, {1., 0.540302, -0.416147, -0.989992, -0.416147, 0.540302, 1.0, 0.000796, 0.000796, 0.000796, -1, -1, -1, 1., 1., 1.0, 1.0, 0.540302, 0.540302, -0.416147, -0.416147, -0.416147, 0.540302, 1., 1.}); @@ -909,7 +909,7 @@ TEST_F(NativeOpsTests, TransformTest_4) { TEST_F(NativeOpsTests, ScalarTadTest_1) { auto x = NDArrayFactory::create('c', {5, 5}); - auto y = NDArrayFactory::create(10.); + auto y = NDArrayFactory::create(10.f); auto exp = NDArrayFactory::create('c', {5,5}); auto z = NDArrayFactory::create('c', {5,5}); @@ -1433,9 +1433,9 @@ TEST_F(NativeOpsTests, MapTests_1) { } TEST_F(NativeOpsTests, CustomOpTest_1) { - auto x = NDArrayFactory::create('c', {1, 6}, {1, 2, 3, 4, 5, 6}); + auto x = NDArrayFactory::create('c', {1, 6}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); auto z = NDArrayFactory::create('c', {6}); - auto e = NDArrayFactory::create('c', {6}, {1, 2, 3, 4, 5, 6}); + auto e = NDArrayFactory::create('c', {6}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); nd4j::ops::squeeze op; diff --git a/libnd4j/tests_cpu/layers_tests/PairwiseTests.cpp b/libnd4j/tests_cpu/layers_tests/PairwiseTests.cpp index e4c28e9ba..ec0388a0b 100644 --- a/libnd4j/tests_cpu/layers_tests/PairwiseTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/PairwiseTests.cpp @@ -23,11 +23,11 @@ class EqualsTest : public testing::Test { public: Nd4jLong firstShapeBuffer[8] = {2,1,2,1,1,0,1,102}; - float data[2] = {1.0,7.0}; + float data[2] = {1.0f, 7.0f}; Nd4jLong secondShapeBuffer[8] = {2,2,1,6,1,0,6,99}; - float dataSecond[12] = {1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0,9.0,10.0,11.0,12.0}; + float dataSecond[12] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f}; int opNum = 4; - float extraArgs[1] = {1e-6}; + float extraArgs[1] = {1e-6f}; int dimension[1] = {2147483647}; int dimensionLength = 1; }; diff --git a/libnd4j/tests_cpu/layers_tests/ReduceTests.cpp b/libnd4j/tests_cpu/layers_tests/ReduceTests.cpp index 8bf12f58b..b91730954 100644 --- a/libnd4j/tests_cpu/layers_tests/ReduceTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/ReduceTests.cpp @@ -26,14 +26,14 @@ class ReduceTest : public testing::Test { public: Nd4jLong shape[2] = {500,3}; - float x[1500] = {4.0,2.0,3.0,8.0,4.0,6.0,12.0,6.0,9.0,16.0,8.0,12.0,20.0,10.0,15.0,24.0,12.0,18.0,28.0,14.0,21.0,32.0,16.0,24.0,36.0,18.0,27.0,40.0,20.0,30.0,44.0,22.0,33.0,48.0,24.0,36.0,52.0,26.0,39.0,56.0,28.0,42.0,60.0,30.0,45.0,64.0,32.0,48.0,68.0,34.0,51.0,72.0,36.0,54.0,76.0,38.0,57.0,80.0,40.0,60.0,84.0,42.0,63.0,88.0,44.0,66.0,92.0,46.0,69.0,96.0,48.0,72.0,100.0,50.0,75.0,104.0,52.0,78.0,108.0,54.0,81.0,112.0,56.0,84.0,116.0,58.0,87.0,120.0,60.0,90.0,124.0,62.0,93.0,128.0,64.0,96.0,132.0,66.0,99.0,136.0,68.0,102.0,140.0,70.0,105.0,144.0,72.0,108.0,148.0,74.0,111.0,152.0,76.0,114.0,156.0,78.0,117.0,160.0,80.0,120.0,164.0,82.0,123.0,168.0,84.0,126.0,172.0,86.0,129.0,176.0,88.0,132.0,180.0,90.0,135.0,184.0,92.0,138.0,188.0,94.0,141.0,192.0,96.0,144.0,196.0,98.0,147.0,200.0,100.0,150.0,204.0,102.0,153.0,208.0,104.0,156.0,212.0,106.0,159.0,216.0,108.0,162.0,220.0,110.0,165.0,224.0,112.0,168.0,228.0,114.0,171.0,232.0,116.0,174.0,236.0,118.0,177.0,240.0,120.0,180.0,244.0,122.0,183.0,248.0,124.0,186.0,252.0,126.0,189.0,256.0,128.0,192.0,260.0,130.0,195.0,264.0,132.0,198.0,268.0,134.0,201.0,272.0,136.0,204.0,276.0,138.0,207.0,280.0,140.0,210.0,284.0,142.0,213.0,288.0,144.0,216.0,292.0,146.0,219.0,296.0,148.0,222.0,300.0,150.0,225.0,304.0,152.0,228.0,308.0,154.0,231.0,312.0,156.0,234.0,316.0,158.0,237.0,320.0,160.0,240.0,324.0,162.0,243.0,328.0,164.0,246.0,332.0,166.0,249.0,336.0,168.0,252.0,340.0,170.0,255.0,344.0,172.0,258.0,348.0,174.0,261.0,352.0,176.0,264.0,356.0,178.0,267.0,360.0,180.0,270.0,364.0,182.0,273.0,368.0,184.0,276.0,372.0,186.0,279.0,376.0,188.0,282.0,380.0,190.0,285.0,384.0,192.0,288.0,388.0,194.0,291.0,392.0,196.0,294.0,396.0,198.0,297.0,400.0,200.0,300.0,404.0,202.0,303.0,408.0,204.0,306.0,412.0,206.0,309.0,416.0,208.0,312.0,420.0,210.0,315.0,424.0,212.0,318.0,428.0,214.0,321.0,432.0,216.0,324.0,436.0,218.0,327.0,440.0,220.0,330.0,444.0,222.0,333.0,448.0,224.0,336.0,452.0,226.0,339.0,456.0,228.0,342.0,460.0,230.0,345.0,464.0,232.0,348.0,468.0,234.0,351.0,472.0,236.0,354.0,476.0,238.0,357.0,480.0,240.0,360.0,484.0,242.0,363.0,488.0,244.0,366.0,492.0,246.0,369.0,496.0,248.0,372.0,500.0,250.0,375.0,504.0,252.0,378.0,508.0,254.0,381.0,512.0,256.0,384.0,516.0,258.0,387.0,520.0,260.0,390.0,524.0,262.0,393.0,528.0,264.0,396.0,532.0,266.0,399.0,536.0,268.0,402.0,540.0,270.0,405.0,544.0,272.0,408.0,548.0,274.0,411.0,552.0,276.0,414.0,556.0,278.0,417.0,560.0,280.0,420.0,564.0,282.0,423.0,568.0,284.0,426.0,572.0,286.0,429.0,576.0,288.0,432.0,580.0,290.0,435.0,584.0,292.0,438.0,588.0,294.0,441.0,592.0,296.0,444.0,596.0,298.0,447.0,600.0,300.0,450.0,604.0,302.0,453.0,608.0,304.0,456.0,612.0,306.0,459.0,616.0,308.0,462.0,620.0,310.0,465.0,624.0,312.0,468.0,628.0,314.0,471.0,632.0,316.0,474.0,636.0,318.0,477.0,640.0,320.0,480.0,644.0,322.0,483.0,648.0,324.0,486.0,652.0,326.0,489.0,656.0,328.0,492.0,660.0,330.0,495.0,664.0,332.0,498.0,668.0,334.0,501.0,672.0,336.0,504.0,676.0,338.0,507.0,680.0,340.0,510.0,684.0,342.0,513.0,688.0,344.0,516.0,692.0,346.0,519.0,696.0,348.0,522.0,700.0,350.0,525.0,704.0,352.0,528.0,708.0,354.0,531.0,712.0,356.0,534.0,716.0,358.0,537.0,720.0,360.0,540.0,724.0,362.0,543.0,728.0,364.0,546.0,732.0,366.0,549.0,736.0,368.0,552.0,740.0,370.0,555.0,744.0,372.0,558.0,748.0,374.0,561.0,752.0,376.0,564.0,756.0,378.0,567.0,760.0,380.0,570.0,764.0,382.0,573.0,768.0,384.0,576.0,772.0,386.0,579.0,776.0,388.0,582.0,780.0,390.0,585.0,784.0,392.0,588.0,788.0,394.0,591.0,792.0,396.0,594.0,796.0,398.0,597.0,800.0,400.0,600.0,804.0,402.0,603.0,808.0,404.0,606.0,812.0,406.0,609.0,816.0,408.0,612.0,820.0,410.0,615.0,824.0,412.0,618.0,828.0,414.0,621.0,832.0,416.0,624.0,836.0,418.0,627.0,840.0,420.0,630.0,844.0,422.0,633.0,848.0,424.0,636.0,852.0,426.0,639.0,856.0,428.0,642.0,860.0,430.0,645.0,864.0,432.0,648.0,868.0,434.0,651.0,872.0,436.0,654.0,876.0,438.0,657.0,880.0,440.0,660.0,884.0,442.0,663.0,888.0,444.0,666.0,892.0,446.0,669.0,896.0,448.0,672.0,900.0,450.0,675.0,904.0,452.0,678.0,908.0,454.0,681.0,912.0,456.0,684.0,916.0,458.0,687.0,920.0,460.0,690.0,924.0,462.0,693.0,928.0,464.0,696.0,932.0,466.0,699.0,936.0,468.0,702.0,940.0,470.0,705.0,944.0,472.0,708.0,948.0,474.0,711.0,952.0,476.0,714.0,956.0,478.0,717.0,960.0,480.0,720.0,964.0,482.0,723.0,968.0,484.0,726.0,972.0,486.0,729.0,976.0,488.0,732.0,980.0,490.0,735.0,984.0,492.0,738.0,988.0,494.0,741.0,992.0,496.0,744.0,996.0,498.0,747.0,1000.0,500.0,750.0,1004.0,502.0,753.0,1008.0,504.0,756.0,1012.0,506.0,759.0,1016.0,508.0,762.0,1020.0,510.0,765.0,1024.0,512.0,768.0,1028.0,514.0,771.0,1032.0,516.0,774.0,1036.0,518.0,777.0,1040.0,520.0,780.0,1044.0,522.0,783.0,1048.0,524.0,786.0,1052.0,526.0,789.0,1056.0,528.0,792.0,1060.0,530.0,795.0,1064.0,532.0,798.0,1068.0,534.0,801.0,1072.0,536.0,804.0,1076.0,538.0,807.0,1080.0,540.0,810.0,1084.0,542.0,813.0,1088.0,544.0,816.0,1092.0,546.0,819.0,1096.0,548.0,822.0,1100.0,550.0,825.0,1104.0,552.0,828.0,1108.0,554.0,831.0,1112.0,556.0,834.0,1116.0,558.0,837.0,1120.0,560.0,840.0,1124.0,562.0,843.0,1128.0,564.0,846.0,1132.0,566.0,849.0,1136.0,568.0,852.0,1140.0,570.0,855.0,1144.0,572.0,858.0,1148.0,574.0,861.0,1152.0,576.0,864.0,1156.0,578.0,867.0,1160.0,580.0,870.0,1164.0,582.0,873.0,1168.0,584.0,876.0,1172.0,586.0,879.0,1176.0,588.0,882.0,1180.0,590.0,885.0,1184.0,592.0,888.0,1188.0,594.0,891.0,1192.0,596.0,894.0,1196.0,598.0,897.0,1200.0,600.0,900.0,1204.0,602.0,903.0,1208.0,604.0,906.0,1212.0,606.0,909.0,1216.0,608.0,912.0,1220.0,610.0,915.0,1224.0,612.0,918.0,1228.0,614.0,921.0,1232.0,616.0,924.0,1236.0,618.0,927.0,1240.0,620.0,930.0,1244.0,622.0,933.0,1248.0,624.0,936.0,1252.0,626.0,939.0,1256.0,628.0,942.0,1260.0,630.0,945.0,1264.0,632.0,948.0,1268.0,634.0,951.0,1272.0,636.0,954.0,1276.0,638.0,957.0,1280.0,640.0,960.0,1284.0,642.0,963.0,1288.0,644.0,966.0,1292.0,646.0,969.0,1296.0,648.0,972.0,1300.0,650.0,975.0,1304.0,652.0,978.0,1308.0,654.0,981.0,1312.0,656.0,984.0,1316.0,658.0,987.0,1320.0,660.0,990.0,1324.0,662.0,993.0,1328.0,664.0,996.0,1332.0,666.0,999.0,1336.0,668.0,1002.0,1340.0,670.0,1005.0,1344.0,672.0,1008.0,1348.0,674.0,1011.0,1352.0,676.0,1014.0,1356.0,678.0,1017.0,1360.0,680.0,1020.0,1364.0,682.0,1023.0,1368.0,684.0,1026.0,1372.0,686.0,1029.0,1376.0,688.0,1032.0,1380.0,690.0,1035.0,1384.0,692.0,1038.0,1388.0,694.0,1041.0,1392.0,696.0,1044.0,1396.0,698.0,1047.0,1400.0,700.0,1050.0,1404.0,702.0,1053.0,1408.0,704.0,1056.0,1412.0,706.0,1059.0,1416.0,708.0,1062.0,1420.0,710.0,1065.0,1424.0,712.0,1068.0,1428.0,714.0,1071.0,1432.0,716.0,1074.0,1436.0,718.0,1077.0,1440.0,720.0,1080.0,1444.0,722.0,1083.0,1448.0,724.0,1086.0,1452.0,726.0,1089.0,1456.0,728.0,1092.0,1460.0,730.0,1095.0,1464.0,732.0,1098.0,1468.0,734.0,1101.0,1472.0,736.0,1104.0,1476.0,738.0,1107.0,1480.0,740.0,1110.0,1484.0,742.0,1113.0,1488.0,744.0,1116.0,1492.0,746.0,1119.0,1496.0,748.0,1122.0,1500.0,750.0,1125.0,1504.0,752.0,1128.0,1508.0,754.0,1131.0,1512.0,756.0,1134.0,1516.0,758.0,1137.0,1520.0,760.0,1140.0,1524.0,762.0,1143.0,1528.0,764.0,1146.0,1532.0,766.0,1149.0,1536.0,768.0,1152.0,1540.0,770.0,1155.0,1544.0,772.0,1158.0,1548.0,774.0,1161.0,1552.0,776.0,1164.0,1556.0,778.0,1167.0,1560.0,780.0,1170.0,1564.0,782.0,1173.0,1568.0,784.0,1176.0,1572.0,786.0,1179.0,1576.0,788.0,1182.0,1580.0,790.0,1185.0,1584.0,792.0,1188.0,1588.0,794.0,1191.0,1592.0,796.0,1194.0,1596.0,798.0,1197.0,1600.0,800.0,1200.0,1604.0,802.0,1203.0,1608.0,804.0,1206.0,1612.0,806.0,1209.0,1616.0,808.0,1212.0,1620.0,810.0,1215.0,1624.0,812.0,1218.0,1628.0,814.0,1221.0,1632.0,816.0,1224.0,1636.0,818.0,1227.0,1640.0,820.0,1230.0,1644.0,822.0,1233.0,1648.0,824.0,1236.0,1652.0,826.0,1239.0,1656.0,828.0,1242.0,1660.0,830.0,1245.0,1664.0,832.0,1248.0,1668.0,834.0,1251.0,1672.0,836.0,1254.0,1676.0,838.0,1257.0,1680.0,840.0,1260.0,1684.0,842.0,1263.0,1688.0,844.0,1266.0,1692.0,846.0,1269.0,1696.0,848.0,1272.0,1700.0,850.0,1275.0,1704.0,852.0,1278.0,1708.0,854.0,1281.0,1712.0,856.0,1284.0,1716.0,858.0,1287.0,1720.0,860.0,1290.0,1724.0,862.0,1293.0,1728.0,864.0,1296.0,1732.0,866.0,1299.0,1736.0,868.0,1302.0,1740.0,870.0,1305.0,1744.0,872.0,1308.0,1748.0,874.0,1311.0,1752.0,876.0,1314.0,1756.0,878.0,1317.0,1760.0,880.0,1320.0,1764.0,882.0,1323.0,1768.0,884.0,1326.0,1772.0,886.0,1329.0,1776.0,888.0,1332.0,1780.0,890.0,1335.0,1784.0,892.0,1338.0,1788.0,894.0,1341.0,1792.0,896.0,1344.0,1796.0,898.0,1347.0,1800.0,900.0,1350.0,1804.0,902.0,1353.0,1808.0,904.0,1356.0,1812.0,906.0,1359.0,1816.0,908.0,1362.0,1820.0,910.0,1365.0,1824.0,912.0,1368.0,1828.0,914.0,1371.0,1832.0,916.0,1374.0,1836.0,918.0,1377.0,1840.0,920.0,1380.0,1844.0,922.0,1383.0,1848.0,924.0,1386.0,1852.0,926.0,1389.0,1856.0,928.0,1392.0,1860.0,930.0,1395.0,1864.0,932.0,1398.0,1868.0,934.0,1401.0,1872.0,936.0,1404.0,1876.0,938.0,1407.0,1880.0,940.0,1410.0,1884.0,942.0,1413.0,1888.0,944.0,1416.0,1892.0,946.0,1419.0,1896.0,948.0,1422.0,1900.0,950.0,1425.0,1904.0,952.0,1428.0,1908.0,954.0,1431.0,1912.0,956.0,1434.0,1916.0,958.0,1437.0,1920.0,960.0,1440.0,1924.0,962.0,1443.0,1928.0,964.0,1446.0,1932.0,966.0,1449.0,1936.0,968.0,1452.0,1940.0,970.0,1455.0,1944.0,972.0,1458.0,1948.0,974.0,1461.0,1952.0,976.0,1464.0,1956.0,978.0,1467.0,1960.0,980.0,1470.0,1964.0,982.0,1473.0,1968.0,984.0,1476.0,1972.0,986.0,1479.0,1976.0,988.0,1482.0,1980.0,990.0,1485.0,1984.0,992.0,1488.0,1988.0,994.0,1491.0,1992.0,996.0,1494.0,1996.0,998.0,1497.0,2000.0,1000.0,1500.0}; - float result[1500] = {0}; + float x[1500] = {4.0f, 2.0f, 3.0f, 8.0f, 4.0f, 6.0f, 12.0f, 6.0f, 9.0f, 16.0f, 8.0f, 12.0f, 20.0f, 10.0f, 15.0f, 24.0f, 12.0f, 18.0f, 28.0f, 14.0f, 21.0f, 32.0f, 16.0f, 24.0f, 36.0f, 18.0f, 27.0f, 40.0f, 20.0f, 30.0f, 44.0f, 22.0f, 33.0f, 48.0f, 24.0f, 36.0f, 52.0f, 26.0f, 39.0f, 56.0f, 28.0f, 42.0f, 60.0f, 30.0f, 45.0f, 64.0f, 32.0f, 48.0f, 68.0f, 34.0f, 51.0f, 72.0f, 36.0f, 54.0f, 76.0f, 38.0f, 57.0f, 80.0f, 40.0f, 60.0f, 84.0f, 42.0f, 63.0f, 88.0f, 44.0f, 66.0f, 92.0f, 46.0f, 69.0f, 96.0f, 48.0f, 72.0f, 100.0f, 50.0f, 75.0f, 104.0f, 52.0f, 78.0f, 108.0f, 54.0f, 81.0f, 112.0f, 56.0f, 84.0f, 116.0f, 58.0f, 87.0f, 120.0f, 60.0f, 90.0f, 124.0f, 62.0f, 93.0f, 128.0f, 64.0f, 96.0f, 132.0f, 66.0f, 99.0f, 136.0f, 68.0f, 102.0f, 140.0f, 70.0f, 105.0f, 144.0f, 72.0f, 108.0f, 148.0f, 74.0f, 111.0f, 152.0f, 76.0f, 114.0f, 156.0f, 78.0f, 117.0f, 160.0f, 80.0f, 120.0f, 164.0f, 82.0f, 123.0f, 168.0f, 84.0f, 126.0f, 172.0f, 86.0f, 129.0f, 176.0f, 88.0f, 132.0f, 180.0f, 90.0f, 135.0f, 184.0f, 92.0f, 138.0f, 188.0f, 94.0f, 141.0f, 192.0f, 96.0f, 144.0f, 196.0f, 98.0f, 147.0f, 200.0f, 100.0f, 150.0f, 204.0f, 102.0f, 153.0f, 208.0f, 104.0f, 156.0f, 212.0f, 106.0f, 159.0f, 216.0f, 108.0f, 162.0f, 220.0f, 110.0f, 165.0f, 224.0f, 112.0f, 168.0f, 228.0f, 114.0f, 171.0f, 232.0f, 116.0f, 174.0f, 236.0f, 118.0f, 177.0f, 240.0f, 120.0f, 180.0f, 244.0f, 122.0f, 183.0f, 248.0f, 124.0f, 186.0f, 252.0f, 126.0f, 189.0f, 256.0f, 128.0f, 192.0f, 260.0f, 130.0f, 195.0f, 264.0f, 132.0f, 198.0f, 268.0f, 134.0f, 201.0f, 272.0f, 136.0f, 204.0f, 276.0f, 138.0f, 207.0f, 280.0f, 140.0f, 210.0f, 284.0f, 142.0f, 213.0f, 288.0f, 144.0f, 216.0f, 292.0f, 146.0f, 219.0f, 296.0f, 148.0f, 222.0f, 300.0f, 150.0f, 225.0f, 304.0f, 152.0f, 228.0f, 308.0f, 154.0f, 231.0f, 312.0f, 156.0f, 234.0f, 316.0f, 158.0f, 237.0f, 320.0f, 160.0f, 240.0f, 324.0f, 162.0f, 243.0f, 328.0f, 164.0f, 246.0f, 332.0f, 166.0f, 249.0f, 336.0f, 168.0f, 252.0f, 340.0f, 170.0f, 255.0f, 344.0f, 172.0f, 258.0f, 348.0f, 174.0f, 261.0f, 352.0f, 176.0f, 264.0f, 356.0f, 178.0f, 267.0f, 360.0f, 180.0f, 270.0f, 364.0f, 182.0f, 273.0f, 368.0f, 184.0f, 276.0f, 372.0f, 186.0f, 279.0f, 376.0f, 188.0f, 282.0f, 380.0f, 190.0f, 285.0f, 384.0f, 192.0f, 288.0f, 388.0f, 194.0f, 291.0f, 392.0f, 196.0f, 294.0f, 396.0f, 198.0f, 297.0f, 400.0f, 200.0f, 300.0f, 404.0f, 202.0f, 303.0f, 408.0f, 204.0f, 306.0f, 412.0f, 206.0f, 309.0f, 416.0f, 208.0f, 312.0f, 420.0f, 210.0f, 315.0f, 424.0f, 212.0f, 318.0f, 428.0f, 214.0f, 321.0f, 432.0f, 216.0f, 324.0f, 436.0f, 218.0f, 327.0f, 440.0f, 220.0f, 330.0f, 444.0f, 222.0f, 333.0f, 448.0f, 224.0f, 336.0f, 452.0f, 226.0f, 339.0f, 456.0f, 228.0f, 342.0f, 460.0f, 230.0f, 345.0f, 464.0f, 232.0f, 348.0f, 468.0f, 234.0f, 351.0f, 472.0f, 236.0f, 354.0f, 476.0f, 238.0f, 357.0f, 480.0f, 240.0f, 360.0f, 484.0f, 242.0f, 363.0f, 488.0f, 244.0f, 366.0f, 492.0f, 246.0f, 369.0f, 496.0f, 248.0f, 372.0f, 500.0f, 250.0f, 375.0f, 504.0f, 252.0f, 378.0f, 508.0f, 254.0f, 381.0f, 512.0f, 256.0f, 384.0f, 516.0f, 258.0f, 387.0f, 520.0f, 260.0f, 390.0f, 524.0f, 262.0f, 393.0f, 528.0f, 264.0f, 396.0f, 532.0f, 266.0f, 399.0f, 536.0f, 268.0f, 402.0f, 540.0f, 270.0f, 405.0f, 544.0f, 272.0f, 408.0f, 548.0f, 274.0f, 411.0f, 552.0f, 276.0f, 414.0f, 556.0f, 278.0f, 417.0f, 560.0f, 280.0f, 420.0f, 564.0f, 282.0f, 423.0f, 568.0f, 284.0f, 426.0f, 572.0f, 286.0f, 429.0f, 576.0f, 288.0f, 432.0f, 580.0f, 290.0f, 435.0f, 584.0f, 292.0f, 438.0f, 588.0f, 294.0f, 441.0f, 592.0f, 296.0f, 444.0f, 596.0f, 298.0f, 447.0f, 600.0f, 300.0f, 450.0f, 604.0f, 302.0f, 453.0f, 608.0f, 304.0f, 456.0f, 612.0f, 306.0f, 459.0f, 616.0f, 308.0f, 462.0f, 620.0f, 310.0f, 465.0f, 624.0f, 312.0f, 468.0f, 628.0f, 314.0f, 471.0f, 632.0f, 316.0f, 474.0f, 636.0f, 318.0f, 477.0f, 640.0f, 320.0f, 480.0f, 644.0f, 322.0f, 483.0f, 648.0f, 324.0f, 486.0f, 652.0f, 326.0f, 489.0f, 656.0f, 328.0f, 492.0f, 660.0f, 330.0f, 495.0f, 664.0f, 332.0f, 498.0f, 668.0f, 334.0f, 501.0f, 672.0f, 336.0f, 504.0f, 676.0f, 338.0f, 507.0f, 680.0f, 340.0f, 510.0f, 684.0f, 342.0f, 513.0f, 688.0f, 344.0f, 516.0f, 692.0f, 346.0f, 519.0f, 696.0f, 348.0f, 522.0f, 700.0f, 350.0f, 525.0f, 704.0f, 352.0f, 528.0f, 708.0f, 354.0f, 531.0f, 712.0f, 356.0f, 534.0f, 716.0f, 358.0f, 537.0f, 720.0f, 360.0f, 540.0f, 724.0f, 362.0f, 543.0f, 728.0f, 364.0f, 546.0f, 732.0f, 366.0f, 549.0f, 736.0f, 368.0f, 552.0f, 740.0f, 370.0f, 555.0f, 744.0f, 372.0f, 558.0f, 748.0f, 374.0f, 561.0f, 752.0f, 376.0f, 564.0f, 756.0f, 378.0f, 567.0f, 760.0f, 380.0f, 570.0f, 764.0f, 382.0f, 573.0f, 768.0f, 384.0f, 576.0f, 772.0f, 386.0f, 579.0f, 776.0f, 388.0f, 582.0f, 780.0f, 390.0f, 585.0f, 784.0f, 392.0f, 588.0f, 788.0f, 394.0f, 591.0f, 792.0f, 396.0f, 594.0f, 796.0f, 398.0f, 597.0f, 800.0f, 400.0f, 600.0f, 804.0f, 402.0f, 603.0f, 808.0f, 404.0f, 606.0f, 812.0f, 406.0f, 609.0f, 816.0f, 408.0f, 612.0f, 820.0f, 410.0f, 615.0f, 824.0f, 412.0f, 618.0f, 828.0f, 414.0f, 621.0f, 832.0f, 416.0f, 624.0f, 836.0f, 418.0f, 627.0f, 840.0f, 420.0f, 630.0f, 844.0f, 422.0f, 633.0f, 848.0f, 424.0f, 636.0f, 852.0f, 426.0f, 639.0f, 856.0f, 428.0f, 642.0f, 860.0f, 430.0f, 645.0f, 864.0f, 432.0f, 648.0f, 868.0f, 434.0f, 651.0f, 872.0f, 436.0f, 654.0f, 876.0f, 438.0f, 657.0f, 880.0f, 440.0f, 660.0f, 884.0f, 442.0f, 663.0f, 888.0f, 444.0f, 666.0f, 892.0f, 446.0f, 669.0f, 896.0f, 448.0f, 672.0f, 900.0f, 450.0f, 675.0f, 904.0f, 452.0f, 678.0f, 908.0f, 454.0f, 681.0f, 912.0f, 456.0f, 684.0f, 916.0f, 458.0f, 687.0f, 920.0f, 460.0f, 690.0f, 924.0f, 462.0f, 693.0f, 928.0f, 464.0f, 696.0f, 932.0f, 466.0f, 699.0f, 936.0f, 468.0f, 702.0f, 940.0f, 470.0f, 705.0f, 944.0f, 472.0f, 708.0f, 948.0f, 474.0f, 711.0f, 952.0f, 476.0f, 714.0f, 956.0f, 478.0f, 717.0f, 960.0f, 480.0f, 720.0f, 964.0f, 482.0f, 723.0f, 968.0f, 484.0f, 726.0f, 972.0f, 486.0f, 729.0f, 976.0f, 488.0f, 732.0f, 980.0f, 490.0f, 735.0f, 984.0f, 492.0f, 738.0f, 988.0f, 494.0f, 741.0f, 992.0f, 496.0f, 744.0f, 996.0f, 498.0f, 747.0f, 1000.0f, 500.0f, 750.0f, 1004.0f, 502.0f, 753.0f, 1008.0f, 504.0f, 756.0f, 1012.0f, 506.0f, 759.0f, 1016.0f, 508.0f, 762.0f, 1020.0f, 510.0f, 765.0f, 1024.0f, 512.0f, 768.0f, 1028.0f, 514.0f, 771.0f, 1032.0f, 516.0f, 774.0f, 1036.0f, 518.0f, 777.0f, 1040.0f, 520.0f, 780.0f, 1044.0f, 522.0f, 783.0f, 1048.0f, 524.0f, 786.0f, 1052.0f, 526.0f, 789.0f, 1056.0f, 528.0f, 792.0f, 1060.0f, 530.0f, 795.0f, 1064.0f, 532.0f, 798.0f, 1068.0f, 534.0f, 801.0f, 1072.0f, 536.0f, 804.0f, 1076.0f, 538.0f, 807.0f, 1080.0f, 540.0f, 810.0f, 1084.0f, 542.0f, 813.0f, 1088.0f, 544.0f, 816.0f, 1092.0f, 546.0f, 819.0f, 1096.0f, 548.0f, 822.0f, 1100.0f, 550.0f, 825.0f, 1104.0f, 552.0f, 828.0f, 1108.0f, 554.0f, 831.0f, 1112.0f, 556.0f, 834.0f, 1116.0f, 558.0f, 837.0f, 1120.0f, 560.0f, 840.0f, 1124.0f, 562.0f, 843.0f, 1128.0f, 564.0f, 846.0f, 1132.0f, 566.0f, 849.0f, 1136.0f, 568.0f, 852.0f, 1140.0f, 570.0f, 855.0f, 1144.0f, 572.0f, 858.0f, 1148.0f, 574.0f, 861.0f, 1152.0f, 576.0f, 864.0f, 1156.0f, 578.0f, 867.0f, 1160.0f, 580.0f, 870.0f, 1164.0f, 582.0f, 873.0f, 1168.0f, 584.0f, 876.0f, 1172.0f, 586.0f, 879.0f, 1176.0f, 588.0f, 882.0f, 1180.0f, 590.0f, 885.0f, 1184.0f, 592.0f, 888.0f, 1188.0f, 594.0f, 891.0f, 1192.0f, 596.0f, 894.0f, 1196.0f, 598.0f, 897.0f, 1200.0f, 600.0f, 900.0f, 1204.0f, 602.0f, 903.0f, 1208.0f, 604.0f, 906.0f, 1212.0f, 606.0f, 909.0f, 1216.0f, 608.0f, 912.0f, 1220.0f, 610.0f, 915.0f, 1224.0f, 612.0f, 918.0f, 1228.0f, 614.0f, 921.0f, 1232.0f, 616.0f, 924.0f, 1236.0f, 618.0f, 927.0f, 1240.0f, 620.0f, 930.0f, 1244.0f, 622.0f, 933.0f, 1248.0f, 624.0f, 936.0f, 1252.0f, 626.0f, 939.0f, 1256.0f, 628.0f, 942.0f, 1260.0f, 630.0f, 945.0f, 1264.0f, 632.0f, 948.0f, 1268.0f, 634.0f, 951.0f, 1272.0f, 636.0f, 954.0f, 1276.0f, 638.0f, 957.0f, 1280.0f, 640.0f, 960.0f, 1284.0f, 642.0f, 963.0f, 1288.0f, 644.0f, 966.0f, 1292.0f, 646.0f, 969.0f, 1296.0f, 648.0f, 972.0f, 1300.0f, 650.0f, 975.0f, 1304.0f, 652.0f, 978.0f, 1308.0f, 654.0f, 981.0f, 1312.0f, 656.0f, 984.0f, 1316.0f, 658.0f, 987.0f, 1320.0f, 660.0f, 990.0f, 1324.0f, 662.0f, 993.0f, 1328.0f, 664.0f, 996.0f, 1332.0f, 666.0f, 999.0f, 1336.0f, 668.0f, 1002.0f, 1340.0f, 670.0f, 1005.0f, 1344.0f, 672.0f, 1008.0f, 1348.0f, 674.0f, 1011.0f, 1352.0f, 676.0f, 1014.0f, 1356.0f, 678.0f, 1017.0f, 1360.0f, 680.0f, 1020.0f, 1364.0f, 682.0f, 1023.0f, 1368.0f, 684.0f, 1026.0f, 1372.0f, 686.0f, 1029.0f, 1376.0f, 688.0f, 1032.0f, 1380.0f, 690.0f, 1035.0f, 1384.0f, 692.0f, 1038.0f, 1388.0f, 694.0f, 1041.0f, 1392.0f, 696.0f, 1044.0f, 1396.0f, 698.0f, 1047.0f, 1400.0f, 700.0f, 1050.0f, 1404.0f, 702.0f, 1053.0f, 1408.0f, 704.0f, 1056.0f, 1412.0f, 706.0f, 1059.0f, 1416.0f, 708.0f, 1062.0f, 1420.0f, 710.0f, 1065.0f, 1424.0f, 712.0f, 1068.0f, 1428.0f, 714.0f, 1071.0f, 1432.0f, 716.0f, 1074.0f, 1436.0f, 718.0f, 1077.0f, 1440.0f, 720.0f, 1080.0f, 1444.0f, 722.0f, 1083.0f, 1448.0f, 724.0f, 1086.0f, 1452.0f, 726.0f, 1089.0f, 1456.0f, 728.0f, 1092.0f, 1460.0f, 730.0f, 1095.0f, 1464.0f, 732.0f, 1098.0f, 1468.0f, 734.0f, 1101.0f, 1472.0f, 736.0f, 1104.0f, 1476.0f, 738.0f, 1107.0f, 1480.0f, 740.0f, 1110.0f, 1484.0f, 742.0f, 1113.0f, 1488.0f, 744.0f, 1116.0f, 1492.0f, 746.0f, 1119.0f, 1496.0f, 748.0f, 1122.0f, 1500.0f, 750.0f, 1125.0f, 1504.0f, 752.0f, 1128.0f, 1508.0f, 754.0f, 1131.0f, 1512.0f, 756.0f, 1134.0f, 1516.0f, 758.0f, 1137.0f, 1520.0f, 760.0f, 1140.0f, 1524.0f, 762.0f, 1143.0f, 1528.0f, 764.0f, 1146.0f, 1532.0f, 766.0f, 1149.0f, 1536.0f, 768.0f, 1152.0f, 1540.0f, 770.0f, 1155.0f, 1544.0f, 772.0f, 1158.0f, 1548.0f, 774.0f, 1161.0f, 1552.0f, 776.0f, 1164.0f, 1556.0f, 778.0f, 1167.0f, 1560.0f, 780.0f, 1170.0f, 1564.0f, 782.0f, 1173.0f, 1568.0f, 784.0f, 1176.0f, 1572.0f, 786.0f, 1179.0f, 1576.0f, 788.0f, 1182.0f, 1580.0f, 790.0f, 1185.0f, 1584.0f, 792.0f, 1188.0f, 1588.0f, 794.0f, 1191.0f, 1592.0f, 796.0f, 1194.0f, 1596.0f, 798.0f, 1197.0f, 1600.0f, 800.0f, 1200.0f, 1604.0f, 802.0f, 1203.0f, 1608.0f, 804.0f, 1206.0f, 1612.0f, 806.0f, 1209.0f, 1616.0f, 808.0f, 1212.0f, 1620.0f, 810.0f, 1215.0f, 1624.0f, 812.0f, 1218.0f, 1628.0f, 814.0f, 1221.0f, 1632.0f, 816.0f, 1224.0f, 1636.0f, 818.0f, 1227.0f, 1640.0f, 820.0f, 1230.0f, 1644.0f, 822.0f, 1233.0f, 1648.0f, 824.0f, 1236.0f, 1652.0f, 826.0f, 1239.0f, 1656.0f, 828.0f, 1242.0f, 1660.0f, 830.0f, 1245.0f, 1664.0f, 832.0f, 1248.0f, 1668.0f, 834.0f, 1251.0f, 1672.0f, 836.0f, 1254.0f, 1676.0f, 838.0f, 1257.0f, 1680.0f, 840.0f, 1260.0f, 1684.0f, 842.0f, 1263.0f, 1688.0f, 844.0f, 1266.0f, 1692.0f, 846.0f, 1269.0f, 1696.0f, 848.0f, 1272.0f, 1700.0f, 850.0f, 1275.0f, 1704.0f, 852.0f, 1278.0f, 1708.0f, 854.0f, 1281.0f, 1712.0f, 856.0f, 1284.0f, 1716.0f, 858.0f, 1287.0f, 1720.0f, 860.0f, 1290.0f, 1724.0f, 862.0f, 1293.0f, 1728.0f, 864.0f, 1296.0f, 1732.0f, 866.0f, 1299.0f, 1736.0f, 868.0f, 1302.0f, 1740.0f, 870.0f, 1305.0f, 1744.0f, 872.0f, 1308.0f, 1748.0f, 874.0f, 1311.0f, 1752.0f, 876.0f, 1314.0f, 1756.0f, 878.0f, 1317.0f, 1760.0f, 880.0f, 1320.0f, 1764.0f, 882.0f, 1323.0f, 1768.0f, 884.0f, 1326.0f, 1772.0f, 886.0f, 1329.0f, 1776.0f, 888.0f, 1332.0f, 1780.0f, 890.0f, 1335.0f, 1784.0f, 892.0f, 1338.0f, 1788.0f, 894.0f, 1341.0f, 1792.0f, 896.0f, 1344.0f, 1796.0f, 898.0f, 1347.0f, 1800.0f, 900.0f, 1350.0f, 1804.0f, 902.0f, 1353.0f, 1808.0f, 904.0f, 1356.0f, 1812.0f, 906.0f, 1359.0f, 1816.0f, 908.0f, 1362.0f, 1820.0f, 910.0f, 1365.0f, 1824.0f, 912.0f, 1368.0f, 1828.0f, 914.0f, 1371.0f, 1832.0f, 916.0f, 1374.0f, 1836.0f, 918.0f, 1377.0f, 1840.0f, 920.0f, 1380.0f, 1844.0f, 922.0f, 1383.0f, 1848.0f, 924.0f, 1386.0f, 1852.0f, 926.0f, 1389.0f, 1856.0f, 928.0f, 1392.0f, 1860.0f, 930.0f, 1395.0f, 1864.0f, 932.0f, 1398.0f, 1868.0f, 934.0f, 1401.0f, 1872.0f, 936.0f, 1404.0f, 1876.0f, 938.0f, 1407.0f, 1880.0f, 940.0f, 1410.0f, 1884.0f, 942.0f, 1413.0f, 1888.0f, 944.0f, 1416.0f, 1892.0f, 946.0f, 1419.0f, 1896.0f, 948.0f, 1422.0f, 1900.0f, 950.0f, 1425.0f, 1904.0f, 952.0f, 1428.0f, 1908.0f, 954.0f, 1431.0f, 1912.0f, 956.0f, 1434.0f, 1916.0f, 958.0f, 1437.0f, 1920.0f, 960.0f, 1440.0f, 1924.0f, 962.0f, 1443.0f, 1928.0f, 964.0f, 1446.0f, 1932.0f, 966.0f, 1449.0f, 1936.0f, 968.0f, 1452.0f, 1940.0f, 970.0f, 1455.0f, 1944.0f, 972.0f, 1458.0f, 1948.0f, 974.0f, 1461.0f, 1952.0f, 976.0f, 1464.0f, 1956.0f, 978.0f, 1467.0f, 1960.0f, 980.0f, 1470.0f, 1964.0f, 982.0f, 1473.0f, 1968.0f, 984.0f, 1476.0f, 1972.0f, 986.0f, 1479.0f, 1976.0f, 988.0f, 1482.0f, 1980.0f, 990.0f, 1485.0f, 1984.0f, 992.0f, 1488.0f, 1988.0f, 994.0f, 1491.0f, 1992.0f, 996.0f, 1494.0f, 1996.0f, 998.0f, 1497.0f, 2000.0f, 1000.0f, 1500.0f}; + float result[1500] = {0.f}; int dimension[1] = {0}; std::vector dim = {0}; int dimensionLength = 1; - float theoreticalMin[3] = {4,2,3}; - float theoreticalMax[3] = {2000.00, 1000.00, 1500.00}; - float theoreticalRange[3] = {1996.00, 998.00, 1497.00}; + float theoreticalMin[3] = {4.f, 2.f, 3.f}; + float theoreticalMax[3] = {2000.00f, 1000.00f, 1500.00f}; + float theoreticalRange[3] = {1996.00f, 998.00f, 1497.00f}; }; class StdTest : public testing::Test { @@ -44,19 +44,19 @@ public: int dimensionLength = 3; //standard deviation int opNum = 1; - float x[7500] ={0.5786382,0.16236664,0.069020785,0.9840061,0.941816,0.76720303,0.7794372,0.46979624,0.73381734,0.9957244,0.6167372,0.53088397,0.28015637,0.826945,0.83352476,0.66504276,0.5793391,0.47484478,0.7076381,0.49456358,0.62396896,0.53332835,0.6388812,0.68836075,0.26663998,0.0014623206,0.19409843,0.56639415,0.98213744,0.68497056,0.867037,0.76840234,0.318186,0.28759065,0.11965875,0.53291357,0.53767395,0.55705845,0.7467155,0.1575149,0.18076386,0.8174763,0.22883898,0.5071535,0.86735153,0.9635827,0.24558435,0.15767147,0.458882,0.71102697,0.21914826,0.16241662,0.27248728,0.89015275,0.71070856,0.55088985,0.98992974,0.70927286,0.9261268,0.50781846,0.62151235,0.4590896,0.7487442,0.21744072,0.2636398,0.084352165,0.46951914,0.383644,0.6749645,0.24111961,0.83259743,0.05546627,0.4790621,0.68884027,0.90992177,0.23907907,0.5342047,0.221003,0.29615387,0.43343517,0.16554528,0.73144174,0.52923626,0.10688303,0.78197056,0.39259177,0.43832788,0.052234255,0.5795483,0.97033966,0.7392455,0.086584255,0.9092887,0.9402065,0.9126419,0.44749174,0.20514569,0.8749829,0.30917913,0.10170506,0.37034252,0.7427814,0.5497875,0.3116048,0.12112484,0.07918618,0.6003074,0.6188079,0.6292188,0.26580265,0.42029652,0.9863358,0.41489154,0.23757206,0.30395788,0.75231904,0.76751274,0.6324773,0.3231405,0.5016677,0.86029065,0.575702,0.7473972,0.118974194,0.115586124,0.62481487,0.91101325,0.6137756,0.71462154,0.995567,0.93439484,0.37260458,0.6033152,0.3444346,0.91579247,0.7452442,0.97466874,0.6299154,0.35426098,0.50121397,0.14155711,0.78726757,0.028531995,0.8435531,0.6444501,0.8826095,0.25354537,0.5547923,0.99555415,0.8430975,246.29712,253.4231,282.26755,215.6161,251.57019,239.20515,296.2021,234.32518,278.9852,235.4248,238.70155,256.9956,212.62695,288.38763,231.21237,284.80396,261.86835,223.92522,205.86221,234.742,262.11407,298.1942,242.60652,238.83704,251.6588,267.23315,294.4865,223.47488,259.24976,251.82695,265.01166,234.65732,265.1853,202.15352,244.42313,253.90427,212.09233,227.62961,237.77951,261.36838,234.32147,240.81522,273.62595,221.19333,284.11353,216.00859,284.36948,243.90376,282.61584,256.97165,275.08722,253.8055,265.1405,298.87567,223.393,288.02148,287.26102,276.36237,290.52777,299.57062,224.73566,290.82623,231.3513,238.51828,230.74028,224.97539,290.11844,238.00816,290.39606,291.32538,272.94766,211.88446,291.66742,210.34077,285.62628,246.31918,283.68738,282.34418,223.43613,245.08679,235.22693,246.01146,224.03375,280.5359,226.01413,262.18884,237.87335,238.89404,259.04294,202.59842,294.69302,209.01956,244.75763,264.3232,293.4627,287.69165,236.79088,282.37012,222.24211,293.5885,249.6388,273.91916,215.40356,255.45584,268.4702,275.81577,259.25064,224.95108,250.37906,267.89093,256.31766,227.89124,204.10915,263.38596,213.62708,218.84116,289.00494,216.93646,200.29439,284.1103,216.20671,260.57642,248.57745,241.73776,244.7205,286.86218,206.42664,204.06395,216.60626,224.02377,219.4697,287.2509,246.91132,289.83777,292.73767,202.73048,206.4165,294.0605,276.23276,288.51318,279.45175,253.69833,281.3311,249.44318,287.76288,262.2878,238.2247,203.41438,208.8359,274.0062,-9.999092,-9.99934,-9.999794,-9.999654,-9.999987,-9.999574,-9.99965,-9.999892,-9.999203,-9.999798,-9.999658,-9.999974,-9.999982,-9.999003,-9.999369,-9.999311,-9.999708,-9.999327,-9.999302,-9.999419,-9.999553,-9.9991665,-9.999842,-9.9991665,-9.999702,-9.999081,-9.9993725,-9.999735,-9.999399,-9.999073,-9.999045,-9.999458,-9.99971,-9.999414,-9.999165,-9.999782,-9.999417,-9.999513,-9.999398,-9.999933,-9.999367,-9.999933,-9.999302,-9.999572,-9.999926,-9.999371,-9.999746,-9.999628,-9.9995165,-9.999816,-9.9998255,-9.999983,-9.999482,-9.99976,-9.999302,-9.999825,-9.999026,-9.999029,-9.999147,-9.9995,-9.999214,-9.999216,-9.999818,-9.999334,-9.999354,-9.999414,-9.999564,-9.99962,-9.999615,-9.999496,-9.999803,-9.999454,-9.999789,-9.999615,-9.999473,-9.999701,-9.999164,-9.999112,-9.9991865,-9.999779,-9.999639,-9.999739,-9.999949,-9.999005,-9.999157,-9.999394,-9.999148,-9.999729,-9.999721,-9.999721,-9.999678,-9.999215,-9.99921,-9.999848,-9.999702,-9.999167,-9.999995,-9.999203,-9.999381,-9.999537,-9.999643,-9.999887,-9.999234,-9.999761,-9.999863,-9.9999275,-9.99965,-9.999459,-9.999674,-9.999408,-9.999761,-9.999802,-9.999465,-9.999648,-9.999447,-9.999051,-9.999212,-9.999952,-9.999188,-9.999153,-9.999513,-9.999785,-9.999538,-9.999458,-9.999802,-9.999176,-9.999821,-9.999529,-9.999089,-9.999206,-9.999853,-9.999218,-9.999763,-9.999283,-9.999687,-9.999333,-9.9996195,-9.999563,-9.99978,-9.999214,-9.999417,-9.999161,-9.999615,-9.999529,-9.999715,-9.99965,-9.999793,-9.999159,-9.999804,-9.999826,0.25581473,0.011998488,0.19125576,0.26596868,0.21618238,0.7962773,0.8030581,0.7543603,0.37575766,0.764879,0.10974313,0.06437898,0.26072952,0.30300763,0.029973997,0.025493756,0.21206349,0.7668091,0.53181326,0.36343664,0.5012292,0.17466855,0.188394,0.73864985,0.4810524,0.42596745,0.17328279,0.2649388,0.5691122,0.6979966,0.40108117,0.680846,0.8891427,0.36562127,0.5258834,0.02162829,0.34679192,0.51932955,0.5934363,0.8976068,0.17759448,0.84487504,0.08563967,0.8079017,0.53375924,0.5292685,0.7386051,0.84675163,0.52025354,0.402771,0.25339442,0.020660425,0.8532977,0.26857603,0.08696012,0.30953142,0.05712433,0.52134746,0.668039,0.8811842,0.84066904,0.5784957,0.13710192,0.25812075,0.12778813,0.6114538,0.68826395,0.6296169,0.050615292,0.60265064,0.59383374,0.50250226,0.5533876,0.80024,0.15964289,0.44098398,0.3639451,0.9836441,0.59009975,0.42786047,0.66358715,0.77674544,0.96205765,0.30722687,0.07275952,0.8073388,0.8589582,0.1655514,0.942791,0.7421209,0.33589354,0.031047517,0.2333922,0.32696965,0.06680667,0.43655157,0.60084665,0.924222,0.5181169,0.8633322,0.07042168,0.3576994,0.23789743,0.98523647,0.35718223,0.09434685,0.7895948,0.6365413,0.7331945,0.8172492,0.2427676,0.23792028,0.7375947,0.72343403,0.47277793,0.53527576,0.30485073,0.64892334,0.15171374,0.8003455,0.9694175,0.3611101,0.8037058,0.7925937,0.18575527,0.81588566,0.094868064,0.9775748,0.6791609,0.26662946,0.18830737,0.595805,0.49300948,0.9033739,0.663468,0.3000145,0.57594025,0.8624458,0.18944798,0.65868706,0.35742447,0.099066,0.2832066,0.6912541,0.24243657,0.9277832,0.64250916,0.9440414,0.2378183,0.055244252,0.76272976,0.67200613,0.49664533,0.5904184,0.17577513,0.7822792,0.61906105,0.6896018,0.873862,0.9968526,0.4556378,0.87811166,0.86004007,0.41853464,0.5995596,0.40827745,0.28851208,0.5202819,0.19265123,0.92939705,0.70689267,0.11201124,0.98409003,0.18970507,0.7182739,0.5939693,0.05994234,0.021280153,0.14513102,0.40208468,0.22757782,0.23340172,0.3629895,0.13855931,0.78980845,0.8154337,0.9686873,0.03149764,0.027852392,0.7822175,0.3670333,0.78024536,0.44308364,0.7551719,0.7001006,0.99656695,0.7096177,0.6460425,0.3090078,0.3817309,0.75382084,0.24751845,0.9919141,0.8101396,0.72690064,0.58389014,0.13931125,0.4260997,0.19920675,0.29389992,0.22849065,0.054567583,0.0286403,0.68753535,0.6393382,0.83747303,0.43944475,0.16854768,0.659512,0.25002992,0.015794016,0.9449101,0.7541057,0.945847,0.97127223,0.59012526,0.04557803,0.114047214,0.7673727,0.4418709,0.1393514,0.41973236,0.5081946,0.282509,0.30676988,0.2546641,0.6687642,0.31170198,0.43019253,0.81878066,0.9186455,0.787344,0.119964,0.48843786,0.26080957,0.43372,0.7264191,0.7316731,0.52168936,0.3228819,0.5850103,0.58188486,0.5764724,0.85721606,0.0048306463,0.9518531,0.51219267,0.9845728,0.72086376,0.21577734,0.14109355,0.16697218,0.70463514,0.54204077,0.5187638,0.08548192,0.021048365,0.8778848,0.19857538,0.04883652,0.7117264,0.10805124,0.49904156,0.22152025,0.6800811,0.17553183,0.637131,0.4801609,0.5453409,0.25295126,0.48752138,0.5394039,0.7378793,0.89846796,0.30146414,0.21664028,0.27394173,0.022367671,0.9892407,0.19886415,0.41262844,0.30491787,0.49006933,0.81182134,0.673692,0.2412966,0.17482981,0.5432391,0.8450185,0.69215244,0.70803803,0.04421597,0.29316452,0.21701345,0.111889146,0.85679144,0.92165715,0.093697235,0.3446256,0.46299627,0.4249108,0.7948484,0.19556557,0.7571282,0.01646797,0.8894279,0.19658394,0.26087877,0.70531607,0.6966002,0.5969214,0.5227917,0.36881882,0.9858828,0.23796275,0.4213183,0.48533306,0.44627303,0.15690878,0.6434008,0.41254497,0.99109685,0.20189007,0.5941583,0.18635221,0.6158875,0.42995065,0.027945405,0.8306056,0.3877798,0.982836,0.49713424,0.91654354,0.6155134,0.814247,0.3077533,0.22847779,0.88966215,0.8747604,0.41640446,0.9716281,0.18517044,0.033389226,0.026901966,0.41404715,0.7838385,0.9055906,0.63307714,0.6555554,0.61210406,0.8100642,0.7994826,0.50656956,0.7002863,0.122354865,0.73366094,0.92528874,0.50401425,0.3586611,0.3649591,0.8697877,0.09153776,0.56987906,0.4228477,0.72918344,0.21651368,0.273237,0.1320687,0.256684,0.3676141,0.1802598,0.8279442,0.5993243,0.99537796,0.70956576,0.6580005,0.9079618,0.06857852,0.33703786,0.42991522,0.46704793,0.30789334,0.97041386,0.067041285,0.48089835,0.23312177,0.09135661,0.6173484,0.47475886,0.9562112,0.99144304,0.50248766,0.5567772,0.6791836,0.5094131,0.5138229,0.9128905,0.5559054,0.28739175,0.5442868,0.1325101,0.039360367,0.9252663,0.30213857,0.5769297,0.24732989,0.7464911,0.16295283,0.22247133,0.6684257,0.30283514,0.31917402,0.2872067,0.41503724,0.81451225,0.03269196,0.820269,0.5588804,0.26527935,0.6293965,0.40942776,0.6733743,0.5519464,0.7554137,0.28561452,0.19815777,0.14119685,0.8302559,0.47257373,0.45373413,0.26654762,0.51656854,0.16259237,0.8570836,0.6660475,0.9988463,0.2234983,0.29011694,0.19929285,0.87688833,288.208,299.0334,234.06802,288.59332,285.71396,208.14828,243.33327,263.37518,222.83241,267.64508,236.68651,240.05948,241.17122,227.03455,229.1796,231.68953,267.16785,205.02823,264.77625,237.24646,249.54239,232.01376,208.56255,210.85419,239.4313,285.38928,207.99615,219.70026,286.46414,259.6215,264.591,240.25525,212.3435,223.9664,258.98178,278.75095,267.05542,200.13255,271.41925,235.1554,277.16098,235.27489,218.60641,299.13928,237.70187,218.95384,233.26817,239.93466,210.01537,237.0251,236.5253,272.3498,248.93144,249.78705,202.80908,296.07632,248.54794,228.7884,238.64236,214.01402,231.23134,243.41833,254.53098,229.02164,210.59755,268.93982,277.32697,297.97763,259.46844,229.38896,288.10034,251.99005,273.70062,277.30673,212.11809,205.43094,270.62506,244.42522,280.7068,252.17372,221.36655,231.1006,224.59811,239.97418,257.73175,290.97693,205.1341,217.40971,275.88208,201.61108,280.00003,289.00586,267.0944,231.31201,211.03806,213.06203,269.1713,265.57556,248.42055,209.8977,286.6746,221.91562,215.06145,229.53949,269.93027,276.57254,250.9029,288.37958,228.52266,267.0228,297.99734,214.70332,253.89653,231.25943,204.15068,276.6967,213.42561,222.77573,246.64607,206.99153,251.96185,275.08154,218.24387,211.39914,266.65384,298.70865,287.00455,227.15556,247.37427,213.96188,272.59308,224.01898,235.20276,253.20197,209.47455,210.07729,261.2526,239.28952,219.84111,211.5859,263.7782,225.82002,209.55066,225.2778,276.13922,208.97437,274.6557,297.25998,287.32483,205.43816,-9.999689,-9.999144,-9.999799,-9.999373,-9.999519,-9.9993925,-9.999233,-9.999142,-9.99984,-9.999262,-9.999546,-9.999872,-9.999391,-9.999968,-9.999606,-9.999656,-9.999715,-9.99956,-9.999932,-9.999743,-9.999814,-9.999712,-9.999522,-9.999528,-9.999384,-9.999094,-9.999038,-9.999751,-9.999586,-9.99945,-9.999128,-9.999073,-9.999791,-9.999677,-9.9991865,-9.99909,-9.999762,-9.999218,-9.9995575,-9.999647,-9.999325,-9.999892,-9.999989,-9.999758,-9.999248,-9.999668,-9.999531,-9.999084,-9.999631,-9.999403,-9.999865,-9.999935,-9.9991,-9.999564,-9.99925,-9.9990425,-9.999887,-9.999345,-9.999006,-9.999103,-9.999717,-9.99932,-9.999787,-9.999386,-9.999753,-9.999903,-9.999105,-9.999969,-9.999686,-9.999083,-9.99972,-9.999545,-9.999551,-9.999687,-9.999285,-9.999309,-9.999812,-9.99978,-9.999336,-9.999835,-9.999004,-9.999377,-9.999526,-9.999481,-9.999829,-9.999929,-9.999993,-9.999933,-9.999451,-9.999956,-9.999661,-9.999863,-9.9993305,-9.999771,-9.999426,-9.999976,-9.999994,-9.999831,-9.99988,-9.999162,-9.999056,-9.999193,-9.999941,-9.999949,-9.999971,-9.999258,-9.999011,-9.999707,-9.999535,-9.999201,-9.9995985,-9.999823,-9.999531,-9.999698,-9.999328,-9.999958,-9.999032,-9.999576,-9.999392,-9.999067,-9.99902,-9.999045,-9.99983,-9.999011,-9.999783,-9.999335,-9.999907,-9.999681,-9.999122,-9.999256,-9.999235,-9.999991,-9.999099,-9.999523,-9.999284,-9.999148,-9.999722,-9.999268,-9.999101,-9.99915,-9.999277,-9.999724,-9.999198,-9.999702,-9.999371,-9.999346,-9.999348,-9.999846,-9.99938,-9.999386,0.9152095,0.9171647,0.8286799,0.06623944,0.4663288,0.6674705,0.88702863,0.26388377,0.38012853,0.22043897,0.34161663,0.7549241,0.89839345,0.57267684,0.46196744,0.40692735,0.63130325,0.46858534,0.25790846,0.5064126,0.6745789,0.815519,0.3279563,0.06752282,0.32830805,0.9456376,0.99969417,0.33946416,0.09058472,0.80821294,0.4096069,0.04731839,0.1274211,0.26724407,0.0013231506,0.89294916,0.14734322,0.3986316,0.44342554,0.37137577,0.55341625,0.49281976,0.7313272,0.2879761,0.20376818,0.9424636,0.21195652,0.22167233,0.5677064,0.36845347,0.079733446,0.6180234,0.52336746,0.2760374,0.07769606,0.637682,0.085176565,0.16043824,0.6679482,0.8272858,0.6635249,0.28023627,0.9216744,0.5184493,0.33986536,0.83903545,0.6198479,0.7963929,0.63605565,0.41838124,0.26928508,0.05648084,0.6071852,0.3672051,0.54514945,0.46253535,0.595289,0.2197304,0.56575435,0.33570454,0.12949312,0.009017748,0.82104915,0.31175017,0.46786937,0.9008307,0.059177548,0.21651942,0.58483404,0.13534085,0.2563066,0.98585606,0.3444204,0.30529618,0.9550007,0.010194158,0.44460547,0.4293112,0.020983648,0.83968806,0.5455774,0.9872851,0.27159318,0.16667603,0.3916389,0.10710736,0.70841914,0.23437801,0.78563285,0.25137436,0.61097264,0.41494665,0.20036837,0.26286733,0.5676644,0.2662849,0.80940986,0.7974582,0.5003222,0.29910246,0.1976132,0.30444196,0.073145,0.68550193,0.28199244,0.7541317,0.11088511,0.34996328,0.7452604,0.42252555,0.21781512,0.96444,0.15884762,0.99850196,0.5329689,0.33807343,0.2701225,0.6472552,0.18246143,0.32816347,0.81063986,0.90712345,0.69261926,0.44346964,0.08311381,0.019193182,0.3513845,0.38967726,0.68732834,0.45974445,0.79513454,0.92073804,0.61770153,0.15796295,0.34206834,0.61403716,0.50911576,0.09764764,0.4105753,0.4610053,0.23835297,0.7583669,0.26223376,0.76859593,0.82576513,0.91628957,0.95209956,0.34038633,0.2481594,0.5448205,0.94344336,0.5867557,0.44679952,0.35732326,0.15309544,0.83495915,0.8223747,0.7383799,0.2723741,0.37363288,0.32874116,0.5468127,0.5836204,0.680963,0.28229877,0.440675,0.058448013,0.26188472,0.8043764,0.92689526,0.26310128,0.6354866,0.915084,0.45643163,0.87117124,0.9500249,0.1889253,0.5461343,0.47915125,0.43820933,0.13977474,0.8290898,0.30484903,0.5062122,0.33160135,0.62606835,0.65262437,0.23008808,0.4257683,0.13102946,0.21824555,0.8722663,0.26695797,0.028245918,0.77160543,0.10392295,0.06169725,0.9943042,0.8000285,0.34662995,0.3909258,0.6586493,0.9920871,0.80688536,0.84350026,0.86506003,0.9833786,0.1113381,0.058909472,0.36759707,0.1351905,0.08711318,0.17150986,0.97114897,0.10649935,0.917866,0.56674695,0.99736273,0.6040517,0.92105764,0.38094944,0.48367384,0.14886507,0.380281,0.41597223,0.11372275,0.9531382,0.67997587,0.15792394,0.3364488,0.021841977,0.07619969,0.7798327,0.19889046,0.67756367,0.50971586,0.52456796,0.5036354,0.7753575,0.34809372,0.6398678,0.4031053,0.32557586,0.9053469,0.8064988,0.017155945,0.6316684,0.45066175,0.4873005,0.19287354,0.57614934,0.83062655,0.78713834,0.68235135,0.87318754,0.59281385,0.064060956,0.9382655,0.84566283,0.5540783,0.17840536,0.61837703,0.60292286,0.6568771,0.8471286,0.17995848,0.49391183,0.58517873,0.5330186,0.5795362,0.23409952,0.5289169,0.3746643,0.3180484,0.5622743,0.036257476,0.43180978,1.3171679E-4,0.63862574,0.5848303,0.94060403,0.5878032,0.6252845,0.18924952,0.39612424,0.7757128,0.9900665,0.86055374,0.18927997,0.84641314,0.8975901,0.89157784,0.57380813,0.94526875,0.501755,0.42647004,0.20386614,0.4966745,0.7561392,0.24496855,0.13073194,0.41784236,0.70873123,0.7233561,0.96866304,0.13634546,0.049341034,0.71949446,0.26208475,0.5635493,0.27563098,0.69374204,0.078678265,0.03588799,0.39408693,0.7788656,0.94594073,0.92669946,0.41283527,0.62035376,0.281576,0.89905745,0.9558993,0.0892733,0.43785354,0.37643972,0.23148632,0.17041226,0.35524517,0.88507247,0.3892006,0.387216,0.15375885,0.21120822,0.24968858,0.44297022,0.2895735,0.15732966,0.07728944,0.71204036,0.6714093,0.053016555,0.75036585,0.23313028,0.56734544,0.7048986,0.8168968,0.06141414,0.35583347,0.07237186,0.12143032,0.83158904,0.6737841,0.53340894,0.13451897,0.24459034,0.96684134,0.30125558,0.39460337,0.07498105,0.6020688,0.11102765,0.3656724,0.4939227,0.21076858,0.13569292,0.6039172,0.08439329,0.30890274,0.22699659,0.64184964,0.2754223,0.7049345,0.63606584,0.9549267,0.80815446,0.17538197,0.05759198,0.43693244,0.26000643,0.6929544,0.7537442,0.61757445,0.19318241,0.034338124,0.8184448,0.92103,0.97425944,0.8894058,0.4300163,0.88676697,0.3483852,0.13178374,0.95866996,0.6248255,0.93648285,0.08839288,0.14454809,0.035382055,0.3209607,0.16345672,0.12934527,0.3662055,0.25347614,0.22039147,0.07854195,0.7695641,0.45950922,0.093585685,0.35322717,0.5360373,0.6071155,0.9050337,0.8356653,0.55022,0.8330065,0.92175573,0.93212676,0.79578835,0.44477537,0.14613354,0.6763672,0.27782786,0.9030046,0.8203768,0.6832867,0.24530792,0.7274624,0.3142183,0.022943567,238.253,220.45427,267.66333,238.0088,271.58243,273.22388,211.78992,289.42252,217.21829,208.85757,217.32358,207.44218,259.48422,208.71153,268.2896,297.33484,254.15167,232.80293,254.54332,232.60858,238.36755,270.21686,279.47226,282.7281,212.87875,212.81602,277.39685,293.25415,220.63031,259.65414,257.0341,286.7428,202.3495,251.0628,268.4925,237.58267,214.1937,219.69623,294.32617,293.98544,271.97043,277.1976,208.15645,285.3982,275.2406,253.17255,280.30792,210.3171,262.86252,211.56,201.4514,237.41928,204.32811,291.4109,246.54733,278.7369,226.24847,262.70038,207.41508,274.15656,250.72443,259.09497,278.62515,298.87927,271.1042,265.95636,228.53195,264.95953,231.45522,238.10721,201.05338,299.04672,203.31392,280.5685,207.49594,288.41803,259.77884,289.5286,212.903,232.62526,273.2359,274.92944,228.19473,292.2021,244.35541,235.74893,281.4144,255.78027,261.2293,219.03902,240.27055,210.33026,250.7247,281.74927,296.55548,224.49033,224.96393,219.88365,294.07227,223.65594,273.98865,279.8825,262.97278,269.57916,284.82678,205.99402,230.71436,245.10574,291.90387,221.07706,285.6493,236.25264,225.34695,210.36287,288.40872,299.56335,259.16122,220.4013,235.9941,213.55952,286.5168,261.12793,230.74602,268.31143,226.09164,217.6272,203.38873,240.80707,255.07602,283.92712,218.6427,278.5974,272.98724,211.10165,230.14198,217.64426,228.90018,266.22888,227.51234,218.84616,247.46571,259.92053,212.12146,248.02554,236.08237,277.90137,263.06485,207.07365,275.89902,264.8849,-9.9997225,-9.9999695,-9.999966,-9.9999895,-9.999834,-9.999596,-9.999333,-9.999578,-9.99955,-9.999539,-9.99926,-9.999182,-9.999128,-9.999777,-9.999337,-9.999904,-9.999079,-9.99941,-9.999122,-9.999788,-9.999136,-9.9995165,-9.999043,-9.999407,-9.999571,-9.999437,-9.999941,-9.999134,-9.999198,-9.999579,-9.999475,-9.999036,-9.999713,-9.999731,-9.999678,-9.999174,-9.999507,-9.999201,-9.999245,-9.999307,-9.999488,-9.999016,-9.999532,-9.999287,-9.999413,-9.999584,-9.99978,-9.999425,-9.999651,-9.999136,-9.999289,-9.999958,-9.9991665,-9.99916,-9.999886,-9.999217,-9.99971,-9.999494,-9.999177,-9.999025,-9.999024,-9.999849,-9.999718,-9.99997,-9.999352,-9.999563,-9.999284,-9.999314,-9.999419,-9.999329,-9.99949,-9.9992075,-9.999859,-9.999224,-9.999656,-9.999043,-9.99958,-9.999525,-9.999985,-9.999004,-9.999768,-9.999181,-9.999919,-9.999416,-9.999452,-9.999608,-9.999645,-9.999955,-9.999919,-9.999946,-9.999472,-9.999145,-9.999147,-9.99935,-9.999072,-9.999628,-9.999188,-9.999702,-9.999313,-9.999205,-9.999878,-9.999991,-9.999111,-9.9991,-9.999404,-9.999437,-9.999719,-9.999646,-9.999839,-9.999222,-9.999134,-9.999098,-9.999538,-9.999294,-9.999013,-9.999872,-9.99908,-9.999922,-9.999595,-9.999158,-9.999308,-9.9995,-9.99924,-9.999744,-9.999338,-9.999049,-9.999883,-9.999513,-9.999893,-9.999218,-9.999468,-9.999204,-9.999081,-9.9994335,-9.999555,-9.999373,-9.999073,-9.999382,-9.999415,-9.999362,-9.999137,-9.999514,-9.999781,-9.999969,-9.999229,-9.999295,-9.999149,-9.999783,-9.999437,-9.999201,0.8368316,0.95952296,0.7187136,0.6472035,0.7200239,0.82257813,0.13384113,0.91812044,0.9440362,0.23334092,0.3562596,0.20390894,0.47781035,0.56394255,0.8770303,0.84794813,0.92716575,0.3591966,0.006163279,0.34427875,0.30020186,0.035439115,0.36127335,0.1666844,0.65421695,0.752802,0.8639191,0.7162624,0.10528788,0.3911885,0.6361361,0.33739233,0.45225555,0.04712947,0.9509385,0.08811871,0.6489793,0.563957,0.8571504,0.47839713,0.86719155,0.7297759,0.9265764,0.86381954,0.2705895,0.80873495,0.69725907,0.4615118,0.98845094,0.38829336,0.5021872,0.051559158,0.4416545,0.84030825,0.028471855,0.8019141,0.4764789,0.73308647,0.24829985,0.28266567,0.1642818,0.497284,0.9761126,0.8595787,0.61120987,0.48310366,0.45415315,0.4246855,0.35486698,0.4365935,0.6768876,0.36493155,0.96304077,0.49552417,0.8761381,0.7559321,0.46201146,0.50861555,0.023068247,0.551351,0.45992744,0.069025,0.9549169,0.9121757,0.35455093,0.32405618,0.6669353,0.16085483,0.9973096,0.81469834,0.47871014,0.009814576,0.9915644,0.4212253,0.18318938,0.5728494,0.3666718,0.78813976,0.48231423,0.723981,0.7495278,0.7334672,0.31657055,0.29471073,0.2991272,0.17905454,0.25772056,0.04573023,0.9155821,0.9855648,0.9641909,0.49942952,0.32687747,0.3305897,0.5485675,0.6368628,0.09610839,0.91397697,0.99097943,0.7983881,0.7839146,0.13756526,0.058954984,0.2574425,0.7659589,0.8970627,0.8955351,0.24972673,0.3770009,0.5416225,0.42023486,0.4635182,0.040502504,0.20716274,0.08657944,0.13138548,0.8770457,0.6316995,0.0990857,0.732918,0.4953378,0.30765584,0.21265133,0.008900259,0.42015043,0.25701198,0.26232395,0.59503317,0.37619093,0.059471674,0.96380097,0.6594173,0.74392956,0.80542815,0.5856752,0.4709212,0.07911475,0.8975309,0.76675755,0.026576402,0.012588193,0.9571294,0.14971007,0.42658392,0.4339528,0.40636125,0.418213,0.19980216,0.8942122,0.995247,0.026640382,0.8785028,0.48940244,0.3919287,0.0862845,0.5089264,0.17742826,0.10345855,0.5513259,0.7041969,0.78375727,0.34573317,0.34970793,0.61609524,0.9967575,0.19738163,0.4390408,0.49108744,0.5759808,0.39300266,0.84470737,0.3280776,0.41459507,0.0031824266,0.3248213,0.21955715,0.8830681,0.6528493,0.7155801,0.18756945,0.038407642,0.048247315,0.06908089,0.96183145,0.8542427,0.45350936,0.3367257,0.26481515,0.06306089,0.3728015,0.4432045,0.7682931,0.34411287,0.018815735,0.60152483,0.06271082,0.30780053,0.15404528,0.777356,0.9382987,0.03425807,0.74410313,0.050881404,0.106018655,0.9237955,0.40959543,0.44272372,0.42992854,0.40163797,0.9774989,0.7284286,0.96605545,0.073630586,0.7020174,0.9556004,0.4899371,0.2590087,0.7959899,0.8613244,0.7109668,0.68005985,0.18156524,0.68875915,0.89809185,0.26884466,0.46794668,0.78001046,0.6469185,0.03375709,0.83638656,0.19561735,0.72300714,0.4323585,0.6666231,0.6944045,0.5573255,0.94807935,0.40593168,0.16260563,0.2516181,0.5295202,0.8144355,0.63592476,0.40705463,0.41550696,0.046603993,0.23649848,0.72142303,0.86540526,0.9812862,0.12677868,0.7740198,0.028188271,0.05125889,0.25654867,0.7408246,0.9826668,0.75396377,0.6689209,0.8002577,0.3877432,0.83123654,0.5672896,0.8960579,0.39333224,0.14590047,0.7893236,0.38733613,0.77125305,0.9827144,0.014167471,0.49262884,0.21413602,0.67211145,0.27530655,0.76538646,0.5841506,0.9951677,0.29803824,0.024221342,0.6438744,0.43844396,0.35386777,0.39374486,0.9667755,0.26405483,0.29369798,6.263968E-5,0.40577433,0.014699541,0.8506516,0.82061505,0.04640132,0.38329712,0.23627418,0.01457501,0.920022,0.36586156,0.54100925,0.4094,0.9525085,0.7759392,0.38271114,0.9372709,0.4954011,0.90372294,0.5493134,0.79789823,0.215295,0.18560563,0.52747923,0.015467339,0.25793558,0.9574369,0.8208537,0.21616516,0.80089974,0.4464337,0.37760806,0.31725752,0.07363392,0.5414981,0.5969112,0.6802155,0.08681603,0.748899,0.8132425,0.6588185,0.7527277,0.22249526,0.48485887,0.52951264,0.9087715,0.0022171019,0.3312975,0.70355535,0.9905531,0.18766245,0.8428444,0.9489218,0.75968647,0.16918193,0.5090402,0.57815427,0.41849396,0.3353734,0.5701858,0.59971434,0.037876863,0.30670634,0.08724593,0.51724964,0.44608638,0.8887655,0.23586161,0.54564106,0.17055021,0.65770286,0.36355573,0.11598958,0.98736215,0.39781153,0.8273148,0.099607535,0.9095583,0.63183874,0.6119373,0.023166118,0.42524394,0.3938052,0.78907496,0.7087274,0.4950751,0.27278492,0.36101273,0.9821936,0.7951266,0.8089244,0.7677898,0.506932,0.6540132,0.45168075,0.82436436,0.6100174,0.50495255,0.95378387,0.15670867,0.3659073,0.34792703,0.22730303,0.41741064,0.5464127,0.12390941,0.38427374,0.64032775,0.77376515,0.8658444,0.7240665,0.43486324,0.12049561,0.8539374,0.08333132,0.97497743,0.09330166,0.44820398,0.6796943,0.48456368,0.9055214,0.26348707,0.658894,0.0733997,0.1792219,0.54822993,0.08548857,0.6243975,0.14298357,0.034526028,0.094718255,0.039160337,0.24803995,0.7548811,0.81707966,0.55264014,0.4717769,0.8132233,0.08796681,0.46675965,0.21120757,0.84116185,0.02198596,233.08963,284.46478,228.92946,299.10284,252.34494,270.3675,247.62338,259.12375,293.7792,292.25543,287.2373,261.2933,234.23328,242.85649,246.06302,211.33946,262.4088,288.57184,280.21918,205.70305,216.75426,287.24652,233.86952,253.43048,228.54883,297.02246,219.41966,230.32181,211.07607,201.58842,255.04857,276.64703,226.55725,285.53146,230.61176,277.40143,217.56476,214.18044,253.52425,286.49228,280.64703,216.87614,229.96323,272.0548,287.85236,209.3926,271.86664,240.23541,299.9867,214.53423,273.7356,253.11342,205.02061,222.24791,242.70433,245.3724,298.40033,289.42432,282.7867,229.05533,289.985,271.32953,206.18881,285.04318,280.12766,215.771,233.6232,204.17224,242.84424,286.33337,254.11534,209.9334,243.23608,272.5159,205.16878,276.64346,244.62245,294.27008,290.36227,216.88017,298.44403,298.37915,214.64677,255.04266,280.10626,281.35904,236.9879,257.5684,280.48505,238.83212,253.65378,291.90552,228.50763,205.08888,281.95593,252.75293,290.4546,287.56818,210.91739,256.31198,232.79715,269.6927,235.58183,276.23233,227.1755,276.03674,292.6508,285.0999,287.64133,234.23032,296.60068,277.18442,257.54352,254.5871,298.60168,202.64233,255.38023,248.32083,260.9433,205.4068,247.34087,208.5292,202.0934,216.09306,221.08582,257.41556,247.06735,266.92804,210.08488,249.02866,204.24144,263.3803,222.9913,251.80115,218.99036,290.71286,227.41696,204.93797,231.20157,292.14478,297.73837,280.12753,297.94702,228.16396,256.27838,280.33307,205.8249,279.23096,268.9643,231.75375,-9.999341,-9.999257,-9.999949,-9.999035,-9.999831,-9.99975,-9.999811,-9.999584,-9.999827,-9.999112,-9.999565,-9.999383,-9.999329,-9.999119,-9.999867,-9.999806,-9.999535,-9.99903,-9.99938,-9.9991255,-9.999031,-9.999938,-9.999783,-9.999634,-9.999506,-9.999364,-9.999014,-9.999437,-9.999991,-9.999617,-9.999323,-9.9991,-9.999098,-9.999426,-9.999119,-9.999553,-9.9994545,-9.999403,-9.99964,-9.999833,-9.99963,-9.999753,-9.999862,-9.999563,-9.999861,-9.999462,-9.99921,-9.99975,-9.999412,-9.99969,-9.999759,-9.999703,-9.999666,-9.999825,-9.999146,-9.999077,-9.999142,-9.999701,-9.999502,-9.999564,-9.9995165,-9.9997835,-9.999195,-9.999329,-9.999829,-9.999427,-9.999484,-9.999804,-9.999084,-9.999392,-9.999105,-9.999679,-9.999752,-9.999843,-9.999609,-9.999379,-9.99906,-9.999004,-9.99919,-9.9998665,-9.999223,-9.999334,-9.999842,-9.999544,-9.999025,-9.999718,-9.999823,-9.999554,-9.99945,-9.999082,-9.999171,-9.999058,-9.999519,-9.9995365,-9.999272,-9.999615,-9.999609,-9.999498,-9.999642,-9.999337,-9.999279,-9.999857,-9.999663,-9.999423,-9.9990635,-9.999101,-9.9993,-9.999743,-9.999616,-9.999779,-9.99996,-9.999366,-9.999638,-9.999791,-9.999472,-9.999714,-9.999069,-9.999222,-9.999011,-9.999037,-9.999066,-9.99982,-9.999337,-9.999344,-9.9998455,-9.999567,-9.999952,-9.9990635,-9.9993515,-9.999747,-9.999756,-9.999433,-9.999954,-9.999456,-9.999391,-9.999602,-9.999213,-9.999057,-9.999885,-9.999203,-9.999455,-9.999208,-9.999754,-9.99941,-9.9997015,-9.999528,-9.999968,-9.999105,-9.999052,-9.999117,0.07731749,0.9572599,0.2881733,0.34789458,0.12208096,0.3989875,0.23046659,0.07561615,0.7311842,0.24280672,0.13743502,0.32029906,0.26720718,0.6435275,0.71581525,0.25040102,0.07968058,0.9510946,0.16737682,0.5338542,0.96112233,0.12613547,0.71407163,0.017653665,0.5663055,0.9523341,0.66330385,0.43527827,0.21753095,0.6377421,0.0820664,0.5563942,0.105712675,0.06655064,0.8044171,0.6876928,0.97473025,0.47098678,0.23313597,0.46495864,0.13682419,0.19020991,0.6946199,0.58204114,0.008083445,0.21409632,0.90480167,0.06497669,0.3296087,0.51603156,0.49303642,0.3029305,0.5821996,0.5105462,0.51879376,0.108761,0.13990402,0.44722676,0.8695498,0.014239418,0.5745597,0.52994305,0.8318035,0.7634822,0.677615,0.09214777,0.705199,0.47799557,0.24047466,0.3105237,0.89669865,0.6427869,0.59037143,0.2127864,0.27039096,0.09363014,0.7930851,0.58145946,0.058050785,0.74635893,0.34254172,0.942883,0.8463423,0.49698228,0.1885729,0.2511439,0.87867934,0.028224535,0.7651291,0.49802932,0.21640365,0.69269353,0.25175697,0.76805496,0.75059545,0.05755356,0.7005975,0.9643457,0.59199476,0.15058741,0.8211338,0.50831884,0.9554822,0.10171006,0.5546305,0.28822696,0.8995881,0.96590596,0.76544195,0.23609895,0.5093231,0.29946357,0.44045478,0.5974459,0.24198511,0.13976322,0.30026865,0.6117198,0.54420567,0.83931947,0.9591503,0.055750016,0.015446019,0.34988365,0.6788849,0.8000394,0.34461623,0.8884854,0.11765242,0.6764313,0.70610297,0.7528662,0.6234379,0.95549244,0.48107228,0.57657474,0.35293803,0.53558505,0.90731245,0.6388894,0.9061205,0.9068154,0.82560843,0.48359713,0.6093791,0.25128087,0.58313656,0.10119824,0.14279248,0.8000816,0.89156765,0.12725733,0.052655865,0.09217951,0.20653115,0.34572187,0.34771374,0.30589288,0.06053133,0.41077146,0.9258966,0.31344774,0.66711676,0.04113631,0.9229566,0.008368838,0.5903627,0.84122473,0.11545232,0.7868713,0.9680761,0.23150893,0.4704689,0.5499954,0.43753204,0.7121286,0.61013496,0.59720284,0.92617583,0.7834906,0.027650753,0.8977211,0.15754606,0.54239666,0.18633401,0.5662742,0.2190944,0.59521663,0.6435355,0.71627194,0.037149042,0.6100622,0.61836076,0.1470259,0.36966816,0.90360576,0.5119274,0.7205386,0.39034662,0.62984717,0.01017152,0.64599174,0.15090384,0.36933318,0.19484489,0.09027873,0.58042485,0.14514206,0.036732975,0.54077417,0.43008235,0.15875153,0.34932455,0.37410876,0.8042535,0.7739999,0.28807458,0.97715217,0.117083825,0.17992087,0.9757363,0.18320304,0.015741833,0.9748695,0.65635973,0.14705919,0.037058447,0.8968405,0.021620478,0.5633058,0.767505,0.12037435,0.44985265,0.26535192,0.22633725,0.5835013,0.42530164,0.6948082,0.7116804,0.6978424,0.82452023,0.23771845,0.99683344,0.70071405,0.12593275,0.7764756,0.36999762,0.3072223,0.09792935,0.43981078,0.8204207,0.14809668,0.7569628,0.8288626,0.15944423,0.21987063,0.5351478,0.11639127,0.9450276,0.657273,0.48179442,0.6428968,0.07266802,0.54417425,0.8990355,0.36724177,0.4083636,0.2944423,0.9782087,0.15691185,0.39151284,0.56013423,0.049810167,0.906521,0.9659634,0.921944,0.30070534,0.9883118,0.95775986,0.13003021,0.8573852,0.1918365,0.10604336,0.19914377,0.40675613,0.024324145,0.23431449,0.72297823,0.7580914,0.20346278,0.82810277,0.32680357,0.10711087,0.590452,0.5469826,0.18557824,0.51672226,0.9832008,0.7936118,0.5308729,0.37090248,0.7742029,0.4481485,0.5493372,0.50338376,0.43103522,0.53751975,0.70061314,0.021088583,0.3308669,0.8162114,0.5326165,0.35944003,0.9206047,0.6406876,0.50699484,0.8470867,0.9593492,0.7875809,0.9962247,0.23328215,0.7006755,0.5442194,0.6375928,0.33889383,0.9687761,0.5783294,0.9320834,0.88320315,0.7495404,0.5102735,0.22573441,0.51124907,0.9721347,0.44289282,0.37883982,0.33592433,0.40807053,0.7348208,0.059953105,0.020652194,0.373106,0.35336265,0.029604226,0.6272284,0.6029403,0.49051753,0.398493,0.4539566,0.2655247,0.9981165,0.75446373,0.46822912,0.648188,0.324949,0.9306804,0.8809041,0.42844233,0.38464552,0.76389503,0.7626695,0.63432926,0.33961716,0.61165744,0.7148871,0.4873704,0.49829185,0.5820676,0.40672466,0.51494414,0.883497,0.78602934,0.24558222,0.5361903,0.69763577,0.26757947,0.4059913,0.862289,0.7588195,0.18907034,0.42610446,0.08498969,0.02107262,0.2888108,0.90481687,0.03300186,0.61184776,0.41099504,0.27365708,0.27691156,0.01747882,0.71713996,0.40858844,0.7091915,0.2785737,0.87971973,0.015822828,0.058852635,0.54861325,0.4243099,0.07972601,0.7242567,0.3915925,0.85279524,0.5510232,0.88121253,0.55209786,0.9690384,0.910818,0.4399193,0.08753263,0.25317103,0.28638893,0.08940263,0.62953717,0.13840295,0.6593923,0.27087918,0.54218894,0.7974436,0.03127277,0.13191597,0.3672008,0.45645824,0.50062525,0.59150535,0.53669804,0.87231857,0.083159134,0.30086067,0.57798487,0.6605887,0.46329933,0.7809135,0.3256513,0.42846498,0.43590286,0.7588255,0.112232044,0.45630154,0.85721415,0.36618492,0.3291177,0.3065707,0.258635,0.93674284,0.267144,0.94944286,0.03034833,0.43545058,277.44568,293.30225,290.0967,226.36577,263.3507,233.65721,271.0456,201.33302,244.87222,248.06546,283.55505,273.16003,273.43265,248.35196,261.96664,252.17625,213.653,268.57755,241.37634,275.69666,231.28116,238.647,267.70135,270.0771,278.84747,232.92476,227.37221,290.46814,282.7081,210.15854,275.31555,260.04895,283.80142,227.62625,267.77484,245.33005,251.6941,232.47691,220.30089,292.46063,252.57907,262.54684,254.58533,239.21768,246.7902,254.07513,230.66675,288.9232,216.71547,214.78873,279.40067,210.46289,269.7311,258.03143,220.68816,220.33643,290.5327,217.04453,203.5228,236.82892,271.18365,253.44327,206.32324,243.99203,285.42123,208.0186,235.3223,215.7981,281.17578,258.11807,235.2606,226.48712,280.93256,280.83173,243.42778,266.36462,236.26477,295.47427,273.871,293.18738,276.67422,232.46318,218.5724,278.0185,260.68582,216.33072,202.01517,256.0112,260.35217,285.29895,282.32895,204.90137,202.91895,201.99902,234.42209,232.87006,296.0879,282.7151,260.2,263.00598,245.1402,220.98419,227.66153,298.27438,288.2768,246.6337,247.41647,229.84933,200.41792,256.62027,207.03185,235.04187,269.5741,279.07892,279.92096,266.31543,277.62415,282.93802,244.6243,261.97354,287.40088,285.73053,210.00949,235.31769,267.29855,256.89893,225.80467,241.72736,243.78555,230.197,220.44577,286.22617,295.29068,248.73352,271.84897,295.86597,274.50906,285.53323,254.3574,246.36845,232.46686,202.37822,232.31885,284.55515,281.44986,288.22656,224.62955,257.4739,277.62314,233.47943,-9.999561,-9.999684,-9.999829,-9.999858,-9.999566,-9.999728,-9.999245,-9.999897,-9.999244,-9.999921,-9.999919,-9.999612,-9.999473,-9.9995575,-9.999303,-9.999789,-9.999555,-9.999162,-9.999468,-9.999969,-9.999672,-9.999807,-9.999847,-9.99909,-9.999817,-9.999831,-9.999489,-9.999215,-9.999848,-9.9998455,-9.999323,-9.999817,-9.999044,-9.999408,-9.999863,-9.999365,-9.99908,-9.99931,-9.99933,-9.99975,-9.999039,-9.99978,-9.999931,-9.99974,-9.999948,-9.999952,-9.999335,-9.999389,-9.999414,-9.999315,-9.999753,-9.999389,-9.99995,-9.999082,-9.999573,-9.999592,-9.9998,-9.999939,-9.999826,-9.999052,-9.99905,-9.999516,-9.999568,-9.999664,-9.999201,-9.9993,-9.999386,-9.999858,-9.999468,-9.99966,-9.999665,-9.999242,-9.9997425,-9.99912,-9.999361,-9.999368,-9.999324,-9.999566,-9.999074,-9.99973,-9.99977,-9.999092,-9.99947,-9.999531,-9.999189,-9.99918,-9.999814,-9.999811,-9.999523,-9.999692,-9.999746,-9.999281,-9.999508,-9.999807,-9.999763,-9.999359,-9.999442,-9.999778,-9.999925,-9.999119,-9.999002,-9.999579,-9.999089,-9.999878,-9.9991865,-9.999503,-9.99901,-9.9991865,-9.999055,-9.999055,-9.9990225,-9.999116,-9.999345,-9.999241,-9.999561,-9.999711,-9.999534,-9.999722,-9.999037,-9.99902,-9.999436,-9.999547,-9.9997425,-9.999701,-9.999172,-9.99957,-9.99917,-9.999358,-9.999515,-9.9994545,-9.999549,-9.99922,-9.999552,-9.999457,-9.999204,-9.999363,-9.99935,-9.999776,-9.999162,-9.999254,-9.99992,-9.999504,-9.9991,-9.999846,-9.99928,-9.99955,-9.999984,-9.999683,-9.999582,-9.999975,0.4054413,0.49212277,0.9723238,0.72839403,0.6485173,0.11651259,0.10785521,0.032620244,0.023706913,0.3086147,0.47183102,0.992096,0.99172103,0.34033036,0.95944905,0.22414577,0.06989748,0.5614623,0.97281843,0.52306736,0.053522028,0.50254625,0.51301396,0.5985718,0.0371569,0.8265822,0.4661505,0.4922629,0.81253344,0.9696686,0.60658884,0.8239178,0.15269178,0.939187,0.14531301,0.37456673,0.779733,0.418844,0.66610193,0.5676376,0.8005674,0.31309485,0.03271992,0.36289623,0.5230104,0.9365938,0.54856783,0.38090333,0.677641,0.98534113,0.6625885,0.9755095,0.078554325,0.018032718,0.8922824,0.9402988,0.7797243,0.5073222,0.8464975,0.7056091,0.49532133,0.42082825,0.39204183,0.7350382,0.7106082,0.7145868,0.7029236,0.22454071,0.9618653,0.4929038,0.58743435,0.22425091,0.52113986,0.29244232,0.58773226,0.17996566,0.16191864,0.8782989,0.6559272,0.45498922,0.109633766,0.29422963,0.28020766,0.45128867,0.34663188,0.011857478,0.13049418,0.39511293,0.15442526,0.98196644,0.74726933,0.20202826,0.066193216,0.6910641,0.91542566,0.36986846,0.36708114,0.7992493,0.66625875,0.9589232,0.58173925,0.2632916,0.8744973,0.869903,0.27612343,0.43633205,0.0069335676,0.46793646,0.6261623,0.8301051,0.4103617,0.583117,0.9595133,0.092884764,0.6108136,0.9563768,0.13297999,0.9781464,0.1866522,0.6501296,0.940671,0.5299086,0.9236821,0.8280376,0.5605807,0.08746594,0.99765533,0.9831952,0.3346039,0.45981014,0.16059282,0.898296,0.24069251,0.84168667,0.42612913,0.840821,0.06970532,0.6529262,0.21027155,0.6587761,0.8506848,0.23469605,0.8375965,0.6650027,0.6900568,0.03741631,0.90703416,0.60072684,0.041207824,0.20454895,0.13258597,0.38379464,0.5782676,0.37454012,0.788924,0.6553679,0.6696084,0.194304,0.18800853,0.42950943,0.70689565,0.837481,0.14751653,0.56871074,0.7577148,0.7652816,0.19738932,0.9059352,0.97273886,0.51461357,0.1711977,0.5120307,0.22731306,0.5407244,0.2804785,0.05774873,0.80988765,0.7796792,0.31191307,0.39822164,0.5347025,0.07349863,0.21531169,0.07873698,0.8192433,0.722044,0.40318736,0.8964449,0.49459186,0.9010825,0.45778024,0.80724466,0.38512704,0.38782215,0.13246128,0.7218372,0.7401796,0.84869057,0.56868243,0.3278968,0.019229556,0.43221912,0.693255,0.43167397,0.78483266,0.09825686,0.5116548,0.1271103,0.18708695,0.95848906,0.23714672,0.52546054,0.5915945,0.7894098,0.8593355,0.31078282,0.28504592,0.85881007,0.29736793,0.50781727,0.65514153,0.44968098,0.9075563,0.7546295,0.45364478,0.29375777,0.94780463,0.6616151,0.01726944,0.9249832,0.9179415,0.6749661,0.43883613,0.37391648,0.65078586,0.21732111,0.02359236,0.007791354,0.30327088,0.31245363,0.84185934,0.49694976,0.93794364,0.8528437,0.7000397,0.5224565,0.8105422,0.99443287,0.847529,0.15470129,0.8077305,0.5341055,0.23147497,0.40932575,0.96443266,0.09061932,0.05683991,0.99754393,0.11661421,0.19272684,0.3620329,0.45262036,0.03901034,0.06041548,0.0075550857,0.27494353,0.67014945,0.2957977,0.2216069,0.6506188,0.45587075,0.28567624,0.5888963,0.98453754,0.8699843,0.9340606,0.0642961,0.14302005,0.7717978,0.75930613,0.6141049,0.4101332,0.27772737,0.28117037,0.8098905,0.5942,0.7786375,0.4493845,0.5141761,0.744234,0.34754843,0.9057713,0.29356617,0.41850287,0.25478244,0.78619635,0.70232016,0.7863453,0.57700616,0.3423882,0.11562478,0.6069529,0.7797115,0.2574891,0.51921356,0.2538803,0.670748,0.82137585,0.47364834,0.9369771,0.1801538,0.5134379,0.3520003,0.38112086,0.29870084,0.55816495,0.95891315,0.3729329,0.7877428,0.029987516,0.37669265,0.10563303,0.14064822,0.4556408,0.86550975,0.73312205,0.09095184,0.9431056,0.372078,0.4691022,0.72663444,0.5589779,0.98812455,0.1695335,0.8314304,0.7852622,0.61309403,0.10439321,0.76670945,0.5409888,0.9157445,0.57858527,0.14883776,0.20041484,0.30621874,0.9036323,0.9339205,0.9151604,0.12393201,0.929967,0.35930997,0.2358306,0.6697985,0.31414795,0.30049297,0.89661825,0.27027792,0.17256655,0.9318595,0.81196785,0.38976404,0.293463,0.2512547,0.81138444,0.988779,0.27900514,0.4261041,0.61765677,0.8339683,0.25210267,0.51324797,0.92285997,0.0889822,0.5169889,0.3989031,0.6554801,0.9353766,0.544529,0.123369224,0.34246746,0.2115331,0.26744205,0.71749866,0.22343503,0.64539504,0.67429143,0.41868812,0.40186298,0.098477215,0.88132435,0.07625152,0.043012597,0.6452063,0.2102687,0.22173183,0.10345679,0.7434575,0.7126712,0.76721144,0.6512526,0.15990873,0.11895295,0.77731425,0.5243528,0.694658,0.86524415,0.75635976,0.057310082,0.16338252,0.78290933,0.7817539,0.8036517,0.33238873,0.676157,0.6762056,0.16322272,0.87960654,0.36118373,0.32454377,0.763408,0.506997,0.6956684,0.9279813,0.20323144,0.5839603,0.5633559,0.6701542,0.25721762,0.9896909,0.95511895,0.9082311,0.29406747,0.60026234,0.93644714,0.61788774,0.66341126,0.20749137,0.52809435,0.30916053,0.59821826,0.42163637,0.8293481,0.9711802,0.7839911,0.7657031,0.5351135,0.6362381,0.5429735,0.29129192,0.74155486,256.6196,299.92203,283.1842,257.95,242.67941,283.13525,297.3768,209.21597,298.94897,272.28577,208.13962,224.24684,215.7119,289.45593,248.60497,291.094,261.66168,291.05728,280.15112,246.94473,281.08008,221.38707,231.09238,220.10115,219.70961,273.52057,298.6576,250.59302,203.40039,227.90755,208.1463,211.84389,251.76518,275.46594,292.12732,277.5088,281.66544,274.27924,291.94995,282.94733,231.35228,229.87643,226.04532,246.81201,285.92133,211.72032,265.00046,292.0401,217.145,258.9742,241.07838,297.71396,265.03607,293.78973,215.46487,271.7528,297.20273,234.13841,253.58505,252.52872,224.75195,218.48878,204.55463,293.8269,283.58505,264.1618,226.64536,280.69232,218.0678,219.11906,209.70735,215.2419,227.23471,226.22966,292.78833,250.87213,220.66672,292.0923,214.3262,220.62033,292.90533,294.61047,210.68884,260.9642,262.28113,255.0517,232.66026,294.8312,206.05696,289.73633,235.66345,232.93633,263.52408,256.7292,210.22684,229.51805,282.41776,211.0127,239.21553,235.43231,278.32697,299.7943,247.10483,219.1755,224.00432,263.2412,276.8183,291.88232,233.7261,241.75543,261.45193,296.58963,203.90746,277.9264,245.81134,261.24277,212.32646,242.76822,241.22888,224.0751,267.85315,232.49553,272.37656,253.20465,206.93951,201.29115,257.55444,296.3969,259.25177,292.10406,267.9734,253.28792,210.03741,272.03717,284.04358,292.52087,253.26274,207.37628,263.50598,228.07819,237.00746,241.3014,278.94174,214.41554,270.15442,264.77567,206.68633,229.17867,238.87085,254.12152,-9.999742,-9.999057,-9.999062,-9.999852,-9.999382,-9.999388,-9.999354,-9.999587,-9.999273,-9.999814,-9.999888,-9.999484,-9.999295,-9.999065,-9.999623,-9.999145,-9.999381,-9.999056,-9.99943,-9.999615,-9.999143,-9.999795,-9.999838,-9.999658,-9.999616,-9.9998,-9.999448,-9.999215,-9.999058,-9.999626,-9.999816,-9.99952,-9.999158,-9.999308,-9.999545,-9.999357,-9.999205,-9.999506,-9.999683,-9.999209,-9.9999895,-9.999543,-9.999428,-9.999628,-9.999103,-9.9991455,-9.999936,-9.999467,-9.999748,-9.99912,-9.999807,-9.999134,-9.999681,-9.999262,-9.999087,-9.999329,-9.999385,-9.999264,-9.999793,-9.999045,-9.9995985,-9.999204,-9.999249,-9.999444,-9.9992075,-9.9998455,-9.999957,-9.999949,-9.999563,-9.999786,-9.999491,-9.999651,-9.999318,-9.999416,-9.999064,-9.999325,-9.9996,-9.999902,-9.999786,-9.99952,-9.999172,-9.999215,-9.999257,-9.9991865,-9.999605,-9.999594,-9.999224,-9.999279,-9.999259,-9.999697,-9.9996195,-9.999134,-9.999058,-9.999047,-9.999575,-9.999919,-9.999645,-9.999633,-9.999902,-9.999141,-9.999885,-9.999965,-9.999505,-9.99982,-9.999797,-9.99964,-9.999083,-9.9995775,-9.9999695,-9.999383,-9.999018,-9.999117,-9.99926,-9.99911,-9.999243,-9.999118,-9.99911,-9.999486,-9.99909,-9.999861,-9.999171,-9.9999275,-9.999972,-9.999925,-9.999671,-9.999307,-9.9994955,-9.999324,-9.999028,-9.999182,-9.999585,-9.999082,-9.999469,-9.999043,-9.999628,-9.9994335,-9.999068,-9.999732,-9.999809,-9.999425,-9.99959,-9.999719,-9.999516,-9.999942,-9.999832,-9.999641,-9.999447,-9.99934,-9.999968,-9.999992,0.639171,0.47615534,0.1366003,0.4112621,0.543977,0.6301188,0.72094375,0.41664115,0.6702276,0.2662457,0.34709758,0.0047021024,0.19731691,0.3105783,0.35764986,0.6188618,0.55722684,0.014176953,0.28426266,0.55528253,0.9861382,0.59125423,0.91971123,0.50413203,0.71612626,0.37045076,0.16731057,0.8361767,0.20203081,0.46268502,0.54416966,0.82547253,0.70076334,0.19353609,0.7197332,0.7577992,0.15850778,0.09100532,0.8406752,0.4743588,0.14548168,0.91383964,0.31233132,0.057911392,0.38550714,0.788842,0.45663434,0.87255025,0.6822182,0.27235323,0.8781251,0.8971649,0.6117316,0.5027711,0.7707731,0.8171592,0.99433446,0.3228524,0.10424189,0.9995735,0.07680203,0.16278757,0.87946606,0.8840557,0.45882654,0.5382355,0.17185123,0.19348888,0.08070494,0.8351659,0.59116447,0.3656219,0.38914752,0.8038363,0.21394636,0.6494243,0.2923405,0.096409395,0.81489897,0.2177272,0.5156461,0.28180742,0.15846203,0.38402006,0.6799602,0.0992625,0.42167094,0.5157946,0.5737303,0.61967856,0.27188474,0.33863726,0.8381059,0.9284707,0.81110543,0.14615615,0.5137047,0.4068576,0.27341366,0.6371842,0.46284974,0.6114867,0.71931726,0.91663635,0.60304374,0.14932536,0.88403726,0.54094154,0.1467738,0.97935086,0.7863954,0.2147064,0.012224621,0.14325804,0.65899223,0.5648787,0.65609366,0.8197612,0.6399177,0.8468733,0.76479703,0.25536442,0.5532024,0.95500815,0.39078063,0.5678974,0.21131837,0.987159,0.27899948,0.45318067,0.052973147,0.22060722,0.13576879,0.22578368,0.4504141,0.81624466,0.6962496,0.38475657,0.5542052,0.040127296,0.7824744,0.7515341,0.2940618,0.45921704,0.74931914,0.4590101,0.1761703,0.76585937,0.3804439,0.20216002,0.79364806,0.48445576,0.9997787,0.07572355,0.9185397,0.43292367,0.6824889,0.57344544,0.45387882,0.61218095,0.001530312,0.36701044,0.3732282,0.21642086,0.0032335173,0.9757738,0.6631197,0.84142756,0.23562978,0.8842848,0.24768245,0.6896844,0.093373105,0.47206926,0.018847544,0.3574926,0.7817249,0.3901984,0.37762666,0.60320383,0.5876514,0.8498338,0.6137263,0.64150596,0.8912183,0.18202206,0.07165835,0.54631984,0.14491297,0.46619728,0.5531275,0.9730491,0.3560192,0.5463067,0.9498098,0.6082786,0.12641688,0.27168056,0.449438,0.2710077,0.059393216,0.47376275,0.3349298,0.8534693,0.24378222,0.27263063,0.31725782,0.027660795,0.36858514,0.31543452,0.32232106,0.7514354,0.7665531,0.93814677,0.94667625,0.7495306,0.07630936,0.07085721,0.09998243,0.14326382,0.3722598,0.8195573,0.88503057,0.64455885,0.9708746,0.574863,0.7547003,0.663569,0.62627494,0.66573906,0.88241595,0.5472183,0.10965517,0.086363465,0.03911088,0.43472022,0.282755,0.81878805,0.7069662,0.6482738,0.7889657,0.13123439,0.5466046,0.9870477,0.65994346,0.044764873,0.2590037,0.21607089,0.7882748,0.030434562,0.7240241,0.24359426,0.24925096,0.50715107,0.8548116,0.5778587,0.81658524,0.8406002,0.26860788,0.308281,0.40139812,0.27045614,0.681128,0.55732554,0.77117866,0.025454784,0.045293983,0.27430618,0.24866389,0.9072126,0.21633524,0.986974,0.91918707,0.86734384,0.5860722,0.8918684,0.86775124,0.24765202,0.7032609,0.4580694,0.6150063,0.12584582,0.13061108,0.11944151,0.27304602,0.08538959,0.2935459,0.6501564,0.6911091,0.79428184,0.19728307,0.9433592,0.98402375,0.278235,0.6931662,0.32246152,0.7604209,0.323686,0.4490462,0.21253695,0.37495488,0.095260054,0.5237899,0.9992169,0.36044437,0.5078252,0.5861082,0.64059675,0.03762793,0.49785113,0.38858363,0.69295675,0.2873984,0.32729995,0.59859157,0.73461634,0.25285175,0.5567667,0.71841735,0.69814867,0.77477485,0.16508374,0.15479185,0.48362815,0.37302348,0.7408702,0.11581469,0.08464117,0.029988535,0.34612563,0.45165575,0.68815565,0.008550999,0.09454897,0.8842033,0.471434,0.16433838,0.5935435,0.8646248,0.57239705,0.65469956,0.5863223,0.4796355,0.59167236,0.54985625,0.39255446,0.61727005,0.50840545,0.3316757,0.74857223,0.35827267,0.8872402,0.8038483,0.3931879,0.70447254,0.16417824,0.42719653,0.7534679,0.57123446,0.34724474,0.54931104,0.39288715,0.42828634,0.8222923,0.8765563,0.94212073,0.12068056,0.70422703,0.2824587,0.027603716,0.52777815,0.5066046,0.5769824,0.07630827,0.103958726,0.1505021,0.24175929,0.50438327,0.6733676,0.35198468,0.0752788,0.7415916,0.42589715,0.761479,0.0033971865,0.91897255,0.9319753,0.81370807,0.79544336,0.23588327,0.9587119,0.71191025,0.42136034,0.19574885,0.54185784,0.008105425,0.14255908,0.63592,0.3044852,0.6324764,0.6508548,0.08161495,0.65241224,0.8424147,0.97779244,0.72876996,0.61530423,0.94752645,0.6066642,0.10435986,0.18537253,0.30024627,0.8787194,0.06873524,0.91032326,0.84761214,0.12825106,0.22760965,0.70036477,0.09428674,0.9861057,0.13853452,0.8474568,0.057899747,0.060172286,0.37916803,0.15240528,0.77621406,0.26485768,0.1740309,0.29064766,0.7386373,0.5348933,0.26158985,0.43255532,0.59368885,0.61983097,0.13413209,0.32573816,0.43871734,0.7316835,0.7375361,0.8791016,0.46889958,0.8362294,0.56079483,0.78738517,0.12909074,0.19669758,0.3654093,257.23004,205.25952,256.3495,287.5462,248.0553,279.42828,252.23164,293.8083,244.82593,241.14514,264.60312,242.02669,265.36676,285.9313,276.8894,264.85254,204.56178,216.75874,245.4952,212.06345,205.75478,284.3255,291.17203,219.69725,203.70792,225.91046,230.73822,262.73547,201.7526,212.36281,283.3116,294.07062,249.66954,283.85126,246.5827,207.68987,272.6758,240.09421,275.82172,225.84433,232.80176,201.71077,252.89136,240.62161,259.20868,247.87543,218.64772,248.03424,202.67117,238.984,290.77563,293.03915,289.35855,289.96945,286.17395,231.49643,251.10532,225.1938,206.88234,256.4651,239.51657,245.26834,247.59836,204.23398,203.37993,225.53943,267.85843,297.7295,265.553,295.24786,242.70523,286.44165,283.38336,251.81482,208.90456,257.36407,229.28513,290.7318,258.70337,223.44356,264.08783,275.03732,251.59811,292.53107,251.5335,244.22394,213.89952,236.25047,211.8138,220.5794,216.87543,233.37456,224.4222,295.09964,214.58566,281.3576,256.06107,241.79654,291.32068,239.49226,228.46638,218.16322,203.63048,299.67514,282.89703,265.6753,287.9343,239.81447,209.17609,262.6297,295.4711,205.0095,223.62189,286.34204,243.34543,237.4936,249.12177,232.68518,229.49867,224.16684,203.26491,272.76715,294.89102,286.48096,273.26846,273.41534,204.2877,210.98381,206.86124,265.20584,244.88943,266.12534,239.2653,286.19138,271.75153,267.04507,210.73386,233.14261,220.80898,273.75244,298.48633,268.37622,204.67131,289.64368,276.43658,290.26245,279.004,201.35966,207.23166,280.78134,-9.999485,-9.999401,-9.99988,-9.99983,-9.999996,-9.999282,-9.999148,-9.999958,-9.999139,-9.999945,-9.999827,-9.999956,-9.999576,-9.999011,-9.99982,-9.999912,-9.999579,-9.9990425,-9.999927,-9.999287,-9.999705,-9.999723,-9.999244,-9.999403,-9.999639,-9.999259,-9.999532,-9.999533,-9.999703,-9.999582,-9.999963,-9.99968,-9.999428,-9.999266,-9.999494,-9.999798,-9.999454,-9.999226,-9.99951,-9.999481,-9.999743,-9.99988,-9.999303,-9.999975,-9.999095,-9.99945,-9.999369,-9.999166,-9.99957,-9.999976,-9.999418,-9.999267,-9.99994,-9.999312,-9.999308,-9.999992,-9.9999,-9.999182,-9.9991665,-9.999685,-9.999133,-9.999587,-9.999473,-9.999556,-9.999567,-9.999451,-9.999944,-9.999353,-9.999919,-9.999077,-9.99981,-9.999687,-9.999805,-9.999417,-9.999404,-9.999712,-9.99989,-9.999068,-9.999573,-9.999242,-9.99952,-9.999031,-9.999762,-9.999584,-9.999476,-9.999041,-9.999508,-9.999519,-9.999463,-9.999605,-9.999481,-9.99913,-9.999719,-9.99981,-9.999058,-9.99957,-9.999909,-9.99912,-9.999596,-9.999688,-9.999179,-9.999336,-9.999998,-9.999264,-9.999145,-9.99914,-9.999104,-9.999027,-9.999755,-9.999626,-9.999572,-9.999876,-9.999124,-9.9998865,-9.999168,-9.999185,-9.9995575,-9.999532,-9.999246,-9.999302,-9.999073,-9.999327,-9.9998045,-9.999645,-9.999669,-9.999047,-9.999023,-9.999354,-9.999763,-9.999772,-9.999175,-9.999568,-9.999145,-9.999254,-9.999511,-9.999705,-9.999031,-9.999324,-9.999718,-9.999497,-9.99974,-9.999597,-9.999909,-9.999239,-9.999544,-9.999691,-9.999259,-9.999239,-9.999568,-9.999504,0.03882216,0.8428897,0.74364215,0.23163715,0.49048677,0.22178552,0.6055793,0.4489804,0.9163623,0.9438124,0.1631071,0.6749212,0.7188561,0.32485962,0.8829685,0.20882395,0.60495543,0.47757575,0.6093003,0.84457403,0.7257506,0.17652789,0.025987253,0.9859064,0.6156289,0.73053515,0.76787066,0.5010675,0.40560544,0.07712759,0.9088255,0.07926025,0.24527292,0.27416497,0.74946845,0.24720564,0.07141664,0.43434754,0.4136174,0.869559,0.22436135,0.31195417,0.12554419,0.7383186,0.48795158,0.52957517,0.623028,0.036754537,0.56178623,0.32868809,0.9017316,0.09641818,0.9912348,0.92983764,0.4863829,0.2328445,0.72820157,0.5609035,0.5382467,0.21526214,0.2952519,0.391415,0.32775486,0.7910391,0.04752018,0.3907967,0.24044213,0.62969697,0.86658025,0.550671,0.6625566,0.7994618,0.12169334,0.21295948,0.4997118,0.98608136,0.67981267,0.5607458,0.20580857,0.59258527,0.74313295,0.504703,0.34825593,0.88810426,0.375232,0.9950801,0.6716571,0.43368435,0.13610889,0.7123607,0.5050985,0.31398848,0.6695705,0.12510324,0.18162547,0.61493284,0.816849,0.9648539,0.37662333,0.03039601,0.8444544,0.3708865,0.24754128,0.33466703,0.96997195,0.4863897,0.425792,0.5019443,0.3766153,0.37071276,0.30467907,0.5455875,0.47557223,0.99561185,0.82659286,0.50989014,0.8268076,0.32439554,0.90867627,0.523794,0.91507274,0.3708023,0.67873424,0.6258858,0.7507315,0.6253023,0.62942946,0.5893559,0.30942422,0.2114435,0.022920458,0.044418756,0.61610794,0.8113304,0.35662258,0.41705018,0.46921277,0.86777097,0.95223355,0.40362936,0.9437976,0.18228506,0.6360729,0.33576652,0.031274755,0.21817888,0.36112952,0.7787455,0.42273897,0.25281885,0.33198494,0.7785485,0.788286,0.16736427,0.0092501305,0.09297396,0.28935695,0.34107473,0.30980217,0.53143716,0.52857065,0.8409118,0.4052178,0.69706166,0.64710814,0.026039753,0.98393834,0.37317148,0.2896904,0.9887286,0.26908764,0.9406588,0.5261725,0.9049269,0.56662345,0.6709716,0.68239623,0.49234113,0.97048306,0.33545634,0.23616292,0.21654218,0.25211942,0.024790008,0.6374578,0.38915554,0.9337675,0.9430794,0.4695175,0.7804938,0.536538,0.9851012,0.19607964,0.3125924,0.55515915,0.85639995,0.76419586,0.19247372,0.8593474,0.65614396,0.8763346,0.5008372,0.75938493,0.30444136,0.8475765,0.2756218,0.7643892,0.10603409,0.4270085,0.40084615,0.094159424,0.28666124,0.907423,0.59824944,0.13585345,0.7766466,0.8080405,0.6886941,0.019375224,0.8924157,0.8251331,0.78726494,0.91793686,0.30526364,0.75136036,0.5101915,0.0959181,0.64297056,0.16485944,0.7552983,0.5024531,0.29433584,0.99849665,0.4194633,0.3247048,0.6200598,0.10172686,0.5053654,0.2359409,0.7552459,0.8971784,0.044323962,0.52423203,0.67628855,0.36866117,0.99563,0.2329034,0.27227026,0.76375973,0.79602706,0.5184415,0.10457488,0.0819885,0.90606177,0.052181873,0.6621527,0.92458886,0.24737877,0.04191045,0.34999782,0.08424192,0.29925734,0.24015819,0.5147704,0.42221153,0.99205357,0.54271156,0.79544294,0.5694224,0.37800944,0.5500707,0.09987821,0.40123457,0.7795467,0.8094248,0.5604407,0.34524485,0.56357986,0.6901132,0.2526902,0.46615395,0.24697252,0.5420497,0.18665877,0.6566352,0.2777055,0.9320998,0.89702964,0.022678716,0.1815973,0.09005783,0.51381236,0.6743502,0.6247244,0.8565416,0.87987,0.6732118,0.00460204,0.27535322,0.7455861,0.15749842,0.9247148,0.03532768,0.08851064,0.23502532,0.752143,0.21853413,0.6609476,0.28531924,0.18054475,0.029035527,0.67236483,0.2241403,0.28975555,0.99908245,0.43963638,0.59023327,0.30457687,0.16792373,0.7709499,0.6859642,0.69117963,0.86467695,0.5084144,0.7589203,0.4828981,0.07482473,0.48116097,0.53940266,0.5052822,0.22626108,0.7467059,0.41369334,0.031238595,0.028987564,0.66039693,0.22867519,0.8922084,0.23077016,0.49657655,0.12957393,0.5363605,0.4044849,0.44835,0.35317385,0.9867398,0.92447424,0.8969754,0.12785867,0.34567907,0.37078106,0.33044818,0.5057445,0.7683958,0.59161294,0.3239813,0.345188,0.5798496,0.64173394,0.8413601,0.47511417,0.835949,0.9396055,0.26686642,0.23109126,0.69826096,0.80957353,0.3445376,0.30203474,0.45118847,0.21602394,0.59850556,0.4789453,0.4077335,0.5152989,0.33034822,0.68474686,0.85391724,0.48057246,0.2998755,0.90360653,0.65591294,0.8092372,0.7287787,0.59123766,0.6105523,0.15701269,0.9201797,0.22071724,0.44657114,0.85324067,0.74536175,0.92492616,0.67641914,0.5987662,0.81729543,0.8069455,0.6891773,0.8835294,0.8892519,0.8500076,0.857101,0.6734726,0.9874815,0.46896955,0.9641137,0.47160545,0.8463774,0.30557284,0.9699319,0.06608189,0.055327572,0.93581414,0.9587841,0.058981307,0.92397076,0.010058546,0.34675553,0.6533823,0.5349482,0.46875533,0.5844002,0.5102338,0.26537207,0.19412437,0.07258324,0.38117927,0.1528994,0.056126937,0.7896892,0.3633707,0.5028834,0.15584666,0.43396717,0.7498128,0.17068368,0.8056127,0.83374524,0.7477155,0.8996221,0.53976667,0.9230572,0.19246647,0.6391656,0.4030687,0.7643678,0.019256072,0.59730285,0.309159,0.7264034,256.18292,247.5509,241.8322,221.72641,247.00475,289.95996,204.75641,299.0052,222.08545,249.15363,277.1748,222.7599,219.53043,259.93314,290.20483,264.3145,203.74707,269.35193,270.35507,233.42912,209.86781,292.96222,238.48882,256.7762,211.95813,255.83502,271.98605,276.92862,244.43182,219.40994,250.76295,294.04694,226.60033,258.7823,224.29234,289.13776,284.96054,215.06387,284.33295,255.14339,249.39714,298.0097,206.93636,207.78658,210.90904,237.74179,227.25084,248.60242,241.76729,289.64044,257.6767,223.0866,249.12407,201.15231,275.7378,262.39612,268.82336,262.55298,269.66827,237.66492,211.21674,246.47617,200.1591,228.94618,286.93787,224.82498,282.6982,216.67554,299.76526,211.74054,258.6674,282.2848,242.32083,244.45291,261.59262,257.17282,230.43474,219.33755,239.1705,229.16939,229.4628,227.99637,278.22507,207.49443,232.81923,250.38698,255.53925,201.98932,279.6214,245.52,216.7771,238.63602,204.19614,258.92218,230.05328,267.0341,256.95154,293.94968,251.7791,249.71518,268.04617,243.68118,239.60608,291.69824,255.33287,247.66194,210.42975,272.79053,251.49638,270.4292,266.5404,223.91647,227.0489,217.59396,202.26263,234.13164,282.81702,241.44751,237.6629,254.03835,276.81006,253.21158,290.75342,299.60394,252.36249,207.7176,293.0687,224.40785,254.29674,210.75064,251.1633,265.51978,292.73917,268.97003,213.86755,280.26193,236.59819,261.9136,271.9696,260.67432,225.67659,279.94318,244.74088,205.70877,236.24387,266.11798,234.5054,227.88277,212.92162,281.1429,-9.9995,-9.999907,-9.999015,-9.99986,-9.999811,-9.99916,-9.9994335,-9.999082,-9.999476,-9.999472,-9.999309,-9.999354,-9.999964,-9.999819,-9.999472,-9.999187,-9.999328,-9.999281,-9.999373,-9.999825,-9.999259,-9.999581,-9.999256,-9.999902,-9.999506,-9.999213,-9.999032,-9.999097,-9.999959,-9.999018,-9.999999,-9.999964,-9.99983,-9.999462,-9.999094,-9.999825,-9.999322,-9.999475,-9.999018,-9.999352,-9.999122,-9.999426,-9.999498,-9.999934,-9.9994545,-9.99973,-9.999741,-9.999373,-9.99933,-9.999706,-9.999398,-9.999283,-9.999558,-9.999604,-9.999935,-9.999592,-9.999328,-9.999943,-9.999334,-9.99971,-9.999961,-9.999668,-9.9997835,-9.999137,-9.999606,-9.999959,-9.99975,-9.999391,-9.999501,-9.999959,-9.999507,-9.999104,-9.999123,-9.999664,-9.99954,-9.999395,-9.99991,-9.999099,-9.999796,-9.999523,-9.999298,-9.999127,-9.99933,-9.999529,-9.999645,-9.999581,-9.999803,-9.999978,-9.999745,-9.999099,-9.999732,-9.999282,-9.999186,-9.999484,-9.9994545,-9.999736,-9.999692,-9.999638,-9.999521,-9.999184,-9.999315,-9.999997,-9.999688,-9.999604,-9.999361,-9.999519,-9.999438,-9.999516,-9.999867,-9.999932,-9.99967,-9.999632,-9.999027,-9.999614,-9.999386,-9.999235,-9.99902,-9.999881,-9.999402,-9.999828,-9.999898,-9.999556,-9.9999485,-9.99902,-9.999726,-9.99967,-9.999689,-9.999588,-9.999742,-9.999436,-9.999829,-9.999895,-9.999559,-9.999202,-9.999972,-9.999332,-9.999621,-9.999881,-9.999916,-9.999846,-9.999947,-9.999159,-9.999294,-9.999025,-9.999374,-9.999594,-9.999471,-9.999263,-9.999252,-9.999847,0.8405395,0.4899531,0.15557215,0.053656846,0.9073092,0.07903749,0.49019513,0.46704555,0.2108235,0.59149706,0.06908697,0.91793466,0.19079898,0.54947394,0.052311927,0.77982026,0.5299146,0.17064495,0.56645525,0.8840749,0.042285662,0.8682272,0.028326662,0.09698481,0.12325795,0.4347101,0.37012324,0.7913993,0.9993339,0.75977063,0.36460763,0.3775515,0.51856863,0.95555836,0.49067768,0.04478922,0.71699315,0.097812556,0.45841676,0.773683,0.75010455,0.42993996,0.9079247,0.017453227,0.44864193,0.672689,0.28056568,0.19584337,0.37550166,0.8117075,0.7120219,0.5780687,0.44134927,0.42259568,0.7511653,0.5891905,0.67056227,0.11231151,0.6758219,0.22908887,0.37498733,0.41971782,0.055803128,0.59144944,0.9299475,0.12942357,0.95274854,0.32053652,0.20608023,0.16834818,0.57836413,0.055714697,0.06392813,0.29768264,0.09972937,0.8983277,0.97463375,0.1341327,0.65210474,0.35204768,0.014110221,0.80327654,0.6689872,0.9037585,0.90981257,0.86295295,0.3795516,0.0062070885,0.5173644,0.20474744,0.86028427,0.15545785,0.3484738,0.48408556,0.28058404,0.75635433,0.5704764,0.80539626,0.8308685,0.7464902,0.12689869,0.89151156,0.37369293,0.36895418,0.5450234,0.1559311,0.2432725,0.38309494,0.27770162,0.56394845,0.72261786,0.5332152,0.49045795,0.88231075,0.6032768,0.6665413,0.857885,0.31463873,0.9153665,0.37640592,0.58912075,0.24793272,0.7373741,0.8440094,0.015947558,0.58805275,0.3667698,0.46238968,0.8334069,0.81946284,0.19397281,0.92121077,0.964989,0.24575949,0.0900369,0.6689977,0.23726216,0.601819,0.16691278,0.47163498,0.03375374,0.36948392,0.08575206,0.9858967,0.7306862,0.21772163,0.39309397,0.7458295,0.7629526,0.3144869,0.94122046,0.20584162,0.83637947,0.7726502,0.9049252,0.36524808,0.7137413,0.8284559,0.22519512,0.30139557,0.8169721,0.5312386,0.8956069,0.66213816,0.58457166,0.45457113,0.5169665,0.6269637,0.26091218,0.7560391,0.7980105,0.3960119,0.08781406,0.10958682,0.12124728,0.4373948,0.031676244,0.55287856,0.7805502,0.56280786,0.25152865,0.566051,0.7870067,0.759523,0.45281285,0.62631804,0.989187,0.26606834,0.39388546,0.87392044,0.583776,0.654467,0.49633527,0.39479604,0.63170516,0.62530655,0.9021866,0.13965032,0.35174674,0.79825306,0.7204604,0.8848764,0.43971986,0.7367297,0.71475625,0.07822404,0.42548487,0.11135407,0.80643165,0.83326644,0.8646103,0.89960915,0.46280593,0.8834037,0.2807901,0.68196964,0.3704893,0.4120405,0.82667,0.02957211,0.16348517,0.528726,0.36919758,0.22145572,0.43879473,0.09656078,0.5824419,0.0181659,0.25570688,0.7642685,0.19078839,0.70748967,0.5835414,0.92161185,0.8213292,0.046582457,0.85949063,0.15103385,0.74723977,0.39284366,0.5726992,0.07368804,0.3426399,0.17463133,0.24858418,0.31684884,0.49405006,0.37952894,0.33315596,0.8640441,0.57182634,0.25183997,0.7026268,0.37704948,0.17044407,0.27955136,0.96993434,0.09108966,0.6897659,0.19774762,0.6693781,0.12952057,0.89581305,0.21900262,0.1147024,0.29112664,0.06916158,0.22942513,0.42038745,0.7651415,0.45440084,0.17078096,0.07726187,0.4274913,0.86462736,0.06414275,0.9592153,0.16050456,0.88035154,0.9545343,0.8513476,0.2491725,0.7261043,0.5407395,0.22621076,0.31755584,0.75632083,0.7962324,0.50990444,0.61564916,0.76425743,0.70222944,0.73869663,0.29614443,0.021682443,0.5887306,0.31215057,0.10243766,0.9339864,0.23341663,0.7255635,0.4185125,0.5641563,0.0210989,0.31937757,0.77237654,0.055116564,0.31758264,0.35916016,0.5235203,0.15846917,0.5410007,0.3291817,0.14069794,0.90887386,0.259237,0.93863297,0.75447625,0.6713672,0.5048135,0.7174148,0.52741486,0.92290014,0.0805213,0.70555705,0.8765804,0.21684085,0.059146658,0.52307314,0.24510364,0.73993003,0.081979565,0.76904917,0.57904243,0.4695278,0.016590666,0.7074726,0.03675281,0.05884536,0.8561499,0.7090553,0.86932564,0.31001756,0.7310781,0.7902563,0.4690628,0.5504265,0.99635744,0.8836126,0.49213162,0.4428661,0.88994193,0.35176337,0.4958119,0.5913544,0.4187957,0.27758822,0.28339785,0.7841562,0.30195132,0.752634,0.3137563,0.4315457,0.44653264,0.5451809,0.44049335,0.8987003,0.5640792,0.5874427,0.47600824,0.5928,0.80064255,0.20061128,0.37571868,0.8139443,0.62335235,0.8047332,0.31274527,0.30714568,0.035397593,0.69739,0.2944578,0.34834376,0.5873635,0.9606469,0.5618423,0.6756651,0.03466902,0.27137738,0.59027666,0.8357776,0.425116,0.50365347,0.4515947,0.4932688,0.005631942,0.57952595,0.47525176,0.6249525,0.086651884,0.89189065,0.6617942,0.9442606,0.27843753,0.44292933,0.38660362,0.07765346,0.50435954,0.83211386,0.9370695,0.39374778,0.08252517,0.20432696,0.9130672,0.6829529,0.4023203,0.18018572,0.7534347,0.42706057,0.42672646,0.47151735,0.22955406,0.9152989,0.08499177,0.21106064,0.81278425,0.4464995,0.9721553,0.5701927,0.5504968,0.33792228,0.97337884,0.1806469,0.09640216,0.163271,0.42888898,0.778335,0.8884757,0.79867357,0.7878421,0.07889473,0.35902497,0.56884366,0.4541578,0.85038835,0.5382435,0.09464303,0.9107641,0.94099534,0.5400446,266.79602,274.32846,213.67004,233.85674,243.74121,250.29242,241.2762,246.10477,210.67426,209.43724,229.85814,280.7868,272.1595,250.896,203.6569,224.5947,228.5461,250.31659,259.0063,207.73958,214.5609,227.4157,288.49915,258.5862,237.1694,260.80396,253.53038,216.46973,200.73683,276.59747,218.64984,277.839,211.7889,278.14984,276.74042,224.4895,237.72171,253.24715,202.98746,237.59871,204.87325,239.43521,295.81796,299.5604,222.03635,228.79982,266.0576,239.92245,268.24426,238.24408,298.47308,288.47458,215.21046,248.30959,290.8601,287.38885,209.855,220.54123,251.46211,269.38593,215.89407,249.74835,233.35129,259.1078,247.44966,203.68665,295.11304,298.9008,216.80823,265.98523,250.68268,259.11737,224.44098,201.49985,265.72772,291.2741,291.02527,205.01653,225.3552,230.4449,205.90791,236.37225,234.94302,227.96848,293.9239,200.43617,261.1322,246.37569,206.33258,230.6332,275.16974,226.53664,253.74765,201.92174,277.2812,279.80594,269.5651,215.83727,290.79214,209.25894,240.69214,259.45502,221.35303,245.88794,233.58676,278.87738,268.62115,238.47983,288.8792,284.89505,235.00497,242.7936,236.64014,252.04784,205.45514,290.40726,232.52823,259.1132,290.73474,227.57782,216.67067,294.74762,217.73929,209.24208,256.90912,240.18433,257.794,282.8988,208.77882,297.82245,299.72125,298.86118,282.77133,299.69577,298.43073,299.66992,206.1796,239.80862,245.31291,207.94046,256.93558,210.00853,297.19482,258.61487,298.00143,247.14326,220.11229,299.13562,289.7299,244.51624,-9.999632,-9.999593,-9.999801,-9.999819,-9.999018,-9.999244,-9.999898,-9.999155,-9.999041,-9.999333,-9.999995,-9.999601,-9.999369,-9.999678,-9.99932,-9.999411,-9.999675,-9.999204,-9.999888,-9.999743,-9.999049,-9.999095,-9.9994955,-9.999148,-9.999902,-9.999157,-9.999642,-9.999242,-9.999449,-9.99954,-9.999594,-9.999917,-9.999246,-9.999855,-9.999591,-9.999358,-9.999842,-9.999382,-9.999745,-9.999809,-9.999109,-9.999151,-9.999462,-9.999784,-9.999753,-9.999547,-9.999858,-9.999641,-9.999331,-9.999973,-9.999725,-9.999956,-9.999523,-9.999478,-9.999359,-9.999043,-9.999455,-9.999254,-9.999494,-9.999362,-9.999646,-9.999454,-9.999153,-9.99971,-9.99948,-9.999924,-9.999973,-9.9990425,-9.999157,-9.999034,-9.999135,-9.999451,-9.99927,-9.999871,-9.999655,-9.999354,-9.999864,-9.999408,-9.999447,-9.999032,-9.999453,-9.999718,-9.999415,-9.999358,-9.999691,-9.99945,-9.999504,-9.999244,-9.999987,-9.999557,-9.999052,-9.999141,-9.999237,-9.999049,-9.99919,-9.999888,-9.999757,-9.999621,-9.999702,-9.999411,-9.999203,-9.999174,-9.999015,-9.999339,-9.999034,-9.999728,-9.99976,-9.999317,-9.999367,-9.999866,-9.999091,-9.999755,-9.999178,-9.999553,-9.999263,-9.999655,-9.999423,-9.999304,-9.999814,-9.999966,-9.999977,-9.9992075,-9.999666,-9.999204,-9.999895,-9.999059,-9.99907,-9.9995575,-9.999523,-9.999056,-9.999571,-9.999786,-9.999026,-9.999145,-9.999575,-9.999738,-9.99979,-9.999363,-9.999586,-9.999727,-9.999086,-9.999402,-9.999158,-9.999252,-9.999179,-9.999597,-9.999156,-9.99936,-9.999807,-9.999261,0.5652288,0.9339315,0.55770487,0.7478212,0.33771703,0.28125492,0.51592994,0.5532214,0.58044416,0.66528046,0.669034,0.16671883,0.67413294,0.036051773,0.108843535,0.7993396,0.1639013,0.6568752,0.122072175,0.70342636,0.5444655,0.5812534,0.4522436,0.2419,0.07067616,0.8879451,0.60514754,0.14282055,0.70217454,0.10503953,0.39604086,0.60164565,0.5446685,0.07094606,0.5559759,0.014643576,0.9885768,0.45798954,0.80507016,0.46793476,0.91752577,0.04094297,0.60369307,0.8747373,0.5086575,0.7004933,0.2251465,0.35307238,0.27597564,0.94157344,0.65179616,0.20595148,0.27256346,0.20036213,0.67921185,0.15910614,0.52645075,0.6180527,0.09315563,0.4282912,0.3796773,0.55366653,0.8087156,0.989089,0.81570625,0.36953965,0.29338685,0.8806224,0.40907812,0.99581677,0.031810474,0.9831273,0.21194534,0.6745432,0.38136473,0.2702163,0.6385419,0.29438227,0.12847719,0.27120438,0.30660692,0.5424479,0.92706877,0.9079774,0.22223541,0.3657775,0.25447527,0.81911993,0.30269873,0.74017876,0.92759985,0.70151937,0.7640615,0.8949204,0.79928416,0.77783567,0.6940916,0.2910855,0.97654736,0.2973309,0.5588422,0.6462096,0.30760437,0.18172295,0.7695246,0.34731266,0.19734544,0.029608455,0.37696892,0.111436665,0.50183326,0.28445065,0.68564844,0.44779962,0.9736052,0.51790065,0.983022,0.52825344,0.41285545,0.9967343,0.6162969,0.37753683,0.17138597,0.07175013,0.81368434,0.9612253,0.9045651,0.84745973,0.36729226,0.98037714,0.20115525,0.12099608,0.96984464,0.37242016,0.29363927,0.39158085,0.27558497,0.66305256,0.10113714,0.76193494,0.45118755,0.4488773,0.93012637,0.31139725,0.0031577414,0.22718209,0.29718128,0.71752393,0.14526285,0.18364605,0.37547293,0.9685261,0.9378056,0.27025697,0.8536382,0.40919214,0.6247997,0.020774715,0.2789666,0.6214883,0.28909984,0.4459083,0.22759606,0.16503142,0.12913509,0.76620036,0.31722352,0.31122422,0.14058389,0.3711774,0.2540991,0.92829734,0.31982893,0.58990836,0.7611616,0.94479626,0.77106464,0.98198724,0.045493614,0.5808194,0.044766188,0.028754123,0.6398209,0.5149536,0.6159741,0.38356403,0.3443942,0.8204024,0.16429621,0.45349202,0.9345274,0.6689286,0.46520096,0.5479114,0.50660115,0.030693837,0.14807424,0.0025167174,0.04072329,0.06662837,0.19923986,0.31228405,0.26450446,0.5282875,0.32404247,0.3938328,0.028723368,0.53065664,0.84379214,0.84157664,0.37586623,0.15792112,0.20647834,0.024251468,0.3573017,0.37901312,0.6181092,0.76309824,0.7608666,0.3481646,0.34048688,0.47856995,0.31012326,0.23520178,0.45539266,0.92912894,0.4204687,0.92543155,0.5307048,0.27608588,0.7496653,0.6049889,0.36525294,0.14689086,0.51323116,0.12193437,0.59619224,0.60478336,0.9294276,0.249309,0.74476606,0.92789376,0.043751504,0.5309229,0.3062958,0.31674966,0.14777556,0.52924913,0.9668007,0.20873389,0.3279674,0.7965414,0.37618962,0.89503884,0.46796778,0.0799155,0.13676843,0.99596673,0.5959752,0.82745814,0.19763403,0.45169583,0.034008075,0.51954156,0.5263711,0.32014525,0.053273566,0.81357837,0.97085255,0.07153194,0.9582462,0.64213526,0.32651472,0.60837305,0.9404863,0.06993771,0.7587776,0.7886673,0.41194588,0.78207874,0.7781359,0.3276002,0.33506534,0.28078383,0.12973906,0.399713,0.62760603,0.75171447,0.80802286,0.5050624,0.33723688,0.23653711,0.22387893,0.3570362,0.05210913,0.8889524,0.49352857,0.4521699,0.9740411,0.7144635,0.4756838,0.331589,0.068503655,0.97924995,0.41867498,0.31639704,0.7069934,0.81501675,0.5386601,0.4093507,0.707298,0.9774356,0.72752196,0.1570271,0.9423814,0.9732382,0.71725017,0.3946321,0.62860346,0.06245658,0.90315664,0.5143768,0.8708286,0.84123635,0.92691624,0.639396,0.2552601,0.37173754,0.7914776,0.91429204,0.4736561,0.15064463,0.7540974,0.2862515,0.48185065,0.13227704,0.32188603,0.63464296,0.8106472,0.94166034,0.17569262,0.19304337,0.29407963,0.587708,0.97985137,0.93614686,0.8405717,0.02620014,0.35624048,0.59463245,0.011628275,0.66693187,0.74045765,0.8160365,0.84104806,0.88261247,0.0711487,0.8989867,0.97475845,0.4168518,0.13669337,0.28926903,0.49182004,0.41090083,0.276433,0.09197279,0.68734396,0.3883402,0.90047145,0.11048286,0.15737055,0.21775864,0.9536175,0.076466806,0.24726667,0.103641525,0.0413075,0.27288043,0.3405656,0.14998767,0.51837134,0.16329993,0.3755023,0.9497281,0.8958037,0.98416775,0.34084278,0.18396701,0.8870497,0.11773594,0.7778607,0.5278507,0.9345038,0.12104616,0.3192234,0.026860172,0.71437854,0.8270822,0.34825006,0.39791596,0.62681943,0.27854878,0.519083,0.9585388,0.9732782,0.24999642,0.18574189,0.92319125,0.2299785,0.78481007,0.4593966,0.18952563,0.4418934,0.75275475,0.47553676,0.47977385,0.516905,0.6218342,0.986334,0.6328223,0.87600803,0.23837951,0.29930744,0.5477805,0.17647119,0.3403492,0.79772884,0.12769036,0.8723695,0.1560829,0.75527936,0.41855234,0.66972154,0.3795148,0.75438255,0.45185962,0.64733654,0.83693033,0.7853063,0.52869916,0.44457012,0.031068115,0.995698,0.86542577,0.29396066,0.3056323,0.7761462,0.5815433,0.4590591,0.6379277,203.08049,242.811,200.0787,248.54701,240.53275,206.88977,264.96545,215.722,207.14218,248.2029,260.38293,246.59158,255.92654,290.20236,282.13013,255.587,289.51746,250.55061,256.14774,212.82437,283.77695,234.53087,295.53558,263.51688,262.4394,295.93118,249.12567,230.53714,244.58417,212.62454,222.62276,202.04688,220.03893,219.85342,298.00995,225.98215,237.55687,233.73161,277.78552,292.03333,241.16255,239.44547,269.768,208.34856,223.83221,247.22945,220.80157,225.7253,267.53107,219.36331,263.37506,292.40854,238.76868,248.44582,284.12405,266.40955,297.5755,221.04996,205.62082,256.34137,216.44402,236.91107,213.73282,215.86444,256.87595,251.31393,216.1751,265.14798,213.08633,254.30765,244.74179,278.06122,262.01956,248.49234,205.56573,285.15247,291.18823,246.23334,286.69305,297.73892,222.13132,274.70645,272.9896,218.96129,263.71072,289.10516,210.93655,235.38228,240.58383,289.90942,238.94185,276.05884,239.10864,254.86401,282.10757,204.39113,238.20418,291.72028,279.3937,255.42195,223.81288,201.32336,262.53845,218.35716,291.38098,248.38783,276.37997,251.07683,295.05258,210.5348,252.41638,265.33124,294.82996,279.9688,295.2437,275.68787,202.7976,207.2586,262.63266,295.0467,288.30432,231.05023,298.57654,286.71002,222.34149,209.956,297.5865,204.87299,243.4733,242.39302,209.53899,221.00655,211.91463,266.0036,223.22115,266.37555,278.43994,214.11813,254.79947,234.70715,294.82663,267.89825,282.26373,285.57803,216.04143,222.16176,264.46344,216.57985,208.0961,251.9738,-9.999269,-9.999741,-9.999561,-9.999911,-9.999339,-9.999749,-9.999292,-9.999522,-9.999454,-9.9992895,-9.999531,-9.99933,-9.999341,-9.99938,-9.999905,-9.999054,-9.999979,-9.999243,-9.999734,-9.999235,-9.999104,-9.999684,-9.999259,-9.999619,-9.999497,-9.999474,-9.999353,-9.999263,-9.999088,-9.999558,-9.999322,-9.999186,-9.9993925,-9.9999075,-9.999958,-9.999795,-9.999834,-9.999768,-9.999121,-9.999825,-9.999527,-9.999656,-9.999941,-9.999142,-9.999984,-9.999141,-9.999887,-9.9990835,-9.999148,-9.9991665,-9.999867,-9.999421,-9.999081,-9.999978,-9.999075,-9.999531,-9.999142,-9.999553,-9.999812,-9.999398,-9.999295,-9.9992285,-9.999865,-9.999482,-9.999524,-9.999773,-9.999741,-9.999358,-9.999916,-9.999248,-9.999274,-9.999893,-9.999962,-9.999569,-9.9997225,-9.999103,-9.999036,-9.999721,-9.999645,-9.999536,-9.999113,-9.9998455,-9.999898,-9.999262,-9.999967,-9.999528,-9.9996195,-9.999813,-9.99977,-9.999597,-9.999661,-9.999434,-9.999925,-9.999199,-9.999759,-9.999627,-9.999813,-9.999361,-9.999325,-9.999499,-9.999843,-9.999769,-9.999987,-9.999241,-9.999264,-9.999075,-9.9998665,-9.99927,-9.999766,-9.999045,-9.999036,-9.999232,-9.999256,-9.999415,-9.999601,-9.999707,-9.999876,-9.999688,-9.999064,-9.999532,-9.99921,-9.99905,-9.999712,-9.999656,-9.999218,-9.999016,-9.999569,-9.999398,-9.999709,-9.999183,-9.999058,-9.999427,-9.999155,-9.999367,-9.999406,-9.99968,-9.999578,-9.999454,-9.999143,-9.999611,-9.999365,-9.999709,-9.9992285,-9.9998255,-9.999111,-9.999831,-9.999511,-9.999469,-9.99995,-9.999711,0.5344577,0.28066808,0.56196564,0.5902792,0.8473387,0.24633567,0.92718124,0.17364842,0.31536132,0.22439669,0.46772173,0.23150134,0.13030241,0.7544915,0.32698,0.59160626,0.5460109,0.84683007,0.23899049,0.8182671,0.7197824,0.8125036,0.8256115,0.40416914,0.66582596,0.0867179,0.0084044915,0.49205506,0.721172,0.40177187,0.29393357,0.015860511,0.93151456,0.4811004,0.54983306,0.9995074,0.27758396,0.22854643,0.5583765,0.6666239,0.85158247,0.21441942,0.6990569,0.017201606,0.530989,0.21839866,0.08578203,0.10198945,0.039713096,0.7290501,0.6342606,0.51234406,0.12498403,0.25547478,0.8394662,0.8280061,0.81155413,0.012060473,0.057682104,0.7739566,0.08708117,0.5193988,0.8415829,0.7520876,0.007182941,0.7731886,0.33688733,0.19361727,0.84651196,0.22044875,0.54851544,0.6421493,0.58298194,0.6989305,0.4031829,0.41380137,0.20955233,0.47619122,0.65416205,0.44766036,0.7429968,0.47871348,0.36874366,0.76017255,0.63620025,0.6808348,0.8399061,0.72613007,0.97575134,0.4643534,0.7247778,0.04549828,0.5940095,0.5128606,0.5878437,0.46860144,0.6618377,0.83293724,0.26350665,0.24366878,0.7788333,0.74646133,0.5429722,0.26375026,0.3656472,0.12205635,0.7138406,0.7608406,0.60281974,0.33415812,0.16791728,0.68858635,0.4469567,0.04436514,0.5672564,0.89869404,0.6294232,0.9793584,0.092907295,0.51271373,0.3846658,0.79488826,0.30746242,0.9191275,0.9108379,0.78182805,0.97138745,0.9847524,0.8531674,0.022702204,0.621023,0.7043253,0.22311302,0.6966194,0.36192545,0.8646154,0.94498384,0.8819606,0.39050183,0.66352,0.9537454,0.9776376,0.07475392,0.14165574,0.9068708,0.07851684,0.098995164,0.4659044,0.94835365,0.8669782,0.47114196,0.24303971,0.36649755,0.38048944,0.3541504,0.3041829,0.04842617,0.5725111,0.68421566,0.18098183,0.96466625,0.32582006,0.47631285,0.17308696,0.5422008,0.43860963,0.94000804,0.90531296,0.24555893,0.15075591,0.8892247,0.80251575,0.43217945,0.5427292,0.58730876,0.9010511,0.75740033,0.16942962,0.77507013,0.7471421,0.18903506,0.96626693,0.43212372,0.9690648,0.31306309,0.62832534,0.7866172,0.79370797,0.32908842,0.5066318,0.34556115,0.1002444,0.90521127,0.3832993,0.3292787,0.9103993,0.17307699,0.36895168,0.7688117,0.7769159,0.7559714,0.7624208,0.4072027,0.6700012,0.10266004,0.46105045,0.8847699,0.3703581,0.79471564,0.18433845,0.26636884,0.5759068,0.025358567,0.6020128,0.85619676,0.77020776,0.8782154,0.605358,0.82230324,0.3943509,0.10723012,0.23251477,0.41980323,0.44982743,0.3976,0.24261324,0.09185766,0.9083403,0.8951799,0.93775445,0.4116088,0.8328249,0.060170095,0.23731631,0.043149915,0.8760627,0.9832404,0.8160704,0.35087004,0.99301636,0.58498734,0.31982517,0.28746068,0.10150419,0.64765805,0.93925524,0.6288832,0.5287214,0.6787367,0.7280878,0.8089835,0.45152652,0.28626585,0.37735057,0.84606636,0.17912877,0.1262947,0.93639624,0.74632484,0.10586514,0.2034781,0.3999192,0.6237884,0.58933526,0.11924875,0.16451561,0.5822025,0.3976624,0.9056206,0.66830647,0.801052,0.6321766,0.47481045,0.6505067,0.5119758,0.8057609,0.059799645,0.014172987,0.637021,0.878043,0.19765095,0.7158634,0.6288858,0.41249686,0.2579455,0.32608235,0.153792,0.030521471,0.5082303,0.33682522,0.5155604,0.8285316,0.7492474,0.56472075,0.7964325,0.8807934,0.21563967,0.67301345,0.32791767,0.47523862}; + float x[7500] ={0.5786382f, 0.16236664f, 0.069020785f, 0.9840061f, 0.941816f, 0.76720303f, 0.7794372f, 0.46979624f, 0.73381734f, 0.9957244f, 0.6167372f, 0.53088397f, 0.28015637f, 0.826945f, 0.83352476f, 0.66504276f, 0.5793391f, 0.47484478f, 0.7076381f, 0.49456358f, 0.62396896f, 0.53332835f, 0.6388812f, 0.68836075f, 0.26663998f, 0.0014623206f, 0.19409843f, 0.56639415f, 0.98213744f, 0.68497056f, 0.867037f, 0.76840234f, 0.318186f, 0.28759065f, 0.11965875f, 0.53291357f, 0.53767395f, 0.55705845f, 0.7467155f, 0.1575149f, 0.18076386f, 0.8174763f, 0.22883898f, 0.5071535f, 0.86735153f, 0.9635827f, 0.24558435f, 0.15767147f, 0.458882f, 0.71102697f, 0.21914826f, 0.16241662f, 0.27248728f, 0.89015275f, 0.71070856f, 0.55088985f, 0.98992974f, 0.70927286f, 0.9261268f, 0.50781846f, 0.62151235f, 0.4590896f, 0.7487442f, 0.21744072f, 0.2636398f, 0.084352165f, 0.46951914f, 0.383644f, 0.6749645f, 0.24111961f, 0.83259743f, 0.05546627f, 0.4790621f, 0.68884027f, 0.90992177f, 0.23907907f, 0.5342047f, 0.221003f, 0.29615387f, 0.43343517f, 0.16554528f, 0.73144174f, 0.52923626f, 0.10688303f, 0.78197056f, 0.39259177f, 0.43832788f, 0.052234255f, 0.5795483f, 0.97033966f, 0.7392455f, 0.086584255f, 0.9092887f, 0.9402065f, 0.9126419f, 0.44749174f, 0.20514569f, 0.8749829f, 0.30917913f, 0.10170506f, 0.37034252f, 0.7427814f, 0.5497875f, 0.3116048f, 0.12112484f, 0.07918618f, 0.6003074f, 0.6188079f, 0.6292188f, 0.26580265f, 0.42029652f, 0.9863358f, 0.41489154f, 0.23757206f, 0.30395788f, 0.75231904f, 0.76751274f, 0.6324773f, 0.3231405f, 0.5016677f, 0.86029065f, 0.575702f, 0.7473972f, 0.118974194f, 0.115586124f, 0.62481487f, 0.91101325f, 0.6137756f, 0.71462154f, 0.995567f, 0.93439484f, 0.37260458f, 0.6033152f, 0.3444346f, 0.91579247f, 0.7452442f, 0.97466874f, 0.6299154f, 0.35426098f, 0.50121397f, 0.14155711f, 0.78726757f, 0.028531995f, 0.8435531f, 0.6444501f, 0.8826095f, 0.25354537f, 0.5547923f, 0.99555415f, 0.8430975f, 246.29712f, 253.4231f, 282.26755f, 215.6161f, 251.57019f, 239.20515f, 296.2021f, 234.32518f, 278.9852f, 235.4248f, 238.70155f, 256.9956f, 212.62695f, 288.38763f, 231.21237f, 284.80396f, 261.86835f, 223.92522f, 205.86221f, 234.742f, 262.11407f, 298.1942f, 242.60652f, 238.83704f, 251.6588f, 267.23315f, 294.4865f, 223.47488f, 259.24976f, 251.82695f, 265.01166f, 234.65732f, 265.1853f, 202.15352f, 244.42313f, 253.90427f, 212.09233f, 227.62961f, 237.77951f, 261.36838f, 234.32147f, 240.81522f, 273.62595f, 221.19333f, 284.11353f, 216.00859f, 284.36948f, 243.90376f, 282.61584f, 256.97165f, 275.08722f, 253.8055f, 265.1405f, 298.87567f, 223.393f, 288.02148f, 287.26102f, 276.36237f, 290.52777f, 299.57062f, 224.73566f, 290.82623f, 231.3513f, 238.51828f, 230.74028f, 224.97539f, 290.11844f, 238.00816f, 290.39606f, 291.32538f, 272.94766f, 211.88446f, 291.66742f, 210.34077f, 285.62628f, 246.31918f, 283.68738f, 282.34418f, 223.43613f, 245.08679f, 235.22693f, 246.01146f, 224.03375f, 280.5359f, 226.01413f, 262.18884f, 237.87335f, 238.89404f, 259.04294f, 202.59842f, 294.69302f, 209.01956f, 244.75763f, 264.3232f, 293.4627f, 287.69165f, 236.79088f, 282.37012f, 222.24211f, 293.5885f, 249.6388f, 273.91916f, 215.40356f, 255.45584f, 268.4702f, 275.81577f, 259.25064f, 224.95108f, 250.37906f, 267.89093f, 256.31766f, 227.89124f, 204.10915f, 263.38596f, 213.62708f, 218.84116f, 289.00494f, 216.93646f, 200.29439f, 284.1103f, 216.20671f, 260.57642f, 248.57745f, 241.73776f, 244.7205f, 286.86218f, 206.42664f, 204.06395f, 216.60626f, 224.02377f, 219.4697f, 287.2509f, 246.91132f, 289.83777f, 292.73767f, 202.73048f, 206.4165f, 294.0605f, 276.23276f, 288.51318f, 279.45175f, 253.69833f, 281.3311f, 249.44318f, 287.76288f, 262.2878f, 238.2247f, 203.41438f, 208.8359f, 274.0062f, -9.999092f, -9.99934f, -9.999794f, -9.999654f, -9.999987f, -9.999574f, -9.99965f, -9.999892f, -9.999203f, -9.999798f, -9.999658f, -9.999974f, -9.999982f, -9.999003f, -9.999369f, -9.999311f, -9.999708f, -9.999327f, -9.999302f, -9.999419f, -9.999553f, -9.9991665f, -9.999842f, -9.9991665f, -9.999702f, -9.999081f, -9.9993725f, -9.999735f, -9.999399f, -9.999073f, -9.999045f, -9.999458f, -9.99971f, -9.999414f, -9.999165f, -9.999782f, -9.999417f, -9.999513f, -9.999398f, -9.999933f, -9.999367f, -9.999933f, -9.999302f, -9.999572f, -9.999926f, -9.999371f, -9.999746f, -9.999628f, -9.9995165f, -9.999816f, -9.9998255f, -9.999983f, -9.999482f, -9.99976f, -9.999302f, -9.999825f, -9.999026f, -9.999029f, -9.999147f, -9.9995f, -9.999214f, -9.999216f, -9.999818f, -9.999334f, -9.999354f, -9.999414f, -9.999564f, -9.99962f, -9.999615f, -9.999496f, -9.999803f, -9.999454f, -9.999789f, -9.999615f, -9.999473f, -9.999701f, -9.999164f, -9.999112f, -9.9991865f, -9.999779f, -9.999639f, -9.999739f, -9.999949f, -9.999005f, -9.999157f, -9.999394f, -9.999148f, -9.999729f, -9.999721f, -9.999721f, -9.999678f, -9.999215f, -9.99921f, -9.999848f, -9.999702f, -9.999167f, -9.999995f, -9.999203f, -9.999381f, -9.999537f, -9.999643f, -9.999887f, -9.999234f, -9.999761f, -9.999863f, -9.9999275f, -9.99965f, -9.999459f, -9.999674f, -9.999408f, -9.999761f, -9.999802f, -9.999465f, -9.999648f, -9.999447f, -9.999051f, -9.999212f, -9.999952f, -9.999188f, -9.999153f, -9.999513f, -9.999785f, -9.999538f, -9.999458f, -9.999802f, -9.999176f, -9.999821f, -9.999529f, -9.999089f, -9.999206f, -9.999853f, -9.999218f, -9.999763f, -9.999283f, -9.999687f, -9.999333f, -9.9996195f, -9.999563f, -9.99978f, -9.999214f, -9.999417f, -9.999161f, -9.999615f, -9.999529f, -9.999715f, -9.99965f, -9.999793f, -9.999159f, -9.999804f, -9.999826f, 0.25581473f, 0.011998488f, 0.19125576f, 0.26596868f, 0.21618238f, 0.7962773f, 0.8030581f, 0.7543603f, 0.37575766f, 0.764879f, 0.10974313f, 0.06437898f, 0.26072952f, 0.30300763f, 0.029973997f, 0.025493756f, 0.21206349f, 0.7668091f, 0.53181326f, 0.36343664f, 0.5012292f, 0.17466855f, 0.188394f, 0.73864985f, 0.4810524f, 0.42596745f, 0.17328279f, 0.2649388f, 0.5691122f, 0.6979966f, 0.40108117f, 0.680846f, 0.8891427f, 0.36562127f, 0.5258834f, 0.02162829f, 0.34679192f, 0.51932955f, 0.5934363f, 0.8976068f, 0.17759448f, 0.84487504f, 0.08563967f, 0.8079017f, 0.53375924f, 0.5292685f, 0.7386051f, 0.84675163f, 0.52025354f, 0.402771f, 0.25339442f, 0.020660425f, 0.8532977f, 0.26857603f, 0.08696012f, 0.30953142f, 0.05712433f, 0.52134746f, 0.668039f, 0.8811842f, 0.84066904f, 0.5784957f, 0.13710192f, 0.25812075f, 0.12778813f, 0.6114538f, 0.68826395f, 0.6296169f, 0.050615292f, 0.60265064f, 0.59383374f, 0.50250226f, 0.5533876f, 0.80024f, 0.15964289f, 0.44098398f, 0.3639451f, 0.9836441f, 0.59009975f, 0.42786047f, 0.66358715f, 0.77674544f, 0.96205765f, 0.30722687f, 0.07275952f, 0.8073388f, 0.8589582f, 0.1655514f, 0.942791f, 0.7421209f, 0.33589354f, 0.031047517f, 0.2333922f, 0.32696965f, 0.06680667f, 0.43655157f, 0.60084665f, 0.924222f, 0.5181169f, 0.8633322f, 0.07042168f, 0.3576994f, 0.23789743f, 0.98523647f, 0.35718223f, 0.09434685f, 0.7895948f, 0.6365413f, 0.7331945f, 0.8172492f, 0.2427676f, 0.23792028f, 0.7375947f, 0.72343403f, 0.47277793f, 0.53527576f, 0.30485073f, 0.64892334f, 0.15171374f, 0.8003455f, 0.9694175f, 0.3611101f, 0.8037058f, 0.7925937f, 0.18575527f, 0.81588566f, 0.094868064f, 0.9775748f, 0.6791609f, 0.26662946f, 0.18830737f, 0.595805f, 0.49300948f, 0.9033739f, 0.663468f, 0.3000145f, 0.57594025f, 0.8624458f, 0.18944798f, 0.65868706f, 0.35742447f, 0.099066f, 0.2832066f, 0.6912541f, 0.24243657f, 0.9277832f, 0.64250916f, 0.9440414f, 0.2378183f, 0.055244252f, 0.76272976f, 0.67200613f, 0.49664533f, 0.5904184f, 0.17577513f, 0.7822792f, 0.61906105f, 0.6896018f, 0.873862f, 0.9968526f, 0.4556378f, 0.87811166f, 0.86004007f, 0.41853464f, 0.5995596f, 0.40827745f, 0.28851208f, 0.5202819f, 0.19265123f, 0.92939705f, 0.70689267f, 0.11201124f, 0.98409003f, 0.18970507f, 0.7182739f, 0.5939693f, 0.05994234f, 0.021280153f, 0.14513102f, 0.40208468f, 0.22757782f, 0.23340172f, 0.3629895f, 0.13855931f, 0.78980845f, 0.8154337f, 0.9686873f, 0.03149764f, 0.027852392f, 0.7822175f, 0.3670333f, 0.78024536f, 0.44308364f, 0.7551719f, 0.7001006f, 0.99656695f, 0.7096177f, 0.6460425f, 0.3090078f, 0.3817309f, 0.75382084f, 0.24751845f, 0.9919141f, 0.8101396f, 0.72690064f, 0.58389014f, 0.13931125f, 0.4260997f, 0.19920675f, 0.29389992f, 0.22849065f, 0.054567583f, 0.0286403f, 0.68753535f, 0.6393382f, 0.83747303f, 0.43944475f, 0.16854768f, 0.659512f, 0.25002992f, 0.015794016f, 0.9449101f, 0.7541057f, 0.945847f, 0.97127223f, 0.59012526f, 0.04557803f, 0.114047214f, 0.7673727f, 0.4418709f, 0.1393514f, 0.41973236f, 0.5081946f, 0.282509f, 0.30676988f, 0.2546641f, 0.6687642f, 0.31170198f, 0.43019253f, 0.81878066f, 0.9186455f, 0.787344f, 0.119964f, 0.48843786f, 0.26080957f, 0.43372f, 0.7264191f, 0.7316731f, 0.52168936f, 0.3228819f, 0.5850103f, 0.58188486f, 0.5764724f, 0.85721606f, 0.0048306463f, 0.9518531f, 0.51219267f, 0.9845728f, 0.72086376f, 0.21577734f, 0.14109355f, 0.16697218f, 0.70463514f, 0.54204077f, 0.5187638f, 0.08548192f, 0.021048365f, 0.8778848f, 0.19857538f, 0.04883652f, 0.7117264f, 0.10805124f, 0.49904156f, 0.22152025f, 0.6800811f, 0.17553183f, 0.637131f, 0.4801609f, 0.5453409f, 0.25295126f, 0.48752138f, 0.5394039f, 0.7378793f, 0.89846796f, 0.30146414f, 0.21664028f, 0.27394173f, 0.022367671f, 0.9892407f, 0.19886415f, 0.41262844f, 0.30491787f, 0.49006933f, 0.81182134f, 0.673692f, 0.2412966f, 0.17482981f, 0.5432391f, 0.8450185f, 0.69215244f, 0.70803803f, 0.04421597f, 0.29316452f, 0.21701345f, 0.111889146f, 0.85679144f, 0.92165715f, 0.093697235f, 0.3446256f, 0.46299627f, 0.4249108f, 0.7948484f, 0.19556557f, 0.7571282f, 0.01646797f, 0.8894279f, 0.19658394f, 0.26087877f, 0.70531607f, 0.6966002f, 0.5969214f, 0.5227917f, 0.36881882f, 0.9858828f, 0.23796275f, 0.4213183f, 0.48533306f, 0.44627303f, 0.15690878f, 0.6434008f, 0.41254497f, 0.99109685f, 0.20189007f, 0.5941583f, 0.18635221f, 0.6158875f, 0.42995065f, 0.027945405f, 0.8306056f, 0.3877798f, 0.982836f, 0.49713424f, 0.91654354f, 0.6155134f, 0.814247f, 0.3077533f, 0.22847779f, 0.88966215f, 0.8747604f, 0.41640446f, 0.9716281f, 0.18517044f, 0.033389226f, 0.026901966f, 0.41404715f, 0.7838385f, 0.9055906f, 0.63307714f, 0.6555554f, 0.61210406f, 0.8100642f, 0.7994826f, 0.50656956f, 0.7002863f, 0.122354865f, 0.73366094f, 0.92528874f, 0.50401425f, 0.3586611f, 0.3649591f, 0.8697877f, 0.09153776f, 0.56987906f, 0.4228477f, 0.72918344f, 0.21651368f, 0.273237f, 0.1320687f, 0.256684f, 0.3676141f, 0.1802598f, 0.8279442f, 0.5993243f, 0.99537796f, 0.70956576f, 0.6580005f, 0.9079618f, 0.06857852f, 0.33703786f, 0.42991522f, 0.46704793f, 0.30789334f, 0.97041386f, 0.067041285f, 0.48089835f, 0.23312177f, 0.09135661f, 0.6173484f, 0.47475886f, 0.9562112f, 0.99144304f, 0.50248766f, 0.5567772f, 0.6791836f, 0.5094131f, 0.5138229f, 0.9128905f, 0.5559054f, 0.28739175f, 0.5442868f, 0.1325101f, 0.039360367f, 0.9252663f, 0.30213857f, 0.5769297f, 0.24732989f, 0.7464911f, 0.16295283f, 0.22247133f, 0.6684257f, 0.30283514f, 0.31917402f, 0.2872067f, 0.41503724f, 0.81451225f, 0.03269196f, 0.820269f, 0.5588804f, 0.26527935f, 0.6293965f, 0.40942776f, 0.6733743f, 0.5519464f, 0.7554137f, 0.28561452f, 0.19815777f, 0.14119685f, 0.8302559f, 0.47257373f, 0.45373413f, 0.26654762f, 0.51656854f, 0.16259237f, 0.8570836f, 0.6660475f, 0.9988463f, 0.2234983f, 0.29011694f, 0.19929285f, 0.87688833f, 288.208f, 299.0334f, 234.06802f, 288.59332f, 285.71396f, 208.14828f, 243.33327f, 263.37518f, 222.83241f, 267.64508f, 236.68651f, 240.05948f, 241.17122f, 227.03455f, 229.1796f, 231.68953f, 267.16785f, 205.02823f, 264.77625f, 237.24646f, 249.54239f, 232.01376f, 208.56255f, 210.85419f, 239.4313f, 285.38928f, 207.99615f, 219.70026f, 286.46414f, 259.6215f, 264.591f, 240.25525f, 212.3435f, 223.9664f, 258.98178f, 278.75095f, 267.05542f, 200.13255f, 271.41925f, 235.1554f, 277.16098f, 235.27489f, 218.60641f, 299.13928f, 237.70187f, 218.95384f, 233.26817f, 239.93466f, 210.01537f, 237.0251f, 236.5253f, 272.3498f, 248.93144f, 249.78705f, 202.80908f, 296.07632f, 248.54794f, 228.7884f, 238.64236f, 214.01402f, 231.23134f, 243.41833f, 254.53098f, 229.02164f, 210.59755f, 268.93982f, 277.32697f, 297.97763f, 259.46844f, 229.38896f, 288.10034f, 251.99005f, 273.70062f, 277.30673f, 212.11809f, 205.43094f, 270.62506f, 244.42522f, 280.7068f, 252.17372f, 221.36655f, 231.1006f, 224.59811f, 239.97418f, 257.73175f, 290.97693f, 205.1341f, 217.40971f, 275.88208f, 201.61108f, 280.00003f, 289.00586f, 267.0944f, 231.31201f, 211.03806f, 213.06203f, 269.1713f, 265.57556f, 248.42055f, 209.8977f, 286.6746f, 221.91562f, 215.06145f, 229.53949f, 269.93027f, 276.57254f, 250.9029f, 288.37958f, 228.52266f, 267.0228f, 297.99734f, 214.70332f, 253.89653f, 231.25943f, 204.15068f, 276.6967f, 213.42561f, 222.77573f, 246.64607f, 206.99153f, 251.96185f, 275.08154f, 218.24387f, 211.39914f, 266.65384f, 298.70865f, 287.00455f, 227.15556f, 247.37427f, 213.96188f, 272.59308f, 224.01898f, 235.20276f, 253.20197f, 209.47455f, 210.07729f, 261.2526f, 239.28952f, 219.84111f, 211.5859f, 263.7782f, 225.82002f, 209.55066f, 225.2778f, 276.13922f, 208.97437f, 274.6557f, 297.25998f, 287.32483f, 205.43816f, -9.999689f, -9.999144f, -9.999799f, -9.999373f, -9.999519f, -9.9993925f, -9.999233f, -9.999142f, -9.99984f, -9.999262f, -9.999546f, -9.999872f, -9.999391f, -9.999968f, -9.999606f, -9.999656f, -9.999715f, -9.99956f, -9.999932f, -9.999743f, -9.999814f, -9.999712f, -9.999522f, -9.999528f, -9.999384f, -9.999094f, -9.999038f, -9.999751f, -9.999586f, -9.99945f, -9.999128f, -9.999073f, -9.999791f, -9.999677f, -9.9991865f, -9.99909f, -9.999762f, -9.999218f, -9.9995575f, -9.999647f, -9.999325f, -9.999892f, -9.999989f, -9.999758f, -9.999248f, -9.999668f, -9.999531f, -9.999084f, -9.999631f, -9.999403f, -9.999865f, -9.999935f, -9.9991f, -9.999564f, -9.99925f, -9.9990425f, -9.999887f, -9.999345f, -9.999006f, -9.999103f, -9.999717f, -9.99932f, -9.999787f, -9.999386f, -9.999753f, -9.999903f, -9.999105f, -9.999969f, -9.999686f, -9.999083f, -9.99972f, -9.999545f, -9.999551f, -9.999687f, -9.999285f, -9.999309f, -9.999812f, -9.99978f, -9.999336f, -9.999835f, -9.999004f, -9.999377f, -9.999526f, -9.999481f, -9.999829f, -9.999929f, -9.999993f, -9.999933f, -9.999451f, -9.999956f, -9.999661f, -9.999863f, -9.9993305f, -9.999771f, -9.999426f, -9.999976f, -9.999994f, -9.999831f, -9.99988f, -9.999162f, -9.999056f, -9.999193f, -9.999941f, -9.999949f, -9.999971f, -9.999258f, -9.999011f, -9.999707f, -9.999535f, -9.999201f, -9.9995985f, -9.999823f, -9.999531f, -9.999698f, -9.999328f, -9.999958f, -9.999032f, -9.999576f, -9.999392f, -9.999067f, -9.99902f, -9.999045f, -9.99983f, -9.999011f, -9.999783f, -9.999335f, -9.999907f, -9.999681f, -9.999122f, -9.999256f, -9.999235f, -9.999991f, -9.999099f, -9.999523f, -9.999284f, -9.999148f, -9.999722f, -9.999268f, -9.999101f, -9.99915f, -9.999277f, -9.999724f, -9.999198f, -9.999702f, -9.999371f, -9.999346f, -9.999348f, -9.999846f, -9.99938f, -9.999386f, 0.9152095f, 0.9171647f, 0.8286799f, 0.06623944f, 0.4663288f, 0.6674705f, 0.88702863f, 0.26388377f, 0.38012853f, 0.22043897f, 0.34161663f, 0.7549241f, 0.89839345f, 0.57267684f, 0.46196744f, 0.40692735f, 0.63130325f, 0.46858534f, 0.25790846f, 0.5064126f, 0.6745789f, 0.815519f, 0.3279563f, 0.06752282f, 0.32830805f, 0.9456376f, 0.99969417f, 0.33946416f, 0.09058472f, 0.80821294f, 0.4096069f, 0.04731839f, 0.1274211f, 0.26724407f, 0.0013231506f, 0.89294916f, 0.14734322f, 0.3986316f, 0.44342554f, 0.37137577f, 0.55341625f, 0.49281976f, 0.7313272f, 0.2879761f, 0.20376818f, 0.9424636f, 0.21195652f, 0.22167233f, 0.5677064f, 0.36845347f, 0.079733446f, 0.6180234f, 0.52336746f, 0.2760374f, 0.07769606f, 0.637682f, 0.085176565f, 0.16043824f, 0.6679482f, 0.8272858f, 0.6635249f, 0.28023627f, 0.9216744f, 0.5184493f, 0.33986536f, 0.83903545f, 0.6198479f, 0.7963929f, 0.63605565f, 0.41838124f, 0.26928508f, 0.05648084f, 0.6071852f, 0.3672051f, 0.54514945f, 0.46253535f, 0.595289f, 0.2197304f, 0.56575435f, 0.33570454f, 0.12949312f, 0.009017748f, 0.82104915f, 0.31175017f, 0.46786937f, 0.9008307f, 0.059177548f, 0.21651942f, 0.58483404f, 0.13534085f, 0.2563066f, 0.98585606f, 0.3444204f, 0.30529618f, 0.9550007f, 0.010194158f, 0.44460547f, 0.4293112f, 0.020983648f, 0.83968806f, 0.5455774f, 0.9872851f, 0.27159318f, 0.16667603f, 0.3916389f, 0.10710736f, 0.70841914f, 0.23437801f, 0.78563285f, 0.25137436f, 0.61097264f, 0.41494665f, 0.20036837f, 0.26286733f, 0.5676644f, 0.2662849f, 0.80940986f, 0.7974582f, 0.5003222f, 0.29910246f, 0.1976132f, 0.30444196f, 0.073145f, 0.68550193f, 0.28199244f, 0.7541317f, 0.11088511f, 0.34996328f, 0.7452604f, 0.42252555f, 0.21781512f, 0.96444f, 0.15884762f, 0.99850196f, 0.5329689f, 0.33807343f, 0.2701225f, 0.6472552f, 0.18246143f, 0.32816347f, 0.81063986f, 0.90712345f, 0.69261926f, 0.44346964f, 0.08311381f, 0.019193182f, 0.3513845f, 0.38967726f, 0.68732834f, 0.45974445f, 0.79513454f, 0.92073804f, 0.61770153f, 0.15796295f, 0.34206834f, 0.61403716f, 0.50911576f, 0.09764764f, 0.4105753f, 0.4610053f, 0.23835297f, 0.7583669f, 0.26223376f, 0.76859593f, 0.82576513f, 0.91628957f, 0.95209956f, 0.34038633f, 0.2481594f, 0.5448205f, 0.94344336f, 0.5867557f, 0.44679952f, 0.35732326f, 0.15309544f, 0.83495915f, 0.8223747f, 0.7383799f, 0.2723741f, 0.37363288f, 0.32874116f, 0.5468127f, 0.5836204f, 0.680963f, 0.28229877f, 0.440675f, 0.058448013f, 0.26188472f, 0.8043764f, 0.92689526f, 0.26310128f, 0.6354866f, 0.915084f, 0.45643163f, 0.87117124f, 0.9500249f, 0.1889253f, 0.5461343f, 0.47915125f, 0.43820933f, 0.13977474f, 0.8290898f, 0.30484903f, 0.5062122f, 0.33160135f, 0.62606835f, 0.65262437f, 0.23008808f, 0.4257683f, 0.13102946f, 0.21824555f, 0.8722663f, 0.26695797f, 0.028245918f, 0.77160543f, 0.10392295f, 0.06169725f, 0.9943042f, 0.8000285f, 0.34662995f, 0.3909258f, 0.6586493f, 0.9920871f, 0.80688536f, 0.84350026f, 0.86506003f, 0.9833786f, 0.1113381f, 0.058909472f, 0.36759707f, 0.1351905f, 0.08711318f, 0.17150986f, 0.97114897f, 0.10649935f, 0.917866f, 0.56674695f, 0.99736273f, 0.6040517f, 0.92105764f, 0.38094944f, 0.48367384f, 0.14886507f, 0.380281f, 0.41597223f, 0.11372275f, 0.9531382f, 0.67997587f, 0.15792394f, 0.3364488f, 0.021841977f, 0.07619969f, 0.7798327f, 0.19889046f, 0.67756367f, 0.50971586f, 0.52456796f, 0.5036354f, 0.7753575f, 0.34809372f, 0.6398678f, 0.4031053f, 0.32557586f, 0.9053469f, 0.8064988f, 0.017155945f, 0.6316684f, 0.45066175f, 0.4873005f, 0.19287354f, 0.57614934f, 0.83062655f, 0.78713834f, 0.68235135f, 0.87318754f, 0.59281385f, 0.064060956f, 0.9382655f, 0.84566283f, 0.5540783f, 0.17840536f, 0.61837703f, 0.60292286f, 0.6568771f, 0.8471286f, 0.17995848f, 0.49391183f, 0.58517873f, 0.5330186f, 0.5795362f, 0.23409952f, 0.5289169f, 0.3746643f, 0.3180484f, 0.5622743f, 0.036257476f, 0.43180978f, 1.3171679E-4f, 0.63862574f, 0.5848303f, 0.94060403f, 0.5878032f, 0.6252845f, 0.18924952f, 0.39612424f, 0.7757128f, 0.9900665f, 0.86055374f, 0.18927997f, 0.84641314f, 0.8975901f, 0.89157784f, 0.57380813f, 0.94526875f, 0.501755f, 0.42647004f, 0.20386614f, 0.4966745f, 0.7561392f, 0.24496855f, 0.13073194f, 0.41784236f, 0.70873123f, 0.7233561f, 0.96866304f, 0.13634546f, 0.049341034f, 0.71949446f, 0.26208475f, 0.5635493f, 0.27563098f, 0.69374204f, 0.078678265f, 0.03588799f, 0.39408693f, 0.7788656f, 0.94594073f, 0.92669946f, 0.41283527f, 0.62035376f, 0.281576f, 0.89905745f, 0.9558993f, 0.0892733f, 0.43785354f, 0.37643972f, 0.23148632f, 0.17041226f, 0.35524517f, 0.88507247f, 0.3892006f, 0.387216f, 0.15375885f, 0.21120822f, 0.24968858f, 0.44297022f, 0.2895735f, 0.15732966f, 0.07728944f, 0.71204036f, 0.6714093f, 0.053016555f, 0.75036585f, 0.23313028f, 0.56734544f, 0.7048986f, 0.8168968f, 0.06141414f, 0.35583347f, 0.07237186f, 0.12143032f, 0.83158904f, 0.6737841f, 0.53340894f, 0.13451897f, 0.24459034f, 0.96684134f, 0.30125558f, 0.39460337f, 0.07498105f, 0.6020688f, 0.11102765f, 0.3656724f, 0.4939227f, 0.21076858f, 0.13569292f, 0.6039172f, 0.08439329f, 0.30890274f, 0.22699659f, 0.64184964f, 0.2754223f, 0.7049345f, 0.63606584f, 0.9549267f, 0.80815446f, 0.17538197f, 0.05759198f, 0.43693244f, 0.26000643f, 0.6929544f, 0.7537442f, 0.61757445f, 0.19318241f, 0.034338124f, 0.8184448f, 0.92103f, 0.97425944f, 0.8894058f, 0.4300163f, 0.88676697f, 0.3483852f, 0.13178374f, 0.95866996f, 0.6248255f, 0.93648285f, 0.08839288f, 0.14454809f, 0.035382055f, 0.3209607f, 0.16345672f, 0.12934527f, 0.3662055f, 0.25347614f, 0.22039147f, 0.07854195f, 0.7695641f, 0.45950922f, 0.093585685f, 0.35322717f, 0.5360373f, 0.6071155f, 0.9050337f, 0.8356653f, 0.55022f, 0.8330065f, 0.92175573f, 0.93212676f, 0.79578835f, 0.44477537f, 0.14613354f, 0.6763672f, 0.27782786f, 0.9030046f, 0.8203768f, 0.6832867f, 0.24530792f, 0.7274624f, 0.3142183f, 0.022943567f, 238.253f, 220.45427f, 267.66333f, 238.0088f, 271.58243f, 273.22388f, 211.78992f, 289.42252f, 217.21829f, 208.85757f, 217.32358f, 207.44218f, 259.48422f, 208.71153f, 268.2896f, 297.33484f, 254.15167f, 232.80293f, 254.54332f, 232.60858f, 238.36755f, 270.21686f, 279.47226f, 282.7281f, 212.87875f, 212.81602f, 277.39685f, 293.25415f, 220.63031f, 259.65414f, 257.0341f, 286.7428f, 202.3495f, 251.0628f, 268.4925f, 237.58267f, 214.1937f, 219.69623f, 294.32617f, 293.98544f, 271.97043f, 277.1976f, 208.15645f, 285.3982f, 275.2406f, 253.17255f, 280.30792f, 210.3171f, 262.86252f, 211.56f, 201.4514f, 237.41928f, 204.32811f, 291.4109f, 246.54733f, 278.7369f, 226.24847f, 262.70038f, 207.41508f, 274.15656f, 250.72443f, 259.09497f, 278.62515f, 298.87927f, 271.1042f, 265.95636f, 228.53195f, 264.95953f, 231.45522f, 238.10721f, 201.05338f, 299.04672f, 203.31392f, 280.5685f, 207.49594f, 288.41803f, 259.77884f, 289.5286f, 212.903f, 232.62526f, 273.2359f, 274.92944f, 228.19473f, 292.2021f, 244.35541f, 235.74893f, 281.4144f, 255.78027f, 261.2293f, 219.03902f, 240.27055f, 210.33026f, 250.7247f, 281.74927f, 296.55548f, 224.49033f, 224.96393f, 219.88365f, 294.07227f, 223.65594f, 273.98865f, 279.8825f, 262.97278f, 269.57916f, 284.82678f, 205.99402f, 230.71436f, 245.10574f, 291.90387f, 221.07706f, 285.6493f, 236.25264f, 225.34695f, 210.36287f, 288.40872f, 299.56335f, 259.16122f, 220.4013f, 235.9941f, 213.55952f, 286.5168f, 261.12793f, 230.74602f, 268.31143f, 226.09164f, 217.6272f, 203.38873f, 240.80707f, 255.07602f, 283.92712f, 218.6427f, 278.5974f, 272.98724f, 211.10165f, 230.14198f, 217.64426f, 228.90018f, 266.22888f, 227.51234f, 218.84616f, 247.46571f, 259.92053f, 212.12146f, 248.02554f, 236.08237f, 277.90137f, 263.06485f, 207.07365f, 275.89902f, 264.8849f, -9.9997225f, -9.9999695f, -9.999966f, -9.9999895f, -9.999834f, -9.999596f, -9.999333f, -9.999578f, -9.99955f, -9.999539f, -9.99926f, -9.999182f, -9.999128f, -9.999777f, -9.999337f, -9.999904f, -9.999079f, -9.99941f, -9.999122f, -9.999788f, -9.999136f, -9.9995165f, -9.999043f, -9.999407f, -9.999571f, -9.999437f, -9.999941f, -9.999134f, -9.999198f, -9.999579f, -9.999475f, -9.999036f, -9.999713f, -9.999731f, -9.999678f, -9.999174f, -9.999507f, -9.999201f, -9.999245f, -9.999307f, -9.999488f, -9.999016f, -9.999532f, -9.999287f, -9.999413f, -9.999584f, -9.99978f, -9.999425f, -9.999651f, -9.999136f, -9.999289f, -9.999958f, -9.9991665f, -9.99916f, -9.999886f, -9.999217f, -9.99971f, -9.999494f, -9.999177f, -9.999025f, -9.999024f, -9.999849f, -9.999718f, -9.99997f, -9.999352f, -9.999563f, -9.999284f, -9.999314f, -9.999419f, -9.999329f, -9.99949f, -9.9992075f, -9.999859f, -9.999224f, -9.999656f, -9.999043f, -9.99958f, -9.999525f, -9.999985f, -9.999004f, -9.999768f, -9.999181f, -9.999919f, -9.999416f, -9.999452f, -9.999608f, -9.999645f, -9.999955f, -9.999919f, -9.999946f, -9.999472f, -9.999145f, -9.999147f, -9.99935f, -9.999072f, -9.999628f, -9.999188f, -9.999702f, -9.999313f, -9.999205f, -9.999878f, -9.999991f, -9.999111f, -9.9991f, -9.999404f, -9.999437f, -9.999719f, -9.999646f, -9.999839f, -9.999222f, -9.999134f, -9.999098f, -9.999538f, -9.999294f, -9.999013f, -9.999872f, -9.99908f, -9.999922f, -9.999595f, -9.999158f, -9.999308f, -9.9995f, -9.99924f, -9.999744f, -9.999338f, -9.999049f, -9.999883f, -9.999513f, -9.999893f, -9.999218f, -9.999468f, -9.999204f, -9.999081f, -9.9994335f, -9.999555f, -9.999373f, -9.999073f, -9.999382f, -9.999415f, -9.999362f, -9.999137f, -9.999514f, -9.999781f, -9.999969f, -9.999229f, -9.999295f, -9.999149f, -9.999783f, -9.999437f, -9.999201f, 0.8368316f, 0.95952296f, 0.7187136f, 0.6472035f, 0.7200239f, 0.82257813f, 0.13384113f, 0.91812044f, 0.9440362f, 0.23334092f, 0.3562596f, 0.20390894f, 0.47781035f, 0.56394255f, 0.8770303f, 0.84794813f, 0.92716575f, 0.3591966f, 0.006163279f, 0.34427875f, 0.30020186f, 0.035439115f, 0.36127335f, 0.1666844f, 0.65421695f, 0.752802f, 0.8639191f, 0.7162624f, 0.10528788f, 0.3911885f, 0.6361361f, 0.33739233f, 0.45225555f, 0.04712947f, 0.9509385f, 0.08811871f, 0.6489793f, 0.563957f, 0.8571504f, 0.47839713f, 0.86719155f, 0.7297759f, 0.9265764f, 0.86381954f, 0.2705895f, 0.80873495f, 0.69725907f, 0.4615118f, 0.98845094f, 0.38829336f, 0.5021872f, 0.051559158f, 0.4416545f, 0.84030825f, 0.028471855f, 0.8019141f, 0.4764789f, 0.73308647f, 0.24829985f, 0.28266567f, 0.1642818f, 0.497284f, 0.9761126f, 0.8595787f, 0.61120987f, 0.48310366f, 0.45415315f, 0.4246855f, 0.35486698f, 0.4365935f, 0.6768876f, 0.36493155f, 0.96304077f, 0.49552417f, 0.8761381f, 0.7559321f, 0.46201146f, 0.50861555f, 0.023068247f, 0.551351f, 0.45992744f, 0.069025f, 0.9549169f, 0.9121757f, 0.35455093f, 0.32405618f, 0.6669353f, 0.16085483f, 0.9973096f, 0.81469834f, 0.47871014f, 0.009814576f, 0.9915644f, 0.4212253f, 0.18318938f, 0.5728494f, 0.3666718f, 0.78813976f, 0.48231423f, 0.723981f, 0.7495278f, 0.7334672f, 0.31657055f, 0.29471073f, 0.2991272f, 0.17905454f, 0.25772056f, 0.04573023f, 0.9155821f, 0.9855648f, 0.9641909f, 0.49942952f, 0.32687747f, 0.3305897f, 0.5485675f, 0.6368628f, 0.09610839f, 0.91397697f, 0.99097943f, 0.7983881f, 0.7839146f, 0.13756526f, 0.058954984f, 0.2574425f, 0.7659589f, 0.8970627f, 0.8955351f, 0.24972673f, 0.3770009f, 0.5416225f, 0.42023486f, 0.4635182f, 0.040502504f, 0.20716274f, 0.08657944f, 0.13138548f, 0.8770457f, 0.6316995f, 0.0990857f, 0.732918f, 0.4953378f, 0.30765584f, 0.21265133f, 0.008900259f, 0.42015043f, 0.25701198f, 0.26232395f, 0.59503317f, 0.37619093f, 0.059471674f, 0.96380097f, 0.6594173f, 0.74392956f, 0.80542815f, 0.5856752f, 0.4709212f, 0.07911475f, 0.8975309f, 0.76675755f, 0.026576402f, 0.012588193f, 0.9571294f, 0.14971007f, 0.42658392f, 0.4339528f, 0.40636125f, 0.418213f, 0.19980216f, 0.8942122f, 0.995247f, 0.026640382f, 0.8785028f, 0.48940244f, 0.3919287f, 0.0862845f, 0.5089264f, 0.17742826f, 0.10345855f, 0.5513259f, 0.7041969f, 0.78375727f, 0.34573317f, 0.34970793f, 0.61609524f, 0.9967575f, 0.19738163f, 0.4390408f, 0.49108744f, 0.5759808f, 0.39300266f, 0.84470737f, 0.3280776f, 0.41459507f, 0.0031824266f, 0.3248213f, 0.21955715f, 0.8830681f, 0.6528493f, 0.7155801f, 0.18756945f, 0.038407642f, 0.048247315f, 0.06908089f, 0.96183145f, 0.8542427f, 0.45350936f, 0.3367257f, 0.26481515f, 0.06306089f, 0.3728015f, 0.4432045f, 0.7682931f, 0.34411287f, 0.018815735f, 0.60152483f, 0.06271082f, 0.30780053f, 0.15404528f, 0.777356f, 0.9382987f, 0.03425807f, 0.74410313f, 0.050881404f, 0.106018655f, 0.9237955f, 0.40959543f, 0.44272372f, 0.42992854f, 0.40163797f, 0.9774989f, 0.7284286f, 0.96605545f, 0.073630586f, 0.7020174f, 0.9556004f, 0.4899371f, 0.2590087f, 0.7959899f, 0.8613244f, 0.7109668f, 0.68005985f, 0.18156524f, 0.68875915f, 0.89809185f, 0.26884466f, 0.46794668f, 0.78001046f, 0.6469185f, 0.03375709f, 0.83638656f, 0.19561735f, 0.72300714f, 0.4323585f, 0.6666231f, 0.6944045f, 0.5573255f, 0.94807935f, 0.40593168f, 0.16260563f, 0.2516181f, 0.5295202f, 0.8144355f, 0.63592476f, 0.40705463f, 0.41550696f, 0.046603993f, 0.23649848f, 0.72142303f, 0.86540526f, 0.9812862f, 0.12677868f, 0.7740198f, 0.028188271f, 0.05125889f, 0.25654867f, 0.7408246f, 0.9826668f, 0.75396377f, 0.6689209f, 0.8002577f, 0.3877432f, 0.83123654f, 0.5672896f, 0.8960579f, 0.39333224f, 0.14590047f, 0.7893236f, 0.38733613f, 0.77125305f, 0.9827144f, 0.014167471f, 0.49262884f, 0.21413602f, 0.67211145f, 0.27530655f, 0.76538646f, 0.5841506f, 0.9951677f, 0.29803824f, 0.024221342f, 0.6438744f, 0.43844396f, 0.35386777f, 0.39374486f, 0.9667755f, 0.26405483f, 0.29369798f, 6.263968E-5f, 0.40577433f, 0.014699541f, 0.8506516f, 0.82061505f, 0.04640132f, 0.38329712f, 0.23627418f, 0.01457501f, 0.920022f, 0.36586156f, 0.54100925f, 0.4094f, 0.9525085f, 0.7759392f, 0.38271114f, 0.9372709f, 0.4954011f, 0.90372294f, 0.5493134f, 0.79789823f, 0.215295f, 0.18560563f, 0.52747923f, 0.015467339f, 0.25793558f, 0.9574369f, 0.8208537f, 0.21616516f, 0.80089974f, 0.4464337f, 0.37760806f, 0.31725752f, 0.07363392f, 0.5414981f, 0.5969112f, 0.6802155f, 0.08681603f, 0.748899f, 0.8132425f, 0.6588185f, 0.7527277f, 0.22249526f, 0.48485887f, 0.52951264f, 0.9087715f, 0.0022171019f, 0.3312975f, 0.70355535f, 0.9905531f, 0.18766245f, 0.8428444f, 0.9489218f, 0.75968647f, 0.16918193f, 0.5090402f, 0.57815427f, 0.41849396f, 0.3353734f, 0.5701858f, 0.59971434f, 0.037876863f, 0.30670634f, 0.08724593f, 0.51724964f, 0.44608638f, 0.8887655f, 0.23586161f, 0.54564106f, 0.17055021f, 0.65770286f, 0.36355573f, 0.11598958f, 0.98736215f, 0.39781153f, 0.8273148f, 0.099607535f, 0.9095583f, 0.63183874f, 0.6119373f, 0.023166118f, 0.42524394f, 0.3938052f, 0.78907496f, 0.7087274f, 0.4950751f, 0.27278492f, 0.36101273f, 0.9821936f, 0.7951266f, 0.8089244f, 0.7677898f, 0.506932f, 0.6540132f, 0.45168075f, 0.82436436f, 0.6100174f, 0.50495255f, 0.95378387f, 0.15670867f, 0.3659073f, 0.34792703f, 0.22730303f, 0.41741064f, 0.5464127f, 0.12390941f, 0.38427374f, 0.64032775f, 0.77376515f, 0.8658444f, 0.7240665f, 0.43486324f, 0.12049561f, 0.8539374f, 0.08333132f, 0.97497743f, 0.09330166f, 0.44820398f, 0.6796943f, 0.48456368f, 0.9055214f, 0.26348707f, 0.658894f, 0.0733997f, 0.1792219f, 0.54822993f, 0.08548857f, 0.6243975f, 0.14298357f, 0.034526028f, 0.094718255f, 0.039160337f, 0.24803995f, 0.7548811f, 0.81707966f, 0.55264014f, 0.4717769f, 0.8132233f, 0.08796681f, 0.46675965f, 0.21120757f, 0.84116185f, 0.02198596f, 233.08963f, 284.46478f, 228.92946f, 299.10284f, 252.34494f, 270.3675f, 247.62338f, 259.12375f, 293.7792f, 292.25543f, 287.2373f, 261.2933f, 234.23328f, 242.85649f, 246.06302f, 211.33946f, 262.4088f, 288.57184f, 280.21918f, 205.70305f, 216.75426f, 287.24652f, 233.86952f, 253.43048f, 228.54883f, 297.02246f, 219.41966f, 230.32181f, 211.07607f, 201.58842f, 255.04857f, 276.64703f, 226.55725f, 285.53146f, 230.61176f, 277.40143f, 217.56476f, 214.18044f, 253.52425f, 286.49228f, 280.64703f, 216.87614f, 229.96323f, 272.0548f, 287.85236f, 209.3926f, 271.86664f, 240.23541f, 299.9867f, 214.53423f, 273.7356f, 253.11342f, 205.02061f, 222.24791f, 242.70433f, 245.3724f, 298.40033f, 289.42432f, 282.7867f, 229.05533f, 289.985f, 271.32953f, 206.18881f, 285.04318f, 280.12766f, 215.771f, 233.6232f, 204.17224f, 242.84424f, 286.33337f, 254.11534f, 209.9334f, 243.23608f, 272.5159f, 205.16878f, 276.64346f, 244.62245f, 294.27008f, 290.36227f, 216.88017f, 298.44403f, 298.37915f, 214.64677f, 255.04266f, 280.10626f, 281.35904f, 236.9879f, 257.5684f, 280.48505f, 238.83212f, 253.65378f, 291.90552f, 228.50763f, 205.08888f, 281.95593f, 252.75293f, 290.4546f, 287.56818f, 210.91739f, 256.31198f, 232.79715f, 269.6927f, 235.58183f, 276.23233f, 227.1755f, 276.03674f, 292.6508f, 285.0999f, 287.64133f, 234.23032f, 296.60068f, 277.18442f, 257.54352f, 254.5871f, 298.60168f, 202.64233f, 255.38023f, 248.32083f, 260.9433f, 205.4068f, 247.34087f, 208.5292f, 202.0934f, 216.09306f, 221.08582f, 257.41556f, 247.06735f, 266.92804f, 210.08488f, 249.02866f, 204.24144f, 263.3803f, 222.9913f, 251.80115f, 218.99036f, 290.71286f, 227.41696f, 204.93797f, 231.20157f, 292.14478f, 297.73837f, 280.12753f, 297.94702f, 228.16396f, 256.27838f, 280.33307f, 205.8249f, 279.23096f, 268.9643f, 231.75375f, -9.999341f, -9.999257f, -9.999949f, -9.999035f, -9.999831f, -9.99975f, -9.999811f, -9.999584f, -9.999827f, -9.999112f, -9.999565f, -9.999383f, -9.999329f, -9.999119f, -9.999867f, -9.999806f, -9.999535f, -9.99903f, -9.99938f, -9.9991255f, -9.999031f, -9.999938f, -9.999783f, -9.999634f, -9.999506f, -9.999364f, -9.999014f, -9.999437f, -9.999991f, -9.999617f, -9.999323f, -9.9991f, -9.999098f, -9.999426f, -9.999119f, -9.999553f, -9.9994545f, -9.999403f, -9.99964f, -9.999833f, -9.99963f, -9.999753f, -9.999862f, -9.999563f, -9.999861f, -9.999462f, -9.99921f, -9.99975f, -9.999412f, -9.99969f, -9.999759f, -9.999703f, -9.999666f, -9.999825f, -9.999146f, -9.999077f, -9.999142f, -9.999701f, -9.999502f, -9.999564f, -9.9995165f, -9.9997835f, -9.999195f, -9.999329f, -9.999829f, -9.999427f, -9.999484f, -9.999804f, -9.999084f, -9.999392f, -9.999105f, -9.999679f, -9.999752f, -9.999843f, -9.999609f, -9.999379f, -9.99906f, -9.999004f, -9.99919f, -9.9998665f, -9.999223f, -9.999334f, -9.999842f, -9.999544f, -9.999025f, -9.999718f, -9.999823f, -9.999554f, -9.99945f, -9.999082f, -9.999171f, -9.999058f, -9.999519f, -9.9995365f, -9.999272f, -9.999615f, -9.999609f, -9.999498f, -9.999642f, -9.999337f, -9.999279f, -9.999857f, -9.999663f, -9.999423f, -9.9990635f, -9.999101f, -9.9993f, -9.999743f, -9.999616f, -9.999779f, -9.99996f, -9.999366f, -9.999638f, -9.999791f, -9.999472f, -9.999714f, -9.999069f, -9.999222f, -9.999011f, -9.999037f, -9.999066f, -9.99982f, -9.999337f, -9.999344f, -9.9998455f, -9.999567f, -9.999952f, -9.9990635f, -9.9993515f, -9.999747f, -9.999756f, -9.999433f, -9.999954f, -9.999456f, -9.999391f, -9.999602f, -9.999213f, -9.999057f, -9.999885f, -9.999203f, -9.999455f, -9.999208f, -9.999754f, -9.99941f, -9.9997015f, -9.999528f, -9.999968f, -9.999105f, -9.999052f, -9.999117f, 0.07731749f, 0.9572599f, 0.2881733f, 0.34789458f, 0.12208096f, 0.3989875f, 0.23046659f, 0.07561615f, 0.7311842f, 0.24280672f, 0.13743502f, 0.32029906f, 0.26720718f, 0.6435275f, 0.71581525f, 0.25040102f, 0.07968058f, 0.9510946f, 0.16737682f, 0.5338542f, 0.96112233f, 0.12613547f, 0.71407163f, 0.017653665f, 0.5663055f, 0.9523341f, 0.66330385f, 0.43527827f, 0.21753095f, 0.6377421f, 0.0820664f, 0.5563942f, 0.105712675f, 0.06655064f, 0.8044171f, 0.6876928f, 0.97473025f, 0.47098678f, 0.23313597f, 0.46495864f, 0.13682419f, 0.19020991f, 0.6946199f, 0.58204114f, 0.008083445f, 0.21409632f, 0.90480167f, 0.06497669f, 0.3296087f, 0.51603156f, 0.49303642f, 0.3029305f, 0.5821996f, 0.5105462f, 0.51879376f, 0.108761f, 0.13990402f, 0.44722676f, 0.8695498f, 0.014239418f, 0.5745597f, 0.52994305f, 0.8318035f, 0.7634822f, 0.677615f, 0.09214777f, 0.705199f, 0.47799557f, 0.24047466f, 0.3105237f, 0.89669865f, 0.6427869f, 0.59037143f, 0.2127864f, 0.27039096f, 0.09363014f, 0.7930851f, 0.58145946f, 0.058050785f, 0.74635893f, 0.34254172f, 0.942883f, 0.8463423f, 0.49698228f, 0.1885729f, 0.2511439f, 0.87867934f, 0.028224535f, 0.7651291f, 0.49802932f, 0.21640365f, 0.69269353f, 0.25175697f, 0.76805496f, 0.75059545f, 0.05755356f, 0.7005975f, 0.9643457f, 0.59199476f, 0.15058741f, 0.8211338f, 0.50831884f, 0.9554822f, 0.10171006f, 0.5546305f, 0.28822696f, 0.8995881f, 0.96590596f, 0.76544195f, 0.23609895f, 0.5093231f, 0.29946357f, 0.44045478f, 0.5974459f, 0.24198511f, 0.13976322f, 0.30026865f, 0.6117198f, 0.54420567f, 0.83931947f, 0.9591503f, 0.055750016f, 0.015446019f, 0.34988365f, 0.6788849f, 0.8000394f, 0.34461623f, 0.8884854f, 0.11765242f, 0.6764313f, 0.70610297f, 0.7528662f, 0.6234379f, 0.95549244f, 0.48107228f, 0.57657474f, 0.35293803f, 0.53558505f, 0.90731245f, 0.6388894f, 0.9061205f, 0.9068154f, 0.82560843f, 0.48359713f, 0.6093791f, 0.25128087f, 0.58313656f, 0.10119824f, 0.14279248f, 0.8000816f, 0.89156765f, 0.12725733f, 0.052655865f, 0.09217951f, 0.20653115f, 0.34572187f, 0.34771374f, 0.30589288f, 0.06053133f, 0.41077146f, 0.9258966f, 0.31344774f, 0.66711676f, 0.04113631f, 0.9229566f, 0.008368838f, 0.5903627f, 0.84122473f, 0.11545232f, 0.7868713f, 0.9680761f, 0.23150893f, 0.4704689f, 0.5499954f, 0.43753204f, 0.7121286f, 0.61013496f, 0.59720284f, 0.92617583f, 0.7834906f, 0.027650753f, 0.8977211f, 0.15754606f, 0.54239666f, 0.18633401f, 0.5662742f, 0.2190944f, 0.59521663f, 0.6435355f, 0.71627194f, 0.037149042f, 0.6100622f, 0.61836076f, 0.1470259f, 0.36966816f, 0.90360576f, 0.5119274f, 0.7205386f, 0.39034662f, 0.62984717f, 0.01017152f, 0.64599174f, 0.15090384f, 0.36933318f, 0.19484489f, 0.09027873f, 0.58042485f, 0.14514206f, 0.036732975f, 0.54077417f, 0.43008235f, 0.15875153f, 0.34932455f, 0.37410876f, 0.8042535f, 0.7739999f, 0.28807458f, 0.97715217f, 0.117083825f, 0.17992087f, 0.9757363f, 0.18320304f, 0.015741833f, 0.9748695f, 0.65635973f, 0.14705919f, 0.037058447f, 0.8968405f, 0.021620478f, 0.5633058f, 0.767505f, 0.12037435f, 0.44985265f, 0.26535192f, 0.22633725f, 0.5835013f, 0.42530164f, 0.6948082f, 0.7116804f, 0.6978424f, 0.82452023f, 0.23771845f, 0.99683344f, 0.70071405f, 0.12593275f, 0.7764756f, 0.36999762f, 0.3072223f, 0.09792935f, 0.43981078f, 0.8204207f, 0.14809668f, 0.7569628f, 0.8288626f, 0.15944423f, 0.21987063f, 0.5351478f, 0.11639127f, 0.9450276f, 0.657273f, 0.48179442f, 0.6428968f, 0.07266802f, 0.54417425f, 0.8990355f, 0.36724177f, 0.4083636f, 0.2944423f, 0.9782087f, 0.15691185f, 0.39151284f, 0.56013423f, 0.049810167f, 0.906521f, 0.9659634f, 0.921944f, 0.30070534f, 0.9883118f, 0.95775986f, 0.13003021f, 0.8573852f, 0.1918365f, 0.10604336f, 0.19914377f, 0.40675613f, 0.024324145f, 0.23431449f, 0.72297823f, 0.7580914f, 0.20346278f, 0.82810277f, 0.32680357f, 0.10711087f, 0.590452f, 0.5469826f, 0.18557824f, 0.51672226f, 0.9832008f, 0.7936118f, 0.5308729f, 0.37090248f, 0.7742029f, 0.4481485f, 0.5493372f, 0.50338376f, 0.43103522f, 0.53751975f, 0.70061314f, 0.021088583f, 0.3308669f, 0.8162114f, 0.5326165f, 0.35944003f, 0.9206047f, 0.6406876f, 0.50699484f, 0.8470867f, 0.9593492f, 0.7875809f, 0.9962247f, 0.23328215f, 0.7006755f, 0.5442194f, 0.6375928f, 0.33889383f, 0.9687761f, 0.5783294f, 0.9320834f, 0.88320315f, 0.7495404f, 0.5102735f, 0.22573441f, 0.51124907f, 0.9721347f, 0.44289282f, 0.37883982f, 0.33592433f, 0.40807053f, 0.7348208f, 0.059953105f, 0.020652194f, 0.373106f, 0.35336265f, 0.029604226f, 0.6272284f, 0.6029403f, 0.49051753f, 0.398493f, 0.4539566f, 0.2655247f, 0.9981165f, 0.75446373f, 0.46822912f, 0.648188f, 0.324949f, 0.9306804f, 0.8809041f, 0.42844233f, 0.38464552f, 0.76389503f, 0.7626695f, 0.63432926f, 0.33961716f, 0.61165744f, 0.7148871f, 0.4873704f, 0.49829185f, 0.5820676f, 0.40672466f, 0.51494414f, 0.883497f, 0.78602934f, 0.24558222f, 0.5361903f, 0.69763577f, 0.26757947f, 0.4059913f, 0.862289f, 0.7588195f, 0.18907034f, 0.42610446f, 0.08498969f, 0.02107262f, 0.2888108f, 0.90481687f, 0.03300186f, 0.61184776f, 0.41099504f, 0.27365708f, 0.27691156f, 0.01747882f, 0.71713996f, 0.40858844f, 0.7091915f, 0.2785737f, 0.87971973f, 0.015822828f, 0.058852635f, 0.54861325f, 0.4243099f, 0.07972601f, 0.7242567f, 0.3915925f, 0.85279524f, 0.5510232f, 0.88121253f, 0.55209786f, 0.9690384f, 0.910818f, 0.4399193f, 0.08753263f, 0.25317103f, 0.28638893f, 0.08940263f, 0.62953717f, 0.13840295f, 0.6593923f, 0.27087918f, 0.54218894f, 0.7974436f, 0.03127277f, 0.13191597f, 0.3672008f, 0.45645824f, 0.50062525f, 0.59150535f, 0.53669804f, 0.87231857f, 0.083159134f, 0.30086067f, 0.57798487f, 0.6605887f, 0.46329933f, 0.7809135f, 0.3256513f, 0.42846498f, 0.43590286f, 0.7588255f, 0.112232044f, 0.45630154f, 0.85721415f, 0.36618492f, 0.3291177f, 0.3065707f, 0.258635f, 0.93674284f, 0.267144f, 0.94944286f, 0.03034833f, 0.43545058f, 277.44568f, 293.30225f, 290.0967f, 226.36577f, 263.3507f, 233.65721f, 271.0456f, 201.33302f, 244.87222f, 248.06546f, 283.55505f, 273.16003f, 273.43265f, 248.35196f, 261.96664f, 252.17625f, 213.653f, 268.57755f, 241.37634f, 275.69666f, 231.28116f, 238.647f, 267.70135f, 270.0771f, 278.84747f, 232.92476f, 227.37221f, 290.46814f, 282.7081f, 210.15854f, 275.31555f, 260.04895f, 283.80142f, 227.62625f, 267.77484f, 245.33005f, 251.6941f, 232.47691f, 220.30089f, 292.46063f, 252.57907f, 262.54684f, 254.58533f, 239.21768f, 246.7902f, 254.07513f, 230.66675f, 288.9232f, 216.71547f, 214.78873f, 279.40067f, 210.46289f, 269.7311f, 258.03143f, 220.68816f, 220.33643f, 290.5327f, 217.04453f, 203.5228f, 236.82892f, 271.18365f, 253.44327f, 206.32324f, 243.99203f, 285.42123f, 208.0186f, 235.3223f, 215.7981f, 281.17578f, 258.11807f, 235.2606f, 226.48712f, 280.93256f, 280.83173f, 243.42778f, 266.36462f, 236.26477f, 295.47427f, 273.871f, 293.18738f, 276.67422f, 232.46318f, 218.5724f, 278.0185f, 260.68582f, 216.33072f, 202.01517f, 256.0112f, 260.35217f, 285.29895f, 282.32895f, 204.90137f, 202.91895f, 201.99902f, 234.42209f, 232.87006f, 296.0879f, 282.7151f, 260.2f, 263.00598f, 245.1402f, 220.98419f, 227.66153f, 298.27438f, 288.2768f, 246.6337f, 247.41647f, 229.84933f, 200.41792f, 256.62027f, 207.03185f, 235.04187f, 269.5741f, 279.07892f, 279.92096f, 266.31543f, 277.62415f, 282.93802f, 244.6243f, 261.97354f, 287.40088f, 285.73053f, 210.00949f, 235.31769f, 267.29855f, 256.89893f, 225.80467f, 241.72736f, 243.78555f, 230.197f, 220.44577f, 286.22617f, 295.29068f, 248.73352f, 271.84897f, 295.86597f, 274.50906f, 285.53323f, 254.3574f, 246.36845f, 232.46686f, 202.37822f, 232.31885f, 284.55515f, 281.44986f, 288.22656f, 224.62955f, 257.4739f, 277.62314f, 233.47943f, -9.999561f, -9.999684f, -9.999829f, -9.999858f, -9.999566f, -9.999728f, -9.999245f, -9.999897f, -9.999244f, -9.999921f, -9.999919f, -9.999612f, -9.999473f, -9.9995575f, -9.999303f, -9.999789f, -9.999555f, -9.999162f, -9.999468f, -9.999969f, -9.999672f, -9.999807f, -9.999847f, -9.99909f, -9.999817f, -9.999831f, -9.999489f, -9.999215f, -9.999848f, -9.9998455f, -9.999323f, -9.999817f, -9.999044f, -9.999408f, -9.999863f, -9.999365f, -9.99908f, -9.99931f, -9.99933f, -9.99975f, -9.999039f, -9.99978f, -9.999931f, -9.99974f, -9.999948f, -9.999952f, -9.999335f, -9.999389f, -9.999414f, -9.999315f, -9.999753f, -9.999389f, -9.99995f, -9.999082f, -9.999573f, -9.999592f, -9.9998f, -9.999939f, -9.999826f, -9.999052f, -9.99905f, -9.999516f, -9.999568f, -9.999664f, -9.999201f, -9.9993f, -9.999386f, -9.999858f, -9.999468f, -9.99966f, -9.999665f, -9.999242f, -9.9997425f, -9.99912f, -9.999361f, -9.999368f, -9.999324f, -9.999566f, -9.999074f, -9.99973f, -9.99977f, -9.999092f, -9.99947f, -9.999531f, -9.999189f, -9.99918f, -9.999814f, -9.999811f, -9.999523f, -9.999692f, -9.999746f, -9.999281f, -9.999508f, -9.999807f, -9.999763f, -9.999359f, -9.999442f, -9.999778f, -9.999925f, -9.999119f, -9.999002f, -9.999579f, -9.999089f, -9.999878f, -9.9991865f, -9.999503f, -9.99901f, -9.9991865f, -9.999055f, -9.999055f, -9.9990225f, -9.999116f, -9.999345f, -9.999241f, -9.999561f, -9.999711f, -9.999534f, -9.999722f, -9.999037f, -9.99902f, -9.999436f, -9.999547f, -9.9997425f, -9.999701f, -9.999172f, -9.99957f, -9.99917f, -9.999358f, -9.999515f, -9.9994545f, -9.999549f, -9.99922f, -9.999552f, -9.999457f, -9.999204f, -9.999363f, -9.99935f, -9.999776f, -9.999162f, -9.999254f, -9.99992f, -9.999504f, -9.9991f, -9.999846f, -9.99928f, -9.99955f, -9.999984f, -9.999683f, -9.999582f, -9.999975f, 0.4054413f, 0.49212277f, 0.9723238f, 0.72839403f, 0.6485173f, 0.11651259f, 0.10785521f, 0.032620244f, 0.023706913f, 0.3086147f, 0.47183102f, 0.992096f, 0.99172103f, 0.34033036f, 0.95944905f, 0.22414577f, 0.06989748f, 0.5614623f, 0.97281843f, 0.52306736f, 0.053522028f, 0.50254625f, 0.51301396f, 0.5985718f, 0.0371569f, 0.8265822f, 0.4661505f, 0.4922629f, 0.81253344f, 0.9696686f, 0.60658884f, 0.8239178f, 0.15269178f, 0.939187f, 0.14531301f, 0.37456673f, 0.779733f, 0.418844f, 0.66610193f, 0.5676376f, 0.8005674f, 0.31309485f, 0.03271992f, 0.36289623f, 0.5230104f, 0.9365938f, 0.54856783f, 0.38090333f, 0.677641f, 0.98534113f, 0.6625885f, 0.9755095f, 0.078554325f, 0.018032718f, 0.8922824f, 0.9402988f, 0.7797243f, 0.5073222f, 0.8464975f, 0.7056091f, 0.49532133f, 0.42082825f, 0.39204183f, 0.7350382f, 0.7106082f, 0.7145868f, 0.7029236f, 0.22454071f, 0.9618653f, 0.4929038f, 0.58743435f, 0.22425091f, 0.52113986f, 0.29244232f, 0.58773226f, 0.17996566f, 0.16191864f, 0.8782989f, 0.6559272f, 0.45498922f, 0.109633766f, 0.29422963f, 0.28020766f, 0.45128867f, 0.34663188f, 0.011857478f, 0.13049418f, 0.39511293f, 0.15442526f, 0.98196644f, 0.74726933f, 0.20202826f, 0.066193216f, 0.6910641f, 0.91542566f, 0.36986846f, 0.36708114f, 0.7992493f, 0.66625875f, 0.9589232f, 0.58173925f, 0.2632916f, 0.8744973f, 0.869903f, 0.27612343f, 0.43633205f, 0.0069335676f, 0.46793646f, 0.6261623f, 0.8301051f, 0.4103617f, 0.583117f, 0.9595133f, 0.092884764f, 0.6108136f, 0.9563768f, 0.13297999f, 0.9781464f, 0.1866522f, 0.6501296f, 0.940671f, 0.5299086f, 0.9236821f, 0.8280376f, 0.5605807f, 0.08746594f, 0.99765533f, 0.9831952f, 0.3346039f, 0.45981014f, 0.16059282f, 0.898296f, 0.24069251f, 0.84168667f, 0.42612913f, 0.840821f, 0.06970532f, 0.6529262f, 0.21027155f, 0.6587761f, 0.8506848f, 0.23469605f, 0.8375965f, 0.6650027f, 0.6900568f, 0.03741631f, 0.90703416f, 0.60072684f, 0.041207824f, 0.20454895f, 0.13258597f, 0.38379464f, 0.5782676f, 0.37454012f, 0.788924f, 0.6553679f, 0.6696084f, 0.194304f, 0.18800853f, 0.42950943f, 0.70689565f, 0.837481f, 0.14751653f, 0.56871074f, 0.7577148f, 0.7652816f, 0.19738932f, 0.9059352f, 0.97273886f, 0.51461357f, 0.1711977f, 0.5120307f, 0.22731306f, 0.5407244f, 0.2804785f, 0.05774873f, 0.80988765f, 0.7796792f, 0.31191307f, 0.39822164f, 0.5347025f, 0.07349863f, 0.21531169f, 0.07873698f, 0.8192433f, 0.722044f, 0.40318736f, 0.8964449f, 0.49459186f, 0.9010825f, 0.45778024f, 0.80724466f, 0.38512704f, 0.38782215f, 0.13246128f, 0.7218372f, 0.7401796f, 0.84869057f, 0.56868243f, 0.3278968f, 0.019229556f, 0.43221912f, 0.693255f, 0.43167397f, 0.78483266f, 0.09825686f, 0.5116548f, 0.1271103f, 0.18708695f, 0.95848906f, 0.23714672f, 0.52546054f, 0.5915945f, 0.7894098f, 0.8593355f, 0.31078282f, 0.28504592f, 0.85881007f, 0.29736793f, 0.50781727f, 0.65514153f, 0.44968098f, 0.9075563f, 0.7546295f, 0.45364478f, 0.29375777f, 0.94780463f, 0.6616151f, 0.01726944f, 0.9249832f, 0.9179415f, 0.6749661f, 0.43883613f, 0.37391648f, 0.65078586f, 0.21732111f, 0.02359236f, 0.007791354f, 0.30327088f, 0.31245363f, 0.84185934f, 0.49694976f, 0.93794364f, 0.8528437f, 0.7000397f, 0.5224565f, 0.8105422f, 0.99443287f, 0.847529f, 0.15470129f, 0.8077305f, 0.5341055f, 0.23147497f, 0.40932575f, 0.96443266f, 0.09061932f, 0.05683991f, 0.99754393f, 0.11661421f, 0.19272684f, 0.3620329f, 0.45262036f, 0.03901034f, 0.06041548f, 0.0075550857f, 0.27494353f, 0.67014945f, 0.2957977f, 0.2216069f, 0.6506188f, 0.45587075f, 0.28567624f, 0.5888963f, 0.98453754f, 0.8699843f, 0.9340606f, 0.0642961f, 0.14302005f, 0.7717978f, 0.75930613f, 0.6141049f, 0.4101332f, 0.27772737f, 0.28117037f, 0.8098905f, 0.5942f, 0.7786375f, 0.4493845f, 0.5141761f, 0.744234f, 0.34754843f, 0.9057713f, 0.29356617f, 0.41850287f, 0.25478244f, 0.78619635f, 0.70232016f, 0.7863453f, 0.57700616f, 0.3423882f, 0.11562478f, 0.6069529f, 0.7797115f, 0.2574891f, 0.51921356f, 0.2538803f, 0.670748f, 0.82137585f, 0.47364834f, 0.9369771f, 0.1801538f, 0.5134379f, 0.3520003f, 0.38112086f, 0.29870084f, 0.55816495f, 0.95891315f, 0.3729329f, 0.7877428f, 0.029987516f, 0.37669265f, 0.10563303f, 0.14064822f, 0.4556408f, 0.86550975f, 0.73312205f, 0.09095184f, 0.9431056f, 0.372078f, 0.4691022f, 0.72663444f, 0.5589779f, 0.98812455f, 0.1695335f, 0.8314304f, 0.7852622f, 0.61309403f, 0.10439321f, 0.76670945f, 0.5409888f, 0.9157445f, 0.57858527f, 0.14883776f, 0.20041484f, 0.30621874f, 0.9036323f, 0.9339205f, 0.9151604f, 0.12393201f, 0.929967f, 0.35930997f, 0.2358306f, 0.6697985f, 0.31414795f, 0.30049297f, 0.89661825f, 0.27027792f, 0.17256655f, 0.9318595f, 0.81196785f, 0.38976404f, 0.293463f, 0.2512547f, 0.81138444f, 0.988779f, 0.27900514f, 0.4261041f, 0.61765677f, 0.8339683f, 0.25210267f, 0.51324797f, 0.92285997f, 0.0889822f, 0.5169889f, 0.3989031f, 0.6554801f, 0.9353766f, 0.544529f, 0.123369224f, 0.34246746f, 0.2115331f, 0.26744205f, 0.71749866f, 0.22343503f, 0.64539504f, 0.67429143f, 0.41868812f, 0.40186298f, 0.098477215f, 0.88132435f, 0.07625152f, 0.043012597f, 0.6452063f, 0.2102687f, 0.22173183f, 0.10345679f, 0.7434575f, 0.7126712f, 0.76721144f, 0.6512526f, 0.15990873f, 0.11895295f, 0.77731425f, 0.5243528f, 0.694658f, 0.86524415f, 0.75635976f, 0.057310082f, 0.16338252f, 0.78290933f, 0.7817539f, 0.8036517f, 0.33238873f, 0.676157f, 0.6762056f, 0.16322272f, 0.87960654f, 0.36118373f, 0.32454377f, 0.763408f, 0.506997f, 0.6956684f, 0.9279813f, 0.20323144f, 0.5839603f, 0.5633559f, 0.6701542f, 0.25721762f, 0.9896909f, 0.95511895f, 0.9082311f, 0.29406747f, 0.60026234f, 0.93644714f, 0.61788774f, 0.66341126f, 0.20749137f, 0.52809435f, 0.30916053f, 0.59821826f, 0.42163637f, 0.8293481f, 0.9711802f, 0.7839911f, 0.7657031f, 0.5351135f, 0.6362381f, 0.5429735f, 0.29129192f, 0.74155486f, 256.6196f, 299.92203f, 283.1842f, 257.95f, 242.67941f, 283.13525f, 297.3768f, 209.21597f, 298.94897f, 272.28577f, 208.13962f, 224.24684f, 215.7119f, 289.45593f, 248.60497f, 291.094f, 261.66168f, 291.05728f, 280.15112f, 246.94473f, 281.08008f, 221.38707f, 231.09238f, 220.10115f, 219.70961f, 273.52057f, 298.6576f, 250.59302f, 203.40039f, 227.90755f, 208.1463f, 211.84389f, 251.76518f, 275.46594f, 292.12732f, 277.5088f, 281.66544f, 274.27924f, 291.94995f, 282.94733f, 231.35228f, 229.87643f, 226.04532f, 246.81201f, 285.92133f, 211.72032f, 265.00046f, 292.0401f, 217.145f, 258.9742f, 241.07838f, 297.71396f, 265.03607f, 293.78973f, 215.46487f, 271.7528f, 297.20273f, 234.13841f, 253.58505f, 252.52872f, 224.75195f, 218.48878f, 204.55463f, 293.8269f, 283.58505f, 264.1618f, 226.64536f, 280.69232f, 218.0678f, 219.11906f, 209.70735f, 215.2419f, 227.23471f, 226.22966f, 292.78833f, 250.87213f, 220.66672f, 292.0923f, 214.3262f, 220.62033f, 292.90533f, 294.61047f, 210.68884f, 260.9642f, 262.28113f, 255.0517f, 232.66026f, 294.8312f, 206.05696f, 289.73633f, 235.66345f, 232.93633f, 263.52408f, 256.7292f, 210.22684f, 229.51805f, 282.41776f, 211.0127f, 239.21553f, 235.43231f, 278.32697f, 299.7943f, 247.10483f, 219.1755f, 224.00432f, 263.2412f, 276.8183f, 291.88232f, 233.7261f, 241.75543f, 261.45193f, 296.58963f, 203.90746f, 277.9264f, 245.81134f, 261.24277f, 212.32646f, 242.76822f, 241.22888f, 224.0751f, 267.85315f, 232.49553f, 272.37656f, 253.20465f, 206.93951f, 201.29115f, 257.55444f, 296.3969f, 259.25177f, 292.10406f, 267.9734f, 253.28792f, 210.03741f, 272.03717f, 284.04358f, 292.52087f, 253.26274f, 207.37628f, 263.50598f, 228.07819f, 237.00746f, 241.3014f, 278.94174f, 214.41554f, 270.15442f, 264.77567f, 206.68633f, 229.17867f, 238.87085f, 254.12152f, -9.999742f, -9.999057f, -9.999062f, -9.999852f, -9.999382f, -9.999388f, -9.999354f, -9.999587f, -9.999273f, -9.999814f, -9.999888f, -9.999484f, -9.999295f, -9.999065f, -9.999623f, -9.999145f, -9.999381f, -9.999056f, -9.99943f, -9.999615f, -9.999143f, -9.999795f, -9.999838f, -9.999658f, -9.999616f, -9.9998f, -9.999448f, -9.999215f, -9.999058f, -9.999626f, -9.999816f, -9.99952f, -9.999158f, -9.999308f, -9.999545f, -9.999357f, -9.999205f, -9.999506f, -9.999683f, -9.999209f, -9.9999895f, -9.999543f, -9.999428f, -9.999628f, -9.999103f, -9.9991455f, -9.999936f, -9.999467f, -9.999748f, -9.99912f, -9.999807f, -9.999134f, -9.999681f, -9.999262f, -9.999087f, -9.999329f, -9.999385f, -9.999264f, -9.999793f, -9.999045f, -9.9995985f, -9.999204f, -9.999249f, -9.999444f, -9.9992075f, -9.9998455f, -9.999957f, -9.999949f, -9.999563f, -9.999786f, -9.999491f, -9.999651f, -9.999318f, -9.999416f, -9.999064f, -9.999325f, -9.9996f, -9.999902f, -9.999786f, -9.99952f, -9.999172f, -9.999215f, -9.999257f, -9.9991865f, -9.999605f, -9.999594f, -9.999224f, -9.999279f, -9.999259f, -9.999697f, -9.9996195f, -9.999134f, -9.999058f, -9.999047f, -9.999575f, -9.999919f, -9.999645f, -9.999633f, -9.999902f, -9.999141f, -9.999885f, -9.999965f, -9.999505f, -9.99982f, -9.999797f, -9.99964f, -9.999083f, -9.9995775f, -9.9999695f, -9.999383f, -9.999018f, -9.999117f, -9.99926f, -9.99911f, -9.999243f, -9.999118f, -9.99911f, -9.999486f, -9.99909f, -9.999861f, -9.999171f, -9.9999275f, -9.999972f, -9.999925f, -9.999671f, -9.999307f, -9.9994955f, -9.999324f, -9.999028f, -9.999182f, -9.999585f, -9.999082f, -9.999469f, -9.999043f, -9.999628f, -9.9994335f, -9.999068f, -9.999732f, -9.999809f, -9.999425f, -9.99959f, -9.999719f, -9.999516f, -9.999942f, -9.999832f, -9.999641f, -9.999447f, -9.99934f, -9.999968f, -9.999992f, 0.639171f, 0.47615534f, 0.1366003f, 0.4112621f, 0.543977f, 0.6301188f, 0.72094375f, 0.41664115f, 0.6702276f, 0.2662457f, 0.34709758f, 0.0047021024f, 0.19731691f, 0.3105783f, 0.35764986f, 0.6188618f, 0.55722684f, 0.014176953f, 0.28426266f, 0.55528253f, 0.9861382f, 0.59125423f, 0.91971123f, 0.50413203f, 0.71612626f, 0.37045076f, 0.16731057f, 0.8361767f, 0.20203081f, 0.46268502f, 0.54416966f, 0.82547253f, 0.70076334f, 0.19353609f, 0.7197332f, 0.7577992f, 0.15850778f, 0.09100532f, 0.8406752f, 0.4743588f, 0.14548168f, 0.91383964f, 0.31233132f, 0.057911392f, 0.38550714f, 0.788842f, 0.45663434f, 0.87255025f, 0.6822182f, 0.27235323f, 0.8781251f, 0.8971649f, 0.6117316f, 0.5027711f, 0.7707731f, 0.8171592f, 0.99433446f, 0.3228524f, 0.10424189f, 0.9995735f, 0.07680203f, 0.16278757f, 0.87946606f, 0.8840557f, 0.45882654f, 0.5382355f, 0.17185123f, 0.19348888f, 0.08070494f, 0.8351659f, 0.59116447f, 0.3656219f, 0.38914752f, 0.8038363f, 0.21394636f, 0.6494243f, 0.2923405f, 0.096409395f, 0.81489897f, 0.2177272f, 0.5156461f, 0.28180742f, 0.15846203f, 0.38402006f, 0.6799602f, 0.0992625f, 0.42167094f, 0.5157946f, 0.5737303f, 0.61967856f, 0.27188474f, 0.33863726f, 0.8381059f, 0.9284707f, 0.81110543f, 0.14615615f, 0.5137047f, 0.4068576f, 0.27341366f, 0.6371842f, 0.46284974f, 0.6114867f, 0.71931726f, 0.91663635f, 0.60304374f, 0.14932536f, 0.88403726f, 0.54094154f, 0.1467738f, 0.97935086f, 0.7863954f, 0.2147064f, 0.012224621f, 0.14325804f, 0.65899223f, 0.5648787f, 0.65609366f, 0.8197612f, 0.6399177f, 0.8468733f, 0.76479703f, 0.25536442f, 0.5532024f, 0.95500815f, 0.39078063f, 0.5678974f, 0.21131837f, 0.987159f, 0.27899948f, 0.45318067f, 0.052973147f, 0.22060722f, 0.13576879f, 0.22578368f, 0.4504141f, 0.81624466f, 0.6962496f, 0.38475657f, 0.5542052f, 0.040127296f, 0.7824744f, 0.7515341f, 0.2940618f, 0.45921704f, 0.74931914f, 0.4590101f, 0.1761703f, 0.76585937f, 0.3804439f, 0.20216002f, 0.79364806f, 0.48445576f, 0.9997787f, 0.07572355f, 0.9185397f, 0.43292367f, 0.6824889f, 0.57344544f, 0.45387882f, 0.61218095f, 0.001530312f, 0.36701044f, 0.3732282f, 0.21642086f, 0.0032335173f, 0.9757738f, 0.6631197f, 0.84142756f, 0.23562978f, 0.8842848f, 0.24768245f, 0.6896844f, 0.093373105f, 0.47206926f, 0.018847544f, 0.3574926f, 0.7817249f, 0.3901984f, 0.37762666f, 0.60320383f, 0.5876514f, 0.8498338f, 0.6137263f, 0.64150596f, 0.8912183f, 0.18202206f, 0.07165835f, 0.54631984f, 0.14491297f, 0.46619728f, 0.5531275f, 0.9730491f, 0.3560192f, 0.5463067f, 0.9498098f, 0.6082786f, 0.12641688f, 0.27168056f, 0.449438f, 0.2710077f, 0.059393216f, 0.47376275f, 0.3349298f, 0.8534693f, 0.24378222f, 0.27263063f, 0.31725782f, 0.027660795f, 0.36858514f, 0.31543452f, 0.32232106f, 0.7514354f, 0.7665531f, 0.93814677f, 0.94667625f, 0.7495306f, 0.07630936f, 0.07085721f, 0.09998243f, 0.14326382f, 0.3722598f, 0.8195573f, 0.88503057f, 0.64455885f, 0.9708746f, 0.574863f, 0.7547003f, 0.663569f, 0.62627494f, 0.66573906f, 0.88241595f, 0.5472183f, 0.10965517f, 0.086363465f, 0.03911088f, 0.43472022f, 0.282755f, 0.81878805f, 0.7069662f, 0.6482738f, 0.7889657f, 0.13123439f, 0.5466046f, 0.9870477f, 0.65994346f, 0.044764873f, 0.2590037f, 0.21607089f, 0.7882748f, 0.030434562f, 0.7240241f, 0.24359426f, 0.24925096f, 0.50715107f, 0.8548116f, 0.5778587f, 0.81658524f, 0.8406002f, 0.26860788f, 0.308281f, 0.40139812f, 0.27045614f, 0.681128f, 0.55732554f, 0.77117866f, 0.025454784f, 0.045293983f, 0.27430618f, 0.24866389f, 0.9072126f, 0.21633524f, 0.986974f, 0.91918707f, 0.86734384f, 0.5860722f, 0.8918684f, 0.86775124f, 0.24765202f, 0.7032609f, 0.4580694f, 0.6150063f, 0.12584582f, 0.13061108f, 0.11944151f, 0.27304602f, 0.08538959f, 0.2935459f, 0.6501564f, 0.6911091f, 0.79428184f, 0.19728307f, 0.9433592f, 0.98402375f, 0.278235f, 0.6931662f, 0.32246152f, 0.7604209f, 0.323686f, 0.4490462f, 0.21253695f, 0.37495488f, 0.095260054f, 0.5237899f, 0.9992169f, 0.36044437f, 0.5078252f, 0.5861082f, 0.64059675f, 0.03762793f, 0.49785113f, 0.38858363f, 0.69295675f, 0.2873984f, 0.32729995f, 0.59859157f, 0.73461634f, 0.25285175f, 0.5567667f, 0.71841735f, 0.69814867f, 0.77477485f, 0.16508374f, 0.15479185f, 0.48362815f, 0.37302348f, 0.7408702f, 0.11581469f, 0.08464117f, 0.029988535f, 0.34612563f, 0.45165575f, 0.68815565f, 0.008550999f, 0.09454897f, 0.8842033f, 0.471434f, 0.16433838f, 0.5935435f, 0.8646248f, 0.57239705f, 0.65469956f, 0.5863223f, 0.4796355f, 0.59167236f, 0.54985625f, 0.39255446f, 0.61727005f, 0.50840545f, 0.3316757f, 0.74857223f, 0.35827267f, 0.8872402f, 0.8038483f, 0.3931879f, 0.70447254f, 0.16417824f, 0.42719653f, 0.7534679f, 0.57123446f, 0.34724474f, 0.54931104f, 0.39288715f, 0.42828634f, 0.8222923f, 0.8765563f, 0.94212073f, 0.12068056f, 0.70422703f, 0.2824587f, 0.027603716f, 0.52777815f, 0.5066046f, 0.5769824f, 0.07630827f, 0.103958726f, 0.1505021f, 0.24175929f, 0.50438327f, 0.6733676f, 0.35198468f, 0.0752788f, 0.7415916f, 0.42589715f, 0.761479f, 0.0033971865f, 0.91897255f, 0.9319753f, 0.81370807f, 0.79544336f, 0.23588327f, 0.9587119f, 0.71191025f, 0.42136034f, 0.19574885f, 0.54185784f, 0.008105425f, 0.14255908f, 0.63592f, 0.3044852f, 0.6324764f, 0.6508548f, 0.08161495f, 0.65241224f, 0.8424147f, 0.97779244f, 0.72876996f, 0.61530423f, 0.94752645f, 0.6066642f, 0.10435986f, 0.18537253f, 0.30024627f, 0.8787194f, 0.06873524f, 0.91032326f, 0.84761214f, 0.12825106f, 0.22760965f, 0.70036477f, 0.09428674f, 0.9861057f, 0.13853452f, 0.8474568f, 0.057899747f, 0.060172286f, 0.37916803f, 0.15240528f, 0.77621406f, 0.26485768f, 0.1740309f, 0.29064766f, 0.7386373f, 0.5348933f, 0.26158985f, 0.43255532f, 0.59368885f, 0.61983097f, 0.13413209f, 0.32573816f, 0.43871734f, 0.7316835f, 0.7375361f, 0.8791016f, 0.46889958f, 0.8362294f, 0.56079483f, 0.78738517f, 0.12909074f, 0.19669758f, 0.3654093f, 257.23004f, 205.25952f, 256.3495f, 287.5462f, 248.0553f, 279.42828f, 252.23164f, 293.8083f, 244.82593f, 241.14514f, 264.60312f, 242.02669f, 265.36676f, 285.9313f, 276.8894f, 264.85254f, 204.56178f, 216.75874f, 245.4952f, 212.06345f, 205.75478f, 284.3255f, 291.17203f, 219.69725f, 203.70792f, 225.91046f, 230.73822f, 262.73547f, 201.7526f, 212.36281f, 283.3116f, 294.07062f, 249.66954f, 283.85126f, 246.5827f, 207.68987f, 272.6758f, 240.09421f, 275.82172f, 225.84433f, 232.80176f, 201.71077f, 252.89136f, 240.62161f, 259.20868f, 247.87543f, 218.64772f, 248.03424f, 202.67117f, 238.984f, 290.77563f, 293.03915f, 289.35855f, 289.96945f, 286.17395f, 231.49643f, 251.10532f, 225.1938f, 206.88234f, 256.4651f, 239.51657f, 245.26834f, 247.59836f, 204.23398f, 203.37993f, 225.53943f, 267.85843f, 297.7295f, 265.553f, 295.24786f, 242.70523f, 286.44165f, 283.38336f, 251.81482f, 208.90456f, 257.36407f, 229.28513f, 290.7318f, 258.70337f, 223.44356f, 264.08783f, 275.03732f, 251.59811f, 292.53107f, 251.5335f, 244.22394f, 213.89952f, 236.25047f, 211.8138f, 220.5794f, 216.87543f, 233.37456f, 224.4222f, 295.09964f, 214.58566f, 281.3576f, 256.06107f, 241.79654f, 291.32068f, 239.49226f, 228.46638f, 218.16322f, 203.63048f, 299.67514f, 282.89703f, 265.6753f, 287.9343f, 239.81447f, 209.17609f, 262.6297f, 295.4711f, 205.0095f, 223.62189f, 286.34204f, 243.34543f, 237.4936f, 249.12177f, 232.68518f, 229.49867f, 224.16684f, 203.26491f, 272.76715f, 294.89102f, 286.48096f, 273.26846f, 273.41534f, 204.2877f, 210.98381f, 206.86124f, 265.20584f, 244.88943f, 266.12534f, 239.2653f, 286.19138f, 271.75153f, 267.04507f, 210.73386f, 233.14261f, 220.80898f, 273.75244f, 298.48633f, 268.37622f, 204.67131f, 289.64368f, 276.43658f, 290.26245f, 279.004f, 201.35966f, 207.23166f, 280.78134f, -9.999485f, -9.999401f, -9.99988f, -9.99983f, -9.999996f, -9.999282f, -9.999148f, -9.999958f, -9.999139f, -9.999945f, -9.999827f, -9.999956f, -9.999576f, -9.999011f, -9.99982f, -9.999912f, -9.999579f, -9.9990425f, -9.999927f, -9.999287f, -9.999705f, -9.999723f, -9.999244f, -9.999403f, -9.999639f, -9.999259f, -9.999532f, -9.999533f, -9.999703f, -9.999582f, -9.999963f, -9.99968f, -9.999428f, -9.999266f, -9.999494f, -9.999798f, -9.999454f, -9.999226f, -9.99951f, -9.999481f, -9.999743f, -9.99988f, -9.999303f, -9.999975f, -9.999095f, -9.99945f, -9.999369f, -9.999166f, -9.99957f, -9.999976f, -9.999418f, -9.999267f, -9.99994f, -9.999312f, -9.999308f, -9.999992f, -9.9999f, -9.999182f, -9.9991665f, -9.999685f, -9.999133f, -9.999587f, -9.999473f, -9.999556f, -9.999567f, -9.999451f, -9.999944f, -9.999353f, -9.999919f, -9.999077f, -9.99981f, -9.999687f, -9.999805f, -9.999417f, -9.999404f, -9.999712f, -9.99989f, -9.999068f, -9.999573f, -9.999242f, -9.99952f, -9.999031f, -9.999762f, -9.999584f, -9.999476f, -9.999041f, -9.999508f, -9.999519f, -9.999463f, -9.999605f, -9.999481f, -9.99913f, -9.999719f, -9.99981f, -9.999058f, -9.99957f, -9.999909f, -9.99912f, -9.999596f, -9.999688f, -9.999179f, -9.999336f, -9.999998f, -9.999264f, -9.999145f, -9.99914f, -9.999104f, -9.999027f, -9.999755f, -9.999626f, -9.999572f, -9.999876f, -9.999124f, -9.9998865f, -9.999168f, -9.999185f, -9.9995575f, -9.999532f, -9.999246f, -9.999302f, -9.999073f, -9.999327f, -9.9998045f, -9.999645f, -9.999669f, -9.999047f, -9.999023f, -9.999354f, -9.999763f, -9.999772f, -9.999175f, -9.999568f, -9.999145f, -9.999254f, -9.999511f, -9.999705f, -9.999031f, -9.999324f, -9.999718f, -9.999497f, -9.99974f, -9.999597f, -9.999909f, -9.999239f, -9.999544f, -9.999691f, -9.999259f, -9.999239f, -9.999568f, -9.999504f, 0.03882216f, 0.8428897f, 0.74364215f, 0.23163715f, 0.49048677f, 0.22178552f, 0.6055793f, 0.4489804f, 0.9163623f, 0.9438124f, 0.1631071f, 0.6749212f, 0.7188561f, 0.32485962f, 0.8829685f, 0.20882395f, 0.60495543f, 0.47757575f, 0.6093003f, 0.84457403f, 0.7257506f, 0.17652789f, 0.025987253f, 0.9859064f, 0.6156289f, 0.73053515f, 0.76787066f, 0.5010675f, 0.40560544f, 0.07712759f, 0.9088255f, 0.07926025f, 0.24527292f, 0.27416497f, 0.74946845f, 0.24720564f, 0.07141664f, 0.43434754f, 0.4136174f, 0.869559f, 0.22436135f, 0.31195417f, 0.12554419f, 0.7383186f, 0.48795158f, 0.52957517f, 0.623028f, 0.036754537f, 0.56178623f, 0.32868809f, 0.9017316f, 0.09641818f, 0.9912348f, 0.92983764f, 0.4863829f, 0.2328445f, 0.72820157f, 0.5609035f, 0.5382467f, 0.21526214f, 0.2952519f, 0.391415f, 0.32775486f, 0.7910391f, 0.04752018f, 0.3907967f, 0.24044213f, 0.62969697f, 0.86658025f, 0.550671f, 0.6625566f, 0.7994618f, 0.12169334f, 0.21295948f, 0.4997118f, 0.98608136f, 0.67981267f, 0.5607458f, 0.20580857f, 0.59258527f, 0.74313295f, 0.504703f, 0.34825593f, 0.88810426f, 0.375232f, 0.9950801f, 0.6716571f, 0.43368435f, 0.13610889f, 0.7123607f, 0.5050985f, 0.31398848f, 0.6695705f, 0.12510324f, 0.18162547f, 0.61493284f, 0.816849f, 0.9648539f, 0.37662333f, 0.03039601f, 0.8444544f, 0.3708865f, 0.24754128f, 0.33466703f, 0.96997195f, 0.4863897f, 0.425792f, 0.5019443f, 0.3766153f, 0.37071276f, 0.30467907f, 0.5455875f, 0.47557223f, 0.99561185f, 0.82659286f, 0.50989014f, 0.8268076f, 0.32439554f, 0.90867627f, 0.523794f, 0.91507274f, 0.3708023f, 0.67873424f, 0.6258858f, 0.7507315f, 0.6253023f, 0.62942946f, 0.5893559f, 0.30942422f, 0.2114435f, 0.022920458f, 0.044418756f, 0.61610794f, 0.8113304f, 0.35662258f, 0.41705018f, 0.46921277f, 0.86777097f, 0.95223355f, 0.40362936f, 0.9437976f, 0.18228506f, 0.6360729f, 0.33576652f, 0.031274755f, 0.21817888f, 0.36112952f, 0.7787455f, 0.42273897f, 0.25281885f, 0.33198494f, 0.7785485f, 0.788286f, 0.16736427f, 0.0092501305f, 0.09297396f, 0.28935695f, 0.34107473f, 0.30980217f, 0.53143716f, 0.52857065f, 0.8409118f, 0.4052178f, 0.69706166f, 0.64710814f, 0.026039753f, 0.98393834f, 0.37317148f, 0.2896904f, 0.9887286f, 0.26908764f, 0.9406588f, 0.5261725f, 0.9049269f, 0.56662345f, 0.6709716f, 0.68239623f, 0.49234113f, 0.97048306f, 0.33545634f, 0.23616292f, 0.21654218f, 0.25211942f, 0.024790008f, 0.6374578f, 0.38915554f, 0.9337675f, 0.9430794f, 0.4695175f, 0.7804938f, 0.536538f, 0.9851012f, 0.19607964f, 0.3125924f, 0.55515915f, 0.85639995f, 0.76419586f, 0.19247372f, 0.8593474f, 0.65614396f, 0.8763346f, 0.5008372f, 0.75938493f, 0.30444136f, 0.8475765f, 0.2756218f, 0.7643892f, 0.10603409f, 0.4270085f, 0.40084615f, 0.094159424f, 0.28666124f, 0.907423f, 0.59824944f, 0.13585345f, 0.7766466f, 0.8080405f, 0.6886941f, 0.019375224f, 0.8924157f, 0.8251331f, 0.78726494f, 0.91793686f, 0.30526364f, 0.75136036f, 0.5101915f, 0.0959181f, 0.64297056f, 0.16485944f, 0.7552983f, 0.5024531f, 0.29433584f, 0.99849665f, 0.4194633f, 0.3247048f, 0.6200598f, 0.10172686f, 0.5053654f, 0.2359409f, 0.7552459f, 0.8971784f, 0.044323962f, 0.52423203f, 0.67628855f, 0.36866117f, 0.99563f, 0.2329034f, 0.27227026f, 0.76375973f, 0.79602706f, 0.5184415f, 0.10457488f, 0.0819885f, 0.90606177f, 0.052181873f, 0.6621527f, 0.92458886f, 0.24737877f, 0.04191045f, 0.34999782f, 0.08424192f, 0.29925734f, 0.24015819f, 0.5147704f, 0.42221153f, 0.99205357f, 0.54271156f, 0.79544294f, 0.5694224f, 0.37800944f, 0.5500707f, 0.09987821f, 0.40123457f, 0.7795467f, 0.8094248f, 0.5604407f, 0.34524485f, 0.56357986f, 0.6901132f, 0.2526902f, 0.46615395f, 0.24697252f, 0.5420497f, 0.18665877f, 0.6566352f, 0.2777055f, 0.9320998f, 0.89702964f, 0.022678716f, 0.1815973f, 0.09005783f, 0.51381236f, 0.6743502f, 0.6247244f, 0.8565416f, 0.87987f, 0.6732118f, 0.00460204f, 0.27535322f, 0.7455861f, 0.15749842f, 0.9247148f, 0.03532768f, 0.08851064f, 0.23502532f, 0.752143f, 0.21853413f, 0.6609476f, 0.28531924f, 0.18054475f, 0.029035527f, 0.67236483f, 0.2241403f, 0.28975555f, 0.99908245f, 0.43963638f, 0.59023327f, 0.30457687f, 0.16792373f, 0.7709499f, 0.6859642f, 0.69117963f, 0.86467695f, 0.5084144f, 0.7589203f, 0.4828981f, 0.07482473f, 0.48116097f, 0.53940266f, 0.5052822f, 0.22626108f, 0.7467059f, 0.41369334f, 0.031238595f, 0.028987564f, 0.66039693f, 0.22867519f, 0.8922084f, 0.23077016f, 0.49657655f, 0.12957393f, 0.5363605f, 0.4044849f, 0.44835f, 0.35317385f, 0.9867398f, 0.92447424f, 0.8969754f, 0.12785867f, 0.34567907f, 0.37078106f, 0.33044818f, 0.5057445f, 0.7683958f, 0.59161294f, 0.3239813f, 0.345188f, 0.5798496f, 0.64173394f, 0.8413601f, 0.47511417f, 0.835949f, 0.9396055f, 0.26686642f, 0.23109126f, 0.69826096f, 0.80957353f, 0.3445376f, 0.30203474f, 0.45118847f, 0.21602394f, 0.59850556f, 0.4789453f, 0.4077335f, 0.5152989f, 0.33034822f, 0.68474686f, 0.85391724f, 0.48057246f, 0.2998755f, 0.90360653f, 0.65591294f, 0.8092372f, 0.7287787f, 0.59123766f, 0.6105523f, 0.15701269f, 0.9201797f, 0.22071724f, 0.44657114f, 0.85324067f, 0.74536175f, 0.92492616f, 0.67641914f, 0.5987662f, 0.81729543f, 0.8069455f, 0.6891773f, 0.8835294f, 0.8892519f, 0.8500076f, 0.857101f, 0.6734726f, 0.9874815f, 0.46896955f, 0.9641137f, 0.47160545f, 0.8463774f, 0.30557284f, 0.9699319f, 0.06608189f, 0.055327572f, 0.93581414f, 0.9587841f, 0.058981307f, 0.92397076f, 0.010058546f, 0.34675553f, 0.6533823f, 0.5349482f, 0.46875533f, 0.5844002f, 0.5102338f, 0.26537207f, 0.19412437f, 0.07258324f, 0.38117927f, 0.1528994f, 0.056126937f, 0.7896892f, 0.3633707f, 0.5028834f, 0.15584666f, 0.43396717f, 0.7498128f, 0.17068368f, 0.8056127f, 0.83374524f, 0.7477155f, 0.8996221f, 0.53976667f, 0.9230572f, 0.19246647f, 0.6391656f, 0.4030687f, 0.7643678f, 0.019256072f, 0.59730285f, 0.309159f, 0.7264034f, 256.18292f, 247.5509f, 241.8322f, 221.72641f, 247.00475f, 289.95996f, 204.75641f, 299.0052f, 222.08545f, 249.15363f, 277.1748f, 222.7599f, 219.53043f, 259.93314f, 290.20483f, 264.3145f, 203.74707f, 269.35193f, 270.35507f, 233.42912f, 209.86781f, 292.96222f, 238.48882f, 256.7762f, 211.95813f, 255.83502f, 271.98605f, 276.92862f, 244.43182f, 219.40994f, 250.76295f, 294.04694f, 226.60033f, 258.7823f, 224.29234f, 289.13776f, 284.96054f, 215.06387f, 284.33295f, 255.14339f, 249.39714f, 298.0097f, 206.93636f, 207.78658f, 210.90904f, 237.74179f, 227.25084f, 248.60242f, 241.76729f, 289.64044f, 257.6767f, 223.0866f, 249.12407f, 201.15231f, 275.7378f, 262.39612f, 268.82336f, 262.55298f, 269.66827f, 237.66492f, 211.21674f, 246.47617f, 200.1591f, 228.94618f, 286.93787f, 224.82498f, 282.6982f, 216.67554f, 299.76526f, 211.74054f, 258.6674f, 282.2848f, 242.32083f, 244.45291f, 261.59262f, 257.17282f, 230.43474f, 219.33755f, 239.1705f, 229.16939f, 229.4628f, 227.99637f, 278.22507f, 207.49443f, 232.81923f, 250.38698f, 255.53925f, 201.98932f, 279.6214f, 245.52f, 216.7771f, 238.63602f, 204.19614f, 258.92218f, 230.05328f, 267.0341f, 256.95154f, 293.94968f, 251.7791f, 249.71518f, 268.04617f, 243.68118f, 239.60608f, 291.69824f, 255.33287f, 247.66194f, 210.42975f, 272.79053f, 251.49638f, 270.4292f, 266.5404f, 223.91647f, 227.0489f, 217.59396f, 202.26263f, 234.13164f, 282.81702f, 241.44751f, 237.6629f, 254.03835f, 276.81006f, 253.21158f, 290.75342f, 299.60394f, 252.36249f, 207.7176f, 293.0687f, 224.40785f, 254.29674f, 210.75064f, 251.1633f, 265.51978f, 292.73917f, 268.97003f, 213.86755f, 280.26193f, 236.59819f, 261.9136f, 271.9696f, 260.67432f, 225.67659f, 279.94318f, 244.74088f, 205.70877f, 236.24387f, 266.11798f, 234.5054f, 227.88277f, 212.92162f, 281.1429f, -9.9995f, -9.999907f, -9.999015f, -9.99986f, -9.999811f, -9.99916f, -9.9994335f, -9.999082f, -9.999476f, -9.999472f, -9.999309f, -9.999354f, -9.999964f, -9.999819f, -9.999472f, -9.999187f, -9.999328f, -9.999281f, -9.999373f, -9.999825f, -9.999259f, -9.999581f, -9.999256f, -9.999902f, -9.999506f, -9.999213f, -9.999032f, -9.999097f, -9.999959f, -9.999018f, -9.999999f, -9.999964f, -9.99983f, -9.999462f, -9.999094f, -9.999825f, -9.999322f, -9.999475f, -9.999018f, -9.999352f, -9.999122f, -9.999426f, -9.999498f, -9.999934f, -9.9994545f, -9.99973f, -9.999741f, -9.999373f, -9.99933f, -9.999706f, -9.999398f, -9.999283f, -9.999558f, -9.999604f, -9.999935f, -9.999592f, -9.999328f, -9.999943f, -9.999334f, -9.99971f, -9.999961f, -9.999668f, -9.9997835f, -9.999137f, -9.999606f, -9.999959f, -9.99975f, -9.999391f, -9.999501f, -9.999959f, -9.999507f, -9.999104f, -9.999123f, -9.999664f, -9.99954f, -9.999395f, -9.99991f, -9.999099f, -9.999796f, -9.999523f, -9.999298f, -9.999127f, -9.99933f, -9.999529f, -9.999645f, -9.999581f, -9.999803f, -9.999978f, -9.999745f, -9.999099f, -9.999732f, -9.999282f, -9.999186f, -9.999484f, -9.9994545f, -9.999736f, -9.999692f, -9.999638f, -9.999521f, -9.999184f, -9.999315f, -9.999997f, -9.999688f, -9.999604f, -9.999361f, -9.999519f, -9.999438f, -9.999516f, -9.999867f, -9.999932f, -9.99967f, -9.999632f, -9.999027f, -9.999614f, -9.999386f, -9.999235f, -9.99902f, -9.999881f, -9.999402f, -9.999828f, -9.999898f, -9.999556f, -9.9999485f, -9.99902f, -9.999726f, -9.99967f, -9.999689f, -9.999588f, -9.999742f, -9.999436f, -9.999829f, -9.999895f, -9.999559f, -9.999202f, -9.999972f, -9.999332f, -9.999621f, -9.999881f, -9.999916f, -9.999846f, -9.999947f, -9.999159f, -9.999294f, -9.999025f, -9.999374f, -9.999594f, -9.999471f, -9.999263f, -9.999252f, -9.999847f, 0.8405395f, 0.4899531f, 0.15557215f, 0.053656846f, 0.9073092f, 0.07903749f, 0.49019513f, 0.46704555f, 0.2108235f, 0.59149706f, 0.06908697f, 0.91793466f, 0.19079898f, 0.54947394f, 0.052311927f, 0.77982026f, 0.5299146f, 0.17064495f, 0.56645525f, 0.8840749f, 0.042285662f, 0.8682272f, 0.028326662f, 0.09698481f, 0.12325795f, 0.4347101f, 0.37012324f, 0.7913993f, 0.9993339f, 0.75977063f, 0.36460763f, 0.3775515f, 0.51856863f, 0.95555836f, 0.49067768f, 0.04478922f, 0.71699315f, 0.097812556f, 0.45841676f, 0.773683f, 0.75010455f, 0.42993996f, 0.9079247f, 0.017453227f, 0.44864193f, 0.672689f, 0.28056568f, 0.19584337f, 0.37550166f, 0.8117075f, 0.7120219f, 0.5780687f, 0.44134927f, 0.42259568f, 0.7511653f, 0.5891905f, 0.67056227f, 0.11231151f, 0.6758219f, 0.22908887f, 0.37498733f, 0.41971782f, 0.055803128f, 0.59144944f, 0.9299475f, 0.12942357f, 0.95274854f, 0.32053652f, 0.20608023f, 0.16834818f, 0.57836413f, 0.055714697f, 0.06392813f, 0.29768264f, 0.09972937f, 0.8983277f, 0.97463375f, 0.1341327f, 0.65210474f, 0.35204768f, 0.014110221f, 0.80327654f, 0.6689872f, 0.9037585f, 0.90981257f, 0.86295295f, 0.3795516f, 0.0062070885f, 0.5173644f, 0.20474744f, 0.86028427f, 0.15545785f, 0.3484738f, 0.48408556f, 0.28058404f, 0.75635433f, 0.5704764f, 0.80539626f, 0.8308685f, 0.7464902f, 0.12689869f, 0.89151156f, 0.37369293f, 0.36895418f, 0.5450234f, 0.1559311f, 0.2432725f, 0.38309494f, 0.27770162f, 0.56394845f, 0.72261786f, 0.5332152f, 0.49045795f, 0.88231075f, 0.6032768f, 0.6665413f, 0.857885f, 0.31463873f, 0.9153665f, 0.37640592f, 0.58912075f, 0.24793272f, 0.7373741f, 0.8440094f, 0.015947558f, 0.58805275f, 0.3667698f, 0.46238968f, 0.8334069f, 0.81946284f, 0.19397281f, 0.92121077f, 0.964989f, 0.24575949f, 0.0900369f, 0.6689977f, 0.23726216f, 0.601819f, 0.16691278f, 0.47163498f, 0.03375374f, 0.36948392f, 0.08575206f, 0.9858967f, 0.7306862f, 0.21772163f, 0.39309397f, 0.7458295f, 0.7629526f, 0.3144869f, 0.94122046f, 0.20584162f, 0.83637947f, 0.7726502f, 0.9049252f, 0.36524808f, 0.7137413f, 0.8284559f, 0.22519512f, 0.30139557f, 0.8169721f, 0.5312386f, 0.8956069f, 0.66213816f, 0.58457166f, 0.45457113f, 0.5169665f, 0.6269637f, 0.26091218f, 0.7560391f, 0.7980105f, 0.3960119f, 0.08781406f, 0.10958682f, 0.12124728f, 0.4373948f, 0.031676244f, 0.55287856f, 0.7805502f, 0.56280786f, 0.25152865f, 0.566051f, 0.7870067f, 0.759523f, 0.45281285f, 0.62631804f, 0.989187f, 0.26606834f, 0.39388546f, 0.87392044f, 0.583776f, 0.654467f, 0.49633527f, 0.39479604f, 0.63170516f, 0.62530655f, 0.9021866f, 0.13965032f, 0.35174674f, 0.79825306f, 0.7204604f, 0.8848764f, 0.43971986f, 0.7367297f, 0.71475625f, 0.07822404f, 0.42548487f, 0.11135407f, 0.80643165f, 0.83326644f, 0.8646103f, 0.89960915f, 0.46280593f, 0.8834037f, 0.2807901f, 0.68196964f, 0.3704893f, 0.4120405f, 0.82667f, 0.02957211f, 0.16348517f, 0.528726f, 0.36919758f, 0.22145572f, 0.43879473f, 0.09656078f, 0.5824419f, 0.0181659f, 0.25570688f, 0.7642685f, 0.19078839f, 0.70748967f, 0.5835414f, 0.92161185f, 0.8213292f, 0.046582457f, 0.85949063f, 0.15103385f, 0.74723977f, 0.39284366f, 0.5726992f, 0.07368804f, 0.3426399f, 0.17463133f, 0.24858418f, 0.31684884f, 0.49405006f, 0.37952894f, 0.33315596f, 0.8640441f, 0.57182634f, 0.25183997f, 0.7026268f, 0.37704948f, 0.17044407f, 0.27955136f, 0.96993434f, 0.09108966f, 0.6897659f, 0.19774762f, 0.6693781f, 0.12952057f, 0.89581305f, 0.21900262f, 0.1147024f, 0.29112664f, 0.06916158f, 0.22942513f, 0.42038745f, 0.7651415f, 0.45440084f, 0.17078096f, 0.07726187f, 0.4274913f, 0.86462736f, 0.06414275f, 0.9592153f, 0.16050456f, 0.88035154f, 0.9545343f, 0.8513476f, 0.2491725f, 0.7261043f, 0.5407395f, 0.22621076f, 0.31755584f, 0.75632083f, 0.7962324f, 0.50990444f, 0.61564916f, 0.76425743f, 0.70222944f, 0.73869663f, 0.29614443f, 0.021682443f, 0.5887306f, 0.31215057f, 0.10243766f, 0.9339864f, 0.23341663f, 0.7255635f, 0.4185125f, 0.5641563f, 0.0210989f, 0.31937757f, 0.77237654f, 0.055116564f, 0.31758264f, 0.35916016f, 0.5235203f, 0.15846917f, 0.5410007f, 0.3291817f, 0.14069794f, 0.90887386f, 0.259237f, 0.93863297f, 0.75447625f, 0.6713672f, 0.5048135f, 0.7174148f, 0.52741486f, 0.92290014f, 0.0805213f, 0.70555705f, 0.8765804f, 0.21684085f, 0.059146658f, 0.52307314f, 0.24510364f, 0.73993003f, 0.081979565f, 0.76904917f, 0.57904243f, 0.4695278f, 0.016590666f, 0.7074726f, 0.03675281f, 0.05884536f, 0.8561499f, 0.7090553f, 0.86932564f, 0.31001756f, 0.7310781f, 0.7902563f, 0.4690628f, 0.5504265f, 0.99635744f, 0.8836126f, 0.49213162f, 0.4428661f, 0.88994193f, 0.35176337f, 0.4958119f, 0.5913544f, 0.4187957f, 0.27758822f, 0.28339785f, 0.7841562f, 0.30195132f, 0.752634f, 0.3137563f, 0.4315457f, 0.44653264f, 0.5451809f, 0.44049335f, 0.8987003f, 0.5640792f, 0.5874427f, 0.47600824f, 0.5928f, 0.80064255f, 0.20061128f, 0.37571868f, 0.8139443f, 0.62335235f, 0.8047332f, 0.31274527f, 0.30714568f, 0.035397593f, 0.69739f, 0.2944578f, 0.34834376f, 0.5873635f, 0.9606469f, 0.5618423f, 0.6756651f, 0.03466902f, 0.27137738f, 0.59027666f, 0.8357776f, 0.425116f, 0.50365347f, 0.4515947f, 0.4932688f, 0.005631942f, 0.57952595f, 0.47525176f, 0.6249525f, 0.086651884f, 0.89189065f, 0.6617942f, 0.9442606f, 0.27843753f, 0.44292933f, 0.38660362f, 0.07765346f, 0.50435954f, 0.83211386f, 0.9370695f, 0.39374778f, 0.08252517f, 0.20432696f, 0.9130672f, 0.6829529f, 0.4023203f, 0.18018572f, 0.7534347f, 0.42706057f, 0.42672646f, 0.47151735f, 0.22955406f, 0.9152989f, 0.08499177f, 0.21106064f, 0.81278425f, 0.4464995f, 0.9721553f, 0.5701927f, 0.5504968f, 0.33792228f, 0.97337884f, 0.1806469f, 0.09640216f, 0.163271f, 0.42888898f, 0.778335f, 0.8884757f, 0.79867357f, 0.7878421f, 0.07889473f, 0.35902497f, 0.56884366f, 0.4541578f, 0.85038835f, 0.5382435f, 0.09464303f, 0.9107641f, 0.94099534f, 0.5400446f, 266.79602f, 274.32846f, 213.67004f, 233.85674f, 243.74121f, 250.29242f, 241.2762f, 246.10477f, 210.67426f, 209.43724f, 229.85814f, 280.7868f, 272.1595f, 250.896f, 203.6569f, 224.5947f, 228.5461f, 250.31659f, 259.0063f, 207.73958f, 214.5609f, 227.4157f, 288.49915f, 258.5862f, 237.1694f, 260.80396f, 253.53038f, 216.46973f, 200.73683f, 276.59747f, 218.64984f, 277.839f, 211.7889f, 278.14984f, 276.74042f, 224.4895f, 237.72171f, 253.24715f, 202.98746f, 237.59871f, 204.87325f, 239.43521f, 295.81796f, 299.5604f, 222.03635f, 228.79982f, 266.0576f, 239.92245f, 268.24426f, 238.24408f, 298.47308f, 288.47458f, 215.21046f, 248.30959f, 290.8601f, 287.38885f, 209.855f, 220.54123f, 251.46211f, 269.38593f, 215.89407f, 249.74835f, 233.35129f, 259.1078f, 247.44966f, 203.68665f, 295.11304f, 298.9008f, 216.80823f, 265.98523f, 250.68268f, 259.11737f, 224.44098f, 201.49985f, 265.72772f, 291.2741f, 291.02527f, 205.01653f, 225.3552f, 230.4449f, 205.90791f, 236.37225f, 234.94302f, 227.96848f, 293.9239f, 200.43617f, 261.1322f, 246.37569f, 206.33258f, 230.6332f, 275.16974f, 226.53664f, 253.74765f, 201.92174f, 277.2812f, 279.80594f, 269.5651f, 215.83727f, 290.79214f, 209.25894f, 240.69214f, 259.45502f, 221.35303f, 245.88794f, 233.58676f, 278.87738f, 268.62115f, 238.47983f, 288.8792f, 284.89505f, 235.00497f, 242.7936f, 236.64014f, 252.04784f, 205.45514f, 290.40726f, 232.52823f, 259.1132f, 290.73474f, 227.57782f, 216.67067f, 294.74762f, 217.73929f, 209.24208f, 256.90912f, 240.18433f, 257.794f, 282.8988f, 208.77882f, 297.82245f, 299.72125f, 298.86118f, 282.77133f, 299.69577f, 298.43073f, 299.66992f, 206.1796f, 239.80862f, 245.31291f, 207.94046f, 256.93558f, 210.00853f, 297.19482f, 258.61487f, 298.00143f, 247.14326f, 220.11229f, 299.13562f, 289.7299f, 244.51624f, -9.999632f, -9.999593f, -9.999801f, -9.999819f, -9.999018f, -9.999244f, -9.999898f, -9.999155f, -9.999041f, -9.999333f, -9.999995f, -9.999601f, -9.999369f, -9.999678f, -9.99932f, -9.999411f, -9.999675f, -9.999204f, -9.999888f, -9.999743f, -9.999049f, -9.999095f, -9.9994955f, -9.999148f, -9.999902f, -9.999157f, -9.999642f, -9.999242f, -9.999449f, -9.99954f, -9.999594f, -9.999917f, -9.999246f, -9.999855f, -9.999591f, -9.999358f, -9.999842f, -9.999382f, -9.999745f, -9.999809f, -9.999109f, -9.999151f, -9.999462f, -9.999784f, -9.999753f, -9.999547f, -9.999858f, -9.999641f, -9.999331f, -9.999973f, -9.999725f, -9.999956f, -9.999523f, -9.999478f, -9.999359f, -9.999043f, -9.999455f, -9.999254f, -9.999494f, -9.999362f, -9.999646f, -9.999454f, -9.999153f, -9.99971f, -9.99948f, -9.999924f, -9.999973f, -9.9990425f, -9.999157f, -9.999034f, -9.999135f, -9.999451f, -9.99927f, -9.999871f, -9.999655f, -9.999354f, -9.999864f, -9.999408f, -9.999447f, -9.999032f, -9.999453f, -9.999718f, -9.999415f, -9.999358f, -9.999691f, -9.99945f, -9.999504f, -9.999244f, -9.999987f, -9.999557f, -9.999052f, -9.999141f, -9.999237f, -9.999049f, -9.99919f, -9.999888f, -9.999757f, -9.999621f, -9.999702f, -9.999411f, -9.999203f, -9.999174f, -9.999015f, -9.999339f, -9.999034f, -9.999728f, -9.99976f, -9.999317f, -9.999367f, -9.999866f, -9.999091f, -9.999755f, -9.999178f, -9.999553f, -9.999263f, -9.999655f, -9.999423f, -9.999304f, -9.999814f, -9.999966f, -9.999977f, -9.9992075f, -9.999666f, -9.999204f, -9.999895f, -9.999059f, -9.99907f, -9.9995575f, -9.999523f, -9.999056f, -9.999571f, -9.999786f, -9.999026f, -9.999145f, -9.999575f, -9.999738f, -9.99979f, -9.999363f, -9.999586f, -9.999727f, -9.999086f, -9.999402f, -9.999158f, -9.999252f, -9.999179f, -9.999597f, -9.999156f, -9.99936f, -9.999807f, -9.999261f, 0.5652288f, 0.9339315f, 0.55770487f, 0.7478212f, 0.33771703f, 0.28125492f, 0.51592994f, 0.5532214f, 0.58044416f, 0.66528046f, 0.669034f, 0.16671883f, 0.67413294f, 0.036051773f, 0.108843535f, 0.7993396f, 0.1639013f, 0.6568752f, 0.122072175f, 0.70342636f, 0.5444655f, 0.5812534f, 0.4522436f, 0.2419f, 0.07067616f, 0.8879451f, 0.60514754f, 0.14282055f, 0.70217454f, 0.10503953f, 0.39604086f, 0.60164565f, 0.5446685f, 0.07094606f, 0.5559759f, 0.014643576f, 0.9885768f, 0.45798954f, 0.80507016f, 0.46793476f, 0.91752577f, 0.04094297f, 0.60369307f, 0.8747373f, 0.5086575f, 0.7004933f, 0.2251465f, 0.35307238f, 0.27597564f, 0.94157344f, 0.65179616f, 0.20595148f, 0.27256346f, 0.20036213f, 0.67921185f, 0.15910614f, 0.52645075f, 0.6180527f, 0.09315563f, 0.4282912f, 0.3796773f, 0.55366653f, 0.8087156f, 0.989089f, 0.81570625f, 0.36953965f, 0.29338685f, 0.8806224f, 0.40907812f, 0.99581677f, 0.031810474f, 0.9831273f, 0.21194534f, 0.6745432f, 0.38136473f, 0.2702163f, 0.6385419f, 0.29438227f, 0.12847719f, 0.27120438f, 0.30660692f, 0.5424479f, 0.92706877f, 0.9079774f, 0.22223541f, 0.3657775f, 0.25447527f, 0.81911993f, 0.30269873f, 0.74017876f, 0.92759985f, 0.70151937f, 0.7640615f, 0.8949204f, 0.79928416f, 0.77783567f, 0.6940916f, 0.2910855f, 0.97654736f, 0.2973309f, 0.5588422f, 0.6462096f, 0.30760437f, 0.18172295f, 0.7695246f, 0.34731266f, 0.19734544f, 0.029608455f, 0.37696892f, 0.111436665f, 0.50183326f, 0.28445065f, 0.68564844f, 0.44779962f, 0.9736052f, 0.51790065f, 0.983022f, 0.52825344f, 0.41285545f, 0.9967343f, 0.6162969f, 0.37753683f, 0.17138597f, 0.07175013f, 0.81368434f, 0.9612253f, 0.9045651f, 0.84745973f, 0.36729226f, 0.98037714f, 0.20115525f, 0.12099608f, 0.96984464f, 0.37242016f, 0.29363927f, 0.39158085f, 0.27558497f, 0.66305256f, 0.10113714f, 0.76193494f, 0.45118755f, 0.4488773f, 0.93012637f, 0.31139725f, 0.0031577414f, 0.22718209f, 0.29718128f, 0.71752393f, 0.14526285f, 0.18364605f, 0.37547293f, 0.9685261f, 0.9378056f, 0.27025697f, 0.8536382f, 0.40919214f, 0.6247997f, 0.020774715f, 0.2789666f, 0.6214883f, 0.28909984f, 0.4459083f, 0.22759606f, 0.16503142f, 0.12913509f, 0.76620036f, 0.31722352f, 0.31122422f, 0.14058389f, 0.3711774f, 0.2540991f, 0.92829734f, 0.31982893f, 0.58990836f, 0.7611616f, 0.94479626f, 0.77106464f, 0.98198724f, 0.045493614f, 0.5808194f, 0.044766188f, 0.028754123f, 0.6398209f, 0.5149536f, 0.6159741f, 0.38356403f, 0.3443942f, 0.8204024f, 0.16429621f, 0.45349202f, 0.9345274f, 0.6689286f, 0.46520096f, 0.5479114f, 0.50660115f, 0.030693837f, 0.14807424f, 0.0025167174f, 0.04072329f, 0.06662837f, 0.19923986f, 0.31228405f, 0.26450446f, 0.5282875f, 0.32404247f, 0.3938328f, 0.028723368f, 0.53065664f, 0.84379214f, 0.84157664f, 0.37586623f, 0.15792112f, 0.20647834f, 0.024251468f, 0.3573017f, 0.37901312f, 0.6181092f, 0.76309824f, 0.7608666f, 0.3481646f, 0.34048688f, 0.47856995f, 0.31012326f, 0.23520178f, 0.45539266f, 0.92912894f, 0.4204687f, 0.92543155f, 0.5307048f, 0.27608588f, 0.7496653f, 0.6049889f, 0.36525294f, 0.14689086f, 0.51323116f, 0.12193437f, 0.59619224f, 0.60478336f, 0.9294276f, 0.249309f, 0.74476606f, 0.92789376f, 0.043751504f, 0.5309229f, 0.3062958f, 0.31674966f, 0.14777556f, 0.52924913f, 0.9668007f, 0.20873389f, 0.3279674f, 0.7965414f, 0.37618962f, 0.89503884f, 0.46796778f, 0.0799155f, 0.13676843f, 0.99596673f, 0.5959752f, 0.82745814f, 0.19763403f, 0.45169583f, 0.034008075f, 0.51954156f, 0.5263711f, 0.32014525f, 0.053273566f, 0.81357837f, 0.97085255f, 0.07153194f, 0.9582462f, 0.64213526f, 0.32651472f, 0.60837305f, 0.9404863f, 0.06993771f, 0.7587776f, 0.7886673f, 0.41194588f, 0.78207874f, 0.7781359f, 0.3276002f, 0.33506534f, 0.28078383f, 0.12973906f, 0.399713f, 0.62760603f, 0.75171447f, 0.80802286f, 0.5050624f, 0.33723688f, 0.23653711f, 0.22387893f, 0.3570362f, 0.05210913f, 0.8889524f, 0.49352857f, 0.4521699f, 0.9740411f, 0.7144635f, 0.4756838f, 0.331589f, 0.068503655f, 0.97924995f, 0.41867498f, 0.31639704f, 0.7069934f, 0.81501675f, 0.5386601f, 0.4093507f, 0.707298f, 0.9774356f, 0.72752196f, 0.1570271f, 0.9423814f, 0.9732382f, 0.71725017f, 0.3946321f, 0.62860346f, 0.06245658f, 0.90315664f, 0.5143768f, 0.8708286f, 0.84123635f, 0.92691624f, 0.639396f, 0.2552601f, 0.37173754f, 0.7914776f, 0.91429204f, 0.4736561f, 0.15064463f, 0.7540974f, 0.2862515f, 0.48185065f, 0.13227704f, 0.32188603f, 0.63464296f, 0.8106472f, 0.94166034f, 0.17569262f, 0.19304337f, 0.29407963f, 0.587708f, 0.97985137f, 0.93614686f, 0.8405717f, 0.02620014f, 0.35624048f, 0.59463245f, 0.011628275f, 0.66693187f, 0.74045765f, 0.8160365f, 0.84104806f, 0.88261247f, 0.0711487f, 0.8989867f, 0.97475845f, 0.4168518f, 0.13669337f, 0.28926903f, 0.49182004f, 0.41090083f, 0.276433f, 0.09197279f, 0.68734396f, 0.3883402f, 0.90047145f, 0.11048286f, 0.15737055f, 0.21775864f, 0.9536175f, 0.076466806f, 0.24726667f, 0.103641525f, 0.0413075f, 0.27288043f, 0.3405656f, 0.14998767f, 0.51837134f, 0.16329993f, 0.3755023f, 0.9497281f, 0.8958037f, 0.98416775f, 0.34084278f, 0.18396701f, 0.8870497f, 0.11773594f, 0.7778607f, 0.5278507f, 0.9345038f, 0.12104616f, 0.3192234f, 0.026860172f, 0.71437854f, 0.8270822f, 0.34825006f, 0.39791596f, 0.62681943f, 0.27854878f, 0.519083f, 0.9585388f, 0.9732782f, 0.24999642f, 0.18574189f, 0.92319125f, 0.2299785f, 0.78481007f, 0.4593966f, 0.18952563f, 0.4418934f, 0.75275475f, 0.47553676f, 0.47977385f, 0.516905f, 0.6218342f, 0.986334f, 0.6328223f, 0.87600803f, 0.23837951f, 0.29930744f, 0.5477805f, 0.17647119f, 0.3403492f, 0.79772884f, 0.12769036f, 0.8723695f, 0.1560829f, 0.75527936f, 0.41855234f, 0.66972154f, 0.3795148f, 0.75438255f, 0.45185962f, 0.64733654f, 0.83693033f, 0.7853063f, 0.52869916f, 0.44457012f, 0.031068115f, 0.995698f, 0.86542577f, 0.29396066f, 0.3056323f, 0.7761462f, 0.5815433f, 0.4590591f, 0.6379277f, 203.08049f, 242.811f, 200.0787f, 248.54701f, 240.53275f, 206.88977f, 264.96545f, 215.722f, 207.14218f, 248.2029f, 260.38293f, 246.59158f, 255.92654f, 290.20236f, 282.13013f, 255.587f, 289.51746f, 250.55061f, 256.14774f, 212.82437f, 283.77695f, 234.53087f, 295.53558f, 263.51688f, 262.4394f, 295.93118f, 249.12567f, 230.53714f, 244.58417f, 212.62454f, 222.62276f, 202.04688f, 220.03893f, 219.85342f, 298.00995f, 225.98215f, 237.55687f, 233.73161f, 277.78552f, 292.03333f, 241.16255f, 239.44547f, 269.768f, 208.34856f, 223.83221f, 247.22945f, 220.80157f, 225.7253f, 267.53107f, 219.36331f, 263.37506f, 292.40854f, 238.76868f, 248.44582f, 284.12405f, 266.40955f, 297.5755f, 221.04996f, 205.62082f, 256.34137f, 216.44402f, 236.91107f, 213.73282f, 215.86444f, 256.87595f, 251.31393f, 216.1751f, 265.14798f, 213.08633f, 254.30765f, 244.74179f, 278.06122f, 262.01956f, 248.49234f, 205.56573f, 285.15247f, 291.18823f, 246.23334f, 286.69305f, 297.73892f, 222.13132f, 274.70645f, 272.9896f, 218.96129f, 263.71072f, 289.10516f, 210.93655f, 235.38228f, 240.58383f, 289.90942f, 238.94185f, 276.05884f, 239.10864f, 254.86401f, 282.10757f, 204.39113f, 238.20418f, 291.72028f, 279.3937f, 255.42195f, 223.81288f, 201.32336f, 262.53845f, 218.35716f, 291.38098f, 248.38783f, 276.37997f, 251.07683f, 295.05258f, 210.5348f, 252.41638f, 265.33124f, 294.82996f, 279.9688f, 295.2437f, 275.68787f, 202.7976f, 207.2586f, 262.63266f, 295.0467f, 288.30432f, 231.05023f, 298.57654f, 286.71002f, 222.34149f, 209.956f, 297.5865f, 204.87299f, 243.4733f, 242.39302f, 209.53899f, 221.00655f, 211.91463f, 266.0036f, 223.22115f, 266.37555f, 278.43994f, 214.11813f, 254.79947f, 234.70715f, 294.82663f, 267.89825f, 282.26373f, 285.57803f, 216.04143f, 222.16176f, 264.46344f, 216.57985f, 208.0961f, 251.9738f, -9.999269f, -9.999741f, -9.999561f, -9.999911f, -9.999339f, -9.999749f, -9.999292f, -9.999522f, -9.999454f, -9.9992895f, -9.999531f, -9.99933f, -9.999341f, -9.99938f, -9.999905f, -9.999054f, -9.999979f, -9.999243f, -9.999734f, -9.999235f, -9.999104f, -9.999684f, -9.999259f, -9.999619f, -9.999497f, -9.999474f, -9.999353f, -9.999263f, -9.999088f, -9.999558f, -9.999322f, -9.999186f, -9.9993925f, -9.9999075f, -9.999958f, -9.999795f, -9.999834f, -9.999768f, -9.999121f, -9.999825f, -9.999527f, -9.999656f, -9.999941f, -9.999142f, -9.999984f, -9.999141f, -9.999887f, -9.9990835f, -9.999148f, -9.9991665f, -9.999867f, -9.999421f, -9.999081f, -9.999978f, -9.999075f, -9.999531f, -9.999142f, -9.999553f, -9.999812f, -9.999398f, -9.999295f, -9.9992285f, -9.999865f, -9.999482f, -9.999524f, -9.999773f, -9.999741f, -9.999358f, -9.999916f, -9.999248f, -9.999274f, -9.999893f, -9.999962f, -9.999569f, -9.9997225f, -9.999103f, -9.999036f, -9.999721f, -9.999645f, -9.999536f, -9.999113f, -9.9998455f, -9.999898f, -9.999262f, -9.999967f, -9.999528f, -9.9996195f, -9.999813f, -9.99977f, -9.999597f, -9.999661f, -9.999434f, -9.999925f, -9.999199f, -9.999759f, -9.999627f, -9.999813f, -9.999361f, -9.999325f, -9.999499f, -9.999843f, -9.999769f, -9.999987f, -9.999241f, -9.999264f, -9.999075f, -9.9998665f, -9.99927f, -9.999766f, -9.999045f, -9.999036f, -9.999232f, -9.999256f, -9.999415f, -9.999601f, -9.999707f, -9.999876f, -9.999688f, -9.999064f, -9.999532f, -9.99921f, -9.99905f, -9.999712f, -9.999656f, -9.999218f, -9.999016f, -9.999569f, -9.999398f, -9.999709f, -9.999183f, -9.999058f, -9.999427f, -9.999155f, -9.999367f, -9.999406f, -9.99968f, -9.999578f, -9.999454f, -9.999143f, -9.999611f, -9.999365f, -9.999709f, -9.9992285f, -9.9998255f, -9.999111f, -9.999831f, -9.999511f, -9.999469f, -9.99995f, -9.999711f, 0.5344577f, 0.28066808f, 0.56196564f, 0.5902792f, 0.8473387f, 0.24633567f, 0.92718124f, 0.17364842f, 0.31536132f, 0.22439669f, 0.46772173f, 0.23150134f, 0.13030241f, 0.7544915f, 0.32698f, 0.59160626f, 0.5460109f, 0.84683007f, 0.23899049f, 0.8182671f, 0.7197824f, 0.8125036f, 0.8256115f, 0.40416914f, 0.66582596f, 0.0867179f, 0.0084044915f, 0.49205506f, 0.721172f, 0.40177187f, 0.29393357f, 0.015860511f, 0.93151456f, 0.4811004f, 0.54983306f, 0.9995074f, 0.27758396f, 0.22854643f, 0.5583765f, 0.6666239f, 0.85158247f, 0.21441942f, 0.6990569f, 0.017201606f, 0.530989f, 0.21839866f, 0.08578203f, 0.10198945f, 0.039713096f, 0.7290501f, 0.6342606f, 0.51234406f, 0.12498403f, 0.25547478f, 0.8394662f, 0.8280061f, 0.81155413f, 0.012060473f, 0.057682104f, 0.7739566f, 0.08708117f, 0.5193988f, 0.8415829f, 0.7520876f, 0.007182941f, 0.7731886f, 0.33688733f, 0.19361727f, 0.84651196f, 0.22044875f, 0.54851544f, 0.6421493f, 0.58298194f, 0.6989305f, 0.4031829f, 0.41380137f, 0.20955233f, 0.47619122f, 0.65416205f, 0.44766036f, 0.7429968f, 0.47871348f, 0.36874366f, 0.76017255f, 0.63620025f, 0.6808348f, 0.8399061f, 0.72613007f, 0.97575134f, 0.4643534f, 0.7247778f, 0.04549828f, 0.5940095f, 0.5128606f, 0.5878437f, 0.46860144f, 0.6618377f, 0.83293724f, 0.26350665f, 0.24366878f, 0.7788333f, 0.74646133f, 0.5429722f, 0.26375026f, 0.3656472f, 0.12205635f, 0.7138406f, 0.7608406f, 0.60281974f, 0.33415812f, 0.16791728f, 0.68858635f, 0.4469567f, 0.04436514f, 0.5672564f, 0.89869404f, 0.6294232f, 0.9793584f, 0.092907295f, 0.51271373f, 0.3846658f, 0.79488826f, 0.30746242f, 0.9191275f, 0.9108379f, 0.78182805f, 0.97138745f, 0.9847524f, 0.8531674f, 0.022702204f, 0.621023f, 0.7043253f, 0.22311302f, 0.6966194f, 0.36192545f, 0.8646154f, 0.94498384f, 0.8819606f, 0.39050183f, 0.66352f, 0.9537454f, 0.9776376f, 0.07475392f, 0.14165574f, 0.9068708f, 0.07851684f, 0.098995164f, 0.4659044f, 0.94835365f, 0.8669782f, 0.47114196f, 0.24303971f, 0.36649755f, 0.38048944f, 0.3541504f, 0.3041829f, 0.04842617f, 0.5725111f, 0.68421566f, 0.18098183f, 0.96466625f, 0.32582006f, 0.47631285f, 0.17308696f, 0.5422008f, 0.43860963f, 0.94000804f, 0.90531296f, 0.24555893f, 0.15075591f, 0.8892247f, 0.80251575f, 0.43217945f, 0.5427292f, 0.58730876f, 0.9010511f, 0.75740033f, 0.16942962f, 0.77507013f, 0.7471421f, 0.18903506f, 0.96626693f, 0.43212372f, 0.9690648f, 0.31306309f, 0.62832534f, 0.7866172f, 0.79370797f, 0.32908842f, 0.5066318f, 0.34556115f, 0.1002444f, 0.90521127f, 0.3832993f, 0.3292787f, 0.9103993f, 0.17307699f, 0.36895168f, 0.7688117f, 0.7769159f, 0.7559714f, 0.7624208f, 0.4072027f, 0.6700012f, 0.10266004f, 0.46105045f, 0.8847699f, 0.3703581f, 0.79471564f, 0.18433845f, 0.26636884f, 0.5759068f, 0.025358567f, 0.6020128f, 0.85619676f, 0.77020776f, 0.8782154f, 0.605358f, 0.82230324f, 0.3943509f, 0.10723012f, 0.23251477f, 0.41980323f, 0.44982743f, 0.3976f, 0.24261324f, 0.09185766f, 0.9083403f, 0.8951799f, 0.93775445f, 0.4116088f, 0.8328249f, 0.060170095f, 0.23731631f, 0.043149915f, 0.8760627f, 0.9832404f, 0.8160704f, 0.35087004f, 0.99301636f, 0.58498734f, 0.31982517f, 0.28746068f, 0.10150419f, 0.64765805f, 0.93925524f, 0.6288832f, 0.5287214f, 0.6787367f, 0.7280878f, 0.8089835f, 0.45152652f, 0.28626585f, 0.37735057f, 0.84606636f, 0.17912877f, 0.1262947f, 0.93639624f, 0.74632484f, 0.10586514f, 0.2034781f, 0.3999192f, 0.6237884f, 0.58933526f, 0.11924875f, 0.16451561f, 0.5822025f, 0.3976624f, 0.9056206f, 0.66830647f, 0.801052f, 0.6321766f, 0.47481045f, 0.6505067f, 0.5119758f, 0.8057609f, 0.059799645f, 0.014172987f, 0.637021f, 0.878043f, 0.19765095f, 0.7158634f, 0.6288858f, 0.41249686f, 0.2579455f, 0.32608235f, 0.153792f, 0.030521471f, 0.5082303f, 0.33682522f, 0.5155604f, 0.8285316f, 0.7492474f, 0.56472075f, 0.7964325f, 0.8807934f, 0.21563967f, 0.67301345f, 0.32791767f, 0.47523862f}; }; class EuclideanDistanceTest : public testing::Test { public: - float x[16] = {1,2,3,4,5,6,7,8,1,2,3,4,5,6,7,8}; - float y[16] = {2,3,4,5,6,7,8,9,2,3,4,5,6,7,8,9}; - float result[9] = {0}; + float x[16] = {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f}; + float y[16] = {2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f}; + float result[9] = {0.f}; Nd4jLong shapeBuffer[12] = {4,2,2,2,2,8,4,2,1,0,1,99}; int dimensionLength = 3; int dimension[3] = {1,2,3}; - float extraVals[2] = {0,0}; + float extraVals[2] = {0.f, 0.f}; int opNum = 1; std::vector dim = {1, 2, 3}; @@ -91,7 +91,7 @@ TEST_F(StdTest,MultiDimTest) { ASSERT_EQ(resultLengthAssertion,shape::length(resultShapeInfo)); shape::TAD *tad = new shape::TAD; tad->init(xShapeInfo,dimensionsForStd,dimensionLength); - float none[1] = {0}; + float none[1] = {0.f}; tad->createTadOnlyShapeInfo(); tad->createOffsets(); int tadElementWiseStride = shape::elementWiseStride(tad->tadOnlyShapeInfo); @@ -130,7 +130,7 @@ TEST_F(ReduceTest,MatrixTest) { ASSERT_EQ(resultLengthAssertion,shape::length(resultShapeInfo)); shape::TAD *tad = new shape::TAD; tad->init(xShapeInfo,dimension,dimensionLength); - float none[1] = {0}; + float none[1] = {0.f}; tad->createTadOnlyShapeInfo(); tad->createOffsets(); auto tadElementWiseStride = shape::elementWiseStride(tad->tadOnlyShapeInfo); From 8123d9fa9bf662de618ee80e7070e90ce983779a Mon Sep 17 00:00:00 2001 From: Alex Black Date: Mon, 2 Dec 2019 18:07:54 +1100 Subject: [PATCH 18/30] SameDiff: Add Java-level assertion check/exception (#96) Signed-off-by: Alex Black --- .../samediff/internal/InferenceSession.java | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/InferenceSession.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/InferenceSession.java index 32a1cc362..4a6a5ce53 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/InferenceSession.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/InferenceSession.java @@ -36,6 +36,7 @@ import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction; import org.nd4j.linalg.api.ops.impl.shape.Concat; import org.nd4j.linalg.api.ops.impl.shape.Stack; import org.nd4j.linalg.api.ops.impl.shape.tensorops.*; +import org.nd4j.linalg.api.ops.impl.transforms.Assert; import org.nd4j.linalg.api.ops.impl.transforms.gradient.GradientBackwardsMarker; import org.nd4j.linalg.api.ops.impl.transforms.same.Identity; import org.nd4j.linalg.api.shape.LongShapeDescriptor; @@ -458,6 +459,25 @@ public class InferenceSession extends AbstractSession { INDArray out = mmgr.allocate(false, arr.dataType(), arr.shape()); out.assign(arr); return new INDArray[]{out}; + } else if (op instanceof Assert) { + Assert a = (Assert)op; + boolean condition = a.getInputArgument(0).getDouble(0) != 0.0; + if(!condition){ + //Assertion failed + String s = "Assertion failed for operation \"" + op.getOwnName() + "\" during execution"; + if(a.numInputArguments() >= 3) { + INDArray msg = a.getInputArgument(2); + if (msg != null && msg.dataType() == DataType.UTF8) { + s += ": " + msg.getString(0); + } + } + if(a.numInputArguments() >= 5){ + INDArray arr = a.getInputArgument(4); + s += "\n" + arr; + } + throw new IllegalStateException(s); + } + return ((Assert) op).outputArguments(); } else if (op instanceof CustomOp) { CustomOp c = (CustomOp) op; Nd4j.exec(c); From 1adc25919c2620575e1e15c4a9337ae43a02ff20 Mon Sep 17 00:00:00 2001 From: Fariz Rahman Date: Mon, 2 Dec 2019 13:50:23 +0530 Subject: [PATCH 19/30] Python updates (#86) * python updates * fix cyclic deps * konduit updates * konduit updates * fix list * fixes * sync pyvars test * setuprun comments * Version fix, other module test fixes Signed-off-by: Alex Black * bug fix using advanced hacking skillzz --- .../transforms/transform/ExecutionTest.java | 8 +- .../TestPythonTransformProcess.java | 157 ++- datavec/datavec-python/pom.xml | 12 +- .../java/org/datavec/python/NumpyArray.java | 31 +- .../org/datavec/python/PythonCondition.java | 123 +- .../org/datavec/python/PythonExecutioner.java | 1221 +++++++++++++---- .../org/datavec/python/PythonTransform.java | 273 ++-- .../java/org/datavec/python/PythonUtils.java | 306 +++++ .../org/datavec/python/PythonVariables.java | 511 +++++-- .../src/main/resources/pythonexec/__init__.py | 0 .../main/resources/pythonexec/clear_vars.py | 5 + .../main/resources/pythonexec/input_code.py | 1 + .../main/resources/pythonexec/outputcode.py | 20 + .../src/main/resources/pythonexec/patch0.py | 202 +++ .../src/main/resources/pythonexec/patch1.py | 172 +++ .../main/resources/pythonexec/pythonexec.py | 20 + .../resources/pythonexec/serialize_array.py | 50 + .../python/TestPythonExecutionSandbox.java | 75 + .../datavec/python/TestPythonExecutioner.java | 74 +- .../datavec/python/TestPythonSetupAndRun.java | 27 + .../datavec/python/TestPythonVariables.java | 102 ++ .../java/org/datavec/python/TestSerde.java | 9 +- .../spark/transform/ExecutionTest.java | 14 +- pom.xml | 2 + 24 files changed, 2724 insertions(+), 691 deletions(-) rename datavec/{datavec-python/src/test/java/org/datavec/python => datavec-local/src/test/java/org/datavec/local/transforms/transform}/TestPythonTransformProcess.java (64%) create mode 100644 datavec/datavec-python/src/main/java/org/datavec/python/PythonUtils.java create mode 100644 datavec/datavec-python/src/main/resources/pythonexec/__init__.py create mode 100644 datavec/datavec-python/src/main/resources/pythonexec/clear_vars.py create mode 100644 datavec/datavec-python/src/main/resources/pythonexec/input_code.py create mode 100644 datavec/datavec-python/src/main/resources/pythonexec/outputcode.py create mode 100644 datavec/datavec-python/src/main/resources/pythonexec/patch0.py create mode 100644 datavec/datavec-python/src/main/resources/pythonexec/patch1.py create mode 100644 datavec/datavec-python/src/main/resources/pythonexec/pythonexec.py create mode 100644 datavec/datavec-python/src/main/resources/pythonexec/serialize_array.py create mode 100644 datavec/datavec-python/src/test/java/org/datavec/python/TestPythonExecutionSandbox.java create mode 100644 datavec/datavec-python/src/test/java/org/datavec/python/TestPythonSetupAndRun.java create mode 100644 datavec/datavec-python/src/test/java/org/datavec/python/TestPythonVariables.java diff --git a/datavec/datavec-local/src/test/java/org/datavec/local/transforms/transform/ExecutionTest.java b/datavec/datavec-local/src/test/java/org/datavec/local/transforms/transform/ExecutionTest.java index 2f508f09e..19733f297 100644 --- a/datavec/datavec-local/src/test/java/org/datavec/local/transforms/transform/ExecutionTest.java +++ b/datavec/datavec-local/src/test/java/org/datavec/local/transforms/transform/ExecutionTest.java @@ -256,11 +256,9 @@ public class ExecutionTest { TransformProcess transformProcess = new TransformProcess.Builder(schema) .transform( - new PythonTransform( - "first = np.sin(first)\nsecond = np.cos(second)", - schema - ) - ) + PythonTransform.builder().code( + "first = np.sin(first)\nsecond = np.cos(second)") + .outputSchema(schema).build()) .build(); List> functions = new ArrayList<>(); diff --git a/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonTransformProcess.java b/datavec/datavec-local/src/test/java/org/datavec/local/transforms/transform/TestPythonTransformProcess.java similarity index 64% rename from datavec/datavec-python/src/test/java/org/datavec/python/TestPythonTransformProcess.java rename to datavec/datavec-local/src/test/java/org/datavec/local/transforms/transform/TestPythonTransformProcess.java index 77ba53e26..37df8ae52 100644 --- a/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonTransformProcess.java +++ b/datavec/datavec-local/src/test/java/org/datavec/local/transforms/transform/TestPythonTransformProcess.java @@ -14,35 +14,40 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -package org.datavec.python; +package org.datavec.local.transforms.transform; import org.datavec.api.transform.TransformProcess; import org.datavec.api.transform.condition.Condition; import org.datavec.api.transform.filter.ConditionFilter; import org.datavec.api.transform.filter.Filter; -import org.datavec.api.writable.*; import org.datavec.api.transform.schema.Schema; -import org.junit.Ignore; +import org.datavec.local.transforms.LocalTransformExecutor; + +import org.datavec.api.writable.*; +import org.datavec.python.PythonCondition; +import org.datavec.python.PythonTransform; import org.junit.Test; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; - +import javax.annotation.concurrent.NotThreadSafe; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.List; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertTrue; -@Ignore("AB 2019/05/21 - Fine locally, timeouts on CI - Issue #7657 and #7771") +import static junit.framework.TestCase.assertTrue; +import static org.datavec.api.transform.schema.Schema.Builder; +import static org.junit.Assert.*; + +@NotThreadSafe public class TestPythonTransformProcess { - @Test(timeout = 60000L) + + @Test() public void testStringConcat() throws Exception{ - Schema.Builder schemaBuilder = new Schema.Builder(); + Builder schemaBuilder = new Builder(); schemaBuilder .addColumnString("col1") .addColumnString("col2"); @@ -54,10 +59,12 @@ public class TestPythonTransformProcess { String pythonCode = "col3 = col1 + col2"; TransformProcess tp = new TransformProcess.Builder(initialSchema).transform( - new PythonTransform(pythonCode, finalSchema) + PythonTransform.builder().code(pythonCode) + .outputSchema(finalSchema) + .build() ).build(); - List inputs = Arrays.asList((Writable) new Text("Hello "), new Text("World!")); + List inputs = Arrays.asList((Writable)new Text("Hello "), new Text("World!")); List outputs = tp.execute(inputs); assertEquals((outputs.get(0)).toString(), "Hello "); @@ -68,7 +75,7 @@ public class TestPythonTransformProcess { @Test(timeout = 60000L) public void testMixedTypes() throws Exception{ - Schema.Builder schemaBuilder = new Schema.Builder(); + Builder schemaBuilder = new Builder(); schemaBuilder .addColumnInteger("col1") .addColumnFloat("col2") @@ -83,11 +90,12 @@ public class TestPythonTransformProcess { String pythonCode = "col5 = (int(col3) + col1 + int(col2)) * int(col4)"; TransformProcess tp = new TransformProcess.Builder(initialSchema).transform( - new PythonTransform(pythonCode, finalSchema) - ).build(); + PythonTransform.builder().code(pythonCode) + .outputSchema(finalSchema) + .inputSchema(initialSchema) + .build() ).build(); - List inputs = Arrays.asList((Writable) - new IntWritable(10), + List inputs = Arrays.asList((Writable)new IntWritable(10), new FloatWritable(3.5f), new Text("5"), new DoubleWritable(2.0) @@ -105,7 +113,7 @@ public class TestPythonTransformProcess { INDArray expectedOutput = arr1.add(arr2); - Schema.Builder schemaBuilder = new Schema.Builder(); + Builder schemaBuilder = new Builder(); schemaBuilder .addColumnNDArray("col1", shape) .addColumnNDArray("col2", shape); @@ -116,12 +124,14 @@ public class TestPythonTransformProcess { String pythonCode = "col3 = col1 + col2"; TransformProcess tp = new TransformProcess.Builder(initialSchema).transform( - new PythonTransform(pythonCode, finalSchema) - ).build(); + PythonTransform.builder().code(pythonCode) + .outputSchema(finalSchema) + .build() ).build(); List inputs = Arrays.asList( - (Writable) new NDArrayWritable(arr1), - new NDArrayWritable(arr2) + (Writable) + new NDArrayWritable(arr1), + new NDArrayWritable(arr2) ); List outputs = tp.execute(inputs); @@ -139,7 +149,7 @@ public class TestPythonTransformProcess { INDArray expectedOutput = arr1.add(arr2); - Schema.Builder schemaBuilder = new Schema.Builder(); + Builder schemaBuilder = new Builder(); schemaBuilder .addColumnNDArray("col1", shape) .addColumnNDArray("col2", shape); @@ -150,11 +160,13 @@ public class TestPythonTransformProcess { String pythonCode = "col3 = col1 + col2"; TransformProcess tp = new TransformProcess.Builder(initialSchema).transform( - new PythonTransform(pythonCode, finalSchema) - ).build(); + PythonTransform.builder().code(pythonCode) + .outputSchema(finalSchema) + .build() ).build(); List inputs = Arrays.asList( - (Writable) new NDArrayWritable(arr1), + (Writable) + new NDArrayWritable(arr1), new NDArrayWritable(arr2) ); @@ -172,7 +184,7 @@ public class TestPythonTransformProcess { INDArray arr2 = Nd4j.rand(DataType.DOUBLE, shape); INDArray expectedOutput = arr1.add(arr2.castTo(DataType.DOUBLE)); - Schema.Builder schemaBuilder = new Schema.Builder(); + Builder schemaBuilder = new Builder(); schemaBuilder .addColumnNDArray("col1", shape) .addColumnNDArray("col2", shape); @@ -183,11 +195,14 @@ public class TestPythonTransformProcess { String pythonCode = "col3 = col1 + col2"; TransformProcess tp = new TransformProcess.Builder(initialSchema).transform( - new PythonTransform(pythonCode, finalSchema) + PythonTransform.builder().code(pythonCode) + .outputSchema(finalSchema) + .build() ).build(); List inputs = Arrays.asList( - (Writable) new NDArrayWritable(arr1), + (Writable) + new NDArrayWritable(arr1), new NDArrayWritable(arr2) ); @@ -199,8 +214,8 @@ public class TestPythonTransformProcess { } @Test(timeout = 60000L) - public void testPythonFilter(){ - Schema schema = new Schema.Builder().addColumnInteger("column").build(); + public void testPythonFilter() { + Schema schema = new Builder().addColumnInteger("column").build(); Condition condition = new PythonCondition( "f = lambda: column < 0" @@ -210,17 +225,17 @@ public class TestPythonTransformProcess { Filter filter = new ConditionFilter(condition); - assertFalse(filter.removeExample(Collections.singletonList((Writable) new IntWritable(10)))); - assertFalse(filter.removeExample(Collections.singletonList((Writable) new IntWritable(1)))); - assertFalse(filter.removeExample(Collections.singletonList((Writable) new IntWritable(0)))); - assertTrue(filter.removeExample(Collections.singletonList((Writable) new IntWritable(-1)))); - assertTrue(filter.removeExample(Collections.singletonList((Writable) new IntWritable(-10)))); + assertFalse(filter.removeExample(Collections.singletonList(new IntWritable(10)))); + assertFalse(filter.removeExample(Collections.singletonList(new IntWritable(1)))); + assertFalse(filter.removeExample(Collections.singletonList(new IntWritable(0)))); + assertTrue(filter.removeExample(Collections.singletonList(new IntWritable(-1)))); + assertTrue(filter.removeExample(Collections.singletonList(new IntWritable(-10)))); } @Test(timeout = 60000L) public void testPythonFilterAndTransform() throws Exception{ - Schema.Builder schemaBuilder = new Schema.Builder(); + Builder schemaBuilder = new Builder(); schemaBuilder .addColumnInteger("col1") .addColumnFloat("col2") @@ -241,33 +256,85 @@ public class TestPythonTransformProcess { String pythonCode = "col6 = str(col1 + col2)"; TransformProcess tp = new TransformProcess.Builder(initialSchema).transform( - new PythonTransform( - pythonCode, - finalSchema - ) + PythonTransform.builder().code(pythonCode) + .outputSchema(finalSchema) + .build() ).filter( filter ).build(); List> inputs = new ArrayList<>(); inputs.add( - Arrays.asList((Writable) new IntWritable(5), + Arrays.asList( + (Writable) + new IntWritable(5), new FloatWritable(3.0f), new Text("abcd"), new DoubleWritable(2.1)) ); inputs.add( - Arrays.asList((Writable) new IntWritable(-3), + Arrays.asList( + (Writable) + new IntWritable(-3), new FloatWritable(3.0f), new Text("abcd"), new DoubleWritable(2.1)) ); inputs.add( - Arrays.asList((Writable) new IntWritable(5), + Arrays.asList( + (Writable) + new IntWritable(5), new FloatWritable(11.2f), new Text("abcd"), new DoubleWritable(2.1)) ); + LocalTransformExecutor.execute(inputs,tp); } -} + + + @Test + public void testPythonTransformNoOutputSpecified() throws Exception { + PythonTransform pythonTransform = PythonTransform.builder() + .code("a += 2; b = 'hello world'") + .returnAllInputs(true) + .build(); + List> inputs = new ArrayList<>(); + inputs.add(Arrays.asList((Writable)new IntWritable(1))); + Schema inputSchema = new Builder() + .addColumnInteger("a") + .build(); + + TransformProcess tp = new TransformProcess.Builder(inputSchema) + .transform(pythonTransform) + .build(); + List> execute = LocalTransformExecutor.execute(inputs, tp); + assertEquals(3,execute.get(0).get(0).toInt()); + assertEquals("hello world",execute.get(0).get(1).toString()); + + } + + @Test + public void testNumpyTransform() throws Exception { + PythonTransform pythonTransform = PythonTransform.builder() + .code("a += 2; b = 'hello world'") + .returnAllInputs(true) + .build(); + + List> inputs = new ArrayList<>(); + inputs.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.scalar(1).reshape(1,1)))); + Schema inputSchema = new Builder() + .addColumnNDArray("a",new long[]{1,1}) + .build(); + + TransformProcess tp = new TransformProcess.Builder(inputSchema) + .transform(pythonTransform) + .build(); + List> execute = LocalTransformExecutor.execute(inputs, tp); + assertFalse(execute.isEmpty()); + assertNotNull(execute.get(0)); + assertNotNull(execute.get(0).get(0)); + assertEquals("hello world",execute.get(0).get(0).toString()); + } + +} \ No newline at end of file diff --git a/datavec/datavec-python/pom.xml b/datavec/datavec-python/pom.xml index 449364207..55cf6c5da 100644 --- a/datavec/datavec-python/pom.xml +++ b/datavec/datavec-python/pom.xml @@ -28,15 +28,21 @@ - com.googlecode.json-simple - json-simple - 1.1 + org.json + json + 20190722 org.bytedeco cpython-platform ${cpython-platform.version} + + org.bytedeco + numpy-platform + ${numpy.javacpp.version} + + com.google.code.findbugs jsr305 diff --git a/datavec/datavec-python/src/main/java/org/datavec/python/NumpyArray.java b/datavec/datavec-python/src/main/java/org/datavec/python/NumpyArray.java index a6ccc3036..ab49cf5ea 100644 --- a/datavec/datavec-python/src/main/java/org/datavec/python/NumpyArray.java +++ b/datavec/datavec-python/src/main/java/org/datavec/python/NumpyArray.java @@ -16,10 +16,13 @@ package org.datavec.python; +import lombok.Builder; import lombok.Getter; +import lombok.NoArgsConstructor; import org.bytedeco.javacpp.Pointer; import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.nativeblas.NativeOps; import org.nd4j.nativeblas.NativeOpsHolder; @@ -33,19 +36,27 @@ import org.nd4j.linalg.api.buffer.DataType; * @author Fariz Rahman */ @Getter +@NoArgsConstructor public class NumpyArray { - private static NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps(); + private static NativeOps nativeOps; private long address; private long[] shape; private long[] strides; - private DataType dtype = DataType.FLOAT; + private DataType dtype; private INDArray nd4jArray; + static { + //initialize + Nd4j.scalar(1.0); + nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps(); + } - public NumpyArray(long address, long[] shape, long strides[], boolean copy){ + @Builder + public NumpyArray(long address, long[] shape, long strides[], boolean copy,DataType dtype) { this.address = address; this.shape = shape; this.strides = strides; + this.dtype = dtype; setND4JArray(); if (copy){ nd4jArray = nd4jArray.dup(); @@ -57,8 +68,9 @@ public class NumpyArray { public NumpyArray copy(){ return new NumpyArray(nd4jArray.dup()); } + public NumpyArray(long address, long[] shape, long strides[]){ - this(address, shape, strides, false); + this(address, shape, strides, false,DataType.FLOAT); } public NumpyArray(long address, long[] shape, long strides[], DataType dtype){ @@ -77,9 +89,9 @@ public class NumpyArray { } } - private void setND4JArray(){ + private void setND4JArray() { long size = 1; - for(long d: shape){ + for(long d: shape) { size *= d; } Pointer ptr = nativeOps.pointerForAddress(address); @@ -88,10 +100,11 @@ public class NumpyArray { DataBuffer buff = Nd4j.createBuffer(ptr, size, dtype); int elemSize = buff.getElementSize(); long[] nd4jStrides = new long[strides.length]; - for (int i=0; i= 1,"Python code must not be empty!"); code = pythonCode; } - private PythonVariables schemaToPythonVariables(Schema schema) throws Exception{ - PythonVariables pyVars = new PythonVariables(); - int numCols = schema.numColumns(); - for (int i=0; i writables){ - PythonVariables ret = new PythonVariables(); - for (String name: pyInputs.getVariables()){ - int colIdx = inputSchema.getIndexOfColumn(name); - Writable w = writables.get(colIdx); - PythonVariables.Type pyType = pyInputs.getType(name); - switch (pyType){ - case INT: - if (w instanceof LongWritable){ - ret.addInt(name, ((LongWritable)w).get()); - } - else{ - ret.addInt(name, ((IntWritable)w).get()); - } - break; - case FLOAT: - ret.addFloat(name, ((DoubleWritable)w).get()); - break; - case STR: - ret.addStr(name, ((Text)w).toString()); - break; - case NDARRAY: - ret.addNDArray(name,((NDArrayWritable)w).get()); - break; - } - - } - return ret; - } @Override - public void setInputSchema(Schema inputSchema){ + public void setInputSchema(Schema inputSchema) { this.inputSchema = inputSchema; try{ pyInputs = schemaToPythonVariables(inputSchema); PythonVariables pyOuts = new PythonVariables(); pyOuts.addInt("out"); - pythonTransform = new PythonTransform( - code + "\n\nout=f()\nout=0 if out is None else int(out)", // TODO: remove int conversion after boolean support is covered - pyInputs, - pyOuts - ); + pythonTransform = PythonTransform.builder() + .code(code + "\n\nout=f()\nout=0 if out is None else int(out)") + .inputs(pyInputs) + .outputs(pyOuts) + .build(); + } catch (Exception e){ throw new RuntimeException(e); @@ -127,41 +76,47 @@ public class PythonCondition implements Condition { return inputSchema; } - public String[] outputColumnNames(){ + @Override + public String[] outputColumnNames() { String[] columnNames = new String[inputSchema.numColumns()]; inputSchema.getColumnNames().toArray(columnNames); return columnNames; } + @Override public String outputColumnName(){ return outputColumnNames()[0]; } + @Override public String[] columnNames(){ return outputColumnNames(); } + @Override public String columnName(){ return outputColumnName(); } + @Override public Schema transform(Schema inputSchema){ return inputSchema; } - public boolean condition(List list){ + @Override + public boolean condition(List list) { PythonVariables inputs = getPyInputsFromWritables(list); try{ PythonExecutioner.exec(pythonTransform.getCode(), inputs, pythonTransform.getOutputs()); boolean ret = pythonTransform.getOutputs().getIntValue("out") != 0; return ret; } - catch (Exception e){ + catch (Exception e) { throw new RuntimeException(e); } - } + @Override public boolean condition(Object input){ return condition(input); } @@ -177,5 +132,37 @@ public class PythonCondition implements Condition { throw new UnsupportedOperationException("not supported"); } + private PythonVariables getPyInputsFromWritables(List writables) { + PythonVariables ret = new PythonVariables(); -} + for (int i = 0; i < inputSchema.numColumns(); i++){ + String name = inputSchema.getName(i); + Writable w = writables.get(i); + PythonVariables.Type pyType = pyInputs.getType(inputSchema.getName(i)); + switch (pyType){ + case INT: + if (w instanceof LongWritable) { + ret.addInt(name, ((LongWritable)w).get()); + } + else { + ret.addInt(name, ((IntWritable)w).get()); + } + + break; + case FLOAT: + ret.addFloat(name, ((DoubleWritable)w).get()); + break; + case STR: + ret.addStr(name, w.toString()); + break; + case NDARRAY: + ret.addNDArray(name,((NDArrayWritable)w).get()); + break; + } + } + + return ret; + } + + +} \ No newline at end of file diff --git a/datavec/datavec-python/src/main/java/org/datavec/python/PythonExecutioner.java b/datavec/datavec-python/src/main/java/org/datavec/python/PythonExecutioner.java index c46d0d710..c6272e7ad 100644 --- a/datavec/datavec-python/src/main/java/org/datavec/python/PythonExecutioner.java +++ b/datavec/datavec-python/src/main/java/org/datavec/python/PythonExecutioner.java @@ -17,132 +17,504 @@ package org.datavec.python; -import java.io.File; -import java.io.FileInputStream; -import java.util.HashMap; +import java.io.*; +import java.nio.charset.Charset; +import java.util.ArrayList; +import java.util.List; import java.util.Map; -import java.util.regex.Pattern; + import lombok.extern.slf4j.Slf4j; -import org.json.simple.JSONArray; -import org.json.simple.JSONObject; -import org.json.simple.parser.JSONParser; +import org.apache.commons.io.FileUtils; +import org.apache.commons.io.IOUtils; +import org.json.JSONObject; +import org.json.JSONArray; import org.bytedeco.javacpp.*; import org.bytedeco.cpython.*; import static org.bytedeco.cpython.global.python.*; +import org.bytedeco.numpy.global.numpy; + +import static org.datavec.python.PythonUtils.*; + +import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.io.ClassPathResource; + /** - * Python executioner + * Allows execution of python scripts managed by + * an internal interpreter. + * An end user may specify a python script to run + * via any of the execution methods available in this class. + * + * At static initialization time (when the class is first initialized) + * a number of components are setup: + * 1. The python path. A user may over ride this with the system property {@link #DEFAULT_PYTHON_PATH_PROPERTY} + * + * 2. Since this executioner uses javacpp to manage and run python interpreters underneath the covers, + * a user may also over ride the system property {@link #JAVACPP_PYTHON_APPEND_TYPE} with one of the {@link JavaCppPathType} + * values. This will allow the user to determine whether the javacpp default python path is used at all, and if so + * whether it is appended, prepended, or not used. This behavior is useful when you need to use an external + * python distribution such as anaconda. + * + * 3. A main interpreter: This is the default interpreter to be used with the main thread. + * We may initialize one or more relative to the thread invoking the python code. + * + * 4. A proper numpy import for use with javacpp: We call numpy import ourselves to ensure proper loading of + * native libraries needed by numpy are allowed to load in the proper order. If we don't do this, + * it causes a variety of issues with running numpy. + * + * 5. Various python scripts pre defined on the classpath included right with the java code. + * These are auxillary python scripts used for loading classes, pre defining certain kinds of behavior + * in order for us to manipulate values within the python memory, as well as pulling them out of memory + * for integration within the internal python executioner. You can see this behavior in {@link #_readOutputs(PythonVariables)} + * as an example. + * + * For more information on how this works, please take a look at the {@link #init()} + * method. + * + * Generally, a user defining a python script for use by the python executioner + * will have a set of defined target input values and output values. + * These values should not be present when actually running the script, but just referenced. + * In order to test your python script for execution outside the engine, + * we recommend commenting out a few default values as dummy input values. + * This will allow an end user to test their script before trying to use the server. + * + * In order to get output values out of a python script, all a user has to do + * is define the output variables they want being used in the final output in the actual pipeline. + * For example, if a user wants to return a dictionary, they just have to create a dictionary with that name + * and based on the configured {@link PythonVariables} passed as outputs + * to one of the execution methods, we can pull the values out automatically. + * + * For input definitions, it is similar. You just define the values you want used in + * {@link PythonVariables} and we will automatically generate code for defining those values + * as desired for running. This allows the user to customize values dynamically + * at runtime but reference them by name in a python script. + * * * @author Fariz Rahman + * @author Adam Gibson + */ + + +/** + * Allows execution of python scripts managed by + * an internal interpreter. + * An end user may specify a python script to run + * via any of the execution methods available in this class. + * + * At static initialization time (when the class is first initialized) + * a number of components are setup: + * 1. The python path. A user may over ride this with the system property {@link #DEFAULT_PYTHON_PATH_PROPERTY} + * + * 2. Since this executioner uses javacpp to manage and run python interpreters underneath the covers, + * a user may also over ride the system property {@link #JAVACPP_PYTHON_APPEND_TYPE} with one of the {@link JavaCppPathType} + * values. This will allow the user to determine whether the javacpp default python path is used at all, and if so + * whether it is appended, prepended, or not used. This behavior is useful when you need to use an external + * python distribution such as anaconda. + * + * 3. A main interpreter: This is the default interpreter to be used with the main thread. + * We may initialize one or more relative to the thread invoking the python code. + * + * 4. A proper numpy import for use with javacpp: We call numpy import ourselves to ensure proper loading of + * native libraries needed by numpy are allowed to load in the proper order. If we don't do this, + * it causes a variety of issues with running numpy. + * + * 5. Various python scripts pre defined on the classpath included right with the java code. + * These are auxillary python scripts used for loading classes, pre defining certain kinds of behavior + * in order for us to manipulate values within the python memory, as well as pulling them out of memory + * for integration within the internal python executioner. You can see this behavior in {@link #_readOutputs(PythonVariables)} + * as an example. + * + * For more information on how this works, please take a look at the {@link #init()} + * method. + * + * Generally, a user defining a python script for use by the python executioner + * will have a set of defined target input values and output values. + * These values should not be present when actually running the script, but just referenced. + * In order to test your python script for execution outside the engine, + * we recommend commenting out a few default values as dummy input values. + * This will allow an end user to test their script before trying to use the server. + * + * In order to get output values out of a python script, all a user has to do + * is define the output variables they want being used in the final output in the actual pipeline. + * For example, if a user wants to return a dictionary, they just have to create a dictionary with that name + * and based on the configured {@link PythonVariables} passed as outputs + * to one of the execution methods, we can pull the values out automatically. + * + * For input definitions, it is similar. You just define the values you want used in + * {@link PythonVariables} and we will automatically generate code for defining those values + * as desired for running. This allows the user to customize values dynamically + * at runtime but reference them by name in a python script. + * + * + * @author Fariz Rahman + * @author Adam Gibson */ @Slf4j public class PythonExecutioner { - private static PyObject module; - private static PyObject globals; - private static JSONParser parser = new JSONParser(); - private static Map gilStates = new HashMap<>(); + + private final static String fileVarName = "_f" + Nd4j.getRandom().nextInt(); + private static boolean init; + public final static String DEFAULT_PYTHON_PATH_PROPERTY = "org.datavec.python.path"; + public final static String JAVACPP_PYTHON_APPEND_TYPE = "org.datavec.python.javacpp.path.append"; + public final static String DEFAULT_APPEND_TYPE = "before"; + private static Map interpreters = new java.util.concurrent.ConcurrentHashMap<>(); + private static PyThreadState currentThreadState; + private static PyThreadState mainThreadState; + public final static String ALL_VARIABLES_KEY = "allVariables"; + public final static String MAIN_INTERPRETER_NAME = "main"; + private static String clearVarsCode; + + private static String currentInterpreter = MAIN_INTERPRETER_NAME; + + /** + * One of a few desired values + * for how we should handle + * using javacpp's python path. + * BEFORE: Prepend the python path alongside a defined one + * AFTER: Append the javacpp python path alongside the defined one + * NONE: Don't use javacpp's python path at all + */ + public enum JavaCppPathType { + BEFORE,AFTER,NONE + } + + /** + * Set the python path. + * Generally you can just use the PYTHONPATH environment variable, + * but if you need to set it from code, this can work as well. + */ + public static synchronized void setPythonPath() { + if(!init) { + try { + String path = System.getProperty(DEFAULT_PYTHON_PATH_PROPERTY); + if(path == null) { + log.info("Setting python default path"); + File[] packages = numpy.cachePackages(); + Py_SetPath(packages); + } + else { + log.info("Setting python path " + path); + StringBuffer sb = new StringBuffer(); + File[] packages = numpy.cachePackages(); + + JavaCppPathType pathAppendValue = JavaCppPathType.valueOf(System.getProperty(JAVACPP_PYTHON_APPEND_TYPE,DEFAULT_APPEND_TYPE).toUpperCase()); + switch(pathAppendValue) { + case BEFORE: + for(File cacheDir : packages) { + sb.append(cacheDir); + sb.append(java.io.File.pathSeparator); + } + + sb.append(path); + + log.info("Prepending javacpp python path " + sb.toString()); + break; + case AFTER: + sb.append(path); + + for(File cacheDir : packages) { + sb.append(cacheDir); + sb.append(java.io.File.pathSeparator); + } + + log.info("Appending javacpp python path " + sb.toString()); + break; + case NONE: + log.info("Not appending javacpp path"); + sb.append(path); + break; + } + + //prepend the javacpp packages + log.info("Final python path " + sb.toString()); + + Py_SetPath(sb.toString()); + } + } catch (IOException e) { + log.error("Failed to set python path.", e); + } + } + else { + throw new IllegalStateException("Unable to reset python path. Already initialized."); + } + } + + /** + * Initialize the name space and the python execution + * Calling this method more than once will be a no op + */ + public static synchronized void init() { + if(init) { + return; + } + + try(InputStream is = new org.nd4j.linalg.io.ClassPathResource("pythonexec/clear_vars.py").getInputStream()) { + clearVarsCode = IOUtils.toString(new java.io.InputStreamReader(is)); + } catch (java.io.IOException e) { + throw new IllegalStateException("Unable to read pythonexec/clear_vars.py"); + } + + log.info("CPython: PyEval_InitThreads()"); + PyEval_InitThreads(); + log.info("CPython: Py_InitializeEx()"); + Py_InitializeEx(0); + log.info("CPython: PyGILState_Release()"); + init = true; + interpreters.put(MAIN_INTERPRETER_NAME, PyThreadState_Get()); + numpy._import_array(); + applyPatches(); + } + + + /** + * Run {@link #resetInterpreter(String)} + * on all interpreters. + */ + public static void resetAllInterpreters() { + for(String interpreter : interpreters.keySet()) { + resetInterpreter(interpreter); + } + } + + /** + * Reset the main interpreter. + * For more information see {@link #resetInterpreter(String)} + */ + public static void resetMainInterpreter() { + resetInterpreter(MAIN_INTERPRETER_NAME); + } + + /** + * Reset the interpreter with the given name. + * Runs pythonexec/clear_vars.py + * For more information see: + * https://stackoverflow.com/questions/3543833/how-do-i-clear-all-variables-in-the-middle-of-a-python-script + * @param interpreterName the interpreter name to + * reset + */ + public static synchronized void resetInterpreter(String interpreterName) { + Preconditions.checkState(hasInterpreter(interpreterName)); + log.info("Resetting interpreter " + interpreterName); + String oldInterpreter = currentInterpreter; + setInterpreter(interpreterName); + exec("pass"); + //exec(interpreterName); // ?? + setInterpreter(oldInterpreter); + } + + /** + * Clear the non main intrepreters. + */ + public static void clearNonMainInterpreters() { + for(String key : interpreters.keySet()) { + if(!key.equals(MAIN_INTERPRETER_NAME)) { + deleteInterpreter(key); + } + } + } + + public static PythonVariables defaultPythonVariableOutput() { + PythonVariables ret = new PythonVariables(); + ret.add(ALL_VARIABLES_KEY, PythonVariables.Type.DICT); + return ret; + } + + /** + * Return the python path being used. + * @return a string specifying the python path in use + */ + public static String getPythonPath() { + return new BytePointer(Py_GetPath()).getString(); + } + static { + setPythonPath(); init(); } - public static void init(){ - log.info("CPython: Py_InitializeEx()"); - Py_InitializeEx(1); - log.info("CPython: PyEval_InitThreads()"); - PyEval_InitThreads(); - log.info("CPython: PyImport_AddModule()"); - module = PyImport_AddModule("__main__"); - log.info("CPython: PyModule_GetDict()"); - globals = PyModule_GetDict(module); - log.info("CPython: PyThreadState_Get()"); + + /* ---------sub-interpreter and gil management-----------*/ + public static void setInterpreter(String interpreterName) { + if (!hasInterpreter(interpreterName)){ + PyThreadState main = PyThreadState_Get(); + PyThreadState ts = Py_NewInterpreter(); + + interpreters.put(interpreterName, ts); + PyThreadState_Swap(main); + } + + currentInterpreter = interpreterName; + } + + /** + * Returns the current interpreter. + * @return + */ + public static String getInterpreter() { + return currentInterpreter; + } + + + public static boolean hasInterpreter(String interpreterName){ + return interpreters.containsKey(interpreterName); + } + + public static void deleteInterpreter(String interpreterName) { + if (interpreterName.equals("main")){ + throw new IllegalArgumentException("Can not delete main interpreter"); + } + + Py_EndInterpreter(interpreters.remove(interpreterName)); + } + + private static synchronized void acquireGIL() { + log.info("acquireGIL()"); + log.info("CPython: PyEval_SaveThread()"); + mainThreadState = PyEval_SaveThread(); + log.info("CPython: PyThreadState_New()"); + currentThreadState = PyThreadState_New(interpreters.get(currentInterpreter).interp()); + log.info("CPython: PyEval_RestoreThread()"); + PyEval_RestoreThread(currentThreadState); + log.info("CPython: PyThreadState_Swap()"); + PyThreadState_Swap(currentThreadState); + + } + + private static synchronized void releaseGIL() { + log.info("CPython: PyEval_SaveThread()"); PyEval_SaveThread(); + log.info("CPython: PyEval_RestoreThread()"); + PyEval_RestoreThread(mainThreadState); } - public static void free(){ - Py_Finalize(); + /* -------------------*/ + /** + * Print the python version to standard out. + */ + public static void printPythonVersion() { + exec("import sys; print(sys.version) sys.stdout.flush();"); } - private static String inputCode(PythonVariables pyInputs)throws Exception{ - String inputCode = "loc={};"; + + + private static String inputCode(PythonVariables pyInputs)throws Exception { + String inputCode = ""; if (pyInputs == null){ return inputCode; } + Map strInputs = pyInputs.getStrVariables(); Map intInputs = pyInputs.getIntVariables(); Map floatInputs = pyInputs.getFloatVariables(); - Map ndInputs = pyInputs.getNDArrayVariables(); + Map ndInputs = pyInputs.getNdVars(); Map listInputs = pyInputs.getListVariables(); Map fileInputs = pyInputs.getFileVariables(); + Map> dictInputs = pyInputs.getDictVariables(); - String[] VarNames; + String[] varNames; - VarNames = strInputs.keySet().toArray(new String[strInputs.size()]); - for(Object varName: VarNames){ + varNames = strInputs.keySet().toArray(new String[strInputs.size()]); + for(String varName: varNames) { + Preconditions.checkNotNull(varName,"Var name is null!"); + Preconditions.checkNotNull(varName.isEmpty(),"Var name can not be empty!"); String varValue = strInputs.get(varName); - inputCode += varName + " = \"\"\"" + escapeStr(varValue) + "\"\"\"\n"; - inputCode += "loc['" + varName + "']=" + varName + "\n"; + //inputCode += varName + "= {}\n"; + if(varValue != null) + inputCode += varName + " = \"\"\"" + escapeStr(varValue) + "\"\"\"\n"; + else { + inputCode += varName + " = ''\n"; + } } - VarNames = intInputs.keySet().toArray(new String[intInputs.size()]); - for(String varName: VarNames){ + varNames = intInputs.keySet().toArray(new String[intInputs.size()]); + for(String varName: varNames) { Long varValue = intInputs.get(varName); - inputCode += varName + " = " + varValue.toString() + "\n"; - inputCode += "loc['" + varName + "']=" + varName + "\n"; + if(varValue != null) + inputCode += varName + " = " + varValue.toString() + "\n"; + else { + inputCode += " = 0\n"; + } } - VarNames = floatInputs.keySet().toArray(new String[floatInputs.size()]); - for(String varName: VarNames){ + varNames = dictInputs.keySet().toArray(new String[dictInputs.size()]); + for(String varName: varNames) { + Map varValue = dictInputs.get(varName); + if(varValue != null) { + throw new IllegalArgumentException("Unable to generate input code for dictionaries."); + } + else { + inputCode += " = {}\n"; + } + } + + varNames = floatInputs.keySet().toArray(new String[floatInputs.size()]); + for(String varName: varNames){ Double varValue = floatInputs.get(varName); - inputCode += varName + " = " + varValue.toString() + "\n"; - inputCode += "loc['" + varName + "']=" + varName + "\n"; + if(varValue != null) + inputCode += varName + " = " + varValue.toString() + "\n"; + else { + inputCode += varName + " = 0.0\n"; + } } - VarNames = listInputs.keySet().toArray(new String[listInputs.size()]); - for (String varName: VarNames){ + varNames = listInputs.keySet().toArray(new String[listInputs.size()]); + for (String varName: varNames) { Object[] varValue = listInputs.get(varName); - String listStr = jArrayToPyString(varValue); - inputCode += varName + " = " + listStr + "\n"; - inputCode += "loc['" + varName + "']=" + varName + "\n"; + if(varValue != null) { + String listStr = jArrayToPyString(varValue); + inputCode += varName + " = " + listStr + "\n"; + } + else { + inputCode += varName + " = []\n"; + } + } - VarNames = fileInputs.keySet().toArray(new String[fileInputs.size()]); - for(Object varName: VarNames){ + varNames = fileInputs.keySet().toArray(new String[fileInputs.size()]); + for(String varName: varNames) { String varValue = fileInputs.get(varName); - inputCode += varName + " = \"\"\"" + escapeStr(varValue) + "\"\"\"\n"; - inputCode += "loc['" + varName + "']=" + varName + "\n"; + if(varValue != null) + inputCode += varName + " = \"\"\"" + escapeStr(varValue) + "\"\"\"\n"; + else { + inputCode += varName + " = ''\n"; + } } - if (ndInputs.size()> 0){ - inputCode += "import ctypes; import numpy as np;"; - VarNames = ndInputs.keySet().toArray(new String[ndInputs.size()]); + if (!ndInputs.isEmpty()) { + inputCode += "import ctypes\n\nimport sys\nimport numpy as np\n"; + varNames = ndInputs.keySet().toArray(new String[ndInputs.size()]); - String converter = "__arr_converter = lambda addr, shape, type: np.ctypeslib.as_array(ctypes.cast(addr, ctypes.POINTER(type)), shape);"; + String converter = "__arr_converter = lambda addr, shape, type: np.ctypeslib.as_array(ctypes.cast(addr, ctypes.POINTER(type)), shape)\n"; inputCode += converter; - for(String varName: VarNames){ + for(String varName: varNames) { NumpyArray npArr = ndInputs.get(varName); + if(npArr == null) + continue; + npArr = npArr.copy(); String shapeStr = "("; for (long d: npArr.getShape()){ - shapeStr += String.valueOf(d) + ","; + shapeStr += d + ","; } shapeStr += ")"; String code; String ctype; - if (npArr.getDtype() == DataType.FLOAT){ + if (npArr.getDtype() == DataType.FLOAT) { ctype = "ctypes.c_float"; } - else if (npArr.getDtype() == DataType.DOUBLE){ + else if (npArr.getDtype() == DataType.DOUBLE) { ctype = "ctypes.c_double"; } - else if (npArr.getDtype() == DataType.SHORT){ + else if (npArr.getDtype() == DataType.SHORT) { ctype = "ctypes.c_int16"; } - else if (npArr.getDtype() == DataType.INT){ + else if (npArr.getDtype() == DataType.INT) { ctype = "ctypes.c_int32"; } else if (npArr.getDtype() == DataType.LONG){ @@ -152,10 +524,9 @@ public class PythonExecutioner { throw new Exception("Unsupported data type: " + npArr.getDtype().toString() + "."); } - code = "__arr_converter(" + String.valueOf(npArr.getAddress()) + "," + shapeStr + "," + ctype + ")"; - code = varName + "=" + code + "\n"; + code = "__arr_converter(" + npArr.getAddress() + "," + shapeStr + "," + ctype + ")"; + code = varName + "=" + code + "\n"; inputCode += code; - inputCode += "loc['" + varName + "']=" + varName + "\n"; } } @@ -163,49 +534,62 @@ public class PythonExecutioner { } - private static void _readOutputs(PythonVariables pyOutputs){ - String json = read(getTempFile()); + private static synchronized void _readOutputs(PythonVariables pyOutputs) throws IOException { File f = new File(getTempFile()); + Preconditions.checkState(f.exists(),"File " + f.getAbsolutePath() + " failed to get written for reading outputs!"); + String json = FileUtils.readFileToString(f, Charset.defaultCharset()); + log.info("Executioner output: "); + log.info(json); f.delete(); - JSONParser p = new JSONParser(); - try{ - JSONObject jobj = (JSONObject) p.parse(json); - for (String varName: pyOutputs.getVariables()){ + + if(json.isEmpty()) { + log.warn("No json found fore reading outputs. Returning."); + return; + } + + try { + JSONObject jobj = new JSONObject(json); + for (String varName: pyOutputs.getVariables()) { PythonVariables.Type type = pyOutputs.getType(varName); - if (type == PythonVariables.Type.NDARRAY){ + if (type == PythonVariables.Type.NDARRAY) { JSONObject varValue = (JSONObject)jobj.get(varName); - long address = (Long)varValue.get("address"); - JSONArray shapeJson = (JSONArray)varValue.get("shape"); - JSONArray stridesJson = (JSONArray)varValue.get("strides"); + long address = (Long) varValue.getLong("address"); + JSONArray shapeJson = (JSONArray) varValue.get("shape"); + JSONArray stridesJson = (JSONArray) varValue.get("strides"); long[] shape = jsonArrayToLongArray(shapeJson); long[] strides = jsonArrayToLongArray(stridesJson); String dtypeName = (String)varValue.get("dtype"); DataType dtype; - if (dtypeName.equals("float64")){ + if (dtypeName.equals("float64")) { dtype = DataType.DOUBLE; } - else if (dtypeName.equals("float32")){ + else if (dtypeName.equals("float32")) { dtype = DataType.FLOAT; } - else if (dtypeName.equals("int16")){ + else if (dtypeName.equals("int16")) { dtype = DataType.SHORT; } - else if (dtypeName.equals("int32")){ + else if (dtypeName.equals("int32")) { dtype = DataType.INT; } - else if (dtypeName.equals("int64")){ + else if (dtypeName.equals("int64")) { dtype = DataType.LONG; } else{ throw new Exception("Unsupported array type " + dtypeName + "."); } + pyOutputs.setValue(varName, new NumpyArray(address, shape, strides, dtype, true)); - } - else if (type == PythonVariables.Type.LIST){ - JSONArray varValue = (JSONArray)jobj.get(varName); - pyOutputs.setValue(varName, varValue.toArray()); + else if (type == PythonVariables.Type.LIST) { + JSONArray varValue = (JSONArray) jobj.get(varName); + pyOutputs.setValue(varName, varValue); + } + else if (type == PythonVariables.Type.DICT) { + Map map = toMap((JSONObject) jobj.get(varName)); + pyOutputs.setValue(varName, map); + } else{ pyOutputs.setValue(varName, jobj.get(varName)); @@ -217,266 +601,422 @@ public class PythonExecutioner { } } - private static void acquireGIL(){ - log.info("---_enterSubInterpreter()---"); - if (PyGILState_Check() != 1){ - gilStates.put(Thread.currentThread().getId(), PyGILState_Ensure()); - log.info("GIL ensured"); + + + + private static synchronized void _exec(String code) { + log.info(code); + log.info("CPython: PyRun_SimpleStringFlag()"); + + int result = PyRun_SimpleStringFlags(code, null); + if (result != 0) { + log.info("CPython: PyErr_Print"); + PyErr_Print(); + throw new RuntimeException("exec failed"); } } - private static void releaseGIL(){ - if (PyGILState_Check() == 1){ - log.info("Releasing gil..."); - PyGILState_Release(gilStates.get(Thread.currentThread().getId())); - log.info("Gil released."); - } - + private static synchronized void _exec_wrapped(String code) { + _exec(getWrappedCode(code)); } /** * Executes python code. Also manages python thread state. - * @param code + * @param code the code to run */ - public static void exec(String code){ - code = getFunctionalCode("__f_" + Thread.currentThread().getId(), code); + public static void exec(String code) { + code = getWrappedCode(code); + if(code.contains("import numpy") && !getInterpreter().equals("main")) {// FIXME + throw new IllegalArgumentException("Unable to execute numpy on sub interpreter. See https://mail.python.org/pipermail/python-dev/2019-January/156095.html for the reasons."); + } + acquireGIL(); - log.info("CPython: PyRun_SimpleStringFlag()"); - log.info(code); - int result = PyRun_SimpleStringFlags(code, null); - if (result != 0){ - PyErr_Print(); - throw new RuntimeException("exec failed"); + _exec(code); + log.info("Exec done"); + releaseGIL(); + } + + private static boolean _hasGlobalVariable(String varName){ + PyObject mainModule = PyImport_AddModule("__main__"); + PyObject var = PyObject_GetAttrString(mainModule, varName); + boolean hasVar = var != null; + Py_DecRef(var); + return hasVar; + } + + /** + * Executes python code and looks for methods setup() and run() + * If both setup() and run() are found, both are executed for the first + * time and for subsequent calls only run() is executed. + */ + public static void execWithSetupAndRun(String code) { + code = getWrappedCode(code); + if(code.contains("import numpy") && !getInterpreter().equals("main")) { // FIXME + throw new IllegalArgumentException("Unable to execute numpy on sub interpreter. See https://mail.python.org/pipermail/python-dev/2019-January/156095.html for the reasons."); + } + + acquireGIL(); + _exec(code); + if (_hasGlobalVariable("setup") && _hasGlobalVariable("run")){ + log.debug("setup() and run() methods found."); + if (!_hasGlobalVariable("__setup_done__")){ + log.debug("Calling setup()..."); + _exec("setup()"); + _exec("__setup_done__ = True"); + } + log.debug("Calling run()..."); + _exec("run()"); } log.info("Exec done"); releaseGIL(); } - public static void exec(String code, PythonVariables pyOutputs){ - exec(code + '\n' + outputCode(pyOutputs)); - _readOutputs(pyOutputs); + /** + * Executes python code and looks for methods setup() and run() + * If both setup() and run() are found, both are executed for the first + * time and for subsequent calls only run() is executed. + */ + public static void execWithSetupAndRun(String code, PythonVariables pyOutputs) { + code = getWrappedCode(code); + if(code.contains("import numpy") && !getInterpreter().equals("main")) { // FIXME + throw new IllegalArgumentException("Unable to execute numpy on sub interpreter. See https://mail.python.org/pipermail/python-dev/2019-January/156095.html for the reasons."); + } + + acquireGIL(); + _exec(code); + if (_hasGlobalVariable("setup") && _hasGlobalVariable("run")){ + log.debug("setup() and run() methods found."); + if (!_hasGlobalVariable("__setup_done__")){ + log.debug("Calling setup()..."); + _exec("setup()"); + _exec("__setup_done__ = True"); + } + log.debug("Calling run()..."); + _exec("__out = run();for (k,v) in __out.items(): globals()[k]=v"); + } + log.info("Exec done"); + try { + + _readOutputs(pyOutputs); + + } catch (IOException e) { + log.error("Failed to read outputs", e); + } + + releaseGIL(); } - public static void exec(String code, PythonVariables pyInputs, PythonVariables pyOutputs) throws Exception{ + /** + * Run the given code with the given python outputs + * @param code the code to run + * @param pyOutputs the outputs to run + */ + public static void exec(String code, PythonVariables pyOutputs) { + + exec(code + '\n' + outputCode(pyOutputs)); + try { + + _readOutputs(pyOutputs); + + } catch (IOException e) { + log.error("Failed to read outputs", e); + } + + releaseGIL(); + } + + + /** + * Execute the given python code with the given + * {@link PythonVariables} as inputs and outputs + * @param code the code to run + * @param pyInputs the inputs to the code + * @param pyOutputs the outputs to the code + * @throws Exception + */ + public static void exec(String code, PythonVariables pyInputs, PythonVariables pyOutputs) throws Exception { String inputCode = inputCode(pyInputs); exec(inputCode + code, pyOutputs); } - - public static PythonVariables exec(PythonTransform transform) throws Exception{ - if (transform.getInputs() != null && transform.getInputs().getVariables().length > 0){ - throw new Exception("Required inputs not provided."); + /** + * Execute the given python code + * with the {@link PythonVariables} + * inputs and outputs for storing the values + * specified by the user and needed by the user + * as output + * @param code the python code to execute + * @param pyInputs the python variables input in to the python script + * @param pyOutputs the python variables output returned by the python script + * @throws Exception + */ + public static void execWithSetupAndRun(String code, PythonVariables pyInputs, PythonVariables pyOutputs) throws Exception { + String inputCode = inputCode(pyInputs); + code = inputCode +code; + code = getWrappedCode(code); + if(code.contains("import numpy") && !getInterpreter().equals("main")) { // FIXME + throw new IllegalArgumentException("Unable to execute numpy on sub interpreter. See https://mail.python.org/pipermail/python-dev/2019-January/156095.html for the reasons."); } - exec(transform.getCode(), null, transform.getOutputs()); - return transform.getOutputs(); + acquireGIL(); + _exec(code); + if (_hasGlobalVariable("setup") && _hasGlobalVariable("run")){ + log.debug("setup() and run() methods found."); + if (!_hasGlobalVariable("__setup_done__")){ + releaseGIL(); // required + acquireGIL(); + log.debug("Calling setup()..."); + _exec("setup()"); + _exec("__setup_done__ = True"); + }else{ + log.debug("setup() already called once."); + } + log.debug("Calling run()..."); + releaseGIL(); // required + acquireGIL(); + _exec("import inspect\n"+ + "__out = run(**{k:globals()[k]for k in inspect.getfullargspec(run).args})\n"+ + "globals().update(__out)"); + } + releaseGIL(); // required + acquireGIL(); + _exec(outputCode(pyOutputs)); + log.info("Exec done"); + try { + + _readOutputs(pyOutputs); + + } catch (IOException e) { + log.error("Failed to read outputs", e); + } + + releaseGIL(); } - public static PythonVariables exec(PythonTransform transform, PythonVariables inputs)throws Exception{ + + + private static String interpreterNameFromTransform(PythonTransform transform){ + return transform.getName().replace("-", "_"); + } + + + /** + * Run a {@link PythonTransform} with the given inputs + * @param transform the transform to run + * @param inputs the inputs to the transform + * @return the output variables + * @throws Exception + */ + public static PythonVariables exec(PythonTransform transform, PythonVariables inputs)throws Exception { + String name = interpreterNameFromTransform(transform); + setInterpreter(name); + Preconditions.checkNotNull(transform.getOutputs(),"Transform outputs were null!"); exec(transform.getCode(), inputs, transform.getOutputs()); return transform.getOutputs(); } - - - public static String evalSTRING(String varName){ - log.info("CPython: PyImport_AddModule()"); - module = PyImport_AddModule("__main__"); - log.info("CPython: PyModule_GetDict()"); - globals = PyModule_GetDict(module); - PyObject xObj = PyDict_GetItemString(globals, varName); - PyObject bytes = PyUnicode_AsEncodedString(xObj, "UTF-8", "strict"); - BytePointer bp = PyBytes_AsString(bytes); - String ret = bp.getString(); - Py_DecRef(xObj); - Py_DecRef(bytes); - return ret; + public static PythonVariables execWithSetupAndRun(PythonTransform transform, PythonVariables inputs)throws Exception { + String name = interpreterNameFromTransform(transform); + setInterpreter(name); + Preconditions.checkNotNull(transform.getOutputs(),"Transform outputs were null!"); + execWithSetupAndRun(transform.getCode(), inputs, transform.getOutputs()); + return transform.getOutputs(); } - public static long evalINTEGER(String varName){ - log.info("CPython: PyImport_AddModule()"); - module = PyImport_AddModule("__main__"); - log.info("CPython: PyModule_GetDict()"); - globals = PyModule_GetDict(module); - PyObject xObj = PyDict_GetItemString(globals, varName); - long ret = PyLong_AsLongLong(xObj); - return ret; + + /** + * Run the code and return the outputs + * @param code the code to run + * @return all python variables + */ + public static PythonVariables execAndReturnAllVariables(String code) { + exec(code + '\n' + outputCodeForAllVariables()); + PythonVariables allVars = new PythonVariables(); + allVars.addDict(ALL_VARIABLES_KEY); + try { + _readOutputs(allVars); + }catch (IOException e) { + log.error("Failed to read outputs", e); + } + + return expandInnerDict(allVars, ALL_VARIABLES_KEY); + } + public static PythonVariables execWithSetupRunAndReturnAllVariables(String code) { + execWithSetupAndRun(code + '\n' + outputCodeForAllVariables()); + PythonVariables allVars = new PythonVariables(); + allVars.addDict(ALL_VARIABLES_KEY); + try { + _readOutputs(allVars); + }catch (IOException e) { + log.error("Failed to read outputs", e); + } + + return expandInnerDict(allVars, ALL_VARIABLES_KEY); } - public static double evalFLOAT(String varName){ - log.info("CPython: PyImport_AddModule()"); - module = PyImport_AddModule("__main__"); - log.info("CPython: PyModule_GetDict()"); - globals = PyModule_GetDict(module); - PyObject xObj = PyDict_GetItemString(globals, varName); - double ret = PyFloat_AsDouble(xObj); - return ret; + /** + * + * @param code code string to run + * @param pyInputs python input variables + * @return all python variables + * @throws Exception throws when there's an issue while execution of python code + */ + public static PythonVariables execAndReturnAllVariables(String code, PythonVariables pyInputs) throws Exception { + String inputCode = inputCode(pyInputs); + return execAndReturnAllVariables(inputCode + code); + } + public static PythonVariables execWithSetupRunAndReturnAllVariables(String code, PythonVariables pyInputs) throws Exception { + String inputCode = inputCode(pyInputs); + return execWithSetupRunAndReturnAllVariables(inputCode + code); } - public static Object[] evalLIST(String varName) throws Exception{ - log.info("CPython: PyImport_AddModule()"); - module = PyImport_AddModule("__main__"); - log.info("CPython: PyModule_GetDict()"); - globals = PyModule_GetDict(module); - PyObject xObj = PyDict_GetItemString(globals, varName); - PyObject strObj = PyObject_Str(xObj); - PyObject bytes = PyUnicode_AsEncodedString(strObj, "UTF-8", "strict"); - BytePointer bp = PyBytes_AsString(bytes); - String listStr = bp.getString(); - Py_DecRef(xObj); - Py_DecRef(bytes); - JSONArray jsonArray = (JSONArray)parser.parse(listStr.replace("\'", "\"")); - return jsonArray.toArray(); + + /** + * Evaluate a string based on the + * current variable name. + * This variable named needs to be present + * or defined earlier in python code + * in order to pull out the values. + * + * @param varName the variable name to evaluate + * @return the evaluated value + */ + public static String evalString(String varName) { + PythonVariables vars = new PythonVariables(); + vars.addStr(varName); + exec("print('')", vars); + return vars.getStrValue(varName); } - public static NumpyArray evalNDARRAY(String varName) throws Exception{ - log.info("CPython: PyImport_AddModule()"); - module = PyImport_AddModule("__main__"); - log.info("CPython: PyModule_GetDict()"); - globals = PyModule_GetDict(module); - PyObject xObj = PyDict_GetItemString(globals, varName); - PyObject arrayInterface = PyObject_GetAttrString(xObj, "__array_interface__"); - PyObject data = PyDict_GetItemString(arrayInterface, "data"); - PyObject zero = PyLong_FromLong(0); - PyObject addressObj = PyObject_GetItem(data, zero); - long address = PyLong_AsLongLong(addressObj); - PyObject shapeObj = PyObject_GetAttrString(xObj, "shape"); - int ndim = (int)PyObject_Size(shapeObj); - PyObject iObj; - long shape[] = new long[ndim]; - for (int i=0; i 0) + outputCode = outputCode.substring(0, outputCode.length() - 1); + outputCode += "})"; + outputCode += "\nwith open('" + getTempFile() + "', 'w') as " + fileVarName + ":" + fileVarName + ".write(" + outputVarName() + ")"; + + return outputCode; } - private static String read(String path){ - try{ - File file = new File(path); - FileInputStream fis = new FileInputStream(file); - byte[] data = new byte[(int) file.length()]; - fis.read(data); - fis.close(); - String str = new String(data, "UTF-8"); - return str; - } - catch (Exception e){ - return ""; - } - } - private static String jArrayToPyString(Object[] array){ + private static String jArrayToPyString(Object[] array) { String str = "["; - for (int i=0; i < array.length; i++){ + for (int i = 0; i < array.length; i++){ Object obj = array[i]; if (obj instanceof Object[]){ str += jArrayToPyString((Object[])obj); @@ -496,32 +1036,109 @@ public class PythonExecutioner { return str; } - private static String escapeStr(String str){ + private static String escapeStr(String str) { + if(str == null) + return null; str = str.replace("\\", "\\\\"); str = str.replace("\"\"\"", "\\\"\\\"\\\""); return str; } - private static String getFunctionalCode(String functionName, String code){ - String out = String.format("def %s():\n", functionName); - for(String line: code.split(Pattern.quote("\n"))){ - out += " " + line + "\n"; + private static String getWrappedCode(String code) { + try(InputStream is = new ClassPathResource("pythonexec/pythonexec.py").getInputStream()) { + String base = IOUtils.toString(is, Charset.defaultCharset()); + StringBuffer indentedCode = new StringBuffer(); + for(String split : code.split("\n")) { + indentedCode.append(" " + split + "\n"); + + } + + String out = base.replace(" pass",indentedCode); + return out; + } catch (IOException e) { + throw new IllegalStateException("Unable to read python code!",e); } - return out + "\n\n" + functionName + "()\n"; + } - private static String getTempFile(){ - String ret = "temp_" + Thread.currentThread().getId() + ".json"; + + private static String getTempFile() { + String ret = "temp_" + Thread.currentThread().getId() + "_" + currentInterpreter + ".json"; log.info(ret); return ret; } - private static long[] jsonArrayToLongArray(JSONArray jsonArray){ - long[] longs = new long[jsonArray.size()]; - for (int i=0; i _getPatches() { + exec("import numpy as np"); + exec( "__overrides_path = np.core.overrides.__file__"); + exec("__random_path = np.random.__file__"); + + List patches = new ArrayList<>(); + + patches.add(new String[]{ + "pythonexec/patch0.py", + evalString("__overrides_path") + }); + patches.add(new String[]{ + "pythonexec/patch1.py", + evalString("__random_path") + }); + + return patches; + } + + private static void _applyPatch(String src, String dest){ + try(InputStream is = new ClassPathResource(src).getInputStream()) { + String patch = IOUtils.toString(is, Charset.defaultCharset()); + FileUtils.write(new File(dest), patch, "utf-8"); + } + catch(IOException e){ + throw new RuntimeException("Error reading resource."); + } + } + + private static boolean _checkPatchApplied(String dest) { + try { + return FileUtils.readFileToString(new File(dest), "utf-8").startsWith("#patch"); + } catch (IOException e) { + throw new RuntimeException("Error patching numpy"); + + } + } + + private static void applyPatches() { + for (String[] patch : _getPatches()){ + if (_checkPatchApplied(patch[1])){ + log.info("Patch already applied for " + patch[1]); + } + else{ + _applyPatch(patch[0], patch[1]); + log.info("Applied patch for " + patch[1]); + } + } + for (String[] patch: _getPatches()){ + if (!_checkPatchApplied(patch[1])){ + throw new RuntimeException("Error patching numpy"); + } + } + } +} \ No newline at end of file diff --git a/datavec/datavec-python/src/main/java/org/datavec/python/PythonTransform.java b/datavec/datavec-python/src/main/java/org/datavec/python/PythonTransform.java index e3b3fb2bf..8f2460035 100644 --- a/datavec/datavec-python/src/main/java/org/datavec/python/PythonTransform.java +++ b/datavec/datavec-python/src/main/java/org/datavec/python/PythonTransform.java @@ -16,16 +16,29 @@ package org.datavec.python; +import lombok.Builder; import lombok.Data; import lombok.NoArgsConstructor; +import org.apache.commons.io.IOUtils; import org.datavec.api.transform.ColumnType; import org.datavec.api.transform.Transform; import org.datavec.api.transform.schema.Schema; import org.datavec.api.writable.*; +import org.nd4j.base.Preconditions; +import org.nd4j.jackson.objectmapper.holder.ObjectMapperHolder; +import org.nd4j.linalg.io.ClassPathResource; +import org.nd4j.shade.jackson.core.JsonProcessingException; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.charset.Charset; import java.util.ArrayList; import java.util.List; +import java.util.Map; import java.util.UUID; +import static org.datavec.python.PythonUtils.schemaToPythonVariables; + /** * Row-wise Transform that applies arbitrary python code on each row * @@ -34,31 +47,87 @@ import java.util.UUID; @NoArgsConstructor @Data -public class PythonTransform implements Transform{ +public class PythonTransform implements Transform { + private String code; - private PythonVariables pyInputs; - private PythonVariables pyOutputs; - private String name; + private PythonVariables inputs; + private PythonVariables outputs; + private String name = UUID.randomUUID().toString(); private Schema inputSchema; private Schema outputSchema; + private String outputDict; + private boolean returnAllVariables; + private boolean setupAndRun = false; - public PythonTransform(String code, PythonVariables pyInputs, PythonVariables pyOutputs) throws Exception{ + @Builder + public PythonTransform(String code, + PythonVariables inputs, + PythonVariables outputs, + String name, + Schema inputSchema, + Schema outputSchema, + String outputDict, + boolean returnAllInputs, + boolean setupAndRun) { + Preconditions.checkNotNull(code,"No code found to run!"); this.code = code; - this.pyInputs = pyInputs; - this.pyOutputs = pyOutputs; - this.name = UUID.randomUUID().toString(); + this.returnAllVariables = returnAllInputs; + this.setupAndRun = setupAndRun; + if(inputs != null) + this.inputs = inputs; + if(outputs != null) + this.outputs = outputs; + + if(name != null) + this.name = name; + if (outputDict != null) { + this.outputDict = outputDict; + this.outputs = new PythonVariables(); + this.outputs.addDict(outputDict); + + String helpers; + try(InputStream is = new ClassPathResource("pythonexec/serialize_array.py").getInputStream()) { + helpers = IOUtils.toString(is, Charset.defaultCharset()); + + }catch (IOException e){ + throw new RuntimeException("Error reading python code"); + } + this.code += "\n\n" + helpers; + this.code += "\n" + outputDict + " = __recursive_serialize_dict(" + outputDict + ")"; + } + + try { + if(inputSchema != null) { + this.inputSchema = inputSchema; + if(inputs == null || inputs.isEmpty()) { + this.inputs = schemaToPythonVariables(inputSchema); + } + } + + if(outputSchema != null) { + this.outputSchema = outputSchema; + if(outputs == null || outputs.isEmpty()) { + this.outputs = schemaToPythonVariables(outputSchema); + } + } + }catch(Exception e) { + throw new IllegalStateException(e); + } + } + @Override - public void setInputSchema(Schema inputSchema){ + public void setInputSchema(Schema inputSchema) { + Preconditions.checkNotNull(inputSchema,"No input schema found!"); this.inputSchema = inputSchema; try{ - pyInputs = schemaToPythonVariables(inputSchema); + inputs = schemaToPythonVariables(inputSchema); }catch (Exception e){ throw new RuntimeException(e); } - if (outputSchema == null){ + if (outputSchema == null && outputDict == null){ outputSchema = inputSchema; } @@ -88,12 +157,42 @@ public class PythonTransform implements Transform{ throw new UnsupportedOperationException("Not yet implemented"); } + + + @Override - public List map(List writables){ + public List map(List writables) { PythonVariables pyInputs = getPyInputsFromWritables(writables); + Preconditions.checkNotNull(pyInputs,"Inputs must not be null!"); + + try{ - PythonExecutioner.exec(code, pyInputs, pyOutputs); - return getWritablesFromPyOutputs(pyOutputs); + if (returnAllVariables) { + if (setupAndRun){ + return getWritablesFromPyOutputs(PythonExecutioner.execWithSetupRunAndReturnAllVariables(code, pyInputs)); + } + return getWritablesFromPyOutputs(PythonExecutioner.execAndReturnAllVariables(code, pyInputs)); + } + + if (outputDict != null) { + if (setupAndRun) { + PythonExecutioner.execWithSetupAndRun(this, pyInputs); + }else{ + PythonExecutioner.exec(this, pyInputs); + } + PythonVariables out = PythonUtils.expandInnerDict(outputs, outputDict); + return getWritablesFromPyOutputs(out); + } + else { + if (setupAndRun) { + PythonExecutioner.execWithSetupAndRun(code, pyInputs, outputs); + }else{ + PythonExecutioner.exec(code, pyInputs, outputs); + } + + return getWritablesFromPyOutputs(outputs); + } + } catch (Exception e){ throw new RuntimeException(e); @@ -102,7 +201,7 @@ public class PythonTransform implements Transform{ @Override public String[] outputColumnNames(){ - return pyOutputs.getVariables(); + return outputs.getVariables(); } @Override @@ -111,7 +210,7 @@ public class PythonTransform implements Transform{ } @Override public String[] columnNames(){ - return pyOutputs.getVariables(); + return outputs.getVariables(); } @Override @@ -124,14 +223,13 @@ public class PythonTransform implements Transform{ } - private PythonVariables getPyInputsFromWritables(List writables){ - + private PythonVariables getPyInputsFromWritables(List writables) { PythonVariables ret = new PythonVariables(); - for (String name: pyInputs.getVariables()){ + for (String name: inputs.getVariables()) { int colIdx = inputSchema.getIndexOfColumn(name); Writable w = writables.get(colIdx); - PythonVariables.Type pyType = pyInputs.getType(name); + PythonVariables.Type pyType = inputs.getType(name); switch (pyType){ case INT: if (w instanceof LongWritable){ @@ -143,7 +241,7 @@ public class PythonTransform implements Transform{ break; case FLOAT: - if (w instanceof DoubleWritable){ + if (w instanceof DoubleWritable) { ret.addFloat(name, ((DoubleWritable)w).get()); } else{ @@ -151,96 +249,99 @@ public class PythonTransform implements Transform{ } break; case STR: - ret.addStr(name, ((Text)w).toString()); + ret.addStr(name, w.toString()); break; case NDARRAY: ret.addNDArray(name,((NDArrayWritable)w).get()); break; + default: + throw new RuntimeException("Unsupported input type:" + pyType); } } return ret; } - private List getWritablesFromPyOutputs(PythonVariables pyOuts){ + private List getWritablesFromPyOutputs(PythonVariables pyOuts) { List out = new ArrayList<>(); - for (int i=0; i dictValue = pyOuts.getDictValue(name); + Map noNullValues = new java.util.HashMap<>(); + for(Map.Entry entry : dictValue.entrySet()) { + if(entry.getValue() != org.json.JSONObject.NULL) { + noNullValues.put(entry.getKey(), entry.getValue()); + } + } + + try { + out.add(new Text(ObjectMapperHolder.getJsonMapper().writeValueAsString(noNullValues))); + } catch (JsonProcessingException e) { + throw new IllegalStateException("Unable to serialize dictionary " + name + " to json!"); + } + break; + case LIST: + Object[] listValue = pyOuts.getListValue(name); + try { + out.add(new Text(ObjectMapperHolder.getJsonMapper().writeValueAsString(listValue))); + } catch (JsonProcessingException e) { + throw new IllegalStateException("Unable to serialize list vlaue " + name + " to json!"); + } + break; + default: + throw new IllegalStateException("Unable to support type " + pyType.name()); } } return out; } - public PythonTransform(String code) throws Exception{ - this.code = code; - this.name = UUID.randomUUID().toString(); - } - private PythonVariables schemaToPythonVariables(Schema schema) throws Exception{ - PythonVariables pyVars = new PythonVariables(); - int numCols = schema.numColumns(); - for (int i=0; i 0,"Input must have variables. Found none."); + for(Map.Entry entry : input.getVars().entrySet()) { + switch(entry.getValue()) { + case INT: + schemaBuilder.addColumnInteger(entry.getKey()); + break; + case STR: + schemaBuilder.addColumnString(entry.getKey()); + break; + case FLOAT: + schemaBuilder.addColumnFloat(entry.getKey()); + break; + case NDARRAY: + schemaBuilder.addColumnNDArray(entry.getKey(),null); + break; + case BOOL: + schemaBuilder.addColumn(new BooleanMetaData(entry.getKey())); + } + } + + return schemaBuilder.build(); + } + + /** + * Create a {@link Schema} from an input + * {@link PythonVariables} + * Types are mapped to types of the same name + * @param input the input schema + * @return the output python variables. + */ + public static PythonVariables fromSchema(Schema input) { + PythonVariables ret = new PythonVariables(); + for(int i = 0; i < input.numColumns(); i++) { + String currColumnName = input.getName(i); + ColumnType columnType = input.getType(i); + switch(columnType) { + case NDArray: + ret.add(currColumnName, PythonVariables.Type.NDARRAY); + break; + case Boolean: + ret.add(currColumnName, PythonVariables.Type.BOOL); + break; + case Categorical: + case String: + ret.add(currColumnName, PythonVariables.Type.STR); + break; + case Double: + case Float: + ret.add(currColumnName, PythonVariables.Type.FLOAT); + break; + case Integer: + case Long: + ret.add(currColumnName, PythonVariables.Type.INT); + break; + case Bytes: + break; + case Time: + throw new UnsupportedOperationException("Unable to process dates with python yet."); + } + } + + return ret; + } + /** + * Convert a {@link Schema} + * to {@link PythonVariables} + * @param schema the input schema + * @return the output {@link PythonVariables} where each + * name in the map is associated with a column name in the schema. + * A proper type is also chosen based on the schema + * @throws Exception + */ + public static PythonVariables schemaToPythonVariables(Schema schema) throws Exception { + PythonVariables pyVars = new PythonVariables(); + int numCols = schema.numColumns(); + for (int i = 0; i < numCols; i++) { + String colName = schema.getName(i); + ColumnType colType = schema.getType(i); + switch (colType){ + case Long: + case Integer: + pyVars.addInt(colName); + break; + case Double: + case Float: + pyVars.addFloat(colName); + break; + case String: + pyVars.addStr(colName); + break; + case NDArray: + pyVars.addNDArray(colName); + break; + default: + throw new Exception("Unsupported python input type: " + colType.toString()); + } + } + + return pyVars; + } + + + public static NumpyArray mapToNumpyArray(Map map){ + String dtypeName = (String)map.get("dtype"); + DataType dtype; + if (dtypeName.equals("float64")){ + dtype = DataType.DOUBLE; + } + else if (dtypeName.equals("float32")){ + dtype = DataType.FLOAT; + } + else if (dtypeName.equals("int16")){ + dtype = DataType.SHORT; + } + else if (dtypeName.equals("int32")){ + dtype = DataType.INT; + } + else if (dtypeName.equals("int64")){ + dtype = DataType.LONG; + } + else{ + throw new RuntimeException("Unsupported array type " + dtypeName + "."); + } + List shapeList = (List)map.get("shape"); + long[] shape = new long[shapeList.size()]; + for (int i = 0; i < shape.length; i++) { + shape[i] = (Long)shapeList.get(i); + } + + List strideList = (List)map.get("shape"); + long[] stride = new long[strideList.size()]; + for (int i = 0; i < stride.length; i++) { + stride[i] = (Long)strideList.get(i); + } + long address = (Long)map.get("address"); + NumpyArray numpyArray = new NumpyArray(address, shape, stride, true,dtype); + return numpyArray; + } + + public static PythonVariables expandInnerDict(PythonVariables pyvars, String key){ + Map dict = pyvars.getDictValue(key); + String[] keys = (String[])dict.keySet().toArray(new String[dict.keySet().size()]); + PythonVariables pyvars2 = new PythonVariables(); + for (String subkey: keys){ + Object value = dict.get(subkey); + if (value instanceof Map){ + Map map = (Map)value; + if (map.containsKey("_is_numpy_array")){ + pyvars2.addNDArray(subkey, mapToNumpyArray(map)); + + } + else{ + pyvars2.addDict(subkey, (Map)value); + } + + } + else if (value instanceof List){ + pyvars2.addList(subkey, ((List) value).toArray()); + } + else if (value instanceof String){ + System.out.println((String)value); + pyvars2.addStr(subkey, (String) value); + } + else if (value instanceof Integer || value instanceof Long) { + Number number = (Number) value; + pyvars2.addInt(subkey, number.intValue()); + } + else if (value instanceof Float || value instanceof Double) { + Number number = (Number) value; + pyvars2.addFloat(subkey, number.doubleValue()); + } + else if (value instanceof NumpyArray){ + pyvars2.addNDArray(subkey, (NumpyArray)value); + } + else if (value == null){ + pyvars2.addStr(subkey, "None"); // FixMe + } + else{ + throw new RuntimeException("Unsupported type!" + value); + } + } + return pyvars2; + } + + public static long[] jsonArrayToLongArray(JSONArray jsonArray){ + long[] longs = new long[jsonArray.length()]; + for (int i=0; i toMap(JSONObject jsonobj) { + Map map = new HashMap<>(); + String[] keys = (String[])jsonobj.keySet().toArray(new String[jsonobj.keySet().size()]); + for (String key: keys){ + Object value = jsonobj.get(key); + if (value instanceof JSONArray) { + value = toList((JSONArray) value); + } else if (value instanceof JSONObject) { + JSONObject jsonobj2 = (JSONObject)value; + if (jsonobj2.has("_is_numpy_array")){ + value = jsonToNumpyArray(jsonobj2); + } + else{ + value = toMap(jsonobj2); + } + + } + + map.put(key, value); + } return map; + } + + + public static List toList(JSONArray array) { + List list = new ArrayList<>(); + for (int i = 0; i < array.length(); i++) { + Object value = array.get(i); + if (value instanceof JSONArray) { + value = toList((JSONArray) value); + } else if (value instanceof JSONObject) { + JSONObject jsonobj2 = (JSONObject) value; + if (jsonobj2.has("_is_numpy_array")) { + value = jsonToNumpyArray(jsonobj2); + } else { + value = toMap(jsonobj2); + } + } + list.add(value); + } + return list; + } + + + private static NumpyArray jsonToNumpyArray(JSONObject map){ + String dtypeName = (String)map.get("dtype"); + DataType dtype; + if (dtypeName.equals("float64")){ + dtype = DataType.DOUBLE; + } + else if (dtypeName.equals("float32")){ + dtype = DataType.FLOAT; + } + else if (dtypeName.equals("int16")){ + dtype = DataType.SHORT; + } + else if (dtypeName.equals("int32")){ + dtype = DataType.INT; + } + else if (dtypeName.equals("int64")){ + dtype = DataType.LONG; + } + else{ + throw new RuntimeException("Unsupported array type " + dtypeName + "."); + } + List shapeList = (List)map.get("shape"); + long[] shape = new long[shapeList.size()]; + for (int i = 0; i < shape.length; i++) { + shape[i] = (Long)shapeList.get(i); + } + + List strideList = (List)map.get("shape"); + long[] stride = new long[strideList.size()]; + for (int i = 0; i < stride.length; i++) { + stride[i] = (Long)strideList.get(i); + } + long address = (Long)map.get("address"); + NumpyArray numpyArray = new NumpyArray(address, shape, stride, true,dtype); + return numpyArray; + } + + +} diff --git a/datavec/datavec-python/src/main/java/org/datavec/python/PythonVariables.java b/datavec/datavec-python/src/main/java/org/datavec/python/PythonVariables.java index fb05e7052..4d04f1d87 100644 --- a/datavec/datavec-python/src/main/java/org/datavec/python/PythonVariables.java +++ b/datavec/datavec-python/src/main/java/org/datavec/python/PythonVariables.java @@ -17,8 +17,8 @@ package org.datavec.python; import lombok.Data; -import org.json.simple.JSONArray; -import org.json.simple.JSONObject; +import org.json.JSONObject; +import org.json.JSONArray; import org.nd4j.linalg.api.ndarray.INDArray; import java.io.Serializable; @@ -31,8 +31,8 @@ import java.util.*; * @author Fariz Rahman */ -@Data -public class PythonVariables implements Serializable{ +@lombok.Data +public class PythonVariables implements java.io.Serializable { public enum Type{ BOOL, @@ -41,23 +41,29 @@ public class PythonVariables implements Serializable{ FLOAT, NDARRAY, LIST, - FILE + FILE, + DICT } - private Map strVars = new HashMap(); - private Map intVars = new HashMap(); - private Map floatVars = new HashMap(); - private Map boolVars = new HashMap(); - private Map ndVars = new HashMap(); - private Map listVars = new HashMap(); - private Map fileVars = new HashMap(); - - private Map vars = new HashMap(); - - private Map maps = new HashMap(); + private java.util.Map strVariables = new java.util.LinkedHashMap<>(); + private java.util.Map intVariables = new java.util.LinkedHashMap<>(); + private java.util.Map floatVariables = new java.util.LinkedHashMap<>(); + private java.util.Map boolVariables = new java.util.LinkedHashMap<>(); + private java.util.Map ndVars = new java.util.LinkedHashMap<>(); + private java.util.Map listVariables = new java.util.LinkedHashMap<>(); + private java.util.Map fileVariables = new java.util.LinkedHashMap<>(); + private java.util.Map> dictVariables = new java.util.LinkedHashMap<>(); + private java.util.Map vars = new java.util.LinkedHashMap<>(); + private java.util.Map maps = new java.util.LinkedHashMap<>(); + /** + * Returns a copy of the variable + * schema in this array without the values + * @return an empty variables clone + * with no values + */ public PythonVariables copySchema(){ PythonVariables ret = new PythonVariables(); for (String varName: getVariables()){ @@ -66,15 +72,30 @@ public class PythonVariables implements Serializable{ } return ret; } - public PythonVariables(){ - maps.put(Type.BOOL, boolVars); - maps.put(Type.STR, strVars); - maps.put(Type.INT, intVars); - maps.put(Type.FLOAT, floatVars); - maps.put(Type.NDARRAY, ndVars); - maps.put(Type.LIST, listVars); - maps.put(Type.FILE, fileVars); + /** + * + */ + public PythonVariables() { + maps.put(PythonVariables.Type.BOOL, boolVariables); + maps.put(PythonVariables.Type.STR, strVariables); + maps.put(PythonVariables.Type.INT, intVariables); + maps.put(PythonVariables.Type.FLOAT, floatVariables); + maps.put(PythonVariables.Type.NDARRAY, ndVars); + maps.put(PythonVariables.Type.LIST, listVariables); + maps.put(PythonVariables.Type.FILE, fileVariables); + maps.put(PythonVariables.Type.DICT, dictVariables); + + } + + + + /** + * + * @return true if there are no variables. + */ + public boolean isEmpty() { + return getVariables().length < 1; } @@ -105,6 +126,9 @@ public class PythonVariables implements Serializable{ break; case FILE: addFile(name); + break; + case DICT: + addDict(name); } } @@ -113,252 +137,463 @@ public class PythonVariables implements Serializable{ * @param name name of the variable * @param type type of the variable * @param value value of the variable (must be instance of expected type) - * @throws Exception */ - public void add (String name, Type type, Object value) throws Exception{ + public void add(String name, Type type, Object value) { add(name, type); setValue(name, value); } + + /** + * Add a null variable to + * the set of variables + * to describe the type but no value + * @param name the field to add + */ + public void addDict(String name) { + vars.put(name, PythonVariables.Type.DICT); + dictVariables.put(name,null); + } + + /** + * Add a null variable to + * the set of variables + * to describe the type but no value + * @param name the field to add + */ public void addBool(String name){ - vars.put(name, Type.BOOL); - boolVars.put(name, null); + vars.put(name, PythonVariables.Type.BOOL); + boolVariables.put(name, null); } + /** + * Add a null variable to + * the set of variables + * to describe the type but no value + * @param name the field to add + */ public void addStr(String name){ - vars.put(name, Type.STR); - strVars.put(name, null); + vars.put(name, PythonVariables.Type.STR); + strVariables.put(name, null); } + /** + * Add a null variable to + * the set of variables + * to describe the type but no value + * @param name the field to add + */ public void addInt(String name){ - vars.put(name, Type.INT); - intVars.put(name, null); + vars.put(name, PythonVariables.Type.INT); + intVariables.put(name, null); } + /** + * Add a null variable to + * the set of variables + * to describe the type but no value + * @param name the field to add + */ public void addFloat(String name){ - vars.put(name, Type.FLOAT); - floatVars.put(name, null); + vars.put(name, PythonVariables.Type.FLOAT); + floatVariables.put(name, null); } + /** + * Add a null variable to + * the set of variables + * to describe the type but no value + * @param name the field to add + */ public void addNDArray(String name){ - vars.put(name, Type.NDARRAY); + vars.put(name, PythonVariables.Type.NDARRAY); ndVars.put(name, null); } + /** + * Add a null variable to + * the set of variables + * to describe the type but no value + * @param name the field to add + */ public void addList(String name){ - vars.put(name, Type.LIST); - listVars.put(name, null); + vars.put(name, PythonVariables.Type.LIST); + listVariables.put(name, null); } + /** + * Add a null variable to + * the set of variables + * to describe the type but no value + * @param name the field to add + */ public void addFile(String name){ - vars.put(name, Type.FILE); - fileVars.put(name, null); - } - public void addBool(String name, boolean value){ - vars.put(name, Type.BOOL); - boolVars.put(name, value); + vars.put(name, PythonVariables.Type.FILE); + fileVariables.put(name, null); } - public void addStr(String name, String value){ - vars.put(name, Type.STR); - strVars.put(name, value); + /** + * Add a boolean variable to + * the set of variables + * @param name the field to add + * @param value the value to add + */ + public void addBool(String name, boolean value) { + vars.put(name, PythonVariables.Type.BOOL); + boolVariables.put(name, value); } - public void addInt(String name, int value){ - vars.put(name, Type.INT); - intVars.put(name, (long)value); + /** + * Add a string variable to + * the set of variables + * @param name the field to add + * @param value the value to add + */ + public void addStr(String name, String value) { + vars.put(name, PythonVariables.Type.STR); + strVariables.put(name, value); } - public void addInt(String name, long value){ - vars.put(name, Type.INT); - intVars.put(name, value); + /** + * Add an int variable to + * the set of variables + * @param name the field to add + * @param value the value to add + */ + public void addInt(String name, int value) { + vars.put(name, PythonVariables.Type.INT); + intVariables.put(name, (long)value); } - public void addFloat(String name, double value){ - vars.put(name, Type.FLOAT); - floatVars.put(name, value); + /** + * Add a long variable to + * the set of variables + * @param name the field to add + * @param value the value to add + */ + public void addInt(String name, long value) { + vars.put(name, PythonVariables.Type.INT); + intVariables.put(name, value); } - public void addFloat(String name, float value){ - vars.put(name, Type.FLOAT); - floatVars.put(name, (double)value); + /** + * Add a double variable to + * the set of variables + * @param name the field to add + * @param value the value to add + */ + public void addFloat(String name, double value) { + vars.put(name, PythonVariables.Type.FLOAT); + floatVariables.put(name, value); } - public void addNDArray(String name, NumpyArray value){ - vars.put(name, Type.NDARRAY); + /** + * Add a float variable to + * the set of variables + * @param name the field to add + * @param value the value to add + */ + public void addFloat(String name, float value) { + vars.put(name, PythonVariables.Type.FLOAT); + floatVariables.put(name, (double)value); + } + + /** + * Add a null variable to + * the set of variables + * to describe the type but no value + * @param name the field to add + * @param value the value to add + */ + public void addNDArray(String name, NumpyArray value) { + vars.put(name, PythonVariables.Type.NDARRAY); ndVars.put(name, value); } - public void addNDArray(String name, INDArray value){ - vars.put(name, Type.NDARRAY); + /** + * Add a null variable to + * the set of variables + * to describe the type but no value + * @param name the field to add + * @param value the value to add + */ + public void addNDArray(String name, org.nd4j.linalg.api.ndarray.INDArray value) { + vars.put(name, PythonVariables.Type.NDARRAY); ndVars.put(name, new NumpyArray(value)); } - public void addList(String name, Object[] value){ - vars.put(name, Type.LIST); - listVars.put(name, value); + /** + * Add a null variable to + * the set of variables + * to describe the type but no value + * @param name the field to add + * @param value the value to add + */ + public void addList(String name, Object[] value) { + vars.put(name, PythonVariables.Type.LIST); + listVariables.put(name, value); } - public void addFile(String name, String value){ - vars.put(name, Type.FILE); - fileVars.put(name, value); + /** + * Add a null variable to + * the set of variables + * to describe the type but no value + * @param name the field to add + * @param value the value to add + */ + public void addFile(String name, String value) { + vars.put(name, PythonVariables.Type.FILE); + fileVariables.put(name, value); } + + /** + * Add a null variable to + * the set of variables + * to describe the type but no value + * @param name the field to add + * @param value the value to add + */ + public void addDict(String name, java.util.Map value) { + vars.put(name, PythonVariables.Type.DICT); + dictVariables.put(name, value); + } /** * * @param name name of the variable * @param value new value for the variable - * @throws Exception */ public void setValue(String name, Object value) { Type type = vars.get(name); - if (type == Type.BOOL){ - boolVars.put(name, (Boolean)value); + if (type == PythonVariables.Type.BOOL){ + boolVariables.put(name, (Boolean)value); } - else if (type == Type.INT){ - if (value instanceof Long){ - intVars.put(name, ((Long)value)); - } - else if (value instanceof Integer){ - intVars.put(name, ((Integer)value).longValue()); - - } + else if (type == PythonVariables.Type.INT){ + Number number = (Number) value; + intVariables.put(name, number.longValue()); } - else if (type == Type.FLOAT){ - floatVars.put(name, (Double)value); + else if (type == PythonVariables.Type.FLOAT){ + Number number = (Number) value; + floatVariables.put(name, number.doubleValue()); } - else if (type == Type.NDARRAY){ + else if (type == PythonVariables.Type.NDARRAY){ if (value instanceof NumpyArray){ ndVars.put(name, (NumpyArray)value); } - else if (value instanceof INDArray){ - ndVars.put(name, new NumpyArray((INDArray) value)); + else if (value instanceof org.nd4j.linalg.api.ndarray.INDArray) { + ndVars.put(name, new NumpyArray((org.nd4j.linalg.api.ndarray.INDArray) value)); } else{ throw new RuntimeException("Unsupported type: " + value.getClass().toString()); } } - else if (type == Type.LIST){ - listVars.put(name, (Object[]) value); + else if (type == PythonVariables.Type.LIST) { + if (value instanceof java.util.List) { + value = ((java.util.List) value).toArray(); + listVariables.put(name, (Object[]) value); + } + else if(value instanceof org.json.JSONArray) { + org.json.JSONArray jsonArray = (org.json.JSONArray) value; + Object[] copyArr = new Object[jsonArray.length()]; + for(int i = 0; i < copyArr.length; i++) { + copyArr[i] = jsonArray.get(i); + } + listVariables.put(name, copyArr); + + } + else { + listVariables.put(name, (Object[]) value); + } } - else if (type == Type.FILE){ - fileVars.put(name, (String)value); + else if(type == PythonVariables.Type.DICT) { + dictVariables.put(name,(java.util.Map) value); + } + else if (type == PythonVariables.Type.FILE){ + fileVariables.put(name, (String)value); } else{ - strVars.put(name, (String)value); + strVariables.put(name, (String)value); } } - public Object getValue(String name){ + /** + * Do a general object lookup. + * The look up will happen relative to the {@link Type} + * of variable is described in the + * @param name the name of the variable to get + * @return teh value for the variable with the given name + */ + public Object getValue(String name) { Type type = vars.get(name); - Map map = maps.get(type); + java.util.Map map = maps.get(type); return map.get(name); } + + /** + * Returns a boolean variable with the given name. + * @param name the variable name to get the value for + * @return the retrieved boolean value + */ + public boolean getBooleanValue(String name) { + return boolVariables.get(name); + } + + /** + * + * @param name the variable name + * @return the dictionary value + */ + public java.util.Map getDictValue(String name) { + return dictVariables.get(name); + } + + /** + /** + * + * @param name the variable name + * @return the string value + */ public String getStrValue(String name){ - return strVars.get(name); + return strVariables.get(name); } - public long getIntValue(String name){ - return intVars.get(name); + /** + * + * @param name the variable name + * @return the long value + */ + public Long getIntValue(String name){ + return intVariables.get(name); } - public double getFloatValue(String name){ - return floatVars.get(name); + /** + * + * @param name the variable name + * @return the float value + */ + public Double getFloatValue(String name){ + return floatVariables.get(name); } + /** + * + * @param name the variable name + * @return the numpy array value + */ public NumpyArray getNDArrayValue(String name){ return ndVars.get(name); } + /** + * + * @param name the variable name + * @return the list value as an object array + */ public Object[] getListValue(String name){ - return listVars.get(name); + return listVariables.get(name); } + /** + * + * @param name the variable name + * @return the value of the given file name + */ public String getFileValue(String name){ - return fileVars.get(name); + return fileVariables.get(name); } + /** + * Returns the type for the given variable name + * @param name the name of the variable to get the type for + * @return the type for the given variable + */ public Type getType(String name){ return vars.get(name); } + /** + * Get all the variables present as a string array + * @return the variable names for this variable sset + */ public String[] getVariables() { String[] strArr = new String[vars.size()]; return vars.keySet().toArray(strArr); } - public Map getBoolVariables(){ - return boolVars; - } - public Map getStrVariables(){ - return strVars; - } - - public Map getIntVariables(){ - return intVars; - } - - public Map getFloatVariables(){ - return floatVars; - } - - public Map getNDArrayVariables(){ - return ndVars; - } - - public Map getListVariables(){ - return listVars; - } - - public Map getFileVariables(){ - return fileVars; - } - - public JSONArray toJSON(){ - JSONArray arr = new JSONArray(); + /** + * This variables set as its json representation (an array of json objects) + * @return the json array output + */ + public org.json.JSONArray toJSON(){ + org.json.JSONArray arr = new org.json.JSONArray(); for (String varName: getVariables()){ - JSONObject var = new JSONObject(); + org.json.JSONObject var = new org.json.JSONObject(); var.put("name", varName); String varType = getType(varName).toString(); var.put("type", varType); - arr.add(var); + arr.put(var); } return arr; } - public static PythonVariables fromJSON(JSONArray jsonArray){ + /** + * Create a schema from a map. + * This is an empty PythonVariables + * that just contains names and types with no values + * @param inputTypes the input types to convert + * @return the schema from the given map + */ + public static PythonVariables schemaFromMap(java.util.Map inputTypes) { + PythonVariables ret = new PythonVariables(); + for(java.util.Map.Entry entry : inputTypes.entrySet()) { + ret.add(entry.getKey(), PythonVariables.Type.valueOf(entry.getValue())); + } + + return ret; + } + + /** + * Get the python variable state relative to the + * input json array + * @param jsonArray the input json array + * @return the python variables based on the input json array + */ + public static PythonVariables fromJSON(org.json.JSONArray jsonArray){ PythonVariables pyvars = new PythonVariables(); - for (int i=0; i" + +def maybe_serialize_ndarray_metadata(x): + return serialize_ndarray_metadata(x) if __is_numpy_array(x) else x + + +def serialize_ndarray_metadata(x): + return {"address": x.__array_interface__['data'][0], + "shape": x.shape, + "strides": x.strides, + "dtype": str(x.dtype), + "_is_numpy_array": True} if __is_numpy_array(x) else x + + +def is_json_ready(key, value): + return key is not 'f2' and not inspect.ismodule(value) \ + and not hasattr(value, '__call__') + diff --git a/datavec/datavec-python/src/main/resources/pythonexec/patch0.py b/datavec/datavec-python/src/main/resources/pythonexec/patch0.py new file mode 100644 index 000000000..d2ed3d5e5 --- /dev/null +++ b/datavec/datavec-python/src/main/resources/pythonexec/patch0.py @@ -0,0 +1,202 @@ +#patch + +"""Implementation of __array_function__ overrides from NEP-18.""" +import collections +import functools +import os + +from numpy.core._multiarray_umath import ( + add_docstring, implement_array_function, _get_implementing_args) +from numpy.compat._inspect import getargspec + + +ENABLE_ARRAY_FUNCTION = bool( + int(os.environ.get('NUMPY_EXPERIMENTAL_ARRAY_FUNCTION', 0))) + + +ARRAY_FUNCTION_ENABLED = ENABLE_ARRAY_FUNCTION # backward compat + + +_add_docstring = add_docstring + + +def add_docstring(*args): + try: + _add_docstring(*args) + except: + pass + + +add_docstring( + implement_array_function, + """ + Implement a function with checks for __array_function__ overrides. + + All arguments are required, and can only be passed by position. + + Arguments + --------- + implementation : function + Function that implements the operation on NumPy array without + overrides when called like ``implementation(*args, **kwargs)``. + public_api : function + Function exposed by NumPy's public API originally called like + ``public_api(*args, **kwargs)`` on which arguments are now being + checked. + relevant_args : iterable + Iterable of arguments to check for __array_function__ methods. + args : tuple + Arbitrary positional arguments originally passed into ``public_api``. + kwargs : dict + Arbitrary keyword arguments originally passed into ``public_api``. + + Returns + ------- + Result from calling ``implementation()`` or an ``__array_function__`` + method, as appropriate. + + Raises + ------ + TypeError : if no implementation is found. + """) + + +# exposed for testing purposes; used internally by implement_array_function +add_docstring( + _get_implementing_args, + """ + Collect arguments on which to call __array_function__. + + Parameters + ---------- + relevant_args : iterable of array-like + Iterable of possibly array-like arguments to check for + __array_function__ methods. + + Returns + ------- + Sequence of arguments with __array_function__ methods, in the order in + which they should be called. + """) + + +ArgSpec = collections.namedtuple('ArgSpec', 'args varargs keywords defaults') + + +def verify_matching_signatures(implementation, dispatcher): + """Verify that a dispatcher function has the right signature.""" + implementation_spec = ArgSpec(*getargspec(implementation)) + dispatcher_spec = ArgSpec(*getargspec(dispatcher)) + + if (implementation_spec.args != dispatcher_spec.args or + implementation_spec.varargs != dispatcher_spec.varargs or + implementation_spec.keywords != dispatcher_spec.keywords or + (bool(implementation_spec.defaults) != + bool(dispatcher_spec.defaults)) or + (implementation_spec.defaults is not None and + len(implementation_spec.defaults) != + len(dispatcher_spec.defaults))): + raise RuntimeError('implementation and dispatcher for %s have ' + 'different function signatures' % implementation) + + if implementation_spec.defaults is not None: + if dispatcher_spec.defaults != (None,) * len(dispatcher_spec.defaults): + raise RuntimeError('dispatcher functions can only use None for ' + 'default argument values') + + +def set_module(module): + """Decorator for overriding __module__ on a function or class. + + Example usage:: + + @set_module('numpy') + def example(): + pass + + assert example.__module__ == 'numpy' + """ + def decorator(func): + if module is not None: + func.__module__ = module + return func + return decorator + + +def array_function_dispatch(dispatcher, module=None, verify=True, + docs_from_dispatcher=False): + """Decorator for adding dispatch with the __array_function__ protocol. + + See NEP-18 for example usage. + + Parameters + ---------- + dispatcher : callable + Function that when called like ``dispatcher(*args, **kwargs)`` with + arguments from the NumPy function call returns an iterable of + array-like arguments to check for ``__array_function__``. + module : str, optional + __module__ attribute to set on new function, e.g., ``module='numpy'``. + By default, module is copied from the decorated function. + verify : bool, optional + If True, verify the that the signature of the dispatcher and decorated + function signatures match exactly: all required and optional arguments + should appear in order with the same names, but the default values for + all optional arguments should be ``None``. Only disable verification + if the dispatcher's signature needs to deviate for some particular + reason, e.g., because the function has a signature like + ``func(*args, **kwargs)``. + docs_from_dispatcher : bool, optional + If True, copy docs from the dispatcher function onto the dispatched + function, rather than from the implementation. This is useful for + functions defined in C, which otherwise don't have docstrings. + + Returns + ------- + Function suitable for decorating the implementation of a NumPy function. + """ + + if not ENABLE_ARRAY_FUNCTION: + # __array_function__ requires an explicit opt-in for now + def decorator(implementation): + if module is not None: + implementation.__module__ = module + if docs_from_dispatcher: + add_docstring(implementation, dispatcher.__doc__) + return implementation + return decorator + + def decorator(implementation): + if verify: + verify_matching_signatures(implementation, dispatcher) + + if docs_from_dispatcher: + add_docstring(implementation, dispatcher.__doc__) + + @functools.wraps(implementation) + def public_api(*args, **kwargs): + relevant_args = dispatcher(*args, **kwargs) + return implement_array_function( + implementation, public_api, relevant_args, args, kwargs) + + if module is not None: + public_api.__module__ = module + + # TODO: remove this when we drop Python 2 support (functools.wraps) + # adds __wrapped__ automatically in later versions) + public_api.__wrapped__ = implementation + + return public_api + + return decorator + + +def array_function_from_dispatcher( + implementation, module=None, verify=True, docs_from_dispatcher=True): + """Like array_function_dispatcher, but with function arguments flipped.""" + + def decorator(dispatcher): + return array_function_dispatch( + dispatcher, module, verify=verify, + docs_from_dispatcher=docs_from_dispatcher)(implementation) + return decorator diff --git a/datavec/datavec-python/src/main/resources/pythonexec/patch1.py b/datavec/datavec-python/src/main/resources/pythonexec/patch1.py new file mode 100644 index 000000000..890852bbc --- /dev/null +++ b/datavec/datavec-python/src/main/resources/pythonexec/patch1.py @@ -0,0 +1,172 @@ +#patch 1 + +""" +======================== +Random Number Generation +======================== + +==================== ========================================================= +Utility functions +============================================================================== +random_sample Uniformly distributed floats over ``[0, 1)``. +random Alias for `random_sample`. +bytes Uniformly distributed random bytes. +random_integers Uniformly distributed integers in a given range. +permutation Randomly permute a sequence / generate a random sequence. +shuffle Randomly permute a sequence in place. +seed Seed the random number generator. +choice Random sample from 1-D array. + +==================== ========================================================= + +==================== ========================================================= +Compatibility functions +============================================================================== +rand Uniformly distributed values. +randn Normally distributed values. +ranf Uniformly distributed floating point numbers. +randint Uniformly distributed integers in a given range. +==================== ========================================================= + +==================== ========================================================= +Univariate distributions +============================================================================== +beta Beta distribution over ``[0, 1]``. +binomial Binomial distribution. +chisquare :math:`\\chi^2` distribution. +exponential Exponential distribution. +f F (Fisher-Snedecor) distribution. +gamma Gamma distribution. +geometric Geometric distribution. +gumbel Gumbel distribution. +hypergeometric Hypergeometric distribution. +laplace Laplace distribution. +logistic Logistic distribution. +lognormal Log-normal distribution. +logseries Logarithmic series distribution. +negative_binomial Negative binomial distribution. +noncentral_chisquare Non-central chi-square distribution. +noncentral_f Non-central F distribution. +normal Normal / Gaussian distribution. +pareto Pareto distribution. +poisson Poisson distribution. +power Power distribution. +rayleigh Rayleigh distribution. +triangular Triangular distribution. +uniform Uniform distribution. +vonmises Von Mises circular distribution. +wald Wald (inverse Gaussian) distribution. +weibull Weibull distribution. +zipf Zipf's distribution over ranked data. +==================== ========================================================= + +==================== ========================================================= +Multivariate distributions +============================================================================== +dirichlet Multivariate generalization of Beta distribution. +multinomial Multivariate generalization of the binomial distribution. +multivariate_normal Multivariate generalization of the normal distribution. +==================== ========================================================= + +==================== ========================================================= +Standard distributions +============================================================================== +standard_cauchy Standard Cauchy-Lorentz distribution. +standard_exponential Standard exponential distribution. +standard_gamma Standard Gamma distribution. +standard_normal Standard normal distribution. +standard_t Standard Student's t-distribution. +==================== ========================================================= + +==================== ========================================================= +Internal functions +============================================================================== +get_state Get tuple representing internal state of generator. +set_state Set state of generator. +==================== ========================================================= + +""" +from __future__ import division, absolute_import, print_function + +import warnings + +__all__ = [ + 'beta', + 'binomial', + 'bytes', + 'chisquare', + 'choice', + 'dirichlet', + 'exponential', + 'f', + 'gamma', + 'geometric', + 'get_state', + 'gumbel', + 'hypergeometric', + 'laplace', + 'logistic', + 'lognormal', + 'logseries', + 'multinomial', + 'multivariate_normal', + 'negative_binomial', + 'noncentral_chisquare', + 'noncentral_f', + 'normal', + 'pareto', + 'permutation', + 'poisson', + 'power', + 'rand', + 'randint', + 'randn', + 'random_integers', + 'random_sample', + 'rayleigh', + 'seed', + 'set_state', + 'shuffle', + 'standard_cauchy', + 'standard_exponential', + 'standard_gamma', + 'standard_normal', + 'standard_t', + 'triangular', + 'uniform', + 'vonmises', + 'wald', + 'weibull', + 'zipf' +] + +with warnings.catch_warnings(): + warnings.filterwarnings("ignore", message="numpy.ndarray size changed") + try: + from .mtrand import * + # Some aliases: + ranf = random = sample = random_sample + __all__.extend(['ranf', 'random', 'sample']) + except: + warnings.warn("numpy.random is not available when using multiple interpreters!") + + + +def __RandomState_ctor(): + """Return a RandomState instance. + + This function exists solely to assist (un)pickling. + + Note that the state of the RandomState returned here is irrelevant, as this function's + entire purpose is to return a newly allocated RandomState whose state pickle can set. + Consequently the RandomState returned by this function is a freshly allocated copy + with a seed=0. + + See https://github.com/numpy/numpy/issues/4763 for a detailed discussion + + """ + return RandomState(seed=0) + +from numpy._pytesttester import PytestTester +test = PytestTester(__name__) +del PytestTester diff --git a/datavec/datavec-python/src/main/resources/pythonexec/pythonexec.py b/datavec/datavec-python/src/main/resources/pythonexec/pythonexec.py new file mode 100644 index 000000000..dbdceff0e --- /dev/null +++ b/datavec/datavec-python/src/main/resources/pythonexec/pythonexec.py @@ -0,0 +1,20 @@ +import sys +import traceback +import json +import inspect + + +try: + + pass + sys.stdout.flush() + sys.stderr.flush() +except Exception as ex: + try: + exc_info = sys.exc_info() + finally: + print(ex) + traceback.print_exception(*exc_info) + sys.stdout.flush() + sys.stderr.flush() + diff --git a/datavec/datavec-python/src/main/resources/pythonexec/serialize_array.py b/datavec/datavec-python/src/main/resources/pythonexec/serialize_array.py new file mode 100644 index 000000000..ac6f5b1c1 --- /dev/null +++ b/datavec/datavec-python/src/main/resources/pythonexec/serialize_array.py @@ -0,0 +1,50 @@ +def __is_numpy_array(x): + return str(type(x))== "" + +def __maybe_serialize_ndarray_metadata(x): + return __serialize_ndarray_metadata(x) if __is_numpy_array(x) else x + + +def __serialize_ndarray_metadata(x): + return {"address": x.__array_interface__['data'][0], + "shape": x.shape, + "strides": x.strides, + "dtype": str(x.dtype), + "_is_numpy_array": True} if __is_numpy_array(x) else x + + +def __serialize_list(x): + import json + return json.dumps(__recursive_serialize_list(x)) + + +def __serialize_dict(x): + import json + return json.dumps(__recursive_serialize_dict(x)) + +def __recursive_serialize_list(x): + out = [] + for i in x: + if __is_numpy_array(i): + out.append(__serialize_ndarray_metadata(i)) + elif isinstance(i, (list, tuple)): + out.append(__recursive_serialize_list(i)) + elif isinstance(i, dict): + out.append(__recursive_serialize_dict(i)) + else: + out.append(i) + return out + +def __recursive_serialize_dict(x): + out = {} + for k in x: + v = x[k] + if __is_numpy_array(v): + out[k] = __serialize_ndarray_metadata(v) + elif isinstance(v, (list, tuple)): + out[k] = __recursive_serialize_list(v) + elif isinstance(v, dict): + out[k] = __recursive_serialize_dict(v) + else: + out[k] = v + return out \ No newline at end of file diff --git a/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonExecutionSandbox.java b/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonExecutionSandbox.java new file mode 100644 index 000000000..435babf7c --- /dev/null +++ b/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonExecutionSandbox.java @@ -0,0 +1,75 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.datavec.python; + + +import org.junit.Assert; +import org.junit.Test; + +@javax.annotation.concurrent.NotThreadSafe +public class TestPythonExecutionSandbox { + + @Test + public void testInt(){ + PythonExecutioner.setInterpreter("interp1"); + PythonExecutioner.exec("a = 1"); + PythonExecutioner.setInterpreter("interp2"); + PythonExecutioner.exec("a = 2"); + PythonExecutioner.setInterpreter("interp3"); + PythonExecutioner.exec("a = 3"); + + + PythonExecutioner.setInterpreter("interp1"); + Assert.assertEquals(1, PythonExecutioner.evalInteger("a")); + + PythonExecutioner.setInterpreter("interp2"); + Assert.assertEquals(2, PythonExecutioner.evalInteger("a")); + + PythonExecutioner.setInterpreter("interp3"); + Assert.assertEquals(3, PythonExecutioner.evalInteger("a")); + } + + @Test + public void testNDArray(){ + PythonExecutioner.setInterpreter("main"); + PythonExecutioner.exec("import numpy as np"); + PythonExecutioner.exec("a = np.zeros(5)"); + + PythonExecutioner.setInterpreter("main"); + //PythonExecutioner.exec("import numpy as np"); + PythonExecutioner.exec("a = np.zeros(5)"); + + PythonExecutioner.setInterpreter("main"); + PythonExecutioner.exec("a += 2"); + + PythonExecutioner.setInterpreter("main"); + PythonExecutioner.exec("a += 3"); + + PythonExecutioner.setInterpreter("main"); + //PythonExecutioner.exec("import numpy as np"); + // PythonExecutioner.exec("a = np.zeros(5)"); + + PythonExecutioner.setInterpreter("main"); + Assert.assertEquals(25, PythonExecutioner.evalNdArray("a").getNd4jArray().sum().getDouble(), 1e-5); + } + + @Test + public void testNumpyRandom(){ + PythonExecutioner.setInterpreter("main"); + PythonExecutioner.exec("import numpy as np; print(np.random.randint(5))"); + } +} diff --git a/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonExecutioner.java b/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonExecutioner.java index 791950043..c8e67febb 100644 --- a/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonExecutioner.java +++ b/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonExecutioner.java @@ -15,17 +15,25 @@ ******************************************************************************/ package org.datavec.python; -import org.junit.Ignore; +import org.junit.Assert; import org.junit.Test; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; + import static org.junit.Assert.assertEquals; -@Ignore("AB 2019/05/21 - Fine locally, timeouts on CI - Issue #7657 and #7771") + +@javax.annotation.concurrent.NotThreadSafe public class TestPythonExecutioner { - @Test(timeout = 60000L) + + @org.junit.Test + public void testPythonSysVersion() { + PythonExecutioner.exec("import sys; print(sys.version)"); + } + + @Test public void testStr() throws Exception{ PythonVariables pyInputs = new PythonVariables(); @@ -47,7 +55,7 @@ public class TestPythonExecutioner { assertEquals("Hello World", z); } - @Test(timeout = 60000L) + @Test public void testInt()throws Exception{ PythonVariables pyInputs = new PythonVariables(); PythonVariables pyOutputs = new PythonVariables(); @@ -55,7 +63,7 @@ public class TestPythonExecutioner { pyInputs.addInt("x", 10); pyInputs.addInt("y", 20); - String code = "z = x + y"; + String code = "z = x + y"; pyOutputs.addInt("z"); @@ -64,11 +72,11 @@ public class TestPythonExecutioner { long z = pyOutputs.getIntValue("z"); - assertEquals(30, z); + Assert.assertEquals(30, z); } - @Test(timeout = 60000L) + @Test public void testList() throws Exception{ PythonVariables pyInputs = new PythonVariables(); PythonVariables pyOutputs = new PythonVariables(); @@ -88,18 +96,35 @@ public class TestPythonExecutioner { Object[] z = pyOutputs.getListValue("z"); - assertEquals(z.length, x.length + y.length); + Assert.assertEquals(z.length, x.length + y.length); + + for (int i = 0; i < x.length; i++) { + if(x[i] instanceof Number) { + Number xNum = (Number) x[i]; + Number zNum = (Number) z[i]; + Assert.assertEquals(xNum.intValue(), zNum.intValue()); + } + else { + Assert.assertEquals(x[i], z[i]); + } - for (int i=0; i < x.length; i++){ - assertEquals(x[i], z[i]); } - for (int i=0; i> inputData = new ArrayList<>(); inputData.add(Arrays.asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1))); @@ -288,10 +287,9 @@ public class ExecutionTest extends BaseSparkTest { String pythonCode = "col3 = col1 + col2"; TransformProcess tp = new TransformProcess.Builder(schema).transform( - new PythonTransform( - pythonCode, - finalSchema - ) + PythonTransform.builder().code( + "first = np.sin(first)\nsecond = np.cos(second)") + .outputSchema(schema).build() ).build(); INDArray zeros = Nd4j.zeros(shape); diff --git a/pom.xml b/pom.xml index 35ef4bcab..ada833f12 100644 --- a/pom.xml +++ b/pom.xml @@ -294,6 +294,8 @@ 3.7.5 ${python.version}-${javacpp-presets.version} + 1.17.3 + ${numpy.version}-${javacpp-presets.version} 0.3.7 2019.5 From 5e152c0d9a9b3029c3539a6d23efebdae605475c Mon Sep 17 00:00:00 2001 From: Alexander Stoyakin Date: Mon, 2 Dec 2019 12:23:06 +0200 Subject: [PATCH 20/30] TF import tests - adding missing operations (#65) * Add and fix mappings. * Intermediate * Added and fixed some mappings * Added op * Missing constructors added. * Added new mappings * SDImage wrappers and minor tweaks. * Added missing constructor * Some corrections * Cleanup * Small fixes * Ops wrappers * Minor fixes. * Max Pooling * MaxPoolWithArgmax * Some fixes * Ignores for failures * Some ops fixed. * Some fixes * Missing package added * Some fixes * Ignored tests fixed. * Some fixes * Merge master * bitcast fix Signed-off-by: raver119 * Bitcast fixed --- .../DifferentialFunctionFactory.java | 30 ++ .../nd4j/autodiff/samediff/ops/SDBitwise.java | 12 + .../nd4j/autodiff/samediff/ops/SDImage.java | 63 +++- .../nd4j/autodiff/samediff/ops/SDMath.java | 52 ++++ .../org/nd4j/autodiff/samediff/ops/SDNN.java | 32 ++ .../converters/ImportClassMapping.java | 14 +- .../linalg/api/ops/custom/AdjustContrast.java | 21 +- .../api/ops/custom/AdjustContrastV2.java | 20 +- .../nd4j/linalg/api/ops/custom/AdjustHue.java | 69 +++++ .../api/ops/custom/AdjustSaturation.java | 68 +++++ .../api/ops/custom/BaseAdjustContrast.java | 22 +- .../nd4j/linalg/api/ops/custom/BetaInc.java | 67 +++++ .../nd4j/linalg/api/ops/custom/BitCast.java | 24 +- .../api/ops/custom/CompareAndBitpack.java | 21 +- .../linalg/api/ops/custom/DivideNoNan.java | 21 +- .../api/ops/custom/DrawBoundingBoxes.java | 23 +- .../FakeQuantWithMinMaxVarsPerChannel.java | 23 +- .../linalg/api/ops/custom/FusedBatchNorm.java | 64 ++++ .../linalg/api/ops/custom/MatrixBandPart.java | 65 ++++ .../nd4j/linalg/api/ops/custom/Polygamma.java | 66 ++++ .../linalg/api/ops/custom/RandomCrop.java | 61 ++++ .../org/nd4j/linalg/api/ops/custom/Roll.java | 64 ++++ .../linalg/api/ops/custom/ToggleBits.java | 64 ++++ .../api/ops/impl/image/NonMaxSuppression.java | 9 +- .../layers/convolution/MaxPoolWithArgmax.java | 284 ++++++++++++++++++ .../impl/layers/convolution/MaxPooling2D.java | 4 +- .../ops/impl/transforms/clip/ClipByValue.java | 2 +- .../impl/transforms/custom/RShiftBits.java | 8 +- .../ops/impl/transforms/custom/ShiftBits.java | 8 +- .../transforms/custom/UniqueWithCounts.java | 4 +- .../pairwise/arithmetic/CopyOp.java | 4 +- .../transforms/pairwise/arithmetic/ModOp.java | 2 +- .../impl/transforms/pairwise/bool/Not.java | 8 +- .../api/ops/impl/transforms/strict/GELU.java | 13 - .../random/custom/DistributionUniform.java | 14 +- .../linalg/api/ops/random/impl/DropOut.java | 12 - .../opvalidation/LayerOpValidation.java | 29 ++ .../TFGraphs/TFGraphTestAllSameDiff.java | 30 +- .../nd4j/linalg/custom/CustomOpsTests.java | 234 +++++++++++++++ 39 files changed, 1545 insertions(+), 86 deletions(-) create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustHue.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustSaturation.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/BetaInc.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/FusedBatchNorm.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/MatrixBandPart.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/Polygamma.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/RandomCrop.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/Roll.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/ToggleBits.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/MaxPoolWithArgmax.java diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java index abea31459..7f59d24e4 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java @@ -2616,6 +2616,36 @@ public class DifferentialFunctionFactory { return new FakeQuantWithMinMaxVarsPerChannel(sameDiff,x,min,max).outputVariable(); } + public SDVariable betainc( SDVariable a, SDVariable b, SDVariable x) { + return new BetaInc(sameDiff, a, b, x).outputVariable(); + } + + public SDVariable[] fusedBatchNorm(SDVariable x, SDVariable scale, SDVariable offset, + SDVariable dataFormat, SDVariable isTraining) { + return new FusedBatchNorm(sameDiff,x,scale,offset,dataFormat,isTraining).outputVariables(); + } + + public SDVariable matrixBandPart(SDVariable input, SDVariable minLower, SDVariable maxUpper) { + return new MatrixBandPart(sameDiff,input,minLower,maxUpper).outputVariable(); + } + + public SDVariable[] maxPoolWithArgmaxs(SDVariable x, Pooling2DConfig pooling2DConfig) { + return new MaxPoolWithArgmax(sameDiff, x, pooling2DConfig).outputVariables(); + } + + public SDVariable polygamma(SDVariable n, SDVariable x) { + return new Polygamma(sameDiff, n,x).outputVariable(); + } + + public SDVariable roll(SDVariable input, SDVariable shift) { + return new Roll(sameDiff, input, shift).outputVariable(); + } + + public SDVariable toggleBits(SDVariable x) { + return new ToggleBits(sameDiff, x).outputVariable(); + } + + public String toString() { return "DifferentialFunctionFactory{methodNames=" + methodNames + "}"; } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBitwise.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBitwise.java index 0857b2b42..a255afbc3 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBitwise.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBitwise.java @@ -202,4 +202,16 @@ public class SDBitwise extends SDOps { SDVariable ret = f().bitwiseXor(x, y); return updateVariableNameAndReference(ret, name); } + + /** + * Flip bits + * + * @param name Name of the output variable + * @param x input array + * @return array after flipping each input bit + */ + public SDVariable toggleBits(String name, SDVariable x) { + SDVariable res = f().toggleBits(x); + return updateVariableNameAndReference(res, name); + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDImage.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDImage.java index f7166ab5e..bf71a665e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDImage.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDImage.java @@ -3,6 +3,10 @@ package org.nd4j.autodiff.samediff.ops; import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.linalg.api.ops.custom.AdjustContrast; +import org.nd4j.linalg.api.ops.custom.AdjustHue; +import org.nd4j.linalg.api.ops.custom.AdjustSaturation; +import org.nd4j.linalg.api.ops.custom.RandomCrop; import org.nd4j.linalg.api.ops.impl.image.CropAndResize; import org.nd4j.linalg.api.ops.impl.image.ExtractImagePatches; import org.nd4j.linalg.api.ops.impl.image.NonMaxSuppression; @@ -52,10 +56,67 @@ public class SDImage extends SDOps { return updateVariableNameAndReference(out, name); } - + /** + * Greedily selects a subset of bounding boxes in descending order of score + * @param name Might be null. Name for the output variable + * @param boxes 2D array of shape [num_boxes,4] + * @param scores vector of shape [num_boxes] + * @param maxOutSize scalar representing the maximum number of boxes to be selected + * @param iouThreshold float - threshold for deciding whether boxes overlap too much with respect to IOU + * @param scoreThreshold float - threshold for deciding when to remove boxes based on score + * @return vectort of shape [M] representing the selected indices from the boxes tensor, where M <= max_output_size + */ public SDVariable nonMaxSuppression(String name, @NonNull SDVariable boxes, @NonNull SDVariable scores, @NonNull SDVariable maxOutSize, @NonNull SDVariable iouThreshold, @NonNull SDVariable scoreThreshold){ SDVariable out = new NonMaxSuppression(sd, boxes, scores, maxOutSize, iouThreshold, scoreThreshold).outputVariable(); return updateVariableNameAndReference(out, name); } + + /** + * Adjusts contrast of RGB or grayscale images. + * @param name name for the output variable + * @param in images to adjust. 3D shape or higher. + * @param factor float multiplier for adjusting contrast. + * @return Contrast-adjusted image + */ + public SDVariable adjustContrast(String name, @NonNull SDVariable in, @NonNull SDVariable factor) { + SDVariable out = new AdjustContrast(sd, in, factor).outputVariable(); + return updateVariableNameAndReference(out, name); + } + + /** + * Adjust saturation of RGB images + * @param name name for the output variable + * @param in RGB image as 3D array + * @param factor factor for saturation + * @return adjusted image + */ + public SDVariable adjustSaturation(String name, @NonNull SDVariable in, @NonNull SDVariable factor) { + SDVariable out = new AdjustSaturation(sd, in, factor).outputVariable(); + return updateVariableNameAndReference(out, name); + } + + /** + * Adjust hue of RGB image + * @param name name for the output variable + * @param in RGB image as 3D array + * @param delta value to add to hue channel + * @return adjusted image + */ + public SDVariable adjustHue(String name, @NonNull SDVariable in, @NonNull SDVariable delta) { + SDVariable out = new AdjustHue(sd, in, delta).outputVariable(); + return updateVariableNameAndReference(out, name); + } + + /** + * Randomly crops image + * @param name name for the output variable + * @param input input array + * @param shape shape for crop + * @return cropped array + */ + public SDVariable randomCrop(String name, @NonNull SDVariable input, @NonNull SDVariable shape) { + SDVariable out = new RandomCrop(sd, input, shape).outputVariable(); + return updateVariableNameAndReference(out, name); + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDMath.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDMath.java index 10fc0b44a..0d0da022e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDMath.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDMath.java @@ -2496,5 +2496,57 @@ public class SDMath extends SDOps { return updateVariableNameAndReference(res, name); } + /** + * Compute the regularized incomplete beta integral + * + * @param name Name of the output variable + * @param a input array + * @param b input array + * @param x input array + * @return array + */ + public SDVariable betainc(String name,SDVariable a,SDVariable b,SDVariable x) { + SDVariable res = f().betainc(a,b,x); + return updateVariableNameAndReference(res, name); + } + /** + * Copy a tensor setting everything outside a central band in each innermost matrix. + * + * @param name Name of the output variable + * @param input Rank k array + * @param minLower Number of subdiagonals to keep. + * @param maxUpper Number of superdiagonals to keep. + * @return Rank k array of the same shape as input. + */ + public SDVariable matrixBandPart(String name, SDVariable input, SDVariable minLower, SDVariable maxUpper) { + SDVariable res = f().matrixBandPart(input,minLower,maxUpper); + return updateVariableNameAndReference(res, name); + } + + /** + * Polygamma function + * + * @param name Name of the output variable + * @param n array + * @param x array + * @return array + */ + public SDVariable polygamma(String name, SDVariable n, SDVariable x) { + SDVariable res = f().polygamma(n,x); + return updateVariableNameAndReference(res, name); + } + + /** + * Rolls the elements of input + * + * @param name Name of the output variable + * @param input array + * @param shift number of places to shift elements + * @return array + */ + public SDVariable roll(String name, SDVariable input, SDVariable shift) { + SDVariable res = f().roll(input,shift); + return updateVariableNameAndReference(res, name); + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDNN.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDNN.java index 7b1cc5768..63aab3f33 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDNN.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDNN.java @@ -19,6 +19,7 @@ package org.nd4j.autodiff.samediff.ops; import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling2DConfig; import org.nd4j.linalg.api.ops.impl.transforms.Pad; import org.nd4j.linalg.factory.Nd4j; @@ -1032,4 +1033,35 @@ public class SDNN extends SDOps { ); } } + + /** + * Max pooling on the input and outputs both max values and indices + * + * @param name Name of the output variable + * @param x input array + * @return output array and argmax array + */ + public SDVariable[] maxPoolWithArgmax(String[] names, SDVariable x, Pooling2DConfig pooling2DConfig) { + SDVariable[] res = f().maxPoolWithArgmaxs(x, pooling2DConfig); + return sd.updateVariableNamesAndReferences(res, names); + } + + /** + * Batch normalization + * + * @param name Name of the output variable + * @param x 4D array + * @param scale vector for scaling factor of normalized x + * @param offset vector to shift to the normalized x + * @param dataFormat integer scalar - data format + * @param isTraining boolean scalar - is training mode + * @return y: 4D array + * batch_mean: vector + * batch_var: vector + */ + public SDVariable[] fusedBatchNorm(String[] names, SDVariable x, SDVariable scale, SDVariable offset, + SDVariable dataFormat, SDVariable isTraining) { + SDVariable[] res = f().fusedBatchNorm(x,scale,offset,dataFormat,isTraining); + return sd.updateVariableNamesAndReferences(res, names); + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java index b1de641b4..5b60ac0b4 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java @@ -46,7 +46,6 @@ public class ImportClassMapping { org.nd4j.linalg.api.ops.custom.BarnesEdgeForces.class, org.nd4j.linalg.api.ops.custom.BarnesHutGains.class, org.nd4j.linalg.api.ops.custom.BarnesHutSymmetrize.class, - org.nd4j.linalg.api.ops.custom.DrawBoundingBoxes.class, org.nd4j.linalg.api.ops.custom.KnnMinDistance.class, org.nd4j.linalg.api.ops.custom.SpTreeCell.class, org.nd4j.linalg.api.ops.custom.Flatten.class, @@ -122,6 +121,7 @@ public class ImportClassMapping { org.nd4j.linalg.api.ops.impl.layers.convolution.LocalResponseNormalizationDerivative.class, org.nd4j.linalg.api.ops.impl.layers.convolution.MaxPooling2D.class, org.nd4j.linalg.api.ops.impl.layers.convolution.MaxPooling3D.class, + org.nd4j.linalg.api.ops.impl.layers.convolution.MaxPoolWithArgmax.class, org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2D.class, org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2DDerivative.class, org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling3DDerivative.class, @@ -589,7 +589,17 @@ public class ImportClassMapping { org.nd4j.linalg.api.ops.custom.BitCast.class, org.nd4j.linalg.api.ops.custom.CompareAndBitpack.class, org.nd4j.linalg.api.ops.custom.DivideNoNan.class, - org.nd4j.linalg.api.ops.custom.FakeQuantWithMinMaxVarsPerChannel.class + org.nd4j.linalg.api.ops.custom.DrawBoundingBoxes.class, + org.nd4j.linalg.api.ops.custom.FakeQuantWithMinMaxVarsPerChannel.class, + org.nd4j.linalg.api.ops.custom.AdjustSaturation.class, + org.nd4j.linalg.api.ops.custom.AdjustHue.class, + org.nd4j.linalg.api.ops.custom.FusedBatchNorm.class, + org.nd4j.linalg.api.ops.custom.BetaInc.class, + org.nd4j.linalg.api.ops.custom.MatrixBandPart.class, + org.nd4j.linalg.api.ops.custom.Polygamma.class, + org.nd4j.linalg.api.ops.custom.RandomCrop.class, + org.nd4j.linalg.api.ops.custom.Roll.class, + org.nd4j.linalg.api.ops.custom.ToggleBits.class ); static { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustContrast.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustContrast.java index 2d0ac235f..68daf6788 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustContrast.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustContrast.java @@ -1,5 +1,22 @@ + +/* ****************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ package org.nd4j.linalg.api.ops.custom; +import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; @@ -14,11 +31,11 @@ public class AdjustContrast extends BaseAdjustContrast { public AdjustContrast() {super();} - public AdjustContrast(INDArray in, double factor, INDArray out) { + public AdjustContrast(@NonNull INDArray in, double factor, INDArray out) { super(in, factor, out); } - public AdjustContrast(SameDiff sameDiff, SDVariable in, SDVariable factor) { + public AdjustContrast(@NonNull SameDiff sameDiff, @NonNull SDVariable in, @NonNull SDVariable factor) { super(sameDiff,new SDVariable[]{in,factor}); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustContrastV2.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustContrastV2.java index 9ebb3ea6f..71c752485 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustContrastV2.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustContrastV2.java @@ -1,5 +1,21 @@ +/* ****************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ package org.nd4j.linalg.api.ops.custom; +import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; @@ -14,11 +30,11 @@ public class AdjustContrastV2 extends BaseAdjustContrast { public AdjustContrastV2() {super();} - public AdjustContrastV2(INDArray in, double factor, INDArray out) { + public AdjustContrastV2(@NonNull INDArray in, double factor, INDArray out) { super(in, factor, out); } - public AdjustContrastV2(SameDiff sameDiff, SDVariable in, SDVariable factor) { + public AdjustContrastV2(@NonNull SameDiff sameDiff, @NonNull SDVariable in, @NonNull SDVariable factor) { super( sameDiff,new SDVariable[]{in,factor}); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustHue.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustHue.java new file mode 100644 index 000000000..e1a5b0a7a --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustHue.java @@ -0,0 +1,69 @@ +/* ****************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.custom; + +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +import java.util.Collections; +import java.util.List; + +public class AdjustHue extends DynamicCustomOp { + public AdjustHue() {} + + public AdjustHue(@NonNull INDArray in, double delta, INDArray out) { + this(in, delta); + if (out != null) { + outputArguments.add(out); + } + } + + public AdjustHue(@NonNull INDArray in, double delta) { + Preconditions.checkArgument(in.rank() >= 3, + "AdjustSaturation: op expects rank of input array to be >= 3, but got %s instead", in.rank()); + Preconditions.checkArgument(-1.0 <= delta && delta <= 1.0, "AdjustHue: parameter delta must be within [-1, 1] interval," + + " but got %s instead", delta); + inputArguments.add(in); + + addTArgument(delta); + } + + public AdjustHue(@NonNull SameDiff sameDiff, @NonNull SDVariable in, @NonNull SDVariable factor) { + super(sameDiff,new SDVariable[]{in,factor}); + } + + @Override + public String opName() { + return "adjust_hue"; + } + + @Override + public String tensorflowName() { + return "AdjustHue"; + } + + @Override + public List calculateOutputDataTypes(List inputDataTypes){ + int n = args().length; + Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes); + return Collections.singletonList(inputDataTypes.get(0)); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustSaturation.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustSaturation.java new file mode 100644 index 000000000..e9f1f90c8 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustSaturation.java @@ -0,0 +1,68 @@ +/* ****************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.custom; + +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +import java.util.Collections; +import java.util.List; + +public class AdjustSaturation extends DynamicCustomOp { + + public AdjustSaturation() {} + + public AdjustSaturation(@NonNull INDArray in, double factor, INDArray out) { + this(in, factor); + if (out != null) { + outputArguments.add(out); + } + } + + public AdjustSaturation(@NonNull INDArray in, double factor) { + Preconditions.checkArgument(in.rank() >= 3, + "AdjustSaturation: op expects rank of input array to be >= 3, but got %s instead", in.rank()); + inputArguments.add(in); + + addTArgument(factor); + } + + public AdjustSaturation(@NonNull SameDiff sameDiff, @NonNull SDVariable in, @NonNull SDVariable factor) { + super(sameDiff, new SDVariable[]{in, factor}); + } + + @Override + public String opName() { + return "adjust_saturation"; + } + + @Override + public String tensorflowName() { + return "AdjustSaturation"; + } + + @Override + public List calculateOutputDataTypes(List inputDataTypes){ + int n = args().length; + Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes); + return Collections.singletonList(inputDataTypes.get(0)); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/BaseAdjustContrast.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/BaseAdjustContrast.java index 25cddd741..a5e296043 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/BaseAdjustContrast.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/BaseAdjustContrast.java @@ -1,5 +1,21 @@ +/* ****************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ package org.nd4j.linalg.api.ops.custom; +import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; @@ -14,16 +30,16 @@ public abstract class BaseAdjustContrast extends DynamicCustomOp { public BaseAdjustContrast() { } - public BaseAdjustContrast(INDArray in, double factor, INDArray out) { + public BaseAdjustContrast(@NonNull INDArray in, double factor, INDArray out) { Preconditions.checkArgument(in.rank() >= 3, - String.format("AdjustContrast: op expects rank of input array to be >= 3, but got %d instead", in.rank())); + "AdjustContrast: op expects rank of input array to be >= 3, but got %s instead", in.rank()); inputArguments.add(in); outputArguments.add(out); addTArgument(factor); } - public BaseAdjustContrast(SameDiff sameDiff, SDVariable[] vars) { + public BaseAdjustContrast(@NonNull SameDiff sameDiff, @NonNull SDVariable[] vars) { super("", sameDiff, vars); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/BetaInc.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/BetaInc.java new file mode 100644 index 000000000..ce45869cc --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/BetaInc.java @@ -0,0 +1,67 @@ +/* ****************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.custom; + +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +import java.util.Collections; +import java.util.List; + +public class BetaInc extends DynamicCustomOp { + + public BetaInc() {} + + public BetaInc(@NonNull INDArray a_input, @NonNull INDArray b_input, @NonNull INDArray x_input, + INDArray output) { + addInputArgument(a_input, b_input, x_input); + if (output != null) { + addOutputArgument(output); + } + } + + public BetaInc(@NonNull INDArray a_input, @NonNull INDArray b_input, @NonNull INDArray x_input) { + inputArguments.add(a_input); + inputArguments.add(b_input); + inputArguments.add(x_input); + } + + public BetaInc(@NonNull SameDiff sameDiff, @NonNull SDVariable a, @NonNull SDVariable b, @NonNull SDVariable x) { + super(sameDiff, new SDVariable[]{a,b,x}); + } + + @Override + public String opName() { + return "betainc"; + } + + @Override + public String tensorflowName() { + return "Betainc"; + } + + @Override + public List calculateOutputDataTypes(List inputDataTypes){ + int n = args().length; + Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes); + return Collections.singletonList(inputDataTypes.get(0)); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/BitCast.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/BitCast.java index ebae33fce..cafc228f2 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/BitCast.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/BitCast.java @@ -1,3 +1,18 @@ +/* ****************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ package org.nd4j.linalg.api.ops.custom; import lombok.val; @@ -20,6 +35,8 @@ import java.util.Map; public class BitCast extends DynamicCustomOp { public BitCast() {} + private DataType dtype; + public BitCast(INDArray in, DataType dataType, INDArray out) { this(in, dataType.toInt(), out); } @@ -28,6 +45,8 @@ public class BitCast extends DynamicCustomOp { inputArguments.add(in); outputArguments.add(out); iArguments.add(Long.valueOf(dataType)); + + dtype = DataType.fromInt(dataType); } public BitCast(INDArray in, DataType dataType) { @@ -37,6 +56,7 @@ public class BitCast extends DynamicCustomOp { public BitCast(INDArray in, int dataType) { inputArguments.add(in); iArguments.add(Long.valueOf(dataType)); + dtype = DataType.fromInt(dataType); } public BitCast(SameDiff sameDiff, SDVariable in, SDVariable dataType) { @@ -49,6 +69,8 @@ public class BitCast extends DynamicCustomOp { val t = nodeDef.getAttrOrDefault("type", null); val type = ArrayOptionsHelper.convertToDataType(t.getType()); addIArgument(type.toInt()); + + dtype = type; } @Override @@ -65,6 +87,6 @@ public class BitCast extends DynamicCustomOp { public List calculateOutputDataTypes(List inputDataTypes){ int n = args().length; Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes); - return Collections.singletonList(inputDataTypes.get(0)); + return Collections.singletonList(dtype); } } \ No newline at end of file diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/CompareAndBitpack.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/CompareAndBitpack.java index d69c73da4..e8285fe9b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/CompareAndBitpack.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/CompareAndBitpack.java @@ -1,3 +1,18 @@ +/* ****************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ package org.nd4j.linalg.api.ops.custom; import org.nd4j.autodiff.samediff.SDVariable; @@ -9,9 +24,13 @@ import org.nd4j.linalg.factory.Nd4j; public class CompareAndBitpack extends DynamicCustomOp { public CompareAndBitpack() {} - public CompareAndBitpack(INDArray in, double threshold, INDArray out) { + public CompareAndBitpack(INDArray in, double threshold) { inputArguments.add(in); inputArguments.add(Nd4j.scalar(threshold)); + } + + public CompareAndBitpack(INDArray in, double threshold, INDArray out) { + this(in, threshold); outputArguments.add(out); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/DivideNoNan.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/DivideNoNan.java index 801384bfd..af62c8443 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/DivideNoNan.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/DivideNoNan.java @@ -1,3 +1,18 @@ +/* ****************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ package org.nd4j.linalg.api.ops.custom; import org.apache.commons.math3.analysis.function.Divide; @@ -16,9 +31,13 @@ public class DivideNoNan extends DynamicCustomOp { public DivideNoNan() { } - public DivideNoNan(INDArray in1, INDArray in2, INDArray out) { + public DivideNoNan(INDArray in1, INDArray in2) { inputArguments.add(in1); inputArguments.add(in2); + } + + public DivideNoNan(INDArray in1, INDArray in2, INDArray out) { + this(in1,in2); outputArguments.add(out); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/DrawBoundingBoxes.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/DrawBoundingBoxes.java index 57551c84c..b92a6f8f9 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/DrawBoundingBoxes.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/DrawBoundingBoxes.java @@ -1,3 +1,18 @@ +/* ****************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ package org.nd4j.linalg.api.ops.custom; import org.nd4j.autodiff.samediff.SDVariable; @@ -13,11 +28,15 @@ import java.util.List; public class DrawBoundingBoxes extends DynamicCustomOp { public DrawBoundingBoxes() {} - public DrawBoundingBoxes(INDArray images, INDArray boxes, INDArray colors, - INDArray output) { + public DrawBoundingBoxes(INDArray images, INDArray boxes, INDArray colors) { inputArguments.add(images); inputArguments.add(boxes); inputArguments.add(colors); + } + + public DrawBoundingBoxes(INDArray images, INDArray boxes, INDArray colors, + INDArray output) { + this(images, boxes, colors); outputArguments.add(output); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/FakeQuantWithMinMaxVarsPerChannel.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/FakeQuantWithMinMaxVarsPerChannel.java index ef150843d..c63cd3b56 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/FakeQuantWithMinMaxVarsPerChannel.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/FakeQuantWithMinMaxVarsPerChannel.java @@ -1,3 +1,18 @@ +/* ****************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ package org.nd4j.linalg.api.ops.custom; import org.nd4j.autodiff.samediff.SDVariable; @@ -13,14 +28,18 @@ import java.util.List; public class FakeQuantWithMinMaxVarsPerChannel extends DynamicCustomOp { public FakeQuantWithMinMaxVarsPerChannel() {} - public FakeQuantWithMinMaxVarsPerChannel(INDArray x, INDArray min, INDArray max, - INDArray output) { + public FakeQuantWithMinMaxVarsPerChannel(INDArray x, INDArray min, INDArray max) { Preconditions.checkArgument(min.isVector() && max.isVector() && min.length() == max.length(), "FakeQuantWithMinMaxVarsPerChannel: min and max should be 1D tensors with the same length"); inputArguments.add(x); inputArguments.add(min); inputArguments.add(max); + } + + public FakeQuantWithMinMaxVarsPerChannel(INDArray x, INDArray min, INDArray max, + INDArray output) { + this(x,min,max); outputArguments.add(output); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/FusedBatchNorm.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/FusedBatchNorm.java new file mode 100644 index 000000000..691e5d43f --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/FusedBatchNorm.java @@ -0,0 +1,64 @@ +/* ****************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.custom; + +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +import java.util.Collections; +import java.util.List; + +public class FusedBatchNorm extends DynamicCustomOp { + + public FusedBatchNorm() {} + + public FusedBatchNorm(@NonNull INDArray x, @NonNull INDArray scale, @NonNull INDArray offset, + int dataFormat, int isTraining, + INDArray yOut, INDArray batchMeanOut, INDArray batchMeanVar) { + addInputArgument(x, scale, offset); + addIArgument(dataFormat, isTraining); + if (yOut != null && batchMeanOut != null && batchMeanVar != null) { + addOutputArgument(yOut, batchMeanOut, batchMeanVar); + } + } + + public FusedBatchNorm(@NonNull SameDiff sameDiff, @NonNull SDVariable x, @NonNull SDVariable scale, @NonNull SDVariable offset, + @NonNull SDVariable dataFormat, @NonNull SDVariable isTraining) { + super("", sameDiff, new SDVariable[]{x, scale, offset, dataFormat, isTraining}); + } + + @Override + public String opName() { + return "fused_batch_norm"; + } + + @Override + public String tensorflowName() { + return "FusedBatchNormV2"; + } + + @Override + public List calculateOutputDataTypes(List inputDataTypes){ + int n = args().length; + Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes); + return Collections.singletonList(inputDataTypes.get(0)); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/MatrixBandPart.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/MatrixBandPart.java new file mode 100644 index 000000000..46d29608e --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/MatrixBandPart.java @@ -0,0 +1,65 @@ +/* ****************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.custom; + +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +import java.util.Collections; +import java.util.List; + +public class MatrixBandPart extends DynamicCustomOp { + + public MatrixBandPart() {} + + public MatrixBandPart(@NonNull INDArray input, int minLower, int maxUpper) { + Preconditions.checkArgument(input.rank() >= 2, "MatrixBandPart: Input rank should be 2 or higher"); + long N = input.size(-2); + long M = input.size(-1); + Preconditions.checkArgument(minLower > -N && minLower < N, "MatrixBandPart: lower diagonal count %s should be less than %s", + minLower, N); + Preconditions.checkArgument(maxUpper > -M && maxUpper < M, "MatrixBandPart: upper diagonal count %s should be less than %s.", + maxUpper, M); + addInputArgument(input); + addIArgument(minLower, maxUpper); + } + + public MatrixBandPart(@NonNull SameDiff sameDiff, @NonNull SDVariable input, SDVariable minLower, SDVariable maxUpper) { + super("", sameDiff, new SDVariable[]{input, minLower, maxUpper}); + } + + @Override + public String opName() { + return "matrix_band_part"; + } + + @Override + public String tensorflowName() { + return "MatrixBandPart"; + } + + @Override + public List calculateOutputDataTypes(List inputDataTypes){ + int n = args().length; + Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes); + return Collections.singletonList(inputDataTypes.get(0)); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/Polygamma.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/Polygamma.java new file mode 100644 index 000000000..3b528eb62 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/Polygamma.java @@ -0,0 +1,66 @@ +/* ****************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.custom; + +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +import java.util.Collections; +import java.util.List; + +public class Polygamma extends DynamicCustomOp { + + public Polygamma() {} + + public Polygamma(@NonNull INDArray n, @NonNull INDArray x) { + Preconditions.checkArgument(n.shape() != x.shape(), + "Polygamma: n and x must have the same shapes"); + addInputArgument(n,x); + } + + public Polygamma(@NonNull INDArray n, @NonNull INDArray x, INDArray output) { + this(n,x); + if (output != null) { + addOutputArgument(output); + } + } + + public Polygamma(@NonNull SameDiff sameDiff, @NonNull SDVariable n, @NonNull SDVariable x) { + super("", sameDiff, new SDVariable[]{n ,x}); + } + + @Override + public String opName() { + return "polygamma"; + } + + @Override + public String tensorflowName() { + return "Polygamma"; + } + + @Override + public List calculateOutputDataTypes(List inputDataTypes){ + int n = args().length; + Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes); + return Collections.singletonList(inputDataTypes.get(0)); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/RandomCrop.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/RandomCrop.java new file mode 100644 index 000000000..1f3f2e3ea --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/RandomCrop.java @@ -0,0 +1,61 @@ +/* ****************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.custom; + +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.rng.Random; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +public class RandomCrop extends DynamicCustomOp { + + public RandomCrop() {} + + public RandomCrop(@NonNull INDArray input, @NonNull INDArray shape) { + Preconditions.checkArgument(shape.isVector(),"RandomCrop:Shape tensor should be a vector"); + Preconditions.checkArgument(input.rank() == shape.length(), "RandomCrop:The length of the shape vector is not match input rank"); + addInputArgument(input, shape); + } + + public RandomCrop(@NonNull SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable shape) { + super("", sameDiff, new SDVariable[]{input, shape}); + } + + @Override + public String opName() { + return "random_crop"; + } + + @Override + public String tensorflowName() { + return "RandomCrop"; + } + + @Override + public List calculateOutputDataTypes(List inputDataTypes){ + Preconditions.checkState(inputDataTypes != null /*&& inputDataTypes.size() == 4*/, + "Expected 4 input datatypes for %s, got %s", getClass(), inputDataTypes); + return Collections.singletonList(DataType.FLOAT); //TF import: always returns float32... + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/Roll.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/Roll.java new file mode 100644 index 000000000..9ce7aa641 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/Roll.java @@ -0,0 +1,64 @@ +/* ****************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.custom; + +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +import java.util.Collections; +import java.util.List; + +public class Roll extends DynamicCustomOp { + + public Roll() {} + + public Roll(@NonNull INDArray input, @NonNull INDArray axes, @NonNull INDArray shifts) { + Preconditions.checkArgument(axes.rank() == shifts.rank(), "Roll: shifts and axes should be the same rank"); + Preconditions.checkArgument(axes.length() == shifts.length(), "Roll: shifts and axes should be the same length"); + addInputArgument(input, axes, shifts); + } + + public Roll(@NonNull INDArray input, int shift) { + addInputArgument(input); + addIArgument(shift); + } + + public Roll(@NonNull SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable shift) { + super("", sameDiff, new SDVariable[]{input,shift}); + } + + @Override + public String opName() { + return "roll"; + } + + @Override + public String tensorflowName() { + return "Roll"; + } + + @Override + public List calculateOutputDataTypes(List inputDataTypes){ + int n = args().length; + Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes); + return Collections.singletonList(inputDataTypes.get(0)); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/ToggleBits.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/ToggleBits.java new file mode 100644 index 000000000..641cb4117 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/ToggleBits.java @@ -0,0 +1,64 @@ +/* ****************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.custom; + +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +import java.util.Collections; +import java.util.List; + +public class ToggleBits extends DynamicCustomOp { + + public ToggleBits() {} + + public ToggleBits(@NonNull INDArray input, INDArray output) { + this(input); + if (output != null) { + addOutputArgument(output); + } + } + + public ToggleBits(@NonNull INDArray input) { + addInputArgument(input); + } + + public ToggleBits(@NonNull SameDiff sameDiff, @NonNull SDVariable input) { + super("", sameDiff, new SDVariable[]{input}); + } + + @Override + public String opName() { + return "toggle_bits"; + } + + @Override + public String tensorflowName() { + return "Invert"; + } + + @Override + public List calculateOutputDataTypes(List inputDataTypes){ + int n = args().length; + Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes); + return Collections.singletonList(inputDataTypes.get(0)); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/NonMaxSuppression.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/NonMaxSuppression.java index d7161cf5f..75b82dc29 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/NonMaxSuppression.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/NonMaxSuppression.java @@ -21,6 +21,7 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.Op; @@ -41,6 +42,12 @@ public class NonMaxSuppression extends DynamicCustomOp { super(null, sameDiff, new SDVariable[]{boxes, scores, maxOutSize, iouThreshold, scoreThreshold}, false); } + public NonMaxSuppression(INDArray boxes, INDArray scores, int maxOutSize, double iouThreshold, double scoreThreshold) { + addInputArgument(boxes,scores); + addIArgument(maxOutSize); + addTArgument(iouThreshold, scoreThreshold); + } + @Override public String onnxName() { throw new NoOpNameFoundException("No onnx name found for shape " + opName()); @@ -53,7 +60,7 @@ public class NonMaxSuppression extends DynamicCustomOp { @Override public String[] tensorflowNames() { - return new String[]{"NonMaxSuppression", "NonMaxSuppressionV2"}; + return new String[]{"NonMaxSuppression", "NonMaxSuppressionV2","NonMaxSuppressionV3","NonMaxSuppressionV4"}; } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/MaxPoolWithArgmax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/MaxPoolWithArgmax.java new file mode 100644 index 000000000..58602d85e --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/MaxPoolWithArgmax.java @@ -0,0 +1,284 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg.api.ops.impl.layers.convolution; + +import lombok.Builder; +import lombok.Getter; +import lombok.NonNull; +import lombok.extern.slf4j.Slf4j; +import lombok.val; +import onnx.Onnx; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.imports.descriptors.properties.PropertyMapping; +import org.nd4j.imports.graphmapper.tf.TFGraphMapper; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling2DConfig; +import org.nd4j.linalg.util.ArrayUtil; +import org.tensorflow.framework.AttrValue; +import org.tensorflow.framework.GraphDef; +import org.tensorflow.framework.NodeDef; + +import java.lang.reflect.Field; +import java.util.*; + +@Slf4j +@Getter +public class MaxPoolWithArgmax extends DynamicCustomOp { + + protected Pooling2DConfig config; + protected DataType outputType; + + public MaxPoolWithArgmax() { + } + + @Builder(builderMethodName = "sameDiffBuilder") + @SuppressWarnings("Used in lombok") + public MaxPoolWithArgmax(SameDiff sameDiff, SDVariable input, Pooling2DConfig config) { + super(null, sameDiff, new SDVariable[]{input}, false); + + config.setType(Pooling2D.Pooling2DType.MAX); + this.config = config; + addArgs(); + } + + public MaxPoolWithArgmax(INDArray input, INDArray output,INDArray outArgMax, @NonNull Pooling2DConfig config){ + super(null, new INDArray[]{input}, new INDArray[]{output, outArgMax}); + config.setType(Pooling2D.Pooling2DType.MAX); + + this.config = config; + addArgs(); + } + + @Override + public boolean isConfigProperties() { + return true; + } + + @Override + public String configFieldName() { + return "config"; + } + + + @Override + public Map propertiesForFunction() { + if(config == null && iArguments.size() > 0){ + //Perhaps loaded from FlatBuffers - hence we have IArgs but not Config object + config = Pooling2DConfig.builder() + .kH(iArguments.get(0)) + .kW(iArguments.get(1)) + .sH(iArguments.get(2)) + .sW(iArguments.get(3)) + .pH(iArguments.get(4)) + .pW(iArguments.get(5)) + .dH(iArguments.get(6)) + .dW(iArguments.get(7)) + .isSameMode(iArguments.get(8) == 1) + .extra(iArguments.get(9)) + .isNHWC(iArguments.get(10) == 1) + .type(Pooling2D.Pooling2DType.MAX) + .build(); + } + return config.toProperties(); + } + + private void addArgs() { + addIArgument(config.getKH(), + config.getKW(), + config.getSH(), + config.getSW(), + config.getPH(), + config.getPW(), + config.getDH(), + config.getDW(), + ArrayUtil.fromBoolean(config.isSameMode()), + (int) config.getExtra(), + ArrayUtil.fromBoolean(config.isNHWC()) + ); + + } + + + public String getPoolingPrefix() { + return "max"; + } + + @Override + public List doDiff(List f1) { + List ret = new ArrayList<>(); + List inputs = new ArrayList<>(); + inputs.addAll(Arrays.asList(args())); + inputs.add(f1.get(0)); + Pooling2DDerivative pooling2DDerivative = Pooling2DDerivative.derivativeBuilder() + .inputs(inputs.toArray(new SDVariable[inputs.size()])) + .sameDiff(sameDiff) + .config(config) + .build(); + ret.addAll(Arrays.asList(pooling2DDerivative.outputVariables())); + return ret; + } + + @Override + public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { + val aStrides = nodeDef.getAttrOrThrow("strides"); + val tfStrides = aStrides.getList().getIList(); + + val aKernels = nodeDef.getAttrOrThrow("ksize"); + val tfKernels = aKernels.getList().getIList(); + + int sH = 0; + int sW = 0; + + int pH = 0; + int pW = 0; + + int kH = 0; + int kW = 0; + + val aPadding = nodeDef.getAttrOrThrow("padding"); + val padding = aPadding.getList().getIList(); + + val paddingMode = aPadding.getS().toStringUtf8().replaceAll("\"", ""); + + boolean isSameMode = paddingMode.equalsIgnoreCase("SAME"); + + String data_format = "nhwc"; + if (nodeDef.containsAttr("data_format")) { + val attr = nodeDef.getAttrOrThrow("data_format"); + + data_format = attr.getS().toStringUtf8().toLowerCase(); + } + + if (data_format.equalsIgnoreCase("nhwc")) { + sH = tfStrides.get(1).intValue(); + sW = tfStrides.get(2).intValue(); + + kH = tfKernels.get(1).intValue(); + kW = tfKernels.get(2).intValue(); + + pH = padding.size() > 0 ? padding.get(1).intValue() : 0; + pW = padding.size() > 0 ? padding.get(2).intValue() : 0; + } else { + sH = tfStrides.get(2).intValue(); + sW = tfStrides.get(3).intValue(); + + kH = tfKernels.get(2).intValue(); + kW = tfKernels.get(3).intValue(); + + pH = padding.size() > 0 ? padding.get(2).intValue() : 0; + pW = padding.size() > 0 ? padding.get(3).intValue() : 0; + } + + Pooling2DConfig pooling2DConfig = Pooling2DConfig.builder() + .sH(sH) + .sW(sW) + .type(Pooling2D.Pooling2DType.MAX) + .isSameMode(isSameMode) + .kH(kH) + .kW(kW) + .pH(pH) + .pW(pW) + .isNHWC(data_format.equalsIgnoreCase("nhwc")) + .extra(1.0) // averaging only for non-padded values + .build(); + this.config = pooling2DConfig; + addArgs(); + if(attributesForNode.containsKey("argmax")) { + outputType = TFGraphMapper.convertType(attributesForNode.get("argmax").getType()); + } else { + outputType = DataType.UINT32; + } + } + + @Override + public Map> mappingsForFunction() { + Map> ret = new HashMap<>(); + Map map = new HashMap<>(); + val strideMapping = PropertyMapping.builder() + .tfAttrName("strides") + .onnxAttrName("strides") + .propertyNames(new String[]{"sW", "sH"}) + .build(); + + val paddingMapping = PropertyMapping.builder() + .onnxAttrName("padding") + .tfAttrName("padding") + .propertyNames(new String[]{"pH", "pW"}) + .build(); + + val kernelMapping = PropertyMapping.builder() + .propertyNames(new String[]{"kH", "kW"}) + .tfInputPosition(1) + .onnxAttrName("ksize") + .build(); + + val dilationMapping = PropertyMapping.builder() + .onnxAttrName("dilations") + .propertyNames(new String[]{"dW", "dH"}) + .tfAttrName("rates") + .build(); + + + //data_format + val dataFormatMapping = PropertyMapping.builder() + .propertyNames(new String[]{"isNHWC"}) + .tfAttrName("data_format") + .build(); + + map.put("sW", strideMapping); + map.put("sH", strideMapping); + map.put("kH", kernelMapping); + map.put("kW", kernelMapping); + map.put("dW", dilationMapping); + map.put("dH", dilationMapping); + map.put("pH", paddingMapping); + map.put("pW", paddingMapping); + map.put("isNHWC", dataFormatMapping); + + ret.put(onnxName(), map); + ret.put(tensorflowName(), map); + return ret; + } + + @Override + public String opName() { + return "max_pool_with_argmax"; + } + + @Override + public String onnxName() { + return "MaxPoolWithArgmax"; + } + + @Override + public String tensorflowName() { + return "MaxPoolWithArgmax"; + } + + @Override + public List calculateOutputDataTypes(List inputDataTypes){ + Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 1, "Expected 1 input data type for %s, got %s", getClass(), inputDataTypes); + List result = new ArrayList<>(); + result.add(inputDataTypes.get(0)); + result.add(outputType == null ? DataType.UINT32 : outputType); + return result; + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/MaxPooling2D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/MaxPooling2D.java index 09e928d2f..ad7984f2b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/MaxPooling2D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/MaxPooling2D.java @@ -293,8 +293,8 @@ public class MaxPooling2D extends DynamicCustomOp { } @Override - public String tensorflowName() { - return "MaxPool"; + public String[] tensorflowNames() { + return new String[]{"MaxPool","MaxPoolV2"}; } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/clip/ClipByValue.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/clip/ClipByValue.java index 3927ba2bc..99f65f46d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/clip/ClipByValue.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/clip/ClipByValue.java @@ -68,7 +68,7 @@ public class ClipByValue extends DynamicCustomOp { @Override public String opName() { - return "clipbyvalue"; + return "ClipByValue"; } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/RShiftBits.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/RShiftBits.java index 3cc03d12b..6e87a05c0 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/RShiftBits.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/RShiftBits.java @@ -53,15 +53,9 @@ public class RShiftBits extends BaseDynamicTransformOp { return "rshift_bits"; } - - @Override - public String onnxName() { - throw new NoOpNameFoundException("No onnx op opName found for " + opName()); - } - @Override public String tensorflowName() { - throw new NoOpNameFoundException("No TensorFlow op opName found for " + opName()); + return "RightShift"; } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ShiftBits.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ShiftBits.java index a9eebb14e..038cca54b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ShiftBits.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ShiftBits.java @@ -53,15 +53,9 @@ public class ShiftBits extends BaseDynamicTransformOp { return "shift_bits"; } - - @Override - public String onnxName() { - throw new NoOpNameFoundException("No onnx op opName found for " + opName()); - } - @Override public String tensorflowName() { - throw new NoOpNameFoundException("No TensorFlow op opName found for " + opName()); + return "LeftShift"; } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/UniqueWithCounts.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/UniqueWithCounts.java index 3f3bdbe74..74a7397a4 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/UniqueWithCounts.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/UniqueWithCounts.java @@ -46,8 +46,8 @@ public class UniqueWithCounts extends DynamicCustomOp { } @Override - public String tensorflowName() { - return "UniqueWithCounts"; + public String[] tensorflowNames() { + return new String[]{"UniqueWithCounts","UniqueWithCountsV2"}; } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/CopyOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/CopyOp.java index 5397108c6..3ee75d23d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/CopyOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/CopyOp.java @@ -77,8 +77,8 @@ public class CopyOp extends BaseTransformSameOp { } @Override - public String tensorflowName() { - throw new NoOpNameFoundException("No tensorflow op opName found for " + opName()); + public String[] tensorflowNames() { + return new String[]{"Copy","DeepCopy","CopyHost"}; } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/ModOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/ModOp.java index 289333f96..46d477310 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/ModOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/ModOp.java @@ -57,7 +57,7 @@ public class ModOp extends BaseDynamicTransformOp { @Override public String tensorflowName() { - return "mod"; + return "Mod"; } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/bool/Not.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/bool/Not.java index b49a89200..95bd0bf41 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/bool/Not.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/bool/Not.java @@ -66,13 +66,7 @@ public class Not extends BaseTransformBoolOp { public String onnxName() { return "Not"; } - - @Override - public String tensorflowName() { - throw new NoOpNameFoundException("Tensorflow name not found for " + opName()); - //return "Not"; - } - + @Override public List doDiff(List f1) { return Collections.singletonList(f().zerosLike(arg())); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/GELU.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/GELU.java index ec91a98e6..b33ea8b8f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/GELU.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/GELU.java @@ -59,19 +59,6 @@ public class GELU extends BaseTransformStrictOp { return "gelu"; } - @Override - public String onnxName() { - throw new NoOpNameFoundException("No onnx op opName found for " + opName()); - } - - @Override - public String tensorflowName() - { - throw new NoOpNameFoundException("No tensorflow op opName found for " + opName()); - //return "GELU"; - } - - @Override public List doDiff(List i_v) { SDVariable ret = f().geluDerivative(arg(), false).mul(i_v.get(0)); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/custom/DistributionUniform.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/custom/DistributionUniform.java index ecc76a1b2..682d7c230 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/custom/DistributionUniform.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/custom/DistributionUniform.java @@ -24,6 +24,7 @@ import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.NodeDef; @@ -71,7 +72,12 @@ public class DistributionUniform extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - AttrValue v = attributesForNode.get("dtype"); + AttrValue vDtype = attributesForNode.get("dtype"); + AttrValue vTout = attributesForNode.get("Tout"); + if (vDtype == null && vTout == null) { + throw new ND4JIllegalStateException("Unable to find output data type for node " + nodeDef.getName()); + } + AttrValue v = vDtype == null ? vTout : vDtype; dataType = TFGraphMapper.convertType(v.getType()); addIArgument(dataType.toInt()); addTArgument(0.0, 1.0); //TF version is hardcoded 0 to 1 @@ -92,8 +98,8 @@ public class DistributionUniform extends DynamicCustomOp { } @Override - public String tensorflowName() { - return "RandomUniform"; + public String[] tensorflowNames() { + return new String[]{"RandomUniform","RandomUniformInt"}; } @Override @@ -103,7 +109,7 @@ public class DistributionUniform extends DynamicCustomOp { @Override public List calculateOutputDataTypes(List inputDataTypes){ - Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 1, "Expected exactly 1 input datatype for %s, got %s", getClass(), inputDataTypes); + Preconditions.checkState(inputDataTypes != null /*&& inputDataTypes.size() == 1*/, "Expected input datatypes for %s, got %s", getClass(), inputDataTypes); //Input data type specifies the shape if(dataType != null){ return Collections.singletonList(dataType); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/DropOut.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/DropOut.java index 32a823ac1..742e28113 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/DropOut.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/DropOut.java @@ -65,18 +65,6 @@ public class DropOut extends BaseRandomOp { return "dropout"; } - - @Override - public String onnxName() { - throw new NoOpNameFoundException("No ONNX op name found for: " + getClass().getName()); - } - - @Override - public String tensorflowName() { - throw new NoOpNameFoundException("No tensorflow op name found for: " + getClass().getName()); - //return opName(); - } - @Override public Type opType() { return Type.RANDOM ; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java index 6f4acd079..9dd529399 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java @@ -736,6 +736,35 @@ public class LayerOpValidation extends BaseOpValidation { // sd.execBackwards(); // TODO: test failing here } + @Test + public void testMaxPoolingArgMax() { + Nd4j.getRandom().setSeed(12345); + int nIn = 3; + int kH = 2; + int kW = 2; + + int mb = 3; + int imgH = 8; + int imgW = 8; + + SameDiff sd = SameDiff.create(); + INDArray inArr = Nd4j.rand(new int[]{mb, nIn, imgH, imgW}); + + SDVariable in = sd.var("in", inArr); + + Pooling2DConfig pooling2DConfig = Pooling2DConfig.builder() + .kH(kH).kW(kW) + .pH(0).pW(0) + .sH(1).sW(1) + .dH(1).dW(1) + .isSameMode(true) + .build(); + + SDVariable[] results = sd.nn().maxPoolWithArgmax(new String[]{"",""}, in, pooling2DConfig); + assertArrayEquals(inArr.shape(), results[0].eval().shape()); + assertArrayEquals(inArr.shape(), results[1].eval().shape()); + } + @Test public void testMaxPooling2dBasic() { Nd4j.getRandom().setSeed(12345); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java index 277bb8a83..ec65d71df 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java @@ -76,8 +76,6 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a "adjust_contrast/.*", //Failing 2019/09/11 - https://github.com/eclipse/deeplearning4j/issues/7965 "bincount/.*", - // Failing 2019/11/15 https://github.com/eclipse/deeplearning4j/issues/8400 - "bitcast/.*", // Failing 2019/11/14 https://github.com/eclipse/deeplearning4j/issues/8393 "is_strictly_increasing/emptyArrayTest/.*", @@ -116,20 +114,32 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a // 2019/11/15 - missing dtype argument in nd4j, tests are useless https://github.com/eclipse/deeplearning4j/issues/8398 "zeros_like/rank2_float32_dtype_int.*", - // 2019/11/15 - failure https://github.com/eclipse/deeplearning4j/issues/8399 - "crop_and_resize.*", - - // 2019/11/15 - failure https://github.com/eclipse/deeplearning4j/issues/8401 - "draw_bounding_boxes.*", - // 2019/11/15 - failure https://github.com/eclipse/deeplearning4j/issues/8402 "fake_quant/min_max_args_per_channel.*", // 2019/11/15 - failure https://github.com/eclipse/deeplearning4j/issues/8403 "resize_bilinear/int32.*", - // Suggesting TF 1.15 bug - see https://github.com/eclipse/deeplearning4j/issues/8449 - "non_max_suppression_v2/float16.*" + // Suggesting TF 1.15 bug + "non_max_suppression_v2/float16.*", + + // 11.26.2019 failing - https://github.com/eclipse/deeplearning4j/issues/8450 + "betainc.*", + + // 11.26.2019 failing - https://github.com/eclipse/deeplearning4j/issues/8452 + "polygamma.*", + + // 11.26.2019 failing - https://github.com/eclipse/deeplearning4j/issues/8453 + "roll/.*", + + // 11.26.2019 failing https://github.com/eclipse/deeplearning4j/issues/8455 + "matrix_band_part/.*", + + // 11.28.2019 failing https://github.com/eclipse/deeplearning4j/issues/8458 + "adjust_hue/.*", + + // 11.28.2019 failing https://github.com/eclipse/deeplearning4j/issues/8459 + "adjust_saturation/.*" }; /* As per TFGraphTestList.printArraysDebugging - this field defines a set of regexes for test cases that should have diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java index c01cf1942..fbb1ddb85 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java @@ -32,7 +32,10 @@ import org.nd4j.linalg.api.ops.executioner.OpExecutioner; import org.nd4j.linalg.api.ops.executioner.OpStatus; import org.nd4j.linalg.api.ops.impl.controlflow.Where; import org.nd4j.linalg.api.ops.impl.image.CropAndResize; +import org.nd4j.linalg.api.ops.impl.image.NonMaxSuppression; import org.nd4j.linalg.api.ops.impl.image.ResizeBilinear; +import org.nd4j.linalg.api.ops.impl.layers.convolution.MaxPoolWithArgmax; +import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling2DConfig; import org.nd4j.linalg.api.ops.impl.reduce.MmulBp; import org.nd4j.linalg.api.ops.impl.shape.Create; import org.nd4j.linalg.api.ops.impl.transforms.any.IsMax; @@ -53,6 +56,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import static java.lang.Float.NaN; import static org.junit.Assert.*; /** @@ -867,6 +871,26 @@ public class CustomOpsTests extends BaseNd4jTest { assertArrayEquals(new long[]{1,10}, lsd.get(0).getShape()); } + @Test + public void testAdjustSaturation() { + INDArray in = Nd4j.createFromArray(new double[]{50,100,78, 118.5,220,112.5,190,163.5,230, 255,128.5,134}).reshape(2,2,3); + INDArray out = Nd4j.create(in.shape()); + INDArray expected = Nd4j.createFromArray(new double[]{0,100,56, 17,220,5, 150,97,230, 255,2,13}).reshape(2,2,3); + + Nd4j.exec(new AdjustSaturation(in, 2.0, out)); + assertEquals(expected, out); + } + + @Test + public void testAdjustHue() { + INDArray in = Nd4j.createFromArray(new double[]{0,100,56, 17,220,5, 150,97,230, 255,2,13}).reshape(2,2,3); + INDArray out = Nd4j.create(in.shape()); + INDArray expected = Nd4j.createFromArray(new double[]{100,0,44, 208,5,220, 177,230,97, 2,255,244}).reshape(2,2,3); + + Nd4j.exec(new AdjustHue(in, 0.5, out)); + assertEquals(expected, out); + } + @Test public void testBitCast() { INDArray in = Nd4j.linspace(DataType.FLOAT, 1.0f, 1.0f, 8).reshape(2,2,2); @@ -1088,6 +1112,216 @@ public class CustomOpsTests extends BaseNd4jTest { assertArrayEquals(new long[]{1,10, 2}, lsd.get(0).getShape()); } + @Test + public void testBetaInc() { + Nd4j.getRandom().setSeed(10); + INDArray a = Nd4j.linspace(DataType.BFLOAT16, 0.1, 0.1, 9).reshape(3,3); + INDArray b = Nd4j.linspace(DataType.BFLOAT16, 0.1, 0.1, 9).reshape(3,3); + INDArray x = Nd4j.linspace(DataType.BFLOAT16, 0.1, 0.1, 9).reshape(3,3); + INDArray expected = Nd4j.createFromArray(new float[]{0.4121f, 0.3926f, 0.4082f, + 0.4414f, 0.5000f, 0.5703f, + 0.6562f, 0.7656f, 0.8828f}).reshape(3,3); + + BetaInc op = new BetaInc(a,b,x); + INDArray[] out = Nd4j.exec(op); + assertArrayEquals(expected.shape(), out[0].shape()); + for (int i = 0; i < 3; ++i) + assertArrayEquals(expected.toDoubleMatrix()[i], out[0].toDoubleMatrix()[i], 1e-4); + } + + @Test + public void testFusedBatchNorm() { + INDArray x = Nd4j.linspace(DataType.DOUBLE, 1.0, 1.0, 2*2*3*4).reshape(2,2,3,4); + INDArray scale = Nd4j.create(DataType.DOUBLE, 4); + scale.assign(0.5); + INDArray offset = Nd4j.create(DataType.DOUBLE, 4); + offset.assign(2.0); + + INDArray y = Nd4j.createUninitialized(DataType.DOUBLE, x.shape()); + INDArray batchMean = Nd4j.create(4); + INDArray batchVar = Nd4j.create(4); + + FusedBatchNorm op = new FusedBatchNorm(x,scale,offset,0,1, + y, batchMean, batchVar); + + INDArray expectedY = Nd4j.createFromArray(new double[]{1.20337462, 1.20337462, 1.20337462, + 1.20337462, 1.34821558, 1.34821558, 1.34821558, 1.34821558, 1.49305654, 1.49305654, + 1.49305654, 1.49305654, 1.63789749, 1.63789749, 1.63789749, 1.63789749, 1.78273857, + 1.78273857, 1.78273857, 1.78273857, 1.92757952, 1.92757952, 1.92757952, 1.92757952, + 2.0724206 , 2.0724206 , 2.0724206 , 2.0724206 , 2.21726155, 2.21726155, 2.21726155, + 2.21726155, 2.36210251, 2.36210251, 2.36210251, 2.36210251, 2.50694346, 2.50694346, + 2.50694346, 2.50694346, 2.65178442, 2.65178442, 2.65178442, 2.65178442, 2.79662538, + 2.79662538, 2.79662538, 2.79662538}).reshape(x.shape()); + INDArray expectedBatchMean = Nd4j.createFromArray(new double[]{23., 24., 25., 26.}); + INDArray expectedBatchVar = Nd4j.createFromArray(new double[]{208.00001526, 208.00001526, 208.00001526, 208.00001526}); + Nd4j.exec(op); + assertArrayEquals(expectedY.shape(), y.shape()); + assertArrayEquals(expectedBatchMean.shape(), batchMean.shape()); + assertArrayEquals(expectedBatchVar.shape(), batchVar.shape()); + } + + @Test + public void testMatrixBandPart() { + INDArray x = Nd4j.linspace(DataType.DOUBLE, 1.0, 1.0, 2*3*3).reshape(2,3,3); + val op = new MatrixBandPart(x,1,1); + INDArray expected = Nd4j.linspace(DataType.DOUBLE, 1.0, 1.0, 2*3*3).reshape(2,3,3); + /*expected.putScalar(0, 0, 2, 0.); + expected.putScalar(1, 0, 2, 0.); + expected.putScalar(0, 2, 0, 0.); + expected.putScalar(1, 2, 0, 0.);*/ + + INDArray[] out = Nd4j.exec(op); + assertEquals(expected, x); + } + + @Test + public void testPolygamma() { + INDArray n = Nd4j.linspace(DataType.FLOAT, 1.0, 1.0, 9).reshape(3,3); + INDArray x = Nd4j.create(DataType.FLOAT, 3,3); + x.assign(0.5); + INDArray expected = Nd4j.createFromArray(new float[]{4.934802f, -16.828796f, 97.409088f, -771.474243f, + 7691.113770f, -92203.460938f, 1290440.250000f, -20644900.000000f, 3.71595e+08f}).reshape(3,3); + INDArray output = Nd4j.create(DataType.FLOAT, expected.shape()); + val op = new Polygamma(x,n,output); + Nd4j.exec(op); + assertEquals(expected, output); + } + + @Test + public void testRandomCrop() { + INDArray x = Nd4j.createFromArray(new double[]{1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9.,2.1, 2.1, 0.7, 0.1,3., 4.2, 2.2, 1. }).reshape(2,2,4); + INDArray shape = Nd4j.createFromArray(new int[] {1,2,3}); + val op = new RandomCrop(x, shape); + INDArray[] res = Nd4j.exec(op); + } + + @Test + public void testRoll() { + INDArray x = Nd4j.createFromArray(new double[]{ 11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42, 12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, + 21.11, 21.12, 21.21, 21.22, 21.31, 21.32, 21.41, 21.42, 22.11, 22.12, 22.21, 22.22, 22.31, 22.32, 22.41, 22.42}). + reshape(2,2,4,2); + + INDArray expected = Nd4j.createFromArray(new double[]{ 22.21, 22.22, 22.31, 22.32, 22.41, 22.42, 11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42, + 12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, 21.11, 21.12, 21.21, 21.22, 21.31, 21.32, + 21.41, 21.42, 22.11, 22.12 + }).reshape(x.shape()); + val op = new Roll(x, 6); + INDArray[] res = Nd4j.exec(op); + assertEquals(expected, res[0]); + } + + @Test + public void testToggleBits() { + INDArray input = Nd4j.createFromArray(new int[]{2,2}); + INDArray expected = Nd4j.createFromArray(new int[]{-3,-3}); + ToggleBits op = new ToggleBits(input); + val result = Nd4j.exec(op); + assertEquals(expected, result[0]); + } + + @Ignore("AS 11.28.2019 - https://github.com/eclipse/deeplearning4j/issues/8449") + @Test + public void testNonMaxSuppression() { + INDArray boxes = Nd4j.createFromArray(new float[] {0.8115f, 0.4121f, 0.0771f, 0.4863f, + 0.7412f, 0.7607f, 0.1543f, 0.5479f, + 0.8223f, 0.2246f, 0.0049f, 0.6465f}).reshape(3,4); + INDArray scores = Nd4j.createFromArray(new float[]{0.0029f, 0.8135f, 0.4873f}); + val op = new NonMaxSuppression(boxes,scores,2,0.5,0.5); + val res = Nd4j.exec(op); + assertEquals(new long[]{1}, res[0].shape()); + } + + @Test + public void testMatrixBand() { + INDArray input = Nd4j.createFromArray(new float[]{0.7788f,0.8012f,0.7244f,0.2309f, + 0.7271f,0.1804f,0.5056f,0.8925f, + 0.5461f,0.9234f,0.0856f,0.7938f}).reshape(3,4); + MatrixBandPart op = new MatrixBandPart(input,1,-1); + List lsd = op.calculateOutputShape(); + assertEquals(1, lsd.size()); + } + + @Ignore("Failed AS 11.26.2019 - https://github.com/eclipse/deeplearning4j/issues/8450") + @Test + public void testBetaInc1() { + INDArray a = Nd4j.createFromArray(new float[]{0.7788f, 0.8012f, 0.7244f, 0.2309f}); + INDArray b = Nd4j.createFromArray(new float[]{0.7717f, 0.9281f, 0.9846f, 0.4838f}); + INDArray c = Nd4j.createFromArray(new float[]{0.9441f, 0.5957f, 0.8669f, 0.3502f}); + BetaInc op = new BetaInc(a,b,c); + INDArray[] ret = Nd4j.exec(op); + INDArray expected = Nd4j.createFromArray(new float[]{0.9122f, 0.6344f, 0.8983f, 0.6245f}); + assertEquals(expected, ret[0]); + } + + @Ignore("Failure AS 11.28.2019 - https://github.com/eclipse/deeplearning4j/issues/8452") + @Test + public void testPolygamma1() { + INDArray a = Nd4j.createFromArray(new float[]{0.7788f, 0.8012f, 0.7244f, 0.2309f, + 0.7271f, 0.1804f, 0.5056f, 0.8925f, + 0.5461f, 0.9234f, 0.0856f, 0.7938f}).reshape(3,4); + INDArray b = Nd4j.createFromArray(new float[]{0.7717f, 0.9281f, 0.9846f, 0.4838f, + 0.6433f, 0.6041f, 0.6501f, 0.7612f, + 0.7605f, 0.3948f, 0.9493f, 0.8600f}).reshape(3,4); + INDArray expected = Nd4j.createFromArray(new float[]{NaN,NaN,NaN,NaN,NaN,NaN,NaN,NaN,NaN,NaN,NaN,NaN, }).reshape(3,4); + Polygamma op = new Polygamma(a,b); + INDArray[] ret = Nd4j.exec(op); + assertEquals(expected, ret[0]); + } + + @Ignore("Failure AS 11.28.2019 - https://github.com/eclipse/deeplearning4j/issues/8453") + @Test + public void testRoll1() { + INDArray a = Nd4j.createFromArray(new float[]{0.7788f, 0.8012f, 0.7244f, 0.2309f}); + Roll op = new Roll(a,Nd4j.scalar(2),Nd4j.scalar(0)); + INDArray[] ret = Nd4j.exec(op); + INDArray expected = Nd4j.createFromArray(new float[]{0.7244f, 0.2309f, 0.7788f, 0.8012f}); + assertEquals(expected, ret[0]); + } + + @Test + public void testAdjustHueShape(){ + INDArray image = Nd4j.createFromArray(new float[]{0.7788f, 0.8012f, 0.7244f, + 0.2309f, 0.7271f, 0.1804f, 0.5056f, 0.8925f, 0.5461f, + 0.9234f, 0.0856f, 0.7938f, 0.6591f, 0.5555f, 0.1596f, + 0.3087f, 0.1548f, 0.4695f, 0.9939f, 0.6113f, 0.6765f, + 0.1800f, 0.6750f, 0.2246f, 0.0509f, 0.4601f, 0.8284f, + 0.2354f, 0.9752f, 0.8361f, 0.2585f, 0.4189f, 0.7028f, + 0.7679f, 0.5373f, 0.7234f, 0.2690f, 0.0062f, 0.0327f, + 0.0644f, 0.8428f, 0.7494f, 0.0755f, 0.6245f, 0.3491f, + 0.5793f, 0.5730f, 0.1822f, 0.6420f, 0.9143f, 0.3019f, + 0.3574f, 0.1704f, 0.8395f, 0.5468f, 0.0744f, 0.9011f, + 0.6574f, 0.4124f, 0.2445f, 0.4248f, 0.5219f, 0.6952f, + 0.4900f, 0.2158f, 0.9549f, 0.1386f, 0.1544f, 0.5365f, + 0.0134f, 0.4163f, 0.1456f, 0.4109f, 0.2484f, 0.3330f, + 0.2974f, 0.6636f, 0.3808f, 0.8664f, 0.1896f, 0.7530f, + 0.7215f, 0.6612f, 0.7270f, 0.5704f, 0.2666f, 0.7453f, + 0.0444f, 0.3024f, 0.4850f, 0.7982f, 0.0965f, 0.7843f, + 0.5075f, 0.0844f, 0.8370f, 0.6103f, 0.4604f, 0.6087f, + 0.8594f, 0.4599f, 0.6714f, 0.2744f, 0.1981f, 0.4143f, + 0.7821f, 0.3505f, 0.5040f, 0.1180f, 0.8307f, 0.1817f, + 0.8442f, 0.5074f, 0.4471f, 0.5105f, 0.6666f, 0.2576f, + 0.2341f, 0.6801f, 0.2652f, 0.5394f, 0.4690f, 0.6146f, + 0.1210f, 0.2576f, 0.0769f, 0.4643f, 0.1628f, 0.2026f, + 0.3774f, 0.0506f, 0.3462f, 0.5720f, 0.0838f, 0.4228f, + 0.0588f, 0.5362f, 0.4756f, 0.2530f, 0.1778f, 0.0751f, + 0.8977f, 0.3648f, 0.3065f, 0.4739f, 0.7014f, 0.4473f, + 0.5171f, 0.1744f, 0.3487f, 0.7759f, 0.9491f, 0.2072f, + 0.2182f, 0.6520f, 0.3092f, 0.9545f, 0.1881f, 0.9579f, + 0.1785f, 0.9636f, 0.4830f, 0.6569f, 0.3353f, 0.9997f, + 0.5869f, 0.5747f, 0.0238f, 0.2943f, 0.5248f, 0.5879f, + 0.7266f, 0.1965f, 0.9167f, 0.9726f, 0.9206f, 0.0519f, + 0.2997f, 0.0039f, 0.7652f, 0.5498f, 0.3794f, 0.3791f, + 0.3528f, 0.2873f, 0.8082f, 0.4732f, 0.4399f, 0.6606f, + 0.5991f, 0.0034f, 0.4874f}).reshape(8,8,3); + + AdjustHue op = new AdjustHue(image, 0.2f); + INDArray[] res = Nd4j.exec(op); + System.out.println(res[0]); + List lsd = op.calculateOutputShape(); + assertEquals(1, lsd.size()); + assertArrayEquals(new long[]{8, 8, 3}, lsd.get(0).getShape()); + } + @Test public void testBitCastShape_3(){ val x = Nd4j.createFromArray(new int[]{1, 2, 3, 4, 5, 6, 7, 8}).reshape(1, 4, 2); From 25b3cd9b80595cfde09d1d9a124ce63e955e881f Mon Sep 17 00:00:00 2001 From: raver119 Date: Mon, 2 Dec 2019 21:37:21 +0300 Subject: [PATCH 21/30] [WIP] CUDA tests (#95) * one more CI test Signed-off-by: raver119 * export additional symbols Signed-off-by: raver119 * few more tweaks Signed-off-by: raver119 * one more tweak for linux Signed-off-by: raver119 * fix dtype in few tests Signed-off-by: raver119 * missing sync and memset in couple of tests Signed-off-by: raver119 * copy step for libnd4j cuda Signed-off-by: raver119 * no-op on empty for adjust hue/contrast/saturation Signed-off-by: raver119 * CUDA_VERBOSE Off Signed-off-by: raver119 * BroadcastBool fix + few tests Signed-off-by: raver119 * trigger jenkins Signed-off-by: raver119 * trigger jenkins Signed-off-by: raver119 * - ignore couple of warnings - remove redundant compiler options Signed-off-by: raver119 --- libnd4j/CMakeLists.txt | 4 +- libnd4j/blas/CMakeLists.txt | 26 +- libnd4j/blas/NDArray.hpp | 352 +++++++++--------- libnd4j/blas/cpu/NDArrayFactory.cpp | 311 ++++++++-------- libnd4j/blas/cuda/NDArray.cu | 2 +- libnd4j/buildnativeoperations.sh | 169 ++++----- libnd4j/include/array/impl/ExtraArguments.cpp | 4 +- .../include/exceptions/allocation_exception.h | 10 +- libnd4j/include/exceptions/cuda_exception.h | 10 +- .../include/exceptions/datatype_exception.h | 10 +- libnd4j/include/exceptions/graph_exception.h | 10 +- .../exceptions/graph_execution_exception.h | 10 +- .../exceptions/graph_exists_exception.h | 10 +- .../include/exceptions/no_results_exception.h | 10 +- .../exceptions/unknown_graph_exception.h | 10 +- libnd4j/include/execution/ThreadPool.h | 2 +- libnd4j/include/execution/Threads.h | 10 +- libnd4j/include/execution/Ticket.h | 2 +- libnd4j/include/graph/Variable.h | 2 +- libnd4j/include/graph/impl/Node.cpp | 2 +- libnd4j/include/graph/impl/Variable.cpp | 2 +- libnd4j/include/helpers/AttentionHelper.h | 2 +- libnd4j/include/helpers/BenchmarkHelper.h | 2 +- libnd4j/include/helpers/BitwiseUtils.h | 2 +- libnd4j/include/helpers/CudaLaunchHelper.h | 2 +- libnd4j/include/helpers/DebugHelper.h | 2 +- libnd4j/include/helpers/GradCheck.h | 2 +- libnd4j/include/helpers/MmulHelper.h | 2 +- libnd4j/include/helpers/OmpLaunchHelper.h | 2 +- libnd4j/include/helpers/PointersManager.h | 2 +- libnd4j/include/helpers/RandomLauncher.h | 2 +- libnd4j/include/helpers/ShapeUtils.h | 2 +- libnd4j/include/helpers/SimpleReadWriteLock.h | 3 +- libnd4j/include/helpers/StringUtils.h | 2 +- libnd4j/include/helpers/cublasHelper.h | 2 +- libnd4j/include/memory/MemoryRegistrator.h | 3 +- libnd4j/include/memory/MemoryReport.h | 3 +- libnd4j/include/memory/MemoryUtils.h | 3 +- libnd4j/include/ops/BroadcastBoolOpsTuple.h | 3 +- libnd4j/include/ops/BroadcastIntOpsTuple.h | 3 +- libnd4j/include/ops/BroadcastOpsTuple.h | 3 +- .../generic/helpers/BroadcastHelper.h | 3 + .../generic/parity_ops/adjust_contrast.cpp | 8 + .../generic/parity_ops/adjust_hue.cpp | 4 + .../generic/parity_ops/adjust_saturation.cpp | 4 + .../ops/declarable/generic/shape/create.cpp | 1 + .../ops/declarable/helpers/activations.h | 18 +- .../include/ops/declarable/helpers/col2im.h | 2 +- .../ops/declarable/helpers/convolutions.h | 3 +- .../include/ops/declarable/helpers/im2col.h | 2 +- .../ops/declarable/helpers/lstmLayer.h | 4 +- .../ops/declarable/helpers/multiUnique.h | 2 +- .../layers_tests/BroadcastableOpsTests.cpp | 37 +- libnd4j/tests_cpu/layers_tests/CMakeLists.txt | 2 +- .../layers_tests/CudaBasicsTests1.cu | 5 + .../layers_tests/DeclarableOpsTests6.cpp | 56 +-- .../linalg/broadcast/BasicBroadcastTests.java | 25 ++ 57 files changed, 673 insertions(+), 518 deletions(-) diff --git a/libnd4j/CMakeLists.txt b/libnd4j/CMakeLists.txt index 50c6b9b8a..d8b0439b4 100755 --- a/libnd4j/CMakeLists.txt +++ b/libnd4j/CMakeLists.txt @@ -25,8 +25,8 @@ elseif (APPLE) elseif(WIN32) set(X86_BUILD true) if (CUDA_BLAS) - set(CMAKE_CXX_FLAGS_RELEASE "-D_RELEASE=true /wd4804") - set(CMAKE_CXX_FLAGS_DEBUG " /FS /EHsc /wd4661 /wd4804 /wd4267 /wd4244 /wd4251 /wd4305") + set(CMAKE_CXX_FLAGS_RELEASE "-D_RELEASE=true") + set(CMAKE_CXX_FLAGS_DEBUG " /FS /EHsc") else() set(CMAKE_CXX_FLAGS_RELEASE "-O3 -fPIC -std=c++11 -fmax-errors=2 -D_RELEASE=true") set(CMAKE_CXX_FLAGS_DEBUG " -g -O2 -fPIC -std=c++11 -fmax-errors=2") diff --git a/libnd4j/blas/CMakeLists.txt b/libnd4j/blas/CMakeLists.txt index 9674e28cd..c86bdc13a 100755 --- a/libnd4j/blas/CMakeLists.txt +++ b/libnd4j/blas/CMakeLists.txt @@ -111,7 +111,7 @@ elseif ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Intel") elseif ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "MSVC") # using Visual Studio C++ - set( CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /EHsc ${ARCH_TUNE}") + set( CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${ARCH_TUNE}") elseif ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") # using GCC SET( CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${ARCH_TUNE}") @@ -158,7 +158,7 @@ if(CUDA_BLAS) include_directories(${CUDA_INCLUDE_DIRS}) message("CUDA found!") set( CUDA_ARCHITECTURE_MINIMUM "3.0" CACHE STRING "Minimum required CUDA compute capability" ) - SET(CUDA_VERBOSE_BUILD ON) + SET(CUDA_VERBOSE_BUILD OFF) SET(CUDA_SEPARABLE_COMPILATION OFF) #set(CUDA_COMPUTE_CAPABILITY "61") set(CUDA_COMPUTE_CAPABILITY "35") @@ -264,24 +264,13 @@ if(CUDA_BLAS) file(GLOB_RECURSE LOOPS_SOURCES false ../include/loops/impl/*.cpp ../include/loops/*.h) file(GLOB_RECURSE LOOPS_SOURCES_CUDA false ../include/loops/*.cu) - if (NOT BUILD_TESTS) - CUDA_ADD_LIBRARY(${LIBND4J_NAME} SHARED cuda/NativeOps.cu cuda/NativeOpExecutioner.cu cuda/BlasVersionHelper.cu Environment.cpp ${LOOPS_SOURCES_CUDA} + + CUDA_ADD_LIBRARY(${LIBND4J_NAME} SHARED cuda/NativeOps.cu cuda/NativeOpExecutioner.cu cuda/BlasVersionHelper.cu Environment.cpp ${LOOPS_SOURCES_CUDA} ${CUSTOMOPS_HELPERS_SOURCES} ${HELPERS_SOURCES} ${EXEC_SOURCES} ../include/cnpy/cnpy.cpp ../include/nd4jmemset.h ../include/nd4jmalloc.h cpu/GraphExecutioner.cpp cuda/NDArray.cu cpu/NDArrayFactory.cpp Environment.h ${LOOPS_SOURCES} ${ARRAY_SOURCES} ${TYPES_SOURCES} ${MEMORY_SOURCES} ${GRAPH_SOURCES} ${CUSTOMOPS_SOURCES} ${INDEXING_SOURCES} ${EXCEPTIONS_SOURCES} ${OPS_SOURCES} ${PERF_SOURCES}) - else() - set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DBUILD_TESTS=true") - - CUDA_ADD_LIBRARY(${LIBND4J_NAME} STATIC cuda/NativeOps.cu cuda/NativeOpExecutioner.cu cuda/BlasVersionHelper.cu Environment.cpp ${LOOPS_SOURCES_CUDA} - ${CUSTOMOPS_HELPERS_SOURCES} ${HELPERS_SOURCES} ${EXEC_SOURCES} - ../include/cnpy/cnpy.cpp ../include/nd4jmemset.h ../include/nd4jmalloc.h - cpu/GraphExecutioner.cpp cuda/NDArray.cu cpu/NDArrayFactory.cpp - Environment.h ${LOOPS_SOURCES} ${ARRAY_SOURCES} ${TYPES_SOURCES} - ${MEMORY_SOURCES} ${GRAPH_SOURCES} ${CUSTOMOPS_SOURCES} ${INDEXING_SOURCES} ${EXCEPTIONS_SOURCES} ${OPS_SOURCES} ${PERF_SOURCES}) - endif() - if(WIN32) message("CUDA on Windows: enabling /EHsc") @@ -289,11 +278,16 @@ if(CUDA_BLAS) SET_TARGET_PROPERTIES(${LIBND4J_NAME} PROPERTIES COMPILER_FLAGS "/EHsc /bigobj /std:c++14") endif() - target_link_libraries(${LIBND4J_NAME} ${CUDA_LIBRARIES} ${CUDA_CUBLAS_LIBRARIES} ${CUDA_cusolver_LIBRARY}) set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${PROJECT_BINARY_DIR}/cuda) install(TARGETS ${LIBND4J_NAME} DESTINATION .) + + add_custom_command( + TARGET ${LIBND4J_NAME} POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy + $ + ${PROJECT_BINARY_DIR}/../../tests_cpu/) endif(CUDA_FOUND) elseif(CPU_BLAS) diff --git a/libnd4j/blas/NDArray.hpp b/libnd4j/blas/NDArray.hpp index 00a984d45..df358b64f 100644 --- a/libnd4j/blas/NDArray.hpp +++ b/libnd4j/blas/NDArray.hpp @@ -31,9 +31,9 @@ namespace nd4j { template <> -utf8string NDArray::e(const Nd4jLong i) const; +ND4J_EXPORT utf8string NDArray::e(const Nd4jLong i) const; template <> -std::string NDArray::e(const Nd4jLong i) const; +ND4J_EXPORT std::string NDArray::e(const Nd4jLong i) const; ////////////////////////////////////////////////////////////////////////// template @@ -48,7 +48,7 @@ NDArray* NDArray::asT() const{ return result; } -BUILD_SINGLE_TEMPLATE(template NDArray* NDArray::asT, () const, LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT NDArray* NDArray::asT, () const, LIBND4J_TYPES); //////////////////////////////////////////////////////////////////////// // copy constructor @@ -435,7 +435,7 @@ std::vector NDArray::getBufferAsVector() { vector[e] = this->e(e); return vector; } -BUILD_SINGLE_TEMPLATE(template std::vector, NDArray::getBufferAsVector(), LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT std::vector, NDArray::getBufferAsVector(), LIBND4J_TYPES); //////////////////////////////////////////////////////////////////////// std::vector NDArray::getShapeAsFlatVector() { @@ -813,7 +813,7 @@ void NDArray::templatedSet(void *buffer, const Nd4jLong *indices, const void *va auto xOffset = shape::getOffset(getShapeInfo(), indices); t[xOffset] = static_cast(y); } -BUILD_DOUBLE_TEMPLATE(template void NDArray::templatedSet, (void *buffer, const Nd4jLong *indices, const void *value), LIBND4J_TYPES, LIBND4J_TYPES); +BUILD_DOUBLE_TEMPLATE(template ND4J_EXPORT void NDArray::templatedSet, (void *buffer, const Nd4jLong *indices, const void *value), LIBND4J_TYPES, LIBND4J_TYPES); ////////////////////////////////////////////////////////////////////////// template @@ -823,7 +823,7 @@ void NDArray::templatedSet(void *buffer, const Nd4jLong offset, const void *valu t[offset] = static_cast(y); } -BUILD_DOUBLE_TEMPLATE(template void NDArray::templatedSet, (void *buffer, const Nd4jLong offset, const void *value), LIBND4J_TYPES, LIBND4J_TYPES); +BUILD_DOUBLE_TEMPLATE(template ND4J_EXPORT void NDArray::templatedSet, (void *buffer, const Nd4jLong offset, const void *value), LIBND4J_TYPES, LIBND4J_TYPES); ////////////////////////////////////////////////////////////////////////// void NDArray::setContext(nd4j::LaunchContext *context) { @@ -1301,7 +1301,7 @@ template void* NDArray::templatedPointerShift(const Nd4jLong offset) const { return reinterpret_cast(getBuffer()) + offset; } -BUILD_SINGLE_TEMPLATE(template void* NDArray::templatedPointerShift, (const Nd4jLong offset) const, LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT void* NDArray::templatedPointerShift, (const Nd4jLong offset) const, LIBND4J_TYPES); ////////////////////////////////////////////////////////////////////////// // method makes copy of this array and applies to the copy transpose operation, this array remains unaffected @@ -1608,7 +1608,7 @@ bool NDArray::isUnitary() { ////////////////////////////////////////////////////////////////////////// template <> -std::string* NDArray::bufferAsT() const { +std::string* ND4J_EXPORT NDArray::bufferAsT() const { throw std::runtime_error("This method is NOT supposed to be used"); } @@ -1620,7 +1620,7 @@ T* NDArray::bufferAsT() const { return reinterpret_cast(getBuffer()); } -BUILD_SINGLE_UNCHAINED_TEMPLATE(template, * NDArray::bufferAsT() const, LIBND4J_TYPES); +BUILD_SINGLE_UNCHAINED_TEMPLATE(template ND4J_EXPORT , * NDArray::bufferAsT() const, LIBND4J_TYPES); //////////////////////////////////////////////////////////////////////// NDArray* NDArray::subarray(IndicesList& idx) const { @@ -1797,16 +1797,16 @@ NDArray NDArray::operator+(const T& scalar) const { return result; } -template NDArray NDArray::operator+(const double& scalar) const; -template NDArray NDArray::operator+(const float& scalar) const; -template NDArray NDArray::operator+(const float16& scalar) const; -template NDArray NDArray::operator+(const bfloat16& scalar) const; -template NDArray NDArray::operator+(const Nd4jLong& scalar) const; -template NDArray NDArray::operator+(const int& scalar) const; -template NDArray NDArray::operator+(const int16_t& scalar) const; -template NDArray NDArray::operator+(const int8_t& scalar) const; -template NDArray NDArray::operator+(const uint8_t& scalar) const; -template NDArray NDArray::operator+(const bool& scalar) const; +template ND4J_EXPORT NDArray NDArray::operator+(const double& scalar) const; +template ND4J_EXPORT NDArray NDArray::operator+(const float& scalar) const; +template ND4J_EXPORT NDArray NDArray::operator+(const float16& scalar) const; +template ND4J_EXPORT NDArray NDArray::operator+(const bfloat16& scalar) const; +template ND4J_EXPORT NDArray NDArray::operator+(const Nd4jLong& scalar) const; +template ND4J_EXPORT NDArray NDArray::operator+(const int& scalar) const; +template ND4J_EXPORT NDArray NDArray::operator+(const int16_t& scalar) const; +template ND4J_EXPORT NDArray NDArray::operator+(const int8_t& scalar) const; +template ND4J_EXPORT NDArray NDArray::operator+(const uint8_t& scalar) const; +template ND4J_EXPORT NDArray NDArray::operator+(const bool& scalar) const; //////////////////////////////////////////////////////////////////////// // subtraction operator array - scalar @@ -1824,16 +1824,16 @@ NDArray NDArray::operator-(const T& scalar) const { return result; } -template NDArray NDArray::operator-(const double& scalar) const; -template NDArray NDArray::operator-(const float& scalar) const; -template NDArray NDArray::operator-(const float16& scalar) const; -template NDArray NDArray::operator-(const bfloat16& scalar) const; -template NDArray NDArray::operator-(const Nd4jLong& scalar) const; -template NDArray NDArray::operator-(const int& scalar) const; -template NDArray NDArray::operator-(const int16_t& scalar) const; -template NDArray NDArray::operator-(const int8_t& scalar) const; -template NDArray NDArray::operator-(const uint8_t& scalar) const; -template NDArray NDArray::operator-(const bool& scalar) const; +template ND4J_EXPORT NDArray NDArray::operator-(const double& scalar) const; +template ND4J_EXPORT NDArray NDArray::operator-(const float& scalar) const; +template ND4J_EXPORT NDArray NDArray::operator-(const float16& scalar) const; +template ND4J_EXPORT NDArray NDArray::operator-(const bfloat16& scalar) const; +template ND4J_EXPORT NDArray NDArray::operator-(const Nd4jLong& scalar) const; +template ND4J_EXPORT NDArray NDArray::operator-(const int& scalar) const; +template ND4J_EXPORT NDArray NDArray::operator-(const int16_t& scalar) const; +template ND4J_EXPORT NDArray NDArray::operator-(const int8_t& scalar) const; +template ND4J_EXPORT NDArray NDArray::operator-(const uint8_t& scalar) const; +template ND4J_EXPORT NDArray NDArray::operator-(const bool& scalar) const; //////////////////////////////////////////////////////////////////////// // multiplication operator array*scalar @@ -1851,16 +1851,16 @@ NDArray NDArray::operator*(const T& scalar) const { return result; } -template NDArray NDArray::operator*(const double& scalar) const; -template NDArray NDArray::operator*(const float& scalar) const; -template NDArray NDArray::operator*(const float16& scalar) const; -template NDArray NDArray::operator*(const bfloat16& scalar) const; -template NDArray NDArray::operator*(const Nd4jLong& scalar) const; -template NDArray NDArray::operator*(const int& scalar) const; -template NDArray NDArray::operator*(const int16_t& scalar) const; -template NDArray NDArray::operator*(const int8_t& scalar) const; -template NDArray NDArray::operator*(const uint8_t& scalar) const; -template NDArray NDArray::operator*(const bool& scalar) const; +template ND4J_EXPORT NDArray NDArray::operator*(const double& scalar) const; +template ND4J_EXPORT NDArray NDArray::operator*(const float& scalar) const; +template ND4J_EXPORT NDArray NDArray::operator*(const float16& scalar) const; +template ND4J_EXPORT NDArray NDArray::operator*(const bfloat16& scalar) const; +template ND4J_EXPORT NDArray NDArray::operator*(const Nd4jLong& scalar) const; +template ND4J_EXPORT NDArray NDArray::operator*(const int& scalar) const; +template ND4J_EXPORT NDArray NDArray::operator*(const int16_t& scalar) const; +template ND4J_EXPORT NDArray NDArray::operator*(const int8_t& scalar) const; +template ND4J_EXPORT NDArray NDArray::operator*(const uint8_t& scalar) const; +template ND4J_EXPORT NDArray NDArray::operator*(const bool& scalar) const; //////////////////////////////////////////////////////////////////////// // division operator array / scalar @@ -1881,16 +1881,16 @@ NDArray NDArray::operator/(const T& scalar) const { return result; } -template NDArray NDArray::operator/(const double& scalar) const; -template NDArray NDArray::operator/(const float& scalar) const; -template NDArray NDArray::operator/(const float16& scalar) const; -template NDArray NDArray::operator/(const bfloat16& scalar) const; -template NDArray NDArray::operator/(const Nd4jLong& scalar) const; -template NDArray NDArray::operator/(const int& scalar) const; -template NDArray NDArray::operator/(const int16_t& scalar) const; -template NDArray NDArray::operator/(const int8_t& scalar) const; -template NDArray NDArray::operator/(const uint8_t& scalar) const; -template NDArray NDArray::operator/(const bool& scalar) const; +template ND4J_EXPORT NDArray NDArray::operator/(const double& scalar) const; +template ND4J_EXPORT NDArray NDArray::operator/(const float& scalar) const; +template ND4J_EXPORT NDArray NDArray::operator/(const float16& scalar) const; +template ND4J_EXPORT NDArray NDArray::operator/(const bfloat16& scalar) const; +template ND4J_EXPORT NDArray NDArray::operator/(const Nd4jLong& scalar) const; +template ND4J_EXPORT NDArray NDArray::operator/(const int& scalar) const; +template ND4J_EXPORT NDArray NDArray::operator/(const int16_t& scalar) const; +template ND4J_EXPORT NDArray NDArray::operator/(const int8_t& scalar) const; +template ND4J_EXPORT NDArray NDArray::operator/(const uint8_t& scalar) const; +template ND4J_EXPORT NDArray NDArray::operator/(const bool& scalar) const; //////////////////////////////////////////////////////////////////////// // addition operator scalar + array @@ -2260,13 +2260,13 @@ void NDArray::operator+=(const T value) { NDArray::registerSpecialUse({this}, {}); } -template void NDArray::operator+=(const double value); -template void NDArray::operator+=(const float value); -template void NDArray::operator+=(const float16 value); -template void NDArray::operator+=(const bfloat16 value); -template void NDArray::operator+=(const Nd4jLong value); -template void NDArray::operator+=(const int value); -template void NDArray::operator+=(const bool value); +template ND4J_EXPORT void NDArray::operator+=(const double value); +template ND4J_EXPORT void NDArray::operator+=(const float value); +template ND4J_EXPORT void NDArray::operator+=(const float16 value); +template ND4J_EXPORT void NDArray::operator+=(const bfloat16 value); +template ND4J_EXPORT void NDArray::operator+=(const Nd4jLong value); +template ND4J_EXPORT void NDArray::operator+=(const int value); +template ND4J_EXPORT void NDArray::operator+=(const bool value); //////////////////////////////////////////////////////////////////////// template @@ -2282,13 +2282,13 @@ void NDArray::operator-=(const T value) { NDArray::registerSpecialUse({this}, {}); } -template void NDArray::operator-=(const double value); -template void NDArray::operator-=(const float value); -template void NDArray::operator-=(const float16 value); -template void NDArray::operator-=(const bfloat16 value); -template void NDArray::operator-=(const Nd4jLong value); -template void NDArray::operator-=(const int value); -template void NDArray::operator-=(const bool value); +template ND4J_EXPORT void NDArray::operator-=(const double value); +template ND4J_EXPORT void NDArray::operator-=(const float value); +template ND4J_EXPORT void NDArray::operator-=(const float16 value); +template ND4J_EXPORT void NDArray::operator-=(const bfloat16 value); +template ND4J_EXPORT void NDArray::operator-=(const Nd4jLong value); +template ND4J_EXPORT void NDArray::operator-=(const int value); +template ND4J_EXPORT void NDArray::operator-=(const bool value); //////////////////////////////////////////////////////////////////////// template @@ -2302,16 +2302,16 @@ void NDArray::operator*=(const T scalar) { NDArray::registerSpecialUse({this}, {}); } -template void NDArray::operator*=(const double scalar); -template void NDArray::operator*=(const float scalar); -template void NDArray::operator*=(const float16 scalar); -template void NDArray::operator*=(const bfloat16 scalar); -template void NDArray::operator*=(const Nd4jLong scalar); -template void NDArray::operator*=(const int scalar); -template void NDArray::operator*=(const int16_t scalar); -template void NDArray::operator*=(const int8_t scalar); -template void NDArray::operator*=(const uint8_t scalar); -template void NDArray::operator*=(const bool scalar); +template ND4J_EXPORT void NDArray::operator*=(const double scalar); +template ND4J_EXPORT void NDArray::operator*=(const float scalar); +template ND4J_EXPORT void NDArray::operator*=(const float16 scalar); +template ND4J_EXPORT void NDArray::operator*=(const bfloat16 scalar); +template ND4J_EXPORT void NDArray::operator*=(const Nd4jLong scalar); +template ND4J_EXPORT void NDArray::operator*=(const int scalar); +template ND4J_EXPORT void NDArray::operator*=(const int16_t scalar); +template ND4J_EXPORT void NDArray::operator*=(const int8_t scalar); +template ND4J_EXPORT void NDArray::operator*=(const uint8_t scalar); +template ND4J_EXPORT void NDArray::operator*=(const bool scalar); //////////////////////////////////////////////////////////////////////// template @@ -2324,16 +2324,16 @@ void NDArray::operator/=(const T scalar) { NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::Divide, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), nullptr); NDArray::registerSpecialUse({this}, {}); } -template void NDArray::operator/=(const double scalar); -template void NDArray::operator/=(const float scalar); -template void NDArray::operator/=(const float16 scalar); -template void NDArray::operator/=(const bfloat16 scalar); -template void NDArray::operator/=(const Nd4jLong scalar); -template void NDArray::operator/=(const int scalar); -template void NDArray::operator/=(const int16_t scalar); -template void NDArray::operator/=(const int8_t scalar); -template void NDArray::operator/=(const uint8_t scalar); -template void NDArray::operator/=(const bool scalar); +template ND4J_EXPORT void NDArray::operator/=(const double scalar); +template ND4J_EXPORT void NDArray::operator/=(const float scalar); +template ND4J_EXPORT void NDArray::operator/=(const float16 scalar); +template ND4J_EXPORT void NDArray::operator/=(const bfloat16 scalar); +template ND4J_EXPORT void NDArray::operator/=(const Nd4jLong scalar); +template ND4J_EXPORT void NDArray::operator/=(const int scalar); +template ND4J_EXPORT void NDArray::operator/=(const int16_t scalar); +template ND4J_EXPORT void NDArray::operator/=(const int8_t scalar); +template ND4J_EXPORT void NDArray::operator/=(const uint8_t scalar); +template ND4J_EXPORT void NDArray::operator/=(const bool scalar); //////////////////////////////////////////////////////////////////////// // subtraction operator array - array @@ -2929,7 +2929,7 @@ std::vector NDArray::asVectorT() { return result; } -BUILD_SINGLE_TEMPLATE(template std::vector, NDArray::asVectorT(), LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT std::vector, NDArray::asVectorT(), LIBND4J_TYPES); ////////////////////////////////////////////////////////////////////////// // set new order and shape in case of suitable array length @@ -3046,7 +3046,7 @@ template void NDArray::templatedSet(void *buffer, const Nd4jLong xOfsset, nd4j::DataType dtype, const void *value) { BUILD_SINGLE_PARTIAL_SELECTOR(dtype, templatedSet< , T>(buffer, xOfsset, value), LIBND4J_TYPES); } -BUILD_SINGLE_TEMPLATE(template void NDArray::templatedSet, (void *buffer, const Nd4jLong xOfsset, nd4j::DataType dtype, const void *value), LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT void NDArray::templatedSet, (void *buffer, const Nd4jLong xOfsset, nd4j::DataType dtype, const void *value), LIBND4J_TYPES); //////////////////////////////////////////////////////////////////////// void NDArray::applyPairwiseTransform(nd4j::pairwise::Ops op, const NDArray* other, NDArray *target, ExtraArguments *extraParams) const{ @@ -3109,7 +3109,7 @@ void NDArray::templatedDoubleAssign(void *xBuffer, const Nd4jLong xOffset, const const auto y = reinterpret_cast(yBuffer); x[xOffset] = static_cast(y[yOffset]); } -BUILD_DOUBLE_TEMPLATE(template void NDArray::templatedDoubleAssign, (void *xBuffer, const Nd4jLong xOffset, const void *yBuffer, const Nd4jLong yOffset) const, LIBND4J_TYPES, LIBND4J_TYPES); +BUILD_DOUBLE_TEMPLATE(template ND4J_EXPORT void NDArray::templatedDoubleAssign, (void *xBuffer, const Nd4jLong xOffset, const void *yBuffer, const Nd4jLong yOffset) const, LIBND4J_TYPES, LIBND4J_TYPES); //////////////////////////////////////////////////////////////////////// void NDArray::varianceAlongDimension(nd4j::variance::Ops op, NDArray *target, const bool biasCorrected, const std::vector& dimensions) const { @@ -3356,7 +3356,7 @@ T NDArray::e(const Nd4jLong i) const { BUILD_SINGLE_PARTIAL_SELECTOR(dataType(), return templatedGet<, T>(getBuffer(), rp), LIBND4J_TYPES); } -BUILD_SINGLE_UNCHAINED_TEMPLATE(template , NDArray::e(const Nd4jLong) const, LIBND4J_TYPES); +BUILD_SINGLE_UNCHAINED_TEMPLATE(template ND4J_EXPORT , NDArray::e(const Nd4jLong) const, LIBND4J_TYPES); ////////////////////////////////////////////////////////////////////////// // Returns value from 2D matrix by coordinates/indexes @@ -3376,7 +3376,7 @@ T NDArray::e(const Nd4jLong i, const Nd4jLong j) const { return static_cast(119); } -BUILD_SINGLE_UNCHAINED_TEMPLATE(template , NDArray::e(const Nd4jLong, const Nd4jLong) const, LIBND4J_TYPES); +BUILD_SINGLE_UNCHAINED_TEMPLATE(template ND4J_EXPORT , NDArray::e(const Nd4jLong, const Nd4jLong) const, LIBND4J_TYPES); ////////////////////////////////////////////////////////////////////////// // returns value from 3D tensor by coordinates @@ -3396,7 +3396,7 @@ T NDArray::e(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k) const { return static_cast(119); } -BUILD_SINGLE_UNCHAINED_TEMPLATE(template , NDArray::e(const Nd4jLong, const Nd4jLong, const Nd4jLong) const, LIBND4J_TYPES); +BUILD_SINGLE_UNCHAINED_TEMPLATE(template ND4J_EXPORT , NDArray::e(const Nd4jLong, const Nd4jLong, const Nd4jLong) const, LIBND4J_TYPES); ////////////////////////////////////////////////////////////////////////// // returns value from 3D tensor by coordinates @@ -3416,7 +3416,7 @@ T NDArray::e(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLon return static_cast(119); } -BUILD_SINGLE_UNCHAINED_TEMPLATE(template , NDArray::e(const Nd4jLong, const Nd4jLong, const Nd4jLong, const Nd4jLong) const, LIBND4J_TYPES); +BUILD_SINGLE_UNCHAINED_TEMPLATE(template ND4J_EXPORT , NDArray::e(const Nd4jLong, const Nd4jLong, const Nd4jLong, const Nd4jLong) const, LIBND4J_TYPES); ////////////////////////////////////////////////////////////////////////// NDArray NDArray::e(const Nd4jLong i) const { @@ -3591,17 +3591,17 @@ void NDArray::applyScalar(nd4j::scalar::Ops op, const T scalar, NDArray *target, applyScalarArr(op, &scalarArr, target, extraParams); } -template <> void NDArray::applyScalar(nd4j::scalar::Ops op, const NDArray* scalar, NDArray *target, ExtraArguments *extraParams) { throw std::runtime_error("NDArray::applyScalar method: do not use me!");} -template void NDArray::applyScalar(nd4j::scalar::Ops op, const double scalar, NDArray *target, ExtraArguments *extraParams); -template void NDArray::applyScalar(nd4j::scalar::Ops op, const float scalar, NDArray *target, ExtraArguments *extraParams); -template void NDArray::applyScalar(nd4j::scalar::Ops op, const float16 scalar, NDArray *target, ExtraArguments *extraParams); -template void NDArray::applyScalar(nd4j::scalar::Ops op, const bfloat16 scalar, NDArray *target, ExtraArguments *extraParams); -template void NDArray::applyScalar(nd4j::scalar::Ops op, const Nd4jLong scalar, NDArray *target, ExtraArguments *extraParams); -template void NDArray::applyScalar(nd4j::scalar::Ops op, const int scalar, NDArray *target, ExtraArguments *extraParams); -template void NDArray::applyScalar(nd4j::scalar::Ops op, const int16_t scalar, NDArray *target, ExtraArguments *extraParams); -template void NDArray::applyScalar(nd4j::scalar::Ops op, const int8_t scalar, NDArray *target, ExtraArguments *extraParams); -template void NDArray::applyScalar(nd4j::scalar::Ops op, const uint8_t scalar, NDArray *target, ExtraArguments *extraParams); -template void NDArray::applyScalar(nd4j::scalar::Ops op, const bool scalar, NDArray *target, ExtraArguments *extraParams); +template <> ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::Ops op, const NDArray* scalar, NDArray *target, ExtraArguments *extraParams) { throw std::runtime_error("NDArray::applyScalar method: do not use me!");} +template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::Ops op, const double scalar, NDArray *target, ExtraArguments *extraParams); +template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::Ops op, const float scalar, NDArray *target, ExtraArguments *extraParams); +template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::Ops op, const float16 scalar, NDArray *target, ExtraArguments *extraParams); +template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::Ops op, const bfloat16 scalar, NDArray *target, ExtraArguments *extraParams); +template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::Ops op, const Nd4jLong scalar, NDArray *target, ExtraArguments *extraParams); +template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::Ops op, const int scalar, NDArray *target, ExtraArguments *extraParams); +template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::Ops op, const int16_t scalar, NDArray *target, ExtraArguments *extraParams); +template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::Ops op, const int8_t scalar, NDArray *target, ExtraArguments *extraParams); +template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::Ops op, const uint8_t scalar, NDArray *target, ExtraArguments *extraParams); +template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::Ops op, const bool scalar, NDArray *target, ExtraArguments *extraParams); ////////////////////////////////////////////////////////////////////////// void NDArray::applyScalarArr(nd4j::scalar::BoolOps op, const NDArray* scalar, NDArray *target, ExtraArguments *extraParams) const { @@ -3627,17 +3627,17 @@ void NDArray::applyScalar(nd4j::scalar::BoolOps op, const T scalar, NDArray *tar applyScalarArr(op, &scalarArr, target, extraParams); } -template <> void NDArray::applyScalar(nd4j::scalar::BoolOps op, const NDArray* scalar, NDArray *target, ExtraArguments *extraParams) const { throw std::runtime_error("NDArray::applyScalar method: do not use me!");} -template void NDArray::applyScalar(nd4j::scalar::BoolOps op, const double scalar, NDArray *target, ExtraArguments *extraParams) const; -template void NDArray::applyScalar(nd4j::scalar::BoolOps op, const float scalar, NDArray *target, ExtraArguments *extraParams) const; -template void NDArray::applyScalar(nd4j::scalar::BoolOps op, const float16 scalar, NDArray *target, ExtraArguments *extraParams) const; -template void NDArray::applyScalar(nd4j::scalar::BoolOps op, const bfloat16 scalar, NDArray *target, ExtraArguments *extraParams) const; -template void NDArray::applyScalar(nd4j::scalar::BoolOps op, const Nd4jLong scalar, NDArray *target, ExtraArguments *extraParams) const; -template void NDArray::applyScalar(nd4j::scalar::BoolOps op, const int scalar, NDArray *target, ExtraArguments *extraParams) const; -template void NDArray::applyScalar(nd4j::scalar::BoolOps op, const int16_t scalar, NDArray *target, ExtraArguments *extraParams) const; -template void NDArray::applyScalar(nd4j::scalar::BoolOps op, const int8_t scalar, NDArray *target, ExtraArguments *extraParams) const; -template void NDArray::applyScalar(nd4j::scalar::BoolOps op, const uint8_t scalar, NDArray *target, ExtraArguments *extraParams) const; -template void NDArray::applyScalar(nd4j::scalar::BoolOps op, const bool scalar, NDArray *target, ExtraArguments *extraParams) const; +template <> ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::BoolOps op, const NDArray* scalar, NDArray *target, ExtraArguments *extraParams) const { throw std::runtime_error("NDArray::applyScalar method: do not use me!");} +template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::BoolOps op, const double scalar, NDArray *target, ExtraArguments *extraParams) const; +template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::BoolOps op, const float scalar, NDArray *target, ExtraArguments *extraParams) const; +template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::BoolOps op, const float16 scalar, NDArray *target, ExtraArguments *extraParams) const; +template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::BoolOps op, const bfloat16 scalar, NDArray *target, ExtraArguments *extraParams) const; +template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::BoolOps op, const Nd4jLong scalar, NDArray *target, ExtraArguments *extraParams) const; +template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::BoolOps op, const int scalar, NDArray *target, ExtraArguments *extraParams) const; +template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::BoolOps op, const int16_t scalar, NDArray *target, ExtraArguments *extraParams) const; +template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::BoolOps op, const int8_t scalar, NDArray *target, ExtraArguments *extraParams) const; +template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::BoolOps op, const uint8_t scalar, NDArray *target, ExtraArguments *extraParams) const; +template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::BoolOps op, const bool scalar, NDArray *target, ExtraArguments *extraParams) const; ////////////////////////////////////////////////////////////////////////// @@ -3665,17 +3665,17 @@ template void NDArray::applyScalar(nd4j::scalar::BoolOps op, const bool sc applyScalarArr(op, &scalarArr, target, extraParams); } - template <> void NDArray::applyScalar(nd4j::scalar::IntOps op, const NDArray* scalar, NDArray *target, ExtraArguments *extraParams) const { throw std::runtime_error("NDArray::applyScalar method: do not use me!");} - template void NDArray::applyScalar(nd4j::scalar::IntOps op, const double scalar, NDArray *target, ExtraArguments *extraParams) const; - template void NDArray::applyScalar(nd4j::scalar::IntOps op, const float scalar, NDArray *target, ExtraArguments *extraParams) const; - template void NDArray::applyScalar(nd4j::scalar::IntOps op, const float16 scalar, NDArray *target, ExtraArguments *extraParams) const; - template void NDArray::applyScalar(nd4j::scalar::IntOps op, const bfloat16 scalar, NDArray *target, ExtraArguments *extraParams) const; - template void NDArray::applyScalar(nd4j::scalar::IntOps op, const Nd4jLong scalar, NDArray *target, ExtraArguments *extraParams) const; - template void NDArray::applyScalar(nd4j::scalar::IntOps op, const int scalar, NDArray *target, ExtraArguments *extraParams) const; - template void NDArray::applyScalar(nd4j::scalar::IntOps op, const int16_t scalar, NDArray *target, ExtraArguments *extraParams) const; - template void NDArray::applyScalar(nd4j::scalar::IntOps op, const int8_t scalar, NDArray *target, ExtraArguments *extraParams) const; - template void NDArray::applyScalar(nd4j::scalar::IntOps op, const uint8_t scalar, NDArray *target, ExtraArguments *extraParams) const; - template void NDArray::applyScalar(nd4j::scalar::IntOps op, const bool scalar, NDArray *target, ExtraArguments *extraParams) const; + template <> ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::IntOps op, const NDArray* scalar, NDArray *target, ExtraArguments *extraParams) const { throw std::runtime_error("NDArray::applyScalar method: do not use me!");} + template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::IntOps op, const double scalar, NDArray *target, ExtraArguments *extraParams) const; + template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::IntOps op, const float scalar, NDArray *target, ExtraArguments *extraParams) const; + template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::IntOps op, const float16 scalar, NDArray *target, ExtraArguments *extraParams) const; + template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::IntOps op, const bfloat16 scalar, NDArray *target, ExtraArguments *extraParams) const; + template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::IntOps op, const Nd4jLong scalar, NDArray *target, ExtraArguments *extraParams) const; + template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::IntOps op, const int scalar, NDArray *target, ExtraArguments *extraParams) const; + template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::IntOps op, const int16_t scalar, NDArray *target, ExtraArguments *extraParams) const; + template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::IntOps op, const int8_t scalar, NDArray *target, ExtraArguments *extraParams) const; + template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::IntOps op, const uint8_t scalar, NDArray *target, ExtraArguments *extraParams) const; + template ND4J_EXPORT void NDArray::applyScalar(nd4j::scalar::IntOps op, const bool scalar, NDArray *target, ExtraArguments *extraParams) const; //////////////////////////////////////////////////////////////////////// @@ -3966,19 +3966,19 @@ void NDArray::p(const Nd4jLong i, const T value) { NDArray::registerPrimaryUse({this}, {}); } -template void NDArray::p(const Nd4jLong i, const double value); -template void NDArray::p(const Nd4jLong i, const float value); -template void NDArray::p(const Nd4jLong i, const float16 value); -template void NDArray::p(const Nd4jLong i, const bfloat16 value); -template void NDArray::p(const Nd4jLong i, const Nd4jLong value); -template void NDArray::p(const Nd4jLong i, const int value); -template void NDArray::p(const Nd4jLong i, const int8_t value); -template void NDArray::p(const Nd4jLong i, const uint8_t value); -template void NDArray::p(const Nd4jLong i, const uint16_t value); -template void NDArray::p(const Nd4jLong i, const uint32_t value); -template void NDArray::p(const Nd4jLong i, const uint64_t value); -template void NDArray::p(const Nd4jLong i, const int16_t value); -template void NDArray::p(const Nd4jLong i, const bool value); +template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const double value); +template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const float value); +template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const float16 value); +template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const bfloat16 value); +template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong value); +template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const int value); +template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const int8_t value); +template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const uint8_t value); +template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const uint16_t value); +template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const uint32_t value); +template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const uint64_t value); +template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const int16_t value); +template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const bool value); ////////////////////////////////////////////////////////////////////////// // This method sets value in 2D matrix to position i, j @@ -3996,19 +3996,19 @@ void NDArray::p(const Nd4jLong i, const Nd4jLong j, const T value) { BUILD_SINGLE_PARTIAL_SELECTOR(dataType(), templatedSet<, T>(this->getBuffer(), xOffset, p), LIBND4J_TYPES); NDArray::registerPrimaryUse({this}, {}); } -template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const double value); -template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const float value); -template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const float16 value); -template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const bfloat16 value); -template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong value); -template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const int value); -template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const int8_t value); -template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const uint8_t value); -template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const uint16_t value); -template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const uint32_t value); -template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const uint64_t value); -template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const int16_t value); -template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const bool value); +template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const double value); +template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const float value); +template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const float16 value); +template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const bfloat16 value); +template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong value); +template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const int value); +template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const int8_t value); +template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const uint8_t value); +template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const uint16_t value); +template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const uint32_t value); +template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const uint64_t value); +template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const int16_t value); +template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const bool value); ////////////////////////////////////////////////////////////////////////// // This method sets value in 3D matrix to position i,j,k @@ -4026,19 +4026,19 @@ void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const T va BUILD_SINGLE_PARTIAL_SELECTOR(dataType(), templatedSet<, T>(this->getBuffer(), xOffset, p), LIBND4J_TYPES); NDArray::registerPrimaryUse({this}, {}); } -template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const double value); -template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const float value); -template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const float16 value); -template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const bfloat16 value); -template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong value); -template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const int value); -template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const int8_t value); -template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const uint8_t value); -template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const uint16_t value); -template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const uint32_t value); -template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const uint64_t value); -template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const int16_t value); -template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const bool value); +template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const double value); +template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const float value); +template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const float16 value); +template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const bfloat16 value); +template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong value); +template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const int value); +template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const int8_t value); +template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const uint8_t value); +template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const uint16_t value); +template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const uint32_t value); +template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const uint64_t value); +template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const int16_t value); +template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const bool value); ////////////////////////////////////////////////////////////////////////// template @@ -4055,19 +4055,19 @@ void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4j BUILD_SINGLE_PARTIAL_SELECTOR(dataType(), templatedSet<, T>(this->getBuffer(), xOffset, p), LIBND4J_TYPES); NDArray::registerPrimaryUse({this}, {}); } -template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const double value); -template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const float value); -template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const float16 value); -template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const bfloat16 value); -template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const Nd4jLong value); -template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const int value); -template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const int8_t value); -template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const uint8_t value); -template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const uint16_t value); -template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const uint32_t value); -template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const uint64_t value); -template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const int16_t value); -template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const bool value); +template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const double value); +template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const float value); +template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const float16 value); +template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const bfloat16 value); +template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const Nd4jLong value); +template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const int value); +template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const int8_t value); +template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const uint8_t value); +template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const uint16_t value); +template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const uint32_t value); +template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const uint64_t value); +template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const int16_t value); +template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const bool value); //////////////////////////////////////////////////////////////////////// void NDArray::p(const Nd4jLong i, const NDArray& scalar) { @@ -4256,7 +4256,7 @@ void NDArray::templatedAssign(void *xBuffer, Nd4jLong xOffset, const void *yBuff if (xBuffer != nullptr && yBuffer != nullptr) *(reinterpret_cast(xBuffer) + xOffset) = *(reinterpret_cast(yBuffer) + yOffset); } -BUILD_SINGLE_TEMPLATE(template void NDArray::templatedAssign, (void *xBuffer, const Nd4jLong xOffset, const void *yBuffer, const Nd4jLong yOffset) const, LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT void NDArray::templatedAssign, (void *xBuffer, const Nd4jLong xOffset, const void *yBuffer, const Nd4jLong yOffset) const, LIBND4J_TYPES); ////////////////////////////////////////////////////////////////////////// diff --git a/libnd4j/blas/cpu/NDArrayFactory.cpp b/libnd4j/blas/cpu/NDArrayFactory.cpp index b091f13b7..54cc6bba8 100644 --- a/libnd4j/blas/cpu/NDArrayFactory.cpp +++ b/libnd4j/blas/cpu/NDArrayFactory.cpp @@ -29,7 +29,7 @@ namespace nd4j { //////////////////////////////////////////////////////////////////////// template <> - NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::vector &data, nd4j::LaunchContext * context) { + ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::vector &data, nd4j::LaunchContext * context) { if ((int) shape.size() > MAX_RANK) throw std::invalid_argument("NDArrayFactory::create: rank of NDArray can't exceed 32 !"); @@ -71,8 +71,19 @@ namespace nd4j { NDArray result(buffer, descriptor, context); return result; - } + template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::vector& data, nd4j::LaunchContext * context); + template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::vector& data, nd4j::LaunchContext * context); + template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::vector& data, nd4j::LaunchContext * context); + template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::vector& data, nd4j::LaunchContext * context); + template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::vector& data, nd4j::LaunchContext * context); + template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::vector& data, nd4j::LaunchContext * context); + template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::vector& data, nd4j::LaunchContext * context); + template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::vector& data, nd4j::LaunchContext * context); + template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::vector& data, nd4j::LaunchContext * context); + template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::vector& data, nd4j::LaunchContext * context); + template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::vector& data, nd4j::LaunchContext * context); + template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::vector& data, nd4j::LaunchContext * context); NDArray NDArrayFactory::string(const char *str, nd4j::LaunchContext * context) { std::string s(str); @@ -118,7 +129,7 @@ template NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, nd4j::LaunchContext * context) { return create_(order, shape, DataTypeUtils::fromT(), context); } -BUILD_SINGLE_TEMPLATE(template NDArray* NDArrayFactory::create_, (const char order, const std::vector &shape, nd4j::LaunchContext * context), LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT NDArray* NDArrayFactory::create_, (const char order, const std::vector &shape, nd4j::LaunchContext * context), LIBND4J_TYPES); //////////////////////////////////////////////////////////////////////// template @@ -128,20 +139,20 @@ void NDArrayFactory::memcpyFromVector(void *ptr, const std::vector &vector) { } template <> -void NDArrayFactory::memcpyFromVector(void *ptr, const std::vector &vector) { +void ND4J_EXPORT NDArrayFactory::memcpyFromVector(void *ptr, const std::vector &vector) { auto p = reinterpret_cast(ptr); for (Nd4jLong e = 0; e < vector.size(); e++) p[e] = vector[e]; } -template void NDArrayFactory::memcpyFromVector(void *ptr, const std::vector &vector); -template void NDArrayFactory::memcpyFromVector(void *ptr, const std::vector &vector); -template void NDArrayFactory::memcpyFromVector(void *ptr, const std::vector &vector); -template void NDArrayFactory::memcpyFromVector(void *ptr, const std::vector &vector); -template void NDArrayFactory::memcpyFromVector(void *ptr, const std::vector &vector); -template void NDArrayFactory::memcpyFromVector(void *ptr, const std::vector &vector); -template void NDArrayFactory::memcpyFromVector(void *ptr, const std::vector &vector); -template void NDArrayFactory::memcpyFromVector(void *ptr, const std::vector &vector); +template ND4J_EXPORT void NDArrayFactory::memcpyFromVector(void *ptr, const std::vector &vector); +template ND4J_EXPORT void NDArrayFactory::memcpyFromVector(void *ptr, const std::vector &vector); +template ND4J_EXPORT void NDArrayFactory::memcpyFromVector(void *ptr, const std::vector &vector); +template ND4J_EXPORT void NDArrayFactory::memcpyFromVector(void *ptr, const std::vector &vector); +template ND4J_EXPORT void NDArrayFactory::memcpyFromVector(void *ptr, const std::vector &vector); +template ND4J_EXPORT void NDArrayFactory::memcpyFromVector(void *ptr, const std::vector &vector); +template ND4J_EXPORT void NDArrayFactory::memcpyFromVector(void *ptr, const std::vector &vector); +template ND4J_EXPORT void NDArrayFactory::memcpyFromVector(void *ptr, const std::vector &vector); #ifndef __JAVACPP_HACK__ @@ -150,16 +161,16 @@ template void NDArrayFactory::memcpyFromVector(void *ptr, const std::vector& shape, const T value, const char order, nd4j::LaunchContext * context) { return valueOf(std::vector(shape), value, order); } - template NDArray* NDArrayFactory::valueOf(const std::initializer_list& shape, const double value, const char order, nd4j::LaunchContext * context); - template NDArray* NDArrayFactory::valueOf(const std::initializer_list& shape, const float value, const char order, nd4j::LaunchContext * context); - template NDArray* NDArrayFactory::valueOf(const std::initializer_list& shape, const float16 value, const char order, nd4j::LaunchContext * context); - template NDArray* NDArrayFactory::valueOf(const std::initializer_list& shape, const bfloat16 value, const char order, nd4j::LaunchContext * context); - template NDArray* NDArrayFactory::valueOf(const std::initializer_list& shape, const Nd4jLong value, const char order, nd4j::LaunchContext * context); - template NDArray* NDArrayFactory::valueOf(const std::initializer_list& shape, const int value, const char order, nd4j::LaunchContext * context); - template NDArray* NDArrayFactory::valueOf(const std::initializer_list& shape, const uint8_t value, const char order, nd4j::LaunchContext * context); - template NDArray* NDArrayFactory::valueOf(const std::initializer_list& shape, const int8_t value, const char order, nd4j::LaunchContext * context); - template NDArray* NDArrayFactory::valueOf(const std::initializer_list& shape, const int16_t value, const char order, nd4j::LaunchContext * context); - template NDArray* NDArrayFactory::valueOf(const std::initializer_list& shape, const bool value, const char order, nd4j::LaunchContext * context); + template ND4J_EXPORT NDArray* NDArrayFactory::valueOf(const std::initializer_list& shape, const double value, const char order, nd4j::LaunchContext * context); + template ND4J_EXPORT NDArray* NDArrayFactory::valueOf(const std::initializer_list& shape, const float value, const char order, nd4j::LaunchContext * context); + template ND4J_EXPORT NDArray* NDArrayFactory::valueOf(const std::initializer_list& shape, const float16 value, const char order, nd4j::LaunchContext * context); + template ND4J_EXPORT NDArray* NDArrayFactory::valueOf(const std::initializer_list& shape, const bfloat16 value, const char order, nd4j::LaunchContext * context); + template ND4J_EXPORT NDArray* NDArrayFactory::valueOf(const std::initializer_list& shape, const Nd4jLong value, const char order, nd4j::LaunchContext * context); + template ND4J_EXPORT NDArray* NDArrayFactory::valueOf(const std::initializer_list& shape, const int value, const char order, nd4j::LaunchContext * context); + template ND4J_EXPORT NDArray* NDArrayFactory::valueOf(const std::initializer_list& shape, const uint8_t value, const char order, nd4j::LaunchContext * context); + template ND4J_EXPORT NDArray* NDArrayFactory::valueOf(const std::initializer_list& shape, const int8_t value, const char order, nd4j::LaunchContext * context); + template ND4J_EXPORT NDArray* NDArrayFactory::valueOf(const std::initializer_list& shape, const int16_t value, const char order, nd4j::LaunchContext * context); + template ND4J_EXPORT NDArray* NDArrayFactory::valueOf(const std::initializer_list& shape, const bool value, const char order, nd4j::LaunchContext * context); //////////////////////////////////////////////////////////////////////// template @@ -167,18 +178,18 @@ template void NDArrayFactory::memcpyFromVector(void *ptr, const std::vector vec(data); return create(order, shape, vec, context); } - template NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, nd4j::LaunchContext * context); - template NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, nd4j::LaunchContext * context); - template NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, nd4j::LaunchContext * context); - template NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, nd4j::LaunchContext * context); - template NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, nd4j::LaunchContext * context); - template NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, nd4j::LaunchContext * context); - template NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, nd4j::LaunchContext * context); - template NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, nd4j::LaunchContext * context); - template NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, nd4j::LaunchContext * context); - template NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, nd4j::LaunchContext * context); - template NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, nd4j::LaunchContext * context); - template NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, nd4j::LaunchContext * context); + template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, nd4j::LaunchContext * context); + template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, nd4j::LaunchContext * context); + template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, nd4j::LaunchContext * context); + template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, nd4j::LaunchContext * context); + template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, nd4j::LaunchContext * context); + template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, nd4j::LaunchContext * context); + template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, nd4j::LaunchContext * context); + template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, nd4j::LaunchContext * context); + template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, nd4j::LaunchContext * context); + template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, nd4j::LaunchContext * context); + template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, nd4j::LaunchContext * context); + template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, nd4j::LaunchContext * context); #endif @@ -197,19 +208,19 @@ template void NDArrayFactory::memcpyFromVector(void *ptr, const std::vector NDArray NDArrayFactory::create(nd4j::DataType type, const T scalar, nd4j::LaunchContext * context) { @@ -223,20 +234,20 @@ template void NDArrayFactory::memcpyFromVector(void *ptr, const std::vector NDArray NDArrayFactory::create(const T scalar, nd4j::LaunchContext * context) { @@ -252,19 +263,19 @@ template void NDArrayFactory::memcpyFromVector(void *ptr, const std::vector & return new NDArray(NDArrayFactory::create(order, shape, data, context)); } -template NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, const std::vector &data, nd4j::LaunchContext * context); -template NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, const std::vector &data, nd4j::LaunchContext * context); -template NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, const std::vector &data, nd4j::LaunchContext * context); -template NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, const std::vector &data, nd4j::LaunchContext * context); -template NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, const std::vector &data, nd4j::LaunchContext * context); -template NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, const std::vector &data, nd4j::LaunchContext * context); -template NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, const std::vector &data, nd4j::LaunchContext * context); -template NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, const std::vector &data, nd4j::LaunchContext * context); -template NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, const std::vector &data, nd4j::LaunchContext * context); -template NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, const std::vector &data, nd4j::LaunchContext * context); -template NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, const std::vector &data, nd4j::LaunchContext * context); -template NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, const std::vector &data, nd4j::LaunchContext * context); -template NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, const std::vector &data, nd4j::LaunchContext * context); +template ND4J_EXPORT NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, const std::vector &data, nd4j::LaunchContext * context); +template ND4J_EXPORT NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, const std::vector &data, nd4j::LaunchContext * context); +template ND4J_EXPORT NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, const std::vector &data, nd4j::LaunchContext * context); +template ND4J_EXPORT NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, const std::vector &data, nd4j::LaunchContext * context); +template ND4J_EXPORT NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, const std::vector &data, nd4j::LaunchContext * context); +template ND4J_EXPORT NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, const std::vector &data, nd4j::LaunchContext * context); +template ND4J_EXPORT NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, const std::vector &data, nd4j::LaunchContext * context); +template ND4J_EXPORT NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, const std::vector &data, nd4j::LaunchContext * context); +template ND4J_EXPORT NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, const std::vector &data, nd4j::LaunchContext * context); +template ND4J_EXPORT NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, const std::vector &data, nd4j::LaunchContext * context); +template ND4J_EXPORT NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, const std::vector &data, nd4j::LaunchContext * context); +template ND4J_EXPORT NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, const std::vector &data, nd4j::LaunchContext * context); +template ND4J_EXPORT NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, const std::vector &data, nd4j::LaunchContext * context); //////////////////////////////////////////////////////////////////////// template <> - NDArray* NDArrayFactory::valueOf(const std::vector& shape, NDArray* value, const char order, nd4j::LaunchContext * context) { + ND4J_EXPORT NDArray* NDArrayFactory::valueOf(const std::vector& shape, NDArray* value, const char order, nd4j::LaunchContext * context) { auto result = create_(order, shape, value->dataType(), context); result->assign(*value); return result; } template <> - NDArray* NDArrayFactory::valueOf(const std::vector& shape, NDArray& value, const char order, nd4j::LaunchContext * context) { + ND4J_EXPORT NDArray* NDArrayFactory::valueOf(const std::vector& shape, NDArray& value, const char order, nd4j::LaunchContext * context) { auto result = create_(order, shape, value.dataType(), context); result->assign(value); return result; @@ -309,16 +320,16 @@ template NDArray* NDArrayFactory::create_(const char order, const std::vectorassign(value); return result; } - template NDArray* NDArrayFactory::valueOf(const std::vector& shape, const double value, const char order, nd4j::LaunchContext * context); - template NDArray* NDArrayFactory::valueOf(const std::vector& shape, const float value, const char order, nd4j::LaunchContext * context); - template NDArray* NDArrayFactory::valueOf(const std::vector& shape, const float16 value, const char order, nd4j::LaunchContext * context); - template NDArray* NDArrayFactory::valueOf(const std::vector& shape, const bfloat16 value, const char order, nd4j::LaunchContext * context); - template NDArray* NDArrayFactory::valueOf(const std::vector& shape, const Nd4jLong value, const char order, nd4j::LaunchContext * context); - template NDArray* NDArrayFactory::valueOf(const std::vector& shape, const int value, const char order, nd4j::LaunchContext * context); - template NDArray* NDArrayFactory::valueOf(const std::vector& shape, const int16_t value, const char order, nd4j::LaunchContext * context); - template NDArray* NDArrayFactory::valueOf(const std::vector& shape, const int8_t value, const char order, nd4j::LaunchContext * context); - template NDArray* NDArrayFactory::valueOf(const std::vector& shape, const uint8_t value, const char order, nd4j::LaunchContext * context); - template NDArray* NDArrayFactory::valueOf(const std::vector& shape, const bool value, const char order, nd4j::LaunchContext * context); + template ND4J_EXPORT NDArray* NDArrayFactory::valueOf(const std::vector& shape, const double value, const char order, nd4j::LaunchContext * context); + template ND4J_EXPORT NDArray* NDArrayFactory::valueOf(const std::vector& shape, const float value, const char order, nd4j::LaunchContext * context); + template ND4J_EXPORT NDArray* NDArrayFactory::valueOf(const std::vector& shape, const float16 value, const char order, nd4j::LaunchContext * context); + template ND4J_EXPORT NDArray* NDArrayFactory::valueOf(const std::vector& shape, const bfloat16 value, const char order, nd4j::LaunchContext * context); + template ND4J_EXPORT NDArray* NDArrayFactory::valueOf(const std::vector& shape, const Nd4jLong value, const char order, nd4j::LaunchContext * context); + template ND4J_EXPORT NDArray* NDArrayFactory::valueOf(const std::vector& shape, const int value, const char order, nd4j::LaunchContext * context); + template ND4J_EXPORT NDArray* NDArrayFactory::valueOf(const std::vector& shape, const int16_t value, const char order, nd4j::LaunchContext * context); + template ND4J_EXPORT NDArray* NDArrayFactory::valueOf(const std::vector& shape, const int8_t value, const char order, nd4j::LaunchContext * context); + template ND4J_EXPORT NDArray* NDArrayFactory::valueOf(const std::vector& shape, const uint8_t value, const char order, nd4j::LaunchContext * context); + template ND4J_EXPORT NDArray* NDArrayFactory::valueOf(const std::vector& shape, const bool value, const char order, nd4j::LaunchContext * context); //////////////////////////////////////////////////////////////////////// @@ -334,19 +345,19 @@ template NDArray* NDArrayFactory::create_(const char order, const std::vector @@ -363,19 +374,19 @@ template NDArray* NDArrayFactory::create_(const char order, const std::vector @@ -383,14 +394,14 @@ template NDArray* NDArrayFactory::create_(const char order, const std::vector vec(shape); return create(order, vec, context); } - BUILD_SINGLE_TEMPLATE(template NDArray NDArrayFactory::create, (const char, const std::initializer_list&, nd4j::LaunchContext * context), LIBND4J_TYPES); + BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT NDArray NDArrayFactory::create, (const char, const std::initializer_list&, nd4j::LaunchContext * context), LIBND4J_TYPES); //////////////////////////////////////////////////////////////////////// template NDArray NDArrayFactory::create(const char order, const std::vector &shape, nd4j::LaunchContext * context) { return create(order, shape, DataTypeUtils::fromT(), context); } - BUILD_SINGLE_TEMPLATE(template NDArray NDArrayFactory::create, (const char order, const std::vector &shape, nd4j::LaunchContext * context), LIBND4J_TYPES); + BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT NDArray NDArrayFactory::create, (const char order, const std::vector &shape, nd4j::LaunchContext * context), LIBND4J_TYPES); //////////////////////////////////////////////////////////////////////// NDArray NDArrayFactory::create(const char order, const std::vector &shape, nd4j::DataType dtype, nd4j::LaunchContext* context) { @@ -443,17 +454,17 @@ NDArray NDArrayFactory::create(const std::vector &values, nd4j::LaunchContext return res; } -template NDArray NDArrayFactory::create(const std::vector &values, nd4j::LaunchContext * context); -template NDArray NDArrayFactory::create(const std::vector &values, nd4j::LaunchContext * context); -template NDArray NDArrayFactory::create(const std::vector &values, nd4j::LaunchContext * context); -template NDArray NDArrayFactory::create(const std::vector &values, nd4j::LaunchContext * context); -template NDArray NDArrayFactory::create(const std::vector &values, nd4j::LaunchContext * context); -template NDArray NDArrayFactory::create(const std::vector &values, nd4j::LaunchContext * context); -template NDArray NDArrayFactory::create(const std::vector &values, nd4j::LaunchContext * context); -template NDArray NDArrayFactory::create(const std::vector &values, nd4j::LaunchContext * context); -template NDArray NDArrayFactory::create(const std::vector &values, nd4j::LaunchContext * context); -template NDArray NDArrayFactory::create(const std::vector &values, nd4j::LaunchContext * context); -template NDArray NDArrayFactory::create(const std::vector &values, nd4j::LaunchContext * context); +template ND4J_EXPORT NDArray NDArrayFactory::create(const std::vector &values, nd4j::LaunchContext * context); +template ND4J_EXPORT NDArray NDArrayFactory::create(const std::vector &values, nd4j::LaunchContext * context); +template ND4J_EXPORT NDArray NDArrayFactory::create(const std::vector &values, nd4j::LaunchContext * context); +template ND4J_EXPORT NDArray NDArrayFactory::create(const std::vector &values, nd4j::LaunchContext * context); +template ND4J_EXPORT NDArray NDArrayFactory::create(const std::vector &values, nd4j::LaunchContext * context); +template ND4J_EXPORT NDArray NDArrayFactory::create(const std::vector &values, nd4j::LaunchContext * context); +template ND4J_EXPORT NDArray NDArrayFactory::create(const std::vector &values, nd4j::LaunchContext * context); +template ND4J_EXPORT NDArray NDArrayFactory::create(const std::vector &values, nd4j::LaunchContext * context); +template ND4J_EXPORT NDArray NDArrayFactory::create(const std::vector &values, nd4j::LaunchContext * context); +template ND4J_EXPORT NDArray NDArrayFactory::create(const std::vector &values, nd4j::LaunchContext * context); +template ND4J_EXPORT NDArray NDArrayFactory::create(const std::vector &values, nd4j::LaunchContext * context); //////////////////////////////////////////////////////////////////////// template @@ -466,7 +477,7 @@ template NDArray NDArrayFactory::create(const std::vector &values, nd4j::L return result; } - BUILD_SINGLE_TEMPLATE(template NDArray* NDArrayFactory::empty_, (nd4j::LaunchContext * context), LIBND4J_TYPES); + BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT NDArray* NDArrayFactory::empty_, (nd4j::LaunchContext * context), LIBND4J_TYPES); NDArray* NDArrayFactory::empty_(nd4j::DataType dataType, nd4j::LaunchContext * context) { if (context == nullptr) @@ -486,7 +497,7 @@ template NDArray NDArrayFactory::create(const std::vector &values, nd4j::L NDArray NDArrayFactory::empty(nd4j::LaunchContext * context) { return empty(DataTypeUtils::fromT(), context); } - BUILD_SINGLE_TEMPLATE(template NDArray NDArrayFactory::empty, (nd4j::LaunchContext * context), LIBND4J_TYPES); + BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT NDArray NDArrayFactory::empty, (nd4j::LaunchContext * context), LIBND4J_TYPES); //////////////////////////////////////////////////////////////////////// NDArray NDArrayFactory::empty(nd4j::DataType dataType, nd4j::LaunchContext * context) { @@ -529,16 +540,16 @@ NDArray NDArrayFactory::create(T* buffer, const char order, const std::initializ return result; } -template NDArray NDArrayFactory::create(double* buffer, const char order, const std::initializer_list& shape, nd4j::LaunchContext * context); -template NDArray NDArrayFactory::create(float* buffer, const char order, const std::initializer_list& shape, nd4j::LaunchContext * context); -template NDArray NDArrayFactory::create(float16* buffer, const char order, const std::initializer_list& shape, nd4j::LaunchContext * context); -template NDArray NDArrayFactory::create(bfloat16* buffer, const char order, const std::initializer_list& shape, nd4j::LaunchContext * context); -template NDArray NDArrayFactory::create(Nd4jLong * buffer, const char order, const std::initializer_list& shape, nd4j::LaunchContext * context); -template NDArray NDArrayFactory::create(int* buffer, const char order, const std::initializer_list& shape, nd4j::LaunchContext * context); -template NDArray NDArrayFactory::create(bool* buffer, const char order, const std::initializer_list& shape, nd4j::LaunchContext * context); -template NDArray NDArrayFactory::create(uint8_t * buffer, const char order, const std::initializer_list& shape, nd4j::LaunchContext * context); -template NDArray NDArrayFactory::create(int8_t* buffer, const char order, const std::initializer_list& shape, nd4j::LaunchContext * context); -template NDArray NDArrayFactory::create(int16_t* buffer, const char order, const std::initializer_list& shape, nd4j::LaunchContext * context); +template ND4J_EXPORT NDArray NDArrayFactory::create(double* buffer, const char order, const std::initializer_list& shape, nd4j::LaunchContext * context); +template ND4J_EXPORT NDArray NDArrayFactory::create(float* buffer, const char order, const std::initializer_list& shape, nd4j::LaunchContext * context); +template ND4J_EXPORT NDArray NDArrayFactory::create(float16* buffer, const char order, const std::initializer_list& shape, nd4j::LaunchContext * context); +template ND4J_EXPORT NDArray NDArrayFactory::create(bfloat16* buffer, const char order, const std::initializer_list& shape, nd4j::LaunchContext * context); +template ND4J_EXPORT NDArray NDArrayFactory::create(Nd4jLong * buffer, const char order, const std::initializer_list& shape, nd4j::LaunchContext * context); +template ND4J_EXPORT NDArray NDArrayFactory::create(int* buffer, const char order, const std::initializer_list& shape, nd4j::LaunchContext * context); +template ND4J_EXPORT NDArray NDArrayFactory::create(bool* buffer, const char order, const std::initializer_list& shape, nd4j::LaunchContext * context); +template ND4J_EXPORT NDArray NDArrayFactory::create(uint8_t * buffer, const char order, const std::initializer_list& shape, nd4j::LaunchContext * context); +template ND4J_EXPORT NDArray NDArrayFactory::create(int8_t* buffer, const char order, const std::initializer_list& shape, nd4j::LaunchContext * context); +template ND4J_EXPORT NDArray NDArrayFactory::create(int16_t* buffer, const char order, const std::initializer_list& shape, nd4j::LaunchContext * context); NDArray NDArrayFactory::string(char order, const std::vector &shape, const std::initializer_list &strings, nd4j::LaunchContext * context) { diff --git a/libnd4j/blas/cuda/NDArray.cu b/libnd4j/blas/cuda/NDArray.cu index f70760f9a..be90a22ae 100644 --- a/libnd4j/blas/cuda/NDArray.cu +++ b/libnd4j/blas/cuda/NDArray.cu @@ -150,7 +150,7 @@ void NDArray::fillAsTriangular(const float val, int lower, int upper, const char manager.synchronize(); } -BUILD_SINGLE_TEMPLATE(template void NDArray::fillAsTriangular, (const float val, int lower, int upper, const char direction, NDArray* target), LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT void NDArray::fillAsTriangular, (const float val, int lower, int upper, const char direction, NDArray* target), LIBND4J_TYPES); //////////////////////////////////////////////////////////////////////// template diff --git a/libnd4j/buildnativeoperations.sh b/libnd4j/buildnativeoperations.sh index 56e225a5d..351a4f8e2 100755 --- a/libnd4j/buildnativeoperations.sh +++ b/libnd4j/buildnativeoperations.sh @@ -168,140 +168,133 @@ fi case "$OS" in linux-armhf) - export RPI_BIN=$RPI_HOME/tools/arm-bcm2708/arm-rpi-4.9.3-linux-gnueabihf/bin/arm-linux-gnueabihf - export CMAKE_COMMAND="$CMAKE_COMMAND -D CMAKE_TOOLCHAIN_FILE=cmake/rpi.cmake" - if [ -z "$ARCH" ]; then + export RPI_BIN=$RPI_HOME/tools/arm-bcm2708/arm-rpi-4.9.3-linux-gnueabihf/bin/arm-linux-gnueabihf + export CMAKE_COMMAND="$CMAKE_COMMAND -D CMAKE_TOOLCHAIN_FILE=cmake/rpi.cmake" + if [ -z "$ARCH" ]; then ARCH="armv7-r" - fi + fi ;; linux-arm64) - if [ -z "$ARCH" ]; then + if [ -z "$ARCH" ]; then ARCH="armv8-a" - fi + fi ;; android-arm) - if [ -z "$ARCH" ]; then + if [ -z "$ARCH" ]; then ARCH="armv7-a" - fi - export ANDROID_BIN="$ANDROID_NDK/toolchains/arm-linux-androideabi-4.9/prebuilt/$KERNEL/" - export ANDROID_CPP="$ANDROID_NDK/sources/cxx-stl/llvm-libc++/" - export ANDROID_LLVM="$ANDROID_NDK/toolchains/llvm/prebuilt/$KERNEL/" - export ANDROID_ROOT="$ANDROID_NDK/platforms/android-14/arch-arm/" - export CMAKE_COMMAND="$CMAKE_COMMAND -DCMAKE_TOOLCHAIN_FILE=cmake/android-arm.cmake -DANDROID_BUILD=true" + fi + export ANDROID_BIN="$ANDROID_NDK/toolchains/arm-linux-androideabi-4.9/prebuilt/$KERNEL/" + export ANDROID_CPP="$ANDROID_NDK/sources/cxx-stl/llvm-libc++/" + export ANDROID_LLVM="$ANDROID_NDK/toolchains/llvm/prebuilt/$KERNEL/" + export ANDROID_ROOT="$ANDROID_NDK/platforms/android-14/arch-arm/" + export CMAKE_COMMAND="$CMAKE_COMMAND -DCMAKE_TOOLCHAIN_FILE=cmake/android-arm.cmake -DANDROID_BUILD=true" ;; android-arm64) - if [ -z "$ARCH" ]; then + if [ -z "$ARCH" ]; then ARCH="armv8-a" - fi - export ANDROID_BIN="$ANDROID_NDK/toolchains/aarch64-linux-android-4.9/prebuilt/$KERNEL/" - export ANDROID_CPP="$ANDROID_NDK/sources/cxx-stl/llvm-libc++/" - export ANDROID_LLVM="$ANDROID_NDK/toolchains/llvm/prebuilt/$KERNEL/" - export ANDROID_ROOT="$ANDROID_NDK/platforms/android-21/arch-arm64/" - export CMAKE_COMMAND="$CMAKE_COMMAND -DCMAKE_TOOLCHAIN_FILE=cmake/android-arm64.cmake -DANDROID_BUILD=true" + fi + export ANDROID_BIN="$ANDROID_NDK/toolchains/aarch64-linux-android-4.9/prebuilt/$KERNEL/" + export ANDROID_CPP="$ANDROID_NDK/sources/cxx-stl/llvm-libc++/" + export ANDROID_LLVM="$ANDROID_NDK/toolchains/llvm/prebuilt/$KERNEL/" + export ANDROID_ROOT="$ANDROID_NDK/platforms/android-21/arch-arm64/" + export CMAKE_COMMAND="$CMAKE_COMMAND -DCMAKE_TOOLCHAIN_FILE=cmake/android-arm64.cmake -DANDROID_BUILD=true" ;; android-x86) - if [ -z "$ARCH" ]; then + if [ -z "$ARCH" ]; then ARCH="i686" - fi - export ANDROID_BIN="$ANDROID_NDK/toolchains/x86-4.9/prebuilt/$KERNEL/" - export ANDROID_CPP="$ANDROID_NDK/sources/cxx-stl/llvm-libc++/" - export ANDROID_LLVM="$ANDROID_NDK/toolchains/llvm/prebuilt/$KERNEL/" - export ANDROID_ROOT="$ANDROID_NDK/platforms/android-14/arch-x86/" - export CMAKE_COMMAND="$CMAKE_COMMAND -DCMAKE_TOOLCHAIN_FILE=cmake/android-x86.cmake -DANDROID_BUILD=true" + fi + export ANDROID_BIN="$ANDROID_NDK/toolchains/x86-4.9/prebuilt/$KERNEL/" + export ANDROID_CPP="$ANDROID_NDK/sources/cxx-stl/llvm-libc++/" + export ANDROID_LLVM="$ANDROID_NDK/toolchains/llvm/prebuilt/$KERNEL/" + export ANDROID_ROOT="$ANDROID_NDK/platforms/android-14/arch-x86/" + export CMAKE_COMMAND="$CMAKE_COMMAND -DCMAKE_TOOLCHAIN_FILE=cmake/android-x86.cmake -DANDROID_BUILD=true" ;; android-x86_64) - if [ -z "$ARCH" ]; then + if [ -z "$ARCH" ]; then ARCH="x86-64" - fi - export ANDROID_BIN="$ANDROID_NDK/toolchains/x86_64-4.9/prebuilt/$KERNEL/" - export ANDROID_CPP="$ANDROID_NDK/sources/cxx-stl/llvm-libc++/" - export ANDROID_LLVM="$ANDROID_NDK/toolchains/llvm/prebuilt/$KERNEL/" - export ANDROID_ROOT="$ANDROID_NDK/platforms/android-21/arch-x86_64/" - export CMAKE_COMMAND="$CMAKE_COMMAND -DCMAKE_TOOLCHAIN_FILE=cmake/android-x86_64.cmake -DANDROID_BUILD=true" + fi + export ANDROID_BIN="$ANDROID_NDK/toolchains/x86_64-4.9/prebuilt/$KERNEL/" + export ANDROID_CPP="$ANDROID_NDK/sources/cxx-stl/llvm-libc++/" + export ANDROID_LLVM="$ANDROID_NDK/toolchains/llvm/prebuilt/$KERNEL/" + export ANDROID_ROOT="$ANDROID_NDK/platforms/android-21/arch-x86_64/" + export CMAKE_COMMAND="$CMAKE_COMMAND -DCMAKE_TOOLCHAIN_FILE=cmake/android-x86_64.cmake -DANDROID_BUILD=true" ;; ios-x86_64) - LIBTYPE="static" - ARCH="x86-64" - if xcrun --sdk iphoneos --show-sdk-version &> /dev/null; then - export IOS_VERSION="$(xcrun --sdk iphoneos --show-sdk-version)" - else + LIBTYPE="static" + ARCH="x86-64" + if xcrun --sdk iphoneos --show-sdk-version &> /dev/null; then + export IOS_VERSION="$(xcrun --sdk iphoneos --show-sdk-version)" + else export IOS_VERSION="10.3" - fi - XCODE_PATH="$(xcode-select --print-path)" - export IOS_SDK="$XCODE_PATH/Platforms/iPhoneSimulator.platform/Developer/SDKs/iPhoneSimulator$IOS_VERSION.sdk" - export CMAKE_COMMAND="$CMAKE_COMMAND -DCMAKE_TOOLCHAIN_FILE=cmake/ios-x86_64.cmake --debug-trycompile -DIOS_BUILD=true" + fi + XCODE_PATH="$(xcode-select --print-path)" + export IOS_SDK="$XCODE_PATH/Platforms/iPhoneSimulator.platform/Developer/SDKs/iPhoneSimulator$IOS_VERSION.sdk" + export CMAKE_COMMAND="$CMAKE_COMMAND -DCMAKE_TOOLCHAIN_FILE=cmake/ios-x86_64.cmake --debug-trycompile -DIOS_BUILD=true" ;; ios-x86) - LIBTYPE="static" - ARCH="i386" - if xcrun --sdk iphoneos --show-sdk-version &> /dev/null; then - export IOS_VERSION="$(xcrun --sdk iphoneos --show-sdk-version)" - else + LIBTYPE="static" + ARCH="i386" + if xcrun --sdk iphoneos --show-sdk-version &> /dev/null; then + export IOS_VERSION="$(xcrun --sdk iphoneos --show-sdk-version)" + else export IOS_VERSION="10.3" - fi - XCODE_PATH="$(xcode-select --print-path)" - export IOS_SDK="$XCODE_PATH/Platforms/iPhoneSimulator.platform/Developer/SDKs/iPhoneSimulator$IOS_VERSION.sdk" - export CMAKE_COMMAND="$CMAKE_COMMAND -DCMAKE_TOOLCHAIN_FILE=cmake/ios-x86.cmake --debug-trycompile -DIOS_BUILD=true" + fi + XCODE_PATH="$(xcode-select --print-path)" + export IOS_SDK="$XCODE_PATH/Platforms/iPhoneSimulator.platform/Developer/SDKs/iPhoneSimulator$IOS_VERSION.sdk" + export CMAKE_COMMAND="$CMAKE_COMMAND -DCMAKE_TOOLCHAIN_FILE=cmake/ios-x86.cmake --debug-trycompile -DIOS_BUILD=true" ;; ios-arm64) - LIBTYPE="static" - ARCH="arm64" - if xcrun --sdk iphoneos --show-sdk-version &> /dev/null; then - export IOS_VERSION="$(xcrun --sdk iphoneos --show-sdk-version)" - else + LIBTYPE="static" + ARCH="arm64" + if xcrun --sdk iphoneos --show-sdk-version &> /dev/null; then + export IOS_VERSION="$(xcrun --sdk iphoneos --show-sdk-version)" + else export IOS_VERSION="10.3" - fi - XCODE_PATH="$(xcode-select --print-path)" - export IOS_SDK="$XCODE_PATH/Platforms/iPhoneOS.platform/Developer/SDKs/iPhoneOS$IOS_VERSION.sdk" - export CMAKE_COMMAND="$CMAKE_COMMAND -DCMAKE_TOOLCHAIN_FILE=cmake/ios-arm64.cmake --debug-trycompile -DIOS_BUILD=true" + fi + XCODE_PATH="$(xcode-select --print-path)" + export IOS_SDK="$XCODE_PATH/Platforms/iPhoneOS.platform/Developer/SDKs/iPhoneOS$IOS_VERSION.sdk" + export CMAKE_COMMAND="$CMAKE_COMMAND -DCMAKE_TOOLCHAIN_FILE=cmake/ios-arm64.cmake --debug-trycompile -DIOS_BUILD=true" ;; ios-arm) - LIBTYPE="static" - ARCH="armv7" - if xcrun --sdk iphoneos --show-sdk-version &> /dev/null; then - export IOS_VERSION="$(xcrun --sdk iphoneos --show-sdk-version)" - else + LIBTYPE="static" + ARCH="armv7" + if xcrun --sdk iphoneos --show-sdk-version &> /dev/null; then + export IOS_VERSION="$(xcrun --sdk iphoneos --show-sdk-version)" + else export IOS_VERSION="10.3" - fi - XCODE_PATH="$(xcode-select --print-path)" - export IOS_SDK="$XCODE_PATH/Platforms/iPhoneOS.platform/Developer/SDKs/iPhoneOS$IOS_VERSION.sdk" - export CMAKE_COMMAND="$CMAKE_COMMAND -DCMAKE_TOOLCHAIN_FILE=cmake/ios-arm.cmake --debug-trycompile -DIOS_BUILD=true" + fi + XCODE_PATH="$(xcode-select --print-path)" + export IOS_SDK="$XCODE_PATH/Platforms/iPhoneOS.platform/Developer/SDKs/iPhoneOS$IOS_VERSION.sdk" + export CMAKE_COMMAND="$CMAKE_COMMAND -DCMAKE_TOOLCHAIN_FILE=cmake/ios-arm.cmake --debug-trycompile -DIOS_BUILD=true" ;; ios-armv7) - # change those 2 parameters and make sure the IOS_SDK exists - export iPhoneOS="iPhoneOS" - export IOS_VERSION="10.3" - LIBTYPE="static" - ARCH="armv7" - export IOS_SDK="/Applications/Xcode.app/Contents/Developer/Platforms/${iPhoneOS}.platform/Developer/SDKs/${iPhoneOS}${IOS_VERSION}.sdk" - export CMAKE_COMMAND="$CMAKE_COMMAND -DCMAKE_TOOLCHAIN_FILE=cmake/ios-armv7.cmake --debug-trycompile -DIOS_BUILD=true" + # change those 2 parameters and make sure the IOS_SDK exists + export iPhoneOS="iPhoneOS" + export IOS_VERSION="10.3" + LIBTYPE="static" + ARCH="armv7" + export IOS_SDK="/Applications/Xcode.app/Contents/Developer/Platforms/${iPhoneOS}.platform/Developer/SDKs/${iPhoneOS}${IOS_VERSION}.sdk" + export CMAKE_COMMAND="$CMAKE_COMMAND -DCMAKE_TOOLCHAIN_FILE=cmake/ios-armv7.cmake --debug-trycompile -DIOS_BUILD=true" ;; linux*) ;; macosx*) - # Do something under Mac OS X platform - #if [ "$CHIP" == "cuda" ]; then - export CC=clang - export CXX=clang++ - PARALLEL="true" - #else - # export CC="$(ls -1 /usr/local/bin/gcc-? | head -n 1)" - # export CXX="$(ls -1 /usr/local/bin/g++-? | head -n 1)" - # PARALLEL="true" - #fi - export CMAKE_COMMAND="$CMAKE_COMMAND -DCMAKE_MACOSX_RPATH=ON -DAPPLE_BUILD=true" + export CC=clang + export CXX=clang++ + PARALLEL="true" + export CMAKE_COMMAND="$CMAKE_COMMAND -DCMAKE_MACOSX_RPATH=ON -DAPPLE_BUILD=true" ;; windows*) diff --git a/libnd4j/include/array/impl/ExtraArguments.cpp b/libnd4j/include/array/impl/ExtraArguments.cpp index 55cda66b0..f9174ea0f 100644 --- a/libnd4j/include/array/impl/ExtraArguments.cpp +++ b/libnd4j/include/array/impl/ExtraArguments.cpp @@ -89,7 +89,7 @@ namespace nd4j { delete[] target; #endif } - BUILD_SINGLE_TEMPLATE(template void ExtraArguments::convertAndCopy, (Nd4jPointer pointer, Nd4jLong offset), LIBND4J_TYPES); + BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT void ExtraArguments::convertAndCopy, (Nd4jPointer pointer, Nd4jLong offset), LIBND4J_TYPES); void* ExtraArguments::allocate(size_t length, size_t elementSize) { #ifdef __CUDABLAS__ @@ -119,7 +119,7 @@ namespace nd4j { void* ExtraArguments::argumentsAsT(Nd4jLong offset) { return argumentsAsT(DataTypeUtils::fromT(), offset); } - BUILD_SINGLE_TEMPLATE(template void *ExtraArguments::argumentsAsT, (Nd4jLong offset), LIBND4J_TYPES); + BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT void *ExtraArguments::argumentsAsT, (Nd4jLong offset), LIBND4J_TYPES); void* ExtraArguments::argumentsAsT(nd4j::DataType dataType, Nd4jLong offset) { diff --git a/libnd4j/include/exceptions/allocation_exception.h b/libnd4j/include/exceptions/allocation_exception.h index 23c53d166..29756d253 100644 --- a/libnd4j/include/exceptions/allocation_exception.h +++ b/libnd4j/include/exceptions/allocation_exception.h @@ -24,9 +24,17 @@ #include #include #include +#include + +#if defined(_MSC_VER) + +// we're ignoring warning about non-exportable parent class, since std::runtime_error is a part of Standard C++ Library +#pragma warning( disable : 4275 ) + +#endif namespace nd4j { - class allocation_exception : public std::runtime_error { + class ND4J_EXPORT allocation_exception : public std::runtime_error { public: allocation_exception(std::string message); ~allocation_exception() = default; diff --git a/libnd4j/include/exceptions/cuda_exception.h b/libnd4j/include/exceptions/cuda_exception.h index 3f6fce4d5..5150033e8 100644 --- a/libnd4j/include/exceptions/cuda_exception.h +++ b/libnd4j/include/exceptions/cuda_exception.h @@ -23,9 +23,17 @@ #include #include +#include + +#if defined(_MSC_VER) + +// we're ignoring warning about non-exportable parent class, since std::runtime_error is a part of Standard C++ Library +#pragma warning( disable : 4275 ) + +#endif namespace nd4j { - class cuda_exception : public std::runtime_error { + class ND4J_EXPORT cuda_exception : public std::runtime_error { public: cuda_exception(std::string message); ~cuda_exception() = default; diff --git a/libnd4j/include/exceptions/datatype_exception.h b/libnd4j/include/exceptions/datatype_exception.h index 05e8ae14a..171a2b13b 100644 --- a/libnd4j/include/exceptions/datatype_exception.h +++ b/libnd4j/include/exceptions/datatype_exception.h @@ -24,9 +24,17 @@ #include #include #include +#include + +#if defined(_MSC_VER) + +// we're ignoring warning about non-exportable parent class, since std::runtime_error is a part of Standard C++ Library +#pragma warning( disable : 4275 ) + +#endif namespace nd4j { - class datatype_exception : public std::runtime_error { + class ND4J_EXPORT datatype_exception : public std::runtime_error { public: datatype_exception(std::string message); ~datatype_exception() = default; diff --git a/libnd4j/include/exceptions/graph_exception.h b/libnd4j/include/exceptions/graph_exception.h index 6daf833cf..440fa5aa4 100644 --- a/libnd4j/include/exceptions/graph_exception.h +++ b/libnd4j/include/exceptions/graph_exception.h @@ -24,9 +24,17 @@ #include #include #include +#include + +#if defined(_MSC_VER) + +// we're ignoring warning about non-exportable parent class, since std::runtime_error is a part of Standard C++ Library +#pragma warning( disable : 4275 ) + +#endif namespace nd4j { - class graph_exception : public std::runtime_error { + class ND4J_EXPORT graph_exception : public std::runtime_error { protected: Nd4jLong _graphId; std::string _message; diff --git a/libnd4j/include/exceptions/graph_execution_exception.h b/libnd4j/include/exceptions/graph_execution_exception.h index 03f0a37e4..92b02e2ee 100644 --- a/libnd4j/include/exceptions/graph_execution_exception.h +++ b/libnd4j/include/exceptions/graph_execution_exception.h @@ -25,9 +25,17 @@ #include #include #include +#include + +#if defined(_MSC_VER) + +// we're ignoring warning about non-exportable parent class, since std::runtime_error is a part of Standard C++ Library +#pragma warning( disable : 4275 ) + +#endif namespace nd4j { - class graph_execution_exception: public graph_exception { + class ND4J_EXPORT graph_execution_exception: public graph_exception { public: explicit graph_execution_exception(Nd4jLong graphId); }; diff --git a/libnd4j/include/exceptions/graph_exists_exception.h b/libnd4j/include/exceptions/graph_exists_exception.h index 355518d02..985770ad3 100644 --- a/libnd4j/include/exceptions/graph_exists_exception.h +++ b/libnd4j/include/exceptions/graph_exists_exception.h @@ -25,9 +25,17 @@ #include #include #include +#include + +#if defined(_MSC_VER) + +// we're ignoring warning about non-exportable parent class, since std::runtime_error is a part of Standard C++ Library +#pragma warning( disable : 4275 ) + +#endif namespace nd4j { - class graph_exists_exception: public graph_exception { + class ND4J_EXPORT graph_exists_exception: public graph_exception { public: explicit graph_exists_exception(Nd4jLong graphId); }; diff --git a/libnd4j/include/exceptions/no_results_exception.h b/libnd4j/include/exceptions/no_results_exception.h index f7673ed0c..0fa1bb167 100644 --- a/libnd4j/include/exceptions/no_results_exception.h +++ b/libnd4j/include/exceptions/no_results_exception.h @@ -25,9 +25,17 @@ #include #include #include +#include + +#if defined(_MSC_VER) + +// we're ignoring warning about non-exportable parent class, since std::runtime_error is a part of Standard C++ Library +#pragma warning( disable : 4275 ) + +#endif namespace nd4j { - class no_results_exception: public graph_exception { + class ND4J_EXPORT no_results_exception: public graph_exception { public: explicit no_results_exception(Nd4jLong graphId); }; diff --git a/libnd4j/include/exceptions/unknown_graph_exception.h b/libnd4j/include/exceptions/unknown_graph_exception.h index 90d9d8e2e..83efc9dcf 100644 --- a/libnd4j/include/exceptions/unknown_graph_exception.h +++ b/libnd4j/include/exceptions/unknown_graph_exception.h @@ -25,9 +25,17 @@ #include #include #include +#include + +#if defined(_MSC_VER) + +// we're ignoring warning about non-exportable parent class, since std::runtime_error is a part of Standard C++ Library +#pragma warning( disable : 4275 ) + +#endif namespace nd4j { - class unknown_graph_exception: public graph_exception { + class ND4J_EXPORT unknown_graph_exception: public graph_exception { public: explicit unknown_graph_exception(Nd4jLong graphId); }; diff --git a/libnd4j/include/execution/ThreadPool.h b/libnd4j/include/execution/ThreadPool.h index e17b4b540..6811f1b1c 100644 --- a/libnd4j/include/execution/ThreadPool.h +++ b/libnd4j/include/execution/ThreadPool.h @@ -33,7 +33,7 @@ #include namespace samediff { - class ThreadPool { + class ND4J_EXPORT ThreadPool { private: static ThreadPool* _INSTANCE; diff --git a/libnd4j/include/execution/Threads.h b/libnd4j/include/execution/Threads.h index be12a311a..14467883f 100644 --- a/libnd4j/include/execution/Threads.h +++ b/libnd4j/include/execution/Threads.h @@ -27,7 +27,7 @@ #include namespace samediff { - class ThreadsHelper { + class ND4J_EXPORT ThreadsHelper { public: static int numberOfThreads(int maxThreads, uint64_t numberOfElements); static int numberOfThreads2d(int maxThreads, uint64_t iters_x, uint64_t iters_y); @@ -36,7 +36,7 @@ namespace samediff { static int pickLoop3d(int numThreads, uint64_t iters_x, uint64_t iters_y, uint64_t iters_z); }; - class Span { + class ND4J_EXPORT Span { private: int64_t _startX, _stopX, _incX; public: @@ -50,7 +50,7 @@ namespace samediff { static Span build(uint64_t thread_id, uint64_t num_threads, int64_t start_x, int64_t stop_x, int64_t inc_x); }; - class Span2 { + class ND4J_EXPORT Span2 { private: int64_t _startX, _stopX, _incX; int64_t _startY, _stopY, _incY; @@ -70,7 +70,7 @@ namespace samediff { static Span2 build(int loop, uint64_t thread_id, uint64_t num_threads, int64_t start_x, int64_t stop_x, int64_t inc_x, int64_t start_y, int64_t stop_y, int64_t inc_y); }; - class Span3 { + class ND4J_EXPORT Span3 { private: int64_t _startX, _stopX, _incX; int64_t _startY, _stopY, _incY; @@ -94,7 +94,7 @@ namespace samediff { static Span3 build(int loop, uint64_t thread_id, uint64_t num_threads, int64_t start_x, int64_t stop_x, int64_t inc_x, int64_t start_y, int64_t stop_y, int64_t inc_y, int64_t start_z, int64_t stop_z, int64_t inc_z); }; - class Threads { + class ND4J_EXPORT Threads { public: /** * This function executes 1 dimensional loop for a given number of threads diff --git a/libnd4j/include/execution/Ticket.h b/libnd4j/include/execution/Ticket.h index e4152b66a..80bf54145 100644 --- a/libnd4j/include/execution/Ticket.h +++ b/libnd4j/include/execution/Ticket.h @@ -29,7 +29,7 @@ #include namespace samediff { - class Ticket { + class ND4J_EXPORT Ticket { private: bool _acquired = false; std::vector*> _queues; diff --git a/libnd4j/include/graph/Variable.h b/libnd4j/include/graph/Variable.h index 2e0053176..60f977e97 100644 --- a/libnd4j/include/graph/Variable.h +++ b/libnd4j/include/graph/Variable.h @@ -64,7 +64,7 @@ namespace nd4j { Variable* clone(); template - Variable* asT(); + ND4J_EXPORT Variable* asT(); bool hasNDArray(); nd4j::NDArray* getNDArray(); diff --git a/libnd4j/include/graph/impl/Node.cpp b/libnd4j/include/graph/impl/Node.cpp index 795d9b7f0..9d2224d2f 100644 --- a/libnd4j/include/graph/impl/Node.cpp +++ b/libnd4j/include/graph/impl/Node.cpp @@ -311,7 +311,7 @@ namespace nd4j { node->_dataType = DataTypeUtils::fromT(); return node; } - BUILD_SINGLE_TEMPLATE(template Node* Node::asT, (), LIBND4J_TYPES); + BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT Node* Node::asT, (), LIBND4J_TYPES); nd4j::graph::Node::Node(nd4j::ops::DeclarableOp *customOp, int id, std::initializer_list input, std::initializer_list output, std::initializer_list dimensions, float scalar, std::initializer_list tArgs, std::initializer_list iArgs) { this->_opType = OpType_CUSTOM; diff --git a/libnd4j/include/graph/impl/Variable.cpp b/libnd4j/include/graph/impl/Variable.cpp index e54112783..d77bded2e 100644 --- a/libnd4j/include/graph/impl/Variable.cpp +++ b/libnd4j/include/graph/impl/Variable.cpp @@ -50,7 +50,7 @@ namespace nd4j { return result; } - BUILD_SINGLE_TEMPLATE(template Variable* Variable::asT, (), LIBND4J_TYPES); + BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT Variable* Variable::asT, (), LIBND4J_TYPES); nd4j::graph::Variable* nd4j::graph::Variable::clone() { auto result = new Variable(this->isPlaceholder()); diff --git a/libnd4j/include/helpers/AttentionHelper.h b/libnd4j/include/helpers/AttentionHelper.h index a04b26ac8..186f959fd 100644 --- a/libnd4j/include/helpers/AttentionHelper.h +++ b/libnd4j/include/helpers/AttentionHelper.h @@ -24,7 +24,7 @@ #include "NDArray.h" namespace nd4j { - class AttentionHelper { + class ND4J_EXPORT AttentionHelper { public: static nd4j::NDArray multiHeadProject(const nd4j::NDArray* input, const nd4j::NDArray* projectionMatrix, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext()); diff --git a/libnd4j/include/helpers/BenchmarkHelper.h b/libnd4j/include/helpers/BenchmarkHelper.h index 58ed7e1b7..8dc946a2a 100644 --- a/libnd4j/include/helpers/BenchmarkHelper.h +++ b/libnd4j/include/helpers/BenchmarkHelper.h @@ -44,7 +44,7 @@ namespace nd4j { - class BenchmarkHelper { + class ND4J_EXPORT BenchmarkHelper { private: unsigned int _wIterations; unsigned int _rIterations; diff --git a/libnd4j/include/helpers/BitwiseUtils.h b/libnd4j/include/helpers/BitwiseUtils.h index e8990a34d..6defc4c49 100644 --- a/libnd4j/include/helpers/BitwiseUtils.h +++ b/libnd4j/include/helpers/BitwiseUtils.h @@ -28,7 +28,7 @@ #include namespace nd4j { - class BitwiseUtils { + class ND4J_EXPORT BitwiseUtils { public: diff --git a/libnd4j/include/helpers/CudaLaunchHelper.h b/libnd4j/include/helpers/CudaLaunchHelper.h index c8c22383c..9fec14764 100644 --- a/libnd4j/include/helpers/CudaLaunchHelper.h +++ b/libnd4j/include/helpers/CudaLaunchHelper.h @@ -28,7 +28,7 @@ #include namespace nd4j { - class CudaLaunchHelper { + class ND4J_EXPORT CudaLaunchHelper { public: static Triple getFlatLaunchParams(Nd4jLong length, int SM, int CORES, int SHARED_MEMORY); static int getReductionBlocks(Nd4jLong xLength, int blockSize = 512); diff --git a/libnd4j/include/helpers/DebugHelper.h b/libnd4j/include/helpers/DebugHelper.h index a932ac759..945bebe8e 100644 --- a/libnd4j/include/helpers/DebugHelper.h +++ b/libnd4j/include/helpers/DebugHelper.h @@ -40,7 +40,7 @@ #include namespace nd4j { class NDArray; - class DebugHelper { + class ND4J_EXPORT DebugHelper { public: // cuda-specific debug functions diff --git a/libnd4j/include/helpers/GradCheck.h b/libnd4j/include/helpers/GradCheck.h index cda0b5eae..32f66109a 100644 --- a/libnd4j/include/helpers/GradCheck.h +++ b/libnd4j/include/helpers/GradCheck.h @@ -27,7 +27,7 @@ namespace nd4j { -class GradCheck { +class ND4J_EXPORT GradCheck { public: enum LossFunc {MEAN = 0, SUM = 1}; diff --git a/libnd4j/include/helpers/MmulHelper.h b/libnd4j/include/helpers/MmulHelper.h index ff0a7d1b2..76244d050 100644 --- a/libnd4j/include/helpers/MmulHelper.h +++ b/libnd4j/include/helpers/MmulHelper.h @@ -25,7 +25,7 @@ #include "NDArray.h" namespace nd4j { - class MmulHelper { + class ND4J_EXPORT MmulHelper { private: diff --git a/libnd4j/include/helpers/OmpLaunchHelper.h b/libnd4j/include/helpers/OmpLaunchHelper.h index 1001d6163..dac93cbe2 100644 --- a/libnd4j/include/helpers/OmpLaunchHelper.h +++ b/libnd4j/include/helpers/OmpLaunchHelper.h @@ -28,7 +28,7 @@ namespace nd4j { -class OmpLaunchHelper { +class ND4J_EXPORT OmpLaunchHelper { public: diff --git a/libnd4j/include/helpers/PointersManager.h b/libnd4j/include/helpers/PointersManager.h index b0cc931ff..50fdbccf9 100644 --- a/libnd4j/include/helpers/PointersManager.h +++ b/libnd4j/include/helpers/PointersManager.h @@ -30,7 +30,7 @@ namespace nd4j { -class PointersManager { +class ND4J_EXPORT PointersManager { private: diff --git a/libnd4j/include/helpers/RandomLauncher.h b/libnd4j/include/helpers/RandomLauncher.h index 24921dc21..2e477e079 100644 --- a/libnd4j/include/helpers/RandomLauncher.h +++ b/libnd4j/include/helpers/RandomLauncher.h @@ -24,7 +24,7 @@ #include namespace nd4j { - class RandomLauncher { + class ND4J_EXPORT RandomLauncher { public: static void applyDropOut(nd4j::LaunchContext *context, nd4j::graph::RandomGenerator& rng, NDArray *array, double retainProb, NDArray* z = nullptr); static void applyInvertedDropOut(nd4j::LaunchContext *context, nd4j::graph::RandomGenerator& rng, NDArray *array, double retainProb, NDArray* z = nullptr); diff --git a/libnd4j/include/helpers/ShapeUtils.h b/libnd4j/include/helpers/ShapeUtils.h index 74719dabb..5f76c11b5 100644 --- a/libnd4j/include/helpers/ShapeUtils.h +++ b/libnd4j/include/helpers/ShapeUtils.h @@ -26,7 +26,7 @@ namespace nd4j { - class ShapeUtils { + class ND4J_EXPORT ShapeUtils { public: diff --git a/libnd4j/include/helpers/SimpleReadWriteLock.h b/libnd4j/include/helpers/SimpleReadWriteLock.h index cb82e7348..b7637f355 100644 --- a/libnd4j/include/helpers/SimpleReadWriteLock.h +++ b/libnd4j/include/helpers/SimpleReadWriteLock.h @@ -23,6 +23,7 @@ #include #include +#include /** * This class provides PRIMITIVE read-write lock, and should NOT be used outside of GraphServer due to its inefficiency. @@ -31,7 +32,7 @@ * Basic idea: write lock won't be obtained before all read requests served */ namespace nd4j { - class SimpleReadWriteLock { + class ND4J_EXPORT SimpleReadWriteLock { private: std::atomic _read_locks; std::atomic _write_locks; diff --git a/libnd4j/include/helpers/StringUtils.h b/libnd4j/include/helpers/StringUtils.h index 9891661ad..1a450450f 100644 --- a/libnd4j/include/helpers/StringUtils.h +++ b/libnd4j/include/helpers/StringUtils.h @@ -27,7 +27,7 @@ #include namespace nd4j { - class StringUtils { + class ND4J_EXPORT StringUtils { public: template static FORCEINLINE std::string valueToString(T value) { diff --git a/libnd4j/include/helpers/cublasHelper.h b/libnd4j/include/helpers/cublasHelper.h index 94cd2446b..53d30abf6 100644 --- a/libnd4j/include/helpers/cublasHelper.h +++ b/libnd4j/include/helpers/cublasHelper.h @@ -27,7 +27,7 @@ #include namespace nd4j { - class CublasHelper { + class ND4J_EXPORT CublasHelper { private: static CublasHelper *_INSTANCE; static std::mutex _mutex; diff --git a/libnd4j/include/memory/MemoryRegistrator.h b/libnd4j/include/memory/MemoryRegistrator.h index 8bf5918a6..53e97d35e 100644 --- a/libnd4j/include/memory/MemoryRegistrator.h +++ b/libnd4j/include/memory/MemoryRegistrator.h @@ -24,10 +24,11 @@ #include "Workspace.h" #include #include +#include namespace nd4j { namespace memory { - class MemoryRegistrator { + class ND4J_EXPORT MemoryRegistrator { protected: static MemoryRegistrator* _INSTANCE; Workspace* _workspace; diff --git a/libnd4j/include/memory/MemoryReport.h b/libnd4j/include/memory/MemoryReport.h index 863c439ee..636178d45 100644 --- a/libnd4j/include/memory/MemoryReport.h +++ b/libnd4j/include/memory/MemoryReport.h @@ -22,10 +22,11 @@ #define LIBND4J_MEMORYREPORT_H #include +#include namespace nd4j { namespace memory { - class MemoryReport { + class ND4J_EXPORT MemoryReport { private: Nd4jLong _vm = 0; Nd4jLong _rss = 0; diff --git a/libnd4j/include/memory/MemoryUtils.h b/libnd4j/include/memory/MemoryUtils.h index 985ca466d..5fe27898c 100644 --- a/libnd4j/include/memory/MemoryUtils.h +++ b/libnd4j/include/memory/MemoryUtils.h @@ -22,10 +22,11 @@ #define LIBND4J_MEMORYUTILS_H #include "MemoryReport.h" +#include namespace nd4j { namespace memory { - class MemoryUtils { + class ND4J_EXPORT MemoryUtils { public: static bool retrieveMemoryStatistics(MemoryReport& report); }; diff --git a/libnd4j/include/ops/BroadcastBoolOpsTuple.h b/libnd4j/include/ops/BroadcastBoolOpsTuple.h index 9bffc1198..7b0f96505 100644 --- a/libnd4j/include/ops/BroadcastBoolOpsTuple.h +++ b/libnd4j/include/ops/BroadcastBoolOpsTuple.h @@ -22,9 +22,10 @@ #define DEV_TESTS_BROADCASTBOOLOPSTUPLE_H #include +#include namespace nd4j { - class BroadcastBoolOpsTuple { + class ND4J_EXPORT BroadcastBoolOpsTuple { private: public: diff --git a/libnd4j/include/ops/BroadcastIntOpsTuple.h b/libnd4j/include/ops/BroadcastIntOpsTuple.h index df40907a9..c96244b1a 100644 --- a/libnd4j/include/ops/BroadcastIntOpsTuple.h +++ b/libnd4j/include/ops/BroadcastIntOpsTuple.h @@ -22,9 +22,10 @@ #define DEV_TESTS_BROADCASTINTOPSTUPLE_H #include +#include namespace nd4j { - class BroadcastIntOpsTuple { + class ND4J_EXPORT BroadcastIntOpsTuple { private: public: diff --git a/libnd4j/include/ops/BroadcastOpsTuple.h b/libnd4j/include/ops/BroadcastOpsTuple.h index 0450e50ab..256e37341 100644 --- a/libnd4j/include/ops/BroadcastOpsTuple.h +++ b/libnd4j/include/ops/BroadcastOpsTuple.h @@ -22,9 +22,10 @@ #define DEV_TESTS_BROADCASTOPSTUPLE_H #include +#include namespace nd4j { - class BroadcastOpsTuple { + class ND4J_EXPORT BroadcastOpsTuple { private: public: diff --git a/libnd4j/include/ops/declarable/generic/helpers/BroadcastHelper.h b/libnd4j/include/ops/declarable/generic/helpers/BroadcastHelper.h index 997857bf3..5e91641ca 100644 --- a/libnd4j/include/ops/declarable/generic/helpers/BroadcastHelper.h +++ b/libnd4j/include/ops/declarable/generic/helpers/BroadcastHelper.h @@ -108,6 +108,9 @@ namespace nd4j { if (!x->isScalar() && !y->isScalar() && x->isSameShape(y)) { x->applyPairwiseTransform(op.p, y, z, nullptr); + } else if (ShapeUtils::areShapesBroadcastable(*x, *y)) { + x->applyTrueBroadcast(op, y, z, true, extraArgs); + return z; } else if (!x->isScalar() && y->isScalar()) { x->applyScalarArr(op.s, const_cast(y), z); } else if (x->isScalar() && !y->isScalar()) { diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/adjust_contrast.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/adjust_contrast.cpp index cc11eedca..d790dd9c2 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/adjust_contrast.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/adjust_contrast.cpp @@ -32,6 +32,10 @@ CONFIGURABLE_OP_IMPL(adjust_contrast, 1, 1, true, -2, 0) { auto input = INPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0); + // just skip op if input is empty + if (input->isEmpty()) + return Status::OK(); + REQUIRE_TRUE(block.numT() > 0 || block.width() > 1, 0, "ADJUST_CONTRAST: Scale factor required"); const double factor = block.width() > 1 ? INPUT_VARIABLE(1)->e(0) : T_ARG(0); @@ -70,6 +74,10 @@ DECLARE_TYPES(adjust_contrast) { const double factor = block.width() > 1 ? INPUT_VARIABLE(1)->e(0) : T_ARG(0); + // just skip op if input is empty + if (input->isEmpty()) + return Status::OK(); + REQUIRE_TRUE(input->rankOf() > 2, 0, "ADJUST_CONTRAST_V2: op expects rank of input array to be >= 3, but got %i instead", input->rankOf()); REQUIRE_TRUE(input->sizeAt(-1) == 3, 0, "ADJUST_CONTRAST_V2: operation expects image with 3 channels (R, G, B), but got %i instead", input->sizeAt(-1)); diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/adjust_hue.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/adjust_hue.cpp index 16062769a..d1d81acf8 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/adjust_hue.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/adjust_hue.cpp @@ -35,6 +35,10 @@ CONFIGURABLE_OP_IMPL(adjust_hue, 1, 1, true, 1, -2) { auto input = INPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0); + // just skip op if input is empty + if (input->isEmpty()) + return Status::OK(); + const int rank = input->rankOf(); const int dimC = block.getIArguments()->size() > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1; const double delta = T_ARG(0); diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/adjust_saturation.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/adjust_saturation.cpp index b4472bef5..5030e5952 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/adjust_saturation.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/adjust_saturation.cpp @@ -33,6 +33,10 @@ CONFIGURABLE_OP_IMPL(adjust_saturation, 1, 1, true, 1, -2) { auto input = INPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0); + // just skip op if input is empty + if (input->isEmpty()) + return Status::OK(); + const int rank = input->rankOf(); const int dimC = block.getIArguments()->size() > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1; const double factor = T_ARG(0); diff --git a/libnd4j/include/ops/declarable/generic/shape/create.cpp b/libnd4j/include/ops/declarable/generic/shape/create.cpp index e743a5cad..c87f63a56 100644 --- a/libnd4j/include/ops/declarable/generic/shape/create.cpp +++ b/libnd4j/include/ops/declarable/generic/shape/create.cpp @@ -25,6 +25,7 @@ namespace nd4j { namespace ops { + CUSTOM_OP_IMPL(create, 1, 1, false, 0, 1) { auto init = block.numB() > 0 ? B_ARG(0) : true; diff --git a/libnd4j/include/ops/declarable/helpers/activations.h b/libnd4j/include/ops/declarable/helpers/activations.h index 67d80f3c2..331170369 100644 --- a/libnd4j/include/ops/declarable/helpers/activations.h +++ b/libnd4j/include/ops/declarable/helpers/activations.h @@ -27,23 +27,23 @@ namespace nd4j { namespace ops { namespace helpers { - void softMaxForVector(nd4j::LaunchContext * context, const NDArray &input, NDArray &output); + ND4J_EXPORT void softMaxForVector(nd4j::LaunchContext * context, const NDArray &input, NDArray &output); - void logSoftMaxForVector(nd4j::LaunchContext * context, const NDArray &input, NDArray &output); + ND4J_EXPORT void logSoftMaxForVector(nd4j::LaunchContext * context, const NDArray &input, NDArray &output); - void softmax(nd4j::LaunchContext * context, const NDArray &input, NDArray &output, const int dimension); + ND4J_EXPORT void softmax(nd4j::LaunchContext * context, const NDArray &input, NDArray &output, const int dimension); - void logSoftmax(nd4j::LaunchContext * context, const NDArray &input, NDArray &output, const int dimension); + ND4J_EXPORT void logSoftmax(nd4j::LaunchContext * context, const NDArray &input, NDArray &output, const int dimension); - void softmaxDerivative(nd4j::LaunchContext * context, const NDArray& input, NDArray& output, const int dimension); + ND4J_EXPORT void softmaxDerivative(nd4j::LaunchContext * context, const NDArray& input, NDArray& output, const int dimension); - void prelu(nd4j::LaunchContext * context, const NDArray &input, const NDArray &alpha, NDArray &output); + ND4J_EXPORT void prelu(nd4j::LaunchContext * context, const NDArray &input, const NDArray &alpha, NDArray &output); - void preluBP(nd4j::LaunchContext * context, const NDArray &input, const NDArray &alpha, const NDArray &dLdO, NDArray &dLdI, NDArray &dLdA); + ND4J_EXPORT void preluBP(nd4j::LaunchContext * context, const NDArray &input, const NDArray &alpha, const NDArray &dLdO, NDArray &dLdI, NDArray &dLdA); - void thresholdRelu(nd4j::LaunchContext * context, const NDArray &input, double threshold, NDArray &output); + ND4J_EXPORT void thresholdRelu(nd4j::LaunchContext * context, const NDArray &input, double threshold, NDArray &output); - void thresholdReluDerivative(nd4j::LaunchContext * context, NDArray *input, double threshold, NDArray* dLdO, NDArray *output); + ND4J_EXPORT void thresholdReluDerivative(nd4j::LaunchContext * context, NDArray *input, double threshold, NDArray* dLdO, NDArray *output); } } } diff --git a/libnd4j/include/ops/declarable/helpers/col2im.h b/libnd4j/include/ops/declarable/helpers/col2im.h index 793da4798..66d7a684a 100644 --- a/libnd4j/include/ops/declarable/helpers/col2im.h +++ b/libnd4j/include/ops/declarable/helpers/col2im.h @@ -27,7 +27,7 @@ namespace nd4j { namespace ops { namespace helpers { - void col2im(nd4j::LaunchContext & context, const NDArray& input, NDArray& output, const int sH, const int sW, const int pH, const int pW, const int iH, const int iW, const int dH, const int dW); + ND4J_EXPORT void col2im(nd4j::LaunchContext & context, const NDArray& input, NDArray& output, const int sH, const int sW, const int pH, const int pW, const int iH, const int iW, const int dH, const int dW); } diff --git a/libnd4j/include/ops/declarable/helpers/convolutions.h b/libnd4j/include/ops/declarable/helpers/convolutions.h index 65544960a..68b39cfd5 100644 --- a/libnd4j/include/ops/declarable/helpers/convolutions.h +++ b/libnd4j/include/ops/declarable/helpers/convolutions.h @@ -23,6 +23,7 @@ #include #include +#include #include @@ -35,7 +36,7 @@ namespace nd4j { PNORM_POOL = 2, }; - class ConvolutionUtils { + class ND4J_EXPORT ConvolutionUtils { public: static inline void calcOutSizePool2D(int& oH, int& oW, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int iH, const int iW, const int paddingMode) { diff --git a/libnd4j/include/ops/declarable/helpers/im2col.h b/libnd4j/include/ops/declarable/helpers/im2col.h index 04559e494..f484c9bc4 100644 --- a/libnd4j/include/ops/declarable/helpers/im2col.h +++ b/libnd4j/include/ops/declarable/helpers/im2col.h @@ -27,7 +27,7 @@ namespace nd4j { namespace ops { namespace helpers { - void im2col(nd4j::LaunchContext & context, const NDArray& im, NDArray& col, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const NDArray& arrZeroPadVal); + ND4J_EXPORT void im2col(nd4j::LaunchContext & context, const NDArray& im, NDArray& col, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const NDArray& arrZeroPadVal); } } } diff --git a/libnd4j/include/ops/declarable/helpers/lstmLayer.h b/libnd4j/include/ops/declarable/helpers/lstmLayer.h index 7d94c32e0..d0bc16b66 100644 --- a/libnd4j/include/ops/declarable/helpers/lstmLayer.h +++ b/libnd4j/include/ops/declarable/helpers/lstmLayer.h @@ -29,13 +29,13 @@ namespace ops { namespace helpers { ////////////////////////////////////////////////////////////////////////// -void lstmLayerCell(const NDArray* x, const NDArray* Wx, const NDArray* Wr, +void ND4J_EXPORT lstmLayerCell(const NDArray* x, const NDArray* Wx, const NDArray* Wr, const NDArray* b, const NDArray* hI, const NDArray* cI, const NDArray* Wp, const std::vector& params, NDArray* h, NDArray* c); ////////////////////////////////////////////////////////////////////////// -void lstmLayerTimeLoop(const NDArray* x, const NDArray* Wx, const NDArray* Wr, +void ND4J_EXPORT lstmLayerTimeLoop(const NDArray* x, const NDArray* Wx, const NDArray* Wr, const NDArray* b, const NDArray* seqLen, const NDArray* hI, const NDArray* cI, const NDArray* Wp, const std::vector& params, const bool forward, diff --git a/libnd4j/include/ops/declarable/helpers/multiUnique.h b/libnd4j/include/ops/declarable/helpers/multiUnique.h index 587ce44f0..12fa6db10 100644 --- a/libnd4j/include/ops/declarable/helpers/multiUnique.h +++ b/libnd4j/include/ops/declarable/helpers/multiUnique.h @@ -26,7 +26,7 @@ namespace nd4j { namespace ops { namespace helpers { - bool multiUnique(std::vector const& inputList, nd4j::memory::Workspace* workspace = nullptr); + ND4J_EXPORT bool multiUnique(std::vector const& inputList, nd4j::memory::Workspace* workspace = nullptr); } } diff --git a/libnd4j/tests_cpu/layers_tests/BroadcastableOpsTests.cpp b/libnd4j/tests_cpu/layers_tests/BroadcastableOpsTests.cpp index 33a8fa10a..ffa19412a 100644 --- a/libnd4j/tests_cpu/layers_tests/BroadcastableOpsTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/BroadcastableOpsTests.cpp @@ -774,8 +774,27 @@ TEST_F(BroadcastableOpsTests, broadcast_bool_2) { ASSERT_TRUE(z.equalsTo(e)); } -TEST_F(BroadcastableOpsTests, broadcast_2) { +TEST_F(BroadcastableOpsTests, broadcast_bool_3) { + auto x = NDArrayFactory::create(0); + auto y = NDArrayFactory::create('c', {3}, {2, 1, 2}); + NDArray z('c', {3}, nd4j::DataType::BOOL); + NDArray e('c', {3}, nd4j::DataType::BOOL); + + e.assign(true); + + nd4j::ops::less op; + auto status = op.execute({&x, &y}, {&z}, {}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, status); + + // z.printIndexedBuffer("Z"); + + ASSERT_TRUE(z.isSameShape(e)); + ASSERT_TRUE(z.equalsTo(e)); +} + +TEST_F(BroadcastableOpsTests, broadcast_2) { NDArray x('c', {3, 1, 2}, nd4j::DataType::FLOAT32); NDArray y('c', {2, 2}, nd4j::DataType::FLOAT32); NDArray z('c', {3, 2, 2}, nd4j::DataType::FLOAT32); @@ -797,3 +816,19 @@ TEST_F(BroadcastableOpsTests, broadcast_2) { ASSERT_TRUE(z.equalsTo(e)); } +TEST_F(BroadcastableOpsTests, broadcast_3) { + auto x = NDArrayFactory::create(0); + auto y = NDArrayFactory::create('c', {3}, {2, 1, 2}); + NDArray z('c', {3}, nd4j::DataType::INT32); + auto e = NDArrayFactory::create('c', {3}, {2, 1, 2}); + + nd4j::ops::add op; + auto status = op.execute({&x, &y}, {&z}, {}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, status); + + // z.printIndexedBuffer("Z"); + + ASSERT_TRUE(z.isSameShape(e)); + ASSERT_TRUE(z.equalsTo(e)); +} diff --git a/libnd4j/tests_cpu/layers_tests/CMakeLists.txt b/libnd4j/tests_cpu/layers_tests/CMakeLists.txt index 1d5a1df98..52fa0ca17 100644 --- a/libnd4j/tests_cpu/layers_tests/CMakeLists.txt +++ b/libnd4j/tests_cpu/layers_tests/CMakeLists.txt @@ -29,7 +29,7 @@ if (CUDA_BLAS) if(WIN32) message("CUDA on Windows: enabling /EHsc") - SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /EHsc /FS /w") + SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /EHsc /FS") SET_TARGET_PROPERTIES(${LIBND4J_NAME} PROPERTIES COMPILER_FLAGS "/EHsc") endif() diff --git a/libnd4j/tests_cpu/layers_tests/CudaBasicsTests1.cu b/libnd4j/tests_cpu/layers_tests/CudaBasicsTests1.cu index c8b6fa1d9..593d47bb5 100644 --- a/libnd4j/tests_cpu/layers_tests/CudaBasicsTests1.cu +++ b/libnd4j/tests_cpu/layers_tests/CudaBasicsTests1.cu @@ -97,6 +97,8 @@ TEST_F(CudaBasicsTests1, TestPairwise_1) { cudaMemcpyAsync(devBufferPtrX, x.buffer(), x.lengthOf() * x.sizeOfT(), cudaMemcpyHostToDevice, *stream); cudaMemcpyAsync(devShapePtrX, x.shapeInfo(), shape::shapeInfoByteLength(x.shapeInfo()), cudaMemcpyHostToDevice, *stream); + res = cudaStreamSynchronize(*stream); + ASSERT_EQ(0, res); LaunchContext lc(stream, nullptr, nullptr); NativeOpExecutioner::execPairwiseTransform(&lc, pairwise::Add, nullptr, x.shapeInfo(), devBufferPtrX, reinterpret_cast(devShapePtrX), nullptr, x.shapeInfo(), devBufferPtrX, reinterpret_cast(devShapePtrX), nullptr, z.shapeInfo(), devBufferPtrZ, reinterpret_cast(devShapePtrX), nullptr); @@ -117,6 +119,7 @@ TEST_F(CudaBasicsTests1, TestPairwise_1) { z.tickWriteHost(); for (int e = 0; e < z.lengthOf(); e++) { + nd4j_printf("step %i\n", e); ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); } } @@ -169,6 +172,8 @@ TEST_F(CudaBasicsTests1, execIndexReduceScalar_1) { void* reductionPointer = nullptr; cudaResult = cudaMalloc(reinterpret_cast(&reductionPointer), 1024*1024); ASSERT_EQ(0, cudaResult); + cudaResult = cudaMemset(reductionPointer, 0, 1024 * 1024); + ASSERT_EQ(0, cudaResult); LaunchContext lc(&stream, LaunchContext::defaultContext()->getReductionPointer(), LaunchContext::defaultContext()->getScalarPointer(), LaunchContext::defaultContext()->getAllocationPointer()); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests6.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests6.cpp index 2fbd42af7..67cd56d5e 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests6.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests6.cpp @@ -1576,32 +1576,32 @@ TEST_F(DeclarableOpsTests6, LogDet_3) { //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, MatrixInverse_1) { - auto x = NDArrayFactory::create('c', {2, 5, 5}, { - 2., 4., 60., 8., 10., - 0., 1., 2., 3., 4., - 0., 0., 2., 4., 6., - 0., 0., 0., 1., 2., - 0., 0., 0., 0., 4., + auto x = NDArrayFactory::create('c', {2, 5, 5}, { + 2.f, 4.f, 60.f, 8.f, 10.f, + 0.f, 1.f, 2.f, 3.f, 4.f, + 0.f, 0.f, 2.f, 4.f, 6.f, + 0.f, 0.f, 0.f, 1.f, 2.f, + 0.f, 0.f, 0.f, 0.f, 4.f, - 1., 0., 0., 0., 0., - 2., 1., 0., 0., 0., - 30., 2., 1., 0., 0., - 4., 3., 2., 1., 0., - 5., 4., 3., 2., 1., + 1.f, 0.f, 0.f, 0.f, 0.f, + 2.f, 1.f, 0.f, 0.f, 0.f, + 30.f, 2.f, 1.f, 0.f, 0.f, + 4.f, 3.f, 2.f, 1.f, 0.f, + 5.f, 4.f, 3.f, 2.f, 1.f }); - auto exp = NDArrayFactory::create('c', {2, 5, 5}, { - 0.5, -2.0, -13.0, 54.0, -6.75, - 0.0, 1.0, -1.0, 1.0, 0.0, - 0, 0, 0.5, -2.0, 0.25, - 0, 0, 0, 1.0, -0.5, - 0, 0, 0, 0, 0.25, + auto exp = NDArrayFactory::create('c', {2, 5, 5}, { + 0.5f, -2.0f, -13.0f, 54.0f, -6.75f, + 0.0f, 1.0f, -1.0f, 1.0f, 0.0f, + 0.f, 0.f, 0.5f, -2.0f, 0.25f, + 0.f, 0.f, 0.f, 1.0f, -0.5f, + 0.f, 0.f, 0.f, 0.f, 0.25f, - 1.0, 0.0, 0.0, 0.0, 0., - -2.0, 1.0, 0., 0., 0., - -26.0, -2.0, 1, 0, 0., - 54.0, 1.0, -2.0, 1, 0., - -27.0, 0.0, 1.0, -2.0, 1. + 1.0f, 0.0f, 0.0f, 0.0f, 0.f, + -2.0f, 1.0f, 0.f, 0.f, 0.f, + -26.0f, -2.0f, 1.f, 0.f, 0.f, + 54.0f, 1.0f, -2.0f, 1.f, 0.f, + -27.0f, 0.0f, 1.0f, -2.0f, 1.f, }); nd4j::ops::matrix_inverse op; @@ -1620,8 +1620,8 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_1) { //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, MatrixInverse_010) { - auto x = NDArrayFactory::create('c', {1, 5, 5}, {1., 0., 0., 0., 0.,2., 1., 0., 0., 0.,30., 2., 1., 0., 0.,4., 3., 2., 1., 0.,5., 4., 3., 2., 1.,}); - auto exp = NDArrayFactory::create('c', {1, 5, 5}, {1.0, 0.0, 0.0, 0.0, 0.,-2.0, 1.0, 0., 0., 0.,-26.0, -2.0, 1, 0, 0.,54.0, 1.0, -2.0, 1, 0.,-27.0, 0.0, 1.0, -2.0, 1.}); + auto x = NDArrayFactory::create('c', {1, 5, 5}, {1.f, 0.f, 0.f, 0.f, 0.f, 2.f, 1.f, 0.f, 0.f, 0.f, 30.f, 2.f, 1.f, 0.f, 0.f, 4.f, 3.f, 2.f, 1.f, 0.f, 5.f, 4.f, 3.f, 2.f, 1.f, }); + auto exp = NDArrayFactory::create('c', {1, 5, 5}, {1.0f, 0.0f, 0.0f, 0.0f, 0.f, -2.0f, 1.0f, 0.f, 0.f, 0.f, -26.0f, -2.0f, 1.f, 0.f, 0.f, 54.0f, 1.0f, -2.0f, 1.f, 0.f, -27.0f, 0.0f, 1.0f, -2.0f, 1.f}); nd4j::ops::matrix_inverse op; auto result = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::FLOAT32); @@ -1639,9 +1639,9 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_010) { //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, MatrixInverse_01) { - auto x = NDArrayFactory::create('c', {1, 5, 5}, {2., 4., 60., 8., 10., 0., 1., 2., 3., 4., 0., 0., 2., 4., 6., 0., 0., 0., 1., 2., 0., 0., 0., 0., 4. }); + auto x = NDArrayFactory::create('c', {1, 5, 5}, {2.f, 4.f, 60.f, 8.f, 10.f, 0.f, 1.f, 2.f, 3.f, 4.f, 0.f, 0.f, 2.f, 4.f, 6.f, 0.f, 0.f, 0.f, 1.f, 2.f, 0.f, 0.f, 0.f, 0.f, 4.f }); - auto exp = NDArrayFactory::create('c', {1, 5, 5}, {0.5, -2.0, -13.0, 54.0, -6.75, 0.0, 1.0, -1.0, 1.0, 0.0, 0, 0, 0.5, -2.0, 0.25, 0, 0, 0, 1.0, -0.5, 0, 0, 0, 0, 0.25 }); + auto exp = NDArrayFactory::create('c', {1, 5, 5}, {0.5f, -2.0f, -13.0f, 54.0f, -6.75f, 0.0f, 1.0f, -1.0f, 1.0f, 0.0f, 0.f, 0.f, 0.5f, -2.0f, 0.25f, 0.f, 0.f, 0.f, 1.0f, -0.5f, 0.f, 0.f, 0.f, 0.f, 0.25f }); nd4j::ops::matrix_inverse op; auto result = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::FLOAT32); @@ -1658,8 +1658,8 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_01) { //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, MatrixInverse_02) { - auto x = NDArrayFactory::create('c', {1, 5, 5}, {1., 0., 0., 0., 0., 2., 1., 0., 0., 0., 30., 2., 1., 0., 0., 4., 3., 2., 1., 0., 5., 4., 3., 2., 1. }); - auto exp = NDArrayFactory::create('c', {1, 5, 5}, {1.0, 0.0, 0.0, 0.0, 0., -2.0, 1.0, 0., 0., 0., -26.0, -2.0, 1, 0, 0., 54.0, 1.0, -2.0, 1, 0., -27.0, 0.0, 1.0, -2.0, 1. }); + auto x = NDArrayFactory::create('c', {1, 5, 5}, {1.f, 0.f, 0.f, 0.f, 0.f, 2.f, 1.f, 0.f, 0.f, 0.f, 30.f, 2.f, 1.f, 0.f, 0.f, 4.f, 3.f, 2.f, 1.f, 0.f, 5.f, 4.f, 3.f, 2.f, 1.f }); + auto exp = NDArrayFactory::create('c', {1, 5, 5}, {1.0f, 0.0f, 0.0f, 0.0f, 0.f, -2.0f, 1.0f, 0.f, 0.f, 0.f, -26.0f, -2.0f, 1.f, 0.f, 0.f, 54.0f, 1.0f, -2.0f, 1.f, 0.f, -27.0f, 0.0f, 1.0f, -2.0f, 1.f }); nd4j::ops::matrix_inverse op; auto result = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::FLOAT32); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/broadcast/BasicBroadcastTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/broadcast/BasicBroadcastTests.java index d9057c95a..7b72f4bae 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/broadcast/BasicBroadcastTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/broadcast/BasicBroadcastTests.java @@ -24,6 +24,7 @@ import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.impl.transforms.custom.LessThan; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.RealDivOp; import org.nd4j.linalg.factory.Nd4j; @@ -288,6 +289,30 @@ public class BasicBroadcastTests extends BaseNd4jTest { } } + @Test + public void testLt(){ + INDArray x = Nd4j.scalar(0); + INDArray y = Nd4j.createFromArray(2,1,2); + + INDArray result = Nd4j.create(DataType.BOOL, 3); + INDArray lt = Nd4j.exec(new LessThan(x,y,result))[0]; + + INDArray exp = Nd4j.createFromArray(true, true, true); + assertEquals(exp, lt); + } + + @Test + public void testAdd(){ + INDArray x = Nd4j.scalar(0); + INDArray y = Nd4j.createFromArray(2,1,2); + + INDArray result = Nd4j.create(DataType.INT, 3); + INDArray sum = Nd4j.exec(new AddOp(x,y,result))[0]; + + INDArray exp = Nd4j.createFromArray(2, 1, 2); + assertEquals(exp, sum); + } + @Override public char ordering() { return 'c'; From 1e9ff114aa7d51617478a21a36d66074a74426d1 Mon Sep 17 00:00:00 2001 From: shugeo Date: Mon, 2 Dec 2019 20:40:54 +0200 Subject: [PATCH 22/30] Shugeo atomic tests (#97) * Added atomic tests for atomicAdd, atomicSub and atomicDiv. * Fixed atomicAdd for 16bit ints. * Fixed atomicMul for 16 floats. * Eliminated waste prints. * Fixed problems with double type on matrix inverse helepers. * Eliminated commented wrong code. * Refactored atomicMul for 16bit types. * few more minor tweaks Signed-off-by: raver119 * Fixed fake_quant_with_min_max_vars_per_channel args processing. --- ...ke_quant_with_min_max_vars_per_channel.cpp | 15 +- .../ops/declarable/helpers/cuda/lup.cu | 32 +++- libnd4j/include/templatemath.h | 81 ++++++--- libnd4j/tests_cpu/layers_tests/AtomicTests.cu | 157 +++++++++++++++++- .../layers_tests/DeclarableOpsTests10.cpp | 94 +++++++++++ 5 files changed, 338 insertions(+), 41 deletions(-) diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/fake_quant_with_min_max_vars_per_channel.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/fake_quant_with_min_max_vars_per_channel.cpp index 5874d2f81..8f379911b 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/fake_quant_with_min_max_vars_per_channel.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/fake_quant_with_min_max_vars_per_channel.cpp @@ -25,13 +25,12 @@ #include namespace nd4j { namespace ops { - CONFIGURABLE_OP_IMPL(fake_quant_with_min_max_vars_per_channel, 1, 1, true, 0, 0) { + CONFIGURABLE_OP_IMPL(fake_quant_with_min_max_vars_per_channel, 3, 1, true, 0, 0) { auto x = INPUT_VARIABLE(0); auto min = INPUT_VARIABLE(1); auto max = INPUT_VARIABLE(2); - REQUIRE_TRUE(block.width() == 3 || block.getTArguments()->size() == 2, 0, "fake_quant_with_min_max_vars_per_channel: No minimum/maximum values provided by either input arrays or TArgs"); auto depth = x->sizeAt(-1); REQUIRE_TRUE(min->rankOf() == 1 && max->rankOf() == 1 && min->lengthOf() == max->lengthOf(), 0, "fake_quant_with_min_max_vars_per_channel: Min and Max should be 1D tensors with the same length"); @@ -49,13 +48,13 @@ namespace nd4j { numBits = INT_ARG(0); bool narrowed = false; //INT_ARG(1); - if (block.getIArguments()->size() == 2) { - numBits = INT_ARG(0); - narrowed = INT_ARG(1); - REQUIRE_TRUE(numBits > 1 && numBits < 17, 0, "fake_quant_with_min_max_vars_per_channel: Number of bits" - " for quatization should be in between 2 and 16, but %i " - "was given.", numBits); + if (block.getBArguments() && block.getBArguments()->size()) { + narrowed = B_ARG(0); } + + REQUIRE_TRUE(numBits > 1 && numBits < 17, 0, "fake_quant_with_min_max_vars_per_channel: Number of bits" + " for quatization should be in between 2 and 16, but %i " + "was given.", numBits); helpers::fakeQuantWithMinMaxVarsPerChannel(block.launchContext(), x, min, max, numBits, narrowed, output); return ND4J_STATUS_OK; } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/lup.cu b/libnd4j/include/ops/declarable/helpers/cuda/lup.cu index 42acf9c09..568b9a9bc 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/lup.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/lup.cu @@ -110,12 +110,21 @@ namespace helpers { template static __global__ void invertLowKernel(void *invertedBuf, Nd4jLong *invertedShape, void *inputBuf, Nd4jLong *inputShape, Nd4jLong n) { + T *inverted = reinterpret_cast(invertedBuf); T *input = reinterpret_cast(inputBuf); + if (threadIdx.x == 0) { + inverted = reinterpret_cast(invertedBuf); + input = reinterpret_cast(inputBuf); + } + __syncthreads(); - for (int i = blockIdx.x + 2; i < n; i += gridDim.x) { + auto tid = blockIdx.x * blockDim.x + threadIdx.x; + auto step = gridDim.x * blockDim.x; + + for (int i = tid + 2; i < n; i += step) { for (int j = i - 2; j >= 0; --j) - for (int k = threadIdx.x; k < i; k += blockDim.x) { + for (int k = 0; k < i; k++) { Nd4jLong posZ[] = {i, j}; Nd4jLong posY[] = {k, j}; Nd4jLong posX[] = {i, k}; @@ -144,10 +153,12 @@ namespace helpers { input = reinterpret_cast(inputBuf); } __syncthreads(); + auto tid = blockIdx.x * blockDim.x + threadIdx.x; + auto step = blockDim.x * gridDim.x; - for (int i = (int)n - blockIdx.x - 2; i >= 0; i -= gridDim.x) { + for (int i = (int)n - tid - 2; i >= 0; i -= step) { for (int j = i + 2; j < (int)n; j++) - for (int k = i + threadIdx.x; k < (int)n; k += blockDim.x) { + for (int k = i; k < (int)n; k++) { Nd4jLong posZ[] = {i, j}; Nd4jLong posY[] = {k, j}; Nd4jLong posX[] = {i, k}; @@ -498,8 +509,6 @@ namespace helpers { fillMatrix<<>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n); // else // fillMatrix<<>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n); - -// if (matrix.dataType() == input->dataType()) lup_(context, &matrix, nullptr, nullptr); // else // lup_(context, &matrix, nullptr, nullptr); @@ -627,9 +636,14 @@ namespace helpers { for (auto i = 0LL; i < packX.numberOfTads(); i++) { fillMatrix<<<1, n2, 1024, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), i * n2, n); matrix.tickWriteDevice(); - compound.assign(matrix); - lup_(context, &compound, nullptr, nullptr); - fillLowerUpperKernel<<>>(lower.specialBuffer(), lower.specialShapeInfo(), upper.specialBuffer(), upper.specialShapeInfo(), compound.specialBuffer(), compound.specialShapeInfo(), n); + //compound.assign(matrix); +// if (matrix.dataType() == input->dataType()) + lup_(context, &matrix, nullptr, nullptr); + fillLowerUpperKernel<<>>(lower.specialBuffer(), lower.specialShapeInfo(), upper.specialBuffer(), upper.specialShapeInfo(), matrix.specialBuffer(), matrix.specialShapeInfo(), n); + lower.tickWriteDevice(); + upper.tickWriteDevice(); +// lower.printIndexedBuffer("LOWER"); +// upper.printIndexedBuffer("UPPER"); matrix.assign(0); invertUpperMatrix(context, &upper, &matrix); // U^{-1} matrix.tickWriteDevice(); diff --git a/libnd4j/include/templatemath.h b/libnd4j/include/templatemath.h index 7aa4dbbe6..b412befd8 100644 --- a/libnd4j/include/templatemath.h +++ b/libnd4j/include/templatemath.h @@ -1305,15 +1305,65 @@ inline __device__ bfloat16 nd4j_atomicAdd(bfloat16* address, bfloat16 else return old.B.L; } +template +static inline __device__ T internal_16bit_atomicAdd(T* address, T val) { + size_t shift = ((size_t)address & 2); + int *base_address = (int *)((char*)address - shift); + + union I16PAIR { + struct { + T H; + T L; + } B; + int W; + + __host__ __device__ + I16PAIR() {}; + + __host__ __device__ + ~I16PAIR() {}; + }; + + I16PAIR pairNew, pairOld, pairAssumed; + + if (reinterpret_cast(address) == base_address) { + pairOld.B.L = val; + do { + + pairNew.B.L = pairOld.B.L; + pairNew.B.H = pairOld.B.H + val; + pairAssumed.W = pairOld.W; + + pairOld.W = atomicCAS(base_address, pairAssumed.W, pairNew.W); + } while (pairAssumed.W != pairOld.W); + + return (T) pairOld.B.H; + } else { + pairOld.B.H = val; + do { + + pairNew.B.H = pairOld.B.H; + pairNew.B.L = pairOld.B.L + val; + pairAssumed.W = pairOld.W; + pairOld.W = atomicCAS(base_address, pairAssumed.W, pairNew.W); + + } while (pairAssumed.W != pairOld.W); + + return (T) pairOld.B.L; + } + +} + template <> inline __device__ int16_t nd4j_atomicAdd(int16_t* address, int16_t val) { - return nd4j_atomicAdd((bfloat16*)address, (bfloat16)val); + return internal_16bit_atomicAdd(address, val); } template <> inline __device__ uint16_t nd4j_atomicAdd(uint16_t* address, uint16_t val) { - return nd4j_atomicAdd((bfloat16*)address, (bfloat16)val); + return internal_16bit_atomicAdd(address, val); } + template <> inline __device__ int8_t nd4j_atomicAdd(int8_t* address, int8_t val) { int res = *address; @@ -1447,7 +1497,7 @@ inline __device__ unsigned char nd4j_atomicMul(unsigned char* add } template -static inline __device__ T internal_16bit_atomicMul(T* address, int16_t val) { +static inline __device__ T internal_16bit_atomicMul(T* address, T val) { size_t shift = ((size_t)address & 2); int *base_address = (int *)((char*)address - shift); @@ -1467,10 +1517,9 @@ static inline __device__ T internal_16bit_atomicMul(T* address, int16_t val) { I16PAIR pairNew, pairOld, pairAssumed; - pairOld.W = (int) val; if (reinterpret_cast(address) == base_address) { + pairOld.B.L = val; do { - pairNew.B.L = pairOld.B.L; pairNew.B.H = pairOld.B.H * val; pairAssumed.W = pairOld.W; @@ -1480,8 +1529,8 @@ static inline __device__ T internal_16bit_atomicMul(T* address, int16_t val) { return (T) pairOld.B.H; } else { + pairOld.B.H = val; do { - pairNew.B.H = pairOld.B.H; pairNew.B.L = pairOld.B.L * val; pairAssumed.W = pairOld.W; @@ -1491,10 +1540,8 @@ static inline __device__ T internal_16bit_atomicMul(T* address, int16_t val) { return (T) pairOld.B.L; } - } - template <> inline __device__ int16_t nd4j_atomicMul(int16_t* address, int16_t val) { return internal_16bit_atomicMul(address, val); @@ -1549,17 +1596,6 @@ inline __device__ uint64_t nd4j_atomicMul(uint64_t* address, uint64_t return (uint64_t)old; } -//template <> -//inline __device__ unsigned long long nd4j_atomicMul(unsigned long long* address, unsigned long long val) { -// unsigned long long int* res_address = address; -// unsigned long long int old = *res_address, assumed; -// do { -// assumed = old; -// old = atomicCAS(res_address, assumed, val * assumed); -// } while (assumed != old); -// return old; -//} - #if !defined(_WIN32) && !defined(_WIN64) template <> inline __device__ Nd4jLong nd4j_atomicMul(Nd4jLong* address, Nd4jLong val) { @@ -1585,22 +1621,21 @@ inline __device__ float16 nd4j_atomicMul(float16* address, float16 val) template <> inline __device__ float nd4j_atomicDiv(float* address, float val) { - return nd4j_atomicMul(address, (float) 1.f / val); + return nd4j_atomicMul(address, 1.f / val); } template <> inline __device__ float16 nd4j_atomicDiv(float16* address, float16 val) { - return nd4j_atomicMul(address, (float16) 1.f / val); + return internal_16bit_atomicMul(address, (float16) 1.f / val); } template <> inline __device__ bfloat16 nd4j_atomicDiv(bfloat16* address, bfloat16 val) { - return nd4j_atomicMul(address, (bfloat16) 1.f / val); + return internal_16bit_atomicMul(address, (bfloat16) 1 / val); } } #endif } - } #ifdef _OPENMP diff --git a/libnd4j/tests_cpu/layers_tests/AtomicTests.cu b/libnd4j/tests_cpu/layers_tests/AtomicTests.cu index 0ede6398c..fdf543026 100644 --- a/libnd4j/tests_cpu/layers_tests/AtomicTests.cu +++ b/libnd4j/tests_cpu/layers_tests/AtomicTests.cu @@ -60,12 +60,93 @@ static void multiplyLauncher(void *vbuffer, uint64_t length, void *vresult) { nd4j::cuda_exception::build("multiply failed", err); } +template +static _CUDA_G void sumKernel(void *vbuffer, uint64_t length, void *vresult) { + auto buffer = reinterpret_cast(vbuffer); + auto result = reinterpret_cast(vresult); + + auto tid = blockIdx.x * blockDim.x + threadIdx.x; + + for (auto e = tid; e < length; e += gridDim.x * blockDim.x) { + auto rem = e % 4; + auto i = (e - rem) / 4; + + nd4j::math::atomics::nd4j_atomicAdd(&result[i], buffer[e]); + } +} + +template +static void sumLauncher(void *vbuffer, uint64_t length, void *vresult) { + sumKernel<<<256, 256, 1024, *nd4j::LaunchContext::defaultContext()->getCudaStream()>>>(vbuffer, length, vresult); + auto err = cudaStreamSynchronize(*nd4j::LaunchContext::defaultContext()->getCudaStream()); + if (err != 0) + nd4j::cuda_exception::build("sum failed", err); +} + +template +static _CUDA_G void subKernel(void *vbuffer, uint64_t length, void *vresult) { + auto buffer = reinterpret_cast(vbuffer); + auto result = reinterpret_cast(vresult); + + auto tid = blockIdx.x * blockDim.x + threadIdx.x; + + for (auto e = tid; e < length; e += gridDim.x * blockDim.x) { + auto rem = e % 4; + auto i = (e - rem) / 4; + + nd4j::math::atomics::nd4j_atomicSub(&result[i], buffer[e]); + } +} + +template +static void subLauncher(void *vbuffer, uint64_t length, void *vresult) { + subKernel<<<256, 256, 1024, *nd4j::LaunchContext::defaultContext()->getCudaStream()>>>(vbuffer, length, vresult); + auto err = cudaStreamSynchronize(*nd4j::LaunchContext::defaultContext()->getCudaStream()); + if (err != 0) + nd4j::cuda_exception::build("sub failed", err); +} + +template +static _CUDA_G void divKernel(void *vbuffer, uint64_t length, void *vresult) { + auto buffer = reinterpret_cast(vbuffer); + auto result = reinterpret_cast(vresult); + + auto tid = blockIdx.x * blockDim.x + threadIdx.x; + + for (auto e = tid; e < length; e += gridDim.x * blockDim.x) { + auto rem = e % 4; + auto i = (e - rem) / 4; + + nd4j::math::atomics::nd4j_atomicDiv(&result[i], buffer[e]); + } +} + +template +static void divLauncher(void *vbuffer, uint64_t length, void *vresult) { + divKernel<<<256, 256, 1024, *nd4j::LaunchContext::defaultContext()->getCudaStream()>>>(vbuffer, length, vresult); + auto err = cudaStreamSynchronize(*nd4j::LaunchContext::defaultContext()->getCudaStream()); + if (err != 0) + nd4j::cuda_exception::build("div failed", err); +} + static void multiplyHost(NDArray &input, NDArray &output) { BUILD_SINGLE_SELECTOR(input.dataType(), multiplyLauncher, (input.specialBuffer(), input.lengthOf(), output.specialBuffer()), NUMERIC_TYPES); } +static void sumHost(NDArray &input, NDArray &output) { + BUILD_SINGLE_SELECTOR(input.dataType(), sumLauncher, (input.specialBuffer(), input.lengthOf(), output.specialBuffer()), NUMERIC_TYPES); +} + +static void subHost(NDArray &input, NDArray &output) { + BUILD_SINGLE_SELECTOR(input.dataType(), subLauncher, (input.specialBuffer(), input.lengthOf(), output.specialBuffer()), FLOAT_TYPES); +} + +static void divHost(NDArray &input, NDArray &output) { + BUILD_SINGLE_SELECTOR(input.dataType(), divLauncher, (input.specialBuffer(), input.lengthOf(), output.specialBuffer()), FLOAT_TYPES); +} + TEST_F(AtomicTests, test_multiply) { - std::vector dtypes = {nd4j::DataType::FLOAT32, nd4j::DataType::DOUBLE, nd4j::DataType::INT16}; + std::vector dtypes = {nd4j::DataType::FLOAT32, nd4j::DataType::DOUBLE, nd4j::DataType::INT16, nd4j::DataType::HALF}; for (auto t:dtypes) { nd4j_printf("Trying data type [%s]\n", DataTypeUtils::asString(t).c_str()); @@ -80,7 +161,81 @@ TEST_F(AtomicTests, test_multiply) { multiplyHost(input, output); ASSERT_EQ(exp, output); } +} +TEST_F(AtomicTests, test_multiply_2) { + std::vector dtypes = {nd4j::DataType::FLOAT32, nd4j::DataType::DOUBLE, nd4j::DataType::HALF, nd4j::DataType::BFLOAT16}; + for (auto t:dtypes) { + nd4j_printf("Trying data type [%s]\n", DataTypeUtils::asString(t).c_str()); + NDArray input('c', {4, 25}, t); + NDArray output('c', {input.lengthOf() / 4}, t); + NDArray exp = output.ulike(); + input.assign(1.5); + output.assign(2); + exp.assign(10.125); + + multiplyHost(input, output); +// output.printBuffer("multiply 2"); + ASSERT_EQ(exp, output); + } +} + +TEST_F(AtomicTests, test_sum) { + std::vector dtypes = {nd4j::DataType::FLOAT32, nd4j::DataType::DOUBLE, nd4j::DataType::BFLOAT16, nd4j::DataType::HALF, nd4j::DataType::INT16}; + + for (auto t:dtypes) { + nd4j_printf("Trying data type [%s]\n", DataTypeUtils::asString(t).c_str()); + NDArray input('c', {4, 25}, t); + NDArray output('c', {input.lengthOf() / 4}, t); + NDArray exp = output.ulike(); + + input.assign(1); + output.assign(1); + exp.assign(5); + + sumHost(input, output); +// output.printIndexedBuffer("Sum"); + ASSERT_EQ(exp, output); + } +} + +TEST_F(AtomicTests, test_sub) { + std::vector dtypes = {nd4j::DataType::FLOAT32, nd4j::DataType::DOUBLE, nd4j::DataType::HALF}; + + for (auto t:dtypes) { + nd4j_printf("Trying data type [%s]\n", DataTypeUtils::asString(t).c_str()); + NDArray input('c', {4, 25}, t); + NDArray output('c', {input.lengthOf() / 4}, t); + NDArray exp = output.ulike(); + + input.assign(1); + output.assign(5); + exp.assign(1); + + subHost(input, output); +// output.printBuffer("Sub"); + + ASSERT_EQ(exp, output); + } +} + +TEST_F(AtomicTests, test_div) { + std::vector dtypes = {nd4j::DataType::FLOAT32, nd4j::DataType::DOUBLE, nd4j::DataType::BFLOAT16, nd4j::DataType::HALF}; + + for (auto t:dtypes) { + nd4j_printf("Trying data type [%s]\n", DataTypeUtils::asString(t).c_str()); + NDArray input('c', {4, 25}, t); + NDArray output('c', {input.lengthOf() / 4}, t); + NDArray exp = output.ulike(); + + input.assign(2); + output.assign(32); + exp.assign(2); + + divHost(input, output); +// output.printBuffer("Div"); + ASSERT_EQ(exp, output); + } } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp index 7bea1e820..6375d935c 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp @@ -2785,6 +2785,100 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_3) { delete results; } +TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_03) { + NDArray x = NDArrayFactory::create('c', {3,5}, {0.7788f,0.8012f, 0.7244f, 0.2309f,0.7271f, + 0.1804f, 0.5056f, 0.8925f, 0.5461f, 0.9234f, + 0.0856f, 0.7938f, 0.6591f, 0.5555f, 0.1596f}); + NDArray exp = NDArrayFactory::create('c', {3,5}, { + 0.777002f, 0.596913f, 0.72314f, 0.231040f, 0.509824f, + 0.179308f, 0.505282f, 0.86846f, 0.349958f, 0.509824f, + 0.087355f, 0.596913f, 0.65740f, 0.349958f, 0.159745f}); + NDArray min = NDArrayFactory::create({-0.2283f, -0.0719f, -0.0154f, -0.5162f, -0.3567f}); + NDArray max = NDArrayFactory::create({0.9441f, 0.5957f, 0.8669f, 0.3502f, 0.5100f}); + + nd4j::ops::fake_quant_with_min_max_vars_per_channel op; + auto results = op.execute({&x, &min, &max}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto result = results->at(0); +// result->printIndexedBuffer("Quantized03"); + ASSERT_TRUE(exp.isSameShapeStrict(result)); + ASSERT_TRUE(exp.equalsTo(result)); + + delete results; +} +TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_03_1) { + NDArray x = NDArrayFactory::create('c', {3,5}, {0.7788f,0.8012f, 0.7244f, 0.2309f,0.7271f, + 0.1804f, 0.5056f, 0.8925f, 0.5461f, 0.9234f, + 0.0856f, 0.7938f, 0.6591f, 0.5555f, 0.1596f}); + NDArray exp = NDArrayFactory::create('c', {3,5}, { + 0.780061f, 0.596635f, 0.725987f, 0.231950f, 0.508419f, + 0.180014f, 0.504643f, 0.868406f, 0.351335f, 0.508419f, + 0.087699f, 0.596635f, 0.659988f, 0.351335f, 0.160374f}); + NDArray min = NDArrayFactory::create({-0.2283f, -0.0719f, -0.0154f, -0.5162f, -0.3567f}); + NDArray max = NDArrayFactory::create({0.9441f, 0.5957f, 0.8669f, 0.3502f, 0.5100f}); + + nd4j::ops::fake_quant_with_min_max_vars_per_channel op; + auto results = op.execute({&x, &min, &max}, {}, {8}, {true}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto result = results->at(0); +// result->printIndexedBuffer("Quantized03_1"); + ASSERT_TRUE(exp.isSameShapeStrict(result)); + ASSERT_TRUE(exp.equalsTo(result)); + + delete results; +} + +TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_03_2) { + NDArray x = NDArrayFactory::create('c', {3,5}, {0.7788f,0.8012f, 0.7244f, 0.2309f,0.7271f, + 0.1804f, 0.5056f, 0.8925f, 0.5461f, 0.9234f, + 0.0856f, 0.7938f, 0.6591f, 0.5555f, 0.1596f}); + NDArray exp = NDArrayFactory::create('c', {3,5}, { + 0.775297f, 0.592226f, 0.725763f, 0.237561f, 0.503245f, + 0.189097f, 0.506084f, 0.868069f, 0.349355f, 0.503245f, + 0.094548f, 0.592226f, 0.654610f, 0.349355f, 0.153769f}); + NDArray min = NDArrayFactory::create({-0.2283f, -0.0719f, -0.0154f, -0.5162f, -0.3567f}); + NDArray max = NDArrayFactory::create({0.9441f, 0.5957f, 0.8669f, 0.3502f, 0.5100f}); + + nd4j::ops::fake_quant_with_min_max_vars_per_channel op; + auto results = op.execute({&x, &min, &max}, {}, {6}, {true}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto result = results->at(0); + result->printIndexedBuffer("Quantized03_2"); + ASSERT_TRUE(exp.isSameShapeStrict(result)); + ASSERT_TRUE(exp.equalsTo(result)); + + delete results; +} + +TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_03_3) { + NDArray x = NDArrayFactory::create('c', {3,5}, {0.7788f,0.8012f, 0.7244f, 0.2309f,0.7271f, + 0.1804f, 0.5056f, 0.8925f, 0.5461f, 0.9234f, + 0.0856f, 0.7938f, 0.6591f, 0.5555f, 0.1596f}); + NDArray exp = NDArrayFactory::create('c', {3,5}, { + 0.781600f, 0.593422f, 0.728248f, 0.233790f, 0.509014f, 0.186095f, 0.508648f, 0.868295f, 0.343809f, + 0.509014f, 0.093048f, 0.593422f, 0.658224f, 0.343809f, 0.165086f}); + NDArray min = NDArrayFactory::create({-0.2283f, -0.0719f, -0.0154f, -0.5162f, -0.3567f}); + NDArray max = NDArrayFactory::create({0.9441f, 0.5957f, 0.8669f, 0.3502f, 0.5100f}); + + nd4j::ops::fake_quant_with_min_max_vars_per_channel op; + auto results = op.execute({&x, &min, &max}, {}, {6}, {false}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto result = results->at(0); + result->printIndexedBuffer("Quantized03_3"); + ASSERT_TRUE(exp.isSameShapeStrict(result)); + ASSERT_TRUE(exp.equalsTo(result)); + + delete results; +} + //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_4) { From 1f5e15b541ab6e16c9807d7f08674a8af97cb56f Mon Sep 17 00:00:00 2001 From: Yurii Shyrma Date: Tue, 3 Dec 2019 08:40:45 +0200 Subject: [PATCH 23/30] Shyrma adjust (#98) * - add possibility of passing scalar-array as input parameter for scale factor in adjust hue/contrast/saturation ops - correct typo in function which calculates regularized incomplete beta integral Signed-off-by: Yurii * - fix bug in betainc cuda kernel Signed-off-by: Yurii * - start working on implementation of digamma function Signed-off-by: Yurii * - further work on digamma function (cpu) Signed-off-by: Yurii * - testing and fixing bugs in digamma op Signed-off-by: Yurii * - make correction n cuda kernel for polyGamma Signed-off-by: Yurii * - remove unnecessary stuff from betaInc cuda kernel Signed-off-by: Yurii * - resolve conflicts in DeclarableOpsTests3.cpp after master branch has been merged Signed-off-by: Yurii * - restore id number of Not opertion in legacy_ops.h Signed-off-by: Yurii * - correct padding calculation in mkl dnn conv1d causal Signed-off-by: Yurii * restore empty check in adjust_contrast_v2 Signed-off-by: raver119 --- libnd4j/include/loops/legacy_ops.h | 6 +- .../generic/parity_ops/adjust_contrast.cpp | 102 +++--- .../generic/parity_ops/adjust_hue.cpp | 22 +- .../generic/parity_ops/adjust_saturation.cpp | 22 +- .../parity_ops/digamma.cpp} | 35 +- .../generic/parity_ops/polygamma.cpp | 14 +- .../ops/declarable/headers/parity_ops.h | 34 +- .../ops/declarable/helpers/cpu/betaInc.cpp | 6 +- .../ops/declarable/helpers/cpu/diGamma.cpp | 53 +++ .../ops/declarable/helpers/cpu/polyGamma.cpp | 31 +- .../ops/declarable/helpers/cuda/betaInc.cu | 55 +-- .../ops/declarable/helpers/cuda/diGamma.cu | 78 +++++ .../ops/declarable/helpers/cuda/polyGamma.cu | 35 +- .../ops/declarable/helpers/gammaMathFunc.h | 100 ++++++ .../ops/declarable/platform/mkldnn/conv2d.cpp | 54 ++- libnd4j/include/ops/ops.h | 30 +- .../layers_tests/ConvolutionTests1.cpp | 314 ++++++++++-------- .../layers_tests/DeclarableOpsTests13.cpp | 6 +- .../layers_tests/DeclarableOpsTests15.cpp | 17 +- .../layers_tests/DeclarableOpsTests3.cpp | 62 +++- 20 files changed, 750 insertions(+), 326 deletions(-) rename libnd4j/include/ops/declarable/{helpers/polyGamma.h => generic/parity_ops/digamma.cpp} (57%) create mode 100644 libnd4j/include/ops/declarable/helpers/cpu/diGamma.cpp create mode 100644 libnd4j/include/ops/declarable/helpers/cuda/diGamma.cu create mode 100644 libnd4j/include/ops/declarable/helpers/gammaMathFunc.h diff --git a/libnd4j/include/loops/legacy_ops.h b/libnd4j/include/loops/legacy_ops.h index 5108ba4a7..7de54a858 100644 --- a/libnd4j/include/loops/legacy_ops.h +++ b/libnd4j/include/loops/legacy_ops.h @@ -112,7 +112,8 @@ (4, IsInfOrNan), \ (5, MatchConditionBool), \ (6, IsPositive) , \ - (7, Not) + (7, Not), \ + (8, IsNegative) #define TRANSFORM_STRICT_OPS \ @@ -279,7 +280,8 @@ (3, IsInfOrNan), \ (4, IsNan), \ (5, IsInf), \ - (6, IsPositive) + (6, IsPositive), \ + (7, IsNegative) #define REDUCE_SAME_OPS \ (0, Sum), \ diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/adjust_contrast.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/adjust_contrast.cpp index d790dd9c2..1aa0c5249 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/adjust_contrast.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/adjust_contrast.cpp @@ -27,7 +27,8 @@ namespace nd4j { namespace ops { -CONFIGURABLE_OP_IMPL(adjust_contrast, 1, 1, true, -2, 0) { +//////////////////////////////////////////////////////////////////// +CONFIGURABLE_OP_IMPL(adjust_contrast, 1, 1, true, 0, 0) { auto input = INPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0); @@ -37,23 +38,31 @@ CONFIGURABLE_OP_IMPL(adjust_contrast, 1, 1, true, -2, 0) { return Status::OK(); REQUIRE_TRUE(block.numT() > 0 || block.width() > 1, 0, "ADJUST_CONTRAST: Scale factor required"); - - const double factor = block.width() > 1 ? INPUT_VARIABLE(1)->e(0) : T_ARG(0); - REQUIRE_TRUE(input->rankOf() > 2, 0, "ADJUST_CONTRAST: op expects rank of input array to be >= 3, but got %i instead", input->rankOf()); REQUIRE_TRUE(input->sizeAt(-1) == 3, 0, "ADJUST_CONTRAST: operation expects image with 3 channels (R, G, B), but got %i instead", input->sizeAt(-1)); - // compute mean before + + NDArray* factor = nullptr; + + if(block.width() > 1) + factor = INPUT_VARIABLE(1); + else { + factor = new NDArray(output->dataType(), block.launchContext()); + factor->p(0, T_ARG(0)); + } + // fill up axes vector first std::vector axes(input->rankOf() - 1); for (auto i = 0; i < axes.size(); ++i) axes[i] = i; + // mean as reduction for last dimension set auto mean = input->reduceAlongDims(reduce::Mean, axes); - NDArray factorT(output->dataType(), block.launchContext()); // = NDArrayFactory::create(factor, block.launchContext()); - factorT.p(0, factor); // this is contrast calculation - output->assign((*input - mean) * factorT + mean); + output->assign((*input - mean) * (*factor) + mean); + + if(block.width() == 1) + delete factor; return Status::OK(); } @@ -64,45 +73,54 @@ DECLARE_TYPES(adjust_contrast) { ->setSameMode(true); } +//////////////////////////////////////////////////////////////////// +CONFIGURABLE_OP_IMPL(adjust_contrast_v2, 1, 1, true, 0, 0) { - CONFIGURABLE_OP_IMPL(adjust_contrast_v2, 1, 1, true, -2, 0) { - - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - REQUIRE_TRUE(block.numT() > 0 || block.width() > 1, 0, "ADJUST_CONTRAST_V2: Scale factor required"); - - const double factor = block.width() > 1 ? INPUT_VARIABLE(1)->e(0) : T_ARG(0); - - // just skip op if input is empty - if (input->isEmpty()) - return Status::OK(); - - REQUIRE_TRUE(input->rankOf() > 2, 0, "ADJUST_CONTRAST_V2: op expects rank of input array to be >= 3, but got %i instead", input->rankOf()); - REQUIRE_TRUE(input->sizeAt(-1) == 3, 0, "ADJUST_CONTRAST_V2: operation expects image with 3 channels (R, G, B), but got %i instead", input->sizeAt(-1)); - - // compute mean before - std::vector axes(input->rankOf() - 1); - for (auto i = 0; i < axes.size(); ++i) - axes[i] = i; - - // mean as reduction for last dimension set - auto mean = input->reduceAlongDims(reduce::Mean, axes); - - // result as (x - mean) * factor + mean - auto temp = input->ulike(); - input->applyTrueBroadcast(BroadcastOpsTuple::Subtract(), &mean, &temp); - temp.applyScalar(scalar::Multiply, factor); - temp.applyTrueBroadcast(BroadcastOpsTuple::Add(), &mean, output); + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + // just skip op if input is empty + if (input->isEmpty()) return Status::OK(); + + REQUIRE_TRUE(input->rankOf() > 2, 0, "ADJUST_CONTRAST_V2: op expects rank of input array to be >= 3, but got %i instead", input->rankOf()); + REQUIRE_TRUE(input->sizeAt(-1) == 3, 0, "ADJUST_CONTRAST_V2: operation expects image with 3 channels (R, G, B), but got %i instead", input->sizeAt(-1)); + REQUIRE_TRUE(block.numT() > 0 || block.width() > 1, 0, "ADJUST_CONTRAST_V2: Scale factor required"); + + NDArray* factor = nullptr; + + if(block.width() > 1) + factor = INPUT_VARIABLE(1); + else { + factor = new NDArray(output->dataType(), block.launchContext()); + factor->p(0, T_ARG(0)); } - DECLARE_TYPES(adjust_contrast_v2) { - getOpDescriptor()->setAllowedInputTypes(nd4j::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}) - ->setSameMode(true); - } + // compute mean before + std::vector axes(input->rankOf() - 1); + for (auto i = 0; i < axes.size(); ++i) + axes[i] = i; + + // mean as reduction for last dimension set + auto mean = input->reduceAlongDims(reduce::Mean, axes); + + // result as (x - mean) * factor + mean + auto temp = input->ulike(); + input->applyTrueBroadcast(BroadcastOpsTuple::Subtract(), &mean, &temp); + temp.applyScalarArr(scalar::Multiply, factor); + temp.applyTrueBroadcast(BroadcastOpsTuple::Add(), &mean, output); + + if(block.width() == 1) + delete factor; + + return Status::OK(); +} + +DECLARE_TYPES(adjust_contrast_v2) { + getOpDescriptor()->setAllowedInputTypes(nd4j::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}) + ->setSameMode(true); +} } } diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/adjust_hue.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/adjust_hue.cpp index d1d81acf8..32e51bdb9 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/adjust_hue.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/adjust_hue.cpp @@ -24,13 +24,12 @@ #include #include -#include namespace nd4j { namespace ops { -CONFIGURABLE_OP_IMPL(adjust_hue, 1, 1, true, 1, -2) { +CONFIGURABLE_OP_IMPL(adjust_hue, 1, 1, true, 0, 0) { auto input = INPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0); @@ -41,15 +40,26 @@ CONFIGURABLE_OP_IMPL(adjust_hue, 1, 1, true, 1, -2) { const int rank = input->rankOf(); const int dimC = block.getIArguments()->size() > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1; - const double delta = T_ARG(0); + REQUIRE_TRUE(block.numT() > 0 || block.width() > 1, 0, "ADJUST_HUE: delta factor is required !"); REQUIRE_TRUE(rank >= 3, 0, "ADJUST_HUE: op expects rank of input array to be >= 3, but got %i instead", rank); REQUIRE_TRUE(input->sizeAt(dimC) == 3, 0, "ADJUST_HUE: operation expects image with 3 channels (R, G, B), but got %i instead", input->sizeAt(dimC)); - REQUIRE_TRUE(-1. <= delta && delta <= 1., 0, "ADJUST_HUE: parameter delta must be within [-1, 1] interval, but got %f instead", delta); - NDArray deltaScalarArr = NDArrayFactory::create(delta, block.launchContext()); + NDArray* delta = nullptr; - helpers::adjustHue(block.launchContext(), input, &deltaScalarArr, output, dimC); + if(block.width() > 1) + delta = INPUT_VARIABLE(1); + else { + delta = new NDArray(output->dataType(), block.launchContext()); + delta->p(0, T_ARG(0)); + } + + REQUIRE_TRUE(-1. <= delta->e(0) && delta->e(0) <= 1., 0, "ADJUST_HUE: parameter delta must be within [-1, 1] interval, but got %f instead", delta); + + helpers::adjustHue(block.launchContext(), input, delta, output, dimC); + + if(block.width() == 1) + delete delta; return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/adjust_saturation.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/adjust_saturation.cpp index 5030e5952..de947c9ae 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/adjust_saturation.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/adjust_saturation.cpp @@ -28,7 +28,7 @@ namespace nd4j { namespace ops { -CONFIGURABLE_OP_IMPL(adjust_saturation, 1, 1, true, 1, -2) { +CONFIGURABLE_OP_IMPL(adjust_saturation, 1, 1, true, 0, 0) { auto input = INPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0); @@ -37,16 +37,26 @@ CONFIGURABLE_OP_IMPL(adjust_saturation, 1, 1, true, 1, -2) { if (input->isEmpty()) return Status::OK(); - const int rank = input->rankOf(); - const int dimC = block.getIArguments()->size() > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1; - const double factor = T_ARG(0); + const int rank = input->rankOf(); + const int dimC = block.getIArguments()->size() > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1; REQUIRE_TRUE(rank >= 3, 0, "ADJUST_SATURATION: op expects rank of input array to be >= 3, but got %i instead", rank); REQUIRE_TRUE(input->sizeAt(dimC) == 3, 0, "ADJUST_SATURATION: operation expects image with 3 channels (R, G, B), but got %i instead", input->sizeAt(dimC)); + REQUIRE_TRUE(block.numT() > 0 || block.width() > 1, 0, "ADJUST_SATURATION: scale factor is required !"); - NDArray factorScalarArr = NDArrayFactory::create(factor, block.launchContext()); + NDArray* factor = nullptr; - helpers::adjustSaturation(block.launchContext(), input, &factorScalarArr, output, dimC); + if(block.width() > 1) + factor = INPUT_VARIABLE(1); + else { + factor = new NDArray(output->dataType(), block.launchContext()); + factor->p(0, T_ARG(0)); + } + + helpers::adjustSaturation(block.launchContext(), input, factor, output, dimC); + + if(block.width() == 1) + delete factor; return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/helpers/polyGamma.h b/libnd4j/include/ops/declarable/generic/parity_ops/digamma.cpp similarity index 57% rename from libnd4j/include/ops/declarable/helpers/polyGamma.h rename to libnd4j/include/ops/declarable/generic/parity_ops/digamma.cpp index 5681d68b7..8a2894be7 100644 --- a/libnd4j/include/ops/declarable/helpers/polyGamma.h +++ b/libnd4j/include/ops/declarable/generic/parity_ops/digamma.cpp @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -15,27 +16,35 @@ ******************************************************************************/ // -// Created by Yurii Shyrma on 13.12.2017. +// @author Yurii Shyrma (iuriish@yahoo.com) // -#ifndef LIBND4J_POLYGAMMA_H -#define LIBND4J_POLYGAMMA_H +#include +#if NOT_EXCLUDED(OP_digamma) -#include -#include "NDArray.h" +#include +#include namespace nd4j { -namespace ops { -namespace helpers { +namespace ops { +CONFIGURABLE_OP_IMPL(digamma, 1, 1, false, 0, 0) { - // calculate the polygamma function - void polyGamma(nd4j::LaunchContext * context, const NDArray& n, const NDArray& x, NDArray& output); - + auto x = INPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); + + helpers::diGamma(block.launchContext(), *x, *z); + + return Status::OK(); +} + +DECLARE_TYPES(digamma) { + getOpDescriptor() + ->setAllowedInputTypes({ALL_FLOATS, ALL_INTS}) + ->setSameMode(true); +} } } -} - -#endif //LIBND4J_POLYGAMMA_H +#endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/polygamma.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/polygamma.cpp index 0f850cd4b..1cfd86a26 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/polygamma.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/polygamma.cpp @@ -15,14 +15,14 @@ ******************************************************************************/ // -// @author Yurii Shyrma (iuriish@yahoo.com), created on 13.12.2017 +// @author Yurii Shyrma (iuriish@yahoo.com) // #include #if NOT_EXCLUDED(OP_polygamma) #include -#include +#include namespace nd4j { namespace ops { @@ -37,11 +37,11 @@ CONFIGURABLE_OP_IMPL(polygamma, 2, 1, false, 0, 0) { Nd4jLong arrLen = n->lengthOf(); // FIXME: this shit should be single op call, not a loop! - auto nPositive = n->reduceNumber(nd4j::reduce::IsPositive, nullptr); - auto xPositive = x->reduceNumber(nd4j::reduce::IsPositive, nullptr); - bool nPositiveFlag = nPositive.e(0); - bool xPositiveFlag = xPositive.e(0); - REQUIRE_TRUE(nPositiveFlag, 0, "POLYGAMMA op: all elements of n array must be > 0 !"); + auto nNegative = n->reduceNumber(nd4j::reduce::IsNegative, nullptr); + auto xPositive = x->reduceNumber(nd4j::reduce::IsPositive, nullptr); + bool nPositiveFlag = !nNegative.e(0); // require all n >= 0 + bool xPositiveFlag = xPositive.e(0); // require all x > 0 + REQUIRE_TRUE(nPositiveFlag, 0, "POLYGAMMA op: all elements of n array must be >= 0 !"); REQUIRE_TRUE(xPositiveFlag, 0, "POLYGAMMA op: all elements of x array must be > 0 !"); helpers::polyGamma(block.launchContext(), *n, *x, *output); diff --git a/libnd4j/include/ops/declarable/headers/parity_ops.h b/libnd4j/include/ops/declarable/headers/parity_ops.h index 0d354630c..e56ba9d6e 100644 --- a/libnd4j/include/ops/declarable/headers/parity_ops.h +++ b/libnd4j/include/ops/declarable/headers/parity_ops.h @@ -513,7 +513,6 @@ namespace nd4j { /** * This op calculates polygamma function psi^(n)(x). Implementation is based on serial representation written in * terms of the Hurwitz zeta function: polygamma = (-1)^{n+1} * n! * zeta(n+1, x). - * Currently the case n = 0 is not supported. * * Input arrays: * 0: n - define derivative order (n+1), type integer (however currently is implemented as float casted to integer) @@ -528,6 +527,20 @@ namespace nd4j { DECLARE_CONFIGURABLE_OP(polygamma, 2, 1, false, 0, 0); #endif + /** + * This op calculates digamma function psi(x) = derivative of log(Gamma(x)) + * + * Input arrays: + * 0: x - abscissa points where to evaluate the digamma function, type float + * + * Output array: + * 0: values of digamma function at corresponding x, type float + * + */ + #if NOT_EXCLUDED(OP_digamma) + DECLARE_CONFIGURABLE_OP(digamma, 1, 1, false, 0, 0); + #endif + /** * This operation takes shape as first argument, and returns new NDArray filled with specific scalar value. * Input arrays: @@ -575,44 +588,47 @@ namespace nd4j { * This operation adjusts image hue by delta * Input arrays: * 0 - input array with rank >= 3, must have at least one dimension equal 3, that is dimension containing channels. + * 1 - optional argument, input scalar-array containing delta * * T arguments: - * 0 - delta value + * 0 - optional argument, delta value * * Int arguments: * 0 - optional argument, corresponds to dimension with 3 channels */ #if NOT_EXCLUDED(OP_adjust_hue) - DECLARE_CONFIGURABLE_OP(adjust_hue, 1, 1, true, 1, -2); + DECLARE_CONFIGURABLE_OP(adjust_hue, 1, 1, true, 0, 0); #endif /** * This operation adjusts image saturation by delta * Input arrays: * 0 - input array with rank >= 3, must have at least one dimension equal 3, that is dimension containing channels. + * 1 - optional argument, input scalar-array containing saturation factor * * T arguments: - * 0 - saturation factor + * 0 - optional argument, saturation factor * * Int arguments: * 0 - optional argument, corresponds to dimension with 3 channels */ #if NOT_EXCLUDED(OP_adjust_saturation) - DECLARE_CONFIGURABLE_OP(adjust_saturation, 1, 1, true, 1, -2); + DECLARE_CONFIGURABLE_OP(adjust_saturation, 1, 1, true, 0, 0); #endif /** * This operation adjusts image contrast by given factor ( z = (x - mean) * factor + mean ) * Input arrays: * 0 - input array with rank >= 3, must have last one dimension equal 3, that is dimension containing channels. + * 1 - optional argument, input scalar-array containing saturation contrast factor * * T arguments: - * 0 - contrast factor + * 0 - optional argument, contrast factor * */ #if NOT_EXCLUDED(OP_adjust_contrast) - DECLARE_CONFIGURABLE_OP(adjust_contrast, 1, 1, true, -2, 0); - DECLARE_CONFIGURABLE_OP(adjust_contrast_v2, 1, 1, true, -2, 0); + DECLARE_CONFIGURABLE_OP(adjust_contrast, 1, 1, true, 0, 0); + DECLARE_CONFIGURABLE_OP(adjust_contrast_v2, 1, 1, true, 0, 0); #endif @@ -1832,7 +1848,7 @@ namespace nd4j { #endif /** - * compare_and_bitpack - compare with greater and pack result with uint8 + * compare_and_bitpack - compare with greater and pack result with uint8 * * input params: * 0 - NDArray (input) diff --git a/libnd4j/include/ops/declarable/helpers/cpu/betaInc.cpp b/libnd4j/include/ops/declarable/helpers/cpu/betaInc.cpp index ddd1ad892..88186b62a 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/betaInc.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/betaInc.cpp @@ -107,12 +107,12 @@ static T betaIncCore(T a, T b, T x) { return x; const T gammaPart = lgamma(a) + lgamma(b) - lgamma(a + b); - const T front = math::nd4j_exp(math::nd4j_log(x) * a + math::nd4j_log(1 - x) * b - gammaPart) / a; + const T front = math::nd4j_exp(math::nd4j_log(x) * a + math::nd4j_log(1.f - x) * b - gammaPart); if (x <= (a + static_cast(1)) / (a + b + static_cast(2))) - return front * continuedFraction(a, b, x); + return front * continuedFraction(a, b, x) / a; else // symmetry relation - return static_cast(1) - front * continuedFraction(b, a, static_cast(1) - x); + return static_cast(1) - front * continuedFraction(b, a, static_cast(1) - x) / b; } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/diGamma.cpp b/libnd4j/include/ops/declarable/helpers/cpu/diGamma.cpp new file mode 100644 index 000000000..8035f8216 --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cpu/diGamma.cpp @@ -0,0 +1,53 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Yurii Shyrma (iuriish@yahoo.com) +// + +#include +#include + +namespace nd4j { +namespace ops { +namespace helpers { + +////////////////////////////////////////////////////////////////////////// +// calculate digamma function for array elements +template +static void diGamma_(const NDArray& x, NDArray& z) { + + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i += increment) + z.p(i, diGammaScalar(x.e(i))); + }; + samediff::Threads::parallel_for(func, 0, x.lengthOf()); +} + +void diGamma(nd4j::LaunchContext* context, const NDArray& x, NDArray& z) { + + BUILD_SINGLE_SELECTOR(x.dataType(), diGamma_, (x, z), FLOAT_TYPES); +} + +BUILD_SINGLE_TEMPLATE(template void diGamma_, (const NDArray& x, NDArray& z), FLOAT_TYPES); + + + +} +} +} + diff --git a/libnd4j/include/ops/declarable/helpers/cpu/polyGamma.cpp b/libnd4j/include/ops/declarable/helpers/cpu/polyGamma.cpp index cb97ffe1e..fc572677e 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/polyGamma.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/polyGamma.cpp @@ -18,7 +18,7 @@ // Created by Yurii Shyrma on 12.12.2017 // -#include +#include #include #include #include @@ -42,7 +42,7 @@ static FORCEINLINE T getFactorial(const int n) { for(int i = 2; i <= n; ++i) result *= i; - + return result; } @@ -50,17 +50,15 @@ static FORCEINLINE T getFactorial(const int n) { // implementation is based on serial representation written in terms of the Hurwitz zeta function as polygamma = (-1)^{n+1} * n! * zeta(n+1, x) template static FORCEINLINE T polyGammaScalar(nd4j::LaunchContext * context, const int n, const T x) { - - // if (n < 0) + + // if (n < 0) // throw("polyGamma function: n must be >= 0 !"); - // if (x <= (T)0.) + // if (x <= (T)0.) // throw("polyGamma function: x must be > 0 !"); - - // TODO case for n = 0 (digamma) int sign = (n + 1) % 2 ? -1 : 1; - // T factorial = (T)std::tgamma(n + 1); + // T factorial = (T)std::tgamma(n + 1); return sign * getFactorial(n) * zetaScalar((T)(n + 1), x); } @@ -71,17 +69,18 @@ static FORCEINLINE T polyGammaScalar(nd4j::LaunchContext * context, const int n, template static void polyGamma_(nd4j::LaunchContext * context, const NDArray& n, const NDArray& x, NDArray& output) { - NDArray& result = output; - - int xLen = x.lengthOf(); - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i += increment) - result.p(i, polyGammaScalar(context, n.e(i), x.e(i))); + for (auto i = start; i < stop; i += increment) { + const T order = n.e(i); + if(order != static_cast(order)) // if order has fractional part then do not perform calculations and return NAN + output.p(i, std::numeric_limits::quiet_NaN()); + else if (order == 0) // polygamma function of zero order is digamma function + output.p(i, diGammaScalar(x.e(i))); + else + output.p(i, polyGammaScalar(context, order, x.e(i))); + } }; samediff::Threads::parallel_for(func, 0, x.lengthOf()); - -// return result; } void polyGamma(nd4j::LaunchContext * context, const NDArray& n, const NDArray& x, NDArray& output) { diff --git a/libnd4j/include/ops/declarable/helpers/cuda/betaInc.cu b/libnd4j/include/ops/declarable/helpers/cuda/betaInc.cu index 90619c76c..e7541a005 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/betaInc.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/betaInc.cu @@ -89,20 +89,6 @@ __device__ T continuedFractionCuda(const T a, const T b, const T x) { return 1.f / 0.f; // no convergence, more iterations is required } -/////////////////////////////////////////////////////////////////// -// evaluates incomplete beta function for positive a and b, and x between 0 and 1. -template -__device__ T betaIncCoreCuda(T a, T b, T x) { - - const T gammaPart = lgamma(a) + lgamma(b) - lgamma(a + b); - const T front = math::nd4j_exp(math::nd4j_log(x) * a + math::nd4j_log(1 - x) * b - gammaPart) / a; - - if (x <= (a + static_cast(1)) / (a + b + static_cast(2))) - return front * continuedFractionCuda(a, b, x); - else // symmetry relation - return static_cast(1) - front * continuedFractionCuda(b, a, static_cast(1) - x); -} - /////////////////////////////////////////////////////////////////// template __global__ void betaIncForArrayCuda(const void* va, const Nd4jLong* aShapeInfo, @@ -115,12 +101,21 @@ __global__ void betaIncForArrayCuda(const void* va, const Nd4jLong* aShapeInfo, const Nd4jLong j = blockIdx.x; // one block per each element - Nd4jLong len = shape::length(xShapeInfo); + T& z = *(reinterpret_cast(vz) + shape::getIndexOffset(j, zShapeInfo)); - const T a = *(reinterpret_cast(va) + shape::getIndexOffset(j, aShapeInfo)); - const T b = *(reinterpret_cast(vb) + shape::getIndexOffset(j, bShapeInfo)); - const T x = *(reinterpret_cast(vx) + shape::getIndexOffset(j, xShapeInfo)); - T& z = *(reinterpret_cast(vz) + shape::getIndexOffset(j, zShapeInfo)); + __shared__ T a, b, x; + __shared__ bool symmCond; + + if (threadIdx.x == 0) { + + a = *(reinterpret_cast(va) + shape::getIndexOffset(j, aShapeInfo)); + b = *(reinterpret_cast(vb) + shape::getIndexOffset(j, bShapeInfo)); + x = *(reinterpret_cast(vx) + shape::getIndexOffset(j, xShapeInfo)); + + symmCond = x <= (a + static_cast(1)) / (a + b + static_cast(2)); + + } + __syncthreads(); // t^{n-1} * (1 - t)^{n-1} is symmetric function with respect to x = 0.5 if(a == b && x == static_cast(0.5)) { @@ -135,17 +130,31 @@ __global__ void betaIncForArrayCuda(const void* va, const Nd4jLong* aShapeInfo, if(threadIdx.x % 2 == 0) { /***** even part *****/ const int m = threadIdx.x + 1; - sharedMem[threadIdx.x] = m * (b - m) * x / ((a + 2 * m - static_cast(1)) * (a + 2 * m)); + if(symmCond) + sharedMem[threadIdx.x] = m * (b - m) * x / ((a + 2 * m - static_cast(1)) * (a + 2 * m)); + else + sharedMem[threadIdx.x] = m * (a - m) * (1.f-x) / ((b + 2 * m - static_cast(1)) * (b + 2 * m)); } else { /***** odd part *****/ const int m = threadIdx.x; - sharedMem[threadIdx.x] = -(a + m) * (a + b + m) * x / ((a + 2 * m + static_cast(1)) * (a + 2 * m)); + if(symmCond) + sharedMem[threadIdx.x] = -(a + m) * (a + b + m) * x / ((a + 2 * m + static_cast(1)) * (a + 2 * m)); + else + sharedMem[threadIdx.x] = -(b + m) * (a + b + m) * (1.f-x) / ((b + 2 * m + static_cast(1)) * (b + 2 * m)); } __syncthreads(); - if(threadIdx.x == 0) - z = betaIncCoreCuda(a, b, x); + if(threadIdx.x == 0) { + + const T gammaPart = lgamma(a) + lgamma(b) - lgamma(a + b); + const T front = math::nd4j_exp(math::nd4j_log(x) * a + math::nd4j_log(1.f - x) * b - gammaPart); + + if (symmCond) + z = front * continuedFractionCuda(a, b, x) / a; + else // symmetry relation + z = static_cast(1) - front * continuedFractionCuda(b, a, static_cast(1) - x) / b; + } } /////////////////////////////////////////////////////////////////// diff --git a/libnd4j/include/ops/declarable/helpers/cuda/diGamma.cu b/libnd4j/include/ops/declarable/helpers/cuda/diGamma.cu new file mode 100644 index 000000000..3edd59ecc --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cuda/diGamma.cu @@ -0,0 +1,78 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Yurii Shyrma (iuriish@yahoo.com) +// + +#include +#include + +namespace nd4j { +namespace ops { +namespace helpers { + +/////////////////////////////////////////////////////////////////// +template +__global__ static void diGammaCuda(const void *vx, const Nd4jLong *xShapeInfo, + void *vz, const Nd4jLong *zShapeInfo) { + + const auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + + __shared__ Nd4jLong len; + __shared__ bool sameOffset; + + if (threadIdx.x == 0) { + len = shape::length(xShapeInfo); + sameOffset = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); + } + __syncthreads(); + + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < len; i += gridDim.x * blockDim.x) { + + const auto xOffset = shape::getIndexOffset(i, xShapeInfo); + const auto zOffset = sameOffset ? xOffset : shape::getIndexOffset(i, zShapeInfo); + + z[zOffset] = diGammaScalar(x[xOffset]); + } +} + +/////////////////////////////////////////////////////////////////// +template +static void diGammaCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo) { + + diGammaCuda<<>>(vx, xShapeInfo, vz, zShapeInfo); +} + +/////////////////////////////////////////////////////////////////// +void diGamma(nd4j::LaunchContext* context, const NDArray& x, NDArray& z) { + + int threadsPerBlock = MAX_NUM_THREADS / 2; + int blocksPerGrid = (z.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + + NDArray::prepareSpecialUse({&z}, {&x}); + BUILD_SINGLE_SELECTOR(x.dataType(), diGammaCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), x.getSpecialBuffer(), x.getSpecialShapeInfo(), z.getSpecialBuffer(), z.getSpecialShapeInfo()), FLOAT_TYPES); + NDArray::registerSpecialUse({&z}, {&x}); +} + +BUILD_SINGLE_TEMPLATE(template void diGammaCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo), FLOAT_TYPES); + +} +} +} + diff --git a/libnd4j/include/ops/declarable/helpers/cuda/polyGamma.cu b/libnd4j/include/ops/declarable/helpers/cuda/polyGamma.cu index 01b9464fa..6a62246a3 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/polyGamma.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/polyGamma.cu @@ -18,7 +18,7 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 26.04.2019 // -#include +#include #include #include @@ -37,9 +37,13 @@ __global__ static void polyGammaCuda(const void *vn, const Nd4jLong *nShapeInfo, auto z = reinterpret_cast(vz); __shared__ Nd4jLong len; + __shared__ bool sameOffsetNX, sameOffsetNZ; - if (threadIdx.x == 0) + if (threadIdx.x == 0) { len = shape::length(nShapeInfo); + sameOffsetNX = shape::haveSameShapeAndStrides(xShapeInfo, nShapeInfo); + sameOffsetNZ = shape::haveSameShapeAndStrides(zShapeInfo, nShapeInfo); + } __syncthreads(); const auto tid = blockIdx.x * blockDim.x + threadIdx.x; @@ -48,19 +52,26 @@ __global__ static void polyGammaCuda(const void *vn, const Nd4jLong *nShapeInfo, for (int i = tid; i < len; i += totalThreads) { const auto nOffset = shape::getIndexOffset(i, nShapeInfo); - const auto xOffset = shape::getIndexOffset(i, xShapeInfo); - const auto zOffset = shape::getIndexOffset(i, zShapeInfo); + const auto xOffset = sameOffsetNX ? nOffset : shape::getIndexOffset(i, xShapeInfo); + const auto zOffset = sameOffsetNZ ? nOffset : shape::getIndexOffset(i, zShapeInfo); - const T nVal = n[nOffset]; + const T order = n[nOffset]; - int sign = (static_cast(nVal) + 1) % 2 ? -1 : 1; + int sign = (static_cast(order) + 1) % 2 ? -1 : 1; - T factorial = 1; - if(nVal != 0 && nVal != 1) - for(int i = 2; i <= nVal; ++i) - factorial *= i; + if(order != static_cast(order)) { + z[zOffset] = DataTypeUtils::nanOrZero(); + } + else if(order == 0) { + z[zOffset] = diGammaScalar(x[xOffset]); + } + else { + T factorial = 1; + for(int i = 2; i <= order; ++i) + factorial *= i; - z[zOffset] = sign * factorial * zetaScalar(nVal + 1, x[xOffset]); + z[zOffset] = sign * factorial * zetaScalar(order + 1, x[xOffset]); + } } } @@ -76,7 +87,7 @@ void polyGamma(nd4j::LaunchContext * context, const NDArray& n, const NDArray& x NDArray::prepareSpecialUse({&z}, {&n, &x}); - int threadsPerBlock = MAX_NUM_THREADS; + int threadsPerBlock = MAX_NUM_THREADS / 2; int blocksPerGrid = (z.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; BUILD_SINGLE_SELECTOR(n.dataType(), polyGammaCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), n.getSpecialBuffer(), n.getSpecialShapeInfo(), x.getSpecialBuffer(), x.getSpecialShapeInfo(), z.getSpecialBuffer(), z.getSpecialShapeInfo()), FLOAT_TYPES); diff --git a/libnd4j/include/ops/declarable/helpers/gammaMathFunc.h b/libnd4j/include/ops/declarable/helpers/gammaMathFunc.h new file mode 100644 index 000000000..2ad540409 --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/gammaMathFunc.h @@ -0,0 +1,100 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Yurii Shyrma (iuriish@yahoo.com) +// + +#ifndef LIBND4J_GAMMAMATHFUNC_H +#define LIBND4J_GAMMAMATHFUNC_H + +#include +#include "NDArray.h" + +namespace nd4j { +namespace ops { +namespace helpers { + + // calculate the digamma function for each element for array + void diGamma(nd4j::LaunchContext* context, const NDArray& x, NDArray& z); + + // calculate the polygamma function + void polyGamma(nd4j::LaunchContext* context, const NDArray& n, const NDArray& x, NDArray& z); + + // calculate the digamma function for one element + // implementation is based on serial representation written in terms of the Hurwitz zeta function as polygamma = (-1)^{n+1} * n! * zeta(n+1, x) + template + _CUDA_HD T diGammaScalar(T x) { + + const int xInt = static_cast(x); + + // negative and zero + if(x <= 0) { + if(x == xInt) // integer + return DataTypeUtils::infOrMax(); + else + return diGammaScalar(1 - x) - M_PI / nd4j::math::nd4j_tan(M_PI * x); // use reflection formula psi(1-x) = psi(x) + pi*cot(pi*x) + } + + // positive integer + if(x == xInt && xInt <= 20) { // psi(n) = -Euler_Mascheroni_const + sum_from_k=1_to_n-1( 1/k ), for n = 1,2,3,...inf, we use this formula only for n <= 20 to avoid time consuming sum calculation for bigger n + T result = -0.577215664901532; + for (uint i = 1; i <= xInt - 1; ++i) { + result += static_cast(1) / i; + } + return result; + } + + // positive half-integer + if(x - xInt == 0.5 && xInt <= 20) { // psi(n+0.5) = -Euler_Mascheroni_const - 2*ln(2) + sum_from_k=1_to_n( 2/(2*k-1) ) , for n = 1,2,3,...inf, we use this formula only for n <= 20 to avoid time consuming sum calculation for bigger n + T result = -0.577215664901532 - 2 * nd4j::math::nd4j_log(2); + for (uint i = 1; i <= xInt; ++i) { + result += static_cast(2) / (2*i - 1); + } + return result; + } + + // positive, smaller then 5; we should use number > 5 in order to have satisfactory accuracy in asymptotic expansion + if(x < 5) + return diGammaScalar(1 + x) - static_cast(1) / x; // recurrence formula psi(x) = psi(x+1) - 1/x. + + // *** other positive **** // + + // truncated expansion formula (from wiki) + // psi(x) = log(x) - 1/(2*x) - 1/(12*x^2) + 1/(120*x^4) - 1/(252*x^6) + 1/(240*x^8) - 5/(660*x^10) + 691/(32760*x^12) - 1/(12*x^14) + ... + + if(x >= (sizeof(T) > 4 ? 1.e16 : 1.e8)) // if x is too big take into account only log(x) + return nd4j::math::nd4j_log(x); + + // coefficients used in truncated asymptotic expansion formula + const T coeffs[7] = {-(T)1/12, (T)1/120, -(T)1/252, (T)1/240, -(T)5/660, (T)691/32760, -(T)1/12}; + // const T coeffs[7] = {-0.0833333333333333, 0.00833333333333333, -0.00396825396825397, 0.00416666666666667, -0.00757575757575758, 0.0210927960927961, -0.0833333333333333}; + + const T x2Inv = static_cast(1) / (x * x); + T result = 0; + + for (int i = 6; i >= 0; --i) + result = (result + coeffs[i]) * x2Inv; + return result + nd4j::math::nd4j_log(x) - static_cast(0.5) / x; + } + +} +} +} + + +#endif //LIBND4J_GAMMAMATHFUNC_H diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/conv2d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/conv2d.cpp index 027c62484..4531fda81 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/conv2d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/conv2d.cpp @@ -36,7 +36,7 @@ namespace platforms { ////////////////////////////////////////////////////////////////////// static void conv2d_mkldnn(nd4j::graph::Context &block, const NDArray *input, const NDArray *weights, const NDArray *bias, NDArray *output, const int kH, const int kW, const int sH, - const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, + const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) { int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; @@ -44,8 +44,7 @@ static void conv2d_mkldnn(nd4j::graph::Context &block, const NDArray *input, con ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); - if(isSameMode) // SAME - ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); dnnl_memory_desc_t empty; dnnl::memory::desc conv_src_md(empty), conv_weights_md(empty), conv_bias_md(empty), conv_dst_md( @@ -53,7 +52,7 @@ static void conv2d_mkldnn(nd4j::graph::Context &block, const NDArray *input, con dnnl::memory::desc user_src_md(empty), user_weights_md(empty), user_bias_md(empty), user_dst_md( empty); dnnl::memory::dims conv_strides, conv_padding, conv_padding_r, conv_dilation; - mkldnnUtils::getMKLDNNMemoryDescConv2d(kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW, + mkldnnUtils::getMKLDNNMemoryDescConv2d(kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, bS, iC, iH, iW, oC, oH, oW, input, nullptr, weights, nullptr, bias, output, &conv_src_md, nullptr, &conv_weights_md, nullptr, @@ -115,13 +114,11 @@ static void conv2d_mkldnn(nd4j::graph::Context &block, const NDArray *input, con ////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(conv2d) { - auto input = INPUT_VARIABLE( - 0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] - auto output = OUTPUT_VARIABLE( - 0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) + auto output = OUTPUT_VARIABLE(0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) int sH = INT_ARG(2); // strides height int sW = INT_ARG(3); // strides width @@ -129,13 +126,13 @@ PLATFORM_IMPL(conv2d) { int pW = INT_ARG(5); // paddings width int dH = INT_ARG(6); // dilations height int dW = INT_ARG(7); // dilations width - int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME + int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME bool isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(weights->sizeAt(0)); // filter(kernel) height int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(weights->sizeAt(1)); // filter(kernel) width - conv2d_mkldnn(block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW); + conv2d_mkldnn(block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW); return Status::OK(); } @@ -155,18 +152,13 @@ PLATFORM_CHECK(conv2d) { ////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(conv2d_bp) { - auto input = INPUT_VARIABLE( - 0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - auto weights = INPUT_VARIABLE( - 1); // [kH, kW, iC, oC] always + auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] - auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE( - 2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next + auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next - auto gradI = OUTPUT_VARIABLE( - 0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon - auto gradW = OUTPUT_VARIABLE( - 1); // [kH, kW, iC, oC] always + auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon + auto gradW = OUTPUT_VARIABLE(1); // [kH, kW, iC, oC] always auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] int kH = INT_ARG(0); // filter(kernel) height @@ -177,7 +169,7 @@ PLATFORM_IMPL(conv2d_bp) { int pW = INT_ARG(5); // paddings width int dH = INT_ARG(6); // dilations height int dW = INT_ARG(7); // dilations width - int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME + int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC REQUIRE_TRUE(input->rankOf() == 4, 0, @@ -195,8 +187,7 @@ PLATFORM_IMPL(conv2d_bp) { ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); - if (isSameMode) // SAME - ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); dnnl_memory_desc_t empty; dnnl::memory::desc conv_src_md(empty), conv_diff_src_md(empty), conv_weights_md(empty), @@ -204,7 +195,7 @@ PLATFORM_IMPL(conv2d_bp) { dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_weights_md(empty), user_diff_weights_md(empty), user_bias_md(empty), user_dst_md(empty); dnnl::memory::dims conv_strides, conv_padding, conv_padding_r, conv_dilation; - mkldnnUtils::getMKLDNNMemoryDescConv2d(kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW, + mkldnnUtils::getMKLDNNMemoryDescConv2d(kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, bS, iC, iH, iW, oC, oH, oW, input, gradI, weights, gradW, gradB, gradO, &conv_src_md, &conv_diff_src_md, &conv_weights_md, @@ -342,18 +333,13 @@ PLATFORM_CHECK(conv2d_bp) { if (::optimalLevel() < 2) return false; - auto input = INPUT_VARIABLE( - 0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - auto weights = INPUT_VARIABLE( - 1); // [kH, kW, iC, oC] always + auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] - auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE( - 2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next + auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next - auto gradI = OUTPUT_VARIABLE( - 0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon - auto gradW = OUTPUT_VARIABLE( - 1); // [kH, kW, iC, oC] always + auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon + auto gradW = OUTPUT_VARIABLE(1); // [kH, kW, iC, oC] always auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] diff --git a/libnd4j/include/ops/ops.h b/libnd4j/include/ops/ops.h index f3f9f9699..1b54889d4 100644 --- a/libnd4j/include/ops/ops.h +++ b/libnd4j/include/ops/ops.h @@ -1568,11 +1568,39 @@ namespace simdOps { return opOutput + old; } - op_def static Z update(X old, X opOutput, X *extraParams) { return opOutput + old; } + op_def static Z postProcess(X reduction, Nd4jLong n, X *extraParams) { + return reduction; + } + }; + + template + class IsNegative { + public: + no_op_exec_special_bool + no_op_exec_special_bool_cuda + + no_op_exec_special_accumulation + no_op_exec_special_accumulation_cuda + + op_def static Z op(X d1, X *params) { + return d1 < (X)0.f; + } + + op_def static X startingValue(const X *input) { + return static_cast(0); + } + + op_def static Z merge(X old, X opOutput, X *extraParams) { + return opOutput + old; + } + + op_def static Z update(X old, X opOutput, X *extraParams) { + return opOutput + old; + } op_def static Z postProcess(X reduction, Nd4jLong n, X *extraParams) { return reduction; diff --git a/libnd4j/tests_cpu/layers_tests/ConvolutionTests1.cpp b/libnd4j/tests_cpu/layers_tests/ConvolutionTests1.cpp index b3552a00f..99092b37d 100644 --- a/libnd4j/tests_cpu/layers_tests/ConvolutionTests1.cpp +++ b/libnd4j/tests_cpu/layers_tests/ConvolutionTests1.cpp @@ -1008,6 +1008,38 @@ TEST_F(ConvolutionTests1, conv1d_causal_6) { delete results; } +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, conv1d_causal_7) { + + int bS=2, iW=8, iC=3,oC=4, kW=2, sW=1, pW=0, dW=1; + int oW = (iW-1)/sW + 1; + int paddingMode = 2; // CAUSAL + int dataFormat = 1; // 1-NHWC, 0-NCHW + + NDArray input('c', {bS, iW, iC}, nd4j::DataType::FLOAT32); + NDArray weights('c', {kW, iC, oC}, nd4j::DataType::FLOAT32); + + NDArray expOutput('c', {bS, oW, oC}, {11.000000, 11.600000, 12.200000, 12.800000, 30.099998, 32.200001, 34.299999, 36.400002, 49.899998, 53.800003, 57.699997, + 61.599998, 69.699997, 75.400002, 81.099998, 86.800003, 89.500000, 97.000000, 104.500000, 112.000000, 109.300003, 118.600006, 127.899994, 137.199997, 129.100006, + 140.199997, 151.300003, 162.399994, 148.899994, 161.800003, 174.699997, 187.600006, 133.399994, 141.200012, 149.000000, 156.800003, 188.500000, 205.000000, + 221.500000, 238.000000, 208.299988, 226.600006, 244.899994, 263.200012, 228.100006, 248.200012, 268.299988, 288.399994, 247.899994, 269.799988, 291.700012, + 313.600006, 267.700012, 291.399994, 315.100006, 338.799988, 287.500000, 313.000000, 338.500000, 364.000000, 307.299988, 334.600006, 361.899994, 389.200012}, nd4j::DataType::FLOAT32); + + input.linspace(1., 1.); + weights.linspace(0.1, 0.1); + + nd4j::ops::conv1d op; + auto results = op.execute({&input, &weights}, {}, {kW, sW, pW, dW, paddingMode, dataFormat}); + auto output = results->at(0); + + ASSERT_EQ(Status::OK(), results->status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); + + delete results; +} + ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests1, conv1d_causal_bp_1) { @@ -1174,14 +1206,14 @@ TYPED_TEST(TypedConvolutionTests1, conv2d_bp_test3) { auto bias = NDArrayFactory::create('c', {oC}, {1,2,3}); auto gradO = NDArrayFactory::create('c', {bS, oC, oH, oW}); - auto expGradI = NDArrayFactory::create('c', {bS, iC, iH, iW},{ 0.567f, 1.224f, 0.66f, 1.314f, 2.82f, 1.512f, 1.386f, 2.976f, 1.596f, 0.801f, 1.71f, 0.912f, 0.657f, 1.422f, 0.768f, 1.53f, 3.288f, 1.764f, 1.602f, 3.444f, 1.848f, 0.927f, 1.98f, 1.056f, - 0.747f, 1.62f, 0.876f, 1.746f, 3.756f, 2.016f, 1.818f, 3.912f, 2.1f, 1.053f, 2.25f, 1.2f, 0.837f, 1.818f, 0.984f, 1.962f, 4.224f, 2.268f, 2.034f, 4.38f, 2.352f, 1.179f, 2.52f, 1.344f, - 1.467f, 3.06f, 1.596f, 3.186f, 6.636f, 3.456f, 3.402f, 7.08f, 3.684f, 1.845f, 3.834f, 1.992f, 1.773f, 3.69f, 1.92f, 3.834f, 7.968f, 4.14f, 4.05f, 8.412f, 4.368f, 2.187f, 4.536f, 2.352f, + auto expGradI = NDArrayFactory::create('c', {bS, iC, iH, iW},{ 0.567f, 1.224f, 0.66f, 1.314f, 2.82f, 1.512f, 1.386f, 2.976f, 1.596f, 0.801f, 1.71f, 0.912f, 0.657f, 1.422f, 0.768f, 1.53f, 3.288f, 1.764f, 1.602f, 3.444f, 1.848f, 0.927f, 1.98f, 1.056f, + 0.747f, 1.62f, 0.876f, 1.746f, 3.756f, 2.016f, 1.818f, 3.912f, 2.1f, 1.053f, 2.25f, 1.2f, 0.837f, 1.818f, 0.984f, 1.962f, 4.224f, 2.268f, 2.034f, 4.38f, 2.352f, 1.179f, 2.52f, 1.344f, + 1.467f, 3.06f, 1.596f, 3.186f, 6.636f, 3.456f, 3.402f, 7.08f, 3.684f, 1.845f, 3.834f, 1.992f, 1.773f, 3.69f, 1.92f, 3.834f, 7.968f, 4.14f, 4.05f, 8.412f, 4.368f, 2.187f, 4.536f, 2.352f, 2.079f, 4.32f, 2.244f, 4.482f, 9.3f, 4.824f, 4.698f, 9.744f, 5.052f, 2.529f, 5.238f, 2.712f, 2.385f, 4.95f, 2.568f, 5.13f, 10.632f, 5.508f, 5.346f, 11.076f, 5.736f, 2.871f, 5.94f, 3.072f}); - auto expGradW = NDArrayFactory::create('c', {oC, iC, kH, kW},{1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, - 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, - 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, + auto expGradW = NDArrayFactory::create('c', {oC, iC, kH, kW},{1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, + 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, + 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f}); auto expGradB = NDArrayFactory::create('c', {oC},{0.68f, 1.f, 1.32f}); @@ -1252,20 +1284,20 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_bp_test1) { auto bias = NDArrayFactory::create('c', {oC}, {1,2,3}); auto gradO = NDArrayFactory::create('c', {bS, oD, oH, oW, oC}); - auto expGradI = NDArrayFactory::create('c', {bS, iD, iH, iW, iC},{0.226f, 0.343f, 0.46f, 0.577f, 1.172f, 1.46f, 1.748f, 2.036f, 1.892f, 2.288f, 2.684f, 3.08f, 1.284f, 1.581f, 1.878f, 2.175f, 4.458f, 5.133f, 5.808f, 6.483f, 6.186f, 7.023f, 7.86f, 8.697f, 3.39f, 3.93f, 4.47f, 5.01f, 9.642f, 10.803f, 11.964f, 13.125f, 11.37f, 12.693f, 14.016f, 15.339f, - 5.266f, 5.707f, 6.148f, 6.589f, 12.98f, 13.916f, 14.852f, 15.788f, 14.564f, 15.608f, 16.652f, 17.696f, 6.284f, 7.166f, 8.048f, 8.93f, 17.896f, 19.768f, 21.64f, 23.512f, 21.928f, 24.016f, 26.104f, 28.192f, 18.12f, 19.686f, 21.252f, 22.818f, 45.852f, 49.146f, 52.44f, 55.734f, 53.196f, 56.814f, 60.432f, 64.05f, - 28.164f, 30.216f, 32.268f, 34.32f, 67.884f, 72.15f, 76.416f, 80.682f, 75.228f, 79.818f, 84.408f, 88.998f, 29.324f, 30.854f, 32.384f, 33.914f, 67.432f, 70.6f, 73.768f, 76.936f, 73.192f, 76.576f, 79.96f, 83.344f, 27.884f, 30.062f, 32.24f, 34.418f, 66.28f, 70.744f, 75.208f, 79.672f, 70.312f, 74.992f, 79.672f, 84.352f, - 58.296f, 61.806f, 65.316f, 68.826f, 133.98f, 141.162f, 148.344f, 155.526f, 141.324f, 148.83f, 156.336f, 163.842f, 68.34f, 72.336f, 76.332f, 80.328f, 156.012f, 164.166f, 172.32f, 180.474f, 163.356f, 171.834f, 180.312f, 188.79f, 61.292f, 64.118f, 66.944f, 69.77f, 136.552f, 142.312f, 148.072f, 153.832f, 142.312f, 148.288f, 154.264f, 160.24f, - 9.298f, 11.359f, 13.42f, 15.481f, 27.092f, 31.268f, 35.444f, 39.62f, 27.812f, 32.096f, 36.38f, 40.664f, 26.556f, 29.769f, 32.982f, 36.195f, 66.666f, 73.173f, 79.68f, 86.187f, 68.394f, 75.063f, 81.732f, 88.401f, 28.662f, 32.118f, 35.574f, 39.03f, 71.85f, 78.843f, 85.836f, 92.829f, 73.578f, 80.733f, 87.888f, 95.043f, - 29.89f, 32.275f, 34.66f, 37.045f, 70.004f, 74.828f, 79.652f, 84.476f, 71.588f, 76.52f, 81.452f, 86.384f, 71.084f, 75.854f, 80.624f, 85.394f, 163.048f, 172.696f, 182.344f, 191.992f, 167.08f, 176.944f, 186.808f, 196.672f, 138.648f, 146.046f, 153.444f, 160.842f, 310.236f, 325.194f, 340.152f, 355.11f, 317.58f, 332.862f, 348.144f, 363.426f, - 148.692f, 156.576f, 164.46f, 172.344f, 332.268f, 348.198f, 364.128f, 380.058f, 339.612f, 355.866f, 372.12f, 388.374f, 125.228f, 130.646f, 136.064f, 141.482f, 274.792f, 285.736f, 296.68f, 307.624f, 280.552f, 291.712f, 302.872f, 314.032f, 92.684f, 98.75f, 104.816f, 110.882f, 211.432f, 223.672f, 235.912f, 248.152f, 215.464f, 227.92f, 240.376f, 252.832f, + auto expGradI = NDArrayFactory::create('c', {bS, iD, iH, iW, iC},{0.226f, 0.343f, 0.46f, 0.577f, 1.172f, 1.46f, 1.748f, 2.036f, 1.892f, 2.288f, 2.684f, 3.08f, 1.284f, 1.581f, 1.878f, 2.175f, 4.458f, 5.133f, 5.808f, 6.483f, 6.186f, 7.023f, 7.86f, 8.697f, 3.39f, 3.93f, 4.47f, 5.01f, 9.642f, 10.803f, 11.964f, 13.125f, 11.37f, 12.693f, 14.016f, 15.339f, + 5.266f, 5.707f, 6.148f, 6.589f, 12.98f, 13.916f, 14.852f, 15.788f, 14.564f, 15.608f, 16.652f, 17.696f, 6.284f, 7.166f, 8.048f, 8.93f, 17.896f, 19.768f, 21.64f, 23.512f, 21.928f, 24.016f, 26.104f, 28.192f, 18.12f, 19.686f, 21.252f, 22.818f, 45.852f, 49.146f, 52.44f, 55.734f, 53.196f, 56.814f, 60.432f, 64.05f, + 28.164f, 30.216f, 32.268f, 34.32f, 67.884f, 72.15f, 76.416f, 80.682f, 75.228f, 79.818f, 84.408f, 88.998f, 29.324f, 30.854f, 32.384f, 33.914f, 67.432f, 70.6f, 73.768f, 76.936f, 73.192f, 76.576f, 79.96f, 83.344f, 27.884f, 30.062f, 32.24f, 34.418f, 66.28f, 70.744f, 75.208f, 79.672f, 70.312f, 74.992f, 79.672f, 84.352f, + 58.296f, 61.806f, 65.316f, 68.826f, 133.98f, 141.162f, 148.344f, 155.526f, 141.324f, 148.83f, 156.336f, 163.842f, 68.34f, 72.336f, 76.332f, 80.328f, 156.012f, 164.166f, 172.32f, 180.474f, 163.356f, 171.834f, 180.312f, 188.79f, 61.292f, 64.118f, 66.944f, 69.77f, 136.552f, 142.312f, 148.072f, 153.832f, 142.312f, 148.288f, 154.264f, 160.24f, + 9.298f, 11.359f, 13.42f, 15.481f, 27.092f, 31.268f, 35.444f, 39.62f, 27.812f, 32.096f, 36.38f, 40.664f, 26.556f, 29.769f, 32.982f, 36.195f, 66.666f, 73.173f, 79.68f, 86.187f, 68.394f, 75.063f, 81.732f, 88.401f, 28.662f, 32.118f, 35.574f, 39.03f, 71.85f, 78.843f, 85.836f, 92.829f, 73.578f, 80.733f, 87.888f, 95.043f, + 29.89f, 32.275f, 34.66f, 37.045f, 70.004f, 74.828f, 79.652f, 84.476f, 71.588f, 76.52f, 81.452f, 86.384f, 71.084f, 75.854f, 80.624f, 85.394f, 163.048f, 172.696f, 182.344f, 191.992f, 167.08f, 176.944f, 186.808f, 196.672f, 138.648f, 146.046f, 153.444f, 160.842f, 310.236f, 325.194f, 340.152f, 355.11f, 317.58f, 332.862f, 348.144f, 363.426f, + 148.692f, 156.576f, 164.46f, 172.344f, 332.268f, 348.198f, 364.128f, 380.058f, 339.612f, 355.866f, 372.12f, 388.374f, 125.228f, 130.646f, 136.064f, 141.482f, 274.792f, 285.736f, 296.68f, 307.624f, 280.552f, 291.712f, 302.872f, 314.032f, 92.684f, 98.75f, 104.816f, 110.882f, 211.432f, 223.672f, 235.912f, 248.152f, 215.464f, 227.92f, 240.376f, 252.832f, 178.824f, 188.166f, 197.508f, 206.85f, 398.364f, 417.21f, 436.056f, 454.902f, 405.708f, 424.878f, 444.048f, 463.218f, 188.868f, 198.696f, 208.524f, 218.352f, 420.396f, 440.214f, 460.032f, 479.85f, 427.74f, 447.882f, 468.024f, 488.166f, 157.196f, 163.91f, 170.624f, 177.338f, 343.912f, 357.448f, 370.984f, 384.52f, 349.672f, 363.424f, 377.176f, 390.928f}); - auto expGradW = NDArrayFactory::create('c', {kD, kH, kW, iC, oC},{120.96f, 122.04f, 123.12f, 120.96f, 122.04f, 123.12f, 120.96f, 122.04f, 123.12f, 120.96f, 122.04f, 123.12f, 79.56f, 80.28f, 81.f, 79.56f, 80.28f, 81.f, 79.56f, 80.28f, 81.f, 79.56f, 80.28f, 81.f, - 154.8f, 156.24f, 157.68f, 154.8f, 156.24f, 157.68f, 154.8f, 156.24f, 157.68f, 154.8f, 156.24f, 157.68f, 101.76f, 102.72f, 103.68f, 101.76f, 102.72f, 103.68f, 101.76f, 102.72f, 103.68f, 101.76f, 102.72f, 103.68f, - 111.24f, 112.32f, 113.4f, 111.24f, 112.32f, 113.4f, 111.24f, 112.32f, 113.4f, 111.24f, 112.32f, 113.4f, 73.08f, 73.8f, 74.52f, 73.08f, 73.8f, 74.52f, 73.08f, 73.8f, 74.52f, 73.08f, 73.8f, 74.52f, - 67.68f, 68.4f, 69.12f, 67.68f, 68.4f, 69.12f, 67.68f, 68.4f, 69.12f, 67.68f, 68.4f, 69.12f, 44.4f, 44.88f, 45.36f, 44.4f, 44.88f, 45.36f, 44.4f, 44.88f, 45.36f, 44.4f, 44.88f, 45.36f, - 85.92f, 86.88f, 87.84f, 85.92f, 86.88f, 87.84f, 85.92f, 86.88f, 87.84f, 85.92f, 86.88f, 87.84f, 56.32f, 56.96f, 57.6f, 56.32f, 56.96f, 57.6f, 56.32f, 56.96f, 57.6f, 56.32f, 56.96f, 57.6f, + auto expGradW = NDArrayFactory::create('c', {kD, kH, kW, iC, oC},{120.96f, 122.04f, 123.12f, 120.96f, 122.04f, 123.12f, 120.96f, 122.04f, 123.12f, 120.96f, 122.04f, 123.12f, 79.56f, 80.28f, 81.f, 79.56f, 80.28f, 81.f, 79.56f, 80.28f, 81.f, 79.56f, 80.28f, 81.f, + 154.8f, 156.24f, 157.68f, 154.8f, 156.24f, 157.68f, 154.8f, 156.24f, 157.68f, 154.8f, 156.24f, 157.68f, 101.76f, 102.72f, 103.68f, 101.76f, 102.72f, 103.68f, 101.76f, 102.72f, 103.68f, 101.76f, 102.72f, 103.68f, + 111.24f, 112.32f, 113.4f, 111.24f, 112.32f, 113.4f, 111.24f, 112.32f, 113.4f, 111.24f, 112.32f, 113.4f, 73.08f, 73.8f, 74.52f, 73.08f, 73.8f, 74.52f, 73.08f, 73.8f, 74.52f, 73.08f, 73.8f, 74.52f, + 67.68f, 68.4f, 69.12f, 67.68f, 68.4f, 69.12f, 67.68f, 68.4f, 69.12f, 67.68f, 68.4f, 69.12f, 44.4f, 44.88f, 45.36f, 44.4f, 44.88f, 45.36f, 44.4f, 44.88f, 45.36f, 44.4f, 44.88f, 45.36f, + 85.92f, 86.88f, 87.84f, 85.92f, 86.88f, 87.84f, 85.92f, 86.88f, 87.84f, 85.92f, 86.88f, 87.84f, 56.32f, 56.96f, 57.6f, 56.32f, 56.96f, 57.6f, 56.32f, 56.96f, 57.6f, 56.32f, 56.96f, 57.6f, 61.2f, 61.92f, 62.64f, 61.2f, 61.92f, 62.64f, 61.2f, 61.92f, 62.64f, 61.2f, 61.92f, 62.64f, 40.08f, 40.56f, 41.04f, 40.08f, 40.56f, 41.04f, 40.08f, 40.56f, 41.04f, 40.08f, 40.56f, 41.04f}); // auto expGradB('c', {oC},{}); @@ -1302,18 +1334,18 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_bp_test2) { auto bias = NDArrayFactory::create('c', {oC}, {1,2,3}); auto gradO = NDArrayFactory::create('c', {bS, oD, oH, oW, oC}); - auto expGradI = NDArrayFactory::create('c', {bS, iD, iH, iW, iC},{ 0.014f, 0.032f, 0.05f, 0.068f, 0.118f, 0.181f, 0.244f, 0.307f, 0.212f, 0.257f, 0.302f, 0.347f, 0.208f, 0.298f, 0.388f, 0.478f, 1.028f, 1.262f, 1.496f, 1.73f, 1.036f, 1.18f, 1.324f, 1.468f, 0.928f, 1.018f, 1.108f, 1.198f, 2.9f, 3.134f, 3.368f, 3.602f, 2.188f, 2.332f, 2.476f, 2.62f, - 1.202f, 1.274f, 1.346f, 1.418f, 3.142f, 3.313f, 3.484f, 3.655f, 2.048f, 2.147f, 2.246f, 2.345f, 0.532f, 0.676f, 0.82f, 0.964f, 2.324f, 2.666f, 3.008f, 3.35f, 2.008f, 2.206f, 2.404f, 2.602f, 3.584f, 3.98f, 4.376f, 4.772f, 10.552f, 11.452f, 12.352f, 13.252f, 7.4f, 7.904f, 8.408f, 8.912f, - 6.752f, 7.148f, 7.544f, 7.94f, 17.752f, 18.652f, 19.552f, 20.452f, 11.432f, 11.936f, 12.44f, 12.944f, 5.932f, 6.184f, 6.436f, 6.688f, 14.42f, 14.978f, 15.536f, 16.094f, 8.704f, 9.01f, 9.316f, 9.622f, 3.11f, 3.236f, 3.362f, 3.488f, 7.39f, 7.669f, 7.948f, 8.227f, 4.388f, 4.541f, 4.694f, 4.847f, - 8.56f, 8.866f, 9.172f, 9.478f, 19.892f, 20.558f, 21.224f, 21.89f, 11.548f, 11.908f, 12.268f, 12.628f, 11.008f, 11.314f, 11.62f, 11.926f, 25.22f, 25.886f, 26.552f, 27.218f, 14.428f, 14.788f, 15.148f, 15.508f, 7.322f, 7.502f, 7.682f, 7.862f, 16.462f, 16.849f, 17.236f, 17.623f, 9.248f, 9.455f, 9.662f, 9.869f, - 0.158f, 0.392f, 0.626f, 0.86f, 1.27f, 1.765f, 2.26f, 2.755f, 1.22f, 1.481f, 1.742f, 2.003f, 2.224f, 2.746f, 3.268f, 3.79f, 6.788f, 7.886f, 8.984f, 10.082f, 4.78f, 5.356f, 5.932f, 6.508f, 6.4f, 6.922f, 7.444f, 7.966f, 15.572f, 16.67f, 17.768f, 18.866f, 9.388f, 9.964f, 10.54f, 11.116f, - 4.802f, 5.09f, 5.378f, 5.666f, 11.206f, 11.809f, 12.412f, 13.015f, 6.512f, 6.827f, 7.142f, 7.457f, 6.004f, 6.58f, 7.156f, 7.732f, 14.996f, 16.202f, 17.408f, 18.614f, 9.208f, 9.838f, 10.468f, 11.098f, 17.984f, 19.244f, 20.504f, 21.764f, 42.808f, 45.436f, 48.064f, 50.692f, 25.256f, 26.624f, 27.992f, 29.36f, - 28.064f, 29.324f, 30.584f, 31.844f, 63.832f, 66.46f, 69.088f, 71.716f, 36.2f, 37.568f, 38.936f, 40.304f, 18.316f, 19.f, 19.684f, 20.368f, 40.916f, 42.338f, 43.76f, 45.182f, 22.816f, 23.554f, 24.292f, 25.03f, 8.438f, 8.78f, 9.122f, 9.464f, 18.91f, 19.621f, 20.332f, 21.043f, 10.58f, 10.949f, 11.318f, 11.687f, + auto expGradI = NDArrayFactory::create('c', {bS, iD, iH, iW, iC},{ 0.014f, 0.032f, 0.05f, 0.068f, 0.118f, 0.181f, 0.244f, 0.307f, 0.212f, 0.257f, 0.302f, 0.347f, 0.208f, 0.298f, 0.388f, 0.478f, 1.028f, 1.262f, 1.496f, 1.73f, 1.036f, 1.18f, 1.324f, 1.468f, 0.928f, 1.018f, 1.108f, 1.198f, 2.9f, 3.134f, 3.368f, 3.602f, 2.188f, 2.332f, 2.476f, 2.62f, + 1.202f, 1.274f, 1.346f, 1.418f, 3.142f, 3.313f, 3.484f, 3.655f, 2.048f, 2.147f, 2.246f, 2.345f, 0.532f, 0.676f, 0.82f, 0.964f, 2.324f, 2.666f, 3.008f, 3.35f, 2.008f, 2.206f, 2.404f, 2.602f, 3.584f, 3.98f, 4.376f, 4.772f, 10.552f, 11.452f, 12.352f, 13.252f, 7.4f, 7.904f, 8.408f, 8.912f, + 6.752f, 7.148f, 7.544f, 7.94f, 17.752f, 18.652f, 19.552f, 20.452f, 11.432f, 11.936f, 12.44f, 12.944f, 5.932f, 6.184f, 6.436f, 6.688f, 14.42f, 14.978f, 15.536f, 16.094f, 8.704f, 9.01f, 9.316f, 9.622f, 3.11f, 3.236f, 3.362f, 3.488f, 7.39f, 7.669f, 7.948f, 8.227f, 4.388f, 4.541f, 4.694f, 4.847f, + 8.56f, 8.866f, 9.172f, 9.478f, 19.892f, 20.558f, 21.224f, 21.89f, 11.548f, 11.908f, 12.268f, 12.628f, 11.008f, 11.314f, 11.62f, 11.926f, 25.22f, 25.886f, 26.552f, 27.218f, 14.428f, 14.788f, 15.148f, 15.508f, 7.322f, 7.502f, 7.682f, 7.862f, 16.462f, 16.849f, 17.236f, 17.623f, 9.248f, 9.455f, 9.662f, 9.869f, + 0.158f, 0.392f, 0.626f, 0.86f, 1.27f, 1.765f, 2.26f, 2.755f, 1.22f, 1.481f, 1.742f, 2.003f, 2.224f, 2.746f, 3.268f, 3.79f, 6.788f, 7.886f, 8.984f, 10.082f, 4.78f, 5.356f, 5.932f, 6.508f, 6.4f, 6.922f, 7.444f, 7.966f, 15.572f, 16.67f, 17.768f, 18.866f, 9.388f, 9.964f, 10.54f, 11.116f, + 4.802f, 5.09f, 5.378f, 5.666f, 11.206f, 11.809f, 12.412f, 13.015f, 6.512f, 6.827f, 7.142f, 7.457f, 6.004f, 6.58f, 7.156f, 7.732f, 14.996f, 16.202f, 17.408f, 18.614f, 9.208f, 9.838f, 10.468f, 11.098f, 17.984f, 19.244f, 20.504f, 21.764f, 42.808f, 45.436f, 48.064f, 50.692f, 25.256f, 26.624f, 27.992f, 29.36f, + 28.064f, 29.324f, 30.584f, 31.844f, 63.832f, 66.46f, 69.088f, 71.716f, 36.2f, 37.568f, 38.936f, 40.304f, 18.316f, 19.f, 19.684f, 20.368f, 40.916f, 42.338f, 43.76f, 45.182f, 22.816f, 23.554f, 24.292f, 25.03f, 8.438f, 8.78f, 9.122f, 9.464f, 18.91f, 19.621f, 20.332f, 21.043f, 10.58f, 10.949f, 11.318f, 11.687f, 20.944f, 21.682f, 22.42f, 23.158f, 46.388f, 47.918f, 49.448f, 50.978f, 25.66f, 26.452f, 27.244f, 28.036f, 26.848f, 27.586f, 28.324f, 29.062f, 58.628f, 60.158f, 61.688f, 63.218f, 31.996f, 32.788f, 33.58f, 34.372f, 16.106f, 16.502f, 16.898f, 17.294f, 34.894f, 35.713f, 36.532f, 37.351f, 18.896f, 19.319f, 19.742f, 20.165f}); - auto expGradW = NDArrayFactory::create('c', {kD, kH, kW, iC, oC},{7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, - 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, - 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, + auto expGradW = NDArrayFactory::create('c', {kD, kH, kW, iC, oC},{7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, + 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, + 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f}); // auto expGradB('c', {oC},{}); @@ -1350,20 +1382,20 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_bp_test3) { auto bias = NDArrayFactory::create('c', {oC}, {1,2,3}); auto gradO = NDArrayFactory::create('c', {bS, oC, oD, oH, oW}); - auto expGradI = NDArrayFactory::create('c', {bS, iC, iD, iH, iW},{2.091f, 4.356f, 2.268f, 4.53f, 9.42f, 4.896f, 4.65f, 9.672f, 5.028f, 2.517f, 5.226f, 2.712f, 4.932f, 10.242f, 5.316f, 10.62f, 22.02f, 11.412f, 10.908f, 22.62f, 11.724f, 5.868f, 12.15f, 6.288f, 2.913f, 6.03f, 3.12f, 6.234f, 12.888f, 6.66f, 6.402f, 13.236f, 6.84f, 3.423f, 7.068f, 3.648f, - 2.415f, 5.04f, 2.628f, 5.25f, 10.932f, 5.688f, 5.37f, 11.184f, 5.82f, 2.913f, 6.054f, 3.144f, 5.724f, 11.898f, 6.18f, 12.348f, 25.62f, 13.284f, 12.636f, 26.22f, 13.596f, 6.804f, 14.094f, 7.296f, 3.381f, 7.002f, 3.624f, 7.242f, 14.976f, 7.74f, 7.41f, 15.324f, 7.92f, 3.963f, 8.184f, 4.224f, - 2.739f, 5.724f, 2.988f, 5.97f, 12.444f, 6.48f, 6.09f, 12.696f, 6.612f, 3.309f, 6.882f, 3.576f, 6.516f, 13.554f, 7.044f, 14.076f, 29.22f, 15.156f, 14.364f, 29.82f, 15.468f, 7.74f, 16.038f, 8.304f, 3.849f, 7.974f, 4.128f, 8.25f, 17.064f, 8.82f, 8.418f, 17.412f, 9.f, 4.503f, 9.3f, 4.8f, - 3.063f, 6.408f, 3.348f, 6.69f, 13.956f, 7.272f, 6.81f, 14.208f, 7.404f, 3.705f, 7.71f, 4.008f, 7.308f, 15.21f, 7.908f, 15.804f, 32.82f, 17.028f, 16.092f, 33.42f, 17.34f, 8.676f, 17.982f, 9.312f, 4.317f, 8.946f, 4.632f, 9.258f, 19.152f, 9.9f, 9.426f, 19.5f, 10.08f, 5.043f, 10.416f, 5.376f, - 5.619f, 11.484f, 5.868f, 11.73f, 23.964f, 12.24f, 12.138f, 24.792f, 12.66f, 6.333f, 12.93f, 6.6f, 12.42f, 25.362f, 12.948f, 25.884f, 52.836f, 26.964f, 26.748f, 54.588f, 27.852f, 13.932f, 28.422f, 14.496f, 6.873f, 14.022f, 7.152f, 14.298f, 29.16f, 14.868f, 14.754f, 30.084f, 15.336f, 7.671f, 15.636f, 7.968f, - 6.807f, 13.896f, 7.092f, 14.178f, 28.932f, 14.76f, 14.586f, 29.76f, 15.18f, 7.593f, 15.486f, 7.896f, 14.94f, 30.474f, 15.54f, 31.068f, 63.348f, 32.292f, 31.932f, 65.1f, 33.18f, 16.596f, 33.822f, 17.232f, 8.205f, 16.722f, 8.52f, 17.034f, 34.704f, 17.676f, 17.49f, 35.628f, 18.144f, 9.075f, 18.48f, 9.408f, - 7.995f, 16.308f, 8.316f, 16.626f, 33.9f, 17.28f, 17.034f, 34.728f, 17.7f, 8.853f, 18.042f, 9.192f, 17.46f, 35.586f, 18.132f, 36.252f, 73.86f, 37.62f, 37.116f, 75.612f, 38.508f, 19.26f, 39.222f, 19.968f, 9.537f, 19.422f, 9.888f, 19.77f, 40.248f, 20.484f, 20.226f, 41.172f, 20.952f, 10.479f, 21.324f, 10.848f, + auto expGradI = NDArrayFactory::create('c', {bS, iC, iD, iH, iW},{2.091f, 4.356f, 2.268f, 4.53f, 9.42f, 4.896f, 4.65f, 9.672f, 5.028f, 2.517f, 5.226f, 2.712f, 4.932f, 10.242f, 5.316f, 10.62f, 22.02f, 11.412f, 10.908f, 22.62f, 11.724f, 5.868f, 12.15f, 6.288f, 2.913f, 6.03f, 3.12f, 6.234f, 12.888f, 6.66f, 6.402f, 13.236f, 6.84f, 3.423f, 7.068f, 3.648f, + 2.415f, 5.04f, 2.628f, 5.25f, 10.932f, 5.688f, 5.37f, 11.184f, 5.82f, 2.913f, 6.054f, 3.144f, 5.724f, 11.898f, 6.18f, 12.348f, 25.62f, 13.284f, 12.636f, 26.22f, 13.596f, 6.804f, 14.094f, 7.296f, 3.381f, 7.002f, 3.624f, 7.242f, 14.976f, 7.74f, 7.41f, 15.324f, 7.92f, 3.963f, 8.184f, 4.224f, + 2.739f, 5.724f, 2.988f, 5.97f, 12.444f, 6.48f, 6.09f, 12.696f, 6.612f, 3.309f, 6.882f, 3.576f, 6.516f, 13.554f, 7.044f, 14.076f, 29.22f, 15.156f, 14.364f, 29.82f, 15.468f, 7.74f, 16.038f, 8.304f, 3.849f, 7.974f, 4.128f, 8.25f, 17.064f, 8.82f, 8.418f, 17.412f, 9.f, 4.503f, 9.3f, 4.8f, + 3.063f, 6.408f, 3.348f, 6.69f, 13.956f, 7.272f, 6.81f, 14.208f, 7.404f, 3.705f, 7.71f, 4.008f, 7.308f, 15.21f, 7.908f, 15.804f, 32.82f, 17.028f, 16.092f, 33.42f, 17.34f, 8.676f, 17.982f, 9.312f, 4.317f, 8.946f, 4.632f, 9.258f, 19.152f, 9.9f, 9.426f, 19.5f, 10.08f, 5.043f, 10.416f, 5.376f, + 5.619f, 11.484f, 5.868f, 11.73f, 23.964f, 12.24f, 12.138f, 24.792f, 12.66f, 6.333f, 12.93f, 6.6f, 12.42f, 25.362f, 12.948f, 25.884f, 52.836f, 26.964f, 26.748f, 54.588f, 27.852f, 13.932f, 28.422f, 14.496f, 6.873f, 14.022f, 7.152f, 14.298f, 29.16f, 14.868f, 14.754f, 30.084f, 15.336f, 7.671f, 15.636f, 7.968f, + 6.807f, 13.896f, 7.092f, 14.178f, 28.932f, 14.76f, 14.586f, 29.76f, 15.18f, 7.593f, 15.486f, 7.896f, 14.94f, 30.474f, 15.54f, 31.068f, 63.348f, 32.292f, 31.932f, 65.1f, 33.18f, 16.596f, 33.822f, 17.232f, 8.205f, 16.722f, 8.52f, 17.034f, 34.704f, 17.676f, 17.49f, 35.628f, 18.144f, 9.075f, 18.48f, 9.408f, + 7.995f, 16.308f, 8.316f, 16.626f, 33.9f, 17.28f, 17.034f, 34.728f, 17.7f, 8.853f, 18.042f, 9.192f, 17.46f, 35.586f, 18.132f, 36.252f, 73.86f, 37.62f, 37.116f, 75.612f, 38.508f, 19.26f, 39.222f, 19.968f, 9.537f, 19.422f, 9.888f, 19.77f, 40.248f, 20.484f, 20.226f, 41.172f, 20.952f, 10.479f, 21.324f, 10.848f, 9.183f, 18.72f, 9.54f, 19.074f, 38.868f, 19.8f, 19.482f, 39.696f, 20.22f, 10.113f, 20.598f, 10.488f, 19.98f, 40.698f, 20.724f, 41.436f, 84.372f, 42.948f, 42.3f, 86.124f, 43.836f, 21.924f, 44.622f, 22.704f, 10.869f, 22.122f, 11.256f, 22.506f, 45.792f, 23.292f, 22.962f, 46.716f, 23.76f, 11.883f, 24.168f, 12.288f}); - auto expGradW = NDArrayFactory::create('c', {oC, iC, kD, kH, kW},{5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, - 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, - 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, - 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, - 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, + auto expGradW = NDArrayFactory::create('c', {oC, iC, kD, kH, kW},{5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, + 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, + 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, + 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, + 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f}); auto expGradB = NDArrayFactory::create('c', {oC},{2.64f, 3.92f, 5.2f}); @@ -1407,9 +1439,9 @@ TYPED_TEST(TypedConvolutionTests1, depthwise_conv2d_1) { auto weights = NDArrayFactory::create('c', {kH, kW, iC, mC}); - auto expOutput = NDArrayFactory::create('c', {bS, oH, oW, oC},{12.f, 12.8f, 13.6f, 14.4f, 12.f, 12.8f, 13.6f, 14.4f, 5.2f, 5.6f, 6.f, 6.4f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 5.4f, 6.f, 6.6f, 7.2f, - 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 5.4f, 6.f, 6.6f, 7.2f, 5.6f, 6.4f, 7.2f, 8.f, 5.6f, 6.4f, 7.2f, 8.f, 2.f, 2.4f, 2.8f, 3.2f, - 12.f, 12.8f, 13.6f, 14.4f, 12.f, 12.8f, 13.6f, 14.4f, 5.2f, 5.6f, 6.f, 6.4f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 5.4f, 6.f, 6.6f, 7.2f, + auto expOutput = NDArrayFactory::create('c', {bS, oH, oW, oC},{12.f, 12.8f, 13.6f, 14.4f, 12.f, 12.8f, 13.6f, 14.4f, 5.2f, 5.6f, 6.f, 6.4f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 5.4f, 6.f, 6.6f, 7.2f, + 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 5.4f, 6.f, 6.6f, 7.2f, 5.6f, 6.4f, 7.2f, 8.f, 5.6f, 6.4f, 7.2f, 8.f, 2.f, 2.4f, 2.8f, 3.2f, + 12.f, 12.8f, 13.6f, 14.4f, 12.f, 12.8f, 13.6f, 14.4f, 5.2f, 5.6f, 6.f, 6.4f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 5.4f, 6.f, 6.6f, 7.2f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 5.4f, 6.f, 6.6f, 7.2f, 5.6f, 6.4f, 7.2f, 8.f, 5.6f, 6.4f, 7.2f, 8.f, 2.f, 2.4f, 2.8f, 3.2f}); input = 2.; weights.linspace(0.1, 0.1); @@ -1439,7 +1471,7 @@ TEST_F(ConvolutionTests1, depthwise_conv2d_2) { auto weights = NDArrayFactory::create('c', {kH, kW, iC, mC}); - auto expOutput = NDArrayFactory::create('c', {bS, oH, oW, oC},{13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, + auto expOutput = NDArrayFactory::create('c', {bS, oH, oW, oC},{13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f}); input = 2.; weights.linspace(0.1, 0.1); @@ -1697,13 +1729,13 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_test1) { auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); auto weights = NDArrayFactory::create('c', {kD, kH, kW, iC, oC}); - auto expected = NDArrayFactory::create('c', {2, 3, 4, 3, 3}, {64.f, 64.f, 64.f, 64.f, 64.f, 64.f, 32.f, 32.f, 32.f, 96.f, 96.f, 96.f, 96.f, 96.f, 96.f, 48.f, 48.f, 48.f, 96.f, 96.f, 96.f, 96.f, 96.f, 96.f, 48.f, 48.f, 48.f, - 64.f, 64.f, 64.f, 64.f, 64.f, 64.f, 32.f, 32.f, 32.f, 64.f, 64.f, 64.f, 64.f, 64.f, 64.f, 32.f, 32.f, 32.f, 96.f, 96.f, 96.f, 96.f, 96.f, 96.f, 48.f, 48.f, 48.f, - 96.f, 96.f, 96.f, 96.f, 96.f, 96.f, 48.f, 48.f, 48.f, 64.f, 64.f, 64.f, 64.f, 64.f, 64.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 16.f, 16.f, 16.f, - 48.f, 48.f, 48.f, 48.f, 48.f, 48.f, 24.f, 24.f, 24.f, 48.f, 48.f, 48.f, 48.f, 48.f, 48.f, 24.f, 24.f, 24.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 16.f, 16.f, 16.f, - 64.f, 64.f, 64.f, 64.f, 64.f, 64.f, 32.f, 32.f, 32.f, 96.f, 96.f, 96.f, 96.f, 96.f, 96.f, 48.f, 48.f, 48.f, 96.f, 96.f, 96.f, 96.f, 96.f, 96.f, 48.f, 48.f, 48.f, - 64.f, 64.f, 64.f, 64.f, 64.f, 64.f, 32.f, 32.f, 32.f, 64.f, 64.f, 64.f, 64.f, 64.f, 64.f, 32.f, 32.f, 32.f, 96.f, 96.f, 96.f, 96.f, 96.f, 96.f, 48.f, 48.f, 48.f, - 96.f, 96.f, 96.f, 96.f, 96.f, 96.f, 48.f, 48.f, 48.f, 64.f, 64.f, 64.f, 64.f, 64.f, 64.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 16.f, 16.f, 16.f, + auto expected = NDArrayFactory::create('c', {2, 3, 4, 3, 3}, {64.f, 64.f, 64.f, 64.f, 64.f, 64.f, 32.f, 32.f, 32.f, 96.f, 96.f, 96.f, 96.f, 96.f, 96.f, 48.f, 48.f, 48.f, 96.f, 96.f, 96.f, 96.f, 96.f, 96.f, 48.f, 48.f, 48.f, + 64.f, 64.f, 64.f, 64.f, 64.f, 64.f, 32.f, 32.f, 32.f, 64.f, 64.f, 64.f, 64.f, 64.f, 64.f, 32.f, 32.f, 32.f, 96.f, 96.f, 96.f, 96.f, 96.f, 96.f, 48.f, 48.f, 48.f, + 96.f, 96.f, 96.f, 96.f, 96.f, 96.f, 48.f, 48.f, 48.f, 64.f, 64.f, 64.f, 64.f, 64.f, 64.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 16.f, 16.f, 16.f, + 48.f, 48.f, 48.f, 48.f, 48.f, 48.f, 24.f, 24.f, 24.f, 48.f, 48.f, 48.f, 48.f, 48.f, 48.f, 24.f, 24.f, 24.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 16.f, 16.f, 16.f, + 64.f, 64.f, 64.f, 64.f, 64.f, 64.f, 32.f, 32.f, 32.f, 96.f, 96.f, 96.f, 96.f, 96.f, 96.f, 48.f, 48.f, 48.f, 96.f, 96.f, 96.f, 96.f, 96.f, 96.f, 48.f, 48.f, 48.f, + 64.f, 64.f, 64.f, 64.f, 64.f, 64.f, 32.f, 32.f, 32.f, 64.f, 64.f, 64.f, 64.f, 64.f, 64.f, 32.f, 32.f, 32.f, 96.f, 96.f, 96.f, 96.f, 96.f, 96.f, 48.f, 48.f, 48.f, + 96.f, 96.f, 96.f, 96.f, 96.f, 96.f, 48.f, 48.f, 48.f, 64.f, 64.f, 64.f, 64.f, 64.f, 64.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 16.f, 16.f, 16.f, 48.f, 48.f, 48.f, 48.f, 48.f, 48.f, 24.f, 24.f, 24.f, 48.f, 48.f, 48.f, 48.f, 48.f, 48.f, 24.f, 24.f, 24.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 16.f, 16.f, 16.f}); input = 2.; weights = 1.; @@ -1729,13 +1761,13 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_test2) { auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); auto weights = NDArrayFactory::create('c', {kD, kH, kW, iC, oC}); - auto expected = NDArrayFactory::create('c', {2, 3, 4, 3, 3}, {534.4f, 540.8f, 547.2f, 534.4f, 540.8f, 547.2f, 248.f, 251.2f, 254.4f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, - 380.8f, 387.2f, 393.6f, 380.8f, 387.2f, 393.6f, 171.2f, 174.4f, 177.6f, 534.4f, 540.8f, 547.2f, 534.4f, 540.8f, 547.2f, 248.f, 251.2f, 254.4f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, - 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, 380.8f, 387.2f, 393.6f, 380.8f, 387.2f, 393.6f, 171.2f, 174.4f, 177.6f, 152.f, 155.2f, 158.4f, 152.f, 155.2f, 158.4f, 66.4f, 68.f, 69.6f, - 170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f, 170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f, 75.2f, 78.4f, 81.6f, 75.2f, 78.4f, 81.6f, 28.f, 29.6f, 31.2f, - 534.4f, 540.8f, 547.2f, 534.4f, 540.8f, 547.2f, 248.f, 251.2f, 254.4f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, - 380.8f, 387.2f, 393.6f, 380.8f, 387.2f, 393.6f, 171.2f, 174.4f, 177.6f, 534.4f, 540.8f, 547.2f, 534.4f, 540.8f, 547.2f, 248.f, 251.2f, 254.4f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, - 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, 380.8f, 387.2f, 393.6f, 380.8f, 387.2f, 393.6f, 171.2f, 174.4f, 177.6f, 152.f, 155.2f, 158.4f, 152.f, 155.2f, 158.4f, 66.4f, 68.f, 69.6f, + auto expected = NDArrayFactory::create('c', {2, 3, 4, 3, 3}, {534.4f, 540.8f, 547.2f, 534.4f, 540.8f, 547.2f, 248.f, 251.2f, 254.4f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, + 380.8f, 387.2f, 393.6f, 380.8f, 387.2f, 393.6f, 171.2f, 174.4f, 177.6f, 534.4f, 540.8f, 547.2f, 534.4f, 540.8f, 547.2f, 248.f, 251.2f, 254.4f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, + 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, 380.8f, 387.2f, 393.6f, 380.8f, 387.2f, 393.6f, 171.2f, 174.4f, 177.6f, 152.f, 155.2f, 158.4f, 152.f, 155.2f, 158.4f, 66.4f, 68.f, 69.6f, + 170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f, 170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f, 75.2f, 78.4f, 81.6f, 75.2f, 78.4f, 81.6f, 28.f, 29.6f, 31.2f, + 534.4f, 540.8f, 547.2f, 534.4f, 540.8f, 547.2f, 248.f, 251.2f, 254.4f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, + 380.8f, 387.2f, 393.6f, 380.8f, 387.2f, 393.6f, 171.2f, 174.4f, 177.6f, 534.4f, 540.8f, 547.2f, 534.4f, 540.8f, 547.2f, 248.f, 251.2f, 254.4f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, + 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, 380.8f, 387.2f, 393.6f, 380.8f, 387.2f, 393.6f, 171.2f, 174.4f, 177.6f, 152.f, 155.2f, 158.4f, 152.f, 155.2f, 158.4f, 66.4f, 68.f, 69.6f, 170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f, 170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f, 75.2f, 78.4f, 81.6f, 75.2f, 78.4f, 81.6f, 28.f, 29.6f, 31.2f}); input = 2.; weights.linspace(0.1, 0.1); @@ -1760,9 +1792,9 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_test3) { auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); auto weights = NDArrayFactory::create('c', {kD, kH, kW, iC, oC}); - auto expected = NDArrayFactory::create('c', {2, 2, 2, 2, 3}, {686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, - 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, - 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, + auto expected = NDArrayFactory::create('c', {2, 2, 2, 2, 3}, {686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, + 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, + 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f}); input = 2.; weights.linspace(0.1, 0.1); @@ -1844,8 +1876,8 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_test6) { auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); auto weights = NDArrayFactory::create('c', {kD, kH, kW, iC, oC}); auto bias = NDArrayFactory::create('c', {oC},{1.f, 2.f, 3.f}); - auto expected = NDArrayFactory::create('c', {2, 3, 2, 2, 2},{49.f, 49.f, 49.f, 49.f, 49.f, 49.f, 49.f, 49.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, - 51.f, 51.f, 51.f, 51.f, 51.f, 51.f, 51.f, 51.f, 49.f, 49.f, 49.f, 49.f, 49.f, 49.f, 49.f, 49.f, + auto expected = NDArrayFactory::create('c', {2, 3, 2, 2, 2},{49.f, 49.f, 49.f, 49.f, 49.f, 49.f, 49.f, 49.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, + 51.f, 51.f, 51.f, 51.f, 51.f, 51.f, 51.f, 51.f, 49.f, 49.f, 49.f, 49.f, 49.f, 49.f, 49.f, 49.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 51.f, 51.f, 51.f, 51.f, 51.f, 51.f, 51.f, 51.f}); input = 2.; weights = 0.5; @@ -1873,9 +1905,9 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_test7) { auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); auto weights = NDArrayFactory::create('c', {oC, iC, kD, kH, kW}); auto bias = NDArrayFactory::create('c', {oC},{1.f, 2.f, 3.f}); - auto expected = NDArrayFactory::create('c', {2, 3, 2, 2, 2},{236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 698.f, 698.f, 698.f, 698.f, - 698.f, 698.f, 698.f, 698.f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, - 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 698.f, 698.f, 698.f, 698.f, + auto expected = NDArrayFactory::create('c', {2, 3, 2, 2, 2},{236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 698.f, 698.f, 698.f, 698.f, + 698.f, 698.f, 698.f, 698.f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, + 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 698.f, 698.f, 698.f, 698.f, 698.f, 698.f, 698.f, 698.f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f}); input = 2.; weights.linspace(0.1, 0.1); @@ -1903,8 +1935,8 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_test8) { auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); auto weights = NDArrayFactory::create('c', {oC, iC, kD, kH, kW}); - auto expected = NDArrayFactory::create('c', {2, 3, 2, 2, 2},{235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 696.f, 696.f, 696.f, 696.f, 696.f, 696.f, 696.f, 696.f, - 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, + auto expected = NDArrayFactory::create('c', {2, 3, 2, 2, 2},{235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 696.f, 696.f, 696.f, 696.f, 696.f, 696.f, 696.f, 696.f, + 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 696.f, 696.f, 696.f, 696.f, 696.f, 696.f, 696.f, 696.f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f}); input = 2.; weights.linspace(0.1, 0.1); @@ -1997,9 +2029,9 @@ TYPED_TEST(TypedConvolutionTests1, pointwise_conv2d_test1) { auto bias = NDArrayFactory::create('c', {oC}); - auto expOutput = NDArrayFactory::create('c', {bS, iH, iW, oC},{ 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, - 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, - 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, + auto expOutput = NDArrayFactory::create('c', {bS, iH, iW, oC},{ 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, + 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, + 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f}); input = 2.; weights.linspace(0.1, 0.1); @@ -2110,20 +2142,20 @@ TEST_F(ConvolutionTests1, vol2col_test2) { auto columns = NDArrayFactory::create('c', {kD, iC, kH, oW, kW, bS, oD, oH}); columns.permutei({5, 1, 0, 2, 4, 6, 7, 3}); columns = -1.; - auto columnsExpected = NDArrayFactory::create('c', {bS, iC, kD, kH, kW, oD, oH, oW}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, -10.f, 11.f, 12.f, 2.f, 0.f, 4.f, 0.f, 6.f, 0.f, 8.f, 0.f, 10.f, 0.f, 12.f, 0.f, 3.f, 4.f, 5.f, 6.f, 0.f, 0.f, 9.f, 10.f, 11.f, 12.f, 0.f, 0.f, 4.f, 0.f, 6.f, 0.f, 0.f, 0.f, 10.f, 0.f, 12.f, 0.f, 0.f, 0.f, 5.f, 6.f, 0.f, 0.f, 0.f, 0.f, 11.f, 12.f, 0.f, 0.f, 0.f, 0.f, 6.f, 0.f, 0.f, 0.f, 0.f, 0.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 7.f, 8.f, -9.f, 10.f, 11.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 8.f, 0.f, 10.f, 0.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 9.f, 10.f, 11.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 10.f, 0.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 11.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, -0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 14.f, 0.f, 16.f, 0.f, 18.f, 0.f, 20.f, 0.f, 22.f, 0.f, 24.f, 0.f, 15.f, 16.f, 17.f, 18.f, 0.f, 0.f, 21.f, 22.f, 23.f, 24.f, 0.f, 0.f, 16.f, 0.f, 18.f, 0.f, 0.f, 0.f, 22.f, 0.f, 24.f, 0.f, 0.f, 0.f, 17.f, 18.f, 0.f, 0.f, 0.f, 0.f, -23.f, 24.f, 0.f, 0.f, 0.f, 0.f, 18.f, 0.f, 0.f, 0.f, 0.f, 0.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 20.f, 0.f, 22.f, 0.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 21.f, 22.f, 23.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 22.f, 0.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, -0.f, 0.f, 0.f, 23.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 26.f, 0.f, 28.f, 0.f, 30.f, 0.f, 32.f, 0.f, 34.f, 0.f, 36.f, 0.f, 27.f, 28.f, 29.f, 30.f, 0.f, 0.f, 33.f, 34.f, 35.f, 36.f, -0.f, 0.f, 28.f, 0.f, 30.f, 0.f, 0.f, 0.f, 34.f, 0.f, 36.f, 0.f, 0.f, 0.f, 29.f, 30.f, 0.f, 0.f, 0.f, 0.f, 35.f, 36.f, 0.f, 0.f, 0.f, 0.f, 30.f, 0.f, 0.f, 0.f, 0.f, 0.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 32.f, 0.f, 34.f, 0.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 33.f, -34.f, 35.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 34.f, 0.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 35.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 38.f, 0.f, 40.f, -0.f, 42.f, 0.f, 44.f, 0.f, 46.f, 0.f, 48.f, 0.f, 39.f, 40.f, 41.f, 42.f, 0.f, 0.f, 45.f, 46.f, 47.f, 48.f, 0.f, 0.f, 40.f, 0.f, 42.f, 0.f, 0.f, 0.f, 46.f, 0.f, 48.f, 0.f, 0.f, 0.f, 41.f, 42.f, 0.f, 0.f, 0.f, 0.f, 47.f, 48.f, 0.f, 0.f, 0.f, 0.f, 42.f, 0.f, 0.f, 0.f, 0.f, 0.f, 48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 43.f, 44.f, 45.f, 46.f, 47.f, -48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 44.f, 0.f, 46.f, 0.f, 48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 45.f, 46.f, 47.f, 48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 46.f, 0.f, 48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 47.f, 48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, -0.f, 0.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 50.f, 0.f, 52.f, 0.f, 54.f, 0.f, 56.f, 0.f, 58.f, 0.f, 60.f, 0.f, 51.f, 52.f, 53.f, 54.f, 0.f, 0.f, 57.f, 58.f, 59.f, 60.f, 0.f, 0.f, 52.f, 0.f, 54.f, 0.f, 0.f, 0.f, 58.f, 0.f, 60.f, 0.f, 0.f, 0.f, 53.f, 54.f, 0.f, 0.f, 0.f, 0.f, 59.f, 60.f, 0.f, 0.f, -0.f, 0.f, 54.f, 0.f, 0.f, 0.f, 0.f, 0.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 56.f, 0.f, 58.f, 0.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 57.f, 58.f, 59.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 58.f, 0.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 59.f, 60.f, -0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 62.f, 0.f, 64.f, 0.f, 66.f, 0.f, 68.f, 0.f, 70.f, 0.f, 72.f, 0.f, 63.f, 64.f, 65.f, 66.f, 0.f, 0.f, 69.f, 70.f, 71.f, 72.f, 0.f, 0.f, 64.f, 0.f, 66.f, -0.f, 0.f, 0.f, 70.f, 0.f, 72.f, 0.f, 0.f, 0.f, 65.f, 66.f, 0.f, 0.f, 0.f, 0.f, 71.f, 72.f, 0.f, 0.f, 0.f, 0.f, 66.f, 0.f, 0.f, 0.f, 0.f, 0.f, 72.f, 0.f, 0.f, 0.f, 0.f, 0.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 68.f, 0.f, 70.f, 0.f, 72.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 69.f, 70.f, 71.f, 72.f, 0.f, 0.f, + auto columnsExpected = NDArrayFactory::create('c', {bS, iC, kD, kH, kW, oD, oH, oW}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, +10.f, 11.f, 12.f, 2.f, 0.f, 4.f, 0.f, 6.f, 0.f, 8.f, 0.f, 10.f, 0.f, 12.f, 0.f, 3.f, 4.f, 5.f, 6.f, 0.f, 0.f, 9.f, 10.f, 11.f, 12.f, 0.f, 0.f, 4.f, 0.f, 6.f, 0.f, 0.f, 0.f, 10.f, 0.f, 12.f, 0.f, 0.f, 0.f, 5.f, 6.f, 0.f, 0.f, 0.f, 0.f, 11.f, 12.f, 0.f, 0.f, 0.f, 0.f, 6.f, 0.f, 0.f, 0.f, 0.f, 0.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 7.f, 8.f, +9.f, 10.f, 11.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 8.f, 0.f, 10.f, 0.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 9.f, 10.f, 11.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 10.f, 0.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 11.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, +0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 14.f, 0.f, 16.f, 0.f, 18.f, 0.f, 20.f, 0.f, 22.f, 0.f, 24.f, 0.f, 15.f, 16.f, 17.f, 18.f, 0.f, 0.f, 21.f, 22.f, 23.f, 24.f, 0.f, 0.f, 16.f, 0.f, 18.f, 0.f, 0.f, 0.f, 22.f, 0.f, 24.f, 0.f, 0.f, 0.f, 17.f, 18.f, 0.f, 0.f, 0.f, 0.f, +23.f, 24.f, 0.f, 0.f, 0.f, 0.f, 18.f, 0.f, 0.f, 0.f, 0.f, 0.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 20.f, 0.f, 22.f, 0.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 21.f, 22.f, 23.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 22.f, 0.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, +0.f, 0.f, 0.f, 23.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 26.f, 0.f, 28.f, 0.f, 30.f, 0.f, 32.f, 0.f, 34.f, 0.f, 36.f, 0.f, 27.f, 28.f, 29.f, 30.f, 0.f, 0.f, 33.f, 34.f, 35.f, 36.f, +0.f, 0.f, 28.f, 0.f, 30.f, 0.f, 0.f, 0.f, 34.f, 0.f, 36.f, 0.f, 0.f, 0.f, 29.f, 30.f, 0.f, 0.f, 0.f, 0.f, 35.f, 36.f, 0.f, 0.f, 0.f, 0.f, 30.f, 0.f, 0.f, 0.f, 0.f, 0.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 32.f, 0.f, 34.f, 0.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 33.f, +34.f, 35.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 34.f, 0.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 35.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 38.f, 0.f, 40.f, +0.f, 42.f, 0.f, 44.f, 0.f, 46.f, 0.f, 48.f, 0.f, 39.f, 40.f, 41.f, 42.f, 0.f, 0.f, 45.f, 46.f, 47.f, 48.f, 0.f, 0.f, 40.f, 0.f, 42.f, 0.f, 0.f, 0.f, 46.f, 0.f, 48.f, 0.f, 0.f, 0.f, 41.f, 42.f, 0.f, 0.f, 0.f, 0.f, 47.f, 48.f, 0.f, 0.f, 0.f, 0.f, 42.f, 0.f, 0.f, 0.f, 0.f, 0.f, 48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 43.f, 44.f, 45.f, 46.f, 47.f, +48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 44.f, 0.f, 46.f, 0.f, 48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 45.f, 46.f, 47.f, 48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 46.f, 0.f, 48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 47.f, 48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, +0.f, 0.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 50.f, 0.f, 52.f, 0.f, 54.f, 0.f, 56.f, 0.f, 58.f, 0.f, 60.f, 0.f, 51.f, 52.f, 53.f, 54.f, 0.f, 0.f, 57.f, 58.f, 59.f, 60.f, 0.f, 0.f, 52.f, 0.f, 54.f, 0.f, 0.f, 0.f, 58.f, 0.f, 60.f, 0.f, 0.f, 0.f, 53.f, 54.f, 0.f, 0.f, 0.f, 0.f, 59.f, 60.f, 0.f, 0.f, +0.f, 0.f, 54.f, 0.f, 0.f, 0.f, 0.f, 0.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 56.f, 0.f, 58.f, 0.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 57.f, 58.f, 59.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 58.f, 0.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 59.f, 60.f, +0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 62.f, 0.f, 64.f, 0.f, 66.f, 0.f, 68.f, 0.f, 70.f, 0.f, 72.f, 0.f, 63.f, 64.f, 65.f, 66.f, 0.f, 0.f, 69.f, 70.f, 71.f, 72.f, 0.f, 0.f, 64.f, 0.f, 66.f, +0.f, 0.f, 0.f, 70.f, 0.f, 72.f, 0.f, 0.f, 0.f, 65.f, 66.f, 0.f, 0.f, 0.f, 0.f, 71.f, 72.f, 0.f, 0.f, 0.f, 0.f, 66.f, 0.f, 0.f, 0.f, 0.f, 0.f, 72.f, 0.f, 0.f, 0.f, 0.f, 0.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 68.f, 0.f, 70.f, 0.f, 72.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 69.f, 70.f, 71.f, 72.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 70.f, 0.f, 72.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 71.f, 72.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 72.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}); graph::Context context(1); @@ -2164,11 +2196,11 @@ TEST_F(ConvolutionTests1, upsampling2d_test1) { auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); input.linspace(1); - auto expOutput = NDArrayFactory::create('c', {bS, iH*factorH, iW*factorW, iC}, {1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, - 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, - 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, - 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, - 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, + auto expOutput = NDArrayFactory::create('c', {bS, iH*factorH, iW*factorW, iC}, {1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, + 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, + 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, + 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, + 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f}); nd4j::ops::upsampling2d op; @@ -2192,11 +2224,11 @@ TEST_F(ConvolutionTests1, upsampling2d_test2) { auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); input.linspace(1); - auto expOutput = NDArrayFactory::create('c', {bS, iC, iH*factorH, iW*factorW}, {1.f, 1.f, 1.f, 2.f, 2.f, 2.f, 1.f, 1.f, 1.f, 2.f, 2.f, 2.f, 3.f, 3.f, 3.f, 4.f, 4.f, 4.f, 3.f, 3.f, 3.f, 4.f, 4.f, 4.f, - 5.f, 5.f, 5.f, 6.f, 6.f, 6.f, 5.f, 5.f, 5.f, 6.f, 6.f, 6.f, 7.f, 7.f, 7.f, 8.f, 8.f, 8.f, 7.f, 7.f, 7.f, 8.f, 8.f, 8.f, 9.f, 9.f, 9.f, 10.f, 10.f, 10.f, 9.f, 9.f, 9.f, 10.f, 10.f, 10.f, 11.f, 11.f, 11.f, 12.f, 12.f, 12.f, 11.f, 11.f, 11.f, 12.f, 12.f, 12.f, - 13.f, 13.f, 13.f, 14.f, 14.f, 14.f, 13.f, 13.f, 13.f, 14.f, 14.f, 14.f, 15.f, 15.f, 15.f, 16.f, 16.f, 16.f, 15.f, 15.f, 15.f, 16.f, 16.f, 16.f, 17.f, 17.f, 17.f, 18.f, 18.f, 18.f, 17.f, 17.f, 17.f, 18.f, 18.f, 18.f, 19.f, 19.f, 19.f, 20.f, 20.f, 20.f, 19.f, 19.f, 19.f, 20.f, 20.f, 20.f, - 21.f, 21.f, 21.f, 22.f, 22.f, 22.f, 21.f, 21.f, 21.f, 22.f, 22.f, 22.f, 23.f, 23.f, 23.f, 24.f, 24.f, 24.f, 23.f, 23.f, 23.f, 24.f, 24.f, 24.f, 25.f, 25.f, 25.f, 26.f, 26.f, 26.f, 25.f, 25.f, 25.f, 26.f, 26.f, 26.f, 27.f, 27.f, 27.f, 28.f, 28.f, 28.f, 27.f, 27.f, 27.f, 28.f, 28.f, 28.f, - 29.f, 29.f, 29.f, 30.f, 30.f, 30.f, 29.f, 29.f, 29.f, 30.f, 30.f, 30.f, 31.f, 31.f, 31.f, 32.f, 32.f, 32.f, 31.f, 31.f, 31.f, 32.f, 32.f, 32.f, + auto expOutput = NDArrayFactory::create('c', {bS, iC, iH*factorH, iW*factorW}, {1.f, 1.f, 1.f, 2.f, 2.f, 2.f, 1.f, 1.f, 1.f, 2.f, 2.f, 2.f, 3.f, 3.f, 3.f, 4.f, 4.f, 4.f, 3.f, 3.f, 3.f, 4.f, 4.f, 4.f, + 5.f, 5.f, 5.f, 6.f, 6.f, 6.f, 5.f, 5.f, 5.f, 6.f, 6.f, 6.f, 7.f, 7.f, 7.f, 8.f, 8.f, 8.f, 7.f, 7.f, 7.f, 8.f, 8.f, 8.f, 9.f, 9.f, 9.f, 10.f, 10.f, 10.f, 9.f, 9.f, 9.f, 10.f, 10.f, 10.f, 11.f, 11.f, 11.f, 12.f, 12.f, 12.f, 11.f, 11.f, 11.f, 12.f, 12.f, 12.f, + 13.f, 13.f, 13.f, 14.f, 14.f, 14.f, 13.f, 13.f, 13.f, 14.f, 14.f, 14.f, 15.f, 15.f, 15.f, 16.f, 16.f, 16.f, 15.f, 15.f, 15.f, 16.f, 16.f, 16.f, 17.f, 17.f, 17.f, 18.f, 18.f, 18.f, 17.f, 17.f, 17.f, 18.f, 18.f, 18.f, 19.f, 19.f, 19.f, 20.f, 20.f, 20.f, 19.f, 19.f, 19.f, 20.f, 20.f, 20.f, + 21.f, 21.f, 21.f, 22.f, 22.f, 22.f, 21.f, 21.f, 21.f, 22.f, 22.f, 22.f, 23.f, 23.f, 23.f, 24.f, 24.f, 24.f, 23.f, 23.f, 23.f, 24.f, 24.f, 24.f, 25.f, 25.f, 25.f, 26.f, 26.f, 26.f, 25.f, 25.f, 25.f, 26.f, 26.f, 26.f, 27.f, 27.f, 27.f, 28.f, 28.f, 28.f, 27.f, 27.f, 27.f, 28.f, 28.f, 28.f, + 29.f, 29.f, 29.f, 30.f, 30.f, 30.f, 29.f, 29.f, 29.f, 30.f, 30.f, 30.f, 31.f, 31.f, 31.f, 32.f, 32.f, 32.f, 31.f, 31.f, 31.f, 32.f, 32.f, 32.f, 33.f, 33.f, 33.f, 34.f, 34.f, 34.f, 33.f, 33.f, 33.f, 34.f, 34.f, 34.f, 35.f, 35.f, 35.f, 36.f, 36.f, 36.f, 35.f, 35.f, 35.f, 36.f, 36.f, 36.f}); nd4j::ops::upsampling2d op; @@ -2221,20 +2253,20 @@ TEST_F(ConvolutionTests1, upsampling3d_test1) { auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); input.linspace(1); - auto expOutput = NDArrayFactory::create('c', {bS, iD*factorD, iH*factorH, iW*factorW, iC}, {1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, - 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, - 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, - 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, - 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, - 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, - 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, - 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 37.f, 38.f, 39.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 40.f, 41.f, 42.f, 37.f, 38.f, 39.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 40.f, 41.f, 42.f, 37.f, 38.f, 39.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 40.f, 41.f, 42.f, 43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 46.f, 47.f, 48.f, - 43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 46.f, 47.f, 48.f, 43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 46.f, 47.f, 48.f, 37.f, 38.f, 39.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 40.f, 41.f, 42.f, 37.f, 38.f, 39.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 40.f, 41.f, 42.f, 37.f, 38.f, 39.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 40.f, 41.f, 42.f, - 43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 46.f, 47.f, 48.f, 43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 46.f, 47.f, 48.f, 43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 46.f, 47.f, 48.f, 49.f, 50.f, 51.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f, 49.f, 50.f, 51.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f, - 49.f, 50.f, 51.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f, 55.f, 56.f, 57.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 58.f, 59.f, 60.f, 55.f, 56.f, 57.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 58.f, 59.f, 60.f, 55.f, 56.f, 57.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 58.f, 59.f, 60.f, 49.f, 50.f, 51.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f, - 49.f, 50.f, 51.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f, 49.f, 50.f, 51.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f, 55.f, 56.f, 57.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 58.f, 59.f, 60.f, 55.f, 56.f, 57.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 58.f, 59.f, 60.f, 55.f, 56.f, 57.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 58.f, 59.f, 60.f, - 61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 64.f, 65.f, 66.f, 61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 64.f, 65.f, 66.f, 61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 64.f, 65.f, 66.f, 67.f, 68.f, 69.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f, 67.f, 68.f, 69.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f, - 67.f, 68.f, 69.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f, 61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 64.f, 65.f, 66.f, 61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 64.f, 65.f, 66.f, 61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 64.f, 65.f, 66.f, 67.f, 68.f, 69.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f, + auto expOutput = NDArrayFactory::create('c', {bS, iD*factorD, iH*factorH, iW*factorW, iC}, {1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, + 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, + 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, + 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, + 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, + 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, + 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, + 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 37.f, 38.f, 39.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 40.f, 41.f, 42.f, 37.f, 38.f, 39.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 40.f, 41.f, 42.f, 37.f, 38.f, 39.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 40.f, 41.f, 42.f, 43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 46.f, 47.f, 48.f, + 43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 46.f, 47.f, 48.f, 43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 46.f, 47.f, 48.f, 37.f, 38.f, 39.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 40.f, 41.f, 42.f, 37.f, 38.f, 39.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 40.f, 41.f, 42.f, 37.f, 38.f, 39.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 40.f, 41.f, 42.f, + 43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 46.f, 47.f, 48.f, 43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 46.f, 47.f, 48.f, 43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 46.f, 47.f, 48.f, 49.f, 50.f, 51.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f, 49.f, 50.f, 51.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f, + 49.f, 50.f, 51.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f, 55.f, 56.f, 57.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 58.f, 59.f, 60.f, 55.f, 56.f, 57.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 58.f, 59.f, 60.f, 55.f, 56.f, 57.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 58.f, 59.f, 60.f, 49.f, 50.f, 51.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f, + 49.f, 50.f, 51.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f, 49.f, 50.f, 51.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f, 55.f, 56.f, 57.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 58.f, 59.f, 60.f, 55.f, 56.f, 57.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 58.f, 59.f, 60.f, 55.f, 56.f, 57.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 58.f, 59.f, 60.f, + 61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 64.f, 65.f, 66.f, 61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 64.f, 65.f, 66.f, 61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 64.f, 65.f, 66.f, 67.f, 68.f, 69.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f, 67.f, 68.f, 69.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f, + 67.f, 68.f, 69.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f, 61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 64.f, 65.f, 66.f, 61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 64.f, 65.f, 66.f, 61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 64.f, 65.f, 66.f, 67.f, 68.f, 69.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f, 67.f, 68.f, 69.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f, 67.f, 68.f, 69.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f}); nd4j::ops::upsampling3d op; @@ -2258,17 +2290,17 @@ TEST_F(ConvolutionTests1, upsampling3d_test2) { auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); input.linspace(1); - auto expOutput = NDArrayFactory::create('c', {bS, iC, iD*factorD, iH*factorH, iW*factorW}, { 1.f, 1.f, 2.f, 2.f, 1.f, 1.f, 2.f, 2.f, 1.f, 1.f, 2.f, 2.f, 3.f, 3.f, 4.f, 4.f, 3.f, 3.f, 4.f, 4.f, 3.f, 3.f, 4.f, 4.f, 1.f, 1.f, 2.f, 2.f, 1.f, 1.f, 2.f, 2.f, 1.f, 1.f, 2.f, 2.f, 3.f, 3.f, 4.f, 4.f, 3.f, 3.f, 4.f, 4.f, 3.f, 3.f, 4.f, 4.f, 5.f, 5.f, 6.f, 6.f, 5.f, 5.f, 6.f, 6.f, 5.f, 5.f, 6.f, 6.f, 7.f, 7.f, 8.f, 8.f, 7.f, 7.f, 8.f, 8.f, 7.f, 7.f, 8.f, 8.f, - 5.f, 5.f, 6.f, 6.f, 5.f, 5.f, 6.f, 6.f, 5.f, 5.f, 6.f, 6.f, 7.f, 7.f, 8.f, 8.f, 7.f, 7.f, 8.f, 8.f, 7.f, 7.f, 8.f, 8.f, 9.f, 9.f, 10.f, 10.f, 9.f, 9.f, 10.f, 10.f, 9.f, 9.f, 10.f, 10.f, 11.f, 11.f, 12.f, 12.f, 11.f, 11.f, 12.f, 12.f, 11.f, 11.f, 12.f, 12.f, 9.f, 9.f, 10.f, 10.f, 9.f, 9.f, 10.f, 10.f, 9.f, 9.f, 10.f, 10.f, 11.f, 11.f, 12.f, 12.f, 11.f, 11.f, 12.f, 12.f, 11.f, 11.f, 12.f, 12.f, - 13.f, 13.f, 14.f, 14.f, 13.f, 13.f, 14.f, 14.f, 13.f, 13.f, 14.f, 14.f, 15.f, 15.f, 16.f, 16.f, 15.f, 15.f, 16.f, 16.f, 15.f, 15.f, 16.f, 16.f, 13.f, 13.f, 14.f, 14.f, 13.f, 13.f, 14.f, 14.f, 13.f, 13.f, 14.f, 14.f, 15.f, 15.f, 16.f, 16.f, 15.f, 15.f, 16.f, 16.f, 15.f, 15.f, 16.f, 16.f, 17.f, 17.f, 18.f, 18.f, 17.f, 17.f, 18.f, 18.f, 17.f, 17.f, 18.f, 18.f, 19.f, 19.f, 20.f, 20.f, 19.f, 19.f, 20.f, 20.f, 19.f, 19.f, 20.f, 20.f, - 17.f, 17.f, 18.f, 18.f, 17.f, 17.f, 18.f, 18.f, 17.f, 17.f, 18.f, 18.f, 19.f, 19.f, 20.f, 20.f, 19.f, 19.f, 20.f, 20.f, 19.f, 19.f, 20.f, 20.f, 21.f, 21.f, 22.f, 22.f, 21.f, 21.f, 22.f, 22.f, 21.f, 21.f, 22.f, 22.f, 23.f, 23.f, 24.f, 24.f, 23.f, 23.f, 24.f, 24.f, 23.f, 23.f, 24.f, 24.f, 21.f, 21.f, 22.f, 22.f, 21.f, 21.f, 22.f, 22.f, 21.f, 21.f, 22.f, 22.f, 23.f, 23.f, 24.f, 24.f, 23.f, 23.f, 24.f, 24.f, 23.f, 23.f, 24.f, 24.f, - 25.f, 25.f, 26.f, 26.f, 25.f, 25.f, 26.f, 26.f, 25.f, 25.f, 26.f, 26.f, 27.f, 27.f, 28.f, 28.f, 27.f, 27.f, 28.f, 28.f, 27.f, 27.f, 28.f, 28.f, 25.f, 25.f, 26.f, 26.f, 25.f, 25.f, 26.f, 26.f, 25.f, 25.f, 26.f, 26.f, 27.f, 27.f, 28.f, 28.f, 27.f, 27.f, 28.f, 28.f, 27.f, 27.f, 28.f, 28.f, 29.f, 29.f, 30.f, 30.f, 29.f, 29.f, 30.f, 30.f, 29.f, 29.f, 30.f, 30.f, 31.f, 31.f, 32.f, 32.f, 31.f, 31.f, 32.f, 32.f, 31.f, 31.f, 32.f, 32.f, - 29.f, 29.f, 30.f, 30.f, 29.f, 29.f, 30.f, 30.f, 29.f, 29.f, 30.f, 30.f, 31.f, 31.f, 32.f, 32.f, 31.f, 31.f, 32.f, 32.f, 31.f, 31.f, 32.f, 32.f, 33.f, 33.f, 34.f, 34.f, 33.f, 33.f, 34.f, 34.f, 33.f, 33.f, 34.f, 34.f, 35.f, 35.f, 36.f, 36.f, 35.f, 35.f, 36.f, 36.f, 35.f, 35.f, 36.f, 36.f, 33.f, 33.f, 34.f, 34.f, 33.f, 33.f, 34.f, 34.f, 33.f, 33.f, 34.f, 34.f, 35.f, 35.f, 36.f, 36.f, 35.f, 35.f, 36.f, 36.f, 35.f, 35.f, 36.f, 36.f, - 37.f, 37.f, 38.f, 38.f, 37.f, 37.f, 38.f, 38.f, 37.f, 37.f, 38.f, 38.f, 39.f, 39.f, 40.f, 40.f, 39.f, 39.f, 40.f, 40.f, 39.f, 39.f, 40.f, 40.f, 37.f, 37.f, 38.f, 38.f, 37.f, 37.f, 38.f, 38.f, 37.f, 37.f, 38.f, 38.f, 39.f, 39.f, 40.f, 40.f, 39.f, 39.f, 40.f, 40.f, 39.f, 39.f, 40.f, 40.f, 41.f, 41.f, 42.f, 42.f, 41.f, 41.f, 42.f, 42.f, 41.f, 41.f, 42.f, 42.f, 43.f, 43.f, 44.f, 44.f, 43.f, 43.f, 44.f, 44.f, 43.f, 43.f, 44.f, 44.f, - 41.f, 41.f, 42.f, 42.f, 41.f, 41.f, 42.f, 42.f, 41.f, 41.f, 42.f, 42.f, 43.f, 43.f, 44.f, 44.f, 43.f, 43.f, 44.f, 44.f, 43.f, 43.f, 44.f, 44.f, 45.f, 45.f, 46.f, 46.f, 45.f, 45.f, 46.f, 46.f, 45.f, 45.f, 46.f, 46.f, 47.f, 47.f, 48.f, 48.f, 47.f, 47.f, 48.f, 48.f, 47.f, 47.f, 48.f, 48.f, 45.f, 45.f, 46.f, 46.f, 45.f, 45.f, 46.f, 46.f, 45.f, 45.f, 46.f, 46.f, 47.f, 47.f, 48.f, 48.f, 47.f, 47.f, 48.f, 48.f, 47.f, 47.f, 48.f, 48.f, - 49.f, 49.f, 50.f, 50.f, 49.f, 49.f, 50.f, 50.f, 49.f, 49.f, 50.f, 50.f, 51.f, 51.f, 52.f, 52.f, 51.f, 51.f, 52.f, 52.f, 51.f, 51.f, 52.f, 52.f, 49.f, 49.f, 50.f, 50.f, 49.f, 49.f, 50.f, 50.f, 49.f, 49.f, 50.f, 50.f, 51.f, 51.f, 52.f, 52.f, 51.f, 51.f, 52.f, 52.f, 51.f, 51.f, 52.f, 52.f, 53.f, 53.f, 54.f, 54.f, 53.f, 53.f, 54.f, 54.f, 53.f, 53.f, 54.f, 54.f, 55.f, 55.f, 56.f, 56.f, 55.f, 55.f, 56.f, 56.f, 55.f, 55.f, 56.f, 56.f, - 53.f, 53.f, 54.f, 54.f, 53.f, 53.f, 54.f, 54.f, 53.f, 53.f, 54.f, 54.f, 55.f, 55.f, 56.f, 56.f, 55.f, 55.f, 56.f, 56.f, 55.f, 55.f, 56.f, 56.f, 57.f, 57.f, 58.f, 58.f, 57.f, 57.f, 58.f, 58.f, 57.f, 57.f, 58.f, 58.f, 59.f, 59.f, 60.f, 60.f, 59.f, 59.f, 60.f, 60.f, 59.f, 59.f, 60.f, 60.f, 57.f, 57.f, 58.f, 58.f, 57.f, 57.f, 58.f, 58.f, 57.f, 57.f, 58.f, 58.f, 59.f, 59.f, 60.f, 60.f, 59.f, 59.f, 60.f, 60.f, 59.f, 59.f, 60.f, 60.f, - 61.f, 61.f, 62.f, 62.f, 61.f, 61.f, 62.f, 62.f, 61.f, 61.f, 62.f, 62.f, 63.f, 63.f, 64.f, 64.f, 63.f, 63.f, 64.f, 64.f, 63.f, 63.f, 64.f, 64.f, 61.f, 61.f, 62.f, 62.f, 61.f, 61.f, 62.f, 62.f, 61.f, 61.f, 62.f, 62.f, 63.f, 63.f, 64.f, 64.f, 63.f, 63.f, 64.f, 64.f, 63.f, 63.f, 64.f, 64.f, 65.f, 65.f, 66.f, 66.f, 65.f, 65.f, 66.f, 66.f, 65.f, 65.f, 66.f, 66.f, 67.f, 67.f, 68.f, 68.f, 67.f, 67.f, 68.f, 68.f, 67.f, 67.f, 68.f, 68.f, + auto expOutput = NDArrayFactory::create('c', {bS, iC, iD*factorD, iH*factorH, iW*factorW}, { 1.f, 1.f, 2.f, 2.f, 1.f, 1.f, 2.f, 2.f, 1.f, 1.f, 2.f, 2.f, 3.f, 3.f, 4.f, 4.f, 3.f, 3.f, 4.f, 4.f, 3.f, 3.f, 4.f, 4.f, 1.f, 1.f, 2.f, 2.f, 1.f, 1.f, 2.f, 2.f, 1.f, 1.f, 2.f, 2.f, 3.f, 3.f, 4.f, 4.f, 3.f, 3.f, 4.f, 4.f, 3.f, 3.f, 4.f, 4.f, 5.f, 5.f, 6.f, 6.f, 5.f, 5.f, 6.f, 6.f, 5.f, 5.f, 6.f, 6.f, 7.f, 7.f, 8.f, 8.f, 7.f, 7.f, 8.f, 8.f, 7.f, 7.f, 8.f, 8.f, + 5.f, 5.f, 6.f, 6.f, 5.f, 5.f, 6.f, 6.f, 5.f, 5.f, 6.f, 6.f, 7.f, 7.f, 8.f, 8.f, 7.f, 7.f, 8.f, 8.f, 7.f, 7.f, 8.f, 8.f, 9.f, 9.f, 10.f, 10.f, 9.f, 9.f, 10.f, 10.f, 9.f, 9.f, 10.f, 10.f, 11.f, 11.f, 12.f, 12.f, 11.f, 11.f, 12.f, 12.f, 11.f, 11.f, 12.f, 12.f, 9.f, 9.f, 10.f, 10.f, 9.f, 9.f, 10.f, 10.f, 9.f, 9.f, 10.f, 10.f, 11.f, 11.f, 12.f, 12.f, 11.f, 11.f, 12.f, 12.f, 11.f, 11.f, 12.f, 12.f, + 13.f, 13.f, 14.f, 14.f, 13.f, 13.f, 14.f, 14.f, 13.f, 13.f, 14.f, 14.f, 15.f, 15.f, 16.f, 16.f, 15.f, 15.f, 16.f, 16.f, 15.f, 15.f, 16.f, 16.f, 13.f, 13.f, 14.f, 14.f, 13.f, 13.f, 14.f, 14.f, 13.f, 13.f, 14.f, 14.f, 15.f, 15.f, 16.f, 16.f, 15.f, 15.f, 16.f, 16.f, 15.f, 15.f, 16.f, 16.f, 17.f, 17.f, 18.f, 18.f, 17.f, 17.f, 18.f, 18.f, 17.f, 17.f, 18.f, 18.f, 19.f, 19.f, 20.f, 20.f, 19.f, 19.f, 20.f, 20.f, 19.f, 19.f, 20.f, 20.f, + 17.f, 17.f, 18.f, 18.f, 17.f, 17.f, 18.f, 18.f, 17.f, 17.f, 18.f, 18.f, 19.f, 19.f, 20.f, 20.f, 19.f, 19.f, 20.f, 20.f, 19.f, 19.f, 20.f, 20.f, 21.f, 21.f, 22.f, 22.f, 21.f, 21.f, 22.f, 22.f, 21.f, 21.f, 22.f, 22.f, 23.f, 23.f, 24.f, 24.f, 23.f, 23.f, 24.f, 24.f, 23.f, 23.f, 24.f, 24.f, 21.f, 21.f, 22.f, 22.f, 21.f, 21.f, 22.f, 22.f, 21.f, 21.f, 22.f, 22.f, 23.f, 23.f, 24.f, 24.f, 23.f, 23.f, 24.f, 24.f, 23.f, 23.f, 24.f, 24.f, + 25.f, 25.f, 26.f, 26.f, 25.f, 25.f, 26.f, 26.f, 25.f, 25.f, 26.f, 26.f, 27.f, 27.f, 28.f, 28.f, 27.f, 27.f, 28.f, 28.f, 27.f, 27.f, 28.f, 28.f, 25.f, 25.f, 26.f, 26.f, 25.f, 25.f, 26.f, 26.f, 25.f, 25.f, 26.f, 26.f, 27.f, 27.f, 28.f, 28.f, 27.f, 27.f, 28.f, 28.f, 27.f, 27.f, 28.f, 28.f, 29.f, 29.f, 30.f, 30.f, 29.f, 29.f, 30.f, 30.f, 29.f, 29.f, 30.f, 30.f, 31.f, 31.f, 32.f, 32.f, 31.f, 31.f, 32.f, 32.f, 31.f, 31.f, 32.f, 32.f, + 29.f, 29.f, 30.f, 30.f, 29.f, 29.f, 30.f, 30.f, 29.f, 29.f, 30.f, 30.f, 31.f, 31.f, 32.f, 32.f, 31.f, 31.f, 32.f, 32.f, 31.f, 31.f, 32.f, 32.f, 33.f, 33.f, 34.f, 34.f, 33.f, 33.f, 34.f, 34.f, 33.f, 33.f, 34.f, 34.f, 35.f, 35.f, 36.f, 36.f, 35.f, 35.f, 36.f, 36.f, 35.f, 35.f, 36.f, 36.f, 33.f, 33.f, 34.f, 34.f, 33.f, 33.f, 34.f, 34.f, 33.f, 33.f, 34.f, 34.f, 35.f, 35.f, 36.f, 36.f, 35.f, 35.f, 36.f, 36.f, 35.f, 35.f, 36.f, 36.f, + 37.f, 37.f, 38.f, 38.f, 37.f, 37.f, 38.f, 38.f, 37.f, 37.f, 38.f, 38.f, 39.f, 39.f, 40.f, 40.f, 39.f, 39.f, 40.f, 40.f, 39.f, 39.f, 40.f, 40.f, 37.f, 37.f, 38.f, 38.f, 37.f, 37.f, 38.f, 38.f, 37.f, 37.f, 38.f, 38.f, 39.f, 39.f, 40.f, 40.f, 39.f, 39.f, 40.f, 40.f, 39.f, 39.f, 40.f, 40.f, 41.f, 41.f, 42.f, 42.f, 41.f, 41.f, 42.f, 42.f, 41.f, 41.f, 42.f, 42.f, 43.f, 43.f, 44.f, 44.f, 43.f, 43.f, 44.f, 44.f, 43.f, 43.f, 44.f, 44.f, + 41.f, 41.f, 42.f, 42.f, 41.f, 41.f, 42.f, 42.f, 41.f, 41.f, 42.f, 42.f, 43.f, 43.f, 44.f, 44.f, 43.f, 43.f, 44.f, 44.f, 43.f, 43.f, 44.f, 44.f, 45.f, 45.f, 46.f, 46.f, 45.f, 45.f, 46.f, 46.f, 45.f, 45.f, 46.f, 46.f, 47.f, 47.f, 48.f, 48.f, 47.f, 47.f, 48.f, 48.f, 47.f, 47.f, 48.f, 48.f, 45.f, 45.f, 46.f, 46.f, 45.f, 45.f, 46.f, 46.f, 45.f, 45.f, 46.f, 46.f, 47.f, 47.f, 48.f, 48.f, 47.f, 47.f, 48.f, 48.f, 47.f, 47.f, 48.f, 48.f, + 49.f, 49.f, 50.f, 50.f, 49.f, 49.f, 50.f, 50.f, 49.f, 49.f, 50.f, 50.f, 51.f, 51.f, 52.f, 52.f, 51.f, 51.f, 52.f, 52.f, 51.f, 51.f, 52.f, 52.f, 49.f, 49.f, 50.f, 50.f, 49.f, 49.f, 50.f, 50.f, 49.f, 49.f, 50.f, 50.f, 51.f, 51.f, 52.f, 52.f, 51.f, 51.f, 52.f, 52.f, 51.f, 51.f, 52.f, 52.f, 53.f, 53.f, 54.f, 54.f, 53.f, 53.f, 54.f, 54.f, 53.f, 53.f, 54.f, 54.f, 55.f, 55.f, 56.f, 56.f, 55.f, 55.f, 56.f, 56.f, 55.f, 55.f, 56.f, 56.f, + 53.f, 53.f, 54.f, 54.f, 53.f, 53.f, 54.f, 54.f, 53.f, 53.f, 54.f, 54.f, 55.f, 55.f, 56.f, 56.f, 55.f, 55.f, 56.f, 56.f, 55.f, 55.f, 56.f, 56.f, 57.f, 57.f, 58.f, 58.f, 57.f, 57.f, 58.f, 58.f, 57.f, 57.f, 58.f, 58.f, 59.f, 59.f, 60.f, 60.f, 59.f, 59.f, 60.f, 60.f, 59.f, 59.f, 60.f, 60.f, 57.f, 57.f, 58.f, 58.f, 57.f, 57.f, 58.f, 58.f, 57.f, 57.f, 58.f, 58.f, 59.f, 59.f, 60.f, 60.f, 59.f, 59.f, 60.f, 60.f, 59.f, 59.f, 60.f, 60.f, + 61.f, 61.f, 62.f, 62.f, 61.f, 61.f, 62.f, 62.f, 61.f, 61.f, 62.f, 62.f, 63.f, 63.f, 64.f, 64.f, 63.f, 63.f, 64.f, 64.f, 63.f, 63.f, 64.f, 64.f, 61.f, 61.f, 62.f, 62.f, 61.f, 61.f, 62.f, 62.f, 61.f, 61.f, 62.f, 62.f, 63.f, 63.f, 64.f, 64.f, 63.f, 63.f, 64.f, 64.f, 63.f, 63.f, 64.f, 64.f, 65.f, 65.f, 66.f, 66.f, 65.f, 65.f, 66.f, 66.f, 65.f, 65.f, 66.f, 66.f, 67.f, 67.f, 68.f, 68.f, 67.f, 67.f, 68.f, 68.f, 67.f, 67.f, 68.f, 68.f, 65.f, 65.f, 66.f, 66.f, 65.f, 65.f, 66.f, 66.f, 65.f, 65.f, 66.f, 66.f, 67.f, 67.f, 68.f, 68.f, 67.f, 67.f, 68.f, 68.f, 67.f, 67.f, 68.f, 68.f, 69.f, 69.f, 70.f, 70.f, 69.f, 69.f, 70.f, 70.f, 69.f, 69.f, 70.f, 70.f, 71.f, 71.f, 72.f, 72.f, 71.f, 71.f, 72.f, 72.f, 71.f, 71.f, 72.f, 72.f, 69.f, 69.f, 70.f, 70.f, 69.f, 69.f, 70.f, 70.f, 69.f, 69.f, 70.f, 70.f, 71.f, 71.f, 72.f, 72.f, 71.f, 71.f, 72.f, 72.f, 71.f, 71.f, 72.f, 72.f}); nd4j::ops::upsampling3d op; @@ -2412,13 +2444,13 @@ TEST_F(ConvolutionTests1, deconv2d_test1) { auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); auto weights = NDArrayFactory::create('c', {kH, kW, oC, iC}); - auto exp = NDArrayFactory::create('c', {bS, oH, oW, oC}, { 2.75f, 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 27.75f, 32.75f, 37.75f, 42.75f, 47.75f, - 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f, - 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f, - 52.75f, 57.75f, 62.75f, 67.75f, 72.75f, 130.5f, 140.5f, 150.5f, 160.5f, 170.5f, 130.5f, 140.5f, 150.5f, 160.5f, 170.5f, 77.75f, 82.75f, 87.75f, 92.75f, 97.75f, - 2.75f, 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 27.75f, 32.75f, 37.75f, 42.75f, 47.75f, - 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f, - 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f, + auto exp = NDArrayFactory::create('c', {bS, oH, oW, oC}, { 2.75f, 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 27.75f, 32.75f, 37.75f, 42.75f, 47.75f, + 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f, + 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f, + 52.75f, 57.75f, 62.75f, 67.75f, 72.75f, 130.5f, 140.5f, 150.5f, 160.5f, 170.5f, 130.5f, 140.5f, 150.5f, 160.5f, 170.5f, 77.75f, 82.75f, 87.75f, 92.75f, 97.75f, + 2.75f, 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 27.75f, 32.75f, 37.75f, 42.75f, 47.75f, + 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f, + 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f, 52.75f, 57.75f, 62.75f, 67.75f, 72.75f, 130.5f, 140.5f, 150.5f, 160.5f, 170.5f, 130.5f, 140.5f, 150.5f, 160.5f, 170.5f, 77.75f, 82.75f, 87.75f, 92.75f, 97.75f}); input = 0.5; weights.linspace(0.1, 0.1); @@ -2445,13 +2477,13 @@ TEST_F(ConvolutionTests1, deconv2d_test2) { auto input = NDArrayFactory::create('c', {bS, oH, oW, oC}); auto weights = NDArrayFactory::create('c', {kH, kW, iC, oC}); - auto exp = NDArrayFactory::create('c', {bS, iH, iW, iC}, {2.75f, 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, - 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, - 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, - 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, - 2.75f, 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, - 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, - 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, + auto exp = NDArrayFactory::create('c', {bS, iH, iW, iC}, {2.75f, 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, + 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, + 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, + 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, + 2.75f, 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, + 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, + 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f }); input = 0.5; weights.linspace(0.1, 0.1); @@ -2673,13 +2705,13 @@ TYPED_TEST(TypedConvolutionTests1, deconv2d_tf_test1) { auto input = NDArrayFactory::create('c', {bS, oH, oW, oC}); auto weights = NDArrayFactory::create('c', {kH, kW, iC, oC}); auto outShape = NDArrayFactory::create('c', {4}, {static_cast(bS), static_cast(iH), static_cast(iW), static_cast(iC)}); - auto exp = NDArrayFactory::create('c', {bS, iH, iW, iC}, { 2.75f, 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 27.75f, 32.75f, 37.75f, 42.75f, 47.75f, - 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f, - 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f, - 52.75f, 57.75f, 62.75f, 67.75f, 72.75f, 130.5f, 140.5f, 150.5f, 160.5f, 170.5f, 130.5f, 140.5f, 150.5f, 160.5f, 170.5f, 77.75f, 82.75f, 87.75f, 92.75f, 97.75f, - 2.75f, 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 27.75f, 32.75f, 37.75f, 42.75f, 47.75f, - 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f, - 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f, + auto exp = NDArrayFactory::create('c', {bS, iH, iW, iC}, { 2.75f, 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 27.75f, 32.75f, 37.75f, 42.75f, 47.75f, + 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f, + 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f, + 52.75f, 57.75f, 62.75f, 67.75f, 72.75f, 130.5f, 140.5f, 150.5f, 160.5f, 170.5f, 130.5f, 140.5f, 150.5f, 160.5f, 170.5f, 77.75f, 82.75f, 87.75f, 92.75f, 97.75f, + 2.75f, 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 27.75f, 32.75f, 37.75f, 42.75f, 47.75f, + 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f, + 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f, 52.75f, 57.75f, 62.75f, 67.75f, 72.75f, 130.5f, 140.5f, 150.5f, 160.5f, 170.5f, 130.5f, 140.5f, 150.5f, 160.5f, 170.5f, 77.75f, 82.75f, 87.75f, 92.75f, 97.75f}); input = 0.5; weights.linspace(0.1, 0.1); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp index 209b16a7a..76a44be0b 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp @@ -428,10 +428,11 @@ TEST_F(DeclarableOpsTests13, CellContains_test_1) { TEST_F(DeclarableOpsTests13, adjustHue_1) { NDArray input('c', {2,2,3}, {0,100,56, 17,220,5, 150,97,230, 255,2,13}, nd4j::DataType::FLOAT32); + NDArray factor = NDArrayFactory::create(0.5); NDArray exp ('c', {2,2,3}, {100,0,44, 208,5,220, 177,230,97, 2,255,244}, nd4j::DataType::FLOAT32); nd4j::ops::adjust_hue op; - auto results = op.execute({&input}, {0.5}, {2}); + auto results = op.execute({&input, &factor}, {}, {2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -525,10 +526,11 @@ TEST_F(DeclarableOpsTests13, adjustHue_5) { TEST_F(DeclarableOpsTests13, adjustSaturation_1) { NDArray input('c', {2,2,3}, {0,100,56, 17,220,5, 150,97,230, 255,2,13}, nd4j::DataType::FLOAT32); + NDArray factor = NDArrayFactory::create(0.5); NDArray exp ('c', {2,2,3}, {50,100,78, 118.5,220,112.5, 190,163.5,230, 255,128.5,134}, nd4j::DataType::FLOAT32); nd4j::ops::adjust_saturation op; - auto results = op.execute({&input}, {0.5}, {2}); + auto results = op.execute({&input, &factor}, {}, {2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp index 41dc12a14..b2ccad86f 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp @@ -159,18 +159,19 @@ TEST_F(DeclarableOpsTests15, Test_standarize_bp_1) { TEST_F(DeclarableOpsTests15, Test_AdjustContrast_1) { auto x = NDArrayFactory::create('c', {4,4,3}); - auto e = NDArrayFactory::create('c', {4,4,3}, { - -21.5, -20.5, -19.5, -15.5, -14.5, -13.5, -9.5, -8.5, -7.5, -3.5, -2.5, -1.5, - 2.5, 3.5, 4.5, 8.5, 9.5, 10.5, 14.5, 15.5, 16.5, 20.5, 21.5, 22.5, - 26.5, 27.5, 28.5, 32.5, 33.5, 34.5, 38.5, 39.5, 40.5, 44.5, 45.5, 46.5, - 50.5, 51.5, 52.5, 56.5, 57.5, 58.5, 62.5, 63.5, 64.5, 68.5, 69.5, 70.5 - }); + NDArray factor = NDArrayFactory::create(2.); + auto e = NDArrayFactory::create('c', {4,4,3}, {-21.5, -20.5, -19.5, -15.5, -14.5, -13.5, -9.5, -8.5, -7.5, -3.5, -2.5, -1.5, + 2.5, 3.5, 4.5, 8.5, 9.5, 10.5, 14.5, 15.5, 16.5, 20.5, 21.5, 22.5, + 26.5, 27.5, 28.5, 32.5, 33.5, 34.5, 38.5, 39.5, 40.5, 44.5, 45.5, 46.5, + 50.5, 51.5, 52.5, 56.5, 57.5, 58.5, 62.5, 63.5, 64.5, 68.5, 69.5, 70.5}); + + x.linspace(1.); nd4j::ops::adjust_contrast op; - auto result = op.execute({&x}, {2.}, {}, {}); + auto result = op.execute({&x, &factor}, {}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); auto out = result->at(0); -// out->printIndexedBuffer("Adjusted Constrast"); + ASSERT_TRUE(e.equalsTo(out)); delete result; } diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests3.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests3.cpp index 269c13f51..a521be97b 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests3.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests3.cpp @@ -1774,6 +1774,28 @@ TEST_F(DeclarableOpsTests3, betainc_test10) { delete results; } +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests3, betainc_test11) { + + NDArray a('c', {4}, {0.7788f, 0.8012f, 0.7244f, 0.2309f}, nd4j::DataType::FLOAT32); + NDArray b('c', {4}, {0.7717f, 0.9281f, 0.9846f, 0.4838f}, nd4j::DataType::FLOAT32); + NDArray x('c', {4}, {0.9441f, 0.5957f, 0.8669f, 0.3502f}, nd4j::DataType::FLOAT32); + + NDArray expected('c', {4}, {0.912156, 0.634443, 0.898314, 0.624544}, nd4j::DataType::FLOAT32); + + nd4j::ops::betainc op; + auto results = op.execute({&a, &b, &x}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto *output = results->at(0); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + delete results; +} + /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, zeta_test1) { @@ -2092,8 +2114,26 @@ TEST_F(DeclarableOpsTests3, polygamma_test3) { x.linspace(10.); auto expected= NDArrayFactory::create('c', {3,3}, {1.05166336e-01,-9.04983497e-03, 1.31009323e-03,-2.44459433e-04, 5.31593880e-05,-1.28049888e-05, 3.31755364e-06,-9.07408791e-07, 2.58758130e-07}); + nd4j::ops::polygamma op; + auto results = op.execute({&n, &x}, {}, {}); - //ASSERT_FALSE(true); + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto output = results->at(0); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + delete results; +} + +TEST_F(DeclarableOpsTests3, polygamma_test4) { + + NDArray n('c', {3,4}, {/*0.7788*/0, 0,1,2,3,4,5,6,7,8,9,10}, nd4j::DataType::DOUBLE); + NDArray x('c', {3,4}, {0.7717,0.9281,0.9846,0.4838,0.6433,0.6041,0.6501,0.7612,0.7605,0.3948,0.9493,0.8600}, nd4j::DataType::DOUBLE); + + NDArray expected('c', {3,4}, {/*std::numeric_limits::quiet_NaN()*/-1.031918, -7.021327e-01, 1.682743e+00, -1.851378e+01,3.604167e+01, -3.008293e+02, + 1.596005e+03, -4.876665e+03,4.510025e+04, -1.730340e+08, 6.110257e+05, -1.907087e+07}, nd4j::DataType::DOUBLE); nd4j::ops::polygamma op; auto results = op.execute({&n, &x}, {}, {}); @@ -2108,6 +2148,26 @@ TEST_F(DeclarableOpsTests3, polygamma_test3) { delete results; } +TEST_F(DeclarableOpsTests3, digamma_1) { + + NDArray x('c', {18}, {-25, -24.99999, -21.5, -21.2, -5.5, -4.1, -2.1, -0.5, -0.3, 0., 0.2, 1, 1.5, 2.2, 5.2, 19., 21, 22.2}, nd4j::DataType::DOUBLE); + + NDArray expected('c', {18}, {std::numeric_limits::infinity(), -99996.761229, 3.091129, 7.401432, 1.792911,11.196838,10.630354, 0.03649, 2.11331, + std::numeric_limits::infinity(),-5.28904,-0.577216, 0.03649, 0.544293, 1.549434,2.917892, 3.020524, 3.077401}, nd4j::DataType::DOUBLE); + + nd4j::ops::digamma op; + auto results = op.execute({&x}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto output = results->at(0); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + delete results; +} + /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, svd_test1) { From d8339246d9a53964f182a5c2570decefb26f3d3c Mon Sep 17 00:00:00 2001 From: raver119 Date: Tue, 3 Dec 2019 10:23:19 +0300 Subject: [PATCH 24/30] fix typo in test Signed-off-by: raver119 --- libnd4j/tests_cpu/layers_tests/BooleanOpsTests.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libnd4j/tests_cpu/layers_tests/BooleanOpsTests.cpp b/libnd4j/tests_cpu/layers_tests/BooleanOpsTests.cpp index 20469ed2d..3cf9eeb04 100644 --- a/libnd4j/tests_cpu/layers_tests/BooleanOpsTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/BooleanOpsTests.cpp @@ -134,7 +134,7 @@ TEST_F(BooleanOpsTests, test_where_1) { auto y = NDArrayFactory::create('c', {6}, { 2, -3, 1, 1, -2, 1 }); auto e = NDArrayFactory::create('c', {3}, { 4, 8, 5 }); - nd4j:ops::choose op; + nd4j::ops::choose op; auto result = op.execute({&x, &y}, {}, {3}); ASSERT_EQ(Status::OK(), result->status()); From 190575196cfe900d117d142068bbee66b406b42c Mon Sep 17 00:00:00 2001 From: shugeo Date: Tue, 3 Dec 2019 14:06:38 +0200 Subject: [PATCH 25/30] Refactored pad and mirror_pad ops to conform with TF. (#100) --- .../include/ops/declarable/generic/transforms/mirrorPad.cpp | 2 +- libnd4j/include/ops/declarable/generic/transforms/pad.cpp | 3 ++- libnd4j/tests_cpu/layers_tests/DeclarableOpsTests7.cpp | 4 ++-- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/libnd4j/include/ops/declarable/generic/transforms/mirrorPad.cpp b/libnd4j/include/ops/declarable/generic/transforms/mirrorPad.cpp index 603bfdf61..fac8451a5 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/mirrorPad.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/mirrorPad.cpp @@ -56,7 +56,7 @@ CUSTOM_OP_IMPL(mirror_pad, 2, 1, false, 0, 1) { DECLARE_TYPES(mirror_pad) { getOpDescriptor()->setAllowedInputTypes(0, {ALL_FLOATS}); - getOpDescriptor()->setAllowedInputTypes(1, {ALL_INTS}); + getOpDescriptor()->setAllowedInputTypes(1, {DataType::INT32}); // to conform with TF getOpDescriptor()->setAllowedOutputTypes(0, {ALL_FLOATS}); } diff --git a/libnd4j/include/ops/declarable/generic/transforms/pad.cpp b/libnd4j/include/ops/declarable/generic/transforms/pad.cpp index 31a5d25b3..9d410a6c3 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/pad.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/pad.cpp @@ -78,7 +78,8 @@ CUSTOM_OP_IMPL(pad, 2, 1, false, 0, 1) { DECLARE_TYPES(pad) { getOpDescriptor() ->setAllowedInputTypes(0, nd4j::DataType::ANY) - ->setAllowedInputTypes(1, {DataType::INT32, DataType::INT64}) // INT32 with TF, but used also INT64 due long shapes + ->setAllowedInputTypes(1, {DataType::INT32}) // INT32 with TF +// ->setAllowedInputTypes(1, {DataType::INT32, DataType::INT64}) // INT32 with TF, but used also INT64 due long shapes ->setSameMode(true); } diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests7.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests7.cpp index f232411b2..23351f7af 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests7.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests7.cpp @@ -4549,7 +4549,7 @@ TEST_F(DeclarableOpsTests7, mirrorPad_test13) { TEST_F(DeclarableOpsTests7, mirrorPad_test14) { auto input = NDArrayFactory::create('c', {2, 3}, {1., 2., 3., 4., 5., 6.}); - auto paddings = NDArrayFactory::create('c', {2, 2}, {1, 0, 0, 1}); + auto paddings = NDArrayFactory::create('c', {2, 2}, {1LL, 0LL, 0LL, 1LL}); auto exp = NDArrayFactory::create('c', {3, 4}, {4, 5, 6, 5, 1, 2, 3, 2, 4, 5, 6, 5}); @@ -4567,7 +4567,7 @@ TEST_F(DeclarableOpsTests7, mirrorPad_test14) { TEST_F(DeclarableOpsTests7, mirrorPad_test15) { auto input = NDArrayFactory::create('c', {2, 3}, {1., 2., 3., 4., 5., 6.}); - auto paddings = NDArrayFactory::create('c', {2, 2}, {1, 1, 0, 0}); + auto paddings = NDArrayFactory::create('c', {2, 2}, {1, 1, 0, 0}); auto exp = NDArrayFactory::create('c', {4, 3}, {1, 2, 3, 1, 2, 3, 4, 5, 6, 4, 5, 6}); From 0d14032d26c6f25bd93476a079d70701227e89a0 Mon Sep 17 00:00:00 2001 From: Fariz Rahman Date: Wed, 4 Dec 2019 11:41:03 +0530 Subject: [PATCH 26/30] TF Updates (#87) * tf updates * pom * copyright * graphrunner tests * gpu test * getSessionOptionsConfigProto * dtype fix * Small fix Signed-off-by: AlexDBlack * cast graphs * savemodel test fix * testresource instead of local * Logging level Signed-off-by: AlexDBlack * gson dependency issue fix; fix GraphRunnerTest for no session options config case Signed-off-by: Alex Black * Final tweaks Signed-off-by: AlexDBlack * few minor fixes Signed-off-by: raver119 * one more fix Signed-off-by: raver119 * Tweak configuration for GraphRunnerTest Signed-off-by: AlexDBlack * nd4j align config * tf warmup --- .../jita/concurrency/CudaAffinityManager.java | 6 +- .../conversion/GraphRunnerTest.java | 88 +- .../conversion/GpuDeviceAlignmentTest.java | 33 - .../conversion/GpuGraphRunnerTest.java | 22 +- nd4j/nd4j-tensorflow/pom.xml | 11 + .../tensorflow/conversion/TensorDataType.java | 131 +++ .../conversion/TensorflowConversion.java | 3 +- .../conversion/graphrunner/GraphRunner.java | 880 +++++++++--------- .../main/resources/cast_graph/ai/__init__.py | 0 .../cast_graph/ai/konduit/__init__.py | 0 .../cast_graph/ai/konduit/casting.py | 17 + .../cast_graph/cast_float16_float16.pb | 5 + .../cast_graph/cast_float16_float32.pb | 11 + .../cast_graph/cast_float16_float64.pb | 11 + .../cast_graph/cast_float16_int16.pb | 11 + .../cast_graph/cast_float16_int32.pb | 11 + .../cast_graph/cast_float16_int64.pb | 11 + .../resources/cast_graph/cast_float16_int8.pb | 11 + .../cast_graph/cast_float16_uint16.pb | 11 + .../cast_graph/cast_float16_uint32.pb | 11 + .../cast_graph/cast_float16_uint64.pb | 11 + .../cast_graph/cast_float16_uint8.pb | 11 + .../cast_graph/cast_float32_float16.pb | 11 + .../cast_graph/cast_float32_float32.pb | 5 + .../cast_graph/cast_float32_float64.pb | 11 + .../cast_graph/cast_float32_int16.pb | 11 + .../cast_graph/cast_float32_int32.pb | 11 + .../cast_graph/cast_float32_int64.pb | 11 + .../resources/cast_graph/cast_float32_int8.pb | 11 + .../cast_graph/cast_float32_uint16.pb | 11 + .../cast_graph/cast_float32_uint32.pb | 11 + .../cast_graph/cast_float32_uint64.pb | 11 + .../cast_graph/cast_float32_uint8.pb | 11 + .../cast_graph/cast_float64_float16.pb | 11 + .../cast_graph/cast_float64_float32.pb | 11 + .../cast_graph/cast_float64_float64.pb | 5 + .../cast_graph/cast_float64_int16.pb | 11 + .../cast_graph/cast_float64_int32.pb | 11 + .../cast_graph/cast_float64_int64.pb | 11 + .../resources/cast_graph/cast_float64_int8.pb | 11 + .../cast_graph/cast_float64_uint16.pb | 11 + .../cast_graph/cast_float64_uint32.pb | 11 + .../cast_graph/cast_float64_uint64.pb | 11 + .../cast_graph/cast_float64_uint8.pb | 11 + .../cast_graph/cast_int16_float16.pb | 11 + .../cast_graph/cast_int16_float32.pb | 11 + .../cast_graph/cast_int16_float64.pb | 11 + .../resources/cast_graph/cast_int16_int16.pb | 5 + .../resources/cast_graph/cast_int16_int32.pb | 11 + .../resources/cast_graph/cast_int16_int64.pb | 11 + .../resources/cast_graph/cast_int16_int8.pb | 11 + .../resources/cast_graph/cast_int16_uint16.pb | 11 + .../resources/cast_graph/cast_int16_uint32.pb | 11 + .../resources/cast_graph/cast_int16_uint64.pb | 11 + .../resources/cast_graph/cast_int16_uint8.pb | 11 + .../cast_graph/cast_int32_float16.pb | 11 + .../cast_graph/cast_int32_float32.pb | 11 + .../cast_graph/cast_int32_float64.pb | 11 + .../resources/cast_graph/cast_int32_int16.pb | 11 + .../resources/cast_graph/cast_int32_int32.pb | 5 + .../resources/cast_graph/cast_int32_int64.pb | 11 + .../resources/cast_graph/cast_int32_int8.pb | 11 + .../resources/cast_graph/cast_int32_uint16.pb | 11 + .../resources/cast_graph/cast_int32_uint32.pb | 11 + .../resources/cast_graph/cast_int32_uint64.pb | 11 + .../resources/cast_graph/cast_int32_uint8.pb | 11 + .../cast_graph/cast_int64_float16.pb | 11 + .../cast_graph/cast_int64_float32.pb | 11 + .../cast_graph/cast_int64_float64.pb | 11 + .../resources/cast_graph/cast_int64_int16.pb | 11 + .../resources/cast_graph/cast_int64_int32.pb | 11 + .../resources/cast_graph/cast_int64_int64.pb | 5 + .../resources/cast_graph/cast_int64_int8.pb | 11 + .../resources/cast_graph/cast_int64_uint16.pb | 11 + .../resources/cast_graph/cast_int64_uint32.pb | 11 + .../resources/cast_graph/cast_int64_uint64.pb | 11 + .../resources/cast_graph/cast_int64_uint8.pb | 11 + .../resources/cast_graph/cast_int8_float16.pb | 11 + .../resources/cast_graph/cast_int8_float32.pb | 11 + .../resources/cast_graph/cast_int8_float64.pb | 11 + .../resources/cast_graph/cast_int8_int16.pb | 11 + .../resources/cast_graph/cast_int8_int32.pb | 11 + .../resources/cast_graph/cast_int8_int64.pb | 11 + .../resources/cast_graph/cast_int8_int8.pb | 5 + .../resources/cast_graph/cast_int8_uint16.pb | 11 + .../resources/cast_graph/cast_int8_uint32.pb | 11 + .../resources/cast_graph/cast_int8_uint64.pb | 11 + .../resources/cast_graph/cast_int8_uint8.pb | 11 + .../cast_graph/cast_uint16_float16.pb | 11 + .../cast_graph/cast_uint16_float32.pb | 11 + .../cast_graph/cast_uint16_float64.pb | 11 + .../resources/cast_graph/cast_uint16_int16.pb | 11 + .../resources/cast_graph/cast_uint16_int32.pb | 11 + .../resources/cast_graph/cast_uint16_int64.pb | 11 + .../resources/cast_graph/cast_uint16_int8.pb | 11 + .../cast_graph/cast_uint16_uint16.pb | 5 + .../cast_graph/cast_uint16_uint32.pb | 11 + .../cast_graph/cast_uint16_uint64.pb | 11 + .../resources/cast_graph/cast_uint16_uint8.pb | 11 + .../cast_graph/cast_uint32_float16.pb | 11 + .../cast_graph/cast_uint32_float32.pb | 11 + .../cast_graph/cast_uint32_float64.pb | 11 + .../resources/cast_graph/cast_uint32_int16.pb | 11 + .../resources/cast_graph/cast_uint32_int32.pb | 11 + .../resources/cast_graph/cast_uint32_int64.pb | 11 + .../resources/cast_graph/cast_uint32_int8.pb | 11 + .../cast_graph/cast_uint32_uint16.pb | 11 + .../cast_graph/cast_uint32_uint32.pb | 5 + .../cast_graph/cast_uint32_uint64.pb | 11 + .../resources/cast_graph/cast_uint32_uint8.pb | 11 + .../cast_graph/cast_uint64_float16.pb | 11 + .../cast_graph/cast_uint64_float32.pb | 11 + .../cast_graph/cast_uint64_float64.pb | 11 + .../resources/cast_graph/cast_uint64_int16.pb | 11 + .../resources/cast_graph/cast_uint64_int32.pb | 11 + .../resources/cast_graph/cast_uint64_int64.pb | 11 + .../resources/cast_graph/cast_uint64_int8.pb | 11 + .../cast_graph/cast_uint64_uint16.pb | 11 + .../cast_graph/cast_uint64_uint32.pb | 11 + .../cast_graph/cast_uint64_uint64.pb | 5 + .../resources/cast_graph/cast_uint64_uint8.pb | 11 + .../cast_graph/cast_uint8_float16.pb | 11 + .../cast_graph/cast_uint8_float32.pb | 11 + .../cast_graph/cast_uint8_float64.pb | 11 + .../resources/cast_graph/cast_uint8_int16.pb | 11 + .../resources/cast_graph/cast_uint8_int32.pb | 11 + .../resources/cast_graph/cast_uint8_int64.pb | 11 + .../resources/cast_graph/cast_uint8_int8.pb | 11 + .../resources/cast_graph/cast_uint8_uint16.pb | 11 + .../resources/cast_graph/cast_uint8_uint32.pb | 11 + .../resources/cast_graph/cast_uint8_uint64.pb | 11 + .../resources/cast_graph/cast_uint8_uint8.pb | 5 + 132 files changed, 1940 insertions(+), 516 deletions(-) delete mode 100644 nd4j/nd4j-backends/nd4j-tests-tensorflow/src/test/gpujava/org/nd4j/tensorflow/conversion/GpuDeviceAlignmentTest.java create mode 100644 nd4j/nd4j-tensorflow/src/main/java/org/nd4j/tensorflow/conversion/TensorDataType.java create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/ai/__init__.py create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/ai/konduit/__init__.py create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/ai/konduit/casting.py create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float16_float16.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float16_float32.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float16_float64.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float16_int16.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float16_int32.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float16_int64.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float16_int8.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float16_uint16.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float16_uint32.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float16_uint64.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float16_uint8.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float32_float16.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float32_float32.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float32_float64.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float32_int16.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float32_int32.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float32_int64.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float32_int8.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float32_uint16.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float32_uint32.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float32_uint64.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float32_uint8.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float64_float16.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float64_float32.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float64_float64.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float64_int16.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float64_int32.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float64_int64.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float64_int8.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float64_uint16.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float64_uint32.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float64_uint64.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float64_uint8.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int16_float16.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int16_float32.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int16_float64.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int16_int16.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int16_int32.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int16_int64.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int16_int8.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int16_uint16.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int16_uint32.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int16_uint64.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int16_uint8.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int32_float16.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int32_float32.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int32_float64.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int32_int16.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int32_int32.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int32_int64.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int32_int8.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int32_uint16.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int32_uint32.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int32_uint64.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int32_uint8.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int64_float16.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int64_float32.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int64_float64.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int64_int16.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int64_int32.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int64_int64.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int64_int8.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int64_uint16.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int64_uint32.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int64_uint64.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int64_uint8.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int8_float16.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int8_float32.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int8_float64.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int8_int16.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int8_int32.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int8_int64.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int8_int8.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int8_uint16.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int8_uint32.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int8_uint64.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int8_uint8.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint16_float16.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint16_float32.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint16_float64.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint16_int16.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint16_int32.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint16_int64.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint16_int8.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint16_uint16.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint16_uint32.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint16_uint64.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint16_uint8.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint32_float16.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint32_float32.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint32_float64.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint32_int16.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint32_int32.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint32_int64.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint32_int8.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint32_uint16.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint32_uint32.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint32_uint64.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint32_uint8.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint64_float16.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint64_float32.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint64_float64.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint64_int16.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint64_int32.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint64_int64.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint64_int8.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint64_uint16.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint64_uint32.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint64_uint64.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint64_uint8.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint8_float16.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint8_float32.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint8_float64.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint8_int16.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint8_int32.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint8_int64.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint8_int8.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint8_uint16.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint8_uint32.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint8_uint64.pb create mode 100644 nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint8_uint8.pb diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/concurrency/CudaAffinityManager.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/concurrency/CudaAffinityManager.java index cf362c460..aea78a4e0 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/concurrency/CudaAffinityManager.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/concurrency/CudaAffinityManager.java @@ -28,6 +28,7 @@ import org.nd4j.linalg.api.concurrency.AffinityManager; import org.nd4j.linalg.api.concurrency.BasicAffinityManager; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.jcublas.buffer.BaseCudaDataBuffer; import org.nd4j.linalg.jcublas.context.CudaContext; import org.nd4j.nativeblas.NativeOpsHolder; import org.slf4j.Logger; @@ -298,9 +299,12 @@ public class CudaAffinityManager extends BasicAffinityManager { @Override public void ensureLocation(INDArray array, Location location) { // to location to ensure for empty array - if (array.isEmpty()) + if (array.isEmpty() || array.isS()) return; + // let's make sure host pointer actually exists + ((BaseCudaDataBuffer) array.data()).lazyAllocateHostPointer(); + val point = AtomicAllocator.getInstance().getAllocationPoint(array); switch (location) { case HOST: { diff --git a/nd4j/nd4j-backends/nd4j-tests-tensorflow/src/test/cpujava/org/nd4j/tensorflow/conversion/GraphRunnerTest.java b/nd4j/nd4j-backends/nd4j-tests-tensorflow/src/test/cpujava/org/nd4j/tensorflow/conversion/GraphRunnerTest.java index ee188605d..c959edbfe 100644 --- a/nd4j/nd4j-backends/nd4j-tests-tensorflow/src/test/cpujava/org/nd4j/tensorflow/conversion/GraphRunnerTest.java +++ b/nd4j/nd4j-backends/nd4j-tests-tensorflow/src/test/cpujava/org/nd4j/tensorflow/conversion/GraphRunnerTest.java @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -16,6 +17,12 @@ package org.nd4j.tensorflow.conversion; +import junit.framework.TestCase; +import org.apache.commons.io.FileUtils; +import org.bytedeco.tensorflow.TF_Tensor; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.resources.Resources; +import org.nd4j.shade.protobuf.Descriptors; import org.nd4j.shade.protobuf.util.JsonFormat; import org.apache.commons.io.IOUtils; import org.junit.Ignore; @@ -27,6 +34,8 @@ import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.io.ClassPathResource; import org.nd4j.tensorflow.conversion.graphrunner.GraphRunner; import org.nd4j.tensorflow.conversion.graphrunner.SavedModelConfig; +import org.tensorflow.framework.ConfigProto; +import org.tensorflow.framework.GPUOptions; import java.io.File; import java.util.Arrays; @@ -39,12 +48,25 @@ import static org.junit.Assert.assertNotNull; public class GraphRunnerTest { + public static ConfigProto getConfig(){ + String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend"); + if("CUDA".equalsIgnoreCase(backend)) { + org.tensorflow.framework.ConfigProto configProto = org.tensorflow.framework.ConfigProto.getDefaultInstance(); + ConfigProto.Builder b = configProto.toBuilder().addDeviceFilters(TensorflowConversion.defaultDeviceForThread()); + return b.setGpuOptions(GPUOptions.newBuilder() + .setAllowGrowth(true) + .setPerProcessGpuMemoryFraction(0.5) + .build()).build(); + } + return null; + } + @Test public void testGraphRunner() throws Exception { List inputs = Arrays.asList("input_0","input_1"); byte[] content = IOUtils.toByteArray(new ClassPathResource("/tf_graphs/nd4j_convert/simple_graph/frozen_model.pb").getInputStream()); - try(GraphRunner graphRunner = new GraphRunner(content,inputs)) { + try(GraphRunner graphRunner = GraphRunner.builder().graphBytes(content).inputNames(inputs).sessionOptionsConfigProto(getConfig()).build()) { runGraphRunnerTest(graphRunner); } } @@ -52,8 +74,9 @@ public class GraphRunnerTest { @Test public void testGraphRunnerFilePath() throws Exception { List inputs = Arrays.asList("input_0","input_1"); - File file = new ClassPathResource("/tf_graphs/nd4j_convert/simple_graph/frozen_model.pb").getFile(); - try(GraphRunner graphRunner = new GraphRunner(file.getAbsolutePath(),inputs)) { + byte[] content = FileUtils.readFileToByteArray(Resources.asFile("/tf_graphs/nd4j_convert/simple_graph/frozen_model.pb")); + + try(GraphRunner graphRunner = GraphRunner.builder().graphBytes(content).inputNames(inputs).sessionOptionsConfigProto(getConfig()).build()) { runGraphRunnerTest(graphRunner); } } @@ -62,37 +85,42 @@ public class GraphRunnerTest { public void testInputOutputResolution() throws Exception { ClassPathResource lenetPb = new ClassPathResource("tf_graphs/lenet_frozen.pb"); byte[] content = IOUtils.toByteArray(lenetPb.getInputStream()); - GraphRunner graphRunner = new GraphRunner(content,Arrays.asList("Reshape/tensor")); - assertEquals(1,graphRunner.getInputOrder().size()); - assertEquals(1,graphRunner.getOutputOrder().size()); + List inputs = Arrays.asList("Reshape/tensor"); + try(GraphRunner graphRunner = GraphRunner.builder().graphBytes(content).inputNames(inputs).sessionOptionsConfigProto(getConfig()).build()) { + assertEquals(1, graphRunner.getInputOrder().size()); + assertEquals(1, graphRunner.getOutputOrder().size()); + } } @Test @Ignore //Ignored 2019/02/05: ssd_inception_v2_coco_2019_01_28 does not exist in test resources public void testMultiOutputGraph() throws Exception { - ClassPathResource classPathResource = new ClassPathResource("/tf_graphs/examples/ssd_inception_v2_coco_2018_01_28/frozen_inference_graph.pb"); - GraphRunner graphRunner = new GraphRunner(classPathResource.getFile().getAbsolutePath(),Arrays.asList("image_tensor")); - String[] outputs = new String[] { "detection_boxes", "detection_scores", "detection_classes", "num_detections"}; + List inputs = Arrays.asList("image_tensor"); + byte[] content = IOUtils.toByteArray(new ClassPathResource("/tf_graphs/examples/ssd_inception_v2_coco_2018_01_28/frozen_inference_graph.pb").getInputStream()); + try(GraphRunner graphRunner = GraphRunner.builder().graphBytes(content).inputNames(inputs).sessionOptionsConfigProto(getConfig()).build()) { + String[] outputs = new String[]{"detection_boxes", "detection_scores", "detection_classes", "num_detections"}; - assertEquals(1,graphRunner.getInputOrder().size()); - System.out.println(graphRunner.getOutputOrder()); - assertEquals(4,graphRunner.getOutputOrder().size()); + assertEquals(1, graphRunner.getInputOrder().size()); + System.out.println(graphRunner.getOutputOrder()); + assertEquals(4, graphRunner.getOutputOrder().size()); + } } private void runGraphRunnerTest(GraphRunner graphRunner) throws Exception { - - org.tensorflow.framework.ConfigProto.Builder builder = org.tensorflow.framework.ConfigProto.newBuilder(); String json = graphRunner.sessionOptionsToJson(); - JsonFormat.parser().merge(json,builder); - org.tensorflow.framework.ConfigProto build = builder.build(); - assertEquals(build,graphRunner.getProtoBufConfigProto()); + if( json != null ) { + org.tensorflow.framework.ConfigProto.Builder builder = org.tensorflow.framework.ConfigProto.newBuilder(); + JsonFormat.parser().merge(json, builder); + org.tensorflow.framework.ConfigProto build = builder.build(); + assertEquals(build,graphRunner.getSessionOptionsConfigProto()); + } assertNotNull(graphRunner.getInputOrder()); assertNotNull(graphRunner.getOutputOrder()); - org.tensorflow.framework.ConfigProto configProto1 = GraphRunner.fromJson(json); + org.tensorflow.framework.ConfigProto configProto1 = json == null ? null : GraphRunner.fromJson(json); - assertEquals(graphRunner.getProtoBufConfigProto(),configProto1); + assertEquals(graphRunner.getSessionOptionsConfigProto(),configProto1); assertEquals(2,graphRunner.getInputOrder().size()); assertEquals(1,graphRunner.getOutputOrder().size()); @@ -125,15 +153,31 @@ public class GraphRunnerTest { .signatureKey("incr_counter_by") .modelTag("serve") .build(); - try(GraphRunner graphRunner = new GraphRunner(savedModelConfig)) { + try(GraphRunner graphRunner = GraphRunner.builder().savedModelConfig(savedModelConfig).sessionOptionsConfigProto(getConfig()).build()) { INDArray delta = Nd4j.create(new float[] { 42 }, new long[0]); Map inputs = new LinkedHashMap<>(); - inputs.put("delta",delta); + inputs.put("delta:0",delta); Map outputs = graphRunner.run(inputs); assertEquals(1, outputs.size()); - INDArray output = outputs.get("output"); + System.out.println(Arrays.toString(outputs.keySet().toArray(new String[0]))); + INDArray output = outputs.values().toArray(new INDArray[0])[0]; assertEquals(42.0, output.getDouble(0), 0.0); } } + @Test + public void testGraphRunnerCast() { + INDArray arr = Nd4j.linspace(1,4,4).castTo(DataType.FLOAT); + TF_Tensor tensor = TensorflowConversion.getInstance().tensorFromNDArray(arr); + TF_Tensor tf_tensor = GraphRunner.castTensor(tensor, TensorDataType.FLOAT,TensorDataType.DOUBLE); + INDArray doubleNDArray = TensorflowConversion.getInstance().ndArrayFromTensor(tf_tensor); + TestCase.assertEquals(DataType.DOUBLE,doubleNDArray.dataType()); + + arr = arr.castTo(DataType.INT); + tensor = TensorflowConversion.getInstance().tensorFromNDArray(arr); + tf_tensor = GraphRunner.castTensor(tensor, TensorDataType.fromNd4jType(DataType.INT),TensorDataType.DOUBLE); + doubleNDArray = TensorflowConversion.getInstance().ndArrayFromTensor(tf_tensor); + TestCase.assertEquals(DataType.DOUBLE,doubleNDArray.dataType()); + + } } diff --git a/nd4j/nd4j-backends/nd4j-tests-tensorflow/src/test/gpujava/org/nd4j/tensorflow/conversion/GpuDeviceAlignmentTest.java b/nd4j/nd4j-backends/nd4j-tests-tensorflow/src/test/gpujava/org/nd4j/tensorflow/conversion/GpuDeviceAlignmentTest.java deleted file mode 100644 index ef3aaa872..000000000 --- a/nd4j/nd4j-backends/nd4j-tests-tensorflow/src/test/gpujava/org/nd4j/tensorflow/conversion/GpuDeviceAlignmentTest.java +++ /dev/null @@ -1,33 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.tensorflow.conversion; - -import org.junit.Test; -import org.nd4j.tensorflow.conversion.graphrunner.GraphRunner; -import org.tensorflow.framework.ConfigProto; - -import static junit.framework.TestCase.assertTrue; - -public class GpuDeviceAlignmentTest { - - @Test - public void testDeviceAlignment() { - ConfigProto configProto = GraphRunner.getAlignedWithNd4j(); - assertTrue(configProto.getDeviceFilters(0).contains("gpu")); - } - -} diff --git a/nd4j/nd4j-backends/nd4j-tests-tensorflow/src/test/gpujava/org/nd4j/tensorflow/conversion/GpuGraphRunnerTest.java b/nd4j/nd4j-backends/nd4j-tests-tensorflow/src/test/gpujava/org/nd4j/tensorflow/conversion/GpuGraphRunnerTest.java index 1ecc0e39a..614330813 100644 --- a/nd4j/nd4j-backends/nd4j-tests-tensorflow/src/test/gpujava/org/nd4j/tensorflow/conversion/GpuGraphRunnerTest.java +++ b/nd4j/nd4j-backends/nd4j-tests-tensorflow/src/test/gpujava/org/nd4j/tensorflow/conversion/GpuGraphRunnerTest.java @@ -1,5 +1,6 @@ -/******************************************************************************* +/* ****************************************************************************** * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -18,7 +19,6 @@ package org.nd4j.tensorflow.conversion; import org.nd4j.shade.protobuf.util.JsonFormat; import org.apache.commons.io.IOUtils; -import org.junit.Ignore; import org.junit.Test; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -28,6 +28,7 @@ import org.tensorflow.framework.ConfigProto; import org.tensorflow.framework.GPUOptions; import java.io.File; +import java.io.FileInputStream; import java.util.Arrays; import java.util.LinkedHashMap; import java.util.List; @@ -36,34 +37,34 @@ import java.util.Map; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; -@Ignore("AB 2019/05/24 - Failing on CI - no jnitensorflow in java.library.path - see issue #7657") public class GpuGraphRunnerTest { @Test public void testGraphRunner() throws Exception { - File f = new ClassPathResource("/tf_graphs/nd4j_convert/simple_graph/frozen_model.pb").getFile(); + byte[] content = IOUtils.toByteArray(new FileInputStream(new File("C:\\Users\\fariz\\code\\dl4j-test-resources\\src\\main\\resources\\tf_graphs\\nd4j_convert\\simple_graph\\frozen_model.pb"))); + //byte[] content = IOUtils.toByteArray(new ClassPathResource("/tf_graphs/nd4j_convert/simple_graph/frozen_model.pb").getInputStream()); List inputNames = Arrays.asList("input_0","input_1"); ConfigProto configProto = ConfigProto.newBuilder() .setGpuOptions(GPUOptions.newBuilder() - .setPerProcessGpuMemoryFraction(0.01) + .setPerProcessGpuMemoryFraction(0.1) .setAllowGrowth(false) .build()) .build(); - try(GraphRunner graphRunner = new GraphRunner(f.getAbsolutePath(), inputNames, configProto)) { + try(GraphRunner graphRunner = GraphRunner.builder().graphBytes(content).inputNames(inputNames).sessionOptionsConfigProto(configProto).build()) { org.tensorflow.framework.ConfigProto.Builder builder = org.tensorflow.framework.ConfigProto.newBuilder(); String json = graphRunner.sessionOptionsToJson(); JsonFormat.parser().merge(json,builder); org.tensorflow.framework.ConfigProto build = builder.build(); - assertEquals(build,graphRunner.getProtoBufConfigProto()); + assertEquals(build,graphRunner.getSessionOptionsConfigProto()); assertNotNull(graphRunner.getInputOrder()); assertNotNull(graphRunner.getOutputOrder()); org.tensorflow.framework.ConfigProto configProto1 = GraphRunner.fromJson(json); - assertEquals(graphRunner.getProtoBufConfigProto(),configProto1); + assertEquals(graphRunner.getSessionOptionsConfigProto(),configProto1); assertEquals(2,graphRunner.getInputOrder().size()); assertEquals(1,graphRunner.getOutputOrder().size()); @@ -83,9 +84,4 @@ public class GpuGraphRunnerTest { } } - - - - - } diff --git a/nd4j/nd4j-tensorflow/pom.xml b/nd4j/nd4j-tensorflow/pom.xml index ea9edd08f..fb859a95f 100644 --- a/nd4j/nd4j-tensorflow/pom.xml +++ b/nd4j/nd4j-tensorflow/pom.xml @@ -45,9 +45,20 @@ tensorflow ${tensorflow.javacpp.version} + + org.bytedeco + tensorflow-platform + ${tensorflow.javacpp.version} + + + com.google.code.gson + gson + ${gson.version} + junit junit + test diff --git a/nd4j/nd4j-tensorflow/src/main/java/org/nd4j/tensorflow/conversion/TensorDataType.java b/nd4j/nd4j-tensorflow/src/main/java/org/nd4j/tensorflow/conversion/TensorDataType.java new file mode 100644 index 000000000..74d053547 --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/java/org/nd4j/tensorflow/conversion/TensorDataType.java @@ -0,0 +1,131 @@ +/* ****************************************************************************** + * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.tensorflow.conversion; + +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.compression.CompressedDataBuffer; +import org.nd4j.linalg.compression.CompressionDescriptor; + +public enum TensorDataType { + INVALID, + FLOAT, + DOUBLE, + INT32, + UINT8, + INT16, + INT8, + STRING, + COMPLEX64, + INT64, + BOOL, + QINT8, + QUINT8, + QINT32, + BFLOAT16, + QINT16, + QUINT16, + UINT16, + COMPLEX128, + HALF, + RESOURCE, + VARIANT, + UINT32, + UINT64; + + + /** + * Map a tensor data type to a proto value found in tensorflow. + * Generally, this is just replacing DT_ with empty + * and returning enum.valueOf(string) + * @param value the input string + * @return the associated {@link TensorDataType} + */ + public static TensorDataType fromProtoValue(String value) { + String valueReplace = value.replace("DT_",""); + return TensorDataType.valueOf(valueReplace); + } + + + + /** + * Get the python name for the given data type + * @param tensorDataType the python name for the given data type + * @return float64 for double, float32 for double, float16 for half, otherwise + * the type's name converted to lower case + */ + public static String toPythonName(TensorDataType tensorDataType) { + switch(tensorDataType) { + case DOUBLE: return "float64"; + case FLOAT: return "float32"; + case HALF: return "float16"; + + default: return tensorDataType.name().toLowerCase(); + } + } + + public static DataType toNd4jType(TensorDataType tensorDataType) { + switch(tensorDataType) { + case FLOAT: return DataType.FLOAT; + case DOUBLE: return DataType.DOUBLE; + case BOOL: return DataType.BOOL; + case INT32: return DataType.INT; + case INT64: return DataType.LONG; + case STRING: return DataType.UTF8; + case HALF: return DataType.HALF; + default: throw new IllegalArgumentException("Unsupported type " + tensorDataType.name()); + } + } + + + public static TensorDataType fromNd4jType(DataType dataType) { + switch(dataType) { + case FLOAT: return TensorDataType.FLOAT; + case LONG: return TensorDataType.INT64; + case INT: return TensorDataType.INT32; + case BOOL: return TensorDataType.BOOL; + case DOUBLE: return TensorDataType.DOUBLE; + case HALF: return TensorDataType.HALF; + case UTF8: return TensorDataType.STRING; + case COMPRESSED: throw new IllegalStateException("Unable to work with compressed data type. Could be 1 or more types."); + case SHORT: return TensorDataType.INT16; + default: throw new IllegalArgumentException("Unknown data type " + dataType); + } + } + + public static TensorDataType fromNd4jType(INDArray array) { + DataType dataType = array.dataType(); + switch(dataType) { + case COMPRESSED: + CompressedDataBuffer compressedData = (CompressedDataBuffer) array.data(); + CompressionDescriptor desc = compressedData.getCompressionDescriptor(); + String algo = desc.getCompressionAlgorithm(); + switch (algo) { + case "FLOAT16": return HALF; + case "INT8": return INT8; + case "UINT8": return UINT8; + case "INT16": return INT16; + case "UINT16": return UINT16; + default: throw new IllegalArgumentException("Unsupported compression algorithm: " + algo); + } + + default: return fromNd4jType(dataType); + } + } + +} diff --git a/nd4j/nd4j-tensorflow/src/main/java/org/nd4j/tensorflow/conversion/TensorflowConversion.java b/nd4j/nd4j-tensorflow/src/main/java/org/nd4j/tensorflow/conversion/TensorflowConversion.java index 6eff18ecc..82c9b947e 100644 --- a/nd4j/nd4j-tensorflow/src/main/java/org/nd4j/tensorflow/conversion/TensorflowConversion.java +++ b/nd4j/nd4j-tensorflow/src/main/java/org/nd4j/tensorflow/conversion/TensorflowConversion.java @@ -239,7 +239,8 @@ public class TensorflowConversion { DataBuffer d = Nd4j.createBuffer(indexer.pointer(),nd4jType,length,indexer); array = Nd4j.create(d,ndShape); } - Nd4j.getAffinityManager().tagLocation(array, AffinityManager.Location.HOST); + // we don't need this in this case. Device memory will be updated right in the constructor + //Nd4j.getAffinityManager().tagLocation(array, AffinityManager.Location.HOST); return array; } diff --git a/nd4j/nd4j-tensorflow/src/main/java/org/nd4j/tensorflow/conversion/graphrunner/GraphRunner.java b/nd4j/nd4j-tensorflow/src/main/java/org/nd4j/tensorflow/conversion/graphrunner/GraphRunner.java index 79d45f781..9cb0a609b 100644 --- a/nd4j/nd4j-tensorflow/src/main/java/org/nd4j/tensorflow/conversion/graphrunner/GraphRunner.java +++ b/nd4j/nd4j-tensorflow/src/main/java/org/nd4j/tensorflow/conversion/graphrunner/GraphRunner.java @@ -16,26 +16,32 @@ package org.nd4j.tensorflow.conversion.graphrunner; +import lombok.Builder; +import lombok.Singular; +import org.apache.commons.io.FileUtils; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.io.ClassPathResource; +import org.nd4j.linalg.primitives.Pair; import org.nd4j.shade.protobuf.ByteString; import org.nd4j.shade.protobuf.InvalidProtocolBufferException; import org.nd4j.shade.protobuf.util.JsonFormat; import lombok.Getter; import lombok.Setter; import lombok.extern.slf4j.Slf4j; +import org.nd4j.tensorflow.conversion.TensorDataType; import org.apache.commons.io.IOUtils; import org.bytedeco.javacpp.BytePointer; import org.bytedeco.javacpp.PointerPointer; import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.factory.Nd4j; import org.nd4j.tensorflow.conversion.TensorflowConversion; import org.tensorflow.framework.ConfigProto; -import org.tensorflow.framework.GPUOptions; import org.tensorflow.framework.NodeDef; -import java.io.Closeable; -import java.io.File; -import java.io.IOException; +import java.io.*; import java.util.*; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.TimeUnit; import org.bytedeco.tensorflow.*; import static org.bytedeco.tensorflow.global.tensorflow.*; @@ -51,10 +57,13 @@ import static org.bytedeco.tensorflow.global.tensorflow.*; */ @Slf4j public class GraphRunner implements Closeable { + + private static boolean isTfWarmedUp = false; + private static boolean isTfWarmingUp = false; private SavedModelConfig savedModelConfig; //the in memory representation parsed from protobuf private TF_Graph graph; - //the conversion between nd4j and tensorflow + //the conversion between nd4j and TensorFlow private TensorflowConversion conversion = TensorflowConversion.getInstance(); //a persistent session to be used when running the graph private TF_Session session; @@ -64,170 +73,104 @@ public class GraphRunner implements Closeable { private TF_Status status; @Getter @Setter + @Singular private List inputOrder,outputOrder; @Getter - private org.tensorflow.framework.ConfigProto protoBufConfigProto; + private org.tensorflow.framework.ConfigProto sessionOptionsConfigProto; + @Getter + @Setter + @Singular + private Map inputDataTypes,outputDataTypes; + private static Map,GraphRunner> recastGraphDefs; - /** - * Pass in a graph instance and - * the length of the protobuf - * that it was instantiated with. - * For files this is typically - * {@link File#length()}, - * for byte arrays, this is - * byte array.length - * and for {@link java.nio.ByteBuffer} - * this would be something like the - * {@link java.nio.ByteBuffer#capacity()} - * @param inputNames the input names for the graph - * @param outputNames the output names in the graph - * @param graph a pointer to the {@link TF_Graph} to use when executing - * @param graphDef {@link org.tensorflow.framework.GraphDef} protobuf - * definition containing - * the graph configuration - * for automatically inferring - * things like - * graph inputs and outputs - * - * - */ - public GraphRunner(List inputNames,List outputNames,TF_Graph graph,org.tensorflow.framework.GraphDef graphDef) { - this(inputNames,outputNames,graph,graphDef,null); - - } - - /** - * Pass in a graph instance and - * the length of the protobuf - * that it was instantiated with. - * For files this is typically - * {@link File#length()}, - * for byte arrays, this is - * byte array.length - * and for {@link java.nio.ByteBuffer} - * this would be something like the - * {@link java.nio.ByteBuffer#capacity()} - * @param graph a pointer to the {@link TF_Graph} to use when executing - * @param graphDef {@link org.tensorflow.framework.GraphDef} protobuf - * definition containing - * the graph configuration - * for automatically inferring - * things like - * graph inputs and outputs - * @param configProto the session configuration proto to use with this runner - */ - public GraphRunner(List inputNames,List outputNames,TF_Graph graph,org.tensorflow.framework.GraphDef graphDef,ConfigProto configProto) { - this.graph = graph; - this.protoBufConfigProto = configProto; - this.inputOrder = inputNames; - this.outputOrder = outputNames; - initSessionAndStatusIfNeeded(graphDef); - - } - - /** - * Initialize with the graph content to use - * @param inputNames the inputs to the graph - * @param graphToUse the raw byte content - * of a protobuf file saved by tensorflow - */ - public GraphRunner(byte[] graphToUse,List inputNames,List outputNames) { - this(graphToUse,inputNames,outputNames,getAlignedWithNd4j()); + static { + recastGraphDefs = new ConcurrentHashMap<>(); } /** - * Initialize with the graph content to use - * @param filePath path of a protobuf file saved by tensorflow - * @param inputNames the input namesfor the graph + * The constructor for creating a graph runner via builder + * @param inputNames the input names to use + * @param outputNames the output names to use + * @param savedModelConfig the saved model configuration to load from (note this can not be used in conjunction + * with graph path) + * @param sessionOptionsConfigProto the session options for running the model (this maybe null) + * @param sessionOptionsProtoBytes the proto bytes equivalent of the session configuration + * @param sessionOptionsProtoPath the file path to a session configuration proto file + * @param graph the tensorflow graph to use + * @param graphPath the path to the graph + * @param graphBytes the in memory bytes of the graph + * @param inputDataTypes the expected input data types + * @param outputDataTypes the expected output data types */ - public GraphRunner(String filePath,List inputNames,List outputNames) { - this(filePath,inputNames,outputNames,getAlignedWithNd4j()); - } - - - - /** - * Initialize with the graph content to use - * @param filePath path of a protobuf file saved by tensorflow - * @param inputNames the names of the inputs for the graph - * @param sessionOptionsConfiguration the session options to use - * for running sessions - */ - public GraphRunner(String filePath,List inputNames,List outputNames,org.tensorflow.framework.ConfigProto sessionOptionsConfiguration) { - byte[] graphToUse = null; - + @Builder + public GraphRunner(List inputNames, + List outputNames, + SavedModelConfig savedModelConfig, + org.tensorflow.framework.ConfigProto sessionOptionsConfigProto, + byte[] sessionOptionsProtoBytes, + File sessionOptionsProtoPath, + TF_Graph graph, + File graphPath, + byte[] graphBytes, + Map inputDataTypes, + Map outputDataTypes) { try { - this.inputOrder = inputNames; - this.outputOrder = outputNames; - this.protoBufConfigProto = sessionOptionsConfiguration; - initOptionsIfNeeded(); - graphToUse = IOUtils.toByteArray(new File(filePath).toURI()); - this.graph = conversion.loadGraph(graphToUse, status); - } catch (Exception e) { - throw new IllegalArgumentException("Unable to parse protobuf",e); - } - - initSessionAndStatusIfNeeded(graphToUse); - } - - /** - * Initialize with the graph content to use - * @param graphToUse the raw byte content - * of a protobuf file saved by tensorflow - * @param sessionOptionsConfiguration the session options to use - * for running sessions - */ - public GraphRunner(byte[] graphToUse,List inputNames,List outputNames,org.tensorflow.framework.ConfigProto sessionOptionsConfiguration) { - try { - this.inputOrder = inputNames; - this.outputOrder = outputNames; - this.protoBufConfigProto = sessionOptionsConfiguration; - initOptionsIfNeeded(); - this.graph = conversion.loadGraph(graphToUse, status); - } catch (Exception e) { - throw new IllegalArgumentException("Unable to parse protobuf",e); - } - - initSessionAndStatusIfNeeded(graphToUse); - } + if(sessionOptionsConfigProto == null) { + if(sessionOptionsConfigProto != null) { + this.sessionOptionsConfigProto = ConfigProto.parseFrom(sessionOptionsProtoBytes); + } + else if(sessionOptionsProtoPath != null) { + byte[] load = FileUtils.readFileToByteArray(sessionOptionsProtoPath); + this.sessionOptionsConfigProto = ConfigProto.parseFrom(load); + } + } + else + this.sessionOptionsConfigProto = sessionOptionsConfigProto; - /** - * Initialize with the SavedModel to use - * @param inputNames (optional) the input names for the tensorflow graph - * @param outputNames the output names for the tensorflow graph - * @param savedModelConfig the configuration of the model to run - */ - public GraphRunner(List inputNames,List outputNames,SavedModelConfig savedModelConfig) { - this(inputNames,outputNames,savedModelConfig,getAlignedWithNd4j()); - } - - /** - * Initialize with the SavedModel to use - * @param inputNames (optional) the input names for the tensorflow graph - * @param outputNames (optional) the output names for the tensorflow graph - * @param savedModelConfig the configuration for the saved model - * @param sessionOptionsConfiguration the session options to use - * for running sessions - */ - public GraphRunner(List inputNames,List outputNames,SavedModelConfig savedModelConfig, ConfigProto sessionOptionsConfiguration) { - try { - this.savedModelConfig = savedModelConfig; - this.protoBufConfigProto = sessionOptionsConfiguration; + this.inputDataTypes = inputDataTypes; + this.outputDataTypes = outputDataTypes; //note that the input and output order, maybe null here //if the names are specified, we should defer to those instead this.inputOrder = inputNames; this.outputOrder = outputNames; initOptionsIfNeeded(); - Map inputsMap = new LinkedHashMap(); - Map outputsMap = new LinkedHashMap(); - this.graph = TF_NewGraph(); - this.session = conversion.loadSavedModel(savedModelConfig, options, null, graph, inputsMap, outputsMap, status); - inputOrder = new ArrayList(inputsMap.keySet()); - outputOrder = new ArrayList(outputsMap.keySet()); - savedModelConfig.setSavedModelInputOrder(new ArrayList(inputsMap.values())); - savedModelConfig.setSaveModelOutputOrder(new ArrayList(outputsMap.values())); + + if(graph != null) { + this.graph = graph; + } + else if(graphBytes != null) { + this.graph = conversion.loadGraph(graphBytes, status); + } + else if(graphPath != null) { + graphBytes = IOUtils.toByteArray(graphPath.toURI()); + this.graph = conversion.loadGraph(graphBytes, status); + } + else + this.graph = TF_NewGraph(); + + if(savedModelConfig != null) { + this.savedModelConfig = savedModelConfig; + Map inputsMap = new LinkedHashMap<>(); + Map outputsMap = new LinkedHashMap<>(); + + this.session = conversion.loadSavedModel(savedModelConfig, options, null, this.graph, inputsMap, outputsMap, status); + + if(inputOrder == null || inputOrder.isEmpty()) + inputOrder = new ArrayList<>(inputsMap.values()); + if(outputOrder == null || outputOrder.isEmpty()) + outputOrder = new ArrayList<>(outputsMap.values()); + + savedModelConfig.setSavedModelInputOrder(new ArrayList<>(inputsMap.values())); + savedModelConfig.setSaveModelOutputOrder(new ArrayList<>(outputsMap.values())); + log.info("Loaded input names from saved model configuration " + inputOrder); + log.info("Loaded output names from saved model configuration " + outputOrder); + + } + + + initSessionAndStatusIfNeeded(graphBytes); } catch (Exception e) { throw new IllegalArgumentException("Unable to parse protobuf",e); } @@ -235,135 +178,249 @@ public class GraphRunner implements Closeable { - /** - * Pass in a graph instance and - * the length of the protobuf - * that it was instantiated with. - * For files this is typically - * {@link File#length()}, - * for byte arrays, this is - * byte array.length - * and for {@link java.nio.ByteBuffer} - * this would be something like the - * {@link java.nio.ByteBuffer#capacity()} - * @param graph a pointer to the {@link TF_Graph} to use when executing - * @param graphDef {@link org.tensorflow.framework.GraphDef} protobuf - * definition containing - * the graph configuration - * for automatically inferring - * things like - * graph inputs and outputs + * Cast inputs from the original data type + * to the target resulting input data type. + * This is for when there's a disconnect from the inputs + * to the target input data type. This runs a pre cast automatically. + * @param inputs the inputs to cast + * @return the re casted input */ - public GraphRunner(List inputNames,TF_Graph graph,org.tensorflow.framework.GraphDef graphDef) { - this(inputNames,null,graph,graphDef,null); - - } - - /** - * Pass in a graph instance and - * the length of the protobuf - * that it was instantiated with. - * For files this is typically - * {@link File#length()}, - * for byte arrays, this is - * byte array.length - * and for {@link java.nio.ByteBuffer} - * this would be something like the - * {@link java.nio.ByteBuffer#capacity()} - * @param graph a pointer to the {@link TF_Graph} to use when executing - * @param graphDef {@link org.tensorflow.framework.GraphDef} protobuf - * definition containing - * the graph configuration - * for automatically inferring - * things like - * graph inputs and outputs - * @param configProto the session configuration proto to use with this runner - */ - public GraphRunner(List inputNames,TF_Graph graph,org.tensorflow.framework.GraphDef graphDef,ConfigProto configProto) { - this(inputNames,null,graph,graphDef,configProto); - - } - - /** - * Initialize with the graph content to use - * @param inputNames the inputs to the graph - * @param graphToUse the raw byte content - * of a protobuf file saved by tensorflow - */ - public GraphRunner(byte[] graphToUse,List inputNames) { - this(graphToUse,inputNames,getAlignedWithNd4j()); + public Map recastInputs(Map inputs) { + return recastInputs(inputs,inputOrder,inputDataTypes); } /** - * Initialize with the graph content to use - * @param filePath path of a protobuf file saved by tensorflow - * @param inputNames the input namesfor the graph + * Cast inputs from the original data type + * to the target resulting input data type. + * This is for when there's a disconnect from the inputs + * to the target input data type. This runs a pre cast automatically. + * @param inputs the inputs to cast + * @return the re casted input */ - public GraphRunner(String filePath,List inputNames) { - this(filePath,inputNames,getAlignedWithNd4j()); - } - - - - /** - * Initialize with the graph content to use - * @param filePath path of a protobuf file saved by tensorflow - * @param inputNames the names of the inputs for the graph - * @param sessionOptionsConfiguration the session options to use - * for running sessions - */ - public GraphRunner(String filePath,List inputNames,org.tensorflow.framework.ConfigProto sessionOptionsConfiguration) { - this(filePath,inputNames,null,sessionOptionsConfiguration); - } - - /** - * Initialize with the graph content to use - * @param graphToUse the raw byte content - * of a protobuf file saved by tensorflow - * @param sessionOptionsConfiguration the session options to use - * for running sessions - */ - public GraphRunner(byte[] graphToUse,List inputNames,org.tensorflow.framework.ConfigProto sessionOptionsConfiguration) { - this(graphToUse,inputNames,null,sessionOptionsConfiguration); + public Map recastOutputs(Map inputs) { + return recastInputs(inputs,outputOrder,outputDataTypes); } /** - * Initialize with the SavedModel to use - * @param savedModelConfig the configuration for loading the saved model + * Automatically recast the input arrays + * as the specified types + * @param inputs the input tensors to recast + * @param inputOrder the order of the input tensors + * @param inputDataTypes the data types to cast to (null means stay the same) + * @return the new values */ - public GraphRunner(SavedModelConfig savedModelConfig) { - this(savedModelConfig,getAlignedWithNd4j()); + public Map recastInputs(Map inputs, List inputOrder, Map inputDataTypes) { + if(inputDataTypes == null || inputDataTypes.isEmpty()) { + + inputDataTypes = new LinkedHashMap<>(); + for(int i = 0; i < inputOrder.size(); i++) { + TensorDataType tensorDataType = TensorDataType.values()[TF_TensorType(inputs.get(inputOrder.get(i)))]; + Preconditions.checkNotNull(tensorDataType,"Data type of " + TF_TensorType(inputs.get(inputOrder.get(i))) + " was null!"); + inputDataTypes.put(inputOrder.get(i),tensorDataType); + } + } + + Map ret = new HashMap<>(); + for(int i = 0; i < inputOrder.size(); i++) { + TF_Tensor currInput = inputs.get(inputOrder.get(i)); + TensorDataType fromDType = TensorDataType.values()[TF_TensorType(currInput)]; + if(fromDType != inputDataTypes.get(inputOrder.get(i))) { + TF_Tensor oldTensor = currInput; + currInput = castTensor(currInput, fromDType, inputDataTypes.get(inputOrder.get(i))); + TF_DeleteTensor(oldTensor); + } + + ret.put(inputOrder.get(i),currInput); + } + + return ret; } /** - * Initialize with the SavedModel to use - * @param savedModelConfig the configuration for loading the saved model - * @param sessionOptionsConfiguration the session options to use - * for running sessions + * Run the graph definition with the given inputs + * in native tensorflow + * @param inputs the inputs to run + * @return the outputSchema from the native tensorflow wrapper */ - public GraphRunner(SavedModelConfig savedModelConfig, ConfigProto sessionOptionsConfiguration) { - try { - this.savedModelConfig = savedModelConfig; - this.protoBufConfigProto = sessionOptionsConfiguration; - initOptionsIfNeeded(); - Map inputsMap = new LinkedHashMap<>(); - Map outputsMap = new LinkedHashMap<>(); - this.graph = TF_NewGraph(); - this.session = conversion.loadSavedModel(savedModelConfig, options, null, graph, inputsMap, outputsMap, status); - inputOrder = new ArrayList<>(inputsMap.keySet()); - outputOrder = new ArrayList<>(outputsMap.keySet()); - savedModelConfig.setSavedModelInputOrder(new ArrayList<>(inputsMap.values())); - savedModelConfig.setSaveModelOutputOrder(new ArrayList<>(outputsMap.values())); - } catch (Exception e) { - throw new IllegalArgumentException("Unable to parse protobuf",e); + public Map runTfTensor(Map inputs) { + if(graph == null) { + throw new IllegalStateException("Graph not initialized."); + } + + + if(inputs.size() != inputOrder.size()) { + throw new IllegalArgumentException("Number of inputs specified do not match number of arrays specified."); + } + + if(inputDataTypes == null) { + inputDataTypes = new LinkedHashMap<>(); + for(int i = 0; i < inputOrder.size(); i++) { + inputDataTypes.put(inputOrder.get(i),TensorDataType.values()[TF_TensorType(inputs.get(inputOrder.get(i)))]); + } + } + + for(Map.Entry entry : inputs.entrySet()) { + Preconditions.checkNotNull(entry.getValue(),"Entry " + entry.getKey() + " was null!"); + } + + //recast for adapting input + inputs = recastInputs(inputs); + + + if(savedModelConfig != null) { + Map outputArrays = new LinkedHashMap<>(); + + Map opsByName = new HashMap<>(); + org.bytedeco.tensorflow.TF_Output inputOut = new org.bytedeco.tensorflow.TF_Output(savedModelConfig.getSavedModelInputOrder().size()); + + TF_Tensor[] inputTensors = new TF_Tensor[savedModelConfig.getSavedModelInputOrder().size()]; + for(int i = 0; i < savedModelConfig.getSavedModelInputOrder().size(); i++) { + String[] name = savedModelConfig.getSavedModelInputOrder().get(i).split(":"); + org.bytedeco.tensorflow.TF_Operation inputOp = TF_GraphOperationByName(graph, name[0]); + opsByName.put(savedModelConfig.getSavedModelInputOrder().get(i),inputOp); + inputOut.position(i).oper(inputOp).index(name.length > 1 ? Integer.parseInt(name[1]) : 0); + TF_Tensor tfTensor = inputs.get(inputOrder != null && !inputOrder.isEmpty() + ? inputOrder.get(i) : savedModelConfig.getSavedModelInputOrder().get(i)); + inputTensors[i] = tfTensor; + } + + + //reset the position of the pointer for execution + inputOut.position(0); + + org.bytedeco.tensorflow.TF_Output outputOut = new org.bytedeco.tensorflow.TF_Output(savedModelConfig.getSaveModelOutputOrder().size()); + //only setup the output ops + for(int i = 0; i < savedModelConfig.getSaveModelOutputOrder().size(); i++) { + String[] name = savedModelConfig.getSaveModelOutputOrder().get(i).split(":"); + org.bytedeco.tensorflow.TF_Operation outputOp = TF_GraphOperationByName(graph, name[0]); + opsByName.put(savedModelConfig.getSaveModelOutputOrder().get(i),outputOp); + outputOut.position(i).oper(outputOp).index(name.length > 1 ? Integer.parseInt(name[1]) : 0); + } + + //reset the position of the pointer for execution + outputOut.position(0); + + + + //these are references to the nd4j ndarrays wrapped for tensorflow + PointerPointer inputTensorsPointer = new PointerPointer<>(inputTensors); + //note that these are the result pointers + //the result pointers are null, and will be populated automatically by the session run + PointerPointer outputTensorsPointer = new PointerPointer<>(savedModelConfig.getSaveModelOutputOrder().size()); + + long start = System.nanoTime(); + TF_SessionRun( + session, + null, + //inputs + inputOut, inputTensorsPointer, inputTensors.length, + //outputSchema + outputOut, outputTensorsPointer, savedModelConfig.getSaveModelOutputOrder().size(), + //targets + null, 0, + null, + status); long end = System.nanoTime(); + long diff = TimeUnit.NANOSECONDS.toMillis((end - start)); + log.debug("Session runtime: {} ms", diff); + + + + + if (TF_GetCode(status) != TF_OK) { + throw new IllegalStateException("ERROR: Unable to run session " + TF_Message(status).getString()); + } else { + for(int i = 0; i < outputOrder.size(); i++) { + outputArrays.put(outputOrder != null && !outputOrder.isEmpty() ? outputOrder.get(i) : + savedModelConfig.getSaveModelOutputOrder().get(i),new TF_Tensor(outputTensorsPointer.get(i))); + } + + } + + return outputArrays; + + } + else { + Map outputArrays = new LinkedHashMap<>(); + + Map opsByName = new HashMap<>(); + org.bytedeco.tensorflow.TF_Output inputOut = new org.bytedeco.tensorflow.TF_Output(inputOrder.size()); + + TF_Tensor[] inputTensors = new TF_Tensor[inputOrder.size()]; + for(int i = 0; i < inputOrder.size(); i++) { + String[] name = inputOrder.get(i).split(":"); + org.bytedeco.tensorflow.TF_Operation inputOp = TF_GraphOperationByName(graph, name[0]); + opsByName.put(inputOrder.get(i),inputOp); + inputOut.position(i).oper(inputOp).index(name.length > 1 ? Integer.parseInt(name[1]) : 0); + TF_Tensor tf_tensor = inputs.get(inputOrder.get(i)); + + inputTensors[i] = tf_tensor; + } + + + //reset the position of the pointer for execution + inputOut.position(0); + + org.bytedeco.tensorflow.TF_Output outputOut = new org.bytedeco.tensorflow.TF_Output(outputOrder.size()); + //only setup the output ops + for(int i = 0; i < outputOrder.size(); i++) { + String[] name = outputOrder.get(i).split(":"); + org.bytedeco.tensorflow.TF_Operation outputOp = TF_GraphOperationByName(graph, name[0]); + if(outputOp == null) { + throw new IllegalArgumentException("Illegal output found " + outputOrder.get(i) + " - no op found! Mis specified name perhaps?"); + } + + opsByName.put(outputOrder.get(i),outputOp); + outputOut.position(i).oper(outputOp).index(name.length > 1 ? Integer.parseInt(name[1]) : 0); + } + + //reset the position of the pointer for execution + outputOut.position(0); + + + + //these are references to the nd4j ndarrays wrapped for tensorflow + PointerPointer inputTensorsPointer = new PointerPointer<>(inputTensors); + //note that these are the result pointers + //the result pointers are null, and will be populated automatically by the session run + PointerPointer outputTensorsPointer = new PointerPointer<>(outputOrder.size()); + + long start = System.nanoTime(); + TF_SessionRun( + session, + null, + //inputs + inputOut, inputTensorsPointer, inputOrder.size(), + //output + outputOut, outputTensorsPointer, outputOrder.size(), + //targets + null, 0, + null, + status); + long end = System.nanoTime(); + long diff = TimeUnit.NANOSECONDS.toMillis((end - start)); + log.debug("Session runtime: {} ms", diff); + + + + + + + if (TF_GetCode(status) != TF_OK) { + throw new IllegalStateException("ERROR: Unable to run session " + TF_Message(status).getString()); + } else { + for(int i = 0; i < outputOrder.size(); i++) { + outputArrays.put(outputOrder.get(i),new TF_Tensor(outputTensorsPointer.get(i))); + } + } + + return outputArrays; } } - /** * Returns a map of the output names * to the ndarrays matching each output. @@ -382,159 +439,25 @@ public class GraphRunner implements Closeable { * {@link INDArray} * @return a map of the output names to the * ndarrays matching each output specified in the graph - * @throws IOException */ public Map run(Map inputs) { - if(graph == null) { - throw new IllegalStateException("Graph not initialized."); + if (!isTfWarmedUp && !isTfWarmingUp){ + isTfWarmingUp = true; + run(inputs); + isTfWarmedUp = true; + } + Map inputTensors = new LinkedHashMap<>(); + for(Map.Entry input : inputs.entrySet()) { + inputTensors.put(input.getKey(),conversion.tensorFromNDArray(input.getValue())); } - if(inputs.size() != inputOrder.size()) { - throw new IllegalArgumentException("Number of inputs specified do not match number of arrays specified."); - } - - - if(savedModelConfig != null) { - Map outputArrays = new LinkedHashMap<>(); - - Map opsByName = new HashMap<>(); - TF_Output inputOut = new TF_Output(savedModelConfig.getSavedModelInputOrder().size()); - - TF_Tensor[] inputTensors = new TF_Tensor[savedModelConfig.getSavedModelInputOrder().size()]; - for(int i = 0; i < savedModelConfig.getSavedModelInputOrder().size(); i++) { - String[] name = savedModelConfig.getSavedModelInputOrder().get(i).split(":"); - TF_Operation inputOp = TF_GraphOperationByName(graph, name[0]); - opsByName.put(savedModelConfig.getSavedModelInputOrder().get(i),inputOp); - inputOut.position(i).oper(inputOp).index(name.length > 1 ? Integer.parseInt(name[1]) : 0); - TF_Tensor tf_tensor = conversion.tensorFromNDArray(inputs.get(inputOrder != null && !inputOrder.isEmpty() - ? inputOrder.get(i) : savedModelConfig.getSavedModelInputOrder().get(i))); - inputTensors[i] = tf_tensor; - } - - - //reset the position of the pointer for execution - inputOut.position(0); - - TF_Output outputOut = new TF_Output(savedModelConfig.getSaveModelOutputOrder().size()); - //only setup the output ops - for(int i = 0; i < savedModelConfig.getSaveModelOutputOrder().size(); i++) { - String[] name =savedModelConfig.getSaveModelOutputOrder().get(i).split(":"); - TF_Operation outputOp = TF_GraphOperationByName(graph, name[0]); - opsByName.put(savedModelConfig.getSaveModelOutputOrder().get(i),outputOp); - outputOut.position(i).oper(outputOp).index(name.length > 1 ? Integer.parseInt(name[1]) : 0); - } - - //reset the position of the pointer for execution - outputOut.position(0); - - - - //these are references to the nd4j ndarrays wrapped for tensorflow - PointerPointer inputTensorsPointer = new PointerPointer<>(inputTensors); - //note that these are the result pointers - //the result pointers are null, and will be populated automatically by the session run - PointerPointer outputTensorsPointer = new PointerPointer<>(savedModelConfig.getSaveModelOutputOrder().size()); - - - TF_SessionRun( - session, - null, - //inputs - inputOut, inputTensorsPointer, inputTensors.length, - //outputs - outputOut, outputTensorsPointer, savedModelConfig.getSaveModelOutputOrder().size(), - //targets - null, 0, - null, - status); - - - if (TF_GetCode(status) != TF_OK) { - throw new IllegalStateException("ERROR: Unable to run session " + TF_Message(status).getString()); - } else { - for(int i = 0; i < outputOrder.size(); i++) { - INDArray to = conversion.ndArrayFromTensor(new TF_Tensor(outputTensorsPointer.get(i))); - outputArrays.put(outputOrder != null && !outputOrder.isEmpty() ? outputOrder.get(i) : - savedModelConfig.getSaveModelOutputOrder().get(i),to); - } - - } - - return outputArrays; - - } - else { - Map outputArrays = new LinkedHashMap<>(); - - Map opsByName = new HashMap<>(); - TF_Output inputOut = new TF_Output(inputOrder.size()); - - TF_Tensor[] inputTensors = new TF_Tensor[inputOrder.size()]; - for(int i = 0; i < inputOrder.size(); i++) { - String[] name = inputOrder.get(i).split(":"); - TF_Operation inputOp = TF_GraphOperationByName(graph, name[0]); - opsByName.put(inputOrder.get(i),inputOp); - inputOut.position(i).oper(inputOp).index(name.length > 1 ? Integer.parseInt(name[1]) : 0); - TF_Tensor tf_tensor = conversion.tensorFromNDArray(inputs.get(inputOrder.get(i))); - inputTensors[i] = tf_tensor; - } - - - //reset the position of the pointer for execution - inputOut.position(0); - - TF_Output outputOut = new TF_Output(outputOrder.size()); - //only setup the output ops - for(int i = 0; i < outputOrder.size(); i++) { - String[] name = outputOrder.get(i).split(":"); - TF_Operation outputOp = TF_GraphOperationByName(graph, name[0]); - if(outputOp == null) { - throw new IllegalArgumentException("Illegal input found " + inputOrder.get(i) + " - no op found! Mis specified name perhaps?"); - } - - opsByName.put(outputOrder.get(i),outputOp); - outputOut.position(i).oper(outputOp).index(name.length > 1 ? Integer.parseInt(name[1]) : 0); - } - - //reset the position of the pointer for execution - outputOut.position(0); - - - - //these are references to the nd4j ndarrays wrapped for tensorflow - PointerPointer inputTensorsPointer = new PointerPointer<>(inputTensors); - //note that these are the result pointers - //the result pointers are null, and will be populated automatically by the session run - PointerPointer outputTensorsPointer = new PointerPointer<>(outputOrder.size()); - - - TF_SessionRun( - session, - null, - //inputs - inputOut, inputTensorsPointer, inputTensors.length, - //outputs - outputOut, outputTensorsPointer, outputOrder.size(), - //targets - null, 0, - null, - status); - - - if (TF_GetCode(status) != TF_OK) { - throw new IllegalStateException("ERROR: Unable to run session " + TF_Message(status).getString()); - } else { - for(int i = 0; i < outputOrder.size(); i++) { - INDArray to = conversion.ndArrayFromTensor(new TF_Tensor(outputTensorsPointer.get(i))); - outputArrays.put(outputOrder.get(i),to); - } - - } - - return outputArrays; - + Map outputTensors = runTfTensor(inputTensors); + Map output = new LinkedHashMap<>(); + for(Map.Entry outputTensor : outputTensors.entrySet()) { + output.put(outputTensor.getKey(),conversion.ndArrayFromTensor(outputTensor.getValue())); } + return output; } @@ -546,8 +469,8 @@ public class GraphRunner implements Closeable { if (options == null) { options = TF_NewSessionOptions(); - if(protoBufConfigProto != null) { - BytePointer bytePointer = new BytePointer(protoBufConfigProto.toByteArray()); + if(sessionOptionsConfigProto != null) { + BytePointer bytePointer = new BytePointer(sessionOptionsConfigProto.toByteArray()); TF_SetConfig(options,bytePointer,bytePointer.getStringBytes().length,status); if (TF_GetCode(status) != TF_OK) { throw new IllegalStateException("ERROR: Unable to set value configuration:" + TF_Message(status).getString()); @@ -557,7 +480,7 @@ public class GraphRunner implements Closeable { } private void initSessionAndStatusIfNeeded(org.tensorflow.framework.GraphDef graphDef1) { - //infer the inputs and outputs for the graph + //infer the inputs and outputSchema for the graph Set seenAsInput = new LinkedHashSet<>(); for(int i = 0; i < graphDef1.getNodeCount(); i++) { NodeDef node = graphDef1.getNode(i); @@ -569,7 +492,7 @@ public class GraphRunner implements Closeable { if(outputOrder == null) { outputOrder = new ArrayList<>(); log.trace("Attempting to automatically resolve tensorflow output names.."); - //find the nodes that were not inputs to any nodes: these are the outputs + //find the nodes that were not inputs to any nodes: these are the outputSchema for(int i = 0; i < graphDef1.getNodeCount(); i++) { if(!seenAsInput.contains(graphDef1.getNode(i).getName()) && !graphDef1.getNode(i).getOp().equals("Placeholder")) { outputOrder.add(graphDef1.getNode(i).getName()); @@ -604,37 +527,20 @@ public class GraphRunner implements Closeable { } private void initSessionAndStatusIfNeeded(byte[] graphToUse) { + if(graphToUse == null) { + //saved model configuration + return; + } + try { //use the protobuf api to load the graph definition and load the node metadata org.tensorflow.framework.GraphDef graphDef1 = org.tensorflow.framework.GraphDef.parseFrom(graphToUse); initSessionAndStatusIfNeeded(graphDef1); - } catch (InvalidProtocolBufferException e) { + } catch (org.nd4j.shade.protobuf.InvalidProtocolBufferException e) { e.printStackTrace(); } } - public static org.tensorflow.framework.ConfigProto getAlignedWithNd4j() { - org.tensorflow.framework.ConfigProto configProto = org.tensorflow.framework.ConfigProto.getDefaultInstance(); - ConfigProto.Builder builder1 = configProto.toBuilder().addDeviceFilters(TensorflowConversion.defaultDeviceForThread()); - try { - //cuda - if(Nd4j.getBackend().getClass().getName().toLowerCase().contains("jcu")) { - builder1.setGpuOptions(GPUOptions.newBuilder() - .setAllowGrowth(true) - .setPerProcessGpuMemoryFraction(0.5) - .build()); - } - //cpu - else { - } - - } catch (Exception e) { - e.printStackTrace(); - } - - return builder1.build(); - } - /** * Convert a json string written out @@ -646,9 +552,9 @@ public class GraphRunner implements Closeable { public static org.tensorflow.framework.ConfigProto fromJson(String json) { org.tensorflow.framework.ConfigProto.Builder builder = org.tensorflow.framework.ConfigProto.newBuilder(); try { - JsonFormat.parser().merge(json,builder); + org.nd4j.shade.protobuf.util.JsonFormat.parser().merge(json,builder); org.tensorflow.framework.ConfigProto build = builder.build(); - ByteString serialized = build.toByteString(); + org.nd4j.shade.protobuf.ByteString serialized = build.toByteString(); byte[] binaryString = serialized.toByteArray(); org.tensorflow.framework.ConfigProto configProto = org.tensorflow.framework.ConfigProto.parseFrom(binaryString); return configProto; @@ -661,15 +567,76 @@ public class GraphRunner implements Closeable { /** - * Write out the session options used - * by this {@link GraphRunner} - * a s a json string using the - * {@link JsonFormat} - * @return + * Cast a tensor to another type using + * the tensorflow c api. + * This method loads a graph from the classpath from + * cast_graph/cast_(name of datatype lower case).pb + * which contains a simple protobuf file with a + * variant data type tensorflow input place holder + * named place holder and an output named cast_output. + * @param input the input data + * @param from the input data type to cast from + * @param to the output data type to + * @return the casted tensor */ - public String sessionOptionsToJson() { + public static TF_Tensor castTensor(TF_Tensor input, TensorDataType from, TensorDataType to) { + if(from.equals(to)) + return input; + + Map inputMap = new HashMap<>(); + inputMap.put("input",input); + GraphRunner graphRunner = getRunner(from,to); try { - return JsonFormat.printer().print(protoBufConfigProto); + Map output = graphRunner.runTfTensor(inputMap); + return output.get("cast_output"); + + } catch(Exception e) { + throw new IllegalStateException("Unable to run graph",e); + } + } + + private static GraphRunner getRunner(TensorDataType from,TensorDataType to) { + Pair key = Pair.of(from,to); + if(!recastGraphDefs.containsKey(key)) { + byte[] graphForDataType = graphForDataType(from,to); + GraphRunner graphRunner = GraphRunner.builder() + .graphBytes(graphForDataType) + .inputNames(Arrays.asList("input")) + .outputNames(Arrays.asList("cast_output")) + .build(); + + recastGraphDefs.put(key,graphRunner); + return graphRunner; + } + + return recastGraphDefs.get(key); + } + + + private static byte[] graphForDataType(TensorDataType from,TensorDataType to) { + ClassPathResource classPathResource = new ClassPathResource("cast_graph/cast_" + TensorDataType.toPythonName(from) + "_" + TensorDataType.toPythonName(to) + ".pb"); + ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(); + try (InputStream is = classPathResource.getInputStream()) { + IOUtils.copy(is, byteArrayOutputStream); + } catch (IOException e) { + throw new IllegalStateException("Unable to read graph " + classPathResource.getFilename(),e); + } + + return byteArrayOutputStream.toByteArray(); + } + + /** + * Write out the session options used + * by this {@link org.nd4j.tensorflow.conversion.graphrunner.GraphRunner} + * a s a json string using the + * {@link org.nd4j.shade.protobuf.util.JsonFormat} + * @return the session options as json (mainly for debugging) + */ + public String sessionOptionsToJson() { + if(sessionOptionsConfigProto == null) + return null; + try { + return org.nd4j.shade.protobuf.util.JsonFormat.printer().print(sessionOptionsConfigProto); } catch (Exception e) { e.printStackTrace(); } @@ -695,4 +662,25 @@ public class GraphRunner implements Closeable { TF_DeleteStatus(status); } } + public static org.tensorflow.framework.ConfigProto getAlignedWithNd4j() { + org.tensorflow.framework.ConfigProto configProto = org.tensorflow.framework.ConfigProto.getDefaultInstance(); + ConfigProto.Builder builder1 = configProto.toBuilder().addDeviceFilters(TensorflowConversion.defaultDeviceForThread()); + try { + //cuda + if(Nd4j.getBackend().getClass().getName().toLowerCase().contains("jcu")) { + builder1.setGpuOptions(org.tensorflow.framework.GPUOptions.newBuilder() + .setAllowGrowth(true) + .setPerProcessGpuMemoryFraction(0.5) + .build()); + } + //cpu + else { + } + + } catch (Exception e) { + e.printStackTrace(); + } + + return builder1.build(); + } } diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/ai/__init__.py b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/ai/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/ai/konduit/__init__.py b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/ai/konduit/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/ai/konduit/casting.py b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/ai/konduit/casting.py new file mode 100644 index 000000000..b6123f557 --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/ai/konduit/casting.py @@ -0,0 +1,17 @@ +import tensorflow as tf + +dtypes = [tf.float16,tf.float32,tf.float64,tf.int8,tf.int16,tf.int32,tf.int64,tf.uint8,tf.uint16,tf.uint32,tf.uint64] +# Quick solution from https://stackoverflow.com/questions/5360220/how-to-split-a-list-into-pairs-in-all-possible-ways :) +import itertools +def all_pairs(lst): + return [(x,y) for x in dtypes for y in dtypes] + + +for item in all_pairs(dtypes): + from_dtype, out_dtype = item + tf.reset_default_graph() + input = tf.placeholder(name='input',dtype=from_dtype) + result = tf.cast(input,name='cast_output',dtype=out_dtype) + + with tf.Session() as session: + tf.train.write_graph(tf.get_default_graph(),logdir='.',name='cast_' + from_dtype.name + '_' + out_dtype.name + '.pb',as_text=True) \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float16_float16.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float16_float16.pb new file mode 100644 index 000000000..6ed3318d5 --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float16_float16.pb @@ -0,0 +1,5 @@ + +0 +input Placeholder* +shape:* +dtype0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float16_float32.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float16_float32.pb new file mode 100644 index 000000000..909d8f5ab --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float16_float32.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +dtype0* +shape: +2 + cast_outputCastinput* + +SrcT0* + +DstT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float16_float64.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float16_float64.pb new file mode 100644 index 000000000..0b51cb893 --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float16_float64.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +dtype0* +shape: +2 + cast_outputCastinput* + +SrcT0* + +DstT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float16_int16.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float16_int16.pb new file mode 100644 index 000000000..ac8d85756 --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float16_int16.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +shape:* +dtype0 +2 + cast_outputCastinput* + +SrcT0* + +DstT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float16_int32.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float16_int32.pb new file mode 100644 index 000000000..4f057d6ef --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float16_int32.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +dtype0* +shape: +2 + cast_outputCastinput* + +SrcT0* + +DstT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float16_int64.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float16_int64.pb new file mode 100644 index 000000000..92e8fc124 --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float16_int64.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +dtype0* +shape: +2 + cast_outputCastinput* + +SrcT0* + +DstT0 " \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float16_int8.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float16_int8.pb new file mode 100644 index 000000000..4317504d5 --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float16_int8.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +dtype0* +shape: +2 + cast_outputCastinput* + +SrcT0* + +DstT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float16_uint16.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float16_uint16.pb new file mode 100644 index 000000000..8f2f205ce --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float16_uint16.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +shape:* +dtype0 +2 + cast_outputCastinput* + +SrcT0* + +DstT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float16_uint32.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float16_uint32.pb new file mode 100644 index 000000000..0825b92ad --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float16_uint32.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +dtype0* +shape: +2 + cast_outputCastinput* + +SrcT0* + +DstT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float16_uint64.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float16_uint64.pb new file mode 100644 index 000000000..2a3284c7b --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float16_uint64.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +dtype0* +shape: +2 + cast_outputCastinput* + +DstT0* + +SrcT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float16_uint8.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float16_uint8.pb new file mode 100644 index 000000000..f06af1592 --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float16_uint8.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +dtype0* +shape: +2 + cast_outputCastinput* + +DstT0* + +SrcT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float32_float16.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float32_float16.pb new file mode 100644 index 000000000..4a947478d --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float32_float16.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +dtype0* +shape: +2 + cast_outputCastinput* + +DstT0* + +SrcT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float32_float32.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float32_float32.pb new file mode 100644 index 000000000..843c041e3 --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float32_float32.pb @@ -0,0 +1,5 @@ + +0 +input Placeholder* +shape:* +dtype0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float32_float64.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float32_float64.pb new file mode 100644 index 000000000..d65a4b1c5 --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float32_float64.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +shape:* +dtype0 +2 + cast_outputCastinput* + +SrcT0* + +DstT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float32_int16.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float32_int16.pb new file mode 100644 index 000000000..32336df2e --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float32_int16.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +dtype0* +shape: +2 + cast_outputCastinput* + +DstT0* + +SrcT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float32_int32.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float32_int32.pb new file mode 100644 index 000000000..e94d38894 --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float32_int32.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +dtype0* +shape: +2 + cast_outputCastinput* + +SrcT0* + +DstT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float32_int64.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float32_int64.pb new file mode 100644 index 000000000..721274c2a --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float32_int64.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +shape:* +dtype0 +2 + cast_outputCastinput* + +SrcT0* + +DstT0 " \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float32_int8.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float32_int8.pb new file mode 100644 index 000000000..00463da39 --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float32_int8.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +dtype0* +shape: +2 + cast_outputCastinput* + +SrcT0* + +DstT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float32_uint16.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float32_uint16.pb new file mode 100644 index 000000000..5428c9239 --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float32_uint16.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +shape:* +dtype0 +2 + cast_outputCastinput* + +DstT0* + +SrcT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float32_uint32.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float32_uint32.pb new file mode 100644 index 000000000..45a199fd7 --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float32_uint32.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +shape:* +dtype0 +2 + cast_outputCastinput* + +SrcT0* + +DstT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float32_uint64.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float32_uint64.pb new file mode 100644 index 000000000..4c6fd3bee --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float32_uint64.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +shape:* +dtype0 +2 + cast_outputCastinput* + +SrcT0* + +DstT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float32_uint8.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float32_uint8.pb new file mode 100644 index 000000000..653f37625 --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float32_uint8.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +dtype0* +shape: +2 + cast_outputCastinput* + +SrcT0* + +DstT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float64_float16.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float64_float16.pb new file mode 100644 index 000000000..7ab02b8c7 --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float64_float16.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +dtype0* +shape: +2 + cast_outputCastinput* + +SrcT0* + +DstT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float64_float32.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float64_float32.pb new file mode 100644 index 000000000..3678a0f7e --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float64_float32.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +dtype0* +shape: +2 + cast_outputCastinput* + +DstT0* + +SrcT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float64_float64.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float64_float64.pb new file mode 100644 index 000000000..96551b5f2 --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float64_float64.pb @@ -0,0 +1,5 @@ + +0 +input Placeholder* +dtype0* +shape:" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float64_int16.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float64_int16.pb new file mode 100644 index 000000000..2290faa8e --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float64_int16.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +dtype0* +shape: +2 + cast_outputCastinput* + +DstT0* + +SrcT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float64_int32.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float64_int32.pb new file mode 100644 index 000000000..5f4880ac4 --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float64_int32.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +dtype0* +shape: +2 + cast_outputCastinput* + +SrcT0* + +DstT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float64_int64.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float64_int64.pb new file mode 100644 index 000000000..d1e43b903 --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float64_int64.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +dtype0* +shape: +2 + cast_outputCastinput* + +DstT0 * + +SrcT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float64_int8.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float64_int8.pb new file mode 100644 index 000000000..c5e83bcbf --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float64_int8.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +shape:* +dtype0 +2 + cast_outputCastinput* + +SrcT0* + +DstT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float64_uint16.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float64_uint16.pb new file mode 100644 index 000000000..41eafba45 --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float64_uint16.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +shape:* +dtype0 +2 + cast_outputCastinput* + +SrcT0* + +DstT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float64_uint32.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float64_uint32.pb new file mode 100644 index 000000000..315b6819f --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float64_uint32.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +shape:* +dtype0 +2 + cast_outputCastinput* + +SrcT0* + +DstT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float64_uint64.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float64_uint64.pb new file mode 100644 index 000000000..ebb92ce54 --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float64_uint64.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +dtype0* +shape: +2 + cast_outputCastinput* + +SrcT0* + +DstT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float64_uint8.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float64_uint8.pb new file mode 100644 index 000000000..a71d3c7d2 --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_float64_uint8.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +dtype0* +shape: +2 + cast_outputCastinput* + +SrcT0* + +DstT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int16_float16.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int16_float16.pb new file mode 100644 index 000000000..3ab637e88 --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int16_float16.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +dtype0* +shape: +2 + cast_outputCastinput* + +DstT0* + +SrcT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int16_float32.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int16_float32.pb new file mode 100644 index 000000000..b26891dd6 --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int16_float32.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +shape:* +dtype0 +2 + cast_outputCastinput* + +SrcT0* + +DstT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int16_float64.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int16_float64.pb new file mode 100644 index 000000000..5be28d608 --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int16_float64.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +dtype0* +shape: +2 + cast_outputCastinput* + +SrcT0* + +DstT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int16_int16.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int16_int16.pb new file mode 100644 index 000000000..73a7c2f45 --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int16_int16.pb @@ -0,0 +1,5 @@ + +0 +input Placeholder* +shape:* +dtype0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int16_int32.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int16_int32.pb new file mode 100644 index 000000000..c05482d34 --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int16_int32.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +dtype0* +shape: +2 + cast_outputCastinput* + +DstT0* + +SrcT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int16_int64.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int16_int64.pb new file mode 100644 index 000000000..5cab21d36 --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int16_int64.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +dtype0* +shape: +2 + cast_outputCastinput* + +SrcT0* + +DstT0 " \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int16_int8.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int16_int8.pb new file mode 100644 index 000000000..d57e7a282 --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int16_int8.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +shape:* +dtype0 +2 + cast_outputCastinput* + +DstT0* + +SrcT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int16_uint16.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int16_uint16.pb new file mode 100644 index 000000000..b1d496198 --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int16_uint16.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +dtype0* +shape: +2 + cast_outputCastinput* + +DstT0* + +SrcT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int16_uint32.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int16_uint32.pb new file mode 100644 index 000000000..84a42abf3 --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int16_uint32.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +shape:* +dtype0 +2 + cast_outputCastinput* + +SrcT0* + +DstT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int16_uint64.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int16_uint64.pb new file mode 100644 index 000000000..6e5fe3e73 --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int16_uint64.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +dtype0* +shape: +2 + cast_outputCastinput* + +DstT0* + +SrcT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int16_uint8.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int16_uint8.pb new file mode 100644 index 000000000..d48963b0b --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int16_uint8.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +dtype0* +shape: +2 + cast_outputCastinput* + +SrcT0* + +DstT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int32_float16.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int32_float16.pb new file mode 100644 index 000000000..a7dc7de1e --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int32_float16.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +dtype0* +shape: +2 + cast_outputCastinput* + +DstT0* + +SrcT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int32_float32.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int32_float32.pb new file mode 100644 index 000000000..9ced9a4ff --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int32_float32.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +shape:* +dtype0 +2 + cast_outputCastinput* + +DstT0* + +SrcT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int32_float64.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int32_float64.pb new file mode 100644 index 000000000..f259ee12b --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int32_float64.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +shape:* +dtype0 +2 + cast_outputCastinput* + +SrcT0* + +DstT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int32_int16.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int32_int16.pb new file mode 100644 index 000000000..0b043570d --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int32_int16.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +dtype0* +shape: +2 + cast_outputCastinput* + +SrcT0* + +DstT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int32_int32.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int32_int32.pb new file mode 100644 index 000000000..037f42cc4 --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int32_int32.pb @@ -0,0 +1,5 @@ + +0 +input Placeholder* +dtype0* +shape:" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int32_int64.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int32_int64.pb new file mode 100644 index 000000000..84ef4a332 --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int32_int64.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +dtype0* +shape: +2 + cast_outputCastinput* + +DstT0 * + +SrcT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int32_int8.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int32_int8.pb new file mode 100644 index 000000000..077790125 --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int32_int8.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +shape:* +dtype0 +2 + cast_outputCastinput* + +DstT0* + +SrcT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int32_uint16.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int32_uint16.pb new file mode 100644 index 000000000..69c8ddb2a --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int32_uint16.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +dtype0* +shape: +2 + cast_outputCastinput* + +SrcT0* + +DstT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int32_uint32.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int32_uint32.pb new file mode 100644 index 000000000..678bc53d1 --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int32_uint32.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +dtype0* +shape: +2 + cast_outputCastinput* + +DstT0* + +SrcT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int32_uint64.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int32_uint64.pb new file mode 100644 index 000000000..a191518a4 --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int32_uint64.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +dtype0* +shape: +2 + cast_outputCastinput* + +SrcT0* + +DstT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int32_uint8.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int32_uint8.pb new file mode 100644 index 000000000..04ab7ddea --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int32_uint8.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +shape:* +dtype0 +2 + cast_outputCastinput* + +SrcT0* + +DstT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int64_float16.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int64_float16.pb new file mode 100644 index 000000000..52ebc3381 --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int64_float16.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +dtype0 * +shape: +2 + cast_outputCastinput* + +SrcT0 * + +DstT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int64_float32.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int64_float32.pb new file mode 100644 index 000000000..3592143ad --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int64_float32.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +dtype0 * +shape: +2 + cast_outputCastinput* + +SrcT0 * + +DstT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int64_float64.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int64_float64.pb new file mode 100644 index 000000000..45eea92e4 --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int64_float64.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +dtype0 * +shape: +2 + cast_outputCastinput* + +SrcT0 * + +DstT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int64_int16.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int64_int16.pb new file mode 100644 index 000000000..feca9571c --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int64_int16.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +dtype0 * +shape: +2 + cast_outputCastinput* + +DstT0* + +SrcT0 " \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int64_int32.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int64_int32.pb new file mode 100644 index 000000000..c0eab4271 --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int64_int32.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +dtype0 * +shape: +2 + cast_outputCastinput* + +DstT0* + +SrcT0 " \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int64_int64.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int64_int64.pb new file mode 100644 index 000000000..41c2603de --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int64_int64.pb @@ -0,0 +1,5 @@ + +0 +input Placeholder* +dtype0 * +shape:" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int64_int8.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int64_int8.pb new file mode 100644 index 000000000..9686dd438 --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int64_int8.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +shape:* +dtype0 +2 + cast_outputCastinput* + +DstT0* + +SrcT0 " \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int64_uint16.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int64_uint16.pb new file mode 100644 index 000000000..680044623 --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int64_uint16.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +dtype0 * +shape: +2 + cast_outputCastinput* + +SrcT0 * + +DstT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int64_uint32.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int64_uint32.pb new file mode 100644 index 000000000..6a1050a8b --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int64_uint32.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +dtype0 * +shape: +2 + cast_outputCastinput* + +SrcT0 * + +DstT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int64_uint64.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int64_uint64.pb new file mode 100644 index 000000000..c7bb130bd --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int64_uint64.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +shape:* +dtype0 +2 + cast_outputCastinput* + +DstT0* + +SrcT0 " \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int64_uint8.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int64_uint8.pb new file mode 100644 index 000000000..357d19961 --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int64_uint8.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +dtype0 * +shape: +2 + cast_outputCastinput* + +SrcT0 * + +DstT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int8_float16.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int8_float16.pb new file mode 100644 index 000000000..ae196efc0 --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int8_float16.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +dtype0* +shape: +2 + cast_outputCastinput* + +DstT0* + +SrcT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int8_float32.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int8_float32.pb new file mode 100644 index 000000000..929af8a70 --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int8_float32.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +shape:* +dtype0 +2 + cast_outputCastinput* + +SrcT0* + +DstT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int8_float64.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int8_float64.pb new file mode 100644 index 000000000..854c2e86d --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int8_float64.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +dtype0* +shape: +2 + cast_outputCastinput* + +DstT0* + +SrcT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int8_int16.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int8_int16.pb new file mode 100644 index 000000000..55f02ec5e --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int8_int16.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +dtype0* +shape: +2 + cast_outputCastinput* + +DstT0* + +SrcT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int8_int32.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int8_int32.pb new file mode 100644 index 000000000..493975022 --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int8_int32.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +dtype0* +shape: +2 + cast_outputCastinput* + +DstT0* + +SrcT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int8_int64.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int8_int64.pb new file mode 100644 index 000000000..5ea30a1fa --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int8_int64.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +shape:* +dtype0 +2 + cast_outputCastinput* + +DstT0 * + +SrcT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int8_int8.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int8_int8.pb new file mode 100644 index 000000000..51066cd62 --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int8_int8.pb @@ -0,0 +1,5 @@ + +0 +input Placeholder* +dtype0* +shape:" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int8_uint16.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int8_uint16.pb new file mode 100644 index 000000000..852003d52 --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int8_uint16.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +dtype0* +shape: +2 + cast_outputCastinput* + +SrcT0* + +DstT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int8_uint32.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int8_uint32.pb new file mode 100644 index 000000000..6e780ddb2 --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int8_uint32.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +dtype0* +shape: +2 + cast_outputCastinput* + +SrcT0* + +DstT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int8_uint64.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int8_uint64.pb new file mode 100644 index 000000000..418e3b343 --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int8_uint64.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +dtype0* +shape: +2 + cast_outputCastinput* + +DstT0* + +SrcT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int8_uint8.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int8_uint8.pb new file mode 100644 index 000000000..6cfc8e089 --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_int8_uint8.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +shape:* +dtype0 +2 + cast_outputCastinput* + +SrcT0* + +DstT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint16_float16.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint16_float16.pb new file mode 100644 index 000000000..282240199 --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint16_float16.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +dtype0* +shape: +2 + cast_outputCastinput* + +SrcT0* + +DstT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint16_float32.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint16_float32.pb new file mode 100644 index 000000000..25eb2ab9a --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint16_float32.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +dtype0* +shape: +2 + cast_outputCastinput* + +SrcT0* + +DstT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint16_float64.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint16_float64.pb new file mode 100644 index 000000000..d1584196c --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint16_float64.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +dtype0* +shape: +2 + cast_outputCastinput* + +SrcT0* + +DstT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint16_int16.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint16_int16.pb new file mode 100644 index 000000000..45abf4e2e --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint16_int16.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +shape:* +dtype0 +2 + cast_outputCastinput* + +DstT0* + +SrcT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint16_int32.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint16_int32.pb new file mode 100644 index 000000000..91fa9f5f4 --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint16_int32.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +dtype0* +shape: +2 + cast_outputCastinput* + +SrcT0* + +DstT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint16_int64.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint16_int64.pb new file mode 100644 index 000000000..b19986b1d --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint16_int64.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +shape:* +dtype0 +2 + cast_outputCastinput* + +SrcT0* + +DstT0 " \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint16_int8.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint16_int8.pb new file mode 100644 index 000000000..4969614ea --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint16_int8.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +dtype0* +shape: +2 + cast_outputCastinput* + +SrcT0* + +DstT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint16_uint16.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint16_uint16.pb new file mode 100644 index 000000000..4aca5d466 --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint16_uint16.pb @@ -0,0 +1,5 @@ + +0 +input Placeholder* +dtype0* +shape:" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint16_uint32.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint16_uint32.pb new file mode 100644 index 000000000..93ea93d29 --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint16_uint32.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +shape:* +dtype0 +2 + cast_outputCastinput* + +SrcT0* + +DstT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint16_uint64.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint16_uint64.pb new file mode 100644 index 000000000..2fd236f80 --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint16_uint64.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +shape:* +dtype0 +2 + cast_outputCastinput* + +SrcT0* + +DstT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint16_uint8.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint16_uint8.pb new file mode 100644 index 000000000..8240aa59b --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint16_uint8.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +shape:* +dtype0 +2 + cast_outputCastinput* + +SrcT0* + +DstT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint32_float16.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint32_float16.pb new file mode 100644 index 000000000..399176df5 --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint32_float16.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +shape:* +dtype0 +2 + cast_outputCastinput* + +SrcT0* + +DstT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint32_float32.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint32_float32.pb new file mode 100644 index 000000000..f90591992 --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint32_float32.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +dtype0* +shape: +2 + cast_outputCastinput* + +SrcT0* + +DstT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint32_float64.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint32_float64.pb new file mode 100644 index 000000000..75a9380be --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint32_float64.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +dtype0* +shape: +2 + cast_outputCastinput* + +DstT0* + +SrcT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint32_int16.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint32_int16.pb new file mode 100644 index 000000000..078eaa6f9 --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint32_int16.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +dtype0* +shape: +2 + cast_outputCastinput* + +DstT0* + +SrcT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint32_int32.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint32_int32.pb new file mode 100644 index 000000000..5af0c0e6b --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint32_int32.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +dtype0* +shape: +2 + cast_outputCastinput* + +SrcT0* + +DstT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint32_int64.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint32_int64.pb new file mode 100644 index 000000000..4b4bef3ca --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint32_int64.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +shape:* +dtype0 +2 + cast_outputCastinput* + +SrcT0* + +DstT0 " \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint32_int8.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint32_int8.pb new file mode 100644 index 000000000..5d54b13fc --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint32_int8.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +dtype0* +shape: +2 + cast_outputCastinput* + +SrcT0* + +DstT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint32_uint16.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint32_uint16.pb new file mode 100644 index 000000000..fcb5ab1c6 --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint32_uint16.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +shape:* +dtype0 +2 + cast_outputCastinput* + +SrcT0* + +DstT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint32_uint32.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint32_uint32.pb new file mode 100644 index 000000000..7a86995dc --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint32_uint32.pb @@ -0,0 +1,5 @@ + +0 +input Placeholder* +dtype0* +shape:" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint32_uint64.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint32_uint64.pb new file mode 100644 index 000000000..0a3c6523b --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint32_uint64.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +dtype0* +shape: +2 + cast_outputCastinput* + +SrcT0* + +DstT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint32_uint8.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint32_uint8.pb new file mode 100644 index 000000000..28bd5a1f0 --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint32_uint8.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +dtype0* +shape: +2 + cast_outputCastinput* + +DstT0* + +SrcT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint64_float16.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint64_float16.pb new file mode 100644 index 000000000..ed5e22cb0 --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint64_float16.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +dtype0* +shape: +2 + cast_outputCastinput* + +SrcT0* + +DstT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint64_float32.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint64_float32.pb new file mode 100644 index 000000000..56f4196a1 --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint64_float32.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +dtype0* +shape: +2 + cast_outputCastinput* + +SrcT0* + +DstT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint64_float64.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint64_float64.pb new file mode 100644 index 000000000..7c0d0b041 --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint64_float64.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +dtype0* +shape: +2 + cast_outputCastinput* + +DstT0* + +SrcT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint64_int16.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint64_int16.pb new file mode 100644 index 000000000..f7c53be64 --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint64_int16.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +dtype0* +shape: +2 + cast_outputCastinput* + +DstT0* + +SrcT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint64_int32.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint64_int32.pb new file mode 100644 index 000000000..15e54997c --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint64_int32.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +shape:* +dtype0 +2 + cast_outputCastinput* + +SrcT0* + +DstT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint64_int64.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint64_int64.pb new file mode 100644 index 000000000..04c27b13c --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint64_int64.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +dtype0* +shape: +2 + cast_outputCastinput* + +SrcT0* + +DstT0 " \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint64_int8.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint64_int8.pb new file mode 100644 index 000000000..0a86c1697 --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint64_int8.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +dtype0* +shape: +2 + cast_outputCastinput* + +DstT0* + +SrcT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint64_uint16.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint64_uint16.pb new file mode 100644 index 000000000..af4813ccb --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint64_uint16.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +shape:* +dtype0 +2 + cast_outputCastinput* + +SrcT0* + +DstT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint64_uint32.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint64_uint32.pb new file mode 100644 index 000000000..109b8111a --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint64_uint32.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +dtype0* +shape: +2 + cast_outputCastinput* + +DstT0* + +SrcT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint64_uint64.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint64_uint64.pb new file mode 100644 index 000000000..9f3abb274 --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint64_uint64.pb @@ -0,0 +1,5 @@ + +0 +input Placeholder* +dtype0* +shape:" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint64_uint8.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint64_uint8.pb new file mode 100644 index 000000000..a56a8f083 --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint64_uint8.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +dtype0* +shape: +2 + cast_outputCastinput* + +SrcT0* + +DstT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint8_float16.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint8_float16.pb new file mode 100644 index 000000000..527465366 --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint8_float16.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +dtype0* +shape: +2 + cast_outputCastinput* + +SrcT0* + +DstT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint8_float32.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint8_float32.pb new file mode 100644 index 000000000..51a70a747 --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint8_float32.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +dtype0* +shape: +2 + cast_outputCastinput* + +SrcT0* + +DstT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint8_float64.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint8_float64.pb new file mode 100644 index 000000000..772605463 --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint8_float64.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +dtype0* +shape: +2 + cast_outputCastinput* + +SrcT0* + +DstT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint8_int16.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint8_int16.pb new file mode 100644 index 000000000..7d7769cb5 --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint8_int16.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +dtype0* +shape: +2 + cast_outputCastinput* + +DstT0* + +SrcT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint8_int32.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint8_int32.pb new file mode 100644 index 000000000..bd6517ea2 --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint8_int32.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +dtype0* +shape: +2 + cast_outputCastinput* + +DstT0* + +SrcT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint8_int64.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint8_int64.pb new file mode 100644 index 000000000..ec8347c31 --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint8_int64.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +dtype0* +shape: +2 + cast_outputCastinput* + +SrcT0* + +DstT0 " \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint8_int8.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint8_int8.pb new file mode 100644 index 000000000..ef71143ef --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint8_int8.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +dtype0* +shape: +2 + cast_outputCastinput* + +SrcT0* + +DstT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint8_uint16.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint8_uint16.pb new file mode 100644 index 000000000..085586a15 --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint8_uint16.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +dtype0* +shape: +2 + cast_outputCastinput* + +SrcT0* + +DstT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint8_uint32.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint8_uint32.pb new file mode 100644 index 000000000..292548245 --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint8_uint32.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +shape:* +dtype0 +2 + cast_outputCastinput* + +DstT0* + +SrcT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint8_uint64.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint8_uint64.pb new file mode 100644 index 000000000..403c58c38 --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint8_uint64.pb @@ -0,0 +1,11 @@ + +0 +input Placeholder* +shape:* +dtype0 +2 + cast_outputCastinput* + +DstT0* + +SrcT0" \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint8_uint8.pb b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint8_uint8.pb new file mode 100644 index 000000000..b196f006b --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/cast_graph/cast_uint8_uint8.pb @@ -0,0 +1,5 @@ + +0 +input Placeholder* +dtype0* +shape:" \ No newline at end of file From cb18d3d996b9f775e01c38d3feb2734418237386 Mon Sep 17 00:00:00 2001 From: raver119 Date: Wed, 4 Dec 2019 09:11:37 +0300 Subject: [PATCH 27/30] allow MKL-DNN on non-AVX machines (#104) Signed-off-by: raver119 --- .../include/ops/declarable/platform/mkldnn/avgpooling2d.cpp | 4 ---- .../ops/declarable/platform/mkldnn/avgpooling2d_bp.cpp | 4 ---- .../include/ops/declarable/platform/mkldnn/avgpooling3d.cpp | 4 ---- .../ops/declarable/platform/mkldnn/avgpooling3d_bp.cpp | 4 ---- libnd4j/include/ops/declarable/platform/mkldnn/batchnorm.cpp | 4 ---- libnd4j/include/ops/declarable/platform/mkldnn/conv2d.cpp | 4 ---- libnd4j/include/ops/declarable/platform/mkldnn/conv3d.cpp | 4 ---- libnd4j/include/ops/declarable/platform/mkldnn/deconv2d.cpp | 4 ---- .../include/ops/declarable/platform/mkldnn/deconv2d_tf.cpp | 4 ---- libnd4j/include/ops/declarable/platform/mkldnn/deconv3d.cpp | 4 ---- libnd4j/include/ops/declarable/platform/mkldnn/lrn.cpp | 4 ---- libnd4j/include/ops/declarable/platform/mkldnn/lstmLayer.cpp | 5 ----- .../include/ops/declarable/platform/mkldnn/maxpooling2d.cpp | 4 ---- .../ops/declarable/platform/mkldnn/maxpooling2d_bp.cpp | 4 ---- .../include/ops/declarable/platform/mkldnn/maxpooling3d.cpp | 4 ---- .../ops/declarable/platform/mkldnn/maxpooling_3d_bp.cpp | 4 ---- 16 files changed, 65 deletions(-) diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling2d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling2d.cpp index e70aff9d9..9a3b2916b 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling2d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling2d.cpp @@ -129,10 +129,6 @@ namespace nd4j { } PLATFORM_CHECK(avgpool2d) { - // we don't want to use mkldnn if cpu doesn't support avx/avx2 - if (::optimalLevel() < 2) - return false; - auto input = INPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0); diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling2d_bp.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling2d_bp.cpp index ceef28d33..428bd6042 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling2d_bp.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling2d_bp.cpp @@ -139,10 +139,6 @@ namespace nd4j { } PLATFORM_CHECK(avgpool2d_bp) { - // we don't want to use mkldnn if cpu doesn't support avx/avx2 - if (::optimalLevel() < 2) - return false; - auto input = INPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0); diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling3d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling3d.cpp index e42cb6a8e..22ace87de 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling3d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling3d.cpp @@ -131,10 +131,6 @@ namespace nd4j { } PLATFORM_CHECK(avgpool3dnew) { - // we don't want to use mkldnn if cpu doesn't support avx/avx2 - if (::optimalLevel() < 2) - return false; - auto input = INPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0); diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling3d_bp.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling3d_bp.cpp index 370f2b3fd..0c52608a0 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling3d_bp.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling3d_bp.cpp @@ -144,10 +144,6 @@ namespace nd4j { } PLATFORM_CHECK(avgpool3dnew_bp) { - // we don't want to use mkldnn if cpu doesn't support avx/avx2 - if (::optimalLevel() < 2) - return false; - auto input = INPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0); diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/batchnorm.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/batchnorm.cpp index 23957aeb7..e66589b0a 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/batchnorm.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/batchnorm.cpp @@ -736,10 +736,6 @@ PLATFORM_IMPL(batchnorm_bp) { ////////////////////////////////////////////////////////////////////////// PLATFORM_CHECK(batchnorm_bp) { - // we don't want to use mkldnn if cpu doesn't support avx/avx2 - // if (::optimalLevel() < 2) - // return false; - NDArray* input = INPUT_VARIABLE(0); // 2D:nc, 4D:nchw, 5D:ncdhw NDArray* mean = INPUT_VARIABLE(1); // [c] NDArray* variance = INPUT_VARIABLE(2); // [c] diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/conv2d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/conv2d.cpp index 4531fda81..a01679740 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/conv2d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/conv2d.cpp @@ -329,10 +329,6 @@ PLATFORM_IMPL(conv2d_bp) { } PLATFORM_CHECK(conv2d_bp) { - // we don't want to use mkldnn if cpu doesn't support avx/avx2 - if (::optimalLevel() < 2) - return false; - auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/conv3d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/conv3d.cpp index a5871fada..1e28e76a5 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/conv3d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/conv3d.cpp @@ -375,10 +375,6 @@ PLATFORM_IMPL(conv3dnew_bp) { } PLATFORM_CHECK(conv3dnew_bp) { - // we don't want to use mkldnn if cpu doesn't support avx/avx2 - if (::optimalLevel() < 2) - return false; - auto input = INPUT_VARIABLE( 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) auto weights = INPUT_VARIABLE( diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d.cpp index 047549e40..ced37aea8 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d.cpp @@ -507,10 +507,6 @@ PLATFORM_IMPL(deconv2d_bp) { } PLATFORM_CHECK(deconv2d_bp) { - // we don't want to use mkldnn if cpu doesn't support avx/avx2 - // if (::optimalLevel() < 2) - // return false; - auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCDHW) auto weights = INPUT_VARIABLE(1); // [kH, kW, oC, iC] always auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d_tf.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d_tf.cpp index b0e27240f..fac53e877 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d_tf.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d_tf.cpp @@ -223,10 +223,6 @@ PLATFORM_IMPL(deconv2d_tf) { } PLATFORM_CHECK(deconv2d_tf) { - // we don't want to use mkldnn if cpu doesn't support avx/avx2 - // if (::optimalLevel() < 2) - // return false; - auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always auto gradO = INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCDHW), gradI diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/deconv3d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/deconv3d.cpp index 51e15349b..7259ea0db 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/deconv3d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/deconv3d.cpp @@ -525,10 +525,6 @@ PLATFORM_IMPL(deconv3d_bp) { PLATFORM_CHECK(deconv3d_bp) { - // we don't want to use mkldnn if cpu doesn't support avx/avx2 - // if (::optimalLevel() < 2) - // return false; - auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NHWC) or [bS, iD, iC, iH, iW] (NCDHW) auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, oC, iC] always auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/lrn.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/lrn.cpp index 41efe6524..ecd8b4c1a 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/lrn.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/lrn.cpp @@ -83,10 +83,6 @@ namespace nd4j { }; PLATFORM_CHECK(lrn) { - // we don't want to use mkldnn if cpu doesn't support avx/avx2 - if (::optimalLevel() < 2) - return false; - auto input = INPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0); diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/lstmLayer.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/lstmLayer.cpp index 50a349cc9..7417653b3 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/lstmLayer.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/lstmLayer.cpp @@ -494,11 +494,6 @@ PLATFORM_IMPL(lstmLayer) { } PLATFORM_CHECK(lstmLayer) { - // we don't want to use mkldnn if cpu doesn't support avx/avx2 - // if (::optimalLevel() < 2) { - // return false; - // } - const auto hasBiases = B_ARG(0); // indicates whether biases array is provided const auto hasInitH = B_ARG(2); // indicates whether initial output is provided const auto hasInitC = B_ARG(3); // indicates whether initial cell state is provided diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling2d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling2d.cpp index 4204b93d0..03008fbc6 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling2d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling2d.cpp @@ -135,10 +135,6 @@ namespace nd4j { } PLATFORM_CHECK(maxpool2d) { - // we don't want to use mkldnn if cpu doesn't support avx/avx2 - if (::optimalLevel() < 2) - return false; - auto input = INPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0); diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling2d_bp.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling2d_bp.cpp index 0c663a59c..e50bef362 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling2d_bp.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling2d_bp.cpp @@ -164,10 +164,6 @@ namespace nd4j { } PLATFORM_CHECK(maxpool2d_bp) { - // we don't want to use mkldnn if cpu doesn't support avx/avx2 - if (::optimalLevel() < 2) - return false; - auto input = INPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0); diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling3d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling3d.cpp index 72fb79709..6f132bb56 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling3d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling3d.cpp @@ -141,10 +141,6 @@ namespace nd4j { } PLATFORM_CHECK(maxpool3dnew) { - // we don't want to use mkldnn if cpu doesn't support avx/avx2 - if (::optimalLevel() < 2) - return false; - auto input = INPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0); diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling_3d_bp.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling_3d_bp.cpp index b4c9f1ad5..4f51d6633 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling_3d_bp.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling_3d_bp.cpp @@ -171,10 +171,6 @@ namespace nd4j { } PLATFORM_CHECK(maxpool3dnew_bp) { - // we don't want to use mkldnn if cpu doesn't support avx/avx2 - if (::optimalLevel() < 2) - return false; - auto input = INPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0); From 578a5abb681dd86b16482ed7ad754ea88790ae70 Mon Sep 17 00:00:00 2001 From: Alex Black Date: Wed, 4 Dec 2019 22:50:17 +1100 Subject: [PATCH 28/30] DNNL/MKLDNN dilated causal conv1d + betainc (#103) * - add padding calculation in same mode in causal conv1d op for right mkl paddings Signed-off-by: Yurii * - correct causal condition in mkldnnUtils.cpp Signed-off-by: Yurii * - correct some code which caused additional round errors is betainc op Signed-off-by: Yurii * - put float in place of template parameter in nan assign in betainc op Signed-off-by: Yurii --- .../declarable/generic/parity_ops/betaInc.cpp | 6 + .../ops/declarable/helpers/cpu/betaInc.cpp | 82 ++++++------ .../ops/declarable/helpers/cuda/betaInc.cu | 122 +++++++++--------- .../platform/mkldnn/mkldnnUtils.cpp | 8 +- .../declarable/platform/mkldnn/mkldnnUtils.h | 2 +- .../layers_tests/ConvolutionTests1.cpp | 33 +++++ .../layers_tests/DeclarableOpsTests3.cpp | 23 +++- 7 files changed, 166 insertions(+), 110 deletions(-) diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/betaInc.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/betaInc.cpp index 9d0a935a9..1b09bbf77 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/betaInc.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/betaInc.cpp @@ -38,6 +38,12 @@ CONFIGURABLE_OP_IMPL(betainc, 3, 1, false, 0, 0) { auto b = INPUT_VARIABLE(1); auto x = INPUT_VARIABLE(2); + // just skip op if input is empty + if (x->isEmpty()) { + *x = DataTypeUtils::nanOrZero(); + return Status::OK(); + } + auto output = OUTPUT_VARIABLE(0); REQUIRE_TRUE(a->isSameShape(b) && a->isSameShape(x), 0, "CONFIGURABLE_OP betainc: all three input arrays must have the same shapes, bit got a=%s, b=%s and x=%s instead !", ShapeUtils::shapeAsString(a).c_str(), ShapeUtils::shapeAsString(b).c_str(), ShapeUtils::shapeAsString(x).c_str()); diff --git a/libnd4j/include/ops/declarable/helpers/cpu/betaInc.cpp b/libnd4j/include/ops/declarable/helpers/cpu/betaInc.cpp index 88186b62a..83cc966ba 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/betaInc.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/betaInc.cpp @@ -31,61 +31,56 @@ namespace helpers { /////////////////////////////////////////////////////////////////// // modified Lentz’s algorithm for continued fractions, // reference: Lentz, W.J. 1976, “Generating Bessel Functions in Mie Scattering Calculations Using Continued Fractions” + template static T continuedFraction(const T a, const T b, const T x) { const T min = DataTypeUtils::min() / DataTypeUtils::eps(); const T aPlusb = a + b; - T val, delta, aPlus2i; + T val, aPlus2i; - // first iteration - T c = 1; - T d = static_cast(1) - aPlusb * x / (a + static_cast(1)); - if(math::nd4j_abs(d) < min) - d = min; - d = static_cast(1) / d; - T f = d; + T t2 = 1; + T t1 = static_cast(1) - aPlusb * x / (a + static_cast(1)); + if(math::nd4j_abs(t1) < min) + t1 = min; + t1 = static_cast(1) / t1; + T result = t1; - for(uint i = 1; i <= maxIter; i += 2) { + for(uint i = 1; i <= maxIter; ++i) { aPlus2i = a + static_cast(2*i); - - /***** even part *****/ val = i * (b - i) * x / ((aPlus2i - static_cast(1)) * aPlus2i); - // d - d = static_cast(1) + val * d; - if(math::nd4j_abs(d) < min) - d = min; - d = static_cast(1) / d; - // c - c = static_cast(1) + val / c; - if(math::nd4j_abs(c) < min) - c = min; - // f - f *= c * d; - - - /***** odd part *****/ + // t1 + t1 = static_cast(1) + val * t1; + if(math::nd4j_abs(t1) < min) + t1 = min; + t1 = static_cast(1) / t1; + // t2 + t2 = static_cast(1) + val / t2; + if(math::nd4j_abs(t2) < min) + t2 = min; + // result + result *= t2 * t1; val = -(a + i) * (aPlusb + i) * x / ((aPlus2i + static_cast(1)) * aPlus2i); - // d - d = static_cast(1) + val * d; - if(math::nd4j_abs(d) < min) - d = min; - d = static_cast(1) / d; - // c - c = static_cast(1) + val / c; - if(math::nd4j_abs(c) < min) - c = min; - // f - delta = c * d; - f *= delta; + // t1 + t1 = static_cast(1) + val * t1; + if(math::nd4j_abs(t1) < min) + t1 = min; + t1 = static_cast(1) / t1; + // t2 + t2 = static_cast(1) + val / t2; + if(math::nd4j_abs(t2) < min) + t2 = min; + // result + val = t2 * t1; + result *= val; // condition to stop loop - if(math::nd4j_abs(delta - static_cast(1)) <= DataTypeUtils::eps()) - return f; + if(math::nd4j_abs(val - static_cast(1)) <= DataTypeUtils::eps()) + return result; } - return std::numeric_limits::infinity(); // no convergence, more iterations is required + return DataTypeUtils::infOrMax(); // no convergence, more iterations is required, return infinity } /////////////////////////////////////////////////////////////////// @@ -110,10 +105,9 @@ static T betaIncCore(T a, T b, T x) { const T front = math::nd4j_exp(math::nd4j_log(x) * a + math::nd4j_log(1.f - x) * b - gammaPart); if (x <= (a + static_cast(1)) / (a + b + static_cast(2))) - return front * continuedFraction(a, b, x) / a; - else // symmetry relation - return static_cast(1) - front * continuedFraction(b, a, static_cast(1) - x) / b; - + return front * continuedFraction(a, b, x) / a; + else // symmetry relation + return static_cast(1) - front * continuedFraction(b, a, static_cast(1) - x) / b; } /////////////////////////////////////////////////////////////////// diff --git a/libnd4j/include/ops/declarable/helpers/cuda/betaInc.cu b/libnd4j/include/ops/declarable/helpers/cuda/betaInc.cu index e7541a005..267ae21c2 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/betaInc.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/betaInc.cu @@ -1,5 +1,5 @@ /******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (t2) 2015-2018 Skymind, Inc. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -39,54 +39,50 @@ __device__ T continuedFractionCuda(const T a, const T b, const T x) { const T min = DataTypeUtils::min() / DataTypeUtils::eps(); const T aPlusb = a + b; - T val, delta, aPlus2i; + T val, aPlus2i; - // first iteration - T c = 1; - T d = static_cast(1) - aPlusb * x / (a + static_cast(1)); - if(math::nd4j_abs(d) < min) - d = min; - d = static_cast(1) / d; - T f = d; + T t2 = coeffs[1]; + T t1 = coeffs[0]; + if(math::nd4j_abs(t1) < min) + t1 = min; + t1 = static_cast(1) / t1; + T result = t1; - for(uint i = 1; i <= maxIter; i += 2) { + for(uint i = 1; i <= maxIter; ++i) { - aPlus2i = a + static_cast(2*i); + const uint i2 = 2*i; + aPlus2i = a + static_cast(i2); - /***** even part *****/ - // d - d = static_cast(1) + coeffs[i - 1] * d; - if(math::nd4j_abs(d) < min) - d = min; - d = static_cast(1) / d; - // c - c = static_cast(1) + coeffs[i - 1] / c; - if(math::nd4j_abs(c) < min) - c = min; - // f - f *= c * d; - - - /***** odd part *****/ - // d - d = static_cast(1) + coeffs[i] * d; - if(math::nd4j_abs(d) < min) - d = min; - d = static_cast(1) / d; - // c - c = static_cast(1) + coeffs[i] / c; - if(math::nd4j_abs(c) < min) - c = min; - // f - delta = c * d; - f *= delta; + // t1 + t1 = static_cast(1) + coeffs[i2] * t1; + if(math::nd4j_abs(t1) < min) + t1 = min; + t1 = static_cast(1) / t1; + // t2 + t2 = static_cast(1) + coeffs[i2] / t2; + if(math::nd4j_abs(t2) < min) + t2 = min; + // result + result *= t2 * t1; + // t1 + t1 = static_cast(1) + coeffs[i2 + 1] * t1; + if(math::nd4j_abs(t1) < min) + t1 = min; + t1 = static_cast(1) / t1; + // t2 + t2 = static_cast(1) + coeffs[i2 + 1] / t2; + if(math::nd4j_abs(t2) < min) + t2 = min; + // result + val = t2 * t1; + result *= val; // condition to stop loop - if(math::nd4j_abs(delta - static_cast(1)) <= DataTypeUtils::eps()) - return f; + if(math::nd4j_abs(val - static_cast(1)) <= DataTypeUtils::eps()) + return result; } - return 1.f / 0.f; // no convergence, more iterations is required + return DataTypeUtils::infOrMax(); // no convergence, more iterations is required, return infinity } /////////////////////////////////////////////////////////////////// @@ -112,7 +108,14 @@ __global__ void betaIncForArrayCuda(const void* va, const Nd4jLong* aShapeInfo, b = *(reinterpret_cast(vb) + shape::getIndexOffset(j, bShapeInfo)); x = *(reinterpret_cast(vx) + shape::getIndexOffset(j, xShapeInfo)); - symmCond = x <= (a + static_cast(1)) / (a + b + static_cast(2)); + symmCond = x > (a + static_cast(1)) / (a + b + static_cast(2)); + + if(symmCond) { // swap a and b, x = 1 - x + T temp = a; + a = b; + b = temp; + x = static_cast(1) - x; + } } __syncthreads(); @@ -124,23 +127,17 @@ __global__ void betaIncForArrayCuda(const void* va, const Nd4jLong* aShapeInfo, } if (x == static_cast(0) || x == static_cast(1)) { - z = x; + z = symmCond ? static_cast(1) - x : x; return; } - if(threadIdx.x % 2 == 0) { /***** even part *****/ - const int m = threadIdx.x + 1; - if(symmCond) - sharedMem[threadIdx.x] = m * (b - m) * x / ((a + 2 * m - static_cast(1)) * (a + 2 * m)); - else - sharedMem[threadIdx.x] = m * (a - m) * (1.f-x) / ((b + 2 * m - static_cast(1)) * (b + 2 * m)); - } - else { /***** odd part *****/ - const int m = threadIdx.x; - if(symmCond) - sharedMem[threadIdx.x] = -(a + m) * (a + b + m) * x / ((a + 2 * m + static_cast(1)) * (a + 2 * m)); - else - sharedMem[threadIdx.x] = -(b + m) * (a + b + m) * (1.f-x) / ((b + 2 * m + static_cast(1)) * (b + 2 * m)); + // calculate two coefficients per thread + if(threadIdx.x != 0) { + + const int i = threadIdx.x; + const T aPlus2i = a + 2*i; + sharedMem[2*i] = i * (b - i) * x / ((aPlus2i - static_cast(1)) * aPlus2i); + sharedMem[2*i + 1] = -(a + i) * (a + b + i) * x / ((aPlus2i + static_cast(1)) * aPlus2i); } __syncthreads(); @@ -150,10 +147,13 @@ __global__ void betaIncForArrayCuda(const void* va, const Nd4jLong* aShapeInfo, const T gammaPart = lgamma(a) + lgamma(b) - lgamma(a + b); const T front = math::nd4j_exp(math::nd4j_log(x) * a + math::nd4j_log(1.f - x) * b - gammaPart); - if (symmCond) - z = front * continuedFractionCuda(a, b, x) / a; - else // symmetry relation - z = static_cast(1) - front * continuedFractionCuda(b, a, static_cast(1) - x) / b; + sharedMem[0] = static_cast(1) - (a + b) * x / (a + static_cast(1)); + sharedMem[1] = static_cast(1); + + z = front * continuedFractionCuda(a, b, x) / a; + + if(symmCond) // symmetry relation + z = static_cast(1) - z; } } @@ -174,7 +174,7 @@ void betaInc(nd4j::LaunchContext* context, const NDArray& a, const NDArray& b, c const int threadsPerBlock = maxIter; const int blocksPerGrid = output.lengthOf(); - const int sharedMem = output.sizeOfT() * threadsPerBlock + 128; + const int sharedMem = 2 * output.sizeOfT() * threadsPerBlock + 128; const auto xType = x.dataType(); diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.cpp index 07cf60ef9..96bbffcf8 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.cpp @@ -141,7 +141,7 @@ namespace nd4j { void getMKLDNNMemoryDescConv2d( - int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, bool isSameMode, bool isNCHW, + int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, const int paddingMode, bool isNCHW, int bS, int iC, int iH, int iW, int oC, int oH, int oW, const NDArray* src, const NDArray* diff_src, const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst, dnnl::memory::desc* conv_src_md, dnnl::memory::desc* conv_diff_src_md, dnnl::memory::desc* conv_weights_md, @@ -154,9 +154,11 @@ namespace nd4j { dnnl::memory::dims conv_bias_tz = { oC }; dnnl::memory::dims conv_dst_tz = { bS, oC, oH, oW }; + const int pWSame = (paddingMode == 2 && dW > 1) ? ((oW - 1) * sW + (kW - 1) * dW + 1 - iW) / 2 : pW; // dH == 1 for causal mode in conv1d + conv_strides = { sH, sW }; conv_padding = { pH, pW }; - conv_padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pW }; + conv_padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pWSame }; conv_dilation = { dH-1, dW-1}; auto type = dnnl::memory::data_type::f32; @@ -220,7 +222,7 @@ namespace nd4j { } void getMKLDNNMemoryDescConv3d( - int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, int dD, int dH, int dW, bool isSameMode, bool isNCDHW, + int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, int dD, int dH, int dW, bool paddingMode, bool isNCDHW, int bS, int iC, int iD, int iH, int iW, int oC, int oD, int oH, int oW, const NDArray* src, const NDArray* diff_src, const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst, dnnl::memory::desc* conv_src_md, dnnl::memory::desc* conv_diff_src_md, dnnl::memory::desc* conv_weights_md, diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.h b/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.h index 1f9b9e010..6274a645f 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.h +++ b/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.h @@ -86,7 +86,7 @@ namespace nd4j{ * Utility methods for MKLDNN */ void getMKLDNNMemoryDescConv2d( - int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, bool isSameMode, bool isNCHW, + int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, const int paddingMode, bool isNCHW, int bS, int iC, int iH, int iW, int oC, int oH, int oW, const NDArray* src, const NDArray* diff_src, const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst, dnnl::memory::desc* conv_src_md, dnnl::memory::desc* conv_diff_src_md, dnnl::memory::desc* conv_weights_md, diff --git a/libnd4j/tests_cpu/layers_tests/ConvolutionTests1.cpp b/libnd4j/tests_cpu/layers_tests/ConvolutionTests1.cpp index 99092b37d..99cc98af9 100644 --- a/libnd4j/tests_cpu/layers_tests/ConvolutionTests1.cpp +++ b/libnd4j/tests_cpu/layers_tests/ConvolutionTests1.cpp @@ -1040,6 +1040,39 @@ TEST_F(ConvolutionTests1, conv1d_causal_7) { delete results; } +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, conv1d_causal_8) { + + int bS=2, iW=8, iC=3,oC=4, kW=2, sW=1, pW=0, dW=2; + int oW = (iW-1)/sW + 1; + int paddingMode = 2; // CAUSAL + int dataFormat = 1; // 1-NHWC, 0-NCHW + + NDArray input('c', {bS, iW, iC}, nd4j::DataType::FLOAT32); + NDArray weights('c', {kW, iC, oC}, nd4j::DataType::FLOAT32); + + NDArray expOutput('c', {bS, oW, oC}, {11.000000, 11.600000, 12.200000, 12.800000, 26.299999, 27.799999, 29.299999, 30.799999, 45.399998, 48.399998, + 51.400002, 54.400005, 65.199997, 70.000000, 74.800003, 79.600006, 85.000000, 91.600006, 98.199997, 104.800003, 104.799995, 113.199997, 121.600006, + 130.000000, 124.599998, 134.800003, 145.000000, 155.200012, 144.399994, 156.399994, 168.399994, 180.400009, 133.400009, 141.199997, 149.000000, + 156.800003, 148.699997, 157.400009, 166.099991, 174.800003, 203.800003, 221.200012, 238.599991, 256.000000, 223.599991, 242.799988, 262.000000, + 281.200012, 243.399994, 264.399994, 285.399994, 306.399994, 263.199982, 286.000000, 308.799988, 331.600006, 283.000000, 307.600006, 332.200012, + 356.800018, 302.799988, 329.199982, 355.600006, 382.000000}, nd4j::DataType::FLOAT32); + + input.linspace(1., 1.); + weights.linspace(0.1, 0.1); + + nd4j::ops::conv1d op; + auto results = op.execute({&input, &weights}, {}, {kW, sW, pW, dW, paddingMode, dataFormat}); + auto output = results->at(0); + + ASSERT_EQ(Status::OK(), results->status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); + + delete results; +} + ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests1, conv1d_causal_bp_1) { diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests3.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests3.cpp index a521be97b..6d224b323 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests3.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests3.cpp @@ -1781,7 +1781,28 @@ TEST_F(DeclarableOpsTests3, betainc_test11) { NDArray b('c', {4}, {0.7717f, 0.9281f, 0.9846f, 0.4838f}, nd4j::DataType::FLOAT32); NDArray x('c', {4}, {0.9441f, 0.5957f, 0.8669f, 0.3502f}, nd4j::DataType::FLOAT32); - NDArray expected('c', {4}, {0.912156, 0.634443, 0.898314, 0.624544}, nd4j::DataType::FLOAT32); + NDArray expected('c', {4}, {0.912156, 0.634460, 0.898314, 0.624538}, nd4j::DataType::FLOAT32); + nd4j::ops::betainc op; + auto results = op.execute({&a, &b, &x}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto *output = results->at(0); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + delete results; +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests3, betainc_test12) { + + NDArray a('c', {4}, {8.0091f, 8.2108f, 7.5194f, 3.0780f}, nd4j::DataType::FLOAT32); + NDArray b('c', {4}, {7.9456f, 9.3527f, 9.8610f, 5.3541f}, nd4j::DataType::FLOAT32); + NDArray x('c', {4}, {0.9441f, 0.5957f, 0.8669f, 0.3502f}, nd4j::DataType::FLOAT32); + + NDArray expected('c', {4}, {0.9999995 , 0.8594694 , 0.999988 , 0.49124345}, nd4j::DataType::FLOAT32); nd4j::ops::betainc op; auto results = op.execute({&a, &b, &x}, {}, {}); From 9cc8803b8dbfc1b1c17dbf4a8ea452e4eece222e Mon Sep 17 00:00:00 2001 From: Alex Black Date: Wed, 4 Dec 2019 22:52:06 +1100 Subject: [PATCH 29/30] DL4J + Keras import: Causal Conv1D support (#107) * Keras causal conv1d support first steps Signed-off-by: AlexDBlack * Add tests Signed-off-by: AlexDBlack * Causal conv mode Signed-off-by: AlexDBlack * Gradient check and fixes for causal conv1d Signed-off-by: AlexDBlack * Fix Conv1D import and testing Signed-off-by: AlexDBlack * Cleanup Signed-off-by: AlexDBlack * Small keras test fix Signed-off-by: Alex Black * Don't allow setting causal convolution mode to conv2d/3d layers Signed-off-by: Alex Black * More robustly infer nIn for recurrent layers for ambiguous NCW and NWC cases Signed-off-by: Alex Black * Polish and cleanup Signed-off-by: Alex Black --- .../gradientcheck/CNN1DGradientCheckTest.java | 74 +++++++ .../convolution/ConvolutionLayerTest.java | 69 +++++++ .../nn/modelimport/keras/KerasLayer.java | 4 + .../keras/config/KerasLayerConfiguration.java | 1 + .../modelimport/keras/layers/KerasInput.java | 21 +- .../modelimport/keras/layers/KerasLoss.java | 7 +- .../convolutional/KerasConvolution.java | 1 - .../convolutional/KerasConvolution1D.java | 17 +- .../convolutional/KerasConvolutionUtils.java | 3 +- .../keras/layers/recurrent/KerasLSTM.java | 19 ++ .../layers/recurrent/KerasSimpleRnn.java | 18 ++ .../layers/wrappers/KerasBidirectional.java | 12 +- .../keras/utils/KerasLayerUtils.java | 11 ++ .../keras/e2e/KerasModelEndToEndTest.java | 185 +++++++++++++++--- .../nn/conf/ConvolutionMode.java | 14 +- .../nn/conf/layers/Convolution1DLayer.java | 5 + .../nn/conf/layers/Convolution3D.java | 6 + .../nn/conf/layers/ConvolutionLayer.java | 15 ++ .../nn/conf/layers/Deconvolution2D.java | 6 + .../conf/layers/DepthwiseConvolution2D.java | 6 + .../conf/layers/SeparableConvolution2D.java | 6 + .../nn/conf/layers/Subsampling1DLayer.java | 5 + .../nn/conf/layers/Subsampling3DLayer.java | 6 + .../nn/conf/layers/SubsamplingLayer.java | 15 ++ .../convolution/Convolution1DLayer.java | 94 +++++++++ .../util/Convolution1DUtils.java | 7 +- .../deeplearning4j/util/ConvolutionUtils.java | 17 +- 27 files changed, 588 insertions(+), 56 deletions(-) diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNN1DGradientCheckTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNN1DGradientCheckTest.java index 64748f932..a0a109cb1 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNN1DGradientCheckTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNN1DGradientCheckTest.java @@ -27,6 +27,8 @@ import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.*; import org.deeplearning4j.nn.conf.layers.convolutional.Cropping1D; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; +import org.deeplearning4j.util.Convolution1DUtils; +import org.deeplearning4j.util.ConvolutionUtils; import org.junit.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; @@ -442,4 +444,76 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest { } } } + + @Test + public void testCnn1Causal() { + int convNIn = 2; + int convNOut1 = 3; + int convNOut2 = 4; + int finalNOut = 3; + + int[] lengths = {11, 12, 13, 9, 10, 11}; + int[] kernels = {2, 3, 2, 4, 2, 3}; + int[] dilations = {1, 1, 2, 1, 2, 1}; + int[] strides = {1, 2, 1, 2, 1, 1}; + boolean[] masks = {false, true, false, true, false, true}; + boolean[] hasB = {true, false, true, false, true, true}; + + for (int i = 0; i < lengths.length; i++) { + int length = lengths[i]; + int k = kernels[i]; + int d = dilations[i]; + int st = strides[i]; + boolean mask = masks[i]; + boolean hasBias = hasB[i]; + //TODO has bias + String s = "k=" + k + ", s=" + st + "d=" + d + ", seqLen=" + length; + log.info("Starting test: " + s); + Nd4j.getRandom().setSeed(12345); + + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + .dataType(DataType.DOUBLE) + .updater(new NoOp()) + .activation(Activation.TANH) + .weightInit(new NormalDistribution(0, 1)) + .seed(12345) + .list() + .layer(new Convolution1DLayer.Builder().kernelSize(k) + .dilation(d) + .hasBias(hasBias) + .convolutionMode(ConvolutionMode.Causal) + .stride(st).nIn(convNIn).nOut(convNOut1) + .build()) + .layer(new Convolution1DLayer.Builder().kernelSize(k) + .dilation(d) + .convolutionMode(ConvolutionMode.Causal) + .stride(st).nIn(convNOut1).nOut(convNOut2) + .build()) + .layer(new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT) + .activation(Activation.SOFTMAX).nOut(finalNOut).build()) + .setInputType(InputType.recurrent(convNIn, length)).build(); + + MultiLayerNetwork net = new MultiLayerNetwork(conf); + net.init(); + + INDArray f = Nd4j.rand(DataType.DOUBLE, 2, convNIn, length); + INDArray fm = null; + if (mask) { + fm = Nd4j.create(2, length); + fm.get(NDArrayIndex.point(0), NDArrayIndex.all()).assign(1); + fm.get(NDArrayIndex.point(1), NDArrayIndex.interval(0, length-2)).assign(1); + } + + long outSize1 = Convolution1DUtils.getOutputSize(length, k, st, 0, ConvolutionMode.Causal, d); + long outSize2 = Convolution1DUtils.getOutputSize(outSize1, k, st, 0, ConvolutionMode.Causal, d); + + INDArray label = TestUtils.randomOneHotTimeSeries(2, finalNOut, (int)outSize2); + + boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, + DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, f, label, fm, null); + + assertTrue(s, gradOK); + TestUtils.testModelSerialization(net); + } + } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/ConvolutionLayerTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/ConvolutionLayerTest.java index 1c4b764bd..431831487 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/ConvolutionLayerTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/ConvolutionLayerTest.java @@ -712,4 +712,73 @@ public class ConvolutionLayerTest extends BaseDL4JTest { assertTrue(msg,msg.contains("Deconvolution2D") && msg.contains("input") && msg.contains("channels")); } } + + @Test + public void testConv1dCausalAllowed(){ + new Convolution1DLayer.Builder().convolutionMode(ConvolutionMode.Causal).kernelSize(2).build(); + new Subsampling1DLayer.Builder().convolutionMode(ConvolutionMode.Causal).kernelSize(2).build(); + } + + @Test + public void testConv2dNoCausalAllowed(){ + + try{ + new ConvolutionLayer.Builder().convolutionMode(ConvolutionMode.Causal).build(); + fail("Expected exception"); + } catch (Throwable t){ + String m = t.getMessage().toLowerCase(); + assertTrue(m, m.contains("causal") && m.contains("1d")); + } + + try{ + new Deconvolution2D.Builder().convolutionMode(ConvolutionMode.Causal).build(); + fail("Expected exception"); + } catch (Throwable t){ + String m = t.getMessage().toLowerCase(); + assertTrue(m, m.contains("causal") && m.contains("1d")); + } + + try{ + new DepthwiseConvolution2D.Builder().convolutionMode(ConvolutionMode.Causal).build(); + fail("Expected exception"); + } catch (Throwable t){ + String m = t.getMessage().toLowerCase(); + assertTrue(m, m.contains("causal") && m.contains("1d")); + } + + try{ + new SeparableConvolution2D.Builder().convolutionMode(ConvolutionMode.Causal).build(); + fail("Expected exception"); + } catch (Throwable t){ + String m = t.getMessage().toLowerCase(); + assertTrue(m, m.contains("causal") && m.contains("1d")); + } + + try{ + new SubsamplingLayer.Builder().convolutionMode(ConvolutionMode.Causal).build(); + fail("Expected exception"); + } catch (Throwable t){ + String m = t.getMessage().toLowerCase(); + assertTrue(m, m.contains("causal") && m.contains("1d")); + } + } + + @Test + public void testConv3dNoCausalAllowed(){ + try{ + new Convolution3D.Builder().convolutionMode(ConvolutionMode.Causal).build(); + fail("Expected exception"); + } catch (Throwable t){ + String m = t.getMessage().toLowerCase(); + assertTrue(m, m.contains("causal") && m.contains("1d")); + } + + try{ + new Subsampling3DLayer.Builder().convolutionMode(ConvolutionMode.Causal).build(); + fail("Expected exception"); + } catch (Throwable t){ + String m = t.getMessage().toLowerCase(); + assertTrue(m, m.contains("causal") && m.contains("1d")); + } + } } diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/KerasLayer.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/KerasLayer.java index a31ac6177..7d70077af 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/KerasLayer.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/KerasLayer.java @@ -356,6 +356,10 @@ public class KerasLayer { return this.layer; } + public void setLayer(Layer layer){ + this.layer = layer; + } + /** * Whether this Keras layer maps to a DL4J Vertex. * diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/config/KerasLayerConfiguration.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/config/KerasLayerConfiguration.java index 84a85a2d5..6d6fc42c9 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/config/KerasLayerConfiguration.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/config/KerasLayerConfiguration.java @@ -233,6 +233,7 @@ public class KerasLayerConfiguration { private final String LAYER_BORDER_MODE_SAME = "same"; private final String LAYER_BORDER_MODE_VALID = "valid"; private final String LAYER_BORDER_MODE_FULL = "full"; + private final String LAYER_BORDER_MODE_CAUSAL = "causal"; /* Noise layers */ private final String LAYER_FIELD_RATE = "rate"; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/KerasInput.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/KerasInput.java index c1df4b592..785e480d1 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/KerasInput.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/KerasInput.java @@ -124,7 +124,26 @@ public class KerasInput extends KerasLayer { myInputType = new InputType.InputTypeFeedForward(this.inputShape[0]); break; case 2: - myInputType = new InputType.InputTypeRecurrent(this.inputShape[1], this.inputShape[0]); + if(this.dimOrder != null) { + switch (this.dimOrder) { + case TENSORFLOW: //NWC == channels_last + myInputType = new InputType.InputTypeRecurrent(this.inputShape[1], this.inputShape[0]); + break; + case THEANO: //NCW == channels_first + myInputType = new InputType.InputTypeRecurrent(this.inputShape[0], this.inputShape[1]); + break; + case NONE: + //Assume RNN in [mb, seqLen, size] format + myInputType = new InputType.InputTypeRecurrent(this.inputShape[0], this.inputShape[1]); + break; + default: + throw new IllegalStateException("Unknown/not supported dimension ordering: " + this.dimOrder); + } + } else { + //Assume RNN in [mb, seqLen, size] format + myInputType = new InputType.InputTypeRecurrent(this.inputShape[0], this.inputShape[1]); + } + break; case 3: switch (this.dimOrder) { diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/KerasLoss.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/KerasLoss.java index 6fd72bd3e..e3c603287 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/KerasLoss.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/KerasLoss.java @@ -27,6 +27,7 @@ import org.deeplearning4j.nn.conf.layers.RnnLossLayer; import org.deeplearning4j.nn.modelimport.keras.KerasLayer; import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; +import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.lossfunctions.LossFunctions; import java.util.ArrayList; @@ -96,13 +97,13 @@ public class KerasLoss extends KerasLayer { */ public FeedForwardLayer getLossLayer(InputType type) throws UnsupportedKerasConfigurationException { if (type instanceof InputType.InputTypeFeedForward) { - this.layer = new LossLayer.Builder(loss).name(this.layerName).build(); + this.layer = new LossLayer.Builder(loss).name(this.layerName).activation(Activation.IDENTITY).build(); } else if (type instanceof InputType.InputTypeRecurrent) { - this.layer = new RnnLossLayer.Builder(loss).name(this.layerName).build(); + this.layer = new RnnLossLayer.Builder(loss).name(this.layerName).activation(Activation.IDENTITY).build(); } else if (type instanceof InputType.InputTypeConvolutional) { - this.layer = new CnnLossLayer.Builder(loss).name(this.layerName).build(); + this.layer = new CnnLossLayer.Builder(loss).name(this.layerName).activation(Activation.IDENTITY).build(); } else { throw new UnsupportedKerasConfigurationException("Unsupported output layer type" + "got : " + type.toString()); diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasConvolution.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasConvolution.java index f1d2f0210..a5f3e15ae 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasConvolution.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasConvolution.java @@ -79,7 +79,6 @@ abstract public class KerasConvolution extends KerasLayer { public KerasConvolution(Map layerConfig, boolean enforceTrainingConfig) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException { super(layerConfig, enforceTrainingConfig); - } /** diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasConvolution1D.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasConvolution1D.java index 3da88d3b1..120870de9 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasConvolution1D.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasConvolution1D.java @@ -185,18 +185,11 @@ public class KerasConvolution1D extends KerasConvolution { break; case THEANO: - paramValue = kerasParamValue.permute(2, 1, 0); - paramValue = paramValue.reshape( - paramValue.size(0), paramValue.size(1), - paramValue.size(2), 1).dup(); - for (int i = 0; i < paramValue.tensorsAlongDimension(2, 3); i++) { - INDArray copyFilter = paramValue.tensorAlongDimension(i, 2, 3).dup(); - double[] flattenedFilter = copyFilter.ravel().data().asDouble(); - ArrayUtils.reverse(flattenedFilter); - INDArray newFilter = Nd4j.create(flattenedFilter, copyFilter.shape()); - INDArray inPlaceFilter = paramValue.tensorAlongDimension(i, 2, 3); - inPlaceFilter.muli(0).addi(newFilter.castTo(inPlaceFilter.dataType())); - } + //Convert from keras [k,nIn,nOut] to DL4J conv2d [nOut, nIn, k, 1] + long k = kerasParamValue.size(0); + long nIn = kerasParamValue.size(1); + long nOut = kerasParamValue.size(2); + paramValue = kerasParamValue.permute(2, 1, 0).dup('c').reshape(nOut, nIn, k, 1); break; default: throw new InvalidKerasConfigurationException("Unknown keras backend " + this.getDimOrder()); diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasConvolutionUtils.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasConvolutionUtils.java index b60b41459..0968260b7 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasConvolutionUtils.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasConvolutionUtils.java @@ -264,7 +264,8 @@ public class KerasConvolutionUtils { } else if (borderMode.equals(conf.getLAYER_BORDER_MODE_VALID()) || borderMode.equals(conf.getLAYER_BORDER_MODE_FULL())) { convolutionMode = ConvolutionMode.Truncate; - + } else if(borderMode.equals(conf.getLAYER_BORDER_MODE_CAUSAL())) { + convolutionMode = ConvolutionMode.Causal; } else { throw new UnsupportedKerasConfigurationException("Unsupported convolution border mode: " + borderMode); } diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasLSTM.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasLSTM.java index 7d5603261..1c205bbca 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasLSTM.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasLSTM.java @@ -23,11 +23,13 @@ import lombok.val; import org.deeplearning4j.nn.api.layers.LayerConstraint; import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.inputs.InputType; +import org.deeplearning4j.nn.conf.layers.FeedForwardLayer; import org.deeplearning4j.nn.conf.layers.InputTypeUtil; import org.deeplearning4j.nn.conf.layers.LSTM; import org.deeplearning4j.nn.conf.layers.Layer; import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep; import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer; +import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer; import org.deeplearning4j.nn.modelimport.keras.KerasLayer; import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; @@ -186,6 +188,9 @@ public class KerasLSTM extends KerasLayer { .biasInit(0.0) // TODO: this is incorrect .l1(this.weightL1Regularization) .l2(this.weightL2Regularization); + Integer nIn = KerasLayerUtils.getNInFromInputDim(layerConfig, conf); + if(nIn != null) + builder.setNIn(nIn); if (biasConstraint != null) builder.constrainBias(biasConstraint); if (weightConstraint != null) @@ -436,6 +441,20 @@ public class KerasLSTM extends KerasLayer { log.warn("Attemping to set weights for unknown parameters: " + unknownParamNames.substring(1, unknownParamNames.length() - 1)); } + + + FeedForwardLayer ffl; + if(this.layer instanceof BaseWrapperLayer){ + BaseWrapperLayer bwl = (BaseWrapperLayer)this.layer; + ffl = (FeedForwardLayer)bwl.getUnderlying(); + } else { + ffl = (FeedForwardLayer) this.layer; + } + if(ffl.getNIn() != wRows){ + //Workaround/hack for ambiguous input shapes (nIn inference) for some RNN models (using NCW format but not recorded in config) + //We can reliably infer nIn from the shape of the weights array however + ffl.setNIn(wRows); + } } /** diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasSimpleRnn.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasSimpleRnn.java index 6f5edf597..f6ecbb6a5 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasSimpleRnn.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasSimpleRnn.java @@ -22,11 +22,13 @@ import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.nn.api.layers.LayerConstraint; import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.inputs.InputType; +import org.deeplearning4j.nn.conf.layers.FeedForwardLayer; import org.deeplearning4j.nn.conf.layers.InputTypeUtil; import org.deeplearning4j.nn.conf.layers.Layer; import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep; import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn; import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer; +import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer; import org.deeplearning4j.nn.modelimport.keras.KerasLayer; import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; @@ -154,6 +156,9 @@ public class KerasSimpleRnn extends KerasLayer { .biasInit(0.0) .l1(this.weightL1Regularization) .l2(this.weightL2Regularization); + Integer nIn = KerasLayerUtils.getNInFromInputDim(layerConfig, conf); + if(nIn != null) + builder.setNIn(nIn); if (biasConstraint != null) builder.constrainBias(biasConstraint); if (weightConstraint != null) @@ -282,6 +287,19 @@ public class KerasSimpleRnn extends KerasLayer { log.warn("Attemping to set weights for unknown parameters: " + unknownParamNames.substring(1, unknownParamNames.length() - 1)); } + + FeedForwardLayer ffl; + if(this.layer instanceof BaseWrapperLayer){ + BaseWrapperLayer bwl = (BaseWrapperLayer)this.layer; + ffl = (FeedForwardLayer)bwl.getUnderlying(); + } else { + ffl = (FeedForwardLayer) this.layer; + } + if(ffl.getNIn() != W.rows()){ + //Workaround/hack for ambiguous input shapes (nIn inference) for some RNN models (using NCW format but not recorded in config) + //We can reliably infer nIn from the shape of the weights array however + ffl.setNIn(W.rows()); + } } } diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/wrappers/KerasBidirectional.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/wrappers/KerasBidirectional.java index 40f1f7074..d37ee399c 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/wrappers/KerasBidirectional.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/wrappers/KerasBidirectional.java @@ -229,8 +229,8 @@ public class KerasBidirectional extends KerasLayer { @Override public void setWeights(Map weights) throws InvalidKerasConfigurationException { - Map forwardWeights = getUnderlyingWeights(weights, "forward"); - Map backwardWeights = getUnderlyingWeights(weights, "backward"); + Map forwardWeights = getUnderlyingWeights(((Bidirectional)this.layer).getFwd(), weights, "forward"); + Map backwardWeights = getUnderlyingWeights(((Bidirectional)this.layer).getBwd(), weights, "backward"); this.weights = new HashMap<>(); @@ -241,7 +241,7 @@ public class KerasBidirectional extends KerasLayer { } - private Map getUnderlyingWeights(Map weights, String direction) + private Map getUnderlyingWeights(Layer l, Map weights, String direction) throws InvalidKerasConfigurationException { int keras1SubstringLength; if (kerasRnnlayer instanceof KerasLSTM) @@ -270,8 +270,12 @@ public class KerasBidirectional extends KerasLayer { weights = newWeights; } + Layer layerBefore = kerasRnnlayer.getLayer(); + kerasRnnlayer.setLayer(l); kerasRnnlayer.setWeights(weights); - return kerasRnnlayer.getWeights(); + Map ret = kerasRnnlayer.getWeights(); + kerasRnnlayer.setLayer(layerBefore); + return ret; } } diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/KerasLayerUtils.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/KerasLayerUtils.java index 8d80d3f38..3494ecf49 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/KerasLayerUtils.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/KerasLayerUtils.java @@ -505,6 +505,17 @@ public class KerasLayerUtils { return nOut; } + public static Integer getNInFromInputDim(Map layerConfig, KerasLayerConfiguration conf) throws InvalidKerasConfigurationException { + Map innerConfig = KerasLayerUtils.getInnerLayerConfigFromConfig(layerConfig, conf); + if(innerConfig.containsKey(conf.getLAYER_FIELD_INPUT_DIM())){ + Object id = innerConfig.get(conf.getLAYER_FIELD_INPUT_DIM()); + if(id instanceof Number){ + return ((Number)id).intValue(); + } + } + return null; + } + /** * Get dropout from Keras layer configuration. * diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasModelEndToEndTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasModelEndToEndTest.java index 0565cc091..d4f458a39 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasModelEndToEndTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasModelEndToEndTest.java @@ -24,6 +24,8 @@ import org.deeplearning4j.eval.ROCMultiClass; import org.deeplearning4j.gradientcheck.GradientCheckUtil; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.layers.IOutputLayer; +import org.deeplearning4j.nn.conf.ConvolutionMode; +import org.deeplearning4j.nn.conf.layers.Convolution1DLayer; import org.deeplearning4j.nn.conf.layers.FeedForwardLayer; import org.deeplearning4j.nn.conf.layers.LossLayer; import org.deeplearning4j.nn.conf.layers.RnnOutputLayer; @@ -47,6 +49,8 @@ import org.nd4j.linalg.activations.impl.*; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.function.BiFunction; +import org.nd4j.linalg.function.Function; import org.nd4j.linalg.learning.config.NoOp; import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.linalg.lossfunctions.impl.LossSparseMCXENT; @@ -58,10 +62,7 @@ import java.io.InputStream; import java.net.URL; import java.nio.file.Files; import java.nio.file.StandardCopyOption; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Random; +import java.util.*; import static org.junit.Assert.*; @@ -86,7 +87,16 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { @Rule public final TemporaryFolder testDir = new TemporaryFolder(); - @Test(expected = IllegalStateException.class) + public static final BiFunction nwc2ncwExpected = new BiFunction() { + @Override + public INDArray apply(String s, INDArray array) { + if(array.rank() == 3) + return array.permute(0, 2, 1); //NWC to NCW + return array; + } + }; + + @Test(expected = IllegalStateException.class) public void fileNotFoundEndToEnd() throws Exception { String modelPath = "modelimport/keras/examples/foo/bar.h5"; importEndModelTest(modelPath, null, true, true, false, false); @@ -154,28 +164,28 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { public void importImdbLstmTfKeras1() throws Exception { String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_1_model.h5"; String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_1_inputs_and_outputs.h5"; - importEndModelTest(modelPath, inputsOutputPath, true, true, false, false); + importEndModelTest(modelPath, inputsOutputPath, true, true, false, false, true, null, nwc2ncwExpected); } @Test public void importImdbLstmThKeras1() throws Exception { String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_1_model.h5"; String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_1_inputs_and_outputs.h5"; - importEndModelTest(modelPath, inputsOutputPath, true, true, false, false); + importEndModelTest(modelPath, inputsOutputPath, true, true, false, false, true, null, nwc2ncwExpected); } @Test public void importImdbLstmTfKeras2() throws Exception { String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_2_model.h5"; String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_2_inputs_and_outputs.h5"; - importEndModelTest(modelPath, inputsOutputPath, true, true, false, false); + importEndModelTest(modelPath, inputsOutputPath, true, true, false, false, true, null, nwc2ncwExpected); } @Test public void importImdbLstmThKeras2() throws Exception { String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_2_model.h5"; String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_2_inputs_and_outputs.h5"; - importEndModelTest(modelPath, inputsOutputPath, false, true, false, false); + importEndModelTest(modelPath, inputsOutputPath, false, true, false, false, true, null, nwc2ncwExpected); } /** @@ -247,7 +257,7 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { String modelPath = "modelimport/keras/examples/simple_flatten_rnn/simple_flatten_rnn_tf_keras_2_model.h5"; String inputsOutputPath = "modelimport/keras/examples/simple_flatten_rnn/" + "simple_flatten_rnn_tf_keras_2_inputs_and_outputs.h5"; - importEndModelTest(modelPath, inputsOutputPath, true, true, false, false); + importEndModelTest(modelPath, inputsOutputPath, true, true, false, false, true, null, nwc2ncwExpected); } /** @@ -598,6 +608,122 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { model.summary(); } + @Test + public void testCausalCon1D() throws Exception { + String[] names = new String[]{ + "causal_conv1d_k2_s1_d1_cl_model.h5", + "causal_conv1d_k2_s1_d2_cl_model.h5", + "causal_conv1d_k2_s2_d1_cl_model.h5", + "causal_conv1d_k2_s3_d1_cl_model.h5", + "causal_conv1d_k3_s1_d1_cl_model.h5", + "causal_conv1d_k3_s1_d2_cl_model.h5", + "causal_conv1d_k3_s2_d1_cl_model.h5", + "causal_conv1d_k3_s3_d1_cl_model.h5", + "causal_conv1d_k4_s1_d1_cl_model.h5", + "causal_conv1d_k4_s1_d2_cl_model.h5", + "causal_conv1d_k4_s2_d1_cl_model.h5", + "causal_conv1d_k4_s3_d1_cl_model.h5" + }; + + for(String name : names ){ + System.out.println("Starting test: " + name); + String modelPath = "modelimport/keras/examples/causal_conv1d/" + name; + String inputsOutputPath = "modelimport/keras/examples/causal_conv1d/" + (name.substring(0,name.length()-"model.h5".length()) + "inputs_and_outputs.h5"); + Function f = new Function() { + @Override + public INDArray apply(INDArray i) { + //NWC to NCW + return i.permute(0, 2, 1); + } + }; + + MultiLayerNetwork net = importEndModelTest(modelPath, inputsOutputPath, true, true, + true, true, false, f, nwc2ncwExpected); + Layer l = net.getLayer(0); + Convolution1DLayer c1d = (Convolution1DLayer) l.getConfig(); + assertEquals(ConvolutionMode.Causal, c1d.getConvolutionMode()); + } + } + + @Test + public void testCon1D() throws Exception { + String[] names = new String[]{ + "conv1d_k2_s1_d1_cf_same_model.h5", + "conv1d_k2_s1_d1_cf_valid_model.h5", + "conv1d_k2_s1_d1_cl_same_model.h5", + "conv1d_k2_s1_d1_cl_valid_model.h5", + "conv1d_k2_s1_d2_cf_same_model.h5", + "conv1d_k2_s1_d2_cf_valid_model.h5", + "conv1d_k2_s1_d2_cl_same_model.h5", + "conv1d_k2_s1_d2_cl_valid_model.h5", + "conv1d_k2_s2_d1_cf_same_model.h5", + "conv1d_k2_s2_d1_cf_valid_model.h5", + "conv1d_k2_s2_d1_cl_same_model.h5", + "conv1d_k2_s2_d1_cl_valid_model.h5", + "conv1d_k2_s3_d1_cf_same_model.h5", + "conv1d_k2_s3_d1_cf_valid_model.h5", + "conv1d_k2_s3_d1_cl_same_model.h5", + "conv1d_k2_s3_d1_cl_valid_model.h5", + "conv1d_k3_s1_d1_cf_same_model.h5", + "conv1d_k3_s1_d1_cf_valid_model.h5", + "conv1d_k3_s1_d1_cl_same_model.h5", + "conv1d_k3_s1_d1_cl_valid_model.h5", + "conv1d_k3_s1_d2_cf_same_model.h5", + "conv1d_k3_s1_d2_cf_valid_model.h5", + "conv1d_k3_s1_d2_cl_same_model.h5", + "conv1d_k3_s1_d2_cl_valid_model.h5", + "conv1d_k3_s2_d1_cf_same_model.h5", + "conv1d_k3_s2_d1_cf_valid_model.h5", + "conv1d_k3_s2_d1_cl_same_model.h5", + "conv1d_k3_s2_d1_cl_valid_model.h5", + "conv1d_k3_s3_d1_cf_same_model.h5", + "conv1d_k3_s3_d1_cf_valid_model.h5", + "conv1d_k3_s3_d1_cl_same_model.h5", + "conv1d_k3_s3_d1_cl_valid_model.h5", + "conv1d_k4_s1_d1_cf_same_model.h5", + "conv1d_k4_s1_d1_cf_valid_model.h5", + "conv1d_k4_s1_d1_cl_same_model.h5", + "conv1d_k4_s1_d1_cl_valid_model.h5", + "conv1d_k4_s1_d2_cf_same_model.h5", + "conv1d_k4_s1_d2_cf_valid_model.h5", + "conv1d_k4_s1_d2_cl_same_model.h5", + "conv1d_k4_s1_d2_cl_valid_model.h5", + "conv1d_k4_s2_d1_cf_same_model.h5", + "conv1d_k4_s2_d1_cf_valid_model.h5", + "conv1d_k4_s2_d1_cl_same_model.h5", + "conv1d_k4_s2_d1_cl_valid_model.h5", + "conv1d_k4_s3_d1_cf_same_model.h5", + "conv1d_k4_s3_d1_cf_valid_model.h5", + "conv1d_k4_s3_d1_cl_same_model.h5", + "conv1d_k4_s3_d1_cl_valid_model.h5", + }; + + for(String name : names ){ + System.out.println("Starting test: " + name); + String modelPath = "modelimport/keras/examples/conv1d/" + name; + String inputsOutputPath = "modelimport/keras/examples/conv1d/" + (name.substring(0,name.length()-"model.h5".length()) + "inputs_and_outputs.h5"); + Function f = name.contains("_cf_") ? null : new Function() { + @Override + public INDArray apply(INDArray i) { + //NWC to NCW + return i.permute(0, 2, 1); + } + }; + + BiFunction f2 = name.contains("_cf_") ? null : new BiFunction() { + @Override + public INDArray apply(String s, INDArray array) { +// if("conv".equals(s)){ + return array.permute(0, 2, 1); +// } + } + }; + + importEndModelTest(modelPath, inputsOutputPath, true, true, + true, true, false, f, f2); + } + } + private ComputationGraph importFunctionalModelH5Test(String modelPath) throws Exception { return importFunctionalModelH5Test(modelPath, null, false); } @@ -640,6 +766,12 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { public MultiLayerNetwork importEndModelTest(String modelPath, String inputsOutputsPath, boolean tfOrdering, boolean checkPredictions, boolean checkGradients, boolean enforceTrainingConfig) throws Exception { + return importEndModelTest(modelPath, inputsOutputsPath, tfOrdering, checkPredictions, checkGradients, true, enforceTrainingConfig, null, null); + } + + public MultiLayerNetwork importEndModelTest(String modelPath, String inputsOutputsPath, boolean tfOrdering, boolean checkPredictions, + boolean checkGradients, boolean enforceTrainingConfig, boolean checkAuc, Function inputPreProc, + BiFunction expectedPreProc) throws Exception { MultiLayerNetwork model; try(InputStream is = Resources.asStream(modelPath)) { File modelFile = createTempFile(TEMP_MODEL_FILENAME, H5_EXTENSION); @@ -658,20 +790,25 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { if (checkPredictions) { INDArray input = getInputs(outputsArchive, tfOrdering)[0]; + if(inputPreProc != null) + input = inputPreProc.apply(input); + Map activationsKeras = getActivations(outputsArchive, tfOrdering); for (int i = 0; i < model.getLayers().length; i++) { String layerName = model.getLayerNames().get(i); if (activationsKeras.containsKey(layerName)) { INDArray activationsDl4j = model.feedForwardToLayer(i, input, false).get(i + 1); - if (activationsDl4j.shape().length == 3) - activationsDl4j = activationsDl4j.permute(0, 2, 1); - compareINDArrays(layerName, activationsKeras.get(layerName), activationsDl4j, EPS); - + INDArray exp = activationsKeras.get(layerName); + if(expectedPreProc != null) + exp = expectedPreProc.apply(layerName, exp); + compareINDArrays(layerName, exp, activationsDl4j, EPS); } } INDArray predictionsKeras = getPredictions(outputsArchive, tfOrdering)[0]; INDArray predictionsDl4j = model.output(input, false); + if(expectedPreProc != null) + predictionsKeras = expectedPreProc.apply("output", predictionsKeras); compareINDArrays("predictions", predictionsKeras, predictionsDl4j, EPS); INDArray outputs = getOutputs(outputsArchive, true)[0]; @@ -680,7 +817,8 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { } val nOut = (int) outputs.size(-1); - compareMulticlassAUC("predictions", outputs, predictionsKeras, predictionsDl4j, nOut, EPS); + if(checkAuc) + compareMulticlassAUC("predictions", outputs, predictionsKeras, predictionsDl4j, nOut, EPS); } if (checkGradients && ! SKIP_GRAD_CHECKS) { @@ -760,20 +898,23 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { return predictions; } - private static void compareINDArrays(String label, INDArray a, INDArray b, double eps) { - INDArray diff = a.sub(b.castTo(a.dataType())); + private static void compareINDArrays(String label, INDArray expected, INDArray actual, double eps) { + if(!expected.equalShapes(actual)){ + throw new IllegalStateException("Shapes do not match for \"" + label + "\": got " + Arrays.toString(expected.shape()) + " vs " + Arrays.toString(actual.shape())); + } + INDArray diff = expected.sub(actual.castTo(expected.dataType())); double min = diff.minNumber().doubleValue(); double max = diff.maxNumber().doubleValue(); - log.info(label + ": " + a.equalsWithEps(b, eps) + ", " + min + ", " + max); + log.info(label + ": " + expected.equalsWithEps(actual, eps) + ", " + min + ", " + max); double threshold = 1e-7; - double aAbsMax = Math.max(Math.abs(a.minNumber().doubleValue()), Math.abs(a.maxNumber().doubleValue())); - double bAbsMax = Math.max(Math.abs(b.minNumber().doubleValue()), Math.abs(b.maxNumber().doubleValue())); + double aAbsMax = Math.max(Math.abs(expected.minNumber().doubleValue()), Math.abs(expected.maxNumber().doubleValue())); + double bAbsMax = Math.max(Math.abs(actual.minNumber().doubleValue()), Math.abs(actual.maxNumber().doubleValue())); // skip too small absolute inputs if (Math.abs(aAbsMax) > threshold && Math.abs(bAbsMax) > threshold) { - assertTrue(a.equalsWithEps(b.castTo(a.dataType()), eps)); + boolean eq = expected.equalsWithEps(actual.castTo(expected.dataType()), eps); + assertTrue("Output differs: " + label, eq); } - } private static void compareMulticlassAUC(String label, INDArray target, INDArray a, INDArray b, int nbClasses, diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/ConvolutionMode.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/ConvolutionMode.java index d6b1e0b55..4bd1050f0 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/ConvolutionMode.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/ConvolutionMode.java @@ -69,6 +69,18 @@ package org.deeplearning4j.nn.conf; *
*
*
+ * Causal: Causal padding mode can only be used for 1D convolutional neural networks.
+ * The motivation behind causal padding mode is that the output time steps depend only on current and past time steps.
+ * That is, out[t] (for time t) depends on only on values in[T] for t < T
+ * The output size of 1D convolution/subsampling layers is the same as with SAME convolution mode - + * i.e., outSize = ceil( inputSize / stride )
+ * Padding is also the same as SAME mode, but all padding in on the left (start of sequence) instead of being on both + * left and right of the input
+ * For more details on causal convolutions, see WaveNet: A Generative Model For Audio, + * section 2.1. + *
+ *
+ *
* For further information on output sizes for convolutional neural networks, see the "Spatial arrangement" section at * http://cs231n.github.io/convolutional-networks/ * @@ -76,6 +88,6 @@ package org.deeplearning4j.nn.conf; */ public enum ConvolutionMode { - Strict, Truncate, Same + Strict, Truncate, Same, Causal } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Convolution1DLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Convolution1DLayer.java index d4ccc4811..b220ba5a6 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Convolution1DLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Convolution1DLayer.java @@ -124,6 +124,11 @@ public class Convolution1DLayer extends ConvolutionLayer { this.setKernelSize((int[]) null); } + @Override + protected boolean allowCausal() { + return true; + } + /** * @param kernelSize Kernel size * @param stride Stride diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Convolution3D.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Convolution3D.java index 61475bf98..cc26169cf 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Convolution3D.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Convolution3D.java @@ -163,6 +163,12 @@ public class Convolution3D extends ConvolutionLayer { super(new int[] {2, 2, 2}, new int[] {1, 1, 1}, new int[] {0, 0, 0}, new int[] {1, 1, 1}, 3); } + @Override + protected boolean allowCausal() { + //Causal convolution - allowed for 1D only + return false; + } + public Builder(int[] kernelSize, int[] stride, int[] padding, int[] dilation) { super(kernelSize, stride, padding, dilation, 3); } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ConvolutionLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ConvolutionLayer.java index 4fdf1e9cc..b0c5bb3d4 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ConvolutionLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ConvolutionLayer.java @@ -30,6 +30,7 @@ import org.deeplearning4j.nn.params.ConvolutionParamInitializer; import org.deeplearning4j.optimize.api.TrainingListener; import org.deeplearning4j.util.ConvolutionUtils; import org.deeplearning4j.util.ValidationUtils; +import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -283,6 +284,12 @@ public class ConvolutionLayer extends FeedForwardLayer { super(); } + @Override + protected boolean allowCausal() { + //Causal convolution - allowed for 1D only + return false; + } + /** * Size of the convolution rows/columns * @@ -456,6 +463,14 @@ public class ConvolutionLayer extends FeedForwardLayer { protected BaseConvBuilder() {} + protected abstract boolean allowCausal(); + + protected void setConvolutionMode(ConvolutionMode convolutionMode){ + Preconditions.checkState(allowCausal() || convolutionMode != ConvolutionMode.Causal, "Causal convolution mode can only be used with 1D" + + " convolutional neural network layers"); + this.convolutionMode = convolutionMode; + } + /** * If true (default): include bias parameters in the model. False: no bias. * diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Deconvolution2D.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Deconvolution2D.java index 03b6ec405..11c9fdb7b 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Deconvolution2D.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Deconvolution2D.java @@ -133,6 +133,12 @@ public class Deconvolution2D extends ConvolutionLayer { super(); } + @Override + protected boolean allowCausal() { + //Causal convolution - allowed for 1D only + return false; + } + /** * Set the convolution mode for the Convolution layer. See {@link ConvolutionMode} for more details * diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/DepthwiseConvolution2D.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/DepthwiseConvolution2D.java index 03fec1191..e103cb0a0 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/DepthwiseConvolution2D.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/DepthwiseConvolution2D.java @@ -133,6 +133,12 @@ public class DepthwiseConvolution2D extends ConvolutionLayer { super(); } + @Override + protected boolean allowCausal() { + //Causal convolution - allowed for 1D only + return false; + } + /** * Set channels multiplier for depth-wise convolution * diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SeparableConvolution2D.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SeparableConvolution2D.java index 181cc5311..133c14869 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SeparableConvolution2D.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SeparableConvolution2D.java @@ -184,6 +184,12 @@ public class SeparableConvolution2D extends ConvolutionLayer { super(); } + @Override + protected boolean allowCausal() { + //Causal convolution - allowed for 1D only + return false; + } + /** * Set channels multiplier of channels-wise step in separable convolution * diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Subsampling1DLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Subsampling1DLayer.java index 4da7ff011..9f3162374 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Subsampling1DLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Subsampling1DLayer.java @@ -167,6 +167,11 @@ public class Subsampling1DLayer extends SubsamplingLayer { this(poolingType, DEFAULT_KERNEL, DEFAULT_STRIDE, DEFAULT_PADDING); } + @Override + protected boolean allowCausal() { + return true; + } + public Builder() { this(DEFAULT_POOLING, DEFAULT_KERNEL, DEFAULT_STRIDE, DEFAULT_PADDING); } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Subsampling3DLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Subsampling3DLayer.java index 2fcc345a1..550e29e4f 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Subsampling3DLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Subsampling3DLayer.java @@ -431,6 +431,12 @@ public class Subsampling3DLayer extends NoParamLayer { this.setPoolingType(poolingType); } + protected void setConvolutionMode(ConvolutionMode convolutionMode){ + Preconditions.checkState(convolutionMode != ConvolutionMode.Causal, "Causal convolution mode can only be used with 1D" + + " convolutional neural network layers"); + this.convolutionMode = convolutionMode; + } + /** * Set the convolution mode for the Convolution layer. See {@link ConvolutionMode} for more details * diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SubsamplingLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SubsamplingLayer.java index c20526cf1..be6764e9a 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SubsamplingLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SubsamplingLayer.java @@ -28,6 +28,7 @@ import org.deeplearning4j.nn.params.EmptyParamInitializer; import org.deeplearning4j.optimize.api.TrainingListener; import org.deeplearning4j.util.ConvolutionUtils; import org.deeplearning4j.util.ValidationUtils; +import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -270,6 +271,12 @@ public class SubsamplingLayer extends NoParamLayer { super(poolingType); } + @Override + protected boolean allowCausal() { + //Only conv1d/subsampling1d can use causal mode + return false; + } + /** * Kernel size * @@ -449,6 +456,14 @@ public class SubsamplingLayer extends NoParamLayer { this.eps = eps; } + protected abstract boolean allowCausal(); + + public void setConvolutionMode(ConvolutionMode convolutionMode){ + Preconditions.checkState(allowCausal() || convolutionMode != ConvolutionMode.Causal, "Causal convolution mode can only be used with 1D" + + " convolutional neural network layers"); + this.convolutionMode = convolutionMode; + } + /** * Set the convolution mode for the Convolution layer. See {@link ConvolutionMode} for more details * diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Convolution1DLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Convolution1DLayer.java index 1ffd19062..985c2f06b 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Convolution1DLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Convolution1DLayer.java @@ -18,18 +18,30 @@ package org.deeplearning4j.nn.layers.convolution; import org.deeplearning4j.exception.DL4JInvalidInputException; import org.deeplearning4j.nn.api.MaskState; +import org.deeplearning4j.nn.conf.ConvolutionMode; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.gradient.DefaultGradient; import org.deeplearning4j.nn.gradient.Gradient; +import org.deeplearning4j.nn.params.ConvolutionParamInitializer; +import org.deeplearning4j.util.Convolution1DUtils; import org.deeplearning4j.util.ConvolutionUtils; import org.nd4j.base.Preconditions; +import org.nd4j.linalg.activations.IActivation; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.impl.layers.convolution.Conv1D; +import org.nd4j.linalg.api.ops.impl.layers.convolution.Conv1DDerivative; +import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv1DConfig; +import org.nd4j.linalg.api.ops.impl.layers.convolution.config.PaddingMode; +import org.nd4j.linalg.api.shape.LongShapeDescriptor; import org.nd4j.linalg.factory.Broadcast; +import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.primitives.Pair; import org.deeplearning4j.nn.workspace.ArrayType; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import java.util.Arrays; +import java.util.List; /** * 1D (temporal) convolutional layer. Currently, we just subclass off the @@ -70,6 +82,52 @@ public class Convolution1DLayer extends ConvolutionLayer { Broadcast.mul(epsilon, maskOut, epsilon, 0, 2); } + if(layerConf().getConvolutionMode() == ConvolutionMode.Causal){ + Pair fwd = causalConv1dForward(); + IActivation afn = layerConf().getActivationFn(); + INDArray delta = afn.backprop(fwd.getFirst(), epsilon).getFirst(); //TODO handle activation function params + + //TODO eventually we'll use this for all convolution modes - but only after libnd4j has cuDNN support + org.deeplearning4j.nn.conf.layers.Convolution1DLayer c = (org.deeplearning4j.nn.conf.layers.Convolution1DLayer) layerConf(); + Conv1DConfig conf = Conv1DConfig.builder() + .k(c.getKernelSize()[0]) + .s(c.getStride()[0]) + .d(c.getDilation()[0]) + .p(c.getPadding()[0]) + .dataFormat(Conv1DConfig.NCW) + .paddingMode(PaddingMode.CAUSAL) + .build(); + + INDArray w = getParam(ConvolutionParamInitializer.WEIGHT_KEY); + w = w.reshape(w.ordering(), w.size(0), w.size(1), w.size(2)).permute(2, 1, 0); //[oC, iC, k, 1] to [k, iC, oC] + + INDArray[] inputArrs; + INDArray[] outputArrs; + INDArray wg = gradientViews.get(ConvolutionParamInitializer.WEIGHT_KEY); + wg = wg.reshape(wg.ordering(), wg.size(0), wg.size(1), wg.size(2)).permute(2, 1, 0); //[oC, iC, k, 1] -> [kW, iC, oC] + INDArray epsOut = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, input.dataType(), input.shape()); + if(layerConf().hasBias()){ + INDArray b = getParam(ConvolutionParamInitializer.BIAS_KEY); + b = b.reshape(b.length()); + inputArrs = new INDArray[]{input.castTo(w.dataType()), w, b, delta}; + INDArray bg = gradientViews.get(ConvolutionParamInitializer.BIAS_KEY); + bg = bg.reshape(bg.length()); + outputArrs = new INDArray[]{epsOut, wg, bg}; + } else { + inputArrs = new INDArray[]{input.castTo(w.dataType()), w, delta}; + outputArrs = new INDArray[]{epsOut, wg}; + } + Conv1DDerivative op = new Conv1DDerivative(inputArrs, outputArrs, conf); + Nd4j.exec(op); + + Gradient retGradient = new DefaultGradient(); + if(layerConf().hasBias()){ + retGradient.setGradientFor(ConvolutionParamInitializer.BIAS_KEY, gradientViews.get(ConvolutionParamInitializer.BIAS_KEY)); + } + retGradient.setGradientFor(ConvolutionParamInitializer.WEIGHT_KEY, gradientViews.get(ConvolutionParamInitializer.WEIGHT_KEY), 'c'); + return new Pair<>(retGradient, epsOut); + } + // add singleton fourth dimension to input and next layer's epsilon epsilon = epsilon.reshape(epsilon.size(0), epsilon.size(1), epsilon.size(2), 1); INDArray origInput = input; @@ -98,6 +156,12 @@ public class Convolution1DLayer extends ConvolutionLayer { @Override protected Pair preOutput(boolean training, boolean forBackprop, LayerWorkspaceMgr workspaceMgr) { assertInputSet(false); + + if(layerConf().getConvolutionMode() == ConvolutionMode.Causal){ + return causalConv1dForward(); + } + + INDArray origInput = input; input = input.reshape(input.size(0), input.size(1), input.size(2), 1); @@ -113,6 +177,36 @@ public class Convolution1DLayer extends ConvolutionLayer { return preOutput; } + protected Pair causalConv1dForward(){ + //TODO eventually we'll use this for all convolution modes - but only after libnd4j has cuDNN support + org.deeplearning4j.nn.conf.layers.Convolution1DLayer c = (org.deeplearning4j.nn.conf.layers.Convolution1DLayer) layerConf(); + Conv1DConfig conf = Conv1DConfig.builder() + .k(c.getKernelSize()[0]) + .s(c.getStride()[0]) + .d(c.getDilation()[0]) + .p(c.getPadding()[0]) + .dataFormat(Conv1DConfig.NCW) + .paddingMode(PaddingMode.CAUSAL) + .build(); + INDArray w = getParam(ConvolutionParamInitializer.WEIGHT_KEY); + w = w.reshape(w.ordering(), w.size(0), w.size(1), w.size(2)).permute(2, 1, 0); //[oC, iC, k, 1] to [k, iC, oC] + + INDArray[] inputs; + if(layerConf().hasBias()){ + INDArray b = getParam(ConvolutionParamInitializer.BIAS_KEY); + b = b.reshape(b.length()); + inputs = new INDArray[]{input.castTo(w.dataType()), w, b}; + } else { + inputs = new INDArray[]{input.castTo(w.dataType()), w}; + } + + Conv1D op = new Conv1D(inputs, null, conf); + List outShape = op.calculateOutputShape(); + op.setOutputArgument(0, Nd4j.create(outShape.get(0), false)); + Nd4j.exec(op); + return new Pair<>(op.getOutputArgument(0), null); + } + @Override public INDArray activate(boolean training, LayerWorkspaceMgr workspaceMgr){ INDArray act4d = super.activate(training, workspaceMgr); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/Convolution1DUtils.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/Convolution1DUtils.java index f0c8d76c9..165483e1c 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/Convolution1DUtils.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/Convolution1DUtils.java @@ -66,7 +66,7 @@ public class Convolution1DUtils { public static long getOutputSize(long inH, int kernel, int strides, int padding, ConvolutionMode convolutionMode, int dilation) { long eKernel = effectiveKernelSize(kernel, dilation); - if (convolutionMode == ConvolutionMode.Same) { + if (convolutionMode == ConvolutionMode.Same || convolutionMode == ConvolutionMode.Causal) { return (int) Math.ceil(inH / ((double) strides)); } return (inH - eKernel + 2 * padding) / strides + 1; @@ -92,7 +92,7 @@ public class Convolution1DUtils { boolean atrous = (eKernel == kernel); validateShapes(inputData, eKernel, strides, padding, convolutionMode, dilation, inH, atrous); - if (convolutionMode == ConvolutionMode.Same) { + if (convolutionMode == ConvolutionMode.Same || convolutionMode == ConvolutionMode.Causal) { int outH = (int) Math.ceil(inH / ((double) strides)); return outH; } @@ -106,8 +106,9 @@ public class Convolution1DUtils { boolean atrous) { int inH = inShape; + boolean t = convolutionMode == ConvolutionMode.Truncate; - if (convolutionMode != ConvolutionMode.Same && (eKernel <= 0 || eKernel > inH + 2 * padding)) { + if (t && (eKernel <= 0 || eKernel > inH + 2 * padding)) { StringBuilder sb = new StringBuilder(); sb.append("Invalid input data or configuration: "); if (atrous) sb.append("effective "); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/ConvolutionUtils.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/ConvolutionUtils.java index 56421bc00..3a447c361 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/ConvolutionUtils.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/ConvolutionUtils.java @@ -121,7 +121,7 @@ public class ConvolutionUtils { int[] inShape = new int[]{inH, inW}; validateShapes(inputData, eKernel, strides, padding, convolutionMode, dilation, inShape, atrous); - if (convolutionMode == ConvolutionMode.Same) { + if (convolutionMode == ConvolutionMode.Same || convolutionMode == ConvolutionMode.Causal) { int outH = (int) Math.ceil(inH / ((double) strides[0])); int outW = (int) Math.ceil(inW / ((double) strides[1])); @@ -142,7 +142,9 @@ public class ConvolutionUtils { int inH = inShape[0]; int inW = inShape[1]; - if (convolutionMode != ConvolutionMode.Same && (eKernel[0] <= 0 || eKernel[0] > inH + 2 * padding[0])) { + boolean t = (convolutionMode == ConvolutionMode.Truncate); + + if (t && (eKernel[0] <= 0 || eKernel[0] > inH + 2 * padding[0])) { StringBuilder sb = new StringBuilder(); sb.append("Invalid input data or configuration: "); if (atrous) sb.append("effective "); @@ -158,7 +160,7 @@ public class ConvolutionUtils { throw new DL4JInvalidInputException(sb.toString()); } - if (convolutionMode != ConvolutionMode.Same && (eKernel[1] <= 0 || eKernel[1] > inW + 2 * padding[1])) { + if (t && (eKernel[1] <= 0 || eKernel[1] > inW + 2 * padding[1])) { StringBuilder sb = new StringBuilder(); sb.append("Invalid input data or configuration: "); if (atrous) sb.append("effective "); @@ -175,8 +177,7 @@ public class ConvolutionUtils { throw new DL4JInvalidInputException(sb.toString()); } - if (eKernel.length == 3 && convolutionMode != ConvolutionMode.Same - && (eKernel[2] <= 0 || eKernel[2] > inShape[2] + 2 * padding[2])) { + if (eKernel.length == 3 && t && (eKernel[2] <= 0 || eKernel[2] > inShape[2] + 2 * padding[2])) { int inD = inShape[2]; StringBuilder sb = new StringBuilder(); sb.append("Invalid input data or configuration: "); @@ -615,7 +616,7 @@ public class ConvolutionUtils { */ public static INDArray cnn1dMaskReduction(INDArray in, int kernel, int stride, int padding, int dilation, ConvolutionMode cm){ Preconditions.checkState(in.rank()==2, "Rank must be 2 for cnn1d mask array - shape ", in.shape()); - if(cm == ConvolutionMode.Same && stride == 1 ){ + if((cm == ConvolutionMode.Same || cm == ConvolutionMode.Causal) && stride == 1 ){ return in; } @@ -630,7 +631,7 @@ public class ConvolutionUtils { int[] k = new int[]{kernel,1}; int[] s = new int[]{stride, 1}; int[] d = new int[]{dilation, 1}; - if (cm == ConvolutionMode.Same) { + if (cm == ConvolutionMode.Same || cm == ConvolutionMode.Causal) { outSize = ConvolutionUtils.getOutputSize(reshaped4d, k, s, null, cm, d); //Also performs validation } else { pad = new int[]{padding, 0}; @@ -645,7 +646,7 @@ public class ConvolutionUtils { .sH(s[0]).sW(s[1]) .pH(pad == null ? 0 : pad[0]).pW(pad == null ? 0 : pad[1]) .dH(d[0]).dW(d[1]) - .isSameMode(cm== ConvolutionMode.Same) + .isSameMode(cm == ConvolutionMode.Same || cm == ConvolutionMode.Causal) .isNHWC(false) .build()); From 91de96588cabf617341fa439beab88abc3c0c288 Mon Sep 17 00:00:00 2001 From: Alex Black Date: Wed, 4 Dec 2019 23:35:38 +1100 Subject: [PATCH 30/30] BertIterator sentence pair support for supervised training (#108) * bert iterator sentence pair handling Signed-off-by: eraly * bert iterator sentence pair handling -seg Signed-off-by: eraly * bert iterator sentence pair handling tests Signed-off-by: eraly * test with pairs long done Signed-off-by: eraly * more tests with bert iter sent pairs done Signed-off-by: eraly * fixed copyright, formatting Signed-off-by: eraly * bert iterator - added featurizer for sentence pair inference Signed-off-by: eraly * bert iterator - finished tests Signed-off-by: eraly * bert iterator - finished tests, polish Signed-off-by: eraly * collection labeled sentence provider Signed-off-by: eraly * lombok fix for pojo class Signed-off-by: eraly * java doc misc clean up Signed-off-by: eraly * Private access modifiers Signed-off-by: AlexDBlack --- .../deeplearning4j/iterator/BertIterator.java | 269 +++++++++++-- .../iterator/LabeledPairSentenceProvider.java | 60 +++ ...CollectionLabeledPairSentenceProvider.java | 135 +++++++ .../CollectionLabeledSentenceProvider.java | 13 +- .../iterator/TestBertIterator.java | 363 ++++++++++++++++-- 5 files changed, 768 insertions(+), 72 deletions(-) create mode 100644 deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/iterator/LabeledPairSentenceProvider.java create mode 100644 deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/iterator/provider/CollectionLabeledPairSentenceProvider.java diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/iterator/BertIterator.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/iterator/BertIterator.java index b5ca6c91a..40c43113c 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/iterator/BertIterator.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/iterator/BertIterator.java @@ -34,6 +34,7 @@ import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.primitives.Pair; +import org.nd4j.linalg.primitives.Triple; import java.util.ArrayList; import java.util.Arrays; @@ -85,10 +86,20 @@ import java.util.Map; *
  * {@code
  *          BertIterator b;
+ *          Pair featuresAndMask;
+ *          INDArray[] features;
+ *          INDArray[] featureMasks;
+ *
+ *          //With sentences
  *          List forInference;
- *          Pair featuresAndMask = b.featurizeSentences(forInference);
- *          INDArray[] features = featuresAndMask.getFirst();
- *          INDArray[] featureMasks = featuresAndMask.getSecond();
+ *          featuresAndMask = b.featurizeSentences(forInference);
+ *
+ *          //OR with sentence pairs
+ *          List> forInferencePair};
+ *          featuresAndMask = b.featurizeSentencePairs(forInference);
+ *
+ *          features = featuresAndMask.getFirst();
+ *          featureMasks = featuresAndMask.getSecond();
  * }
  * 
* This iterator supports numerous ways of configuring the behaviour with respect to the sequence lengths and data layout.
@@ -135,6 +146,7 @@ public class BertIterator implements MultiDataSetIterator { @Setter protected MultiDataSetPreProcessor preProcessor; protected LabeledSentenceProvider sentenceProvider = null; + protected LabeledPairSentenceProvider sentencePairProvider = null; protected LengthHandling lengthHandling; protected FeatureArrays featureArrays; protected Map vocabMap; //TODO maybe use Eclipse ObjectIntHashMap or similar for fewer objects? @@ -142,6 +154,7 @@ public class BertIterator implements MultiDataSetIterator { protected UnsupervisedLabelFormat unsupervisedLabelFormat = null; protected String maskToken; protected String prependToken; + protected String appendToken; protected List vocabKeysAsList; @@ -154,6 +167,7 @@ public class BertIterator implements MultiDataSetIterator { this.padMinibatches = b.padMinibatches; this.preProcessor = b.preProcessor; this.sentenceProvider = b.sentenceProvider; + this.sentencePairProvider = b.sentencePairProvider; this.lengthHandling = b.lengthHandling; this.featureArrays = b.featureArrays; this.vocabMap = b.vocabMap; @@ -161,11 +175,14 @@ public class BertIterator implements MultiDataSetIterator { this.unsupervisedLabelFormat = b.unsupervisedLabelFormat; this.maskToken = b.maskToken; this.prependToken = b.prependToken; + this.appendToken = b.appendToken; } @Override public boolean hasNext() { - return sentenceProvider.hasNext(); + if (sentenceProvider != null) + return sentenceProvider.hasNext(); + return sentencePairProvider.hasNext(); } @Override @@ -181,29 +198,38 @@ public class BertIterator implements MultiDataSetIterator { @Override public MultiDataSet next(int num) { Preconditions.checkState(hasNext(), "No next element available"); - - List> list = new ArrayList<>(num); + List, String>> tokensAndLabelList; int mbSize = 0; + int outLength; + long[] segIdOnesFrom = null; if (sentenceProvider != null) { + List> list = new ArrayList<>(num); while (sentenceProvider.hasNext() && mbSize++ < num) { list.add(sentenceProvider.nextSentence()); } + SentenceListProcessed sentenceListProcessed = tokenizeMiniBatch(list); + tokensAndLabelList = sentenceListProcessed.getTokensAndLabelList(); + outLength = sentenceListProcessed.getMaxL(); + } else if (sentencePairProvider != null) { + List> listPairs = new ArrayList<>(num); + while (sentencePairProvider.hasNext() && mbSize++ < num) { + listPairs.add(sentencePairProvider.nextSentencePair()); + } + SentencePairListProcessed sentencePairListProcessed = tokenizePairsMiniBatch(listPairs); + tokensAndLabelList = sentencePairListProcessed.getTokensAndLabelList(); + outLength = sentencePairListProcessed.getMaxL(); + segIdOnesFrom = sentencePairListProcessed.getSegIdOnesFrom(); } else { //TODO - other types of iterators... throw new UnsupportedOperationException("Labelled sentence provider is null and no other iterator types have yet been implemented"); } - - Pair, String>>> outLTokenizedSentencesPair = tokenizeMiniBatch(list); - List, String>> tokenizedSentences = outLTokenizedSentencesPair.getRight(); - int outLength = outLTokenizedSentencesPair.getLeft(); - - Pair featuresAndMaskArraysPair = convertMiniBatchFeatures(tokenizedSentences, outLength); + Pair featuresAndMaskArraysPair = convertMiniBatchFeatures(tokensAndLabelList, outLength, segIdOnesFrom); INDArray[] featureArray = featuresAndMaskArraysPair.getFirst(); INDArray[] featureMaskArray = featuresAndMaskArraysPair.getSecond(); - Pair labelsAndMaskArraysPair = convertMiniBatchLabels(tokenizedSentences, featureArray, outLength); + Pair labelsAndMaskArraysPair = convertMiniBatchLabels(tokensAndLabelList, featureArray, outLength); INDArray[] labelArray = labelsAndMaskArraysPair.getFirst(); INDArray[] labelMaskArray = labelsAndMaskArraysPair.getSecond(); @@ -224,32 +250,59 @@ public class BertIterator implements MultiDataSetIterator { public Pair featurizeSentences(List listOnlySentences) { List> sentencesWithNullLabel = addDummyLabel(listOnlySentences); + SentenceListProcessed sentenceListProcessed = tokenizeMiniBatch(sentencesWithNullLabel); + List, String>> tokensAndLabelList = sentenceListProcessed.getTokensAndLabelList(); + int outLength = sentenceListProcessed.getMaxL(); - Pair, String>>> outLTokenizedSentencesPair = tokenizeMiniBatch(sentencesWithNullLabel); - List, String>> tokenizedSentences = outLTokenizedSentencesPair.getRight(); - int outLength = outLTokenizedSentencesPair.getLeft(); - - Pair featureFeatureMasks = convertMiniBatchFeatures(tokenizedSentences, outLength); if (preProcessor != null) { + Pair featureFeatureMasks = convertMiniBatchFeatures(tokensAndLabelList, outLength, null); MultiDataSet dummyMDS = new org.nd4j.linalg.dataset.MultiDataSet(featureFeatureMasks.getFirst(), null, featureFeatureMasks.getSecond(), null); preProcessor.preProcess(dummyMDS); - return new Pair(dummyMDS.getFeatures(), dummyMDS.getFeaturesMaskArrays()); + return new Pair<>(dummyMDS.getFeatures(), dummyMDS.getFeaturesMaskArrays()); } - return convertMiniBatchFeatures(tokenizedSentences, outLength); + return convertMiniBatchFeatures(tokensAndLabelList, outLength, null); } - private Pair convertMiniBatchFeatures(List, String>> tokenizedSentences, int outLength) { - int mbPadded = padMinibatches ? minibatchSize : tokenizedSentences.size(); + /** + * For use during inference. Will convert a given pair of a list of sentences to features and feature masks as appropriate. + * + * @param listOnlySentencePairs + * @return Pair of INDArrays[], first element is feature arrays and the second is the masks array + */ + public Pair featurizeSentencePairs(List> listOnlySentencePairs) { + Preconditions.checkState(sentencePairProvider != null, "The featurizeSentencePairs method is meant for inference with sentence pairs. Use only when the sentence pair provider is set (i.e not null)."); + + List> sentencePairsWithNullLabel = addDummyLabelForPairs(listOnlySentencePairs); + SentencePairListProcessed sentencePairListProcessed = tokenizePairsMiniBatch(sentencePairsWithNullLabel); + List, String>> tokensAndLabelList = sentencePairListProcessed.getTokensAndLabelList(); + int outLength = sentencePairListProcessed.getMaxL(); + long[] segIdOnesFrom = sentencePairListProcessed.getSegIdOnesFrom(); + if (preProcessor != null) { + Pair featuresAndMaskArraysPair = convertMiniBatchFeatures(tokensAndLabelList, outLength, segIdOnesFrom); + MultiDataSet dummyMDS = new org.nd4j.linalg.dataset.MultiDataSet(featuresAndMaskArraysPair.getFirst(), null, featuresAndMaskArraysPair.getSecond(), null); + preProcessor.preProcess(dummyMDS); + return new Pair<>(dummyMDS.getFeatures(), dummyMDS.getFeaturesMaskArrays()); + } + return convertMiniBatchFeatures(tokensAndLabelList, outLength, segIdOnesFrom); + } + + private Pair convertMiniBatchFeatures(List, String>> tokensAndLabelList, int outLength, long[] segIdOnesFrom) { + int mbPadded = padMinibatches ? minibatchSize : tokensAndLabelList.size(); int[][] outIdxs = new int[mbPadded][outLength]; int[][] outMask = new int[mbPadded][outLength]; - for (int i = 0; i < tokenizedSentences.size(); i++) { - Pair, String> p = tokenizedSentences.get(i); + int[][] outSegmentId = null; + if (featureArrays == FeatureArrays.INDICES_MASK_SEGMENTID) + outSegmentId = new int[mbPadded][outLength]; + for (int i = 0; i < tokensAndLabelList.size(); i++) { + Pair, String> p = tokensAndLabelList.get(i); List t = p.getFirst(); for (int j = 0; j < outLength && j < t.size(); j++) { Preconditions.checkState(vocabMap.containsKey(t.get(j)), "Unknown token encountered: token \"%s\" is not in vocabulary", t.get(j)); int idx = vocabMap.get(t.get(j)); outIdxs[i][j] = idx; outMask[i][j] = 1; + if (segIdOnesFrom != null && j >= segIdOnesFrom[i]) + outSegmentId[i][j] = 1; } } @@ -260,8 +313,7 @@ public class BertIterator implements MultiDataSetIterator { INDArray[] f; INDArray[] fm; if (featureArrays == FeatureArrays.INDICES_MASK_SEGMENTID) { - //For now: always segment index 0 (only single s sequence input supported) - outSegmentIdArr = Nd4j.zeros(DataType.INT, mbPadded, outLength); + outSegmentIdArr = Nd4j.createFromArray(outSegmentId); f = new INDArray[]{outIdxsArr, outSegmentIdArr}; fm = new INDArray[]{outMaskArr, null}; } else { @@ -271,16 +323,15 @@ public class BertIterator implements MultiDataSetIterator { return new Pair<>(f, fm); } - private Pair, String>>> tokenizeMiniBatch(List> list) { + private SentenceListProcessed tokenizeMiniBatch(List> list) { //Get and tokenize the sentences for this minibatch - List, String>> tokenizedSentences = new ArrayList<>(list.size()); + SentenceListProcessed sentenceListProcessed = new SentenceListProcessed(list.size()); int longestSeq = -1; for (Pair p : list) { List tokens = tokenizeSentence(p.getFirst()); - tokenizedSentences.add(new Pair<>(tokens, p.getSecond())); + sentenceListProcessed.addProcessedToList(new Pair<>(tokens, p.getSecond())); longestSeq = Math.max(longestSeq, tokens.size()); } - //Determine output array length... int outLength; switch (lengthHandling) { @@ -296,7 +347,52 @@ public class BertIterator implements MultiDataSetIterator { default: throw new RuntimeException("Not implemented length handling mode: " + lengthHandling); } - return new Pair<>(outLength, tokenizedSentences); + sentenceListProcessed.setMaxL(outLength); + return sentenceListProcessed; + } + + private SentencePairListProcessed tokenizePairsMiniBatch(List> listPairs) { + SentencePairListProcessed sentencePairListProcessed = new SentencePairListProcessed(listPairs.size()); + for (Triple t : listPairs) { + List tokensL = tokenizeSentence(t.getFirst(), true); + List tokensR = tokenizeSentence(t.getSecond(), true); + List tokens = new ArrayList<>(maxTokens); + int maxLength = maxTokens; + if (prependToken != null) + maxLength--; + if (appendToken != null) + maxLength -= 2; + if (tokensL.size() + tokensR.size() > maxLength) { + boolean shortOnL = tokensL.size() < tokensR.size(); + int shortSize = Math.min(tokensL.size(), tokensR.size()); + if (shortSize > maxLength / 2) { + //both lists need to be sliced + tokensL.subList(maxLength / 2, tokensL.size()).clear(); //if maxsize/2 is odd pop extra on L side to match implementation in TF + tokensR.subList(maxLength - maxLength / 2, tokensR.size()).clear(); + } else { + //slice longer list + if (shortOnL) { + //longer on R - slice R + tokensR.subList(maxLength - tokensL.size(), tokensR.size()).clear(); + } else { + //longer on L - slice L + tokensL.subList(maxLength - tokensR.size(), tokensL.size()).clear(); + } + } + } + if (prependToken != null) + tokens.add(prependToken); + tokens.addAll(tokensL); + if (appendToken != null) + tokens.add(appendToken); + int segIdOnesFrom = tokens.size(); + tokens.addAll(tokensR); + if (appendToken != null) + tokens.add(appendToken); + sentencePairListProcessed.addProcessedToList(segIdOnesFrom, new Pair<>(tokens, t.getThird())); + } + sentencePairListProcessed.setMaxL(maxTokens); + return sentencePairListProcessed; } private Pair convertMiniBatchLabels(List, String>> tokenizedSentences, INDArray[] featureArray, int outLength) { @@ -316,6 +412,14 @@ public class BertIterator implements MultiDataSetIterator { classLabels[i] = labels.indexOf(lbl); Preconditions.checkState(classLabels[i] >= 0, "Provided label \"%s\" for sentence does not exist in set of classes/categories", lbl); } + } else if (sentencePairProvider != null) { + numClasses = sentencePairProvider.numLabelClasses(); + List labels = sentencePairProvider.allLabels(); + for (int i = 0; i < mbSize; i++) { + String lbl = tokenizedSentences.get(i).getRight(); + classLabels[i] = labels.indexOf(lbl); + Preconditions.checkState(classLabels[i] >= 0, "Provided label \"%s\" for sentence does not exist in set of classes/categories", lbl); + } } else { throw new RuntimeException(); } @@ -392,16 +496,22 @@ public class BertIterator implements MultiDataSetIterator { } private List tokenizeSentence(String sentence) { + return tokenizeSentence(sentence, false); + } + + private List tokenizeSentence(String sentence, boolean ignorePrependAppend) { Tokenizer t = tokenizerFactory.create(sentence); List tokens = new ArrayList<>(); - if (prependToken != null) + if (prependToken != null && !ignorePrependAppend) tokens.add(prependToken); while (t.hasMoreTokens()) { String token = t.nextToken(); tokens.add(token); } + if (appendToken != null && !ignorePrependAppend) + tokens.add(appendToken); return tokens; } @@ -414,6 +524,13 @@ public class BertIterator implements MultiDataSetIterator { return list; } + private List> addDummyLabelForPairs(List> listOnlySentencePairs) { + List> list = new ArrayList<>(listOnlySentencePairs.size()); + for (Pair p : listOnlySentencePairs) { + list.add(new Triple(p.getFirst(), p.getSecond(), null)); + } + return list; + } @Override public boolean resetSupported() { @@ -446,12 +563,14 @@ public class BertIterator implements MultiDataSetIterator { protected boolean padMinibatches = false; protected MultiDataSetPreProcessor preProcessor; protected LabeledSentenceProvider sentenceProvider = null; + protected LabeledPairSentenceProvider sentencePairProvider = null; protected FeatureArrays featureArrays = FeatureArrays.INDICES_MASK_SEGMENTID; protected Map vocabMap; //TODO maybe use Eclipse ObjectIntHashMap for fewer objects? protected BertSequenceMasker masker = new BertMaskedLMMasker(); protected UnsupervisedLabelFormat unsupervisedLabelFormat; protected String maskToken; protected String prependToken; + protected String appendToken; /** * Specify the {@link Task} the iterator should be set up for. See {@link BertIterator} for more details. @@ -519,14 +638,21 @@ public class BertIterator implements MultiDataSetIterator { } /** - * Specify the source of the data for classification. Can also be used for unsupervised learning; in the unsupervised - * use case, the labels will be ignored. + * Specify the source of the data for classification. */ public Builder sentenceProvider(LabeledSentenceProvider sentenceProvider) { this.sentenceProvider = sentenceProvider; return this; } + /** + * Specify the source of the data for classification on sentence pairs. + */ + public Builder sentencePairProvider(LabeledPairSentenceProvider sentencePairProvider) { + this.sentencePairProvider = sentencePairProvider; + return this; + } + /** * Specify what arrays should be returned. See {@link BertIterator} for more details. */ @@ -591,6 +717,19 @@ public class BertIterator implements MultiDataSetIterator { return this; } + /** + * Append the specified token to the sequences, when doing training on sentence pairs.
+ * Generally "[SEP]" is used + * No token in appended by default. + * + * @param appendToken Token at end of each sentence for pairs of sentences (null: no token will be appended) + * @return + */ + public Builder appendToken(String appendToken) { + this.appendToken = appendToken; + return this; + } + public BertIterator build() { Preconditions.checkState(task != null, "No task has been set. Use .task(BertIterator.Task.X) to set the task to be performed"); Preconditions.checkState(tokenizerFactory != null, "No tokenizer factory has been set. A tokenizer factory (such as BertWordPieceTokenizerFactory) is required"); @@ -598,9 +737,69 @@ public class BertIterator implements MultiDataSetIterator { Preconditions.checkState(task != Task.UNSUPERVISED || masker != null, "If task is UNSUPERVISED training, a masker must be set via masker(BertSequenceMasker) method"); Preconditions.checkState(task != Task.UNSUPERVISED || unsupervisedLabelFormat != null, "If task is UNSUPERVISED training, a label format must be set via masker(BertSequenceMasker) method"); Preconditions.checkState(task != Task.UNSUPERVISED || maskToken != null, "If task is UNSUPERVISED training, the mask token in the vocab (such as \"[MASK]\" must be specified"); - + if (sentencePairProvider != null) { + Preconditions.checkState(task == Task.SEQ_CLASSIFICATION, "Currently only supervised sequence classification is set up with sentence pairs. \".task(BertIterator.Task.SEQ_CLASSIFICATION)\" is required with a sentence pair provider"); + Preconditions.checkState(featureArrays == FeatureArrays.INDICES_MASK_SEGMENTID, "Currently only supervised sequence classification is set up with sentence pairs. \".featureArrays(FeatureArrays.INDICES_MASK_SEGMENTID)\" is required with a sentence pair provider"); + Preconditions.checkState(lengthHandling == LengthHandling.FIXED_LENGTH, "Currently only fixed length is supported for sentence pairs. \".lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, maxLength)\" is required with a sentence pair provider"); + Preconditions.checkState(sentencePairProvider != null, "Provide either a sentence provider or a sentence pair provider. Both cannot be non null"); + } + if (appendToken != null) { + Preconditions.checkState(sentencePairProvider != null, "Tokens are only appended with sentence pairs. Sentence pair provider is not set. Set sentence pair provider."); + } return new BertIterator(this); } } + private static class SentencePairListProcessed { + private int listLength = 0; + + @Getter + private long[] segIdOnesFrom; + private int cursor = 0; + private SentenceListProcessed sentenceListProcessed; + + private SentencePairListProcessed(int listLength) { + this.listLength = listLength; + segIdOnesFrom = new long[listLength]; + sentenceListProcessed = new SentenceListProcessed(listLength); + } + + private void addProcessedToList(long segIdIdx, Pair, String> tokenizedSentencePairAndLabel) { + segIdOnesFrom[cursor] = segIdIdx; + sentenceListProcessed.addProcessedToList(tokenizedSentencePairAndLabel); + cursor++; + } + + private void setMaxL(int maxL) { + sentenceListProcessed.setMaxL(maxL); + } + + private int getMaxL() { + return sentenceListProcessed.getMaxL(); + } + + private List, String>> getTokensAndLabelList() { + return sentenceListProcessed.getTokensAndLabelList(); + } + } + + private static class SentenceListProcessed { + private int listLength; + + @Getter + @Setter + private int maxL; + + @Getter + private List, String>> tokensAndLabelList; + + private SentenceListProcessed(int listLength) { + this.listLength = listLength; + tokensAndLabelList = new ArrayList<>(listLength); + } + + private void addProcessedToList(Pair, String> tokenizedSentenceAndLabel) { + tokensAndLabelList.add(tokenizedSentenceAndLabel); + } + } } diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/iterator/LabeledPairSentenceProvider.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/iterator/LabeledPairSentenceProvider.java new file mode 100644 index 000000000..ee68477ee --- /dev/null +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/iterator/LabeledPairSentenceProvider.java @@ -0,0 +1,60 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.deeplearning4j.iterator; + +import org.nd4j.linalg.primitives.Triple; + +import java.util.List; + +/** + * LabeledPairSentenceProvider: a simple iterator interface over a pair of sentences/documents that have a label.
+ */ +public interface LabeledPairSentenceProvider { + + /** + * Are there more sentences/documents available? + */ + boolean hasNext(); + + /** + * @return Triple: two sentence/document texts and label + */ + Triple nextSentencePair(); + + /** + * Reset the iterator - including shuffling the order, if necessary/appropriate + */ + void reset(); + + /** + * Return the total number of sentences, or -1 if not available + */ + int totalNumSentences(); + + /** + * Return the list of labels - this also defines the class/integer label assignment order + */ + List allLabels(); + + /** + * Equivalent to allLabels().size() + */ + int numLabelClasses(); + +} + + diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/iterator/provider/CollectionLabeledPairSentenceProvider.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/iterator/provider/CollectionLabeledPairSentenceProvider.java new file mode 100644 index 000000000..c3c752bed --- /dev/null +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/iterator/provider/CollectionLabeledPairSentenceProvider.java @@ -0,0 +1,135 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.deeplearning4j.iterator.provider; + +import lombok.NonNull; +import org.deeplearning4j.iterator.LabeledPairSentenceProvider; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.primitives.Triple; +import org.nd4j.linalg.util.MathUtils; + +import java.util.*; + +/** + * Iterate over a pair of sentences/documents, + * where the sentences and labels are provided in lists. + * + */ +public class CollectionLabeledPairSentenceProvider implements LabeledPairSentenceProvider { + + private final List sentenceL; + private final List sentenceR; + private final List labels; + private final Random rng; + private final int[] order; + private final List allLabels; + + private int cursor = 0; + + /** + * Lists containing sentences to iterate over with a third for labels + * Sentences in the same position in the first two lists are considered a pair + * @param sentenceL + * @param sentenceR + * @param labelsForSentences + */ + public CollectionLabeledPairSentenceProvider(@NonNull List sentenceL, @NonNull List sentenceR, + @NonNull List labelsForSentences) { + this(sentenceL, sentenceR, labelsForSentences, new Random()); + } + + /** + * Lists containing sentences to iterate over with a third for labels + * Sentences in the same position in the first two lists are considered a pair + * @param sentenceL + * @param sentenceR + * @param labelsForSentences + * @param rng If null, list order is not shuffled + */ + public CollectionLabeledPairSentenceProvider(@NonNull List sentenceL, List sentenceR, @NonNull List labelsForSentences, + Random rng) { + if (sentenceR.size() != sentenceL.size()) { + throw new IllegalArgumentException("Sentence lists must be same size (first list size: " + + sentenceL.size() + ", second list size: " + sentenceR.size() + ")"); + } + if (sentenceR.size() != labelsForSentences.size()) { + throw new IllegalArgumentException("Sentence pairs and labels must be same size (sentence pair size: " + + sentenceR.size() + ", labels size: " + labelsForSentences.size() + ")"); + } + + this.sentenceL = sentenceL; + this.sentenceR = sentenceR; + this.labels = labelsForSentences; + this.rng = rng; + if (rng == null) { + order = null; + } else { + order = new int[sentenceR.size()]; + for (int i = 0; i < sentenceR.size(); i++) { + order[i] = i; + } + + MathUtils.shuffleArray(order, rng); + } + + //Collect set of unique labels for all sentences + Set uniqueLabels = new HashSet<>(labelsForSentences); + allLabels = new ArrayList<>(uniqueLabels); + Collections.sort(allLabels); + } + + @Override + public boolean hasNext() { + return cursor < sentenceR.size(); + } + + @Override + public Triple nextSentencePair() { + Preconditions.checkState(hasNext(),"No next element available"); + int idx; + if (rng == null) { + idx = cursor++; + } else { + idx = order[cursor++]; + } + return new Triple<>(sentenceL.get(idx), sentenceR.get(idx), labels.get(idx)); + } + + @Override + public void reset() { + cursor = 0; + if (rng != null) { + MathUtils.shuffleArray(order, rng); + } + } + + @Override + public int totalNumSentences() { + return sentenceR.size(); + } + + @Override + public List allLabels() { + return allLabels; + } + + @Override + public int numLabelClasses() { + return allLabels.size(); + } +} + diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/iterator/provider/CollectionLabeledSentenceProvider.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/iterator/provider/CollectionLabeledSentenceProvider.java index 3dbaa7db8..e6d65b48b 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/iterator/provider/CollectionLabeledSentenceProvider.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/iterator/provider/CollectionLabeledSentenceProvider.java @@ -18,6 +18,7 @@ package org.deeplearning4j.iterator.provider; import lombok.NonNull; import org.deeplearning4j.iterator.LabeledSentenceProvider; +import org.nd4j.base.Preconditions; import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.util.MathUtils; @@ -40,15 +41,15 @@ public class CollectionLabeledSentenceProvider implements LabeledSentenceProvide private int cursor = 0; public CollectionLabeledSentenceProvider(@NonNull List sentences, - @NonNull List labelsForSentences) { + @NonNull List labelsForSentences) { this(sentences, labelsForSentences, new Random()); } public CollectionLabeledSentenceProvider(@NonNull List sentences, @NonNull List labelsForSentences, - Random rng) { + Random rng) { if (sentences.size() != labelsForSentences.size()) { throw new IllegalArgumentException("Sentences and labels must be same size (sentences size: " - + sentences.size() + ", labels size: " + labelsForSentences.size() + ")"); + + sentences.size() + ", labels size: " + labelsForSentences.size() + ")"); } this.sentences = sentences; @@ -66,10 +67,7 @@ public class CollectionLabeledSentenceProvider implements LabeledSentenceProvide } //Collect set of unique labels for all sentences - Set uniqueLabels = new HashSet<>(); - for (String s : labelsForSentences) { - uniqueLabels.add(s); - } + Set uniqueLabels = new HashSet<>(labelsForSentences); allLabels = new ArrayList<>(uniqueLabels); Collections.sort(allLabels); } @@ -81,6 +79,7 @@ public class CollectionLabeledSentenceProvider implements LabeledSentenceProvide @Override public Pair nextSentence() { + Preconditions.checkState(hasNext(), "No next element available"); int idx; if (rng == null) { idx = cursor++; diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/iterator/TestBertIterator.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/iterator/TestBertIterator.java index d4be5e352..a6716ba40 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/iterator/TestBertIterator.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/iterator/TestBertIterator.java @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -27,6 +28,7 @@ import org.nd4j.linalg.dataset.api.MultiDataSet; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.primitives.Pair; +import org.nd4j.linalg.primitives.Triple; import org.nd4j.resources.Resources; import java.io.File; @@ -43,7 +45,8 @@ public class TestBertIterator extends BaseDL4JTest { private File pathToVocab = Resources.asFile("other/vocab.txt"); private static Charset c = StandardCharsets.UTF_8; - public TestBertIterator() throws IOException{ } + public TestBertIterator() throws IOException { + } @Test(timeout = 20000L) public void testBertSequenceClassification() throws Exception { @@ -74,8 +77,8 @@ public class TestBertIterator extends BaseDL4JTest { INDArray expEx0 = Nd4j.create(DataType.INT, 1, 16); INDArray expM0 = Nd4j.create(DataType.INT, 1, 16); List tokens = t.create(toTokenize1).getTokens(); - Map m = t.getVocab(); - for( int i=0; i m = t.getVocab(); + for (int i = 0; i < tokens.size(); i++) { int idx = m.get(tokens.get(i)); expEx0.putScalar(0, i, idx); expM0.putScalar(0, i, 1); @@ -84,9 +87,9 @@ public class TestBertIterator extends BaseDL4JTest { INDArray expEx1 = Nd4j.create(DataType.INT, 1, 16); INDArray expM1 = Nd4j.create(DataType.INT, 1, 16); List tokens2 = t.create(toTokenize2).getTokens(); - for( int i=0; i tokens = t.create(toTokenize1).getTokens(); - Map m = t.getVocab(); - for( int i=0; i m = t.getVocab(); + for (int i = 0; i < tokens.size(); i++) { int idx = m.get(tokens.get(i)); expEx0.putScalar(0, i, idx); expM0.putScalar(0, i, 1); @@ -178,9 +184,10 @@ public class TestBertIterator extends BaseDL4JTest { INDArray expEx1 = Nd4j.create(DataType.INT, 1, 16); INDArray expM1 = Nd4j.create(DataType.INT, 1, 16); List tokens2 = t.create(toTokenize2).getTokens(); - for( int i=0; i forInference = new ArrayList<>(); forInference.add(toTokenize1); forInference.add(toTokenize2); + forInference.add(toTokenize3); BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c); INDArray expEx0 = Nd4j.create(DataType.INT, 1, 16); INDArray expM0 = Nd4j.create(DataType.INT, 1, 16); List tokens = t.create(toTokenize1).getTokens(); - Map m = t.getVocab(); - for( int i=0; i m = t.getVocab(); + for (int i = 0; i < tokens.size(); i++) { int idx = m.get(tokens.get(i)); expEx0.putScalar(0, i, idx); expM0.putScalar(0, i, 1); @@ -253,9 +262,9 @@ public class TestBertIterator extends BaseDL4JTest { INDArray expEx1 = Nd4j.create(DataType.INT, 1, 16); INDArray expM1 = Nd4j.create(DataType.INT, 1, 16); List tokens2 = t.create(toTokenize2).getTokens(); - for( int i=0; i tokens3 = t.create(toTokenize3).getTokens(); + for (int i = 0; i < tokens3.size(); i++) { + String token = tokens3.get(i); + if (!m.containsKey(token)) { + throw new IllegalStateException("Unknown token: \"" + token + "\""); + } + int idx = m.get(token); + expEx3.putScalar(0, i, idx); + expM3.putScalar(0, i, 1); + } - INDArray expF = Nd4j.vstack(expEx0, expEx1, zeros); - INDArray expM = Nd4j.vstack(expM0, expM1, zeros); - INDArray expL = Nd4j.createFromArray(new float[][]{{1, 0}, {0, 1}, {0, 0}, {0, 0}}); + INDArray zeros = Nd4j.create(DataType.INT, 1, 16); + INDArray expF = Nd4j.vstack(expEx0, expEx1, expEx3, zeros); + INDArray expM = Nd4j.vstack(expM0, expM1, expM3, zeros); + INDArray expL = Nd4j.createFromArray(new float[][]{{1, 0}, {0, 1}, {1, 0}, {0, 0}}); INDArray expLM = Nd4j.create(DataType.FLOAT, 4, 1); expLM.putScalar(0, 0, 1); expLM.putScalar(1, 0, 1); + expLM.putScalar(2, 0, 1); //-------------------------------------------------------------- @@ -305,9 +327,234 @@ public class TestBertIterator extends BaseDL4JTest { assertEquals(expM, b.featurizeSentences(forInference).getSecond()[0]); } + @Test + public void testSentencePairsSingle() throws IOException { + String shortSent = "I saw a girl with a telescope."; + String longSent = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum"; + boolean prependAppend; + BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c); + int shortL = t.create(shortSent).countTokens(); + int longL = t.create(longSent).countTokens(); + + Triple multiDataSetTriple; + MultiDataSet shortLongPair, shortSentence, longSentence; + + // check for pair max length exactly equal to sum of lengths - pop neither no padding + // should be the same as hstack with segment ids 1 for second sentence + prependAppend = true; + multiDataSetTriple = generateMultiDataSets(new Triple<>(shortL + longL, shortL, longL), prependAppend); + shortLongPair = multiDataSetTriple.getFirst(); + shortSentence = multiDataSetTriple.getSecond(); + longSentence = multiDataSetTriple.getThird(); + assertEquals(shortLongPair.getFeatures(0), Nd4j.hstack(shortSentence.getFeatures(0), longSentence.getFeatures(0))); + longSentence.getFeatures(1).addi(1); + assertEquals(shortLongPair.getFeatures(1), Nd4j.hstack(shortSentence.getFeatures(1), longSentence.getFeatures(1))); + assertEquals(shortLongPair.getFeaturesMaskArray(0), Nd4j.hstack(shortSentence.getFeaturesMaskArray(0), longSentence.getFeaturesMaskArray(0))); + + //check for pair max length greater than sum of lengths - pop neither with padding + // features should be the same as hstack of shorter and longer padded with prepend/append + // segment id should 1 only in the longer for part of the length of the sentence + prependAppend = true; + multiDataSetTriple = generateMultiDataSets(new Triple<>(shortL + longL + 5, shortL, longL + 5), prependAppend); + shortLongPair = multiDataSetTriple.getFirst(); + shortSentence = multiDataSetTriple.getSecond(); + longSentence = multiDataSetTriple.getThird(); + assertEquals(shortLongPair.getFeatures(0), Nd4j.hstack(shortSentence.getFeatures(0), longSentence.getFeatures(0))); + longSentence.getFeatures(1).get(NDArrayIndex.all(), NDArrayIndex.interval(0, longL + 1)).addi(1); //segmentId stays 0 for the padded part + assertEquals(shortLongPair.getFeatures(1), Nd4j.hstack(shortSentence.getFeatures(1), longSentence.getFeatures(1))); + assertEquals(shortLongPair.getFeaturesMaskArray(0), Nd4j.hstack(shortSentence.getFeaturesMaskArray(0), longSentence.getFeaturesMaskArray(0))); + + //check for pair max length less than shorter sentence - pop both + //should be the same as hstack with segment ids 1 for second sentence if no prepend/append + int maxL = shortL - 2; + prependAppend = false; + multiDataSetTriple = generateMultiDataSets(new Triple<>(maxL, maxL / 2, maxL - maxL / 2), prependAppend); + shortLongPair = multiDataSetTriple.getFirst(); + shortSentence = multiDataSetTriple.getSecond(); + longSentence = multiDataSetTriple.getThird(); + assertEquals(shortLongPair.getFeatures(0), Nd4j.hstack(shortSentence.getFeatures(0), longSentence.getFeatures(0))); + longSentence.getFeatures(1).addi(1); + assertEquals(shortLongPair.getFeatures(1), Nd4j.hstack(shortSentence.getFeatures(1), longSentence.getFeatures(1))); + assertEquals(shortLongPair.getFeaturesMaskArray(0), Nd4j.hstack(shortSentence.getFeaturesMaskArray(0), longSentence.getFeaturesMaskArray(0))); + } + + @Test + public void testSentencePairsUnequalLengths() throws IOException { + //check for pop only longer (i.e between longer and longer + shorter), first row pop from second sentence, next row pop from first sentence, nothing to pop in the third row + //should be identical to hstack if there is no append, prepend + //batch size is 2 + int mbS = 4; + String shortSent = "I saw a girl with a telescope."; + String longSent = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum"; + String sent1 = "Goodnight noises everywhere"; //shorter than shortSent - no popping + String sent2 = "Goodnight moon"; //shorter than shortSent - no popping + BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c); + int shortL = t.create(shortSent).countTokens(); + int longL = t.create(longSent).countTokens(); + int sent1L = t.create(sent1).countTokens(); + int sent2L = t.create(sent2).countTokens(); + //won't check 2*shortL + 1 because this will always pop on the left + for (int maxL = longL + shortL - 1; maxL > 2 * shortL; maxL--) { + MultiDataSet leftMDS = BertIterator.builder() + .tokenizer(t) + .minibatchSize(mbS) + .featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID) + .vocabMap(t.getVocab()) + .task(BertIterator.Task.SEQ_CLASSIFICATION) + .lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, longL + 10) //random big num guaranteed to be longer than either + .sentenceProvider(new TestSentenceProvider()) + .padMinibatches(true) + .build().next(); + + MultiDataSet rightMDS = BertIterator.builder() + .tokenizer(t) + .minibatchSize(mbS) + .featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID) + .vocabMap(t.getVocab()) + .task(BertIterator.Task.SEQ_CLASSIFICATION) + .lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, longL + 10) //random big num guaranteed to be longer than either + .sentenceProvider(new TestSentenceProvider(true)) + .padMinibatches(true) + .build().next(); + + MultiDataSet pairMDS = BertIterator.builder() + .tokenizer(t) + .minibatchSize(mbS) + .featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID) + .vocabMap(t.getVocab()) + .task(BertIterator.Task.SEQ_CLASSIFICATION) + .lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, maxL) //random big num guaranteed to be longer than either + .sentencePairProvider(new TestSentencePairProvider()) + .padMinibatches(true) + .build().next(); + + //Left sentences here are {{shortSent}, + // {longSent}, + // {Sent1}} + //Right sentences here are {{longSent}, + // {shortSent}, + // {Sent2}} + //The sentence pairs here are {{shortSent,longSent}, + // {longSent,shortSent} + // {Sent1, Sent2}} + + //CHECK FEATURES + INDArray combinedFeat = Nd4j.create(DataType.INT,mbS,maxL); + //left side + INDArray leftFeatures = leftMDS.getFeatures(0); + INDArray topLSentFeat = leftFeatures.getRow(0).get(NDArrayIndex.interval(0, shortL)); + INDArray midLSentFeat = leftFeatures.getRow(1).get(NDArrayIndex.interval(0, maxL - shortL)); + INDArray bottomLSentFeat = leftFeatures.getRow(2).get(NDArrayIndex.interval(0,sent1L)); + //right side + INDArray rightFeatures = rightMDS.getFeatures(0); + INDArray topRSentFeat = rightFeatures.getRow(0).get(NDArrayIndex.interval(0, maxL - shortL)); + INDArray midRSentFeat = rightFeatures.getRow(1).get(NDArrayIndex.interval(0, shortL)); + INDArray bottomRSentFeat = rightFeatures.getRow(2).get(NDArrayIndex.interval(0,sent2L)); + //expected pair + combinedFeat.getRow(0).addi(Nd4j.hstack(topLSentFeat,topRSentFeat)); + combinedFeat.getRow(1).addi(Nd4j.hstack(midLSentFeat,midRSentFeat)); + combinedFeat.getRow(2).get(NDArrayIndex.interval(0,sent1L+sent2L)).addi(Nd4j.hstack(bottomLSentFeat,bottomRSentFeat)); + + assertEquals(maxL, pairMDS.getFeatures(0).shape()[1]); + assertArrayEquals(combinedFeat.shape(), pairMDS.getFeatures(0).shape()); + assertEquals(combinedFeat, pairMDS.getFeatures(0)); + + //CHECK SEGMENT ID + INDArray combinedFetSeg = Nd4j.create(DataType.INT, mbS, maxL); + combinedFetSeg.get(NDArrayIndex.point(0), NDArrayIndex.interval(shortL, maxL)).addi(1); + combinedFetSeg.get(NDArrayIndex.point(1), NDArrayIndex.interval(maxL - shortL, maxL)).addi(1); + combinedFetSeg.get(NDArrayIndex.point(2), NDArrayIndex.interval(sent1L, sent1L+sent2L)).addi(1); + assertArrayEquals(combinedFetSeg.shape(), pairMDS.getFeatures(1).shape()); + assertEquals(maxL, combinedFetSeg.shape()[1]); + assertEquals(combinedFetSeg, pairMDS.getFeatures(1)); + } + } + + @Test + public void testSentencePairFeaturizer() throws IOException { + String shortSent = "I saw a girl with a telescope."; + String longSent = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum"; + List> listSentencePair = new ArrayList<>(); + listSentencePair.add(new Pair<>(shortSent, longSent)); + listSentencePair.add(new Pair<>(longSent, shortSent)); + BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c); + BertIterator b = BertIterator.builder() + .tokenizer(t) + .minibatchSize(2) + .padMinibatches(true) + .featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID) + .vocabMap(t.getVocab()) + .task(BertIterator.Task.SEQ_CLASSIFICATION) + .lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, 128) + .sentencePairProvider(new TestSentencePairProvider()) + .prependToken("[CLS]") + .appendToken("[SEP]") + .build(); + MultiDataSet mds = b.next(); + INDArray[] featuresArr = mds.getFeatures(); + INDArray[] featuresMaskArr = mds.getFeaturesMaskArrays(); + + Pair p = b.featurizeSentencePairs(listSentencePair); + assertEquals(p.getFirst().length, 2); + assertEquals(featuresArr[0], p.getFirst()[0]); + assertEquals(featuresArr[1], p.getFirst()[1]); + //assertEquals(p.getSecond().length, 2); + assertEquals(featuresMaskArr[0], p.getSecond()[0]); + //assertEquals(featuresMaskArr[1], p.getSecond()[1]); + } + + /** + * Returns three multidatasets from bert iterator based on given max lengths and whether to prepend/append + * Idea is the sentence pair dataset can be constructed from the single sentence datasets + * First one is constructed from a sentence pair "I saw a girl with a telescope." & "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum" + * Second one is constructed from the left of the sentence pair i.e "I saw a girl with a telescope." + * Third one is constructed from the right of the sentence pair i.e "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum" + */ + private Triple generateMultiDataSets(Triple maxLengths, boolean prependAppend) throws IOException { + BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c); + int maxforPair = maxLengths.getFirst(); + int maxPartOne = maxLengths.getSecond(); + int maxPartTwo = maxLengths.getThird(); + BertIterator.Builder commonBuilder; + commonBuilder = BertIterator.builder() + .tokenizer(t) + .minibatchSize(1) + .featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID) + .vocabMap(t.getVocab()) + .task(BertIterator.Task.SEQ_CLASSIFICATION); + BertIterator shortLongPairFirstIter = commonBuilder + .lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, prependAppend ? maxforPair + 3 : maxforPair) + .sentencePairProvider(new TestSentencePairProvider()) + .prependToken(prependAppend ? "[CLS]" : null) + .appendToken(prependAppend ? "[SEP]" : null) + .build(); + BertIterator shortFirstIter = commonBuilder + .lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, prependAppend ? maxPartOne + 2 : maxPartOne) + .sentenceProvider(new TestSentenceProvider()) + .prependToken(prependAppend ? "[CLS]" : null) + .appendToken(prependAppend ? "[SEP]" : null) + .build(); + BertIterator longFirstIter = commonBuilder + .lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, prependAppend ? maxPartTwo + 1 : maxPartTwo) + .sentenceProvider(new TestSentenceProvider(true)) + .prependToken(null) + .appendToken(prependAppend ? "[SEP]" : null) + .build(); + return new Triple<>(shortLongPairFirstIter.next(), shortFirstIter.next(), longFirstIter.next()); + } + private static class TestSentenceProvider implements LabeledSentenceProvider { private int pos = 0; + private boolean invert; + + private TestSentenceProvider() { + this.invert = false; + } + + private TestSentenceProvider(boolean invert) { + this.invert = invert; + } @Override public boolean hasNext() { @@ -317,10 +564,20 @@ public class TestBertIterator extends BaseDL4JTest { @Override public Pair nextSentence() { Preconditions.checkState(hasNext()); - if(pos++ == 0){ - return new Pair<>("I saw a girl with a telescope.", "positive"); - } else { + if (pos == 0) { + pos++; + if (!invert) return new Pair<>("I saw a girl with a telescope.", "positive"); return new Pair<>("Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum", "negative"); + } else { + if (pos == 1) { + pos++; + if (!invert) return new Pair<>("Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum", "negative"); + return new Pair<>("I saw a girl with a telescope.", "positive"); + } + pos++; + if (!invert) + return new Pair<>("Goodnight noises everywhere", "positive"); + return new Pair<>("Goodnight moon", "positive"); } } @@ -331,8 +588,54 @@ public class TestBertIterator extends BaseDL4JTest { @Override public int totalNumSentences() { + return 3; + } + + @Override + public List allLabels() { + return Arrays.asList("positive", "negative"); + } + + @Override + public int numLabelClasses() { return 2; } + } + + private static class TestSentencePairProvider implements LabeledPairSentenceProvider { + + private int pos = 0; + + @Override + public boolean hasNext() { + return pos < totalNumSentences(); + } + + @Override + public Triple nextSentencePair() { + Preconditions.checkState(hasNext()); + if (pos == 0) { + pos++; + return new Triple<>("I saw a girl with a telescope.", "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum", "positive"); + } else { + if (pos == 1) { + pos++; + return new Triple<>("Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum", "I saw a girl with a telescope.", "negative"); + } + pos++; + return new Triple<>("Goodnight noises everywhere", "Goodnight moon", "positive"); + } + } + + @Override + public void reset() { + pos = 0; + } + + @Override + public int totalNumSentences() { + return 3; + } @Override public List allLabels() {