2019-06-06 14:21:15 +02:00
|
|
|
/*******************************************************************************
|
|
|
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
|
|
*
|
|
|
|
* This program and the accompanying materials are made available under the
|
|
|
|
* terms of the Apache License, Version 2.0 which is available at
|
|
|
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
|
|
*
|
|
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
|
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
|
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
|
|
* License for the specific language governing permissions and limitations
|
|
|
|
* under the License.
|
|
|
|
*
|
|
|
|
* SPDX-License-Identifier: Apache-2.0
|
|
|
|
******************************************************************************/
|
|
|
|
|
|
|
|
#pragma once
|
|
|
|
#ifndef FUSED_OPS_H_
|
|
|
|
#define FUSED_OPS_H_
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
#include <system/pointercast.h>
|
|
|
|
#include <system/op_boilerplate.h>
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
#include <ops/ops.h>
|
|
|
|
|
|
|
|
namespace metaOps {
|
|
|
|
/**
|
|
|
|
* InvertedMetaOp shares the same idea as MetaOp, but op being applied to op.Y in pairwise/broadcast ops
|
|
|
|
*/
|
|
|
|
template<typename T, typename OpTypeA, typename OpTypeB>
|
|
|
|
class InvertedMetaOp {
|
|
|
|
public:
|
|
|
|
no_op_exec_special
|
|
|
|
no_op_exec_special_cuda
|
|
|
|
|
|
|
|
/*
|
|
|
|
* PREDICATE
|
|
|
|
*/
|
|
|
|
|
|
|
|
// scalar, transform, reduce, indexreduce entry
|
|
|
|
op_def static T op(T d1, T *params) {
|
|
|
|
/*
|
|
|
|
* We assume, that this method won't be EVER called
|
|
|
|
*/
|
|
|
|
printf("You should NEVER see this message in output\n");
|
|
|
|
return (T) 0.0f;
|
|
|
|
}
|
|
|
|
|
|
|
|
// PWT, broadcast entry. Predicate can be only scalar, transform
|
|
|
|
op_def static T op(T d1, T d2, T *params) {
|
|
|
|
Nd4jPointer *wrap = reinterpret_cast<Nd4jPointer *> (params);
|
|
|
|
T *paramsA = reinterpret_cast<T *> (wrap[0]);
|
|
|
|
T *paramsB = reinterpret_cast<T *> (wrap[1]);
|
|
|
|
|
|
|
|
return OpTypeB::op(OpTypeA::op(d1, d2, paramsA), paramsB);
|
|
|
|
}
|
|
|
|
|
|
|
|
/*
|
|
|
|
* POSTULATE
|
|
|
|
*/
|
|
|
|
|
|
|
|
// will be called for reduce, reduce3
|
|
|
|
op_def static T postProcess(T reduction, Nd4jLong n, T *params) {
|
|
|
|
/*
|
|
|
|
* We assume, that this method won't be EVER called
|
|
|
|
*/
|
|
|
|
printf("You should NEVER EVER see this message in output\n");
|
|
|
|
|
|
|
|
return (T) 0.0f;
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
* Special case here: MetaOp which consist of 2 operations.
|
|
|
|
*
|
|
|
|
* Predicate can be either scalar or transform, to process data before actual op call
|
|
|
|
* Postulate will be the scalar/transform, but will be applied to result of broadcast/reduce/reduce3
|
|
|
|
*/
|
|
|
|
template<typename T, typename OpTypeA, typename OpTypeB>
|
|
|
|
class MetaOp {
|
|
|
|
public:
|
|
|
|
no_op_exec_special
|
|
|
|
no_op_exec_special_cuda
|
|
|
|
|
|
|
|
/*
|
|
|
|
* PREDICATE
|
|
|
|
*/
|
|
|
|
|
|
|
|
meta_def static T startingValue(const T *input) {
|
|
|
|
return (T) 0.0f;
|
|
|
|
}
|
|
|
|
|
|
|
|
// scalar, transform, reduce, indexreduce entry
|
|
|
|
meta_def static T op(T d1, T *params) {
|
|
|
|
/*
|
|
|
|
* We assume, that params for MetaOp is a set of pointers to actual op A & B extraArgs
|
|
|
|
*/
|
|
|
|
Nd4jPointer *wrap = reinterpret_cast<Nd4jPointer *> (params);
|
|
|
|
T *paramsA = reinterpret_cast<T *> (wrap[0]);
|
|
|
|
T *paramsB = reinterpret_cast<T *> (wrap[1]);
|
|
|
|
|
|
|
|
return OpTypeB::op(OpTypeA::op(d1, paramsA), paramsB);
|
|
|
|
}
|
|
|
|
|
|
|
|
// PWT, broadcast entry. Predicate can be only scalar, transform
|
|
|
|
meta_def static T op(T d1, T d2, T *params) {
|
|
|
|
Nd4jPointer *wrap = reinterpret_cast<Nd4jPointer *> (params);
|
|
|
|
T *paramsA = reinterpret_cast<T *> (wrap[0]);
|
|
|
|
T *paramsB = reinterpret_cast<T *> (wrap[1]);
|
|
|
|
|
|
|
|
return OpTypeB::op(OpTypeA::op(d1, paramsA), d2, paramsB);
|
|
|
|
}
|
|
|
|
|
|
|
|
/*
|
|
|
|
* POSTULATE
|
|
|
|
*/
|
|
|
|
|
|
|
|
// will be called for reduce, reduce3
|
|
|
|
meta_def static T postProcess(T reduction, Nd4jLong n, T *params) {
|
|
|
|
Nd4jPointer *wrap = reinterpret_cast<Nd4jPointer *> (params);
|
|
|
|
T *paramsA = reinterpret_cast<T *> (wrap[0]);
|
|
|
|
T *paramsB = reinterpret_cast<T *> (wrap[1]);
|
|
|
|
|
|
|
|
return OpTypeB::op(OpTypeA::postProcess(reduction, n, paramsA), paramsB);
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
template<typename T, typename OpTypeA, typename OpTypeB>
|
|
|
|
class ReduceMetaOp {
|
|
|
|
public:
|
|
|
|
no_op_exec_special
|
|
|
|
no_op_exec_special_cuda
|
|
|
|
|
|
|
|
meta_def static T startingValue(const T *input) {
|
|
|
|
return OpTypeB::startingValue(input);
|
|
|
|
}
|
|
|
|
|
|
|
|
meta_def static T merge(T old, T opOutput, T *params) {
|
|
|
|
Nd4jPointer *wrap = reinterpret_cast<Nd4jPointer *> (params);
|
|
|
|
// T *paramsA = reinterpret_cast<T *> (wrap[0]);
|
|
|
|
T *paramsB = reinterpret_cast<T *> (wrap[1]);
|
|
|
|
|
|
|
|
return OpTypeB::merge(old, opOutput, paramsB);
|
|
|
|
}
|
|
|
|
|
|
|
|
meta_def static T update(T old, T opOutput, T *params) {
|
|
|
|
Nd4jPointer *wrap = reinterpret_cast<Nd4jPointer *> (params);
|
|
|
|
//T *paramsA = reinterpret_cast<T *> (wrap[0]);
|
|
|
|
T *paramsB = reinterpret_cast<T *> (wrap[1]);
|
|
|
|
|
|
|
|
return OpTypeB::update(old, opOutput, paramsB);
|
|
|
|
}
|
|
|
|
|
|
|
|
meta_def static T op(T d1, T *params) {
|
|
|
|
Nd4jPointer *wrap = reinterpret_cast<Nd4jPointer *> (params);
|
|
|
|
T *paramsA = reinterpret_cast<T *> (wrap[0]);
|
|
|
|
T *paramsB = reinterpret_cast<T *> (wrap[1]);
|
|
|
|
|
|
|
|
return OpTypeB::op(OpTypeA::op(d1, paramsA), paramsB);
|
|
|
|
}
|
|
|
|
|
|
|
|
meta_def static T postProcess(T reduction, Nd4jLong n, T *params) {
|
|
|
|
Nd4jPointer *wrap = reinterpret_cast<Nd4jPointer *> (params);
|
|
|
|
// T *paramsA = reinterpret_cast<T *> (wrap[0]);
|
|
|
|
T *paramsB = reinterpret_cast<T *> (wrap[1]);
|
|
|
|
|
|
|
|
return OpTypeB::postProcess(reduction, n, paramsB);
|
|
|
|
}
|
|
|
|
};
|
|
|
|
}
|
|
|
|
|
|
|
|
#endif
|