Remove calculate output shape from java side (#151)
* remove some unneeded java-side output shape calculations Signed-off-by: Ryan Nett <rnett@skymind.io> * delete Broadcast Signed-off-by: Ryan Nett <rnett@skymind.io> * delete Linear and Module, Signed-off-by: Ryan Nett <rnett@skymind.io> * update Identity, HashCode, and NoOp Signed-off-by: Ryan Nett <rnett@skymind.io> * removed Cast java-side shape function, added tests and SDVariable.isEmpty Signed-off-by: Ryan Nett <rnett@skymind.io> * ignoring test w/ issues on master Signed-off-by: Ryan Nett <rnett@skymind.io> * noop needs more work, fixed BaseArithmeticBackprop and BaseDynamicTransform ops merge in master for c++ build fix Signed-off-by: Ryan Nett <rnett@skymind.io> * fix EqualTo Signed-off-by: Ryan Nett <rnett@skymind.io> * fix other cond ops Signed-off-by: Ryan Nett <rnett@skymind.io> * "fake" ops calculateOutputShape() throws exception Signed-off-by: Ryan Nett <rnett@skymind.io> * use c++ shape calc for Linspace Signed-off-by: Ryan Nett <rnett@skymind.io> * fix exception message, move most to BaseCompatOp Signed-off-by: Ryan Nett <rnett@skymind.io> * remove SDVariable.isEmpty Signed-off-by: Ryan Nett <rnett@skymind.io> * remove commented out code Signed-off-by: Ryan Nett <rnett@skymind.io>master
parent
b46f9827b8
commit
2a1431264f
|
@ -163,7 +163,6 @@ import org.nd4j.linalg.api.ops.impl.scatter.ScatterMin;
|
||||||
import org.nd4j.linalg.api.ops.impl.scatter.ScatterMul;
|
import org.nd4j.linalg.api.ops.impl.scatter.ScatterMul;
|
||||||
import org.nd4j.linalg.api.ops.impl.scatter.ScatterSub;
|
import org.nd4j.linalg.api.ops.impl.scatter.ScatterSub;
|
||||||
import org.nd4j.linalg.api.ops.impl.scatter.ScatterUpdate;
|
import org.nd4j.linalg.api.ops.impl.scatter.ScatterUpdate;
|
||||||
import org.nd4j.linalg.api.ops.impl.shape.Broadcast;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.shape.Concat;
|
import org.nd4j.linalg.api.ops.impl.shape.Concat;
|
||||||
import org.nd4j.linalg.api.ops.impl.shape.ConfusionMatrix;
|
import org.nd4j.linalg.api.ops.impl.shape.ConfusionMatrix;
|
||||||
import org.nd4j.linalg.api.ops.impl.shape.Cross;
|
import org.nd4j.linalg.api.ops.impl.shape.Cross;
|
||||||
|
@ -1450,14 +1449,6 @@ public class DifferentialFunctionFactory {
|
||||||
return new MatrixInverse(sameDiff(), in, false).outputVariable();
|
return new MatrixInverse(sameDiff(), in, false).outputVariable();
|
||||||
}
|
}
|
||||||
|
|
||||||
public SDVariable broadcast(SDVariable iX, int... shape) {
|
|
||||||
return broadcast(iX, ArrayUtil.toLongArray(shape));
|
|
||||||
}
|
|
||||||
|
|
||||||
public SDVariable broadcast(SDVariable iX, long... shape) {
|
|
||||||
return new Broadcast(sameDiff(), iX, shape).outputVariable();
|
|
||||||
}
|
|
||||||
|
|
||||||
public SDVariable onehot(SDVariable indices, int depth, int axis, double on, double off, DataType dataType) {
|
public SDVariable onehot(SDVariable indices, int depth, int axis, double on, double off, DataType dataType) {
|
||||||
return new OneHot(sameDiff(), indices, depth, axis, on, off, dataType).outputVariable();
|
return new OneHot(sameDiff(), indices, depth, axis, on, off, dataType).outputVariable();
|
||||||
}
|
}
|
||||||
|
|
|
@ -111,9 +111,6 @@ public class SDVariable implements Serializable {
|
||||||
return variableType == VariableType.CONSTANT;
|
return variableType == VariableType.CONSTANT;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Allocate and return a new array
|
* Allocate and return a new array
|
||||||
* based on the vertex id and weight initialization.
|
* based on the vertex id and weight initialization.
|
||||||
|
|
|
@ -98,7 +98,6 @@ public class ImportClassMapping {
|
||||||
org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax.class,
|
org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax.class,
|
||||||
org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin.class,
|
org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin.class,
|
||||||
org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction.class,
|
org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction.class,
|
||||||
org.nd4j.linalg.api.ops.impl.layers.Linear.class,
|
|
||||||
org.nd4j.linalg.api.ops.impl.layers.convolution.AvgPooling2D.class,
|
org.nd4j.linalg.api.ops.impl.layers.convolution.AvgPooling2D.class,
|
||||||
org.nd4j.linalg.api.ops.impl.layers.convolution.AvgPooling3D.class,
|
org.nd4j.linalg.api.ops.impl.layers.convolution.AvgPooling3D.class,
|
||||||
org.nd4j.linalg.api.ops.impl.layers.convolution.BatchNorm.class,
|
org.nd4j.linalg.api.ops.impl.layers.convolution.BatchNorm.class,
|
||||||
|
@ -267,7 +266,6 @@ public class ImportClassMapping {
|
||||||
org.nd4j.linalg.api.ops.impl.scatter.ScatterSub.class,
|
org.nd4j.linalg.api.ops.impl.scatter.ScatterSub.class,
|
||||||
org.nd4j.linalg.api.ops.impl.scatter.ScatterUpdate.class,
|
org.nd4j.linalg.api.ops.impl.scatter.ScatterUpdate.class,
|
||||||
org.nd4j.linalg.api.ops.impl.shape.ApplyGradientDescent.class,
|
org.nd4j.linalg.api.ops.impl.shape.ApplyGradientDescent.class,
|
||||||
org.nd4j.linalg.api.ops.impl.shape.Broadcast.class,
|
|
||||||
org.nd4j.linalg.api.ops.impl.shape.BroadcastDynamicShape.class,
|
org.nd4j.linalg.api.ops.impl.shape.BroadcastDynamicShape.class,
|
||||||
org.nd4j.linalg.api.ops.impl.shape.Concat.class,
|
org.nd4j.linalg.api.ops.impl.shape.Concat.class,
|
||||||
org.nd4j.linalg.api.ops.impl.shape.ConfusionMatrix.class,
|
org.nd4j.linalg.api.ops.impl.shape.ConfusionMatrix.class,
|
||||||
|
|
|
@ -1,63 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops;
|
|
||||||
|
|
||||||
import lombok.NoArgsConstructor;
|
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Abstract base class for {@link Module}
|
|
||||||
* that handles Dynamic ops and handles nesting.
|
|
||||||
*
|
|
||||||
* This is a logical unit for defining layers
|
|
||||||
* very similar to pytorch's modules, or tensorflow's layers.
|
|
||||||
*
|
|
||||||
* @author Adam Gibson
|
|
||||||
*/
|
|
||||||
@NoArgsConstructor
|
|
||||||
public abstract class BaseModule extends DynamicCustomOp implements Module {
|
|
||||||
private List<Module> modules = new ArrayList<>();
|
|
||||||
|
|
||||||
public BaseModule(String opName, INDArray[] inputs, INDArray[] outputs, List<Double> tArguments, List<Integer> iArguments, List<Module> modules) {
|
|
||||||
super(opName, inputs, outputs, tArguments, iArguments);
|
|
||||||
this.modules = modules;
|
|
||||||
}
|
|
||||||
|
|
||||||
public BaseModule(String opName, SameDiff sameDiff, SDVariable[] args, boolean inPlace, List<Module> modules) {
|
|
||||||
super(opName, sameDiff, args, inPlace);
|
|
||||||
this.modules = modules;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Module[] subModules() {
|
|
||||||
return modules.toArray(new Module[modules.size()]);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void addModule(Module module) {
|
|
||||||
modules.add(module);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,51 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops;
|
|
||||||
|
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A Module is a {@link CustomOp}
|
|
||||||
* with varying input arguments
|
|
||||||
* and automatically calculated outputs
|
|
||||||
* defined at a higher level than c++.
|
|
||||||
*
|
|
||||||
* A Module is meant to be a way of implementing custom operations
|
|
||||||
* in straight java/nd4j.
|
|
||||||
*/
|
|
||||||
public interface Module extends CustomOp {
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param inputs
|
|
||||||
*/
|
|
||||||
void exec(INDArray... inputs);
|
|
||||||
|
|
||||||
|
|
||||||
Module[] subModules();
|
|
||||||
|
|
||||||
|
|
||||||
void addModule(Module module);
|
|
||||||
|
|
||||||
|
|
||||||
void execSameDiff(SDVariable... input);
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
|
@ -80,6 +80,7 @@ public class NoOp extends DynamicCustomOp {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<LongShapeDescriptor> calculateOutputShape(){
|
public List<LongShapeDescriptor> calculateOutputShape(){
|
||||||
if(inputArguments != null && !inputArguments.isEmpty()){
|
if(inputArguments != null && !inputArguments.isEmpty()){
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.controlflow.compat;
|
package org.nd4j.linalg.api.ops.impl.controlflow.compat;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
import lombok.NonNull;
|
import lombok.NonNull;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
@ -25,6 +26,7 @@ import org.nd4j.imports.descriptors.properties.AttributeAdapter;
|
||||||
import org.nd4j.imports.descriptors.properties.PropertyMapping;
|
import org.nd4j.imports.descriptors.properties.PropertyMapping;
|
||||||
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
||||||
import org.tensorflow.framework.AttrValue;
|
import org.tensorflow.framework.AttrValue;
|
||||||
import org.tensorflow.framework.GraphDef;
|
import org.tensorflow.framework.GraphDef;
|
||||||
import org.tensorflow.framework.NodeDef;
|
import org.tensorflow.framework.NodeDef;
|
||||||
|
@ -89,4 +91,9 @@ public abstract class BaseCompatOp extends DynamicCustomOp {
|
||||||
public Map<String, Map<String, AttributeAdapter>> attributeAdaptersForFunction() {
|
public Map<String, Map<String, AttributeAdapter>> attributeAdaptersForFunction() {
|
||||||
return super.attributeAdaptersForFunction();
|
return super.attributeAdaptersForFunction();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<LongShapeDescriptor> calculateOutputShape() {
|
||||||
|
throw new UnsupportedOperationException("calculateOutputShape() is not supported for control flow ops.");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -68,15 +68,6 @@ public class Enter extends BaseCompatOp {
|
||||||
return OP_NAME;
|
return OP_NAME;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<LongShapeDescriptor> calculateOutputShape() {
|
|
||||||
if(arg().getArr() != null) {
|
|
||||||
return Collections.singletonList(LongShapeDescriptor.fromShape(arg().getShape(), arg().getArr().dataType()));
|
|
||||||
}
|
|
||||||
else
|
|
||||||
return Collections.emptyList();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public SDVariable[] outputVariables() {
|
public SDVariable[] outputVariables() {
|
||||||
return super.outputVariables();
|
return super.outputVariables();
|
||||||
|
|
|
@ -56,15 +56,6 @@ public class Exit extends BaseCompatOp {
|
||||||
return OP_NAME;
|
return OP_NAME;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<LongShapeDescriptor> calculateOutputShape() {
|
|
||||||
if(arg().getArr() != null) {
|
|
||||||
return Collections.singletonList(arg().getShapeDescriptor());
|
|
||||||
}
|
|
||||||
else
|
|
||||||
return Collections.emptyList();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public SDVariable[] outputVariables() {
|
public SDVariable[] outputVariables() {
|
||||||
return super.outputVariables();
|
return super.outputVariables();
|
||||||
|
|
|
@ -37,15 +37,6 @@ public class LoopCond extends BaseCompatOp {
|
||||||
return "loop_cond";
|
return "loop_cond";
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<LongShapeDescriptor> calculateOutputShape() {
|
|
||||||
if(arg().getArr() != null) {
|
|
||||||
return Collections.singletonList(LongShapeDescriptor.fromShape(arg().getShape(), arg().dataType()));
|
|
||||||
}
|
|
||||||
else
|
|
||||||
return Collections.emptyList();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public SDVariable[] outputVariables() {
|
public SDVariable[] outputVariables() {
|
||||||
return super.outputVariables();
|
return super.outputVariables();
|
||||||
|
|
|
@ -64,15 +64,6 @@ public class Merge extends BaseCompatOp {
|
||||||
return 60L;
|
return 60L;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<LongShapeDescriptor> calculateOutputShape() {
|
|
||||||
if(arg().getArr() != null) {
|
|
||||||
return Collections.singletonList(arg().getShapeDescriptor());
|
|
||||||
}
|
|
||||||
else
|
|
||||||
return Collections.emptyList();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public SDVariable[] outputVariables() {
|
public SDVariable[] outputVariables() {
|
||||||
return super.outputVariables();
|
return super.outputVariables();
|
||||||
|
|
|
@ -53,15 +53,6 @@ public class NextIteration extends BaseCompatOp {
|
||||||
return OP_NAME;
|
return OP_NAME;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<LongShapeDescriptor> calculateOutputShape() {
|
|
||||||
if(arg().getArr() != null) {
|
|
||||||
return Collections.singletonList(LongShapeDescriptor.fromShape(arg().getShape(), arg().dataType()));
|
|
||||||
}
|
|
||||||
else
|
|
||||||
return Collections.emptyList();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public SDVariable[] outputVariables() {
|
public SDVariable[] outputVariables() {
|
||||||
return super.outputVariables();
|
return super.outputVariables();
|
||||||
|
|
|
@ -63,18 +63,6 @@ public class Switch extends BaseCompatOp {
|
||||||
return OP_NAME;
|
return OP_NAME;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<LongShapeDescriptor> calculateOutputShape() {
|
|
||||||
if(args()[0].getArr() != null) {
|
|
||||||
val arg0 = args()[0];
|
|
||||||
val arr0 = arg0.getArr();
|
|
||||||
val dtype = arr0.dataType();
|
|
||||||
return Arrays.asList(LongShapeDescriptor.fromShape(arg0.getShape(), dtype),LongShapeDescriptor.fromShape(arg0.getShape(), dtype));
|
|
||||||
}
|
|
||||||
else
|
|
||||||
return Collections.emptyList();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public SDVariable[] outputVariables() {
|
public SDVariable[] outputVariables() {
|
||||||
return super.outputVariables();
|
return super.outputVariables();
|
||||||
|
|
|
@ -1,198 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.layers;
|
|
||||||
|
|
||||||
import lombok.Builder;
|
|
||||||
import lombok.NoArgsConstructor;
|
|
||||||
import lombok.val;
|
|
||||||
import onnx.Onnx;
|
|
||||||
import org.nd4j.autodiff.functions.DifferentialFunction;
|
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
|
||||||
import org.nd4j.autodiff.samediff.VariableType;
|
|
||||||
import org.nd4j.imports.NoOpNameFoundException;
|
|
||||||
import org.nd4j.linalg.api.blas.params.MMulTranspose;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.api.ops.BaseModule;
|
|
||||||
import org.nd4j.linalg.api.ops.Module;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.reduce.Mmul;
|
|
||||||
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
|
||||||
import org.nd4j.linalg.api.shape.Shape;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
|
||||||
import org.nd4j.weightinit.WeightInitScheme;
|
|
||||||
import org.nd4j.weightinit.impl.ZeroInitScheme;
|
|
||||||
import org.tensorflow.framework.AttrValue;
|
|
||||||
import org.tensorflow.framework.GraphDef;
|
|
||||||
import org.tensorflow.framework.NodeDef;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Linear:
|
|
||||||
* a * bT
|
|
||||||
*
|
|
||||||
* @author Adam Gibson
|
|
||||||
*/
|
|
||||||
@NoArgsConstructor
|
|
||||||
public class Linear extends BaseModule {
|
|
||||||
private DifferentialFunction forward;
|
|
||||||
private int nIn,nOut;
|
|
||||||
private WeightInitScheme weightInitScheme,biasWeightInitScheme;
|
|
||||||
|
|
||||||
@Builder(builderMethodName = "execBuilder")
|
|
||||||
public Linear(int nIn,
|
|
||||||
int nOut,
|
|
||||||
WeightInitScheme weightInitScheme,
|
|
||||||
WeightInitScheme biasWeightInitScheme) {
|
|
||||||
super(null,
|
|
||||||
getParams(nIn,nOut,weightInitScheme,biasWeightInitScheme),
|
|
||||||
new INDArray[]{},
|
|
||||||
new ArrayList<Double>(), new ArrayList<Integer>(), new ArrayList<Module>());
|
|
||||||
this.weightInitScheme = weightInitScheme;
|
|
||||||
this.biasWeightInitScheme = biasWeightInitScheme;
|
|
||||||
this.nIn = nIn;
|
|
||||||
this.nOut = nOut;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Builder(builderMethodName = "sameDiffBuilder")
|
|
||||||
public Linear(SameDiff sameDiff,
|
|
||||||
int nIn,
|
|
||||||
int nOut,
|
|
||||||
WeightInitScheme weightInitScheme,
|
|
||||||
WeightInitScheme biasWeightInitScheme) {
|
|
||||||
super(null, sameDiff, null, false, new ArrayList<Module>());
|
|
||||||
this.weightInitScheme = weightInitScheme;
|
|
||||||
this.biasWeightInitScheme = biasWeightInitScheme;
|
|
||||||
|
|
||||||
this.nIn = nIn;
|
|
||||||
this.nOut = nOut;
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String opName() {
|
|
||||||
return "linear";
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
|
||||||
execSameDiff();
|
|
||||||
return forward.doDiff(f1);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<LongShapeDescriptor> calculateOutputShape() {
|
|
||||||
List<LongShapeDescriptor> ret = new ArrayList<>();
|
|
||||||
ret.add(LongShapeDescriptor.fromShape(Shape.getMatrixMultiplyShape(inputArguments()[0].shape(),new long[]{nOut,nIn}), inputArguments()[1].dataType()));
|
|
||||||
|
|
||||||
ret.add(LongShapeDescriptor.fromShape(Shape.getMatrixMultiplyShape(inputArguments()[0].shape(),inputArguments()[1].transpose().shape()), inputArguments()[1].dataType()));
|
|
||||||
if(biasWeightInitScheme != null) {
|
|
||||||
ret.add(LongShapeDescriptor.fromShape(new long[]{nOut,1}, inputArguments()[1].dataType()));
|
|
||||||
}
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String onnxName() {
|
|
||||||
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String tensorflowName() {
|
|
||||||
throw new NoOpNameFoundException("No tensorflow op opName found for " + opName());
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void exec(INDArray... inputs) {
|
|
||||||
val inputArguments = inputArguments();
|
|
||||||
if(inputArguments == null || inputArguments.length < 1) {
|
|
||||||
throw new IllegalStateException("No arguments found.");
|
|
||||||
}
|
|
||||||
|
|
||||||
INDArray weights = inputArguments[0];
|
|
||||||
INDArray right = inputArguments[1];
|
|
||||||
|
|
||||||
val outputArguments = outputArguments();
|
|
||||||
|
|
||||||
if(outputArguments == null || outputArguments.length < 1) {
|
|
||||||
if(inputArguments.length == 1)
|
|
||||||
addOutputArgument(inputs[0].mmul(weights.transpose()));
|
|
||||||
else
|
|
||||||
addOutputArgument(inputs[0].mmul(weights.transpose()).addiColumnVector(right));
|
|
||||||
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
inputs[0].mmul(weights.transpose(),outputArguments[0]);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void execSameDiff(SDVariable... input) {
|
|
||||||
val args = args();
|
|
||||||
if(args == null || args.length == 0) {
|
|
||||||
throw new IllegalStateException("No arguments found");
|
|
||||||
}
|
|
||||||
|
|
||||||
if(forward == null) {
|
|
||||||
//bias needs to be added yet
|
|
||||||
if(args.length > 1) {
|
|
||||||
/*
|
|
||||||
forward = f().add(new Mmul(sameDiff, input[0],args()[0],
|
|
||||||
MMulTranspose.builder()
|
|
||||||
.transposeA(false)
|
|
||||||
.transposeB(true)
|
|
||||||
.build()).outputVariables()[0],args()[1]);
|
|
||||||
*/
|
|
||||||
} else {
|
|
||||||
forward = new Mmul(sameDiff, input[0],args()[0],
|
|
||||||
MMulTranspose.builder().transposeA(false).transposeB(true).build());
|
|
||||||
}
|
|
||||||
|
|
||||||
this.outputVariables = forward.outputVariables();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
private static INDArray[] getParams(int nIn,
|
|
||||||
int nOut,
|
|
||||||
WeightInitScheme paramsScheme,
|
|
||||||
WeightInitScheme biasInitScheme) {
|
|
||||||
if(biasInitScheme != null) {
|
|
||||||
return new INDArray[] {paramsScheme.create(Nd4j.defaultFloatingPointType(), new long[]{nOut,nIn}),biasInitScheme.create(Nd4j.defaultFloatingPointType(), new long[]{nOut,1})};
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
return new INDArray[] {paramsScheme.create(Nd4j.defaultFloatingPointType(), new long[]{nOut,nIn})};
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -38,11 +38,6 @@ public class L2Loss extends DynamicCustomOp {
|
||||||
super(sameDiff, new SDVariable[]{var});
|
super(sameDiff, new SDVariable[]{var});
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<LongShapeDescriptor> calculateOutputShape() {
|
|
||||||
return Collections.singletonList(LongShapeDescriptor.fromShape(new long[0], Nd4j.defaultFloatingPointType()));
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String opName() {
|
public String opName() {
|
||||||
return "l2_loss";
|
return "l2_loss";
|
||||||
|
|
|
@ -47,11 +47,6 @@ public class HashCode extends DynamicCustomOp {
|
||||||
this.outputArguments.add(result);
|
this.outputArguments.add(result);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<LongShapeDescriptor> calculateOutputShape() {
|
|
||||||
return Collections.singletonList(LongShapeDescriptor.fromShape(new long[0], DataType.LONG));
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String opName() {
|
public String opName() {
|
||||||
return "hashcode";
|
return "hashcode";
|
||||||
|
|
|
@ -1,78 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.shape;
|
|
||||||
|
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
|
||||||
import org.nd4j.base.Preconditions;
|
|
||||||
import org.nd4j.imports.NoOpNameFoundException;
|
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
|
||||||
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
|
||||||
|
|
||||||
import java.util.Arrays;
|
|
||||||
import java.util.Collections;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Broadcast function
|
|
||||||
*
|
|
||||||
* @author Adam Gibson
|
|
||||||
*/
|
|
||||||
public class Broadcast extends DynamicCustomOp {
|
|
||||||
private long[] shape;
|
|
||||||
public Broadcast(SameDiff sameDiff,SDVariable iX, long[] shape) {
|
|
||||||
super(null,sameDiff,new SDVariable[]{iX});
|
|
||||||
this.shape = shape;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
public Broadcast() {}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<LongShapeDescriptor> calculateOutputShape() {
|
|
||||||
return Arrays.asList(LongShapeDescriptor.fromShape(shape, larg().dataType()));
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String opName() {
|
|
||||||
return "broadcast";
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<SDVariable> doDiff(List<SDVariable> i_v) {
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String onnxName() {
|
|
||||||
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String tensorflowName() {
|
|
||||||
throw new NoOpNameFoundException("No tensorflow op opName found for " + opName());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes){
|
|
||||||
return Collections.singletonList(dataTypes.get(0));
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -82,12 +82,4 @@ public class Linspace extends DynamicCustomOp {
|
||||||
public List<SDVariable> doDiff(List<SDVariable> gradients){
|
public List<SDVariable> doDiff(List<SDVariable> gradients){
|
||||||
return Arrays.asList(f().zerosLike(arg(0)), f().zerosLike(arg(1)), f().zerosLike(arg(2)));
|
return Arrays.asList(f().zerosLike(arg(0)), f().zerosLike(arg(1)), f().zerosLike(arg(2)));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<LongShapeDescriptor> calculateOutputShape(){
|
|
||||||
INDArray l = arg(2).getArr();
|
|
||||||
if(l == null)
|
|
||||||
return Collections.emptyList();
|
|
||||||
return Collections.singletonList(LongShapeDescriptor.fromShape(new long[]{l.getLong(0)}, dataType));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -85,14 +85,6 @@ public class Rank extends DynamicCustomOp {
|
||||||
return "Rank";
|
return "Rank";
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<LongShapeDescriptor> calculateOutputShape() {
|
|
||||||
List<LongShapeDescriptor> ret = new ArrayList<>();
|
|
||||||
ret.add(LongShapeDescriptor.fromShape(new long[]{}, DataType.INT));
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<SDVariable> doDiff(List<SDVariable> i_v) {
|
public List<SDVariable> doDiff(List<SDVariable> i_v) {
|
||||||
return Collections.singletonList(sameDiff.zerosLike(arg()));
|
return Collections.singletonList(sameDiff.zerosLike(arg()));
|
||||||
|
|
|
@ -152,17 +152,4 @@ public class Unstack extends DynamicCustomOp {
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<LongShapeDescriptor> calculateOutputShape(){
|
|
||||||
//TEMPORARY workaround for: https://github.com/deeplearning4j/deeplearning4j/issues/7093
|
|
||||||
if(inputArguments.size() == 1 && inputArguments.get(0).rank() == 1){
|
|
||||||
INDArray arr = inputArguments.get(0);
|
|
||||||
Preconditions.checkState(jaxis == 0, "Can only unstack along dimension 0 for rank 1 arrays, got axis %s for array %ndShape", jaxis, arr);
|
|
||||||
LongShapeDescriptor lsd = LongShapeDescriptor.fromShape(new long[0], arr.dataType());
|
|
||||||
List<LongShapeDescriptor> out = Arrays.asList(ArrayUtil.nTimes((int)arr.length(), lsd, LongShapeDescriptor.class));
|
|
||||||
return out;
|
|
||||||
}
|
|
||||||
return super.calculateOutputShape();
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -78,8 +78,7 @@ public abstract class BaseTensorOp extends DynamicCustomOp {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<LongShapeDescriptor> calculateOutputShape() {
|
public List<LongShapeDescriptor> calculateOutputShape() {
|
||||||
//Not used/not required
|
throw new UnsupportedOperationException("calculateOutputShape() is not supported for tensor ops.");
|
||||||
return Collections.emptyList();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -41,12 +41,6 @@ public class TensorArraySize extends BaseTensorOp {
|
||||||
return "tensorarraysizev3";
|
return "tensorarraysizev3";
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<LongShapeDescriptor> calculateOutputShape() {
|
|
||||||
// output is scalar only
|
|
||||||
return Collections.singletonList(LongShapeDescriptor.fromShape(new long[]{}, DataType.LONG));
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
||||||
super.initFromTensorFlow(nodeDef, initWith, attributesForNode, graph);
|
super.initFromTensorFlow(nodeDef, initWith, attributesForNode, graph);
|
||||||
|
|
|
@ -46,68 +46,6 @@ public abstract class BaseDynamicTransformOp extends DynamicCustomOp {
|
||||||
super(null, inputs, outputs);
|
super(null, inputs, outputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<LongShapeDescriptor> calculateOutputShape() {
|
|
||||||
long[] firstArgShape;
|
|
||||||
long[] secondArgShape;
|
|
||||||
DataType dtypeZ;
|
|
||||||
|
|
||||||
if(numInputArguments() == 2){
|
|
||||||
return super.calculateOutputShape(); //Use c++ shape calc, which also accounts for empty broadcast cases, etc
|
|
||||||
// firstArgShape = inputArguments.get(0).shape();
|
|
||||||
// secondArgShape = inputArguments.get(1).shape();
|
|
||||||
// dtypeZ = Shape.pickPairwiseDataType(inputArguments.get(0).dataType(), inputArguments.get(1).dataType());
|
|
||||||
} else {
|
|
||||||
val args = args();
|
|
||||||
if (args.length < 2) {
|
|
||||||
if (args[0] == null || (inputArguments.isEmpty() && args[0].getShape() == null)) {
|
|
||||||
return Collections.emptyList();
|
|
||||||
}
|
|
||||||
DataType dtypeX = !inputArguments.isEmpty() ? inputArguments.get(0).dataType() : args[0].dataType();
|
|
||||||
long[] shape = !inputArguments.isEmpty() ? inputArguments.get(0).shape() : args[0].getShape();
|
|
||||||
|
|
||||||
return Collections.singletonList(LongShapeDescriptor.fromShape(shape, dtypeX));
|
|
||||||
}
|
|
||||||
|
|
||||||
if(inputArguments.size() == 2 && inputArguments.get(0) != null && inputArguments.get(1) != null){
|
|
||||||
firstArgShape = inputArguments.get(0).shape();
|
|
||||||
secondArgShape = inputArguments.get(1).shape();
|
|
||||||
} else {
|
|
||||||
firstArgShape = args[0].getShape();
|
|
||||||
secondArgShape = args[1].getShape();
|
|
||||||
}
|
|
||||||
if (args[0] == null || args[0].getShape() == null) {
|
|
||||||
return Collections.emptyList();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (args[1] == null || args[1].getShape() == null) {
|
|
||||||
return Collections.emptyList();
|
|
||||||
}
|
|
||||||
|
|
||||||
// detecting datatype based on both args
|
|
||||||
val dtypeX = inputArguments.size() > 0 ? inputArguments.get(0).dataType() : args[0].dataType();
|
|
||||||
val dtypeY = inputArguments.size() > 1 ? inputArguments.get(1).dataType() : args[1].dataType();
|
|
||||||
dtypeZ = Shape.pickPairwiseDataType(dtypeX, dtypeY);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if(Arrays.equals(firstArgShape, secondArgShape)){
|
|
||||||
try {
|
|
||||||
return Collections.singletonList(LongShapeDescriptor.fromShape(firstArgShape, dtypeZ));
|
|
||||||
} catch (Throwable e) {
|
|
||||||
throw new RuntimeException("calculateOutputShape() failed for [" + this.opName() + "]", e);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
//Handle broadcast shape: [1,4]+[3,1] = [3,4]
|
|
||||||
Shape.assertBroadcastable(firstArgShape, secondArgShape, this.getClass());
|
|
||||||
val outShape = Shape.broadcastOutputShape(firstArgShape, secondArgShape);
|
|
||||||
|
|
||||||
return Collections.singletonList(LongShapeDescriptor.fromShape(outShape, dtypeZ));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes){
|
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes){
|
||||||
Preconditions.checkState(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes for %s, got input %s", getClass(), dataTypes);
|
Preconditions.checkState(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes for %s, got input %s", getClass(), dataTypes);
|
||||||
|
|
|
@ -71,18 +71,6 @@ public class EqualTo extends BaseDynamicTransformOp {
|
||||||
return Arrays.asList(sameDiff.zerosLike(args()[0]), sameDiff.zerosLike(args()[1]));
|
return Arrays.asList(sameDiff.zerosLike(args()[0]), sameDiff.zerosLike(args()[1]));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<LongShapeDescriptor> calculateOutputShape() {
|
|
||||||
if (args() == null)
|
|
||||||
return Collections.emptyList();
|
|
||||||
|
|
||||||
if (inputArguments.size() == 0)
|
|
||||||
return Collections.emptyList();
|
|
||||||
|
|
||||||
|
|
||||||
return Nd4j.getExecutioner().calculateOutputShape(this);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes){
|
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes){
|
||||||
Preconditions.checkState(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes for %s, got %s", getClass(), dataTypes);
|
Preconditions.checkState(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes for %s, got %s", getClass(), dataTypes);
|
||||||
|
|
|
@ -73,18 +73,6 @@ public class GreaterThan extends BaseDynamicTransformOp {
|
||||||
return Arrays.asList(sameDiff.zerosLike(args()[0]), sameDiff.zerosLike(args()[1]));
|
return Arrays.asList(sameDiff.zerosLike(args()[0]), sameDiff.zerosLike(args()[1]));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<LongShapeDescriptor> calculateOutputShape() {
|
|
||||||
if (args() == null)
|
|
||||||
return Collections.emptyList();
|
|
||||||
|
|
||||||
if (inputArguments.size() == 0)
|
|
||||||
return Collections.emptyList();
|
|
||||||
|
|
||||||
|
|
||||||
return Nd4j.getExecutioner().calculateOutputShape(this);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes){
|
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes){
|
||||||
Preconditions.checkState(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes for %s, got %s", getClass(), dataTypes);
|
Preconditions.checkState(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes for %s, got %s", getClass(), dataTypes);
|
||||||
|
|
|
@ -76,18 +76,6 @@ public class GreaterThanOrEqual extends BaseDynamicTransformOp {
|
||||||
return Arrays.asList(sameDiff.zerosLike(args()[0]), sameDiff.zerosLike(args()[1]));
|
return Arrays.asList(sameDiff.zerosLike(args()[0]), sameDiff.zerosLike(args()[1]));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<LongShapeDescriptor> calculateOutputShape() {
|
|
||||||
if (args() == null)
|
|
||||||
return Collections.emptyList();
|
|
||||||
|
|
||||||
if (inputArguments.size() == 0)
|
|
||||||
return Collections.emptyList();
|
|
||||||
|
|
||||||
|
|
||||||
return Nd4j.getExecutioner().calculateOutputShape(this);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes){
|
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes){
|
||||||
Preconditions.checkState(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes for %s, got %s", getClass(), dataTypes);
|
Preconditions.checkState(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes for %s, got %s", getClass(), dataTypes);
|
||||||
|
|
|
@ -73,18 +73,6 @@ public class LessThan extends BaseDynamicTransformOp {
|
||||||
return Arrays.asList(sameDiff.zerosLike(args()[0]), sameDiff.zerosLike(args()[1]));
|
return Arrays.asList(sameDiff.zerosLike(args()[0]), sameDiff.zerosLike(args()[1]));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<LongShapeDescriptor> calculateOutputShape() {
|
|
||||||
if (args() == null)
|
|
||||||
return Collections.emptyList();
|
|
||||||
|
|
||||||
if (inputArguments.size() == 0)
|
|
||||||
return Collections.emptyList();
|
|
||||||
|
|
||||||
|
|
||||||
return Nd4j.getExecutioner().calculateOutputShape(this);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes){
|
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes){
|
||||||
Preconditions.checkState(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes for %s, got %s", getClass(), dataTypes);
|
Preconditions.checkState(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes for %s, got %s", getClass(), dataTypes);
|
||||||
|
|
|
@ -71,18 +71,6 @@ public class LessThanOrEqual extends BaseDynamicTransformOp {
|
||||||
return Arrays.asList(sameDiff.zerosLike(args()[0]), sameDiff.zerosLike(args()[1]));
|
return Arrays.asList(sameDiff.zerosLike(args()[0]), sameDiff.zerosLike(args()[1]));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<LongShapeDescriptor> calculateOutputShape() {
|
|
||||||
if (args() == null)
|
|
||||||
return Collections.emptyList();
|
|
||||||
|
|
||||||
if (inputArguments.size() == 0)
|
|
||||||
return Collections.emptyList();
|
|
||||||
|
|
||||||
|
|
||||||
return Nd4j.getExecutioner().calculateOutputShape(this);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes){
|
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes){
|
||||||
Preconditions.checkState(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes for %s, got %s", getClass(), dataTypes);
|
Preconditions.checkState(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes for %s, got %s", getClass(), dataTypes);
|
||||||
|
|
|
@ -72,18 +72,6 @@ public class NotEqualTo extends BaseDynamicTransformOp {
|
||||||
return Arrays.asList(sameDiff.zerosLike(args()[0]), sameDiff.zerosLike(args()[1]));
|
return Arrays.asList(sameDiff.zerosLike(args()[0]), sameDiff.zerosLike(args()[1]));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<LongShapeDescriptor> calculateOutputShape() {
|
|
||||||
if (args() == null)
|
|
||||||
return Collections.emptyList();
|
|
||||||
|
|
||||||
if (inputArguments.size() == 0)
|
|
||||||
return Collections.emptyList();
|
|
||||||
|
|
||||||
|
|
||||||
return Nd4j.getExecutioner().calculateOutputShape(this);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes){
|
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes){
|
||||||
Preconditions.checkState(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes for %s, got %s", getClass(), dataTypes);
|
Preconditions.checkState(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes for %s, got %s", getClass(), dataTypes);
|
||||||
|
|
|
@ -125,41 +125,6 @@ public class Cast extends BaseDynamicTransformOp {
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<LongShapeDescriptor> calculateOutputShape() {
|
|
||||||
if(inputArguments.size() > 0){
|
|
||||||
long[] s = inputArguments.get(0).shape();
|
|
||||||
LongShapeDescriptor lsd = LongShapeDescriptor.fromShape(s, typeDst);
|
|
||||||
if(inputArguments.get(0).isEmpty()){
|
|
||||||
long e = lsd.getExtras();
|
|
||||||
e = ArrayOptionsHelper.setOptionBit(e, ArrayType.EMPTY);
|
|
||||||
lsd.setExtras(e);
|
|
||||||
}
|
|
||||||
return Collections.singletonList(lsd);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (arg() != null && (arg().getArr() != null || arg().getShape() != null)) {
|
|
||||||
if (arg().getArr() != null) {
|
|
||||||
long[] s = arg().getArr().shape();
|
|
||||||
LongShapeDescriptor lsd = LongShapeDescriptor.fromShape(s, typeDst);
|
|
||||||
if(inputArguments.size() > 0 && inputArguments.get(0) != null && inputArguments.get(0).isEmpty()){
|
|
||||||
long e = lsd.getExtras();
|
|
||||||
e = ArrayOptionsHelper.setOptionBit(e, ArrayType.EMPTY);
|
|
||||||
lsd.setExtras(e);
|
|
||||||
}
|
|
||||||
return Collections.singletonList(lsd);
|
|
||||||
} else {
|
|
||||||
long[] s = arg().getShape();
|
|
||||||
if(Shape.isPlaceholderShape(s)){
|
|
||||||
return Collections.emptyList();
|
|
||||||
}
|
|
||||||
return Collections.singletonList(LongShapeDescriptor.fromShape(s, typeDst));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return Collections.emptyList();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void setValueFor(Field target, Object value) {
|
public void setValueFor(Field target, Object value) {
|
||||||
//This is a hack around a property mapping issue - TF datatype DT_DOUBLE return attribute.getType() of DT_DOUBLE which doesn't make sense
|
//This is a hack around a property mapping issue - TF datatype DT_DOUBLE return attribute.getType() of DT_DOUBLE which doesn't make sense
|
||||||
|
|
|
@ -77,7 +77,7 @@ public class GradientBackwardsMarker extends DynamicCustomOp {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<LongShapeDescriptor> calculateOutputShape() {
|
public List<LongShapeDescriptor> calculateOutputShape() {
|
||||||
return Collections.singletonList(LongShapeDescriptor.fromShape(new long[0], DataType.FLOAT));
|
throw new UnsupportedOperationException("calculateOutputShape() is not supported for control flow ops.");
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -46,11 +46,6 @@ public abstract class BaseArithmeticBackpropOp extends BaseDynamicTransformOp {
|
||||||
throw new UnsupportedOperationException("Not supported");
|
throw new UnsupportedOperationException("Not supported");
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<LongShapeDescriptor> calculateOutputShape(){
|
|
||||||
return Nd4j.getExecutioner().calculateOutputShape(this);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes){
|
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes){
|
||||||
Preconditions.checkState(dataTypes != null && dataTypes.size() == 3, "Expected exactly 3 input datatypes for %s, got input %s", getClass(), dataTypes);
|
Preconditions.checkState(dataTypes != null && dataTypes.size() == 3, "Expected exactly 3 input datatypes for %s, got input %s", getClass(), dataTypes);
|
||||||
|
|
|
@ -79,11 +79,4 @@ public class Identity extends BaseDynamicTransformOp {
|
||||||
return dataTypes;
|
return dataTypes;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<LongShapeDescriptor> calculateOutputShape() {
|
|
||||||
if(inputArguments == null || inputArguments.isEmpty())
|
|
||||||
return Collections.emptyList();
|
|
||||||
return Collections.singletonList(inputArguments.get(0).shapeDescriptor());
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1619,21 +1619,6 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
assertEquals(expected, result.eval());
|
assertEquals(expected, result.eval());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testBroadcast() {
|
|
||||||
OpValidationSuite.ignoreFailing();
|
|
||||||
SameDiff sd = SameDiff.create();
|
|
||||||
SDVariable in = sd.var("in", Nd4j.rand(3, 4));
|
|
||||||
SDVariable broadcast = sd.f().broadcast(in, 3, 4, 5);
|
|
||||||
|
|
||||||
INDArray out = sd.execAndEndResult();
|
|
||||||
assertArrayEquals(new long[]{3, 4, 5}, out.shape());
|
|
||||||
|
|
||||||
for (int i = 0; i < 5; i++) {
|
|
||||||
assertEquals(in.getArr(), out.get(all(), all(), point(i)));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testSlice2d() {
|
public void testSlice2d() {
|
||||||
|
|
|
@ -53,7 +53,6 @@ import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction;
|
import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.Linear;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.LocalResponseNormalizationConfig;
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.LocalResponseNormalizationConfig;
|
||||||
import org.nd4j.linalg.api.ops.impl.reduce3.ManhattanDistance;
|
import org.nd4j.linalg.api.ops.impl.reduce3.ManhattanDistance;
|
||||||
|
@ -522,39 +521,6 @@ public class SameDiffTests extends BaseNd4jTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testLinearModule() {
|
|
||||||
int nIn = 5;
|
|
||||||
Linear linear = Linear.execBuilder()
|
|
||||||
.nIn(nIn)
|
|
||||||
.nOut(4)
|
|
||||||
.weightInitScheme(new UniformInitScheme('f', nIn))
|
|
||||||
.biasWeightInitScheme(new ZeroInitScheme('f'))
|
|
||||||
.build();
|
|
||||||
linear.exec(Nd4j.linspace(1, 20, 20).reshape(4, 5));
|
|
||||||
assertEquals(1, linear.numOutputArguments());
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testLinearModule2() {
|
|
||||||
Linear linear = Linear.execBuilder()
|
|
||||||
.nIn(3)
|
|
||||||
.nOut(2)
|
|
||||||
.weightInitScheme(new OneInitScheme('f'))
|
|
||||||
.biasWeightInitScheme(new ZeroInitScheme('f'))
|
|
||||||
.build();
|
|
||||||
linear.exec(Nd4j.linspace(1, 6, 6).reshape(2, 3));
|
|
||||||
INDArray assertion = Nd4j.create(new double[][]{
|
|
||||||
{6, 6},
|
|
||||||
{15, 15}
|
|
||||||
});
|
|
||||||
assertEquals(assertion, linear.outputArguments()[0]);
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testDefineFunctionArrayExistence() {
|
public void testDefineFunctionArrayExistence() {
|
||||||
SameDiff sameDiff = SameDiff.create();
|
SameDiff sameDiff = SameDiff.create();
|
||||||
|
@ -3577,4 +3543,24 @@ public class SameDiffTests extends BaseNd4jTest {
|
||||||
|
|
||||||
assertEquals(e, mod.eval());
|
assertEquals(e, mod.eval());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void castShapeTest1(){
|
||||||
|
SameDiff sd = SameDiff.create();
|
||||||
|
SDVariable x = sd.constant(Nd4j.createFromArray(1, 2, 3, 4));
|
||||||
|
SDVariable casted = x.castTo(DataType.FLOAT);
|
||||||
|
|
||||||
|
assertEquals(casted.dataType(), DataType.FLOAT);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
@Ignore // casted shape is null
|
||||||
|
public void castShapeTestEmpty(){
|
||||||
|
SameDiff sd = SameDiff.create();
|
||||||
|
SDVariable x = sd.constant(Nd4j.empty(DataType.INT));
|
||||||
|
SDVariable casted = x.castTo(DataType.FLOAT);
|
||||||
|
|
||||||
|
assertEquals(casted.dataType(), DataType.FLOAT);
|
||||||
|
assertTrue(casted.getShapeDescriptor().isEmpty());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,7 +15,7 @@ import org.nd4j.linalg.api.blas.params.MMulTranspose
|
||||||
import org.nd4j.linalg.api.buffer.DataType
|
import org.nd4j.linalg.api.buffer.DataType
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray
|
import org.nd4j.linalg.api.ndarray.INDArray
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp
|
import org.nd4j.linalg.api.ops.DynamicCustomOp
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.{ ExternalErrorsFunction, Linear }
|
import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.{ Conv2DConfig, LocalResponseNormalizationConfig }
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.{ Conv2DConfig, LocalResponseNormalizationConfig }
|
||||||
import org.nd4j.linalg.api.ops.impl.reduce3.ManhattanDistance
|
import org.nd4j.linalg.api.ops.impl.reduce3.ManhattanDistance
|
||||||
import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArray
|
import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArray
|
||||||
|
|
Loading…
Reference in New Issue