Shugeo segment fix4 (#385)

* Added test and fixed error message for unsorted_segment_sqrt_n op.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed error message for unsorted_segment_* ops when 1 segment is given.

Signed-off-by: shugeo <sgazeos@gmail.com>
master
shugeo 2020-04-20 09:04:35 +03:00 committed by GitHub
parent 1bec9a4f61
commit a5db0e33be
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 32 additions and 13 deletions

View File

@ -29,11 +29,11 @@ namespace sd {
auto segmentedOutput = OUTPUT_NULLIFIED(0);
Nd4jLong numOfClasses = block.width() == 3 ? INPUT_VARIABLE(2)->e<Nd4jLong>(0) : INT_ARG(0);
REQUIRE_TRUE(idxSegments->isVector(), 0, "unsorted_segment_max: segment indexes array should be a vector, but it rank is %i.", idxSegments->rankOf());
REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "unsorted_segment_max: segment indexes array length should be equal to the input first dimension, but %i != %i.", idxSegments->lengthOf(), input->sizeAt(0));
REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "unsorted_segment_max: segment indexes array length should be equal to the input first dimension, but %ld != %ild.", idxSegments->lengthOf(), input->sizeAt(0));
Nd4jLong wrong;
REQUIRE_TRUE(helpers::unsortedSegmentIndicesValidate(block.launchContext(), idxSegments, numOfClasses, wrong), 0, "unsorted_segment_max: segment indices should be in range [0, %i), but %i > %i",
REQUIRE_TRUE(helpers::unsortedSegmentIndicesValidate(block.launchContext(), idxSegments, numOfClasses, wrong), 0, "unsorted_segment_max: segment indices should be in range [0, %ld), but %ld != %ld",
numOfClasses, wrong, numOfClasses);
helpers::unsortedSegmentMaxFunctor(block.launchContext(), input, idxSegments, numOfClasses, segmentedOutput);

View File

@ -30,11 +30,11 @@ namespace sd {
Nd4jLong numOfClasses = block.width() == 3 ? INPUT_VARIABLE(2)->e<Nd4jLong>(0) : INT_ARG(0);
REQUIRE_TRUE(idxSegments->isVector(), 0, "unsorted_segment_mean: segment indexes array should be a vector, but it rank is %i.", idxSegments->rankOf());
REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "unsorted_segment_mean: segment indexes array length should be equal to the input first dimension, but %i != %i.", idxSegments->lengthOf(), input->sizeAt(0));
REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "unsorted_segment_mean: segment indexes array length should be equal to the input first dimension, but %ld != %ld.", idxSegments->lengthOf(), input->sizeAt(0));
Nd4jLong wrong;
REQUIRE_TRUE(helpers::unsortedSegmentIndicesValidate(block.launchContext(), idxSegments, numOfClasses, wrong), 0, "unsorted_segment_mean: segment indices should be in range [0, %i), but %i > %i",
REQUIRE_TRUE(helpers::unsortedSegmentIndicesValidate(block.launchContext(), idxSegments, numOfClasses, wrong), 0, "unsorted_segment_mean: segment indices should be in range [0, %ld), but %ld != %ld",
numOfClasses, wrong, numOfClasses);
helpers::unsortedSegmentMeanFunctor(block.launchContext(), input, idxSegments, numOfClasses, segmentedOutput);

View File

@ -29,11 +29,11 @@ namespace sd {
auto segmentedOutput = OUTPUT_NULLIFIED(0);
Nd4jLong numOfClasses = block.width() == 3 ? INPUT_VARIABLE(2)->e<Nd4jLong>(0) : INT_ARG(0);
REQUIRE_TRUE(idxSegments->isVector(), 0, "unsorted_segment_min: segment indexes array should be a vector, but it rank is %i.", idxSegments->rankOf());
REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "unsorted_segment_min: segment indexes array length should be equal to the input first dimension, but %i != %i.", idxSegments->lengthOf(), input->sizeAt(0));
REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "unsorted_segment_min: segment indexes array length should be equal to the input first dimension, but %ld != %ld.", idxSegments->lengthOf(), input->sizeAt(0));
Nd4jLong wrong;
REQUIRE_TRUE(helpers::unsortedSegmentIndicesValidate(block.launchContext(), idxSegments, numOfClasses, wrong), 0, "unsorted_segment_min: segment indices should be in range [0, %i), but %i > %i",
REQUIRE_TRUE(helpers::unsortedSegmentIndicesValidate(block.launchContext(), idxSegments, numOfClasses, wrong), 0, "unsorted_segment_min: segment indices should be in range [0, %ld), but %ld > %ld",
numOfClasses, wrong, numOfClasses);
helpers::unsortedSegmentMinFunctor(block.launchContext(), input, idxSegments, numOfClasses, segmentedOutput);

View File

@ -29,11 +29,11 @@ namespace sd {
auto segmentedOutput = OUTPUT_NULLIFIED(0);
Nd4jLong numOfClasses = block.width() == 3 ? INPUT_VARIABLE(2)->e<Nd4jLong>(0) : INT_ARG(0);
REQUIRE_TRUE(idxSegments->isVector(), 0, "unsorted_segment_prod: segment indexes array should be a vector, but it rank is %i.", idxSegments->rankOf());
REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "unsorted_segment_prod: segment indexes array length should be equal to the input first dimension, but %i != %i.", idxSegments->lengthOf(), input->sizeAt(0));
REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "unsorted_segment_prod: segment indexes array length should be equal to the input first dimension, but %ld != %ld.", idxSegments->lengthOf(), input->sizeAt(0));
Nd4jLong wrong = 0;
REQUIRE_TRUE(helpers::unsortedSegmentIndicesValidate(block.launchContext(), idxSegments, numOfClasses, wrong), 0, "unsorted_segment_prod: segment indices should be in range [0, %i), but %i > %i",
REQUIRE_TRUE(helpers::unsortedSegmentIndicesValidate(block.launchContext(), idxSegments, numOfClasses, wrong), 0, "unsorted_segment_prod: segment indices should be in range [0, %ld), but %ld != %ld",
numOfClasses, wrong, numOfClasses);
helpers::unsortedSegmentProdFunctor(block.launchContext(), input, idxSegments, numOfClasses, segmentedOutput);

View File

@ -29,11 +29,11 @@ namespace sd {
auto segmentedOutput = OUTPUT_NULLIFIED(0);
Nd4jLong numOfClasses = block.width() == 3 ? INPUT_VARIABLE(2)->e<Nd4jLong>(0) : INT_ARG(0);
REQUIRE_TRUE(idxSegments->isVector(), 0, "unsorted_segment_sqrt_n: segment indexes array should be a vector, but it rank is %i.", idxSegments->rankOf());
REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "unsorted_segment_sqrt_n: segment indexes array length should be equal to the input first dimension, but %i != %i.", idxSegments->lengthOf(), input->sizeAt(0));
REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "unsorted_segment_sqrt_n: segment indexes array length should be equal to the input first dimension, but %ld != %ld.", idxSegments->lengthOf(), input->sizeAt(0));
Nd4jLong wrong;
REQUIRE_TRUE(helpers::unsortedSegmentIndicesValidate(block.launchContext(), idxSegments, numOfClasses, wrong), 0, "unsorted_segment_sqrt_n: segment indices should be in range [0, %i), but %i > %i",
REQUIRE_TRUE(helpers::unsortedSegmentIndicesValidate(block.launchContext(), idxSegments, numOfClasses, wrong), 0, "unsorted_segment_sqrt_n: segment indices should be in range [0, %ld), but %ld != %ld",
numOfClasses, wrong, numOfClasses);
helpers::unsortedSegmentSqrtNFunctor(block.launchContext(), input, idxSegments, numOfClasses, segmentedOutput);

View File

@ -29,11 +29,11 @@ namespace sd {
auto segmentedOutput = OUTPUT_NULLIFIED(0);
Nd4jLong numOfClasses = block.width() == 3 ? INPUT_VARIABLE(2)->e<Nd4jLong>(0) : INT_ARG(0);
REQUIRE_TRUE(idxSegments->isVector(), 0, "unsorted_segment_sum: segment indexes array should be a vector, but it rank is %i.", idxSegments->rankOf());
REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "unsorted_segment_sum: segment indexes array length should be equal to the input first dimension, but %i != %i.", idxSegments->lengthOf(), input->sizeAt(0));
REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "unsorted_segment_sum: segment indexes array length should be equal to the input first dimension, but %ld != %ld", idxSegments->lengthOf(), input->sizeAt(0));
Nd4jLong wrong;
REQUIRE_TRUE(helpers::unsortedSegmentIndicesValidate(block.launchContext(), idxSegments, numOfClasses, wrong), 0, "unsorted_segment_sum: segment indices should be in range [0, %i), but %i > %i",
REQUIRE_TRUE(helpers::unsortedSegmentIndicesValidate(block.launchContext(), idxSegments, numOfClasses, wrong), 0, "unsorted_segment_sum: segment indices should be in range [0, %ld), but %ld > %ld",
numOfClasses, wrong, numOfClasses);
helpers::unsortedSegmentSumFunctor(block.launchContext(), input, idxSegments, numOfClasses, segmentedOutput);

View File

@ -1868,7 +1868,26 @@ TEST_F(DeclarableOpsTests7, TestUnsortedSegmentSqrtN_5) {
// exp.printIndexedBuffer("Expect");
ASSERT_TRUE(exp.equalsTo(result.at(0)));
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests7, TestUnsortedSegmentSqrtN_6) {
auto x = NDArrayFactory::create<double>({5,1,7,2,3,4,1,3});
auto idx = NDArrayFactory::create<int>({0,0,0,1,2,2,3,3});
//NDArray<double> exp({1.7320508075688772, 1., 1.4142135623730951, 1.4142135623730951});
// auto exp = NDArrayFactory::create<double>({7.5055537, 2., 4.9497476, 2.828427});
sd::ops::unsorted_segment_sqrt_n op;
try {
auto result = op.evaluate({&x, &idx}, {}, {1});
ASSERT_NE(result.status(), Status::OK());
}
catch (std::exception& err) {
}
// result.at(0)->printIndexedBuffer("Output");
// exp.printIndexedBuffer("Expect");
//ASSERT_TRUE(exp.equalsTo(result.at(0)));
}
////////////////////////////////////////////////////////////////////////////////