| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | /*******************************************************************************
 | 
					
						
							|  |  |  |  * Copyright (c) 2015-2018 Skymind, Inc. | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  * This program and the accompanying materials are made available under the | 
					
						
							|  |  |  |  * terms of the Apache License, Version 2.0 which is available at | 
					
						
							|  |  |  |  * https://www.apache.org/licenses/LICENSE-2.0.
 | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  * Unless required by applicable law or agreed to in writing, software | 
					
						
							|  |  |  |  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | 
					
						
							|  |  |  |  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | 
					
						
							|  |  |  |  * License for the specific language governing permissions and limitations | 
					
						
							|  |  |  |  * under the License. | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  * SPDX-License-Identifier: Apache-2.0 | 
					
						
							|  |  |  |  ******************************************************************************/ | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | //
 | 
					
						
							|  |  |  | // Created by raver119 on 16.10.2017.
 | 
					
						
							|  |  |  | //
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #include <ops/declarable/LegacyScalarOp.h>
 | 
					
						
							| 
									
										
										
										
											2020-03-02 12:49:41 +03:00
										 |  |  | #include <array/NDArrayFactory.h>
 | 
					
						
							|  |  |  | #include <graph/Status.h>
 | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-02 12:49:41 +03:00
										 |  |  | namespace sd { | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |     namespace ops { | 
					
						
							|  |  |  |         LegacyScalarOp::LegacyScalarOp() : LegacyOp::LegacyOp(1) { | 
					
						
							| 
									
										
										
										
											2020-02-13 20:59:35 +03:00
										 |  |  |             this->getOpDescriptor()->allowInplace(true); | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |         } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         LegacyScalarOp::LegacyScalarOp(int opNum)  : LegacyOp::LegacyOp(1, opNum){ | 
					
						
							| 
									
										
										
										
											2020-02-13 20:59:35 +03:00
										 |  |  |             this->getOpDescriptor()->allowInplace(true); | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |         } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         LegacyOp* LegacyScalarOp::clone() { | 
					
						
							|  |  |  |             return new LegacyScalarOp(this->_opNum, *this->_scalar); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         LegacyScalarOp::LegacyScalarOp(int opNum, NDArray &scalar)  : LegacyOp::LegacyOp(1, opNum){ | 
					
						
							| 
									
										
										
										
											2019-12-20 21:35:39 +02:00
										 |  |  |             _scalar = new NDArray(scalar.dup(scalar.ordering())); | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |         } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-02 12:49:41 +03:00
										 |  |  |         ShapeList *LegacyScalarOp::calculateOutputShape(ShapeList *inputShape, sd::graph::Context &block) { | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |             auto inShape = inputShape->at(0); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             Nd4jLong *newShape; | 
					
						
							|  |  |  |             COPY_SHAPE(inShape, newShape); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             return SHAPELIST(CONSTANT(newShape)); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         Nd4jStatus LegacyScalarOp::validateAndExecute(Context &block) { | 
					
						
							|  |  |  |             auto x = INPUT_VARIABLE(0); | 
					
						
							|  |  |  |             auto z = OUTPUT_VARIABLE(0); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             int opNum = block.opNum() < 0 ? this->_opNum : block.opNum(); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             ExtraArguments extras(*block.getTArguments()); | 
					
						
							|  |  |  |             PointersManager manager(block.launchContext(), "LegacyScalarOp"); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             if (block.width() > 1) { | 
					
						
							|  |  |  |                 auto y = INPUT_VARIABLE(1); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 NDArray::prepareSpecialUse({z}, {x, y}); | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-05-09 08:06:14 +03:00
										 |  |  |                 NativeOpExecutioner::execScalar(block.launchContext(), opNum, x->buffer(), x->shapeInfo(), x->specialBuffer(), x->specialShapeInfo(), z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), y->buffer(), y->shapeInfo(), y->specialBuffer(), y->specialShapeInfo(), extras.argumentsAsT(z->dataType())); | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-02-13 20:59:35 +03:00
										 |  |  |                 NDArray::registerSpecialUse({z}, {x, y}); | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |             } else if (block.getTArguments()->size() > 0) { | 
					
						
							|  |  |  |                 auto y = NDArrayFactory::create(x->dataType(), T_ARG(0), block.launchContext()); | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-02 12:49:41 +03:00
										 |  |  |                 x->applyScalarArr(static_cast<sd::scalar::Ops>(opNum), y, *z); | 
					
						
							| 
									
										
										
										
											2020-01-20 21:32:46 +03:00
										 |  |  |                 // NDArray::prepareSpecialUse({z}, {x, &y});
 | 
					
						
							| 
									
										
										
										
											2020-06-06 15:26:55 +03:00
										 |  |  |                 // NativeOpExecutioner::execScalar(block.launchContext(), opNum, x->buffer(), x->shapeInfo(), x->specialBuffer(), x->specialShapeInfo(), z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), y.buffer(), y.shapeInfo(), y.specialBuffer(), y.special(), extras.argumentsAsT(z->dataType(), 1));
 | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |                 manager.synchronize(); | 
					
						
							|  |  |  |             } else { | 
					
						
							|  |  |  |                 NDArray::prepareSpecialUse({z}, {x, _scalar}); | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-05-09 08:06:14 +03:00
										 |  |  |                 NativeOpExecutioner::execScalar(block.launchContext(), opNum, x->buffer(), x->shapeInfo(), x->specialBuffer(), x->specialShapeInfo(), z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), _scalar->buffer(), _scalar->shapeInfo(), _scalar->specialBuffer(), _scalar->specialShapeInfo(), extras.argumentsAsT(z->dataType())); | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-02-13 20:59:35 +03:00
										 |  |  |                 NDArray::registerSpecialUse({z}, {x, _scalar}); | 
					
						
							|  |  |  |             } | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |             return Status::OK(); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | } |