2019-06-06 14:21:15 +02:00
|
|
|
/*******************************************************************************
|
|
|
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
|
|
*
|
|
|
|
* This program and the accompanying materials are made available under the
|
|
|
|
* terms of the Apache License, Version 2.0 which is available at
|
|
|
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
|
|
*
|
|
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
|
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
|
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
|
|
* License for the specific language governing permissions and limitations
|
|
|
|
* under the License.
|
|
|
|
*
|
|
|
|
* SPDX-License-Identifier: Apache-2.0
|
|
|
|
******************************************************************************/
|
|
|
|
|
|
|
|
//
|
|
|
|
// Created by agibsonccc on 2/21/16.
|
|
|
|
//
|
|
|
|
|
|
|
|
#define __STDC_CONSTANT_MACROS
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
#include <legacy/NativeOps.h>
|
|
|
|
#include "legacy/NativeOpExecutioner.h"
|
|
|
|
#include <array/NDArray.h>
|
|
|
|
#include <graph/GraphExecutioner.h>
|
2019-06-06 14:21:15 +02:00
|
|
|
#include <graph/GraphHolder.h>
|
2020-03-02 10:49:41 +01:00
|
|
|
#include <math/templatemath.h>
|
2019-06-06 14:21:15 +02:00
|
|
|
#include <types/float8.h>
|
|
|
|
#include <loops/type_conversions.h>
|
|
|
|
#include <helpers/helper_ptrmap.h>
|
|
|
|
#include <helpers/logger.h>
|
2020-03-02 10:49:41 +01:00
|
|
|
#include <system/pointercast.h>
|
|
|
|
#include <system/pairwise_util.h>
|
2019-06-06 14:21:15 +02:00
|
|
|
#include <types/types.h>
|
|
|
|
#include <ops/declarable/helpers/transforms.h>
|
|
|
|
#include <exceptions/allocation_exception.h>
|
2019-11-13 15:04:59 +01:00
|
|
|
#include <helpers/BlasHelper.h>
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
|
|
|
|
#include <fcntl.h>
|
|
|
|
#include <stdio.h>
|
|
|
|
#include <stdlib.h>
|
|
|
|
#ifndef _WIN32
|
|
|
|
#include <unistd.h>
|
|
|
|
#include <sys/mman.h>
|
|
|
|
#else
|
|
|
|
#include <io.h>
|
|
|
|
#include <helpers/mman.h>
|
|
|
|
#endif
|
|
|
|
#include <sys/types.h>
|
|
|
|
|
|
|
|
#include <ops/declarable/CustomOperations.h>
|
|
|
|
#include <errno.h>
|
|
|
|
|
|
|
|
|
|
|
|
char *name;
|
|
|
|
bool nameSet = false;
|
|
|
|
|
|
|
|
|
|
|
|
#ifdef __ND4J_EXPERIMENTAL__
|
|
|
|
bool experimentalSupport = true;
|
|
|
|
#else
|
|
|
|
bool experimentalSupport = false;
|
|
|
|
#endif
|
|
|
|
|
|
|
|
#include <ops/specials.h>
|
2020-03-02 10:49:41 +01:00
|
|
|
#include <system/Environment.h>
|
|
|
|
#include <helpers/TAD.h>
|
2019-06-06 14:21:15 +02:00
|
|
|
#include <ops/declarable/OpRegistrator.h>
|
|
|
|
#include <graph/Context.h>
|
|
|
|
#include <graph/ResultWrapper.h>
|
|
|
|
#include <helpers/DebugHelper.h>
|
|
|
|
#include <helpers/ConstantTadHelper.h>
|
2019-07-12 07:21:15 +02:00
|
|
|
#include <performance/benchmarking/BenchmarkSuit.h>
|
|
|
|
#include <performance/benchmarking/FullBenchmarkSuit.h>
|
|
|
|
#include <performance/benchmarking/LightBenchmarkSuit.h>
|
2019-11-13 15:04:59 +01:00
|
|
|
#include <execution/Threads.h>
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-09-11 20:50:28 +02:00
|
|
|
#ifdef CPU_FEATURES
|
|
|
|
#include <cpuinfo_x86.h>
|
|
|
|
#endif
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
using namespace sd;
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
void setElementThreshold(int num) {
|
2019-06-06 14:21:15 +02:00
|
|
|
if (num > 0)
|
2020-06-06 14:26:55 +02:00
|
|
|
sd::Environment::getInstance().setElementwiseThreshold(num);
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
void setTADThreshold(int num) {
|
2019-06-06 14:21:15 +02:00
|
|
|
if (num > 0)
|
2020-06-06 14:26:55 +02:00
|
|
|
sd::Environment::getInstance().setTadThreshold(num);
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
/**
|
|
|
|
*
|
|
|
|
* @param opNum
|
|
|
|
* @param hX
|
|
|
|
* @param hXShapeInfo
|
|
|
|
* @param extraParams
|
|
|
|
*/
|
2019-07-22 13:34:08 +02:00
|
|
|
void execIndexReduceScalar(Nd4jPointer *extraPointers,
|
2019-06-06 14:21:15 +02:00
|
|
|
int opNum,
|
2020-05-09 07:06:14 +02:00
|
|
|
OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo,
|
2019-06-06 14:21:15 +02:00
|
|
|
void *extraParams,
|
2020-05-09 07:06:14 +02:00
|
|
|
OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo) {
|
2019-08-26 18:57:51 +02:00
|
|
|
try {
|
2020-01-04 11:27:50 +01:00
|
|
|
NativeOpExecutioner::execIndexReduceScalar(nullptr, opNum, dbX->primary(), hXShapeInfo, dbX->special(), dXShapeInfo, extraParams, dbZ->primary(), hZShapeInfo, dbZ->special(), dZShapeInfo);
|
2019-08-26 18:57:51 +02:00
|
|
|
} catch (std::exception &e) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
2019-08-26 18:57:51 +02:00
|
|
|
}
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
/**
|
|
|
|
*
|
|
|
|
* @param opNum
|
|
|
|
* @param hX
|
|
|
|
* @param hXShapeInfo
|
|
|
|
* @param extraParams
|
|
|
|
* @param hZ
|
|
|
|
* @param hZShapeInfo
|
|
|
|
* @param dimension
|
|
|
|
* @param dimensionLength
|
|
|
|
*/
|
2019-07-22 13:34:08 +02:00
|
|
|
void execIndexReduce(Nd4jPointer *extraPointers,int opNum,
|
2020-05-09 07:06:14 +02:00
|
|
|
OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo,
|
2019-06-06 14:21:15 +02:00
|
|
|
void *extraParams,
|
2020-05-09 07:06:14 +02:00
|
|
|
OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo,
|
|
|
|
OpaqueDataBuffer *dbDimension, const Nd4jLong *hDimensionShape, const Nd4jLong *dDimensionShape) {
|
2019-08-26 18:57:51 +02:00
|
|
|
try {
|
2020-01-04 11:27:50 +01:00
|
|
|
auto dimension = reinterpret_cast<int *>(dbDimension->primary());
|
2019-08-26 18:57:51 +02:00
|
|
|
int dimensionLength = static_cast<int>(shape::length(hDimensionShape));
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-06-06 14:26:55 +02:00
|
|
|
auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(hXShapeInfo, dimension,
|
2019-08-26 18:57:51 +02:00
|
|
|
dimensionLength);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-26 18:57:51 +02:00
|
|
|
auto hTADShapeInfo = tadPack.primaryShapeInfo();
|
|
|
|
auto hTADOffsets = tadPack.primaryOffsets();
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-01-04 11:27:50 +01:00
|
|
|
auto hz = reinterpret_cast<Nd4jLong *>(dbZ->primary());
|
2019-08-26 18:57:51 +02:00
|
|
|
|
|
|
|
NativeOpExecutioner::execIndexReduce(nullptr, opNum,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbX->primary(),
|
2019-08-26 18:57:51 +02:00
|
|
|
hXShapeInfo,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbX->special(),
|
2019-08-26 18:57:51 +02:00
|
|
|
dXShapeInfo,
|
|
|
|
extraParams,
|
|
|
|
hz,
|
|
|
|
hZShapeInfo,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbZ->special(),
|
2019-08-26 18:57:51 +02:00
|
|
|
dZShapeInfo,
|
|
|
|
dimension,
|
|
|
|
dimensionLength,
|
|
|
|
hTADShapeInfo,
|
|
|
|
hTADOffsets);
|
|
|
|
} catch (std::exception &e) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
2019-08-26 18:57:51 +02:00
|
|
|
}
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
*
|
|
|
|
* @param opNum
|
|
|
|
* @param hX
|
|
|
|
* @param hXShapeInfo
|
|
|
|
* @param hY
|
|
|
|
* @param hYShapeInfo
|
|
|
|
* @param hZ
|
|
|
|
* @param hZShapeInfo
|
|
|
|
* @param dimension
|
|
|
|
* @param dimensionLength
|
|
|
|
*/
|
2019-07-22 13:34:08 +02:00
|
|
|
void execBroadcast(Nd4jPointer *extraPointers,
|
2019-06-06 14:21:15 +02:00
|
|
|
int opNum,
|
2020-05-09 07:06:14 +02:00
|
|
|
OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo,
|
|
|
|
OpaqueDataBuffer *dbY, const Nd4jLong *hYShapeInfo, const Nd4jLong *dYShapeInfo,
|
|
|
|
OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo,
|
|
|
|
OpaqueDataBuffer *dbDimension, const Nd4jLong *hDimensionShape, const Nd4jLong *dDimensionShape) {
|
2020-01-04 11:27:50 +01:00
|
|
|
try {
|
|
|
|
auto dimension = reinterpret_cast<int *>(dbDimension->primary());
|
2020-05-09 07:06:14 +02:00
|
|
|
auto dimensionLength = static_cast<int>(shape::length(hDimensionShape));
|
2019-08-26 18:57:51 +02:00
|
|
|
|
2020-06-06 14:26:55 +02:00
|
|
|
auto tadPackX = sd::ConstantTadHelper::getInstance().tadForDimensions(hXShapeInfo, dimension, dimensionLength);
|
|
|
|
auto tadPackZ = sd::ConstantTadHelper::getInstance().tadForDimensions(hZShapeInfo, dimension, dimensionLength);
|
2019-08-26 18:57:51 +02:00
|
|
|
|
|
|
|
auto hTADShapeInfo = tadPackX.primaryShapeInfo();
|
|
|
|
auto hTADOffsets = tadPackX.primaryOffsets();
|
|
|
|
auto hTADShapeInfoZ = tadPackZ.primaryShapeInfo();
|
|
|
|
auto hTADOffsetsZ = tadPackZ.primaryOffsets();
|
|
|
|
|
|
|
|
NativeOpExecutioner::execBroadcast(nullptr,
|
|
|
|
opNum,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbX->primary(),
|
2019-08-26 18:57:51 +02:00
|
|
|
hXShapeInfo,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbX->special(),
|
2019-08-26 18:57:51 +02:00
|
|
|
dXShapeInfo,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbY->primary(),
|
2019-08-26 18:57:51 +02:00
|
|
|
hYShapeInfo,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbY->special(),
|
2019-08-26 18:57:51 +02:00
|
|
|
dYShapeInfo,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbZ->primary(), hZShapeInfo,
|
|
|
|
dbZ->special(), dZShapeInfo,
|
2019-08-26 18:57:51 +02:00
|
|
|
dimension,
|
|
|
|
dimensionLength, hTADShapeInfo, hTADOffsets, hTADShapeInfoZ, hTADOffsetsZ);
|
|
|
|
} catch (std::exception &e) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
2019-08-26 18:57:51 +02:00
|
|
|
}
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
void execBroadcastBool(Nd4jPointer *extraPointers,
|
2019-06-06 14:21:15 +02:00
|
|
|
int opNum,
|
2020-05-09 07:06:14 +02:00
|
|
|
OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo,
|
|
|
|
OpaqueDataBuffer *dbY, const Nd4jLong *hYShapeInfo, const Nd4jLong *dYShapeInfo,
|
|
|
|
OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo,
|
2019-11-21 13:43:03 +01:00
|
|
|
void *extraParams,
|
2020-05-09 07:06:14 +02:00
|
|
|
OpaqueDataBuffer *dbDimension, const Nd4jLong *hDimensionShape, const Nd4jLong *dDimensionShape) {
|
2019-08-26 18:57:51 +02:00
|
|
|
try {
|
2020-01-04 11:27:50 +01:00
|
|
|
auto dimension = reinterpret_cast<int *>(dbDimension->primary());
|
2020-05-09 07:06:14 +02:00
|
|
|
auto dimensionLength = static_cast<int>(shape::length(hDimensionShape));
|
2019-08-26 18:57:51 +02:00
|
|
|
|
2020-06-06 14:26:55 +02:00
|
|
|
auto tadPackX = sd::ConstantTadHelper::getInstance().tadForDimensions(hXShapeInfo, dimension, dimensionLength);
|
|
|
|
auto tadPackZ = sd::ConstantTadHelper::getInstance().tadForDimensions(hZShapeInfo, dimension, dimensionLength);
|
2019-08-26 18:57:51 +02:00
|
|
|
|
|
|
|
auto hTADShapeInfo = tadPackX.primaryShapeInfo();
|
|
|
|
auto hTADOffsets = tadPackX.primaryOffsets();
|
|
|
|
auto hTADShapeInfoZ = tadPackZ.primaryShapeInfo();
|
|
|
|
auto hTADOffsetsZ = tadPackZ.primaryOffsets();
|
|
|
|
|
|
|
|
NativeOpExecutioner::execBroadcastBool(nullptr,
|
|
|
|
opNum,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbX->primary(),
|
2019-08-26 18:57:51 +02:00
|
|
|
hXShapeInfo,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbX->special(),
|
2019-08-26 18:57:51 +02:00
|
|
|
dXShapeInfo,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbY->primary(),
|
2019-08-26 18:57:51 +02:00
|
|
|
hYShapeInfo,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbY->special(),
|
2019-08-26 18:57:51 +02:00
|
|
|
dYShapeInfo,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbZ->primary(), hZShapeInfo,
|
|
|
|
dbZ->special(), dZShapeInfo,
|
2019-11-21 13:43:03 +01:00
|
|
|
extraParams,
|
2019-08-26 18:57:51 +02:00
|
|
|
dimension,
|
|
|
|
dimensionLength, hTADShapeInfo, hTADOffsets, hTADShapeInfoZ,
|
|
|
|
hTADOffsetsZ);
|
|
|
|
} catch (std::exception &e) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
2019-08-26 18:57:51 +02:00
|
|
|
}
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
/**
|
|
|
|
*
|
|
|
|
* @param opNum
|
|
|
|
* @param hX
|
|
|
|
* @param hXShapeInfo
|
|
|
|
* @param hY
|
|
|
|
* @param hYShapeInfo
|
|
|
|
* @param hZ
|
|
|
|
* @param hZShapeInfo
|
|
|
|
* @param extraParams
|
|
|
|
* @param n
|
|
|
|
*/
|
2019-07-22 13:34:08 +02:00
|
|
|
void execPairwiseTransform(
|
2019-06-06 14:21:15 +02:00
|
|
|
Nd4jPointer *extraPointers,
|
|
|
|
int opNum,
|
2020-05-09 07:06:14 +02:00
|
|
|
OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo,
|
|
|
|
OpaqueDataBuffer *dbY, const Nd4jLong *hYShapeInfo, const Nd4jLong *dYShapeInfo,
|
|
|
|
OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo,
|
2019-06-06 14:21:15 +02:00
|
|
|
void *extraParams) {
|
2019-08-26 18:57:51 +02:00
|
|
|
try {
|
|
|
|
NativeOpExecutioner::execPairwiseTransform(nullptr,
|
|
|
|
opNum,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbX->primary(),
|
2019-08-26 18:57:51 +02:00
|
|
|
hXShapeInfo,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbX->special(),
|
2019-08-26 18:57:51 +02:00
|
|
|
dXShapeInfo,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbY->primary(),
|
2019-08-26 18:57:51 +02:00
|
|
|
hYShapeInfo,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbY->special(),
|
2019-08-26 18:57:51 +02:00
|
|
|
dYShapeInfo,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbZ->primary(),
|
2019-08-26 18:57:51 +02:00
|
|
|
hZShapeInfo,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbZ->special(),
|
2019-08-26 18:57:51 +02:00
|
|
|
dZShapeInfo,
|
|
|
|
extraParams);
|
|
|
|
} catch (std::exception &e) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
2019-08-26 18:57:51 +02:00
|
|
|
}
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
void execPairwiseTransformBool(
|
2019-06-06 14:21:15 +02:00
|
|
|
Nd4jPointer *extraPointers,
|
|
|
|
int opNum,
|
2020-05-09 07:06:14 +02:00
|
|
|
OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo,
|
|
|
|
OpaqueDataBuffer *dbY, const Nd4jLong *hYShapeInfo, const Nd4jLong *dYShapeInfo,
|
|
|
|
OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo,
|
2019-06-06 14:21:15 +02:00
|
|
|
void *extraParams) {
|
2019-08-26 18:57:51 +02:00
|
|
|
|
|
|
|
try {
|
|
|
|
NativeOpExecutioner::execPairwiseBoolTransform(nullptr,
|
|
|
|
opNum,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbX->primary(),
|
2019-08-26 18:57:51 +02:00
|
|
|
hXShapeInfo,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbX->special(),
|
2019-08-26 18:57:51 +02:00
|
|
|
dXShapeInfo,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbY->primary(),
|
2019-08-26 18:57:51 +02:00
|
|
|
hYShapeInfo,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbY->special(),
|
2019-08-26 18:57:51 +02:00
|
|
|
dYShapeInfo,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbZ->primary(),
|
2019-08-26 18:57:51 +02:00
|
|
|
hZShapeInfo,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbZ->special(),
|
2019-08-26 18:57:51 +02:00
|
|
|
dZShapeInfo,
|
|
|
|
extraParams);
|
|
|
|
} catch (std::exception &e) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
2019-08-26 18:57:51 +02:00
|
|
|
}
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
/**
|
|
|
|
*
|
|
|
|
* @param opNum
|
|
|
|
* @param hX
|
|
|
|
* @param hXShapeInfo
|
|
|
|
* @param extraParams
|
|
|
|
* @param hZ
|
|
|
|
* @param hZShapeInfo
|
|
|
|
*/
|
2019-07-22 13:34:08 +02:00
|
|
|
void execReduceFloat(
|
2019-06-06 14:21:15 +02:00
|
|
|
Nd4jPointer *extraPointers,
|
|
|
|
int opNum,
|
2020-05-09 07:06:14 +02:00
|
|
|
OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo,
|
2019-06-06 14:21:15 +02:00
|
|
|
void *extraParams,
|
2020-05-09 07:06:14 +02:00
|
|
|
OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo) {
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-26 18:57:51 +02:00
|
|
|
try {
|
|
|
|
NativeOpExecutioner::execReduceFloatScalar(nullptr,
|
|
|
|
opNum,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbX->primary(),
|
2019-08-26 18:57:51 +02:00
|
|
|
hXShapeInfo,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbX->special(),
|
2019-08-26 18:57:51 +02:00
|
|
|
dXShapeInfo,
|
|
|
|
extraParams,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbZ->primary(),
|
2019-08-26 18:57:51 +02:00
|
|
|
hZShapeInfo,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbZ->special(),
|
2019-08-26 18:57:51 +02:00
|
|
|
dZShapeInfo);
|
|
|
|
} catch (std::exception &e) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
2019-08-26 18:57:51 +02:00
|
|
|
}
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
void execReduceSame(
|
2019-06-06 14:21:15 +02:00
|
|
|
Nd4jPointer *extraPointers,
|
|
|
|
int opNum,
|
2020-05-09 07:06:14 +02:00
|
|
|
OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo,
|
2019-06-06 14:21:15 +02:00
|
|
|
void *extraParams,
|
2020-05-09 07:06:14 +02:00
|
|
|
OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo) {
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-26 18:57:51 +02:00
|
|
|
try {
|
|
|
|
NativeOpExecutioner::execReduceSameScalar(nullptr,
|
|
|
|
opNum,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbX->primary(),
|
2019-08-26 18:57:51 +02:00
|
|
|
hXShapeInfo,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbX->special(),
|
2019-08-26 18:57:51 +02:00
|
|
|
dXShapeInfo,
|
|
|
|
extraParams,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbZ->primary(),
|
2019-08-26 18:57:51 +02:00
|
|
|
hZShapeInfo,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbZ->special(),
|
2019-08-26 18:57:51 +02:00
|
|
|
dZShapeInfo);
|
|
|
|
} catch (std::exception &e) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
2019-08-26 18:57:51 +02:00
|
|
|
}
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
void execReduceBool(
|
2019-06-06 14:21:15 +02:00
|
|
|
Nd4jPointer *extraPointers,
|
|
|
|
int opNum,
|
2020-05-09 07:06:14 +02:00
|
|
|
OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo,
|
2019-06-06 14:21:15 +02:00
|
|
|
void *extraParams,
|
2020-05-09 07:06:14 +02:00
|
|
|
OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo) {
|
2019-08-26 18:57:51 +02:00
|
|
|
try {
|
|
|
|
NativeOpExecutioner::execReduceBoolScalar(nullptr,
|
|
|
|
opNum,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbX->primary(),
|
2019-08-26 18:57:51 +02:00
|
|
|
hXShapeInfo,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbX->special(),
|
2019-08-26 18:57:51 +02:00
|
|
|
dXShapeInfo,
|
|
|
|
extraParams,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbZ->primary(),
|
2019-08-26 18:57:51 +02:00
|
|
|
hZShapeInfo,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbZ->special(),
|
2019-08-26 18:57:51 +02:00
|
|
|
dZShapeInfo);
|
|
|
|
} catch (std::exception &e) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
2019-08-26 18:57:51 +02:00
|
|
|
}
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
void execReduceLong(
|
2019-06-06 14:21:15 +02:00
|
|
|
Nd4jPointer *extraPointers,
|
|
|
|
int opNum,
|
2020-05-09 07:06:14 +02:00
|
|
|
OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo,
|
2019-06-06 14:21:15 +02:00
|
|
|
void *extraParams,
|
2020-05-09 07:06:14 +02:00
|
|
|
OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo) {
|
2019-08-26 18:57:51 +02:00
|
|
|
try {
|
|
|
|
NativeOpExecutioner::execReduceLongScalar(nullptr,
|
|
|
|
opNum,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbX->primary(),
|
2019-08-26 18:57:51 +02:00
|
|
|
hXShapeInfo,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbX->special(),
|
2019-08-26 18:57:51 +02:00
|
|
|
dXShapeInfo,
|
|
|
|
extraParams,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbZ->primary(),
|
2019-08-26 18:57:51 +02:00
|
|
|
hZShapeInfo,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbZ->special(),
|
2019-08-26 18:57:51 +02:00
|
|
|
dZShapeInfo);
|
|
|
|
} catch (std::exception &e) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
2019-08-26 18:57:51 +02:00
|
|
|
}
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
/**
|
|
|
|
*
|
|
|
|
* @param opNum
|
|
|
|
* @param hX
|
|
|
|
* @param hXShapeInfo
|
|
|
|
* @param extraParams
|
|
|
|
* @param hZ
|
|
|
|
* @param hZShapeInfo
|
|
|
|
*/
|
2019-07-22 13:34:08 +02:00
|
|
|
void execReduceFloat2(Nd4jPointer *extraPointers,
|
2019-06-06 14:21:15 +02:00
|
|
|
int opNum,
|
2020-05-09 07:06:14 +02:00
|
|
|
OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo,
|
2019-06-06 14:21:15 +02:00
|
|
|
void *extraParams,
|
2020-05-09 07:06:14 +02:00
|
|
|
OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo,
|
|
|
|
OpaqueDataBuffer *dbDimension, const Nd4jLong *hDimensionShape, const Nd4jLong *dDimensionShape) {
|
2019-08-26 18:57:51 +02:00
|
|
|
try {
|
2020-01-04 11:27:50 +01:00
|
|
|
auto dimension = reinterpret_cast<int *>(dbDimension->primary());
|
2020-05-09 07:06:14 +02:00
|
|
|
auto dimensionLength = static_cast<int>(shape::length(hDimensionShape));
|
2019-08-26 18:57:51 +02:00
|
|
|
|
2020-06-06 14:26:55 +02:00
|
|
|
auto tadPackX = sd::ConstantTadHelper::getInstance().tadForDimensions(hXShapeInfo, dimension, dimensionLength);
|
2019-08-26 18:57:51 +02:00
|
|
|
|
|
|
|
auto hTADShapeInfo = tadPackX.primaryShapeInfo();
|
|
|
|
auto hTADOffsets = tadPackX.primaryOffsets();
|
|
|
|
|
|
|
|
NativeOpExecutioner::execReduceFloat(nullptr, opNum,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbX->primary(),
|
2019-08-26 18:57:51 +02:00
|
|
|
hXShapeInfo,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbX->special(),
|
2019-08-26 18:57:51 +02:00
|
|
|
dXShapeInfo,
|
|
|
|
extraParams,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbZ->primary(),
|
2019-08-26 18:57:51 +02:00
|
|
|
hZShapeInfo,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbZ->special(),
|
2019-08-26 18:57:51 +02:00
|
|
|
dZShapeInfo,
|
|
|
|
dimension,
|
|
|
|
dimensionLength,
|
|
|
|
hTADShapeInfo,
|
|
|
|
hTADOffsets);
|
|
|
|
} catch (std::exception &e) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
2019-08-26 18:57:51 +02:00
|
|
|
}
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
void execReduceBool2(Nd4jPointer *extraPointers,
|
2019-06-06 14:21:15 +02:00
|
|
|
int opNum,
|
2020-05-09 07:06:14 +02:00
|
|
|
OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo,
|
2019-06-06 14:21:15 +02:00
|
|
|
void *extraParams,
|
2020-05-09 07:06:14 +02:00
|
|
|
OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo,
|
|
|
|
OpaqueDataBuffer *dbDimension, const Nd4jLong *hDimensionShape, const Nd4jLong *dDimensionShape) {
|
2019-08-26 18:57:51 +02:00
|
|
|
try {
|
2020-01-04 11:27:50 +01:00
|
|
|
auto dimension = reinterpret_cast<int *>(dbDimension->primary());
|
2020-05-09 07:06:14 +02:00
|
|
|
auto dimensionLength = static_cast<int>(shape::length(hDimensionShape));
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-06-06 14:26:55 +02:00
|
|
|
auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(hXShapeInfo, dimension,
|
2019-08-26 18:57:51 +02:00
|
|
|
dimensionLength);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-26 18:57:51 +02:00
|
|
|
auto hTADShapeInfo = tadPack.primaryShapeInfo();
|
|
|
|
auto hTADOffsets = tadPack.primaryOffsets();
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-26 18:57:51 +02:00
|
|
|
NativeOpExecutioner::execReduceBool(nullptr, opNum,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbX->primary(),
|
2019-08-26 18:57:51 +02:00
|
|
|
hXShapeInfo,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbX->special(),
|
2019-08-26 18:57:51 +02:00
|
|
|
dXShapeInfo,
|
|
|
|
extraParams,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbZ->primary(),
|
2019-08-26 18:57:51 +02:00
|
|
|
hZShapeInfo,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbZ->special(),
|
2019-08-26 18:57:51 +02:00
|
|
|
dZShapeInfo,
|
|
|
|
dimension,
|
|
|
|
dimensionLength,
|
|
|
|
hTADShapeInfo,
|
|
|
|
hTADOffsets);
|
|
|
|
} catch (std::exception &e) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
2019-08-26 18:57:51 +02:00
|
|
|
}
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
void execReduceSame2(Nd4jPointer *extraPointers,
|
2019-06-06 14:21:15 +02:00
|
|
|
int opNum,
|
2020-05-09 07:06:14 +02:00
|
|
|
OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo,
|
2019-06-06 14:21:15 +02:00
|
|
|
void *extraParams,
|
2020-05-09 07:06:14 +02:00
|
|
|
OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo,
|
|
|
|
OpaqueDataBuffer *dbDimension, const Nd4jLong *hDimensionShape, const Nd4jLong *dDimensionShape) {
|
2019-08-26 18:57:51 +02:00
|
|
|
try {
|
2020-01-04 11:27:50 +01:00
|
|
|
auto dimension = reinterpret_cast<int *>(dbDimension->primary());
|
2019-08-26 18:57:51 +02:00
|
|
|
int dimensionLength = static_cast<int>(shape::length(hDimensionShape));
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-06-06 14:26:55 +02:00
|
|
|
auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(hXShapeInfo, dimension,
|
2019-08-26 18:57:51 +02:00
|
|
|
dimensionLength);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-26 18:57:51 +02:00
|
|
|
auto hTADShapeInfo = tadPack.primaryShapeInfo();
|
|
|
|
auto hTADOffsets = tadPack.primaryOffsets();
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-26 18:57:51 +02:00
|
|
|
NativeOpExecutioner::execReduceSame(nullptr, opNum,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbX->primary(),
|
2019-08-26 18:57:51 +02:00
|
|
|
hXShapeInfo,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbX->special(),
|
2019-08-26 18:57:51 +02:00
|
|
|
dXShapeInfo,
|
|
|
|
extraParams,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbZ->primary(),
|
2019-08-26 18:57:51 +02:00
|
|
|
hZShapeInfo,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbZ->special(),
|
2019-08-26 18:57:51 +02:00
|
|
|
dZShapeInfo,
|
|
|
|
dimension,
|
|
|
|
dimensionLength,
|
|
|
|
hTADShapeInfo,
|
|
|
|
hTADOffsets);
|
|
|
|
} catch (std::exception &e) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
2019-08-26 18:57:51 +02:00
|
|
|
}
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
void execReduceLong2(Nd4jPointer *extraPointers,
|
2019-06-06 14:21:15 +02:00
|
|
|
int opNum,
|
2020-05-09 07:06:14 +02:00
|
|
|
OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo,
|
2019-06-06 14:21:15 +02:00
|
|
|
void *extraParams,
|
2020-05-09 07:06:14 +02:00
|
|
|
OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo,
|
|
|
|
OpaqueDataBuffer *dbDimension, const Nd4jLong *hDimensionShape, const Nd4jLong *dDimensionShape) {
|
2019-08-26 18:57:51 +02:00
|
|
|
try {
|
2020-01-04 11:27:50 +01:00
|
|
|
auto dimension = reinterpret_cast<int *>(dbDimension->primary());
|
2019-08-26 18:57:51 +02:00
|
|
|
int dimensionLength = static_cast<int>(shape::length(hDimensionShape));
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-06-06 14:26:55 +02:00
|
|
|
auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(hXShapeInfo, dimension, dimensionLength);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-26 18:57:51 +02:00
|
|
|
auto hTADShapeInfo = tadPack.primaryShapeInfo();
|
|
|
|
auto hTADOffsets = tadPack.primaryOffsets();
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-26 18:57:51 +02:00
|
|
|
NativeOpExecutioner::execReduceLong(nullptr, opNum,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbX->primary(),
|
2019-08-26 18:57:51 +02:00
|
|
|
hXShapeInfo,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbX->special(),
|
2019-08-26 18:57:51 +02:00
|
|
|
dXShapeInfo,
|
|
|
|
extraParams,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbZ->primary(),
|
2019-08-26 18:57:51 +02:00
|
|
|
hZShapeInfo,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbZ->special(),
|
2019-08-26 18:57:51 +02:00
|
|
|
dZShapeInfo,
|
|
|
|
dimension,
|
|
|
|
dimensionLength,
|
|
|
|
hTADShapeInfo,
|
|
|
|
hTADOffsets);
|
|
|
|
} catch (std::exception &e) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
2019-08-26 18:57:51 +02:00
|
|
|
}
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
/**
|
|
|
|
*
|
|
|
|
* @param opNum
|
|
|
|
* @param hX
|
|
|
|
* @param hXShapeInfo
|
|
|
|
* @param extraParamsVals
|
|
|
|
* @param hY
|
|
|
|
* @param hYShapeInfo
|
|
|
|
* @param hZ
|
|
|
|
* @param hZShapeInfo
|
|
|
|
*/
|
2019-07-22 13:34:08 +02:00
|
|
|
void execReduce3(Nd4jPointer *extraPointers,
|
2019-06-06 14:21:15 +02:00
|
|
|
int opNum,
|
2020-05-09 07:06:14 +02:00
|
|
|
OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo,
|
2019-06-06 14:21:15 +02:00
|
|
|
void *extraParams,
|
2020-05-09 07:06:14 +02:00
|
|
|
OpaqueDataBuffer *dbY, const Nd4jLong *hYShapeInfo, const Nd4jLong *dYShapeInfo,
|
|
|
|
OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo) {
|
2019-08-26 18:57:51 +02:00
|
|
|
try {
|
2020-01-04 11:27:50 +01:00
|
|
|
NativeOpExecutioner::execReduce3(nullptr, opNum, dbX->primary(), hXShapeInfo, dbX->special(), dXShapeInfo, extraParams, dbY->primary(), hYShapeInfo,
|
|
|
|
dbY->special(), dYShapeInfo, dbZ->primary(), hZShapeInfo, dbZ->special(), dZShapeInfo);
|
2019-08-26 18:57:51 +02:00
|
|
|
} catch (std::exception &e) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
2019-08-26 18:57:51 +02:00
|
|
|
}
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
/**
|
|
|
|
*
|
|
|
|
* @param opNum
|
|
|
|
* @param hX
|
|
|
|
* @param hXShapeInfo
|
|
|
|
* @param extraParamsVals
|
|
|
|
* @param hY
|
|
|
|
* @param hYShapeInfo
|
|
|
|
*/
|
2019-07-22 13:34:08 +02:00
|
|
|
void execReduce3Scalar(Nd4jPointer *extraPointers,int opNum,
|
2020-05-09 07:06:14 +02:00
|
|
|
OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo,
|
2019-06-06 14:21:15 +02:00
|
|
|
void *extraParams,
|
2020-05-09 07:06:14 +02:00
|
|
|
OpaqueDataBuffer *dbY, const Nd4jLong *hYShapeInfo, const Nd4jLong *dYShapeInfo,
|
|
|
|
OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo) {
|
2019-08-26 18:57:51 +02:00
|
|
|
try {
|
2020-01-04 11:27:50 +01:00
|
|
|
NativeOpExecutioner::execReduce3Scalar(nullptr, opNum, dbX->primary(), hXShapeInfo, dbX->special(), dXShapeInfo, extraParams, dbY->primary(),
|
|
|
|
hYShapeInfo, dbY->special(), dYShapeInfo, dbZ->primary(), hZShapeInfo, dbZ->special(), dZShapeInfo);
|
2019-08-26 18:57:51 +02:00
|
|
|
} catch (std::exception &e) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
2019-08-26 18:57:51 +02:00
|
|
|
}
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
/**
|
|
|
|
*
|
|
|
|
* @param opNum
|
|
|
|
* @param hX
|
|
|
|
* @param hXShapeInfo
|
|
|
|
* @param extraParamsVals
|
|
|
|
* @param hY
|
|
|
|
* @param hYShapeInfo
|
|
|
|
* @param hZ
|
|
|
|
* @param hZShapeInfo
|
|
|
|
* @param dimension
|
|
|
|
* @param dimensionLength
|
|
|
|
*/
|
2019-07-22 13:34:08 +02:00
|
|
|
void execReduce3Tad(Nd4jPointer *extraPointers,
|
2019-06-06 14:21:15 +02:00
|
|
|
int opNum,
|
2020-05-09 07:06:14 +02:00
|
|
|
OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo,
|
2019-06-06 14:21:15 +02:00
|
|
|
void *extraParams,
|
2020-05-09 07:06:14 +02:00
|
|
|
OpaqueDataBuffer *dbY, const Nd4jLong *hYShapeInfo, const Nd4jLong *dYShapeInfo,
|
|
|
|
OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo,
|
|
|
|
OpaqueDataBuffer *dbDimension, const Nd4jLong *hDimensionShape, const Nd4jLong *dDimensionShape,
|
|
|
|
const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets,
|
|
|
|
const Nd4jLong *yTadOnlyShapeInfo, const Nd4jLong *yTadOffsets) {
|
2019-08-26 18:57:51 +02:00
|
|
|
try {
|
2020-01-04 11:27:50 +01:00
|
|
|
auto dimension = reinterpret_cast<int *>(dbDimension->primary());
|
2020-05-09 07:06:14 +02:00
|
|
|
auto dimensionLength = static_cast<int>(shape::length(hDimensionShape));
|
2019-08-26 18:57:51 +02:00
|
|
|
|
|
|
|
if (extraPointers == nullptr || extraPointers[2] == 0) {
|
2020-01-04 11:27:50 +01:00
|
|
|
NativeOpExecutioner::execReduce3(LaunchContext::defaultContext(), opNum, dbX->primary(), hXShapeInfo, dbX->special(), dXShapeInfo,
|
|
|
|
extraParams, dbY->primary(), hYShapeInfo, dbY->special(), dYShapeInfo, dbZ->primary(), hZShapeInfo, dbZ->special(),
|
2019-08-26 18:57:51 +02:00
|
|
|
dZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets,
|
|
|
|
yTadOnlyShapeInfo, yTadOffsets);
|
|
|
|
} else {
|
|
|
|
// going tad-way
|
2020-06-06 14:26:55 +02:00
|
|
|
auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(hXShapeInfo, dimension,
|
2019-08-26 18:57:51 +02:00
|
|
|
dimensionLength);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-26 18:57:51 +02:00
|
|
|
auto hTADShapeInfo = tadPack.primaryShapeInfo();
|
|
|
|
auto hTADOffsets = tadPack.primaryOffsets();
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-01-04 11:27:50 +01:00
|
|
|
NativeOpExecutioner::execReduce3TAD(LaunchContext::defaultContext(), opNum, dbX->primary(), hXShapeInfo, dbX->special(),
|
|
|
|
dXShapeInfo, extraParams, dbY->primary(), hYShapeInfo, dbY->special(), dYShapeInfo, dbZ->primary(),
|
|
|
|
hZShapeInfo, dbZ->special(), dZShapeInfo, dimension, dimensionLength, hTADShapeInfo,
|
2019-08-26 18:57:51 +02:00
|
|
|
hTADOffsets, nullptr, nullptr);
|
|
|
|
}
|
|
|
|
} catch (std::exception &e) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2019-09-04 13:41:08 +02:00
|
|
|
bool isBlasVersionMatches(int major, int minor, int build) {
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
/**
|
|
|
|
*
|
|
|
|
* @param opNum
|
|
|
|
* @param hX
|
|
|
|
* @param hXShapeInfo
|
|
|
|
* @param hZ
|
|
|
|
* @param hZShapeInfo
|
|
|
|
* @param hScalar
|
|
|
|
* @param extraParams
|
|
|
|
* @param n
|
|
|
|
*/
|
2019-07-22 13:34:08 +02:00
|
|
|
void execScalar(
|
2019-06-06 14:21:15 +02:00
|
|
|
Nd4jPointer *extraPointers,
|
|
|
|
int opNum,
|
2020-05-09 07:06:14 +02:00
|
|
|
OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo,
|
|
|
|
OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo,
|
|
|
|
OpaqueDataBuffer *dbScalar, const Nd4jLong *hScalarShapeInfo, const Nd4jLong *dScalarShapeInfo,
|
2019-06-06 14:21:15 +02:00
|
|
|
void *extraParams) {
|
2019-08-26 18:57:51 +02:00
|
|
|
try {
|
|
|
|
NativeOpExecutioner::execScalar(nullptr,
|
|
|
|
opNum,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbX->primary(),
|
2019-06-06 14:21:15 +02:00
|
|
|
hXShapeInfo,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbX->special(),
|
2019-06-06 14:21:15 +02:00
|
|
|
dXShapeInfo,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbZ->primary(),
|
2019-06-06 14:21:15 +02:00
|
|
|
hZShapeInfo,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbZ->special(),
|
2019-06-06 14:21:15 +02:00
|
|
|
dZShapeInfo,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbScalar->primary(),
|
2019-06-06 14:21:15 +02:00
|
|
|
hScalarShapeInfo,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbScalar->special(),
|
2019-06-06 14:21:15 +02:00
|
|
|
dScalarShapeInfo,
|
2019-08-26 18:57:51 +02:00
|
|
|
extraParams);
|
|
|
|
} catch (std::exception &e) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
2019-08-26 18:57:51 +02:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
void execScalarBool(
|
|
|
|
Nd4jPointer *extraPointers,
|
|
|
|
int opNum,
|
2020-05-09 07:06:14 +02:00
|
|
|
OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo,
|
|
|
|
OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo,
|
|
|
|
OpaqueDataBuffer *dbScalar, const Nd4jLong *hScalarShapeInfo, const Nd4jLong *dScalarShapeInfo,
|
2019-08-26 18:57:51 +02:00
|
|
|
void *extraParams) {
|
|
|
|
try {
|
|
|
|
NativeOpExecutioner::execScalarBool(nullptr,
|
|
|
|
opNum,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbX->primary(),
|
2019-08-26 18:57:51 +02:00
|
|
|
hXShapeInfo,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbX->special(),
|
2019-08-26 18:57:51 +02:00
|
|
|
dXShapeInfo,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbZ->primary(),
|
2019-08-26 18:57:51 +02:00
|
|
|
hZShapeInfo,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbZ->special(),
|
2019-08-26 18:57:51 +02:00
|
|
|
dZShapeInfo,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbScalar->primary(),
|
2019-08-26 18:57:51 +02:00
|
|
|
hScalarShapeInfo,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbScalar->special(),
|
2019-08-26 18:57:51 +02:00
|
|
|
dScalarShapeInfo,
|
|
|
|
extraParams);
|
|
|
|
} catch (std::exception &e) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
2019-08-26 18:57:51 +02:00
|
|
|
}
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
/**
|
|
|
|
*
|
|
|
|
* @param opNum
|
|
|
|
* @param hX
|
|
|
|
* @param hXShapeInfo
|
|
|
|
* @param extraParams
|
|
|
|
*/
|
2019-07-22 13:34:08 +02:00
|
|
|
void execSummaryStatsScalar(Nd4jPointer *extraPointers,
|
2019-06-06 14:21:15 +02:00
|
|
|
int opNum,
|
2020-05-09 07:06:14 +02:00
|
|
|
OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo,
|
2019-06-06 14:21:15 +02:00
|
|
|
void *extraParams,
|
2020-05-09 07:06:14 +02:00
|
|
|
OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo,
|
2019-06-06 14:21:15 +02:00
|
|
|
bool biasCorrected) {
|
2019-08-26 18:57:51 +02:00
|
|
|
try {
|
|
|
|
NativeOpExecutioner::execSummaryStatsScalar(nullptr,
|
|
|
|
opNum,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbX->primary(),
|
2019-08-26 18:57:51 +02:00
|
|
|
hXShapeInfo,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbX->special(),
|
2019-08-26 18:57:51 +02:00
|
|
|
dXShapeInfo,
|
|
|
|
extraParams,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbZ->primary(),
|
2019-08-26 18:57:51 +02:00
|
|
|
hZShapeInfo,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbZ->special(),
|
2019-08-26 18:57:51 +02:00
|
|
|
dZShapeInfo,
|
|
|
|
biasCorrected);
|
|
|
|
} catch (std::exception &e) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
2019-08-26 18:57:51 +02:00
|
|
|
}
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
/**
|
|
|
|
*
|
|
|
|
* @param opNum
|
|
|
|
* @param hX
|
|
|
|
* @param hXShapeInfo
|
|
|
|
* @param extraParams
|
|
|
|
* @param hZ
|
|
|
|
* @param hZShapeInfo
|
|
|
|
*/
|
2019-07-22 13:34:08 +02:00
|
|
|
void execSummaryStats(Nd4jPointer *extraPointers,
|
2019-06-06 14:21:15 +02:00
|
|
|
int opNum,
|
2020-05-09 07:06:14 +02:00
|
|
|
OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo,
|
2019-06-06 14:21:15 +02:00
|
|
|
void *extraParams,
|
2020-05-09 07:06:14 +02:00
|
|
|
OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo,
|
2019-06-06 14:21:15 +02:00
|
|
|
bool biasCorrected) {
|
2019-08-26 18:57:51 +02:00
|
|
|
try {
|
|
|
|
NativeOpExecutioner::execSummaryStats(nullptr,
|
|
|
|
opNum,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbX->primary(),
|
2019-08-26 18:57:51 +02:00
|
|
|
hXShapeInfo,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbX->special(),
|
2019-08-26 18:57:51 +02:00
|
|
|
dXShapeInfo,
|
|
|
|
extraParams,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbZ->primary(),
|
2019-08-26 18:57:51 +02:00
|
|
|
hZShapeInfo,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbZ->special(),
|
2019-08-26 18:57:51 +02:00
|
|
|
dZShapeInfo,
|
|
|
|
biasCorrected);
|
|
|
|
} catch (std::exception &e) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
2019-08-26 18:57:51 +02:00
|
|
|
}
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
/**
|
|
|
|
*
|
|
|
|
* @param opNum
|
|
|
|
* @param hX
|
|
|
|
* @param hXShapeInfo
|
|
|
|
* @param extraParams
|
|
|
|
* @param hZ
|
|
|
|
* @param hZShapeInfo
|
|
|
|
* @param dimension
|
|
|
|
* @param dimensionLength
|
|
|
|
*/
|
2019-07-22 13:34:08 +02:00
|
|
|
void execSummaryStatsTad(Nd4jPointer *extraPointers,
|
2019-06-06 14:21:15 +02:00
|
|
|
int opNum,
|
2020-05-09 07:06:14 +02:00
|
|
|
OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo,
|
2019-06-06 14:21:15 +02:00
|
|
|
void *extraParams,
|
2020-05-09 07:06:14 +02:00
|
|
|
OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo,
|
|
|
|
OpaqueDataBuffer *dbDimension, const Nd4jLong *hDimensionShape, const Nd4jLong *dDimensionShape,
|
2019-06-06 14:21:15 +02:00
|
|
|
bool biasCorrected,
|
2020-05-09 07:06:14 +02:00
|
|
|
const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) {
|
2019-08-26 18:57:51 +02:00
|
|
|
try {
|
2020-01-04 11:27:50 +01:00
|
|
|
auto dimension = reinterpret_cast<int *>(dbDimension->primary());
|
2019-08-26 18:57:51 +02:00
|
|
|
int dimensionLength = static_cast<int>(shape::length(hDimensionShape));
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
|
2019-08-26 18:57:51 +02:00
|
|
|
NativeOpExecutioner::execSummaryStats(nullptr,
|
|
|
|
opNum,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbX->primary(),
|
2019-08-26 18:57:51 +02:00
|
|
|
hXShapeInfo,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbX->special(),
|
2019-08-26 18:57:51 +02:00
|
|
|
dXShapeInfo,
|
|
|
|
extraParams,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbZ->primary(),
|
2019-08-26 18:57:51 +02:00
|
|
|
hZShapeInfo,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbZ->special(),
|
2019-08-26 18:57:51 +02:00
|
|
|
dZShapeInfo,
|
|
|
|
dimension,
|
|
|
|
dimensionLength,
|
|
|
|
tadShapeInfo,
|
|
|
|
tadOffsets,
|
|
|
|
biasCorrected);
|
|
|
|
} catch (std::exception &e) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
2019-08-26 18:57:51 +02:00
|
|
|
}
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
/**
|
|
|
|
*
|
|
|
|
* @param opNum
|
|
|
|
* @param hX
|
|
|
|
* @param hXShapeInfo
|
|
|
|
* @param hZ
|
|
|
|
* @param hZShapeInfo
|
|
|
|
* @param extraParams
|
|
|
|
* @param n
|
|
|
|
*/
|
2019-07-22 13:34:08 +02:00
|
|
|
void execTransformFloat(
|
2019-06-06 14:21:15 +02:00
|
|
|
Nd4jPointer *extraPointers,
|
|
|
|
int opNum,
|
2020-05-09 07:06:14 +02:00
|
|
|
OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo,
|
|
|
|
OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo,
|
2019-06-06 14:21:15 +02:00
|
|
|
void *extraParams) {
|
2019-08-26 18:57:51 +02:00
|
|
|
try {
|
|
|
|
NativeOpExecutioner::execTransformFloat(nullptr,
|
|
|
|
opNum,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbX->primary(),
|
2019-08-26 18:57:51 +02:00
|
|
|
hXShapeInfo,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbX->special(),
|
2019-08-26 18:57:51 +02:00
|
|
|
dXShapeInfo,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbZ->primary(),
|
2019-08-26 18:57:51 +02:00
|
|
|
hZShapeInfo,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbZ->special(),
|
2019-08-26 18:57:51 +02:00
|
|
|
dZShapeInfo,
|
|
|
|
extraParams,
|
|
|
|
nullptr,
|
|
|
|
nullptr);
|
|
|
|
} catch (std::exception &e) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
2019-08-26 18:57:51 +02:00
|
|
|
}
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
void execTransformSame(
|
2019-06-06 14:21:15 +02:00
|
|
|
Nd4jPointer *extraPointers,
|
|
|
|
int opNum,
|
2020-05-09 07:06:14 +02:00
|
|
|
OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo,
|
|
|
|
OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo,
|
2019-06-06 14:21:15 +02:00
|
|
|
void *extraParams) {
|
2019-08-26 18:57:51 +02:00
|
|
|
try {
|
|
|
|
NativeOpExecutioner::execTransformSame(nullptr,
|
|
|
|
opNum,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbX->primary(),
|
2019-08-26 18:57:51 +02:00
|
|
|
hXShapeInfo,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbX->special(),
|
2019-08-26 18:57:51 +02:00
|
|
|
dXShapeInfo,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbZ->primary(),
|
2019-08-26 18:57:51 +02:00
|
|
|
hZShapeInfo,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbZ->special(),
|
2019-08-26 18:57:51 +02:00
|
|
|
dZShapeInfo,
|
|
|
|
extraParams,
|
|
|
|
nullptr,
|
|
|
|
nullptr);
|
|
|
|
} catch (std::exception &e) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
2019-08-26 18:57:51 +02:00
|
|
|
}
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
void execTransformBool(
|
2019-06-06 14:21:15 +02:00
|
|
|
Nd4jPointer *extraPointers,
|
|
|
|
int opNum,
|
2020-05-09 07:06:14 +02:00
|
|
|
OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo,
|
|
|
|
OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo,
|
2019-06-06 14:21:15 +02:00
|
|
|
void *extraParams) {
|
2019-08-26 18:57:51 +02:00
|
|
|
try {
|
|
|
|
NativeOpExecutioner::execTransformBool(nullptr,
|
|
|
|
opNum,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbX->primary(),
|
2019-08-26 18:57:51 +02:00
|
|
|
hXShapeInfo,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbX->special(),
|
2019-08-26 18:57:51 +02:00
|
|
|
dXShapeInfo,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbZ->primary(),
|
2019-08-26 18:57:51 +02:00
|
|
|
hZShapeInfo,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbZ->special(),
|
2019-08-26 18:57:51 +02:00
|
|
|
dZShapeInfo,
|
|
|
|
extraParams,
|
|
|
|
nullptr,
|
|
|
|
nullptr);
|
|
|
|
} catch (std::exception &e) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
2019-08-26 18:57:51 +02:00
|
|
|
}
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
void execTransformAny(
|
2019-06-06 14:21:15 +02:00
|
|
|
Nd4jPointer *extraPointers,
|
|
|
|
int opNum,
|
2020-05-09 07:06:14 +02:00
|
|
|
OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo,
|
|
|
|
OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo,
|
2019-06-06 14:21:15 +02:00
|
|
|
void *extraParams) {
|
2019-08-26 18:57:51 +02:00
|
|
|
try {
|
|
|
|
NativeOpExecutioner::execTransformAny(nullptr,
|
|
|
|
opNum,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbX->primary(),
|
2019-08-26 18:57:51 +02:00
|
|
|
hXShapeInfo,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbX->special(),
|
2019-08-26 18:57:51 +02:00
|
|
|
dXShapeInfo,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbZ->primary(),
|
2019-08-26 18:57:51 +02:00
|
|
|
hZShapeInfo,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbZ->special(),
|
2019-08-26 18:57:51 +02:00
|
|
|
dZShapeInfo,
|
|
|
|
extraParams,
|
|
|
|
nullptr,
|
|
|
|
nullptr);
|
|
|
|
} catch (std::exception &e) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
2019-08-26 18:57:51 +02:00
|
|
|
}
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
void execTransformStrict(
|
2019-06-06 14:21:15 +02:00
|
|
|
Nd4jPointer *extraPointers,
|
|
|
|
int opNum,
|
2020-05-09 07:06:14 +02:00
|
|
|
OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo,
|
|
|
|
OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo,
|
2019-06-06 14:21:15 +02:00
|
|
|
void *extraParams) {
|
2019-08-26 18:57:51 +02:00
|
|
|
try {
|
|
|
|
NativeOpExecutioner::execTransformStrict(nullptr,
|
|
|
|
opNum,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbX->primary(),
|
2019-08-26 18:57:51 +02:00
|
|
|
hXShapeInfo,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbX->special(),
|
2019-08-26 18:57:51 +02:00
|
|
|
dXShapeInfo,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbZ->primary(),
|
2019-08-26 18:57:51 +02:00
|
|
|
hZShapeInfo,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbZ->special(),
|
2019-08-26 18:57:51 +02:00
|
|
|
dZShapeInfo,
|
|
|
|
extraParams,
|
|
|
|
nullptr,
|
|
|
|
nullptr);
|
|
|
|
} catch (std::exception &e) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
2019-08-26 18:57:51 +02:00
|
|
|
}
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
void execReduce3All(Nd4jPointer *extraPointers,
|
2019-06-06 14:21:15 +02:00
|
|
|
int opNum,
|
2020-05-09 07:06:14 +02:00
|
|
|
OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo,
|
2019-06-06 14:21:15 +02:00
|
|
|
void *extraParamsVals,
|
2020-05-09 07:06:14 +02:00
|
|
|
OpaqueDataBuffer *dbY, const Nd4jLong *hYShapeInfo, const Nd4jLong *dYShapeInfo,
|
|
|
|
OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo,
|
|
|
|
OpaqueDataBuffer *dbDimension, const Nd4jLong *hDimensionShape, const Nd4jLong *dDimensionShape,
|
|
|
|
const Nd4jLong *xTadShapeInfo, const Nd4jLong *xOffsets,
|
|
|
|
const Nd4jLong *yTadShapeInfo, const Nd4jLong *yOffsets) {
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-26 18:57:51 +02:00
|
|
|
try {
|
2020-01-04 11:27:50 +01:00
|
|
|
auto dimension = reinterpret_cast<int *>(dbDimension->primary());
|
2020-05-09 07:06:14 +02:00
|
|
|
auto dimensionLength = static_cast<int>(shape::length(hDimensionShape));
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
|
2020-01-04 11:27:50 +01:00
|
|
|
NativeOpExecutioner::execReduce3All(nullptr, opNum, dbX->primary(), hXShapeInfo, dbX->special(), dXShapeInfo, extraParamsVals, dbY->primary(),
|
|
|
|
hYShapeInfo, dbY->special(), dYShapeInfo, dbZ->primary(), hZShapeInfo, dbZ->special(), dZShapeInfo, dimension,
|
2019-08-26 18:57:51 +02:00
|
|
|
dimensionLength, xTadShapeInfo, xOffsets, yTadShapeInfo, yOffsets);
|
|
|
|
} catch (std::exception &e) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
/**
|
|
|
|
* Concatneate multi array of the same shape together
|
|
|
|
* along a particular dimension
|
|
|
|
*/
|
2019-07-22 13:34:08 +02:00
|
|
|
void specialConcat(
|
2019-06-06 14:21:15 +02:00
|
|
|
Nd4jPointer *extraPointers,
|
|
|
|
int dimension,
|
|
|
|
int numArrays,
|
|
|
|
Nd4jPointer *data,
|
|
|
|
Nd4jPointer *inputShapeInfo,
|
|
|
|
void *hZ,
|
2020-05-09 07:06:14 +02:00
|
|
|
Nd4jLong const* hZShapeInfo,
|
2019-06-06 14:21:15 +02:00
|
|
|
Nd4jPointer *tadPointers,
|
|
|
|
Nd4jPointer *offsetPointers) {
|
2019-08-26 18:57:51 +02:00
|
|
|
try {
|
2020-03-02 10:49:41 +01:00
|
|
|
auto zType = sd::ArrayOptions::dataType(hZShapeInfo);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
BUILD_SINGLE_SELECTOR(zType, sd::SpecialMethods,::concatCpuGeneric(dimension, numArrays, data, inputShapeInfo, hZ, hZShapeInfo), LIBND4J_TYPES);
|
2019-08-26 18:57:51 +02:00
|
|
|
} catch (std::exception &e) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
2019-08-26 18:57:51 +02:00
|
|
|
}
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
/**
|
|
|
|
* This is dummy method for JNI compatibility
|
|
|
|
* Since we'll use this from java, jni compiler would like to have method no matter what.
|
|
|
|
*/
|
2019-07-22 13:34:08 +02:00
|
|
|
void initializeDevicesAndFunctions() {
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
}
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
void initializeFunctions(Nd4jPointer *functions) {
|
2020-06-06 14:26:55 +02:00
|
|
|
sd::BlasHelper::getInstance().initializeFunctions(functions);
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
/**
|
|
|
|
* This method acquires memory chunk of requested size on host side
|
|
|
|
*
|
|
|
|
* @param pointer pointer that'll be used for allocation
|
|
|
|
* @param memorySize memory size, in bytes
|
|
|
|
* @param flags optional parameter
|
|
|
|
*/
|
2019-07-22 13:34:08 +02:00
|
|
|
Nd4jPointer mallocHost(Nd4jLong memorySize, int flags) {
|
2019-11-13 15:04:59 +01:00
|
|
|
return reinterpret_cast<Nd4jPointer>(new int8_t[memorySize]);
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
/**
|
|
|
|
* This method acquires memory chunk of requested size on specified device
|
|
|
|
*
|
|
|
|
* PLEASE NOTE: This method is NOT supported and has NO effect in CPU-based backend.
|
|
|
|
*
|
|
|
|
* @param pointer pointer that'll be used for allocation
|
|
|
|
* @param memorySize memory size, in bytes
|
|
|
|
* @param ptrToDeviceId pointer to deviceId. For cuda that's just and int, for OpenCL that's pointer to device_id, etc
|
|
|
|
* @param flags optional parameter
|
|
|
|
*/
|
2019-07-22 13:34:08 +02:00
|
|
|
Nd4jPointer mallocDevice(Nd4jLong memorySize, int deviceId, int flags) {
|
2019-06-06 14:21:15 +02:00
|
|
|
// not supported
|
|
|
|
return 0L;
|
|
|
|
}
|
|
|
|
|
|
|
|
/**
|
|
|
|
* This method releases previously allocated host memory space
|
|
|
|
*
|
|
|
|
* @param pointer pointer that'll be freed
|
|
|
|
*/
|
2019-07-22 13:34:08 +02:00
|
|
|
int freeHost(Nd4jPointer pointer) {
|
2019-11-13 15:04:59 +01:00
|
|
|
delete[] reinterpret_cast<int8_t *>(pointer);
|
2019-06-06 14:21:15 +02:00
|
|
|
return 1L;
|
|
|
|
}
|
|
|
|
|
|
|
|
/**
|
|
|
|
* This method releases previously allocated memory space on device
|
|
|
|
*
|
|
|
|
* PLEASE NOTE: This method is NOT supported and has NO effect in CPU-based backend.
|
|
|
|
*
|
|
|
|
* @param pointer pointer that'll be freed
|
|
|
|
* @param ptrToDeviceId pointer to deviceId.
|
|
|
|
*/
|
2019-07-22 13:34:08 +02:00
|
|
|
int freeDevice(Nd4jPointer pointer, int deviceId) {
|
2019-06-06 14:21:15 +02:00
|
|
|
// not supported
|
|
|
|
return 0L;
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
* Returns the maximum number open mp threads
|
|
|
|
*/
|
2019-07-22 13:34:08 +02:00
|
|
|
int ompGetMaxThreads() {
|
2019-06-06 14:21:15 +02:00
|
|
|
return omp_get_max_threads();
|
|
|
|
}
|
|
|
|
|
|
|
|
/**
|
|
|
|
* Returns the number open mp threads
|
|
|
|
*/
|
2019-07-22 13:34:08 +02:00
|
|
|
int ompGetNumThreads() {
|
2019-06-06 14:21:15 +02:00
|
|
|
return omp_get_num_threads();
|
|
|
|
}
|
|
|
|
|
|
|
|
/**
|
|
|
|
* Sets the number of openmp threads
|
|
|
|
*/
|
2019-07-22 13:34:08 +02:00
|
|
|
void setOmpNumThreads(int threads) {
|
2019-06-06 14:21:15 +02:00
|
|
|
omp_set_num_threads(threads);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
Nd4jPointer createContext() {
|
2019-06-06 14:21:15 +02:00
|
|
|
return 0L;
|
|
|
|
}
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
Nd4jPointer createStream() {
|
2019-06-06 14:21:15 +02:00
|
|
|
return 0L;
|
|
|
|
}
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
Nd4jPointer createEvent() {
|
2019-06-06 14:21:15 +02:00
|
|
|
return 0L;
|
|
|
|
}
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
int getDeviceMajor(int deviceId ) {
|
2019-06-06 14:21:15 +02:00
|
|
|
return 0;
|
|
|
|
}
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
int getDeviceMinor(int deviceId) {
|
2019-06-06 14:21:15 +02:00
|
|
|
return 0;
|
|
|
|
}
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
int registerEvent(Nd4jPointer event, Nd4jPointer stream) {
|
2019-06-06 14:21:15 +02:00
|
|
|
return 0L;
|
|
|
|
}
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
int setDevice(int deviceId) {
|
2019-06-06 14:21:15 +02:00
|
|
|
return 0L;
|
|
|
|
}
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
Nd4jLong getDeviceFreeMemory(int deviceId) {
|
2019-06-06 14:21:15 +02:00
|
|
|
return 0L;
|
|
|
|
}
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
Nd4jLong getDeviceFreeMemoryDefault() {
|
2019-06-06 14:21:15 +02:00
|
|
|
return 0L;
|
|
|
|
}
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
Nd4jLong getDeviceTotalMemory(int deviceId) {
|
2019-06-06 14:21:15 +02:00
|
|
|
return 0L;
|
|
|
|
}
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
int memcpySync(Nd4jPointer dst, Nd4jPointer src, Nd4jLong size, int flags, Nd4jPointer reserved) {
|
2019-06-06 14:21:15 +02:00
|
|
|
return 0L;
|
|
|
|
}
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
int memcpyAsync(Nd4jPointer dst, Nd4jPointer src, Nd4jLong size, int flags, Nd4jPointer reserved) {
|
2019-06-06 14:21:15 +02:00
|
|
|
return 0L;
|
|
|
|
}
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
int memsetSync(Nd4jPointer dst, int value, Nd4jLong size, int flags, Nd4jPointer reserved) {
|
2019-06-06 14:21:15 +02:00
|
|
|
return 0L;
|
|
|
|
}
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
int memsetAsync(Nd4jPointer dst, int value, Nd4jLong size, int flags, Nd4jPointer reserved) {
|
2019-06-06 14:21:15 +02:00
|
|
|
return 0L;
|
|
|
|
}
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
int destroyEvent(Nd4jPointer event) {
|
2019-06-06 14:21:15 +02:00
|
|
|
return 0L;
|
|
|
|
}
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
int streamSynchronize(Nd4jPointer stream) {
|
2019-06-06 14:21:15 +02:00
|
|
|
return 0L;
|
|
|
|
}
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
int eventSynchronize(Nd4jPointer event) {
|
2019-06-06 14:21:15 +02:00
|
|
|
return 0L;
|
|
|
|
}
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
int getAvailableDevices() {
|
2019-06-06 14:21:15 +02:00
|
|
|
return 0L;
|
|
|
|
}
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
void enableDebugMode(bool reallyEnable) {
|
2020-06-06 14:26:55 +02:00
|
|
|
sd::Environment::getInstance().setDebug(reallyEnable);
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
void enableVerboseMode(bool reallyEnable) {
|
2020-06-06 14:26:55 +02:00
|
|
|
sd::Environment::getInstance().setVerbose(reallyEnable);
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
void setGridLimit(int gridSize) {
|
2019-06-06 14:21:15 +02:00
|
|
|
// no-op
|
|
|
|
}
|
|
|
|
|
2020-05-09 07:06:14 +02:00
|
|
|
sd::TadPack* tadOnlyShapeInfo(Nd4jLong const* hXShapeInfo, int *dimension, int dimensionLength) {
|
2019-06-06 14:21:15 +02:00
|
|
|
auto pack = new TadPack();
|
2019-08-26 18:57:51 +02:00
|
|
|
try {
|
2020-06-06 14:26:55 +02:00
|
|
|
*pack = sd::ConstantTadHelper::getInstance().tadForDimensions(hXShapeInfo, dimension, dimensionLength);
|
2019-08-26 18:57:51 +02:00
|
|
|
} catch (std::exception &e) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
2019-08-26 18:57:51 +02:00
|
|
|
}
|
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
return pack;
|
|
|
|
}
|
|
|
|
|
2020-05-09 07:06:14 +02:00
|
|
|
Nd4jLong const* getPrimaryShapeInfo(sd::TadPack* pack) {
|
|
|
|
return const_cast<Nd4jLong*>(pack->primaryShapeInfo());
|
2019-07-24 14:14:54 +02:00
|
|
|
}
|
2020-05-09 07:06:14 +02:00
|
|
|
|
|
|
|
Nd4jLong const* getPrimaryOffsets(sd::TadPack* pack) {
|
|
|
|
return const_cast<Nd4jLong*>(pack->primaryOffsets());
|
2019-07-24 14:14:54 +02:00
|
|
|
}
|
2020-05-09 07:06:14 +02:00
|
|
|
|
|
|
|
Nd4jLong const* getSpecialShapeInfo(sd::TadPack* pack) {
|
|
|
|
return const_cast<Nd4jLong*>(pack->specialShapeInfo());
|
2019-07-24 14:14:54 +02:00
|
|
|
}
|
2020-05-09 07:06:14 +02:00
|
|
|
|
|
|
|
Nd4jLong const* getSpecialOffsets(sd::TadPack* pack) {
|
|
|
|
return const_cast<Nd4jLong*>(pack->specialOffsets());
|
2019-07-24 14:14:54 +02:00
|
|
|
}
|
2020-05-09 07:06:14 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
Nd4jLong getNumberOfTads(sd::TadPack* pack) {
|
2019-07-24 14:14:54 +02:00
|
|
|
return pack->numberOfTads();
|
|
|
|
}
|
2020-05-09 07:06:14 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
int getShapeInfoLength(sd::TadPack* pack) {
|
2019-07-24 14:14:54 +02:00
|
|
|
return pack->shapeInfoLength();
|
|
|
|
}
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
int memcpyConstantAsync(Nd4jLong dst, Nd4jPointer src, Nd4jLong size, int flags, Nd4jPointer reserved) {
|
2019-06-06 14:21:15 +02:00
|
|
|
// no-op
|
|
|
|
return 0L;
|
|
|
|
}
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
Nd4jPointer getConstantSpace() {
|
2019-06-06 14:21:15 +02:00
|
|
|
// no-op
|
|
|
|
return 0L;
|
|
|
|
}
|
|
|
|
|
|
|
|
template<typename T>
|
|
|
|
void pullRowsGeneric(void *vx,
|
2020-05-09 07:06:14 +02:00
|
|
|
Nd4jLong const* hXShapeInfo,
|
2019-06-06 14:21:15 +02:00
|
|
|
void *vz,
|
2020-05-09 07:06:14 +02:00
|
|
|
Nd4jLong const* hZShapeInfo,
|
2019-06-06 14:21:15 +02:00
|
|
|
const int n,
|
2020-05-09 07:06:14 +02:00
|
|
|
Nd4jLong const* indexes,
|
|
|
|
Nd4jLong const* tadShapeInfo,
|
|
|
|
Nd4jLong const* tadOffsets,
|
|
|
|
Nd4jLong const* zTadShapeInfo,
|
|
|
|
Nd4jLong const* zTadOffsets) {
|
2019-06-06 14:21:15 +02:00
|
|
|
auto hX = reinterpret_cast<T *>(vx);
|
|
|
|
auto hZ = reinterpret_cast<T *>(vz);
|
|
|
|
|
|
|
|
const auto xEWS = shape::elementWiseStride(tadShapeInfo);
|
|
|
|
const auto zEWS = shape::elementWiseStride(zTadShapeInfo);
|
|
|
|
const auto tadLength = shape::length(tadShapeInfo);
|
|
|
|
|
|
|
|
int elementsPerThread = n / TAD_THRESHOLD;
|
2020-03-02 10:49:41 +01:00
|
|
|
int _threads = sd::math::nd4j_max<int>(1, elementsPerThread);
|
2020-06-06 14:26:55 +02:00
|
|
|
_threads = sd::math::nd4j_min<int>(_threads, sd::Environment::getInstance().maxThreads());
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-11-13 15:04:59 +01:00
|
|
|
auto func = PRAGMA_THREADS_FOR {
|
2020-02-20 09:43:26 +01:00
|
|
|
for (auto idx = start; idx < stop; idx++) {
|
2019-11-13 15:04:59 +01:00
|
|
|
auto xTadOffsetForBlock = tadOffsets[indexes[idx]];
|
|
|
|
auto zTadOffsetForBlock = zTadOffsets[idx];
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-11-13 15:04:59 +01:00
|
|
|
auto rX = hX + xTadOffsetForBlock;
|
|
|
|
auto rZ = hZ + zTadOffsetForBlock;
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-11-13 15:04:59 +01:00
|
|
|
if (xEWS == 1 && zEWS == 1) {
|
|
|
|
PRAGMA_OMP_SIMD
|
2020-02-28 15:04:45 +01:00
|
|
|
for (Nd4jLong i = 0; i < tadLength; i++) {
|
2019-11-13 15:04:59 +01:00
|
|
|
rZ[i] = rX[i];
|
|
|
|
}
|
|
|
|
} else if (xEWS >= 1 && zEWS >= 1) {
|
|
|
|
PRAGMA_OMP_SIMD
|
2020-02-28 15:04:45 +01:00
|
|
|
for (Nd4jLong i = 0; i < tadLength; i++) {
|
2019-11-13 15:04:59 +01:00
|
|
|
rZ[i * zEWS] = rX[i * xEWS];
|
|
|
|
}
|
|
|
|
} else {
|
2020-02-28 15:04:45 +01:00
|
|
|
for (Nd4jLong i = 0; i < tadLength; i++) {
|
2019-11-13 15:04:59 +01:00
|
|
|
auto xOffset = xTadOffsetForBlock + shape::getIndexOffset(i, tadShapeInfo);
|
|
|
|
auto zOffset = zTadOffsetForBlock + shape::getIndexOffset(i, zTadShapeInfo);
|
|
|
|
hZ[zOffset] = hX[xOffset];
|
|
|
|
}
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
}
|
2019-11-13 15:04:59 +01:00
|
|
|
};
|
|
|
|
|
2020-03-09 06:22:49 +01:00
|
|
|
samediff::Threads::parallel_tad(func, 0, n, 1, _threads);
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
void pullRows(Nd4jPointer *extraPointers,
|
2020-05-09 07:06:14 +02:00
|
|
|
OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
|
|
|
|
OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo,
|
2019-06-06 14:21:15 +02:00
|
|
|
Nd4jLong n,
|
2020-05-09 07:06:14 +02:00
|
|
|
Nd4jLong* indexes,
|
|
|
|
Nd4jLong const* tadShapeInfo,
|
|
|
|
Nd4jLong const* tadOffsets,
|
|
|
|
Nd4jLong const* zTadShapeInfo,
|
|
|
|
Nd4jLong const* zTadOffsets) {
|
2019-08-26 18:57:51 +02:00
|
|
|
try {
|
2020-03-02 10:49:41 +01:00
|
|
|
auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-01-04 11:27:50 +01:00
|
|
|
BUILD_SINGLE_SELECTOR(xType, pullRowsGeneric, (dbX->primary(), hXShapeInfo, dbZ->primary(), hZShapeInfo, n, indexes, tadShapeInfo, tadOffsets, zTadShapeInfo, zTadOffsets), LIBND4J_TYPES);
|
2019-08-26 18:57:51 +02:00
|
|
|
} catch (std::exception &e) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
2019-08-26 18:57:51 +02:00
|
|
|
}
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
template<typename T>
|
|
|
|
void tearGeneric(void *vx,
|
2020-05-09 07:06:14 +02:00
|
|
|
Nd4jLong const* hXShapeInfo,
|
2019-06-06 14:21:15 +02:00
|
|
|
Nd4jPointer *targets,
|
2020-05-09 07:06:14 +02:00
|
|
|
Nd4jLong const* hZShapeInfo,
|
|
|
|
Nd4jLong const* tadShapeInfo,
|
|
|
|
Nd4jLong const* tadOffsets) {
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
auto hX = reinterpret_cast<T *>(vx);
|
|
|
|
|
|
|
|
const auto tadLength = shape::length(tadShapeInfo);
|
|
|
|
auto tadEWS = shape::elementWiseStride(tadShapeInfo);
|
|
|
|
auto zEWS = shape::elementWiseStride(hZShapeInfo);
|
|
|
|
auto numTads = shape::length(hXShapeInfo) / tadLength;
|
|
|
|
|
2019-11-13 15:04:59 +01:00
|
|
|
auto func = PRAGMA_THREADS_FOR {
|
2020-02-20 09:43:26 +01:00
|
|
|
for (auto i = start; i < stop; i++) {
|
2019-11-13 15:04:59 +01:00
|
|
|
auto hZ = reinterpret_cast<T *>(targets[i]);
|
|
|
|
auto s = hX + tadOffsets[i];
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-11-13 15:04:59 +01:00
|
|
|
if (zEWS == 1 && tadEWS == 1) {
|
|
|
|
PRAGMA_OMP_SIMD
|
|
|
|
for (Nd4jLong j = 0; j < tadLength; j++) {
|
|
|
|
hZ[j] = s[j];
|
|
|
|
}
|
|
|
|
} else if (zEWS > 0 && tadEWS > 0) {
|
|
|
|
PRAGMA_OMP_SIMD
|
|
|
|
for (Nd4jLong j = 0; j < tadLength; j++) {
|
|
|
|
hZ[j * zEWS] = s[j * tadEWS];
|
|
|
|
}
|
|
|
|
} else {
|
|
|
|
for (Nd4jLong j = 0; j < tadLength; j++)
|
|
|
|
hZ[shape::getIndexOffset(j, hZShapeInfo)] = s[shape::getIndexOffset(j, tadShapeInfo)];
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
}
|
2019-11-13 15:04:59 +01:00
|
|
|
};
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-03-09 06:22:49 +01:00
|
|
|
samediff::Threads::parallel_tad(func,0, numTads);
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
void tear(Nd4jPointer *extraPointers,
|
2020-05-09 07:06:14 +02:00
|
|
|
OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo,
|
2019-06-06 14:21:15 +02:00
|
|
|
Nd4jPointer *targets,
|
2020-05-09 07:06:14 +02:00
|
|
|
Nd4jLong const* hZShapeInfo,
|
|
|
|
Nd4jLong const* tadShapeInfo,
|
|
|
|
Nd4jLong const* tadOffsets) {
|
2019-08-26 18:57:51 +02:00
|
|
|
try {
|
2020-03-02 10:49:41 +01:00
|
|
|
auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-01-04 11:27:50 +01:00
|
|
|
BUILD_SINGLE_SELECTOR(xType, tearGeneric, (dbX->primary(), hXShapeInfo, targets, hZShapeInfo, tadShapeInfo, tadOffsets), LIBND4J_TYPES);
|
2019-08-26 18:57:51 +02:00
|
|
|
} catch (std::exception &e) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
2019-08-26 18:57:51 +02:00
|
|
|
}
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
void average(Nd4jPointer *extras,
|
2020-05-09 07:06:14 +02:00
|
|
|
Nd4jPointer *hX, const Nd4jLong *hXShapeInfo,
|
|
|
|
Nd4jPointer *dX, const Nd4jLong *dXShapeInfo,
|
|
|
|
void *z, const Nd4jLong *hZShapeInfo,
|
|
|
|
void *dz, const Nd4jLong *dZShapeInfo,
|
2019-06-06 14:21:15 +02:00
|
|
|
int n,
|
|
|
|
Nd4jLong length,
|
|
|
|
bool propagate) {
|
2019-08-26 18:57:51 +02:00
|
|
|
try {
|
2020-03-02 10:49:41 +01:00
|
|
|
auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
BUILD_SINGLE_SELECTOR(xType, sd::SpecialMethods, ::averageGeneric(hX, z, hZShapeInfo, n, length, propagate), LIBND4J_TYPES);
|
2019-08-26 18:57:51 +02:00
|
|
|
} catch (std::exception &e) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
2019-08-26 18:57:51 +02:00
|
|
|
}
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
void accumulate(Nd4jPointer *extras,
|
2020-05-09 07:06:14 +02:00
|
|
|
Nd4jPointer *hX, Nd4jLong const* hXShapeInfo,
|
|
|
|
Nd4jPointer *dX, Nd4jLong const* dXShapeInfo,
|
|
|
|
void *hz, Nd4jLong const* hZShapeInfo,
|
|
|
|
void *dz, Nd4jLong const* dZShapeInfo,
|
2019-06-06 14:21:15 +02:00
|
|
|
int n,
|
|
|
|
Nd4jLong length) {
|
2019-08-26 18:57:51 +02:00
|
|
|
try {
|
2020-03-02 10:49:41 +01:00
|
|
|
auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
BUILD_SINGLE_SELECTOR(xType, sd::SpecialMethods, ::accumulateGeneric(hX, hz, hZShapeInfo, n, length), LIBND4J_TYPES);
|
2019-08-26 18:57:51 +02:00
|
|
|
} catch (std::exception &e) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
2019-08-26 18:57:51 +02:00
|
|
|
}
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
void enableP2P(bool enable) {
|
2019-06-06 14:21:15 +02:00
|
|
|
// no-op
|
|
|
|
}
|
|
|
|
|
2020-05-09 07:06:14 +02:00
|
|
|
|
|
|
|
|
|
|
|
void encodeThresholdP1(Nd4jPointer *extraPointers, void *hX, Nd4jLong const* hXShapeInfo, Nd4jLong N, int *dz, float threshold) {
|
|
|
|
// TODO: to be implemented
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
void encodeThresholdP2Int(Nd4jPointer *extraPointers, int *hX, Nd4jLong N, int *dz) {
|
|
|
|
// TODO: to be implemented
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
void encodeThresholdP3(Nd4jPointer *extraPointers, void *hX, Nd4jLong const* hXShapeInfo, int *offsets, Nd4jLong N, int *dz){
|
|
|
|
// offsets won't be used here
|
|
|
|
|
|
|
|
// TODO: to be implemented
|
|
|
|
}
|
|
|
|
|
|
|
|
void decodeThreshold(Nd4jPointer *extraPointers, void *hX, Nd4jLong N, void *dz, const Nd4jLong *hZShapeInfo){
|
|
|
|
// TODO: to be implemented
|
|
|
|
}
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
bool isP2PAvailable() {
|
2019-06-06 14:21:15 +02:00
|
|
|
// always TRUE for cpu backend
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
void checkP2P() {
|
2019-06-06 14:21:15 +02:00
|
|
|
// no-op
|
|
|
|
}
|
|
|
|
|
2020-05-09 07:06:14 +02:00
|
|
|
void decodeBitmap(Nd4jPointer *extraPointers, void *hX, Nd4jLong N, void *dz, Nd4jLong const* hZShapeInfo) {
|
|
|
|
NativeOpExecutioner::decodeBitmap(hX, N, dz, hZShapeInfo);
|
|
|
|
}
|
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
template<typename T>
|
2020-05-09 07:06:14 +02:00
|
|
|
void shuffleGeneric(void **hX, Nd4jLong * const*hXShapeInfo, void **dz, Nd4jLong * const* hZShapeInfo, int N, int *shuffleMap, Nd4jLong * const* tadOnlyShapeInfo, Nd4jLong * const* tadOffsets) {
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
auto dX = reinterpret_cast<T **>(hX);
|
|
|
|
auto dZ = reinterpret_cast<T **>(dz);
|
|
|
|
|
2019-11-13 15:04:59 +01:00
|
|
|
auto func = PRAGMA_THREADS_FOR {
|
2020-02-20 09:43:26 +01:00
|
|
|
for (auto f = start; f < stop; f++) {
|
2019-11-13 15:04:59 +01:00
|
|
|
auto hX = reinterpret_cast<T *>(dX[f]);
|
|
|
|
//auto hZ = reinterpret_cast<T *>(dZ[f]);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-11-13 15:04:59 +01:00
|
|
|
auto xShapeInfo = hXShapeInfo[f];
|
|
|
|
auto tadOffset = reinterpret_cast<Nd4jLong *>(tadOffsets[f]);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
|
2019-11-13 15:04:59 +01:00
|
|
|
const auto tadLength = shape::length(tadOnlyShapeInfo[f]);
|
|
|
|
auto tadEWS = shape::elementWiseStride(tadOnlyShapeInfo[f]);
|
|
|
|
auto tadRank = shape::rank(tadOnlyShapeInfo[f]);
|
|
|
|
auto numTads = shape::length(hXShapeInfo[f]) / tadLength;
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-11-13 15:04:59 +01:00
|
|
|
auto tadShape = shape::shapeOf(tadOnlyShapeInfo[f]);
|
|
|
|
auto tadStride = shape::stride(tadOnlyShapeInfo[f]);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-11-13 15:04:59 +01:00
|
|
|
if (shape::rank(xShapeInfo) == 1) {
|
|
|
|
auto xLength = shape::length(xShapeInfo);
|
|
|
|
auto ews = shape::elementWiseStride(xShapeInfo);
|
|
|
|
for (Nd4jLong r = 0; r < xLength; r++) {
|
|
|
|
auto swapIdx = shuffleMap[r];
|
|
|
|
if (swapIdx < 0)
|
|
|
|
continue;
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::math::nd4j_swap<T>(hX[r * ews], hX[swapIdx * ews]);
|
2019-11-13 15:04:59 +01:00
|
|
|
}
|
|
|
|
} else {
|
|
|
|
for (Nd4jLong r = 0; r < numTads; r++) {
|
|
|
|
if (shuffleMap[r] < 0)
|
|
|
|
continue;
|
|
|
|
|
|
|
|
auto oldOffset = tadOffset[r];
|
|
|
|
auto newOffset = tadOffset[shuffleMap[r]];
|
|
|
|
|
|
|
|
auto rX = hX + oldOffset;
|
|
|
|
auto rY = hX + newOffset;
|
|
|
|
|
|
|
|
if (tadEWS == 1) {
|
|
|
|
for (Nd4jLong i = 0; i < tadLength; i++) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::math::nd4j_swap<T>(rX[i], rY[i]);
|
2019-11-13 15:04:59 +01:00
|
|
|
}
|
|
|
|
} else {
|
|
|
|
for (Nd4jLong i = 0; i < tadLength; i++) {
|
|
|
|
auto offset = shape::getIndexOffset(i, tadOnlyShapeInfo[f]);
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::math::nd4j_swap<T>(hX[offset + oldOffset], hX[offset + newOffset]);
|
2019-11-13 15:04:59 +01:00
|
|
|
}
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
2019-11-13 15:04:59 +01:00
|
|
|
};
|
|
|
|
|
2020-03-09 06:22:49 +01:00
|
|
|
samediff::Threads::parallel_tad(func, 0, N);
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
void shuffle(Nd4jPointer *extras,
|
2019-06-06 14:21:15 +02:00
|
|
|
Nd4jPointer *hX, Nd4jPointer *hXShapeInfo,
|
|
|
|
Nd4jPointer *dX, Nd4jPointer *dXShapeInfo,
|
|
|
|
Nd4jPointer *hz, Nd4jPointer *hZShapeInfo,
|
|
|
|
Nd4jPointer *dz, Nd4jPointer *dZShapeInfo,
|
|
|
|
int N,
|
|
|
|
int *shuffleMap,
|
|
|
|
Nd4jPointer *tadShapeInfo,
|
|
|
|
Nd4jPointer *tadOffsets) {
|
2019-08-26 18:57:51 +02:00
|
|
|
try {
|
2020-05-09 07:06:14 +02:00
|
|
|
auto xShape = reinterpret_cast<Nd4jLong * const*>(hXShapeInfo);
|
|
|
|
auto zShape = reinterpret_cast<Nd4jLong * const*>(hZShapeInfo);
|
|
|
|
auto tadOnlyShapeInfo = reinterpret_cast<Nd4jLong * const*>(tadShapeInfo);
|
|
|
|
auto tadOffset = reinterpret_cast<Nd4jLong * const*>(tadOffsets);
|
2019-08-26 18:57:51 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
auto xType = sd::ArrayOptions::dataType(xShape[0]);
|
2019-08-26 18:57:51 +02:00
|
|
|
|
|
|
|
BUILD_SINGLE_SELECTOR(xType, shuffleGeneric,
|
|
|
|
(hX, xShape, hz, zShape, N, shuffleMap, tadOnlyShapeInfo, tadOffset), LIBND4J_TYPES);
|
|
|
|
} catch (std::exception &e) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
2019-08-26 18:57:51 +02:00
|
|
|
}
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
bool isExperimentalEnabled() {
|
2020-06-06 14:26:55 +02:00
|
|
|
return sd::Environment::getInstance().isExperimentalBuild();
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
void setOmpMinThreads(int threads) {
|
2019-06-06 14:21:15 +02:00
|
|
|
// TODO: to be implemented
|
|
|
|
}
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
int getDevice() {
|
2019-06-06 14:21:15 +02:00
|
|
|
return 0;
|
|
|
|
}
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
void execScalarTad(Nd4jPointer *extraPointers,
|
2019-06-06 14:21:15 +02:00
|
|
|
int opNum,
|
2020-05-09 07:06:14 +02:00
|
|
|
OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const*dXShapeInfo,
|
|
|
|
OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const*dZShapeInfo,
|
|
|
|
OpaqueDataBuffer *dbScalars, Nd4jLong const* hScalarShapeInfo, Nd4jLong const* dScalarShapeInfo,
|
2019-06-06 14:21:15 +02:00
|
|
|
void *extraParams,
|
2020-05-09 07:06:14 +02:00
|
|
|
OpaqueDataBuffer *dbDimension, Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape,
|
|
|
|
Nd4jLong const*tadShapeInfo, Nd4jLong const* tadOffsets,
|
|
|
|
Nd4jLong const*tadShapeInfoZ, Nd4jLong const* tadOffsetsZ) {
|
2019-08-26 18:57:51 +02:00
|
|
|
try {
|
2020-01-04 11:27:50 +01:00
|
|
|
auto dimension = reinterpret_cast<int *>(dbDimension->primary());
|
2019-08-26 18:57:51 +02:00
|
|
|
int dimensionLength = static_cast<int>(shape::length(hDimensionShape));
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-26 18:57:51 +02:00
|
|
|
NativeOpExecutioner::execScalar(nullptr,
|
|
|
|
opNum,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbX->primary(),
|
2019-08-26 18:57:51 +02:00
|
|
|
hXShapeInfo,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbX->special(),
|
2019-08-26 18:57:51 +02:00
|
|
|
dXShapeInfo,
|
|
|
|
extraParams,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbZ->primary(),
|
2019-08-26 18:57:51 +02:00
|
|
|
hZShapeInfo,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbZ->special(),
|
2019-08-26 18:57:51 +02:00
|
|
|
dZShapeInfo,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbScalars->primary(),
|
2019-08-26 18:57:51 +02:00
|
|
|
hScalarShapeInfo,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbScalars->special(),
|
2019-08-26 18:57:51 +02:00
|
|
|
dScalarShapeInfo,
|
|
|
|
dimension,
|
|
|
|
shape::length(hDimensionShape),
|
|
|
|
tadShapeInfo,
|
|
|
|
tadOffsets,
|
|
|
|
tadShapeInfoZ,
|
|
|
|
tadOffsetsZ);
|
|
|
|
} catch (std::exception &e) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
2019-08-26 18:57:51 +02:00
|
|
|
}
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
void execScalarBoolTad(Nd4jPointer *extraPointers,
|
2019-06-06 14:21:15 +02:00
|
|
|
int opNum,
|
2020-05-09 07:06:14 +02:00
|
|
|
OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo,
|
|
|
|
OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo,
|
|
|
|
OpaqueDataBuffer *dbScalars, const Nd4jLong *hScalarShapeInfo, const Nd4jLong *dScalarShapeInfo,
|
2019-06-06 14:21:15 +02:00
|
|
|
void *extraParams,
|
2020-05-09 07:06:14 +02:00
|
|
|
OpaqueDataBuffer *dbDimension, const Nd4jLong *hDimensionShape, const Nd4jLong *dDimensionShape,
|
|
|
|
const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets,
|
|
|
|
const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetsZ) {
|
2019-08-26 18:57:51 +02:00
|
|
|
try {
|
2020-01-04 11:27:50 +01:00
|
|
|
auto dimension = reinterpret_cast<int *>(dbDimension->primary());
|
2019-08-26 18:57:51 +02:00
|
|
|
int dimensionLength = static_cast<int>(shape::length(hDimensionShape));
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-26 18:57:51 +02:00
|
|
|
NativeOpExecutioner::execScalarBool(nullptr,
|
|
|
|
opNum,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbX->primary(),
|
2019-08-26 18:57:51 +02:00
|
|
|
hXShapeInfo,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbX->special(),
|
2019-08-26 18:57:51 +02:00
|
|
|
dXShapeInfo,
|
|
|
|
extraParams,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbZ->primary(),
|
2019-08-26 18:57:51 +02:00
|
|
|
hZShapeInfo,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbZ->special(),
|
2019-08-26 18:57:51 +02:00
|
|
|
dZShapeInfo,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbScalars->primary(),
|
2019-08-26 18:57:51 +02:00
|
|
|
hScalarShapeInfo,
|
2020-01-04 11:27:50 +01:00
|
|
|
dbScalars->special(),
|
2019-08-26 18:57:51 +02:00
|
|
|
dScalarShapeInfo,
|
|
|
|
dimension,
|
|
|
|
dimensionLength,
|
|
|
|
tadShapeInfo,
|
|
|
|
tadOffsets,
|
|
|
|
tadShapeInfoZ,
|
|
|
|
tadOffsetsZ);
|
|
|
|
} catch (std::exception &e) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
2019-08-26 18:57:51 +02:00
|
|
|
}
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
const char * getDeviceName(int deviceId) {
|
2019-08-26 18:57:51 +02:00
|
|
|
try {
|
|
|
|
if (!nameSet) {
|
|
|
|
name = reinterpret_cast<char *>(malloc(256 * sizeof(char)));
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-26 18:57:51 +02:00
|
|
|
CHECK_ALLOC(name, "Failed to allocate new string buffer", 256);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-26 18:57:51 +02:00
|
|
|
std::memset(name, 0, 256 * sizeof(char));
|
|
|
|
nameSet = true;
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-26 18:57:51 +02:00
|
|
|
// TODO: provide proper CPU model name here
|
|
|
|
sprintf(name, "x86-compatible CPU");
|
|
|
|
}
|
|
|
|
} catch (std::exception &e) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
return name;
|
|
|
|
}
|
|
|
|
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
void execAggregate(Nd4jPointer *extraPointers,int opNum,
|
2019-06-06 14:21:15 +02:00
|
|
|
void **arguments,
|
|
|
|
int numArguments,
|
|
|
|
Nd4jLong **shapeArguments,
|
|
|
|
int numShapeArguments,
|
|
|
|
int *indexArguments,
|
|
|
|
int numIndexArguments,
|
|
|
|
int **intArrays,
|
|
|
|
int numIntArrays,
|
|
|
|
void *realArguments,
|
|
|
|
int numRealArguments,
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::DataType dtype) {
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
}
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
void batchExecutor(Nd4jPointer *extraPointers,
|
|
|
|
int numAggregates,
|
|
|
|
int opNum,
|
|
|
|
int maxArgs,
|
|
|
|
int maxShapes,
|
|
|
|
int maxIntArrays,
|
|
|
|
int maxIntArraySize,
|
|
|
|
int maxIdx,
|
|
|
|
int maxReals,
|
|
|
|
void *ptrToArguments,
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::DataType dtype) {
|
2019-11-13 15:04:59 +01:00
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
}
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
void execAggregateBatch(Nd4jPointer *extraPointers,
|
2019-06-06 14:21:15 +02:00
|
|
|
int numAggregates,
|
|
|
|
int opNum,
|
|
|
|
int maxArgs,
|
|
|
|
int maxShapes,
|
|
|
|
int maxIntArrays,
|
|
|
|
int maxIntArraySize,
|
|
|
|
int maxIdx,
|
|
|
|
int maxReals,
|
|
|
|
void *ptrToArguments,
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::DataType dtype) {
|
2019-11-13 15:04:59 +01:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
void execRandom(Nd4jPointer *extraPointers,
|
2019-06-06 14:21:15 +02:00
|
|
|
int opNum,
|
|
|
|
Nd4jPointer state,
|
2020-05-09 07:06:14 +02:00
|
|
|
OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo,
|
2019-06-06 14:21:15 +02:00
|
|
|
void *extraArguments) {
|
2019-08-26 18:57:51 +02:00
|
|
|
try {
|
2020-01-04 11:27:50 +01:00
|
|
|
NativeOpExecutioner::execRandom(nullptr, opNum, state, dbZ->primary(), hZShapeInfo, dbZ->special(), dZShapeInfo, extraArguments);
|
2019-08-26 18:57:51 +02:00
|
|
|
} catch (std::exception &e) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
2019-08-26 18:57:51 +02:00
|
|
|
}
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
void execRandom3(Nd4jPointer *extraPointers,
|
2019-06-06 14:21:15 +02:00
|
|
|
int opNum,
|
|
|
|
Nd4jPointer state,
|
2020-05-09 07:06:14 +02:00
|
|
|
OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo,
|
|
|
|
OpaqueDataBuffer *dbY, const Nd4jLong *hYShapeInfo, const Nd4jLong *dYShapeInfo,
|
|
|
|
OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo,
|
2019-06-06 14:21:15 +02:00
|
|
|
void *extraArguments) {
|
2019-08-26 18:57:51 +02:00
|
|
|
try {
|
2020-01-04 11:27:50 +01:00
|
|
|
NativeOpExecutioner::execRandom(nullptr, opNum, state, dbX->primary(), hXShapeInfo, dbX->special(), dXShapeInfo, dbY->primary(), hYShapeInfo, dbY->special(), dYShapeInfo, dbZ->primary(), hZShapeInfo, dbZ->special(), dZShapeInfo, extraArguments);
|
2019-08-26 18:57:51 +02:00
|
|
|
} catch (std::exception &e) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
2019-08-26 18:57:51 +02:00
|
|
|
}
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
void execRandom2(Nd4jPointer *extraPointers,
|
2019-06-06 14:21:15 +02:00
|
|
|
int opNum,
|
|
|
|
Nd4jPointer state,
|
2020-05-09 07:06:14 +02:00
|
|
|
OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo,
|
|
|
|
OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo,
|
2019-06-06 14:21:15 +02:00
|
|
|
void *extraArguments) {
|
2019-08-26 18:57:51 +02:00
|
|
|
try {
|
2020-01-04 11:27:50 +01:00
|
|
|
NativeOpExecutioner::execRandom(nullptr, opNum, state, dbX->primary(), hXShapeInfo, dbX->special(), dXShapeInfo, dbZ->primary(), hZShapeInfo, dbZ->special(), dZShapeInfo, extraArguments);
|
2019-08-26 18:57:51 +02:00
|
|
|
} catch (std::exception &e) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
2019-08-26 18:57:51 +02:00
|
|
|
}
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
Nd4jPointer initRandom(Nd4jPointer *extraPointers, long seed, long bufferSize, Nd4jPointer ptrToBuffer) {
|
2019-08-26 18:57:51 +02:00
|
|
|
try {
|
|
|
|
auto generator = new graph::RandomGenerator(seed, seed);
|
|
|
|
|
|
|
|
return (Nd4jPointer) generator;
|
|
|
|
} catch (std::exception &e) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
2019-08-26 18:57:51 +02:00
|
|
|
|
|
|
|
return nullptr;
|
|
|
|
}
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
void refreshBuffer(Nd4jPointer *extraPointers, long seed, Nd4jPointer ptrRandom) {
|
2020-03-02 10:49:41 +01:00
|
|
|
auto generator = reinterpret_cast<sd::graph::RandomGenerator*> (ptrRandom);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
generator->setStates(seed);
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
void reSeedBuffer(Nd4jPointer *extraPointers, long seed, Nd4jPointer ptrRandom) {
|
2020-03-02 10:49:41 +01:00
|
|
|
auto generator = reinterpret_cast<sd::graph::RandomGenerator *> (ptrRandom);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
generator->setStates(seed);
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
void destroyRandom(Nd4jPointer ptrBuffer) {
|
2020-03-02 10:49:41 +01:00
|
|
|
auto buffer = reinterpret_cast<sd::graph::RandomGenerator*>(ptrBuffer);
|
2019-06-06 14:21:15 +02:00
|
|
|
delete buffer;
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
* Return the length of a shape buffer
|
|
|
|
* based on the pointer
|
|
|
|
* @param buffer the buffer pointer to check
|
|
|
|
* @return
|
|
|
|
*/
|
2019-07-22 13:34:08 +02:00
|
|
|
int lengthForShapeBufferPointer(Nd4jPointer buffer) {
|
2019-06-06 14:21:15 +02:00
|
|
|
auto shapeBuffer = reinterpret_cast<Nd4jLong *>(buffer);
|
|
|
|
return shape::shapeInfoLength(shape::rank(shapeBuffer));
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
* The pointer to get the address for
|
|
|
|
*
|
|
|
|
* @param address the address to get the pointer
|
|
|
|
* @return the pointer for the given address
|
|
|
|
*/
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
Nd4jPointer pointerForAddress(Nd4jLong address) {
|
2019-06-06 14:21:15 +02:00
|
|
|
return reinterpret_cast<Nd4jPointer >(address);
|
|
|
|
}
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
void sort(Nd4jPointer *extraPointers,
|
2020-05-09 07:06:14 +02:00
|
|
|
void *hX, const Nd4jLong *hXShapeInfo,
|
|
|
|
void *dX, const Nd4jLong *dXShapeInfo,
|
2019-06-06 14:21:15 +02:00
|
|
|
bool descending) {
|
2019-08-26 18:57:51 +02:00
|
|
|
try {
|
|
|
|
NativeOpExecutioner::execSort(hX, hXShapeInfo, descending);
|
|
|
|
} catch (std::exception &e) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
2019-08-26 18:57:51 +02:00
|
|
|
}
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
void sortTad(Nd4jPointer *extraPointers,
|
2020-05-09 07:06:14 +02:00
|
|
|
void *hX, const Nd4jLong *hXShapeInfo,
|
|
|
|
void *dX, const Nd4jLong *dXShapeInfo,
|
|
|
|
int *dimension, int dimensionLength,
|
|
|
|
const Nd4jLong *tadShapeInfo,
|
|
|
|
const Nd4jLong *tadOffsets,
|
2019-06-06 14:21:15 +02:00
|
|
|
bool descending) {
|
2019-08-26 18:57:51 +02:00
|
|
|
try {
|
|
|
|
NativeOpExecutioner::execSort(hX, hXShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets, descending);
|
|
|
|
} catch (std::exception &e) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
2019-08-26 18:57:51 +02:00
|
|
|
}
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
void sortCooIndices(Nd4jPointer *extraPointers,
|
2019-06-06 14:21:15 +02:00
|
|
|
Nd4jLong *indices,
|
|
|
|
void *values,
|
|
|
|
Nd4jLong length,
|
|
|
|
int rank) {
|
2019-08-26 18:57:51 +02:00
|
|
|
try {
|
|
|
|
NativeOpExecutioner::execSortCooIndices(indices, values, length, rank);
|
|
|
|
} catch (std::exception &e) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
2019-08-26 18:57:51 +02:00
|
|
|
}
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
2020-05-09 07:06:14 +02:00
|
|
|
Nd4jLong encodeBitmap(Nd4jPointer *extraPointers, void *hX, Nd4jLong const* hXShapeInfo, Nd4jLong N, int *dz, float threshold) {
|
|
|
|
return NativeOpExecutioner::encodeBitmap(hX, hXShapeInfo, N, dz, threshold);
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
Nd4jLong* mmapFile(Nd4jPointer *extraPointers, const char *fileName, Nd4jLong length) {
|
2019-06-06 14:21:15 +02:00
|
|
|
auto hZ = new Nd4jLong[2];errno = 0;
|
2019-08-26 18:57:51 +02:00
|
|
|
try {
|
2019-06-06 14:21:15 +02:00
|
|
|
#if defined(_WIN32) || defined(_WIN64)
|
|
|
|
_mmap(hZ, static_cast<size_t>(length), fileName);
|
|
|
|
#else
|
|
|
|
int fd = open(fileName, O_RDWR, 0);// checking for failed fopen
|
|
|
|
if (fd < 0) {
|
|
|
|
nd4j_printf("Errno: %i\n", errno);
|
|
|
|
throw std::runtime_error("Failed to open file for MMAP");
|
|
|
|
}
|
2019-08-26 18:57:51 +02:00
|
|
|
void *ptr = mmap(NULL, length, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
// check for failed allocation
|
|
|
|
if (ptr == MAP_FAILED)
|
|
|
|
return nullptr;
|
|
|
|
|
|
|
|
hZ[0] = (Nd4jLong) ptr;
|
|
|
|
hZ[1] = fd;
|
|
|
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
return hZ;
|
2019-08-26 18:57:51 +02:00
|
|
|
} catch (std::exception &e) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
2019-08-26 18:57:51 +02:00
|
|
|
return nullptr;
|
|
|
|
}
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
void munmapFile(Nd4jPointer *extraPointers, Nd4jLong *ptrMap, Nd4jLong length) {
|
2019-06-06 14:21:15 +02:00
|
|
|
munmap((Nd4jPointer) ptrMap[0], length);
|
|
|
|
#if defined(_WIN32) || defined(_WIN64)
|
|
|
|
CloseHandle(reinterpret_cast<HANDLE>(ptrMap[1]));
|
|
|
|
#else
|
|
|
|
close((int) ptrMap[1]);
|
|
|
|
#endif
|
|
|
|
|
|
|
|
delete[] ptrMap;
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::graph::ResultWrapper* executeFlatGraph(Nd4jPointer *extraPointers, Nd4jPointer flatBufferPointer) {
|
2019-08-26 18:57:51 +02:00
|
|
|
try {
|
2020-03-02 10:49:41 +01:00
|
|
|
return sd::graph::GraphExecutioner::executeFlatBuffer(flatBufferPointer);
|
2019-08-26 18:57:51 +02:00
|
|
|
} catch (std::exception &e) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
2019-08-26 18:57:51 +02:00
|
|
|
return nullptr;
|
|
|
|
}
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
Nd4jLong getResultWrapperSize(sd::graph::ResultWrapper* ptr) {
|
2019-07-24 14:14:54 +02:00
|
|
|
return ptr->size();
|
|
|
|
}
|
2020-03-02 10:49:41 +01:00
|
|
|
Nd4jPointer getResultWrapperPointer(sd::graph::ResultWrapper* ptr) {
|
2019-07-24 14:14:54 +02:00
|
|
|
return ptr->pointer();
|
|
|
|
}
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
const char* getAllCustomOps() {
|
2020-06-06 14:26:55 +02:00
|
|
|
return sd::ops::OpRegistrator::getInstance().getAllCustomOperations();
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
FORCEINLINE int estimateThresholdGeneric(Nd4jPointer *extraPointers, Nd4jPointer hX, int N, T threshold) {
|
|
|
|
auto buffer = reinterpret_cast<T *>(hX);
|
|
|
|
int span = (N / 6) + 8;
|
|
|
|
|
2019-11-13 15:04:59 +01:00
|
|
|
auto func = PRAGMA_REDUCE_LONG {
|
|
|
|
int64_t cnt = 0;
|
2019-06-06 14:21:15 +02:00
|
|
|
PRAGMA_OMP_SIMD
|
2019-11-13 15:04:59 +01:00
|
|
|
for (auto e = start; e < stop; e++) {
|
2020-03-02 10:49:41 +01:00
|
|
|
auto v = sd::math::nd4j_abs<T>(buffer[e]);
|
2019-06-06 14:21:15 +02:00
|
|
|
if (v >= threshold)
|
|
|
|
cnt++;
|
|
|
|
}
|
|
|
|
|
2019-11-13 15:04:59 +01:00
|
|
|
return cnt;
|
|
|
|
};
|
|
|
|
|
2020-03-09 06:22:49 +01:00
|
|
|
return samediff::Threads::parallel_long(func, LAMBDA_AL { return _old + _new; }, 0, N);
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
|
2020-05-09 07:06:14 +02:00
|
|
|
int estimateThreshold(Nd4jPointer *extraPointers, Nd4jPointer hX, Nd4jLong const* hXShapeInfo, int N, float threshold) {
|
2019-08-26 18:57:51 +02:00
|
|
|
try {
|
|
|
|
auto xType = ArrayOptions::dataType(hXShapeInfo);
|
|
|
|
BUILD_SINGLE_SELECTOR(xType, return estimateThresholdGeneric, (extraPointers, hX, N, threshold), FLOAT_TYPES);
|
|
|
|
} catch (std::exception &e) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
2019-08-26 18:57:51 +02:00
|
|
|
return 0;
|
|
|
|
}
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
Nd4jLong getShapeListSize(sd::ShapeList* list) {
|
2019-07-24 14:14:54 +02:00
|
|
|
return list->size();
|
|
|
|
}
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-05-09 07:06:14 +02:00
|
|
|
Nd4jLong const* getShape(sd::ShapeList* list, Nd4jLong i) {
|
|
|
|
return const_cast<Nd4jLong const*>(list->at(i));
|
2019-07-24 14:14:54 +02:00
|
|
|
}
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
void deleteShapeList(Nd4jPointer shapeList) {
|
2020-03-02 10:49:41 +01:00
|
|
|
auto list = reinterpret_cast<sd::ShapeList*>(shapeList);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
//list->destroy();
|
|
|
|
delete list;
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ShapeList* _calculateOutputShapes(Nd4jPointer* extraPointers, sd::ops::DeclarableOp* op, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs, int *dArgs, int numDArgs) {
|
|
|
|
sd::graph::VariableSpace varSpace;
|
2019-06-06 14:21:15 +02:00
|
|
|
Context block(2, &varSpace);
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ShapeList inShapes;
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
for (int e = 0; e < numIArgs; e++)
|
|
|
|
block.getIArguments()->push_back(iArgs[e]);
|
|
|
|
|
|
|
|
for (int e = 0; e < numTArgs; e++)
|
|
|
|
block.getTArguments()->push_back(tArgs[e]);
|
|
|
|
|
|
|
|
for (int e = 0; e < numBArgs; e++)
|
|
|
|
block.getBArguments()->push_back(bArgs[e]);
|
|
|
|
|
2020-01-30 16:46:12 +01:00
|
|
|
for (int e = 0; e < numDArgs; e++)
|
2020-03-02 10:49:41 +01:00
|
|
|
block.getDArguments()->push_back((sd::DataType) dArgs[e]);
|
2020-01-30 16:46:12 +01:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
for (int e = 0; e < numInputShapes; e++) {
|
|
|
|
auto shape_ = reinterpret_cast<Nd4jLong *>(inputShapes[e]);
|
|
|
|
|
|
|
|
// we shouldn't copy buffer if that's empty array
|
2020-03-02 10:49:41 +01:00
|
|
|
void *buffer_ = sd::ArrayOptions::arrayType(shape_) == ArrayType::EMPTY ? nullptr : inputBuffers[e];
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
auto array = new sd::NDArray(buffer_, shape_, varSpace.launchContext(), false);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
// block should contain references to proper variable
|
|
|
|
varSpace.putVariable(1, e, array);
|
|
|
|
block.pickInput(1, e);
|
|
|
|
|
|
|
|
inShapes.push_back(shape_);
|
|
|
|
}
|
|
|
|
|
|
|
|
auto status = op->validateDataTypes(block);
|
|
|
|
if (status != Status::OK())
|
|
|
|
throw std::runtime_error("Data types validation failed");
|
|
|
|
|
|
|
|
auto shapeList = op->calculateOutputShape(&inShapes, block);
|
|
|
|
|
|
|
|
if (varSpace.launchContext() != nullptr)
|
|
|
|
shapeList->detach();
|
|
|
|
|
|
|
|
return shapeList;
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ShapeList* calculateOutputShapes2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs, int *dArgs, int numDArgs) {
|
2019-08-26 18:57:51 +02:00
|
|
|
try {
|
2020-06-06 14:26:55 +02:00
|
|
|
auto op = sd::ops::OpRegistrator::getInstance().getOperation(hash);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-01-30 16:46:12 +01:00
|
|
|
return _calculateOutputShapes(extraPointers, op, inputBuffers, inputShapes, numInputShapes, tArgs, numTArgs, iArgs, numIArgs, bArgs, numBArgs, dArgs, numDArgs);
|
2019-08-26 18:57:51 +02:00
|
|
|
} catch (std::exception &e) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
2019-08-26 18:57:51 +02:00
|
|
|
return nullptr;
|
|
|
|
}
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ShapeList* _calculateOutputShapes(Nd4jPointer* extraPointers, sd::ops::DeclarableOp *op, Nd4jPointer* inputShapes, int numInputShapes, double *tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs) {
|
2019-06-06 14:21:15 +02:00
|
|
|
Context block(1);
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ShapeList inShapes;
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
for (int e = 0; e < numIArgs; e++)
|
|
|
|
block.getIArguments()->push_back(iArgs[e]);
|
|
|
|
|
|
|
|
for (int e = 0; e < numTArgs; e++)
|
|
|
|
block.getTArguments()->push_back(tArgs[e]);
|
|
|
|
|
|
|
|
for (int e = 0; e < numInputShapes; e++)
|
|
|
|
inShapes.push_back(reinterpret_cast<Nd4jLong *>(inputShapes[e]));
|
|
|
|
|
|
|
|
auto shapeList = op->calculateOutputShape(&inShapes, block);
|
|
|
|
shapeList->detach();
|
|
|
|
|
|
|
|
return shapeList;
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ShapeList* calculateOutputShapes(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs) {
|
2019-08-26 18:57:51 +02:00
|
|
|
try {
|
2020-06-06 14:26:55 +02:00
|
|
|
auto op = sd::ops::OpRegistrator::getInstance().getOperation(hash);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-26 18:57:51 +02:00
|
|
|
return _calculateOutputShapes(extraPointers, op, inputShapes, numInputShapes, tArgs, numTArgs, iArgs, numIArgs);
|
|
|
|
} catch (std::exception &e) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
2019-08-26 18:57:51 +02:00
|
|
|
return nullptr;
|
|
|
|
}
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
int execCustomOp2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer opContext) {
|
2019-08-26 18:57:51 +02:00
|
|
|
try {
|
2020-06-06 14:26:55 +02:00
|
|
|
auto op = sd::ops::OpRegistrator::getInstance().getOperation(hash);
|
2019-08-26 18:57:51 +02:00
|
|
|
auto context = reinterpret_cast<Context *>(opContext);
|
|
|
|
|
|
|
|
return op->execute(context);
|
|
|
|
} catch (std::exception &e) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
2019-08-26 18:57:51 +02:00
|
|
|
return 20;
|
|
|
|
}
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
Nd4jStatus realExec(sd::ops::DeclarableOp* op, Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputs, Nd4jPointer* outputBuffers, Nd4jPointer* outputShapes, int numOutputs, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool* bArgs, int numBArgs, bool isInplace) {
|
2019-06-06 14:21:15 +02:00
|
|
|
if (op == nullptr)
|
|
|
|
nd4j_printf("Can't find requested operation: [%lld]\n", hash);
|
|
|
|
|
|
|
|
// we're using the same fake nodeId everywhere here
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
std::vector<sd::NDArray*> inputs(numInputs);
|
|
|
|
std::vector<sd::NDArray*> outputs(numOutputs);
|
2019-06-06 14:21:15 +02:00
|
|
|
std::vector<double> ttArgs(numTArgs);
|
|
|
|
std::vector<Nd4jLong> iiArgs(numIArgs);
|
|
|
|
std::vector<bool> biArgs(numBArgs);
|
|
|
|
|
|
|
|
// filling block now with inputs
|
|
|
|
for (int e = 0; e < numInputs; e++) {
|
|
|
|
auto shape = reinterpret_cast<Nd4jLong *>(inputShapes[e]);
|
2020-03-02 10:49:41 +01:00
|
|
|
void *buffer = sd::ArrayOptions::arrayType(shape) == ArrayType::EMPTY ? nullptr : inputBuffers[e];
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
inputs[e] = new sd::NDArray(buffer, shape);
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
// if not inplace - transferring output arrays
|
|
|
|
|
|
|
|
if (!isInplace)
|
|
|
|
for (int e = 0; e < numOutputs; e++) {
|
|
|
|
// we want to keep original output shape intact
|
|
|
|
auto shape = shape::copyShape(reinterpret_cast<Nd4jLong *>(outputShapes[e]));
|
2020-03-02 10:49:41 +01:00
|
|
|
void *buffer = sd::ArrayOptions::arrayType(shape) == ArrayType::EMPTY ? nullptr : outputBuffers[e];
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
// FIXME: revisit this.
|
|
|
|
bool canNullify = true;
|
|
|
|
for (int i = 0; i < numInputs; i++) {
|
2020-03-02 10:49:41 +01:00
|
|
|
void *ibuffer = sd::ArrayOptions::arrayType(shape) == ArrayType::EMPTY ? nullptr : inputBuffers[i];
|
2019-06-06 14:21:15 +02:00
|
|
|
if (ibuffer == buffer) {
|
|
|
|
canNullify = false;
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
if (canNullify)
|
|
|
|
memset((uint8_t *) buffer, '\0', shape::length(shape) * DataTypeUtils::sizeOfElement(ArrayOptions::dataType(shape)));
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
auto array = new sd::NDArray(buffer, shape);
|
2019-06-06 14:21:15 +02:00
|
|
|
outputs[e] = array;
|
|
|
|
|
|
|
|
// and we want to release shape copy once we're done
|
|
|
|
delete []shape;
|
|
|
|
}
|
|
|
|
|
|
|
|
for (int e = 0; e < numIArgs; e++)
|
|
|
|
iiArgs[e] = iArgs[e];
|
|
|
|
|
|
|
|
|
|
|
|
for (int e = 0; e < numTArgs; e++)
|
|
|
|
ttArgs[e] = tArgs[e];
|
|
|
|
|
|
|
|
for (int e = 0; e < numBArgs; e++)
|
|
|
|
biArgs[e] = bArgs[e];
|
|
|
|
|
|
|
|
// hypothetically at this point we have everything filled
|
2020-03-02 10:49:41 +01:00
|
|
|
auto hZ = op->execute(inputs, outputs, ttArgs, iiArgs, biArgs, std::vector<sd::DataType>(), isInplace);
|
2019-06-06 14:21:15 +02:00
|
|
|
//auto hZ = op->execute(inputs, ttArgs, iiArgs, isInplace);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (!isInplace)
|
|
|
|
for (int e = 0; e < numOutputs; e++) {
|
|
|
|
//shape::printShapeInfoLinear("JVM output shape", (int *) outputShapes[e]);
|
|
|
|
//shape::printShapeInfoLinear("C++ output shape", (int *) outputs[e]->shapeInfo());
|
|
|
|
//outputs[e]->printIndexedBuffer("C++ raw output");
|
|
|
|
//outputs[e]->printBuffer("C++ indexed output");
|
|
|
|
|
|
|
|
if (outputs[e]->ordering() != shape::order(reinterpret_cast<Nd4jLong *>(outputShapes[e])))
|
|
|
|
outputs[e]->streamline(shape::order(reinterpret_cast<Nd4jLong *>(outputShapes[e])));
|
|
|
|
}
|
|
|
|
|
|
|
|
for (auto v: inputs)
|
|
|
|
delete v;
|
|
|
|
|
|
|
|
for (auto v: outputs)
|
|
|
|
delete v;
|
|
|
|
|
|
|
|
return hZ;
|
|
|
|
}
|
|
|
|
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
int execCustomOp(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputs, Nd4jPointer* outputBuffers, Nd4jPointer* outputShapes, int numOutputs, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool* bArgs, int numBArgs, bool isInplace) {
|
2019-08-26 18:57:51 +02:00
|
|
|
try {
|
2020-06-06 14:26:55 +02:00
|
|
|
auto op = sd::ops::OpRegistrator::getInstance().getOperation(hash);
|
2019-08-26 18:57:51 +02:00
|
|
|
return realExec(op, extraPointers, hash, inputBuffers, inputShapes, numInputs, outputBuffers, outputShapes, numOutputs, tArgs, numTArgs, iArgs, numIArgs, bArgs, numBArgs, isInplace);
|
|
|
|
} catch (std::exception &e) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
2019-08-26 18:57:51 +02:00
|
|
|
return 1;
|
|
|
|
}
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
int registerGraph(Nd4jPointer *extraPointers, Nd4jLong graphId, Nd4jPointer flatBufferPointer) {
|
2019-08-26 18:57:51 +02:00
|
|
|
try {
|
2020-03-02 10:49:41 +01:00
|
|
|
auto graph = sd::graph::GraphExecutioner::importFromFlatPointer(flatBufferPointer);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-06-06 14:26:55 +02:00
|
|
|
sd::graph::GraphHolder::getInstance().registerGraph(graphId, graph);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-26 18:57:51 +02:00
|
|
|
return ND4J_STATUS_OK;
|
|
|
|
} catch (std::exception &e) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
2019-08-26 18:57:51 +02:00
|
|
|
return 1;
|
|
|
|
}
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
static VariablesSet* executeStoredGraphT(Nd4jPointer *extraPointers, Nd4jLong graphId, Nd4jPointer *inputBuffers, Nd4jPointer *inputShapes, int* inputIndices, int numInputs) {
|
2020-06-06 14:26:55 +02:00
|
|
|
auto graph = sd::graph::GraphHolder::getInstance().cloneGraph(graphId);
|
2019-06-06 14:21:15 +02:00
|
|
|
auto varSpace = graph->getVariableSpace();
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
std::vector<sd::NDArray*> handles;
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
for (int e = 0; e < numInputs; e++) {
|
|
|
|
auto idx = inputIndices[e];
|
|
|
|
|
|
|
|
// we'll delete this array later, together with cloned VariableSpace
|
2020-03-02 10:49:41 +01:00
|
|
|
auto array = new sd::NDArray(inputBuffers[e], reinterpret_cast<Nd4jLong *>(inputShapes[e]));
|
2019-06-06 14:21:15 +02:00
|
|
|
handles.emplace_back(array);
|
|
|
|
|
|
|
|
if (varSpace->hasVariable(idx)) {
|
|
|
|
auto var = varSpace->getVariable(idx);
|
|
|
|
if (var->hasNDArray())
|
|
|
|
delete var->getNDArray();
|
|
|
|
|
|
|
|
var->setNDArray(array);
|
|
|
|
} else
|
|
|
|
varSpace->putVariable(idx, array);
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
auto hZ = sd::graph::GraphExecutioner::execute(graph, varSpace);
|
|
|
|
auto varSet = new sd::graph::VariablesSet(hZ);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
if (hZ == ND4J_STATUS_OK) {
|
|
|
|
// pull back results, and provide them
|
|
|
|
auto outputs = graph->fetchOutputs();
|
|
|
|
for (int e = 0; e < outputs->size(); e++) {
|
|
|
|
// we're only getting variable ID/Index from original grap. values will be taken from cloned workspace
|
|
|
|
std::pair<int, int> varId(outputs->at(e)->id(), outputs->at(e)->index());
|
|
|
|
|
|
|
|
auto var = varSpace->getVariable(varId);
|
|
|
|
|
|
|
|
varSet->push_back(var->clone());
|
|
|
|
}
|
|
|
|
|
|
|
|
delete outputs;
|
|
|
|
}
|
|
|
|
|
|
|
|
delete graph;
|
|
|
|
|
|
|
|
return varSet;
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::graph::VariablesSet* executeStoredGraph(Nd4jPointer *extraPointers, Nd4jLong graphId, Nd4jPointer *inputBuffers, Nd4jPointer *inputShapes, int* inputIndices, int numInputs) {
|
2019-06-06 14:21:15 +02:00
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
Nd4jLong getVariablesSetSize(sd::graph::VariablesSet* set) {
|
2019-07-24 14:14:54 +02:00
|
|
|
return set->size();
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
Nd4jStatus getVariablesSetStatus(sd::graph::VariablesSet* set) {
|
2019-07-24 14:14:54 +02:00
|
|
|
return set->status();
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::graph::Variable* getVariable(sd::graph::VariablesSet* set, Nd4jLong i) {
|
2019-07-24 14:14:54 +02:00
|
|
|
return set->at(i);
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
int getVariableId(sd::graph::Variable* variable) {
|
2019-07-24 14:14:54 +02:00
|
|
|
return variable->id();
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
int getVariableIndex(sd::graph::Variable* variable) {
|
2019-07-24 14:14:54 +02:00
|
|
|
return variable->index();
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
const char* getVariableName(sd::graph::Variable* variable) {
|
2019-07-24 14:14:54 +02:00
|
|
|
return variable->getName()->c_str();
|
|
|
|
}
|
|
|
|
|
2020-05-09 07:06:14 +02:00
|
|
|
Nd4jLong const* getVariableShape(sd::graph::Variable* variable) {
|
|
|
|
return const_cast<Nd4jLong const*>(variable->getNDArray()->shapeInfo());
|
2019-07-24 14:14:54 +02:00
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
void* getVariableBuffer(sd::graph::Variable* variable) {
|
2019-07-24 14:14:54 +02:00
|
|
|
return variable->getNDArray()->buffer();
|
|
|
|
}
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
int unregisterGraph(Nd4jPointer *extraPointers, Nd4jLong graphId) {
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-06-06 14:26:55 +02:00
|
|
|
sd::graph::GraphHolder::getInstance().dropGraphAny(graphId);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
return sd::Status::OK();
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
void deletePointerArray(Nd4jPointer pointer) {
|
2019-06-06 14:21:15 +02:00
|
|
|
auto ptr = reinterpret_cast<Nd4jPointer *>(pointer);
|
|
|
|
delete[] ptr;
|
|
|
|
}
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
void deleteCharArray(Nd4jPointer pointer) {
|
2019-07-12 07:21:15 +02:00
|
|
|
auto ptr = reinterpret_cast<char *>(pointer);
|
|
|
|
delete[] ptr;
|
|
|
|
}
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
void deleteIntArray(Nd4jPointer pointer) {
|
2019-06-06 14:21:15 +02:00
|
|
|
auto ptr = reinterpret_cast<int *>(pointer);
|
|
|
|
delete[] ptr;
|
|
|
|
}
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
void deleteLongArray(Nd4jPointer pointer) {
|
2019-06-06 14:21:15 +02:00
|
|
|
auto ptr = reinterpret_cast<Nd4jLong *>(pointer);
|
|
|
|
delete[] ptr;
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
void deleteVariablesSet(sd::graph::VariablesSet* pointer) {
|
2019-07-26 09:22:44 +02:00
|
|
|
delete pointer;
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
const char* getAllOperations() {
|
2020-06-06 14:26:55 +02:00
|
|
|
return sd::OpTracker::getInstance().exportOperations();
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
Nd4jPointer getGraphState(Nd4jLong id) {
|
2020-03-02 10:49:41 +01:00
|
|
|
return (Nd4jPointer) new sd::graph::GraphState(id);
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
void deleteGraphState(Nd4jPointer state) {
|
2020-03-02 10:49:41 +01:00
|
|
|
auto stateP = reinterpret_cast<sd::graph::GraphState*>(state);
|
2019-06-06 14:21:15 +02:00
|
|
|
delete stateP;
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
Nd4jStatus execCustomOpWithScope_(Nd4jPointer *extraPointers, sd::graph::GraphState *state, Nd4jLong opHash, Nd4jLong *scopes, int numScopes, Nd4jPointer *inputBuffers, Nd4jPointer *inputShapes, int numInputs, Nd4jPointer *outputBuffers, Nd4jPointer *outputShapes, int numOutputs) {
|
2019-06-06 14:21:15 +02:00
|
|
|
/**
|
|
|
|
* That's basically exec, with VariableSpace provided in GraphState:
|
|
|
|
* depending on operation (i.e. while of if), different logic executors could be used
|
|
|
|
*/
|
|
|
|
|
|
|
|
auto graph = state->graph();
|
|
|
|
auto varSpace = state->variableSpace();
|
|
|
|
|
|
|
|
// Node is dynamically created, and has nothing beyond it: only inputs and outputs
|
|
|
|
// this node has id of 0, and inputs are
|
|
|
|
Node node(OpType_LOGIC, opHash, 0);
|
|
|
|
|
|
|
|
// mapping inputs
|
|
|
|
for (int e = 0; e < numInputs; e++) {
|
|
|
|
auto buffer = inputBuffers[e];
|
|
|
|
auto shapeInfo = reinterpret_cast<Nd4jLong *>(inputShapes[e]);
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
auto array = new sd::NDArray(buffer, shapeInfo, varSpace->launchContext());
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
// now we just put array to VarSpace
|
|
|
|
varSpace->putVariable(0, e, array);
|
|
|
|
node.pickInput(0, e);
|
|
|
|
}
|
|
|
|
|
|
|
|
// mapping scopes
|
|
|
|
for (int e = 0; e < numScopes; e++) {
|
|
|
|
// we should check scope existence in GraphState/Graph
|
|
|
|
int scopeId = (int) scopes[e];
|
|
|
|
if (!state->hasScope(scopeId)) {
|
|
|
|
// nd4j_printf("execCustomOpWithScope: referenced scope [%i] doesn't exist\n", scopeId);
|
|
|
|
return Status::THROW();
|
|
|
|
}
|
|
|
|
node.pickInput(scopeId, 0);
|
|
|
|
}
|
|
|
|
|
|
|
|
auto hZ = LogicExecutor::processNode(graph, &node);
|
|
|
|
if (hZ != Status::OK())
|
|
|
|
return hZ;
|
|
|
|
|
|
|
|
// mapping outputs
|
|
|
|
|
|
|
|
for (int e = 0; e < numOutputs; e++) {
|
|
|
|
auto buffer = outputBuffers[e];
|
|
|
|
auto shapeInfo = reinterpret_cast<Nd4jLong *>(outputShapes[e]);
|
|
|
|
|
|
|
|
NDArray array(buffer, shapeInfo, varSpace->launchContext());
|
|
|
|
|
|
|
|
// now we just put array to VarSpace to the same ID
|
|
|
|
//varSpace->putVariable(0, e, array);
|
|
|
|
|
|
|
|
auto t = varSpace->getVariable(0, e)->getNDArray();
|
|
|
|
array.assign(t);
|
|
|
|
}
|
|
|
|
|
|
|
|
// removing input variables
|
|
|
|
for (int e = 0; e < numInputs; e++) {
|
|
|
|
varSpace->dropVariable(0, e);
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// after some bla-bla-bla we should have Graph and Node for current op
|
|
|
|
return Status::OK();
|
|
|
|
}
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
Nd4jStatus execCustomOpWithScope(Nd4jPointer *extraPointers, Nd4jPointer state, Nd4jLong opHash, Nd4jLong *scopes, int numScopes, Nd4jPointer *inputBuffers, Nd4jPointer *inputShapes, int numInputs, Nd4jPointer *outputBuffers, Nd4jPointer *outputShapes, int numOutputs) {
|
2019-08-26 18:57:51 +02:00
|
|
|
try {
|
2020-03-02 10:49:41 +01:00
|
|
|
return execCustomOpWithScope_(extraPointers, reinterpret_cast<sd::graph::GraphState *>(state), opHash, scopes, numScopes, inputBuffers, inputShapes, numInputs, outputBuffers, outputShapes, numOutputs);
|
2019-08-26 18:57:51 +02:00
|
|
|
} catch (std::exception &e) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
2019-08-26 18:57:51 +02:00
|
|
|
return 1;
|
|
|
|
}
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
void deleteResultWrapper(Nd4jPointer ptr) {
|
2019-06-06 14:21:15 +02:00
|
|
|
// just 0 room for compiler s@!t
|
2020-03-02 10:49:41 +01:00
|
|
|
auto p = reinterpret_cast<sd::graph::ResultWrapper *>(ptr);
|
2019-06-06 14:21:15 +02:00
|
|
|
delete p;
|
|
|
|
}
|
|
|
|
|
|
|
|
/*
|
|
|
|
* TypeDef:
|
|
|
|
* void convertTypes(Nd4jPointer *extras, int srcType, Nd4jPointer hX, long N, int dstType, Nd4jPointer hZ);
|
|
|
|
*/
|
2019-07-22 13:34:08 +02:00
|
|
|
void convertTypes(Nd4jPointer *extras, int srcType, Nd4jPointer hX, Nd4jLong N, int dstType, Nd4jPointer hZ) {
|
2019-06-06 14:21:15 +02:00
|
|
|
auto hx = reinterpret_cast<void *>(hX);
|
|
|
|
auto hz = reinterpret_cast<void *>(hZ);
|
|
|
|
|
|
|
|
if (srcType == ND4J_FLOAT8) {
|
|
|
|
if (dstType == ND4J_FLOAT8) {
|
2020-03-02 10:49:41 +01:00
|
|
|
// convertGeneric<double, sd::float8>(hx, N, hz);
|
2019-06-06 14:21:15 +02:00
|
|
|
} else if (dstType == ND4J_INT8) {
|
2020-03-02 10:49:41 +01:00
|
|
|
//sd::TypeCast::convertGeneric<sd::float8, sd::int8>(nullptr, hx, N, hz);
|
2019-06-06 14:21:15 +02:00
|
|
|
} else if (dstType == ND4J_UINT8) {
|
2020-03-02 10:49:41 +01:00
|
|
|
//sd::TypeCast::convertGeneric<sd::float8, sd::uint8>(nullptr, hx, N, hz);
|
2019-06-06 14:21:15 +02:00
|
|
|
} else if (dstType == ND4J_FLOAT16) {
|
2020-03-02 10:49:41 +01:00
|
|
|
//sd::TypeCast::convertGeneric<sd::float8, float16>(nullptr, hx, N, hz);
|
2019-06-06 14:21:15 +02:00
|
|
|
} else if (dstType == ND4J_INT16) {
|
2020-03-02 10:49:41 +01:00
|
|
|
//sd::TypeCast::convertGeneric<sd::float8, sd::int16>(nullptr, hx, N, hz);
|
2019-06-06 14:21:15 +02:00
|
|
|
} else if (dstType == ND4J_UINT16) {
|
2020-03-02 10:49:41 +01:00
|
|
|
//sd::TypeCast::convertGeneric<sd::float8, sd::uint16>(nullptr, hx, N, hz);
|
2019-06-06 14:21:15 +02:00
|
|
|
} else if (dstType == ND4J_FLOAT24) {
|
|
|
|
|
|
|
|
} else if (dstType == ND4J_FLOAT32) {
|
2020-03-02 10:49:41 +01:00
|
|
|
//sd::TypeCast::convertGeneric<sd::float8, float>(nullptr, hx, N, hz);
|
2019-06-06 14:21:15 +02:00
|
|
|
} else if (dstType == ND4J_DOUBLE) {
|
2020-03-02 10:49:41 +01:00
|
|
|
//sd::TypeCast::convertGeneric<sd::float8, double>(nullptr, hx, N, hz);
|
2019-06-06 14:21:15 +02:00
|
|
|
} else {
|
|
|
|
//nd4j_printf("Unsupported types conversion: [%i] -> [%i]\n", srcType, dstType);
|
|
|
|
}
|
|
|
|
} else if (srcType == ND4J_INT8) {
|
|
|
|
if (dstType == ND4J_FLOAT8) {
|
2020-03-02 10:49:41 +01:00
|
|
|
//sd::TypeCast::convertGeneric<sd::int8, sd::float8>(nullptr, hx, N, hz);
|
2019-06-06 14:21:15 +02:00
|
|
|
} else if (dstType == ND4J_INT8) {
|
2020-03-02 10:49:41 +01:00
|
|
|
//convertGeneric<sd::int8, sd::int8>(hx, N, hz);
|
2019-06-06 14:21:15 +02:00
|
|
|
} else if (dstType == ND4J_UINT8) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::TypeCast::convertGeneric<int8_t, uint8_t>(nullptr, hx, N, hz);
|
2019-06-06 14:21:15 +02:00
|
|
|
} else if (dstType == ND4J_FLOAT16) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::TypeCast::convertGeneric<int8_t, float16>(nullptr, hx, N, hz);
|
2019-06-06 14:21:15 +02:00
|
|
|
} else if (dstType == ND4J_INT16) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::TypeCast::convertGeneric<int8_t, int16_t>(nullptr, hx, N, hz);
|
2019-06-06 14:21:15 +02:00
|
|
|
} else if (dstType == ND4J_UINT16) {
|
2020-03-02 10:49:41 +01:00
|
|
|
//sd::TypeCast::convertGeneric<int8_t, uint16_t>(nullptr, hx, N, hz);
|
2019-06-06 14:21:15 +02:00
|
|
|
} else if (dstType == ND4J_FLOAT24) {
|
|
|
|
// TODO: eventually we might want to add it
|
|
|
|
} else if (dstType == ND4J_FLOAT32) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::TypeCast::convertGeneric<int8_t, float>(nullptr, hx, N, hz);
|
2019-06-06 14:21:15 +02:00
|
|
|
} else if (dstType == ND4J_DOUBLE) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::TypeCast::convertGeneric<int8_t, double>(nullptr, hx, N, hz);
|
2019-06-06 14:21:15 +02:00
|
|
|
} else {
|
|
|
|
nd4j_printf("Unsupported types conversion: [%i] -> [%i]\n", srcType, dstType);
|
|
|
|
}
|
|
|
|
} else if (srcType == ND4J_UINT8) {
|
|
|
|
if (dstType == ND4J_FLOAT8) {
|
2020-03-02 10:49:41 +01:00
|
|
|
// sd::TypeCast::convertGeneric<uint8_t, sd::float8>(nullptr, hx, N, hz);
|
2019-06-06 14:21:15 +02:00
|
|
|
} else if (dstType == ND4J_INT8) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::TypeCast::convertGeneric<uint8_t, int8_t>(nullptr, hx, N, hz);
|
2019-06-06 14:21:15 +02:00
|
|
|
} else if (dstType == ND4J_UINT8) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::TypeCast::convertGeneric<uint8_t, uint8_t>(nullptr, hx, N, hz);
|
2019-06-06 14:21:15 +02:00
|
|
|
} else if (dstType == ND4J_FLOAT16) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::TypeCast::convertGeneric<uint8_t, float16>(nullptr, hx, N, hz);
|
2019-06-06 14:21:15 +02:00
|
|
|
} else if (dstType == ND4J_INT16) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::TypeCast::convertGeneric<uint8_t, int16_t>(nullptr, hx, N, hz);
|
2019-06-06 14:21:15 +02:00
|
|
|
} else if (dstType == ND4J_UINT16) {
|
2020-03-02 10:49:41 +01:00
|
|
|
// sd::TypeCast::convertGeneric<uint8_t, uint16_t>(nullptr, hx, N, hz);
|
2019-06-06 14:21:15 +02:00
|
|
|
} else if (dstType == ND4J_FLOAT24) {
|
|
|
|
// TODO: still might want to add
|
|
|
|
} else if (dstType == ND4J_FLOAT32) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::TypeCast::convertGeneric<uint8_t, float>(nullptr, hx, N, hz);
|
2019-06-06 14:21:15 +02:00
|
|
|
} else if (dstType == ND4J_DOUBLE) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::TypeCast::convertGeneric<uint8_t, double>(nullptr, hx, N, hz);
|
2019-06-06 14:21:15 +02:00
|
|
|
} else {
|
|
|
|
nd4j_printf("Unsupported types conversion: [%i] -> [%i]\n", srcType, dstType);
|
|
|
|
}
|
|
|
|
} else if (srcType == ND4J_FLOAT16) {
|
|
|
|
if (dstType == ND4J_FLOAT8) {
|
2020-03-02 10:49:41 +01:00
|
|
|
// sd::TypeCast::convertGeneric<float16, sd::float8>(nullptr, hx, N, hz);
|
2019-06-06 14:21:15 +02:00
|
|
|
} else if (dstType == ND4J_INT8) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::TypeCast::convertGeneric<float16, int8_t>(nullptr, hx, N, hz);
|
2019-06-06 14:21:15 +02:00
|
|
|
} else if (dstType == ND4J_UINT8) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::TypeCast::convertGeneric<float16, uint8_t>(nullptr, hx, N, hz);
|
2019-06-06 14:21:15 +02:00
|
|
|
} else if (dstType == ND4J_FLOAT16) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::TypeCast::convertGeneric<float16, float16>(nullptr, hx, N, hz);
|
2019-06-06 14:21:15 +02:00
|
|
|
} else if (dstType == ND4J_INT16) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::TypeCast::convertGeneric<float16, int16_t>(nullptr, hx, N, hz);
|
2019-06-06 14:21:15 +02:00
|
|
|
} else if (dstType == ND4J_UINT16) {
|
2020-03-02 10:49:41 +01:00
|
|
|
// sd::TypeCast::convertGeneric<float16, uint16_t>(nullptr, hx, N, hz);
|
2019-06-06 14:21:15 +02:00
|
|
|
} else if (dstType == ND4J_FLOAT24) {
|
|
|
|
// TODO: .... ^^^
|
|
|
|
} else if (dstType == ND4J_FLOAT32) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::TypeCast::convertGeneric<float16, float>(nullptr, hx, N, hz);
|
2019-06-06 14:21:15 +02:00
|
|
|
} else if (dstType == ND4J_DOUBLE) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::TypeCast::convertGeneric<float16, double>(nullptr, hx, N, hz);
|
2019-06-06 14:21:15 +02:00
|
|
|
} else if (dstType == ND4J_THRESHOLD) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::TypeCast::convertToThreshold<float16>(nullptr, hx, N, hz);
|
2019-06-06 14:21:15 +02:00
|
|
|
} else {
|
|
|
|
nd4j_printf("Unsupported types conversion: [%i] -> [%i]\n", srcType, dstType);
|
|
|
|
}
|
|
|
|
} else if (srcType == ND4J_INT16) {
|
|
|
|
if (dstType == ND4J_FLOAT8) {
|
2020-03-02 10:49:41 +01:00
|
|
|
// sd::TypeCast::convertGeneric<int16_t, sd::float8>(nullptr, hx, N, hz);
|
2019-06-06 14:21:15 +02:00
|
|
|
} else if (dstType == ND4J_INT8) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::TypeCast::convertGeneric<int16_t, int8_t>(nullptr, hx, N, hz);
|
2019-06-06 14:21:15 +02:00
|
|
|
} else if (dstType == ND4J_UINT8) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::TypeCast::convertGeneric<int16_t, uint8_t>(nullptr, hx, N, hz);
|
2019-06-06 14:21:15 +02:00
|
|
|
} else if (dstType == ND4J_FLOAT16) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::TypeCast::convertGeneric<int16_t, float16>(nullptr, hx, N, hz);
|
2019-06-06 14:21:15 +02:00
|
|
|
} else if (dstType == ND4J_INT16) {
|
2020-03-02 10:49:41 +01:00
|
|
|
//sd::TypeCast::convertGeneric<int16_t, int16_t>(nullptr, hx, N, hz);
|
2019-06-06 14:21:15 +02:00
|
|
|
} else if (dstType == ND4J_UINT16) {
|
2020-03-02 10:49:41 +01:00
|
|
|
// sd::TypeCast::convertGeneric<int16_t, uint16_t>(nullptr, hx, N, hz);
|
2019-06-06 14:21:15 +02:00
|
|
|
} else if (dstType == ND4J_FLOAT24) {
|
|
|
|
// TODO...
|
|
|
|
} else if (dstType == ND4J_FLOAT32) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::TypeCast::convertGeneric<int16_t, float>(nullptr, hx, N, hz);
|
2019-06-06 14:21:15 +02:00
|
|
|
} else if (dstType == ND4J_DOUBLE) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::TypeCast::convertGeneric<int16_t, double>(nullptr, hx, N, hz);
|
2019-06-06 14:21:15 +02:00
|
|
|
} else {
|
|
|
|
printf("Unsupported types conversion: [%i] -> [%i]\n", srcType, dstType);
|
|
|
|
}
|
|
|
|
} else if (srcType == ND4J_FLOAT24) {
|
|
|
|
|
|
|
|
} else if (srcType == ND4J_FLOAT32) {
|
|
|
|
if (dstType == ND4J_FLOAT8) {
|
2020-03-02 10:49:41 +01:00
|
|
|
// sd::TypeCast::convertGeneric<float, sd::float8>(nullptr, hx, N, hz);
|
2019-06-06 14:21:15 +02:00
|
|
|
} else if (dstType == ND4J_INT8) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::TypeCast::convertGeneric<float, int8_t>(nullptr, hx, N, hz);
|
2019-06-06 14:21:15 +02:00
|
|
|
} else if (dstType == ND4J_UINT8) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::TypeCast::convertGeneric<float, uint8_t>(nullptr, hx, N, hz);
|
2019-06-06 14:21:15 +02:00
|
|
|
} else if (dstType == ND4J_FLOAT16) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::TypeCast::convertGeneric<float, float16>(nullptr, hx, N, hz);
|
2019-06-06 14:21:15 +02:00
|
|
|
} else if (dstType == ND4J_INT16) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::TypeCast::convertGeneric<float, int16_t>(nullptr, hx, N, hz);
|
2019-06-06 14:21:15 +02:00
|
|
|
} else if (dstType == ND4J_UINT16) {
|
2020-03-02 10:49:41 +01:00
|
|
|
// sd::TypeCast::convertGeneric<float, uint16_t>(nullptr, hx, N, hz);
|
2019-06-06 14:21:15 +02:00
|
|
|
} else if (dstType == ND4J_FLOAT24) {
|
|
|
|
|
|
|
|
} else if (dstType == ND4J_DOUBLE) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::TypeCast::convertGeneric<float, double>(nullptr, hx, N, hz);
|
2019-06-06 14:21:15 +02:00
|
|
|
} else if (dstType == ND4J_THRESHOLD) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::TypeCast::convertToThreshold<float>(nullptr, hx, N, hz);
|
2019-06-06 14:21:15 +02:00
|
|
|
} else {
|
|
|
|
nd4j_printf("Unsupported types conversion: [%i] -> [%i]\n", srcType, dstType);
|
|
|
|
}
|
|
|
|
} else if (srcType == ND4J_DOUBLE) {
|
|
|
|
if (dstType == ND4J_FLOAT8) {
|
2020-03-02 10:49:41 +01:00
|
|
|
// sd::TypeCast::convertGeneric<double, sd::float8>(nullptr, hx, N, hz);
|
2019-06-06 14:21:15 +02:00
|
|
|
} else if (dstType == ND4J_INT8) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::TypeCast::convertGeneric<double, int8_t>(nullptr, hx, N, hz);
|
2019-06-06 14:21:15 +02:00
|
|
|
} else if (dstType == ND4J_UINT8) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::TypeCast::convertGeneric<double, uint8_t>(nullptr, hx, N, hz);
|
2019-06-06 14:21:15 +02:00
|
|
|
} else if (dstType == ND4J_FLOAT16) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::TypeCast::convertGeneric<double, float16>(nullptr, hx, N, hz);
|
2019-06-06 14:21:15 +02:00
|
|
|
} else if (dstType == ND4J_INT16) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::TypeCast::convertGeneric<double, int16_t>(nullptr, hx, N, hz);
|
2019-06-06 14:21:15 +02:00
|
|
|
} else if (dstType == ND4J_UINT16) {
|
2020-03-02 10:49:41 +01:00
|
|
|
// sd::TypeCast::convertGeneric<double, uint16_t>(nullptr, hx, N, hz);
|
2019-06-06 14:21:15 +02:00
|
|
|
} else if (dstType == ND4J_FLOAT24) {
|
|
|
|
|
|
|
|
} else if (dstType == ND4J_FLOAT32) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::TypeCast::convertGeneric<double, float>(nullptr, hx, N, hz);
|
2019-06-06 14:21:15 +02:00
|
|
|
} else if (dstType == ND4J_DOUBLE) {
|
|
|
|
//
|
|
|
|
} else if (dstType == ND4J_THRESHOLD) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::TypeCast::convertToThreshold<double>(nullptr, hx, N, hz);
|
2019-06-06 14:21:15 +02:00
|
|
|
} else {
|
|
|
|
nd4j_printf("Unsupported types conversion: [%i] -> [%i]\n", srcType, dstType);
|
|
|
|
}
|
|
|
|
} else if (srcType == ND4J_THRESHOLD) {
|
|
|
|
if (dstType == ND4J_FLOAT16) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::TypeCast::convertFromThreshold<float16>(nullptr, hx, N, hz);
|
2019-06-06 14:21:15 +02:00
|
|
|
} else if (dstType == ND4J_FLOAT32) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::TypeCast::convertFromThreshold<float>(nullptr, hx, N, hz);
|
2019-06-06 14:21:15 +02:00
|
|
|
} else if (dstType == ND4J_DOUBLE) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::TypeCast::convertFromThreshold<double>(nullptr, hx, N, hz);
|
2019-06-06 14:21:15 +02:00
|
|
|
} else {
|
|
|
|
nd4j_printf("Unsupported types conversion: [%i] -> [%i]\n", srcType, dstType);
|
|
|
|
}
|
|
|
|
} else {
|
|
|
|
nd4j_printf("Unsupported types conversion: [%i] -> [%i]\n", srcType, dstType);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
/*
|
2019-07-22 13:34:08 +02:00
|
|
|
void fillUtf8String(Nd4jPointer *extraPointers, const char **strings, int numStrings, Nd4jPointer buffer) {
|
2020-03-02 10:49:41 +01:00
|
|
|
auto hZ = reinterpret_cast<sd::utf8string**>(buffer);
|
2019-06-06 14:21:15 +02:00
|
|
|
for (int e = 0; e < numStrings; e++) {
|
2020-03-02 10:49:41 +01:00
|
|
|
hZ[e] = reinterpret_cast<sd::utf8string*>(createUtf8String(extraPointers, strings[e]));
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
}
|
|
|
|
*/
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
Nd4jPointer createUtf8String(Nd4jPointer *extraPointers, const char *string, int length) {
|
2020-03-02 10:49:41 +01:00
|
|
|
auto u = new sd::utf8string(string, length);
|
2019-06-06 14:21:15 +02:00
|
|
|
return reinterpret_cast<Nd4jPointer>(u);
|
|
|
|
}
|
|
|
|
|
2019-07-24 14:14:54 +02:00
|
|
|
Nd4jLong getUtf8StringLength(Nd4jPointer *extraPointers, Nd4jPointer ptr) {
|
2020-03-02 10:49:41 +01:00
|
|
|
return reinterpret_cast<sd::utf8string*>(ptr)->_length;
|
2019-07-24 14:14:54 +02:00
|
|
|
}
|
|
|
|
char* getUtf8StringBuffer(Nd4jPointer *extraPointers, Nd4jPointer ptr) {
|
2020-03-02 10:49:41 +01:00
|
|
|
return reinterpret_cast<sd::utf8string*>(ptr)->_buffer;
|
2019-07-24 14:14:54 +02:00
|
|
|
}
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
void deleteUtf8String(Nd4jPointer *extraPointers, Nd4jPointer ptr) {
|
2020-03-02 10:49:41 +01:00
|
|
|
delete(reinterpret_cast<sd::utf8string*>(ptr));
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
2019-10-31 10:23:09 +01:00
|
|
|
template <typename I>
|
2020-05-09 07:06:14 +02:00
|
|
|
static void _scatterUpdate(
|
|
|
|
Nd4jPointer *extraPointers, int opCode, int numOfSubArrs,
|
|
|
|
void* hX, const Nd4jLong* hXShapeInfo, const Nd4jLong* hXOffsets,
|
|
|
|
void* dX, const Nd4jLong* dXShapeInfo, const Nd4jLong* dXOffsets,
|
|
|
|
void* hY, const Nd4jLong* hYShapeInfo, const Nd4jLong* hYOffsets,
|
|
|
|
void* dY, const Nd4jLong* dYShapeInfo, const Nd4jLong* dYOffsets,
|
|
|
|
void* vIindexes, const Nd4jLong* hIndicesShapeInfo, void* dIindexes, const Nd4jLong* dIndicesShapeInfo) {
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-10-31 10:23:09 +01:00
|
|
|
auto hIindexes = reinterpret_cast<I*>(vIindexes);
|
2019-11-13 15:04:59 +01:00
|
|
|
auto func = PRAGMA_THREADS_DO {
|
|
|
|
for (int i = 0; i < numOfSubArrs; ++i) {
|
|
|
|
int threadIndex = thread_id;
|
2019-10-31 10:23:09 +01:00
|
|
|
const auto xIndex = hIindexes[i];
|
|
|
|
const bool isOwner = xIndex < numThreads ? threadIndex == xIndex : threadIndex == xIndex % numThreads;
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-10-31 10:23:09 +01:00
|
|
|
if (!isOwner)
|
|
|
|
continue;
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-11-13 15:04:59 +01:00
|
|
|
NDArray inSubArr(reinterpret_cast<int8_t *>(hX) + (hXOffsets[hIindexes[i]] * DataTypeUtils::sizeOf(hXShapeInfo)), hXShapeInfo);
|
|
|
|
NDArray updSubArr(reinterpret_cast<int8_t *>(hY) + (hYOffsets[i] * DataTypeUtils::sizeOf(hXShapeInfo)), hYShapeInfo);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-10-31 10:23:09 +01:00
|
|
|
if (inSubArr.lengthOf() != updSubArr.lengthOf()) {
|
|
|
|
continue;
|
|
|
|
}
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-11-13 15:04:59 +01:00
|
|
|
switch (opCode) {
|
|
|
|
case 0:
|
2019-12-20 20:35:39 +01:00
|
|
|
inSubArr.applyPairwiseTransform(pairwise::Add, updSubArr, inSubArr);
|
2019-11-13 15:04:59 +01:00
|
|
|
break;
|
|
|
|
case 1:
|
2019-12-20 20:35:39 +01:00
|
|
|
inSubArr.applyPairwiseTransform(pairwise::Subtract, updSubArr, inSubArr);
|
2019-11-13 15:04:59 +01:00
|
|
|
break;
|
|
|
|
case 2:
|
2019-12-20 20:35:39 +01:00
|
|
|
inSubArr.applyPairwiseTransform(pairwise::Multiply, updSubArr, inSubArr);
|
2019-11-13 15:04:59 +01:00
|
|
|
break;
|
|
|
|
case 3:
|
2019-12-20 20:35:39 +01:00
|
|
|
inSubArr.applyPairwiseTransform(pairwise::Divide, updSubArr, inSubArr);
|
2019-11-13 15:04:59 +01:00
|
|
|
break;
|
|
|
|
case 4:
|
2019-12-20 20:35:39 +01:00
|
|
|
inSubArr.applyPairwiseTransform(pairwise::ReverseSubtract, updSubArr, inSubArr);
|
2019-11-13 15:04:59 +01:00
|
|
|
break;
|
|
|
|
case 5:
|
2019-12-20 20:35:39 +01:00
|
|
|
inSubArr.applyPairwiseTransform(pairwise::ReverseDivide, updSubArr, inSubArr);
|
2019-11-13 15:04:59 +01:00
|
|
|
break;
|
|
|
|
case 6:
|
2019-12-20 20:35:39 +01:00
|
|
|
inSubArr.applyPairwiseTransform(pairwise::CopyPws, updSubArr, inSubArr);
|
2019-11-13 15:04:59 +01:00
|
|
|
break;
|
|
|
|
default:
|
|
|
|
continue;
|
|
|
|
}
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
2019-11-13 15:04:59 +01:00
|
|
|
};
|
2019-10-31 10:23:09 +01:00
|
|
|
|
2020-03-09 06:22:49 +01:00
|
|
|
samediff::Threads::parallel_do(func);
|
2019-10-31 10:23:09 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
////////////////////////////////////////////////////////////////////////
|
|
|
|
void scatterUpdate(Nd4jPointer *extraPointers, int opCode, int numOfSubArrs,
|
2020-05-09 07:06:14 +02:00
|
|
|
void* hX, const Nd4jLong* hXShapeInfo, const Nd4jLong* hXOffsets,
|
|
|
|
void* dX, const Nd4jLong* dXShapeInfo, const Nd4jLong* dXOffsets,
|
|
|
|
void* hY, const Nd4jLong* hYShapeInfo, const Nd4jLong* hYOffsets,
|
|
|
|
void* dY, const Nd4jLong* dYShapeInfo, const Nd4jLong* dYOffsets,
|
|
|
|
void* hIindexes, const Nd4jLong* hIndicesShapeInfo, void* dIindexes, const Nd4jLong* dIndicesShapeInfo) {
|
2019-10-31 10:23:09 +01:00
|
|
|
auto iType = ArrayOptions::dataType(hIndicesShapeInfo);
|
|
|
|
|
|
|
|
try {
|
|
|
|
BUILD_SINGLE_SELECTOR(iType, _scatterUpdate, (extraPointers, opCode, numOfSubArrs, hX, hXShapeInfo, hXOffsets, dX, dXShapeInfo, dXOffsets, hY, hYShapeInfo, hYOffsets, dY, dYShapeInfo, dYOffsets, hIindexes, hIndicesShapeInfo, dIindexes, dIndicesShapeInfo), INDEXING_TYPES);
|
2019-08-26 18:57:51 +02:00
|
|
|
} catch (std::exception &e) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2019-11-13 15:04:59 +01:00
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
void inspectArray(Nd4jPointer *extraPointers, Nd4jPointer buffer, Nd4jLong *shapeInfo, Nd4jPointer specialBuffer, Nd4jLong *specialShapeInfo, Nd4jPointer debugInfo) {
|
2019-08-26 18:57:51 +02:00
|
|
|
try {
|
2020-03-02 10:49:41 +01:00
|
|
|
auto p = reinterpret_cast<sd::DebugInfo *>(debugInfo);
|
2019-08-26 18:57:51 +02:00
|
|
|
NDArray array(buffer, shapeInfo);
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::DebugHelper::retrieveDebugStatistics(p, &array);
|
2019-08-26 18:57:51 +02:00
|
|
|
} catch (std::exception &e) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
2019-08-26 18:57:51 +02:00
|
|
|
}
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
void tryPointer(Nd4jPointer extra, Nd4jPointer p, int len) {
|
2019-08-26 18:57:51 +02:00
|
|
|
try {
|
|
|
|
auto buf = reinterpret_cast<int8_t *>(p);
|
|
|
|
int cnt = 0;
|
|
|
|
for (int i = 0; i < len; i++)
|
|
|
|
cnt += buf[cnt];
|
|
|
|
} catch (std::exception &e) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
2019-08-26 18:57:51 +02:00
|
|
|
}
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
2020-06-06 14:26:55 +02:00
|
|
|
sd::ConstantShapeBuffer* shapeBuffer(int rank, Nd4jLong *shape, Nd4jLong *strides, sd::DataType dtype, char order, Nd4jLong ews, bool empty) {
|
2019-08-26 18:57:51 +02:00
|
|
|
try {
|
2020-06-06 14:26:55 +02:00
|
|
|
auto buffer = new ConstantShapeBuffer();
|
|
|
|
*buffer = sd::ConstantShapeHelper::getInstance().bufferForShapeInfo(
|
2019-08-26 18:57:51 +02:00
|
|
|
ShapeDescriptor(dtype, order, shape, strides, rank, ews, empty));
|
|
|
|
return buffer;
|
|
|
|
} catch (std::exception &e) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
2019-08-26 18:57:51 +02:00
|
|
|
return nullptr;
|
|
|
|
}
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
2020-06-06 14:26:55 +02:00
|
|
|
void deleteConstantShapeBuffer(sd::ConstantShapeBuffer* ptr) {
|
2019-07-24 14:14:54 +02:00
|
|
|
delete ptr;
|
2019-07-22 13:00:24 +02:00
|
|
|
}
|
|
|
|
|
2020-06-06 14:26:55 +02:00
|
|
|
void deleteConstantDataBuffer(sd::ConstantDataBuffer* ptr) {
|
|
|
|
delete ptr;
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
void deleteTadPack(sd::TadPack* ptr) {
|
2019-07-24 14:14:54 +02:00
|
|
|
delete ptr;
|
2019-07-22 14:55:28 +02:00
|
|
|
}
|
|
|
|
|
2020-05-09 07:06:14 +02:00
|
|
|
sd::ConstantDataBuffer* constantBufferLong(sd::DataType dtype, const Nd4jLong *data, int length) {
|
2019-06-06 14:21:15 +02:00
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ConstantDataBuffer* constantBufferDouble(sd::DataType dtype, double *data, int length) {
|
2019-06-06 14:21:15 +02:00
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ConstantDataBuffer* constantBuffer(sd::DataType dtype, sd::ConstantDescriptor *descriptor) {
|
2019-08-26 18:57:51 +02:00
|
|
|
try {
|
2020-06-06 14:26:55 +02:00
|
|
|
return sd::ConstantHelper::getInstance().constantBuffer(*descriptor, dtype);
|
2019-08-26 18:57:51 +02:00
|
|
|
} catch (std::exception &e) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
2019-08-26 18:57:51 +02:00
|
|
|
return nullptr;
|
|
|
|
}
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
2020-06-06 14:26:55 +02:00
|
|
|
Nd4jPointer getConstantShapeBufferPrimary(sd::ConstantShapeBuffer* dbf) {
|
|
|
|
return const_cast<Nd4jLong*>(dbf->primary());
|
|
|
|
}
|
|
|
|
|
|
|
|
Nd4jPointer getConstantShapeBufferSpecial(sd::ConstantShapeBuffer* dbf) {
|
|
|
|
return const_cast<Nd4jLong*>(dbf->special());
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
Nd4jPointer getConstantDataBufferPrimary(sd::ConstantDataBuffer* dbf) {
|
2019-07-24 14:14:54 +02:00
|
|
|
return dbf->primary();
|
|
|
|
}
|
2020-03-02 10:49:41 +01:00
|
|
|
Nd4jPointer getConstantDataBufferSpecial(sd::ConstantDataBuffer* dbf) {
|
2019-07-24 14:14:54 +02:00
|
|
|
return dbf->special();
|
|
|
|
}
|
2020-03-02 10:49:41 +01:00
|
|
|
Nd4jLong getConstantDataBufferLength(sd::ConstantDataBuffer* dbf) {
|
2019-07-24 14:14:54 +02:00
|
|
|
return dbf->length();
|
|
|
|
}
|
2020-03-02 10:49:41 +01:00
|
|
|
Nd4jLong getConstantDataBufferSizeOf(sd::ConstantDataBuffer* dbf) {
|
2019-07-24 14:14:54 +02:00
|
|
|
return dbf->sizeOf();
|
|
|
|
}
|
|
|
|
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::graph::Context* createGraphContext(int nodeId) {
|
2019-08-26 18:57:51 +02:00
|
|
|
try {
|
2020-03-02 10:49:41 +01:00
|
|
|
return new sd::graph::Context(nodeId);
|
2019-08-26 18:57:51 +02:00
|
|
|
} catch (std::exception &e) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
2019-08-26 18:57:51 +02:00
|
|
|
return nullptr;
|
|
|
|
}
|
2019-07-24 14:14:54 +02:00
|
|
|
}
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::graph::RandomGenerator* getGraphContextRandomGenerator(sd::graph::Context* ptr) {
|
2019-07-24 14:14:54 +02:00
|
|
|
return &ptr->randomGenerator();
|
|
|
|
}
|
2020-03-02 10:49:41 +01:00
|
|
|
void markGraphContextInplace(sd::graph::Context* ptr, bool reallyInplace) {
|
2019-07-24 14:14:54 +02:00
|
|
|
ptr->markInplace(reallyInplace);
|
|
|
|
}
|
2020-03-02 10:49:41 +01:00
|
|
|
void setGraphContextCudaContext(sd::graph::Context* ptr, void *stream, void *reductionPointer, void *allocationPointer) {
|
2019-07-24 14:14:54 +02:00
|
|
|
}
|
2020-03-02 10:49:41 +01:00
|
|
|
void setGraphContextInputArray(sd::graph::Context* ptr, int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo) {
|
2019-07-24 14:14:54 +02:00
|
|
|
ptr->setInputArray(index, buffer, shapeInfo, specialBuffer, specialShapeInfo);
|
|
|
|
}
|
2020-03-02 10:49:41 +01:00
|
|
|
void setGraphContextOutputArray(sd::graph::Context* ptr, int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo) {
|
2019-07-24 14:14:54 +02:00
|
|
|
ptr->setOutputArray(index, buffer, shapeInfo, specialBuffer, specialShapeInfo);
|
|
|
|
}
|
2020-01-04 11:27:50 +01:00
|
|
|
|
|
|
|
void setGraphContextInputBuffer(OpaqueContext* ptr, int index, OpaqueDataBuffer *buffer, void *shapeInfo, void *specialShapeInfo) {
|
|
|
|
ptr->setInputArray(index, buffer, shapeInfo, specialShapeInfo);
|
|
|
|
}
|
|
|
|
|
|
|
|
void setGraphContextOutputBuffer(OpaqueContext* ptr, int index, OpaqueDataBuffer *buffer, void *shapeInfo, void *specialShapeInfo) {
|
|
|
|
ptr->setOutputArray(index, buffer, shapeInfo, specialShapeInfo);
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
void setGraphContextTArguments(sd::graph::Context* ptr, double *arguments, int numberOfArguments) {
|
2019-07-24 14:14:54 +02:00
|
|
|
ptr->setTArguments(arguments, numberOfArguments);
|
|
|
|
}
|
2020-03-02 10:49:41 +01:00
|
|
|
void setGraphContextIArguments(sd::graph::Context* ptr, Nd4jLong *arguments, int numberOfArguments) {
|
2019-07-24 14:14:54 +02:00
|
|
|
ptr->setIArguments(arguments, numberOfArguments);
|
|
|
|
}
|
2020-03-02 10:49:41 +01:00
|
|
|
void setGraphContextBArguments(sd::graph::Context* ptr, bool *arguments, int numberOfArguments) {
|
2019-07-24 14:14:54 +02:00
|
|
|
ptr->setBArguments(arguments, numberOfArguments);
|
|
|
|
}
|
2020-01-30 08:07:24 +01:00
|
|
|
|
|
|
|
void setGraphContextDArguments(OpaqueContext* ptr, int *arguments, int numberOfArguments) {
|
2020-03-02 10:49:41 +01:00
|
|
|
std::vector<sd::DataType> dtypes(numberOfArguments);
|
2020-01-30 08:07:24 +01:00
|
|
|
for (int e = 0; e < numberOfArguments; e++)
|
2020-03-02 10:49:41 +01:00
|
|
|
dtypes[e] = (sd::DataType) arguments[e];
|
2020-01-30 08:07:24 +01:00
|
|
|
|
|
|
|
ptr->setDArguments(dtypes);
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
void deleteGraphContext(sd::graph::Context* ptr) {
|
2019-07-24 14:14:54 +02:00
|
|
|
delete ptr;
|
|
|
|
}
|
|
|
|
|
2019-11-14 12:35:02 +01:00
|
|
|
void ctxAllowHelpers(OpaqueContext* ptr, bool reallyAllow) {
|
|
|
|
ptr->allowHelpers(reallyAllow);
|
|
|
|
}
|
2019-07-24 14:14:54 +02:00
|
|
|
|
2020-01-27 08:00:07 +01:00
|
|
|
void ctxSetExecutionMode(OpaqueContext* ptr, int execMode) {
|
|
|
|
if (execMode < 0 || execMode > 2)
|
|
|
|
execMode = 0;
|
|
|
|
|
2020-03-09 06:22:49 +01:00
|
|
|
ptr->setExecutionMode((samediff::ExecutionMode) execMode);
|
2020-01-27 08:00:07 +01:00
|
|
|
}
|
|
|
|
|
2020-02-05 05:27:24 +01:00
|
|
|
void ctxPurge(OpaqueContext* ptr) {
|
|
|
|
ptr->clearFastPath();
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::graph::RandomGenerator* createRandomGenerator(Nd4jLong rootSeed, Nd4jLong nodeSeed) {
|
|
|
|
return new sd::graph::RandomGenerator(rootSeed, nodeSeed);
|
2019-07-24 14:14:54 +02:00
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
Nd4jLong getRandomGeneratorRootState(sd::graph::RandomGenerator* ptr) {
|
2019-07-24 14:14:54 +02:00
|
|
|
return ptr->rootState();
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
Nd4jLong getRandomGeneratorNodeState(sd::graph::RandomGenerator* ptr) {
|
2019-07-24 14:14:54 +02:00
|
|
|
return ptr->nodeState();
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
void setRandomGeneratorStates(sd::graph::RandomGenerator* ptr, Nd4jLong rootSeed, Nd4jLong nodeSeed) {
|
2019-07-24 14:14:54 +02:00
|
|
|
ptr->setStates(rootSeed, nodeSeed);
|
|
|
|
}
|
|
|
|
|
2020-05-30 20:13:33 +02:00
|
|
|
float getRandomGeneratorRelativeFloat(sd::graph::RandomGenerator* ptr, Nd4jLong index) {
|
|
|
|
return ptr->relativeT<float>(index);
|
|
|
|
}
|
|
|
|
|
|
|
|
double getRandomGeneratorRelativeDouble(sd::graph::RandomGenerator* ptr, Nd4jLong index) {
|
|
|
|
return ptr->relativeT<double>(index);
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
int getRandomGeneratorRelativeInt(sd::graph::RandomGenerator* ptr, Nd4jLong index) {
|
2019-07-24 14:14:54 +02:00
|
|
|
return ptr->relativeInt(index);
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
Nd4jLong getRandomGeneratorRelativeLong(sd::graph::RandomGenerator* ptr, Nd4jLong index) {
|
2019-07-24 14:14:54 +02:00
|
|
|
return ptr->relativeLong(index);
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
void deleteRandomGenerator(sd::graph::RandomGenerator* ptr) {
|
2019-07-24 14:14:54 +02:00
|
|
|
delete ptr;
|
|
|
|
}
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
int dataTypeFromNpyHeader(void *header) {
|
2019-06-06 14:21:15 +02:00
|
|
|
return (int) cnpy::dataTypeFromHeader(reinterpret_cast<char *>(header));
|
|
|
|
}
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
Nd4jPointer shapeBufferForNumpy(Nd4jPointer npyArray) {
|
2019-08-26 18:57:51 +02:00
|
|
|
try {
|
|
|
|
cnpy::NpyArray arr = cnpy::loadNpyFromPointer(reinterpret_cast<char *>(npyArray));
|
|
|
|
unsigned int shapeSize = arr.shape.size();
|
|
|
|
std::vector<Nd4jLong> shape(shapeSize);
|
|
|
|
bool _empty = false;
|
|
|
|
for (unsigned int i = 0; i < shapeSize; i++) {
|
|
|
|
shape[i] = arr.shape[i];
|
|
|
|
|
|
|
|
if (arr.shape[i] == 0)
|
|
|
|
_empty = true;
|
|
|
|
}
|
|
|
|
|
|
|
|
auto dtype = cnpy::dataTypeFromHeader(reinterpret_cast<char *>(npyArray));
|
|
|
|
|
|
|
|
Nd4jLong *shapeBuffer;
|
|
|
|
if (shape.size() == 1 && shape[0] == 0) {
|
|
|
|
// scalar case
|
2020-03-02 10:49:41 +01:00
|
|
|
shapeBuffer = sd::ShapeBuilders::createScalarShapeInfo(dtype);
|
2019-08-26 18:57:51 +02:00
|
|
|
} else if (_empty) {
|
|
|
|
if (shapeSize > 0)
|
2020-03-02 10:49:41 +01:00
|
|
|
shapeBuffer = sd::ShapeBuilders::emptyShapeInfo(dtype, arr.fortranOrder ? 'f' : 'c', shape);
|
2019-08-26 18:57:51 +02:00
|
|
|
else
|
2020-03-02 10:49:41 +01:00
|
|
|
shapeBuffer = sd::ShapeBuilders::emptyShapeInfo(dtype);
|
2019-08-26 18:57:51 +02:00
|
|
|
} else {
|
2020-03-02 10:49:41 +01:00
|
|
|
shapeBuffer = sd::ShapeBuilders::createShapeInfo(dtype, arr.fortranOrder ? 'f' : 'c', shape);
|
2019-08-26 18:57:51 +02:00
|
|
|
}
|
2020-06-06 14:26:55 +02:00
|
|
|
return const_cast<Nd4jLong*>(sd::ConstantShapeHelper::getInstance().createFromExisting(shapeBuffer, true));
|
2019-08-26 18:57:51 +02:00
|
|
|
} catch (std::exception &e) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
2019-08-26 18:57:51 +02:00
|
|
|
return nullptr;
|
2019-06-15 13:34:34 +02:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
void sortByKey(Nd4jPointer *extraPointers,
|
2020-05-09 07:06:14 +02:00
|
|
|
void *x, const Nd4jLong *xShapeInfo,
|
|
|
|
void *dx, const Nd4jLong *dxShapeInfo,
|
|
|
|
void *y, const Nd4jLong *yShapeInfo,
|
|
|
|
void *dy, const Nd4jLong *dyShapeInfo,
|
Merge master to upstream (#7945)
* Shugeo strided slice zeros (#14)
* Modified strided_slice op to properly work with empty-like shapes.
* Fixed test for reduce_mean with empty-like input.
* [WIP] Last merge (#15)
* correct logsoftmax looss (#2)
* Small SameDiff listener fix (#4)
* Various fixes (#6)
* #7839 Fix for asXMatrix and tests
* #7866 EmbeddingSequenceLayer dtype fix + test
* #7856 SameDiff save/load stream methods
* #7859 RegressionEvaluation rank 4 fix + tests + axis configuration
* EvaluationBinary 3d/4d
* More evaluation 3d/4d tests
* #7847 Evaluation empty checks
* Small test ifx
* #7848 Fix median edge case
* Improve DL4J samediff layer tests
* [WIP] FastText wrapper implemented (#8)
* FastText implemented
* Some fixes
* Fix shapes for wordsNearest
* Validation of input vectors
* Fixes
* Fixed test
* Thread tagged
* Some tweaks
* setContextClassLoader for DeallocatorServiceThread
* Numpy format tests (#1)
* Various fixes (#11)
* #7852 SameDiff gather fix
* #7892 SameDiff placeholder to constant conversion
* #7890 validate input rank for MLN/CG init methods
* Fix broken permute shape calculation
* Permute and gather fixes
* Tests
* #7850 LogSumExp fix + test
* Handful of test fixes
* Empty arrays with non-scalar shapes (#10)
* minor rearrangements for lambdas
* empty tensors with non-scalar shapes
* numpy empty tensors with non-scalar shapes
* few more empty tweaks
* Small fixes
* conv3d signature update
* micro fix in batchnorm mkldnn
* Import fixes
* Fix
* MKL-DNN update
* Small fill fix
* fill with empty input + test
* Fixes
* Small error improvement
* Fix
* one special test
* couple of fixes for lstm
* Rewrite TFGraphMapper.getNDArrayFromTensor to be maintainable and less error prone
* Fixes
* FP16
* Unsigned
* BFloat16
* Fill op - empty tweaks
* - couple of fixes for empty arrays construction
- stack updated
* strided slice fix
* one transform test
* provide method for reducing shapeInfo in case of input array is empty
* Fixed reduceAlongDimensions to use empty input properly.
* couple of broadcast tests
* couple of tests broadcast tests + tweak to make them pass
* add check of non-empty to methods producing sub-arrays
* Fixed reshapeC with zeros in shape.
* complete empty check in reduce_... legacy ops
* Concat and cumsum/prod
* Tweak to empty shape inference on import
* add empty check to the rest of reduce legacy ops
* one more test
* correct typo in evalReduceShapeInfoEmpty
* Added tests for reduce_* ops to tests with zero shapes.
* few more tests for empty reductions
* Fixed strided_slice op with empty case and tests.
* one more empty reduction test
* Fixed strided_slice test.
* add empty check to NDArray::reshapei
* infOrMax
* empty min/max with infinity tests
* made unstack working correctly with empty arrays
* few IndexReduce tests + tweaks for empty shapes
* add test for empty concat
* few tests fixed
* Validation fix for reductions on empty shapes
* Reverse fix
* Reduction shape calc fixes
* SameDiff.generateOutputVariable: don't use shape function to determine number of outputs
* Range fix
* - NDArray constructor updated for scalars/empty arrays
- few tests fixed
* More fixes
* Empty creator fixes
* concat fix
* concat fix
* TF import tests: allow 'both all NaN' and 'both all inf' to pass
* Slice, zero fraction, and reshape fixes
* transpose, gather
* Zero fraction
* scalar cast fix
* Empty reduction axis support
* few more tests fixed
* Fixed input checks conforming with TF for concat op and tests.
* few tests fixed
* matmul scalar shape fix
* Fixed checkout for data type and scalarity with concat to allow non-empty scalars with vector concats.
* broadcast bool fix
* few more tests
* few more tests
* correct evalReduceShapeInfoEmpty
* argmax/argmin + tests
* one more empty edge case + one more test
* argmax/argmin/realdiv_bp tweaks
* empty reshape test + fix
* Helper fixes
* Small fixes
* Gather test fix
* Gather test fix
* Small fixes
* reduce scalar zero values
* scalar mean workaround
* Remove debug code
* along dim mean workaround
* one more test
* - equalsTo() tweak for empty arrays
- one more test
* broadcast tweaks
* [WIP] Fixing outstanding issues for NLP (#9)
* Avoid using not-inited objects
* Test fixed.
* Redundant method avoided for models like FastText
* KMeans++ implementation
* KMeans++ implementation
* Disable parallel execution
* KMeans++
* Tests
* Dev branch merge (#16)
* SameDiff: convertDataType and gradient check util improvements (#12)
* GradCheck util improvements
* StopGradient constructor + test
* SameDiff: Add datatype conversion
* Javadoc and add DataType.isNumerical()
* Small fix
* Fix SameDiff TF import test cases intermediate naming (workaround for bad default)
* TFGraphTestAllHelper: check intermediates in execution order
* Add missing debug listener
* [WIP] lstmBlock fix + other changes (#13)
- fixes lstmBlock issue
- changes NDArray method reshape(), permute(), transpose() by making them return instance instead of pointer
- CheckNumerics op
- fixes for ReduceBool IsInfOrNan & IsFinite
* Small test fix
* CheckNumerics op wrapper
* Fix some issues on master (#17)
* Fix DataVec test issue
* Fix issue with dl4j SameDiff output layer
* Dtype fix for lambda layers
* #7912 BertIterator dtype fix (use float32 not global default)
* [WIP] Next set of CUDA stuff (#7)
New CUDA implementations and improvements
* bad file
* Dev branch master merge (#23)
* SameDiff: convertDataType and gradient check util improvements (#12)
* GradCheck util improvements
* StopGradient constructor + test
* SameDiff: Add datatype conversion
* Javadoc and add DataType.isNumerical()
* Small fix
* Fix SameDiff TF import test cases intermediate naming (workaround for bad default)
* TFGraphTestAllHelper: check intermediates in execution order
* Add missing debug listener
* [WIP] lstmBlock fix + other changes (#13)
- fixes lstmBlock issue
- changes NDArray method reshape(), permute(), transpose() by making them return instance instead of pointer
- CheckNumerics op
- fixes for ReduceBool IsInfOrNan & IsFinite
* Small test fix
* CheckNumerics op wrapper
* Compatibility of deserialization (#18)
Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>
* SameDiff: add activation gradient checking support for debugging (#19)
* SameDiff gradient checker: first pass on activation gradient checks
* Fixes + tests for activation gradient checking
* Javadoc
* [WIP] Some nd4j data type corrections (#20)
* Adjust data type
* Set correct Data type.
* Size of proper data type.
* fix averaged cpu load (#22)
* SameDiff ops, TF import and fixes (#24)
* CheckNumerics tests + fixes + misc fixes
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Fake quant
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Fixes
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* FakeQuantWithMinMaxArgs
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* CheckNumerics fix
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Fix libnd4j ALL_INTS and ALL_FLOATS declaration (uint and bfloat types)
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Small fix
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Javadoc
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Exception tweak
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* fix
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Fix for out of scope stack allocated var use
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Ignores
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Ignore for known failing test (already logged issue)
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Merge upstream to fork (#25)
* Add thousand-separator commas to TotalParams (#7915)
* Add thousand-separator commas to TotalParams
The number of parameters can be quite large, and it would help the reading of the summary printout to have the TotalParams column & values at the bottom have thousand-separator-commas in them.
* Add thousand-separator commas to MultiLayerNetwork
Corresponding change to MultiLayerNetwork
Signed-off-by: Jxtps Jxtps <jxtps435@gmail.com>
* Update contributing and issue/PR templates (#7934)
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Fix link to AdaDelta paper (#7942)
Fix link to AdaDelta paper hosted on matthewzeiler.com
Signed-off-by: Jxtps
* Fixes, and ignores for known/logged failing issues (#7943)
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* SameDiff + DL4J/SameDiff: Multiple fixes (#28)
* #7919 HDF5 attribute buffer length fix
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* #7909 Arbiter constructor exception ux improvements
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* #7925 RNN output layer length checks
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* #7939 Add listener for validating inputs are not incorrectly modified
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* #7939 Integrate NonInplaceValidationListener into tests
* #7844 DL4J SameDiff fixes for variable minibatch size
* DL4J SameDiff fixes - ensure gradient for input placeholder is available
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Tweaks to ExternalErrorsFunction - use placeholders, make more robust
* Another fix
* More fixes
* More SameDiff/DL4J fixes
* Scope out scalar array creation in BaseScalarOp
* Remove debug code
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* [WIP] Final dev branch merge (#29)
* SameDiff: convertDataType and gradient check util improvements (#12)
* GradCheck util improvements
* StopGradient constructor + test
* SameDiff: Add datatype conversion
* Javadoc and add DataType.isNumerical()
* Small fix
* Fix SameDiff TF import test cases intermediate naming (workaround for bad default)
* TFGraphTestAllHelper: check intermediates in execution order
* Add missing debug listener
* [WIP] lstmBlock fix + other changes (#13)
- fixes lstmBlock issue
- changes NDArray method reshape(), permute(), transpose() by making them return instance instead of pointer
- CheckNumerics op
- fixes for ReduceBool IsInfOrNan & IsFinite
* Small test fix
* CheckNumerics op wrapper
* Compatibility of deserialization (#18)
Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>
* SameDiff: add activation gradient checking support for debugging (#19)
* SameDiff gradient checker: first pass on activation gradient checks
* Fixes + tests for activation gradient checking
* Javadoc
* [WIP] Some nd4j data type corrections (#20)
* Adjust data type
* Set correct Data type.
* Size of proper data type.
* fix averaged cpu load (#22)
* [WIP] Multiple dataset iterators (#27)
* Splitting dataset into arbitrary number
* Fixes
* Multiple split of iterator
* Test
* Test
* Some fixes
* signature change
* one more tweak
Signed-off-by: raver119 <raver119@gmail.com>
* one more test for sequential use of DataSetIteratorSplitter
Signed-off-by: raver119 <raver119@gmail.com>
* Fixes
* Fixes
* one more test for Alexander
Signed-off-by: raver119 <raver119@gmail.com>
* Some fixes
* Some fixes
* one more test for Alexander
Signed-off-by: raver119 <raver119@gmail.com>
* minor test fix
Signed-off-by: raver119 <raver119@gmail.com>
* Some fixes
* Some fixes
* couple of assertions tweaked
Signed-off-by: raver119 <raver119@gmail.com>
* MDS splitter test :/
Signed-off-by: raver119 <raver119@gmail.com>
* Minor refactoring
* Multi dataset
* Some fixes
* More tests
* Small number of test fixes/improvements (failures on CI) (#31)
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* [WIP] More CUDA stuff (#26)
* initial commit
Signed-off-by: raver119 <raver119@gmail.com>
* LRN BP CUDA
Signed-off-by: raver119 <raver119@gmail.com>
* less memory
Signed-off-by: raver119 <raver119@gmail.com>
* Fixed bug with crop_and_resize op helper.
* get rid of unnecessary index-calculation dunction
Signed-off-by: Yurii <yurii@skymind.io>
* Fixed sort with nth_element cuda-based helper.
* Refactored nth_element.
* Refactored nth_element op and tests.
* Modified usage of dim array with sortTad routine.
* Refactored main routine of helper for non_max_image_suppression op.
* non_max_image_suppression op helper with cuda kernel implementation. Initial revision.
* fix vol2col cuda kernel
* meh
Signed-off-by: raver119 <raver119@gmail.com>
* topK concept
Signed-off-by: raver119 <raver119@gmail.com>
* unsorted topK with scanWitdh of 1
Signed-off-by: raver119 <raver119@gmail.com>
* correct vol2col tests
* sorted/unsorted topK
Signed-off-by: raver119 <raver119@gmail.com>
* implementation and fixing col2im/col2vol
* Corrected usage flags with input/output with reverse op.
* dup is const now
Signed-off-by: raver119 <raver119@gmail.com>
* percentile op
Signed-off-by: raver119 <raver119@gmail.com>
* group tests for mapool2d
Signed-off-by: Yurii <yurii@skymind.io>
* special test for george
Signed-off-by: raver119 <raver119@gmail.com>
* less threads for sortTad
Signed-off-by: raver119 <raver119@gmail.com>
* provide conv2d for cuda
Signed-off-by: Yurii <yurii@skymind.io>
* remove auther in sort tad kernel code
Signed-off-by: Yurii <yurii@skymind.io>
* provide depthwise_conv2d for cuda
Signed-off-by: Yurii <yurii@skymind.io>
* - max_pooling_with_argmax
- null check for special use
Signed-off-by: raver119 <raver119@gmail.com>
* dts cuda
Signed-off-by: raver119 <raver119@gmail.com>
* provide sconv2d for cuda
Signed-off-by: Yurii <yurii@skymind.io>
* std cuda
Signed-off-by: raver119 <raver119@gmail.com>
* Refactored non_max_suppression op to conform TF implementation.
* Improved suppression helper.
* provide pooling3d for cuda
Signed-off-by: Yurii <yurii@skymind.io>
* minor lstm rearrangements
Signed-off-by: raver119 <raver119@gmail.com>
* more of minor lstm rearrangements
Signed-off-by: raver119 <raver119@gmail.com>
* (bi)dynamic_rnn
Signed-off-by: raver119 <raver119@gmail.com>
* templates init order
Signed-off-by: raver119 <raver119@gmail.com>
* Refactored non_max_suppression op.
* Added cuda kernel for non_max_suppression.
* CPU sort by key/value
Signed-off-by: raver119 <raver119@gmail.com>
* CPU sort TAD by key/value
Signed-off-by: raver119 <raver119@gmail.com>
* CPU sort TAD by key/value tests
Signed-off-by: raver119 <raver119@gmail.com>
* Eliminate compiler error with cuda implementation.
* - repaired gradCheck in cuda
- provide conv2d_bp for cuda
Signed-off-by: Yurii <yurii@skymind.io>
* missed signature
Signed-off-by: raver119 <raver119@gmail.com>
* provide depthwise_conv2d_bp for cuda
Signed-off-by: Yurii <yurii@skymind.io>
* Implementation of lup helper with cuda kernel. Initial commit.
* further work on backprops for convolutions
Signed-off-by: Yurii <yurii@skymind.io>
* CUDA linear sort by key/val
Signed-off-by: raver119 <raver119@gmail.com>
* CUDA tad sort by key/val
Signed-off-by: raver119 <raver119@gmail.com>
* start providing of backprop for pooling2d/3d
Signed-off-by: Yurii <yurii@skymind.io>
* Added atomicAdd for bool datatype.
* dynamic partition concept
Signed-off-by: raver119 <raver119@gmail.com>
* dynamic partition concept
Signed-off-by: raver119 <raver119@gmail.com>
* dynamic partition scalar CUDA
Signed-off-by: raver119 <raver119@gmail.com>
* important comment
Signed-off-by: raver119 <raver119@gmail.com>
* fix pooling2d/3d backprop helpers
Signed-off-by: Yurii <yurii@skymind.io>
* Added non-linear test with dynamic_partition.
* Improved test for dynamic_partition.
* dynamic_partition TAD concept
Signed-off-by: raver119 <raver119@gmail.com>
* - dynamic_partition TAD CUDA impl
- dynamic_partition TAD CPU fix
Signed-off-by: raver119 <raver119@gmail.com>
* - rewrite cpu code for usampling2d/3d
- write cuda code for usampling2d/3d
Signed-off-by: Yurii <yurii@skymind.io>
* dynamic_stitch CUDA vector case
Signed-off-by: raver119 <raver119@gmail.com>
* dynamic_stitch CUDA TAD case concept
Signed-off-by: raver119 <raver119@gmail.com>
* dynamic_stitch CUDA TAD case impl
Signed-off-by: raver119 <raver119@gmail.com>
* Added tests for dynamic_stitch 3D-4D cases.
* minor tests tweaks
Signed-off-by: raver119 <raver119@gmail.com>
* Fixed type check for dynamic stitch.
* min/max bp
Signed-off-by: raver119 <raver119@gmail.com>
* rewrite code for upsampling2d/3d cpu
Signed-off-by: Yurii <yurii@skymind.io>
* reduce min/max/norm_max bp
Signed-off-by: raver119 <raver119@gmail.com>
* lup implementation. Additional enhancements.
* provide code for upsamling2d/3d backprop
Signed-off-by: Yurii <yurii@skymind.io>
* weightedCrossEntropyWithLogits
Signed-off-by: raver119 <raver119@gmail.com>
* Fixed template math atomicMul for 64bit ints.
* Refactored dynamic_partition_bp op.
* inverseBroadcast fix
Signed-off-by: raver119 <raver119@gmail.com>
* DynamicPartitionBP test datatype fixed.
* - nd4j_atomicMul Windows fix
- cpu/NDArrayLambda.hpp excluded from CUDA
Signed-off-by: raver119 <raver119@gmail.com>
2019-06-27 17:37:04 +02:00
|
|
|
bool descending) {
|
2019-08-26 18:57:51 +02:00
|
|
|
try {
|
|
|
|
auto xType = ArrayOptions::dataType(xShapeInfo);
|
|
|
|
auto yType = ArrayOptions::dataType(yShapeInfo);
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
BUILD_DOUBLE_SELECTOR(xType, yType, sd::DoubleMethods, ::sortByKey(x, xShapeInfo, y, yShapeInfo, descending), LIBND4J_TYPES, LIBND4J_TYPES);
|
2019-08-26 18:57:51 +02:00
|
|
|
} catch (std::exception &e) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
2019-08-26 18:57:51 +02:00
|
|
|
}
|
Merge master to upstream (#7945)
* Shugeo strided slice zeros (#14)
* Modified strided_slice op to properly work with empty-like shapes.
* Fixed test for reduce_mean with empty-like input.
* [WIP] Last merge (#15)
* correct logsoftmax looss (#2)
* Small SameDiff listener fix (#4)
* Various fixes (#6)
* #7839 Fix for asXMatrix and tests
* #7866 EmbeddingSequenceLayer dtype fix + test
* #7856 SameDiff save/load stream methods
* #7859 RegressionEvaluation rank 4 fix + tests + axis configuration
* EvaluationBinary 3d/4d
* More evaluation 3d/4d tests
* #7847 Evaluation empty checks
* Small test ifx
* #7848 Fix median edge case
* Improve DL4J samediff layer tests
* [WIP] FastText wrapper implemented (#8)
* FastText implemented
* Some fixes
* Fix shapes for wordsNearest
* Validation of input vectors
* Fixes
* Fixed test
* Thread tagged
* Some tweaks
* setContextClassLoader for DeallocatorServiceThread
* Numpy format tests (#1)
* Various fixes (#11)
* #7852 SameDiff gather fix
* #7892 SameDiff placeholder to constant conversion
* #7890 validate input rank for MLN/CG init methods
* Fix broken permute shape calculation
* Permute and gather fixes
* Tests
* #7850 LogSumExp fix + test
* Handful of test fixes
* Empty arrays with non-scalar shapes (#10)
* minor rearrangements for lambdas
* empty tensors with non-scalar shapes
* numpy empty tensors with non-scalar shapes
* few more empty tweaks
* Small fixes
* conv3d signature update
* micro fix in batchnorm mkldnn
* Import fixes
* Fix
* MKL-DNN update
* Small fill fix
* fill with empty input + test
* Fixes
* Small error improvement
* Fix
* one special test
* couple of fixes for lstm
* Rewrite TFGraphMapper.getNDArrayFromTensor to be maintainable and less error prone
* Fixes
* FP16
* Unsigned
* BFloat16
* Fill op - empty tweaks
* - couple of fixes for empty arrays construction
- stack updated
* strided slice fix
* one transform test
* provide method for reducing shapeInfo in case of input array is empty
* Fixed reduceAlongDimensions to use empty input properly.
* couple of broadcast tests
* couple of tests broadcast tests + tweak to make them pass
* add check of non-empty to methods producing sub-arrays
* Fixed reshapeC with zeros in shape.
* complete empty check in reduce_... legacy ops
* Concat and cumsum/prod
* Tweak to empty shape inference on import
* add empty check to the rest of reduce legacy ops
* one more test
* correct typo in evalReduceShapeInfoEmpty
* Added tests for reduce_* ops to tests with zero shapes.
* few more tests for empty reductions
* Fixed strided_slice op with empty case and tests.
* one more empty reduction test
* Fixed strided_slice test.
* add empty check to NDArray::reshapei
* infOrMax
* empty min/max with infinity tests
* made unstack working correctly with empty arrays
* few IndexReduce tests + tweaks for empty shapes
* add test for empty concat
* few tests fixed
* Validation fix for reductions on empty shapes
* Reverse fix
* Reduction shape calc fixes
* SameDiff.generateOutputVariable: don't use shape function to determine number of outputs
* Range fix
* - NDArray constructor updated for scalars/empty arrays
- few tests fixed
* More fixes
* Empty creator fixes
* concat fix
* concat fix
* TF import tests: allow 'both all NaN' and 'both all inf' to pass
* Slice, zero fraction, and reshape fixes
* transpose, gather
* Zero fraction
* scalar cast fix
* Empty reduction axis support
* few more tests fixed
* Fixed input checks conforming with TF for concat op and tests.
* few tests fixed
* matmul scalar shape fix
* Fixed checkout for data type and scalarity with concat to allow non-empty scalars with vector concats.
* broadcast bool fix
* few more tests
* few more tests
* correct evalReduceShapeInfoEmpty
* argmax/argmin + tests
* one more empty edge case + one more test
* argmax/argmin/realdiv_bp tweaks
* empty reshape test + fix
* Helper fixes
* Small fixes
* Gather test fix
* Gather test fix
* Small fixes
* reduce scalar zero values
* scalar mean workaround
* Remove debug code
* along dim mean workaround
* one more test
* - equalsTo() tweak for empty arrays
- one more test
* broadcast tweaks
* [WIP] Fixing outstanding issues for NLP (#9)
* Avoid using not-inited objects
* Test fixed.
* Redundant method avoided for models like FastText
* KMeans++ implementation
* KMeans++ implementation
* Disable parallel execution
* KMeans++
* Tests
* Dev branch merge (#16)
* SameDiff: convertDataType and gradient check util improvements (#12)
* GradCheck util improvements
* StopGradient constructor + test
* SameDiff: Add datatype conversion
* Javadoc and add DataType.isNumerical()
* Small fix
* Fix SameDiff TF import test cases intermediate naming (workaround for bad default)
* TFGraphTestAllHelper: check intermediates in execution order
* Add missing debug listener
* [WIP] lstmBlock fix + other changes (#13)
- fixes lstmBlock issue
- changes NDArray method reshape(), permute(), transpose() by making them return instance instead of pointer
- CheckNumerics op
- fixes for ReduceBool IsInfOrNan & IsFinite
* Small test fix
* CheckNumerics op wrapper
* Fix some issues on master (#17)
* Fix DataVec test issue
* Fix issue with dl4j SameDiff output layer
* Dtype fix for lambda layers
* #7912 BertIterator dtype fix (use float32 not global default)
* [WIP] Next set of CUDA stuff (#7)
New CUDA implementations and improvements
* bad file
* Dev branch master merge (#23)
* SameDiff: convertDataType and gradient check util improvements (#12)
* GradCheck util improvements
* StopGradient constructor + test
* SameDiff: Add datatype conversion
* Javadoc and add DataType.isNumerical()
* Small fix
* Fix SameDiff TF import test cases intermediate naming (workaround for bad default)
* TFGraphTestAllHelper: check intermediates in execution order
* Add missing debug listener
* [WIP] lstmBlock fix + other changes (#13)
- fixes lstmBlock issue
- changes NDArray method reshape(), permute(), transpose() by making them return instance instead of pointer
- CheckNumerics op
- fixes for ReduceBool IsInfOrNan & IsFinite
* Small test fix
* CheckNumerics op wrapper
* Compatibility of deserialization (#18)
Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>
* SameDiff: add activation gradient checking support for debugging (#19)
* SameDiff gradient checker: first pass on activation gradient checks
* Fixes + tests for activation gradient checking
* Javadoc
* [WIP] Some nd4j data type corrections (#20)
* Adjust data type
* Set correct Data type.
* Size of proper data type.
* fix averaged cpu load (#22)
* SameDiff ops, TF import and fixes (#24)
* CheckNumerics tests + fixes + misc fixes
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Fake quant
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Fixes
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* FakeQuantWithMinMaxArgs
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* CheckNumerics fix
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Fix libnd4j ALL_INTS and ALL_FLOATS declaration (uint and bfloat types)
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Small fix
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Javadoc
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Exception tweak
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* fix
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Fix for out of scope stack allocated var use
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Ignores
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Ignore for known failing test (already logged issue)
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Merge upstream to fork (#25)
* Add thousand-separator commas to TotalParams (#7915)
* Add thousand-separator commas to TotalParams
The number of parameters can be quite large, and it would help the reading of the summary printout to have the TotalParams column & values at the bottom have thousand-separator-commas in them.
* Add thousand-separator commas to MultiLayerNetwork
Corresponding change to MultiLayerNetwork
Signed-off-by: Jxtps Jxtps <jxtps435@gmail.com>
* Update contributing and issue/PR templates (#7934)
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Fix link to AdaDelta paper (#7942)
Fix link to AdaDelta paper hosted on matthewzeiler.com
Signed-off-by: Jxtps
* Fixes, and ignores for known/logged failing issues (#7943)
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* SameDiff + DL4J/SameDiff: Multiple fixes (#28)
* #7919 HDF5 attribute buffer length fix
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* #7909 Arbiter constructor exception ux improvements
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* #7925 RNN output layer length checks
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* #7939 Add listener for validating inputs are not incorrectly modified
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* #7939 Integrate NonInplaceValidationListener into tests
* #7844 DL4J SameDiff fixes for variable minibatch size
* DL4J SameDiff fixes - ensure gradient for input placeholder is available
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Tweaks to ExternalErrorsFunction - use placeholders, make more robust
* Another fix
* More fixes
* More SameDiff/DL4J fixes
* Scope out scalar array creation in BaseScalarOp
* Remove debug code
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* [WIP] Final dev branch merge (#29)
* SameDiff: convertDataType and gradient check util improvements (#12)
* GradCheck util improvements
* StopGradient constructor + test
* SameDiff: Add datatype conversion
* Javadoc and add DataType.isNumerical()
* Small fix
* Fix SameDiff TF import test cases intermediate naming (workaround for bad default)
* TFGraphTestAllHelper: check intermediates in execution order
* Add missing debug listener
* [WIP] lstmBlock fix + other changes (#13)
- fixes lstmBlock issue
- changes NDArray method reshape(), permute(), transpose() by making them return instance instead of pointer
- CheckNumerics op
- fixes for ReduceBool IsInfOrNan & IsFinite
* Small test fix
* CheckNumerics op wrapper
* Compatibility of deserialization (#18)
Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>
* SameDiff: add activation gradient checking support for debugging (#19)
* SameDiff gradient checker: first pass on activation gradient checks
* Fixes + tests for activation gradient checking
* Javadoc
* [WIP] Some nd4j data type corrections (#20)
* Adjust data type
* Set correct Data type.
* Size of proper data type.
* fix averaged cpu load (#22)
* [WIP] Multiple dataset iterators (#27)
* Splitting dataset into arbitrary number
* Fixes
* Multiple split of iterator
* Test
* Test
* Some fixes
* signature change
* one more tweak
Signed-off-by: raver119 <raver119@gmail.com>
* one more test for sequential use of DataSetIteratorSplitter
Signed-off-by: raver119 <raver119@gmail.com>
* Fixes
* Fixes
* one more test for Alexander
Signed-off-by: raver119 <raver119@gmail.com>
* Some fixes
* Some fixes
* one more test for Alexander
Signed-off-by: raver119 <raver119@gmail.com>
* minor test fix
Signed-off-by: raver119 <raver119@gmail.com>
* Some fixes
* Some fixes
* couple of assertions tweaked
Signed-off-by: raver119 <raver119@gmail.com>
* MDS splitter test :/
Signed-off-by: raver119 <raver119@gmail.com>
* Minor refactoring
* Multi dataset
* Some fixes
* More tests
* Small number of test fixes/improvements (failures on CI) (#31)
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* [WIP] More CUDA stuff (#26)
* initial commit
Signed-off-by: raver119 <raver119@gmail.com>
* LRN BP CUDA
Signed-off-by: raver119 <raver119@gmail.com>
* less memory
Signed-off-by: raver119 <raver119@gmail.com>
* Fixed bug with crop_and_resize op helper.
* get rid of unnecessary index-calculation dunction
Signed-off-by: Yurii <yurii@skymind.io>
* Fixed sort with nth_element cuda-based helper.
* Refactored nth_element.
* Refactored nth_element op and tests.
* Modified usage of dim array with sortTad routine.
* Refactored main routine of helper for non_max_image_suppression op.
* non_max_image_suppression op helper with cuda kernel implementation. Initial revision.
* fix vol2col cuda kernel
* meh
Signed-off-by: raver119 <raver119@gmail.com>
* topK concept
Signed-off-by: raver119 <raver119@gmail.com>
* unsorted topK with scanWitdh of 1
Signed-off-by: raver119 <raver119@gmail.com>
* correct vol2col tests
* sorted/unsorted topK
Signed-off-by: raver119 <raver119@gmail.com>
* implementation and fixing col2im/col2vol
* Corrected usage flags with input/output with reverse op.
* dup is const now
Signed-off-by: raver119 <raver119@gmail.com>
* percentile op
Signed-off-by: raver119 <raver119@gmail.com>
* group tests for mapool2d
Signed-off-by: Yurii <yurii@skymind.io>
* special test for george
Signed-off-by: raver119 <raver119@gmail.com>
* less threads for sortTad
Signed-off-by: raver119 <raver119@gmail.com>
* provide conv2d for cuda
Signed-off-by: Yurii <yurii@skymind.io>
* remove auther in sort tad kernel code
Signed-off-by: Yurii <yurii@skymind.io>
* provide depthwise_conv2d for cuda
Signed-off-by: Yurii <yurii@skymind.io>
* - max_pooling_with_argmax
- null check for special use
Signed-off-by: raver119 <raver119@gmail.com>
* dts cuda
Signed-off-by: raver119 <raver119@gmail.com>
* provide sconv2d for cuda
Signed-off-by: Yurii <yurii@skymind.io>
* std cuda
Signed-off-by: raver119 <raver119@gmail.com>
* Refactored non_max_suppression op to conform TF implementation.
* Improved suppression helper.
* provide pooling3d for cuda
Signed-off-by: Yurii <yurii@skymind.io>
* minor lstm rearrangements
Signed-off-by: raver119 <raver119@gmail.com>
* more of minor lstm rearrangements
Signed-off-by: raver119 <raver119@gmail.com>
* (bi)dynamic_rnn
Signed-off-by: raver119 <raver119@gmail.com>
* templates init order
Signed-off-by: raver119 <raver119@gmail.com>
* Refactored non_max_suppression op.
* Added cuda kernel for non_max_suppression.
* CPU sort by key/value
Signed-off-by: raver119 <raver119@gmail.com>
* CPU sort TAD by key/value
Signed-off-by: raver119 <raver119@gmail.com>
* CPU sort TAD by key/value tests
Signed-off-by: raver119 <raver119@gmail.com>
* Eliminate compiler error with cuda implementation.
* - repaired gradCheck in cuda
- provide conv2d_bp for cuda
Signed-off-by: Yurii <yurii@skymind.io>
* missed signature
Signed-off-by: raver119 <raver119@gmail.com>
* provide depthwise_conv2d_bp for cuda
Signed-off-by: Yurii <yurii@skymind.io>
* Implementation of lup helper with cuda kernel. Initial commit.
* further work on backprops for convolutions
Signed-off-by: Yurii <yurii@skymind.io>
* CUDA linear sort by key/val
Signed-off-by: raver119 <raver119@gmail.com>
* CUDA tad sort by key/val
Signed-off-by: raver119 <raver119@gmail.com>
* start providing of backprop for pooling2d/3d
Signed-off-by: Yurii <yurii@skymind.io>
* Added atomicAdd for bool datatype.
* dynamic partition concept
Signed-off-by: raver119 <raver119@gmail.com>
* dynamic partition concept
Signed-off-by: raver119 <raver119@gmail.com>
* dynamic partition scalar CUDA
Signed-off-by: raver119 <raver119@gmail.com>
* important comment
Signed-off-by: raver119 <raver119@gmail.com>
* fix pooling2d/3d backprop helpers
Signed-off-by: Yurii <yurii@skymind.io>
* Added non-linear test with dynamic_partition.
* Improved test for dynamic_partition.
* dynamic_partition TAD concept
Signed-off-by: raver119 <raver119@gmail.com>
* - dynamic_partition TAD CUDA impl
- dynamic_partition TAD CPU fix
Signed-off-by: raver119 <raver119@gmail.com>
* - rewrite cpu code for usampling2d/3d
- write cuda code for usampling2d/3d
Signed-off-by: Yurii <yurii@skymind.io>
* dynamic_stitch CUDA vector case
Signed-off-by: raver119 <raver119@gmail.com>
* dynamic_stitch CUDA TAD case concept
Signed-off-by: raver119 <raver119@gmail.com>
* dynamic_stitch CUDA TAD case impl
Signed-off-by: raver119 <raver119@gmail.com>
* Added tests for dynamic_stitch 3D-4D cases.
* minor tests tweaks
Signed-off-by: raver119 <raver119@gmail.com>
* Fixed type check for dynamic stitch.
* min/max bp
Signed-off-by: raver119 <raver119@gmail.com>
* rewrite code for upsampling2d/3d cpu
Signed-off-by: Yurii <yurii@skymind.io>
* reduce min/max/norm_max bp
Signed-off-by: raver119 <raver119@gmail.com>
* lup implementation. Additional enhancements.
* provide code for upsamling2d/3d backprop
Signed-off-by: Yurii <yurii@skymind.io>
* weightedCrossEntropyWithLogits
Signed-off-by: raver119 <raver119@gmail.com>
* Fixed template math atomicMul for 64bit ints.
* Refactored dynamic_partition_bp op.
* inverseBroadcast fix
Signed-off-by: raver119 <raver119@gmail.com>
* DynamicPartitionBP test datatype fixed.
* - nd4j_atomicMul Windows fix
- cpu/NDArrayLambda.hpp excluded from CUDA
Signed-off-by: raver119 <raver119@gmail.com>
2019-06-27 17:37:04 +02:00
|
|
|
}
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
void sortByValue(Nd4jPointer *extraPointers,
|
2020-05-09 07:06:14 +02:00
|
|
|
void *x, const Nd4jLong *xShapeInfo,
|
|
|
|
void *dx, const Nd4jLong *dxShapeInfo,
|
|
|
|
void *y, const Nd4jLong *yShapeInfo,
|
|
|
|
void *dy, const Nd4jLong *dyShapeInfo,
|
Merge master to upstream (#7945)
* Shugeo strided slice zeros (#14)
* Modified strided_slice op to properly work with empty-like shapes.
* Fixed test for reduce_mean with empty-like input.
* [WIP] Last merge (#15)
* correct logsoftmax looss (#2)
* Small SameDiff listener fix (#4)
* Various fixes (#6)
* #7839 Fix for asXMatrix and tests
* #7866 EmbeddingSequenceLayer dtype fix + test
* #7856 SameDiff save/load stream methods
* #7859 RegressionEvaluation rank 4 fix + tests + axis configuration
* EvaluationBinary 3d/4d
* More evaluation 3d/4d tests
* #7847 Evaluation empty checks
* Small test ifx
* #7848 Fix median edge case
* Improve DL4J samediff layer tests
* [WIP] FastText wrapper implemented (#8)
* FastText implemented
* Some fixes
* Fix shapes for wordsNearest
* Validation of input vectors
* Fixes
* Fixed test
* Thread tagged
* Some tweaks
* setContextClassLoader for DeallocatorServiceThread
* Numpy format tests (#1)
* Various fixes (#11)
* #7852 SameDiff gather fix
* #7892 SameDiff placeholder to constant conversion
* #7890 validate input rank for MLN/CG init methods
* Fix broken permute shape calculation
* Permute and gather fixes
* Tests
* #7850 LogSumExp fix + test
* Handful of test fixes
* Empty arrays with non-scalar shapes (#10)
* minor rearrangements for lambdas
* empty tensors with non-scalar shapes
* numpy empty tensors with non-scalar shapes
* few more empty tweaks
* Small fixes
* conv3d signature update
* micro fix in batchnorm mkldnn
* Import fixes
* Fix
* MKL-DNN update
* Small fill fix
* fill with empty input + test
* Fixes
* Small error improvement
* Fix
* one special test
* couple of fixes for lstm
* Rewrite TFGraphMapper.getNDArrayFromTensor to be maintainable and less error prone
* Fixes
* FP16
* Unsigned
* BFloat16
* Fill op - empty tweaks
* - couple of fixes for empty arrays construction
- stack updated
* strided slice fix
* one transform test
* provide method for reducing shapeInfo in case of input array is empty
* Fixed reduceAlongDimensions to use empty input properly.
* couple of broadcast tests
* couple of tests broadcast tests + tweak to make them pass
* add check of non-empty to methods producing sub-arrays
* Fixed reshapeC with zeros in shape.
* complete empty check in reduce_... legacy ops
* Concat and cumsum/prod
* Tweak to empty shape inference on import
* add empty check to the rest of reduce legacy ops
* one more test
* correct typo in evalReduceShapeInfoEmpty
* Added tests for reduce_* ops to tests with zero shapes.
* few more tests for empty reductions
* Fixed strided_slice op with empty case and tests.
* one more empty reduction test
* Fixed strided_slice test.
* add empty check to NDArray::reshapei
* infOrMax
* empty min/max with infinity tests
* made unstack working correctly with empty arrays
* few IndexReduce tests + tweaks for empty shapes
* add test for empty concat
* few tests fixed
* Validation fix for reductions on empty shapes
* Reverse fix
* Reduction shape calc fixes
* SameDiff.generateOutputVariable: don't use shape function to determine number of outputs
* Range fix
* - NDArray constructor updated for scalars/empty arrays
- few tests fixed
* More fixes
* Empty creator fixes
* concat fix
* concat fix
* TF import tests: allow 'both all NaN' and 'both all inf' to pass
* Slice, zero fraction, and reshape fixes
* transpose, gather
* Zero fraction
* scalar cast fix
* Empty reduction axis support
* few more tests fixed
* Fixed input checks conforming with TF for concat op and tests.
* few tests fixed
* matmul scalar shape fix
* Fixed checkout for data type and scalarity with concat to allow non-empty scalars with vector concats.
* broadcast bool fix
* few more tests
* few more tests
* correct evalReduceShapeInfoEmpty
* argmax/argmin + tests
* one more empty edge case + one more test
* argmax/argmin/realdiv_bp tweaks
* empty reshape test + fix
* Helper fixes
* Small fixes
* Gather test fix
* Gather test fix
* Small fixes
* reduce scalar zero values
* scalar mean workaround
* Remove debug code
* along dim mean workaround
* one more test
* - equalsTo() tweak for empty arrays
- one more test
* broadcast tweaks
* [WIP] Fixing outstanding issues for NLP (#9)
* Avoid using not-inited objects
* Test fixed.
* Redundant method avoided for models like FastText
* KMeans++ implementation
* KMeans++ implementation
* Disable parallel execution
* KMeans++
* Tests
* Dev branch merge (#16)
* SameDiff: convertDataType and gradient check util improvements (#12)
* GradCheck util improvements
* StopGradient constructor + test
* SameDiff: Add datatype conversion
* Javadoc and add DataType.isNumerical()
* Small fix
* Fix SameDiff TF import test cases intermediate naming (workaround for bad default)
* TFGraphTestAllHelper: check intermediates in execution order
* Add missing debug listener
* [WIP] lstmBlock fix + other changes (#13)
- fixes lstmBlock issue
- changes NDArray method reshape(), permute(), transpose() by making them return instance instead of pointer
- CheckNumerics op
- fixes for ReduceBool IsInfOrNan & IsFinite
* Small test fix
* CheckNumerics op wrapper
* Fix some issues on master (#17)
* Fix DataVec test issue
* Fix issue with dl4j SameDiff output layer
* Dtype fix for lambda layers
* #7912 BertIterator dtype fix (use float32 not global default)
* [WIP] Next set of CUDA stuff (#7)
New CUDA implementations and improvements
* bad file
* Dev branch master merge (#23)
* SameDiff: convertDataType and gradient check util improvements (#12)
* GradCheck util improvements
* StopGradient constructor + test
* SameDiff: Add datatype conversion
* Javadoc and add DataType.isNumerical()
* Small fix
* Fix SameDiff TF import test cases intermediate naming (workaround for bad default)
* TFGraphTestAllHelper: check intermediates in execution order
* Add missing debug listener
* [WIP] lstmBlock fix + other changes (#13)
- fixes lstmBlock issue
- changes NDArray method reshape(), permute(), transpose() by making them return instance instead of pointer
- CheckNumerics op
- fixes for ReduceBool IsInfOrNan & IsFinite
* Small test fix
* CheckNumerics op wrapper
* Compatibility of deserialization (#18)
Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>
* SameDiff: add activation gradient checking support for debugging (#19)
* SameDiff gradient checker: first pass on activation gradient checks
* Fixes + tests for activation gradient checking
* Javadoc
* [WIP] Some nd4j data type corrections (#20)
* Adjust data type
* Set correct Data type.
* Size of proper data type.
* fix averaged cpu load (#22)
* SameDiff ops, TF import and fixes (#24)
* CheckNumerics tests + fixes + misc fixes
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Fake quant
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Fixes
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* FakeQuantWithMinMaxArgs
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* CheckNumerics fix
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Fix libnd4j ALL_INTS and ALL_FLOATS declaration (uint and bfloat types)
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Small fix
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Javadoc
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Exception tweak
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* fix
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Fix for out of scope stack allocated var use
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Ignores
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Ignore for known failing test (already logged issue)
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Merge upstream to fork (#25)
* Add thousand-separator commas to TotalParams (#7915)
* Add thousand-separator commas to TotalParams
The number of parameters can be quite large, and it would help the reading of the summary printout to have the TotalParams column & values at the bottom have thousand-separator-commas in them.
* Add thousand-separator commas to MultiLayerNetwork
Corresponding change to MultiLayerNetwork
Signed-off-by: Jxtps Jxtps <jxtps435@gmail.com>
* Update contributing and issue/PR templates (#7934)
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Fix link to AdaDelta paper (#7942)
Fix link to AdaDelta paper hosted on matthewzeiler.com
Signed-off-by: Jxtps
* Fixes, and ignores for known/logged failing issues (#7943)
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* SameDiff + DL4J/SameDiff: Multiple fixes (#28)
* #7919 HDF5 attribute buffer length fix
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* #7909 Arbiter constructor exception ux improvements
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* #7925 RNN output layer length checks
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* #7939 Add listener for validating inputs are not incorrectly modified
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* #7939 Integrate NonInplaceValidationListener into tests
* #7844 DL4J SameDiff fixes for variable minibatch size
* DL4J SameDiff fixes - ensure gradient for input placeholder is available
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Tweaks to ExternalErrorsFunction - use placeholders, make more robust
* Another fix
* More fixes
* More SameDiff/DL4J fixes
* Scope out scalar array creation in BaseScalarOp
* Remove debug code
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* [WIP] Final dev branch merge (#29)
* SameDiff: convertDataType and gradient check util improvements (#12)
* GradCheck util improvements
* StopGradient constructor + test
* SameDiff: Add datatype conversion
* Javadoc and add DataType.isNumerical()
* Small fix
* Fix SameDiff TF import test cases intermediate naming (workaround for bad default)
* TFGraphTestAllHelper: check intermediates in execution order
* Add missing debug listener
* [WIP] lstmBlock fix + other changes (#13)
- fixes lstmBlock issue
- changes NDArray method reshape(), permute(), transpose() by making them return instance instead of pointer
- CheckNumerics op
- fixes for ReduceBool IsInfOrNan & IsFinite
* Small test fix
* CheckNumerics op wrapper
* Compatibility of deserialization (#18)
Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>
* SameDiff: add activation gradient checking support for debugging (#19)
* SameDiff gradient checker: first pass on activation gradient checks
* Fixes + tests for activation gradient checking
* Javadoc
* [WIP] Some nd4j data type corrections (#20)
* Adjust data type
* Set correct Data type.
* Size of proper data type.
* fix averaged cpu load (#22)
* [WIP] Multiple dataset iterators (#27)
* Splitting dataset into arbitrary number
* Fixes
* Multiple split of iterator
* Test
* Test
* Some fixes
* signature change
* one more tweak
Signed-off-by: raver119 <raver119@gmail.com>
* one more test for sequential use of DataSetIteratorSplitter
Signed-off-by: raver119 <raver119@gmail.com>
* Fixes
* Fixes
* one more test for Alexander
Signed-off-by: raver119 <raver119@gmail.com>
* Some fixes
* Some fixes
* one more test for Alexander
Signed-off-by: raver119 <raver119@gmail.com>
* minor test fix
Signed-off-by: raver119 <raver119@gmail.com>
* Some fixes
* Some fixes
* couple of assertions tweaked
Signed-off-by: raver119 <raver119@gmail.com>
* MDS splitter test :/
Signed-off-by: raver119 <raver119@gmail.com>
* Minor refactoring
* Multi dataset
* Some fixes
* More tests
* Small number of test fixes/improvements (failures on CI) (#31)
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* [WIP] More CUDA stuff (#26)
* initial commit
Signed-off-by: raver119 <raver119@gmail.com>
* LRN BP CUDA
Signed-off-by: raver119 <raver119@gmail.com>
* less memory
Signed-off-by: raver119 <raver119@gmail.com>
* Fixed bug with crop_and_resize op helper.
* get rid of unnecessary index-calculation dunction
Signed-off-by: Yurii <yurii@skymind.io>
* Fixed sort with nth_element cuda-based helper.
* Refactored nth_element.
* Refactored nth_element op and tests.
* Modified usage of dim array with sortTad routine.
* Refactored main routine of helper for non_max_image_suppression op.
* non_max_image_suppression op helper with cuda kernel implementation. Initial revision.
* fix vol2col cuda kernel
* meh
Signed-off-by: raver119 <raver119@gmail.com>
* topK concept
Signed-off-by: raver119 <raver119@gmail.com>
* unsorted topK with scanWitdh of 1
Signed-off-by: raver119 <raver119@gmail.com>
* correct vol2col tests
* sorted/unsorted topK
Signed-off-by: raver119 <raver119@gmail.com>
* implementation and fixing col2im/col2vol
* Corrected usage flags with input/output with reverse op.
* dup is const now
Signed-off-by: raver119 <raver119@gmail.com>
* percentile op
Signed-off-by: raver119 <raver119@gmail.com>
* group tests for mapool2d
Signed-off-by: Yurii <yurii@skymind.io>
* special test for george
Signed-off-by: raver119 <raver119@gmail.com>
* less threads for sortTad
Signed-off-by: raver119 <raver119@gmail.com>
* provide conv2d for cuda
Signed-off-by: Yurii <yurii@skymind.io>
* remove auther in sort tad kernel code
Signed-off-by: Yurii <yurii@skymind.io>
* provide depthwise_conv2d for cuda
Signed-off-by: Yurii <yurii@skymind.io>
* - max_pooling_with_argmax
- null check for special use
Signed-off-by: raver119 <raver119@gmail.com>
* dts cuda
Signed-off-by: raver119 <raver119@gmail.com>
* provide sconv2d for cuda
Signed-off-by: Yurii <yurii@skymind.io>
* std cuda
Signed-off-by: raver119 <raver119@gmail.com>
* Refactored non_max_suppression op to conform TF implementation.
* Improved suppression helper.
* provide pooling3d for cuda
Signed-off-by: Yurii <yurii@skymind.io>
* minor lstm rearrangements
Signed-off-by: raver119 <raver119@gmail.com>
* more of minor lstm rearrangements
Signed-off-by: raver119 <raver119@gmail.com>
* (bi)dynamic_rnn
Signed-off-by: raver119 <raver119@gmail.com>
* templates init order
Signed-off-by: raver119 <raver119@gmail.com>
* Refactored non_max_suppression op.
* Added cuda kernel for non_max_suppression.
* CPU sort by key/value
Signed-off-by: raver119 <raver119@gmail.com>
* CPU sort TAD by key/value
Signed-off-by: raver119 <raver119@gmail.com>
* CPU sort TAD by key/value tests
Signed-off-by: raver119 <raver119@gmail.com>
* Eliminate compiler error with cuda implementation.
* - repaired gradCheck in cuda
- provide conv2d_bp for cuda
Signed-off-by: Yurii <yurii@skymind.io>
* missed signature
Signed-off-by: raver119 <raver119@gmail.com>
* provide depthwise_conv2d_bp for cuda
Signed-off-by: Yurii <yurii@skymind.io>
* Implementation of lup helper with cuda kernel. Initial commit.
* further work on backprops for convolutions
Signed-off-by: Yurii <yurii@skymind.io>
* CUDA linear sort by key/val
Signed-off-by: raver119 <raver119@gmail.com>
* CUDA tad sort by key/val
Signed-off-by: raver119 <raver119@gmail.com>
* start providing of backprop for pooling2d/3d
Signed-off-by: Yurii <yurii@skymind.io>
* Added atomicAdd for bool datatype.
* dynamic partition concept
Signed-off-by: raver119 <raver119@gmail.com>
* dynamic partition concept
Signed-off-by: raver119 <raver119@gmail.com>
* dynamic partition scalar CUDA
Signed-off-by: raver119 <raver119@gmail.com>
* important comment
Signed-off-by: raver119 <raver119@gmail.com>
* fix pooling2d/3d backprop helpers
Signed-off-by: Yurii <yurii@skymind.io>
* Added non-linear test with dynamic_partition.
* Improved test for dynamic_partition.
* dynamic_partition TAD concept
Signed-off-by: raver119 <raver119@gmail.com>
* - dynamic_partition TAD CUDA impl
- dynamic_partition TAD CPU fix
Signed-off-by: raver119 <raver119@gmail.com>
* - rewrite cpu code for usampling2d/3d
- write cuda code for usampling2d/3d
Signed-off-by: Yurii <yurii@skymind.io>
* dynamic_stitch CUDA vector case
Signed-off-by: raver119 <raver119@gmail.com>
* dynamic_stitch CUDA TAD case concept
Signed-off-by: raver119 <raver119@gmail.com>
* dynamic_stitch CUDA TAD case impl
Signed-off-by: raver119 <raver119@gmail.com>
* Added tests for dynamic_stitch 3D-4D cases.
* minor tests tweaks
Signed-off-by: raver119 <raver119@gmail.com>
* Fixed type check for dynamic stitch.
* min/max bp
Signed-off-by: raver119 <raver119@gmail.com>
* rewrite code for upsampling2d/3d cpu
Signed-off-by: Yurii <yurii@skymind.io>
* reduce min/max/norm_max bp
Signed-off-by: raver119 <raver119@gmail.com>
* lup implementation. Additional enhancements.
* provide code for upsamling2d/3d backprop
Signed-off-by: Yurii <yurii@skymind.io>
* weightedCrossEntropyWithLogits
Signed-off-by: raver119 <raver119@gmail.com>
* Fixed template math atomicMul for 64bit ints.
* Refactored dynamic_partition_bp op.
* inverseBroadcast fix
Signed-off-by: raver119 <raver119@gmail.com>
* DynamicPartitionBP test datatype fixed.
* - nd4j_atomicMul Windows fix
- cpu/NDArrayLambda.hpp excluded from CUDA
Signed-off-by: raver119 <raver119@gmail.com>
2019-06-27 17:37:04 +02:00
|
|
|
bool descending) {
|
2019-08-26 18:57:51 +02:00
|
|
|
try {
|
|
|
|
auto xType = ArrayOptions::dataType(xShapeInfo);
|
|
|
|
auto yType = ArrayOptions::dataType(yShapeInfo);
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
BUILD_DOUBLE_SELECTOR(xType, yType, sd::DoubleMethods, ::sortByValue(x, xShapeInfo, y, yShapeInfo, descending), LIBND4J_TYPES, LIBND4J_TYPES);
|
2019-08-26 18:57:51 +02:00
|
|
|
} catch (std::exception &e) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
2019-08-26 18:57:51 +02:00
|
|
|
}
|
Merge master to upstream (#7945)
* Shugeo strided slice zeros (#14)
* Modified strided_slice op to properly work with empty-like shapes.
* Fixed test for reduce_mean with empty-like input.
* [WIP] Last merge (#15)
* correct logsoftmax looss (#2)
* Small SameDiff listener fix (#4)
* Various fixes (#6)
* #7839 Fix for asXMatrix and tests
* #7866 EmbeddingSequenceLayer dtype fix + test
* #7856 SameDiff save/load stream methods
* #7859 RegressionEvaluation rank 4 fix + tests + axis configuration
* EvaluationBinary 3d/4d
* More evaluation 3d/4d tests
* #7847 Evaluation empty checks
* Small test ifx
* #7848 Fix median edge case
* Improve DL4J samediff layer tests
* [WIP] FastText wrapper implemented (#8)
* FastText implemented
* Some fixes
* Fix shapes for wordsNearest
* Validation of input vectors
* Fixes
* Fixed test
* Thread tagged
* Some tweaks
* setContextClassLoader for DeallocatorServiceThread
* Numpy format tests (#1)
* Various fixes (#11)
* #7852 SameDiff gather fix
* #7892 SameDiff placeholder to constant conversion
* #7890 validate input rank for MLN/CG init methods
* Fix broken permute shape calculation
* Permute and gather fixes
* Tests
* #7850 LogSumExp fix + test
* Handful of test fixes
* Empty arrays with non-scalar shapes (#10)
* minor rearrangements for lambdas
* empty tensors with non-scalar shapes
* numpy empty tensors with non-scalar shapes
* few more empty tweaks
* Small fixes
* conv3d signature update
* micro fix in batchnorm mkldnn
* Import fixes
* Fix
* MKL-DNN update
* Small fill fix
* fill with empty input + test
* Fixes
* Small error improvement
* Fix
* one special test
* couple of fixes for lstm
* Rewrite TFGraphMapper.getNDArrayFromTensor to be maintainable and less error prone
* Fixes
* FP16
* Unsigned
* BFloat16
* Fill op - empty tweaks
* - couple of fixes for empty arrays construction
- stack updated
* strided slice fix
* one transform test
* provide method for reducing shapeInfo in case of input array is empty
* Fixed reduceAlongDimensions to use empty input properly.
* couple of broadcast tests
* couple of tests broadcast tests + tweak to make them pass
* add check of non-empty to methods producing sub-arrays
* Fixed reshapeC with zeros in shape.
* complete empty check in reduce_... legacy ops
* Concat and cumsum/prod
* Tweak to empty shape inference on import
* add empty check to the rest of reduce legacy ops
* one more test
* correct typo in evalReduceShapeInfoEmpty
* Added tests for reduce_* ops to tests with zero shapes.
* few more tests for empty reductions
* Fixed strided_slice op with empty case and tests.
* one more empty reduction test
* Fixed strided_slice test.
* add empty check to NDArray::reshapei
* infOrMax
* empty min/max with infinity tests
* made unstack working correctly with empty arrays
* few IndexReduce tests + tweaks for empty shapes
* add test for empty concat
* few tests fixed
* Validation fix for reductions on empty shapes
* Reverse fix
* Reduction shape calc fixes
* SameDiff.generateOutputVariable: don't use shape function to determine number of outputs
* Range fix
* - NDArray constructor updated for scalars/empty arrays
- few tests fixed
* More fixes
* Empty creator fixes
* concat fix
* concat fix
* TF import tests: allow 'both all NaN' and 'both all inf' to pass
* Slice, zero fraction, and reshape fixes
* transpose, gather
* Zero fraction
* scalar cast fix
* Empty reduction axis support
* few more tests fixed
* Fixed input checks conforming with TF for concat op and tests.
* few tests fixed
* matmul scalar shape fix
* Fixed checkout for data type and scalarity with concat to allow non-empty scalars with vector concats.
* broadcast bool fix
* few more tests
* few more tests
* correct evalReduceShapeInfoEmpty
* argmax/argmin + tests
* one more empty edge case + one more test
* argmax/argmin/realdiv_bp tweaks
* empty reshape test + fix
* Helper fixes
* Small fixes
* Gather test fix
* Gather test fix
* Small fixes
* reduce scalar zero values
* scalar mean workaround
* Remove debug code
* along dim mean workaround
* one more test
* - equalsTo() tweak for empty arrays
- one more test
* broadcast tweaks
* [WIP] Fixing outstanding issues for NLP (#9)
* Avoid using not-inited objects
* Test fixed.
* Redundant method avoided for models like FastText
* KMeans++ implementation
* KMeans++ implementation
* Disable parallel execution
* KMeans++
* Tests
* Dev branch merge (#16)
* SameDiff: convertDataType and gradient check util improvements (#12)
* GradCheck util improvements
* StopGradient constructor + test
* SameDiff: Add datatype conversion
* Javadoc and add DataType.isNumerical()
* Small fix
* Fix SameDiff TF import test cases intermediate naming (workaround for bad default)
* TFGraphTestAllHelper: check intermediates in execution order
* Add missing debug listener
* [WIP] lstmBlock fix + other changes (#13)
- fixes lstmBlock issue
- changes NDArray method reshape(), permute(), transpose() by making them return instance instead of pointer
- CheckNumerics op
- fixes for ReduceBool IsInfOrNan & IsFinite
* Small test fix
* CheckNumerics op wrapper
* Fix some issues on master (#17)
* Fix DataVec test issue
* Fix issue with dl4j SameDiff output layer
* Dtype fix for lambda layers
* #7912 BertIterator dtype fix (use float32 not global default)
* [WIP] Next set of CUDA stuff (#7)
New CUDA implementations and improvements
* bad file
* Dev branch master merge (#23)
* SameDiff: convertDataType and gradient check util improvements (#12)
* GradCheck util improvements
* StopGradient constructor + test
* SameDiff: Add datatype conversion
* Javadoc and add DataType.isNumerical()
* Small fix
* Fix SameDiff TF import test cases intermediate naming (workaround for bad default)
* TFGraphTestAllHelper: check intermediates in execution order
* Add missing debug listener
* [WIP] lstmBlock fix + other changes (#13)
- fixes lstmBlock issue
- changes NDArray method reshape(), permute(), transpose() by making them return instance instead of pointer
- CheckNumerics op
- fixes for ReduceBool IsInfOrNan & IsFinite
* Small test fix
* CheckNumerics op wrapper
* Compatibility of deserialization (#18)
Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>
* SameDiff: add activation gradient checking support for debugging (#19)
* SameDiff gradient checker: first pass on activation gradient checks
* Fixes + tests for activation gradient checking
* Javadoc
* [WIP] Some nd4j data type corrections (#20)
* Adjust data type
* Set correct Data type.
* Size of proper data type.
* fix averaged cpu load (#22)
* SameDiff ops, TF import and fixes (#24)
* CheckNumerics tests + fixes + misc fixes
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Fake quant
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Fixes
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* FakeQuantWithMinMaxArgs
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* CheckNumerics fix
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Fix libnd4j ALL_INTS and ALL_FLOATS declaration (uint and bfloat types)
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Small fix
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Javadoc
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Exception tweak
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* fix
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Fix for out of scope stack allocated var use
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Ignores
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Ignore for known failing test (already logged issue)
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Merge upstream to fork (#25)
* Add thousand-separator commas to TotalParams (#7915)
* Add thousand-separator commas to TotalParams
The number of parameters can be quite large, and it would help the reading of the summary printout to have the TotalParams column & values at the bottom have thousand-separator-commas in them.
* Add thousand-separator commas to MultiLayerNetwork
Corresponding change to MultiLayerNetwork
Signed-off-by: Jxtps Jxtps <jxtps435@gmail.com>
* Update contributing and issue/PR templates (#7934)
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Fix link to AdaDelta paper (#7942)
Fix link to AdaDelta paper hosted on matthewzeiler.com
Signed-off-by: Jxtps
* Fixes, and ignores for known/logged failing issues (#7943)
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* SameDiff + DL4J/SameDiff: Multiple fixes (#28)
* #7919 HDF5 attribute buffer length fix
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* #7909 Arbiter constructor exception ux improvements
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* #7925 RNN output layer length checks
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* #7939 Add listener for validating inputs are not incorrectly modified
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* #7939 Integrate NonInplaceValidationListener into tests
* #7844 DL4J SameDiff fixes for variable minibatch size
* DL4J SameDiff fixes - ensure gradient for input placeholder is available
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Tweaks to ExternalErrorsFunction - use placeholders, make more robust
* Another fix
* More fixes
* More SameDiff/DL4J fixes
* Scope out scalar array creation in BaseScalarOp
* Remove debug code
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* [WIP] Final dev branch merge (#29)
* SameDiff: convertDataType and gradient check util improvements (#12)
* GradCheck util improvements
* StopGradient constructor + test
* SameDiff: Add datatype conversion
* Javadoc and add DataType.isNumerical()
* Small fix
* Fix SameDiff TF import test cases intermediate naming (workaround for bad default)
* TFGraphTestAllHelper: check intermediates in execution order
* Add missing debug listener
* [WIP] lstmBlock fix + other changes (#13)
- fixes lstmBlock issue
- changes NDArray method reshape(), permute(), transpose() by making them return instance instead of pointer
- CheckNumerics op
- fixes for ReduceBool IsInfOrNan & IsFinite
* Small test fix
* CheckNumerics op wrapper
* Compatibility of deserialization (#18)
Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>
* SameDiff: add activation gradient checking support for debugging (#19)
* SameDiff gradient checker: first pass on activation gradient checks
* Fixes + tests for activation gradient checking
* Javadoc
* [WIP] Some nd4j data type corrections (#20)
* Adjust data type
* Set correct Data type.
* Size of proper data type.
* fix averaged cpu load (#22)
* [WIP] Multiple dataset iterators (#27)
* Splitting dataset into arbitrary number
* Fixes
* Multiple split of iterator
* Test
* Test
* Some fixes
* signature change
* one more tweak
Signed-off-by: raver119 <raver119@gmail.com>
* one more test for sequential use of DataSetIteratorSplitter
Signed-off-by: raver119 <raver119@gmail.com>
* Fixes
* Fixes
* one more test for Alexander
Signed-off-by: raver119 <raver119@gmail.com>
* Some fixes
* Some fixes
* one more test for Alexander
Signed-off-by: raver119 <raver119@gmail.com>
* minor test fix
Signed-off-by: raver119 <raver119@gmail.com>
* Some fixes
* Some fixes
* couple of assertions tweaked
Signed-off-by: raver119 <raver119@gmail.com>
* MDS splitter test :/
Signed-off-by: raver119 <raver119@gmail.com>
* Minor refactoring
* Multi dataset
* Some fixes
* More tests
* Small number of test fixes/improvements (failures on CI) (#31)
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* [WIP] More CUDA stuff (#26)
* initial commit
Signed-off-by: raver119 <raver119@gmail.com>
* LRN BP CUDA
Signed-off-by: raver119 <raver119@gmail.com>
* less memory
Signed-off-by: raver119 <raver119@gmail.com>
* Fixed bug with crop_and_resize op helper.
* get rid of unnecessary index-calculation dunction
Signed-off-by: Yurii <yurii@skymind.io>
* Fixed sort with nth_element cuda-based helper.
* Refactored nth_element.
* Refactored nth_element op and tests.
* Modified usage of dim array with sortTad routine.
* Refactored main routine of helper for non_max_image_suppression op.
* non_max_image_suppression op helper with cuda kernel implementation. Initial revision.
* fix vol2col cuda kernel
* meh
Signed-off-by: raver119 <raver119@gmail.com>
* topK concept
Signed-off-by: raver119 <raver119@gmail.com>
* unsorted topK with scanWitdh of 1
Signed-off-by: raver119 <raver119@gmail.com>
* correct vol2col tests
* sorted/unsorted topK
Signed-off-by: raver119 <raver119@gmail.com>
* implementation and fixing col2im/col2vol
* Corrected usage flags with input/output with reverse op.
* dup is const now
Signed-off-by: raver119 <raver119@gmail.com>
* percentile op
Signed-off-by: raver119 <raver119@gmail.com>
* group tests for mapool2d
Signed-off-by: Yurii <yurii@skymind.io>
* special test for george
Signed-off-by: raver119 <raver119@gmail.com>
* less threads for sortTad
Signed-off-by: raver119 <raver119@gmail.com>
* provide conv2d for cuda
Signed-off-by: Yurii <yurii@skymind.io>
* remove auther in sort tad kernel code
Signed-off-by: Yurii <yurii@skymind.io>
* provide depthwise_conv2d for cuda
Signed-off-by: Yurii <yurii@skymind.io>
* - max_pooling_with_argmax
- null check for special use
Signed-off-by: raver119 <raver119@gmail.com>
* dts cuda
Signed-off-by: raver119 <raver119@gmail.com>
* provide sconv2d for cuda
Signed-off-by: Yurii <yurii@skymind.io>
* std cuda
Signed-off-by: raver119 <raver119@gmail.com>
* Refactored non_max_suppression op to conform TF implementation.
* Improved suppression helper.
* provide pooling3d for cuda
Signed-off-by: Yurii <yurii@skymind.io>
* minor lstm rearrangements
Signed-off-by: raver119 <raver119@gmail.com>
* more of minor lstm rearrangements
Signed-off-by: raver119 <raver119@gmail.com>
* (bi)dynamic_rnn
Signed-off-by: raver119 <raver119@gmail.com>
* templates init order
Signed-off-by: raver119 <raver119@gmail.com>
* Refactored non_max_suppression op.
* Added cuda kernel for non_max_suppression.
* CPU sort by key/value
Signed-off-by: raver119 <raver119@gmail.com>
* CPU sort TAD by key/value
Signed-off-by: raver119 <raver119@gmail.com>
* CPU sort TAD by key/value tests
Signed-off-by: raver119 <raver119@gmail.com>
* Eliminate compiler error with cuda implementation.
* - repaired gradCheck in cuda
- provide conv2d_bp for cuda
Signed-off-by: Yurii <yurii@skymind.io>
* missed signature
Signed-off-by: raver119 <raver119@gmail.com>
* provide depthwise_conv2d_bp for cuda
Signed-off-by: Yurii <yurii@skymind.io>
* Implementation of lup helper with cuda kernel. Initial commit.
* further work on backprops for convolutions
Signed-off-by: Yurii <yurii@skymind.io>
* CUDA linear sort by key/val
Signed-off-by: raver119 <raver119@gmail.com>
* CUDA tad sort by key/val
Signed-off-by: raver119 <raver119@gmail.com>
* start providing of backprop for pooling2d/3d
Signed-off-by: Yurii <yurii@skymind.io>
* Added atomicAdd for bool datatype.
* dynamic partition concept
Signed-off-by: raver119 <raver119@gmail.com>
* dynamic partition concept
Signed-off-by: raver119 <raver119@gmail.com>
* dynamic partition scalar CUDA
Signed-off-by: raver119 <raver119@gmail.com>
* important comment
Signed-off-by: raver119 <raver119@gmail.com>
* fix pooling2d/3d backprop helpers
Signed-off-by: Yurii <yurii@skymind.io>
* Added non-linear test with dynamic_partition.
* Improved test for dynamic_partition.
* dynamic_partition TAD concept
Signed-off-by: raver119 <raver119@gmail.com>
* - dynamic_partition TAD CUDA impl
- dynamic_partition TAD CPU fix
Signed-off-by: raver119 <raver119@gmail.com>
* - rewrite cpu code for usampling2d/3d
- write cuda code for usampling2d/3d
Signed-off-by: Yurii <yurii@skymind.io>
* dynamic_stitch CUDA vector case
Signed-off-by: raver119 <raver119@gmail.com>
* dynamic_stitch CUDA TAD case concept
Signed-off-by: raver119 <raver119@gmail.com>
* dynamic_stitch CUDA TAD case impl
Signed-off-by: raver119 <raver119@gmail.com>
* Added tests for dynamic_stitch 3D-4D cases.
* minor tests tweaks
Signed-off-by: raver119 <raver119@gmail.com>
* Fixed type check for dynamic stitch.
* min/max bp
Signed-off-by: raver119 <raver119@gmail.com>
* rewrite code for upsampling2d/3d cpu
Signed-off-by: Yurii <yurii@skymind.io>
* reduce min/max/norm_max bp
Signed-off-by: raver119 <raver119@gmail.com>
* lup implementation. Additional enhancements.
* provide code for upsamling2d/3d backprop
Signed-off-by: Yurii <yurii@skymind.io>
* weightedCrossEntropyWithLogits
Signed-off-by: raver119 <raver119@gmail.com>
* Fixed template math atomicMul for 64bit ints.
* Refactored dynamic_partition_bp op.
* inverseBroadcast fix
Signed-off-by: raver119 <raver119@gmail.com>
* DynamicPartitionBP test datatype fixed.
* - nd4j_atomicMul Windows fix
- cpu/NDArrayLambda.hpp excluded from CUDA
Signed-off-by: raver119 <raver119@gmail.com>
2019-06-27 17:37:04 +02:00
|
|
|
}
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
void sortTadByKey(Nd4jPointer *extraPointers,
|
2020-05-09 07:06:14 +02:00
|
|
|
void *x, const Nd4jLong *xShapeInfo,
|
|
|
|
void *dx, const Nd4jLong *dxShapeInfo,
|
|
|
|
void *y, const Nd4jLong *yShapeInfo,
|
|
|
|
void *dy, const Nd4jLong *dyShapeInfo,
|
|
|
|
int *dimension, int dimensionLength,
|
Merge master to upstream (#7945)
* Shugeo strided slice zeros (#14)
* Modified strided_slice op to properly work with empty-like shapes.
* Fixed test for reduce_mean with empty-like input.
* [WIP] Last merge (#15)
* correct logsoftmax looss (#2)
* Small SameDiff listener fix (#4)
* Various fixes (#6)
* #7839 Fix for asXMatrix and tests
* #7866 EmbeddingSequenceLayer dtype fix + test
* #7856 SameDiff save/load stream methods
* #7859 RegressionEvaluation rank 4 fix + tests + axis configuration
* EvaluationBinary 3d/4d
* More evaluation 3d/4d tests
* #7847 Evaluation empty checks
* Small test ifx
* #7848 Fix median edge case
* Improve DL4J samediff layer tests
* [WIP] FastText wrapper implemented (#8)
* FastText implemented
* Some fixes
* Fix shapes for wordsNearest
* Validation of input vectors
* Fixes
* Fixed test
* Thread tagged
* Some tweaks
* setContextClassLoader for DeallocatorServiceThread
* Numpy format tests (#1)
* Various fixes (#11)
* #7852 SameDiff gather fix
* #7892 SameDiff placeholder to constant conversion
* #7890 validate input rank for MLN/CG init methods
* Fix broken permute shape calculation
* Permute and gather fixes
* Tests
* #7850 LogSumExp fix + test
* Handful of test fixes
* Empty arrays with non-scalar shapes (#10)
* minor rearrangements for lambdas
* empty tensors with non-scalar shapes
* numpy empty tensors with non-scalar shapes
* few more empty tweaks
* Small fixes
* conv3d signature update
* micro fix in batchnorm mkldnn
* Import fixes
* Fix
* MKL-DNN update
* Small fill fix
* fill with empty input + test
* Fixes
* Small error improvement
* Fix
* one special test
* couple of fixes for lstm
* Rewrite TFGraphMapper.getNDArrayFromTensor to be maintainable and less error prone
* Fixes
* FP16
* Unsigned
* BFloat16
* Fill op - empty tweaks
* - couple of fixes for empty arrays construction
- stack updated
* strided slice fix
* one transform test
* provide method for reducing shapeInfo in case of input array is empty
* Fixed reduceAlongDimensions to use empty input properly.
* couple of broadcast tests
* couple of tests broadcast tests + tweak to make them pass
* add check of non-empty to methods producing sub-arrays
* Fixed reshapeC with zeros in shape.
* complete empty check in reduce_... legacy ops
* Concat and cumsum/prod
* Tweak to empty shape inference on import
* add empty check to the rest of reduce legacy ops
* one more test
* correct typo in evalReduceShapeInfoEmpty
* Added tests for reduce_* ops to tests with zero shapes.
* few more tests for empty reductions
* Fixed strided_slice op with empty case and tests.
* one more empty reduction test
* Fixed strided_slice test.
* add empty check to NDArray::reshapei
* infOrMax
* empty min/max with infinity tests
* made unstack working correctly with empty arrays
* few IndexReduce tests + tweaks for empty shapes
* add test for empty concat
* few tests fixed
* Validation fix for reductions on empty shapes
* Reverse fix
* Reduction shape calc fixes
* SameDiff.generateOutputVariable: don't use shape function to determine number of outputs
* Range fix
* - NDArray constructor updated for scalars/empty arrays
- few tests fixed
* More fixes
* Empty creator fixes
* concat fix
* concat fix
* TF import tests: allow 'both all NaN' and 'both all inf' to pass
* Slice, zero fraction, and reshape fixes
* transpose, gather
* Zero fraction
* scalar cast fix
* Empty reduction axis support
* few more tests fixed
* Fixed input checks conforming with TF for concat op and tests.
* few tests fixed
* matmul scalar shape fix
* Fixed checkout for data type and scalarity with concat to allow non-empty scalars with vector concats.
* broadcast bool fix
* few more tests
* few more tests
* correct evalReduceShapeInfoEmpty
* argmax/argmin + tests
* one more empty edge case + one more test
* argmax/argmin/realdiv_bp tweaks
* empty reshape test + fix
* Helper fixes
* Small fixes
* Gather test fix
* Gather test fix
* Small fixes
* reduce scalar zero values
* scalar mean workaround
* Remove debug code
* along dim mean workaround
* one more test
* - equalsTo() tweak for empty arrays
- one more test
* broadcast tweaks
* [WIP] Fixing outstanding issues for NLP (#9)
* Avoid using not-inited objects
* Test fixed.
* Redundant method avoided for models like FastText
* KMeans++ implementation
* KMeans++ implementation
* Disable parallel execution
* KMeans++
* Tests
* Dev branch merge (#16)
* SameDiff: convertDataType and gradient check util improvements (#12)
* GradCheck util improvements
* StopGradient constructor + test
* SameDiff: Add datatype conversion
* Javadoc and add DataType.isNumerical()
* Small fix
* Fix SameDiff TF import test cases intermediate naming (workaround for bad default)
* TFGraphTestAllHelper: check intermediates in execution order
* Add missing debug listener
* [WIP] lstmBlock fix + other changes (#13)
- fixes lstmBlock issue
- changes NDArray method reshape(), permute(), transpose() by making them return instance instead of pointer
- CheckNumerics op
- fixes for ReduceBool IsInfOrNan & IsFinite
* Small test fix
* CheckNumerics op wrapper
* Fix some issues on master (#17)
* Fix DataVec test issue
* Fix issue with dl4j SameDiff output layer
* Dtype fix for lambda layers
* #7912 BertIterator dtype fix (use float32 not global default)
* [WIP] Next set of CUDA stuff (#7)
New CUDA implementations and improvements
* bad file
* Dev branch master merge (#23)
* SameDiff: convertDataType and gradient check util improvements (#12)
* GradCheck util improvements
* StopGradient constructor + test
* SameDiff: Add datatype conversion
* Javadoc and add DataType.isNumerical()
* Small fix
* Fix SameDiff TF import test cases intermediate naming (workaround for bad default)
* TFGraphTestAllHelper: check intermediates in execution order
* Add missing debug listener
* [WIP] lstmBlock fix + other changes (#13)
- fixes lstmBlock issue
- changes NDArray method reshape(), permute(), transpose() by making them return instance instead of pointer
- CheckNumerics op
- fixes for ReduceBool IsInfOrNan & IsFinite
* Small test fix
* CheckNumerics op wrapper
* Compatibility of deserialization (#18)
Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>
* SameDiff: add activation gradient checking support for debugging (#19)
* SameDiff gradient checker: first pass on activation gradient checks
* Fixes + tests for activation gradient checking
* Javadoc
* [WIP] Some nd4j data type corrections (#20)
* Adjust data type
* Set correct Data type.
* Size of proper data type.
* fix averaged cpu load (#22)
* SameDiff ops, TF import and fixes (#24)
* CheckNumerics tests + fixes + misc fixes
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Fake quant
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Fixes
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* FakeQuantWithMinMaxArgs
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* CheckNumerics fix
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Fix libnd4j ALL_INTS and ALL_FLOATS declaration (uint and bfloat types)
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Small fix
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Javadoc
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Exception tweak
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* fix
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Fix for out of scope stack allocated var use
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Ignores
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Ignore for known failing test (already logged issue)
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Merge upstream to fork (#25)
* Add thousand-separator commas to TotalParams (#7915)
* Add thousand-separator commas to TotalParams
The number of parameters can be quite large, and it would help the reading of the summary printout to have the TotalParams column & values at the bottom have thousand-separator-commas in them.
* Add thousand-separator commas to MultiLayerNetwork
Corresponding change to MultiLayerNetwork
Signed-off-by: Jxtps Jxtps <jxtps435@gmail.com>
* Update contributing and issue/PR templates (#7934)
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Fix link to AdaDelta paper (#7942)
Fix link to AdaDelta paper hosted on matthewzeiler.com
Signed-off-by: Jxtps
* Fixes, and ignores for known/logged failing issues (#7943)
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* SameDiff + DL4J/SameDiff: Multiple fixes (#28)
* #7919 HDF5 attribute buffer length fix
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* #7909 Arbiter constructor exception ux improvements
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* #7925 RNN output layer length checks
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* #7939 Add listener for validating inputs are not incorrectly modified
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* #7939 Integrate NonInplaceValidationListener into tests
* #7844 DL4J SameDiff fixes for variable minibatch size
* DL4J SameDiff fixes - ensure gradient for input placeholder is available
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Tweaks to ExternalErrorsFunction - use placeholders, make more robust
* Another fix
* More fixes
* More SameDiff/DL4J fixes
* Scope out scalar array creation in BaseScalarOp
* Remove debug code
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* [WIP] Final dev branch merge (#29)
* SameDiff: convertDataType and gradient check util improvements (#12)
* GradCheck util improvements
* StopGradient constructor + test
* SameDiff: Add datatype conversion
* Javadoc and add DataType.isNumerical()
* Small fix
* Fix SameDiff TF import test cases intermediate naming (workaround for bad default)
* TFGraphTestAllHelper: check intermediates in execution order
* Add missing debug listener
* [WIP] lstmBlock fix + other changes (#13)
- fixes lstmBlock issue
- changes NDArray method reshape(), permute(), transpose() by making them return instance instead of pointer
- CheckNumerics op
- fixes for ReduceBool IsInfOrNan & IsFinite
* Small test fix
* CheckNumerics op wrapper
* Compatibility of deserialization (#18)
Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>
* SameDiff: add activation gradient checking support for debugging (#19)
* SameDiff gradient checker: first pass on activation gradient checks
* Fixes + tests for activation gradient checking
* Javadoc
* [WIP] Some nd4j data type corrections (#20)
* Adjust data type
* Set correct Data type.
* Size of proper data type.
* fix averaged cpu load (#22)
* [WIP] Multiple dataset iterators (#27)
* Splitting dataset into arbitrary number
* Fixes
* Multiple split of iterator
* Test
* Test
* Some fixes
* signature change
* one more tweak
Signed-off-by: raver119 <raver119@gmail.com>
* one more test for sequential use of DataSetIteratorSplitter
Signed-off-by: raver119 <raver119@gmail.com>
* Fixes
* Fixes
* one more test for Alexander
Signed-off-by: raver119 <raver119@gmail.com>
* Some fixes
* Some fixes
* one more test for Alexander
Signed-off-by: raver119 <raver119@gmail.com>
* minor test fix
Signed-off-by: raver119 <raver119@gmail.com>
* Some fixes
* Some fixes
* couple of assertions tweaked
Signed-off-by: raver119 <raver119@gmail.com>
* MDS splitter test :/
Signed-off-by: raver119 <raver119@gmail.com>
* Minor refactoring
* Multi dataset
* Some fixes
* More tests
* Small number of test fixes/improvements (failures on CI) (#31)
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* [WIP] More CUDA stuff (#26)
* initial commit
Signed-off-by: raver119 <raver119@gmail.com>
* LRN BP CUDA
Signed-off-by: raver119 <raver119@gmail.com>
* less memory
Signed-off-by: raver119 <raver119@gmail.com>
* Fixed bug with crop_and_resize op helper.
* get rid of unnecessary index-calculation dunction
Signed-off-by: Yurii <yurii@skymind.io>
* Fixed sort with nth_element cuda-based helper.
* Refactored nth_element.
* Refactored nth_element op and tests.
* Modified usage of dim array with sortTad routine.
* Refactored main routine of helper for non_max_image_suppression op.
* non_max_image_suppression op helper with cuda kernel implementation. Initial revision.
* fix vol2col cuda kernel
* meh
Signed-off-by: raver119 <raver119@gmail.com>
* topK concept
Signed-off-by: raver119 <raver119@gmail.com>
* unsorted topK with scanWitdh of 1
Signed-off-by: raver119 <raver119@gmail.com>
* correct vol2col tests
* sorted/unsorted topK
Signed-off-by: raver119 <raver119@gmail.com>
* implementation and fixing col2im/col2vol
* Corrected usage flags with input/output with reverse op.
* dup is const now
Signed-off-by: raver119 <raver119@gmail.com>
* percentile op
Signed-off-by: raver119 <raver119@gmail.com>
* group tests for mapool2d
Signed-off-by: Yurii <yurii@skymind.io>
* special test for george
Signed-off-by: raver119 <raver119@gmail.com>
* less threads for sortTad
Signed-off-by: raver119 <raver119@gmail.com>
* provide conv2d for cuda
Signed-off-by: Yurii <yurii@skymind.io>
* remove auther in sort tad kernel code
Signed-off-by: Yurii <yurii@skymind.io>
* provide depthwise_conv2d for cuda
Signed-off-by: Yurii <yurii@skymind.io>
* - max_pooling_with_argmax
- null check for special use
Signed-off-by: raver119 <raver119@gmail.com>
* dts cuda
Signed-off-by: raver119 <raver119@gmail.com>
* provide sconv2d for cuda
Signed-off-by: Yurii <yurii@skymind.io>
* std cuda
Signed-off-by: raver119 <raver119@gmail.com>
* Refactored non_max_suppression op to conform TF implementation.
* Improved suppression helper.
* provide pooling3d for cuda
Signed-off-by: Yurii <yurii@skymind.io>
* minor lstm rearrangements
Signed-off-by: raver119 <raver119@gmail.com>
* more of minor lstm rearrangements
Signed-off-by: raver119 <raver119@gmail.com>
* (bi)dynamic_rnn
Signed-off-by: raver119 <raver119@gmail.com>
* templates init order
Signed-off-by: raver119 <raver119@gmail.com>
* Refactored non_max_suppression op.
* Added cuda kernel for non_max_suppression.
* CPU sort by key/value
Signed-off-by: raver119 <raver119@gmail.com>
* CPU sort TAD by key/value
Signed-off-by: raver119 <raver119@gmail.com>
* CPU sort TAD by key/value tests
Signed-off-by: raver119 <raver119@gmail.com>
* Eliminate compiler error with cuda implementation.
* - repaired gradCheck in cuda
- provide conv2d_bp for cuda
Signed-off-by: Yurii <yurii@skymind.io>
* missed signature
Signed-off-by: raver119 <raver119@gmail.com>
* provide depthwise_conv2d_bp for cuda
Signed-off-by: Yurii <yurii@skymind.io>
* Implementation of lup helper with cuda kernel. Initial commit.
* further work on backprops for convolutions
Signed-off-by: Yurii <yurii@skymind.io>
* CUDA linear sort by key/val
Signed-off-by: raver119 <raver119@gmail.com>
* CUDA tad sort by key/val
Signed-off-by: raver119 <raver119@gmail.com>
* start providing of backprop for pooling2d/3d
Signed-off-by: Yurii <yurii@skymind.io>
* Added atomicAdd for bool datatype.
* dynamic partition concept
Signed-off-by: raver119 <raver119@gmail.com>
* dynamic partition concept
Signed-off-by: raver119 <raver119@gmail.com>
* dynamic partition scalar CUDA
Signed-off-by: raver119 <raver119@gmail.com>
* important comment
Signed-off-by: raver119 <raver119@gmail.com>
* fix pooling2d/3d backprop helpers
Signed-off-by: Yurii <yurii@skymind.io>
* Added non-linear test with dynamic_partition.
* Improved test for dynamic_partition.
* dynamic_partition TAD concept
Signed-off-by: raver119 <raver119@gmail.com>
* - dynamic_partition TAD CUDA impl
- dynamic_partition TAD CPU fix
Signed-off-by: raver119 <raver119@gmail.com>
* - rewrite cpu code for usampling2d/3d
- write cuda code for usampling2d/3d
Signed-off-by: Yurii <yurii@skymind.io>
* dynamic_stitch CUDA vector case
Signed-off-by: raver119 <raver119@gmail.com>
* dynamic_stitch CUDA TAD case concept
Signed-off-by: raver119 <raver119@gmail.com>
* dynamic_stitch CUDA TAD case impl
Signed-off-by: raver119 <raver119@gmail.com>
* Added tests for dynamic_stitch 3D-4D cases.
* minor tests tweaks
Signed-off-by: raver119 <raver119@gmail.com>
* Fixed type check for dynamic stitch.
* min/max bp
Signed-off-by: raver119 <raver119@gmail.com>
* rewrite code for upsampling2d/3d cpu
Signed-off-by: Yurii <yurii@skymind.io>
* reduce min/max/norm_max bp
Signed-off-by: raver119 <raver119@gmail.com>
* lup implementation. Additional enhancements.
* provide code for upsamling2d/3d backprop
Signed-off-by: Yurii <yurii@skymind.io>
* weightedCrossEntropyWithLogits
Signed-off-by: raver119 <raver119@gmail.com>
* Fixed template math atomicMul for 64bit ints.
* Refactored dynamic_partition_bp op.
* inverseBroadcast fix
Signed-off-by: raver119 <raver119@gmail.com>
* DynamicPartitionBP test datatype fixed.
* - nd4j_atomicMul Windows fix
- cpu/NDArrayLambda.hpp excluded from CUDA
Signed-off-by: raver119 <raver119@gmail.com>
2019-06-27 17:37:04 +02:00
|
|
|
bool descending) {
|
2019-08-26 18:57:51 +02:00
|
|
|
try {
|
|
|
|
auto xType = ArrayOptions::dataType(xShapeInfo);
|
|
|
|
auto yType = ArrayOptions::dataType(yShapeInfo);
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
BUILD_DOUBLE_SELECTOR(xType, yType, sd::DoubleMethods, ::sortTadByKey(x, xShapeInfo, y, yShapeInfo, dimension, dimensionLength, descending), LIBND4J_TYPES, LIBND4J_TYPES);
|
2019-08-26 18:57:51 +02:00
|
|
|
} catch (std::exception &e) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
2019-08-26 18:57:51 +02:00
|
|
|
}
|
Merge master to upstream (#7945)
* Shugeo strided slice zeros (#14)
* Modified strided_slice op to properly work with empty-like shapes.
* Fixed test for reduce_mean with empty-like input.
* [WIP] Last merge (#15)
* correct logsoftmax looss (#2)
* Small SameDiff listener fix (#4)
* Various fixes (#6)
* #7839 Fix for asXMatrix and tests
* #7866 EmbeddingSequenceLayer dtype fix + test
* #7856 SameDiff save/load stream methods
* #7859 RegressionEvaluation rank 4 fix + tests + axis configuration
* EvaluationBinary 3d/4d
* More evaluation 3d/4d tests
* #7847 Evaluation empty checks
* Small test ifx
* #7848 Fix median edge case
* Improve DL4J samediff layer tests
* [WIP] FastText wrapper implemented (#8)
* FastText implemented
* Some fixes
* Fix shapes for wordsNearest
* Validation of input vectors
* Fixes
* Fixed test
* Thread tagged
* Some tweaks
* setContextClassLoader for DeallocatorServiceThread
* Numpy format tests (#1)
* Various fixes (#11)
* #7852 SameDiff gather fix
* #7892 SameDiff placeholder to constant conversion
* #7890 validate input rank for MLN/CG init methods
* Fix broken permute shape calculation
* Permute and gather fixes
* Tests
* #7850 LogSumExp fix + test
* Handful of test fixes
* Empty arrays with non-scalar shapes (#10)
* minor rearrangements for lambdas
* empty tensors with non-scalar shapes
* numpy empty tensors with non-scalar shapes
* few more empty tweaks
* Small fixes
* conv3d signature update
* micro fix in batchnorm mkldnn
* Import fixes
* Fix
* MKL-DNN update
* Small fill fix
* fill with empty input + test
* Fixes
* Small error improvement
* Fix
* one special test
* couple of fixes for lstm
* Rewrite TFGraphMapper.getNDArrayFromTensor to be maintainable and less error prone
* Fixes
* FP16
* Unsigned
* BFloat16
* Fill op - empty tweaks
* - couple of fixes for empty arrays construction
- stack updated
* strided slice fix
* one transform test
* provide method for reducing shapeInfo in case of input array is empty
* Fixed reduceAlongDimensions to use empty input properly.
* couple of broadcast tests
* couple of tests broadcast tests + tweak to make them pass
* add check of non-empty to methods producing sub-arrays
* Fixed reshapeC with zeros in shape.
* complete empty check in reduce_... legacy ops
* Concat and cumsum/prod
* Tweak to empty shape inference on import
* add empty check to the rest of reduce legacy ops
* one more test
* correct typo in evalReduceShapeInfoEmpty
* Added tests for reduce_* ops to tests with zero shapes.
* few more tests for empty reductions
* Fixed strided_slice op with empty case and tests.
* one more empty reduction test
* Fixed strided_slice test.
* add empty check to NDArray::reshapei
* infOrMax
* empty min/max with infinity tests
* made unstack working correctly with empty arrays
* few IndexReduce tests + tweaks for empty shapes
* add test for empty concat
* few tests fixed
* Validation fix for reductions on empty shapes
* Reverse fix
* Reduction shape calc fixes
* SameDiff.generateOutputVariable: don't use shape function to determine number of outputs
* Range fix
* - NDArray constructor updated for scalars/empty arrays
- few tests fixed
* More fixes
* Empty creator fixes
* concat fix
* concat fix
* TF import tests: allow 'both all NaN' and 'both all inf' to pass
* Slice, zero fraction, and reshape fixes
* transpose, gather
* Zero fraction
* scalar cast fix
* Empty reduction axis support
* few more tests fixed
* Fixed input checks conforming with TF for concat op and tests.
* few tests fixed
* matmul scalar shape fix
* Fixed checkout for data type and scalarity with concat to allow non-empty scalars with vector concats.
* broadcast bool fix
* few more tests
* few more tests
* correct evalReduceShapeInfoEmpty
* argmax/argmin + tests
* one more empty edge case + one more test
* argmax/argmin/realdiv_bp tweaks
* empty reshape test + fix
* Helper fixes
* Small fixes
* Gather test fix
* Gather test fix
* Small fixes
* reduce scalar zero values
* scalar mean workaround
* Remove debug code
* along dim mean workaround
* one more test
* - equalsTo() tweak for empty arrays
- one more test
* broadcast tweaks
* [WIP] Fixing outstanding issues for NLP (#9)
* Avoid using not-inited objects
* Test fixed.
* Redundant method avoided for models like FastText
* KMeans++ implementation
* KMeans++ implementation
* Disable parallel execution
* KMeans++
* Tests
* Dev branch merge (#16)
* SameDiff: convertDataType and gradient check util improvements (#12)
* GradCheck util improvements
* StopGradient constructor + test
* SameDiff: Add datatype conversion
* Javadoc and add DataType.isNumerical()
* Small fix
* Fix SameDiff TF import test cases intermediate naming (workaround for bad default)
* TFGraphTestAllHelper: check intermediates in execution order
* Add missing debug listener
* [WIP] lstmBlock fix + other changes (#13)
- fixes lstmBlock issue
- changes NDArray method reshape(), permute(), transpose() by making them return instance instead of pointer
- CheckNumerics op
- fixes for ReduceBool IsInfOrNan & IsFinite
* Small test fix
* CheckNumerics op wrapper
* Fix some issues on master (#17)
* Fix DataVec test issue
* Fix issue with dl4j SameDiff output layer
* Dtype fix for lambda layers
* #7912 BertIterator dtype fix (use float32 not global default)
* [WIP] Next set of CUDA stuff (#7)
New CUDA implementations and improvements
* bad file
* Dev branch master merge (#23)
* SameDiff: convertDataType and gradient check util improvements (#12)
* GradCheck util improvements
* StopGradient constructor + test
* SameDiff: Add datatype conversion
* Javadoc and add DataType.isNumerical()
* Small fix
* Fix SameDiff TF import test cases intermediate naming (workaround for bad default)
* TFGraphTestAllHelper: check intermediates in execution order
* Add missing debug listener
* [WIP] lstmBlock fix + other changes (#13)
- fixes lstmBlock issue
- changes NDArray method reshape(), permute(), transpose() by making them return instance instead of pointer
- CheckNumerics op
- fixes for ReduceBool IsInfOrNan & IsFinite
* Small test fix
* CheckNumerics op wrapper
* Compatibility of deserialization (#18)
Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>
* SameDiff: add activation gradient checking support for debugging (#19)
* SameDiff gradient checker: first pass on activation gradient checks
* Fixes + tests for activation gradient checking
* Javadoc
* [WIP] Some nd4j data type corrections (#20)
* Adjust data type
* Set correct Data type.
* Size of proper data type.
* fix averaged cpu load (#22)
* SameDiff ops, TF import and fixes (#24)
* CheckNumerics tests + fixes + misc fixes
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Fake quant
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Fixes
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* FakeQuantWithMinMaxArgs
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* CheckNumerics fix
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Fix libnd4j ALL_INTS and ALL_FLOATS declaration (uint and bfloat types)
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Small fix
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Javadoc
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Exception tweak
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* fix
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Fix for out of scope stack allocated var use
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Ignores
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Ignore for known failing test (already logged issue)
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Merge upstream to fork (#25)
* Add thousand-separator commas to TotalParams (#7915)
* Add thousand-separator commas to TotalParams
The number of parameters can be quite large, and it would help the reading of the summary printout to have the TotalParams column & values at the bottom have thousand-separator-commas in them.
* Add thousand-separator commas to MultiLayerNetwork
Corresponding change to MultiLayerNetwork
Signed-off-by: Jxtps Jxtps <jxtps435@gmail.com>
* Update contributing and issue/PR templates (#7934)
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Fix link to AdaDelta paper (#7942)
Fix link to AdaDelta paper hosted on matthewzeiler.com
Signed-off-by: Jxtps
* Fixes, and ignores for known/logged failing issues (#7943)
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* SameDiff + DL4J/SameDiff: Multiple fixes (#28)
* #7919 HDF5 attribute buffer length fix
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* #7909 Arbiter constructor exception ux improvements
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* #7925 RNN output layer length checks
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* #7939 Add listener for validating inputs are not incorrectly modified
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* #7939 Integrate NonInplaceValidationListener into tests
* #7844 DL4J SameDiff fixes for variable minibatch size
* DL4J SameDiff fixes - ensure gradient for input placeholder is available
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Tweaks to ExternalErrorsFunction - use placeholders, make more robust
* Another fix
* More fixes
* More SameDiff/DL4J fixes
* Scope out scalar array creation in BaseScalarOp
* Remove debug code
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* [WIP] Final dev branch merge (#29)
* SameDiff: convertDataType and gradient check util improvements (#12)
* GradCheck util improvements
* StopGradient constructor + test
* SameDiff: Add datatype conversion
* Javadoc and add DataType.isNumerical()
* Small fix
* Fix SameDiff TF import test cases intermediate naming (workaround for bad default)
* TFGraphTestAllHelper: check intermediates in execution order
* Add missing debug listener
* [WIP] lstmBlock fix + other changes (#13)
- fixes lstmBlock issue
- changes NDArray method reshape(), permute(), transpose() by making them return instance instead of pointer
- CheckNumerics op
- fixes for ReduceBool IsInfOrNan & IsFinite
* Small test fix
* CheckNumerics op wrapper
* Compatibility of deserialization (#18)
Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>
* SameDiff: add activation gradient checking support for debugging (#19)
* SameDiff gradient checker: first pass on activation gradient checks
* Fixes + tests for activation gradient checking
* Javadoc
* [WIP] Some nd4j data type corrections (#20)
* Adjust data type
* Set correct Data type.
* Size of proper data type.
* fix averaged cpu load (#22)
* [WIP] Multiple dataset iterators (#27)
* Splitting dataset into arbitrary number
* Fixes
* Multiple split of iterator
* Test
* Test
* Some fixes
* signature change
* one more tweak
Signed-off-by: raver119 <raver119@gmail.com>
* one more test for sequential use of DataSetIteratorSplitter
Signed-off-by: raver119 <raver119@gmail.com>
* Fixes
* Fixes
* one more test for Alexander
Signed-off-by: raver119 <raver119@gmail.com>
* Some fixes
* Some fixes
* one more test for Alexander
Signed-off-by: raver119 <raver119@gmail.com>
* minor test fix
Signed-off-by: raver119 <raver119@gmail.com>
* Some fixes
* Some fixes
* couple of assertions tweaked
Signed-off-by: raver119 <raver119@gmail.com>
* MDS splitter test :/
Signed-off-by: raver119 <raver119@gmail.com>
* Minor refactoring
* Multi dataset
* Some fixes
* More tests
* Small number of test fixes/improvements (failures on CI) (#31)
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* [WIP] More CUDA stuff (#26)
* initial commit
Signed-off-by: raver119 <raver119@gmail.com>
* LRN BP CUDA
Signed-off-by: raver119 <raver119@gmail.com>
* less memory
Signed-off-by: raver119 <raver119@gmail.com>
* Fixed bug with crop_and_resize op helper.
* get rid of unnecessary index-calculation dunction
Signed-off-by: Yurii <yurii@skymind.io>
* Fixed sort with nth_element cuda-based helper.
* Refactored nth_element.
* Refactored nth_element op and tests.
* Modified usage of dim array with sortTad routine.
* Refactored main routine of helper for non_max_image_suppression op.
* non_max_image_suppression op helper with cuda kernel implementation. Initial revision.
* fix vol2col cuda kernel
* meh
Signed-off-by: raver119 <raver119@gmail.com>
* topK concept
Signed-off-by: raver119 <raver119@gmail.com>
* unsorted topK with scanWitdh of 1
Signed-off-by: raver119 <raver119@gmail.com>
* correct vol2col tests
* sorted/unsorted topK
Signed-off-by: raver119 <raver119@gmail.com>
* implementation and fixing col2im/col2vol
* Corrected usage flags with input/output with reverse op.
* dup is const now
Signed-off-by: raver119 <raver119@gmail.com>
* percentile op
Signed-off-by: raver119 <raver119@gmail.com>
* group tests for mapool2d
Signed-off-by: Yurii <yurii@skymind.io>
* special test for george
Signed-off-by: raver119 <raver119@gmail.com>
* less threads for sortTad
Signed-off-by: raver119 <raver119@gmail.com>
* provide conv2d for cuda
Signed-off-by: Yurii <yurii@skymind.io>
* remove auther in sort tad kernel code
Signed-off-by: Yurii <yurii@skymind.io>
* provide depthwise_conv2d for cuda
Signed-off-by: Yurii <yurii@skymind.io>
* - max_pooling_with_argmax
- null check for special use
Signed-off-by: raver119 <raver119@gmail.com>
* dts cuda
Signed-off-by: raver119 <raver119@gmail.com>
* provide sconv2d for cuda
Signed-off-by: Yurii <yurii@skymind.io>
* std cuda
Signed-off-by: raver119 <raver119@gmail.com>
* Refactored non_max_suppression op to conform TF implementation.
* Improved suppression helper.
* provide pooling3d for cuda
Signed-off-by: Yurii <yurii@skymind.io>
* minor lstm rearrangements
Signed-off-by: raver119 <raver119@gmail.com>
* more of minor lstm rearrangements
Signed-off-by: raver119 <raver119@gmail.com>
* (bi)dynamic_rnn
Signed-off-by: raver119 <raver119@gmail.com>
* templates init order
Signed-off-by: raver119 <raver119@gmail.com>
* Refactored non_max_suppression op.
* Added cuda kernel for non_max_suppression.
* CPU sort by key/value
Signed-off-by: raver119 <raver119@gmail.com>
* CPU sort TAD by key/value
Signed-off-by: raver119 <raver119@gmail.com>
* CPU sort TAD by key/value tests
Signed-off-by: raver119 <raver119@gmail.com>
* Eliminate compiler error with cuda implementation.
* - repaired gradCheck in cuda
- provide conv2d_bp for cuda
Signed-off-by: Yurii <yurii@skymind.io>
* missed signature
Signed-off-by: raver119 <raver119@gmail.com>
* provide depthwise_conv2d_bp for cuda
Signed-off-by: Yurii <yurii@skymind.io>
* Implementation of lup helper with cuda kernel. Initial commit.
* further work on backprops for convolutions
Signed-off-by: Yurii <yurii@skymind.io>
* CUDA linear sort by key/val
Signed-off-by: raver119 <raver119@gmail.com>
* CUDA tad sort by key/val
Signed-off-by: raver119 <raver119@gmail.com>
* start providing of backprop for pooling2d/3d
Signed-off-by: Yurii <yurii@skymind.io>
* Added atomicAdd for bool datatype.
* dynamic partition concept
Signed-off-by: raver119 <raver119@gmail.com>
* dynamic partition concept
Signed-off-by: raver119 <raver119@gmail.com>
* dynamic partition scalar CUDA
Signed-off-by: raver119 <raver119@gmail.com>
* important comment
Signed-off-by: raver119 <raver119@gmail.com>
* fix pooling2d/3d backprop helpers
Signed-off-by: Yurii <yurii@skymind.io>
* Added non-linear test with dynamic_partition.
* Improved test for dynamic_partition.
* dynamic_partition TAD concept
Signed-off-by: raver119 <raver119@gmail.com>
* - dynamic_partition TAD CUDA impl
- dynamic_partition TAD CPU fix
Signed-off-by: raver119 <raver119@gmail.com>
* - rewrite cpu code for usampling2d/3d
- write cuda code for usampling2d/3d
Signed-off-by: Yurii <yurii@skymind.io>
* dynamic_stitch CUDA vector case
Signed-off-by: raver119 <raver119@gmail.com>
* dynamic_stitch CUDA TAD case concept
Signed-off-by: raver119 <raver119@gmail.com>
* dynamic_stitch CUDA TAD case impl
Signed-off-by: raver119 <raver119@gmail.com>
* Added tests for dynamic_stitch 3D-4D cases.
* minor tests tweaks
Signed-off-by: raver119 <raver119@gmail.com>
* Fixed type check for dynamic stitch.
* min/max bp
Signed-off-by: raver119 <raver119@gmail.com>
* rewrite code for upsampling2d/3d cpu
Signed-off-by: Yurii <yurii@skymind.io>
* reduce min/max/norm_max bp
Signed-off-by: raver119 <raver119@gmail.com>
* lup implementation. Additional enhancements.
* provide code for upsamling2d/3d backprop
Signed-off-by: Yurii <yurii@skymind.io>
* weightedCrossEntropyWithLogits
Signed-off-by: raver119 <raver119@gmail.com>
* Fixed template math atomicMul for 64bit ints.
* Refactored dynamic_partition_bp op.
* inverseBroadcast fix
Signed-off-by: raver119 <raver119@gmail.com>
* DynamicPartitionBP test datatype fixed.
* - nd4j_atomicMul Windows fix
- cpu/NDArrayLambda.hpp excluded from CUDA
Signed-off-by: raver119 <raver119@gmail.com>
2019-06-27 17:37:04 +02:00
|
|
|
}
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
void sortTadByValue(Nd4jPointer *extraPointers,
|
2020-05-09 07:06:14 +02:00
|
|
|
void *x, const Nd4jLong *xShapeInfo,
|
|
|
|
void *dx, const Nd4jLong *dxShapeInfo,
|
|
|
|
void *y, const Nd4jLong *yShapeInfo,
|
|
|
|
void *dy, const Nd4jLong *dyShapeInfo,
|
|
|
|
int *dimension, int dimensionLength,
|
Merge master to upstream (#7945)
* Shugeo strided slice zeros (#14)
* Modified strided_slice op to properly work with empty-like shapes.
* Fixed test for reduce_mean with empty-like input.
* [WIP] Last merge (#15)
* correct logsoftmax looss (#2)
* Small SameDiff listener fix (#4)
* Various fixes (#6)
* #7839 Fix for asXMatrix and tests
* #7866 EmbeddingSequenceLayer dtype fix + test
* #7856 SameDiff save/load stream methods
* #7859 RegressionEvaluation rank 4 fix + tests + axis configuration
* EvaluationBinary 3d/4d
* More evaluation 3d/4d tests
* #7847 Evaluation empty checks
* Small test ifx
* #7848 Fix median edge case
* Improve DL4J samediff layer tests
* [WIP] FastText wrapper implemented (#8)
* FastText implemented
* Some fixes
* Fix shapes for wordsNearest
* Validation of input vectors
* Fixes
* Fixed test
* Thread tagged
* Some tweaks
* setContextClassLoader for DeallocatorServiceThread
* Numpy format tests (#1)
* Various fixes (#11)
* #7852 SameDiff gather fix
* #7892 SameDiff placeholder to constant conversion
* #7890 validate input rank for MLN/CG init methods
* Fix broken permute shape calculation
* Permute and gather fixes
* Tests
* #7850 LogSumExp fix + test
* Handful of test fixes
* Empty arrays with non-scalar shapes (#10)
* minor rearrangements for lambdas
* empty tensors with non-scalar shapes
* numpy empty tensors with non-scalar shapes
* few more empty tweaks
* Small fixes
* conv3d signature update
* micro fix in batchnorm mkldnn
* Import fixes
* Fix
* MKL-DNN update
* Small fill fix
* fill with empty input + test
* Fixes
* Small error improvement
* Fix
* one special test
* couple of fixes for lstm
* Rewrite TFGraphMapper.getNDArrayFromTensor to be maintainable and less error prone
* Fixes
* FP16
* Unsigned
* BFloat16
* Fill op - empty tweaks
* - couple of fixes for empty arrays construction
- stack updated
* strided slice fix
* one transform test
* provide method for reducing shapeInfo in case of input array is empty
* Fixed reduceAlongDimensions to use empty input properly.
* couple of broadcast tests
* couple of tests broadcast tests + tweak to make them pass
* add check of non-empty to methods producing sub-arrays
* Fixed reshapeC with zeros in shape.
* complete empty check in reduce_... legacy ops
* Concat and cumsum/prod
* Tweak to empty shape inference on import
* add empty check to the rest of reduce legacy ops
* one more test
* correct typo in evalReduceShapeInfoEmpty
* Added tests for reduce_* ops to tests with zero shapes.
* few more tests for empty reductions
* Fixed strided_slice op with empty case and tests.
* one more empty reduction test
* Fixed strided_slice test.
* add empty check to NDArray::reshapei
* infOrMax
* empty min/max with infinity tests
* made unstack working correctly with empty arrays
* few IndexReduce tests + tweaks for empty shapes
* add test for empty concat
* few tests fixed
* Validation fix for reductions on empty shapes
* Reverse fix
* Reduction shape calc fixes
* SameDiff.generateOutputVariable: don't use shape function to determine number of outputs
* Range fix
* - NDArray constructor updated for scalars/empty arrays
- few tests fixed
* More fixes
* Empty creator fixes
* concat fix
* concat fix
* TF import tests: allow 'both all NaN' and 'both all inf' to pass
* Slice, zero fraction, and reshape fixes
* transpose, gather
* Zero fraction
* scalar cast fix
* Empty reduction axis support
* few more tests fixed
* Fixed input checks conforming with TF for concat op and tests.
* few tests fixed
* matmul scalar shape fix
* Fixed checkout for data type and scalarity with concat to allow non-empty scalars with vector concats.
* broadcast bool fix
* few more tests
* few more tests
* correct evalReduceShapeInfoEmpty
* argmax/argmin + tests
* one more empty edge case + one more test
* argmax/argmin/realdiv_bp tweaks
* empty reshape test + fix
* Helper fixes
* Small fixes
* Gather test fix
* Gather test fix
* Small fixes
* reduce scalar zero values
* scalar mean workaround
* Remove debug code
* along dim mean workaround
* one more test
* - equalsTo() tweak for empty arrays
- one more test
* broadcast tweaks
* [WIP] Fixing outstanding issues for NLP (#9)
* Avoid using not-inited objects
* Test fixed.
* Redundant method avoided for models like FastText
* KMeans++ implementation
* KMeans++ implementation
* Disable parallel execution
* KMeans++
* Tests
* Dev branch merge (#16)
* SameDiff: convertDataType and gradient check util improvements (#12)
* GradCheck util improvements
* StopGradient constructor + test
* SameDiff: Add datatype conversion
* Javadoc and add DataType.isNumerical()
* Small fix
* Fix SameDiff TF import test cases intermediate naming (workaround for bad default)
* TFGraphTestAllHelper: check intermediates in execution order
* Add missing debug listener
* [WIP] lstmBlock fix + other changes (#13)
- fixes lstmBlock issue
- changes NDArray method reshape(), permute(), transpose() by making them return instance instead of pointer
- CheckNumerics op
- fixes for ReduceBool IsInfOrNan & IsFinite
* Small test fix
* CheckNumerics op wrapper
* Fix some issues on master (#17)
* Fix DataVec test issue
* Fix issue with dl4j SameDiff output layer
* Dtype fix for lambda layers
* #7912 BertIterator dtype fix (use float32 not global default)
* [WIP] Next set of CUDA stuff (#7)
New CUDA implementations and improvements
* bad file
* Dev branch master merge (#23)
* SameDiff: convertDataType and gradient check util improvements (#12)
* GradCheck util improvements
* StopGradient constructor + test
* SameDiff: Add datatype conversion
* Javadoc and add DataType.isNumerical()
* Small fix
* Fix SameDiff TF import test cases intermediate naming (workaround for bad default)
* TFGraphTestAllHelper: check intermediates in execution order
* Add missing debug listener
* [WIP] lstmBlock fix + other changes (#13)
- fixes lstmBlock issue
- changes NDArray method reshape(), permute(), transpose() by making them return instance instead of pointer
- CheckNumerics op
- fixes for ReduceBool IsInfOrNan & IsFinite
* Small test fix
* CheckNumerics op wrapper
* Compatibility of deserialization (#18)
Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>
* SameDiff: add activation gradient checking support for debugging (#19)
* SameDiff gradient checker: first pass on activation gradient checks
* Fixes + tests for activation gradient checking
* Javadoc
* [WIP] Some nd4j data type corrections (#20)
* Adjust data type
* Set correct Data type.
* Size of proper data type.
* fix averaged cpu load (#22)
* SameDiff ops, TF import and fixes (#24)
* CheckNumerics tests + fixes + misc fixes
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Fake quant
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Fixes
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* FakeQuantWithMinMaxArgs
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* CheckNumerics fix
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Fix libnd4j ALL_INTS and ALL_FLOATS declaration (uint and bfloat types)
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Small fix
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Javadoc
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Exception tweak
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* fix
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Fix for out of scope stack allocated var use
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Ignores
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Ignore for known failing test (already logged issue)
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Merge upstream to fork (#25)
* Add thousand-separator commas to TotalParams (#7915)
* Add thousand-separator commas to TotalParams
The number of parameters can be quite large, and it would help the reading of the summary printout to have the TotalParams column & values at the bottom have thousand-separator-commas in them.
* Add thousand-separator commas to MultiLayerNetwork
Corresponding change to MultiLayerNetwork
Signed-off-by: Jxtps Jxtps <jxtps435@gmail.com>
* Update contributing and issue/PR templates (#7934)
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Fix link to AdaDelta paper (#7942)
Fix link to AdaDelta paper hosted on matthewzeiler.com
Signed-off-by: Jxtps
* Fixes, and ignores for known/logged failing issues (#7943)
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* SameDiff + DL4J/SameDiff: Multiple fixes (#28)
* #7919 HDF5 attribute buffer length fix
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* #7909 Arbiter constructor exception ux improvements
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* #7925 RNN output layer length checks
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* #7939 Add listener for validating inputs are not incorrectly modified
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* #7939 Integrate NonInplaceValidationListener into tests
* #7844 DL4J SameDiff fixes for variable minibatch size
* DL4J SameDiff fixes - ensure gradient for input placeholder is available
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Tweaks to ExternalErrorsFunction - use placeholders, make more robust
* Another fix
* More fixes
* More SameDiff/DL4J fixes
* Scope out scalar array creation in BaseScalarOp
* Remove debug code
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* [WIP] Final dev branch merge (#29)
* SameDiff: convertDataType and gradient check util improvements (#12)
* GradCheck util improvements
* StopGradient constructor + test
* SameDiff: Add datatype conversion
* Javadoc and add DataType.isNumerical()
* Small fix
* Fix SameDiff TF import test cases intermediate naming (workaround for bad default)
* TFGraphTestAllHelper: check intermediates in execution order
* Add missing debug listener
* [WIP] lstmBlock fix + other changes (#13)
- fixes lstmBlock issue
- changes NDArray method reshape(), permute(), transpose() by making them return instance instead of pointer
- CheckNumerics op
- fixes for ReduceBool IsInfOrNan & IsFinite
* Small test fix
* CheckNumerics op wrapper
* Compatibility of deserialization (#18)
Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>
* SameDiff: add activation gradient checking support for debugging (#19)
* SameDiff gradient checker: first pass on activation gradient checks
* Fixes + tests for activation gradient checking
* Javadoc
* [WIP] Some nd4j data type corrections (#20)
* Adjust data type
* Set correct Data type.
* Size of proper data type.
* fix averaged cpu load (#22)
* [WIP] Multiple dataset iterators (#27)
* Splitting dataset into arbitrary number
* Fixes
* Multiple split of iterator
* Test
* Test
* Some fixes
* signature change
* one more tweak
Signed-off-by: raver119 <raver119@gmail.com>
* one more test for sequential use of DataSetIteratorSplitter
Signed-off-by: raver119 <raver119@gmail.com>
* Fixes
* Fixes
* one more test for Alexander
Signed-off-by: raver119 <raver119@gmail.com>
* Some fixes
* Some fixes
* one more test for Alexander
Signed-off-by: raver119 <raver119@gmail.com>
* minor test fix
Signed-off-by: raver119 <raver119@gmail.com>
* Some fixes
* Some fixes
* couple of assertions tweaked
Signed-off-by: raver119 <raver119@gmail.com>
* MDS splitter test :/
Signed-off-by: raver119 <raver119@gmail.com>
* Minor refactoring
* Multi dataset
* Some fixes
* More tests
* Small number of test fixes/improvements (failures on CI) (#31)
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* [WIP] More CUDA stuff (#26)
* initial commit
Signed-off-by: raver119 <raver119@gmail.com>
* LRN BP CUDA
Signed-off-by: raver119 <raver119@gmail.com>
* less memory
Signed-off-by: raver119 <raver119@gmail.com>
* Fixed bug with crop_and_resize op helper.
* get rid of unnecessary index-calculation dunction
Signed-off-by: Yurii <yurii@skymind.io>
* Fixed sort with nth_element cuda-based helper.
* Refactored nth_element.
* Refactored nth_element op and tests.
* Modified usage of dim array with sortTad routine.
* Refactored main routine of helper for non_max_image_suppression op.
* non_max_image_suppression op helper with cuda kernel implementation. Initial revision.
* fix vol2col cuda kernel
* meh
Signed-off-by: raver119 <raver119@gmail.com>
* topK concept
Signed-off-by: raver119 <raver119@gmail.com>
* unsorted topK with scanWitdh of 1
Signed-off-by: raver119 <raver119@gmail.com>
* correct vol2col tests
* sorted/unsorted topK
Signed-off-by: raver119 <raver119@gmail.com>
* implementation and fixing col2im/col2vol
* Corrected usage flags with input/output with reverse op.
* dup is const now
Signed-off-by: raver119 <raver119@gmail.com>
* percentile op
Signed-off-by: raver119 <raver119@gmail.com>
* group tests for mapool2d
Signed-off-by: Yurii <yurii@skymind.io>
* special test for george
Signed-off-by: raver119 <raver119@gmail.com>
* less threads for sortTad
Signed-off-by: raver119 <raver119@gmail.com>
* provide conv2d for cuda
Signed-off-by: Yurii <yurii@skymind.io>
* remove auther in sort tad kernel code
Signed-off-by: Yurii <yurii@skymind.io>
* provide depthwise_conv2d for cuda
Signed-off-by: Yurii <yurii@skymind.io>
* - max_pooling_with_argmax
- null check for special use
Signed-off-by: raver119 <raver119@gmail.com>
* dts cuda
Signed-off-by: raver119 <raver119@gmail.com>
* provide sconv2d for cuda
Signed-off-by: Yurii <yurii@skymind.io>
* std cuda
Signed-off-by: raver119 <raver119@gmail.com>
* Refactored non_max_suppression op to conform TF implementation.
* Improved suppression helper.
* provide pooling3d for cuda
Signed-off-by: Yurii <yurii@skymind.io>
* minor lstm rearrangements
Signed-off-by: raver119 <raver119@gmail.com>
* more of minor lstm rearrangements
Signed-off-by: raver119 <raver119@gmail.com>
* (bi)dynamic_rnn
Signed-off-by: raver119 <raver119@gmail.com>
* templates init order
Signed-off-by: raver119 <raver119@gmail.com>
* Refactored non_max_suppression op.
* Added cuda kernel for non_max_suppression.
* CPU sort by key/value
Signed-off-by: raver119 <raver119@gmail.com>
* CPU sort TAD by key/value
Signed-off-by: raver119 <raver119@gmail.com>
* CPU sort TAD by key/value tests
Signed-off-by: raver119 <raver119@gmail.com>
* Eliminate compiler error with cuda implementation.
* - repaired gradCheck in cuda
- provide conv2d_bp for cuda
Signed-off-by: Yurii <yurii@skymind.io>
* missed signature
Signed-off-by: raver119 <raver119@gmail.com>
* provide depthwise_conv2d_bp for cuda
Signed-off-by: Yurii <yurii@skymind.io>
* Implementation of lup helper with cuda kernel. Initial commit.
* further work on backprops for convolutions
Signed-off-by: Yurii <yurii@skymind.io>
* CUDA linear sort by key/val
Signed-off-by: raver119 <raver119@gmail.com>
* CUDA tad sort by key/val
Signed-off-by: raver119 <raver119@gmail.com>
* start providing of backprop for pooling2d/3d
Signed-off-by: Yurii <yurii@skymind.io>
* Added atomicAdd for bool datatype.
* dynamic partition concept
Signed-off-by: raver119 <raver119@gmail.com>
* dynamic partition concept
Signed-off-by: raver119 <raver119@gmail.com>
* dynamic partition scalar CUDA
Signed-off-by: raver119 <raver119@gmail.com>
* important comment
Signed-off-by: raver119 <raver119@gmail.com>
* fix pooling2d/3d backprop helpers
Signed-off-by: Yurii <yurii@skymind.io>
* Added non-linear test with dynamic_partition.
* Improved test for dynamic_partition.
* dynamic_partition TAD concept
Signed-off-by: raver119 <raver119@gmail.com>
* - dynamic_partition TAD CUDA impl
- dynamic_partition TAD CPU fix
Signed-off-by: raver119 <raver119@gmail.com>
* - rewrite cpu code for usampling2d/3d
- write cuda code for usampling2d/3d
Signed-off-by: Yurii <yurii@skymind.io>
* dynamic_stitch CUDA vector case
Signed-off-by: raver119 <raver119@gmail.com>
* dynamic_stitch CUDA TAD case concept
Signed-off-by: raver119 <raver119@gmail.com>
* dynamic_stitch CUDA TAD case impl
Signed-off-by: raver119 <raver119@gmail.com>
* Added tests for dynamic_stitch 3D-4D cases.
* minor tests tweaks
Signed-off-by: raver119 <raver119@gmail.com>
* Fixed type check for dynamic stitch.
* min/max bp
Signed-off-by: raver119 <raver119@gmail.com>
* rewrite code for upsampling2d/3d cpu
Signed-off-by: Yurii <yurii@skymind.io>
* reduce min/max/norm_max bp
Signed-off-by: raver119 <raver119@gmail.com>
* lup implementation. Additional enhancements.
* provide code for upsamling2d/3d backprop
Signed-off-by: Yurii <yurii@skymind.io>
* weightedCrossEntropyWithLogits
Signed-off-by: raver119 <raver119@gmail.com>
* Fixed template math atomicMul for 64bit ints.
* Refactored dynamic_partition_bp op.
* inverseBroadcast fix
Signed-off-by: raver119 <raver119@gmail.com>
* DynamicPartitionBP test datatype fixed.
* - nd4j_atomicMul Windows fix
- cpu/NDArrayLambda.hpp excluded from CUDA
Signed-off-by: raver119 <raver119@gmail.com>
2019-06-27 17:37:04 +02:00
|
|
|
bool descending) {
|
2019-08-26 18:57:51 +02:00
|
|
|
try {
|
|
|
|
auto xType = ArrayOptions::dataType(xShapeInfo);
|
|
|
|
auto yType = ArrayOptions::dataType(yShapeInfo);
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
BUILD_DOUBLE_SELECTOR(xType, yType, sd::DoubleMethods, ::sortTadByValue(x, xShapeInfo, y, yShapeInfo, dimension, dimensionLength, descending), LIBND4J_TYPES, LIBND4J_TYPES);
|
2019-08-26 18:57:51 +02:00
|
|
|
} catch (std::exception &e) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
2019-08-26 18:57:51 +02:00
|
|
|
}
|
Merge master to upstream (#7945)
* Shugeo strided slice zeros (#14)
* Modified strided_slice op to properly work with empty-like shapes.
* Fixed test for reduce_mean with empty-like input.
* [WIP] Last merge (#15)
* correct logsoftmax looss (#2)
* Small SameDiff listener fix (#4)
* Various fixes (#6)
* #7839 Fix for asXMatrix and tests
* #7866 EmbeddingSequenceLayer dtype fix + test
* #7856 SameDiff save/load stream methods
* #7859 RegressionEvaluation rank 4 fix + tests + axis configuration
* EvaluationBinary 3d/4d
* More evaluation 3d/4d tests
* #7847 Evaluation empty checks
* Small test ifx
* #7848 Fix median edge case
* Improve DL4J samediff layer tests
* [WIP] FastText wrapper implemented (#8)
* FastText implemented
* Some fixes
* Fix shapes for wordsNearest
* Validation of input vectors
* Fixes
* Fixed test
* Thread tagged
* Some tweaks
* setContextClassLoader for DeallocatorServiceThread
* Numpy format tests (#1)
* Various fixes (#11)
* #7852 SameDiff gather fix
* #7892 SameDiff placeholder to constant conversion
* #7890 validate input rank for MLN/CG init methods
* Fix broken permute shape calculation
* Permute and gather fixes
* Tests
* #7850 LogSumExp fix + test
* Handful of test fixes
* Empty arrays with non-scalar shapes (#10)
* minor rearrangements for lambdas
* empty tensors with non-scalar shapes
* numpy empty tensors with non-scalar shapes
* few more empty tweaks
* Small fixes
* conv3d signature update
* micro fix in batchnorm mkldnn
* Import fixes
* Fix
* MKL-DNN update
* Small fill fix
* fill with empty input + test
* Fixes
* Small error improvement
* Fix
* one special test
* couple of fixes for lstm
* Rewrite TFGraphMapper.getNDArrayFromTensor to be maintainable and less error prone
* Fixes
* FP16
* Unsigned
* BFloat16
* Fill op - empty tweaks
* - couple of fixes for empty arrays construction
- stack updated
* strided slice fix
* one transform test
* provide method for reducing shapeInfo in case of input array is empty
* Fixed reduceAlongDimensions to use empty input properly.
* couple of broadcast tests
* couple of tests broadcast tests + tweak to make them pass
* add check of non-empty to methods producing sub-arrays
* Fixed reshapeC with zeros in shape.
* complete empty check in reduce_... legacy ops
* Concat and cumsum/prod
* Tweak to empty shape inference on import
* add empty check to the rest of reduce legacy ops
* one more test
* correct typo in evalReduceShapeInfoEmpty
* Added tests for reduce_* ops to tests with zero shapes.
* few more tests for empty reductions
* Fixed strided_slice op with empty case and tests.
* one more empty reduction test
* Fixed strided_slice test.
* add empty check to NDArray::reshapei
* infOrMax
* empty min/max with infinity tests
* made unstack working correctly with empty arrays
* few IndexReduce tests + tweaks for empty shapes
* add test for empty concat
* few tests fixed
* Validation fix for reductions on empty shapes
* Reverse fix
* Reduction shape calc fixes
* SameDiff.generateOutputVariable: don't use shape function to determine number of outputs
* Range fix
* - NDArray constructor updated for scalars/empty arrays
- few tests fixed
* More fixes
* Empty creator fixes
* concat fix
* concat fix
* TF import tests: allow 'both all NaN' and 'both all inf' to pass
* Slice, zero fraction, and reshape fixes
* transpose, gather
* Zero fraction
* scalar cast fix
* Empty reduction axis support
* few more tests fixed
* Fixed input checks conforming with TF for concat op and tests.
* few tests fixed
* matmul scalar shape fix
* Fixed checkout for data type and scalarity with concat to allow non-empty scalars with vector concats.
* broadcast bool fix
* few more tests
* few more tests
* correct evalReduceShapeInfoEmpty
* argmax/argmin + tests
* one more empty edge case + one more test
* argmax/argmin/realdiv_bp tweaks
* empty reshape test + fix
* Helper fixes
* Small fixes
* Gather test fix
* Gather test fix
* Small fixes
* reduce scalar zero values
* scalar mean workaround
* Remove debug code
* along dim mean workaround
* one more test
* - equalsTo() tweak for empty arrays
- one more test
* broadcast tweaks
* [WIP] Fixing outstanding issues for NLP (#9)
* Avoid using not-inited objects
* Test fixed.
* Redundant method avoided for models like FastText
* KMeans++ implementation
* KMeans++ implementation
* Disable parallel execution
* KMeans++
* Tests
* Dev branch merge (#16)
* SameDiff: convertDataType and gradient check util improvements (#12)
* GradCheck util improvements
* StopGradient constructor + test
* SameDiff: Add datatype conversion
* Javadoc and add DataType.isNumerical()
* Small fix
* Fix SameDiff TF import test cases intermediate naming (workaround for bad default)
* TFGraphTestAllHelper: check intermediates in execution order
* Add missing debug listener
* [WIP] lstmBlock fix + other changes (#13)
- fixes lstmBlock issue
- changes NDArray method reshape(), permute(), transpose() by making them return instance instead of pointer
- CheckNumerics op
- fixes for ReduceBool IsInfOrNan & IsFinite
* Small test fix
* CheckNumerics op wrapper
* Fix some issues on master (#17)
* Fix DataVec test issue
* Fix issue with dl4j SameDiff output layer
* Dtype fix for lambda layers
* #7912 BertIterator dtype fix (use float32 not global default)
* [WIP] Next set of CUDA stuff (#7)
New CUDA implementations and improvements
* bad file
* Dev branch master merge (#23)
* SameDiff: convertDataType and gradient check util improvements (#12)
* GradCheck util improvements
* StopGradient constructor + test
* SameDiff: Add datatype conversion
* Javadoc and add DataType.isNumerical()
* Small fix
* Fix SameDiff TF import test cases intermediate naming (workaround for bad default)
* TFGraphTestAllHelper: check intermediates in execution order
* Add missing debug listener
* [WIP] lstmBlock fix + other changes (#13)
- fixes lstmBlock issue
- changes NDArray method reshape(), permute(), transpose() by making them return instance instead of pointer
- CheckNumerics op
- fixes for ReduceBool IsInfOrNan & IsFinite
* Small test fix
* CheckNumerics op wrapper
* Compatibility of deserialization (#18)
Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>
* SameDiff: add activation gradient checking support for debugging (#19)
* SameDiff gradient checker: first pass on activation gradient checks
* Fixes + tests for activation gradient checking
* Javadoc
* [WIP] Some nd4j data type corrections (#20)
* Adjust data type
* Set correct Data type.
* Size of proper data type.
* fix averaged cpu load (#22)
* SameDiff ops, TF import and fixes (#24)
* CheckNumerics tests + fixes + misc fixes
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Fake quant
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Fixes
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* FakeQuantWithMinMaxArgs
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* CheckNumerics fix
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Fix libnd4j ALL_INTS and ALL_FLOATS declaration (uint and bfloat types)
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Small fix
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Javadoc
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Exception tweak
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* fix
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Fix for out of scope stack allocated var use
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Ignores
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Ignore for known failing test (already logged issue)
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Merge upstream to fork (#25)
* Add thousand-separator commas to TotalParams (#7915)
* Add thousand-separator commas to TotalParams
The number of parameters can be quite large, and it would help the reading of the summary printout to have the TotalParams column & values at the bottom have thousand-separator-commas in them.
* Add thousand-separator commas to MultiLayerNetwork
Corresponding change to MultiLayerNetwork
Signed-off-by: Jxtps Jxtps <jxtps435@gmail.com>
* Update contributing and issue/PR templates (#7934)
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Fix link to AdaDelta paper (#7942)
Fix link to AdaDelta paper hosted on matthewzeiler.com
Signed-off-by: Jxtps
* Fixes, and ignores for known/logged failing issues (#7943)
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* SameDiff + DL4J/SameDiff: Multiple fixes (#28)
* #7919 HDF5 attribute buffer length fix
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* #7909 Arbiter constructor exception ux improvements
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* #7925 RNN output layer length checks
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* #7939 Add listener for validating inputs are not incorrectly modified
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* #7939 Integrate NonInplaceValidationListener into tests
* #7844 DL4J SameDiff fixes for variable minibatch size
* DL4J SameDiff fixes - ensure gradient for input placeholder is available
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Tweaks to ExternalErrorsFunction - use placeholders, make more robust
* Another fix
* More fixes
* More SameDiff/DL4J fixes
* Scope out scalar array creation in BaseScalarOp
* Remove debug code
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* [WIP] Final dev branch merge (#29)
* SameDiff: convertDataType and gradient check util improvements (#12)
* GradCheck util improvements
* StopGradient constructor + test
* SameDiff: Add datatype conversion
* Javadoc and add DataType.isNumerical()
* Small fix
* Fix SameDiff TF import test cases intermediate naming (workaround for bad default)
* TFGraphTestAllHelper: check intermediates in execution order
* Add missing debug listener
* [WIP] lstmBlock fix + other changes (#13)
- fixes lstmBlock issue
- changes NDArray method reshape(), permute(), transpose() by making them return instance instead of pointer
- CheckNumerics op
- fixes for ReduceBool IsInfOrNan & IsFinite
* Small test fix
* CheckNumerics op wrapper
* Compatibility of deserialization (#18)
Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>
* SameDiff: add activation gradient checking support for debugging (#19)
* SameDiff gradient checker: first pass on activation gradient checks
* Fixes + tests for activation gradient checking
* Javadoc
* [WIP] Some nd4j data type corrections (#20)
* Adjust data type
* Set correct Data type.
* Size of proper data type.
* fix averaged cpu load (#22)
* [WIP] Multiple dataset iterators (#27)
* Splitting dataset into arbitrary number
* Fixes
* Multiple split of iterator
* Test
* Test
* Some fixes
* signature change
* one more tweak
Signed-off-by: raver119 <raver119@gmail.com>
* one more test for sequential use of DataSetIteratorSplitter
Signed-off-by: raver119 <raver119@gmail.com>
* Fixes
* Fixes
* one more test for Alexander
Signed-off-by: raver119 <raver119@gmail.com>
* Some fixes
* Some fixes
* one more test for Alexander
Signed-off-by: raver119 <raver119@gmail.com>
* minor test fix
Signed-off-by: raver119 <raver119@gmail.com>
* Some fixes
* Some fixes
* couple of assertions tweaked
Signed-off-by: raver119 <raver119@gmail.com>
* MDS splitter test :/
Signed-off-by: raver119 <raver119@gmail.com>
* Minor refactoring
* Multi dataset
* Some fixes
* More tests
* Small number of test fixes/improvements (failures on CI) (#31)
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* [WIP] More CUDA stuff (#26)
* initial commit
Signed-off-by: raver119 <raver119@gmail.com>
* LRN BP CUDA
Signed-off-by: raver119 <raver119@gmail.com>
* less memory
Signed-off-by: raver119 <raver119@gmail.com>
* Fixed bug with crop_and_resize op helper.
* get rid of unnecessary index-calculation dunction
Signed-off-by: Yurii <yurii@skymind.io>
* Fixed sort with nth_element cuda-based helper.
* Refactored nth_element.
* Refactored nth_element op and tests.
* Modified usage of dim array with sortTad routine.
* Refactored main routine of helper for non_max_image_suppression op.
* non_max_image_suppression op helper with cuda kernel implementation. Initial revision.
* fix vol2col cuda kernel
* meh
Signed-off-by: raver119 <raver119@gmail.com>
* topK concept
Signed-off-by: raver119 <raver119@gmail.com>
* unsorted topK with scanWitdh of 1
Signed-off-by: raver119 <raver119@gmail.com>
* correct vol2col tests
* sorted/unsorted topK
Signed-off-by: raver119 <raver119@gmail.com>
* implementation and fixing col2im/col2vol
* Corrected usage flags with input/output with reverse op.
* dup is const now
Signed-off-by: raver119 <raver119@gmail.com>
* percentile op
Signed-off-by: raver119 <raver119@gmail.com>
* group tests for mapool2d
Signed-off-by: Yurii <yurii@skymind.io>
* special test for george
Signed-off-by: raver119 <raver119@gmail.com>
* less threads for sortTad
Signed-off-by: raver119 <raver119@gmail.com>
* provide conv2d for cuda
Signed-off-by: Yurii <yurii@skymind.io>
* remove auther in sort tad kernel code
Signed-off-by: Yurii <yurii@skymind.io>
* provide depthwise_conv2d for cuda
Signed-off-by: Yurii <yurii@skymind.io>
* - max_pooling_with_argmax
- null check for special use
Signed-off-by: raver119 <raver119@gmail.com>
* dts cuda
Signed-off-by: raver119 <raver119@gmail.com>
* provide sconv2d for cuda
Signed-off-by: Yurii <yurii@skymind.io>
* std cuda
Signed-off-by: raver119 <raver119@gmail.com>
* Refactored non_max_suppression op to conform TF implementation.
* Improved suppression helper.
* provide pooling3d for cuda
Signed-off-by: Yurii <yurii@skymind.io>
* minor lstm rearrangements
Signed-off-by: raver119 <raver119@gmail.com>
* more of minor lstm rearrangements
Signed-off-by: raver119 <raver119@gmail.com>
* (bi)dynamic_rnn
Signed-off-by: raver119 <raver119@gmail.com>
* templates init order
Signed-off-by: raver119 <raver119@gmail.com>
* Refactored non_max_suppression op.
* Added cuda kernel for non_max_suppression.
* CPU sort by key/value
Signed-off-by: raver119 <raver119@gmail.com>
* CPU sort TAD by key/value
Signed-off-by: raver119 <raver119@gmail.com>
* CPU sort TAD by key/value tests
Signed-off-by: raver119 <raver119@gmail.com>
* Eliminate compiler error with cuda implementation.
* - repaired gradCheck in cuda
- provide conv2d_bp for cuda
Signed-off-by: Yurii <yurii@skymind.io>
* missed signature
Signed-off-by: raver119 <raver119@gmail.com>
* provide depthwise_conv2d_bp for cuda
Signed-off-by: Yurii <yurii@skymind.io>
* Implementation of lup helper with cuda kernel. Initial commit.
* further work on backprops for convolutions
Signed-off-by: Yurii <yurii@skymind.io>
* CUDA linear sort by key/val
Signed-off-by: raver119 <raver119@gmail.com>
* CUDA tad sort by key/val
Signed-off-by: raver119 <raver119@gmail.com>
* start providing of backprop for pooling2d/3d
Signed-off-by: Yurii <yurii@skymind.io>
* Added atomicAdd for bool datatype.
* dynamic partition concept
Signed-off-by: raver119 <raver119@gmail.com>
* dynamic partition concept
Signed-off-by: raver119 <raver119@gmail.com>
* dynamic partition scalar CUDA
Signed-off-by: raver119 <raver119@gmail.com>
* important comment
Signed-off-by: raver119 <raver119@gmail.com>
* fix pooling2d/3d backprop helpers
Signed-off-by: Yurii <yurii@skymind.io>
* Added non-linear test with dynamic_partition.
* Improved test for dynamic_partition.
* dynamic_partition TAD concept
Signed-off-by: raver119 <raver119@gmail.com>
* - dynamic_partition TAD CUDA impl
- dynamic_partition TAD CPU fix
Signed-off-by: raver119 <raver119@gmail.com>
* - rewrite cpu code for usampling2d/3d
- write cuda code for usampling2d/3d
Signed-off-by: Yurii <yurii@skymind.io>
* dynamic_stitch CUDA vector case
Signed-off-by: raver119 <raver119@gmail.com>
* dynamic_stitch CUDA TAD case concept
Signed-off-by: raver119 <raver119@gmail.com>
* dynamic_stitch CUDA TAD case impl
Signed-off-by: raver119 <raver119@gmail.com>
* Added tests for dynamic_stitch 3D-4D cases.
* minor tests tweaks
Signed-off-by: raver119 <raver119@gmail.com>
* Fixed type check for dynamic stitch.
* min/max bp
Signed-off-by: raver119 <raver119@gmail.com>
* rewrite code for upsampling2d/3d cpu
Signed-off-by: Yurii <yurii@skymind.io>
* reduce min/max/norm_max bp
Signed-off-by: raver119 <raver119@gmail.com>
* lup implementation. Additional enhancements.
* provide code for upsamling2d/3d backprop
Signed-off-by: Yurii <yurii@skymind.io>
* weightedCrossEntropyWithLogits
Signed-off-by: raver119 <raver119@gmail.com>
* Fixed template math atomicMul for 64bit ints.
* Refactored dynamic_partition_bp op.
* inverseBroadcast fix
Signed-off-by: raver119 <raver119@gmail.com>
* DynamicPartitionBP test datatype fixed.
* - nd4j_atomicMul Windows fix
- cpu/NDArrayLambda.hpp excluded from CUDA
Signed-off-by: raver119 <raver119@gmail.com>
2019-06-27 17:37:04 +02:00
|
|
|
}
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
const char* runLightBenchmarkSuit(bool printOut) {
|
2019-08-26 18:57:51 +02:00
|
|
|
try {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::LightBenchmarkSuit suit;
|
2019-08-26 18:57:51 +02:00
|
|
|
auto result = suit.runSuit();
|
2019-07-12 07:21:15 +02:00
|
|
|
|
2019-08-26 18:57:51 +02:00
|
|
|
if (printOut)
|
|
|
|
nd4j_printf("%s\n", result.data());
|
2019-07-12 07:21:15 +02:00
|
|
|
|
2019-08-26 18:57:51 +02:00
|
|
|
auto chars = new char[result.length() + 1];
|
|
|
|
std::memcpy(chars, result.data(), result.length());
|
|
|
|
chars[result.length()] = (char) 0x0;
|
2019-07-12 07:21:15 +02:00
|
|
|
|
2019-08-26 18:57:51 +02:00
|
|
|
return chars;
|
|
|
|
} catch (std::exception &e) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
2019-08-26 18:57:51 +02:00
|
|
|
return nullptr;
|
|
|
|
}
|
2019-07-12 07:21:15 +02:00
|
|
|
}
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
Nd4jLong getCachedMemory(int deviceId) {
|
2020-06-06 14:26:55 +02:00
|
|
|
return sd::ConstantHelper::getInstance().getCachedAmount(deviceId);
|
2019-07-12 07:21:15 +02:00
|
|
|
}
|
|
|
|
|
2019-07-22 13:34:08 +02:00
|
|
|
const char* runFullBenchmarkSuit(bool printOut) {
|
2019-08-26 18:57:51 +02:00
|
|
|
try {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::FullBenchmarkSuit suit;
|
2019-08-26 18:57:51 +02:00
|
|
|
auto result = suit.runSuit();
|
2019-07-12 07:21:15 +02:00
|
|
|
|
2019-08-26 18:57:51 +02:00
|
|
|
if (printOut)
|
|
|
|
nd4j_printf("%s\n", result.data());
|
2019-07-12 07:21:15 +02:00
|
|
|
|
2019-08-26 18:57:51 +02:00
|
|
|
auto chars = new char[result.length() + 1];
|
|
|
|
std::memcpy(chars, result.data(), result.length());
|
|
|
|
chars[result.length()] = (char) 0x0;
|
2019-07-12 07:21:15 +02:00
|
|
|
|
2019-08-26 18:57:51 +02:00
|
|
|
return chars;
|
|
|
|
} catch (std::exception &e) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
2019-08-26 18:57:51 +02:00
|
|
|
return nullptr;
|
|
|
|
}
|
2019-07-12 07:21:15 +02:00
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::LaunchContext* defaultLaunchContext() {
|
[WIP] multi-device support (#80)
* fix pad javadoc and @see links. (#72)
Signed-off-by: Robert Altena <Rob@Ra-ai.com>
* [WIP] More fixes (#73)
* special tests for ConstantTadHelper/ConstantShapeHelper
Signed-off-by: raver119 <raver119@gmail.com>
* release methods for data buffers
Signed-off-by: raver119 <raver119@gmail.com>
* delete temporary buffer Java side
Signed-off-by: raver119 <raver119@gmail.com>
* delete temporary buffer Java side
Signed-off-by: raver119 <raver119@gmail.com>
* delete temporary TadPack C++/Java side (#74)
Signed-off-by: raver119 <raver119@gmail.com>
* Zoo model TF import test updates (#75)
* argLine fix, update compression_gru comment
* updated comment for xception
* undid but commented argLine change
* updated xlnet comment
* copyright headers
* - new NDArray methods like()/ulike() (#77)
- fix for depthwise_conv2d_bp + special test
Signed-off-by: raver119 <raver119@gmail.com>
* upsampling2d fix CUDA
Signed-off-by: raver119 <raver119@gmail.com>
* DL4J trace logging (#79)
* MLN/CG trace logging for debugging
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Tiny tweak
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* strided_slice_bp shape fn leak fix
Signed-off-by: raver119 <raver119@gmail.com>
* SameDiff fixes and naming (#78)
* remove SDVariable inplace methods
* import methods
* npe fix in OpVal
* removed SameDiff inplace ops from tests
* Naming updates, moved to centralized methods in SameDiff, should use op_#:# for everything
* quick fixes
* javadoc
* SDVariable eval with placeholders
* use regex match
* better matching
* initial commit
Signed-off-by: raver119 <raver119@gmail.com>
* initial commit
Signed-off-by: raver119 <raver119@gmail.com>
* fix javadoc. (#76)
* fix javadoc.
Signed-off-by: Robert Altena <Rob@Ra-ai.com>
* replace most @see with @link s.
Signed-off-by: Robert Altena <Rob@Ra-ai.com>
* 4 additional tests
Signed-off-by: raver119 <raver119@gmail.com>
* launch context reorganization
Signed-off-by: raver119 <raver119@gmail.com>
* LaunchContext reorganization
Signed-off-by: raver119 <raver119@gmail.com>
* per-device LaunchContext
Signed-off-by: raver119 <raver119@gmail.com>
* Various DL4J/ND4J fixes (#81)
* #7954 Force refresh of UI when switching tabs on overview page
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* #8017 Concurrent modification exception (synchronize) fix
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* #8033 Don't initialize updater in middle of writing memory crash dump
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* #8208 Fix shape checks for ND4J int[] creator methods
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* #6385 #7992 Keras import naming fixes + cleanup
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* #8016 Upsampling3D - add NDHWC format support
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* ContextBuffers as separate entity
Signed-off-by: raver119 <raver119@gmail.com>
* Refactor NativeOps.h to export C functions
* Actually export functions from NativeOps.h
* Adapt the Java wrappers in ND4J generated with JavaCPP
* Create C wrappers for some of the C++ classes currently used by ND4J
* ContextBuffers as separate entity
Signed-off-by: raver119 <raver119@gmail.com>
* remove duplicate code in createBufferDetached. (#83)
Signed-off-by: Robert Altena <Rob@Ra-ai.com>
* Keras model import - updater lr fix (#84)
* Keras model import - updater lr fix
Signed-off-by: eraly <susan.eraly@gmail.com>
* Keras model import - updater lr fix, cleanup
Signed-off-by: eraly <susan.eraly@gmail.com>
* ContextBuffers as separate entity
Signed-off-by: raver119 <raver119@gmail.com>
* ContextBuffers as separate entity
Signed-off-by: raver119 <raver119@gmail.com>
* Fix functions of OpaqueVariablesSet
* thread-local buffers/affinity
Signed-off-by: raver119 <raver119@gmail.com>
* thread safety for LaunchContext
Signed-off-by: raver119 <raver119@gmail.com>
* more of thread safety
Signed-off-by: raver119 <raver119@gmail.com>
* one more multi threaded test
Signed-off-by: raver119 <raver119@gmail.com>
* SameDiff Convolution Config validation, better output methods (#82)
* Conv Config validation & tests
Signed-off-by: Ryan Nett <rnett@skymind.io>
* stackOutputs utility method
Signed-off-by: Ryan Nett <rnett@skymind.io>
* use constructor for validation, support negative kernel sizes (infered from weights)
Signed-off-by: Ryan Nett <rnett@skymind.io>
* better output methods
Signed-off-by: Ryan Nett <rnett@skymind.io>
* move output to be with fit and evaluate
Signed-off-by: Ryan Nett <rnett@skymind.io>
* fixes
Signed-off-by: Ryan Nett <rnett@skymind.io>
* more fixes
Signed-off-by: Ryan Nett <rnett@skymind.io>
* refactor duplicate code from pad methods. (#86)
* refactor duplicate code from pad methods.
Signed-off-by: Robert Altena <Rob@Ra-ai.com>
* replace switch with if.
Signed-off-by: Robert Altena <Rob@Ra-ai.com>
* Various ND4J/DL4J fixes and improvements (#87)
* Reshape and reallocate - small fixes
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Reshape and reallocate - small fixes
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* #6488 ElementWiseVertex broadcast support
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Constructors and broadcast supported it Transforms.max/min
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* #8054 ElementWiseVertex now supports broadcast inputs
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* #8057 Nd4j.create overload dtype fix
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* #7551 ND4J Shape validation fix
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* [WIP] Numpy boolean import (#91)
* numpy bool type
Signed-off-by: raver119 <raver119@gmail.com>
* numpy bool java side
Signed-off-by: raver119 <raver119@gmail.com>
* remove create method with unused parameter. (#89)
* remove create method with unused parameter.
* removed more unused methods.
Signed-off-by: Robert Altena <Rob@Ra-ai.com>
* removing more unused code.
Signed-off-by: Robert Altena <Rob@Ra-ai.com>
* last removal of unused code.
Signed-off-by: Robert Altena <Rob@Ra-ai.com>
* remove createSparse methods. (#92)
Signed-off-by: Robert Altena <Rob@Ra-ai.com>
* Various ND4J/DL4J fixes (#90)
* Deprecate Old*Op instances
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* #8063 #8054 Broadcast exceptions + cleanup inplace ops
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Small fix
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Remove bad test condition
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* #7993 Fix shape function issue in crop_and_resize op
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* DL4J SameDiff lambda layer fix
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* #8029 Fix for pnorm backprop math
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* #8038 Fix Op profiler NaN/Inf triggering + add tests (#93)
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* createUninitializedDetached refactoring. (#94)
* wip
* update interface, add null implementations.
* Breaking one test in a weird way.
Signed-off-by: Robert Altena <Rob@Ra-ai.com>
* createUninitializedDetached refactored.
Signed-off-by: Robert Altena <Rob@Ra-ai.com>
* cuda build fix for issues introduced by recent refactoring
Signed-off-by: raver119 <raver119@gmail.com>
* [WIP] More of CUDA (#95)
* initial commit
Signed-off-by: raver119 <raver119@gmail.com>
* Implementation of hashcode cuda helper. Working edition.
* Fixed parallel test input arangements.
* Fixed tests for hashcode op.
* Fixed shape calculation for image:crop_and_resize op and test.
* NativeOps tests. Initial test suite.
* Added tests for indexReduce methods.
* Added test on execBroadcast with NDArray as dimensions.
* Added test on execBroadcastBool with NDArray as dimensions.
* Added tests on execPairwiseTransform and execPairwiseTransofrmBool.
* Added tests for execReduce with scalar results.
* Added reduce tests for non-empty dims array.
* Added tests for reduce3.
* Added tests for execScalar.
* Added tests for execSummaryStats.
* - provide cpu/cuda code for batch_to_space
- testing it
Signed-off-by: Yurii <yurii@skymind.io>
* - remove old test for batch_to_space (had wrong format and numbers were not checked)
Signed-off-by: Yurii <yurii@skymind.io>
* Fixed complilation errors with test.
* Added test for execTransformFloat.
* Added test for execTransformSame.
* Added test for execTransformBool.
* Added test for execTransformStrict.
* Added tests for execScalar/execScalarBool with TADs.
* Added test for flatten.
* - provide cpu/cuda code for space_to_Batch operaion
Signed-off-by: Yurii <yurii@skymind.io>
* Added test for concat.
* comment unnecessary stuff in s_t_b
Signed-off-by: Yurii <yurii@skymind.io>
* Added test for specialConcat.
* Added tests for memcpy/set routines.
* Fixed pullRow cuda test.
* Added pullRow test.
* Added average test.
* - correct typo in NDArray::applyPairwiseTransform(nd4j::pairwise::BoolOps op...)
Signed-off-by: Yurii <yurii@skymind.io>
* - debugging and fixing cuda tests in JavaInteropTests file
Signed-off-by: Yurii <yurii@skymind.io>
* - correct some tests
Signed-off-by: Yurii <yurii@skymind.io>
* Added test for shuffle.
* Fixed ops declarations.
* Restored omp and added shuffle test.
* Added convertTypes test.
* Added tests for execRandom. Eliminated usage of RandomBuffer with NativeOps.
* Added sort tests.
* Added tests for execCustomOp.
* - further debuging and fixing tests terminated with crash
Signed-off-by: Yurii <yurii@skymind.io>
* Added tests for calculateOutputShapes.
* Addded Benchmarks test.
* Commented benchmark tests.
* change assertion
Signed-off-by: raver119 <raver119@gmail.com>
* Added tests for apply_sgd op. Added cpu helper for that op.
* Implement cuda helper for aplly_sgd op. Fixed tests for NativeOps.
* Added test for assign broadcastable.
* Added tests for assign_bp op.
* Added tests for axpy op.
* - assign/execScalar/execTransformAny signature change
- minor test fix
Signed-off-by: raver119 <raver119@gmail.com>
* Fixed axpy op.
* meh
Signed-off-by: raver119 <raver119@gmail.com>
* - fix tests for nativeOps::concat
Signed-off-by: Yurii <yurii@skymind.io>
* sequential transform/scalar
Signed-off-by: raver119 <raver119@gmail.com>
* allow nested parallelism
Signed-off-by: raver119 <raver119@gmail.com>
* assign_bp leak fix
Signed-off-by: raver119 <raver119@gmail.com>
* block setRNG fix
Signed-off-by: raver119 <raver119@gmail.com>
* enable parallelism by default
Signed-off-by: raver119 <raver119@gmail.com>
* enable nested parallelism by default
Signed-off-by: raver119 <raver119@gmail.com>
* Added cuda implementation for row_count helper.
* Added implementation for tnse gains op helper.
* - take into account possible situations when input arrays are empty in reduce_ cuda stuff
Signed-off-by: Yurii <yurii@skymind.io>
* Implemented tsne/edge_forces op cuda-based helper. Parallelized cpu-based helper for edge_forces.
* Added kernel for tsne/symmetrized op heleper.
* Implementation of tsne/symmetrized op cuda helper. Working edition.
* Eliminated waste printfs.
* Added test for broadcastgradientargs op.
* host-only fallback for empty reduce float
Signed-off-by: raver119 <raver119@gmail.com>
* - some tests fixes
Signed-off-by: Yurii <yurii@skymind.io>
* - correct the rest of reduce_ stuff
Signed-off-by: Yurii <yurii@skymind.io>
* - further correction of reduce_ stuff
Signed-off-by: Yurii <yurii@skymind.io>
* Added test for Cbow op. Also added cuda implementation for cbow helpers.
* - improve code of stack operation for scalar case
Signed-off-by: Yurii <yurii@skymind.io>
* - provide cuda kernel for gatherND operation
Signed-off-by: Yurii <yurii@skymind.io>
* Implementation of cbow helpers with cuda kernels.
* minor tests tweaks
Signed-off-by: raver119 <raver119@gmail.com>
* minor tests tweaks
Signed-off-by: raver119 <raver119@gmail.com>
* - further correction of cuda stuff
Signed-off-by: Yurii <yurii@skymind.io>
* Implementatation of cbow op helper with cuda kernels. Working edition.
* Skip random testing for cudablas case.
* lstmBlockCell context fix
Signed-off-by: raver119 <raver119@gmail.com>
* Added tests for ELU and ELU_BP ops.
* Added tests for eq_scalar, gt_scalar, gte_scalar and lte_scalar ops.
* Added tests for neq_scalar.
* Added test for noop.
* - further work on clipbynorm_bp
Signed-off-by: Yurii <yurii@skymind.io>
* - get rid of concat op call, use instead direct concat helper call
Signed-off-by: Yurii <yurii@skymind.io>
* lstmBlockCell context fix
Signed-off-by: raver119 <raver119@gmail.com>
* Added tests for lrelu and lrelu_bp.
* Added tests for selu and selu_bp.
* Fixed lrelu derivative helpers.
* - some corrections in lstm
Signed-off-by: Yurii <yurii@skymind.io>
* operator * result shape fix
Signed-off-by: raver119 <raver119@gmail.com>
* - correct typo in lstmCell
Signed-off-by: Yurii <yurii@skymind.io>
* few tests fixed
Signed-off-by: raver119 <raver119@gmail.com>
* CUDA inverse broadcast bool fix
Signed-off-by: raver119 <raver119@gmail.com>
* disable MMAP test for CUDA
Signed-off-by: raver119 <raver119@gmail.com>
* BooleanOp syncToDevice
Signed-off-by: raver119 <raver119@gmail.com>
* meh
Signed-off-by: raver119 <raver119@gmail.com>
* additional data types for im2col/col2im
Signed-off-by: raver119 <raver119@gmail.com>
* Added test for firas_sparse op.
* one more RandomBuffer test excluded
Signed-off-by: raver119 <raver119@gmail.com>
* Added tests for flatten op.
* Added test for Floor op.
* bunch of tests fixed
Signed-off-by: raver119 <raver119@gmail.com>
* mmulDot tests fixed
Signed-off-by: raver119 <raver119@gmail.com>
* more tests fixed
Signed-off-by: raver119 <raver119@gmail.com>
* Implemented floordiv_bp op and tests.
* Fixed scalar case with cuda implementation for bds.
* - work on cuda kernel for clip_by_norm backprop op is completed
Signed-off-by: Yurii <yurii@skymind.io>
* Eliminate cbow crach.
* more tests fixed
Signed-off-by: raver119 <raver119@gmail.com>
* more tests fixed
Signed-off-by: raver119 <raver119@gmail.com>
* Eliminated abortion with batched nlp test.
* more tests fixed
Signed-off-by: raver119 <raver119@gmail.com>
* Fixed shared flag initializing.
* disabled bunch of cpu workspaces tests
Signed-off-by: raver119 <raver119@gmail.com>
* scalar operators fix: missing registerSpecialUse call
Signed-off-by: raver119 <raver119@gmail.com>
* Fixed logdet for cuda and tests.
* - correct clipBynorm_bp
Signed-off-by: Yurii <yurii@skymind.io>
* Fixed crop_and_resize shape datatype.
* - correct some mmul tests
Signed-off-by: Yurii <yurii@skymind.io>
* build fix
Signed-off-by: raver119 <raver119@gmail.com>
* exclude two methods for JNI
Signed-off-by: raver119 <raver119@gmail.com>
* exclude two methods for JNI
Signed-off-by: raver119 <raver119@gmail.com>
* exclude two methods for JNI (#97)
Signed-off-by: raver119 <raver119@gmail.com>
* temporary stack fix
Signed-off-by: raver119 <raver119@gmail.com>
* round robin affinity test
Signed-off-by: raver119 <raver119@gmail.com>
* get rid of legacy CudaContext methods
Signed-off-by: raver119 <raver119@gmail.com>
* get rid of legacy ContextPool classes/methods
Signed-off-by: raver119 <raver119@gmail.com>
* one legacy test removed
Signed-off-by: raver119 <raver119@gmail.com>
* few more fields rearranged
Signed-off-by: raver119 <raver119@gmail.com>
* OpaqueLaunchContext
Signed-off-by: raver119 <raver119@gmail.com>
* OpaqueLaunchContext++
Signed-off-by: raver119 <raver119@gmail.com>
* more of OpaqueLaunchContext methods
Signed-off-by: raver119 <raver119@gmail.com>
* LaunchContext -> CudaContext
Signed-off-by: raver119 <raver119@gmail.com>
* AffinityManger changes
Signed-off-by: raver119 <raver119@gmail.com>
* AffinityManger changes
Signed-off-by: raver119 <raver119@gmail.com>
* cusolver handles
Signed-off-by: raver119 <raver119@gmail.com>
* typo
Signed-off-by: raver119 <raver119@gmail.com>
* cusolver method
Signed-off-by: raver119 <raver119@gmail.com>
* cusolver handle propagated
Signed-off-by: raver119 <raver119@gmail.com>
* blas/solver handles
Signed-off-by: raver119 <raver119@gmail.com>
* one more test
Signed-off-by: raver119 <raver119@gmail.com>
* legacy concat implementations replaced with new CustomOp
Signed-off-by: raver119 <raver119@gmail.com>
* one more test
Signed-off-by: raver119 <raver119@gmail.com>
* concat now uses way more blocks
Signed-off-by: raver119 <raver119@gmail.com>
* print
Signed-off-by: raver119 <raver119@gmail.com>
* no more triple template mmul
Signed-off-by: raver119 <raver119@gmail.com>
* bunch of kernels have dtypes reconsidered
Signed-off-by: raver119 <raver119@gmail.com>
* bunch of kernels have dtypes reconsidered
Signed-off-by: raver119 <raver119@gmail.com>
* bitonic sort reorganized
Signed-off-by: raver119 <raver119@gmail.com>
* bunch of cpu stuff removed from cuda scope
Signed-off-by: raver119 <raver119@gmail.com>
* bunch of cpu stuff removed from cuda scope
Signed-off-by: raver119 <raver119@gmail.com>
* type conversions moved to generic impl
Signed-off-by: raver119 <raver119@gmail.com>
* cpu data types pass
Signed-off-by: raver119 <raver119@gmail.com>
* non_max_suppression
Signed-off-by: raver119 <raver119@gmail.com>
* sortByValue fix
Signed-off-by: raver119 <raver119@gmail.com>
* ignore all mixed datatype tests for mmul
Signed-off-by: raver119 <raver119@gmail.com>
* special handling of OpProfiler exceptions
Signed-off-by: raver119 <raver119@gmail.com>
* - one failing concat test in cpp
- Nd4j.tile now uses op internally
Signed-off-by: raver119 <raver119@gmail.com>
* get back dtype exception for legacy arrays deserialization
Signed-off-by: raver119 <raver119@gmail.com>
2019-08-14 15:52:34 +02:00
|
|
|
return LaunchContext::defaultContext();
|
|
|
|
}
|
|
|
|
|
|
|
|
Nd4jPointer lcScalarPointer(OpaqueLaunchContext* lc) {
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
|
|
|
Nd4jPointer lcReductionPointer(OpaqueLaunchContext* lc) {
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
|
|
|
Nd4jPointer lcAllocationPointer(OpaqueLaunchContext* lc) {
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
|
|
|
Nd4jPointer lcExecutionStream(OpaqueLaunchContext* lc) {
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
|
|
|
Nd4jPointer lcCopyStream(OpaqueLaunchContext* lc) {
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
|
|
|
Nd4jPointer lcBlasHandle(OpaqueLaunchContext* lc) {
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
|
|
|
Nd4jPointer lcSolverHandle(OpaqueLaunchContext* lc) {
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
2019-08-26 18:57:51 +02:00
|
|
|
int lastErrorCode() {
|
2020-03-02 10:49:41 +01:00
|
|
|
return sd::LaunchContext::defaultContext()->errorReference()->errorCode();
|
2019-08-26 18:57:51 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
const char* lastErrorMessage() {
|
2020-03-02 10:49:41 +01:00
|
|
|
return sd::LaunchContext::defaultContext()->errorReference()->errorMessage();
|
2019-08-26 18:57:51 +02:00
|
|
|
}
|
Merge master to upstream (#7945)
* Shugeo strided slice zeros (#14)
* Modified strided_slice op to properly work with empty-like shapes.
* Fixed test for reduce_mean with empty-like input.
* [WIP] Last merge (#15)
* correct logsoftmax looss (#2)
* Small SameDiff listener fix (#4)
* Various fixes (#6)
* #7839 Fix for asXMatrix and tests
* #7866 EmbeddingSequenceLayer dtype fix + test
* #7856 SameDiff save/load stream methods
* #7859 RegressionEvaluation rank 4 fix + tests + axis configuration
* EvaluationBinary 3d/4d
* More evaluation 3d/4d tests
* #7847 Evaluation empty checks
* Small test ifx
* #7848 Fix median edge case
* Improve DL4J samediff layer tests
* [WIP] FastText wrapper implemented (#8)
* FastText implemented
* Some fixes
* Fix shapes for wordsNearest
* Validation of input vectors
* Fixes
* Fixed test
* Thread tagged
* Some tweaks
* setContextClassLoader for DeallocatorServiceThread
* Numpy format tests (#1)
* Various fixes (#11)
* #7852 SameDiff gather fix
* #7892 SameDiff placeholder to constant conversion
* #7890 validate input rank for MLN/CG init methods
* Fix broken permute shape calculation
* Permute and gather fixes
* Tests
* #7850 LogSumExp fix + test
* Handful of test fixes
* Empty arrays with non-scalar shapes (#10)
* minor rearrangements for lambdas
* empty tensors with non-scalar shapes
* numpy empty tensors with non-scalar shapes
* few more empty tweaks
* Small fixes
* conv3d signature update
* micro fix in batchnorm mkldnn
* Import fixes
* Fix
* MKL-DNN update
* Small fill fix
* fill with empty input + test
* Fixes
* Small error improvement
* Fix
* one special test
* couple of fixes for lstm
* Rewrite TFGraphMapper.getNDArrayFromTensor to be maintainable and less error prone
* Fixes
* FP16
* Unsigned
* BFloat16
* Fill op - empty tweaks
* - couple of fixes for empty arrays construction
- stack updated
* strided slice fix
* one transform test
* provide method for reducing shapeInfo in case of input array is empty
* Fixed reduceAlongDimensions to use empty input properly.
* couple of broadcast tests
* couple of tests broadcast tests + tweak to make them pass
* add check of non-empty to methods producing sub-arrays
* Fixed reshapeC with zeros in shape.
* complete empty check in reduce_... legacy ops
* Concat and cumsum/prod
* Tweak to empty shape inference on import
* add empty check to the rest of reduce legacy ops
* one more test
* correct typo in evalReduceShapeInfoEmpty
* Added tests for reduce_* ops to tests with zero shapes.
* few more tests for empty reductions
* Fixed strided_slice op with empty case and tests.
* one more empty reduction test
* Fixed strided_slice test.
* add empty check to NDArray::reshapei
* infOrMax
* empty min/max with infinity tests
* made unstack working correctly with empty arrays
* few IndexReduce tests + tweaks for empty shapes
* add test for empty concat
* few tests fixed
* Validation fix for reductions on empty shapes
* Reverse fix
* Reduction shape calc fixes
* SameDiff.generateOutputVariable: don't use shape function to determine number of outputs
* Range fix
* - NDArray constructor updated for scalars/empty arrays
- few tests fixed
* More fixes
* Empty creator fixes
* concat fix
* concat fix
* TF import tests: allow 'both all NaN' and 'both all inf' to pass
* Slice, zero fraction, and reshape fixes
* transpose, gather
* Zero fraction
* scalar cast fix
* Empty reduction axis support
* few more tests fixed
* Fixed input checks conforming with TF for concat op and tests.
* few tests fixed
* matmul scalar shape fix
* Fixed checkout for data type and scalarity with concat to allow non-empty scalars with vector concats.
* broadcast bool fix
* few more tests
* few more tests
* correct evalReduceShapeInfoEmpty
* argmax/argmin + tests
* one more empty edge case + one more test
* argmax/argmin/realdiv_bp tweaks
* empty reshape test + fix
* Helper fixes
* Small fixes
* Gather test fix
* Gather test fix
* Small fixes
* reduce scalar zero values
* scalar mean workaround
* Remove debug code
* along dim mean workaround
* one more test
* - equalsTo() tweak for empty arrays
- one more test
* broadcast tweaks
* [WIP] Fixing outstanding issues for NLP (#9)
* Avoid using not-inited objects
* Test fixed.
* Redundant method avoided for models like FastText
* KMeans++ implementation
* KMeans++ implementation
* Disable parallel execution
* KMeans++
* Tests
* Dev branch merge (#16)
* SameDiff: convertDataType and gradient check util improvements (#12)
* GradCheck util improvements
* StopGradient constructor + test
* SameDiff: Add datatype conversion
* Javadoc and add DataType.isNumerical()
* Small fix
* Fix SameDiff TF import test cases intermediate naming (workaround for bad default)
* TFGraphTestAllHelper: check intermediates in execution order
* Add missing debug listener
* [WIP] lstmBlock fix + other changes (#13)
- fixes lstmBlock issue
- changes NDArray method reshape(), permute(), transpose() by making them return instance instead of pointer
- CheckNumerics op
- fixes for ReduceBool IsInfOrNan & IsFinite
* Small test fix
* CheckNumerics op wrapper
* Fix some issues on master (#17)
* Fix DataVec test issue
* Fix issue with dl4j SameDiff output layer
* Dtype fix for lambda layers
* #7912 BertIterator dtype fix (use float32 not global default)
* [WIP] Next set of CUDA stuff (#7)
New CUDA implementations and improvements
* bad file
* Dev branch master merge (#23)
* SameDiff: convertDataType and gradient check util improvements (#12)
* GradCheck util improvements
* StopGradient constructor + test
* SameDiff: Add datatype conversion
* Javadoc and add DataType.isNumerical()
* Small fix
* Fix SameDiff TF import test cases intermediate naming (workaround for bad default)
* TFGraphTestAllHelper: check intermediates in execution order
* Add missing debug listener
* [WIP] lstmBlock fix + other changes (#13)
- fixes lstmBlock issue
- changes NDArray method reshape(), permute(), transpose() by making them return instance instead of pointer
- CheckNumerics op
- fixes for ReduceBool IsInfOrNan & IsFinite
* Small test fix
* CheckNumerics op wrapper
* Compatibility of deserialization (#18)
Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>
* SameDiff: add activation gradient checking support for debugging (#19)
* SameDiff gradient checker: first pass on activation gradient checks
* Fixes + tests for activation gradient checking
* Javadoc
* [WIP] Some nd4j data type corrections (#20)
* Adjust data type
* Set correct Data type.
* Size of proper data type.
* fix averaged cpu load (#22)
* SameDiff ops, TF import and fixes (#24)
* CheckNumerics tests + fixes + misc fixes
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Fake quant
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Fixes
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* FakeQuantWithMinMaxArgs
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* CheckNumerics fix
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Fix libnd4j ALL_INTS and ALL_FLOATS declaration (uint and bfloat types)
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Small fix
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Javadoc
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Exception tweak
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* fix
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Fix for out of scope stack allocated var use
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Ignores
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Ignore for known failing test (already logged issue)
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Merge upstream to fork (#25)
* Add thousand-separator commas to TotalParams (#7915)
* Add thousand-separator commas to TotalParams
The number of parameters can be quite large, and it would help the reading of the summary printout to have the TotalParams column & values at the bottom have thousand-separator-commas in them.
* Add thousand-separator commas to MultiLayerNetwork
Corresponding change to MultiLayerNetwork
Signed-off-by: Jxtps Jxtps <jxtps435@gmail.com>
* Update contributing and issue/PR templates (#7934)
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Fix link to AdaDelta paper (#7942)
Fix link to AdaDelta paper hosted on matthewzeiler.com
Signed-off-by: Jxtps
* Fixes, and ignores for known/logged failing issues (#7943)
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* SameDiff + DL4J/SameDiff: Multiple fixes (#28)
* #7919 HDF5 attribute buffer length fix
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* #7909 Arbiter constructor exception ux improvements
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* #7925 RNN output layer length checks
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* #7939 Add listener for validating inputs are not incorrectly modified
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* #7939 Integrate NonInplaceValidationListener into tests
* #7844 DL4J SameDiff fixes for variable minibatch size
* DL4J SameDiff fixes - ensure gradient for input placeholder is available
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Tweaks to ExternalErrorsFunction - use placeholders, make more robust
* Another fix
* More fixes
* More SameDiff/DL4J fixes
* Scope out scalar array creation in BaseScalarOp
* Remove debug code
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* [WIP] Final dev branch merge (#29)
* SameDiff: convertDataType and gradient check util improvements (#12)
* GradCheck util improvements
* StopGradient constructor + test
* SameDiff: Add datatype conversion
* Javadoc and add DataType.isNumerical()
* Small fix
* Fix SameDiff TF import test cases intermediate naming (workaround for bad default)
* TFGraphTestAllHelper: check intermediates in execution order
* Add missing debug listener
* [WIP] lstmBlock fix + other changes (#13)
- fixes lstmBlock issue
- changes NDArray method reshape(), permute(), transpose() by making them return instance instead of pointer
- CheckNumerics op
- fixes for ReduceBool IsInfOrNan & IsFinite
* Small test fix
* CheckNumerics op wrapper
* Compatibility of deserialization (#18)
Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>
* SameDiff: add activation gradient checking support for debugging (#19)
* SameDiff gradient checker: first pass on activation gradient checks
* Fixes + tests for activation gradient checking
* Javadoc
* [WIP] Some nd4j data type corrections (#20)
* Adjust data type
* Set correct Data type.
* Size of proper data type.
* fix averaged cpu load (#22)
* [WIP] Multiple dataset iterators (#27)
* Splitting dataset into arbitrary number
* Fixes
* Multiple split of iterator
* Test
* Test
* Some fixes
* signature change
* one more tweak
Signed-off-by: raver119 <raver119@gmail.com>
* one more test for sequential use of DataSetIteratorSplitter
Signed-off-by: raver119 <raver119@gmail.com>
* Fixes
* Fixes
* one more test for Alexander
Signed-off-by: raver119 <raver119@gmail.com>
* Some fixes
* Some fixes
* one more test for Alexander
Signed-off-by: raver119 <raver119@gmail.com>
* minor test fix
Signed-off-by: raver119 <raver119@gmail.com>
* Some fixes
* Some fixes
* couple of assertions tweaked
Signed-off-by: raver119 <raver119@gmail.com>
* MDS splitter test :/
Signed-off-by: raver119 <raver119@gmail.com>
* Minor refactoring
* Multi dataset
* Some fixes
* More tests
* Small number of test fixes/improvements (failures on CI) (#31)
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* [WIP] More CUDA stuff (#26)
* initial commit
Signed-off-by: raver119 <raver119@gmail.com>
* LRN BP CUDA
Signed-off-by: raver119 <raver119@gmail.com>
* less memory
Signed-off-by: raver119 <raver119@gmail.com>
* Fixed bug with crop_and_resize op helper.
* get rid of unnecessary index-calculation dunction
Signed-off-by: Yurii <yurii@skymind.io>
* Fixed sort with nth_element cuda-based helper.
* Refactored nth_element.
* Refactored nth_element op and tests.
* Modified usage of dim array with sortTad routine.
* Refactored main routine of helper for non_max_image_suppression op.
* non_max_image_suppression op helper with cuda kernel implementation. Initial revision.
* fix vol2col cuda kernel
* meh
Signed-off-by: raver119 <raver119@gmail.com>
* topK concept
Signed-off-by: raver119 <raver119@gmail.com>
* unsorted topK with scanWitdh of 1
Signed-off-by: raver119 <raver119@gmail.com>
* correct vol2col tests
* sorted/unsorted topK
Signed-off-by: raver119 <raver119@gmail.com>
* implementation and fixing col2im/col2vol
* Corrected usage flags with input/output with reverse op.
* dup is const now
Signed-off-by: raver119 <raver119@gmail.com>
* percentile op
Signed-off-by: raver119 <raver119@gmail.com>
* group tests for mapool2d
Signed-off-by: Yurii <yurii@skymind.io>
* special test for george
Signed-off-by: raver119 <raver119@gmail.com>
* less threads for sortTad
Signed-off-by: raver119 <raver119@gmail.com>
* provide conv2d for cuda
Signed-off-by: Yurii <yurii@skymind.io>
* remove auther in sort tad kernel code
Signed-off-by: Yurii <yurii@skymind.io>
* provide depthwise_conv2d for cuda
Signed-off-by: Yurii <yurii@skymind.io>
* - max_pooling_with_argmax
- null check for special use
Signed-off-by: raver119 <raver119@gmail.com>
* dts cuda
Signed-off-by: raver119 <raver119@gmail.com>
* provide sconv2d for cuda
Signed-off-by: Yurii <yurii@skymind.io>
* std cuda
Signed-off-by: raver119 <raver119@gmail.com>
* Refactored non_max_suppression op to conform TF implementation.
* Improved suppression helper.
* provide pooling3d for cuda
Signed-off-by: Yurii <yurii@skymind.io>
* minor lstm rearrangements
Signed-off-by: raver119 <raver119@gmail.com>
* more of minor lstm rearrangements
Signed-off-by: raver119 <raver119@gmail.com>
* (bi)dynamic_rnn
Signed-off-by: raver119 <raver119@gmail.com>
* templates init order
Signed-off-by: raver119 <raver119@gmail.com>
* Refactored non_max_suppression op.
* Added cuda kernel for non_max_suppression.
* CPU sort by key/value
Signed-off-by: raver119 <raver119@gmail.com>
* CPU sort TAD by key/value
Signed-off-by: raver119 <raver119@gmail.com>
* CPU sort TAD by key/value tests
Signed-off-by: raver119 <raver119@gmail.com>
* Eliminate compiler error with cuda implementation.
* - repaired gradCheck in cuda
- provide conv2d_bp for cuda
Signed-off-by: Yurii <yurii@skymind.io>
* missed signature
Signed-off-by: raver119 <raver119@gmail.com>
* provide depthwise_conv2d_bp for cuda
Signed-off-by: Yurii <yurii@skymind.io>
* Implementation of lup helper with cuda kernel. Initial commit.
* further work on backprops for convolutions
Signed-off-by: Yurii <yurii@skymind.io>
* CUDA linear sort by key/val
Signed-off-by: raver119 <raver119@gmail.com>
* CUDA tad sort by key/val
Signed-off-by: raver119 <raver119@gmail.com>
* start providing of backprop for pooling2d/3d
Signed-off-by: Yurii <yurii@skymind.io>
* Added atomicAdd for bool datatype.
* dynamic partition concept
Signed-off-by: raver119 <raver119@gmail.com>
* dynamic partition concept
Signed-off-by: raver119 <raver119@gmail.com>
* dynamic partition scalar CUDA
Signed-off-by: raver119 <raver119@gmail.com>
* important comment
Signed-off-by: raver119 <raver119@gmail.com>
* fix pooling2d/3d backprop helpers
Signed-off-by: Yurii <yurii@skymind.io>
* Added non-linear test with dynamic_partition.
* Improved test for dynamic_partition.
* dynamic_partition TAD concept
Signed-off-by: raver119 <raver119@gmail.com>
* - dynamic_partition TAD CUDA impl
- dynamic_partition TAD CPU fix
Signed-off-by: raver119 <raver119@gmail.com>
* - rewrite cpu code for usampling2d/3d
- write cuda code for usampling2d/3d
Signed-off-by: Yurii <yurii@skymind.io>
* dynamic_stitch CUDA vector case
Signed-off-by: raver119 <raver119@gmail.com>
* dynamic_stitch CUDA TAD case concept
Signed-off-by: raver119 <raver119@gmail.com>
* dynamic_stitch CUDA TAD case impl
Signed-off-by: raver119 <raver119@gmail.com>
* Added tests for dynamic_stitch 3D-4D cases.
* minor tests tweaks
Signed-off-by: raver119 <raver119@gmail.com>
* Fixed type check for dynamic stitch.
* min/max bp
Signed-off-by: raver119 <raver119@gmail.com>
* rewrite code for upsampling2d/3d cpu
Signed-off-by: Yurii <yurii@skymind.io>
* reduce min/max/norm_max bp
Signed-off-by: raver119 <raver119@gmail.com>
* lup implementation. Additional enhancements.
* provide code for upsamling2d/3d backprop
Signed-off-by: Yurii <yurii@skymind.io>
* weightedCrossEntropyWithLogits
Signed-off-by: raver119 <raver119@gmail.com>
* Fixed template math atomicMul for 64bit ints.
* Refactored dynamic_partition_bp op.
* inverseBroadcast fix
Signed-off-by: raver119 <raver119@gmail.com>
* DynamicPartitionBP test datatype fixed.
* - nd4j_atomicMul Windows fix
- cpu/NDArrayLambda.hpp excluded from CUDA
Signed-off-by: raver119 <raver119@gmail.com>
2019-06-27 17:37:04 +02:00
|
|
|
|
2020-01-04 07:06:44 +01:00
|
|
|
void ctxShapeFunctionOverride(OpaqueContext* ptr, bool reallyOverride) {
|
|
|
|
ptr->setShapeFunctionOverride(reallyOverride);
|
|
|
|
}
|
|
|
|
|
2019-09-11 20:50:28 +02:00
|
|
|
int binaryLevel() {
|
|
|
|
#ifdef CPU_FEATURES
|
|
|
|
|
|
|
|
#if defined(F_X64)
|
|
|
|
return 1;
|
|
|
|
#elif defined (F_AVX2)
|
|
|
|
return 2;
|
|
|
|
#elif defined (F_AVX512)
|
|
|
|
return 3;
|
|
|
|
#else
|
|
|
|
return 0;
|
|
|
|
#endif
|
|
|
|
|
|
|
|
#else
|
|
|
|
return 0;
|
|
|
|
#endif
|
|
|
|
}
|
|
|
|
|
|
|
|
int optimalLevel() {
|
|
|
|
#ifdef CPU_FEATURES
|
|
|
|
auto features = cpu_features::GetX86Info().features;
|
|
|
|
|
|
|
|
if (features.avx && features.avx2 && features.avx512f && features.avx512vl && features.avx512bw && features.avx512dq && features.avx512cd)
|
|
|
|
return 3;
|
|
|
|
else if (features.avx && features.avx2)
|
|
|
|
return 2;
|
|
|
|
else
|
|
|
|
return 1;
|
|
|
|
|
|
|
|
#else
|
|
|
|
return 0;
|
|
|
|
#endif
|
|
|
|
}
|
|
|
|
|
|
|
|
bool isMinimalRequirementsMet() {
|
|
|
|
#ifdef CPU_FEATURES
|
|
|
|
auto features = cpu_features::GetX86Info().features;
|
|
|
|
|
|
|
|
#if defined(F_X64)
|
|
|
|
return true;
|
|
|
|
#elif defined (F_AVX2)
|
|
|
|
return features.avx && features.avx2;
|
|
|
|
#elif defined (F_AVX512)
|
|
|
|
// we're optimizing for skylake-avx512 features, so we'll check those out
|
|
|
|
return features.avx && features.avx2 && features.avx512f && features.avx512vl && features.avx512bw && features.avx512dq && features.avx512cd;
|
|
|
|
#else
|
|
|
|
return true;
|
|
|
|
#endif
|
|
|
|
|
|
|
|
#else
|
|
|
|
return true;
|
|
|
|
#endif
|
|
|
|
}
|
|
|
|
|
|
|
|
bool isOptimalRequirementsMet() {
|
|
|
|
#ifdef CPU_FEATURES
|
|
|
|
auto b = ::binaryLevel();
|
|
|
|
auto o = ::optimalLevel();
|
|
|
|
|
|
|
|
if (b == o)
|
|
|
|
return true;
|
|
|
|
else
|
|
|
|
return false;
|
|
|
|
#else
|
|
|
|
return true;
|
|
|
|
#endif
|
|
|
|
}
|
|
|
|
|
2020-04-28 19:38:16 +02:00
|
|
|
OpaqueDataBuffer* dbAllocateDataBuffer(Nd4jLong elements, int dataType, bool allocateBoth) {
|
|
|
|
return allocateDataBuffer(elements, dataType, allocateBoth);
|
|
|
|
}
|
|
|
|
|
2020-01-04 11:27:50 +01:00
|
|
|
OpaqueDataBuffer* allocateDataBuffer(Nd4jLong elements, int dataType, bool allocateBoth) {
|
2020-01-24 08:11:09 +01:00
|
|
|
try {
|
|
|
|
auto dtype = DataTypeUtils::fromInt(dataType);
|
2020-03-02 10:49:41 +01:00
|
|
|
return new sd::InteropDataBuffer(elements * DataTypeUtils::sizeOf(dtype) , dtype, allocateBoth);
|
2020-01-24 08:11:09 +01:00
|
|
|
} catch (std::exception &e) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
2020-01-24 08:11:09 +01:00
|
|
|
return nullptr;
|
|
|
|
}
|
2020-01-04 11:27:50 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
Nd4jPointer dbPrimaryBuffer(OpaqueDataBuffer *dataBuffer) {
|
|
|
|
return dataBuffer->primary();
|
|
|
|
}
|
|
|
|
|
|
|
|
Nd4jPointer dbSpecialBuffer(OpaqueDataBuffer *dataBuffer) {
|
|
|
|
return dataBuffer->special();
|
|
|
|
}
|
|
|
|
|
|
|
|
void deleteDataBuffer(OpaqueDataBuffer *dataBuffer) {
|
|
|
|
delete dataBuffer;
|
|
|
|
}
|
|
|
|
|
2020-04-28 19:38:16 +02:00
|
|
|
OpaqueDataBuffer* dbCreateExternalDataBuffer(Nd4jLong elements, int dataType, Nd4jPointer primary, Nd4jPointer special) {
|
|
|
|
auto buffer = dbAllocateDataBuffer(0, dataType, false);
|
|
|
|
|
|
|
|
if (primary != nullptr)
|
|
|
|
buffer->setPrimary(primary, elements);
|
|
|
|
|
|
|
|
if (special != nullptr)
|
|
|
|
buffer->setSpecial(special, elements);
|
|
|
|
|
|
|
|
return buffer;
|
|
|
|
}
|
|
|
|
|
2020-01-04 11:27:50 +01:00
|
|
|
void dbSetPrimaryBuffer(OpaqueDataBuffer *dataBuffer, Nd4jPointer primaryBuffer, Nd4jLong numBytes) {
|
|
|
|
dataBuffer->setPrimary(primaryBuffer, numBytes);
|
|
|
|
}
|
|
|
|
|
|
|
|
void dbSetSpecialBuffer(OpaqueDataBuffer *dataBuffer, Nd4jPointer specialBuffer, Nd4jLong numBytes) {
|
|
|
|
dataBuffer->setSpecial(specialBuffer, numBytes);
|
|
|
|
}
|
|
|
|
|
|
|
|
void dbAllocatePrimaryBuffer(OpaqueDataBuffer *dataBuffer) {
|
|
|
|
dataBuffer->dataBuffer()->allocatePrimary();
|
|
|
|
}
|
|
|
|
|
|
|
|
void dbAllocateSpecialBuffer(OpaqueDataBuffer *dataBuffer) {
|
|
|
|
dataBuffer->dataBuffer()->allocateSpecial();
|
|
|
|
}
|
|
|
|
|
|
|
|
void dbExpandBuffer(OpaqueDataBuffer *dataBuffer, Nd4jLong elements) {
|
2020-01-24 08:11:09 +01:00
|
|
|
try {
|
|
|
|
dataBuffer->dataBuffer()->expand(elements * DataTypeUtils::sizeOf(dataBuffer->dataBuffer()->getDataType()));
|
|
|
|
} catch (std::exception &e) {
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
|
|
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
2020-01-24 08:11:09 +01:00
|
|
|
}
|
2020-01-04 11:27:50 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
OpaqueDataBuffer* dbCreateView(OpaqueDataBuffer *dataBuffer, Nd4jLong length, Nd4jLong offset) {
|
|
|
|
return new InteropDataBuffer(*dataBuffer, length, offset);
|
|
|
|
}
|
|
|
|
|
|
|
|
void dbSyncToSpecial(OpaqueDataBuffer *dataBuffer) {
|
|
|
|
dataBuffer->dataBuffer()->syncToSpecial();
|
|
|
|
}
|
|
|
|
|
|
|
|
void dbSyncToPrimary(OpaqueDataBuffer *dataBuffer) {
|
|
|
|
dataBuffer->dataBuffer()->syncToPrimary(nullptr);
|
|
|
|
}
|
|
|
|
|
|
|
|
void dbTickHostRead(OpaqueDataBuffer *dataBuffer) {
|
|
|
|
dataBuffer->dataBuffer()->readPrimary();
|
|
|
|
}
|
|
|
|
|
|
|
|
void dbTickHostWrite(OpaqueDataBuffer *dataBuffer) {
|
|
|
|
dataBuffer->dataBuffer()->writePrimary();
|
|
|
|
}
|
|
|
|
|
|
|
|
void dbTickDeviceRead(OpaqueDataBuffer *dataBuffer) {
|
|
|
|
dataBuffer->dataBuffer()->readSpecial();
|
|
|
|
}
|
|
|
|
|
|
|
|
void dbTickDeviceWrite(OpaqueDataBuffer *dataBuffer) {
|
|
|
|
dataBuffer->dataBuffer()->writeSpecial();
|
|
|
|
}
|
|
|
|
|
|
|
|
void dbExpand(OpaqueDataBuffer *dataBuffer, Nd4jLong elements) {
|
|
|
|
dataBuffer->expand(elements);
|
|
|
|
}
|
|
|
|
|
|
|
|
int dbLocality(OpaqueDataBuffer *dataBuffer) {
|
|
|
|
return 0;
|
|
|
|
}
|
|
|
|
|
|
|
|
void dbSetDeviceId(OpaqueDataBuffer *dataBuffer, int deviceId) {
|
|
|
|
dataBuffer->setDeviceId(deviceId);
|
|
|
|
}
|
|
|
|
|
|
|
|
int dbDeviceId(OpaqueDataBuffer *dataBuffer) {
|
|
|
|
return dataBuffer->deviceId();
|
|
|
|
}
|
|
|
|
|
|
|
|
void dbClose(OpaqueDataBuffer *dataBuffer) {
|
|
|
|
dataBuffer->getDataBuffer()->close();
|
|
|
|
}
|
|
|
|
|
2020-05-09 07:06:14 +02:00
|
|
|
BUILD_SINGLE_TEMPLATE(template void pullRowsGeneric, (void *, Nd4jLong const*, void*, Nd4jLong const*, const int, Nd4jLong const*, Nd4jLong const*, Nd4jLong const*, Nd4jLong const*, Nd4jLong const*), LIBND4J_TYPES);
|
|
|
|
BUILD_SINGLE_TEMPLATE(template void tearGeneric, (void *, Nd4jLong const* , Nd4jPointer*, Nd4jLong const*, Nd4jLong const*, Nd4jLong const*), LIBND4J_TYPES);
|
|
|
|
BUILD_SINGLE_TEMPLATE(template void shuffleGeneric, (void**, Nd4jLong* const*, void**, Nd4jLong* const*, int, int*, Nd4jLong* const*, Nd4jLong* const*), LIBND4J_TYPES);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
|