Merge remote-tracking branch 'konduit/master'

master
AlexDBlack 2019-10-14 17:21:23 +11:00
commit 2d750b69e5
31 changed files with 1422 additions and 110 deletions

View File

@ -1422,6 +1422,7 @@ namespace nd4j {
template <typename T>
void p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const T value);
void p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, NDArray const& value);
template <typename T>

View File

@ -4071,6 +4071,24 @@ void NDArray::p(const Nd4jLong i, const NDArray& scalar) {
NDArray::registerPrimaryUse({this}, {&scalar});
}
////////////////////////////////////////////////////////////////////////
void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const NDArray& scalar) {
if(!scalar.isScalar())
throw std::invalid_argument("NDArray::p method: input array must be scalar!");
if (i >= _length)
throw std::invalid_argument("NDArray::p(i, NDArray_scalar): input index is out of array length !");
// void *p = reinterpret_cast<void *>(scalar.getBuffer());
Nd4jLong coords[4] = {i, j, k, l};
auto xOffset = shape::getOffset(getShapeInfo(), coords);
NDArray::preparePrimaryUse({this}, {&scalar}, true);
// BUILD_SINGLE_PARTIAL_SELECTOR(dataType(), templatedSet<, T>(this->getBuffer(), xOffset, p), LIBND4J_TYPES);
BUILD_SINGLE_SELECTOR(scalar.dataType(), templatedSet, (this->getBuffer(), xOffset, scalar.dataType(), scalar.getBuffer()), LIBND4J_TYPES);
NDArray::registerPrimaryUse({this}, {&scalar});
}
//////////////////////////////////////////////////////////////////////////
void NDArray::addRowVector(const NDArray *row, NDArray *target) const {

View File

@ -77,7 +77,8 @@
(27, LogicalOr) ,\
(28, LogicalXor) ,\
(29, LogicalNot) ,\
(30, LogicalAnd)
(30, LogicalAnd), \
(31, DivideNoNan)
// these ops return same data type as input
#define TRANSFORM_SAME_OPS \
@ -243,8 +244,8 @@
(42, LstmClip), \
(43, TruncateMod) ,\
(44, SquaredReverseSubtract) ,\
(45, ReversePow)
(45, ReversePow), \
(46, DivideNoNan)
@ -378,7 +379,8 @@
(34, AMaxPairwise), \
(35, AMinPairwise) ,\
(36, TruncateMod), \
(37, ReplaceNans)
(37, ReplaceNans), \
(38, DivideNoNan)

View File

@ -46,6 +46,7 @@ namespace nd4j {
static BroadcastOpsTuple Add();
static BroadcastOpsTuple Assign();
static BroadcastOpsTuple Divide();
static BroadcastOpsTuple DivideNoNan();
static BroadcastOpsTuple Multiply();
static BroadcastOpsTuple Subtract();
};

View File

@ -0,0 +1,57 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author George A. Shulinok <sgazeos@gmail.com>
//
#include <op_boilerplate.h>
#if NOT_EXCLUDED(OP_divide_no_nan)
#include <ops/declarable/generic/helpers/BroadcastHelper.h>
#include <ops/declarable/CustomOperations.h>
namespace nd4j {
namespace ops {
BROADCASTABLE_OP_IMPL(divide_no_nan, 0, 0) {
auto x = INPUT_VARIABLE(0);
auto y = INPUT_VARIABLE(1);
auto z = OUTPUT_VARIABLE(0);
BROADCAST_CHECK_EMPTY(x,y,z);
REQUIRE_TRUE(!y->isB(), 0, "DIVIDE_NO_NAN OP: you can't divide by bool array!");
auto tZ = BroadcastHelper::broadcastApply(BroadcastOpsTuple::DivideNoNan(), x, y, z);
if (tZ == nullptr)
return ND4J_STATUS_KERNEL_FAILURE;
else if (tZ != z) {
OVERWRITE_RESULT(tZ);
}
return Status::OK();
}
DECLARE_SYN(Div, divide);
DECLARE_TYPES(divide_no_nan) {
getOpDescriptor()
->setAllowedInputTypes(0, DataType::ANY)
->setAllowedInputTypes(1, DataType::ANY)
->setAllowedOutputTypes(0, DataType::INHERIT);
}
}
}
#endif

View File

@ -0,0 +1,92 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author George A. Shulinok <sgazeos@gmail.com>
//
#include <op_boilerplate.h>
#if NOT_EXCLUDED(OP_bitcast)
#include <array/DataTypeUtils.h>
#include <ops/declarable/CustomOperations.h>
namespace nd4j {
namespace ops {
CUSTOM_OP_IMPL(bitcast, 1, 1, false, 0, 1) {
auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0);
// when empty - nothing to do
if(input->isEmpty()){
REQUIRE_TRUE(output->isEmpty(), 0, "BITCAST: If input is empty, output array must also be empty.");
return Status::OK();
}
// buffers for both input and output should be equals
DataBuffer buf(input->buffer(), input->specialBuffer(), input->lengthOf() * input->sizeOfT(), input->dataType());
*(output->dataBuffer()) = buf;
return Status::OK();
}
DECLARE_SYN(BitCast, bitcast);
DECLARE_SHAPE_FN(bitcast) {
auto inShape = inputShape->at(0);
auto inputRank = shape::rank(inShape);
auto it = INT_ARG(0);
DataType newType = DataTypeUtils::fromInt(it);
DataType oldType = ArrayOptions::dataType(inShape);
// correct output shape to conform with output data type
auto inputSize = DataTypeUtils::sizeOf(oldType);
auto outputSize = DataTypeUtils::sizeOf(newType);
if (shape::length(inShape) == 0)
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(inShape, newType)));
if (inputSize == outputSize) {
// only type should be changed
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(inShape, newType)));
}
else if (inputSize > outputSize) {
// range of output increased by 1 with inputSize / outputSize as last dimension
std::vector<Nd4jLong> shapeOf(inputRank + 1);
int i;
for (i = 0; i < inputRank; ++i) {
shapeOf[i] = inShape[i + 1];
}
shapeOf[i] = inputSize / outputSize;
auto outputShape = ConstantShapeHelper::getInstance()->createShapeInfo(newType, shape::order(inShape), shapeOf);
return SHAPELIST(outputShape);
}
REQUIRE_TRUE(shape::sizeAt(inShape, -1) == outputSize / inputSize, 0, "BITCAST: %ull > %ull. So last dimension should be %ull, but %i given.", inputSize, outputSize, outputSize / inputSize, shape::sizeAt(inShape, -1));
std::vector<Nd4jLong> shapeOf(inputRank - 1);
for (auto i = 0; i < shapeOf.size(); ++i) {
shapeOf[i] = inShape[i + 1];
}
auto outputShape = ConstantShapeHelper::getInstance()->createShapeInfo(newType, shape::order(inShape), shapeOf);
return SHAPELIST(outputShape);
}
DECLARE_TYPES(bitcast) {
getOpDescriptor()
->setAllowedInputTypes(nd4j::DataType::ANY)
->setAllowedOutputTypes(nd4j::DataType::ANY);
}
}
}
#endif

View File

@ -0,0 +1,98 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author George A. Shulinok <sgazeos@gmail.com>
//
#include <op_boilerplate.h>
#if NOT_EXCLUDED(OP_adjust_contrast)
#include <ops/declarable/headers/parity_ops.h>
#include <NDArrayFactory.h>
namespace nd4j {
namespace ops {
CONFIGURABLE_OP_IMPL(adjust_contrast, 1, 1, true, 1, 0) {
auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0);
const double factor = T_ARG(0);
REQUIRE_TRUE(input->rankOf() > 2, 0, "ADJUST_CONTRAST: op expects rank of input array to be >= 3, but got %i instead", input->rankOf());
REQUIRE_TRUE(input->sizeAt(-1) == 3, 0, "ADJUST_CONTRAST: operation expects image with 3 channels (R, G, B), but got %i instead", input->sizeAt(-1));
// compute mean before
// fill up axes vector first
std::vector<int> axes(input->rankOf() - 1);
for (auto i = 0; i < axes.size(); ++i)
axes[i] = i;
// mean as reduction for last dimension set
auto mean = input->reduceAlongDims(reduce::Mean, axes);
NDArray factorT(output->dataType(), block.launchContext()); // = NDArrayFactory::create(factor, block.launchContext());
factorT.p(0, factor);
// this is contrast calculation
*output = (*input - mean) * factorT + mean;
return Status::OK();
}
DECLARE_TYPES(adjust_contrast) {
getOpDescriptor()->setAllowedInputTypes(nd4j::DataType::ANY)
->setAllowedOutputTypes({ALL_FLOATS})
->setSameMode(true);
}
CONFIGURABLE_OP_IMPL(adjust_contrast_v2, 1, 1, true, 1, 0) {
auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0);
const double factor = T_ARG(0);
REQUIRE_TRUE(input->rankOf() > 2, 0, "ADJUST_CONTRAST: op expects rank of input array to be >= 3, but got %i instead", input->rankOf());
REQUIRE_TRUE(input->sizeAt(-1) == 3, 0, "ADJUST_CONTRAST: operation expects image with 3 channels (R, G, B), but got %i instead", input->sizeAt(-1));
// compute mean before
std::vector<int> axes(input->rankOf() - 1);
for (auto i = 0; i < axes.size(); ++i)
axes[i] = i;
// mean as reduction for last dimension set
auto mean = input->reduceAlongDims(reduce::Mean, axes);
// result as (x - mean) * factor + mean
std::unique_ptr<NDArray> temp(input->dup());
input->applyTrueBroadcast(BroadcastOpsTuple::Subtract(), &mean, temp.get());
temp->applyScalar(scalar::Multiply, factor);
temp->applyTrueBroadcast(BroadcastOpsTuple::Add(), &mean, output);
return Status::OK();
}
DECLARE_TYPES(adjust_contrast_v2) {
getOpDescriptor()->setAllowedInputTypes(nd4j::DataType::ANY)
->setAllowedOutputTypes({ALL_FLOATS})
->setSameMode(true);
}
}
}
#endif

