TF import tests - adding missing operations (#65)

* Add and fix mappings.

* Intermediate

* Added and fixed some mappings

* Added op

* Missing constructors added.

* Added new mappings

* SDImage wrappers and minor tweaks.

* Added missing constructor

* Some corrections

* Cleanup

* Small fixes

* Ops wrappers

* Minor fixes.

* Max Pooling

* MaxPoolWithArgmax

* Some fixes

* Ignores for failures

* Some ops fixed.

* Some fixes

* Missing package added

* Some fixes

* Ignored tests fixed.

* Some fixes

* Merge master

* bitcast fix

Signed-off-by: raver119 <raver119@gmail.com>

* Bitcast fixed
master
Alexander Stoyakin 2019-12-02 12:23:06 +02:00 committed by Alex Black
parent 1adc25919c
commit 5e152c0d9a
39 changed files with 1545 additions and 86 deletions

View File

@ -2616,6 +2616,36 @@ public class DifferentialFunctionFactory {
return new FakeQuantWithMinMaxVarsPerChannel(sameDiff,x,min,max).outputVariable(); return new FakeQuantWithMinMaxVarsPerChannel(sameDiff,x,min,max).outputVariable();
} }
public SDVariable betainc( SDVariable a, SDVariable b, SDVariable x) {
return new BetaInc(sameDiff, a, b, x).outputVariable();
}
public SDVariable[] fusedBatchNorm(SDVariable x, SDVariable scale, SDVariable offset,
SDVariable dataFormat, SDVariable isTraining) {
return new FusedBatchNorm(sameDiff,x,scale,offset,dataFormat,isTraining).outputVariables();
}
public SDVariable matrixBandPart(SDVariable input, SDVariable minLower, SDVariable maxUpper) {
return new MatrixBandPart(sameDiff,input,minLower,maxUpper).outputVariable();
}
public SDVariable[] maxPoolWithArgmaxs(SDVariable x, Pooling2DConfig pooling2DConfig) {
return new MaxPoolWithArgmax(sameDiff, x, pooling2DConfig).outputVariables();
}
public SDVariable polygamma(SDVariable n, SDVariable x) {
return new Polygamma(sameDiff, n,x).outputVariable();
}
public SDVariable roll(SDVariable input, SDVariable shift) {
return new Roll(sameDiff, input, shift).outputVariable();
}
public SDVariable toggleBits(SDVariable x) {
return new ToggleBits(sameDiff, x).outputVariable();
}
public String toString() { public String toString() {
return "DifferentialFunctionFactory{methodNames=" + methodNames + "}"; return "DifferentialFunctionFactory{methodNames=" + methodNames + "}";
} }

View File

@ -202,4 +202,16 @@ public class SDBitwise extends SDOps {
SDVariable ret = f().bitwiseXor(x, y); SDVariable ret = f().bitwiseXor(x, y);
return updateVariableNameAndReference(ret, name); return updateVariableNameAndReference(ret, name);
} }
/**
* Flip bits
*
* @param name Name of the output variable
* @param x input array
* @return array after flipping each input bit
*/
public SDVariable toggleBits(String name, SDVariable x) {
SDVariable res = f().toggleBits(x);
return updateVariableNameAndReference(res, name);
}
} }

View File

@ -3,6 +3,10 @@ package org.nd4j.autodiff.samediff.ops;
import lombok.NonNull; import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ops.custom.AdjustContrast;
import org.nd4j.linalg.api.ops.custom.AdjustHue;
import org.nd4j.linalg.api.ops.custom.AdjustSaturation;
import org.nd4j.linalg.api.ops.custom.RandomCrop;
import org.nd4j.linalg.api.ops.impl.image.CropAndResize; import org.nd4j.linalg.api.ops.impl.image.CropAndResize;
import org.nd4j.linalg.api.ops.impl.image.ExtractImagePatches; import org.nd4j.linalg.api.ops.impl.image.ExtractImagePatches;
import org.nd4j.linalg.api.ops.impl.image.NonMaxSuppression; import org.nd4j.linalg.api.ops.impl.image.NonMaxSuppression;
@ -52,10 +56,67 @@ public class SDImage extends SDOps {
return updateVariableNameAndReference(out, name); return updateVariableNameAndReference(out, name);
} }
/**
* Greedily selects a subset of bounding boxes in descending order of score
* @param name Might be null. Name for the output variable
* @param boxes 2D array of shape [num_boxes,4]
* @param scores vector of shape [num_boxes]
* @param maxOutSize scalar representing the maximum number of boxes to be selected
* @param iouThreshold float - threshold for deciding whether boxes overlap too much with respect to IOU
* @param scoreThreshold float - threshold for deciding when to remove boxes based on score
* @return vectort of shape [M] representing the selected indices from the boxes tensor, where M <= max_output_size
*/
public SDVariable nonMaxSuppression(String name, @NonNull SDVariable boxes, @NonNull SDVariable scores, @NonNull SDVariable maxOutSize, public SDVariable nonMaxSuppression(String name, @NonNull SDVariable boxes, @NonNull SDVariable scores, @NonNull SDVariable maxOutSize,
@NonNull SDVariable iouThreshold, @NonNull SDVariable scoreThreshold){ @NonNull SDVariable iouThreshold, @NonNull SDVariable scoreThreshold){
SDVariable out = new NonMaxSuppression(sd, boxes, scores, maxOutSize, iouThreshold, scoreThreshold).outputVariable(); SDVariable out = new NonMaxSuppression(sd, boxes, scores, maxOutSize, iouThreshold, scoreThreshold).outputVariable();
return updateVariableNameAndReference(out, name); return updateVariableNameAndReference(out, name);
} }
/**
* Adjusts contrast of RGB or grayscale images.
* @param name name for the output variable
* @param in images to adjust. 3D shape or higher.
* @param factor float multiplier for adjusting contrast.
* @return Contrast-adjusted image
*/
public SDVariable adjustContrast(String name, @NonNull SDVariable in, @NonNull SDVariable factor) {
SDVariable out = new AdjustContrast(sd, in, factor).outputVariable();
return updateVariableNameAndReference(out, name);
}
/**
* Adjust saturation of RGB images
* @param name name for the output variable
* @param in RGB image as 3D array
* @param factor factor for saturation
* @return adjusted image
*/
public SDVariable adjustSaturation(String name, @NonNull SDVariable in, @NonNull SDVariable factor) {
SDVariable out = new AdjustSaturation(sd, in, factor).outputVariable();
return updateVariableNameAndReference(out, name);
}
/**
* Adjust hue of RGB image
* @param name name for the output variable
* @param in RGB image as 3D array
* @param delta value to add to hue channel
* @return adjusted image
*/
public SDVariable adjustHue(String name, @NonNull SDVariable in, @NonNull SDVariable delta) {
SDVariable out = new AdjustHue(sd, in, delta).outputVariable();
return updateVariableNameAndReference(out, name);
}
/**
* Randomly crops image
* @param name name for the output variable
* @param input input array
* @param shape shape for crop
* @return cropped array
*/
public SDVariable randomCrop(String name, @NonNull SDVariable input, @NonNull SDVariable shape) {
SDVariable out = new RandomCrop(sd, input, shape).outputVariable();
return updateVariableNameAndReference(out, name);
}
} }

View File

@ -2496,5 +2496,57 @@ public class SDMath extends SDOps {
return updateVariableNameAndReference(res, name); return updateVariableNameAndReference(res, name);
} }
/**
* Compute the regularized incomplete beta integral
*
* @param name Name of the output variable
* @param a input array
* @param b input array
* @param x input array
* @return array
*/
public SDVariable betainc(String name,SDVariable a,SDVariable b,SDVariable x) {
SDVariable res = f().betainc(a,b,x);
return updateVariableNameAndReference(res, name);
}
/**
* Copy a tensor setting everything outside a central band in each innermost matrix.
*
* @param name Name of the output variable
* @param input Rank k array
* @param minLower Number of subdiagonals to keep.
* @param maxUpper Number of superdiagonals to keep.
* @return Rank k array of the same shape as input.
*/
public SDVariable matrixBandPart(String name, SDVariable input, SDVariable minLower, SDVariable maxUpper) {
SDVariable res = f().matrixBandPart(input,minLower,maxUpper);
return updateVariableNameAndReference(res, name);
}
/**
* Polygamma function
*
* @param name Name of the output variable
* @param n array
* @param x array
* @return array
*/
public SDVariable polygamma(String name, SDVariable n, SDVariable x) {
SDVariable res = f().polygamma(n,x);
return updateVariableNameAndReference(res, name);
}
/**
* Rolls the elements of input
*
* @param name Name of the output variable
* @param input array
* @param shift number of places to shift elements
* @return array
*/
public SDVariable roll(String name, SDVariable input, SDVariable shift) {
SDVariable res = f().roll(input,shift);
return updateVariableNameAndReference(res, name);
}
} }

