Signed-off-by: AlexDBlack <blacka101@gmail.com>
master
AlexDBlack 2019-08-19 18:46:47 +10:00
commit 01cb57041a
125 changed files with 3721 additions and 1559 deletions

View File

@ -380,19 +380,23 @@ public class VPTree implements Serializable {
private Node buildFromPoints(INDArray items) {
if (executorService == null && items == this.items && workers > 1) {
final val deviceId = Nd4j.getAffinityManager().getDeviceForCurrentThread();
executorService = Executors.newFixedThreadPool(workers, new ThreadFactory() {
@Override
public Thread newThread(Runnable r) {
Thread t = Executors.defaultThreadFactory().newThread(r);
public Thread newThread(final Runnable r) {
Thread t = new Thread(new Runnable() {
@Override
public void run() {
Nd4j.getAffinityManager().unsafeSetDevice(deviceId);
r.run();
}
});
t.setDaemon(true);
t.setName("VPTree thread");
// we don't want threads to be working on different devices
Nd4j.getAffinityManager().attachThreadToDevice(t,
Nd4j.getAffinityManager().getDeviceForCurrentThread());
return t;
}
});

View File

@ -132,9 +132,8 @@ public class ParallelInference {
boolean cRoot = !assignedRoot.get() && cDevice == currentDevice;
assignedRoot.compareAndSet(false, cRoot);
zoo[i] = new InferenceWorker(i, model, observables, cRoot);
zoo[i] = new InferenceWorker(i, model, observables, cRoot, cDevice);
Nd4j.getAffinityManager().attachThreadToDevice(zoo[i], cDevice);
zoo[i].setDaemon(true);
zoo[i].start();
}
@ -425,13 +424,15 @@ public class ParallelInference {
private Model replicatedModel;
private AtomicLong counter = new AtomicLong(0);
private boolean rootDevice;
private int deviceId;
private ReentrantReadWriteLock modelLock = new ReentrantReadWriteLock();
private InferenceWorker(int id, @NonNull Model model, @NonNull BlockingQueue inputQueue, boolean rootDevice) {
private InferenceWorker(int id, @NonNull Model model, @NonNull BlockingQueue inputQueue, boolean rootDevice, int deviceId) {
this.inputQueue = inputQueue;
this.protoModel = model;
this.rootDevice = rootDevice;
this.deviceId = deviceId;
this.setDaemon(true);
this.setName("InferenceThread-" + id);
@ -491,6 +492,7 @@ public class ParallelInference {
@Override
public void run() {
Nd4j.getAffinityManager().unsafeSetDevice(deviceId);
try {
// model should be replicated & initialized here
initializeReplicaModel();

View File

@ -151,18 +151,21 @@ public class ParallelWrapper implements AutoCloseable {
workerCounter.set(0);
this.executorService = (ThreadPoolExecutor) Executors.newFixedThreadPool(workers, new ThreadFactory() {
@Override
public Thread newThread(@NonNull Runnable r) {
Thread t = Executors.defaultThreadFactory().newThread(r);
public Thread newThread(@NonNull final Runnable r) {
final int cThread = workerCounter.getAndIncrement();
int cThread = workerCounter.getAndIncrement();
Thread t = new Thread(new Runnable() {
@Override
public void run() {
Nd4j.getAffinityManager().unsafeSetDevice(cThread % Nd4j.getAffinityManager().getNumberOfDevices());
r.run();
}
});
t.setName("ParallelWrapper training thread " + cThread);
t.setDaemon(true);
t.setUncaughtExceptionHandler(handler);
Nd4j.getAffinityManager().attachThreadToDevice(t,
cThread % Nd4j.getAffinityManager().getNumberOfDevices());
return t;
}
});

View File

@ -80,14 +80,9 @@ public class Word2VecPerformer implements VoidFunction<Pair<List<VocabWord>, Ato
initExpTable();
if (negative > 0 && conf.contains(Word2VecVariables.TABLE)) {
try {
ByteArrayInputStream bis = new ByteArrayInputStream(conf.get(Word2VecVariables.TABLE).getBytes());
DataInputStream dis = new DataInputStream(bis);
table = Nd4j.read(dis);
} catch (IOException e) {
e.printStackTrace();
}
}
}

View File

@ -95,16 +95,10 @@ public class Word2VecPerformerVoid implements VoidFunction<Pair<List<VocabWord>,
initExpTable();
if (negative > 0 && conf.contains(TABLE)) {
try {
ByteArrayInputStream bis = new ByteArrayInputStream(conf.get(TABLE).getBytes());
DataInputStream dis = new DataInputStream(bis);
table = Nd4j.read(dis);
} catch (IOException e) {
e.printStackTrace();
}
}
}

View File

@ -86,7 +86,7 @@ public class SharedTrainingWorker extends BaseTrainingWorker<SharedTrainingResul
// This method will be called ONLY once, in master thread
//Before getting NetBroadcastTuple, to ensure it always gets mapped to device 0
Nd4j.getAffinityManager().attachThreadToDevice(Thread.currentThread(), 0);
Nd4j.getAffinityManager().unsafeSetDevice(0);
NetBroadcastTuple tuple = broadcastModel.getValue();
if (tuple.getConfiguration() != null) {
@ -109,7 +109,7 @@ public class SharedTrainingWorker extends BaseTrainingWorker<SharedTrainingResul
@Override
public ComputationGraph getInitialModelGraph() {
//Before getting NetBroadcastTuple, to ensure it always gets mapped to device 0
Nd4j.getAffinityManager().attachThreadToDevice(Thread.currentThread(), 0);
Nd4j.getAffinityManager().unsafeSetDevice(0);
NetBroadcastTuple tuple = broadcastModel.getValue();
if (tuple.getGraphConfiguration() != null) {
ComputationGraphConfiguration conf = tuple.getGraphConfiguration();

View File

@ -108,7 +108,7 @@ public class EvaluationRunner {
INDArray p;
try{
p = Nd4j.read(new ByteArrayInputStream(pBytes));
} catch (IOException e){
} catch (RuntimeException e){
throw new RuntimeException(e); //Should never happen
}
DeviceLocalNDArray dlp = new DeviceLocalNDArray(p);

View File

@ -97,13 +97,12 @@ public class SparkADSI extends AsyncDataSetIterator {
context = TaskContext.get();
this.thread = new SparkPrefetchThread(buffer, iterator, terminator, null);
this.thread = new SparkPrefetchThread(buffer, iterator, terminator, null, Nd4j.getAffinityManager().getDeviceForCurrentThread());
/**
* We want to ensure, that background thread will have the same thread->device affinity, as master thread
*/
Nd4j.getAffinityManager().attachThreadToDevice(thread, deviceId);
thread.setDaemon(true);
thread.start();
}
@ -116,9 +115,8 @@ public class SparkADSI extends AsyncDataSetIterator {
public class SparkPrefetchThread extends AsyncPrefetchThread {
protected SparkPrefetchThread(BlockingQueue<DataSet> queue, DataSetIterator iterator, DataSet terminator,
MemoryWorkspace workspace) {
super(queue, iterator, terminator, workspace);
protected SparkPrefetchThread(BlockingQueue<DataSet> queue, DataSetIterator iterator, DataSet terminator, MemoryWorkspace workspace, int deviceId) {
super(queue, iterator, terminator, workspace, deviceId);
}

View File

@ -97,15 +97,10 @@ public class SparkAMDSI extends AsyncMultiDataSetIterator {
if (iterator.resetSupported())
this.backedIterator.reset();
this.thread = new SparkPrefetchThread(buffer, iterator, terminator);
this.thread = new SparkPrefetchThread(buffer, iterator, terminator, Nd4j.getAffinityManager().getDeviceForCurrentThread());
context = TaskContext.get();
/**
* We want to ensure, that background thread will have the same thread->device affinity, as master thread
*/
Nd4j.getAffinityManager().attachThreadToDevice(thread, deviceId);
thread.setDaemon(true);
thread.start();
}
@ -117,9 +112,8 @@ public class SparkAMDSI extends AsyncMultiDataSetIterator {
protected class SparkPrefetchThread extends AsyncPrefetchThread {
protected SparkPrefetchThread(@NonNull BlockingQueue<MultiDataSet> queue,
@NonNull MultiDataSetIterator iterator, @NonNull MultiDataSet terminator) {
super(queue, iterator, terminator);
protected SparkPrefetchThread(@NonNull BlockingQueue<MultiDataSet> queue, @NonNull MultiDataSetIterator iterator, @NonNull MultiDataSet terminator, int deviceId) {
super(queue, iterator, terminator, deviceId);
}
}
}

View File

@ -2462,7 +2462,7 @@ double NDArray::getTrace() const {
double sum = 0.;
PRAGMA_OMP_PARALLEL_FOR_ARGS(reduction(OMP_SUMT:sum) if(minDim > Environment::getInstance()->elementwiseThreshold()) schedule(guided))
PRAGMA_OMP_PARALLEL_FOR_ARGS(reduction(OMP_SUMT:sum) OMP_IF(minDim > Environment::getInstance()->elementwiseThreshold()) schedule(guided))
for(int i = 0; i < minDim; ++i)
sum += e<double>(i * offset);

View File

@ -100,7 +100,7 @@ void NDArray::fillAsTriangular(const float val, int lower, int upper, const char
std::vector<Nd4jLong> coords(zRank);
PRAGMA_OMP_PARALLEL_FOR_ARGS(if(zLen > Environment::getInstance()->elementwiseThreshold()) firstprivate(coords))
PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(zLen > Environment::getInstance()->elementwiseThreshold()) firstprivate(coords))
for (Nd4jLong i = 0; i < zLen; ++i) {
shape::index2coords(zRank, target->shapeOf(), i, zLen, coords.data());
@ -141,7 +141,7 @@ void NDArray::setIdentity() {
minDim = shape[i];
float v = 1.0f;
PRAGMA_OMP_PARALLEL_FOR_ARGS(if(minDim > Environment::getInstance()->elementwiseThreshold()) schedule(guided))
PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(minDim > Environment::getInstance()->elementwiseThreshold()) schedule(guided))
for(int i = 0; i < minDim; ++i)
templatedSet<float>(buffer(), i*offset, this->dataType(), &v);
}

View File

@ -172,7 +172,9 @@ template void NDArrayFactory::memcpyFromVector(void *ptr, const std::vector<int8
template NDArray NDArrayFactory::create(const char order, const std::vector<Nd4jLong> &shape, const std::initializer_list<float16>& data, nd4j::LaunchContext * context);
template NDArray NDArrayFactory::create(const char order, const std::vector<Nd4jLong> &shape, const std::initializer_list<bfloat16>& data, nd4j::LaunchContext * context);
template NDArray NDArrayFactory::create(const char order, const std::vector<Nd4jLong> &shape, const std::initializer_list<Nd4jLong>& data, nd4j::LaunchContext * context);
template NDArray NDArrayFactory::create(const char order, const std::vector<Nd4jLong> &shape, const std::initializer_list<uint64_t>& data, nd4j::LaunchContext * context);
template NDArray NDArrayFactory::create(const char order, const std::vector<Nd4jLong> &shape, const std::initializer_list<int>& data, nd4j::LaunchContext * context);
template NDArray NDArrayFactory::create(const char order, const std::vector<Nd4jLong> &shape, const std::initializer_list<unsigned int>& data, nd4j::LaunchContext * context);
template NDArray NDArrayFactory::create(const char order, const std::vector<Nd4jLong> &shape, const std::initializer_list<int16_t>& data, nd4j::LaunchContext * context);
template NDArray NDArrayFactory::create(const char order, const std::vector<Nd4jLong> &shape, const std::initializer_list<int8_t>& data, nd4j::LaunchContext * context);
template NDArray NDArrayFactory::create(const char order, const std::vector<Nd4jLong> &shape, const std::initializer_list<uint8_t>& data, nd4j::LaunchContext * context);

View File

@ -785,48 +785,9 @@ void NativeOpExecutioner::execTransformAny(nd4j::LaunchContext *lc,
auto xType = ArrayOptions::dataType(hXShapeInfo);
auto zType = ArrayOptions::dataType(hZShapeInfo);
switch (opNum) {
case transform::IsMax: {
bool scalarCheat = false;
if (extraParams == nullptr) {
scalarCheat = true;
}
void* special = lc->getAllocationPointer();
if (scalarCheat) {
auto scalarShape = nd4j::ConstantShapeHelper::getInstance()->bufferForShapeInfo(ShapeDescriptor::scalarDescriptor(nd4j::DataType::INT64)); //ShapeBuilders::createScalarShapeInfo(nd4j::DataType::INT64);
/**
* In case of vector-input for IsMax, it just turns into IndexReduce call + further filler call
*/
execIndexReduceScalar(lc, indexreduce::IndexMax, nullptr, hXShapeInfo, dX, dXShapeInfo, extraParams, nullptr, scalarShape.primaryAsT<Nd4jLong>(), special, scalarShape.specialAsT<Nd4jLong>());
Nd4jLong maxIdx = -119;
nd4j::DebugHelper::checkErrorCode(stream, "IsMax: execIndexReduce(...) failed");
cudaMemcpyAsync(&maxIdx, special, sizeof(Nd4jLong), cudaMemcpyDeviceToHost, *stream);
nd4j::DebugHelper::checkErrorCode(stream, "IsMax: cudaMemcpyAsync(...) failed");
int targetIdx = 0;
if (shape::order(hXShapeInfo) == 'c' || shape::order(hXShapeInfo) == 'f' && maxIdx * shape::stride(hXShapeInfo)[shape::rank(hXShapeInfo) - 1] >= shape::length(hXShapeInfo))
targetIdx = maxIdx;
else
targetIdx = maxIdx * shape::stride(hXShapeInfo)[shape::rank(hXShapeInfo) - 1];
dim3 launchDims(1, 512, 1024);
BUILD_SINGLE_SELECTOR(zType, fillIsMaxGeneric, (launchDims, stream, dZ, shape::length(hZShapeInfo), targetIdx), LIBND4J_TYPES);
nd4j::DebugHelper::checkErrorCode(stream, "Legacy IsMax(...) failed");
//delete[] scalarShape;
}
}
break;
default: {
dim3 launchDims(512, 512, 16384);
dim3 launchDims(512, 512, 2048);
BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformAny, ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, xRank, extraParams, dZ, dZShapeInfo, zRank, nullptr, nullptr, nullptr, nullptr), LIBND4J_TYPES, LIBND4J_TYPES);
}
}
// TODO: remove after the release
auto res = cudaStreamSynchronize(*stream);
@ -884,7 +845,7 @@ void NativeOpExecutioner::execTransformFloat(nd4j::LaunchContext *lc,
if (!DataTypeUtils::isR(zType))
throw datatype_exception::build("NativeOpExecutioner::execTransformFloat requires Z to have floating point type", zType);
dim3 launchDims(512, 512, 16384);
dim3 launchDims(512, 512, 2048);
BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformFloat, ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, xRank, extraParams, dZ, dZShapeInfo, zRank, nullptr, nullptr, nullptr, nullptr), LIBND4J_TYPES, FLOAT_TYPES);
// TODO: remove after the release

View File

@ -653,36 +653,7 @@ void execTransformAny(Nd4jPointer *extraPointers,int opNum,
auto streamSpecial = reinterpret_cast<cudaStream_t&>(extraPointers[4]);
LaunchContext lc(stream, streamSpecial, extraPointers[5], extraPointers[3], reinterpret_cast<int*>(extraPointers[6]));
// FIXME: remove this once all operations are enabled
if (opNum == nd4j::transform::IsMax && extraParams != nullptr) {
auto hostYShapeInfo = reinterpret_cast<Nd4jLong *>(extraPointers[7]);
auto hostTShapeInfo = reinterpret_cast<Nd4jLong *>(extraPointers[19]);
auto tadMaxShapeInfo = reinterpret_cast<Nd4jLong *> (extraPointers[10]);
auto tadMaxOffsets = reinterpret_cast<Nd4jLong *> (extraPointers[11]);
int *dimension = reinterpret_cast<int *> (extraPointers[15]);
int *hDimension = reinterpret_cast<int *> (extraPointers[16]);
int dimensionLength = getDeviceId(extraPointers[18]);
auto special = reinterpret_cast<double *>(extraPointers[17]);
auto cshape = ShapeBuilders::createVectorShapeInfo(nd4j::DataType::INT32, dimensionLength);
// we call for IMax on specified dimension
execIndexReduce(extraPointers, indexreduce::IndexMax, nullptr, hXShapeInfo, dX, dXShapeInfo, extraParams, nullptr, hostTShapeInfo, special, hostYShapeInfo, hDimension, cshape, dimension, nullptr);
DEBUG_KERNEL(stream, opNum);
dim3 launchDims(256, 256, 16384);
auto zType = ArrayOptions::dataType(hZShapeInfo);
// at this point, all IMax indexes are gathered, and we execute filler
BUILD_SINGLE_SELECTOR(zType, fillDimensionalIsMaxGeneric, (launchDims, stream, special, dZ, dZShapeInfo, tadMaxShapeInfo, dimension, dimensionLength, tadMaxOffsets), LIBND4J_TYPES);
nd4j::DebugHelper::checkErrorCode(stream, "Legacy IsMax(...) failed");
delete[] cshape;
} else {
NativeOpExecutioner::execTransformAny(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, extraParams, nullptr, nullptr);
}
}
////////////////////////////////////////////////////////////////////////
@ -712,7 +683,7 @@ void execTransformFloat(Nd4jPointer *extraPointers,int opNum,
auto tadOffsets = reinterpret_cast<Nd4jLong *>(extraPointers != nullptr ? extraPointers[11] : nullptr);
LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]);
NativeOpExecutioner::execTransformFloat(&lc, opNum, hX, hXShapeInfo, dZ, dXShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, extraParams, tadShapeInfo, tadOffsets);
NativeOpExecutioner::execTransformFloat(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, extraParams, tadShapeInfo, tadOffsets);
}

View File

@ -137,8 +137,8 @@ namespace nd4j {
auto newAxis = ShapeUtils::evalDimsToExclude(array->rankOf(), args);
auto result = array->allTensorsAlongDimension(newAxis);
for (int e = 0; e < result->size(); e++) {
auto chunk = result->at(e)->dup(array->ordering());
write(e, chunk);
auto chunk = result->at(e);//->dup(array->ordering());
write(e, chunk->dup(array->ordering()));
}
delete result;
}

View File

@ -922,7 +922,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo,
//*********************************************//
case LoopKind::EWS1: {
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) if(numThreads > 1) private(extraParams))
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams))
for (uint i = 0; i < zLen; ++i) {
extraParams[0] = param0;
@ -944,7 +944,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo,
//*********************************************//
case LoopKind::EWSNONZERO: {
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) if(numThreads > 1) private(extraParams))
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams))
for (uint i = 0; i < zLen; ++i) {
extraParams[0] = param0;
@ -966,7 +966,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo,
//*********************************************//
case LoopKind::RANK1: {
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) if(numThreads > 1) private(extraParams))
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams))
for (uint i = 0; i < zLen; i++) {
extraParams[0] = param0;
@ -990,7 +990,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo,
//*********************************************//
case LoopKind::RANK2: {
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) if(numThreads > 1) private(extraParams))
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams))
for (uint i = 0; i < zLen; i++) {
extraParams[0] = param0;
@ -1016,7 +1016,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo,
//*********************************************//
case LoopKind::RANK3: {
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) if(numThreads > 1) private(extraParams))
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams))
for (uint i = 0; i < zLen; i++) {
extraParams[0] = param0;
@ -1044,7 +1044,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo,
//*********************************************//
case LoopKind::RANK4: {
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) if(numThreads > 1) private(extraParams))
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams))
for (uint i = 0; i < zLen; i++) {
extraParams[0] = param0;
@ -1074,7 +1074,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo,
//*********************************************//
case LoopKind::RANK5: {
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) if(numThreads > 1) private(extraParams))
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams))
for (uint i = 0; i < zLen; i++) {
extraParams[0] = param0;
@ -1111,7 +1111,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo,
if(shape::haveSameShapeAndStrides(xTadShapeInfo, yTadShapeInfo)) {
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) if(numThreads > 1) private(extraParams))
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams))
for (uint i = 0; i < zLen; ++i) {
extraParams[0] = param0;
@ -1135,7 +1135,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo,
uint castYTadShapeInfo[MAX_RANK];
const bool canCastYTad = nd4j::DataTypeUtils::castShapeInfo<uint>(yTadShapeInfo, castYTadShapeInfo);
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) if(numThreads > 1) private(extraParams))
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams))
for (uint i = 0; i < zLen; ++i) {
extraParams[0] = param0;
@ -1199,7 +1199,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo,
//*********************************************//
case LoopKind::EWS1: {
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) if(numThreads > 1) private(extraParams))
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams))
for (uint ix = 0; ix < numXTads; ++ix) {
for (uint iy = 0; iy < numYTads; ++iy) {
@ -1224,7 +1224,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo,
//*********************************************//
case LoopKind::EWSNONZERO: {
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) if(numThreads > 1) private(extraParams))
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams))
for (uint ix = 0; ix < numXTads; ++ix) {
for (uint iy = 0; iy < numYTads; ++iy) {
@ -1249,7 +1249,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo,
//*********************************************//
case LoopKind::RANK1: {
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) if(numThreads > 1) private(extraParams))
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams))
for (uint ix = 0; ix < numXTads; ++ix) {
for (uint iy = 0; iy < numYTads; ++iy) {
@ -1276,7 +1276,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo,
//*********************************************//
case LoopKind::RANK2: {
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) if(numThreads > 1) private(extraParams))
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams))
for (uint ix = 0; ix < numXTads; ++ix) {
for (uint iy = 0; iy < numYTads; ++iy) {
@ -1305,7 +1305,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo,
//*********************************************//
case LoopKind::RANK3: {
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) if(numThreads > 1) private(extraParams))
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams))
for (uint ix = 0; ix < numXTads; ++ix) {
for (uint iy = 0; iy < numYTads; ++iy) {
@ -1336,7 +1336,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo,
//*********************************************//
case LoopKind::RANK4: {
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) if(numThreads > 1) private(extraParams))
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams))
for (uint ix = 0; ix < numXTads; ++ix) {
for (uint iy = 0; iy < numYTads; ++iy) {
@ -1369,7 +1369,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo,
//*********************************************//
case LoopKind::RANK5: {
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) if(numThreads > 1) private(extraParams))
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams))
for (uint ix = 0; ix < numXTads; ++ix) {
for (uint iy = 0; iy < numYTads; ++iy) {
@ -1409,7 +1409,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo,
if(shape::haveSameShapeAndStrides(xTadShapeInfo, yTadShapeInfo)) {
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) if(numThreads > 1) private(extraParams))
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams))
for (uint ix = 0; ix < numXTads; ++ix) {
for (uint iy = 0; iy < numYTads; ++iy) {
@ -1435,7 +1435,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo,
uint castYTadShapeInfo[MAX_RANK];
const bool canCastYTad = nd4j::DataTypeUtils::castShapeInfo<uint>(yTadShapeInfo, castYTadShapeInfo);
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) if(numThreads > 1) private(extraParams))
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams))
for (uint ix = 0; ix < numXTads; ++ix) {
for (uint iy = 0; iy < numYTads; ++iy) {

View File

@ -40,7 +40,7 @@ static void usualGemm(const char cOrder, const bool transA, const bool transB, c
const bool flagA = (flagC && transA) || (!flagC && !transA);
const bool flagB = (flagC && transB) || (!flagC && !transB);
// PRAGMA_OMP_PARALLEL_FOR_ARGS(if(M*N > Environment::getInstance()->elementwiseThreshold()) schedule(guided))
// PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(M*N > Environment::getInstance()->elementwiseThreshold()) schedule(guided))
// for(uint row = 0; row < M; ++row) {
// T3* c = flagC ? (C + row) : (C + row * ldc);
@ -74,7 +74,7 @@ static void usualGemm(const char cOrder, const bool transA, const bool transB, c
// }
// }
PRAGMA_OMP_PARALLEL_FOR_ARGS(if(M*N > Environment::getInstance()->elementwiseThreshold()) schedule(guided) collapse(2))
PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(M*N > Environment::getInstance()->elementwiseThreshold()) schedule(guided) collapse(2))
for(uint row = 0; row < M; ++row) {
for(uint col = 0; col < N; ++col) {
@ -108,7 +108,7 @@ static void usualGemv(const char aOrder, const int M, const int N, const double
const bool flagA = aOrder == 'f';
PRAGMA_OMP_PARALLEL_FOR_ARGS(if(M > Environment::getInstance()->elementwiseThreshold()) schedule(guided))
PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(M > Environment::getInstance()->elementwiseThreshold()) schedule(guided))
for(int row = 0; row < M; ++row) {
T3* y = Y + row * incy;
@ -139,7 +139,7 @@ static void usualDot(const Nd4jLong length, const double alpha, const void* vX,
T3 alphaZ(alpha), betaZ(beta);
T3 sum = 0;
PRAGMA_OMP_PARALLEL_FOR_ARGS(if(length > Environment::getInstance()->elementwiseThreshold()) schedule(guided) reduction(OMP_SUMT:sum))
PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(length > Environment::getInstance()->elementwiseThreshold()) schedule(guided) reduction(OMP_SUMT:sum))
for(int i = 0; i < length; ++i)
sum = sum + X[i * incx] * Y[i * incy];

View File

@ -25,21 +25,21 @@ namespace nd4j {
////////////////////////////////////////////////////////////////////////
template <typename T>
__global__ void execFillIsMax(void *vdZ, Nd4jLong length, long idx) {
__global__ void execFillIsMax(void *vdZ, Nd4jLong *xShapeInfo, Nd4jLong length, long idx) {
auto dz = reinterpret_cast<T*>(vdZ);
int tid = blockIdx.x * blockDim.x + threadIdx.x;
for (Nd4jLong i = tid; i < length; i += blockDim.x * gridDim.x)
dz[i] = (i == idx ? (T) 1 : (T) 0);
dz[shape::getIndexOffset(i, xShapeInfo, length)] = (i == idx ? (T) 1 : (T) 0);
}
////////////////////////////////////////////////////////////////////////
template <typename T>
__host__ void fillIsMaxGeneric(dim3 &launchDims, cudaStream_t *stream, void *dx, Nd4jLong length, long idx) {
execFillIsMax<T><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(dx, length, idx);
__host__ void fillIsMaxGeneric(dim3 &launchDims, cudaStream_t *stream, void *dx, Nd4jLong *xShapeInfo, Nd4jLong length, long idx) {
execFillIsMax<T><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(dx, xShapeInfo, length, idx);
nd4j::DebugHelper::checkErrorCode(stream, "fillIsMax(...) failed");
}
BUILD_SINGLE_TEMPLATE(template void ND4J_EXPORT fillIsMaxGeneric, (dim3& launchDims, cudaStream_t *stream, void* dz, Nd4jLong length, long idx), LIBND4J_TYPES);
BUILD_SINGLE_TEMPLATE(template void ND4J_EXPORT fillIsMaxGeneric, (dim3& launchDims, cudaStream_t *stream, void* dz, Nd4jLong *zShapeInfo, Nd4jLong length, long idx), LIBND4J_TYPES);
}

View File

@ -99,18 +99,18 @@ namespace functions {
if(xEws > 0 && zEws > 0 && xOrder == zOrder) {
for (int i = tid; i < length; i += totalThreads)
for (Nd4jLong i = tid; i < length; i += totalThreads)
z[i * zEws] = OpType::op(x[i * xEws], params);
}
else {
if(vx == vz) {
for (Nd4jLong i = tid; i < length; i+= gridDim.x * blockDim.x) {
for (Nd4jLong i = tid; i < length; i+= totalThreads) {
auto xOffset = shape::getIndexOffset(i, xShapeInfo, length);
z[xOffset] = OpType::op(x[xOffset], params);
}
}
else {
for (Nd4jLong i = tid; i < length; i+= gridDim.x * blockDim.x) {
for (Nd4jLong i = tid; i < length; i+= totalThreads) {
auto xOffset = shape::getIndexOffset(i, xShapeInfo, length);
auto zOffset = shape::getIndexOffset(i, zShapeInfo, length);
z[zOffset] = OpType::op(x[xOffset], params);

View File

@ -92,8 +92,7 @@
(21, Copy)
#define TRANSFORM_ANY_OPS \
(0, Assign) , \
(1, IsMax)
(0, Assign)
// these ops return bool
#define TRANSFORM_BOOL_OPS \

View File

@ -36,7 +36,7 @@
namespace nd4j {
template <typename T>
_CUDA_H void fillIsMaxGeneric(dim3 &launchDims, cudaStream_t *stream, void *dx, Nd4jLong length, long idx);
_CUDA_H void fillIsMaxGeneric(dim3 &launchDims, cudaStream_t *stream, void *dx, Nd4jLong *xShapeInfo, Nd4jLong length, long idx);
template <typename T>
_CUDA_H void fillDimensionalIsMaxGeneric(dim3 &launchDims, cudaStream_t *stream, void *dX, void *dZ, Nd4jLong *zShapeInfo, Nd4jLong *tadOnlyShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOffsets);

View File

@ -1328,7 +1328,8 @@
REGISTER_C(NAME) \
nd4j::ShapeList* nd4j::ops::NAME::calculateOutputShape(nd4j::ShapeList* inputShape, nd4j::graph::Context& block) { \
auto shapeList = SHAPELIST(); \
for (int e = 0; e < this->getOpDescriptor()->getNumberOfOutputs(); e++) { \
auto opLimit = this->getOpDescriptor()->getNumberOfOutputs() < 1 ? block.width() : this->getOpDescriptor()->getNumberOfOutputs(); \
for (int e = 0; e < opLimit; e++) { \
auto newshape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inputShape->at(e)), shape::order(inputShape->at(e)), shape::rank(inputShape->at(e)), shape::shapeOf(inputShape->at(e))); \
shapeList->push_back(newshape); \
} \
@ -1365,7 +1366,8 @@
REGISTER_C(NAME) \
nd4j::ShapeList* nd4j::ops::NAME::calculateOutputShape(nd4j::ShapeList* inputShape, nd4j::graph::Context& block) { \
auto shapeList = SHAPELIST(); \
for (int e = 0; e < this->getOpDescriptor()->getNumberOfOutputs(); e++) { \
auto opLimit = this->getOpDescriptor()->getNumberOfOutputs() < 1 ? block.width() : this->getOpDescriptor()->getNumberOfOutputs(); \
for (int e = 0; e < opLimit; e++) { \
Nd4jLong* newshape; \
COPY_SHAPE(inputShape->at(0), newshape); \
shapeList->push_back(CONSTANT(newshape)); \
@ -1388,7 +1390,8 @@
REGISTER_C(NAME) \
nd4j::ShapeList* nd4j::ops::NAME::calculateOutputShape(nd4j::ShapeList* inputShape, nd4j::graph::Context& block) { \
auto shapeList = SHAPELIST(); \
for (int e = 0; e < this->getOpDescriptor()->getNumberOfOutputs(); e++) { \
auto opLimit = this->getOpDescriptor()->getNumberOfOutputs() < 1 ? block.width() : this->getOpDescriptor()->getNumberOfOutputs(); \
for (int e = 0; e < opLimit; e++) { \
auto newshape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inputShape->at(e)), shape::order(inputShape->at(e)), shape::rank(inputShape->at(e)), shape::shapeOf(inputShape->at(e))); \
shapeList->push_back(newshape); \
} \

View File

@ -0,0 +1,58 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include <op_boilerplate.h>
#if NOT_EXCLUDED(OP_cyclic_rshift_bits)
#include <ops/declarable/CustomOperations.h>
#include <ops/declarable/helpers/helpers.h>
#include <ops/declarable/helpers/shift.h>
namespace nd4j {
namespace ops {
CONFIGURABLE_OP_IMPL(cyclic_rshift_bits, 1, 1, true, 0, -2) {
auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0);
REQUIRE_TRUE(block.numI() > 0 || block.width() > 1, 0, "cyclic_rshift_bits: actual shift value is missing");
uint32_t shift = 0;
if (block.width() > 1) {
shift = INPUT_VARIABLE(1)->e<uint32_t>(0);
} else if (block.numI() > 0) {
shift = INT_ARG(0);
};
helpers::cyclic_rshift_bits(block.launchContext(), *input, *output, shift);
REQUIRE_TRUE(shift > 0 && shift < input->sizeOfT() * 8, 0, "cyclic_rshift_bits: can't shift beyond size of data type")
return Status::OK();
}
DECLARE_TYPES(cyclic_rshift_bits) {
getOpDescriptor()
->setAllowedInputTypes({ALL_INTS})
->setSameMode(true);
}
}
}
#endif

View File

@ -0,0 +1,58 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include <op_boilerplate.h>
#if NOT_EXCLUDED(OP_cyclic_shift_bits)
#include <ops/declarable/CustomOperations.h>
#include <ops/declarable/helpers/helpers.h>
#include <ops/declarable/helpers/shift.h>
namespace nd4j {
namespace ops {
CONFIGURABLE_OP_IMPL(cyclic_shift_bits, 1, 1, true, 0, -2) {
auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0);
REQUIRE_TRUE(block.numI() > 0 || block.width() > 1, 0, "cyclic_shift_bits: actual shift value is missing");
uint32_t shift = 0;
if (block.width() > 1) {
shift = INPUT_VARIABLE(1)->e<uint32_t>(0);
} else if (block.numI() > 0) {
shift = INT_ARG(0);
};
helpers::cyclic_shift_bits(block.launchContext(), *input, *output, shift);
REQUIRE_TRUE(shift > 0 && shift < input->sizeOfT() * 8, 0, "cyclic_shift_bits: can't shift beyond size of data type")
return Status::OK();
}
DECLARE_TYPES(cyclic_shift_bits) {
getOpDescriptor()
->setAllowedInputTypes({ALL_INTS})
->setSameMode(true);
}
}
}
#endif

View File

@ -0,0 +1,58 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include <op_boilerplate.h>
#if NOT_EXCLUDED(OP_rshift_bits)
#include <ops/declarable/CustomOperations.h>
#include <ops/declarable/helpers/helpers.h>
#include <ops/declarable/helpers/shift.h>
namespace nd4j {
namespace ops {
CONFIGURABLE_OP_IMPL(rshift_bits, 1, 1, true, 0, -2) {
auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0);
REQUIRE_TRUE(block.numI() > 0 || block.width() > 1, 0, "rshift_bits: actual shift value is missing");
uint32_t shift = 0;
if (block.width() > 1) {
shift = INPUT_VARIABLE(1)->e<uint32_t>(0);
} else if (block.numI() > 0) {
shift = INT_ARG(0);
};
REQUIRE_TRUE(shift > 0 && shift < input->sizeOfT() * 8, 0, "rshift_bits: can't shift beyond size of data type")
helpers::rshift_bits(block.launchContext(), *input, *output, shift);
return Status::OK();
}
DECLARE_TYPES(rshift_bits) {
getOpDescriptor()
->setAllowedInputTypes({ALL_INTS})
->setSameMode(true);
}
}
}
#endif

View File

@ -0,0 +1,58 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include <op_boilerplate.h>
#if NOT_EXCLUDED(OP_shift_bits)
#include <ops/declarable/CustomOperations.h>
#include <ops/declarable/helpers/helpers.h>
#include <ops/declarable/helpers/shift.h>
namespace nd4j {
namespace ops {
CONFIGURABLE_OP_IMPL(shift_bits, 1, 1, true, 0, -2) {
auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0);
REQUIRE_TRUE(block.numI() > 0 || block.width() > 1, 0, "shift_bits: actual shift value is missing");
uint32_t shift = 0;
if (block.width() > 1) {
shift = INPUT_VARIABLE(1)->e<uint32_t>(0);
} else if (block.numI() > 0) {
shift = INT_ARG(0);
};
REQUIRE_TRUE(shift > 0 && shift < input->sizeOfT() * 8, 0, "shift_bits: can't shift beyond size of data type")
helpers::shift_bits(block.launchContext(), *input, *output, shift);
return Status::OK();
}
DECLARE_TYPES(shift_bits) {
getOpDescriptor()
->setAllowedInputTypes({ALL_INTS})
->setSameMode(true);
}
}
}
#endif

View File

@ -34,7 +34,7 @@ namespace nd4j {
auto z = OUTPUT_VARIABLE(i);
REQUIRE_TRUE(x->dataType() == z->dataType(), 0, "Toggle bits requires input and output to have same type");
REQUIRE_TRUE(x->isR(),0, "Toggle bits requires input and output to be integer type (int8, int16, int32, int64)");
REQUIRE_TRUE(x->isZ(),0, "Toggle bits requires input and output to be integer type (int8, int16, int32, int64)");
helpers::__toggle_bits(block.launchContext(), *x, *z);
}
@ -44,7 +44,8 @@ namespace nd4j {
DECLARE_TYPES(toggle_bits) {
getOpDescriptor()
->setAllowedInputTypes({ALL_INTS})
->setSameMode(true);
->setAllowedOutputTypes({ALL_INTS})
->setSameMode(false);
}
}
}

View File

@ -28,7 +28,7 @@
namespace nd4j {
namespace ops {
CONFIGURABLE_OP_IMPL(ismax, 1, 1, true, 0, -1) {
CONFIGURABLE_OP_IMPL(ismax, 1, 1, true, 0, -2) {
auto x = INPUT_VARIABLE(0);
auto z = OUTPUT_VARIABLE(0);

View File

@ -25,7 +25,7 @@
namespace nd4j {
namespace ops {
OP_IMPL(to_double, 1, 1, true) {
CUSTOM_OP_IMPL(to_double, 1, 1, true, 0, 0) {
auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0);
@ -42,6 +42,12 @@ namespace nd4j {
->setAllowedInputTypes(nd4j::DataType::ANY)
->setAllowedOutputTypes(nd4j::DataType::DOUBLE);
}
DECLARE_SHAPE_FN(to_double) {
auto outShape = ShapeBuilders::copyShapeInfoAndType(inputShape->at(0), DataType::DOUBLE, true, block.workspace());
return SHAPELIST(CONSTANT(outShape));
}
}
}

View File

@ -25,7 +25,7 @@
namespace nd4j {
namespace ops {
OP_IMPL(to_float16, 1, 1, true) {
CUSTOM_OP_IMPL(to_float16, 1, 1, true, 0, 0) {
auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0);
@ -42,6 +42,12 @@ namespace nd4j {
->setAllowedInputTypes(nd4j::DataType::ANY)
->setAllowedOutputTypes(nd4j::DataType::HALF);
}
DECLARE_SHAPE_FN(to_float16) {
auto outShape = ShapeBuilders::copyShapeInfoAndType(inputShape->at(0), DataType::HALF, true, block.workspace());
return SHAPELIST(CONSTANT(outShape));
}
}
}

View File

@ -25,7 +25,7 @@
namespace nd4j {
namespace ops {
OP_IMPL(to_float32, 1, 1, true) {
CUSTOM_OP_IMPL(to_float32, 1, 1, true, 0, 0) {
auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0);
@ -42,6 +42,12 @@ namespace nd4j {
->setAllowedInputTypes(nd4j::DataType::ANY)
->setAllowedOutputTypes(nd4j::DataType::FLOAT32);
}
DECLARE_SHAPE_FN(to_float32) {
auto outShape = ShapeBuilders::copyShapeInfoAndType(inputShape->at(0), DataType::FLOAT32, true, block.workspace());
return SHAPELIST(CONSTANT(outShape));
}
}
}

View File

@ -25,7 +25,7 @@
namespace nd4j {
namespace ops {
OP_IMPL(to_int32, 1, 1, true) {
CUSTOM_OP_IMPL(to_int32, 1, 1, true, 0, 0) {
auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0);
@ -42,6 +42,11 @@ namespace nd4j {
->setAllowedInputTypes(nd4j::DataType::ANY)
->setAllowedOutputTypes(nd4j::DataType::INT32);
}
DECLARE_SHAPE_FN(to_int32) {
auto outShape = ShapeBuilders::copyShapeInfoAndType(inputShape->at(0), DataType::INT32, true, block.workspace());
return SHAPELIST(CONSTANT(outShape));
}
}
}

View File

@ -25,7 +25,7 @@
namespace nd4j {
namespace ops {
OP_IMPL(to_int64, 1, 1, true) {
CUSTOM_OP_IMPL(to_int64, 1, 1, true, 0, 0) {
auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0);
@ -42,6 +42,11 @@ namespace nd4j {
->setAllowedInputTypes(nd4j::DataType::ANY)
->setAllowedOutputTypes(nd4j::DataType::INT64);
}
DECLARE_SHAPE_FN(to_int64) {
auto outShape = ShapeBuilders::copyShapeInfoAndType(inputShape->at(0), DataType::INT64, true, block.workspace());
return SHAPELIST(CONSTANT(outShape));
}
}
}

View File

@ -25,7 +25,7 @@
namespace nd4j {
namespace ops {
OP_IMPL(to_uint32, 1, 1, true) {
CUSTOM_OP_IMPL(to_uint32, 1, 1, true, 0, 0) {
auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0);
@ -40,8 +40,13 @@ namespace nd4j {
DECLARE_TYPES(to_uint32) {
getOpDescriptor()
->setAllowedInputTypes(nd4j::DataType::ANY)
->setAllowedOutputTypes(nd4j::DataType::INT16);
->setAllowedOutputTypes(nd4j::DataType::INT32);
}
DECLARE_SHAPE_FN(to_uint32) {
auto outShape = ShapeBuilders::copyShapeInfoAndType(inputShape->at(0), DataType::UINT32, true, block.workspace());
return SHAPELIST(CONSTANT(outShape));
}
}
}

View File

@ -25,7 +25,7 @@
namespace nd4j {
namespace ops {
OP_IMPL(to_uint64, 1, 1, true) {
CUSTOM_OP_IMPL(to_uint64, 1, 1, true, 0, 0) {
auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0);
@ -42,6 +42,10 @@ namespace nd4j {
->setAllowedInputTypes(nd4j::DataType::ANY)
->setAllowedOutputTypes(nd4j::DataType::INT8);
}
DECLARE_SHAPE_FN(to_uint64) {
auto outShape = ShapeBuilders::copyShapeInfoAndType(inputShape->at(0), DataType::UINT64, true, block.workspace());
return SHAPELIST(CONSTANT(outShape));
}
}
}

View File

@ -26,13 +26,19 @@
namespace nd4j {
namespace ops {
LIST_OP_IMPL(unstack_list, 1, 1, 0, 0) {
auto input = INPUT_VARIABLE(0);
auto outputList = INPUT_LIST(0);
auto input = INPUT_VARIABLE(int(outputList != nullptr) );
auto list = new NDArrayList(0, true);
list->unstack(input, 0);
if (outputList == nullptr) {
outputList = new NDArrayList(0, true);
//block.trackList(outputList);
setupResultList(outputList, block);
}
outputList->unstack(input, INT_ARG(0));
//OVERWRITE_RESULT(list);
setupResultList(list, block);
//
return Status::OK();
}
}

View File

@ -23,18 +23,41 @@
#include <ops/declarable/headers/parity_ops.h>
#include <ops/declarable/helpers/roll.h>
#include <ops/declarable/helpers/axis.h>
namespace nd4j {
namespace ops {
CONFIGURABLE_OP_IMPL(roll, 1, 1, true, 0, 1) {
CONFIGURABLE_OP_IMPL(roll, 1, 1, true, 0, 0) {
auto output = OUTPUT_VARIABLE(0);
auto input = INPUT_VARIABLE(0);
bool shiftIsLinear = true;
//std::vector<int> axes(input->rankOf());
int shift = INT_ARG(0);
int inputLen = input->lengthOf();
if (block.isInplace()) output = input;
bool shiftIsLinear = block.width() == 1;
std::vector<int> axes;
std::vector<int> shifts;
if (block.width() > 1) {
REQUIRE_TRUE(block.width() == 3, 0, "roll: 3 arguments required for roll - input, shifts and axes. But %i given.", block.width());
auto axesI = INPUT_VARIABLE(2);
auto shiftsI = INPUT_VARIABLE(1);
REQUIRE_TRUE(axesI->rankOf() == shiftsI->rankOf(), 0, "roll: shifts and axes should be the same rank, but %i and %i given.", (int)shiftsI->rankOf(), (int)axesI->rankOf());
REQUIRE_TRUE(axesI->lengthOf() == shiftsI->lengthOf(), 0, "roll: shifts and axes should be the same length, but %i and %i given.", (int)shiftsI->lengthOf(), (int)axesI->lengthOf());
helpers::adjustAxis(axesI->lengthOf(), axesI, axes );
shifts.resize(shiftsI->lengthOf());
for (Nd4jLong i = 0; i < shiftsI->lengthOf(); i++) {
auto shift = shiftsI->e<int>(i);
if (shift < 0) {
shift -= input->sizeAt(i) * (shift / inputLen - 1);
}
else {
shift %= input->sizeAt(i);
}
shifts[i] = shift;
}
}
else {
int shift = INT_ARG(0);
if (shift < 0) {
// convert shift to positive value between 1 and inputLen - 1
shift -= inputLen * (shift / inputLen - 1);
@ -42,21 +65,32 @@ namespace ops {
else
// cut shift to value between 1 and inputLen - 1
shift %= inputLen;
axes.resize(block.getIArguments()->size() - 1);
if (axes.size())
shifts.resize(axes.size());//emplace_back(shift);
else
shifts.push_back(shift);
for (auto& s: shifts)
s = shift;
for (unsigned e = 0; e < axes.size(); ++e) {
int axis = INT_ARG(e + 1);
REQUIRE_TRUE(axis < input->rankOf() && axis >= -input->rankOf(), 0, "roll: axe value should be between -%i and %i, but %i was given.",
input->rankOf(), input->rankOf() - 1, axis);
axes[e] = (axis < 0? (input->rankOf() + axis) : axis);
}
}
if (block.isInplace()) output = input;
shiftIsLinear = axes.size() == 0;
if (block.numI() > 1)
shiftIsLinear = false;
if (shiftIsLinear) {
helpers::rollFunctorLinear(block.launchContext(), input, output, shift, block.isInplace());
helpers::rollFunctorLinear(block.launchContext(), input, output, shifts[0], block.isInplace());
}
else {
std::vector<int> axes(block.numI() - 1);
for (unsigned e = 0; e < axes.size(); ++e) {
int axe = INT_ARG(e + 1);
REQUIRE_TRUE(axe < input->rankOf() && axe >= -input->rankOf(), 0, "roll: axe value should be between -%i and %i, but %i was given.",
input->rankOf(), input->rankOf() - 1, axe);
axes[e] = (axe < 0? (input->rankOf() + axe) : axe);
}
helpers::rollFunctorFull(block.launchContext(), input, output, shift, axes, block.isInplace());
helpers::rollFunctorFull(block.launchContext(), input, output, shifts, axes, block.isInplace());
}
return Status::OK();
@ -64,7 +98,9 @@ namespace ops {
DECLARE_TYPES(roll) {
getOpDescriptor()
->setAllowedInputTypes(nd4j::DataType::ANY)
->setAllowedInputTypes(0,nd4j::DataType::ANY)
->setAllowedInputTypes(1,nd4j::DataType::INT32) // TODO: all ints in future
->setAllowedInputTypes(2,nd4j::DataType::INT32)
->setAllowedOutputTypes(nd4j::DataType::ANY)
->setSameMode(true);
}

View File

@ -26,11 +26,11 @@
namespace nd4j {
namespace ops {
CUSTOM_OP_IMPL(get_seed, -2, 1, false, 0, 0) {
REQUIRE_TRUE(block.getRNG() != nullptr, 0, "RNG should be defined in Graph");
auto rng = block.getRNG();
// REQUIRE_TRUE(block.getRNG() != nullptr, 0, "RNG should be defined in Graph");
auto rng = block.getRng();
auto z = OUTPUT_VARIABLE(0);
z->p(Nd4jLong(0), rng->getSeed());
z->p(Nd4jLong(0), rng.rootState());
return Status::OK();
}

View File

@ -27,8 +27,9 @@
namespace nd4j {
namespace ops {
CUSTOM_OP_IMPL(set_seed, -2, 1, false, 0, -2) {
REQUIRE_TRUE(block.getRNG() != nullptr, 0, "RNG should be defined in Graph");
auto rng = block.getRNG();
// REQUIRE_TRUE(block.getRNG() != nullptr, 0, "RNG should be defined in Graph");
auto rng = block.getRng(); //.getRNG();
Nd4jLong seed = 0;
if (block.getIArguments()->size() > 0) {
seed = INT_ARG(0);
@ -41,8 +42,8 @@ namespace nd4j {
}
// FIXME: this approach isn't really good for cuda, since it'll assume that CUDA might get nullptr instead of stream
refreshBuffer(nullptr, seed, (Nd4jPointer) rng);
//refreshBuffer(nullptr, seed, (Nd4jPointer) rng);
rng.setSeed((int)seed);
return Status::OK();
}

View File

@ -25,7 +25,7 @@
namespace nd4j {
namespace ops {
OP_IMPL(Log1p, 2, 1, true) {
OP_IMPL(Log1p, 1, 1, true) {
auto x = INPUT_VARIABLE(0);
auto z = OUTPUT_VARIABLE(0);

View File

@ -27,7 +27,7 @@
namespace nd4j {
namespace ops {
OP_IMPL(mergemaxindex, -1, 1, false) {
CUSTOM_OP_IMPL(mergemaxindex, -1, 1, false, 0, 0) {
REQUIRE_OK(this->validateInputDimensionsMatch(block));
auto output = OUTPUT_VARIABLE(0);
@ -49,6 +49,15 @@ DECLARE_SYN(MergeMaxIndex, mergemaxindex);
->setAllowedInputTypes({ALL_INTS, ALL_FLOATS});
}
}
DECLARE_SHAPE_FN(mergemaxindex) {
auto in = inputShape->at(0);
auto dtype = DataType::INT32;
if (block.getIArguments()->size()> 0)
dtype = (DataType)INT_ARG(0);
auto resShape = ShapeBuilders::copyShapeInfoAndType(in, dtype, block.workspace());
return SHAPELIST(CONSTANT(resShape));
}
}
#endif

View File

@ -28,13 +28,58 @@ namespace nd4j {
/**
* This operation toggles individual bits of each element in array
*
* PLEASE NOTE: This operation is possible only on integer datatypes
* PLEASE NOTE: This operation is possible only on integer data types
*
* @tparam T
*/
#if NOT_EXCLUDED(OP_toggle_bits)
DECLARE_OP(toggle_bits, -1, -1, true);
#endif
/**
* This operation shift individual bits of each element in array to the left: <<
*
* PLEASE NOTE: This operation is applicable only to integer data types
*
* @tparam T
*/
#if NOT_EXCLUDED(OP_shift_bits)
DECLARE_CONFIGURABLE_OP(shift_bits, 1, 1, true, 0, -2);
#endif
/**
* This operation shift individual bits of each element in array to the right: >>
*
* PLEASE NOTE: This operation is applicable only to integer data types
*
* @tparam T
*/
#if NOT_EXCLUDED(OP_rshift_bits)
DECLARE_CONFIGURABLE_OP(rshift_bits, 1, 1, true, 0, -2);
#endif
/**
* This operation shift individual bits of each element in array, shifting to the left
*
* PLEASE NOTE: This operation is applicable only to integer data types
*
* @tparam T
*/
#if NOT_EXCLUDED(OP_cyclic_shift_bits)
DECLARE_CONFIGURABLE_OP(cyclic_shift_bits, 1, 1, true, 0, -2);
#endif
/**
* This operation shift individual bits of each element in array, shifting to the right
*
* PLEASE NOTE: This operation is applicable only to integer data types
*
* @tparam T
*/
#if NOT_EXCLUDED(OP_cyclic_rshift_bits)
DECLARE_CONFIGURABLE_OP(cyclic_rshift_bits, 1, 1, true, 0, -2);
#endif
}
}

View File

@ -260,7 +260,7 @@ namespace nd4j {
* 0: axis
*/
#if NOT_EXCLUDED(OP_ismax)
DECLARE_CONFIGURABLE_OP(ismax, 1, 1, true, 0, -1);
DECLARE_CONFIGURABLE_OP(ismax, 1, 1, true, 0, -2);
#endif
/**

View File

@ -30,7 +30,7 @@ namespace nd4j {
* PLEASE NOTE: This op is disabled atm, and reserved for future releases.
*/
#if NOT_EXCLUDED(OP_to_double)
DECLARE_OP(to_double, 1, 1, true);
DECLARE_CUSTOM_OP(to_double, 1, 1, true, 0, 0);
#endif
/**
@ -39,7 +39,7 @@ namespace nd4j {
* PLEASE NOTE: This op is disabled atm, and reserved for future releases.
*/
#if NOT_EXCLUDED(OP_to_float16)
DECLARE_OP(to_float16, 1, 1, true);
DECLARE_CUSTOM_OP(to_float16, 1, 1, true, 0, 0);
#endif
/**
@ -48,7 +48,7 @@ namespace nd4j {
* PLEASE NOTE: This op is disabled atm, and reserved for future releases.
*/
#if NOT_EXCLUDED(OP_to_float32)
DECLARE_OP(to_float32, 1, 1, true);
DECLARE_CUSTOM_OP(to_float32, 1, 1, true, 0, 0);
#endif
/**
@ -57,7 +57,7 @@ namespace nd4j {
* PLEASE NOTE: This op is disabled atm, and reserved for future releases.
*/
#if NOT_EXCLUDED(OP_to_int32)
DECLARE_OP(to_int32, 1, 1, true);
DECLARE_CUSTOM_OP(to_int32, 1, 1, true, 0, 0);
#endif
/**
@ -66,7 +66,7 @@ namespace nd4j {
* PLEASE NOTE: This op is disabled atm, and reserved for future releases.
*/
#if NOT_EXCLUDED(OP_to_int64)
DECLARE_OP(to_int64, 1, 1, true);
DECLARE_CUSTOM_OP(to_int64, 1, 1, true, 0, 0);
#endif
/**
@ -75,7 +75,7 @@ namespace nd4j {
* PLEASE NOTE: This op is disabled atm, and reserved for future releases.
*/
#if NOT_EXCLUDED(OP_to_uint32)
DECLARE_OP(to_uint32, 1, 1, true);
DECLARE_CUSTOM_OP(to_uint32, 1, 1, true, 0, 0);
#endif
/**
@ -84,7 +84,7 @@ namespace nd4j {
* PLEASE NOTE: This op is disabled atm, and reserved for future releases.
*/
#if NOT_EXCLUDED(OP_to_uint64)
DECLARE_OP(to_uint64, 1, 1, true);
DECLARE_CUSTOM_OP(to_uint64, 1, 1, true, 0, 0);
#endif
/**

View File

@ -65,9 +65,15 @@ namespace nd4j {
#if NOT_EXCLUDED(OP_mergemax)
DECLARE_OP(mergemax, -1, 1, false);
#endif
/*
* Complete tensor with max indices merged from all input tensors list
*
* INPUT: tensors with the same shape
* OUTPUT: integer tensor with the same shape
* INT_ARG: result type (one of int), INT32 by default
*/
#if NOT_EXCLUDED(OP_mergemaxindex)
DECLARE_OP(mergemaxindex, -1, 1, false);
DECLARE_CUSTOM_OP(mergemaxindex, -1, 1, false, 0, 0);
#endif
#if NOT_EXCLUDED(OP_mergeadd)

View File

@ -43,7 +43,6 @@ namespace helpers {
axisVector[e] = a + rank;
}
}
}
}
}

View File

@ -85,18 +85,19 @@ namespace helpers {
}
}
void rollFunctorFull(nd4j::LaunchContext * context, NDArray* input, NDArray* output, int shift, std::vector<int> const& axes, bool inplace){
void rollFunctorFull(nd4j::LaunchContext * context, NDArray* input, NDArray* output, std::vector<int> const& shifts, std::vector<int> const& axes, bool inplace){
if (!inplace)
output->assign(input);
auto source = output; //input;
for (int axe: axes) {
for (auto i = 0; i < axes.size(); i++) {
int axe = axes[i];
if (axe == source->rankOf() - 1) {// last dimension
std::unique_ptr<ResultSet> listOfTensors(source->allTensorsAlongDimension({axe}));
std::unique_ptr<ResultSet> listOfOutTensors(output->allTensorsAlongDimension({axe}));
int fullLen = listOfTensors->size();
int theShift = shift;
int theShift = shifts[i];
if (theShift > 0) {
theShift %= fullLen;
}
@ -118,7 +119,7 @@ namespace helpers {
int fullLen = listOfTensors->size();
int sizeAt = input->sizeAt(axe);
int theShift = shift;
int theShift = shifts[i];
if (theShift > 0) {
theShift %= sizeAt;

View File

@ -35,8 +35,8 @@ void scatter(nd4j::LaunchContext *context, pairwise::Ops op, const NDArray& ind
if(outRank == 1) {
// PRAGMA_OMP_PARALLEL_FOR_ARGS(if(indLen > Environment::getInstance()->elementwiseThreshold()) schedule(guided))
PRAGMA_OMP_PARALLEL_FOR_ARGS(if(!lock) schedule(guided))
// PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(indLen > Environment::getInstance()->elementwiseThreshold()) schedule(guided))
PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(!lock) schedule(guided))
for(Nd4jLong i = 0; i < indLen; ++i) {
Nd4jLong idx = indices.e<Nd4jLong>(i);
@ -54,8 +54,8 @@ PRAGMA_OMP_PARALLEL_FOR_ARGS(if(!lock) schedule(guided))
std::vector<int> dimsToExcludeUpd(sizeOfDims);
std::iota(dimsToExcludeUpd.begin(), dimsToExcludeUpd.end(), 0);
// PRAGMA_OMP_PARALLEL_FOR_ARGS(if(indLen > Environment::getInstance()->elementwiseThreshold()) schedule(guided)) // causes known openMP asan bug !
PRAGMA_OMP_PARALLEL_FOR_ARGS(if(!lock) schedule(guided))
// PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(indLen > Environment::getInstance()->elementwiseThreshold()) schedule(guided)) // causes known openMP asan bug !
PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(!lock) schedule(guided))
for(Nd4jLong i = 0; i < indLen; ++i) {
NDArray outSubArr = output(indices.e<Nd4jLong>(i), std::vector<int>({0}));
@ -76,8 +76,8 @@ void scatterND(nd4j::LaunchContext *context, pairwise::Ops op, const NDArray& i
if(outRank == 1) {
// PRAGMA_OMP_PARALLEL_FOR_ARGS(if(indLen > Environment::getInstance()->elementwiseThreshold()) schedule(guided))
PRAGMA_OMP_PARALLEL_FOR_ARGS(if(!lock) schedule(guided))
// PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(indLen > Environment::getInstance()->elementwiseThreshold()) schedule(guided))
PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(!lock) schedule(guided))
for(Nd4jLong i = 0; i < indLen; ++i) {
Nd4jLong idx = indices.e<Nd4jLong>(i);
@ -93,8 +93,8 @@ PRAGMA_OMP_PARALLEL_FOR_ARGS(if(!lock) schedule(guided))
std::iota(dimsToExcludeUpd.begin(), dimsToExcludeUpd.end(), 0);
std::vector<Nd4jLong> idxRangeOut(2*outRank, 0);
// PRAGMA_OMP_PARALLEL_FOR_ARGS(if(indLen/indLastDim > Environment::getInstance()->elementwiseThreshold()) schedule(guided) firstprivate(idxRangeOut))
PRAGMA_OMP_PARALLEL_FOR_ARGS(if(!lock) schedule(guided) firstprivate(idxRangeOut))
// PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(indLen/indLastDim > Environment::getInstance()->elementwiseThreshold()) schedule(guided) firstprivate(idxRangeOut))
PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(!lock) schedule(guided) firstprivate(idxRangeOut))
for(Nd4jLong i = 0; i < indLen/indLastDim; ++i) {
NDArray indSubArr = indices(i, dimsToExcludeInd);

View File

@ -479,7 +479,7 @@ namespace helpers {
for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) {
auto outputT = listOfOutTensors->at(fi->first);
outputT->assign(listOfTensors->at(fi->second.at(0)));
auto loopSize = fi->second.size();
Nd4jLong loopSize = fi->second.size();
PRAGMA_OMP_PARALLEL_FOR
for (Nd4jLong idx = 1; idx < loopSize; ++idx) {
auto current = listOfTensors->at(fi->second.at(idx));

View File

@ -0,0 +1,81 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include <ops/declarable/helpers/shift.h>
namespace nd4j {
namespace ops {
namespace helpers {
template <typename T>
void rshift_bits_(LaunchContext* launchContext, NDArray &input, NDArray &output, uint32_t shift) {
auto lambda = LAMBDA_T(x, shift) {
return x >> shift;
};
input.applyLambda<T>(lambda, &output);
}
void rshift_bits(LaunchContext* launchContext, NDArray &x, NDArray &z, uint32_t shift) {
BUILD_SINGLE_SELECTOR(x.dataType(), rshift_bits_, (launchContext, x, z, shift), INTEGER_TYPES);
}
template <typename T>
void shift_bits_(LaunchContext* launchContext, NDArray &input, NDArray &output, uint32_t shift) {
auto lambda = LAMBDA_T(x, shift) {
return x << shift;
};
input.applyLambda<T>(lambda, &output);
}
void shift_bits(LaunchContext* launchContext, NDArray &x, NDArray &z, uint32_t shift) {
BUILD_SINGLE_SELECTOR(x.dataType(), shift_bits_, (launchContext, x, z, shift), INTEGER_TYPES);
}
template <typename T>
void cyclic_rshift_bits_(LaunchContext* launchContext, NDArray &input, NDArray &output, uint32_t shift) {
auto step = (sizeof(T) * 8) - shift;
auto lambda = LAMBDA_T(x, shift, step) {
return x >> shift | x << step;
};
input.applyLambda<T>(lambda, &output);
}
void cyclic_rshift_bits(LaunchContext* launchContext, NDArray &x, NDArray &z, uint32_t shift) {
BUILD_SINGLE_SELECTOR(x.dataType(), cyclic_rshift_bits_, (launchContext, x, z, shift), INTEGER_TYPES);
}
template <typename T>
void cyclic_shift_bits_(LaunchContext* launchContext, NDArray &input, NDArray &output, uint32_t shift) {
auto step = (sizeof(T) * 8) - shift;
auto lambda = LAMBDA_T(x, shift, step) {
return x << shift | x >> step;
};
input.applyLambda<T>(lambda, &output);
}
void cyclic_shift_bits(LaunchContext* launchContext, NDArray &x, NDArray &z, uint32_t shift) {
BUILD_SINGLE_SELECTOR(x.dataType(), cyclic_shift_bits_, (launchContext, x, z, shift), INTEGER_TYPES);
}
}
}
}

View File

@ -562,7 +562,7 @@ static void gatherND_(NDArray& input, NDArray& indices, NDArray& output) {
std::vector<Nd4jLong> coords(maxRank);
PRAGMA_OMP_PARALLEL_FOR_ARGS(if(zLen > Environment::getInstance()->elementwiseThreshold()) firstprivate(coords))
PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(zLen > Environment::getInstance()->elementwiseThreshold()) firstprivate(coords))
for (Nd4jLong i = 0; i < zLen; ++i) {
Nd4jLong *zCoordStart, *xCoordStart;

View File

@ -27,6 +27,8 @@ namespace helpers {
void adjustAxis(Nd4jLong rank, NDArray* axisVector, std::vector<int>& output) {
output.resize(axisVector->lengthOf());
axisVector->tickReadDevice();
axisVector->syncToHost();
for (int e = 0; e < axisVector->lengthOf(); e++) {
auto ca = axisVector->e<int>(e);
if (ca < 0)

View File

@ -30,103 +30,101 @@
#include <ConstantTadHelper.h>
namespace nd4j {
namespace ops {
namespace helpers {
namespace ops {
namespace helpers {
///////////////////////////////////////////////////////////////////
template<typename T>
__global__ static void concatCuda(const int numOfArrs, void* pVx, void* pxShapeInfo, void* pVz, void* pzShapeInfo) {
template<typename T>
__global__ static void concatCuda(void* pVx, void* pxShapeInfo, void* vz, Nd4jLong* zShapeInfo, const int axis) {
__shared__ int arrIdx, blocksPerArr;
T* z = reinterpret_cast<T*>(vz);
__shared__ Nd4jLong zLen, totalThreads, *sharedMem;
__shared__ int rank;
if (threadIdx.x == 0) {
extern __shared__ unsigned char shmem[];
sharedMem = reinterpret_cast<Nd4jLong*>(shmem);
blocksPerArr = (gridDim.x + numOfArrs - 1) / numOfArrs; // ceil
arrIdx = blockIdx.x / blocksPerArr;
zLen = shape::length(zShapeInfo);
rank = shape::rank(zShapeInfo);
totalThreads = gridDim.x * blockDim.x;
}
__syncthreads();
for(int j = arrIdx; j < numOfArrs; j += gridDim.x) {
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
const auto* x = reinterpret_cast<T*>(reinterpret_cast<void**>(pVx)[j]);
auto* z = reinterpret_cast<T*>(reinterpret_cast<void**>(pVz)[j]);
const auto* xShapeInfo = reinterpret_cast<Nd4jLong**>(pxShapeInfo)[j];
const auto* zShapeInfo = reinterpret_cast<Nd4jLong**>(pzShapeInfo)[j];
if(tid >= zLen)
return;
const auto arrLen = shape::length(xShapeInfo);
auto coords = sharedMem + threadIdx.x * rank;
const auto arrLenPerBlock = (arrLen + blocksPerArr - 1) / blocksPerArr; // ceil
shape::index2coords(rank, zShapeInfo + 1, tid, zLen, coords);
const auto start = (blockIdx.x % blocksPerArr) * arrLenPerBlock;
const auto end = (start + arrLenPerBlock) > arrLen ? arrLen : (start + arrLenPerBlock);
const auto zOffset = shape::getOffset(0, zShapeInfo + 1, zShapeInfo + rank + 1, coords, rank);
for (Nd4jLong i = start + threadIdx.x; i < end; i += blockDim.x)
z[shape::getIndexOffset(i, zShapeInfo, arrLen)] = x[shape::getIndexOffset(i, xShapeInfo, arrLen)];
}
int inArrIdx = 0;
Nd4jLong *xShapeInfo = reinterpret_cast<Nd4jLong**>(pxShapeInfo)[inArrIdx];
while(coords[axis] >= xShapeInfo[axis + 1]) {
coords[axis] -= xShapeInfo[axis + 1];
xShapeInfo = reinterpret_cast<Nd4jLong**>(pxShapeInfo)[++inArrIdx];
}
const auto* x = reinterpret_cast<T*>(reinterpret_cast<void**>(pVx)[inArrIdx]);
const auto xOffset = shape::getOffset(0, xShapeInfo + 1, xShapeInfo + rank + 1, coords, rank);
z[zOffset] = x[xOffset];
}
///////////////////////////////////////////////////////////////////
template<typename T>
__host__ static void concatCudaLauncher(const int numOfArrs, const cudaStream_t *stream, void* pVx, void* pxShapeInfo, void* pVz, void* pzShapeInfo) {
template<typename T>
__host__ static void concatCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream,
void* pVx, void* pxShapeInfo, void* vz, Nd4jLong* zShapeInfo, const int axis) {
concatCuda<T><<<512, 512, 512, *stream>>>(numOfArrs, pVx, pxShapeInfo, pVz, pzShapeInfo);
}
BUILD_SINGLE_TEMPLATE(template void concatCudaLauncher, (const int numOfArrs, const cudaStream_t *stream, void* pVx, void* pxShapeInfo, void* pVz, void* pzShapeInfo), LIBND4J_TYPES);
concatCuda<T><<<blocksPerGrid, threadsPerBlock, sharedMem, *stream>>>(pVx, pxShapeInfo, vz, zShapeInfo, axis);
}
BUILD_SINGLE_TEMPLATE(template void concatCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, void* pVx, void* pxShapeInfo, void* vz, Nd4jLong* zShapeInfo, const int axis), LIBND4J_TYPES);
//////////////////////////////////////////////////////////////////////////
void concat(nd4j::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output, const int axis) {
//////////////////////////////////////////////////////////////////////////
void concat(nd4j::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output, const int axis) {
const int threadsPerBlock = MAX_NUM_THREADS / 4;
const int blocksPerGrid = (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
const int sharedMem = threadsPerBlock * sizeof(Nd4jLong) * output.rankOf() + 128;
const int numOfArrs = inArrs.size();
for(int i = 0; i < numOfArrs; ++i)
if(!inArrs[i]->isActualOnDeviceSide()) inArrs[i]->syncToDevice();
inArrs[i]->syncToDevice();
const int rank = inArrs[0]->rankOf();
const int rank2 = 2*rank;
std::vector<std::vector<Nd4jLong>> indices(numOfArrs, std::vector<Nd4jLong>(rank2,0));
// take into account indices for first array
indices[0][2 * axis + 1] = inArrs[0]->sizeAt(axis);
// loop through the rest of input arrays
for(int i = 1; i < numOfArrs; ++i) {
indices[i][2 * axis] = indices[i-1][2 * axis + 1]; // index start from
indices[i][2 * axis + 1] = indices[i-1][2 * axis + 1] + inArrs[i]->sizeAt(axis); // index end with (excluding)
}
std::vector<NDArray*> outSubArrs(numOfArrs);
for(int i = 0; i < numOfArrs; ++i)
outSubArrs[i] = new NDArray(output(indices[i], true));
output.syncToDevice();
// prepare arrays of pointers on buffers and shapes
std::vector<void*> hOutBuffers(numOfArrs), hInBuffers(numOfArrs);
std::vector<Nd4jLong*> hOutShapeInfo(numOfArrs), hInShapeInfo(numOfArrs);
std::vector<void*> hInBuffers(numOfArrs);
std::vector<Nd4jLong*> hInShapeInfo(numOfArrs);
for(int i = 0; i < numOfArrs; ++i) {
hOutBuffers[i] = outSubArrs[i]->getSpecialBuffer();
hInBuffers[i] = inArrs[i]->getSpecialBuffer();
hOutShapeInfo[i] = outSubArrs[i]->getSpecialShapeInfo();
hInShapeInfo[i] = inArrs[i]->getSpecialShapeInfo();
}
// allocate and copy all buffers and shapes arrays to global memory
PointersManager manager(context, "helpers::concat");
void* dOutBuffers = manager.replicatePointer(hOutBuffers.data(), hOutBuffers.size() * sizeof(void*));
void* dInBuffers = manager.replicatePointer(hInBuffers.data(), hInBuffers.size() * sizeof(void*));
void* dInShapeInfo = manager.replicatePointer(hInShapeInfo.data(), hInShapeInfo.size() * sizeof(Nd4jLong*));
void* dOutShapeInfo = manager.replicatePointer(hOutShapeInfo.data(), hOutShapeInfo.size() * sizeof(Nd4jLong*));
BUILD_SINGLE_SELECTOR(inArrs[0]->dataType(), concatCudaLauncher, (numOfArrs, context->getCudaStream(), dInBuffers, dInShapeInfo, dOutBuffers, dOutShapeInfo), LIBND4J_TYPES);
BUILD_SINGLE_SELECTOR(inArrs[0]->dataType(), concatCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), dInBuffers, dInShapeInfo, output.specialBuffer(), output.specialShapeInfo(), axis), LIBND4J_TYPES);
manager.synchronize();
for(int i = 0; i < numOfArrs; ++i)
delete outSubArrs[i];
for(int i = 0; i < numOfArrs; ++i)
inArrs[i]->tickReadHost();
inArrs[i]->tickReadDevice();
output.tickWriteDevice();
}
}
}
}
}
}
}

View File

@ -25,7 +25,7 @@ namespace nd4j {
namespace ops {
namespace helpers {
template <typename X, typename Z>
void _CUDA_G histogramKernel(void *xBuffer, Nd4jLong *xShapeInfo, void *zBuffer, Nd4jLong *zShapeInfo, void *allocationPointer, void *reductionPointer, Nd4jLong numBins, double min_val, double max_val) {
void _CUDA_G histogramKernel(void *xBuffer, Nd4jLong *xShapeInfo, void *zBuffer, Nd4jLong *zShapeInfo, void *allocationPointer, void *reductionPointer, Nd4jLong numBins, X* min_val, X* max_val) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
auto dx = reinterpret_cast<X*>(xBuffer);
auto result = reinterpret_cast<Z*>(zBuffer);
@ -42,19 +42,19 @@ namespace nd4j {
}
__syncthreads();
Z binSize = (max_val - min_val) / (numBins);
X binSize = X((*max_val - *min_val) / numBins);
for (int e = threadIdx.x; e < numBins; e += blockDim.x) {
bins[e] = (Z) 0.0f;
bins[e] = (Z) 0;
}
__syncthreads();
for (int e = tid; e < length; e+= blockDim.x * gridDim.x) {
int idx = (int) ((dx[e] - min_val) / binSize);
if (idx < 0) idx = 0;
else if (idx >= numBins) idx = numBins - 1;
nd4j::math::atomics::nd4j_atomicAdd(&bins[idx], (Z) 1.0f);
for (int e = tid; e < length; e += blockDim.x * gridDim.x) {
int idx = int((dx[e] - *min_val) / binSize);
idx = math::nd4j_max(idx, 0); //atomicMax(&idx, 0);//atomicMax(&idx, 0);
idx = math::nd4j_min(idx, int(numBins - 1)); //atomicMin(&idx, int(numBins - 1));
nd4j::math::atomics::nd4j_atomicAdd(&bins[idx], (Z)1);
// bins[idx]++;
}
__syncthreads();
@ -82,7 +82,7 @@ namespace nd4j {
// nullify shared memory for future accumulation
for (int e = threadIdx.x; e < numBins; e += blockDim.x) {
bins[e] = (Z) 0.0f;
bins[e] = (Z) 0;
}
// accumulate reduced bins
@ -90,7 +90,7 @@ namespace nd4j {
Z *ptrBuf = ((Z *)allocationPointer) + (r * numBins);
for (int e = threadIdx.x; e < numBins; e += blockDim.x) {
bins[e] += ptrBuf[e];
math::atomics::nd4j_atomicAdd(&bins[e], ptrBuf[e]);
}
}
__syncthreads();
@ -109,24 +109,26 @@ namespace nd4j {
}
template <typename X, typename Z>
static void histogram_(nd4j::LaunchContext *context, void *xBuffer, Nd4jLong *xShapeInfo, void *zBuffer, Nd4jLong *zShapeInfo, Nd4jLong numBins, double min_val, double max_val) {
static void histogram_(nd4j::LaunchContext *context, void *xBuffer, Nd4jLong *xShapeInfo, Nd4jLong *dxShapeInfo, void *zBuffer, Nd4jLong *zShapeInfo, Nd4jLong numBins, void* min_val, void* max_val) {
int numThreads = 256;
int numBlocks = nd4j::math::nd4j_max<int>(256, nd4j::math::nd4j_min<int>(1, shape::length(xShapeInfo) / numThreads));
int workspaceSize = numBlocks * numBins;
auto tmp = NDArrayFactory::create<Z>('c',{workspaceSize});
auto tmp = NDArrayFactory::create<Z>('c', {workspaceSize});
histogramKernel<X, Z><<<numBlocks, numThreads, 32768, *context->getCudaStream()>>>(xBuffer, xShapeInfo, zBuffer, zShapeInfo, tmp.getSpecialBuffer(), context->getReductionPointer(), numBins, min_val, max_val);
histogramKernel<X, Z><<<numBlocks, numThreads, 32768, *context->getCudaStream()>>>(xBuffer, dxShapeInfo, zBuffer, zShapeInfo, tmp.getSpecialBuffer(), context->getReductionPointer(), numBins, reinterpret_cast<X*>(min_val), reinterpret_cast<X*>(max_val));
cudaStreamSynchronize(*context->getCudaStream());
}
void histogramHelper(nd4j::LaunchContext *context, NDArray &input, NDArray &output) {
Nd4jLong numBins = output.lengthOf();
double min_val = input.reduceNumber(reduce::SameOps::Min).e<double>(0);
double max_val = input.reduceNumber(reduce::SameOps::Max).e<double>(0);
BUILD_DOUBLE_SELECTOR(input.dataType(), output.dataType(), histogram_, (context, input.specialBuffer(), input.specialShapeInfo(), output.getSpecialBuffer(), output.getSpecialShapeInfo(), numBins, min_val, max_val), LIBND4J_TYPES, INDEXING_TYPES);
NDArray::registerSpecialUse({&output}, {&input});
auto min_val = input.reduceNumber(reduce::SameOps::Min);
auto max_val = input.reduceNumber(reduce::SameOps::Max);
// min_val.printIndexedBuffer("MIN");
// max_val.printIndexedBuffer("MAX");
BUILD_DOUBLE_SELECTOR(input.dataType(), output.dataType(), histogram_, (context, input.specialBuffer(), input.shapeInfo(), input.specialShapeInfo(), output.getSpecialBuffer(), output.getSpecialShapeInfo(), numBins, min_val.specialBuffer(), max_val.specialBuffer()), LIBND4J_TYPES, INTEGER_TYPES);
NDArray::registerSpecialUse({&output}, {&input});
}
}

View File

@ -68,21 +68,21 @@ namespace helpers {
static __global__ void shouldSelectKernel(T* boxesBuf, Nd4jLong* boxesShape, I* indexBuf, I* selectedIndicesData, double threshold, int numSelected, int i, bool* shouldSelect) {
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
auto step = gridDim.x * blockDim.x;
__shared__ bool shouldSelectShared;
__shared__ unsigned int shouldSelectShared;
if (threadIdx.x == 0) {
shouldSelectShared = shouldSelect[0];
shouldSelectShared = (unsigned int)shouldSelect[0];
}
__syncthreads();
for (int j = numSelected - 1 - tid; j >= 0; j -= step) {
if (shouldSelectShared) {
if (needToSuppressWithThreshold(boxesBuf, boxesShape, indexBuf[i],
indexBuf[selectedIndicesData[j]], T(threshold)))
shouldSelectShared = false;
atomicCAS(&shouldSelectShared, 1, 0);
}
}
__syncthreads();
if (threadIdx.x == 0) {
*shouldSelect = shouldSelectShared;
*shouldSelect = shouldSelectShared > 0;
}
}

View File

@ -34,11 +34,6 @@ namespace helpers {
template <typename T>
static void ismax_(nd4j::LaunchContext * context, const NDArray* input, NDArray* output, const std::vector<int>& dimensions) {
void* extraParams = nullptr;
bool scalarCheat = false;
if (extraParams == nullptr) {
scalarCheat = true;
}
auto stream = context->getCudaStream();
auto xRank = input->rankOf();
@ -49,29 +44,16 @@ static void ismax_(nd4j::LaunchContext * context, const NDArray* input, NDArray*
Nd4jLong* special = nullptr;
PointersManager manager(context, "IsMaxHelper");
if (dimensions.size() == 0) {
// auto scalarShape = ShapeBuilders::createScalarShapeInfo(nd4j::DataType::INT64);
/**
* In case of vector-input for IsMax, it just turns into IndexReduce call + further filler call
* In case of vector-input for IsMax, it just turns into IndexReduce call + subsequent filler call
*/
auto indexMax = input->applyIndexReduce(indexreduce::IndexMax, dimensions);
//NativeOpExecutioner::execIndexReduceScalar(context, indexreduce::IndexMax, nullptr, input->getShapeInfo(), input->getSpecialBuffer(), input->getSpecialShapeInfo(), extraParams, nullptr, scalarShape, special, nullptr);
//Nd4jLong maxIdx = -119;
//checkCudaErrors(cudaStreamSynchronize(*stream));
//cudaMemcpyAsync(&maxIdx, special, sizeof(Nd4jLong), cudaMemcpyDeviceToHost, *stream);
//checkCudaErrors(cudaStreamSynchronize(*stream));
int targetIdx = 0;
auto targetIdx = indexMax->e<Nd4jLong>(0);
if (input->ordering() == 'c' || input->ordering() == 'f' && indexMax->e<Nd4jLong>(0) * shape::stride(input->getShapeInfo())[input->rankOf() - 1] >= input->lengthOf())
targetIdx = indexMax->e<Nd4jLong>(0);
else
targetIdx = indexMax->e<Nd4jLong>(0) * shape::stride(input->getShapeInfo())[input->rankOf() - 1];
dim3 launchDims(128, 512, 1024);
BUILD_SINGLE_SELECTOR(zType, fillIsMaxGeneric, (launchDims, stream, output->specialBuffer(), output->specialShapeInfo(), output->lengthOf(), targetIdx), LIBND4J_TYPES);
manager.synchronize();
dim3 launchDims(1, 512, 1024);
BUILD_SINGLE_SELECTOR(zType, fillIsMaxGeneric, (launchDims, stream, output->specialBuffer(), output->lengthOf(), targetIdx), LIBND4J_TYPES);
nd4j::DebugHelper::checkErrorCode(stream, "Legacy IsMax(...) failed");
//delete[] scalarShape;
delete indexMax;
} else {
Nd4jLong* hostYShapeInfo = nullptr;
@ -82,13 +64,7 @@ static void ismax_(nd4j::LaunchContext * context, const NDArray* input, NDArray*
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), copy.data(), copy.size());
auto indexMaxArr = input->applyIndexReduce(indexreduce::IndexMax, dimensions);
//indexMaxArr->printIndexedBuffer("Index max!!!");
// we call for IMax on specified dimension
//NativeOpExecutioner::execIndexReduce(context, indexreduce::IndexMax, nullptr, input->getShapeInfo(), input->getSpecialBuffer(), input->getSpecialShapeInfo(), extraParams, nullptr, hostTShapeInfo, special, hostYShapeInfo, const_cast<int*>(dimensions.data()), (int)dimensions.size(), nullptr, nullptr);
//DEBUG_KERNEL(stream, opNum);
dim3 launchDims(256, 256, 16384);
dimension = (int *) manager.replicatePointer(dimensions.data(), dimensions.size() * sizeof(int));
@ -103,7 +79,11 @@ static void ismax_(nd4j::LaunchContext * context, const NDArray* input, NDArray*
void ismax(nd4j::LaunchContext * context, const NDArray *input, NDArray *output, const std::vector<int>& dimensions) {
NDArray::prepareSpecialUse({output}, {input});
BUILD_SINGLE_SELECTOR(input->dataType(), ismax_, (context, input, output, dimensions), LIBND4J_TYPES);
NDArray::registerSpecialUse({output}, {input});
}
BUILD_SINGLE_TEMPLATE(template void ismax_, (nd4j::LaunchContext * context, const NDArray *input, NDArray *output, const std::vector<int>& dimensions), LIBND4J_TYPES);

View File

@ -48,8 +48,10 @@ namespace nd4j {
auto x = reinterpret_cast<T*>(inArrs[i]);
auto xShape = reinterpret_cast<Nd4jLong *>(inShapes[i]);
auto val = x[shape::getIndexOffset(e, xShape, length)];;
if (mVal < val)
mIdx = static_cast<Z>(e);
if (mVal < val) {
mIdx = static_cast<Z>(i);
mVal = val;
}
}
__syncthreads();

View File

@ -228,22 +228,23 @@ namespace helpers {
}
template <typename T>
static void rollFunctorFull_(NDArray* input, NDArray* output, int shift, std::vector<int> const& axis, bool inplace){
static void rollFunctorFull_(NDArray* input, NDArray* output, std::vector<int> const& shifts, std::vector<int> const& axes, bool inplace){
if (!inplace)
output->assign(input);
for (int axe: axis) {
for (size_t i = 0; i < axes.size(); i++) {
int axe = axes[i];
if (axe == input->rankOf() - 1) { // last dimension
std::unique_ptr<ResultSet> listOfTensors(output->allTensorsAlongDimension({axe}));
std::unique_ptr<ResultSet> listOfOutTensors(output->allTensorsAlongDimension({axe}));
int fullLen = listOfTensors->size();
int theShift = shift;
if (theShift > 0) {
theShift %= fullLen;
}
else {
theShift -= fullLen * (theShift / fullLen - 1);
}
int theShift = shifts[i];
// if (theShift > 0) {
// theShift %= fullLen;
// }
// else {
// theShift -= fullLen * (theShift / fullLen - 1);
// }
for (int k = 0; k < fullLen; k++) {
rollFunctorLinear(output->getContext(), listOfTensors->at(k), listOfOutTensors->at(k), theShift, true);
}
@ -258,12 +259,12 @@ namespace helpers {
int sizeAt = input->sizeAt(axe);
auto tadLength = shape::length(packZ.primaryShapeInfo());
int theShift = shift;
int theShift = shifts[i];
if (theShift > 0)
theShift %= sizeAt;
else
theShift -= sizeAt * (theShift / sizeAt - 1);
// if (theShift > 0)
// theShift %= sizeAt;
// else
// theShift -= sizeAt * (theShift / sizeAt - 1);
if (theShift) {
for (int dim = 0; dim < numTads / sizeAt; ++dim) {
@ -307,10 +308,10 @@ namespace helpers {
}
}
void rollFunctorFull(nd4j::LaunchContext * context, NDArray* input, NDArray* output, int shift, std::vector<int> const& axis, bool inplace){
void rollFunctorFull(nd4j::LaunchContext * context, NDArray* input, NDArray* output, std::vector<int> const& shifts, std::vector<int> const& axes, bool inplace){
input->syncToDevice();
BUILD_SINGLE_SELECTOR(input->dataType(), rollFunctorFull_, (input, output, shift, axis, inplace), LIBND4J_TYPES);
BUILD_SINGLE_SELECTOR(input->dataType(), rollFunctorFull_, (input, output, shifts, axes, inplace), LIBND4J_TYPES);
output->tickWriteDevice();
}
@ -324,7 +325,7 @@ namespace helpers {
}
BUILD_SINGLE_TEMPLATE(template void rollFunctorLinear_, (NDArray* input, NDArray* output, int shift, bool inplace), LIBND4J_TYPES);
BUILD_SINGLE_TEMPLATE(template void rollFunctorFull_, (NDArray* input, NDArray* output, int shift, std::vector<int> const& axis, bool inplace), LIBND4J_TYPES);
BUILD_SINGLE_TEMPLATE(template void rollFunctorFull_, (NDArray* input, NDArray* output, std::vector<int> const& shifts, std::vector<int> const& axes, bool inplace), LIBND4J_TYPES);
}
}
}

View File

@ -123,14 +123,236 @@ namespace nd4j {
nSamplingKernel<T><<<1,1,128, *stream>>>(vsyn0, vsyn1Neg, vexpTable, vneu1e, alpha, vectorLength, code, expLength, isInference);
}
/*
* binarySearch - find element in haystack buffer (haystack - sorted device memory)
* */
int binarySearch(const int *haystack, const int needle, const int totalElements) {
return 0;
int firstIndex = 0;
int lastIndex = totalElements - 1;
int halfIndex = nd4j::math::nd4j_floor<float, int>((lastIndex + firstIndex) / (float) 2);
while(haystack[halfIndex] != needle && firstIndex < lastIndex) {
if (needle < haystack[halfIndex]) {
lastIndex = halfIndex - 1;
} else if (needle > haystack[halfIndex]) {
firstIndex = halfIndex + 1;
}
halfIndex = nd4j::math::nd4j_floor<float, int>((lastIndex + firstIndex) / (float) 2);
}
void skipgram(NDArray &syn0, NDArray &syn1, NDArray &syn1Neg, NDArray &expTable, NDArray &negTable, NDArray &target, NDArray &ngStarter, int nsRounds, NDArray &indices, NDArray &codes, NDArray &alpha, NDArray &randomValue, NDArray &inferenceVector, const bool preciseMode, const int numWorkers) {
return (haystack[halfIndex] == needle) ? halfIndex : -1;
}
template <typename T>
__global__ void addInfVectorKernel(T* neu1, T* infVector, int vectorLength) {
auto start = blockIdx.x * blockDim.x + threadIdx.x;
auto step = blockDim.x * gridDim.x;
for (auto i = start; i < vectorLength; i += step) {
neu1[i] += infVector[i];
}
}
template <typename T>
void skipgram_(NDArray& s0, NDArray& s1, NDArray& s1n, NDArray& expTableV, NDArray& negTableV, NDArray& infV, int target, int ngStarter, NDArray& indices, NDArray& codes, double alpha, Nd4jLong randomValue, const int hsRounds, const int nsRounds) {
// void *vsyn0, void *vsyn1, void *vsyn1Neg, void *vexpTable, void *vnegTable, void *vinfVector, int target, int ngStarter, int *indices, int8_t *codes, double alpha, Nd4jLong randomValue, const int hsRounds, const int nsRounds, const int vocabSize, const int vectorLength, const int expLength, const int negLength) {
auto syn0 = reinterpret_cast<T*>(s0.specialBuffer());
auto syn1 = reinterpret_cast<T*>(s1.specialBuffer());
auto syn1Neg = reinterpret_cast<T*>(s1n.specialBuffer());
auto expTable = reinterpret_cast<T*>(expTableV.specialBuffer());
auto negTable = reinterpret_cast<T*>(negTableV.specialBuffer());
auto infVector = reinterpret_cast<T*>(infV.specialBuffer());
const int vocabSize = s0.sizeAt(0);
const int vectorLength = s0.sizeAt(1);
const int expLength = expTableV.lengthOf();
const int negLength = negTableV.lengthOf();
indices.tickReadDevice();
indices.syncToHost();
codes.tickReadDevice();
codes.syncToHost();
auto stream = s0.getContext()->getCudaStream();
T* neu1e; // = new T[vectorLength];
//memset(neu1e, 0, vectorLength * sizeof(T));
auto err = cudaMalloc(&neu1e, sizeof(T) * vectorLength);
err = cudaMemset(neu1e, 0, sizeof(T) * vectorLength);
// hierarchic softmax goes first (if enabled)
auto syn0row = infVector != nullptr ? infVector : syn0 + (target * vectorLength);
auto irow = 0;
if (hsRounds > 0) {
for (int r = 0; r < hsRounds; r++) {
irow = indices.t<int>(r);
if (irow < 0 || irow >= vocabSize)
break;
hSoftmax_<T>(syn0row, syn1 + (irow * vectorLength), expTable, neu1e, alpha, vectorLength, codes.t<int8_t>(r), expLength, infVector != nullptr, stream);
}
}
// negative sampling goes second (if enabled)
auto nsStarter = ngStarter;
irow = nsStarter;
if (nsRounds > 0) {
for (int r = 0; r < nsRounds + 1; r++) {
if (r == 0) {
// target is known in advance
} else {
randomValue = randomValue * (unsigned long long) 25214903917 + 11;
auto idx = nd4j::math::nd4j_abs<Nd4jLong >((randomValue >> 16) % negLength);
irow = idx >= negLength ? -1 : negTableV.e<int>(idx);
if (irow < 0 || irow >= vocabSize) irow = randomValue % (vocabSize - 1) + 1;
if (irow == nsStarter)
continue;
}
nSampling_<T>(syn0row, syn1Neg + (irow * vectorLength), expTable, neu1e, alpha, vectorLength, r == 0 ? 1 : 0, expLength, infVector != nullptr, stream);
}
}
if (infVector == nullptr) {
addInfVectorKernel<T><<<128, 256, 256, *stream>>>(syn0row, neu1e, vectorLength);
} else {
addInfVectorKernel<T><<<128, 256, 256, *stream>>>(infVector, neu1e, vectorLength);
}
err = cudaFree(neu1e);
if (0 != err) {
throw cuda_exception::build("helpers::skipgram_: Cannot deallocate temp memory for lingual net", err);
}
}
BUILD_SINGLE_TEMPLATE(template void skipgram_, (NDArray& syn0, NDArray& syn1, NDArray& syn1Neg, NDArray& expTable, NDArray& negTable, NDArray& infVector, int target, int ngStarter, NDArray& indices, NDArray& codes, double alpha, Nd4jLong randomValue, const int hsRounds, const int nsRounds), FLOAT_TYPES);
/*
* batched version of skipgram routine
* */
template <typename T>
void skipgramBatchExec_(NDArray &s0, NDArray &s1, NDArray &s1n, NDArray& expTableV, NDArray& negTableV, NDArray &targets, NDArray &negStarters, NDArray &indices, NDArray &codes, NDArray &lr, NDArray &nextRandom, const int nsRounds, const bool preciseMode, const int numThreads) {
// (NDArray &s0, NDArray &s1, NDArray &s1n, NDArray& expTable, NDArray& negTable, NDArray& infVector, NDArray& targets, NDArray& negStarters, NDArray& indices, NDArray& codes, NDArray& lr, NDArray& nextRandom, const int nsRounds, const bool preciseMode, const int numThreads) {
//auto syn0 = reinterpret_cast<T*>(vsyn0);
//auto syn1 = reinterpret_cast<T*>(vsyn1);
//auto syn1Neg = reinterpret_cast<T*>(vsyn1Neg);
auto stream = s0.getContext()->getCudaStream();
negTableV.tickReadDevice();
negTableV.syncToHost();
const auto expTable = reinterpret_cast<T*>(expTableV.specialBuffer());
const auto negTable = reinterpret_cast<T*>(negTableV.buffer());
const auto infVector = (T*)nullptr; //reinterpret_cast<T*>(infVector.specialBuffer());
const int vocabSize = s0.sizeAt(0);
const int vectorLength = s0.sizeAt(1);
const int expLength = expTableV.lengthOf();
const int negLength = negTableV.lengthOf();
//T sneu1e[600];
//const auto numThreads = omp_get_max_threads();
const auto idxShift = indices.isEmpty() ? 0 : indices.sizeAt(1);
const auto hsRounds = codes.isEmpty() ? 0 : codes.sizeAt(1);
// regular mode provides 0 guarantees for reproducibility
auto numTargets = targets.lengthOf();
targets.syncToHost();
indices.syncToHost();
codes.syncToHost();
lr.syncToHost();
nextRandom.syncToHost();
negStarters.tickReadDevice();
negStarters.syncToHost();
auto bTarget = reinterpret_cast<int*>(targets.buffer()); //targets.bufferAsT<int>();
auto bIndices = reinterpret_cast<int*>(indices.buffer()); //indices.bufferAsT<int>();
auto bCodes = reinterpret_cast<int8_t*>(codes.buffer()); //codes.bufferAsT<int8_t>();
// PRAGMA_OMP_PARALLEL_FOR_ARGS(num_threads(numThreads))
for (int t = 0; t < numTargets; t++) {
T* neu1e;//lvectorLength <= 600 ? sneu1e : new T[vectorLength];
auto err = cudaMalloc(&neu1e, vectorLength * sizeof(T));
err = cudaMemset(neu1e, 0, vectorLength * sizeof(T));
//memset(neu1e, 0, vectorLength * sizeof(T));
auto target = bTarget[t];
auto alpha = lr.e<double>(t);
unsigned long long randomValue = nextRandom.e<Nd4jLong>(t);
auto syn0row = reinterpret_cast<T*>(s0.specialBuffer()) + (target * vectorLength);
if (hsRounds > 0) {
int irow = 0;
auto cShift = t * idxShift;
for (int e = 0; e < hsRounds; e++) {
irow = bIndices[e + cShift];
if (irow < 0 || irow >= vocabSize)
continue;
auto syn1row = reinterpret_cast<T*>(s1.getSpecialBuffer()) + (irow * vectorLength);
auto code = bCodes[e + cShift];
//nd4j_printf("syn0: [%i]; syn1: [%i]; code: [%i]\n", target, irow, code);
hSoftmax_<T>(syn0row, syn1row, expTable, neu1e, alpha, vectorLength, code, expLength, false, stream);
}
}
if (nsRounds > 0) {
int irow = negStarters.e<int>(t);
int nsStarter = irow;
for (int r = 0; r < nsRounds + 1; r++) {
if (r == 0) {
// target is known in advance
} else {
randomValue = randomValue * (unsigned long long) 25214903917 + 11;
auto idx = nd4j::math::nd4j_abs<Nd4jLong >((randomValue >> 16) % negLength);
irow = idx >= negLength ? -1 : static_cast<int>(negTable[idx]);
if (irow < 0 || irow >= vocabSize)
irow = randomValue % (vocabSize - 1) + 1;
if (irow == nsStarter)
continue;
}
auto syn1row = reinterpret_cast<T*>(s1n.getSpecialBuffer()) + (irow * vectorLength);
nSampling_<T>(syn0row, syn1row, expTable, neu1e, alpha, vectorLength, r == 0 ? 1 : 0, expLength, false, stream);
}
}
addInfVectorKernel<T><<<128, 256, 256, *stream>>>(syn0row, neu1e, vectorLength);
// optionally release temp arrays
err = cudaFree(neu1e);
if (err != 0) {
break;
}
// if (vectorLength > 600)
// delete[] neu1e;
}
}
BUILD_SINGLE_TEMPLATE(template void skipgramBatchExec_, (NDArray &s0, NDArray &s1, NDArray &s1n, NDArray& expTable, NDArray& negTable, NDArray &targets, NDArray &negStarters, NDArray &indices, NDArray &codes, NDArray &lr, NDArray &nextRandom, const int nsRounds, const bool preciseMode, const int numThreads), FLOAT_TYPES);
void skipgram(NDArray &syn0, NDArray &syn1, NDArray &syn1Neg, NDArray &expTable, NDArray &negTable,
NDArray &target, NDArray &ngStarter, int nsRounds, NDArray &indices, NDArray &codes, NDArray &alpha, NDArray &randomValue, NDArray &inferenceVector, const bool preciseMode, const int numWorkers) {
auto xType = syn0.dataType();
// single round case
if ((ngStarter.isScalar() && !ngStarter.isEmpty())|| (target.isScalar() && !target.isEmpty())) {
auto hsRounds = codes.lengthOf();
target.syncToHost();
ngStarter.syncToHost();
alpha.syncToHost();
randomValue.syncToHost();
auto targetV = target.isEmpty() ? -1 : target.e<int>(0);
auto starterV = ngStarter.isEmpty() ? -1 : ngStarter.e<int>(0);
auto alphaV = alpha.e<double>(0);
auto randomV = randomValue.e<Nd4jLong>(0);
BUILD_SINGLE_SELECTOR(xType, skipgram_, (syn0, syn1, syn1Neg, expTable, negTable, inferenceVector, targetV, starterV, indices, codes, alphaV, randomV, hsRounds, nsRounds), FLOAT_TYPES);
} else if (ngStarter.isVector() || target.isVector()){
// batch mode
// NDArray& infVector, NDArray &targets, NDArray &negStarters, NDArray &indices, NDArray &codes, NDArray &lr, NDArray &nextRandom, const int nsRounds, const bool preciseMode, const int numThreads)
BUILD_SINGLE_SELECTOR(xType, skipgramBatchExec_, (syn0, syn1, syn1Neg, expTable, negTable, target, ngStarter, indices, codes, alpha, randomValue, nsRounds, preciseMode, numWorkers), FLOAT_TYPES);
} else
throw std::runtime_error("SkipGram: target must have rank 0 or 1");
}
template <typename T>
static __global__ void checkContextKernel(int* context, T* syn0, T* neu1, int contextWidth, int vectorLength, int vocabSize) {
__shared__ bool hasError;
@ -157,16 +379,6 @@ namespace nd4j {
}
}
template <typename T>
__global__ void addInfVectorKernel(T* neu1, T* infVector, int vectorLength) {
auto start = blockIdx.x * blockDim.x + threadIdx.x;
auto step = blockDim.x * gridDim.x;
for (auto i = start; i < vectorLength; i += step) {
neu1[i] += infVector[i];
}
}
template <typename T>
__global__ void shiftKernel(T* neu1, T* infVector, int contextWidth, int vectorLength) {
auto start = blockIdx.x * blockDim.x + threadIdx.x;

View File

@ -0,0 +1,81 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include <ops/declarable/helpers/shift.h>
namespace nd4j {
namespace ops {
namespace helpers {
template <typename T>
void rshift_bits_(LaunchContext* launchContext, NDArray &input, NDArray &output, uint32_t shift) {
auto lambda = LAMBDA_T(x, shift) {
return x >> shift;
};
input.applyLambda(lambda, &output);
}
void rshift_bits(LaunchContext* launchContext, NDArray &x, NDArray &z, uint32_t shift) {
BUILD_SINGLE_SELECTOR(x.dataType(), rshift_bits_, (launchContext, x, z, shift), INTEGER_TYPES);
}
template <typename T>
void shift_bits_(LaunchContext* launchContext, NDArray &input, NDArray &output, uint32_t shift) {
auto lambda = LAMBDA_T(x, shift) {
return x << shift;
};
input.applyLambda(lambda, &output);
}
void shift_bits(LaunchContext* launchContext, NDArray &x, NDArray &z, uint32_t shift) {
BUILD_SINGLE_SELECTOR(x.dataType(), shift_bits_, (launchContext, x, z, shift), INTEGER_TYPES);
}
template <typename T>
void cyclic_rshift_bits_(LaunchContext* launchContext, NDArray &input, NDArray &output, uint32_t shift) {
auto step = (sizeof(T) * 8) - shift;
auto lambda = LAMBDA_T(x, shift, step) {
return x >> shift | x << step;
};
input.applyLambda(lambda, &output);
}
void cyclic_rshift_bits(LaunchContext* launchContext, NDArray &x, NDArray &z, uint32_t shift) {
BUILD_SINGLE_SELECTOR(x.dataType(), cyclic_rshift_bits_, (launchContext, x, z, shift), INTEGER_TYPES);
}
template <typename T>
void cyclic_shift_bits_(LaunchContext* launchContext, NDArray &input, NDArray &output, uint32_t shift) {
auto step = (sizeof(T) * 8) - shift;
auto lambda = LAMBDA_T(x, shift, step) {
return x << shift | x >> step;
};
input.applyLambda(lambda, &output);
}
void cyclic_shift_bits(LaunchContext* launchContext, NDArray &x, NDArray &z, uint32_t shift) {
BUILD_SINGLE_SELECTOR(x.dataType(), cyclic_shift_bits_, (launchContext, x, z, shift), INTEGER_TYPES);
}
}
}
}

View File

@ -26,7 +26,13 @@ namespace nd4j {
namespace helpers {
template<typename T>
void toggle_bits__(NDArray &in, NDArray &out) {
NDArray::prepareSpecialUse({&out}, {&in});
auto lambda = LAMBDA_T(_x) {
return ~_x;//eUtils::flip_bits(_x);
};
in.applyLambda(lambda, &out);
NDArray::registerSpecialUse({&out}, {&in});
}
BUILD_SINGLE_TEMPLATE(template void toggle_bits__, (NDArray &in, NDArray &out), INTEGER_TYPES);

View File

@ -685,13 +685,12 @@ void clipByNormBP(nd4j::LaunchContext* context, const NDArray& input, const NDAr
BUILD_SINGLE_TEMPLATE(template void randomShuffle_, (nd4j::LaunchContext * context, NDArray& input, NDArray& output, nd4j::graph::RandomGenerator& rng, const bool isInplace), LIBND4J_TYPES);
//////////////////////////////////////////////////////////////////////////
void eye(nd4j::LaunchContext * context, NDArray& output) {
output.setIdentity();
}
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T>
static __global__ void clipByNormInplaceKernel(Nd4jLong numOfSubArrs, T* inputBuffer, Nd4jLong* shape, Nd4jLong* inputOffsets, T* norm2Buf, Nd4jLong* norm2shape, T clipNorm) {
@ -807,7 +806,6 @@ void clipByNormBP(nd4j::LaunchContext* context, const NDArray& input, const NDAr
void clipByGlobalNorm_(nd4j::LaunchContext * context, std::vector<NDArray*> const& inputs, double clipNorm, nd4j::memory::Workspace* workspace, std::vector<NDArray*>& outputs, bool isInplace) {
NDArray globalNorm = NDArrayFactory::create<T>(0, inputs[0]->getContext()); //sqrt(sum([l2norm(t)**2 for t in t_list]))
PRAGMA_OMP_PARALLEL_FOR
for (auto i = 0; i < inputs.size(); i++) {
auto input = inputs[i];
auto l2norm = input->reduceNumber(reduce::Norm2);
@ -819,7 +817,6 @@ void clipByNormBP(nd4j::LaunchContext* context, const NDArray& input, const NDAr
globalNorm.syncToHost();
const T factor = clipNorm / globalNorm.e<T>(0);
PRAGMA_OMP_PARALLEL_FOR
for (size_t e = 0; e < inputs.size(); e++) {
// all-reduce
auto input = inputs[e];

View File

@ -26,7 +26,7 @@ namespace ops {
namespace helpers {
void rollFunctorLinear(nd4j::LaunchContext * context, NDArray* input, NDArray* output, int shift, bool inplace = false);
void rollFunctorFull(nd4j::LaunchContext * context, NDArray* input, NDArray* output, int shift, std::vector<int> const& axes, bool inplace = false);
void rollFunctorFull(nd4j::LaunchContext * context, NDArray* input, NDArray* output, std::vector<int> const& shifts, std::vector<int> const& axes, bool inplace = false);
}
}
}

View File

@ -0,0 +1,42 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#ifndef DEV_TESTS_SHIFT_H
#define DEV_TESTS_SHIFT_H
#include <op_boilerplate.h>
#include <types/types.h>
#include <NDArray.h>
namespace nd4j {
namespace ops {
namespace helpers {
void rshift_bits(LaunchContext* launchContext, NDArray &x, NDArray &z, uint32_t shift);
void shift_bits(LaunchContext* launchContext, NDArray &x, NDArray &z, uint32_t shift);
void cyclic_rshift_bits(LaunchContext* launchContext, NDArray &x, NDArray &z, uint32_t shift);
void cyclic_shift_bits(LaunchContext* launchContext, NDArray &x, NDArray &z, uint32_t shift);
}
}
}
#endif //DEV_TESTS_SHIFT_H

View File

@ -113,6 +113,14 @@ TEST_F(DataTypesValidationTests, test_bfloat16_rand_1) {
ASSERT_TRUE(x.sumNumber().e<float>(0) > 0);
}
TEST_F(DataTypesValidationTests, test_bfloat16_rand_2) {
auto x = NDArrayFactory::create<bfloat16>('c', {5, 10});
RandomGenerator gen(119, 120);
RandomLauncher::fillGaussian(LaunchContext::defaultContext(), gen, &x, 0, 1);
ASSERT_TRUE(x.sumNumber().e<float>(0) > 0);
}
TEST_F(DataTypesValidationTests, cast_1) {
float16 x = static_cast<float16>(1.f);

View File

@ -750,59 +750,6 @@ TEST_F(DeclarableOpsTests12, tensormmul_6) {
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests12, concat_test10) {
NDArray x0('c', {1,4,5}, nd4j::DataType::FLOAT32);
NDArray x1('c', {2,4,5}, nd4j::DataType::FLOAT32);
NDArray z('f', {3,4,5}, nd4j::DataType::FLOAT32);
x0 = 0.;
x1 = 1.;
nd4j::ops::concat op;
auto status = op.execute({&x0, &x1}, {&z}, {}, {0}, {});
ASSERT_EQ(ND4J_STATUS_OK, status);
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests12, concat_14) {
NDArray x0('c', {1,6}, {1,2,3,4,5,6});
NDArray x1('c', {1,6}, {7,8,9,10,11,12});
NDArray output('f', {2,6}, nd4j::DataType::DOUBLE);
NDArray exp('c', {2,6}, {1,2,3,4,5,6,7,8,9,10,11,12});
nd4j::ops::concat op;
auto status = op.execute({&x0, &x1}, {&output}, {}, {0}, {});
ASSERT_EQ(ND4J_STATUS_OK, status);
// output.printBuffer();
// output.printIndexedBuffer();
ASSERT_TRUE(exp.equalsTo(output));
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests12, concat_15) {
NDArray x0('c', {1,4}, {1,2,3,4});
NDArray x1('c', {1,4}, {5,6,7,8});
NDArray output('c', {2,4}, nd4j::DataType::DOUBLE);
NDArray exp('c', {2,4}, {1,2,3,4,5,6,7,8});
nd4j::ops::concat op;
auto status = op.execute({&x0, &x1}, {&output}, {}, {0}, {});
ASSERT_EQ(ND4J_STATUS_OK, status);
// output.printBuffer();
// output.printIndexedBuffer();
ASSERT_TRUE(exp.equalsTo(output));
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests12, reduceMeanBp_4) {

View File

@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@ -33,8 +33,8 @@ class DeclarableOpsTests13 : public testing::Test {
public:
DeclarableOpsTests13() {
printf("\n");
fflush(stdout);
//printf("\n");
//fflush(stdout);
}
};
@ -103,8 +103,9 @@ TEST_F(DeclarableOpsTests13, test_argmax_edge_1) {
nd4j::ops::argmax op;
auto result = op.execute(ctx);
ASSERT_EQ(Status::OK(), result);
nd4j_printf("Done\n","");
//nd4j_printf("Done\n","");
delete ctx;
}
@ -258,7 +259,7 @@ TEST_F(DeclarableOpsTests13, BarnesHutTsne_EdgeForceTest_1) {
ASSERT_EQ(result->status(), Status::OK());
result->at(0)->printBuffer("Output");
//result->at(0)->printBuffer("Output");
ASSERT_TRUE(exp1.equalsTo(result->at(0)));
delete result;
}
@ -306,8 +307,8 @@ TEST_F(DeclarableOpsTests13, BarnesHutTsne_EdgeForceTest_3) {
//nd4j_printf("rows %lld, cols %lld, vals %lld, res full %lld\n", rows.lengthOf(), cols.lengthOf(), vals.lengthOf(), exp1.lengthOf());
ASSERT_EQ(result->status(), Status::OK());
result->at(0)->printBuffer("Output");
exp.printBuffer("Expect");
//result->at(0)->printBuffer("Output");
//exp.printBuffer("Expect");
//result->at(0)->printShapeInfo("Shape output");
ASSERT_TRUE(exp.equalsTo(result->at(0)));
delete result;
@ -327,7 +328,7 @@ TEST_F(DeclarableOpsTests13, BarnesHutTsne_symmetrized_1) {
nd4j::ops::barnes_symmetrized op;
auto result = op.execute({&rows, &cols, &vals}, {}, {1});
ASSERT_EQ(result->status(), Status::OK());
result->at(2)->printBuffer("Symmetrized1");
//result->at(2)->printBuffer("Symmetrized1");
ASSERT_TRUE(exp.equalsTo(result->at(2)));
delete result;
@ -346,7 +347,7 @@ TEST_F(DeclarableOpsTests13, BarnesHutTsne_symmetrized_2) {
nd4j::ops::barnes_symmetrized op;
auto result = op.execute({&rows, &cols, &vals}, {}, {3});
ASSERT_EQ(result->status(), Status::OK());
result->at(2)->printBuffer("Symmetrized2");
//result->at(2)->printBuffer("Symmetrized2");
// ASSERT_TRUE(exp[i]->equalsTo(result->at(i)));
ASSERT_TRUE(exp.equalsTo(result->at(2)));
delete result;
@ -365,7 +366,7 @@ TEST_F(DeclarableOpsTests13, BarnesHutTsne_symmetrized_3) {
nd4j::ops::barnes_symmetrized op;
auto result = op.execute({&rows, &cols, &vals}, {}, {11});
ASSERT_EQ(result->status(), Status::OK());
result->at(2)->printBuffer("Symmetrized3");
//result->at(2)->printBuffer("Symmetrized3");
//exp.printBuffer("EXPect symm3");
// ASSERT_TRUE(exp[i]->equalsTo(result->at(i)));
//ASSERT_TRUE(exp.equalsTo(result->at(0)));
@ -390,10 +391,10 @@ TEST_F(DeclarableOpsTests13, BarnesHutTsne_symmetrized_4) {
auto result = op.execute({&rows, &cols, &vals}, {}, {11});
ASSERT_EQ(result->status(), Status::OK());
auto res = result->at(2);
res->printBuffer("Symmetrized4");
exp4.printBuffer("Expected sym");
nd4j_printf("Total res is {1, %lld}\n", res->lengthOf());
nd4j_printf("Expected is {1, %lld}\n", exp4.lengthOf());
// res->printBuffer("Symmetrized4");
// exp4.printBuffer("Expected sym");
// nd4j_printf("Total res is {1, %lld}\n", res->lengthOf());
// nd4j_printf("Expected is {1, %lld}\n", exp4.lengthOf());
//exp.printBuffer("EXPect symm3");
// ASSERT_TRUE(exp[i]->equalsTo(result->at(i)));
@ -619,3 +620,72 @@ TEST_F(DeclarableOpsTests13, adjustSaturation_5) {
delete results;
}
TEST_F(DeclarableOpsTests13, shift_bits_1) {
auto x = NDArrayFactory::create<int>('c', {5});
auto e = x.ulike();
x.assign(32);
e.assign(512);
nd4j::ops::shift_bits op;
auto result = op.execute({&x}, {}, {4});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
ASSERT_EQ(e, *z);
delete result;
}
TEST_F(DeclarableOpsTests13, rshift_bits_1) {
auto x = NDArrayFactory::create<int>('c', {5});
auto e = x.ulike();
x.assign(512);
e.assign(32);
nd4j::ops::rshift_bits op;
auto result = op.execute({&x}, {}, {4});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
ASSERT_EQ(e, *z);
delete result;
}
TEST_F(DeclarableOpsTests13, cyclic_shift_bits_1) {
auto x = NDArrayFactory::create<int>('c', {5});
auto e = x.ulike();
x.assign(32);
e.assign(512);
nd4j::ops::cyclic_shift_bits op;
auto result = op.execute({&x}, {}, {4});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
ASSERT_EQ(e, *z);
delete result;
}
TEST_F(DeclarableOpsTests13, cyclic_rshift_bits_1) {
auto x = NDArrayFactory::create<int>('c', {5});
auto e = x.ulike();
x.assign(512);
e.assign(32);
nd4j::ops::cyclic_rshift_bits op;
auto result = op.execute({&x}, {}, {4});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
ASSERT_EQ(e, *z);
delete result;
}

View File

@ -364,77 +364,6 @@ TEST_F(DeclarableOpsTests15, test_rank_2) {
delete result;
}
TEST_F(DeclarableOpsTests15, test_concat_column_1) {
auto x = NDArrayFactory::create<double>('c', {2, 1}, {1, 1});
auto y = NDArrayFactory::create<double>('c', {2, 1}, {0, 0});
auto e = NDArrayFactory::create<double>('c', {2, 2}, {1, 0, 1, 0});
auto z = NDArrayFactory::create<double>('c', {2, 2});
nd4j::ops::concat op;
auto status = op.execute({&x, &y}, {&z}, {}, {1}, {});
ASSERT_EQ(Status::OK(), status);
z.printIndexedBuffer("z");
ASSERT_EQ(e, z);
}
TEST_F(DeclarableOpsTests15, test_concat_large_1) {
std::array<NDArray*, 2000> arrays;
Context context(1);
Nd4jLong axis = 0;
// we crate bunch of arrays, filled with specific values
for (int e = 0; e < arrays.size(); e++) {
auto array = NDArrayFactory::create_<float>('c', {1, 300});
array->assign(e);
context.setInputArray(e, array, true);
}
auto z = NDArrayFactory::create<float>('c', {2000, 300});
context.setOutputArray(0, &z, false);
context.setIArguments(&axis, 1);
nd4j::ops::concat op;
op.execute(&context);
for (int e = 0; e < arrays.size(); e++) {
auto row = z.tensorAlongDimension(e, {1});
ASSERT_NEAR((float) e, row->e<float>(0), 1e-5f);
delete row;
}
}
TEST_F(DeclarableOpsTests15, test_concat_large_2) {
std::array<NDArray*, 10> arrays;
Context context(1);
Nd4jLong axis = 0;
// we crate bunch of arrays, filled with specific values
for (int e = 0; e < arrays.size(); e++) {
auto array = NDArrayFactory::create_<float>('c', {1, 5, 20});
array->assign(e);
context.setInputArray(e, array, true);
}
auto z = NDArrayFactory::create<float>('c', {arrays.size(), 5, 20});
context.setOutputArray(0, &z, false);
context.setIArguments(&axis, 1);
nd4j::ops::concat op;
op.execute(&context);
for (int e = 0; e < arrays.size(); e++) {
auto row = z.tensorAlongDimension(e, {1, 2});
ASSERT_NEAR((float) e, row->meanNumber().e<float>(0), 1e-5f);
delete row;
}
}
TEST_F(DeclarableOpsTests15, test_lstmBlock_1) {
auto x0 = NDArrayFactory::create<Nd4jLong>(5);
auto x1 = NDArrayFactory::create<float>('c', {5, 1, 4}, {0.7787856f, 0.80119777f, 0.72437465f, 0.23089433f, 0.72714126f, 0.18039072f, 0.50563407f, 0.89252293f, 0.5461209f, 0.92336726f, 0.085571885f, 0.7937801f, 0.65908563f, 0.55552566f, 0.15962744f, 0.30874777f, 0.15476847f, 0.46954823f, 0.9938899f, 0.6112741f});

View File

@ -0,0 +1,54 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include "testlayers.h"
#include <ops/declarable/CustomOperations.h>
#include <NDArray.h>
#include <ops/ops.h>
#include <GradCheck.h>
#include <array>
using namespace nd4j;
class DeclarableOpsTests16 : public testing::Test {
public:
DeclarableOpsTests16() {
printf("\n");
fflush(stdout);
}
};
TEST_F(DeclarableOpsTests16, test_repeat_119) {
auto x = NDArrayFactory::create<double>('c', {2, 3}, {1, 2, 3, 4, 5, 6});
auto e = NDArrayFactory::create<double>('c', {4, 3}, {1, 2, 3, 1, 2, 3, 4, 5, 6, 4, 5, 6});
nd4j::ops::repeat op;
auto result = op.execute({&x}, {}, {2, 0});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
ASSERT_EQ(e, *z);
delete result;
}

View File

@ -373,35 +373,6 @@ TEST_F(DeclarableOpsTests2, NLP_Cbow_Test_1) {
delete result;
}
TEST_F(DeclarableOpsTests2, Test_Concat_3D_1) {
auto x0 = NDArrayFactory::create<double>('c', {1, 100, 150});
auto x1 = NDArrayFactory::create<double>('c', {1, 100, 150});
auto x2 = NDArrayFactory::create<double>('c', {1, 100, 150});
auto x3 = NDArrayFactory::create<double>('c', {1, 100, 150});
x0.assign(1.0);
x1.assign(2.0);
x2.assign(3.0);
x3.assign(4.0);
nd4j::ops::concat op;
auto result = op.execute({&x0, &x1, &x2, &x3}, {}, {0}, {});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
Nd4jLong numOfTads= ShapeUtils::getNumOfSubArrs(z->getShapeInfo(), {0});
ASSERT_TRUE(4 == numOfTads);
for (int e = 0; e < numOfTads; e++) {
NDArray tad = (*z)(e, {0});
auto mean = tad.meanNumber().e<double>(0);
ASSERT_NEAR((double) e+1, mean, 1e-5);
}
delete result;
}
TEST_F(DeclarableOpsTests2, YetAnotherMatmulTest_1) {
auto A = NDArrayFactory::create<float>('c', {3, 3});
auto B = NDArrayFactory::create<float>('c', {3, 1});
@ -502,6 +473,7 @@ TEST_F(DeclarableOpsTests2, Test_FloorDiv_2) {
auto x = NDArrayFactory::create<float>('c', {1, 3}, {3.0, 6.0, -3.0});
auto y = NDArrayFactory::create<float>('c', {1, 3}, {-2.0, 2.0, -2.0});
auto eps = NDArrayFactory::create<float>('c', {1, 3}, {1, 2, 3});
auto exp1 = NDArrayFactory::create<float>('c', {1, 3}, {0.f, 0.f, 0.f});
auto exp2 = NDArrayFactory::create<float>('c', {1, 3}, {0.f, 0.f, 0.f});

View File

@ -223,21 +223,221 @@ TEST_F(DeclarableOpsTests5, Test_Boolean_diff_1) {
delete result;
}
TEST_F(DeclarableOpsTests5, Test_SetSeed_1) {
auto x = NDArrayFactory::create<int>('c', {1, 1}, {120});
auto y = NDArrayFactory::create<int>(5);
nd4j::ops::set_seed op;
auto result = op.execute({&x, &y}, {}, {120, 5}, {}, false, nd4j::DataType::INT32);
ASSERT_EQ(Status::OK(), result->status());
// result->at(0)->printIndexedBuffer("RES SEED");
nd4j::ops::get_seed getOp;
auto getRes = getOp.execute({}, {}, {});
ASSERT_EQ(Status::OK(), getRes->status());
// getRes->at(0)->printIndexedBuffer("Output RES GET SEED");
// ASSERT_EQ(result->at(0)->t<bool>(0), true);
delete result;
delete getRes;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, scatterMul_test1) {
auto matrix = NDArrayFactory::create<float>('c', {2, 2}, {1, 2, 3, 4});
NDArray idc('c', {1}, {0LL}, nd4j::DataType::INT64);
auto updates = NDArrayFactory::create<float>('c', {1, 2}, {10, 1});
auto exp = NDArrayFactory::create<float>('c', {2, 2}, {10, 2, 3, 4});
nd4j::ops::scatter_mul op;
auto result = op.execute({&matrix, &idc, &updates}, {}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, scatterDiv_test1) {
auto matrix = NDArrayFactory::create<float>('c', {2, 2}, {1, 2, 3, 4});
NDArray idc('c', {1}, {0LL}, nd4j::DataType::INT64);
auto updates = NDArrayFactory::create<float>('c', {1, 2}, {10, 1});
auto exp = NDArrayFactory::create<float>('c', {2, 2}, {0.10, 2, 3, 4});
nd4j::ops::scatter_div op;
auto result = op.execute({&matrix, &idc, &updates}, {}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
// z->printIndexedBuffer("Scatter Div");
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, scatterSub_test1) {
auto matrix = NDArrayFactory::create<float>('c', {2, 2}, {1, 2, 3, 4});
NDArray idc('c', {1}, {0LL}, nd4j::DataType::INT64);
auto updates = NDArrayFactory::create<float>('c', {1, 2}, {10, 1});
auto exp = NDArrayFactory::create<float>('c', {2, 2}, {-9, 1, 3, 4});
nd4j::ops::scatter_sub op;
auto result = op.execute({&matrix, &idc, &updates}, {}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
// z->printIndexedBuffer("Scatter Sub");
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, hardsigmoid_test1) {
auto matrix = NDArrayFactory::create<float>('c', {2, 2}, {1, 2, 3, 4});
auto exp = NDArrayFactory::create<float>('c', {2, 2}, {0.7, 0.9, 1, 1});
nd4j::ops::hardsigmoid op;
auto result = op.execute({&matrix}, {}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
z->printIndexedBuffer("Hadrdsigmoid 2x2");
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, hardsigmoid_test2) {
auto matrix = NDArrayFactory::create<float>('c', {2, 2}, {1, 2, 3, 4});
auto eps = NDArrayFactory::create<float>('c', {2, 2}, {1, 2, 3, 4});
auto exp = NDArrayFactory::create<float>('c', {2, 2}, {0.2, 0.4, 0, 0});
nd4j::ops::hardsigmoid_bp op;
auto result = op.execute({&matrix, &eps}, {}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
z->printIndexedBuffer("Hadrdsigmoid 2x2");
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, hardtanh_test1) {
auto matrix = NDArrayFactory::create<float>('c', {3, 3}, {-4, -3, -2, -1, 0, 1, 2, 3, 4});
auto exp = NDArrayFactory::create<float>('c', {3, 3}, {-1, -1, -1, -1, 0, 1, 1, 1, 1});
nd4j::ops::hardtanh op;
auto result = op.execute({&matrix}, {}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
// z->printIndexedBuffer("Hardtanh 2x2");
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, hardtanh_test2) {
auto matrix = NDArrayFactory::create<float>('c', {3, 3}, {-4, -3, -2, -1, 0, 1, 2, 3, 4});
auto eps = NDArrayFactory::create<float>('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9});
auto exp = NDArrayFactory::create<float>('c', {3, 3}, {0, 0, 0, 4, 5, 6, 0, 0, 0});
nd4j::ops::hardtanh_bp op;
auto result = op.execute({&matrix, &eps}, {}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
// z->printIndexedBuffer("Hardtanh_bp 2x2");
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, histogram_test1) {
auto matrix = NDArrayFactory::create<double>('c', {3, 3}, {-4, -3, -2, -1, 0, 1, 2, 3, 4});
auto exp = NDArrayFactory::create<Nd4jLong>('c', {3}, {3, 3, 3});
nd4j::ops::histogram op;
auto result = op.execute({&matrix}, {}, {3}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
// z->printIndexedBuffer("Histogram3");
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, histogram_test2) {
auto matrix = NDArrayFactory::create<double>('c', {3}, {1, 2, 1});
auto exp = NDArrayFactory::create<Nd4jLong>('c', {4}, {2, 0, 0, 1});
nd4j::ops::histogram op;
auto result = op.execute({&matrix}, {}, {4}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
z->printIndexedBuffer("Histogram4");
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, Identity_test1) {
auto matrix = NDArrayFactory::create<float>('c', {3, 3}, {-4, -3, -2, -1, 0, 1, 2, 3, 4});
// auto exp = NDArrayFactory::create<Nd4jLong>('c', {3, 3}, {3, 3, 3});
nd4j::ops::identity op;
auto result = op.execute({&matrix}, {}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
// z->printIndexedBuffer("Histogram3");
ASSERT_TRUE(matrix.equalsTo(z));
delete result;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, Identity_test2) {
auto matrix = NDArrayFactory::create<float>('c', {3, 3}, {-4, -3, -2, -1, 0, 1, 2, 3, 4});
auto eps = NDArrayFactory::create<float>('c', {3, 3}, {1,2,3,4,5,6,7,8,9});
// auto exp = NDArrayFactory::create<float>('c', {3,3});
nd4j::ops::identity_bp op;
auto result = op.execute({&matrix, &eps}, {}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
z->printIndexedBuffer("Identity_BP");
ASSERT_TRUE(z->equalsTo(eps));
delete result;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, Log1p_test1) {
auto matrix = NDArrayFactory::create<float>('c', {3, 3}, {4, 3, 2, 1, 0, 1, 2, 3, 4});
auto y = NDArrayFactory::create<float>('c', {3,3}, {5,4,3,2,1,2,3,4,5});
// auto eps = NDArrayFactory::create<float>('c', {3, 3}, {1,2,3,4,5,6,7,8,9});
// auto exp = NDArrayFactory::create<float>('c', {3,3});
nd4j::ops::Log1p op;
y.applyTransform(nd4j::transform::Log, nullptr, nullptr);
auto result = op.execute({&matrix}, {}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
z->printIndexedBuffer("Log1p");
ASSERT_TRUE(z->equalsTo(y));
delete result;
}
TEST_F(DeclarableOpsTests5, Test_SpaceToBatch_1) {

View File

@ -737,6 +737,44 @@ TEST_F(DeclarableOpsTests6, cumSum_20) {
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, TestMergeMaxIndex_1) {
auto x = NDArrayFactory::create<double>('c', {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f});
auto y = NDArrayFactory::create<double>('c', {2, 2, 2}, {10.f, 2.f, 30.f, 4.f, 50.f, 6.f, 70.f, 8.f});
auto z = NDArrayFactory::create<double>('c', {2, 2, 2}, {1.f, 20.f, 3.f, 40.f, 5.f, 60.f, 7.f, 80.f});
auto exp = NDArrayFactory::create<int>('c', {2, 2, 2}, {1, 2, 1, 2, 1, 2, 1, 2});
nd4j::ops::mergemaxindex op;
auto ress = op.execute({&x, &y, &z}, {}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, ress->status());
// ress->at(0)->printIndexedBuffer("MergeMaxIndex Result is ");
// ress->at(0)->printShapeInfo("Shape info for MergeMaxIdex");
// x.printIndexedBuffer("Input is");
ASSERT_TRUE(ress->at(0)->equalsTo(exp));
delete ress;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, TestMergeMaxIndex_2) {
auto x = NDArrayFactory::create<double>('c', {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f});
auto y = NDArrayFactory::create<double>('c', {2, 2, 2}, {10.f, 2.f, 30.f, 4.f, 50.f, 6.f, 70.f, 8.f});
auto z = NDArrayFactory::create<double>('c', {2, 2, 2}, {1.f, 20.f, 3.f, 40.f, 5.f, 60.f, 7.f, 80.f});
auto exp = NDArrayFactory::create<Nd4jLong>('c', {2, 2, 2}, {1, 2, 1, 2, 1, 2, 1, 2});
nd4j::ops::mergemaxindex op;
auto ress = op.execute({&x, &y, &z}, {}, {nd4j::DataType::INT64}, {});
ASSERT_EQ(ND4J_STATUS_OK, ress->status());
// ress->at(0)->printIndexedBuffer("MergeMaxIndex2 Result is ");
// ress->at(0)->printShapeInfo("Shape info for MergeMaxIdex2");
// x.printIndexedBuffer("Input is");
ASSERT_TRUE(ress->at(0)->equalsTo(exp));
delete ress;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, TestDropout_1) {
@ -752,8 +790,60 @@ TEST_F(DeclarableOpsTests6, TestDropout_1) {
delete ress;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, TestMod_1) {
auto x = NDArrayFactory::create<double>('c', {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f});
auto y = NDArrayFactory::create<double>('c', {2, 2, 2}, {10.f, 2.f, 30.f, 4.f, 50.f, 6.f, 70.f, 8.f});
auto exp = NDArrayFactory::create<double>('c', {2, 2, 2}, {1, 0, 3, 0, 5, 0, 7, 0});
nd4j::ops::mod op;
auto ress = op.execute({&x, &y}, {}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, ress->status());
// ress->at(0)->printIndexedBuffer("MOD Result is ");
// x.printIndexedBuffer("Input is");
ASSERT_TRUE(ress->at(0)->equalsTo(exp));
delete ress;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, TestMod_BP_1) {
auto x = NDArrayFactory::create<double>('c', {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f});
auto y = NDArrayFactory::create<double>('c', {2, 2, 2}, {10.f, 2.f, 30.f, 4.f, 50.f, 6.f, 70.f, 8.f});
auto eps = NDArrayFactory::create<double>('c', {2, 2, 2}, {10.f, 2.f, 30.f, 4.f, 50.f, 6.f, 70.f, 8.f});
auto exp = NDArrayFactory::create<double>('c', {2, 2, 2});
nd4j::ops::mod_bp op;
auto ress = op.execute({&x, &y, &eps}, {}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, ress->status());
// ress->at(0)->printIndexedBuffer("MOD_BP Result is ");
// x.printIndexedBuffer("Input is");
ASSERT_TRUE(ress->at(0)->equalsTo(exp));
delete ress;
}
///////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, TestRank_1) {
auto x = NDArrayFactory::create<double>('c', {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f});
auto y = NDArrayFactory::create<double>('c', {2, 2, 2}, {10.f, 2.f, 30.f, 4.f, 50.f, 6.f, 70.f, 8.f});
auto eps = NDArrayFactory::create<double>('c', {2, 2, 2}, {10.f, 2.f, 30.f, 4.f, 50.f, 6.f, 70.f, 8.f});
auto exp = NDArrayFactory::create<int>(3);
nd4j::ops::rank op;
auto ress = op.execute({&x}, {}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, ress->status());
ress->at(0)->printIndexedBuffer("RANK Result is ");
// x.printIndexedBuffer("Input is");
ASSERT_TRUE(ress->at(0)->equalsTo(exp));
delete ress;
}
TEST_F(DeclarableOpsTests6, TestDropout_2) {
// auto x0 = NDArrayFactory::create<double>('c', {10, 10});
// auto x1 = NDArrayFactory::create<double>('c', {10, 10});
@ -1480,8 +1570,8 @@ TEST_F(DeclarableOpsTests6, LogDet_1) {
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
z->printIndexedBuffer("LogDet Output1 ");
exp.printIndexedBuffer("LogDet Expected1 ");
// z->printIndexedBuffer("LogDet Output1 ");
// exp.printIndexedBuffer("LogDet Expected1 ");
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
@ -1502,9 +1592,9 @@ TEST_F(DeclarableOpsTests6, LogDet_2) {
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
z->printIndexedBuffer("LogDet Output2 ");
// z->printIndexedBuffer("LogDet Output2 ");
// z->printShapeInfo("Shape");
exp.printIndexedBuffer("LogDet Expected2 ");
// exp.printIndexedBuffer("LogDet Expected2 ");
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
@ -1525,9 +1615,9 @@ TEST_F(DeclarableOpsTests6, LogDet_3) {
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
z->printIndexedBuffer("LogDet Output3 ");
// z->printIndexedBuffer("LogDet Output3 ");
// z->printShapeInfo("Shape");
exp.printIndexedBuffer("LogDet Expected3 ");
// exp.printIndexedBuffer("LogDet Expected3 ");
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
@ -1572,8 +1662,8 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_1) {
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
z->printIndexedBuffer("Output ");
exp.printIndexedBuffer("Expected ");
// z->printIndexedBuffer("Output ");
// exp.printIndexedBuffer("Expected ");
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
@ -1608,8 +1698,8 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_01) {
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
z->printIndexedBuffer("Output ");
exp.printIndexedBuffer("Expected ");
// z->printIndexedBuffer("Output ");
// exp.printIndexedBuffer("Expected ");
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
@ -1642,8 +1732,8 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_02) {
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
z->printIndexedBuffer("Output ");
exp.printIndexedBuffer("Expected ");
// z->printIndexedBuffer("Output ");
// exp.printIndexedBuffer("Expected ");
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
@ -1722,8 +1812,8 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_03) {
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
z->printIndexedBuffer("Output ");
exp.printIndexedBuffer("Expected ");
// z->printIndexedBuffer("Output ");
// exp.printIndexedBuffer("Expected ");
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
@ -2755,31 +2845,4 @@ TEST_F(DeclarableOpsTests6, Test_Diag_119_3) {
delete result;
}
TEST_F(DeclarableOpsTests6, concat_test14) {
NDArray x0('c', {1, 55, 40}, nd4j::DataType::DOUBLE);
NDArray x1('c', {1, 55, 40}, nd4j::DataType::DOUBLE);
x0 = 1.;
x1 = 2.;
nd4j::ops::concat op;
auto result = op.execute({&x0, &x1}, {}, {0}, {});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
// z->printShapeInfo();
// z->printIndexedBuffer();
Nd4jLong numOfTads= ShapeUtils::getNumOfSubArrs(z->getShapeInfo(), {0});
ASSERT_TRUE(2 == numOfTads);
for (int e = 0; e < numOfTads; ++e) {
NDArray tad = (*z)(e, {0});
auto mean = tad.meanNumber().e<double>(0);
ASSERT_NEAR((e+1)*1., mean, 1e-5);
}
delete result;
}

View File

@ -24,6 +24,7 @@
#include <helpers/helper_hash.h>
#include <NDArray.h>
#include <array/NDArrayList.h>
#include <GradCheck.h>
using namespace nd4j;
@ -3310,6 +3311,130 @@ auto exp = NDArrayFactory::create<double>('c', {2, 3, 3}, {
// delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests7, TestRoll_10) {
auto x = NDArrayFactory::create<double>('c', {2, 3, 4}, {
1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24.
});
auto exp = NDArrayFactory::create<double>('c', {2, 3, 4}, {
1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24.
});
// ----------------------------------------------------------------
nd4j::ops::roll op;
auto result = op.execute({&x}, {}, {3, 1}, {}, false, nd4j::DataType::DOUBLE);
ASSERT_EQ(result->status(), Status::OK());
auto out = result->at(0);
// out->printIndexedBuffer("Output");
//exp.printIndexedBuffer("Expect");
ASSERT_TRUE(exp.equalsTo(out));
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests7, TestRoll_11) {
auto x = NDArrayFactory::create<double>('c', {2, 3, 4}, {
1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24.
});
auto shift = NDArrayFactory::create<int>({1,2});
auto axis = NDArrayFactory::create<int>({0, 1});
auto exp = NDArrayFactory::create<double>('c', {2, 3, 4}, {
17., 18., 19., 20., 21., 22., 23., 24., 13., 14., 15., 16., 5., 6., 7, 8, 9, 10, 11, 12, 1, 2, 3, 4
});
// ----------------------------------------------------------------
nd4j::ops::roll op;
NDArray* y = nullptr;
auto result = op.execute({&x, &shift, &axis}, {}, {}, {}, false, nd4j::DataType::DOUBLE);
ASSERT_EQ(result->status(), Status::OK());
auto out = result->at(0);
// out->printIndexedBuffer("Output");
//exp.printIndexedBuffer("Expect");
ASSERT_TRUE(exp.equalsTo(out));
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests7, TestRoll_12) {
auto x = NDArrayFactory::create<double>('c', {2, 3, 4}, {
1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24.
});
auto shift = NDArrayFactory::create<int>({1,1,1});
auto axis = NDArrayFactory::create<int>({0, 1, 2});
auto exp = NDArrayFactory::create<double>('c', {2, 3, 4}, {
24, 21, 22, 23, 16, 13, 14, 15, 20, 17, 18, 19, 12, 9, 10, 11, 4, 1, 2, 3, 8, 5, 6, 7
});
// ----------------------------------------------------------------
nd4j::ops::roll op;
NDArray* y = nullptr;
auto result = op.execute({&x, &shift, &axis}, {}, {}, {}, false, nd4j::DataType::DOUBLE);
ASSERT_EQ(result->status(), Status::OK());
auto out = result->at(0);
out->printIndexedBuffer("Output");
//exp.printIndexedBuffer("Expect");
ASSERT_TRUE(exp.equalsTo(out));
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests7, TestRoll_13) {
auto x = NDArrayFactory::create<double>('c', {2, 3, 4}, {
1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24.
});
auto shift = NDArrayFactory::create<int>(3);
auto axis = NDArrayFactory::create<int>(2);
auto exp = NDArrayFactory::create<double>('c', {2, 3, 4}, {
2,3,4,1,6,7,8,5,10,11,12,9,14, 15, 16, 13, 18, 19, 20, 17, 22, 23, 24, 21
});
// ----------------------------------------------------------------
nd4j::ops::roll op;
NDArray* y = nullptr;
auto result = op.execute({&x}, {}, {3,2}, {}, false, nd4j::DataType::DOUBLE);
ASSERT_EQ(result->status(), Status::OK());
auto out = result->at(0);
// out->printIndexedBuffer("Output");
//exp.printIndexedBuffer("Expect");
ASSERT_TRUE(exp.equalsTo(out));
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests7, TestRoll_14) {
auto x = NDArrayFactory::create<double>('c', {2, 3, 4}, {
1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24.
});
auto shift = NDArrayFactory::create<int>({1,1,1});
auto axis = NDArrayFactory::create<int>({0, 1, 2});
auto exp = NDArrayFactory::create<double>('c', {2, 3, 4}, {
24, 21, 22, 23, 16, 13, 14, 15, 20, 17, 18, 19, 12, 9, 10, 11, 4, 1, 2, 3, 8, 5, 6, 7
});
// ----------------------------------------------------------------
nd4j::ops::roll op;
auto result = op.execute({&x, &shift, &axis}, {}, {}, {}, false, nd4j::DataType::DOUBLE);
ASSERT_EQ(result->status(), Status::OK());
auto out = result->at(0);
// out->printIndexedBuffer("Output");
//exp.printIndexedBuffer("Expect");
ASSERT_TRUE(exp.equalsTo(out));
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests7, percentile_test1) {
@ -3605,6 +3730,289 @@ TEST_F(DeclarableOpsTests7, transpose_test3) {
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests7, rationaltanh_test1) {
auto input = NDArrayFactory::create<double>('c', {8}, {0, 1, 2, 3, 4, 5, 6, 7});
NDArray exp = NDArrayFactory::create<double>({0.000000, 0.998222, 1.516093, 1.658054, 1.695077, 1.706884, 1.711427, 1.713446});
nd4j::ops::rationaltanh op;
auto result = op.execute({&input}, {}, {});
auto output = result->at(0);
// output->printIndexedBuffer("Output rationaltanh");
ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output));
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests7, rationaltanh_test2) {
auto input = NDArrayFactory::create<double>('c', {2,2,2}, {0, 1, 2, 3, 4, 5, 6, 7});
NDArray exp = NDArrayFactory::create<double>('c', {2,2,2}, {0.000000, 0.998222, 1.516093, 1.658054, 1.695077, 1.706884, 1.711427, 1.713446});
nd4j::ops::rationaltanh op;
auto result = op.execute({&input}, {}, {});
auto output = result->at(0);
// output->printIndexedBuffer("Output rationaltanh");
ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output));
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests7, rationaltanh_test3) {
auto input = NDArrayFactory::create<double>('c', {2,2,2}, {0, 1, 2, 3, 4, 5, 6, 7});
auto eps = NDArrayFactory::create<double>('c', {2,2,2}, {1, 2, 3, 4, 5, 6, 7, 8});
NDArray exp = NDArrayFactory::create<double>('c', {2,2,2}, {1.143933, 1.605747, 0.795557, 0.261710, 0.095832, 0.041218, 0.020221, 0.010971});
nd4j::ops::rationaltanh_bp op;
auto result = op.execute({&input, &eps}, {}, {});
auto output = result->at(0);
// output->printBuffer("Output rationaltanh BP");
ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output));
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests7, rectifiedtanh_test1) {
auto input = NDArrayFactory::create<double>('c', {2,2,2}, {0, 1, 2, 3, 4, 5, 6, 7});
NDArray exp = NDArrayFactory::create<double>('c', {2,2,2}, {0.000000, 0.761594, 0.964028, 0.995055, 0.999329, 0.999909, 0.999988, 0.999998});
nd4j::ops::rectifiedtanh op;
auto result = op.execute({&input}, {}, {});
auto output = result->at(0);
// output->printIndexedBuffer("Output rectifiedtanh");
ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output));
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests7, rectifiedtanh_test2) {
auto input = NDArrayFactory::create<double>('c', {2,2,2}, {0, 1, 2, 3, 4, 5, 6, 7});
auto eps = NDArrayFactory::create<double>('c', {2,2,2}, {1, 2, 3, 4, 5, 6, 7, 8});
NDArray exp = NDArrayFactory::create<double>('c', {2,2,2}, {0.000000, 0.839949, 0.211952, 0.039464, 0.006705, 0.001089, 0.000172, 0.000027});
nd4j::ops::rectifiedtanh_bp op;
auto result = op.execute({&input, &eps}, {}, {});
auto output = result->at(0);
// output->printBuffer("Output rectifiedtanh BP");
ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output));
delete result;
}
TEST_F(DeclarableOpsTests7, RealDiv_1) {
NDArray x = NDArrayFactory::create<float>('c', {1, 2, 1}, {2, 4});
NDArray y = NDArrayFactory::create<float>('c', {1, 2}, {1,2});
NDArray e = NDArrayFactory::create<float>('c', {1, 2, 2}, {2, 1, 4, 2});
nd4j::ops::realdiv op;
auto result = op.execute({&x, &y}, {}, {});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
// z->printIndexedBuffer("OUtput RealDiv");
ASSERT_TRUE(e.isSameShape(z));
ASSERT_TRUE(e.equalsTo(*z));
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests7, RealDiv_BP_1) {
NDArray x = NDArrayFactory::create<float>('c', {1, 2, 1}, {2, 4});
NDArray y = NDArrayFactory::create<float>('c', {1, 2}, {1,2});
NDArray e0 = NDArrayFactory::create<float>('c', {1, 2, 1}, {2, 5});
NDArray e1 = NDArrayFactory::create<float>('c', {1, 2}, {-14, -5});
NDArray eps = NDArrayFactory::create<float>('c', {1, 2, 2}, {1, 2, 3, 4});
nd4j::ops::realdiv_bp op;
auto result = op.execute({&x, &y, &eps}, {}, {});
ASSERT_EQ(Status::OK(), result->status());
auto z0 = result->at(0);
auto z1 = result->at(1);
// z0->printShapeInfo("OUtput RealDiv BP0 shape");
// z1->printShapeInfo("OUtput RealDiv BP1 shape");
// z0->printIndexedBuffer("OUtput RealDiv BP0");
// z1->printIndexedBuffer("OUtput RealDiv BP1");
// ASSERT_TRUE(e.isSameShape(z));
ASSERT_TRUE(e0.equalsTo(z0));
ASSERT_TRUE(e1.equalsTo(z1));
delete result;
}
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests7, ShapesOf_1) {
NDArray x = NDArrayFactory::create<float>('c', {1, 2, 1}, {2, 4});
// NDArray y = NDArrayFactory::create<float>('c', {1, 2}, {1,2});
NDArray e = NDArrayFactory::create<Nd4jLong>({1, 2, 1});
nd4j::ops::shapes_of op;
auto result = op.execute({&x}, {}, {});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
// z->printIndexedBuffer("OUtput RealDiv");
// ASSERT_TRUE(e.isSameShape(z));
ASSERT_TRUE(e.equalsTo(*z));
delete result;
}
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests7, ShapesOf_2) {
NDArray x = NDArrayFactory::create<float>('c', {1, 2, 1}, {2, 4});
NDArray y = NDArrayFactory::create<float>('c', {1, 2}, {1,2});
NDArray e0 = NDArrayFactory::create<Nd4jLong>({1, 2, 1});
NDArray e1 = NDArrayFactory::create<Nd4jLong>({1, 2});
nd4j::ops::shapes_of op;
auto result = op.execute({&x, &y}, {}, {});
ASSERT_EQ(Status::OK(), result->status());
auto z0 = result->at(0);
auto z1 = result->at(1);
// z0->printIndexedBuffer("OUtput shapes2");
// z1->printIndexedBuffer("OUtput shapes2");
// ASSERT_TRUE(e.isSameShape(z));
ASSERT_TRUE(e0.equalsTo(z0));
ASSERT_TRUE(e1.equalsTo(z1));
delete result;
}
TEST_F(DeclarableOpsTests7, Size_1) {
NDArray x = NDArrayFactory::create<float>('c', {1, 2, 1}, {2, 4});
NDArray y = NDArrayFactory::create<float>('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11});
NDArray e = NDArrayFactory::create<Nd4jLong>(2);
nd4j::ops::size op;
auto result = op.execute({&x}, {}, {});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
// z->printIndexedBuffer("OUtput SIZE");
/// ASSERT_TRUE(e.isSameShape(z));
ASSERT_TRUE(e.equalsTo(*z));
delete result;
}
TEST_F(DeclarableOpsTests7, Size_2) {
NDArray x = NDArrayFactory::create<float>('c', {1, 2, 1}, {2, 4});
NDArray y = NDArrayFactory::create<float>('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11});
NDArray e = NDArrayFactory::create<Nd4jLong>(10);
nd4j::ops::size op;
auto result = op.execute({&y}, {}, {});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
// z->printIndexedBuffer("OUtput SIZE");
/// ASSERT_TRUE(e.isSameShape(z));
ASSERT_TRUE(e.equalsTo(*z));
delete result;
}
TEST_F(DeclarableOpsTests7, Softplus_1) {
NDArray x = NDArrayFactory::create<float>('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11});
NDArray e = NDArrayFactory::create<float>('c', {5, 2}, {1.3132616, 2.126928, 3.0485873, 4.01815, 5.0067153, 7.0009117, 9.000123, 10.000046, 10.000046, 11.000016});
nd4j::ops::softplus op;
auto result = op.execute({&x}, {}, {});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
// z->printIndexedBuffer("OUtput Softplus");
/// ASSERT_TRUE(e.isSameShape(z));
ASSERT_TRUE(e.equalsTo(*z));
delete result;
}
TEST_F(DeclarableOpsTests7, Softplus_BP_1) {
NDArray x = NDArrayFactory::create<double >('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11});
// NDArray e = NDArrayFactory::create<float>('c', {5, 2}, {1.3132616, 2.126928, 3.0485873, 4.01815, 5.0067153, 7.0009117, 9.000123, 10.000046, 10.000046, 11.000016});
NDArray eps = NDArrayFactory::create<double>('c', {5, 2}, {1,2,3,4,5,6,7,8, 9, 10});
nd4j::ops::softplus ffOP;
nd4j::ops::softplus_bp bpOp;
const OpArgsHolder argsHolderFF({&x}, {}, {});
const OpArgsHolder argsHolderBP({&x, &eps}, {}, {});
bool gradOK = GradCheck::checkGrad(ffOP, bpOp, argsHolderFF, argsHolderBP);
ASSERT_TRUE(gradOK);
//
// auto z = result->at(0);
// z->printIndexedBuffer("OUtput Softplus");
///// ASSERT_TRUE(e.isSameShape(z));
// ASSERT_TRUE(e.equalsTo(*z));
//
// delete result;
}
TEST_F(DeclarableOpsTests7, Softsign_1) {
NDArray x = NDArrayFactory::create<float>('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11});
NDArray e = NDArrayFactory::create<float>('c', {5, 2}, {0.5, 0.6666667, 0.75, 0.8, 0.8333333, 0.875, 0.9, 0.90909094, 0.90909094, 0.9166667});
nd4j::ops::softsign op;
auto result = op.execute({&x}, {}, {});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
// z->printIndexedBuffer("OUtput Softsign");
/// ASSERT_TRUE(e.isSameShape(z));
ASSERT_TRUE(e.equalsTo(*z));
delete result;
}
TEST_F(DeclarableOpsTests7, Softsign_BP_1) {
NDArray x = NDArrayFactory::create<double >('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11});
// NDArray e = NDArrayFactory::create<float>('c', {5, 2}, {1.3132616, 2.126928, 3.0485873, 4.01815, 5.0067153, 7.0009117, 9.000123, 10.000046, 10.000046, 11.000016});
NDArray eps = NDArrayFactory::create<double>('c', {5, 2}, {1,2,3,4,5,6,7,8, 9, 10});
nd4j::ops::softsign ffOP;
nd4j::ops::softsign_bp bpOp;
const OpArgsHolder argsHolderFF({&x}, {}, {});
const OpArgsHolder argsHolderBP({&x, &eps}, {}, {});
bool gradOK = GradCheck::checkGrad(ffOP, bpOp, argsHolderFF, argsHolderBP);
ASSERT_TRUE(gradOK);
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests7, fill_test2) {
@ -3644,6 +4052,185 @@ TEST_F(DeclarableOpsTests7, fill_test3) {
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests7, ToggleBits_test1) {
auto x = NDArrayFactory::create<int>('c', {2}, {2, 2});
auto exp = NDArrayFactory::create<int>('c', {2}, {-3, -3});
nd4j::ops::toggle_bits op;
auto result = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::INT32);
auto output = result->at(0);
ASSERT_EQ(ND4J_STATUS_OK, result->status());
// output->printIndexedBuffer("Toggled");
ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output));
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests7, ToggleBits_test2) {
auto x = NDArrayFactory::create<int>('c', {2}, {2, 2});
auto y = NDArrayFactory::create<int>('c', {2}, {1, 1});
auto exp0 = NDArrayFactory::create<int>('c', {2}, {-3, -3});
auto exp1 = NDArrayFactory::create<int>('c', {2}, {-2, -2});
nd4j::ops::toggle_bits op;
auto result = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::INT32);
auto output = result->at(0);
auto z = result->at(1);
ASSERT_EQ(ND4J_STATUS_OK, result->status());
// output->printIndexedBuffer("Toggled");
ASSERT_TRUE(exp0.isSameShape(output));
ASSERT_TRUE(exp0.equalsTo(output));
ASSERT_TRUE(exp1.isSameShape(z));
ASSERT_TRUE(exp1.equalsTo(z));
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests7, Truncatediv_test1) {
NDArray x = NDArrayFactory::create<double >('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11});
NDArray y = NDArrayFactory::create<double >('c', {5, 2}, {2,2,2,2,2,2,2,2, 2, 2});
NDArray exp = NDArrayFactory::create<double >('c', {5, 2}, {0.5, 1., 1.5, 2., 2.5, 3.5, 4.5, 5., 5., 5.5});
nd4j::ops::truncatediv op;
auto result = op.execute({&x, &y}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto output = result->at(0);
// output->printIndexedBuffer("Toggled");
ASSERT_TRUE(exp.isSameShape(output));
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests7, Truncatediv_test2) {
NDArray x = NDArrayFactory::create<double >('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11});
NDArray y = NDArrayFactory::create<double >('c', {1, 2}, {2,2});
NDArray exp = NDArrayFactory::create<double >('c', {5, 2}, {0.5, 1., 1.5, 2., 2.5, 3.5, 4.5, 5., 5., 5.5});
nd4j::ops::truncatediv op;
auto result = op.execute({&x, &y}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto output = result->at(0);
// output->printIndexedBuffer("Toggled");
ASSERT_TRUE(exp.isSameShape(output));
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests7, TypesConversion_test1) {
NDArray x = NDArrayFactory::create<double >('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11});
NDArray expI = NDArrayFactory::create<int>('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11});
NDArray expL = NDArrayFactory::create<Nd4jLong>('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11});
NDArray expF = NDArrayFactory::create<float>('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11});
NDArray expF16 = NDArrayFactory::create<float16>('c', {5, 2}, {1.f,2.f,3.f,4.f,5.f,7.f,9.f,10.f, 10.f, 11.f});
nd4j::ops::to_int32 op32;
nd4j::ops::to_int64 op64;
auto result32 = op32.execute({&x}, {}, {});
auto result64 = op64.execute({&x}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result32->status());
ASSERT_EQ(ND4J_STATUS_OK, result64->status());
auto out1 = result32->at(0);
// out1->printIndexedBuffer("OUT_I");
auto out2 = result64->at(0);
// out2->printIndexedBuffer("OUT_L");
// output->printIndexedBuffer("Toggled");
ASSERT_TRUE(expI.equalsTo(out1));
ASSERT_TRUE(expL.equalsTo(out2));
delete result32;
delete result64;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests7, TypesConversion_test2) {
NDArray x = NDArrayFactory::create<double >('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11});
NDArray expF = NDArrayFactory::create<float>('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11});
NDArray expH = NDArrayFactory::create<float16>('c', {5, 2}, {1.f,2.f,3.f,4.f,5.f,7.f,9.f,10.f, 10.f, 11.f});
nd4j::ops::to_float32 op32;
nd4j::ops::to_float16 op16;
auto result32 = op32.execute({&x}, {}, {});
auto result16 = op16.execute({&x}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result32->status());
ASSERT_EQ(ND4J_STATUS_OK, result16->status());
auto out1 = result32->at(0);
// out1->printIndexedBuffer("OUT_F");
auto out2 = result16->at(0);
// out2->printIndexedBuffer("OUT_H");
// output->printIndexedBuffer("Toggled");
ASSERT_TRUE(expF.equalsTo(out1));
ASSERT_TRUE(expH.equalsTo(out2));
delete result32;
delete result16;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests7, TypesConversion_test3) {
NDArray x = NDArrayFactory::create<Nd4jLong>('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11});
NDArray exp32 = NDArrayFactory::create<unsigned int>('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11});
NDArray exp64 = NDArrayFactory::create<uint64_t>('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11});
nd4j::ops::to_uint32 op32;
nd4j::ops::to_uint64 op64;
auto result32 = op32.execute({&x}, {}, {});
auto result64 = op64.execute({&x}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result32->status());
ASSERT_EQ(ND4J_STATUS_OK, result64->status());
auto out1 = result32->at(0);
// out1->printIndexedBuffer("OUT_U32");
auto out2 = result64->at(0);
// out2->printIndexedBuffer("OUT_U64");
// output->printIndexedBuffer("Toggled");
ASSERT_TRUE(exp32.equalsTo(out1));
ASSERT_TRUE(exp64.equalsTo(out2));
delete result32;
delete result64;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests7, TypesConversion_test4) {
NDArray x = NDArrayFactory::create<Nd4jLong>('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11});
NDArray exp32 = NDArrayFactory::create<float>('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11});
NDArray exp64 = NDArrayFactory::create<double>('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11});
nd4j::ops::to_float32 op32;
nd4j::ops::to_double op64;
auto result32 = op32.execute({&x}, {}, {});
auto result64 = op64.execute({&x}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result32->status());
ASSERT_EQ(ND4J_STATUS_OK, result64->status());
auto out1 = result32->at(0);
out1->printIndexedBuffer("OUT_F");
auto out2 = result64->at(0);
out2->printIndexedBuffer("OUT_D");
// output->printIndexedBuffer("Toggled");
ASSERT_TRUE(exp32.equalsTo(out1));
ASSERT_TRUE(exp64.equalsTo(out2));
delete result32;
delete result64;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests7, mirrorPad_test1) {

View File

@ -584,6 +584,180 @@ TEST_F(DeclarableOpsTests9, concat_test16) {
delete result;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, concat_test17) {
NDArray x0('c', {1, 55, 40}, nd4j::DataType::DOUBLE);
NDArray x1('c', {1, 55, 40}, nd4j::DataType::DOUBLE);
x0 = 1.;
x1 = 2.;
nd4j::ops::concat op;
auto result = op.execute({&x0, &x1}, {}, {0}, {});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
// z->printShapeInfo();
// z->printIndexedBuffer();
Nd4jLong numOfTads= ShapeUtils::getNumOfSubArrs(z->getShapeInfo(), {0});
ASSERT_TRUE(2 == numOfTads);
for (int e = 0; e < numOfTads; ++e) {
NDArray tad = (*z)(e, {0});
auto mean = tad.meanNumber().e<double>(0);
ASSERT_NEAR((e+1)*1., mean, 1e-5);
}
delete result;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, concat_test18) {
std::array<NDArray*, 2000> arrays;
Context context(1);
Nd4jLong axis = 0;
// we crate bunch of arrays, filled with specific values
for (int e = 0; e < arrays.size(); e++) {
auto array = NDArrayFactory::create_<float>('c', {1, 300});
array->assign(e);
context.setInputArray(e, array, true);
}
auto z = NDArrayFactory::create<float>('c', {2000, 300});
context.setOutputArray(0, &z, false);
context.setIArguments(&axis, 1);
nd4j::ops::concat op;
op.execute(&context);
for (int e = 0; e < arrays.size(); e++) {
auto row = z.tensorAlongDimension(e, {1});
ASSERT_NEAR((float) e, row->e<float>(0), 1e-5f);
delete row;
}
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, concat_test19) {
std::array<NDArray*, 10> arrays;
Context context(1);
Nd4jLong axis = 0;
// we crate bunch of arrays, filled with specific values
for (int e = 0; e < arrays.size(); e++) {
auto array = NDArrayFactory::create_<float>('c', {1, 5, 20});
array->assign(e);
context.setInputArray(e, array, true);
}
auto z = NDArrayFactory::create<float>('c', {arrays.size(), 5, 20});
context.setOutputArray(0, &z, false);
context.setIArguments(&axis, 1);
nd4j::ops::concat op;
op.execute(&context);
for (int e = 0; e < arrays.size(); e++)
ASSERT_NEAR((float) e, z(e, {0}).meanNumber().e<float>(0), 1e-5f);
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, concat_test20) {
auto x0 = NDArrayFactory::create<double>('c', {1, 100, 150});
auto x1 = NDArrayFactory::create<double>('c', {1, 100, 150});
auto x2 = NDArrayFactory::create<double>('c', {1, 100, 150});
auto x3 = NDArrayFactory::create<double>('c', {1, 100, 150});
x0.assign(1.0);
x1.assign(2.0);
x2.assign(3.0);
x3.assign(4.0);
nd4j::ops::concat op;
auto result = op.execute({&x0, &x1, &x2, &x3}, {}, {0}, {});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
Nd4jLong numOfTads= ShapeUtils::getNumOfSubArrs(z->getShapeInfo(), {0});
ASSERT_TRUE(4 == numOfTads);
for (int e = 0; e < numOfTads; e++) {
NDArray tad = (*z)(e, {0});
auto mean = tad.meanNumber().e<double>(0);
ASSERT_NEAR((double) e+1, mean, 1e-5);
}
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, concat_test21) {
NDArray x0('c', {1,4,5}, nd4j::DataType::FLOAT32);
NDArray x1('c', {2,4,5}, nd4j::DataType::FLOAT32);
NDArray z('f', {3,4,5}, nd4j::DataType::FLOAT32);
x0 = 0.;
x1 = 1.;
nd4j::ops::concat op;
auto status = op.execute({&x0, &x1}, {&z}, {}, {0}, {});
ASSERT_EQ(ND4J_STATUS_OK, status);
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, concat_test22) {
NDArray x0('c', {1,6}, {1,2,3,4,5,6});
NDArray x1('c', {1,6}, {7,8,9,10,11,12});
NDArray output('f', {2,6}, nd4j::DataType::DOUBLE);
NDArray exp('c', {2,6}, {1,2,3,4,5,6,7,8,9,10,11,12});
nd4j::ops::concat op;
auto status = op.execute({&x0, &x1}, {&output}, {}, {0}, {});
ASSERT_EQ(ND4J_STATUS_OK, status);
ASSERT_TRUE(exp.equalsTo(output));
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, concat_test23) {
NDArray x0('c', {1,4}, {1,2,3,4});
NDArray x1('c', {1,4}, {5,6,7,8});
NDArray output('c', {2,4}, nd4j::DataType::DOUBLE);
NDArray exp('c', {2,4}, {1,2,3,4,5,6,7,8});
nd4j::ops::concat op;
auto status = op.execute({&x0, &x1}, {&output}, {}, {0}, {});
ASSERT_EQ(ND4J_STATUS_OK, status);
ASSERT_TRUE(exp.equalsTo(output));
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, concat_test24) {
auto x = NDArrayFactory::create<double>('c', {2, 1}, {1, 1});
auto y = NDArrayFactory::create<double>('c', {2, 1}, {0, 0});
auto e = NDArrayFactory::create<double>('c', {2, 2}, {1, 0, 1, 0});
auto z = NDArrayFactory::create<double>('c', {2, 2});
nd4j::ops::concat op;
auto status = op.execute({&x, &y}, {&z}, {}, {1}, {});
ASSERT_EQ(Status::OK(), status);
ASSERT_EQ(e, z);
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, tile_bp_test1) {

View File

@ -975,69 +975,6 @@ TEST_F(JavaInteropTests, zeta_test10) {
ASSERT_EQ(e, z);
}
TEST_F(JavaInteropTests, Test_Is_Max_1) {
auto arrayX = NDArrayFactory::create<float>({1, 2, 1, 1});
auto arrayZ = NDArrayFactory::create<bool>({0, 0, 0, 0});
auto arrayE = NDArrayFactory::create<bool>({0, 1, 0, 0});
nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext();
Nd4jPointer* extraPointers = nullptr;
#ifdef __CUDABLAS__
extraPointers = new Nd4jPointer[7] {nullptr, context->getCudaStream(), context->getScalarPointer(), nullptr, context->getCudaSpecialStream(), context->getReductionPointer(), context->getAllocationPointer()};
#endif
NDArray::prepareSpecialUse({&arrayZ}, {&arrayX});
execTransformAny(extraPointers, transform::IsMax,
arrayX.buffer(), arrayX.shapeInfo(), arrayX.getSpecialBuffer(), arrayX.getSpecialShapeInfo(),
arrayZ.buffer(), arrayZ.shapeInfo(), arrayZ.getSpecialBuffer(), arrayZ.getSpecialShapeInfo(),
nullptr);
NDArray::registerSpecialUse({&arrayZ}, {&arrayX});
ASSERT_EQ(arrayE, arrayZ);
delete []extraPointers;
}
TEST_F(JavaInteropTests, Test_Is_Max_1_2) {
auto arrayX = NDArrayFactory::create<float>({1, 2, 1, 1});
auto arrayZ = NDArrayFactory::create<float>({0, 0, 0, 0});
auto arrayE = NDArrayFactory::create<float>({0, 1, 0, 0});
nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext();
Nd4jPointer* extraPointers = nullptr;
#ifdef __CUDABLAS__
extraPointers = new Nd4jPointer[7] {nullptr, context->getCudaStream(), context->getScalarPointer(), nullptr, context->getCudaSpecialStream(), context->getReductionPointer(), context->getAllocationPointer()};
#endif
NDArray::prepareSpecialUse({&arrayZ}, {&arrayX});
execTransformAny(extraPointers, transform::IsMax,
arrayX.buffer(), arrayX.shapeInfo(), arrayX.getSpecialBuffer(), arrayX.getSpecialShapeInfo(),
arrayZ.buffer(), arrayZ.shapeInfo(), arrayZ.getSpecialBuffer(), arrayZ.getSpecialShapeInfo(),
nullptr);
//arrayZ.printIndexedBuffer("JAVA ISMAX1");
NDArray::registerSpecialUse({&arrayZ}, {&arrayX});
ASSERT_EQ(arrayE, arrayZ);
delete []extraPointers;
}
TEST_F(JavaInteropTests, Test_Is_Max_2) {
auto arrayX = NDArrayFactory::create<float>('c', {3, 2, 3}, {1, 10, 2, 3, 4, 5, -10, -9, -8, -7, -6, -5, 4, 3, 2, 1, 0, -1});
auto arrayZ = NDArrayFactory::create<bool>('c', {3, 2, 3});
Nd4jLong tad[] = {2, 2, 3, 3, 1, 524288, -1, 99};
Nd4jLong off[] = {0, 6, 12};
Nd4jLong *ex[] = {tad, off};
float ea[] = {2, 1, 2};
NDArray::prepareSpecialUse({&arrayZ}, {&arrayX});
execTransformBool(reinterpret_cast<void **>(ex), transform::IsMax,
arrayX.buffer(), arrayX.shapeInfo(), arrayX.getSpecialBuffer(), arrayX.getSpecialShapeInfo(),
arrayZ.buffer(), arrayZ.shapeInfo(), arrayZ.getSpecialBuffer(), arrayZ.getSpecialShapeInfo(),
ea);
NDArray::registerSpecialUse({&arrayZ}, {&arrayX});
}
TEST_F(JavaInteropTests, Test_IAMax_1) {
auto arrayX = NDArrayFactory::create<float>({-0.24f, -0.26f, -0.07f, -0.01f});
auto arrayZ = arrayX.indexReduceNumber(indexreduce::IndexAbsoluteMax, nullptr);

View File

@ -367,49 +367,6 @@ TEST_F(LegacyOpsTests, IndexReduceTests_2) {
delete result;
}
TEST_F(LegacyOpsTests, Test_IsMax_1) {
if (!Environment::getInstance()->isCPU())
return;
auto x = NDArrayFactory::create<double>('c', {2, 2, 2, 2, 2, 2});
auto z = NDArrayFactory::create<double>('c', {2, 2, 2, 2, 2, 2});
x.linspace(1.0);
z.assign(-589);
double extra[] = {1.0, 0.0};
NativeOpExecutioner::execTransformAny(nullptr, transform::IsMax, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(),
z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), extra, nullptr, nullptr);
// z.printIndexedBuffer("z");
for (Nd4jLong e = 0; e < z.lengthOf(); e++) {
ASSERT_TRUE(z.e<double>(e) >= 0);
}
}
TEST_F(LegacyOpsTests, Test_IsMax_2) {
if (!Environment::getInstance()->isCPU())
return;
auto x = NDArrayFactory::create<double>('c', {2, 2, 2, 2, 2, 2});
auto z = NDArrayFactory::create<bool>('c', {2, 2, 2, 2, 2, 2});
x.linspace(1.0);
z.assign(false);
double extra[] = {1.0, 0.0};
NativeOpExecutioner::execTransformAny(nullptr, transform::IsMax, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(),
z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), extra, nullptr, nullptr);
// z.printIndexedBuffer("z");
for (Nd4jLong e = 0; e < z.lengthOf(); e++) {
if (e >= z.lengthOf() / 2)
ASSERT_TRUE(z.e<bool>(e));
else
ASSERT_FALSE(z.e<bool>(e));
}
}
TEST_F(LegacyOpsTests, BroadcastingTests_1) {
auto x = NDArrayFactory::create<double>('c', {5, 5});
x.assign(0.0f);

View File

@ -78,6 +78,72 @@ TEST_F(ListOperationsTests, BasicTest_Stack_1) {
delete tads;
}
TEST_F(ListOperationsTests, BasicTest_UnStackList_1) {
NDArrayList list(0, true);
auto x = NDArrayFactory::create<double>('c', {10, 100});
auto tads = x.allTensorsAlongDimension({1});
for (int e = 0; e < 10; e++) {
auto row = NDArrayFactory::create_<double>('c', {100});
row->assign((double) e);
//list.write(e, row);
tads->at(e)->assign(row);
delete row;
}
nd4j::ops::unstack_list op;
auto result = op.execute(&list, {&x}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
ASSERT_EQ(list.elements(), 10);
// auto z = result->at(0);
// z->printShapeInfo("The first of");
// ASSERT_TRUE(exp.isSameShape(z));
// ASSERT_TRUE(exp.equalsTo(z));
for (int e = 0; e < 10; e++) {
auto row = list.read(e);
ASSERT_TRUE(row->equalsTo(tads->at(e)));
//list.write(e, row);
delete row;
}
delete result;
delete tads;
}
//TEST_F(ListOperationsTests, BasicTest_UnStackList_2) {
//// NDArrayList list(0, true);
// auto x = NDArrayFactory::create<double>('c', {10, 100});
// auto tads = x.allTensorsAlongDimension({1});
// for (int e = 0; e < 10; e++) {
// auto row = NDArrayFactory::create_<double>('c', {100});
// row->assign((double) e);
// //list.write(e, row);
// tads->at(e)->assign(row);
// delete row;
// }
//
// nd4j::ops::unstack_list op;
//
// auto result = op.execute(nullptr, {&x}, {}, {0});
//
// ASSERT_EQ(ND4J_STATUS_OK, result->status());
// ASSERT_EQ(result->size(), 10);
//
// // auto z = result->at(0);
//// z->printShapeInfo("The first of");
//// ASSERT_TRUE(exp.isSameShape(z));
//// ASSERT_TRUE(exp.equalsTo(z));
// for (int e = 0; e < 10; e++) {
// auto row = result->at(e);
// ASSERT_TRUE(row->equalsTo(tads->at(e)));
// //list.write(e, row);
// }
//
// delete result;
// delete tads;
//}
TEST_F(ListOperationsTests, BasicTest_Read_1) {
NDArrayList list(10);

View File

@ -193,8 +193,8 @@ TEST_F(MultiDataTypeTests, ndarray_assign_number_test3) {
////////////////////////////////////////////////////////////////////////////////
TEST_F(MultiDataTypeTests, ndarray_repeat_test1) {
NDArray x('c', {2, 2}, {0.5, 1.5, 2.5, 3.5}, nd4j::DataType::HALF);
NDArray y('c', {2, 4}, nd4j::DataType::UINT8);
NDArray exp('c', {2, 4}, {0, 0, 1, 1, 2, 2, 3, 3}, nd4j::DataType::UINT8);
NDArray y('c', {2, 4}, nd4j::DataType::HALF);
NDArray exp('c', {2, 4}, {0.5, 0.5, 1.5, 1.5, 2.5, 2.5, 3.5, 3.5}, nd4j::DataType::HALF);
x.repeat(1, y);
@ -1790,6 +1790,7 @@ TEST_F(MultiDataTypeTests, RowCol_test2) {
}
//////////////////////////////////////////////////////////////////////
/*
TEST_F(MultiDataTypeTests, tile_test1) {
NDArray x1('c', {2,1}, {0,1}, nd4j::DataType::INT32);
@ -1823,6 +1824,7 @@ TEST_F(MultiDataTypeTests, tile_test1) {
x1.tile(x7);
ASSERT_EQ(x7, exp4);
}
*/
//////////////////////////////////////////////////////////////////////
TEST_F(MultiDataTypeTests, broadcast_test1) {

View File

@ -248,8 +248,8 @@ TEST_F(RNGTests, Test_Gaussian_21) {
RandomLauncher::fillGaussian(LaunchContext::defaultContext(), _rngA, &x0, 0.0f, 1.0f);
RandomLauncher::fillGaussian(LaunchContext::defaultContext(), _rngB, &x1, 0.0f, 1.0f);
x0.printIndexedBuffer("x0");
x1.printIndexedBuffer("x1");
// x0.printIndexedBuffer("x0");
// x1.printIndexedBuffer("x1");
ASSERT_TRUE(x0.equalsTo(&x1));
ASSERT_FALSE(x0.equalsTo(nexp0));

View File

@ -229,44 +229,7 @@ import org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByNorm;
import org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByValue;
import org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndReplace;
import org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet;
import org.nd4j.linalg.api.ops.impl.transforms.custom.ATan2;
import org.nd4j.linalg.api.ops.impl.transforms.custom.Assign;
import org.nd4j.linalg.api.ops.impl.transforms.custom.BatchToSpace;
import org.nd4j.linalg.api.ops.impl.transforms.custom.CumProd;
import org.nd4j.linalg.api.ops.impl.transforms.custom.CumSum;
import org.nd4j.linalg.api.ops.impl.transforms.custom.Dilation2D;
import org.nd4j.linalg.api.ops.impl.transforms.custom.DotProductAttention;
import org.nd4j.linalg.api.ops.impl.transforms.custom.DotProductAttentionBp;
import org.nd4j.linalg.api.ops.impl.transforms.custom.DynamicPartition;
import org.nd4j.linalg.api.ops.impl.transforms.custom.DynamicStitch;
import org.nd4j.linalg.api.ops.impl.transforms.custom.EqualTo;
import org.nd4j.linalg.api.ops.impl.transforms.custom.Fill;
import org.nd4j.linalg.api.ops.impl.transforms.custom.GreaterThan;
import org.nd4j.linalg.api.ops.impl.transforms.custom.GreaterThanOrEqual;
import org.nd4j.linalg.api.ops.impl.transforms.custom.InvertPermutation;
import org.nd4j.linalg.api.ops.impl.transforms.custom.IsNonDecreasing;
import org.nd4j.linalg.api.ops.impl.transforms.custom.IsNumericTensor;
import org.nd4j.linalg.api.ops.impl.transforms.custom.IsStrictlyIncreasing;
import org.nd4j.linalg.api.ops.impl.transforms.custom.LayerNorm;
import org.nd4j.linalg.api.ops.impl.transforms.custom.LayerNormBp;
import org.nd4j.linalg.api.ops.impl.transforms.custom.LessThan;
import org.nd4j.linalg.api.ops.impl.transforms.custom.LessThanOrEqual;
import org.nd4j.linalg.api.ops.impl.transforms.custom.ListDiff;
import org.nd4j.linalg.api.ops.impl.transforms.custom.LogSoftMax;
import org.nd4j.linalg.api.ops.impl.transforms.custom.MatrixDeterminant;
import org.nd4j.linalg.api.ops.impl.transforms.custom.MatrixInverse;
import org.nd4j.linalg.api.ops.impl.transforms.custom.MatrixSetDiag;
import org.nd4j.linalg.api.ops.impl.transforms.custom.MultiHeadDotProductAttention;
import org.nd4j.linalg.api.ops.impl.transforms.custom.MultiHeadDotProductAttentionBp;
import org.nd4j.linalg.api.ops.impl.transforms.custom.NotEqualTo;
import org.nd4j.linalg.api.ops.impl.transforms.custom.Reverse;
import org.nd4j.linalg.api.ops.impl.transforms.custom.ReverseSequence;
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
import org.nd4j.linalg.api.ops.impl.transforms.custom.SpaceToBatch;
import org.nd4j.linalg.api.ops.impl.transforms.custom.Standardize;
import org.nd4j.linalg.api.ops.impl.transforms.custom.StandardizeBp;
import org.nd4j.linalg.api.ops.impl.transforms.custom.Trace;
import org.nd4j.linalg.api.ops.impl.transforms.custom.XwPlusB;
import org.nd4j.linalg.api.ops.impl.transforms.custom.*;
import org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentMax;
import org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentMean;
import org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentMin;
@ -289,25 +252,8 @@ import org.nd4j.linalg.api.ops.impl.transforms.gradient.SELUDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SigmoidDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftmaxBp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.DivOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.FloorDivOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.FloorModOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MergeAddOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MulOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.RDivOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.RSubOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.SquaredDifferenceOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.SubOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.TruncateDivOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.AddBpOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.DivBpOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.FloorDivBpOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.FloorModBpOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.MulBpOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.RDivBpOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.RSubBpOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.SubBpOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.*;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.*;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.And;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Or;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Xor;
@ -1290,7 +1236,7 @@ public class DifferentialFunctionFactory {
}
public SDVariable isMax(SDVariable ix) {
return new IsMax(sameDiff(), ix, false).outputVariable();
return new IsMax(sameDiff(), ix).outputVariable();
}
public SDVariable replaceWhere(SDVariable to, SDVariable from, Condition condition) {
@ -1317,6 +1263,21 @@ public class DifferentialFunctionFactory {
return new Xor(sameDiff(), ix, iy).outputVariable();
}
public SDVariable shift(SDVariable ix, int shift) {
return new ShiftBits(sameDiff(), ix, shift).outputVariable();
}
public SDVariable rshift(SDVariable ix, int shift) {
return new RShiftBits(sameDiff(), ix, shift).outputVariable();
}
public SDVariable rotl(SDVariable ix, int shift) {
return new CyclicShiftBits(sameDiff(), ix, shift).outputVariable();
}
public SDVariable rotr(SDVariable ix, int shift) {
return new CyclicRShiftBits(sameDiff(), ix, shift).outputVariable();
}
public SDVariable eq(SDVariable iX, SDVariable i_y) {
return new EqualTo(sameDiff(), new SDVariable[]{iX, i_y}, false).outputVariable();
@ -2231,6 +2192,10 @@ public class DifferentialFunctionFactory {
return Arrays.asList(new MulBpOp(sameDiff(), x, y, grad).outputVariables());
}
public List<SDVariable> modBp(SDVariable x, SDVariable y, SDVariable grad) {
return Arrays.asList(new ModBpOp(sameDiff(), x, y, grad).outputVariables());
}
public SDVariable muli(SDVariable differentialFunction, SDVariable i_v) {
validateDifferentialFunctionsameDiff(differentialFunction);
@ -2238,6 +2203,10 @@ public class DifferentialFunctionFactory {
}
public SDVariable mod(SDVariable differentialFunction, SDVariable i_v) {
validateDifferentialFunctionsameDiff(differentialFunction);
return new ModOp(sameDiff(), new SDVariable[]{differentialFunction, i_v}, false).outputVariable();
}
public SDVariable div(SDVariable differentialFunction, SDVariable i_v) {
validateDifferentialFunctionsameDiff(differentialFunction);

View File

@ -804,6 +804,34 @@ public class SDVariable extends DifferentialFunction implements Serializable {
return sameDiff.updateVariableNameAndReference(result, name);
}
/**
* Floor division operation: elementwise {@code this // x}<br>
* If this and x variables have equal shape, the output shape is the same as the inputs.<br>
* Supports broadcasting: if this and x have different shapes and are broadcastable, the output shape is broadcast.
*
* @param name Name of the output variable
* @param x Variable to perform operation with
* @return Output (result) SDVariable
*/
public SDVariable fdiv(String name, SDVariable x) {
val result = sameDiff.f().floorDiv(this, x);
return sameDiff.updateVariableNameAndReference(result, name);
}
/**
* Modulo operation: elementwise {@code this / x}<br>
* If this and x variables have equal shape, the output shape is the same as the inputs.<br>
* Supports broadcasting: if this and x have different shapes and are broadcastable, the output shape is broadcast.
*
* @param name Name of the output variable
* @param x Variable to perform operation with
* @return Output (result) SDVariable
*/
public SDVariable mod(String name, SDVariable x) {
val result = sameDiff.f().mod(this, x);
return sameDiff.updateVariableNameAndReference(result, name);
}
/**
* See {@link #mul(String, double)}
*/

View File

@ -682,6 +682,8 @@ public class InferenceSession extends AbstractSession<INDArray,DifferentialFunct
List<LongShapeDescriptor> outShape = customOp.calculateOutputShape();
Preconditions.checkState(outShape != null && outShape.size() > 0, "Failed to calculate output shapes for op %s (%s) - no shapes were returned by calculateOutputShape()", customOp.opName(), customOp.getOwnName());
String[] outNames = df.outputVariablesNames();
Preconditions.checkState(outNames.length == outShape.size(), "Error in operation shape calculation for op \"%s\": Got %s op output shapes for an operation" +
" with %s outputs (number of shapes and outputs must be equal)", df.opName(), outShape.size(), outNames.length);
for( int i=0; i<outShape.size(); i++ ){
INDArray currOutput = (customOp.numOutputArguments() <= i ? null : customOp.getOutputArgument(i));
LongShapeDescriptor reqShape = outShape.get(i);

View File

@ -2421,6 +2421,58 @@ public class SDMath extends SDOps {
return updateVariableNameAndReference(result, name);
}
/**
* Shift integer bits to the left, i.e. var << 4
*
* @param name Name of the output variable
* @param x Input 1
* @return Output SDVariable with shifted bits
*/
public SDVariable bitShift(String name, SDVariable x, int shift) {
validateInteger("shift_bits", x);
SDVariable result = f().shift(x, shift);
return updateVariableNameAndReference(result, name);
}
/**
* Shift integer bits to the right, i.e. var >> 4
*
* @param name Name of the output variable
* @param x Input 1
* @return Output SDVariable with shifted bits
*/
public SDVariable bitShiftRight(String name, SDVariable x, int shift) {
validateInteger("rshift_bits", x);
SDVariable result = f().rshift(x, shift);
return updateVariableNameAndReference(result, name);
}
/**
* Roll integer bits to the left, i.e. var << 4 | var >> (32 - 4)
*
* @param name Name of the output variable
* @param x Input 1
* @return Output SDVariable with shifted bits
*/
public SDVariable bitRotl(String name, SDVariable x, int shift) {
validateInteger("cyclic_shift_bits", x);
SDVariable result = f().rotl(x, shift);
return updateVariableNameAndReference(result, name);
}
/**
* Roll integer bits to the right, i.e. var >> 4 | var << (32 - 4)
*
* @param name Name of the output variable
* @param x Input 1
* @return Output SDVariable with shifted bits
*/
public SDVariable bitRotr(String name, SDVariable x, int shift) {
validateInteger("cyclic_rshift_bits", x);
SDVariable result = f().rotr(x, shift);
return updateVariableNameAndReference(result, name);
}
/**
* Full array zero fraction array reduction operation, optionally along specified dimensions: out = (count(x == 0) / length(x))
*

View File

@ -262,7 +262,7 @@ public class EvaluationCalibration extends BaseEvaluation<EvaluationCalibration>
labelCountsEachClass.addi(labels2d.sum(0).castTo(labelCountsEachClass.dataType()));
//For prediction counts: do an IsMax op, but we need to take masking into account...
INDArray isPredictedClass = Nd4j.getExecutioner().exec(new IsMax(p.dup(), 1));
INDArray isPredictedClass = Nd4j.getExecutioner().exec(new IsMax(p, p.ulike(), 1))[0];
if (maskArray != null) {
LossUtil.applyMask(isPredictedClass, maskArray);
}

View File

@ -34,6 +34,12 @@ public interface AffinityManager {
*/
Integer getDeviceForCurrentThread();
/**
* This method returns deviceId for a given thread
* @return
*/
Integer getDeviceForThread(long threadId);
/**
* This method returns id of current device for a given INDArray

View File

@ -28,6 +28,11 @@ public abstract class BasicAffinityManager implements AffinityManager {
return 0;
}
@Override
public Integer getDeviceForThread(long threadId) {
return 0;
}
@Override
public Integer getDeviceForArray(INDArray array) {
return 0;

View File

@ -24,6 +24,7 @@ import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseTransformAnyOp;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.factory.Nd4j;
@ -34,47 +35,29 @@ import java.util.List;
* [1, 2, 3, 1] -> [0, 0, 1, 0]
* @author Adam Gibson
*/
public class IsMax extends BaseTransformAnyOp {
public IsMax(SameDiff sameDiff, SDVariable i_v, boolean inPlace) {
super(sameDiff, i_v, inPlace);
public class IsMax extends DynamicCustomOp {
public IsMax(SameDiff sameDiff, SDVariable i_v) {
super(sameDiff, i_v);
}
public IsMax(SameDiff sameDiff, SDVariable i_v, int[] shape, boolean inPlace, Object[] extraArgs) {
super(sameDiff, i_v, shape, inPlace, extraArgs);
}
public IsMax(SameDiff sameDiff, SDVariable i_v, Object[] extraArgs) {
super(sameDiff, i_v, extraArgs);
}
public IsMax(INDArray x, INDArray z) {
super(x, z);
super(new INDArray[]{x}, new INDArray[]{z});
}
public IsMax() {}
public IsMax(INDArray x) {
super(x, Nd4j.createUninitialized(DataType.BOOL, x.shape(), x.ordering()));
this(x, Nd4j.createUninitialized(DataType.BOOL, x.shape(), x.ordering()));
}
public IsMax(INDArray x, INDArray z, int... dimensions) {
super(x, z);
this.extraArgs = new Object[dimensions.length + 1];
this.extraArgs[0] = dimensions.length;
for (int i = 0; i < dimensions.length; i++)
this.extraArgs[i + 1] = dimensions[i];
this(x, z);
this.addIArgument(dimensions);
}
public IsMax(INDArray x, int... dimensions) {
super(x, Nd4j.createUninitialized(x.dataType(), x.shape(), x.ordering()));
this.extraArgs = new Object[dimensions.length + 1];
this.extraArgs[0] = dimensions.length;
for (int i = 0; i < dimensions.length; i++)
this.extraArgs[i + 1] = dimensions[i];
}
@Override
public int opNum() {
return 1;
this(x, Nd4j.createUninitialized(DataType.BOOL, x.shape(), x.ordering()), dimensions);
}
@Override
@ -82,7 +65,6 @@ public class IsMax extends BaseTransformAnyOp {
return "ismax";
}
@Override
public String onnxName() {
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
@ -93,14 +75,6 @@ public class IsMax extends BaseTransformAnyOp {
throw new NoOpNameFoundException("No tensorflow op opName found for " + opName());
}
@Override
public DataBuffer extraArgsDataBuff(DataType dtype) {
if (Nd4j.getExecutioner().type() == OpExecutioner.ExecutionerType.CUDA)
return this.extraArgs == null ? null : Nd4j.createBuffer(DataType.LONG, 1, false);
else
return super.extraArgsDataBuff(dtype);
}
@Override
public List<SDVariable> doDiff(List<SDVariable> f1) {
return Collections.singletonList(f().zerosLike(arg()));

View File

@ -77,7 +77,7 @@ public class BatchToSpace extends DynamicCustomOp {
@Override
public String tensorflowName() {
return "BatchToSpaceND";
return "BatchToSpace";
}
@Override

View File

@ -0,0 +1,93 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.transforms.custom;
import lombok.val;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
/**
* N-dimensional batch to space operation. Transforms data from a tensor from batch dimension into M spatial dimensions
* according to the "blocks" specified (a vector of length M). Afterwards the spatial dimensions are optionally cropped,
* as specified in "crops", a tensor of dim (M, 2), denoting the crop range.
* <p>
* Example:
* input: [[[[1]]], [[[2]]], [[[3]]], [[[4]]]]
* input shape: [4, 1, 1, 1]
* blocks: [2, 2]
* crops: [[0, 0], [0, 0]]
* <p>
* output: [[[[1], [2]], [[3], [4]]]]
* output shape: [1, 2, 2, 1]
*
* @author Max Pumperla
*/
public class BatchToSpaceND extends DynamicCustomOp {
private int[] blocks;
private int[][] crops;
public BatchToSpaceND() {
}
public BatchToSpaceND(SameDiff sameDiff, SDVariable[] args, int[] blocks, int[][] crops, boolean inPlace) {
super(null, sameDiff, args, inPlace);
this.blocks = blocks;
this.crops = crops;
for (val b : blocks)
addIArgument(b);
for (int e = 0; e < crops.length; e++)
addIArgument(crops[e][0], crops[e][1]);
}
@Override
public String opName() {
return "batch_to_space_nd";
}
@Override
public String onnxName() {
return "batch_to_space_nd";
}
@Override
public String tensorflowName() {
return "BatchToSpaceND";
}
@Override
public List<SDVariable> doDiff(List<SDVariable> i_v) {
// Inverse of batch to space is space to batch with same blocks and padding as crops
SDVariable gradient = sameDiff.setupFunction(i_v.get(0));
return Arrays.asList(sameDiff.cnn().spaceToBatch(gradient, blocks, crops));
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes){
return Collections.singletonList(dataTypes.get(0));
}
}

View File

@ -0,0 +1,80 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.transforms.custom;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp;
import java.util.Collections;
import java.util.List;
/**
* Element-wise roll operation, rolls bits to the left, <<
*
* @author raver119@gmail.com
*/
public class CyclicRShiftBits extends BaseDynamicTransformOp {
public CyclicRShiftBits(SameDiff sameDiff, SDVariable x, int shift) {
super(sameDiff, new SDVariable[] {x} ,false);
this.addIArgument(shift);
}
public CyclicRShiftBits(INDArray input, int shift, INDArray output) {
super(new INDArray[]{input}, new INDArray[]{output});
this.addIArgument(shift);
}
public CyclicRShiftBits(INDArray input, int shift) {
this(input, shift,null);
}
public CyclicRShiftBits() {}
@Override
public String opName() {
return "cyclic_rshift_bits";
}
@Override
public String onnxName() {
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
}
@Override
public String tensorflowName() {
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
}
@Override
public List<SDVariable> doDiff(List<SDVariable> i_v) {
throw new UnsupportedOperationException("Not yet implemented: " + opName());
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes){
Preconditions.checkState(dataTypes.get(0).isIntType(), "Input 0 datatype must be a integer type, got %s", dataTypes.get(0));
return Collections.singletonList(dataTypes.get(0));
}
}

View File

@ -0,0 +1,80 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.transforms.custom;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp;
import java.util.Collections;
import java.util.List;
/**
* Element-wise roll operation, rolls bits to the left, <<
*
* @author raver119@gmail.com
*/
public class CyclicShiftBits extends BaseDynamicTransformOp {
public CyclicShiftBits(SameDiff sameDiff, SDVariable x, int shift) {
super(sameDiff, new SDVariable[] {x} ,false);
this.addIArgument(shift);
}
public CyclicShiftBits(INDArray input, int shift, INDArray output) {
super(new INDArray[]{input}, new INDArray[]{output});
this.addIArgument(shift);
}
public CyclicShiftBits(INDArray input, int shift) {
this(input, shift,null);
}
public CyclicShiftBits() {}
@Override
public String opName() {
return "cyclic_shift_bits";
}
@Override
public String onnxName() {
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
}
@Override
public String tensorflowName() {
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
}
@Override
public List<SDVariable> doDiff(List<SDVariable> i_v) {
throw new UnsupportedOperationException("Not yet implemented: " + opName());
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes){
Preconditions.checkState(dataTypes.get(0).isIntType(), "Input 0 datatype must be a integer type, got %s", dataTypes.get(0));
return Collections.singletonList(dataTypes.get(0));
}
}

View File

@ -0,0 +1,80 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.transforms.custom;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp;
import java.util.Collections;
import java.util.List;
/**
* Element-wise shift operation, shift bits to the right, >>
*
* @author raver119@gmail.com
*/
public class RShiftBits extends BaseDynamicTransformOp {
public RShiftBits(SameDiff sameDiff, SDVariable x, int shift) {
super(sameDiff, new SDVariable[] {x} ,false);
this.addIArgument(shift);
}
public RShiftBits(INDArray input, int shift, INDArray output) {
super(new INDArray[]{input}, new INDArray[]{output});
this.addIArgument(shift);
}
public RShiftBits(INDArray input, int shift) {
this(input, shift,null);
}
public RShiftBits() {}
@Override
public String opName() {
return "rshift_bits";
}
@Override
public String onnxName() {
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
}
@Override
public String tensorflowName() {
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
}
@Override
public List<SDVariable> doDiff(List<SDVariable> i_v) {
throw new UnsupportedOperationException("Not yet implemented: " + opName());
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes){
Preconditions.checkState(dataTypes.get(0).isIntType(), "Input 0 datatype must be a integer type, got %s", dataTypes.get(0));
return Collections.singletonList(dataTypes.get(0));
}
}

View File

@ -0,0 +1,80 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.transforms.custom;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp;
import java.util.Collections;
import java.util.List;
/**
* Element-wise shift operation, shift bits to the left, <<
*
* @author raver119@gmail.com
*/
public class ShiftBits extends BaseDynamicTransformOp {
public ShiftBits(SameDiff sameDiff, SDVariable x, int shift) {
super(sameDiff, new SDVariable[] {x} ,false);
this.addIArgument(shift);
}
public ShiftBits(INDArray input, int shift, INDArray output) {
super(new INDArray[]{input}, new INDArray[]{output});
this.addIArgument(shift);
}
public ShiftBits(INDArray input, int shift) {
this(input, shift,null);
}
public ShiftBits() {}
@Override
public String opName() {
return "shift_bits";
}
@Override
public String onnxName() {
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
}
@Override
public String tensorflowName() {
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
}
@Override
public List<SDVariable> doDiff(List<SDVariable> i_v) {
throw new UnsupportedOperationException("Not yet implemented: " + opName());
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes){
Preconditions.checkState(dataTypes.get(0).isIntType(), "Input 0 datatype must be a integer type, got %s", dataTypes.get(0));
return Collections.singletonList(dataTypes.get(0));
}
}

View File

@ -77,7 +77,7 @@ public class SpaceToBatch extends DynamicCustomOp {
@Override
public String tensorflowName() {
return "SpaceToBatchND";
return "SpaceToBatch";
}
@Override

View File

@ -0,0 +1,95 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.transforms.custom;
import lombok.val;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
/**
* N-dimensional space to batch operation. Transforms data from a tensor from M spatial dimensions into batch dimension
* according to the "blocks" specified (a vector of length M). Afterwards the spatial dimensions are optionally padded,
* as specified in "padding", a tensor of dim (M, 2), denoting the padding range.
* <p>
* Example:
* input: [[[[1], [2]], [[3], [4]]]]
* input shape: [1, 2, 2, 1]
* blocks: [2, 2]
* padding: [[0, 0], [0, 0]]
* <p>
* output: [[[[1]]], [[[2]]], [[[3]]], [[[4]]]]
* output shape: [4, 1, 1, 1]
* *
*
* @author Max Pumperla
*/
public class SpaceToBatchND extends DynamicCustomOp {
protected int[] blocks;
protected int[][] padding;
public SpaceToBatchND() {
}
public SpaceToBatchND(SameDiff sameDiff, SDVariable[] args, int[] blocks, int[][] padding, boolean inPlace) {
super(null, sameDiff, args, inPlace);
this.blocks = blocks;
this.padding = padding;
for (val b : blocks)
addIArgument(b);
for (int e = 0; e < padding.length; e++)
addIArgument(padding[e][0], padding[e][1]);
}
@Override
public String opName() {
return "space_to_batch_nd";
}
@Override
public String onnxName() {
return "space_to_batch_nd";
}
@Override
public String tensorflowName() {
return "SpaceToBatchND";
}
@Override
public List<SDVariable> doDiff(List<SDVariable> i_v) {
// Inverse of space to batch is batch to space with same blocks and crops as padding
SDVariable gradient = sameDiff.setupFunction(i_v.get(0));
return Arrays.asList(sameDiff.cnn().batchToSpace(gradient, blocks, padding));
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes){
return Collections.singletonList(dataTypes.get(0));
}
}

View File

@ -0,0 +1,69 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp;
import java.util.List;
/**
* Modulo operation
*
* @author raver119@gmail.com
*/
public class ModOp extends BaseDynamicTransformOp {
public static final String OP_NAME = "mod";
public ModOp() {}
public ModOp(SameDiff sameDiff, SDVariable[] args, boolean inPlace) {
super(sameDiff, args, inPlace);
}
public ModOp(INDArray first, INDArray second, INDArray result){
this(new INDArray[]{first, second}, result == null ? null : new INDArray[]{result});
}
public ModOp(INDArray[] inputs, INDArray[] outputs) {
super(inputs, outputs);
}
@Override
public String opName() {
return OP_NAME;
}
@Override
public String onnxName() {
return "Mod";
}
@Override
public String tensorflowName() {
return "mod";
}
@Override
public List<SDVariable> doDiff(List<SDVariable> i_v) {
return f().modBp(larg(), rarg(), i_v.get(0));
}
}

View File

@ -0,0 +1,39 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
/**
* Modulo backprop operation. Supports 'undoing' of auto broadcast as applied in div op forward pass
*
* @author raver119@gmail.com
*/
public class ModBpOp extends BaseArithmeticBackpropOp {
public ModBpOp() {}
public ModBpOp(SameDiff sameDiff, SDVariable x, SDVariable y, SDVariable eps) {
super(sameDiff, x,y,eps);
}
@Override
public String opName() {
return "mod_bp";
}
}

View File

@ -3676,7 +3676,7 @@ public class Shape {
}
public static boolean isR(@NonNull DataType x) {
return x == DataType.FLOAT || x == DataType.HALF || x == DataType.DOUBLE;
return x == DataType.FLOAT || x == DataType.HALF || x == DataType.DOUBLE || x == DataType.BFLOAT16;
}
private static DataType max(@NonNull DataType typeX, @NonNull DataType typeY) {

View File

@ -378,7 +378,7 @@ public class Transforms {
public static INDArray asin(INDArray in, boolean copy) {
return Nd4j.getExecutioner().exec(new ASin(((copy ? in.dup() : in))));
return Nd4j.getExecutioner().exec(new ASin(in, (copy ? in.ulike() : in)));
}
public static INDArray atan(INDArray arr) {
@ -999,7 +999,8 @@ public class Transforms {
}
public static INDArray isMax(INDArray input, INDArray output) {
return Nd4j.getExecutioner().exec(new IsMax(input, output));
Nd4j.getExecutioner().exec(new IsMax(input, output));
return output;
}
@ -1035,7 +1036,7 @@ public class Transforms {
* @return
*/
public static INDArray sqrt(INDArray ndArray, boolean dup) {
return exec(dup ? new Sqrt(ndArray, ndArray.ulike()) : new Sqrt(ndArray));
return exec(dup ? new Sqrt(ndArray, ndArray.ulike()) : new Sqrt(ndArray, ndArray));
}
/**

Some files were not shown because too many files have changed in this diff Show More