SameDiff TF import (#49)
* Added implementation files for image_resize and resize_bicubic ops. * Image resize and image.resize_bicubic ops implementation. Initial revision. * Minor fix * Some TF imports disabled. * Finished with infrastructure development for image.resize_bilinear op and image_resizo op implementation. * Refactored resize methods. * Added processing for Mitchelcubic algorithm. * adjust_contrast * Small fix for TF import expected value loading when variable name starts with the test name Signed-off-by: AlexDBlack <blacka101@gmail.com> * Tests * Tests added. * Removed tf names absent in mapping. * Some fixes. * Small fixes * Minor change * Some failing tests. * Disable failed test * Ignore some tests * Fix import class mapping Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix float property mapping (flatbuffers) Signed-off-by: AlexDBlack <blacka101@gmail.com> * Override equality function for model 'dropout' Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fail tests * Failed tests ignored temporarily. * Minor fixes * Small fix * Conflict resolved * Default implementations of tensorflowName and onnxNamemaster
parent
ce2ef20f96
commit
da1944e8e1
|
@ -0,0 +1,91 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author sgazeos@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <op_boilerplate.h>
|
||||||
|
#if NOT_EXCLUDED(OP_image_resize)
|
||||||
|
|
||||||
|
#include <ops/declarable/CustomOperations.h>
|
||||||
|
#include <ops/declarable/helpers/image_resize.h>
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
namespace ops {
|
||||||
|
CUSTOM_OP_IMPL(image_resize, 2, 1, false, 0, 0) {
|
||||||
|
|
||||||
|
auto image = INPUT_VARIABLE(0);
|
||||||
|
auto size = INPUT_VARIABLE(1);
|
||||||
|
|
||||||
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
int width;
|
||||||
|
int height;
|
||||||
|
bool preserveAspectRatio = false; // - default value
|
||||||
|
bool antialias = false;
|
||||||
|
REQUIRE_TRUE(size->lengthOf() == 2, 0, "resize_bilinear: Resize params is a pair of values, not %lld.", size->lengthOf());
|
||||||
|
width = size->e<int>(0);
|
||||||
|
height = size->e<int>(1);
|
||||||
|
if (block.getBArguments()->size()) {
|
||||||
|
preserveAspectRatio = B_ARG(0);
|
||||||
|
if (block.getBArguments()->size() > 1)
|
||||||
|
antialias = B_ARG(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto method = helpers::ImageResizeMethods::kResizeBilinear;
|
||||||
|
if (block.numI() == 1) {
|
||||||
|
method = (helpers::ImageResizeMethods)INT_ARG(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
return helpers::resizeFunctor(block.launchContext(), image, width, height, method, preserveAspectRatio, antialias, output);
|
||||||
|
}
|
||||||
|
|
||||||
|
DECLARE_SHAPE_FN(image_resize) {
|
||||||
|
auto shapeList = SHAPELIST();
|
||||||
|
auto in = inputShape->at(0);
|
||||||
|
|
||||||
|
Nd4jLong* outputShape;
|
||||||
|
|
||||||
|
int width;
|
||||||
|
int height;
|
||||||
|
auto newImageSize = INPUT_VARIABLE(1);
|
||||||
|
REQUIRE_TRUE(newImageSize->lengthOf() == 2, 0, "resize_bilinear: Resize params is a pair of values, not %i.", newImageSize->lengthOf());
|
||||||
|
REQUIRE_TRUE(block.numI() <= 1, 0, "resize_bilinear: Resize params already given by the second param. Int params are expensive.");
|
||||||
|
width = newImageSize->e<int>(0);
|
||||||
|
height = newImageSize->e<int>(1);
|
||||||
|
|
||||||
|
ALLOCATE(outputShape, block.getWorkspace(), shape::shapeInfoLength(4), Nd4jLong);
|
||||||
|
outputShape[0] = 4;
|
||||||
|
outputShape[1] = in[1];
|
||||||
|
outputShape[2] = width;
|
||||||
|
outputShape[3] = height;
|
||||||
|
outputShape[4] = in[4];
|
||||||
|
ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in));
|
||||||
|
|
||||||
|
shapeList->push_back(CONSTANT(outputShape));
|
||||||
|
return shapeList;
|
||||||
|
}
|
||||||
|
DECLARE_TYPES(image_resize) {
|
||||||
|
getOpDescriptor()
|
||||||
|
->setAllowedInputTypes(0, {ALL_FLOATS})
|
||||||
|
->setAllowedInputTypes(1, {ALL_INTS})
|
||||||
|
->setAllowedOutputTypes({ALL_FLOATS});
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
|
@ -0,0 +1,93 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author sgazeos@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <op_boilerplate.h>
|
||||||
|
#if NOT_EXCLUDED(OP_resize_bicubic)
|
||||||
|
|
||||||
|
#include <ops/declarable/CustomOperations.h>
|
||||||
|
#include <ops/declarable/helpers/image_resize.h>
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
namespace ops {
|
||||||
|
CUSTOM_OP_IMPL(resize_bicubic, 2, 1, false, 0, 0) {
|
||||||
|
|
||||||
|
auto image = INPUT_VARIABLE(0);
|
||||||
|
auto size = INPUT_VARIABLE(1);
|
||||||
|
|
||||||
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
int width;
|
||||||
|
int height;
|
||||||
|
bool center = false; // - default value
|
||||||
|
|
||||||
|
REQUIRE_TRUE(size->lengthOf() == 2, 0, "resize_bilinear: Resize params is a pair of values, not %i.", size->lengthOf());
|
||||||
|
REQUIRE_TRUE(block.numI() <= 1, 0, "resize_bilinear: Resize params already given by the second param. Int params are expensive.");
|
||||||
|
width = size->e<int>(0);
|
||||||
|
height = size->e<int>(1);
|
||||||
|
auto method = 1; //kResizeBilinear;
|
||||||
|
if (block.numI() == 1) {
|
||||||
|
method = INT_ARG(0);
|
||||||
|
}
|
||||||
|
auto preserveAspectRatio = false;
|
||||||
|
auto antialias = false;
|
||||||
|
if (block.numB() > 0) {
|
||||||
|
preserveAspectRatio = block.getBArguments()->at(0);
|
||||||
|
if (block.numB()> 1)
|
||||||
|
antialias = block.getBArguments()->at(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
return helpers::resizeBicubicFunctor(block.launchContext(), image, width, height, preserveAspectRatio, antialias, output);
|
||||||
|
}
|
||||||
|
|
||||||
|
DECLARE_SHAPE_FN(resize_bicubic) {
|
||||||
|
auto shapeList = SHAPELIST();
|
||||||
|
auto in = inputShape->at(0);
|
||||||
|
|
||||||
|
Nd4jLong* outputShape;
|
||||||
|
|
||||||
|
int width;
|
||||||
|
int height;
|
||||||
|
auto newImageSize = INPUT_VARIABLE(1);
|
||||||
|
REQUIRE_TRUE(newImageSize->lengthOf() == 2, 0, "resize_bilinear: Resize params is a pair of values, not %i.", newImageSize->lengthOf());
|
||||||
|
REQUIRE_TRUE(block.numI() <= 1, 0, "resize_bilinear: Resize params already given by the second param. Int params are expensive.");
|
||||||
|
width = newImageSize->e<int>(0);
|
||||||
|
height = newImageSize->e<int>(1);
|
||||||
|
|
||||||
|
ALLOCATE(outputShape, block.getWorkspace(), shape::shapeInfoLength(4), Nd4jLong);
|
||||||
|
outputShape[0] = 4;
|
||||||
|
outputShape[1] = in[1];
|
||||||
|
outputShape[2] = width;
|
||||||
|
outputShape[3] = height;
|
||||||
|
outputShape[4] = in[4];
|
||||||
|
ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in));
|
||||||
|
|
||||||
|
shapeList->push_back(CONSTANT(outputShape));
|
||||||
|
return shapeList;
|
||||||
|
}
|
||||||
|
DECLARE_TYPES(resize_bicubic) {
|
||||||
|
getOpDescriptor()
|
||||||
|
->setAllowedInputTypes(0, {ALL_FLOATS})
|
||||||
|
->setAllowedInputTypes(1, {ALL_INTS})
|
||||||
|
->setAllowedOutputTypes({ALL_FLOATS});
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
|
@ -1,5 +1,6 @@
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
@ -1649,7 +1650,7 @@ namespace nd4j {
|
||||||
* 1 - new height
|
* 1 - new height
|
||||||
*
|
*
|
||||||
* output array:
|
* output array:
|
||||||
* the 4D-Tensor with calculated backproped dots
|
* the 4D-Tensor with resized image (shape is {batch, newWidth, newHeight, channels})
|
||||||
*
|
*
|
||||||
* CAUTION: either size tensor or a pair of int params should be provided.
|
* CAUTION: either size tensor or a pair of int params should be provided.
|
||||||
*/
|
*/
|
||||||
|
@ -1659,20 +1660,56 @@ namespace nd4j {
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This op calculates backprop dot for two tensors along given dimensions
|
* This op make bicubic interpolated resize for given tensor
|
||||||
*
|
*
|
||||||
* input array:
|
* input array:
|
||||||
* x: tensor to calculate dot for
|
* 0 - 4D-Tensor with shape (batch, sizeX, sizeY, channels)
|
||||||
* y: tensor to calculate dot for
|
* 1 - 1D-Tensor with 2 values (newWidth, newHeight)
|
||||||
* z: tensor with gradient output of the FF dot for x and y
|
|
||||||
*
|
|
||||||
* int arguments:
|
|
||||||
* list of integers - dimensions to calculate dot along,
|
|
||||||
* default corresponds to empty list in which case calculation
|
|
||||||
* is performed for all dimensions and scalar is returned.
|
|
||||||
*
|
*
|
||||||
* output array:
|
* output array:
|
||||||
* the tensor with calculated backproped dots
|
* the 4D-Tensor with resized image (shape is {batch, newWidth, newHeight, channels})
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
#if NOT_EXCLUDED(OP_resize_bicubic)
|
||||||
|
DECLARE_CUSTOM_OP(resize_bicubic, 1, 1, false, 0, -2);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This op make interpolated resize for given tensor with given algorithm.
|
||||||
|
* Supported algorithms are bilinear, bicubic, nearest_neighbor.
|
||||||
|
* Need to implement to full compatibility with TF: lanczos5, gaussian, area and mitchellcubic
|
||||||
|
*
|
||||||
|
* input array:
|
||||||
|
* 0 - 4D-Tensor with shape (batch, sizeX, sizeY, channels)
|
||||||
|
* 1 - 1D-Tensor with 2 values (newWidth, newHeight)
|
||||||
|
*
|
||||||
|
* optional int args:
|
||||||
|
* 0 - algorithm - bilinear by default
|
||||||
|
* optional bool args:
|
||||||
|
* 0 - preserve_aspect_ratio - default False
|
||||||
|
* 1 - antialias - default False
|
||||||
|
*
|
||||||
|
* output array:
|
||||||
|
* the 4D-Tensor with resized by given algorithm image (shape is {batch, newWidth, newHeight, channels})
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
|
||||||
|
#if NOT_EXCLUDED(OP_image_resize)
|
||||||
|
DECLARE_CUSTOM_OP(image_resize, 2, 1, false, 0, 0);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Copy a tensor setting everything outside a central band in each innermost matrix
|
||||||
|
*
|
||||||
|
* input array:
|
||||||
|
* x: given tensor with shape {..., M, N} - as vector (matrix) of matricies MxN
|
||||||
|
*
|
||||||
|
* int arguments:
|
||||||
|
* lower band
|
||||||
|
* upper band
|
||||||
|
*
|
||||||
|
* output array:
|
||||||
|
* matrix with given bands between lower and upper diagonals
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
@ -1684,7 +1721,8 @@ namespace nd4j {
|
||||||
#if NOT_EXCLUDED(OP_Assert)
|
#if NOT_EXCLUDED(OP_Assert)
|
||||||
DECLARE_OP(Assert, 1, 1, false);
|
DECLARE_OP(Assert, 1, 1, false);
|
||||||
#endif
|
#endif
|
||||||
/*
|
|
||||||
|
/**
|
||||||
* image.non_max_suppression op.
|
* image.non_max_suppression op.
|
||||||
* input:
|
* input:
|
||||||
* 0 - boxes - 2D-tensor with shape (num_boxes, 4) by float type
|
* 0 - boxes - 2D-tensor with shape (num_boxes, 4) by float type
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
@ -334,6 +335,25 @@ namespace helpers {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int resizeBicubicFunctor(nd4j::LaunchContext * context, NDArray const* image, int width, int height,
|
||||||
|
bool preserveAspectRatio, bool antialias, NDArray* output) {
|
||||||
|
return ND4J_STATUS_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
int resizeFunctor(nd4j::LaunchContext * context, NDArray const* image, int width, int height,
|
||||||
|
ImageResizeMethods method, bool preserveAspectRatio, bool antialias, NDArray* output) {
|
||||||
|
switch (method) {
|
||||||
|
case kResizeBilinear: return resizeBilinearFunctor(context, image, width, height, false, output); break;
|
||||||
|
case kResizeNearest: return resizeNeighborFunctor(context, image, width, height, false, output); break;
|
||||||
|
case kResizeBicubic: return resizeBicubicFunctor(context, image, width, height, preserveAspectRatio, antialias, output); break;
|
||||||
|
case kResizeLanczos5:
|
||||||
|
case kResizeGaussian:
|
||||||
|
case kResizeArea:
|
||||||
|
case kResizeMitchelcubic:
|
||||||
|
throw std::runtime_error("helper::resizeFunctor: Non implemented yet.");
|
||||||
|
}
|
||||||
|
return ND4J_STATUS_OK;
|
||||||
|
}
|
||||||
|
|
||||||
void
|
void
|
||||||
cropAndResizeFunctor(nd4j::LaunchContext * context, NDArray const *images, NDArray const *boxes, NDArray const *indices, NDArray const *cropSize,
|
cropAndResizeFunctor(nd4j::LaunchContext * context, NDArray const *images, NDArray const *boxes, NDArray const *indices, NDArray const *cropSize,
|
||||||
|
|
|
@ -293,7 +293,27 @@ namespace helpers {
|
||||||
BUILD_SINGLE_TEMPLATE(template int resizeNeighborFunctor_, (nd4j::LaunchContext* context, NDArray const* images,
|
BUILD_SINGLE_TEMPLATE(template int resizeNeighborFunctor_, (nd4j::LaunchContext* context, NDArray const* images,
|
||||||
int width, int height, bool center, NDArray* output), LIBND4J_TYPES);
|
int width, int height, bool center, NDArray* output), LIBND4J_TYPES);
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
int resizeBicubicFunctor(nd4j::LaunchContext * context, NDArray const* image, int width, int height,
|
||||||
|
bool preserveAspectRatio, bool antialias, NDArray* output) {
|
||||||
|
return ND4J_STATUS_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
int resizeFunctor(nd4j::LaunchContext * context, NDArray const* image, int width, int height,
|
||||||
|
ImageResizeMethods method, bool preserveAspectRatio, bool antialias, NDArray* output) {
|
||||||
|
switch (method) {
|
||||||
|
case kResizeBilinear: return resizeBilinearFunctor(context, image, width, height, false, output); break;
|
||||||
|
case kResizeNearest: return resizeNeighborFunctor(context, image, width, height, true, output); break;
|
||||||
|
case kResizeBicubic: return resizeBicubicFunctor(context, image, width, height, preserveAspectRatio, antialias, output); break;
|
||||||
|
case kResizeLanczos5:
|
||||||
|
case kResizeGaussian:
|
||||||
|
case kResizeArea:
|
||||||
|
case kResizeMitchelcubic:
|
||||||
|
throw std::runtime_error("helper::resizeFunctor: Non implemented yet.");
|
||||||
|
}
|
||||||
|
return ND4J_STATUS_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
// --------------------------------------------------------------------------------------------------------------- //
|
// --------------------------------------------------------------------------------------------------------------- //
|
||||||
// Crop and Resize helper implementation
|
// Crop and Resize helper implementation
|
||||||
// --------------------------------------------------------------------------------------------------------------- //
|
// --------------------------------------------------------------------------------------------------------------- //
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
@ -26,9 +27,27 @@ namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
namespace helpers {
|
namespace helpers {
|
||||||
|
|
||||||
int resizeBilinearFunctor(nd4j::LaunchContext * context, NDArray const* image, int width, int height, bool center, NDArray* output);
|
enum ImageResizeMethods {
|
||||||
int resizeNeighborFunctor(nd4j::LaunchContext * context, NDArray const* image, int width, int height, bool center, NDArray* output);
|
kResizeBilinear = 1,
|
||||||
void cropAndResizeFunctor(nd4j::LaunchContext * context, NDArray const* images, NDArray const* boxes, NDArray const* indices, NDArray const* cropSize, int method, double extrapolationVal, NDArray* crops);
|
kResizeBicubic,
|
||||||
|
kResizeNearest,
|
||||||
|
kResizeGaussian,
|
||||||
|
kResizeLanczos5,
|
||||||
|
kResizeMitchelcubic,
|
||||||
|
kResizeArea
|
||||||
|
};
|
||||||
|
|
||||||
|
int resizeBilinearFunctor(nd4j::LaunchContext * context, NDArray const* image, int width, int height, bool center,
|
||||||
|
NDArray* output);
|
||||||
|
int resizeNeighborFunctor(nd4j::LaunchContext * context, NDArray const* image, int width, int height, bool center,
|
||||||
|
NDArray* output);
|
||||||
|
int resizeBicubicFunctor(nd4j::LaunchContext * context, NDArray const* image, int width, int height,
|
||||||
|
bool preserveAspectRatio, bool antialias, NDArray* output);
|
||||||
|
int resizeFunctor(nd4j::LaunchContext * context, NDArray const* image, int width, int height,
|
||||||
|
ImageResizeMethods method, bool preserveAspectRatio, bool antialias, NDArray* output);
|
||||||
|
|
||||||
|
void cropAndResizeFunctor(nd4j::LaunchContext * context, NDArray const* images, NDArray const* boxes,
|
||||||
|
NDArray const* indices, NDArray const* cropSize, int method, double extrapolationVal, NDArray* crops);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -255,6 +255,11 @@ public abstract class DifferentialFunction {
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
try {
|
try {
|
||||||
|
//Edge case: we store float fields as doubles, rather than introduce an extra property
|
||||||
|
if(target.getType() == float.class && value instanceof Double){
|
||||||
|
value = ((Double) value).floatValue();
|
||||||
|
}
|
||||||
|
|
||||||
target.set(this,value);
|
target.set(this,value);
|
||||||
} catch (IllegalAccessException e) {
|
} catch (IllegalAccessException e) {
|
||||||
throw new RuntimeException("Error setting property for function " + getClass().getName(), e);
|
throw new RuntimeException("Error setting property for function " + getClass().getName(), e);
|
||||||
|
|
|
@ -479,6 +479,8 @@ public class FlatBuffersMapper {
|
||||||
} else if (v instanceof Number) {
|
} else if (v instanceof Number) {
|
||||||
if (v instanceof Double) {
|
if (v instanceof Double) {
|
||||||
d = new double[]{(Double) v};
|
d = new double[]{(Double) v};
|
||||||
|
} else if (v instanceof Float){
|
||||||
|
d = new double[]{(Float) v};
|
||||||
} else if (v instanceof Integer) {
|
} else if (v instanceof Integer) {
|
||||||
i = new int[]{(Integer) v};
|
i = new int[]{(Integer) v};
|
||||||
} else if (v instanceof Long) {
|
} else if (v instanceof Long) {
|
||||||
|
|
|
@ -46,6 +46,7 @@ public class ImportClassMapping {
|
||||||
org.nd4j.linalg.api.ops.custom.BarnesEdgeForces.class,
|
org.nd4j.linalg.api.ops.custom.BarnesEdgeForces.class,
|
||||||
org.nd4j.linalg.api.ops.custom.BarnesHutGains.class,
|
org.nd4j.linalg.api.ops.custom.BarnesHutGains.class,
|
||||||
org.nd4j.linalg.api.ops.custom.BarnesHutSymmetrize.class,
|
org.nd4j.linalg.api.ops.custom.BarnesHutSymmetrize.class,
|
||||||
|
org.nd4j.linalg.api.ops.custom.DrawBoundingBoxes.class,
|
||||||
org.nd4j.linalg.api.ops.custom.KnnMinDistance.class,
|
org.nd4j.linalg.api.ops.custom.KnnMinDistance.class,
|
||||||
org.nd4j.linalg.api.ops.custom.SpTreeCell.class,
|
org.nd4j.linalg.api.ops.custom.SpTreeCell.class,
|
||||||
org.nd4j.linalg.api.ops.custom.Flatten.class,
|
org.nd4j.linalg.api.ops.custom.Flatten.class,
|
||||||
|
@ -584,7 +585,6 @@ public class ImportClassMapping {
|
||||||
org.nd4j.linalg.api.ops.custom.BitCast.class,
|
org.nd4j.linalg.api.ops.custom.BitCast.class,
|
||||||
org.nd4j.linalg.api.ops.custom.CompareAndBitpack.class,
|
org.nd4j.linalg.api.ops.custom.CompareAndBitpack.class,
|
||||||
org.nd4j.linalg.api.ops.custom.DivideNoNan.class,
|
org.nd4j.linalg.api.ops.custom.DivideNoNan.class,
|
||||||
org.nd4j.linalg.api.ops.custom.DrawBoundingBoxes.class,
|
|
||||||
org.nd4j.linalg.api.ops.custom.FakeQuantWithMinMaxVarsPerChannel.class
|
org.nd4j.linalg.api.ops.custom.FakeQuantWithMinMaxVarsPerChannel.class
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|
|
@ -267,7 +267,7 @@ public class TFGraphMapper {
|
||||||
https://github.com/eclipse/deeplearning4j/issues/8285
|
https://github.com/eclipse/deeplearning4j/issues/8285
|
||||||
*/
|
*/
|
||||||
DifferentialFunction dfInstance = DifferentialFunctionClassHolder.getInstance().getOpWithTensorflowName(opName);
|
DifferentialFunction dfInstance = DifferentialFunctionClassHolder.getInstance().getOpWithTensorflowName(opName);
|
||||||
Preconditions.checkState(dfInstance != null, "Could not find class for TF Ops: {}", opName);
|
Preconditions.checkState(dfInstance != null, "Could not find class for TF Ops: %s", opName);
|
||||||
|
|
||||||
DifferentialFunction df;
|
DifferentialFunction df;
|
||||||
try {
|
try {
|
||||||
|
|
|
@ -25,6 +25,7 @@ import org.nd4j.autodiff.functions.DifferentialFunction;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
|
import org.nd4j.imports.NoOpNameFoundException;
|
||||||
import org.nd4j.linalg.api.buffer.DataBuffer;
|
import org.nd4j.linalg.api.buffer.DataBuffer;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
@ -386,4 +387,15 @@ public abstract class BaseOp extends DifferentialFunction implements Op {
|
||||||
y = null;
|
y = null;
|
||||||
z = null;
|
z = null;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String onnxName() {
|
||||||
|
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String tensorflowName() {
|
||||||
|
throw new NoOpNameFoundException("No tensorflow op opName found for " + opName());
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,9 +3,13 @@ package org.nd4j.linalg.api.ops.custom;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
public class AdjustContrast extends BaseAdjustContrast {
|
public class AdjustContrast extends BaseAdjustContrast {
|
||||||
|
|
||||||
public AdjustContrast() {super();}
|
public AdjustContrast() {super();}
|
||||||
|
|
|
@ -3,9 +3,13 @@ package org.nd4j.linalg.api.ops.custom;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
public class AdjustContrastV2 extends BaseAdjustContrast {
|
public class AdjustContrastV2 extends BaseAdjustContrast {
|
||||||
|
|
||||||
public AdjustContrastV2() {super();}
|
public AdjustContrastV2() {super();}
|
||||||
|
@ -25,6 +29,6 @@ public class AdjustContrastV2 extends BaseAdjustContrast {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String tensorflowName() {
|
public String tensorflowName() {
|
||||||
return "AdjustContrastV2";
|
return "AdjustContrastv2";
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -3,9 +3,13 @@ package org.nd4j.linalg.api.ops.custom;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
public abstract class BaseAdjustContrast extends DynamicCustomOp {
|
public abstract class BaseAdjustContrast extends DynamicCustomOp {
|
||||||
public BaseAdjustContrast() {
|
public BaseAdjustContrast() {
|
||||||
}
|
}
|
||||||
|
@ -22,4 +26,11 @@ public abstract class BaseAdjustContrast extends DynamicCustomOp {
|
||||||
public BaseAdjustContrast(SameDiff sameDiff, SDVariable[] vars) {
|
public BaseAdjustContrast(SameDiff sameDiff, SDVariable[] vars) {
|
||||||
super("", sameDiff, vars);
|
super("", sameDiff, vars);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
||||||
|
int n = args().length;
|
||||||
|
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes);
|
||||||
|
return Collections.singletonList(inputDataTypes.get(0));
|
||||||
|
}
|
||||||
}
|
}
|
|
@ -3,16 +3,18 @@ package org.nd4j.linalg.api.ops.custom;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
import org.nd4j.linalg.api.shape.options.ArrayOptionsHelper;
|
import org.nd4j.linalg.api.shape.options.ArrayOptionsHelper;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
|
||||||
import org.tensorflow.framework.AttrValue;
|
import org.tensorflow.framework.AttrValue;
|
||||||
import org.tensorflow.framework.GraphDef;
|
import org.tensorflow.framework.GraphDef;
|
||||||
import org.tensorflow.framework.NodeDef;
|
import org.tensorflow.framework.NodeDef;
|
||||||
|
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
|
||||||
public class BitCast extends DynamicCustomOp {
|
public class BitCast extends DynamicCustomOp {
|
||||||
|
@ -58,4 +60,11 @@ public class BitCast extends DynamicCustomOp {
|
||||||
public String tensorflowName() {
|
public String tensorflowName() {
|
||||||
return "Bitcast";
|
return "Bitcast";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
||||||
|
int n = args().length;
|
||||||
|
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes);
|
||||||
|
return Collections.singletonList(inputDataTypes.get(0));
|
||||||
|
}
|
||||||
}
|
}
|
|
@ -3,8 +3,14 @@ package org.nd4j.linalg.api.ops.custom;
|
||||||
import org.apache.commons.math3.analysis.function.Divide;
|
import org.apache.commons.math3.analysis.function.Divide;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.base.Preconditions;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
import org.nd4j.linalg.api.shape.Shape;
|
||||||
|
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
public class DivideNoNan extends DynamicCustomOp {
|
public class DivideNoNan extends DynamicCustomOp {
|
||||||
public DivideNoNan() {
|
public DivideNoNan() {
|
||||||
|
@ -29,4 +35,12 @@ public class DivideNoNan extends DynamicCustomOp {
|
||||||
public String tensorflowName() {
|
public String tensorflowName() {
|
||||||
return "DivNoNan";
|
return "DivNoNan";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes){
|
||||||
|
Preconditions.checkState(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes for %s, got input %s", getClass(), dataTypes);
|
||||||
|
|
||||||
|
DataType z = Shape.pickPairwiseDataType(dataTypes.get(0), dataTypes.get(1));
|
||||||
|
return Collections.singletonList(z);
|
||||||
|
}
|
||||||
}
|
}
|
|
@ -2,9 +2,14 @@ package org.nd4j.linalg.api.ops.custom;
|
||||||
|
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.base.Preconditions;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
public class DrawBoundingBoxes extends DynamicCustomOp {
|
public class DrawBoundingBoxes extends DynamicCustomOp {
|
||||||
public DrawBoundingBoxes() {}
|
public DrawBoundingBoxes() {}
|
||||||
|
|
||||||
|
@ -26,7 +31,14 @@ public class DrawBoundingBoxes extends DynamicCustomOp {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String tensorflowName() {
|
public String[] tensorflowNames() {
|
||||||
return "DrawBoundingBoxes";
|
return new String[]{"DrawBoundingBoxes", "DrawBoundingBoxesV2"};
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
||||||
|
int n = args().length;
|
||||||
|
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes);
|
||||||
|
return Collections.singletonList(inputDataTypes.get(0));
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -3,9 +3,13 @@ package org.nd4j.linalg.api.ops.custom;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
public class FakeQuantWithMinMaxVarsPerChannel extends DynamicCustomOp {
|
public class FakeQuantWithMinMaxVarsPerChannel extends DynamicCustomOp {
|
||||||
public FakeQuantWithMinMaxVarsPerChannel() {}
|
public FakeQuantWithMinMaxVarsPerChannel() {}
|
||||||
|
|
||||||
|
@ -33,4 +37,10 @@ public class FakeQuantWithMinMaxVarsPerChannel extends DynamicCustomOp {
|
||||||
public String tensorflowName() {
|
public String tensorflowName() {
|
||||||
return "FakeQuantWithMinMaxVarsPerChannel";
|
return "FakeQuantWithMinMaxVarsPerChannel";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
||||||
|
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 3, "Expected exactly 3 inputs, got %s", inputDataTypes);
|
||||||
|
return Collections.singletonList(inputDataTypes.get(0));
|
||||||
|
}
|
||||||
}
|
}
|
|
@ -21,6 +21,7 @@ import org.nd4j.autodiff.functions.DifferentialFunction;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
|
import org.nd4j.imports.NoOpNameFoundException;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.BaseBroadcastOp;
|
import org.nd4j.linalg.api.ops.BaseBroadcastOp;
|
||||||
|
@ -65,11 +66,6 @@ public class BiasAddGrad extends DynamicCustomOp {
|
||||||
return "BiasAddGrad";
|
return "BiasAddGrad";
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public String tensorflowName() {
|
|
||||||
return "BiasAddGrad";
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
||||||
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 3, "Expected 3 input data types for %s, got %s", getClass(), inputDataTypes);
|
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 3, "Expected 3 input data types for %s, got %s", getClass(), inputDataTypes);
|
||||||
|
|
|
@ -61,14 +61,4 @@ public class BroadcastAddOp extends BaseBroadcastOp {
|
||||||
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public String onnxName() {
|
|
||||||
throw new NoOpNameFoundException("No ONNX op name found for: " + getClass().getName());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String tensorflowName() {
|
|
||||||
return "BroadcastAdd";
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -61,14 +61,4 @@ public class BroadcastGradientArgs extends BaseBroadcastOp {
|
||||||
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public String onnxName() {
|
|
||||||
throw new NoOpNameFoundException("No op name found for " + opName());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String tensorflowName() {
|
|
||||||
return "BroadcastGradientArgs";
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -51,7 +51,8 @@ public class CropAndResize extends DynamicCustomOp {
|
||||||
}
|
}
|
||||||
|
|
||||||
public CropAndResize(@NonNull INDArray image, @NonNull INDArray cropBoxes, @NonNull INDArray boxIndices,
|
public CropAndResize(@NonNull INDArray image, @NonNull INDArray cropBoxes, @NonNull INDArray boxIndices,
|
||||||
@NonNull INDArray cropOutSize, @NonNull Method method, double extrapolationValue){
|
@NonNull INDArray cropOutSize, @NonNull Method method, double extrapolationValue,
|
||||||
|
INDArray output){
|
||||||
super(new INDArray[]{image, cropBoxes, boxIndices, cropOutSize}, null);
|
super(new INDArray[]{image, cropBoxes, boxIndices, cropOutSize}, null);
|
||||||
Preconditions.checkArgument(image.rank() == 4, "Input image must be rank 4 with shape [batch, height, width, channels], got %ndShape", image);
|
Preconditions.checkArgument(image.rank() == 4, "Input image must be rank 4 with shape [batch, height, width, channels], got %ndShape", image);
|
||||||
Preconditions.checkArgument(cropBoxes.rank() == 2 && cropBoxes.size(1) == 4, "Crop boxes must be rank 4 with shape [num_boxes, 5], got %ndShape", cropBoxes);
|
Preconditions.checkArgument(cropBoxes.rank() == 2 && cropBoxes.size(1) == 4, "Crop boxes must be rank 4 with shape [num_boxes, 5], got %ndShape", cropBoxes);
|
||||||
|
@ -60,6 +61,7 @@ public class CropAndResize extends DynamicCustomOp {
|
||||||
this.method = method;
|
this.method = method;
|
||||||
this.extrapolationValue = extrapolationValue;
|
this.extrapolationValue = extrapolationValue;
|
||||||
addArgs();
|
addArgs();
|
||||||
|
outputArguments.add(output);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -89,8 +91,6 @@ public class CropAndResize extends DynamicCustomOp {
|
||||||
}
|
}
|
||||||
|
|
||||||
protected void addArgs() {
|
protected void addArgs() {
|
||||||
iArguments.clear();
|
|
||||||
tArguments.clear();
|
|
||||||
addIArgument(method == Method.BILINEAR ? 0 : 1);
|
addIArgument(method == Method.BILINEAR ? 0 : 1);
|
||||||
addTArgument(extrapolationValue);
|
addTArgument(extrapolationValue);
|
||||||
}
|
}
|
||||||
|
|
|
@ -53,7 +53,7 @@ public class NonMaxSuppression extends DynamicCustomOp {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String[] tensorflowNames() {
|
public String[] tensorflowNames() {
|
||||||
return new String[]{"NonMaxSuppression", "NonMaxSuppressionV2"};
|
return new String[]{"NonMaxSuppression", "NonMaxSuppressionV2", "NonMaxSuppressionV3"};
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -30,6 +30,7 @@ import onnx.Onnx;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
|
import org.nd4j.imports.NoOpNameFoundException;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
@ -231,11 +232,6 @@ public class Pooling2D extends DynamicCustomOp {
|
||||||
return "Pooling";
|
return "Pooling";
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public String tensorflowName() {
|
|
||||||
return "Pooling2D";
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
||||||
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 1, "Expected 1 input data type for %s, got %s", getClass(), inputDataTypes);
|
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 1, "Expected 1 input data type for %s, got %s", getClass(), inputDataTypes);
|
||||||
|
|
|
@ -64,11 +64,6 @@ public class SigmoidCrossEntropyLoss extends BaseLoss {
|
||||||
return "sigm_cross_entropy_loss";
|
return "sigm_cross_entropy_loss";
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public String tensorflowName() {
|
|
||||||
return "sigmoid_cross_entropy";
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<SDVariable> doDiff(List<SDVariable> grad){
|
public List<SDVariable> doDiff(List<SDVariable> grad){
|
||||||
//No external gradient
|
//No external gradient
|
||||||
|
|
|
@ -73,16 +73,6 @@ public class Moments extends DynamicCustomOp {
|
||||||
return "moments";
|
return "moments";
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public String onnxName() {
|
|
||||||
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String tensorflowName() {
|
|
||||||
return "moments";
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<SDVariable> doDiff(List<SDVariable> grad){
|
public List<SDVariable> doDiff(List<SDVariable> grad){
|
||||||
SDVariable dLdMean = grad.get(0);
|
SDVariable dLdMean = grad.get(0);
|
||||||
|
|
|
@ -69,16 +69,6 @@ public class NormalizeMoments extends DynamicCustomOp {
|
||||||
return "normalize_moments";
|
return "normalize_moments";
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public String onnxName() {
|
|
||||||
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String tensorflowName() {
|
|
||||||
return "normalize_moments";
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
||||||
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 3, "Expected 3 input datatypes for %s, got %s", getClass(), inputDataTypes);
|
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 3, "Expected 3 input datatypes for %s, got %s", getClass(), inputDataTypes);
|
||||||
|
|
|
@ -18,6 +18,7 @@ package org.nd4j.linalg.api.ops.impl.reduce.bool;
|
||||||
|
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.imports.NoOpNameFoundException;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.BaseReduceBoolOp;
|
import org.nd4j.linalg.api.ops.BaseReduceBoolOp;
|
||||||
|
|
||||||
|
@ -64,12 +65,6 @@ public class IsInf extends BaseReduceBoolOp {
|
||||||
return "HasInf";
|
return "HasInf";
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public String tensorflowName() {
|
|
||||||
return "HasInf";
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<SDVariable> doDiff(List<SDVariable> i_v) {
|
public List<SDVariable> doDiff(List<SDVariable> i_v) {
|
||||||
|
|
|
@ -18,6 +18,7 @@ package org.nd4j.linalg.api.ops.impl.reduce.bool;
|
||||||
|
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.imports.NoOpNameFoundException;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.BaseReduceBoolOp;
|
import org.nd4j.linalg.api.ops.BaseReduceBoolOp;
|
||||||
|
|
||||||
|
@ -64,12 +65,6 @@ public class IsNaN extends BaseReduceBoolOp {
|
||||||
return "hasNaNs";
|
return "hasNaNs";
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public String tensorflowName() {
|
|
||||||
return "hasNans";
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<SDVariable> doDiff(List<SDVariable> i_v) {
|
public List<SDVariable> doDiff(List<SDVariable> i_v) {
|
||||||
|
|
|
@ -19,6 +19,7 @@ package org.nd4j.linalg.api.ops.impl.reduce.custom;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
|
import org.nd4j.imports.NoOpNameFoundException;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
@ -89,9 +90,4 @@ public class LogSumExp extends DynamicCustomOp {
|
||||||
public String onnxName() {
|
public String onnxName() {
|
||||||
return "ReduceLogSumExp";
|
return "ReduceLogSumExp";
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public String tensorflowName() {
|
|
||||||
return "reduce_logsumexp";
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -57,16 +57,6 @@ public class Entropy extends BaseReduceFloatOp {
|
||||||
return "entropy";
|
return "entropy";
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public String onnxName() {
|
|
||||||
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String tensorflowName() {
|
|
||||||
return "entropy";
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Type getOpType() {
|
public Type getOpType() {
|
||||||
return Type.REDUCE_FLOAT;
|
return Type.REDUCE_FLOAT;
|
||||||
|
|
|
@ -18,6 +18,7 @@ package org.nd4j.linalg.api.ops.impl.reduce.floating;
|
||||||
|
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.imports.NoOpNameFoundException;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.BaseReduceFloatOp;
|
import org.nd4j.linalg.api.ops.BaseReduceFloatOp;
|
||||||
import org.nd4j.linalg.ops.transforms.Transforms;
|
import org.nd4j.linalg.ops.transforms.Transforms;
|
||||||
|
@ -80,8 +81,4 @@ public class Norm2 extends BaseReduceFloatOp {
|
||||||
return "Norm";
|
return "Norm";
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public String tensorflowName() {
|
|
||||||
return "norm";
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -54,17 +54,6 @@ public class CountNonZero extends BaseReduceLongOp {
|
||||||
return "countNonZero";
|
return "countNonZero";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String onnxName() {
|
|
||||||
throw new NoOpNameFoundException("No onnx name found for shape " + opName());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String tensorflowName() {
|
|
||||||
return "count_nonzero";
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
||||||
return Collections.singletonList(f().zerosLike(arg()));
|
return Collections.singletonList(f().zerosLike(arg()));
|
||||||
|
|
|
@ -92,9 +92,4 @@ public class CosineDistance extends BaseReduce3Op {
|
||||||
List<SDVariable> diff = CosineSimilarity.doDiff(sameDiff, f(), larg(), rarg(), i_v1.get(0), keepDims, dimensions);
|
List<SDVariable> diff = CosineSimilarity.doDiff(sameDiff, f(), larg(), rarg(), i_v1.get(0), keepDims, dimensions);
|
||||||
return Arrays.asList(f().neg(diff.get(0)), f().neg(diff.get(1)));
|
return Arrays.asList(f().neg(diff.get(0)), f().neg(diff.get(1)));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public String tensorflowName() {
|
|
||||||
return "cosine_distance";
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -72,16 +72,6 @@ public class ScalarAdd extends BaseScalarOp {
|
||||||
return "add_scalar";
|
return "add_scalar";
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public String onnxName() {
|
|
||||||
throw new NoOpNameFoundException("No ONNX op name found for: " + getClass().getName());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String tensorflowName() {
|
|
||||||
return "RealAdd";
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<SDVariable> doDiff(List<SDVariable> i_v1) {
|
public List<SDVariable> doDiff(List<SDVariable> i_v1) {
|
||||||
SDVariable g = i_v1.get(0);
|
SDVariable g = i_v1.get(0);
|
||||||
|
|
|
@ -58,16 +58,6 @@ public class ScalarMax extends BaseScalarOp {
|
||||||
return "max_scalar";
|
return "max_scalar";
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public String onnxName() {
|
|
||||||
throw new NoOpNameFoundException("No ONNX op name found for: " + getClass().getName());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String tensorflowName() {
|
|
||||||
return "RealMax";
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<SDVariable> doDiff(List<SDVariable> i_v1) {
|
public List<SDVariable> doDiff(List<SDVariable> i_v1) {
|
||||||
SDVariable mask = arg().gt(scalarValue.getDouble(0)).castTo(arg().dataType());
|
SDVariable mask = arg().gt(scalarValue.getDouble(0)).castTo(arg().dataType());
|
||||||
|
|
|
@ -52,16 +52,6 @@ public class ScalarMin extends BaseScalarOp {
|
||||||
return 13;
|
return 13;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public String onnxName() {
|
|
||||||
throw new NoOpNameFoundException("No ONNX op name found for: " + getClass().getName());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String tensorflowName() {
|
|
||||||
return "RealMin";
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String opName() {
|
public String opName() {
|
||||||
return "scalar_min";
|
return "scalar_min";
|
||||||
|
|
|
@ -66,15 +66,7 @@ public class ScalarMultiplication extends BaseScalarOp {
|
||||||
public String opName() {
|
public String opName() {
|
||||||
return "mul_scalar";
|
return "mul_scalar";
|
||||||
}
|
}
|
||||||
@Override
|
|
||||||
public String onnxName() {
|
|
||||||
throw new NoOpNameFoundException("No ONNX op name found for: " + getClass().getName());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String tensorflowName() {
|
|
||||||
return "RealMul";
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<SDVariable> doDiff(List<SDVariable> i_v1) {
|
public List<SDVariable> doDiff(List<SDVariable> i_v1) {
|
||||||
|
|
|
@ -66,19 +66,6 @@ public class ScalarSubtraction extends BaseScalarOp {
|
||||||
return "sub_scalar";
|
return "sub_scalar";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String onnxName() {
|
|
||||||
throw new NoOpNameFoundException("No ONNX op name found for: " + getClass().getName());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String tensorflowName() {
|
|
||||||
return "RealSub";
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<SDVariable> doDiff(List<SDVariable> i_v1) {
|
public List<SDVariable> doDiff(List<SDVariable> i_v1) {
|
||||||
SDVariable g = i_v1.get(0);
|
SDVariable g = i_v1.get(0);
|
||||||
|
|
|
@ -18,6 +18,7 @@ package org.nd4j.linalg.api.ops.impl.scalar.comparison;
|
||||||
|
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.imports.NoOpNameFoundException;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.BaseScalarBoolOp;
|
import org.nd4j.linalg.api.ops.BaseScalarBoolOp;
|
||||||
|
|
||||||
|
@ -61,12 +62,6 @@ public class ScalarAnd extends BaseScalarBoolOp {
|
||||||
return "AndScalar";
|
return "AndScalar";
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public String tensorflowName() {
|
|
||||||
return "AndScalar";
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
||||||
//Not continuously differentiable, but 0 gradient in most places
|
//Not continuously differentiable, but 0 gradient in most places
|
||||||
|
|
|
@ -18,6 +18,7 @@ package org.nd4j.linalg.api.ops.impl.scalar.comparison;
|
||||||
|
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.imports.NoOpNameFoundException;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.BaseScalarBoolOp;
|
import org.nd4j.linalg.api.ops.BaseScalarBoolOp;
|
||||||
import org.nd4j.linalg.api.ops.BaseScalarOp;
|
import org.nd4j.linalg.api.ops.BaseScalarOp;
|
||||||
|
@ -87,11 +88,4 @@ public class ScalarEps extends BaseScalarBoolOp {
|
||||||
public String onnxName() {
|
public String onnxName() {
|
||||||
return "ScalarEps";
|
return "ScalarEps";
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public String tensorflowName() {
|
|
||||||
return "ScalarEps";
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -83,15 +83,4 @@ public class ScalarEquals extends BaseScalarBoolOp {
|
||||||
return Arrays.asList(sameDiff.zerosLike(arg()));
|
return Arrays.asList(sameDiff.zerosLike(arg()));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public String onnxName() {
|
|
||||||
throw new NoOpNameFoundException("No ONNX op name found for: " + getClass().getName());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String tensorflowName() {
|
|
||||||
return "equal";
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -77,16 +77,4 @@ public class ScalarGreaterThan extends BaseScalarBoolOp {
|
||||||
|
|
||||||
return Arrays.asList(sameDiff.zerosLike(arg()));
|
return Arrays.asList(sameDiff.zerosLike(arg()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String onnxName() {
|
|
||||||
throw new NoOpNameFoundException("No ONNX op name found for: " + getClass().getName());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String tensorflowName() {
|
|
||||||
return "greater";
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -70,18 +70,6 @@ public class ScalarGreaterThanOrEqual extends BaseScalarBoolOp {
|
||||||
return "greaterthanorequal_scalar";
|
return "greaterthanorequal_scalar";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String onnxName() {
|
|
||||||
throw new NoOpNameFoundException("No ONNX op name found for: " + getClass().getName());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String tensorflowName() {
|
|
||||||
return "greater_equal";
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
||||||
//Not continuously differentiable, but 0 gradient in most places
|
//Not continuously differentiable, but 0 gradient in most places
|
||||||
|
|
|
@ -63,12 +63,6 @@ public class ScalarLessThan extends BaseScalarBoolOp {
|
||||||
return "Less";
|
return "Less";
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public String tensorflowName() {
|
|
||||||
return "less";
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
||||||
//Not continuously differentiable, but 0 gradient in most places
|
//Not continuously differentiable, but 0 gradient in most places
|
||||||
|
|
|
@ -63,17 +63,6 @@ public class ScalarLessThanOrEqual extends BaseScalarBoolOp {
|
||||||
return "lessthanorequal_scalar";
|
return "lessthanorequal_scalar";
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public String onnxName() {
|
|
||||||
throw new NoOpNameFoundException("No ONNX op name found for: " + getClass().getName());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String tensorflowName() {
|
|
||||||
return "less_equal";
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
||||||
//Not continuously differentiable, but 0 gradient in most places
|
//Not continuously differentiable, but 0 gradient in most places
|
||||||
|
|
|
@ -18,6 +18,7 @@ package org.nd4j.linalg.api.ops.impl.scalar.comparison;
|
||||||
|
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.imports.NoOpNameFoundException;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.BaseScalarBoolOp;
|
import org.nd4j.linalg.api.ops.BaseScalarBoolOp;
|
||||||
|
|
||||||
|
@ -66,12 +67,6 @@ public class ScalarNot extends BaseScalarBoolOp {
|
||||||
return "NotScalar";
|
return "NotScalar";
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public String tensorflowName() {
|
|
||||||
return "Not_Scalar";
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
||||||
//Not continuously differentiable, but 0 gradient in most places
|
//Not continuously differentiable, but 0 gradient in most places
|
||||||
|
|
|
@ -63,17 +63,6 @@ public class ScalarNotEquals extends BaseScalarBoolOp {
|
||||||
return "notequals_scalar";
|
return "notequals_scalar";
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public String onnxName() {
|
|
||||||
throw new NoOpNameFoundException("No ONNX op name found for: " + getClass().getName());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String tensorflowName() {
|
|
||||||
return "logical_not";
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
||||||
//Not continuously differentiable, but 0 gradient in most places
|
//Not continuously differentiable, but 0 gradient in most places
|
||||||
|
|
|
@ -18,6 +18,7 @@ package org.nd4j.linalg.api.ops.impl.scalar.comparison;
|
||||||
|
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.imports.NoOpNameFoundException;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.BaseScalarBoolOp;
|
import org.nd4j.linalg.api.ops.BaseScalarBoolOp;
|
||||||
|
|
||||||
|
@ -66,11 +67,6 @@ public class ScalarOr extends BaseScalarBoolOp {
|
||||||
return "OrScalar";
|
return "OrScalar";
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public String tensorflowName() {
|
|
||||||
return "Or_Scalar";
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
||||||
|
|
|
@ -18,6 +18,7 @@ package org.nd4j.linalg.api.ops.impl.scalar.comparison;
|
||||||
|
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.imports.NoOpNameFoundException;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.BaseScalarBoolOp;
|
import org.nd4j.linalg.api.ops.BaseScalarBoolOp;
|
||||||
|
|
||||||
|
@ -66,12 +67,6 @@ public class ScalarXor extends BaseScalarBoolOp {
|
||||||
return "Xor_scalar";
|
return "Xor_scalar";
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public String tensorflowName() {
|
|
||||||
return "Xor_scalar";
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
||||||
//Not continuously differentiable, but 0 gradient in most places
|
//Not continuously differentiable, but 0 gradient in most places
|
||||||
|
|
|
@ -19,6 +19,7 @@ package org.nd4j.linalg.api.ops.impl.shape;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.imports.NoOpNameFoundException;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
import org.tensorflow.framework.AttrValue;
|
import org.tensorflow.framework.AttrValue;
|
||||||
import org.tensorflow.framework.GraphDef;
|
import org.tensorflow.framework.GraphDef;
|
||||||
|
@ -52,12 +53,6 @@ public class ApplyGradientDescent extends DynamicCustomOp {
|
||||||
return "ApplyGradientDescent";
|
return "ApplyGradientDescent";
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public String tensorflowName() {
|
|
||||||
return "ApplyGradientDescent";
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
||||||
/*
|
/*
|
||||||
|
|
|
@ -18,14 +18,10 @@ package org.nd4j.linalg.api.ops.impl.shape;
|
||||||
|
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
|
||||||
import org.nd4j.imports.NoOpNameFoundException;
|
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.Arrays;
|
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
|
@ -112,17 +108,6 @@ public class Eye extends DynamicCustomOp {
|
||||||
addTArgument((double) dataType.toInt());
|
addTArgument((double) dataType.toInt());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public String onnxName() {
|
|
||||||
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String tensorflowName() {
|
|
||||||
return "Eye";
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String opName() {
|
public String opName() {
|
||||||
return "eye";
|
return "eye";
|
||||||
|
|
|
@ -21,10 +21,8 @@ import onnx.Onnx;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.imports.NoOpNameFoundException;
|
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.tensorflow.framework.AttrValue;
|
import org.tensorflow.framework.AttrValue;
|
||||||
import org.tensorflow.framework.GraphDef;
|
import org.tensorflow.framework.GraphDef;
|
||||||
|
@ -69,16 +67,6 @@ public class MergeAvg extends DynamicCustomOp {
|
||||||
super.initFromOnnx(node, initWith, attributesForNode, graph);
|
super.initFromOnnx(node, initWith, attributesForNode, graph);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public String onnxName() {
|
|
||||||
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String tensorflowName() {
|
|
||||||
return "MergeAvg";
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<SDVariable> doDiff(List<SDVariable> i_v) {
|
public List<SDVariable> doDiff(List<SDVariable> i_v) {
|
||||||
int nArgs = args().length;
|
int nArgs = args().length;
|
||||||
|
|
|
@ -22,11 +22,8 @@ import onnx.Onnx;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.imports.NoOpNameFoundException;
|
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
|
||||||
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
|
||||||
import org.tensorflow.framework.AttrValue;
|
import org.tensorflow.framework.AttrValue;
|
||||||
import org.tensorflow.framework.GraphDef;
|
import org.tensorflow.framework.GraphDef;
|
||||||
import org.tensorflow.framework.NodeDef;
|
import org.tensorflow.framework.NodeDef;
|
||||||
|
@ -68,16 +65,6 @@ public class MergeMax extends DynamicCustomOp {
|
||||||
super.initFromOnnx(node, initWith, attributesForNode, graph);
|
super.initFromOnnx(node, initWith, attributesForNode, graph);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public String onnxName() {
|
|
||||||
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String tensorflowName() {
|
|
||||||
return "MergeMax";
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<SDVariable> doDiff(List<SDVariable> i_v) {
|
public List<SDVariable> doDiff(List<SDVariable> i_v) {
|
||||||
SDVariable gradient = sameDiff.setupFunction(i_v.get(0));
|
SDVariable gradient = sameDiff.setupFunction(i_v.get(0));
|
||||||
|
|
|
@ -21,7 +21,6 @@ import onnx.Onnx;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.imports.NoOpNameFoundException;
|
|
||||||
import org.nd4j.imports.descriptors.properties.PropertyMapping;
|
import org.nd4j.imports.descriptors.properties.PropertyMapping;
|
||||||
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
@ -46,17 +45,6 @@ public class ParallelStack extends DynamicCustomOp {
|
||||||
super(null, sameDiff, values, false);
|
super(null, sameDiff, values, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public String onnxName() {
|
|
||||||
throw new NoOpNameFoundException("No onnx opName found for " + opName());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String tensorflowName() {
|
|
||||||
return "parallel_stack";
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String opName() {
|
public String opName() {
|
||||||
return "parallel_stack";
|
return "parallel_stack";
|
||||||
|
|
|
@ -21,7 +21,6 @@ import lombok.val;
|
||||||
import onnx.Onnx;
|
import onnx.Onnx;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
|
||||||
import org.nd4j.imports.descriptors.properties.PropertyMapping;
|
import org.nd4j.imports.descriptors.properties.PropertyMapping;
|
||||||
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
@ -115,12 +114,6 @@ public class Repeat extends DynamicCustomOp {
|
||||||
return "Repeat";
|
return "Repeat";
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public String tensorflowName() {
|
|
||||||
return "Repeat";
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<SDVariable> doDiff(List<SDVariable> i_v) {
|
public List<SDVariable> doDiff(List<SDVariable> i_v) {
|
||||||
SDVariable ret = outputVariables()[0];
|
SDVariable ret = outputVariables()[0];
|
||||||
|
|
|
@ -18,11 +18,9 @@ package org.nd4j.linalg.api.ops.impl.shape;
|
||||||
|
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import onnx.OnnxMl;
|
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.imports.NoOpNameFoundException;
|
|
||||||
import org.nd4j.imports.descriptors.properties.PropertyMapping;
|
import org.nd4j.imports.descriptors.properties.PropertyMapping;
|
||||||
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
@ -97,17 +95,6 @@ public class SequenceMask extends DynamicCustomOp {
|
||||||
return "sequence_mask";
|
return "sequence_mask";
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public String onnxName() {
|
|
||||||
throw new NoOpNameFoundException("No onnx opName found for " + opName());
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String tensorflowName() {
|
|
||||||
return "SequenceMask";
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<SDVariable> doDiff(List<SDVariable> grad){
|
public List<SDVariable> doDiff(List<SDVariable> grad){
|
||||||
//Input is integer indices
|
//Input is integer indices
|
||||||
|
|
|
@ -18,9 +18,6 @@ package org.nd4j.linalg.api.ops.impl.transforms;
|
||||||
|
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.imports.NoOpNameFoundException;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.api.ops.BaseTransformOp;
|
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
|
@ -45,16 +42,6 @@ public class Angle extends DynamicCustomOp {
|
||||||
return "zeros_like";
|
return "zeros_like";
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public String onnxName() { throw new NoOpNameFoundException("No onnx op opName found for " + opName());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String tensorflowName() {
|
|
||||||
return "Angle";
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<SDVariable> doDiff(List<SDVariable> i_v) {
|
public List<SDVariable> doDiff(List<SDVariable> i_v) {
|
||||||
return Collections.singletonList(f().zerosLike(arg()));
|
return Collections.singletonList(f().zerosLike(arg()));
|
||||||
|
|
|
@ -80,7 +80,8 @@ public class MaxOut extends BaseTransformOp {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String tensorflowName() {
|
public String tensorflowName() {
|
||||||
return "Maxout";
|
throw new NoOpNameFoundException("Tensorflow name not found for " + opName());
|
||||||
|
//return "Maxout";
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -20,6 +20,7 @@ import lombok.NonNull;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
|
import org.nd4j.imports.NoOpNameFoundException;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.BaseTransformBoolOp;
|
import org.nd4j.linalg.api.ops.BaseTransformBoolOp;
|
||||||
|
@ -59,12 +60,12 @@ public class BooleanNot extends BaseTransformBoolOp {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String onnxName() {
|
public String onnxName() {
|
||||||
return "not_applicable";
|
throw new NoOpNameFoundException("Onnx name not found for " + opName());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String tensorflowName() {
|
public String tensorflowName() {
|
||||||
return "not_applicable";
|
throw new NoOpNameFoundException("Tensorflow name not found for " + opName());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -61,7 +61,7 @@ public class BitwiseAnd extends BaseDynamicTransformOp {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String tensorflowName() {
|
public String tensorflowName() {
|
||||||
return "bitwise_and";
|
return "BitwiseAnd";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -22,7 +22,6 @@ import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp;
|
|
||||||
|
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
|
@ -49,12 +48,6 @@ public class IsNumericTensor extends DynamicCustomOp {
|
||||||
return "is_numeric_tensor";
|
return "is_numeric_tensor";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String tensorflowName() {
|
|
||||||
return "IsNumericTensor";
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
||||||
throw new UnsupportedOperationException("");
|
throw new UnsupportedOperationException("");
|
||||||
|
|
|
@ -22,9 +22,7 @@ import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp;
|
|
||||||
|
|
||||||
import java.util.Arrays;
|
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
|
@ -50,12 +48,6 @@ public class IsStrictlyIncreasing extends DynamicCustomOp {
|
||||||
return "is_strictly_increasing";
|
return "is_strictly_increasing";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String tensorflowName() {
|
|
||||||
return "IsStrictlyIncreasing";
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
||||||
return Collections.singletonList(sameDiff.zerosLike(arg()));
|
return Collections.singletonList(sameDiff.zerosLike(arg()));
|
||||||
|
|
|
@ -39,12 +39,6 @@ public class LogicalXor extends DynamicCustomOp {
|
||||||
return "boolean_xor";
|
return "boolean_xor";
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public String tensorflowName() {
|
|
||||||
return "LogicalXor";
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
||||||
return Arrays.asList( sameDiff.zerosLike(larg()), sameDiff.zerosLike(rarg()));
|
return Arrays.asList( sameDiff.zerosLike(larg()), sameDiff.zerosLike(rarg()));
|
||||||
|
|
|
@ -62,11 +62,6 @@ public class SigmoidDerivative extends DynamicCustomOp {
|
||||||
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
|
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public String tensorflowName() {
|
|
||||||
return "SigmoidGrad";
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<SDVariable> doDiff(List<SDVariable> i_v) {
|
public List<SDVariable> doDiff(List<SDVariable> i_v) {
|
||||||
throw new UnsupportedOperationException();
|
throw new UnsupportedOperationException();
|
||||||
|
|
|
@ -20,6 +20,7 @@ import lombok.NonNull;
|
||||||
import org.nd4j.autodiff.functions.DifferentialFunction;
|
import org.nd4j.autodiff.functions.DifferentialFunction;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.imports.NoOpNameFoundException;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.BaseTransformBoolOp;
|
import org.nd4j.linalg.api.ops.BaseTransformBoolOp;
|
||||||
import org.nd4j.linalg.api.ops.BaseTransformOp;
|
import org.nd4j.linalg.api.ops.BaseTransformOp;
|
||||||
|
@ -68,7 +69,8 @@ public class Not extends BaseTransformBoolOp {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String tensorflowName() {
|
public String tensorflowName() {
|
||||||
return "Not";
|
throw new NoOpNameFoundException("Tensorflow name not found for " + opName());
|
||||||
|
//return "Not";
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -18,11 +18,14 @@ package org.nd4j.linalg.api.ops.impl.transforms.same;
|
||||||
|
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.base.Preconditions;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.BaseTransformOp;
|
import org.nd4j.linalg.api.ops.BaseTransformOp;
|
||||||
import org.nd4j.linalg.api.ops.BaseTransformSameOp;
|
import org.nd4j.linalg.api.ops.BaseTransformSameOp;
|
||||||
|
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
|
import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -73,4 +76,11 @@ public class Abs extends BaseTransformSameOp {
|
||||||
SDVariable ret = f().sign(arg()).mul(i_v.get(0));
|
SDVariable ret = f().sign(arg()).mul(i_v.get(0));
|
||||||
return Arrays.asList(ret);
|
return Arrays.asList(ret);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
||||||
|
int n = args().length;
|
||||||
|
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes);
|
||||||
|
return Collections.singletonList(inputDataTypes.get(0));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -65,8 +65,10 @@ public class GELU extends BaseTransformStrictOp {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String tensorflowName() {
|
public String tensorflowName()
|
||||||
return "GELU";
|
{
|
||||||
|
throw new NoOpNameFoundException("No tensorflow op opName found for " + opName());
|
||||||
|
//return "GELU";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -65,7 +65,8 @@ public class PreciseGELU extends BaseTransformStrictOp {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String tensorflowName() {
|
public String tensorflowName() {
|
||||||
return "PreciseGELU";
|
throw new NoOpNameFoundException("Tensorflow name not found for " + opName());
|
||||||
|
//return "PreciseGELU";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -73,7 +73,8 @@ public class DropOut extends BaseRandomOp {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String tensorflowName() {
|
public String tensorflowName() {
|
||||||
return opName();
|
throw new NoOpNameFoundException("No tensorflow op name found for: " + getClass().getName());
|
||||||
|
//return opName();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -1641,6 +1641,19 @@ public class SameDiffTests extends BaseNd4jTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Ignore(/*AS - 20191114 https://github.com/eclipse/deeplearning4j/issues/8393*/)
|
||||||
|
@Test
|
||||||
|
public void testIsStrictlyIncShape() {
|
||||||
|
int nOut = 0;
|
||||||
|
int minibatch = 0;
|
||||||
|
|
||||||
|
INDArray ia = Nd4j.randn(minibatch, nOut);
|
||||||
|
INDArray expOut = Nd4j.create(DataType.BOOL, ia.shape());
|
||||||
|
|
||||||
|
Nd4j.exec(new IsStrictlyIncreasing(new INDArray[]{ia}, new INDArray[]{expOut}));
|
||||||
|
System.out.println(expOut);
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testExpandDims2d() {
|
public void testExpandDims2d() {
|
||||||
val origShape = new long[]{3, 4};
|
val origShape = new long[]{3, 4};
|
||||||
|
|
|
@ -635,7 +635,7 @@ public class TFGraphTestAllHelper {
|
||||||
for (int i = 0; i < resources.size(); i++) {
|
for (int i = 0; i < resources.size(); i++) {
|
||||||
URI u = resources.get(i).getFirst().getURI();
|
URI u = resources.get(i).getFirst().getURI();
|
||||||
String varName = u.toString();
|
String varName = u.toString();
|
||||||
int idx = varName.lastIndexOf(modelName);
|
int idx = varName.indexOf(modelName);
|
||||||
varName = varName.substring(idx + modelName.length()+1); //+1 for "/"
|
varName = varName.substring(idx + modelName.length()+1); //+1 for "/"
|
||||||
varName = varName.replaceAll("____","/");
|
varName = varName.replaceAll("____","/");
|
||||||
varName = varName.replaceAll(".placeholder.shape","");
|
varName = varName.replaceAll(".placeholder.shape","");
|
||||||
|
@ -752,7 +752,8 @@ public class TFGraphTestAllHelper {
|
||||||
return (t, s) -> Nd4j.sort(t, true).equals(Nd4j.sort(s, true));
|
return (t, s) -> Nd4j.sort(t, true).equals(Nd4j.sort(s, true));
|
||||||
}
|
}
|
||||||
|
|
||||||
if(modelName.startsWith("alpha_dropout") || modelName.startsWith("layers_dropout"))
|
if(modelName.startsWith("alpha_dropout") || modelName.startsWith("layers_dropout") || modelName.equals("dropout"))
|
||||||
|
//We can't compare dropout using simple equality due to randomness
|
||||||
return (t, s) -> {
|
return (t, s) -> {
|
||||||
double[] tfNums = t.ravel().toDoubleVector();
|
double[] tfNums = t.ravel().toDoubleVector();
|
||||||
double[] sdNums = s.ravel().toDoubleVector();
|
double[] sdNums = s.ravel().toDoubleVector();
|
||||||
|
|
|
@ -70,8 +70,15 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a
|
||||||
//Still failing 2019/09/11
|
//Still failing 2019/09/11
|
||||||
"slogdet/.*",
|
"slogdet/.*",
|
||||||
|
|
||||||
|
// Failing 2019/11/14 - |https://github.com/eclipse/deeplearning4j/issues/8374
|
||||||
|
"adjust_contrast/*",
|
||||||
|
"adjust_contrast/.*",
|
||||||
//Failing 2019/09/11 - https://github.com/eclipse/deeplearning4j/issues/7965
|
//Failing 2019/09/11 - https://github.com/eclipse/deeplearning4j/issues/7965
|
||||||
"bincount/.*",
|
"bincount/.*",
|
||||||
|
// Failing 2019/11/15 https://github.com/eclipse/deeplearning4j/issues/8400
|
||||||
|
"bitcast/.*",
|
||||||
|
// Failing 2019/11/14 https://github.com/eclipse/deeplearning4j/issues/8393
|
||||||
|
"is_strictly_increasing/emptyArrayTest/.*",
|
||||||
|
|
||||||
//TODO floormod and truncatemod behave differently - i.e., "c" vs. "python" semantics. Need to check implementations too
|
//TODO floormod and truncatemod behave differently - i.e., "c" vs. "python" semantics. Need to check implementations too
|
||||||
"truncatemod/.*",
|
"truncatemod/.*",
|
||||||
|
@ -100,7 +107,25 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a
|
||||||
"multinomial/.*",
|
"multinomial/.*",
|
||||||
|
|
||||||
//2019/11/04 AB - disabled, pending libnd4j deconv3d_tf implementation
|
//2019/11/04 AB - disabled, pending libnd4j deconv3d_tf implementation
|
||||||
"conv3d_transpose.*"
|
"conv3d_transpose.*",
|
||||||
|
|
||||||
|
//2019/11/15 - mapping is not present yet https://github.com/eclipse/deeplearning4j/issues/8397
|
||||||
|
"ragged/reduce_mean/.*",
|
||||||
|
|
||||||
|
// 2019/11/15 - missing dtype argument in nd4j, tests are useless https://github.com/eclipse/deeplearning4j/issues/8398
|
||||||
|
"zeros_like/rank2_float32_dtype_int.*",
|
||||||
|
|
||||||
|
// 2019/11/15 - failure https://github.com/eclipse/deeplearning4j/issues/8399
|
||||||
|
"crop_and_resize.*",
|
||||||
|
|
||||||
|
// 2019/11/15 - failure https://github.com/eclipse/deeplearning4j/issues/8401
|
||||||
|
"draw_bounding_boxes.*",
|
||||||
|
|
||||||
|
// 2019/11/15 - failure https://github.com/eclipse/deeplearning4j/issues/8402
|
||||||
|
"fake_quant/min_max_args_per_channel.*",
|
||||||
|
|
||||||
|
// 2019/11/15 - failure https://github.com/eclipse/deeplearning4j/issues/8403
|
||||||
|
"resize_bilinear/int32.*"
|
||||||
};
|
};
|
||||||
|
|
||||||
@BeforeClass
|
@BeforeClass
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.custom;
|
package org.nd4j.linalg.custom;
|
||||||
|
|
||||||
|
import lombok.NonNull;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import org.junit.Ignore;
|
import org.junit.Ignore;
|
||||||
|
@ -29,12 +30,15 @@ import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
import org.nd4j.linalg.api.ops.custom.*;
|
import org.nd4j.linalg.api.ops.custom.*;
|
||||||
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
|
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
|
||||||
import org.nd4j.linalg.api.ops.executioner.OpStatus;
|
import org.nd4j.linalg.api.ops.executioner.OpStatus;
|
||||||
import org.nd4j.linalg.api.ops.impl.reduce.Mmul;
|
import org.nd4j.linalg.api.ops.impl.controlflow.Where;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.image.CropAndResize;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.image.ResizeBilinear;
|
||||||
import org.nd4j.linalg.api.ops.impl.reduce.MmulBp;
|
import org.nd4j.linalg.api.ops.impl.reduce.MmulBp;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.any.IsMax;
|
import org.nd4j.linalg.api.ops.impl.transforms.any.IsMax;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp;
|
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.ModOp;
|
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.ModOp;
|
||||||
import org.nd4j.linalg.api.ops.random.compat.RandomStandardNormal;
|
import org.nd4j.linalg.api.ops.random.compat.RandomStandardNormal;
|
||||||
|
import org.nd4j.linalg.api.ops.random.impl.DropOut;
|
||||||
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
||||||
import org.nd4j.linalg.api.shape.Shape;
|
import org.nd4j.linalg.api.shape.Shape;
|
||||||
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
||||||
|
@ -823,6 +827,17 @@ public class CustomOpsTests extends BaseNd4jTest {
|
||||||
assertEquals(expected, out);
|
assertEquals(expected, out);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Ignore("AS 11/13/2019 https://github.com/eclipse/deeplearning4j/issues/8374")
|
||||||
|
@Test
|
||||||
|
public void testAdjustContrastShape(){
|
||||||
|
DynamicCustomOp op = DynamicCustomOp.builder("adjust_contrast_v2")
|
||||||
|
.addInputs(Nd4j.create(DataType.FLOAT, 256, 256,3), Nd4j.scalar(0.5f))
|
||||||
|
.build();
|
||||||
|
List<LongShapeDescriptor> lsd = op.calculateOutputShape();
|
||||||
|
assertEquals(1, lsd.size());
|
||||||
|
assertArrayEquals(new long[]{256, 256, 3}, lsd.get(0).getShape());
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testAdjustContrastV2() {
|
public void testAdjustContrastV2() {
|
||||||
INDArray in = Nd4j.linspace(DataType.DOUBLE,1.0,1.0, 4*4*3).reshape(4,4,3);
|
INDArray in = Nd4j.linspace(DataType.DOUBLE,1.0,1.0, 4*4*3).reshape(4,4,3);
|
||||||
|
@ -840,6 +855,16 @@ public class CustomOpsTests extends BaseNd4jTest {
|
||||||
assertEquals(expected, out);
|
assertEquals(expected, out);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Ignore("AS 11/13/2019 https://github.com/eclipse/deeplearning4j/issues/8374")
|
||||||
|
@Test
|
||||||
|
public void testBitCastShape(){
|
||||||
|
INDArray out = Nd4j.createUninitialized(1,10);
|
||||||
|
BitCast op = new BitCast(Nd4j.zeros(1,10), DataType.FLOAT.toInt(), out);
|
||||||
|
List<LongShapeDescriptor> lsd = op.calculateOutputShape();
|
||||||
|
assertEquals(1, lsd.size());
|
||||||
|
assertArrayEquals(new long[]{1,10}, lsd.get(0).getShape());
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testBitCast() {
|
public void testBitCast() {
|
||||||
INDArray in = Nd4j.linspace(DataType.FLOAT, 1.0f, 1.0f, 8).reshape(2,2,2);
|
INDArray in = Nd4j.linspace(DataType.FLOAT, 1.0f, 1.0f, 8).reshape(2,2,2);
|
||||||
|
@ -852,6 +877,79 @@ public class CustomOpsTests extends BaseNd4jTest {
|
||||||
assertEquals(expected, out);
|
assertEquals(expected, out);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Ignore("AS 11/13/2019 https://github.com/eclipse/deeplearning4j/issues/8374")
|
||||||
|
@Test
|
||||||
|
public void testDrawBoundingBoxesShape() {
|
||||||
|
INDArray images = Nd4j.createFromArray(new float[]{0.7788f, 0.8012f, 0.7244f, 0.2309f, 0.7271f,
|
||||||
|
0.1804f,0.5056f,0.8925f,0.5461f,0.9234f,0.0856f,0.7938f,0.6591f,0.5555f,0.1596f,
|
||||||
|
0.3087f,0.1548f,0.4695f,0.9939f,0.6113f,0.6765f,0.1800f,0.6750f,0.2246f,0.0509f,
|
||||||
|
0.4601f,0.8284f,0.2354f,0.9752f,0.8361f,0.2585f,0.4189f,0.7028f,0.7679f,0.5373f,
|
||||||
|
0.7234f,0.2690f,0.0062f,0.0327f,0.0644f,0.8428f,0.7494f,0.0755f,0.6245f,0.3491f,
|
||||||
|
0.5793f,0.5730f,0.1822f,0.6420f,0.9143f}).reshape(2,5,5,1);
|
||||||
|
INDArray boxes = Nd4j.createFromArray(new float[]{0.7717f, 0.9281f, 0.9846f, 0.4838f,
|
||||||
|
0.6433f, 0.6041f, 0.6501f, 0.7612f,
|
||||||
|
0.7605f, 0.3948f, 0.9493f, 0.8600f,
|
||||||
|
0.7876f, 0.8945f, 0.4638f, 0.7157f}).reshape(2,2,4);
|
||||||
|
INDArray colors = Nd4j.createFromArray(new float[]{0.9441f, 0.5957f}).reshape(1,2);
|
||||||
|
INDArray output = Nd4j.create(DataType.FLOAT, images.shape());
|
||||||
|
val op = new DrawBoundingBoxes(images, boxes, colors, output);
|
||||||
|
Nd4j.exec(op);
|
||||||
|
INDArray expected = Nd4j.createFromArray(new float[]{0.7788f, 0.8012f, 0.7244f, 0.2309f, 0.7271f,
|
||||||
|
0.1804f, 0.5056f, 0.8925f, 0.5461f, 0.9234f, 0.0856f, 0.7938f, 0.9441f,
|
||||||
|
0.9441f, 0.1596f, 0.3087f, 0.1548f, 0.4695f, 0.9939f, 0.6113f, 0.6765f,
|
||||||
|
0.1800f, 0.6750f, 0.2246f, 0.0509f, 0.4601f, 0.8284f, 0.2354f, 0.9752f, 0.8361f,
|
||||||
|
0.2585f, 0.4189f,0.7028f,0.7679f,0.5373f,0.7234f,0.2690f,0.0062f,0.0327f,0.0644f,
|
||||||
|
0.8428f, 0.9441f,0.9441f,0.9441f,0.3491f,0.5793f,0.5730f,0.1822f,0.6420f,0.9143f});
|
||||||
|
assertEquals(expected, output);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Ignore(" 2019/11/15 - failure https://github.com/eclipse/deeplearning4j/issues/8402")
|
||||||
|
@Test
|
||||||
|
public void testFakeQuantAgainstTF_1() {
|
||||||
|
INDArray x = Nd4j.createFromArray(new float[]{ 0.7788f, 0.8012f, 0.7244f, 0.2309f, 0.7271f,
|
||||||
|
0.1804f, 0.5056f, 0.8925f, 0.5461f, 0.9234f,
|
||||||
|
0.0856f, 0.7938f, 0.6591f, 0.5555f, 0.1596f}).reshape(3,5);
|
||||||
|
INDArray min = Nd4j.createFromArray(new float[]{ -0.2283f, -0.0719f, -0.0154f, -0.5162f, -0.3567f});
|
||||||
|
INDArray max = Nd4j.createFromArray(new float[]{ 0.9441f, 0.5957f, 0.8669f, 0.3502f, 0.5100f});
|
||||||
|
|
||||||
|
INDArray expected = Nd4j.createFromArray(new float[]{0.7801f, 0.5966f, 0.7260f, 0.2320f, 0.5084f,
|
||||||
|
0.1800f, 0.5046f, 0.8684f, 0.3513f, 0.5084f,
|
||||||
|
0.0877f, 0.5966f, 0.6600f, 0.3513f, 0.1604f}).reshape(3,5);
|
||||||
|
|
||||||
|
INDArray out = Nd4j.createUninitialized(x.shape());
|
||||||
|
val op = new FakeQuantWithMinMaxVarsPerChannel(x,min,max,out);
|
||||||
|
Nd4j.exec(op);
|
||||||
|
assertEquals(expected, out);
|
||||||
|
|
||||||
|
/*TF: [[ 0.7801, 0.5966, 0.7260, 0.2320, 0.5084],
|
||||||
|
[ 0.1800, 0.5046, 0.8684, 0.3513, 0.5084],
|
||||||
|
[ 0.0877, 0.5966, 0.6600, 0.3513, 0.1604]]
|
||||||
|
SD: [[ 0.7770, 0.5969, 0.7232, 0.2310, 0.5098],
|
||||||
|
[ 0.1793, 0.5053, 0.8685, 0.3500, 0.5098],
|
||||||
|
[ 0.0874, 0.5969, 0.6574, 0.3500, 0.1597]]*/
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testWhereFail() {
|
||||||
|
INDArray in = Nd4j.createFromArray(new float[]{0f, 1.0000f, 1.0000f, 1.0000f, 1.0000f});
|
||||||
|
INDArray out = Nd4j.createUninitialized(4,1);
|
||||||
|
INDArray expected = Nd4j.createFromArray(4,1);
|
||||||
|
val op = new Where(new INDArray[]{in}, new INDArray[]{out});
|
||||||
|
Nd4j.exec(op);
|
||||||
|
assertArrayEquals(new long[]{4,1} , out.shape());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Ignore("2019/11/15 - failure https://github.com/eclipse/deeplearning4j/issues/8403")
|
||||||
|
@Test
|
||||||
|
public void testResizeBilinear1() {
|
||||||
|
|
||||||
|
INDArray x = Nd4j.rand(1, 2,3,4);
|
||||||
|
INDArray z = Nd4j.createUninitialized(x.shape());
|
||||||
|
boolean align = false;
|
||||||
|
val op = new ResizeBilinear(x, z, 10, 10, align);
|
||||||
|
Nd4j.exec(op);
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testCompareAndBitpack() {
|
public void testCompareAndBitpack() {
|
||||||
INDArray in = Nd4j.createFromArray(new double[]{-12.f, -11.f, -10.f, -9.f, -8.f, -7.f, -6.f, -5.f, -4.f, -3.f,
|
INDArray in = Nd4j.createFromArray(new double[]{-12.f, -11.f, -10.f, -9.f, -8.f, -7.f, -6.f, -5.f, -4.f, -3.f,
|
||||||
|
@ -932,6 +1030,30 @@ public class CustomOpsTests extends BaseNd4jTest {
|
||||||
System.out.println(distance);
|
System.out.println(distance);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Ignore("2019/11/15 AS - https://github.com/eclipse/deeplearning4j/issues/8399")
|
||||||
|
@Test
|
||||||
|
public void testCropAndResize() {
|
||||||
|
INDArray image = Nd4j.createUninitialized(DataType.FLOAT, 1, 2, 2, 1);
|
||||||
|
INDArray boxes = Nd4j.createFromArray(new float[]{1,2,3,4}).reshape(1,4);
|
||||||
|
INDArray box_indices = Nd4j.createFromArray(new int[]{1});
|
||||||
|
INDArray crop_size = Nd4j.createFromArray(new int[]{1,2}).reshape(1,2);
|
||||||
|
|
||||||
|
//Output shape mismatch - TF [2, 2, 1, 1] vs SD: [1, 2, 1, 1]
|
||||||
|
INDArray output = Nd4j.create(DataType.FLOAT, 2,2,1,1);
|
||||||
|
|
||||||
|
|
||||||
|
Nd4j.exec(new CropAndResize(image, boxes, box_indices, crop_size, CropAndResize.Method.BILINEAR, 0.5,
|
||||||
|
output));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testLayersDropoutFail() {
|
||||||
|
INDArray input = Nd4j.rand(4, 5);
|
||||||
|
INDArray output = Nd4j.createUninitialized(4, 5);
|
||||||
|
DropOut op = new DropOut(input, output, 0.1);
|
||||||
|
Nd4j.exec(op);
|
||||||
|
System.out.println(output);
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testRange(){
|
public void testRange(){
|
||||||
|
|
Loading…
Reference in New Issue