2021-02-09 05:16:31 +01:00
/*
* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * *
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License , Version 2.0 which is available at
* * https : //www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership .
* * Unless required by applicable law or agreed to in writing , software
* * distributed under the License is distributed on an " AS IS " BASIS , WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND , either express or implied . See the
* * License for the specific language governing permissions and limitations
* * under the License .
* *
* * SPDX - License - Identifier : Apache - 2.0
* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * *
*/
2019-06-15 13:34:34 +02:00
template < typename T >
2019-12-20 20:35:39 +01:00
void NDArray : : applyTriplewiseLambda ( NDArray & second , NDArray & third , const std : : function < T ( T , T , T ) > & func , NDArray & target ) {
2019-06-15 13:34:34 +02:00
if ( dataType ( ) ! = DataTypeUtils : : fromT < T > ( ) )
throw std : : runtime_error ( " NDArray::applyTriplewiseLambda<T> method: wrong template parameter T, its type should be the same as type of this array! " ) ;
2019-12-20 20:35:39 +01:00
if ( dataType ( ) ! = second . dataType ( ) | | dataType ( ) ! = third . dataType ( ) | | dataType ( ) ! = target . dataType ( ) )
2019-06-15 13:34:34 +02:00
throw std : : runtime_error ( " NDArray::applyTriplewiseLambda<T> method: bother four arrays (this, second, third, target) should have the same type ! " ) ;
2019-12-20 20:35:39 +01:00
if ( this - > lengthOf ( ) ! = second . lengthOf ( ) | | this - > lengthOf ( ) ! = third . lengthOf ( ) | | ! this - > isSameShape ( second ) | | ! this - > isSameShape ( third ) ) {
2020-04-13 12:21:51 +02:00
nd4j_printf ( " applyTriplewiseLambda requires all operands to have the same shape \n " , " " ) ;
2019-06-15 13:34:34 +02:00
throw std : : runtime_error ( " Shapes mismach " ) ;
}
auto f = this - > bufferAsT < T > ( ) ;
2019-12-20 20:35:39 +01:00
auto s = second . bufferAsT < T > ( ) ;
auto t = third . bufferAsT < T > ( ) ;
auto z = target . bufferAsT < T > ( ) ;
2019-06-15 13:34:34 +02:00
2019-12-20 20:35:39 +01:00
if ( this - > ordering ( ) = = second . ordering ( ) & & this - > ordering ( ) = = third . ordering ( ) & & this - > ordering ( ) = = target . ordering ( ) & & ( this - > ews ( ) = = 1 & & target . ews ( ) = = 1 ) & & this - > ews ( ) = = second . ews ( ) & & this - > ews ( ) = = third . ews ( ) ) {
2019-06-15 13:34:34 +02:00
2019-11-13 15:15:18 +01:00
auto loop = PRAGMA_THREADS_FOR {
2020-02-20 09:43:26 +01:00
for ( auto e = start ; e < stop ; e + + )
2019-11-13 15:15:18 +01:00
z [ e ] = func ( f [ e ] , s [ e ] , t [ e ] ) ;
} ;
2020-03-09 06:22:49 +01:00
samediff : : Threads : : parallel_for ( loop , 0 , _length ) ;
2019-06-15 13:34:34 +02:00
} else {
if ( f = = z ) {
2019-11-13 15:15:18 +01:00
auto loop = PRAGMA_THREADS_FOR {
2020-02-20 09:43:26 +01:00
for ( auto e = start ; e < stop ; e + + ) {
2019-11-13 15:15:18 +01:00
auto tOffset = this - > getOffset ( e ) ;
2019-12-20 20:35:39 +01:00
auto uOffset = second . getOffset ( e ) ;
auto vOffset = third . getOffset ( e ) ;
2019-06-15 13:34:34 +02:00
2019-11-13 15:15:18 +01:00
f [ tOffset ] = func ( f [ tOffset ] , s [ uOffset ] , t [ vOffset ] ) ;
}
} ;
2019-06-15 13:34:34 +02:00
2020-03-09 06:22:49 +01:00
samediff : : Threads : : parallel_for ( loop , 0 , _length ) ;
2019-06-15 13:34:34 +02:00
} else {
2019-11-13 15:15:18 +01:00
auto loop = PRAGMA_THREADS_FOR {
2020-02-20 09:43:26 +01:00
for ( auto e = start ; e < stop ; e + + ) {
2019-11-13 15:15:18 +01:00
auto tOffset = this - > getOffset ( e ) ;
2019-12-20 20:35:39 +01:00
auto uOffset = second . getOffset ( e ) ;
auto vOffset = third . getOffset ( e ) ;
auto zOffset = target . getOffset ( e ) ;
2019-06-15 13:34:34 +02:00
2019-11-13 15:15:18 +01:00
z [ zOffset ] = func ( f [ tOffset ] , s [ uOffset ] , t [ vOffset ] ) ;
}
} ;
2019-06-15 13:34:34 +02:00
2020-03-09 06:22:49 +01:00
samediff : : Threads : : parallel_for ( loop , 0 , _length ) ;
2019-06-15 13:34:34 +02:00
}
}
}
2019-12-20 20:35:39 +01:00
template void NDArray : : applyTriplewiseLambda ( NDArray & second , NDArray & third , const std : : function < double ( double , double , double ) > & func , NDArray & target ) ;
template void NDArray : : applyTriplewiseLambda ( NDArray & second , NDArray & third , const std : : function < float ( float , float , float ) > & func , NDArray & target ) ;
template void NDArray : : applyTriplewiseLambda ( NDArray & second , NDArray & third , const std : : function < float16 ( float16 , float16 , float16 ) > & func , NDArray & target ) ;
template void NDArray : : applyTriplewiseLambda ( NDArray & second , NDArray & third , const std : : function < bfloat16 ( bfloat16 , bfloat16 , bfloat16 ) > & func , NDArray & target ) ;
template void NDArray : : applyTriplewiseLambda ( NDArray & second , NDArray & third , const std : : function < Nd4jLong ( Nd4jLong , Nd4jLong , Nd4jLong ) > & func , NDArray & target ) ;
template void NDArray : : applyTriplewiseLambda ( NDArray & second , NDArray & third , const std : : function < int ( int , int , int ) > & func , NDArray & target ) ;
template void NDArray : : applyTriplewiseLambda ( NDArray & second , NDArray & third , const std : : function < int16_t ( int16_t , int16_t , int16_t ) > & func , NDArray & target ) ;
template void NDArray : : applyTriplewiseLambda ( NDArray & second , NDArray & third , const std : : function < uint8_t ( uint8_t , uint8_t , uint8_t ) > & func , NDArray & target ) ;
template void NDArray : : applyTriplewiseLambda ( NDArray & second , NDArray & third , const std : : function < uint16_t ( uint16_t , uint16_t , uint16_t ) > & func , NDArray & target ) ;
template void NDArray : : applyTriplewiseLambda ( NDArray & second , NDArray & third , const std : : function < uint32_t ( uint32_t , uint32_t , uint32_t ) > & func , NDArray & target ) ;
template void NDArray : : applyTriplewiseLambda ( NDArray & second , NDArray & third , const std : : function < uint64_t ( uint64_t , uint64_t , uint64_t ) > & func , NDArray & target ) ;
template void NDArray : : applyTriplewiseLambda ( NDArray & second , NDArray & third , const std : : function < int8_t ( int8_t , int8_t , int8_t ) > & func , NDArray & target ) ;
template void NDArray : : applyTriplewiseLambda ( NDArray & second , NDArray & third , const std : : function < bool ( bool , bool , bool ) > & func , NDArray & target ) ;
2019-06-15 13:34:34 +02:00
//////////////////////////////////////////////////////////////////////////
template < typename T >
2019-12-20 20:35:39 +01:00
void NDArray : : applyPairwiseLambda ( const NDArray & other , const std : : function < T ( T , T ) > & func , NDArray & target ) {
2019-06-15 13:34:34 +02:00
if ( dataType ( ) ! = DataTypeUtils : : fromT < T > ( ) )
throw std : : runtime_error ( " NDArray::applyPairwiseLambda<T> method: wrong template parameter T, its type should be the same as type of this array! " ) ;
2019-12-20 20:35:39 +01:00
if ( dataType ( ) ! = other . dataType ( ) | | dataType ( ) ! = target . dataType ( ) )
2019-06-15 13:34:34 +02:00
throw std : : runtime_error ( " NDArray::applyPairwiseLambda<T> method: all three arrays (this, other, target) must have the same type ! " ) ;
2019-12-20 20:35:39 +01:00
if ( this - > lengthOf ( ) ! = other . lengthOf ( ) ) {
2019-06-15 13:34:34 +02:00
nd4j_printf ( " applyPairwiseLambda requires both operands to have the same shape \n " , " " ) ;
throw std : : runtime_error ( " Shapes mismach " ) ;
}
auto f = this - > bufferAsT < T > ( ) ;
2019-12-20 20:35:39 +01:00
auto s = other . bufferAsT < T > ( ) ;
auto z = target . bufferAsT < T > ( ) ;
2019-06-15 13:34:34 +02:00
2019-12-20 20:35:39 +01:00
if ( this - > ordering ( ) = = other . ordering ( ) & & this - > ordering ( ) = = target . ordering ( ) & & ( this - > ews ( ) = = 1 & & target . ews ( ) = = 1 ) & & this - > ews ( ) = = other . ews ( ) ) {
2019-06-15 13:34:34 +02:00
2019-11-13 15:15:18 +01:00
auto loop = PRAGMA_THREADS_FOR {
2020-02-20 09:43:26 +01:00
for ( auto e = start ; e < stop ; e + + )
2019-11-13 15:15:18 +01:00
z [ e ] = func ( f [ e ] , s [ e ] ) ;
} ;
2020-03-09 06:22:49 +01:00
samediff : : Threads : : parallel_for ( loop , 0 , _length ) ;
2019-06-15 13:34:34 +02:00
} else {
if ( f = = z ) {
2019-11-13 15:15:18 +01:00
auto loop = PRAGMA_THREADS_FOR {
2020-02-20 09:43:26 +01:00
for ( auto e = start ; e < stop ; e + + ) {
2019-11-13 15:15:18 +01:00
auto xOffset = this - > getOffset ( e ) ;
2019-12-20 20:35:39 +01:00
auto yOffset = other . getOffset ( e ) ;
2019-06-15 13:34:34 +02:00
2019-11-13 15:15:18 +01:00
f [ xOffset ] = func ( f [ xOffset ] , s [ yOffset ] ) ;
}
} ;
2019-06-15 13:34:34 +02:00
2020-03-09 06:22:49 +01:00
samediff : : Threads : : parallel_for ( loop , 0 , _length ) ;
2019-06-15 13:34:34 +02:00
} else {
2019-11-13 15:15:18 +01:00
auto loop = PRAGMA_THREADS_FOR {
2020-02-20 09:43:26 +01:00
for ( auto e = start ; e < stop ; e + + ) {
2019-11-13 15:15:18 +01:00
auto xOffset = this - > getOffset ( e ) ;
2019-12-20 20:35:39 +01:00
auto yOffset = other . getOffset ( e ) ;
auto zOffset = target . getOffset ( e ) ;
2019-06-15 13:34:34 +02:00
2019-11-13 15:15:18 +01:00
z [ zOffset ] = func ( f [ xOffset ] , s [ yOffset ] ) ;
}
} ;
2019-06-15 13:34:34 +02:00
2020-03-09 06:22:49 +01:00
samediff : : Threads : : parallel_for ( loop , 0 , _length ) ;
2019-06-15 13:34:34 +02:00
}
}
}
2019-12-20 20:35:39 +01:00
template void NDArray : : applyPairwiseLambda ( const NDArray & other , const std : : function < double ( double , double ) > & func , NDArray & target ) ;
template void NDArray : : applyPairwiseLambda ( const NDArray & other , const std : : function < float ( float , float ) > & func , NDArray & target ) ;
template void NDArray : : applyPairwiseLambda ( const NDArray & other , const std : : function < float16 ( float16 , float16 ) > & func , NDArray & target ) ;
template void NDArray : : applyPairwiseLambda ( const NDArray & other , const std : : function < bfloat16 ( bfloat16 , bfloat16 ) > & func , NDArray & target ) ;
template void NDArray : : applyPairwiseLambda ( const NDArray & other , const std : : function < Nd4jLong ( Nd4jLong , Nd4jLong ) > & func , NDArray & target ) ;
template void NDArray : : applyPairwiseLambda ( const NDArray & other , const std : : function < int ( int , int ) > & func , NDArray & target ) ;
template void NDArray : : applyPairwiseLambda ( const NDArray & other , const std : : function < int16_t ( int16_t , int16_t ) > & func , NDArray & target ) ;
template void NDArray : : applyPairwiseLambda ( const NDArray & other , const std : : function < uint8_t ( uint8_t , uint8_t ) > & func , NDArray & target ) ;
template void NDArray : : applyPairwiseLambda ( const NDArray & other , const std : : function < uint16_t ( uint16_t , uint16_t ) > & func , NDArray & target ) ;
template void NDArray : : applyPairwiseLambda ( const NDArray & other , const std : : function < uint32_t ( uint32_t , uint32_t ) > & func , NDArray & target ) ;
template void NDArray : : applyPairwiseLambda ( const NDArray & other , const std : : function < uint64_t ( uint64_t , uint64_t ) > & func , NDArray & target ) ;
template void NDArray : : applyPairwiseLambda ( const NDArray & other , const std : : function < int8_t ( int8_t , int8_t ) > & func , NDArray & target ) ;
template void NDArray : : applyPairwiseLambda ( const NDArray & other , const std : : function < bool ( bool , bool ) > & func , NDArray & target ) ;
2019-06-15 13:34:34 +02:00
//////////////////////////////////////////////////////////////////////////
template < typename T >
2019-12-20 20:35:39 +01:00
void NDArray : : applyLambda ( const std : : function < T ( T ) > & func , NDArray & target ) {
2019-06-15 13:34:34 +02:00
if ( dataType ( ) ! = DataTypeUtils : : fromT < T > ( ) )
throw std : : runtime_error ( " NDArray::applyLambda<T> method: wrong template parameter T, its type should be the same as type of this array! " ) ;
2019-12-20 20:35:39 +01:00
if ( dataType ( ) ! = target . dataType ( ) )
2019-06-15 13:34:34 +02:00
throw std : : runtime_error ( " NDArray::applyLambda<T> method: types of this and target array should match ! " ) ;
auto f = this - > bufferAsT < T > ( ) ;
2019-12-20 20:35:39 +01:00
auto z = target . bufferAsT < T > ( ) ;
2019-06-15 13:34:34 +02:00
2019-12-20 20:35:39 +01:00
if ( this - > ordering ( ) = = target . ordering ( ) & & ( this - > ews ( ) = = 1 & & target . ews ( ) = = 1 ) ) {
2019-06-15 13:34:34 +02:00
2019-11-13 15:15:18 +01:00
auto loop = PRAGMA_THREADS_FOR {
2020-02-20 09:43:26 +01:00
for ( auto e = start ; e < stop ; e + + )
2019-11-13 15:15:18 +01:00
z [ e ] = func ( f [ e ] ) ;
} ;
2020-03-09 06:22:49 +01:00
samediff : : Threads : : parallel_for ( loop , 0 , _length ) ;
2019-06-15 13:34:34 +02:00
} else {
if ( f = = z ) {
2019-11-13 15:15:18 +01:00
auto loop = PRAGMA_THREADS_FOR {
2020-02-20 09:43:26 +01:00
for ( auto e = start ; e < stop ; e + + ) {
2019-11-13 15:15:18 +01:00
auto xOffset = this - > getOffset ( e ) ;
2019-06-15 13:34:34 +02:00
2019-11-13 15:15:18 +01:00
f [ xOffset ] = func ( f [ xOffset ] ) ;
}
} ;
2019-06-15 13:34:34 +02:00
2020-03-09 06:22:49 +01:00
samediff : : Threads : : parallel_for ( loop , 0 , _length ) ;
2019-06-15 13:34:34 +02:00
} else {
2019-11-13 15:15:18 +01:00
auto loop = PRAGMA_THREADS_FOR {
2020-02-20 09:43:26 +01:00
for ( auto e = start ; e < stop ; e + + ) {
2019-11-13 15:15:18 +01:00
auto xOffset = this - > getOffset ( e ) ;
2019-12-20 20:35:39 +01:00
auto zOffset = target . getOffset ( e ) ;
2019-06-15 13:34:34 +02:00
2019-11-13 15:15:18 +01:00
z [ zOffset ] = func ( f [ xOffset ] ) ;
}
} ;
2019-06-15 13:34:34 +02:00
2020-03-09 06:22:49 +01:00
samediff : : Threads : : parallel_for ( loop , 0 , _length ) ;
2019-06-15 13:34:34 +02:00
}
}
}
2019-12-20 20:35:39 +01:00
template void NDArray : : applyLambda ( const std : : function < double ( double ) > & func , NDArray & target ) ;
template void NDArray : : applyLambda ( const std : : function < float ( float ) > & func , NDArray & target ) ;
template void NDArray : : applyLambda ( const std : : function < float16 ( float16 ) > & func , NDArray & target ) ;
template void NDArray : : applyLambda ( const std : : function < bfloat16 ( bfloat16 ) > & func , NDArray & target ) ;
template void NDArray : : applyLambda ( const std : : function < Nd4jLong ( Nd4jLong ) > & func , NDArray & target ) ;
template void NDArray : : applyLambda ( const std : : function < int16_t ( int16_t ) > & func , NDArray & target ) ;
template void NDArray : : applyLambda ( const std : : function < int32_t ( int32_t ) > & func , NDArray & target ) ;
template void NDArray : : applyLambda ( const std : : function < uint8_t ( uint8_t ) > & func , NDArray & target ) ;
template void NDArray : : applyLambda ( const std : : function < uint16_t ( uint16_t ) > & func , NDArray & target ) ;
template void NDArray : : applyLambda ( const std : : function < uint32_t ( uint32_t ) > & func , NDArray & target ) ;
template void NDArray : : applyLambda ( const std : : function < uint64_t ( uint64_t ) > & func , NDArray & target ) ;
template void NDArray : : applyLambda ( const std : : function < int8_t ( int8_t ) > & func , NDArray & target ) ;
template void NDArray : : applyLambda ( const std : : function < bool ( bool ) > & func , NDArray & target ) ;
2019-06-15 13:34:34 +02:00
//////////////////////////////////////////////////////////////////////////
template < typename T >
2019-12-20 20:35:39 +01:00
void NDArray : : applyIndexedLambda ( const std : : function < T ( Nd4jLong , T ) > & func , NDArray & target ) {
2019-06-15 13:34:34 +02:00
if ( dataType ( ) ! = DataTypeUtils : : fromT < T > ( ) )
throw std : : runtime_error ( " NDArray::applyIndexedLambda<T> method: wrong template parameter T, its type should be the same as type of this array! " ) ;
2019-12-20 20:35:39 +01:00
if ( dataType ( ) ! = target . dataType ( ) )
2019-06-15 13:34:34 +02:00
throw std : : runtime_error ( " NDArray::applyIndexedLambda<T> method: types of this and target array should match ! " ) ;
auto f = this - > bufferAsT < T > ( ) ;
2019-12-20 20:35:39 +01:00
auto z = target . bufferAsT < T > ( ) ;
2019-06-15 13:34:34 +02:00
2019-12-20 20:35:39 +01:00
if ( this - > ordering ( ) = = target . ordering ( ) & & ( this - > ews ( ) = = 1 & & target . ews ( ) = = 1 ) ) {
2019-06-15 13:34:34 +02:00
2019-11-13 15:15:18 +01:00
auto loop = PRAGMA_THREADS_FOR {
2020-02-20 09:43:26 +01:00
for ( auto e = start ; e < stop ; e + + )
2019-11-13 15:15:18 +01:00
z [ e ] = func ( e , f [ e ] ) ;
} ;
2020-03-09 06:22:49 +01:00
samediff : : Threads : : parallel_for ( loop , 0 , _length ) ;
2019-06-15 13:34:34 +02:00
} else {
if ( f = = z ) {
2019-11-13 15:15:18 +01:00
auto loop = PRAGMA_THREADS_FOR {
2020-02-20 09:43:26 +01:00
for ( auto e = start ; e < stop ; e + + ) {
2019-11-13 15:15:18 +01:00
auto xOffset = this - > getOffset ( e ) ;
2019-06-15 13:34:34 +02:00
2019-11-13 15:15:18 +01:00
f [ xOffset ] = func ( e , f [ xOffset ] ) ;
}
} ;
2019-06-15 13:34:34 +02:00
2020-03-09 06:22:49 +01:00
samediff : : Threads : : parallel_for ( loop , 0 , _length ) ;
2019-06-15 13:34:34 +02:00
} else {
2019-11-13 15:15:18 +01:00
auto loop = PRAGMA_THREADS_FOR {
2020-02-20 09:43:26 +01:00
for ( auto e = start ; e < stop ; e + + ) {
2019-11-13 15:15:18 +01:00
auto xOffset = this - > getOffset ( e ) ;
2019-12-20 20:35:39 +01:00
auto zOffset = target . getOffset ( e ) ;
2019-06-15 13:34:34 +02:00
2019-11-13 15:15:18 +01:00
z [ zOffset ] = func ( e , f [ xOffset ] ) ;
}
} ;
2019-06-15 13:34:34 +02:00
2020-03-09 06:22:49 +01:00
samediff : : Threads : : parallel_for ( loop , 0 , _length ) ;
2019-06-15 13:34:34 +02:00
}
}
}
2019-12-20 20:35:39 +01:00
template void NDArray : : applyIndexedLambda ( const std : : function < double ( Nd4jLong , double ) > & func , NDArray & target ) ;
template void NDArray : : applyIndexedLambda ( const std : : function < float ( Nd4jLong , float ) > & func , NDArray & target ) ;
template void NDArray : : applyIndexedLambda ( const std : : function < float16 ( Nd4jLong , float16 ) > & func , NDArray & target ) ;
template void NDArray : : applyIndexedLambda ( const std : : function < bfloat16 ( Nd4jLong , bfloat16 ) > & func , NDArray & target ) ;
template void NDArray : : applyIndexedLambda ( const std : : function < Nd4jLong ( Nd4jLong , Nd4jLong ) > & func , NDArray & target ) ;
template void NDArray : : applyIndexedLambda ( const std : : function < int ( Nd4jLong , int ) > & func , NDArray & target ) ;
template void NDArray : : applyIndexedLambda ( const std : : function < int16_t ( Nd4jLong , int16_t ) > & func , NDArray & target ) ;
template void NDArray : : applyIndexedLambda ( const std : : function < uint8_t ( Nd4jLong , uint8_t ) > & func , NDArray & target ) ;
template void NDArray : : applyIndexedLambda ( const std : : function < uint16_t ( Nd4jLong , uint16_t ) > & func , NDArray & target ) ;
template void NDArray : : applyIndexedLambda ( const std : : function < uint32_t ( Nd4jLong , uint32_t ) > & func , NDArray & target ) ;
template void NDArray : : applyIndexedLambda ( const std : : function < uint64_t ( Nd4jLong , uint64_t ) > & func , NDArray & target ) ;
template void NDArray : : applyIndexedLambda ( const std : : function < int8_t ( Nd4jLong , int8_t ) > & func , NDArray & target ) ;
template void NDArray : : applyIndexedLambda ( const std : : function < bool ( Nd4jLong , bool ) > & func , NDArray & target ) ;
2019-06-15 13:34:34 +02:00
//////////////////////////////////////////////////////////////////////////
template < typename T >
2019-12-20 20:35:39 +01:00
void NDArray : : applyIndexedPairwiseLambda ( NDArray & other , const std : : function < T ( Nd4jLong , T , T ) > & func , NDArray & target ) {
2019-06-15 13:34:34 +02:00
if ( dataType ( ) ! = DataTypeUtils : : fromT < T > ( ) )
throw std : : runtime_error ( " NDArray::applyIndexedPairwiseLambda<T> method: wrong template parameter T, its type should be the same as type of this array! " ) ;
2019-12-20 20:35:39 +01:00
if ( dataType ( ) ! = target . dataType ( ) )
2019-06-15 13:34:34 +02:00
throw std : : runtime_error ( " NDArray::applyIndexedPairwiseLambda<T> method: types of this and target array should match ! " ) ;
2019-12-20 20:35:39 +01:00
if ( this - > lengthOf ( ) ! = other . lengthOf ( ) ) {
2019-06-15 13:34:34 +02:00
nd4j_printf ( " applyIndexedPairwiseLambda requires both operands to have the same shape \n " , " " ) ;
throw std : : runtime_error ( " Shapes mismach " ) ;
}
auto f = this - > bufferAsT < T > ( ) ;
2019-12-20 20:35:39 +01:00
auto s = other . bufferAsT < T > ( ) ;
auto z = target . bufferAsT < T > ( ) ;
2019-06-15 13:34:34 +02:00
2019-12-20 20:35:39 +01:00
if ( this - > ordering ( ) = = other . ordering ( ) & & this - > ordering ( ) = = target . ordering ( ) & & ( this - > ews ( ) = = 1 & & target . ews ( ) = = 1 ) & & this - > ews ( ) = = other . ews ( ) ) {
2019-06-15 13:34:34 +02:00
2019-11-13 15:15:18 +01:00
auto loop = PRAGMA_THREADS_FOR {
2020-02-20 09:43:26 +01:00
for ( auto e = start ; e < stop ; e + + )
2019-11-13 15:15:18 +01:00
z [ e ] = func ( ( Nd4jLong ) e , f [ e ] , s [ e ] ) ;
} ;
2020-03-09 06:22:49 +01:00
samediff : : Threads : : parallel_for ( loop , 0 , _length ) ;
2019-06-15 13:34:34 +02:00
} else {
if ( f = = z ) {
2019-11-13 15:15:18 +01:00
auto loop = PRAGMA_THREADS_FOR {
2020-02-20 09:43:26 +01:00
for ( auto e = start ; e < stop ; e + + ) {
2019-11-13 15:15:18 +01:00
auto xOffset = this - > getOffset ( e ) ;
2019-12-20 20:35:39 +01:00
auto yOffset = other . getOffset ( e ) ;
2019-06-15 13:34:34 +02:00
2019-11-13 15:15:18 +01:00
f [ xOffset ] = func ( ( Nd4jLong ) e , f [ xOffset ] , s [ yOffset ] ) ;
}
} ;
2019-06-15 13:34:34 +02:00
2020-03-09 06:22:49 +01:00
samediff : : Threads : : parallel_for ( loop , 0 , _length ) ;
2019-06-15 13:34:34 +02:00
} else {
2019-11-13 15:15:18 +01:00
auto loop = PRAGMA_THREADS_FOR {
2020-02-20 09:43:26 +01:00
for ( auto e = start ; e < stop ; e + + ) {
2019-11-13 15:15:18 +01:00
auto xOffset = this - > getOffset ( e ) ;
2019-12-20 20:35:39 +01:00
auto yOffset = other . getOffset ( e ) ;
auto zOffset = target . getOffset ( e ) ;
2019-06-15 13:34:34 +02:00
2019-11-13 15:15:18 +01:00
z [ zOffset ] = func ( ( Nd4jLong ) e , f [ xOffset ] , s [ yOffset ] ) ;
}
} ;
2019-06-15 13:34:34 +02:00
2020-03-09 06:22:49 +01:00
samediff : : Threads : : parallel_for ( loop , 0 , _length ) ;
2019-06-15 13:34:34 +02:00
}
}
}
2019-12-20 20:35:39 +01:00
template void NDArray : : applyIndexedPairwiseLambda ( NDArray & other , const std : : function < double ( Nd4jLong , double , double ) > & func , NDArray & target ) ;
template void NDArray : : applyIndexedPairwiseLambda ( NDArray & other , const std : : function < float ( Nd4jLong , float , float ) > & func , NDArray & target ) ;
template void NDArray : : applyIndexedPairwiseLambda ( NDArray & other , const std : : function < float16 ( Nd4jLong , float16 , float16 ) > & func , NDArray & target ) ;
template void NDArray : : applyIndexedPairwiseLambda ( NDArray & other , const std : : function < bfloat16 ( Nd4jLong , bfloat16 , bfloat16 ) > & func , NDArray & target ) ;
template void NDArray : : applyIndexedPairwiseLambda ( NDArray & other , const std : : function < Nd4jLong ( Nd4jLong , Nd4jLong , Nd4jLong ) > & func , NDArray & target ) ;
template void NDArray : : applyIndexedPairwiseLambda ( NDArray & other , const std : : function < int ( Nd4jLong , int , int ) > & func , NDArray & target ) ;
template void NDArray : : applyIndexedPairwiseLambda ( NDArray & other , const std : : function < int16_t ( Nd4jLong , int16_t , int16_t ) > & func , NDArray & target ) ;
template void NDArray : : applyIndexedPairwiseLambda ( NDArray & other , const std : : function < uint8_t ( Nd4jLong , uint8_t , uint8_t ) > & func , NDArray & target ) ;
template void NDArray : : applyIndexedPairwiseLambda ( NDArray & other , const std : : function < uint16_t ( Nd4jLong , uint16_t , uint16_t ) > & func , NDArray & target ) ;
template void NDArray : : applyIndexedPairwiseLambda ( NDArray & other , const std : : function < uint32_t ( Nd4jLong , uint32_t , uint32_t ) > & func , NDArray & target ) ;
template void NDArray : : applyIndexedPairwiseLambda ( NDArray & other , const std : : function < uint64_t ( Nd4jLong , uint64_t , uint64_t ) > & func , NDArray & target ) ;
template void NDArray : : applyIndexedPairwiseLambda ( N DArray & other , const std : : function < int8_t ( Nd4jLong , int8_t , int8_t ) > & func , NDArray & target ) ;
template void NDArray : : applyIndexedPairwiseLambda ( NDArray & other , const std : : function < bool ( Nd4jLong , bool , bool ) > & func , NDArray & target ) ;