View File

@ -19,6 +19,7 @@ package org.nd4j.autodiff.samediff.ops;
import lombok.NonNull; import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling2DConfig;
import org.nd4j.linalg.api.ops.impl.transforms.Pad; import org.nd4j.linalg.api.ops.impl.transforms.Pad;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -1032,4 +1033,35 @@ public class SDNN extends SDOps {
); );
} }
} }
/**
* Max pooling on the input and outputs both max values and indices
*
* @param name Name of the output variable
* @param x input array
* @return output array and argmax array
*/
public SDVariable[] maxPoolWithArgmax(String[] names, SDVariable x, Pooling2DConfig pooling2DConfig) {
SDVariable[] res = f().maxPoolWithArgmaxs(x, pooling2DConfig);
return sd.updateVariableNamesAndReferences(res, names);
}
/**
* Batch normalization
*
* @param name Name of the output variable
* @param x 4D array
* @param scale vector for scaling factor of normalized x
* @param offset vector to shift to the normalized x
* @param dataFormat integer scalar - data format
* @param isTraining boolean scalar - is training mode
* @return y: 4D array
* batch_mean: vector
* batch_var: vector
*/
public SDVariable[] fusedBatchNorm(String[] names, SDVariable x, SDVariable scale, SDVariable offset,
SDVariable dataFormat, SDVariable isTraining) {
SDVariable[] res = f().fusedBatchNorm(x,scale,offset,dataFormat,isTraining);
return sd.updateVariableNamesAndReferences(res, names);
}
} }

View File

@ -46,7 +46,6 @@ public class ImportClassMapping {
org.nd4j.linalg.api.ops.custom.BarnesEdgeForces.class, org.nd4j.linalg.api.ops.custom.BarnesEdgeForces.class,
org.nd4j.linalg.api.ops.custom.BarnesHutGains.class, org.nd4j.linalg.api.ops.custom.BarnesHutGains.class,
org.nd4j.linalg.api.ops.custom.BarnesHutSymmetrize.class, org.nd4j.linalg.api.ops.custom.BarnesHutSymmetrize.class,
org.nd4j.linalg.api.ops.custom.DrawBoundingBoxes.class,
org.nd4j.linalg.api.ops.custom.KnnMinDistance.class, org.nd4j.linalg.api.ops.custom.KnnMinDistance.class,
org.nd4j.linalg.api.ops.custom.SpTreeCell.class, org.nd4j.linalg.api.ops.custom.SpTreeCell.class,
org.nd4j.linalg.api.ops.custom.Flatten.class, org.nd4j.linalg.api.ops.custom.Flatten.class,
@ -122,6 +121,7 @@ public class ImportClassMapping {
org.nd4j.linalg.api.ops.impl.layers.convolution.LocalResponseNormalizationDerivative.class, org.nd4j.linalg.api.ops.impl.layers.convolution.LocalResponseNormalizationDerivative.class,
org.nd4j.linalg.api.ops.impl.layers.convolution.MaxPooling2D.class, org.nd4j.linalg.api.ops.impl.layers.convolution.MaxPooling2D.class,
org.nd4j.linalg.api.ops.impl.layers.convolution.MaxPooling3D.class, org.nd4j.linalg.api.ops.impl.layers.convolution.MaxPooling3D.class,
org.nd4j.linalg.api.ops.impl.layers.convolution.MaxPoolWithArgmax.class,
org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2D.class, org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2D.class,
org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2DDerivative.class, org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2DDerivative.class,
org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling3DDerivative.class, org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling3DDerivative.class,
@ -589,7 +589,17 @@ public class ImportClassMapping {
org.nd4j.linalg.api.ops.custom.BitCast.class, org.nd4j.linalg.api.ops.custom.BitCast.class,
org.nd4j.linalg.api.ops.custom.CompareAndBitpack.class, org.nd4j.linalg.api.ops.custom.CompareAndBitpack.class,
org.nd4j.linalg.api.ops.custom.DivideNoNan.class, org.nd4j.linalg.api.ops.custom.DivideNoNan.class,
org.nd4j.linalg.api.ops.custom.FakeQuantWithMinMaxVarsPerChannel.class org.nd4j.linalg.api.ops.custom.DrawBoundingBoxes.class,
org.nd4j.linalg.api.ops.custom.FakeQuantWithMinMaxVarsPerChannel.class,
org.nd4j.linalg.api.ops.custom.AdjustSaturation.class,
org.nd4j.linalg.api.ops.custom.AdjustHue.class,
org.nd4j.linalg.api.ops.custom.FusedBatchNorm.class,
org.nd4j.linalg.api.ops.custom.BetaInc.class,
org.nd4j.linalg.api.ops.custom.MatrixBandPart.class,
org.nd4j.linalg.api.ops.custom.Polygamma.class,
org.nd4j.linalg.api.ops.custom.RandomCrop.class,
org.nd4j.linalg.api.ops.custom.Roll.class,
org.nd4j.linalg.api.ops.custom.ToggleBits.class
); );
static { static {

View File

@ -1,5 +1,22 @@
/* ******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.custom; package org.nd4j.linalg.api.ops.custom;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
@ -14,11 +31,11 @@ public class AdjustContrast extends BaseAdjustContrast {
public AdjustContrast() {super();} public AdjustContrast() {super();}
public AdjustContrast(INDArray in, double factor, INDArray out) { public AdjustContrast(@NonNull INDArray in, double factor, INDArray out) {
super(in, factor, out); super(in, factor, out);
} }
public AdjustContrast(SameDiff sameDiff, SDVariable in, SDVariable factor) { public AdjustContrast(@NonNull SameDiff sameDiff, @NonNull SDVariable in, @NonNull SDVariable factor) {
super(sameDiff,new SDVariable[]{in,factor}); super(sameDiff,new SDVariable[]{in,factor});
} }

View File

@ -1,5 +1,21 @@
/* ******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.custom; package org.nd4j.linalg.api.ops.custom;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
@ -14,11 +30,11 @@ public class AdjustContrastV2 extends BaseAdjustContrast {
public AdjustContrastV2() {super();} public AdjustContrastV2() {super();}
public AdjustContrastV2(INDArray in, double factor, INDArray out) { public AdjustContrastV2(@NonNull INDArray in, double factor, INDArray out) {
super(in, factor, out); super(in, factor, out);
} }
public AdjustContrastV2(SameDiff sameDiff, SDVariable in, SDVariable factor) { public AdjustContrastV2(@NonNull SameDiff sameDiff, @NonNull SDVariable in, @NonNull SDVariable factor) {
super( sameDiff,new SDVariable[]{in,factor}); super( sameDiff,new SDVariable[]{in,factor});
} }

View File

@ -0,0 +1,69 @@
/* ******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.custom;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import java.util.Collections;
import java.util.List;
public class AdjustHue extends DynamicCustomOp {
public AdjustHue() {}
public AdjustHue(@NonNull INDArray in, double delta, INDArray out) {
this(in, delta);
if (out != null) {
outputArguments.add(out);
}
}
public AdjustHue(@NonNull INDArray in, double delta) {
Preconditions.checkArgument(in.rank() >= 3,
"AdjustSaturation: op expects rank of input array to be >= 3, but got %s instead", in.rank());
Preconditions.checkArgument(-1.0 <= delta && delta <= 1.0, "AdjustHue: parameter delta must be within [-1, 1] interval," +
" but got %s instead", delta);
inputArguments.add(in);
addTArgument(delta);
}
public AdjustHue(@NonNull SameDiff sameDiff, @NonNull SDVariable in, @NonNull SDVariable factor) {
super(sameDiff,new SDVariable[]{in,factor});
}
@Override
public String opName() {
return "adjust_hue";
}
@Override
public String tensorflowName() {
return "AdjustHue";
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
int n = args().length;
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes);
return Collections.singletonList(inputDataTypes.get(0));
}
}

View File

@ -0,0 +1,68 @@
/* ******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.custom;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import java.util.Collections;
import java.util.List;
public class AdjustSaturation extends DynamicCustomOp {
public AdjustSaturation() {}
public AdjustSaturation(@NonNull INDArray in, double factor, INDArray out) {
this(in, factor);
if (out != null) {
outputArguments.add(out);
}
}
public AdjustSaturation(@NonNull INDArray in, double factor) {
Preconditions.checkArgument(in.rank() >= 3,
"AdjustSaturation: op expects rank of input array to be >= 3, but got %s instead", in.rank());
inputArguments.add(in);
addTArgument(factor);
}
public AdjustSaturation(@NonNull SameDiff sameDiff, @NonNull SDVariable in, @NonNull SDVariable factor) {
super(sameDiff, new SDVariable[]{in, factor});
}
@Override
public String opName() {
return "adjust_saturation";
}
@Override
public String tensorflowName() {
return "AdjustSaturation";
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
int n = args().length;
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes);
return Collections.singletonList(inputDataTypes.get(0));
}
}

View File

@ -1,5 +1,21 @@
/* ******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.custom; package org.nd4j.linalg.api.ops.custom;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
@ -14,16 +30,16 @@ public abstract class BaseAdjustContrast extends DynamicCustomOp {
public BaseAdjustContrast() { public BaseAdjustContrast() {
} }
public BaseAdjustContrast(INDArray in, double factor, INDArray out) { public BaseAdjustContrast(@NonNull INDArray in, double factor, INDArray out) {
Preconditions.checkArgument(in.rank() >= 3, Preconditions.checkArgument(in.rank() >= 3,
String.format("AdjustContrast: op expects rank of input array to be >= 3, but got %d instead", in.rank())); "AdjustContrast: op expects rank of input array to be >= 3, but got %s instead", in.rank());
inputArguments.add(in); inputArguments.add(in);
outputArguments.add(out); outputArguments.add(out);
addTArgument(factor); addTArgument(factor);
} }
public BaseAdjustContrast(SameDiff sameDiff, SDVariable[] vars) { public BaseAdjustContrast(@NonNull SameDiff sameDiff, @NonNull SDVariable[] vars) {
super("", sameDiff, vars); super("", sameDiff, vars);
} }

View File

@ -0,0 +1,67 @@
/* ******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.custom;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import java.util.Collections;
import java.util.List;
public class BetaInc extends DynamicCustomOp {
public BetaInc() {}
public BetaInc(@NonNull INDArray a_input, @NonNull INDArray b_input, @NonNull INDArray x_input,
INDArray output) {
addInputArgument(a_input, b_input, x_input);
if (output != null) {
addOutputArgument(output);
}
}
public BetaInc(@NonNull INDArray a_input, @NonNull INDArray b_input, @NonNull INDArray x_input) {
inputArguments.add(a_input);
inputArguments.add(b_input);
inputArguments.add(x_input);
}
public BetaInc(@NonNull SameDiff sameDiff, @NonNull SDVariable a, @NonNull SDVariable b, @NonNull SDVariable x) {
super(sameDiff, new SDVariable[]{a,b,x});
}
@Override
public String opName() {
return "betainc";
}
@Override
public String tensorflowName() {
return "Betainc";
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
int n = args().length;
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes);
return Collections.singletonList(inputDataTypes.get(0));
}
}

View File

@ -1,3 +1,18 @@
/* ******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.custom; package org.nd4j.linalg.api.ops.custom;
import lombok.val; import lombok.val;
@ -20,6 +35,8 @@ import java.util.Map;
public class BitCast extends DynamicCustomOp { public class BitCast extends DynamicCustomOp {
public BitCast() {} public BitCast() {}
private DataType dtype;
public BitCast(INDArray in, DataType dataType, INDArray out) { public BitCast(INDArray in, DataType dataType, INDArray out) {
this(in, dataType.toInt(), out); this(in, dataType.toInt(), out);
} }
@ -28,6 +45,8 @@ public class BitCast extends DynamicCustomOp {
inputArguments.add(in); inputArguments.add(in);
outputArguments.add(out); outputArguments.add(out);
iArguments.add(Long.valueOf(dataType)); iArguments.add(Long.valueOf(dataType));
dtype = DataType.fromInt(dataType);
} }
public BitCast(INDArray in, DataType dataType) { public BitCast(INDArray in, DataType dataType) {
@ -37,6 +56,7 @@ public class BitCast extends DynamicCustomOp {
public BitCast(INDArray in, int dataType) { public BitCast(INDArray in, int dataType) {
inputArguments.add(in); inputArguments.add(in);
iArguments.add(Long.valueOf(dataType)); iArguments.add(Long.valueOf(dataType));
dtype = DataType.fromInt(dataType);
} }
public BitCast(SameDiff sameDiff, SDVariable in, SDVariable dataType) { public BitCast(SameDiff sameDiff, SDVariable in, SDVariable dataType) {
@ -49,6 +69,8 @@ public class BitCast extends DynamicCustomOp {
val t = nodeDef.getAttrOrDefault("type", null); val t = nodeDef.getAttrOrDefault("type", null);
val type = ArrayOptionsHelper.convertToDataType(t.getType()); val type = ArrayOptionsHelper.convertToDataType(t.getType());
addIArgument(type.toInt()); addIArgument(type.toInt());
dtype = type;
} }
@Override @Override
@ -65,6 +87,6 @@ public class BitCast extends DynamicCustomOp {
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){ public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
int n = args().length; int n = args().length;
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes); Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes);
return Collections.singletonList(inputDataTypes.get(0)); return Collections.singletonList(dtype);
} }
} }

View File

@ -1,3 +1,18 @@
/* ******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.custom; package org.nd4j.linalg.api.ops.custom;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
@ -9,9 +24,13 @@ import org.nd4j.linalg.factory.Nd4j;
public class CompareAndBitpack extends DynamicCustomOp { public class CompareAndBitpack extends DynamicCustomOp {
public CompareAndBitpack() {} public CompareAndBitpack() {}
public CompareAndBitpack(INDArray in, double threshold, INDArray out) { public CompareAndBitpack(INDArray in, double threshold) {
inputArguments.add(in); inputArguments.add(in);
inputArguments.add(Nd4j.scalar(threshold)); inputArguments.add(Nd4j.scalar(threshold));
}
public CompareAndBitpack(INDArray in, double threshold, INDArray out) {
this(in, threshold);
outputArguments.add(out); outputArguments.add(out);
} }

View File

@ -1,3 +1,18 @@
/* ******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.custom; package org.nd4j.linalg.api.ops.custom;
import org.apache.commons.math3.analysis.function.Divide; import org.apache.commons.math3.analysis.function.Divide;
@ -16,9 +31,13 @@ public class DivideNoNan extends DynamicCustomOp {
public DivideNoNan() { public DivideNoNan() {
} }
public DivideNoNan(INDArray in1, INDArray in2, INDArray out) { public DivideNoNan(INDArray in1, INDArray in2) {
inputArguments.add(in1); inputArguments.add(in1);
inputArguments.add(in2); inputArguments.add(in2);
}
public DivideNoNan(INDArray in1, INDArray in2, INDArray out) {
this(in1,in2);
outputArguments.add(out); outputArguments.add(out);
} }

View File

@ -1,3 +1,18 @@
/* ******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.custom; package org.nd4j.linalg.api.ops.custom;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
@ -13,11 +28,15 @@ import java.util.List;
public class DrawBoundingBoxes extends DynamicCustomOp { public class DrawBoundingBoxes extends DynamicCustomOp {
public DrawBoundingBoxes() {} public DrawBoundingBoxes() {}
public DrawBoundingBoxes(INDArray images, INDArray boxes, INDArray colors, public DrawBoundingBoxes(INDArray images, INDArray boxes, INDArray colors) {
INDArray output) {
inputArguments.add(images); inputArguments.add(images);
inputArguments.add(boxes); inputArguments.add(boxes);
inputArguments.add(colors); inputArguments.add(colors);
}
public DrawBoundingBoxes(INDArray images, INDArray boxes, INDArray colors,
INDArray output) {
this(images, boxes, colors);
outputArguments.add(output); outputArguments.add(output);
} }

View File

@ -1,3 +1,18 @@
/* ******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.custom; package org.nd4j.linalg.api.ops.custom;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
@ -13,14 +28,18 @@ import java.util.List;
public class FakeQuantWithMinMaxVarsPerChannel extends DynamicCustomOp { public class FakeQuantWithMinMaxVarsPerChannel extends DynamicCustomOp {
public FakeQuantWithMinMaxVarsPerChannel() {} public FakeQuantWithMinMaxVarsPerChannel() {}
public FakeQuantWithMinMaxVarsPerChannel(INDArray x, INDArray min, INDArray max, public FakeQuantWithMinMaxVarsPerChannel(INDArray x, INDArray min, INDArray max) {
INDArray output) {
Preconditions.checkArgument(min.isVector() && max.isVector() && Preconditions.checkArgument(min.isVector() && max.isVector() &&
min.length() == max.length(), min.length() == max.length(),
"FakeQuantWithMinMaxVarsPerChannel: min and max should be 1D tensors with the same length"); "FakeQuantWithMinMaxVarsPerChannel: min and max should be 1D tensors with the same length");
inputArguments.add(x); inputArguments.add(x);
inputArguments.add(min); inputArguments.add(min);
inputArguments.add(max); inputArguments.add(max);
}
public FakeQuantWithMinMaxVarsPerChannel(INDArray x, INDArray min, INDArray max,
INDArray output) {
this(x,min,max);
outputArguments.add(output); outputArguments.add(output);
} }

View File

@ -0,0 +1,64 @@
/* ******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.custom;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import java.util.Collections;
import java.util.List;
public class FusedBatchNorm extends DynamicCustomOp {
public FusedBatchNorm() {}
public FusedBatchNorm(@NonNull INDArray x, @NonNull INDArray scale, @NonNull INDArray offset,
int dataFormat, int isTraining,
INDArray yOut, INDArray batchMeanOut, INDArray batchMeanVar) {
addInputArgument(x, scale, offset);
addIArgument(dataFormat, isTraining);
if (yOut != null && batchMeanOut != null && batchMeanVar != null) {
addOutputArgument(yOut, batchMeanOut, batchMeanVar);
}
}
public FusedBatchNorm(@NonNull SameDiff sameDiff, @NonNull SDVariable x, @NonNull SDVariable scale, @NonNull SDVariable offset,
@NonNull SDVariable dataFormat, @NonNull SDVariable isTraining) {
super("", sameDiff, new SDVariable[]{x, scale, offset, dataFormat, isTraining});
}
@Override
public String opName() {
return "fused_batch_norm";
}
@Override
public String tensorflowName() {
return "FusedBatchNormV2";
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
int n = args().length;
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes);
return Collections.singletonList(inputDataTypes.get(0));
}
}

View File

@ -0,0 +1,65 @@
/* ******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.custom;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import java.util.Collections;
import java.util.List;
public class MatrixBandPart extends DynamicCustomOp {
public MatrixBandPart() {}
public MatrixBandPart(@NonNull INDArray input, int minLower, int maxUpper) {
Preconditions.checkArgument(input.rank() >= 2, "MatrixBandPart: Input rank should be 2 or higher");
long N = input.size(-2);
long M = input.size(-1);
Preconditions.checkArgument(minLower > -N && minLower < N, "MatrixBandPart: lower diagonal count %s should be less than %s",
minLower, N);
Preconditions.checkArgument(maxUpper > -M && maxUpper < M, "MatrixBandPart: upper diagonal count %s should be less than %s.",
maxUpper, M);
addInputArgument(input);
addIArgument(minLower, maxUpper);
}
public MatrixBandPart(@NonNull SameDiff sameDiff, @NonNull SDVariable input, SDVariable minLower, SDVariable maxUpper) {
super("", sameDiff, new SDVariable[]{input, minLower, maxUpper});
}
@Override
public String opName() {
return "matrix_band_part";
}
@Override
public String tensorflowName() {
return "MatrixBandPart";
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
int n = args().length;
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes);
return Collections.singletonList(inputDataTypes.get(0));
}
}

View File

@ -0,0 +1,66 @@
/* ******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.custom;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import java.util.Collections;
import java.util.List;
public class Polygamma extends DynamicCustomOp {
public Polygamma() {}
public Polygamma(@NonNull INDArray n, @NonNull INDArray x) {
Preconditions.checkArgument(n.shape() != x.shape(),
"Polygamma: n and x must have the same shapes");
addInputArgument(n,x);
}
public Polygamma(@NonNull INDArray n, @NonNull INDArray x, INDArray output) {
this(n,x);
if (output != null) {
addOutputArgument(output);
}
}
public Polygamma(@NonNull SameDiff sameDiff, @NonNull SDVariable n, @NonNull SDVariable x) {
super("", sameDiff, new SDVariable[]{n ,x});
}
@Override
public String opName() {
return "polygamma";
}
@Override
public String tensorflowName() {
return "Polygamma";
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
int n = args().length;
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes);
return Collections.singletonList(inputDataTypes.get(0));
}
}

View File

@ -0,0 +1,61 @@
/* ******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.custom;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.rng.Random;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
public class RandomCrop extends DynamicCustomOp {
public RandomCrop() {}
public RandomCrop(@NonNull INDArray input, @NonNull INDArray shape) {
Preconditions.checkArgument(shape.isVector(),"RandomCrop:Shape tensor should be a vector");
Preconditions.checkArgument(input.rank() == shape.length(), "RandomCrop:The length of the shape vector is not match input rank");
addInputArgument(input, shape);
}
public RandomCrop(@NonNull SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable shape) {
super("", sameDiff, new SDVariable[]{input, shape});
}
@Override
public String opName() {
return "random_crop";
}
@Override
public String tensorflowName() {
return "RandomCrop";
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
Preconditions.checkState(inputDataTypes != null /*&& inputDataTypes.size() == 4*/,
"Expected 4 input datatypes for %s, got %s", getClass(), inputDataTypes);
return Collections.singletonList(DataType.FLOAT); //TF import: always returns float32...
}
}

