TF import tests - adding missing operations (#65)
* Add and fix mappings. * Intermediate * Added and fixed some mappings * Added op * Missing constructors added. * Added new mappings * SDImage wrappers and minor tweaks. * Added missing constructor * Some corrections * Cleanup * Small fixes * Ops wrappers * Minor fixes. * Max Pooling * MaxPoolWithArgmax * Some fixes * Ignores for failures * Some ops fixed. * Some fixes * Missing package added * Some fixes * Ignored tests fixed. * Some fixes * Merge master * bitcast fix Signed-off-by: raver119 <raver119@gmail.com> * Bitcast fixedmaster
parent
1adc25919c
commit
5e152c0d9a
|
@ -2616,6 +2616,36 @@ public class DifferentialFunctionFactory {
|
|||
return new FakeQuantWithMinMaxVarsPerChannel(sameDiff,x,min,max).outputVariable();
|
||||
}
|
||||
|
||||
public SDVariable betainc( SDVariable a, SDVariable b, SDVariable x) {
|
||||
return new BetaInc(sameDiff, a, b, x).outputVariable();
|
||||
}
|
||||
|
||||
public SDVariable[] fusedBatchNorm(SDVariable x, SDVariable scale, SDVariable offset,
|
||||
SDVariable dataFormat, SDVariable isTraining) {
|
||||
return new FusedBatchNorm(sameDiff,x,scale,offset,dataFormat,isTraining).outputVariables();
|
||||
}
|
||||
|
||||
public SDVariable matrixBandPart(SDVariable input, SDVariable minLower, SDVariable maxUpper) {
|
||||
return new MatrixBandPart(sameDiff,input,minLower,maxUpper).outputVariable();
|
||||
}
|
||||
|
||||
public SDVariable[] maxPoolWithArgmaxs(SDVariable x, Pooling2DConfig pooling2DConfig) {
|
||||
return new MaxPoolWithArgmax(sameDiff, x, pooling2DConfig).outputVariables();
|
||||
}
|
||||
|
||||
public SDVariable polygamma(SDVariable n, SDVariable x) {
|
||||
return new Polygamma(sameDiff, n,x).outputVariable();
|
||||
}
|
||||
|
||||
public SDVariable roll(SDVariable input, SDVariable shift) {
|
||||
return new Roll(sameDiff, input, shift).outputVariable();
|
||||
}
|
||||
|
||||
public SDVariable toggleBits(SDVariable x) {
|
||||
return new ToggleBits(sameDiff, x).outputVariable();
|
||||
}
|
||||
|
||||
|
||||
public String toString() {
|
||||
return "DifferentialFunctionFactory{methodNames=" + methodNames + "}";
|
||||
}
|
||||
|
|
|
@ -202,4 +202,16 @@ public class SDBitwise extends SDOps {
|
|||
SDVariable ret = f().bitwiseXor(x, y);
|
||||
return updateVariableNameAndReference(ret, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* Flip bits
|
||||
*
|
||||
* @param name Name of the output variable
|
||||
* @param x input array
|
||||
* @return array after flipping each input bit
|
||||
*/
|
||||
public SDVariable toggleBits(String name, SDVariable x) {
|
||||
SDVariable res = f().toggleBits(x);
|
||||
return updateVariableNameAndReference(res, name);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,6 +3,10 @@ package org.nd4j.autodiff.samediff.ops;
|
|||
import lombok.NonNull;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.linalg.api.ops.custom.AdjustContrast;
|
||||
import org.nd4j.linalg.api.ops.custom.AdjustHue;
|
||||
import org.nd4j.linalg.api.ops.custom.AdjustSaturation;
|
||||
import org.nd4j.linalg.api.ops.custom.RandomCrop;
|
||||
import org.nd4j.linalg.api.ops.impl.image.CropAndResize;
|
||||
import org.nd4j.linalg.api.ops.impl.image.ExtractImagePatches;
|
||||
import org.nd4j.linalg.api.ops.impl.image.NonMaxSuppression;
|
||||
|
@ -52,10 +56,67 @@ public class SDImage extends SDOps {
|
|||
return updateVariableNameAndReference(out, name);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Greedily selects a subset of bounding boxes in descending order of score
|
||||
* @param name Might be null. Name for the output variable
|
||||
* @param boxes 2D array of shape [num_boxes,4]
|
||||
* @param scores vector of shape [num_boxes]
|
||||
* @param maxOutSize scalar representing the maximum number of boxes to be selected
|
||||
* @param iouThreshold float - threshold for deciding whether boxes overlap too much with respect to IOU
|
||||
* @param scoreThreshold float - threshold for deciding when to remove boxes based on score
|
||||
* @return vectort of shape [M] representing the selected indices from the boxes tensor, where M <= max_output_size
|
||||
*/
|
||||
public SDVariable nonMaxSuppression(String name, @NonNull SDVariable boxes, @NonNull SDVariable scores, @NonNull SDVariable maxOutSize,
|
||||
@NonNull SDVariable iouThreshold, @NonNull SDVariable scoreThreshold){
|
||||
SDVariable out = new NonMaxSuppression(sd, boxes, scores, maxOutSize, iouThreshold, scoreThreshold).outputVariable();
|
||||
return updateVariableNameAndReference(out, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* Adjusts contrast of RGB or grayscale images.
|
||||
* @param name name for the output variable
|
||||
* @param in images to adjust. 3D shape or higher.
|
||||
* @param factor float multiplier for adjusting contrast.
|
||||
* @return Contrast-adjusted image
|
||||
*/
|
||||
public SDVariable adjustContrast(String name, @NonNull SDVariable in, @NonNull SDVariable factor) {
|
||||
SDVariable out = new AdjustContrast(sd, in, factor).outputVariable();
|
||||
return updateVariableNameAndReference(out, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* Adjust saturation of RGB images
|
||||
* @param name name for the output variable
|
||||
* @param in RGB image as 3D array
|
||||
* @param factor factor for saturation
|
||||
* @return adjusted image
|
||||
*/
|
||||
public SDVariable adjustSaturation(String name, @NonNull SDVariable in, @NonNull SDVariable factor) {
|
||||
SDVariable out = new AdjustSaturation(sd, in, factor).outputVariable();
|
||||
return updateVariableNameAndReference(out, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* Adjust hue of RGB image
|
||||
* @param name name for the output variable
|
||||
* @param in RGB image as 3D array
|
||||
* @param delta value to add to hue channel
|
||||
* @return adjusted image
|
||||
*/
|
||||
public SDVariable adjustHue(String name, @NonNull SDVariable in, @NonNull SDVariable delta) {
|
||||
SDVariable out = new AdjustHue(sd, in, delta).outputVariable();
|
||||
return updateVariableNameAndReference(out, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* Randomly crops image
|
||||
* @param name name for the output variable
|
||||
* @param input input array
|
||||
* @param shape shape for crop
|
||||
* @return cropped array
|
||||
*/
|
||||
public SDVariable randomCrop(String name, @NonNull SDVariable input, @NonNull SDVariable shape) {
|
||||
SDVariable out = new RandomCrop(sd, input, shape).outputVariable();
|
||||
return updateVariableNameAndReference(out, name);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2496,5 +2496,57 @@ public class SDMath extends SDOps {
|
|||
return updateVariableNameAndReference(res, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* Compute the regularized incomplete beta integral
|
||||
*
|
||||
* @param name Name of the output variable
|
||||
* @param a input array
|
||||
* @param b input array
|
||||
* @param x input array
|
||||
* @return array
|
||||
*/
|
||||
public SDVariable betainc(String name,SDVariable a,SDVariable b,SDVariable x) {
|
||||
SDVariable res = f().betainc(a,b,x);
|
||||
return updateVariableNameAndReference(res, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* Copy a tensor setting everything outside a central band in each innermost matrix.
|
||||
*
|
||||
* @param name Name of the output variable
|
||||
* @param input Rank k array
|
||||
* @param minLower Number of subdiagonals to keep.
|
||||
* @param maxUpper Number of superdiagonals to keep.
|
||||
* @return Rank k array of the same shape as input.
|
||||
*/
|
||||
public SDVariable matrixBandPart(String name, SDVariable input, SDVariable minLower, SDVariable maxUpper) {
|
||||
SDVariable res = f().matrixBandPart(input,minLower,maxUpper);
|
||||
return updateVariableNameAndReference(res, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* Polygamma function
|
||||
*
|
||||
* @param name Name of the output variable
|
||||
* @param n array
|
||||
* @param x array
|
||||
* @return array
|
||||
*/
|
||||
public SDVariable polygamma(String name, SDVariable n, SDVariable x) {
|
||||
SDVariable res = f().polygamma(n,x);
|
||||
return updateVariableNameAndReference(res, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* Rolls the elements of input
|
||||
*
|
||||
* @param name Name of the output variable
|
||||
* @param input array
|
||||
* @param shift number of places to shift elements
|
||||
* @return array
|
||||
*/
|
||||
public SDVariable roll(String name, SDVariable input, SDVariable shift) {
|
||||
SDVariable res = f().roll(input,shift);
|
||||
return updateVariableNameAndReference(res, name);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -19,6 +19,7 @@ package org.nd4j.autodiff.samediff.ops;
|
|||
import lombok.NonNull;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling2DConfig;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.Pad;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
|
@ -1032,4 +1033,35 @@ public class SDNN extends SDOps {
|
|||
);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Max pooling on the input and outputs both max values and indices
|
||||
*
|
||||
* @param name Name of the output variable
|
||||
* @param x input array
|
||||
* @return output array and argmax array
|
||||
*/
|
||||
public SDVariable[] maxPoolWithArgmax(String[] names, SDVariable x, Pooling2DConfig pooling2DConfig) {
|
||||
SDVariable[] res = f().maxPoolWithArgmaxs(x, pooling2DConfig);
|
||||
return sd.updateVariableNamesAndReferences(res, names);
|
||||
}
|
||||
|
||||
/**
|
||||
* Batch normalization
|
||||
*
|
||||
* @param name Name of the output variable
|
||||
* @param x 4D array
|
||||
* @param scale vector for scaling factor of normalized x
|
||||
* @param offset vector to shift to the normalized x
|
||||
* @param dataFormat integer scalar - data format
|
||||
* @param isTraining boolean scalar - is training mode
|
||||
* @return y: 4D array
|
||||
* batch_mean: vector
|
||||
* batch_var: vector
|
||||
*/
|
||||
public SDVariable[] fusedBatchNorm(String[] names, SDVariable x, SDVariable scale, SDVariable offset,
|
||||
SDVariable dataFormat, SDVariable isTraining) {
|
||||
SDVariable[] res = f().fusedBatchNorm(x,scale,offset,dataFormat,isTraining);
|
||||
return sd.updateVariableNamesAndReferences(res, names);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -46,7 +46,6 @@ public class ImportClassMapping {
|
|||
org.nd4j.linalg.api.ops.custom.BarnesEdgeForces.class,
|
||||
org.nd4j.linalg.api.ops.custom.BarnesHutGains.class,
|
||||
org.nd4j.linalg.api.ops.custom.BarnesHutSymmetrize.class,
|
||||
org.nd4j.linalg.api.ops.custom.DrawBoundingBoxes.class,
|
||||
org.nd4j.linalg.api.ops.custom.KnnMinDistance.class,
|
||||
org.nd4j.linalg.api.ops.custom.SpTreeCell.class,
|
||||
org.nd4j.linalg.api.ops.custom.Flatten.class,
|
||||
|
@ -122,6 +121,7 @@ public class ImportClassMapping {
|
|||
org.nd4j.linalg.api.ops.impl.layers.convolution.LocalResponseNormalizationDerivative.class,
|
||||
org.nd4j.linalg.api.ops.impl.layers.convolution.MaxPooling2D.class,
|
||||
org.nd4j.linalg.api.ops.impl.layers.convolution.MaxPooling3D.class,
|
||||
org.nd4j.linalg.api.ops.impl.layers.convolution.MaxPoolWithArgmax.class,
|
||||
org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2D.class,
|
||||
org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2DDerivative.class,
|
||||
org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling3DDerivative.class,
|
||||
|
@ -589,7 +589,17 @@ public class ImportClassMapping {
|
|||
org.nd4j.linalg.api.ops.custom.BitCast.class,
|
||||
org.nd4j.linalg.api.ops.custom.CompareAndBitpack.class,
|
||||
org.nd4j.linalg.api.ops.custom.DivideNoNan.class,
|
||||
org.nd4j.linalg.api.ops.custom.FakeQuantWithMinMaxVarsPerChannel.class
|
||||
org.nd4j.linalg.api.ops.custom.DrawBoundingBoxes.class,
|
||||
org.nd4j.linalg.api.ops.custom.FakeQuantWithMinMaxVarsPerChannel.class,
|
||||
org.nd4j.linalg.api.ops.custom.AdjustSaturation.class,
|
||||
org.nd4j.linalg.api.ops.custom.AdjustHue.class,
|
||||
org.nd4j.linalg.api.ops.custom.FusedBatchNorm.class,
|
||||
org.nd4j.linalg.api.ops.custom.BetaInc.class,
|
||||
org.nd4j.linalg.api.ops.custom.MatrixBandPart.class,
|
||||
org.nd4j.linalg.api.ops.custom.Polygamma.class,
|
||||
org.nd4j.linalg.api.ops.custom.RandomCrop.class,
|
||||
org.nd4j.linalg.api.ops.custom.Roll.class,
|
||||
org.nd4j.linalg.api.ops.custom.ToggleBits.class
|
||||
);
|
||||
|
||||
static {
|
||||
|
|
|
@ -1,5 +1,22 @@
|
|||
|
||||
/* ******************************************************************************
|
||||
* Copyright (c) 2019 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.nd4j.linalg.api.ops.custom;
|
||||
|
||||
import lombok.NonNull;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
@ -14,11 +31,11 @@ public class AdjustContrast extends BaseAdjustContrast {
|
|||
|
||||
public AdjustContrast() {super();}
|
||||
|
||||
public AdjustContrast(INDArray in, double factor, INDArray out) {
|
||||
public AdjustContrast(@NonNull INDArray in, double factor, INDArray out) {
|
||||
super(in, factor, out);
|
||||
}
|
||||
|
||||
public AdjustContrast(SameDiff sameDiff, SDVariable in, SDVariable factor) {
|
||||
public AdjustContrast(@NonNull SameDiff sameDiff, @NonNull SDVariable in, @NonNull SDVariable factor) {
|
||||
super(sameDiff,new SDVariable[]{in,factor});
|
||||
}
|
||||
|
||||
|
|
|
@ -1,5 +1,21 @@
|
|||
/* ******************************************************************************
|
||||
* Copyright (c) 2019 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.nd4j.linalg.api.ops.custom;
|
||||
|
||||
import lombok.NonNull;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
@ -14,11 +30,11 @@ public class AdjustContrastV2 extends BaseAdjustContrast {
|
|||
|
||||
public AdjustContrastV2() {super();}
|
||||
|
||||
public AdjustContrastV2(INDArray in, double factor, INDArray out) {
|
||||
public AdjustContrastV2(@NonNull INDArray in, double factor, INDArray out) {
|
||||
super(in, factor, out);
|
||||
}
|
||||
|
||||
public AdjustContrastV2(SameDiff sameDiff, SDVariable in, SDVariable factor) {
|
||||
public AdjustContrastV2(@NonNull SameDiff sameDiff, @NonNull SDVariable in, @NonNull SDVariable factor) {
|
||||
super( sameDiff,new SDVariable[]{in,factor});
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,69 @@
|
|||
/* ******************************************************************************
|
||||
* Copyright (c) 2019 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.nd4j.linalg.api.ops.custom;
|
||||
|
||||
import lombok.NonNull;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
public class AdjustHue extends DynamicCustomOp {
|
||||
public AdjustHue() {}
|
||||
|
||||
public AdjustHue(@NonNull INDArray in, double delta, INDArray out) {
|
||||
this(in, delta);
|
||||
if (out != null) {
|
||||
outputArguments.add(out);
|
||||
}
|
||||
}
|
||||
|
||||
public AdjustHue(@NonNull INDArray in, double delta) {
|
||||
Preconditions.checkArgument(in.rank() >= 3,
|
||||
"AdjustSaturation: op expects rank of input array to be >= 3, but got %s instead", in.rank());
|
||||
Preconditions.checkArgument(-1.0 <= delta && delta <= 1.0, "AdjustHue: parameter delta must be within [-1, 1] interval," +
|
||||
" but got %s instead", delta);
|
||||
inputArguments.add(in);
|
||||
|
||||
addTArgument(delta);
|
||||
}
|
||||
|
||||
public AdjustHue(@NonNull SameDiff sameDiff, @NonNull SDVariable in, @NonNull SDVariable factor) {
|
||||
super(sameDiff,new SDVariable[]{in,factor});
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "adjust_hue";
|
||||
}
|
||||
|
||||
@Override
|
||||
public String tensorflowName() {
|
||||
return "AdjustHue";
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
||||
int n = args().length;
|
||||
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes);
|
||||
return Collections.singletonList(inputDataTypes.get(0));
|
||||
}
|
||||
}
|
|
@ -0,0 +1,68 @@
|
|||
/* ******************************************************************************
|
||||
* Copyright (c) 2019 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.nd4j.linalg.api.ops.custom;
|
||||
|
||||
import lombok.NonNull;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
public class AdjustSaturation extends DynamicCustomOp {
|
||||
|
||||
public AdjustSaturation() {}
|
||||
|
||||
public AdjustSaturation(@NonNull INDArray in, double factor, INDArray out) {
|
||||
this(in, factor);
|
||||
if (out != null) {
|
||||
outputArguments.add(out);
|
||||
}
|
||||
}
|
||||
|
||||
public AdjustSaturation(@NonNull INDArray in, double factor) {
|
||||
Preconditions.checkArgument(in.rank() >= 3,
|
||||
"AdjustSaturation: op expects rank of input array to be >= 3, but got %s instead", in.rank());
|
||||
inputArguments.add(in);
|
||||
|
||||
addTArgument(factor);
|
||||
}
|
||||
|
||||
public AdjustSaturation(@NonNull SameDiff sameDiff, @NonNull SDVariable in, @NonNull SDVariable factor) {
|
||||
super(sameDiff, new SDVariable[]{in, factor});
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "adjust_saturation";
|
||||
}
|
||||
|
||||
@Override
|
||||
public String tensorflowName() {
|
||||
return "AdjustSaturation";
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
||||
int n = args().length;
|
||||
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes);
|
||||
return Collections.singletonList(inputDataTypes.get(0));
|
||||
}
|
||||
}
|
|
@ -1,5 +1,21 @@
|
|||
/* ******************************************************************************
|
||||
* Copyright (c) 2019 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.nd4j.linalg.api.ops.custom;
|
||||
|
||||
import lombok.NonNull;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
@ -14,16 +30,16 @@ public abstract class BaseAdjustContrast extends DynamicCustomOp {
|
|||
public BaseAdjustContrast() {
|
||||
}
|
||||
|
||||
public BaseAdjustContrast(INDArray in, double factor, INDArray out) {
|
||||
public BaseAdjustContrast(@NonNull INDArray in, double factor, INDArray out) {
|
||||
Preconditions.checkArgument(in.rank() >= 3,
|
||||
String.format("AdjustContrast: op expects rank of input array to be >= 3, but got %d instead", in.rank()));
|
||||
"AdjustContrast: op expects rank of input array to be >= 3, but got %s instead", in.rank());
|
||||
inputArguments.add(in);
|
||||
outputArguments.add(out);
|
||||
|
||||
addTArgument(factor);
|
||||
}
|
||||
|
||||
public BaseAdjustContrast(SameDiff sameDiff, SDVariable[] vars) {
|
||||
public BaseAdjustContrast(@NonNull SameDiff sameDiff, @NonNull SDVariable[] vars) {
|
||||
super("", sameDiff, vars);
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,67 @@
|
|||
/* ******************************************************************************
|
||||
* Copyright (c) 2019 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.nd4j.linalg.api.ops.custom;
|
||||
|
||||
import lombok.NonNull;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
public class BetaInc extends DynamicCustomOp {
|
||||
|
||||
public BetaInc() {}
|
||||
|
||||
public BetaInc(@NonNull INDArray a_input, @NonNull INDArray b_input, @NonNull INDArray x_input,
|
||||
INDArray output) {
|
||||
addInputArgument(a_input, b_input, x_input);
|
||||
if (output != null) {
|
||||
addOutputArgument(output);
|
||||
}
|
||||
}
|
||||
|
||||
public BetaInc(@NonNull INDArray a_input, @NonNull INDArray b_input, @NonNull INDArray x_input) {
|
||||
inputArguments.add(a_input);
|
||||
inputArguments.add(b_input);
|
||||
inputArguments.add(x_input);
|
||||
}
|
||||
|
||||
public BetaInc(@NonNull SameDiff sameDiff, @NonNull SDVariable a, @NonNull SDVariable b, @NonNull SDVariable x) {
|
||||
super(sameDiff, new SDVariable[]{a,b,x});
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "betainc";
|
||||
}
|
||||
|
||||
@Override
|
||||
public String tensorflowName() {
|
||||
return "Betainc";
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
||||
int n = args().length;
|
||||
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes);
|
||||
return Collections.singletonList(inputDataTypes.get(0));
|
||||
}
|
||||
}
|
|
@ -1,3 +1,18 @@
|
|||
/* ******************************************************************************
|
||||
* Copyright (c) 2019 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.nd4j.linalg.api.ops.custom;
|
||||
|
||||
import lombok.val;
|
||||
|
@ -20,6 +35,8 @@ import java.util.Map;
|
|||
public class BitCast extends DynamicCustomOp {
|
||||
public BitCast() {}
|
||||
|
||||
private DataType dtype;
|
||||
|
||||
public BitCast(INDArray in, DataType dataType, INDArray out) {
|
||||
this(in, dataType.toInt(), out);
|
||||
}
|
||||
|
@ -28,6 +45,8 @@ public class BitCast extends DynamicCustomOp {
|
|||
inputArguments.add(in);
|
||||
outputArguments.add(out);
|
||||
iArguments.add(Long.valueOf(dataType));
|
||||
|
||||
dtype = DataType.fromInt(dataType);
|
||||
}
|
||||
|
||||
public BitCast(INDArray in, DataType dataType) {
|
||||
|
@ -37,6 +56,7 @@ public class BitCast extends DynamicCustomOp {
|
|||
public BitCast(INDArray in, int dataType) {
|
||||
inputArguments.add(in);
|
||||
iArguments.add(Long.valueOf(dataType));
|
||||
dtype = DataType.fromInt(dataType);
|
||||
}
|
||||
|
||||
public BitCast(SameDiff sameDiff, SDVariable in, SDVariable dataType) {
|
||||
|
@ -49,6 +69,8 @@ public class BitCast extends DynamicCustomOp {
|
|||
val t = nodeDef.getAttrOrDefault("type", null);
|
||||
val type = ArrayOptionsHelper.convertToDataType(t.getType());
|
||||
addIArgument(type.toInt());
|
||||
|
||||
dtype = type;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -65,6 +87,6 @@ public class BitCast extends DynamicCustomOp {
|
|||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
||||
int n = args().length;
|
||||
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes);
|
||||
return Collections.singletonList(inputDataTypes.get(0));
|
||||
return Collections.singletonList(dtype);
|
||||
}
|
||||
}
|
|
@ -1,3 +1,18 @@
|
|||
/* ******************************************************************************
|
||||
* Copyright (c) 2019 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.nd4j.linalg.api.ops.custom;
|
||||
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
|
@ -9,9 +24,13 @@ import org.nd4j.linalg.factory.Nd4j;
|
|||
public class CompareAndBitpack extends DynamicCustomOp {
|
||||
public CompareAndBitpack() {}
|
||||
|
||||
public CompareAndBitpack(INDArray in, double threshold, INDArray out) {
|
||||
public CompareAndBitpack(INDArray in, double threshold) {
|
||||
inputArguments.add(in);
|
||||
inputArguments.add(Nd4j.scalar(threshold));
|
||||
}
|
||||
|
||||
public CompareAndBitpack(INDArray in, double threshold, INDArray out) {
|
||||
this(in, threshold);
|
||||
outputArguments.add(out);
|
||||
}
|
||||
|
||||
|
|
|
@ -1,3 +1,18 @@
|
|||
/* ******************************************************************************
|
||||
* Copyright (c) 2019 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.nd4j.linalg.api.ops.custom;
|
||||
|
||||
import org.apache.commons.math3.analysis.function.Divide;
|
||||
|
@ -16,9 +31,13 @@ public class DivideNoNan extends DynamicCustomOp {
|
|||
public DivideNoNan() {
|
||||
}
|
||||
|
||||
public DivideNoNan(INDArray in1, INDArray in2, INDArray out) {
|
||||
public DivideNoNan(INDArray in1, INDArray in2) {
|
||||
inputArguments.add(in1);
|
||||
inputArguments.add(in2);
|
||||
}
|
||||
|
||||
public DivideNoNan(INDArray in1, INDArray in2, INDArray out) {
|
||||
this(in1,in2);
|
||||
outputArguments.add(out);
|
||||
}
|
||||
|
||||
|
|
|
@ -1,3 +1,18 @@
|
|||
/* ******************************************************************************
|
||||
* Copyright (c) 2019 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.nd4j.linalg.api.ops.custom;
|
||||
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
|
@ -13,11 +28,15 @@ import java.util.List;
|
|||
public class DrawBoundingBoxes extends DynamicCustomOp {
|
||||
public DrawBoundingBoxes() {}
|
||||
|
||||
public DrawBoundingBoxes(INDArray images, INDArray boxes, INDArray colors,
|
||||
INDArray output) {
|
||||
public DrawBoundingBoxes(INDArray images, INDArray boxes, INDArray colors) {
|
||||
inputArguments.add(images);
|
||||
inputArguments.add(boxes);
|
||||
inputArguments.add(colors);
|
||||
}
|
||||
|
||||
public DrawBoundingBoxes(INDArray images, INDArray boxes, INDArray colors,
|
||||
INDArray output) {
|
||||
this(images, boxes, colors);
|
||||
outputArguments.add(output);
|
||||
}
|
||||
|
||||
|
|
|
@ -1,3 +1,18 @@
|
|||
/* ******************************************************************************
|
||||
* Copyright (c) 2019 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.nd4j.linalg.api.ops.custom;
|
||||
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
|
@ -13,14 +28,18 @@ import java.util.List;
|
|||
public class FakeQuantWithMinMaxVarsPerChannel extends DynamicCustomOp {
|
||||
public FakeQuantWithMinMaxVarsPerChannel() {}
|
||||
|
||||
public FakeQuantWithMinMaxVarsPerChannel(INDArray x, INDArray min, INDArray max,
|
||||
INDArray output) {
|
||||
public FakeQuantWithMinMaxVarsPerChannel(INDArray x, INDArray min, INDArray max) {
|
||||
Preconditions.checkArgument(min.isVector() && max.isVector() &&
|
||||
min.length() == max.length(),
|
||||
"FakeQuantWithMinMaxVarsPerChannel: min and max should be 1D tensors with the same length");
|
||||
inputArguments.add(x);
|
||||
inputArguments.add(min);
|
||||
inputArguments.add(max);
|
||||
}
|
||||
|
||||
public FakeQuantWithMinMaxVarsPerChannel(INDArray x, INDArray min, INDArray max,
|
||||
INDArray output) {
|
||||
this(x,min,max);
|
||||
outputArguments.add(output);
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,64 @@
|
|||
/* ******************************************************************************
|
||||
* Copyright (c) 2019 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.nd4j.linalg.api.ops.custom;
|
||||
|
||||
import lombok.NonNull;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
public class FusedBatchNorm extends DynamicCustomOp {
|
||||
|
||||
public FusedBatchNorm() {}
|
||||
|
||||
public FusedBatchNorm(@NonNull INDArray x, @NonNull INDArray scale, @NonNull INDArray offset,
|
||||
int dataFormat, int isTraining,
|
||||
INDArray yOut, INDArray batchMeanOut, INDArray batchMeanVar) {
|
||||
addInputArgument(x, scale, offset);
|
||||
addIArgument(dataFormat, isTraining);
|
||||
if (yOut != null && batchMeanOut != null && batchMeanVar != null) {
|
||||
addOutputArgument(yOut, batchMeanOut, batchMeanVar);
|
||||
}
|
||||
}
|
||||
|
||||
public FusedBatchNorm(@NonNull SameDiff sameDiff, @NonNull SDVariable x, @NonNull SDVariable scale, @NonNull SDVariable offset,
|
||||
@NonNull SDVariable dataFormat, @NonNull SDVariable isTraining) {
|
||||
super("", sameDiff, new SDVariable[]{x, scale, offset, dataFormat, isTraining});
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "fused_batch_norm";
|
||||
}
|
||||
|
||||
@Override
|
||||
public String tensorflowName() {
|
||||
return "FusedBatchNormV2";
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
||||
int n = args().length;
|
||||
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes);
|
||||
return Collections.singletonList(inputDataTypes.get(0));
|
||||
}
|
||||
}
|
|
@ -0,0 +1,65 @@
|
|||
/* ******************************************************************************
|
||||
* Copyright (c) 2019 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.nd4j.linalg.api.ops.custom;
|
||||
|
||||
import lombok.NonNull;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
public class MatrixBandPart extends DynamicCustomOp {
|
||||
|
||||
public MatrixBandPart() {}
|
||||
|
||||
public MatrixBandPart(@NonNull INDArray input, int minLower, int maxUpper) {
|
||||
Preconditions.checkArgument(input.rank() >= 2, "MatrixBandPart: Input rank should be 2 or higher");
|
||||
long N = input.size(-2);
|
||||
long M = input.size(-1);
|
||||
Preconditions.checkArgument(minLower > -N && minLower < N, "MatrixBandPart: lower diagonal count %s should be less than %s",
|
||||
minLower, N);
|
||||
Preconditions.checkArgument(maxUpper > -M && maxUpper < M, "MatrixBandPart: upper diagonal count %s should be less than %s.",
|
||||
maxUpper, M);
|
||||
addInputArgument(input);
|
||||
addIArgument(minLower, maxUpper);
|
||||
}
|
||||
|
||||
public MatrixBandPart(@NonNull SameDiff sameDiff, @NonNull SDVariable input, SDVariable minLower, SDVariable maxUpper) {
|
||||
super("", sameDiff, new SDVariable[]{input, minLower, maxUpper});
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "matrix_band_part";
|
||||
}
|
||||
|
||||
@Override
|
||||
public String tensorflowName() {
|
||||
return "MatrixBandPart";
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
||||
int n = args().length;
|
||||
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes);
|
||||
return Collections.singletonList(inputDataTypes.get(0));
|
||||
}
|
||||
}
|
|
@ -0,0 +1,66 @@
|
|||
/* ******************************************************************************
|
||||
* Copyright (c) 2019 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.nd4j.linalg.api.ops.custom;
|
||||
|
||||
import lombok.NonNull;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
public class Polygamma extends DynamicCustomOp {
|
||||
|
||||
public Polygamma() {}
|
||||
|
||||
public Polygamma(@NonNull INDArray n, @NonNull INDArray x) {
|
||||
Preconditions.checkArgument(n.shape() != x.shape(),
|
||||
"Polygamma: n and x must have the same shapes");
|
||||
addInputArgument(n,x);
|
||||
}
|
||||
|
||||
public Polygamma(@NonNull INDArray n, @NonNull INDArray x, INDArray output) {
|
||||
this(n,x);
|
||||
if (output != null) {
|
||||
addOutputArgument(output);
|
||||
}
|
||||
}
|
||||
|
||||
public Polygamma(@NonNull SameDiff sameDiff, @NonNull SDVariable n, @NonNull SDVariable x) {
|
||||
super("", sameDiff, new SDVariable[]{n ,x});
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "polygamma";
|
||||
}
|
||||
|
||||
@Override
|
||||
public String tensorflowName() {
|
||||
return "Polygamma";
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
||||
int n = args().length;
|
||||
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes);
|
||||
return Collections.singletonList(inputDataTypes.get(0));
|
||||
}
|
||||
}
|
|
@ -0,0 +1,61 @@
|
|||
/* ******************************************************************************
|
||||
* Copyright (c) 2019 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.nd4j.linalg.api.ops.custom;
|
||||
|
||||
import lombok.NonNull;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.nd4j.linalg.api.rng.Random;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
public class RandomCrop extends DynamicCustomOp {
|
||||
|
||||
public RandomCrop() {}
|
||||
|
||||
public RandomCrop(@NonNull INDArray input, @NonNull INDArray shape) {
|
||||
Preconditions.checkArgument(shape.isVector(),"RandomCrop:Shape tensor should be a vector");
|
||||
Preconditions.checkArgument(input.rank() == shape.length(), "RandomCrop:The length of the shape vector is not match input rank");
|
||||
addInputArgument(input, shape);
|
||||
}
|
||||
|
||||
public RandomCrop(@NonNull SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable shape) {
|
||||
super("", sameDiff, new SDVariable[]{input, shape});
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "random_crop";
|
||||
}
|
||||
|
||||
@Override
|
||||
public String tensorflowName() {
|
||||
return "RandomCrop";
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
||||
Preconditions.checkState(inputDataTypes != null /*&& inputDataTypes.size() == 4*/,
|
||||
"Expected 4 input datatypes for %s, got %s", getClass(), inputDataTypes);
|
||||
return Collections.singletonList(DataType.FLOAT); //TF import: always returns float32...
|
||||
}
|
||||
}
|
|
@ -0,0 +1,64 @@
|
|||
/* ******************************************************************************
|
||||
* Copyright (c) 2019 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.nd4j.linalg.api.ops.custom;
|
||||
|
||||
import lombok.NonNull;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
public class Roll extends DynamicCustomOp {
|
||||
|
||||
public Roll() {}
|
||||
|
||||
public Roll(@NonNull INDArray input, @NonNull INDArray axes, @NonNull INDArray shifts) {
|
||||
Preconditions.checkArgument(axes.rank() == shifts.rank(), "Roll: shifts and axes should be the same rank");
|
||||
Preconditions.checkArgument(axes.length() == shifts.length(), "Roll: shifts and axes should be the same length");
|
||||
addInputArgument(input, axes, shifts);
|
||||
}
|
||||
|
||||
public Roll(@NonNull INDArray input, int shift) {
|
||||
addInputArgument(input);
|
||||
addIArgument(shift);
|
||||
}
|
||||
|
||||
public Roll(@NonNull SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable shift) {
|
||||
super("", sameDiff, new SDVariable[]{input,shift});
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "roll";
|
||||
}
|
||||
|
||||
@Override
|
||||
public String tensorflowName() {
|
||||
return "Roll";
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
||||
int n = args().length;
|
||||
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes);
|
||||
return Collections.singletonList(inputDataTypes.get(0));
|
||||
}
|
||||
}
|
|
@ -0,0 +1,64 @@
|
|||
/* ******************************************************************************
|
||||
* Copyright (c) 2019 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.nd4j.linalg.api.ops.custom;
|
||||
|
||||
import lombok.NonNull;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
public class ToggleBits extends DynamicCustomOp {
|
||||
|
||||
public ToggleBits() {}
|
||||
|
||||
public ToggleBits(@NonNull INDArray input, INDArray output) {
|
||||
this(input);
|
||||
if (output != null) {
|
||||
addOutputArgument(output);
|
||||
}
|
||||
}
|
||||
|
||||
public ToggleBits(@NonNull INDArray input) {
|
||||
addInputArgument(input);
|
||||
}
|
||||
|
||||
public ToggleBits(@NonNull SameDiff sameDiff, @NonNull SDVariable input) {
|
||||
super("", sameDiff, new SDVariable[]{input});
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "toggle_bits";
|
||||
}
|
||||
|
||||
@Override
|
||||
public String tensorflowName() {
|
||||
return "Invert";
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
||||
int n = args().length;
|
||||
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes);
|
||||
return Collections.singletonList(inputDataTypes.get(0));
|
||||
}
|
||||
}
|
|
@ -21,6 +21,7 @@ import org.nd4j.autodiff.samediff.SDVariable;
|
|||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.imports.NoOpNameFoundException;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.nd4j.linalg.api.ops.Op;
|
||||
|
||||
|
@ -41,6 +42,12 @@ public class NonMaxSuppression extends DynamicCustomOp {
|
|||
super(null, sameDiff, new SDVariable[]{boxes, scores, maxOutSize, iouThreshold, scoreThreshold}, false);
|
||||
}
|
||||
|
||||
public NonMaxSuppression(INDArray boxes, INDArray scores, int maxOutSize, double iouThreshold, double scoreThreshold) {
|
||||
addInputArgument(boxes,scores);
|
||||
addIArgument(maxOutSize);
|
||||
addTArgument(iouThreshold, scoreThreshold);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String onnxName() {
|
||||
throw new NoOpNameFoundException("No onnx name found for shape " + opName());
|
||||
|
@ -53,7 +60,7 @@ public class NonMaxSuppression extends DynamicCustomOp {
|
|||
|
||||
@Override
|
||||
public String[] tensorflowNames() {
|
||||
return new String[]{"NonMaxSuppression", "NonMaxSuppressionV2"};
|
||||
return new String[]{"NonMaxSuppression", "NonMaxSuppressionV2","NonMaxSuppressionV3","NonMaxSuppressionV4"};
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -0,0 +1,284 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2019 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.nd4j.linalg.api.ops.impl.layers.convolution;
|
||||
|
||||
import lombok.Builder;
|
||||
import lombok.Getter;
|
||||
import lombok.NonNull;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.imports.descriptors.properties.PropertyMapping;
|
||||
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling2DConfig;
|
||||
import org.nd4j.linalg.util.ArrayUtil;
|
||||
import org.tensorflow.framework.AttrValue;
|
||||
import org.tensorflow.framework.GraphDef;
|
||||
import org.tensorflow.framework.NodeDef;
|
||||
|
||||
import java.lang.reflect.Field;
|
||||
import java.util.*;
|
||||
|
||||
@Slf4j
|
||||
@Getter
|
||||
public class MaxPoolWithArgmax extends DynamicCustomOp {
|
||||
|
||||
protected Pooling2DConfig config;
|
||||
protected DataType outputType;
|
||||
|
||||
public MaxPoolWithArgmax() {
|
||||
}
|
||||
|
||||
@Builder(builderMethodName = "sameDiffBuilder")
|
||||
@SuppressWarnings("Used in lombok")
|
||||
public MaxPoolWithArgmax(SameDiff sameDiff, SDVariable input, Pooling2DConfig config) {
|
||||
super(null, sameDiff, new SDVariable[]{input}, false);
|
||||
|
||||
config.setType(Pooling2D.Pooling2DType.MAX);
|
||||
this.config = config;
|
||||
addArgs();
|
||||
}
|
||||
|
||||
public MaxPoolWithArgmax(INDArray input, INDArray output,INDArray outArgMax, @NonNull Pooling2DConfig config){
|
||||
super(null, new INDArray[]{input}, new INDArray[]{output, outArgMax});
|
||||
config.setType(Pooling2D.Pooling2DType.MAX);
|
||||
|
||||
this.config = config;
|
||||
addArgs();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isConfigProperties() {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String configFieldName() {
|
||||
return "config";
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public Map<String, Object> propertiesForFunction() {
|
||||
if(config == null && iArguments.size() > 0){
|
||||
//Perhaps loaded from FlatBuffers - hence we have IArgs but not Config object
|
||||
config = Pooling2DConfig.builder()
|
||||
.kH(iArguments.get(0))
|
||||
.kW(iArguments.get(1))
|
||||
.sH(iArguments.get(2))
|
||||
.sW(iArguments.get(3))
|
||||
.pH(iArguments.get(4))
|
||||
.pW(iArguments.get(5))
|
||||
.dH(iArguments.get(6))
|
||||
.dW(iArguments.get(7))
|
||||
.isSameMode(iArguments.get(8) == 1)
|
||||
.extra(iArguments.get(9))
|
||||
.isNHWC(iArguments.get(10) == 1)
|
||||
.type(Pooling2D.Pooling2DType.MAX)
|
||||
.build();
|
||||
}
|
||||
return config.toProperties();
|
||||
}
|
||||
|
||||
private void addArgs() {
|
||||
addIArgument(config.getKH(),
|
||||
config.getKW(),
|
||||
config.getSH(),
|
||||
config.getSW(),
|
||||
config.getPH(),
|
||||
config.getPW(),
|
||||
config.getDH(),
|
||||
config.getDW(),
|
||||
ArrayUtil.fromBoolean(config.isSameMode()),
|
||||
(int) config.getExtra(),
|
||||
ArrayUtil.fromBoolean(config.isNHWC())
|
||||
);
|
||||
|
||||
}
|
||||
|
||||
|
||||
public String getPoolingPrefix() {
|
||||
return "max";
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
||||
List<SDVariable> ret = new ArrayList<>();
|
||||
List<SDVariable> inputs = new ArrayList<>();
|
||||
inputs.addAll(Arrays.asList(args()));
|
||||
inputs.add(f1.get(0));
|
||||
Pooling2DDerivative pooling2DDerivative = Pooling2DDerivative.derivativeBuilder()
|
||||
.inputs(inputs.toArray(new SDVariable[inputs.size()]))
|
||||
.sameDiff(sameDiff)
|
||||
.config(config)
|
||||
.build();
|
||||
ret.addAll(Arrays.asList(pooling2DDerivative.outputVariables()));
|
||||
return ret;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
||||
val aStrides = nodeDef.getAttrOrThrow("strides");
|
||||
val tfStrides = aStrides.getList().getIList();
|
||||
|
||||
val aKernels = nodeDef.getAttrOrThrow("ksize");
|
||||
val tfKernels = aKernels.getList().getIList();
|
||||
|
||||
int sH = 0;
|
||||
int sW = 0;
|
||||
|
||||
int pH = 0;
|
||||
int pW = 0;
|
||||
|
||||
int kH = 0;
|
||||
int kW = 0;
|
||||
|
||||
val aPadding = nodeDef.getAttrOrThrow("padding");
|
||||
val padding = aPadding.getList().getIList();
|
||||
|
||||
val paddingMode = aPadding.getS().toStringUtf8().replaceAll("\"", "");
|
||||
|
||||
boolean isSameMode = paddingMode.equalsIgnoreCase("SAME");
|
||||
|
||||
String data_format = "nhwc";
|
||||
if (nodeDef.containsAttr("data_format")) {
|
||||
val attr = nodeDef.getAttrOrThrow("data_format");
|
||||
|
||||
data_format = attr.getS().toStringUtf8().toLowerCase();
|
||||
}
|
||||
|
||||
if (data_format.equalsIgnoreCase("nhwc")) {
|
||||
sH = tfStrides.get(1).intValue();
|
||||
sW = tfStrides.get(2).intValue();
|
||||
|
||||
kH = tfKernels.get(1).intValue();
|
||||
kW = tfKernels.get(2).intValue();
|
||||
|
||||
pH = padding.size() > 0 ? padding.get(1).intValue() : 0;
|
||||
pW = padding.size() > 0 ? padding.get(2).intValue() : 0;
|
||||
} else {
|
||||
sH = tfStrides.get(2).intValue();
|
||||
sW = tfStrides.get(3).intValue();
|
||||
|
||||
kH = tfKernels.get(2).intValue();
|
||||
kW = tfKernels.get(3).intValue();
|
||||
|
||||
pH = padding.size() > 0 ? padding.get(2).intValue() : 0;
|
||||
pW = padding.size() > 0 ? padding.get(3).intValue() : 0;
|
||||
}
|
||||
|
||||
Pooling2DConfig pooling2DConfig = Pooling2DConfig.builder()
|
||||
.sH(sH)
|
||||
.sW(sW)
|
||||
.type(Pooling2D.Pooling2DType.MAX)
|
||||
.isSameMode(isSameMode)
|
||||
.kH(kH)
|
||||
.kW(kW)
|
||||
.pH(pH)
|
||||
.pW(pW)
|
||||
.isNHWC(data_format.equalsIgnoreCase("nhwc"))
|
||||
.extra(1.0) // averaging only for non-padded values
|
||||
.build();
|
||||
this.config = pooling2DConfig;
|
||||
addArgs();
|
||||
if(attributesForNode.containsKey("argmax")) {
|
||||
outputType = TFGraphMapper.convertType(attributesForNode.get("argmax").getType());
|
||||
} else {
|
||||
outputType = DataType.UINT32;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public Map<String, Map<String, PropertyMapping>> mappingsForFunction() {
|
||||
Map<String, Map<String, PropertyMapping>> ret = new HashMap<>();
|
||||
Map<String, PropertyMapping> map = new HashMap<>();
|
||||
val strideMapping = PropertyMapping.builder()
|
||||
.tfAttrName("strides")
|
||||
.onnxAttrName("strides")
|
||||
.propertyNames(new String[]{"sW", "sH"})
|
||||
.build();
|
||||
|
||||
val paddingMapping = PropertyMapping.builder()
|
||||
.onnxAttrName("padding")
|
||||
.tfAttrName("padding")
|
||||
.propertyNames(new String[]{"pH", "pW"})
|
||||
.build();
|
||||
|
||||
val kernelMapping = PropertyMapping.builder()
|
||||
.propertyNames(new String[]{"kH", "kW"})
|
||||
.tfInputPosition(1)
|
||||
.onnxAttrName("ksize")
|
||||
.build();
|
||||
|
||||
val dilationMapping = PropertyMapping.builder()
|
||||
.onnxAttrName("dilations")
|
||||
.propertyNames(new String[]{"dW", "dH"})
|
||||
.tfAttrName("rates")
|
||||
.build();
|
||||
|
||||
|
||||
//data_format
|
||||
val dataFormatMapping = PropertyMapping.builder()
|
||||
.propertyNames(new String[]{"isNHWC"})
|
||||
.tfAttrName("data_format")
|
||||
.build();
|
||||
|
||||
map.put("sW", strideMapping);
|
||||
map.put("sH", strideMapping);
|
||||
map.put("kH", kernelMapping);
|
||||
map.put("kW", kernelMapping);
|
||||
map.put("dW", dilationMapping);
|
||||
map.put("dH", dilationMapping);
|
||||
map.put("pH", paddingMapping);
|
||||
map.put("pW", paddingMapping);
|
||||
map.put("isNHWC", dataFormatMapping);
|
||||
|
||||
ret.put(onnxName(), map);
|
||||
ret.put(tensorflowName(), map);
|
||||
return ret;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "max_pool_with_argmax";
|
||||
}
|
||||
|
||||
@Override
|
||||
public String onnxName() {
|
||||
return "MaxPoolWithArgmax";
|
||||
}
|
||||
|
||||
@Override
|
||||
public String tensorflowName() {
|
||||
return "MaxPoolWithArgmax";
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
||||
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 1, "Expected 1 input data type for %s, got %s", getClass(), inputDataTypes);
|
||||
List<DataType> result = new ArrayList<>();
|
||||
result.add(inputDataTypes.get(0));
|
||||
result.add(outputType == null ? DataType.UINT32 : outputType);
|
||||
return result;
|
||||
}
|
||||
}
|
|
@ -293,8 +293,8 @@ public class MaxPooling2D extends DynamicCustomOp {
|
|||
}
|
||||
|
||||
@Override
|
||||
public String tensorflowName() {
|
||||
return "MaxPool";
|
||||
public String[] tensorflowNames() {
|
||||
return new String[]{"MaxPool","MaxPoolV2"};
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -68,7 +68,7 @@ public class ClipByValue extends DynamicCustomOp {
|
|||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "clipbyvalue";
|
||||
return "ClipByValue";
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -53,15 +53,9 @@ public class RShiftBits extends BaseDynamicTransformOp {
|
|||
return "rshift_bits";
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public String onnxName() {
|
||||
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
|
||||
}
|
||||
|
||||
@Override
|
||||
public String tensorflowName() {
|
||||
throw new NoOpNameFoundException("No TensorFlow op opName found for " + opName());
|
||||
return "RightShift";
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -53,15 +53,9 @@ public class ShiftBits extends BaseDynamicTransformOp {
|
|||
return "shift_bits";
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public String onnxName() {
|
||||
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
|
||||
}
|
||||
|
||||
@Override
|
||||
public String tensorflowName() {
|
||||
throw new NoOpNameFoundException("No TensorFlow op opName found for " + opName());
|
||||
return "LeftShift";
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -46,8 +46,8 @@ public class UniqueWithCounts extends DynamicCustomOp {
|
|||
}
|
||||
|
||||
@Override
|
||||
public String tensorflowName() {
|
||||
return "UniqueWithCounts";
|
||||
public String[] tensorflowNames() {
|
||||
return new String[]{"UniqueWithCounts","UniqueWithCountsV2"};
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -77,8 +77,8 @@ public class CopyOp extends BaseTransformSameOp {
|
|||
}
|
||||
|
||||
@Override
|
||||
public String tensorflowName() {
|
||||
throw new NoOpNameFoundException("No tensorflow op opName found for " + opName());
|
||||
public String[] tensorflowNames() {
|
||||
return new String[]{"Copy","DeepCopy","CopyHost"};
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -57,7 +57,7 @@ public class ModOp extends BaseDynamicTransformOp {
|
|||
|
||||
@Override
|
||||
public String tensorflowName() {
|
||||
return "mod";
|
||||
return "Mod";
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -67,12 +67,6 @@ public class Not extends BaseTransformBoolOp {
|
|||
return "Not";
|
||||
}
|
||||
|
||||
@Override
|
||||
public String tensorflowName() {
|
||||
throw new NoOpNameFoundException("Tensorflow name not found for " + opName());
|
||||
//return "Not";
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
||||
return Collections.singletonList(f().zerosLike(arg()));
|
||||
|
|
|
@ -59,19 +59,6 @@ public class GELU extends BaseTransformStrictOp {
|
|||
return "gelu";
|
||||
}
|
||||
|
||||
@Override
|
||||
public String onnxName() {
|
||||
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
|
||||
}
|
||||
|
||||
@Override
|
||||
public String tensorflowName()
|
||||
{
|
||||
throw new NoOpNameFoundException("No tensorflow op opName found for " + opName());
|
||||
//return "GELU";
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public List<SDVariable> doDiff(List<SDVariable> i_v) {
|
||||
SDVariable ret = f().geluDerivative(arg(), false).mul(i_v.get(0));
|
||||
|
|
|
@ -24,6 +24,7 @@ import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
|||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
||||
import org.tensorflow.framework.AttrValue;
|
||||
import org.tensorflow.framework.GraphDef;
|
||||
import org.tensorflow.framework.NodeDef;
|
||||
|
@ -71,7 +72,12 @@ public class DistributionUniform extends DynamicCustomOp {
|
|||
|
||||
@Override
|
||||
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
||||
AttrValue v = attributesForNode.get("dtype");
|
||||
AttrValue vDtype = attributesForNode.get("dtype");
|
||||
AttrValue vTout = attributesForNode.get("Tout");
|
||||
if (vDtype == null && vTout == null) {
|
||||
throw new ND4JIllegalStateException("Unable to find output data type for node " + nodeDef.getName());
|
||||
}
|
||||
AttrValue v = vDtype == null ? vTout : vDtype;
|
||||
dataType = TFGraphMapper.convertType(v.getType());
|
||||
addIArgument(dataType.toInt());
|
||||
addTArgument(0.0, 1.0); //TF version is hardcoded 0 to 1
|
||||
|
@ -92,8 +98,8 @@ public class DistributionUniform extends DynamicCustomOp {
|
|||
}
|
||||
|
||||
@Override
|
||||
public String tensorflowName() {
|
||||
return "RandomUniform";
|
||||
public String[] tensorflowNames() {
|
||||
return new String[]{"RandomUniform","RandomUniformInt"};
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -103,7 +109,7 @@ public class DistributionUniform extends DynamicCustomOp {
|
|||
|
||||
@Override
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
||||
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 1, "Expected exactly 1 input datatype for %s, got %s", getClass(), inputDataTypes);
|
||||
Preconditions.checkState(inputDataTypes != null /*&& inputDataTypes.size() == 1*/, "Expected input datatypes for %s, got %s", getClass(), inputDataTypes);
|
||||
//Input data type specifies the shape
|
||||
if(dataType != null){
|
||||
return Collections.singletonList(dataType);
|
||||
|
|
|
@ -65,18 +65,6 @@ public class DropOut extends BaseRandomOp {
|
|||
return "dropout";
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public String onnxName() {
|
||||
throw new NoOpNameFoundException("No ONNX op name found for: " + getClass().getName());
|
||||
}
|
||||
|
||||
@Override
|
||||
public String tensorflowName() {
|
||||
throw new NoOpNameFoundException("No tensorflow op name found for: " + getClass().getName());
|
||||
//return opName();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Type opType() {
|
||||
return Type.RANDOM ;
|
||||
|
|
|
@ -736,6 +736,35 @@ public class LayerOpValidation extends BaseOpValidation {
|
|||
// sd.execBackwards(); // TODO: test failing here
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMaxPoolingArgMax() {
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
int nIn = 3;
|
||||
int kH = 2;
|
||||
int kW = 2;
|
||||
|
||||
int mb = 3;
|
||||
int imgH = 8;
|
||||
int imgW = 8;
|
||||
|
||||
SameDiff sd = SameDiff.create();
|
||||
INDArray inArr = Nd4j.rand(new int[]{mb, nIn, imgH, imgW});
|
||||
|
||||
SDVariable in = sd.var("in", inArr);
|
||||
|
||||
Pooling2DConfig pooling2DConfig = Pooling2DConfig.builder()
|
||||
.kH(kH).kW(kW)
|
||||
.pH(0).pW(0)
|
||||
.sH(1).sW(1)
|
||||
.dH(1).dW(1)
|
||||
.isSameMode(true)
|
||||
.build();
|
||||
|
||||
SDVariable[] results = sd.nn().maxPoolWithArgmax(new String[]{"",""}, in, pooling2DConfig);
|
||||
assertArrayEquals(inArr.shape(), results[0].eval().shape());
|
||||
assertArrayEquals(inArr.shape(), results[1].eval().shape());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMaxPooling2dBasic() {
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
|
|
@ -76,8 +76,6 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a
|
|||
"adjust_contrast/.*",
|
||||
//Failing 2019/09/11 - https://github.com/eclipse/deeplearning4j/issues/7965
|
||||
"bincount/.*",
|
||||
// Failing 2019/11/15 https://github.com/eclipse/deeplearning4j/issues/8400
|
||||
"bitcast/.*",
|
||||
// Failing 2019/11/14 https://github.com/eclipse/deeplearning4j/issues/8393
|
||||
"is_strictly_increasing/emptyArrayTest/.*",
|
||||
|
||||
|
@ -116,20 +114,32 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a
|
|||
// 2019/11/15 - missing dtype argument in nd4j, tests are useless https://github.com/eclipse/deeplearning4j/issues/8398
|
||||
"zeros_like/rank2_float32_dtype_int.*",
|
||||
|
||||
// 2019/11/15 - failure https://github.com/eclipse/deeplearning4j/issues/8399
|
||||
"crop_and_resize.*",
|
||||
|
||||
// 2019/11/15 - failure https://github.com/eclipse/deeplearning4j/issues/8401
|
||||
"draw_bounding_boxes.*",
|
||||
|
||||
// 2019/11/15 - failure https://github.com/eclipse/deeplearning4j/issues/8402
|
||||
"fake_quant/min_max_args_per_channel.*",
|
||||
|
||||
// 2019/11/15 - failure https://github.com/eclipse/deeplearning4j/issues/8403
|
||||
"resize_bilinear/int32.*",
|
||||
|
||||
// Suggesting TF 1.15 bug - see https://github.com/eclipse/deeplearning4j/issues/8449
|
||||
"non_max_suppression_v2/float16.*"
|
||||
// Suggesting TF 1.15 bug
|
||||
"non_max_suppression_v2/float16.*",
|
||||
|
||||
// 11.26.2019 failing - https://github.com/eclipse/deeplearning4j/issues/8450
|
||||
"betainc.*",
|
||||
|
||||
// 11.26.2019 failing - https://github.com/eclipse/deeplearning4j/issues/8452
|
||||
"polygamma.*",
|
||||
|
||||
// 11.26.2019 failing - https://github.com/eclipse/deeplearning4j/issues/8453
|
||||
"roll/.*",
|
||||
|
||||
// 11.26.2019 failing https://github.com/eclipse/deeplearning4j/issues/8455
|
||||
"matrix_band_part/.*",
|
||||
|
||||
// 11.28.2019 failing https://github.com/eclipse/deeplearning4j/issues/8458
|
||||
"adjust_hue/.*",
|
||||
|
||||
// 11.28.2019 failing https://github.com/eclipse/deeplearning4j/issues/8459
|
||||
"adjust_saturation/.*"
|
||||
};
|
||||
|
||||
/* As per TFGraphTestList.printArraysDebugging - this field defines a set of regexes for test cases that should have
|
||||
|
|
|
@ -32,7 +32,10 @@ import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
|
|||
import org.nd4j.linalg.api.ops.executioner.OpStatus;
|
||||
import org.nd4j.linalg.api.ops.impl.controlflow.Where;
|
||||
import org.nd4j.linalg.api.ops.impl.image.CropAndResize;
|
||||
import org.nd4j.linalg.api.ops.impl.image.NonMaxSuppression;
|
||||
import org.nd4j.linalg.api.ops.impl.image.ResizeBilinear;
|
||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.MaxPoolWithArgmax;
|
||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling2DConfig;
|
||||
import org.nd4j.linalg.api.ops.impl.reduce.MmulBp;
|
||||
import org.nd4j.linalg.api.ops.impl.shape.Create;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.any.IsMax;
|
||||
|
@ -53,6 +56,7 @@ import java.util.ArrayList;
|
|||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
import static java.lang.Float.NaN;
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
/**
|
||||
|
@ -867,6 +871,26 @@ public class CustomOpsTests extends BaseNd4jTest {
|
|||
assertArrayEquals(new long[]{1,10}, lsd.get(0).getShape());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testAdjustSaturation() {
|
||||
INDArray in = Nd4j.createFromArray(new double[]{50,100,78, 118.5,220,112.5,190,163.5,230, 255,128.5,134}).reshape(2,2,3);
|
||||
INDArray out = Nd4j.create(in.shape());
|
||||
INDArray expected = Nd4j.createFromArray(new double[]{0,100,56, 17,220,5, 150,97,230, 255,2,13}).reshape(2,2,3);
|
||||
|
||||
Nd4j.exec(new AdjustSaturation(in, 2.0, out));
|
||||
assertEquals(expected, out);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testAdjustHue() {
|
||||
INDArray in = Nd4j.createFromArray(new double[]{0,100,56, 17,220,5, 150,97,230, 255,2,13}).reshape(2,2,3);
|
||||
INDArray out = Nd4j.create(in.shape());
|
||||
INDArray expected = Nd4j.createFromArray(new double[]{100,0,44, 208,5,220, 177,230,97, 2,255,244}).reshape(2,2,3);
|
||||
|
||||
Nd4j.exec(new AdjustHue(in, 0.5, out));
|
||||
assertEquals(expected, out);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testBitCast() {
|
||||
INDArray in = Nd4j.linspace(DataType.FLOAT, 1.0f, 1.0f, 8).reshape(2,2,2);
|
||||
|
@ -1088,6 +1112,216 @@ public class CustomOpsTests extends BaseNd4jTest {
|
|||
assertArrayEquals(new long[]{1,10, 2}, lsd.get(0).getShape());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testBetaInc() {
|
||||
Nd4j.getRandom().setSeed(10);
|
||||
INDArray a = Nd4j.linspace(DataType.BFLOAT16, 0.1, 0.1, 9).reshape(3,3);
|
||||
INDArray b = Nd4j.linspace(DataType.BFLOAT16, 0.1, 0.1, 9).reshape(3,3);
|
||||
INDArray x = Nd4j.linspace(DataType.BFLOAT16, 0.1, 0.1, 9).reshape(3,3);
|
||||
INDArray expected = Nd4j.createFromArray(new float[]{0.4121f, 0.3926f, 0.4082f,
|
||||
0.4414f, 0.5000f, 0.5703f,
|
||||
0.6562f, 0.7656f, 0.8828f}).reshape(3,3);
|
||||
|
||||
BetaInc op = new BetaInc(a,b,x);
|
||||
INDArray[] out = Nd4j.exec(op);
|
||||
assertArrayEquals(expected.shape(), out[0].shape());
|
||||
for (int i = 0; i < 3; ++i)
|
||||
assertArrayEquals(expected.toDoubleMatrix()[i], out[0].toDoubleMatrix()[i], 1e-4);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testFusedBatchNorm() {
|
||||
INDArray x = Nd4j.linspace(DataType.DOUBLE, 1.0, 1.0, 2*2*3*4).reshape(2,2,3,4);
|
||||
INDArray scale = Nd4j.create(DataType.DOUBLE, 4);
|
||||
scale.assign(0.5);
|
||||
INDArray offset = Nd4j.create(DataType.DOUBLE, 4);
|
||||
offset.assign(2.0);
|
||||
|
||||
INDArray y = Nd4j.createUninitialized(DataType.DOUBLE, x.shape());
|
||||
INDArray batchMean = Nd4j.create(4);
|
||||
INDArray batchVar = Nd4j.create(4);
|
||||
|
||||
FusedBatchNorm op = new FusedBatchNorm(x,scale,offset,0,1,
|
||||
y, batchMean, batchVar);
|
||||
|
||||
INDArray expectedY = Nd4j.createFromArray(new double[]{1.20337462, 1.20337462, 1.20337462,
|
||||
1.20337462, 1.34821558, 1.34821558, 1.34821558, 1.34821558, 1.49305654, 1.49305654,
|
||||
1.49305654, 1.49305654, 1.63789749, 1.63789749, 1.63789749, 1.63789749, 1.78273857,
|
||||
1.78273857, 1.78273857, 1.78273857, 1.92757952, 1.92757952, 1.92757952, 1.92757952,
|
||||
2.0724206 , 2.0724206 , 2.0724206 , 2.0724206 , 2.21726155, 2.21726155, 2.21726155,
|
||||
2.21726155, 2.36210251, 2.36210251, 2.36210251, 2.36210251, 2.50694346, 2.50694346,
|
||||
2.50694346, 2.50694346, 2.65178442, 2.65178442, 2.65178442, 2.65178442, 2.79662538,
|
||||
2.79662538, 2.79662538, 2.79662538}).reshape(x.shape());
|
||||
INDArray expectedBatchMean = Nd4j.createFromArray(new double[]{23., 24., 25., 26.});
|
||||
INDArray expectedBatchVar = Nd4j.createFromArray(new double[]{208.00001526, 208.00001526, 208.00001526, 208.00001526});
|
||||
Nd4j.exec(op);
|
||||
assertArrayEquals(expectedY.shape(), y.shape());
|
||||
assertArrayEquals(expectedBatchMean.shape(), batchMean.shape());
|
||||
assertArrayEquals(expectedBatchVar.shape(), batchVar.shape());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMatrixBandPart() {
|
||||
INDArray x = Nd4j.linspace(DataType.DOUBLE, 1.0, 1.0, 2*3*3).reshape(2,3,3);
|
||||
val op = new MatrixBandPart(x,1,1);
|
||||
INDArray expected = Nd4j.linspace(DataType.DOUBLE, 1.0, 1.0, 2*3*3).reshape(2,3,3);
|
||||
/*expected.putScalar(0, 0, 2, 0.);
|
||||
expected.putScalar(1, 0, 2, 0.);
|
||||
expected.putScalar(0, 2, 0, 0.);
|
||||
expected.putScalar(1, 2, 0, 0.);*/
|
||||
|
||||
INDArray[] out = Nd4j.exec(op);
|
||||
assertEquals(expected, x);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testPolygamma() {
|
||||
INDArray n = Nd4j.linspace(DataType.FLOAT, 1.0, 1.0, 9).reshape(3,3);
|
||||
INDArray x = Nd4j.create(DataType.FLOAT, 3,3);
|
||||
x.assign(0.5);
|
||||
INDArray expected = Nd4j.createFromArray(new float[]{4.934802f, -16.828796f, 97.409088f, -771.474243f,
|
||||
7691.113770f, -92203.460938f, 1290440.250000f, -20644900.000000f, 3.71595e+08f}).reshape(3,3);
|
||||
INDArray output = Nd4j.create(DataType.FLOAT, expected.shape());
|
||||
val op = new Polygamma(x,n,output);
|
||||
Nd4j.exec(op);
|
||||
assertEquals(expected, output);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testRandomCrop() {
|
||||
INDArray x = Nd4j.createFromArray(new double[]{1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9.,2.1, 2.1, 0.7, 0.1,3., 4.2, 2.2, 1. }).reshape(2,2,4);
|
||||
INDArray shape = Nd4j.createFromArray(new int[] {1,2,3});
|
||||
val op = new RandomCrop(x, shape);
|
||||
INDArray[] res = Nd4j.exec(op);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testRoll() {
|
||||
INDArray x = Nd4j.createFromArray(new double[]{ 11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42, 12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42,
|
||||
21.11, 21.12, 21.21, 21.22, 21.31, 21.32, 21.41, 21.42, 22.11, 22.12, 22.21, 22.22, 22.31, 22.32, 22.41, 22.42}).
|
||||
reshape(2,2,4,2);
|
||||
|
||||
INDArray expected = Nd4j.createFromArray(new double[]{ 22.21, 22.22, 22.31, 22.32, 22.41, 22.42, 11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42,
|
||||
12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, 21.11, 21.12, 21.21, 21.22, 21.31, 21.32,
|
||||
21.41, 21.42, 22.11, 22.12
|
||||
}).reshape(x.shape());
|
||||
val op = new Roll(x, 6);
|
||||
INDArray[] res = Nd4j.exec(op);
|
||||
assertEquals(expected, res[0]);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testToggleBits() {
|
||||
INDArray input = Nd4j.createFromArray(new int[]{2,2});
|
||||
INDArray expected = Nd4j.createFromArray(new int[]{-3,-3});
|
||||
ToggleBits op = new ToggleBits(input);
|
||||
val result = Nd4j.exec(op);
|
||||
assertEquals(expected, result[0]);
|
||||
}
|
||||
|
||||
@Ignore("AS 11.28.2019 - https://github.com/eclipse/deeplearning4j/issues/8449")
|
||||
@Test
|
||||
public void testNonMaxSuppression() {
|
||||
INDArray boxes = Nd4j.createFromArray(new float[] {0.8115f, 0.4121f, 0.0771f, 0.4863f,
|
||||
0.7412f, 0.7607f, 0.1543f, 0.5479f,
|
||||
0.8223f, 0.2246f, 0.0049f, 0.6465f}).reshape(3,4);
|
||||
INDArray scores = Nd4j.createFromArray(new float[]{0.0029f, 0.8135f, 0.4873f});
|
||||
val op = new NonMaxSuppression(boxes,scores,2,0.5,0.5);
|
||||
val res = Nd4j.exec(op);
|
||||
assertEquals(new long[]{1}, res[0].shape());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMatrixBand() {
|
||||
INDArray input = Nd4j.createFromArray(new float[]{0.7788f,0.8012f,0.7244f,0.2309f,
|
||||
0.7271f,0.1804f,0.5056f,0.8925f,
|
||||
0.5461f,0.9234f,0.0856f,0.7938f}).reshape(3,4);
|
||||
MatrixBandPart op = new MatrixBandPart(input,1,-1);
|
||||
List<LongShapeDescriptor> lsd = op.calculateOutputShape();
|
||||
assertEquals(1, lsd.size());
|
||||
}
|
||||
|
||||
@Ignore("Failed AS 11.26.2019 - https://github.com/eclipse/deeplearning4j/issues/8450")
|
||||
@Test
|
||||
public void testBetaInc1() {
|
||||
INDArray a = Nd4j.createFromArray(new float[]{0.7788f, 0.8012f, 0.7244f, 0.2309f});
|
||||
INDArray b = Nd4j.createFromArray(new float[]{0.7717f, 0.9281f, 0.9846f, 0.4838f});
|
||||
INDArray c = Nd4j.createFromArray(new float[]{0.9441f, 0.5957f, 0.8669f, 0.3502f});
|
||||
BetaInc op = new BetaInc(a,b,c);
|
||||
INDArray[] ret = Nd4j.exec(op);
|
||||
INDArray expected = Nd4j.createFromArray(new float[]{0.9122f, 0.6344f, 0.8983f, 0.6245f});
|
||||
assertEquals(expected, ret[0]);
|
||||
}
|
||||
|
||||
@Ignore("Failure AS 11.28.2019 - https://github.com/eclipse/deeplearning4j/issues/8452")
|
||||
@Test
|
||||
public void testPolygamma1() {
|
||||
INDArray a = Nd4j.createFromArray(new float[]{0.7788f, 0.8012f, 0.7244f, 0.2309f,
|
||||
0.7271f, 0.1804f, 0.5056f, 0.8925f,
|
||||
0.5461f, 0.9234f, 0.0856f, 0.7938f}).reshape(3,4);
|
||||
INDArray b = Nd4j.createFromArray(new float[]{0.7717f, 0.9281f, 0.9846f, 0.4838f,
|
||||
0.6433f, 0.6041f, 0.6501f, 0.7612f,
|
||||
0.7605f, 0.3948f, 0.9493f, 0.8600f}).reshape(3,4);
|
||||
INDArray expected = Nd4j.createFromArray(new float[]{NaN,NaN,NaN,NaN,NaN,NaN,NaN,NaN,NaN,NaN,NaN,NaN, }).reshape(3,4);
|
||||
Polygamma op = new Polygamma(a,b);
|
||||
INDArray[] ret = Nd4j.exec(op);
|
||||
assertEquals(expected, ret[0]);
|
||||
}
|
||||
|
||||
@Ignore("Failure AS 11.28.2019 - https://github.com/eclipse/deeplearning4j/issues/8453")
|
||||
@Test
|
||||
public void testRoll1() {
|
||||
INDArray a = Nd4j.createFromArray(new float[]{0.7788f, 0.8012f, 0.7244f, 0.2309f});
|
||||
Roll op = new Roll(a,Nd4j.scalar(2),Nd4j.scalar(0));
|
||||
INDArray[] ret = Nd4j.exec(op);
|
||||
INDArray expected = Nd4j.createFromArray(new float[]{0.7244f, 0.2309f, 0.7788f, 0.8012f});
|
||||
assertEquals(expected, ret[0]);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testAdjustHueShape(){
|
||||
INDArray image = Nd4j.createFromArray(new float[]{0.7788f, 0.8012f, 0.7244f,
|
||||
0.2309f, 0.7271f, 0.1804f, 0.5056f, 0.8925f, 0.5461f,
|
||||
0.9234f, 0.0856f, 0.7938f, 0.6591f, 0.5555f, 0.1596f,
|
||||
0.3087f, 0.1548f, 0.4695f, 0.9939f, 0.6113f, 0.6765f,
|
||||
0.1800f, 0.6750f, 0.2246f, 0.0509f, 0.4601f, 0.8284f,
|
||||
0.2354f, 0.9752f, 0.8361f, 0.2585f, 0.4189f, 0.7028f,
|
||||
0.7679f, 0.5373f, 0.7234f, 0.2690f, 0.0062f, 0.0327f,
|
||||
0.0644f, 0.8428f, 0.7494f, 0.0755f, 0.6245f, 0.3491f,
|
||||
0.5793f, 0.5730f, 0.1822f, 0.6420f, 0.9143f, 0.3019f,
|
||||
0.3574f, 0.1704f, 0.8395f, 0.5468f, 0.0744f, 0.9011f,
|
||||
0.6574f, 0.4124f, 0.2445f, 0.4248f, 0.5219f, 0.6952f,
|
||||
0.4900f, 0.2158f, 0.9549f, 0.1386f, 0.1544f, 0.5365f,
|
||||
0.0134f, 0.4163f, 0.1456f, 0.4109f, 0.2484f, 0.3330f,
|
||||
0.2974f, 0.6636f, 0.3808f, 0.8664f, 0.1896f, 0.7530f,
|
||||
0.7215f, 0.6612f, 0.7270f, 0.5704f, 0.2666f, 0.7453f,
|
||||
0.0444f, 0.3024f, 0.4850f, 0.7982f, 0.0965f, 0.7843f,
|
||||
0.5075f, 0.0844f, 0.8370f, 0.6103f, 0.4604f, 0.6087f,
|
||||
0.8594f, 0.4599f, 0.6714f, 0.2744f, 0.1981f, 0.4143f,
|
||||
0.7821f, 0.3505f, 0.5040f, 0.1180f, 0.8307f, 0.1817f,
|
||||
0.8442f, 0.5074f, 0.4471f, 0.5105f, 0.6666f, 0.2576f,
|
||||
0.2341f, 0.6801f, 0.2652f, 0.5394f, 0.4690f, 0.6146f,
|
||||
0.1210f, 0.2576f, 0.0769f, 0.4643f, 0.1628f, 0.2026f,
|
||||
0.3774f, 0.0506f, 0.3462f, 0.5720f, 0.0838f, 0.4228f,
|
||||
0.0588f, 0.5362f, 0.4756f, 0.2530f, 0.1778f, 0.0751f,
|
||||
0.8977f, 0.3648f, 0.3065f, 0.4739f, 0.7014f, 0.4473f,
|
||||
0.5171f, 0.1744f, 0.3487f, 0.7759f, 0.9491f, 0.2072f,
|
||||
0.2182f, 0.6520f, 0.3092f, 0.9545f, 0.1881f, 0.9579f,
|
||||
0.1785f, 0.9636f, 0.4830f, 0.6569f, 0.3353f, 0.9997f,
|
||||
0.5869f, 0.5747f, 0.0238f, 0.2943f, 0.5248f, 0.5879f,
|
||||
0.7266f, 0.1965f, 0.9167f, 0.9726f, 0.9206f, 0.0519f,
|
||||
0.2997f, 0.0039f, 0.7652f, 0.5498f, 0.3794f, 0.3791f,
|
||||
0.3528f, 0.2873f, 0.8082f, 0.4732f, 0.4399f, 0.6606f,
|
||||
0.5991f, 0.0034f, 0.4874f}).reshape(8,8,3);
|
||||
|
||||
AdjustHue op = new AdjustHue(image, 0.2f);
|
||||
INDArray[] res = Nd4j.exec(op);
|
||||
System.out.println(res[0]);
|
||||
List<LongShapeDescriptor> lsd = op.calculateOutputShape();
|
||||
assertEquals(1, lsd.size());
|
||||
assertArrayEquals(new long[]{8, 8, 3}, lsd.get(0).getShape());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testBitCastShape_3(){
|
||||
val x = Nd4j.createFromArray(new int[]{1, 2, 3, 4, 5, 6, 7, 8}).reshape(1, 4, 2);
|
||||
|
|
Loading…
Reference in New Issue