Base namespace (#287)

* wip

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* up to assign operation.

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* fix Imax, IMin.

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* concat.

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* dynamicPartition

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* new ops up to gte.

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* wip

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* updated review items.

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* up to matchCondition.

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* wip

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* up to OneHot.

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* wip. up to permute.

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* wip. up to rank.

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* wip. up to scatterMul.

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* resolving code review issues.

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* wip. inclides UnsortedSegment ops.

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* wip. up to stridedSlice.

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* fix stridedSlice.

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* first pass of SDBaseops.kt complete.

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* fix review items.

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* wip

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* put branch in compilable state.

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* add NDBaseTest. fix dynamicPartition signature. failed fix of assign.

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* make tests public.

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* adds tests up to invertedPermutation.

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* fix ScalarEquals, Assign.

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* wip

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* updates NDBaseTest.

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* updates 'check' comments based on test pass/fail.

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* fix scalar ops. Update tests,

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* dev-tools review items. wip.

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* dev-tools code review items.

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* Test fixes

Signed-off-by: Alex Black <blacka101@gmail.com>

* complete review items.

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* Comment for logged issue; fix test case

Signed-off-by: Alex Black <blacka101@gmail.com>

* Fixes

Signed-off-by: Alex Black <blacka101@gmail.com>

* More fixes

* wip

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* undo changes to Nd4jCpu.java

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* update tests.

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* Fixes and regenerate

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Small test fixes

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* small fixes to tests.

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* Cleanup

Signed-off-by: Alex Black <blacka101@gmail.com>

* Fixes

Signed-off-by: Alex Black <blacka101@gmail.com>

* Small CUDAExecutioner fix

Signed-off-by: Alex Black <blacka101@gmail.com>

* Fixes

Signed-off-by: Alex Black <blacka101@gmail.com>

* Small CudaExecutioner fix

Signed-off-by: Alex Black <blacka101@gmail.com>

* Another small CudaExecutioner fix

Signed-off-by: Alex Black <blacka101@gmail.com>

* Another small CudaExecutioner fix

Signed-off-by: Alex Black <blacka101@gmail.com>

Co-authored-by: Robert Altena <Rob@Ra-ai.com>
master
Alex Black 2020-04-20 16:57:00 +10:00 committed by GitHub
parent a5db0e33be
commit 191bda3228
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
71 changed files with 1338 additions and 234 deletions

View File

@ -45,15 +45,14 @@ public class IMax extends BaseIndexAccumulation {
super(x, z, dimensions); super(x, z, dimensions);
} }
public IMax(INDArray x, boolean keepDims, int... dimensions) {
super(x, keepDims, dimensions);
}
public IMax(INDArray x, int... dimensions) { public IMax(INDArray x, int... dimensions) {
super(x, null, dimensions); super(x, null, dimensions);
} }
public IMax(INDArray x, boolean keepDims, int... dimensions) {
super(x, null, dimensions);
this.keepDims = keepDims;
}
@Override @Override
public int opNum() { public int opNum() {

View File

@ -53,6 +53,7 @@ public class IMin extends BaseIndexAccumulation {
} }
@Override @Override
public int opNum() { public int opNum() {
return 1; return 1;

View File

@ -306,4 +306,3 @@ public class Mmul extends DynamicCustomOp {
return Collections.singletonList(dataTypes.get(0)); return Collections.singletonList(dataTypes.get(0));
} }
} }

View File

