Shugeo segment fix4 (#385)
* Added test and fixed error message for unsorted_segment_sqrt_n op. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed error message for unsorted_segment_* ops when 1 segment is given. Signed-off-by: shugeo <sgazeos@gmail.com>
This commit is contained in:
		
							parent
							
								
									1bec9a4f61
								
							
						
					
					
						commit
						a5db0e33be
					
				| @ -29,11 +29,11 @@ namespace sd { | ||||
|             auto segmentedOutput = OUTPUT_NULLIFIED(0); | ||||
|             Nd4jLong numOfClasses = block.width() == 3 ? INPUT_VARIABLE(2)->e<Nd4jLong>(0) : INT_ARG(0); | ||||
|             REQUIRE_TRUE(idxSegments->isVector(), 0, "unsorted_segment_max: segment indexes array should be a vector, but it rank is %i.", idxSegments->rankOf()); | ||||
|             REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "unsorted_segment_max: segment indexes array length should be equal to the input first dimension, but %i != %i.", idxSegments->lengthOf(), input->sizeAt(0)); | ||||
|             REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "unsorted_segment_max: segment indexes array length should be equal to the input first dimension, but %ld != %ild.", idxSegments->lengthOf(), input->sizeAt(0)); | ||||
| 
 | ||||
|             Nd4jLong wrong; | ||||
| 
 | ||||
|             REQUIRE_TRUE(helpers::unsortedSegmentIndicesValidate(block.launchContext(), idxSegments, numOfClasses, wrong), 0, "unsorted_segment_max: segment indices should be in range [0, %i), but %i > %i", | ||||
|             REQUIRE_TRUE(helpers::unsortedSegmentIndicesValidate(block.launchContext(), idxSegments, numOfClasses, wrong), 0, "unsorted_segment_max: segment indices should be in range [0, %ld), but %ld != %ld", | ||||
|                     numOfClasses, wrong, numOfClasses); | ||||
| 
 | ||||
|             helpers::unsortedSegmentMaxFunctor(block.launchContext(), input, idxSegments, numOfClasses, segmentedOutput); | ||||
|  | ||||
| @ -30,11 +30,11 @@ namespace sd { | ||||
|             Nd4jLong numOfClasses = block.width() == 3 ? INPUT_VARIABLE(2)->e<Nd4jLong>(0) : INT_ARG(0); | ||||
| 
 | ||||
|             REQUIRE_TRUE(idxSegments->isVector(), 0, "unsorted_segment_mean: segment indexes array should be a vector, but it rank is %i.", idxSegments->rankOf()); | ||||
|             REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "unsorted_segment_mean: segment indexes array length should be equal to the input first dimension, but %i != %i.", idxSegments->lengthOf(), input->sizeAt(0)); | ||||
|             REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "unsorted_segment_mean: segment indexes array length should be equal to the input first dimension, but %ld != %ld.", idxSegments->lengthOf(), input->sizeAt(0)); | ||||
| 
 | ||||
|             Nd4jLong wrong; | ||||
| 
 | ||||
|             REQUIRE_TRUE(helpers::unsortedSegmentIndicesValidate(block.launchContext(), idxSegments, numOfClasses, wrong), 0, "unsorted_segment_mean: segment indices should be in range [0, %i), but %i > %i", | ||||
|             REQUIRE_TRUE(helpers::unsortedSegmentIndicesValidate(block.launchContext(), idxSegments, numOfClasses, wrong), 0, "unsorted_segment_mean: segment indices should be in range [0, %ld), but %ld != %ld", | ||||
|                     numOfClasses, wrong, numOfClasses); | ||||
| 
 | ||||
|             helpers::unsortedSegmentMeanFunctor(block.launchContext(), input, idxSegments, numOfClasses, segmentedOutput); | ||||
|  | ||||
| @ -29,11 +29,11 @@ namespace sd { | ||||
|             auto segmentedOutput = OUTPUT_NULLIFIED(0); | ||||
|             Nd4jLong numOfClasses = block.width() == 3 ? INPUT_VARIABLE(2)->e<Nd4jLong>(0) : INT_ARG(0); | ||||
|             REQUIRE_TRUE(idxSegments->isVector(), 0, "unsorted_segment_min: segment indexes array should be a vector, but it rank is %i.", idxSegments->rankOf()); | ||||
|             REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "unsorted_segment_min: segment indexes array length should be equal to the input first dimension, but %i != %i.", idxSegments->lengthOf(), input->sizeAt(0)); | ||||
|             REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "unsorted_segment_min: segment indexes array length should be equal to the input first dimension, but %ld != %ld.", idxSegments->lengthOf(), input->sizeAt(0)); | ||||
| 
 | ||||
|             Nd4jLong wrong; | ||||
| 
 | ||||
|             REQUIRE_TRUE(helpers::unsortedSegmentIndicesValidate(block.launchContext(), idxSegments, numOfClasses, wrong), 0, "unsorted_segment_min: segment indices should be in range [0, %i), but %i > %i", | ||||
|             REQUIRE_TRUE(helpers::unsortedSegmentIndicesValidate(block.launchContext(), idxSegments, numOfClasses, wrong), 0, "unsorted_segment_min: segment indices should be in range [0, %ld), but %ld > %ld", | ||||
|                     numOfClasses, wrong, numOfClasses); | ||||
| 
 | ||||
|             helpers::unsortedSegmentMinFunctor(block.launchContext(), input, idxSegments, numOfClasses, segmentedOutput); | ||||
|  | ||||
| @ -29,11 +29,11 @@ namespace sd { | ||||
|             auto segmentedOutput = OUTPUT_NULLIFIED(0); | ||||
|             Nd4jLong numOfClasses = block.width() == 3 ? INPUT_VARIABLE(2)->e<Nd4jLong>(0) : INT_ARG(0); | ||||
|             REQUIRE_TRUE(idxSegments->isVector(), 0, "unsorted_segment_prod: segment indexes array should be a vector, but it rank is %i.", idxSegments->rankOf()); | ||||
|             REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "unsorted_segment_prod: segment indexes array length should be equal to the input first dimension, but %i != %i.", idxSegments->lengthOf(), input->sizeAt(0)); | ||||
|             REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "unsorted_segment_prod: segment indexes array length should be equal to the input first dimension, but %ld != %ld.", idxSegments->lengthOf(), input->sizeAt(0)); | ||||
| 
 | ||||
|             Nd4jLong wrong = 0; | ||||
| 
 | ||||
|             REQUIRE_TRUE(helpers::unsortedSegmentIndicesValidate(block.launchContext(), idxSegments, numOfClasses, wrong), 0, "unsorted_segment_prod: segment indices should be in range [0, %i), but %i > %i", | ||||
|             REQUIRE_TRUE(helpers::unsortedSegmentIndicesValidate(block.launchContext(), idxSegments, numOfClasses, wrong), 0, "unsorted_segment_prod: segment indices should be in range [0, %ld), but %ld != %ld", | ||||
|                     numOfClasses, wrong, numOfClasses); | ||||
| 
 | ||||
|             helpers::unsortedSegmentProdFunctor(block.launchContext(), input, idxSegments, numOfClasses, segmentedOutput); | ||||
|  | ||||
| @ -29,11 +29,11 @@ namespace sd { | ||||
|             auto segmentedOutput = OUTPUT_NULLIFIED(0); | ||||
|             Nd4jLong numOfClasses = block.width() == 3 ? INPUT_VARIABLE(2)->e<Nd4jLong>(0) : INT_ARG(0); | ||||
|             REQUIRE_TRUE(idxSegments->isVector(), 0, "unsorted_segment_sqrt_n: segment indexes array should be a vector, but it rank is %i.", idxSegments->rankOf()); | ||||
|             REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "unsorted_segment_sqrt_n: segment indexes array length should be equal to the input first dimension, but %i != %i.", idxSegments->lengthOf(), input->sizeAt(0)); | ||||
|             REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "unsorted_segment_sqrt_n: segment indexes array length should be equal to the input first dimension, but %ld != %ld.", idxSegments->lengthOf(), input->sizeAt(0)); | ||||
| 
 | ||||
|             Nd4jLong wrong; | ||||
| 
 | ||||
