2019-06-06 14:21:15 +02:00
/*******************************************************************************
* Copyright ( c ) 2015 - 2018 Skymind , Inc .
*
* This program and the accompanying materials are made available under the
* terms of the Apache License , Version 2.0 which is available at
* https : //www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing , software
* distributed under the License is distributed on an " AS IS " BASIS , WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND , either express or implied . See the
* License for the specific language governing permissions and limitations
* under the License .
*
* SPDX - License - Identifier : Apache - 2.0
* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
//
// @author raver119@gmail.com
//
# include "../ConstantTadHelper.h"
2020-03-02 10:49:41 +01:00
# include <helpers/TAD.h>
# include <helpers/ShapeUtils.h>
2020-06-06 14:26:55 +02:00
# include <array/ConstantOffsetsBuffer.h>
# include <array/PrimaryPointerDeallocator.h>
2019-06-06 14:21:15 +02:00
# ifndef __CUDABLAS__
2020-02-24 05:51:01 +01:00
2020-03-02 10:49:41 +01:00
namespace sd {
2019-06-06 14:21:15 +02:00
ConstantTadHelper : : ConstantTadHelper ( ) {
2020-02-24 05:51:01 +01:00
MAP_IMPL < TadDescriptor , TadPack > pack ;
2019-06-06 14:21:15 +02:00
_cache . emplace_back ( pack ) ;
}
2020-06-06 14:26:55 +02:00
ConstantTadHelper & ConstantTadHelper : : getInstance ( ) {
static ConstantTadHelper instance ;
return instance ;
2019-06-06 14:21:15 +02:00
}
2019-09-03 21:02:02 +02:00
TadPack ConstantTadHelper : : tadForDimensions ( const Nd4jLong * originalShape , int dimension , const bool keepUnitiesInShape ) {
2019-06-06 14:21:15 +02:00
return tadForDimensions ( originalShape , & dimension , 1 , keepUnitiesInShape ) ;
}
2019-09-03 21:02:02 +02:00
TadPack ConstantTadHelper : : tadForDimensions ( const Nd4jLong * originalShape , const std : : vector < int > & dimensions , const bool keepUnitiesInShape ) {
2019-06-06 14:21:15 +02:00
return tadForDimensions ( originalShape , const_cast < int * > ( dimensions . data ( ) ) , dimensions . size ( ) , keepUnitiesInShape ) ;
}
2019-09-03 21:02:02 +02:00
TadPack ConstantTadHelper : : tadForDimensions ( const Nd4jLong * originalShape , int * dimensions , int dimLength , const bool keepUnitiesInShape ) {
2019-06-06 14:21:15 +02:00
TadDescriptor tadDescriptor ( originalShape , dimensions , dimLength , keepUnitiesInShape ) ;
return tadForDimensions ( tadDescriptor ) ;
}
2019-09-03 21:02:02 +02:00
TadPack ConstantTadHelper : : tadForDimensions ( ShapeDescriptor & descriptor , std : : vector < int > & dimensions , const bool keepUnitiesInShape ) {
2019-06-06 14:21:15 +02:00
TadDescriptor tadDescriptor ( descriptor , dimensions , keepUnitiesInShape ) ;
return tadForDimensions ( tadDescriptor ) ;
}
2019-09-03 21:02:02 +02:00
TadPack ConstantTadHelper : : tadForDimensions ( TadDescriptor & descriptor ) {
2019-06-06 14:21:15 +02:00
const int deviceId = 0 ;
2020-06-06 14:26:55 +02:00
std : : lock_guard < std : : mutex > lock ( _mutex ) ;
2019-06-06 14:21:15 +02:00
if ( _cache [ deviceId ] . count ( descriptor ) = = 0 ) {
2020-06-06 14:26:55 +02:00
// if there's no TadPack matching this descriptor - create one
2019-06-06 14:21:15 +02:00
const auto shapeInfo = descriptor . originalShape ( ) . toShapeInfo ( ) ;
const int rank = shape : : rank ( shapeInfo ) ;
const std : : vector < int > dimsToExclude = ShapeUtils : : evalDimsToExclude ( rank , descriptor . axis ( ) ) ;
const Nd4jLong numOfSubArrs = ShapeUtils : : getNumOfSubArrs ( shapeInfo , dimsToExclude ) ;
const int subArrRank = ( rank = = dimsToExclude . size ( ) | | descriptor . areUnitiesinShape ( ) ) ? rank : rank - dimsToExclude . size ( ) ;
2020-06-06 14:26:55 +02:00
auto sPtr = std : : make_shared < PointerWrapper > ( new Nd4jLong [ shape : : shapeInfoLength ( subArrRank ) ] , std : : make_shared < PrimaryPointerDeallocator > ( ) ) ; // shape of sub-arrays (same for all for them)
auto oPtr = std : : make_shared < PointerWrapper > ( new Nd4jLong [ numOfSubArrs ] , std : : make_shared < PrimaryPointerDeallocator > ( ) ) ;
2019-06-06 14:21:15 +02:00
2019-06-15 13:34:34 +02:00
if ( numOfSubArrs > 0 )
2020-06-06 14:26:55 +02:00
shape : : calcSubArrsShapeInfoAndOffsets ( shapeInfo , numOfSubArrs , dimsToExclude . size ( ) , dimsToExclude . data ( ) , sPtr - > pointerAsT < Nd4jLong > ( ) , oPtr - > pointerAsT < Nd4jLong > ( ) , descriptor . areUnitiesinShape ( ) ) ;
2019-06-06 14:21:15 +02:00
2020-06-06 14:26:55 +02:00
ConstantShapeBuffer shapeBuffer ( sPtr ) ;
ConstantOffsetsBuffer offsetsBuffer ( oPtr ) ;
TadPack t ( shapeBuffer , offsetsBuffer , numOfSubArrs ) ;
2019-06-06 14:21:15 +02:00
_cache [ deviceId ] [ descriptor ] = t ;
delete [ ] shapeInfo ;
}
2020-06-06 14:26:55 +02:00
return _cache [ deviceId ] [ descriptor ] ;
}
2019-06-06 14:21:15 +02:00
}
# endif