Added implementation for adjust_contrast_v2 op and tests.
This commit is contained in:
		
							parent
							
								
									e06dfb5dcc
								
							
						
					
					
						commit
						1575c704ae
					
				| @ -61,6 +61,43 @@ DECLARE_TYPES(adjust_contrast) { | ||||
| } | ||||
| 
 | ||||
| 
 | ||||
|     CONFIGURABLE_OP_IMPL(adjust_contrast_v2, 1, 1, true, 1, 0) { | ||||
| 
 | ||||
|         auto input  = INPUT_VARIABLE(0); | ||||
|         auto output = OUTPUT_VARIABLE(0); | ||||
| 
 | ||||
|         const double factor = T_ARG(0); | ||||
| 
 | ||||
|         REQUIRE_TRUE(input->rankOf() > 2, 0, "ADJUST_CONTRAST: op expects rank of input array to be >= 3, but got %i instead", input->rankOf()); | ||||
|         REQUIRE_TRUE(input->sizeAt(-1) == 3, 0, "ADJUST_CONTRAST: operation expects image with 3 channels (R, G, B), but got %i instead", input->sizeAt(-1)); | ||||
| 
 | ||||
|         // compute mean before
 | ||||
|         reduce_mean meanOp; | ||||
|         auto axes = NDArrayFactory::create<int>('c', {input->rankOf() - 1}, block.launchContext()); | ||||
|         for (int i = 0; i < input->rankOf() -1; i++) | ||||
|             axes.p(i, i); | ||||
| 
 | ||||
|         auto meanRes = meanOp.execute({input, &axes}, {}, {}); | ||||
|         REQUIRE_TRUE(meanRes->status() == Status::OK(), 0, "ADJUST_CONTRAST: op should be successful, but error code %i occured.", meanRes->status()); | ||||
|         auto mean = meanRes->at(0); | ||||
| //        NDArray factorT(output->dataType(), block.launchContext()); // = NDArrayFactory::create(factor, block.launchContext());
 | ||||
| //        factorT.p(0, factor);
 | ||||
|         // this is contrast calculation
 | ||||
|         std::unique_ptr<NDArray> temp(input->dup()); | ||||
|         input->applyTrueBroadcast(BroadcastOpsTuple::Subtract(), mean, temp.get()); | ||||
|         temp->applyScalar(scalar::Multiply, factor); | ||||
|         temp->applyTrueBroadcast(BroadcastOpsTuple::Add(), mean, output); | ||||
| //        *output = (*input - *mean) * factorT + *mean;
 | ||||
| 
 | ||||
|         delete meanRes; | ||||
|         return Status::OK(); | ||||
|     } | ||||
| 
 | ||||
|     DECLARE_TYPES(adjust_contrast_v2) { | ||||
|         getOpDescriptor()->setAllowedInputTypes(nd4j::DataType::ANY) | ||||
|                 ->setAllowedOutputTypes({ALL_FLOATS}) | ||||
|                 ->setSameMode(true); | ||||
|     } | ||||
| 
 | ||||
| } | ||||
| } | ||||
|  | ||||
| @ -611,6 +611,7 @@ namespace nd4j { | ||||
|          */ | ||||
|         #if NOT_EXCLUDED(OP_adjust_contrast) | ||||
|         DECLARE_CONFIGURABLE_OP(adjust_contrast, 1, 1, true, 1, 0); | ||||
|         DECLARE_CONFIGURABLE_OP(adjust_contrast_v2, 1, 1, true, 1, 0); | ||||
|         #endif | ||||
| 
 | ||||
| 
 | ||||
|  | ||||
| @ -193,6 +193,42 @@ TEST_F(DeclarableOpsTests15, Test_AdjustContrast_2) { | ||||
|     delete result; | ||||
| } | ||||
| 
 | ||||
| TEST_F(DeclarableOpsTests15, Test_AdjustContrast_3) { | ||||
|     auto x = NDArrayFactory::create<float>('c', {1, 4,4,3}); | ||||
|     auto e = NDArrayFactory::create<float>('c', {1, 4,4,3}, { | ||||
|             -21.5, -20.5, -19.5,  -15.5, -14.5, -13.5,  -9.5,  -8.5,  -7.5,  -3.5,  -2.5,  -1.5, | ||||
|             2.5,   3.5,   4.5,    8.5,   9.5,  10.5,  14.5,  15.5,  16.5,  20.5,  21.5,  22.5, | ||||
|             26.5,  27.5,  28.5,   32.5,  33.5,  34.5,  38.5,  39.5,  40.5,  44.5,  45.5,  46.5, | ||||
|             50.5,  51.5,  52.5,   56.5,  57.5,  58.5,  62.5,  63.5,  64.5,  68.5,  69.5,  70.5 | ||||
|     }); | ||||
|     x.linspace(1.); | ||||
|     nd4j::ops::adjust_contrast_v2 op; | ||||
|     auto result = op.execute({&x}, {2.}, {}, {}); | ||||
|     ASSERT_EQ(Status::OK(), result->status()); | ||||
|     auto out = result->at(0); | ||||
| //    out->printIndexedBuffer("Adjusted Constrast");
 | ||||
|     ASSERT_TRUE(e.equalsTo(out)); | ||||
|     delete result; | ||||
| } | ||||
| 
 | ||||
| TEST_F(DeclarableOpsTests15, Test_AdjustContrast_4) { | ||||
|     auto x = NDArrayFactory::create<double>('c', {4, 4, 3}); | ||||
|     auto e = NDArrayFactory::create<double>('c', {4, 4, 3}, { | ||||
|             -21.5, -20.5, -19.5,  -15.5, -14.5, -13.5,  -9.5,  -8.5,  -7.5,  -3.5,  -2.5,  -1.5, | ||||
|             2.5,   3.5,   4.5,    8.5,   9.5,  10.5,  14.5,  15.5,  16.5,  20.5,  21.5,  22.5, | ||||
|             26.5,  27.5,  28.5,   32.5,  33.5,  34.5,  38.5,  39.5,  40.5,  44.5,  45.5,  46.5, | ||||
|             50.5,  51.5,  52.5,   56.5,  57.5,  58.5,  62.5,  63.5,  64.5,  68.5,  69.5,  70.5 | ||||
|     }); | ||||
|     x.linspace(1.); | ||||
|     nd4j::ops::adjust_contrast_v2 op; | ||||
|     auto result = op.execute({&x}, {2.}, {}, {}); | ||||
|     ASSERT_EQ(Status::OK(), result->status()); | ||||
|     auto out = result->at(0); | ||||
| //    out->printIndexedBuffer("Adjusted Constrast");
 | ||||
|     ASSERT_TRUE(e.equalsTo(out)); | ||||
|     delete result; | ||||
| } | ||||
| 
 | ||||
| TEST_F(DeclarableOpsTests15, Test_depthwise_bp_1) { | ||||
|     auto in = NDArrayFactory::create<float>('c', {4, 8, 64, 64}); | ||||
|     auto w = NDArrayFactory::create<float>('c', {2, 2, 8, 2}); | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user