| 
									
										
										
										
											2021-02-01 21:31:45 +09:00
										 |  |  | /* ******************************************************************************
 | 
					
						
							|  |  |  |  * | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |  * | 
					
						
							|  |  |  |  * This program and the accompanying materials are made available under the | 
					
						
							|  |  |  |  * terms of the Apache License, Version 2.0 which is available at | 
					
						
							|  |  |  |  * https://www.apache.org/licenses/LICENSE-2.0.
 | 
					
						
							|  |  |  |  * | 
					
						
							| 
									
										
										
										
											2021-02-01 21:31:45 +09:00
										 |  |  |  *  See the NOTICE file distributed with this work for additional | 
					
						
							|  |  |  |  *  information regarding copyright ownership. | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |  * Unless required by applicable law or agreed to in writing, software | 
					
						
							|  |  |  |  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | 
					
						
							|  |  |  |  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | 
					
						
							|  |  |  |  * License for the specific language governing permissions and limitations | 
					
						
							|  |  |  |  * under the License. | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  * SPDX-License-Identifier: Apache-2.0 | 
					
						
							|  |  |  |  ******************************************************************************/ | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | //
 | 
					
						
							| 
									
										
										
										
											2019-09-02 16:25:58 +03:00
										 |  |  | // @author Yurii Shyrma (iuriish@yahoo.com)
 | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | //
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-02 12:49:41 +03:00
										 |  |  | #include <array/ResultSet.h>
 | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | #include <ops/declarable/helpers/matrixSetDiag.h>
 | 
					
						
							| 
									
										
										
										
											2019-11-13 17:15:18 +03:00
										 |  |  | #include <execution/Threads.h>
 | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-02 12:49:41 +03:00
										 |  |  | namespace sd { | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | namespace ops { | 
					
						
							|  |  |  | namespace helpers { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | //////////////////////////////////////////////////////////////////////////
 | 
					
						
							| 
									
										
										
										
											2019-09-02 16:25:58 +03:00
										 |  |  | template<typename T> | 
					
						
							|  |  |  | void matrixSetDiag_(const NDArray& input, const NDArray& diagonal, NDArray& output, const bool zeroPad) { | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-09-02 16:25:58 +03:00
										 |  |  |     // input and output are the same array (x == z) when zeroPad = true
 | 
					
						
							|  |  |  |     // xRank = zRank, xRank = yRank + 1
 | 
					
						
							|  |  |  |     // xLen = zLen
 | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-09-02 16:25:58 +03:00
										 |  |  |     const T* x = input.bufferAsT<T>(); | 
					
						
							|  |  |  |     const T* y = diagonal.bufferAsT<T>(); | 
					
						
							|  |  |  |           T* z = output.bufferAsT<T>(); | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-05-09 08:06:14 +03:00
										 |  |  |     const Nd4jLong* xShapeInfo = input.shapeInfo(); | 
					
						
							|  |  |  |     const Nd4jLong* yShapeInfo = diagonal.shapeInfo(); | 
					
						
							|  |  |  |     const Nd4jLong* zShapeInfo = output.shapeInfo(); | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-09-02 16:25:58 +03:00
										 |  |  |     const bool areSameOffsets = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo);    // shapes are definitely the same, but strides might not
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     const int xRank = input.rankOf(); | 
					
						
							|  |  |  |     const auto xLen = input.lengthOf(); | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-11-13 17:15:18 +03:00
										 |  |  |     auto func = PRAGMA_THREADS_FOR { | 
					
						
							| 
									
										
										
										
											2020-03-11 15:21:59 +02:00
										 |  |  | 
 | 
					
						
							|  |  |  |         int coords[MAX_RANK]; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-11-13 17:15:18 +03:00
										 |  |  |         for (Nd4jLong i = 0; i < xLen; ++i) { | 
					
						
							| 
									
										
										
										
											2020-03-11 15:21:59 +02:00
										 |  |  | 
 | 
					
						
							|  |  |  |             shape::index2coordsCPU(start, i, xShapeInfo, coords); | 
					
						
							| 
									
										
										
										
											2019-11-13 17:15:18 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |             const auto xOffset = shape::getOffset(xShapeInfo, coords); | 
					
						
							|  |  |  |             const auto zOffset = areSameOffsets ? xOffset : shape::getOffset(zShapeInfo, coords); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             // condition to be on diagonal of innermost matrix
 | 
					
						
							|  |  |  |             if (coords[xRank - 2] == coords[xRank - 1]) | 
					
						
							|  |  |  |                 z[zOffset] = y[shape::getOffset(yShapeInfo, coords)]; | 
					
						
							|  |  |  |             else | 
					
						
							|  |  |  |                 z[zOffset] = zeroPad ? static_cast<T>(0) : x[xOffset]; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |     }; | 
					
						
							| 
									
										
										
										
											2020-03-09 08:22:49 +03:00
										 |  |  |     samediff::Threads::parallel_for(func, 0, xLen); | 
					
						
							| 
									
										
										
										
											2019-09-02 16:25:58 +03:00
										 |  |  | } | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-09-02 16:25:58 +03:00
										 |  |  | //////////////////////////////////////////////////////////////////////////
 | 
					
						
							| 
									
										
										
										
											2020-03-02 12:49:41 +03:00
										 |  |  | void matrixSetDiag(sd::LaunchContext* context, const NDArray& input, const NDArray& diagonal, NDArray& output, const bool zeroPad) { | 
					
						
							| 
									
										
										
										
											2019-09-02 16:25:58 +03:00
										 |  |  |     BUILD_SINGLE_SELECTOR(input.dataType(), matrixSetDiag_, (input, diagonal, output, zeroPad), LIBND4J_TYPES); | 
					
						
							|  |  |  | } | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | } |