[WIP] Small fixes here and there (#50)

* one range test

Signed-off-by: raver119 <raver119@gmail.com>

* few Context convenience singatures

Signed-off-by: raver119 <raver119@gmail.com>

* one more range test

Signed-off-by: raver119 <raver119@gmail.com>

* "range" "fix"

Signed-off-by: raver119 <raver119@gmail.com>

* adjuct_contrast_v2 now allows scale factor to be provided via input_variable

Signed-off-by: raver119 <raver119@gmail.com>

* adjust_contrast now allows scale factor as variable too

Signed-off-by: raver119 <raver119@gmail.com>

* bitcast shape tests

Signed-off-by: raver119 <raver119@gmail.com>

* BitCast import dtype added

Signed-off-by: raver119 <raver119@gmail.com>

* few more BitCast signatures

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2019-11-15 17:04:29 +03:00 committed by GitHub
parent d7718c28fb
commit 1780dcc883
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 140 additions and 13 deletions

View File

@ -190,6 +190,10 @@ namespace nd4j {
void setIArguments(Nd4jLong *arguments, int numberOfArguments); void setIArguments(Nd4jLong *arguments, int numberOfArguments);
void setBArguments(bool *arguments, int numberOfArguments); void setBArguments(bool *arguments, int numberOfArguments);
void setTArguments(const std::vector<double> &tArgs);
void setIArguments(const std::vector<Nd4jLong> &tArgs);
void setBArguments(const std::vector<bool> &tArgs);
void setCudaContext(Nd4jPointer cudaStream, Nd4jPointer reductionPointer, Nd4jPointer allocationPointer); void setCudaContext(Nd4jPointer cudaStream, Nd4jPointer reductionPointer, Nd4jPointer allocationPointer);

View File

@ -469,6 +469,21 @@ namespace nd4j {
bool Context::helpersAllowed() { bool Context::helpersAllowed() {
return _helpersAllowed; return _helpersAllowed;
} }
void Context::setTArguments(const std::vector<double> &tArgs) {
for (auto t:tArgs)
_tArgs.emplace_back(t);
}
void Context::setIArguments(const std::vector<Nd4jLong> &iArgs) {
for (auto i:iArgs)
_iArgs.emplace_back(i);
}
void Context::setBArguments(const std::vector<bool> &bArgs) {
for (auto b:bArgs)
_bArgs.emplace_back(b);
}
} }
} }

View File

