diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/igamma.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/igamma.cpp new file mode 100644 index 000000000..6bd1c88ed --- /dev/null +++ b/libnd4j/include/ops/declarable/generic/broadcastable/igamma.cpp @@ -0,0 +1,59 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author sgazeos@gmail.com +// + +#include +#if NOT_EXCLUDED(OP_igamma) + +#include +#include + +namespace nd4j { + namespace ops { + BROADCASTABLE_OP_IMPL(igamma, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); + + BROADCAST_CHECK_EMPTY(x,y,z); + + //REQUIRE_TRUE(!y->isB(), 0, "Pairwise OP: you can't divide by bool array!"); + +// auto tZ = BroadcastHelper::broadcastApply({scalar::IGamma, pairwise::IGamma, broadcast::IGamma}, x, y, z); + auto tZ = BroadcastHelper::broadcastApply(BroadcastOpsTuple::IGamma(), x, y, z); + + if (tZ == nullptr) + return ND4J_STATUS_KERNEL_FAILURE; + else if (tZ != z) { + OVERWRITE_RESULT(tZ); + } + + return Status::OK(); + } + + DECLARE_TYPES(igamma) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_FLOATS}) + ->setAllowedInputTypes(1, {ALL_FLOATS}) + ->setAllowedOutputTypes(0, {ALL_FLOATS}); + } + } +} + +#endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/igammac.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/igammac.cpp new file mode 100644 index 000000000..89494dc4b --- /dev/null +++ b/libnd4j/include/ops/declarable/generic/broadcastable/igammac.cpp @@ -0,0 +1,58 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author sgazeos@gmail.com +// + +#include +#if NOT_EXCLUDED(OP_igammac) + +#include +#include + +namespace nd4j { + namespace ops { + BROADCASTABLE_OP_IMPL(igammac, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); + + BROADCAST_CHECK_EMPTY(x,y,z); + + //REQUIRE_TRUE(!y->isB(), 0, "Pairwise OP: you can't divide by bool array!"); + +// auto tZ = BroadcastHelper::broadcastApply({scalar::IGammac, pairwise::IGammac, broadcast::IGammac}, x, y, z); + auto tZ = BroadcastHelper::broadcastApply(BroadcastOpsTuple::IGammac(), x, y, z); + if (tZ == nullptr) + return ND4J_STATUS_KERNEL_FAILURE; + else if (tZ != z) { + OVERWRITE_RESULT(tZ); + } + + return Status::OK(); + } + + DECLARE_TYPES(igammac) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_FLOATS}) + ->setAllowedInputTypes(1, {ALL_FLOATS}) + ->setAllowedOutputTypes(0, {ALL_FLOATS}); + } + } +} + +#endif \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp index 0652a398e..25fe3429a 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp @@ -537,6 +537,50 @@ TEST_F(DeclarableOpsTests10, atan2_test6) { delete result; } +////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, IGamma_Test1) { + + auto y = NDArrayFactory::create('c', {1, 3, 4}, {1.1 , 2.1 , 3.1 ,4.1 , 5.1 , 6.1 ,7.1 ,8.1 ,9.1 ,10.1,11.1 ,12.1}); + auto x = NDArrayFactory::create('c', { 4}, {1.2, 2.2, 3.2, 4.2}); + + auto exp = NDArrayFactory::create('c', {1,3,4}, { + 0.659917, 0.61757898, 0.59726304, 0.58478117, + 0.0066205109, 0.022211598, 0.040677428, 0.059117373, + 0.0000039433403, 0.000086064574, 0.000436067, 0.0012273735}); + + nd4j::ops::igamma op; + auto result = op.execute({&y, &x}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result->status()); + auto z = result->at(0); +// z->printBuffer("OUtput"); +// exp.printBuffer("EXpect"); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + delete result; +} + +////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, IGamma_Test2) { + + auto y = NDArrayFactory::create('c', {1, 3, 4}, {1.1 , 2.1 , 3.1 ,4.1 , 5.1 , 6.1 , + 7.1 ,8.1 ,9.1 ,10.1,11.1 ,12.1}); + auto x = NDArrayFactory::create('c', { 4}, {1.2, 2.2, 3.2, 4.2}); + auto exp = NDArrayFactory::create('c', {1,3,4}, {0.340083, 0.382421, 0.402737, 0.415221, + 0.993379, 0.977788, 0.959323, 0.940883, + 0.999996, 0.999914, 0.999564, 0.998773}); + + nd4j::ops::igammac op; + auto result = op.execute({&y, &x}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result->status()); + auto z = result->at(0); +// z->printBuffer("OUtput"); +// exp.printBuffer("EXpect"); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + delete result; +} ////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, range_test10) {