Shugeo segment fix4 (#385)
* Added test and fixed error message for unsorted_segment_sqrt_n op. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed error message for unsorted_segment_* ops when 1 segment is given. Signed-off-by: shugeo <sgazeos@gmail.com>master
parent
1bec9a4f61
commit
a5db0e33be
|
@ -29,11 +29,11 @@ namespace sd {
|
|||
auto segmentedOutput = OUTPUT_NULLIFIED(0);
|
||||
Nd4jLong numOfClasses = block.width() == 3 ? INPUT_VARIABLE(2)->e<Nd4jLong>(0) : INT_ARG(0);
|
||||
REQUIRE_TRUE(idxSegments->isVector(), 0, "unsorted_segment_max: segment indexes array should be a vector, but it rank is %i.", idxSegments->rankOf());
|
||||
REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "unsorted_segment_max: segment indexes array length should be equal to the input first dimension, but %i != %i.", idxSegments->lengthOf(), input->sizeAt(0));
|
||||
REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "unsorted_segment_max: segment indexes array length should be equal to the input first dimension, but %ld != %ild.", idxSegments->lengthOf(), input->sizeAt(0));
|
||||
|
||||
Nd4jLong wrong;
|
||||
|
||||
REQUIRE_TRUE(helpers::unsortedSegmentIndicesValidate(block.launchContext(), idxSegments, numOfClasses, wrong), 0, "unsorted_segment_max: segment indices should be in range [0, %i), but %i > %i",
|
||||
REQUIRE_TRUE(helpers::unsortedSegmentIndicesValidate(block.launchContext(), idxSegments, numOfClasses, wrong), 0, "unsorted_segment_max: segment indices should be in range [0, %ld), but %ld != %ld",
|
||||
numOfClasses, wrong, numOfClasses);
|
||||
|
||||
helpers::unsortedSegmentMaxFunctor(block.launchContext(), input, idxSegments, numOfClasses, segmentedOutput);
|
||||
|
|
|
@ -30,11 +30,11 @@ namespace sd {
|
|||
Nd4jLong numOfClasses = block.width() == 3 ? INPUT_VARIABLE(2)->e<Nd4jLong>(0) : INT_ARG(0);
|
||||
|
||||
REQUIRE_TRUE(idxSegments->isVector(), 0, "unsorted_segment_mean: segment indexes array should be a vector, but it rank is %i.", idxSegments->rankOf());
|
||||
REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "unsorted_segment_mean: segment indexes array length should be equal to the input first dimension, but %i != %i.", idxSegments->lengthOf(), input->sizeAt(0));
|
||||
REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "unsorted_segment_mean: segment indexes array length should be equal to the input first dimension, but %ld != %ld.", idxSegments->lengthOf(), input->sizeAt(0));
|
||||
|
||||
Nd4jLong wrong;
|
||||
|
||||
REQUIRE_TRUE(helpers::unsortedSegmentIndicesValidate(block.launchContext(), idxSegments, numOfClasses, wrong), 0, "unsorted_segment_mean: segment indices should be in range [0, %i), but %i > %i",
|
||||
REQUIRE_TRUE(helpers::unsortedSegmentIndicesValidate(block.launchContext(), idxSegments, numOfClasses, wrong), 0, "unsorted_segment_mean: segment indices should be in range [0, %ld), but %ld != %ld",
|
||||
numOfClasses, wrong, numOfClasses);
|
||||
|
||||
helpers::unsortedSegmentMeanFunctor(block.launchContext(), input, idxSegments, numOfClasses, segmentedOutput);
|
||||
|
|
|
@ -29,11 +29,11 @@ namespace sd {
|
|||
auto segmentedOutput = OUTPUT_NULLIFIED(0);
|
||||
Nd4jLong numOfClasses = block.width() == 3 ? INPUT_VARIABLE(2)->e<Nd4jLong>(0) : INT_ARG(0);
|
||||
REQUIRE_TRUE(idxSegments->isVector(), 0, "unsorted_segment_min: segment indexes array should be a vector, but it rank is %i.", idxSegments->rankOf());
|
||||
REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "unsorted_segment_min: segment indexes array length should be equal to the input first dimension, but %i != %i.", idxSegments->lengthOf(), input->sizeAt(0));
|
||||
REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "unsorted_segment_min: segment indexes array length should be equal to the input first dimension, but %ld != %ld.", idxSegments->lengthOf(), input->sizeAt(0));
|
||||
|
||||
Nd4jLong wrong;
|
||||
|
||||
REQUIRE_TRUE(helpers::unsortedSegmentIndicesValidate(block.launchContext(), idxSegments, numOfClasses, wrong), 0, "unsorted_segment_min: segment indices should be in range [0, %i), but %i > %i",
|
||||
REQUIRE_TRUE(helpers::unsortedSegmentIndicesValidate(block.launchContext(), idxSegments, numOfClasses, wrong), 0, "unsorted_segment_min: segment indices should be in range [0, %ld), but %ld > %ld",
|
||||
numOfClasses, wrong, numOfClasses);
|
||||
|
||||
helpers::unsortedSegmentMinFunctor(block.launchContext(), input, idxSegments, numOfClasses, segmentedOutput);
|
||||
|
|
|
@ -29,11 +29,11 @@ namespace sd {
|
|||
auto segmentedOutput = OUTPUT_NULLIFIED(0);
|
||||
Nd4jLong numOfClasses = block.width() == 3 ? INPUT_VARIABLE(2)->e<Nd4jLong>(0) : INT_ARG(0);
|
||||
REQUIRE_TRUE(idxSegments->isVector(), 0, "unsorted_segment_prod: segment indexes array should be a vector, but it rank is %i.", idxSegments->rankOf());
|
||||
REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "unsorted_segment_prod: segment indexes array length should be equal to the input first dimension, but %i != %i.", idxSegments->lengthOf(), input->sizeAt(0));
|
||||
REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "unsorted_segment_prod: segment indexes array length should be equal to the input first dimension, but %ld != %ld.", idxSegments->lengthOf(), input->sizeAt(0));
|
||||
|
||||
Nd4jLong wrong = 0;
|
||||
|
||||
REQUIRE_TRUE(helpers::unsortedSegmentIndicesValidate(block.launchContext(), idxSegments, numOfClasses, wrong), 0, "unsorted_segment_prod: segment indices should be in range [0, %i), but %i > %i",
|
||||
REQUIRE_TRUE(helpers::unsortedSegmentIndicesValidate(block.launchContext(), idxSegments, numOfClasses, wrong), 0, "unsorted_segment_prod: segment indices should be in range [0, %ld), but %ld != %ld",
|
||||
numOfClasses, wrong, numOfClasses);
|
||||
|
||||
helpers::unsortedSegmentProdFunctor(block.launchContext(), input, idxSegments, numOfClasses, segmentedOutput);
|
||||
|
|
|
@ -29,11 +29,11 @@ namespace sd {
|
|||
auto segmentedOutput = OUTPUT_NULLIFIED(0);
|
||||
Nd4jLong numOfClasses = block.width() == 3 ? INPUT_VARIABLE(2)->e<Nd4jLong>(0) : INT_ARG(0);
|
||||
REQUIRE_TRUE(idxSegments->isVector(), 0, "unsorted_segment_sqrt_n: segment indexes array should be a vector, but it rank is %i.", idxSegments->rankOf());
|
||||
REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "unsorted_segment_sqrt_n: segment indexes array length should be equal to the input first dimension, but %i != %i.", idxSegments->lengthOf(), input->sizeAt(0));
|
||||
REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "unsorted_segment_sqrt_n: segment indexes array length should be equal to the input first dimension, but %ld != %ld.", idxSegments->lengthOf(), input->sizeAt(0));
|
||||
|
||||
Nd4jLong wrong;
|
||||
|
||||
REQUIRE_TRUE(helpers::unsortedSegmentIndicesValidate(block.launchContext(), idxSegments, numOfClasses, wrong), 0, "unsorted_segment_sqrt_n: segment indices should be in range [0, %i), but %i > %i",
|
||||
REQUIRE_TRUE(helpers::unsortedSegmentIndicesValidate(block.launchContext(), idxSegments, numOfClasses, wrong), 0, "unsorted_segment_sqrt_n: segment indices should be in range [0, %ld), but %ld != %ld",
|
||||
numOfClasses, wrong, numOfClasses);
|
||||
|
||||
helpers::unsortedSegmentSqrtNFunctor(block.launchContext(), input, idxSegments, numOfClasses, segmentedOutput);
|
||||
|
|
|
@ -29,11 +29,11 @@ namespace sd {
|
|||
auto segmentedOutput = OUTPUT_NULLIFIED(0);
|
||||
Nd4jLong numOfClasses = block.width() == 3 ? INPUT_VARIABLE(2)->e<Nd4jLong>(0) : INT_ARG(0);
|
||||
REQUIRE_TRUE(idxSegments->isVector(), 0, "unsorted_segment_sum: segment indexes array should be a vector, but it rank is %i.", idxSegments->rankOf());
|
||||
REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "unsorted_segment_sum: segment indexes array length should be equal to the input first dimension, but %i != %i.", idxSegments->lengthOf(), input->sizeAt(0));
|
||||
REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "unsorted_segment_sum: segment indexes array length should be equal to the input first dimension, but %ld != %ld", idxSegments->lengthOf(), input->sizeAt(0));
|
||||
|
||||
Nd4jLong wrong;
|
||||
|
||||
REQUIRE_TRUE(helpers::unsortedSegmentIndicesValidate(block.launchContext(), idxSegments, numOfClasses, wrong), 0, "unsorted_segment_sum: segment indices should be in range [0, %i), but %i > %i",
|
||||
REQUIRE_TRUE(helpers::unsortedSegmentIndicesValidate(block.launchContext(), idxSegments, numOfClasses, wrong), 0, "unsorted_segment_sum: segment indices should be in range [0, %ld), but %ld > %ld",
|
||||
numOfClasses, wrong, numOfClasses);
|
||||
|
||||
helpers::unsortedSegmentSumFunctor(block.launchContext(), input, idxSegments, numOfClasses, segmentedOutput);
|
||||
|
|
|
@ -1868,7 +1868,26 @@ TEST_F(DeclarableOpsTests7, TestUnsortedSegmentSqrtN_5) {
|
|||
// exp.printIndexedBuffer("Expect");
|
||||
ASSERT_TRUE(exp.equalsTo(result.at(0)));
|
||||
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests7, TestUnsortedSegmentSqrtN_6) {
|
||||
auto x = NDArrayFactory::create<double>({5,1,7,2,3,4,1,3});
|
||||
auto idx = NDArrayFactory::create<int>({0,0,0,1,2,2,3,3});
|
||||
//NDArray<double> exp({1.7320508075688772, 1., 1.4142135623730951, 1.4142135623730951});
|
||||
// auto exp = NDArrayFactory::create<double>({7.5055537, 2., 4.9497476, 2.828427});
|
||||
sd::ops::unsorted_segment_sqrt_n op;
|
||||
|
||||
try {
|
||||
auto result = op.evaluate({&x, &idx}, {}, {1});
|
||||
ASSERT_NE(result.status(), Status::OK());
|
||||
}
|
||||
catch (std::exception& err) {
|
||||
|
||||
}
|
||||
// result.at(0)->printIndexedBuffer("Output");
|
||||
// exp.printIndexedBuffer("Expect");
|
||||
//ASSERT_TRUE(exp.equalsTo(result.at(0)));
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
|
Loading…
Reference in New Issue