| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | /*******************************************************************************
 | 
					
						
							|  |  |  |  * Copyright (c) 2015-2018 Skymind, Inc. | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  * This program and the accompanying materials are made available under the | 
					
						
							|  |  |  |  * terms of the Apache License, Version 2.0 which is available at | 
					
						
							|  |  |  |  * https://www.apache.org/licenses/LICENSE-2.0.
 | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  * Unless required by applicable law or agreed to in writing, software | 
					
						
							|  |  |  |  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | 
					
						
							|  |  |  |  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | 
					
						
							|  |  |  |  * License for the specific language governing permissions and limitations | 
					
						
							|  |  |  |  * under the License. | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  * SPDX-License-Identifier: Apache-2.0 | 
					
						
							|  |  |  |  ******************************************************************************/ | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | //
 | 
					
						
							|  |  |  | //  @author raver119@gmail.com
 | 
					
						
							|  |  |  | //
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #ifndef LIBND4J_BROADCAST_HELPER_H
 | 
					
						
							|  |  |  | #define LIBND4J_BROADCAST_HELPER_H
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #include <NDArray.h>
 | 
					
						
							|  |  |  | #include <helpers/ShapeUtils.h>
 | 
					
						
							|  |  |  | #include <ops/BroadcastOpsTuple.h>
 | 
					
						
							|  |  |  | #include <ops/BroadcastBoolOpsTuple.h>
 | 
					
						
							|  |  |  | #include <NDArrayFactory.h>
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | namespace nd4j { | 
					
						
							|  |  |  |     namespace ops { | 
					
						
							|  |  |  |         class BroadcastHelper { | 
					
						
							|  |  |  |         public:  | 
					
						
							|  |  |  |             static FORCEINLINE NDArray* broadcastApply(nd4j::BroadcastOpsTuple op, NDArray* x, NDArray* y, NDArray* z, ExtraArguments *extraArgs = nullptr) { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 if(x->isEmpty() || y->isEmpty()) { | 
					
						
							|  |  |  |                     if(!z->isEmpty()) | 
					
						
							|  |  |  |                         throw std::runtime_error("BroadcastHelper::broadcastApply: when some of input arrays (or both) is empty, output array must be empty as well !"); | 
					
						
							|  |  |  |                     return z; | 
					
						
							|  |  |  |                 } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 std::unique_ptr<NDArray> ptr; | 
					
						
							|  |  |  |                 if (!Environment::getInstance()->isExperimentalBuild()) { | 
					
						
							|  |  |  |                     if (y->dataType() != x->dataType()) { | 
					
						
							|  |  |  |                         y = y->cast(x->dataType()); | 
					
						
							|  |  |  |                         std::unique_ptr<NDArray> ptr2(y); | 
					
						
							|  |  |  |                         ptr.swap(ptr2); | 
					
						
							|  |  |  |                     } | 
					
						
							|  |  |  |                 } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 if (!x->isScalar() && !y->isScalar() && x->isSameShape(y)) { | 
					
						
							|  |  |  | 				    x->applyPairwiseTransform(op.p, y, z, nullptr); | 
					
						
							|  |  |  |                 } else if (!x->isScalar() && y->isScalar()) { | 
					
						
							|  |  |  |                     x->applyScalarArr(op.s, const_cast<const NDArray*>(y), z); | 
					
						
							|  |  |  |                 } else if (x->isScalar() && !y->isScalar()) { | 
					
						
							|  |  |  |                     if (z->isSameShape(y)) { | 
					
						
							|  |  |  |                         if (op.s == scalar::Add || op.s == scalar::Multiply ) { | 
					
						
							|  |  |  |                             y->applyScalarArr(op.s, x, z, nullptr); | 
					
						
							|  |  |  |                         } else if (op.s == scalar::SquaredSubtract) { | 
					
						
							|  |  |  |                             y->applyScalarArr(scalar::SquaredReverseSubtract, x, z, nullptr); | 
					
						
							|  |  |  |                         } else if (op.s == scalar::Subtract) { | 
					
						
							|  |  |  |                             y->applyScalarArr(scalar::ReverseSubtract, x, z, nullptr); | 
					
						
							|  |  |  |                         } else if (op.s == scalar::Divide) { | 
					
						
							|  |  |  |                             y->applyScalarArr(scalar::ReverseDivide, x, z, nullptr); | 
					
						
							|  |  |  |                         } else if (op.s == scalar::Pow) { | 
					
						
							|  |  |  |                             y->applyScalarArr(scalar::ReversePow, x, z, nullptr); | 
					
						
							|  |  |  |                         } else if (op.s == scalar::ReverseSubtract) { | 
					
						
							|  |  |  |                             y->applyScalarArr(scalar::Subtract, x, z, nullptr); | 
					
						
							|  |  |  |                         } else if (op.s == scalar::ReverseDivide) { | 
					
						
							|  |  |  |                             y->applyScalarArr(scalar::Divide, x, z, nullptr); | 
					
						
							|  |  |  |                         } else if (op.s == scalar::MaxPairwise || op.s == scalar::MinPairwise || op.s == scalar::AMaxPairwise || op.s == scalar::AMinPairwise) { | 
					
						
							|  |  |  |                             y->applyScalarArr(op.s, x, z, nullptr); | 
					
						
							|  |  |  |                         } else if (op.s == scalar::CopyPws) { | 
					
						
							|  |  |  |                             z->assign(y); | 
					
						
							|  |  |  |                         } else { | 
					
						
							|  |  |  |                             z->assign(x); | 
					
						
							|  |  |  |                             z->applyPairwiseTransform(op.p, *y, extraArgs); | 
					
						
							|  |  |  |                         } | 
					
						
							|  |  |  |                         return z; | 
					
						
							|  |  |  |                     } else { | 
					
						
							|  |  |  |                         auto v = y->getShapeAsVector(); | 
					
						
							|  |  |  |                         auto tZ = NDArrayFactory::valueOf(v, y, y->ordering()); | 
					
						
							|  |  |  |                         tZ->applyPairwiseTransform(op.p, *y, extraArgs); | 
					
						
							|  |  |  |                         return tZ; | 
					
						
							|  |  |  |                     } | 
					
						
							|  |  |  |                 } else if (x->isScalar() && y->isScalar()) { // x->isScalar() && y->isScalar()
 | 
					
						
							|  |  |  | 				    x->applyScalarArr(op.s, const_cast<const NDArray*>(y), z, nullptr); | 
					
						
							|  |  |  | 			    } else if (ShapeUtils::areShapesBroadcastable(*x, *y)) { | 
					
						
							|  |  |  |                     x->applyTrueBroadcast(op, y, z, true, extraArgs); | 
					
						
							|  |  |  |                     return z; | 
					
						
							|  |  |  |                 } else { | 
					
						
							|  |  |  |                     auto sx = ShapeUtils::shapeAsString(x); | 
					
						
							|  |  |  |                     auto sy = ShapeUtils::shapeAsString(y); | 
					
						
							|  |  |  |                     nd4j_printf("Broadcast: shapes should be equal, or broadcastable. But got %s vs %s instead\n", sx.c_str(), sy.c_str()); | 
					
						
							|  |  |  |                     return nullptr; | 
					
						
							|  |  |  |                 } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 return z; | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             static FORCEINLINE NDArray* broadcastApply(nd4j::BroadcastBoolOpsTuple op, NDArray* x, NDArray* y, NDArray* z, ExtraArguments *extraArgs = nullptr) { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 if(x->isEmpty() || y->isEmpty()) { | 
					
						
							|  |  |  |                     if(!z->isEmpty()) | 
					
						
							|  |  |  |                         throw std::runtime_error("BroadcastHelper::broadcastApply: when some of input arrays (or both) is empty, output array must be empty as well !"); | 
					
						
							|  |  |  |                     return z; | 
					
						
							|  |  |  |                 } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 if (!x->isScalar() && !y->isScalar() && x->isSameShape(y)) { | 
					
						
							|  |  |  |                     x->applyPairwiseTransform(op.p, y, z, nullptr); | 
					
						
							| 
									
										
										
										
											2019-12-02 21:37:21 +03:00
										 |  |  |                 } else if (ShapeUtils::areShapesBroadcastable(*x, *y)) { | 
					
						
							|  |  |  |                     x->applyTrueBroadcast(op, y, z, true, extraArgs); | 
					
						
							|  |  |  |                     return z; | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |                 } else if (!x->isScalar() && y->isScalar()) { | 
					
						
							|  |  |  |                     x->applyScalarArr(op.s, const_cast<const NDArray*>(y), z); | 
					
						
							|  |  |  |                 } else if (x->isScalar() && !y->isScalar()) { | 
					
						
							|  |  |  |                     if (z->isSameShape(y)) { | 
					
						
							|  |  |  |                         //z->assign(x);
 | 
					
						
							|  |  |  |                         x->applyPairwiseTransform(op.p, y, z, extraArgs); | 
					
						
							|  |  |  |                         return z; | 
					
						
							|  |  |  |                     } else { | 
					
						
							|  |  |  |                         auto v = y->getShapeAsVector(); | 
					
						
							|  |  |  |                         auto tZ = NDArrayFactory::valueOf(v, y, y->ordering()); | 
					
						
							|  |  |  |                         //tZ->applyPairwiseTransform(op.p, *y, extraArgs);
 | 
					
						
							|  |  |  |                         return tZ; | 
					
						
							|  |  |  |                     } | 
					
						
							|  |  |  |                 } else if (x->isScalar() && y->isScalar()) { // x->isScalar() && y->isScalar()
 | 
					
						
							|  |  |  |                     x->applyScalarArr(op.s, const_cast<const NDArray*>(y), z, nullptr); | 
					
						
							|  |  |  |                 } else if (ShapeUtils::areShapesBroadcastable(*x, *y)) { | 
					
						
							|  |  |  |                     x->applyTrueBroadcast(op, y, z, true, extraArgs); | 
					
						
							|  |  |  |                     return z; | 
					
						
							|  |  |  |                 } else { | 
					
						
							|  |  |  |                     auto sx = ShapeUtils::shapeAsString(x); | 
					
						
							|  |  |  |                     auto sy = ShapeUtils::shapeAsString(y); | 
					
						
							|  |  |  |                     nd4j_printf("Broadcast: shapes should be equal, or broadcastable. But got %s vs %s instead\n", sx.c_str(), sy.c_str()); | 
					
						
							|  |  |  |                     return nullptr; | 
					
						
							|  |  |  |                 } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 return z; | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |         }; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #endif
 |