| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | /*******************************************************************************
 | 
					
						
							|  |  |  |  * Copyright (c) 2015-2018 Skymind, Inc. | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  * This program and the accompanying materials are made available under the | 
					
						
							|  |  |  |  * terms of the Apache License, Version 2.0 which is available at | 
					
						
							|  |  |  |  * https://www.apache.org/licenses/LICENSE-2.0.
 | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  * Unless required by applicable law or agreed to in writing, software | 
					
						
							|  |  |  |  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | 
					
						
							|  |  |  |  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | 
					
						
							|  |  |  |  * License for the specific language governing permissions and limitations | 
					
						
							|  |  |  |  * under the License. | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  * SPDX-License-Identifier: Apache-2.0 | 
					
						
							|  |  |  |  ******************************************************************************/ | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | //
 | 
					
						
							|  |  |  | //  @author raver119@gmail.com
 | 
					
						
							|  |  |  | //
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #include "testlayers.h"
 | 
					
						
							|  |  |  | #include <graph/GraphState.h>
 | 
					
						
							|  |  |  | #include <ops/declarable/CustomOperations.h>
 | 
					
						
							|  |  |  | #include <ops/declarable/LegacyTransformOp.h>
 | 
					
						
							|  |  |  | #include <ops/declarable/LegacyReduceOp.h>
 | 
					
						
							| 
									
										
										
										
											2020-03-02 12:49:41 +03:00
										 |  |  | #include <legacy/NativeOps.h>
 | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-02 12:49:41 +03:00
										 |  |  | using namespace sd; | 
					
						
							|  |  |  | using namespace sd::graph; | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  | class GraphStateTests : public testing::Test { | 
					
						
							|  |  |  | public: | 
					
						
							|  |  |  |     GraphStateTests() { | 
					
						
							| 
									
										
										
										
											2020-06-06 15:26:55 +03:00
										 |  |  |         Environment::getInstance().setDebug(false); | 
					
						
							|  |  |  |         Environment::getInstance().setVerbose(false); | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |     }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     ~GraphStateTests() { | 
					
						
							| 
									
										
										
										
											2020-06-06 15:26:55 +03:00
										 |  |  |         Environment::getInstance().setDebug(false); | 
					
						
							|  |  |  |         Environment::getInstance().setVerbose(false); | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |     } | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | /*
 | 
					
						
							|  |  |  |  * PLAN: | 
					
						
							|  |  |  |  * Create GraphState | 
					
						
							|  |  |  |  * Register Scope | 
					
						
							|  |  |  |  * Add few Ops to it | 
					
						
							| 
									
										
										
										
											2019-08-02 20:01:03 +03:00
										 |  |  |  * Call conditional, that refers to scopes | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |  * Check results | 
					
						
							|  |  |  |  */ | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | TEST_F(GraphStateTests, Basic_Tests_1) { | 
					
						
							| 
									
										
										
										
											2019-07-22 20:34:08 +09:00
										 |  |  |     auto state = (GraphState *) getGraphState(117L); | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |     ASSERT_EQ(117L, state->id()); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // this call will create scope internally
 | 
					
						
							|  |  |  |     state->registerScope(119); | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-02 12:49:41 +03:00
										 |  |  |     sd::ops::add opA; | 
					
						
							|  |  |  |     sd::ops::LegacyTransformSameOp opB(transform::Neg); // simdOps::Neg
 | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |     ArgumentsList argsA; | 
					
						
							|  |  |  |     ArgumentsList argsB; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     state->attachOpToScope(119, 1, &opA, argsA); | 
					
						
							|  |  |  |     state->attachOpToScope(119, 2, &opB, argsB); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     auto scope = state->getScope(119); | 
					
						
							|  |  |  |     ASSERT_TRUE(scope != nullptr); | 
					
						
							| 
									
										
										
										
											2019-08-02 20:01:03 +03:00
										 |  |  |     ASSERT_EQ(2, scope->size()); | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-07-22 20:34:08 +09:00
										 |  |  |     deleteGraphState(state); | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // just separate case for doubles wrapper in NativeOps, nothing else
 | 
					
						
							|  |  |  | TEST_F(GraphStateTests, Basic_Tests_2) { | 
					
						
							| 
									
										
										
										
											2019-07-22 20:34:08 +09:00
										 |  |  |     auto state = (GraphState *) getGraphState(117L); | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |     ASSERT_EQ(117L, state->id()); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // this call will create scope internally
 | 
					
						
							|  |  |  |     state->registerScope(119); | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-02 12:49:41 +03:00
										 |  |  |     sd::ops::add opA; | 
					
						
							|  |  |  |     sd::ops::LegacyTransformSameOp opB(transform::Neg); // simdOps::Neg
 | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |     ArgumentsList argsA; | 
					
						
							|  |  |  |     ArgumentsList argsB; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     state->attachOpToScope(119, 1, &opA, argsA); | 
					
						
							|  |  |  |     state->attachOpToScope(119, 2, &opB, argsB); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     auto scope = state->getScope(119); | 
					
						
							|  |  |  |     ASSERT_TRUE(scope != nullptr); | 
					
						
							| 
									
										
										
										
											2019-08-02 20:01:03 +03:00
										 |  |  |     ASSERT_EQ(2, scope->size()); | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-07-22 20:34:08 +09:00
										 |  |  |     deleteGraphState(state); | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-08-02 20:01:03 +03:00
										 |  |  | /*
 | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | TEST_F(GraphStateTests, Stateful_Execution_1) { | 
					
						
							| 
									
										
										
										
											2019-07-22 20:34:08 +09:00
										 |  |  |     auto state = getGraphState(117L); | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |     Nd4jLong scopes[] = {22, 33}; | 
					
						
							| 
									
										
										
										
											2019-07-22 20:34:08 +09:00
										 |  |  |     //auto status = execCustomOpWithScope(nullptr, state, 10, scopes, 2, nullptr, nullptr, 0, nullptr, nullptr, 0);
 | 
					
						
							|  |  |  |     auto status = execCustomOpWithScope(nullptr, state, 10, scopes, 2, nullptr, nullptr, 0, nullptr, nullptr, 0); | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |     ASSERT_EQ(Status::THROW(), status); | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-07-22 20:34:08 +09:00
										 |  |  |     deleteGraphState(state); | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | TEST_F(GraphStateTests, Stateful_Execution_2) { | 
					
						
							| 
									
										
										
										
											2019-07-22 20:34:08 +09:00
										 |  |  |     auto state = (GraphState *) getGraphState(117L); | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |     state->registerScope(22); | 
					
						
							|  |  |  |     state->registerScope(33); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     Nd4jLong scopes[] = {22, 33}; | 
					
						
							| 
									
										
										
										
											2019-07-22 20:34:08 +09:00
										 |  |  |     auto status = execCustomOpWithScope(nullptr, state, 10, scopes, 2, nullptr, nullptr, 0, nullptr, nullptr, 0); | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |     // it's no-op: just LogicScope
 | 
					
						
							|  |  |  |     ASSERT_EQ(Status::OK(), status); | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-07-22 20:34:08 +09:00
										 |  |  |     deleteGraphState(state); | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-08-02 20:01:03 +03:00
										 |  |  | // This test checks WHILE loop
 | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | TEST_F(GraphStateTests, Stateful_Execution_3) { | 
					
						
							|  |  |  |     auto var0 = NDArrayFactory::create<float>('c', {2, 2}, {1, 2, 3, 4}); | 
					
						
							|  |  |  |     auto var1 = NDArrayFactory::create<float>(11.0f); | 
					
						
							|  |  |  |     auto var2 = NDArrayFactory::create<float>(2.0f); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     auto res0 = NDArrayFactory::create<float>('c', {2, 2}); | 
					
						
							|  |  |  |     auto res1 = NDArrayFactory::create<float>(0.0f); | 
					
						
							|  |  |  |     auto res2 = NDArrayFactory::create<float>(0.0f); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // registering our GraphState holder
 | 
					
						
							| 
									
										
										
										
											2019-07-22 20:34:08 +09:00
										 |  |  |     auto state = (GraphState *) getGraphState(117L); | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |     // we're prepping pointers to input/output buffers
 | 
					
						
							|  |  |  |     Nd4jPointer ptrBuffers[] = {(Nd4jPointer) var0.buffer(), (Nd4jPointer) var1.buffer(), (Nd4jPointer)var2.buffer()}; | 
					
						
							|  |  |  |     Nd4jPointer ptrShapes[] = {(Nd4jPointer) var0.shapeInfo(), (Nd4jPointer) var1.shapeInfo(), (Nd4jPointer)var2.shapeInfo()}; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     Nd4jPointer outBuffers[] = {(Nd4jPointer) res0.buffer(), (Nd4jPointer) res1.buffer(), (Nd4jPointer) res2.buffer()}; | 
					
						
							|  |  |  |     Nd4jPointer outShapes[] = {(Nd4jPointer) res0.shapeInfo(), (Nd4jPointer) res1.shapeInfo(), (Nd4jPointer) res2.shapeInfo()}; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // conditional scope
 | 
					
						
							|  |  |  |     state->registerScope(22); | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-02 12:49:41 +03:00
										 |  |  |     sd::ops::LegacyReduceSameOp op1(reduce::Sum); | 
					
						
							|  |  |  |     sd::ops::lt_scalar op2; | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |     // while sum(var0) < var1
 | 
					
						
							|  |  |  |     // this op takes sum
 | 
					
						
							|  |  |  |     ArgumentsList args1({{0, 0}}); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // this op compares result of sum to input variable 0:1
 | 
					
						
							|  |  |  |     ArgumentsList args2({{1, 0}, {0, 1}}); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     state->attachOpToScope(22, 1, &op1, args1); | 
					
						
							|  |  |  |     state->attachOpToScope(22, 2, &op2, args2); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // body scope
 | 
					
						
							|  |  |  |     state->registerScope(33); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // var0 + var1 + var1
 | 
					
						
							|  |  |  |     // this op is var0 + var1
 | 
					
						
							|  |  |  |     ArgumentsList args3({{0, 0}, {0, 2}}); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // this op is result of previous op + 1
 | 
					
						
							|  |  |  |     ArgumentsList args4({{3, 0}, {0, 2}}); | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-02 12:49:41 +03:00
										 |  |  |     sd::ops::add op3; | 
					
						
							|  |  |  |     sd::ops::add op4; | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |     state->attachOpToScope(33, 3, &op3, args3); | 
					
						
							|  |  |  |     state->attachOpToScope(33, 4, &op4, args4); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // Now we define RETURN, which returns 1 modified variable, and 2 unmodified variables
 | 
					
						
							|  |  |  |     ArgumentsList args5({{4, 0}, {0, 1}, {0, 2}}); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // so, at the end of body, initial variables will be updated
 | 
					
						
							|  |  |  |     state->defineReturn(33, 5, args5); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     Nd4jLong scopes[] = {22, 33}; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // we're executing while loop
 | 
					
						
							| 
									
										
										
										
											2019-07-22 20:34:08 +09:00
										 |  |  |     auto status = execCustomOpWithScope(nullptr, state, 0, scopes, 2, ptrBuffers, ptrShapes, 3, outBuffers, outShapes, 3); | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |     ASSERT_EQ(Status::OK(), status); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // now we check provided result array
 | 
					
						
							|  |  |  |     float sum = res0.reduceNumber(reduce::Sum).e<float>(0); | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-08-02 20:01:03 +03:00
										 |  |  |     // Expected result is {1, 2, 3, 4} + {2} elementwise + {2} elementwise, which gives { 5, 6, 7, 8}, and sum should be 26
 | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |     ASSERT_NEAR(26.0f, sum, 1e-5); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // nd4j_printf("0 ------------------\n","");
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-07-22 20:34:08 +09:00
										 |  |  |     deleteGraphState(state); | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |     // nd4j_printf("1 ------------------\n","");
 | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-08-02 20:01:03 +03:00
										 |  |  | // This test checks CONDITIONAL execution for FALSE
 | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | TEST_F(GraphStateTests, Stateful_Execution_4) { | 
					
						
							|  |  |  |     auto var0 = NDArrayFactory::create<float>('c', {2, 2}, {1, 2, 3, 4}); | 
					
						
							|  |  |  |     auto var1 = NDArrayFactory::create<float>(5.0f); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     auto res0 = NDArrayFactory::create<float>('c', {2, 2}); | 
					
						
							|  |  |  |     auto res1 = NDArrayFactory::create<float>(0.0f); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     auto exp = NDArrayFactory::create<float>('c', {2, 2}, {-4, -3, -2, -1}); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // registering our GraphState holder
 | 
					
						
							| 
									
										
										
										
											2019-07-22 20:34:08 +09:00
										 |  |  |     auto state = (GraphState *) getGraphState(117L); | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |     // we're prepping pointers to input/output buffers
 | 
					
						
							|  |  |  |     Nd4jPointer ptrBuffers[] = {(Nd4jPointer) var0.buffer(), (Nd4jPointer) var1.buffer()}; | 
					
						
							|  |  |  |     Nd4jPointer ptrShapes[] = {(Nd4jPointer) var0.shapeInfo(), (Nd4jPointer) var1.shapeInfo()}; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     Nd4jPointer outBuffers[] = {(Nd4jPointer) res0.buffer(), (Nd4jPointer) res1.buffer()}; | 
					
						
							|  |  |  |     Nd4jPointer outShapes[] = {(Nd4jPointer) res0.shapeInfo(), (Nd4jPointer) res1.shapeInfo()}; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // conditional scope
 | 
					
						
							|  |  |  |     state->registerScope(22); | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-02 12:49:41 +03:00
										 |  |  |     sd::ops::LegacyReduceSameOp op1(reduce::Sum); | 
					
						
							|  |  |  |     sd::ops::lt_scalar op2; | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |     // if sum(var0) < var1
 | 
					
						
							|  |  |  |     // this op takes sum
 | 
					
						
							|  |  |  |     ArgumentsList args1({{0, 0}}); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // this op compares result of sum to input variable 0:1
 | 
					
						
							|  |  |  |     ArgumentsList args2({{1, 0}, {0, 1}}); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     state->attachOpToScope(22, 1, &op1, args1); | 
					
						
							|  |  |  |     state->attachOpToScope(22, 2, &op2, args2); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // false scope
 | 
					
						
							|  |  |  |     state->registerScope(33); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     ArgumentsList args3({{0, 0}, {0, 1}}); | 
					
						
							| 
									
										
										
										
											2020-03-02 12:49:41 +03:00
										 |  |  |     sd::ops::subtract op3; | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |     state->attachOpToScope(33, 3, &op3, args3); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // return for false scope
 | 
					
						
							|  |  |  |     ArgumentsList args10({{3, 0}, {0, 1}}); | 
					
						
							|  |  |  |     state->defineReturn(33, 10, args10); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // true scope
 | 
					
						
							|  |  |  |     state->registerScope(44); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     ArgumentsList args4({{0, 0}, {0, 1}}); | 
					
						
							| 
									
										
										
										
											2020-03-02 12:49:41 +03:00
										 |  |  |     sd::ops::add op4; | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |     state->attachOpToScope(44, 4, &op4, args4); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // return for false scope
 | 
					
						
							|  |  |  |     ArgumentsList args20({{4, 0}, {0, 1}}); | 
					
						
							|  |  |  |     state->defineReturn(44, 20, args20); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     Nd4jLong scopes[] = {22, 33, 44}; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // we're executing conditional op
 | 
					
						
							| 
									
										
										
										
											2019-07-22 20:34:08 +09:00
										 |  |  |     auto status = execCustomOpWithScope(nullptr, state, 20, scopes, 3, ptrBuffers, ptrShapes, 2, outBuffers, outShapes, 2); | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |     ASSERT_EQ(Status::OK(), status); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     ASSERT_TRUE(exp.isSameShape(&res0)); | 
					
						
							|  |  |  |     ASSERT_TRUE(exp.equalsTo(&res0)); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-07-22 20:34:08 +09:00
										 |  |  |     deleteGraphState(state); | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-08-02 20:01:03 +03:00
										 |  |  | // This test checks CONDITIONAL execution for TRUE
 | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | TEST_F(GraphStateTests, Stateful_Execution_5) { | 
					
						
							|  |  |  |     auto var0 = NDArrayFactory::create<float>('c', {2, 2}, {1, 2, 3, 4}); | 
					
						
							|  |  |  |     auto var1 = NDArrayFactory::create<float>(5.0f); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     auto res0 = NDArrayFactory::create<float>('c', {2, 2}); | 
					
						
							|  |  |  |     auto res1 = NDArrayFactory::create<float>(0.0f); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     auto exp = NDArrayFactory::create<float>('c', {2, 2}, {6, 7, 8, 9}); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // registering our GraphState holder
 | 
					
						
							| 
									
										
										
										
											2019-07-22 20:34:08 +09:00
										 |  |  |     auto state = (GraphState *) getGraphState(117L); | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |     // we're prepping pointers to input/output buffers
 | 
					
						
							|  |  |  |     Nd4jPointer ptrBuffers[] = {(Nd4jPointer) var0.buffer(), (Nd4jPointer) var1.buffer()}; | 
					
						
							|  |  |  |     Nd4jPointer ptrShapes[] = {(Nd4jPointer) var0.shapeInfo(), (Nd4jPointer) var1.shapeInfo()}; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     Nd4jPointer outBuffers[] = {(Nd4jPointer) res0.buffer(), (Nd4jPointer) res1.buffer()}; | 
					
						
							|  |  |  |     Nd4jPointer outShapes[] = {(Nd4jPointer) res0.shapeInfo(), (Nd4jPointer) res1.shapeInfo()}; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // conditional scope
 | 
					
						
							|  |  |  |     state->registerScope(22); | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-02 12:49:41 +03:00
										 |  |  |     sd::ops::LegacyReduceSameOp op1(reduce::Sum); | 
					
						
							|  |  |  |     sd::ops::gt_scalar op2; | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |     // if sum(var0) < var1
 | 
					
						
							|  |  |  |     // this op takes sum
 | 
					
						
							|  |  |  |     ArgumentsList args1({{0, 0}}); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // this op compares result of sum to input variable 0:1
 | 
					
						
							|  |  |  |     ArgumentsList args2({{1, 0}, {0, 1}}); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     state->attachOpToScope(22, 1, &op1, args1); | 
					
						
							|  |  |  |     state->attachOpToScope(22, 2, &op2, args2); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // false scope
 | 
					
						
							|  |  |  |     state->registerScope(33); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     ArgumentsList args3({{0, 0}, {0, 1}}); | 
					
						
							| 
									
										
										
										
											2020-03-02 12:49:41 +03:00
										 |  |  |     sd::ops::subtract op3; | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |     state->attachOpToScope(33, 3, &op3, args3); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // return for false scope
 | 
					
						
							|  |  |  |     ArgumentsList args10({{3, 0}, {0, 1}}); | 
					
						
							|  |  |  |     state->defineReturn(33, 10, args10); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // true scope
 | 
					
						
							|  |  |  |     state->registerScope(44); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     ArgumentsList args4({{0, 0}, {0, 1}}); | 
					
						
							| 
									
										
										
										
											2020-03-02 12:49:41 +03:00
										 |  |  |     sd::ops::add op4; | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |     state->attachOpToScope(44, 4, &op4, args4); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // return for false scope
 | 
					
						
							|  |  |  |     ArgumentsList args20({{4, 0}, {0, 1}}); | 
					
						
							|  |  |  |     state->defineReturn(44, 20, args20); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     Nd4jLong scopes[] = {22, 33, 44}; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // we're executing conditional op
 | 
					
						
							| 
									
										
										
										
											2019-07-22 20:34:08 +09:00
										 |  |  |     auto status = execCustomOpWithScope(nullptr, state, 20, scopes, 3, ptrBuffers, ptrShapes, 2, outBuffers, outShapes, 2); | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |     ASSERT_EQ(Status::OK(), status); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     ASSERT_TRUE(exp.isSameShape(&res0)); | 
					
						
							|  |  |  |     ASSERT_TRUE(exp.equalsTo(&res0)); | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-07-22 20:34:08 +09:00
										 |  |  |     deleteGraphState(state); | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | } | 
					
						
							| 
									
										
										
										
											2019-08-02 20:01:03 +03:00
										 |  |  | */ |