151 lines
5.1 KiB
C++
151 lines
5.1 KiB
C++
/*******************************************************************************
|
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
*
|
|
* This program and the accompanying materials are made available under the
|
|
* terms of the Apache License, Version 2.0 which is available at
|
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
* License for the specific language governing permissions and limitations
|
|
* under the License.
|
|
*
|
|
* SPDX-License-Identifier: Apache-2.0
|
|
******************************************************************************/
|
|
|
|
//
|
|
// @author raver119@gmail.com
|
|
//
|
|
|
|
#include <op_boilerplate.h>
|
|
#if NOT_EXCLUDED(OP_cumsum)
|
|
|
|
#include <ops/declarable/helpers/prefix.h>
|
|
#include <ops/declarable/CustomOperations.h>
|
|
|
|
namespace nd4j {
|
|
namespace ops {
|
|
|
|
CONFIGURABLE_OP_IMPL(cumsum, 1, 1, true, 0, 2) {
|
|
auto input = INPUT_VARIABLE(0);
|
|
auto output = OUTPUT_VARIABLE(0);
|
|
|
|
const bool exclusive = INT_ARG(0) == 1;
|
|
const bool reverse = INT_ARG(1) == 1;
|
|
|
|
REQUIRE_TRUE(input->dataType() == output->dataType(), 0, "CumSum: input and output data types must be equal");
|
|
|
|
if(input->isEmpty()){
|
|
//No-op
|
|
return Status::OK();
|
|
}
|
|
|
|
if (block.getIArguments()->size() == 2 && block.width() == 1) {
|
|
// all at once case
|
|
nd4j::ops::helpers::_prefix(block.launchContext(), scalar::Add, input, output, exclusive, reverse);
|
|
}
|
|
else {
|
|
std::vector<int> dims(block.numI() - 2);
|
|
|
|
if (block.width() == 1) {
|
|
|
|
for (int e = 0; e < block.numI() - 2; e++)
|
|
dims[e] = INT_ARG(e + 2);
|
|
}
|
|
else {
|
|
auto ax = INPUT_VARIABLE(1);
|
|
dims = ax->template asVectorT<int>();
|
|
}
|
|
|
|
for (int e = 0; e < dims.size(); e++)
|
|
if (dims[e] < 0)
|
|
dims[e] += input->rankOf();
|
|
|
|
nd4j::ops::helpers::_prefix(block.launchContext(), scalar::Add, input, output, dims, exclusive, reverse);
|
|
}
|
|
|
|
return Status::OK();
|
|
}
|
|
DECLARE_TYPES(cumsum) {
|
|
|
|
getOpDescriptor()
|
|
->setAllowedInputTypes(0, {ALL_FLOATS, ALL_INTS})
|
|
->setAllowedInputTypes(1, {ALL_INTS})
|
|
->setAllowedOutputTypes({ALL_FLOATS})
|
|
->setSameMode(false);
|
|
}
|
|
|
|
CUSTOM_OP_IMPL(cumsum_bp, 2, -1, true, 0, 2) {
|
|
auto input = INPUT_VARIABLE(0);
|
|
auto axis = block.width() == 3 ? INPUT_VARIABLE(1) : nullptr;
|
|
auto gradOut = block.width() == 3 ? INPUT_VARIABLE(2) : INPUT_VARIABLE(1);
|
|
auto output = OUTPUT_VARIABLE(0);
|
|
// output->assign(gradOut);
|
|
const bool exclusive = INT_ARG(0) == 1;
|
|
const bool reverse = INT_ARG(1) == 1;
|
|
|
|
std::vector<int> dims;
|
|
|
|
if (block.width() > 2) {
|
|
dims = axis->template asVectorT<int>();
|
|
OUTPUT_VARIABLE(1)->assign(1.0f);
|
|
} else if (int newSize = (block.numI() - 2)) {
|
|
dims.resize(newSize);
|
|
|
|
for (int e = 0; e < newSize; e++)
|
|
dims[e] = INT_ARG(e + 2);
|
|
}
|
|
if (!exclusive && !reverse) {
|
|
if (dims.size())
|
|
nd4j::ops::helpers::_prefix(block.launchContext(), scalar::Add, gradOut, output, dims, false, true);
|
|
else
|
|
nd4j::ops::helpers::_prefix(block.launchContext(), scalar::Add, gradOut, output, false, true);
|
|
|
|
}
|
|
else if (!exclusive && reverse){
|
|
if (dims.size())
|
|
nd4j::ops::helpers::_prefix(block.launchContext(), scalar::Add, gradOut, output, dims, false, false);
|
|
else
|
|
nd4j::ops::helpers::_prefix(block.launchContext(), scalar::Add, gradOut, output, false, false);
|
|
}
|
|
else if (exclusive && !reverse) {
|
|
if (dims.size())
|
|
nd4j::ops::helpers::_prefix(block.launchContext(), scalar::Add, gradOut, output, dims, true, true);
|
|
else
|
|
nd4j::ops::helpers::_prefix(block.launchContext(), scalar::Add, gradOut, output, true, true);
|
|
}
|
|
else {
|
|
if (dims.size())
|
|
nd4j::ops::helpers::_prefix(block.launchContext(), scalar::Add, gradOut, output, dims, true, false);
|
|
else
|
|
nd4j::ops::helpers::_prefix(block.launchContext(), scalar::Add, gradOut, output, true, false);
|
|
}
|
|
|
|
return Status::OK();
|
|
}
|
|
DECLARE_TYPES(cumsum_bp) {
|
|
getOpDescriptor()->setAllowedInputTypes(0, {ALL_FLOATS, ALL_INTS});
|
|
getOpDescriptor()->setAllowedInputTypes(1, {ALL_FLOATS, ALL_INTS}); // axes can be set as the second param
|
|
getOpDescriptor()->setAllowedInputTypes(2, {ALL_FLOATS});
|
|
getOpDescriptor()->setAllowedOutputTypes(0, {ALL_FLOATS});
|
|
}
|
|
|
|
DECLARE_SHAPE_FN(cumsum_bp) {
|
|
auto inp = inputShape->at(0);
|
|
Nd4jLong *newShapeX = nullptr;
|
|
COPY_SHAPE(inp, newShapeX);
|
|
|
|
if (block.width() == 2) {
|
|
return SHAPELIST(CONSTANT(newShapeX));
|
|
} else {
|
|
Nd4jLong *newShapeA = nullptr;
|
|
COPY_SHAPE(inputShape->at(1), newShapeA);
|
|
|
|
return SHAPELIST(CONSTANT(newShapeX), CONSTANT(newShapeA));
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
#endif |