[WIP] More fixes (#178)

* skip string arrays for device validation

Signed-off-by: raver119 <raver119@gmail.com>

* histogram_fixed_width now really supports indexing types

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2019-08-27 13:21:01 +03:00 committed by GitHub
parent fd22a8ecc7
commit a49f7c908b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 19 additions and 18 deletions

View File

@ -49,7 +49,7 @@ CUSTOM_OP_IMPL(histogram_fixed_width, 2, 1, false, 0, 0) {
DECLARE_TYPES(histogram_fixed_width) { DECLARE_TYPES(histogram_fixed_width) {
getOpDescriptor() getOpDescriptor()
->setAllowedInputTypes(nd4j::DataType::ANY) ->setAllowedInputTypes(nd4j::DataType::ANY)
->setAllowedOutputTypes({ALL_INTS}); ->setAllowedOutputTypes({ALL_INDICES});
} }

View File

@ -27,16 +27,16 @@ namespace ops {
namespace helpers { namespace helpers {
/////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////
template<typename T> template<typename X, typename Z>
__global__ static void histogramFixedWidthCuda( const void* vx, const Nd4jLong* xShapeInfo, __global__ static void histogramFixedWidthCuda( const void* vx, const Nd4jLong* xShapeInfo,
void* vz, const Nd4jLong* zShapeInfo, void* vz, const Nd4jLong* zShapeInfo,
const T leftEdge, const T rightEdge) { const X leftEdge, const X rightEdge) {
const T* x = reinterpret_cast<const T*>(vx); const auto x = reinterpret_cast<const X*>(vx);
Nd4jLong* z = reinterpret_cast<Nd4jLong*>(vz); auto z = reinterpret_cast<Z*>(vz);
__shared__ Nd4jLong xLen, zLen, totalThreads, nbins; __shared__ Nd4jLong xLen, zLen, totalThreads, nbins;
__shared__ T binWidth, secondEdge, lastButOneEdge; __shared__ X binWidth, secondEdge, lastButOneEdge;
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
@ -55,7 +55,7 @@ __global__ static void histogramFixedWidthCuda( const void* vx, const Nd4jLong*
for (Nd4jLong i = tid; i < xLen; i += totalThreads) { for (Nd4jLong i = tid; i < xLen; i += totalThreads) {
const T value = x[shape::getIndexOffset(i, xShapeInfo, xLen)]; const X value = x[shape::getIndexOffset(i, xShapeInfo, xLen)];
Nd4jLong zIndex; Nd4jLong zIndex;
@ -66,18 +66,18 @@ __global__ static void histogramFixedWidthCuda( const void* vx, const Nd4jLong*
else else
zIndex = static_cast<Nd4jLong>((value - leftEdge) / binWidth); zIndex = static_cast<Nd4jLong>((value - leftEdge) / binWidth);
nd4j::math::atomics::nd4j_atomicAdd(&z[shape::getIndexOffset(zIndex, zShapeInfo, nbins)], 1LL); nd4j::math::atomics::nd4j_atomicAdd<Z>(&z[shape::getIndexOffset(zIndex, zShapeInfo, nbins)], 1);
} }
} }
/////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////
template<typename T> template<typename X, typename Z>
__host__ static void histogramFixedWidthCudaLauncher(const cudaStream_t *stream, const NDArray& input, const NDArray& range, NDArray& output) { __host__ static void histogramFixedWidthCudaLauncher(const cudaStream_t *stream, const NDArray& input, const NDArray& range, NDArray& output) {
const T leftEdge = range.e<T>(0); const X leftEdge = range.e<X>(0);
const T rightEdge = range.e<T>(1); const X rightEdge = range.e<X>(1);
histogramFixedWidthCuda<T><<<512, MAX_NUM_THREADS / 2, 512, *stream>>>(input.getSpecialBuffer(), input.getSpecialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), leftEdge, rightEdge); histogramFixedWidthCuda<X, Z><<<256, 256, 1024, *stream>>>(input.getSpecialBuffer(), input.getSpecialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), leftEdge, rightEdge);
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
@ -89,7 +89,7 @@ void histogramFixedWidth(nd4j::LaunchContext* context, const NDArray& input, con
PointersManager manager(context, "histogramFixedWidth"); PointersManager manager(context, "histogramFixedWidth");
NDArray::prepareSpecialUse({&output}, {&input}); NDArray::prepareSpecialUse({&output}, {&input});
BUILD_SINGLE_SELECTOR(input.dataType(), histogramFixedWidthCudaLauncher, (context->getCudaStream(), input, range, output), LIBND4J_TYPES); BUILD_DOUBLE_SELECTOR(input.dataType(), output.dataType(), histogramFixedWidthCudaLauncher, (context->getCudaStream(), input, range, output), LIBND4J_TYPES, INDEXING_TYPES);
NDArray::registerSpecialUse({&output}, {&input}); NDArray::registerSpecialUse({&output}, {&input});
manager.synchronize(); manager.synchronize();

View File

@ -312,7 +312,7 @@ public class AtomicAllocator implements Allocator {
@Override @Override
public Pointer getPointer(INDArray array, CudaContext context) { public Pointer getPointer(INDArray array, CudaContext context) {
// DataBuffer buffer = array.data().originalDataBuffer() == null ? array.data() : array.data().originalDataBuffer(); // DataBuffer buffer = array.data().originalDataBuffer() == null ? array.data() : array.data().originalDataBuffer();
if (array.isEmpty()) if (array.isEmpty() || array.isS())
return null; return null;
return memoryHandler.getDevicePointer(array.data(), context); return memoryHandler.getDevicePointer(array.data(), context);

View File

@ -172,7 +172,7 @@ public class SynchronousFlowController implements FlowController {
val cId = allocator.getDeviceId(); val cId = allocator.getDeviceId();
if (result != null && !result.isEmpty()) { if (result != null && !result.isEmpty() && !result.isS()) {
Nd4j.getCompressor().autoDecompress(result); Nd4j.getCompressor().autoDecompress(result);
prepareDelayedMemory(result); prepareDelayedMemory(result);
val pointData = allocator.getAllocationPoint(result); val pointData = allocator.getAllocationPoint(result);
@ -198,7 +198,8 @@ public class SynchronousFlowController implements FlowController {
return context; return context;
for (INDArray operand : operands) { for (INDArray operand : operands) {
if (operand == null || operand.isEmpty()) // empty or String arrays can be skipped
if (operand == null || operand.isEmpty() || operand.isS())
continue; continue;
Nd4j.getCompressor().autoDecompress(operand); Nd4j.getCompressor().autoDecompress(operand);

View File

@ -100,7 +100,7 @@ public class CudaOpContext extends BaseOpContext implements OpContext {
@Override @Override
public Pointer contextPointer() { public Pointer contextPointer() {
for (val v:fastpath_in.values()) { for (val v:fastpath_in.values()) {
if (v.isEmpty()) if (v.isEmpty() || v.isS())
continue; continue;
AtomicAllocator.getInstance().getAllocationPoint(v).tickHostRead(); AtomicAllocator.getInstance().getAllocationPoint(v).tickHostRead();
@ -111,7 +111,7 @@ public class CudaOpContext extends BaseOpContext implements OpContext {
} }
for (val v:fastpath_out.values()) { for (val v:fastpath_out.values()) {
if (v.isEmpty()) if (v.isEmpty() || v.isS())
continue; continue;
AtomicAllocator.getInstance().getAllocationPoint(v).tickHostRead(); AtomicAllocator.getInstance().getAllocationPoint(v).tickHostRead();