parent
3093977c96
commit
285c6755d1
|
@ -49,6 +49,7 @@
|
||||||
#include <ops/declarable/headers/BarnesHutTsne.h>
|
#include <ops/declarable/headers/BarnesHutTsne.h>
|
||||||
#include <ops/declarable/headers/images.h>
|
#include <ops/declarable/headers/images.h>
|
||||||
#include <ops/declarable/headers/updaters.h>
|
#include <ops/declarable/headers/updaters.h>
|
||||||
|
#include <ops/declarable/headers/decoder.h>
|
||||||
#include <system/dll.h>
|
#include <system/dll.h>
|
||||||
#include <helpers/shape.h>
|
#include <helpers/shape.h>
|
||||||
#include <helpers/TAD.h>
|
#include <helpers/TAD.h>
|
||||||
|
|
|
@ -0,0 +1,111 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2021 Deeplearning4j Contributors
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
*******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author AbdelRauf
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <system/op_boilerplate.h>
|
||||||
|
#include <ops/declarable/CustomOperations.h>
|
||||||
|
#include <ops/declarable/helpers/ctc.h>
|
||||||
|
|
||||||
|
namespace sd {
|
||||||
|
namespace ops {
|
||||||
|
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
CUSTOM_OP_IMPL(ctc_beam, 2, 3, false, 0, -2) {
|
||||||
|
|
||||||
|
auto logit = INPUT_VARIABLE(0);
|
||||||
|
auto sequence_length = INPUT_VARIABLE(1);
|
||||||
|
auto result_sequences = OUTPUT_VARIABLE(0);
|
||||||
|
auto result_probs = OUTPUT_VARIABLE(1);
|
||||||
|
auto result_sequences_length = OUTPUT_VARIABLE(2);
|
||||||
|
auto arg_size = block.getIArguments()->size();
|
||||||
|
auto normalize_logits = block.numB() > 0 ? B_ARG(0) : false;
|
||||||
|
|
||||||
|
int blank_index = arg_size>0 ? INT_ARG(0) : -1;
|
||||||
|
int beam_width = arg_size>1 ? INT_ARG(1) : 25;
|
||||||
|
int nbest_len = arg_size>2? INT_ARG(2): 1;
|
||||||
|
|
||||||
|
REQUIRE_TRUE(logit->rankOf()==3, 0, "Ctc Beam Search: logit Input fails to meet rank requirement {BATCH_LEN, MAX_FRAME_LEN, CLASS_LEN }: %i == 3 ", logit->rankOf());
|
||||||
|
REQUIRE_TRUE(sequence_length->rankOf()==1, 0, "Ctc Beam Search: sequence frame length (sequence_length) Input fails to meet rank requirement {BATCH_LEN}: %i == 1 ", sequence_length->rankOf());
|
||||||
|
|
||||||
|
REQUIRE_TRUE(result_sequences->rankOf()==3, 0, "Ctc Beam Search: result_sequences Output fails to meet rank requirement {BATCH_LEN, NBEST_LEN, MAX_FRAME_LEN }: %i == 3 ", result_sequences->rankOf());
|
||||||
|
REQUIRE_TRUE(result_probs->rankOf()==2, 0, "Ctc Beam Search: result_probs Output fails to meet rank requirement {BATCH_LEN, NBEST_LEN}: %i == 2 ", result_probs->rankOf());
|
||||||
|
REQUIRE_TRUE(result_sequences_length->rankOf()==2, 0, "Ctc Beam Search: result_sequences_length Output fails to meet rank requirement {BATCH_LEN, NBEST_LEN}: %i == 2 ", result_sequences_length->rankOf());
|
||||||
|
|
||||||
|
auto batchSize0 = logit->sizeAt(0);
|
||||||
|
auto batchSize1 = sequence_length->sizeAt(0);
|
||||||
|
auto batchSize2 = result_sequences->sizeAt(0);
|
||||||
|
auto batchSize3 = result_probs->sizeAt(0);
|
||||||
|
auto batchSize4 = result_sequences_length->sizeAt(0);
|
||||||
|
|
||||||
|
bool check_batches = (batchSize0 == batchSize1) && (batchSize2 == batchSize3);
|
||||||
|
check_batches = check_batches && (batchSize0 == batchSize4) && (batchSize0 == batchSize2);
|
||||||
|
|
||||||
|
REQUIRE_TRUE(nbest_len>0 && nbest_len <=beam_width, 0, "Ctc Beam Search: nbest_len %i should be > 0 and <= %i", nbest_len, beam_width);
|
||||||
|
REQUIRE_TRUE(check_batches, 0, "Ctc Beam Search: All batch sizes should be %i", batchSize0);
|
||||||
|
auto max_t = logit->sizeAt(1);
|
||||||
|
REQUIRE_TRUE(result_sequences->sizeAt(1) == nbest_len && result_sequences->sizeAt(2) == max_t , 0, "Ctc Beam Search: shape of the result_sequences should be {%i, %i, %i} but got { %i, %i, %i}",
|
||||||
|
batchSize0, nbest_len, max_t, batchSize1, result_sequences->sizeAt(1), result_sequences->sizeAt(2));
|
||||||
|
REQUIRE_TRUE(result_probs->sizeAt(1) == nbest_len , 0, "Ctc Beam Search: shape of the result_probs should be {%i, %i} but got { %i, %i}",
|
||||||
|
batchSize0, nbest_len, batchSize3, result_sequences->sizeAt(1));
|
||||||
|
REQUIRE_TRUE(result_sequences_length->sizeAt(1) == nbest_len , 0, "Ctc Beam Search: shape of the result_sequences_length should be {%i, %i} but got { %i, %i}",
|
||||||
|
batchSize0, nbest_len, batchSize4, result_sequences_length->sizeAt(1));
|
||||||
|
REQUIRE_TRUE(result_sequences->ews()==1 && result_sequences->ordering()=='c', 0, "Ctc Beam Search: result_sequences output should be ews()==1 and c order: %d == ews(1) %c == order(c) ", result_sequences->ews(), result_sequences->ordering());
|
||||||
|
REQUIRE_TRUE(result_probs->ews()==1 && result_probs->ordering()=='c', 0, "Ctc Beam Search: result_probs output should be ews()==1 and c order: %d == ews(1) %c == order(c) ", result_probs->ews(), result_probs->ordering());
|
||||||
|
REQUIRE_TRUE(result_sequences_length->ews()==1 && result_sequences_length->ordering()=='c', 0, "Ctc Beam Search: result_sequences_length output should be ews()==1 and c order: %d == ews(1) %c == order(c) ", result_sequences_length->ews(), result_sequences_length->ordering());
|
||||||
|
|
||||||
|
sd::ops::helpers::beamSearch(*logit, *sequence_length, *result_sequences, *result_probs, *result_sequences_length, blank_index, beam_width, nbest_len, normalize_logits);
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
DECLARE_TYPES(ctc_beam) {
|
||||||
|
getOpDescriptor()->setAllowedInputTypes({ALL_FLOATS})
|
||||||
|
->setAllowedInputTypes(1,{ALL_INDICES})
|
||||||
|
->setAllowedOutputTypes(0, {ALL_INDICES})
|
||||||
|
->setAllowedOutputTypes(1, {ALL_FLOATS})
|
||||||
|
->setAllowedOutputTypes(2, {ALL_INDICES});
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
DECLARE_SHAPE_FN(ctc_beam) {
|
||||||
|
auto logitShapeInfo = inputShape->at(0);
|
||||||
|
auto sequenceShapeInfo = inputShape->at(1);
|
||||||
|
auto arg_size = block.getIArguments()->size();
|
||||||
|
|
||||||
|
auto nbest_len = arg_size>2? INT_ARG(2): 1;
|
||||||
|
|
||||||
|
REQUIRE_TRUE(logitShapeInfo[0] ==3 , 0, "Ctc Beam Search: logit Input fails to meet rank requirement {BATCH_LEN, MAX_FRAME_LEN, CLASS_LEN }: %i == 3",
|
||||||
|
logitShapeInfo[0]);
|
||||||
|
|
||||||
|
auto batch_size = shape::shapeOf(logitShapeInfo)[0] ;
|
||||||
|
auto max_t = shape::shapeOf(logitShapeInfo)[1] ;
|
||||||
|
|
||||||
|
auto dtype_float = ArrayOptions::dataType(logitShapeInfo);
|
||||||
|
auto dtype_index = ArrayOptions::dataType(sequenceShapeInfo);
|
||||||
|
|
||||||
|
auto output0 = ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(dtype_index, 'c', {batch_size, nbest_len, max_t}));
|
||||||
|
auto output1 = ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(dtype_float, 'c', {batch_size, nbest_len}));
|
||||||
|
auto output2 = ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(dtype_index, 'c', {batch_size, nbest_len}));
|
||||||
|
return SHAPELIST(output0, output1, output2);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
|
@ -21,7 +21,7 @@
|
||||||
|
|
||||||
#include <system/op_boilerplate.h>
|
#include <system/op_boilerplate.h>
|
||||||
#include <ops/declarable/CustomOperations.h>
|
#include <ops/declarable/CustomOperations.h>
|
||||||
#include <ops/declarable/helpers/ctcLoss.h>
|
#include <ops/declarable/helpers/ctc.h>
|
||||||
|
|
||||||
namespace sd {
|
namespace sd {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
@ -43,16 +43,16 @@ CUSTOM_OP_IMPL(ctc_loss, 4, 1, false, 0, 1) {
|
||||||
REQUIRE_TRUE(targetLabelLengths->rankOf()==1, 0, "CtcLoss: target label length fails to meet rank requirement (batch_size): %i == 1 ", targetLabelLengths->rankOf());
|
REQUIRE_TRUE(targetLabelLengths->rankOf()==1, 0, "CtcLoss: target label length fails to meet rank requirement (batch_size): %i == 1 ", targetLabelLengths->rankOf());
|
||||||
REQUIRE_TRUE(logitInputLengths->rankOf()==1, 0, "CtcLoss: logit Input lengths fails to meet rank requirement (batch_size): %i == 1 ", logitInputLengths->rankOf());
|
REQUIRE_TRUE(logitInputLengths->rankOf()==1, 0, "CtcLoss: logit Input lengths fails to meet rank requirement (batch_size): %i == 1 ", logitInputLengths->rankOf());
|
||||||
|
|
||||||
int batchSize0 = targetLabels->sizeAt(0);
|
auto batchSize0 = targetLabels->sizeAt(0);
|
||||||
int batchSize1 = logitInput->sizeAt(0);
|
auto batchSize1 = logitInput->sizeAt(0);
|
||||||
int batchSize2 = targetLabelLengths->sizeAt(0);
|
auto batchSize2 = targetLabelLengths->sizeAt(0);
|
||||||
int batchSize3 = logitInputLengths->sizeAt(0);
|
auto batchSize3 = logitInputLengths->sizeAt(0);
|
||||||
int batchSize4 = outputLosses->sizeAt(0);
|
auto batchSize4 = outputLosses->sizeAt(0);
|
||||||
|
|
||||||
bool check_batches = (batchSize0 == batchSize1) && (batchSize2 == batchSize3);
|
bool check_batches = (batchSize0 == batchSize1) && (batchSize2 == batchSize3);
|
||||||
check_batches = check_batches && (batchSize0 == batchSize4) && (batchSize0 == batchSize2);
|
check_batches = check_batches && (batchSize0 == batchSize4) && (batchSize0 == batchSize2);
|
||||||
|
|
||||||
REQUIRE_TRUE(check_batches, 0, "CtcLoss: All batch sizes should be equal %i", batchSize0);
|
REQUIRE_TRUE(check_batches, 0, "CtcLoss: All batch sizes should be %i", batchSize0);
|
||||||
REQUIRE_TRUE(outputLosses->isSameShape(targetLabelLengths), 0, "CtcLoss: wrong shape of output array, expected is %s but got %s instead !", ShapeUtils::shapeAsString(targetLabelLengths).c_str(), ShapeUtils::shapeAsString(outputLosses).c_str());
|
REQUIRE_TRUE(outputLosses->isSameShape(targetLabelLengths), 0, "CtcLoss: wrong shape of output array, expected is %s but got %s instead !", ShapeUtils::shapeAsString(targetLabelLengths).c_str(), ShapeUtils::shapeAsString(outputLosses).c_str());
|
||||||
|
|
||||||
auto emptyGradients = NDArrayFactory::empty<float>();
|
auto emptyGradients = NDArrayFactory::empty<float>();
|
||||||
|
@ -95,16 +95,16 @@ CUSTOM_OP_IMPL(ctc_loss_grad, 4, 1, false, 0, 1) {
|
||||||
REQUIRE_TRUE(targetLabelLengths->rankOf()==1, 0, "CtcLoss: target label length fails to meet rank requirement (batch_size): %i == 1 ", targetLabelLengths->rankOf());
|
REQUIRE_TRUE(targetLabelLengths->rankOf()==1, 0, "CtcLoss: target label length fails to meet rank requirement (batch_size): %i == 1 ", targetLabelLengths->rankOf());
|
||||||
REQUIRE_TRUE(logitInputLengths->rankOf()==1, 0, "CtcLoss: logit Input lengths fails to meet rank requirement (batch_size): %i == 1 ", logitInputLengths->rankOf());
|
REQUIRE_TRUE(logitInputLengths->rankOf()==1, 0, "CtcLoss: logit Input lengths fails to meet rank requirement (batch_size): %i == 1 ", logitInputLengths->rankOf());
|
||||||
|
|
||||||
int batchSize0 = targetLabels->sizeAt(0);
|
auto batchSize0 = targetLabels->sizeAt(0);
|
||||||
int batchSize1 = logitInput->sizeAt(0);
|
auto batchSize1 = logitInput->sizeAt(0);
|
||||||
int batchSize2 = targetLabelLengths->sizeAt(0);
|
auto batchSize2 = targetLabelLengths->sizeAt(0);
|
||||||
int batchSize3 = logitInputLengths->sizeAt(0);
|
auto batchSize3 = logitInputLengths->sizeAt(0);
|
||||||
int batchSize4 = outputGradients->sizeAt(0);
|
auto batchSize4 = outputGradients->sizeAt(0);
|
||||||
|
|
||||||
bool check_batches = (batchSize0 == batchSize1) && (batchSize2 == batchSize3);
|
bool check_batches = (batchSize0 == batchSize1) && (batchSize2 == batchSize3);
|
||||||
check_batches = check_batches && (batchSize0 == batchSize4) && (batchSize0 == batchSize2);
|
check_batches = check_batches && (batchSize0 == batchSize4) && (batchSize0 == batchSize2);
|
||||||
|
|
||||||
REQUIRE_TRUE(check_batches, 0, "CtcLoss Gradient: All batch sizes should be equal %i", batchSize0);
|
REQUIRE_TRUE(check_batches, 0, "CtcLoss Gradient: All batch sizes should be %i", batchSize0);
|
||||||
REQUIRE_TRUE(outputGradients->isSameShape(logitInput), 0, "CtcLoss Gradient: wrong shape of output array, expected is %s but got %s instead !", ShapeUtils::shapeAsString(logitInput).c_str(), ShapeUtils::shapeAsString(outputGradients).c_str());
|
REQUIRE_TRUE(outputGradients->isSameShape(logitInput), 0, "CtcLoss Gradient: wrong shape of output array, expected is %s but got %s instead !", ShapeUtils::shapeAsString(logitInput).c_str(), ShapeUtils::shapeAsString(outputGradients).c_str());
|
||||||
|
|
||||||
auto emptyLoss = NDArrayFactory::empty<float>();
|
auto emptyLoss = NDArrayFactory::empty<float>();
|
||||||
|
|
|
@ -0,0 +1,65 @@
|
||||||
|
/* ******************************************************************************
|
||||||
|
*
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* See the NOTICE file distributed with this work for additional
|
||||||
|
* information regarding copyright ownership.
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author AbdelRauf
|
||||||
|
//
|
||||||
|
|
||||||
|
#ifndef LIBND4J_HEADERS_DECODER_H
|
||||||
|
#define LIBND4J_HEADERS_DECODER_H
|
||||||
|
|
||||||
|
#include <ops/declarable/headers/common.h>
|
||||||
|
|
||||||
|
namespace sd {
|
||||||
|
namespace ops {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Implementation of CTC beam search
|
||||||
|
*
|
||||||
|
* Input arrays:
|
||||||
|
* 0: logits - logits NDArray logit NDArray {BATCH_LEN, MAX_FRAME_LEN, CLASS_LEN }. It should include a blank label as well. type float
|
||||||
|
* 1: sequence_length - NDArray {BATCH_LEN} length of frames. type integer
|
||||||
|
*
|
||||||
|
* Input integer arguments (IArgs):
|
||||||
|
* 0: blank_index the index of the blank label in logits. default is last class. CLASS_LEN-1
|
||||||
|
* 1: beam_width the width of the beam search. default is 25
|
||||||
|
* 2: nbest_len the number of top best results that should be returned. default is 1
|
||||||
|
* NOTE: if it is > beam_width it will be defaulted to beam_width size.
|
||||||
|
* Input bool argument (BArgs):
|
||||||
|
* 0: normalize_logit when its true it will normalize logits. by default it is assumed logit contains already normalized log-probabilities
|
||||||
|
* Output array:
|
||||||
|
* 0: result_sequences NDArray {BATCH_LEN, NBEST, MAX_FRAME_LEN} result sequences.
|
||||||
|
* NOTE: result_sequences NdArray should be c order and have ews == 1. type integer
|
||||||
|
* 1: result_probs NDArray {BATCH_LEN, NBEST} negative log probabilities for each sequence. type float
|
||||||
|
* NOTE: result_probs NdArray should be c order and have ews == 1
|
||||||
|
* 2: result_sequence_length NDArray {BATCH_LEN, NBEST} the length of the each sequence. type integer
|
||||||
|
* NOTE: result_sequence_length NdArray should be c order and have ews == 1
|
||||||
|
*
|
||||||
|
* NOTE:
|
||||||
|
* maximum value of integer indexing type should be >= CLASS_LEN to make sense. And also it should consider frame lengthes as well.
|
||||||
|
* For now this case is mostly fine as only Indexing types are allowed as integer.
|
||||||
|
*/
|
||||||
|
#if NOT_EXCLUDED(OP_ctc_beam)
|
||||||
|
DECLARE_CUSTOM_OP(ctc_beam, 2, 3, false, 0, -2);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
|
@ -365,7 +365,8 @@ namespace ops {
|
||||||
*
|
*
|
||||||
* Input arrays:
|
* Input arrays:
|
||||||
* 0: labels - labels NDArray {BATCH_LEN, MAX_TARGET_LEN}, type integer
|
* 0: labels - labels NDArray {BATCH_LEN, MAX_TARGET_LEN}, type integer
|
||||||
* 1: logits - logits NDArray {BATCH_LEN, FRAME_LEN, CLASS_LEN }. log softmax of rnn output. It should include a blank label as well, type float
|
* 1: logits - logits NDArray {BATCH_LEN, FRAME_LEN, CLASS_LEN }. It should include a blank label as well, type float
|
||||||
|
* NOTE: we expect normalized logits (softmax normalized logarithm values for logits).
|
||||||
* 2: targetLabelLengths - Length of label sequence in labels NDArray {BATCH_LEN}, type integer
|
* 2: targetLabelLengths - Length of label sequence in labels NDArray {BATCH_LEN}, type integer
|
||||||
* 3: logitsLengths - Length of input sequence in logits NDArray {BATCH_LEN}, type integer
|
* 3: logitsLengths - Length of input sequence in logits NDArray {BATCH_LEN}, type integer
|
||||||
*
|
*
|
||||||
|
|
|
@ -25,7 +25,7 @@
|
||||||
#include <execution/Threads.h>
|
#include <execution/Threads.h>
|
||||||
#include <execution/ThreadPool.h>
|
#include <execution/ThreadPool.h>
|
||||||
#include <helpers/LoopsCoordsHelper.h>
|
#include <helpers/LoopsCoordsHelper.h>
|
||||||
#include <ops/declarable/helpers/ctcLoss.h>
|
#include <ops/declarable/helpers/ctc.h>
|
||||||
|
|
||||||
namespace sd
|
namespace sd
|
||||||
{
|
{
|
||||||
|
@ -34,26 +34,11 @@ namespace sd
|
||||||
namespace helpers
|
namespace helpers
|
||||||
{
|
{
|
||||||
|
|
||||||
//choose ptr[index*element_stride]
|
|
||||||
template <bool Strided, typename Type>
|
|
||||||
typename std::enable_if<Strided == true, Type &>::type
|
|
||||||
element(Type *ptr, int index, int element_stride)
|
|
||||||
{
|
|
||||||
return ptr[index * element_stride];
|
|
||||||
}
|
|
||||||
|
|
||||||
//choose ptr[index] assuming element_stride is 1
|
|
||||||
template <bool Strided, typename Type>
|
|
||||||
typename std::enable_if<Strided == false, Type &>::type
|
|
||||||
element(Type *ptr, int index, int element_stride)
|
|
||||||
{
|
|
||||||
return ptr[index];
|
|
||||||
}
|
|
||||||
|
|
||||||
template <bool IsLogPStrided = false, bool IsLblStrided = false, typename Type, typename IndexType>
|
template <bool IsLogPStrided = false, bool IsLblStrided = false, typename Type, typename IndexType>
|
||||||
Type forward(Type *alphaPtr, const Nd4jLong &incA, const Type *logP, const Nd4jLong &incP, const IndexType *lbl, const Nd4jLong &lenSB, const Nd4jLong &lenT, const int &blankIndex, int elwiseP = 1, int elwiseS = 1)
|
Type forward(Type *alphaPtr, const Nd4jLong &incA, const Type *logP, const Nd4jLong &incP, const IndexType *lbl, const Nd4jLong &lenSB, const Nd4jLong &lenT, const int &blankIndex, int elwiseP = 1, int elwiseS = 1)
|
||||||
{
|
{
|
||||||
Type negInf = -DataTypeUtils::infOrMax<Type>();
|
Type negInf = negative_infinity<Type>();
|
||||||
//initialize alphas at t=0
|
//initialize alphas at t=0
|
||||||
alphaPtr[0] = element<IsLogPStrided>(logP, blankIndex, elwiseP);
|
alphaPtr[0] = element<IsLogPStrided>(logP, blankIndex, elwiseP);
|
||||||
//alphaPtr[1] =logP[lbl[0]];
|
//alphaPtr[1] =logP[lbl[0]];
|
||||||
|
@ -82,23 +67,17 @@ namespace sd
|
||||||
// {t-1,s}
|
// {t-1,s}
|
||||||
Type alphaS = alphaPrevPtr[s];
|
Type alphaS = alphaPrevPtr[s];
|
||||||
Type alphaS_1 = s > 0 ? alphaPrevPtr[s - 1] : negInf;
|
Type alphaS_1 = s > 0 ? alphaPrevPtr[s - 1] : negInf;
|
||||||
Type cMax = std::max(alphaS, alphaS_1);
|
|
||||||
//logP[currentInd] or logP[currentInd*elwiseP]
|
//logP[currentInd] or logP[currentInd*elwiseP]
|
||||||
auto currentProb = element<IsLogPStrided>(logP, currentInd, elwiseP);
|
auto currentProb = element<IsLogPStrided>(logP, currentInd, elwiseP);
|
||||||
// if blank or the same as previous
|
// if blank or the same as previous
|
||||||
if (s > 1 && currentInd != blankIndex && currentInd != element<IsLblStrided>(lbl, ind - 1, elwiseS))
|
if (s > 1 && currentInd != blankIndex && currentInd != element<IsLblStrided>(lbl, ind - 1, elwiseS))
|
||||||
{
|
{
|
||||||
Type alphaS_2 = alphaPrevPtr[s - 2];
|
Type alphaS_2 = alphaPrevPtr[s - 2];
|
||||||
cMax = std::max(cMax, alphaS_2);
|
alphaPtr[s] = log_sum_exp(alphaS, alphaS_1, alphaS_2) + currentProb;
|
||||||
if (cMax == negInf)
|
|
||||||
cMax = 0;
|
|
||||||
alphaPtr[s] = std::log(std::exp(alphaS - cMax) + std::exp(alphaS_1 - cMax) + std::exp(alphaS_2 - cMax)) + cMax + currentProb;
|
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
if (cMax == negInf)
|
alphaPtr[s] = log_sum_exp(alphaS, alphaS_1) + currentProb;
|
||||||
cMax = 0;
|
|
||||||
alphaPtr[s] = std::log(std::exp(alphaS - cMax) + std::exp(alphaS_1 - cMax)) + cMax + currentProb;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -109,8 +88,7 @@ namespace sd
|
||||||
}
|
}
|
||||||
auto logP0 = alphaPrevPtr[lenSB - 1];
|
auto logP0 = alphaPrevPtr[lenSB - 1];
|
||||||
auto logP1 = alphaPrevPtr[lenSB - 2];
|
auto logP1 = alphaPrevPtr[lenSB - 2];
|
||||||
auto cMax = std::max(logP0, logP1);
|
return -log_sum_exp(logP0, logP1 );
|
||||||
return -(std::log(std::exp(logP0 - cMax) + std::exp(logP1 - cMax)) + cMax);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//#undef CALCULATE_ALL_IN_ONE_FRAME_LOOP
|
//#undef CALCULATE_ALL_IN_ONE_FRAME_LOOP
|
||||||
|
@ -121,7 +99,7 @@ namespace sd
|
||||||
int elwiseP = 1, int elwiseS = 1, int elwiseG = 1)
|
int elwiseP = 1, int elwiseS = 1, int elwiseG = 1)
|
||||||
{
|
{
|
||||||
|
|
||||||
Type negInf = -DataTypeUtils::infOrMax<Type>();
|
Type negInf = negative_infinity<Type>();
|
||||||
Nd4jLong lenSB = 2 * lenS + 1;
|
Nd4jLong lenSB = 2 * lenS + 1;
|
||||||
auto origBetta = bettaPtr;
|
auto origBetta = bettaPtr;
|
||||||
auto origLogP = logP;
|
auto origLogP = logP;
|
||||||
|
@ -197,23 +175,17 @@ namespace sd
|
||||||
// {t-1,s}
|
// {t-1,s}
|
||||||
Type bettaS = bettaPrevPtr[s];
|
Type bettaS = bettaPrevPtr[s];
|
||||||
Type bettaS_1 = s < lenSB - 1 ? bettaPrevPtr[s + 1] : negInf;
|
Type bettaS_1 = s < lenSB - 1 ? bettaPrevPtr[s + 1] : negInf;
|
||||||
Type cMax = std::max(bettaS, bettaS_1);
|
|
||||||
//logP[currentInd]
|
//logP[currentInd]
|
||||||
auto currentProb = element<IsLogPStrided>(logP, currentInd, elwiseP);
|
auto currentProb = element<IsLogPStrided>(logP, currentInd, elwiseP);
|
||||||
// if blank or the same as previous
|
// if blank or the same as previous
|
||||||
if (s < lenSB - 2 && currentInd != blankIndex && currentInd != element<IsLblStrided>(lbl, ind + 1, elwiseS))
|
if (s < lenSB - 2 && currentInd != blankIndex && currentInd != element<IsLblStrided>(lbl, ind + 1, elwiseS))
|
||||||
{
|
{
|
||||||
Type bettaS_2 = bettaPrevPtr[s + 2];
|
Type bettaS_2 = bettaPrevPtr[s + 2];
|
||||||
cMax = std::max(cMax, bettaS_2);
|
bettaPtr[s] = log_sum_exp(bettaS, bettaS_1, bettaS_2) + currentProb;
|
||||||
if (cMax == negInf)
|
|
||||||
cMax = 0;
|
|
||||||
bettaPtr[s] = std::log(std::exp(bettaS - cMax) + std::exp(bettaS_1 - cMax) + std::exp(bettaS_2 - cMax)) + cMax + currentProb;
|
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
if (cMax == negInf)
|
bettaPtr[s] = log_sum_exp(bettaS, bettaS_1) + currentProb;
|
||||||
cMax = 0;
|
|
||||||
bettaPtr[s] = std::log(std::exp(bettaS - cMax) + std::exp(bettaS_1 - cMax)) + cMax + currentProb;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#if defined(CALCULATE_ALL_IN_ONE_FRAME_LOOP)
|
#if defined(CALCULATE_ALL_IN_ONE_FRAME_LOOP)
|
||||||
|
@ -262,8 +234,7 @@ namespace sd
|
||||||
|
|
||||||
auto logBP0 = bettaPrevPtr[0];
|
auto logBP0 = bettaPrevPtr[0];
|
||||||
auto logBP1 = bettaPrevPtr[1];
|
auto logBP1 = bettaPrevPtr[1];
|
||||||
auto bcMax = std::max(logBP0, logBP1);
|
auto blogLoss = -log_sum_exp(logBP0, logBP1);
|
||||||
auto blogLoss = -(std::log(std::exp(logBP0 - bcMax) + std::exp(logBP1 - bcMax)) + bcMax);
|
|
||||||
|
|
||||||
#if !defined(CALCULATE_ALL_IN_ONE_FRAME_LOOP)
|
#if !defined(CALCULATE_ALL_IN_ONE_FRAME_LOOP)
|
||||||
//alpha*betta
|
//alpha*betta
|
||||||
|
@ -289,8 +260,7 @@ namespace sd
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
Type cMax = std::max(currentGrad, alphaBettaS);
|
currentGrad = log_sum_exp(currentGrad, alphaBettaS);
|
||||||
currentGrad = std::log(std::exp(currentGrad - cMax) + std::exp(alphaBettaS - cMax)) + cMax;
|
|
||||||
}
|
}
|
||||||
//alphaPtr[s] = alphaBettaS;
|
//alphaPtr[s] = alphaBettaS;
|
||||||
}
|
}
|
||||||
|
@ -345,7 +315,7 @@ namespace sd
|
||||||
auto bufferPtr = bufferArr.bufferAsT<Type>();
|
auto bufferPtr = bufferArr.bufferAsT<Type>();
|
||||||
auto incA = bufferArr.stridesOf()[1];
|
auto incA = bufferArr.stridesOf()[1];
|
||||||
auto bettaBufferPtr = bufferPtr + bufferArr.stridesOf()[0];
|
auto bettaBufferPtr = bufferPtr + bufferArr.stridesOf()[0];
|
||||||
Type negInf = -DataTypeUtils::infOrMax<Type>();
|
Type negInf = negative_infinity<Type>();
|
||||||
|
|
||||||
#if 1
|
#if 1
|
||||||
if (gradPtr)
|
if (gradPtr)
|
||||||
|
@ -421,7 +391,8 @@ namespace sd
|
||||||
elwiseLL = logLosses.stridesOf()[0];
|
elwiseLL = logLosses.stridesOf()[0];
|
||||||
logLossPtr = logLosses.bufferAsT<Type>();
|
logLossPtr = logLosses.bufferAsT<Type>();
|
||||||
}
|
}
|
||||||
|
//defaulting blankIndex to the last class if its incorrect or -1
|
||||||
|
if (blankIndex > maxLenS || blankIndex < 0) blankIndex = maxLenS - 1;
|
||||||
auto func = [logP, batchP, incP, elwiseP, lenK, lenTPtr, lenSPtr, logLossPtr, lblPtr, maxLenT, maxLenS,
|
auto func = [logP, batchP, incP, elwiseP, lenK, lenTPtr, lenSPtr, logLossPtr, lblPtr, maxLenT, maxLenS,
|
||||||
batchLbl, blankIndex, elwiseT, elwiseLL, elwiseSLen, elwiseS, &gradients](uint64_t thread_id, int64_t start, int64_t stop, int64_t increment) -> void {
|
batchLbl, blankIndex, elwiseT, elwiseLL, elwiseSLen, elwiseS, &gradients](uint64_t thread_id, int64_t start, int64_t stop, int64_t increment) -> void {
|
||||||
Type *gradPtr = nullptr;
|
Type *gradPtr = nullptr;
|
||||||
|
@ -450,7 +421,7 @@ namespace sd
|
||||||
lenS = lenS > maxLenS ? maxLenS : lenS;
|
lenS = lenS > maxLenS ? maxLenS : lenS;
|
||||||
if (lenS <= 0 || lenT <= 0)
|
if (lenS <= 0 || lenT <= 0)
|
||||||
{
|
{
|
||||||
resultLoss = -DataTypeUtils::infOrMax<Type>();
|
resultLoss = negative_infinity<Type>();
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
|
@ -475,7 +446,7 @@ namespace sd
|
||||||
lenS = lenS > maxLenS ? maxLenS : lenS;
|
lenS = lenS > maxLenS ? maxLenS : lenS;
|
||||||
if (lenS <= 0 || lenT <= 0)
|
if (lenS <= 0 || lenT <= 0)
|
||||||
{
|
{
|
||||||
resultLoss = -DataTypeUtils::infOrMax<Type>();
|
resultLoss = negative_infinity<Type>();
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
|
@ -495,11 +466,11 @@ namespace sd
|
||||||
|
|
||||||
void ctcLoss(graph::Context& block, const NDArray &logits, const NDArray &targetLabels, const NDArray &logitsLengths, const NDArray &targetLabelLengths, NDArray &logLosses, NDArray &gradients, int blankIndex){
|
void ctcLoss(graph::Context& block, const NDArray &logits, const NDArray &targetLabels, const NDArray &logitsLengths, const NDArray &targetLabelLengths, NDArray &logLosses, NDArray &gradients, int blankIndex){
|
||||||
|
|
||||||
BUILD_DOUBLE_SELECTOR(logits.dataType(), targetLabels.dataType(), ctc_loss_, (logits, targetLabels, logitsLengths, targetLabelLengths, logLosses, gradients, blankIndex), FLOAT_TYPES, INDEXING_TYPES);
|
BUILD_DOUBLE_SELECTOR(logits.dataType(), targetLabels.dataType(), ctc_loss_, (logits, targetLabels, logitsLengths, targetLabelLengths, logLosses, gradients, blankIndex), FLOAT_TYPES, INDEXING_TYPES);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
BUILD_DOUBLE_TEMPLATE(template void ctc_loss_, (const NDArray &logits, const NDArray &targetLabels, const NDArray &logitsLengths, const NDArray &targetLabelLengths, NDArray &logLosses, NDArray &gradients, int blankIndex), FLOAT_TYPES, INDEXING_TYPES);
|
BUILD_DOUBLE_TEMPLATE(template void ctc_loss_, (const NDArray &logits, const NDArray &targetLabels, const NDArray &logitsLengths, const NDArray &targetLabelLengths, NDArray &logLosses, NDArray &gradients, int blankIndex), FLOAT_TYPES, INDEXING_TYPES);
|
||||||
|
|
||||||
|
|
||||||
} // namespace helpers
|
} // namespace helpers
|
||||||
|
|
|
@ -0,0 +1,153 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2021 Deeplearning4j Contributors
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
*******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author AbdelRauf
|
||||||
|
//
|
||||||
|
|
||||||
|
#ifndef LIBND4J_HELPERS_CTCLOSS_H
|
||||||
|
#define LIBND4J_HELPERS_CTCLOSS_H
|
||||||
|
|
||||||
|
#include <ops/declarable/helpers/helpers.h>
|
||||||
|
#include <graph/Context.h>
|
||||||
|
#include <type_traits>
|
||||||
|
#include <math/platformmath.h>
|
||||||
|
namespace sd {
|
||||||
|
namespace ops {
|
||||||
|
namespace helpers {
|
||||||
|
|
||||||
|
//#define LOGIT_SOFTMAX_NORMALIZATION 1
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
constexpr T negative_infinity()
|
||||||
|
{
|
||||||
|
return -DataTypeUtils::infOrMax<T>();
|
||||||
|
}
|
||||||
|
|
||||||
|
//choose ptr[index*element_stride]
|
||||||
|
template <bool HasStride, typename Type>
|
||||||
|
typename std::enable_if<HasStride == true, Type &>::type
|
||||||
|
element(Type *ptr, int index, int element_stride)
|
||||||
|
{
|
||||||
|
return ptr[index * element_stride];
|
||||||
|
}
|
||||||
|
|
||||||
|
//choose ptr[index] assuming element_stride is 1
|
||||||
|
template <bool HasStride, typename Type>
|
||||||
|
typename std::enable_if<HasStride == false, Type &>::type
|
||||||
|
element(Type *ptr, int index, int element_stride)
|
||||||
|
{
|
||||||
|
return ptr[index];
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
T local_log(T x)
|
||||||
|
{
|
||||||
|
if (x > 0)
|
||||||
|
{
|
||||||
|
return (sd::math::p_log<T>(x));
|
||||||
|
}
|
||||||
|
return (negative_infinity<T>());
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
T log_sum_exp(T x1, T x2)
|
||||||
|
{
|
||||||
|
//substituting this : std::log(std::exp(arg1 - cMax) + std::exp(arg2 - cMax)) + cMax
|
||||||
|
//if arg1==cMax : std::log(1 + std::exp(arg2 - cMax)) + cMax
|
||||||
|
if (x1 >= x2)
|
||||||
|
{
|
||||||
|
//x1 is max
|
||||||
|
return (x1 + local_log(1 + sd::math::p_exp<T>(x2 - x1)));
|
||||||
|
}
|
||||||
|
//x2 is max
|
||||||
|
return (x2 + local_log(1 + sd::math::p_exp<T>(x1 - x2)));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
T log_sum_exp(T arg1, T arg2, T arg3)
|
||||||
|
{
|
||||||
|
auto c_max = std::max(arg1, arg2);
|
||||||
|
c_max = std::max(c_max, arg3);
|
||||||
|
if (negative_infinity<T>() == c_max)
|
||||||
|
{
|
||||||
|
c_max = 0;
|
||||||
|
}
|
||||||
|
return sd::math::p_log(sd::math::p_exp(arg1 - c_max) + sd::math::p_exp(arg2 - c_max) + sd::math::p_exp(arg3 - c_max)) + c_max;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <bool HasElementStride, typename Type, typename IndexType>
|
||||||
|
Type softmax_normalization_term(const Type* log_p, const uint64_t len_c, const uint64_t element_stride)
|
||||||
|
{
|
||||||
|
Type max_p;
|
||||||
|
for (auto c = 0; c < len_c; ++c) {
|
||||||
|
max_p = std::max(max_p, element<HasElementStride>(log_p, c, element_stride));
|
||||||
|
}
|
||||||
|
// Get normalization term of softmax: log(sum(exp(logit[j]-max_p))).
|
||||||
|
Type logsumexp = Type(0.0);
|
||||||
|
for (auto c = 0; c < len_c; ++c) {
|
||||||
|
logsumexp += sd::math::p_exp(element<HasElementStride>(log_p, c, element_stride) - max_p);
|
||||||
|
}
|
||||||
|
logsumexp = sd::math::p_log(logsumexp);
|
||||||
|
return max_p + logsumexp;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Implementation of CTC loss function
|
||||||
|
* References:
|
||||||
|
Connectionist Temporal Classification - Labeling Unsegmented Sequence Data
|
||||||
|
with Recurrent Neural Networks:
|
||||||
|
[Graves et al., 2006](https://dl.acm.org/citation.cfm?id=1143891)
|
||||||
|
([pdf](http://www.cs.toronto.edu/~graves/icml_2006.pdf))
|
||||||
|
*
|
||||||
|
* @param block Context
|
||||||
|
* @param logits NDArray {BATCH_LEN, MAX_FRAME_LEN, CLASS_LEN }. It should include a blank label as well.
|
||||||
|
* NOTE: log softmax of rnn output. so we expect softmax normalized
|
||||||
|
* @param targetLabels NDArray {BATCH_LEN, MAX_TARGET_LEN}
|
||||||
|
* @param logitsLengths NDArray {BATCH_LEN} Length of input sequence in logits
|
||||||
|
* @param targetLabelLengths NDArray {BATCH_LEN} Length of label sequence in labels
|
||||||
|
* @param logLosses NDArray {BATCH_LEN} or EMPTY. if empty it will be skipped. negative log probabilities of loss
|
||||||
|
* @param gradients NDArray {BATCH_LEN, MAX_FRAME_LEN, CLASS_LEN } or EMPTY. gradients
|
||||||
|
* @param blankIndex index of the blank label in logits
|
||||||
|
*/
|
||||||
|
void ctcLoss(graph::Context& block, const NDArray &logitsInput, const NDArray &targetLabels, const NDArray &logitsLengths, const NDArray &targetLabelLengths, NDArray &logLosses, NDArray &gradients, int blankIndex);
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Implementation of CTC beam search
|
||||||
|
*
|
||||||
|
* @param logit NDArray {BATCH_LEN, MAX_FRAME_LEN, CLASS_LEN }. log probabilities. It should include a blank label as well.
|
||||||
|
* @param sequence_length NDArray {BATCH_LEN} length of frames. type integer
|
||||||
|
* @param result_sequences NDArray {BATCH_LEN, NBEST, MAX_FRAME_LEN} result sequences.
|
||||||
|
* NOTE: result_sequences NdArray should be c order and have ews == 1. type integer.
|
||||||
|
* @param result_probs NDArray {BATCH_LEN, NBEST} negative log probabilities for each sequence.
|
||||||
|
* NOTE: result_probs NdArray should be c order and have ews == 1
|
||||||
|
* @param result_sequences_length NDArray {BATCH_LEN, NBEST} the length of each sequence in result_sequences.
|
||||||
|
* NOTE: result_sequences_length NdArray should be c order and have ews == 1
|
||||||
|
* @param blank_index the index of the blank label in logits
|
||||||
|
* @param beam_width the width of the beam search.
|
||||||
|
* @param nbest_len the number of top best results that should be returned. if it is greather than beam_width it will be defaulted to beam_width size.
|
||||||
|
* @param normalize_logits when its true it will normalize logits. by default it is assumed logit contains already normalized log-probabilities
|
||||||
|
* NOTE:
|
||||||
|
* maximum value of integer type should be >= CLASS_LEN to make sense. And also user should consider frame lengthes as well.
|
||||||
|
*/
|
||||||
|
void beamSearch(const NDArray& logit, const NDArray& sequence_length, NDArray& result_sequences, NDArray& result_probs, NDArray& result_sequences_length, int blank_index, int beam_width , int nbest_len, bool normalize_logits);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
#endif
|
|
@ -1,55 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2021 Deeplearning4j Contributors
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
*******************************************************************************/
|
|
||||||
|
|
||||||
//
|
|
||||||
// @author AbdelRauf
|
|
||||||
//
|
|
||||||
|
|
||||||
#ifndef LIBND4J_HELPERS_CTCLOSS_H
|
|
||||||
#define LIBND4J_HELPERS_CTCLOSS_H
|
|
||||||
|
|
||||||
#include <ops/declarable/helpers/helpers.h>
|
|
||||||
#include <graph/Context.h>
|
|
||||||
|
|
||||||
namespace sd {
|
|
||||||
namespace ops {
|
|
||||||
namespace helpers {
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Implementation of CTC loss function
|
|
||||||
* References:
|
|
||||||
Connectionist Temporal Classification - Labeling Unsegmented Sequence Data
|
|
||||||
with Recurrent Neural Networks:
|
|
||||||
[Graves et al., 2006](https://dl.acm.org/citation.cfm?id=1143891)
|
|
||||||
([pdf](http://www.cs.toronto.edu/~graves/icml_2006.pdf))
|
|
||||||
*
|
|
||||||
* @param block Context
|
|
||||||
* @param logits NDArray {BATCH_LEN, FRAME_LEN, CLASS_LEN }. log softmax of rnn output. It should include a blank label as well.
|
|
||||||
* @param targetLabels NDArray {BATCH_LEN, MAX_TARGET_LEN}
|
|
||||||
* @param logitsLengths NDArray {BATCH_LEN} Length of input sequence in logits
|
|
||||||
* @param targetLabelLengths NDArray {BATCH_LEN} Length of label sequence in labels
|
|
||||||
* @param logLosses NDArray {BATCH_LEN} or EMPTY. if empty it will be skipped. negative log probabilities of loss
|
|
||||||
* @param gradients NDArray {BATCH_LEN, FRAME_LEN, CLASS_LEN } or EMPTY. gradients
|
|
||||||
* @param blankIndex index of the blank label in logits
|
|
||||||
*/
|
|
||||||
void ctcLoss(graph::Context& block, const NDArray &logitsInput, const NDArray &targetLabels, const NDArray &logitsLengths, const NDArray &targetLabelLengths, NDArray &logLosses, NDArray &gradients, int blankIndex);
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
#endif // LIBND4J_ADDBIAS_H
|
|
|
@ -25,7 +25,7 @@
|
||||||
#include <execution/Threads.h>
|
#include <execution/Threads.h>
|
||||||
#include <execution/ThreadPool.h>
|
#include <execution/ThreadPool.h>
|
||||||
#include <helpers/LoopsCoordsHelper.h>
|
#include <helpers/LoopsCoordsHelper.h>
|
||||||
#include <ops/declarable/helpers/ctcLoss.h>
|
#include <ops/declarable/helpers/ctc.h>
|
||||||
|
|
||||||
namespace sd
|
namespace sd
|
||||||
{
|
{
|
||||||
|
|
|
@ -0,0 +1,718 @@
|
||||||
|
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2021 Deeplearning4j Contributors
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
*******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author AbdelRauf
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include <limits>
|
||||||
|
#include <algorithm>
|
||||||
|
#include <cassert>
|
||||||
|
#include <numeric>
|
||||||
|
#include <cmath>
|
||||||
|
#include <execution/Threads.h>
|
||||||
|
#include <execution/ThreadPool.h>
|
||||||
|
#include <helpers/LoopsCoordsHelper.h>
|
||||||
|
#include <ops/declarable/helpers/ctc.h>
|
||||||
|
|
||||||
|
namespace sd {
|
||||||
|
namespace ops {
|
||||||
|
namespace helpers {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct BeamProb
|
||||||
|
{
|
||||||
|
T total = negative_infinity<T>();
|
||||||
|
T non_blank = negative_infinity<T>();
|
||||||
|
T blank = negative_infinity<T>(); //log(1)
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
template <typename T, typename T2 = void>
|
||||||
|
struct DefaultInvalid
|
||||||
|
{
|
||||||
|
static constexpr T value = T();
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct DefaultInvalid<T, typename std::enable_if<std::is_integral<T>::value>::type>
|
||||||
|
{
|
||||||
|
static constexpr T value = static_cast<T>(-1);
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct SequenceNode
|
||||||
|
{
|
||||||
|
//intrusive double links
|
||||||
|
SequenceNode<T>* prev = nullptr;
|
||||||
|
SequenceNode<T>* next = nullptr;
|
||||||
|
|
||||||
|
//sequence prefix/parent
|
||||||
|
SequenceNode<T>* prefix = nullptr;
|
||||||
|
|
||||||
|
T value = DefaultInvalid<T>::value;
|
||||||
|
|
||||||
|
int state = 0;
|
||||||
|
|
||||||
|
void markAsFullyExtended()
|
||||||
|
{
|
||||||
|
state |= 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
void increaseRef()
|
||||||
|
{
|
||||||
|
//we will have just two copies in bad case. so just or
|
||||||
|
state = state | 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
void decreaseRef()
|
||||||
|
{
|
||||||
|
//we will have just two cases in bad case, so just remove that
|
||||||
|
state = state & (-2);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool safeToRemove()
|
||||||
|
{
|
||||||
|
|
||||||
|
if (state & 1) return false;
|
||||||
|
|
||||||
|
decreaseRef();
|
||||||
|
//we do not want to remove parent nodes in our case. otherwise just returning state<=1 is ok
|
||||||
|
return state == 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool isFullyExtended() const { return state & 1; }
|
||||||
|
};
|
||||||
|
|
||||||
|
/***
|
||||||
|
* Sequence container.
|
||||||
|
*
|
||||||
|
* NOTE: it is not thread-safe
|
||||||
|
*
|
||||||
|
* Extend path - O(1)
|
||||||
|
* Remove path - O(1)
|
||||||
|
* Generating Sequence with backtracking prefix: O(n)
|
||||||
|
*
|
||||||
|
* Note: Sequence container is implemented primitively and only usable within this task.
|
||||||
|
* As it does not behave as a fully capable tree. some cases should be handled manually
|
||||||
|
*
|
||||||
|
* Here is special cases that should be handled manually to exploit tree/graph behaviour:
|
||||||
|
*
|
||||||
|
* Extending new path value:
|
||||||
|
*
|
||||||
|
* To extend the path one need to give path and value and in return get new_path:
|
||||||
|
* new_path = container.extendPath ( path, new_value );
|
||||||
|
*
|
||||||
|
* Also note that:
|
||||||
|
* SequenceContainer has already default empty path as a beginning point for paths.
|
||||||
|
* So as an initial node one should use it.
|
||||||
|
* initial_path = container.getEmptyPath();
|
||||||
|
*
|
||||||
|
* Adding new path that could be already in container:
|
||||||
|
*
|
||||||
|
* Assume we have two paths that can overlap in next step
|
||||||
|
* 1st path: node#0() -> node#1(1) => generated sequence {},{1}
|
||||||
|
* 2nd path: node#0() -> node#1(1) -> node#2(2) => generated sequence {},{1}, {2}
|
||||||
|
*
|
||||||
|
* While extending the first path with value (2). it will be:
|
||||||
|
*
|
||||||
|
* node#0() -> node#0(1) -> node#( either new or old)(2) => generated sequence {},{1}, {2}
|
||||||
|
*
|
||||||
|
* For some tasks its not desired to have additional node that will generate the same sequence.
|
||||||
|
* For example:
|
||||||
|
* Assume you wanted to use it as sequence entry in map with just (entry->prefix, entry->value).
|
||||||
|
* so in that case having different paths is not correct and will not be unique in map.
|
||||||
|
*
|
||||||
|
* there is not direct way to handle that in our container other than searching.
|
||||||
|
* So one should look for the node with prefix node#1(1) and value(2) and return that node instead of adding new one
|
||||||
|
|
||||||
|
* Fortunately, for our beam search case:
|
||||||
|
*
|
||||||
|
* we need only look for such overlapped cases within the candidates list.
|
||||||
|
* which makes it easy to determine them beforehand while finding and marking overlapped cases. instead of looking for it in SequenceContainer
|
||||||
|
*
|
||||||
|
* Removing the same nodes multiple times:
|
||||||
|
* It is fast to remove nodes. As nodes can be stored externally One should follow this rule:
|
||||||
|
*
|
||||||
|
* One should not remove the same node twice as it will lead to double free. as Nodes are pointers the same applies to removing a copy
|
||||||
|
*
|
||||||
|
* There could be cases where you would like to store copy of nodes. in that cases you can use below method to be able to safely remove:
|
||||||
|
* node should have mutable method named safeToRemove().
|
||||||
|
* Basic implementation will be decreasing reference/copy counts and returning true if it is safe to delete
|
||||||
|
*
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
template <typename T>
|
||||||
|
class SequenceContainer
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
SequenceContainer() : count_(1)
|
||||||
|
{
|
||||||
|
empty_path = new SequenceNode<T>();
|
||||||
|
current_ = empty_path;
|
||||||
|
}
|
||||||
|
|
||||||
|
SequenceContainer(const SequenceContainer& s) = delete;
|
||||||
|
|
||||||
|
SequenceContainer(SequenceContainer&& other) noexcept
|
||||||
|
{
|
||||||
|
this->current_ = other.current_;
|
||||||
|
other.current_ = nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
SequenceContainer& operator=(const SequenceContainer& other) = delete;
|
||||||
|
|
||||||
|
SequenceContainer& operator=(SequenceContainer&& other) noexcept
|
||||||
|
{
|
||||||
|
if (this != other)
|
||||||
|
{
|
||||||
|
clear();
|
||||||
|
this->current_ = other.current_;
|
||||||
|
this->count_ = other.count_;
|
||||||
|
other.current_ = nullptr;
|
||||||
|
other.count_ = 0;
|
||||||
|
}
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
SequenceNode<T>* getEmptyPath()
|
||||||
|
{
|
||||||
|
return current_;
|
||||||
|
}
|
||||||
|
|
||||||
|
SequenceNode<T>* extendPath(SequenceNode<T>* prefix, T value)
|
||||||
|
{
|
||||||
|
auto new_node = new SequenceNode<T>();
|
||||||
|
|
||||||
|
new_node->value = value;
|
||||||
|
new_node->prefix = prefix;
|
||||||
|
//add in the holder
|
||||||
|
new_node->next = nullptr;
|
||||||
|
new_node->prev = current_;
|
||||||
|
/*std::cout << "add " << (long long)new_node << std::endl;
|
||||||
|
print_seq1(new_node);*/
|
||||||
|
if (current_) current_->next = new_node;
|
||||||
|
|
||||||
|
current_ = new_node;
|
||||||
|
count_++;
|
||||||
|
return new_node;
|
||||||
|
}
|
||||||
|
|
||||||
|
void remove(SequenceNode<T>* seq)
|
||||||
|
{
|
||||||
|
if (seq == nullptr) return;
|
||||||
|
|
||||||
|
if (!seq->safeToRemove()) return;
|
||||||
|
|
||||||
|
SequenceNode<T>* previous = seq->prev;
|
||||||
|
SequenceNode<T>* next = seq->next;
|
||||||
|
if (previous) previous->next = next;
|
||||||
|
if (next) next->prev = previous;
|
||||||
|
|
||||||
|
if (current_ == seq)
|
||||||
|
{
|
||||||
|
current_ = previous;
|
||||||
|
}
|
||||||
|
//std::cout << "remove " << (long long)seq << " " << std::endl;
|
||||||
|
//print_seq1(seq);
|
||||||
|
delete seq;
|
||||||
|
count_--;
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::vector<T> getSequence(SequenceNode<T>* seq, size_t reserve_size = 1024)
|
||||||
|
{
|
||||||
|
std::vector<T> ret;
|
||||||
|
ret.reserve(reserve_size);
|
||||||
|
SequenceNode<T>* backtrack = seq;
|
||||||
|
while (backtrack)
|
||||||
|
{
|
||||||
|
ret.push_back(backtrack->value);
|
||||||
|
backtrack = backtrack->prefix;
|
||||||
|
}
|
||||||
|
if (ret.size() > 1)
|
||||||
|
{
|
||||||
|
//remove last default node
|
||||||
|
ret.pop_back();
|
||||||
|
//reverse
|
||||||
|
std::reverse(std::begin(ret), std::end(ret));
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
void clear()
|
||||||
|
{
|
||||||
|
//destruct all nodes
|
||||||
|
SequenceNode<T>* del = current_;
|
||||||
|
//int i = 0;
|
||||||
|
while (del)
|
||||||
|
{
|
||||||
|
//++i;
|
||||||
|
SequenceNode<T>* temp = del->prev;
|
||||||
|
delete del;
|
||||||
|
del = temp;
|
||||||
|
}
|
||||||
|
current_ = nullptr;
|
||||||
|
//assert(count_==i);
|
||||||
|
}
|
||||||
|
|
||||||
|
~SequenceContainer()
|
||||||
|
{
|
||||||
|
clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
SequenceNode<T>* current_ = nullptr;
|
||||||
|
|
||||||
|
SequenceNode<T>* empty_path = nullptr;
|
||||||
|
|
||||||
|
int count_ = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T, typename U>
|
||||||
|
struct BeamEntry
|
||||||
|
{
|
||||||
|
SequenceNode<U>* sequence{};
|
||||||
|
BeamProb<T> prob;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
template <typename T, typename U>
|
||||||
|
struct BeamEntryEx
|
||||||
|
{
|
||||||
|
BeamEntry<T, U> entry;
|
||||||
|
//keep indices for lookUp
|
||||||
|
int index_as_child = -1;
|
||||||
|
int index_as_parent = -1;
|
||||||
|
int children_count = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T, typename U>
|
||||||
|
struct LookUpEntry
|
||||||
|
{
|
||||||
|
U last_c; //this is is the same as node->value. just we added for the speed
|
||||||
|
SequenceNode<U>* node = nullptr;
|
||||||
|
int next_beam_index = -1; //index inside next_beam array
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T, typename U>
|
||||||
|
bool compare_beam_prob(const BeamEntry<T, U>& i1, const BeamEntry<T, U>& i2)
|
||||||
|
{
|
||||||
|
return (i1.prob.total > i2.prob.total);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
template <typename T, typename U>
|
||||||
|
T pr(const int c, const BeamProb<T>& beam_prob, const SequenceNode<U>* seq, const T prob)
|
||||||
|
{
|
||||||
|
return seq->value == c ? beam_prob.blank + prob : beam_prob.total + prob;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<bool HasElementStride = false, typename Type, typename IndexType>
|
||||||
|
void inner_beam_search(const Type* log_p, const uint64_t inc_p, IndexType* result_sequence, const uint64_t inc_res_seq,
|
||||||
|
const uint64_t max_len_t, Type* result_prob, IndexType* result_seq_length, uint64_t len_t,
|
||||||
|
const uint64_t len_c, const int blank_index, int beam_width, int nbest_len, bool normalize_logits, const uint64_t element_stride = 1L)
|
||||||
|
{
|
||||||
|
|
||||||
|
using BeamEntryType = BeamEntry<Type, IndexType>;
|
||||||
|
using BeamEntryTypeEx = BeamEntryEx<Type, IndexType>;
|
||||||
|
|
||||||
|
if (beam_width < 1) beam_width = 1;
|
||||||
|
if (nbest_len > beam_width) nbest_len = beam_width;
|
||||||
|
//if len_t is greater than max_len_t truncate it
|
||||||
|
len_t = len_t > max_len_t ? max_len_t : len_t;
|
||||||
|
|
||||||
|
SequenceContainer<IndexType> sequence_container;
|
||||||
|
BeamEntryType empty;
|
||||||
|
empty.prob.blank = 0;
|
||||||
|
empty.prob.total = log_sum_exp(empty.prob.blank, empty.prob.non_blank);
|
||||||
|
empty.sequence = sequence_container.getEmptyPath();
|
||||||
|
|
||||||
|
//vectors: we will use it as array, here
|
||||||
|
std::vector<BeamEntryTypeEx> last_beams;
|
||||||
|
std::vector<BeamEntryType> next_beams;
|
||||||
|
last_beams.resize(beam_width);
|
||||||
|
//as we skip blank indexes the count is beam_width * len_c
|
||||||
|
next_beams.resize(beam_width * len_c);
|
||||||
|
last_beams[0].entry = empty;
|
||||||
|
last_beams[0].index_as_child = -1;
|
||||||
|
last_beams[0].index_as_parent = -1;
|
||||||
|
last_beams[0].children_count = 0;
|
||||||
|
auto last_beam_size = 1;
|
||||||
|
|
||||||
|
// lookupContainer:
|
||||||
|
// it will keep sorted entries. so we will just move and compare the entry
|
||||||
|
// in each step there will be overlapped cases
|
||||||
|
// the size of overlapped cases in last_beam[0:beam_width]:
|
||||||
|
// as we have beam_width size in each step after sort and pruning
|
||||||
|
// there is at least one item who will not have any parent
|
||||||
|
// and for the rest (beam_width-1) it will check has_parent_in_container() ? 1 : 0
|
||||||
|
// so maximum size of overlapped pairs is beam_width-1
|
||||||
|
|
||||||
|
std::vector<LookUpEntry<Type, IndexType>> lookUp;
|
||||||
|
lookUp.resize(beam_width - 1);
|
||||||
|
|
||||||
|
//additional storage to sort overlapped case by classes
|
||||||
|
std::vector<std::pair<IndexType, int >> child_class_sorter_help;
|
||||||
|
child_class_sorter_help.resize(beam_width - 1);
|
||||||
|
Type norm_offset = 0;
|
||||||
|
|
||||||
|
for (uint64_t t = 0; t < len_t; t++)
|
||||||
|
{
|
||||||
|
auto next_beam_size = 0;
|
||||||
|
if (normalize_logits){
|
||||||
|
norm_offset = softmax_normalization_term<HasElementStride, Type, IndexType>(log_p, len_c, element_stride);
|
||||||
|
}
|
||||||
|
for (auto j = 0; j < last_beam_size; j++)
|
||||||
|
{
|
||||||
|
SequenceNode<IndexType>* seq = last_beams[j].entry.sequence;
|
||||||
|
auto& cur_prob = last_beams[j].entry.prob;
|
||||||
|
//if len(seq) > 0 then
|
||||||
|
const auto log_p_blank = element<HasElementStride>(log_p, blank_index, element_stride);
|
||||||
|
Type blank_prob, non_blank_prob;
|
||||||
|
//log_p[seq->value]
|
||||||
|
non_blank_prob = seq->value != -1 ? (element<HasElementStride>(log_p, seq->value, element_stride) + cur_prob.non_blank) : negative_infinity<Type>();
|
||||||
|
blank_prob = log_p_blank + cur_prob.total;
|
||||||
|
|
||||||
|
if (normalize_logits){
|
||||||
|
non_blank_prob = non_blank_prob - norm_offset;
|
||||||
|
blank_prob = blank_prob - norm_offset;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto look_up_beam_index = -1;
|
||||||
|
|
||||||
|
if (last_beams[j].index_as_child != -1)
|
||||||
|
{
|
||||||
|
//check entry
|
||||||
|
look_up_beam_index = lookUp[last_beams[j].index_as_child].next_beam_index;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (look_up_beam_index == -1)
|
||||||
|
{
|
||||||
|
BeamEntryType entry;
|
||||||
|
entry.sequence = seq;
|
||||||
|
entry.prob.blank = blank_prob;
|
||||||
|
entry.prob.non_blank = non_blank_prob;
|
||||||
|
entry.prob.total = log_sum_exp(blank_prob, non_blank_prob);
|
||||||
|
next_beams[next_beam_size] = entry;
|
||||||
|
//map if its overlapped one. in this case just being child is enough
|
||||||
|
if (last_beams[j].index_as_child != -1)
|
||||||
|
{
|
||||||
|
lookUp[last_beams[j].index_as_child].next_beam_index = next_beam_size;
|
||||||
|
}
|
||||||
|
++next_beam_size;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
//note: here we took as ref &
|
||||||
|
auto& entry_prob = next_beams[look_up_beam_index].prob;
|
||||||
|
entry_prob.blank = log_sum_exp(entry_prob.blank, blank_prob);
|
||||||
|
entry_prob.non_blank = log_sum_exp(entry_prob.non_blank, non_blank_prob);
|
||||||
|
entry_prob.total = log_sum_exp(entry_prob.blank, entry_prob.non_blank);
|
||||||
|
}
|
||||||
|
//check to see if it is overlapped parent
|
||||||
|
auto start_index = last_beams[j].index_as_parent;
|
||||||
|
auto end_index = last_beams[j].index_as_parent + last_beams[j].children_count;
|
||||||
|
|
||||||
|
for (int c = 0; c < len_c; c++)
|
||||||
|
{
|
||||||
|
if (c == blank_index) continue;
|
||||||
|
|
||||||
|
const auto prob = element<HasElementStride>(log_p, c, element_stride);//log_p[c];
|
||||||
|
|
||||||
|
non_blank_prob = pr(c, cur_prob, seq, prob);
|
||||||
|
if(normalize_logits) non_blank_prob = non_blank_prob - norm_offset;
|
||||||
|
//extend by new character
|
||||||
|
auto look_up_beam_index_ex = -1;
|
||||||
|
int found_index = -1;
|
||||||
|
|
||||||
|
//get index within array if its that class index
|
||||||
|
if (start_index < end_index && lookUp[start_index].last_c == c){
|
||||||
|
look_up_beam_index_ex = lookUp[start_index].next_beam_index;
|
||||||
|
|
||||||
|
found_index = start_index;
|
||||||
|
++start_index;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (look_up_beam_index_ex == -1)
|
||||||
|
{
|
||||||
|
BeamEntryType entry;
|
||||||
|
SequenceNode<IndexType>* extended_sequence;
|
||||||
|
if (found_index!=-1)
|
||||||
|
{
|
||||||
|
extended_sequence = lookUp[found_index].node;
|
||||||
|
//assing next_beam_index for lookup
|
||||||
|
lookUp[found_index].next_beam_index = next_beam_size;
|
||||||
|
extended_sequence->increaseRef();
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
extended_sequence = sequence_container.extendPath(seq, c);
|
||||||
|
}
|
||||||
|
entry.prob.non_blank = non_blank_prob;
|
||||||
|
entry.prob.total = non_blank_prob;
|
||||||
|
entry.sequence = extended_sequence;
|
||||||
|
next_beams[next_beam_size] = entry;
|
||||||
|
|
||||||
|
++next_beam_size;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
auto& entry_prob = next_beams[look_up_beam_index_ex].prob;
|
||||||
|
entry_prob.non_blank = log_sum_exp(entry_prob.non_blank, non_blank_prob);
|
||||||
|
entry_prob.total = log_sum_exp(entry_prob.total, non_blank_prob);
|
||||||
|
}
|
||||||
|
} //iteration over classes
|
||||||
|
|
||||||
|
//mark it as extended
|
||||||
|
seq->markAsFullyExtended();
|
||||||
|
|
||||||
|
} //iteration over beams
|
||||||
|
|
||||||
|
log_p += inc_p;
|
||||||
|
|
||||||
|
last_beam_size = std::min(next_beam_size, beam_width);
|
||||||
|
#if !defined(NTH_ELEMENT)
|
||||||
|
//sort next beams to get candidates
|
||||||
|
std::partial_sort(std::begin(next_beams),
|
||||||
|
std::begin(next_beams) + last_beam_size,
|
||||||
|
std::begin(next_beams) + next_beam_size, compare_beam_prob<Type, IndexType>);
|
||||||
|
|
||||||
|
#else
|
||||||
|
std::nth_element(std::begin(next_beams),
|
||||||
|
std::begin(next_beams) + last_beam_size,
|
||||||
|
std::begin(next_beams) + next_beam_size, compare_beam_prob<Type, IndexType>);
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
if (t < len_t)
|
||||||
|
{
|
||||||
|
//copy top beams
|
||||||
|
for (int j = 0; j < last_beam_size; j++)
|
||||||
|
{
|
||||||
|
last_beams[j].entry = next_beams[j];
|
||||||
|
last_beams[j].index_as_child = -1;
|
||||||
|
last_beams[j].index_as_parent = -1;
|
||||||
|
last_beams[j].children_count = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
//delete sequences from the sequence_holder to decrease memory
|
||||||
|
for (auto j = beam_width; j < next_beam_size; j++)
|
||||||
|
{
|
||||||
|
sequence_container.remove(next_beams[j].sequence);
|
||||||
|
}
|
||||||
|
|
||||||
|
//check overlapping cases and create lookUp with sorted classes as well
|
||||||
|
int look_up_index = 0;
|
||||||
|
for (auto j = 0; j < last_beam_size; j++)
|
||||||
|
{
|
||||||
|
//if it is not parent node then there is not any need to check
|
||||||
|
if (last_beams[j].entry.sequence->isFullyExtended())
|
||||||
|
{
|
||||||
|
auto parent_seq=last_beams[j].entry.sequence;
|
||||||
|
int children_count = 0;
|
||||||
|
for (int k = 0; k < last_beam_size; k++)
|
||||||
|
{
|
||||||
|
auto current = last_beams[k].entry.sequence;
|
||||||
|
if (current->prefix == parent_seq)
|
||||||
|
{
|
||||||
|
child_class_sorter_help[children_count].first = current->value;
|
||||||
|
child_class_sorter_help[children_count].second = k ;
|
||||||
|
++children_count ;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (children_count > 0)
|
||||||
|
{
|
||||||
|
|
||||||
|
//sort by class
|
||||||
|
if(children_count<2){
|
||||||
|
//
|
||||||
|
if (children_count > 1 && child_class_sorter_help[0].first > child_class_sorter_help[1].first)
|
||||||
|
{
|
||||||
|
std::swap(child_class_sorter_help[0], child_class_sorter_help[1]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
std::sort(std::begin(child_class_sorter_help), std::begin(child_class_sorter_help) + children_count,
|
||||||
|
[](const std::pair<int, int>& left, const std::pair<int, int>& right) {
|
||||||
|
return left.first < right.first;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
last_beams[j].index_as_parent = look_up_index;
|
||||||
|
last_beams[j].children_count = children_count;
|
||||||
|
|
||||||
|
for (int l = 0; l < children_count; l++)
|
||||||
|
{
|
||||||
|
|
||||||
|
int c = child_class_sorter_help[l].first;
|
||||||
|
int k = child_class_sorter_help[l].second;
|
||||||
|
//std::cout << c <<" , " << k << std::endl;
|
||||||
|
last_beams[k].index_as_child = look_up_index;
|
||||||
|
auto seq = last_beams[k].entry.sequence;
|
||||||
|
lookUp[look_up_index].last_c = c;
|
||||||
|
lookUp[look_up_index].node = seq;
|
||||||
|
lookUp[look_up_index].next_beam_index = -1;
|
||||||
|
//next one
|
||||||
|
++look_up_index;
|
||||||
|
}
|
||||||
|
}//add sorted lookUps
|
||||||
|
|
||||||
|
}
|
||||||
|
} //overlap_direction identified to speed up lookUp
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}//iterate over t
|
||||||
|
#if defined(NTH_ELEMENT)
|
||||||
|
//use sort for n elements as only nth_element was used
|
||||||
|
std::sort(std::begin(next_beams), std::begin(next_beams) + last_beam_size, compare_beam_prob<Type, IndexType>);
|
||||||
|
#endif
|
||||||
|
//store nbest results
|
||||||
|
if (nbest_len <= last_beam_size) {
|
||||||
|
for (int j = 0; j < nbest_len; j++)
|
||||||
|
{
|
||||||
|
auto top = next_beams[j];
|
||||||
|
auto result_vector = SequenceContainer<IndexType>::getSequence(top.sequence, len_t);
|
||||||
|
const auto seq_size = result_vector.size();
|
||||||
|
|
||||||
|
result_prob[j] = top.prob.total;
|
||||||
|
result_seq_length[j] = seq_size;
|
||||||
|
//copy sequence
|
||||||
|
for (auto s = 0; s < seq_size; s++)
|
||||||
|
{
|
||||||
|
result_sequence[s] = result_vector[s];
|
||||||
|
}
|
||||||
|
|
||||||
|
result_sequence += inc_res_seq;
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
for (int j = 0; j < nbest_len; j++)
|
||||||
|
{
|
||||||
|
result_prob[j] = negative_infinity<Type>();
|
||||||
|
result_seq_length[j] = 0;;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename Type, typename IndexType = int>
|
||||||
|
void
|
||||||
|
beamSearch_(const NDArray& logit, const NDArray& sequence_length, NDArray& result_sequences, NDArray& result_probs, NDArray& result_sequences_length, int blank_index, int beam_width, int nbest_len, bool normalize_logits )
|
||||||
|
{
|
||||||
|
|
||||||
|
const auto shapes = logit.shapeOf();
|
||||||
|
const auto strides = logit.stridesOf();
|
||||||
|
const auto rank = logit.rankOf();
|
||||||
|
|
||||||
|
const IndexType* len_t_ptr = nullptr;
|
||||||
|
uint64_t element_stride_t = 1;
|
||||||
|
|
||||||
|
//checks before
|
||||||
|
if (rank < 2) return;
|
||||||
|
auto batch_len = rank > 2 ? shapes[0] : 1;
|
||||||
|
auto max_len_t = shapes[rank - 2];
|
||||||
|
auto len_c = shapes[rank - 1];
|
||||||
|
|
||||||
|
if (len_c < 1 || max_len_t < 1) return;
|
||||||
|
//defaulting blankIndex to the last class if its incorrect or -1
|
||||||
|
if (blank_index > len_c || blank_index < 0) blank_index = static_cast<int>(len_c) - 1;
|
||||||
|
if (sequence_length.rankOf() == 1 && sequence_length.shapeOf()[0] == batch_len)
|
||||||
|
{
|
||||||
|
len_t_ptr = sequence_length.bufferAsT<IndexType>();
|
||||||
|
element_stride_t = sequence_length.stridesOf()[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
//strides
|
||||||
|
auto batch_stride = rank > 2 ? strides[0] : 0;
|
||||||
|
auto inc_p = strides[rank - 2];
|
||||||
|
auto element_stride = logit.stridesOf()[rank - 1];
|
||||||
|
|
||||||
|
auto logits_ptr = logit.bufferAsT<Type>();
|
||||||
|
|
||||||
|
#if defined(ASSERT_INNER)
|
||||||
|
//result_probs should be [batch_len, nbest_len]
|
||||||
|
assert(result_probs.ews() == 1 && result_probs.rankOf() == 2 && result_probs.shapeOf()[0] == batch_len && result_probs.shapeOf()[1] == nbest_len);
|
||||||
|
//result sequence should be [batch_len, nbest_len, max_len_t]
|
||||||
|
assert(result_sequences.ews() == 1 && result_sequences.rankOf() == 3 && result_sequences.shapeOf()[0] == batch_len && result_sequences.shapeOf()[1] == nbest_len
|
||||||
|
&& result_sequences.shapeOf()[2] == max_len_t);
|
||||||
|
#endif
|
||||||
|
auto result_seq_ptr = result_sequences.bufferAsT<IndexType>();
|
||||||
|
auto result_probs_ptr = result_probs.bufferAsT<Type>();
|
||||||
|
auto result_seq_length_ptr = result_sequences_length.bufferAsT<IndexType>();
|
||||||
|
|
||||||
|
const auto batch_stride_res = result_sequences.stridesOf()[0];
|
||||||
|
const auto inc_res = result_sequences.stridesOf()[1];
|
||||||
|
const auto batch_stride_res_prob = result_probs.stridesOf()[0];
|
||||||
|
const auto batch_stride_res_seq_length = result_sequences_length.stridesOf()[0];
|
||||||
|
auto func = [max_len_t, len_c, batch_stride, inc_p, element_stride, element_stride_t, logits_ptr, len_t_ptr, blank_index, beam_width, normalize_logits,
|
||||||
|
nbest_len, result_seq_ptr, result_seq_length_ptr, result_probs_ptr, batch_stride_res, inc_res, batch_stride_res_prob, batch_stride_res_seq_length]
|
||||||
|
(uint64_t thread_id, int64_t start, int64_t stop, int64_t increment) -> void
|
||||||
|
{
|
||||||
|
|
||||||
|
auto ptr = logits_ptr + start * batch_stride;
|
||||||
|
|
||||||
|
if (element_stride == 1)
|
||||||
|
{
|
||||||
|
//choose ews one
|
||||||
|
for (auto b = start; b < stop; b += increment)
|
||||||
|
{
|
||||||
|
auto prob_ptr = &(result_probs_ptr[b * batch_stride_res_prob]);
|
||||||
|
auto seq_length_ptr = &(result_seq_length_ptr[b * batch_stride_res_seq_length]);
|
||||||
|
auto seq_ptr = &(result_seq_ptr[b * batch_stride_res]);
|
||||||
|
|
||||||
|
auto len_t = len_t_ptr ? len_t_ptr[b * element_stride_t] : max_len_t;
|
||||||
|
inner_beam_search<false, Type, IndexType>(ptr, inc_p, seq_ptr, inc_res, max_len_t, prob_ptr, seq_length_ptr, len_t, len_c, blank_index, beam_width, nbest_len, normalize_logits);
|
||||||
|
|
||||||
|
ptr += batch_stride;
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
// element with stride case
|
||||||
|
for (auto b = start; b < stop; b += increment)
|
||||||
|
{
|
||||||
|
auto prob_ptr = &(result_probs_ptr[b * batch_stride_res_prob]);
|
||||||
|
auto seq_length_ptr = &(result_seq_length_ptr[b * batch_stride_res_seq_length]);
|
||||||
|
auto seq_ptr = &(result_seq_ptr[b * batch_stride_res]);
|
||||||
|
|
||||||
|
auto len_t = len_t_ptr ? len_t_ptr[b * element_stride_t] : max_len_t;
|
||||||
|
inner_beam_search<false, Type, IndexType>(ptr, inc_p, seq_ptr, inc_res, max_len_t, prob_ptr, seq_length_ptr, len_t, len_c, blank_index, beam_width, nbest_len, normalize_logits, element_stride);
|
||||||
|
|
||||||
|
ptr += batch_stride;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
samediff::Threads::parallel_for(func, 0, batch_len, 1);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
void beamSearch(const NDArray& logit, const NDArray& sequence_length, NDArray& result_sequences, NDArray& result_probs, NDArray& result_sequences_length, int blank_index, int beam_width , int nbest_len, bool normalize_logits = true){
|
||||||
|
|
||||||
|
BUILD_DOUBLE_SELECTOR(logit.dataType(), result_sequences.dataType(), beamSearch_, (logit, sequence_length, result_sequences, result_probs, result_sequences_length, blank_index, beam_width , nbest_len, normalize_logits), FLOAT_TYPES, INDEXING_TYPES);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
BUILD_DOUBLE_TEMPLATE(template void beamSearch_, (const NDArray& logit, const NDArray& sequence_length, NDArray& result_sequences, NDArray& result_probs, NDArray& result_sequences_length, int blank_index, int beam_width , int nbest_len, bool normalize_logits), FLOAT_TYPES, INDEXING_TYPES);
|
||||||
|
|
||||||
|
}}}
|
|
@ -0,0 +1,31 @@
|
||||||
|
cmake_minimum_required(VERSION 3.15)
|
||||||
|
project(tests_cpu)
|
||||||
|
|
||||||
|
# Download and unpack googletest at configure time
|
||||||
|
configure_file(CMakeLists.txt.in googletest-download/CMakeLists.txt)
|
||||||
|
execute_process(COMMAND ${CMAKE_COMMAND} -G "${CMAKE_GENERATOR}" .
|
||||||
|
RESULT_VARIABLE result
|
||||||
|
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/googletest-download )
|
||||||
|
if(result)
|
||||||
|
message(FATAL_ERROR "CMake step for googletest failed: ${result}")
|
||||||
|
endif()
|
||||||
|
execute_process(COMMAND ${CMAKE_COMMAND} --build .
|
||||||
|
RESULT_VARIABLE result
|
||||||
|
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/googletest-download )
|
||||||
|
if(result)
|
||||||
|
message(FATAL_ERROR "Build step for googletest failed: ${result}")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
# Prevent overriding the parent project's compiler/linker
|
||||||
|
# settings on Windows
|
||||||
|
set(gtest_force_shared_crt ON CACHE BOOL "" FORCE)
|
||||||
|
|
||||||
|
# Add googletest directly to our build. This defines
|
||||||
|
# the gtest and gtest_main targets.
|
||||||
|
add_subdirectory(${CMAKE_CURRENT_BINARY_DIR}/googletest-src
|
||||||
|
${CMAKE_CURRENT_BINARY_DIR}/googletest-build
|
||||||
|
EXCLUDE_FROM_ALL)
|
||||||
|
|
||||||
|
set(gtest_SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/googletest-src)
|
||||||
|
#add_subdirectory(libnd4j_tests)
|
||||||
|
add_subdirectory(layers_tests)
|
|
@ -0,0 +1,16 @@
|
||||||
|
cmake_minimum_required(VERSION 2.8.2)
|
||||||
|
|
||||||
|
project(googletest-download NONE)
|
||||||
|
|
||||||
|
include(ExternalProject)
|
||||||
|
ExternalProject_Add(googletest
|
||||||
|
GIT_REPOSITORY https://github.com/google/googletest.git
|
||||||
|
GIT_TAG release-1.10.0
|
||||||
|
SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/googletest-src"
|
||||||
|
BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/googletest-build"
|
||||||
|
CMAKE_ARGS ""
|
||||||
|
CONFIGURE_COMMAND ""
|
||||||
|
BUILD_COMMAND ""
|
||||||
|
INSTALL_COMMAND ""
|
||||||
|
TEST_COMMAND ""
|
||||||
|
)
|
|
@ -0,0 +1,47 @@
|
||||||
|
/* ******************************************************************************
|
||||||
|
*
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* See the NOTICE file distributed with this work for additional
|
||||||
|
* information regarding copyright ownership.
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// Created by raver119 on 04.08.17.
|
||||||
|
//
|
||||||
|
//
|
||||||
|
#include "testlayers.h"
|
||||||
|
/*
|
||||||
|
#include "DenseLayerTests.cpp"
|
||||||
|
#include "NDArrayTests.cpp"
|
||||||
|
#include "VariableSpaceTests.cpp"
|
||||||
|
#include "VariableTests.cpp"
|
||||||
|
#include "DeclarableOpsTests.cpp"
|
||||||
|
#include "HashUtilsTests.cpp"
|
||||||
|
#include "WorkspaceTests.cpp"
|
||||||
|
#include "ConvolutionTests.cpp"
|
||||||
|
#include "TadTests.cpp"
|
||||||
|
#include "StashTests.cpp"
|
||||||
|
#include "SessionLocalTests.cpp"
|
||||||
|
#include "GraphTests.cpp"
|
||||||
|
#include "FlatBuffersTests.cpp"
|
||||||
|
*/
|
||||||
|
///////
|
||||||
|
|
||||||
|
//#include "CyclicTests.h"
|
||||||
|
// #include "ProtoBufTests.cpp"
|
||||||
|
|
||||||
|
int main(int argc, char **argv) {
|
||||||
|
testing::InitGoogleTest(&argc, argv);
|
||||||
|
return RUN_ALL_TESTS();
|
||||||
|
}
|
|
@ -0,0 +1,111 @@
|
||||||
|
/* ******************************************************************************
|
||||||
|
*
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* See the NOTICE file distributed with this work for additional
|
||||||
|
* information regarding copyright ownership.
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// Created by raver119 on 13.01.2018.
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "testlayers.h"
|
||||||
|
#include <array/ArrayOptions.h>
|
||||||
|
#include <array/NDArray.h>
|
||||||
|
|
||||||
|
using namespace sd;
|
||||||
|
|
||||||
|
|
||||||
|
class ArrayOptionsTests : public testing::Test {
|
||||||
|
public:
|
||||||
|
Nd4jLong shape[8] = {2, 5, 5, 5, 1, 0, 1, 99};
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(ArrayOptionsTests, TestShape_Basic_0) {
|
||||||
|
shape[5] = 1;
|
||||||
|
|
||||||
|
ASSERT_TRUE(ArrayOptions::isNewFormat(shape));
|
||||||
|
ASSERT_FALSE(ArrayOptions::isSparseArray(shape));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(ArrayOptionsTests, TestShape_Basic_1) {
|
||||||
|
shape[5] = 2;
|
||||||
|
|
||||||
|
|
||||||
|
ASSERT_TRUE(ArrayOptions::isNewFormat(shape));
|
||||||
|
ASSERT_TRUE(ArrayOptions::isSparseArray(shape));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ArrayOptionsTests, TestShape_Basic_2) {
|
||||||
|
shape[5] = 258;
|
||||||
|
|
||||||
|
ASSERT_TRUE(ArrayOptions::isNewFormat(shape));
|
||||||
|
|
||||||
|
ASSERT_TRUE(ArrayOptions::isSparseArray(shape));
|
||||||
|
ASSERT_EQ(SpaceType::CONTINUOUS, ArrayOptions::spaceType(shape));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ArrayOptionsTests, TestShape_Basic_3) {
|
||||||
|
ASSERT_EQ(0, shape::extra(shape));
|
||||||
|
|
||||||
|
ASSERT_EQ(SpaceType::CONTINUOUS, ArrayOptions::spaceType(shape));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ArrayOptionsTests, TestShape_Basic_4) {
|
||||||
|
|
||||||
|
ArrayOptions::setPropertyBits(shape, {ARRAY_HALF, ARRAY_QUANTIZED});
|
||||||
|
|
||||||
|
auto dtype = ArrayOptions::dataType(shape);
|
||||||
|
|
||||||
|
ASSERT_FALSE(ArrayOptions::isSparseArray(shape));
|
||||||
|
ASSERT_TRUE(sd::DataType::HALF == ArrayOptions::dataType(shape));
|
||||||
|
ASSERT_EQ(sd::ArrayType::DENSE, ArrayOptions::arrayType(shape));
|
||||||
|
ASSERT_EQ(sd::SpaceType::QUANTIZED, ArrayOptions::spaceType(shape));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ArrayOptionsTests, TestShape_Basic_5) {
|
||||||
|
ArrayOptions::setPropertyBits(shape, {ARRAY_SPARSE, ARRAY_INT, ARRAY_CSC});
|
||||||
|
|
||||||
|
ASSERT_TRUE(ArrayOptions::isSparseArray(shape));
|
||||||
|
ASSERT_TRUE(sd::DataType::INT32 == ArrayOptions::dataType(shape));
|
||||||
|
ASSERT_EQ(sd::SparseType::CSC, ArrayOptions::sparseType(shape));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ArrayOptionsTests, TestShape_Basic_6) {
|
||||||
|
ArrayOptions::setPropertyBits(shape, {ARRAY_EMPTY, ARRAY_INT, ARRAY_CSC});
|
||||||
|
|
||||||
|
ASSERT_EQ(sd::ArrayType::EMPTY, ArrayOptions::arrayType(shape));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ArrayOptionsTests, TestShape_Basic_7) {
|
||||||
|
ArrayOptions::setDataType(shape, sd::DataType::FLOAT32);
|
||||||
|
ArrayOptions::setDataType(shape, sd::DataType::FLOAT32);
|
||||||
|
|
||||||
|
ASSERT_EQ(sd::DataType::FLOAT32, ArrayOptions::dataType(shape));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ArrayOptionsTests, TestShape_Basic_8) {
|
||||||
|
ArrayOptions::setDataType(shape, sd::DataType::DOUBLE);
|
||||||
|
ArrayOptions::setDataType(shape, sd::DataType::FLOAT32);
|
||||||
|
|
||||||
|
ASSERT_EQ(sd::DataType::FLOAT32, ArrayOptions::dataType(shape));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ArrayOptionsTests, TestShape_Basic_9) {
|
||||||
|
ArrayOptions::setDataType(shape, sd::DataType::FLOAT32);
|
||||||
|
ArrayOptions::setDataType(shape, sd::DataType::DOUBLE);
|
||||||
|
|
||||||
|
ASSERT_EQ(sd::DataType::DOUBLE, ArrayOptions::dataType(shape));
|
||||||
|
}
|
|
@ -0,0 +1,243 @@
|
||||||
|
/* ******************************************************************************
|
||||||
|
*
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* See the NOTICE file distributed with this work for additional
|
||||||
|
* information regarding copyright ownership.
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "testlayers.h"
|
||||||
|
#include <ops/declarable/CustomOperations.h>
|
||||||
|
#include <array/NDArray.h>
|
||||||
|
#include <ops/ops.h>
|
||||||
|
#include <helpers/GradCheck.h>
|
||||||
|
#include <helpers/RandomLauncher.h>
|
||||||
|
#include <exceptions/cuda_exception.h>
|
||||||
|
|
||||||
|
|
||||||
|
using namespace sd;
|
||||||
|
|
||||||
|
|
||||||
|
class AtomicTests : public testing::Test {
|
||||||
|
public:
|
||||||
|
AtomicTests() {
|
||||||
|
//
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static _CUDA_G void multiplyKernel(void *vbuffer, uint64_t length, void *vresult) {
|
||||||
|
auto buffer = reinterpret_cast<T*>(vbuffer);
|
||||||
|
auto result = reinterpret_cast<T*>(vresult);
|
||||||
|
|
||||||
|
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
|
||||||
|
for (auto e = tid; e < length; e += gridDim.x * blockDim.x) {
|
||||||
|
auto rem = e % 4;
|
||||||
|
auto i = (e - rem) / 4;
|
||||||
|
|
||||||
|
sd::math::atomics::nd4j_atomicMul<T>(&result[i], buffer[e]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static void multiplyLauncher(void *vbuffer, uint64_t length, void *vresult) {
|
||||||
|
multiplyKernel<T><<<256, 256, 1024, *sd::LaunchContext::defaultContext()->getCudaStream()>>>(vbuffer, length, vresult);
|
||||||
|
auto err = cudaStreamSynchronize(*sd::LaunchContext::defaultContext()->getCudaStream());
|
||||||
|
if (err != 0)
|
||||||
|
throw sd::cuda_exception::build("multiply failed", err);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static _CUDA_G void sumKernel(void *vbuffer, uint64_t length, void *vresult) {
|
||||||
|
auto buffer = reinterpret_cast<T*>(vbuffer);
|
||||||
|
auto result = reinterpret_cast<T*>(vresult);
|
||||||
|
|
||||||
|
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
|
||||||
|
for (auto e = tid; e < length; e += gridDim.x * blockDim.x) {
|
||||||
|
auto rem = e % 4;
|
||||||
|
auto i = (e - rem) / 4;
|
||||||
|
|
||||||
|
sd::math::atomics::nd4j_atomicAdd<T>(&result[i], buffer[e]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static void sumLauncher(void *vbuffer, uint64_t length, void *vresult) {
|
||||||
|
sumKernel<T><<<256, 256, 1024, *sd::LaunchContext::defaultContext()->getCudaStream()>>>(vbuffer, length, vresult);
|
||||||
|
auto err = cudaStreamSynchronize(*sd::LaunchContext::defaultContext()->getCudaStream());
|
||||||
|
if (err != 0)
|
||||||
|
throw sd::cuda_exception::build("sum failed", err);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static _CUDA_G void subKernel(void *vbuffer, uint64_t length, void *vresult) {
|
||||||
|
auto buffer = reinterpret_cast<T*>(vbuffer);
|
||||||
|
auto result = reinterpret_cast<T*>(vresult);
|
||||||
|
|
||||||
|
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
|
||||||
|
for (auto e = tid; e < length; e += gridDim.x * blockDim.x) {
|
||||||
|
auto rem = e % 4;
|
||||||
|
auto i = (e - rem) / 4;
|
||||||
|
|
||||||
|
sd::math::atomics::nd4j_atomicSub<T>(&result[i], buffer[e]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static void subLauncher(void *vbuffer, uint64_t length, void *vresult) {
|
||||||
|
subKernel<T><<<256, 256, 1024, *sd::LaunchContext::defaultContext()->getCudaStream()>>>(vbuffer, length, vresult);
|
||||||
|
auto err = cudaStreamSynchronize(*sd::LaunchContext::defaultContext()->getCudaStream());
|
||||||
|
if (err != 0)
|
||||||
|
throw sd::cuda_exception::build("sub failed", err);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static _CUDA_G void divKernel(void *vbuffer, uint64_t length, void *vresult) {
|
||||||
|
auto buffer = reinterpret_cast<T*>(vbuffer);
|
||||||
|
auto result = reinterpret_cast<T*>(vresult);
|
||||||
|
|
||||||
|
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
|
||||||
|
for (auto e = tid; e < length; e += gridDim.x * blockDim.x) {
|
||||||
|
auto rem = e % 4;
|
||||||
|
auto i = (e - rem) / 4;
|
||||||
|
|
||||||
|
sd::math::atomics::nd4j_atomicDiv<T>(&result[i], buffer[e]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static void divLauncher(void *vbuffer, uint64_t length, void *vresult) {
|
||||||
|
divKernel<T><<<256, 256, 1024, *sd::LaunchContext::defaultContext()->getCudaStream()>>>(vbuffer, length, vresult);
|
||||||
|
auto err = cudaStreamSynchronize(*sd::LaunchContext::defaultContext()->getCudaStream());
|
||||||
|
if (err != 0)
|
||||||
|
throw sd::cuda_exception::build("div failed", err);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void multiplyHost(NDArray &input, NDArray &output) {
|
||||||
|
BUILD_SINGLE_SELECTOR(input.dataType(), multiplyLauncher, (input.specialBuffer(), input.lengthOf(), output.specialBuffer()), NUMERIC_TYPES);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void sumHost(NDArray &input, NDArray &output) {
|
||||||
|
BUILD_SINGLE_SELECTOR(input.dataType(), sumLauncher, (input.specialBuffer(), input.lengthOf(), output.specialBuffer()), NUMERIC_TYPES);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void subHost(NDArray &input, NDArray &output) {
|
||||||
|
BUILD_SINGLE_SELECTOR(input.dataType(), subLauncher, (input.specialBuffer(), input.lengthOf(), output.specialBuffer()), FLOAT_TYPES);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void divHost(NDArray &input, NDArray &output) {
|
||||||
|
BUILD_SINGLE_SELECTOR(input.dataType(), divLauncher, (input.specialBuffer(), input.lengthOf(), output.specialBuffer()), FLOAT_TYPES);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(AtomicTests, test_multiply) {
|
||||||
|
std::vector<sd::DataType> dtypes = {sd::DataType::FLOAT32, sd::DataType::DOUBLE, sd::DataType::INT16, sd::DataType::HALF};
|
||||||
|
|
||||||
|
for (auto t:dtypes) {
|
||||||
|
nd4j_printf("Trying data type [%s]\n", DataTypeUtils::asString(t).c_str());
|
||||||
|
NDArray input('c', {4, 25}, t);
|
||||||
|
NDArray output('c', {input.lengthOf() / 4}, t);
|
||||||
|
NDArray exp = output.ulike();
|
||||||
|
|
||||||
|
input.assign(2);
|
||||||
|
output.assign(2);
|
||||||
|
exp.assign(32);
|
||||||
|
|
||||||
|
multiplyHost(input, output);
|
||||||
|
ASSERT_EQ(exp, output);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(AtomicTests, test_multiply_2) {
|
||||||
|
std::vector<sd::DataType> dtypes = {sd::DataType::FLOAT32, sd::DataType::DOUBLE, sd::DataType::HALF, sd::DataType::BFLOAT16};
|
||||||
|
|
||||||
|
for (auto t:dtypes) {
|
||||||
|
nd4j_printf("Trying data type [%s]\n", DataTypeUtils::asString(t).c_str());
|
||||||
|
NDArray input('c', {4, 25}, t);
|
||||||
|
NDArray output('c', {input.lengthOf() / 4}, t);
|
||||||
|
NDArray exp = output.ulike();
|
||||||
|
|
||||||
|
input.assign(1.5);
|
||||||
|
output.assign(2);
|
||||||
|
exp.assign(10.125);
|
||||||
|
|
||||||
|
multiplyHost(input, output);
|
||||||
|
// output.printBuffer("multiply 2");
|
||||||
|
ASSERT_EQ(exp, output);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(AtomicTests, test_sum) {
|
||||||
|
std::vector<sd::DataType> dtypes = {sd::DataType::FLOAT32, sd::DataType::DOUBLE, sd::DataType::BFLOAT16, sd::DataType::HALF, sd::DataType::INT16};
|
||||||
|
|
||||||
|
for (auto t:dtypes) {
|
||||||
|
nd4j_printf("Trying data type [%s]\n", DataTypeUtils::asString(t).c_str());
|
||||||
|
NDArray input('c', {4, 25}, t);
|
||||||
|
NDArray output('c', {input.lengthOf() / 4}, t);
|
||||||
|
NDArray exp = output.ulike();
|
||||||
|
|
||||||
|
input.assign(1);
|
||||||
|
output.assign(1);
|
||||||
|
exp.assign(5);
|
||||||
|
|
||||||
|
sumHost(input, output);
|
||||||
|
// output.printIndexedBuffer("Sum");
|
||||||
|
ASSERT_EQ(exp, output);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(AtomicTests, test_sub) {
|
||||||
|
std::vector<sd::DataType> dtypes = {sd::DataType::FLOAT32, sd::DataType::DOUBLE, sd::DataType::HALF};
|
||||||
|
|
||||||
|
for (auto t:dtypes) {
|
||||||
|
nd4j_printf("Trying data type [%s]\n", DataTypeUtils::asString(t).c_str());
|
||||||
|
NDArray input('c', {4, 25}, t);
|
||||||
|
NDArray output('c', {input.lengthOf() / 4}, t);
|
||||||
|
NDArray exp = output.ulike();
|
||||||
|
|
||||||
|
input.assign(1);
|
||||||
|
output.assign(5);
|
||||||
|
exp.assign(1);
|
||||||
|
|
||||||
|
subHost(input, output);
|
||||||
|
// output.printBuffer("Sub");
|
||||||
|
|
||||||
|
ASSERT_EQ(exp, output);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(AtomicTests, test_div) {
|
||||||
|
std::vector<sd::DataType> dtypes = {sd::DataType::FLOAT32, sd::DataType::DOUBLE, sd::DataType::BFLOAT16, sd::DataType::HALF};
|
||||||
|
|
||||||
|
for (auto t:dtypes) {
|
||||||
|
nd4j_printf("Trying data type [%s]\n", DataTypeUtils::asString(t).c_str());
|
||||||
|
NDArray input('c', {4, 25}, t);
|
||||||
|
NDArray output('c', {input.lengthOf() / 4}, t);
|
||||||
|
NDArray exp = output.ulike();
|
||||||
|
|
||||||
|
input.assign(2);
|
||||||
|
output.assign(32);
|
||||||
|
exp.assign(2);
|
||||||
|
|
||||||
|
divHost(input, output);
|
||||||
|
// output.printBuffer("Div");
|
||||||
|
ASSERT_EQ(exp, output);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,222 @@
|
||||||
|
/*
|
||||||
|
* ******************************************************************************
|
||||||
|
* *
|
||||||
|
* *
|
||||||
|
* * This program and the accompanying materials are made available under the
|
||||||
|
* * terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
* *
|
||||||
|
* * See the NOTICE file distributed with this work for additional
|
||||||
|
* * information regarding copyright ownership.
|
||||||
|
* * Unless required by applicable law or agreed to in writing, software
|
||||||
|
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* * License for the specific language governing permissions and limitations
|
||||||
|
* * under the License.
|
||||||
|
* *
|
||||||
|
* * SPDX-License-Identifier: Apache-2.0
|
||||||
|
* *****************************************************************************
|
||||||
|
*/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "testlayers.h"
|
||||||
|
#include <ops/declarable/CustomOperations.h>
|
||||||
|
#include <array/NDArray.h>
|
||||||
|
#include <ops/ops.h>
|
||||||
|
#include <helpers/GradCheck.h>
|
||||||
|
#include <helpers/RandomLauncher.h>
|
||||||
|
|
||||||
|
|
||||||
|
using namespace sd;
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionTests : public testing::Test {
|
||||||
|
public:
|
||||||
|
AttentionTests() {
|
||||||
|
printf("\n");
|
||||||
|
fflush(stdout);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(AttentionTests, basic_dot_product_attention) {
|
||||||
|
auto keys = NDArrayFactory::create<float>('c', {10, 4, 3});
|
||||||
|
auto values = NDArrayFactory::create<float>('c', {10, 4, 3});
|
||||||
|
auto queries = NDArrayFactory::create<float>('c', {10, 4, 1});
|
||||||
|
|
||||||
|
sd::ops::dot_product_attention op;
|
||||||
|
auto result = op.evaluate({&queries, &keys, &values}, {1, 0});
|
||||||
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
//Ignored: AB 2019/05/21 - Segmentation fault on on linux-ppc64le-cpu - https://github.com/deeplearning4j/deeplearning4j/issues/7657
|
||||||
|
TEST_F(AttentionTests, basic_dot_product_attention_bp) {
|
||||||
|
auto keys = NDArrayFactory::create<float>('c', {10, 4, 3});
|
||||||
|
auto values = NDArrayFactory::create<float>('c', {10, 4, 3});
|
||||||
|
auto queries = NDArrayFactory::create<float>('c', {10, 4, 1});
|
||||||
|
auto eps = NDArrayFactory::create<float>('c', {10, 4, 1});
|
||||||
|
|
||||||
|
sd::ops::dot_product_attention_bp op;
|
||||||
|
auto result = op.execute({&queries, &keys, &values, &eps}, {}, {1, 0}, {});
|
||||||
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
|
TEST_F(AttentionTests, basic_dot_product_attention_with_weights) {
|
||||||
|
auto keys = NDArrayFactory::create<float>('c', {10, 4, 3});
|
||||||
|
auto values = NDArrayFactory::create<float>('c', {10, 4, 3});
|
||||||
|
auto queries = NDArrayFactory::create<float>('c', {10, 4, 1});
|
||||||
|
|
||||||
|
sd::ops::dot_product_attention op;
|
||||||
|
auto result = op.evaluate({&queries, &keys, &values}, {1, 1});
|
||||||
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(AttentionTests, basic_dot_product_attention_with_mask) {
|
||||||
|
auto keys = NDArrayFactory::create<float>('c', {10, 4, 3});
|
||||||
|
auto values = NDArrayFactory::create<float>('c', {10, 4, 3});
|
||||||
|
auto queries = NDArrayFactory::create<float>('c', {10, 4, 1});
|
||||||
|
auto mask = NDArrayFactory::create<float>('c', {10, 3});
|
||||||
|
mask.assign(1.);
|
||||||
|
|
||||||
|
sd::ops::dot_product_attention op;
|
||||||
|
auto result = op.evaluate({&queries, &keys, &values, &mask}, {1, 0});
|
||||||
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
//AB 2019/05/28 - Segfault on ppc64le - See issue #7657
|
||||||
|
TEST_F(AttentionTests, basic_dot_product_attention_bp_with_mask) {
|
||||||
|
auto keys = NDArrayFactory::create<float>('c', {10, 4, 3});
|
||||||
|
auto values = NDArrayFactory::create<float>('c', {10, 4, 3});
|
||||||
|
auto queries = NDArrayFactory::create<float>('c', {10, 4, 1});
|
||||||
|
auto eps = NDArrayFactory::create<float>('c', {10, 4, 1});
|
||||||
|
auto mask = NDArrayFactory::create<float>('c', {10, 3});
|
||||||
|
mask.assign(1.);
|
||||||
|
|
||||||
|
sd::ops::dot_product_attention_bp op;
|
||||||
|
auto result = op.execute({&queries, &keys, &values, &eps, &mask}, {}, {1, 0}, {});
|
||||||
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
|
TEST_F(AttentionTests, multi_head_input_dot_product_attention_with_mask) {
|
||||||
|
auto keys = NDArrayFactory::create<float>('c', {2, 5, 4, 3});
|
||||||
|
auto values = NDArrayFactory::create<float>('c', {2, 5, 4, 3});
|
||||||
|
auto queries = NDArrayFactory::create<float>('c', {2, 5, 4, 1});
|
||||||
|
auto mask = NDArrayFactory::create<float>('c', {2, 3});
|
||||||
|
mask.assign(1.);
|
||||||
|
|
||||||
|
sd::ops::dot_product_attention op;
|
||||||
|
auto result = op.evaluate({&queries, &keys, &values, &mask}, {1, 0});
|
||||||
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
//AB 2019/05/30 - Segfault on ppc64le - See issue #7657
|
||||||
|
TEST_F(AttentionTests, multi_head_input_dot_product_attention_bp_with_mask) {
|
||||||
|
auto keys = NDArrayFactory::create<float>('c', {2, 5, 4, 3});
|
||||||
|
auto values = NDArrayFactory::create<float>('c', {2, 5, 4, 3});
|
||||||
|
auto queries = NDArrayFactory::create<float>('c', {2, 5, 4, 1});
|
||||||
|
auto eps = NDArrayFactory::create<float>('c', {2, 5, 4, 1});
|
||||||
|
auto mask = NDArrayFactory::create<float>('c', {2, 3});
|
||||||
|
mask.assign(1.);
|
||||||
|
|
||||||
|
sd::ops::dot_product_attention_bp op;
|
||||||
|
auto result = op.execute({&queries, &keys, &values, &eps, &mask}, {}, {1, 0}, {});
|
||||||
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(AttentionTests, basic_multi_head_dot_product_attention) {
|
||||||
|
auto keys = NDArrayFactory::create<float>('c', {10, 4, 5});
|
||||||
|
auto values = NDArrayFactory::create<float>('c', {10, 4, 5});
|
||||||
|
auto queries = NDArrayFactory::create<float>('c', {10, 4, 2});
|
||||||
|
|
||||||
|
auto Wk = NDArrayFactory::create<float>('c', {2, 3, 4});
|
||||||
|
auto Wv = NDArrayFactory::create<float>('c', {2, 3, 4});
|
||||||
|
auto Wq = NDArrayFactory::create<float>('c', {2, 3, 4});
|
||||||
|
auto Wo = NDArrayFactory::create<float>('c', {2* 3, 4});
|
||||||
|
|
||||||
|
sd::ops::multi_head_dot_product_attention op;
|
||||||
|
auto result = op.evaluate({&queries, &keys, &values, &Wk, &Wv, &Wq, &Wo}, {1, 0});
|
||||||
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
//AB 2019/05/30 - Other attention BP tests are segfaulting on ppc64le - disabling this pre-emptively - See issue #7657
|
||||||
|
TEST_F(AttentionTests, basic_multi_head_dot_product_bp_attention) {
|
||||||
|
auto keys = NDArrayFactory::create<float>('c', {10, 4, 5});
|
||||||
|
auto values = NDArrayFactory::create<float>('c', {10, 4, 5});
|
||||||
|
auto queries = NDArrayFactory::create<float>('c', {10, 4, 2});
|
||||||
|
|
||||||
|
auto Wk = NDArrayFactory::create<float>('c', {2, 3, 4});
|
||||||
|
auto Wv = NDArrayFactory::create<float>('c', {2, 3, 4});
|
||||||
|
auto Wq = NDArrayFactory::create<float>('c', {2, 3, 4});
|
||||||
|
auto Wo = NDArrayFactory::create<float>('c', {2* 3, 7});
|
||||||
|
|
||||||
|
auto eps = NDArrayFactory::create<float>('c', {10, 7, 2});
|
||||||
|
|
||||||
|
|
||||||
|
sd::ops::multi_head_dot_product_attention_bp op;
|
||||||
|
auto result = op.execute({&queries, &keys, &values, &Wk, &Wv, &Wq, &Wo, &eps}, {}, {1, 0}, {});
|
||||||
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
|
TEST_F(AttentionTests, basic_multi_head_dot_product_attention_with_mask) {
|
||||||
|
auto keys = NDArrayFactory::create<float>('c', {10, 4, 5});
|
||||||
|
auto values = NDArrayFactory::create<float>('c', {10, 4, 5});
|
||||||
|
auto queries = NDArrayFactory::create<float>('c', {10, 4, 2});
|
||||||
|
|
||||||
|
auto Wk = NDArrayFactory::create<float>('c', {2, 3, 4});
|
||||||
|
auto Wv = NDArrayFactory::create<float>('c', {2, 3, 4});
|
||||||
|
auto Wq = NDArrayFactory::create<float>('c', {2, 3, 4});
|
||||||
|
auto Wo = NDArrayFactory::create<float>('c', {2* 3, 4});
|
||||||
|
|
||||||
|
auto mask = NDArrayFactory::create<float>('c', {10, 5});
|
||||||
|
mask.assign(1.);
|
||||||
|
|
||||||
|
|
||||||
|
sd::ops::multi_head_dot_product_attention op;
|
||||||
|
auto result = op.evaluate({&queries, &keys, &values, &Wk, &Wv, &Wq, &Wo, &mask}, {1, 0});
|
||||||
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
//AB 2019/05/30 - Other attention BP tests are segfaulting on ppc64le - disabling this pre-emptively - See issue #7657
|
||||||
|
TEST_F(AttentionTests, basic_multi_head_dot_product_bp_attention_with_mask) {
|
||||||
|
auto keys = NDArrayFactory::create<float>('c', {10, 4, 5});
|
||||||
|
auto values = NDArrayFactory::create<float>('c', {10, 4, 5});
|
||||||
|
auto queries = NDArrayFactory::create<float>('c', {10, 4, 2});
|
||||||
|
|
||||||
|
auto Wk = NDArrayFactory::create<float>('c', {2, 3, 4});
|
||||||
|
auto Wv = NDArrayFactory::create<float>('c', {2, 3, 4});
|
||||||
|
auto Wq = NDArrayFactory::create<float>('c', {2, 3, 4});
|
||||||
|
auto Wo = NDArrayFactory::create<float>('c', {2* 3, 7});
|
||||||
|
|
||||||
|
auto eps = NDArrayFactory::create<float>('c', {10, 7, 2});
|
||||||
|
|
||||||
|
auto mask = NDArrayFactory::create<float>('c', {10, 5});
|
||||||
|
mask.assign(1.);
|
||||||
|
|
||||||
|
|
||||||
|
sd::ops::multi_head_dot_product_attention_bp op;
|
||||||
|
auto result = op.execute({&queries, &keys, &values, &Wk, &Wv, &Wq, &Wo, &eps, &mask}, {}, {1, 0}, {});
|
||||||
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
*/
|
|
@ -0,0 +1,51 @@
|
||||||
|
/* ******************************************************************************
|
||||||
|
*
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* See the NOTICE file distributed with this work for additional
|
||||||
|
* information regarding copyright ownership.
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// Created by raver119 on 13.01.2018.
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "testlayers.h"
|
||||||
|
#include <ops/declarable/CustomOperations.h>
|
||||||
|
|
||||||
|
using namespace sd;
|
||||||
|
using namespace sd::ops;
|
||||||
|
using namespace sd::graph;
|
||||||
|
|
||||||
|
class BackpropTests : public testing::Test {
|
||||||
|
public:
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(BackpropTests, Test_Add_1) {
|
||||||
|
|
||||||
|
NDArray x('c', {2, 3, 4}, sd::DataType::FLOAT32);
|
||||||
|
NDArray y('c', {3, 4}, sd::DataType::FLOAT32);
|
||||||
|
NDArray e('c', {2, 3, 4}, sd::DataType::FLOAT32);
|
||||||
|
|
||||||
|
sd::ops::add_bp op;
|
||||||
|
auto result = op.evaluate({&x, &y, &e});
|
||||||
|
|
||||||
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
|
auto eps = result.at(0);
|
||||||
|
auto grad = result.at(1);
|
||||||
|
|
||||||
|
ASSERT_TRUE(x.isSameShape(eps));
|
||||||
|
ASSERT_TRUE(y.isSameShape(grad));
|
||||||
|
}
|
|
@ -0,0 +1,78 @@
|
||||||
|
/* ******************************************************************************
|
||||||
|
*
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* See the NOTICE file distributed with this work for additional
|
||||||
|
* information regarding copyright ownership.
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// Created by raver119 on 10.11.2017.
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "testlayers.h"
|
||||||
|
#include <ops/declarable/CustomOperations.h>
|
||||||
|
#include <array/NDArray.h>
|
||||||
|
#include <legacy/NativeOps.h>
|
||||||
|
#include <helpers/BitwiseUtils.h>
|
||||||
|
|
||||||
|
using namespace sd;
|
||||||
|
using namespace sd::graph;
|
||||||
|
|
||||||
|
class BitwiseUtilsTests : public testing::Test {
|
||||||
|
public:
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
// oviously, this test will fail on big-endian machines, but who cares
|
||||||
|
TEST_F(BitwiseUtilsTests, Test_Runtime_Endianess_1) {
|
||||||
|
bool isBE = BitwiseUtils::isBE();
|
||||||
|
|
||||||
|
ASSERT_FALSE(isBE);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(BitwiseUtilsTests, Test_ValueBit_1) {
|
||||||
|
int idx = BitwiseUtils::valueBit(1);
|
||||||
|
|
||||||
|
ASSERT_EQ(0, idx);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(BitwiseUtilsTests, Test_ValueBit_2) {
|
||||||
|
int idx = BitwiseUtils::valueBit(2);
|
||||||
|
|
||||||
|
ASSERT_EQ(1, idx);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(BitwiseUtilsTests, Test_ValueBits_1) {
|
||||||
|
std::vector<int> expected({1, 1});
|
||||||
|
while (expected.size() < 32)
|
||||||
|
expected.push_back(0);
|
||||||
|
|
||||||
|
std::vector<int> result = BitwiseUtils::valueBits(3);
|
||||||
|
|
||||||
|
ASSERT_EQ(32, result.size());
|
||||||
|
ASSERT_EQ(expected, result);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(BitwiseUtilsTests, Test_ValueBits_2) {
|
||||||
|
int value = 48;
|
||||||
|
int flipped = BitwiseUtils::flip_bits(value);
|
||||||
|
|
||||||
|
ASSERT_NE(value, flipped);
|
||||||
|
|
||||||
|
auto o = BitwiseUtils::valueBits(value);
|
||||||
|
auto f = BitwiseUtils::valueBits(flipped);
|
||||||
|
|
||||||
|
for (int e = 0; e < o.size(); e++)
|
||||||
|
ASSERT_NE(o.at(e), f.at(e));
|
||||||
|
}
|
|
@ -0,0 +1,150 @@
|
||||||
|
/* ******************************************************************************
|
||||||
|
*
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* See the NOTICE file distributed with this work for additional
|
||||||
|
* information regarding copyright ownership.
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// Created by raver119 on 13.10.2017.
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "testlayers.h"
|
||||||
|
#include <ops/declarable/CustomOperations.h>
|
||||||
|
|
||||||
|
using namespace sd;
|
||||||
|
using namespace sd::ops;
|
||||||
|
using namespace sd::graph;
|
||||||
|
|
||||||
|
class BooleanOpsTests : public testing::Test {
|
||||||
|
public:
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(BooleanOpsTests, LtTest_1) {
|
||||||
|
auto x = NDArrayFactory::create_(1.0f);
|
||||||
|
auto y = NDArrayFactory::create_(2.0f);
|
||||||
|
|
||||||
|
sd::ops::lt_scalar op;
|
||||||
|
|
||||||
|
|
||||||
|
ASSERT_TRUE(op.verify({x, y}));
|
||||||
|
|
||||||
|
delete x;
|
||||||
|
delete y;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(BooleanOpsTests, LtTest_2) {
|
||||||
|
auto x = NDArrayFactory::create_(2.0f);
|
||||||
|
auto y = NDArrayFactory::create_(1.0f);
|
||||||
|
|
||||||
|
sd::ops::lt_scalar op;
|
||||||
|
|
||||||
|
|
||||||
|
ASSERT_FALSE(op.verify({x, y}));
|
||||||
|
|
||||||
|
delete x;
|
||||||
|
delete y;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(BooleanOpsTests, Is_non_decreasing_1) {
|
||||||
|
auto x = NDArrayFactory::create<double>('c', {2 , 2}, {1, 2, 4, 4});
|
||||||
|
|
||||||
|
sd::ops::is_non_decreasing op;
|
||||||
|
|
||||||
|
ASSERT_TRUE(op.verify({&x}));
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(BooleanOpsTests, Is_non_decreasing_2) {
|
||||||
|
auto x = NDArrayFactory::create<double>('c', {2 , 2}, {1, 2, 4, 3});
|
||||||
|
|
||||||
|
sd::ops::is_non_decreasing op;
|
||||||
|
|
||||||
|
ASSERT_FALSE(op.verify({&x}));
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(BooleanOpsTests, Is_strictly_increasing_1) {
|
||||||
|
auto x = NDArrayFactory::create<double>('c', {2 , 2}, {1, 2, 4, 5});
|
||||||
|
|
||||||
|
sd::ops::is_strictly_increasing op;
|
||||||
|
|
||||||
|
ASSERT_TRUE(op.verify({&x}));
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(BooleanOpsTests, Is_strictly_increasing_2) {
|
||||||
|
auto x = NDArrayFactory::create<double>('c', {2 , 2}, {1, 2, 3, 3});
|
||||||
|
|
||||||
|
sd::ops::is_strictly_increasing op;
|
||||||
|
|
||||||
|
ASSERT_FALSE(op.verify({&x}));
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(BooleanOpsTests, Is_strictly_increasing_3) {
|
||||||
|
auto x = NDArrayFactory::create<double>('c', {2 , 2}, {1, 2, 4, 3});
|
||||||
|
|
||||||
|
sd::ops::is_strictly_increasing op;
|
||||||
|
|
||||||
|
ASSERT_FALSE(op.verify({&x}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(BooleanOpsTests, Is_strictly_increasing_5) {
|
||||||
|
auto x = NDArrayFactory::create<double>('c', {64, 512});
|
||||||
|
x.linspace(1.0);
|
||||||
|
|
||||||
|
sd::ops::is_strictly_increasing op;
|
||||||
|
|
||||||
|
ASSERT_TRUE(op.verify({&x}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(BooleanOpsTests, Is_strictly_increasing_6) {
|
||||||
|
auto x = NDArrayFactory::create<double>('c', {64, 512});
|
||||||
|
x.linspace(1.0);
|
||||||
|
|
||||||
|
x.p(18, 1000323.f);
|
||||||
|
|
||||||
|
sd::ops::is_strictly_increasing op;
|
||||||
|
|
||||||
|
ASSERT_FALSE(op.verify({&x}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(BooleanOpsTests, Is_numeric_tensor_1) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {2 , 2}, {1.f, 2.f, 4.f, 3.f});
|
||||||
|
|
||||||
|
sd::ops::is_numeric_tensor op;
|
||||||
|
|
||||||
|
ASSERT_TRUE(op.verify({&x}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(BooleanOpsTests, test_where_1) {
|
||||||
|
auto x = NDArrayFactory::create<double>('c', {6}, { 1, -3, 4, 8, -2, 5 });
|
||||||
|
auto y = NDArrayFactory::create<double>('c', {6}, { 2, -3, 1, 1, -2, 1 });
|
||||||
|
auto e = NDArrayFactory::create<double>('c', {3}, { 4, 8, 5 });
|
||||||
|
|
||||||
|
sd::ops::choose op;
|
||||||
|
|
||||||
|
auto result = op.evaluate({&x, &y}, {3});
|
||||||
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
|
auto z = result.at(0);
|
||||||
|
|
||||||
|
//z->printIndexedBuffer("z");
|
||||||
|
|
||||||
|
ASSERT_EQ(e, *z);
|
||||||
|
}
|
||||||
|
|
|
@ -0,0 +1,857 @@
|
||||||
|
/* ******************************************************************************
|
||||||
|
*
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* See the NOTICE file distributed with this work for additional
|
||||||
|
* information regarding copyright ownership.
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// Created by raver119 on 23.11.17.
|
||||||
|
//
|
||||||
|
|
||||||
|
|
||||||
|
#include "testlayers.h"
|
||||||
|
#include <graph/Graph.h>
|
||||||
|
#include <graph/Node.h>
|
||||||
|
#include <ops/declarable/CustomOperations.h>
|
||||||
|
|
||||||
|
using namespace sd;
|
||||||
|
using namespace sd::graph;
|
||||||
|
|
||||||
|
class BroadcastableOpsTests : public testing::Test {
|
||||||
|
public:
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(BroadcastableOpsTests, Test_Add_1) {
|
||||||
|
|
||||||
|
NDArray x('c', {5, 5}, sd::DataType::FLOAT32);
|
||||||
|
NDArray y('c', {1, 5}, sd::DataType::FLOAT32);
|
||||||
|
NDArray exp('c', {5, 5}, sd::DataType::FLOAT32);
|
||||||
|
x.linspace(1);
|
||||||
|
y.linspace(1);
|
||||||
|
exp.linspace(1);
|
||||||
|
|
||||||
|
//exp.printIndexedBuffer("E B");
|
||||||
|
|
||||||
|
exp.applyBroadcast(broadcast::Add, {1}, y, exp);
|
||||||
|
|
||||||
|
sd::ops::add op;
|
||||||
|
auto result = op.evaluate({&x, &y});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
|
auto z = result.at(0);
|
||||||
|
|
||||||
|
//exp.printIndexedBuffer("E A");
|
||||||
|
//z->printIndexedBuffer("Z");
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(BroadcastableOpsTests, Test_Multiply_1) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {5, 5});
|
||||||
|
auto y = NDArrayFactory::create<float>('c', {1, 5});
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', {5, 5});
|
||||||
|
x.linspace(1);
|
||||||
|
y.linspace(1);
|
||||||
|
exp.linspace(1);
|
||||||
|
|
||||||
|
exp.applyBroadcast(broadcast::Multiply, {1}, y, exp);
|
||||||
|
|
||||||
|
sd::ops::multiply op;
|
||||||
|
auto result = op.evaluate({&x, &y});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
|
auto z = result.at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(BroadcastableOpsTests, Test_SquaredSubtract_1) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {5, 5});
|
||||||
|
auto y = NDArrayFactory::create<float>('c', {1, 5});
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', {5, 5});
|
||||||
|
x.linspace(1);
|
||||||
|
y.linspace(1);
|
||||||
|
exp.linspace(1);
|
||||||
|
|
||||||
|
exp.applyBroadcast(broadcast::SquaredSubtract, {1}, y, exp);
|
||||||
|
|
||||||
|
|
||||||
|
sd::ops::squaredsubtract op;
|
||||||
|
auto result = op.evaluate({&x, &y});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
|
auto z = result.at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(BroadcastableOpsTests, Test_ScalarBroadcast_1) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {1, 1}, {1});
|
||||||
|
auto y = NDArrayFactory::create<float>('c', {1, 3}, {0, 1, 2});
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', {1,3}, {1, 0, -1});
|
||||||
|
|
||||||
|
sd::ops::subtract op;
|
||||||
|
auto result = op.evaluate({&x, &y});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
|
auto z = result.at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(BroadcastableOpsTests, Test_ScalarBroadcast_2) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {1, 1}, {1});
|
||||||
|
auto y = NDArrayFactory::create<float>('c', {1, 3}, {0, 1, 2});
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', {1,3}, {1, 2, 3});
|
||||||
|
|
||||||
|
sd::ops::add op;
|
||||||
|
auto result = op.evaluate({&x, &y});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
|
auto z = result.at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(BroadcastableOpsTests, Test_Maximum_1) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {2, 3}, {1, 2, 1, 2, 3, 2});
|
||||||
|
auto row = NDArrayFactory::create<float>('c', {1, 3}, {2, 2, 2});
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', {2, 3}, {2, 2, 2, 2, 3, 2});
|
||||||
|
|
||||||
|
sd::ops::maximum op;
|
||||||
|
auto result = op.evaluate({&x, &row});
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
|
auto z = result.at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(BroadcastableOpsTests, Test_Minimum_1) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {2, 3}, {1, 2, 1, 2, 3, 2});
|
||||||
|
auto col = NDArrayFactory::create<float>('c', {2, 1}, {2, 1});
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', {2, 3}, {1, 2, 1, 1, 1, 1});
|
||||||
|
|
||||||
|
sd::ops::minimum op;
|
||||||
|
auto result = op.evaluate({&x, &col});
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
|
auto z = result.at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(BroadcastableOpsTests, Test_Shape_1) {
|
||||||
|
sd::ops::minimum op;
|
||||||
|
|
||||||
|
Nd4jLong shapeX[] = {2, 2, 5, 5, 1, 8192, 1, 99};
|
||||||
|
Nd4jLong shapeY[] = {2, 2, 5, 5, 1, 8192, 1, 99};
|
||||||
|
ShapeList inputShape({shapeX, shapeY});
|
||||||
|
VariableSpace vs;
|
||||||
|
Context ctx(1, &vs, false);
|
||||||
|
|
||||||
|
auto shapes = op.calculateOutputShape(&inputShape, ctx);
|
||||||
|
|
||||||
|
auto shapeZ = shapes->at(0);
|
||||||
|
ASSERT_TRUE(shape::shapeEquals(shapeX, shapeZ));
|
||||||
|
|
||||||
|
delete shapes;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(BroadcastableOpsTests, Test_Shape_2) {
|
||||||
|
sd::ops::minimum op;
|
||||||
|
|
||||||
|
const Nd4jLong shapeX[] = {2, 1, 1, 1, 1, 8192, 1, 99};
|
||||||
|
const Nd4jLong shapeY[] = {2, 2, 5, 5, 1, 8192, 1, 99};
|
||||||
|
ShapeList inputShape({shapeX, shapeY});
|
||||||
|
VariableSpace vs;
|
||||||
|
Context ctx(1, &vs, false);
|
||||||
|
|
||||||
|
auto shapes = op.calculateOutputShape(&inputShape, ctx);
|
||||||
|
|
||||||
|
auto shapeZ = shapes->at(0);
|
||||||
|
ASSERT_TRUE(shape::shapeEquals(shapeY, shapeZ));
|
||||||
|
|
||||||
|
delete shapes;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(BroadcastableOpsTests, Test_Shape_3) {
|
||||||
|
sd::ops::minimum op;
|
||||||
|
|
||||||
|
const Nd4jLong shapeX[] = {2, 5, 3, 1, 1, 8192, 1, 99};
|
||||||
|
const Nd4jLong shapeY[] = {2, 1, 3, 3, 1, 8192, 1, 99};
|
||||||
|
ShapeList inputShape({shapeX, shapeY});
|
||||||
|
VariableSpace vs;
|
||||||
|
Context ctx(1, &vs, false);
|
||||||
|
|
||||||
|
auto shapes = op.calculateOutputShape(&inputShape, ctx);
|
||||||
|
|
||||||
|
auto shapeZ = shapes->at(0);
|
||||||
|
ASSERT_TRUE(shape::shapeEquals(shapeX, shapeZ));
|
||||||
|
|
||||||
|
delete shapes;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(BroadcastableOpsTests, Test_Shape_4) {
|
||||||
|
sd::ops::minimum op;
|
||||||
|
|
||||||
|
const Nd4jLong shapeX[] = {2, 5, 3, 1, 1, 8192, 1, 99};
|
||||||
|
const Nd4jLong shapeY[] = {2, 5, 1, 1, 1, 8192, 1, 99};
|
||||||
|
ShapeList inputShape({shapeX, shapeY});
|
||||||
|
VariableSpace vs;
|
||||||
|
Context ctx(1, &vs, false);
|
||||||
|
|
||||||
|
auto shapes = op.calculateOutputShape(&inputShape, ctx);
|
||||||
|
|
||||||
|
auto shapeZ = shapes->at(0);
|
||||||
|
ASSERT_TRUE(shape::shapeEquals(shapeX, shapeZ));
|
||||||
|
|
||||||
|
delete shapes;
|
||||||
|
}
|
||||||
|
|
||||||
|
// (2,1,3) + (4,3) = (2,4,3)
|
||||||
|
|
||||||
|
TEST_F(BroadcastableOpsTests, Test_Shape_5) {
|
||||||
|
sd::ops::minimum op;
|
||||||
|
|
||||||
|
const Nd4jLong shapeX[] = {3, 2, 1, 3, 3, 3, 1, 8192, 1, 99};
|
||||||
|
const Nd4jLong shapeY[] = {2, 4, 3, 3, 1, 8192, 1, 99};
|
||||||
|
const Nd4jLong shapeE[] = {3, 2, 4, 3, 12, 3, 1, 8192, 1, 99};
|
||||||
|
ShapeList inputShape({shapeX, shapeY});
|
||||||
|
VariableSpace vs;
|
||||||
|
Context ctx(1, &vs, false);
|
||||||
|
|
||||||
|
auto shapes = op.calculateOutputShape(&inputShape, ctx);
|
||||||
|
|
||||||
|
auto shapeZ = shapes->at(0);
|
||||||
|
ASSERT_TRUE(shape::shapeEquals(shapeE, shapeZ));
|
||||||
|
|
||||||
|
delete shapes;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(BroadcastableOpsTests, Test_Scalar_Add_1) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {2, 2}, {1, 2, 3, 4});
|
||||||
|
auto y = NDArrayFactory::create<float>(2.0f);
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', {2, 2}, {3, 4, 5, 6});
|
||||||
|
|
||||||
|
sd::ops::add op;
|
||||||
|
auto result = op.evaluate({&x, &y});
|
||||||
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
|
auto z = result.at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(BroadcastableOpsTests, Test_Inplace_Output_1) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {2, 1, 3});
|
||||||
|
auto y = NDArrayFactory::create<float>('c', {4, 3});
|
||||||
|
auto o = NDArrayFactory::create<float>('c', {2, 4, 3});
|
||||||
|
auto e = NDArrayFactory::create<float>('c', {2, 4, 3});
|
||||||
|
auto buffO1 = reinterpret_cast<float *>(o.buffer());
|
||||||
|
y.assign(1.0f);
|
||||||
|
e.assign(1.0f);
|
||||||
|
|
||||||
|
sd::ops::add op;
|
||||||
|
auto result = op.execute({&x, &y}, {&o}, {}, {}, {});
|
||||||
|
ASSERT_EQ(Status::OK(), result);
|
||||||
|
|
||||||
|
auto buffO2 = reinterpret_cast<float *>(o.buffer());
|
||||||
|
|
||||||
|
ASSERT_TRUE(e.isSameShape(o));
|
||||||
|
ASSERT_TRUE(e.equalsTo(o));
|
||||||
|
|
||||||
|
ASSERT_TRUE(buffO1 == buffO2);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(BroadcastableOpsTests, Test_Subtract_1) {
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create<float>(1.0f);
|
||||||
|
auto y = NDArrayFactory::create<float>('c', {2}, {0.0f, 1.0f});
|
||||||
|
auto e = NDArrayFactory::create<float>('c', {2}, {1.0f, 0.0f});
|
||||||
|
|
||||||
|
auto z = x - y;
|
||||||
|
|
||||||
|
ASSERT_TRUE(e.equalsTo(z));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(BroadcastableOpsTests, Test_Subtract_2) {
|
||||||
|
auto x = NDArrayFactory::create<float>(1.0f);
|
||||||
|
auto y = NDArrayFactory::create<float>('c', {2}, {0.0f, 1.0f});
|
||||||
|
auto e = NDArrayFactory::create<float>('c', {2}, {1.0f, 0.0f});
|
||||||
|
|
||||||
|
sd::ops::subtract op;
|
||||||
|
auto result = op.evaluate({&x, &y});
|
||||||
|
auto z = result.at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(e.equalsTo(z));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(BroadcastableOpsTests, Test_Subtract_3) {
|
||||||
|
auto x = NDArrayFactory::create<float>(1.0f);
|
||||||
|
auto y = NDArrayFactory::create<float>('c', {2}, {0.0f, 1.0f});
|
||||||
|
auto z = NDArrayFactory::create<float>('c', {2}, {0.0f, 0.0f});
|
||||||
|
auto e = NDArrayFactory::create<float>('c', {2}, {1.0f, 0.0f});
|
||||||
|
|
||||||
|
sd::ops::subtract op;
|
||||||
|
auto result = op.execute({&x, &y}, {&z}, {}, {}, {});
|
||||||
|
|
||||||
|
ASSERT_EQ(Status::OK(), result);
|
||||||
|
ASSERT_TRUE(e.equalsTo(z));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(BroadcastableOpsTests, Test_Subtract_4) {
|
||||||
|
auto x = NDArrayFactory::create<float>(1.0f);
|
||||||
|
auto y = NDArrayFactory::create<float>('c', {2}, {0.0f, 1.0f});
|
||||||
|
auto e = NDArrayFactory::create<float>('c', {2}, {1.0f, 0.0f});
|
||||||
|
|
||||||
|
auto z = x.applyTrueBroadcast(BroadcastOpsTuple::Subtract(), y);
|
||||||
|
|
||||||
|
ASSERT_TRUE(e.isSameShape(z));
|
||||||
|
ASSERT_TRUE(e.equalsTo(z));
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(BroadcastableOpsTests, Test_Subtract_5) {
|
||||||
|
auto x = NDArrayFactory::create<float>(1.0f);
|
||||||
|
auto y = NDArrayFactory::create<float>('c', {2}, {0.0f, 1.0f});
|
||||||
|
auto e = NDArrayFactory::create<float>('c', {2}, {-1., 0.});
|
||||||
|
|
||||||
|
auto z = y - x;
|
||||||
|
|
||||||
|
ASSERT_TRUE(e.equalsTo(z));
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(BroadcastableOpsTests, Test_Subtract_6) {
|
||||||
|
auto x = NDArrayFactory::create<float>(1.0f);
|
||||||
|
auto y = NDArrayFactory::create<float>(4.f);
|
||||||
|
auto e = NDArrayFactory::create<float>(3.f);
|
||||||
|
|
||||||
|
auto z = y - x;
|
||||||
|
|
||||||
|
ASSERT_TRUE(e.equalsTo(z));
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(BroadcastableOpsTests, Test_Subtract_7) {
|
||||||
|
auto x = NDArrayFactory::create<float>(1.0f);
|
||||||
|
auto y = NDArrayFactory::create<float>(4.f);
|
||||||
|
auto e = NDArrayFactory::create<float>(-3.f);
|
||||||
|
|
||||||
|
auto z = x - y;
|
||||||
|
|
||||||
|
ASSERT_TRUE(e.equalsTo(z));
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(BroadcastableOpsTests, Test_Add_2) {
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create<float>(1.0f);
|
||||||
|
auto y = NDArrayFactory::create<float>('c', {2}, {0.0f, 1.0f});
|
||||||
|
auto e = NDArrayFactory::create<float>('c', {2}, {1.f, 2.f});
|
||||||
|
|
||||||
|
auto z = x + y;
|
||||||
|
|
||||||
|
ASSERT_TRUE(e.equalsTo(z));
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(BroadcastableOpsTests, Test_Add_3) {
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create<float>(1.0f);
|
||||||
|
auto y = NDArrayFactory::create<float>('c', {2}, {0.0f, 1.0f});
|
||||||
|
auto e = NDArrayFactory::create<float>('c', {2}, {1.f, 2.f});
|
||||||
|
|
||||||
|
auto z = y + x;
|
||||||
|
|
||||||
|
ASSERT_TRUE(e.equalsTo(z));
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(BroadcastableOpsTests, Test_Add_4) {
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create<float>(1.0f);
|
||||||
|
auto y = NDArrayFactory::create<float>(4.f);
|
||||||
|
auto e = NDArrayFactory::create<float>(5.f);
|
||||||
|
|
||||||
|
auto z = x + y;
|
||||||
|
|
||||||
|
ASSERT_TRUE(e.equalsTo(z));
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(BroadcastableOpsTests, Test_Add_5) {
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create<float>(1.0f);
|
||||||
|
auto y = NDArrayFactory::create<float>(4.f);
|
||||||
|
auto e = NDArrayFactory::create<float>(5.f);
|
||||||
|
|
||||||
|
auto z = y + x;
|
||||||
|
|
||||||
|
ASSERT_TRUE(e.equalsTo(z));
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(BroadcastableOpsTests, Test_Multiply_2) {
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create<float>(2.0f);
|
||||||
|
auto y = NDArrayFactory::create<float>('c', {2}, {3.f, 4.f});
|
||||||
|
auto e = NDArrayFactory::create<float>('c', {2}, {6.f, 8.f});
|
||||||
|
|
||||||
|
auto z = y * x;
|
||||||
|
|
||||||
|
ASSERT_TRUE(e.equalsTo(z));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(BroadcastableOpsTests, Test_Multiply_3) {
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create<float>(2.0f);
|
||||||
|
auto y = NDArrayFactory::create<float>('c', {2}, {3.f, 4.f});
|
||||||
|
auto e = NDArrayFactory::create<float>('c', {2}, {6.f, 8.f});
|
||||||
|
|
||||||
|
auto z = x * y;
|
||||||
|
|
||||||
|
ASSERT_TRUE(e.equalsTo(z));
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(BroadcastableOpsTests, Test_Multiply_4) {
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create<float>(2.0f);
|
||||||
|
auto y = NDArrayFactory::create<float>(4.f);
|
||||||
|
auto e = NDArrayFactory::create<float>(8.f);
|
||||||
|
|
||||||
|
auto z = y * x;
|
||||||
|
|
||||||
|
ASSERT_TRUE(e.equalsTo(z));
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(BroadcastableOpsTests, Test_Multiply_5) {
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create<float>(2.0f);
|
||||||
|
auto y = NDArrayFactory::create<float>(4.f);
|
||||||
|
auto e = NDArrayFactory::create<float>(8.f);
|
||||||
|
|
||||||
|
auto z = x * y;
|
||||||
|
|
||||||
|
ASSERT_TRUE(e.equalsTo(z));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(BroadcastableOpsTests, Test_Multiply_6) {
|
||||||
|
auto x = NDArrayFactory::create<float>(2.0f);
|
||||||
|
auto y = NDArrayFactory::create<float>('c', {1}, {4.f});
|
||||||
|
auto e = NDArrayFactory::create<float>('c', {1}, {8.f});
|
||||||
|
|
||||||
|
auto z = x * y;
|
||||||
|
|
||||||
|
ASSERT_TRUE(e.equalsTo(z));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(BroadcastableOpsTests, Test_Multiply_7) {
|
||||||
|
auto x = NDArrayFactory::create<float>(2.0f);
|
||||||
|
auto y = NDArrayFactory::create<float>('c', {1}, {4.f});
|
||||||
|
auto e = NDArrayFactory::create<float>('c', {1}, {8.f});
|
||||||
|
|
||||||
|
sd::ops::multiply op;
|
||||||
|
auto result = op.evaluate({&x, &y});
|
||||||
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
|
auto z = result.at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(e.equalsTo(z));
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(BroadcastableOpsTests, Test_Multiply_8) {
|
||||||
|
auto x = NDArrayFactory::create<float>(2.0f);
|
||||||
|
auto y = NDArrayFactory::create<float>('c', {1, 1}, {4.f});
|
||||||
|
auto e = NDArrayFactory::create<float>('c', {1, 1}, {8.f});
|
||||||
|
|
||||||
|
sd::ops::multiply op;
|
||||||
|
auto result = op.evaluate({&x, &y});
|
||||||
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
|
auto z = result.at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(e.equalsTo(z));
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(BroadcastableOpsTests, broadcast_add_1) {
|
||||||
|
|
||||||
|
NDArray x('c', {4}, {1,1,1,1});
|
||||||
|
NDArray y('c', {1,4}, {1,2,3,4});
|
||||||
|
NDArray z('c', {1,4}, sd::DataType::DOUBLE);
|
||||||
|
NDArray exp('c', {1,4}, {2,3,4,5}, sd::DataType::DOUBLE);
|
||||||
|
|
||||||
|
sd::ops::add op;
|
||||||
|
auto status = op.execute({&x, &y}, {&z});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||||
|
ASSERT_TRUE(z.equalsTo(exp));
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(BroadcastableOpsTests, broadcast_equals_1) {
|
||||||
|
|
||||||
|
NDArray x('c', {1,4}, {1,2,3,4});
|
||||||
|
NDArray y('c', {3,4}, {0,0,0,0, 1,2,3,4, 1,2,3,4});
|
||||||
|
NDArray z('c', {3,4}, sd::DataType::BOOL);
|
||||||
|
NDArray exp('c', {3,4}, {0,0,0,0, 1,1,1,1, 1,1,1,1}, sd::DataType::BOOL);
|
||||||
|
|
||||||
|
sd::ops::equals op;
|
||||||
|
auto status = op.execute({&x, &y}, {&z});
|
||||||
|
// z.printIndexedBuffer();
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||||
|
ASSERT_TRUE(z.equalsTo(exp));
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(BroadcastableOpsTests, broadcast_empty_1) {
|
||||||
|
|
||||||
|
NDArray y('c', {3,4}, {0,0,0,0, 1,2,3,4, 1,2,3,4});
|
||||||
|
NDArray x(sd::DataType::DOUBLE, y.getContext(), false);
|
||||||
|
NDArray z(sd::DataType::DOUBLE, y.getContext(), false);
|
||||||
|
NDArray zExp(sd::DataType::DOUBLE, y.getContext(), false);
|
||||||
|
|
||||||
|
sd::ops::multiply op;
|
||||||
|
auto status = op.execute({&x, &y}, {&z}, {}, {}, {});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||||
|
ASSERT_TRUE(z.isSameShape(zExp));
|
||||||
|
ASSERT_TRUE(z.equalsTo(zExp));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(BroadcastableOpsTests, broadcast_empty_2) {
|
||||||
|
|
||||||
|
NDArray y('c', {1,4}, {1,2,3,4});
|
||||||
|
NDArray x = NDArrayFactory::create<double>('c', {0, 4});
|
||||||
|
NDArray e = NDArrayFactory::create<double>('c', {0, 4});;
|
||||||
|
|
||||||
|
sd::ops::multiply op;
|
||||||
|
auto status = op.execute({&x, &y}, {&x}, {}, {}, {});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||||
|
ASSERT_TRUE(e.isSameShape(x));
|
||||||
|
ASSERT_TRUE(e.equalsTo(x));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(BroadcastableOpsTests, broadcast_empty_3) {
|
||||||
|
|
||||||
|
NDArray x = NDArrayFactory::create<float>('c', {1, 0, 2});
|
||||||
|
NDArray y('c', {}, std::vector<double>{0.1}, sd::DataType::FLOAT32);
|
||||||
|
NDArray e = NDArrayFactory::create<float>('c', {1, 0, 2});;
|
||||||
|
|
||||||
|
sd::ops::maximum op;
|
||||||
|
auto result = op.evaluate({&x, &y});
|
||||||
|
|
||||||
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
|
auto z = result.at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(e.isSameShape(z));
|
||||||
|
ASSERT_TRUE(e.equalsTo(*z));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(BroadcastableOpsTests, broadcast_empty_4) {
|
||||||
|
|
||||||
|
NDArray x = NDArrayFactory::create<float>('c', {1, 0, 1});
|
||||||
|
NDArray y = NDArrayFactory::create<float>('c', {1, 0, 2});
|
||||||
|
NDArray e = NDArrayFactory::create<float>('c', {1, 0, 2});;
|
||||||
|
|
||||||
|
sd::ops::maximum op;
|
||||||
|
auto result = op.evaluate({&x, &y});
|
||||||
|
|
||||||
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
|
auto z = result.at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(e.isSameShape(z));
|
||||||
|
ASSERT_TRUE(e.equalsTo(*z));
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(BroadcastableOpsTests, broadcast_empty_5) {
|
||||||
|
|
||||||
|
NDArray x = NDArrayFactory::create<float>('c', {1, 0, 1});
|
||||||
|
NDArray y = NDArrayFactory::create<float>('c', {1, 0, 2});
|
||||||
|
NDArray e = NDArrayFactory::create<float>('c', {1, 0, 2});;
|
||||||
|
|
||||||
|
sd::ops::realdiv op;
|
||||||
|
auto result = op.evaluate({&x, &y});
|
||||||
|
|
||||||
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
|
auto z = result.at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(e.isSameShape(z));
|
||||||
|
ASSERT_TRUE(e.equalsTo(*z));
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(BroadcastableOpsTests, broadcast_empty_6) {
|
||||||
|
|
||||||
|
NDArray x = NDArrayFactory::create<float>('c', {1, 0, 1});
|
||||||
|
NDArray y = NDArrayFactory::create<float>('c', {1, 2}, {2, 2});
|
||||||
|
NDArray e = NDArrayFactory::create<float>('c', {1, 0, 2});;
|
||||||
|
|
||||||
|
sd::ops::realdiv op;
|
||||||
|
auto result = op.evaluate({&x, &y});
|
||||||
|
|
||||||
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
|
auto z = result.at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(e.isSameShape(z));
|
||||||
|
ASSERT_TRUE(e.equalsTo(*z));
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(BroadcastableOpsTests, broadcast_empty_7) {
|
||||||
|
|
||||||
|
NDArray x = NDArrayFactory::create<float>('c', {1, 0, 2, 1});
|
||||||
|
NDArray y = NDArrayFactory::create<float>('c', {1, 2, 0});
|
||||||
|
NDArray e = NDArrayFactory::create<float>('c', {1, 0, 2, 0});;
|
||||||
|
|
||||||
|
sd::ops::realdiv op;
|
||||||
|
auto result = op.evaluate({&x, &y});
|
||||||
|
|
||||||
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
|
auto z = result.at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(e.isSameShape(z));
|
||||||
|
ASSERT_TRUE(e.equalsTo(*z));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(BroadcastableOpsTests, broadcast_bool_empty_1) {
|
||||||
|
|
||||||
|
NDArray y('c', {3,4}, {0,0,0,0, 1,2,3,4, 1,2,3,4});
|
||||||
|
NDArray x(sd::DataType::DOUBLE, y.getContext(), false);
|
||||||
|
NDArray z(sd::DataType::BOOL, y.getContext(), false);
|
||||||
|
NDArray zExp(sd::DataType::BOOL, y.getContext(), false);
|
||||||
|
|
||||||
|
sd::ops::greater op;
|
||||||
|
auto status = op.execute({&x, &y}, {&z}, {}, {}, {});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||||
|
ASSERT_TRUE(z.isSameShape(zExp));
|
||||||
|
ASSERT_TRUE(z.equalsTo(zExp));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(BroadcastableOpsTests, broadcast_bool_empty_2) {
|
||||||
|
|
||||||
|
NDArray y('c', {1,4}, {1,2,3,4});
|
||||||
|
NDArray x = NDArrayFactory::create<double>('c', {0, 4});
|
||||||
|
NDArray e = NDArrayFactory::create<bool>('c', {0, 4});;
|
||||||
|
|
||||||
|
|
||||||
|
sd::ops::greater op;
|
||||||
|
auto result = op.evaluate({&x, &y});
|
||||||
|
|
||||||
|
auto z = result.at(0);
|
||||||
|
|
||||||
|
// z->printShapeInfo("z");
|
||||||
|
|
||||||
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
ASSERT_TRUE(e.isSameShape(z));
|
||||||
|
ASSERT_TRUE(e.equalsTo(*z));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(BroadcastableOpsTests, broadcast_bool_1) {
|
||||||
|
|
||||||
|
NDArray x('c', {3, 1, 2}, sd::DataType::FLOAT32);
|
||||||
|
NDArray y('c', {2, 2}, sd::DataType::FLOAT32);
|
||||||
|
NDArray z('c', {3, 2, 2}, sd::DataType::BOOL);
|
||||||
|
NDArray e('c', {3, 2, 2}, sd::DataType::BOOL);
|
||||||
|
|
||||||
|
x.assign(4.f);
|
||||||
|
y.assign(2.f);
|
||||||
|
e.assign(true);
|
||||||
|
|
||||||
|
sd::ops::greater op;
|
||||||
|
|
||||||
|
auto status = op.execute({&x, &y}, {&z});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||||
|
|
||||||
|
// z.printIndexedBuffer("Z");
|
||||||
|
|
||||||
|
ASSERT_TRUE(z.isSameShape(e));
|
||||||
|
ASSERT_TRUE(z.equalsTo(e));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(BroadcastableOpsTests, broadcast_bool_2) {
|
||||||
|
|
||||||
|
NDArray x('c', {3, 1, 2}, sd::DataType::FLOAT32);
|
||||||
|
NDArray y('c', {2, 2}, sd::DataType::FLOAT32);
|
||||||
|
NDArray z('c', {3, 2, 2}, sd::DataType::BOOL);
|
||||||
|
NDArray e('c', {3, 2, 2}, sd::DataType::BOOL);
|
||||||
|
|
||||||
|
x.assign(1.f);
|
||||||
|
y.assign(2.f);
|
||||||
|
e.assign(false);
|
||||||
|
|
||||||
|
sd::ops::equals op;
|
||||||
|
|
||||||
|
auto status = op.execute({&x, &y}, {&z}, {}, {}, {});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||||
|
|
||||||
|
// z.printIndexedBuffer("Z");
|
||||||
|
|
||||||
|
ASSERT_TRUE(z.isSameShape(e));
|
||||||
|
ASSERT_TRUE(z.equalsTo(e));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(BroadcastableOpsTests, broadcast_bool_3) {
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create<int>(0);
|
||||||
|
auto y = NDArrayFactory::create<int>('c', {3}, {2, 1, 2});
|
||||||
|
NDArray z('c', {3}, sd::DataType::BOOL);
|
||||||
|
NDArray e('c', {3}, sd::DataType::BOOL);
|
||||||
|
|
||||||
|
e.assign(true);
|
||||||
|
|
||||||
|
sd::ops::less op;
|
||||||
|
auto status = op.execute({&x, &y}, {&z}, {}, {}, {});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||||
|
|
||||||
|
// z.printIndexedBuffer("Z");
|
||||||
|
|
||||||
|
ASSERT_TRUE(z.isSameShape(e));
|
||||||
|
ASSERT_TRUE(z.equalsTo(e));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(BroadcastableOpsTests, broadcast_2) {
|
||||||
|
NDArray x('c', {3, 1, 2}, sd::DataType::FLOAT32);
|
||||||
|
NDArray y('c', {2, 2}, sd::DataType::FLOAT32);
|
||||||
|
NDArray z('c', {3, 2, 2}, sd::DataType::FLOAT32);
|
||||||
|
NDArray e('c', {3, 2, 2}, sd::DataType::FLOAT32);
|
||||||
|
|
||||||
|
x = 4.f;
|
||||||
|
y = 2.f;
|
||||||
|
e = -2.f;
|
||||||
|
|
||||||
|
sd::ops::reversesubtract op; // z = y - x;
|
||||||
|
|
||||||
|
auto status = op.execute({&x, &y}, {&z}, {}, {}, {});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||||
|
|
||||||
|
// z.printIndexedBuffer("Z");
|
||||||
|
|
||||||
|
ASSERT_TRUE(z.isSameShape(e));
|
||||||
|
ASSERT_TRUE(z.equalsTo(e));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(BroadcastableOpsTests, broadcast_3) {
|
||||||
|
auto x = NDArrayFactory::create<int>(0);
|
||||||
|
auto y = NDArrayFactory::create<int>('c', {3}, {2, 1, 2});
|
||||||
|
NDArray z('c', {3}, sd::DataType::INT32);
|
||||||
|
auto e = NDArrayFactory::create<int>('c', {3}, {2, 1, 2});
|
||||||
|
|
||||||
|
sd::ops::add op;
|
||||||
|
auto status = op.execute({&x, &y}, {&z}, {}, {}, {});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||||
|
|
||||||
|
// z.printIndexedBuffer("Z");
|
||||||
|
|
||||||
|
ASSERT_TRUE(z.isSameShape(e));
|
||||||
|
ASSERT_TRUE(z.equalsTo(e));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(BroadcastableOpsTests, test_bert_multiply_1) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {4, 128, 1});
|
||||||
|
auto y = NDArrayFactory::create<float>('c', {4, 1, 128});
|
||||||
|
auto z = NDArrayFactory::create<float>('c', {4, 128, 128});
|
||||||
|
auto e = NDArrayFactory::create<float>('c', {4, 128, 128});
|
||||||
|
|
||||||
|
x.assign(0.f);
|
||||||
|
y.assign(1.f);
|
||||||
|
z.assign(119.f);
|
||||||
|
e.assign(0.f);
|
||||||
|
/*
|
||||||
|
Context ctx(1);
|
||||||
|
ctx.setInputArray(0, &x);
|
||||||
|
ctx.setInputArray(1, &y);
|
||||||
|
ctx.setOutputArray(0, &z);
|
||||||
|
|
||||||
|
sd::ops::multiply op;
|
||||||
|
auto status = op.execute(&ctx);
|
||||||
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
|
||||||
|
z.printIndexedBuffer();
|
||||||
|
*/
|
||||||
|
|
||||||
|
x.applyTrueBroadcast(BroadcastOpsTuple::Multiply(), y, z);
|
||||||
|
|
||||||
|
//z.printIndexedBuffer();
|
||||||
|
|
||||||
|
ASSERT_EQ(e, z);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(BroadcastableOpsTests, test_bert_multiply_2) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {4, 128, 1});
|
||||||
|
auto y = NDArrayFactory::create<float>('c', {768});
|
||||||
|
auto z = NDArrayFactory::create<float>('c', {4, 128, 768});
|
||||||
|
auto e = NDArrayFactory::create<float>('c', {4, 128, 768});
|
||||||
|
|
||||||
|
x.assign(1.f);
|
||||||
|
y.assign(2.f);
|
||||||
|
z.assign(119.f);
|
||||||
|
e.assign(2.f);
|
||||||
|
|
||||||
|
x.applyTrueBroadcast(BroadcastOpsTuple::Multiply(), y, z);
|
||||||
|
|
||||||
|
ASSERT_EQ(e, z);
|
||||||
|
}
|
|
@ -0,0 +1,68 @@
|
||||||
|
/* ******************************************************************************
|
||||||
|
*
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* See the NOTICE file distributed with this work for additional
|
||||||
|
* information regarding copyright ownership.
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// Created by agibsonccc on 1/19/17.
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "testinclude.h"
|
||||||
|
#include <loops/broadcasting.h>
|
||||||
|
|
||||||
|
class BroadcastMultiDimTest : public testing::Test {
|
||||||
|
public:
|
||||||
|
int dimensions[2] = {0,2};
|
||||||
|
Nd4jLong inputShapeBuffer[10] = {3,2,3,5,15,5,1,8192,1,99};
|
||||||
|
float inputData[30] = {1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0,9.0,10.0,11.0,12.0,13.0,14.0,15.0,16.0,17.0,18.0,19.0,20.0,21.0,22.0,23.0,24.0,25.0,26.0,27.0,28.0,29.0,30.0};
|
||||||
|
float dataAssertion[30] = {1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0,9.0,10.0,11.0,12.0,13.0,14.0,15.0,16.0,17.0,18.0,0.0,0.0,21.0,22.0,23.0,0.0,0.0,26.0,27.0,28.0,0.0,0.0};
|
||||||
|
float result[30] = {0.0};
|
||||||
|
float broadcastData[10] = {1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,0.0,0.0};
|
||||||
|
Nd4jLong broadcastShapeInfo[8] = {2,2,5,5,1,8192,1,99};
|
||||||
|
int opNum = 2;
|
||||||
|
int dimensionLength = 2;
|
||||||
|
};
|
||||||
|
|
||||||
|
#ifndef __CUDABLAS__
|
||||||
|
|
||||||
|
TEST_F(BroadcastMultiDimTest,MultimDimTest) {
|
||||||
|
auto tad = new shape::TAD();
|
||||||
|
tad->init(inputShapeBuffer,dimensions,dimensionLength);
|
||||||
|
tad->createTadOnlyShapeInfo();
|
||||||
|
tad-> createOffsets();
|
||||||
|
functions::broadcast::Broadcast<float, float, float>::exec(
|
||||||
|
opNum,
|
||||||
|
inputData, //x
|
||||||
|
inputShapeBuffer, //xShapeInfo
|
||||||
|
broadcastData, //y
|
||||||
|
broadcastShapeInfo, //yShapeInfo
|
||||||
|
result, //result
|
||||||
|
inputShapeBuffer, //resultShapeInfo
|
||||||
|
dimensions, //dimension
|
||||||
|
dimensionLength, //dimensionLength
|
||||||
|
tad->tadOnlyShapeInfo, //tadShapeInfo
|
||||||
|
tad->tadOffsets, //tadOffset
|
||||||
|
tad->tadOnlyShapeInfo, //tadShapeInfoZ
|
||||||
|
tad->tadOffsets, sd::LoopKind::COMMON, 0, tad->numTads); //tadOffsetZ
|
||||||
|
|
||||||
|
for(int i = 0; i < 30; i++) {
|
||||||
|
ASSERT_EQ(dataAssertion[i],result[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
delete tad;
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
|
@ -0,0 +1,171 @@
|
||||||
|
include_directories(${gtest_SOURCE_DIR}/include ${gtest_SOURCE_DIR})
|
||||||
|
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include)
|
||||||
|
if(LINUX)
|
||||||
|
link_directories(/usr/local/lib)
|
||||||
|
link_directories(/usr/lib)
|
||||||
|
link_directories(/lib)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
||||||
|
|
||||||
|
if(APPLE)
|
||||||
|
message("Using apple")
|
||||||
|
link_directories(/usr/local/lib)
|
||||||
|
link_directories(/usr/lib)
|
||||||
|
link_directories(/lib)
|
||||||
|
endif()
|
||||||
|
if(WIN32)
|
||||||
|
get_property(dirs DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} PROPERTY INCLUDE_DIRECTORIES)
|
||||||
|
foreach(dir ${dirs})
|
||||||
|
message(STATUS "dir='${dir}'")
|
||||||
|
endforeach()
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if (SD_CUDA)
|
||||||
|
find_package(CUDA)
|
||||||
|
message("Tests CUDA include directory: ${CUDA_INCLUDE_DIRS}")
|
||||||
|
include_directories(${CUDA_INCLUDE_DIRS})
|
||||||
|
add_definitions(-D__CUDABLAS__=true)
|
||||||
|
|
||||||
|
if(WIN32)
|
||||||
|
message("CUDA on Windows: enabling /EHsc")
|
||||||
|
SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /EHsc /FS")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
string( TOLOWER "${COMPUTE}" COMPUTE_CMP )
|
||||||
|
if ("${COMPUTE_CMP}" STREQUAL "all")
|
||||||
|
CUDA_SELECT_NVCC_ARCH_FLAGS(CUDA_ARCH_FLAGS "Common")
|
||||||
|
elseif("${COMPUTE_CMP}" STREQUAL "auto")
|
||||||
|
CUDA_SELECT_NVCC_ARCH_FLAGS(CUDA_ARCH_FLAGS "Auto")
|
||||||
|
elseif(COMPUTE_CMP MATCHES "^[0-9]+$")
|
||||||
|
#matches USER COMPUTE old way
|
||||||
|
set(CUDA_ARCH_FLAGS "-gencode arch=compute_${COMPUTE},code=sm_${COMPUTE} ")
|
||||||
|
else()
|
||||||
|
#matches numbers NAME | NUM.NUM | NUM.NUM(NUM.NUM) | NUM.NUM+PTX
|
||||||
|
#NAME: Fermi Kepler Maxwell Kepler+Tegra Kepler+Tesla Maxwell+Tegra Pascal
|
||||||
|
#NUM: 2.0 2.1 3.0 3.2 3.5 3.7 5.0 5.2 5.3 6.0 6.2 et cetera
|
||||||
|
CUDA_SELECT_NVCC_ARCH_FLAGS(CUDA_ARCH_FLAGS "${COMPUTE}")
|
||||||
|
endif()
|
||||||
|
# list to spaces
|
||||||
|
string (REPLACE ";" " " CUDA_ARCH_FLAGS "${CUDA_ARCH_FLAGS}")
|
||||||
|
|
||||||
|
set(CMAKE_CUDA_FLAGS " ${CMAKE_CUDA_FLAGS} -DCUDA_VERSION_MAJOR=${CUDA_VERSION_MAJOR} ${EXPM} -w --cudart=static --expt-extended-lambda -Xfatbin -compress-all ${CUDA_ARCH_FLAGS}")
|
||||||
|
|
||||||
|
endif()
|
||||||
|
|
||||||
|
# -fsanitize=address
|
||||||
|
# -fsanitize=leak
|
||||||
|
if (APPLE)
|
||||||
|
set(CMAKE_CXX_FLAGS " -fPIC -D__APPLE_OS__=true")
|
||||||
|
elseif(WIN32)
|
||||||
|
if (SD_CPU)
|
||||||
|
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -fPIC -mmmx -msse -msse2 -msse3 -mssse3 -msse4.1 -msse4.2 -msse4 -mavx -mavx2 -O3")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if (SD_CPU AND LINUX)
|
||||||
|
set(CMAKE_CXX_FLAGS " -fPIC")
|
||||||
|
endif()
|
||||||
|
else()
|
||||||
|
set(CMAKE_CXX_FLAGS " -fPIC")
|
||||||
|
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -O3")
|
||||||
|
IF(${SD_ARCH} MATCHES "arm*")
|
||||||
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=${SD_ARCH}")
|
||||||
|
else()
|
||||||
|
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -O3")
|
||||||
|
|
||||||
|
if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64*")
|
||||||
|
set(CMAKE_CXX_FLAGS " ${CMAKE_CXX_FLAGS} -mcpu=native")
|
||||||
|
else()
|
||||||
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=native -mtune=native")
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
if (SD_CPU AND SD_SANITIZE)
|
||||||
|
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -fsanitize=address")
|
||||||
|
else()
|
||||||
|
# CUDA?
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
|
||||||
|
|
||||||
|
# tests are always compiled with all ops included
|
||||||
|
SET( CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DSD_ALL_OPS=true -DBUILD_TESTS=true")
|
||||||
|
|
||||||
|
if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang")
|
||||||
|
# using Clang
|
||||||
|
SET( CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${ARCH_TUNE}")
|
||||||
|
elseif ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Intel")
|
||||||
|
# using Intel C++
|
||||||
|
SET( CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${ARCH_TUNE} -fp-model fast")
|
||||||
|
elseif ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "MSVC")
|
||||||
|
# using Visual Studio C++
|
||||||
|
|
||||||
|
elseif ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU")
|
||||||
|
# using GCC
|
||||||
|
SET( CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fmax-errors=2")
|
||||||
|
|
||||||
|
if (CMAKE_BUILD_TYPE STREQUAL "Debug" AND ${CMAKE_SYSTEM_NAME} MATCHES "Linux" AND NOT(MINGW))
|
||||||
|
SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -rdynamic -Wl,-export-dynamic")
|
||||||
|
SET(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -export-dynamic")
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
|
||||||
|
IF(${CMAKE_SYSTEM_NAME} MATCHES "Linux")
|
||||||
|
include_directories("/usr/include")
|
||||||
|
include_directories("/usr/local/include")
|
||||||
|
ENDIF(${CMAKE_SYSTEM_NAME} MATCHES "Linux")
|
||||||
|
|
||||||
|
if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" AND "${CMAKE_CXX_COMPILER_VERSION}" VERSION_LESS 4.9)
|
||||||
|
message(FATAL_ERROR "You need at least GCC 4.9")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU")
|
||||||
|
find_package(OpenMP)
|
||||||
|
endif()
|
||||||
|
if (OPENMP_FOUND)
|
||||||
|
set (CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}")
|
||||||
|
set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
|
||||||
|
else()
|
||||||
|
message("OPENMP NOT FOUND")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if (SD_CPU)
|
||||||
|
file(GLOB_RECURSE TEST_SOURCES false ./*.cpp ./*.h)
|
||||||
|
elseif (SD_CUDA)
|
||||||
|
file(GLOB_RECURSE TEST_SOURCES false ./*.cpp ./*.cu ./*.h)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
# Filter out any source files from */CMakeFiles/* paths. these tend to cause problems such a multiple main definitions.
|
||||||
|
set (EXCLUDE_DIR "/CMakeFiles/")
|
||||||
|
foreach (TMP_PATH ${TEST_SOURCES})
|
||||||
|
string (FIND ${TMP_PATH} ${EXCLUDE_DIR} EXCLUDE_DIR_FOUND)
|
||||||
|
if (NOT ${EXCLUDE_DIR_FOUND} EQUAL -1)
|
||||||
|
list (REMOVE_ITEM TEST_SOURCES ${TMP_PATH})
|
||||||
|
endif ()
|
||||||
|
endforeach(TMP_PATH)
|
||||||
|
|
||||||
|
if (SD_CPU)
|
||||||
|
if (NOT BLAS_LIBRARIES)
|
||||||
|
set(BLAS_LIBRARIES "")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
add_executable(runtests ${TEST_SOURCES})
|
||||||
|
target_link_libraries(runtests samediff_obj ${MKLDNN_LIBRARIES} ${OPENBLAS_LIBRARIES} ${MKLDNN} ${BLAS_LIBRARIES} ${CPU_FEATURES} ${ARMCOMPUTE_LIBRARIES} gtest gtest_main)
|
||||||
|
elseif(SD_CUDA)
|
||||||
|
|
||||||
|
add_executable(runtests ${TEST_SOURCES})
|
||||||
|
|
||||||
|
if (WIN32)
|
||||||
|
message("MSVC runtime for tests: ${MSVC_RT_LIB}")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
# applies to windows only
|
||||||
|
set_property(TARGET runtests PROPERTY MSVC_RUNTIME_LIBRARY "${MSVC_RT_LIB}$<$<CONFIG:Debug>:Debug>")
|
||||||
|
set_property(TARGET gtest PROPERTY MSVC_RUNTIME_LIBRARY "${MSVC_RT_LIB}$<$<CONFIG:Debug>:Debug>")
|
||||||
|
set_property(TARGET gtest_main PROPERTY MSVC_RUNTIME_LIBRARY "${MSVC_RT_LIB}$<$<CONFIG:Debug>:Debug>")
|
||||||
|
|
||||||
|
if (HAVE_CUDNN)
|
||||||
|
message("CUDNN library: ${CUDNN}")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
target_link_libraries(runtests samediff_obj ${CUDA_LIBRARIES} ${CUDA_CUBLAS_LIBRARIES} ${CUDA_cusolver_LIBRARY} ${CUDNN} ${MKLDNN} gtest gtest_main)
|
||||||
|
endif()
|
|
@ -0,0 +1,96 @@
|
||||||
|
/* ******************************************************************************
|
||||||
|
*
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* See the NOTICE file distributed with this work for additional
|
||||||
|
* information regarding copyright ownership.
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// Created by agibsonccc on 3/30/17.
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "testinclude.h"
|
||||||
|
#include <string>
|
||||||
|
#include <legacy/NativeOps.h>
|
||||||
|
|
||||||
|
class FileTest : public testing::Test {
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
class LoadFromStringTest : public testing::Test {
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
class HeaderTest : public testing::Test {
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(HeaderTest, test_dataTypes_1) {
|
||||||
|
std::string header("0NUMPY6789{'descr': '>f4");
|
||||||
|
|
||||||
|
|
||||||
|
ASSERT_EQ(sd::DataType::FLOAT32, dataTypeFromNpyHeader(const_cast<char *>(header.data())));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(HeaderTest, test_dataTypes_2) {
|
||||||
|
std::string header("0NUMPY6789{'descr': '>f8");
|
||||||
|
|
||||||
|
|
||||||
|
ASSERT_EQ(sd::DataType::DOUBLE, dataTypeFromNpyHeader(const_cast<char *>(header.data())));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(HeaderTest, test_dataTypes_3) {
|
||||||
|
std::string header("0NUMPY6789{'descr': '<i4");
|
||||||
|
|
||||||
|
|
||||||
|
ASSERT_EQ(sd::DataType::INT32, dataTypeFromNpyHeader(const_cast<char *>(header.data())));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(HeaderTest, test_dataTypes_4) {
|
||||||
|
std::string header("0NUMPY6789{'descr': '>u2");
|
||||||
|
|
||||||
|
|
||||||
|
ASSERT_EQ(sd::DataType::UINT16, dataTypeFromNpyHeader(const_cast<char *>(header.data())));
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
TEST_F(FileTest,T) {
|
||||||
|
cnpy::NpyArray npy = cnpy::npyLoad(std::string("/home/agibsonccc/code/libnd4j/test.npy"));
|
||||||
|
ASSERT_FALSE(npy.fortranOrder);
|
||||||
|
|
||||||
|
ASSERT_EQ(2,npy.shape[0]);
|
||||||
|
ASSERT_EQ(2,npy.shape[1]);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(LoadFromStringTest,PathTest) {
|
||||||
|
char *loaded = cnpy::loadFile("/home/agibsonccc/code/libnd4j/test.npy");
|
||||||
|
cnpy::NpyArray loadedArr = cnpy::loadNpyFromPointer(loaded);
|
||||||
|
ASSERT_FALSE(loadedArr.fortranOrder);
|
||||||
|
ASSERT_EQ(2,loadedArr.shape[0]);
|
||||||
|
ASSERT_EQ(2,loadedArr.shape[1]);
|
||||||
|
double *data = reinterpret_cast<double *>(loadedArr.data);
|
||||||
|
ASSERT_EQ(1.0,data[0]);
|
||||||
|
ASSERT_EQ(2.0,data[1]);
|
||||||
|
ASSERT_EQ(3.0,data[2]);
|
||||||
|
ASSERT_EQ(4.0,data[3]);
|
||||||
|
Nd4jPointer pointer = reinterpret_cast<Nd4jPointer >(&loadedArr);
|
||||||
|
int *shapeBuffer = shape::shapeBufferOfNpy(loadedArr);
|
||||||
|
Nd4jPointer pointer1 = dataPointForNumpy(loaded);
|
||||||
|
delete[] shapeBuffer;
|
||||||
|
|
||||||
|
double *data2 = reinterpret_cast<double *>(pointer1);
|
||||||
|
delete[] loaded;
|
||||||
|
}
|
||||||
|
|
||||||
|
*/
|
|
@ -0,0 +1,334 @@
|
||||||
|
/* ******************************************************************************
|
||||||
|
*
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* See the NOTICE file distributed with this work for additional
|
||||||
|
* information regarding copyright ownership.
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// Created by raver119 on 16.10.2017.
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "testlayers.h"
|
||||||
|
#include <graph/Graph.h>
|
||||||
|
#include <graph/GraphExecutioner.h>
|
||||||
|
#include <graph/Node.h>
|
||||||
|
#include <ops/declarable/CustomOperations.h>
|
||||||
|
|
||||||
|
using namespace sd;
|
||||||
|
using namespace sd::graph;
|
||||||
|
|
||||||
|
class ConditionalTests : public testing::Test {
|
||||||
|
public:
|
||||||
|
ConditionalTests(){
|
||||||
|
//Environment::getInstance().setVerbose(true);
|
||||||
|
//Environment::getInstance().setDebug(true);
|
||||||
|
}
|
||||||
|
|
||||||
|
~ConditionalTests(){
|
||||||
|
//Environment::getInstance().setVerbose(false);
|
||||||
|
//Environment::getInstance().setDebug(false);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(ConditionalTests, BasicTests_1) {
|
||||||
|
Graph graph;
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::valueOf({2, 2}, 1.0f);
|
||||||
|
auto y0 = NDArrayFactory::valueOf({2, 2}, 5.0f);
|
||||||
|
auto y1 = NDArrayFactory::valueOf({2, 2}, -5.0f);
|
||||||
|
auto scalar = NDArrayFactory::create_(1.0f);
|
||||||
|
|
||||||
|
auto variableSpace = graph.getVariableSpace();
|
||||||
|
|
||||||
|
variableSpace->putVariable(-1, x);
|
||||||
|
variableSpace->putVariable(-2, y0);
|
||||||
|
variableSpace->putVariable(-3, y1);
|
||||||
|
variableSpace->putVariable(-4, scalar);
|
||||||
|
|
||||||
|
|
||||||
|
auto scopeCondition = new Node(OpType_LOGIC, logic::Scope, 1);
|
||||||
|
scopeCondition->setName("scopeCondition");
|
||||||
|
|
||||||
|
auto scopeFalse = new Node(OpType_LOGIC, logic::Scope, 2);
|
||||||
|
scopeFalse->setName("scopeFalse");
|
||||||
|
|
||||||
|
auto scopeTrue = new Node(OpType_LOGIC, logic::Scope, 3);
|
||||||
|
scopeTrue->setName("scopeTrue");
|
||||||
|
|
||||||
|
auto nodeF = new Node(OpType_PAIRWISE, pairwise::Add, 5, {-1, -2});
|
||||||
|
nodeF->setScopeInfo(2, "scopeFalse");
|
||||||
|
|
||||||
|
auto nodeT = new Node(OpType_PAIRWISE, pairwise::Subtract, 6, {-1, -2});
|
||||||
|
nodeT->setScopeInfo(3, "scopeTrue");
|
||||||
|
|
||||||
|
auto nodeC0 = new Node(OpType_REDUCE_SAME, reduce::Sum, 7, {-1});
|
||||||
|
nodeC0->setScopeInfo(1, "scopeCondition");
|
||||||
|
|
||||||
|
sd::ops::eq_scalar op;
|
||||||
|
auto nodeC1 = new Node(&op, 8, {7, -4});
|
||||||
|
nodeC1->setScopeInfo(1, "scopeCondition");
|
||||||
|
|
||||||
|
graph.addNode(scopeCondition);
|
||||||
|
graph.addNode(scopeFalse);
|
||||||
|
graph.addNode(scopeTrue);
|
||||||
|
graph.addNode(nodeF);
|
||||||
|
graph.addNode(nodeT);
|
||||||
|
graph.addNode(nodeC0);
|
||||||
|
graph.addNode(nodeC1);
|
||||||
|
|
||||||
|
// at this point graph should ounly have Nodes referring to the Scopes: condition scope, true scope and false scope
|
||||||
|
ASSERT_EQ(3, graph.totalNodes());
|
||||||
|
|
||||||
|
// now we're adding Condition op, that'll take all of those in
|
||||||
|
auto nodeCondition = new Node(OpType_LOGIC, logic::Conditional, 10, {1, 2, 3});
|
||||||
|
graph.addNode(nodeCondition);
|
||||||
|
|
||||||
|
ASSERT_EQ(4, graph.totalNodes());
|
||||||
|
|
||||||
|
Nd4jStatus status = GraphExecutioner::execute(&graph);
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||||
|
|
||||||
|
ASSERT_TRUE(variableSpace->hasVariable(10, 0));
|
||||||
|
auto conditionalResult = variableSpace->getVariable(10, 0)->getNDArray();
|
||||||
|
ASSERT_NE(nullptr, conditionalResult);
|
||||||
|
|
||||||
|
ASSERT_NEAR(6.0, conditionalResult->meanNumber().e<double>(0), 1e-5);
|
||||||
|
}
|
||||||
|
#ifdef GRAPH_FILES_OK
|
||||||
|
/**
|
||||||
|
* Condition is False
|
||||||
|
*/
|
||||||
|
TEST_F(ConditionalTests, Flat_Test_1) {
|
||||||
|
sd::ops::identity op0;
|
||||||
|
|
||||||
|
auto graph = GraphExecutioner::importFromFlatBuffers("./resources/simpleif_0_1.fb");
|
||||||
|
auto varSpace = graph->getVariableSpace();
|
||||||
|
//varSpace->getVariable(1)->getNDArray()->assign(2.0);
|
||||||
|
//varSpace->getVariable(2)->getNDArray()->assign(0.0);
|
||||||
|
|
||||||
|
//graph->printOut();
|
||||||
|
|
||||||
|
auto status = GraphExecutioner::execute(graph);
|
||||||
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
|
||||||
|
ASSERT_TRUE(varSpace->hasVariable(15));
|
||||||
|
|
||||||
|
auto z = varSpace->getVariable(15)->getNDArray();
|
||||||
|
|
||||||
|
ASSERT_NE(nullptr, z);
|
||||||
|
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', {2, 2}, {-2, -2, -2, -2});
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
delete graph;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Condition is True
|
||||||
|
*/
|
||||||
|
TEST_F(ConditionalTests, Flat_Test_2) {
|
||||||
|
Environment::getInstance().setDebug(true);
|
||||||
|
Environment::getInstance().setVerbose(true);
|
||||||
|
sd::ops::identity op0;
|
||||||
|
|
||||||
|
auto graph = GraphExecutioner::importFromFlatBuffers("./resources/simpleif_0.fb");
|
||||||
|
auto varSpace = graph->getVariableSpace();
|
||||||
|
varSpace->getVariable(1)->getNDArray()->assign(-1.0);
|
||||||
|
|
||||||
|
graph->printOut();
|
||||||
|
|
||||||
|
auto status = GraphExecutioner::execute(graph);
|
||||||
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
|
||||||
|
ASSERT_TRUE(varSpace->hasVariable(15));
|
||||||
|
|
||||||
|
auto z = varSpace->getVariable(15)->getNDArray();
|
||||||
|
|
||||||
|
ASSERT_NE(nullptr, z);
|
||||||
|
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', {2, 2}, {1, 1, 1, 1});
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
delete graph;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Condition is false here, so there loop will be skipped
|
||||||
|
*/
|
||||||
|
TEST_F(ConditionalTests, Flat_Test_3) {
|
||||||
|
sd::ops::identity op0;
|
||||||
|
|
||||||
|
auto graph = GraphExecutioner::importFromFlatBuffers("./resources/simplewhile_0_3.fb");
|
||||||
|
auto varSpace = graph->getVariableSpace();
|
||||||
|
varSpace->getVariable(1)->getNDArray()->assign(1.0);
|
||||||
|
|
||||||
|
//graph->printOut();
|
||||||
|
|
||||||
|
auto status = GraphExecutioner::execute(graph);
|
||||||
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
|
||||||
|
ASSERT_TRUE(varSpace->hasVariable(17));
|
||||||
|
|
||||||
|
auto z = varSpace->getVariable(17)->getNDArray();
|
||||||
|
|
||||||
|
ASSERT_NE(nullptr, z);
|
||||||
|
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', {2, 2}, {1, 1, 1, 1});
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
delete graph;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* just one cycle in body
|
||||||
|
*/
|
||||||
|
TEST_F(ConditionalTests, Flat_Test_4) {
|
||||||
|
sd::ops::identity op0;
|
||||||
|
|
||||||
|
auto graph = GraphExecutioner::importFromFlatBuffers("./resources/simplewhile_0_4.fb");
|
||||||
|
auto varSpace = graph->getVariableSpace();
|
||||||
|
varSpace->getVariable(2)->getNDArray()->assign(4.0);
|
||||||
|
|
||||||
|
//graph->printOut();
|
||||||
|
|
||||||
|
auto status = GraphExecutioner::execute(graph);
|
||||||
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
|
||||||
|
ASSERT_TRUE(varSpace->hasVariable(17));
|
||||||
|
|
||||||
|
auto z = varSpace->getVariable(17)->getNDArray();
|
||||||
|
|
||||||
|
ASSERT_NE(nullptr, z);
|
||||||
|
|
||||||
|
// 0.0 + 2.0 = 2.0 in each element
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', {2, 2}, {2, 2, 2, 2});
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
delete graph;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* just two cycles in body
|
||||||
|
*/
|
||||||
|
TEST_F(ConditionalTests, Flat_Test_5) {
|
||||||
|
sd::ops::identity op0;
|
||||||
|
|
||||||
|
auto graph = GraphExecutioner::importFromFlatBuffers("./resources/simplewhile_0_4.fb");
|
||||||
|
auto varSpace = graph->getVariableSpace();
|
||||||
|
varSpace->getVariable(2)->getNDArray()->assign(9.0);
|
||||||
|
|
||||||
|
//graph->printOut();
|
||||||
|
|
||||||
|
auto status = GraphExecutioner::execute(graph);
|
||||||
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
|
||||||
|
ASSERT_TRUE(varSpace->hasVariable(17));
|
||||||
|
|
||||||
|
auto z = varSpace->getVariable(17)->getNDArray();
|
||||||
|
|
||||||
|
ASSERT_NE(nullptr, z);
|
||||||
|
|
||||||
|
// 0.0 + 2.0 + 2.0 = 4.0 in each element
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', {2, 2}, {4, 4, 4, 4});
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
delete graph;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* While loop with multiple variables
|
||||||
|
*/
|
||||||
|
TEST_F(ConditionalTests, Flat_Test_6) {
|
||||||
|
sd::ops::identity op0;
|
||||||
|
|
||||||
|
auto graph = GraphExecutioner::importFromFlatBuffers("./resources/simplewhile_1.fb");
|
||||||
|
auto varSpace = graph->getVariableSpace();
|
||||||
|
varSpace->getVariable(1)->getNDArray()->assign(-4.0f);
|
||||||
|
varSpace->getVariable(2)->getNDArray()->assign(1.0f);
|
||||||
|
|
||||||
|
//graph->printOut();
|
||||||
|
|
||||||
|
auto status = GraphExecutioner::execute(graph);
|
||||||
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
|
||||||
|
ASSERT_TRUE(varSpace->hasVariable(25));
|
||||||
|
|
||||||
|
auto z = varSpace->getVariable(25)->getNDArray();
|
||||||
|
|
||||||
|
ASSERT_NE(nullptr, z);
|
||||||
|
|
||||||
|
//z->printIndexedBuffer();
|
||||||
|
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', {2, 2}, {-1, -1, -1, -1});
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
delete graph;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ConditionalTests, Flat_Test_7) {
|
||||||
|
sd::ops::identity op0;
|
||||||
|
|
||||||
|
auto graph = GraphExecutioner::importFromFlatBuffers("./resources/simplewhile_1.fb");
|
||||||
|
auto varSpace = graph->getVariableSpace();
|
||||||
|
varSpace->getVariable(1)->getNDArray()->assign(-9.0f);
|
||||||
|
varSpace->getVariable(2)->getNDArray()->assign(1.0f);
|
||||||
|
|
||||||
|
//graph->printOut();
|
||||||
|
|
||||||
|
auto status = GraphExecutioner::execute(graph);
|
||||||
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
|
||||||
|
ASSERT_TRUE(varSpace->hasVariable(25));
|
||||||
|
|
||||||
|
auto z = varSpace->getVariable(25)->getNDArray();
|
||||||
|
|
||||||
|
ASSERT_NE(nullptr, z);
|
||||||
|
|
||||||
|
//z->printIndexedBuffer();
|
||||||
|
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', {2, 2}, {-3, -3, -3, -3});
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
delete graph;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This test checks nested while execution
|
||||||
|
*/
|
||||||
|
TEST_F(ConditionalTests, Flat_Test_8) {
|
||||||
|
auto graph = GraphExecutioner::importFromFlatBuffers("./resources/simplewhile_nested.fb");
|
||||||
|
auto varSpace = graph->getVariableSpace();
|
||||||
|
//graph->printOut();
|
||||||
|
|
||||||
|
auto status = GraphExecutioner::execute(graph);
|
||||||
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
|
||||||
|
ASSERT_TRUE(varSpace->hasVariable(52));
|
||||||
|
|
||||||
|
auto z = varSpace->getVariable(52)->getNDArray();
|
||||||
|
|
||||||
|
ASSERT_NE(nullptr, z);
|
||||||
|
|
||||||
|
//val exp = Nd4j.create(2, 2).assign(15.0);
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', {2, 2}, {15, 15, 15, 15});
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
delete graph;
|
||||||
|
}
|
||||||
|
#endif
|
|
@ -0,0 +1,351 @@
|
||||||
|
/* ******************************************************************************
|
||||||
|
*
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* See the NOTICE file distributed with this work for additional
|
||||||
|
* information regarding copyright ownership.
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "testlayers.h"
|
||||||
|
#include <ops/declarable/CustomOperations.h>
|
||||||
|
#include <helpers/ConstantShapeHelper.h>
|
||||||
|
#include <array/ShapeDescriptor.h>
|
||||||
|
#include <array/ConstantDataBuffer.h>
|
||||||
|
#include <helpers/PointersManager.h>
|
||||||
|
|
||||||
|
using namespace sd;
|
||||||
|
using namespace sd::ops;
|
||||||
|
using namespace sd::graph;
|
||||||
|
|
||||||
|
class ConstantShapeHelperTests : public testing::Test {
|
||||||
|
public:
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
class ConstantHelperTests : public testing::Test {
|
||||||
|
public:
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
class ConstantTadHelperTests : public testing::Test {
|
||||||
|
public:
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(ConstantShapeHelperTests, test_cachedAmount_1) {
|
||||||
|
auto ttlBefore = ConstantShapeHelper::getInstance().totalCachedEntries();
|
||||||
|
|
||||||
|
auto arrayA = NDArrayFactory::create<bool>('c', {7, 11, 17, 23, 31, 43});
|
||||||
|
|
||||||
|
auto ttlMiddle = ConstantShapeHelper::getInstance().totalCachedEntries();
|
||||||
|
|
||||||
|
auto arrayB = NDArrayFactory::create<bool>('c', {7, 11, 17, 23, 31, 43});
|
||||||
|
|
||||||
|
auto ttlAfter = ConstantShapeHelper::getInstance().totalCachedEntries();
|
||||||
|
|
||||||
|
ASSERT_TRUE(ttlBefore <= ttlMiddle);
|
||||||
|
ASSERT_EQ(ttlMiddle, ttlAfter);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ConstantTadHelperTests, test_cachedAmount_1) {
|
||||||
|
auto arrayA = NDArrayFactory::create<bool>('c', {7, 11, 17, 23, 31, 43});
|
||||||
|
auto ttlBefore = ConstantTadHelper::getInstance().totalCachedEntries();
|
||||||
|
|
||||||
|
auto packAA = ConstantTadHelper::getInstance().tadForDimensions(arrayA.shapeInfo(), {3, 4});
|
||||||
|
|
||||||
|
auto ttlMiddle = ConstantTadHelper::getInstance().totalCachedEntries();
|
||||||
|
|
||||||
|
auto packAB = ConstantTadHelper::getInstance().tadForDimensions(arrayA.shapeInfo(), {3, 4});
|
||||||
|
|
||||||
|
auto ttlAfter = ConstantTadHelper::getInstance().totalCachedEntries();
|
||||||
|
|
||||||
|
ASSERT_TRUE(ttlBefore <= ttlMiddle);
|
||||||
|
ASSERT_EQ(ttlMiddle, ttlAfter);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ConstantShapeHelperTests, basic_test_1) {
|
||||||
|
auto ptr = ShapeBuilders::createShapeInfo(sd::DataType::BFLOAT16, 'f', {5, 10, 15});
|
||||||
|
ShapeDescriptor descriptor(ptr);
|
||||||
|
ShapeDescriptor descriptor2(ptr);
|
||||||
|
|
||||||
|
ASSERT_EQ(descriptor, descriptor2);
|
||||||
|
|
||||||
|
ASSERT_EQ(1, descriptor.ews());
|
||||||
|
ASSERT_EQ(3, descriptor.rank());
|
||||||
|
ASSERT_EQ('f', descriptor.order());
|
||||||
|
ASSERT_EQ(sd::DataType::BFLOAT16, descriptor.dataType());
|
||||||
|
ASSERT_FALSE(descriptor.isEmpty());
|
||||||
|
|
||||||
|
ASSERT_FALSE(ConstantShapeHelper::getInstance().checkBufferExistenceForShapeInfo(descriptor));
|
||||||
|
|
||||||
|
auto buffer = ConstantShapeHelper::getInstance().bufferForShapeInfo(descriptor);
|
||||||
|
|
||||||
|
ASSERT_TRUE(ConstantShapeHelper::getInstance().checkBufferExistenceForShapeInfo(descriptor));
|
||||||
|
|
||||||
|
auto buffer2 = ConstantShapeHelper::getInstance().bufferForShapeInfo(descriptor2);
|
||||||
|
|
||||||
|
|
||||||
|
ASSERT_TRUE(buffer.primary() != nullptr);
|
||||||
|
ASSERT_TRUE(buffer.primary() == buffer2.primary());
|
||||||
|
ASSERT_TRUE(buffer.special() == buffer2.special());
|
||||||
|
|
||||||
|
delete []ptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ConstantShapeHelperTests, stress_test_1) {
|
||||||
|
|
||||||
|
for (auto x = 0; x < 1000; x++) {
|
||||||
|
auto ptr = ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', {5, x + 10, x + 1});
|
||||||
|
ShapeDescriptor descriptor(ptr);
|
||||||
|
ConstantShapeHelper::getInstance().createShapeInfo(descriptor);
|
||||||
|
delete [] ptr;
|
||||||
|
}
|
||||||
|
ShapeDescriptor aShape(sd::DataType::FLOAT32, 'c', {(Nd4jLong)5, (Nd4jLong)382, (Nd4jLong)373});
|
||||||
|
// nd4j_printf("%d\n", ConstantShapeHelper::getInstance().cachedEntriesForDevice(0));
|
||||||
|
|
||||||
|
auto timeStart = std::chrono::system_clock::now();
|
||||||
|
ASSERT_TRUE(ConstantShapeHelper::getInstance().checkBufferExistenceForShapeInfo(aShape));
|
||||||
|
auto timeEnd = std::chrono::system_clock::now();
|
||||||
|
|
||||||
|
auto outerTime = std::chrono::duration_cast<std::chrono::nanoseconds>(timeEnd - timeStart).count();
|
||||||
|
nd4j_printf("Total time (us) %lld\n", outerTime);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ConstantShapeHelperTests, basic_test_3) {
|
||||||
|
auto array = NDArrayFactory::create_<float>('c', {128});
|
||||||
|
|
||||||
|
ASSERT_TRUE(array->shapeInfo() != nullptr);
|
||||||
|
|
||||||
|
#ifdef __CUDABLAS__
|
||||||
|
ASSERT_TRUE(array->specialShapeInfo() != nullptr);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
delete array;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(ConstantShapeHelperTests, basic_test_4) {
|
||||||
|
auto array = NDArrayFactory::create_<float>('c', {128, 256});
|
||||||
|
|
||||||
|
auto dup = new NDArray(array->dup('f'));
|
||||||
|
|
||||||
|
ASSERT_TRUE(dup->shapeInfo() != nullptr);
|
||||||
|
|
||||||
|
#ifdef __CUDABLAS__
|
||||||
|
ASSERT_TRUE(dup->specialShapeInfo() != nullptr);
|
||||||
|
PointersManager manager(sd::LaunchContext ::defaultContext(), "test");
|
||||||
|
// manager.printDevContentOnDev<Nd4jLong>(dup->special(), shape::shapeInfoLength(2), 0);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
delete array;
|
||||||
|
delete dup;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(ConstantShapeHelperTests, basic_test_5) {
|
||||||
|
|
||||||
|
auto arrayA = NDArrayFactory::create<int>(1);
|
||||||
|
auto arrayB = NDArrayFactory::create_<float>('c', {128, 256});
|
||||||
|
|
||||||
|
//arrayA.printShapeInfo("A");
|
||||||
|
//arrayB->printShapeInfo("B");
|
||||||
|
ASSERT_EQ(0, arrayA.rankOf());
|
||||||
|
ASSERT_EQ(2, arrayB->rankOf());
|
||||||
|
ASSERT_NE(arrayA.dataType(), arrayB->dataType());
|
||||||
|
|
||||||
|
delete arrayB;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ConstantShapeHelperTests, basic_test_6) {
|
||||||
|
ShapeDescriptor descriptorA(sd::DataType::INT32, 'c', {});
|
||||||
|
ShapeDescriptor descriptorB(sd::DataType::FLOAT32, 'c', {10, 10});
|
||||||
|
|
||||||
|
// ASSERT_FALSE(descriptorA < descriptorB);
|
||||||
|
// ASSERT_TRUE(descriptorB < descriptorA);
|
||||||
|
|
||||||
|
ASSERT_TRUE(descriptorA < descriptorB);
|
||||||
|
ASSERT_FALSE(descriptorB < descriptorA);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ConstantShapeHelperTests, basic_test_7) {
|
||||||
|
auto array = NDArrayFactory::create_<float>('c', {32, 256});
|
||||||
|
|
||||||
|
IndicesList indices({NDIndex::all(), NDIndex::interval(0,1)});
|
||||||
|
auto strided = array->subarray(indices);
|
||||||
|
strided.assign(1.0f);
|
||||||
|
|
||||||
|
//strided->printIndexedBuffer("column");
|
||||||
|
|
||||||
|
delete array;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ConstantHelperTests, basic_test_1) {
|
||||||
|
|
||||||
|
ConstantDescriptor descriptor({1, 2, 3});
|
||||||
|
|
||||||
|
ConstantDataBuffer* fBuffer = ConstantHelper::getInstance().constantBuffer(descriptor, sd::DataType::FLOAT32);
|
||||||
|
auto fPtr = fBuffer->primaryAsT<float>();
|
||||||
|
|
||||||
|
ASSERT_NEAR(1.f, fPtr[0], 1e-5);
|
||||||
|
ASSERT_NEAR(2.f, fPtr[1], 1e-5);
|
||||||
|
ASSERT_NEAR(3.f, fPtr[2], 1e-5);
|
||||||
|
|
||||||
|
auto iBuffer = ConstantHelper::getInstance().constantBuffer(descriptor, sd::DataType::INT32);
|
||||||
|
auto iPtr = iBuffer->primaryAsT<int>();
|
||||||
|
|
||||||
|
ASSERT_EQ(1, iPtr[0]);
|
||||||
|
ASSERT_EQ(2, iPtr[1]);
|
||||||
|
ASSERT_EQ(3, iPtr[2]);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ConstantHelperTests, basic_test_2) {
|
||||||
|
|
||||||
|
double array[] = {1., 2., 3.};
|
||||||
|
ConstantDescriptor descriptor(array, 3);
|
||||||
|
|
||||||
|
ConstantDataBuffer* fBuffer = ConstantHelper::getInstance().constantBuffer(descriptor, sd::DataType::FLOAT32);
|
||||||
|
auto fPtr = fBuffer->primaryAsT<float>();
|
||||||
|
|
||||||
|
ASSERT_NEAR(1.f, fPtr[0], 1e-5);
|
||||||
|
ASSERT_NEAR(2.f, fPtr[1], 1e-5);
|
||||||
|
ASSERT_NEAR(3.f, fPtr[2], 1e-5);
|
||||||
|
|
||||||
|
auto iBuffer = ConstantHelper::getInstance().constantBuffer(descriptor, sd::DataType::INT32);
|
||||||
|
auto iPtr = iBuffer->primaryAsT<int>();
|
||||||
|
|
||||||
|
ASSERT_EQ(1, iPtr[0]);
|
||||||
|
ASSERT_EQ(2, iPtr[1]);
|
||||||
|
ASSERT_EQ(3, iPtr[2]);
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(ConstantShapeHelperTests, ShapeDescriptor_1) {
|
||||||
|
|
||||||
|
Nd4jLong shapeInfo1[] = {4, 2, 5, 5, 2, 25, 5, 1, 50, 8192, 0, 99};
|
||||||
|
Nd4jLong shapeInfo2[] = {4, 2, 5, 5, 2, 50, 10, 2, 1, 8192, 1, 99};
|
||||||
|
|
||||||
|
ShapeDescriptor descr1(shapeInfo1);
|
||||||
|
ShapeDescriptor descr2(shapeInfo2);
|
||||||
|
|
||||||
|
ASSERT_FALSE(descr1 == descr2);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ConstantShapeHelperTests, ShapeDescriptor_validation) {
|
||||||
|
|
||||||
|
//for c order
|
||||||
|
std::vector<Nd4jLong> shape{ 2,3,4,5 };
|
||||||
|
std::vector<Nd4jLong> incorrectStride1{ 20,20,5,1 };
|
||||||
|
std::vector<Nd4jLong> incorrectStride2{ 60,20,5,5 };
|
||||||
|
std::vector<Nd4jLong> correctStride1{ 60,20,5,1 };
|
||||||
|
std::vector<Nd4jLong> correctStride2{ 300,100,25,5 };
|
||||||
|
std::vector<Nd4jLong> correctStride3{ 800, 200, 40, 5 };
|
||||||
|
|
||||||
|
auto shapeDesc = ShapeDescriptor(DataType::FLOAT32, 'c', shape, incorrectStride1, 1);
|
||||||
|
ASSERT_TRUE(shapeDesc.validate() == SHAPE_DESC_INCORRECT_STRIDES);
|
||||||
|
shapeDesc = ShapeDescriptor(DataType::FLOAT32, 'c', shape, correctStride1, 1);
|
||||||
|
ASSERT_TRUE(shapeDesc.validate() == SHAPE_DESC_OK);
|
||||||
|
shapeDesc = ShapeDescriptor(DataType::FLOAT32, 'c', shape, incorrectStride2, 1);
|
||||||
|
ASSERT_TRUE(shapeDesc.validate() == (SHAPE_DESC_INCORRECT_STRIDES | SHAPE_DESC_INCORRECT_EWS));
|
||||||
|
shapeDesc = ShapeDescriptor(DataType::FLOAT32, 'c', shape, correctStride2, 1);
|
||||||
|
ASSERT_TRUE(shapeDesc.validate() == SHAPE_DESC_INCORRECT_EWS);
|
||||||
|
shapeDesc = ShapeDescriptor(DataType::FLOAT32, 'c', shape, correctStride2, 5);
|
||||||
|
ASSERT_TRUE(shapeDesc.validate() == SHAPE_DESC_OK);
|
||||||
|
shapeDesc = ShapeDescriptor(DataType::FLOAT32, 'c', shape, correctStride3, 1);
|
||||||
|
ASSERT_TRUE(shapeDesc.validate() == SHAPE_DESC_INCORRECT_EWS);
|
||||||
|
shapeDesc = ShapeDescriptor(DataType::FLOAT32, 'c', shape, correctStride3, 0);
|
||||||
|
ASSERT_TRUE(shapeDesc.validate() == SHAPE_DESC_OK);
|
||||||
|
|
||||||
|
//order f
|
||||||
|
std::reverse(std::begin(shape), std::end(shape));
|
||||||
|
std::reverse(std::begin(incorrectStride1), std::end(incorrectStride1));
|
||||||
|
std::reverse(std::begin(incorrectStride2), std::end(incorrectStride2));
|
||||||
|
std::reverse(std::begin(correctStride1), std::end(correctStride1));
|
||||||
|
std::reverse(std::begin(correctStride2), std::end(correctStride2));
|
||||||
|
std::reverse(std::begin(correctStride3), std::end(correctStride3));
|
||||||
|
|
||||||
|
shapeDesc = ShapeDescriptor(DataType::FLOAT32, 'f', shape, incorrectStride1, 1);
|
||||||
|
ASSERT_TRUE(shapeDesc.validate() == SHAPE_DESC_INCORRECT_STRIDES);
|
||||||
|
shapeDesc = ShapeDescriptor(DataType::FLOAT32, 'f', shape, correctStride1, 1);
|
||||||
|
ASSERT_TRUE(shapeDesc.validate() == SHAPE_DESC_OK);
|
||||||
|
shapeDesc = ShapeDescriptor(DataType::FLOAT32, 'f', shape, incorrectStride2, 1);
|
||||||
|
ASSERT_TRUE(shapeDesc.validate() == (SHAPE_DESC_INCORRECT_STRIDES | SHAPE_DESC_INCORRECT_EWS));
|
||||||
|
shapeDesc = ShapeDescriptor(DataType::FLOAT32, 'f', shape, correctStride2, 1);
|
||||||
|
ASSERT_TRUE(shapeDesc.validate() == SHAPE_DESC_INCORRECT_EWS);
|
||||||
|
shapeDesc = ShapeDescriptor(DataType::FLOAT32, 'f', shape, correctStride2, 5);
|
||||||
|
ASSERT_TRUE(shapeDesc.validate() == SHAPE_DESC_OK);
|
||||||
|
shapeDesc = ShapeDescriptor(DataType::FLOAT32, 'f', shape, correctStride3, 1);
|
||||||
|
ASSERT_TRUE(shapeDesc.validate() == SHAPE_DESC_INCORRECT_EWS);
|
||||||
|
shapeDesc = ShapeDescriptor(DataType::FLOAT32, 'f', shape, correctStride3, 0);
|
||||||
|
ASSERT_TRUE(shapeDesc.validate() == SHAPE_DESC_OK);
|
||||||
|
|
||||||
|
std::vector<Nd4jLong> shape1;
|
||||||
|
shape1.resize(MAX_RANK+1);
|
||||||
|
shapeDesc = ShapeDescriptor(DataType::FLOAT32, 'f', shape1, correctStride3, 0);
|
||||||
|
ASSERT_TRUE( (shapeDesc.validate() & SHAPE_DESC_INCORRECT_RANK) == SHAPE_DESC_INCORRECT_RANK);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ConstantShapeHelperTests, ShapeDescriptor_paddedBuffer) {
|
||||||
|
|
||||||
|
constexpr int n = 2;
|
||||||
|
constexpr int c = 3;
|
||||||
|
constexpr int h = 4;
|
||||||
|
constexpr int w = 5;
|
||||||
|
constexpr int n_pad = 2;
|
||||||
|
constexpr int c_pad = 3;
|
||||||
|
constexpr int h_pad = 4;
|
||||||
|
constexpr int w_pad = 5;
|
||||||
|
char orders[] = { 'c', 'f' };
|
||||||
|
|
||||||
|
for (auto& order : orders) {
|
||||||
|
auto shapeDesc1 = ShapeDescriptor::paddedBufferDescriptor(DataType::FLOAT32, order, { n, c, h, w }, { n_pad, c_pad, h_pad, w_pad });
|
||||||
|
auto shapeDesc2 = ShapeDescriptor(DataType::FLOAT32, order, { n + n_pad, c + c_pad, h + h_pad, w + w_pad });
|
||||||
|
auto shapeDesc3 = ShapeDescriptor::paddedBufferDescriptor(DataType::FLOAT32, order, { n, c, h, w }, { n_pad, c_pad });
|
||||||
|
auto shapeDesc4 = ShapeDescriptor(DataType::FLOAT32, order, { n + n_pad, c + c_pad, h, w });
|
||||||
|
auto shapeDesc5 = ShapeDescriptor::paddedBufferDescriptor(DataType::FLOAT32, order, { n, c, h, w }, { 0, 0, h_pad, w_pad });
|
||||||
|
auto shapeDesc6 = ShapeDescriptor(DataType::FLOAT32, order, { n, c , h + h_pad, w + w_pad });
|
||||||
|
|
||||||
|
ASSERT_TRUE(shapeDesc1.validate() == SHAPE_DESC_OK);
|
||||||
|
ASSERT_TRUE(shapeDesc2.validate() == SHAPE_DESC_OK);
|
||||||
|
ASSERT_TRUE(shapeDesc3.validate() == SHAPE_DESC_OK);
|
||||||
|
ASSERT_TRUE(shapeDesc4.validate() == SHAPE_DESC_OK);
|
||||||
|
ASSERT_TRUE(shapeDesc5.validate() == SHAPE_DESC_OK);
|
||||||
|
ASSERT_TRUE(shapeDesc6.validate() == SHAPE_DESC_OK);
|
||||||
|
|
||||||
|
ASSERT_TRUE(shapeDesc1.allocLength() == shapeDesc2.allocLength());
|
||||||
|
ASSERT_TRUE(shapeDesc3.allocLength() == shapeDesc4.allocLength());
|
||||||
|
ASSERT_TRUE(shapeDesc5.allocLength() == shapeDesc6.allocLength());
|
||||||
|
|
||||||
|
const auto& v1 = shapeDesc1.strides();
|
||||||
|
const auto& v2 = shapeDesc2.strides();
|
||||||
|
const auto& v3 = shapeDesc3.strides();
|
||||||
|
const auto& v4 = shapeDesc4.strides();
|
||||||
|
const auto& v5 = shapeDesc5.strides();
|
||||||
|
const auto& v6 = shapeDesc6.strides();
|
||||||
|
|
||||||
|
for (int i = 0; i < v1.size(); i++) {
|
||||||
|
ASSERT_TRUE(v1[i] == v2[i]);
|
||||||
|
}
|
||||||
|
for (int i = 0; i < v3.size(); i++) {
|
||||||
|
ASSERT_TRUE(v3[i] == v4[i]);
|
||||||
|
}
|
||||||
|
for (int i = 0; i < v5.size(); i++) {
|
||||||
|
ASSERT_TRUE(v5[i] == v6[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,358 @@
|
||||||
|
/* ******************************************************************************
|
||||||
|
*
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* See the NOTICE file distributed with this work for additional
|
||||||
|
* information regarding copyright ownership.
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// Created by raver119 on 30.10.2017.
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "testlayers.h"
|
||||||
|
#include <ops/declarable/CustomOperations.h>
|
||||||
|
|
||||||
|
using namespace sd;
|
||||||
|
using namespace sd::ops;
|
||||||
|
using namespace sd::graph;
|
||||||
|
|
||||||
|
class ContextTests : public testing::Test {
|
||||||
|
public:
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(ContextTests, Basic_Test_1) {
|
||||||
|
VariableSpace variableSpace;
|
||||||
|
|
||||||
|
auto _20 = NDArrayFactory::create_<float>('c', {2, 2});
|
||||||
|
auto _21 = NDArrayFactory::create_<float>('c', {2, 2});
|
||||||
|
|
||||||
|
_20->assign(1.0f);
|
||||||
|
_21->assign(2.0f);
|
||||||
|
|
||||||
|
variableSpace.putVariable(2, 0, _20);
|
||||||
|
variableSpace.putVariable(2, 1, _21);
|
||||||
|
|
||||||
|
Context block(1, &variableSpace);
|
||||||
|
|
||||||
|
block.pickInput(2, 0);
|
||||||
|
block.pickInput(2, 1);
|
||||||
|
|
||||||
|
ASSERT_EQ(2, block.inputs()->size());
|
||||||
|
ASSERT_EQ(2, block.width());
|
||||||
|
|
||||||
|
ASSERT_TRUE(variableSpace.hasVariable(2, 0));
|
||||||
|
ASSERT_TRUE(variableSpace.hasVariable(2, 1));
|
||||||
|
|
||||||
|
ASSERT_NEAR(1.0f, block.variable(0)->getNDArray()->meanNumber().e<float>(0), 1e-5);
|
||||||
|
ASSERT_NEAR(2.0f, block.variable(1)->getNDArray()->meanNumber().e<float>(0), 1e-5);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(ContextTests, Basic_Test_2) {
|
||||||
|
VariableSpace variableSpace;
|
||||||
|
|
||||||
|
auto _20 = NDArrayFactory::create_<float>('c', {2, 2});
|
||||||
|
auto _21 = NDArrayFactory::create_<float>('c', {2, 2});
|
||||||
|
|
||||||
|
_20->assign(1.0f);
|
||||||
|
_21->assign(2.0f);
|
||||||
|
|
||||||
|
variableSpace.putVariable(-1, _20);
|
||||||
|
variableSpace.putVariable(-2, _21);
|
||||||
|
|
||||||
|
Context block(1, &variableSpace);
|
||||||
|
|
||||||
|
block.pickInput(-1);
|
||||||
|
block.pickInput(-2);
|
||||||
|
|
||||||
|
ASSERT_EQ(2, block.inputs()->size());
|
||||||
|
ASSERT_EQ(2, block.width());
|
||||||
|
|
||||||
|
ASSERT_TRUE(variableSpace.hasVariable(-1));
|
||||||
|
ASSERT_TRUE(variableSpace.hasVariable(-2));
|
||||||
|
|
||||||
|
ASSERT_NEAR(1.0f, block.variable(0)->getNDArray()->meanNumber().e<float>(0), 1e-5);
|
||||||
|
ASSERT_NEAR(2.0f, block.variable(1)->getNDArray()->meanNumber().e<float>(0), 1e-5);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(ContextTests, Basic_Test_3) {
|
||||||
|
VariableSpace variableSpace;
|
||||||
|
|
||||||
|
Context ctx(1, &variableSpace);
|
||||||
|
|
||||||
|
auto _20 = NDArrayFactory::create_<float>('c', {2, 2});
|
||||||
|
|
||||||
|
ctx.pushNDArrayToVariableSpace(1, 1, _20);
|
||||||
|
|
||||||
|
ASSERT_TRUE(variableSpace.hasVariable(1, 1));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(ContextTests, Basic_Test_4) {
|
||||||
|
VariableSpace variableSpace;
|
||||||
|
|
||||||
|
Context ctx(1, &variableSpace);
|
||||||
|
|
||||||
|
auto _20 = NDArrayFactory::create_<float>('c', {2, 2});
|
||||||
|
_20->linspace(1);
|
||||||
|
|
||||||
|
auto _21 = NDArrayFactory::create_<float>('c', {2, 2});
|
||||||
|
_21->linspace(10);
|
||||||
|
|
||||||
|
ctx.pushNDArrayToVariableSpace(1, 1, _20);
|
||||||
|
|
||||||
|
ASSERT_TRUE(variableSpace.hasVariable(1, 1));
|
||||||
|
|
||||||
|
ctx.pushNDArrayToVariableSpace(1, 1, _21);
|
||||||
|
|
||||||
|
auto vA = ctx.variable(1, 1);
|
||||||
|
|
||||||
|
ASSERT_TRUE(vA->getNDArray()->equalsTo(_21));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ContextTests, Basic_Test_5) {
|
||||||
|
VariableSpace variableSpace;
|
||||||
|
|
||||||
|
Context ctx(1, &variableSpace);
|
||||||
|
|
||||||
|
auto _20 = NDArrayFactory::create_<float>('c', {2, 2});
|
||||||
|
_20->linspace(1);
|
||||||
|
|
||||||
|
auto exp = new NDArray(_20->dup());
|
||||||
|
|
||||||
|
ctx.pushNDArrayToVariableSpace(1, 1, _20);
|
||||||
|
|
||||||
|
ASSERT_TRUE(variableSpace.hasVariable(1, 1));
|
||||||
|
|
||||||
|
ctx.pushNDArrayToVariableSpace(1, 1, _20);
|
||||||
|
|
||||||
|
auto vA = ctx.variable(1, 1);
|
||||||
|
|
||||||
|
ASSERT_TRUE(vA->getNDArray() == _20);
|
||||||
|
|
||||||
|
ASSERT_TRUE(vA->getNDArray()->equalsTo(exp));
|
||||||
|
|
||||||
|
delete exp;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(ContextTests, Basic_Test_6) {
|
||||||
|
VariableSpace variableSpace;
|
||||||
|
|
||||||
|
Context ctx(1, &variableSpace);
|
||||||
|
|
||||||
|
auto v0 = ctx.ensureVariable();
|
||||||
|
auto v1 = ctx.ensureVariable(1);
|
||||||
|
|
||||||
|
ASSERT_TRUE(variableSpace.hasVariable(1, 0));
|
||||||
|
ASSERT_TRUE(variableSpace.hasVariable(1, 1));
|
||||||
|
|
||||||
|
auto var0 = variableSpace.getVariable(1, 0);
|
||||||
|
auto var1 = variableSpace.getVariable(1, 1);
|
||||||
|
|
||||||
|
ASSERT_TRUE(v0 == var0);
|
||||||
|
ASSERT_TRUE(v1 == var1);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(ContextTests, Basic_Test_7) {
|
||||||
|
VariableSpace variableSpace;
|
||||||
|
|
||||||
|
Context ctx(1, &variableSpace);
|
||||||
|
|
||||||
|
auto v0 = ctx.ensureVariable();
|
||||||
|
auto v1 = ctx.ensureVariable(1);
|
||||||
|
|
||||||
|
ASSERT_TRUE(variableSpace.hasVariable(1, 0));
|
||||||
|
ASSERT_TRUE(variableSpace.hasVariable(1, 1));
|
||||||
|
|
||||||
|
auto var0 = variableSpace.getVariable(1, 0);
|
||||||
|
auto var1 = variableSpace.getVariable(1, 1);
|
||||||
|
|
||||||
|
ASSERT_TRUE(v0 == var0);
|
||||||
|
ASSERT_TRUE(v1 == var1);
|
||||||
|
|
||||||
|
|
||||||
|
auto _10 = NDArrayFactory::create_<float>('c', {2, 2});
|
||||||
|
_10->linspace(1);
|
||||||
|
|
||||||
|
auto _11 = NDArrayFactory::create_<float>('c', {2, 2});
|
||||||
|
_11->linspace(10);
|
||||||
|
|
||||||
|
ctx.pushNDArrayToVariableSpace(1, 0, _10);
|
||||||
|
ctx.pushNDArrayToVariableSpace(1, 1, _11);
|
||||||
|
|
||||||
|
auto z0 = variableSpace.getVariable(1, 0);
|
||||||
|
auto z1 = variableSpace.getVariable(1, 1);
|
||||||
|
|
||||||
|
ASSERT_TRUE(v0 == z0);
|
||||||
|
ASSERT_TRUE(v1 == z1);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ContextTests, Basic_Test_8) {
|
||||||
|
VariableSpace variableSpace;
|
||||||
|
|
||||||
|
Context ctx(1, &variableSpace);
|
||||||
|
|
||||||
|
auto _10 = NDArrayFactory::create_<float>('c', {2, 2});
|
||||||
|
_10->linspace(1);
|
||||||
|
|
||||||
|
auto _11 = NDArrayFactory::create_<float>('c', {2, 2});
|
||||||
|
_11->linspace(10);
|
||||||
|
|
||||||
|
ctx.pushNDArrayToVariableSpace(1, 0, _10);
|
||||||
|
ctx.pushNDArrayToVariableSpace(1, 1, _11);
|
||||||
|
|
||||||
|
auto z0 = variableSpace.getVariable(1, 0);
|
||||||
|
auto z1 = variableSpace.getVariable(1, 1);
|
||||||
|
|
||||||
|
auto v0 = ctx.ensureVariable();
|
||||||
|
auto v1 = ctx.ensureVariable(1);
|
||||||
|
|
||||||
|
ASSERT_TRUE(v0 == z0);
|
||||||
|
ASSERT_TRUE(v1 == z1);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(ContextTests, Basic_Test_9) {
|
||||||
|
VariableSpace variableSpace;
|
||||||
|
|
||||||
|
auto in = NDArrayFactory::create<float>('c', {5, 5});
|
||||||
|
|
||||||
|
Context ctx(1, &variableSpace, true);
|
||||||
|
ctx.pushNDArrayToVariableSpace(1, 1, &in, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ContextTests, Basic_Test_10) {
|
||||||
|
VariableSpace variableSpace;
|
||||||
|
|
||||||
|
Context ctx(119, &variableSpace);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(ContextTests, Prototype_Test_1) {
|
||||||
|
ContextPrototype prototype(nullptr, 119, true);
|
||||||
|
prototype.pickInput(12, 3);
|
||||||
|
prototype.pickInput(12, 4);
|
||||||
|
|
||||||
|
prototype.getTArguments()->push_back(2.0);
|
||||||
|
prototype.getTArguments()->push_back(-2.0);
|
||||||
|
|
||||||
|
prototype.getIArguments()->push_back(17);
|
||||||
|
prototype.getIArguments()->push_back(119);
|
||||||
|
|
||||||
|
Context ctx(&prototype, nullptr);
|
||||||
|
|
||||||
|
ASSERT_EQ(ctx.nodeId(), prototype.nodeId());
|
||||||
|
ASSERT_EQ(ctx.isInplace(), prototype.isInplace());
|
||||||
|
|
||||||
|
ASSERT_EQ(2, ctx.inputs()->size());
|
||||||
|
ASSERT_EQ(2, ctx.getTArguments()->size());
|
||||||
|
ASSERT_EQ(2, ctx.getIArguments()->size());
|
||||||
|
|
||||||
|
ASSERT_EQ(2.0, ctx.getTArguments()->at(0));
|
||||||
|
ASSERT_EQ(-2.0, ctx.getTArguments()->at(1));
|
||||||
|
|
||||||
|
ASSERT_EQ(17, ctx.getIArguments()->at(0));
|
||||||
|
ASSERT_EQ(119, ctx.getIArguments()->at(1));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(ContextTests, Prototype_Test_2) {
|
||||||
|
ContextPrototype prototype(nullptr, 119, false);
|
||||||
|
prototype.setOpNum(179);
|
||||||
|
|
||||||
|
Context ctx(&prototype, nullptr);
|
||||||
|
|
||||||
|
ASSERT_EQ(ctx.isInplace(), prototype.isInplace());
|
||||||
|
ASSERT_EQ(ctx.opNum(), prototype.opNum());
|
||||||
|
|
||||||
|
ASSERT_EQ(0, ctx.inputs()->size());
|
||||||
|
ASSERT_EQ(0, ctx.getTArguments()->size());
|
||||||
|
ASSERT_EQ(0, ctx.getIArguments()->size());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ContextTests, test_short_context_1) {
|
||||||
|
auto array0 = NDArrayFactory::create<float>('c', {3, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f});
|
||||||
|
auto array1 = NDArrayFactory::create<float>('c', {3, 2}, {-1.f, -2.f, -3.f, -4.f, -5.f, -6.f});
|
||||||
|
Context ctx(1);
|
||||||
|
|
||||||
|
ctx.setInputArray(0, array0.buffer(), array0.shapeInfo(), array0.specialBuffer(), array0.specialShapeInfo());
|
||||||
|
ctx.setInputArray(1, array1.buffer(), array1.shapeInfo(), array1.specialBuffer(), array1.specialShapeInfo());
|
||||||
|
|
||||||
|
ASSERT_EQ(2, ctx.width());
|
||||||
|
|
||||||
|
auto input0 = ctx.array(0);
|
||||||
|
ASSERT_TRUE(input0 != nullptr);
|
||||||
|
|
||||||
|
auto input1 = ctx.array(1);
|
||||||
|
ASSERT_TRUE(input1 != nullptr);
|
||||||
|
|
||||||
|
ASSERT_TRUE(input0->buffer() == array0.buffer());
|
||||||
|
ASSERT_TRUE(input0->shapeInfo() == array0.shapeInfo());
|
||||||
|
|
||||||
|
ASSERT_TRUE(input0->specialBuffer() == array0.specialBuffer());
|
||||||
|
ASSERT_TRUE(input0->specialShapeInfo() == array0.specialShapeInfo());
|
||||||
|
|
||||||
|
ASSERT_TRUE(input1->buffer() == array1.buffer());
|
||||||
|
ASSERT_TRUE(input1->shapeInfo() == array1.shapeInfo());
|
||||||
|
|
||||||
|
ASSERT_TRUE(input1->specialBuffer() == array1.specialBuffer());
|
||||||
|
ASSERT_TRUE(input1->specialShapeInfo() == array1.specialShapeInfo());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ContextTests, test_short_context_2) {
|
||||||
|
auto array0 = NDArrayFactory::create<float>('c', {3, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f});
|
||||||
|
auto array1 = NDArrayFactory::create<float>('c', {3, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f});
|
||||||
|
auto z = NDArrayFactory::create<float>('c', {3, 2});
|
||||||
|
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', {3, 2}, {2.f, 4.f, 6.f, 8.f, 10.f, 12.f});
|
||||||
|
Context ctx(1);
|
||||||
|
|
||||||
|
ctx.setInputArray(0, array0.buffer(), array0.shapeInfo(), array0.specialBuffer(), array0.specialShapeInfo());
|
||||||
|
ctx.setInputArray(1, array1.buffer(), array1.shapeInfo(), array1.specialBuffer(), array1.specialShapeInfo());
|
||||||
|
ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo());
|
||||||
|
|
||||||
|
ASSERT_EQ(2, ctx.width());
|
||||||
|
|
||||||
|
sd::ops::add op;
|
||||||
|
op.execute(&ctx);
|
||||||
|
|
||||||
|
ASSERT_EQ(exp, z);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ContextTests, test_short_context_3) {
|
||||||
|
auto array0 = NDArrayFactory::create<float>('c', {3, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f});
|
||||||
|
auto array1 = NDArrayFactory::create<float>('c', {3, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f});
|
||||||
|
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', {3, 2}, {2.f, 4.f, 6.f, 8.f, 10.f, 12.f});
|
||||||
|
Context ctx(1);
|
||||||
|
|
||||||
|
ctx.setInputArray(0, array0.buffer(), array0.shapeInfo(), array0.specialBuffer(), array0.specialShapeInfo());
|
||||||
|
ctx.setInputArray(1, array1.buffer(), array1.shapeInfo(), array1.specialBuffer(), array1.specialShapeInfo());
|
||||||
|
|
||||||
|
ASSERT_EQ(2, ctx.width());
|
||||||
|
|
||||||
|
sd::ops::add op;
|
||||||
|
op.execute(&ctx);
|
||||||
|
|
||||||
|
ASSERT_EQ(1, ctx.fastpath_out().size());
|
||||||
|
|
||||||
|
auto z = ctx.fastpath_out()[0];
|
||||||
|
|
||||||
|
ASSERT_EQ(exp, *z);
|
||||||
|
}
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
|
@ -0,0 +1,150 @@
|
||||||
|
/* ******************************************************************************
|
||||||
|
*
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* See the NOTICE file distributed with this work for additional
|
||||||
|
* information regarding copyright ownership.
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "testlayers.h"
|
||||||
|
#include <initializer_list>
|
||||||
|
#include <array/NDArrayFactory.h>
|
||||||
|
#include <ops/declarable/PlatformHelper.h>
|
||||||
|
#include <ops/declarable/CustomOperations.h>
|
||||||
|
#include <execution/Engine.h>
|
||||||
|
|
||||||
|
#ifdef HAVE_CUDNN
|
||||||
|
|
||||||
|
#include <ops/declarable/platform/cudnn/cudnnUtils.h>
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
using namespace sd;
|
||||||
|
|
||||||
|
class CuDnnTests : public testing::Test {
|
||||||
|
public:
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
static void printer(std::initializer_list<sd::ops::platforms::PlatformHelper*> helpers) {
|
||||||
|
|
||||||
|
for (auto v:helpers) {
|
||||||
|
nd4j_printf("Initialized [%s]\n", v->name().c_str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(CuDnnTests, helpers_includer) {
|
||||||
|
// we need this block, to make sure all helpers are still available within binary, and not optimized out by linker
|
||||||
|
#ifdef HAVE_CUDNN
|
||||||
|
sd::ops::platforms::PLATFORM_conv2d_ENGINE_CUDA conv2d;
|
||||||
|
sd::ops::platforms::PLATFORM_conv2d_bp_ENGINE_CUDA conv2d_bp;
|
||||||
|
sd::ops::platforms::PLATFORM_conv3dnew_ENGINE_CUDA conv3dnew;
|
||||||
|
sd::ops::platforms::PLATFORM_conv3dnew_bp_ENGINE_CUDA conv3dnew_bp;
|
||||||
|
sd::ops::platforms::PLATFORM_depthwise_conv2d_ENGINE_CUDA depthwise_conv2d;
|
||||||
|
sd::ops::platforms::PLATFORM_depthwise_conv2d_bp_ENGINE_CUDA depthwise_conv2d_bp;
|
||||||
|
sd::ops::platforms::PLATFORM_batchnorm_ENGINE_CUDA batchnorm;
|
||||||
|
sd::ops::platforms::PLATFORM_batchnorm_bp_ENGINE_CUDA batchnorm_bp;
|
||||||
|
sd::ops::platforms::PLATFORM_avgpool2d_ENGINE_CUDA avgpool2d;
|
||||||
|
sd::ops::platforms::PLATFORM_avgpool2d_bp_ENGINE_CUDA avgpool2d_bp;
|
||||||
|
sd::ops::platforms::PLATFORM_maxpool2d_ENGINE_CUDA maxpool2d;
|
||||||
|
sd::ops::platforms::PLATFORM_maxpool2d_bp_ENGINE_CUDA maxpool2d_bp;
|
||||||
|
sd::ops::platforms::PLATFORM_avgpool3dnew_ENGINE_CUDA avgpool3dnew;
|
||||||
|
sd::ops::platforms::PLATFORM_avgpool3dnew_bp_ENGINE_CUDA avgpool3dnew_bp;
|
||||||
|
sd::ops::platforms::PLATFORM_maxpool3dnew_ENGINE_CUDA maxpool3dnew;
|
||||||
|
sd::ops::platforms::PLATFORM_maxpool3dnew_bp_ENGINE_CUDA maxpool3dnew_bp;
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
printer({&conv2d});
|
||||||
|
printer({&conv2d_bp});
|
||||||
|
printer({&conv3dnew});
|
||||||
|
printer({&conv3dnew_bp});
|
||||||
|
printer({&depthwise_conv2d});
|
||||||
|
printer({&depthwise_conv2d_bp});
|
||||||
|
printer({&batchnorm});
|
||||||
|
printer({&batchnorm_bp});
|
||||||
|
printer({&avgpool2d});
|
||||||
|
printer({&avgpool2d_bp});
|
||||||
|
printer({&maxpool2d});
|
||||||
|
printer({&maxpool2d_bp});
|
||||||
|
printer({&avgpool3dnew});
|
||||||
|
printer({&avgpool3dnew_bp});
|
||||||
|
printer({&maxpool3dnew});
|
||||||
|
printer({&maxpool3dnew_bp});
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(CuDnnTests, mixed_helpers_test_1) {
|
||||||
|
#if defined(HAVE_CUDNN) && defined (HAVE_MKLDNN)
|
||||||
|
nd4j_printf("Mixed platforms test\n", "");
|
||||||
|
|
||||||
|
|
||||||
|
int bS=2, iH=4,iW=3, iC=4,oC=3, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1;
|
||||||
|
int oH=2,oW=2;
|
||||||
|
int paddingMode = 0; // 1-SAME, 0-VALID;
|
||||||
|
int dataFormat = 0; // 1-NHWC, 0-NCHW
|
||||||
|
|
||||||
|
auto input = NDArrayFactory::create<float>('c', {bS, iC, iH, iW});
|
||||||
|
auto weights = NDArrayFactory::create<float>('c', {oC, iC, kH, kW});
|
||||||
|
auto bias = NDArrayFactory::create<float>('c', {oC}, {1,2,3});
|
||||||
|
|
||||||
|
auto expOutput = NDArrayFactory::create<float>('c', {bS, oC, oH, oW}, {61.f, 61.f, 61.f, 61.f, 177.2f, 177.2f, 177.2f, 177.2f, 293.4f, 293.4f, 293.4f, 293.4f, 61.f, 61.f, 61.f, 61.f, 177.2f, 177.2f, 177.2f, 177.2f, 293.4f, 293.4f, 293.4f, 293.4f});
|
||||||
|
auto zCUDA = expOutput.like();
|
||||||
|
auto zMKL = expOutput.like();
|
||||||
|
|
||||||
|
input = 2.;
|
||||||
|
weights.linspace(0.1, 0.1);
|
||||||
|
weights.permutei({2,3,1,0});
|
||||||
|
|
||||||
|
input.syncToHost();
|
||||||
|
weights.syncToHost();
|
||||||
|
bias.syncToHost();
|
||||||
|
|
||||||
|
sd::ops::conv2d op;
|
||||||
|
|
||||||
|
// cuDNN part
|
||||||
|
Context cuda(1);
|
||||||
|
cuda.setTargetEngine(samediff::Engine::ENGINE_CUDA);
|
||||||
|
cuda.setInputArray(0, &input);
|
||||||
|
cuda.setInputArray(1, &weights);
|
||||||
|
cuda.setInputArray(2, &bias);
|
||||||
|
cuda.setOutputArray(0, &zCUDA);
|
||||||
|
cuda.setIArguments({kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
|
||||||
|
auto statusCUDA = op.execute(&cuda);
|
||||||
|
|
||||||
|
ASSERT_EQ(Status::OK(), statusCUDA);
|
||||||
|
ASSERT_EQ(expOutput, zCUDA);
|
||||||
|
|
||||||
|
// MKL-DNN part
|
||||||
|
Context mkl(1);
|
||||||
|
mkl.setTargetEngine(samediff::Engine::ENGINE_CPU);
|
||||||
|
mkl.setInputArray(0, &input);
|
||||||
|
mkl.setInputArray(1, &weights);
|
||||||
|
mkl.setInputArray(2, &bias);
|
||||||
|
mkl.setOutputArray(0, &zMKL);
|
||||||
|
mkl.setIArguments({kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
|
||||||
|
auto statusMKL = op.execute(&mkl);
|
||||||
|
|
||||||
|
zMKL.tickWriteHost();
|
||||||
|
|
||||||
|
ASSERT_EQ(Status::OK(), statusMKL);
|
||||||
|
ASSERT_EQ(expOutput, zMKL);
|
||||||
|
#endif
|
||||||
|
}
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,76 @@
|
||||||
|
/* ******************************************************************************
|
||||||
|
*
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* See the NOTICE file distributed with this work for additional
|
||||||
|
* information regarding copyright ownership.
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "testlayers.h"
|
||||||
|
#include <array/ExtraArguments.h>
|
||||||
|
#include <array>
|
||||||
|
#include <cuda.h>
|
||||||
|
#include <cuda_runtime.h>
|
||||||
|
|
||||||
|
using namespace sd;
|
||||||
|
|
||||||
|
class CudaExtraArgumentsTests : public testing::Test {
|
||||||
|
public:
|
||||||
|
|
||||||
|
CudaExtraArgumentsTests() {
|
||||||
|
printf("\n");
|
||||||
|
fflush(stdout);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(CudaExtraArgumentsTests, Basic_Test_1) {
|
||||||
|
ExtraArguments args({1.0, 2.0, 3.0});
|
||||||
|
|
||||||
|
float ef[] = {1.f, 2.f, 3.f};
|
||||||
|
double ed[] = {1., 2., 3.};
|
||||||
|
|
||||||
|
auto ptrFloat = reinterpret_cast<float *>(args.argumentsAsT<float>());
|
||||||
|
auto ptrDouble = reinterpret_cast<double *>(args.argumentsAsT<double>());
|
||||||
|
ASSERT_TRUE(ptrFloat != nullptr);
|
||||||
|
ASSERT_TRUE(ptrDouble != nullptr);
|
||||||
|
|
||||||
|
auto tmpFloat = new float[3];
|
||||||
|
auto tmpDouble = new double[3];
|
||||||
|
|
||||||
|
cudaMemcpy(tmpFloat, ptrFloat, 3 * sizeof(float), cudaMemcpyDeviceToHost);
|
||||||
|
cudaMemcpy(tmpDouble, ptrDouble, 3 * sizeof(double), cudaMemcpyDeviceToHost);
|
||||||
|
|
||||||
|
for (int e = 0; e < 3; e++) {
|
||||||
|
ASSERT_NEAR(ef[e], tmpFloat[e], 1e-5f);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int e = 0; e < 3; e++) {
|
||||||
|
ASSERT_NEAR(ed[e], tmpDouble[e], 1e-5);
|
||||||
|
}
|
||||||
|
|
||||||
|
delete[] tmpFloat;
|
||||||
|
delete[] tmpDouble;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(CudaExtraArgumentsTests, Basic_Test_2) {
|
||||||
|
ExtraArguments args;
|
||||||
|
|
||||||
|
auto ptrInt = args.argumentsAsT<int>();
|
||||||
|
ASSERT_TRUE(ptrInt == nullptr);
|
||||||
|
}
|
||||||
|
|
|
@ -0,0 +1,48 @@
|
||||||
|
/* ******************************************************************************
|
||||||
|
*
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* See the NOTICE file distributed with this work for additional
|
||||||
|
* information regarding copyright ownership.
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// Created by raver on 11/26/2018.
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "testlayers.h"
|
||||||
|
#include <helpers/CudaLaunchHelper.h>
|
||||||
|
|
||||||
|
using namespace sd;
|
||||||
|
using namespace sd::graph;
|
||||||
|
|
||||||
|
class CudaLaunchHelperTests : public testing::Test {
|
||||||
|
public:
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(CudaLaunchHelperTests, test_reduction_blocks_1) {
|
||||||
|
ASSERT_EQ(1, CudaLaunchHelper::getReductionBlocks(512));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(CudaLaunchHelperTests, test_reduction_blocks_2) {
|
||||||
|
ASSERT_EQ(1, CudaLaunchHelper::getReductionBlocks(121));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(CudaLaunchHelperTests, test_reduction_blocks_3) {
|
||||||
|
ASSERT_EQ(2, CudaLaunchHelper::getReductionBlocks(513));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(CudaLaunchHelperTests, test_reduction_blocks_4) {
|
||||||
|
ASSERT_EQ(3, CudaLaunchHelper::getReductionBlocks(1225));
|
||||||
|
}
|
|
@ -0,0 +1,80 @@
|
||||||
|
/* ******************************************************************************
|
||||||
|
*
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* See the NOTICE file distributed with this work for additional
|
||||||
|
* information regarding copyright ownership.
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "testlayers.h"
|
||||||
|
#include <array/NDArray.h>
|
||||||
|
#include <graph/Context.h>
|
||||||
|
#include <graph/Node.h>
|
||||||
|
#include <graph/Variable.h>
|
||||||
|
#include <graph/VariableSpace.h>
|
||||||
|
#include <ops/declarable/CustomOperations.h>
|
||||||
|
#include <ops/declarable/helpers/convolutions.h>
|
||||||
|
#include <ops/declarable/helpers/col2im.h>
|
||||||
|
#include <helpers/RandomLauncher.h>
|
||||||
|
|
||||||
|
using namespace sd;
|
||||||
|
using namespace sd::graph;
|
||||||
|
using namespace sd::memory;
|
||||||
|
|
||||||
|
class DataBufferTests : public testing::Test {
|
||||||
|
public:
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(DataBufferTests, test_alloc_limit_1) {
|
||||||
|
if (!Environment::getInstance().isCPU())
|
||||||
|
return;
|
||||||
|
|
||||||
|
auto deviceId = AffinityManager::currentDeviceId();
|
||||||
|
auto odLimit = MemoryCounter::getInstance().deviceLimit(deviceId);
|
||||||
|
auto ogLimit = MemoryCounter::getInstance().groupLimit(MemoryType::HOST);
|
||||||
|
auto odUse = MemoryCounter::getInstance().allocatedDevice(deviceId);
|
||||||
|
auto ogUse = MemoryCounter::getInstance().allocatedGroup(MemoryType::HOST);
|
||||||
|
|
||||||
|
auto limitSize = odUse + (150 * 1024 * 1024);
|
||||||
|
auto allocSize = 100000000;
|
||||||
|
|
||||||
|
MemoryCounter::getInstance().setDeviceLimit(deviceId, odLimit + limitSize);
|
||||||
|
MemoryCounter::getInstance().setGroupLimit(MemoryType::HOST, odLimit + limitSize);
|
||||||
|
|
||||||
|
DataBuffer buffer(allocSize, DataType::INT32);
|
||||||
|
|
||||||
|
// separately testing per-device limits and group limits
|
||||||
|
ASSERT_EQ(odUse + allocSize, MemoryCounter::getInstance().allocatedDevice(deviceId));
|
||||||
|
ASSERT_EQ(ogUse + allocSize, MemoryCounter::getInstance().allocatedGroup(MemoryType::HOST));
|
||||||
|
|
||||||
|
|
||||||
|
// setting smaller limits, to make sure next allocation fails with OOM exception
|
||||||
|
MemoryCounter::getInstance().setDeviceLimit(deviceId, allocSize - 100);
|
||||||
|
MemoryCounter::getInstance().setGroupLimit(MemoryType::HOST, allocSize - 100);
|
||||||
|
|
||||||
|
try {
|
||||||
|
DataBuffer bufferFailed(allocSize, DataType::INT32);
|
||||||
|
ASSERT_TRUE(false);
|
||||||
|
} catch (allocation_exception &e) {
|
||||||
|
// we expect exception here
|
||||||
|
}
|
||||||
|
|
||||||
|
// restore original limits, so subsequent tests do not fail
|
||||||
|
MemoryCounter::getInstance().setDeviceLimit(deviceId, odLimit);
|
||||||
|
MemoryCounter::getInstance().setGroupLimit(MemoryType::HOST, odLimit);
|
||||||
|
}
|
|
@ -0,0 +1,91 @@
|
||||||
|
/* ******************************************************************************
|
||||||
|
*
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* See the NOTICE file distributed with this work for additional
|
||||||
|
* information regarding copyright ownership.
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "testlayers.h"
|
||||||
|
#include <array/NDArray.h>
|
||||||
|
#include <graph/Context.h>
|
||||||
|
#include <graph/Node.h>
|
||||||
|
#include <graph/Variable.h>
|
||||||
|
#include <graph/VariableSpace.h>
|
||||||
|
#include <ops/declarable/CustomOperations.h>
|
||||||
|
#include <ops/declarable/helpers/convolutions.h>
|
||||||
|
#include <ops/declarable/helpers/col2im.h>
|
||||||
|
#include <helpers/RandomLauncher.h>
|
||||||
|
|
||||||
|
using namespace sd;
|
||||||
|
using namespace sd::graph;
|
||||||
|
using namespace sd::memory;
|
||||||
|
|
||||||
|
class DataBufferTestsCuda : public testing::Test {
|
||||||
|
public:
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
/*
|
||||||
|
TEST_F(DataBufferTestsCuda, test_alloc_limit_1) {
|
||||||
|
auto deviceId = AffinityManager::currentDeviceId();
|
||||||
|
|
||||||
|
auto odLimit = MemoryCounter::getInstance().deviceLimit(deviceId);
|
||||||
|
|
||||||
|
auto opLimit = MemoryCounter::getInstance().groupLimit(MemoryType::HOST);
|
||||||
|
auto osLimit = MemoryCounter::getInstance().groupLimit(MemoryType::DEVICE);
|
||||||
|
|
||||||
|
auto odUse = MemoryCounter::getInstance().allocatedDevice(deviceId);
|
||||||
|
|
||||||
|
auto opUse = MemoryCounter::getInstance().allocatedGroup(MemoryType::HOST);
|
||||||
|
auto osUse = MemoryCounter::getInstance().allocatedGroup(MemoryType::DEVICE);
|
||||||
|
|
||||||
|
auto limitSize = odUse + 150000000;
|
||||||
|
auto allocSize = 100000000;
|
||||||
|
|
||||||
|
MemoryCounter::getInstance().setDeviceLimit(deviceId, odLimit + limitSize);
|
||||||
|
MemoryCounter::getInstance().setGroupLimit(MemoryType::HOST, opLimit + limitSize);
|
||||||
|
MemoryCounter::getInstance().setGroupLimit(MemoryType::DEVICE, osLimit + limitSize);
|
||||||
|
|
||||||
|
DataBuffer buffer(allocSize, DataType::INT32, nullptr, true);
|
||||||
|
|
||||||
|
// separately testing per-device limits and group limits
|
||||||
|
ASSERT_EQ(odUse + allocSize, MemoryCounter::getInstance().allocatedDevice(deviceId));
|
||||||
|
ASSERT_EQ(opUse + allocSize, MemoryCounter::getInstance().allocatedGroup(MemoryType::HOST));
|
||||||
|
ASSERT_EQ(osUse + allocSize, MemoryCounter::getInstance().allocatedGroup(MemoryType::DEVICE));
|
||||||
|
|
||||||
|
// setting smaller limits, to make sure next allocation fails with OOM exception
|
||||||
|
MemoryCounter::getInstance().setDeviceLimit(deviceId, allocSize - 100);
|
||||||
|
MemoryCounter::getInstance().setGroupLimit(MemoryType::DEVICE, allocSize - 100);
|
||||||
|
|
||||||
|
|
||||||
|
// this allocation should fail, since we're allocating too much
|
||||||
|
try {
|
||||||
|
DataBuffer bufferFailed(allocSize + 1, DataType::INT32);
|
||||||
|
ASSERT_TRUE(false);
|
||||||
|
} catch (allocation_exception &e) {
|
||||||
|
// we expect exception here
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
|
||||||
|
// restore original limits, so subsequent tests do not fail
|
||||||
|
MemoryCounter::getInstance().setDeviceLimit(deviceId, odLimit);
|
||||||
|
MemoryCounter::getInstance().setGroupLimit(MemoryType::HOST, opLimit);
|
||||||
|
MemoryCounter::getInstance().setGroupLimit(MemoryType::DEVICE, osLimit);
|
||||||
|
}
|
||||||
|
*/
|
|
@ -0,0 +1,158 @@
|
||||||
|
/* ******************************************************************************
|
||||||
|
*
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* See the NOTICE file distributed with this work for additional
|
||||||
|
* information regarding copyright ownership.
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "testlayers.h"
|
||||||
|
#include <array/NDArray.h>
|
||||||
|
#include <graph/Context.h>
|
||||||
|
#include <graph/Node.h>
|
||||||
|
#include <graph/Variable.h>
|
||||||
|
#include <graph/VariableSpace.h>
|
||||||
|
#include <ops/declarable/CustomOperations.h>
|
||||||
|
#include <ops/declarable/helpers/convolutions.h>
|
||||||
|
#include <ops/declarable/helpers/col2im.h>
|
||||||
|
#include <helpers/RandomLauncher.h>
|
||||||
|
|
||||||
|
using namespace sd;
|
||||||
|
using namespace sd::graph;
|
||||||
|
|
||||||
|
class DataTypesValidationTests : public testing::Test {
|
||||||
|
public:
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(DataTypesValidationTests, Basic_Test_1) {
|
||||||
|
auto input = NDArrayFactory::create<int8_t>('c', {1, 1, 1, 4});
|
||||||
|
auto weights = NDArrayFactory::create<int8_t>('c', {1, 1, 1, 4});
|
||||||
|
auto exp = NDArrayFactory::create<double>('c', {1, 4, 1, 4}, {2., 4., 6., 8., 2., 4., 6., 8., 2., 4., 6., 8., 2., 4., 6., 8.});
|
||||||
|
|
||||||
|
weights.assign(2.0);
|
||||||
|
input.linspace(1);
|
||||||
|
|
||||||
|
sd::ops::conv2d op;
|
||||||
|
auto result = op.evaluate({&input, &weights}, {1, 1, 1, 1, 0, 0, 1, 1, 0, 0});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_VALIDATION, result.status());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DataTypesValidationTests, Basic_Test_2) {
|
||||||
|
auto input = NDArrayFactory::create<float16>('c', {1, 1, 1, 4});
|
||||||
|
auto weights = NDArrayFactory::create<float16>('c', {1, 1, 1, 4});
|
||||||
|
auto exp = NDArrayFactory::create<float16>('c', {1, 4, 1, 4}, {2., 4., 6., 8., 2., 4., 6., 8., 2., 4., 6., 8., 2., 4., 6., 8.});
|
||||||
|
|
||||||
|
weights.assign(2.0);
|
||||||
|
input.linspace(1);
|
||||||
|
|
||||||
|
sd::ops::conv2d op;
|
||||||
|
auto result = op.evaluate({&input, &weights}, {1, 1, 1, 1, 0, 0, 1, 1, 0, 0});
|
||||||
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
|
auto z = result.at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(DataTypesValidationTests, Basic_Test_3) {
|
||||||
|
auto input = NDArrayFactory::create<bfloat16>('c', {1, 1, 1, 4});
|
||||||
|
auto weights = NDArrayFactory::create<bfloat16>('c', {1, 1, 1, 4});
|
||||||
|
auto exp = NDArrayFactory::create<bfloat16>('c', {1, 4, 1, 4}, {2., 4., 6., 8., 2., 4., 6., 8., 2., 4., 6., 8., 2., 4., 6., 8.});
|
||||||
|
auto out = NDArrayFactory::create<bfloat16>('c', {1, 4, 1, 4});
|
||||||
|
|
||||||
|
weights.assign(2.0);
|
||||||
|
input.linspace(1);
|
||||||
|
|
||||||
|
sd::ops::conv2d op;
|
||||||
|
auto result = op.execute({&input, &weights}, {&out}, {}, {1, 1, 1, 1, 0, 0, 1, 1, 0, 0}, {});
|
||||||
|
ASSERT_EQ(Status::OK(), result);
|
||||||
|
|
||||||
|
ASSERT_EQ(exp, out);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DataTypesValidationTests, Basic_Test_4) {
|
||||||
|
auto input = NDArrayFactory::create<int8_t>('c', {1, 1, 1, 4});
|
||||||
|
auto weights = NDArrayFactory::create<float16>('c', {1, 1, 1, 4});
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', {1, 4, 1, 4}, {2., 4., 6., 8., 2., 4., 6., 8., 2., 4., 6., 8., 2., 4., 6., 8.});
|
||||||
|
auto out = NDArrayFactory::create<int8_t>('c', {1, 4, 1, 4});
|
||||||
|
|
||||||
|
weights.assign(2.0);
|
||||||
|
input.linspace(1);
|
||||||
|
|
||||||
|
sd::ops::conv2d op;
|
||||||
|
auto result = op.execute({&input, &weights}, {&out}, {}, {1, 1, 1, 1, 0, 0, 1, 1, 0, 0}, {});
|
||||||
|
ASSERT_EQ(ND4J_STATUS_VALIDATION, result);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DataTypesValidationTests, test_bfloat16_rand_1) {
|
||||||
|
auto x = NDArrayFactory::create<bfloat16>('c', {5, 10});
|
||||||
|
RandomGenerator gen(119, 120);
|
||||||
|
RandomLauncher::fillUniform(LaunchContext::defaultContext(), gen, &x, 1, 6);
|
||||||
|
|
||||||
|
ASSERT_TRUE(x.sumNumber().e<float>(0) != 0.f);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DataTypesValidationTests, test_bfloat16_rand_2) {
|
||||||
|
auto x = NDArrayFactory::create<bfloat16>('c', {5, 10});
|
||||||
|
RandomGenerator gen(119, 120);
|
||||||
|
RandomLauncher::fillGaussian(LaunchContext::defaultContext(), gen, &x, 0, 1);
|
||||||
|
|
||||||
|
ASSERT_TRUE(x.sumNumber().e<float>(0) != 0.f);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DataTypesValidationTests, cast_1) {
|
||||||
|
|
||||||
|
float16 x = static_cast<float16>(1.f);
|
||||||
|
float y = static_cast<float16>(x);
|
||||||
|
|
||||||
|
ASSERT_TRUE(static_cast<float16>(1.f) == x);
|
||||||
|
ASSERT_TRUE(y == static_cast<float>(x));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DataTypesValidationTests, test_bits_hamming_distance_1) {
|
||||||
|
auto x = NDArrayFactory::create<int>('c', {3}, {0b01011000, 0b01011111, 0b01111110});
|
||||||
|
auto y = NDArrayFactory::create<int>('c', {3}, {0b00010110, 0b01011000, 0b01011000});
|
||||||
|
auto z = NDArrayFactory::create<uint64_t>(0);
|
||||||
|
|
||||||
|
Context ctx(1);
|
||||||
|
ctx.setInputArray(0, &x);
|
||||||
|
ctx.setInputArray(1, &y);
|
||||||
|
ctx.setOutputArray(0, &z);
|
||||||
|
|
||||||
|
sd::ops::bits_hamming_distance op;
|
||||||
|
auto status = op.execute(&ctx);
|
||||||
|
ASSERT_NE(Status::OK(), status);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DataTypesValidationTests, test_bits_hamming_distance_2) {
|
||||||
|
auto x = NDArrayFactory::create<int>('c', {3}, {0b01011000, 0b01011111, 0b01111110});
|
||||||
|
auto y = NDArrayFactory::create<int>('c', {3}, {0b00010110, 0b01011000, 0b01011000});
|
||||||
|
auto z = NDArrayFactory::create<Nd4jLong>(0);
|
||||||
|
|
||||||
|
Context ctx(1);
|
||||||
|
ctx.setInputArray(0, &x);
|
||||||
|
ctx.setInputArray(1, &y);
|
||||||
|
ctx.setOutputArray(0, &z);
|
||||||
|
|
||||||
|
sd::ops::bits_hamming_distance op;
|
||||||
|
auto status = op.execute(&ctx);
|
||||||
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
}
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,94 @@
|
||||||
|
/*
|
||||||
|
* ******************************************************************************
|
||||||
|
* *
|
||||||
|
* *
|
||||||
|
* * This program and the accompanying materials are made available under the
|
||||||
|
* * terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
* *
|
||||||
|
* * See the NOTICE file distributed with this work for additional
|
||||||
|
* * information regarding copyright ownership.
|
||||||
|
* * Unless required by applicable law or agreed to in writing, software
|
||||||
|
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* * License for the specific language governing permissions and limitations
|
||||||
|
* * under the License.
|
||||||
|
* *
|
||||||
|
* * SPDX-License-Identifier: Apache-2.0
|
||||||
|
* *****************************************************************************
|
||||||
|
*/
|
||||||
|
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "testlayers.h"
|
||||||
|
#include <ops/declarable/CustomOperations.h>
|
||||||
|
#include <array/NDArray.h>
|
||||||
|
#include <ops/ops.h>
|
||||||
|
#include <helpers/GradCheck.h>
|
||||||
|
#include <array>
|
||||||
|
|
||||||
|
|
||||||
|
using namespace sd;
|
||||||
|
|
||||||
|
|
||||||
|
class DeclarableOpsTests17 : public testing::Test {
|
||||||
|
public:
|
||||||
|
|
||||||
|
DeclarableOpsTests17() {
|
||||||
|
printf("\n");
|
||||||
|
fflush(stdout);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests17, test_sparse_to_dense_1) {
|
||||||
|
auto values = NDArrayFactory::create<float>({1.f, 2.f, 3.f});
|
||||||
|
auto shape = NDArrayFactory::create<Nd4jLong>({3, 3});
|
||||||
|
auto ranges = NDArrayFactory::create<Nd4jLong>({0,0, 1,1, 2,2});
|
||||||
|
auto def = NDArrayFactory::create<float>(0.f);
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', {3, 3}, {1.f,0.f,0.f, 0.f,2.f,0.f, 0.f,0.f,3.f});
|
||||||
|
|
||||||
|
|
||||||
|
sd::ops::compat_sparse_to_dense op;
|
||||||
|
auto result = op.evaluate({&ranges, &shape, &values, &def});
|
||||||
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests17, test_sparse_to_dense_2) {
|
||||||
|
auto values = NDArrayFactory::string({3}, {"alpha", "beta", "gamma"});
|
||||||
|
auto shape = NDArrayFactory::create<Nd4jLong>({3, 3});
|
||||||
|
auto ranges = NDArrayFactory::create<Nd4jLong>({0,0, 1,1, 2,2});
|
||||||
|
auto def = NDArrayFactory::string("d");
|
||||||
|
auto exp = NDArrayFactory::string( {3, 3}, {"alpha","d","d", "d","beta","d", "d","d","gamma"});
|
||||||
|
|
||||||
|
|
||||||
|
sd::ops::compat_sparse_to_dense op;
|
||||||
|
auto result = op.evaluate({&ranges, &shape, &values, &def});
|
||||||
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests17, test_compat_string_split_1) {
|
||||||
|
auto x = NDArrayFactory::string( {2}, {"first string", "second"});
|
||||||
|
auto delimiter = NDArrayFactory::string(" ");
|
||||||
|
|
||||||
|
auto exp0 = NDArrayFactory::create<Nd4jLong>({0,0, 0,1, 1,0});
|
||||||
|
auto exp1 = NDArrayFactory::string( {3}, {"first", "string", "second"});
|
||||||
|
|
||||||
|
sd::ops::compat_string_split op;
|
||||||
|
auto result = op.evaluate({&x, &delimiter});
|
||||||
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
ASSERT_EQ(2, result.size());
|
||||||
|
|
||||||
|
auto z0 = result.at(0);
|
||||||
|
auto z1 = result.at(1);
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp0.isSameShape(z0));
|
||||||
|
ASSERT_TRUE(exp1.isSameShape(z1));
|
||||||
|
|
||||||
|
ASSERT_EQ(exp0, *z0);
|
||||||
|
ASSERT_EQ(exp1, *z1);
|
||||||
|
|
||||||
|
}
|
File diff suppressed because one or more lines are too long
|
@ -0,0 +1,427 @@
|
||||||
|
/*
|
||||||
|
* ******************************************************************************
|
||||||
|
* *
|
||||||
|
* *
|
||||||
|
* * This program and the accompanying materials are made available under the
|
||||||
|
* * terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
* *
|
||||||
|
* * See the NOTICE file distributed with this work for additional
|
||||||
|
* * information regarding copyright ownership.
|
||||||
|
* * Unless required by applicable law or agreed to in writing, software
|
||||||
|
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* * License for the specific language governing permissions and limitations
|
||||||
|
* * under the License.
|
||||||
|
* *
|
||||||
|
* * SPDX-License-Identifier: Apache-2.0
|
||||||
|
* *****************************************************************************
|
||||||
|
*/
|
||||||
|
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "testlayers.h"
|
||||||
|
#include <ops/declarable/CustomOperations.h>
|
||||||
|
#include <array/NDArray.h>
|
||||||
|
#include <ops/ops.h>
|
||||||
|
#include <helpers/GradCheck.h>
|
||||||
|
#include <array>
|
||||||
|
#include <helpers/RandomLauncher.h>
|
||||||
|
|
||||||
|
|
||||||
|
using namespace sd;
|
||||||
|
|
||||||
|
|
||||||
|
class DeclarableOpsTests19 : public testing::Test {
|
||||||
|
public:
|
||||||
|
|
||||||
|
DeclarableOpsTests19() {
|
||||||
|
printf("\n");
|
||||||
|
fflush(stdout);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests19, test_argmax_maxint_vector_1) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {3}, {0.1f, 0.5f, 0.7f});
|
||||||
|
auto z = NDArrayFactory::create<Nd4jLong>(0);
|
||||||
|
auto e = NDArrayFactory::create<Nd4jLong>(2);
|
||||||
|
|
||||||
|
sd::ops::argmax op;
|
||||||
|
auto status = op.execute({&x}, {&z}, {DataTypeUtils::max<int>()});
|
||||||
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
ASSERT_EQ(e, z);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests19, test_threshold_encode_1) {
|
||||||
|
auto x = NDArrayFactory::create<double>('c', {3}, {1.5, 2.5, -3.5});
|
||||||
|
auto exp_encoded = NDArrayFactory::create<int>('c', {7}, {3, 3, 1056964608, 0, 1, 2, -3});
|
||||||
|
auto exp_gradients = NDArrayFactory::create<double>('c', {3}, {1.0, 2.0, -3.0});
|
||||||
|
|
||||||
|
sd::ops::encode_threshold op;
|
||||||
|
auto result = op.evaluate({&x}, {0.5});
|
||||||
|
|
||||||
|
auto gradients = result.at(0);
|
||||||
|
auto encoded = result.at(1);
|
||||||
|
|
||||||
|
//encoded->printIndexedBuffer("ENC");
|
||||||
|
|
||||||
|
ASSERT_EQ(exp_encoded, *encoded);
|
||||||
|
ASSERT_EQ(exp_gradients, x);
|
||||||
|
|
||||||
|
// FIXME: we need to add a way to declare individual inplace outputs
|
||||||
|
//ASSERT_EQ(exp_gradients, *gradients);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests19, test_threshold_encode_2) {
|
||||||
|
for (int length = 5; length < 35; length++) {
|
||||||
|
auto x = NDArrayFactory::create<double>('c', {10000});
|
||||||
|
auto exp_gradients = NDArrayFactory::create<double>('c', {10000});
|
||||||
|
|
||||||
|
for (int e = 0; e < length; e++) {
|
||||||
|
x.p(e, 2e-3);
|
||||||
|
exp_gradients.p(e, 1e-3);
|
||||||
|
}
|
||||||
|
|
||||||
|
sd::ops::encode_threshold op;
|
||||||
|
auto result = op.evaluate({&x}, {1e-3});
|
||||||
|
|
||||||
|
auto encoded = result.at(1);
|
||||||
|
|
||||||
|
ASSERT_EQ(length + 4, encoded->lengthOf());
|
||||||
|
ASSERT_EQ(exp_gradients, x);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests19, test_threshold_encode_boundary_1) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {6});
|
||||||
|
x = 1.0f;
|
||||||
|
|
||||||
|
sd::ops::encode_threshold op;
|
||||||
|
auto result = op.evaluate({&x}, {1.0}, {3});
|
||||||
|
|
||||||
|
auto gradients = result.at(0);
|
||||||
|
auto encoded = result.at(1);
|
||||||
|
|
||||||
|
ASSERT_EQ(7, encoded->lengthOf());
|
||||||
|
ASSERT_EQ(3, x.sumNumber().e<int>(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests19, test_threshold_encode_boundary_2) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {1000});
|
||||||
|
x = 1.0f;
|
||||||
|
|
||||||
|
sd::ops::encode_threshold op;
|
||||||
|
auto result = op.evaluate({&x}, {1.0}, {100});
|
||||||
|
|
||||||
|
auto gradients = result.at(0);
|
||||||
|
auto encoded = result.at(1);
|
||||||
|
|
||||||
|
ASSERT_EQ(104, encoded->lengthOf());
|
||||||
|
|
||||||
|
ASSERT_EQ(900, x.sumNumber().e<int>(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests19, test_threshold_decode_1) {
|
||||||
|
auto x = NDArrayFactory::create<double>('c', {3}, {1.0, 2.0, -3.0});
|
||||||
|
auto y = NDArrayFactory::create<int>('c', {7}, {3, 3, 1056964608, 0, 1, 2, -3});
|
||||||
|
auto exp_gradients = NDArrayFactory::create<double>('c', {3}, {1.5, 2.5, -3.5});
|
||||||
|
|
||||||
|
sd::ops::decode_threshold op;
|
||||||
|
auto status = op.execute({&x, &y}, {&x});
|
||||||
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
ASSERT_EQ(exp_gradients, x);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests19, test_bitmap_encode_1) {
|
||||||
|
auto initial = NDArrayFactory::create<float>('c', {6}, {0.0f, 0.0f, 1e-3f, -1e-3f, 0.0f, 0.0f});
|
||||||
|
auto exp_0 = initial.like();
|
||||||
|
auto exp_1 = initial.dup();
|
||||||
|
auto exp_c = NDArrayFactory::create<int>(2L);
|
||||||
|
|
||||||
|
sd::ops::encode_bitmap enc;
|
||||||
|
auto enc_result = enc.evaluate({&initial}, {1e-3f});
|
||||||
|
ASSERT_EQ(Status::OK(), enc_result.status());
|
||||||
|
|
||||||
|
//initial.printIndexedBuffer("initial");
|
||||||
|
ASSERT_EQ(exp_0, initial);
|
||||||
|
|
||||||
|
auto encoded = enc_result.at(1);
|
||||||
|
auto counter = enc_result.at(2);
|
||||||
|
|
||||||
|
//encoded->printIndexedBuffer("encoded");
|
||||||
|
|
||||||
|
ASSERT_EQ(exp_c, *counter);
|
||||||
|
|
||||||
|
sd::ops::decode_bitmap dec;
|
||||||
|
auto status = dec.execute({&initial, encoded}, {&initial});
|
||||||
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
|
||||||
|
|
||||||
|
//initial.printIndexedBuffer();
|
||||||
|
|
||||||
|
ASSERT_EQ(exp_1, initial);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests19, test_bitmap_encode_decode) {
|
||||||
|
auto initial = NDArrayFactory::create<float>('c', {256000});
|
||||||
|
initial = 1.0f;
|
||||||
|
auto exp = initial.dup();
|
||||||
|
auto neg = initial.like();
|
||||||
|
neg = 0.5f;
|
||||||
|
|
||||||
|
sd::ops::encode_bitmap enc;
|
||||||
|
auto enc_result = enc.evaluate({&initial}, {0.5f});
|
||||||
|
auto encoded = enc_result.at(1);
|
||||||
|
|
||||||
|
// checking equality of all encoded bits
|
||||||
|
for (int e = 5; e < encoded->lengthOf() - 1; e++) {
|
||||||
|
if (encoded->e<int>(e) != encoded->e<int>(e - 1))
|
||||||
|
nd4j_printf("Non equal encoded values at E[%i]: %i;\n", e, encoded->e<int>(e));
|
||||||
|
}
|
||||||
|
|
||||||
|
ASSERT_NE(exp, initial);
|
||||||
|
ASSERT_EQ(neg, initial);
|
||||||
|
|
||||||
|
sd::ops::decode_bitmap dec;
|
||||||
|
auto status = dec.execute({&initial, encoded}, {&initial});
|
||||||
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
|
||||||
|
// checking equality of all dedoded bits
|
||||||
|
for (int e = 0; e < initial.lengthOf(); e++) {
|
||||||
|
auto f = initial.e<float>(e);
|
||||||
|
if (f != 1.0f)
|
||||||
|
nd4j_printf("initial[%i] = %f\n", e, f);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
ASSERT_EQ(exp, initial);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests19, test_threshold_encode_decode) {
|
||||||
|
auto initial = NDArrayFactory::create<float>('c', {256000});
|
||||||
|
initial = 1.0f;
|
||||||
|
auto exp = initial.dup();
|
||||||
|
auto neg = initial.like();
|
||||||
|
neg = 0.5f;
|
||||||
|
|
||||||
|
sd::ops::encode_threshold enc;
|
||||||
|
auto enc_result = enc.evaluate({&initial}, {0.5f});
|
||||||
|
auto encoded = enc_result.at(1);
|
||||||
|
|
||||||
|
ASSERT_EQ(256000 + 4, encoded->lengthOf());
|
||||||
|
ASSERT_NE(exp, initial);
|
||||||
|
|
||||||
|
for (int e = 0; e < initial.lengthOf(); e++) {
|
||||||
|
auto f = initial.e<float>(e);
|
||||||
|
if (f != 0.5f) {
|
||||||
|
nd4j_printf("initial[%i] = %f\n", e, f);
|
||||||
|
throw std::runtime_error("");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ASSERT_EQ(neg, initial);
|
||||||
|
|
||||||
|
// checking equality of all encoded bits
|
||||||
|
//for (int e = 5; e < encoded->lengthOf() - 1; e++) {
|
||||||
|
//if (encoded->e<int>(e) != encoded->e<int>(e - 1) + 1)
|
||||||
|
//nd4j_printf("Non equal encoded values at E[%i]: %i;\n", e, encoded->e<int>(e));
|
||||||
|
//}
|
||||||
|
|
||||||
|
sd::ops::decode_threshold dec;
|
||||||
|
auto status = dec.execute({&initial, encoded}, {&initial});
|
||||||
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
|
||||||
|
// checking equality of all dedoded bits
|
||||||
|
for (int e = 0; e < initial.lengthOf(); e++) {
|
||||||
|
auto f = initial.e<float>(e);
|
||||||
|
if (f != 1.0f)
|
||||||
|
nd4j_printf("initial[%i] = %f\n", e, f);
|
||||||
|
}
|
||||||
|
|
||||||
|
ASSERT_EQ(exp, initial);
|
||||||
|
}
|
||||||
|
|
||||||
|
#ifdef _RELEASE
|
||||||
|
TEST_F(DeclarableOpsTests19, test_threshold_encode_decode_2) {
|
||||||
|
// [2,1,135079944,1,1,8192,1,99]
|
||||||
|
constexpr int sizeX= 10*1000*1000;
|
||||||
|
auto initial = NDArrayFactory::create<float>('c', {1, sizeX});
|
||||||
|
initial = 1.0f;
|
||||||
|
auto exp = initial.dup();
|
||||||
|
auto neg = initial.like();
|
||||||
|
neg = 0.5f;
|
||||||
|
|
||||||
|
sd::ops::encode_threshold enc;
|
||||||
|
auto enc_result = enc.evaluate({&initial}, {0.5f});
|
||||||
|
auto encoded = enc_result.at(1);
|
||||||
|
|
||||||
|
ASSERT_EQ(sizeX + 4, encoded->lengthOf());
|
||||||
|
ASSERT_NE(exp, initial);
|
||||||
|
/*
|
||||||
|
for (int e = 0; e < initial.lengthOf(); e++) {
|
||||||
|
auto f = initial.e<float>(e);
|
||||||
|
if (f != 0.5f) {
|
||||||
|
nd4j_printf("initial[%i] = %f\n", e, f);
|
||||||
|
throw std::runtime_error("");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
ASSERT_EQ(neg, initial);
|
||||||
|
|
||||||
|
// checking equality of all encoded bits
|
||||||
|
//for (int e = 5; e < encoded->lengthOf() - 1; e++) {
|
||||||
|
//if (encoded->e<int>(e) != encoded->e<int>(e - 1) + 1)
|
||||||
|
//nd4j_printf("Non equal encoded values at E[%i]: %i;\n", e, encoded->e<int>(e));
|
||||||
|
//}
|
||||||
|
|
||||||
|
sd::ops::decode_threshold dec;
|
||||||
|
auto status = dec.execute({&initial, encoded}, {&initial});
|
||||||
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
|
||||||
|
// checking equality of all dedoded bits
|
||||||
|
/*
|
||||||
|
for (int e = 0; e < initial.lengthOf(); e++) {
|
||||||
|
auto f = initial.e<float>(e);
|
||||||
|
if (f != 1.0f)
|
||||||
|
nd4j_printf("initial[%i] = %f\n", e, f);
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
|
ASSERT_EQ(exp, initial);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests19, test_matmul_ccc) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {10, 10});
|
||||||
|
auto y = NDArrayFactory::create<float>('c', {10, 10});
|
||||||
|
auto e = NDArrayFactory::create<float>('c', {10, 10});
|
||||||
|
auto z = NDArrayFactory::create<float>('c', {10, 10});
|
||||||
|
|
||||||
|
z.assign(100.f);
|
||||||
|
e.assign(110.f);
|
||||||
|
x.assign(1.0f);
|
||||||
|
y.assign(1.0f);
|
||||||
|
|
||||||
|
sd::ops::matmul op;
|
||||||
|
auto status = op.execute({&x, &y}, {&z}, {1.0, 1.0});
|
||||||
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
|
||||||
|
ASSERT_EQ(e, z);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests19, test_matmul_fcf) {
|
||||||
|
auto x = NDArrayFactory::create<float>('f', {10, 10});
|
||||||
|
auto y = NDArrayFactory::create<float>('c', {10, 10});
|
||||||
|
auto e = NDArrayFactory::create<float>('f', {10, 10});
|
||||||
|
auto z = NDArrayFactory::create<float>('f', {10, 10});
|
||||||
|
|
||||||
|
z.assign(100.f);
|
||||||
|
e.assign(110.f);
|
||||||
|
x.assign(1.0f);
|
||||||
|
y.assign(1.0f);
|
||||||
|
|
||||||
|
sd::ops::matmul op;
|
||||||
|
auto status = op.execute({&x, &y}, {&z}, {1.0, 1.0});
|
||||||
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
|
||||||
|
ASSERT_EQ(e, z);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests19, test_matmul_cff) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {10, 10});
|
||||||
|
auto y = NDArrayFactory::create<float>('f', {10, 10});
|
||||||
|
auto e = NDArrayFactory::create<float>('f', {10, 10});
|
||||||
|
auto z = NDArrayFactory::create<float>('f', {10, 10});
|
||||||
|
|
||||||
|
z.assign(100.f);
|
||||||
|
e.assign(110.f);
|
||||||
|
x.assign(1.0f);
|
||||||
|
y.assign(1.0f);
|
||||||
|
|
||||||
|
sd::ops::matmul op;
|
||||||
|
auto status = op.execute({&x, &y}, {&z}, {1.0, 1.0});
|
||||||
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
|
||||||
|
ASSERT_EQ(e, z);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests19, test_matmul_ccf) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {10, 10});
|
||||||
|
auto y = NDArrayFactory::create<float>('c', {10, 10});
|
||||||
|
auto e = NDArrayFactory::create<float>('f', {10, 10});
|
||||||
|
auto z = NDArrayFactory::create<float>('f', {10, 10});
|
||||||
|
|
||||||
|
z.assign(100.f);
|
||||||
|
e.assign(110.f);
|
||||||
|
x.assign(1.0f);
|
||||||
|
y.assign(1.0f);
|
||||||
|
|
||||||
|
sd::ops::matmul op;
|
||||||
|
auto status = op.execute({&x, &y}, {&z}, {1.0, 1.0});
|
||||||
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
|
||||||
|
ASSERT_EQ(e, z);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests19, test_matmul_fff) {
|
||||||
|
auto x = NDArrayFactory::create<float>('f', {10, 10});
|
||||||
|
auto y = NDArrayFactory::create<float>('f', {10, 10});
|
||||||
|
auto e = NDArrayFactory::create<float>('f', {10, 10});
|
||||||
|
auto z = NDArrayFactory::create<float>('f', {10, 10});
|
||||||
|
|
||||||
|
z.assign(100.f);
|
||||||
|
e.assign(110.f);
|
||||||
|
x.assign(1.0f);
|
||||||
|
y.assign(1.0f);
|
||||||
|
|
||||||
|
sd::ops::matmul op;
|
||||||
|
auto status = op.execute({&x, &y}, {&z}, {1.0, 1.0});
|
||||||
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
|
||||||
|
ASSERT_EQ(e, z);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests19, test_conv1d_bp_1) {
|
||||||
|
/*
|
||||||
|
DynamicCustomOp op = DynamicCustomOp.builder("conv1d_bp")
|
||||||
|
.addInputs(
|
||||||
|
Nd4j.create(DataType.FLOAT, 2,2,12),
|
||||||
|
Nd4j.create(DataType.FLOAT, 3,2,3),
|
||||||
|
Nd4j.create(DataType.FLOAT, 2,3,6)
|
||||||
|
)
|
||||||
|
.addOutputs(
|
||||||
|
Nd4j.create(DataType.FLOAT, 2,2,12),
|
||||||
|
Nd4j.create(DataType.FLOAT, 3,2,3))
|
||||||
|
.addIntegerArguments(3,2,0,1,2,0)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
Nd4j.exec(op);
|
||||||
|
*/
|
||||||
|
|
||||||
|
auto t = NDArrayFactory::create<float>('c', {2, 2, 12});
|
||||||
|
auto u = NDArrayFactory::create<float>('c', {3, 2, 3});
|
||||||
|
auto v = NDArrayFactory::create<float>('c', {2, 3, 6});
|
||||||
|
|
||||||
|
sd::ops::conv1d_bp op;
|
||||||
|
auto result = op.evaluate({&t, &u, &v}, {3, 2, 0, 1, 2,0});
|
||||||
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests19, test_squeeze_1) {
|
||||||
|
auto x = NDArrayFactory::create<double>('c', {3, 4, 1});
|
||||||
|
auto e = NDArrayFactory::create<double>('c', {3, 4});
|
||||||
|
int axis = 2;
|
||||||
|
|
||||||
|
sd::ops::squeeze op;
|
||||||
|
auto status = op.execute({&x}, {&e}, {axis});
|
||||||
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
}
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because it is too large
Load Diff
File diff suppressed because one or more lines are too long
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,78 @@
|
||||||
|
/* ******************************************************************************
|
||||||
|
*
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* See the NOTICE file distributed with this work for additional
|
||||||
|
* information regarding copyright ownership.
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "testlayers.h"
|
||||||
|
#include <ops/declarable/CustomOperations.h>
|
||||||
|
#include <array/NDArray.h>
|
||||||
|
#include <ops/ops.h>
|
||||||
|
#include <helpers/GradCheck.h>
|
||||||
|
#include <chrono>
|
||||||
|
|
||||||
|
|
||||||
|
using namespace sd;
|
||||||
|
|
||||||
|
|
||||||
|
class DeclarableOpsTestsCuda1 : public testing::Test {
|
||||||
|
public:
|
||||||
|
|
||||||
|
DeclarableOpsTestsCuda1() {
|
||||||
|
printf("\n");
|
||||||
|
fflush(stdout);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTestsCuda1, Test_CHOOSE_SCALAR_LARGE) {
|
||||||
|
double inputData[150] = {
|
||||||
|
0, 0.51, 0.68, 0.69, 0.86, 0.91, 0.96, 0.97, 0.97, 1.03, 1.13, 1.16, 1.16, 1.17, 1.19, 1.25, 1.25, 1.26, 1.27, 1.28, 1.29, 1.29, 1.29, 1.30, 1.31, 1.32, 1.33, 1.33, 1.35, 1.35, 1.36, 1.37, 1.38, 1.40, 1.41, 1.42, 1.43, 1.44, 1.44, 1.45, 1.45, 1.47, 1.47, 1.51, 1.51, 1.51, 1.52, 1.53, 1.56, 1.57, 1.58, 1.59, 1.61, 1.62, 1.63, 1.63, 1.64, 1.64, 1.66, 1.66, 1.67, 1.67, 1.70, 1.70, 1.70, 1.72, 1.72, 1.72, 1.72, 1.73, 1.74, 1.74, 1.76, 1.76, 1.77, 1.77, 1.80, 1.80, 1.81, 1.82, 1.83, 1.83, 1.84, 1.84, 1.84, 1.85, 1.85, 1.85, 1.86, 1.86, 1.87, 1.88, 1.89, 1.89, 1.89, 1.89, 1.89, 1.91, 1.91, 1.91, 1.92, 1.94, 1.95, 1.97, 1.98, 1.98, 1.98, 1.98, 1.98, 1.99, 2, 2, 2.01, 2.01, 2.02, 2.03, 2.03, 2.03, 2.04, 2.04, 2.05, 2.06, 2.07, 2.08, 2.08, 2.08, 2.08, 2.09, 2.09, 2.10, 2.10, 2.11, 2.11, 2.11, 2.12, 2.12, 2.13, 2.13, 2.14, 2.14, 2.14, 2.14, 2.15, 2.15, 2.16, 2.16, 2.16, 2.16, 2.16, 2.17
|
||||||
|
};
|
||||||
|
|
||||||
|
auto precursor = NDArrayFactory::create<double>(inputData,'c',{1,149});
|
||||||
|
NDArray x(nullptr, precursor.specialBuffer(), precursor.shapeInfo());
|
||||||
|
|
||||||
|
sd::ops::choose op;
|
||||||
|
//greater than test
|
||||||
|
auto result = op.evaluate({&x}, {0.0},{3});
|
||||||
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
|
auto z = result.at(1);
|
||||||
|
|
||||||
|
ASSERT_EQ(148,z->e<double>(0));
|
||||||
|
//ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
TEST_F(DeclarableOpsTestsCuda1, Test_Reverse_TAD_1) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {1, 3, 608, 608});
|
||||||
|
auto z = x.like();
|
||||||
|
x.linspace(1.0f);
|
||||||
|
|
||||||
|
sd::ops::reverse op;
|
||||||
|
auto timeStart = std::chrono::system_clock::now();
|
||||||
|
auto status = op.execute({&x}, {&z}, {}, {1}, {});
|
||||||
|
auto timeEnd = std::chrono::system_clock::now();
|
||||||
|
auto outerTime = std::chrono::duration_cast<std::chrono::microseconds> (timeEnd - timeStart).count();
|
||||||
|
nd4j_printf("exec time: %lld us\n", outerTime);
|
||||||
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
}
|
||||||
|
*/
|
|
@ -0,0 +1,256 @@
|
||||||
|
/* ******************************************************************************
|
||||||
|
*
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* See the NOTICE file distributed with this work for additional
|
||||||
|
* information regarding copyright ownership.
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// Created by raver on 6/18/2018.
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "testlayers.h"
|
||||||
|
#include <ops/declarable/CustomOperations.h>
|
||||||
|
#include <array/NDArray.h>
|
||||||
|
// #include <array/NDArrayList.h>
|
||||||
|
|
||||||
|
using namespace sd;
|
||||||
|
|
||||||
|
|
||||||
|
class EmptyTests : public testing::Test {
|
||||||
|
public:
|
||||||
|
|
||||||
|
EmptyTests() {
|
||||||
|
printf("\n");
|
||||||
|
fflush(stdout);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(EmptyTests, Test_Create_Empty_1) {
|
||||||
|
auto empty = NDArrayFactory::empty_<float>();
|
||||||
|
ASSERT_TRUE(empty->isEmpty());
|
||||||
|
|
||||||
|
ASSERT_EQ(0, empty->lengthOf());
|
||||||
|
ASSERT_TRUE(empty->buffer() == nullptr);
|
||||||
|
|
||||||
|
ASSERT_TRUE(shape::isEmpty(empty->shapeInfo()));
|
||||||
|
|
||||||
|
delete empty;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(EmptyTests, Test_Create_Empty_2) {
|
||||||
|
auto empty = NDArrayFactory::empty<float>();
|
||||||
|
ASSERT_TRUE(empty.isEmpty());
|
||||||
|
|
||||||
|
ASSERT_EQ(0, empty.lengthOf());
|
||||||
|
ASSERT_TRUE(empty.buffer() == nullptr);
|
||||||
|
|
||||||
|
ASSERT_TRUE(shape::isEmpty(empty.shapeInfo()));
|
||||||
|
ASSERT_TRUE(empty.isEmpty());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(EmptyTests, Test_Concat_1) {
|
||||||
|
// auto empty = NDArrayFactory::empty_<float>();
|
||||||
|
auto empty = new NDArray('c', {0}, sd::DataType::FLOAT32);//NDArrayFactory::create_<float>('c', {(Nd4jLong)0}};
|
||||||
|
auto vector = NDArrayFactory::create_<float>('c', {1}, {1.0f});
|
||||||
|
|
||||||
|
ASSERT_TRUE(empty->isEmpty());
|
||||||
|
|
||||||
|
sd::ops::concat op;
|
||||||
|
auto result = op.evaluate({empty, vector}, {}, {0});
|
||||||
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
|
auto z = result.at(0);
|
||||||
|
|
||||||
|
// z->printShapeInfo("z shape");
|
||||||
|
// z->printIndexedBuffer("z buffr");
|
||||||
|
|
||||||
|
ASSERT_EQ(*vector, *z);
|
||||||
|
|
||||||
|
delete empty;
|
||||||
|
delete vector;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(EmptyTests, Test_Concat_2) {
|
||||||
|
auto empty = new NDArray('c', {0}, sd::DataType::FLOAT32); //NDArrayFactory::empty_<float>();
|
||||||
|
auto scalar1 = NDArrayFactory::create_<float>('c', {1}, {1.0f});
|
||||||
|
auto scalar2 = NDArrayFactory::create_<float>('c', {1}, {2.0f});
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', {2}, {1.f, 2.f});
|
||||||
|
|
||||||
|
ASSERT_TRUE(empty->isEmpty());
|
||||||
|
|
||||||
|
sd::ops::concat op;
|
||||||
|
auto result = op.evaluate({empty, scalar1, scalar2}, {}, {0});
|
||||||
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
|
auto z = result.at(0);
|
||||||
|
|
||||||
|
// z->printShapeInfo("z shape");
|
||||||
|
// z->printIndexedBuffer("z buffr");
|
||||||
|
|
||||||
|
ASSERT_EQ(exp, *z);
|
||||||
|
|
||||||
|
delete empty;
|
||||||
|
delete scalar1;
|
||||||
|
delete scalar2;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(EmptyTests, Test_Concat_3) {
|
||||||
|
auto empty = NDArrayFactory::empty<float>(); //NDArrayFactory::empty_<float>();
|
||||||
|
auto scalar1 = NDArrayFactory::create<float>(1.0f);
|
||||||
|
auto scalar2 = NDArrayFactory::create<float>(2.0f);
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', {2}, {1.f, 2.f});
|
||||||
|
|
||||||
|
ASSERT_TRUE(empty.isEmpty());
|
||||||
|
|
||||||
|
sd::ops::concat op;
|
||||||
|
auto result = op.evaluate({&empty, &scalar1, &scalar2}, {}, {0});
|
||||||
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
|
auto z = result.at(0);
|
||||||
|
|
||||||
|
ASSERT_EQ(exp, *z);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(EmptyTests, Test_Concat_4) {
|
||||||
|
auto empty = NDArrayFactory::empty<float>(); //NDArrayFactory::empty_<float>();
|
||||||
|
auto scalar1 = NDArrayFactory::create<float>(1.0f);
|
||||||
|
auto scalar2 = NDArrayFactory::create<float>(2.0f);
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', {2}, {1.f, 2.f});
|
||||||
|
|
||||||
|
ASSERT_TRUE(empty.isEmpty());
|
||||||
|
|
||||||
|
sd::ops::concat op;
|
||||||
|
auto result = op.evaluate({&scalar1, &empty, &scalar2}, {}, {0});
|
||||||
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
|
auto z = result.at(0);
|
||||||
|
|
||||||
|
ASSERT_EQ(exp, *z);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(EmptyTests, Test_dup_1) {
|
||||||
|
auto empty = NDArrayFactory::empty<int>();
|
||||||
|
auto dup = new NDArray(empty.dup());
|
||||||
|
|
||||||
|
ASSERT_TRUE(dup->isEmpty());
|
||||||
|
ASSERT_EQ(empty, *dup);
|
||||||
|
|
||||||
|
delete dup;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(EmptyTests, test_empty_scatter_1) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {5});
|
||||||
|
auto indices = NDArrayFactory::create<int>('c', {0});
|
||||||
|
auto updates = NDArrayFactory::create<float>('c', {0});
|
||||||
|
|
||||||
|
x.linspace(1.0f);
|
||||||
|
|
||||||
|
sd::ops::scatter_upd op;
|
||||||
|
auto result = op.evaluate({&x, &indices, &updates}, {}, {}, {true});
|
||||||
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
|
auto z = result.at(0);
|
||||||
|
ASSERT_EQ(x, *z);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(EmptyTests, test_empty_scatter_2) {
|
||||||
|
NDArray x ('c', {5}, sd::DataType::FLOAT32);
|
||||||
|
NDArray z ('c', {5}, sd::DataType::FLOAT32);
|
||||||
|
auto indices = NDArrayFactory::create<int>('c', {0});
|
||||||
|
auto updates = NDArrayFactory::create<float>('c', {0});
|
||||||
|
|
||||||
|
x.linspace(1.0f);
|
||||||
|
|
||||||
|
sd::ops::scatter_upd op;
|
||||||
|
auto status = op.execute({&x, &indices, &updates}, {&z}, {}, {}, {true});
|
||||||
|
|
||||||
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
|
||||||
|
ASSERT_EQ(x, z);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(EmptyTests, test_shaped_empty_1) {
|
||||||
|
auto empty = NDArrayFactory::create<float>('c', {2, 0, 3});
|
||||||
|
std::vector<Nd4jLong> shape = {2, 0, 3};
|
||||||
|
|
||||||
|
ASSERT_EQ(sd::DataType::FLOAT32, empty.dataType());
|
||||||
|
ASSERT_EQ(0, empty.lengthOf());
|
||||||
|
ASSERT_TRUE(empty.isEmpty());
|
||||||
|
ASSERT_EQ(shape, empty.getShapeAsVector());
|
||||||
|
ASSERT_EQ(3, empty.rankOf());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(EmptyTests, test_shaped_empty_2) {
|
||||||
|
auto empty = NDArrayFactory::create<float>('c', {0, 3});
|
||||||
|
std::vector<Nd4jLong> shape = {0, 3};
|
||||||
|
|
||||||
|
ASSERT_EQ(sd::DataType::FLOAT32, empty.dataType());
|
||||||
|
ASSERT_EQ(0, empty.lengthOf());
|
||||||
|
ASSERT_TRUE(empty.isEmpty());
|
||||||
|
ASSERT_EQ(shape, empty.getShapeAsVector());
|
||||||
|
ASSERT_EQ(2, empty.rankOf());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(EmptyTests, test_shaped_empty_3) {
|
||||||
|
auto empty = NDArrayFactory::create<float>('c', {0});
|
||||||
|
std::vector<Nd4jLong> shape = {0};
|
||||||
|
|
||||||
|
ASSERT_EQ(sd::DataType::FLOAT32, empty.dataType());
|
||||||
|
ASSERT_EQ(0, empty.lengthOf());
|
||||||
|
ASSERT_TRUE(empty.isEmpty());
|
||||||
|
ASSERT_EQ(shape, empty.getShapeAsVector());
|
||||||
|
ASSERT_EQ(1, empty.rankOf());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(EmptyTests, test_shaped_empty_4) {
|
||||||
|
const auto shape = ConstantShapeHelper::getInstance().vectorShapeInfo(0, sd::DataType::FLOAT32);
|
||||||
|
NDArray array(shape, true, sd::LaunchContext::defaultContext());
|
||||||
|
std::vector<Nd4jLong> shapeOf({0});
|
||||||
|
|
||||||
|
ASSERT_TRUE(array.isEmpty());
|
||||||
|
ASSERT_EQ(1, array.rankOf());
|
||||||
|
ASSERT_EQ(shapeOf, array.getShapeAsVector());
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(EmptyTests, test_empty_matmul_1) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {0, 1});
|
||||||
|
auto y = NDArrayFactory::create<float>('c', {1, 0});
|
||||||
|
auto e = NDArrayFactory::create<float>('c', {0, 0});
|
||||||
|
|
||||||
|
sd::ops::matmul op;
|
||||||
|
auto result = op.evaluate({&x, &y}, {}, {});
|
||||||
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
|
auto z = result.at(0);
|
||||||
|
ASSERT_EQ(e, *z);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(EmptyTests, test_empty_matmul_2) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {1, 0, 4});
|
||||||
|
auto y = NDArrayFactory::create<float>('c', {1, 4, 0});
|
||||||
|
auto e = NDArrayFactory::create<float>('c', {1, 0, 0});
|
||||||
|
|
||||||
|
sd::ops::matmul op;
|
||||||
|
auto result = op.evaluate({&x, &y}, {}, {});
|
||||||
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
|
auto z = result.at(0);
|
||||||
|
ASSERT_EQ(e, *z);
|
||||||
|
}
|
|
@ -0,0 +1,68 @@
|
||||||
|
/* ******************************************************************************
|
||||||
|
*
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* See the NOTICE file distributed with this work for additional
|
||||||
|
* information regarding copyright ownership.
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "testlayers.h"
|
||||||
|
#include <array/ExtraArguments.h>
|
||||||
|
#include <array>
|
||||||
|
|
||||||
|
using namespace sd;
|
||||||
|
|
||||||
|
class ExtraArgumentsTests : public testing::Test {
|
||||||
|
public:
|
||||||
|
|
||||||
|
ExtraArgumentsTests() {
|
||||||
|
printf("\n");
|
||||||
|
fflush(stdout);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(ExtraArgumentsTests, Basic_Test_1) {
|
||||||
|
if (!Environment::getInstance().isCPU())
|
||||||
|
return;
|
||||||
|
|
||||||
|
ExtraArguments args({1.0, 2.0, 3.0});
|
||||||
|
|
||||||
|
float ef[] = {1.f, 2.f, 3.f};
|
||||||
|
double ed[] = {1., 2., 3.};
|
||||||
|
|
||||||
|
auto ptrFloat = reinterpret_cast<float *>(args.argumentsAsT<float>());
|
||||||
|
auto ptrDouble = reinterpret_cast<double *>(args.argumentsAsT<double>());
|
||||||
|
ASSERT_TRUE(ptrFloat != nullptr);
|
||||||
|
ASSERT_TRUE(ptrDouble != nullptr);
|
||||||
|
|
||||||
|
for (int e = 0; e < 3; e++) {
|
||||||
|
ASSERT_NEAR(ef[e], ptrFloat[e], 1e-5f);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int e = 0; e < 3; e++) {
|
||||||
|
ASSERT_NEAR(ed[e], ptrDouble[e], 1e-5);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(ExtraArgumentsTests, Basic_Test_2) {
|
||||||
|
ExtraArguments args;
|
||||||
|
|
||||||
|
auto ptrInt = args.argumentsAsT<int>();
|
||||||
|
ASSERT_TRUE(ptrInt == nullptr);
|
||||||
|
}
|
||||||
|
|
File diff suppressed because one or more lines are too long
|
@ -0,0 +1,104 @@
|
||||||
|
/*
|
||||||
|
* ******************************************************************************
|
||||||
|
* *
|
||||||
|
* *
|
||||||
|
* * This program and the accompanying materials are made available under the
|
||||||
|
* * terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
* *
|
||||||
|
* * See the NOTICE file distributed with this work for additional
|
||||||
|
* * information regarding copyright ownership.
|
||||||
|
* * Unless required by applicable law or agreed to in writing, software
|
||||||
|
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* * License for the specific language governing permissions and limitations
|
||||||
|
* * under the License.
|
||||||
|
* *
|
||||||
|
* * SPDX-License-Identifier: Apache-2.0
|
||||||
|
* *****************************************************************************
|
||||||
|
*/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <array/NDArray.h>
|
||||||
|
#include <array/NDArrayFactory.h>
|
||||||
|
#include "testlayers.h"
|
||||||
|
#include <graph/Stash.h>
|
||||||
|
#include <graph/FlatUtils.h>
|
||||||
|
|
||||||
|
using namespace sd;
|
||||||
|
|
||||||
|
class FlatUtilsTests : public testing::Test {
|
||||||
|
public:
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(FlatUtilsTests, flat_float_serde_1) {
|
||||||
|
auto array = NDArrayFactory::create<float>('c', {4}, {1.f, 2.f, 3.f, 4.f});
|
||||||
|
|
||||||
|
flatbuffers::FlatBufferBuilder builder(1024);
|
||||||
|
auto flatArray = FlatUtils::toFlatArray(builder, array);
|
||||||
|
builder.Finish(flatArray);
|
||||||
|
|
||||||
|
|
||||||
|
auto pfArray = GetFlatArray(builder.GetBufferPointer());
|
||||||
|
|
||||||
|
auto restored = FlatUtils::fromFlatArray(pfArray);
|
||||||
|
|
||||||
|
ASSERT_EQ(array, *restored);
|
||||||
|
|
||||||
|
delete restored;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(FlatUtilsTests, flat_int_serde_1) {
|
||||||
|
auto array = NDArrayFactory::create<int>('c', {4}, {1, 2, 3, 4});
|
||||||
|
|
||||||
|
flatbuffers::FlatBufferBuilder builder(1024);
|
||||||
|
auto flatArray = FlatUtils::toFlatArray(builder, array);
|
||||||
|
builder.Finish(flatArray);
|
||||||
|
|
||||||
|
|
||||||
|
auto pfArray = GetFlatArray(builder.GetBufferPointer());
|
||||||
|
|
||||||
|
auto restored = FlatUtils::fromFlatArray(pfArray);
|
||||||
|
|
||||||
|
ASSERT_EQ(array, *restored);
|
||||||
|
|
||||||
|
delete restored;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(FlatUtilsTests, flat_bool_serde_1) {
|
||||||
|
auto array = NDArrayFactory::create<bool>('c', {4}, {true, false, true, false});
|
||||||
|
|
||||||
|
flatbuffers::FlatBufferBuilder builder(1024);
|
||||||
|
auto flatArray = FlatUtils::toFlatArray(builder, array);
|
||||||
|
builder.Finish(flatArray);
|
||||||
|
|
||||||
|
|
||||||
|
auto pfArray = GetFlatArray(builder.GetBufferPointer());
|
||||||
|
|
||||||
|
auto restored = FlatUtils::fromFlatArray(pfArray);
|
||||||
|
|
||||||
|
ASSERT_EQ(array, *restored);
|
||||||
|
|
||||||
|
delete restored;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(FlatUtilsTests, flat_string_serde_1) {
|
||||||
|
auto array = NDArrayFactory::string( {3}, {"alpha", "beta", "gamma"});
|
||||||
|
|
||||||
|
flatbuffers::FlatBufferBuilder builder(1024);
|
||||||
|
auto flatArray = FlatUtils::toFlatArray(builder, array);
|
||||||
|
builder.Finish(flatArray);
|
||||||
|
|
||||||
|
|
||||||
|
auto pfArray = GetFlatArray(builder.GetBufferPointer());
|
||||||
|
|
||||||
|
auto restored = FlatUtils::fromFlatArray(pfArray);
|
||||||
|
|
||||||
|
ASSERT_EQ(array, *restored);
|
||||||
|
|
||||||
|
delete restored;
|
||||||
|
}
|
|
@ -0,0 +1,105 @@
|
||||||
|
/* ******************************************************************************
|
||||||
|
*
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* See the NOTICE file distributed with this work for additional
|
||||||
|
* information regarding copyright ownership.
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// Created by raver119 on 29.11.17.
|
||||||
|
//
|
||||||
|
|
||||||
|
|
||||||
|
#include "testlayers.h"
|
||||||
|
#include <flatbuffers/flatbuffers.h>
|
||||||
|
#include <graph/generated/node_generated.h>
|
||||||
|
#include <graph/generated/graph_generated.h>
|
||||||
|
#include <graph/Node.h>
|
||||||
|
#include <graph/Graph.h>
|
||||||
|
#include <array/NDArray.h>
|
||||||
|
#include <ops/declarable/DeclarableOp.h>
|
||||||
|
|
||||||
|
using namespace sd;
|
||||||
|
using namespace sd::graph;
|
||||||
|
|
||||||
|
class GraphExecutionerTests : public testing::Test {
|
||||||
|
public:
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
#ifdef GRAPH_TESTS_OK
|
||||||
|
TEST_F(GraphExecutionerTests, Test_Implicit_Output_1) {
|
||||||
|
auto graph = GraphExecutioner::importFromFlatBuffers("./resources/tensor_slice.fb");
|
||||||
|
graph->buildGraph();
|
||||||
|
|
||||||
|
auto outputs = graph->fetchOutputs();
|
||||||
|
|
||||||
|
ASSERT_EQ(1, outputs->size());
|
||||||
|
|
||||||
|
auto var0 = outputs->at(0);
|
||||||
|
|
||||||
|
ASSERT_EQ(7, var0->id());
|
||||||
|
ASSERT_EQ(0, var0->index());
|
||||||
|
|
||||||
|
delete outputs;
|
||||||
|
delete graph;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(GraphExecutionerTests, Test_Implicit_Output_2) {
|
||||||
|
auto graph = GraphExecutioner::importFromFlatBuffers("./resources/reduce_dim_false.fb");
|
||||||
|
graph->buildGraph();
|
||||||
|
|
||||||
|
auto outputs = graph->fetchOutputs();
|
||||||
|
|
||||||
|
ASSERT_EQ(1, outputs->size());
|
||||||
|
|
||||||
|
auto var0 = outputs->at(0);
|
||||||
|
|
||||||
|
ASSERT_EQ(3, var0->id());
|
||||||
|
ASSERT_EQ(0, var0->index());
|
||||||
|
|
||||||
|
delete outputs;
|
||||||
|
delete graph;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(GraphExecutionerTests, Test_Implicit_Output_3) {
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', {3}, {3, 3, 3});
|
||||||
|
|
||||||
|
auto graph = GraphExecutioner::importFromFlatBuffers("./resources/reduce_dim_false.fb");
|
||||||
|
auto status = GraphExecutioner::execute(graph);
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||||
|
|
||||||
|
auto outputs = graph->fetchOutputs();
|
||||||
|
|
||||||
|
ASSERT_EQ(1, outputs->size());
|
||||||
|
|
||||||
|
auto var0 = outputs->at(0);
|
||||||
|
|
||||||
|
ASSERT_EQ(3, var0->id());
|
||||||
|
ASSERT_EQ(0, var0->index());
|
||||||
|
|
||||||
|
auto array = var0->getNDArray();
|
||||||
|
|
||||||
|
ASSERT_TRUE(array != nullptr);
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.isSameShape(array));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(array));
|
||||||
|
|
||||||
|
delete outputs;
|
||||||
|
delete graph;
|
||||||
|
}
|
||||||
|
#endif
|
|
@ -0,0 +1,88 @@
|
||||||
|
/* ******************************************************************************
|
||||||
|
*
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* See the NOTICE file distributed with this work for additional
|
||||||
|
* information regarding copyright ownership.
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// Created by raver119 on 11.12.17.
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "testlayers.h"
|
||||||
|
#include <graph/GraphHolder.h>
|
||||||
|
|
||||||
|
using namespace sd;
|
||||||
|
using namespace sd::ops;
|
||||||
|
using namespace sd::graph;
|
||||||
|
|
||||||
|
class GraphHolderTests : public testing::Test {
|
||||||
|
public:
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(GraphHolderTests, SimpleTests_1) {
|
||||||
|
Graph graph;
|
||||||
|
Nd4jLong graphId = 119;
|
||||||
|
GraphHolder::getInstance().registerGraph(graphId, &graph);
|
||||||
|
|
||||||
|
ASSERT_TRUE(GraphHolder::getInstance().hasGraph(graphId));
|
||||||
|
|
||||||
|
GraphHolder::getInstance().forgetGraph(graphId);
|
||||||
|
|
||||||
|
ASSERT_FALSE(GraphHolder::getInstance().hasGraph(graphId));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(GraphHolderTests, SimpleTests_2) {
|
||||||
|
auto graph = new Graph;
|
||||||
|
Nd4jLong graphId = 117;
|
||||||
|
GraphHolder::getInstance().registerGraph(graphId, graph);
|
||||||
|
|
||||||
|
ASSERT_TRUE(GraphHolder::getInstance().hasGraph(graphId));
|
||||||
|
|
||||||
|
auto graph2 = GraphHolder::getInstance().cloneGraph(graphId);
|
||||||
|
|
||||||
|
ASSERT_TRUE(graph != graph2);
|
||||||
|
ASSERT_TRUE(graph2 != nullptr);
|
||||||
|
|
||||||
|
GraphHolder::getInstance().forgetGraph(graphId);
|
||||||
|
|
||||||
|
ASSERT_FALSE(GraphHolder::getInstance().hasGraph(graphId));
|
||||||
|
|
||||||
|
delete graph;
|
||||||
|
delete graph2;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(GraphHolderTests, SimpleTests_3) {
|
||||||
|
auto graph = new Graph;
|
||||||
|
Nd4jLong graphId = 117;
|
||||||
|
GraphHolder::getInstance().registerGraph(graphId, graph);
|
||||||
|
|
||||||
|
ASSERT_TRUE(GraphHolder::getInstance().hasGraph(graphId));
|
||||||
|
|
||||||
|
auto graph2 = GraphHolder::getInstance().cloneGraph(graphId);
|
||||||
|
|
||||||
|
ASSERT_TRUE(graph != graph2);
|
||||||
|
ASSERT_TRUE(graph2 != nullptr);
|
||||||
|
|
||||||
|
GraphHolder::getInstance().dropGraph(graphId);
|
||||||
|
|
||||||
|
ASSERT_FALSE(GraphHolder::getInstance().hasGraph(graphId));
|
||||||
|
|
||||||
|
|
||||||
|
delete graph2;
|
||||||
|
}
|
|
@ -0,0 +1,266 @@
|
||||||
|
/* ******************************************************************************
|
||||||
|
*
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* See the NOTICE file distributed with this work for additional
|
||||||
|
* information regarding copyright ownership.
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
#include "testlayers.h"
|
||||||
|
#include <graph/RandomGenerator.h>
|
||||||
|
#include <array/DataTypeUtils.h>
|
||||||
|
#include <graph/Graph.h>
|
||||||
|
#include <array>
|
||||||
|
|
||||||
|
using namespace sd;
|
||||||
|
using namespace sd::graph;
|
||||||
|
|
||||||
|
class GraphRandomGeneratorTests : public testing::Test {
|
||||||
|
public:
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(GraphRandomGeneratorTests, Reproducibility_Test_1) {
|
||||||
|
sd::graph::RandomGenerator g0(119);
|
||||||
|
sd::graph::RandomGenerator g1(119);
|
||||||
|
|
||||||
|
auto i0 = g0.relativeT<int>(15, 0, DataTypeUtils::max<int>());
|
||||||
|
auto i1 = g1.relativeT<int>(15, 0, DataTypeUtils::max<int>());
|
||||||
|
|
||||||
|
ASSERT_EQ(i0, i1);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(GraphRandomGeneratorTests, Reproducibility_Test_2) {
|
||||||
|
sd::graph::RandomGenerator g0(119);
|
||||||
|
sd::graph::RandomGenerator g1(117);
|
||||||
|
|
||||||
|
auto i0 = g0.relativeT<int>(15, 0, DataTypeUtils::max<int>());
|
||||||
|
auto i1 = g1.relativeT<int>(15, 0, DataTypeUtils::max<int>());
|
||||||
|
|
||||||
|
ASSERT_NE(i0, i1);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(GraphRandomGeneratorTests, Reproducibility_Test_3) {
|
||||||
|
sd::graph::RandomGenerator g0(119, 5);
|
||||||
|
sd::graph::RandomGenerator g1(119, 10);
|
||||||
|
|
||||||
|
auto i0 = g0.relativeT<int>(15, 0, DataTypeUtils::max<int>());
|
||||||
|
auto i1 = g1.relativeT<int>(15, 0, DataTypeUtils::max<int>());
|
||||||
|
|
||||||
|
ASSERT_NE(i0, i1);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(GraphRandomGeneratorTests, Reproducibility_Test_4) {
|
||||||
|
sd::graph::RandomGenerator g0(119, 5);
|
||||||
|
sd::graph::RandomGenerator g1(117, 5);
|
||||||
|
|
||||||
|
auto i0 = g0.relativeT<int>(15, 0, DataTypeUtils::max<int>());
|
||||||
|
auto i1 = g1.relativeT<int>(15, 0, DataTypeUtils::max<int>());
|
||||||
|
|
||||||
|
ASSERT_NE(i0, i1);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(GraphRandomGeneratorTests, Sequential_Test_1) {
|
||||||
|
sd::graph::RandomGenerator g0(119, 5);
|
||||||
|
sd::graph::RandomGenerator g1(119, 5);
|
||||||
|
|
||||||
|
auto v0 = g0.relativeT<int>(15, 0, DataTypeUtils::max<int>());
|
||||||
|
auto v1 = g1.relativeT<int>(15, 0, DataTypeUtils::max<int>());
|
||||||
|
g0.rewindH(200);
|
||||||
|
auto r0 = g0.relativeT<int>(15, 0, DataTypeUtils::max<int>());
|
||||||
|
auto r1 = g1.relativeT<int>(15, 0, DataTypeUtils::max<int>());
|
||||||
|
|
||||||
|
// values after rewind aren't equal
|
||||||
|
ASSERT_NE(r0, v0);
|
||||||
|
|
||||||
|
// two generators must give the same output
|
||||||
|
ASSERT_EQ(v0, v1);
|
||||||
|
|
||||||
|
// but not after one of them was rewinded
|
||||||
|
ASSERT_NE(r1, r0);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(GraphRandomGeneratorTests, Sequential_Test_2) {
|
||||||
|
sd::graph::RandomGenerator g0(119, 5);
|
||||||
|
sd::graph::RandomGenerator g1(119, 5);
|
||||||
|
|
||||||
|
auto v0 = g0.relativeT<int>(15, 0, DataTypeUtils::max<int>());
|
||||||
|
auto v1 = g1.relativeT<int>(15, 0, DataTypeUtils::max<int>());
|
||||||
|
g0.rewindH(200);
|
||||||
|
g1.rewindH(199);
|
||||||
|
auto r0 = g0.relativeT<int>(15, 0, DataTypeUtils::max<int>());
|
||||||
|
auto r1 = g1.relativeT<int>(15, 0, DataTypeUtils::max<int>());
|
||||||
|
|
||||||
|
// values after rewind aren't equal
|
||||||
|
ASSERT_NE(r0, v0);
|
||||||
|
|
||||||
|
// two generators must give the same output
|
||||||
|
ASSERT_EQ(v0, v1);
|
||||||
|
|
||||||
|
// but not after they was rewinded with different number of elements
|
||||||
|
ASSERT_NE(r1, r0);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(GraphRandomGeneratorTests, Sequential_Test_3) {
|
||||||
|
sd::graph::RandomGenerator g0(119, 5);
|
||||||
|
sd::graph::RandomGenerator g1(119, 5);
|
||||||
|
|
||||||
|
auto v0 = g0.relativeT<int>(15, 0, DataTypeUtils::max<int>());
|
||||||
|
auto v1 = g1.relativeT<int>(15, 0, DataTypeUtils::max<int>());
|
||||||
|
g0.rewindH(200);
|
||||||
|
g1.rewindH(200);
|
||||||
|
auto r0 = g0.relativeT<int>(15, 0, DataTypeUtils::max<int>());
|
||||||
|
auto r1 = g1.relativeT<int>(15, 0, DataTypeUtils::max<int>());
|
||||||
|
|
||||||
|
// values after rewind aren't equal
|
||||||
|
ASSERT_NE(r0, v0);
|
||||||
|
|
||||||
|
// two generators must give the same output
|
||||||
|
ASSERT_EQ(v0, v1);
|
||||||
|
|
||||||
|
// and here output must be equal as well
|
||||||
|
ASSERT_EQ(r1, r0);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(GraphRandomGeneratorTests, Sequential_Test_4) {
|
||||||
|
sd::graph::RandomGenerator g0(119, 5);
|
||||||
|
sd::graph::RandomGenerator g1(119, 5);
|
||||||
|
|
||||||
|
auto v0 = g0.relativeT<int>(15, 0, DataTypeUtils::max<int>());
|
||||||
|
auto v1 = g1.relativeT<int>(15, 0, DataTypeUtils::max<int>());
|
||||||
|
g0.rewindH(200);
|
||||||
|
g1.rewindH(200);
|
||||||
|
auto r0 = g0.relativeT<int>(15, 0, DataTypeUtils::max<int>());
|
||||||
|
auto r1 = g1.relativeT<int>(15, 0, DataTypeUtils::max<int>());
|
||||||
|
g0.rewindH(200);
|
||||||
|
g1.rewindH(200);
|
||||||
|
auto z0 = g0.relativeT<int>(15, 0, DataTypeUtils::max<int>());
|
||||||
|
auto z1 = g1.relativeT<int>(15, 0, DataTypeUtils::max<int>());
|
||||||
|
g0.rewindH(201);
|
||||||
|
g1.rewindH(199);
|
||||||
|
auto y0 = g0.relativeT<int>(15, 0, DataTypeUtils::max<int>());
|
||||||
|
auto y1 = g1.relativeT<int>(15, 0, DataTypeUtils::max<int>());
|
||||||
|
|
||||||
|
// values after rewind aren't equal
|
||||||
|
ASSERT_NE(r0, v0);
|
||||||
|
|
||||||
|
// two generators must give the same output
|
||||||
|
ASSERT_EQ(v0, v1);
|
||||||
|
|
||||||
|
// and here output must be equal as well
|
||||||
|
ASSERT_EQ(r0, r1);
|
||||||
|
|
||||||
|
ASSERT_EQ(z0, z1);
|
||||||
|
|
||||||
|
ASSERT_NE(r0, z0);
|
||||||
|
ASSERT_NE(r1, z1);
|
||||||
|
|
||||||
|
ASSERT_NE(y0, z0);
|
||||||
|
ASSERT_NE(y1, z1);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
//#ifndef __clang__
|
||||||
|
|
||||||
|
TEST_F(GraphRandomGeneratorTests, Long_Test_1) {
|
||||||
|
sd::graph::RandomGenerator g0(119, 5);
|
||||||
|
sd::graph::RandomGenerator g1(119, 5);
|
||||||
|
|
||||||
|
std::array<Nd4jLong, 10000> z0, z1, z2, z3;
|
||||||
|
|
||||||
|
for (int e = 0; e < z0.size(); e++) {
|
||||||
|
z0[e] = g0.relativeT<Nd4jLong>(e);
|
||||||
|
z1[e] = g1.relativeT<Nd4jLong>(e);
|
||||||
|
}
|
||||||
|
|
||||||
|
g0.rewindH(z0.size());
|
||||||
|
g1.rewindH(z0.size());
|
||||||
|
|
||||||
|
for (int e = 0; e < z0.size(); e++) {
|
||||||
|
z2[e] = g0.relativeT<Nd4jLong>(e);
|
||||||
|
z3[e] = g1.relativeT<Nd4jLong>(e);
|
||||||
|
}
|
||||||
|
|
||||||
|
// these sequences should be equal
|
||||||
|
ASSERT_EQ(z0, z1);
|
||||||
|
ASSERT_EQ(z2, z3);
|
||||||
|
|
||||||
|
// these sequences should be different due to rewind
|
||||||
|
ASSERT_NE(z0, z3);
|
||||||
|
|
||||||
|
// we'll be counting values > MAX_INT here
|
||||||
|
int maxes = 0;
|
||||||
|
|
||||||
|
for (int e = 0; e < z0.size(); e++) {
|
||||||
|
auto v = z0[e];
|
||||||
|
|
||||||
|
// we don't want any negatives here
|
||||||
|
ASSERT_TRUE(v > 0);
|
||||||
|
|
||||||
|
if (v > DataTypeUtils::max<int>())
|
||||||
|
maxes++;
|
||||||
|
}
|
||||||
|
|
||||||
|
// and now we're ensuring there ARE values above MAX_INT
|
||||||
|
ASSERT_NE(0, maxes);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(GraphRandomGeneratorTests, FloatingPoint_Test_1) {
|
||||||
|
sd::graph::RandomGenerator g0(119, 5);
|
||||||
|
sd::graph::RandomGenerator g1(119, 5);
|
||||||
|
|
||||||
|
std::array<double, 100> z0, z1, z2, z3;
|
||||||
|
|
||||||
|
for (int e = 0; e < z0.size(); e++) {
|
||||||
|
z0[e] = g0.relativeT<double>(e, -1.0, 1.0);
|
||||||
|
z1[e] = g1.relativeT<double>(e, -1.0, 1.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
g0.rewindH(z0.size());
|
||||||
|
g1.rewindH(z0.size());
|
||||||
|
|
||||||
|
for (int e = 0; e < z0.size(); e++) {
|
||||||
|
z2[e] = g0.relativeT<double>(e, -1.0, 1.0);
|
||||||
|
z3[e] = g1.relativeT<double>(e, -1.0, 1.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
// these sequences should be equal
|
||||||
|
ASSERT_EQ(z0, z1);
|
||||||
|
ASSERT_EQ(z2, z3);
|
||||||
|
|
||||||
|
// these sequences should be different due to rewind
|
||||||
|
ASSERT_NE(z0, z3);
|
||||||
|
|
||||||
|
// we'll count negatives as well
|
||||||
|
int negs = 0;
|
||||||
|
|
||||||
|
// make sure every value stays within distribution borders
|
||||||
|
for (int e = 0; e < z0.size(); e++) {
|
||||||
|
auto v = z0[e];
|
||||||
|
if (!(v >= -1.0 && v <= 1.0)) {
|
||||||
|
nd4j_printf("Failed at idx [%i]: %f\n", e, (float) v);
|
||||||
|
ASSERT_TRUE(v >= -1.0 && v <= 1.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (v < 0.0)
|
||||||
|
negs++;
|
||||||
|
}
|
||||||
|
|
||||||
|
// there should be negatives
|
||||||
|
ASSERT_TRUE(negs > 0);
|
||||||
|
|
||||||
|
// and positives
|
||||||
|
ASSERT_NE(z0.size(), negs);
|
||||||
|
}
|
||||||
|
|
|
@ -0,0 +1,351 @@
|
||||||
|
/* ******************************************************************************
|
||||||
|
*
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* See the NOTICE file distributed with this work for additional
|
||||||
|
* information regarding copyright ownership.
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "testlayers.h"
|
||||||
|
#include <graph/GraphState.h>
|
||||||
|
#include <ops/declarable/CustomOperations.h>
|
||||||
|
#include <ops/declarable/LegacyTransformOp.h>
|
||||||
|
#include <ops/declarable/LegacyReduceOp.h>
|
||||||
|
#include <legacy/NativeOps.h>
|
||||||
|
|
||||||
|
using namespace sd;
|
||||||
|
using namespace sd::graph;
|
||||||
|
|
||||||
|
class GraphStateTests : public testing::Test {
|
||||||
|
public:
|
||||||
|
GraphStateTests() {
|
||||||
|
Environment::getInstance().setDebug(false);
|
||||||
|
Environment::getInstance().setVerbose(false);
|
||||||
|
};
|
||||||
|
|
||||||
|
~GraphStateTests() {
|
||||||
|
Environment::getInstance().setDebug(false);
|
||||||
|
Environment::getInstance().setVerbose(false);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/*
|
||||||
|
* PLAN:
|
||||||
|
* Create GraphState
|
||||||
|
* Register Scope
|
||||||
|
* Add few Ops to it
|
||||||
|
* Call conditional, that refers to scopes
|
||||||
|
* Check results
|
||||||
|
*/
|
||||||
|
|
||||||
|
TEST_F(GraphStateTests, Basic_Tests_1) {
|
||||||
|
auto state = (GraphState *) getGraphState(117L);
|
||||||
|
ASSERT_EQ(117L, state->id());
|
||||||
|
|
||||||
|
// this call will create scope internally
|
||||||
|
state->registerScope(119);
|
||||||
|
|
||||||
|
sd::ops::add opA;
|
||||||
|
sd::ops::LegacyTransformSameOp opB(transform::Neg); // simdOps::Neg
|
||||||
|
|
||||||
|
ArgumentsList argsA;
|
||||||
|
ArgumentsList argsB;
|
||||||
|
|
||||||
|
state->attachOpToScope(119, 1, &opA, argsA);
|
||||||
|
state->attachOpToScope(119, 2, &opB, argsB);
|
||||||
|
|
||||||
|
auto scope = state->getScope(119);
|
||||||
|
ASSERT_TRUE(scope != nullptr);
|
||||||
|
ASSERT_EQ(2, scope->size());
|
||||||
|
|
||||||
|
deleteGraphState(state);
|
||||||
|
}
|
||||||
|
|
||||||
|
// just separate case for doubles wrapper in NativeOps, nothing else
|
||||||
|
TEST_F(GraphStateTests, Basic_Tests_2) {
|
||||||
|
auto state = (GraphState *) getGraphState(117L);
|
||||||
|
ASSERT_EQ(117L, state->id());
|
||||||
|
|
||||||
|
// this call will create scope internally
|
||||||
|
state->registerScope(119);
|
||||||
|
|
||||||
|
sd::ops::add opA;
|
||||||
|
sd::ops::LegacyTransformSameOp opB(transform::Neg); // simdOps::Neg
|
||||||
|
|
||||||
|
ArgumentsList argsA;
|
||||||
|
ArgumentsList argsB;
|
||||||
|
|
||||||
|
state->attachOpToScope(119, 1, &opA, argsA);
|
||||||
|
state->attachOpToScope(119, 2, &opB, argsB);
|
||||||
|
|
||||||
|
auto scope = state->getScope(119);
|
||||||
|
ASSERT_TRUE(scope != nullptr);
|
||||||
|
ASSERT_EQ(2, scope->size());
|
||||||
|
|
||||||
|
deleteGraphState(state);
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
TEST_F(GraphStateTests, Stateful_Execution_1) {
|
||||||
|
auto state = getGraphState(117L);
|
||||||
|
|
||||||
|
Nd4jLong scopes[] = {22, 33};
|
||||||
|
//auto status = execCustomOpWithScope(nullptr, state, 10, scopes, 2, nullptr, nullptr, 0, nullptr, nullptr, 0);
|
||||||
|
auto status = execCustomOpWithScope(nullptr, state, 10, scopes, 2, nullptr, nullptr, 0, nullptr, nullptr, 0);
|
||||||
|
|
||||||
|
ASSERT_EQ(Status::THROW(), status);
|
||||||
|
|
||||||
|
deleteGraphState(state);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(GraphStateTests, Stateful_Execution_2) {
|
||||||
|
auto state = (GraphState *) getGraphState(117L);
|
||||||
|
|
||||||
|
state->registerScope(22);
|
||||||
|
state->registerScope(33);
|
||||||
|
|
||||||
|
Nd4jLong scopes[] = {22, 33};
|
||||||
|
auto status = execCustomOpWithScope(nullptr, state, 10, scopes, 2, nullptr, nullptr, 0, nullptr, nullptr, 0);
|
||||||
|
// it's no-op: just LogicScope
|
||||||
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
|
||||||
|
deleteGraphState(state);
|
||||||
|
}
|
||||||
|
|
||||||
|
// This test checks WHILE loop
|
||||||
|
TEST_F(GraphStateTests, Stateful_Execution_3) {
|
||||||
|
auto var0 = NDArrayFactory::create<float>('c', {2, 2}, {1, 2, 3, 4});
|
||||||
|
auto var1 = NDArrayFactory::create<float>(11.0f);
|
||||||
|
auto var2 = NDArrayFactory::create<float>(2.0f);
|
||||||
|
|
||||||
|
auto res0 = NDArrayFactory::create<float>('c', {2, 2});
|
||||||
|
auto res1 = NDArrayFactory::create<float>(0.0f);
|
||||||
|
auto res2 = NDArrayFactory::create<float>(0.0f);
|
||||||
|
|
||||||
|
// registering our GraphState holder
|
||||||
|
auto state = (GraphState *) getGraphState(117L);
|
||||||
|
|
||||||
|
// we're prepping pointers to input/output buffers
|
||||||
|
Nd4jPointer ptrBuffers[] = {(Nd4jPointer) var0.buffer(), (Nd4jPointer) var1.buffer(), (Nd4jPointer)var2.buffer()};
|
||||||
|
Nd4jPointer ptrShapes[] = {(Nd4jPointer) var0.shapeInfo(), (Nd4jPointer) var1.shapeInfo(), (Nd4jPointer)var2.shapeInfo()};
|
||||||
|
|
||||||
|
Nd4jPointer outBuffers[] = {(Nd4jPointer) res0.buffer(), (Nd4jPointer) res1.buffer(), (Nd4jPointer) res2.buffer()};
|
||||||
|
Nd4jPointer outShapes[] = {(Nd4jPointer) res0.shapeInfo(), (Nd4jPointer) res1.shapeInfo(), (Nd4jPointer) res2.shapeInfo()};
|
||||||
|
|
||||||
|
// conditional scope
|
||||||
|
state->registerScope(22);
|
||||||
|
|
||||||
|
sd::ops::LegacyReduceSameOp op1(reduce::Sum);
|
||||||
|
sd::ops::lt_scalar op2;
|
||||||
|
|
||||||
|
// while sum(var0) < var1
|
||||||
|
// this op takes sum
|
||||||
|
ArgumentsList args1({{0, 0}});
|
||||||
|
|
||||||
|
// this op compares result of sum to input variable 0:1
|
||||||
|
ArgumentsList args2({{1, 0}, {0, 1}});
|
||||||
|
|
||||||
|
state->attachOpToScope(22, 1, &op1, args1);
|
||||||
|
state->attachOpToScope(22, 2, &op2, args2);
|
||||||
|
|
||||||
|
// body scope
|
||||||
|
state->registerScope(33);
|
||||||
|
|
||||||
|
// var0 + var1 + var1
|
||||||
|
// this op is var0 + var1
|
||||||
|
ArgumentsList args3({{0, 0}, {0, 2}});
|
||||||
|
|
||||||
|
// this op is result of previous op + 1
|
||||||
|
ArgumentsList args4({{3, 0}, {0, 2}});
|
||||||
|
|
||||||
|
sd::ops::add op3;
|
||||||
|
sd::ops::add op4;
|
||||||
|
|
||||||
|
state->attachOpToScope(33, 3, &op3, args3);
|
||||||
|
state->attachOpToScope(33, 4, &op4, args4);
|
||||||
|
|
||||||
|
// Now we define RETURN, which returns 1 modified variable, and 2 unmodified variables
|
||||||
|
ArgumentsList args5({{4, 0}, {0, 1}, {0, 2}});
|
||||||
|
|
||||||
|
// so, at the end of body, initial variables will be updated
|
||||||
|
state->defineReturn(33, 5, args5);
|
||||||
|
|
||||||
|
Nd4jLong scopes[] = {22, 33};
|
||||||
|
|
||||||
|
// we're executing while loop
|
||||||
|
auto status = execCustomOpWithScope(nullptr, state, 0, scopes, 2, ptrBuffers, ptrShapes, 3, outBuffers, outShapes, 3);
|
||||||
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
|
||||||
|
// now we check provided result array
|
||||||
|
float sum = res0.reduceNumber(reduce::Sum).e<float>(0);
|
||||||
|
|
||||||
|
// Expected result is {1, 2, 3, 4} + {2} elementwise + {2} elementwise, which gives { 5, 6, 7, 8}, and sum should be 26
|
||||||
|
ASSERT_NEAR(26.0f, sum, 1e-5);
|
||||||
|
|
||||||
|
// nd4j_printf("0 ------------------\n","");
|
||||||
|
|
||||||
|
deleteGraphState(state);
|
||||||
|
|
||||||
|
// nd4j_printf("1 ------------------\n","");
|
||||||
|
}
|
||||||
|
|
||||||
|
// This test checks CONDITIONAL execution for FALSE
|
||||||
|
TEST_F(GraphStateTests, Stateful_Execution_4) {
|
||||||
|
auto var0 = NDArrayFactory::create<float>('c', {2, 2}, {1, 2, 3, 4});
|
||||||
|
auto var1 = NDArrayFactory::create<float>(5.0f);
|
||||||
|
|
||||||
|
auto res0 = NDArrayFactory::create<float>('c', {2, 2});
|
||||||
|
auto res1 = NDArrayFactory::create<float>(0.0f);
|
||||||
|
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', {2, 2}, {-4, -3, -2, -1});
|
||||||
|
|
||||||
|
|
||||||
|
// registering our GraphState holder
|
||||||
|
auto state = (GraphState *) getGraphState(117L);
|
||||||
|
|
||||||
|
// we're prepping pointers to input/output buffers
|
||||||
|
Nd4jPointer ptrBuffers[] = {(Nd4jPointer) var0.buffer(), (Nd4jPointer) var1.buffer()};
|
||||||
|
Nd4jPointer ptrShapes[] = {(Nd4jPointer) var0.shapeInfo(), (Nd4jPointer) var1.shapeInfo()};
|
||||||
|
|
||||||
|
Nd4jPointer outBuffers[] = {(Nd4jPointer) res0.buffer(), (Nd4jPointer) res1.buffer()};
|
||||||
|
Nd4jPointer outShapes[] = {(Nd4jPointer) res0.shapeInfo(), (Nd4jPointer) res1.shapeInfo()};
|
||||||
|
|
||||||
|
// conditional scope
|
||||||
|
state->registerScope(22);
|
||||||
|
|
||||||
|
sd::ops::LegacyReduceSameOp op1(reduce::Sum);
|
||||||
|
sd::ops::lt_scalar op2;
|
||||||
|
|
||||||
|
// if sum(var0) < var1
|
||||||
|
// this op takes sum
|
||||||
|
ArgumentsList args1({{0, 0}});
|
||||||
|
|
||||||
|
// this op compares result of sum to input variable 0:1
|
||||||
|
ArgumentsList args2({{1, 0}, {0, 1}});
|
||||||
|
|
||||||
|
state->attachOpToScope(22, 1, &op1, args1);
|
||||||
|
state->attachOpToScope(22, 2, &op2, args2);
|
||||||
|
|
||||||
|
// false scope
|
||||||
|
state->registerScope(33);
|
||||||
|
|
||||||
|
ArgumentsList args3({{0, 0}, {0, 1}});
|
||||||
|
sd::ops::subtract op3;
|
||||||
|
state->attachOpToScope(33, 3, &op3, args3);
|
||||||
|
|
||||||
|
// return for false scope
|
||||||
|
ArgumentsList args10({{3, 0}, {0, 1}});
|
||||||
|
state->defineReturn(33, 10, args10);
|
||||||
|
|
||||||
|
// true scope
|
||||||
|
state->registerScope(44);
|
||||||
|
|
||||||
|
ArgumentsList args4({{0, 0}, {0, 1}});
|
||||||
|
sd::ops::add op4;
|
||||||
|
state->attachOpToScope(44, 4, &op4, args4);
|
||||||
|
|
||||||
|
// return for false scope
|
||||||
|
ArgumentsList args20({{4, 0}, {0, 1}});
|
||||||
|
state->defineReturn(44, 20, args20);
|
||||||
|
|
||||||
|
|
||||||
|
Nd4jLong scopes[] = {22, 33, 44};
|
||||||
|
|
||||||
|
// we're executing conditional op
|
||||||
|
auto status = execCustomOpWithScope(nullptr, state, 20, scopes, 3, ptrBuffers, ptrShapes, 2, outBuffers, outShapes, 2);
|
||||||
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.isSameShape(&res0));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(&res0));
|
||||||
|
|
||||||
|
|
||||||
|
deleteGraphState(state);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// This test checks CONDITIONAL execution for TRUE
|
||||||
|
TEST_F(GraphStateTests, Stateful_Execution_5) {
|
||||||
|
auto var0 = NDArrayFactory::create<float>('c', {2, 2}, {1, 2, 3, 4});
|
||||||
|
auto var1 = NDArrayFactory::create<float>(5.0f);
|
||||||
|
|
||||||
|
auto res0 = NDArrayFactory::create<float>('c', {2, 2});
|
||||||
|
auto res1 = NDArrayFactory::create<float>(0.0f);
|
||||||
|
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', {2, 2}, {6, 7, 8, 9});
|
||||||
|
|
||||||
|
|
||||||
|
// registering our GraphState holder
|
||||||
|
auto state = (GraphState *) getGraphState(117L);
|
||||||
|
|
||||||
|
// we're prepping pointers to input/output buffers
|
||||||
|
Nd4jPointer ptrBuffers[] = {(Nd4jPointer) var0.buffer(), (Nd4jPointer) var1.buffer()};
|
||||||
|
Nd4jPointer ptrShapes[] = {(Nd4jPointer) var0.shapeInfo(), (Nd4jPointer) var1.shapeInfo()};
|
||||||
|
|
||||||
|
Nd4jPointer outBuffers[] = {(Nd4jPointer) res0.buffer(), (Nd4jPointer) res1.buffer()};
|
||||||
|
Nd4jPointer outShapes[] = {(Nd4jPointer) res0.shapeInfo(), (Nd4jPointer) res1.shapeInfo()};
|
||||||
|
|
||||||
|
// conditional scope
|
||||||
|
state->registerScope(22);
|
||||||
|
|
||||||
|
sd::ops::LegacyReduceSameOp op1(reduce::Sum);
|
||||||
|
sd::ops::gt_scalar op2;
|
||||||
|
|
||||||
|
// if sum(var0) < var1
|
||||||
|
// this op takes sum
|
||||||
|
ArgumentsList args1({{0, 0}});
|
||||||
|
|
||||||
|
// this op compares result of sum to input variable 0:1
|
||||||
|
ArgumentsList args2({{1, 0}, {0, 1}});
|
||||||
|
|
||||||
|
state->attachOpToScope(22, 1, &op1, args1);
|
||||||
|
state->attachOpToScope(22, 2, &op2, args2);
|
||||||
|
|
||||||
|
// false scope
|
||||||
|
state->registerScope(33);
|
||||||
|
|
||||||
|
ArgumentsList args3({{0, 0}, {0, 1}});
|
||||||
|
sd::ops::subtract op3;
|
||||||
|
state->attachOpToScope(33, 3, &op3, args3);
|
||||||
|
|
||||||
|
// return for false scope
|
||||||
|
ArgumentsList args10({{3, 0}, {0, 1}});
|
||||||
|
state->defineReturn(33, 10, args10);
|
||||||
|
|
||||||
|
// true scope
|
||||||
|
state->registerScope(44);
|
||||||
|
|
||||||
|
ArgumentsList args4({{0, 0}, {0, 1}});
|
||||||
|
sd::ops::add op4;
|
||||||
|
state->attachOpToScope(44, 4, &op4, args4);
|
||||||
|
|
||||||
|
// return for false scope
|
||||||
|
ArgumentsList args20({{4, 0}, {0, 1}});
|
||||||
|
state->defineReturn(44, 20, args20);
|
||||||
|
|
||||||
|
|
||||||
|
Nd4jLong scopes[] = {22, 33, 44};
|
||||||
|
|
||||||
|
// we're executing conditional op
|
||||||
|
auto status = execCustomOpWithScope(nullptr, state, 20, scopes, 3, ptrBuffers, ptrShapes, 2, outBuffers, outShapes, 2);
|
||||||
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.isSameShape(&res0));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(&res0));
|
||||||
|
|
||||||
|
deleteGraphState(state);
|
||||||
|
}
|
||||||
|
*/
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,45 @@
|
||||||
|
/* ******************************************************************************
|
||||||
|
*
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* See the NOTICE file distributed with this work for additional
|
||||||
|
* information regarding copyright ownership.
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// Created by raver119 on 02.09.17.
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "testlayers.h"
|
||||||
|
#include <helpers/helper_hash.h>
|
||||||
|
|
||||||
|
class HashUtilsTests : public testing::Test {
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(HashUtilsTests, TestEquality1) {
|
||||||
|
std::string str("Conv2D");
|
||||||
|
|
||||||
|
Nd4jLong hash1 = sd::ops::HashHelper::getInstance().getLongHash(str);
|
||||||
|
ASSERT_EQ(-1637140380760460323L, hash1);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(HashUtilsTests, TestEquality2) {
|
||||||
|
std::string str("switch");
|
||||||
|
|
||||||
|
Nd4jLong hash1 = sd::ops::HashHelper::getInstance().getLongHash(str);
|
||||||
|
ASSERT_EQ(-1988317239813741487L, hash1);
|
||||||
|
}
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,429 @@
|
||||||
|
/*
|
||||||
|
* ******************************************************************************
|
||||||
|
* *
|
||||||
|
* *
|
||||||
|
* * This program and the accompanying materials are made available under the
|
||||||
|
* * terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
* *
|
||||||
|
* * See the NOTICE file distributed with this work for additional
|
||||||
|
* * information regarding copyright ownership.
|
||||||
|
* * Unless required by applicable law or agreed to in writing, software
|
||||||
|
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* * License for the specific language governing permissions and limitations
|
||||||
|
* * under the License.
|
||||||
|
* *
|
||||||
|
* * SPDX-License-Identifier: Apache-2.0
|
||||||
|
* *****************************************************************************
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "testlayers.h"
|
||||||
|
#include <helpers/HessenbergAndSchur.h>
|
||||||
|
#include <helpers/EigenValsAndVecs.h>
|
||||||
|
#include <helpers/FullPivLU.h>
|
||||||
|
#include <ops/declarable/helpers/triangular_solve.h>
|
||||||
|
#include <helpers/Sqrtm.h>
|
||||||
|
|
||||||
|
using namespace sd;
|
||||||
|
|
||||||
|
class HelpersTests2 : public testing::Test {
|
||||||
|
public:
|
||||||
|
|
||||||
|
HelpersTests2() {
|
||||||
|
|
||||||
|
std::cout<<std::endl<<std::flush;
|
||||||
|
}
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
// #ifndef __CUDABLAS__
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(HelpersTests2, Hessenberg_1) {
|
||||||
|
|
||||||
|
|
||||||
|
NDArray x1('c', {1,4}, {14,17,3,1}, sd::DataType::DOUBLE);
|
||||||
|
NDArray x2('c', {1,1}, {14}, sd::DataType::DOUBLE);
|
||||||
|
NDArray expQ('c', {1,1}, {1}, sd::DataType::DOUBLE);
|
||||||
|
|
||||||
|
ops::helpers::Hessenberg<double> hess1(x1);
|
||||||
|
ASSERT_TRUE(hess1._H.isSameShape(&x1));
|
||||||
|
ASSERT_TRUE(hess1._H.equalsTo(&x1));
|
||||||
|
ASSERT_TRUE(hess1._Q.isSameShape(&expQ));
|
||||||
|
ASSERT_TRUE(hess1._Q.equalsTo(&expQ));
|
||||||
|
|
||||||
|
ops::helpers::Hessenberg<double> hess2(x2);
|
||||||
|
ASSERT_TRUE(hess2._H.isSameShape(&x2));
|
||||||
|
ASSERT_TRUE(hess2._H.equalsTo(&x2));
|
||||||
|
ASSERT_TRUE(hess2._Q.isSameShape(&expQ));
|
||||||
|
ASSERT_TRUE(hess2._Q.equalsTo(&expQ));
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(HelpersTests2, Hessenberg_2) {
|
||||||
|
|
||||||
|
NDArray x('c', {2,2}, {1.5,-2,17,5}, sd::DataType::DOUBLE);
|
||||||
|
NDArray expQ('c', {2,2}, {1,0,0,1}, sd::DataType::DOUBLE);
|
||||||
|
|
||||||
|
ops::helpers::Hessenberg<double> hess(x);
|
||||||
|
|
||||||
|
// hess._H.printBuffer();
|
||||||
|
|
||||||
|
ASSERT_TRUE(hess._H.isSameShape(&x));
|
||||||
|
ASSERT_TRUE(hess._H.equalsTo(&x));
|
||||||
|
|
||||||
|
ASSERT_TRUE(hess._Q.isSameShape(&expQ));
|
||||||
|
ASSERT_TRUE(hess._Q.equalsTo(&expQ));
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(HelpersTests2, Hessenberg_3) {
|
||||||
|
|
||||||
|
NDArray x('c', {3,3}, {33,24,-48,57,12.5,-3,1.1,10,-5.2}, sd::DataType::DOUBLE);
|
||||||
|
NDArray expH('c', {3,3}, {33, -23.06939, -48.45414, -57.01061, 12.62845, 3.344058, 0, -9.655942, -5.328448}, sd::DataType::DOUBLE);
|
||||||
|
NDArray expQ('c', {3,3}, {1,0,0,0, -0.99981, -0.019295, 0, -0.019295,0.99981}, sd::DataType::DOUBLE);
|
||||||
|
|
||||||
|
ops::helpers::Hessenberg<double> hess(x);
|
||||||
|
|
||||||
|
ASSERT_TRUE(hess._H.isSameShape(&expH));
|
||||||
|
ASSERT_TRUE(hess._H.equalsTo(&expH));
|
||||||
|
|
||||||
|
ASSERT_TRUE(hess._Q.isSameShape(&expQ));
|
||||||
|
ASSERT_TRUE(hess._Q.equalsTo(&expQ));
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(HelpersTests2, Hessenberg_4) {
|
||||||
|
|
||||||
|
NDArray x('c', {4,4}, {0.33 ,-7.25 ,1.71 ,6.20 ,1.34 ,5.38 ,-2.76 ,-8.51 ,7.59 ,3.44 ,2.24 ,-6.82 ,-1.15 ,4.80 ,-4.67 ,2.14}, sd::DataType::DOUBLE);
|
||||||
|
NDArray expH('c', {4,4}, {0.33, 0.4961181, 3.51599, 9.017665, -7.792702, 4.190221, 6.500328, 5.438888, 0, 3.646734, 0.4641911, -7.635502, 0,0, 5.873535, 5.105588}, sd::DataType::DOUBLE);
|
||||||
|
NDArray expQ('c', {4,4}, {1,0,0,0, 0,-0.171956, 0.336675, -0.925787, 0,-0.973988,0.0826795, 0.210976, 0, 0.147574, 0.937984,0.3137}, sd::DataType::DOUBLE);
|
||||||
|
|
||||||
|
ops::helpers::Hessenberg<double> hess(x);
|
||||||
|
|
||||||
|
ASSERT_TRUE(hess._H.isSameShape(&expH));
|
||||||
|
ASSERT_TRUE(hess._H.equalsTo(&expH));
|
||||||
|
|
||||||
|
ASSERT_TRUE(hess._Q.isSameShape(&expQ));
|
||||||
|
ASSERT_TRUE(hess._Q.equalsTo(&expQ));
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(HelpersTests2, Hessenberg_5) {
|
||||||
|
|
||||||
|
NDArray x('c', {10,10}, {6.9 ,4.8 ,9.5 ,3.1 ,6.5 ,5.8 ,-0.9 ,-7.3 ,-8.1 ,3.0 ,0.1 ,9.9 ,-3.2 ,6.4 ,6.2 ,-7.0 ,5.5 ,-2.2 ,-4.0 ,3.7 ,-3.6 ,9.0 ,-1.4 ,-2.4 ,1.7 ,
|
||||||
|
-6.1 ,-4.2 ,-2.5 ,-5.6 ,-0.4 ,0.4 ,9.1 ,-2.1 ,-5.4 ,7.3 ,3.6 ,-1.7 ,-5.7 ,-8.0 ,8.8 ,-3.0 ,-0.5 ,1.1 ,10.0 ,8.0 ,0.8 ,1.0 ,7.5 ,3.5 ,-1.8 ,
|
||||||
|
0.3 ,-0.6 ,-6.3 ,-4.5 ,-1.1 ,1.8 ,0.6 ,9.6 ,9.2 ,9.7 ,-2.6 ,4.3 ,-3.4 ,0.0 ,-6.7 ,5.0 ,10.5 ,1.5 ,-7.8 ,-4.1 ,-5.3 ,-5.0 ,2.0 ,-4.4 ,-8.4 ,
|
||||||
|
6.0 ,-9.4 ,-4.8 ,8.2 ,7.8 ,5.2 ,-9.5 ,-3.9 ,0.2 ,6.8 ,5.7 ,-8.5 ,-1.9 ,-0.3 ,7.4 ,-8.7 ,7.2 ,1.3 ,6.3 ,-3.7 ,3.9 ,3.3 ,-6.0 ,-9.1 ,5.9}, sd::DataType::DOUBLE);
|
||||||
|
NDArray expH('c', {10,10}, {6.9, 6.125208, -8.070945, 7.219828, -9.363308, 2.181236, 5.995414, 3.892612, 4.982657, -2.088574,-12.6412, 1.212547, -6.449684, 5.162879, 0.4341714, -5.278079, -2.624011, -2.03615, 11.39619, -3.034842,
|
||||||
|
0, -12.71931, 10.1146, 6.494434, -1.062934, 5.668906, -4.672953, -9.319893, -2.023392, 6.090341,0,0, 7.800521, -1.46286, 1.484626, -10.58252, -3.492978, 2.42187, 5.470045, 1.877265,
|
||||||
|
0,0,0, 14.78259,-0.3147726, -5.74874, -0.377823, 3.310056, 2.242614, -5.111574,0,0,0,0, -9.709131, 3.885072, 6.762626, 4.509144, 2.390195, -4.991013,
|
||||||
|
0,0,0,0,0, 8.126269, -12.32529, 9.030151, 1.390931, 0.8634045,0,0,0,0,0,0, -12.99477, 9.574299,-0.3098022, 4.910835,0,0,0,0,0,0,0, 14.75256, 18.95723, -5.054717,0,0,0,0,0,0,0,0, -4.577715, -5.440827,}, sd::DataType::DOUBLE);
|
||||||
|
NDArray expQ('c', {10,10}, {1,0,0,0,0,0,0,0,0,0,0,-0.0079106,-0.38175,-0.39287,-0.26002,-0.44102,-0.071516,0.12118,0.64392,0.057562,
|
||||||
|
0,0.28478,0.0058784,0.3837,-0.47888,0.39477,0.0036847,-0.24678,0.3229,0.47042,0,-0.031643,-0.61277,0.087648,0.12014,0.47648,-0.5288,0.060599,0.021434,-0.30102,
|
||||||
|
0,0.23732,-0.17801,-0.31809,-0.31267,0.27595,0.30134,0.64555,-0.33392,0.13363,0,-0.023732,-0.40236,0.43089,-0.38692,-0.5178,-0.03957,-0.081667,-0.47515,-0.0077949,
|
||||||
|
0,0.20568,-0.0169,0.36962,0.49669,-0.22475,-0.22199,0.50075,0.10454,0.46112,0,0.41926,0.30243,-0.3714,-0.16795,-0.12969,-0.67572,-0.1205,-0.26047,0.10407,
|
||||||
|
0,-0.41135,-0.28357,-0.33858,0.18836,0.083822,-0.0068213,-0.30161,-0.24956,0.66327,0,0.68823,-0.33616,-0.12129,0.36163,-0.063256,0.34198,-0.37564,-0.048196,-0.058948}, sd::DataType::DOUBLE);
|
||||||
|
|
||||||
|
ops::helpers::Hessenberg<double> hess(x);
|
||||||
|
|
||||||
|
ASSERT_TRUE(hess._H.isSameShape(&expH));
|
||||||
|
ASSERT_TRUE(hess._H.equalsTo(&expH));
|
||||||
|
|
||||||
|
ASSERT_TRUE(hess._Q.isSameShape(&expQ));
|
||||||
|
ASSERT_TRUE(hess._Q.equalsTo(&expQ));
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(HelpersTests2, Schur_1) {
|
||||||
|
|
||||||
|
NDArray x('c', {3,3}, sd::DataType::DOUBLE);
|
||||||
|
|
||||||
|
NDArray expT('c', {3,3}, {-2.5, -2, 1, 0, 1.5, -2, 3, 4, 5}, sd::DataType::DOUBLE);
|
||||||
|
NDArray expU('c', {3,3}, {0.3, 0.2,-0.1, 0,-0.1, 0.2, -0.3,-0.4, 0.5}, sd::DataType::DOUBLE);
|
||||||
|
|
||||||
|
ops::helpers::Schur<double> schur(x);
|
||||||
|
schur._T.linspace(-3, 1);
|
||||||
|
schur._U.linspace(-0.3, 0.1);
|
||||||
|
|
||||||
|
schur.splitTwoRows(1, 0.5);
|
||||||
|
|
||||||
|
ASSERT_TRUE(schur._T.isSameShape(&expT));
|
||||||
|
ASSERT_TRUE(schur._T.equalsTo(&expT));
|
||||||
|
|
||||||
|
ASSERT_TRUE(schur._U.isSameShape(&expU));
|
||||||
|
ASSERT_TRUE(schur._U.equalsTo(&expU));
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(HelpersTests2, Schur_2) {
|
||||||
|
|
||||||
|
NDArray x('c', {3,3}, sd::DataType::DOUBLE);
|
||||||
|
|
||||||
|
NDArray shift('c', {3}, sd::DataType::DOUBLE);
|
||||||
|
NDArray exp1('c', {3}, {1,-3,0}, sd::DataType::DOUBLE);
|
||||||
|
NDArray exp2('c', {3}, {3, 3,-7}, sd::DataType::DOUBLE);
|
||||||
|
NDArray exp3('c', {3}, {0.964,0.964,0.964}, sd::DataType::DOUBLE);
|
||||||
|
NDArray exp1T('c', {3,3}, {-3,-2,-1,0,1,2,3,4,5}, sd::DataType::DOUBLE);
|
||||||
|
NDArray exp2T('c', {3,3}, {-8,-2,-1,0,-4,2,3,4,0}, sd::DataType::DOUBLE);
|
||||||
|
NDArray exp3T('c', {3,3}, {-9.464102,-2,-1,0,-5.464102,2,3,4,-1.464102,}, sd::DataType::DOUBLE);
|
||||||
|
|
||||||
|
ops::helpers::Schur<double> schur(x);
|
||||||
|
// schur._U.linspace(-0.3, 0.1); // doesn't matter
|
||||||
|
|
||||||
|
schur._T.linspace(-3, 1);
|
||||||
|
double expShift =0;
|
||||||
|
schur.calcShift(1, 5, expShift, shift);
|
||||||
|
ASSERT_TRUE(schur._T.equalsTo(&exp1T));
|
||||||
|
ASSERT_TRUE(shift.isSameShape(&exp1));
|
||||||
|
ASSERT_TRUE(shift.equalsTo(&exp1));
|
||||||
|
ASSERT_TRUE(expShift == 0);
|
||||||
|
|
||||||
|
schur._T.linspace(-3, 1);
|
||||||
|
expShift = 0;
|
||||||
|
schur.calcShift(2, 10, expShift, shift);
|
||||||
|
ASSERT_TRUE(schur._T.equalsTo(&exp2T));
|
||||||
|
ASSERT_TRUE(shift.isSameShape(&exp2));
|
||||||
|
ASSERT_TRUE(shift.equalsTo(&exp2));
|
||||||
|
ASSERT_TRUE(expShift == 5);
|
||||||
|
|
||||||
|
schur._T.linspace(-3, 1);
|
||||||
|
expShift = 0;
|
||||||
|
schur.calcShift(2, 30, expShift, shift);
|
||||||
|
ASSERT_TRUE(schur._T.equalsTo(&exp3T));
|
||||||
|
ASSERT_TRUE(shift.isSameShape(&exp3));
|
||||||
|
ASSERT_TRUE(shift.equalsTo(&exp3));
|
||||||
|
ASSERT_TRUE((6.4641-0.00001) < expShift && expShift < (6.4641+0.00001));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(HelpersTests2, Schur_3) {
|
||||||
|
|
||||||
|
NDArray x('c', {2,2}, {1.5,-2,17,5}, sd::DataType::DOUBLE);
|
||||||
|
NDArray expU('c', {2,2}, {1,0,0,1}, sd::DataType::DOUBLE);
|
||||||
|
|
||||||
|
ops::helpers::Schur<double> schur(x);
|
||||||
|
|
||||||
|
ASSERT_TRUE(schur._T.isSameShape(&x));
|
||||||
|
ASSERT_TRUE(schur._T.equalsTo(&x));
|
||||||
|
|
||||||
|
ASSERT_TRUE(schur._U.isSameShape(&expU));
|
||||||
|
ASSERT_TRUE(schur._U.equalsTo(&expU));
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(HelpersTests2, Schur_4) {
|
||||||
|
|
||||||
|
NDArray x('c', {3,3}, {33,24,-48,57,12.5,-3,1.1,10,-5.2}, sd::DataType::DOUBLE);
|
||||||
|
NDArray expT('c', {3,3}, {53.73337,-20.21406,-50.44809,0,-27.51557, 26.74307,0,0,14.0822}, sd::DataType::DOUBLE);
|
||||||
|
NDArray expU('c', {3,3}, {-0.5848506, 0.7185352, 0.3763734,-0.7978391,-0.5932709,-0.1071558,-0.1462962, 0.3629555,-0.9202504}, sd::DataType::DOUBLE);
|
||||||
|
|
||||||
|
ops::helpers::Schur<double> schur(x);
|
||||||
|
|
||||||
|
ASSERT_TRUE(schur._T.isSameShape(&expT));
|
||||||
|
ASSERT_TRUE(schur._T.equalsTo(&expT));
|
||||||
|
|
||||||
|
ASSERT_TRUE(schur._U.isSameShape(&expU));
|
||||||
|
ASSERT_TRUE(schur._U.equalsTo(&expU));
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(HelpersTests2, Schur_5) {
|
||||||
|
|
||||||
|
NDArray x('c', {4,4}, {0.33 ,-7.25 ,1.71 ,6.20 ,1.34 ,5.38 ,-2.76 ,-8.51 ,7.59 ,3.44 ,2.24 ,-6.82 ,-1.15 ,4.80 ,-4.67 ,2.14}, sd::DataType::DOUBLE);
|
||||||
|
NDArray expT('c', {4,4}, {6.940177,7.201107,2.523849,-8.534745,-3.109643,5.289615,-2.940507,9.330303, 0,0,-0.1740346, 7.19851,0,0, -2.870214, -1.965758}, sd::DataType::DOUBLE);
|
||||||
|
NDArray expU('c', {4,4}, {-0.2602141, 0.8077556,-0.3352316,-0.4091935,0.3285353,-0.4395489,-0.4714875,-0.6903338,0.7536921, 0.3005626,-0.3910435, 0.4343908,-0.5062621, -0.252962,-0.7158242, 0.4090287}, sd::DataType::DOUBLE);
|
||||||
|
|
||||||
|
ops::helpers::Schur<double> schur(x);
|
||||||
|
|
||||||
|
ASSERT_TRUE(schur._T.isSameShape(&expT));
|
||||||
|
ASSERT_TRUE(schur._T.equalsTo(&expT));
|
||||||
|
|
||||||
|
ASSERT_TRUE(schur._U.isSameShape(&expU));
|
||||||
|
ASSERT_TRUE(schur._U.equalsTo(&expU));
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
/*
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(HelpersTests2, Schur_6) {
|
||||||
|
|
||||||
|
NDArray x('c', {10,10}, {6.9 ,4.8 ,9.5 ,3.1 ,6.5 ,5.8 ,-0.9 ,-7.3 ,-8.1 ,3.0 ,0.1 ,9.9 ,-3.2 ,6.4 ,6.2 ,-7.0 ,5.5 ,-2.2 ,-4.0 ,3.7 ,-3.6 ,9.0 ,-1.4 ,-2.4 ,1.7 ,
|
||||||
|
-6.1 ,-4.2 ,-2.5 ,-5.6 ,-0.4 ,0.4 ,9.1 ,-2.1 ,-5.4 ,7.3 ,3.6 ,-1.7 ,-5.7 ,-8.0 ,8.8 ,-3.0 ,-0.5 ,1.1 ,10.0 ,8.0 ,0.8 ,1.0 ,7.5 ,3.5 ,-1.8 ,
|
||||||
|
0.3 ,-0.6 ,-6.3 ,-4.5 ,-1.1 ,1.8 ,0.6 ,9.6 ,9.2 ,9.7 ,-2.6 ,4.3 ,-3.4 ,0.0 ,-6.7 ,5.0 ,10.5 ,1.5 ,-7.8 ,-4.1 ,-5.3 ,-5.0 ,2.0 ,-4.4 ,-8.4 ,
|
||||||
|
6.0 ,-9.4 ,-4.8 ,8.2 ,7.8 ,5.2 ,-9.5 ,-3.9 ,0.2 ,6.8 ,5.7 ,-8.5 ,-1.9 ,-0.3 ,7.4 ,-8.7 ,7.2 ,1.3 ,6.3 ,-3.7 ,3.9 ,3.3 ,-6.0 ,-9.1 ,5.9}, sd::DataType::DOUBLE);
|
||||||
|
NDArray expT('c', {10,10}, {-13.78982, 6.072464, 0.3021194, -8.455495,-0.3047058, 4.033153, 2.610364, 2.80607, -2.735616, 0.3040549,-2.188506, -12.38324, -1.167179, -4.539672, -19.08546, 1.752401,-0.1354974,-0.2747422,-0.3270464, -5.070936,
|
||||||
|
0,0,0.5067366, 7.930223,-0.6465996, 8.659522, 1.283713, 4.551415, 12.7736, 3.4812,0,0,-9.858142, -2.905068, -6.474159, -6.247967, 0.4720073, -10.49523, 3.617189, -4.941627,
|
||||||
|
0,0,0,0,9.461626, -4.896166, 9.339704, 4.640336, 16.8626, 2.056027,0,0,0,0,6.479812, 8.462862, 7.386285, -4.123457, -5.817095, -2.633641,0,0,0,0,0,0,13.46667, -4.907281, 4.602204, 5.198035,
|
||||||
|
0,0,0,0,0,0, 7.176822, 16.93311, 2.195036, 1.346086,0,0,0,0,0,0,0,0, 16.86979, -3.052473,0,0,0,0,0,0,0,0,0, -5.52268}, sd::DataType::DOUBLE);
|
||||||
|
|
||||||
|
// NDArray expT('c', {10,10}, {-13.78982, 6.072464, 0.1926198, -8.458698,-0.3047363, 4.033151, 2.610336, 2.806096, -2.735616, 0.3040549,-2.188506, -12.38324, -1.225857, -4.52418, -19.08548, 1.752257,-0.1354946,-0.2747435,-0.3270464, -5.070936,
|
||||||
|
// 0,0, 0.4812058, 7.886377,-0.7304318, 8.577898, 1.289673, 4.415163, 12.81936, 3.416929,0,0, -9.901988, -2.879537, -6.465196, -6.359608, 0.455452, -10.55328, 3.451505, -4.986284,
|
||||||
|
// 0,0,0,0, 9.461614, -4.896159, 9.339602, 4.64046, 16.86265, 2.056047,0,0,0,0, 6.47982, 8.462874, 7.386396, -4.123349, -5.816967, -2.633626,
|
||||||
|
// 0,0,0,0,0,0, 13.46665, -4.907315, 4.602182, 5.198022,0,0,0,0,0,0, 7.176788, 16.93313, 2.195081, 1.346137,0,0,0,0,0,0,0,0, 16.86979, -3.052473,0,0,0,0,0,0,0,0,0, -5.52268}, sd::DataType::DOUBLE);
|
||||||
|
|
||||||
|
NDArray expU('c', {10,10}, {0.1964177, 0.2165192, -0.2138164, 0.4083154, -0.1872303, -0.5087223, 0.5529025, -0.2996174,-0.08772947, 0.07126534,-0.1906247, -0.223588, 0.3574755, 0.4245914, -0.3885589,-0.07328949, -0.4176507, -0.1885168, -0.4476957, 0.1971104,
|
||||||
|
-0.2219015, 0.3084187, 0.1069209, -0.4905009, -0.3517786, 0.1446875, 0.121738, -0.3772941, 0.1232591, 0.5353205,-0.4766346, 0.6158252, -0.1529085, 0.04780914, 0.1274182, -0.1219211, -0.3123289, -0.2219282,-0.07613826, -0.429201,
|
||||||
|
0.2577533, -0.3356205, -0.225358, -0.1540796, 0.3155174, -0.1904664, -0.3567101, -0.6831458, 0.1244646, 0.03383783, -0.45597, -0.3350697, 0.06824276, -0.2861978,-0.06724917, -0.7046481, 0.01664764, 0.2270567, 0.2003283,-0.01544937,
|
||||||
|
0.122865, 0.1516775, -0.4446453, -0.2338583, 0.1633447, -0.193498, -0.198088, 0.3170272, -0.5869794, 0.4013553, 0.347383, 0.3666581, 0.6890763,-0.05797414, 0.3630058, -0.319958, -0.1071812, 0.06162044, 0.03171228, 0.1275262,
|
||||||
|
-0.2986812, 0.05382598, -0.1484276, 0.4936468, 0.362756, 0.05858297, -0.1055183, 0.1090384, 0.4217073, 0.5534347, 0.3864388, 0.2085926, -0.204135, 0.05230855, -0.5290207, -0.1548485, -0.4670302, 0.2205726, 0.4380318,-0.01626632}, sd::DataType::DOUBLE);
|
||||||
|
|
||||||
|
ops::helpers::Schur<double> schur(x);
|
||||||
|
|
||||||
|
ASSERT_TRUE(schur._T.isSameShape(&expT));
|
||||||
|
ASSERT_TRUE(schur._T.equalsTo(&expT, 1e-3));
|
||||||
|
|
||||||
|
ASSERT_TRUE(schur._U.isSameShape(&expU));
|
||||||
|
ASSERT_TRUE(schur._U.equalsTo(&expU));
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(HelpersTests2, EigenValsAndVecs_1) {
|
||||||
|
|
||||||
|
NDArray x('c', {2,2}, {1.5,-2,17,5}, sd::DataType::DOUBLE);
|
||||||
|
NDArray expVals('c', {2,2}, {3.25,5.562149, 3.25,-5.562149}, sd::DataType::DOUBLE);
|
||||||
|
NDArray expVecs('c', {2,2,2}, {-0.3094862,-0.0973726, -0.3094862,0.0973726,0,0.9459053, 0,-0.9459053}, sd::DataType::DOUBLE);
|
||||||
|
|
||||||
|
ops::helpers::EigenValsAndVecs<double> eig(x);
|
||||||
|
|
||||||
|
ASSERT_TRUE(eig._Vals.isSameShape(&expVals));
|
||||||
|
ASSERT_TRUE(eig._Vals.equalsTo(&expVals));
|
||||||
|
|
||||||
|
ASSERT_TRUE(eig._Vecs.isSameShape(&expVecs));
|
||||||
|
ASSERT_TRUE(eig._Vecs.equalsTo(&expVecs));
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(HelpersTests2, EigenValsAndVecs_2) {
|
||||||
|
|
||||||
|
NDArray x('c', {3,3}, {33,24,-48,57,12.5,-3,1.1,10,-5.2}, sd::DataType::DOUBLE);
|
||||||
|
NDArray expVals('c', {3,2}, {53.73337,0, -27.51557,0, 14.0822,0}, sd::DataType::DOUBLE);
|
||||||
|
NDArray expVecs('c', {3,3,2}, {-0.5848506,0,0.5560778,0,-0.04889745,0,-0.7978391,0,-0.7683444,0,-0.8855156,0,-0.1462962,0,0.3168979,0,-0.4620293,0}, sd::DataType::DOUBLE);
|
||||||
|
|
||||||
|
ops::helpers::EigenValsAndVecs<double> eig(x);
|
||||||
|
|
||||||
|
ASSERT_TRUE(eig._Vals.isSameShape(&expVals));
|
||||||
|
ASSERT_TRUE(eig._Vals.equalsTo(&expVals));
|
||||||
|
|
||||||
|
ASSERT_TRUE(eig._Vecs.isSameShape(&expVecs));
|
||||||
|
ASSERT_TRUE(eig._Vecs.equalsTo(&expVecs));
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(HelpersTests2, EigenValsAndVecs_3) {
|
||||||
|
|
||||||
|
NDArray x('c', {4,4}, {0.33 ,-7.25 ,1.71 ,6.20 ,1.34 ,5.38 ,-2.76 ,-8.51 ,7.59 ,3.44 ,2.24 ,-6.82 ,-1.15 ,4.80 ,-4.67 ,2.14}, sd::DataType::DOUBLE);
|
||||||
|
NDArray expVals('c', {4,2}, {6.114896,4.659591,6.114896,-4.659591, -1.069896,4.45631,-1.069896,-4.45631}, sd::DataType::DOUBLE);
|
||||||
|
NDArray expVecs('c', {4,4,2}, {-0.2141303,0.4815241,-0.2141303,-0.4815241, 0.1035092,-0.4270603, 0.1035092,0.4270603, 0.2703519,-0.2892722, 0.2703519,0.2892722, -0.5256817,0.044061, -0.5256817,-0.044061,
|
||||||
|
0.6202137,0.05521234,0.6202137,-0.05521234, -0.5756007,0.3932209,-0.5756007,-0.3932209,-0.4166034,-0.0651337, -0.4166034,0.0651337, -0.1723716,0.1138941,-0.1723716,-0.1138941}, sd::DataType::DOUBLE);
|
||||||
|
|
||||||
|
ops::helpers::EigenValsAndVecs<double> eig(x);
|
||||||
|
|
||||||
|
ASSERT_TRUE(eig._Vals.isSameShape(&expVals));
|
||||||
|
ASSERT_TRUE(eig._Vals.equalsTo(&expVals));
|
||||||
|
|
||||||
|
ASSERT_TRUE(eig._Vecs.isSameShape(&expVecs));
|
||||||
|
ASSERT_TRUE(eig._Vecs.equalsTo(&expVecs));
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(HelpersTests2, EigenValsAndVecs_4) {
|
||||||
|
|
||||||
|
NDArray x('c', {10,10}, {6.9 ,4.8 ,9.5 ,3.1 ,6.5 ,5.8 ,-0.9 ,-7.3 ,-8.1 ,3.0 ,0.1 ,9.9 ,-3.2 ,6.4 ,6.2 ,-7.0 ,5.5 ,-2.2 ,-4.0 ,3.7 ,-3.6 ,9.0 ,-1.4 ,-2.4 ,1.7 ,
|
||||||
|
-6.1 ,-4.2 ,-2.5 ,-5.6 ,-0.4 ,0.4 ,9.1 ,-2.1 ,-5.4 ,7.3 ,3.6 ,-1.7 ,-5.7 ,-8.0 ,8.8 ,-3.0 ,-0.5 ,1.1 ,10.0 ,8.0 ,0.8 ,1.0 ,7.5 ,3.5 ,-1.8 ,
|
||||||
|
0.3 ,-0.6 ,-6.3 ,-4.5 ,-1.1 ,1.8 ,0.6 ,9.6 ,9.2 ,9.7 ,-2.6 ,4.3 ,-3.4 ,0.0 ,-6.7 ,5.0 ,10.5 ,1.5 ,-7.8 ,-4.1 ,-5.3 ,-5.0 ,2.0 ,-4.4 ,-8.4 ,
|
||||||
|
6.0 ,-9.4 ,-4.8 ,8.2 ,7.8 ,5.2 ,-9.5 ,-3.9 ,0.2 ,6.8 ,5.7 ,-8.5 ,-1.9 ,-0.3 ,7.4 ,-8.7 ,7.2 ,1.3 ,6.3 ,-3.7 ,3.9 ,3.3 ,-6.0 ,-9.1 ,5.9}, sd::DataType::DOUBLE);
|
||||||
|
NDArray expVals('c', {10,2}, { -13.08653,3.577011,-13.08653,-3.577011, -1.199166,8.675665,-1.199166,-8.675665,8.962244,
|
||||||
|
5.610424, 8.962244,-5.610424, 15.19989,5.675794, 15.19989,-5.675794,16.86979,0,-5.52268,0}, sd::DataType::DOUBLE);
|
||||||
|
NDArray expVecs('c', {10,10,2}, {0.1652385,0.1439317, 0.1652385,-0.1439317, -0.198272,0.207306, -0.198272,-0.207306, 0.1861466,-0.4599919, 0.1861466,0.4599919, 0.09384053,-0.4889922, 0.09384053,0.4889922, -0.6153314,0, -0.2180209,0,
|
||||||
|
-0.1603652,-0.1466119, -0.1603652,0.1466119, 0.2817409,0.3301842, 0.2817409,-0.3301842, 0.09747303,-0.2218182, 0.09747303,0.2218182, 0.2318273,-0.3355113, 0.2318273,0.3355113, -0.4828878,0, -0.1451126,0,
|
||||||
|
-0.1866771,0.1220412, -0.1866771,-0.1220412, 0.08937842,-0.3025104, 0.08937842,0.3025104, 0.2783766,0.2258364, 0.2783766,-0.2258364, -0.1413997,-0.09596012, -0.1413997,0.09596012, -0.2286925,0, 0.3290011,0,
|
||||||
|
-0.4009741,0.238131, -0.4009741,-0.238131, -0.02772353,0.1338458, -0.02772353,-0.1338458, 0.09030543,-0.2222453, 0.09030543,0.2222453, 0.2565825,-0.2275446, 0.2565825,0.2275446, -0.2855937,0, -0.3950544,0,
|
||||||
|
0.2168379,-0.1301121, 0.2168379,0.1301121, -0.165433,-0.1220125, -0.165433,0.1220125, -0.2685605,0.008133055,-0.2685605,-0.008133055, 0.1929395,-0.1194659, 0.1929395,0.1194659, 0.2206467,0, 0.3289105,0,
|
||||||
|
-0.3835898,-0.2478813, -0.3835898,0.2478813, 0.1923005,-0.01036433, 0.1923005,0.01036433, -0.1711637,-0.3548358, -0.1711637,0.3548358, 0.2888441,0.09625169, 0.2888441,-0.09625169, 0.2595426,0, -0.1288072,0,
|
||||||
|
0.1033616,0.09839151, 0.1033616,-0.09839151, -0.3080167,-0.1624564, -0.3080167,0.1624564,-0.03972293,-0.03967309, -0.03972293,0.03967309, 0.1965443,0.3025898, 0.1965443,-0.3025898, 0.04587166,0, 0.499261,0,
|
||||||
|
0.2922398,0.2461792, 0.2922398,-0.2461792, 0.2769633,-0.2745029, 0.2769633,0.2745029, 0.1034687,-0.002947149, 0.1034687,0.002947149, -0.02611308,0.1658046, -0.02611308,-0.1658046, 0.2351063,0, -0.3787892,0,
|
||||||
|
-0.2512689,-0.02169855, -0.2512689,0.02169855, -0.01481625,0.4376404, -0.01481625,-0.4376404, -0.2298635,-0.2360671, -0.2298635,0.2360671, 0.11004,-0.1467444, 0.11004,0.1467444, 0.1501568,0, 0.340117,0,
|
||||||
|
0.325096,0.1712822, 0.325096,-0.1712822, -0.2412035,-0.09236849, -0.2412035,0.09236849, 0.3894343,-0.08673087, 0.3894343,0.08673087, 0.3125305,0.07128152, 0.3125305,-0.07128152, -0.2415555,0, 0.1841298,0,}, sd::DataType::DOUBLE);
|
||||||
|
|
||||||
|
ops::helpers::EigenValsAndVecs<double> eig(x);
|
||||||
|
|
||||||
|
ASSERT_TRUE(eig._Vals.isSameShape(&expVals));
|
||||||
|
ASSERT_TRUE(eig._Vals.equalsTo(&expVals));
|
||||||
|
|
||||||
|
ASSERT_TRUE(eig._Vecs.isSameShape(&expVecs));
|
||||||
|
ASSERT_TRUE(eig._Vecs.equalsTo(&expVecs));
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(HelpersTests2, fullPivLU_1) {
|
||||||
|
|
||||||
|
NDArray a('c', {4,4}, {0.33 ,-7.25 ,1.71 ,6.20 ,1.34 ,5.38 ,-2.76 ,-8.51 ,7.59 ,3.44 ,2.24 ,-6.82 ,-1.15 ,4.80 ,-4.67 ,2.14}, sd::DataType::DOUBLE);
|
||||||
|
NDArray b('c', {4,1}, {-5.,10,9,1}, sd::DataType::DOUBLE);
|
||||||
|
|
||||||
|
NDArray x = b.ulike();
|
||||||
|
|
||||||
|
NDArray expX('c', {4,1}, {0.8527251, -0.2545784, -1.076495, -0.8526268}, sd::DataType::DOUBLE);
|
||||||
|
|
||||||
|
ops::helpers::FullPivLU<double>::solve(a,b,x);
|
||||||
|
|
||||||
|
ASSERT_TRUE(x.equalsTo(&expX));
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(HelpersTests2, fullPivLU_2) {
|
||||||
|
|
||||||
|
NDArray a('c', {4,4}, {0.33 ,-7.25 ,1.71 ,6.20 ,1.34 ,5.38 ,-2.76 ,-8.51 ,7.59 ,3.44 ,2.24 ,-6.82 ,-1.15 ,4.80 ,-4.67 ,2.14}, sd::DataType::DOUBLE);
|
||||||
|
NDArray b('c', {4,2}, {-5.,10,9,1,1.5,-2,17,5}, sd::DataType::DOUBLE);
|
||||||
|
|
||||||
|
NDArray x = b.ulike();
|
||||||
|
|
||||||
|
NDArray expX('c', {4,2}, {1.462913, 1.835338, 0.4083664, -2.163816, -3.344481, -3.739225, 0.5156383,0.01624954}, sd::DataType::DOUBLE);
|
||||||
|
|
||||||
|
ops::helpers::FullPivLU<double>::solve(a,b,x);
|
||||||
|
|
||||||
|
ASSERT_TRUE(x.equalsTo(&expX));
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(HelpersTests2, fullPivLU_3) {
|
||||||
|
|
||||||
|
NDArray a1('c', {4,3}, {0.33 ,1.71 ,6.20 ,1.34 ,5.38 ,-2.76 ,-8.51 ,2.24 ,-6.82 ,4.80 ,-4.67 ,2.14}, sd::DataType::DOUBLE);
|
||||||
|
NDArray a2('c', {3,4}, {0.33 ,1.71 ,6.20 ,1.34 ,5.38 ,-2.76 ,-8.51 ,2.24 ,-6.82 ,4.80 ,-4.67 ,2.14}, sd::DataType::DOUBLE);
|
||||||
|
NDArray b1('c', {4,2}, {-5.,10,9,1,1.5,-2,17,5}, sd::DataType::DOUBLE);
|
||||||
|
NDArray b2('c', {3,2}, {-5.,10,9,1,1.5,-2}, sd::DataType::DOUBLE);
|
||||||
|
|
||||||
|
NDArray expX1('c', {3,2}, {0.9344955,-0.5841325, 0.8768102, 1.029137, -1.098021, 1.360152}, sd::DataType::DOUBLE);
|
||||||
|
NDArray expX2('c', {4,2}, {0.3536033,0.5270184,0,0,-0.8292221,0.967515,0.01827441,2.856337}, sd::DataType::DOUBLE);
|
||||||
|
|
||||||
|
NDArray x1 = expX1.ulike();
|
||||||
|
ops::helpers::FullPivLU<double>::solve(a1,b1,x1);
|
||||||
|
ASSERT_TRUE(x1.equalsTo(&expX1));
|
||||||
|
|
||||||
|
NDArray x2 = expX2.ulike();
|
||||||
|
ops::helpers::FullPivLU<double>::solve(a2,b2,x2);
|
||||||
|
ASSERT_TRUE(x2.equalsTo(&expX2));
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(HelpersTests2, fullPivLU_4) {
|
||||||
|
|
||||||
|
NDArray a('c', {10,10}, {6.9 ,4.8 ,9.5 ,3.1 ,6.5 ,5.8 ,-0.9 ,-7.3 ,-8.1 ,3.0 ,0.1 ,9.9 ,-3.2 ,6.4 ,6.2 ,-7.0 ,5.5 ,-2.2 ,-4.0 ,3.7 ,-3.6 ,9.0 ,-1.4 ,-2.4 ,1.7 ,
|
||||||
|
-6.1 ,-4.2 ,-2.5 ,-5.6 ,-0.4 ,0.4 ,9.1 ,-2.1 ,-5.4 ,7.3 ,3.6 ,-1.7 ,-5.7 ,-8.0 ,8.8 ,-3.0 ,-0.5 ,1.1 ,10.0 ,8.0 ,0.8 ,1.0 ,7.5 ,3.5 ,-1.8 ,
|
||||||
|
0.3 ,-0.6 ,-6.3 ,-4.5 ,-1.1 ,1.8 ,0.6 ,9.6 ,9.2 ,9.7 ,-2.6 ,4.3 ,-3.4 ,0.0 ,-6.7 ,5.0 ,10.5 ,1.5 ,-7.8 ,-4.1 ,-5.3 ,-5.0 ,2.0 ,-4.4 ,-8.4 ,
|
||||||
|
6.0 ,-9.4 ,-4.8 ,8.2 ,7.8 ,5.2 ,-9.5 ,-3.9 ,0.2 ,6.8 ,5.7 ,-8.5 ,-1.9 ,-0.3 ,7.4 ,-8.7 ,7.2 ,1.3 ,6.3 ,-3.7 ,3.9 ,3.3 ,-6.0 ,-9.1 ,5.9}, sd::DataType::DOUBLE);
|
||||||
|
NDArray b('c', {10,2}, {-5.,10,9,1,1.5,-2,17,5,3.6,0.12, -3.1,2.27,-0.5,27.3,8.9,5,-7,8,-9,10}, sd::DataType::DOUBLE);
|
||||||
|
|
||||||
|
NDArray x = b.ulike();
|
||||||
|
|
||||||
|
NDArray expX('c', {10,2}, {-0.697127, 2.58257, 2.109721,3.160622,-2.217796, -3.275736,-0.5752479, 2.475356,1.996841, -1.928947,
|
||||||
|
2.213154,3.541014, 0.7104885, -1.981451,-3.297972,-0.4720612, 3.672657, 0.9161028, -2.322383, -1.784493}, sd::DataType::DOUBLE);
|
||||||
|
|
||||||
|
ops::helpers::FullPivLU<double>::solve(a,b,x);
|
||||||
|
|
||||||
|
ASSERT_TRUE(x.equalsTo(&expX));
|
||||||
|
}
|
|
@ -0,0 +1,472 @@
|
||||||
|
/* ******************************************************************************
|
||||||
|
*
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* See the NOTICE file distributed with this work for additional
|
||||||
|
* information regarding copyright ownership.
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// Created by raver119 on 31.10.2017.
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "testlayers.h"
|
||||||
|
#include <ops/declarable/CustomOperations.h>
|
||||||
|
#include <array/NDArray.h>
|
||||||
|
#include <legacy/NativeOps.h>
|
||||||
|
|
||||||
|
using namespace sd;
|
||||||
|
using namespace sd::graph;
|
||||||
|
|
||||||
|
class IndexingTests : public testing::Test {
|
||||||
|
public:
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(IndexingTests, StridedSlice_1) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {3, 3, 3});
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', {1, 1, 3});
|
||||||
|
exp.p(0, 25.f);
|
||||||
|
exp.p(1, 26.f);
|
||||||
|
exp.p(2, 27.f);
|
||||||
|
|
||||||
|
x.linspace(1);
|
||||||
|
auto begin = NDArrayFactory::create<int>({2,2, 0});
|
||||||
|
auto end = NDArrayFactory::create<int>({3,3,3});
|
||||||
|
auto strides = NDArrayFactory::create<int>({1,1,1});
|
||||||
|
|
||||||
|
|
||||||
|
sd::ops::strided_slice op;
|
||||||
|
|
||||||
|
auto result = op.evaluate({&x, &begin, &end, &strides}, {}, {0,0,0,0,0}); //, 2,2,0, 3,3,3, 1,1,1});
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
|
auto z = result.at(0);
|
||||||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(IndexingTests, StridedSlice_2) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {5, 5, 5});
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', {2, 3, 3}, {86.f, 87.f, 88.f, 91.f, 92.f, 93.f, 96.f, 97.f, 98.f, 111.f, 112.f, 113.f, 116.f, 117.f, 118.f, 121.f, 122.f, 123.f});
|
||||||
|
|
||||||
|
x.linspace(1);
|
||||||
|
|
||||||
|
sd::ops::strided_slice op;
|
||||||
|
|
||||||
|
auto result = op.evaluate({&x}, {}, {0,0,0,0,0, 3,2,0, 5,5,3, 1,1,1});
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
|
auto z = result.at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(IndexingTests, StridedSlice_3) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {5, 5, 5});
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', {2, 3, 2}, {86.f, 88.f, 91.f, 93.f, 96.f, 98.f, 111.f, 113.f, 116.f, 118.f, 121.f, 123.f});
|
||||||
|
|
||||||
|
x.linspace(1);
|
||||||
|
|
||||||
|
sd::ops::strided_slice op;
|
||||||
|
|
||||||
|
auto result = op.evaluate({&x}, {}, {0,0,0,0,0, 3,2,0, 5,5,3, 1,1,2});
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
|
auto z = result.at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(IndexingTests, SimpleSlice_1) {
|
||||||
|
|
||||||
|
auto input = NDArrayFactory::create<float>('c', {3, 2, 3}, {1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6});
|
||||||
|
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', {1, 1, 3});
|
||||||
|
exp.p(0, 3.0f);
|
||||||
|
exp.p(1, 3.0f);
|
||||||
|
exp.p(2, 3.0f);
|
||||||
|
|
||||||
|
sd::ops::slice op;
|
||||||
|
|
||||||
|
auto result = op.evaluate({&input}, {}, {1,0,0, 1,1,3});
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
|
auto z = result.at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(IndexingTests, SimpleSlice_2) {
|
||||||
|
auto input = NDArrayFactory::create<float>('c', {3, 2, 3}, {1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6});
|
||||||
|
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', {1, 2, 3});
|
||||||
|
exp.p(0, 3.0f);
|
||||||
|
exp.p(1, 3.0f);
|
||||||
|
exp.p(2, 3.0f);
|
||||||
|
exp.p(3, 4.0f);
|
||||||
|
exp.p(4, 4.0f);
|
||||||
|
exp.p(5, 4.0f);
|
||||||
|
|
||||||
|
sd::ops::slice op;
|
||||||
|
|
||||||
|
auto result = op.evaluate({&input}, {}, {1,0,0, 1,2,3});
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
|
auto z = result.at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(IndexingTests, SimpleSlice_3) {
|
||||||
|
auto input = NDArrayFactory::create<float>('c', {3, 2, 3}, {1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6});
|
||||||
|
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', {2, 1, 3});
|
||||||
|
exp.p(0, 3.0f);
|
||||||
|
exp.p(1, 3.0f);
|
||||||
|
exp.p(2, 3.0f);
|
||||||
|
exp.p(3, 5.0f);
|
||||||
|
exp.p(4, 5.0f);
|
||||||
|
exp.p(5, 5.0f);
|
||||||
|
|
||||||
|
sd::ops::slice op;
|
||||||
|
|
||||||
|
auto result = op.evaluate({&input}, {}, {1,0,0, 2,1,3});
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
|
auto z = result.at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(IndexingTests, SimpleSlice_4) {
|
||||||
|
auto input = NDArrayFactory::create<double>('c', {3, 2, 3}, {1.0, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6});
|
||||||
|
auto start = NDArrayFactory::create<double>('c', {3}, {1.0, 0.0, 0.0});
|
||||||
|
auto stop = NDArrayFactory::create<double>('c', {3}, {2.0, 1.0, 3.0});
|
||||||
|
auto exp = NDArrayFactory::create<double>('c', {2, 1, 3}, {3.0, 3.0, 3.0, 5.0, 5.0, 5.0});
|
||||||
|
|
||||||
|
sd::ops::slice op;
|
||||||
|
|
||||||
|
auto result = op.evaluate({&input, &start, &stop});
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
|
auto z = result.at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(IndexingTests, MaskedSlice_0) {
|
||||||
|
auto matrix = NDArrayFactory::create<float>('c', {3, 5});
|
||||||
|
auto tads = matrix.allTensorsAlongDimension({1});
|
||||||
|
for (int e = 0; e < tads.size(); e++) {
|
||||||
|
tads.at(e)->assign((float) (e+1));
|
||||||
|
}
|
||||||
|
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', {1, 5});
|
||||||
|
exp.assign(2.0f);
|
||||||
|
|
||||||
|
sd::ops::strided_slice op;
|
||||||
|
auto result = op.evaluate({&matrix}, {}, {0,0,0,0,0, 1, 2, 1});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
|
auto z = result.at(0);
|
||||||
|
|
||||||
|
// z->printShapeInfo("z");
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(IndexingTests, MaskedSlice_00) {
|
||||||
|
auto matrix = NDArrayFactory::create<float>('c', {3, 5});
|
||||||
|
auto tads = matrix.allTensorsAlongDimension({1});
|
||||||
|
for (int e = 0; e < tads.size(); e++) {
|
||||||
|
tads.at(e)->assign((float) (e+1));
|
||||||
|
}
|
||||||
|
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', {1, 2}, {2, 2});
|
||||||
|
|
||||||
|
|
||||||
|
sd::ops::strided_slice op;
|
||||||
|
auto result = op.evaluate({&matrix}, {}, {0,0,0,0,0, 1, 1, 2, 3, 1, 1});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
|
auto z = result.at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(IndexingTests, MaskedSlice_1) {
|
||||||
|
auto matrix = NDArrayFactory::create<float>('c', {3, 5});
|
||||||
|
auto tads = matrix.allTensorsAlongDimension({1});
|
||||||
|
for (int e = 0; e < tads.size(); e++) {
|
||||||
|
tads.at(e)->assign((float) (e+1));
|
||||||
|
}
|
||||||
|
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', {5});
|
||||||
|
exp.assign(2.0f);
|
||||||
|
|
||||||
|
sd::ops::strided_slice op;
|
||||||
|
auto result = op.evaluate({&matrix}, {}, {0,0,0,0,1, 1, 2, 1});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
|
auto z = result.at(0);
|
||||||
|
|
||||||
|
// z->printShapeInfo("z");
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(IndexingTests, MaskedSlice_2) {
|
||||||
|
|
||||||
|
auto matrix = NDArrayFactory::create<float>('c', {3, 3, 3}, {1.f, 1.2f, 1.3f, 2.f, 2.2f, 2.3f, 3.f, 3.2f, 3.3f, 4.f, 4.2f, 4.3f, 5.f, 5.2f, 5.3f, 6.f, 6.2f, 6.3f, 7.f, 7.2f, 7.3f, 8.f, 8.2f, 8.3f, 9.f, 9.2f, 9.3f});
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', {3, 3}, {4.000000f, 4.200000f, 4.300000f, 5.000000f, 5.200000f, 5.300000f, 6.000000f, 6.200000f, 6.300000f});
|
||||||
|
|
||||||
|
// output = tf.strided_slice(a, [1, 0, 0], [3, 3, 3], shrink_axis_mask=5)
|
||||||
|
sd::ops::strided_slice op;
|
||||||
|
auto result = op.evaluate({&matrix}, {}, {0,0,0,0,1, 1, 0, 0, 3, 3, 3, 1, 1, 1});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
|
auto z = result.at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(IndexingTests, MaskedSlice_3) {
|
||||||
|
|
||||||
|
auto matrix = NDArrayFactory::create<float>('c', {3, 3, 3}, {1.f, 1.2f, 1.3f, 2.f, 2.2f, 2.3f, 3.f, 3.2f, 3.3f, 4.f, 4.2f, 4.3f, 5.f, 5.2f, 5.3f, 6.f, 6.2f, 6.3f, 7.f, 7.2f, 7.3f, 8.f, 8.2f, 8.3f, 9.f, 9.2f, 9.3f});
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', {2, 3}, { 4.f, 4.2f, 4.3f, 7.f, 7.2f, 7.3f});
|
||||||
|
|
||||||
|
// output = tf.strided_slice(a, [1, 0, 0], [3, 3, 3], shrink_axis_mask=5)
|
||||||
|
sd::ops::strided_slice op;
|
||||||
|
auto result = op.evaluate({&matrix}, {}, {0,0,0,0,2, 1, 0, 0, 3, 3, 3, 1, 1, 1});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
|
auto z = result.at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(IndexingTests, MaskedSlice_4) {
|
||||||
|
|
||||||
|
auto matrix = NDArrayFactory::create<float>('c', {3, 3, 3}, {1.f, 1.2f, 1.3f, 2.f, 2.2f, 2.3f, 3.f, 3.2f, 3.3f, 4.f, 4.2f, 4.3f, 5.f, 5.2f, 5.3f, 6.f, 6.2f, 6.3f, 7.f, 7.2f, 7.3f, 8.f, 8.2f, 8.3f, 9.f, 9.2f, 9.3f});
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', {3}, { 4.f, 4.2f, 4.3f});
|
||||||
|
|
||||||
|
// output = tf.strided_slice(a, [1, 0, 0], [3, 3, 3], shrink_axis_mask=5)
|
||||||
|
sd::ops::strided_slice op;
|
||||||
|
auto result = op.evaluate({&matrix}, {}, {0,0,0,0, 3, 1, 0, 0, 3, 3, 3, 1, 1, 1});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
|
auto z = result.at(0);
|
||||||
|
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(IndexingTests, Live_Slice_1) {
|
||||||
|
auto matrix = NDArrayFactory::create<float>('c', {3, 3, 3}, {1.f, 1.2f, 1.3f, 2.f, 2.2f, 2.3f, 3.f, 3.2f, 3.3f, 4.f, 4.2f, 4.3f, 5.f, 5.2f, 5.3f, 6.f, 6.2f, 6.3f, 7.f, 7.2f, 7.3f, 8.f, 8.2f, 8.3f, 9.f, 9.2f, 9.3f});
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', {3}, { 4.f, 4.2f, 4.3f});
|
||||||
|
|
||||||
|
auto begin = NDArrayFactory::create<float>('c', {3}, {1.0f, 0.0f, 0.0f});
|
||||||
|
auto end = NDArrayFactory::create<float>('c', {3}, {3.0f, 3.0f, 3.0f});
|
||||||
|
auto stride = NDArrayFactory::create<float>('c', {3}, {1.0f, 1.0f, 1.0f});
|
||||||
|
|
||||||
|
// output = tf.strided_slice(a, [1, 0, 0], [3, 3, 3], shrink_axis_mask=5)
|
||||||
|
sd::ops::strided_slice op;
|
||||||
|
auto result = op.evaluate({&matrix, &begin, &end, &stride}, {}, {0,0,0,0,3});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
|
auto z = result.at(0);
|
||||||
|
|
||||||
|
// z->printShapeInfo("z shape");
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(IndexingTests, Test_StridedSlice_1) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {1, 2}, {5.f, 2.f});
|
||||||
|
auto a = NDArrayFactory::create<float>('c', {1}, {0.f});
|
||||||
|
auto b = NDArrayFactory::create<float>('c', {1}, {1.f});
|
||||||
|
auto c = NDArrayFactory::create<float>('c', {1}, {1.f});
|
||||||
|
auto exp = NDArrayFactory::create<float>({5.0f, 2});
|
||||||
|
|
||||||
|
sd::ops::strided_slice op;
|
||||||
|
auto result = op.evaluate({&x, &a, &b, &c}, {}, {0, 0, 0, 0, 1});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
|
auto z = result.at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(IndexingTests, Test_StridedSlice_2) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {2, 3}, {1, 2, 3, 4, 5, 6});
|
||||||
|
auto a = NDArrayFactory::create<float>('c', {2}, {1, 1});
|
||||||
|
auto b = NDArrayFactory::create<float>('c', {2}, {2, 2});
|
||||||
|
auto c = NDArrayFactory::create<float>('c', {2}, {1, 1});
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', {1}, {5.0});
|
||||||
|
|
||||||
|
sd::ops::strided_slice op;
|
||||||
|
auto result = op.evaluate({&x, &a, &b, &c}, {}, {0, 0, 0, 0, 1});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
|
auto z = result.at(0);
|
||||||
|
|
||||||
|
// z->printIndexedBuffer("Z");
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(IndexingTests, Test_StridedSlice_3) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {2, 3}, {1, 2, 3, 4, 5, 6});
|
||||||
|
auto a = NDArrayFactory::create<float>('c', {2}, {1, 2});
|
||||||
|
auto b = NDArrayFactory::create<float>('c', {2}, {2, 3});
|
||||||
|
auto c = NDArrayFactory::create<float>('c', {2}, {1, 1});
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', {1}, {6.0});
|
||||||
|
|
||||||
|
sd::ops::strided_slice op;
|
||||||
|
auto result = op.evaluate({&x, &a, &b, &c}, {}, {0, 0, 0, 0, 1});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
|
auto z = result.at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(IndexingTests, Test_StridedSlice_4) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {1, 2}, {5, 2});
|
||||||
|
auto a = NDArrayFactory::create<float>('c', {1}, {0.});
|
||||||
|
auto b = NDArrayFactory::create<float>('c', {1}, {1});
|
||||||
|
auto c = NDArrayFactory::create<float>('c', {1}, {1});
|
||||||
|
auto exp = NDArrayFactory::create<float>({5.0f, 2});
|
||||||
|
|
||||||
|
sd::ops::strided_slice op;
|
||||||
|
auto result = op.evaluate({&x, &a, &b, &c}, {}, {0, 0, 0, 0, 1});
|
||||||
|
// auto result = op.execute({&x, &a, &b, &c}, {}, {0, 0, 0, 0, 1, 0, 1, 1});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
|
auto z = result.at(0);
|
||||||
|
|
||||||
|
//z->printIndexedBuffer("Z");
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(IndexingTests, Test_Subarray_Strided_1) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9});
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', {3, 2}, {1, 3, 4, 6, 7, 9});
|
||||||
|
auto sub = x({0,0,0, 0,3,2}, true, true);
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.isSameShape(sub));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(sub));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/*
|
||||||
|
TEST_F(IndexingTests, MaskedSlice_5) {
|
||||||
|
|
||||||
|
auto matrix('c', {3, 3, 3}, {1.f, 1.2f, 1.3f, 2.f, 2.2f, 2.3f, 3.f, 3.2f, 3.3f, 4.f, 4.2f, 4.3f, 5.f, 5.2f, 5.3f, 6.f, 6.2f, 6.3f, 7.f, 7.2f, 7.3f, 8.f, 8.2f, 8.3f, 9.f, 9.2f, 9.3f});
|
||||||
|
auto exp('c', {2, 3}, { 4.f, 4.2f, 4.3f, 7.f, 7.2f, 7.3f});
|
||||||
|
|
||||||
|
// output = tf.strided_slice(a, [1, 0, 0], [3, 3, 3], shrink_axis_mask=5)
|
||||||
|
sd::ops::strided_slice<float> op;
|
||||||
|
auto result = op.execute({&matrix}, {}, {0,0,0,0,2, 1, 0, 0, 3, 3, 3});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
|
auto z = result.at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
*/
|
|
@ -0,0 +1,89 @@
|
||||||
|
/* ******************************************************************************
|
||||||
|
*
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* See the NOTICE file distributed with this work for additional
|
||||||
|
* information regarding copyright ownership.
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "testlayers.h"
|
||||||
|
#include <helpers/PointersManager.h>
|
||||||
|
#include <array/ExtraArguments.h>
|
||||||
|
#include <ops/declarable/CustomOperations.h>
|
||||||
|
#include <array>
|
||||||
|
#include <cuda.h>
|
||||||
|
#include <cuda_runtime.h>
|
||||||
|
|
||||||
|
using namespace sd;
|
||||||
|
using namespace sd::ops;
|
||||||
|
|
||||||
|
class JavaInteropCudaTests : public testing::Test {
|
||||||
|
public:
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(JavaInteropCudaTests, test_DeclarableOp_execution_1) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {3, 5});
|
||||||
|
auto y = NDArrayFactory::create<float>('c', {5}, {1.f, 1.f, 1.f, 1.f, 1.f});
|
||||||
|
auto e = NDArrayFactory::create<float>('c', {3, 5});
|
||||||
|
x.assign(1.f);
|
||||||
|
e.assign(2.f);
|
||||||
|
|
||||||
|
sd::ops::add op;
|
||||||
|
Context context(1);
|
||||||
|
|
||||||
|
context.setCudaContext(LaunchContext::defaultContext()->getCudaStream(), LaunchContext::defaultContext()->getReductionPointer(), LaunchContext::defaultContext()->getAllocationPointer());
|
||||||
|
context.setInputArray(0, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo());
|
||||||
|
context.setInputArray(1, y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo());
|
||||||
|
|
||||||
|
context.setOutputArray(0, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo());
|
||||||
|
|
||||||
|
PointersManager pm(LaunchContext::defaultContext(), "test_DeclarableOp_execution_1");
|
||||||
|
execCustomOp2(nullptr, op.getOpHash(), &context);
|
||||||
|
|
||||||
|
pm.synchronize();
|
||||||
|
|
||||||
|
ASSERT_EQ(e, x);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(JavaInteropCudaTests, test_DeclarableOp_execution_2) {
|
||||||
|
NDArray x('c', {3, 1, 2}, sd::DataType::FLOAT32);
|
||||||
|
NDArray y('c', {2, 2}, sd::DataType::FLOAT32);
|
||||||
|
NDArray z('c', {3, 2, 2}, sd::DataType::BOOL);
|
||||||
|
NDArray e('c', {3, 2, 2}, sd::DataType::BOOL);
|
||||||
|
|
||||||
|
x.assign(1.f);
|
||||||
|
y.assign(2.f);
|
||||||
|
e.assign(false);
|
||||||
|
|
||||||
|
sd::ops::equals op;
|
||||||
|
Context context(1);
|
||||||
|
|
||||||
|
context.setCudaContext(LaunchContext::defaultContext()->getCudaStream(), LaunchContext::defaultContext()->getReductionPointer(), LaunchContext::defaultContext()->getAllocationPointer());
|
||||||
|
context.setInputArray(0, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo());
|
||||||
|
context.setInputArray(1, y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo());
|
||||||
|
|
||||||
|
context.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo());
|
||||||
|
|
||||||
|
PointersManager pm(LaunchContext::defaultContext(), "test_DeclarableOp_execution_2");
|
||||||
|
execCustomOp2(nullptr, op.getOpHash(), &context);
|
||||||
|
|
||||||
|
pm.synchronize();
|
||||||
|
|
||||||
|
ASSERT_EQ(e, z);
|
||||||
|
}
|
||||||
|
|
File diff suppressed because one or more lines are too long
|
@ -0,0 +1,221 @@
|
||||||
|
/* ******************************************************************************
|
||||||
|
*
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* See the NOTICE file distributed with this work for additional
|
||||||
|
* information regarding copyright ownership.
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "testlayers.h"
|
||||||
|
#include <array/ExtraArguments.h>
|
||||||
|
#include <array>
|
||||||
|
#include <cuda.h>
|
||||||
|
#include <cuda_runtime.h>
|
||||||
|
|
||||||
|
using namespace sd;
|
||||||
|
|
||||||
|
class LambdaTests : public testing::Test {
|
||||||
|
public:
|
||||||
|
|
||||||
|
LambdaTests() {
|
||||||
|
printf("\n");
|
||||||
|
fflush(stdout);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Lambda>
|
||||||
|
__global__ void runLambda(double *input, double *output, Nd4jLong length, Lambda lambda) {
|
||||||
|
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
for (Nd4jLong e = tid; e < length; e += gridDim.x * blockDim.x) {
|
||||||
|
output[e] = lambda(input[e]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void launcher(cudaStream_t *stream, double *input, double *output, Nd4jLong length) {
|
||||||
|
//auto f = [] __host__ __device__ (double x) -> double {
|
||||||
|
// return x + 1.;
|
||||||
|
//};
|
||||||
|
auto f = LAMBDA_D(x) {
|
||||||
|
return x+1.;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
runLambda<<<128, 128, 128, *stream>>>(input, output, length, f);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(LambdaTests, test_basic_1) {
|
||||||
|
auto x = NDArrayFactory::create<double>('c', {5});
|
||||||
|
auto e = NDArrayFactory::create<double>('c', {5}, {1., 1., 1., 1., 1.});
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
//x.applyLambda<double>(f, nullptr);
|
||||||
|
launcher(LaunchContext::defaultContext()->getCudaStream(), (double *)x.specialBuffer(), (double *)x.specialBuffer(), x.lengthOf());
|
||||||
|
auto res = cudaStreamSynchronize(*LaunchContext::defaultContext()->getCudaStream());
|
||||||
|
ASSERT_EQ(0, res);
|
||||||
|
|
||||||
|
ASSERT_EQ(e, x);
|
||||||
|
}
|
||||||
|
|
||||||
|
void test(NDArray &x) {
|
||||||
|
auto f = LAMBDA_D(x) {
|
||||||
|
return x+1.;
|
||||||
|
};
|
||||||
|
|
||||||
|
x.applyLambda(f, x);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void test2(NDArray &x) {
|
||||||
|
auto f = LAMBDA_T(x) {
|
||||||
|
return x+1.;
|
||||||
|
};
|
||||||
|
|
||||||
|
x.applyLambda(f, x);
|
||||||
|
}
|
||||||
|
|
||||||
|
void testPairwise(NDArray &x, NDArray &y) {
|
||||||
|
auto f = LAMBDA_DD(x, y) {
|
||||||
|
return x + y +1.;
|
||||||
|
};
|
||||||
|
|
||||||
|
x.applyPairwiseLambda(y, f, x);
|
||||||
|
}
|
||||||
|
|
||||||
|
void testTriplewise(NDArray &i, NDArray &j, NDArray &k) {
|
||||||
|
auto f = LAMBDA_DDD(i, j, k) {
|
||||||
|
return i + j + k + 2.;
|
||||||
|
};
|
||||||
|
|
||||||
|
i.applyTriplewiseLambda(j, k, f, i);
|
||||||
|
}
|
||||||
|
|
||||||
|
void testIndexed(NDArray &x) {
|
||||||
|
auto f = ILAMBDA_D(x) {
|
||||||
|
return _idx + 1.;
|
||||||
|
};
|
||||||
|
|
||||||
|
x.applyIndexedLambda(f, x);
|
||||||
|
}
|
||||||
|
|
||||||
|
void testIndexedPairwise(NDArray &x, NDArray &y) {
|
||||||
|
auto f = ILAMBDA_DD(x, y) {
|
||||||
|
return _idx + x + y +1.;
|
||||||
|
};
|
||||||
|
|
||||||
|
x.applyIndexedPairwiseLambda(y, f, x);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(LambdaTests, test_basic_2) {
|
||||||
|
auto x = NDArrayFactory::create<double>('c', {5});
|
||||||
|
auto e = NDArrayFactory::create<double>('c', {5}, {1., 1., 1., 1., 1.});
|
||||||
|
|
||||||
|
test(x);
|
||||||
|
|
||||||
|
ASSERT_EQ(e, x);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(LambdaTests, test_basic_3) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {5});
|
||||||
|
auto e = NDArrayFactory::create<float>('c', {5}, {1.f, 1.f, 1.f, 1.f, 1.f});
|
||||||
|
|
||||||
|
test(x);
|
||||||
|
|
||||||
|
ASSERT_EQ(e, x);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(LambdaTests, test_basic_4) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {5});
|
||||||
|
auto e = NDArrayFactory::create<float>('c', {5}, {1.f, 1.f, 1.f, 1.f, 1.f});
|
||||||
|
|
||||||
|
test2<float>(x);
|
||||||
|
|
||||||
|
ASSERT_EQ(e, x);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(LambdaTests, test_basic_5) {
|
||||||
|
auto x = NDArrayFactory::create<double>('c', {5}, {1., 1., 1., 1., 1.});
|
||||||
|
auto y = NDArrayFactory::create<double>('c', {5}, {2., 2., 2., 2., 2.});
|
||||||
|
auto e = NDArrayFactory::create<double>('c', {5}, {4., 4., 4., 4., 4.});
|
||||||
|
|
||||||
|
testPairwise(x, y);
|
||||||
|
|
||||||
|
ASSERT_EQ(e, x);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(LambdaTests, test_basic_6) {
|
||||||
|
auto x = NDArrayFactory::create<double>('c', {5});
|
||||||
|
auto e = NDArrayFactory::create<double>('c', {5}, {1., 2., 3., 4., 5.});
|
||||||
|
|
||||||
|
testIndexed(x);
|
||||||
|
|
||||||
|
ASSERT_EQ(e, x);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(LambdaTests, test_basic_7) {
|
||||||
|
auto w = NDArrayFactory::create<double>('c', {5}, {0., 0., 0., 0., 0.});
|
||||||
|
auto x = NDArrayFactory::create<double>('c', {5}, {1., 1., 1., 1., 1.});
|
||||||
|
auto y = NDArrayFactory::create<double>('c', {5}, {2., 2., 2., 2., 2.});
|
||||||
|
auto e = NDArrayFactory::create<double>('c', {5}, {5., 5., 5., 5., 5.});
|
||||||
|
|
||||||
|
testTriplewise(w, x, y);
|
||||||
|
|
||||||
|
ASSERT_EQ(e, w);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(LambdaTests, test_basic_8) {
|
||||||
|
auto x = NDArrayFactory::create<double>('c', {5}, {1., 1., 1., 1., 1.});
|
||||||
|
auto y = NDArrayFactory::create<double>('c', {5}, {2., 2., 2., 2., 2.});
|
||||||
|
auto e = NDArrayFactory::create<double>('c', {5}, {4., 5., 6., 7., 8.});
|
||||||
|
|
||||||
|
testIndexedPairwise(x, y);
|
||||||
|
|
||||||
|
ASSERT_EQ(e, x);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void testPairwiseMy(NDArray &x, NDArray &y, NDArray &z) {
|
||||||
|
|
||||||
|
auto f = LAMBDA_TT(x, y){
|
||||||
|
return sd::math::nd4j_max<T>(x, (T)0.f)
|
||||||
|
- x * y
|
||||||
|
+ sd::math::nd4j_log<T,T>((T)1.f
|
||||||
|
+ sd::math::nd4j_exp<T,T>(-sd::math::nd4j_abs(x)));
|
||||||
|
};
|
||||||
|
|
||||||
|
x.applyPairwiseLambda(y, f, z);
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(LambdaTests, test_basic_9) {
|
||||||
|
|
||||||
|
NDArray labels('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0});
|
||||||
|
NDArray logits('c', {2,3,4}, sd::DataType::DOUBLE);
|
||||||
|
NDArray output('c', {2,3,4}, sd::DataType::DOUBLE);
|
||||||
|
NDArray expected('c', {2,3,4}, {0.744397, 0.598139, 0.554355, 0.913015, 0.474077, 1.037488, 0.403186, 1.171101, 0.341154, 1.313262, 0.287335, 1.463282, 0.241008, 1.620417, 0.201413, 1.783901, 0.167786, 1.952978, 2.039387, 0.126928, 0.115520, 2.305083, 0.095545, 2.486836});
|
||||||
|
|
||||||
|
logits.linspace(0.1, 0.1);
|
||||||
|
|
||||||
|
NDArray::prepareSpecialUse({&output}, {&logits, &labels});
|
||||||
|
testPairwiseMy<double>(logits, labels, output);
|
||||||
|
NDArray::registerSpecialUse({&output}, {&logits, &labels});
|
||||||
|
|
||||||
|
// output.printBuffer(nullptr, -1, true);
|
||||||
|
ASSERT_TRUE(expected.equalsTo(output));
|
||||||
|
}
|
|
@ -0,0 +1,127 @@
|
||||||
|
/* ******************************************************************************
|
||||||
|
*
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* See the NOTICE file distributed with this work for additional
|
||||||
|
* information regarding copyright ownership.
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "testlayers.h"
|
||||||
|
#include <array/NDArray.h>
|
||||||
|
#include <helpers/ShapeUtils.h>
|
||||||
|
#include <loops/reduce3.h>
|
||||||
|
#include <ops/declarable/LegacyTransformOp.h>
|
||||||
|
#include <ops/declarable/LegacyPairwiseTransformOp.h>
|
||||||
|
#include <ops/declarable/LegacyScalarOp.h>
|
||||||
|
#include <ops/declarable/LegacyReduceSameOp.h>
|
||||||
|
#include <ops/declarable/LegacyReduceFloatOp.h>
|
||||||
|
#include <ops/declarable/LegacyIndexReduceOp.h>
|
||||||
|
#include <ops/declarable/LegacyBroadcastOp.h>
|
||||||
|
#include <helpers/TAD.h>
|
||||||
|
#include <helpers/ConstantTadHelper.h>
|
||||||
|
#include <thread>
|
||||||
|
#include <execution/AffinityManager.h>
|
||||||
|
|
||||||
|
using namespace sd;
|
||||||
|
using namespace sd::ops;
|
||||||
|
|
||||||
|
class LaunchContextCudaTests : public testing::Test {
|
||||||
|
//
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
void acquireContext(int threadId, int &deviceId) {
|
||||||
|
deviceId = AffinityManager::currentDeviceId();
|
||||||
|
|
||||||
|
nd4j_printf("Creating thread: [%i]; assigned deviceId: [%i];\n", threadId, deviceId);
|
||||||
|
|
||||||
|
auto lc = LaunchContext::defaultContext();
|
||||||
|
nd4j_printf("LC: [%p]\n", lc);
|
||||||
|
|
||||||
|
nd4j_printf("reductionPtr: [%p]; stream: [%p];\n", lc->getReductionPointer(), lc->getCudaStream());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(LaunchContextCudaTests, basic_test_1) {
|
||||||
|
int deviceA, deviceB;
|
||||||
|
std::thread threadA(acquireContext, 0, std::ref(deviceA));
|
||||||
|
std::thread threadB(acquireContext, 1, std::ref(deviceB));
|
||||||
|
|
||||||
|
threadA.join();
|
||||||
|
threadB.join();
|
||||||
|
nd4j_printf("All threads joined\n","");
|
||||||
|
|
||||||
|
if (AffinityManager::numberOfDevices() > 1)
|
||||||
|
ASSERT_NE(deviceA, deviceB);
|
||||||
|
}
|
||||||
|
|
||||||
|
void fillArray(int tid, std::vector<NDArray*> &arrays) {
|
||||||
|
auto array = NDArrayFactory::create_<int>('c', {3, 10});
|
||||||
|
nd4j_printf("Array created on device [%i]\n", AffinityManager::currentDeviceId());
|
||||||
|
array->assign(tid);
|
||||||
|
arrays[tid] = array;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(LaunchContextCudaTests, basic_test_2) {
|
||||||
|
std::vector<NDArray*> arrays(2);
|
||||||
|
|
||||||
|
std::thread threadA(fillArray, 0, std::ref(arrays));
|
||||||
|
std::thread threadB(fillArray, 1, std::ref(arrays));
|
||||||
|
|
||||||
|
threadA.join();
|
||||||
|
threadB.join();
|
||||||
|
|
||||||
|
for (int e = 0; e < 2; e++) {
|
||||||
|
auto array = arrays[e];
|
||||||
|
ASSERT_EQ(e, array->e<int>(0));
|
||||||
|
|
||||||
|
delete array;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void initAffinity(int tid, std::vector<int> &aff) {
|
||||||
|
auto affinity = AffinityManager::currentDeviceId();
|
||||||
|
aff[tid] = affinity;
|
||||||
|
nd4j_printf("Thread [%i] affined with device [%i]\n", tid, affinity);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(LaunchContextCudaTests, basic_test_3) {
|
||||||
|
auto totalThreads = AffinityManager::numberOfDevices() * 4;
|
||||||
|
nd4j_printf("Total threads: %i\n", totalThreads);
|
||||||
|
std::vector<int> affinities(totalThreads);
|
||||||
|
|
||||||
|
for (int e = 0; e < totalThreads; e++) {
|
||||||
|
std::thread thread(initAffinity, e, std::ref(affinities));
|
||||||
|
|
||||||
|
thread.join();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int> hits(AffinityManager::numberOfDevices());
|
||||||
|
std::fill(hits.begin(), hits.end(), 0);
|
||||||
|
|
||||||
|
// we need to make sure all threads were attached to "valid" devices
|
||||||
|
for (int e = 0; e < totalThreads; e++) {
|
||||||
|
auto aff = affinities[e];
|
||||||
|
ASSERT_TRUE(aff >= 0 && aff < AffinityManager::numberOfDevices());
|
||||||
|
|
||||||
|
hits[aff]++;
|
||||||
|
}
|
||||||
|
|
||||||
|
// now we check if all devices got some threads
|
||||||
|
for (int e = 0; e < AffinityManager::numberOfDevices(); e++) {
|
||||||
|
ASSERT_GT(hits[e], 0);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,114 @@
|
||||||
|
/* ******************************************************************************
|
||||||
|
*
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* See the NOTICE file distributed with this work for additional
|
||||||
|
* information regarding copyright ownership.
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "testlayers.h"
|
||||||
|
#include <array/NDArray.h>
|
||||||
|
#include <helpers/ShapeUtils.h>
|
||||||
|
#include <loops/reduce3.h>
|
||||||
|
#include <ops/declarable/LegacyTransformOp.h>
|
||||||
|
#include <ops/declarable/LegacyPairwiseTransformOp.h>
|
||||||
|
#include <ops/declarable/LegacyScalarOp.h>
|
||||||
|
#include <ops/declarable/LegacyReduceSameOp.h>
|
||||||
|
#include <ops/declarable/LegacyReduceFloatOp.h>
|
||||||
|
#include <ops/declarable/LegacyIndexReduceOp.h>
|
||||||
|
#include <ops/declarable/LegacyBroadcastOp.h>
|
||||||
|
#include <helpers/TAD.h>
|
||||||
|
#include <helpers/ConstantTadHelper.h>
|
||||||
|
|
||||||
|
using namespace sd;
|
||||||
|
using namespace sd::ops;
|
||||||
|
|
||||||
|
class LegacyOpsCudaTests : public testing::Test {
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(LegacyOpsCudaTests, test_sortTad_1) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {3, 5}, {1.f, 3.f, 0.f, 2.f, 4.f,
|
||||||
|
6.f, 5.f, 9.f, 7.f, 8.f,
|
||||||
|
10.f, 11.f, 14.f, 12.f, 13.f});
|
||||||
|
|
||||||
|
auto e = NDArrayFactory::create<float>('c', {3, 5}, {0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 13.f, 14.f});
|
||||||
|
|
||||||
|
int axis = 1;
|
||||||
|
auto packX = ConstantTadHelper::getInstance().tadForDimensions(x.shapeInfo(), axis);
|
||||||
|
|
||||||
|
Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()};
|
||||||
|
|
||||||
|
x.syncToDevice();
|
||||||
|
sortTad(extras, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), &axis, 1, packX.platformShapeInfo(), packX.platformOffsets(), false);
|
||||||
|
x.tickWriteDevice();
|
||||||
|
|
||||||
|
ASSERT_EQ(e, x);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(LegacyOpsCudaTests, test_sort_1) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {4}, {4.f, 2.f, 1.f, 3.f});
|
||||||
|
auto e = NDArrayFactory::create<float>('c', {4}, {1.f, 2.f, 3.f, 4.f});
|
||||||
|
|
||||||
|
Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()};
|
||||||
|
|
||||||
|
NDArray::prepareSpecialUse({&x}, {&x});
|
||||||
|
::sort(extras, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), false);
|
||||||
|
NDArray::registerSpecialUse({&x});
|
||||||
|
|
||||||
|
ASSERT_EQ(e, x);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(LegacyOpsCudaTests, test_sort_2) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {4}, {4.f, 2.f, 1.f, 3.f});
|
||||||
|
auto e = NDArrayFactory::create<float>('c', {4}, {4.f, 3.f, 2.f, 1.f});
|
||||||
|
|
||||||
|
Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()};
|
||||||
|
|
||||||
|
NDArray::prepareSpecialUse({&x}, {&x});
|
||||||
|
::sort(extras, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), true);
|
||||||
|
NDArray::registerSpecialUse({&x});
|
||||||
|
|
||||||
|
ASSERT_EQ(e, x);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(LegacyOpsCudaTests, test_sort_3) {
|
||||||
|
auto x = NDArrayFactory::create<double>('c', {4}, {0.5, 0.4, 0.1, 0.2});
|
||||||
|
auto e = NDArrayFactory::create<double>('c', {4}, {0.1, 0.2, 0.4, 0.5});
|
||||||
|
|
||||||
|
Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()};
|
||||||
|
|
||||||
|
NDArray::prepareSpecialUse({&x}, {&x});
|
||||||
|
::sort(extras, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), false);
|
||||||
|
NDArray::registerSpecialUse({&x});
|
||||||
|
|
||||||
|
ASSERT_EQ(e, x);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(LegacyOpsCudaTests, test_sort_4) {
|
||||||
|
auto x = NDArrayFactory::create<double>('c', {4}, {7, 4, 9, 2});
|
||||||
|
auto e = NDArrayFactory::create<double>('c', {4}, {2, 4, 7, 9});
|
||||||
|
|
||||||
|
Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()};
|
||||||
|
|
||||||
|
NDArray::prepareSpecialUse({&x}, {&x});
|
||||||
|
::sort(extras, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), false);
|
||||||
|
NDArray::registerSpecialUse({&x});
|
||||||
|
|
||||||
|
ASSERT_EQ(e, x);
|
||||||
|
}
|
|
@ -0,0 +1,770 @@
|
||||||
|
/* ******************************************************************************
|
||||||
|
*
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* See the NOTICE file distributed with this work for additional
|
||||||
|
* information regarding copyright ownership.
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// Created by raver119 on 16.10.2017.
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "testlayers.h"
|
||||||
|
#include <array/NDArray.h>
|
||||||
|
#include <helpers/ShapeUtils.h>
|
||||||
|
#include <loops/reduce3.h>
|
||||||
|
#include <ops/declarable/LegacyTransformOp.h>
|
||||||
|
#include <ops/declarable/LegacyPairwiseTransformOp.h>
|
||||||
|
#include <ops/declarable/LegacyScalarOp.h>
|
||||||
|
#include <ops/declarable/LegacyReduceSameOp.h>
|
||||||
|
#include <ops/declarable/LegacyReduceFloatOp.h>
|
||||||
|
#include <ops/declarable/LegacyIndexReduceOp.h>
|
||||||
|
#include <ops/declarable/LegacyBroadcastOp.h>
|
||||||
|
#include <helpers/TAD.h>
|
||||||
|
#include <helpers/ConstantTadHelper.h>
|
||||||
|
|
||||||
|
using namespace sd;
|
||||||
|
using namespace sd::ops;
|
||||||
|
|
||||||
|
class LegacyOpsTests : public testing::Test {
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(LegacyOpsTests, TransformTests_1) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {5, 5});
|
||||||
|
x.assign(1.0);
|
||||||
|
auto z = NDArrayFactory::create<float>('c', {5,5});
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', {5, 5});
|
||||||
|
exp.assign(-1.0);
|
||||||
|
|
||||||
|
sd::ops::LegacyTransformSameOp op(transform::Neg); // Neg
|
||||||
|
auto status = op.execute({&x}, {&z}, {}, {}, {});
|
||||||
|
ASSERT_EQ(status, ND4J_STATUS_OK);
|
||||||
|
//z.printIndexedBuffer("Output NEG");
|
||||||
|
ASSERT_TRUE(z.equalsTo(&exp));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(LegacyOpsTests, TransformTests_2) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {5, 5});
|
||||||
|
x.assign(1.0);
|
||||||
|
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', {5, 5});
|
||||||
|
exp.assign(-1.0);
|
||||||
|
|
||||||
|
sd::ops::LegacyTransformSameOp op(transform::Neg); // Neg
|
||||||
|
auto result = op.evaluate({&x}, {}, {});
|
||||||
|
|
||||||
|
ASSERT_EQ(1, result.size());
|
||||||
|
|
||||||
|
auto z = result.at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(LegacyOpsTests, Reciprocal_1) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {5, 5});
|
||||||
|
x.assign(2.0f);
|
||||||
|
|
||||||
|
auto ethalon = NDArrayFactory::create<float>('c', {5, 5});
|
||||||
|
ethalon.assign(0.5f);
|
||||||
|
|
||||||
|
sd::ops::LegacyTransformSameOp op(transform::Reciprocal); // Reciprocal
|
||||||
|
Nd4jStatus status = op.execute({&x}, {&x}, {}, {}, {});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||||
|
ASSERT_TRUE(ethalon.equalsTo(&x));
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(LegacyOpsTests, PWT_Tests_1) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {5, 5});
|
||||||
|
x.assign(2.0);
|
||||||
|
|
||||||
|
auto y = NDArrayFactory::create<float>('c', {5, 5});
|
||||||
|
y.assign(3.0);
|
||||||
|
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', {5, 5});
|
||||||
|
exp.assign(6.0);
|
||||||
|
|
||||||
|
sd::ops::LegacyPairwiseTransformOp op(pairwise::Multiply); // Multiply
|
||||||
|
Nd4jStatus status = op.execute({&x, &y}, {&x}, {}, {}, {});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.equalsTo(&x));
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(LegacyOpsTests, PWT_Tests_2) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {5, 5});
|
||||||
|
x.assign(2.0);
|
||||||
|
|
||||||
|
auto y = NDArrayFactory::create<float>('c', {5, 5});
|
||||||
|
y.assign(3.0);
|
||||||
|
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', {5, 5});
|
||||||
|
exp.assign(6.0);
|
||||||
|
|
||||||
|
sd::ops::LegacyPairwiseTransformOp op(pairwise::Multiply); // Multiply
|
||||||
|
auto result = op.evaluate({&x, &y}, {}, {});
|
||||||
|
|
||||||
|
auto z = result.at(0);
|
||||||
|
|
||||||
|
//z->printBuffer("Z");
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(LegacyOpsTests, Scalar_Test_1) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {5, 5});
|
||||||
|
x.assign(2.0);
|
||||||
|
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', {5, 5});
|
||||||
|
exp.assign(7.0);
|
||||||
|
|
||||||
|
sd::ops::LegacyScalarOp op(scalar::Add);
|
||||||
|
op.execute({&x}, {&x}, {5.0}, {}, {}); //
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.equalsTo(&x));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(LegacyOpsTests, Scalar_Test_2) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {5, 5});
|
||||||
|
x.assign(2.0);
|
||||||
|
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', {5, 5});
|
||||||
|
exp.assign(7.0);
|
||||||
|
|
||||||
|
auto y = NDArrayFactory::create<float>(5.0f);
|
||||||
|
|
||||||
|
sd::ops::LegacyScalarOp op(scalar::Add, y);
|
||||||
|
auto result = op.evaluate({&x}, {}, {});
|
||||||
|
|
||||||
|
auto z = result.at(0);
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(LegacyOpsTests, ReduceTests_1) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {5, 5});
|
||||||
|
x.assign(1.0);
|
||||||
|
int opNum = reduce::Sum;
|
||||||
|
sd::ops::LegacyReduceSameOp op(opNum);
|
||||||
|
|
||||||
|
auto result = op.evaluate({&x}, {}, {});
|
||||||
|
|
||||||
|
ASSERT_EQ(1, result.size());
|
||||||
|
|
||||||
|
auto z = result.at(0);
|
||||||
|
// z->printBuffer("ReduceTest1");
|
||||||
|
ASSERT_TRUE(z->isScalar());
|
||||||
|
ASSERT_NEAR(x.sumNumber().e<float>(0), z->e<float>(0), 1e-5f);
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(LegacyOpsTests, ReduceTests_2) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {5, 5});
|
||||||
|
x.assign(1.0);
|
||||||
|
|
||||||
|
sd::ops::LegacyReduceSameOp op(reduce::Sum);
|
||||||
|
auto axis = NDArrayFactory::create<Nd4jLong>('c', {1}, {1});
|
||||||
|
auto result = op.evaluate({&x, &axis}, {}, {});
|
||||||
|
|
||||||
|
ASSERT_EQ(1, result.size());
|
||||||
|
|
||||||
|
auto z = result.at(0);
|
||||||
|
|
||||||
|
auto exp = x.reduceAlongDimension(reduce::Sum, {1});
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(LegacyOpsTests, ReduceTests_3) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {3, 5});
|
||||||
|
x.linspace(1);
|
||||||
|
auto indices = NDArrayFactory::create<int>('c', {1,1}, {1});
|
||||||
|
|
||||||
|
|
||||||
|
sd::ops::LegacyReduceSameOp op(reduce::Sum);
|
||||||
|
auto result = op.evaluate({&x, &indices}, {}, {});
|
||||||
|
auto z = result.at(0);
|
||||||
|
auto exp = x.reduceAlongDimension(reduce::Sum,{1});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(LegacyOpsTests, ReduceTests_4) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {2, 3, 5});
|
||||||
|
x.linspace(1);
|
||||||
|
auto indices = NDArrayFactory::create<int>('c', {1, 1}, {1});
|
||||||
|
|
||||||
|
|
||||||
|
sd::ops::LegacyReduceSameOp op(reduce::Sum);
|
||||||
|
auto result = op.evaluate({&x, &indices}, {}, {}, {true});
|
||||||
|
auto z = result.at(0);
|
||||||
|
auto exp = x.reduceAlongDimension(reduce::Sum, {1}, true);
|
||||||
|
// indices.printShapeInfo("Indices shape");
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
// z->printIndexedBuffer("Output reduce 4");
|
||||||
|
// exp.printIndexedBuffer("Expected reduce 4");
|
||||||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(LegacyOpsTests, ReduceTests_5) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {5, 5});
|
||||||
|
x.assign(1.0);
|
||||||
|
int opNum = reduce::Mean;
|
||||||
|
sd::ops::LegacyReduceFloatOp op(opNum);
|
||||||
|
|
||||||
|
auto result = op.evaluate({&x});
|
||||||
|
|
||||||
|
ASSERT_EQ(1, result.size());
|
||||||
|
|
||||||
|
auto z = result.at(0);
|
||||||
|
// z->printBuffer("ReduceTest1");
|
||||||
|
ASSERT_TRUE(z->isScalar());
|
||||||
|
ASSERT_NEAR(x.meanNumber().e<float>(0), z->e<float>(0), 1e-5f);
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(LegacyOpsTests, ReduceTests_6) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {5, 5});
|
||||||
|
x.assign(1.0);
|
||||||
|
auto axis = NDArrayFactory::create<int>('c', {1}, {1});
|
||||||
|
sd::ops::LegacyReduceFloatOp op(reduce::Mean);
|
||||||
|
|
||||||
|
auto result = op.evaluate({&x, &axis}, {}, {});
|
||||||
|
|
||||||
|
ASSERT_EQ(1, result.size());
|
||||||
|
|
||||||
|
auto z = result.at(0);
|
||||||
|
|
||||||
|
auto exp = x.reduceAlongDimension(reduce::Mean, {1});
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(LegacyOpsTests, ReduceTests_7) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {3, 5});
|
||||||
|
x.linspace(1);
|
||||||
|
auto indices = NDArrayFactory::create<int>('c', {1,1}, {1});
|
||||||
|
|
||||||
|
|
||||||
|
sd::ops::LegacyReduceFloatOp op(reduce::Mean);
|
||||||
|
auto result = op.evaluate({&x, &indices}, {}, {});
|
||||||
|
auto z = result.at(0);
|
||||||
|
auto exp = x.reduceAlongDimension(reduce::Mean,{1});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(LegacyOpsTests, ReduceTests_8) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {2, 3, 5});
|
||||||
|
x.linspace(1);
|
||||||
|
auto indices = NDArrayFactory::create<int>('c', {1}, {1});
|
||||||
|
|
||||||
|
|
||||||
|
sd::ops::LegacyReduceFloatOp op(reduce::Mean);
|
||||||
|
auto result = op.evaluate({&x, &indices}, {}, {}, {true});
|
||||||
|
auto z = result.at(0);
|
||||||
|
auto exp = x.reduceAlongDimension(reduce::Mean, {1}, true);
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
// z->printIndexedBuffer("Reduce8 output");
|
||||||
|
// z->printShapeInfo("Reduce8 shape");
|
||||||
|
// exp.printShapeInfo("Reduce8 expected shape");
|
||||||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(LegacyOpsTests, IndexReduceTests_1) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {5, 5});
|
||||||
|
x.linspace(1);
|
||||||
|
|
||||||
|
sd::ops::LegacyIndexReduceOp op(indexreduce::IndexMax);
|
||||||
|
|
||||||
|
auto result = op.evaluate({&x}, {}, {});
|
||||||
|
|
||||||
|
ASSERT_EQ(1, result.size());
|
||||||
|
|
||||||
|
auto z = result.at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(z->isScalar());
|
||||||
|
ASSERT_EQ(24, z->e<int>(0));
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(LegacyOpsTests, IndexReduceTests_2) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {5, 5});
|
||||||
|
auto indices = NDArrayFactory::create<int>('c', {1}, {1});
|
||||||
|
x.linspace(1);
|
||||||
|
auto exp = NDArrayFactory::create<Nd4jLong>({4,4,4,4,4});
|
||||||
|
sd::ops::LegacyIndexReduceOp op(indexreduce::IndexMax);
|
||||||
|
|
||||||
|
auto result = op.evaluate({&x, &indices}, {}, {});
|
||||||
|
|
||||||
|
ASSERT_EQ(1, result.size());
|
||||||
|
|
||||||
|
auto z = result.at(0);
|
||||||
|
// z->printIndexedBuffer("Hello indexreduce2");
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
//ASSERT_EQ(4, z->e<int>(0));
|
||||||
|
//ASSERT_EQ(4, z->e<int>(1));
|
||||||
|
//ASSERT_EQ(4, z->e<int>(2));
|
||||||
|
//ASSERT_EQ(4, z->e<int>(3));
|
||||||
|
//ASSERT_EQ(4, z->e<int>(4));
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(LegacyOpsTests, BroadcastingTests_1) {
|
||||||
|
auto x = NDArrayFactory::create<double>('c', {5, 5});
|
||||||
|
x.assign(0.0f);
|
||||||
|
|
||||||
|
auto row = NDArrayFactory::create<double>('c', {1, 5});
|
||||||
|
row.linspace(1);
|
||||||
|
auto axis = NDArrayFactory::create<int>('c', {1}, {1});
|
||||||
|
sd::ops::LegacyBroadcastOp op(broadcast::Add);
|
||||||
|
Nd4jStatus status = op.execute({&x, &row, &axis}, {&x}, {}, {}, {});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||||
|
|
||||||
|
auto list = x.allTensorsAlongDimension({1});
|
||||||
|
// x.printIndexedBuffer("Output broadcast");
|
||||||
|
// list->at(0)->printIndexedBuffer("Column 0:");
|
||||||
|
for (int e = 0; e < list.size(); e++)
|
||||||
|
ASSERT_TRUE(row.equalsTo(list.at(e)));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(LegacyOpsTests, BroadcastingTests_2) {
|
||||||
|
auto x = NDArrayFactory::create<double>('c', {5}, {1, 1, 1, 1, 1});
|
||||||
|
auto y = NDArrayFactory::create<double>('c', {10, 5});
|
||||||
|
auto e = NDArrayFactory::create<double>('c', {10, 5});
|
||||||
|
y.assign(3.0);
|
||||||
|
e.assign(4.0);
|
||||||
|
|
||||||
|
int axis = 1;
|
||||||
|
|
||||||
|
// shape::printShapeInfoLinear("tad shape", tad.tadOnlyShapeInfo);
|
||||||
|
auto packY = sd::ConstantTadHelper::getInstance().tadForDimensions(y.shapeInfo(), {axis});
|
||||||
|
|
||||||
|
NDArray::prepareSpecialUse({&y}, {&x});
|
||||||
|
|
||||||
|
NativeOpExecutioner::execInverseBroadcast(LaunchContext::defaultContext(), broadcast::Add, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), &axis, 1, packY.platformShapeInfo(), packY.platformOffsets(), packY.platformShapeInfo(), packY.platformOffsets());
|
||||||
|
|
||||||
|
NDArray::registerSpecialUse({&y}, {&x});
|
||||||
|
|
||||||
|
ASSERT_EQ(e, y);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(LegacyOpsTests, PowDerivative_1) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {5, 5});
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', {5, 5});
|
||||||
|
x.assign(3.f);
|
||||||
|
exp.assign(6.f);
|
||||||
|
|
||||||
|
float p = 2.0f;
|
||||||
|
|
||||||
|
x.applyScalar(scalar::PowDerivative, p, x);
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.equalsTo(&x));
|
||||||
|
}
|
||||||
|
|
||||||
|
#ifndef __CUDABLAS__
|
||||||
|
TEST_F(LegacyOpsTests, reduce3_1) {
|
||||||
|
|
||||||
|
Nd4jLong yShape[2] = {4,4};
|
||||||
|
Nd4jLong xShape[1] = {4};
|
||||||
|
float y[16] ={1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16};
|
||||||
|
float x[4] = {1,2,3,4};
|
||||||
|
int dimension[1] = {1};
|
||||||
|
int dimensionLength = 1;
|
||||||
|
int opNum = 1;
|
||||||
|
float extraVals[1] = {0};
|
||||||
|
float result[4] = {0.0,0.0,0.0,0.0};
|
||||||
|
|
||||||
|
std::vector<int> dim = {1};
|
||||||
|
|
||||||
|
auto shapeBuffer = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', 2, yShape);
|
||||||
|
auto xShapeBuffer = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', 1, xShape);
|
||||||
|
|
||||||
|
//int *tadShapeBuffer = shape::computeResultShape(shapeBuffer,dimension,dimensionLength);
|
||||||
|
auto tadShapeBuffer = sd::ShapeUtils::evalReduceShapeInfo('c', dim, shapeBuffer, false, true, nullptr);
|
||||||
|
functions::reduce3::Reduce3<float, float>::exec(opNum, x, xShapeBuffer, extraVals, y, shapeBuffer, result, tadShapeBuffer, dimension, dimensionLength, 0, 4);
|
||||||
|
|
||||||
|
float distancesAssertion[4] = {0.0,8.0,16.0,24.0};
|
||||||
|
for(int i = 0; i < 4; i++)
|
||||||
|
ASSERT_NEAR(distancesAssertion[i],result[i], 1e-5);
|
||||||
|
|
||||||
|
delete[] shapeBuffer;
|
||||||
|
delete[] xShapeBuffer;
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(LegacyOpsTests, Reduce3_2) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {5, 5});
|
||||||
|
auto y = NDArrayFactory::create<float>('c', {5});
|
||||||
|
auto z = NDArrayFactory::create<float>('c', {5});
|
||||||
|
|
||||||
|
auto dim = NDArrayFactory::create<int>('c', {1}, {1});
|
||||||
|
dim.syncToHost();
|
||||||
|
|
||||||
|
sd::LaunchContext* context = sd::LaunchContext::defaultContext();
|
||||||
|
|
||||||
|
Nd4jPointer* extraPointers = nullptr;
|
||||||
|
#ifdef __CUDABLAS__
|
||||||
|
extraPointers = new Nd4jPointer[7] {nullptr, context->getCudaStream(), context->getScalarPointer(), nullptr, context->getCudaSpecialStream(), context->getReductionPointer(), context->getAllocationPointer()};
|
||||||
|
#endif
|
||||||
|
|
||||||
|
auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(x.shapeInfo(), {1});
|
||||||
|
auto packY = sd::ConstantTadHelper::getInstance().tadForDimensions(y.shapeInfo(), {1});
|
||||||
|
|
||||||
|
NDArray::prepareSpecialUse({&z}, {&x, &y, &dim});
|
||||||
|
OpaqueDataBuffer xBuf(x.dataBuffer());
|
||||||
|
OpaqueDataBuffer yBuf(y.dataBuffer());
|
||||||
|
OpaqueDataBuffer zBuf(z.dataBuffer());
|
||||||
|
OpaqueDataBuffer dimBuf(dim.dataBuffer());
|
||||||
|
|
||||||
|
execReduce3Tad(extraPointers, reduce3::CosineSimilarity,
|
||||||
|
&xBuf, x.shapeInfo(), x.specialShapeInfo(),
|
||||||
|
nullptr, &yBuf, y.shapeInfo(), y.specialShapeInfo(),
|
||||||
|
&zBuf, z.shapeInfo(), z.specialShapeInfo(),
|
||||||
|
&dimBuf, dim.shapeInfo(), dim.specialShapeInfo(),
|
||||||
|
packX.platformShapeInfo(), packX.platformOffsets(), packY.platformShapeInfo(), packY.platformOffsets());
|
||||||
|
|
||||||
|
NDArray::registerSpecialUse({&z}, {&x, &y, &dim});
|
||||||
|
|
||||||
|
delete []extraPointers;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(LegacyOpsTests, Reduce3_3) {
|
||||||
|
auto x = NDArrayFactory::create<double>('c', {3, 5}, {-0.84443557262, -0.06822254508, 0.74266910552, 0.61765557527, -0.77555125951,
|
||||||
|
-0.99536740779, -0.0257304441183, -0.6512106060, -0.345789492130, -1.25485503673,
|
||||||
|
0.62955373525, -0.31357592344, 1.03362500667, -0.59279078245, 1.1914824247});
|
||||||
|
|
||||||
|
auto y = NDArrayFactory::create<double>('c', {5}, {-0.99536740779, -0.0257304441183, -0.6512106060, -0.345789492130, -1.25485503673});
|
||||||
|
auto e = NDArrayFactory::create<double>('c', {3}, {0.577452, 0.0, 1.80182});
|
||||||
|
auto z = NDArrayFactory::create<double>('c', {3});
|
||||||
|
|
||||||
|
auto dim = NDArrayFactory::create<int>('c', {1}, {1});
|
||||||
|
dim.syncToHost();
|
||||||
|
|
||||||
|
sd::LaunchContext* context = sd::LaunchContext::defaultContext();
|
||||||
|
|
||||||
|
Nd4jPointer* extraPointers = nullptr;
|
||||||
|
#ifdef __CUDABLAS__
|
||||||
|
extraPointers = new Nd4jPointer[7] {nullptr, context->getCudaStream(), context->getScalarPointer(), nullptr, context->getCudaSpecialStream(), context->getReductionPointer(), context->getAllocationPointer()};
|
||||||
|
#endif
|
||||||
|
|
||||||
|
auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(x.shapeInfo(), {1});
|
||||||
|
auto packY = sd::ConstantTadHelper::getInstance().tadForDimensions(y.shapeInfo(), {1});
|
||||||
|
|
||||||
|
NDArray::prepareSpecialUse({&z}, {&x, &y, &dim});
|
||||||
|
OpaqueDataBuffer xBuf(x.dataBuffer());
|
||||||
|
OpaqueDataBuffer yBuf(y.dataBuffer());
|
||||||
|
OpaqueDataBuffer zBuf(z.dataBuffer());
|
||||||
|
OpaqueDataBuffer dimBuf(dim.dataBuffer());
|
||||||
|
|
||||||
|
execReduce3Tad(extraPointers, reduce3::CosineDistance,
|
||||||
|
&xBuf, x.shapeInfo(), x.specialShapeInfo(),
|
||||||
|
nullptr,
|
||||||
|
&yBuf, y.shapeInfo(), y.specialShapeInfo(),
|
||||||
|
&zBuf, z.shapeInfo(), z.specialShapeInfo(),
|
||||||
|
&dimBuf, dim.shapeInfo(), dim.specialShapeInfo(),
|
||||||
|
packX.platformShapeInfo(), packX.platformOffsets(), packY.platformShapeInfo(), packY.platformOffsets());
|
||||||
|
ASSERT_EQ(e, z);
|
||||||
|
NDArray::registerSpecialUse({&z}, {&x, &y, &dim});
|
||||||
|
delete []extraPointers;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(LegacyOpsTests, Reduce3_4) {
|
||||||
|
auto x = NDArrayFactory::create<double>('c', {3, 5}, {-0.84443557262, -0.06822254508, 0.74266910552, 0.61765557527, -0.77555125951,
|
||||||
|
-0.99536740779, -0.0257304441183, -0.6512106060, -0.345789492130, -1.25485503673,
|
||||||
|
0.62955373525, -0.31357592344, 1.03362500667, -0.59279078245, 1.1914824247});
|
||||||
|
|
||||||
|
auto y = NDArrayFactory::create<double>('c', {1, 5}, {-0.99536740779, -0.0257304441183, -0.6512106060, -0.345789492130, -1.25485503673});
|
||||||
|
auto e = NDArrayFactory::create<double>('c', {1, 3}, {0.577452, 0.0, 1.80182});
|
||||||
|
auto z = NDArrayFactory::create<double>('c', {1, 3});
|
||||||
|
|
||||||
|
auto dim = NDArrayFactory::create<int>('c', {1}, {1});
|
||||||
|
dim.syncToHost();
|
||||||
|
|
||||||
|
sd::LaunchContext* context = sd::LaunchContext::defaultContext();
|
||||||
|
|
||||||
|
Nd4jPointer* extraPointers = nullptr;
|
||||||
|
#ifdef __CUDABLAS__
|
||||||
|
extraPointers = new Nd4jPointer[7] {nullptr, context->getCudaStream(), context->getScalarPointer(), nullptr, context->getCudaSpecialStream(), context->getReductionPointer(), context->getAllocationPointer()};
|
||||||
|
#endif
|
||||||
|
|
||||||
|
auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(x.shapeInfo(), {1});
|
||||||
|
auto packY = sd::ConstantTadHelper::getInstance().tadForDimensions(y.shapeInfo(), {1});
|
||||||
|
|
||||||
|
NDArray::prepareSpecialUse({&z}, {&x, &y, &dim});
|
||||||
|
OpaqueDataBuffer xBuf(x.dataBuffer());
|
||||||
|
OpaqueDataBuffer yBuf(y.dataBuffer());
|
||||||
|
OpaqueDataBuffer zBuf(z.dataBuffer());
|
||||||
|
OpaqueDataBuffer dimBuf(dim.dataBuffer());
|
||||||
|
|
||||||
|
execReduce3Tad(extraPointers, reduce3::CosineDistance,
|
||||||
|
&xBuf, x.shapeInfo(), x.specialShapeInfo(),
|
||||||
|
nullptr,
|
||||||
|
&yBuf, y.shapeInfo(), y.specialShapeInfo(),
|
||||||
|
&zBuf, z.shapeInfo(), z.specialShapeInfo(),
|
||||||
|
&dimBuf, dim.shapeInfo(), dim.specialShapeInfo(),
|
||||||
|
packX.platformShapeInfo(), packX.platformOffsets(), packY.platformShapeInfo(), packY.platformOffsets());
|
||||||
|
|
||||||
|
// z.printIndexedBuffer("z");
|
||||||
|
NDArray::registerSpecialUse({&z}, {&x, &y, &dim});
|
||||||
|
ASSERT_EQ(e, z);
|
||||||
|
delete []extraPointers;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(LegacyOpsTests, Reduce3_5) {
|
||||||
|
auto x = NDArrayFactory::create<double>('c', {3, 5}, {-0.84443557262, -0.06822254508, 0.74266910552, 0.61765557527, -0.77555125951,
|
||||||
|
-0.99536740779, -0.0257304441183, -0.6512106060, -0.345789492130, -1.25485503673,
|
||||||
|
0.62955373525, -0.31357592344, 1.03362500667, -0.59279078245, 1.1914824247});
|
||||||
|
|
||||||
|
auto y = NDArrayFactory::create<double>('c', {1, 5}, {-0.99536740779, -0.0257304441183, -0.6512106060, -0.345789492130, -1.25485503673});
|
||||||
|
auto e = NDArrayFactory::create<double>('c', {1, 3}, {0.577452, 0.0, 1.80182});
|
||||||
|
auto z = NDArrayFactory::create<double>('c', {1, 3});
|
||||||
|
|
||||||
|
auto dim = NDArrayFactory::create<int>('c', {1}, {1});
|
||||||
|
dim.syncToHost();
|
||||||
|
|
||||||
|
sd::LaunchContext* context = sd::LaunchContext::defaultContext();
|
||||||
|
|
||||||
|
Nd4jPointer* extraPointers = nullptr;
|
||||||
|
#ifdef __CUDABLAS__
|
||||||
|
extraPointers = new Nd4jPointer[7] {nullptr, context->getCudaStream(), context->getScalarPointer(), nullptr, context->getCudaSpecialStream(), context->getReductionPointer(), context->getAllocationPointer()};
|
||||||
|
#endif
|
||||||
|
|
||||||
|
auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(x.shapeInfo(), {1});
|
||||||
|
auto packY = sd::ConstantTadHelper::getInstance().tadForDimensions(y.shapeInfo(), {1});
|
||||||
|
|
||||||
|
NDArray::prepareSpecialUse({&z}, {&x, &y, &dim});
|
||||||
|
|
||||||
|
OpaqueDataBuffer xBuf(x.dataBuffer());
|
||||||
|
OpaqueDataBuffer yBuf(y.dataBuffer());
|
||||||
|
OpaqueDataBuffer zBuf(z.dataBuffer());
|
||||||
|
OpaqueDataBuffer dimBuf(dim.dataBuffer());
|
||||||
|
|
||||||
|
execReduce3Tad(extraPointers, reduce3::CosineDistance,
|
||||||
|
&xBuf, x.shapeInfo(), x.specialShapeInfo(),
|
||||||
|
nullptr,
|
||||||
|
&yBuf, y.shapeInfo(), y.specialShapeInfo(), &zBuf, z.shapeInfo(), z.specialShapeInfo(),
|
||||||
|
&dimBuf, dim.shapeInfo(), dim.specialShapeInfo(),
|
||||||
|
packX.platformShapeInfo(), packX.platformOffsets(), packY.platformShapeInfo(), packY.platformOffsets());
|
||||||
|
|
||||||
|
NDArray::registerSpecialUse({&z}, {&x, &y, &dim});
|
||||||
|
ASSERT_EQ(e, z);
|
||||||
|
delete []extraPointers;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(LegacyOpsTests, test_Reduce3_All_1) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {1000, 100});
|
||||||
|
auto y = NDArrayFactory::create<float>('c', {1, 100});
|
||||||
|
auto z = NDArrayFactory::create<float>('c', {1000, 1});
|
||||||
|
auto dim = NDArrayFactory::create<int>('c', {1}, {-1});
|
||||||
|
|
||||||
|
auto tadPackX = sd::ConstantTadHelper::getInstance().tadForDimensions(x.shapeInfo(), -1);
|
||||||
|
auto tadPackY = sd::ConstantTadHelper::getInstance().tadForDimensions(y.shapeInfo(), -1);
|
||||||
|
|
||||||
|
sd::LaunchContext* context = sd::LaunchContext::defaultContext();
|
||||||
|
|
||||||
|
Nd4jPointer* extraPointers = nullptr;
|
||||||
|
#ifdef __CUDABLAS__
|
||||||
|
extraPointers = new Nd4jPointer[7] {nullptr, context->getCudaStream(), context->getScalarPointer(), nullptr, context->getCudaSpecialStream(), context->getReductionPointer(), context->getAllocationPointer()};
|
||||||
|
#endif
|
||||||
|
|
||||||
|
NDArray::prepareSpecialUse({&z}, {&x, &y});
|
||||||
|
|
||||||
|
OpaqueDataBuffer xBuf(x.dataBuffer());
|
||||||
|
OpaqueDataBuffer yBuf(y.dataBuffer());
|
||||||
|
OpaqueDataBuffer zBuf(z.dataBuffer());
|
||||||
|
OpaqueDataBuffer dimBuf(dim.dataBuffer());
|
||||||
|
|
||||||
|
execReduce3All(extraPointers, reduce3::EuclideanDistance, &xBuf, x.shapeInfo(), x.specialShapeInfo(),
|
||||||
|
nullptr, &yBuf, y.shapeInfo(), y.specialShapeInfo(),
|
||||||
|
&zBuf, z.shapeInfo(), z.specialShapeInfo(),
|
||||||
|
&dimBuf, dim.shapeInfo(), dim.specialShapeInfo(),
|
||||||
|
tadPackX.platformShapeInfo(), tadPackX.platformOffsets(),
|
||||||
|
tadPackY.platformShapeInfo(), tadPackY.platformOffsets());
|
||||||
|
|
||||||
|
NDArray::registerSpecialUse({&z}, {&x, &y});
|
||||||
|
|
||||||
|
delete []extraPointers;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(LegacyOpsTests, test_inverse_broadcast_1) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {4}, {2.0f, 2.0f, 2.0f, 2.0f});
|
||||||
|
auto y = NDArrayFactory::create<float>('c', {3, 4});
|
||||||
|
auto e = NDArrayFactory::create<float>('c', {3, 4});
|
||||||
|
e.assign(2.0f);
|
||||||
|
|
||||||
|
auto tadPackY = sd::ConstantTadHelper::getInstance().tadForDimensions(y.shapeInfo(), 1);
|
||||||
|
|
||||||
|
y.tickWriteDevice();
|
||||||
|
|
||||||
|
NativeOpExecutioner::execInverseBroadcast(LaunchContext::defaultContext(), broadcast::Add,
|
||||||
|
x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(),
|
||||||
|
y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(),
|
||||||
|
y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(),
|
||||||
|
nullptr, 0,
|
||||||
|
tadPackY.platformShapeInfo(), tadPackY.platformOffsets(),
|
||||||
|
tadPackY.platformShapeInfo(), tadPackY.platformOffsets());
|
||||||
|
|
||||||
|
ASSERT_EQ(e, y);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(LegacyOpsTests, test_inverse_broadcast_2) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {4}, {2.0f, 2.0f, 2.0f, 2.0f});
|
||||||
|
auto y = NDArrayFactory::create<float>('c', {3, 4});
|
||||||
|
auto z = NDArrayFactory::create<bool>('c', {3, 4});
|
||||||
|
auto e = NDArrayFactory::create<bool>('c', {3, 4});
|
||||||
|
e.assign(false);
|
||||||
|
|
||||||
|
auto row = y(1, {0});
|
||||||
|
row.assign(2.0f);
|
||||||
|
|
||||||
|
auto erow = e(1, {0});
|
||||||
|
erow.assign(true);
|
||||||
|
|
||||||
|
auto tadPackY = sd::ConstantTadHelper::getInstance().tadForDimensions(y.shapeInfo(), 1);
|
||||||
|
|
||||||
|
z.tickWriteDevice();
|
||||||
|
|
||||||
|
NativeOpExecutioner::execInverseBroadcastBool(LaunchContext::defaultContext(), broadcast::BoolOps::EqualTo,
|
||||||
|
x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(),
|
||||||
|
y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(),
|
||||||
|
z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(),
|
||||||
|
nullptr,
|
||||||
|
nullptr, 0,
|
||||||
|
tadPackY.platformShapeInfo(), tadPackY.platformOffsets(),
|
||||||
|
tadPackY.platformShapeInfo(), tadPackY.platformOffsets());
|
||||||
|
|
||||||
|
ASSERT_EQ(e, z);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(LegacyOpsTests, test_legacy_reduce_empty_1) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {2, 0, 3});
|
||||||
|
auto z = NDArrayFactory::create<float>('c', {2, 3});
|
||||||
|
auto e = NDArrayFactory::create<float>('c', {2, 3});
|
||||||
|
|
||||||
|
int dim = 1;
|
||||||
|
|
||||||
|
NativeOpExecutioner::execReduceSame(LaunchContext::defaultContext(), reduce::SameOps::Sum,
|
||||||
|
x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(),
|
||||||
|
nullptr,
|
||||||
|
z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(),
|
||||||
|
&dim, 1);
|
||||||
|
|
||||||
|
ASSERT_EQ(e, z);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(LegacyOpsTests, test_legacy_reduce_empty_2) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {2, 0, 3});
|
||||||
|
auto z = NDArrayFactory::create<float>('c', {2, 3});
|
||||||
|
auto e = NDArrayFactory::create<float>('c', {2, 3});
|
||||||
|
e.assign(std::numeric_limits<float>::infinity());
|
||||||
|
|
||||||
|
int dim = 1;
|
||||||
|
|
||||||
|
NativeOpExecutioner::execReduceSame(LaunchContext::defaultContext(), reduce::SameOps::Min, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), nullptr, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), &dim, 1);
|
||||||
|
|
||||||
|
ASSERT_EQ(e, z);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(LegacyOpsTests, test_legacy_reduce_empty_3) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {2, 0, 3});
|
||||||
|
auto z = NDArrayFactory::create<float>('c', {2, 3});
|
||||||
|
auto e = NDArrayFactory::create<float>('c', {2, 3});
|
||||||
|
e.assign(-std::numeric_limits<float>::infinity());
|
||||||
|
|
||||||
|
int dim = 1;
|
||||||
|
|
||||||
|
NativeOpExecutioner::execReduceSame(LaunchContext::defaultContext(), reduce::SameOps::Max, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), nullptr, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), &dim, 1);
|
||||||
|
|
||||||
|
ASSERT_EQ(e, z);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(LegacyOpsTests, test_legacy_reduce_empty_4) {
|
||||||
|
if (!Environment::getInstance().isCPU())
|
||||||
|
return;
|
||||||
|
int a = 0;
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {1, 0, 2});
|
||||||
|
auto d = NDArrayFactory::create<int>('c', {1}, {a});
|
||||||
|
auto z = NDArrayFactory::create<float>('c', {0, 2});
|
||||||
|
auto e = NDArrayFactory::create<float>('c', {0, 2});
|
||||||
|
|
||||||
|
InteropDataBuffer xdb(x.dataBuffer());
|
||||||
|
InteropDataBuffer ddb(d.dataBuffer());
|
||||||
|
InteropDataBuffer zdb(z.dataBuffer());
|
||||||
|
|
||||||
|
|
||||||
|
::execReduceSame2(nullptr, reduce::SameOps::Sum,
|
||||||
|
&xdb, x.shapeInfo(), x.specialShapeInfo(),
|
||||||
|
nullptr,
|
||||||
|
&zdb, z.shapeInfo(), z.specialShapeInfo(),
|
||||||
|
&ddb, d.shapeInfo(), d.specialShapeInfo());
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(LegacyOpsTests, test_legacy_transform_float_1) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {1, 0, 4});
|
||||||
|
|
||||||
|
NativeOpExecutioner::execTransformFloat(LaunchContext::defaultContext(), transform::FloatOps::RSqrt, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), nullptr, nullptr, nullptr);
|
||||||
|
}
|
|
@ -0,0 +1,663 @@
|
||||||
|
/* ******************************************************************************
|
||||||
|
*
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* See the NOTICE file distributed with this work for additional
|
||||||
|
* information regarding copyright ownership.
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "testlayers.h"
|
||||||
|
#include <array/NDArray.h>
|
||||||
|
#include <graph/GraphExecutioner.h>
|
||||||
|
#include <ops/declarable/CustomOperations.h>
|
||||||
|
|
||||||
|
using namespace sd;
|
||||||
|
using namespace sd::ops;
|
||||||
|
|
||||||
|
class ListOperationsTests : public testing::Test {
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(ListOperationsTests, BasicTest_Write_1) {
|
||||||
|
NDArrayList list(5);
|
||||||
|
auto x = NDArrayFactory::create<double>('c', {128});
|
||||||
|
x.linspace(1);
|
||||||
|
|
||||||
|
sd::ops::write_list op;
|
||||||
|
|
||||||
|
auto result = op.execute(&list, {&x}, {}, {1});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
|
ASSERT_EQ(1, list.elements());
|
||||||
|
|
||||||
|
auto result2 = op.execute(&list, {&x}, {}, {2});
|
||||||
|
|
||||||
|
ASSERT_EQ(2, list.elements());
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ListOperationsTests, BasicTest_Stack_1) {
|
||||||
|
NDArrayList list(10);
|
||||||
|
auto exp = NDArrayFactory::create<double>('c', {10, 100});
|
||||||
|
auto tads = exp.allTensorsAlongDimension({1});
|
||||||
|
for (int e = 0; e < 10; e++) {
|
||||||
|
auto row = NDArrayFactory::create_<double>('c', {100});
|
||||||
|
row->assign((double) e);
|
||||||
|
list.write(e, row);
|
||||||
|
tads.at(e)->assign(row);
|
||||||
|
}
|
||||||
|
|
||||||
|
sd::ops::stack_list op;
|
||||||
|
|
||||||
|
auto result = op.execute(&list, {}, {}, {1});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
|
auto z = result.at(0);
|
||||||
|
// z->printShapeInfo();
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ListOperationsTests, BasicTest_UnStackList_1) {
|
||||||
|
NDArrayList list(0, true);
|
||||||
|
auto x = NDArrayFactory::create<double>('c', {10, 100});
|
||||||
|
auto tads = x.allTensorsAlongDimension({1});
|
||||||
|
for (int e = 0; e < 10; e++) {
|
||||||
|
auto row = NDArrayFactory::create_<double>('c', {100});
|
||||||
|
row->assign((double) e);
|
||||||
|
//list.write(e, row);
|
||||||
|
tads.at(e)->assign(row);
|
||||||
|
delete row;
|
||||||
|
}
|
||||||
|
|
||||||
|
sd::ops::unstack_list op;
|
||||||
|
|
||||||
|
auto result = op.execute(&list, {&x}, {}, {0});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
ASSERT_EQ(list.elements(), 10);
|
||||||
|
|
||||||
|
// auto z = result.at(0);
|
||||||
|
// z->printShapeInfo("The first of");
|
||||||
|
// ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
// ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
for (int e = 0; e < 10; e++) {
|
||||||
|
auto row = list.read(e);
|
||||||
|
ASSERT_TRUE(row->equalsTo(tads.at(e)));
|
||||||
|
//list.write(e, row);
|
||||||
|
delete row;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
//TEST_F(ListOperationsTests, BasicTest_UnStackList_2) {
|
||||||
|
//// NDArrayList list(0, true);
|
||||||
|
// auto x = NDArrayFactory::create<double>('c', {10, 100});
|
||||||
|
// auto tads = x.allTensorsAlongDimension({1});
|
||||||
|
// for (int e = 0; e < 10; e++) {
|
||||||
|
// auto row = NDArrayFactory::create_<double>('c', {100});
|
||||||
|
// row->assign((double) e);
|
||||||
|
// //list.write(e, row);
|
||||||
|
// tads->at(e)->assign(row);
|
||||||
|
// delete row;
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// sd::ops::unstack_list op;
|
||||||
|
//
|
||||||
|
// auto result = op.execute(nullptr, {&x}, {}, {0});
|
||||||
|
//
|
||||||
|
// ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
// ASSERT_EQ(result->size(), 10);
|
||||||
|
//
|
||||||
|
// // auto z = result.at(0);
|
||||||
|
//// z->printShapeInfo("The first of");
|
||||||
|
//// ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
//// ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
// for (int e = 0; e < 10; e++) {
|
||||||
|
// auto row = result.at(e);
|
||||||
|
// ASSERT_TRUE(row->equalsTo(tads->at(e)));
|
||||||
|
// //list.write(e, row);
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
//
|
||||||
|
// delete tads;
|
||||||
|
//}
|
||||||
|
|
||||||
|
TEST_F(ListOperationsTests, BasicTest_Read_1) {
|
||||||
|
NDArrayList list(10);
|
||||||
|
auto exp = NDArrayFactory::create<double>('c', {1, 100});
|
||||||
|
exp.assign(4.0f);
|
||||||
|
|
||||||
|
for (int e = 0; e < 10; e++) {
|
||||||
|
auto row = NDArrayFactory::create_<double>('c', {1, 100});
|
||||||
|
row->assign((double) e);
|
||||||
|
list.write(e, new NDArray(row->dup()));
|
||||||
|
|
||||||
|
delete row;
|
||||||
|
}
|
||||||
|
|
||||||
|
sd::ops::read_list op;
|
||||||
|
|
||||||
|
auto result = op.execute(&list, {}, {}, {4});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
|
auto z = result.at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ListOperationsTests, BasicTest_Pick_1) {
|
||||||
|
NDArrayList list(10);
|
||||||
|
auto exp = NDArrayFactory::create<double>('c', {4, 100});
|
||||||
|
|
||||||
|
for (int e = 0; e < 10; e++) {
|
||||||
|
auto row = NDArrayFactory::create_<double>('c', {100});
|
||||||
|
row->assign((double) e);
|
||||||
|
list.write(e, new NDArray(row->dup()));
|
||||||
|
|
||||||
|
delete row;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto tads = exp.allTensorsAlongDimension({1});
|
||||||
|
tads.at(0)->assign(1.0f);
|
||||||
|
tads.at(1)->assign(1.0f);
|
||||||
|
tads.at(2)->assign(3.0f);
|
||||||
|
tads.at(3)->assign(3.0f);
|
||||||
|
|
||||||
|
|
||||||
|
sd::ops::pick_list op;
|
||||||
|
auto result = op.execute(&list, {}, {}, {1, 1, 3, 3});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
|
auto z = result.at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ListOperationsTests, BasicTest_Size_1) {
|
||||||
|
NDArrayList list(10);
|
||||||
|
auto exp = NDArrayFactory::create<int>(10);
|
||||||
|
for (int e = 0; e < 10; e++) {
|
||||||
|
auto row = NDArrayFactory::create_<double>('c', {100});
|
||||||
|
row->assign((double) e);
|
||||||
|
list.write(e, new NDArray(row->dup()));
|
||||||
|
|
||||||
|
delete row;
|
||||||
|
}
|
||||||
|
|
||||||
|
sd::ops::size_list op;
|
||||||
|
|
||||||
|
auto result = op.execute(&list, {}, {}, {1});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
|
auto z = result.at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ListOperationsTests, BasicTest_Create_1) {
|
||||||
|
auto matrix = NDArrayFactory::create<double>('c', {3, 2});
|
||||||
|
matrix.linspace(1);
|
||||||
|
|
||||||
|
sd::ops::create_list op;
|
||||||
|
|
||||||
|
auto result = op.execute(nullptr, {&matrix}, {}, {1, 1});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
|
// we return flow as well
|
||||||
|
ASSERT_EQ(1, result.size());
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ListOperationsTests, BasicTest_Split_1) {
|
||||||
|
NDArrayList list(0, true);
|
||||||
|
|
||||||
|
auto exp0 = NDArrayFactory::create<double>('c', {2, 5});
|
||||||
|
auto exp1 = NDArrayFactory::create<double>('c', {3, 5});
|
||||||
|
auto exp2 = NDArrayFactory::create<double>('c', {5, 5});
|
||||||
|
|
||||||
|
auto matrix = NDArrayFactory::create<double>('c', {10, 5});
|
||||||
|
|
||||||
|
auto lengths = NDArrayFactory::create<int>('c', {3});
|
||||||
|
lengths.p(0, 2);
|
||||||
|
lengths.p(1, 3);
|
||||||
|
lengths.p(2, 5);
|
||||||
|
|
||||||
|
auto tads = matrix.allTensorsAlongDimension({1});
|
||||||
|
|
||||||
|
auto tads0 = exp0.allTensorsAlongDimension({1});
|
||||||
|
auto tads1 = exp1.allTensorsAlongDimension({1});
|
||||||
|
auto tads2 = exp2.allTensorsAlongDimension({1});
|
||||||
|
|
||||||
|
int cnt0 = 0;
|
||||||
|
int cnt1 = 0;
|
||||||
|
int cnt2 = 0;
|
||||||
|
for (int e = 0; e < 10; e++) {
|
||||||
|
auto row = NDArrayFactory::create_<double>('c', {5});
|
||||||
|
row->assign((double) e);
|
||||||
|
tads.at(e)->assign(row);
|
||||||
|
|
||||||
|
if (e < 2)
|
||||||
|
tads0.at(cnt0++)->assign(row);
|
||||||
|
else if (e < 5)
|
||||||
|
tads1.at(cnt1++)->assign(row);
|
||||||
|
else
|
||||||
|
tads2.at(cnt2++)->assign(row);
|
||||||
|
|
||||||
|
delete row;
|
||||||
|
}
|
||||||
|
|
||||||
|
sd::ops::split_list op;
|
||||||
|
auto result = op.execute(&list, {&matrix, &lengths}, {}, {});
|
||||||
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
|
ASSERT_EQ(3, list.height());
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp0.isSameShape(list.readRaw(0)));
|
||||||
|
ASSERT_TRUE(exp0.equalsTo(list.readRaw(0)));
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp1.isSameShape(list.readRaw(1)));
|
||||||
|
ASSERT_TRUE(exp1.equalsTo(list.readRaw(1)));
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp2.isSameShape(list.readRaw(2)));
|
||||||
|
ASSERT_TRUE(exp2.equalsTo(list.readRaw(2)));
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ListOperationsTests, BasicTest_Scatter_1) {
|
||||||
|
NDArrayList list(0, true);
|
||||||
|
auto s = NDArrayFactory::create<double>(0.0);
|
||||||
|
|
||||||
|
auto matrix = NDArrayFactory::create<double>('c', {10, 5});
|
||||||
|
auto tads = matrix.allTensorsAlongDimension({1});
|
||||||
|
for (int e = 0; e < 10; e++) {
|
||||||
|
auto row = NDArrayFactory::create_<double>('c', {1, 5});
|
||||||
|
row->assign((double) e);
|
||||||
|
tads.at(e)->assign(row);
|
||||||
|
|
||||||
|
delete row;
|
||||||
|
}
|
||||||
|
auto indices = NDArrayFactory::create<double>('c', {1, 10});
|
||||||
|
for (int e = 0; e < matrix.rows(); e++)
|
||||||
|
indices.p(e, 9 - e);
|
||||||
|
|
||||||
|
sd::ops::scatter_list op;
|
||||||
|
auto result = op.execute(&list, {&indices, &matrix, &s}, {}, {});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
|
for (int e = 0; e < 10; e++) {
|
||||||
|
auto row = tads.at(9 - e);
|
||||||
|
auto chunk = list.readRaw(e);
|
||||||
|
|
||||||
|
ASSERT_TRUE(chunk->isSameShape(row));
|
||||||
|
|
||||||
|
ASSERT_TRUE(chunk->equalsTo(row));
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ListOperationsTests, BasicTest_Clone_1) {
|
||||||
|
auto list = new NDArrayList(0, true);
|
||||||
|
|
||||||
|
VariableSpace variableSpace;
|
||||||
|
auto var = new Variable(nullptr, nullptr, -1, 0);
|
||||||
|
var->setNDArrayList(list);
|
||||||
|
|
||||||
|
variableSpace.putVariable(-1, var);
|
||||||
|
variableSpace.trackList(list);
|
||||||
|
|
||||||
|
Context block(1, &variableSpace);
|
||||||
|
block.pickInput(-1);
|
||||||
|
|
||||||
|
sd::ops::clone_list op;
|
||||||
|
|
||||||
|
ASSERT_TRUE(list == block.variable(0)->getNDArrayList());
|
||||||
|
|
||||||
|
auto result = op.execute(&block);
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result);
|
||||||
|
|
||||||
|
auto resVar = variableSpace.getVariable(1);
|
||||||
|
|
||||||
|
auto resList = resVar->getNDArrayList();
|
||||||
|
|
||||||
|
ASSERT_TRUE( resList != nullptr);
|
||||||
|
|
||||||
|
ASSERT_TRUE(list->equals(*resList));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ListOperationsTests, BasicTest_Gather_1) {
|
||||||
|
NDArrayList list(0, true);
|
||||||
|
for (int e = 0; e < 10; e++) {
|
||||||
|
auto row = NDArrayFactory::create_<double>('c', {3});
|
||||||
|
row->assign((double) e);
|
||||||
|
list.write(e, new NDArray(row->dup()));
|
||||||
|
|
||||||
|
delete row;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto exp = NDArrayFactory::create<double>('c', {10, 3});
|
||||||
|
auto tads = exp.allTensorsAlongDimension({1});
|
||||||
|
for (int e = 0; e < 10; e++) {
|
||||||
|
auto tad = tads.at(9 - e);
|
||||||
|
tad->assign(e);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto indices = NDArrayFactory::create<double>('c', {1, 10});
|
||||||
|
indices.linspace(9, -1);
|
||||||
|
|
||||||
|
sd::ops::gather_list op;
|
||||||
|
auto result = op.execute(&list, {&indices}, {}, {});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
ASSERT_EQ(1, result.size());
|
||||||
|
|
||||||
|
auto z = result.at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
|
||||||
|
//exp.printIndexedBuffer("e");
|
||||||
|
//z->printIndexedBuffer("z");
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ListOperationsTests, GraphTests_Sequential_1) {
|
||||||
|
Graph graph;
|
||||||
|
|
||||||
|
auto matrix = NDArrayFactory::create_<float>('c', {3, 3});
|
||||||
|
auto tads = matrix->allTensorsAlongDimension({1});
|
||||||
|
for (int e = 0; e < tads.size(); e++) {
|
||||||
|
tads.at(e)->assign((float) (e+1));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', {3, 3});
|
||||||
|
auto tadsExp = exp.allTensorsAlongDimension({1});
|
||||||
|
tadsExp.at(0)->assign(0.f);
|
||||||
|
tadsExp.at(1)->assign(-1.f);
|
||||||
|
tadsExp.at(2)->assign(-2.f);
|
||||||
|
|
||||||
|
auto indices = NDArrayFactory::valueOf<int>({3}, 1, 'c');
|
||||||
|
//indices->linspace(0);
|
||||||
|
|
||||||
|
|
||||||
|
auto variableSpace = graph.getVariableSpace();
|
||||||
|
variableSpace->putVariable(-1, matrix);
|
||||||
|
variableSpace->putVariable(-2, indices);
|
||||||
|
|
||||||
|
|
||||||
|
auto nodeA = new Node(OpType_TRANSFORM_SAME, 0, 1, {-1});
|
||||||
|
|
||||||
|
// creating list
|
||||||
|
sd::ops::create_list opB;
|
||||||
|
auto nodeB = new Node(&opB, 2, {1},{},{}, 0.0f, {}, {0, 1});
|
||||||
|
//nodeB->setCustomOp(&opB);
|
||||||
|
|
||||||
|
// filling list with matrix
|
||||||
|
sd::ops::split_list opC;
|
||||||
|
auto nodeC = new Node(&opC, 3, {2, 1, -2});
|
||||||
|
//nodeC->setCustomOp(&opC);
|
||||||
|
|
||||||
|
// reading chunks from List. We're adding op number 3 in inputs, to ensure graph will execute this node after split
|
||||||
|
sd::ops::read_list opD;
|
||||||
|
auto nodeD0 = new Node(&opD, 5, {2, 3}, {},{}, 0.0f, {}, {0});
|
||||||
|
auto nodeD1 = new Node(&opD, 6, {2, 3}, {},{}, 0.0f, {}, {1});
|
||||||
|
auto nodeD2 = new Node(&opD, 7, {2, 3}, {},{}, 0.0f, {}, {2});
|
||||||
|
//nodeD0->setCustomOp(&opD);
|
||||||
|
//nodeD1->setCustomOp(&opD);
|
||||||
|
//nodeD2->setCustomOp(&opD);
|
||||||
|
|
||||||
|
// using OneMinus on each chunk separately
|
||||||
|
auto nodeE0 = new Node(OpType_TRANSFORM_SAME, sd::transform::OneMinus, 10, {5});
|
||||||
|
auto nodeE1 = new Node(OpType_TRANSFORM_SAME, sd::transform::OneMinus, 11, {6});
|
||||||
|
auto nodeE2 = new Node(OpType_TRANSFORM_SAME, sd::transform::OneMinus, 12, {7});
|
||||||
|
|
||||||
|
// writing chunks back to the List
|
||||||
|
sd::ops::write_list opF;
|
||||||
|
auto nodeF0 = new Node(&opF, 15, {2, 10}, {},{}, 0.0f, {}, {0});
|
||||||
|
auto nodeF1 = new Node(&opF, 16, {2, 11}, {},{}, 0.0f, {}, {1});
|
||||||
|
auto nodeF2 = new Node(&opF, 17, {2, 12}, {},{}, 0.0f, {}, {2});
|
||||||
|
|
||||||
|
// nodeF0->setCustomOp(&opF);
|
||||||
|
// nodeF1->setCustomOp(&opF);
|
||||||
|
// nodeF2->setCustomOp(&opF);
|
||||||
|
|
||||||
|
// now we're stacking chunks back to matrix state
|
||||||
|
sd::ops::stack_list opG;
|
||||||
|
auto nodeG = new Node(&opG, 20, {2, 15, 16, 17});
|
||||||
|
//auto nodeG = new Node<float>(OpType_CUSTOM, 0, 20, {2});
|
||||||
|
|
||||||
|
// nodeG->setCustomOp(&opG);
|
||||||
|
|
||||||
|
|
||||||
|
graph.addNode(nodeA);
|
||||||
|
graph.addNode(nodeB);
|
||||||
|
graph.addNode(nodeC);
|
||||||
|
graph.addNode(nodeD0);
|
||||||
|
graph.addNode(nodeD1);
|
||||||
|
graph.addNode(nodeD2);
|
||||||
|
graph.addNode(nodeE0);
|
||||||
|
graph.addNode(nodeE1);
|
||||||
|
graph.addNode(nodeE2);
|
||||||
|
|
||||||
|
graph.addNode(nodeF0);
|
||||||
|
graph.addNode(nodeF1);
|
||||||
|
graph.addNode(nodeF2);
|
||||||
|
|
||||||
|
graph.addNode(nodeG);
|
||||||
|
|
||||||
|
// let's also validate structural integrity
|
||||||
|
graph.buildGraph();
|
||||||
|
|
||||||
|
ASSERT_EQ(0, nodeA->getLayer());
|
||||||
|
ASSERT_EQ(1, nodeB->getLayer());
|
||||||
|
ASSERT_EQ(2, nodeC->getLayer());
|
||||||
|
|
||||||
|
ASSERT_EQ(3, nodeD0->getLayer());
|
||||||
|
ASSERT_EQ(3, nodeD1->getLayer());
|
||||||
|
ASSERT_EQ(3, nodeD2->getLayer());
|
||||||
|
|
||||||
|
ASSERT_EQ(4, nodeE0->getLayer());
|
||||||
|
ASSERT_EQ(4, nodeE1->getLayer());
|
||||||
|
ASSERT_EQ(4, nodeE2->getLayer());
|
||||||
|
|
||||||
|
ASSERT_EQ(5, nodeF0->getLayer());
|
||||||
|
ASSERT_EQ(5, nodeF1->getLayer());
|
||||||
|
ASSERT_EQ(5, nodeF2->getLayer());
|
||||||
|
|
||||||
|
ASSERT_EQ(6, nodeG->getLayer());
|
||||||
|
|
||||||
|
auto result = GraphExecutioner::execute(&graph);
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result);
|
||||||
|
|
||||||
|
ASSERT_TRUE(variableSpace->hasVariable(2));
|
||||||
|
auto list = variableSpace->getVariable(2)->getNDArrayList();
|
||||||
|
|
||||||
|
ASSERT_TRUE(list != nullptr);
|
||||||
|
|
||||||
|
ASSERT_EQ(3, list->height());
|
||||||
|
ASSERT_EQ(3, list->elements());
|
||||||
|
|
||||||
|
|
||||||
|
ASSERT_TRUE(variableSpace->hasVariable(20));
|
||||||
|
|
||||||
|
auto stack = variableSpace->getVariable(20)->getNDArray();
|
||||||
|
|
||||||
|
ASSERT_TRUE(stack != nullptr);
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.isSameShape(stack));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(stack));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(ListOperationsTests, GraphTests_Sequential_2) {
|
||||||
|
Graph graph;
|
||||||
|
|
||||||
|
auto scalar = NDArrayFactory::create_<double>(0.0f);
|
||||||
|
auto matrix = NDArrayFactory::create_<double>('c', {3, 3});
|
||||||
|
auto tads = matrix->allTensorsAlongDimension({1});
|
||||||
|
for (int e = 0; e < tads.size(); e++) {
|
||||||
|
tads.at(e)->assign((float) (e+1));
|
||||||
|
}
|
||||||
|
|
||||||
|
auto exp = NDArrayFactory::create<double>('c', {3, 3});
|
||||||
|
auto tadsExp = exp.allTensorsAlongDimension({1});
|
||||||
|
tadsExp.at(0)->assign(0.f);
|
||||||
|
tadsExp.at(1)->assign(-1.f);
|
||||||
|
tadsExp.at(2)->assign(-2.f);
|
||||||
|
|
||||||
|
//auto indices = NDArray<float>::valueOf({1, 3}, 1.0f, 'c');
|
||||||
|
auto indices = NDArrayFactory::create_<double>('c', {1, 3});
|
||||||
|
indices->linspace(0);
|
||||||
|
|
||||||
|
|
||||||
|
auto variableSpace = graph.getVariableSpace();
|
||||||
|
variableSpace->putVariable(-1, matrix);
|
||||||
|
variableSpace->putVariable(-2, indices);
|
||||||
|
variableSpace->putVariable(-3, scalar);
|
||||||
|
|
||||||
|
|
||||||
|
auto nodeA = new Node(OpType_TRANSFORM_SAME, 0, 1, {-1});
|
||||||
|
|
||||||
|
// creating list
|
||||||
|
sd::ops::create_list opB;
|
||||||
|
auto nodeB = new Node(&opB, 2, {1},{},{}, 0.0f, {}, {0, 1});
|
||||||
|
// nodeB->setCustomOp(&opB);
|
||||||
|
|
||||||
|
// filling list with matrix
|
||||||
|
sd::ops::scatter_list opC;
|
||||||
|
auto nodeC = new Node(&opC, 3, {2, -2, 1, -3});
|
||||||
|
|
||||||
|
//nodeC->setCustomOp(&opC);
|
||||||
|
|
||||||
|
sd::ops::read_list opD;
|
||||||
|
auto nodeD0 = new Node(&opD, 5, {2, 3}, {},{}, 0.0f, {}, {0});
|
||||||
|
auto nodeD1 = new Node(&opD, 6, {2, 3, 15}, {},{}, 0.0f, {}, {1});
|
||||||
|
auto nodeD2 = new Node(&opD, 7, {2, 3, 16}, {},{}, 0.0f, {}, {2});
|
||||||
|
|
||||||
|
// nodeD0->setCustomOp(&opD);
|
||||||
|
// nodeD1->setCustomOp(&opD);
|
||||||
|
// nodeD2->setCustomOp(&opD);
|
||||||
|
|
||||||
|
|
||||||
|
// using OneMinus on each chunk separately
|
||||||
|
auto nodeE0 = new Node(OpType_TRANSFORM_SAME, sd::transform::OneMinus, 10, {5});
|
||||||
|
auto nodeE1 = new Node(OpType_TRANSFORM_SAME, sd::transform::OneMinus, 11, {6});
|
||||||
|
auto nodeE2 = new Node(OpType_TRANSFORM_SAME, sd::transform::OneMinus, 12, {7});
|
||||||
|
|
||||||
|
// writing chunks back to the List
|
||||||
|
sd::ops::write_list opF;
|
||||||
|
auto nodeF0 = new Node(&opF, 15, {2, 10}, {},{}, 0.0f, {}, {0});
|
||||||
|
auto nodeF1 = new Node(&opF, 16, {2, 11}, {},{}, 0.0f, {}, {1});
|
||||||
|
auto nodeF2 = new Node(&opF, 17, {2, 12}, {},{}, 0.0f, {}, {2});
|
||||||
|
|
||||||
|
// nodeF0->setCustomOp(&opF);
|
||||||
|
// nodeF1->setCustomOp(&opF);
|
||||||
|
// nodeF2->setCustomOp(&opF);
|
||||||
|
|
||||||
|
// now we're gathering chunks back to matrix state
|
||||||
|
sd::ops::pick_list opG;
|
||||||
|
auto nodeG = new Node(&opG, 20, {2, -2, 15, 16, 17});
|
||||||
|
//auto nodeG = new Node<float>(OpType_CUSTOM, 0, 20, {2});
|
||||||
|
|
||||||
|
//nodeG->setCustomOp(&opG);
|
||||||
|
|
||||||
|
graph.addNode(nodeA);
|
||||||
|
graph.addNode(nodeB);
|
||||||
|
graph.addNode(nodeC);
|
||||||
|
graph.addNode(nodeD0);
|
||||||
|
graph.addNode(nodeD1);
|
||||||
|
graph.addNode(nodeD2);
|
||||||
|
graph.addNode(nodeE0);
|
||||||
|
graph.addNode(nodeE1);
|
||||||
|
graph.addNode(nodeE2);
|
||||||
|
|
||||||
|
graph.addNode(nodeF0);
|
||||||
|
graph.addNode(nodeF1);
|
||||||
|
graph.addNode(nodeF2);
|
||||||
|
|
||||||
|
graph.addNode(nodeG);
|
||||||
|
|
||||||
|
// let's also validate structural integrity
|
||||||
|
graph.buildGraph();
|
||||||
|
|
||||||
|
ASSERT_EQ(0, nodeA->getLayer());
|
||||||
|
ASSERT_EQ(1, nodeB->getLayer());
|
||||||
|
ASSERT_EQ(2, nodeC->getLayer());
|
||||||
|
|
||||||
|
ASSERT_EQ(3, nodeD0->getLayer());
|
||||||
|
ASSERT_EQ(4, nodeE0->getLayer());
|
||||||
|
ASSERT_EQ(5, nodeF0->getLayer());
|
||||||
|
|
||||||
|
ASSERT_EQ(6, nodeD1->getLayer());
|
||||||
|
ASSERT_EQ(7, nodeE1->getLayer());
|
||||||
|
ASSERT_EQ(8, nodeF1->getLayer());
|
||||||
|
|
||||||
|
ASSERT_EQ(9, nodeD2->getLayer());
|
||||||
|
ASSERT_EQ(10, nodeE2->getLayer());
|
||||||
|
ASSERT_EQ(11, nodeF2->getLayer());
|
||||||
|
|
||||||
|
ASSERT_EQ(12, nodeG->getLayer());
|
||||||
|
|
||||||
|
|
||||||
|
auto result = GraphExecutioner::execute(&graph);
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result);
|
||||||
|
|
||||||
|
ASSERT_TRUE(variableSpace->hasVariable(2));
|
||||||
|
auto list = variableSpace->getVariable(2)->getNDArrayList();
|
||||||
|
|
||||||
|
ASSERT_TRUE(list != nullptr);
|
||||||
|
|
||||||
|
ASSERT_EQ(3, list->height());
|
||||||
|
ASSERT_EQ(3, list->elements());
|
||||||
|
|
||||||
|
ASSERT_TRUE(variableSpace->hasVariable(20));
|
||||||
|
|
||||||
|
auto stack = variableSpace->getVariable(20)->getNDArray();
|
||||||
|
|
||||||
|
ASSERT_TRUE(stack != nullptr);
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.isSameShape(stack));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(stack));
|
||||||
|
}
|
|
@ -0,0 +1,225 @@
|
||||||
|
/* ******************************************************************************
|
||||||
|
*
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* See the NOTICE file distributed with this work for additional
|
||||||
|
* information regarding copyright ownership.
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author Abdelrauf
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "testlayers.h"
|
||||||
|
#include <helpers/LoopsCoordsHelper.h>
|
||||||
|
#include <type_traits>
|
||||||
|
using namespace sd;
|
||||||
|
|
||||||
|
class LoopCoordsHelper : public testing::Test {
|
||||||
|
public:
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
template<size_t Rank, size_t rankIndex = 0, bool Last_Index_Faster = true>
|
||||||
|
FORCEINLINE
|
||||||
|
typename std::enable_if<(Rank - 1 == rankIndex), bool>::type
|
||||||
|
eq_strides(CoordsState<Rank - 1>& cbs, const Nd4jLong* strides) {
|
||||||
|
return STRIDE(cbs, rankIndex) == strides[rankIndex];
|
||||||
|
}
|
||||||
|
|
||||||
|
template<size_t Rank, size_t rankIndex = 0>
|
||||||
|
FORCEINLINE
|
||||||
|
typename std::enable_if<(Rank - 1 != rankIndex), bool>::type
|
||||||
|
eq_strides(CoordsState<Rank - 1>& cbs, const Nd4jLong* strides) {
|
||||||
|
return STRIDE(cbs, rankIndex) == strides[rankIndex] && eq_strides<Rank, rankIndex + 1>(cbs, strides);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<size_t Rank, size_t rankIndex = 0, bool Last_Index_Faster = true>
|
||||||
|
FORCEINLINE
|
||||||
|
typename std::enable_if<(Rank - 1 == rankIndex), bool>::type
|
||||||
|
eq_zip_strides(ZipCoordsState<Rank - 1>& cbs, const Nd4jLong* strides1, const Nd4jLong* strides2) {
|
||||||
|
return ZIP_STRIDE1(cbs, rankIndex) == strides1[rankIndex] && ZIP_STRIDE2(cbs, rankIndex) == strides2[rankIndex];
|
||||||
|
}
|
||||||
|
|
||||||
|
template<size_t Rank, size_t rankIndex = 0>
|
||||||
|
FORCEINLINE
|
||||||
|
typename std::enable_if<(Rank - 1 != rankIndex), bool>::type
|
||||||
|
eq_zip_strides(ZipCoordsState<Rank - 1>& cbs, const Nd4jLong* strides1, const Nd4jLong* strides2) {
|
||||||
|
return ZIP_STRIDE1(cbs, rankIndex) == strides1[rankIndex] && ZIP_STRIDE2(cbs, rankIndex) == strides2[rankIndex]
|
||||||
|
&& eq_zip_strides<Rank, rankIndex + 1>(cbs, strides1, strides2);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(LoopCoordsHelper, Init_Tests) {
|
||||||
|
|
||||||
|
constexpr size_t test_Index = 131;
|
||||||
|
constexpr size_t Rank = 5;
|
||||||
|
|
||||||
|
Nd4jLong shape[Rank] = { 3, 5 ,7, 8, 9};
|
||||||
|
Nd4jLong multiply_st[] = { 2,3,3,5,6,7,9,3 };
|
||||||
|
Nd4jLong strides_c[Rank] ;
|
||||||
|
Nd4jLong strides_f[Rank];
|
||||||
|
|
||||||
|
Nd4jLong coords[Rank];
|
||||||
|
Nd4jLong coords_f[Rank];
|
||||||
|
|
||||||
|
strides_f[0] = multiply_st[0] * shape[0];
|
||||||
|
strides_c[Rank-1] = multiply_st[Rank-1] * shape[Rank-1];
|
||||||
|
|
||||||
|
for (int i = 1; i < Rank; i++) {
|
||||||
|
strides_f[i] = strides_f[i - 1] * multiply_st[i] * shape[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = Rank-2; i >=0; i--) {
|
||||||
|
strides_c[i] = strides_c[i+1] * multiply_st[i] * shape[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
//init our base coords
|
||||||
|
index2coords_C(test_Index, Rank, shape, coords);
|
||||||
|
index2coords_F(test_Index, Rank, shape, coords_f);
|
||||||
|
|
||||||
|
|
||||||
|
size_t offset_calc = offset_from_coords(strides_c, coords, Rank);
|
||||||
|
size_t offset_calc_f = offset_from_coords(strides_f, coords_f, Rank);
|
||||||
|
|
||||||
|
CoordsState<Rank-1> cts;
|
||||||
|
CoordsState<Rank-1> cts_f;
|
||||||
|
|
||||||
|
ZipCoordsState<Rank-1> zcts;
|
||||||
|
ZipCoordsState<Rank-1> zcts_f;
|
||||||
|
|
||||||
|
size_t offset = init_coords<Rank>(cts, test_Index, shape, strides_c);
|
||||||
|
size_t offset_f = init_coords<Rank,0,false>(cts_f, test_Index, shape, strides_f);
|
||||||
|
|
||||||
|
zip_size_t zoffset = init_coords<Rank>(zcts, test_Index, shape, strides_c, strides_c);
|
||||||
|
zip_size_t zoffset_f = init_coords<Rank, 0, false>(zcts_f, test_Index, shape, strides_f, strides_f);
|
||||||
|
|
||||||
|
ASSERT_TRUE(eq_coords<Rank>(cts, coords));
|
||||||
|
ASSERT_TRUE(eq_coords<Rank>(cts_f, coords_f));
|
||||||
|
|
||||||
|
ASSERT_TRUE(eq_zip_coords<Rank>(zcts, coords));
|
||||||
|
ASSERT_TRUE(eq_zip_coords<Rank>(zcts_f, coords_f));
|
||||||
|
|
||||||
|
ASSERT_TRUE(eq_strides<Rank>(cts,strides_c));
|
||||||
|
ASSERT_TRUE(eq_strides<Rank>(cts_f,strides_f));
|
||||||
|
|
||||||
|
ASSERT_TRUE(eq_zip_strides<Rank>(zcts, strides_c, strides_c));
|
||||||
|
ASSERT_TRUE(eq_zip_strides<Rank>(zcts_f, strides_f, strides_f));
|
||||||
|
|
||||||
|
|
||||||
|
ASSERT_EQ(offset , offset_calc);
|
||||||
|
ASSERT_EQ(zoffset.first , offset_calc);
|
||||||
|
ASSERT_EQ(zoffset.second , offset_calc);
|
||||||
|
ASSERT_EQ(offset_f , offset_calc_f);
|
||||||
|
ASSERT_EQ(zoffset_f.first , offset_calc_f);
|
||||||
|
ASSERT_EQ(zoffset_f.second , offset_calc_f);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(LoopCoordsHelper, Increment_Use_Tests) {
|
||||||
|
|
||||||
|
|
||||||
|
constexpr size_t Rank = 4;
|
||||||
|
|
||||||
|
Nd4jLong shape[Rank] = { 3, 5 ,7, 8 };
|
||||||
|
Nd4jLong multiply_st[] = { 2,3,3,5,6,7,9,3 };
|
||||||
|
Nd4jLong strides_c[Rank];
|
||||||
|
Nd4jLong strides_f[Rank];
|
||||||
|
|
||||||
|
Nd4jLong coords[Rank] = {};
|
||||||
|
Nd4jLong coords_f[Rank] = {};
|
||||||
|
Nd4jLong coords2[Rank] = {};
|
||||||
|
Nd4jLong coords2_f[Rank] = {};
|
||||||
|
Nd4jLong zcoords2[Rank] = {};
|
||||||
|
Nd4jLong zcoords2_f[Rank] = {};
|
||||||
|
|
||||||
|
strides_f[0] = multiply_st[0] * shape[0];
|
||||||
|
strides_c[Rank - 1] = multiply_st[Rank - 1] * shape[Rank - 1];
|
||||||
|
|
||||||
|
for (int i = 1; i < Rank; i++) {
|
||||||
|
strides_f[i] = strides_f[i - 1] * multiply_st[i] * shape[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = Rank - 2; i >= 0; i--) {
|
||||||
|
strides_c[i] = strides_c[i + 1] * multiply_st[i] * shape[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
int total = 1;
|
||||||
|
for (int i = 0; i < Rank; i++) {
|
||||||
|
total *= shape[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
CoordsState<Rank - 1> cts;
|
||||||
|
CoordsState<Rank - 1> cts_f;
|
||||||
|
|
||||||
|
ZipCoordsState<Rank - 1> zcts;
|
||||||
|
ZipCoordsState<Rank - 1> zcts_f;
|
||||||
|
|
||||||
|
size_t offset = init_coords<Rank>(cts, 0, shape, strides_c);
|
||||||
|
size_t offset_f = init_coords<Rank, 0, false>(cts_f, 0, shape, strides_f);
|
||||||
|
|
||||||
|
zip_size_t zoffset = init_coords<Rank>(zcts, 0, shape, strides_c, strides_c);
|
||||||
|
zip_size_t zoffset_f = init_coords<Rank, 0, false>(zcts_f, 0, shape, strides_f, strides_f);
|
||||||
|
|
||||||
|
size_t offset2 = 0;
|
||||||
|
size_t offset2_f = 0;
|
||||||
|
zip_size_t zoffset2 = {};
|
||||||
|
zip_size_t zoffset2_f = {};
|
||||||
|
|
||||||
|
for (int j = 0; j < total; j++) {
|
||||||
|
|
||||||
|
|
||||||
|
index2coords_C(j, Rank, shape, coords);
|
||||||
|
index2coords_F(j, Rank, shape, coords_f);
|
||||||
|
|
||||||
|
size_t offset_calc = offset_from_coords(strides_c, coords, Rank);
|
||||||
|
size_t offset_calc_f = offset_from_coords(strides_f, coords_f, Rank);
|
||||||
|
|
||||||
|
|
||||||
|
ASSERT_TRUE(eq_coords<Rank>(cts, coords));
|
||||||
|
ASSERT_TRUE(eq_coords<Rank>(cts_f, coords_f));
|
||||||
|
|
||||||
|
ASSERT_TRUE(eq_zip_coords<Rank>(zcts, coords));
|
||||||
|
ASSERT_TRUE(eq_zip_coords<Rank>(zcts_f, coords_f));
|
||||||
|
|
||||||
|
ASSERT_EQ(offset, offset_calc);
|
||||||
|
ASSERT_EQ(zoffset.first, offset_calc);
|
||||||
|
ASSERT_EQ(zoffset.second, offset_calc);
|
||||||
|
ASSERT_EQ(offset_f, offset_calc_f);
|
||||||
|
ASSERT_EQ(zoffset_f.first, offset_calc_f);
|
||||||
|
ASSERT_EQ(zoffset_f.second, offset_calc_f);
|
||||||
|
|
||||||
|
|
||||||
|
ASSERT_EQ(offset2, offset_calc);
|
||||||
|
ASSERT_EQ(zoffset2.first, offset_calc);
|
||||||
|
ASSERT_EQ(zoffset2.second, offset_calc);
|
||||||
|
ASSERT_EQ(offset2_f, offset_calc_f);
|
||||||
|
ASSERT_EQ(zoffset2_f.first, offset_calc_f);
|
||||||
|
ASSERT_EQ(zoffset2_f.second, offset_calc_f);
|
||||||
|
|
||||||
|
offset = inc_coords<Rank>(cts, offset);
|
||||||
|
offset_f = inc_coords<Rank,0,false>(cts_f, offset_f);
|
||||||
|
zoffset = inc_coords<Rank>(zcts, zoffset);
|
||||||
|
zoffset_f = inc_coords<Rank, 0, false>(zcts_f, zoffset_f);
|
||||||
|
|
||||||
|
offset2 = inc_coords(shape,strides_c, coords2, offset2, Rank);
|
||||||
|
offset2_f = inc_coords<false>(shape, strides_f, coords2_f, offset2_f, Rank);
|
||||||
|
zoffset2 = inc_coords(shape, strides_c, strides_c, zcoords2, zoffset2, Rank);
|
||||||
|
zoffset2_f = inc_coords<false>(shape, strides_f, strides_f, zcoords2_f, zoffset2_f, Rank);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
|
@ -0,0 +1,48 @@
|
||||||
|
/* ******************************************************************************
|
||||||
|
*
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* See the NOTICE file distributed with this work for additional
|
||||||
|
* information regarding copyright ownership.
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// Created by raver119 on 11.10.2017.
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <memory/MemoryReport.h>
|
||||||
|
#include <memory/MemoryUtils.h>
|
||||||
|
#include "testlayers.h"
|
||||||
|
|
||||||
|
using namespace sd::memory;
|
||||||
|
|
||||||
|
class MemoryUtilsTests : public testing::Test {
|
||||||
|
public:
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(MemoryUtilsTests, BasicRetrieve_1) {
|
||||||
|
MemoryReport reportA;
|
||||||
|
MemoryReport reportB;
|
||||||
|
|
||||||
|
#ifdef _WIN32
|
||||||
|
if (1 > 0)
|
||||||
|
return;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
|
||||||
|
MemoryUtils::retrieveMemoryStatistics(reportA);
|
||||||
|
|
||||||
|
|
||||||
|
ASSERT_NE(reportA, reportB);
|
||||||
|
}
|
|
@ -0,0 +1,111 @@
|
||||||
|
/* ******************************************************************************
|
||||||
|
*
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* See the NOTICE file distributed with this work for additional
|
||||||
|
* information regarding copyright ownership.
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#ifdef HAVE_MKLDNN
|
||||||
|
|
||||||
|
#include "testlayers.h"
|
||||||
|
#include <initializer_list>
|
||||||
|
#include <ops/declarable/PlatformHelper.h>
|
||||||
|
#include <ops/declarable/platform/mkldnn/mkldnnUtils.h>
|
||||||
|
#include <array/NDArrayFactory.h>
|
||||||
|
#include <ops/declarable/CustomOperations.h>
|
||||||
|
|
||||||
|
|
||||||
|
using namespace sd;
|
||||||
|
|
||||||
|
class MklDnnTests : public testing::Test {
|
||||||
|
public:
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
static void printer(std::initializer_list<sd::ops::platforms::PlatformHelper*> helpers) {
|
||||||
|
|
||||||
|
for (auto v:helpers) {
|
||||||
|
nd4j_printf("Initialized [%s]\n", v->name().c_str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(MklDnnTests, helpers_includer) {
|
||||||
|
// we need this block, to make sure all helpers are still available within binary, and not optimized out by linker
|
||||||
|
sd::ops::platforms::PLATFORM_conv2d_ENGINE_CPU conv2d;
|
||||||
|
sd::ops::platforms::PLATFORM_conv2d_bp_ENGINE_CPU conv2d_bp;
|
||||||
|
|
||||||
|
sd::ops::platforms::PLATFORM_conv2d_ENGINE_CPU conv3d;
|
||||||
|
sd::ops::platforms::PLATFORM_conv2d_bp_ENGINE_CPU conv3d_bp;
|
||||||
|
|
||||||
|
sd::ops::platforms::PLATFORM_avgpool2d_ENGINE_CPU avgpool2d;
|
||||||
|
sd::ops::platforms::PLATFORM_avgpool2d_bp_ENGINE_CPU avgpool2d_bp;
|
||||||
|
|
||||||
|
sd::ops::platforms::PLATFORM_maxpool2d_ENGINE_CPU maxpool2d;
|
||||||
|
sd::ops::platforms::PLATFORM_maxpool2d_bp_ENGINE_CPU maxpool2d_bp;
|
||||||
|
|
||||||
|
sd::ops::platforms::PLATFORM_avgpool3dnew_ENGINE_CPU avgpool3d;
|
||||||
|
sd::ops::platforms::PLATFORM_avgpool3dnew_bp_ENGINE_CPU avgpool3d_bp;
|
||||||
|
|
||||||
|
sd::ops::platforms::PLATFORM_maxpool3dnew_ENGINE_CPU maxpool3d;
|
||||||
|
sd::ops::platforms::PLATFORM_maxpool3dnew_bp_ENGINE_CPU maxpool3d_bp;
|
||||||
|
|
||||||
|
sd::ops::platforms::PLATFORM_lrn_ENGINE_CPU lrn;
|
||||||
|
|
||||||
|
sd::ops::platforms::PLATFORM_batchnorm_ENGINE_CPU batchnorm;
|
||||||
|
|
||||||
|
sd::ops::platforms::PLATFORM_matmul_ENGINE_CPU matmul;
|
||||||
|
|
||||||
|
sd::ops::platforms::PLATFORM_softmax_ENGINE_CPU softmax;
|
||||||
|
|
||||||
|
sd::ops::platforms::PLATFORM_softmax_bp_ENGINE_CPU softmax_bp;
|
||||||
|
|
||||||
|
sd::ops::platforms::PLATFORM_tanh_ENGINE_CPU tanh;
|
||||||
|
|
||||||
|
sd::ops::platforms::PLATFORM_tanh_ENGINE_CPU tanh_bp;
|
||||||
|
|
||||||
|
sd::ops::platforms::PLATFORM_xw_plus_b_ENGINE_CPU xw_plus_b;
|
||||||
|
|
||||||
|
sd::ops::platforms::PLATFORM_xw_plus_b_bp_ENGINE_CPU xw_plus_b_bp;
|
||||||
|
|
||||||
|
|
||||||
|
printer({&conv2d, &conv2d_bp, &conv3d, &conv3d_bp, &avgpool2d, &avgpool2d_bp, &maxpool2d, &maxpool2d_bp, &avgpool3d, &avgpool3d_bp, &maxpool3d, &maxpool3d_bp, &lrn, &batchnorm, &matmul, &softmax, &softmax_bp, &tanh, &tanh_bp, &xw_plus_b, &xw_plus_b_bp });
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(MklDnnTests, test_tanh_1) {
|
||||||
|
auto x = NDArrayFactory::create<float>(1.0f);
|
||||||
|
auto z = NDArrayFactory::create<float>(0.0f);
|
||||||
|
|
||||||
|
sd::ops::tanh op;
|
||||||
|
auto status = op.execute({&x}, {&z});
|
||||||
|
|
||||||
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(MklDnnTests, test_tanh_2) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {1}, {1.0f});
|
||||||
|
auto z = NDArrayFactory::create<float>('c', {1}, {0.0f});
|
||||||
|
|
||||||
|
sd::ops::tanh op;
|
||||||
|
auto status = op.execute({&x}, {&z});
|
||||||
|
|
||||||
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
|
@ -0,0 +1,57 @@
|
||||||
|
/* ******************************************************************************
|
||||||
|
*
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* See the NOTICE file distributed with this work for additional
|
||||||
|
* information regarding copyright ownership.
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// Created by raver on 5/13/2018.
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "testlayers.h"
|
||||||
|
#include <ops/declarable/CustomOperations.h>
|
||||||
|
#include <array/NDArray.h>
|
||||||
|
#include <legacy/NativeOps.h>
|
||||||
|
#include <fstream>
|
||||||
|
|
||||||
|
using namespace sd;
|
||||||
|
using namespace sd::graph;
|
||||||
|
|
||||||
|
class MmapTests : public testing::Test {
|
||||||
|
public:
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(MmapTests, Test_Basic_Mmap_1) {
|
||||||
|
// FIXME: we must adopt this for CUDA as well
|
||||||
|
if (!Environment::getInstance().isCPU())
|
||||||
|
return;
|
||||||
|
|
||||||
|
// just 10GB
|
||||||
|
Nd4jLong size = 100000L;
|
||||||
|
|
||||||
|
std::ofstream ofs("file", std::ios::binary | std::ios::out);
|
||||||
|
ofs.seekp(size + 1024L);
|
||||||
|
ofs.write("", 1);
|
||||||
|
ofs.close();
|
||||||
|
|
||||||
|
auto result = mmapFile(nullptr, "file", size);
|
||||||
|
|
||||||
|
ASSERT_FALSE(result == nullptr);
|
||||||
|
|
||||||
|
munmapFile(nullptr, result, size);
|
||||||
|
|
||||||
|
remove("file");
|
||||||
|
}
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,72 @@
|
||||||
|
/*
|
||||||
|
* ******************************************************************************
|
||||||
|
* *
|
||||||
|
* *
|
||||||
|
* * This program and the accompanying materials are made available under the
|
||||||
|
* * terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
* *
|
||||||
|
* * See the NOTICE file distributed with this work for additional
|
||||||
|
* * information regarding copyright ownership.
|
||||||
|
* * Unless required by applicable law or agreed to in writing, software
|
||||||
|
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* * License for the specific language governing permissions and limitations
|
||||||
|
* * under the License.
|
||||||
|
* *
|
||||||
|
* * SPDX-License-Identifier: Apache-2.0
|
||||||
|
* *****************************************************************************
|
||||||
|
*/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "testlayers.h"
|
||||||
|
#include <array/ArrayOptions.h>
|
||||||
|
#include <execution/AffinityManager.h>
|
||||||
|
#include <array/NDArray.h>
|
||||||
|
#include <array/NDArrayFactory.h>
|
||||||
|
#include <ops/declarable/headers/broadcastable.h>
|
||||||
|
#include <helpers/MmulHelper.h>
|
||||||
|
#include <thread>
|
||||||
|
|
||||||
|
|
||||||
|
using namespace sd;
|
||||||
|
|
||||||
|
class MultiDeviceTests : public testing::Test {
|
||||||
|
public:
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
void createArrays(int limit, std::vector<NDArray*> &arrays) {
|
||||||
|
auto deviceId = AffinityManager::currentDeviceId();
|
||||||
|
auto numDevices = AffinityManager::numberOfDevices();
|
||||||
|
|
||||||
|
for (int e = 0; e < limit; e++) {
|
||||||
|
auto value = deviceId * limit + e;
|
||||||
|
arrays[value] = NDArrayFactory::create_<float>('c', {10});
|
||||||
|
arrays[value]->assign(value);
|
||||||
|
//nd4j_printf("device_%i; value: [%i]; mean: [%f]\n", deviceId, value, arrays[value]->meanNumber().e<float>(0));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(MultiDeviceTests, test_multi_device_migration_1) {
|
||||||
|
auto deviceId = AffinityManager::currentDeviceId();
|
||||||
|
auto numDevices = AffinityManager::numberOfDevices();
|
||||||
|
auto numArrays = 10;
|
||||||
|
std::vector<NDArray*> arrays(numDevices * numArrays);
|
||||||
|
|
||||||
|
// filling list of arrays on multiple threads
|
||||||
|
for (int e = 0; e < numDevices; e++) {
|
||||||
|
std::thread t1(createArrays, numArrays, std::ref(arrays));
|
||||||
|
|
||||||
|
t1.join();
|
||||||
|
}
|
||||||
|
|
||||||
|
// at this moment all arrays are build, so we can test migration
|
||||||
|
for (int e = 0; e < arrays.size(); e++) {
|
||||||
|
ASSERT_NEAR((float) e, arrays[e]->meanNumber().e<float>(0), 1e-5f);
|
||||||
|
delete arrays[e];
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,208 @@
|
||||||
|
/* ******************************************************************************
|
||||||
|
*
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* See the NOTICE file distributed with this work for additional
|
||||||
|
* information regarding copyright ownership.
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "testlayers.h"
|
||||||
|
#include <array/NDArray.h>
|
||||||
|
#include <array/NDArrayFactory.h>
|
||||||
|
#include <graph/Context.h>
|
||||||
|
#include <graph/Node.h>
|
||||||
|
#include <graph/Variable.h>
|
||||||
|
#include <graph/VariableSpace.h>
|
||||||
|
#include <execution/LaunchContext.h>
|
||||||
|
#include <ops/specials_cuda.h>
|
||||||
|
#include <helpers/TAD.h>
|
||||||
|
|
||||||
|
#include <cuda.h>
|
||||||
|
|
||||||
|
using namespace sd;
|
||||||
|
using namespace sd::graph;
|
||||||
|
|
||||||
|
class NDArrayConstructorsTests : public testing::Test {
|
||||||
|
public:
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(NDArrayConstructorsTests, test_constructor_1) {
|
||||||
|
auto x = NDArrayFactory::empty_<float>();
|
||||||
|
|
||||||
|
ASSERT_TRUE(x->buffer() == nullptr);
|
||||||
|
ASSERT_TRUE(x->specialBuffer() == nullptr);
|
||||||
|
|
||||||
|
ASSERT_FALSE(x->shapeInfo() == nullptr);
|
||||||
|
ASSERT_FALSE(x->specialShapeInfo() == nullptr);
|
||||||
|
|
||||||
|
ASSERT_TRUE(x->isActualOnDeviceSide());
|
||||||
|
ASSERT_TRUE(x->isActualOnHostSide());
|
||||||
|
|
||||||
|
delete x;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(NDArrayConstructorsTests, test_constructor_2) {
|
||||||
|
auto x = NDArrayFactory::vector<float>(5, 1.0f);
|
||||||
|
|
||||||
|
|
||||||
|
ASSERT_FALSE(x->buffer() == nullptr);
|
||||||
|
ASSERT_FALSE(x->specialBuffer() == nullptr);
|
||||||
|
|
||||||
|
ASSERT_FALSE(x->shapeInfo() == nullptr);
|
||||||
|
ASSERT_FALSE(x->specialShapeInfo() == nullptr);
|
||||||
|
|
||||||
|
ASSERT_TRUE(x->isActualOnDeviceSide());
|
||||||
|
ASSERT_FALSE(x->isActualOnHostSide());
|
||||||
|
|
||||||
|
delete x;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(NDArrayConstructorsTests, test_constructor_3) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c',{5, 5});
|
||||||
|
|
||||||
|
ASSERT_TRUE(x.buffer() == nullptr);
|
||||||
|
ASSERT_FALSE(x.specialBuffer() == nullptr);
|
||||||
|
|
||||||
|
ASSERT_FALSE(x.shapeInfo() == nullptr);
|
||||||
|
ASSERT_FALSE(x.specialShapeInfo() == nullptr);
|
||||||
|
|
||||||
|
ASSERT_TRUE(x.isActualOnDeviceSide());
|
||||||
|
ASSERT_FALSE(x.isActualOnHostSide());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(NDArrayConstructorsTests, test_constructor_4) {
|
||||||
|
auto x = NDArrayFactory::create(sd::DataType::FLOAT32, 1.0f);
|
||||||
|
|
||||||
|
ASSERT_FALSE(x.buffer() == nullptr);
|
||||||
|
ASSERT_FALSE(x.specialBuffer() == nullptr);
|
||||||
|
|
||||||
|
ASSERT_FALSE(x.shapeInfo() == nullptr);
|
||||||
|
ASSERT_FALSE(x.specialShapeInfo() == nullptr);
|
||||||
|
|
||||||
|
ASSERT_TRUE(x.isActualOnDeviceSide());
|
||||||
|
ASSERT_TRUE(x.isActualOnHostSide());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(NDArrayConstructorsTests, test_constructor_5) {
|
||||||
|
auto x = NDArrayFactory::create<double>('c',{2, 2}, {1, 2, 3, 4});
|
||||||
|
|
||||||
|
ASSERT_TRUE(x.buffer() == nullptr);
|
||||||
|
ASSERT_FALSE(x.specialBuffer() == nullptr);
|
||||||
|
|
||||||
|
ASSERT_FALSE(x.shapeInfo() == nullptr);
|
||||||
|
ASSERT_FALSE(x.specialShapeInfo() == nullptr);
|
||||||
|
|
||||||
|
ASSERT_TRUE(x.isActualOnDeviceSide());
|
||||||
|
ASSERT_FALSE(x.isActualOnHostSide());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(NDArrayConstructorsTests, test_constructor_6) {
|
||||||
|
auto x = NDArrayFactory::create<double>('c', {2, 2}, {1, 2, 3, 4});
|
||||||
|
NDArray y(x);
|
||||||
|
|
||||||
|
ASSERT_TRUE(y.buffer() == nullptr);
|
||||||
|
ASSERT_FALSE(y.specialBuffer() == nullptr);
|
||||||
|
|
||||||
|
ASSERT_FALSE(y.shapeInfo() == nullptr);
|
||||||
|
ASSERT_FALSE(y.specialShapeInfo() == nullptr);
|
||||||
|
|
||||||
|
ASSERT_TRUE(y.isActualOnDeviceSide());
|
||||||
|
ASSERT_FALSE(y.isActualOnHostSide());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(NDArrayConstructorsTests, test_constructor_7) {
|
||||||
|
auto x = NDArrayFactory::create<float>(1.0f);
|
||||||
|
|
||||||
|
ASSERT_FALSE(x.buffer() == nullptr);
|
||||||
|
ASSERT_FALSE(x.specialBuffer() == nullptr);
|
||||||
|
|
||||||
|
ASSERT_FALSE(x.shapeInfo() == nullptr);
|
||||||
|
ASSERT_FALSE(x.specialShapeInfo() == nullptr);
|
||||||
|
|
||||||
|
ASSERT_TRUE(x.isActualOnDeviceSide());
|
||||||
|
ASSERT_TRUE(x.isActualOnHostSide());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(NDArrayConstructorsTests, test_constructor_8) {
|
||||||
|
auto x = NDArrayFactory::create_<double>('c',{2, 2}, {1, 2, 3, 4});
|
||||||
|
|
||||||
|
ASSERT_TRUE(x->buffer() == nullptr);
|
||||||
|
ASSERT_FALSE(x->specialBuffer() == nullptr);
|
||||||
|
|
||||||
|
ASSERT_FALSE(x->shapeInfo() == nullptr);
|
||||||
|
ASSERT_FALSE(x->specialShapeInfo() == nullptr);
|
||||||
|
|
||||||
|
ASSERT_TRUE(x->isActualOnDeviceSide());
|
||||||
|
ASSERT_FALSE(x->isActualOnHostSide());
|
||||||
|
|
||||||
|
delete x;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(NDArrayConstructorsTests, test_constructor_9) {
|
||||||
|
auto x = NDArrayFactory::create_<double>('c',{2, 2});
|
||||||
|
|
||||||
|
ASSERT_TRUE(x->buffer() == nullptr);
|
||||||
|
ASSERT_FALSE(x->specialBuffer() == nullptr);
|
||||||
|
|
||||||
|
ASSERT_FALSE(x->shapeInfo() == nullptr);
|
||||||
|
ASSERT_FALSE(x->specialShapeInfo() == nullptr);
|
||||||
|
|
||||||
|
ASSERT_TRUE(x->isActualOnDeviceSide());
|
||||||
|
ASSERT_FALSE(x->isActualOnHostSide());
|
||||||
|
|
||||||
|
delete x;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(NDArrayConstructorsTests, test_linspace_1) {
|
||||||
|
auto x = NDArrayFactory::linspace<float>(1.0f, 10.0f, 20);
|
||||||
|
|
||||||
|
ASSERT_FALSE(x->buffer() == nullptr);
|
||||||
|
ASSERT_FALSE(x->specialBuffer() == nullptr);
|
||||||
|
|
||||||
|
ASSERT_FALSE(x->shapeInfo() == nullptr);
|
||||||
|
ASSERT_FALSE(x->specialShapeInfo() == nullptr);
|
||||||
|
|
||||||
|
ASSERT_TRUE(x->isActualOnDeviceSide());
|
||||||
|
ASSERT_TRUE(x->isActualOnHostSide());
|
||||||
|
|
||||||
|
delete x;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(NDArrayConstructorsTests, test_constructor_10) {
|
||||||
|
|
||||||
|
NDArray scalar1(sd::DataType::DOUBLE); // scalar1 = 0
|
||||||
|
NDArray scalar2('c', {}, std::vector<double>{0});
|
||||||
|
|
||||||
|
ASSERT_TRUE(scalar1.isActualOnDeviceSide());
|
||||||
|
ASSERT_TRUE(!scalar1.isActualOnHostSide());
|
||||||
|
ASSERT_TRUE(scalar2.isActualOnDeviceSide());
|
||||||
|
ASSERT_TRUE(scalar2.isActualOnHostSide());
|
||||||
|
|
||||||
|
ASSERT_TRUE(scalar2.equalsTo(scalar1));
|
||||||
|
|
||||||
|
ASSERT_TRUE(scalar1.isActualOnDeviceSide());
|
||||||
|
ASSERT_TRUE(!scalar1.isActualOnHostSide());
|
||||||
|
ASSERT_TRUE(scalar2.isActualOnDeviceSide());
|
||||||
|
ASSERT_TRUE(scalar2.isActualOnHostSide());
|
||||||
|
|
||||||
|
ASSERT_TRUE(scalar1.buffer() == nullptr);
|
||||||
|
ASSERT_TRUE(scalar1.specialBuffer() != nullptr);
|
||||||
|
ASSERT_TRUE(scalar1.shapeInfo() != nullptr);
|
||||||
|
ASSERT_TRUE(scalar1.specialShapeInfo() != nullptr);
|
||||||
|
ASSERT_TRUE(scalar1.lengthOf() == 1);
|
||||||
|
}
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,75 @@
|
||||||
|
/* ******************************************************************************
|
||||||
|
*
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* See the NOTICE file distributed with this work for additional
|
||||||
|
* information regarding copyright ownership.
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <array/NDArray.h>
|
||||||
|
#include <array/NDArrayList.h>
|
||||||
|
#include "testlayers.h"
|
||||||
|
|
||||||
|
using namespace sd;
|
||||||
|
|
||||||
|
class NDArrayListTests : public testing::Test {
|
||||||
|
public:
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(NDArrayListTests, BasicTests_1) {
|
||||||
|
NDArrayList list(false);
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {1, 10});
|
||||||
|
auto y = NDArrayFactory::create<float>('c', {1, 10});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, list.write(1, new NDArray(x.dup())));
|
||||||
|
|
||||||
|
//ASSERT_EQ(ND4J_STATUS_DOUBLE_WRITE, list.write(1, &y));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(NDArrayListTests, BasicTests_2) {
|
||||||
|
NDArrayList list(false);
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {1, 10});
|
||||||
|
auto y = NDArrayFactory::create<float>('c', {1, 7});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, list.write(1, new NDArray(x.dup())));
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_BAD_INPUT, list.write(0, &y));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(NDArrayListTests, Test_Stack_UnStack_1) {
|
||||||
|
auto input = NDArrayFactory::create<float>('c', {10, 10});
|
||||||
|
input.linspace(1);
|
||||||
|
|
||||||
|
NDArrayList list(false);
|
||||||
|
|
||||||
|
list.unstack(&input, 0);
|
||||||
|
|
||||||
|
ASSERT_EQ(10, list.elements());
|
||||||
|
|
||||||
|
auto array = list.stack();
|
||||||
|
|
||||||
|
ASSERT_TRUE(input.isSameShape(array));
|
||||||
|
|
||||||
|
ASSERT_TRUE(input.equalsTo(array));
|
||||||
|
|
||||||
|
delete array;
|
||||||
|
}
|
File diff suppressed because one or more lines are too long
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,476 @@
|
||||||
|
/* ******************************************************************************
|
||||||
|
*
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* See the NOTICE file distributed with this work for additional
|
||||||
|
* information regarding copyright ownership.
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "testlayers.h"
|
||||||
|
#include <ops/declarable/CustomOperations.h>
|
||||||
|
#include <array/NDArray.h>
|
||||||
|
#include <ops/ops.h>
|
||||||
|
#include <helpers/GradCheck.h>
|
||||||
|
#include <helpers/RandomLauncher.h>
|
||||||
|
|
||||||
|
|
||||||
|
using namespace sd;
|
||||||
|
|
||||||
|
|
||||||
|
class NlpTests : public testing::Test {
|
||||||
|
public:
|
||||||
|
|
||||||
|
NlpTests() {
|
||||||
|
printf("\n");
|
||||||
|
fflush(stdout);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(NlpTests, basic_sg_hs_test_1) {
|
||||||
|
auto exp0 = NDArrayFactory::create<float>('c', {1, 10});
|
||||||
|
auto exp1 = NDArrayFactory::create<float>('c', {1, 10});
|
||||||
|
|
||||||
|
exp0.assign(0.01001f);
|
||||||
|
exp1.assign(0.020005f);
|
||||||
|
|
||||||
|
auto target = NDArrayFactory::create<int>(0);
|
||||||
|
auto ngStarter = NDArrayFactory::empty<int>();
|
||||||
|
auto indices = NDArrayFactory::create<int>('c', {1}, {1});
|
||||||
|
auto codes = NDArrayFactory::create<int8_t>('c', {1});
|
||||||
|
auto syn0 = NDArrayFactory::create<float>('c', {100, 10});
|
||||||
|
auto syn1 = NDArrayFactory::create<float>('c', {100, 10});
|
||||||
|
auto syn1Neg = NDArrayFactory::empty<float>();
|
||||||
|
auto expTable = NDArrayFactory::create<float>('c', {10000});
|
||||||
|
auto negTable = NDArrayFactory::empty<float>();
|
||||||
|
auto neu1e = NDArrayFactory::create<float>('c', {10});
|
||||||
|
|
||||||
|
syn0.assign(0.01);
|
||||||
|
syn1.assign(0.02);
|
||||||
|
expTable.assign(0.5);
|
||||||
|
|
||||||
|
auto alpha = NDArrayFactory::create<double>(0.001);
|
||||||
|
auto randomValue = NDArrayFactory::create<Nd4jLong>(1L);
|
||||||
|
auto inferenceVector = NDArrayFactory::empty<float>();
|
||||||
|
|
||||||
|
sd::ops::skipgram op;
|
||||||
|
auto result = op.evaluate({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {}, {false}, {}, true);
|
||||||
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
|
auto row0 = syn0({0,1, 0,0}, true);
|
||||||
|
auto row1 = syn1({1,2, 0,0}, true);
|
||||||
|
|
||||||
|
ASSERT_EQ(exp0, row0);
|
||||||
|
ASSERT_EQ(exp1, row1);
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(NlpTests, basic_sg_hs_test_2) {
|
||||||
|
auto exp0 = NDArrayFactory::create<float>('c', {1, 10});
|
||||||
|
auto exp1 = NDArrayFactory::create<float>('c', {1, 10});
|
||||||
|
auto exp2 = NDArrayFactory::create<float>('c', {1, 10});
|
||||||
|
|
||||||
|
exp0.assign(0.01f);
|
||||||
|
exp1.assign(0.020005f);
|
||||||
|
exp2.assign(0.019995f);
|
||||||
|
|
||||||
|
auto target = NDArrayFactory::create<int>(0);
|
||||||
|
auto ngStarter = NDArrayFactory::empty<int>();
|
||||||
|
auto indices = NDArrayFactory::create<int>('c', {2}, {1, 2});
|
||||||
|
auto codes = NDArrayFactory::create<int8_t>('c', {2}, {0, 1});
|
||||||
|
auto syn0 = NDArrayFactory::create<float>('c', {100, 10});
|
||||||
|
auto syn1 = NDArrayFactory::create<float>('c', {100, 10});
|
||||||
|
auto syn1Neg = NDArrayFactory::empty<float>();
|
||||||
|
auto expTable = NDArrayFactory::create<float>('c', {10000});
|
||||||
|
auto negTable = NDArrayFactory::empty<float>();
|
||||||
|
auto neu1e = NDArrayFactory::create<float>('c', {10});
|
||||||
|
|
||||||
|
syn0.assign(0.01);
|
||||||
|
syn1.assign(0.02);
|
||||||
|
expTable.assign(0.5);
|
||||||
|
|
||||||
|
auto alpha = NDArrayFactory::create<double>(0.001);
|
||||||
|
auto randomValue = NDArrayFactory::create<Nd4jLong>(1L);
|
||||||
|
auto inferenceVector = NDArrayFactory::empty<float>();
|
||||||
|
|
||||||
|
sd::ops::skipgram op;
|
||||||
|
auto result = op.evaluate({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {}, {false}, {}, true);
|
||||||
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
|
auto row0 = syn0({0,1, 0,0}, true);
|
||||||
|
auto row1 = syn1({1,2, 0,0}, true);
|
||||||
|
auto row2 = syn1({2,3, 0,0}, true);
|
||||||
|
|
||||||
|
ASSERT_EQ(exp0, row0);
|
||||||
|
ASSERT_EQ(exp1, row1);
|
||||||
|
ASSERT_EQ(exp2, row2);
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(NlpTests, basic_sg_hs_test_3) {
|
||||||
|
auto exp0 = NDArrayFactory::create<float>('c', {1, 10});
|
||||||
|
auto exp1 = NDArrayFactory::create<float>('c', {1, 10});
|
||||||
|
auto exp2 = NDArrayFactory::create<float>('c', {1, 10});
|
||||||
|
|
||||||
|
exp0.assign(0.01f);
|
||||||
|
exp1.assign(0.020005f);
|
||||||
|
exp2.assign(0.019995f);
|
||||||
|
|
||||||
|
auto target = NDArrayFactory::create<int>(0);
|
||||||
|
auto ngStarter = NDArrayFactory::empty<int>();
|
||||||
|
auto indices0 = NDArrayFactory::create<int>('c', {3}, {1, 2, 3});
|
||||||
|
auto indices1 = NDArrayFactory::create<int>('c', {3}, {3, 1, 2});
|
||||||
|
auto codes00 = NDArrayFactory::create<int8_t>('c', {3}, {0, 1, 1});
|
||||||
|
auto codes01 = NDArrayFactory::create<int8_t>('c', {3}, {1, 0, 1});
|
||||||
|
auto syn00 = NDArrayFactory::create<float>('c', {100, 10});
|
||||||
|
auto syn01 = NDArrayFactory::create<float>('c', {100, 10});
|
||||||
|
auto syn10 = NDArrayFactory::create<float>('c', {100, 10});
|
||||||
|
auto syn11 = NDArrayFactory::create<float>('c', {100, 10});
|
||||||
|
auto syn1Neg = NDArrayFactory::empty<float>();
|
||||||
|
auto expTable = NDArrayFactory::create<float>('c', {10000});
|
||||||
|
auto negTable = NDArrayFactory::empty<float>();
|
||||||
|
auto neu1e = NDArrayFactory::create<float>('c', {10});
|
||||||
|
|
||||||
|
RandomGenerator rng(119L, 198L);
|
||||||
|
RandomLauncher::fillUniform(LaunchContext::defaultContext(), rng, &syn00, 0.0, 1.0);
|
||||||
|
RandomLauncher::fillUniform(LaunchContext::defaultContext(), rng, &syn10, 0.0, 1.0);
|
||||||
|
|
||||||
|
syn01.assign(syn00);
|
||||||
|
syn11.assign(syn10);
|
||||||
|
expTable.assign(0.5);
|
||||||
|
|
||||||
|
auto alpha = NDArrayFactory::create<double>(0.001);
|
||||||
|
auto randomValue = NDArrayFactory::create<Nd4jLong>(1L);
|
||||||
|
auto inferenceVector = NDArrayFactory::empty<float>();
|
||||||
|
|
||||||
|
sd::ops::skipgram op;
|
||||||
|
auto result0 = op.evaluate({&target, &ngStarter, &indices0, &codes00, &syn00, &syn10, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {}, {false}, {}, true);
|
||||||
|
auto result1 = op.evaluate({&target, &ngStarter, &indices1, &codes01, &syn01, &syn11, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {}, {false}, {}, true);
|
||||||
|
ASSERT_EQ(Status::OK(), result0.status());
|
||||||
|
|
||||||
|
auto row00 = syn00({0,1, 0,0}, true);
|
||||||
|
auto row01 = syn01({0,1, 0,0}, true);
|
||||||
|
auto row1 = syn10({1,2, 0,0}, true);
|
||||||
|
auto row2 = syn11({1,2, 0,0}, true);
|
||||||
|
|
||||||
|
ASSERT_EQ(row2, row1);
|
||||||
|
ASSERT_EQ(row00, row01);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(NlpTests, basic_sg_hs_ns_test_1) {
|
||||||
|
auto target = NDArrayFactory::create<int>(0);
|
||||||
|
auto ngStarter = NDArrayFactory::create<int>(1);
|
||||||
|
auto indices = NDArrayFactory::create<int>('c', {5}, {1, 2, 3, 4, 5});
|
||||||
|
auto codes = NDArrayFactory::create<int8_t>('c', {5}, {1, 1, 0, 1, 1});
|
||||||
|
auto syn0 = NDArrayFactory::create<float>('c', {100, 150});
|
||||||
|
auto syn1 = NDArrayFactory::create<float>('c', {100, 150});
|
||||||
|
auto syn1Neg = NDArrayFactory::create<float>('c', {100, 150});
|
||||||
|
auto expTable = NDArrayFactory::create<float>('c', {1000});
|
||||||
|
auto negTable = NDArrayFactory::create<float>('c', {1000});
|
||||||
|
auto neu1e = NDArrayFactory::create<float>('c', {10});
|
||||||
|
negTable.linspace(1.0);
|
||||||
|
|
||||||
|
auto alpha = NDArrayFactory::create<double>(1.25);
|
||||||
|
auto randomValue = NDArrayFactory::create<Nd4jLong>(119L);
|
||||||
|
auto inferenceVector = NDArrayFactory::empty<float>();
|
||||||
|
|
||||||
|
sd::ops::skipgram op;
|
||||||
|
auto result = op.evaluate({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {3}, {false}, {}, true);
|
||||||
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(NlpTests, basic_sg_ns_test_1) {
|
||||||
|
auto exp0 = NDArrayFactory::create<float>('c', {1, 10});
|
||||||
|
|
||||||
|
exp0.assign(0.01);
|
||||||
|
|
||||||
|
auto target = NDArrayFactory::create<int>(1);
|
||||||
|
auto ngStarter = NDArrayFactory::create<int>(3);
|
||||||
|
auto indices = NDArrayFactory::empty<int>();
|
||||||
|
auto codes = NDArrayFactory::empty<int8_t>();
|
||||||
|
auto syn0 = NDArrayFactory::create<float>('c', {10, 10});
|
||||||
|
auto syn1 = NDArrayFactory::empty<float>();
|
||||||
|
auto syn1Neg = NDArrayFactory::create<float>('c', {10, 10});
|
||||||
|
auto expTable = NDArrayFactory::create<float>('c', {1000});
|
||||||
|
auto negTable = NDArrayFactory::create<float>('c', {1000});
|
||||||
|
auto neu1e = NDArrayFactory::create<float>('c', {10});
|
||||||
|
|
||||||
|
auto syn1Neg2 = NDArrayFactory::create<float>('c', {10, 10});
|
||||||
|
|
||||||
|
syn0.assign(0.01);
|
||||||
|
syn1.assign(0.02);
|
||||||
|
syn1Neg.assign(0.03);
|
||||||
|
syn1Neg2.assign(0.03);
|
||||||
|
expTable.assign(0.5);
|
||||||
|
|
||||||
|
auto alpha = NDArrayFactory::create<double>(0.001);
|
||||||
|
auto randomValue = NDArrayFactory::create<Nd4jLong>(2L);
|
||||||
|
auto inferenceVector = NDArrayFactory::empty<float>();
|
||||||
|
|
||||||
|
sd::ops::skipgram op;
|
||||||
|
auto result = op.evaluate({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {1, 1}, {false}, {}, true);
|
||||||
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
|
auto row0 = syn0({1,2, 0,0}, true);
|
||||||
|
|
||||||
|
ASSERT_EQ(exp0, row0);
|
||||||
|
ASSERT_FALSE(syn1Neg2.equalsTo(syn1Neg, 1e-6));
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(NlpTests, basic_cb_hs_test_1) {
|
||||||
|
auto exp0 = NDArrayFactory::create<float>('c', {1, 10});
|
||||||
|
auto exp1 = NDArrayFactory::create<float>('c', {1, 10});
|
||||||
|
auto exp2 = NDArrayFactory::create<float>('c', {1, 10});
|
||||||
|
|
||||||
|
exp0.assign(0.0095f);
|
||||||
|
exp1.assign(0.019875f);
|
||||||
|
exp2.assign(0.02f);
|
||||||
|
|
||||||
|
auto target = NDArrayFactory::create<int>(0);
|
||||||
|
auto ngStarter = NDArrayFactory::empty<int>();
|
||||||
|
auto context = NDArrayFactory::create<int>('c', {3}, {0, 1, 2});
|
||||||
|
auto locked = NDArrayFactory::create<int>('c', {3});
|
||||||
|
auto indices = NDArrayFactory::create<int>('c', {2}, {4, 5});
|
||||||
|
auto codes = NDArrayFactory::create<int8_t>('c', {2}, {1, 1});
|
||||||
|
auto syn0 = NDArrayFactory::create<float>('c', {100, 10});
|
||||||
|
auto syn1 = NDArrayFactory::create<float>('c', {100, 10});
|
||||||
|
auto syn1Neg = NDArrayFactory::empty<float>();
|
||||||
|
auto expTable = NDArrayFactory::create<float>('c', {10000});
|
||||||
|
auto negTable = NDArrayFactory::empty<float>();
|
||||||
|
auto numWords = NDArrayFactory::create<int>('c', {1}, {1});
|
||||||
|
|
||||||
|
syn0.assign(0.01);
|
||||||
|
syn1.assign(0.02);
|
||||||
|
expTable.assign(0.5);
|
||||||
|
|
||||||
|
auto alpha = NDArrayFactory::create<double>(0.025);
|
||||||
|
auto randomValue = NDArrayFactory::create<Nd4jLong>(2L);
|
||||||
|
auto inferenceVector = NDArrayFactory::empty<float>();
|
||||||
|
|
||||||
|
sd::ops::cbow op;
|
||||||
|
auto result = op.evaluate({&target, &ngStarter, &context, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &numWords, &locked, &inferenceVector}, {}, {}, {true}, {}, true);
|
||||||
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
|
auto row_s0_0 = syn0({0,1, 0,0}, true);
|
||||||
|
auto row_s0_1 = syn0({1,2, 0,0}, true);
|
||||||
|
auto row_s0_2 = syn0({2,3, 0,0}, true);
|
||||||
|
|
||||||
|
auto row_s1_4 = syn1({4,5, 0,0}, true);
|
||||||
|
auto row_s1_5 = syn1({5,6, 0,0}, true);
|
||||||
|
auto row_s1_6 = syn1({6,7, 0,0}, true);
|
||||||
|
|
||||||
|
ASSERT_EQ(exp0, row_s0_0);
|
||||||
|
ASSERT_EQ(exp0, row_s0_1);
|
||||||
|
ASSERT_EQ(exp0, row_s0_2);
|
||||||
|
|
||||||
|
ASSERT_EQ(exp1, row_s1_4);
|
||||||
|
ASSERT_EQ(exp1, row_s1_5);
|
||||||
|
ASSERT_EQ(exp2, row_s1_6);
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(NlpTests, basic_cb_ns_test_1) {
|
||||||
|
auto exp0 = NDArrayFactory::create<float>('c', {1, 10});
|
||||||
|
auto exp1 = NDArrayFactory::create<float>('c', {1, 10});
|
||||||
|
auto exp2 = NDArrayFactory::create<float>('c', {1, 10});
|
||||||
|
|
||||||
|
exp0.assign(0.0096265625);
|
||||||
|
exp1.assign(0.01);
|
||||||
|
exp2.assign(0.030125f);
|
||||||
|
|
||||||
|
auto target = NDArrayFactory::create<int>(0);
|
||||||
|
auto ngStarter = NDArrayFactory::create<int>(6);
|
||||||
|
auto context = NDArrayFactory::create<int>('c', {3}, {0, 1, 2});
|
||||||
|
auto locked = NDArrayFactory::create<int>('c', {3});
|
||||||
|
auto indices = NDArrayFactory::empty<int>();
|
||||||
|
auto codes = NDArrayFactory::empty<int8_t>();
|
||||||
|
auto syn0 = NDArrayFactory::create<float>('c', {100, 10});
|
||||||
|
auto syn1 = NDArrayFactory::create<float>('c', {100, 10});
|
||||||
|
auto syn1Neg = NDArrayFactory::create<float>('c', {100, 10});
|
||||||
|
auto expTable = NDArrayFactory::create<float>('c', {10000});
|
||||||
|
auto negTable = NDArrayFactory::create<float>('c', {100000});
|
||||||
|
auto numWords = NDArrayFactory::create<int>('c', {2}, {1, 2});
|
||||||
|
|
||||||
|
syn0.assign(0.01);
|
||||||
|
syn1.assign(0.02);
|
||||||
|
syn1Neg.assign(0.03);
|
||||||
|
expTable.assign(0.5);
|
||||||
|
|
||||||
|
auto alpha = NDArrayFactory::create<double>(0.025);
|
||||||
|
auto randomValue = NDArrayFactory::create<Nd4jLong>(2L);
|
||||||
|
auto inferenceVector = NDArrayFactory::empty<float>();
|
||||||
|
|
||||||
|
sd::ops::cbow op;
|
||||||
|
auto result = op.evaluate({&target, &ngStarter, &context, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &numWords, &locked, &inferenceVector}, {}, {1, 2, 0}, {true}, {}, true);
|
||||||
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
|
auto row_s0_0 = syn0({0,1, 0,0}, true);
|
||||||
|
auto row_s0_1 = syn0({1,2, 0,0}, true);
|
||||||
|
auto row_s0_2 = syn0({2,3, 0,0}, true);
|
||||||
|
|
||||||
|
auto row_s1_4 = syn1({4,5, 0,0}, true);
|
||||||
|
auto row_s1_5 = syn1({5,6, 0,0}, true);
|
||||||
|
auto row_s1_6 = syn1Neg({6,7, 0,0}, true);
|
||||||
|
|
||||||
|
|
||||||
|
ASSERT_EQ(exp0, row_s0_0);
|
||||||
|
ASSERT_EQ(exp0, row_s0_1);
|
||||||
|
ASSERT_EQ(exp0, row_s0_2);
|
||||||
|
ASSERT_EQ(exp2, row_s1_6);
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(NlpTests, test_sg_hs_batch_1) {
|
||||||
|
auto exp0 = NDArrayFactory::create<float>('c', {1, 10});
|
||||||
|
auto exp1 = NDArrayFactory::create<float>('c', {1, 10});
|
||||||
|
auto exp2 = NDArrayFactory::create<float>('c', {1, 10});
|
||||||
|
|
||||||
|
exp0.assign(0.01f);
|
||||||
|
exp1.assign(0.020005f);
|
||||||
|
exp2.assign(0.019995f);
|
||||||
|
|
||||||
|
auto target = NDArrayFactory::create<int>('c', {2}, {0, 5});
|
||||||
|
auto ngStarter = NDArrayFactory::empty<int>();
|
||||||
|
auto indices = NDArrayFactory::create<int>('c', {2, 2}, {1, 2, 3, 4});
|
||||||
|
auto codes = NDArrayFactory::create<int8_t>('c', {2, 2}, {0, 1, 1, 1});
|
||||||
|
auto syn0 = NDArrayFactory::create<float>('c', {100, 10});
|
||||||
|
auto syn1 = NDArrayFactory::create<float>('c', {100, 10});
|
||||||
|
auto syn1Neg = NDArrayFactory::empty<float>();
|
||||||
|
auto expTable = NDArrayFactory::create<float>('c', {10000});
|
||||||
|
auto negTable = NDArrayFactory::empty<float>();
|
||||||
|
|
||||||
|
auto alpha = NDArrayFactory::create<double>('c', {2}, {0.001, 0.024});
|
||||||
|
auto randomValue = NDArrayFactory::create<Nd4jLong>('c', {2}, {1L, 3L});
|
||||||
|
auto inferenceVector = NDArrayFactory::empty<float>();
|
||||||
|
auto neu1e = NDArrayFactory::create<float>('c', {2, 10});
|
||||||
|
|
||||||
|
syn0.assign(0.01);
|
||||||
|
syn1.assign(0.02);
|
||||||
|
expTable.assign(0.5);
|
||||||
|
|
||||||
|
sd::ops::skipgram op;
|
||||||
|
auto result = op.evaluate({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {}, {false, true}, {}, true);
|
||||||
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
|
auto row0 = syn0({0,1, 0,0}, true);
|
||||||
|
auto row1 = syn1({1,2, 0,0}, true);
|
||||||
|
auto row2 = syn1({2,3, 0,0}, true);
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp0.equalsTo(row0, 1e-6));
|
||||||
|
ASSERT_TRUE(exp1.equalsTo(row1, 1e-6));
|
||||||
|
ASSERT_TRUE(exp2.equalsTo(row2, 1e-6));
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(NlpTests, test_sg_ns_batch_1) {
|
||||||
|
auto exp0 = NDArrayFactory::create<float>('c', {1, 10});
|
||||||
|
auto exp1 = NDArrayFactory::create<float>('c', {1, 10});
|
||||||
|
auto exp2 = NDArrayFactory::create<float>('c', {1, 10});
|
||||||
|
|
||||||
|
exp0.assign(0.01f);
|
||||||
|
exp1.assign(0.020005f);
|
||||||
|
exp2.assign(0.019995f);
|
||||||
|
|
||||||
|
auto target = NDArrayFactory::create<int>('c', {2}, {0, 5});
|
||||||
|
auto ngStarter = NDArrayFactory::create<int>('c', {2}, {3, 8});
|
||||||
|
auto indices = NDArrayFactory::empty<int>();
|
||||||
|
auto codes = NDArrayFactory::empty<int8_t>();
|
||||||
|
auto syn0 = NDArrayFactory::create<float>('c', {100, 10});
|
||||||
|
auto syn1Neg = NDArrayFactory::create<float>('c', {100, 10});
|
||||||
|
auto syn1 = NDArrayFactory::empty<float>();
|
||||||
|
auto expTable = NDArrayFactory::create<float>('c', {10000});
|
||||||
|
auto negTable = NDArrayFactory::create<float>('c', {100000});
|
||||||
|
|
||||||
|
auto alpha = NDArrayFactory::create<double>('c', {2}, {0.001, 0.024});
|
||||||
|
auto randomValue = NDArrayFactory::create<Nd4jLong>('c', {2}, {1L, 3L});
|
||||||
|
auto inferenceVector = NDArrayFactory::empty<float>();
|
||||||
|
auto neu1e = NDArrayFactory::create<float>('c', {2, 10});
|
||||||
|
|
||||||
|
syn0.assign(0.01);
|
||||||
|
syn1.assign(0.02);
|
||||||
|
expTable.assign(0.5);
|
||||||
|
negTable.linspace(0.0);
|
||||||
|
|
||||||
|
sd::ops::skipgram op;
|
||||||
|
auto result = op.evaluate({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {4, 5}, {false, true}, {}, true);
|
||||||
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(NlpTests, test_cbow_hs_batch_1) {
|
||||||
|
#ifdef __CUDABLAS__
|
||||||
|
return ;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
auto target = NDArrayFactory::create<int>(0);
|
||||||
|
auto ngStarter = NDArrayFactory::empty<int>();
|
||||||
|
auto context = NDArrayFactory::create<int>('c', {2, 3}, {0, 1, 2, 100, 101, 102});
|
||||||
|
auto locked = NDArrayFactory::create<int>('c', {2, 3});
|
||||||
|
auto indices = NDArrayFactory::create<int>('c', {2, 2}, {4, 5, 40, 50});
|
||||||
|
auto codes = NDArrayFactory::create<int8_t>('c', {2, 2}, {1, 1, 1, 1});
|
||||||
|
auto syn0 = NDArrayFactory::create<float>('c', {244, 10});
|
||||||
|
auto syn1 = NDArrayFactory::create<float>('c', {244, 10});
|
||||||
|
auto syn1Neg = NDArrayFactory::empty<float>();
|
||||||
|
auto expTable = NDArrayFactory::create<float>('c', {10000});
|
||||||
|
auto negTable = NDArrayFactory::empty<float>();
|
||||||
|
auto numWords = NDArrayFactory::create<int>('c', {2}, {1, 2});
|
||||||
|
|
||||||
|
syn0.assign(0.01);
|
||||||
|
syn1.assign(0.02);
|
||||||
|
expTable.assign(0.5);
|
||||||
|
|
||||||
|
auto alpha = NDArrayFactory::create<double>('c', {2}, {0.025, 0.025});
|
||||||
|
auto randomValue = NDArrayFactory::create<Nd4jLong>('c', {2}, {2L, 2L});
|
||||||
|
auto inferenceVector = NDArrayFactory::empty<float>();
|
||||||
|
|
||||||
|
sd::ops::cbow op;
|
||||||
|
auto result = op.evaluate({&target, &ngStarter, &context, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &numWords, &locked, &inferenceVector}, {}, {}, {true}, {}, true);
|
||||||
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
|
auto exp0 = NDArrayFactory::create<float>('c', {1, 10});
|
||||||
|
auto exp1 = NDArrayFactory::create<float>('c', {1, 10});
|
||||||
|
auto exp2 = NDArrayFactory::create<float>('c', {1, 10});
|
||||||
|
|
||||||
|
exp0.assign(0.0095f);
|
||||||
|
exp1.assign(0.019875f);
|
||||||
|
exp2.assign(0.02f);
|
||||||
|
|
||||||
|
auto row_s0_0 = syn0({0,1, 0,0}, true);
|
||||||
|
auto row_s0_1 = syn0({1,2, 0,0}, true);
|
||||||
|
auto row_s0_2 = syn0({2,3, 0,0}, true);
|
||||||
|
|
||||||
|
auto row_s1_4 = syn1({4,5, 0,0}, true);
|
||||||
|
auto row_s1_5 = syn1({5,6, 0,0}, true);
|
||||||
|
auto row_s1_6 = syn1({6,7, 0,0}, true);
|
||||||
|
|
||||||
|
ASSERT_EQ(exp0, row_s0_0);
|
||||||
|
ASSERT_EQ(exp0, row_s0_1);
|
||||||
|
ASSERT_EQ(exp0, row_s0_2);
|
||||||
|
ASSERT_EQ(exp1, row_s1_4);
|
||||||
|
ASSERT_EQ(exp1, row_s1_5);
|
||||||
|
ASSERT_EQ(exp2, row_s1_6);
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,75 @@
|
||||||
|
/* ******************************************************************************
|
||||||
|
*
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* See the NOTICE file distributed with this work for additional
|
||||||
|
* information regarding copyright ownership.
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// Created by raver119 on 21.02.18.
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "testlayers.h"
|
||||||
|
#include <array/NDArray.h>
|
||||||
|
#include <graph/Variable.h>
|
||||||
|
#include <flatbuffers/flatbuffers.h>
|
||||||
|
#include <ops/declarable/headers/broadcastable.h>
|
||||||
|
|
||||||
|
using namespace sd;
|
||||||
|
using namespace sd::graph;
|
||||||
|
|
||||||
|
class NodeTests : public testing::Test {
|
||||||
|
public:
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(NodeTests, Test_Dtype_Conversion_1) {
|
||||||
|
auto nodeA = new Node(OpType_TRANSFORM_SAME, 0, 1, {-1}, {2});
|
||||||
|
|
||||||
|
auto nd = nodeA->asT<double>();
|
||||||
|
auto nf = nd->asT<float>();
|
||||||
|
|
||||||
|
ASSERT_EQ(nodeA->id(), nf->id());
|
||||||
|
ASSERT_EQ(*nodeA->name(), *nf->name());
|
||||||
|
ASSERT_EQ(nodeA->getOpClass(), nf->getOpClass());
|
||||||
|
ASSERT_EQ(nodeA->opType(), nf->opType());
|
||||||
|
ASSERT_EQ(nodeA->opNum(), nf->opNum());
|
||||||
|
|
||||||
|
delete nodeA;
|
||||||
|
delete nd;
|
||||||
|
delete nf;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(NodeTests, Test_Dtype_Conversion_2) {
|
||||||
|
sd::ops::add opA;
|
||||||
|
|
||||||
|
//auto nodeA = new Node(OpType_CUSTOM, 0, 1, {-1}, {2});
|
||||||
|
auto nodeA = new Node(&opA, 1, {-1}, {2});
|
||||||
|
//nodeA->setCustomOp(&op);
|
||||||
|
|
||||||
|
auto nd = nodeA->asT<double>();
|
||||||
|
auto nf = nd->asT<float>();
|
||||||
|
|
||||||
|
ASSERT_EQ(nodeA->id(), nf->id());
|
||||||
|
ASSERT_EQ(*nodeA->name(), *nf->name());
|
||||||
|
// ASSERT_EQ(nodeA->getOpClass(), nf->getOpClass());
|
||||||
|
ASSERT_EQ(nodeA->opType(), nf->opType());
|
||||||
|
ASSERT_EQ(nodeA->opNum(), nf->opNum());
|
||||||
|
ASSERT_EQ(nodeA->getCustomOp()->getOpHash(), nf->getCustomOp()->getOpHash());
|
||||||
|
|
||||||
|
delete nodeA;
|
||||||
|
delete nd;
|
||||||
|
delete nf;
|
||||||
|
}
|
|
@ -0,0 +1,125 @@
|
||||||
|
/* ******************************************************************************
|
||||||
|
*
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* See the NOTICE file distributed with this work for additional
|
||||||
|
* information regarding copyright ownership.
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// Created by raver119 on 30.06.18.
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "testlayers.h"
|
||||||
|
#include <array/NDArray.h>
|
||||||
|
#include <helpers/OmpLaunchHelper.h>
|
||||||
|
|
||||||
|
|
||||||
|
using namespace sd;
|
||||||
|
using namespace sd::graph;
|
||||||
|
|
||||||
|
class OmpLaunchHelperTests : public testing::Test {
|
||||||
|
private:
|
||||||
|
int ewt = 0;
|
||||||
|
public:
|
||||||
|
OmpLaunchHelperTests() {
|
||||||
|
this->ewt = Environment::getInstance().elementwiseThreshold();
|
||||||
|
Environment::getInstance().setElementwiseThreshold(1000);
|
||||||
|
};
|
||||||
|
|
||||||
|
~OmpLaunchHelperTests() {
|
||||||
|
Environment::getInstance().setElementwiseThreshold(this->ewt);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(OmpLaunchHelperTests, Test_BetterSpan_1) {
|
||||||
|
auto span = OmpLaunchHelper::betterSpan(1000, 4);
|
||||||
|
ASSERT_EQ(250, span);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(OmpLaunchHelperTests, Test_BetterSpan_2) {
|
||||||
|
auto span = OmpLaunchHelper::betterSpan(1001, 4);
|
||||||
|
ASSERT_EQ(251, span);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(OmpLaunchHelperTests, Test_BetterSpan_3) {
|
||||||
|
auto span = OmpLaunchHelper::betterSpan(1002, 4);
|
||||||
|
ASSERT_EQ(251, span);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(OmpLaunchHelperTests, Test_BetterSpan_5) {
|
||||||
|
auto span = OmpLaunchHelper::betterSpan(1003, 4);
|
||||||
|
ASSERT_EQ(251, span);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(OmpLaunchHelperTests, Test_BetterSpan_6) {
|
||||||
|
auto span = OmpLaunchHelper::betterSpan(1004, 4);
|
||||||
|
ASSERT_EQ(251, span);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(OmpLaunchHelperTests, Test_BetterThreads_1) {
|
||||||
|
auto n = OmpLaunchHelper::betterThreads(4000, 6);
|
||||||
|
ASSERT_EQ(4, n);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(OmpLaunchHelperTests, Test_BetterThreads_2) {
|
||||||
|
auto n = OmpLaunchHelper::betterThreads(12000, 6);
|
||||||
|
ASSERT_EQ(6, n);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(OmpLaunchHelperTests, Test_BetterThreads_3) {
|
||||||
|
auto n = OmpLaunchHelper::betterThreads(899, 6);
|
||||||
|
ASSERT_EQ(1, n);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(OmpLaunchHelperTests, test_tad_threads_1) {
|
||||||
|
Nd4jLong numTads = 16;
|
||||||
|
Nd4jLong tadLength = 16;
|
||||||
|
|
||||||
|
// nd4j_printf("TT: [%i]; ET: [%i];\n", Environment::getInstance().tadThreshold(), Environment::getInstance().elementwiseThreshold());
|
||||||
|
ASSERT_EQ(1, OmpLaunchHelper::tadThreads(tadLength, numTads));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(OmpLaunchHelperTests, test_tad_threads_2) {
|
||||||
|
if (omp_get_max_threads() <= 1)
|
||||||
|
return;
|
||||||
|
|
||||||
|
Nd4jLong numTads = 2;
|
||||||
|
Nd4jLong tadLength = Environment::getInstance().elementwiseThreshold();
|
||||||
|
|
||||||
|
ASSERT_EQ(2, OmpLaunchHelper::tadThreads(tadLength, numTads));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(OmpLaunchHelperTests, test_tad_threads_3) {
|
||||||
|
Nd4jLong numTads = 2;
|
||||||
|
Nd4jLong tadLength = 128;
|
||||||
|
|
||||||
|
ASSERT_EQ(1, OmpLaunchHelper::tadThreads(tadLength, numTads));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(OmpLaunchHelperTests, test_tad_threads_4) {
|
||||||
|
Nd4jLong numTads = 4;
|
||||||
|
Nd4jLong tadLength = 64;
|
||||||
|
|
||||||
|
ASSERT_EQ(1, OmpLaunchHelper::tadThreads(tadLength, numTads));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(OmpLaunchHelperTests, test_tad_threads_5) {
|
||||||
|
auto exp = omp_get_max_threads();
|
||||||
|
|
||||||
|
Nd4jLong numTads = exp;
|
||||||
|
Nd4jLong tadLength = Environment::getInstance().elementwiseThreshold();
|
||||||
|
|
||||||
|
ASSERT_EQ(exp, OmpLaunchHelper::tadThreads(tadLength, numTads));
|
||||||
|
}
|
|
@ -0,0 +1,390 @@
|
||||||
|
/* ******************************************************************************
|
||||||
|
*
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* See the NOTICE file distributed with this work for additional
|
||||||
|
* information regarding copyright ownership.
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// Created by raver119 on 11.10.2017.
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "testlayers.h"
|
||||||
|
#include <vector>
|
||||||
|
#include <ops/declarable/CustomOperations.h>
|
||||||
|
#include <ops/declarable/OpTuple.h>
|
||||||
|
#include <ops/declarable/OpRegistrator.h>
|
||||||
|
#include <graph/GraphExecutioner.h>
|
||||||
|
#include <memory/MemoryReport.h>
|
||||||
|
#include <memory/MemoryUtils.h>
|
||||||
|
#include <helpers/MmulHelper.h>
|
||||||
|
|
||||||
|
using namespace sd;
|
||||||
|
using namespace sd::ops;
|
||||||
|
|
||||||
|
class OneOffTests : public testing::Test {
|
||||||
|
public:
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(OneOffTests, test_avg_pool_3d_1) {
|
||||||
|
auto graph = GraphExecutioner::importFromFlatBuffers("./resources/avg_pooling3d.fb");
|
||||||
|
|
||||||
|
ASSERT_TRUE(graph != nullptr);
|
||||||
|
|
||||||
|
// graph->printOut();
|
||||||
|
|
||||||
|
Nd4jStatus status = GraphExecutioner::execute(graph);
|
||||||
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
delete graph;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(OneOffTests, test_non2d_0A_1) {
|
||||||
|
auto graph = GraphExecutioner::importFromFlatBuffers("./resources/non2d_0A.fb");
|
||||||
|
|
||||||
|
ASSERT_TRUE(graph != nullptr);
|
||||||
|
|
||||||
|
// graph->printOut();
|
||||||
|
|
||||||
|
Nd4jStatus status = GraphExecutioner::execute(graph);
|
||||||
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
delete graph;
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
TEST_F(OneOffTests, test_assert_scalar_float32_1) {
|
||||||
|
sd::ops::Assert op;
|
||||||
|
sd::ops::identity op1;
|
||||||
|
sd::ops::noop op2;
|
||||||
|
auto graph = GraphExecutioner::importFromFlatBuffers("./resources/scalar_float32.fb");
|
||||||
|
|
||||||
|
ASSERT_TRUE(graph != nullptr);
|
||||||
|
|
||||||
|
graph->printOut();
|
||||||
|
|
||||||
|
Nd4jStatus status = GraphExecutioner::execute(graph);
|
||||||
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
delete graph;
|
||||||
|
}*/
|
||||||
|
|
||||||
|
TEST_F(OneOffTests, test_assert_scalar_float32_2) {
|
||||||
|
sd::ops::Assert op;
|
||||||
|
sd::ops::identity op1;
|
||||||
|
sd::ops::noop op2;
|
||||||
|
auto graph = GraphExecutioner::importFromFlatBuffers("./resources/assertsomething.fb");
|
||||||
|
|
||||||
|
ASSERT_TRUE(graph != nullptr);
|
||||||
|
|
||||||
|
// graph->printOut();
|
||||||
|
|
||||||
|
Nd4jStatus status = GraphExecutioner::execute(graph);
|
||||||
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
delete graph;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(OneOffTests, test_pad_1D_1) {
|
||||||
|
auto e = NDArrayFactory::create<float>('c', {7}, {10.f,0.778786f, 0.801198f, 0.724375f, 0.230894f, 0.727141f,10.f});
|
||||||
|
auto graph = GraphExecutioner::importFromFlatBuffers("./resources/pad_1D.fb");
|
||||||
|
|
||||||
|
ASSERT_TRUE(graph != nullptr);
|
||||||
|
|
||||||
|
// graph->printOut();
|
||||||
|
|
||||||
|
Nd4jStatus status = GraphExecutioner::execute(graph);
|
||||||
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
|
||||||
|
ASSERT_TRUE(graph->getVariableSpace()->hasVariable(4));
|
||||||
|
|
||||||
|
auto z = graph->getVariableSpace()->getVariable(4)->getNDArray();
|
||||||
|
ASSERT_TRUE(z != nullptr);
|
||||||
|
|
||||||
|
// z->printIndexedBuffer("z");
|
||||||
|
|
||||||
|
ASSERT_EQ(e, *z);
|
||||||
|
delete graph;
|
||||||
|
}
|
||||||
|
/*
|
||||||
|
TEST_F(OneOffTests, test_scatter_nd_update_1) {
|
||||||
|
|
||||||
|
auto e = NDArrayFactory::create<float>('c', {10, 7}, {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 0.20446908f, 0.37918627f, 0.99792874f, 0.71881700f, 0.18677747f,
|
||||||
|
0.78299069f, 0.55216062f, 0.40746713f, 0.92128086f, 0.57195139f, 0.44686234f, 0.30861020f, 0.31026053f, 0.09293187f,
|
||||||
|
1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 0.95073712f, 0.45613325f, 0.95149803f, 0.88341522f, 0.54366302f, 0.50060666f, 0.39031255f,
|
||||||
|
1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f,
|
||||||
|
1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f});
|
||||||
|
|
||||||
|
auto graph = GraphExecutioner::importFromFlatBuffers("./resources/scatter_nd_update.fb");
|
||||||
|
ASSERT_TRUE(graph != nullptr);
|
||||||
|
|
||||||
|
graph->printOut();
|
||||||
|
|
||||||
|
Nd4jStatus status = GraphExecutioner::execute(graph);
|
||||||
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
|
||||||
|
ASSERT_TRUE(graph->getVariableSpace()->hasVariable(6));
|
||||||
|
|
||||||
|
auto z = graph->getVariableSpace()->getVariable(6)->getNDArray();
|
||||||
|
ASSERT_TRUE(z != nullptr);
|
||||||
|
|
||||||
|
z->printIndexedBuffer("z");
|
||||||
|
|
||||||
|
ASSERT_EQ(e, *z);
|
||||||
|
|
||||||
|
delete graph;
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
|
TEST_F(OneOffTests, test_conv2d_nhwc_failed_1) {
|
||||||
|
auto e = NDArrayFactory::create<float>('c', {1, 5, 5, 6}, {0.55744928f, 0.76827729f, 1.09401524f, 0.00000000f, 0.00000000f, 0.00000000f, 0.56373537f, 0.90029907f, 0.78997850f, 0.00000000f, 0.00000000f, 0.00000000f, 0.14252824f, 0.95961076f, 0.87750554f, 0.00000000f, 0.00000000f, 0.00000000f, 0.44874173f, 0.99537718f, 1.17154264f, 0.00000000f, 0.00000000f, 0.00000000f, 0.60377145f, 0.79939061f, 0.56031001f, 0.00000000f, 0.00000000f, 0.00000000f, 0.52975273f, 0.90678585f, 0.73763013f, 0.00000000f, 0.00000000f, 0.00000000f, 0.22146404f, 0.82499605f, 0.47222072f, 0.00000000f, 0.00000000f, 0.00000000f, 0.42772964f, 0.39793295f, 0.71436501f, 0.00000000f, 0.00000000f, 0.00000000f, 0.48836520f, 1.01658893f, 0.74419701f, 0.00000000f, 0.00000000f, 0.00000000f, 0.78984612f, 0.94083673f, 0.83841157f, 0.00000000f, 0.00000000f, 0.00000000f, 0.40448499f, 0.67732805f, 0.75499672f, 0.00000000f, 0.00000000f, 0.00000000f, 0.43675962f, 0.79476535f, 0.72976631f, 0.00000000f, 0.00000000f, 0.00000000f, 0.58808053f, 0.65222591f, 0.72552216f, 0.00000000f, 0.00000000f, 0.00000000f, 0.37445742f, 1.22581339f, 1.05341125f, 0.00000000f, 0.00000000f, 0.00000000f, 0.30095795f, 0.59941679f, 0.63323414f, 0.00000000f, 0.00000000f, 0.00000000f, 0.24199286f, 1.02546394f, 0.69537812f, 0.00000000f, 0.00000000f, 0.00000000f, 0.23628944f, 0.90791851f, 1.01209974f, 0.00000000f, 0.00000000f, 0.00000000f, 0.62740159f, 0.56518674f, 0.76692569f, 0.00000000f, 0.00000000f, 0.00000000f, 0.13327584f, 0.32628393f, 0.10280430f, 0.00000000f, 0.00000000f, 0.00000000f, 0.42691272f, 0.25625113f, 0.30524066f, 0.00000000f, 0.00000000f, 0.00000000f, 0.17797673f, 0.84179950f, 0.80061519f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00199084f, 0.51838887f, 0.43932241f, 0.00000000f, 0.00000000f, 0.00000000f, 0.16684581f, 0.50822425f, 0.48668745f, 0.00000000f, 0.00000000f, 0.00000000f, 0.16749343f, 0.93093169f, 0.86871749f, 0.00000000f, 0.00000000f, 0.00000000f, 0.17486368f, 0.44460732f, 0.44499981f, 0.00000000f, 0.00000000f, 0.00000000f});
|
||||||
|
|
||||||
|
auto graph = GraphExecutioner::importFromFlatBuffers("./resources/channels_last_b1_k2_s1_d1_SAME_crelu.fb");
|
||||||
|
ASSERT_TRUE(graph != nullptr);
|
||||||
|
|
||||||
|
// graph->printOut();
|
||||||
|
|
||||||
|
Nd4jStatus status = GraphExecutioner::execute(graph);
|
||||||
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
|
||||||
|
ASSERT_TRUE(graph->getVariableSpace()->hasVariable(9));
|
||||||
|
|
||||||
|
auto z = graph->getVariableSpace()->getVariable(9)->getNDArray();
|
||||||
|
ASSERT_TRUE(z != nullptr);
|
||||||
|
|
||||||
|
ASSERT_EQ(e, *z);
|
||||||
|
|
||||||
|
delete graph;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(OneOffTests, test_tensor_array_1) {
|
||||||
|
auto e = NDArrayFactory::create<float>('c', {2, 3}, {0.77878559f, 0.80119777f, 0.72437465f, 0.23089433f, 0.72714126f, 0.18039072f});
|
||||||
|
|
||||||
|
auto graph = GraphExecutioner::importFromFlatBuffers("./resources/tensor_array_close_sz1_float32_nodynamic_noname_noshape.fb");
|
||||||
|
ASSERT_TRUE(graph != nullptr);
|
||||||
|
|
||||||
|
// graph->printOut();
|
||||||
|
|
||||||
|
Nd4jStatus status = GraphExecutioner::execute(graph);
|
||||||
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
ASSERT_TRUE(graph->getVariableSpace()->hasVariable(5));
|
||||||
|
|
||||||
|
auto z = graph->getVariableSpace()->getVariable(5)->getNDArray();
|
||||||
|
ASSERT_TRUE(z != nullptr);
|
||||||
|
|
||||||
|
ASSERT_EQ(e, *z);
|
||||||
|
|
||||||
|
delete graph;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(OneOffTests, test_tensor_array_2) {
|
||||||
|
auto e = NDArrayFactory::create<float>('c', {2, 3}, {0.77878559f, 0.80119777f, 0.72437465f, 0.23089433f, 0.72714126f, 0.18039072f});
|
||||||
|
|
||||||
|
auto graph = GraphExecutioner::importFromFlatBuffers("./resources/tensor_array_split_sz1_float32_nodynamic_noname_noshape.fb");
|
||||||
|
ASSERT_TRUE(graph != nullptr);
|
||||||
|
|
||||||
|
// graph->printOut();
|
||||||
|
|
||||||
|
Nd4jStatus status = GraphExecutioner::execute(graph);
|
||||||
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
ASSERT_TRUE(graph->getVariableSpace()->hasVariable(6));
|
||||||
|
|
||||||
|
auto z = graph->getVariableSpace()->getVariable(6)->getNDArray();
|
||||||
|
ASSERT_TRUE(z != nullptr);
|
||||||
|
|
||||||
|
ASSERT_EQ(e, *z);
|
||||||
|
|
||||||
|
delete graph;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(OneOffTests, test_tensor_array_3) {
|
||||||
|
auto e = NDArrayFactory::create<int>('c', {3, 2, 3}, {7, 2, 9, 4, 3, 3, 8, 7, 0, 0, 6, 8, 7, 9, 0, 1, 1, 4});
|
||||||
|
|
||||||
|
auto graph = GraphExecutioner::importFromFlatBuffers("./resources/tensor_array_stack_sz3-1_int32_dynamic_name_shape.fb");
|
||||||
|
ASSERT_TRUE(graph != nullptr);
|
||||||
|
|
||||||
|
// graph->printOut();
|
||||||
|
|
||||||
|
|
||||||
|
Nd4jStatus status = GraphExecutioner::execute(graph);
|
||||||
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
ASSERT_TRUE(graph->getVariableSpace()->hasVariable(15));
|
||||||
|
|
||||||
|
auto z = graph->getVariableSpace()->getVariable(15)->getNDArray();
|
||||||
|
ASSERT_TRUE(z != nullptr);
|
||||||
|
|
||||||
|
ASSERT_EQ(e, *z);
|
||||||
|
|
||||||
|
delete graph;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(OneOffTests, test_tensor_array_4) {
|
||||||
|
auto e = NDArrayFactory::create<Nd4jLong>('c', {2, 3}, {4, 3, 1, 1, 1, 0});
|
||||||
|
|
||||||
|
auto graph = GraphExecutioner::importFromFlatBuffers("./resources/tensor_array_unstack_sz1_int64_nodynamic_noname_shape2-3.fb");
|
||||||
|
ASSERT_TRUE(graph != nullptr);
|
||||||
|
|
||||||
|
// graph->printOut();
|
||||||
|
|
||||||
|
|
||||||
|
Nd4jStatus status = GraphExecutioner::execute(graph);
|
||||||
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
ASSERT_TRUE(graph->getVariableSpace()->hasVariable(11));
|
||||||
|
|
||||||
|
auto z = graph->getVariableSpace()->getVariable(11)->getNDArray();
|
||||||
|
ASSERT_TRUE(z != nullptr);
|
||||||
|
|
||||||
|
ASSERT_EQ(e, *z);
|
||||||
|
|
||||||
|
delete graph;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(OneOffTests, test_assert_4) {
|
||||||
|
auto e = NDArrayFactory::create<Nd4jLong>('c', {2, 2}, {1, 1, 1, 1});
|
||||||
|
|
||||||
|
auto graph = GraphExecutioner::importFromFlatBuffers("./resources/assert_type_rank2_int64.fb");
|
||||||
|
ASSERT_TRUE(graph != nullptr);
|
||||||
|
|
||||||
|
// graph->printOut();
|
||||||
|
|
||||||
|
|
||||||
|
Nd4jStatus status = GraphExecutioner::execute(graph);
|
||||||
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
ASSERT_TRUE(graph->getVariableSpace()->hasVariable(1));
|
||||||
|
|
||||||
|
auto z = graph->getVariableSpace()->getVariable(1)->getNDArray();
|
||||||
|
ASSERT_TRUE(z != nullptr);
|
||||||
|
|
||||||
|
ASSERT_EQ(e, *z);
|
||||||
|
|
||||||
|
delete graph;
|
||||||
|
}
|
||||||
|
|
||||||
|
// TEST_F(OneOffTests, test_cond_true_1) {
|
||||||
|
// auto e = NDArrayFactory::create<float>('c', {5}, {1.f, 2.f, 3.f, 4.f, 5.f});
|
||||||
|
|
||||||
|
// auto graph = GraphExecutioner::importFromFlatBuffers("./resources/cond_true.fb");
|
||||||
|
// ASSERT_TRUE(graph != nullptr);
|
||||||
|
|
||||||
|
// graph->printOut();
|
||||||
|
|
||||||
|
|
||||||
|
// Nd4jStatus status = GraphExecutioner::execute(graph);
|
||||||
|
// ASSERT_EQ(Status::OK(), status);
|
||||||
|
// ASSERT_TRUE(graph->getVariableSpace()->hasVariable(6));
|
||||||
|
|
||||||
|
// auto z = graph->getVariableSpace()->getVariable(6)->getNDArray();
|
||||||
|
// ASSERT_TRUE(z != nullptr);
|
||||||
|
|
||||||
|
// z->printIndexedBuffer("z buffer");
|
||||||
|
|
||||||
|
// ASSERT_EQ(e, *z);
|
||||||
|
|
||||||
|
// delete graph;
|
||||||
|
// }
|
||||||
|
|
||||||
|
/*
|
||||||
|
TEST_F(OneOffTests, test_cond_false_1) {
|
||||||
|
auto e = NDArrayFactory::create<float>('c', {5}, {1.f, 1.f, 1.f, 1.f, 1.f});
|
||||||
|
|
||||||
|
auto graph = GraphExecutioner::importFromFlatBuffers("./resources/cond_false.fb");
|
||||||
|
ASSERT_TRUE(graph != nullptr);
|
||||||
|
|
||||||
|
graph->printOut();
|
||||||
|
|
||||||
|
|
||||||
|
Nd4jStatus status = GraphExecutioner::execute(graph);
|
||||||
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
ASSERT_TRUE(graph->getVariableSpace()->hasVariable(6));
|
||||||
|
|
||||||
|
auto z = graph->getVariableSpace()->getVariable(6)->getNDArray();
|
||||||
|
ASSERT_TRUE(z != nullptr);
|
||||||
|
|
||||||
|
z->printIndexedBuffer("z buffer");
|
||||||
|
|
||||||
|
ASSERT_EQ(e, *z);
|
||||||
|
|
||||||
|
delete graph;
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
|
TEST_F(OneOffTests, test_identity_n_2) {
|
||||||
|
auto e = NDArrayFactory::create<float>('c', {2, 3}, {0.77878559f, 0.80119777f, 0.72437465f, 0.23089433f, 0.72714126f, 0.18039072f});
|
||||||
|
|
||||||
|
sd::ops::identity_n op;
|
||||||
|
|
||||||
|
auto graph = GraphExecutioner::importFromFlatBuffers("./resources/identity_n_2.fb");
|
||||||
|
ASSERT_TRUE(graph != nullptr);
|
||||||
|
|
||||||
|
// graph->printOut();
|
||||||
|
|
||||||
|
|
||||||
|
Nd4jStatus status = GraphExecutioner::execute(graph);
|
||||||
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
ASSERT_TRUE(graph->getVariableSpace()->hasVariable(1));
|
||||||
|
ASSERT_TRUE(graph->getVariableSpace()->hasVariable(1, 1));
|
||||||
|
|
||||||
|
auto z = graph->getVariableSpace()->getVariable(1)->getNDArray();
|
||||||
|
ASSERT_TRUE(z != nullptr);
|
||||||
|
|
||||||
|
ASSERT_EQ(e, *z);
|
||||||
|
|
||||||
|
delete graph;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(OneOffTests, test_non2d_1) {
|
||||||
|
auto e = NDArrayFactory::create<float>('c', {1, 1}, {5.42746449f});
|
||||||
|
|
||||||
|
auto graph = GraphExecutioner::importFromFlatBuffers("./resources/non2d_1.fb");
|
||||||
|
ASSERT_TRUE(graph != nullptr);
|
||||||
|
|
||||||
|
// graph->printOut();
|
||||||
|
|
||||||
|
Nd4jStatus status = GraphExecutioner::execute(graph);
|
||||||
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
|
||||||
|
ASSERT_TRUE(graph->getVariableSpace()->hasVariable(3));
|
||||||
|
|
||||||
|
auto z = graph->getVariableSpace()->getVariable(3)->getNDArray();
|
||||||
|
ASSERT_TRUE(z != nullptr);
|
||||||
|
|
||||||
|
ASSERT_EQ(e, *z);
|
||||||
|
|
||||||
|
|
||||||
|
delete graph;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(OneOffTests, test_reduce_all_1) {
|
||||||
|
auto e = NDArrayFactory::create<bool>('c', {1, 4}, {true, false, false, false});
|
||||||
|
|
||||||
|
auto graph = GraphExecutioner::importFromFlatBuffers("./resources/reduce_all_rank2_d0_keep.fb");
|
||||||
|
ASSERT_TRUE(graph != nullptr);
|
||||||
|
|
||||||
|
// graph->printOut();
|
||||||
|
|
||||||
|
Nd4jStatus status = GraphExecutioner::execute(graph);
|
||||||
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
|
||||||
|
ASSERT_TRUE(graph->getVariableSpace()->hasVariable(1));
|
||||||
|
|
||||||
|
ASSERT_TRUE(graph->getVariableSpace()->hasVariable(2));
|
||||||
|
auto in = graph->getVariableSpace()->getVariable(2)->getNDArray();
|
||||||
|
|
||||||
|
|
||||||
|
auto z = graph->getVariableSpace()->getVariable(1)->getNDArray();
|
||||||
|
ASSERT_TRUE(z != nullptr);
|
||||||
|
|
||||||
|
ASSERT_EQ(e, *z);
|
||||||
|
|
||||||
|
|
||||||
|
delete graph;
|
||||||
|
}
|
|
@ -0,0 +1,71 @@
|
||||||
|
/* ******************************************************************************
|
||||||
|
*
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* See the NOTICE file distributed with this work for additional
|
||||||
|
* information regarding copyright ownership.
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// Created by raver119 on 15.12.17.
|
||||||
|
//
|
||||||
|
#include "testlayers.h"
|
||||||
|
#include <graph/Graph.h>
|
||||||
|
#include <chrono>
|
||||||
|
#include <graph/Node.h>
|
||||||
|
#include <helpers/OpTracker.h>
|
||||||
|
#include <ops/declarable/CustomOperations.h>
|
||||||
|
|
||||||
|
using namespace sd;
|
||||||
|
using namespace sd::ops;
|
||||||
|
using namespace sd::graph;
|
||||||
|
|
||||||
|
class OpTrackerTests : public testing::Test {
|
||||||
|
public:
|
||||||
|
int numIterations = 10;
|
||||||
|
int poolSize = 10;
|
||||||
|
|
||||||
|
OpTrackerTests() {
|
||||||
|
printf("\n");
|
||||||
|
fflush(stdout);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(OpTrackerTests, Test_Existence_1) {
|
||||||
|
sd::_loader loader;
|
||||||
|
|
||||||
|
// nd4j_printf("Groups: %i; Operations: %i\n", OpTracker::getInstance().totalGroups(), OpTracker::getInstance().totalOperations());
|
||||||
|
|
||||||
|
ASSERT_TRUE(OpTracker::getInstance().totalGroups() > 0);
|
||||||
|
ASSERT_TRUE(OpTracker::getInstance().totalOperations() > 0);
|
||||||
|
|
||||||
|
OpTracker::getInstance().exportOperations();
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(OpTrackerTests, Test_Ops_List_1) {
|
||||||
|
sd::ops::less op;
|
||||||
|
auto vec = OpRegistrator::getInstance().getAllHashes();
|
||||||
|
|
||||||
|
// nd4j_printf("Total ops: %lld\n", vec.size());
|
||||||
|
// nd4j_printf("Less hash: %lld\n", op.getOpHash());
|
||||||
|
|
||||||
|
for (const auto &v: vec) {
|
||||||
|
if (v == 5484196977525668316L) {
|
||||||
|
auto op = OpRegistrator::getInstance().getOperation(v);
|
||||||
|
// nd4j_printf("OpName: %s\n", op->getOpName()->c_str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,61 @@
|
||||||
|
/* ******************************************************************************
|
||||||
|
*
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* See the NOTICE file distributed with this work for additional
|
||||||
|
* information regarding copyright ownership.
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// Created by raver119 on 11.10.2017.
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "testlayers.h"
|
||||||
|
#include <array/NDArray.h>
|
||||||
|
#include <ops/declarable/OpTuple.h>
|
||||||
|
|
||||||
|
using namespace sd;
|
||||||
|
using namespace sd::ops;
|
||||||
|
|
||||||
|
class OpTupleTests : public testing::Test {
|
||||||
|
public:
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(OpTupleTests, DirectConstructorTest1) {
|
||||||
|
auto alpha = NDArrayFactory::create_<float>('c', {1, 2});
|
||||||
|
auto beta = NDArrayFactory::create_<float>('c', {1, 2});
|
||||||
|
OpTuple tuple("dummy", {alpha, beta}, {12.0f}, {1,2, 3});
|
||||||
|
|
||||||
|
ASSERT_EQ("dummy", tuple._opName);
|
||||||
|
ASSERT_EQ(2, tuple._inputs.size());
|
||||||
|
ASSERT_EQ(0, tuple._outputs.size());
|
||||||
|
ASSERT_EQ(1, tuple._tArgs.size());
|
||||||
|
ASSERT_EQ(3, tuple._iArgs.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(OpTupleTests, BuilderTest1) {
|
||||||
|
auto alpha = NDArrayFactory::create_<float>('c', {1, 2});
|
||||||
|
auto beta = NDArrayFactory::create_<float>('c', {1, 2});
|
||||||
|
OpTuple tuple("dummy");
|
||||||
|
tuple.addInput(alpha)
|
||||||
|
->addInput(beta)
|
||||||
|
->setTArgs({12.0f})
|
||||||
|
->setIArgs({1, 2, 3});
|
||||||
|
|
||||||
|
|
||||||
|
ASSERT_EQ("dummy", tuple._opName);
|
||||||
|
ASSERT_EQ(2, tuple._inputs.size());
|
||||||
|
ASSERT_EQ(0, tuple._outputs.size());
|
||||||
|
ASSERT_EQ(1, tuple._tArgs.size());
|
||||||
|
ASSERT_EQ(3, tuple._iArgs.size());
|
||||||
|
}
|
|
@ -0,0 +1,52 @@
|
||||||
|
/* ******************************************************************************
|
||||||
|
*
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* See the NOTICE file distributed with this work for additional
|
||||||
|
* information regarding copyright ownership.
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// Created by agibsonccc on 1/17/17.
|
||||||
|
//
|
||||||
|
#include "testinclude.h"
|
||||||
|
#include <loops/reduce3.h>
|
||||||
|
|
||||||
|
class EqualsTest : public testing::Test {
|
||||||
|
public:
|
||||||
|
const Nd4jLong firstShapeBuffer[8] = {2,1,2,1,1,0,1,102};
|
||||||
|
float data[2] = {1.0f, 7.0f};
|
||||||
|
const Nd4jLong secondShapeBuffer[8] = {2,2,1,6,1,0,6,99};
|
||||||
|
float dataSecond[12] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f};
|
||||||
|
int opNum = 4;
|
||||||
|
float extraArgs[1] = {1e-6f};
|
||||||
|
int dimension[1] = {2147483647};
|
||||||
|
int dimensionLength = 1;
|
||||||
|
};
|
||||||
|
|
||||||
|
#ifndef __CUDABLAS__
|
||||||
|
|
||||||
|
TEST_F(EqualsTest,Eps) {
|
||||||
|
auto val = sd::NDArrayFactory::create(0.0f);
|
||||||
|
functions::reduce3::Reduce3<float, float>::execScalar(opNum,
|
||||||
|
data,
|
||||||
|
firstShapeBuffer,
|
||||||
|
extraArgs,
|
||||||
|
dataSecond,
|
||||||
|
secondShapeBuffer,
|
||||||
|
val.buffer(),
|
||||||
|
val.shapeInfo());
|
||||||
|
ASSERT_TRUE(val.e<float>(0) < 0.5);
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,146 @@
|
||||||
|
/* ******************************************************************************
|
||||||
|
*
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* See the NOTICE file distributed with this work for additional
|
||||||
|
* information regarding copyright ownership.
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "testlayers.h"
|
||||||
|
#include <graph/Graph.h>
|
||||||
|
#include <chrono>
|
||||||
|
#include <graph/Node.h>
|
||||||
|
#include <ops/declarable/CustomOperations.h>
|
||||||
|
#include <graph/profiling/GraphProfilingHelper.h>
|
||||||
|
#include <loops/type_conversions.h>
|
||||||
|
#include <helpers/threshold.h>
|
||||||
|
#include <helpers/MmulHelper.h>
|
||||||
|
#include <ops/ops.h>
|
||||||
|
#include <helpers/OmpLaunchHelper.h>
|
||||||
|
#include <helpers/GradCheck.h>
|
||||||
|
#include <ops/declarable/helpers/im2col.h>
|
||||||
|
#include <helpers/Loops.h>
|
||||||
|
#include <helpers/RandomLauncher.h>
|
||||||
|
|
||||||
|
#include <helpers/BenchmarkHelper.h>
|
||||||
|
#include <ops/declarable/helpers/scatter.h>
|
||||||
|
#include <helpers/ConstantShapeHelper.h>
|
||||||
|
#include <helpers/ConstantTadHelper.h>
|
||||||
|
#include <array>
|
||||||
|
#include <performance/benchmarking/FullBenchmarkSuit.h>
|
||||||
|
#include <performance/benchmarking/LightBenchmarkSuit.h>
|
||||||
|
|
||||||
|
#include <ops/declarable/helpers/legacy_helpers.h>
|
||||||
|
#include <execution/ThreadPool.h>
|
||||||
|
|
||||||
|
using namespace sd;
|
||||||
|
using namespace sd::graph;
|
||||||
|
|
||||||
|
class PerformanceTests : public testing::Test {
|
||||||
|
public:
|
||||||
|
int numIterations = 100;
|
||||||
|
|
||||||
|
PerformanceTests() {
|
||||||
|
samediff::ThreadPool::getInstance();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
#ifdef RELEASE_BUILD
|
||||||
|
|
||||||
|
TEST_F(PerformanceTests, test_matmul_c_f_1) {
|
||||||
|
int iterations = 500;
|
||||||
|
std::vector<ino64_t> valuesC, valuesF;
|
||||||
|
for (int e = 0; e < iterations; e++) {
|
||||||
|
auto xc = NDArrayFactory::create<float>('c', {512, 2048});
|
||||||
|
auto yc = NDArrayFactory::create<float>('c', {2048, 512});
|
||||||
|
auto zc = NDArrayFactory::create<float>('c', {512, 512});
|
||||||
|
|
||||||
|
auto xf = NDArrayFactory::create<float>('f', {512, 2048});
|
||||||
|
auto yf = NDArrayFactory::create<float>('f', {2048, 512});
|
||||||
|
auto zf = NDArrayFactory::create<float>('f', {512, 512});
|
||||||
|
|
||||||
|
auto warm = xc.like();
|
||||||
|
warm.linspace(1.0);
|
||||||
|
|
||||||
|
//zc.linspace(1.0);
|
||||||
|
//zf.linspace(1.0);
|
||||||
|
|
||||||
|
sd::ops::matmul op;
|
||||||
|
|
||||||
|
auto timeStartF = std::chrono::system_clock::now();
|
||||||
|
|
||||||
|
op.execute({&xf, &yf}, {&zf});
|
||||||
|
|
||||||
|
auto timeEndF = std::chrono::system_clock::now();
|
||||||
|
auto outerTimeF = std::chrono::duration_cast<std::chrono::nanoseconds>(timeEndF - timeStartF).count();
|
||||||
|
|
||||||
|
|
||||||
|
auto timeStartC = std::chrono::system_clock::now();
|
||||||
|
|
||||||
|
op.execute({&xc, &yc}, {&zc});
|
||||||
|
|
||||||
|
auto timeEndC = std::chrono::system_clock::now();
|
||||||
|
auto outerTimeC = std::chrono::duration_cast<std::chrono::nanoseconds>(timeEndC - timeStartC).count();
|
||||||
|
|
||||||
|
valuesF.emplace_back(outerTimeF);
|
||||||
|
valuesC.emplace_back(outerTimeC);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::sort(valuesC.begin(), valuesC.end());
|
||||||
|
std::sort(valuesF.begin(), valuesF.end());
|
||||||
|
|
||||||
|
|
||||||
|
nd4j_printf("Median time C: [%lld]; Median time F: [%lld];", valuesC[valuesC.size() / 2], valuesF[valuesF.size() / 2]);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(PerformanceTests, test_maxpooling2d_1) {
|
||||||
|
std::vector<Nd4jLong> valuesX;
|
||||||
|
// auto x = NDArrayFactory::create<float>('c', {32, 3, 224, 224});
|
||||||
|
// auto z = NDArrayFactory::create<float>('c', {32, 3, 224, 224});
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {8, 3, 64, 64});
|
||||||
|
auto z = NDArrayFactory::create<float>('c', {8, 3, 64, 64});
|
||||||
|
x.linspace(1.0f);
|
||||||
|
Nd4jLong k = 5;
|
||||||
|
|
||||||
|
|
||||||
|
Nd4jLong iArgs[] {k,k, 1,1, 0,0, 1,1, 1};
|
||||||
|
Context ctx(1);
|
||||||
|
ctx.setInputArray(0, &x);
|
||||||
|
ctx.setOutputArray(0, &z);
|
||||||
|
ctx.setIArguments(iArgs, 9);
|
||||||
|
|
||||||
|
sd::ops::maxpool2d op;
|
||||||
|
|
||||||
|
for (int i = 0; i < numIterations; i++) {
|
||||||
|
auto timeStart = std::chrono::system_clock::now();
|
||||||
|
|
||||||
|
op.execute(&ctx);
|
||||||
|
|
||||||
|
auto timeEnd = std::chrono::system_clock::now();
|
||||||
|
auto outerTime = std::chrono::duration_cast<std::chrono::nanoseconds>(timeEnd - timeStart).count();
|
||||||
|
valuesX.emplace_back(outerTime);
|
||||||
|
|
||||||
|
if ((i + 1) % 1000 == 0)
|
||||||
|
nd4j_printf("Iteration %i finished...\n", i + 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::sort(valuesX.begin(), valuesX.end());
|
||||||
|
nd4j_printf("Execution time: %lld; Min: %lld; Max: %lld;\n", valuesX[valuesX.size() / 2], valuesX[0], valuesX[valuesX.size() - 1]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,94 @@
|
||||||
|
/* ******************************************************************************
|
||||||
|
*
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* See the NOTICE file distributed with this work for additional
|
||||||
|
* information regarding copyright ownership.
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver110@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "testlayers.h"
|
||||||
|
#include <graph/Graph.h>
|
||||||
|
#include <chrono>
|
||||||
|
#include <graph/Node.h>
|
||||||
|
#include <ops/declarable/CustomOperations.h>
|
||||||
|
#include <graph/profiling/GraphProfilingHelper.h>
|
||||||
|
#include <loops/type_conversions.h>
|
||||||
|
#include <helpers/threshold.h>
|
||||||
|
#include <helpers/MmulHelper.h>
|
||||||
|
#include <ops/ops.h>
|
||||||
|
#include <helpers/OmpLaunchHelper.h>
|
||||||
|
#include <helpers/GradCheck.h>
|
||||||
|
#include <ops/declarable/helpers/im2col.h>
|
||||||
|
#include <helpers/Loops.h>
|
||||||
|
#include <helpers/RandomLauncher.h>
|
||||||
|
#include <ops/declarable/helpers/convolutions.h>
|
||||||
|
|
||||||
|
#include <helpers/BenchmarkHelper.h>
|
||||||
|
#include <ops/declarable/helpers/scatter.h>
|
||||||
|
#include <helpers/ConstantShapeHelper.h>
|
||||||
|
#include <helpers/ConstantTadHelper.h>
|
||||||
|
#include <array>
|
||||||
|
#include <performance/benchmarking/FullBenchmarkSuit.h>
|
||||||
|
#include <performance/benchmarking/LightBenchmarkSuit.h>
|
||||||
|
#include <random>
|
||||||
|
#include <ops/declarable/helpers/legacy_helpers.h>
|
||||||
|
#include <ops/declarable/helpers/addBias.h>
|
||||||
|
#include <ops/declarable/helpers/axis.h>
|
||||||
|
#include <ops/declarable/helpers/reductions.h>
|
||||||
|
#include <helpers/LoopsCoordsHelper.h>
|
||||||
|
|
||||||
|
using namespace sd;
|
||||||
|
using namespace sd::graph;
|
||||||
|
|
||||||
|
class PrimitivesTests : public testing::Test {
|
||||||
|
public:
|
||||||
|
|
||||||
|
PrimitivesTests() {
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(PrimitivesTests, test_mod_1) {
|
||||||
|
int ix = 7;
|
||||||
|
int iy = 3;
|
||||||
|
|
||||||
|
|
||||||
|
auto v = simdOps::Mod<int, int, int>::op(ix, iy);
|
||||||
|
|
||||||
|
ASSERT_EQ(7 % 3, v);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(PrimitivesTests, test_mod_2) {
|
||||||
|
float ix = 7.f;
|
||||||
|
float iy = 3.f;
|
||||||
|
|
||||||
|
|
||||||
|
auto e = sd::math::nd4j_fmod<float, float, float>(ix, iy);
|
||||||
|
auto v = simdOps::Mod<float, float, float>::op(ix, iy);
|
||||||
|
|
||||||
|
ASSERT_NEAR(e, v, 1e-5f);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(PrimitivesTests, test_mod_3) {
|
||||||
|
float ix = 7.f;
|
||||||
|
float iy = 0.f;
|
||||||
|
|
||||||
|
|
||||||
|
auto e = sd::math::nd4j_fmod<float, float, float>(ix, iy);
|
||||||
|
auto v = simdOps::Mod<float, float, float>::op(ix, iy);
|
||||||
|
|
||||||
|
// absence of SIGFPE will be a good enough
|
||||||
|
}
|
|
@ -0,0 +1,112 @@
|
||||||
|
/* ******************************************************************************
|
||||||
|
*
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* See the NOTICE file distributed with this work for additional
|
||||||
|
* information regarding copyright ownership.
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
|
||||||
|
#include "testlayers.h"
|
||||||
|
#include <graph/GraphExecutioner.h>
|
||||||
|
|
||||||
|
/*
|
||||||
|
|
||||||
|
using namespace sd::graph;
|
||||||
|
|
||||||
|
class ProtoBufTests : public testing::Test {
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(ProtoBufTests, TestBinaryLoad1) {
|
||||||
|
GOOGLE_PROTOBUF_VERIFY_VERSION;
|
||||||
|
|
||||||
|
auto graph = GraphExecutioner<float>::importFromTensorFlow("../../../tests/resources/tensorflow_inception_graph.pb");
|
||||||
|
|
||||||
|
ASSERT_FALSE(graph == nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ProtoBufTests, TestTextLoad1) {
|
||||||
|
GOOGLE_PROTOBUF_VERIFY_VERSION;
|
||||||
|
|
||||||
|
auto graph = GraphExecutioner<float>::importFromTensorFlow("../../../tests/resources/max_graph.pb.txt");
|
||||||
|
|
||||||
|
ASSERT_FALSE(graph == nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(ProtoBufTests, TestTextLoad2) {
|
||||||
|
GOOGLE_PROTOBUF_VERIFY_VERSION;
|
||||||
|
|
||||||
|
auto graph = GraphExecutioner<float>::importFromTensorFlow("../../../tests/resources/max_add_2.pb.txt");
|
||||||
|
|
||||||
|
ASSERT_FALSE(graph == nullptr);
|
||||||
|
|
||||||
|
ASSERT_EQ(2, graph->getVariableSpace()->externalEntries());
|
||||||
|
|
||||||
|
auto var0 = graph->getVariableSpace()->getVariable(new std::string("zeros"));
|
||||||
|
auto var1 = graph->getVariableSpace()->getVariable(new std::string("ones"));
|
||||||
|
|
||||||
|
|
||||||
|
// first we're veryfying variable states
|
||||||
|
ASSERT_TRUE(var0 != nullptr);
|
||||||
|
ASSERT_TRUE(var1 != nullptr);
|
||||||
|
|
||||||
|
ASSERT_TRUE(var0->getNDArray() != nullptr);
|
||||||
|
ASSERT_TRUE(var1->getNDArray() != nullptr);
|
||||||
|
|
||||||
|
ASSERT_EQ(12, var0->getNDArray()->lengthOf());
|
||||||
|
ASSERT_EQ(12, var1->getNDArray()->lengthOf());
|
||||||
|
|
||||||
|
ASSERT_NEAR(0.0f, var0->getNDArray()->reduceNumber<simdOps::Sum<float>>(), 1e-5);
|
||||||
|
ASSERT_NEAR(12.0f, var1->getNDArray()->reduceNumber<simdOps::Sum<float>>(), 1e-5);
|
||||||
|
ASSERT_NEAR(1.0f, var1->getNDArray()->reduceNumber<simdOps::Mean<float>>(), 1e-5);
|
||||||
|
|
||||||
|
|
||||||
|
// now we're veryfying op graph
|
||||||
|
ASSERT_EQ(1, graph->totalNodes());
|
||||||
|
|
||||||
|
GraphExecutioner<float>::execute(graph);
|
||||||
|
|
||||||
|
ASSERT_NEAR(12.0f, var0->getNDArray()->reduceNumber<simdOps::Sum<float>>(), 1e-5);
|
||||||
|
ASSERT_NEAR(1.0f, var0->getNDArray()->reduceNumber<simdOps::Mean<float>>(), 1e-5);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(ProtoBufTests, TestTextLoad3) {
|
||||||
|
GOOGLE_PROTOBUF_VERIFY_VERSION;
|
||||||
|
|
||||||
|
auto graph = GraphExecutioner<float>::importFromTensorFlow("../../../tests/resources/max_multiply.pb.txt");
|
||||||
|
|
||||||
|
ASSERT_FALSE(graph == nullptr);
|
||||||
|
|
||||||
|
ASSERT_EQ(2, graph->getVariableSpace()->externalEntries());
|
||||||
|
|
||||||
|
auto var0 = graph->getVariableSpace()->getVariable(new std::string("Placeholder"));
|
||||||
|
auto var1 = graph->getVariableSpace()->getVariable(new std::string("Placeholder_1"));
|
||||||
|
|
||||||
|
ASSERT_TRUE(var0 != nullptr);
|
||||||
|
ASSERT_TRUE(var1 != nullptr);
|
||||||
|
|
||||||
|
// we expect both variables to be set to null here
|
||||||
|
ASSERT_TRUE(var0->getNDArray() == nullptr);
|
||||||
|
ASSERT_TRUE(var1->getNDArray() == nullptr);
|
||||||
|
|
||||||
|
// now we're veryfying op graph
|
||||||
|
ASSERT_EQ(1, graph->totalNodes());
|
||||||
|
}
|
||||||
|
*/
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue