diff --git a/libnd4j/include/graph/Context.h b/libnd4j/include/graph/Context.h index 96079e5a2..d5b85b543 100644 --- a/libnd4j/include/graph/Context.h +++ b/libnd4j/include/graph/Context.h @@ -190,6 +190,10 @@ namespace nd4j { void setIArguments(Nd4jLong *arguments, int numberOfArguments); void setBArguments(bool *arguments, int numberOfArguments); + void setTArguments(const std::vector &tArgs); + void setIArguments(const std::vector &tArgs); + void setBArguments(const std::vector &tArgs); + void setCudaContext(Nd4jPointer cudaStream, Nd4jPointer reductionPointer, Nd4jPointer allocationPointer); diff --git a/libnd4j/include/graph/impl/Context.cpp b/libnd4j/include/graph/impl/Context.cpp index b18d3f347..146e66067 100644 --- a/libnd4j/include/graph/impl/Context.cpp +++ b/libnd4j/include/graph/impl/Context.cpp @@ -469,6 +469,21 @@ namespace nd4j { bool Context::helpersAllowed() { return _helpersAllowed; } + + void Context::setTArguments(const std::vector &tArgs) { + for (auto t:tArgs) + _tArgs.emplace_back(t); + } + + void Context::setIArguments(const std::vector &iArgs) { + for (auto i:iArgs) + _iArgs.emplace_back(i); + } + + void Context::setBArguments(const std::vector &bArgs) { + for (auto b:bArgs) + _bArgs.emplace_back(b); + } } } diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/adjust_contrast.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/adjust_contrast.cpp index 538214b14..cc11eedca 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/adjust_contrast.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/adjust_contrast.cpp @@ -27,12 +27,14 @@ namespace nd4j { namespace ops { -CONFIGURABLE_OP_IMPL(adjust_contrast, 1, 1, true, 1, 0) { +CONFIGURABLE_OP_IMPL(adjust_contrast, 1, 1, true, -2, 0) { auto input = INPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0); - const double factor = T_ARG(0); + REQUIRE_TRUE(block.numT() > 0 || block.width() > 1, 0, "ADJUST_CONTRAST: Scale factor required"); + + const double factor = block.width() > 1 ? INPUT_VARIABLE(1)->e(0) : T_ARG(0); REQUIRE_TRUE(input->rankOf() > 2, 0, "ADJUST_CONTRAST: op expects rank of input array to be >= 3, but got %i instead", input->rankOf()); REQUIRE_TRUE(input->sizeAt(-1) == 3, 0, "ADJUST_CONTRAST: operation expects image with 3 channels (R, G, B), but got %i instead", input->sizeAt(-1)); @@ -59,15 +61,17 @@ DECLARE_TYPES(adjust_contrast) { } - CONFIGURABLE_OP_IMPL(adjust_contrast_v2, 1, 1, true, 1, 0) { + CONFIGURABLE_OP_IMPL(adjust_contrast_v2, 1, 1, true, -2, 0) { auto input = INPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0); - const double factor = T_ARG(0); + REQUIRE_TRUE(block.numT() > 0 || block.width() > 1, 0, "ADJUST_CONTRAST_V2: Scale factor required"); - REQUIRE_TRUE(input->rankOf() > 2, 0, "ADJUST_CONTRAST: op expects rank of input array to be >= 3, but got %i instead", input->rankOf()); - REQUIRE_TRUE(input->sizeAt(-1) == 3, 0, "ADJUST_CONTRAST: operation expects image with 3 channels (R, G, B), but got %i instead", input->sizeAt(-1)); + const double factor = block.width() > 1 ? INPUT_VARIABLE(1)->e(0) : T_ARG(0); + + REQUIRE_TRUE(input->rankOf() > 2, 0, "ADJUST_CONTRAST_V2: op expects rank of input array to be >= 3, but got %i instead", input->rankOf()); + REQUIRE_TRUE(input->sizeAt(-1) == 3, 0, "ADJUST_CONTRAST_V2: operation expects image with 3 channels (R, G, B), but got %i instead", input->sizeAt(-1)); // compute mean before std::vector axes(input->rankOf() - 1); @@ -78,10 +82,10 @@ DECLARE_TYPES(adjust_contrast) { auto mean = input->reduceAlongDims(reduce::Mean, axes); // result as (x - mean) * factor + mean - std::unique_ptr temp(input->dup()); - input->applyTrueBroadcast(BroadcastOpsTuple::Subtract(), &mean, temp.get()); - temp->applyScalar(scalar::Multiply, factor); - temp->applyTrueBroadcast(BroadcastOpsTuple::Add(), &mean, output); + auto temp = input->ulike(); + input->applyTrueBroadcast(BroadcastOpsTuple::Subtract(), &mean, &temp); + temp.applyScalar(scalar::Multiply, factor); + temp.applyTrueBroadcast(BroadcastOpsTuple::Add(), &mean, output); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/headers/parity_ops.h b/libnd4j/include/ops/declarable/headers/parity_ops.h index 3660ee229..590d99308 100644 --- a/libnd4j/include/ops/declarable/headers/parity_ops.h +++ b/libnd4j/include/ops/declarable/headers/parity_ops.h @@ -610,8 +610,8 @@ namespace nd4j { * */ #if NOT_EXCLUDED(OP_adjust_contrast) - DECLARE_CONFIGURABLE_OP(adjust_contrast, 1, 1, true, 1, 0); - DECLARE_CONFIGURABLE_OP(adjust_contrast_v2, 1, 1, true, 1, 0); + DECLARE_CONFIGURABLE_OP(adjust_contrast, 1, 1, true, -2, 0); + DECLARE_CONFIGURABLE_OP(adjust_contrast_v2, 1, 1, true, -2, 0); #endif diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests16.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests16.cpp index 1a459a012..d29d1f0e1 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests16.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests16.cpp @@ -161,4 +161,29 @@ TEST_F(DeclarableOpsTests16, test_empty_cast_1) { ASSERT_EQ(e, *result->at(0)); delete result; +} + +TEST_F(DeclarableOpsTests16, test_range_1) { + nd4j::ops::range op; + auto z = NDArrayFactory::create('c', {200}); + + Context ctx(1); + ctx.setTArguments({-1.0, 1.0, 0.01}); + ctx.setOutputArray(0, &z); + + auto status = op.execute(&ctx); + ASSERT_EQ(Status::OK(), status); +} + +TEST_F(DeclarableOpsTests16, test_range_2) { + nd4j::ops::range op; + auto z = NDArrayFactory::create('c', {200}); + + double tArgs[] = {-1.0, 1.0, 0.01}; + + auto shapes = ::calculateOutputShapes2(nullptr, op.getOpHash(), nullptr, nullptr, 0, tArgs, 3, nullptr, 0, nullptr, 0); + shape::printShapeInfoLinear("Result", shapes->at(0)); + ASSERT_TRUE(shape::shapeEquals(z.shapeInfo(), shapes->at(0))); + + delete shapes; } \ No newline at end of file diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/BitCast.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/BitCast.java index ee0adfb94..43bff11e6 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/BitCast.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/BitCast.java @@ -1,25 +1,54 @@ package org.nd4j.linalg.api.ops.custom; +import lombok.val; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.shape.options.ArrayOptionsHelper; import org.nd4j.linalg.factory.Nd4j; +import org.tensorflow.framework.AttrValue; +import org.tensorflow.framework.GraphDef; +import org.tensorflow.framework.NodeDef; + +import java.util.Map; public class BitCast extends DynamicCustomOp { public BitCast() {} + public BitCast(INDArray in, DataType dataType, INDArray out) { + this(in, dataType.toInt(), out); + } + public BitCast(INDArray in, int dataType, INDArray out) { inputArguments.add(in); outputArguments.add(out); iArguments.add(Long.valueOf(dataType)); } + public BitCast(INDArray in, DataType dataType) { + this(in, dataType.toInt()); + } + + public BitCast(INDArray in, int dataType) { + inputArguments.add(in); + iArguments.add(Long.valueOf(dataType)); + } + public BitCast(SameDiff sameDiff, SDVariable in, SDVariable dataType) { super("", sameDiff, new SDVariable[]{in, dataType}); } + @Override + public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { + TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); + val t = nodeDef.getAttrOrDefault("type", null); + val type = ArrayOptionsHelper.convertToDataType(t.getType()); + addIArgument(type.toInt()); + } + @Override public String opName() { return "bitcast"; diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java index 8fe744b38..20f2b5f22 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java @@ -2226,7 +2226,7 @@ public class CudaExecutioner extends DefaultOpExecutioner { cnt = 0; for (val t: op.tArgs()) - tArgs.put(cnt++, (float) t); + tArgs.put(cnt++, t); OpaqueShapeList ptrptr = nativeOps.calculateOutputShapes2(null, hash, inputBuffers, inputShapes, op.inputArguments().length, tArgs, op.tArgs().length, iArgs, op.iArgs().length, bArgs, op.numBArguments()); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java index 22b2068d4..e8b5e15c9 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java @@ -6754,6 +6754,15 @@ NDArray& NDArray::operator()(const Nd4jLong* idx) { public native void setBArguments(@Cast("bool*") BooleanPointer arguments, int numberOfArguments); public native void setBArguments(@Cast("bool*") boolean[] arguments, int numberOfArguments); + public native void setTArguments(@StdVector DoublePointer tArgs); + public native void setTArguments(@StdVector DoubleBuffer tArgs); + public native void setTArguments(@StdVector double[] tArgs); + public native void setIArguments(@Cast("Nd4jLong*") @StdVector LongPointer tArgs); + public native void setIArguments(@Cast("Nd4jLong*") @StdVector LongBuffer tArgs); + public native void setIArguments(@Cast("Nd4jLong*") @StdVector long[] tArgs); + public native void setBArguments(@Cast("bool*") @StdVector BooleanPointer tArgs); + public native void setBArguments(@Cast("bool*") @StdVector boolean[] tArgs); + public native void setCudaContext(@Cast("Nd4jPointer") Pointer cudaStream, @Cast("Nd4jPointer") Pointer reductionPointer, @Cast("Nd4jPointer") Pointer allocationPointer); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java index d99a8240a..e2e9b0c2f 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java @@ -6754,6 +6754,15 @@ NDArray& NDArray::operator()(const Nd4jLong* idx) { public native void setBArguments(@Cast("bool*") BooleanPointer arguments, int numberOfArguments); public native void setBArguments(@Cast("bool*") boolean[] arguments, int numberOfArguments); + public native void setTArguments(@StdVector DoublePointer tArgs); + public native void setTArguments(@StdVector DoubleBuffer tArgs); + public native void setTArguments(@StdVector double[] tArgs); + public native void setIArguments(@Cast("Nd4jLong*") @StdVector LongPointer tArgs); + public native void setIArguments(@Cast("Nd4jLong*") @StdVector LongBuffer tArgs); + public native void setIArguments(@Cast("Nd4jLong*") @StdVector long[] tArgs); + public native void setBArguments(@Cast("bool*") @StdVector BooleanPointer tArgs); + public native void setBArguments(@Cast("bool*") @StdVector boolean[] tArgs); + public native void setCudaContext(@Cast("Nd4jPointer") Pointer cudaStream, @Cast("Nd4jPointer") Pointer reductionPointer, @Cast("Nd4jPointer") Pointer allocationPointer); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java index ad38f39d7..556405c14 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java @@ -931,4 +931,36 @@ public class CustomOpsTests extends BaseNd4jTest { Nd4j.exec(new KnnMinDistance(point, lowest, highest, distance)); System.out.println(distance); } + + + @Test + public void testRange(){ + DynamicCustomOp op = DynamicCustomOp.builder("range") + .addFloatingPointArguments(-1.0, 1.0, 0.01) + .build(); + + List lsd = op.calculateOutputShape(); + //System.out.println("Calculated output shape: " + Arrays.toString(lsd.get(0).getShape())); + op.setOutputArgument(0, Nd4j.create(lsd.get(0))); + + Nd4j.exec(op); + } + + @Test + public void testBitCastShape_1(){ + val out = Nd4j.createUninitialized(1,10); + BitCast op = new BitCast(Nd4j.zeros(DataType.FLOAT,1,10), DataType.INT.toInt(), out); + List lsd = op.calculateOutputShape(); + assertEquals(1, lsd.size()); + assertArrayEquals(new long[]{1,10}, lsd.get(0).getShape()); + } + + @Test + public void testBitCastShape_2(){ + val out = Nd4j.createUninitialized(1,10); + BitCast op = new BitCast(Nd4j.zeros(DataType.DOUBLE,1,10), DataType.INT.toInt(), out); + List lsd = op.calculateOutputShape(); + assertEquals(1, lsd.size()); + assertArrayEquals(new long[]{1,10, 2}, lsd.get(0).getShape()); + } }