Base namespace (#287)
* wip Signed-off-by: Robert Altena <Rob@Ra-ai.com> * up to assign operation. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * fix Imax, IMin. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * concat. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * dynamicPartition Signed-off-by: Robert Altena <Rob@Ra-ai.com> * new ops up to gte. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * wip Signed-off-by: Robert Altena <Rob@Ra-ai.com> * updated review items. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * up to matchCondition. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * wip Signed-off-by: Robert Altena <Rob@Ra-ai.com> * up to OneHot. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * wip. up to permute. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * wip. up to rank. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * wip. up to scatterMul. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * resolving code review issues. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * wip. inclides UnsortedSegment ops. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * wip. up to stridedSlice. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * fix stridedSlice. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * first pass of SDBaseops.kt complete. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * fix review items. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * wip Signed-off-by: Robert Altena <Rob@Ra-ai.com> * put branch in compilable state. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * add NDBaseTest. fix dynamicPartition signature. failed fix of assign. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * make tests public. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * adds tests up to invertedPermutation. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * fix ScalarEquals, Assign. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * wip Signed-off-by: Robert Altena <Rob@Ra-ai.com> * updates NDBaseTest. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * updates 'check' comments based on test pass/fail. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * fix scalar ops. Update tests, Signed-off-by: Robert Altena <Rob@Ra-ai.com> * dev-tools review items. wip. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * dev-tools code review items. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * Test fixes Signed-off-by: Alex Black <blacka101@gmail.com> * complete review items. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * Comment for logged issue; fix test case Signed-off-by: Alex Black <blacka101@gmail.com> * Fixes Signed-off-by: Alex Black <blacka101@gmail.com> * More fixes * wip Signed-off-by: Robert Altena <Rob@Ra-ai.com> * undo changes to Nd4jCpu.java Signed-off-by: Robert Altena <Rob@Ra-ai.com> * update tests. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * Fixes and regenerate Signed-off-by: AlexDBlack <blacka101@gmail.com> * Small test fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * small fixes to tests. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * Cleanup Signed-off-by: Alex Black <blacka101@gmail.com> * Fixes Signed-off-by: Alex Black <blacka101@gmail.com> * Small CUDAExecutioner fix Signed-off-by: Alex Black <blacka101@gmail.com> * Fixes Signed-off-by: Alex Black <blacka101@gmail.com> * Small CudaExecutioner fix Signed-off-by: Alex Black <blacka101@gmail.com> * Another small CudaExecutioner fix Signed-off-by: Alex Black <blacka101@gmail.com> * Another small CudaExecutioner fix Signed-off-by: Alex Black <blacka101@gmail.com> Co-authored-by: Robert Altena <Rob@Ra-ai.com>master
parent
a5db0e33be
commit
191bda3228
|
@ -45,15 +45,14 @@ public class IMax extends BaseIndexAccumulation {
|
|||
super(x, z, dimensions);
|
||||
}
|
||||
|
||||
public IMax(INDArray x, boolean keepDims, int... dimensions) {
|
||||
super(x, keepDims, dimensions);
|
||||
|
||||
}
|
||||
|
||||
public IMax(INDArray x, int... dimensions) {
|
||||
super(x, null, dimensions);
|
||||
}
|
||||
|
||||
public IMax(INDArray x, boolean keepDims, int... dimensions) {
|
||||
super(x, null, dimensions);
|
||||
this.keepDims = keepDims;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int opNum() {
|
||||
|
|
|
@ -53,6 +53,7 @@ public class IMin extends BaseIndexAccumulation {
|
|||
}
|
||||
|
||||
|
||||
|
||||
@Override
|
||||
public int opNum() {
|
||||
return 1;
|
||||
|
|
|
@ -306,4 +306,3 @@ public class Mmul extends DynamicCustomOp {
|
|||
return Collections.singletonList(dataTypes.get(0));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -51,6 +51,35 @@ public class TensorMmul extends DynamicCustomOp {
|
|||
protected boolean addedEdges;
|
||||
protected MMulTranspose mMulTranspose;
|
||||
|
||||
|
||||
public TensorMmul(INDArray x, INDArray y, int[][] axes) {
|
||||
this(x,y,axes[0], axes[1], false, false, false);
|
||||
}
|
||||
|
||||
/**
|
||||
* Initialize with the given
|
||||
* input, pairwise transform, result, and number
|
||||
* of elements
|
||||
*
|
||||
* @param x the input
|
||||
* @param y the pairwise transform
|
||||
* @param z the result
|
||||
*/
|
||||
public TensorMmul(INDArray x, INDArray y, INDArray z, int[][] axes) {
|
||||
this(x, y, axes[0], axes[1], false, false, false);
|
||||
}
|
||||
|
||||
public TensorMmul(INDArray x, INDArray y, int[] dimensionsX, int[] dimensionsY,
|
||||
boolean transposeX, boolean transposeY, boolean transposeZ) {
|
||||
super(null,new INDArray[]{x, y},null);
|
||||
this.axes = new int[][]{dimensionsX, dimensionsY};
|
||||
addIArgument(dimensionsX.length);
|
||||
addIArgument(dimensionsX);
|
||||
addIArgument(dimensionsY.length);
|
||||
addIArgument(dimensionsY);
|
||||
addBArgument(transposeX, transposeY, transposeZ);
|
||||
}
|
||||
|
||||
public TensorMmul(SameDiff sameDiff,
|
||||
SDVariable i_v1,
|
||||
SDVariable i_v2,
|
||||
|
@ -229,34 +258,6 @@ public class TensorMmul extends DynamicCustomOp {
|
|||
return sameDiff.reshape(ret, aPlusB);
|
||||
}
|
||||
|
||||
|
||||
public TensorMmul(INDArray x, INDArray y, int[][] axes) {
|
||||
super(null,new INDArray[]{x, y},null);
|
||||
this.axes = axes;
|
||||
this.extraArgs = new Object[] {axes};
|
||||
}
|
||||
|
||||
/**
|
||||
* Initialize with the given
|
||||
* input, pairwise transform, result, and number
|
||||
* of elements
|
||||
*
|
||||
* @param x the input
|
||||
* @param y the pairwise transform
|
||||
* @param z the result
|
||||
*/
|
||||
public TensorMmul(INDArray x, INDArray y, INDArray z, int[][] axes) {
|
||||
super(null,new INDArray[]{x, y, z},null);
|
||||
this.axes = axes;
|
||||
}
|
||||
|
||||
public TensorMmul(INDArray x, INDArray y, int[] dimensionsX, int[] dimensionsY,
|
||||
boolean transposeX, boolean transposeY, boolean transposeZ) {
|
||||
super(null,new INDArray[]{x, y},null);
|
||||
this.axes = new int[][]{dimensionsX, dimensionsY};
|
||||
addBArgument(transposeX, transposeY, transposeZ);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "tensordot";
|
||||
|
|
|
@ -48,7 +48,6 @@ public class NormMax extends BaseReduceFloatOp {
|
|||
super(x, null, z, dimensions);
|
||||
}
|
||||
|
||||
|
||||
public NormMax(INDArray x, int... dimensions) {
|
||||
super(x, dimensions);
|
||||
}
|
||||
|
|
|
@ -48,6 +48,10 @@ public class SquaredNorm extends BaseReduceFloatOp {
|
|||
|
||||
public SquaredNorm(){}
|
||||
|
||||
public SquaredNorm(INDArray x, int... dimensions){
|
||||
super(x, dimensions);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int opNum() {
|
||||
return 7;
|
||||
|
|
|
@ -52,11 +52,15 @@ public class MatchCondition extends BaseReduceLongOp {
|
|||
|
||||
public MatchCondition() {}
|
||||
|
||||
|
||||
public MatchCondition(INDArray x, Condition condition, int... dimensions) {
|
||||
this(x, Nd4j.EPS_THRESHOLD, condition, dimensions);
|
||||
}
|
||||
|
||||
public MatchCondition(INDArray x, Condition condition, boolean keepDims, int... dimensions) {
|
||||
this(x, Nd4j.EPS_THRESHOLD, condition, dimensions);
|
||||
this.keepDims = keepDims;
|
||||
}
|
||||
|
||||
public MatchCondition(INDArray x, double eps, Condition condition, int... dimensions) {
|
||||
super(x);
|
||||
this.compare = condition.getValue();
|
||||
|
@ -68,10 +72,6 @@ public class MatchCondition extends BaseReduceLongOp {
|
|||
defineDimensions(dimensions);
|
||||
}
|
||||
|
||||
public MatchCondition(INDArray in, Condition condition, boolean keepDim, int... dimensions) {
|
||||
this(in, condition, dimensions);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int opNum() {
|
||||
return 2;
|
||||
|
|
|
@ -41,7 +41,6 @@ public class Sum extends BaseReduceSameOp {
|
|||
super(sameDiff, i_v, i_v2, dimensions);
|
||||
}
|
||||
|
||||
|
||||
public Sum() {
|
||||
}
|
||||
|
||||
|
|
|
@ -40,7 +40,7 @@ public class ScalarEquals extends BaseScalarBoolOp {
|
|||
}
|
||||
|
||||
public ScalarEquals(INDArray x, Number num) {
|
||||
super(x, num);
|
||||
this(x, null, num);
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -19,9 +19,11 @@ package org.nd4j.linalg.api.ops.impl.scalar.comparison;
|
|||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.imports.NoOpNameFoundException;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.BaseScalarBoolOp;
|
||||
import org.nd4j.linalg.api.ops.BaseScalarOp;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
@ -56,7 +58,7 @@ public class ScalarGreaterThan extends BaseScalarBoolOp {
|
|||
}
|
||||
|
||||
public ScalarGreaterThan(INDArray x, Number num) {
|
||||
super(x, num);
|
||||
this(x, null, num);
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -19,9 +19,11 @@ package org.nd4j.linalg.api.ops.impl.scalar.comparison;
|
|||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.imports.NoOpNameFoundException;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.BaseScalarBoolOp;
|
||||
import org.nd4j.linalg.api.ops.BaseScalarOp;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
@ -40,7 +42,7 @@ public class ScalarGreaterThanOrEqual extends BaseScalarBoolOp {
|
|||
}
|
||||
|
||||
public ScalarGreaterThanOrEqual(INDArray x, Number num) {
|
||||
super(x, num);
|
||||
this(x, null, num);
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -18,9 +18,11 @@ package org.nd4j.linalg.api.ops.impl.scalar.comparison;
|
|||
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.BaseScalarBoolOp;
|
||||
import org.nd4j.linalg.api.ops.BaseScalarOp;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
@ -39,7 +41,7 @@ public class ScalarLessThan extends BaseScalarBoolOp {
|
|||
}
|
||||
|
||||
public ScalarLessThan(INDArray x, Number num) {
|
||||
super(x, num);
|
||||
this(x, null, num);
|
||||
}
|
||||
|
||||
public ScalarLessThan(SameDiff sameDiff, SDVariable i_v, Number scalar, boolean inPlace) {
|
||||
|
|
|
@ -19,9 +19,11 @@ package org.nd4j.linalg.api.ops.impl.scalar.comparison;
|
|||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.imports.NoOpNameFoundException;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.BaseScalarBoolOp;
|
||||
import org.nd4j.linalg.api.ops.BaseScalarOp;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
@ -49,7 +51,7 @@ public class ScalarLessThanOrEqual extends BaseScalarBoolOp {
|
|||
}
|
||||
|
||||
public ScalarLessThanOrEqual(INDArray x, Number num) {
|
||||
super(x, num);
|
||||
this(x, null, num);
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -41,10 +41,9 @@ public class ScalarNotEquals extends BaseScalarBoolOp {
|
|||
}
|
||||
|
||||
public ScalarNotEquals(INDArray x, Number num) {
|
||||
super(x, num);
|
||||
this(x, null, num);
|
||||
}
|
||||
|
||||
|
||||
public ScalarNotEquals(SameDiff sameDiff, SDVariable i_v, Number scalar) {
|
||||
super(sameDiff, i_v, scalar);
|
||||
}
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.nd4j.linalg.api.ops.impl.scatter;
|
||||
|
||||
import lombok.NonNull;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
@ -44,12 +45,12 @@ public class ScatterAdd extends DynamicCustomOp {
|
|||
super(null, sameDiff, new SDVariable[]{ref, indices, updates}, false);
|
||||
}
|
||||
|
||||
public ScatterAdd(INDArray ref, INDArray indices, INDArray updates) {
|
||||
addInputArgument(ref, indices, updates);
|
||||
}
|
||||
|
||||
public ScatterAdd(){}
|
||||
|
||||
public ScatterAdd(@NonNull INDArray ref, @NonNull INDArray indices, @NonNull INDArray update){
|
||||
super(new INDArray[]{ref, indices, update}, null);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "scatter_add";
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.nd4j.linalg.api.ops.impl.scatter;
|
||||
|
||||
import lombok.NonNull;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
@ -44,11 +45,12 @@ public class ScatterDiv extends DynamicCustomOp {
|
|||
super(null, sameDiff, new SDVariable[]{ref, indices, updates}, false);
|
||||
}
|
||||
|
||||
public ScatterDiv(INDArray ref, INDArray indices, INDArray updates) {
|
||||
addInputArgument(ref, indices, updates);
|
||||
public ScatterDiv() {}
|
||||
|
||||
public ScatterDiv(@NonNull INDArray ref, @NonNull INDArray indices, @NonNull INDArray update){
|
||||
super(new INDArray[]{ref, indices, update}, null);
|
||||
}
|
||||
|
||||
public ScatterDiv() {}
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.nd4j.linalg.api.ops.impl.scatter;
|
||||
|
||||
import lombok.NonNull;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
@ -42,12 +43,12 @@ public class ScatterMax extends DynamicCustomOp {
|
|||
super(null, sameDiff, new SDVariable[]{ref, indices, updates}, false);
|
||||
}
|
||||
|
||||
public ScatterMax(INDArray ref, INDArray indices, INDArray updates) {
|
||||
addInputArgument(ref, indices, updates);
|
||||
}
|
||||
|
||||
public ScatterMax() {}
|
||||
|
||||
public ScatterMax(@NonNull INDArray ref, @NonNull INDArray indices, @NonNull INDArray update){
|
||||
super(new INDArray[]{ref, indices, update}, null);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "scatter_max";
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.nd4j.linalg.api.ops.impl.scatter;
|
||||
|
||||
import lombok.NonNull;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
@ -42,12 +43,12 @@ public class ScatterMin extends DynamicCustomOp {
|
|||
super(null, sameDiff, new SDVariable[]{ref, indices, updates}, false);
|
||||
}
|
||||
|
||||
public ScatterMin(INDArray ref, INDArray indices, INDArray updates) {
|
||||
addInputArgument(ref, indices, updates);
|
||||
}
|
||||
|
||||
public ScatterMin() {}
|
||||
|
||||
public ScatterMin(@NonNull INDArray ref, @NonNull INDArray indices, @NonNull INDArray update){
|
||||
super(new INDArray[]{ref, indices, update}, null);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "scatter_min";
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.nd4j.linalg.api.ops.impl.scatter;
|
||||
|
||||
import lombok.NonNull;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
@ -44,12 +45,12 @@ public class ScatterMul extends DynamicCustomOp {
|
|||
super(null, sameDiff, new SDVariable[]{ref, indices, updates}, false);
|
||||
}
|
||||
|
||||
public ScatterMul(INDArray ref, INDArray indices, INDArray updates) {
|
||||
addInputArgument(ref, indices, updates);
|
||||
}
|
||||
|
||||
public ScatterMul() {}
|
||||
|
||||
public ScatterMul(@NonNull INDArray ref, @NonNull INDArray indices, @NonNull INDArray update){
|
||||
super(new INDArray[]{ref, indices, update}, null);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "scatter_mul";
|
||||
|
|
|
@ -44,12 +44,12 @@ public class ScatterSub extends DynamicCustomOp {
|
|||
super(null, sameDiff, new SDVariable[]{ref, indices, updates}, false);
|
||||
}
|
||||
|
||||
public ScatterSub(INDArray ref, INDArray indices, INDArray updates) {
|
||||
addInputArgument(ref, indices, updates);
|
||||
}
|
||||
|
||||
public ScatterSub() {}
|
||||
|
||||
public ScatterSub(INDArray ref, INDArray indices, INDArray update){
|
||||
super(new INDArray[]{ref, indices, update}, null);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "scatter_sub";
|
||||
|
|
|
@ -54,12 +54,12 @@ public class ScatterUpdate extends DynamicCustomOp {
|
|||
super(null, sameDiff, new SDVariable[]{ref, indices, updates}, false);
|
||||
}
|
||||
|
||||
public ScatterUpdate(INDArray ref, INDArray indices, INDArray updates) {
|
||||
addInputArgument(ref, indices, updates);
|
||||
}
|
||||
|
||||
public ScatterUpdate(){}
|
||||
|
||||
public ScatterUpdate(INDArray ref, INDArray indices, INDArray update){
|
||||
super(new INDArray[]{ref, indices, update}, null);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "scatter_upd";
|
||||
|
|
|
@ -44,7 +44,7 @@ public class Concat extends DynamicCustomOp {
|
|||
}
|
||||
|
||||
public Concat(int concatDimension, INDArray... arrays) {
|
||||
super(null, arrays, new INDArray[0]);
|
||||
super(null, arrays, null);
|
||||
this.concatDimension = concatDimension;
|
||||
addIArgument(concatDimension);
|
||||
}
|
||||
|
|
|
@ -67,15 +67,16 @@ public class ExpandDims extends DynamicCustomOp {
|
|||
super(null, inputs, outputs);
|
||||
}
|
||||
|
||||
public ExpandDims(INDArray input, int axis) {
|
||||
addInputArgument(input);
|
||||
addIArgument(axis);
|
||||
}
|
||||
|
||||
public ExpandDims(SameDiff sameDiff, SDVariable[] args, boolean inPlace) {
|
||||
super(null, sameDiff, args, inPlace);
|
||||
}
|
||||
|
||||
public ExpandDims(INDArray x, int axis){
|
||||
super(new INDArray[]{x}, null);
|
||||
this.jaxis = axis;
|
||||
addIArgument(axis);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
||||
val targetNode = TFGraphMapper.getNodeWithNameFromGraph(graph, nodeDef.getInput(1));
|
||||
|
|
|
@ -42,6 +42,10 @@ public class GatherNd extends DynamicCustomOp {
|
|||
super(new INDArray[]{df, indices}, null);
|
||||
}
|
||||
|
||||
public GatherNd(INDArray[] inputs, INDArray[] outputs){
|
||||
super(inputs, outputs);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "gather_nd";
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
|
|
|
@ -67,8 +67,7 @@ public class OneHot extends DynamicCustomOp {
|
|||
}
|
||||
|
||||
public OneHot(INDArray indices, int depth) {
|
||||
addInputArgument(indices);
|
||||
addIArgument(depth);
|
||||
this(indices, null, depth, 0, 1.0, 0.0);
|
||||
}
|
||||
|
||||
public OneHot(INDArray indices, INDArray output, int depth, int axis, double on, double off) {
|
||||
|
@ -80,14 +79,16 @@ public class OneHot extends DynamicCustomOp {
|
|||
addArgs();
|
||||
}
|
||||
|
||||
public OneHot(INDArray indices, int depth, int axis, double on, double off, DataType dataType) {
|
||||
addInputArgument(indices);
|
||||
addIArgument(depth, axis);
|
||||
addTArgument(on, off);
|
||||
addDArgument(dataType);
|
||||
public OneHot(INDArray indices, int depth, int axis, double on, double off) {
|
||||
this(indices, null, depth, axis, on, off);
|
||||
}
|
||||
|
||||
|
||||
public OneHot(INDArray indices, int depth, int axis, double on, double off, DataType dataType) {
|
||||
this(indices, null, depth, axis, on, off);
|
||||
this.outputType = dataType;
|
||||
if (outputType != null)
|
||||
addDArgument(outputType);
|
||||
}
|
||||
|
||||
protected void addArgs() {
|
||||
addIArgument(jaxis);
|
||||
|
|
|
@ -23,6 +23,7 @@ import org.nd4j.base.Preconditions;
|
|||
import org.nd4j.imports.descriptors.properties.PropertyMapping;
|
||||
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.tensorflow.framework.AttrValue;
|
||||
import org.tensorflow.framework.GraphDef;
|
||||
|
@ -44,6 +45,10 @@ public class ParallelStack extends DynamicCustomOp {
|
|||
super(null, sameDiff, values, false);
|
||||
}
|
||||
|
||||
public ParallelStack(INDArray[] inputs){
|
||||
super(inputs, null);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "parallel_stack";
|
||||
|
|
|
@ -55,15 +55,16 @@ public class Permute extends Transpose {
|
|||
addIArgument(permuteDims);
|
||||
}
|
||||
|
||||
public Permute(INDArray input, int... permuteDims){
|
||||
addInputArgument(input);
|
||||
addIArgument(permuteDims);
|
||||
}
|
||||
|
||||
public Permute(SameDiff sd, SDVariable input, SDVariable permuteDims){
|
||||
super(sd, input, permuteDims);
|
||||
}
|
||||
|
||||
public Permute(INDArray input, int... permuteDims){
|
||||
super(input, null);
|
||||
this.permuteDims = permuteDims;
|
||||
addIArgument(permuteDims);
|
||||
}
|
||||
|
||||
public Permute() {
|
||||
}
|
||||
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.nd4j.linalg.api.ops.impl.shape;
|
||||
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.NonNull;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
|
@ -41,6 +42,7 @@ import java.util.Map;
|
|||
* @author Adam Gibson
|
||||
*/
|
||||
@Slf4j
|
||||
@NoArgsConstructor
|
||||
public class Reshape extends DynamicCustomOp {
|
||||
|
||||
private long[] shape;
|
||||
|
@ -61,15 +63,13 @@ public class Reshape extends DynamicCustomOp {
|
|||
addIArgument(shape);
|
||||
}
|
||||
|
||||
public Reshape(INDArray in, INDArray shape){
|
||||
this(in, shape, null);
|
||||
}
|
||||
|
||||
public Reshape(@NonNull INDArray in, @NonNull INDArray shape, INDArray out){
|
||||
super(null, new INDArray[]{in, shape}, wrapOrNull(out), null, (List<Integer>)null);
|
||||
}
|
||||
|
||||
public Reshape() {
|
||||
public Reshape(INDArray in, INDArray shape){
|
||||
this(in, shape, null);
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -48,11 +48,10 @@ public class Size extends DynamicCustomOp {
|
|||
super(null, sameDiff, new SDVariable[] {input}, false);
|
||||
}
|
||||
|
||||
public Size(INDArray in) {
|
||||
addInputArgument(in);
|
||||
public Size(INDArray in){
|
||||
super(new INDArray[] {in}, null);
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public String onnxName() {
|
||||
throw new NoOpNameFoundException("No onnx name found for shape " + opName());
|
||||
|
|
|
@ -54,8 +54,10 @@ public class Slice extends DynamicCustomOp {
|
|||
super(null, sameDiff, new SDVariable[]{input, begin, end});
|
||||
}
|
||||
|
||||
public Slice(INDArray in, int[] begin, int... size) {
|
||||
addInputArgument(in);
|
||||
public Slice(INDArray input, int[] begin, int... size){
|
||||
super(new INDArray[] {input}, null);
|
||||
this.begin = begin;
|
||||
this.size = size;
|
||||
addIArgument(begin);
|
||||
addIArgument(size);
|
||||
}
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.nd4j.linalg.api.ops.impl.shape;
|
||||
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.NonNull;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
|
|
|
@ -67,13 +67,16 @@ public class Tile extends DynamicCustomOp {
|
|||
this(inputs,outputs,axis,false);
|
||||
}
|
||||
|
||||
public Tile(INDArray x, INDArray repeat) {
|
||||
addInputArgument(x, repeat);
|
||||
public Tile(INDArray x, INDArray repeat){
|
||||
super(null, new INDArray[] {x, repeat}, null);
|
||||
this.jaxis = null;
|
||||
}
|
||||
|
||||
public Tile(INDArray x, int... repeat) {
|
||||
addInputArgument(x);
|
||||
addIArgument(repeat);
|
||||
public Tile(INDArray inputs, int... axis){
|
||||
super(null, new INDArray[] {inputs}, null);
|
||||
this.jaxis = axis;
|
||||
this.is_static_reps = true;
|
||||
addArguments();
|
||||
}
|
||||
|
||||
public Tile() {}
|
||||
|
|
|
@ -60,8 +60,8 @@ public class Transpose extends DynamicCustomOp {
|
|||
super(null, new INDArray[]{input}, result == null ? null : new INDArray[]{result}, null, (List<Integer>) null);
|
||||
}
|
||||
|
||||
public Transpose(INDArray input) {
|
||||
addInputArgument(input);
|
||||
public Transpose(INDArray input){
|
||||
this(input, null);
|
||||
}
|
||||
|
||||
public Transpose() {
|
||||
|
|
|
@ -62,12 +62,10 @@ public class MatchConditionTransform extends BaseTransformBoolOp {
|
|||
this(x, z, Nd4j.EPS_THRESHOLD, condition);
|
||||
}
|
||||
|
||||
|
||||
public MatchConditionTransform(INDArray x, @NonNull Condition condition) {
|
||||
this(x, null, Nd4j.EPS_THRESHOLD, condition);
|
||||
}
|
||||
|
||||
|
||||
public MatchConditionTransform(INDArray x, INDArray z, double eps, @NonNull Condition condition) {
|
||||
super(x, null, z);
|
||||
|
||||
|
|
|
@ -69,7 +69,7 @@ public class CompareAndReplace extends BaseTransformSameOp {
|
|||
* @param condition
|
||||
*/
|
||||
public CompareAndReplace(INDArray x, INDArray y, Condition condition) {
|
||||
this(x, y, x, condition);
|
||||
this(x, y, null, condition);
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -24,6 +24,7 @@ import org.nd4j.linalg.api.buffer.DataType;
|
|||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.nd4j.linalg.api.ops.Op;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.tensorflow.framework.AttrValue;
|
||||
import org.tensorflow.framework.GraphDef;
|
||||
import org.tensorflow.framework.NodeDef;
|
||||
|
@ -46,6 +47,10 @@ public class Assign extends DynamicCustomOp {
|
|||
super(null,inputs, outputs);
|
||||
}
|
||||
|
||||
public Assign(INDArray x, INDArray y ) {
|
||||
this( new INDArray[]{y ,x},new INDArray[]{y}); // TODO: Still check. y cannot be null, must be same shape as x.
|
||||
}
|
||||
|
||||
@Override
|
||||
public void addIArgument(int... arg) {
|
||||
super.addIArgument(arg);
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.nd4j.linalg.api.ops.impl.transforms.custom;
|
||||
|
||||
import lombok.NonNull;
|
||||
import lombok.val;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
|
@ -67,11 +68,19 @@ public class DynamicPartition extends DynamicCustomOp {
|
|||
addArgs();
|
||||
}
|
||||
|
||||
public DynamicPartition(INDArray input, INDArray partitions, int numPartitions) {
|
||||
addInputArgument(input);
|
||||
addIArgument(numPartitions);
|
||||
public DynamicPartition(@NonNull INDArray input, @NonNull INDArray partitions, int numPartitions) {
|
||||
super(new INDArray[]{input, partitions}, null);
|
||||
this.numPartitions = numPartitions;
|
||||
addArgs();
|
||||
}
|
||||
|
||||
public DynamicPartition(INDArray x, INDArray [] partitions, int numPartitions){
|
||||
//TODO; This needs fixing.
|
||||
super(new INDArray[]{x}, null);
|
||||
// this.partitions = partitions;
|
||||
this.numPartitions = numPartitions;
|
||||
addArgs();
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<SDVariable> doDiff(List<SDVariable> i_v) {
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.nd4j.linalg.api.ops.impl.transforms.custom;
|
||||
|
||||
import lombok.NonNull;
|
||||
import org.apache.commons.lang3.ArrayUtils;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
|
@ -61,14 +62,8 @@ public class DynamicStitch extends DynamicCustomOp {
|
|||
this.numPartitions = inputs.length;
|
||||
}
|
||||
|
||||
public DynamicStitch(INDArray[] inputs, INDArray[] indices) {
|
||||
for (INDArray input : inputs) {
|
||||
addInputArgument(input);
|
||||
}
|
||||
|
||||
for (INDArray index : indices) {
|
||||
addInputArgument(index);
|
||||
}
|
||||
public DynamicStitch(@NonNull INDArray[] indices, @NonNull INDArray[] inputs) {
|
||||
super(ArrayUtils.addAll(indices, inputs), null);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -48,14 +48,14 @@ public class EqualTo extends BaseDynamicTransformOp {
|
|||
super(inputs, outputs);
|
||||
}
|
||||
|
||||
public EqualTo( INDArray x, INDArray y) {
|
||||
addInputArgument(x, y);
|
||||
}
|
||||
|
||||
public EqualTo(INDArray x, INDArray y, INDArray z){
|
||||
this(new INDArray[]{x, y}, new INDArray[]{z});
|
||||
}
|
||||
|
||||
public EqualTo(INDArray x, INDArray y){
|
||||
this(new INDArray[]{x, y}, null);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "equals";
|
||||
|
|
|
@ -27,6 +27,7 @@ import org.nd4j.linalg.api.ndarray.INDArray;
|
|||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.nd4j.linalg.api.ops.Op;
|
||||
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.tensorflow.framework.AttrValue;
|
||||
import org.tensorflow.framework.GraphDef;
|
||||
import org.tensorflow.framework.NodeDef;
|
||||
|
@ -55,19 +56,21 @@ public class Fill extends DynamicCustomOp {
|
|||
super(null,sameDiff, new SDVariable[] {shape}, false);
|
||||
this.value = value;
|
||||
this.outputDataType = outputDataType;
|
||||
this.outputDataType = outputDataType;
|
||||
addArgs();
|
||||
}
|
||||
|
||||
public Fill(INDArray shape, DataType outputDataType, double value) {
|
||||
super(new INDArray[]{shape, Nd4j.scalar(outputDataType, value)}, null);
|
||||
this.value = value;
|
||||
this.outputDataType = outputDataType;
|
||||
}
|
||||
|
||||
public Fill(INDArray shape, INDArray result, double value) {
|
||||
super(null, shape, result, Collections.singletonList(value), null);
|
||||
this.value = value;
|
||||
}
|
||||
|
||||
public Fill(INDArray shape, DataType dataType, double value) {
|
||||
super(null, shape, null, Collections.singletonList(value), null);
|
||||
this.value = value;
|
||||
}
|
||||
|
||||
public Fill(INDArray shape, INDArray value, INDArray result) {
|
||||
super(null, new INDArray[]{shape, value}, new INDArray[]{result});
|
||||
}
|
||||
|
|
|
@ -49,14 +49,14 @@ public class GreaterThan extends BaseDynamicTransformOp {
|
|||
super(inputs, outputs);
|
||||
}
|
||||
|
||||
public GreaterThan( INDArray x, INDArray y) {
|
||||
addInputArgument(x,y);
|
||||
}
|
||||
|
||||
public GreaterThan(INDArray x, INDArray y, INDArray z){
|
||||
this(new INDArray[]{x, y}, new INDArray[]{z});
|
||||
}
|
||||
|
||||
public GreaterThan(INDArray x, INDArray y){
|
||||
this(new INDArray[]{x, y}, null);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "greater";
|
||||
|
|
|
@ -52,9 +52,9 @@ public class GreaterThanOrEqual extends BaseDynamicTransformOp {
|
|||
this(new INDArray[]{x, y}, new INDArray[]{z});
|
||||
}
|
||||
|
||||
public GreaterThanOrEqual(INDArray x, INDArray y) {
|
||||
|
||||
this(new INDArray[]{x,y}, null);
|
||||
public GreaterThanOrEqual(INDArray x, INDArray y){
|
||||
this(new INDArray[]{x, y}, null);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -45,8 +45,8 @@ public class IsNumericTensor extends DynamicCustomOp {
|
|||
super(null, inputs, outputs);
|
||||
}
|
||||
|
||||
public IsNumericTensor(INDArray input) {
|
||||
addInputArgument(input);
|
||||
public IsNumericTensor(INDArray inputs) {
|
||||
super( new INDArray[] {inputs}, null);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -49,14 +49,14 @@ public class LessThan extends BaseDynamicTransformOp {
|
|||
super(inputs, outputs);
|
||||
}
|
||||
|
||||
public LessThan( INDArray x, INDArray y) {
|
||||
addInputArgument(x,y);
|
||||
}
|
||||
|
||||
public LessThan(INDArray x, INDArray y, INDArray z){
|
||||
this(new INDArray[]{x, y}, new INDArray[]{z});
|
||||
}
|
||||
|
||||
public LessThan(INDArray x, INDArray y){
|
||||
this(new INDArray[]{x, y}, null);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "less";
|
||||
|
|
|
@ -48,14 +48,14 @@ public class LessThanOrEqual extends BaseDynamicTransformOp {
|
|||
super(inputs, outputs);
|
||||
}
|
||||
|
||||
public LessThanOrEqual( INDArray x, INDArray y) {
|
||||
addInputArgument(x,y);
|
||||
}
|
||||
|
||||
public LessThanOrEqual(INDArray x, INDArray y, INDArray z){
|
||||
this(new INDArray[]{x, y}, new INDArray[]{z});
|
||||
}
|
||||
|
||||
public LessThanOrEqual(INDArray x, INDArray y){
|
||||
this(new INDArray[]{x, y}, null);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "less_equal";
|
||||
|
|
|
@ -48,12 +48,12 @@ public class Max extends BaseDynamicTransformOp {
|
|||
super(new INDArray[]{first, second}, out == null ? null : new INDArray[]{out});
|
||||
}
|
||||
|
||||
public Max( INDArray[] inputs, INDArray[] outputs) {
|
||||
super(inputs, outputs);
|
||||
public Max( INDArray first, INDArray second){
|
||||
this(first, second, null);
|
||||
}
|
||||
|
||||
public Max( INDArray x, INDArray y) {
|
||||
addInputArgument(x,y);
|
||||
public Max( INDArray[] inputs, INDArray[] outputs) {
|
||||
super(inputs, outputs);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -48,12 +48,12 @@ public class Min extends BaseDynamicTransformOp {
|
|||
super(new INDArray[]{first, second}, out == null ? null : new INDArray[]{out});
|
||||
}
|
||||
|
||||
public Min( INDArray[] inputs, INDArray[] outputs) {
|
||||
super(inputs, outputs);
|
||||
public Min( INDArray first, INDArray second){
|
||||
this(first, second, null);
|
||||
}
|
||||
|
||||
public Min( INDArray x, INDArray y) {
|
||||
addInputArgument(x,y);
|
||||
public Min( INDArray[] inputs, INDArray[] outputs) {
|
||||
super(inputs, outputs);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -48,14 +48,14 @@ public class NotEqualTo extends BaseDynamicTransformOp {
|
|||
super(inputs, outputs);
|
||||
}
|
||||
|
||||
public NotEqualTo( INDArray x, INDArray y) {
|
||||
addInputArgument(x,y);
|
||||
}
|
||||
|
||||
public NotEqualTo(INDArray x, INDArray y, INDArray z){
|
||||
this(new INDArray[]{x, y}, new INDArray[]{z});
|
||||
}
|
||||
|
||||
public NotEqualTo(INDArray x, INDArray y){
|
||||
this(new INDArray[]{x, y}, null);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "not_equals";
|
||||
|
|
|
@ -59,6 +59,17 @@ public class ReverseSequence extends DynamicCustomOp {
|
|||
addArguments();
|
||||
}
|
||||
|
||||
public ReverseSequence(INDArray x, INDArray seq_lengths, int seqDim, int batchDim){
|
||||
super(new INDArray[]{x, seq_lengths}, null);
|
||||
this.seqDim = seqDim;
|
||||
this.batchDim = batchDim;
|
||||
addArguments();
|
||||
}
|
||||
|
||||
public ReverseSequence(INDArray x, INDArray seq_lengths){
|
||||
this(x, seq_lengths, 1, 0);
|
||||
}
|
||||
|
||||
private void addArguments(){
|
||||
addIArgument(seqDim);
|
||||
addIArgument(batchDim);
|
||||
|
@ -67,11 +78,6 @@ public class ReverseSequence extends DynamicCustomOp {
|
|||
public ReverseSequence() {
|
||||
}
|
||||
|
||||
public ReverseSequence(INDArray x, INDArray seq_lengths, int seqDim, int batchDim) {
|
||||
addInputArgument(x, seq_lengths);
|
||||
addIArgument(seqDim, batchDim);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "reverse_sequence";
|
||||
|
|
|
@ -39,8 +39,8 @@ public class SegmentMax extends DynamicCustomOp {
|
|||
super(null, sameDiff, new SDVariable[] {data, segmentIds}, false);
|
||||
}
|
||||
|
||||
public SegmentMax(INDArray data, INDArray segmentIds) {
|
||||
addInputArgument(data, segmentIds);
|
||||
public SegmentMax(INDArray data, INDArray segmentIds){
|
||||
super(new INDArray[]{data, segmentIds}, null);
|
||||
}
|
||||
|
||||
public SegmentMax(){ }
|
||||
|
|
|
@ -39,12 +39,12 @@ public class SegmentMean extends DynamicCustomOp {
|
|||
super(null, sameDiff, new SDVariable[] {data, segmentIds}, false);
|
||||
}
|
||||
|
||||
public SegmentMean(INDArray data, INDArray segmentIds) {
|
||||
addInputArgument(data, segmentIds);
|
||||
}
|
||||
|
||||
public SegmentMean(){ }
|
||||
|
||||
public SegmentMean(INDArray data, INDArray segmentIds){
|
||||
super(new INDArray[]{data, segmentIds}, null);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName(){
|
||||
return "segment_mean";
|
||||
|
|
|
@ -39,12 +39,12 @@ public class SegmentMin extends DynamicCustomOp {
|
|||
super(null, sameDiff, new SDVariable[] {data, segmentIds}, false);
|
||||
}
|
||||
|
||||
public SegmentMin(INDArray data, INDArray segmentIds) {
|
||||
addInputArgument(data, segmentIds);
|
||||
}
|
||||
|
||||
public SegmentMin(){ }
|
||||
|
||||
public SegmentMin(INDArray data, INDArray segmentIds){
|
||||
super(new INDArray[]{data, segmentIds}, null);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName(){
|
||||
return "segment_min";
|
||||
|
|
|
@ -39,12 +39,12 @@ public class SegmentProd extends DynamicCustomOp {
|
|||
super(null, sameDiff, new SDVariable[] {data, segmentIds}, false);
|
||||
}
|
||||
|
||||
public SegmentProd(INDArray data, INDArray segmentIds) {
|
||||
addInputArgument(data, segmentIds);
|
||||
}
|
||||
|
||||
public SegmentProd(){ }
|
||||
|
||||
public SegmentProd(INDArray data, INDArray segmentIds){
|
||||
super(new INDArray[]{data, segmentIds}, null);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName(){
|
||||
return "segment_prod";
|
||||
|
|
|
@ -39,12 +39,12 @@ public class SegmentSum extends DynamicCustomOp {
|
|||
super(null, sameDiff, new SDVariable[] {data, segmentIds}, false);
|
||||
}
|
||||
|
||||
public SegmentSum(INDArray data, INDArray segmentIds) {
|
||||
addInputArgument(data, segmentIds);
|
||||
}
|
||||
|
||||
public SegmentSum(){ }
|
||||
|
||||
public SegmentSum(INDArray data, INDArray segmentIds){
|
||||
super(new INDArray[]{data, segmentIds}, null);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName(){
|
||||
return "segment_sum";
|
||||
|
|
|
@ -42,7 +42,7 @@ public class Identity extends BaseDynamicTransformOp {
|
|||
}
|
||||
|
||||
public Identity(INDArray x){
|
||||
addInputArgument(x);
|
||||
super(new INDArray[]{x}, null);
|
||||
}
|
||||
|
||||
public Identity(){ }
|
||||
|
|
|
@ -41,13 +41,14 @@ public class UnsortedSegmentMax extends DynamicCustomOp {
|
|||
addIArgument(numSegments);
|
||||
}
|
||||
|
||||
public UnsortedSegmentMax(INDArray data, INDArray segmentIds, int numSegments) {
|
||||
addInputArgument(data, segmentIds);
|
||||
public UnsortedSegmentMax(){ }
|
||||
|
||||
public UnsortedSegmentMax(INDArray data, INDArray segmentIds, int numSegments){
|
||||
super(new INDArray[]{data, segmentIds}, null);
|
||||
this.numSegments = numSegments;
|
||||
addIArgument(numSegments);
|
||||
}
|
||||
|
||||
public UnsortedSegmentMax(){ }
|
||||
|
||||
@Override
|
||||
public String opName(){
|
||||
return "unsorted_segment_max";
|
||||
|
|
|
@ -45,8 +45,9 @@ public class UnsortedSegmentMean extends DynamicCustomOp {
|
|||
addIArgument(numSegments);
|
||||
}
|
||||
|
||||
public UnsortedSegmentMean(INDArray data, INDArray segmentIds, int numSegments) {
|
||||
addInputArgument(data, segmentIds);
|
||||
public UnsortedSegmentMean(INDArray data, INDArray segmentIds, int numSegments){
|
||||
super(new INDArray[]{data, segmentIds}, null);
|
||||
this.numSegments = numSegments;
|
||||
addIArgument(numSegments);
|
||||
}
|
||||
|
||||
|
|
|
@ -45,8 +45,9 @@ public class UnsortedSegmentMin extends DynamicCustomOp {
|
|||
addIArgument(numSegments);
|
||||
}
|
||||
|
||||
public UnsortedSegmentMin(INDArray data, INDArray segmentIds, int numSegments) {
|
||||
addInputArgument(data, segmentIds);
|
||||
public UnsortedSegmentMin(INDArray data, INDArray segmentIds, int numSegments){
|
||||
super(new INDArray[]{data, segmentIds}, null);
|
||||
this.numSegments = numSegments;
|
||||
addIArgument(numSegments);
|
||||
}
|
||||
|
||||
|
|
|
@ -45,8 +45,9 @@ public class UnsortedSegmentProd extends DynamicCustomOp {
|
|||
addIArgument(numSegments);
|
||||
}
|
||||
|
||||
public UnsortedSegmentProd(INDArray data, INDArray segmentIds, int numSegments) {
|
||||
addInputArgument(data, segmentIds);
|
||||
public UnsortedSegmentProd(INDArray data, INDArray segmentIds, int numSegments){
|
||||
super(new INDArray[]{data, segmentIds}, null);
|
||||
this.numSegments = numSegments;
|
||||
addIArgument(numSegments);
|
||||
}
|
||||
|
||||
|
|
|
@ -39,18 +39,18 @@ public class UnsortedSegmentSqrtN extends DynamicCustomOp {
|
|||
|
||||
private int numSegments;
|
||||
|
||||
public UnsortedSegmentSqrtN(INDArray data, INDArray segmentIds, int numSegments) {
|
||||
addInputArgument(data, segmentIds);
|
||||
addIArgument(numSegments);
|
||||
this.numSegments = numSegments;
|
||||
}
|
||||
|
||||
public UnsortedSegmentSqrtN(SameDiff sameDiff, SDVariable data, SDVariable segmentIds, int numSegments) {
|
||||
super(null, sameDiff, new SDVariable[] {data, segmentIds}, false);
|
||||
this.numSegments = numSegments;
|
||||
addIArgument(numSegments);
|
||||
}
|
||||
|
||||
public UnsortedSegmentSqrtN(INDArray data, INDArray segmentIds, int numSegments){
|
||||
super(new INDArray[]{data, segmentIds}, null);
|
||||
this.numSegments = numSegments;
|
||||
addIArgument(numSegments);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName(){
|
||||
return "unsorted_segment_sqrt_n";
|
||||
|
|
|
@ -46,8 +46,9 @@ public class UnsortedSegmentSum extends DynamicCustomOp {
|
|||
addIArgument(numSegments);
|
||||
}
|
||||
|
||||
public UnsortedSegmentSum(INDArray data, INDArray segmentIds, int numSegments) {
|
||||
addInputArgument(data, segmentIds);
|
||||
public UnsortedSegmentSum(INDArray data, INDArray segmentIds, int numSegments){
|
||||
super(new INDArray[]{data, segmentIds}, null);
|
||||
this.numSegments = numSegments;
|
||||
addIArgument(numSegments);
|
||||
}
|
||||
|
||||
|
|
|
@ -130,13 +130,20 @@ public class NDValidation {
|
|||
" type; got array with non-integer data type " + v.dataType());
|
||||
}
|
||||
|
||||
public static void validateInteger(String opName, String inputName, INDArray[] vars) {
|
||||
for (INDArray v : vars) {
|
||||
/**
|
||||
* Validate that the operation is being applied on an integer type INDArray []
|
||||
*
|
||||
* @param opName Operation name to print in the exception
|
||||
* @param inputName Name of the input to the op to validate
|
||||
* @param v Variable to validate datatype for (input to operation)
|
||||
*/
|
||||
public static void validateInteger(String opName, String inputName, INDArray [] v) {
|
||||
if (v == null)
|
||||
return;
|
||||
if (!v.dataType().isIntType())
|
||||
for (int i = 0; i < v.length; i++) {
|
||||
if (!v[i].dataType().isIntType())
|
||||
throw new IllegalStateException("Input \"" + inputName + "\" for operation \"" + opName + "\" must be an integer" +
|
||||
" type; got array with non-integer data type " + v.dataType());
|
||||
" type; got array with non-integer data type member" + v[i].dataType());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -246,10 +253,11 @@ public class NDValidation {
|
|||
}
|
||||
|
||||
public static boolean isSameType(INDArray[] x) {
|
||||
DataType firstDataType = x[0].dataType();
|
||||
if (x.length > 1) {
|
||||
for (int i = 1; i < x.length; ++i) {
|
||||
if (firstDataType != x[i].dataType())
|
||||
if(x.length == 0)
|
||||
return true;
|
||||
DataType first = x[0].dataType();
|
||||
for( int i=1; i<x.length; i++ ){
|
||||
if(first != x[i].dataType()){
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2006,10 +2006,9 @@ public class Nd4j {
|
|||
* @return the linearly spaced vector
|
||||
*/
|
||||
public static INDArray linspace(@NonNull DataType dataType, double lower, double step, long num) {
|
||||
Preconditions.checkState(dataType.isFPType());
|
||||
Preconditions.checkState(dataType.isFPType(), "Datatype must be a floating point type for linspace, got %s", dataType);
|
||||
if (num == 1)
|
||||
return Nd4j.scalar(dataType, lower);
|
||||
|
||||
return Nd4j.getExecutioner().exec(new Linspace(lower, num, step, dataType));
|
||||
}
|
||||
|
||||
|
@ -2022,10 +2021,9 @@ public class Nd4j {
|
|||
* @return the linearly spaced vector
|
||||
*/
|
||||
public static INDArray linspace( double lower, double upper, long num, @NonNull DataType dataType) {
|
||||
Preconditions.checkState(dataType.isFPType());
|
||||
Preconditions.checkState(dataType.isFPType(), "Datatype must be a floating point type for linspace, got %s", dataType);
|
||||
if (num == 1)
|
||||
return Nd4j.scalar(dataType, lower);
|
||||
|
||||
return Nd4j.getExecutioner().exec(new Linspace(lower, upper, num, dataType));
|
||||
}
|
||||
|
||||
|
|
|
@ -159,7 +159,7 @@ public class BooleanIndexing {
|
|||
if (to.length() != from.length())
|
||||
throw new IllegalStateException("Mis matched length for to and from");
|
||||
|
||||
Nd4j.getExecutioner().exec(new CompareAndSet(to, from, condition));
|
||||
Nd4j.getExecutioner().exec(new CompareAndSet(to, from, to, condition));
|
||||
}
|
||||
|
||||
|
||||
|
@ -177,7 +177,7 @@ public class BooleanIndexing {
|
|||
if (to.length() != from.length())
|
||||
throw new IllegalStateException("Mis matched length for to and from");
|
||||
|
||||
Nd4j.getExecutioner().exec(new CompareAndReplace(to, from, condition));
|
||||
Nd4j.getExecutioner().exec(new CompareAndReplace(to, from, to, condition));
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
|
|
|
@ -19,13 +19,11 @@ package org.nd4j.linalg.learning;
|
|||
|
||||
import lombok.Data;
|
||||
import lombok.NonNull;
|
||||
import org.apache.commons.math3.util.FastMath;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.shape.Shape;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.indexing.NDArrayIndex;
|
||||
import org.nd4j.linalg.learning.config.Nadam;
|
||||
import org.nd4j.linalg.ops.transforms.Transforms;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
|
|
@ -792,6 +792,19 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
|||
}
|
||||
}
|
||||
|
||||
boolean keepDims = op.isKeepDims();
|
||||
long[] retShape = Shape.reductionShape(x, dimension, true, keepDims);
|
||||
|
||||
if(z == null || x == z) {
|
||||
val ret = Nd4j.createUninitialized(DataType.LONG, retShape);
|
||||
|
||||
setZ(ret, op, oc);
|
||||
z = ret;
|
||||
} else if(!Arrays.equals(retShape, z.shape())){
|
||||
throw new IllegalStateException("Z array shape does not match expected return type for op " + op
|
||||
+ ": expected shape " + Arrays.toString(retShape) + ", z.shape()=" + Arrays.toString(z.shape()));
|
||||
}
|
||||
|
||||
long st = profilingConfigurableHookIn(op);
|
||||
|
||||
checkForCompression(op);
|
||||
|
@ -2060,7 +2073,7 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
|||
throw new ND4JIllegalStateException("Op name " + op.opName() + " failed to execute. You can't execute non-inplace CustomOp without outputs being specified");
|
||||
|
||||
for (val shape: list)
|
||||
op.addOutputArgument(Nd4j.create(shape));
|
||||
op.addOutputArgument(Nd4j.create(shape, false));
|
||||
|
||||
shapeOverride = true;
|
||||
} catch (Exception e) {
|
||||
|
|
|
@ -772,8 +772,10 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
|
|||
|
||||
if (y != null) {
|
||||
|
||||
if (z == null)
|
||||
if (z == null) {
|
||||
setZ(Nd4j.create(op.resultType(), x.shape()), op, oc);
|
||||
z = getZ(op, oc);
|
||||
}
|
||||
|
||||
|
||||
op.validateDataTypes(oc, experimentalMode.get());
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -234,9 +234,9 @@ public class BooleanIndexingTest extends BaseNd4jTest {
|
|||
INDArray array = Nd4j.create(new double[] {1, 2, 0, 4, 5});
|
||||
INDArray comp = Nd4j.create(new double[] {1, 2, 3, 4, 5});
|
||||
|
||||
Nd4j.getExecutioner().exec(new CompareAndReplace(array, comp, Conditions.lessThan(1)));
|
||||
INDArray z = Nd4j.exec(new CompareAndReplace(array, comp, Conditions.lessThan(1)));
|
||||
|
||||
assertEquals(comp, array);
|
||||
assertEquals(comp, z);
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -256,9 +256,9 @@ public class BooleanIndexingTest extends BaseNd4jTest {
|
|||
INDArray y = Nd4j.create(new double[] {2, 4, 3, 4, 5});
|
||||
INDArray comp = Nd4j.create(new double[] {2, 4, 0, 4, 5});
|
||||
|
||||
Nd4j.getExecutioner().exec(new CompareAndReplace(x, y, Conditions.epsNotEquals(0.0)));
|
||||
INDArray z = Nd4j.exec(new CompareAndReplace(x, y, Conditions.epsNotEquals(0.0)));
|
||||
|
||||
assertEquals(comp, x);
|
||||
assertEquals(comp, z);
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -267,9 +267,9 @@ public class BooleanIndexingTest extends BaseNd4jTest {
|
|||
INDArray y = Nd4j.create(new double[] {2, 4, 3, 4, 5});
|
||||
INDArray comp = Nd4j.create(new double[] {2, 4, 3, 4, 5});
|
||||
|
||||
Nd4j.getExecutioner().exec(new CompareAndReplace(x, y, Conditions.lessThan(4)));
|
||||
INDArray z = Nd4j.exec(new CompareAndReplace(x, y, Conditions.lessThan(4)));
|
||||
|
||||
assertEquals(comp, x);
|
||||
assertEquals(comp, z);
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -278,9 +278,9 @@ public class BooleanIndexingTest extends BaseNd4jTest {
|
|||
INDArray y = Nd4j.create(new double[] {2, 4, 3, 4, 5});
|
||||
INDArray comp = Nd4j.create(new double[] {2, 2, 3, 4, 5});
|
||||
|
||||
Nd4j.getExecutioner().exec(new CompareAndReplace(x, y, Conditions.lessThan(2)));
|
||||
INDArray z = Nd4j.exec(new CompareAndReplace(x, y, Conditions.lessThan(2)));
|
||||
|
||||
assertEquals(comp, x);
|
||||
assertEquals(comp, z);
|
||||
}
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue