[WIP] More fixes (#178)
* skip string arrays for device validation Signed-off-by: raver119 <raver119@gmail.com> * histogram_fixed_width now really supports indexing types Signed-off-by: raver119 <raver119@gmail.com>master
parent
fd22a8ecc7
commit
a49f7c908b
|
@ -49,7 +49,7 @@ CUSTOM_OP_IMPL(histogram_fixed_width, 2, 1, false, 0, 0) {
|
|||
DECLARE_TYPES(histogram_fixed_width) {
|
||||
getOpDescriptor()
|
||||
->setAllowedInputTypes(nd4j::DataType::ANY)
|
||||
->setAllowedOutputTypes({ALL_INTS});
|
||||
->setAllowedOutputTypes({ALL_INDICES});
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -27,16 +27,16 @@ namespace ops {
|
|||
namespace helpers {
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
template<typename T>
|
||||
template<typename X, typename Z>
|
||||
__global__ static void histogramFixedWidthCuda( const void* vx, const Nd4jLong* xShapeInfo,
|
||||
void* vz, const Nd4jLong* zShapeInfo,
|
||||
const T leftEdge, const T rightEdge) {
|
||||
const X leftEdge, const X rightEdge) {
|
||||
|
||||
const T* x = reinterpret_cast<const T*>(vx);
|
||||
Nd4jLong* z = reinterpret_cast<Nd4jLong*>(vz);
|
||||
const auto x = reinterpret_cast<const X*>(vx);
|
||||
auto z = reinterpret_cast<Z*>(vz);
|
||||
|
||||
__shared__ Nd4jLong xLen, zLen, totalThreads, nbins;
|
||||
__shared__ T binWidth, secondEdge, lastButOneEdge;
|
||||
__shared__ X binWidth, secondEdge, lastButOneEdge;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
|
||||
|
@ -55,7 +55,7 @@ __global__ static void histogramFixedWidthCuda( const void* vx, const Nd4jLong*
|
|||
|
||||
for (Nd4jLong i = tid; i < xLen; i += totalThreads) {
|
||||
|
||||
const T value = x[shape::getIndexOffset(i, xShapeInfo, xLen)];
|
||||
const X value = x[shape::getIndexOffset(i, xShapeInfo, xLen)];
|
||||
|
||||
Nd4jLong zIndex;
|
||||
|
||||
|
@ -66,18 +66,18 @@ __global__ static void histogramFixedWidthCuda( const void* vx, const Nd4jLong*
|
|||
else
|
||||
zIndex = static_cast<Nd4jLong>((value - leftEdge) / binWidth);
|
||||
|
||||
nd4j::math::atomics::nd4j_atomicAdd(&z[shape::getIndexOffset(zIndex, zShapeInfo, nbins)], 1LL);
|
||||
nd4j::math::atomics::nd4j_atomicAdd<Z>(&z[shape::getIndexOffset(zIndex, zShapeInfo, nbins)], 1);
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
template<typename T>
|
||||
template<typename X, typename Z>
|
||||
__host__ static void histogramFixedWidthCudaLauncher(const cudaStream_t *stream, const NDArray& input, const NDArray& range, NDArray& output) {
|
||||
|
||||
const T leftEdge = range.e<T>(0);
|
||||
const T rightEdge = range.e<T>(1);
|
||||
const X leftEdge = range.e<X>(0);
|
||||
const X rightEdge = range.e<X>(1);
|
||||
|
||||
histogramFixedWidthCuda<T><<<512, MAX_NUM_THREADS / 2, 512, *stream>>>(input.getSpecialBuffer(), input.getSpecialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), leftEdge, rightEdge);
|
||||
histogramFixedWidthCuda<X, Z><<<256, 256, 1024, *stream>>>(input.getSpecialBuffer(), input.getSpecialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), leftEdge, rightEdge);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
|
@ -89,7 +89,7 @@ void histogramFixedWidth(nd4j::LaunchContext* context, const NDArray& input, con
|
|||
PointersManager manager(context, "histogramFixedWidth");
|
||||
|
||||
NDArray::prepareSpecialUse({&output}, {&input});
|
||||
BUILD_SINGLE_SELECTOR(input.dataType(), histogramFixedWidthCudaLauncher, (context->getCudaStream(), input, range, output), LIBND4J_TYPES);
|
||||
BUILD_DOUBLE_SELECTOR(input.dataType(), output.dataType(), histogramFixedWidthCudaLauncher, (context->getCudaStream(), input, range, output), LIBND4J_TYPES, INDEXING_TYPES);
|
||||
NDArray::registerSpecialUse({&output}, {&input});
|
||||
|
||||
manager.synchronize();
|
||||
|
|
|
@ -312,7 +312,7 @@ public class AtomicAllocator implements Allocator {
|
|||
@Override
|
||||
public Pointer getPointer(INDArray array, CudaContext context) {
|
||||
// DataBuffer buffer = array.data().originalDataBuffer() == null ? array.data() : array.data().originalDataBuffer();
|
||||
if (array.isEmpty())
|
||||
if (array.isEmpty() || array.isS())
|
||||
return null;
|
||||
|
||||
return memoryHandler.getDevicePointer(array.data(), context);
|
||||
|
|
|
@ -172,7 +172,7 @@ public class SynchronousFlowController implements FlowController {
|
|||
val cId = allocator.getDeviceId();
|
||||
|
||||
|
||||
if (result != null && !result.isEmpty()) {
|
||||
if (result != null && !result.isEmpty() && !result.isS()) {
|
||||
Nd4j.getCompressor().autoDecompress(result);
|
||||
prepareDelayedMemory(result);
|
||||
val pointData = allocator.getAllocationPoint(result);
|
||||
|
@ -198,7 +198,8 @@ public class SynchronousFlowController implements FlowController {
|
|||
return context;
|
||||
|
||||
for (INDArray operand : operands) {
|
||||
if (operand == null || operand.isEmpty())
|
||||
// empty or String arrays can be skipped
|
||||
if (operand == null || operand.isEmpty() || operand.isS())
|
||||
continue;
|
||||
|
||||
Nd4j.getCompressor().autoDecompress(operand);
|
||||
|
|
|
@ -100,7 +100,7 @@ public class CudaOpContext extends BaseOpContext implements OpContext {
|
|||
@Override
|
||||
public Pointer contextPointer() {
|
||||
for (val v:fastpath_in.values()) {
|
||||
if (v.isEmpty())
|
||||
if (v.isEmpty() || v.isS())
|
||||
continue;
|
||||
|
||||
AtomicAllocator.getInstance().getAllocationPoint(v).tickHostRead();
|
||||
|
@ -111,7 +111,7 @@ public class CudaOpContext extends BaseOpContext implements OpContext {
|
|||
}
|
||||
|
||||
for (val v:fastpath_out.values()) {
|
||||
if (v.isEmpty())
|
||||
if (v.isEmpty() || v.isS())
|
||||
continue;
|
||||
|
||||
AtomicAllocator.getInstance().getAllocationPoint(v).tickHostRead();
|
||||
|
|
Loading…
Reference in New Issue