View File

@ -0,0 +1,60 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author sgazeos@gmail.com
//
#include <ops/declarable/generic/helpers/BroadcastHelper.h>
#include <ops/declarable/headers/parity_ops.h>
#include <ops/declarable/headers/datatypes.h>
#include <NDArrayFactory.h>
namespace nd4j {
namespace ops {
CUSTOM_OP_IMPL(compare_and_bitpack, 2, 1, false, 0, 0) {
auto x = INPUT_VARIABLE(0);
auto y = INPUT_VARIABLE(1);
auto z = OUTPUT_VARIABLE(0);
auto z0 = NDArrayFactory::create<bool>(x->ordering(), x->getShapeAsVector());
BROADCAST_CHECK_EMPTY(x, y, (&z0));
auto tZ = BroadcastHelper::broadcastApply(BROADCAST_BOOL(GreaterThan), x, y, &z0);
bitcast res;
auto status = res.execute({tZ}, {z}, {}, {DataType::UINT8}, {}, false);
if (tZ != &z0) {
delete tZ;
}
return status;
}
DECLARE_TYPES(compare_and_bitpack) {
getOpDescriptor()
->setAllowedInputTypes(0, DataType::ANY)
->setAllowedInputTypes(1, DataType::ANY)
->setAllowedOutputTypes(0, DataType::UINT8);
}
DECLARE_SHAPE_FN(compare_and_bitpack) {
auto inShape = inputShape->at(0);
DataType newType = DataType::UINT8;
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(inShape, newType)));
}
}
}

View File

@ -0,0 +1,49 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author George A. Shulinok <sgazeos@gmail.com>
//
#include <op_boilerplate.h>
#if NOT_EXCLUDED(OP_draw_bounding_boxes)
#include <ops/declarable/headers/parity_ops.h>
#include <ops/declarable/helpers/image_draw_bounding_boxes.h>
namespace nd4j {
namespace ops {
OP_IMPL(draw_bounding_boxes, 3, 1, true) {
auto images = INPUT_VARIABLE(0);
auto boxes = INPUT_VARIABLE(1);
auto colors = INPUT_VARIABLE(2);
auto output = OUTPUT_VARIABLE(0);
helpers::drawBoundingBoxesFunctor(block.launchContext(), images, boxes, colors, output);
return ND4J_STATUS_OK;
}
DECLARE_TYPES(draw_bounding_boxes) {
getOpDescriptor()
->setAllowedInputTypes(0, {HALF, FLOAT32})// TF allows HALF and FLOAT32 only
->setAllowedInputTypes(1, {FLOAT32}) // as TF
->setAllowedInputTypes(2, {FLOAT32}) // as TF
->setAllowedOutputTypes({HALF, FLOAT32}); // TF allows HALF and FLOAT32 only
}
}
}
#endif

View File

@ -0,0 +1,71 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author George Shulinok <sgazeos@gmail.com>, created on 08.10.2019
//
#include <op_boilerplate.h>
#if NOT_EXCLUDED(OP_fake_quant_with_min_max_vars_per_channel)
#include <ops/declarable/CustomOperations.h>
#include <ops/declarable/helpers/fake_quantization.h>
namespace nd4j {
namespace ops {
CONFIGURABLE_OP_IMPL(fake_quant_with_min_max_vars_per_channel, 1, 1, true, 0, 0) {
auto x = INPUT_VARIABLE(0);
auto min = INPUT_VARIABLE(1);
auto max = INPUT_VARIABLE(2);
REQUIRE_TRUE(block.width() == 3 || block.getTArguments()->size() == 2, 0, "fake_quant_with_min_max_vars_per_channel: No minimum/maximum values provided by either input arrays or TArgs");
auto depth = x->sizeAt(-1);
REQUIRE_TRUE(min->rankOf() == 1 && max->rankOf() == 1 && min->lengthOf() == max->lengthOf(), 0,
"fake_quant_with_min_max_vars_per_channel: Min and Max should be 1D tensors with the same length");
REQUIRE_TRUE(depth == min->lengthOf(), 0, "fake_quant_with_min_max_vars_per_channel: Min length should be"
" %lld, but %lld occurs.", depth, min->lengthOf());
REQUIRE_TRUE(depth == max->lengthOf(), 0, "fake_quant_with_min_max_vars_per_channel: Max length should be"
"%lld, but %lld occurs.", depth, max->lengthOf());
auto output = OUTPUT_VARIABLE(0);
int numBits = 8;
if (block.getIArguments() && block.getIArguments()->size())
numBits = INT_ARG(0);
bool narrowed = false;
//INT_ARG(1);
if (block.getIArguments()->size() == 2) {
numBits = INT_ARG(0);
narrowed = INT_ARG(1);
REQUIRE_TRUE(numBits > 1 && numBits < 17, 0, "fake_quant_with_min_max_vars_per_channel: Number of bits"
" for quatization should be in between 2 and 16, but %i "
"was given.", numBits);
}
helpers::fakeQuantWithMinMaxVarsPerChannel(block.launchContext(), x, min, max, numBits, narrowed, output);
return ND4J_STATUS_OK;
}
DECLARE_TYPES(fake_quant_with_min_max_vars_per_channel) {
getOpDescriptor()
-> setAllowedOutputTypes({ALL_FLOATS})
-> setAllowedInputTypes({ALL_INTS, ALL_FLOATS});
}
DECLARE_SYN(fake_quant_with_min_max_args_per_channel, fake_quant_with_min_max_vars_per_channel);
}
}
#endif

View File

@ -156,6 +156,18 @@ namespace nd4j {
DECLARE_CUSTOM_OP(divide_bp, 3, 2, false, 0, 0);
#endif
/**
* This is one of auto-broadcastable operations. It accepts 2 operands, and operation is applied based on their shapes:
* 1) if shapes are equal that's pairwise operation, result will have the same shape.
* 2) if shape X is scalar and shape Y is array - result will have shape equal to Y.
* 3) if shape X is array and shape Y is scalar - result will have shape equal to X.
* 4) if shape X and Y are both arrays, but shapes aren't equal - result shape will be broadcast result.
*
* This operation returns Z = Divide(X, Y) with exception, 0 if Y = 0
*/
#if NOT_EXCLUDED(OP_divide_no_nan)
DECLARE_BROADCASTABLE_OP(divide_no_nan, 0, 0);
#endif
/**
* This is one of auto-broadcastable operations. It accepts 2 operands, and operation is applied based on their shapes:
* 1) if shapes are equal that's pairwise operation, result will have the same shape.

View File

@ -99,6 +99,14 @@ namespace nd4j {
#if NOT_EXCLUDED(OP_cast)
DECLARE_CUSTOM_OP(cast, 1, 1, false, 0, 1);
#endif
/**
* This operation change type of input and modified shape of output to conform with given data type
*
* all as above op
* */
#if NOT_EXCLUDED(OP_bitcast)
DECLARE_CUSTOM_OP(bitcast, 1, 1, false, 0, 1);
#endif
}
}

View File

