diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java index e16bd3dc2..1f361ce9c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java @@ -163,7 +163,6 @@ import org.nd4j.linalg.api.ops.impl.scatter.ScatterMin; import org.nd4j.linalg.api.ops.impl.scatter.ScatterMul; import org.nd4j.linalg.api.ops.impl.scatter.ScatterSub; import org.nd4j.linalg.api.ops.impl.scatter.ScatterUpdate; -import org.nd4j.linalg.api.ops.impl.shape.Broadcast; import org.nd4j.linalg.api.ops.impl.shape.Concat; import org.nd4j.linalg.api.ops.impl.shape.ConfusionMatrix; import org.nd4j.linalg.api.ops.impl.shape.Cross; @@ -1450,14 +1449,6 @@ public class DifferentialFunctionFactory { return new MatrixInverse(sameDiff(), in, false).outputVariable(); } - public SDVariable broadcast(SDVariable iX, int... shape) { - return broadcast(iX, ArrayUtil.toLongArray(shape)); - } - - public SDVariable broadcast(SDVariable iX, long... shape) { - return new Broadcast(sameDiff(), iX, shape).outputVariable(); - } - public SDVariable onehot(SDVariable indices, int depth, int axis, double on, double off, DataType dataType) { return new OneHot(sameDiff(), indices, depth, axis, on, off, dataType).outputVariable(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SDVariable.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SDVariable.java index 430b4d83a..0d2700b43 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SDVariable.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SDVariable.java @@ -111,9 +111,6 @@ public class SDVariable implements Serializable { return variableType == VariableType.CONSTANT; } - - - /** * Allocate and return a new array * based on the vertex id and weight initialization. diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java index 3f270e342..e2d3cf87b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java @@ -98,7 +98,6 @@ public class ImportClassMapping { org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax.class, org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin.class, org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction.class, - org.nd4j.linalg.api.ops.impl.layers.Linear.class, org.nd4j.linalg.api.ops.impl.layers.convolution.AvgPooling2D.class, org.nd4j.linalg.api.ops.impl.layers.convolution.AvgPooling3D.class, org.nd4j.linalg.api.ops.impl.layers.convolution.BatchNorm.class, @@ -267,7 +266,6 @@ public class ImportClassMapping { org.nd4j.linalg.api.ops.impl.scatter.ScatterSub.class, org.nd4j.linalg.api.ops.impl.scatter.ScatterUpdate.class, org.nd4j.linalg.api.ops.impl.shape.ApplyGradientDescent.class, - org.nd4j.linalg.api.ops.impl.shape.Broadcast.class, org.nd4j.linalg.api.ops.impl.shape.BroadcastDynamicShape.class, org.nd4j.linalg.api.ops.impl.shape.Concat.class, org.nd4j.linalg.api.ops.impl.shape.ConfusionMatrix.class, diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseModule.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseModule.java deleted file mode 100644 index 0826ac5f9..000000000 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseModule.java +++ /dev/null @@ -1,63 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.linalg.api.ops; - -import lombok.NoArgsConstructor; -import org.nd4j.autodiff.samediff.SDVariable; -import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.linalg.api.ndarray.INDArray; - -import java.util.ArrayList; -import java.util.List; - -/** - * Abstract base class for {@link Module} - * that handles Dynamic ops and handles nesting. - * - * This is a logical unit for defining layers - * very similar to pytorch's modules, or tensorflow's layers. - * - * @author Adam Gibson - */ -@NoArgsConstructor -public abstract class BaseModule extends DynamicCustomOp implements Module { - private List modules = new ArrayList<>(); - - public BaseModule(String opName, INDArray[] inputs, INDArray[] outputs, List tArguments, List iArguments, List modules) { - super(opName, inputs, outputs, tArguments, iArguments); - this.modules = modules; - } - - public BaseModule(String opName, SameDiff sameDiff, SDVariable[] args, boolean inPlace, List modules) { - super(opName, sameDiff, args, inPlace); - this.modules = modules; - } - - @Override - public Module[] subModules() { - return modules.toArray(new Module[modules.size()]); - } - - @Override - public void addModule(Module module) { - modules.add(module); - } - - - - -} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/Module.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/Module.java deleted file mode 100644 index f41f94e5c..000000000 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/Module.java +++ /dev/null @@ -1,51 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.linalg.api.ops; - -import org.nd4j.autodiff.samediff.SDVariable; -import org.nd4j.linalg.api.ndarray.INDArray; - -/** - * A Module is a {@link CustomOp} - * with varying input arguments - * and automatically calculated outputs - * defined at a higher level than c++. - * - * A Module is meant to be a way of implementing custom operations - * in straight java/nd4j. - */ -public interface Module extends CustomOp { - - /** - * - * @param inputs - */ - void exec(INDArray... inputs); - - - Module[] subModules(); - - - void addModule(Module module); - - - void execSameDiff(SDVariable... input); - - - - -} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/NoOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/NoOp.java index 6b174bd07..554ad917e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/NoOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/NoOp.java @@ -80,6 +80,7 @@ public class NoOp extends DynamicCustomOp { return 1; } + @Override public List calculateOutputShape(){ if(inputArguments != null && !inputArguments.isEmpty()){ diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/BaseCompatOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/BaseCompatOp.java index 4c0fb2e6b..3f56096a2 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/BaseCompatOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/BaseCompatOp.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.controlflow.compat; +import java.util.List; import lombok.NonNull; import lombok.val; import org.nd4j.autodiff.samediff.SDVariable; @@ -25,6 +26,7 @@ import org.nd4j.imports.descriptors.properties.AttributeAdapter; import org.nd4j.imports.descriptors.properties.PropertyMapping; import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.shape.LongShapeDescriptor; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.NodeDef; @@ -89,4 +91,9 @@ public abstract class BaseCompatOp extends DynamicCustomOp { public Map> attributeAdaptersForFunction() { return super.attributeAdaptersForFunction(); } + + @Override + public List calculateOutputShape() { + throw new UnsupportedOperationException("calculateOutputShape() is not supported for control flow ops."); + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/Enter.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/Enter.java index 85a94eb13..52705ce9e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/Enter.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/Enter.java @@ -68,15 +68,6 @@ public class Enter extends BaseCompatOp { return OP_NAME; } - @Override - public List calculateOutputShape() { - if(arg().getArr() != null) { - return Collections.singletonList(LongShapeDescriptor.fromShape(arg().getShape(), arg().getArr().dataType())); - } - else - return Collections.emptyList(); - } - @Override public SDVariable[] outputVariables() { return super.outputVariables(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/Exit.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/Exit.java index f9e358f3c..a7fecf03c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/Exit.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/Exit.java @@ -56,15 +56,6 @@ public class Exit extends BaseCompatOp { return OP_NAME; } - @Override - public List calculateOutputShape() { - if(arg().getArr() != null) { - return Collections.singletonList(arg().getShapeDescriptor()); - } - else - return Collections.emptyList(); - } - @Override public SDVariable[] outputVariables() { return super.outputVariables(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/LoopCond.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/LoopCond.java index a3ace4f13..4f5d11b38 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/LoopCond.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/LoopCond.java @@ -37,15 +37,6 @@ public class LoopCond extends BaseCompatOp { return "loop_cond"; } - @Override - public List calculateOutputShape() { - if(arg().getArr() != null) { - return Collections.singletonList(LongShapeDescriptor.fromShape(arg().getShape(), arg().dataType())); - } - else - return Collections.emptyList(); - } - @Override public SDVariable[] outputVariables() { return super.outputVariables(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/Merge.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/Merge.java index 386f4a075..6cf52e7b5 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/Merge.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/Merge.java @@ -64,15 +64,6 @@ public class Merge extends BaseCompatOp { return 60L; } - @Override - public List calculateOutputShape() { - if(arg().getArr() != null) { - return Collections.singletonList(arg().getShapeDescriptor()); - } - else - return Collections.emptyList(); - } - @Override public SDVariable[] outputVariables() { return super.outputVariables(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/NextIteration.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/NextIteration.java index fabd0479b..367a134a3 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/NextIteration.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/NextIteration.java @@ -53,15 +53,6 @@ public class NextIteration extends BaseCompatOp { return OP_NAME; } - @Override - public List calculateOutputShape() { - if(arg().getArr() != null) { - return Collections.singletonList(LongShapeDescriptor.fromShape(arg().getShape(), arg().dataType())); - } - else - return Collections.emptyList(); - } - @Override public SDVariable[] outputVariables() { return super.outputVariables(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/Switch.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/Switch.java index 77145a625..331dea887 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/Switch.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/Switch.java @@ -63,18 +63,6 @@ public class Switch extends BaseCompatOp { return OP_NAME; } - @Override - public List calculateOutputShape() { - if(args()[0].getArr() != null) { - val arg0 = args()[0]; - val arr0 = arg0.getArr(); - val dtype = arr0.dataType(); - return Arrays.asList(LongShapeDescriptor.fromShape(arg0.getShape(), dtype),LongShapeDescriptor.fromShape(arg0.getShape(), dtype)); - } - else - return Collections.emptyList(); - } - @Override public SDVariable[] outputVariables() { return super.outputVariables(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/Linear.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/Linear.java deleted file mode 100644 index 27f357b4b..000000000 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/Linear.java +++ /dev/null @@ -1,198 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.linalg.api.ops.impl.layers; - -import lombok.Builder; -import lombok.NoArgsConstructor; -import lombok.val; -import onnx.Onnx; -import org.nd4j.autodiff.functions.DifferentialFunction; -import org.nd4j.autodiff.samediff.SDVariable; -import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.autodiff.samediff.VariableType; -import org.nd4j.imports.NoOpNameFoundException; -import org.nd4j.linalg.api.blas.params.MMulTranspose; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.BaseModule; -import org.nd4j.linalg.api.ops.Module; -import org.nd4j.linalg.api.ops.impl.reduce.Mmul; -import org.nd4j.linalg.api.shape.LongShapeDescriptor; -import org.nd4j.linalg.api.shape.Shape; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.weightinit.WeightInitScheme; -import org.nd4j.weightinit.impl.ZeroInitScheme; -import org.tensorflow.framework.AttrValue; -import org.tensorflow.framework.GraphDef; -import org.tensorflow.framework.NodeDef; - -import java.util.ArrayList; -import java.util.List; -import java.util.Map; - -/** - * Linear: - * a * bT - * - * @author Adam Gibson - */ -@NoArgsConstructor -public class Linear extends BaseModule { - private DifferentialFunction forward; - private int nIn,nOut; - private WeightInitScheme weightInitScheme,biasWeightInitScheme; - - @Builder(builderMethodName = "execBuilder") - public Linear(int nIn, - int nOut, - WeightInitScheme weightInitScheme, - WeightInitScheme biasWeightInitScheme) { - super(null, - getParams(nIn,nOut,weightInitScheme,biasWeightInitScheme), - new INDArray[]{}, - new ArrayList(), new ArrayList(), new ArrayList()); - this.weightInitScheme = weightInitScheme; - this.biasWeightInitScheme = biasWeightInitScheme; - this.nIn = nIn; - this.nOut = nOut; - } - - @Builder(builderMethodName = "sameDiffBuilder") - public Linear(SameDiff sameDiff, - int nIn, - int nOut, - WeightInitScheme weightInitScheme, - WeightInitScheme biasWeightInitScheme) { - super(null, sameDiff, null, false, new ArrayList()); - this.weightInitScheme = weightInitScheme; - this.biasWeightInitScheme = biasWeightInitScheme; - - this.nIn = nIn; - this.nOut = nOut; - - } - - @Override - public String opName() { - return "linear"; - } - - @Override - public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - - } - - @Override - public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map attributesForNode, Onnx.GraphProto graph) { - - } - - @Override - public List doDiff(List f1) { - execSameDiff(); - return forward.doDiff(f1); - } - - @Override - public List calculateOutputShape() { - List ret = new ArrayList<>(); - ret.add(LongShapeDescriptor.fromShape(Shape.getMatrixMultiplyShape(inputArguments()[0].shape(),new long[]{nOut,nIn}), inputArguments()[1].dataType())); - - ret.add(LongShapeDescriptor.fromShape(Shape.getMatrixMultiplyShape(inputArguments()[0].shape(),inputArguments()[1].transpose().shape()), inputArguments()[1].dataType())); - if(biasWeightInitScheme != null) { - ret.add(LongShapeDescriptor.fromShape(new long[]{nOut,1}, inputArguments()[1].dataType())); - } - return ret; - } - - @Override - public String onnxName() { - throw new NoOpNameFoundException("No onnx op opName found for " + opName()); - } - - @Override - public String tensorflowName() { - throw new NoOpNameFoundException("No tensorflow op opName found for " + opName()); - } - - - - @Override - public void exec(INDArray... inputs) { - val inputArguments = inputArguments(); - if(inputArguments == null || inputArguments.length < 1) { - throw new IllegalStateException("No arguments found."); - } - - INDArray weights = inputArguments[0]; - INDArray right = inputArguments[1]; - - val outputArguments = outputArguments(); - - if(outputArguments == null || outputArguments.length < 1) { - if(inputArguments.length == 1) - addOutputArgument(inputs[0].mmul(weights.transpose())); - else - addOutputArgument(inputs[0].mmul(weights.transpose()).addiColumnVector(right)); - - } - else { - inputs[0].mmul(weights.transpose(),outputArguments[0]); - } - - } - - @Override - public void execSameDiff(SDVariable... input) { - val args = args(); - if(args == null || args.length == 0) { - throw new IllegalStateException("No arguments found"); - } - - if(forward == null) { - //bias needs to be added yet - if(args.length > 1) { - /* - forward = f().add(new Mmul(sameDiff, input[0],args()[0], - MMulTranspose.builder() - .transposeA(false) - .transposeB(true) - .build()).outputVariables()[0],args()[1]); - */ - } else { - forward = new Mmul(sameDiff, input[0],args()[0], - MMulTranspose.builder().transposeA(false).transposeB(true).build()); - } - - this.outputVariables = forward.outputVariables(); - } - - - } - - private static INDArray[] getParams(int nIn, - int nOut, - WeightInitScheme paramsScheme, - WeightInitScheme biasInitScheme) { - if(biasInitScheme != null) { - return new INDArray[] {paramsScheme.create(Nd4j.defaultFloatingPointType(), new long[]{nOut,nIn}),biasInitScheme.create(Nd4j.defaultFloatingPointType(), new long[]{nOut,1})}; - } - else { - return new INDArray[] {paramsScheme.create(Nd4j.defaultFloatingPointType(), new long[]{nOut,nIn})}; - - } - } -} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/L2Loss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/L2Loss.java index cfef2a61b..21941de93 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/L2Loss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/L2Loss.java @@ -38,11 +38,6 @@ public class L2Loss extends DynamicCustomOp { super(sameDiff, new SDVariable[]{var}); } - @Override - public List calculateOutputShape() { - return Collections.singletonList(LongShapeDescriptor.fromShape(new long[0], Nd4j.defaultFloatingPointType())); - } - @Override public String opName() { return "l2_loss"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/HashCode.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/HashCode.java index 4ad8fc5d9..cdddee2f6 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/HashCode.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/HashCode.java @@ -47,11 +47,6 @@ public class HashCode extends DynamicCustomOp { this.outputArguments.add(result); } - @Override - public List calculateOutputShape() { - return Collections.singletonList(LongShapeDescriptor.fromShape(new long[0], DataType.LONG)); - } - @Override public String opName() { return "hashcode"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Broadcast.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Broadcast.java deleted file mode 100644 index 8e2f6c790..000000000 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Broadcast.java +++ /dev/null @@ -1,78 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.linalg.api.ops.impl.shape; - -import org.nd4j.autodiff.samediff.SDVariable; -import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.base.Preconditions; -import org.nd4j.imports.NoOpNameFoundException; -import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.api.ops.DynamicCustomOp; -import org.nd4j.linalg.api.shape.LongShapeDescriptor; - -import java.util.Arrays; -import java.util.Collections; -import java.util.List; - -/** - * Broadcast function - * - * @author Adam Gibson - */ -public class Broadcast extends DynamicCustomOp { - private long[] shape; - public Broadcast(SameDiff sameDiff,SDVariable iX, long[] shape) { - super(null,sameDiff,new SDVariable[]{iX}); - this.shape = shape; - } - - - public Broadcast() {} - - - @Override - public List calculateOutputShape() { - return Arrays.asList(LongShapeDescriptor.fromShape(shape, larg().dataType())); - } - - @Override - public String opName() { - return "broadcast"; - } - - - - @Override - public List doDiff(List i_v) { - throw new UnsupportedOperationException(); - } - - @Override - public String onnxName() { - throw new NoOpNameFoundException("No onnx op opName found for " + opName()); - } - - @Override - public String tensorflowName() { - throw new NoOpNameFoundException("No tensorflow op opName found for " + opName()); - } - - @Override - public List calculateOutputDataTypes(List dataTypes){ - return Collections.singletonList(dataTypes.get(0)); - } -} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Linspace.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Linspace.java index a11b241e5..1ad13ad40 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Linspace.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Linspace.java @@ -82,12 +82,4 @@ public class Linspace extends DynamicCustomOp { public List doDiff(List gradients){ return Arrays.asList(f().zerosLike(arg(0)), f().zerosLike(arg(1)), f().zerosLike(arg(2))); } - - @Override - public List calculateOutputShape(){ - INDArray l = arg(2).getArr(); - if(l == null) - return Collections.emptyList(); - return Collections.singletonList(LongShapeDescriptor.fromShape(new long[]{l.getLong(0)}, dataType)); - } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Rank.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Rank.java index aacfa19e1..568b14a44 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Rank.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Rank.java @@ -85,14 +85,6 @@ public class Rank extends DynamicCustomOp { return "Rank"; } - @Override - public List calculateOutputShape() { - List ret = new ArrayList<>(); - ret.add(LongShapeDescriptor.fromShape(new long[]{}, DataType.INT)); - return ret; - } - - @Override public List doDiff(List i_v) { return Collections.singletonList(sameDiff.zerosLike(arg())); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Unstack.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Unstack.java index 9dd6b6338..f9eb1f95a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Unstack.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Unstack.java @@ -152,17 +152,4 @@ public class Unstack extends DynamicCustomOp { return out; } - @Override - public List calculateOutputShape(){ - //TEMPORARY workaround for: https://github.com/deeplearning4j/deeplearning4j/issues/7093 - if(inputArguments.size() == 1 && inputArguments.get(0).rank() == 1){ - INDArray arr = inputArguments.get(0); - Preconditions.checkState(jaxis == 0, "Can only unstack along dimension 0 for rank 1 arrays, got axis %s for array %ndShape", jaxis, arr); - LongShapeDescriptor lsd = LongShapeDescriptor.fromShape(new long[0], arr.dataType()); - List out = Arrays.asList(ArrayUtil.nTimes((int)arr.length(), lsd, LongShapeDescriptor.class)); - return out; - } - return super.calculateOutputShape(); - } - } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/BaseTensorOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/BaseTensorOp.java index bcc4e1844..b027750fc 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/BaseTensorOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/BaseTensorOp.java @@ -32,7 +32,7 @@ import java.util.Collections; import java.util.List; import java.util.Map; -public abstract class BaseTensorOp extends DynamicCustomOp { +public abstract class BaseTensorOp extends DynamicCustomOp { public BaseTensorOp(String name, SameDiff sameDiff, SDVariable[] args){ super(name, sameDiff, args); @@ -78,8 +78,7 @@ public abstract class BaseTensorOp extends DynamicCustomOp { @Override public List calculateOutputShape() { - //Not used/not required - return Collections.emptyList(); + throw new UnsupportedOperationException("calculateOutputShape() is not supported for tensor ops."); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/TensorArraySize.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/TensorArraySize.java index 276dadcab..0503d377b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/TensorArraySize.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/TensorArraySize.java @@ -41,12 +41,6 @@ public class TensorArraySize extends BaseTensorOp { return "tensorarraysizev3"; } - @Override - public List calculateOutputShape() { - // output is scalar only - return Collections.singletonList(LongShapeDescriptor.fromShape(new long[]{}, DataType.LONG)); - } - @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { super.initFromTensorFlow(nodeDef, initWith, attributesForNode, graph); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/BaseDynamicTransformOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/BaseDynamicTransformOp.java index aa85461e6..ae4a21df7 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/BaseDynamicTransformOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/BaseDynamicTransformOp.java @@ -46,68 +46,6 @@ public abstract class BaseDynamicTransformOp extends DynamicCustomOp { super(null, inputs, outputs); } - - @Override - public List calculateOutputShape() { - long[] firstArgShape; - long[] secondArgShape; - DataType dtypeZ; - - if(numInputArguments() == 2){ - return super.calculateOutputShape(); //Use c++ shape calc, which also accounts for empty broadcast cases, etc -// firstArgShape = inputArguments.get(0).shape(); -// secondArgShape = inputArguments.get(1).shape(); -// dtypeZ = Shape.pickPairwiseDataType(inputArguments.get(0).dataType(), inputArguments.get(1).dataType()); - } else { - val args = args(); - if (args.length < 2) { - if (args[0] == null || (inputArguments.isEmpty() && args[0].getShape() == null)) { - return Collections.emptyList(); - } - DataType dtypeX = !inputArguments.isEmpty() ? inputArguments.get(0).dataType() : args[0].dataType(); - long[] shape = !inputArguments.isEmpty() ? inputArguments.get(0).shape() : args[0].getShape(); - - return Collections.singletonList(LongShapeDescriptor.fromShape(shape, dtypeX)); - } - - if(inputArguments.size() == 2 && inputArguments.get(0) != null && inputArguments.get(1) != null){ - firstArgShape = inputArguments.get(0).shape(); - secondArgShape = inputArguments.get(1).shape(); - } else { - firstArgShape = args[0].getShape(); - secondArgShape = args[1].getShape(); - } - if (args[0] == null || args[0].getShape() == null) { - return Collections.emptyList(); - } - - if (args[1] == null || args[1].getShape() == null) { - return Collections.emptyList(); - } - - // detecting datatype based on both args - val dtypeX = inputArguments.size() > 0 ? inputArguments.get(0).dataType() : args[0].dataType(); - val dtypeY = inputArguments.size() > 1 ? inputArguments.get(1).dataType() : args[1].dataType(); - dtypeZ = Shape.pickPairwiseDataType(dtypeX, dtypeY); - } - - - - if(Arrays.equals(firstArgShape, secondArgShape)){ - try { - return Collections.singletonList(LongShapeDescriptor.fromShape(firstArgShape, dtypeZ)); - } catch (Throwable e) { - throw new RuntimeException("calculateOutputShape() failed for [" + this.opName() + "]", e); - } - } else { - //Handle broadcast shape: [1,4]+[3,1] = [3,4] - Shape.assertBroadcastable(firstArgShape, secondArgShape, this.getClass()); - val outShape = Shape.broadcastOutputShape(firstArgShape, secondArgShape); - - return Collections.singletonList(LongShapeDescriptor.fromShape(outShape, dtypeZ)); - } - } - @Override public List calculateOutputDataTypes(List dataTypes){ Preconditions.checkState(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes for %s, got input %s", getClass(), dataTypes); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/EqualTo.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/EqualTo.java index b786602d3..95019c5b2 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/EqualTo.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/EqualTo.java @@ -71,18 +71,6 @@ public class EqualTo extends BaseDynamicTransformOp { return Arrays.asList(sameDiff.zerosLike(args()[0]), sameDiff.zerosLike(args()[1])); } - @Override - public List calculateOutputShape() { - if (args() == null) - return Collections.emptyList(); - - if (inputArguments.size() == 0) - return Collections.emptyList(); - - - return Nd4j.getExecutioner().calculateOutputShape(this); - } - @Override public List calculateOutputDataTypes(List dataTypes){ Preconditions.checkState(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes for %s, got %s", getClass(), dataTypes); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/GreaterThan.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/GreaterThan.java index 27b2ea189..72a46f111 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/GreaterThan.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/GreaterThan.java @@ -73,18 +73,6 @@ public class GreaterThan extends BaseDynamicTransformOp { return Arrays.asList(sameDiff.zerosLike(args()[0]), sameDiff.zerosLike(args()[1])); } - @Override - public List calculateOutputShape() { - if (args() == null) - return Collections.emptyList(); - - if (inputArguments.size() == 0) - return Collections.emptyList(); - - - return Nd4j.getExecutioner().calculateOutputShape(this); - } - @Override public List calculateOutputDataTypes(List dataTypes){ Preconditions.checkState(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes for %s, got %s", getClass(), dataTypes); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/GreaterThanOrEqual.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/GreaterThanOrEqual.java index 48d3953aa..50d1c7c43 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/GreaterThanOrEqual.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/GreaterThanOrEqual.java @@ -76,18 +76,6 @@ public class GreaterThanOrEqual extends BaseDynamicTransformOp { return Arrays.asList(sameDiff.zerosLike(args()[0]), sameDiff.zerosLike(args()[1])); } - @Override - public List calculateOutputShape() { - if (args() == null) - return Collections.emptyList(); - - if (inputArguments.size() == 0) - return Collections.emptyList(); - - - return Nd4j.getExecutioner().calculateOutputShape(this); - } - @Override public List calculateOutputDataTypes(List dataTypes){ Preconditions.checkState(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes for %s, got %s", getClass(), dataTypes); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LessThan.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LessThan.java index 0ee59458c..4c345070e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LessThan.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LessThan.java @@ -73,18 +73,6 @@ public class LessThan extends BaseDynamicTransformOp { return Arrays.asList(sameDiff.zerosLike(args()[0]), sameDiff.zerosLike(args()[1])); } - @Override - public List calculateOutputShape() { - if (args() == null) - return Collections.emptyList(); - - if (inputArguments.size() == 0) - return Collections.emptyList(); - - - return Nd4j.getExecutioner().calculateOutputShape(this); - } - @Override public List calculateOutputDataTypes(List dataTypes){ Preconditions.checkState(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes for %s, got %s", getClass(), dataTypes); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LessThanOrEqual.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LessThanOrEqual.java index 56a5882db..89c08fe65 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LessThanOrEqual.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LessThanOrEqual.java @@ -71,18 +71,6 @@ public class LessThanOrEqual extends BaseDynamicTransformOp { return Arrays.asList(sameDiff.zerosLike(args()[0]), sameDiff.zerosLike(args()[1])); } - @Override - public List calculateOutputShape() { - if (args() == null) - return Collections.emptyList(); - - if (inputArguments.size() == 0) - return Collections.emptyList(); - - - return Nd4j.getExecutioner().calculateOutputShape(this); - } - @Override public List calculateOutputDataTypes(List dataTypes){ Preconditions.checkState(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes for %s, got %s", getClass(), dataTypes); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/NotEqualTo.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/NotEqualTo.java index b1f0dadbb..62d3bedfa 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/NotEqualTo.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/NotEqualTo.java @@ -72,18 +72,6 @@ public class NotEqualTo extends BaseDynamicTransformOp { return Arrays.asList(sameDiff.zerosLike(args()[0]), sameDiff.zerosLike(args()[1])); } - @Override - public List calculateOutputShape() { - if (args() == null) - return Collections.emptyList(); - - if (inputArguments.size() == 0) - return Collections.emptyList(); - - - return Nd4j.getExecutioner().calculateOutputShape(this); - } - @Override public List calculateOutputDataTypes(List dataTypes){ Preconditions.checkState(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes for %s, got %s", getClass(), dataTypes); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/dtype/Cast.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/dtype/Cast.java index 01408ea4d..98a479542 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/dtype/Cast.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/dtype/Cast.java @@ -125,41 +125,6 @@ public class Cast extends BaseDynamicTransformOp { return ret; } - @Override - public List calculateOutputShape() { - if(inputArguments.size() > 0){ - long[] s = inputArguments.get(0).shape(); - LongShapeDescriptor lsd = LongShapeDescriptor.fromShape(s, typeDst); - if(inputArguments.get(0).isEmpty()){ - long e = lsd.getExtras(); - e = ArrayOptionsHelper.setOptionBit(e, ArrayType.EMPTY); - lsd.setExtras(e); - } - return Collections.singletonList(lsd); - } - - if (arg() != null && (arg().getArr() != null || arg().getShape() != null)) { - if (arg().getArr() != null) { - long[] s = arg().getArr().shape(); - LongShapeDescriptor lsd = LongShapeDescriptor.fromShape(s, typeDst); - if(inputArguments.size() > 0 && inputArguments.get(0) != null && inputArguments.get(0).isEmpty()){ - long e = lsd.getExtras(); - e = ArrayOptionsHelper.setOptionBit(e, ArrayType.EMPTY); - lsd.setExtras(e); - } - return Collections.singletonList(lsd); - } else { - long[] s = arg().getShape(); - if(Shape.isPlaceholderShape(s)){ - return Collections.emptyList(); - } - return Collections.singletonList(LongShapeDescriptor.fromShape(s, typeDst)); - } - } - - return Collections.emptyList(); - } - @Override public void setValueFor(Field target, Object value) { //This is a hack around a property mapping issue - TF datatype DT_DOUBLE return attribute.getType() of DT_DOUBLE which doesn't make sense diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/GradientBackwardsMarker.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/GradientBackwardsMarker.java index d61f7a067..10abb69a2 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/GradientBackwardsMarker.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/GradientBackwardsMarker.java @@ -77,7 +77,7 @@ public class GradientBackwardsMarker extends DynamicCustomOp { @Override public List calculateOutputShape() { - return Collections.singletonList(LongShapeDescriptor.fromShape(new long[0], DataType.FLOAT)); + throw new UnsupportedOperationException("calculateOutputShape() is not supported for control flow ops."); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/bp/BaseArithmeticBackpropOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/bp/BaseArithmeticBackpropOp.java index dcbcc271f..44648cc4a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/bp/BaseArithmeticBackpropOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/bp/BaseArithmeticBackpropOp.java @@ -46,11 +46,6 @@ public abstract class BaseArithmeticBackpropOp extends BaseDynamicTransformOp { throw new UnsupportedOperationException("Not supported"); } - @Override - public List calculateOutputShape(){ - return Nd4j.getExecutioner().calculateOutputShape(this); - } - @Override public List calculateOutputDataTypes(List dataTypes){ Preconditions.checkState(dataTypes != null && dataTypes.size() == 3, "Expected exactly 3 input datatypes for %s, got input %s", getClass(), dataTypes); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Identity.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Identity.java index 3e548226f..4b5825d8f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Identity.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Identity.java @@ -79,11 +79,4 @@ public class Identity extends BaseDynamicTransformOp { return dataTypes; } - @Override - public List calculateOutputShape() { - if(inputArguments == null || inputArguments.isEmpty()) - return Collections.emptyList(); - return Collections.singletonList(inputArguments.get(0).shapeDescriptor()); - } - } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java index ffb585183..bf4b331a3 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java @@ -1619,21 +1619,6 @@ public class ShapeOpValidation extends BaseOpValidation { assertEquals(expected, result.eval()); } - @Test - public void testBroadcast() { - OpValidationSuite.ignoreFailing(); - SameDiff sd = SameDiff.create(); - SDVariable in = sd.var("in", Nd4j.rand(3, 4)); - SDVariable broadcast = sd.f().broadcast(in, 3, 4, 5); - - INDArray out = sd.execAndEndResult(); - assertArrayEquals(new long[]{3, 4, 5}, out.shape()); - - for (int i = 0; i < 5; i++) { - assertEquals(in.getArr(), out.get(all(), all(), point(i))); - } - } - @Test public void testSlice2d() { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java index ef6d1268b..bb15b3392 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java @@ -53,7 +53,6 @@ import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction; -import org.nd4j.linalg.api.ops.impl.layers.Linear; import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig; import org.nd4j.linalg.api.ops.impl.layers.convolution.config.LocalResponseNormalizationConfig; import org.nd4j.linalg.api.ops.impl.reduce3.ManhattanDistance; @@ -522,39 +521,6 @@ public class SameDiffTests extends BaseNd4jTest { } - @Test - public void testLinearModule() { - int nIn = 5; - Linear linear = Linear.execBuilder() - .nIn(nIn) - .nOut(4) - .weightInitScheme(new UniformInitScheme('f', nIn)) - .biasWeightInitScheme(new ZeroInitScheme('f')) - .build(); - linear.exec(Nd4j.linspace(1, 20, 20).reshape(4, 5)); - assertEquals(1, linear.numOutputArguments()); - - } - - - @Test - public void testLinearModule2() { - Linear linear = Linear.execBuilder() - .nIn(3) - .nOut(2) - .weightInitScheme(new OneInitScheme('f')) - .biasWeightInitScheme(new ZeroInitScheme('f')) - .build(); - linear.exec(Nd4j.linspace(1, 6, 6).reshape(2, 3)); - INDArray assertion = Nd4j.create(new double[][]{ - {6, 6}, - {15, 15} - }); - assertEquals(assertion, linear.outputArguments()[0]); - - } - - @Test public void testDefineFunctionArrayExistence() { SameDiff sameDiff = SameDiff.create(); @@ -3577,4 +3543,24 @@ public class SameDiffTests extends BaseNd4jTest { assertEquals(e, mod.eval()); } + + @Test + public void castShapeTest1(){ + SameDiff sd = SameDiff.create(); + SDVariable x = sd.constant(Nd4j.createFromArray(1, 2, 3, 4)); + SDVariable casted = x.castTo(DataType.FLOAT); + + assertEquals(casted.dataType(), DataType.FLOAT); + } + + @Test + @Ignore // casted shape is null + public void castShapeTestEmpty(){ + SameDiff sd = SameDiff.create(); + SDVariable x = sd.constant(Nd4j.empty(DataType.INT)); + SDVariable casted = x.castTo(DataType.FLOAT); + + assertEquals(casted.dataType(), DataType.FLOAT); + assertTrue(casted.getShapeDescriptor().isEmpty()); + } } diff --git a/nd4s/src/test/scala/org/nd4s/samediff/SameDiffTest.scala b/nd4s/src/test/scala/org/nd4s/samediff/SameDiffTest.scala index c9ee41bfb..2f5746b95 100644 --- a/nd4s/src/test/scala/org/nd4s/samediff/SameDiffTest.scala +++ b/nd4s/src/test/scala/org/nd4s/samediff/SameDiffTest.scala @@ -15,7 +15,7 @@ import org.nd4j.linalg.api.blas.params.MMulTranspose import org.nd4j.linalg.api.buffer.DataType import org.nd4j.linalg.api.ndarray.INDArray import org.nd4j.linalg.api.ops.DynamicCustomOp -import org.nd4j.linalg.api.ops.impl.layers.{ ExternalErrorsFunction, Linear } +import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction import org.nd4j.linalg.api.ops.impl.layers.convolution.config.{ Conv2DConfig, LocalResponseNormalizationConfig } import org.nd4j.linalg.api.ops.impl.reduce3.ManhattanDistance import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArray