| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | /*******************************************************************************
 | 
					
						
							|  |  |  |  * Copyright (c) 2015-2018 Skymind, Inc. | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  * This program and the accompanying materials are made available under the | 
					
						
							|  |  |  |  * terms of the Apache License, Version 2.0 which is available at | 
					
						
							|  |  |  |  * https://www.apache.org/licenses/LICENSE-2.0.
 | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  * Unless required by applicable law or agreed to in writing, software | 
					
						
							|  |  |  |  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | 
					
						
							|  |  |  |  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | 
					
						
							|  |  |  |  * License for the specific language governing permissions and limitations | 
					
						
							|  |  |  |  * under the License. | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  * SPDX-License-Identifier: Apache-2.0 | 
					
						
							|  |  |  |  ******************************************************************************/ | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | //
 | 
					
						
							|  |  |  | //  @author raver119@gmail.com
 | 
					
						
							|  |  |  | //
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #ifndef DEV_TESTS_CONSTANTSHAPEHELPER_H
 | 
					
						
							|  |  |  | #define DEV_TESTS_CONSTANTSHAPEHELPER_H
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-02 12:49:41 +03:00
										 |  |  | #include <system/dll.h>
 | 
					
						
							|  |  |  | #include <system/pointercast.h>
 | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | #include <map>
 | 
					
						
							|  |  |  | #include <mutex>
 | 
					
						
							|  |  |  | #include <vector>
 | 
					
						
							| 
									
										
										
										
											2020-03-02 12:49:41 +03:00
										 |  |  | #include <array/ShapeDescriptor.h>
 | 
					
						
							| 
									
										
										
										
											2020-06-06 15:26:55 +03:00
										 |  |  | #include <array/ConstantShapeBuffer.h>
 | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | #include <memory/Workspace.h>
 | 
					
						
							| 
									
										
										
										
											2020-03-02 12:49:41 +03:00
										 |  |  | #include <system/op_boilerplate.h>
 | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-02 12:49:41 +03:00
										 |  |  | namespace sd { | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |     class ND4J_EXPORT ConstantShapeHelper { | 
					
						
							|  |  |  |     private: | 
					
						
							|  |  |  |         std::mutex _mutex; | 
					
						
							| 
									
										
										
										
											2020-06-06 15:26:55 +03:00
										 |  |  |         std::vector<MAP_IMPL<ShapeDescriptor, ConstantShapeBuffer>> _cache; | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         ConstantShapeHelper(); | 
					
						
							|  |  |  |     public: | 
					
						
							|  |  |  |         ~ConstantShapeHelper() = default; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-06-06 15:26:55 +03:00
										 |  |  |         static ConstantShapeHelper & getInstance(); | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-06-06 15:26:55 +03:00
										 |  |  |       ConstantShapeBuffer& bufferForShapeInfo(sd::DataType dataType, char order, const std::vector<Nd4jLong> &shape); | 
					
						
							|  |  |  |       ConstantShapeBuffer& bufferForShapeInfo(const ShapeDescriptor &descriptor); | 
					
						
							|  |  |  |       ConstantShapeBuffer& bufferForShapeInfo(const Nd4jLong *shapeInfo); | 
					
						
							|  |  |  |       ConstantShapeBuffer& bufferForShapeInfo(sd::DataType dataType, char order, int rank, const Nd4jLong* shape); | 
					
						
							|  |  |  |       ConstantShapeBuffer& createShapeInfoWithUnitiesForBroadcast(const Nd4jLong* maxShapeInfo, const Nd4jLong* minShapeInfo, sd::memory::Workspace* workspace = nullptr, const std::vector<int> &dimensions = {}); | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-05-09 08:06:14 +03:00
										 |  |  |         const Nd4jLong* emptyShapeInfo(sd::DataType dataType); | 
					
						
							|  |  |  |         const Nd4jLong* scalarShapeInfo(sd::DataType dataType); | 
					
						
							|  |  |  |         const Nd4jLong* vectorShapeInfo(Nd4jLong length, sd::DataType dataType); | 
					
						
							|  |  |  |         const Nd4jLong* createShapeInfo(const ShapeDescriptor &descriptor); | 
					
						
							|  |  |  |         const Nd4jLong* createShapeInfo(sd::DataType dataType, char order, const std::vector<Nd4jLong> &shape); | 
					
						
							|  |  |  |         const Nd4jLong* createShapeInfo(sd::DataType dataType, char order, int rank, const Nd4jLong* shape); | 
					
						
							|  |  |  |         const Nd4jLong* createShapeInfo(sd::DataType dataType, const Nd4jLong* shapeInfo); | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-05-09 08:06:14 +03:00
										 |  |  |         const Nd4jLong* createFromExisting(Nd4jLong *shapeInfo, sd::memory::Workspace *workspace); | 
					
						
							|  |  |  |         const Nd4jLong* createFromExisting(Nd4jLong *shapeInfo, bool destroyOriginal = true); | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |         bool checkBufferExistenceForShapeInfo(ShapeDescriptor &descriptor); | 
					
						
							| 
									
										
										
										
											2019-07-22 14:00:24 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         /**
 | 
					
						
							|  |  |  |          * This method returns number of cached TAD shapes/offsets on specific device | 
					
						
							|  |  |  |          * @return | 
					
						
							|  |  |  |          */ | 
					
						
							|  |  |  |         FORCEINLINE int cachedEntriesForDevice(int deviceId) { | 
					
						
							|  |  |  |             if (deviceId > _cache.size()) | 
					
						
							|  |  |  |                 throw std::runtime_error("deviceId > number of actual devices"); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             return _cache[deviceId].size(); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         /**
 | 
					
						
							|  |  |  |          * This method returns total number of cached TAD shapes/offsets on all devices | 
					
						
							|  |  |  |          * @return | 
					
						
							|  |  |  |          */ | 
					
						
							|  |  |  |         FORCEINLINE int totalCachedEntries() { | 
					
						
							|  |  |  |             int total = 0; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             for (int e = 0; e < _cache.size(); e++) | 
					
						
							|  |  |  |                 total += _cache[e].size(); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             return total; | 
					
						
							|  |  |  |         } | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |     }; | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #endif //DEV_TESTS_CONSTANTSHAPEHELPER_H
 |