|             REQUIRE_TRUE(helpers::unsortedSegmentIndicesValidate(block.launchContext(), idxSegments, numOfClasses, wrong), 0, "unsorted_segment_sqrt_n: segment indices should be in range [0, %i), but %i > %i", | ||||
|             REQUIRE_TRUE(helpers::unsortedSegmentIndicesValidate(block.launchContext(), idxSegments, numOfClasses, wrong), 0, "unsorted_segment_sqrt_n: segment indices should be in range [0, %ld), but %ld != %ld", | ||||
|                     numOfClasses, wrong, numOfClasses); | ||||
| 
 | ||||
|             helpers::unsortedSegmentSqrtNFunctor(block.launchContext(), input, idxSegments, numOfClasses, segmentedOutput); | ||||
|  | ||||
| @ -29,11 +29,11 @@ namespace sd { | ||||
|             auto segmentedOutput = OUTPUT_NULLIFIED(0); | ||||
|             Nd4jLong numOfClasses = block.width() == 3 ? INPUT_VARIABLE(2)->e<Nd4jLong>(0) : INT_ARG(0); | ||||
|             REQUIRE_TRUE(idxSegments->isVector(), 0, "unsorted_segment_sum: segment indexes array should be a vector, but it rank is %i.", idxSegments->rankOf()); | ||||
|             REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "unsorted_segment_sum: segment indexes array length should be equal to the input first dimension, but %i != %i.", idxSegments->lengthOf(), input->sizeAt(0)); | ||||
|             REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "unsorted_segment_sum: segment indexes array length should be equal to the input first dimension, but %ld != %ld", idxSegments->lengthOf(), input->sizeAt(0)); | ||||
| 
 | ||||
|             Nd4jLong wrong; | ||||
| 
 | ||||
|             REQUIRE_TRUE(helpers::unsortedSegmentIndicesValidate(block.launchContext(), idxSegments, numOfClasses, wrong), 0, "unsorted_segment_sum: segment indices should be in range [0, %i), but %i > %i", | ||||
|             REQUIRE_TRUE(helpers::unsortedSegmentIndicesValidate(block.launchContext(), idxSegments, numOfClasses, wrong), 0, "unsorted_segment_sum: segment indices should be in range [0, %ld), but %ld > %ld", | ||||
|                     numOfClasses, wrong, numOfClasses); | ||||
| 
 | ||||
|             helpers::unsortedSegmentSumFunctor(block.launchContext(), input, idxSegments, numOfClasses, segmentedOutput); | ||||
|  | ||||
| @ -1868,7 +1868,26 @@ TEST_F(DeclarableOpsTests7, TestUnsortedSegmentSqrtN_5) { | ||||
|    // exp.printIndexedBuffer("Expect");
 | ||||
|     ASSERT_TRUE(exp.equalsTo(result.at(0))); | ||||
| 
 | ||||
|      | ||||
| } | ||||
| 
 | ||||
| ////////////////////////////////////////////////////////////////////////////////
 | ||||
| TEST_F(DeclarableOpsTests7, TestUnsortedSegmentSqrtN_6) { | ||||
|     auto x = NDArrayFactory::create<double>({5,1,7,2,3,4,1,3}); | ||||
|     auto idx = NDArrayFactory::create<int>({0,0,0,1,2,2,3,3}); | ||||
|     //NDArray<double> exp({1.7320508075688772, 1.,      1.4142135623730951,        1.4142135623730951});
 | ||||
| //    auto exp = NDArrayFactory::create<double>({7.5055537, 2.,        4.9497476, 2.828427});
 | ||||
|     sd::ops::unsorted_segment_sqrt_n op; | ||||
| 
 | ||||
| try { | ||||
|     auto result = op.evaluate({&x, &idx}, {}, {1}); | ||||
|     ASSERT_NE(result.status(), Status::OK()); | ||||
| } | ||||
| catch (std::exception& err) { | ||||
| 
 | ||||
| } | ||||
|     // result.at(0)->printIndexedBuffer("Output");
 | ||||
|     // exp.printIndexedBuffer("Expect");
 | ||||
|     //ASSERT_TRUE(exp.equalsTo(result.at(0)));
 | ||||
| } | ||||
| 
 | ||||
| ////////////////////////////////////////////////////////////////////////////////
 | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user