57 lines
		
	
	
		
			1.9 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
			
		
		
	
	
			57 lines
		
	
	
		
			1.9 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
/*******************************************************************************
 | 
						|
 * Copyright (c) 2015-2018 Skymind, Inc.
 | 
						|
 *
 | 
						|
 * This program and the accompanying materials are made available under the
 | 
						|
 * terms of the Apache License, Version 2.0 which is available at
 | 
						|
 * https://www.apache.org/licenses/LICENSE-2.0.
 | 
						|
 *
 | 
						|
 * Unless required by applicable law or agreed to in writing, software
 | 
						|
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 | 
						|
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 | 
						|
 * License for the specific language governing permissions and limitations
 | 
						|
 * under the License.
 | 
						|
 *
 | 
						|
 * SPDX-License-Identifier: Apache-2.0
 | 
						|
 ******************************************************************************/
 | 
						|
 | 
						|
//
 | 
						|
// @author raver119@gmail.com
 | 
						|
//
 | 
						|
 | 
						|
#include <op_boilerplate.h>
 | 
						|
#if NOT_EXCLUDED(OP_Pow)
 | 
						|
 | 
						|
#include <ops/declarable/generic/helpers/BroadcastHelper.h>
 | 
						|
#include <ops/declarable/CustomOperations.h>
 | 
						|
 | 
						|
namespace nd4j {
 | 
						|
    namespace ops {
 | 
						|
        BROADCASTABLE_OP_IMPL(Pow, 0, 0) {
 | 
						|
            auto x = INPUT_VARIABLE(0);
 | 
						|
            auto y = INPUT_VARIABLE(1);
 | 
						|
            auto z = OUTPUT_VARIABLE(0);
 | 
						|
 | 
						|
            BROADCAST_CHECK_EMPTY(x,y,z);
 | 
						|
 | 
						|
            //REQUIRE_TRUE(!y->isB(), 0, "Pairwise OP: you can't divide by bool array!");
 | 
						|
 | 
						|
            auto tZ = BroadcastHelper::broadcastApply({scalar::Pow, pairwise::Pow, broadcast::Pow}, x, y, z);
 | 
						|
            if (tZ == nullptr)
 | 
						|
                return ND4J_STATUS_KERNEL_FAILURE;
 | 
						|
            else if (tZ != z) {
 | 
						|
                OVERWRITE_RESULT(tZ);
 | 
						|
            }
 | 
						|
 | 
						|
            return Status::OK();
 | 
						|
        }
 | 
						|
 | 
						|
        DECLARE_TYPES(Pow) {
 | 
						|
            getOpDescriptor()
 | 
						|
                ->setAllowedInputTypes(0, {ALL_FLOATS, ALL_INTS})
 | 
						|
                ->setAllowedInputTypes(1, {ALL_FLOATS, ALL_INTS})
 | 
						|
                ->setAllowedOutputTypes(0, {ALL_FLOATS, ALL_INTS});
 | 
						|
        }
 | 
						|
    }
 | 
						|
}
 | 
						|
 | 
						|
#endif |