View File

@ -0,0 +1,64 @@
/* ******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.custom;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import java.util.Collections;
import java.util.List;
public class Roll extends DynamicCustomOp {
public Roll() {}
public Roll(@NonNull INDArray input, @NonNull INDArray axes, @NonNull INDArray shifts) {
Preconditions.checkArgument(axes.rank() == shifts.rank(), "Roll: shifts and axes should be the same rank");
Preconditions.checkArgument(axes.length() == shifts.length(), "Roll: shifts and axes should be the same length");
addInputArgument(input, axes, shifts);
}
public Roll(@NonNull INDArray input, int shift) {
addInputArgument(input);
addIArgument(shift);
}
public Roll(@NonNull SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable shift) {
super("", sameDiff, new SDVariable[]{input,shift});
}
@Override
public String opName() {
return "roll";
}
@Override
public String tensorflowName() {
return "Roll";
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
int n = args().length;
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes);
return Collections.singletonList(inputDataTypes.get(0));
}
}

View File

@ -0,0 +1,64 @@
/* ******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.custom;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import java.util.Collections;
import java.util.List;
public class ToggleBits extends DynamicCustomOp {
public ToggleBits() {}
public ToggleBits(@NonNull INDArray input, INDArray output) {
this(input);
if (output != null) {
addOutputArgument(output);
}
}
public ToggleBits(@NonNull INDArray input) {
addInputArgument(input);
}
public ToggleBits(@NonNull SameDiff sameDiff, @NonNull SDVariable input) {
super("", sameDiff, new SDVariable[]{input});
}
@Override
public String opName() {
return "toggle_bits";
}
@Override
public String tensorflowName() {
return "Invert";
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
int n = args().length;
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes);
return Collections.singletonList(inputDataTypes.get(0));
}
}

View File

@ -21,6 +21,7 @@ import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.Op; import org.nd4j.linalg.api.ops.Op;
@ -41,6 +42,12 @@ public class NonMaxSuppression extends DynamicCustomOp {
super(null, sameDiff, new SDVariable[]{boxes, scores, maxOutSize, iouThreshold, scoreThreshold}, false); super(null, sameDiff, new SDVariable[]{boxes, scores, maxOutSize, iouThreshold, scoreThreshold}, false);
} }
public NonMaxSuppression(INDArray boxes, INDArray scores, int maxOutSize, double iouThreshold, double scoreThreshold) {
addInputArgument(boxes,scores);
addIArgument(maxOutSize);
addTArgument(iouThreshold, scoreThreshold);
}
@Override @Override
public String onnxName() { public String onnxName() {
throw new NoOpNameFoundException("No onnx name found for shape " + opName()); throw new NoOpNameFoundException("No onnx name found for shape " + opName());
@ -53,7 +60,7 @@ public class NonMaxSuppression extends DynamicCustomOp {
@Override @Override
public String[] tensorflowNames() { public String[] tensorflowNames() {
return new String[]{"NonMaxSuppression", "NonMaxSuppressionV2"}; return new String[]{"NonMaxSuppression", "NonMaxSuppressionV2","NonMaxSuppressionV3","NonMaxSuppressionV4"};
} }
@Override @Override

View File

@ -0,0 +1,284 @@
/*******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.layers.convolution;
import lombok.Builder;
import lombok.Getter;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.imports.descriptors.properties.PropertyMapping;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling2DConfig;
import org.nd4j.linalg.util.ArrayUtil;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;
import java.lang.reflect.Field;
import java.util.*;
@Slf4j
@Getter
public class MaxPoolWithArgmax extends DynamicCustomOp {
protected Pooling2DConfig config;
protected DataType outputType;
public MaxPoolWithArgmax() {
}
@Builder(builderMethodName = "sameDiffBuilder")
@SuppressWarnings("Used in lombok")
public MaxPoolWithArgmax(SameDiff sameDiff, SDVariable input, Pooling2DConfig config) {
super(null, sameDiff, new SDVariable[]{input}, false);
config.setType(Pooling2D.Pooling2DType.MAX);
this.config = config;
addArgs();
}
public MaxPoolWithArgmax(INDArray input, INDArray output,INDArray outArgMax, @NonNull Pooling2DConfig config){
super(null, new INDArray[]{input}, new INDArray[]{output, outArgMax});
config.setType(Pooling2D.Pooling2DType.MAX);
this.config = config;
addArgs();
}
@Override
public boolean isConfigProperties() {
return true;
}
@Override
public String configFieldName() {
return "config";
}
@Override
public Map<String, Object> propertiesForFunction() {
if(config == null && iArguments.size() > 0){
//Perhaps loaded from FlatBuffers - hence we have IArgs but not Config object
config = Pooling2DConfig.builder()
.kH(iArguments.get(0))
.kW(iArguments.get(1))
.sH(iArguments.get(2))
.sW(iArguments.get(3))
.pH(iArguments.get(4))
.pW(iArguments.get(5))
.dH(iArguments.get(6))
.dW(iArguments.get(7))
.isSameMode(iArguments.get(8) == 1)
.extra(iArguments.get(9))
.isNHWC(iArguments.get(10) == 1)
.type(Pooling2D.Pooling2DType.MAX)
.build();
}
return config.toProperties();
}
private void addArgs() {
addIArgument(config.getKH(),
config.getKW(),
config.getSH(),
config.getSW(),
config.getPH(),
config.getPW(),
config.getDH(),
config.getDW(),
ArrayUtil.fromBoolean(config.isSameMode()),
(int) config.getExtra(),
ArrayUtil.fromBoolean(config.isNHWC())
);
}
public String getPoolingPrefix() {
return "max";
}
@Override
public List<SDVariable> doDiff(List<SDVariable> f1) {
List<SDVariable> ret = new ArrayList<>();
List<SDVariable> inputs = new ArrayList<>();
inputs.addAll(Arrays.asList(args()));
inputs.add(f1.get(0));
Pooling2DDerivative pooling2DDerivative = Pooling2DDerivative.derivativeBuilder()
.inputs(inputs.toArray(new SDVariable[inputs.size()]))
.sameDiff(sameDiff)
.config(config)
.build();
ret.addAll(Arrays.asList(pooling2DDerivative.outputVariables()));
return ret;
}
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
val aStrides = nodeDef.getAttrOrThrow("strides");
val tfStrides = aStrides.getList().getIList();
val aKernels = nodeDef.getAttrOrThrow("ksize");
val tfKernels = aKernels.getList().getIList();
int sH = 0;
int sW = 0;
int pH = 0;
int pW = 0;
int kH = 0;
int kW = 0;
val aPadding = nodeDef.getAttrOrThrow("padding");
val padding = aPadding.getList().getIList();
val paddingMode = aPadding.getS().toStringUtf8().replaceAll("\"", "");
boolean isSameMode = paddingMode.equalsIgnoreCase("SAME");
String data_format = "nhwc";
if (nodeDef.containsAttr("data_format")) {
val attr = nodeDef.getAttrOrThrow("data_format");
data_format = attr.getS().toStringUtf8().toLowerCase();
}
if (data_format.equalsIgnoreCase("nhwc")) {
sH = tfStrides.get(1).intValue();
sW = tfStrides.get(2).intValue();
kH = tfKernels.get(1).intValue();
kW = tfKernels.get(2).intValue();
pH = padding.size() > 0 ? padding.get(1).intValue() : 0;
pW = padding.size() > 0 ? padding.get(2).intValue() : 0;
} else {
sH = tfStrides.get(2).intValue();
sW = tfStrides.get(3).intValue();
kH = tfKernels.get(2).intValue();
kW = tfKernels.get(3).intValue();
pH = padding.size() > 0 ? padding.get(2).intValue() : 0;
pW = padding.size() > 0 ? padding.get(3).intValue() : 0;
}
Pooling2DConfig pooling2DConfig = Pooling2DConfig.builder()
.sH(sH)
.sW(sW)
.type(Pooling2D.Pooling2DType.MAX)
.isSameMode(isSameMode)
.kH(kH)
.kW(kW)
.pH(pH)
.pW(pW)
.isNHWC(data_format.equalsIgnoreCase("nhwc"))
.extra(1.0) // averaging only for non-padded values
.build();
this.config = pooling2DConfig;
addArgs();
if(attributesForNode.containsKey("argmax")) {
outputType = TFGraphMapper.convertType(attributesForNode.get("argmax").getType());
} else {
outputType = DataType.UINT32;
}
}
@Override
public Map<String, Map<String, PropertyMapping>> mappingsForFunction() {
Map<String, Map<String, PropertyMapping>> ret = new HashMap<>();
Map<String, PropertyMapping> map = new HashMap<>();
val strideMapping = PropertyMapping.builder()
.tfAttrName("strides")
.onnxAttrName("strides")
.propertyNames(new String[]{"sW", "sH"})
.build();
val paddingMapping = PropertyMapping.builder()
.onnxAttrName("padding")
.tfAttrName("padding")
.propertyNames(new String[]{"pH", "pW"})
.build();
val kernelMapping = PropertyMapping.builder()
.propertyNames(new String[]{"kH", "kW"})
.tfInputPosition(1)
.onnxAttrName("ksize")
.build();
val dilationMapping = PropertyMapping.builder()
.onnxAttrName("dilations")
.propertyNames(new String[]{"dW", "dH"})
.tfAttrName("rates")
.build();
//data_format
val dataFormatMapping = PropertyMapping.builder()
.propertyNames(new String[]{"isNHWC"})
.tfAttrName("data_format")
.build();
map.put("sW", strideMapping);
map.put("sH", strideMapping);
map.put("kH", kernelMapping);
map.put("kW", kernelMapping);
map.put("dW", dilationMapping);
map.put("dH", dilationMapping);
map.put("pH", paddingMapping);
map.put("pW", paddingMapping);
map.put("isNHWC", dataFormatMapping);
ret.put(onnxName(), map);
ret.put(tensorflowName(), map);
return ret;
}
@Override
public String opName() {
return "max_pool_with_argmax";
}
@Override
public String onnxName() {
return "MaxPoolWithArgmax";
}
@Override
public String tensorflowName() {
return "MaxPoolWithArgmax";
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 1, "Expected 1 input data type for %s, got %s", getClass(), inputDataTypes);
List<DataType> result = new ArrayList<>();
result.add(inputDataTypes.get(0));
result.add(outputType == null ? DataType.UINT32 : outputType);
return result;
}
}

View File

@ -293,8 +293,8 @@ public class MaxPooling2D extends DynamicCustomOp {
} }
@Override @Override
public String tensorflowName() { public String[] tensorflowNames() {
return "MaxPool"; return new String[]{"MaxPool","MaxPoolV2"};
} }
@Override @Override

View File

@ -68,7 +68,7 @@ public class ClipByValue extends DynamicCustomOp {
@Override @Override
public String opName() { public String opName() {
return "clipbyvalue"; return "ClipByValue";
} }
@Override @Override

View File

@ -53,15 +53,9 @@ public class RShiftBits extends BaseDynamicTransformOp {
return "rshift_bits"; return "rshift_bits";
} }
@Override
public String onnxName() {
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
}
@Override @Override
public String tensorflowName() { public String tensorflowName() {
throw new NoOpNameFoundException("No TensorFlow op opName found for " + opName()); return "RightShift";
} }

View File

@ -53,15 +53,9 @@ public class ShiftBits extends BaseDynamicTransformOp {
return "shift_bits"; return "shift_bits";
} }
@Override
public String onnxName() {
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
}
@Override @Override
public String tensorflowName() { public String tensorflowName() {
throw new NoOpNameFoundException("No TensorFlow op opName found for " + opName()); return "LeftShift";
} }

View File

@ -46,8 +46,8 @@ public class UniqueWithCounts extends DynamicCustomOp {
} }
@Override @Override
public String tensorflowName() { public String[] tensorflowNames() {
return "UniqueWithCounts"; return new String[]{"UniqueWithCounts","UniqueWithCountsV2"};
} }
@Override @Override

View File

@ -77,8 +77,8 @@ public class CopyOp extends BaseTransformSameOp {
} }
@Override @Override
public String tensorflowName() { public String[] tensorflowNames() {
throw new NoOpNameFoundException("No tensorflow op opName found for " + opName()); return new String[]{"Copy","DeepCopy","CopyHost"};
} }
@Override @Override

View File

@ -57,7 +57,7 @@ public class ModOp extends BaseDynamicTransformOp {
@Override @Override
public String tensorflowName() { public String tensorflowName() {
return "mod"; return "Mod";
} }
@Override @Override

View File

@ -67,12 +67,6 @@ public class Not extends BaseTransformBoolOp {
return "Not"; return "Not";
} }
@Override
public String tensorflowName() {
throw new NoOpNameFoundException("Tensorflow name not found for " + opName());
//return "Not";
}
@Override @Override
public List<SDVariable> doDiff(List<SDVariable> f1) { public List<SDVariable> doDiff(List<SDVariable> f1) {
return Collections.singletonList(f().zerosLike(arg())); return Collections.singletonList(f().zerosLike(arg()));

View File

@ -59,19 +59,6 @@ public class GELU extends BaseTransformStrictOp {
return "gelu"; return "gelu";
} }
@Override
public String onnxName() {
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
}
@Override
public String tensorflowName()
{
throw new NoOpNameFoundException("No tensorflow op opName found for " + opName());
//return "GELU";
}
@Override @Override
public List<SDVariable> doDiff(List<SDVariable> i_v) { public List<SDVariable> doDiff(List<SDVariable> i_v) {
SDVariable ret = f().geluDerivative(arg(), false).mul(i_v.get(0)); SDVariable ret = f().geluDerivative(arg(), false).mul(i_v.get(0));

View File

@ -24,6 +24,7 @@ import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef; import org.tensorflow.framework.NodeDef;
@ -71,7 +72,12 @@ public class DistributionUniform extends DynamicCustomOp {
@Override @Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) { public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
AttrValue v = attributesForNode.get("dtype"); AttrValue vDtype = attributesForNode.get("dtype");
AttrValue vTout = attributesForNode.get("Tout");
if (vDtype == null && vTout == null) {
throw new ND4JIllegalStateException("Unable to find output data type for node " + nodeDef.getName());
}
AttrValue v = vDtype == null ? vTout : vDtype;
dataType = TFGraphMapper.convertType(v.getType()); dataType = TFGraphMapper.convertType(v.getType());
addIArgument(dataType.toInt()); addIArgument(dataType.toInt());
addTArgument(0.0, 1.0); //TF version is hardcoded 0 to 1 addTArgument(0.0, 1.0); //TF version is hardcoded 0 to 1
@ -92,8 +98,8 @@ public class DistributionUniform extends DynamicCustomOp {
} }
@Override @Override
public String tensorflowName() { public String[] tensorflowNames() {
return "RandomUniform"; return new String[]{"RandomUniform","RandomUniformInt"};
} }
@Override @Override
@ -103,7 +109,7 @@ public class DistributionUniform extends DynamicCustomOp {
@Override @Override
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){ public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 1, "Expected exactly 1 input datatype for %s, got %s", getClass(), inputDataTypes); Preconditions.checkState(inputDataTypes != null /*&& inputDataTypes.size() == 1*/, "Expected input datatypes for %s, got %s", getClass(), inputDataTypes);
//Input data type specifies the shape //Input data type specifies the shape
if(dataType != null){ if(dataType != null){
return Collections.singletonList(dataType); return Collections.singletonList(dataType);

View File

@ -65,18 +65,6 @@ public class DropOut extends BaseRandomOp {
return "dropout"; return "dropout";
} }
@Override
public String onnxName() {
throw new NoOpNameFoundException("No ONNX op name found for: " + getClass().getName());
}
@Override
public String tensorflowName() {
throw new NoOpNameFoundException("No tensorflow op name found for: " + getClass().getName());
//return opName();
}
@Override @Override
public Type opType() { public Type opType() {
return Type.RANDOM ; return Type.RANDOM ;

View File

@ -736,6 +736,35 @@ public class LayerOpValidation extends BaseOpValidation {
// sd.execBackwards(); // TODO: test failing here // sd.execBackwards(); // TODO: test failing here
} }
@Test
public void testMaxPoolingArgMax() {
Nd4j.getRandom().setSeed(12345);
int nIn = 3;
int kH = 2;
int kW = 2;
int mb = 3;
int imgH = 8;
int imgW = 8;
SameDiff sd = SameDiff.create();
INDArray inArr = Nd4j.rand(new int[]{mb, nIn, imgH, imgW});
SDVariable in = sd.var("in", inArr);
Pooling2DConfig pooling2DConfig = Pooling2DConfig.builder()
.kH(kH).kW(kW)
.pH(0).pW(0)
.sH(1).sW(1)
.dH(1).dW(1)
.isSameMode(true)
.build();
SDVariable[] results = sd.nn().maxPoolWithArgmax(new String[]{"",""}, in, pooling2DConfig);
assertArrayEquals(inArr.shape(), results[0].eval().shape());
assertArrayEquals(inArr.shape(), results[1].eval().shape());
}
@Test @Test
public void testMaxPooling2dBasic() { public void testMaxPooling2dBasic() {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);

View File

@ -76,8 +76,6 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a
"adjust_contrast/.*", "adjust_contrast/.*",
//Failing 2019/09/11 - https://github.com/eclipse/deeplearning4j/issues/7965 //Failing 2019/09/11 - https://github.com/eclipse/deeplearning4j/issues/7965
"bincount/.*", "bincount/.*",
// Failing 2019/11/15 https://github.com/eclipse/deeplearning4j/issues/8400
"bitcast/.*",
// Failing 2019/11/14 https://github.com/eclipse/deeplearning4j/issues/8393 // Failing 2019/11/14 https://github.com/eclipse/deeplearning4j/issues/8393
"is_strictly_increasing/emptyArrayTest/.*", "is_strictly_increasing/emptyArrayTest/.*",
@ -116,20 +114,32 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a
// 2019/11/15 - missing dtype argument in nd4j, tests are useless https://github.com/eclipse/deeplearning4j/issues/8398 // 2019/11/15 - missing dtype argument in nd4j, tests are useless https://github.com/eclipse/deeplearning4j/issues/8398
"zeros_like/rank2_float32_dtype_int.*", "zeros_like/rank2_float32_dtype_int.*",
// 2019/11/15 - failure https://github.com/eclipse/deeplearning4j/issues/8399
"crop_and_resize.*",
// 2019/11/15 - failure https://github.com/eclipse/deeplearning4j/issues/8401
"draw_bounding_boxes.*",
// 2019/11/15 - failure https://github.com/eclipse/deeplearning4j/issues/8402 // 2019/11/15 - failure https://github.com/eclipse/deeplearning4j/issues/8402
"fake_quant/min_max_args_per_channel.*", "fake_quant/min_max_args_per_channel.*",
// 2019/11/15 - failure https://github.com/eclipse/deeplearning4j/issues/8403 // 2019/11/15 - failure https://github.com/eclipse/deeplearning4j/issues/8403
"resize_bilinear/int32.*", "resize_bilinear/int32.*",
// Suggesting TF 1.15 bug - see https://github.com/eclipse/deeplearning4j/issues/8449 // Suggesting TF 1.15 bug
"non_max_suppression_v2/float16.*" "non_max_suppression_v2/float16.*",
// 11.26.2019 failing - https://github.com/eclipse/deeplearning4j/issues/8450
"betainc.*",
// 11.26.2019 failing - https://github.com/eclipse/deeplearning4j/issues/8452
"polygamma.*",
// 11.26.2019 failing - https://github.com/eclipse/deeplearning4j/issues/8453
"roll/.*",
// 11.26.2019 failing https://github.com/eclipse/deeplearning4j/issues/8455
"matrix_band_part/.*",
// 11.28.2019 failing https://github.com/eclipse/deeplearning4j/issues/8458
"adjust_hue/.*",
// 11.28.2019 failing https://github.com/eclipse/deeplearning4j/issues/8459
"adjust_saturation/.*"
}; };
/* As per TFGraphTestList.printArraysDebugging - this field defines a set of regexes for test cases that should have /* As per TFGraphTestList.printArraysDebugging - this field defines a set of regexes for test cases that should have

View File

@ -32,7 +32,10 @@ import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.api.ops.executioner.OpStatus; import org.nd4j.linalg.api.ops.executioner.OpStatus;
import org.nd4j.linalg.api.ops.impl.controlflow.Where; import org.nd4j.linalg.api.ops.impl.controlflow.Where;
import org.nd4j.linalg.api.ops.impl.image.CropAndResize; import org.nd4j.linalg.api.ops.impl.image.CropAndResize;
import org.nd4j.linalg.api.ops.impl.image.NonMaxSuppression;
import org.nd4j.linalg.api.ops.impl.image.ResizeBilinear; import org.nd4j.linalg.api.ops.impl.image.ResizeBilinear;
import org.nd4j.linalg.api.ops.impl.layers.convolution.MaxPoolWithArgmax;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling2DConfig;
import org.nd4j.linalg.api.ops.impl.reduce.MmulBp; import org.nd4j.linalg.api.ops.impl.reduce.MmulBp;
import org.nd4j.linalg.api.ops.impl.shape.Create; import org.nd4j.linalg.api.ops.impl.shape.Create;
import org.nd4j.linalg.api.ops.impl.transforms.any.IsMax; import org.nd4j.linalg.api.ops.impl.transforms.any.IsMax;
@ -53,6 +56,7 @@ import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import static java.lang.Float.NaN;
import static org.junit.Assert.*; import static org.junit.Assert.*;
/** /**
@ -867,6 +871,26 @@ public class CustomOpsTests extends BaseNd4jTest {
assertArrayEquals(new long[]{1,10}, lsd.get(0).getShape()); assertArrayEquals(new long[]{1,10}, lsd.get(0).getShape());
} }
@Test
public void testAdjustSaturation() {
INDArray in = Nd4j.createFromArray(new double[]{50,100,78, 118.5,220,112.5,190,163.5,230, 255,128.5,134}).reshape(2,2,3);
INDArray out = Nd4j.create(in.shape());
INDArray expected = Nd4j.createFromArray(new double[]{0,100,56, 17,220,5, 150,97,230, 255,2,13}).reshape(2,2,3);
Nd4j.exec(new AdjustSaturation(in, 2.0, out));
assertEquals(expected, out);
}
@Test
public void testAdjustHue() {
INDArray in = Nd4j.createFromArray(new double[]{0,100,56, 17,220,5, 150,97,230, 255,2,13}).reshape(2,2,3);
INDArray out = Nd4j.create(in.shape());
INDArray expected = Nd4j.createFromArray(new double[]{100,0,44, 208,5,220, 177,230,97, 2,255,244}).reshape(2,2,3);
Nd4j.exec(new AdjustHue(in, 0.5, out));
assertEquals(expected, out);
}
@Test @Test
public void testBitCast() { public void testBitCast() {
INDArray in = Nd4j.linspace(DataType.FLOAT, 1.0f, 1.0f, 8).reshape(2,2,2); INDArray in = Nd4j.linspace(DataType.FLOAT, 1.0f, 1.0f, 8).reshape(2,2,2);
@ -1088,6 +1112,216 @@ public class CustomOpsTests extends BaseNd4jTest {
assertArrayEquals(new long[]{1,10, 2}, lsd.get(0).getShape()); assertArrayEquals(new long[]{1,10, 2}, lsd.get(0).getShape());
} }
@Test
public void testBetaInc() {
Nd4j.getRandom().setSeed(10);
INDArray a = Nd4j.linspace(DataType.BFLOAT16, 0.1, 0.1, 9).reshape(3,3);
INDArray b = Nd4j.linspace(DataType.BFLOAT16, 0.1, 0.1, 9).reshape(3,3);
INDArray x = Nd4j.linspace(DataType.BFLOAT16, 0.1, 0.1, 9).reshape(3,3);
INDArray expected = Nd4j.createFromArray(new float[]{0.4121f, 0.3926f, 0.4082f,
0.4414f, 0.5000f, 0.5703f,
0.6562f, 0.7656f, 0.8828f}).reshape(3,3);
BetaInc op = new BetaInc(a,b,x);
INDArray[] out = Nd4j.exec(op);
assertArrayEquals(expected.shape(), out[0].shape());
for (int i = 0; i < 3; ++i)
assertArrayEquals(expected.toDoubleMatrix()[i], out[0].toDoubleMatrix()[i], 1e-4);
}
@Test
public void testFusedBatchNorm() {
INDArray x = Nd4j.linspace(DataType.DOUBLE, 1.0, 1.0, 2*2*3*4).reshape(2,2,3,4);
INDArray scale = Nd4j.create(DataType.DOUBLE, 4);
scale.assign(0.5);
INDArray offset = Nd4j.create(DataType.DOUBLE, 4);
offset.assign(2.0);
INDArray y = Nd4j.createUninitialized(DataType.DOUBLE, x.shape());
INDArray batchMean = Nd4j.create(4);
INDArray batchVar = Nd4j.create(4);
FusedBatchNorm op = new FusedBatchNorm(x,scale,offset,0,1,
y, batchMean, batchVar);
INDArray expectedY = Nd4j.createFromArray(new double[]{1.20337462, 1.20337462, 1.20337462,
1.20337462, 1.34821558, 1.34821558, 1.34821558, 1.34821558, 1.49305654, 1.49305654,
1.49305654, 1.49305654, 1.63789749, 1.63789749, 1.63789749, 1.63789749, 1.78273857,
1.78273857, 1.78273857, 1.78273857, 1.92757952, 1.92757952, 1.92757952, 1.92757952,
2.0724206 , 2.0724206 , 2.0724206 , 2.0724206 , 2.21726155, 2.21726155, 2.21726155,
2.21726155, 2.36210251, 2.36210251, 2.36210251, 2.36210251, 2.50694346, 2.50694346,
2.50694346, 2.50694346, 2.65178442, 2.65178442, 2.65178442, 2.65178442, 2.79662538,
2.79662538, 2.79662538, 2.79662538}).reshape(x.shape());
INDArray expectedBatchMean = Nd4j.createFromArray(new double[]{23., 24., 25., 26.});
INDArray expectedBatchVar = Nd4j.createFromArray(new double[]{208.00001526, 208.00001526, 208.00001526, 208.00001526});
Nd4j.exec(op);
assertArrayEquals(expectedY.shape(), y.shape());
assertArrayEquals(expectedBatchMean.shape(), batchMean.shape());
assertArrayEquals(expectedBatchVar.shape(), batchVar.shape());
}
@Test
public void testMatrixBandPart() {
INDArray x = Nd4j.linspace(DataType.DOUBLE, 1.0, 1.0, 2*3*3).reshape(2,3,3);
val op = new MatrixBandPart(x,1,1);
INDArray expected = Nd4j.linspace(DataType.DOUBLE, 1.0, 1.0, 2*3*3).reshape(2,3,3);
/*expected.putScalar(0, 0, 2, 0.);
expected.putScalar(1, 0, 2, 0.);
expected.putScalar(0, 2, 0, 0.);
expected.putScalar(1, 2, 0, 0.);*/
INDArray[] out = Nd4j.exec(op);
assertEquals(expected, x);
}
@Test
public void testPolygamma() {
INDArray n = Nd4j.linspace(DataType.FLOAT, 1.0, 1.0, 9).reshape(3,3);
INDArray x = Nd4j.create(DataType.FLOAT, 3,3);
x.assign(0.5);
INDArray expected = Nd4j.createFromArray(new float[]{4.934802f, -16.828796f, 97.409088f, -771.474243f,
7691.113770f, -92203.460938f, 1290440.250000f, -20644900.000000f, 3.71595e+08f}).reshape(3,3);
INDArray output = Nd4j.create(DataType.FLOAT, expected.shape());
val op = new Polygamma(x,n,output);
Nd4j.exec(op);
assertEquals(expected, output);
}
@Test
public void testRandomCrop() {
INDArray x = Nd4j.createFromArray(new double[]{1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9.,2.1, 2.1, 0.7, 0.1,3., 4.2, 2.2, 1. }).reshape(2,2,4);
INDArray shape = Nd4j.createFromArray(new int[] {1,2,3});
val op = new RandomCrop(x, shape);
INDArray[] res = Nd4j.exec(op);
}
@Test
public void testRoll() {
INDArray x = Nd4j.createFromArray(new double[]{ 11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42, 12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42,
21.11, 21.12, 21.21, 21.22, 21.31, 21.32, 21.41, 21.42, 22.11, 22.12, 22.21, 22.22, 22.31, 22.32, 22.41, 22.42}).
reshape(2,2,4,2);
INDArray expected = Nd4j.createFromArray(new double[]{ 22.21, 22.22, 22.31, 22.32, 22.41, 22.42, 11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42,
12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, 21.11, 21.12, 21.21, 21.22, 21.31, 21.32,
21.41, 21.42, 22.11, 22.12
}).reshape(x.shape());
val op = new Roll(x, 6);
INDArray[] res = Nd4j.exec(op);
assertEquals(expected, res[0]);
}
@Test
public void testToggleBits() {
INDArray input = Nd4j.createFromArray(new int[]{2,2});
INDArray expected = Nd4j.createFromArray(new int[]{-3,-3});
ToggleBits op = new ToggleBits(input);
val result = Nd4j.exec(op);
assertEquals(expected, result[0]);
}
@Ignore("AS 11.28.2019 - https://github.com/eclipse/deeplearning4j/issues/8449")
@Test
public void testNonMaxSuppression() {
INDArray boxes = Nd4j.createFromArray(new float[] {0.8115f, 0.4121f, 0.0771f, 0.4863f,
0.7412f, 0.7607f, 0.1543f, 0.5479f,
0.8223f, 0.2246f, 0.0049f, 0.6465f}).reshape(3,4);
INDArray scores = Nd4j.createFromArray(new float[]{0.0029f, 0.8135f, 0.4873f});
val op = new NonMaxSuppression(boxes,scores,2,0.5,0.5);
val res = Nd4j.exec(op);
assertEquals(new long[]{1}, res[0].shape());
}
@Test
public void testMatrixBand() {
INDArray input = Nd4j.createFromArray(new float[]{0.7788f,0.8012f,0.7244f,0.2309f,
0.7271f,0.1804f,0.5056f,0.8925f,
0.5461f,0.9234f,0.0856f,0.7938f}).reshape(3,4);
MatrixBandPart op = new MatrixBandPart(input,1,-1);
List<LongShapeDescriptor> lsd = op.calculateOutputShape();
assertEquals(1, lsd.size());
}
@Ignore("Failed AS 11.26.2019 - https://github.com/eclipse/deeplearning4j/issues/8450")
@Test
public void testBetaInc1() {
INDArray a = Nd4j.createFromArray(new float[]{0.7788f, 0.8012f, 0.7244f, 0.2309f});
INDArray b = Nd4j.createFromArray(new float[]{0.7717f, 0.9281f, 0.9846f, 0.4838f});
INDArray c = Nd4j.createFromArray(new float[]{0.9441f, 0.5957f, 0.8669f, 0.3502f});
BetaInc op = new BetaInc(a,b,c);
INDArray[] ret = Nd4j.exec(op);
INDArray expected = Nd4j.createFromArray(new float[]{0.9122f, 0.6344f, 0.8983f, 0.6245f});
assertEquals(expected, ret[0]);
}
@Ignore("Failure AS 11.28.2019 - https://github.com/eclipse/deeplearning4j/issues/8452")
@Test
public void testPolygamma1() {
INDArray a = Nd4j.createFromArray(new float[]{0.7788f, 0.8012f, 0.7244f, 0.2309f,
0.7271f, 0.1804f, 0.5056f, 0.8925f,
0.5461f, 0.9234f, 0.0856f, 0.7938f}).reshape(3,4);
INDArray b = Nd4j.createFromArray(new float[]{0.7717f, 0.9281f, 0.9846f, 0.4838f,
0.6433f, 0.6041f, 0.6501f, 0.7612f,
0.7605f, 0.3948f, 0.9493f, 0.8600f}).reshape(3,4);
INDArray expected = Nd4j.createFromArray(new float[]{NaN,NaN,NaN,NaN,NaN,NaN,NaN,NaN,NaN,NaN,NaN,NaN, }).reshape(3,4);
Polygamma op = new Polygamma(a,b);
INDArray[] ret = Nd4j.exec(op);
assertEquals(expected, ret[0]);
}
@Ignore("Failure AS 11.28.2019 - https://github.com/eclipse/deeplearning4j/issues/8453")
@Test
public void testRoll1() {
INDArray a = Nd4j.createFromArray(new float[]{0.7788f, 0.8012f, 0.7244f, 0.2309f});
Roll op = new Roll(a,Nd4j.scalar(2),Nd4j.scalar(0));
INDArray[] ret = Nd4j.exec(op);
INDArray expected = Nd4j.createFromArray(new float[]{0.7244f, 0.2309f, 0.7788f, 0.8012f});
assertEquals(expected, ret[0]);
}
@Test
public void testAdjustHueShape(){
INDArray image = Nd4j.createFromArray(new float[]{0.7788f, 0.8012f, 0.7244f,
0.2309f, 0.7271f, 0.1804f, 0.5056f, 0.8925f, 0.5461f,
0.9234f, 0.0856f, 0.7938f, 0.6591f, 0.5555f, 0.1596f,
0.3087f, 0.1548f, 0.4695f, 0.9939f, 0.6113f, 0.6765f,
0.1800f, 0.6750f, 0.2246f, 0.0509f, 0.4601f, 0.8284f,
0.2354f, 0.9752f, 0.8361f, 0.2585f, 0.4189f, 0.7028f,
0.7679f, 0.5373f, 0.7234f, 0.2690f, 0.0062f, 0.0327f,
0.0644f, 0.8428f, 0.7494f, 0.0755f, 0.6245f, 0.3491f,
0.5793f, 0.5730f, 0.1822f, 0.6420f, 0.9143f, 0.3019f,
0.3574f, 0.1704f, 0.8395f, 0.5468f, 0.0744f, 0.9011f,
0.6574f, 0.4124f, 0.2445f, 0.4248f, 0.5219f, 0.6952f,
0.4900f, 0.2158f, 0.9549f, 0.1386f, 0.1544f, 0.5365f,
0.0134f, 0.4163f, 0.1456f, 0.4109f, 0.2484f, 0.3330f,
0.2974f, 0.6636f, 0.3808f, 0.8664f, 0.1896f, 0.7530f,
0.7215f, 0.6612f, 0.7270f, 0.5704f, 0.2666f, 0.7453f,
0.0444f, 0.3024f, 0.4850f, 0.7982f, 0.0965f, 0.7843f,
0.5075f, 0.0844f, 0.8370f, 0.6103f, 0.4604f, 0.6087f,
0.8594f, 0.4599f, 0.6714f, 0.2744f, 0.1981f, 0.4143f,
0.7821f, 0.3505f, 0.5040f, 0.1180f, 0.8307f, 0.1817f,
0.8442f, 0.5074f, 0.4471f, 0.5105f, 0.6666f, 0.2576f,
0.2341f, 0.6801f, 0.2652f, 0.5394f, 0.4690f, 0.6146f,
0.1210f, 0.2576f, 0.0769f, 0.4643f, 0.1628f, 0.2026f,
0.3774f, 0.0506f, 0.3462f, 0.5720f, 0.0838f, 0.4228f,
0.0588f, 0.5362f, 0.4756f, 0.2530f, 0.1778f, 0.0751f,
0.8977f, 0.3648f, 0.3065f, 0.4739f, 0.7014f, 0.4473f,
0.5171f, 0.1744f, 0.3487f, 0.7759f, 0.9491f, 0.2072f,
0.2182f, 0.6520f, 0.3092f, 0.9545f, 0.1881f, 0.9579f,
0.1785f, 0.9636f, 0.4830f, 0.6569f, 0.3353f, 0.9997f,
0.5869f, 0.5747f, 0.0238f, 0.2943f, 0.5248f, 0.5879f,
0.7266f, 0.1965f, 0.9167f, 0.9726f, 0.9206f, 0.0519f,
0.2997f, 0.0039f, 0.7652f, 0.5498f, 0.3794f, 0.3791f,
0.3528f, 0.2873f, 0.8082f, 0.4732f, 0.4399f, 0.6606f,
0.5991f, 0.0034f, 0.4874f}).reshape(8,8,3);
AdjustHue op = new AdjustHue(image, 0.2f);
INDArray[] res = Nd4j.exec(op);
System.out.println(res[0]);
List<LongShapeDescriptor> lsd = op.calculateOutputShape();
assertEquals(1, lsd.size());
assertArrayEquals(new long[]{8, 8, 3}, lsd.get(0).getShape());
}
@Test @Test
public void testBitCastShape_3(){ public void testBitCastShape_3(){
val x = Nd4j.createFromArray(new int[]{1, 2, 3, 4, 5, 6, 7, 8}).reshape(1, 4, 2); val x = Nd4j.createFromArray(new int[]{1, 2, 3, 4, 5, 6, 7, 8}).reshape(1, 4, 2);