@ -51,6 +51,35 @@ public class TensorMmul extends DynamicCustomOp {
protected boolean addedEdges; protected boolean addedEdges;
protected MMulTranspose mMulTranspose; protected MMulTranspose mMulTranspose;
public TensorMmul(INDArray x, INDArray y, int[][] axes) {
this(x,y,axes[0], axes[1], false, false, false);
}
/**
* Initialize with the given
* input, pairwise transform, result, and number
* of elements
*
* @param x the input
* @param y the pairwise transform
* @param z the result
*/
public TensorMmul(INDArray x, INDArray y, INDArray z, int[][] axes) {
this(x, y, axes[0], axes[1], false, false, false);
}
public TensorMmul(INDArray x, INDArray y, int[] dimensionsX, int[] dimensionsY,
boolean transposeX, boolean transposeY, boolean transposeZ) {
super(null,new INDArray[]{x, y},null);
this.axes = new int[][]{dimensionsX, dimensionsY};
addIArgument(dimensionsX.length);
addIArgument(dimensionsX);
addIArgument(dimensionsY.length);
addIArgument(dimensionsY);
addBArgument(transposeX, transposeY, transposeZ);
}
public TensorMmul(SameDiff sameDiff, public TensorMmul(SameDiff sameDiff,
SDVariable i_v1, SDVariable i_v1,
SDVariable i_v2, SDVariable i_v2,
@ -229,34 +258,6 @@ public class TensorMmul extends DynamicCustomOp {
return sameDiff.reshape(ret, aPlusB); return sameDiff.reshape(ret, aPlusB);
} }
public TensorMmul(INDArray x, INDArray y, int[][] axes) {
super(null,new INDArray[]{x, y},null);
this.axes = axes;
this.extraArgs = new Object[] {axes};
}
/**
* Initialize with the given
* input, pairwise transform, result, and number
* of elements
*
* @param x the input
* @param y the pairwise transform
* @param z the result
*/
public TensorMmul(INDArray x, INDArray y, INDArray z, int[][] axes) {
super(null,new INDArray[]{x, y, z},null);
this.axes = axes;
}
public TensorMmul(INDArray x, INDArray y, int[] dimensionsX, int[] dimensionsY,
boolean transposeX, boolean transposeY, boolean transposeZ) {
super(null,new INDArray[]{x, y},null);
this.axes = new int[][]{dimensionsX, dimensionsY};
addBArgument(transposeX, transposeY, transposeZ);
}
@Override @Override
public String opName() { public String opName() {
return "tensordot"; return "tensordot";

View File

@ -48,7 +48,6 @@ public class NormMax extends BaseReduceFloatOp {
super(x, null, z, dimensions); super(x, null, z, dimensions);
} }
public NormMax(INDArray x, int... dimensions) { public NormMax(INDArray x, int... dimensions) {
super(x, dimensions); super(x, dimensions);
} }

View File

@ -48,6 +48,10 @@ public class SquaredNorm extends BaseReduceFloatOp {
public SquaredNorm(){} public SquaredNorm(){}
public SquaredNorm(INDArray x, int... dimensions){
super(x, dimensions);
}
@Override @Override
public int opNum() { public int opNum() {
return 7; return 7;

View File

@ -52,11 +52,15 @@ public class MatchCondition extends BaseReduceLongOp {
public MatchCondition() {} public MatchCondition() {}
public MatchCondition(INDArray x, Condition condition, int... dimensions) { public MatchCondition(INDArray x, Condition condition, int... dimensions) {
this(x, Nd4j.EPS_THRESHOLD, condition, dimensions); this(x, Nd4j.EPS_THRESHOLD, condition, dimensions);
} }
public MatchCondition(INDArray x, Condition condition, boolean keepDims, int... dimensions) {
this(x, Nd4j.EPS_THRESHOLD, condition, dimensions);
this.keepDims = keepDims;
}
public MatchCondition(INDArray x, double eps, Condition condition, int... dimensions) { public MatchCondition(INDArray x, double eps, Condition condition, int... dimensions) {
super(x); super(x);
this.compare = condition.getValue(); this.compare = condition.getValue();
@ -68,10 +72,6 @@ public class MatchCondition extends BaseReduceLongOp {
defineDimensions(dimensions); defineDimensions(dimensions);
} }
public MatchCondition(INDArray in, Condition condition, boolean keepDim, int... dimensions) {
this(in, condition, dimensions);
}
@Override @Override
public int opNum() { public int opNum() {
return 2; return 2;

View File

@ -41,7 +41,6 @@ public class Sum extends BaseReduceSameOp {
super(sameDiff, i_v, i_v2, dimensions); super(sameDiff, i_v, i_v2, dimensions);
} }
public Sum() { public Sum() {
} }

View File

@ -40,7 +40,7 @@ public class ScalarEquals extends BaseScalarBoolOp {
} }
public ScalarEquals(INDArray x, Number num) { public ScalarEquals(INDArray x, Number num) {
super(x, num); this(x, null, num);
} }

View File

@ -19,9 +19,11 @@ package org.nd4j.linalg.api.ops.impl.scalar.comparison;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseScalarBoolOp; import org.nd4j.linalg.api.ops.BaseScalarBoolOp;
import org.nd4j.linalg.api.ops.BaseScalarOp; import org.nd4j.linalg.api.ops.BaseScalarOp;
import org.nd4j.linalg.factory.Nd4j;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
@ -56,7 +58,7 @@ public class ScalarGreaterThan extends BaseScalarBoolOp {
} }
public ScalarGreaterThan(INDArray x, Number num) { public ScalarGreaterThan(INDArray x, Number num) {
super(x, num); this(x, null, num);
} }

View File

@ -19,9 +19,11 @@ package org.nd4j.linalg.api.ops.impl.scalar.comparison;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseScalarBoolOp; import org.nd4j.linalg.api.ops.BaseScalarBoolOp;
import org.nd4j.linalg.api.ops.BaseScalarOp; import org.nd4j.linalg.api.ops.BaseScalarOp;
import org.nd4j.linalg.factory.Nd4j;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
@ -40,7 +42,7 @@ public class ScalarGreaterThanOrEqual extends BaseScalarBoolOp {
} }
public ScalarGreaterThanOrEqual(INDArray x, Number num) { public ScalarGreaterThanOrEqual(INDArray x, Number num) {
super(x, num); this(x, null, num);
} }

View File

@ -18,9 +18,11 @@ package org.nd4j.linalg.api.ops.impl.scalar.comparison;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseScalarBoolOp; import org.nd4j.linalg.api.ops.BaseScalarBoolOp;
import org.nd4j.linalg.api.ops.BaseScalarOp; import org.nd4j.linalg.api.ops.BaseScalarOp;
import org.nd4j.linalg.factory.Nd4j;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
@ -39,7 +41,7 @@ public class ScalarLessThan extends BaseScalarBoolOp {
} }
public ScalarLessThan(INDArray x, Number num) { public ScalarLessThan(INDArray x, Number num) {
super(x, num); this(x, null, num);
} }
public ScalarLessThan(SameDiff sameDiff, SDVariable i_v, Number scalar, boolean inPlace) { public ScalarLessThan(SameDiff sameDiff, SDVariable i_v, Number scalar, boolean inPlace) {

View File

@ -19,9 +19,11 @@ package org.nd4j.linalg.api.ops.impl.scalar.comparison;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseScalarBoolOp; import org.nd4j.linalg.api.ops.BaseScalarBoolOp;
import org.nd4j.linalg.api.ops.BaseScalarOp; import org.nd4j.linalg.api.ops.BaseScalarOp;
import org.nd4j.linalg.factory.Nd4j;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
@ -49,7 +51,7 @@ public class ScalarLessThanOrEqual extends BaseScalarBoolOp {
} }
public ScalarLessThanOrEqual(INDArray x, Number num) { public ScalarLessThanOrEqual(INDArray x, Number num) {
super(x, num); this(x, null, num);
} }

View File

@ -41,10 +41,9 @@ public class ScalarNotEquals extends BaseScalarBoolOp {
} }
public ScalarNotEquals(INDArray x, Number num) { public ScalarNotEquals(INDArray x, Number num) {
super(x, num); this(x, null, num);
} }
public ScalarNotEquals(SameDiff sameDiff, SDVariable i_v, Number scalar) { public ScalarNotEquals(SameDiff sameDiff, SDVariable i_v, Number scalar) {
super(sameDiff, i_v, scalar); super(sameDiff, i_v, scalar);
} }

View File

@ -16,6 +16,7 @@
package org.nd4j.linalg.api.ops.impl.scatter; package org.nd4j.linalg.api.ops.impl.scatter;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
@ -44,12 +45,12 @@ public class ScatterAdd extends DynamicCustomOp {
super(null, sameDiff, new SDVariable[]{ref, indices, updates}, false); super(null, sameDiff, new SDVariable[]{ref, indices, updates}, false);
} }
public ScatterAdd(INDArray ref, INDArray indices, INDArray updates) {
addInputArgument(ref, indices, updates);
}
public ScatterAdd(){} public ScatterAdd(){}
public ScatterAdd(@NonNull INDArray ref, @NonNull INDArray indices, @NonNull INDArray update){
super(new INDArray[]{ref, indices, update}, null);
}
@Override @Override
public String opName() { public String opName() {
return "scatter_add"; return "scatter_add";

View File

@ -16,6 +16,7 @@
package org.nd4j.linalg.api.ops.impl.scatter; package org.nd4j.linalg.api.ops.impl.scatter;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
@ -44,11 +45,12 @@ public class ScatterDiv extends DynamicCustomOp {
super(null, sameDiff, new SDVariable[]{ref, indices, updates}, false); super(null, sameDiff, new SDVariable[]{ref, indices, updates}, false);
} }
public ScatterDiv(INDArray ref, INDArray indices, INDArray updates) { public ScatterDiv() {}
addInputArgument(ref, indices, updates);
public ScatterDiv(@NonNull INDArray ref, @NonNull INDArray indices, @NonNull INDArray update){
super(new INDArray[]{ref, indices, update}, null);
} }
public ScatterDiv() {}
@Override @Override
public String opName() { public String opName() {

View File

@ -16,6 +16,7 @@
package org.nd4j.linalg.api.ops.impl.scatter; package org.nd4j.linalg.api.ops.impl.scatter;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
@ -42,12 +43,12 @@ public class ScatterMax extends DynamicCustomOp {
super(null, sameDiff, new SDVariable[]{ref, indices, updates}, false); super(null, sameDiff, new SDVariable[]{ref, indices, updates}, false);
} }
public ScatterMax(INDArray ref, INDArray indices, INDArray updates) {
addInputArgument(ref, indices, updates);
}
public ScatterMax() {} public ScatterMax() {}
public ScatterMax(@NonNull INDArray ref, @NonNull INDArray indices, @NonNull INDArray update){
super(new INDArray[]{ref, indices, update}, null);
}
@Override @Override
public String opName() { public String opName() {
return "scatter_max"; return "scatter_max";

View File

@ -16,6 +16,7 @@
package org.nd4j.linalg.api.ops.impl.scatter; package org.nd4j.linalg.api.ops.impl.scatter;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
@ -42,12 +43,12 @@ public class ScatterMin extends DynamicCustomOp {
super(null, sameDiff, new SDVariable[]{ref, indices, updates}, false); super(null, sameDiff, new SDVariable[]{ref, indices, updates}, false);
} }
public ScatterMin(INDArray ref, INDArray indices, INDArray updates) {
addInputArgument(ref, indices, updates);
}
public ScatterMin() {} public ScatterMin() {}
public ScatterMin(@NonNull INDArray ref, @NonNull INDArray indices, @NonNull INDArray update){
super(new INDArray[]{ref, indices, update}, null);
}
@Override @Override
public String opName() { public String opName() {
return "scatter_min"; return "scatter_min";

View File

@ -16,6 +16,7 @@
package org.nd4j.linalg.api.ops.impl.scatter; package org.nd4j.linalg.api.ops.impl.scatter;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
@ -44,12 +45,12 @@ public class ScatterMul extends DynamicCustomOp {
super(null, sameDiff, new SDVariable[]{ref, indices, updates}, false); super(null, sameDiff, new SDVariable[]{ref, indices, updates}, false);
} }
public ScatterMul(INDArray ref, INDArray indices, INDArray updates) {
addInputArgument(ref, indices, updates);
}
public ScatterMul() {} public ScatterMul() {}
public ScatterMul(@NonNull INDArray ref, @NonNull INDArray indices, @NonNull INDArray update){
super(new INDArray[]{ref, indices, update}, null);
}
@Override @Override
public String opName() { public String opName() {
return "scatter_mul"; return "scatter_mul";

View File

@ -44,12 +44,12 @@ public class ScatterSub extends DynamicCustomOp {
super(null, sameDiff, new SDVariable[]{ref, indices, updates}, false); super(null, sameDiff, new SDVariable[]{ref, indices, updates}, false);
} }
public ScatterSub(INDArray ref, INDArray indices, INDArray updates) {
addInputArgument(ref, indices, updates);
}
public ScatterSub() {} public ScatterSub() {}
public ScatterSub(INDArray ref, INDArray indices, INDArray update){
super(new INDArray[]{ref, indices, update}, null);
}
@Override @Override
public String opName() { public String opName() {
return "scatter_sub"; return "scatter_sub";

View File

@ -54,12 +54,12 @@ public class ScatterUpdate extends DynamicCustomOp {
super(null, sameDiff, new SDVariable[]{ref, indices, updates}, false); super(null, sameDiff, new SDVariable[]{ref, indices, updates}, false);
} }
public ScatterUpdate(INDArray ref, INDArray indices, INDArray updates) {
addInputArgument(ref, indices, updates);
}
public ScatterUpdate(){} public ScatterUpdate(){}
public ScatterUpdate(INDArray ref, INDArray indices, INDArray update){
super(new INDArray[]{ref, indices, update}, null);
}
@Override @Override
public String opName() { public String opName() {
return "scatter_upd"; return "scatter_upd";

View File

@ -44,7 +44,7 @@ public class Concat extends DynamicCustomOp {
} }
public Concat(int concatDimension, INDArray... arrays) { public Concat(int concatDimension, INDArray... arrays) {
super(null, arrays, new INDArray[0]); super(null, arrays, null);
this.concatDimension = concatDimension; this.concatDimension = concatDimension;
addIArgument(concatDimension); addIArgument(concatDimension);
} }

View File

@ -67,15 +67,16 @@ public class ExpandDims extends DynamicCustomOp {
super(null, inputs, outputs); super(null, inputs, outputs);
} }
public ExpandDims(INDArray input, int axis) {
addInputArgument(input);
addIArgument(axis);
}
public ExpandDims(SameDiff sameDiff, SDVariable[] args, boolean inPlace) { public ExpandDims(SameDiff sameDiff, SDVariable[] args, boolean inPlace) {
super(null, sameDiff, args, inPlace); super(null, sameDiff, args, inPlace);
} }
public ExpandDims(INDArray x, int axis){
super(new INDArray[]{x}, null);
this.jaxis = axis;
addIArgument(axis);
}
@Override @Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) { public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
val targetNode = TFGraphMapper.getNodeWithNameFromGraph(graph, nodeDef.getInput(1)); val targetNode = TFGraphMapper.getNodeWithNameFromGraph(graph, nodeDef.getInput(1));

View File

@ -42,6 +42,10 @@ public class GatherNd extends DynamicCustomOp {
super(new INDArray[]{df, indices}, null); super(new INDArray[]{df, indices}, null);
} }
public GatherNd(INDArray[] inputs, INDArray[] outputs){
super(inputs, outputs);
}
@Override @Override
public String opName() { public String opName() {
return "gather_nd"; return "gather_nd";

View File

@ -1,5 +1,6 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc. * Copyright (c) 2015-2019 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at

View File

@ -67,8 +67,7 @@ public class OneHot extends DynamicCustomOp {
} }
public OneHot(INDArray indices, int depth) { public OneHot(INDArray indices, int depth) {
addInputArgument(indices); this(indices, null, depth, 0, 1.0, 0.0);
addIArgument(depth);
} }
public OneHot(INDArray indices, INDArray output, int depth, int axis, double on, double off) { public OneHot(INDArray indices, INDArray output, int depth, int axis, double on, double off) {
@ -80,14 +79,16 @@ public class OneHot extends DynamicCustomOp {
addArgs(); addArgs();
} }
public OneHot(INDArray indices, int depth, int axis, double on, double off, DataType dataType) { public OneHot(INDArray indices, int depth, int axis, double on, double off) {
addInputArgument(indices); this(indices, null, depth, axis, on, off);
addIArgument(depth, axis);
addTArgument(on, off);
addDArgument(dataType);
} }
public OneHot(INDArray indices, int depth, int axis, double on, double off, DataType dataType) {
this(indices, null, depth, axis, on, off);
this.outputType = dataType;
if (outputType != null)
addDArgument(outputType);
}
protected void addArgs() { protected void addArgs() {
addIArgument(jaxis); addIArgument(jaxis);

View File

@ -23,6 +23,7 @@ import org.nd4j.base.Preconditions;
import org.nd4j.imports.descriptors.properties.PropertyMapping; import org.nd4j.imports.descriptors.properties.PropertyMapping;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.GraphDef;
@ -44,6 +45,10 @@ public class ParallelStack extends DynamicCustomOp {
super(null, sameDiff, values, false); super(null, sameDiff, values, false);
} }
public ParallelStack(INDArray[] inputs){
super(inputs, null);
}
@Override @Override
public String opName() { public String opName() {
return "parallel_stack"; return "parallel_stack";

View File

@ -55,15 +55,16 @@ public class Permute extends Transpose {
addIArgument(permuteDims); addIArgument(permuteDims);
} }
public Permute(INDArray input, int... permuteDims){
addInputArgument(input);
addIArgument(permuteDims);
}
public Permute(SameDiff sd, SDVariable input, SDVariable permuteDims){ public Permute(SameDiff sd, SDVariable input, SDVariable permuteDims){
super(sd, input, permuteDims); super(sd, input, permuteDims);
} }
public Permute(INDArray input, int... permuteDims){
super(input, null);
this.permuteDims = permuteDims;
addIArgument(permuteDims);
}
public Permute() { public Permute() {
} }

View File

@ -16,6 +16,7 @@
package org.nd4j.linalg.api.ops.impl.shape; package org.nd4j.linalg.api.ops.impl.shape;
import lombok.NoArgsConstructor;
import lombok.NonNull; import lombok.NonNull;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
@ -41,6 +42,7 @@ import java.util.Map;
* @author Adam Gibson * @author Adam Gibson
*/ */
@Slf4j @Slf4j
@NoArgsConstructor
public class Reshape extends DynamicCustomOp { public class Reshape extends DynamicCustomOp {
private long[] shape; private long[] shape;
@ -61,15 +63,13 @@ public class Reshape extends DynamicCustomOp {
addIArgument(shape); addIArgument(shape);
} }
public Reshape(INDArray in, INDArray shape){
this(in, shape, null);
}
public Reshape(@NonNull INDArray in, @NonNull INDArray shape, INDArray out){ public Reshape(@NonNull INDArray in, @NonNull INDArray shape, INDArray out){
super(null, new INDArray[]{in, shape}, wrapOrNull(out), null, (List<Integer>)null); super(null, new INDArray[]{in, shape}, wrapOrNull(out), null, (List<Integer>)null);
} }
public Reshape() { public Reshape(INDArray in, INDArray shape){
this(in, shape, null);
} }

View File

@ -49,10 +49,9 @@ public class Size extends DynamicCustomOp {
} }
public Size(INDArray in){ public Size(INDArray in){
addInputArgument(in); super(new INDArray[] {in}, null);
} }
@Override @Override
public String onnxName() { public String onnxName() {
throw new NoOpNameFoundException("No onnx name found for shape " + opName()); throw new NoOpNameFoundException("No onnx name found for shape " + opName());

View File

@ -54,8 +54,10 @@ public class Slice extends DynamicCustomOp {
super(null, sameDiff, new SDVariable[]{input, begin, end}); super(null, sameDiff, new SDVariable[]{input, begin, end});
} }
public Slice(INDArray in, int[] begin, int... size) { public Slice(INDArray input, int[] begin, int... size){
addInputArgument(in); super(new INDArray[] {input}, null);
this.begin = begin;
this.size = size;
addIArgument(begin); addIArgument(begin);
addIArgument(size); addIArgument(size);
} }

View File

@ -16,6 +16,7 @@
package org.nd4j.linalg.api.ops.impl.shape; package org.nd4j.linalg.api.ops.impl.shape;
import lombok.NoArgsConstructor;
import lombok.NonNull; import lombok.NonNull;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;

View File

@ -68,12 +68,15 @@ public class Tile extends DynamicCustomOp {
} }
public Tile(INDArray x, INDArray repeat){ public Tile(INDArray x, INDArray repeat){
addInputArgument(x, repeat); super(null, new INDArray[] {x, repeat}, null);
this.jaxis = null;
} }
public Tile(INDArray x, int... repeat) { public Tile(INDArray inputs, int... axis){
addInputArgument(x); super(null, new INDArray[] {inputs}, null);
addIArgument(repeat); this.jaxis = axis;
this.is_static_reps = true;
addArguments();
} }
public Tile() {} public Tile() {}

View File

@ -61,7 +61,7 @@ public class Transpose extends DynamicCustomOp {
} }
public Transpose(INDArray input){ public Transpose(INDArray input){
addInputArgument(input); this(input, null);
} }
public Transpose() { public Transpose() {

View File

@ -62,12 +62,10 @@ public class MatchConditionTransform extends BaseTransformBoolOp {
this(x, z, Nd4j.EPS_THRESHOLD, condition); this(x, z, Nd4j.EPS_THRESHOLD, condition);
} }
public MatchConditionTransform(INDArray x, @NonNull Condition condition) { public MatchConditionTransform(INDArray x, @NonNull Condition condition) {
this(x, null, Nd4j.EPS_THRESHOLD, condition); this(x, null, Nd4j.EPS_THRESHOLD, condition);
} }
public MatchConditionTransform(INDArray x, INDArray z, double eps, @NonNull Condition condition) { public MatchConditionTransform(INDArray x, INDArray z, double eps, @NonNull Condition condition) {
super(x, null, z); super(x, null, z);

View File

@ -69,7 +69,7 @@ public class CompareAndReplace extends BaseTransformSameOp {
* @param condition * @param condition
*/ */
public CompareAndReplace(INDArray x, INDArray y, Condition condition) { public CompareAndReplace(INDArray x, INDArray y, Condition condition) {
this(x, y, x, condition); this(x, y, null, condition);
} }
/** /**

View File

@ -24,6 +24,7 @@ import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.Op; import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.factory.Nd4j;
import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef; import org.tensorflow.framework.NodeDef;
@ -46,6 +47,10 @@ public class Assign extends DynamicCustomOp {
super(null,inputs, outputs); super(null,inputs, outputs);
} }
public Assign(INDArray x, INDArray y ) {
this( new INDArray[]{y ,x},new INDArray[]{y}); // TODO: Still check. y cannot be null, must be same shape as x.
}
@Override @Override
public void addIArgument(int... arg) { public void addIArgument(int... arg) {
super.addIArgument(arg); super.addIArgument(arg);

View File

@ -16,6 +16,7 @@
package org.nd4j.linalg.api.ops.impl.transforms.custom; package org.nd4j.linalg.api.ops.impl.transforms.custom;
import lombok.NonNull;
import lombok.val; import lombok.val;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
@ -67,11 +68,19 @@ public class DynamicPartition extends DynamicCustomOp {
addArgs(); addArgs();
} }
public DynamicPartition(INDArray input, INDArray partitions, int numPartitions) { public DynamicPartition(@NonNull INDArray input, @NonNull INDArray partitions, int numPartitions) {
addInputArgument(input); super(new INDArray[]{input, partitions}, null);
addIArgument(numPartitions); this.numPartitions = numPartitions;
addArgs();
} }
public DynamicPartition(INDArray x, INDArray [] partitions, int numPartitions){
//TODO; This needs fixing.
super(new INDArray[]{x}, null);
// this.partitions = partitions;
this.numPartitions = numPartitions;
addArgs();
}
@Override @Override
public List<SDVariable> doDiff(List<SDVariable> i_v) { public List<SDVariable> doDiff(List<SDVariable> i_v) {

View File

@ -16,6 +16,7 @@
package org.nd4j.linalg.api.ops.impl.transforms.custom; package org.nd4j.linalg.api.ops.impl.transforms.custom;
import lombok.NonNull;
import org.apache.commons.lang3.ArrayUtils; import org.apache.commons.lang3.ArrayUtils;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
@ -61,14 +62,8 @@ public class DynamicStitch extends DynamicCustomOp {
this.numPartitions = inputs.length; this.numPartitions = inputs.length;
} }
public DynamicStitch(INDArray[] inputs, INDArray[] indices) { public DynamicStitch(@NonNull INDArray[] indices, @NonNull INDArray[] inputs) {
for (INDArray input : inputs) { super(ArrayUtils.addAll(indices, inputs), null);
addInputArgument(input);
}
for (INDArray index : indices) {
addInputArgument(index);
}
} }
@Override @Override

View File

@ -48,14 +48,14 @@ public class EqualTo extends BaseDynamicTransformOp {
super(inputs, outputs); super(inputs, outputs);
} }
public EqualTo( INDArray x, INDArray y) {
addInputArgument(x, y);
}
public EqualTo(INDArray x, INDArray y, INDArray z){ public EqualTo(INDArray x, INDArray y, INDArray z){
this(new INDArray[]{x, y}, new INDArray[]{z}); this(new INDArray[]{x, y}, new INDArray[]{z});
} }
public EqualTo(INDArray x, INDArray y){
this(new INDArray[]{x, y}, null);
}
@Override @Override
public String opName() { public String opName() {
return "equals"; return "equals";

View File

@ -27,6 +27,7 @@ import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.Op; import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef; import org.tensorflow.framework.NodeDef;
@ -55,19 +56,21 @@ public class Fill extends DynamicCustomOp {
super(null,sameDiff, new SDVariable[] {shape}, false); super(null,sameDiff, new SDVariable[] {shape}, false);
this.value = value; this.value = value;
this.outputDataType = outputDataType; this.outputDataType = outputDataType;
this.outputDataType = outputDataType;
addArgs(); addArgs();
} }
public Fill(INDArray shape, DataType outputDataType, double value) {
super(new INDArray[]{shape, Nd4j.scalar(outputDataType, value)}, null);
this.value = value;
this.outputDataType = outputDataType;
}
public Fill(INDArray shape, INDArray result, double value) { public Fill(INDArray shape, INDArray result, double value) {
super(null, shape, result, Collections.singletonList(value), null); super(null, shape, result, Collections.singletonList(value), null);
this.value = value; this.value = value;
} }
public Fill(INDArray shape, DataType dataType, double value) {
super(null, shape, null, Collections.singletonList(value), null);
this.value = value;
}
public Fill(INDArray shape, INDArray value, INDArray result) { public Fill(INDArray shape, INDArray value, INDArray result) {
super(null, new INDArray[]{shape, value}, new INDArray[]{result}); super(null, new INDArray[]{shape, value}, new INDArray[]{result});
} }

View File

@ -49,14 +49,14 @@ public class GreaterThan extends BaseDynamicTransformOp {
super(inputs, outputs); super(inputs, outputs);
} }
public GreaterThan( INDArray x, INDArray y) {
addInputArgument(x,y);
}
public GreaterThan(INDArray x, INDArray y, INDArray z){ public GreaterThan(INDArray x, INDArray y, INDArray z){
this(new INDArray[]{x, y}, new INDArray[]{z}); this(new INDArray[]{x, y}, new INDArray[]{z});
} }
public GreaterThan(INDArray x, INDArray y){
this(new INDArray[]{x, y}, null);
}
@Override @Override
public String opName() { public String opName() {
return "greater"; return "greater";

View File

@ -52,8 +52,8 @@ public class GreaterThanOrEqual extends BaseDynamicTransformOp {
this(new INDArray[]{x, y}, new INDArray[]{z}); this(new INDArray[]{x, y}, new INDArray[]{z});
} }
public GreaterThanOrEqual(INDArray x, INDArray y) {
public GreaterThanOrEqual(INDArray x, INDArray y){
this(new INDArray[]{x, y}, null); this(new INDArray[]{x, y}, null);
} }

View File

@ -45,8 +45,8 @@ public class IsNumericTensor extends DynamicCustomOp {
super(null, inputs, outputs); super(null, inputs, outputs);
} }
public IsNumericTensor(INDArray input) { public IsNumericTensor(INDArray inputs) {
addInputArgument(input); super( new INDArray[] {inputs}, null);
} }
@Override @Override

View File

@ -49,14 +49,14 @@ public class LessThan extends BaseDynamicTransformOp {
super(inputs, outputs); super(inputs, outputs);
} }
public LessThan( INDArray x, INDArray y) {
addInputArgument(x,y);
}
public LessThan(INDArray x, INDArray y, INDArray z){ public LessThan(INDArray x, INDArray y, INDArray z){
this(new INDArray[]{x, y}, new INDArray[]{z}); this(new INDArray[]{x, y}, new INDArray[]{z});
} }
public LessThan(INDArray x, INDArray y){
this(new INDArray[]{x, y}, null);
}
@Override @Override
public String opName() { public String opName() {
return "less"; return "less";

View File

@ -48,14 +48,14 @@ public class LessThanOrEqual extends BaseDynamicTransformOp {
super(inputs, outputs); super(inputs, outputs);
} }
public LessThanOrEqual( INDArray x, INDArray y) {
addInputArgument(x,y);
}
public LessThanOrEqual(INDArray x, INDArray y, INDArray z){ public LessThanOrEqual(INDArray x, INDArray y, INDArray z){
this(new INDArray[]{x, y}, new INDArray[]{z}); this(new INDArray[]{x, y}, new INDArray[]{z});
} }
public LessThanOrEqual(INDArray x, INDArray y){
this(new INDArray[]{x, y}, null);
}
@Override @Override
public String opName() { public String opName() {
return "less_equal"; return "less_equal";

View File

@ -48,12 +48,12 @@ public class Max extends BaseDynamicTransformOp {
super(new INDArray[]{first, second}, out == null ? null : new INDArray[]{out}); super(new INDArray[]{first, second}, out == null ? null : new INDArray[]{out});
} }
public Max( INDArray[] inputs, INDArray[] outputs) { public Max( INDArray first, INDArray second){
super(inputs, outputs); this(first, second, null);
} }
public Max( INDArray x, INDArray y) { public Max( INDArray[] inputs, INDArray[] outputs) {
addInputArgument(x,y); super(inputs, outputs);
} }
@Override @Override

View File

@ -48,12 +48,12 @@ public class Min extends BaseDynamicTransformOp {
super(new INDArray[]{first, second}, out == null ? null : new INDArray[]{out}); super(new INDArray[]{first, second}, out == null ? null : new INDArray[]{out});
} }
public Min( INDArray[] inputs, INDArray[] outputs) { public Min( INDArray first, INDArray second){
super(inputs, outputs); this(first, second, null);
} }
public Min( INDArray x, INDArray y) { public Min( INDArray[] inputs, INDArray[] outputs) {
addInputArgument(x,y); super(inputs, outputs);
} }
@Override @Override

View File

@ -48,14 +48,14 @@ public class NotEqualTo extends BaseDynamicTransformOp {
super(inputs, outputs); super(inputs, outputs);
} }
public NotEqualTo( INDArray x, INDArray y) {
addInputArgument(x,y);
}
public NotEqualTo(INDArray x, INDArray y, INDArray z){ public NotEqualTo(INDArray x, INDArray y, INDArray z){
this(new INDArray[]{x, y}, new INDArray[]{z}); this(new INDArray[]{x, y}, new INDArray[]{z});
} }
public NotEqualTo(INDArray x, INDArray y){
this(new INDArray[]{x, y}, null);
}
@Override @Override
public String opName() { public String opName() {
return "not_equals"; return "not_equals";

View File

@ -59,6 +59,17 @@ public class ReverseSequence extends DynamicCustomOp {
addArguments(); addArguments();
} }
public ReverseSequence(INDArray x, INDArray seq_lengths, int seqDim, int batchDim){
super(new INDArray[]{x, seq_lengths}, null);
this.seqDim = seqDim;
this.batchDim = batchDim;
addArguments();
}
public ReverseSequence(INDArray x, INDArray seq_lengths){
this(x, seq_lengths, 1, 0);
}
private void addArguments(){ private void addArguments(){
addIArgument(seqDim); addIArgument(seqDim);
addIArgument(batchDim); addIArgument(batchDim);
@ -67,11 +78,6 @@ public class ReverseSequence extends DynamicCustomOp {
public ReverseSequence() { public ReverseSequence() {
} }
public ReverseSequence(INDArray x, INDArray seq_lengths, int seqDim, int batchDim) {
addInputArgument(x, seq_lengths);
addIArgument(seqDim, batchDim);
}
@Override @Override
public String opName() { public String opName() {
return "reverse_sequence"; return "reverse_sequence";

View File

@ -40,7 +40,7 @@ public class SegmentMax extends DynamicCustomOp {
} }
public SegmentMax(INDArray data, INDArray segmentIds){ public SegmentMax(INDArray data, INDArray segmentIds){
addInputArgument(data, segmentIds); super(new INDArray[]{data, segmentIds}, null);
} }
public SegmentMax(){ } public SegmentMax(){ }

View File

@ -39,12 +39,12 @@ public class SegmentMean extends DynamicCustomOp {
super(null, sameDiff, new SDVariable[] {data, segmentIds}, false); super(null, sameDiff, new SDVariable[] {data, segmentIds}, false);
} }
public SegmentMean(INDArray data, INDArray segmentIds) {
addInputArgument(data, segmentIds);
}
public SegmentMean(){ } public SegmentMean(){ }
public SegmentMean(INDArray data, INDArray segmentIds){
super(new INDArray[]{data, segmentIds}, null);
}
@Override @Override
public String opName(){ public String opName(){
return "segment_mean"; return "segment_mean";

View File

@ -39,12 +39,12 @@ public class SegmentMin extends DynamicCustomOp {
super(null, sameDiff, new SDVariable[] {data, segmentIds}, false); super(null, sameDiff, new SDVariable[] {data, segmentIds}, false);
} }
public SegmentMin(INDArray data, INDArray segmentIds) {
addInputArgument(data, segmentIds);
}
public SegmentMin(){ } public SegmentMin(){ }
public SegmentMin(INDArray data, INDArray segmentIds){
super(new INDArray[]{data, segmentIds}, null);
}
@Override @Override
public String opName(){ public String opName(){
return "segment_min"; return "segment_min";

View File

@ -39,12 +39,12 @@ public class SegmentProd extends DynamicCustomOp {
super(null, sameDiff, new SDVariable[] {data, segmentIds}, false); super(null, sameDiff, new SDVariable[] {data, segmentIds}, false);
} }
public SegmentProd(INDArray data, INDArray segmentIds) {
addInputArgument(data, segmentIds);
}
public SegmentProd(){ } public SegmentProd(){ }
public SegmentProd(INDArray data, INDArray segmentIds){
super(new INDArray[]{data, segmentIds}, null);
}
@Override @Override
public String opName(){ public String opName(){
return "segment_prod"; return "segment_prod";

View File

@ -39,12 +39,12 @@ public class SegmentSum extends DynamicCustomOp {
super(null, sameDiff, new SDVariable[] {data, segmentIds}, false); super(null, sameDiff, new SDVariable[] {data, segmentIds}, false);
} }
public SegmentSum(INDArray data, INDArray segmentIds) {
addInputArgument(data, segmentIds);
}
public SegmentSum(){ } public SegmentSum(){ }
public SegmentSum(INDArray data, INDArray segmentIds){
super(new INDArray[]{data, segmentIds}, null);
}
@Override @Override
public String opName(){ public String opName(){
return "segment_sum"; return "segment_sum";

View File

@ -42,7 +42,7 @@ public class Identity extends BaseDynamicTransformOp {
} }
public Identity(INDArray x){ public Identity(INDArray x){
addInputArgument(x); super(new INDArray[]{x}, null);
} }
public Identity(){ } public Identity(){ }

View File

@ -41,13 +41,14 @@ public class UnsortedSegmentMax extends DynamicCustomOp {
addIArgument(numSegments); addIArgument(numSegments);
} }
public UnsortedSegmentMax(){ }
public UnsortedSegmentMax(INDArray data, INDArray segmentIds, int numSegments){ public UnsortedSegmentMax(INDArray data, INDArray segmentIds, int numSegments){
addInputArgument(data, segmentIds); super(new INDArray[]{data, segmentIds}, null);
this.numSegments = numSegments;
addIArgument(numSegments); addIArgument(numSegments);
} }
public UnsortedSegmentMax(){ }
@Override @Override
public String opName(){ public String opName(){
return "unsorted_segment_max"; return "unsorted_segment_max";

View File

@ -46,7 +46,8 @@ public class UnsortedSegmentMean extends DynamicCustomOp {
} }
public UnsortedSegmentMean(INDArray data, INDArray segmentIds, int numSegments){ public UnsortedSegmentMean(INDArray data, INDArray segmentIds, int numSegments){
addInputArgument(data, segmentIds); super(new INDArray[]{data, segmentIds}, null);
this.numSegments = numSegments;
addIArgument(numSegments); addIArgument(numSegments);
} }

View File

@ -46,7 +46,8 @@ public class UnsortedSegmentMin extends DynamicCustomOp {
} }
public UnsortedSegmentMin(INDArray data, INDArray segmentIds, int numSegments){ public UnsortedSegmentMin(INDArray data, INDArray segmentIds, int numSegments){
addInputArgument(data, segmentIds); super(new INDArray[]{data, segmentIds}, null);
this.numSegments = numSegments;
addIArgument(numSegments); addIArgument(numSegments);
} }

View File

@ -46,7 +46,8 @@ public class UnsortedSegmentProd extends DynamicCustomOp {
} }
public UnsortedSegmentProd(INDArray data, INDArray segmentIds, int numSegments){ public UnsortedSegmentProd(INDArray data, INDArray segmentIds, int numSegments){
addInputArgument(data, segmentIds); super(new INDArray[]{data, segmentIds}, null);
this.numSegments = numSegments;
addIArgument(numSegments); addIArgument(numSegments);
} }

View File

@ -39,18 +39,18 @@ public class UnsortedSegmentSqrtN extends DynamicCustomOp {
private int numSegments; private int numSegments;
public UnsortedSegmentSqrtN(INDArray data, INDArray segmentIds, int numSegments) {
addInputArgument(data, segmentIds);
addIArgument(numSegments);
this.numSegments = numSegments;
}
public UnsortedSegmentSqrtN(SameDiff sameDiff, SDVariable data, SDVariable segmentIds, int numSegments) { public UnsortedSegmentSqrtN(SameDiff sameDiff, SDVariable data, SDVariable segmentIds, int numSegments) {
super(null, sameDiff, new SDVariable[] {data, segmentIds}, false); super(null, sameDiff, new SDVariable[] {data, segmentIds}, false);
this.numSegments = numSegments; this.numSegments = numSegments;
addIArgument(numSegments); addIArgument(numSegments);
} }
public UnsortedSegmentSqrtN(INDArray data, INDArray segmentIds, int numSegments){
super(new INDArray[]{data, segmentIds}, null);
this.numSegments = numSegments;
addIArgument(numSegments);
}
@Override @Override
public String opName(){ public String opName(){
return "unsorted_segment_sqrt_n"; return "unsorted_segment_sqrt_n";

View File

@ -47,7 +47,8 @@ public class UnsortedSegmentSum extends DynamicCustomOp {
} }
public UnsortedSegmentSum(INDArray data, INDArray segmentIds, int numSegments){ public UnsortedSegmentSum(INDArray data, INDArray segmentIds, int numSegments){
addInputArgument(data, segmentIds); super(new INDArray[]{data, segmentIds}, null);
this.numSegments = numSegments;
addIArgument(numSegments); addIArgument(numSegments);
} }

View File

@ -130,13 +130,20 @@ public class NDValidation {
" type; got array with non-integer data type " + v.dataType()); " type; got array with non-integer data type " + v.dataType());
} }
public static void validateInteger(String opName, String inputName, INDArray[] vars) { /**
for (INDArray v : vars) { * Validate that the operation is being applied on an integer type INDArray []
*
* @param opName Operation name to print in the exception
* @param inputName Name of the input to the op to validate
* @param v Variable to validate datatype for (input to operation)
*/
public static void validateInteger(String opName, String inputName, INDArray [] v) {
if (v == null) if (v == null)
return; return;
if (!v.dataType().isIntType()) for (int i = 0; i < v.length; i++) {
if (!v[i].dataType().isIntType())
throw new IllegalStateException("Input \"" + inputName + "\" for operation \"" + opName + "\" must be an integer" + throw new IllegalStateException("Input \"" + inputName + "\" for operation \"" + opName + "\" must be an integer" +
" type; got array with non-integer data type " + v.dataType()); " type; got array with non-integer data type member" + v[i].dataType());
} }
} }
@ -246,10 +253,11 @@ public class NDValidation {
} }
public static boolean isSameType(INDArray[] x) { public static boolean isSameType(INDArray[] x) {
DataType firstDataType = x[0].dataType(); if(x.length == 0)
if (x.length > 1) { return true;
for (int i = 1; i < x.length; ++i) { DataType first = x[0].dataType();
if (firstDataType != x[i].dataType()) for( int i=1; i<x.length; i++ ){
if(first != x[i].dataType()){
return false; return false;
} }
} }

View File

@ -2006,10 +2006,9 @@ public class Nd4j {
* @return the linearly spaced vector * @return the linearly spaced vector
*/ */
public static INDArray linspace(@NonNull DataType dataType, double lower, double step, long num) { public static INDArray linspace(@NonNull DataType dataType, double lower, double step, long num) {
Preconditions.checkState(dataType.isFPType()); Preconditions.checkState(dataType.isFPType(), "Datatype must be a floating point type for linspace, got %s", dataType);
if (num == 1) if (num == 1)
return Nd4j.scalar(dataType, lower); return Nd4j.scalar(dataType, lower);
return Nd4j.getExecutioner().exec(new Linspace(lower, num, step, dataType)); return Nd4j.getExecutioner().exec(new Linspace(lower, num, step, dataType));
} }
@ -2022,10 +2021,9 @@ public class Nd4j {
* @return the linearly spaced vector * @return the linearly spaced vector
*/ */
public static INDArray linspace( double lower, double upper, long num, @NonNull DataType dataType) { public static INDArray linspace( double lower, double upper, long num, @NonNull DataType dataType) {
Preconditions.checkState(dataType.isFPType()); Preconditions.checkState(dataType.isFPType(), "Datatype must be a floating point type for linspace, got %s", dataType);
if (num == 1) if (num == 1)
return Nd4j.scalar(dataType, lower); return Nd4j.scalar(dataType, lower);
return Nd4j.getExecutioner().exec(new Linspace(lower, upper, num, dataType)); return Nd4j.getExecutioner().exec(new Linspace(lower, upper, num, dataType));
} }

View File

@ -159,7 +159,7 @@ public class BooleanIndexing {
if (to.length() != from.length()) if (to.length() != from.length())
throw new IllegalStateException("Mis matched length for to and from"); throw new IllegalStateException("Mis matched length for to and from");
Nd4j.getExecutioner().exec(new CompareAndSet(to, from, condition)); Nd4j.getExecutioner().exec(new CompareAndSet(to, from, to, condition));
} }
@ -177,7 +177,7 @@ public class BooleanIndexing {
if (to.length() != from.length()) if (to.length() != from.length())
throw new IllegalStateException("Mis matched length for to and from"); throw new IllegalStateException("Mis matched length for to and from");
Nd4j.getExecutioner().exec(new CompareAndReplace(to, from, condition)); Nd4j.getExecutioner().exec(new CompareAndReplace(to, from, to, condition));
} }
/** /**

View File

@ -1,5 +1,6 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at

View File

@ -19,13 +19,11 @@ package org.nd4j.linalg.learning;
import lombok.Data; import lombok.Data;
import lombok.NonNull; import lombok.NonNull;
import org.apache.commons.math3.util.FastMath;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.learning.config.Nadam; import org.nd4j.linalg.learning.config.Nadam;
import org.nd4j.linalg.ops.transforms.Transforms;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;

View File

@ -792,6 +792,19 @@ public class CudaExecutioner extends DefaultOpExecutioner {
} }
} }
boolean keepDims = op.isKeepDims();
long[] retShape = Shape.reductionShape(x, dimension, true, keepDims);
if(z == null || x == z) {
val ret = Nd4j.createUninitialized(DataType.LONG, retShape);
setZ(ret, op, oc);
z = ret;
} else if(!Arrays.equals(retShape, z.shape())){
throw new IllegalStateException("Z array shape does not match expected return type for op " + op
+ ": expected shape " + Arrays.toString(retShape) + ", z.shape()=" + Arrays.toString(z.shape()));
}
long st = profilingConfigurableHookIn(op); long st = profilingConfigurableHookIn(op);
checkForCompression(op); checkForCompression(op);
@ -2060,7 +2073,7 @@ public class CudaExecutioner extends DefaultOpExecutioner {
throw new ND4JIllegalStateException("Op name " + op.opName() + " failed to execute. You can't execute non-inplace CustomOp without outputs being specified"); throw new ND4JIllegalStateException("Op name " + op.opName() + " failed to execute. You can't execute non-inplace CustomOp without outputs being specified");
for (val shape: list) for (val shape: list)
op.addOutputArgument(Nd4j.create(shape)); op.addOutputArgument(Nd4j.create(shape, false));
shapeOverride = true; shapeOverride = true;
} catch (Exception e) { } catch (Exception e) {

View File

@ -772,8 +772,10 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
if (y != null) { if (y != null) {
if (z == null) if (z == null) {
setZ(Nd4j.create(op.resultType(), x.shape()), op, oc); setZ(Nd4j.create(op.resultType(), x.shape()), op, oc);
z = getZ(op, oc);
}
op.validateDataTypes(oc, experimentalMode.get()); op.validateDataTypes(oc, experimentalMode.get());

View File

@ -234,9 +234,9 @@ public class BooleanIndexingTest extends BaseNd4jTest {
INDArray array = Nd4j.create(new double[] {1, 2, 0, 4, 5}); INDArray array = Nd4j.create(new double[] {1, 2, 0, 4, 5});
INDArray comp = Nd4j.create(new double[] {1, 2, 3, 4, 5}); INDArray comp = Nd4j.create(new double[] {1, 2, 3, 4, 5});
Nd4j.getExecutioner().exec(new CompareAndReplace(array, comp, Conditions.lessThan(1))); INDArray z = Nd4j.exec(new CompareAndReplace(array, comp, Conditions.lessThan(1)));
assertEquals(comp, array); assertEquals(comp, z);
} }
@Test @Test
@ -256,9 +256,9 @@ public class BooleanIndexingTest extends BaseNd4jTest {
INDArray y = Nd4j.create(new double[] {2, 4, 3, 4, 5}); INDArray y = Nd4j.create(new double[] {2, 4, 3, 4, 5});
INDArray comp = Nd4j.create(new double[] {2, 4, 0, 4, 5}); INDArray comp = Nd4j.create(new double[] {2, 4, 0, 4, 5});
Nd4j.getExecutioner().exec(new CompareAndReplace(x, y, Conditions.epsNotEquals(0.0))); INDArray z = Nd4j.exec(new CompareAndReplace(x, y, Conditions.epsNotEquals(0.0)));
assertEquals(comp, x); assertEquals(comp, z);
} }
@Test @Test
@ -267,9 +267,9 @@ public class BooleanIndexingTest extends BaseNd4jTest {
INDArray y = Nd4j.create(new double[] {2, 4, 3, 4, 5}); INDArray y = Nd4j.create(new double[] {2, 4, 3, 4, 5});
INDArray comp = Nd4j.create(new double[] {2, 4, 3, 4, 5}); INDArray comp = Nd4j.create(new double[] {2, 4, 3, 4, 5});
Nd4j.getExecutioner().exec(new CompareAndReplace(x, y, Conditions.lessThan(4))); INDArray z = Nd4j.exec(new CompareAndReplace(x, y, Conditions.lessThan(4)));
assertEquals(comp, x); assertEquals(comp, z);
} }
@Test @Test
@ -278,9 +278,9 @@ public class BooleanIndexingTest extends BaseNd4jTest {
INDArray y = Nd4j.create(new double[] {2, 4, 3, 4, 5}); INDArray y = Nd4j.create(new double[] {2, 4, 3, 4, 5});
INDArray comp = Nd4j.create(new double[] {2, 2, 3, 4, 5}); INDArray comp = Nd4j.create(new double[] {2, 2, 3, 4, 5});
Nd4j.getExecutioner().exec(new CompareAndReplace(x, y, Conditions.lessThan(2))); INDArray z = Nd4j.exec(new CompareAndReplace(x, y, Conditions.lessThan(2)));
assertEquals(comp, x); assertEquals(comp, z);
} }