@ -27,12 +27,14 @@
namespace nd4j { namespace nd4j {
namespace ops { namespace ops {
CONFIGURABLE_OP_IMPL(adjust_contrast, 1, 1, true, 1, 0) { CONFIGURABLE_OP_IMPL(adjust_contrast, 1, 1, true, -2, 0) {
auto input = INPUT_VARIABLE(0); auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0);
const double factor = T_ARG(0); REQUIRE_TRUE(block.numT() > 0 || block.width() > 1, 0, "ADJUST_CONTRAST: Scale factor required");
const double factor = block.width() > 1 ? INPUT_VARIABLE(1)->e<double>(0) : T_ARG(0);
REQUIRE_TRUE(input->rankOf() > 2, 0, "ADJUST_CONTRAST: op expects rank of input array to be >= 3, but got %i instead", input->rankOf()); REQUIRE_TRUE(input->rankOf() > 2, 0, "ADJUST_CONTRAST: op expects rank of input array to be >= 3, but got %i instead", input->rankOf());
REQUIRE_TRUE(input->sizeAt(-1) == 3, 0, "ADJUST_CONTRAST: operation expects image with 3 channels (R, G, B), but got %i instead", input->sizeAt(-1)); REQUIRE_TRUE(input->sizeAt(-1) == 3, 0, "ADJUST_CONTRAST: operation expects image with 3 channels (R, G, B), but got %i instead", input->sizeAt(-1));
@ -59,15 +61,17 @@ DECLARE_TYPES(adjust_contrast) {
} }
CONFIGURABLE_OP_IMPL(adjust_contrast_v2, 1, 1, true, 1, 0) { CONFIGURABLE_OP_IMPL(adjust_contrast_v2, 1, 1, true, -2, 0) {
auto input = INPUT_VARIABLE(0); auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0);
const double factor = T_ARG(0); REQUIRE_TRUE(block.numT() > 0 || block.width() > 1, 0, "ADJUST_CONTRAST_V2: Scale factor required");
REQUIRE_TRUE(input->rankOf() > 2, 0, "ADJUST_CONTRAST: op expects rank of input array to be >= 3, but got %i instead", input->rankOf()); const double factor = block.width() > 1 ? INPUT_VARIABLE(1)->e<double>(0) : T_ARG(0);
REQUIRE_TRUE(input->sizeAt(-1) == 3, 0, "ADJUST_CONTRAST: operation expects image with 3 channels (R, G, B), but got %i instead", input->sizeAt(-1));
REQUIRE_TRUE(input->rankOf() > 2, 0, "ADJUST_CONTRAST_V2: op expects rank of input array to be >= 3, but got %i instead", input->rankOf());
REQUIRE_TRUE(input->sizeAt(-1) == 3, 0, "ADJUST_CONTRAST_V2: operation expects image with 3 channels (R, G, B), but got %i instead", input->sizeAt(-1));
// compute mean before // compute mean before
std::vector<int> axes(input->rankOf() - 1); std::vector<int> axes(input->rankOf() - 1);
@ -78,10 +82,10 @@ DECLARE_TYPES(adjust_contrast) {
auto mean = input->reduceAlongDims(reduce::Mean, axes); auto mean = input->reduceAlongDims(reduce::Mean, axes);
// result as (x - mean) * factor + mean // result as (x - mean) * factor + mean
std::unique_ptr<NDArray> temp(input->dup()); auto temp = input->ulike();
input->applyTrueBroadcast(BroadcastOpsTuple::Subtract(), &mean, temp.get()); input->applyTrueBroadcast(BroadcastOpsTuple::Subtract(), &mean, &temp);
temp->applyScalar(scalar::Multiply, factor); temp.applyScalar(scalar::Multiply, factor);
temp->applyTrueBroadcast(BroadcastOpsTuple::Add(), &mean, output); temp.applyTrueBroadcast(BroadcastOpsTuple::Add(), &mean, output);
return Status::OK(); return Status::OK();
} }

View File

@ -610,8 +610,8 @@ namespace nd4j {
* *
*/ */
#if NOT_EXCLUDED(OP_adjust_contrast) #if NOT_EXCLUDED(OP_adjust_contrast)
DECLARE_CONFIGURABLE_OP(adjust_contrast, 1, 1, true, 1, 0); DECLARE_CONFIGURABLE_OP(adjust_contrast, 1, 1, true, -2, 0);
DECLARE_CONFIGURABLE_OP(adjust_contrast_v2, 1, 1, true, 1, 0); DECLARE_CONFIGURABLE_OP(adjust_contrast_v2, 1, 1, true, -2, 0);
#endif #endif

View File

@ -162,3 +162,28 @@ TEST_F(DeclarableOpsTests16, test_empty_cast_1) {
delete result; delete result;
} }
TEST_F(DeclarableOpsTests16, test_range_1) {
nd4j::ops::range op;
auto z = NDArrayFactory::create<float>('c', {200});
Context ctx(1);
ctx.setTArguments({-1.0, 1.0, 0.01});
ctx.setOutputArray(0, &z);
auto status = op.execute(&ctx);
ASSERT_EQ(Status::OK(), status);
}
TEST_F(DeclarableOpsTests16, test_range_2) {
nd4j::ops::range op;
auto z = NDArrayFactory::create<float>('c', {200});
double tArgs[] = {-1.0, 1.0, 0.01};
auto shapes = ::calculateOutputShapes2(nullptr, op.getOpHash(), nullptr, nullptr, 0, tArgs, 3, nullptr, 0, nullptr, 0);
shape::printShapeInfoLinear("Result", shapes->at(0));
ASSERT_TRUE(shape::shapeEquals(z.shapeInfo(), shapes->at(0)));
delete shapes;
}

View File

@ -1,25 +1,54 @@
package org.nd4j.linalg.api.ops.custom; package org.nd4j.linalg.api.ops.custom;
import lombok.val;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.shape.options.ArrayOptionsHelper;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;
import java.util.Map;
public class BitCast extends DynamicCustomOp { public class BitCast extends DynamicCustomOp {
public BitCast() {} public BitCast() {}
public BitCast(INDArray in, DataType dataType, INDArray out) {
this(in, dataType.toInt(), out);
}
public BitCast(INDArray in, int dataType, INDArray out) { public BitCast(INDArray in, int dataType, INDArray out) {
inputArguments.add(in); inputArguments.add(in);
outputArguments.add(out); outputArguments.add(out);
iArguments.add(Long.valueOf(dataType)); iArguments.add(Long.valueOf(dataType));
} }
public BitCast(INDArray in, DataType dataType) {
this(in, dataType.toInt());
}
public BitCast(INDArray in, int dataType) {
inputArguments.add(in);
iArguments.add(Long.valueOf(dataType));
}
public BitCast(SameDiff sameDiff, SDVariable in, SDVariable dataType) { public BitCast(SameDiff sameDiff, SDVariable in, SDVariable dataType) {
super("", sameDiff, new SDVariable[]{in, dataType}); super("", sameDiff, new SDVariable[]{in, dataType});
} }
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
val t = nodeDef.getAttrOrDefault("type", null);
val type = ArrayOptionsHelper.convertToDataType(t.getType());
addIArgument(type.toInt());
}
@Override @Override
public String opName() { public String opName() {
return "bitcast"; return "bitcast";

View File

@ -2226,7 +2226,7 @@ public class CudaExecutioner extends DefaultOpExecutioner {
cnt = 0; cnt = 0;
for (val t: op.tArgs()) for (val t: op.tArgs())
tArgs.put(cnt++, (float) t); tArgs.put(cnt++, t);
OpaqueShapeList ptrptr = nativeOps.calculateOutputShapes2(null, hash, inputBuffers, inputShapes, op.inputArguments().length, tArgs, op.tArgs().length, iArgs, op.iArgs().length, bArgs, op.numBArguments()); OpaqueShapeList ptrptr = nativeOps.calculateOutputShapes2(null, hash, inputBuffers, inputShapes, op.inputArguments().length, tArgs, op.tArgs().length, iArgs, op.iArgs().length, bArgs, op.numBArguments());

View File

@ -6754,6 +6754,15 @@ NDArray& NDArray::operator()(const Nd4jLong* idx) {
public native void setBArguments(@Cast("bool*") BooleanPointer arguments, int numberOfArguments); public native void setBArguments(@Cast("bool*") BooleanPointer arguments, int numberOfArguments);
public native void setBArguments(@Cast("bool*") boolean[] arguments, int numberOfArguments); public native void setBArguments(@Cast("bool*") boolean[] arguments, int numberOfArguments);
public native void setTArguments(@StdVector DoublePointer tArgs);
public native void setTArguments(@StdVector DoubleBuffer tArgs);
public native void setTArguments(@StdVector double[] tArgs);
public native void setIArguments(@Cast("Nd4jLong*") @StdVector LongPointer tArgs);
public native void setIArguments(@Cast("Nd4jLong*") @StdVector LongBuffer tArgs);
public native void setIArguments(@Cast("Nd4jLong*") @StdVector long[] tArgs);
public native void setBArguments(@Cast("bool*") @StdVector BooleanPointer tArgs);
public native void setBArguments(@Cast("bool*") @StdVector boolean[] tArgs);
public native void setCudaContext(@Cast("Nd4jPointer") Pointer cudaStream, @Cast("Nd4jPointer") Pointer reductionPointer, @Cast("Nd4jPointer") Pointer allocationPointer); public native void setCudaContext(@Cast("Nd4jPointer") Pointer cudaStream, @Cast("Nd4jPointer") Pointer reductionPointer, @Cast("Nd4jPointer") Pointer allocationPointer);

View File

@ -6754,6 +6754,15 @@ NDArray& NDArray::operator()(const Nd4jLong* idx) {
public native void setBArguments(@Cast("bool*") BooleanPointer arguments, int numberOfArguments); public native void setBArguments(@Cast("bool*") BooleanPointer arguments, int numberOfArguments);
public native void setBArguments(@Cast("bool*") boolean[] arguments, int numberOfArguments); public native void setBArguments(@Cast("bool*") boolean[] arguments, int numberOfArguments);
public native void setTArguments(@StdVector DoublePointer tArgs);
public native void setTArguments(@StdVector DoubleBuffer tArgs);
public native void setTArguments(@StdVector double[] tArgs);
public native void setIArguments(@Cast("Nd4jLong*") @StdVector LongPointer tArgs);
public native void setIArguments(@Cast("Nd4jLong*") @StdVector LongBuffer tArgs);
public native void setIArguments(@Cast("Nd4jLong*") @StdVector long[] tArgs);
public native void setBArguments(@Cast("bool*") @StdVector BooleanPointer tArgs);
public native void setBArguments(@Cast("bool*") @StdVector boolean[] tArgs);
public native void setCudaContext(@Cast("Nd4jPointer") Pointer cudaStream, @Cast("Nd4jPointer") Pointer reductionPointer, @Cast("Nd4jPointer") Pointer allocationPointer); public native void setCudaContext(@Cast("Nd4jPointer") Pointer cudaStream, @Cast("Nd4jPointer") Pointer reductionPointer, @Cast("Nd4jPointer") Pointer allocationPointer);

View File

@ -931,4 +931,36 @@ public class CustomOpsTests extends BaseNd4jTest {
Nd4j.exec(new KnnMinDistance(point, lowest, highest, distance)); Nd4j.exec(new KnnMinDistance(point, lowest, highest, distance));
System.out.println(distance); System.out.println(distance);
} }
@Test
public void testRange(){
DynamicCustomOp op = DynamicCustomOp.builder("range")
.addFloatingPointArguments(-1.0, 1.0, 0.01)
.build();
List<LongShapeDescriptor> lsd = op.calculateOutputShape();
//System.out.println("Calculated output shape: " + Arrays.toString(lsd.get(0).getShape()));
op.setOutputArgument(0, Nd4j.create(lsd.get(0)));
Nd4j.exec(op);
}
@Test
public void testBitCastShape_1(){
val out = Nd4j.createUninitialized(1,10);
BitCast op = new BitCast(Nd4j.zeros(DataType.FLOAT,1,10), DataType.INT.toInt(), out);
List<LongShapeDescriptor> lsd = op.calculateOutputShape();
assertEquals(1, lsd.size());
assertArrayEquals(new long[]{1,10}, lsd.get(0).getShape());
}
@Test
public void testBitCastShape_2(){
val out = Nd4j.createUninitialized(1,10);
BitCast op = new BitCast(Nd4j.zeros(DataType.DOUBLE,1,10), DataType.INT.toInt(), out);
List<LongShapeDescriptor> lsd = op.calculateOutputShape();
assertEquals(1, lsd.size());
assertArrayEquals(new long[]{1,10, 2}, lsd.get(0).getShape());
}
} }