cavis/libnd4j/include/ops/declarable/generic/parity_ops/bincount.cpp

162 lines
6.1 KiB
C++
Raw Normal View History

2021-02-01 13:31:45 +01:00
/* ******************************************************************************
*
2019-06-06 14:21:15 +02:00
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
2021-02-01 13:31:45 +01:00
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership.
2019-06-06 14:21:15 +02:00
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include <system/op_boilerplate.h>
2019-06-06 14:21:15 +02:00
#if NOT_EXCLUDED(OP_bincount)
//#include <ops/declarable/headers/parity_ops.h>
#include <ops/declarable/CustomOperations.h>
#include <ops/declarable/helpers/weights.h>
namespace sd {
2019-06-06 14:21:15 +02:00
namespace ops {
DECLARE_TYPES(bincount) {
getOpDescriptor()
->setAllowedInputTypes(0, sd::DataType::INT32)
->setAllowedInputTypes(1, sd::DataType::ANY)
2019-06-06 14:21:15 +02:00
->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS});
}
CUSTOM_OP_IMPL(bincount, 1, 1, false, 0, 0) {
auto values = INPUT_VARIABLE(0);
2021-02-10 14:40:04 +01:00
2019-06-06 14:21:15 +02:00
NDArray *weights = nullptr;
int maxLength = -1;
int minLength = 0;
int maxIndex = values->argMax();
maxLength = values->e<int>(maxIndex) + 1;
if (block.numI() > 0) {
minLength = sd::math::nd4j_max(INT_ARG(0), 0);
2019-06-06 14:21:15 +02:00
if (block.numI() == 2)
maxLength = sd::math::nd4j_min(maxLength, INT_ARG(1));
2019-06-06 14:21:15 +02:00
}
if (block.width() == 2) { // the second argument is weights
weights = INPUT_VARIABLE(1);
2021-02-10 14:40:04 +01:00
if(weights->lengthOf() < 1) {
weights = NDArrayFactory::create_('c',values->getShapeAsVector(),values->dataType());
weights->assign(1);
} else if(weights->isScalar()) {
auto value = weights->asVectorT<int>();
weights = NDArrayFactory::create_('c',values->getShapeAsVector(),values->dataType());
weights->assign(value[0]);
}
2019-06-06 14:21:15 +02:00
REQUIRE_TRUE(values->isSameShape(weights), 0, "bincount: the input and weights shapes should be equals");
}
else if (block.width() == 3) { // the second argument is min and the third is max
auto min= INPUT_VARIABLE(1);
2021-02-10 14:40:04 +01:00
auto max = min;
if(INPUT_VARIABLE(2)->lengthOf() > 0) {
max = INPUT_VARIABLE(2);
}
2019-06-06 14:21:15 +02:00
minLength = min->e<int>(0);
maxLength = max->e<int>(0);
}
else if (block.width() > 3) {
auto min= INPUT_VARIABLE(2);
auto max = INPUT_VARIABLE(3);
minLength = min->e<int>(0);
2021-02-10 14:40:04 +01:00
if(INPUT_VARIABLE(2)->lengthOf() > 0) {
maxLength = max->e<int>(0);
}
else
maxLength = minLength;
2019-06-06 14:21:15 +02:00
weights = INPUT_VARIABLE(1);
2021-02-10 14:40:04 +01:00
if(weights->lengthOf() < 1) {
weights = NDArrayFactory::create_('c',values->getShapeAsVector(),values->dataType());
weights->assign(1);
} else if(weights->isScalar()) {
auto value = weights->asVectorT<int>();
weights = NDArrayFactory::create_('c',values->getShapeAsVector(),values->dataType());
weights->assign(value[0]);
}
2019-06-06 14:21:15 +02:00
REQUIRE_TRUE(values->isSameShape(weights), 0, "bincount: the input and weights shapes should be equals");
}
2021-02-10 14:40:04 +01:00
minLength = sd::math::nd4j_max(minLength, 0);
maxLength = sd::math::nd4j_min(maxLength, values->e<int>(maxIndex) + 1);
2019-06-06 14:21:15 +02:00
auto result = OUTPUT_VARIABLE(0);
result->assign(0.0f);
2021-02-10 14:40:04 +01:00
2019-06-06 14:21:15 +02:00
helpers::adjustWeights(block.launchContext(), values, weights, result, minLength, maxLength);
return Status::OK();
}
DECLARE_SHAPE_FN(bincount) {
2021-02-10 14:40:04 +01:00
auto shapeList = SHAPELIST();
2019-06-06 14:21:15 +02:00
auto in = INPUT_VARIABLE(0);
sd::DataType dtype = DataType::INT32;
2019-06-06 14:21:15 +02:00
if (block.width() > 1)
dtype = ArrayOptions::dataType(inputShape->at(1));
else if (block.numI() > 2)
dtype = (sd::DataType)INT_ARG(2);
2019-06-06 14:21:15 +02:00
int maxIndex = in->argMax();
int maxLength = in->e<int>(maxIndex) + 1;
int outLength = maxLength;
2021-02-10 14:40:04 +01:00
2019-06-06 14:21:15 +02:00
if (block.numI() > 0)
outLength = sd::math::nd4j_max(maxLength, INT_ARG(0));
2019-06-06 14:21:15 +02:00
2021-02-10 14:40:04 +01:00
if (block.numI() > 1)
outLength = sd::math::nd4j_min(outLength, INT_ARG(1));
2019-06-06 14:21:15 +02:00
2021-02-10 14:40:04 +01:00
2019-06-06 14:21:15 +02:00
if (block.width() == 3) { // the second argument is min and the third is max
2021-02-10 14:40:04 +01:00
auto min = INPUT_VARIABLE(1)->e<int>(0);
auto max = min;
if(INPUT_VARIABLE(2)->lengthOf() > 0) {
max = INPUT_VARIABLE(2)->e<int>(0);
}
outLength = sd::math::nd4j_max(maxLength, min);
outLength = sd::math::nd4j_min(outLength, max);
2019-06-06 14:21:15 +02:00
}
else if (block.width() > 3) {
auto min= INPUT_VARIABLE(2);
2021-02-10 14:40:04 +01:00
auto max = min;
if(INPUT_VARIABLE(3)->lengthOf() > 0) {
max = INPUT_VARIABLE(3);
}
outLength = sd::math::nd4j_max(maxLength, min->e<int>(0));
outLength = sd::math::nd4j_min(outLength, max->e<int>(0));
2019-06-06 14:21:15 +02:00
}
2021-02-10 14:40:04 +01:00
auto newshape = ConstantShapeHelper::getInstance().vectorShapeInfo(outLength, dtype);
2019-06-06 14:21:15 +02:00
2021-02-10 14:40:04 +01:00
shapeList->push_back(newshape);
2019-06-06 14:21:15 +02:00
return shapeList;
}
}
}
#endif