[WIP] Create and small fix (#67)
* - create op - skip exec for empty inputs for non_max_suppression - EmptyHandling idea Signed-off-by: raver119 <raver119@gmail.com> * Create op and mapping for it Signed-off-by: raver119 <raver119@gmail.com>master
parent
dc0036f2c6
commit
83cb0d9329
|
@ -32,6 +32,7 @@
|
||||||
#include <array/ResultSet.h>
|
#include <array/ResultSet.h>
|
||||||
#include <helpers/OpArgsHolder.h>
|
#include <helpers/OpArgsHolder.h>
|
||||||
#include <dll.h>
|
#include <dll.h>
|
||||||
|
#include <ops/declarable/EmptyHandling.h>
|
||||||
//#include <ops/declarable/declarable_ops.h>
|
//#include <ops/declarable/declarable_ops.h>
|
||||||
|
|
||||||
#include <chrono>
|
#include <chrono>
|
||||||
|
@ -111,7 +112,7 @@ namespace nd4j {
|
||||||
*/
|
*/
|
||||||
int prepareOutputs(Context& block);
|
int prepareOutputs(Context& block);
|
||||||
|
|
||||||
//std::vector<int>* calculateOutputShape(std::vector<int>* inputShape, nd4j::graph::Block<T>& block);
|
virtual samediff::EmptyHandling emptyHandling();
|
||||||
public:
|
public:
|
||||||
// for special cases, like BooleanOps
|
// for special cases, like BooleanOps
|
||||||
DeclarableOp();
|
DeclarableOp();
|
||||||
|
|
|
@ -0,0 +1,32 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#ifndef SAMEDIFF_EMPTYHANDLING_H
|
||||||
|
#define SAMEDIFF_EMPTYHANDLING_H
|
||||||
|
|
||||||
|
namespace samediff {
|
||||||
|
enum EmptyHandling {
|
||||||
|
EMPTY_SKIP = 1,
|
||||||
|
EMPTY_EXCEPTION = 2,
|
||||||
|
EMPTY_EXECUTE = 3
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif //SAMEDIFF_EMPTYHANDLING_H
|
|
@ -37,6 +37,9 @@ namespace nd4j {
|
||||||
else
|
else
|
||||||
REQUIRE_TRUE(false, 0, "image.non_max_suppression: Max output size argument cannot be retrieved.");
|
REQUIRE_TRUE(false, 0, "image.non_max_suppression: Max output size argument cannot be retrieved.");
|
||||||
|
|
||||||
|
if (boxes->isEmpty() || scales->isEmpty())
|
||||||
|
return Status::OK();
|
||||||
|
|
||||||
REQUIRE_TRUE(boxes->rankOf() == 2, 0, "image.non_max_suppression: The rank of boxes array should be 2, but %i is given", boxes->rankOf());
|
REQUIRE_TRUE(boxes->rankOf() == 2, 0, "image.non_max_suppression: The rank of boxes array should be 2, but %i is given", boxes->rankOf());
|
||||||
REQUIRE_TRUE(boxes->sizeAt(1) == 4, 0, "image.non_max_suppression: The last dimension of boxes array should be 4, but %i is given", boxes->sizeAt(1));
|
REQUIRE_TRUE(boxes->sizeAt(1) == 4, 0, "image.non_max_suppression: The last dimension of boxes array should be 4, but %i is given", boxes->sizeAt(1));
|
||||||
REQUIRE_TRUE(scales->rankOf() == 1 && scales->lengthOf() == boxes->sizeAt(0), 0, "image.non_max_suppression: The rank of scales array should be 1, but %i is given", boxes->rankOf());
|
REQUIRE_TRUE(scales->rankOf() == 1 && scales->lengthOf() == boxes->sizeAt(0), 0, "image.non_max_suppression: The rank of scales array should be 1, but %i is given", boxes->rankOf());
|
||||||
|
|
|
@ -0,0 +1,57 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <op_boilerplate.h>
|
||||||
|
#if NOT_EXCLUDED(OP_shapes_of)
|
||||||
|
|
||||||
|
#include <ops/declarable/CustomOperations.h>
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
namespace ops {
|
||||||
|
CUSTOM_OP_IMPL(create, 1, 1, false, 0, 1) {
|
||||||
|
auto init = block.numB() > 0 ? B_ARG(0) : true;
|
||||||
|
|
||||||
|
if (init)
|
||||||
|
OUTPUT_VARIABLE(0)->nullify();
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
DECLARE_SHAPE_FN(create) {
|
||||||
|
auto shapeInput = INPUT_VARIABLE(0);
|
||||||
|
auto order = (char) INT_ARG(0);
|
||||||
|
auto dtype = DataTypeUtils::fromInt(INT_ARG(1));
|
||||||
|
|
||||||
|
REQUIRE_TRUE(order == 'c' || order == 'f', 0, "create: order must be either c or f");
|
||||||
|
|
||||||
|
auto shape = shapeInput->getBufferAsVector<Nd4jLong>();
|
||||||
|
|
||||||
|
return SHAPELIST(nd4j::ConstantShapeHelper::getInstance()->createShapeInfo(dtype, order, shape));
|
||||||
|
}
|
||||||
|
|
||||||
|
DECLARE_TYPES(create) {
|
||||||
|
getOpDescriptor()
|
||||||
|
->setAllowedInputTypes({ALL_INTS})
|
||||||
|
->setAllowedOutputTypes(nd4j::DataType::ANY);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
|
@ -99,6 +99,22 @@ namespace nd4j {
|
||||||
#if NOT_EXCLUDED(OP_evaluate_reduction_shape)
|
#if NOT_EXCLUDED(OP_evaluate_reduction_shape)
|
||||||
DECLARE_CUSTOM_OP(evaluate_reduction_shape, 2, 1, false, 0, 0);
|
DECLARE_CUSTOM_OP(evaluate_reduction_shape, 2, 1, false, 0, 0);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This operation creates new array
|
||||||
|
* Input:
|
||||||
|
* array with shape values
|
||||||
|
*
|
||||||
|
* IArgs:
|
||||||
|
* order value
|
||||||
|
* data type value
|
||||||
|
*
|
||||||
|
* BArgs:
|
||||||
|
* initialization option
|
||||||
|
*/
|
||||||
|
#if NOT_EXCLUDED(OP_create)
|
||||||
|
DECLARE_CUSTOM_OP(create, 1, 1, false, 0, 1);
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -933,6 +933,10 @@ namespace nd4j {
|
||||||
return ND4J_STATUS_OK;
|
return ND4J_STATUS_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
samediff::EmptyHandling DeclarableOp::emptyHandling() {
|
||||||
|
return samediff::EmptyHandling::EMPTY_SKIP;
|
||||||
|
}
|
||||||
|
|
||||||
void DeclarableOp::registerTypes() {
|
void DeclarableOp::registerTypes() {
|
||||||
this->getOpDescriptor()->setSameMode(true);
|
this->getOpDescriptor()->setSameMode(true);
|
||||||
}
|
}
|
||||||
|
|
|
@ -149,36 +149,7 @@ import org.nd4j.linalg.api.ops.impl.scatter.ScatterMin;
|
||||||
import org.nd4j.linalg.api.ops.impl.scatter.ScatterMul;
|
import org.nd4j.linalg.api.ops.impl.scatter.ScatterMul;
|
||||||
import org.nd4j.linalg.api.ops.impl.scatter.ScatterSub;
|
import org.nd4j.linalg.api.ops.impl.scatter.ScatterSub;
|
||||||
import org.nd4j.linalg.api.ops.impl.scatter.ScatterUpdate;
|
import org.nd4j.linalg.api.ops.impl.scatter.ScatterUpdate;
|
||||||
import org.nd4j.linalg.api.ops.impl.shape.Concat;
|
import org.nd4j.linalg.api.ops.impl.shape.*;
|
||||||
import org.nd4j.linalg.api.ops.impl.shape.ConfusionMatrix;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.shape.Cross;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.shape.Diag;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.shape.DiagPart;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.shape.ExpandDims;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.shape.Gather;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.shape.GatherNd;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.shape.MergeAvg;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.shape.MergeMax;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.shape.MeshGrid;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.shape.OneHot;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.shape.OnesLike;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.shape.ParallelStack;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.shape.Permute;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.shape.Rank;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.shape.ReductionShape;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.shape.Repeat;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.shape.Reshape;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.shape.SequenceMask;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.shape.Size;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.shape.SizeAt;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.shape.Slice;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.shape.Squeeze;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.shape.Stack;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.shape.StridedSlice;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.shape.Tile;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.shape.Transpose;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.shape.Unstack;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.shape.ZerosLike;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.shape.bp.SliceBp;
|
import org.nd4j.linalg.api.ops.impl.shape.bp.SliceBp;
|
||||||
import org.nd4j.linalg.api.ops.impl.shape.bp.StridedSliceBp;
|
import org.nd4j.linalg.api.ops.impl.shape.bp.StridedSliceBp;
|
||||||
import org.nd4j.linalg.api.ops.impl.shape.bp.TileBp;
|
import org.nd4j.linalg.api.ops.impl.shape.bp.TileBp;
|
||||||
|
@ -339,6 +310,15 @@ public class DifferentialFunctionFactory {
|
||||||
return new ZerosLike(name, sameDiff(), input).outputVariable();
|
return new ZerosLike(name, sameDiff(), input).outputVariable();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public SDVariable create(String name, SDVariable shape, boolean initialize, DataType dataType) {
|
||||||
|
return create(name, shape, 'c', initialize, dataType);
|
||||||
|
}
|
||||||
|
|
||||||
|
public SDVariable create(String name, SDVariable shape, char order, boolean initialize, DataType dataType) {
|
||||||
|
validateDifferentialFunctionsameDiff(shape);
|
||||||
|
return new Create(name, sameDiff(), shape, order, initialize, dataType).outputVariable();
|
||||||
|
}
|
||||||
|
|
||||||
public SDVariable onesLike(String name, SDVariable input, DataType dataType) {
|
public SDVariable onesLike(String name, SDVariable input, DataType dataType) {
|
||||||
validateDifferentialFunctionsameDiff(input);
|
validateDifferentialFunctionsameDiff(input);
|
||||||
return new OnesLike(name, sameDiff(), input, dataType).outputVariable();
|
return new OnesLike(name, sameDiff(), input, dataType).outputVariable();
|
||||||
|
|
|
@ -65,6 +65,7 @@ public class ImportClassMapping {
|
||||||
org.nd4j.linalg.api.ops.impl.broadcast.BroadcastRSubOp.class,
|
org.nd4j.linalg.api.ops.impl.broadcast.BroadcastRSubOp.class,
|
||||||
org.nd4j.linalg.api.ops.impl.broadcast.BroadcastSubOp.class,
|
org.nd4j.linalg.api.ops.impl.broadcast.BroadcastSubOp.class,
|
||||||
org.nd4j.linalg.api.ops.impl.broadcast.BroadcastTo.class,
|
org.nd4j.linalg.api.ops.impl.broadcast.BroadcastTo.class,
|
||||||
|
org.nd4j.linalg.api.ops.impl.shape.Create.class,
|
||||||
org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastEqualTo.class,
|
org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastEqualTo.class,
|
||||||
org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastGreaterThan.class,
|
org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastGreaterThan.class,
|
||||||
org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastGreaterThanOrEqual.class,
|
org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastGreaterThanOrEqual.class,
|
||||||
|
|
|
@ -0,0 +1,138 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.nd4j.linalg.api.ops.impl.shape;
|
||||||
|
|
||||||
|
import lombok.NonNull;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.base.Preconditions;
|
||||||
|
import org.nd4j.imports.NoOpNameFoundException;
|
||||||
|
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
import org.tensorflow.framework.AttrValue;
|
||||||
|
import org.tensorflow.framework.GraphDef;
|
||||||
|
import org.tensorflow.framework.NodeDef;
|
||||||
|
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This operation creates a new, optionally nullified, array with a given shape, order and data type
|
||||||
|
*
|
||||||
|
* @author raver119@gmail.com
|
||||||
|
*/
|
||||||
|
@Slf4j
|
||||||
|
public class Create extends DynamicCustomOp {
|
||||||
|
|
||||||
|
protected boolean initialize = false;
|
||||||
|
protected char order = 'c';
|
||||||
|
protected DataType outputType = DataType.FLOAT; //Allow customizing dtype for TF import
|
||||||
|
|
||||||
|
public Create() {
|
||||||
|
}
|
||||||
|
|
||||||
|
public Create(String name, SameDiff sameDiff, SDVariable input, boolean initialize) {
|
||||||
|
this(name, sameDiff, input, 'c', initialize, input.dataType());
|
||||||
|
}
|
||||||
|
|
||||||
|
public Create(String name, SameDiff sameDiff, SDVariable input, char order, boolean initialize, DataType dataType) {
|
||||||
|
super(name, sameDiff, new SDVariable[]{input}, false);
|
||||||
|
this.outputType = dataType;
|
||||||
|
this.initialize = initialize;
|
||||||
|
this.order = order;
|
||||||
|
|
||||||
|
addArgs();
|
||||||
|
}
|
||||||
|
|
||||||
|
public Create(INDArray shape, DataType dataType) {
|
||||||
|
this(shape, 'c', false, dataType);
|
||||||
|
}
|
||||||
|
|
||||||
|
public Create(INDArray shape, boolean initialize, DataType dataType) {
|
||||||
|
this(shape, 'c', initialize, dataType);
|
||||||
|
}
|
||||||
|
|
||||||
|
public Create(@NonNull INDArray shape, char order, boolean initialize, DataType dataType) {
|
||||||
|
super(new INDArray[]{shape}, new INDArray[0]);
|
||||||
|
this.order = order;
|
||||||
|
this.initialize = initialize;
|
||||||
|
this.outputType = dataType;
|
||||||
|
|
||||||
|
addArgs();
|
||||||
|
}
|
||||||
|
|
||||||
|
protected void addArgs() {
|
||||||
|
addBArgument(initialize);
|
||||||
|
addIArgument((int) order,outputType.toInt());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String opName() {
|
||||||
|
return "create";
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String onnxName() {
|
||||||
|
throw new NoOpNameFoundException("No op found for " + opName());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String tensorflowName() {
|
||||||
|
return "Empty";
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
||||||
|
// convert output data type
|
||||||
|
if(attributesForNode.containsKey("dtype")) {
|
||||||
|
outputType = TFGraphMapper.convertType(attributesForNode.get("dtype").getType());
|
||||||
|
}
|
||||||
|
|
||||||
|
// get init field
|
||||||
|
if(attributesForNode.containsKey("init")) {
|
||||||
|
initialize = attributesForNode.get("init").getB();
|
||||||
|
}
|
||||||
|
|
||||||
|
// there's no order in TF, just plain C
|
||||||
|
this.order = 'c';
|
||||||
|
addArgs();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<SDVariable> doDiff(List<SDVariable> i_v) {
|
||||||
|
SDVariable ret = sameDiff.zerosLike(outputVariables()[0]);
|
||||||
|
return Arrays.asList(ret);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes){
|
||||||
|
Preconditions.checkState(dataTypes.size() == 1, "Expected list with exactly 1 datatype for %s, got %s", getClass(), dataTypes);
|
||||||
|
if(outputType != null){
|
||||||
|
return Collections.singletonList(outputType);
|
||||||
|
} else {
|
||||||
|
//Output type is same as input type
|
||||||
|
return dataTypes;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -9632,6 +9632,7 @@ public static final int PREALLOC_SIZE = 33554432;
|
||||||
// #include <array/ResultSet.h>
|
// #include <array/ResultSet.h>
|
||||||
// #include <helpers/OpArgsHolder.h>
|
// #include <helpers/OpArgsHolder.h>
|
||||||
// #include <dll.h>
|
// #include <dll.h>
|
||||||
|
// #include <ops/declarable/EmptyHandling.h>
|
||||||
//#include <ops/declarable/declarable_ops.h>
|
//#include <ops/declarable/declarable_ops.h>
|
||||||
|
|
||||||
// #include <chrono>
|
// #include <chrono>
|
||||||
|
|
|
@ -11834,6 +11834,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
||||||
// #include <array/ResultSet.h>
|
// #include <array/ResultSet.h>
|
||||||
// #include <helpers/OpArgsHolder.h>
|
// #include <helpers/OpArgsHolder.h>
|
||||||
// #include <dll.h>
|
// #include <dll.h>
|
||||||
|
// #include <ops/declarable/EmptyHandling.h>
|
||||||
//#include <ops/declarable/declarable_ops.h>
|
//#include <ops/declarable/declarable_ops.h>
|
||||||
|
|
||||||
// #include <chrono>
|
// #include <chrono>
|
||||||
|
@ -17126,6 +17127,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
||||||
|
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
@ -20576,7 +20578,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
||||||
}
|
}
|
||||||
// #endif
|
// #endif
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This op make bilinear or nearest neighbor interpolated resize for given tensor
|
* This op make bilinear or nearest neighbor interpolated resize for given tensor
|
||||||
*
|
*
|
||||||
* input array:
|
* input array:
|
||||||
|
@ -20612,7 +20614,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
||||||
}
|
}
|
||||||
// #endif
|
// #endif
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This op make bilinear interpolated resize for given tensor
|
* This op make bilinear interpolated resize for given tensor
|
||||||
*
|
*
|
||||||
* input array:
|
* input array:
|
||||||
|
@ -20647,7 +20649,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
||||||
}
|
}
|
||||||
// #endif
|
// #endif
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This op make nearest neighbor interpolated resize for given tensor
|
* This op make nearest neighbor interpolated resize for given tensor
|
||||||
*
|
*
|
||||||
* input array:
|
* input array:
|
||||||
|
@ -20659,7 +20661,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
||||||
* 1 - new height
|
* 1 - new height
|
||||||
*
|
*
|
||||||
* output array:
|
* output array:
|
||||||
* the 4D-Tensor with calculated backproped dots
|
* the 4D-Tensor with resized image (shape is {batch, newWidth, newHeight, channels})
|
||||||
*
|
*
|
||||||
* CAUTION: either size tensor or a pair of int params should be provided.
|
* CAUTION: either size tensor or a pair of int params should be provided.
|
||||||
*/
|
*/
|
||||||
|
@ -20682,21 +20684,85 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
||||||
}
|
}
|
||||||
// #endif
|
// #endif
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This op calculates backprop dot for two tensors along given dimensions
|
* This op make bicubic interpolated resize for given tensor
|
||||||
*
|
*
|
||||||
* input array:
|
* input array:
|
||||||
* x: tensor to calculate dot for
|
* 0 - 4D-Tensor with shape (batch, sizeX, sizeY, channels)
|
||||||
* y: tensor to calculate dot for
|
* 1 - 1D-Tensor with 2 values (newWidth, newHeight)
|
||||||
* z: tensor with gradient output of the FF dot for x and y
|
|
||||||
*
|
|
||||||
* int arguments:
|
|
||||||
* list of integers - dimensions to calculate dot along,
|
|
||||||
* default corresponds to empty list in which case calculation
|
|
||||||
* is performed for all dimensions and scalar is returned.
|
|
||||||
*
|
*
|
||||||
* output array:
|
* output array:
|
||||||
* the tensor with calculated backproped dots
|
* the 4D-Tensor with resized image (shape is {batch, newWidth, newHeight, channels})
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
// #if NOT_EXCLUDED(OP_resize_bicubic)
|
||||||
|
@Namespace("nd4j::ops") public static class resize_bicubic extends DeclarableCustomOp {
|
||||||
|
static { Loader.load(); }
|
||||||
|
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
||||||
|
public resize_bicubic(Pointer p) { super(p); }
|
||||||
|
/** Native array allocator. Access with {@link Pointer#position(long)}. */
|
||||||
|
public resize_bicubic(long size) { super((Pointer)null); allocateArray(size); }
|
||||||
|
private native void allocateArray(long size);
|
||||||
|
@Override public resize_bicubic position(long position) {
|
||||||
|
return (resize_bicubic)super.position(position);
|
||||||
|
}
|
||||||
|
|
||||||
|
public resize_bicubic() { super((Pointer)null); allocate(); }
|
||||||
|
private native void allocate();
|
||||||
|
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
|
||||||
|
}
|
||||||
|
// #endif
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This op make interpolated resize for given tensor with given algorithm.
|
||||||
|
* Supported algorithms are bilinear, bicubic, nearest_neighbor.
|
||||||
|
* Need to implement to full compatibility with TF: lanczos5, gaussian, area and mitchellcubic
|
||||||
|
*
|
||||||
|
* input array:
|
||||||
|
* 0 - 4D-Tensor with shape (batch, sizeX, sizeY, channels)
|
||||||
|
* 1 - 1D-Tensor with 2 values (newWidth, newHeight)
|
||||||
|
*
|
||||||
|
* optional int args:
|
||||||
|
* 0 - algorithm - bilinear by default
|
||||||
|
* optional bool args:
|
||||||
|
* 0 - preserve_aspect_ratio - default False
|
||||||
|
* 1 - antialias - default False
|
||||||
|
*
|
||||||
|
* output array:
|
||||||
|
* the 4D-Tensor with resized by given algorithm image (shape is {batch, newWidth, newHeight, channels})
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
|
||||||
|
// #if NOT_EXCLUDED(OP_image_resize)
|
||||||
|
@Namespace("nd4j::ops") public static class image_resize extends DeclarableCustomOp {
|
||||||
|
static { Loader.load(); }
|
||||||
|
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
||||||
|
public image_resize(Pointer p) { super(p); }
|
||||||
|
/** Native array allocator. Access with {@link Pointer#position(long)}. */
|
||||||
|
public image_resize(long size) { super((Pointer)null); allocateArray(size); }
|
||||||
|
private native void allocateArray(long size);
|
||||||
|
@Override public image_resize position(long position) {
|
||||||
|
return (image_resize)super.position(position);
|
||||||
|
}
|
||||||
|
|
||||||
|
public image_resize() { super((Pointer)null); allocate(); }
|
||||||
|
private native void allocate();
|
||||||
|
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
|
||||||
|
}
|
||||||
|
// #endif
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Copy a tensor setting everything outside a central band in each innermost matrix
|
||||||
|
*
|
||||||
|
* input array:
|
||||||
|
* x: given tensor with shape {..., M, N} - as vector (matrix) of matricies MxN
|
||||||
|
*
|
||||||
|
* int arguments:
|
||||||
|
* lower band
|
||||||
|
* upper band
|
||||||
|
*
|
||||||
|
* output array:
|
||||||
|
* matrix with given bands between lower and upper diagonals
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
@ -20736,7 +20802,8 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
||||||
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
|
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
|
||||||
}
|
}
|
||||||
// #endif
|
// #endif
|
||||||
/*
|
|
||||||
|
/**
|
||||||
* image.non_max_suppression op.
|
* image.non_max_suppression op.
|
||||||
* input:
|
* input:
|
||||||
* 0 - boxes - 2D-tensor with shape (num_boxes, 4) by float type
|
* 0 - boxes - 2D-tensor with shape (num_boxes, 4) by float type
|
||||||
|
@ -21271,6 +21338,36 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
||||||
}
|
}
|
||||||
// #endif
|
// #endif
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This operation creates new array
|
||||||
|
* Input:
|
||||||
|
* array with shape values
|
||||||
|
*
|
||||||
|
* IArgs:
|
||||||
|
* order value
|
||||||
|
* data type value
|
||||||
|
*
|
||||||
|
* BArgs:
|
||||||
|
* initialization option
|
||||||
|
*/
|
||||||
|
// #if NOT_EXCLUDED(OP_create)
|
||||||
|
@Namespace("nd4j::ops") public static class create extends DeclarableCustomOp {
|
||||||
|
static { Loader.load(); }
|
||||||
|
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
||||||
|
public create(Pointer p) { super(p); }
|
||||||
|
/** Native array allocator. Access with {@link Pointer#position(long)}. */
|
||||||
|
public create(long size) { super((Pointer)null); allocateArray(size); }
|
||||||
|
private native void allocateArray(long size);
|
||||||
|
@Override public create position(long position) {
|
||||||
|
return (create)super.position(position);
|
||||||
|
}
|
||||||
|
|
||||||
|
public create() { super((Pointer)null); allocate(); }
|
||||||
|
private native void allocate();
|
||||||
|
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
|
||||||
|
}
|
||||||
|
// #endif
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
// #endif
|
// #endif
|
||||||
|
|
|
@ -34,6 +34,7 @@ import org.nd4j.linalg.api.ops.impl.controlflow.Where;
|
||||||
import org.nd4j.linalg.api.ops.impl.image.CropAndResize;
|
import org.nd4j.linalg.api.ops.impl.image.CropAndResize;
|
||||||
import org.nd4j.linalg.api.ops.impl.image.ResizeBilinear;
|
import org.nd4j.linalg.api.ops.impl.image.ResizeBilinear;
|
||||||
import org.nd4j.linalg.api.ops.impl.reduce.MmulBp;
|
import org.nd4j.linalg.api.ops.impl.reduce.MmulBp;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.shape.Create;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.any.IsMax;
|
import org.nd4j.linalg.api.ops.impl.transforms.any.IsMax;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp;
|
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.ModOp;
|
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.ModOp;
|
||||||
|
@ -1085,4 +1086,15 @@ public class CustomOpsTests extends BaseNd4jTest {
|
||||||
assertEquals(1, lsd.size());
|
assertEquals(1, lsd.size());
|
||||||
assertArrayEquals(new long[]{1,10, 2}, lsd.get(0).getShape());
|
assertArrayEquals(new long[]{1,10, 2}, lsd.get(0).getShape());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testCreateOp_1() {
|
||||||
|
val shape = Nd4j.createFromArray(new int[] {3, 4, 5});
|
||||||
|
val exp = Nd4j.create(DataType.INT, 3, 4, 5);
|
||||||
|
|
||||||
|
val result = Nd4j.exec(new Create(shape, 'c', true, DataType.INT))[0];
|
||||||
|
|
||||||
|
assertEquals(exp, result);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue