[WIP] confusion (#180)
* skip string arrays for device validation Signed-off-by: raver119 <raver119@gmail.com> * confusion_matrix fix Signed-off-by: raver119 <raver119@gmail.com>master
parent
dff599aa8f
commit
0e523490e9
|
@ -30,10 +30,10 @@ namespace helpers {
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__global__ static void copyBuffers(Nd4jLong* destination, void const* source, Nd4jLong bufferLength) {
|
__global__ static void copyBuffers(Nd4jLong* destination, void const* source, Nd4jLong bufferLength) {
|
||||||
const auto tid = blockIdx.x * gridDim.x + threadIdx.x;
|
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
const auto step = gridDim.x * blockDim.x;
|
const auto step = gridDim.x * blockDim.x;
|
||||||
for (int t = tid; t < bufferLength; t += step) {
|
for (int t = tid; t < bufferLength; t += step) {
|
||||||
destination[t] = reinterpret_cast<T const*>(source)[t];
|
destination[t] = static_cast<Nd4jLong>(reinterpret_cast<T const*>(source)[t]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -51,38 +51,24 @@ namespace helpers {
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
const auto tid = blockIdx.x * gridDim.x + threadIdx.x;
|
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
const auto step = gridDim.x * blockDim.x;
|
const auto step = gridDim.x * blockDim.x;
|
||||||
for (int t = tid; t < bufferLength; t += step) {
|
for (int t = tid; t < bufferLength; t += step) {
|
||||||
//auto tX = reinterpret_cast<T*>(inputList[t]);
|
|
||||||
//auto xShape = reinterpret_cast<Nd4jLong*>(inputShapeList[t]);
|
|
||||||
auto label = labelsBuffer[t]; //->e<Nd4jLong>(j);
|
auto label = labelsBuffer[t]; //->e<Nd4jLong>(j);
|
||||||
auto pred = predictionBuffer[t]; //->e<Nd4jLong>(j);
|
auto pred = predictionBuffer[t]; //->e<Nd4jLong>(j);
|
||||||
auto tZ = z + tadOffsets[label];
|
auto tZ = z + tadOffsets[label];
|
||||||
T val = (weightsBuffer == nullptr ? (T)1.0f : w[t]);
|
T val = (weightsBuffer == nullptr ? (T)1.0f : w[t]);
|
||||||
|
|
||||||
//for (int e = threadIdx.x; e < arrLen; e += blockDim.x) {
|
auto idx = shape::getIndexOffset(pred, tadShape, arrLen);
|
||||||
|
tZ[idx] = val;
|
||||||
tZ[shape::getIndexOffset(pred, tadShape, arrLen)] = val; //tX[shape::getIndexOffset(e, , arrLen)];
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename X, typename Z>
|
||||||
void _confusionFunctor(nd4j::LaunchContext * context, NDArray* labels, NDArray* predictions, NDArray* weights, NDArray* output) {
|
void _confusionFunctor(nd4j::LaunchContext * context, NDArray* labels, NDArray* predictions, NDArray* weights, NDArray* output) {
|
||||||
// std::unique_ptr<ResultSet> arrs(output->allTensorsAlongDimension({1}));
|
auto stream = context->getCudaStream();
|
||||||
//
|
|
||||||
//#pragma omp parallel for if(labels->lengthOf() > Environment::getInstance()->elementwiseThreshold()) schedule(static)
|
|
||||||
// for (int j = 0; j < labels->lengthOf(); ++j){
|
|
||||||
// auto label = labels->e<Nd4jLong>(j);
|
|
||||||
// auto pred = predictions->e<Nd4jLong>(j);
|
|
||||||
// T value = (weights == nullptr ? (T)1.0f : weights->e<T>(j));
|
|
||||||
// (*arrs->at(label)).p<T>(pred, value);
|
|
||||||
// }
|
|
||||||
|
|
||||||
int dimension = 1;
|
|
||||||
|
|
||||||
auto pack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), dimension);
|
|
||||||
|
|
||||||
|
auto pack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), 1);
|
||||||
|
|
||||||
PointersManager manager(context, "helpers::confusion");
|
PointersManager manager(context, "helpers::confusion");
|
||||||
|
|
||||||
|
@ -90,26 +76,26 @@ namespace helpers {
|
||||||
Nd4jLong* predictionLongBuffer = predictions->dataType() == nd4j::DataType::INT64?(Nd4jLong*)predictions->specialBuffer():nullptr;
|
Nd4jLong* predictionLongBuffer = predictions->dataType() == nd4j::DataType::INT64?(Nd4jLong*)predictions->specialBuffer():nullptr;
|
||||||
|
|
||||||
if (labelsLongBuffer == nullptr) {
|
if (labelsLongBuffer == nullptr) {
|
||||||
cudaError_t err = cudaMalloc(&labelsLongBuffer, labels->lengthOf() * sizeof(Nd4jLong));
|
auto err = cudaMalloc(&labelsLongBuffer, labels->lengthOf() * sizeof(Nd4jLong));
|
||||||
if (err != 0)
|
if (err != 0)
|
||||||
throw nd4j::cuda_exception::build("Cannot allocate memory for labels long buffer", err);
|
throw nd4j::cuda_exception::build("Cannot allocate memory for labels long buffer", err);
|
||||||
// copy with type conversion
|
// copy with type conversion
|
||||||
copyBuffers<T><<<256, 512, 8192>>>(labelsLongBuffer, labels->getSpecialBuffer(), labels->lengthOf());
|
copyBuffers<X><<<256, 512, 1024, *stream>>>(labelsLongBuffer, labels->getSpecialBuffer(), labels->lengthOf());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (predictionLongBuffer == nullptr) {
|
if (predictionLongBuffer == nullptr) {
|
||||||
cudaError_t err = cudaMalloc(&predictionLongBuffer, predictions->lengthOf() * sizeof(Nd4jLong));
|
auto err = cudaMalloc(&predictionLongBuffer, predictions->lengthOf() * sizeof(Nd4jLong));
|
||||||
if (err != 0)
|
if (err != 0)
|
||||||
throw nd4j::cuda_exception::build("Cannot allocate memory for predictions long buffer", err);
|
throw nd4j::cuda_exception::build("Cannot allocate memory for predictions long buffer", err);
|
||||||
// copy with type conversion
|
// copy with type conversion
|
||||||
copyBuffers<T><<<256, 512, 8192>>>(predictionLongBuffer, predictions->getSpecialBuffer(), predictions->lengthOf());
|
copyBuffers<X><<<256, 512, 1024, *stream>>>(predictionLongBuffer, predictions->getSpecialBuffer(), predictions->lengthOf());
|
||||||
}
|
}
|
||||||
|
|
||||||
auto bufferLength = labels->lengthOf();
|
auto bufferLength = labels->lengthOf();
|
||||||
dim3 launchDims(32, 32, 1024);
|
dim3 launchDims(32, 32, 1024);
|
||||||
auto stream = context->getCudaStream();
|
confusionFunctorKernel<Z><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(labelsLongBuffer, predictionLongBuffer, bufferLength, weights != nullptr? weights->getSpecialBuffer():nullptr, output->specialBuffer(), pack.specialShapeInfo(), pack.specialOffsets());
|
||||||
confusionFunctorKernel<T><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(labelsLongBuffer, predictionLongBuffer,
|
|
||||||
bufferLength, weights != nullptr? weights->getSpecialBuffer():nullptr, output->specialBuffer(), pack.specialShapeInfo(), pack.specialOffsets());
|
manager.synchronize();
|
||||||
|
|
||||||
if (predictionLongBuffer != predictions->getSpecialBuffer()) {
|
if (predictionLongBuffer != predictions->getSpecialBuffer()) {
|
||||||
cudaError_t err = cudaFree(predictionLongBuffer);
|
cudaError_t err = cudaFree(predictionLongBuffer);
|
||||||
|
@ -122,17 +108,15 @@ namespace helpers {
|
||||||
if (err != 0)
|
if (err != 0)
|
||||||
throw nd4j::cuda_exception::build("Cannot deallocate memory for labels long buffer", err);
|
throw nd4j::cuda_exception::build("Cannot deallocate memory for labels long buffer", err);
|
||||||
}
|
}
|
||||||
manager.synchronize();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void confusionFunctor(nd4j::LaunchContext * context, NDArray* labels, NDArray* predictions, NDArray* weights, NDArray* output) {
|
void confusionFunctor(nd4j::LaunchContext * context, NDArray* labels, NDArray* predictions, NDArray* weights, NDArray* output) {
|
||||||
auto xType = output->dataType(); // weights can be null
|
auto xType = predictions->dataType();
|
||||||
|
auto zType = output->dataType(); // weights can be null
|
||||||
BUILD_SINGLE_SELECTOR(xType, _confusionFunctor, (context, labels, predictions, weights, output), NUMERIC_TYPES);
|
NDArray::prepareSpecialUse({output}, {labels, predictions, weights});
|
||||||
}
|
BUILD_DOUBLE_SELECTOR(xType, zType, _confusionFunctor, (context, labels, predictions, weights, output), INDEXING_TYPES, NUMERIC_TYPES);
|
||||||
|
NDArray::registerSpecialUse({output}, {labels, predictions, weights});
|
||||||
BUILD_SINGLE_TEMPLATE(template void _confusionFunctor, (nd4j::LaunchContext * context, NDArray* labels, NDArray* predictions, NDArray* weights, NDArray* output);, NUMERIC_TYPES);
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
Loading…
Reference in New Issue