@ -600,6 +600,22 @@ namespace nd4j {
DECLARE_CONFIGURABLE_OP(adjust_saturation, 1, 1, true, 1, -2);
#endif
/**
* This operation adjusts image contrast by given factor ( z = (x - mean) * factor + mean )
* Input arrays:
* 0 - input array with rank >= 3, must have last one dimension equal 3, that is dimension containing channels.
*
* T arguments:
* 0 - contrast factor
*
*/
#if NOT_EXCLUDED(OP_adjust_contrast)
DECLARE_CONFIGURABLE_OP(adjust_contrast, 1, 1, true, 1, 0);
DECLARE_CONFIGURABLE_OP(adjust_contrast_v2, 1, 1, true, 1, 0);
#endif
/**
* This operation rearranges data from depth into blocks of spatial data. This is the reverse transformation
@ -1228,6 +1244,23 @@ namespace nd4j {
DECLARE_CUSTOM_OP(extract_image_patches, 1, 1, false, 0, 7);
#endif
/**
* draw_bounding_boxes op - modified input image with given colors exept given boxes.
*
* input params:
* 0 - images tensor (4D) with shape {batch, width, height, channels}, where channes is 1 (BW image),
* 3 (RGB) or 4 (RGBA)
* 1 - boxes tensor (3D) with shape {batch, number_of_boxes, 4} where last dimension encoded as
* (y_min, x_min, y_max, x_max), all values in between 0. and 1.
* 2 - colours tensor (2D) with shape {number_of_boxes, channels} -- bordering color set (palette)
*
* output:
* 0 - 4D tensor with same shape as images (input 0)
*/
#if NOT_EXCLUDED(OP_draw_bounding_boxes)
DECLARE_OP(draw_bounding_boxes, 3, 1, true);
#endif
/**
* roll - op porting from numpy (https://docs.scipy.org/doc/numpy-1.14.0/reference/generated/numpy.roll.html)
*
@ -1715,6 +1748,39 @@ namespace nd4j {
DECLARE_CONFIGURABLE_OP(fake_quant_with_min_max_vars, 3, 1, true, 0, -2);
#endif
/**
* fake_quant_with_min_max_vals_per_channel - tf.quantization.fake_quant_with_min_max_vars_per_channel
*
* input params:
* 0 - NDArray (input) - at least 2D.
* 1 - 1D Tensor - min values (min length equals to last dim of input)
* 2 - 1D Tensor - max value (length equals to min)
*
* int params (optional):
* 0 - num_bits (allowed interval [2, 16], default 8)
* 1 - narrow_range (default False)
*
* output:
* 0 - NDArray with the same shape as input
*/
#if NOT_EXCLUDED(OP_fake_quant_with_min_max_vars_per_channel)
DECLARE_CONFIGURABLE_OP(fake_quant_with_min_max_vars_per_channel, 3, 1, true, 0, -2);
#endif
/**
* compare_and_bitpack - compare with greater and pack result with uint8
*
* input params:
* 0 - NDArray (input)
* 1 - 0D Tensor - threshold
*
*
* output:
* 0 - NDArray with the same shape as input and type uint8
*/
#if NOT_EXCLUDED(OP_compare_and_bitpack)
DECLARE_CUSTOM_OP(compare_and_bitpack, 2, 1, false, 0, 0);
#endif
}
}

View File

