2021-02-01 13:31:45 +01:00
/* ******************************************************************************
*
2019-06-06 14:21:15 +02:00
*
* This program and the accompanying materials are made available under the
* terms of the Apache License , Version 2.0 which is available at
* https : //www.apache.org/licenses/LICENSE-2.0.
*
2021-02-01 13:31:45 +01:00
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership .
2019-06-06 14:21:15 +02:00
* Unless required by applicable law or agreed to in writing , software
* distributed under the License is distributed on an " AS IS " BASIS , WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND , either express or implied . See the
* License for the specific language governing permissions and limitations
* under the License .
*
* SPDX - License - Identifier : Apache - 2.0
* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
//
// @author Yurii Shyrma (iuriish@yahoo.com), created on 12.12.2017
//
2020-03-02 10:49:41 +01:00
# include <system/op_boilerplate.h>
2019-06-06 14:21:15 +02:00
# if NOT_EXCLUDED(OP_zeta)
# include <ops/declarable/CustomOperations.h>
# include <ops/declarable/helpers/zeta.h>
2020-03-02 10:49:41 +01:00
namespace sd {
2019-06-06 14:21:15 +02:00
namespace ops {
CONFIGURABLE_OP_IMPL ( zeta , 2 , 1 , false , 0 , 0 ) {
auto x = INPUT_VARIABLE ( 0 ) ;
auto q = INPUT_VARIABLE ( 1 ) ;
auto output = OUTPUT_VARIABLE ( 0 ) ;
REQUIRE_TRUE ( x - > isSameShape ( q ) , 0 , " ZETA op: two input arrays must have the same shapes, bot got x=%s and q=%s ! " , ShapeUtils : : shapeAsString ( x ) . c_str ( ) , ShapeUtils : : shapeAsString ( q ) . c_str ( ) ) ;
2019-08-10 08:14:18 +02:00
Nd4jLong arrLen = x - > lengthOf ( ) ;
2019-06-06 14:21:15 +02:00
// FIXME: this should NOT be loop.
2019-08-10 08:14:18 +02:00
for ( Nd4jLong i = 0 ; i < arrLen ; + + i ) {
2019-06-06 14:21:15 +02:00
REQUIRE_TRUE ( x - > e < float > ( i ) > 1.f , 0 , " ZETA op: all elements of x array must be > 1 ! " ) ;
REQUIRE_TRUE ( q - > e < float > ( i ) > 0.f , 0 , " ZETA op: all elements of q array must be > 0 ! " ) ;
}
helpers : : zeta ( block . launchContext ( ) , * x , * q , * output ) ;
return Status : : OK ( ) ;
}
DECLARE_SYN ( Zeta , zeta ) ;
DECLARE_TYPES ( zeta ) {
getOpDescriptor ( )
- > setAllowedInputTypes ( { ALL_FLOATS } )
- > setAllowedOutputTypes ( { ALL_FLOATS } ) ;
}
}
}
# endif