[WIP] More fixes (#178)

* skip string arrays for device validation

Signed-off-by: raver119 <raver119@gmail.com>

* histogram_fixed_width now really supports indexing types

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2019-08-27 13:21:01 +03:00 committed by GitHub
parent fd22a8ecc7
commit a49f7c908b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 19 additions and 18 deletions

View File

@ -49,7 +49,7 @@ CUSTOM_OP_IMPL(histogram_fixed_width, 2, 1, false, 0, 0) {
DECLARE_TYPES(histogram_fixed_width) {
getOpDescriptor()
->setAllowedInputTypes(nd4j::DataType::ANY)
->setAllowedOutputTypes({ALL_INTS});
->setAllowedOutputTypes({ALL_INDICES});
}

View File

@ -27,16 +27,16 @@ namespace ops {
namespace helpers {
///////////////////////////////////////////////////////////////////
template<typename T>
template<typename X, typename Z>
__global__ static void histogramFixedWidthCuda( const void* vx, const Nd4jLong* xShapeInfo,
void* vz, const Nd4jLong* zShapeInfo,
const T leftEdge, const T rightEdge) {
const X leftEdge, const X rightEdge) {
const T* x = reinterpret_cast<const T*>(vx);
Nd4jLong* z = reinterpret_cast<Nd4jLong*>(vz);
const auto x = reinterpret_cast<const X*>(vx);
auto z = reinterpret_cast<Z*>(vz);
__shared__ Nd4jLong xLen, zLen, totalThreads, nbins;
__shared__ T binWidth, secondEdge, lastButOneEdge;
__shared__ X binWidth, secondEdge, lastButOneEdge;
if (threadIdx.x == 0) {
@ -55,7 +55,7 @@ __global__ static void histogramFixedWidthCuda( const void* vx, const Nd4jLong*
for (Nd4jLong i = tid; i < xLen; i += totalThreads) {
const T value = x[shape::getIndexOffset(i, xShapeInfo, xLen)];
const X value = x[shape::getIndexOffset(i, xShapeInfo, xLen)];
Nd4jLong zIndex;
@ -66,18 +66,18 @@ __global__ static void histogramFixedWidthCuda( const void* vx, const Nd4jLong*
else
zIndex = static_cast<Nd4jLong>((value - leftEdge) / binWidth);
nd4j::math::atomics::nd4j_atomicAdd(&z[shape::getIndexOffset(zIndex, zShapeInfo, nbins)], 1LL);
nd4j::math::atomics::nd4j_atomicAdd<Z>(&z[shape::getIndexOffset(zIndex, zShapeInfo, nbins)], 1);
}
}
///////////////////////////////////////////////////////////////////
template<typename T>
template<typename X, typename Z>
__host__ static void histogramFixedWidthCudaLauncher(const cudaStream_t *stream, const NDArray& input, const NDArray& range, NDArray& output) {
const T leftEdge = range.e<T>(0);
const T rightEdge = range.e<T>(1);
const X leftEdge = range.e<X>(0);
const X rightEdge = range.e<X>(1);
histogramFixedWidthCuda<T><<<512, MAX_NUM_THREADS / 2, 512, *stream>>>(input.getSpecialBuffer(), input.getSpecialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), leftEdge, rightEdge);
histogramFixedWidthCuda<X, Z><<<256, 256, 1024, *stream>>>(input.getSpecialBuffer(), input.getSpecialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), leftEdge, rightEdge);
}
////////////////////////////////////////////////////////////////////////
@ -89,7 +89,7 @@ void histogramFixedWidth(nd4j::LaunchContext* context, const NDArray& input, con
PointersManager manager(context, "histogramFixedWidth");
NDArray::prepareSpecialUse({&output}, {&input});
BUILD_SINGLE_SELECTOR(input.dataType(), histogramFixedWidthCudaLauncher, (context->getCudaStream(), input, range, output), LIBND4J_TYPES);
BUILD_DOUBLE_SELECTOR(input.dataType(), output.dataType(), histogramFixedWidthCudaLauncher, (context->getCudaStream(), input, range, output), LIBND4J_TYPES, INDEXING_TYPES);
NDArray::registerSpecialUse({&output}, {&input});
manager.synchronize();

View File

@ -312,7 +312,7 @@ public class AtomicAllocator implements Allocator {
@Override
public Pointer getPointer(INDArray array, CudaContext context) {
// DataBuffer buffer = array.data().originalDataBuffer() == null ? array.data() : array.data().originalDataBuffer();
if (array.isEmpty())
if (array.isEmpty() || array.isS())
return null;
return memoryHandler.getDevicePointer(array.data(), context);

View File

@ -172,7 +172,7 @@ public class SynchronousFlowController implements FlowController {
val cId = allocator.getDeviceId();
if (result != null && !result.isEmpty()) {
if (result != null && !result.isEmpty() && !result.isS()) {
Nd4j.getCompressor().autoDecompress(result);
prepareDelayedMemory(result);
val pointData = allocator.getAllocationPoint(result);
@ -198,7 +198,8 @@ public class SynchronousFlowController implements FlowController {
return context;
for (INDArray operand : operands) {
if (operand == null || operand.isEmpty())
// empty or String arrays can be skipped
if (operand == null || operand.isEmpty() || operand.isS())
continue;
Nd4j.getCompressor().autoDecompress(operand);

View File

@ -100,7 +100,7 @@ public class CudaOpContext extends BaseOpContext implements OpContext {
@Override
public Pointer contextPointer() {
for (val v:fastpath_in.values()) {
if (v.isEmpty())
if (v.isEmpty() || v.isS())
continue;
AtomicAllocator.getInstance().getAllocationPoint(v).tickHostRead();
@ -111,7 +111,7 @@ public class CudaOpContext extends BaseOpContext implements OpContext {
}
for (val v:fastpath_out.values()) {
if (v.isEmpty())
if (v.isEmpty() || v.isS())
continue;
AtomicAllocator.getInstance().getAllocationPoint(v).tickHostRead();