@ -25,74 +25,89 @@ namespace nd4j {
namespace ops {
namespace helpers {
//
// nudge - nudged min max over scale
// scale = (Max - Min) / (quantMax - quantMin)
// quantMin = 0 or 1, quantMax = 2^b - 1 == (1 << b) - 1
//
template <typename T>
static void nudge(T min, T max, int quantMin, int quantMax, T* scale, T* nudgedMin, T* nudgedMax) {
// floating point instead integers
T quantMaxF = static_cast<T>(quantMax);
T quantMinF = static_cast<T>(quantMin);
// compute scale
*scale = (max - min) / (quantMaxF - quantMinF);
// compute left bound point
auto zeroPointFromMin = quantMinF - min / *scale;
// bound zero point to conform with range [0 or 1, 2^b - 1]
uint16_t const nudged_zero_point = [zeroPointFromMin, quantMin, quantMax, quantMaxF, quantMinF] {
if (zeroPointFromMin < quantMinF) {
return static_cast<uint16_t>(quantMin);
}
if (zeroPointFromMin > quantMaxF) {
return static_cast<uint16_t>(quantMax);
}
return nd4j::math::nd4j_round<T,uint16_t>(zeroPointFromMin);
}();
// compute nudged min and max with computed nudged zero point
*nudgedMin = (quantMinF - nudged_zero_point) * (*scale);
*nudgedMax = (quantMaxF - nudged_zero_point) * (*scale);
}
template <typename T>
void fakeQuantWithMinMaxVarsPerChannel_(NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output) {
int lowIntBound = narrowed ? 1 : 0; // 0 or 1
int upperIntBound = (1 << numBits) - 1; // 2^b - 1
auto channels = input->sizeAt(-1); // last dimension
PRAGMA_OMP_PARALLEL_FOR
for (auto i = 0; i < channels; i++) {
T scale, nudged_min, nudged_max;
// nudge min and max first, with scale computing
nudge<T>(min->t<T>(i), max->t<T>(i), lowIntBound, upperIntBound, &scale, &nudged_min, &nudged_max);
// slide using last dimension and process all for given channel
for (auto e = 0; e < input->lengthOf(); e += channels) {
T val = input->t<T>(e + i);
if ( val <= nudged_min)
val = nudged_min;
else if (val >= nudged_max)
val = nudged_max;
// quantization itself
output->t<T>(e + i) = math::nd4j_floor<T,T>((val - nudged_min)/scale + T(0.5)) * scale + nudged_min;
}
}
}
template <typename T>
void fakeQuantWithMinMaxVars_(NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output) {
int lowIntBound = narrowed ? 1 : 0;
int upperIntBound = 1 << numBits - 1;
int upperIntBound = (1 << numBits) - 1;
const float quant_min_float = static_cast<float>(lowIntBound);
const float quant_max_float = static_cast<float>(upperIntBound);
T scale = (max->t<T>(0) - min->t<T>(0)) / (quant_max_float - quant_min_float);
const T zero_point_from_min = quant_min_float - min->e<T>(0) / scale;
const uint16_t nudged_zero_point = [zero_point_from_min, lowIntBound,
quant_min_float, upperIntBound,
quant_max_float] {
if (zero_point_from_min < quant_min_float) {
return static_cast<uint16_t>(lowIntBound);
}
if (zero_point_from_min > quant_max_float) {
return static_cast<uint16_t>(upperIntBound);
}
return static_cast<uint16_t>(roundf(zero_point_from_min));
}();
auto nudged_min = (quant_min_float - nudged_zero_point) * (scale);
auto nudged_max = (quant_max_float - nudged_zero_point) * (scale);
//input->applyScalar(scalar::CompareAndSet, nudged_max, clamped, nullptr); //.cwiseMin(nudged_max).cwiseMax(nudged_min);
//input->applyScalar(scalar::CompareAndSet, nudged_min, clamped, nullptr); //.cwiseMin(nudged_max).cwiseMax(nudged_min);
auto wiseMax = LAMBDA_T(x, nudged_min) {
if (x < nudged_min) {
return nudged_min;
T nudgedMin, nudgedMax, scale;
// nudge with given min and max and compute scale and nudged min and max
nudge<T>(min->t<T>(0), max->t<T>(0), lowIntBound, upperIntBound, &scale, &nudgedMin, &nudgedMax);
// quantization as one
auto fakeQuantizationWithMinMax = LAMBDA_T(x, nudgedMin, nudgedMax, scale) {
T val = x; // boundign value between nudged min and max
if (val < nudgedMin) {
val = nudgedMin;
}
return x;
else if (val > nudgedMax)
val = nudgedMax;
// converse value with scale and shifted with nudged min
return (nd4j::math::nd4j_floor<T,T>((val - nudgedMin)/scale + T(0.5)) * scale + nudgedMin);
};
auto wiseMin = LAMBDA_T(x, nudged_max) {
if (x > nudged_max) {
return nudged_max;
}
return x;
};
auto scaleTensor(*input); // = NDArrayFactory::create(input->ordering(), input->getShapeAsVector(), input->getWorkspace());
auto clamped(*input); // = NDArrayFactory::create(input->ordering(), input->getShapeAsVector(), input->getWorkspace());
scaleTensor.assign(scale);
input->applyLambda<T>(wiseMin, &clamped);
// const auto clamped = inputs.cwiseMin(nudged_max).cwiseMax(nudged_min);
clamped.applyLambda<T>(wiseMax, output);
// const auto clamped_shifted = clamped - nudged_min;
*output -= nudged_min;
// auto nudgedScale = scale;
(*output) /= scaleTensor;
(*output) += T(0.5f);
output->applyTransform(transform::Floor, nullptr, nullptr);
(*output) *= scaleTensor;
(*output) += nudged_min;
//output->printIndexedBuffer("FAKE QUANTED");
/*
const auto nudged_scale_repl = inputs.constant(nudged_scale);
const auto clamped = inputs.cwiseMin(nudged_max).cwiseMax(nudged_min);
const auto clamped_shifted = clamped - nudged_min;
*output = (clamped_shifted / nudged_scale_repl + 0.5f).floor() *
nudged_scale_repl +
nudged_min;
*/
input->applyLambda<T>(fakeQuantizationWithMinMax, output);
}
void fakeQuantWithMinMaxVars(NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output) {
BUILD_SINGLE_SELECTOR(input->dataType(), fakeQuantWithMinMaxVars_, (input, min, max, numBits, narrowed, output), FLOAT_TYPES);
}
void fakeQuantWithMinMaxVarsPerChannel(LaunchContext* context, NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output) {
BUILD_SINGLE_SELECTOR(input->dataType(), fakeQuantWithMinMaxVarsPerChannel_, (input, min, max, numBits, narrowed, output), FLOAT_TYPES);
}
BUILD_SINGLE_TEMPLATE(template void fakeQuantWithMinMaxVars_, (NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output), FLOAT_TYPES);
}

View File

@ -0,0 +1,78 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author sgazeos@gmail.com
//
#include <op_boilerplate.h>
#include <NDArray.h>
namespace nd4j {
namespace ops {
namespace helpers {
void drawBoundingBoxesFunctor(nd4j::LaunchContext * context, NDArray* images, NDArray* boxes, NDArray* colors, NDArray* output) {
// images - batch of 3D images with BW (last dim = 1), RGB (last dim = 3) or RGBA (last dim = 4) channel set
// boxes - batch of 2D bounds with last dim (y_start, x_start, y_end, x_end) to compute i and j as
// floor((height - 1 ) * y_start) => rowStart, floor((height - 1) * y_end) => rowEnd
// floor((width - 1 ) * x_start) => colStart, floor((width - 1) * x_end) => colEnd
// height = images->sizeAt(1), width = images->sizeAt(2)
// colors - colors for each box given
// set up color for each box as frame
auto batchSize = images->sizeAt(0);
auto height = images->sizeAt(1);
auto width = images->sizeAt(2);
auto channels = images->sizeAt(3);
//auto imageList = images->allTensorsAlongDimension({1, 2, 3}); // split images by batch
// auto boxList = boxes->allTensorsAlongDimension({1, 2}); // split boxes by batch
auto colorSet = colors->allTensorsAlongDimension({1});
output->assign(images); // fill up all output with input images, then fill up boxes
PRAGMA_OMP_PARALLEL_FOR
for (auto b = 0; b < batchSize; ++b) { // loop by batch
// auto image = imageList->at(b);
for (auto c = 0; c < colorSet->size(); ++c) {
// box with shape
auto internalBox = (*boxes)(b, {0})(c, {0});//internalBoxes->at(c);
auto color = colorSet->at(c);
auto rowStart = nd4j::math::nd4j_max(Nd4jLong (0), Nd4jLong ((height - 1) * internalBox.e<float>(0)));
auto rowEnd = nd4j::math::nd4j_min(Nd4jLong (height - 1), Nd4jLong ((height - 1) * internalBox.e<float>(2)));
auto colStart = nd4j::math::nd4j_max(Nd4jLong (0), Nd4jLong ((width - 1) * internalBox.e<float>(1)));
auto colEnd = nd4j::math::nd4j_min(Nd4jLong(width - 1), Nd4jLong ((width - 1) * internalBox.e<float>(3)));
for (auto y = rowStart; y <= rowEnd; y++) {
for (auto e = 0; e < color->lengthOf(); ++e) {
output->p(b, y, colStart, e, color->e(e));
output->p(b, y, colEnd, e, color->e(e));
}
}
for (auto x = colStart + 1; x < colEnd; x++) {
for (auto e = 0; e < color->lengthOf(); ++e) {
output->p(b, rowStart, x, e, color->e(e));
output->p(b, rowEnd, x, e, color->e(e));
}
}
}
// delete internalBoxes;
}
delete colorSet;
// delete imageList;
// delete boxList;
}
}
}
}

View File

@ -33,65 +33,105 @@ namespace helpers {
// narrowed - shrink is true
// output - output tensor
//
template <typename T>
static __host__ __device__ void
nudge(T min, T max, int quantMin, int quantMax, T* scale, T* nudgedMin, T* nudgedMax) {
T quantMaxF = static_cast<T>(quantMax);
T quantMinF = static_cast<T>(quantMin);
*scale = (max - min) / (quantMaxF - quantMinF);
auto zeroPointFromMin = quantMinF - min / *scale;
uint16_t const nudgedZeroPoint = [zeroPointFromMin, quantMin, quantMax, quantMaxF, quantMinF] {
if (zeroPointFromMin < quantMinF) {
return static_cast<uint16_t>(quantMin);
}
if (zeroPointFromMin > quantMaxF) {
return static_cast<uint16_t>(quantMax);
}
return nd4j::math::nd4j_round<T,uint16_t>(zeroPointFromMin);
}();
*nudgedMin = (quantMinF - nudgedZeroPoint) * (*scale);
*nudgedMax = (quantMaxF - nudgedZeroPoint) * (*scale);
}
template <typename T>
void fakeQuantWithMinMaxVars_(NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output) {
int lowIntBound = narrowed?1:0;
int upperIntBound = 1 << numBits - 1;
min->syncToHost();
int upperIntBound = (1 << numBits) - 1;
min->syncToHost(); // these are scalars, so nothing much happened
max->syncToHost();
const float quant_min_float = static_cast<float>(lowIntBound);
const float quant_max_float = static_cast<float>(upperIntBound);
T scale = (max->t<T>(0) - min->t<T>(0)) / (quant_max_float - quant_min_float);
const T zero_point_from_min = quant_min_float - min->t<T>(0) / scale;
T scale, nudgedMin, nudgedMax;
nudge(min->t<T>(0), max->t<T>(0), lowIntBound, upperIntBound, &scale, &nudgedMin, &nudgedMax);
const uint16_t nudged_zero_point = [zero_point_from_min, lowIntBound,
quant_min_float, upperIntBound,
quant_max_float] {
if (zero_point_from_min < quant_min_float) {
return static_cast<uint16_t>(lowIntBound);
auto wiseMinMaxAndSoOn = LAMBDA_T(x, nudgedMin, nudgedMax, scale) {
T val = x;
if (x < nudgedMin) {
val = nudgedMin;
}
if (zero_point_from_min > quant_max_float) {
return static_cast<uint16_t>(upperIntBound);
else if (x > nudgedMax) {
val = nudgedMax;
}
return static_cast<uint16_t>(roundf(zero_point_from_min));
}();
auto nudged_min = (quant_min_float - nudged_zero_point) * (scale);
auto nudged_max = (quant_max_float - nudged_zero_point) * (scale);
auto wiseMax = LAMBDA_T(x, nudged_min) {
if (x < nudged_min) {
return nudged_min;
}
return x;
else
val = x;
return (math::nd4j_floor<T,T>((val - nudgedMin) / scale + T(0.5)) * scale + nudgedMin);
};
auto wiseMin = LAMBDA_T(x, nudged_max) {
if (x > nudged_max) {
return nudged_max;
}
return x;
};
input->applyLambda(wiseMinMaxAndSoOn, output);
}
auto scaleTensor(*input);
auto clamped(*input);
scaleTensor.assign(scale);
input->applyLambda(wiseMin, &clamped);
template <typename T>
static __global__ void fakeQuantWithMinMaxKernel(T* input, Nd4jLong* inputShape, T* min, T* max,
int lowIntBound, int upperIntBound, Nd4jLong channels,
T* output, Nd4jLong* outputShape, Nd4jLong length) {
__shared__ int block;
if (threadIdx.x == 0) {
block = length / channels; // to loop with last dimension as block
}
__syncthreads();
clamped.applyLambda(wiseMax, output);
*output -= nudged_min;
for (auto i = blockIdx.x; i < (int)channels; i += gridDim.x) {
T scale, nudgedMin, nudgedMax;
nudge(min[i], max[i], lowIntBound, upperIntBound, &scale, &nudgedMin, &nudgedMax);
// loop over blocks to quantization between nudged min and max
for (auto b = threadIdx.x; b < block; b += blockDim.x) {
T val = input[shape::getIndexOffset(b * channels + i, inputShape)];
if (val < nudgedMin) {
val = nudgedMin;
} else if (val > nudgedMax) {
val = nudgedMax;
}
output[shape::getIndexOffset(b * channels + i, outputShape)] =
(math::nd4j_floor<T, T>((val - nudgedMin) / scale + T(0.5)) * scale + nudgedMin);
};
}
}
template <typename T>
void fakeQuantWithMinMaxVarsPerChannel_(LaunchContext* context, NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output) {
int lowIntBound = narrowed?1:0;
int upperIntBound = (1 << numBits) - 1;
auto channels = min->lengthOf();
auto length = input->lengthOf();
NDArray::prepareSpecialUse({output}, {min, max, input});
auto stream = context->getCudaStream();
T* inputBuf = input->dataBuffer()->specialAsT<T>();
T* outputBuf = output->dataBuffer()->specialAsT<T>();
T* minBuf = min->dataBuffer()->specialAsT<T>();
T* maxBuf = max->dataBuffer()->specialAsT<T>();
fakeQuantWithMinMaxKernel<<<128, 256, 256, *stream>>>(inputBuf, input->specialShapeInfo(),
minBuf, maxBuf, lowIntBound, upperIntBound, channels, outputBuf, output->specialShapeInfo(), length);
NDArray::registerSpecialUse({output}, {min, max, input});
(*output) /= scaleTensor;
(*output) += T(0.5f);
output->applyTransform(transform::Floor, nullptr, nullptr);
(*output) *= scaleTensor;
(*output) += nudged_min;
}
void fakeQuantWithMinMaxVars(NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output) {
BUILD_SINGLE_SELECTOR(input->dataType(), fakeQuantWithMinMaxVars_, (input, min, max, numBits, narrowed, output), FLOAT_TYPES);
}
void fakeQuantWithMinMaxVarsPerChannel(LaunchContext* context, NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output) {
BUILD_SINGLE_SELECTOR(input->dataType(), fakeQuantWithMinMaxVarsPerChannel_, (context, input, min, max, numBits, narrowed, output), FLOAT_TYPES);
}
BUILD_SINGLE_TEMPLATE(template void fakeQuantWithMinMaxVars_, (NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output), FLOAT_TYPES);
BUILD_SINGLE_TEMPLATE(template void fakeQuantWithMinMaxVarsPerChannel_, (LaunchContext* context, NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output), FLOAT_TYPES);
}
}

View File

@ -0,0 +1,100 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author sgazeos@gmail.com
//
#include <op_boilerplate.h>
#include <NDArray.h>
namespace nd4j {
namespace ops {
namespace helpers {
template <typename T>
static __global__ void drawBoundingBoxesKernel(T const* images, Nd4jLong* imagesShape, T const* boxes,
Nd4jLong* boxesShape, T const* colors, Nd4jLong* colorsShape, T* output, Nd4jLong* outputShape,
Nd4jLong batchSize, Nd4jLong width, Nd4jLong height, Nd4jLong channels, Nd4jLong colorSetSize) {
for (auto b = blockIdx.x; b < (int)batchSize; b += gridDim.x) { // loop by batch
for (auto c = 0; c < colorSetSize; c++) {
// box with shape
auto internalBox = &boxes[b * colorSetSize * 4 + c * 4];//(*boxes)(b, {0})(c, {0});//internalBoxes->at(c);
auto color = &colors[channels * c];//colorSet->at(c);
auto rowStart = nd4j::math::nd4j_max(Nd4jLong (0), Nd4jLong ((height - 1) * internalBox[0]));
auto rowEnd = nd4j::math::nd4j_min(Nd4jLong (height - 1), Nd4jLong ((height - 1) * internalBox[2]));
auto colStart = nd4j::math::nd4j_max(Nd4jLong (0), Nd4jLong ((width - 1) * internalBox[1]));
auto colEnd = nd4j::math::nd4j_min(Nd4jLong(width - 1), Nd4jLong ((width - 1) * internalBox[3]));
for (auto y = rowStart + threadIdx.x; y <= rowEnd; y += blockDim.x) {
for (auto e = 0; e < channels; ++e) {
Nd4jLong yMinPos[] = {b, y, colStart, e};
Nd4jLong yMaxPos[] = {b, y, colEnd, e};
auto zIndexYmin = shape::getOffset(outputShape, yMinPos);
auto zIndexYmax = shape::getOffset(outputShape, yMaxPos);
output[zIndexYmin] = color[e];
output[zIndexYmax] = color[e];
}
}
for (auto x = colStart + 1 + threadIdx.x; x < colEnd; x += blockDim.x) {
for (auto e = 0; e < channels; ++e) {
Nd4jLong xMinPos[] = {b, rowStart, x, e};
Nd4jLong xMaxPos[] = {b, rowEnd, x, e};
auto zIndexXmin = shape::getOffset(outputShape, xMinPos);
auto zIndexXmax = shape::getOffset(outputShape, xMaxPos);
output[zIndexXmin] = color[e];
output[zIndexXmax] = color[e];
}
}
}
}
}
template <typename T>
void drawBoundingBoxesH(nd4j::LaunchContext* context, NDArray const* images, NDArray const* boxes, NDArray const* colors, NDArray* output) {
auto batchSize = images->sizeAt(0);
auto height = images->sizeAt(1);
auto width = images->sizeAt(2);
auto channels = images->sizeAt(3);
auto stream = context->getCudaStream();
auto colorSetSize = colors->sizeAt(0);
auto imagesBuf = images->getDataBuffer()->specialAsT<T>();
auto boxesBuf = boxes->getDataBuffer()->specialAsT<T>();
auto colorsBuf = colors->getDataBuffer()->specialAsT<T>();
auto outputBuf = output->dataBuffer()->specialAsT<T>();
drawBoundingBoxesKernel<<<batchSize > 128? 128: batchSize, 256, 1024, *stream>>>(imagesBuf, images->getSpecialShapeInfo(),
boxesBuf, boxes->getSpecialShapeInfo(), colorsBuf, colors->getSpecialShapeInfo(),
outputBuf, output->specialShapeInfo(), batchSize, width, height, channels, colorSetSize);
}
void drawBoundingBoxesFunctor(nd4j::LaunchContext * context, NDArray* images, NDArray* boxes, NDArray* colors, NDArray* output) {
// images - batch of 3D images with BW (last dim = 1), RGB (last dim = 3) or RGBA (last dim = 4) channel set
// boxes - batch of 2D bounds with last dim (y_start, x_start, y_end, x_end) to compute i and j as
// floor((height - 1 ) * y_start) => rowStart, floor((height - 1) * y_end) => rowEnd
// floor((width - 1 ) * x_start) => colStart, floor((width - 1) * x_end) => colEnd
// height = images->sizeAt(1), width = images->sizeAt(2)
// colors - colors for each box given
// set up color for each box as frame
NDArray::prepareSpecialUse({output}, {images, boxes, colors});
output->assign(images);
BUILD_SINGLE_SELECTOR(output->dataType(), drawBoundingBoxesH, (context, images, boxes, colors, output), FLOAT_TYPES);
NDArray::registerSpecialUse({output}, {images, boxes, colors});
}
BUILD_SINGLE_TEMPLATE(template void drawBoundingBoxesH, (nd4j::LaunchContext* context, NDArray const* images, NDArray const* boxes, NDArray const* colors, NDArray* output), FLOAT_TYPES);
}
}
}

View File

@ -27,6 +27,7 @@ namespace ops {
namespace helpers {
void fakeQuantWithMinMaxVars(NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output);
void fakeQuantWithMinMaxVarsPerChannel(LaunchContext* context, NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output);
}
}
}

View File

@ -0,0 +1,34 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author sgazeos@gmail.com
//
#ifndef __IMAGE_DRAW_BOUNDING_BOXES_H_HELPERS__
#define __IMAGE_DRAW_BOUNDING_BOXES_H_HELPERS__
#include <op_boilerplate.h>
#include <NDArray.h>
namespace nd4j {
namespace ops {
namespace helpers {
void drawBoundingBoxesFunctor(nd4j::LaunchContext * context, NDArray* images, NDArray* boxes, NDArray* colors, NDArray* output);
}
}
}
#endif

View File

@ -37,6 +37,10 @@ namespace nd4j {
return custom(nd4j::scalar::Divide, nd4j::pairwise::Divide, nd4j::broadcast::Divide);
}
BroadcastOpsTuple BroadcastOpsTuple::DivideNoNan() {
return custom(nd4j::scalar::DivideNoNan, nd4j::pairwise::DivideNoNan, nd4j::broadcast::DivideNoNan);
}
BroadcastOpsTuple BroadcastOpsTuple::Multiply() {
return custom(nd4j::scalar::Multiply, nd4j::pairwise::Multiply, nd4j::broadcast::Multiply);
}

View File

@ -360,6 +360,34 @@ namespace simdOps {
}
};
template <typename X, typename Y, typename Z>
class DivideNoNan {
public:
op_def static Z op(X d1, Y d2) {
if (d2 == (Y)0) return (Z)0;
return static_cast<Z>(d1 / d2);
}
op_def static Z op(X d1, Y d2, Z *params) {
if (d2 == (Y)0) return (Z)0;
return static_cast<Z>(d1 / d2);
}
op_def static Z op(X d1) {
return static_cast<Z>(d1);
}
// op for MetaOps
op_def static Z op(X d1, Y *params) {
if (params[0] == (Y)0) return (Z)0;
return static_cast<Z>(d1 / params[0]);
}
op_def static X startingValue() {
return static_cast<X>(1);
}
};
template <typename X, typename Y, typename Z>
class SafeDivide {
public:

View File

@ -1194,6 +1194,41 @@ TEST_F(DeclarableOpsTests1, BroadcastDivideTest_1) {
delete res;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests1, BroadcastDivideTest_2) {
auto x = NDArrayFactory::create<float>('c', {3, 4, 5, 1});
auto y = NDArrayFactory::create<float>('c', {1, 6});
auto exp = NDArrayFactory::create<float>('c', {3, 4, 5, 6});
x.assign(6);
y.assign(2);
exp.assign(3);
nd4j::ops::divide_no_nan div;
auto res = div.execute({&x, &y}, {}, {});
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
ASSERT_TRUE(res->at(0)->equalsTo(exp));
delete res;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests1, BroadcastDivideTest_3) {
auto x = NDArrayFactory::create<float>({6,6,6,6,6});
auto y = NDArrayFactory::create<float>({3,3,0,3,3});
auto exp = NDArrayFactory::create<float>({2, 2, 0, 2, 2});
nd4j::ops::divide_no_nan div;
auto res = div.execute({&x, &y}, {}, {});
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
ASSERT_TRUE(res->at(0)->equalsTo(exp));
delete res;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests1, BroadcastReverseDivideTest_1) {

View File

@ -2043,11 +2043,81 @@ TEST_F(DeclarableOpsTests10, Image_CropAndResize_5) {
delete results;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests10, Image_DrawBoundingBoxes_1) {
NDArray images = NDArrayFactory::create<float>('c', {2,4,5,3});
NDArray boxes = NDArrayFactory::create<float>('c', {2, 2, 4}, {
0. , 0. , 1. , 1. , 0.1, 0.2, 0.9, 0.8,
0.3, 0.3, 0.7, 0.7, 0.4, 0.4, 0.6, 0.6
});
NDArray colors = NDArrayFactory::create<float>('c', {2, 3}, {201., 202., 203., 127., 128., 129.});
//NDArray<float> ('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f});
NDArray expected = NDArrayFactory::create<float>('c', {2,4,5,3}, {
127., 128., 129., 127., 128., 129., 127., 128., 129., 127., 128., 129., 201., 202., 203.,
127., 128., 129., 19., 20., 21., 22., 23., 24., 127., 128., 129., 201., 202., 203.,
127., 128., 129., 127., 128., 129., 127., 128., 129., 127., 128., 129., 201., 202., 203.,
201., 202., 203., 201. ,202. ,203., 201., 202., 203., 201., 202., 203., 201., 202., 203.,
61., 62., 63., 201., 202., 203., 201., 202., 203., 70., 71., 72., 73., 74., 75.,
76., 77., 78., 127., 128., 129., 127., 128., 129., 85., 86., 87., 88., 89., 90.,
91., 92., 93., 201., 202., 203., 201., 202., 203., 100., 101., 102., 103., 104., 105.,
106., 107., 108., 109., 110., 111., 112., 113., 114., 115., 116., 117., 118., 119., 120.
});
images.linspace(1.);
nd4j::ops::draw_bounding_boxes op;
auto results = op.execute({&images, &boxes, &colors}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto result = results->at(0);
// result->printBuffer("Bounded boxes");
// expected.printBuffer("Bounded expec");
ASSERT_TRUE(expected.isSameShapeStrict(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests10, Image_DrawBoundingBoxes_2) {
NDArray images = NDArrayFactory::create<float>('c', {1,9,9,1});
NDArray boxes = NDArrayFactory::create<float>('c', {1, 1, 4}, {0.2, 0.2, 0.7, 0.7});
NDArray colors = NDArrayFactory::create<float>('c', {1, 1}, {0.95});
//NDArray<float> ('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f});
NDArray expected = NDArrayFactory::create<float>('c', {1,9,9,1}, {
1.1 , 2.1, 3.1 , 4.1 , 5.1 , 6.1 , 7.1 , 8.1 , 9.1 ,
10.1 , 0.95, 0.95, 0.95, 0.95, 0.95, 16.1 , 17.1 , 18.1 ,
19.1 , 0.95, 21.1, 22.1, 23.1, 0.95, 25.1 , 26.1 , 27.1 ,
28.1 , 0.95, 30.1, 31.1, 32.1, 0.95, 34.1 , 35.1 , 36.1 ,
37.1 , 0.95, 39.1, 40.1, 41.1, 0.95, 43.1 , 44.1 , 45.1 ,
46.1 , 0.95, 0.95, 0.95, 0.95, 0.95, 52.1 , 53.1 , 54.1 ,
55.1 , 56.1, 57.1 , 58.1 , 59.1 , 60.1 , 61.1 , 62.1 , 63.1 ,
64.1 , 65.1, 66.1 , 67.1 , 68.1 , 69.1 , 70.1 , 71.1 , 72.1 ,
73.1 , 74.1, 75.1 , 76.1 , 77.1 , 78.1 , 79.1 , 80.1 , 81.1 });
images.linspace(1.1);
nd4j::ops::draw_bounding_boxes op;
auto results = op.execute({&images, &boxes, &colors}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto result = results->at(0);
// result->syncToHost();
// result->printBuffer("Bounded boxes 2");
// expected.printBuffer("Bounded expec 2");
ASSERT_TRUE(expected.isSameShapeStrict(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_1) {
NDArray x('c', {2,3}, {-63.80f, -63.75f, -63.70f, -63.5f, 0.0f, 0.1f}, nd4j::DataType::FLOAT32);
NDArray exp('c', {2,3}, {-63.75f, -63.75f, -63.75f, -63.251953f, 0.0f, 0.0f}, nd4j::DataType::FLOAT32);
NDArray exp('c', {2,3}, {-63.75, -63.75, -63.75, -63.5, 0., 0.}, nd4j::DataType::FLOAT32);
NDArray min('c', {}, {-63.65f}, nd4j::DataType::FLOAT32);
NDArray max('c', {}, {0.1f}, nd4j::DataType::FLOAT32);
@ -2057,7 +2127,8 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_1) {
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto result = results->at(0);
// result->printIndexedBuffer("Quantized");
// result->printBuffer("Quantized");
// exp.printBuffer("Expected");
ASSERT_TRUE(exp.isSameShapeStrict(result));
ASSERT_TRUE(exp.equalsTo(result));
@ -2067,7 +2138,7 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_1) {
TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_2) {
NDArray x = NDArrayFactory::create<double>('c', {2,3}, {-63.80, -63.75, -63.4, -63.5, 0.0, 0.1});
NDArray exp = NDArrayFactory::create<double>('c', {2,3}, {-63.75, -63.75, -63.251953, -63.251953, 0.0, 0.0});
NDArray exp = NDArrayFactory::create<double>('c', {2,3}, {-63.75, -63.75, -63.5 , -63.5 , 0. , 0. });
NDArray min = NDArrayFactory::create<double>(-63.65);
NDArray max = NDArrayFactory::create<double>(0.1);
@ -2084,6 +2155,119 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_2) {
delete results;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_3) {
NDArray x = NDArrayFactory::create<double>('c', {1,2,3,1}, {-63.80, -63.75, -63.4, -63.5, 0.0, 0.1});
NDArray exp = NDArrayFactory::create<double>('c', {1,2,3,1}, {-63.75, -63.75, -63.5 , -63.5 , 0. , 0. });
NDArray min = NDArrayFactory::create<double>('c', {1},{-63.65});
NDArray max = NDArrayFactory::create<double>('c', {1}, {0.1});
nd4j::ops::fake_quant_with_min_max_vars_per_channel op;
auto results = op.execute({&x, &min, &max}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto result = results->at(0);
// result->printIndexedBuffer("Quantized2");
ASSERT_TRUE(exp.isSameShapeStrict(result));
ASSERT_TRUE(exp.equalsTo(result));
delete results;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_4) {
NDArray x = NDArrayFactory::create<float>('c', {2,4,5,3});
NDArray exp = NDArrayFactory::create<float>('c', {2,4,5,3},{
1.0588236, 1.9607843, 3.019608, 4.0588236, 5.098039, 6.039216, 7.0588236, 8.039216, 9.058824,
10.058824, 10.980392, 12.078432, 13.058824, 13.921569, 15.09804, 16.058825, 17.058825, 18.117647,
19.058825, 20., 21.137257, 22.058825, 22.941177, 23.882355, 25.058825, 26.078432, 26.901962,
28.058825, 29.019608, 29.92157, 31.058825, 31.960785, 32.941177, 34.058823, 35.09804, 35.960785,
37.058823, 38.039215, 38.980392, 40.058823, 40.980392, 42.000004, 43.058826, 43.92157, 45.01961,
45., 47.058823, 48.03922, 45., 50., 51.058826, 45., 50., 54.078434,
45., 50., 57.09804, 45., 50., 60.11765, 45., 50., 62.862747,
45., 50., 65.882355, 45., 50., 68.90196, 45., 50., 70.,
45., 50., 70., 45., 50., 70., 45., 50., 70.,
45., 50., 70., 45., 50., 70., 45., 50., 70.,
45., 50., 70., 45., 50., 70., 45., 50., 70.,
45., 50., 70., 45., 50., 70., 45., 50., 70.,
45., 50., 70., 45., 50., 70., 45., 50., 70.,
45., 50., 70.});
NDArray min = NDArrayFactory::create<float>({20., 20., 20.});
NDArray max = NDArrayFactory::create<float>({65., 70., 90.});
x.linspace(1.);
nd4j::ops::fake_quant_with_min_max_vars_per_channel op;
auto results = op.execute({&x, &min, &max}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto result = results->at(0);
// result->printBuffer("Quantized per channels 4");
// exp.printBuffer("Quantized per channest E");
// auto diff = *result - exp;
// diff.printIndexedBuffer("Difference");
ASSERT_TRUE(exp.isSameShapeStrict(result));
ASSERT_TRUE(exp.equalsTo(result));
delete results;
}
TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_5) {
NDArray x = NDArrayFactory::create<float>('c', {2, 3, 5, 4});
NDArray exp = NDArrayFactory::create<float>('c', {2, 3, 5, 4},{
-19.92157 , -18.980392 , -18.039217 , -16.941177 ,
-19.92157 , -18.980392 , -18.039217 , -16.941177 ,
-19.92157 , -18.980392 , -18.039217 , -16.941177 ,
-19.92157 , -18.980392 , -18.039217 , -16.941177 ,
-19.92157 , -18.980392 , -18.039217 , -16.941177 ,
-19.92157 , -18.980392 , -18.039217 , -16.941177 ,
-19.92157 , -18.980392 , -18.039217 , -16.941177 ,
-19.92157 , -18.980392 , -18.039217 , -16.941177 ,
-19.92157 , -18.980392 , -18.039217 , -16.941177 ,
-19.92157 , -18.980392 , -18.039217 , -16.941177 ,
-19.92157 , -18.980392 , -18.039217 , -16.941177 ,
-16. , -15.058824 , -13.960785 , -13.0196085 ,
-11.92157 , -10.980392 , -10.039217 , -8.941177 ,
-8.000001 , -7.0588236 , -5.960785 , -5.0196085 ,
-3.9215698 , -2.9803925 , -2.039217 , -0.94117737,
0. , 0.94117737, 2.039215 , 2.9803925 ,
4.07843 , 5.0196075 , 5.960783 , 7.0588226 ,
8. , 8.941177 , 10.039215 , 10.980392 ,
12.07843 , 13.019608 , 13.960783 , 15.058823 ,
16. , 16.941177 , 18.039217 , 18.980392 ,
20.07843 , 21.019608 , 21.960783 , 23.058823 ,
20.07843 , 21.019608 , 21.960783 , 23.058823 ,
20.07843 , 21.019608 , 21.960783 , 23.058823 ,
20.07843 , 21.019608 , 21.960783 , 23.058823 ,
20.07843 , 21.019608 , 21.960783 , 23.058823 ,
20.07843 , 21.019608 , 21.960783 , 23.058823 ,
20.07843 , 21.019608 , 21.960783 , 23.058823 ,
20.07843 , 21.019608 , 21.960783 , 23.058823 ,
20.07843 , 21.019608 , 21.960783 , 23.058823 ,
20.07843 , 21.019608 , 21.960783 , 23.058823
});
NDArray min = NDArrayFactory::create<float>({-20., -19., -18., -17});
NDArray max = NDArrayFactory::create<float>({20., 21., 22., 23});
x.linspace(-60.);
nd4j::ops::fake_quant_with_min_max_vars_per_channel op;
auto results = op.execute({&x, &min, &max}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto result = results->at(0);
// result->printBuffer("Quantized per channels 5");
// exp.printBuffer("Quantized per channest E");
// auto diff = *result - exp;
// diff.printIndexedBuffer("Difference");
ASSERT_TRUE(exp.isSameShapeStrict(result));
ASSERT_TRUE(exp.equalsTo(result));
delete results;
}
////////////////////////////////////////////////////////////////////
TYPED_TEST(TypedDeclarableOpsTests10, batchnorm_new_test1) {

View File

@ -157,6 +157,104 @@ TEST_F(DeclarableOpsTests15, Test_standarize_bp_1) {
delete result;
}
TEST_F(DeclarableOpsTests15, Test_AdjustContrast_1) {
auto x = NDArrayFactory::create<double>('c', {4,4,3});
auto e = NDArrayFactory::create<double>('c', {4,4,3}, {
-21.5, -20.5, -19.5, -15.5, -14.5, -13.5, -9.5, -8.5, -7.5, -3.5, -2.5, -1.5,
2.5, 3.5, 4.5, 8.5, 9.5, 10.5, 14.5, 15.5, 16.5, 20.5, 21.5, 22.5,
26.5, 27.5, 28.5, 32.5, 33.5, 34.5, 38.5, 39.5, 40.5, 44.5, 45.5, 46.5,
50.5, 51.5, 52.5, 56.5, 57.5, 58.5, 62.5, 63.5, 64.5, 68.5, 69.5, 70.5
});
x.linspace(1.);
nd4j::ops::adjust_contrast op;
auto result = op.execute({&x}, {2.}, {}, {});
ASSERT_EQ(Status::OK(), result->status());
auto out = result->at(0);
// out->printIndexedBuffer("Adjusted Constrast");
ASSERT_TRUE(e.equalsTo(out));
delete result;
}
TEST_F(DeclarableOpsTests15, Test_AdjustContrast_2) {
auto x = NDArrayFactory::create<float>('c', {1, 4,4,3});
auto e = NDArrayFactory::create<float>('c', {1, 4,4,3}, {
-21.5, -20.5, -19.5, -15.5, -14.5, -13.5, -9.5, -8.5, -7.5, -3.5, -2.5, -1.5,
2.5, 3.5, 4.5, 8.5, 9.5, 10.5, 14.5, 15.5, 16.5, 20.5, 21.5, 22.5,
26.5, 27.5, 28.5, 32.5, 33.5, 34.5, 38.5, 39.5, 40.5, 44.5, 45.5, 46.5,
50.5, 51.5, 52.5, 56.5, 57.5, 58.5, 62.5, 63.5, 64.5, 68.5, 69.5, 70.5
});
x.linspace(1.);
nd4j::ops::adjust_contrast op;
auto result = op.execute({&x}, {2.}, {}, {});
ASSERT_EQ(Status::OK(), result->status());
auto out = result->at(0);
// out->printIndexedBuffer("Adjusted Constrast");
ASSERT_TRUE(e.equalsTo(out));
delete result;
}
TEST_F(DeclarableOpsTests15, Test_AdjustContrast_3) {
auto x = NDArrayFactory::create<float>('c', {1, 4,4,3});
auto e = NDArrayFactory::create<float>('c', {1, 4,4,3}, {
-21.5, -20.5, -19.5, -15.5, -14.5, -13.5, -9.5, -8.5, -7.5, -3.5, -2.5, -1.5,
2.5, 3.5, 4.5, 8.5, 9.5, 10.5, 14.5, 15.5, 16.5, 20.5, 21.5, 22.5,
26.5, 27.5, 28.5, 32.5, 33.5, 34.5, 38.5, 39.5, 40.5, 44.5, 45.5, 46.5,
50.5, 51.5, 52.5, 56.5, 57.5, 58.5, 62.5, 63.5, 64.5, 68.5, 69.5, 70.5
});
x.linspace(1.);
nd4j::ops::adjust_contrast_v2 op;
auto result = op.execute({&x}, {2.}, {}, {});
ASSERT_EQ(Status::OK(), result->status());
auto out = result->at(0);
// out->printIndexedBuffer("Adjusted Constrast");
ASSERT_TRUE(e.equalsTo(out));
delete result;
}
TEST_F(DeclarableOpsTests15, Test_AdjustContrast_4) {
auto x = NDArrayFactory::create<double>('c', {4, 4, 3});
auto e = NDArrayFactory::create<double>('c', {4, 4, 3}, {
-21.5, -20.5, -19.5, -15.5, -14.5, -13.5, -9.5, -8.5, -7.5, -3.5, -2.5, -1.5,
2.5, 3.5, 4.5, 8.5, 9.5, 10.5, 14.5, 15.5, 16.5, 20.5, 21.5, 22.5,
26.5, 27.5, 28.5, 32.5, 33.5, 34.5, 38.5, 39.5, 40.5, 44.5, 45.5, 46.5,
50.5, 51.5, 52.5, 56.5, 57.5, 58.5, 62.5, 63.5, 64.5, 68.5, 69.5, 70.5
});
x.linspace(1.);
nd4j::ops::adjust_contrast_v2 op;
auto result = op.execute({&x}, {2.}, {}, {});
ASSERT_EQ(Status::OK(), result->status());
auto out = result->at(0);
// out->printIndexedBuffer("Adjusted Constrast");
ASSERT_TRUE(e.equalsTo(out));
delete result;
}
TEST_F(DeclarableOpsTests15, Test_BitCast_1) {
auto x = NDArrayFactory::create<float>('c', {2, 2, 2});
auto e = NDArrayFactory::create<double>('c', {2, 2}, {2., 512., 8192., 131072.032 });
x.linspace(1.);
nd4j::ops::bitcast op;
auto result = op.execute({&x}, {}, {nd4j::DataType::DOUBLE}, {});
ASSERT_EQ(Status::OK(), result->status());
auto out = result->at(0);
// out->printIndexedBuffer("Casted result");
ASSERT_TRUE(e.equalsTo(out));
delete result;
}
TEST_F(DeclarableOpsTests15, Test_BitCast_2) {
auto x = NDArrayFactory::create<float>('c', {2, 4});
auto e = NDArrayFactory::create<float16>('c', {2, 4, 2}, {0, 1.875, 0, 2., 0, 2.125, 0, 2.25,
0, 2.312, 0, 2.375, 0, 2.438, 0., 2.5});
x.linspace(1.);
nd4j::ops::bitcast op;
auto result = op.execute({&x}, {}, {nd4j::DataType::HALF}, {});
ASSERT_EQ(Status::OK(), result->status());
auto out = result->at(0);
out->printIndexedBuffer("Casted result");
ASSERT_TRUE(e.equalsTo(out));
delete result;
}
TEST_F(DeclarableOpsTests15, Test_depthwise_bp_1) {
auto in = NDArrayFactory::create<float>('c', {4, 8, 64, 64});
auto w = NDArrayFactory::create<float>('c', {2, 2, 8, 2});

View File

@ -2387,6 +2387,25 @@ TEST_F(DeclarableOpsTests9, thresholdedrelu_test1) {
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, compare_and_bitpack_test1) {
auto x = NDArrayFactory::create<double>('c', {2, 3, 4}, {-12.f, -11.f, -10.f, -9.f, -8.f, -7.f, -6.f, -5.f, -4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f});
auto threshold = NDArrayFactory::create<double>(2.0);
auto exp = NDArrayFactory::create<uint8_t>('c', {2, 3, 4}, {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1});
nd4j::ops::compare_and_bitpack op;
auto result = op.execute({&x, &threshold}, {}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto output = result->at(0);
// output->printIndexedBuffer("Packed to uint8");
ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output));
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, thresholdedrelu_test2) {

View File

@ -276,6 +276,10 @@ object Implicits {
override def hasNegative: Boolean = false
}
case object --- extends IndexRange {
override def hasNegative: Boolean = false
}
implicit class IntRange(val underlying: Int) extends IndexNumberRange {
protected[nd4s] override def asRange(max: => Int): DRange =
DRange(underlying, underlying, true, 1, max)
@ -307,6 +311,10 @@ object Implicits {
def -> = IntRangeFrom(underlying)
}
implicit class IntRangeFromGen1(val underlying: Int) extends AnyVal {
def :: = IntRangeFromReverse(underlying)
}
implicit class IndexRangeWrapper(val underlying: Range) extends IndexNumberRange {
protected[nd4s] override def asRange(max: => Int): DRange =
DRange.from(underlying, max)
@ -377,17 +385,27 @@ object IndexNumberRange {
val endExclusive = if (endR >= 0) endR + diff else max + endR + diff
(start, endExclusive)
}
NDArrayIndex.interval(start, step, end, false)
}
}
sealed trait IndexRange {
/*sealed*/
trait IndexRange {
def hasNegative: Boolean
}
case class IntRangeFrom(underlying: Int) extends IndexRange {
def apply[T](a: T): (Int, T) = (underlying, a)
def apply[T](a: T): (Int, T) =
(underlying, a)
override def toString: String = s"$underlying->"
override def hasNegative: Boolean = false
}
case class IntRangeFromReverse(underlying: Int) extends IndexRange {
def apply[T](a: T): (T, Int) =
(a, underlying)
override def toString: String = s"$underlying->"

View File

@ -23,6 +23,12 @@ import org.slf4j.LoggerFactory
import _root_.scala.annotation.tailrec
package object ops {
case object :: extends IndexRange {
override def hasNegative: Boolean = false
}
}
trait SliceableNDArray[A <: INDArray] {
lazy val log = LoggerFactory.getLogger(classOf[SliceableNDArray[A]])
val underlying: A
@ -68,6 +74,8 @@ trait SliceableNDArray[A <: INDArray] {
@tailrec
def modifyTargetIndices(input: List[IndexRange], i: Int, acc: List[DRange]): List[DRange] = input match {
case ops.:: :: t =>
modifyTargetIndices(t, i + 1, DRange(0, originalShape(i), 1) :: acc)
case -> :: t =>
modifyTargetIndices(t, i + 1, DRange(0, originalShape(i), 1) :: acc)
case ---> :: t =>
@ -137,6 +145,9 @@ trait SliceableNDArray[A <: INDArray] {
case ---> :: t =>
val ellipsised = List.fill(originalShape.length - i - t.size)(->)
modifyTargetIndices(ellipsised ::: t, i, acc)
case --- :: t =>
val ellipsised = List.fill(originalShape.length - i - t.size)(->)
modifyTargetIndices(ellipsised ::: t, i, acc)
case IntRangeFrom(from: Int) :: t =>
val max = originalShape(i)
modifyTargetIndices(t, i + 1, IndexNumberRange.toNDArrayIndex(from, max, false, 1, max) :: acc)

View File

@ -44,6 +44,9 @@ class SameDiffWrapper {
def bind(name: String, dataType: DataType, shape: Array[Long]): SDVariable =
sd.`var`(name, dataType, shape: _*)
def bind(data: INDArray): SDVariable =
sd.`var`("", data)
def bind(name: String, dataType: DataType, shape: Array[Int]): SDVariable =
sd.`var`(name, dataType, shape: _*)
@ -51,18 +54,39 @@ class SameDiffWrapper {
sd.placeHolder(name, dataType, shape: _*)
}
case class SDIndexWrapper(end: Long) {
def ::(start: Long): SDIndex =
SDIndex.interval(start, end)
}
case class SDIndexWrapper1(start: Int) {
def ::(end: Int): SDIndex =
SDIndex.interval(start, end)
}
object --- extends SDIndex {
val thisIndex: SDIndex = SDIndex.all()
}
class SDVariableWrapper {
var thisVariable: SDVariable = null
var isScalar: Boolean = false
val --- : SDIndex = SDIndex.all()
def this(variable: SDVariable) {
this
thisVariable = variable
}
// Indexing
def apply(index: Long): SDVariable = thisVariable.get(SDIndex.point(index))
def apply(index: SDIndex*): SDVariable = thisVariable.get(index: _*)
// Arithmetic
def add(other: Double): Unit = thisVariable.add(other)
def *(other: SDVariable): SDVariable =

View File

@ -15,9 +15,9 @@
******************************************************************************/
package org.nd4s.samediff.implicits
import org.nd4j.autodiff.samediff.{ SDVariable, SameDiff }
import org.nd4j.autodiff.samediff.{ SDIndex, SDVariable, SameDiff }
import org.nd4j.linalg.factory.Nd4j
import org.nd4s.samediff.{ SDVariableWrapper, SameDiffWrapper }
import org.nd4s.samediff.{ SDIndexWrapper, SDVariableWrapper, SameDiffWrapper }
object Implicits {
implicit def SameDiffToWrapper(sd: SameDiff): SameDiffWrapper =
@ -43,4 +43,20 @@ object Implicits {
result.isScalar = true
result
}
implicit def RangeToWrapper(start: Long): SDIndexWrapper = {
val result = new SDIndexWrapper(start)
result
}
implicit def LongToPoint(x: Long): SDIndex =
SDIndex.point(x)
implicit def IntRangeToWrapper(start: Int): SDIndexWrapper = {
val result = new SDIndexWrapper(start)
result
}
implicit def IntToPoint(x: Int): SDIndex =
SDIndex.point(x)
}

View File

@ -48,6 +48,42 @@ trait NDArrayExtractionTestBase extends FlatSpec { self: OrderingForTest =>
assert(extracted == expected)
}
it should "be able to extract a part of 2d matrix with alternative syntax" in {
val ndArray =
Array(
Array(1, 2, 3),
Array(4, 5, 6),
Array(7, 8, 9)
).mkNDArray(ordering)
val extracted = ndArray(1 :: 3, 0 :: 2)
val expected =
Array(
Array(4, 5),
Array(7, 8)
).mkNDArray(ordering)
assert(extracted == expected)
}
it should "be able to extract a part of 2d matrix with mixed syntax" in {
val ndArray =
Array(
Array(1, 2, 3),
Array(4, 5, 6),
Array(7, 8, 9)
).mkNDArray(ordering)
val extracted = ndArray(1 -> 3, 0 :: 2)
val expected =
Array(
Array(4, 5),
Array(7, 8)
).mkNDArray(ordering)
assert(extracted == expected)
}
it should "be able to extract a part of 2d matrix with double data" in {
val ndArray = (5 to 8).map(_.toDouble).mkNDArray(Array(2, 2), NDOrdering.C)
@ -171,6 +207,9 @@ trait NDArrayExtractionTestBase extends FlatSpec { self: OrderingForTest =>
val ellipsised = ndArray(--->)
assert(ellipsised == ndArray)
val ellipsised1 = ndArray(---)
assert(ellipsised1 == ndArray)
}
it should "accept partially ellipsis indices" in {

View File

@ -20,6 +20,7 @@ import org.nd4j.linalg.api.buffer.DataType
import org.nd4j.linalg.api.ndarray.INDArray
import org.nd4j.linalg.factory.Nd4j
import org.nd4s.Implicits._
import org.nd4s.NDOrdering
import org.nd4s.samediff.implicits.Implicits._
import org.scalatest.{ FlatSpec, Matchers }
@ -205,9 +206,41 @@ class MathTest extends FlatSpec with Matchers {
implicit val sd = SameDiff.create
val arr = Nd4j.linspace(1, 100, 100).reshape('c', 10L, 10L)
val x = sd.`var`(arr)
val x = sd.bind(arr)
val y = new SDVariableWrapper(x)
x.get(SDIndex.point(0)).getArr shouldBe y(0).getArr
}
"SDVariable " should "be indexable in 2d" in {
implicit val sd = SameDiff.create
val arr = Nd4j.linspace(DataType.FLOAT, 1.0, 1.0, 9).reshape(3, 3)
val x = sd.bind(arr)
x(0, ---).eval shouldBe x(SDIndex.point(0), SDIndex.all()).eval
val slice1 = x.get(SDIndex.interval(0, 2), SDIndex.all()).eval
val slice2 = x(0 :: 2, ---).eval
slice1 shouldBe slice2
}
"SDVariable " should "be indexable in 3d" in {
implicit val sd = SameDiff.create
val arr = Nd4j.linspace(DataType.FLOAT, 1.0, 1.0, 18).reshape(3, 3, 2)
val x = sd.bind(arr)
x.get(SDIndex.all(), SDIndex.all(), SDIndex.all()).eval shouldBe x(---, ---, ---).eval
x.get(SDIndex.point(0), SDIndex.all(), SDIndex.all()).eval shouldBe x(0, ---, ---).eval
x.get(SDIndex.point(0), SDIndex.point(0), SDIndex.all()).eval shouldBe x(0, 0, ---).eval
x.get(SDIndex.point(0), SDIndex.point(0), SDIndex.point(0)).eval shouldBe x(0, 0, 0).eval
x.get(SDIndex.interval(0, 2), SDIndex.point(0), SDIndex.point(0)).eval shouldBe x(0 :: 2, 0, 0).eval
x.get(SDIndex.interval(0, 2), SDIndex.interval(0, 1), SDIndex.interval(0, 2)).eval shouldBe x(0 :: 2,
0 :: 1,
0 :: 2).eval
x.get(SDIndex.interval(0, 2), SDIndex.interval(0, 1), SDIndex.all()).eval shouldBe x(0 :: 2, 0 :: 1, ---).eval
}
}