Base namespace (#287)
* wip Signed-off-by: Robert Altena <Rob@Ra-ai.com> * up to assign operation. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * fix Imax, IMin. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * concat. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * dynamicPartition Signed-off-by: Robert Altena <Rob@Ra-ai.com> * new ops up to gte. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * wip Signed-off-by: Robert Altena <Rob@Ra-ai.com> * updated review items. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * up to matchCondition. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * wip Signed-off-by: Robert Altena <Rob@Ra-ai.com> * up to OneHot. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * wip. up to permute. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * wip. up to rank. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * wip. up to scatterMul. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * resolving code review issues. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * wip. inclides UnsortedSegment ops. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * wip. up to stridedSlice. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * fix stridedSlice. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * first pass of SDBaseops.kt complete. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * fix review items. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * wip Signed-off-by: Robert Altena <Rob@Ra-ai.com> * put branch in compilable state. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * add NDBaseTest. fix dynamicPartition signature. failed fix of assign. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * make tests public. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * adds tests up to invertedPermutation. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * fix ScalarEquals, Assign. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * wip Signed-off-by: Robert Altena <Rob@Ra-ai.com> * updates NDBaseTest. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * updates 'check' comments based on test pass/fail. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * fix scalar ops. Update tests, Signed-off-by: Robert Altena <Rob@Ra-ai.com> * dev-tools review items. wip. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * dev-tools code review items. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * Test fixes Signed-off-by: Alex Black <blacka101@gmail.com> * complete review items. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * Comment for logged issue; fix test case Signed-off-by: Alex Black <blacka101@gmail.com> * Fixes Signed-off-by: Alex Black <blacka101@gmail.com> * More fixes * wip Signed-off-by: Robert Altena <Rob@Ra-ai.com> * undo changes to Nd4jCpu.java Signed-off-by: Robert Altena <Rob@Ra-ai.com> * update tests. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * Fixes and regenerate Signed-off-by: AlexDBlack <blacka101@gmail.com> * Small test fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * small fixes to tests. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * Cleanup Signed-off-by: Alex Black <blacka101@gmail.com> * Fixes Signed-off-by: Alex Black <blacka101@gmail.com> * Small CUDAExecutioner fix Signed-off-by: Alex Black <blacka101@gmail.com> * Fixes Signed-off-by: Alex Black <blacka101@gmail.com> * Small CudaExecutioner fix Signed-off-by: Alex Black <blacka101@gmail.com> * Another small CudaExecutioner fix Signed-off-by: Alex Black <blacka101@gmail.com> * Another small CudaExecutioner fix Signed-off-by: Alex Black <blacka101@gmail.com> Co-authored-by: Robert Altena <Rob@Ra-ai.com>master
parent
a5db0e33be
commit
191bda3228
|
@ -45,15 +45,14 @@ public class IMax extends BaseIndexAccumulation {
|
||||||
super(x, z, dimensions);
|
super(x, z, dimensions);
|
||||||
}
|
}
|
||||||
|
|
||||||
public IMax(INDArray x, boolean keepDims, int... dimensions) {
|
|
||||||
super(x, keepDims, dimensions);
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
public IMax(INDArray x, int... dimensions) {
|
public IMax(INDArray x, int... dimensions) {
|
||||||
super(x, null, dimensions);
|
super(x, null, dimensions);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public IMax(INDArray x, boolean keepDims, int... dimensions) {
|
||||||
|
super(x, null, dimensions);
|
||||||
|
this.keepDims = keepDims;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int opNum() {
|
public int opNum() {
|
||||||
|
|
|
@ -53,6 +53,7 @@ public class IMin extends BaseIndexAccumulation {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int opNum() {
|
public int opNum() {
|
||||||
return 1;
|
return 1;
|
||||||
|
|
|
@ -306,4 +306,3 @@ public class Mmul extends DynamicCustomOp {
|
||||||
return Collections.singletonList(dataTypes.get(0));
|
return Collections.singletonList(dataTypes.get(0));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -51,6 +51,35 @@ public class TensorMmul extends DynamicCustomOp {
|
||||||
protected boolean addedEdges;
|
protected boolean addedEdges;
|
||||||
protected MMulTranspose mMulTranspose;
|
protected MMulTranspose mMulTranspose;
|
||||||
|
|
||||||
|
|
||||||
|
public TensorMmul(INDArray x, INDArray y, int[][] axes) {
|
||||||
|
this(x,y,axes[0], axes[1], false, false, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Initialize with the given
|
||||||
|
* input, pairwise transform, result, and number
|
||||||
|
* of elements
|
||||||
|
*
|
||||||
|
* @param x the input
|
||||||
|
* @param y the pairwise transform
|
||||||
|
* @param z the result
|
||||||
|
*/
|
||||||
|
public TensorMmul(INDArray x, INDArray y, INDArray z, int[][] axes) {
|
||||||
|
this(x, y, axes[0], axes[1], false, false, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
public TensorMmul(INDArray x, INDArray y, int[] dimensionsX, int[] dimensionsY,
|
||||||
|
boolean transposeX, boolean transposeY, boolean transposeZ) {
|
||||||
|
super(null,new INDArray[]{x, y},null);
|
||||||
|
this.axes = new int[][]{dimensionsX, dimensionsY};
|
||||||
|
addIArgument(dimensionsX.length);
|
||||||
|
addIArgument(dimensionsX);
|
||||||
|
addIArgument(dimensionsY.length);
|
||||||
|
addIArgument(dimensionsY);
|
||||||
|
addBArgument(transposeX, transposeY, transposeZ);
|
||||||
|
}
|
||||||
|
|
||||||
public TensorMmul(SameDiff sameDiff,
|
public TensorMmul(SameDiff sameDiff,
|
||||||
SDVariable i_v1,
|
SDVariable i_v1,
|
||||||
SDVariable i_v2,
|
SDVariable i_v2,
|
||||||
|
@ -229,34 +258,6 @@ public class TensorMmul extends DynamicCustomOp {
|
||||||
return sameDiff.reshape(ret, aPlusB);
|
return sameDiff.reshape(ret, aPlusB);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
public TensorMmul(INDArray x, INDArray y, int[][] axes) {
|
|
||||||
super(null,new INDArray[]{x, y},null);
|
|
||||||
this.axes = axes;
|
|
||||||
this.extraArgs = new Object[] {axes};
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Initialize with the given
|
|
||||||
* input, pairwise transform, result, and number
|
|
||||||
* of elements
|
|
||||||
*
|
|
||||||
* @param x the input
|
|
||||||
* @param y the pairwise transform
|
|
||||||
* @param z the result
|
|
||||||
*/
|
|
||||||
public TensorMmul(INDArray x, INDArray y, INDArray z, int[][] axes) {
|
|
||||||
super(null,new INDArray[]{x, y, z},null);
|
|
||||||
this.axes = axes;
|
|
||||||
}
|
|
||||||
|
|
||||||
public TensorMmul(INDArray x, INDArray y, int[] dimensionsX, int[] dimensionsY,
|
|
||||||
boolean transposeX, boolean transposeY, boolean transposeZ) {
|
|
||||||
super(null,new INDArray[]{x, y},null);
|
|
||||||
this.axes = new int[][]{dimensionsX, dimensionsY};
|
|
||||||
addBArgument(transposeX, transposeY, transposeZ);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String opName() {
|
public String opName() {
|
||||||
return "tensordot";
|
return "tensordot";
|
||||||
|
|
|
@ -48,7 +48,6 @@ public class NormMax extends BaseReduceFloatOp {
|
||||||
super(x, null, z, dimensions);
|
super(x, null, z, dimensions);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
public NormMax(INDArray x, int... dimensions) {
|
public NormMax(INDArray x, int... dimensions) {
|
||||||
super(x, dimensions);
|
super(x, dimensions);
|
||||||
}
|
}
|
||||||
|
|
|
@ -48,6 +48,10 @@ public class SquaredNorm extends BaseReduceFloatOp {
|
||||||
|
|
||||||
public SquaredNorm(){}
|
public SquaredNorm(){}
|
||||||
|
|
||||||
|
public SquaredNorm(INDArray x, int... dimensions){
|
||||||
|
super(x, dimensions);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int opNum() {
|
public int opNum() {
|
||||||
return 7;
|
return 7;
|
||||||
|
|
|
@ -52,11 +52,15 @@ public class MatchCondition extends BaseReduceLongOp {
|
||||||
|
|
||||||
public MatchCondition() {}
|
public MatchCondition() {}
|
||||||
|
|
||||||
|
|
||||||
public MatchCondition(INDArray x, Condition condition, int... dimensions) {
|
public MatchCondition(INDArray x, Condition condition, int... dimensions) {
|
||||||
this(x, Nd4j.EPS_THRESHOLD, condition, dimensions);
|
this(x, Nd4j.EPS_THRESHOLD, condition, dimensions);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public MatchCondition(INDArray x, Condition condition, boolean keepDims, int... dimensions) {
|
||||||
|
this(x, Nd4j.EPS_THRESHOLD, condition, dimensions);
|
||||||
|
this.keepDims = keepDims;
|
||||||
|
}
|
||||||
|
|
||||||
public MatchCondition(INDArray x, double eps, Condition condition, int... dimensions) {
|
public MatchCondition(INDArray x, double eps, Condition condition, int... dimensions) {
|
||||||
super(x);
|
super(x);
|
||||||
this.compare = condition.getValue();
|
this.compare = condition.getValue();
|
||||||
|
@ -68,10 +72,6 @@ public class MatchCondition extends BaseReduceLongOp {
|
||||||
defineDimensions(dimensions);
|
defineDimensions(dimensions);
|
||||||
}
|
}
|
||||||
|
|
||||||
public MatchCondition(INDArray in, Condition condition, boolean keepDim, int... dimensions) {
|
|
||||||
this(in, condition, dimensions);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int opNum() {
|
public int opNum() {
|
||||||
return 2;
|
return 2;
|
||||||
|
|
|
@ -41,7 +41,6 @@ public class Sum extends BaseReduceSameOp {
|
||||||
super(sameDiff, i_v, i_v2, dimensions);
|
super(sameDiff, i_v, i_v2, dimensions);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
public Sum() {
|
public Sum() {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -40,7 +40,7 @@ public class ScalarEquals extends BaseScalarBoolOp {
|
||||||
}
|
}
|
||||||
|
|
||||||
public ScalarEquals(INDArray x, Number num) {
|
public ScalarEquals(INDArray x, Number num) {
|
||||||
super(x, num);
|
this(x, null, num);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -19,9 +19,11 @@ package org.nd4j.linalg.api.ops.impl.scalar.comparison;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.imports.NoOpNameFoundException;
|
import org.nd4j.imports.NoOpNameFoundException;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.BaseScalarBoolOp;
|
import org.nd4j.linalg.api.ops.BaseScalarBoolOp;
|
||||||
import org.nd4j.linalg.api.ops.BaseScalarOp;
|
import org.nd4j.linalg.api.ops.BaseScalarOp;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
@ -56,7 +58,7 @@ public class ScalarGreaterThan extends BaseScalarBoolOp {
|
||||||
}
|
}
|
||||||
|
|
||||||
public ScalarGreaterThan(INDArray x, Number num) {
|
public ScalarGreaterThan(INDArray x, Number num) {
|
||||||
super(x, num);
|
this(x, null, num);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -19,9 +19,11 @@ package org.nd4j.linalg.api.ops.impl.scalar.comparison;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.imports.NoOpNameFoundException;
|
import org.nd4j.imports.NoOpNameFoundException;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.BaseScalarBoolOp;
|
import org.nd4j.linalg.api.ops.BaseScalarBoolOp;
|
||||||
import org.nd4j.linalg.api.ops.BaseScalarOp;
|
import org.nd4j.linalg.api.ops.BaseScalarOp;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
@ -40,7 +42,7 @@ public class ScalarGreaterThanOrEqual extends BaseScalarBoolOp {
|
||||||
}
|
}
|
||||||
|
|
||||||
public ScalarGreaterThanOrEqual(INDArray x, Number num) {
|
public ScalarGreaterThanOrEqual(INDArray x, Number num) {
|
||||||
super(x, num);
|
this(x, null, num);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -18,9 +18,11 @@ package org.nd4j.linalg.api.ops.impl.scalar.comparison;
|
||||||
|
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.BaseScalarBoolOp;
|
import org.nd4j.linalg.api.ops.BaseScalarBoolOp;
|
||||||
import org.nd4j.linalg.api.ops.BaseScalarOp;
|
import org.nd4j.linalg.api.ops.BaseScalarOp;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
@ -39,7 +41,7 @@ public class ScalarLessThan extends BaseScalarBoolOp {
|
||||||
}
|
}
|
||||||
|
|
||||||
public ScalarLessThan(INDArray x, Number num) {
|
public ScalarLessThan(INDArray x, Number num) {
|
||||||
super(x, num);
|
this(x, null, num);
|
||||||
}
|
}
|
||||||
|
|
||||||
public ScalarLessThan(SameDiff sameDiff, SDVariable i_v, Number scalar, boolean inPlace) {
|
public ScalarLessThan(SameDiff sameDiff, SDVariable i_v, Number scalar, boolean inPlace) {
|
||||||
|
|
|
@ -19,9 +19,11 @@ package org.nd4j.linalg.api.ops.impl.scalar.comparison;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.imports.NoOpNameFoundException;
|
import org.nd4j.imports.NoOpNameFoundException;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.BaseScalarBoolOp;
|
import org.nd4j.linalg.api.ops.BaseScalarBoolOp;
|
||||||
import org.nd4j.linalg.api.ops.BaseScalarOp;
|
import org.nd4j.linalg.api.ops.BaseScalarOp;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
@ -49,7 +51,7 @@ public class ScalarLessThanOrEqual extends BaseScalarBoolOp {
|
||||||
}
|
}
|
||||||
|
|
||||||
public ScalarLessThanOrEqual(INDArray x, Number num) {
|
public ScalarLessThanOrEqual(INDArray x, Number num) {
|
||||||
super(x, num);
|
this(x, null, num);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -41,10 +41,9 @@ public class ScalarNotEquals extends BaseScalarBoolOp {
|
||||||
}
|
}
|
||||||
|
|
||||||
public ScalarNotEquals(INDArray x, Number num) {
|
public ScalarNotEquals(INDArray x, Number num) {
|
||||||
super(x, num);
|
this(x, null, num);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
public ScalarNotEquals(SameDiff sameDiff, SDVariable i_v, Number scalar) {
|
public ScalarNotEquals(SameDiff sameDiff, SDVariable i_v, Number scalar) {
|
||||||
super(sameDiff, i_v, scalar);
|
super(sameDiff, i_v, scalar);
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.scatter;
|
package org.nd4j.linalg.api.ops.impl.scatter;
|
||||||
|
|
||||||
|
import lombok.NonNull;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
|
@ -44,12 +45,12 @@ public class ScatterAdd extends DynamicCustomOp {
|
||||||
super(null, sameDiff, new SDVariable[]{ref, indices, updates}, false);
|
super(null, sameDiff, new SDVariable[]{ref, indices, updates}, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
public ScatterAdd(INDArray ref, INDArray indices, INDArray updates) {
|
|
||||||
addInputArgument(ref, indices, updates);
|
|
||||||
}
|
|
||||||
|
|
||||||
public ScatterAdd(){}
|
public ScatterAdd(){}
|
||||||
|
|
||||||
|
public ScatterAdd(@NonNull INDArray ref, @NonNull INDArray indices, @NonNull INDArray update){
|
||||||
|
super(new INDArray[]{ref, indices, update}, null);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String opName() {
|
public String opName() {
|
||||||
return "scatter_add";
|
return "scatter_add";
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.scatter;
|
package org.nd4j.linalg.api.ops.impl.scatter;
|
||||||
|
|
||||||
|
import lombok.NonNull;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
|
@ -44,11 +45,12 @@ public class ScatterDiv extends DynamicCustomOp {
|
||||||
super(null, sameDiff, new SDVariable[]{ref, indices, updates}, false);
|
super(null, sameDiff, new SDVariable[]{ref, indices, updates}, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
public ScatterDiv(INDArray ref, INDArray indices, INDArray updates) {
|
public ScatterDiv() {}
|
||||||
addInputArgument(ref, indices, updates);
|
|
||||||
|
public ScatterDiv(@NonNull INDArray ref, @NonNull INDArray indices, @NonNull INDArray update){
|
||||||
|
super(new INDArray[]{ref, indices, update}, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
public ScatterDiv() {}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String opName() {
|
public String opName() {
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.scatter;
|
package org.nd4j.linalg.api.ops.impl.scatter;
|
||||||
|
|
||||||
|
import lombok.NonNull;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
|
@ -42,12 +43,12 @@ public class ScatterMax extends DynamicCustomOp {
|
||||||
super(null, sameDiff, new SDVariable[]{ref, indices, updates}, false);
|
super(null, sameDiff, new SDVariable[]{ref, indices, updates}, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
public ScatterMax(INDArray ref, INDArray indices, INDArray updates) {
|
|
||||||
addInputArgument(ref, indices, updates);
|
|
||||||
}
|
|
||||||
|
|
||||||
public ScatterMax() {}
|
public ScatterMax() {}
|
||||||
|
|
||||||
|
public ScatterMax(@NonNull INDArray ref, @NonNull INDArray indices, @NonNull INDArray update){
|
||||||
|
super(new INDArray[]{ref, indices, update}, null);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String opName() {
|
public String opName() {
|
||||||
return "scatter_max";
|
return "scatter_max";
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.scatter;
|
package org.nd4j.linalg.api.ops.impl.scatter;
|
||||||
|
|
||||||
|
import lombok.NonNull;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
|
@ -42,12 +43,12 @@ public class ScatterMin extends DynamicCustomOp {
|
||||||
super(null, sameDiff, new SDVariable[]{ref, indices, updates}, false);
|
super(null, sameDiff, new SDVariable[]{ref, indices, updates}, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
public ScatterMin(INDArray ref, INDArray indices, INDArray updates) {
|
|
||||||
addInputArgument(ref, indices, updates);
|
|
||||||
}
|
|
||||||
|
|
||||||
public ScatterMin() {}
|
public ScatterMin() {}
|
||||||
|
|
||||||
|
public ScatterMin(@NonNull INDArray ref, @NonNull INDArray indices, @NonNull INDArray update){
|
||||||
|
super(new INDArray[]{ref, indices, update}, null);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String opName() {
|
public String opName() {
|
||||||
return "scatter_min";
|
return "scatter_min";
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.scatter;
|
package org.nd4j.linalg.api.ops.impl.scatter;
|
||||||
|
|
||||||
|
import lombok.NonNull;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
|
@ -44,12 +45,12 @@ public class ScatterMul extends DynamicCustomOp {
|
||||||
super(null, sameDiff, new SDVariable[]{ref, indices, updates}, false);
|
super(null, sameDiff, new SDVariable[]{ref, indices, updates}, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
public ScatterMul(INDArray ref, INDArray indices, INDArray updates) {
|
|
||||||
addInputArgument(ref, indices, updates);
|
|
||||||
}
|
|
||||||
|
|
||||||
public ScatterMul() {}
|
public ScatterMul() {}
|
||||||
|
|
||||||
|
public ScatterMul(@NonNull INDArray ref, @NonNull INDArray indices, @NonNull INDArray update){
|
||||||
|
super(new INDArray[]{ref, indices, update}, null);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String opName() {
|
public String opName() {
|
||||||
return "scatter_mul";
|
return "scatter_mul";
|
||||||
|
|
|
@ -44,12 +44,12 @@ public class ScatterSub extends DynamicCustomOp {
|
||||||
super(null, sameDiff, new SDVariable[]{ref, indices, updates}, false);
|
super(null, sameDiff, new SDVariable[]{ref, indices, updates}, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
public ScatterSub(INDArray ref, INDArray indices, INDArray updates) {
|
|
||||||
addInputArgument(ref, indices, updates);
|
|
||||||
}
|
|
||||||
|
|
||||||
public ScatterSub() {}
|
public ScatterSub() {}
|
||||||
|
|
||||||
|
public ScatterSub(INDArray ref, INDArray indices, INDArray update){
|
||||||
|
super(new INDArray[]{ref, indices, update}, null);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String opName() {
|
public String opName() {
|
||||||
return "scatter_sub";
|
return "scatter_sub";
|
||||||
|
|
|
@ -54,12 +54,12 @@ public class ScatterUpdate extends DynamicCustomOp {
|
||||||
super(null, sameDiff, new SDVariable[]{ref, indices, updates}, false);
|
super(null, sameDiff, new SDVariable[]{ref, indices, updates}, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
public ScatterUpdate(INDArray ref, INDArray indices, INDArray updates) {
|
|
||||||
addInputArgument(ref, indices, updates);
|
|
||||||
}
|
|
||||||
|
|
||||||
public ScatterUpdate(){}
|
public ScatterUpdate(){}
|
||||||
|
|
||||||
|
public ScatterUpdate(INDArray ref, INDArray indices, INDArray update){
|
||||||
|
super(new INDArray[]{ref, indices, update}, null);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String opName() {
|
public String opName() {
|
||||||
return "scatter_upd";
|
return "scatter_upd";
|
||||||
|
|
|
@ -44,7 +44,7 @@ public class Concat extends DynamicCustomOp {
|
||||||
}
|
}
|
||||||
|
|
||||||
public Concat(int concatDimension, INDArray... arrays) {
|
public Concat(int concatDimension, INDArray... arrays) {
|
||||||
super(null, arrays, new INDArray[0]);
|
super(null, arrays, null);
|
||||||
this.concatDimension = concatDimension;
|
this.concatDimension = concatDimension;
|
||||||
addIArgument(concatDimension);
|
addIArgument(concatDimension);
|
||||||
}
|
}
|
||||||
|
|
|
@ -67,15 +67,16 @@ public class ExpandDims extends DynamicCustomOp {
|
||||||
super(null, inputs, outputs);
|
super(null, inputs, outputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
public ExpandDims(INDArray input, int axis) {
|
|
||||||
addInputArgument(input);
|
|
||||||
addIArgument(axis);
|
|
||||||
}
|
|
||||||
|
|
||||||
public ExpandDims(SameDiff sameDiff, SDVariable[] args, boolean inPlace) {
|
public ExpandDims(SameDiff sameDiff, SDVariable[] args, boolean inPlace) {
|
||||||
super(null, sameDiff, args, inPlace);
|
super(null, sameDiff, args, inPlace);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public ExpandDims(INDArray x, int axis){
|
||||||
|
super(new INDArray[]{x}, null);
|
||||||
|
this.jaxis = axis;
|
||||||
|
addIArgument(axis);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
||||||
val targetNode = TFGraphMapper.getNodeWithNameFromGraph(graph, nodeDef.getInput(1));
|
val targetNode = TFGraphMapper.getNodeWithNameFromGraph(graph, nodeDef.getInput(1));
|
||||||
|
|
|
@ -42,6 +42,10 @@ public class GatherNd extends DynamicCustomOp {
|
||||||
super(new INDArray[]{df, indices}, null);
|
super(new INDArray[]{df, indices}, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public GatherNd(INDArray[] inputs, INDArray[] outputs){
|
||||||
|
super(inputs, outputs);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String opName() {
|
public String opName() {
|
||||||
return "gather_nd";
|
return "gather_nd";
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
|
|
@ -67,8 +67,7 @@ public class OneHot extends DynamicCustomOp {
|
||||||
}
|
}
|
||||||
|
|
||||||
public OneHot(INDArray indices, int depth) {
|
public OneHot(INDArray indices, int depth) {
|
||||||
addInputArgument(indices);
|
this(indices, null, depth, 0, 1.0, 0.0);
|
||||||
addIArgument(depth);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public OneHot(INDArray indices, INDArray output, int depth, int axis, double on, double off) {
|
public OneHot(INDArray indices, INDArray output, int depth, int axis, double on, double off) {
|
||||||
|
@ -80,14 +79,16 @@ public class OneHot extends DynamicCustomOp {
|
||||||
addArgs();
|
addArgs();
|
||||||
}
|
}
|
||||||
|
|
||||||
public OneHot(INDArray indices, int depth, int axis, double on, double off, DataType dataType) {
|
public OneHot(INDArray indices, int depth, int axis, double on, double off) {
|
||||||
addInputArgument(indices);
|
this(indices, null, depth, axis, on, off);
|
||||||
addIArgument(depth, axis);
|
|
||||||
addTArgument(on, off);
|
|
||||||
addDArgument(dataType);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public OneHot(INDArray indices, int depth, int axis, double on, double off, DataType dataType) {
|
||||||
|
this(indices, null, depth, axis, on, off);
|
||||||
|
this.outputType = dataType;
|
||||||
|
if (outputType != null)
|
||||||
|
addDArgument(outputType);
|
||||||
|
}
|
||||||
|
|
||||||
protected void addArgs() {
|
protected void addArgs() {
|
||||||
addIArgument(jaxis);
|
addIArgument(jaxis);
|
||||||
|
|
|
@ -23,6 +23,7 @@ import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.imports.descriptors.properties.PropertyMapping;
|
import org.nd4j.imports.descriptors.properties.PropertyMapping;
|
||||||
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
import org.tensorflow.framework.AttrValue;
|
import org.tensorflow.framework.AttrValue;
|
||||||
import org.tensorflow.framework.GraphDef;
|
import org.tensorflow.framework.GraphDef;
|
||||||
|
@ -44,6 +45,10 @@ public class ParallelStack extends DynamicCustomOp {
|
||||||
super(null, sameDiff, values, false);
|
super(null, sameDiff, values, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public ParallelStack(INDArray[] inputs){
|
||||||
|
super(inputs, null);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String opName() {
|
public String opName() {
|
||||||
return "parallel_stack";
|
return "parallel_stack";
|
||||||
|
|
|
@ -55,15 +55,16 @@ public class Permute extends Transpose {
|
||||||
addIArgument(permuteDims);
|
addIArgument(permuteDims);
|
||||||
}
|
}
|
||||||
|
|
||||||
public Permute(INDArray input, int... permuteDims){
|
|
||||||
addInputArgument(input);
|
|
||||||
addIArgument(permuteDims);
|
|
||||||
}
|
|
||||||
|
|
||||||
public Permute(SameDiff sd, SDVariable input, SDVariable permuteDims){
|
public Permute(SameDiff sd, SDVariable input, SDVariable permuteDims){
|
||||||
super(sd, input, permuteDims);
|
super(sd, input, permuteDims);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public Permute(INDArray input, int... permuteDims){
|
||||||
|
super(input, null);
|
||||||
|
this.permuteDims = permuteDims;
|
||||||
|
addIArgument(permuteDims);
|
||||||
|
}
|
||||||
|
|
||||||
public Permute() {
|
public Permute() {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.shape;
|
package org.nd4j.linalg.api.ops.impl.shape;
|
||||||
|
|
||||||
|
import lombok.NoArgsConstructor;
|
||||||
import lombok.NonNull;
|
import lombok.NonNull;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
|
@ -41,6 +42,7 @@ import java.util.Map;
|
||||||
* @author Adam Gibson
|
* @author Adam Gibson
|
||||||
*/
|
*/
|
||||||
@Slf4j
|
@Slf4j
|
||||||
|
@NoArgsConstructor
|
||||||
public class Reshape extends DynamicCustomOp {
|
public class Reshape extends DynamicCustomOp {
|
||||||
|
|
||||||
private long[] shape;
|
private long[] shape;
|
||||||
|
@ -61,15 +63,13 @@ public class Reshape extends DynamicCustomOp {
|
||||||
addIArgument(shape);
|
addIArgument(shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
public Reshape(INDArray in, INDArray shape){
|
|
||||||
this(in, shape, null);
|
|
||||||
}
|
|
||||||
|
|
||||||
public Reshape(@NonNull INDArray in, @NonNull INDArray shape, INDArray out){
|
public Reshape(@NonNull INDArray in, @NonNull INDArray shape, INDArray out){
|
||||||
super(null, new INDArray[]{in, shape}, wrapOrNull(out), null, (List<Integer>)null);
|
super(null, new INDArray[]{in, shape}, wrapOrNull(out), null, (List<Integer>)null);
|
||||||
}
|
}
|
||||||
|
|
||||||
public Reshape() {
|
public Reshape(INDArray in, INDArray shape){
|
||||||
|
this(in, shape, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -49,10 +49,9 @@ public class Size extends DynamicCustomOp {
|
||||||
}
|
}
|
||||||
|
|
||||||
public Size(INDArray in){
|
public Size(INDArray in){
|
||||||
addInputArgument(in);
|
super(new INDArray[] {in}, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String onnxName() {
|
public String onnxName() {
|
||||||
throw new NoOpNameFoundException("No onnx name found for shape " + opName());
|
throw new NoOpNameFoundException("No onnx name found for shape " + opName());
|
||||||
|
|
|
@ -54,8 +54,10 @@ public class Slice extends DynamicCustomOp {
|
||||||
super(null, sameDiff, new SDVariable[]{input, begin, end});
|
super(null, sameDiff, new SDVariable[]{input, begin, end});
|
||||||
}
|
}
|
||||||
|
|
||||||
public Slice(INDArray in, int[] begin, int... size) {
|
public Slice(INDArray input, int[] begin, int... size){
|
||||||
addInputArgument(in);
|
super(new INDArray[] {input}, null);
|
||||||
|
this.begin = begin;
|
||||||
|
this.size = size;
|
||||||
addIArgument(begin);
|
addIArgument(begin);
|
||||||
addIArgument(size);
|
addIArgument(size);
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.shape;
|
package org.nd4j.linalg.api.ops.impl.shape;
|
||||||
|
|
||||||
|
import lombok.NoArgsConstructor;
|
||||||
import lombok.NonNull;
|
import lombok.NonNull;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
|
|
|
@ -68,12 +68,15 @@ public class Tile extends DynamicCustomOp {
|
||||||
}
|
}
|
||||||
|
|
||||||
public Tile(INDArray x, INDArray repeat){
|
public Tile(INDArray x, INDArray repeat){
|
||||||
addInputArgument(x, repeat);
|
super(null, new INDArray[] {x, repeat}, null);
|
||||||
|
this.jaxis = null;
|
||||||
}
|
}
|
||||||
|
|
||||||
public Tile(INDArray x, int... repeat) {
|
public Tile(INDArray inputs, int... axis){
|
||||||
addInputArgument(x);
|
super(null, new INDArray[] {inputs}, null);
|
||||||
addIArgument(repeat);
|
this.jaxis = axis;
|
||||||
|
this.is_static_reps = true;
|
||||||
|
addArguments();
|
||||||
}
|
}
|
||||||
|
|
||||||
public Tile() {}
|
public Tile() {}
|
||||||
|
|
|
@ -61,7 +61,7 @@ public class Transpose extends DynamicCustomOp {
|
||||||
}
|
}
|
||||||
|
|
||||||
public Transpose(INDArray input){
|
public Transpose(INDArray input){
|
||||||
addInputArgument(input);
|
this(input, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
public Transpose() {
|
public Transpose() {
|
||||||
|
|
|
@ -62,12 +62,10 @@ public class MatchConditionTransform extends BaseTransformBoolOp {
|
||||||
this(x, z, Nd4j.EPS_THRESHOLD, condition);
|
this(x, z, Nd4j.EPS_THRESHOLD, condition);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
public MatchConditionTransform(INDArray x, @NonNull Condition condition) {
|
public MatchConditionTransform(INDArray x, @NonNull Condition condition) {
|
||||||
this(x, null, Nd4j.EPS_THRESHOLD, condition);
|
this(x, null, Nd4j.EPS_THRESHOLD, condition);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
public MatchConditionTransform(INDArray x, INDArray z, double eps, @NonNull Condition condition) {
|
public MatchConditionTransform(INDArray x, INDArray z, double eps, @NonNull Condition condition) {
|
||||||
super(x, null, z);
|
super(x, null, z);
|
||||||
|
|
||||||
|
|
|
@ -69,7 +69,7 @@ public class CompareAndReplace extends BaseTransformSameOp {
|
||||||
* @param condition
|
* @param condition
|
||||||
*/
|
*/
|
||||||
public CompareAndReplace(INDArray x, INDArray y, Condition condition) {
|
public CompareAndReplace(INDArray x, INDArray y, Condition condition) {
|
||||||
this(x, y, x, condition);
|
this(x, y, null, condition);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -24,6 +24,7 @@ import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
import org.nd4j.linalg.api.ops.Op;
|
import org.nd4j.linalg.api.ops.Op;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.tensorflow.framework.AttrValue;
|
import org.tensorflow.framework.AttrValue;
|
||||||
import org.tensorflow.framework.GraphDef;
|
import org.tensorflow.framework.GraphDef;
|
||||||
import org.tensorflow.framework.NodeDef;
|
import org.tensorflow.framework.NodeDef;
|
||||||
|
@ -46,6 +47,10 @@ public class Assign extends DynamicCustomOp {
|
||||||
super(null,inputs, outputs);
|
super(null,inputs, outputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public Assign(INDArray x, INDArray y ) {
|
||||||
|
this( new INDArray[]{y ,x},new INDArray[]{y}); // TODO: Still check. y cannot be null, must be same shape as x.
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void addIArgument(int... arg) {
|
public void addIArgument(int... arg) {
|
||||||
super.addIArgument(arg);
|
super.addIArgument(arg);
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.transforms.custom;
|
package org.nd4j.linalg.api.ops.impl.transforms.custom;
|
||||||
|
|
||||||
|
import lombok.NonNull;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
@ -67,11 +68,19 @@ public class DynamicPartition extends DynamicCustomOp {
|
||||||
addArgs();
|
addArgs();
|
||||||
}
|
}
|
||||||
|
|
||||||
public DynamicPartition(INDArray input, INDArray partitions, int numPartitions) {
|
public DynamicPartition(@NonNull INDArray input, @NonNull INDArray partitions, int numPartitions) {
|
||||||
addInputArgument(input);
|
super(new INDArray[]{input, partitions}, null);
|
||||||
addIArgument(numPartitions);
|
this.numPartitions = numPartitions;
|
||||||
|
addArgs();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public DynamicPartition(INDArray x, INDArray [] partitions, int numPartitions){
|
||||||
|
//TODO; This needs fixing.
|
||||||
|
super(new INDArray[]{x}, null);
|
||||||
|
// this.partitions = partitions;
|
||||||
|
this.numPartitions = numPartitions;
|
||||||
|
addArgs();
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<SDVariable> doDiff(List<SDVariable> i_v) {
|
public List<SDVariable> doDiff(List<SDVariable> i_v) {
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.transforms.custom;
|
package org.nd4j.linalg.api.ops.impl.transforms.custom;
|
||||||
|
|
||||||
|
import lombok.NonNull;
|
||||||
import org.apache.commons.lang3.ArrayUtils;
|
import org.apache.commons.lang3.ArrayUtils;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
@ -61,14 +62,8 @@ public class DynamicStitch extends DynamicCustomOp {
|
||||||
this.numPartitions = inputs.length;
|
this.numPartitions = inputs.length;
|
||||||
}
|
}
|
||||||
|
|
||||||
public DynamicStitch(INDArray[] inputs, INDArray[] indices) {
|
public DynamicStitch(@NonNull INDArray[] indices, @NonNull INDArray[] inputs) {
|
||||||
for (INDArray input : inputs) {
|
super(ArrayUtils.addAll(indices, inputs), null);
|
||||||
addInputArgument(input);
|
|
||||||
}
|
|
||||||
|
|
||||||
for (INDArray index : indices) {
|
|
||||||
addInputArgument(index);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -48,14 +48,14 @@ public class EqualTo extends BaseDynamicTransformOp {
|
||||||
super(inputs, outputs);
|
super(inputs, outputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
public EqualTo( INDArray x, INDArray y) {
|
|
||||||
addInputArgument(x, y);
|
|
||||||
}
|
|
||||||
|
|
||||||
public EqualTo(INDArray x, INDArray y, INDArray z){
|
public EqualTo(INDArray x, INDArray y, INDArray z){
|
||||||
this(new INDArray[]{x, y}, new INDArray[]{z});
|
this(new INDArray[]{x, y}, new INDArray[]{z});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public EqualTo(INDArray x, INDArray y){
|
||||||
|
this(new INDArray[]{x, y}, null);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String opName() {
|
public String opName() {
|
||||||
return "equals";
|
return "equals";
|
||||||
|
|
|
@ -27,6 +27,7 @@ import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
import org.nd4j.linalg.api.ops.Op;
|
import org.nd4j.linalg.api.ops.Op;
|
||||||
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.tensorflow.framework.AttrValue;
|
import org.tensorflow.framework.AttrValue;
|
||||||
import org.tensorflow.framework.GraphDef;
|
import org.tensorflow.framework.GraphDef;
|
||||||
import org.tensorflow.framework.NodeDef;
|
import org.tensorflow.framework.NodeDef;
|
||||||
|
@ -55,19 +56,21 @@ public class Fill extends DynamicCustomOp {
|
||||||
super(null,sameDiff, new SDVariable[] {shape}, false);
|
super(null,sameDiff, new SDVariable[] {shape}, false);
|
||||||
this.value = value;
|
this.value = value;
|
||||||
this.outputDataType = outputDataType;
|
this.outputDataType = outputDataType;
|
||||||
|
this.outputDataType = outputDataType;
|
||||||
addArgs();
|
addArgs();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public Fill(INDArray shape, DataType outputDataType, double value) {
|
||||||
|
super(new INDArray[]{shape, Nd4j.scalar(outputDataType, value)}, null);
|
||||||
|
this.value = value;
|
||||||
|
this.outputDataType = outputDataType;
|
||||||
|
}
|
||||||
|
|
||||||
public Fill(INDArray shape, INDArray result, double value) {
|
public Fill(INDArray shape, INDArray result, double value) {
|
||||||
super(null, shape, result, Collections.singletonList(value), null);
|
super(null, shape, result, Collections.singletonList(value), null);
|
||||||
this.value = value;
|
this.value = value;
|
||||||
}
|
}
|
||||||
|
|
||||||
public Fill(INDArray shape, DataType dataType, double value) {
|
|
||||||
super(null, shape, null, Collections.singletonList(value), null);
|
|
||||||
this.value = value;
|
|
||||||
}
|
|
||||||
|
|
||||||
public Fill(INDArray shape, INDArray value, INDArray result) {
|
public Fill(INDArray shape, INDArray value, INDArray result) {
|
||||||
super(null, new INDArray[]{shape, value}, new INDArray[]{result});
|
super(null, new INDArray[]{shape, value}, new INDArray[]{result});
|
||||||
}
|
}
|
||||||
|
|
|
@ -49,14 +49,14 @@ public class GreaterThan extends BaseDynamicTransformOp {
|
||||||
super(inputs, outputs);
|
super(inputs, outputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
public GreaterThan( INDArray x, INDArray y) {
|
|
||||||
addInputArgument(x,y);
|
|
||||||
}
|
|
||||||
|
|
||||||
public GreaterThan(INDArray x, INDArray y, INDArray z){
|
public GreaterThan(INDArray x, INDArray y, INDArray z){
|
||||||
this(new INDArray[]{x, y}, new INDArray[]{z});
|
this(new INDArray[]{x, y}, new INDArray[]{z});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public GreaterThan(INDArray x, INDArray y){
|
||||||
|
this(new INDArray[]{x, y}, null);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String opName() {
|
public String opName() {
|
||||||
return "greater";
|
return "greater";
|
||||||
|
|
|
@ -52,8 +52,8 @@ public class GreaterThanOrEqual extends BaseDynamicTransformOp {
|
||||||
this(new INDArray[]{x, y}, new INDArray[]{z});
|
this(new INDArray[]{x, y}, new INDArray[]{z});
|
||||||
}
|
}
|
||||||
|
|
||||||
public GreaterThanOrEqual(INDArray x, INDArray y) {
|
|
||||||
|
|
||||||
|
public GreaterThanOrEqual(INDArray x, INDArray y){
|
||||||
this(new INDArray[]{x, y}, null);
|
this(new INDArray[]{x, y}, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -45,8 +45,8 @@ public class IsNumericTensor extends DynamicCustomOp {
|
||||||
super(null, inputs, outputs);
|
super(null, inputs, outputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
public IsNumericTensor(INDArray input) {
|
public IsNumericTensor(INDArray inputs) {
|
||||||
addInputArgument(input);
|
super( new INDArray[] {inputs}, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -49,14 +49,14 @@ public class LessThan extends BaseDynamicTransformOp {
|
||||||
super(inputs, outputs);
|
super(inputs, outputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
public LessThan( INDArray x, INDArray y) {
|
|
||||||
addInputArgument(x,y);
|
|
||||||
}
|
|
||||||
|
|
||||||
public LessThan(INDArray x, INDArray y, INDArray z){
|
public LessThan(INDArray x, INDArray y, INDArray z){
|
||||||
this(new INDArray[]{x, y}, new INDArray[]{z});
|
this(new INDArray[]{x, y}, new INDArray[]{z});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public LessThan(INDArray x, INDArray y){
|
||||||
|
this(new INDArray[]{x, y}, null);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String opName() {
|
public String opName() {
|
||||||
return "less";
|
return "less";
|
||||||
|
|
|
@ -48,14 +48,14 @@ public class LessThanOrEqual extends BaseDynamicTransformOp {
|
||||||
super(inputs, outputs);
|
super(inputs, outputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
public LessThanOrEqual( INDArray x, INDArray y) {
|
|
||||||
addInputArgument(x,y);
|
|
||||||
}
|
|
||||||
|
|
||||||
public LessThanOrEqual(INDArray x, INDArray y, INDArray z){
|
public LessThanOrEqual(INDArray x, INDArray y, INDArray z){
|
||||||
this(new INDArray[]{x, y}, new INDArray[]{z});
|
this(new INDArray[]{x, y}, new INDArray[]{z});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public LessThanOrEqual(INDArray x, INDArray y){
|
||||||
|
this(new INDArray[]{x, y}, null);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String opName() {
|
public String opName() {
|
||||||
return "less_equal";
|
return "less_equal";
|
||||||
|
|
|
@ -48,12 +48,12 @@ public class Max extends BaseDynamicTransformOp {
|
||||||
super(new INDArray[]{first, second}, out == null ? null : new INDArray[]{out});
|
super(new INDArray[]{first, second}, out == null ? null : new INDArray[]{out});
|
||||||
}
|
}
|
||||||
|
|
||||||
public Max( INDArray[] inputs, INDArray[] outputs) {
|
public Max( INDArray first, INDArray second){
|
||||||
super(inputs, outputs);
|
this(first, second, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
public Max( INDArray x, INDArray y) {
|
public Max( INDArray[] inputs, INDArray[] outputs) {
|
||||||
addInputArgument(x,y);
|
super(inputs, outputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -48,12 +48,12 @@ public class Min extends BaseDynamicTransformOp {
|
||||||
super(new INDArray[]{first, second}, out == null ? null : new INDArray[]{out});
|
super(new INDArray[]{first, second}, out == null ? null : new INDArray[]{out});
|
||||||
}
|
}
|
||||||
|
|
||||||
public Min( INDArray[] inputs, INDArray[] outputs) {
|
public Min( INDArray first, INDArray second){
|
||||||
super(inputs, outputs);
|
this(first, second, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
public Min( INDArray x, INDArray y) {
|
public Min( INDArray[] inputs, INDArray[] outputs) {
|
||||||
addInputArgument(x,y);
|
super(inputs, outputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -48,14 +48,14 @@ public class NotEqualTo extends BaseDynamicTransformOp {
|
||||||
super(inputs, outputs);
|
super(inputs, outputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
public NotEqualTo( INDArray x, INDArray y) {
|
|
||||||
addInputArgument(x,y);
|
|
||||||
}
|
|
||||||
|
|
||||||
public NotEqualTo(INDArray x, INDArray y, INDArray z){
|
public NotEqualTo(INDArray x, INDArray y, INDArray z){
|
||||||
this(new INDArray[]{x, y}, new INDArray[]{z});
|
this(new INDArray[]{x, y}, new INDArray[]{z});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public NotEqualTo(INDArray x, INDArray y){
|
||||||
|
this(new INDArray[]{x, y}, null);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String opName() {
|
public String opName() {
|
||||||
return "not_equals";
|
return "not_equals";
|
||||||
|
|
|
@ -59,6 +59,17 @@ public class ReverseSequence extends DynamicCustomOp {
|
||||||
addArguments();
|
addArguments();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public ReverseSequence(INDArray x, INDArray seq_lengths, int seqDim, int batchDim){
|
||||||
|
super(new INDArray[]{x, seq_lengths}, null);
|
||||||
|
this.seqDim = seqDim;
|
||||||
|
this.batchDim = batchDim;
|
||||||
|
addArguments();
|
||||||
|
}
|
||||||
|
|
||||||
|
public ReverseSequence(INDArray x, INDArray seq_lengths){
|
||||||
|
this(x, seq_lengths, 1, 0);
|
||||||
|
}
|
||||||
|
|
||||||
private void addArguments(){
|
private void addArguments(){
|
||||||
addIArgument(seqDim);
|
addIArgument(seqDim);
|
||||||
addIArgument(batchDim);
|
addIArgument(batchDim);
|
||||||
|
@ -67,11 +78,6 @@ public class ReverseSequence extends DynamicCustomOp {
|
||||||
public ReverseSequence() {
|
public ReverseSequence() {
|
||||||
}
|
}
|
||||||
|
|
||||||
public ReverseSequence(INDArray x, INDArray seq_lengths, int seqDim, int batchDim) {
|
|
||||||
addInputArgument(x, seq_lengths);
|
|
||||||
addIArgument(seqDim, batchDim);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String opName() {
|
public String opName() {
|
||||||
return "reverse_sequence";
|
return "reverse_sequence";
|
||||||
|
|
|
@ -40,7 +40,7 @@ public class SegmentMax extends DynamicCustomOp {
|
||||||
}
|
}
|
||||||
|
|
||||||
public SegmentMax(INDArray data, INDArray segmentIds){
|
public SegmentMax(INDArray data, INDArray segmentIds){
|
||||||
addInputArgument(data, segmentIds);
|
super(new INDArray[]{data, segmentIds}, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
public SegmentMax(){ }
|
public SegmentMax(){ }
|
||||||
|
|
|
@ -39,12 +39,12 @@ public class SegmentMean extends DynamicCustomOp {
|
||||||
super(null, sameDiff, new SDVariable[] {data, segmentIds}, false);
|
super(null, sameDiff, new SDVariable[] {data, segmentIds}, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
public SegmentMean(INDArray data, INDArray segmentIds) {
|
|
||||||
addInputArgument(data, segmentIds);
|
|
||||||
}
|
|
||||||
|
|
||||||
public SegmentMean(){ }
|
public SegmentMean(){ }
|
||||||
|
|
||||||
|
public SegmentMean(INDArray data, INDArray segmentIds){
|
||||||
|
super(new INDArray[]{data, segmentIds}, null);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String opName(){
|
public String opName(){
|
||||||
return "segment_mean";
|
return "segment_mean";
|
||||||
|
|
|
@ -39,12 +39,12 @@ public class SegmentMin extends DynamicCustomOp {
|
||||||
super(null, sameDiff, new SDVariable[] {data, segmentIds}, false);
|
super(null, sameDiff, new SDVariable[] {data, segmentIds}, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
public SegmentMin(INDArray data, INDArray segmentIds) {
|
|
||||||
addInputArgument(data, segmentIds);
|
|
||||||
}
|
|
||||||
|
|
||||||
public SegmentMin(){ }
|
public SegmentMin(){ }
|
||||||
|
|
||||||
|
public SegmentMin(INDArray data, INDArray segmentIds){
|
||||||
|
super(new INDArray[]{data, segmentIds}, null);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String opName(){
|
public String opName(){
|
||||||
return "segment_min";
|
return "segment_min";
|
||||||
|
|
|
@ -39,12 +39,12 @@ public class SegmentProd extends DynamicCustomOp {
|
||||||
super(null, sameDiff, new SDVariable[] {data, segmentIds}, false);
|
super(null, sameDiff, new SDVariable[] {data, segmentIds}, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
public SegmentProd(INDArray data, INDArray segmentIds) {
|
|
||||||
addInputArgument(data, segmentIds);
|
|
||||||
}
|
|
||||||
|
|
||||||
public SegmentProd(){ }
|
public SegmentProd(){ }
|
||||||
|
|
||||||
|
public SegmentProd(INDArray data, INDArray segmentIds){
|
||||||
|
super(new INDArray[]{data, segmentIds}, null);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String opName(){
|
public String opName(){
|
||||||
return "segment_prod";
|
return "segment_prod";
|
||||||
|
|
|
@ -39,12 +39,12 @@ public class SegmentSum extends DynamicCustomOp {
|
||||||
super(null, sameDiff, new SDVariable[] {data, segmentIds}, false);
|
super(null, sameDiff, new SDVariable[] {data, segmentIds}, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
public SegmentSum(INDArray data, INDArray segmentIds) {
|
|
||||||
addInputArgument(data, segmentIds);
|
|
||||||
}
|
|
||||||
|
|
||||||
public SegmentSum(){ }
|
public SegmentSum(){ }
|
||||||
|
|
||||||
|
public SegmentSum(INDArray data, INDArray segmentIds){
|
||||||
|
super(new INDArray[]{data, segmentIds}, null);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String opName(){
|
public String opName(){
|
||||||
return "segment_sum";
|
return "segment_sum";
|
||||||
|
|
|
@ -42,7 +42,7 @@ public class Identity extends BaseDynamicTransformOp {
|
||||||
}
|
}
|
||||||
|
|
||||||
public Identity(INDArray x){
|
public Identity(INDArray x){
|
||||||
addInputArgument(x);
|
super(new INDArray[]{x}, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
public Identity(){ }
|
public Identity(){ }
|
||||||
|
|
|
@ -41,13 +41,14 @@ public class UnsortedSegmentMax extends DynamicCustomOp {
|
||||||
addIArgument(numSegments);
|
addIArgument(numSegments);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public UnsortedSegmentMax(){ }
|
||||||
|
|
||||||
public UnsortedSegmentMax(INDArray data, INDArray segmentIds, int numSegments){
|
public UnsortedSegmentMax(INDArray data, INDArray segmentIds, int numSegments){
|
||||||
addInputArgument(data, segmentIds);
|
super(new INDArray[]{data, segmentIds}, null);
|
||||||
|
this.numSegments = numSegments;
|
||||||
addIArgument(numSegments);
|
addIArgument(numSegments);
|
||||||
}
|
}
|
||||||
|
|
||||||
public UnsortedSegmentMax(){ }
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String opName(){
|
public String opName(){
|
||||||
return "unsorted_segment_max";
|
return "unsorted_segment_max";
|
||||||
|
|
|
@ -46,7 +46,8 @@ public class UnsortedSegmentMean extends DynamicCustomOp {
|
||||||
}
|
}
|
||||||
|
|
||||||
public UnsortedSegmentMean(INDArray data, INDArray segmentIds, int numSegments){
|
public UnsortedSegmentMean(INDArray data, INDArray segmentIds, int numSegments){
|
||||||
addInputArgument(data, segmentIds);
|
super(new INDArray[]{data, segmentIds}, null);
|
||||||
|
this.numSegments = numSegments;
|
||||||
addIArgument(numSegments);
|
addIArgument(numSegments);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -46,7 +46,8 @@ public class UnsortedSegmentMin extends DynamicCustomOp {
|
||||||
}
|
}
|
||||||
|
|
||||||
public UnsortedSegmentMin(INDArray data, INDArray segmentIds, int numSegments){
|
public UnsortedSegmentMin(INDArray data, INDArray segmentIds, int numSegments){
|
||||||
addInputArgument(data, segmentIds);
|
super(new INDArray[]{data, segmentIds}, null);
|
||||||
|
this.numSegments = numSegments;
|
||||||
addIArgument(numSegments);
|
addIArgument(numSegments);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -46,7 +46,8 @@ public class UnsortedSegmentProd extends DynamicCustomOp {
|
||||||
}
|
}
|
||||||
|
|
||||||
public UnsortedSegmentProd(INDArray data, INDArray segmentIds, int numSegments){
|
public UnsortedSegmentProd(INDArray data, INDArray segmentIds, int numSegments){
|
||||||
addInputArgument(data, segmentIds);
|
super(new INDArray[]{data, segmentIds}, null);
|
||||||
|
this.numSegments = numSegments;
|
||||||
addIArgument(numSegments);
|
addIArgument(numSegments);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -39,18 +39,18 @@ public class UnsortedSegmentSqrtN extends DynamicCustomOp {
|
||||||
|
|
||||||
private int numSegments;
|
private int numSegments;
|
||||||
|
|
||||||
public UnsortedSegmentSqrtN(INDArray data, INDArray segmentIds, int numSegments) {
|
|
||||||
addInputArgument(data, segmentIds);
|
|
||||||
addIArgument(numSegments);
|
|
||||||
this.numSegments = numSegments;
|
|
||||||
}
|
|
||||||
|
|
||||||
public UnsortedSegmentSqrtN(SameDiff sameDiff, SDVariable data, SDVariable segmentIds, int numSegments) {
|
public UnsortedSegmentSqrtN(SameDiff sameDiff, SDVariable data, SDVariable segmentIds, int numSegments) {
|
||||||
super(null, sameDiff, new SDVariable[] {data, segmentIds}, false);
|
super(null, sameDiff, new SDVariable[] {data, segmentIds}, false);
|
||||||
this.numSegments = numSegments;
|
this.numSegments = numSegments;
|
||||||
addIArgument(numSegments);
|
addIArgument(numSegments);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public UnsortedSegmentSqrtN(INDArray data, INDArray segmentIds, int numSegments){
|
||||||
|
super(new INDArray[]{data, segmentIds}, null);
|
||||||
|
this.numSegments = numSegments;
|
||||||
|
addIArgument(numSegments);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String opName(){
|
public String opName(){
|
||||||
return "unsorted_segment_sqrt_n";
|
return "unsorted_segment_sqrt_n";
|
||||||
|
|
|
@ -47,7 +47,8 @@ public class UnsortedSegmentSum extends DynamicCustomOp {
|
||||||
}
|
}
|
||||||
|
|
||||||
public UnsortedSegmentSum(INDArray data, INDArray segmentIds, int numSegments){
|
public UnsortedSegmentSum(INDArray data, INDArray segmentIds, int numSegments){
|
||||||
addInputArgument(data, segmentIds);
|
super(new INDArray[]{data, segmentIds}, null);
|
||||||
|
this.numSegments = numSegments;
|
||||||
addIArgument(numSegments);
|
addIArgument(numSegments);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -130,13 +130,20 @@ public class NDValidation {
|
||||||
" type; got array with non-integer data type " + v.dataType());
|
" type; got array with non-integer data type " + v.dataType());
|
||||||
}
|
}
|
||||||
|
|
||||||
public static void validateInteger(String opName, String inputName, INDArray[] vars) {
|
/**
|
||||||
for (INDArray v : vars) {
|
* Validate that the operation is being applied on an integer type INDArray []
|
||||||
|
*
|
||||||
|
* @param opName Operation name to print in the exception
|
||||||
|
* @param inputName Name of the input to the op to validate
|
||||||
|
* @param v Variable to validate datatype for (input to operation)
|
||||||
|
*/
|
||||||
|
public static void validateInteger(String opName, String inputName, INDArray [] v) {
|
||||||
if (v == null)
|
if (v == null)
|
||||||
return;
|
return;
|
||||||
if (!v.dataType().isIntType())
|
for (int i = 0; i < v.length; i++) {
|
||||||
|
if (!v[i].dataType().isIntType())
|
||||||
throw new IllegalStateException("Input \"" + inputName + "\" for operation \"" + opName + "\" must be an integer" +
|
throw new IllegalStateException("Input \"" + inputName + "\" for operation \"" + opName + "\" must be an integer" +
|
||||||
" type; got array with non-integer data type " + v.dataType());
|
" type; got array with non-integer data type member" + v[i].dataType());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -246,10 +253,11 @@ public class NDValidation {
|
||||||
}
|
}
|
||||||
|
|
||||||
public static boolean isSameType(INDArray[] x) {
|
public static boolean isSameType(INDArray[] x) {
|
||||||
DataType firstDataType = x[0].dataType();
|
if(x.length == 0)
|
||||||
if (x.length > 1) {
|
return true;
|
||||||
for (int i = 1; i < x.length; ++i) {
|
DataType first = x[0].dataType();
|
||||||
if (firstDataType != x[i].dataType())
|
for( int i=1; i<x.length; i++ ){
|
||||||
|
if(first != x[i].dataType()){
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -2006,10 +2006,9 @@ public class Nd4j {
|
||||||
* @return the linearly spaced vector
|
* @return the linearly spaced vector
|
||||||
*/
|
*/
|
||||||
public static INDArray linspace(@NonNull DataType dataType, double lower, double step, long num) {
|
public static INDArray linspace(@NonNull DataType dataType, double lower, double step, long num) {
|
||||||
Preconditions.checkState(dataType.isFPType());
|
Preconditions.checkState(dataType.isFPType(), "Datatype must be a floating point type for linspace, got %s", dataType);
|
||||||
if (num == 1)
|
if (num == 1)
|
||||||
return Nd4j.scalar(dataType, lower);
|
return Nd4j.scalar(dataType, lower);
|
||||||
|
|
||||||
return Nd4j.getExecutioner().exec(new Linspace(lower, num, step, dataType));
|
return Nd4j.getExecutioner().exec(new Linspace(lower, num, step, dataType));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2022,10 +2021,9 @@ public class Nd4j {
|
||||||
* @return the linearly spaced vector
|
* @return the linearly spaced vector
|
||||||
*/
|
*/
|
||||||
public static INDArray linspace( double lower, double upper, long num, @NonNull DataType dataType) {
|
public static INDArray linspace( double lower, double upper, long num, @NonNull DataType dataType) {
|
||||||
Preconditions.checkState(dataType.isFPType());
|
Preconditions.checkState(dataType.isFPType(), "Datatype must be a floating point type for linspace, got %s", dataType);
|
||||||
if (num == 1)
|
if (num == 1)
|
||||||
return Nd4j.scalar(dataType, lower);
|
return Nd4j.scalar(dataType, lower);
|
||||||
|
|
||||||
return Nd4j.getExecutioner().exec(new Linspace(lower, upper, num, dataType));
|
return Nd4j.getExecutioner().exec(new Linspace(lower, upper, num, dataType));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -159,7 +159,7 @@ public class BooleanIndexing {
|
||||||
if (to.length() != from.length())
|
if (to.length() != from.length())
|
||||||
throw new IllegalStateException("Mis matched length for to and from");
|
throw new IllegalStateException("Mis matched length for to and from");
|
||||||
|
|
||||||
Nd4j.getExecutioner().exec(new CompareAndSet(to, from, condition));
|
Nd4j.getExecutioner().exec(new CompareAndSet(to, from, to, condition));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -177,7 +177,7 @@ public class BooleanIndexing {
|
||||||
if (to.length() != from.length())
|
if (to.length() != from.length())
|
||||||
throw new IllegalStateException("Mis matched length for to and from");
|
throw new IllegalStateException("Mis matched length for to and from");
|
||||||
|
|
||||||
Nd4j.getExecutioner().exec(new CompareAndReplace(to, from, condition));
|
Nd4j.getExecutioner().exec(new CompareAndReplace(to, from, to, condition));
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
|
|
@ -19,13 +19,11 @@ package org.nd4j.linalg.learning;
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.NonNull;
|
import lombok.NonNull;
|
||||||
import org.apache.commons.math3.util.FastMath;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.shape.Shape;
|
import org.nd4j.linalg.api.shape.Shape;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.indexing.NDArrayIndex;
|
import org.nd4j.linalg.indexing.NDArrayIndex;
|
||||||
import org.nd4j.linalg.learning.config.Nadam;
|
import org.nd4j.linalg.learning.config.Nadam;
|
||||||
import org.nd4j.linalg.ops.transforms.Transforms;
|
|
||||||
|
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
|
|
@ -792,6 +792,19 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
boolean keepDims = op.isKeepDims();
|
||||||
|
long[] retShape = Shape.reductionShape(x, dimension, true, keepDims);
|
||||||
|
|
||||||
|
if(z == null || x == z) {
|
||||||
|
val ret = Nd4j.createUninitialized(DataType.LONG, retShape);
|
||||||
|
|
||||||
|
setZ(ret, op, oc);
|
||||||
|
z = ret;
|
||||||
|
} else if(!Arrays.equals(retShape, z.shape())){
|
||||||
|
throw new IllegalStateException("Z array shape does not match expected return type for op " + op
|
||||||
|
+ ": expected shape " + Arrays.toString(retShape) + ", z.shape()=" + Arrays.toString(z.shape()));
|
||||||
|
}
|
||||||
|
|
||||||
long st = profilingConfigurableHookIn(op);
|
long st = profilingConfigurableHookIn(op);
|
||||||
|
|
||||||
checkForCompression(op);
|
checkForCompression(op);
|
||||||
|
@ -2060,7 +2073,7 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
||||||
throw new ND4JIllegalStateException("Op name " + op.opName() + " failed to execute. You can't execute non-inplace CustomOp without outputs being specified");
|
throw new ND4JIllegalStateException("Op name " + op.opName() + " failed to execute. You can't execute non-inplace CustomOp without outputs being specified");
|
||||||
|
|
||||||
for (val shape: list)
|
for (val shape: list)
|
||||||
op.addOutputArgument(Nd4j.create(shape));
|
op.addOutputArgument(Nd4j.create(shape, false));
|
||||||
|
|
||||||
shapeOverride = true;
|
shapeOverride = true;
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
|
|
|
@ -772,8 +772,10 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
|
||||||
|
|
||||||
if (y != null) {
|
if (y != null) {
|
||||||
|
|
||||||
if (z == null)
|
if (z == null) {
|
||||||
setZ(Nd4j.create(op.resultType(), x.shape()), op, oc);
|
setZ(Nd4j.create(op.resultType(), x.shape()), op, oc);
|
||||||
|
z = getZ(op, oc);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
op.validateDataTypes(oc, experimentalMode.get());
|
op.validateDataTypes(oc, experimentalMode.get());
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -234,9 +234,9 @@ public class BooleanIndexingTest extends BaseNd4jTest {
|
||||||
INDArray array = Nd4j.create(new double[] {1, 2, 0, 4, 5});
|
INDArray array = Nd4j.create(new double[] {1, 2, 0, 4, 5});
|
||||||
INDArray comp = Nd4j.create(new double[] {1, 2, 3, 4, 5});
|
INDArray comp = Nd4j.create(new double[] {1, 2, 3, 4, 5});
|
||||||
|
|
||||||
Nd4j.getExecutioner().exec(new CompareAndReplace(array, comp, Conditions.lessThan(1)));
|
INDArray z = Nd4j.exec(new CompareAndReplace(array, comp, Conditions.lessThan(1)));
|
||||||
|
|
||||||
assertEquals(comp, array);
|
assertEquals(comp, z);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -256,9 +256,9 @@ public class BooleanIndexingTest extends BaseNd4jTest {
|
||||||
INDArray y = Nd4j.create(new double[] {2, 4, 3, 4, 5});
|
INDArray y = Nd4j.create(new double[] {2, 4, 3, 4, 5});
|
||||||
INDArray comp = Nd4j.create(new double[] {2, 4, 0, 4, 5});
|
INDArray comp = Nd4j.create(new double[] {2, 4, 0, 4, 5});
|
||||||
|
|
||||||
Nd4j.getExecutioner().exec(new CompareAndReplace(x, y, Conditions.epsNotEquals(0.0)));
|
INDArray z = Nd4j.exec(new CompareAndReplace(x, y, Conditions.epsNotEquals(0.0)));
|
||||||
|
|
||||||
assertEquals(comp, x);
|
assertEquals(comp, z);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -267,9 +267,9 @@ public class BooleanIndexingTest extends BaseNd4jTest {
|
||||||
INDArray y = Nd4j.create(new double[] {2, 4, 3, 4, 5});
|
INDArray y = Nd4j.create(new double[] {2, 4, 3, 4, 5});
|
||||||
INDArray comp = Nd4j.create(new double[] {2, 4, 3, 4, 5});
|
INDArray comp = Nd4j.create(new double[] {2, 4, 3, 4, 5});
|
||||||
|
|
||||||
Nd4j.getExecutioner().exec(new CompareAndReplace(x, y, Conditions.lessThan(4)));
|
INDArray z = Nd4j.exec(new CompareAndReplace(x, y, Conditions.lessThan(4)));
|
||||||
|
|
||||||
assertEquals(comp, x);
|
assertEquals(comp, z);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -278,9 +278,9 @@ public class BooleanIndexingTest extends BaseNd4jTest {
|
||||||
INDArray y = Nd4j.create(new double[] {2, 4, 3, 4, 5});
|
INDArray y = Nd4j.create(new double[] {2, 4, 3, 4, 5});
|
||||||
INDArray comp = Nd4j.create(new double[] {2, 2, 3, 4, 5});
|
INDArray comp = Nd4j.create(new double[] {2, 2, 3, 4, 5});
|
||||||
|
|
||||||
Nd4j.getExecutioner().exec(new CompareAndReplace(x, y, Conditions.lessThan(2)));
|
INDArray z = Nd4j.exec(new CompareAndReplace(x, y, Conditions.lessThan(2)));
|
||||||
|
|
||||||
assertEquals(comp, x);
|
assertEquals(comp, z);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue