Remove calculate output shape from java side (#151)

* remove some unneeded java-side output shape calculations

Signed-off-by: Ryan Nett <rnett@skymind.io>

* delete Broadcast

Signed-off-by: Ryan Nett <rnett@skymind.io>

* delete Linear and Module,

Signed-off-by: Ryan Nett <rnett@skymind.io>

* update Identity, HashCode, and NoOp

Signed-off-by: Ryan Nett <rnett@skymind.io>

* removed Cast java-side shape function, added tests and SDVariable.isEmpty

Signed-off-by: Ryan Nett <rnett@skymind.io>

* ignoring test w/ issues on master

Signed-off-by: Ryan Nett <rnett@skymind.io>

* noop needs more work, fixed BaseArithmeticBackprop and BaseDynamicTransform ops

merge in master for c++ build fix

Signed-off-by: Ryan Nett <rnett@skymind.io>

* fix EqualTo

Signed-off-by: Ryan Nett <rnett@skymind.io>

* fix other cond ops

Signed-off-by: Ryan Nett <rnett@skymind.io>

* "fake" ops calculateOutputShape() throws exception

Signed-off-by: Ryan Nett <rnett@skymind.io>

* use c++ shape calc for Linspace

Signed-off-by: Ryan Nett <rnett@skymind.io>

* fix exception message, move most to BaseCompatOp

Signed-off-by: Ryan Nett <rnett@skymind.io>

* remove SDVariable.isEmpty

Signed-off-by: Ryan Nett <rnett@skymind.io>

* remove commented out code

Signed-off-by: Ryan Nett <rnett@skymind.io>
master
Ryan Nett 2019-08-27 20:39:32 -07:00 committed by GitHub
parent b46f9827b8
commit 2a1431264f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
36 changed files with 32 additions and 741 deletions

View File

@ -163,7 +163,6 @@ import org.nd4j.linalg.api.ops.impl.scatter.ScatterMin;
import org.nd4j.linalg.api.ops.impl.scatter.ScatterMul; import org.nd4j.linalg.api.ops.impl.scatter.ScatterMul;
import org.nd4j.linalg.api.ops.impl.scatter.ScatterSub; import org.nd4j.linalg.api.ops.impl.scatter.ScatterSub;
import org.nd4j.linalg.api.ops.impl.scatter.ScatterUpdate; import org.nd4j.linalg.api.ops.impl.scatter.ScatterUpdate;
import org.nd4j.linalg.api.ops.impl.shape.Broadcast;
import org.nd4j.linalg.api.ops.impl.shape.Concat; import org.nd4j.linalg.api.ops.impl.shape.Concat;
import org.nd4j.linalg.api.ops.impl.shape.ConfusionMatrix; import org.nd4j.linalg.api.ops.impl.shape.ConfusionMatrix;
import org.nd4j.linalg.api.ops.impl.shape.Cross; import org.nd4j.linalg.api.ops.impl.shape.Cross;
@ -1450,14 +1449,6 @@ public class DifferentialFunctionFactory {
return new MatrixInverse(sameDiff(), in, false).outputVariable(); return new MatrixInverse(sameDiff(), in, false).outputVariable();
} }
public SDVariable broadcast(SDVariable iX, int... shape) {
return broadcast(iX, ArrayUtil.toLongArray(shape));
}
public SDVariable broadcast(SDVariable iX, long... shape) {
return new Broadcast(sameDiff(), iX, shape).outputVariable();
}
public SDVariable onehot(SDVariable indices, int depth, int axis, double on, double off, DataType dataType) { public SDVariable onehot(SDVariable indices, int depth, int axis, double on, double off, DataType dataType) {
return new OneHot(sameDiff(), indices, depth, axis, on, off, dataType).outputVariable(); return new OneHot(sameDiff(), indices, depth, axis, on, off, dataType).outputVariable();
} }

View File

@ -111,9 +111,6 @@ public class SDVariable implements Serializable {
return variableType == VariableType.CONSTANT; return variableType == VariableType.CONSTANT;
} }
/** /**
* Allocate and return a new array * Allocate and return a new array
* based on the vertex id and weight initialization. * based on the vertex id and weight initialization.

View File

@ -98,7 +98,6 @@ public class ImportClassMapping {
org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax.class, org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax.class,
org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin.class, org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin.class,
org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction.class, org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction.class,
org.nd4j.linalg.api.ops.impl.layers.Linear.class,
org.nd4j.linalg.api.ops.impl.layers.convolution.AvgPooling2D.class, org.nd4j.linalg.api.ops.impl.layers.convolution.AvgPooling2D.class,
org.nd4j.linalg.api.ops.impl.layers.convolution.AvgPooling3D.class, org.nd4j.linalg.api.ops.impl.layers.convolution.AvgPooling3D.class,
org.nd4j.linalg.api.ops.impl.layers.convolution.BatchNorm.class, org.nd4j.linalg.api.ops.impl.layers.convolution.BatchNorm.class,
@ -267,7 +266,6 @@ public class ImportClassMapping {
org.nd4j.linalg.api.ops.impl.scatter.ScatterSub.class, org.nd4j.linalg.api.ops.impl.scatter.ScatterSub.class,
org.nd4j.linalg.api.ops.impl.scatter.ScatterUpdate.class, org.nd4j.linalg.api.ops.impl.scatter.ScatterUpdate.class,
org.nd4j.linalg.api.ops.impl.shape.ApplyGradientDescent.class, org.nd4j.linalg.api.ops.impl.shape.ApplyGradientDescent.class,
org.nd4j.linalg.api.ops.impl.shape.Broadcast.class,
org.nd4j.linalg.api.ops.impl.shape.BroadcastDynamicShape.class, org.nd4j.linalg.api.ops.impl.shape.BroadcastDynamicShape.class,
org.nd4j.linalg.api.ops.impl.shape.Concat.class, org.nd4j.linalg.api.ops.impl.shape.Concat.class,
org.nd4j.linalg.api.ops.impl.shape.ConfusionMatrix.class, org.nd4j.linalg.api.ops.impl.shape.ConfusionMatrix.class,

View File

@ -1,63 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops;
import lombok.NoArgsConstructor;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ndarray.INDArray;
import java.util.ArrayList;
import java.util.List;
/**
* Abstract base class for {@link Module}
* that handles Dynamic ops and handles nesting.
*
* This is a logical unit for defining layers
* very similar to pytorch's modules, or tensorflow's layers.
*
* @author Adam Gibson
*/
@NoArgsConstructor
public abstract class BaseModule extends DynamicCustomOp implements Module {
private List<Module> modules = new ArrayList<>();
public BaseModule(String opName, INDArray[] inputs, INDArray[] outputs, List<Double> tArguments, List<Integer> iArguments, List<Module> modules) {
super(opName, inputs, outputs, tArguments, iArguments);
this.modules = modules;
}
public BaseModule(String opName, SameDiff sameDiff, SDVariable[] args, boolean inPlace, List<Module> modules) {
super(opName, sameDiff, args, inPlace);
this.modules = modules;
}
@Override
public Module[] subModules() {
return modules.toArray(new Module[modules.size()]);
}
@Override
public void addModule(Module module) {
modules.add(module);
}
}

View File

@ -1,51 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.linalg.api.ndarray.INDArray;
/**
* A Module is a {@link CustomOp}
* with varying input arguments
* and automatically calculated outputs
* defined at a higher level than c++.
*
* A Module is meant to be a way of implementing custom operations
* in straight java/nd4j.
*/
public interface Module extends CustomOp {
/**
*
* @param inputs
*/
void exec(INDArray... inputs);
Module[] subModules();
void addModule(Module module);
void execSameDiff(SDVariable... input);
}

View File

@ -80,6 +80,7 @@ public class NoOp extends DynamicCustomOp {
return 1; return 1;
} }
@Override @Override
public List<LongShapeDescriptor> calculateOutputShape(){ public List<LongShapeDescriptor> calculateOutputShape(){
if(inputArguments != null && !inputArguments.isEmpty()){ if(inputArguments != null && !inputArguments.isEmpty()){

View File

@ -16,6 +16,7 @@
package org.nd4j.linalg.api.ops.impl.controlflow.compat; package org.nd4j.linalg.api.ops.impl.controlflow.compat;
import java.util.List;
import lombok.NonNull; import lombok.NonNull;
import lombok.val; import lombok.val;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
@ -25,6 +26,7 @@ import org.nd4j.imports.descriptors.properties.AttributeAdapter;
import org.nd4j.imports.descriptors.properties.PropertyMapping; import org.nd4j.imports.descriptors.properties.PropertyMapping;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef; import org.tensorflow.framework.NodeDef;
@ -89,4 +91,9 @@ public abstract class BaseCompatOp extends DynamicCustomOp {
public Map<String, Map<String, AttributeAdapter>> attributeAdaptersForFunction() { public Map<String, Map<String, AttributeAdapter>> attributeAdaptersForFunction() {
return super.attributeAdaptersForFunction(); return super.attributeAdaptersForFunction();
} }
@Override
public List<LongShapeDescriptor> calculateOutputShape() {
throw new UnsupportedOperationException("calculateOutputShape() is not supported for control flow ops.");
}
} }

View File

@ -68,15 +68,6 @@ public class Enter extends BaseCompatOp {
return OP_NAME; return OP_NAME;
} }
@Override
public List<LongShapeDescriptor> calculateOutputShape() {
if(arg().getArr() != null) {
return Collections.singletonList(LongShapeDescriptor.fromShape(arg().getShape(), arg().getArr().dataType()));
}
else
return Collections.emptyList();
}
@Override @Override
public SDVariable[] outputVariables() { public SDVariable[] outputVariables() {
return super.outputVariables(); return super.outputVariables();

View File

@ -56,15 +56,6 @@ public class Exit extends BaseCompatOp {
return OP_NAME; return OP_NAME;
} }
@Override
public List<LongShapeDescriptor> calculateOutputShape() {
if(arg().getArr() != null) {
return Collections.singletonList(arg().getShapeDescriptor());
}
else
return Collections.emptyList();
}
@Override @Override
public SDVariable[] outputVariables() { public SDVariable[] outputVariables() {
return super.outputVariables(); return super.outputVariables();

View File

@ -37,15 +37,6 @@ public class LoopCond extends BaseCompatOp {
return "loop_cond"; return "loop_cond";
} }
@Override
public List<LongShapeDescriptor> calculateOutputShape() {
if(arg().getArr() != null) {
return Collections.singletonList(LongShapeDescriptor.fromShape(arg().getShape(), arg().dataType()));
}
else
return Collections.emptyList();
}
@Override @Override
public SDVariable[] outputVariables() { public SDVariable[] outputVariables() {
return super.outputVariables(); return super.outputVariables();

View File

@ -64,15 +64,6 @@ public class Merge extends BaseCompatOp {
return 60L; return 60L;
} }
@Override
public List<LongShapeDescriptor> calculateOutputShape() {
if(arg().getArr() != null) {
return Collections.singletonList(arg().getShapeDescriptor());
}
else
return Collections.emptyList();
}
@Override @Override
public SDVariable[] outputVariables() { public SDVariable[] outputVariables() {
return super.outputVariables(); return super.outputVariables();

View File

@ -53,15 +53,6 @@ public class NextIteration extends BaseCompatOp {
return OP_NAME; return OP_NAME;
} }
@Override
public List<LongShapeDescriptor> calculateOutputShape() {
if(arg().getArr() != null) {
return Collections.singletonList(LongShapeDescriptor.fromShape(arg().getShape(), arg().dataType()));
}
else
return Collections.emptyList();
}
@Override @Override
public SDVariable[] outputVariables() { public SDVariable[] outputVariables() {
return super.outputVariables(); return super.outputVariables();

View File

@ -63,18 +63,6 @@ public class Switch extends BaseCompatOp {
return OP_NAME; return OP_NAME;
} }
@Override
public List<LongShapeDescriptor> calculateOutputShape() {
if(args()[0].getArr() != null) {
val arg0 = args()[0];
val arr0 = arg0.getArr();
val dtype = arr0.dataType();
return Arrays.asList(LongShapeDescriptor.fromShape(arg0.getShape(), dtype),LongShapeDescriptor.fromShape(arg0.getShape(), dtype));
}
else
return Collections.emptyList();
}
@Override @Override
public SDVariable[] outputVariables() { public SDVariable[] outputVariables() {
return super.outputVariables(); return super.outputVariables();

View File

@ -1,198 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.layers;
import lombok.Builder;
import lombok.NoArgsConstructor;
import lombok.val;
import onnx.Onnx;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.VariableType;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.blas.params.MMulTranspose;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseModule;
import org.nd4j.linalg.api.ops.Module;
import org.nd4j.linalg.api.ops.impl.reduce.Mmul;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.weightinit.WeightInitScheme;
import org.nd4j.weightinit.impl.ZeroInitScheme;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
/**
* Linear:
* a * bT
*
* @author Adam Gibson
*/
@NoArgsConstructor
public class Linear extends BaseModule {
private DifferentialFunction forward;
private int nIn,nOut;
private WeightInitScheme weightInitScheme,biasWeightInitScheme;
@Builder(builderMethodName = "execBuilder")
public Linear(int nIn,
int nOut,
WeightInitScheme weightInitScheme,
WeightInitScheme biasWeightInitScheme) {
super(null,
getParams(nIn,nOut,weightInitScheme,biasWeightInitScheme),
new INDArray[]{},
new ArrayList<Double>(), new ArrayList<Integer>(), new ArrayList<Module>());
this.weightInitScheme = weightInitScheme;
this.biasWeightInitScheme = biasWeightInitScheme;
this.nIn = nIn;
this.nOut = nOut;
}
@Builder(builderMethodName = "sameDiffBuilder")
public Linear(SameDiff sameDiff,
int nIn,
int nOut,
WeightInitScheme weightInitScheme,
WeightInitScheme biasWeightInitScheme) {
super(null, sameDiff, null, false, new ArrayList<Module>());
this.weightInitScheme = weightInitScheme;
this.biasWeightInitScheme = biasWeightInitScheme;
this.nIn = nIn;
this.nOut = nOut;
}
@Override
public String opName() {
return "linear";
}
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
}
@Override
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
}
@Override
public List<SDVariable> doDiff(List<SDVariable> f1) {
execSameDiff();
return forward.doDiff(f1);
}
@Override
public List<LongShapeDescriptor> calculateOutputShape() {
List<LongShapeDescriptor> ret = new ArrayList<>();
ret.add(LongShapeDescriptor.fromShape(Shape.getMatrixMultiplyShape(inputArguments()[0].shape(),new long[]{nOut,nIn}), inputArguments()[1].dataType()));
ret.add(LongShapeDescriptor.fromShape(Shape.getMatrixMultiplyShape(inputArguments()[0].shape(),inputArguments()[1].transpose().shape()), inputArguments()[1].dataType()));
if(biasWeightInitScheme != null) {
ret.add(LongShapeDescriptor.fromShape(new long[]{nOut,1}, inputArguments()[1].dataType()));
}
return ret;
}
@Override
public String onnxName() {
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
}
@Override
public String tensorflowName() {
throw new NoOpNameFoundException("No tensorflow op opName found for " + opName());
}
@Override
public void exec(INDArray... inputs) {
val inputArguments = inputArguments();
if(inputArguments == null || inputArguments.length < 1) {
throw new IllegalStateException("No arguments found.");
}
INDArray weights = inputArguments[0];
INDArray right = inputArguments[1];
val outputArguments = outputArguments();
if(outputArguments == null || outputArguments.length < 1) {
if(inputArguments.length == 1)
addOutputArgument(inputs[0].mmul(weights.transpose()));
else
addOutputArgument(inputs[0].mmul(weights.transpose()).addiColumnVector(right));
}
else {
inputs[0].mmul(weights.transpose(),outputArguments[0]);
}
}
@Override
public void execSameDiff(SDVariable... input) {
val args = args();
if(args == null || args.length == 0) {
throw new IllegalStateException("No arguments found");
}
if(forward == null) {
//bias needs to be added yet
if(args.length > 1) {
/*
forward = f().add(new Mmul(sameDiff, input[0],args()[0],
MMulTranspose.builder()
.transposeA(false)
.transposeB(true)
.build()).outputVariables()[0],args()[1]);
*/
} else {
forward = new Mmul(sameDiff, input[0],args()[0],
MMulTranspose.builder().transposeA(false).transposeB(true).build());
}
this.outputVariables = forward.outputVariables();
}
}
private static INDArray[] getParams(int nIn,
int nOut,
WeightInitScheme paramsScheme,
WeightInitScheme biasInitScheme) {
if(biasInitScheme != null) {
return new INDArray[] {paramsScheme.create(Nd4j.defaultFloatingPointType(), new long[]{nOut,nIn}),biasInitScheme.create(Nd4j.defaultFloatingPointType(), new long[]{nOut,1})};
}
else {
return new INDArray[] {paramsScheme.create(Nd4j.defaultFloatingPointType(), new long[]{nOut,nIn})};
}
}
}

View File

@ -38,11 +38,6 @@ public class L2Loss extends DynamicCustomOp {
super(sameDiff, new SDVariable[]{var}); super(sameDiff, new SDVariable[]{var});
} }
@Override
public List<LongShapeDescriptor> calculateOutputShape() {
return Collections.singletonList(LongShapeDescriptor.fromShape(new long[0], Nd4j.defaultFloatingPointType()));
}
@Override @Override
public String opName() { public String opName() {
return "l2_loss"; return "l2_loss";

View File

@ -47,11 +47,6 @@ public class HashCode extends DynamicCustomOp {
this.outputArguments.add(result); this.outputArguments.add(result);
} }
@Override
public List<LongShapeDescriptor> calculateOutputShape() {
return Collections.singletonList(LongShapeDescriptor.fromShape(new long[0], DataType.LONG));
}
@Override @Override
public String opName() { public String opName() {
return "hashcode"; return "hashcode";

View File

@ -1,78 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.shape;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
/**
* Broadcast function
*
* @author Adam Gibson
*/
public class Broadcast extends DynamicCustomOp {
private long[] shape;
public Broadcast(SameDiff sameDiff,SDVariable iX, long[] shape) {
super(null,sameDiff,new SDVariable[]{iX});
this.shape = shape;
}
public Broadcast() {}
@Override
public List<LongShapeDescriptor> calculateOutputShape() {
return Arrays.asList(LongShapeDescriptor.fromShape(shape, larg().dataType()));
}
@Override
public String opName() {
return "broadcast";
}
@Override
public List<SDVariable> doDiff(List<SDVariable> i_v) {
throw new UnsupportedOperationException();
}
@Override
public String onnxName() {
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
}
@Override
public String tensorflowName() {
throw new NoOpNameFoundException("No tensorflow op opName found for " + opName());
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes){
return Collections.singletonList(dataTypes.get(0));
}
}

View File

@ -82,12 +82,4 @@ public class Linspace extends DynamicCustomOp {
public List<SDVariable> doDiff(List<SDVariable> gradients){ public List<SDVariable> doDiff(List<SDVariable> gradients){
return Arrays.asList(f().zerosLike(arg(0)), f().zerosLike(arg(1)), f().zerosLike(arg(2))); return Arrays.asList(f().zerosLike(arg(0)), f().zerosLike(arg(1)), f().zerosLike(arg(2)));
} }
@Override
public List<LongShapeDescriptor> calculateOutputShape(){
INDArray l = arg(2).getArr();
if(l == null)
return Collections.emptyList();
return Collections.singletonList(LongShapeDescriptor.fromShape(new long[]{l.getLong(0)}, dataType));
}
} }

View File

@ -85,14 +85,6 @@ public class Rank extends DynamicCustomOp {
return "Rank"; return "Rank";
} }
@Override
public List<LongShapeDescriptor> calculateOutputShape() {
List<LongShapeDescriptor> ret = new ArrayList<>();
ret.add(LongShapeDescriptor.fromShape(new long[]{}, DataType.INT));
return ret;
}
@Override @Override
public List<SDVariable> doDiff(List<SDVariable> i_v) { public List<SDVariable> doDiff(List<SDVariable> i_v) {
return Collections.singletonList(sameDiff.zerosLike(arg())); return Collections.singletonList(sameDiff.zerosLike(arg()));

View File

@ -152,17 +152,4 @@ public class Unstack extends DynamicCustomOp {
return out; return out;
} }
@Override
public List<LongShapeDescriptor> calculateOutputShape(){
//TEMPORARY workaround for: https://github.com/deeplearning4j/deeplearning4j/issues/7093
if(inputArguments.size() == 1 && inputArguments.get(0).rank() == 1){
INDArray arr = inputArguments.get(0);
Preconditions.checkState(jaxis == 0, "Can only unstack along dimension 0 for rank 1 arrays, got axis %s for array %ndShape", jaxis, arr);
LongShapeDescriptor lsd = LongShapeDescriptor.fromShape(new long[0], arr.dataType());
List<LongShapeDescriptor> out = Arrays.asList(ArrayUtil.nTimes((int)arr.length(), lsd, LongShapeDescriptor.class));
return out;
}
return super.calculateOutputShape();
}
} }

View File

@ -32,7 +32,7 @@ import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
public abstract class BaseTensorOp extends DynamicCustomOp { public abstract class BaseTensorOp extends DynamicCustomOp {
public BaseTensorOp(String name, SameDiff sameDiff, SDVariable[] args){ public BaseTensorOp(String name, SameDiff sameDiff, SDVariable[] args){
super(name, sameDiff, args); super(name, sameDiff, args);
@ -78,8 +78,7 @@ public abstract class BaseTensorOp extends DynamicCustomOp {
@Override @Override
public List<LongShapeDescriptor> calculateOutputShape() { public List<LongShapeDescriptor> calculateOutputShape() {
//Not used/not required throw new UnsupportedOperationException("calculateOutputShape() is not supported for tensor ops.");
return Collections.emptyList();
} }
@Override @Override

View File

@ -41,12 +41,6 @@ public class TensorArraySize extends BaseTensorOp {
return "tensorarraysizev3"; return "tensorarraysizev3";
} }
@Override
public List<LongShapeDescriptor> calculateOutputShape() {
// output is scalar only
return Collections.singletonList(LongShapeDescriptor.fromShape(new long[]{}, DataType.LONG));
}
@Override @Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) { public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
super.initFromTensorFlow(nodeDef, initWith, attributesForNode, graph); super.initFromTensorFlow(nodeDef, initWith, attributesForNode, graph);

View File

@ -46,68 +46,6 @@ public abstract class BaseDynamicTransformOp extends DynamicCustomOp {
super(null, inputs, outputs); super(null, inputs, outputs);
} }
@Override
public List<LongShapeDescriptor> calculateOutputShape() {
long[] firstArgShape;
long[] secondArgShape;
DataType dtypeZ;
if(numInputArguments() == 2){
return super.calculateOutputShape(); //Use c++ shape calc, which also accounts for empty broadcast cases, etc
// firstArgShape = inputArguments.get(0).shape();
// secondArgShape = inputArguments.get(1).shape();
// dtypeZ = Shape.pickPairwiseDataType(inputArguments.get(0).dataType(), inputArguments.get(1).dataType());
} else {
val args = args();
if (args.length < 2) {
if (args[0] == null || (inputArguments.isEmpty() && args[0].getShape() == null)) {
return Collections.emptyList();
}
DataType dtypeX = !inputArguments.isEmpty() ? inputArguments.get(0).dataType() : args[0].dataType();
long[] shape = !inputArguments.isEmpty() ? inputArguments.get(0).shape() : args[0].getShape();
return Collections.singletonList(LongShapeDescriptor.fromShape(shape, dtypeX));
}
if(inputArguments.size() == 2 && inputArguments.get(0) != null && inputArguments.get(1) != null){
firstArgShape = inputArguments.get(0).shape();
secondArgShape = inputArguments.get(1).shape();
} else {
firstArgShape = args[0].getShape();
secondArgShape = args[1].getShape();
}
if (args[0] == null || args[0].getShape() == null) {
return Collections.emptyList();
}
if (args[1] == null || args[1].getShape() == null) {
return Collections.emptyList();
}
// detecting datatype based on both args
val dtypeX = inputArguments.size() > 0 ? inputArguments.get(0).dataType() : args[0].dataType();
val dtypeY = inputArguments.size() > 1 ? inputArguments.get(1).dataType() : args[1].dataType();
dtypeZ = Shape.pickPairwiseDataType(dtypeX, dtypeY);
}
if(Arrays.equals(firstArgShape, secondArgShape)){
try {
return Collections.singletonList(LongShapeDescriptor.fromShape(firstArgShape, dtypeZ));
} catch (Throwable e) {
throw new RuntimeException("calculateOutputShape() failed for [" + this.opName() + "]", e);
}
} else {
//Handle broadcast shape: [1,4]+[3,1] = [3,4]
Shape.assertBroadcastable(firstArgShape, secondArgShape, this.getClass());
val outShape = Shape.broadcastOutputShape(firstArgShape, secondArgShape);
return Collections.singletonList(LongShapeDescriptor.fromShape(outShape, dtypeZ));
}
}
@Override @Override
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes){ public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes){
Preconditions.checkState(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes for %s, got input %s", getClass(), dataTypes); Preconditions.checkState(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes for %s, got input %s", getClass(), dataTypes);

View File

@ -71,18 +71,6 @@ public class EqualTo extends BaseDynamicTransformOp {
return Arrays.asList(sameDiff.zerosLike(args()[0]), sameDiff.zerosLike(args()[1])); return Arrays.asList(sameDiff.zerosLike(args()[0]), sameDiff.zerosLike(args()[1]));
} }
@Override
public List<LongShapeDescriptor> calculateOutputShape() {
if (args() == null)
return Collections.emptyList();
if (inputArguments.size() == 0)
return Collections.emptyList();
return Nd4j.getExecutioner().calculateOutputShape(this);
}
@Override @Override
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes){ public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes){
Preconditions.checkState(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes for %s, got %s", getClass(), dataTypes); Preconditions.checkState(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes for %s, got %s", getClass(), dataTypes);

View File

@ -73,18 +73,6 @@ public class GreaterThan extends BaseDynamicTransformOp {
return Arrays.asList(sameDiff.zerosLike(args()[0]), sameDiff.zerosLike(args()[1])); return Arrays.asList(sameDiff.zerosLike(args()[0]), sameDiff.zerosLike(args()[1]));
} }
@Override
public List<LongShapeDescriptor> calculateOutputShape() {
if (args() == null)
return Collections.emptyList();
if (inputArguments.size() == 0)
return Collections.emptyList();
return Nd4j.getExecutioner().calculateOutputShape(this);
}
@Override @Override
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes){ public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes){
Preconditions.checkState(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes for %s, got %s", getClass(), dataTypes); Preconditions.checkState(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes for %s, got %s", getClass(), dataTypes);

View File

@ -76,18 +76,6 @@ public class GreaterThanOrEqual extends BaseDynamicTransformOp {
return Arrays.asList(sameDiff.zerosLike(args()[0]), sameDiff.zerosLike(args()[1])); return Arrays.asList(sameDiff.zerosLike(args()[0]), sameDiff.zerosLike(args()[1]));
} }
@Override
public List<LongShapeDescriptor> calculateOutputShape() {
if (args() == null)
return Collections.emptyList();
if (inputArguments.size() == 0)
return Collections.emptyList();
return Nd4j.getExecutioner().calculateOutputShape(this);
}
@Override @Override
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes){ public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes){
Preconditions.checkState(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes for %s, got %s", getClass(), dataTypes); Preconditions.checkState(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes for %s, got %s", getClass(), dataTypes);

View File

@ -73,18 +73,6 @@ public class LessThan extends BaseDynamicTransformOp {
return Arrays.asList(sameDiff.zerosLike(args()[0]), sameDiff.zerosLike(args()[1])); return Arrays.asList(sameDiff.zerosLike(args()[0]), sameDiff.zerosLike(args()[1]));
} }
@Override
public List<LongShapeDescriptor> calculateOutputShape() {
if (args() == null)
return Collections.emptyList();
if (inputArguments.size() == 0)
return Collections.emptyList();
return Nd4j.getExecutioner().calculateOutputShape(this);
}
@Override @Override
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes){ public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes){
Preconditions.checkState(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes for %s, got %s", getClass(), dataTypes); Preconditions.checkState(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes for %s, got %s", getClass(), dataTypes);

View File

@ -71,18 +71,6 @@ public class LessThanOrEqual extends BaseDynamicTransformOp {
return Arrays.asList(sameDiff.zerosLike(args()[0]), sameDiff.zerosLike(args()[1])); return Arrays.asList(sameDiff.zerosLike(args()[0]), sameDiff.zerosLike(args()[1]));
} }
@Override
public List<LongShapeDescriptor> calculateOutputShape() {
if (args() == null)
return Collections.emptyList();
if (inputArguments.size() == 0)
return Collections.emptyList();
return Nd4j.getExecutioner().calculateOutputShape(this);
}
@Override @Override
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes){ public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes){
Preconditions.checkState(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes for %s, got %s", getClass(), dataTypes); Preconditions.checkState(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes for %s, got %s", getClass(), dataTypes);

View File

@ -72,18 +72,6 @@ public class NotEqualTo extends BaseDynamicTransformOp {
return Arrays.asList(sameDiff.zerosLike(args()[0]), sameDiff.zerosLike(args()[1])); return Arrays.asList(sameDiff.zerosLike(args()[0]), sameDiff.zerosLike(args()[1]));
} }
@Override
public List<LongShapeDescriptor> calculateOutputShape() {
if (args() == null)
return Collections.emptyList();
if (inputArguments.size() == 0)
return Collections.emptyList();
return Nd4j.getExecutioner().calculateOutputShape(this);
}
@Override @Override
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes){ public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes){
Preconditions.checkState(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes for %s, got %s", getClass(), dataTypes); Preconditions.checkState(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes for %s, got %s", getClass(), dataTypes);

View File

@ -125,41 +125,6 @@ public class Cast extends BaseDynamicTransformOp {
return ret; return ret;
} }
@Override
public List<LongShapeDescriptor> calculateOutputShape() {
if(inputArguments.size() > 0){
long[] s = inputArguments.get(0).shape();
LongShapeDescriptor lsd = LongShapeDescriptor.fromShape(s, typeDst);
if(inputArguments.get(0).isEmpty()){
long e = lsd.getExtras();
e = ArrayOptionsHelper.setOptionBit(e, ArrayType.EMPTY);
lsd.setExtras(e);
}
return Collections.singletonList(lsd);
}
if (arg() != null && (arg().getArr() != null || arg().getShape() != null)) {
if (arg().getArr() != null) {
long[] s = arg().getArr().shape();
LongShapeDescriptor lsd = LongShapeDescriptor.fromShape(s, typeDst);
if(inputArguments.size() > 0 && inputArguments.get(0) != null && inputArguments.get(0).isEmpty()){
long e = lsd.getExtras();
e = ArrayOptionsHelper.setOptionBit(e, ArrayType.EMPTY);
lsd.setExtras(e);
}
return Collections.singletonList(lsd);
} else {
long[] s = arg().getShape();
if(Shape.isPlaceholderShape(s)){
return Collections.emptyList();
}
return Collections.singletonList(LongShapeDescriptor.fromShape(s, typeDst));
}
}
return Collections.emptyList();
}
@Override @Override
public void setValueFor(Field target, Object value) { public void setValueFor(Field target, Object value) {
//This is a hack around a property mapping issue - TF datatype DT_DOUBLE return attribute.getType() of DT_DOUBLE which doesn't make sense //This is a hack around a property mapping issue - TF datatype DT_DOUBLE return attribute.getType() of DT_DOUBLE which doesn't make sense

View File

@ -77,7 +77,7 @@ public class GradientBackwardsMarker extends DynamicCustomOp {
@Override @Override
public List<LongShapeDescriptor> calculateOutputShape() { public List<LongShapeDescriptor> calculateOutputShape() {
return Collections.singletonList(LongShapeDescriptor.fromShape(new long[0], DataType.FLOAT)); throw new UnsupportedOperationException("calculateOutputShape() is not supported for control flow ops.");
} }
@Override @Override

View File

@ -46,11 +46,6 @@ public abstract class BaseArithmeticBackpropOp extends BaseDynamicTransformOp {
throw new UnsupportedOperationException("Not supported"); throw new UnsupportedOperationException("Not supported");
} }
@Override
public List<LongShapeDescriptor> calculateOutputShape(){
return Nd4j.getExecutioner().calculateOutputShape(this);
}
@Override @Override
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes){ public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes){
Preconditions.checkState(dataTypes != null && dataTypes.size() == 3, "Expected exactly 3 input datatypes for %s, got input %s", getClass(), dataTypes); Preconditions.checkState(dataTypes != null && dataTypes.size() == 3, "Expected exactly 3 input datatypes for %s, got input %s", getClass(), dataTypes);

View File

@ -79,11 +79,4 @@ public class Identity extends BaseDynamicTransformOp {
return dataTypes; return dataTypes;
} }
@Override
public List<LongShapeDescriptor> calculateOutputShape() {
if(inputArguments == null || inputArguments.isEmpty())
return Collections.emptyList();
return Collections.singletonList(inputArguments.get(0).shapeDescriptor());
}
} }

View File

@ -1619,21 +1619,6 @@ public class ShapeOpValidation extends BaseOpValidation {
assertEquals(expected, result.eval()); assertEquals(expected, result.eval());
} }
@Test
public void testBroadcast() {
OpValidationSuite.ignoreFailing();
SameDiff sd = SameDiff.create();
SDVariable in = sd.var("in", Nd4j.rand(3, 4));
SDVariable broadcast = sd.f().broadcast(in, 3, 4, 5);
INDArray out = sd.execAndEndResult();
assertArrayEquals(new long[]{3, 4, 5}, out.shape());
for (int i = 0; i < 5; i++) {
assertEquals(in.getArr(), out.get(all(), all(), point(i)));
}
}
@Test @Test
public void testSlice2d() { public void testSlice2d() {

View File

@ -53,7 +53,6 @@ import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction; import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction;
import org.nd4j.linalg.api.ops.impl.layers.Linear;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig; import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.LocalResponseNormalizationConfig; import org.nd4j.linalg.api.ops.impl.layers.convolution.config.LocalResponseNormalizationConfig;
import org.nd4j.linalg.api.ops.impl.reduce3.ManhattanDistance; import org.nd4j.linalg.api.ops.impl.reduce3.ManhattanDistance;
@ -522,39 +521,6 @@ public class SameDiffTests extends BaseNd4jTest {
} }
@Test
public void testLinearModule() {
int nIn = 5;
Linear linear = Linear.execBuilder()
.nIn(nIn)
.nOut(4)
.weightInitScheme(new UniformInitScheme('f', nIn))
.biasWeightInitScheme(new ZeroInitScheme('f'))
.build();
linear.exec(Nd4j.linspace(1, 20, 20).reshape(4, 5));
assertEquals(1, linear.numOutputArguments());
}
@Test
public void testLinearModule2() {
Linear linear = Linear.execBuilder()
.nIn(3)
.nOut(2)
.weightInitScheme(new OneInitScheme('f'))
.biasWeightInitScheme(new ZeroInitScheme('f'))
.build();
linear.exec(Nd4j.linspace(1, 6, 6).reshape(2, 3));
INDArray assertion = Nd4j.create(new double[][]{
{6, 6},
{15, 15}
});
assertEquals(assertion, linear.outputArguments()[0]);
}
@Test @Test
public void testDefineFunctionArrayExistence() { public void testDefineFunctionArrayExistence() {
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
@ -3577,4 +3543,24 @@ public class SameDiffTests extends BaseNd4jTest {
assertEquals(e, mod.eval()); assertEquals(e, mod.eval());
} }
@Test
public void castShapeTest1(){
SameDiff sd = SameDiff.create();
SDVariable x = sd.constant(Nd4j.createFromArray(1, 2, 3, 4));
SDVariable casted = x.castTo(DataType.FLOAT);
assertEquals(casted.dataType(), DataType.FLOAT);
}
@Test
@Ignore // casted shape is null
public void castShapeTestEmpty(){
SameDiff sd = SameDiff.create();
SDVariable x = sd.constant(Nd4j.empty(DataType.INT));
SDVariable casted = x.castTo(DataType.FLOAT);
assertEquals(casted.dataType(), DataType.FLOAT);
assertTrue(casted.getShapeDescriptor().isEmpty());
}
} }

View File

@ -15,7 +15,7 @@ import org.nd4j.linalg.api.blas.params.MMulTranspose
import org.nd4j.linalg.api.buffer.DataType import org.nd4j.linalg.api.buffer.DataType
import org.nd4j.linalg.api.ndarray.INDArray import org.nd4j.linalg.api.ndarray.INDArray
import org.nd4j.linalg.api.ops.DynamicCustomOp import org.nd4j.linalg.api.ops.DynamicCustomOp
import org.nd4j.linalg.api.ops.impl.layers.{ ExternalErrorsFunction, Linear } import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.{ Conv2DConfig, LocalResponseNormalizationConfig } import org.nd4j.linalg.api.ops.impl.layers.convolution.config.{ Conv2DConfig, LocalResponseNormalizationConfig }
import org.nd4j.linalg.api.ops.impl.reduce3.ManhattanDistance import org.nd4j.linalg.api.ops.impl.reduce3.ManhattanDistance
import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArray import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArray