2019-06-06 14:21:15 +02:00
|
|
|
/*******************************************************************************
|
|
|
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
|
|
*
|
|
|
|
* This program and the accompanying materials are made available under the
|
|
|
|
* terms of the Apache License, Version 2.0 which is available at
|
|
|
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
|
|
*
|
|
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
|
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
|
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
|
|
* License for the specific language governing permissions and limitations
|
|
|
|
* under the License.
|
|
|
|
*
|
|
|
|
* SPDX-License-Identifier: Apache-2.0
|
|
|
|
******************************************************************************/
|
|
|
|
|
|
|
|
//
|
|
|
|
// @author GS <sgazeos@gmail.com>
|
|
|
|
//
|
|
|
|
|
|
|
|
#include <ops/declarable/helpers/confusion.h>
|
|
|
|
#include <cuda_exception.h>
|
|
|
|
#include <TAD.h>
|
|
|
|
#include <PointersManager.h>
|
|
|
|
#include <helpers/ConstantTadHelper.h>
|
|
|
|
|
|
|
|
namespace nd4j {
|
|
|
|
namespace ops {
|
|
|
|
namespace helpers {
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
__global__ static void copyBuffers(Nd4jLong* destination, void const* source, Nd4jLong bufferLength) {
|
2019-08-27 13:30:37 +02:00
|
|
|
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
2019-06-06 14:21:15 +02:00
|
|
|
const auto step = gridDim.x * blockDim.x;
|
|
|
|
for (int t = tid; t < bufferLength; t += step) {
|
2019-08-27 13:30:37 +02:00
|
|
|
destination[t] = static_cast<Nd4jLong>(reinterpret_cast<T const*>(source)[t]);
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
__global__ static void confusionFunctorKernel(Nd4jLong* labelsBuffer, Nd4jLong* predictionBuffer, Nd4jLong bufferLength, void const* weightsBuffer, void* outputBuffer, Nd4jLong* tadShape, Nd4jLong* tadOffsets) {
|
|
|
|
__shared__ int arrIdx, blocksPerArr;
|
|
|
|
__shared__ T *z;
|
|
|
|
__shared__ T const* w;
|
|
|
|
__shared__ Nd4jLong *zShapeInfo, *xShapeInfo, arrLen;
|
|
|
|
|
|
|
|
if (threadIdx.x == 0) {
|
|
|
|
z = reinterpret_cast<T*>(outputBuffer);
|
|
|
|
w = reinterpret_cast<T const*>(weightsBuffer);
|
|
|
|
arrLen = shape::length(tadShape);
|
|
|
|
}
|
|
|
|
__syncthreads();
|
|
|
|
|
2019-08-27 13:30:37 +02:00
|
|
|
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
2019-06-06 14:21:15 +02:00
|
|
|
const auto step = gridDim.x * blockDim.x;
|
|
|
|
for (int t = tid; t < bufferLength; t += step) {
|
|
|
|
auto label = labelsBuffer[t]; //->e<Nd4jLong>(j);
|
|
|
|
auto pred = predictionBuffer[t]; //->e<Nd4jLong>(j);
|
|
|
|
auto tZ = z + tadOffsets[label];
|
|
|
|
T val = (weightsBuffer == nullptr ? (T)1.0f : w[t]);
|
|
|
|
|
2019-08-27 13:30:37 +02:00
|
|
|
auto idx = shape::getIndexOffset(pred, tadShape, arrLen);
|
|
|
|
tZ[idx] = val;
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2019-08-27 13:30:37 +02:00
|
|
|
template <typename X, typename Z>
|
2019-06-06 14:21:15 +02:00
|
|
|
void _confusionFunctor(nd4j::LaunchContext * context, NDArray* labels, NDArray* predictions, NDArray* weights, NDArray* output) {
|
2019-08-27 13:30:37 +02:00
|
|
|
auto stream = context->getCudaStream();
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2019-08-27 13:30:37 +02:00
|
|
|
auto pack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), 1);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
PointersManager manager(context, "helpers::confusion");
|
|
|
|
|
|
|
|
Nd4jLong* labelsLongBuffer = labels->dataType() == nd4j::DataType::INT64?(Nd4jLong*)labels->specialBuffer():nullptr;
|
|
|
|
Nd4jLong* predictionLongBuffer = predictions->dataType() == nd4j::DataType::INT64?(Nd4jLong*)predictions->specialBuffer():nullptr;
|
|
|
|
|
|
|
|
if (labelsLongBuffer == nullptr) {
|
2019-08-27 13:30:37 +02:00
|
|
|
auto err = cudaMalloc(&labelsLongBuffer, labels->lengthOf() * sizeof(Nd4jLong));
|
2019-06-06 14:21:15 +02:00
|
|
|
if (err != 0)
|
|
|
|
throw nd4j::cuda_exception::build("Cannot allocate memory for labels long buffer", err);
|
|
|
|
// copy with type conversion
|
2019-08-27 13:30:37 +02:00
|
|
|
copyBuffers<X><<<256, 512, 1024, *stream>>>(labelsLongBuffer, labels->getSpecialBuffer(), labels->lengthOf());
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
if (predictionLongBuffer == nullptr) {
|
2019-08-27 13:30:37 +02:00
|
|
|
auto err = cudaMalloc(&predictionLongBuffer, predictions->lengthOf() * sizeof(Nd4jLong));
|
2019-06-06 14:21:15 +02:00
|
|
|
if (err != 0)
|
|
|
|
throw nd4j::cuda_exception::build("Cannot allocate memory for predictions long buffer", err);
|
|
|
|
// copy with type conversion
|
2019-08-27 13:30:37 +02:00
|
|
|
copyBuffers<X><<<256, 512, 1024, *stream>>>(predictionLongBuffer, predictions->getSpecialBuffer(), predictions->lengthOf());
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
auto bufferLength = labels->lengthOf();
|
|
|
|
dim3 launchDims(32, 32, 1024);
|
2019-08-27 13:30:37 +02:00
|
|
|
confusionFunctorKernel<Z><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(labelsLongBuffer, predictionLongBuffer, bufferLength, weights != nullptr? weights->getSpecialBuffer():nullptr, output->specialBuffer(), pack.specialShapeInfo(), pack.specialOffsets());
|
|
|
|
|
|
|
|
manager.synchronize();
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
if (predictionLongBuffer != predictions->getSpecialBuffer()) {
|
|
|
|
cudaError_t err = cudaFree(predictionLongBuffer);
|
|
|
|
if (err != 0)
|
|
|
|
throw nd4j::cuda_exception::build("Cannot deallocate memory for predictions long buffer", err);
|
|
|
|
}
|
|
|
|
|
|
|
|
if (labelsLongBuffer != labels->getSpecialBuffer()) {
|
|
|
|
cudaError_t err = cudaFree(labelsLongBuffer);
|
|
|
|
if (err != 0)
|
|
|
|
throw nd4j::cuda_exception::build("Cannot deallocate memory for labels long buffer", err);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
void confusionFunctor(nd4j::LaunchContext * context, NDArray* labels, NDArray* predictions, NDArray* weights, NDArray* output) {
|
2019-08-27 13:30:37 +02:00
|
|
|
auto xType = predictions->dataType();
|
|
|
|
auto zType = output->dataType(); // weights can be null
|
|
|
|
NDArray::prepareSpecialUse({output}, {labels, predictions, weights});
|
|
|
|
BUILD_DOUBLE_SELECTOR(xType, zType, _confusionFunctor, (context, labels, predictions, weights, output), INDEXING_TYPES, NUMERIC_TYPES);
|
|
|
|
NDArray::registerSpecialUse({output}, {labels, predictions, weights});
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|