Merge remote-tracking branch 'konduit/master'
commit
2d750b69e5
|
@ -1422,6 +1422,7 @@ namespace nd4j {
|
|||
|
||||
template <typename T>
|
||||
void p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const T value);
|
||||
void p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, NDArray const& value);
|
||||
|
||||
|
||||
template <typename T>
|
||||
|
|
|
@ -4071,6 +4071,24 @@ void NDArray::p(const Nd4jLong i, const NDArray& scalar) {
|
|||
NDArray::registerPrimaryUse({this}, {&scalar});
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const NDArray& scalar) {
|
||||
|
||||
if(!scalar.isScalar())
|
||||
throw std::invalid_argument("NDArray::p method: input array must be scalar!");
|
||||
if (i >= _length)
|
||||
throw std::invalid_argument("NDArray::p(i, NDArray_scalar): input index is out of array length !");
|
||||
|
||||
// void *p = reinterpret_cast<void *>(scalar.getBuffer());
|
||||
Nd4jLong coords[4] = {i, j, k, l};
|
||||
auto xOffset = shape::getOffset(getShapeInfo(), coords);
|
||||
|
||||
NDArray::preparePrimaryUse({this}, {&scalar}, true);
|
||||
// BUILD_SINGLE_PARTIAL_SELECTOR(dataType(), templatedSet<, T>(this->getBuffer(), xOffset, p), LIBND4J_TYPES);
|
||||
BUILD_SINGLE_SELECTOR(scalar.dataType(), templatedSet, (this->getBuffer(), xOffset, scalar.dataType(), scalar.getBuffer()), LIBND4J_TYPES);
|
||||
NDArray::registerPrimaryUse({this}, {&scalar});
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
void NDArray::addRowVector(const NDArray *row, NDArray *target) const {
|
||||
|
||||
|
|
|
@ -77,7 +77,8 @@
|
|||
(27, LogicalOr) ,\
|
||||
(28, LogicalXor) ,\
|
||||
(29, LogicalNot) ,\
|
||||
(30, LogicalAnd)
|
||||
(30, LogicalAnd), \
|
||||
(31, DivideNoNan)
|
||||
|
||||
// these ops return same data type as input
|
||||
#define TRANSFORM_SAME_OPS \
|
||||
|
@ -243,8 +244,8 @@
|
|||
(42, LstmClip), \
|
||||
(43, TruncateMod) ,\
|
||||
(44, SquaredReverseSubtract) ,\
|
||||
(45, ReversePow)
|
||||
|
||||
(45, ReversePow), \
|
||||
(46, DivideNoNan)
|
||||
|
||||
|
||||
|
||||
|
@ -378,7 +379,8 @@
|
|||
(34, AMaxPairwise), \
|
||||
(35, AMinPairwise) ,\
|
||||
(36, TruncateMod), \
|
||||
(37, ReplaceNans)
|
||||
(37, ReplaceNans), \
|
||||
(38, DivideNoNan)
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -46,6 +46,7 @@ namespace nd4j {
|
|||
static BroadcastOpsTuple Add();
|
||||
static BroadcastOpsTuple Assign();
|
||||
static BroadcastOpsTuple Divide();
|
||||
static BroadcastOpsTuple DivideNoNan();
|
||||
static BroadcastOpsTuple Multiply();
|
||||
static BroadcastOpsTuple Subtract();
|
||||
};
|
||||
|
|
|
@ -0,0 +1,57 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author George A. Shulinok <sgazeos@gmail.com>
|
||||
//
|
||||
|
||||
#include <op_boilerplate.h>
|
||||
#if NOT_EXCLUDED(OP_divide_no_nan)
|
||||
|
||||
#include <ops/declarable/generic/helpers/BroadcastHelper.h>
|
||||
#include <ops/declarable/CustomOperations.h>
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
BROADCASTABLE_OP_IMPL(divide_no_nan, 0, 0) {
|
||||
auto x = INPUT_VARIABLE(0);
|
||||
auto y = INPUT_VARIABLE(1);
|
||||
auto z = OUTPUT_VARIABLE(0);
|
||||
|
||||
BROADCAST_CHECK_EMPTY(x,y,z);
|
||||
|
||||
REQUIRE_TRUE(!y->isB(), 0, "DIVIDE_NO_NAN OP: you can't divide by bool array!");
|
||||
auto tZ = BroadcastHelper::broadcastApply(BroadcastOpsTuple::DivideNoNan(), x, y, z);
|
||||
if (tZ == nullptr)
|
||||
return ND4J_STATUS_KERNEL_FAILURE;
|
||||
else if (tZ != z) {
|
||||
OVERWRITE_RESULT(tZ);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
DECLARE_SYN(Div, divide);
|
||||
|
||||
DECLARE_TYPES(divide_no_nan) {
|
||||
getOpDescriptor()
|
||||
->setAllowedInputTypes(0, DataType::ANY)
|
||||
->setAllowedInputTypes(1, DataType::ANY)
|
||||
->setAllowedOutputTypes(0, DataType::INHERIT);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
|
@ -0,0 +1,92 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author George A. Shulinok <sgazeos@gmail.com>
|
||||
//
|
||||
|
||||
#include <op_boilerplate.h>
|
||||
#if NOT_EXCLUDED(OP_bitcast)
|
||||
|
||||
#include <array/DataTypeUtils.h>
|
||||
#include <ops/declarable/CustomOperations.h>
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
CUSTOM_OP_IMPL(bitcast, 1, 1, false, 0, 1) {
|
||||
auto input = INPUT_VARIABLE(0);
|
||||
auto output = OUTPUT_VARIABLE(0);
|
||||
// when empty - nothing to do
|
||||
if(input->isEmpty()){
|
||||
REQUIRE_TRUE(output->isEmpty(), 0, "BITCAST: If input is empty, output array must also be empty.");
|
||||
return Status::OK();
|
||||
}
|
||||
// buffers for both input and output should be equals
|
||||
DataBuffer buf(input->buffer(), input->specialBuffer(), input->lengthOf() * input->sizeOfT(), input->dataType());
|
||||
*(output->dataBuffer()) = buf;
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
DECLARE_SYN(BitCast, bitcast);
|
||||
|
||||
DECLARE_SHAPE_FN(bitcast) {
|
||||
auto inShape = inputShape->at(0);
|
||||
auto inputRank = shape::rank(inShape);
|
||||
auto it = INT_ARG(0);
|
||||
DataType newType = DataTypeUtils::fromInt(it);
|
||||
DataType oldType = ArrayOptions::dataType(inShape);
|
||||
// correct output shape to conform with output data type
|
||||
auto inputSize = DataTypeUtils::sizeOf(oldType);
|
||||
auto outputSize = DataTypeUtils::sizeOf(newType);
|
||||
|
||||
if (shape::length(inShape) == 0)
|
||||
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(inShape, newType)));
|
||||
|
||||
if (inputSize == outputSize) {
|
||||
// only type should be changed
|
||||
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(inShape, newType)));
|
||||
}
|
||||
else if (inputSize > outputSize) {
|
||||
// range of output increased by 1 with inputSize / outputSize as last dimension
|
||||
std::vector<Nd4jLong> shapeOf(inputRank + 1);
|
||||
int i;
|
||||
for (i = 0; i < inputRank; ++i) {
|
||||
shapeOf[i] = inShape[i + 1];
|
||||
}
|
||||
shapeOf[i] = inputSize / outputSize;
|
||||
auto outputShape = ConstantShapeHelper::getInstance()->createShapeInfo(newType, shape::order(inShape), shapeOf);
|
||||
return SHAPELIST(outputShape);
|
||||
}
|
||||
REQUIRE_TRUE(shape::sizeAt(inShape, -1) == outputSize / inputSize, 0, "BITCAST: %ull > %ull. So last dimension should be %ull, but %i given.", inputSize, outputSize, outputSize / inputSize, shape::sizeAt(inShape, -1));
|
||||
std::vector<Nd4jLong> shapeOf(inputRank - 1);
|
||||
|
||||
for (auto i = 0; i < shapeOf.size(); ++i) {
|
||||
shapeOf[i] = inShape[i + 1];
|
||||
}
|
||||
|
||||
auto outputShape = ConstantShapeHelper::getInstance()->createShapeInfo(newType, shape::order(inShape), shapeOf);
|
||||
return SHAPELIST(outputShape);
|
||||
}
|
||||
|
||||
DECLARE_TYPES(bitcast) {
|
||||
getOpDescriptor()
|
||||
->setAllowedInputTypes(nd4j::DataType::ANY)
|
||||
->setAllowedOutputTypes(nd4j::DataType::ANY);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
|
@ -0,0 +1,98 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author George A. Shulinok <sgazeos@gmail.com>
|
||||
//
|
||||
|
||||
#include <op_boilerplate.h>
|
||||
#if NOT_EXCLUDED(OP_adjust_contrast)
|
||||
|
||||
#include <ops/declarable/headers/parity_ops.h>
|
||||
#include <NDArrayFactory.h>
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
|
||||
CONFIGURABLE_OP_IMPL(adjust_contrast, 1, 1, true, 1, 0) {
|
||||
|
||||
auto input = INPUT_VARIABLE(0);
|
||||
auto output = OUTPUT_VARIABLE(0);
|
||||
|
||||
const double factor = T_ARG(0);
|
||||
|
||||
REQUIRE_TRUE(input->rankOf() > 2, 0, "ADJUST_CONTRAST: op expects rank of input array to be >= 3, but got %i instead", input->rankOf());
|
||||
REQUIRE_TRUE(input->sizeAt(-1) == 3, 0, "ADJUST_CONTRAST: operation expects image with 3 channels (R, G, B), but got %i instead", input->sizeAt(-1));
|
||||
// compute mean before
|
||||
// fill up axes vector first
|
||||
std::vector<int> axes(input->rankOf() - 1);
|
||||
for (auto i = 0; i < axes.size(); ++i)
|
||||
axes[i] = i;
|
||||
// mean as reduction for last dimension set
|
||||
auto mean = input->reduceAlongDims(reduce::Mean, axes);
|
||||
|
||||
NDArray factorT(output->dataType(), block.launchContext()); // = NDArrayFactory::create(factor, block.launchContext());
|
||||
factorT.p(0, factor);
|
||||
// this is contrast calculation
|
||||
*output = (*input - mean) * factorT + mean;
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
DECLARE_TYPES(adjust_contrast) {
|
||||
getOpDescriptor()->setAllowedInputTypes(nd4j::DataType::ANY)
|
||||
->setAllowedOutputTypes({ALL_FLOATS})
|
||||
->setSameMode(true);
|
||||
}
|
||||
|
||||
|
||||
CONFIGURABLE_OP_IMPL(adjust_contrast_v2, 1, 1, true, 1, 0) {
|
||||
|
||||
auto input = INPUT_VARIABLE(0);
|
||||
auto output = OUTPUT_VARIABLE(0);
|
||||
|
||||
const double factor = T_ARG(0);
|
||||
|
||||
REQUIRE_TRUE(input->rankOf() > 2, 0, "ADJUST_CONTRAST: op expects rank of input array to be >= 3, but got %i instead", input->rankOf());
|
||||
REQUIRE_TRUE(input->sizeAt(-1) == 3, 0, "ADJUST_CONTRAST: operation expects image with 3 channels (R, G, B), but got %i instead", input->sizeAt(-1));
|
||||
|
||||
// compute mean before
|
||||
std::vector<int> axes(input->rankOf() - 1);
|
||||
for (auto i = 0; i < axes.size(); ++i)
|
||||
axes[i] = i;
|
||||
|
||||
// mean as reduction for last dimension set
|
||||
auto mean = input->reduceAlongDims(reduce::Mean, axes);
|
||||
|
||||
// result as (x - mean) * factor + mean
|
||||
std::unique_ptr<NDArray> temp(input->dup());
|
||||
input->applyTrueBroadcast(BroadcastOpsTuple::Subtract(), &mean, temp.get());
|
||||
temp->applyScalar(scalar::Multiply, factor);
|
||||
temp->applyTrueBroadcast(BroadcastOpsTuple::Add(), &mean, output);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
DECLARE_TYPES(adjust_contrast_v2) {
|
||||
getOpDescriptor()->setAllowedInputTypes(nd4j::DataType::ANY)
|
||||
->setAllowedOutputTypes({ALL_FLOATS})
|
||||
->setSameMode(true);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
|
@ -0,0 +1,60 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author sgazeos@gmail.com
|
||||
//
|
||||
|
||||
#include <ops/declarable/generic/helpers/BroadcastHelper.h>
|
||||
#include <ops/declarable/headers/parity_ops.h>
|
||||
#include <ops/declarable/headers/datatypes.h>
|
||||
#include <NDArrayFactory.h>
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
CUSTOM_OP_IMPL(compare_and_bitpack, 2, 1, false, 0, 0) {
|
||||
auto x = INPUT_VARIABLE(0);
|
||||
auto y = INPUT_VARIABLE(1);
|
||||
auto z = OUTPUT_VARIABLE(0);
|
||||
auto z0 = NDArrayFactory::create<bool>(x->ordering(), x->getShapeAsVector());
|
||||
BROADCAST_CHECK_EMPTY(x, y, (&z0));
|
||||
|
||||
auto tZ = BroadcastHelper::broadcastApply(BROADCAST_BOOL(GreaterThan), x, y, &z0);
|
||||
bitcast res;
|
||||
auto status = res.execute({tZ}, {z}, {}, {DataType::UINT8}, {}, false);
|
||||
if (tZ != &z0) {
|
||||
delete tZ;
|
||||
}
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
DECLARE_TYPES(compare_and_bitpack) {
|
||||
getOpDescriptor()
|
||||
->setAllowedInputTypes(0, DataType::ANY)
|
||||
->setAllowedInputTypes(1, DataType::ANY)
|
||||
->setAllowedOutputTypes(0, DataType::UINT8);
|
||||
}
|
||||
|
||||
DECLARE_SHAPE_FN(compare_and_bitpack) {
|
||||
auto inShape = inputShape->at(0);
|
||||
DataType newType = DataType::UINT8;
|
||||
|
||||
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(inShape, newType)));
|
||||
}
|
||||
|
||||
}
|
||||
}
|
|
@ -0,0 +1,49 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author George A. Shulinok <sgazeos@gmail.com>
|
||||
//
|
||||
|
||||
#include <op_boilerplate.h>
|
||||
#if NOT_EXCLUDED(OP_draw_bounding_boxes)
|
||||
|
||||
#include <ops/declarable/headers/parity_ops.h>
|
||||
#include <ops/declarable/helpers/image_draw_bounding_boxes.h>
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
OP_IMPL(draw_bounding_boxes, 3, 1, true) {
|
||||
|
||||
auto images = INPUT_VARIABLE(0);
|
||||
auto boxes = INPUT_VARIABLE(1);
|
||||
auto colors = INPUT_VARIABLE(2);
|
||||
auto output = OUTPUT_VARIABLE(0);
|
||||
|
||||
helpers::drawBoundingBoxesFunctor(block.launchContext(), images, boxes, colors, output);
|
||||
return ND4J_STATUS_OK;
|
||||
}
|
||||
|
||||
DECLARE_TYPES(draw_bounding_boxes) {
|
||||
getOpDescriptor()
|
||||
->setAllowedInputTypes(0, {HALF, FLOAT32})// TF allows HALF and FLOAT32 only
|
||||
->setAllowedInputTypes(1, {FLOAT32}) // as TF
|
||||
->setAllowedInputTypes(2, {FLOAT32}) // as TF
|
||||
->setAllowedOutputTypes({HALF, FLOAT32}); // TF allows HALF and FLOAT32 only
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
|
@ -0,0 +1,71 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author George Shulinok <sgazeos@gmail.com>, created on 08.10.2019
|
||||
//
|
||||
|
||||
#include <op_boilerplate.h>
|
||||
#if NOT_EXCLUDED(OP_fake_quant_with_min_max_vars_per_channel)
|
||||
|
||||
#include <ops/declarable/CustomOperations.h>
|
||||
#include <ops/declarable/helpers/fake_quantization.h>
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
CONFIGURABLE_OP_IMPL(fake_quant_with_min_max_vars_per_channel, 1, 1, true, 0, 0) {
|
||||
|
||||
auto x = INPUT_VARIABLE(0);
|
||||
auto min = INPUT_VARIABLE(1);
|
||||
auto max = INPUT_VARIABLE(2);
|
||||
|
||||
REQUIRE_TRUE(block.width() == 3 || block.getTArguments()->size() == 2, 0, "fake_quant_with_min_max_vars_per_channel: No minimum/maximum values provided by either input arrays or TArgs");
|
||||
auto depth = x->sizeAt(-1);
|
||||
REQUIRE_TRUE(min->rankOf() == 1 && max->rankOf() == 1 && min->lengthOf() == max->lengthOf(), 0,
|
||||
"fake_quant_with_min_max_vars_per_channel: Min and Max should be 1D tensors with the same length");
|
||||
REQUIRE_TRUE(depth == min->lengthOf(), 0, "fake_quant_with_min_max_vars_per_channel: Min length should be"
|
||||
" %lld, but %lld occurs.", depth, min->lengthOf());
|
||||
|
||||
REQUIRE_TRUE(depth == max->lengthOf(), 0, "fake_quant_with_min_max_vars_per_channel: Max length should be"
|
||||
"%lld, but %lld occurs.", depth, max->lengthOf());
|
||||
|
||||
auto output = OUTPUT_VARIABLE(0);
|
||||
int numBits = 8;
|
||||
if (block.getIArguments() && block.getIArguments()->size())
|
||||
numBits = INT_ARG(0);
|
||||
bool narrowed = false;
|
||||
//INT_ARG(1);
|
||||
if (block.getIArguments()->size() == 2) {
|
||||
numBits = INT_ARG(0);
|
||||
narrowed = INT_ARG(1);
|
||||
REQUIRE_TRUE(numBits > 1 && numBits < 17, 0, "fake_quant_with_min_max_vars_per_channel: Number of bits"
|
||||
" for quatization should be in between 2 and 16, but %i "
|
||||
"was given.", numBits);
|
||||
}
|
||||
helpers::fakeQuantWithMinMaxVarsPerChannel(block.launchContext(), x, min, max, numBits, narrowed, output);
|
||||
return ND4J_STATUS_OK;
|
||||
}
|
||||
|
||||
DECLARE_TYPES(fake_quant_with_min_max_vars_per_channel) {
|
||||
getOpDescriptor()
|
||||
-> setAllowedOutputTypes({ALL_FLOATS})
|
||||
-> setAllowedInputTypes({ALL_INTS, ALL_FLOATS});
|
||||
}
|
||||
|
||||
DECLARE_SYN(fake_quant_with_min_max_args_per_channel, fake_quant_with_min_max_vars_per_channel);
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
|
@ -156,6 +156,18 @@ namespace nd4j {
|
|||
DECLARE_CUSTOM_OP(divide_bp, 3, 2, false, 0, 0);
|
||||
#endif
|
||||
|
||||
/**
|
||||
* This is one of auto-broadcastable operations. It accepts 2 operands, and operation is applied based on their shapes:
|
||||
* 1) if shapes are equal that's pairwise operation, result will have the same shape.
|
||||
* 2) if shape X is scalar and shape Y is array - result will have shape equal to Y.
|
||||
* 3) if shape X is array and shape Y is scalar - result will have shape equal to X.
|
||||
* 4) if shape X and Y are both arrays, but shapes aren't equal - result shape will be broadcast result.
|
||||
*
|
||||
* This operation returns Z = Divide(X, Y) with exception, 0 if Y = 0
|
||||
*/
|
||||
#if NOT_EXCLUDED(OP_divide_no_nan)
|
||||
DECLARE_BROADCASTABLE_OP(divide_no_nan, 0, 0);
|
||||
#endif
|
||||
/**
|
||||
* This is one of auto-broadcastable operations. It accepts 2 operands, and operation is applied based on their shapes:
|
||||
* 1) if shapes are equal that's pairwise operation, result will have the same shape.
|
||||
|
|
|
@ -99,6 +99,14 @@ namespace nd4j {
|
|||
#if NOT_EXCLUDED(OP_cast)
|
||||
DECLARE_CUSTOM_OP(cast, 1, 1, false, 0, 1);
|
||||
#endif
|
||||
/**
|
||||
* This operation change type of input and modified shape of output to conform with given data type
|
||||
*
|
||||
* all as above op
|
||||
* */
|
||||
#if NOT_EXCLUDED(OP_bitcast)
|
||||
DECLARE_CUSTOM_OP(bitcast, 1, 1, false, 0, 1);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -600,6 +600,22 @@ namespace nd4j {
|
|||
DECLARE_CONFIGURABLE_OP(adjust_saturation, 1, 1, true, 1, -2);
|
||||
#endif
|
||||
|
||||
/**
|
||||
* This operation adjusts image contrast by given factor ( z = (x - mean) * factor + mean )
|
||||
* Input arrays:
|
||||
* 0 - input array with rank >= 3, must have last one dimension equal 3, that is dimension containing channels.
|
||||
*
|
||||
* T arguments:
|
||||
* 0 - contrast factor
|
||||
*
|
||||
*/
|
||||
#if NOT_EXCLUDED(OP_adjust_contrast)
|
||||
DECLARE_CONFIGURABLE_OP(adjust_contrast, 1, 1, true, 1, 0);
|
||||
DECLARE_CONFIGURABLE_OP(adjust_contrast_v2, 1, 1, true, 1, 0);
|
||||
#endif
|
||||
|
||||
|
||||
|
||||
|
||||
/**
|
||||
* This operation rearranges data from depth into blocks of spatial data. This is the reverse transformation
|
||||
|
@ -1228,6 +1244,23 @@ namespace nd4j {
|
|||
DECLARE_CUSTOM_OP(extract_image_patches, 1, 1, false, 0, 7);
|
||||
#endif
|
||||
|
||||
/**
|
||||
* draw_bounding_boxes op - modified input image with given colors exept given boxes.
|
||||
*
|
||||
* input params:
|
||||
* 0 - images tensor (4D) with shape {batch, width, height, channels}, where channes is 1 (BW image),
|
||||
* 3 (RGB) or 4 (RGBA)
|
||||
* 1 - boxes tensor (3D) with shape {batch, number_of_boxes, 4} where last dimension encoded as
|
||||
* (y_min, x_min, y_max, x_max), all values in between 0. and 1.
|
||||
* 2 - colours tensor (2D) with shape {number_of_boxes, channels} -- bordering color set (palette)
|
||||
*
|
||||
* output:
|
||||
* 0 - 4D tensor with same shape as images (input 0)
|
||||
*/
|
||||
#if NOT_EXCLUDED(OP_draw_bounding_boxes)
|
||||
DECLARE_OP(draw_bounding_boxes, 3, 1, true);
|
||||
#endif
|
||||
|
||||
/**
|
||||
* roll - op porting from numpy (https://docs.scipy.org/doc/numpy-1.14.0/reference/generated/numpy.roll.html)
|
||||
*
|
||||
|
@ -1715,6 +1748,39 @@ namespace nd4j {
|
|||
DECLARE_CONFIGURABLE_OP(fake_quant_with_min_max_vars, 3, 1, true, 0, -2);
|
||||
#endif
|
||||
|
||||
/**
|
||||
* fake_quant_with_min_max_vals_per_channel - tf.quantization.fake_quant_with_min_max_vars_per_channel
|
||||
*
|
||||
* input params:
|
||||
* 0 - NDArray (input) - at least 2D.
|
||||
* 1 - 1D Tensor - min values (min length equals to last dim of input)
|
||||
* 2 - 1D Tensor - max value (length equals to min)
|
||||
*
|
||||
* int params (optional):
|
||||
* 0 - num_bits (allowed interval [2, 16], default 8)
|
||||
* 1 - narrow_range (default False)
|
||||
*
|
||||
* output:
|
||||
* 0 - NDArray with the same shape as input
|
||||
*/
|
||||
#if NOT_EXCLUDED(OP_fake_quant_with_min_max_vars_per_channel)
|
||||
DECLARE_CONFIGURABLE_OP(fake_quant_with_min_max_vars_per_channel, 3, 1, true, 0, -2);
|
||||
#endif
|
||||
|
||||
/**
|
||||
* compare_and_bitpack - compare with greater and pack result with uint8
|
||||
*
|
||||
* input params:
|
||||
* 0 - NDArray (input)
|
||||
* 1 - 0D Tensor - threshold
|
||||
*
|
||||
*
|
||||
* output:
|
||||
* 0 - NDArray with the same shape as input and type uint8
|
||||
*/
|
||||
#if NOT_EXCLUDED(OP_compare_and_bitpack)
|
||||
DECLARE_CUSTOM_OP(compare_and_bitpack, 2, 1, false, 0, 0);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -25,74 +25,89 @@ namespace nd4j {
|
|||
namespace ops {
|
||||
namespace helpers {
|
||||
|
||||
//
|
||||
// nudge - nudged min max over scale
|
||||
// scale = (Max - Min) / (quantMax - quantMin)
|
||||
// quantMin = 0 or 1, quantMax = 2^b - 1 == (1 << b) - 1
|
||||
//
|
||||
template <typename T>
|
||||
static void nudge(T min, T max, int quantMin, int quantMax, T* scale, T* nudgedMin, T* nudgedMax) {
|
||||
// floating point instead integers
|
||||
T quantMaxF = static_cast<T>(quantMax);
|
||||
T quantMinF = static_cast<T>(quantMin);
|
||||
// compute scale
|
||||
*scale = (max - min) / (quantMaxF - quantMinF);
|
||||
// compute left bound point
|
||||
auto zeroPointFromMin = quantMinF - min / *scale;
|
||||
// bound zero point to conform with range [0 or 1, 2^b - 1]
|
||||
uint16_t const nudged_zero_point = [zeroPointFromMin, quantMin, quantMax, quantMaxF, quantMinF] {
|
||||
if (zeroPointFromMin < quantMinF) {
|
||||
return static_cast<uint16_t>(quantMin);
|
||||
}
|
||||
if (zeroPointFromMin > quantMaxF) {
|
||||
return static_cast<uint16_t>(quantMax);
|
||||
}
|
||||
return nd4j::math::nd4j_round<T,uint16_t>(zeroPointFromMin);
|
||||
}();
|
||||
// compute nudged min and max with computed nudged zero point
|
||||
*nudgedMin = (quantMinF - nudged_zero_point) * (*scale);
|
||||
*nudgedMax = (quantMaxF - nudged_zero_point) * (*scale);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void fakeQuantWithMinMaxVarsPerChannel_(NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output) {
|
||||
int lowIntBound = narrowed ? 1 : 0; // 0 or 1
|
||||
int upperIntBound = (1 << numBits) - 1; // 2^b - 1
|
||||
auto channels = input->sizeAt(-1); // last dimension
|
||||
|
||||
PRAGMA_OMP_PARALLEL_FOR
|
||||
for (auto i = 0; i < channels; i++) {
|
||||
T scale, nudged_min, nudged_max;
|
||||
// nudge min and max first, with scale computing
|
||||
nudge<T>(min->t<T>(i), max->t<T>(i), lowIntBound, upperIntBound, &scale, &nudged_min, &nudged_max);
|
||||
// slide using last dimension and process all for given channel
|
||||
for (auto e = 0; e < input->lengthOf(); e += channels) {
|
||||
T val = input->t<T>(e + i);
|
||||
if ( val <= nudged_min)
|
||||
val = nudged_min;
|
||||
else if (val >= nudged_max)
|
||||
val = nudged_max;
|
||||
// quantization itself
|
||||
output->t<T>(e + i) = math::nd4j_floor<T,T>((val - nudged_min)/scale + T(0.5)) * scale + nudged_min;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void fakeQuantWithMinMaxVars_(NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output) {
|
||||
int lowIntBound = narrowed ? 1 : 0;
|
||||
int upperIntBound = 1 << numBits - 1;
|
||||
int upperIntBound = (1 << numBits) - 1;
|
||||
|
||||
const float quant_min_float = static_cast<float>(lowIntBound);
|
||||
const float quant_max_float = static_cast<float>(upperIntBound);
|
||||
T scale = (max->t<T>(0) - min->t<T>(0)) / (quant_max_float - quant_min_float);
|
||||
const T zero_point_from_min = quant_min_float - min->e<T>(0) / scale;
|
||||
const uint16_t nudged_zero_point = [zero_point_from_min, lowIntBound,
|
||||
quant_min_float, upperIntBound,
|
||||
quant_max_float] {
|
||||
if (zero_point_from_min < quant_min_float) {
|
||||
return static_cast<uint16_t>(lowIntBound);
|
||||
}
|
||||
if (zero_point_from_min > quant_max_float) {
|
||||
return static_cast<uint16_t>(upperIntBound);
|
||||
}
|
||||
return static_cast<uint16_t>(roundf(zero_point_from_min));
|
||||
}();
|
||||
|
||||
auto nudged_min = (quant_min_float - nudged_zero_point) * (scale);
|
||||
auto nudged_max = (quant_max_float - nudged_zero_point) * (scale);
|
||||
//input->applyScalar(scalar::CompareAndSet, nudged_max, clamped, nullptr); //.cwiseMin(nudged_max).cwiseMax(nudged_min);
|
||||
//input->applyScalar(scalar::CompareAndSet, nudged_min, clamped, nullptr); //.cwiseMin(nudged_max).cwiseMax(nudged_min);
|
||||
auto wiseMax = LAMBDA_T(x, nudged_min) {
|
||||
if (x < nudged_min) {
|
||||
return nudged_min;
|
||||
T nudgedMin, nudgedMax, scale;
|
||||
// nudge with given min and max and compute scale and nudged min and max
|
||||
nudge<T>(min->t<T>(0), max->t<T>(0), lowIntBound, upperIntBound, &scale, &nudgedMin, &nudgedMax);
|
||||
// quantization as one
|
||||
auto fakeQuantizationWithMinMax = LAMBDA_T(x, nudgedMin, nudgedMax, scale) {
|
||||
T val = x; // boundign value between nudged min and max
|
||||
if (val < nudgedMin) {
|
||||
val = nudgedMin;
|
||||
}
|
||||
return x;
|
||||
|
||||
else if (val > nudgedMax)
|
||||
val = nudgedMax;
|
||||
// converse value with scale and shifted with nudged min
|
||||
return (nd4j::math::nd4j_floor<T,T>((val - nudgedMin)/scale + T(0.5)) * scale + nudgedMin);
|
||||
};
|
||||
auto wiseMin = LAMBDA_T(x, nudged_max) {
|
||||
if (x > nudged_max) {
|
||||
return nudged_max;
|
||||
}
|
||||
return x;
|
||||
};
|
||||
auto scaleTensor(*input); // = NDArrayFactory::create(input->ordering(), input->getShapeAsVector(), input->getWorkspace());
|
||||
auto clamped(*input); // = NDArrayFactory::create(input->ordering(), input->getShapeAsVector(), input->getWorkspace());
|
||||
scaleTensor.assign(scale);
|
||||
input->applyLambda<T>(wiseMin, &clamped);
|
||||
// const auto clamped = inputs.cwiseMin(nudged_max).cwiseMax(nudged_min);
|
||||
clamped.applyLambda<T>(wiseMax, output);
|
||||
// const auto clamped_shifted = clamped - nudged_min;
|
||||
*output -= nudged_min;
|
||||
// auto nudgedScale = scale;
|
||||
(*output) /= scaleTensor;
|
||||
(*output) += T(0.5f);
|
||||
output->applyTransform(transform::Floor, nullptr, nullptr);
|
||||
(*output) *= scaleTensor;
|
||||
(*output) += nudged_min;
|
||||
//output->printIndexedBuffer("FAKE QUANTED");
|
||||
/*
|
||||
const auto nudged_scale_repl = inputs.constant(nudged_scale);
|
||||
|
||||
const auto clamped = inputs.cwiseMin(nudged_max).cwiseMax(nudged_min);
|
||||
const auto clamped_shifted = clamped - nudged_min;
|
||||
*output = (clamped_shifted / nudged_scale_repl + 0.5f).floor() *
|
||||
nudged_scale_repl +
|
||||
nudged_min;
|
||||
*/
|
||||
|
||||
input->applyLambda<T>(fakeQuantizationWithMinMax, output);
|
||||
}
|
||||
|
||||
void fakeQuantWithMinMaxVars(NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output) {
|
||||
BUILD_SINGLE_SELECTOR(input->dataType(), fakeQuantWithMinMaxVars_, (input, min, max, numBits, narrowed, output), FLOAT_TYPES);
|
||||
}
|
||||
void fakeQuantWithMinMaxVarsPerChannel(LaunchContext* context, NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output) {
|
||||
BUILD_SINGLE_SELECTOR(input->dataType(), fakeQuantWithMinMaxVarsPerChannel_, (input, min, max, numBits, narrowed, output), FLOAT_TYPES);
|
||||
}
|
||||
|
||||
BUILD_SINGLE_TEMPLATE(template void fakeQuantWithMinMaxVars_, (NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output), FLOAT_TYPES);
|
||||
|
||||
}
|
||||
|
|
|
@ -0,0 +1,78 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author sgazeos@gmail.com
|
||||
//
|
||||
#include <op_boilerplate.h>
|
||||
#include <NDArray.h>
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
namespace helpers {
|
||||
|
||||
void drawBoundingBoxesFunctor(nd4j::LaunchContext * context, NDArray* images, NDArray* boxes, NDArray* colors, NDArray* output) {
|
||||
// images - batch of 3D images with BW (last dim = 1), RGB (last dim = 3) or RGBA (last dim = 4) channel set
|
||||
// boxes - batch of 2D bounds with last dim (y_start, x_start, y_end, x_end) to compute i and j as
|
||||
// floor((height - 1 ) * y_start) => rowStart, floor((height - 1) * y_end) => rowEnd
|
||||
// floor((width - 1 ) * x_start) => colStart, floor((width - 1) * x_end) => colEnd
|
||||
// height = images->sizeAt(1), width = images->sizeAt(2)
|
||||
// colors - colors for each box given
|
||||
// set up color for each box as frame
|
||||
auto batchSize = images->sizeAt(0);
|
||||
auto height = images->sizeAt(1);
|
||||
auto width = images->sizeAt(2);
|
||||
auto channels = images->sizeAt(3);
|
||||
//auto imageList = images->allTensorsAlongDimension({1, 2, 3}); // split images by batch
|
||||
// auto boxList = boxes->allTensorsAlongDimension({1, 2}); // split boxes by batch
|
||||
auto colorSet = colors->allTensorsAlongDimension({1});
|
||||
output->assign(images); // fill up all output with input images, then fill up boxes
|
||||
|
||||
PRAGMA_OMP_PARALLEL_FOR
|
||||
for (auto b = 0; b < batchSize; ++b) { // loop by batch
|
||||
// auto image = imageList->at(b);
|
||||
|
||||
for (auto c = 0; c < colorSet->size(); ++c) {
|
||||
// box with shape
|
||||
auto internalBox = (*boxes)(b, {0})(c, {0});//internalBoxes->at(c);
|
||||
auto color = colorSet->at(c);
|
||||
auto rowStart = nd4j::math::nd4j_max(Nd4jLong (0), Nd4jLong ((height - 1) * internalBox.e<float>(0)));
|
||||
auto rowEnd = nd4j::math::nd4j_min(Nd4jLong (height - 1), Nd4jLong ((height - 1) * internalBox.e<float>(2)));
|
||||
auto colStart = nd4j::math::nd4j_max(Nd4jLong (0), Nd4jLong ((width - 1) * internalBox.e<float>(1)));
|
||||
auto colEnd = nd4j::math::nd4j_min(Nd4jLong(width - 1), Nd4jLong ((width - 1) * internalBox.e<float>(3)));
|
||||
for (auto y = rowStart; y <= rowEnd; y++) {
|
||||
for (auto e = 0; e < color->lengthOf(); ++e) {
|
||||
output->p(b, y, colStart, e, color->e(e));
|
||||
output->p(b, y, colEnd, e, color->e(e));
|
||||
}
|
||||
}
|
||||
for (auto x = colStart + 1; x < colEnd; x++) {
|
||||
for (auto e = 0; e < color->lengthOf(); ++e) {
|
||||
output->p(b, rowStart, x, e, color->e(e));
|
||||
output->p(b, rowEnd, x, e, color->e(e));
|
||||
}
|
||||
}
|
||||
}
|
||||
// delete internalBoxes;
|
||||
}
|
||||
delete colorSet;
|
||||
// delete imageList;
|
||||
// delete boxList;
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
|
@ -33,65 +33,105 @@ namespace helpers {
|
|||
// narrowed - shrink is true
|
||||
// output - output tensor
|
||||
//
|
||||
template <typename T>
|
||||
static __host__ __device__ void
|
||||
nudge(T min, T max, int quantMin, int quantMax, T* scale, T* nudgedMin, T* nudgedMax) {
|
||||
T quantMaxF = static_cast<T>(quantMax);
|
||||
T quantMinF = static_cast<T>(quantMin);
|
||||
*scale = (max - min) / (quantMaxF - quantMinF);
|
||||
auto zeroPointFromMin = quantMinF - min / *scale;
|
||||
uint16_t const nudgedZeroPoint = [zeroPointFromMin, quantMin, quantMax, quantMaxF, quantMinF] {
|
||||
if (zeroPointFromMin < quantMinF) {
|
||||
return static_cast<uint16_t>(quantMin);
|
||||
}
|
||||
if (zeroPointFromMin > quantMaxF) {
|
||||
return static_cast<uint16_t>(quantMax);
|
||||
}
|
||||
return nd4j::math::nd4j_round<T,uint16_t>(zeroPointFromMin);
|
||||
}();
|
||||
*nudgedMin = (quantMinF - nudgedZeroPoint) * (*scale);
|
||||
*nudgedMax = (quantMaxF - nudgedZeroPoint) * (*scale);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void fakeQuantWithMinMaxVars_(NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output) {
|
||||
int lowIntBound = narrowed?1:0;
|
||||
int upperIntBound = 1 << numBits - 1;
|
||||
min->syncToHost();
|
||||
int upperIntBound = (1 << numBits) - 1;
|
||||
min->syncToHost(); // these are scalars, so nothing much happened
|
||||
max->syncToHost();
|
||||
const float quant_min_float = static_cast<float>(lowIntBound);
|
||||
const float quant_max_float = static_cast<float>(upperIntBound);
|
||||
T scale = (max->t<T>(0) - min->t<T>(0)) / (quant_max_float - quant_min_float);
|
||||
const T zero_point_from_min = quant_min_float - min->t<T>(0) / scale;
|
||||
T scale, nudgedMin, nudgedMax;
|
||||
nudge(min->t<T>(0), max->t<T>(0), lowIntBound, upperIntBound, &scale, &nudgedMin, &nudgedMax);
|
||||
|
||||
const uint16_t nudged_zero_point = [zero_point_from_min, lowIntBound,
|
||||
quant_min_float, upperIntBound,
|
||||
quant_max_float] {
|
||||
if (zero_point_from_min < quant_min_float) {
|
||||
return static_cast<uint16_t>(lowIntBound);
|
||||
auto wiseMinMaxAndSoOn = LAMBDA_T(x, nudgedMin, nudgedMax, scale) {
|
||||
T val = x;
|
||||
if (x < nudgedMin) {
|
||||
val = nudgedMin;
|
||||
}
|
||||
if (zero_point_from_min > quant_max_float) {
|
||||
return static_cast<uint16_t>(upperIntBound);
|
||||
else if (x > nudgedMax) {
|
||||
val = nudgedMax;
|
||||
}
|
||||
return static_cast<uint16_t>(roundf(zero_point_from_min));
|
||||
}();
|
||||
|
||||
auto nudged_min = (quant_min_float - nudged_zero_point) * (scale);
|
||||
auto nudged_max = (quant_max_float - nudged_zero_point) * (scale);
|
||||
|
||||
auto wiseMax = LAMBDA_T(x, nudged_min) {
|
||||
if (x < nudged_min) {
|
||||
return nudged_min;
|
||||
}
|
||||
return x;
|
||||
else
|
||||
val = x;
|
||||
return (math::nd4j_floor<T,T>((val - nudgedMin) / scale + T(0.5)) * scale + nudgedMin);
|
||||
};
|
||||
|
||||
auto wiseMin = LAMBDA_T(x, nudged_max) {
|
||||
if (x > nudged_max) {
|
||||
return nudged_max;
|
||||
}
|
||||
return x;
|
||||
};
|
||||
input->applyLambda(wiseMinMaxAndSoOn, output);
|
||||
}
|
||||
|
||||
auto scaleTensor(*input);
|
||||
auto clamped(*input);
|
||||
scaleTensor.assign(scale);
|
||||
input->applyLambda(wiseMin, &clamped);
|
||||
template <typename T>
|
||||
static __global__ void fakeQuantWithMinMaxKernel(T* input, Nd4jLong* inputShape, T* min, T* max,
|
||||
int lowIntBound, int upperIntBound, Nd4jLong channels,
|
||||
T* output, Nd4jLong* outputShape, Nd4jLong length) {
|
||||
__shared__ int block;
|
||||
if (threadIdx.x == 0) {
|
||||
block = length / channels; // to loop with last dimension as block
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
clamped.applyLambda(wiseMax, output);
|
||||
*output -= nudged_min;
|
||||
for (auto i = blockIdx.x; i < (int)channels; i += gridDim.x) {
|
||||
T scale, nudgedMin, nudgedMax;
|
||||
nudge(min[i], max[i], lowIntBound, upperIntBound, &scale, &nudgedMin, &nudgedMax);
|
||||
// loop over blocks to quantization between nudged min and max
|
||||
for (auto b = threadIdx.x; b < block; b += blockDim.x) {
|
||||
T val = input[shape::getIndexOffset(b * channels + i, inputShape)];
|
||||
if (val < nudgedMin) {
|
||||
val = nudgedMin;
|
||||
} else if (val > nudgedMax) {
|
||||
val = nudgedMax;
|
||||
}
|
||||
output[shape::getIndexOffset(b * channels + i, outputShape)] =
|
||||
(math::nd4j_floor<T, T>((val - nudgedMin) / scale + T(0.5)) * scale + nudgedMin);
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void fakeQuantWithMinMaxVarsPerChannel_(LaunchContext* context, NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output) {
|
||||
int lowIntBound = narrowed?1:0;
|
||||
int upperIntBound = (1 << numBits) - 1;
|
||||
auto channels = min->lengthOf();
|
||||
auto length = input->lengthOf();
|
||||
NDArray::prepareSpecialUse({output}, {min, max, input});
|
||||
auto stream = context->getCudaStream();
|
||||
T* inputBuf = input->dataBuffer()->specialAsT<T>();
|
||||
T* outputBuf = output->dataBuffer()->specialAsT<T>();
|
||||
T* minBuf = min->dataBuffer()->specialAsT<T>();
|
||||
T* maxBuf = max->dataBuffer()->specialAsT<T>();
|
||||
fakeQuantWithMinMaxKernel<<<128, 256, 256, *stream>>>(inputBuf, input->specialShapeInfo(),
|
||||
minBuf, maxBuf, lowIntBound, upperIntBound, channels, outputBuf, output->specialShapeInfo(), length);
|
||||
NDArray::registerSpecialUse({output}, {min, max, input});
|
||||
|
||||
(*output) /= scaleTensor;
|
||||
(*output) += T(0.5f);
|
||||
output->applyTransform(transform::Floor, nullptr, nullptr);
|
||||
(*output) *= scaleTensor;
|
||||
(*output) += nudged_min;
|
||||
}
|
||||
|
||||
void fakeQuantWithMinMaxVars(NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output) {
|
||||
BUILD_SINGLE_SELECTOR(input->dataType(), fakeQuantWithMinMaxVars_, (input, min, max, numBits, narrowed, output), FLOAT_TYPES);
|
||||
}
|
||||
void fakeQuantWithMinMaxVarsPerChannel(LaunchContext* context, NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output) {
|
||||
BUILD_SINGLE_SELECTOR(input->dataType(), fakeQuantWithMinMaxVarsPerChannel_, (context, input, min, max, numBits, narrowed, output), FLOAT_TYPES);
|
||||
}
|
||||
|
||||
BUILD_SINGLE_TEMPLATE(template void fakeQuantWithMinMaxVars_, (NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output), FLOAT_TYPES);
|
||||
BUILD_SINGLE_TEMPLATE(template void fakeQuantWithMinMaxVarsPerChannel_, (LaunchContext* context, NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output), FLOAT_TYPES);
|
||||
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,100 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author sgazeos@gmail.com
|
||||
//
|
||||
#include <op_boilerplate.h>
|
||||
#include <NDArray.h>
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
namespace helpers {
|
||||
|
||||
template <typename T>
|
||||
static __global__ void drawBoundingBoxesKernel(T const* images, Nd4jLong* imagesShape, T const* boxes,
|
||||
Nd4jLong* boxesShape, T const* colors, Nd4jLong* colorsShape, T* output, Nd4jLong* outputShape,
|
||||
Nd4jLong batchSize, Nd4jLong width, Nd4jLong height, Nd4jLong channels, Nd4jLong colorSetSize) {
|
||||
|
||||
for (auto b = blockIdx.x; b < (int)batchSize; b += gridDim.x) { // loop by batch
|
||||
for (auto c = 0; c < colorSetSize; c++) {
|
||||
// box with shape
|
||||
auto internalBox = &boxes[b * colorSetSize * 4 + c * 4];//(*boxes)(b, {0})(c, {0});//internalBoxes->at(c);
|
||||
auto color = &colors[channels * c];//colorSet->at(c);
|
||||
auto rowStart = nd4j::math::nd4j_max(Nd4jLong (0), Nd4jLong ((height - 1) * internalBox[0]));
|
||||
auto rowEnd = nd4j::math::nd4j_min(Nd4jLong (height - 1), Nd4jLong ((height - 1) * internalBox[2]));
|
||||
auto colStart = nd4j::math::nd4j_max(Nd4jLong (0), Nd4jLong ((width - 1) * internalBox[1]));
|
||||
auto colEnd = nd4j::math::nd4j_min(Nd4jLong(width - 1), Nd4jLong ((width - 1) * internalBox[3]));
|
||||
for (auto y = rowStart + threadIdx.x; y <= rowEnd; y += blockDim.x) {
|
||||
for (auto e = 0; e < channels; ++e) {
|
||||
Nd4jLong yMinPos[] = {b, y, colStart, e};
|
||||
Nd4jLong yMaxPos[] = {b, y, colEnd, e};
|
||||
auto zIndexYmin = shape::getOffset(outputShape, yMinPos);
|
||||
auto zIndexYmax = shape::getOffset(outputShape, yMaxPos);
|
||||
output[zIndexYmin] = color[e];
|
||||
output[zIndexYmax] = color[e];
|
||||
}
|
||||
}
|
||||
for (auto x = colStart + 1 + threadIdx.x; x < colEnd; x += blockDim.x) {
|
||||
for (auto e = 0; e < channels; ++e) {
|
||||
Nd4jLong xMinPos[] = {b, rowStart, x, e};
|
||||
Nd4jLong xMaxPos[] = {b, rowEnd, x, e};
|
||||
auto zIndexXmin = shape::getOffset(outputShape, xMinPos);
|
||||
auto zIndexXmax = shape::getOffset(outputShape, xMaxPos);
|
||||
output[zIndexXmin] = color[e];
|
||||
output[zIndexXmax] = color[e];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void drawBoundingBoxesH(nd4j::LaunchContext* context, NDArray const* images, NDArray const* boxes, NDArray const* colors, NDArray* output) {
|
||||
auto batchSize = images->sizeAt(0);
|
||||
auto height = images->sizeAt(1);
|
||||
auto width = images->sizeAt(2);
|
||||
auto channels = images->sizeAt(3);
|
||||
auto stream = context->getCudaStream();
|
||||
auto colorSetSize = colors->sizeAt(0);
|
||||
|
||||
auto imagesBuf = images->getDataBuffer()->specialAsT<T>();
|
||||
auto boxesBuf = boxes->getDataBuffer()->specialAsT<T>();
|
||||
auto colorsBuf = colors->getDataBuffer()->specialAsT<T>();
|
||||
auto outputBuf = output->dataBuffer()->specialAsT<T>();
|
||||
drawBoundingBoxesKernel<<<batchSize > 128? 128: batchSize, 256, 1024, *stream>>>(imagesBuf, images->getSpecialShapeInfo(),
|
||||
boxesBuf, boxes->getSpecialShapeInfo(), colorsBuf, colors->getSpecialShapeInfo(),
|
||||
outputBuf, output->specialShapeInfo(), batchSize, width, height, channels, colorSetSize);
|
||||
}
|
||||
|
||||
void drawBoundingBoxesFunctor(nd4j::LaunchContext * context, NDArray* images, NDArray* boxes, NDArray* colors, NDArray* output) {
|
||||
// images - batch of 3D images with BW (last dim = 1), RGB (last dim = 3) or RGBA (last dim = 4) channel set
|
||||
// boxes - batch of 2D bounds with last dim (y_start, x_start, y_end, x_end) to compute i and j as
|
||||
// floor((height - 1 ) * y_start) => rowStart, floor((height - 1) * y_end) => rowEnd
|
||||
// floor((width - 1 ) * x_start) => colStart, floor((width - 1) * x_end) => colEnd
|
||||
// height = images->sizeAt(1), width = images->sizeAt(2)
|
||||
// colors - colors for each box given
|
||||
// set up color for each box as frame
|
||||
NDArray::prepareSpecialUse({output}, {images, boxes, colors});
|
||||
output->assign(images);
|
||||
BUILD_SINGLE_SELECTOR(output->dataType(), drawBoundingBoxesH, (context, images, boxes, colors, output), FLOAT_TYPES);
|
||||
NDArray::registerSpecialUse({output}, {images, boxes, colors});
|
||||
}
|
||||
BUILD_SINGLE_TEMPLATE(template void drawBoundingBoxesH, (nd4j::LaunchContext* context, NDArray const* images, NDArray const* boxes, NDArray const* colors, NDArray* output), FLOAT_TYPES);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -27,6 +27,7 @@ namespace ops {
|
|||
namespace helpers {
|
||||
|
||||
void fakeQuantWithMinMaxVars(NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output);
|
||||
void fakeQuantWithMinMaxVarsPerChannel(LaunchContext* context, NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,34 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author sgazeos@gmail.com
|
||||
//
|
||||
#ifndef __IMAGE_DRAW_BOUNDING_BOXES_H_HELPERS__
|
||||
#define __IMAGE_DRAW_BOUNDING_BOXES_H_HELPERS__
|
||||
#include <op_boilerplate.h>
|
||||
#include <NDArray.h>
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
namespace helpers {
|
||||
|
||||
void drawBoundingBoxesFunctor(nd4j::LaunchContext * context, NDArray* images, NDArray* boxes, NDArray* colors, NDArray* output);
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
|
@ -37,6 +37,10 @@ namespace nd4j {
|
|||
return custom(nd4j::scalar::Divide, nd4j::pairwise::Divide, nd4j::broadcast::Divide);
|
||||
}
|
||||
|
||||
BroadcastOpsTuple BroadcastOpsTuple::DivideNoNan() {
|
||||
return custom(nd4j::scalar::DivideNoNan, nd4j::pairwise::DivideNoNan, nd4j::broadcast::DivideNoNan);
|
||||
}
|
||||
|
||||
BroadcastOpsTuple BroadcastOpsTuple::Multiply() {
|
||||
return custom(nd4j::scalar::Multiply, nd4j::pairwise::Multiply, nd4j::broadcast::Multiply);
|
||||
}
|
||||
|
|
|
@ -360,6 +360,34 @@ namespace simdOps {
|
|||
}
|
||||
};
|
||||
|
||||
template <typename X, typename Y, typename Z>
|
||||
class DivideNoNan {
|
||||
public:
|
||||
op_def static Z op(X d1, Y d2) {
|
||||
if (d2 == (Y)0) return (Z)0;
|
||||
return static_cast<Z>(d1 / d2);
|
||||
}
|
||||
|
||||
op_def static Z op(X d1, Y d2, Z *params) {
|
||||
if (d2 == (Y)0) return (Z)0;
|
||||
return static_cast<Z>(d1 / d2);
|
||||
}
|
||||
|
||||
op_def static Z op(X d1) {
|
||||
return static_cast<Z>(d1);
|
||||
}
|
||||
|
||||
// op for MetaOps
|
||||
op_def static Z op(X d1, Y *params) {
|
||||
if (params[0] == (Y)0) return (Z)0;
|
||||
return static_cast<Z>(d1 / params[0]);
|
||||
}
|
||||
|
||||
op_def static X startingValue() {
|
||||
return static_cast<X>(1);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename X, typename Y, typename Z>
|
||||
class SafeDivide {
|
||||
public:
|
||||
|
|
|
@ -1194,6 +1194,41 @@ TEST_F(DeclarableOpsTests1, BroadcastDivideTest_1) {
|
|||
delete res;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests1, BroadcastDivideTest_2) {
|
||||
|
||||
auto x = NDArrayFactory::create<float>('c', {3, 4, 5, 1});
|
||||
auto y = NDArrayFactory::create<float>('c', {1, 6});
|
||||
auto exp = NDArrayFactory::create<float>('c', {3, 4, 5, 6});
|
||||
x.assign(6);
|
||||
y.assign(2);
|
||||
exp.assign(3);
|
||||
|
||||
nd4j::ops::divide_no_nan div;
|
||||
auto res = div.execute({&x, &y}, {}, {});
|
||||
|
||||
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
||||
ASSERT_TRUE(res->at(0)->equalsTo(exp));
|
||||
|
||||
delete res;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests1, BroadcastDivideTest_3) {
|
||||
|
||||
auto x = NDArrayFactory::create<float>({6,6,6,6,6});
|
||||
auto y = NDArrayFactory::create<float>({3,3,0,3,3});
|
||||
auto exp = NDArrayFactory::create<float>({2, 2, 0, 2, 2});
|
||||
|
||||
nd4j::ops::divide_no_nan div;
|
||||
auto res = div.execute({&x, &y}, {}, {});
|
||||
|
||||
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
||||
ASSERT_TRUE(res->at(0)->equalsTo(exp));
|
||||
|
||||
delete res;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests1, BroadcastReverseDivideTest_1) {
|
||||
|
||||
|
|
|
@ -2043,11 +2043,81 @@ TEST_F(DeclarableOpsTests10, Image_CropAndResize_5) {
|
|||
delete results;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests10, Image_DrawBoundingBoxes_1) {
|
||||
NDArray images = NDArrayFactory::create<float>('c', {2,4,5,3});
|
||||
NDArray boxes = NDArrayFactory::create<float>('c', {2, 2, 4}, {
|
||||
0. , 0. , 1. , 1. , 0.1, 0.2, 0.9, 0.8,
|
||||
0.3, 0.3, 0.7, 0.7, 0.4, 0.4, 0.6, 0.6
|
||||
});
|
||||
|
||||
NDArray colors = NDArrayFactory::create<float>('c', {2, 3}, {201., 202., 203., 127., 128., 129.});
|
||||
|
||||
//NDArray<float> ('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f});
|
||||
NDArray expected = NDArrayFactory::create<float>('c', {2,4,5,3}, {
|
||||
127., 128., 129., 127., 128., 129., 127., 128., 129., 127., 128., 129., 201., 202., 203.,
|
||||
127., 128., 129., 19., 20., 21., 22., 23., 24., 127., 128., 129., 201., 202., 203.,
|
||||
127., 128., 129., 127., 128., 129., 127., 128., 129., 127., 128., 129., 201., 202., 203.,
|
||||
201., 202., 203., 201. ,202. ,203., 201., 202., 203., 201., 202., 203., 201., 202., 203.,
|
||||
|
||||
61., 62., 63., 201., 202., 203., 201., 202., 203., 70., 71., 72., 73., 74., 75.,
|
||||
76., 77., 78., 127., 128., 129., 127., 128., 129., 85., 86., 87., 88., 89., 90.,
|
||||
91., 92., 93., 201., 202., 203., 201., 202., 203., 100., 101., 102., 103., 104., 105.,
|
||||
106., 107., 108., 109., 110., 111., 112., 113., 114., 115., 116., 117., 118., 119., 120.
|
||||
});
|
||||
images.linspace(1.);
|
||||
nd4j::ops::draw_bounding_boxes op;
|
||||
auto results = op.execute({&images, &boxes, &colors}, {}, {});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
|
||||
auto result = results->at(0);
|
||||
// result->printBuffer("Bounded boxes");
|
||||
// expected.printBuffer("Bounded expec");
|
||||
ASSERT_TRUE(expected.isSameShapeStrict(result));
|
||||
ASSERT_TRUE(expected.equalsTo(result));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests10, Image_DrawBoundingBoxes_2) {
|
||||
NDArray images = NDArrayFactory::create<float>('c', {1,9,9,1});
|
||||
NDArray boxes = NDArrayFactory::create<float>('c', {1, 1, 4}, {0.2, 0.2, 0.7, 0.7});
|
||||
NDArray colors = NDArrayFactory::create<float>('c', {1, 1}, {0.95});
|
||||
|
||||
//NDArray<float> ('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f});
|
||||
NDArray expected = NDArrayFactory::create<float>('c', {1,9,9,1}, {
|
||||
1.1 , 2.1, 3.1 , 4.1 , 5.1 , 6.1 , 7.1 , 8.1 , 9.1 ,
|
||||
10.1 , 0.95, 0.95, 0.95, 0.95, 0.95, 16.1 , 17.1 , 18.1 ,
|
||||
19.1 , 0.95, 21.1, 22.1, 23.1, 0.95, 25.1 , 26.1 , 27.1 ,
|
||||
28.1 , 0.95, 30.1, 31.1, 32.1, 0.95, 34.1 , 35.1 , 36.1 ,
|
||||
37.1 , 0.95, 39.1, 40.1, 41.1, 0.95, 43.1 , 44.1 , 45.1 ,
|
||||
46.1 , 0.95, 0.95, 0.95, 0.95, 0.95, 52.1 , 53.1 , 54.1 ,
|
||||
55.1 , 56.1, 57.1 , 58.1 , 59.1 , 60.1 , 61.1 , 62.1 , 63.1 ,
|
||||
64.1 , 65.1, 66.1 , 67.1 , 68.1 , 69.1 , 70.1 , 71.1 , 72.1 ,
|
||||
73.1 , 74.1, 75.1 , 76.1 , 77.1 , 78.1 , 79.1 , 80.1 , 81.1 });
|
||||
images.linspace(1.1);
|
||||
nd4j::ops::draw_bounding_boxes op;
|
||||
auto results = op.execute({&images, &boxes, &colors}, {}, {});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
|
||||
auto result = results->at(0);
|
||||
// result->syncToHost();
|
||||
// result->printBuffer("Bounded boxes 2");
|
||||
// expected.printBuffer("Bounded expec 2");
|
||||
ASSERT_TRUE(expected.isSameShapeStrict(result));
|
||||
ASSERT_TRUE(expected.equalsTo(result));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_1) {
|
||||
|
||||
NDArray x('c', {2,3}, {-63.80f, -63.75f, -63.70f, -63.5f, 0.0f, 0.1f}, nd4j::DataType::FLOAT32);
|
||||
NDArray exp('c', {2,3}, {-63.75f, -63.75f, -63.75f, -63.251953f, 0.0f, 0.0f}, nd4j::DataType::FLOAT32);
|
||||
NDArray exp('c', {2,3}, {-63.75, -63.75, -63.75, -63.5, 0., 0.}, nd4j::DataType::FLOAT32);
|
||||
NDArray min('c', {}, {-63.65f}, nd4j::DataType::FLOAT32);
|
||||
NDArray max('c', {}, {0.1f}, nd4j::DataType::FLOAT32);
|
||||
|
||||
|
@ -2057,7 +2127,8 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_1) {
|
|||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
|
||||
auto result = results->at(0);
|
||||
// result->printIndexedBuffer("Quantized");
|
||||
// result->printBuffer("Quantized");
|
||||
// exp.printBuffer("Expected");
|
||||
ASSERT_TRUE(exp.isSameShapeStrict(result));
|
||||
ASSERT_TRUE(exp.equalsTo(result));
|
||||
|
||||
|
@ -2067,7 +2138,7 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_1) {
|
|||
TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_2) {
|
||||
|
||||
NDArray x = NDArrayFactory::create<double>('c', {2,3}, {-63.80, -63.75, -63.4, -63.5, 0.0, 0.1});
|
||||
NDArray exp = NDArrayFactory::create<double>('c', {2,3}, {-63.75, -63.75, -63.251953, -63.251953, 0.0, 0.0});
|
||||
NDArray exp = NDArrayFactory::create<double>('c', {2,3}, {-63.75, -63.75, -63.5 , -63.5 , 0. , 0. });
|
||||
NDArray min = NDArrayFactory::create<double>(-63.65);
|
||||
NDArray max = NDArrayFactory::create<double>(0.1);
|
||||
|
||||
|
@ -2084,6 +2155,119 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_2) {
|
|||
delete results;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_3) {
|
||||
|
||||
NDArray x = NDArrayFactory::create<double>('c', {1,2,3,1}, {-63.80, -63.75, -63.4, -63.5, 0.0, 0.1});
|
||||
NDArray exp = NDArrayFactory::create<double>('c', {1,2,3,1}, {-63.75, -63.75, -63.5 , -63.5 , 0. , 0. });
|
||||
NDArray min = NDArrayFactory::create<double>('c', {1},{-63.65});
|
||||
NDArray max = NDArrayFactory::create<double>('c', {1}, {0.1});
|
||||
|
||||
nd4j::ops::fake_quant_with_min_max_vars_per_channel op;
|
||||
auto results = op.execute({&x, &min, &max}, {}, {});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
|
||||
auto result = results->at(0);
|
||||
// result->printIndexedBuffer("Quantized2");
|
||||
ASSERT_TRUE(exp.isSameShapeStrict(result));
|
||||
ASSERT_TRUE(exp.equalsTo(result));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_4) {
|
||||
|
||||
NDArray x = NDArrayFactory::create<float>('c', {2,4,5,3});
|
||||
NDArray exp = NDArrayFactory::create<float>('c', {2,4,5,3},{
|
||||
1.0588236, 1.9607843, 3.019608, 4.0588236, 5.098039, 6.039216, 7.0588236, 8.039216, 9.058824,
|
||||
10.058824, 10.980392, 12.078432, 13.058824, 13.921569, 15.09804, 16.058825, 17.058825, 18.117647,
|
||||
19.058825, 20., 21.137257, 22.058825, 22.941177, 23.882355, 25.058825, 26.078432, 26.901962,
|
||||
28.058825, 29.019608, 29.92157, 31.058825, 31.960785, 32.941177, 34.058823, 35.09804, 35.960785,
|
||||
37.058823, 38.039215, 38.980392, 40.058823, 40.980392, 42.000004, 43.058826, 43.92157, 45.01961,
|
||||
45., 47.058823, 48.03922, 45., 50., 51.058826, 45., 50., 54.078434,
|
||||
45., 50., 57.09804, 45., 50., 60.11765, 45., 50., 62.862747,
|
||||
45., 50., 65.882355, 45., 50., 68.90196, 45., 50., 70.,
|
||||
45., 50., 70., 45., 50., 70., 45., 50., 70.,
|
||||
45., 50., 70., 45., 50., 70., 45., 50., 70.,
|
||||
45., 50., 70., 45., 50., 70., 45., 50., 70.,
|
||||
45., 50., 70., 45., 50., 70., 45., 50., 70.,
|
||||
45., 50., 70., 45., 50., 70., 45., 50., 70.,
|
||||
45., 50., 70.});
|
||||
NDArray min = NDArrayFactory::create<float>({20., 20., 20.});
|
||||
NDArray max = NDArrayFactory::create<float>({65., 70., 90.});
|
||||
x.linspace(1.);
|
||||
nd4j::ops::fake_quant_with_min_max_vars_per_channel op;
|
||||
auto results = op.execute({&x, &min, &max}, {}, {});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
|
||||
auto result = results->at(0);
|
||||
// result->printBuffer("Quantized per channels 4");
|
||||
// exp.printBuffer("Quantized per channest E");
|
||||
// auto diff = *result - exp;
|
||||
// diff.printIndexedBuffer("Difference");
|
||||
ASSERT_TRUE(exp.isSameShapeStrict(result));
|
||||
ASSERT_TRUE(exp.equalsTo(result));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_5) {
|
||||
NDArray x = NDArrayFactory::create<float>('c', {2, 3, 5, 4});
|
||||
NDArray exp = NDArrayFactory::create<float>('c', {2, 3, 5, 4},{
|
||||
-19.92157 , -18.980392 , -18.039217 , -16.941177 ,
|
||||
-19.92157 , -18.980392 , -18.039217 , -16.941177 ,
|
||||
-19.92157 , -18.980392 , -18.039217 , -16.941177 ,
|
||||
-19.92157 , -18.980392 , -18.039217 , -16.941177 ,
|
||||
-19.92157 , -18.980392 , -18.039217 , -16.941177 ,
|
||||
-19.92157 , -18.980392 , -18.039217 , -16.941177 ,
|
||||
-19.92157 , -18.980392 , -18.039217 , -16.941177 ,
|
||||
-19.92157 , -18.980392 , -18.039217 , -16.941177 ,
|
||||
-19.92157 , -18.980392 , -18.039217 , -16.941177 ,
|
||||
-19.92157 , -18.980392 , -18.039217 , -16.941177 ,
|
||||
-19.92157 , -18.980392 , -18.039217 , -16.941177 ,
|
||||
-16. , -15.058824 , -13.960785 , -13.0196085 ,
|
||||
-11.92157 , -10.980392 , -10.039217 , -8.941177 ,
|
||||
-8.000001 , -7.0588236 , -5.960785 , -5.0196085 ,
|
||||
-3.9215698 , -2.9803925 , -2.039217 , -0.94117737,
|
||||
0. , 0.94117737, 2.039215 , 2.9803925 ,
|
||||
4.07843 , 5.0196075 , 5.960783 , 7.0588226 ,
|
||||
8. , 8.941177 , 10.039215 , 10.980392 ,
|
||||
12.07843 , 13.019608 , 13.960783 , 15.058823 ,
|
||||
16. , 16.941177 , 18.039217 , 18.980392 ,
|
||||
20.07843 , 21.019608 , 21.960783 , 23.058823 ,
|
||||
20.07843 , 21.019608 , 21.960783 , 23.058823 ,
|
||||
20.07843 , 21.019608 , 21.960783 , 23.058823 ,
|
||||
20.07843 , 21.019608 , 21.960783 , 23.058823 ,
|
||||
20.07843 , 21.019608 , 21.960783 , 23.058823 ,
|
||||
20.07843 , 21.019608 , 21.960783 , 23.058823 ,
|
||||
20.07843 , 21.019608 , 21.960783 , 23.058823 ,
|
||||
20.07843 , 21.019608 , 21.960783 , 23.058823 ,
|
||||
20.07843 , 21.019608 , 21.960783 , 23.058823 ,
|
||||
20.07843 , 21.019608 , 21.960783 , 23.058823
|
||||
});
|
||||
NDArray min = NDArrayFactory::create<float>({-20., -19., -18., -17});
|
||||
NDArray max = NDArrayFactory::create<float>({20., 21., 22., 23});
|
||||
x.linspace(-60.);
|
||||
nd4j::ops::fake_quant_with_min_max_vars_per_channel op;
|
||||
auto results = op.execute({&x, &min, &max}, {}, {});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
|
||||
auto result = results->at(0);
|
||||
// result->printBuffer("Quantized per channels 5");
|
||||
// exp.printBuffer("Quantized per channest E");
|
||||
// auto diff = *result - exp;
|
||||
// diff.printIndexedBuffer("Difference");
|
||||
|
||||
ASSERT_TRUE(exp.isSameShapeStrict(result));
|
||||
ASSERT_TRUE(exp.equalsTo(result));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
TYPED_TEST(TypedDeclarableOpsTests10, batchnorm_new_test1) {
|
||||
|
||||
|
|
|
@ -157,6 +157,104 @@ TEST_F(DeclarableOpsTests15, Test_standarize_bp_1) {
|
|||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests15, Test_AdjustContrast_1) {
|
||||
auto x = NDArrayFactory::create<double>('c', {4,4,3});
|
||||
auto e = NDArrayFactory::create<double>('c', {4,4,3}, {
|
||||
-21.5, -20.5, -19.5, -15.5, -14.5, -13.5, -9.5, -8.5, -7.5, -3.5, -2.5, -1.5,
|
||||
2.5, 3.5, 4.5, 8.5, 9.5, 10.5, 14.5, 15.5, 16.5, 20.5, 21.5, 22.5,
|
||||
26.5, 27.5, 28.5, 32.5, 33.5, 34.5, 38.5, 39.5, 40.5, 44.5, 45.5, 46.5,
|
||||
50.5, 51.5, 52.5, 56.5, 57.5, 58.5, 62.5, 63.5, 64.5, 68.5, 69.5, 70.5
|
||||
});
|
||||
x.linspace(1.);
|
||||
nd4j::ops::adjust_contrast op;
|
||||
auto result = op.execute({&x}, {2.}, {}, {});
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
auto out = result->at(0);
|
||||
// out->printIndexedBuffer("Adjusted Constrast");
|
||||
ASSERT_TRUE(e.equalsTo(out));
|
||||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests15, Test_AdjustContrast_2) {
|
||||
auto x = NDArrayFactory::create<float>('c', {1, 4,4,3});
|
||||
auto e = NDArrayFactory::create<float>('c', {1, 4,4,3}, {
|
||||
-21.5, -20.5, -19.5, -15.5, -14.5, -13.5, -9.5, -8.5, -7.5, -3.5, -2.5, -1.5,
|
||||
2.5, 3.5, 4.5, 8.5, 9.5, 10.5, 14.5, 15.5, 16.5, 20.5, 21.5, 22.5,
|
||||
26.5, 27.5, 28.5, 32.5, 33.5, 34.5, 38.5, 39.5, 40.5, 44.5, 45.5, 46.5,
|
||||
50.5, 51.5, 52.5, 56.5, 57.5, 58.5, 62.5, 63.5, 64.5, 68.5, 69.5, 70.5
|
||||
});
|
||||
x.linspace(1.);
|
||||
nd4j::ops::adjust_contrast op;
|
||||
auto result = op.execute({&x}, {2.}, {}, {});
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
auto out = result->at(0);
|
||||
// out->printIndexedBuffer("Adjusted Constrast");
|
||||
ASSERT_TRUE(e.equalsTo(out));
|
||||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests15, Test_AdjustContrast_3) {
|
||||
auto x = NDArrayFactory::create<float>('c', {1, 4,4,3});
|
||||
auto e = NDArrayFactory::create<float>('c', {1, 4,4,3}, {
|
||||
-21.5, -20.5, -19.5, -15.5, -14.5, -13.5, -9.5, -8.5, -7.5, -3.5, -2.5, -1.5,
|
||||
2.5, 3.5, 4.5, 8.5, 9.5, 10.5, 14.5, 15.5, 16.5, 20.5, 21.5, 22.5,
|
||||
26.5, 27.5, 28.5, 32.5, 33.5, 34.5, 38.5, 39.5, 40.5, 44.5, 45.5, 46.5,
|
||||
50.5, 51.5, 52.5, 56.5, 57.5, 58.5, 62.5, 63.5, 64.5, 68.5, 69.5, 70.5
|
||||
});
|
||||
x.linspace(1.);
|
||||
nd4j::ops::adjust_contrast_v2 op;
|
||||
auto result = op.execute({&x}, {2.}, {}, {});
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
auto out = result->at(0);
|
||||
// out->printIndexedBuffer("Adjusted Constrast");
|
||||
ASSERT_TRUE(e.equalsTo(out));
|
||||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests15, Test_AdjustContrast_4) {
|
||||
auto x = NDArrayFactory::create<double>('c', {4, 4, 3});
|
||||
auto e = NDArrayFactory::create<double>('c', {4, 4, 3}, {
|
||||
-21.5, -20.5, -19.5, -15.5, -14.5, -13.5, -9.5, -8.5, -7.5, -3.5, -2.5, -1.5,
|
||||
2.5, 3.5, 4.5, 8.5, 9.5, 10.5, 14.5, 15.5, 16.5, 20.5, 21.5, 22.5,
|
||||
26.5, 27.5, 28.5, 32.5, 33.5, 34.5, 38.5, 39.5, 40.5, 44.5, 45.5, 46.5,
|
||||
50.5, 51.5, 52.5, 56.5, 57.5, 58.5, 62.5, 63.5, 64.5, 68.5, 69.5, 70.5
|
||||
});
|
||||
x.linspace(1.);
|
||||
nd4j::ops::adjust_contrast_v2 op;
|
||||
auto result = op.execute({&x}, {2.}, {}, {});
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
auto out = result->at(0);
|
||||
// out->printIndexedBuffer("Adjusted Constrast");
|
||||
ASSERT_TRUE(e.equalsTo(out));
|
||||
delete result;
|
||||
}
|
||||
TEST_F(DeclarableOpsTests15, Test_BitCast_1) {
|
||||
auto x = NDArrayFactory::create<float>('c', {2, 2, 2});
|
||||
auto e = NDArrayFactory::create<double>('c', {2, 2}, {2., 512., 8192., 131072.032 });
|
||||
x.linspace(1.);
|
||||
nd4j::ops::bitcast op;
|
||||
auto result = op.execute({&x}, {}, {nd4j::DataType::DOUBLE}, {});
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
auto out = result->at(0);
|
||||
// out->printIndexedBuffer("Casted result");
|
||||
ASSERT_TRUE(e.equalsTo(out));
|
||||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests15, Test_BitCast_2) {
|
||||
auto x = NDArrayFactory::create<float>('c', {2, 4});
|
||||
auto e = NDArrayFactory::create<float16>('c', {2, 4, 2}, {0, 1.875, 0, 2., 0, 2.125, 0, 2.25,
|
||||
0, 2.312, 0, 2.375, 0, 2.438, 0., 2.5});
|
||||
x.linspace(1.);
|
||||
nd4j::ops::bitcast op;
|
||||
auto result = op.execute({&x}, {}, {nd4j::DataType::HALF}, {});
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
auto out = result->at(0);
|
||||
out->printIndexedBuffer("Casted result");
|
||||
ASSERT_TRUE(e.equalsTo(out));
|
||||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests15, Test_depthwise_bp_1) {
|
||||
auto in = NDArrayFactory::create<float>('c', {4, 8, 64, 64});
|
||||
auto w = NDArrayFactory::create<float>('c', {2, 2, 8, 2});
|
||||
|
|
|
@ -2387,6 +2387,25 @@ TEST_F(DeclarableOpsTests9, thresholdedrelu_test1) {
|
|||
delete result;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests9, compare_and_bitpack_test1) {
|
||||
|
||||
auto x = NDArrayFactory::create<double>('c', {2, 3, 4}, {-12.f, -11.f, -10.f, -9.f, -8.f, -7.f, -6.f, -5.f, -4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f});
|
||||
auto threshold = NDArrayFactory::create<double>(2.0);
|
||||
auto exp = NDArrayFactory::create<uint8_t>('c', {2, 3, 4}, {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1});
|
||||
|
||||
nd4j::ops::compare_and_bitpack op;
|
||||
|
||||
auto result = op.execute({&x, &threshold}, {}, {}, {});
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
auto output = result->at(0);
|
||||
// output->printIndexedBuffer("Packed to uint8");
|
||||
ASSERT_TRUE(exp.isSameShape(output));
|
||||
ASSERT_TRUE(exp.equalsTo(output));
|
||||
delete result;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests9, thresholdedrelu_test2) {
|
||||
|
||||
|
|
|
@ -276,6 +276,10 @@ object Implicits {
|
|||
override def hasNegative: Boolean = false
|
||||
}
|
||||
|
||||
case object --- extends IndexRange {
|
||||
override def hasNegative: Boolean = false
|
||||
}
|
||||
|
||||
implicit class IntRange(val underlying: Int) extends IndexNumberRange {
|
||||
protected[nd4s] override def asRange(max: => Int): DRange =
|
||||
DRange(underlying, underlying, true, 1, max)
|
||||
|
@ -307,6 +311,10 @@ object Implicits {
|
|||
def -> = IntRangeFrom(underlying)
|
||||
}
|
||||
|
||||
implicit class IntRangeFromGen1(val underlying: Int) extends AnyVal {
|
||||
def :: = IntRangeFromReverse(underlying)
|
||||
}
|
||||
|
||||
implicit class IndexRangeWrapper(val underlying: Range) extends IndexNumberRange {
|
||||
protected[nd4s] override def asRange(max: => Int): DRange =
|
||||
DRange.from(underlying, max)
|
||||
|
@ -377,17 +385,27 @@ object IndexNumberRange {
|
|||
val endExclusive = if (endR >= 0) endR + diff else max + endR + diff
|
||||
(start, endExclusive)
|
||||
}
|
||||
|
||||
NDArrayIndex.interval(start, step, end, false)
|
||||
}
|
||||
}
|
||||
|
||||
sealed trait IndexRange {
|
||||
/*sealed*/
|
||||
trait IndexRange {
|
||||
def hasNegative: Boolean
|
||||
}
|
||||
|
||||
case class IntRangeFrom(underlying: Int) extends IndexRange {
|
||||
def apply[T](a: T): (Int, T) = (underlying, a)
|
||||
def apply[T](a: T): (Int, T) =
|
||||
(underlying, a)
|
||||
|
||||
override def toString: String = s"$underlying->"
|
||||
|
||||
override def hasNegative: Boolean = false
|
||||
}
|
||||
|
||||
case class IntRangeFromReverse(underlying: Int) extends IndexRange {
|
||||
def apply[T](a: T): (T, Int) =
|
||||
(a, underlying)
|
||||
|
||||
override def toString: String = s"$underlying->"
|
||||
|
||||
|
|
|
@ -23,6 +23,12 @@ import org.slf4j.LoggerFactory
|
|||
|
||||
import _root_.scala.annotation.tailrec
|
||||
|
||||
package object ops {
|
||||
case object :: extends IndexRange {
|
||||
override def hasNegative: Boolean = false
|
||||
}
|
||||
}
|
||||
|
||||
trait SliceableNDArray[A <: INDArray] {
|
||||
lazy val log = LoggerFactory.getLogger(classOf[SliceableNDArray[A]])
|
||||
val underlying: A
|
||||
|
@ -68,6 +74,8 @@ trait SliceableNDArray[A <: INDArray] {
|
|||
|
||||
@tailrec
|
||||
def modifyTargetIndices(input: List[IndexRange], i: Int, acc: List[DRange]): List[DRange] = input match {
|
||||
case ops.:: :: t =>
|
||||
modifyTargetIndices(t, i + 1, DRange(0, originalShape(i), 1) :: acc)
|
||||
case -> :: t =>
|
||||
modifyTargetIndices(t, i + 1, DRange(0, originalShape(i), 1) :: acc)
|
||||
case ---> :: t =>
|
||||
|
@ -137,6 +145,9 @@ trait SliceableNDArray[A <: INDArray] {
|
|||
case ---> :: t =>
|
||||
val ellipsised = List.fill(originalShape.length - i - t.size)(->)
|
||||
modifyTargetIndices(ellipsised ::: t, i, acc)
|
||||
case --- :: t =>
|
||||
val ellipsised = List.fill(originalShape.length - i - t.size)(->)
|
||||
modifyTargetIndices(ellipsised ::: t, i, acc)
|
||||
case IntRangeFrom(from: Int) :: t =>
|
||||
val max = originalShape(i)
|
||||
modifyTargetIndices(t, i + 1, IndexNumberRange.toNDArrayIndex(from, max, false, 1, max) :: acc)
|
||||
|
|
|
@ -44,6 +44,9 @@ class SameDiffWrapper {
|
|||
def bind(name: String, dataType: DataType, shape: Array[Long]): SDVariable =
|
||||
sd.`var`(name, dataType, shape: _*)
|
||||
|
||||
def bind(data: INDArray): SDVariable =
|
||||
sd.`var`("", data)
|
||||
|
||||
def bind(name: String, dataType: DataType, shape: Array[Int]): SDVariable =
|
||||
sd.`var`(name, dataType, shape: _*)
|
||||
|
||||
|
@ -51,18 +54,39 @@ class SameDiffWrapper {
|
|||
sd.placeHolder(name, dataType, shape: _*)
|
||||
}
|
||||
|
||||
case class SDIndexWrapper(end: Long) {
|
||||
|
||||
def ::(start: Long): SDIndex =
|
||||
SDIndex.interval(start, end)
|
||||
}
|
||||
|
||||
case class SDIndexWrapper1(start: Int) {
|
||||
|
||||
def ::(end: Int): SDIndex =
|
||||
SDIndex.interval(start, end)
|
||||
}
|
||||
|
||||
object --- extends SDIndex {
|
||||
val thisIndex: SDIndex = SDIndex.all()
|
||||
}
|
||||
|
||||
class SDVariableWrapper {
|
||||
|
||||
var thisVariable: SDVariable = null
|
||||
var isScalar: Boolean = false
|
||||
val --- : SDIndex = SDIndex.all()
|
||||
|
||||
def this(variable: SDVariable) {
|
||||
this
|
||||
thisVariable = variable
|
||||
}
|
||||
|
||||
// Indexing
|
||||
def apply(index: Long): SDVariable = thisVariable.get(SDIndex.point(index))
|
||||
|
||||
def apply(index: SDIndex*): SDVariable = thisVariable.get(index: _*)
|
||||
|
||||
// Arithmetic
|
||||
def add(other: Double): Unit = thisVariable.add(other)
|
||||
|
||||
def *(other: SDVariable): SDVariable =
|
||||
|
|
|
@ -15,9 +15,9 @@
|
|||
******************************************************************************/
|
||||
package org.nd4s.samediff.implicits
|
||||
|
||||
import org.nd4j.autodiff.samediff.{ SDVariable, SameDiff }
|
||||
import org.nd4j.autodiff.samediff.{ SDIndex, SDVariable, SameDiff }
|
||||
import org.nd4j.linalg.factory.Nd4j
|
||||
import org.nd4s.samediff.{ SDVariableWrapper, SameDiffWrapper }
|
||||
import org.nd4s.samediff.{ SDIndexWrapper, SDVariableWrapper, SameDiffWrapper }
|
||||
|
||||
object Implicits {
|
||||
implicit def SameDiffToWrapper(sd: SameDiff): SameDiffWrapper =
|
||||
|
@ -43,4 +43,20 @@ object Implicits {
|
|||
result.isScalar = true
|
||||
result
|
||||
}
|
||||
|
||||
implicit def RangeToWrapper(start: Long): SDIndexWrapper = {
|
||||
val result = new SDIndexWrapper(start)
|
||||
result
|
||||
}
|
||||
|
||||
implicit def LongToPoint(x: Long): SDIndex =
|
||||
SDIndex.point(x)
|
||||
|
||||
implicit def IntRangeToWrapper(start: Int): SDIndexWrapper = {
|
||||
val result = new SDIndexWrapper(start)
|
||||
result
|
||||
}
|
||||
|
||||
implicit def IntToPoint(x: Int): SDIndex =
|
||||
SDIndex.point(x)
|
||||
}
|
||||
|
|
|
@ -48,6 +48,42 @@ trait NDArrayExtractionTestBase extends FlatSpec { self: OrderingForTest =>
|
|||
assert(extracted == expected)
|
||||
}
|
||||
|
||||
it should "be able to extract a part of 2d matrix with alternative syntax" in {
|
||||
val ndArray =
|
||||
Array(
|
||||
Array(1, 2, 3),
|
||||
Array(4, 5, 6),
|
||||
Array(7, 8, 9)
|
||||
).mkNDArray(ordering)
|
||||
|
||||
val extracted = ndArray(1 :: 3, 0 :: 2)
|
||||
|
||||
val expected =
|
||||
Array(
|
||||
Array(4, 5),
|
||||
Array(7, 8)
|
||||
).mkNDArray(ordering)
|
||||
assert(extracted == expected)
|
||||
}
|
||||
|
||||
it should "be able to extract a part of 2d matrix with mixed syntax" in {
|
||||
val ndArray =
|
||||
Array(
|
||||
Array(1, 2, 3),
|
||||
Array(4, 5, 6),
|
||||
Array(7, 8, 9)
|
||||
).mkNDArray(ordering)
|
||||
|
||||
val extracted = ndArray(1 -> 3, 0 :: 2)
|
||||
|
||||
val expected =
|
||||
Array(
|
||||
Array(4, 5),
|
||||
Array(7, 8)
|
||||
).mkNDArray(ordering)
|
||||
assert(extracted == expected)
|
||||
}
|
||||
|
||||
it should "be able to extract a part of 2d matrix with double data" in {
|
||||
val ndArray = (5 to 8).map(_.toDouble).mkNDArray(Array(2, 2), NDOrdering.C)
|
||||
|
||||
|
@ -171,6 +207,9 @@ trait NDArrayExtractionTestBase extends FlatSpec { self: OrderingForTest =>
|
|||
|
||||
val ellipsised = ndArray(--->)
|
||||
assert(ellipsised == ndArray)
|
||||
|
||||
val ellipsised1 = ndArray(---)
|
||||
assert(ellipsised1 == ndArray)
|
||||
}
|
||||
|
||||
it should "accept partially ellipsis indices" in {
|
||||
|
|
|
@ -20,6 +20,7 @@ import org.nd4j.linalg.api.buffer.DataType
|
|||
import org.nd4j.linalg.api.ndarray.INDArray
|
||||
import org.nd4j.linalg.factory.Nd4j
|
||||
import org.nd4s.Implicits._
|
||||
import org.nd4s.NDOrdering
|
||||
import org.nd4s.samediff.implicits.Implicits._
|
||||
import org.scalatest.{ FlatSpec, Matchers }
|
||||
|
||||
|
@ -205,9 +206,41 @@ class MathTest extends FlatSpec with Matchers {
|
|||
implicit val sd = SameDiff.create
|
||||
|
||||
val arr = Nd4j.linspace(1, 100, 100).reshape('c', 10L, 10L)
|
||||
val x = sd.`var`(arr)
|
||||
val x = sd.bind(arr)
|
||||
val y = new SDVariableWrapper(x)
|
||||
|
||||
x.get(SDIndex.point(0)).getArr shouldBe y(0).getArr
|
||||
}
|
||||
|
||||
"SDVariable " should "be indexable in 2d" in {
|
||||
implicit val sd = SameDiff.create
|
||||
|
||||
val arr = Nd4j.linspace(DataType.FLOAT, 1.0, 1.0, 9).reshape(3, 3)
|
||||
|
||||
val x = sd.bind(arr)
|
||||
|
||||
x(0, ---).eval shouldBe x(SDIndex.point(0), SDIndex.all()).eval
|
||||
|
||||
val slice1 = x.get(SDIndex.interval(0, 2), SDIndex.all()).eval
|
||||
val slice2 = x(0 :: 2, ---).eval
|
||||
slice1 shouldBe slice2
|
||||
}
|
||||
|
||||
"SDVariable " should "be indexable in 3d" in {
|
||||
implicit val sd = SameDiff.create
|
||||
|
||||
val arr = Nd4j.linspace(DataType.FLOAT, 1.0, 1.0, 18).reshape(3, 3, 2)
|
||||
val x = sd.bind(arr)
|
||||
|
||||
x.get(SDIndex.all(), SDIndex.all(), SDIndex.all()).eval shouldBe x(---, ---, ---).eval
|
||||
x.get(SDIndex.point(0), SDIndex.all(), SDIndex.all()).eval shouldBe x(0, ---, ---).eval
|
||||
x.get(SDIndex.point(0), SDIndex.point(0), SDIndex.all()).eval shouldBe x(0, 0, ---).eval
|
||||
x.get(SDIndex.point(0), SDIndex.point(0), SDIndex.point(0)).eval shouldBe x(0, 0, 0).eval
|
||||
|
||||
x.get(SDIndex.interval(0, 2), SDIndex.point(0), SDIndex.point(0)).eval shouldBe x(0 :: 2, 0, 0).eval
|
||||
x.get(SDIndex.interval(0, 2), SDIndex.interval(0, 1), SDIndex.interval(0, 2)).eval shouldBe x(0 :: 2,
|
||||
0 :: 1,
|
||||
0 :: 2).eval
|
||||
x.get(SDIndex.interval(0, 2), SDIndex.interval(0, 1), SDIndex.all()).eval shouldBe x(0 :: 2, 0 :: 1, ---).eval
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue