commit
c0163f6e01
|
@ -24,6 +24,7 @@ import org.bytedeco.javacv.OpenCVFrameConverter;
|
|||
import org.datavec.image.data.Image;
|
||||
import org.datavec.image.data.ImageWritable;
|
||||
import org.datavec.image.transform.ImageTransform;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.api.concurrency.AffinityManager;
|
||||
import org.nd4j.linalg.api.memory.pointers.PagedPointer;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
@ -284,6 +285,9 @@ public class NativeImageLoader extends BaseImageLoader {
|
|||
private Mat streamToMat(InputStream is) throws IOException {
|
||||
if(buffer == null){
|
||||
buffer = IOUtils.toByteArray(is);
|
||||
if(buffer.length <= 0){
|
||||
throw new IOException("Could not decode image from input stream: input stream was empty (no data)");
|
||||
}
|
||||
bufferMat = new Mat(buffer);
|
||||
return bufferMat;
|
||||
} else {
|
||||
|
@ -292,6 +296,10 @@ public class NativeImageLoader extends BaseImageLoader {
|
|||
//(a) if numRead < buffer.length - got everything
|
||||
//(b) if numRead >= buffer.length: we MIGHT have got everything (exact right size buffer) OR we need more data
|
||||
|
||||
if(numReadTotal <= 0){
|
||||
throw new IOException("Could not decode image from input stream: input stream was empty (no data)");
|
||||
}
|
||||
|
||||
if(numReadTotal < buffer.length){
|
||||
bufferMat.data().put(buffer, 0, numReadTotal);
|
||||
bufferMat.cols(numReadTotal);
|
||||
|
|
|
@ -24,7 +24,9 @@ import org.bytedeco.javacv.Frame;
|
|||
import org.bytedeco.javacv.Java2DFrameConverter;
|
||||
import org.bytedeco.javacv.OpenCVFrameConverter;
|
||||
import org.datavec.image.data.ImageWritable;
|
||||
import org.junit.Rule;
|
||||
import org.junit.Test;
|
||||
import org.junit.rules.TemporaryFolder;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
@ -55,6 +57,9 @@ public class TestNativeImageLoader {
|
|||
static final long seed = 10;
|
||||
static final Random rng = new Random(seed);
|
||||
|
||||
@Rule
|
||||
public TemporaryFolder testDir = new TemporaryFolder();
|
||||
|
||||
@Test
|
||||
public void testConvertPix() throws Exception {
|
||||
PIX pix;
|
||||
|
@ -554,4 +559,43 @@ public class TestNativeImageLoader {
|
|||
assertEquals(img1LargeBuffer, img1ExactBuffer);
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testNativeImageLoaderEmptyStreams() throws Exception {
|
||||
File dir = testDir.newFolder();
|
||||
File f = new File(dir, "myFile.jpg");
|
||||
f.createNewFile();
|
||||
|
||||
NativeImageLoader nil = new NativeImageLoader(32, 32, 3);
|
||||
|
||||
try(InputStream is = new FileInputStream(f)){
|
||||
nil.asMatrix(is);
|
||||
} catch (IOException e){
|
||||
String msg = e.getMessage();
|
||||
assertTrue(msg, msg.contains("decode image"));
|
||||
}
|
||||
|
||||
try(InputStream is = new FileInputStream(f)){
|
||||
nil.asImageMatrix(is);
|
||||
} catch (IOException e){
|
||||
String msg = e.getMessage();
|
||||
assertTrue(msg, msg.contains("decode image"));
|
||||
}
|
||||
|
||||
try(InputStream is = new FileInputStream(f)){
|
||||
nil.asRowVector(is);
|
||||
} catch (IOException e){
|
||||
String msg = e.getMessage();
|
||||
assertTrue(msg, msg.contains("decode image"));
|
||||
}
|
||||
|
||||
try(InputStream is = new FileInputStream(f)){
|
||||
INDArray arr = Nd4j.create(DataType.FLOAT, 1, 3, 32, 32);
|
||||
nil.asMatrixView(is, arr);
|
||||
} catch (IOException e){
|
||||
String msg = e.getMessage();
|
||||
assertTrue(msg, msg.contains("decode image"));
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -46,9 +46,9 @@ void TrueBroadcastHelper<X, Y, Z>::exec(const NDArray& xArr, const NDArray& yArr
|
|||
|
||||
const Nd4jLong zLen = zArr.lengthOf();
|
||||
|
||||
std::vector<Nd4jLong> xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf());
|
||||
|
||||
auto func = PRAGMA_THREADS_FOR {
|
||||
std::vector<Nd4jLong> xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf());
|
||||
|
||||
for (auto i = start; i < stop; ++i) {
|
||||
|
||||
shape::index2coords(i, zShapeInfo, zCoords.data());
|
||||
|
@ -109,6 +109,7 @@ void TrueBroadcastBoolHelper<X, Z>::exec(const NDArray& xArr, const NDArray& yAr
|
|||
|
||||
auto func = PRAGMA_THREADS_FOR {
|
||||
std::vector<Nd4jLong> xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf());
|
||||
|
||||
for (auto i = start; i < stop; ++i) {
|
||||
|
||||
shape::index2coords(i, zShapeInfo, zCoords.data());
|
||||
|
@ -167,9 +168,9 @@ void TrueBroadcastIntHelper<X>::exec(const NDArray& xArr, const NDArray& yArr, N
|
|||
|
||||
const Nd4jLong zLen = zArr.lengthOf();
|
||||
|
||||
std::vector<Nd4jLong> xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf());
|
||||
|
||||
auto func = PRAGMA_THREADS_FOR {
|
||||
std::vector<Nd4jLong> xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf());
|
||||
|
||||
for (auto i = start; i < stop; ++i) {
|
||||
|
||||
shape::index2coords(i, zShapeInfo, zCoords.data());
|
||||
|
|
|
@ -104,9 +104,9 @@ void TrueBroadcastHelper<X,Y,Z>::exec(const nd4j::broadcast::Ops opNum, const ND
|
|||
|
||||
dim3 launchDims;
|
||||
|
||||
launchDims.x = MAX_NUM_THREADS / 8; // threadsPerBlock
|
||||
launchDims.y = (zArr.lengthOf() + launchDims.x - 1) / launchDims.x; // blocksPerGrid
|
||||
launchDims.z = sizeof(Nd4jLong) * launchDims.x * (xArr.rankOf() + yArr.rankOf() + zArr.rankOf()) + 128; // sharedMem
|
||||
launchDims.y = MAX_NUM_THREADS / 8; // threadsPerBlock
|
||||
launchDims.x = (zArr.lengthOf() + launchDims.y - 1) / launchDims.y; // blocksPerGrid
|
||||
launchDims.z = sizeof(Nd4jLong) * launchDims.y * (xArr.rankOf() + yArr.rankOf() + zArr.rankOf()) + 128; // sharedMe
|
||||
|
||||
PointersManager manager(xArr.getContext(), "TrueBroadcastHelper<X,Y,Z>::exec");
|
||||
|
||||
|
@ -189,9 +189,10 @@ template<typename X, typename Y>
|
|||
void TrueBroadcastBoolHelper<X,Y>::exec(const nd4j::broadcast::BoolOps opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
|
||||
|
||||
dim3 launchDims;
|
||||
launchDims.x = MAX_NUM_THREADS / 8; // threadsPerBlock
|
||||
launchDims.y = (zArr.lengthOf() + launchDims.x - 1) / launchDims.x; // blocksPerGrid
|
||||
launchDims.z = sizeof(Nd4jLong) * launchDims.x * (xArr.rankOf() + yArr.rankOf() + zArr.rankOf()) + 128; // sharedMem
|
||||
|
||||
launchDims.y = MAX_NUM_THREADS / 8; // threadsPerBlock
|
||||
launchDims.x = (zArr.lengthOf() + launchDims.y - 1) / launchDims.y; // blocksPerGrid
|
||||
launchDims.z = sizeof(Nd4jLong) * launchDims.y * (xArr.rankOf() + yArr.rankOf() + zArr.rankOf()) + 128; // sharedMe
|
||||
|
||||
PointersManager manager(xArr.getContext(), "TrueBroadcastBoolHelper<X,Y>::exec");
|
||||
|
||||
|
@ -274,9 +275,10 @@ template<typename X>
|
|||
void TrueBroadcastIntHelper<X>::exec(const nd4j::broadcast::IntOps opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
|
||||
|
||||
dim3 launchDims;
|
||||
launchDims.x = MAX_NUM_THREADS / 8; // threadsPerBlock
|
||||
launchDims.y = (zArr.lengthOf() + launchDims.x - 1) / launchDims.x; // blocksPerGrid
|
||||
launchDims.z = sizeof(Nd4jLong) * launchDims.x * (xArr.rankOf() + yArr.rankOf() + zArr.rankOf()) + 128; // sharedMem
|
||||
|
||||
launchDims.y = MAX_NUM_THREADS / 8; // threadsPerBlock
|
||||
launchDims.x = (zArr.lengthOf() + launchDims.y - 1) / launchDims.y; // blocksPerGrid
|
||||
launchDims.z = sizeof(Nd4jLong) * launchDims.y * (xArr.rankOf() + yArr.rankOf() + zArr.rankOf()) + 128; // sharedMe
|
||||
|
||||
PointersManager manager(xArr.getContext(), "TrueBroadcastIntHelper<X>::exec");
|
||||
|
||||
|
|
|
@ -237,6 +237,8 @@ template <typename X, typename Z>
|
|||
template<typename OpType>
|
||||
__host__ void ReduceBoolFunction<X,Z>::intermediateXD(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, Nd4jLong *hXShapeInfo, void *extraParams, void *z, Nd4jLong *zShapeInfo, Nd4jLong *hZShapeInfo, int *dimension, int dimensionLength, void *reductionPointer, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
||||
|
||||
nd4j_printf("Step A%i\n", -1);
|
||||
|
||||
if(shape::isEmpty(hXShapeInfo)) {
|
||||
|
||||
if(shape::isEmpty(hZShapeInfo))
|
||||
|
@ -251,7 +253,8 @@ __host__ void ReduceBoolFunction<X,Z>::intermediateXD(dim3 launchDims, cudaStrea
|
|||
auto ptr = nd4j::LaunchContext::defaultContext()->getScalarPointer();
|
||||
|
||||
// scalar assign
|
||||
functions::scalar::ScalarTransform<Z, Z, Z>::executeCudaShaped(launchDims, stream, 14, z, zShapeInfo, hXShapeInfo, z, zShapeInfo, hZShapeInfo, ptr, nullptr);
|
||||
functions::scalar::ScalarTransform<Z, Z, Z>::executeCudaShaped(launchDims, stream, 14, z, zShapeInfo, hZShapeInfo, z, zShapeInfo, hZShapeInfo, ptr, nullptr);
|
||||
nd4j::DebugHelper::checkErrorCode(stream, "reduceBoolDim empty(...) failed");
|
||||
}
|
||||
else {
|
||||
simpleReduce<X, Z, OpType><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, reductionPointer, tadShapeInfo, tadOffsets);
|
||||
|
@ -274,6 +277,9 @@ __host__ void ReduceBoolFunction<X,Z>::intermediateScalar(dim3 launchDims, cudaS
|
|||
auto res = cudaMemcpyAsync(z, &startingVal, sizeof(Z), cudaMemcpyHostToDevice, *stream);
|
||||
if (res != 0)
|
||||
throw nd4j::cuda_exception::build("ReduceBoolFunction<X,Z>::intermediateScalar: failed to copy resulting scalar", res);
|
||||
|
||||
nd4j::DebugHelper::checkErrorCode(stream, "reduceBoolScalar empty(...) failed");
|
||||
|
||||
}
|
||||
else {
|
||||
simpleScalar<X, Z, OpType><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, reductionBuffer, tadOnlyShapeInfo);
|
||||
|
|
|
@ -249,7 +249,7 @@ __host__ void ReduceFloatFunction<X,Z>::intermediateXD(dim3 launchDims, cudaStre
|
|||
auto ptr = nd4j::LaunchContext::defaultContext()->getScalarPointer();
|
||||
|
||||
// scalar assign
|
||||
functions::scalar::ScalarTransform<Z, Z, Z>::executeCudaShaped(launchDims, stream, 14, z, zShape, hXShapeInfo, z, zShape, hZShapeInfo, ptr, nullptr);
|
||||
functions::scalar::ScalarTransform<Z, Z, Z>::executeCudaShaped(launchDims, stream, 14, z, zShape, hZShapeInfo, z, zShape, hZShapeInfo, ptr, nullptr);
|
||||
}
|
||||
else {
|
||||
simpleReduce<X, Z, OpType><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(x, xShape, extraParams, z, zShape, dimension, dimensionLength, reductionPointer, tadShapeInfo, tadOffsets);
|
||||
|
|
|
@ -76,7 +76,7 @@ namespace nd4j {
|
|||
// Y gradient
|
||||
//epsNext->applyTriplewiseLambda(x, y, lambdaY, gradY);
|
||||
|
||||
gradY->assign(epsNext * -(*x) / ((*y) * (*y)));
|
||||
gradY->assign((*epsNext) * -(*x) / ((*y) * (*y)));
|
||||
|
||||
} else if (y->isScalar()) {
|
||||
// scalar case
|
||||
|
|
|
@ -89,7 +89,7 @@ namespace nd4j {
|
|||
gradY->assign(tmpX);
|
||||
|
||||
//epsNext->applyPairwiseLambda(x, lambdaS, gradX);
|
||||
gradX->assign(epsNext * ts * ((*x) - (*y)));
|
||||
gradX->assign((*epsNext) * ts * ((*x) - (*y)));
|
||||
} else {
|
||||
// broadcast case
|
||||
|
||||
|
|
|
@ -39,7 +39,7 @@ CONFIGURABLE_OP_IMPL(adjust_contrast, 1, 1, true, 0, 0) {
|
|||
|
||||
REQUIRE_TRUE(block.numT() > 0 || block.width() > 1, 0, "ADJUST_CONTRAST: Scale factor required");
|
||||
REQUIRE_TRUE(input->rankOf() > 2, 0, "ADJUST_CONTRAST: op expects rank of input array to be >= 3, but got %i instead", input->rankOf());
|
||||
REQUIRE_TRUE(input->sizeAt(-1) == 3, 0, "ADJUST_CONTRAST: operation expects image with 3 channels (R, G, B), but got %i instead", input->sizeAt(-1));
|
||||
// REQUIRE_TRUE(input->sizeAt(-1) == 3, 0, "ADJUST_CONTRAST: operation expects image with 3 channels (R, G, B), but got %i instead", input->sizeAt(-1));
|
||||
|
||||
NDArray* factor = nullptr;
|
||||
|
||||
|
@ -84,10 +84,15 @@ CONFIGURABLE_OP_IMPL(adjust_contrast_v2, 1, 1, true, 0, 0) {
|
|||
return Status::OK();
|
||||
|
||||
REQUIRE_TRUE(input->rankOf() > 2, 0, "ADJUST_CONTRAST_V2: op expects rank of input array to be >= 3, but got %i instead", input->rankOf());
|
||||
REQUIRE_TRUE(input->sizeAt(-1) == 3, 0, "ADJUST_CONTRAST_V2: operation expects image with 3 channels (R, G, B), but got %i instead", input->sizeAt(-1));
|
||||
// REQUIRE_TRUE(input->sizeAt(-1) == 3, 0, "ADJUST_CONTRAST_V2: operation expects image with 3 channels (R, G, B), but got %i instead", input->sizeAt(-1));
|
||||
REQUIRE_TRUE(block.numT() > 0 || block.width() > 1, 0, "ADJUST_CONTRAST_V2: Scale factor required");
|
||||
|
||||
NDArray* factor = nullptr;
|
||||
auto size = input->sizeAt(-2) * input->sizeAt(-3);
|
||||
auto channels = input->sizeAt(-1);
|
||||
auto batch = input->lengthOf() / (size * channels);
|
||||
auto input3D = input->reshape(input->ordering(), {batch, size, channels});
|
||||
auto output3D = input->reshape(input->ordering(), {batch, size, channels});
|
||||
|
||||
if(block.width() > 1)
|
||||
factor = INPUT_VARIABLE(1);
|
||||
|
@ -96,20 +101,17 @@ CONFIGURABLE_OP_IMPL(adjust_contrast_v2, 1, 1, true, 0, 0) {
|
|||
factor->p(0, T_ARG(0));
|
||||
}
|
||||
|
||||
// compute mean before
|
||||
std::vector<int> axes(input->rankOf() - 1);
|
||||
for (auto i = 0; i < axes.size(); ++i)
|
||||
axes[i] = i;
|
||||
std::vector<int> axes({1}); // dim 1 of pseudoresult
|
||||
|
||||
// mean as reduction for last dimension set
|
||||
auto mean = input->reduceAlongDims(reduce::Mean, axes);
|
||||
// mean as reduction for last dimension set over size (dim 1) of result3D
|
||||
auto mean = input3D.reduceAlongDims(reduce::Mean, axes);
|
||||
|
||||
// result as (x - mean) * factor + mean
|
||||
auto temp = input->ulike();
|
||||
input->applyTrueBroadcast(BroadcastOpsTuple::Subtract(), &mean, &temp);
|
||||
auto temp = input3D.ulike();
|
||||
input3D.applyBroadcast(broadcast::Subtract, {0, 2}, &mean, &temp, nullptr);
|
||||
temp.applyScalarArr(scalar::Multiply, factor);
|
||||
temp.applyTrueBroadcast(BroadcastOpsTuple::Add(), &mean, output);
|
||||
|
||||
temp.applyBroadcast(broadcast::Add, {0, 2}, &mean, &output3D);
|
||||
output->assign(output3D);
|
||||
if(block.width() == 1)
|
||||
delete factor;
|
||||
|
||||
|
|
|
@ -52,12 +52,11 @@ namespace nd4j {
|
|||
if (block.getIArguments() && block.getIArguments()->size())
|
||||
numBits = INT_ARG(0);
|
||||
bool narrowed = false;
|
||||
//INT_ARG(1);
|
||||
if (block.getIArguments()->size() == 2) {
|
||||
numBits = INT_ARG(0);
|
||||
narrowed = INT_ARG(1);
|
||||
REQUIRE_TRUE(numBits > 1 && numBits < 17, 0, "fake_quant_with_min_max_vars: Number of bits for quatization should be in between 2 and 16, but %i was given.", numBits);
|
||||
if (block.getBArguments() && block.getBArguments()->size()) {
|
||||
narrowed = B_ARG(0);
|
||||
}
|
||||
REQUIRE_TRUE(numBits > 1 && numBits < 17, 0, "fake_quant_with_min_max_vars: Number of \
|
||||
bits for quantization should be in between 2 and 16, but %i was given.", numBits);
|
||||
helpers::fakeQuantWithMinMaxVars(x, min, max, numBits, narrowed, output);
|
||||
return ND4J_STATUS_OK;
|
||||
}
|
||||
|
|
|
@ -96,16 +96,16 @@ namespace nd4j {
|
|||
outputShape[2] = height;
|
||||
outputShape[3] = in[3];
|
||||
}
|
||||
ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in));
|
||||
ShapeUtils::updateStridesAndType(outputShape, DataType::FLOAT32, shape::order(in));
|
||||
|
||||
shapeList->push_back(CONSTANT(outputShape));
|
||||
return shapeList;
|
||||
}
|
||||
DECLARE_TYPES(resize_bicubic) {
|
||||
getOpDescriptor()
|
||||
->setAllowedInputTypes(0, {ALL_FLOATS})
|
||||
->setAllowedInputTypes(1, {DataType::INT32})
|
||||
->setAllowedOutputTypes({ALL_FLOATS});
|
||||
->setAllowedInputTypes(0, {ALL_FLOATS, ALL_INTS})
|
||||
->setAllowedInputTypes(1, DataType::INT32)
|
||||
->setAllowedOutputTypes({DataType::FLOAT32});
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -164,6 +164,7 @@ namespace nd4j {
|
|||
|
||||
// we can launch op using Int arguments
|
||||
if (inputShape->size() == 1) {
|
||||
REQUIRE_TRUE(block.numI() > 0, 0, "Reshape: new shape should be provided as NDArray or int arguments, but nothing was defined");
|
||||
std::vector<int> *arguments = block.getIArguments();
|
||||
|
||||
int e = 1;
|
||||
|
|
|
@ -352,14 +352,12 @@ namespace helpers {
|
|||
|
||||
int resizeBilinearFunctor(nd4j::LaunchContext * context, NDArray const *images, int const width, int const height,
|
||||
bool const alignCorners, bool const halfPixelCenter, NDArray *output) {
|
||||
BUILD_DOUBLE_SELECTOR(images->dataType(), output->dataType(), return resizeBilinearFunctor_,
|
||||
(images, width, height, alignCorners, halfPixelCenter, output), NUMERIC_TYPES, FLOAT_TYPES);
|
||||
BUILD_DOUBLE_SELECTOR(images->dataType(), output->dataType(), return resizeBilinearFunctor_, (images, width, height, alignCorners, halfPixelCenter, output), NUMERIC_TYPES, FLOAT_TYPES);
|
||||
}
|
||||
|
||||
int resizeNeighborFunctor(nd4j::LaunchContext * context, NDArray const *images, int const width, int const height,
|
||||
bool const alignCorners, bool const halfPixelCenter, NDArray *output) {
|
||||
BUILD_SINGLE_SELECTOR(images->dataType(), return resizeNeighborFunctor_,
|
||||
(images, width, height, alignCorners, halfPixelCenter, output), LIBND4J_TYPES);
|
||||
BUILD_SINGLE_SELECTOR(images->dataType(), return resizeNeighborFunctor_, (images, width, height, alignCorners, halfPixelCenter, output), LIBND4J_TYPES);
|
||||
}
|
||||
|
||||
|
||||
|
@ -696,7 +694,7 @@ namespace helpers {
|
|||
const Nd4jLong inBatchWidth = resizerState.inHeight * inRowWidth;
|
||||
|
||||
const T* inputPtr = image->getDataBuffer()->primaryAsT<T>();
|
||||
T* pOutputY = output->dataBuffer()->primaryAsT<T>(); //_data.data();
|
||||
float* pOutputY = output->dataBuffer()->primaryAsT<float>(); // output is float anyway
|
||||
std::vector<float> cachedValue(numChannels == 3 ? 0 : 4 * numChannels, 0);
|
||||
|
||||
auto func = PRAGMA_THREADS_FOR {
|
||||
|
@ -881,8 +879,7 @@ namespace helpers {
|
|||
}
|
||||
int resizeBicubicFunctorA(nd4j::LaunchContext * context, NDArray const* image, int const width, int const height,
|
||||
bool const alignCorners, bool const halfPixelAlign, NDArray* output) {
|
||||
BUILD_SINGLE_SELECTOR(image->dataType(), return resizeBicubicFunctorA_, (context,
|
||||
image, width, height, alignCorners, halfPixelAlign, output), NUMERIC_TYPES);
|
||||
BUILD_SINGLE_SELECTOR(image->dataType(), return resizeBicubicFunctorA_, (context, image, width, height, alignCorners, halfPixelAlign, output), NUMERIC_TYPES);
|
||||
}
|
||||
// ------------------------------------------------------------------------------------------------------------------ //
|
||||
int resizeFunctor(nd4j::LaunchContext * context, NDArray const* image, int const width, int const height,
|
||||
|
@ -921,4 +918,4 @@ namespace helpers {
|
|||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -689,7 +689,7 @@ namespace helpers {
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
static __global__ void bicubicInterpolateWithCachingKernel(float const* cachedTable, float* cachedValue, T const* inputPtr, ImageResizerState* pResizerState, WeightsAndIndices* xWais, bool halfPixelCenters, Nd4jLong inBatchWidth, Nd4jLong inRowWidth, T* outputPtr) {
|
||||
static __global__ void bicubicInterpolateWithCachingKernel(float const* cachedTable, float* cachedValue, T const* inputPtr, ImageResizerState* pResizerState, WeightsAndIndices* xWais, bool halfPixelCenters, Nd4jLong inBatchWidth, Nd4jLong inRowWidth, float* outputPtr) {
|
||||
// auto numChannels = pResizerState->channels;
|
||||
for (Nd4jLong b = blockIdx.x; b < pResizerState->batchSize; b += gridDim.x) {
|
||||
auto pInput = inputPtr + b * inBatchWidth;
|
||||
|
@ -877,7 +877,7 @@ namespace helpers {
|
|||
throw cuda_exception::build("helpers::bicubicInterpolateWithCaching: computeXWeigtsAndInidces finished with error", err);
|
||||
}
|
||||
const T* pInput = image->getDataBuffer()->specialAsT<T>();
|
||||
T* pOutput = output->dataBuffer()->specialAsT<T>(); //_data.data();
|
||||
float* pOutput = output->dataBuffer()->specialAsT<float>(); //_data.data();
|
||||
bicubicInterpolateWithCachingKernel<T><<<128, 1, 512, *stream>>>(coeffsTable, cachedValue, pInput,
|
||||
resizerStateD, xWais, halfPixelCenters, inBatchWidth, inRowWidth, pOutput);
|
||||
err = cudaStreamSynchronize(*stream);
|
||||
|
|
|
@ -832,3 +832,49 @@ TEST_F(BroadcastableOpsTests, broadcast_3) {
|
|||
ASSERT_TRUE(z.isSameShape(e));
|
||||
ASSERT_TRUE(z.equalsTo(e));
|
||||
}
|
||||
|
||||
TEST_F(BroadcastableOpsTests, test_bert_multiply_1) {
|
||||
auto x = NDArrayFactory::create<float>('c', {4, 128, 1});
|
||||
auto y = NDArrayFactory::create<float>('c', {4, 1, 128});
|
||||
auto z = NDArrayFactory::create<float>('c', {4, 128, 128});
|
||||
auto e = NDArrayFactory::create<float>('c', {4, 128, 128});
|
||||
|
||||
x.assign(0.f);
|
||||
y.assign(1.f);
|
||||
z.assign(119.f);
|
||||
e.assign(0.f);
|
||||
/*
|
||||
Context ctx(1);
|
||||
ctx.setInputArray(0, &x);
|
||||
ctx.setInputArray(1, &y);
|
||||
ctx.setOutputArray(0, &z);
|
||||
|
||||
nd4j::ops::multiply op;
|
||||
auto status = op.execute(&ctx);
|
||||
ASSERT_EQ(Status::OK(), status);
|
||||
|
||||
z.printIndexedBuffer();
|
||||
*/
|
||||
|
||||
x.applyTrueBroadcast(BroadcastOpsTuple::Multiply(), &y, &z);
|
||||
|
||||
//z.printIndexedBuffer();
|
||||
|
||||
ASSERT_EQ(e, z);
|
||||
}
|
||||
|
||||
TEST_F(BroadcastableOpsTests, test_bert_multiply_2) {
|
||||
auto x = NDArrayFactory::create<float>('c', {4, 128, 1});
|
||||
auto y = NDArrayFactory::create<float>('c', {768});
|
||||
auto z = NDArrayFactory::create<float>('c', {4, 128, 768});
|
||||
auto e = NDArrayFactory::create<float>('c', {4, 128, 768});
|
||||
|
||||
x.assign(1.f);
|
||||
y.assign(2.f);
|
||||
z.assign(119.f);
|
||||
e.assign(2.f);
|
||||
|
||||
x.applyTrueBroadcast(BroadcastOpsTuple::Multiply(), &y, &z);
|
||||
|
||||
ASSERT_EQ(e, z);
|
||||
}
|
||||
|
|
|
@ -479,166 +479,166 @@ TEST_F(DeclarableOpsTests11, log_loss_grad_test13) {
|
|||
|
||||
TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test1) {
|
||||
|
||||
NDArray input = NDArrayFactory::create<double>('c', {1, 7, 7, 1}, {
|
||||
1, 2.1, 3.15, 4.2, 5.15, 6.1, 7,
|
||||
8, 9.1, 10., 11, 12.9, 13.1, 14,
|
||||
15, 16., 17., 18, 19, 20., 21,
|
||||
22, 23., 24., 25, 26, 27, 28,
|
||||
30, 31, 32, 33, 34., 35, 36,
|
||||
37, 38, 39, 40, 41., 42, 43,
|
||||
44, 45, 46, 47, 48., 49, 50
|
||||
NDArray input = NDArrayFactory::create<float>('c', {1, 7, 7, 1}, {
|
||||
1.f, 2.1f, 3.15f, 4.2f, 5.15f, 6.1f, 7.f,
|
||||
8.f, 9.1f, 10.f, 11.f, 12.9f, 13.1f, 14.f,
|
||||
15.f, 16.f, 17.f, 18.f, 19.f, 20.f, 21.f,
|
||||
22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 28.f,
|
||||
30.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f,
|
||||
37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 43.f,
|
||||
44.f, 45.f, 46.f, 47.f, 48.f, 49.f, 50.f
|
||||
});
|
||||
NDArray expected = NDArrayFactory::create<double>('c', {1, 30, 30, 1}, {
|
||||
1. ,1.1976162 ,1.4174359 ,1.6775769 ,1.9961575 ,2.3283265 ,
|
||||
2.550918 ,2.7360606 ,2.9655411 ,3.2929654 ,3.5441515 ,3.7380352 ,
|
||||
3.948995 ,4.248106 ,4.5073795 ,4.6843743 ,4.8572845 ,5.104302 ,
|
||||
5.3869915 ,5.581401 ,5.7539616 ,5.974285 ,6.272836 ,6.5204263 ,
|
||||
6.718899 ,6.8871036 ,7.039068 ,7.099216 ,7.0784245 ,7.0281887 ,
|
||||
2.247592 ,2.446947 ,2.6694887 ,2.9312382 ,3.248216 ,3.5745337 ,
|
||||
3.78931 ,3.9656973 ,4.186417 ,4.5046535 ,4.740569 ,4.9217057 ,
|
||||
5.133866 ,5.459533 ,5.7744613 ,6.0197873 ,6.254011 ,6.535633 ,
|
||||
6.8097296 ,6.9607787 ,7.0749416 ,7.241601 ,7.5094895 ,7.7499495 ,
|
||||
7.954571 ,8.131972 ,8.286526 ,8.346463 ,8.325745 ,8.275683 ,
|
||||
3.6286845 ,3.830573 ,4.0569587 ,4.3211575 ,4.6364856 ,4.9556503 ,
|
||||
5.160583 ,5.3258467 ,5.535462 ,5.84216 ,6.058749 ,6.223753 ,
|
||||
6.437597 ,6.797369 ,7.1836042 ,7.5164022 ,7.8290343 ,8.154773 ,
|
||||
8.417635 ,8.512958 ,8.5521 ,8.649708 ,8.87788 ,9.108794 ,
|
||||
9.320926 ,9.509781 ,9.667375 ,9.72694 ,9.706349 ,9.656599 ,
|
||||
5.276778 ,5.480438 ,5.709702 ,5.9754477 ,6.288551 ,6.6005697 ,
|
||||
6.796207 ,6.9511423 ,7.1503997 ,7.4461427 ,7.644651 ,7.794562 ,
|
||||
8.009684 ,8.400473 ,8.851847 ,9.26469 ,9.649218, 10.015648 ,
|
||||
10.268647 ,10.313368 ,10.2843275 ,10.319379 ,10.512033 ,10.734956 ,
|
||||
10.954604 ,11.154507 ,11.315369 ,11.374779 ,11.354242 ,11.304622 ,
|
||||
7.325373 ,7.5284843 ,7.757575 ,8.022221 ,8.331997 ,8.638187 ,
|
||||
8.827649 ,8.976217 ,9.168955 ,9.45726 ,9.6442375 ,9.784517 ,
|
||||
9.999621, 10.407702 ,10.896234, 11.355122, 11.781423, 12.172186 ,
|
||||
12.420712 ,12.4374485 ,12.370511 ,12.371386 ,12.545973 ,12.766424 ,
|
||||
12.992249 ,13.20012 ,13.364252 ,13.424109 ,13.40342 ,13.353425 ,
|
||||
9.493208 ,9.692467 ,9.9169445, 10.176801, 10.482199, 10.78547 ,
|
||||
10.974367 ,11.123442 ,11.31637 ,11.603645 ,11.790616 ,11.930889 ,
|
||||
12.144082 ,12.546447 ,13.024898 ,13.4723 ,13.889232 ,14.276275 ,
|
||||
14.528972 ,14.555555 ,14.50145 ,14.515459 ,14.700572 ,14.927055 ,
|
||||
15.156046 ,15.366046 ,15.532901 ,15.594008 ,15.5728855 ,15.521847 ,
|
||||
10.970133 ,11.163599 ,11.380694 ,11.633735 ,11.935032 ,12.238887 ,
|
||||
12.43254 ,12.588294 ,12.787534 ,13.079956 ,13.27752 ,13.426631 ,
|
||||
13.636713 ,14.013844 ,14.441672 ,14.827978 ,15.191209 ,15.549808 ,
|
||||
15.81343 ,15.881828 ,15.883522 ,15.950411 ,16.16933 ,16.40794 ,
|
||||
16.636436 ,16.842583 ,17.010887 ,17.07363 ,17.05194 ,16.999537 ,
|
||||
12.219155 ,12.406129 ,12.614796 ,12.860335 ,13.157928 ,13.464224 ,
|
||||
13.665207 ,13.830567 ,14.039036 ,14.339629 ,14.552863 ,14.715049 ,
|
||||
14.921564 ,15.264454 ,15.622843 ,15.924977 ,16.213829 ,16.532364 ,
|
||||
16.8099 ,16.934835 ,17.012146 ,17.150164 ,17.413412 ,17.666712 ,
|
||||
17.892765 ,18.09207 ,18.261044 ,18.325531 ,18.303238 ,18.249378 ,
|
||||
13.7663965 ,13.947391 ,14.148263 ,14.386917 ,14.681246 ,14.990087 ,
|
||||
15.198166 ,15.372728 ,15.590062 ,15.898583 ,16.126892 ,16.301655 ,
|
||||
16.50487 ,16.815214 ,17.107498 ,17.329458 ,17.547403 ,17.827654 ,
|
||||
18.118288 ,18.296928 ,18.4461 ,18.651634 ,18.956806 ,19.22382 ,
|
||||
19.447308 ,19.639887 ,19.809319 ,19.875397 ,19.852556 ,19.797365 ,
|
||||
15.9419365 ,16.118704 ,16.314133 ,16.547867 ,16.839561 ,17.14954 ,
|
||||
17.361883 ,17.542162 ,17.764957 ,18.078188 ,18.315733 ,18.498205 ,
|
||||
18.699116 ,18.988684 ,19.238989 ,19.410137 ,19.583265 ,19.839512 ,
|
||||
20.13878 ,20.35177 ,20.546844 ,20.795671 ,21.128067 ,21.404358 ,
|
||||
21.626736 ,21.8155 ,21.98561 ,22.052843 ,22.029604 ,21.973448 ,
|
||||
17.53522 ,17.71077 ,17.904636 ,18.13695 ,18.42784 ,18.738056 ,
|
||||
18.951529 ,19.133352 ,19.357613 ,19.672083 ,19.912102 ,20.096638 ,
|
||||
20.296894 ,20.580765 ,20.819603 ,20.976887 ,21.137802 ,21.387535 ,
|
||||
21.689209 ,21.911621 ,22.119276 ,22.37999 ,22.71991 ,22.998823 ,
|
||||
23.22097 ,23.40876 ,23.57911 ,23.646685 ,23.623325 ,23.566887 ,
|
||||
18.746353 ,18.922657 ,19.117487 ,19.350685 ,19.64207 ,19.952137 ,
|
||||
20.164913 ,20.345781 ,20.569134 ,20.88284 ,21.12133 ,21.30459 ,
|
||||
21.505253 ,21.792645 ,22.038572 ,22.204426 ,22.37289 ,22.626648 ,
|
||||
22.926834 ,23.143423 ,23.343302 ,23.596668 ,23.931936 ,24.209232 ,
|
||||
24.431519 ,24.619913 ,24.79011 ,24.857473 ,24.83419 ,24.777927 ,
|
||||
20.16656 ,20.344206 ,20.540766 ,20.775532 ,21.067804 ,21.377607 ,
|
||||
21.589132 ,21.768297 ,21.99003 ,22.302366 ,22.538124 ,22.719105 ,
|
||||
22.920494 ,23.214176 ,23.472767 ,23.653934 ,23.83589 ,24.096842 ,
|
||||
24.394371 ,24.600555 ,24.786541 ,25.026773 ,25.353731 ,25.62813 ,
|
||||
25.850672 ,26.04014 ,26.210072 ,26.277063 ,26.253906 ,26.197956 ,
|
||||
22.363024 ,22.54125 ,22.738552 ,22.973991 ,23.266647 ,23.57634 ,
|
||||
23.787327 ,23.96576 ,24.186796 ,24.498543 ,24.733124 ,24.913122 ,
|
||||
25.114826 ,25.411213 ,25.675262 ,25.863028 ,26.050789 ,26.314838 ,
|
||||
26.611223 ,26.812925 ,26.992926 ,27.227505 ,27.550882 ,27.824034 ,
|
||||
28.046684 ,28.236614 ,28.406433 ,28.473265 ,28.450163 ,28.394344 ,
|
||||
24.429443 ,24.60767 ,24.80497 ,25.04041 ,25.333065 ,25.642756 ,
|
||||
25.853743 ,26.032173 ,26.25321 ,26.564959 ,26.79954 ,26.97954 ,
|
||||
27.181242 ,27.47763 ,27.74168 ,27.929441 ,28.117207 ,28.381254 ,
|
||||
28.677637 ,28.879343 ,29.059345 ,29.293922 ,29.617298 ,29.890451 ,
|
||||
30.113104 ,30.303034 ,30.472853 ,30.539684 ,30.516582 ,30.460762 ,
|
||||
26. ,26.178228 ,26.375526 ,26.61097 ,26.903624 ,27.213314 ,
|
||||
27.424305 ,27.602734 ,27.823772 ,28.135519 ,28.3701 ,28.550098 ,
|
||||
28.7518 ,29.04819 ,29.312237 ,29.5 ,29.687763 ,29.951813 ,
|
||||
30.2482 ,30.449903 ,30.629902 ,30.864483 ,31.187859 ,31.461012 ,
|
||||
31.683659 ,31.873592 ,32.043407 ,32.11024 ,32.087135 ,32.03132 ,
|
||||
27.570559 ,27.748787 ,27.946087 ,28.181528 ,28.474184 ,28.783876 ,
|
||||
28.994865 ,29.173294 ,29.39433 ,29.70608 ,29.940659 ,30.120655 ,
|
||||
30.32236 ,30.618746 ,30.882797 ,31.070557 ,31.25832 ,31.522371 ,
|
||||
31.818754 ,32.02046 ,32.20046 ,32.43504 ,32.758415 ,33.031567 ,
|
||||
33.25422 ,33.44415 ,33.613964 ,33.680794 ,33.657696 ,33.60188 ,
|
||||
29.636976 ,29.815207 ,30.0125 ,30.247944 ,30.5406 ,30.85029 ,
|
||||
31.061283 ,31.239712 ,31.46075 ,31.7725 ,32.00708 ,32.187077 ,
|
||||
32.38878 ,32.685165 ,32.949215 ,33.13698 ,33.32474 ,33.58879 ,
|
||||
33.885178 ,34.086884 ,34.26688 ,34.501457 ,34.824837 ,35.09799 ,
|
||||
35.320637 ,35.510574 ,35.68039 ,35.747215 ,35.724117 ,35.6683 ,
|
||||
31.83344 ,32.011665 ,32.20897 ,32.444412 ,32.73707 ,33.046757 ,
|
||||
33.257744 ,33.436176 ,33.657207 ,33.96896 ,34.203537 ,34.383537 ,
|
||||
34.58524 ,34.88163 ,35.145676 ,35.33344 ,35.521206 ,35.785255 ,
|
||||
36.081642 ,36.28334 ,36.46334 ,36.69792 ,37.021297 ,37.294453 ,
|
||||
37.517097 ,37.707027 ,37.876846 ,37.94368 ,37.920578 ,37.864758 ,
|
||||
33.253647 ,33.431873 ,33.62917 ,33.864613 ,34.15727 ,34.466957 ,
|
||||
34.677948 ,34.856377 ,35.077415 ,35.38916 ,35.623745 ,35.803745 ,
|
||||
36.005447 ,36.301834 ,36.565884 ,36.753647 ,36.941406 ,37.205456 ,
|
||||
37.50184 ,37.703545 ,37.883545 ,38.118122 ,38.4415 ,38.714653 ,
|
||||
38.9373 ,39.127235 ,39.297054 ,39.363884 ,39.340782 ,39.28496 ,
|
||||
34.464783 ,34.64301 ,34.840305 ,35.075752 ,35.368404 ,35.6781 ,
|
||||
35.889088 ,36.067516 ,36.28855 ,36.6003 ,36.834885 ,37.014877 ,
|
||||
37.216583 ,37.51297 ,37.77702 ,37.964783 ,38.152546 ,38.416595 ,
|
||||
38.71298 ,38.914684 ,39.094685 ,39.32926 ,39.652645 ,39.925793 ,
|
||||
40.14844 ,40.338375 ,40.508194 ,40.575024 ,40.55192 ,40.496105 ,
|
||||
36.058067 ,36.23629 ,36.43359 ,36.669033 ,36.961685 ,37.271378 ,
|
||||
37.48237 ,37.6608 ,37.881836 ,38.19359 ,38.42817 ,38.608162 ,
|
||||
38.809868 ,39.10625 ,39.3703 ,39.558064 ,39.74583 ,40.00988 ,
|
||||
40.306267 ,40.50797 ,40.68797 ,40.92255 ,41.245926 ,41.519077 ,
|
||||
41.741722 ,41.931652 ,42.101475 ,42.168304 ,42.145203 ,42.089386 ,
|
||||
38.315002 ,38.493233 ,38.690533 ,38.925976 ,39.218628 ,39.52832 ,
|
||||
39.739307 ,39.917736 ,40.138775 ,40.45052 ,40.685104 ,40.865097 ,
|
||||
41.066803 ,41.36319 ,41.627243 ,41.815002 ,42.002766 ,42.26682 ,
|
||||
42.5632 ,42.764908 ,42.944904 ,43.179485 ,43.50286 ,43.776016 ,
|
||||
43.998665 ,44.188595 ,44.358418 ,44.425247 ,44.402145 ,44.34633 ,
|
||||
40.22708 ,40.40531 ,40.602608 ,40.83805 ,41.130707 ,41.440395 ,
|
||||
41.651382 ,41.82982 ,42.050854 ,42.3626 ,42.597183 ,42.77718 ,
|
||||
42.97888 ,43.27527 ,43.53932 ,43.72708 ,43.914845 ,44.178894 ,
|
||||
44.47528 ,44.676983 ,44.856983 ,45.09156 ,45.41494 ,45.68809 ,
|
||||
45.91074 ,46.100674 ,46.270493 ,46.337322 ,46.31422 ,46.2584 ,
|
||||
41.785618 ,41.963844 ,42.161144 ,42.396584 ,42.68924 ,42.998936 ,
|
||||
43.209923 ,43.388355 ,43.609394 ,43.921143 ,44.15572 ,44.335716 ,
|
||||
44.53742 ,44.833805 ,45.09786 ,45.285614 ,45.473377 ,45.737427 ,
|
||||
46.033817 ,46.235523 ,46.415524 ,46.650105 ,46.973476 ,47.24663 ,
|
||||
47.469276 ,47.65921 ,47.82903 ,47.895855 ,47.872753 ,47.81694 ,
|
||||
43.11514 ,43.293365 ,43.490665 ,43.726105 ,44.018764 ,44.328457 ,
|
||||
44.539444 ,44.717873 ,44.93891 ,45.25066 ,45.48524 ,45.665237 ,
|
||||
45.86694 ,46.163326 ,46.427376 ,46.615143 ,46.802902 ,47.066956 ,
|
||||
47.363342 ,47.56505 ,47.74505 ,47.979626 ,48.302998 ,48.576153 ,
|
||||
48.798798 ,48.98873 ,49.158546 ,49.225376 ,49.202282 ,49.146458 ,
|
||||
44.303867 ,44.482094 ,44.679394 ,44.914833 ,45.207493 ,45.51718 ,
|
||||
45.72817 ,45.9066 ,46.12764 ,46.439384 ,46.673965 ,46.853966 ,
|
||||
47.055668 ,47.352055 ,47.6161 ,47.803867 ,47.99163 ,48.25568 ,
|
||||
48.552063 ,48.75377 ,48.933773 ,49.16835 ,49.491726 ,49.764877 ,
|
||||
49.987526 ,50.17746 ,50.347275 ,50.4141 ,50.391006 ,50.335186 ,
|
||||
44.771675 ,44.949905 ,45.1472 ,45.382645 ,45.6753 ,45.98499 ,
|
||||
46.195976 ,46.374413 ,46.595448 ,46.907196 ,47.141773 ,47.321774 ,
|
||||
47.523476 ,47.819862 ,48.08391 ,48.27168 ,48.459446 ,48.72349 ,
|
||||
49.019882 ,49.22158 ,49.401585 ,49.63616 ,49.959538 ,50.232693 ,
|
||||
50.455338 ,50.64527 ,50.81509 ,50.88192 ,50.858818 ,50.803 ,
|
||||
44.609966 ,44.788193 ,44.985493 ,45.220936 ,45.51359 ,45.82328 ,
|
||||
46.03427 ,46.2127 ,46.433743 ,46.74549 ,46.98007 ,47.160065 ,
|
||||
47.36177 ,47.658157 ,47.922207 ,48.10997 ,48.297733 ,48.561783 ,
|
||||
48.858166 ,49.059875 ,49.239872 ,49.47445 ,49.79783 ,50.07098 ,
|
||||
50.293625 ,50.48356 ,50.653378 ,50.720203 ,50.6971 ,50.64128 ,
|
||||
44.219246 ,44.397472 ,44.594772 ,44.83021 ,45.122868 ,45.43256 ,
|
||||
45.643543 ,45.82198 ,46.04302 ,46.354763 ,46.589344 ,46.76934 ,
|
||||
46.971046 ,47.267433 ,47.531483 ,47.719242 ,47.907005 ,48.17105 ,
|
||||
48.467438 ,48.66914 ,48.849144 ,49.08372 ,49.4071 ,49.680256 ,
|
||||
49.902905 ,50.092834 ,50.262653 ,50.329483 ,50.30638 ,50.25057});
|
||||
NDArray expected = NDArrayFactory::create<float>('c', {1, 30, 30, 1}, {
|
||||
1.f, 1.1976162f, 1.4174359f, 1.6775769f, 1.9961575f, 2.3283265f,
|
||||
2.550918f, 2.7360606f, 2.9655411f, 3.2929654f, 3.5441515f, 3.7380352f,
|
||||
3.948995f, 4.248106f, 4.5073795f, 4.6843743f, 4.8572845f, 5.104302f,
|
||||
5.3869915f, 5.581401f, 5.7539616f, 5.974285f, 6.272836f, 6.5204263f,
|
||||
6.718899f, 6.8871036f, 7.039068f, 7.099216f, 7.0784245f, 7.0281887f,
|
||||
2.247592f, 2.446947f, 2.6694887f, 2.9312382f, 3.248216f, 3.5745337f,
|
||||
3.78931f, 3.9656973f, 4.186417f, 4.5046535f, 4.740569f, 4.9217057f,
|
||||
5.133866f, 5.459533f, 5.7744613f, 6.0197873f, 6.254011f, 6.535633f,
|
||||
6.8097296f, 6.9607787f, 7.0749416f, 7.241601f, 7.5094895f, 7.7499495f,
|
||||
7.954571f, 8.131972f, 8.286526f, 8.346463f, 8.325745f, 8.275683f,
|
||||
3.6286845f, 3.830573f, 4.0569587f, 4.3211575f, 4.6364856f, 4.9556503f,
|
||||
5.160583f, 5.3258467f, 5.535462f, 5.84216f, 6.058749f, 6.223753f,
|
||||
6.437597f, 6.797369f, 7.1836042f, 7.5164022f, 7.8290343f, 8.154773f,
|
||||
8.417635f, 8.512958f, 8.5521f, 8.649708f, 8.87788f, 9.108794f,
|
||||
9.320926f, 9.509781f, 9.667375f, 9.72694f, 9.706349f, 9.656599f,
|
||||
5.276778f, 5.480438f, 5.709702f, 5.9754477f, 6.288551f, 6.6005697f,
|
||||
6.796207f, 6.9511423f, 7.1503997f, 7.4461427f, 7.644651f, 7.794562f,
|
||||
8.009684f, 8.400473f, 8.851847f, 9.26469f, 9.649218f, 10.015648f,
|
||||
10.268647f, 10.313368f, 10.2843275f, 10.319379f, 10.512033f, 10.734956f,
|
||||
10.954604f, 11.154507f, 11.315369f, 11.374779f, 11.354242f, 11.304622f,
|
||||
7.325373f, 7.5284843f, 7.757575f, 8.022221f, 8.331997f, 8.638187f,
|
||||
8.827649f, 8.976217f, 9.168955f, 9.45726f, 9.6442375f, 9.784517f,
|
||||
9.999621f, 10.407702f, 10.896234f, 11.355122f, 11.781423f, 12.172186f,
|
||||
12.420712f, 12.4374485f, 12.370511f, 12.371386f, 12.545973f, 12.766424f,
|
||||
12.992249f, 13.20012f, 13.364252f, 13.424109f, 13.40342f, 13.353425f,
|
||||
9.493208f, 9.692467f, 9.9169445f, 10.176801f, 10.482199f, 10.78547f,
|
||||
10.974367f, 11.123442f, 11.31637f, 11.603645f, 11.790616f, 11.930889f,
|
||||
12.144082f, 12.546447f, 13.024898f, 13.4723f, 13.889232f, 14.276275f,
|
||||
14.528972f, 14.555555f, 14.50145f, 14.515459f, 14.700572f, 14.927055f,
|
||||
15.156046f, 15.366046f, 15.532901f, 15.594008f, 15.5728855f, 15.521847f,
|
||||
10.970133f, 11.163599f, 11.380694f, 11.633735f, 11.935032f, 12.238887f,
|
||||
12.43254f, 12.588294f, 12.787534f, 13.079956f, 13.27752f, 13.426631f,
|
||||
13.636713f, 14.013844f, 14.441672f, 14.827978f, 15.191209f, 15.549808f,
|
||||
15.81343f, 15.881828f, 15.883522f, 15.950411f, 16.16933f, 16.40794f,
|
||||
16.636436f, 16.842583f, 17.010887f, 17.07363f, 17.05194f, 16.999537f,
|
||||
12.219155f, 12.406129f, 12.614796f, 12.860335f, 13.157928f, 13.464224f,
|
||||
13.665207f, 13.830567f, 14.039036f, 14.339629f, 14.552863f, 14.715049f,
|
||||
14.921564f, 15.264454f, 15.622843f, 15.924977f, 16.213829f, 16.532364f,
|
||||
16.8099f, 16.934835f, 17.012146f, 17.150164f, 17.413412f, 17.666712f,
|
||||
17.892765f, 18.09207f, 18.261044f, 18.325531f, 18.303238f, 18.249378f,
|
||||
13.7663965f, 13.947391f, 14.148263f, 14.386917f, 14.681246f, 14.990087f,
|
||||
15.198166f, 15.372728f, 15.590062f, 15.898583f, 16.126892f, 16.301655f,
|
||||
16.50487f, 16.815214f, 17.107498f, 17.329458f, 17.547403f, 17.827654f,
|
||||
18.118288f, 18.296928f, 18.4461f, 18.651634f, 18.956806f, 19.22382f,
|
||||
19.447308f, 19.639887f, 19.809319f, 19.875397f, 19.852556f, 19.797365f,
|
||||
15.9419365f, 16.118704f, 16.314133f, 16.547867f, 16.839561f, 17.14954f,
|
||||
17.361883f, 17.542162f, 17.764957f, 18.078188f, 18.315733f, 18.498205f,
|
||||
18.699116f, 18.988684f, 19.238989f, 19.410137f, 19.583265f, 19.839512f,
|
||||
20.13878f, 20.35177f, 20.546844f, 20.795671f, 21.128067f, 21.404358f,
|
||||
21.626736f, 21.8155f, 21.98561f, 22.052843f, 22.029604f, 21.973448f,
|
||||
17.53522f, 17.71077f, 17.904636f, 18.13695f, 18.42784f, 18.738056f,
|
||||
18.951529f, 19.133352f, 19.357613f, 19.672083f, 19.912102f, 20.096638f,
|
||||
20.296894f, 20.580765f, 20.819603f, 20.976887f, 21.137802f, 21.387535f,
|
||||
21.689209f, 21.911621f, 22.119276f, 22.37999f, 22.71991f, 22.998823f,
|
||||
23.22097f, 23.40876f, 23.57911f, 23.646685f, 23.623325f, 23.566887f,
|
||||
18.746353f, 18.922657f, 19.117487f, 19.350685f, 19.64207f, 19.952137f,
|
||||
20.164913f, 20.345781f, 20.569134f, 20.88284f, 21.12133f, 21.30459f,
|
||||
21.505253f, 21.792645f, 22.038572f, 22.204426f, 22.37289f, 22.626648f,
|
||||
22.926834f, 23.143423f, 23.343302f, 23.596668f, 23.931936f, 24.209232f,
|
||||
24.431519f, 24.619913f, 24.79011f, 24.857473f, 24.83419f, 24.777927f,
|
||||
20.16656f, 20.344206f, 20.540766f, 20.775532f, 21.067804f, 21.377607f,
|
||||
21.589132f, 21.768297f, 21.99003f, 22.302366f, 22.538124f, 22.719105f,
|
||||
22.920494f, 23.214176f, 23.472767f, 23.653934f, 23.83589f, 24.096842f,
|
||||
24.394371f, 24.600555f, 24.786541f, 25.026773f, 25.353731f, 25.62813f,
|
||||
25.850672f, 26.04014f, 26.210072f, 26.277063f, 26.253906f, 26.197956f,
|
||||
22.363024f, 22.54125f, 22.738552f, 22.973991f, 23.266647f, 23.57634f,
|
||||
23.787327f, 23.96576f, 24.186796f, 24.498543f, 24.733124f, 24.913122f,
|
||||
25.114826f, 25.411213f, 25.675262f, 25.863028f, 26.050789f, 26.314838f,
|
||||
26.611223f, 26.812925f, 26.992926f, 27.227505f, 27.550882f, 27.824034f,
|
||||
28.046684f, 28.236614f, 28.406433f, 28.473265f, 28.450163f, 28.394344f,
|
||||
24.429443f, 24.60767f, 24.80497f, 25.04041f, 25.333065f, 25.642756f,
|
||||
25.853743f, 26.032173f, 26.25321f, 26.564959f, 26.79954f, 26.97954f,
|
||||
27.181242f, 27.47763f, 27.74168f, 27.929441f, 28.117207f, 28.381254f,
|
||||
28.677637f, 28.879343f, 29.059345f, 29.293922f, 29.617298f, 29.890451f,
|
||||
30.113104f, 30.303034f, 30.472853f, 30.539684f, 30.516582f, 30.460762f,
|
||||
26.f, 26.178228f, 26.375526f, 26.61097f, 26.903624f, 27.213314f,
|
||||
27.424305f, 27.602734f, 27.823772f, 28.135519f, 28.3701f, 28.550098f,
|
||||
28.7518f, 29.04819f, 29.312237f, 29.5f, 29.687763f, 29.951813f,
|
||||
30.2482f, 30.449903f, 30.629902f, 30.864483f, 31.187859f, 31.461012f,
|
||||
31.683659f, 31.873592f, 32.043407f, 32.11024f, 32.087135f, 32.03132f,
|
||||
27.570559f, 27.748787f, 27.946087f, 28.181528f, 28.474184f, 28.783876f,
|
||||
28.994865f, 29.173294f, 29.39433f, 29.70608f, 29.940659f, 30.120655f,
|
||||
30.32236f, 30.618746f, 30.882797f, 31.070557f, 31.25832f, 31.522371f,
|
||||
31.818754f, 32.02046f, 32.20046f, 32.43504f, 32.758415f, 33.031567f,
|
||||
33.25422f, 33.44415f, 33.613964f, 33.680794f, 33.657696f, 33.60188f,
|
||||
29.636976f, 29.815207f, 30.0125f, 30.247944f, 30.5406f, 30.85029f,
|
||||
31.061283f, 31.239712f, 31.46075f, 31.7725f, 32.00708f, 32.187077f,
|
||||
32.38878f, 32.685165f, 32.949215f, 33.13698f, 33.32474f, 33.58879f,
|
||||
33.885178f, 34.086884f, 34.26688f, 34.501457f, 34.824837f, 35.09799f,
|
||||
35.320637f, 35.510574f, 35.68039f, 35.747215f, 35.724117f, 35.6683f,
|
||||
31.83344f, 32.011665f, 32.20897f, 32.444412f, 32.73707f, 33.046757f,
|
||||
33.257744f, 33.436176f, 33.657207f, 33.96896f, 34.203537f, 34.383537f,
|
||||
34.58524f, 34.88163f, 35.145676f, 35.33344f, 35.521206f, 35.785255f,
|
||||
36.081642f, 36.28334f, 36.46334f, 36.69792f, 37.021297f, 37.294453f,
|
||||
37.517097f, 37.707027f, 37.876846f, 37.94368f, 37.920578f, 37.864758f,
|
||||
33.253647f, 33.431873f, 33.62917f, 33.864613f, 34.15727f, 34.466957f,
|
||||
34.677948f, 34.856377f, 35.077415f, 35.38916f, 35.623745f, 35.803745f,
|
||||
36.005447f, 36.301834f, 36.565884f, 36.753647f, 36.941406f, 37.205456f,
|
||||
37.50184f, 37.703545f, 37.883545f, 38.118122f, 38.4415f, 38.714653f,
|
||||
38.9373f, 39.127235f, 39.297054f, 39.363884f, 39.340782f, 39.28496f,
|
||||
34.464783f, 34.64301f, 34.840305f, 35.075752f, 35.368404f, 35.6781f,
|
||||
35.889088f, 36.067516f, 36.28855f, 36.6003f, 36.834885f, 37.014877f,
|
||||
37.216583f, 37.51297f, 37.77702f, 37.964783f, 38.152546f, 38.416595f,
|
||||
38.71298f, 38.914684f, 39.094685f, 39.32926f, 39.652645f, 39.925793f,
|
||||
40.14844f, 40.338375f, 40.508194f, 40.575024f, 40.55192f, 40.496105f,
|
||||
36.058067f, 36.23629f, 36.43359f, 36.669033f, 36.961685f, 37.271378f,
|
||||
37.48237f, 37.6608f, 37.881836f, 38.19359f, 38.42817f, 38.608162f,
|
||||
38.809868f, 39.10625f, 39.3703f, 39.558064f, 39.74583f, 40.00988f,
|
||||
40.306267f, 40.50797f, 40.68797f, 40.92255f, 41.245926f, 41.519077f,
|
||||
41.741722f, 41.931652f, 42.101475f, 42.168304f, 42.145203f, 42.089386f,
|
||||
38.315002f, 38.493233f, 38.690533f, 38.925976f, 39.218628f, 39.52832f,
|
||||
39.739307f, 39.917736f, 40.138775f, 40.45052f, 40.685104f, 40.865097f,
|
||||
41.066803f, 41.36319f, 41.627243f, 41.815002f, 42.002766f, 42.26682f,
|
||||
42.5632f, 42.764908f, 42.944904f, 43.179485f, 43.50286f, 43.776016f,
|
||||
43.998665f, 44.188595f, 44.358418f, 44.425247f, 44.402145f, 44.34633f,
|
||||
40.22708f, 40.40531f, 40.602608f, 40.83805f, 41.130707f, 41.440395f,
|
||||
41.651382f, 41.82982f, 42.050854f, 42.3626f, 42.597183f, 42.77718f,
|
||||
42.97888f, 43.27527f, 43.53932f, 43.72708f, 43.914845f, 44.178894f,
|
||||
44.47528f, 44.676983f, 44.856983f, 45.09156f, 45.41494f, 45.68809f,
|
||||
45.91074f, 46.100674f, 46.270493f, 46.337322f, 46.31422f, 46.2584f,
|
||||
41.785618f, 41.963844f, 42.161144f, 42.396584f, 42.68924f, 42.998936f,
|
||||
43.209923f, 43.388355f, 43.609394f, 43.921143f, 44.15572f, 44.335716f,
|
||||
44.53742f, 44.833805f, 45.09786f, 45.285614f, 45.473377f, 45.737427f,
|
||||
46.033817f, 46.235523f, 46.415524f, 46.650105f, 46.973476f, 47.24663f,
|
||||
47.469276f, 47.65921f, 47.82903f, 47.895855f, 47.872753f, 47.81694f,
|
||||
43.11514f, 43.293365f, 43.490665f, 43.726105f, 44.018764f, 44.328457f,
|
||||
44.539444f, 44.717873f, 44.93891f, 45.25066f, 45.48524f, 45.665237f,
|
||||
45.86694f, 46.163326f, 46.427376f, 46.615143f, 46.802902f, 47.066956f,
|
||||
47.363342f, 47.56505f, 47.74505f, 47.979626f, 48.302998f, 48.576153f,
|
||||
48.798798f, 48.98873f, 49.158546f, 49.225376f, 49.202282f, 49.146458f,
|
||||
44.303867f, 44.482094f, 44.679394f, 44.914833f, 45.207493f, 45.51718f,
|
||||
45.72817f, 45.9066f, 46.12764f, 46.439384f, 46.673965f, 46.853966f,
|
||||
47.055668f, 47.352055f, 47.6161f, 47.803867f, 47.99163f, 48.25568f,
|
||||
48.552063f, 48.75377f, 48.933773f, 49.16835f, 49.491726f, 49.764877f,
|
||||
49.987526f, 50.17746f, 50.347275f, 50.4141f, 50.391006f, 50.335186f,
|
||||
44.771675f, 44.949905f, 45.1472f, 45.382645f, 45.6753f, 45.98499f,
|
||||
46.195976f, 46.374413f, 46.595448f, 46.907196f, 47.141773f, 47.321774f,
|
||||
47.523476f, 47.819862f, 48.08391f, 48.27168f, 48.459446f, 48.72349f,
|
||||
49.019882f, 49.22158f, 49.401585f, 49.63616f, 49.959538f, 50.232693f,
|
||||
50.455338f, 50.64527f, 50.81509f, 50.88192f, 50.858818f, 50.803f,
|
||||
44.609966f, 44.788193f, 44.985493f, 45.220936f, 45.51359f, 45.82328f,
|
||||
46.03427f, 46.2127f, 46.433743f, 46.74549f, 46.98007f, 47.160065f,
|
||||
47.36177f, 47.658157f, 47.922207f, 48.10997f, 48.297733f, 48.561783f,
|
||||
48.858166f, 49.059875f, 49.239872f, 49.47445f, 49.79783f, 50.07098f,
|
||||
50.293625f, 50.48356f, 50.653378f, 50.720203f, 50.6971f, 50.64128f,
|
||||
44.219246f, 44.397472f, 44.594772f, 44.83021f, 45.122868f, 45.43256f,
|
||||
45.643543f, 45.82198f, 46.04302f, 46.354763f, 46.589344f, 46.76934f,
|
||||
46.971046f, 47.267433f, 47.531483f, 47.719242f, 47.907005f, 48.17105f,
|
||||
48.467438f, 48.66914f, 48.849144f, 49.08372f, 49.4071f, 49.680256f,
|
||||
49.902905f, 50.092834f, 50.262653f, 50.329483f, 50.30638f, 50.25057f});
|
||||
|
||||
auto size = NDArrayFactory::create<int>({30, 30});
|
||||
nd4j::ops::resize_bicubic op;
|
||||
|
@ -656,64 +656,63 @@ TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test1) {
|
|||
TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test2) {
|
||||
|
||||
NDArray input = NDArrayFactory::create<double>('c', {2, 5, 4, 3});
|
||||
NDArray expected = NDArrayFactory::create<double>('c', {2, 10, 8, 3}, {
|
||||
1. , 2. ,3. ,2.21875, 3.21875, 4.21875, 4. , 5. , 6. ,5.5,
|
||||
6.5, 7.5, 7., 8., 9. ,8.78125, 9.78125, 10.78125, 10., 11. ,
|
||||
12., 10.28125 , 11.28125 ,12.28125, 5.875, 6.875, 7.875, 7.09375, 8.09375 ,9.09375,
|
||||
8.875, 9.875, 10.875, 10.375, 11.375, 12.375 ,11.875 ,12.875 , 13.875, 13.65625,
|
||||
14.65625, 15.65625, 14.875 ,15.875 ,16.875 , 15.15625, 16.15625, 17.15625, 13., 14.,
|
||||
15. ,14.21875, 15.21875, 16.21875, 16., 17., 18. ,17.5 ,18.5 , 19.5,
|
||||
19., 20., 21., 20.78125 ,21.78125 ,22.78125, 22., 23. , 24. , 22.28125,
|
||||
23.28125 ,24.28125 ,19. , 20., 21., 20.21875, 21.21875, 22.21875 ,22. ,23.,
|
||||
24. , 23.5, 24.5, 25.5, 25. ,26. ,27., 26.78125 , 27.78125, 28.78125,
|
||||
28., 29. ,30. ,28.28125, 29.28125, 30.28125, 25., 26., 27. ,26.21875,
|
||||
27.21875, 28.21875, 28., 29., 30., 29.5 ,30.5 ,31.5 , 31., 32.,
|
||||
33., 32.78125, 33.78125 ,34.78125 ,34., 35., 36., 34.28125, 35.28125, 36.28125,
|
||||
31. ,32., 33. , 32.21875, 33.21875, 34.21875, 34. ,35. ,36., 35.5,
|
||||
36.5 , 37.5 , 37., 38. ,39. ,38.78125, 39.78125, 40.78125, 40., 41.,
|
||||
42. ,40.28125 ,41.28125, 42.28125, 37. , 38., 39., 38.21875 ,39.21875 ,40.21875,
|
||||
40. , 41. , 42. , 41.5, 42.5, 43.5 ,43., 44., 45., 44.78125,
|
||||
45.78125, 46.78125 ,46. ,47. , 48. , 46.28125 , 47.28125, 48.28125, 44.125 ,45.125,
|
||||
46.125, 45.34375, 46.34375, 47.34375, 47.125, 48.125 ,49.125 ,48.625, 49.625 , 50.625,
|
||||
50.125 , 51.125, 52.125 ,51.90625 ,52.90625, 53.90625, 53.125, 54.125, 55.125, 53.40625,
|
||||
54.40625 ,55.40625, 49. ,50. , 51. ,50.21875, 51.21875 ,52.21875 ,52. ,53.,
|
||||
54. ,53.5 , 54.5, 55.5 ,55. ,56. ,57. ,56.78125 ,57.78125, 58.78125,
|
||||
58. ,59. ,60. ,58.28125 ,59.28125 ,60.28125, 50.125, 51.125 ,52.125 ,51.34375,
|
||||
52.34375 ,53.34375 ,53.125, 54.125, 55.125 ,54.625 ,55.625 ,56.625 ,56.125 ,57.125,
|
||||
58.125, 57.90625 ,58.90625 ,59.90625 ,59.125 ,60.125 ,61.125, 59.40625, 60.40625 ,61.40625,
|
||||
61. ,62. ,63. ,62.21875, 63.21875, 64.21875 ,64. ,65. ,66. ,65.5 ,
|
||||
66.5, 67.5, 67. ,68. ,69. ,68.78125 ,69.78125 ,70.78125 ,70., 71. ,
|
||||
72. ,70.28125 ,71.28125 ,72.28125 ,65.875 ,66.875, 67.875 ,67.09375 ,68.09375 ,69.09375,
|
||||
68.875 ,69.875 ,70.875, 70.375 ,71.375 ,72.375 ,71.875 ,72.875 ,73.875 ,73.65625,
|
||||
74.65625 ,75.65625 ,74.875 ,75.875, 76.875 ,75.15625 ,76.15625,
|
||||
77.15625 ,73. ,74. ,75., 74.21875 ,75.21875 ,76.21875,
|
||||
76. ,77. ,78. ,77.5 ,78.5 ,79.5 ,79.,
|
||||
80. ,81. ,80.78125 ,81.78125, 82.78125 ,82. ,83.,
|
||||
84. ,82.28125 ,83.28125 ,84.28125, 79. ,80. ,81.,
|
||||
80.21875 ,81.21875 ,82.21875 ,82., 83. ,84. ,83.5,
|
||||
84.5 ,85.5 ,85. ,86., 87. ,86.78125 ,87.78125,
|
||||
88.78125 ,88. ,89. ,90., 88.28125 ,89.28125 ,90.28125,
|
||||
85. ,86. ,87. ,86.21875, 87.21875 ,88.21875 ,88.,
|
||||
89. ,90. ,89.5 ,90.5, 91.5 ,91. ,92.,
|
||||
93. ,92.78125 ,93.78125 ,94.78125, 94. ,95. ,96.,
|
||||
94.28125 ,95.28125 ,96.28125 ,91., 92. ,93. ,92.21875,
|
||||
93.21875 ,94.21875 ,94. ,95., 96. ,95.5 ,96.5,
|
||||
97.5 ,97. ,98. ,99., 98.78125 ,99.78125 ,100.78125,
|
||||
100. ,101. ,102. ,100.28125, 101.28125 ,102.28125, 97.,
|
||||
98. ,99. ,98.21875 ,99.21875, 100.21875 ,100., 101.,
|
||||
102. ,101.5 ,102.5 ,103.5, 103. ,104., 105.,
|
||||
104.78125 ,105.78125 ,106.78125 ,106., 107. ,108., 106.28125,
|
||||
107.28125 ,108.28125 ,104.125 ,105.125, 106.125 ,105.34375, 106.34375,
|
||||
107.34375 ,107.125 ,108.125 ,109.125, 108.625 ,109.625, 110.625,
|
||||
110.125 ,111.125 ,112.125 ,111.90625, 112.90625 ,113.90625, 113.125,
|
||||
114.125 ,115.125 ,113.40625 ,114.40625, 115.40625 ,109., 110.,
|
||||
111. ,110.21875 ,111.21875 ,112.21875, 112., 113., 114.,
|
||||
113.5 ,114.5 ,115.5 ,115., 116., 117., 116.78125,
|
||||
117.78125 ,118.78125 ,118. ,119., 120., 118.28125, 119.28125,
|
||||
120.28125 ,110.125 ,111.125 ,112.125, 111.34375, 112.34375, 113.34375,
|
||||
113.125 ,114.125 ,115.125 ,114.625, 115.625, 116.625, 116.125,
|
||||
117.125 ,118.125 ,117.90625, 118.90625, 119.90625, 119.125, 120.125,
|
||||
121.125 ,119.40625 ,120.40625, 121.40625}); //input = 1.f;
|
||||
NDArray expected = NDArrayFactory::create<float>('c', {2, 10, 8, 3}, {
|
||||
1.000000f, 2.000000f, 3.000000f, 2.218750f, 3.218750f, 4.218750f, 4.000000f, 5.000000f, 6.000000f,
|
||||
5.500000f, 6.500000f, 7.500000f, 7.000000f, 8.000000f, 9.000000f, 8.781250f, 9.781250f, 10.781250f,
|
||||
10.000000f, 11.000000f, 12.000000f, 10.281250f, 11.281250f, 12.281250f, 5.875000f, 6.875000f, 7.875000f,
|
||||
7.093750f, 8.093750f, 9.093750f, 8.875000f, 9.875000f, 10.875000f, 10.375000f, 11.375000f, 12.375000f,
|
||||
11.875000f, 12.875000f, 13.875000f, 13.656250f, 14.656250f, 15.656250f, 14.875000f, 15.875000f, 16.875000f,
|
||||
15.156250f, 16.156250f, 17.156250f, 13.000000f, 14.000000f, 15.000000f, 14.218750f, 15.218750f, 16.218750f,
|
||||
16.000000f, 17.000000f, 18.000000f, 17.500000f, 18.500000f, 19.500000f, 19.000000f, 20.000000f, 21.000000f,
|
||||
20.781250f, 21.781250f, 22.781250f, 22.000000f, 23.000000f, 24.000000f, 22.281250f, 23.281250f, 24.281250f,
|
||||
19.000000f, 20.000000f, 21.000000f, 20.218750f, 21.218750f, 22.218750f, 22.000000f, 23.000000f, 24.000000f,
|
||||
23.500000f, 24.500000f, 25.500000f, 25.000000f, 26.000000f, 27.000000f, 26.781250f, 27.781250f, 28.781250f,
|
||||
28.000000f, 29.000000f, 30.000000f, 28.281250f, 29.281250f, 30.281250f, 25.000000f, 26.000000f, 27.000000f,
|
||||
26.218750f, 27.218750f, 28.218750f, 28.000000f, 29.000000f, 30.000000f, 29.500000f, 30.500000f, 31.500000f,
|
||||
31.000000f, 32.000000f, 33.000000f, 32.781250f, 33.781250f, 34.781250f, 34.000000f, 35.000000f, 36.000000f,
|
||||
34.281250f, 35.281250f, 36.281250f, 31.000000f, 32.000000f, 33.000000f, 32.218750f, 33.218750f, 34.218750f,
|
||||
34.000000f, 35.000000f, 36.000000f, 35.500000f, 36.500000f, 37.500000f, 37.000000f, 38.000000f, 39.000000f,
|
||||
38.781250f, 39.781250f, 40.781250f, 40.000000f, 41.000000f, 42.000000f, 40.281250f, 41.281250f, 42.281250f,
|
||||
37.000000f, 38.000000f, 39.000000f, 38.218750f, 39.218750f, 40.218750f, 40.000000f, 41.000000f, 42.000000f,
|
||||
41.500000f, 42.500000f, 43.500000f, 43.000000f, 44.000000f, 45.000000f, 44.781250f, 45.781250f, 46.781250f,
|
||||
46.000000f, 47.000000f, 48.000000f, 46.281250f, 47.281250f, 48.281250f, 44.125000f, 45.125000f, 46.125000f,
|
||||
45.343750f, 46.343750f, 47.343750f, 47.125000f, 48.125000f, 49.125000f, 48.625000f, 49.625000f, 50.625000f,
|
||||
50.125000f, 51.125000f, 52.125000f, 51.906250f, 52.906250f, 53.906250f, 53.125000f, 54.125000f, 55.125000f,
|
||||
53.406250f, 54.406250f, 55.406250f, 49.000000f, 50.000000f, 51.000000f, 50.218750f, 51.218750f, 52.218750f,
|
||||
52.000000f, 53.000000f, 54.000000f, 53.500000f, 54.500000f, 55.500000f, 55.000000f, 56.000000f, 57.000000f,
|
||||
56.781250f, 57.781250f, 58.781250f, 58.000000f, 59.000000f, 60.000000f, 58.281250f, 59.281250f, 60.281250f,
|
||||
50.125000f, 51.125000f, 52.125000f, 51.343750f, 52.343750f, 53.343750f, 53.125000f, 54.125000f, 55.125000f,
|
||||
54.625000f, 55.625000f, 56.625000f, 56.125000f, 57.125000f, 58.125000f, 57.906250f, 58.906250f, 59.906250f,
|
||||
59.125000f, 60.125000f, 61.125000f, 59.406250f, 60.406250f, 61.406250f, 61.000000f, 62.000000f, 63.000000f,
|
||||
62.218750f, 63.218750f, 64.218750f, 64.000000f, 65.000000f, 66.000000f, 65.500000f, 66.500000f, 67.500000f,
|
||||
67.000000f, 68.000000f, 69.000000f, 68.781250f, 69.781250f, 70.781250f, 70.000000f, 71.000000f, 72.000000f,
|
||||
70.281250f, 71.281250f, 72.281250f, 65.875000f, 66.875000f, 67.875000f, 67.093750f, 68.093750f, 69.093750f,
|
||||
68.875000f, 69.875000f, 70.875000f, 70.375000f, 71.375000f, 72.375000f, 71.875000f, 72.875000f, 73.875000f,
|
||||
73.656250f, 74.656250f, 75.656250f, 74.875000f, 75.875000f, 76.875000f, 75.156250f, 76.156250f, 77.156250f,
|
||||
73.000000f, 74.000000f, 75.000000f, 74.218750f, 75.218750f, 76.218750f, 76.000000f, 77.000000f, 78.000000f,
|
||||
77.500000f, 78.500000f, 79.500000f, 79.000000f, 80.000000f, 81.000000f, 80.781250f, 81.781250f, 82.781250f,
|
||||
82.000000f, 83.000000f, 84.000000f, 82.281250f, 83.281250f, 84.281250f, 79.000000f, 80.000000f, 81.000000f,
|
||||
80.218750f, 81.218750f, 82.218750f, 82.000000f, 83.000000f, 84.000000f, 83.500000f, 84.500000f, 85.500000f,
|
||||
85.000000f, 86.000000f, 87.000000f, 86.781250f, 87.781250f, 88.781250f, 88.000000f, 89.000000f, 90.000000f,
|
||||
88.281250f, 89.281250f, 90.281250f, 85.000000f, 86.000000f, 87.000000f, 86.218750f, 87.218750f, 88.218750f,
|
||||
88.000000f, 89.000000f, 90.000000f, 89.500000f, 90.500000f, 91.500000f, 91.000000f, 92.000000f, 93.000000f,
|
||||
92.781250f, 93.781250f, 94.781250f, 94.000000f, 95.000000f, 96.000000f, 94.281250f, 95.281250f, 96.281250f,
|
||||
91.000000f, 92.000000f, 93.000000f, 92.218750f, 93.218750f, 94.218750f, 94.000000f, 95.000000f, 96.000000f,
|
||||
95.500000f, 96.500000f, 97.500000f, 97.000000f, 98.000000f, 99.000000f, 98.781250f, 99.781250f, 100.781250f,
|
||||
100.000000f, 101.000000f, 102.000000f, 100.281250f, 101.281250f, 102.281250f, 97.000000f, 98.000000f,
|
||||
99.000000f, 98.218750f, 99.218750f, 100.218750f, 100.000000f, 101.000000f, 102.000000f, 101.500000f,
|
||||
102.500000f, 103.500000f, 103.000000f, 104.000000f, 105.000000f, 104.781250f, 105.781250f, 106.781250f,
|
||||
106.000000f, 107.000000f, 108.000000f, 106.281250f, 107.281250f, 108.281250f, 104.125000f, 105.125000f,
|
||||
106.125000f, 105.343750f, 106.343750f, 107.343750f, 107.125000f, 108.125000f, 109.125000f, 108.625000f,
|
||||
109.625000f, 110.625000f, 110.125000f, 111.125000f, 112.125000f, 111.906250f, 112.906250f, 113.906250f,
|
||||
113.125000f, 114.125000f, 115.125000f, 113.406250f, 114.406250f, 115.406250f, 109.000000f, 110.000000f,
|
||||
111.000000f, 110.218750f, 111.218750f, 112.218750f, 112.000000f, 113.000000f, 114.000000f, 113.500000f,
|
||||
114.500000f, 115.500000f, 115.000000f, 116.000000f, 117.000000f, 116.781250f, 117.781250f, 118.781250f,
|
||||
118.000000f, 119.000000f, 120.000000f, 118.281250f, 119.281250f, 120.281250f, 110.125000f, 111.125000f,
|
||||
112.125000f, 111.343750f, 112.343750f, 113.343750f, 113.125000f, 114.125000f, 115.125000f, 114.625000f,
|
||||
115.625000f, 116.625000f, 116.125000f, 117.125000f, 118.125000f, 117.906250f, 118.906250f, 119.906250f,
|
||||
119.125000f, 120.125000f, 121.125000f, 119.406250f, 120.406250f, 121.406250f
|
||||
}); //input = 1.f;
|
||||
input.linspace(1);
|
||||
auto size = NDArrayFactory::create<int>({10, 8});
|
||||
nd4j::ops::resize_bicubic op;
|
||||
|
@ -733,48 +732,23 @@ TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test2) {
|
|||
TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test3) {
|
||||
|
||||
NDArray input = NDArrayFactory::create<double>('c', {1, 3, 3, 4});
|
||||
NDArray expected = NDArrayFactory::create<double>('c', {1, 6, 6, 4}, {
|
||||
1. ,2. ,3. ,4.,
|
||||
2.625 ,3.625 ,4.625 ,5.625,
|
||||
5. ,6. ,7. ,8.,
|
||||
7.375 ,8.375 ,9.375, 10.375,
|
||||
9. ,10. ,11. ,12.,
|
||||
9.375 ,10.375 ,11.375 ,12.375,
|
||||
|
||||
5.875 ,6.875 ,7.875 , 8.875 ,
|
||||
7.5 ,8.5 ,9.5 , 10.5 ,
|
||||
9.875 ,10.875 ,11.875, 12.875,
|
||||
12.25 ,13.25 ,14.25 , 15.25 ,
|
||||
13.875 ,14.875 ,15.875, 16.875,
|
||||
14.25 ,15.25 ,16.25 , 17.25 ,
|
||||
|
||||
13. ,14. ,15. ,16.,
|
||||
14.625 ,15.625 ,16.625 ,17.625,
|
||||
17. ,18. ,19. ,20.,
|
||||
19.375 ,20.375 ,21.375 ,22.375,
|
||||
21. ,22. ,23. ,24.,
|
||||
21.375 ,22.375 ,23.375 ,24.375,
|
||||
|
||||
20.125 ,21.125 ,22.125 ,23.125,
|
||||
21.75 ,22.75 ,23.75 ,24.75,
|
||||
24.125 ,25.125 ,26.125 ,27.125,
|
||||
26.5 ,27.5 ,28.5 ,29.5,
|
||||
28.125 ,29.125 ,30.125 ,31.125,
|
||||
28.5 ,29.5 ,30.5 ,31.5,
|
||||
|
||||
25. , 26. , 27. , 28.,
|
||||
26.625 ,27.625 ,28.625 ,29.625,
|
||||
29. ,30. ,31. ,32.,
|
||||
31.375 ,32.375 ,33.375 ,34.375,
|
||||
33. ,34. ,35. ,36.,
|
||||
33.375 ,34.375 ,35.375 ,36.375,
|
||||
|
||||
26.125, 27.125, 28.125, 29.125,
|
||||
27.75 ,28.75 ,29.75 ,30.75,
|
||||
30.125 ,31.125 ,32.125 ,33.125,
|
||||
32.5 ,33.5 ,34.5 ,35.5,
|
||||
34.125 ,35.125 ,36.125 ,37.125,
|
||||
34.5 ,35.5 ,36.5 ,37.5
|
||||
NDArray expected = NDArrayFactory::create<float>('c', {1, 6, 6, 4}, {
|
||||
1.000000f, 2.000000f, 3.000000f, 4.000000f, 2.625000f, 3.625000f, 4.625000f, 5.625000f, 5.000000f,
|
||||
6.000000f, 7.000000f, 8.000000f, 7.375000f, 8.375000f, 9.375000f, 10.375000f, 9.000000f, 10.000000f,
|
||||
11.000000f, 12.000000f, 9.375000f, 10.375000f, 11.375000f, 12.375000f, 5.875000f, 6.875000f, 7.875000f,
|
||||
8.875000f, 7.500000f, 8.500000f, 9.500000f, 10.500000f, 9.875000f, 10.875000f, 11.875000f, 12.875000f,
|
||||
12.250000f, 13.250000f, 14.250000f, 15.250000f, 13.875000f, 14.875000f, 15.875000f, 16.875000f, 14.250000f,
|
||||
15.250000f, 16.250000f, 17.250000f, 13.000000f, 14.000000f, 15.000000f, 16.000000f, 14.625000f, 15.625000f,
|
||||
16.625000f, 17.625000f, 17.000000f, 18.000000f, 19.000000f, 20.000000f, 19.375000f, 20.375000f, 21.375000f,
|
||||
22.375000f, 21.000000f, 22.000000f, 23.000000f, 24.000000f, 21.375000f, 22.375000f, 23.375000f, 24.375000f,
|
||||
20.125000f, 21.125000f, 22.125000f, 23.125000f, 21.750000f, 22.750000f, 23.750000f, 24.750000f, 24.125000f,
|
||||
25.125000f, 26.125000f, 27.125000f, 26.500000f, 27.500000f, 28.500000f, 29.500000f, 28.125000f, 29.125000f,
|
||||
30.125000f, 31.125000f, 28.500000f, 29.500000f, 30.500000f, 31.500000f, 25.000000f, 26.000000f, 27.000000f,
|
||||
28.000000f, 26.625000f, 27.625000f, 28.625000f, 29.625000f, 29.000000f, 30.000000f, 31.000000f, 32.000000f,
|
||||
31.375000f, 32.375000f, 33.375000f, 34.375000f, 33.000000f, 34.000000f, 35.000000f, 36.000000f, 33.375000f,
|
||||
34.375000f, 35.375000f, 36.375000f, 26.125000f, 27.125000f, 28.125000f, 29.125000f, 27.750000f, 28.750000f,
|
||||
29.750000f, 30.750000f, 30.125000f, 31.125000f, 32.125000f, 33.125000f, 32.500000f, 33.500000f, 34.500000f,
|
||||
35.500000f, 34.125000f, 35.125000f, 36.125000f, 37.125000f, 34.500000f, 35.500000f, 36.500000f, 37.500000f
|
||||
});
|
||||
input.linspace(1);
|
||||
auto size = NDArrayFactory::create<int>({6, 6});
|
||||
|
@ -795,60 +769,24 @@ TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test3) {
|
|||
TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test4) {
|
||||
|
||||
NDArray input = NDArrayFactory::create<double>('c', {1, 3, 4, 3});
|
||||
NDArray expected = NDArrayFactory::create<double>('c', {1, 6, 8, 3}, {
|
||||
1. , 2. , 3. ,
|
||||
2.21875 ,3.21875 ,4.21875,
|
||||
4. ,5. ,6. ,
|
||||
5.5 ,6.5 ,7.5 ,
|
||||
7. ,8. ,9. ,
|
||||
8.78125 ,9.78125, 10.78125,
|
||||
10. ,11., 12. ,
|
||||
10.28125 ,11.28125, 12.28125,
|
||||
|
||||
5.875 , 6.875 , 7.875 ,
|
||||
7.09375 , 8.09375 , 9.09375,
|
||||
8.875 , 9.875 ,10.875 ,
|
||||
10.375 ,11.375 ,12.375 ,
|
||||
11.875 ,12.875 ,13.875 ,
|
||||
13.65625 ,14.65625 ,15.65625,
|
||||
14.875 ,15.875 ,16.875 ,
|
||||
15.15625 ,16.15625 ,17.15625,
|
||||
|
||||
13., 14., 15.,
|
||||
14.21875 ,15.21875 ,16.21875,
|
||||
16. ,17. ,18. ,
|
||||
17.5 ,18.5 ,19.5 ,
|
||||
19. ,20. ,21. ,
|
||||
20.78125 ,21.78125 ,22.78125,
|
||||
22. ,23. ,24. ,
|
||||
22.28125 ,23.28125 ,24.28125,
|
||||
|
||||
20.125 , 21.125 , 22.125,
|
||||
21.34375 ,22.34375 ,23.34375,
|
||||
23.125 ,24.125 ,25.125 ,
|
||||
24.625 ,25.625 ,26.625 ,
|
||||
26.125 ,27.125 ,28.125 ,
|
||||
27.90625 ,28.90625 ,29.90625,
|
||||
29.125 ,30.125 ,31.125 ,
|
||||
29.40625 ,30.40625 ,31.40625,
|
||||
|
||||
25. ,26. ,27. ,
|
||||
26.21875 ,27.21875 ,28.21875,
|
||||
28. ,29. ,30. ,
|
||||
29.5 ,30.5 ,31.5 ,
|
||||
31. ,32. ,33. ,
|
||||
32.78125 ,33.78125 ,34.78125,
|
||||
34. ,35. ,36. ,
|
||||
34.28125 ,35.28125 ,36.28125,
|
||||
|
||||
26.125 ,27.125 , 28.125 ,
|
||||
27.34375 ,28.34375 ,29.34375,
|
||||
29.125 ,30.125 ,31.125 ,
|
||||
30.625 ,31.625 ,32.625 ,
|
||||
32.125 ,33.125 ,34.125 ,
|
||||
33.90625 ,34.90625 ,35.90625,
|
||||
35.125 ,36.125 ,37.125 ,
|
||||
35.40625 ,36.40625 ,37.40625 });
|
||||
NDArray expected = NDArrayFactory::create<float>('c', {1, 6, 8, 3}, {
|
||||
1.000000f, 2.000000f, 3.000000f, 2.218750f, 3.218750f, 4.218750f, 4.000000f, 5.000000f, 6.000000f,
|
||||
5.500000f, 6.500000f, 7.500000f, 7.000000f, 8.000000f, 9.000000f, 8.781250f, 9.781250f, 10.781250f,
|
||||
10.000000f, 11.000000f, 12.000000f, 10.281250f, 11.281250f, 12.281250f, 5.875000f, 6.875000f, 7.875000f,
|
||||
7.093750f, 8.093750f, 9.093750f, 8.875000f, 9.875000f, 10.875000f, 10.375000f, 11.375000f, 12.375000f,
|
||||
11.875000f, 12.875000f, 13.875000f, 13.656250f, 14.656250f, 15.656250f, 14.875000f, 15.875000f, 16.875000f,
|
||||
15.156250f, 16.156250f, 17.156250f, 13.000000f, 14.000000f, 15.000000f, 14.218750f, 15.218750f, 16.218750f,
|
||||
16.000000f, 17.000000f, 18.000000f, 17.500000f, 18.500000f, 19.500000f, 19.000000f, 20.000000f, 21.000000f,
|
||||
20.781250f, 21.781250f, 22.781250f, 22.000000f, 23.000000f, 24.000000f, 22.281250f, 23.281250f, 24.281250f,
|
||||
20.125000f, 21.125000f, 22.125000f, 21.343750f, 22.343750f, 23.343750f, 23.125000f, 24.125000f, 25.125000f,
|
||||
24.625000f, 25.625000f, 26.625000f, 26.125000f, 27.125000f, 28.125000f, 27.906250f, 28.906250f, 29.906250f,
|
||||
29.125000f, 30.125000f, 31.125000f, 29.406250f, 30.406250f, 31.406250f, 25.000000f, 26.000000f, 27.000000f,
|
||||
26.218750f, 27.218750f, 28.218750f, 28.000000f, 29.000000f, 30.000000f, 29.500000f, 30.500000f, 31.500000f,
|
||||
31.000000f, 32.000000f, 33.000000f, 32.781250f, 33.781250f, 34.781250f, 34.000000f, 35.000000f, 36.000000f,
|
||||
34.281250f, 35.281250f, 36.281250f, 26.125000f, 27.125000f, 28.125000f, 27.343750f, 28.343750f, 29.343750f,
|
||||
29.125000f, 30.125000f, 31.125000f, 30.625000f, 31.625000f, 32.625000f, 32.125000f, 33.125000f, 34.125000f,
|
||||
33.906250f, 34.906250f, 35.906250f, 35.125000f, 36.125000f, 37.125000f, 35.406250f, 36.406250f, 37.406250f
|
||||
});
|
||||
input.linspace(1);
|
||||
auto size = NDArrayFactory::create<int>({6, 8});
|
||||
nd4j::ops::resize_bicubic op;
|
||||
|
@ -868,32 +806,30 @@ TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test4) {
|
|||
TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test5) {
|
||||
|
||||
NDArray input = NDArrayFactory::create<double>('c', {1, 4, 4, 3});
|
||||
NDArray expected = NDArrayFactory::create<double>('c', {1, 8, 8, 3}, {
|
||||
1. ,2. , 3. , 2.21875 , 3.21875 , 4.21875 , 4. , 5. ,
|
||||
6. ,5.5 , 6.5 , 7.5 , 7. , 8. , 9. , 8.78125 ,
|
||||
9.78125 ,10.78125 ,10. ,11. ,12. ,10.28125 ,11.28125 ,12.28125 ,
|
||||
5.875 ,6.875 , 7.875 , 7.09375 , 8.09375 , 9.09375 , 8.875 , 9.875 ,
|
||||
10.875 ,10.375 , 11.375 , 12.375 , 11.875 , 12.875 , 13.875 , 13.65625,
|
||||
14.65625 ,15.65625, 14.875 , 15.875 , 16.875 , 15.15625, 16.15625, 17.15625,
|
||||
13. ,14. , 15. , 14.21875, 15.21875, 16.21875, 16. , 17. ,
|
||||
18. ,17.5 , 18.5 , 19.5 , 19. , 20. , 21. , 20.78125,
|
||||
21.78125 ,22.78125, 22. , 23. , 24. , 22.28125, 23.28125, 24.28125,
|
||||
19. ,20. , 21. , 20.21875, 21.21875, 22.21875, 22. , 23. ,
|
||||
24. ,23.5 , 24.5 , 25.5 , 25. , 26. , 27. , 26.78125,
|
||||
27.78125 ,28.78125, 28. , 29. , 30. , 28.28125, 29.28125, 30.28125,
|
||||
25. ,26. , 27. , 26.21875, 27.21875, 28.21875, 28. , 29. ,
|
||||
30. ,29.5 , 30.5 , 31.5 , 31. , 32. , 33. , 32.78125,
|
||||
33.78125 ,34.78125, 34. , 35. , 36. , 34.28125, 35.28125, 36.28125,
|
||||
32.125 ,33.125 , 34.125 , 33.34375, 34.34375, 35.34375, 35.125 , 36.125 ,
|
||||
37.125 ,36.625 , 37.625 , 38.625 , 38.125 , 39.125 , 40.125 , 39.90625,
|
||||
40.90625 ,41.90625, 41.125 , 42.125 , 43.125 , 41.40625, 42.40625, 43.40625,
|
||||
37. ,38. , 39. , 38.21875, 39.21875, 40.21875, 40. , 41. ,
|
||||
42. ,41.5 , 42.5 , 43.5 , 43. , 44. , 45. , 44.78125,
|
||||
45.78125 ,46.78125, 46. , 47. , 48. , 46.28125, 47.28125, 48.28125,
|
||||
38.125 ,39.125 , 40.125 , 39.34375, 40.34375, 41.34375, 41.125 , 42.125 ,
|
||||
43.125 ,42.625 , 43.625 , 44.625 , 44.125 , 45.125 , 46.125 , 45.90625,
|
||||
46.90625 ,47.90625, 47.125 , 48.125 , 49.125 , 47.40625, 48.40625, 49.40625,
|
||||
});
|
||||
NDArray expected = NDArrayFactory::create<float>('c', {1, 8, 8, 3}, {
|
||||
1.000000f, 2.000000f, 3.000000f, 2.218750f, 3.218750f, 4.218750f, 4.000000f, 5.000000f, 6.000000f,
|
||||
5.500000f, 6.500000f, 7.500000f, 7.000000f, 8.000000f, 9.000000f, 8.781250f, 9.781250f, 10.781250f,
|
||||
10.000000f, 11.000000f, 12.000000f, 10.281250f, 11.281250f, 12.281250f, 5.875000f, 6.875000f, 7.875000f,
|
||||
7.093750f, 8.093750f, 9.093750f, 8.875000f, 9.875000f, 10.875000f, 10.375000f, 11.375000f, 12.375000f,
|
||||
11.875000f, 12.875000f, 13.875000f, 13.656250f, 14.656250f, 15.656250f, 14.875000f, 15.875000f, 16.875000f,
|
||||
15.156250f, 16.156250f, 17.156250f, 13.000000f, 14.000000f, 15.000000f, 14.218750f, 15.218750f, 16.218750f,
|
||||
16.000000f, 17.000000f, 18.000000f, 17.500000f, 18.500000f, 19.500000f, 19.000000f, 20.000000f, 21.000000f,
|
||||
20.781250f, 21.781250f, 22.781250f, 22.000000f, 23.000000f, 24.000000f, 22.281250f, 23.281250f, 24.281250f,
|
||||
19.000000f, 20.000000f, 21.000000f, 20.218750f, 21.218750f, 22.218750f, 22.000000f, 23.000000f, 24.000000f,
|
||||
23.500000f, 24.500000f, 25.500000f, 25.000000f, 26.000000f, 27.000000f, 26.781250f, 27.781250f, 28.781250f,
|
||||
28.000000f, 29.000000f, 30.000000f, 28.281250f, 29.281250f, 30.281250f, 25.000000f, 26.000000f, 27.000000f,
|
||||
26.218750f, 27.218750f, 28.218750f, 28.000000f, 29.000000f, 30.000000f, 29.500000f, 30.500000f, 31.500000f,
|
||||
31.000000f, 32.000000f, 33.000000f, 32.781250f, 33.781250f, 34.781250f, 34.000000f, 35.000000f, 36.000000f,
|
||||
34.281250f, 35.281250f, 36.281250f, 32.125000f, 33.125000f, 34.125000f, 33.343750f, 34.343750f, 35.343750f,
|
||||
35.125000f, 36.125000f, 37.125000f, 36.625000f, 37.625000f, 38.625000f, 38.125000f, 39.125000f, 40.125000f,
|
||||
39.906250f, 40.906250f, 41.906250f, 41.125000f, 42.125000f, 43.125000f, 41.406250f, 42.406250f, 43.406250f,
|
||||
37.000000f, 38.000000f, 39.000000f, 38.218750f, 39.218750f, 40.218750f, 40.000000f, 41.000000f, 42.000000f,
|
||||
41.500000f, 42.500000f, 43.500000f, 43.000000f, 44.000000f, 45.000000f, 44.781250f, 45.781250f, 46.781250f,
|
||||
46.000000f, 47.000000f, 48.000000f, 46.281250f, 47.281250f, 48.281250f, 38.125000f, 39.125000f, 40.125000f,
|
||||
39.343750f, 40.343750f, 41.343750f, 41.125000f, 42.125000f, 43.125000f, 42.625000f, 43.625000f, 44.625000f,
|
||||
44.125000f, 45.125000f, 46.125000f, 45.906250f, 46.906250f, 47.906250f, 47.125000f, 48.125000f, 49.125000f,
|
||||
47.406250f, 48.406250f, 49.406250f,
|
||||
});
|
||||
input.linspace(1);
|
||||
auto size = NDArrayFactory::create<int>({8, 8});
|
||||
nd4j::ops::resize_bicubic op;
|
||||
|
@ -912,167 +848,118 @@ TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test5) {
|
|||
|
||||
TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test6) {
|
||||
|
||||
NDArray input = NDArrayFactory::create<double>('c', {7, 7, 1}, {
|
||||
1, 2.1, 3.15, 4.2, 5.15, 6.1, 7,
|
||||
8, 9.1, 10., 11, 12.9, 13.1, 14,
|
||||
15, 16., 17., 18, 19, 20., 21,
|
||||
22, 23., 24., 25, 26, 27, 28,
|
||||
30, 31, 32, 33, 34., 35, 36,
|
||||
37, 38, 39, 40, 41., 42, 43,
|
||||
44, 45, 46, 47, 48., 49, 50
|
||||
NDArray input = NDArrayFactory::create<float>('c', {7, 7, 1}, {
|
||||
1.f, 2.1f, 3.15f, 4.2f, 5.15f, 6.1f, 7.f,
|
||||
8.f, 9.1f, 10.f, 11.f, 12.9f, 13.1f, 14.f,
|
||||
15.f, 16.f, 17.f, 18.f, 19.f, 20.f, 21.f,
|
||||
22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 28.f,
|
||||
30.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f,
|
||||
37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 43.f,
|
||||
44.f, 45.f, 46.f, 47.f, 48.f, 49.f, 50.f
|
||||
});
|
||||
|
||||
NDArray expected = NDArrayFactory::create<double>('c', {30, 30, 1}, {
|
||||
1. ,1.1976162 ,1.4174359 ,1.6775769 ,1.9961575 ,2.3283265 ,
|
||||
2.550918 ,2.7360606 ,2.9655411 ,3.2929654 ,3.5441515 ,3.7380352 ,
|
||||
3.948995 ,4.248106 ,4.5073795 ,4.6843743 ,4.8572845 ,5.104302 ,
|
||||
5.3869915 ,5.581401 ,5.7539616 ,5.974285 ,6.272836 ,6.5204263 ,
|
||||
6.718899 ,6.8871036 ,7.039068 ,7.099216 ,7.0784245 ,7.0281887 ,
|
||||
2.247592 ,2.446947 ,2.6694887 ,2.9312382 ,3.248216 ,3.5745337 ,
|
||||
3.78931 ,3.9656973 ,4.186417 ,4.5046535 ,4.740569 ,4.9217057 ,
|
||||
5.133866 ,5.459533 ,5.7744613 ,6.0197873 ,6.254011 ,6.535633 ,
|
||||
6.8097296 ,6.9607787 ,7.0749416 ,7.241601 ,7.5094895 ,7.7499495 ,
|
||||
7.954571 ,8.131972 ,8.286526 ,8.346463 ,8.325745 ,8.275683 ,
|
||||
3.6286845 ,3.830573 ,4.0569587 ,4.3211575 ,4.6364856 ,4.9556503 ,
|
||||
5.160583 ,5.3258467 ,5.535462 ,5.84216 ,6.058749 ,6.223753 ,
|
||||
6.437597 ,6.797369 ,7.1836042 ,7.5164022 ,7.8290343 ,8.154773 ,
|
||||
8.417635 ,8.512958 ,8.5521 ,8.649708 ,8.87788 ,9.108794 ,
|
||||
9.320926 ,9.509781 ,9.667375 ,9.72694 ,9.706349 ,9.656599 ,
|
||||
5.276778 ,5.480438 ,5.709702 ,5.9754477 ,6.288551 ,6.6005697 ,
|
||||
6.796207 ,6.9511423 ,7.1503997 ,7.4461427 ,7.644651 ,7.794562 ,
|
||||
8.009684 ,8.400473 ,8.851847 ,9.26469 ,9.649218, 10.015648 ,
|
||||
10.268647 ,10.313368 ,10.2843275 ,10.319379 ,10.512033 ,10.734956 ,
|
||||
10.954604 ,11.154507 ,11.315369 ,11.374779 ,11.354242 ,11.304622 ,
|
||||
7.325373 ,7.5284843 ,7.757575 ,8.022221 ,8.331997 ,8.638187 ,
|
||||
8.827649 ,8.976217 ,9.168955 ,9.45726 ,9.6442375 ,9.784517 ,
|
||||
9.999621, 10.407702 ,10.896234, 11.355122, 11.781423, 12.172186 ,
|
||||
12.420712 ,12.4374485 ,12.370511 ,12.371386 ,12.545973 ,12.766424 ,
|
||||
12.992249 ,13.20012 ,13.364252 ,13.424109 ,13.40342 ,13.353425 ,
|
||||
9.493208 ,9.692467 ,9.9169445, 10.176801, 10.482199, 10.78547 ,
|
||||
10.974367 ,11.123442 ,11.31637 ,11.603645 ,11.790616 ,11.930889 ,
|
||||
12.144082 ,12.546447 ,13.024898 ,13.4723 ,13.889232 ,14.276275 ,
|
||||
14.528972 ,14.555555 ,14.50145 ,14.515459 ,14.700572 ,14.927055 ,
|
||||
15.156046 ,15.366046 ,15.532901 ,15.594008 ,15.5728855 ,15.521847 ,
|
||||
10.970133 ,11.163599 ,11.380694 ,11.633735 ,11.935032 ,12.238887 ,
|
||||
12.43254 ,12.588294 ,12.787534 ,13.079956 ,13.27752 ,13.426631 ,
|
||||
13.636713 ,14.013844 ,14.441672 ,14.827978 ,15.191209 ,15.549808 ,
|
||||
15.81343 ,15.881828 ,15.883522 ,15.950411 ,16.16933 ,16.40794 ,
|
||||
16.636436 ,16.842583 ,17.010887 ,17.07363 ,17.05194 ,16.999537 ,
|
||||
12.219155 ,12.406129 ,12.614796 ,12.860335 ,13.157928 ,13.464224 ,
|
||||
13.665207 ,13.830567 ,14.039036 ,14.339629 ,14.552863 ,14.715049 ,
|
||||
14.921564 ,15.264454 ,15.622843 ,15.924977 ,16.213829 ,16.532364 ,
|
||||
16.8099 ,16.934835 ,17.012146 ,17.150164 ,17.413412 ,17.666712 ,
|
||||
17.892765 ,18.09207 ,18.261044 ,18.325531 ,18.303238 ,18.249378 ,
|
||||
13.7663965 ,13.947391 ,14.148263 ,14.386917 ,14.681246 ,14.990087 ,
|
||||
15.198166 ,15.372728 ,15.590062 ,15.898583 ,16.126892 ,16.301655 ,
|
||||
16.50487 ,16.815214 ,17.107498 ,17.329458 ,17.547403 ,17.827654 ,
|
||||
18.118288 ,18.296928 ,18.4461 ,18.651634 ,18.956806 ,19.22382 ,
|
||||
19.447308 ,19.639887 ,19.809319 ,19.875397 ,19.852556 ,19.797365 ,
|
||||
15.9419365 ,16.118704 ,16.314133 ,16.547867 ,16.839561 ,17.14954 ,
|
||||
17.361883 ,17.542162 ,17.764957 ,18.078188 ,18.315733 ,18.498205 ,
|
||||
18.699116 ,18.988684 ,19.238989 ,19.410137 ,19.583265 ,19.839512 ,
|
||||
20.13878 ,20.35177 ,20.546844 ,20.795671 ,21.128067 ,21.404358 ,
|
||||
21.626736 ,21.8155 ,21.98561 ,22.052843 ,22.029604 ,21.973448 ,
|
||||
17.53522 ,17.71077 ,17.904636 ,18.13695 ,18.42784 ,18.738056 ,
|
||||
18.951529 ,19.133352 ,19.357613 ,19.672083 ,19.912102 ,20.096638 ,
|
||||
20.296894 ,20.580765 ,20.819603 ,20.976887 ,21.137802 ,21.387535 ,
|
||||
21.689209 ,21.911621 ,22.119276 ,22.37999 ,22.71991 ,22.998823 ,
|
||||
23.22097 ,23.40876 ,23.57911 ,23.646685 ,23.623325 ,23.566887 ,
|
||||
18.746353 ,18.922657 ,19.117487 ,19.350685 ,19.64207 ,19.952137 ,
|
||||
20.164913 ,20.345781 ,20.569134 ,20.88284 ,21.12133 ,21.30459 ,
|
||||
21.505253 ,21.792645 ,22.038572 ,22.204426 ,22.37289 ,22.626648 ,
|
||||
22.926834 ,23.143423 ,23.343302 ,23.596668 ,23.931936 ,24.209232 ,
|
||||
24.431519 ,24.619913 ,24.79011 ,24.857473 ,24.83419 ,24.777927 ,
|
||||
20.16656 ,20.344206 ,20.540766 ,20.775532 ,21.067804 ,21.377607 ,
|
||||
21.589132 ,21.768297 ,21.99003 ,22.302366 ,22.538124 ,22.719105 ,
|
||||
22.920494 ,23.214176 ,23.472767 ,23.653934 ,23.83589 ,24.096842 ,
|
||||
24.394371 ,24.600555 ,24.786541 ,25.026773 ,25.353731 ,25.62813 ,
|
||||
25.850672 ,26.04014 ,26.210072 ,26.277063 ,26.253906 ,26.197956 ,
|
||||
22.363024 ,22.54125 ,22.738552 ,22.973991 ,23.266647 ,23.57634 ,
|
||||
23.787327 ,23.96576 ,24.186796 ,24.498543 ,24.733124 ,24.913122 ,
|
||||
25.114826 ,25.411213 ,25.675262 ,25.863028 ,26.050789 ,26.314838 ,
|
||||
26.611223 ,26.812925 ,26.992926 ,27.227505 ,27.550882 ,27.824034 ,
|
||||
28.046684 ,28.236614 ,28.406433 ,28.473265 ,28.450163 ,28.394344 ,
|
||||
24.429443 ,24.60767 ,24.80497 ,25.04041 ,25.333065 ,25.642756 ,
|
||||
25.853743 ,26.032173 ,26.25321 ,26.564959 ,26.79954 ,26.97954 ,
|
||||
27.181242 ,27.47763 ,27.74168 ,27.929441 ,28.117207 ,28.381254 ,
|
||||
28.677637 ,28.879343 ,29.059345 ,29.293922 ,29.617298 ,29.890451 ,
|
||||
30.113104 ,30.303034 ,30.472853 ,30.539684 ,30.516582 ,30.460762 ,
|
||||
26. ,26.178228 ,26.375526 ,26.61097 ,26.903624 ,27.213314 ,
|
||||
27.424305 ,27.602734 ,27.823772 ,28.135519 ,28.3701 ,28.550098 ,
|
||||
28.7518 ,29.04819 ,29.312237 ,29.5 ,29.687763 ,29.951813 ,
|
||||
30.2482 ,30.449903 ,30.629902 ,30.864483 ,31.187859 ,31.461012 ,
|
||||
31.683659 ,31.873592 ,32.043407 ,32.11024 ,32.087135 ,32.03132 ,
|
||||
27.570559 ,27.748787 ,27.946087 ,28.181528 ,28.474184 ,28.783876 ,
|
||||
28.994865 ,29.173294 ,29.39433 ,29.70608 ,29.940659 ,30.120655 ,
|
||||
30.32236 ,30.618746 ,30.882797 ,31.070557 ,31.25832 ,31.522371 ,
|
||||
31.818754 ,32.02046 ,32.20046 ,32.43504 ,32.758415 ,33.031567 ,
|
||||
33.25422 ,33.44415 ,33.613964 ,33.680794 ,33.657696 ,33.60188 ,
|
||||
29.636976 ,29.815207 ,30.0125 ,30.247944 ,30.5406 ,30.85029 ,
|
||||
31.061283 ,31.239712 ,31.46075 ,31.7725 ,32.00708 ,32.187077 ,
|
||||
32.38878 ,32.685165 ,32.949215 ,33.13698 ,33.32474 ,33.58879 ,
|
||||
33.885178 ,34.086884 ,34.26688 ,34.501457 ,34.824837 ,35.09799 ,
|
||||
35.320637 ,35.510574 ,35.68039 ,35.747215 ,35.724117 ,35.6683 ,
|
||||
31.83344 ,32.011665 ,32.20897 ,32.444412 ,32.73707 ,33.046757 ,
|
||||
33.257744 ,33.436176 ,33.657207 ,33.96896 ,34.203537 ,34.383537 ,
|
||||
34.58524 ,34.88163 ,35.145676 ,35.33344 ,35.521206 ,35.785255 ,
|
||||
36.081642 ,36.28334 ,36.46334 ,36.69792 ,37.021297 ,37.294453 ,
|
||||
37.517097 ,37.707027 ,37.876846 ,37.94368 ,37.920578 ,37.864758 ,
|
||||
33.253647 ,33.431873 ,33.62917 ,33.864613 ,34.15727 ,34.466957 ,
|
||||
34.677948 ,34.856377 ,35.077415 ,35.38916 ,35.623745 ,35.803745 ,
|
||||
36.005447 ,36.301834 ,36.565884 ,36.753647 ,36.941406 ,37.205456 ,
|
||||
37.50184 ,37.703545 ,37.883545 ,38.118122 ,38.4415 ,38.714653 ,
|
||||
38.9373 ,39.127235 ,39.297054 ,39.363884 ,39.340782 ,39.28496 ,
|
||||
34.464783 ,34.64301 ,34.840305 ,35.075752 ,35.368404 ,35.6781 ,
|
||||
35.889088 ,36.067516 ,36.28855 ,36.6003 ,36.834885 ,37.014877 ,
|
||||
37.216583 ,37.51297 ,37.77702 ,37.964783 ,38.152546 ,38.416595 ,
|
||||
38.71298 ,38.914684 ,39.094685 ,39.32926 ,39.652645 ,39.925793 ,
|
||||
40.14844 ,40.338375 ,40.508194 ,40.575024 ,40.55192 ,40.496105 ,
|
||||
36.058067 ,36.23629 ,36.43359 ,36.669033 ,36.961685 ,37.271378 ,
|
||||
37.48237 ,37.6608 ,37.881836 ,38.19359 ,38.42817 ,38.608162 ,
|
||||
38.809868 ,39.10625 ,39.3703 ,39.558064 ,39.74583 ,40.00988 ,
|
||||
40.306267 ,40.50797 ,40.68797 ,40.92255 ,41.245926 ,41.519077 ,
|
||||
41.741722 ,41.931652 ,42.101475 ,42.168304 ,42.145203 ,42.089386 ,
|
||||
38.315002 ,38.493233 ,38.690533 ,38.925976 ,39.218628 ,39.52832 ,
|
||||
39.739307 ,39.917736 ,40.138775 ,40.45052 ,40.685104 ,40.865097 ,
|
||||
41.066803 ,41.36319 ,41.627243 ,41.815002 ,42.002766 ,42.26682 ,
|
||||
42.5632 ,42.764908 ,42.944904 ,43.179485 ,43.50286 ,43.776016 ,
|
||||
43.998665 ,44.188595 ,44.358418 ,44.425247 ,44.402145 ,44.34633 ,
|
||||
40.22708 ,40.40531 ,40.602608 ,40.83805 ,41.130707 ,41.440395 ,
|
||||
41.651382 ,41.82982 ,42.050854 ,42.3626 ,42.597183 ,42.77718 ,
|
||||
42.97888 ,43.27527 ,43.53932 ,43.72708 ,43.914845 ,44.178894 ,
|
||||
44.47528 ,44.676983 ,44.856983 ,45.09156 ,45.41494 ,45.68809 ,
|
||||
45.91074 ,46.100674 ,46.270493 ,46.337322 ,46.31422 ,46.2584 ,
|
||||
41.785618 ,41.963844 ,42.161144 ,42.396584 ,42.68924 ,42.998936 ,
|
||||
43.209923 ,43.388355 ,43.609394 ,43.921143 ,44.15572 ,44.335716 ,
|
||||
44.53742 ,44.833805 ,45.09786 ,45.285614 ,45.473377 ,45.737427 ,
|
||||
46.033817 ,46.235523 ,46.415524 ,46.650105 ,46.973476 ,47.24663 ,
|
||||
47.469276 ,47.65921 ,47.82903 ,47.895855 ,47.872753 ,47.81694 ,
|
||||
43.11514 ,43.293365 ,43.490665 ,43.726105 ,44.018764 ,44.328457 ,
|
||||
44.539444 ,44.717873 ,44.93891 ,45.25066 ,45.48524 ,45.665237 ,
|
||||
45.86694 ,46.163326 ,46.427376 ,46.615143 ,46.802902 ,47.066956 ,
|
||||
47.363342 ,47.56505 ,47.74505 ,47.979626 ,48.302998 ,48.576153 ,
|
||||
48.798798 ,48.98873 ,49.158546 ,49.225376 ,49.202282 ,49.146458 ,
|
||||
44.303867 ,44.482094 ,44.679394 ,44.914833 ,45.207493 ,45.51718 ,
|
||||
45.72817 ,45.9066 ,46.12764 ,46.439384 ,46.673965 ,46.853966 ,
|
||||
47.055668 ,47.352055 ,47.6161 ,47.803867 ,47.99163 ,48.25568 ,
|
||||
48.552063 ,48.75377 ,48.933773 ,49.16835 ,49.491726 ,49.764877 ,
|
||||
49.987526 ,50.17746 ,50.347275 ,50.4141 ,50.391006 ,50.335186 ,
|
||||
44.771675 ,44.949905 ,45.1472 ,45.382645 ,45.6753 ,45.98499 ,
|
||||
46.195976 ,46.374413 ,46.595448 ,46.907196 ,47.141773 ,47.321774 ,
|
||||
47.523476 ,47.819862 ,48.08391 ,48.27168 ,48.459446 ,48.72349 ,
|
||||
49.019882 ,49.22158 ,49.401585 ,49.63616 ,49.959538 ,50.232693 ,
|
||||
50.455338 ,50.64527 ,50.81509 ,50.88192 ,50.858818 ,50.803 ,
|
||||
44.609966 ,44.788193 ,44.985493 ,45.220936 ,45.51359 ,45.82328 ,
|
||||
46.03427 ,46.2127 ,46.433743 ,46.74549 ,46.98007 ,47.160065 ,
|
||||
47.36177 ,47.658157 ,47.922207 ,48.10997 ,48.297733 ,48.561783 ,
|
||||
48.858166 ,49.059875 ,49.239872 ,49.47445 ,49.79783 ,50.07098 ,
|
||||
50.293625 ,50.48356 ,50.653378 ,50.720203 ,50.6971 ,50.64128 ,
|
||||
44.219246 ,44.397472 ,44.594772 ,44.83021 ,45.122868 ,45.43256 ,
|
||||
45.643543 ,45.82198 ,46.04302 ,46.354763 ,46.589344 ,46.76934 ,
|
||||
46.971046 ,47.267433 ,47.531483 ,47.719242 ,47.907005 ,48.17105 ,
|
||||
48.467438 ,48.66914 ,48.849144 ,49.08372 ,49.4071 ,49.680256 ,
|
||||
49.902905 ,50.092834 ,50.262653 ,50.329483 ,50.30638 ,50.25057});
|
||||
NDArray expected = NDArrayFactory::create<float>('c', {30, 30, 1}, {
|
||||
1.000000f, 1.197616f, 1.417436f, 1.677577f, 1.996158f, 2.328327f, 2.550918f, 2.736061f, 2.965541f,
|
||||
3.292965f, 3.544151f, 3.738035f, 3.948995f, 4.248106f, 4.507379f, 4.684374f, 4.857284f, 5.104302f,
|
||||
5.386991f, 5.581401f, 5.753962f, 5.974285f, 6.272836f, 6.520426f, 6.718899f, 6.887104f, 7.039068f,
|
||||
7.099216f, 7.078424f, 7.028189f, 2.247592f, 2.446947f, 2.669489f, 2.931238f, 3.248216f, 3.574534f,
|
||||
3.789310f, 3.965697f, 4.186417f, 4.504653f, 4.740569f, 4.921706f, 5.133866f, 5.459533f, 5.774461f,
|
||||
6.019787f, 6.254011f, 6.535633f, 6.809730f, 6.960779f, 7.074942f, 7.241601f, 7.509489f, 7.749949f,
|
||||
7.954571f, 8.131972f, 8.286526f, 8.346463f, 8.325745f, 8.275683f, 3.628684f, 3.830573f, 4.056959f,
|
||||
4.321157f, 4.636486f, 4.955650f, 5.160583f, 5.325847f, 5.535462f, 5.842160f, 6.058749f, 6.223753f,
|
||||
6.437597f, 6.797369f, 7.183604f, 7.516402f, 7.829034f, 8.154773f, 8.417635f, 8.512958f, 8.552100f,
|
||||
8.649708f, 8.877880f, 9.108794f, 9.320926f, 9.509781f, 9.667375f, 9.726940f, 9.706349f, 9.656599f,
|
||||
5.276778f, 5.480438f, 5.709702f, 5.975448f, 6.288551f, 6.600570f, 6.796207f, 6.951142f, 7.150400f,
|
||||
7.446143f, 7.644651f, 7.794562f, 8.009684f, 8.400473f, 8.851847f, 9.264690f, 9.649218f, 10.015648f,
|
||||
10.268647f, 10.313368f, 10.284327f, 10.319379f, 10.512033f, 10.734956f, 10.954604f, 11.154507f, 11.315369f,
|
||||
11.374779f, 11.354242f, 11.304622f, 7.325373f, 7.528484f, 7.757575f, 8.022221f, 8.331997f, 8.638187f,
|
||||
8.827649f, 8.976217f, 9.168955f, 9.457260f, 9.644237f, 9.784517f, 9.999621f, 10.407702f, 10.896234f,
|
||||
11.355122f, 11.781423f, 12.172186f, 12.420712f, 12.437449f, 12.370511f, 12.371386f, 12.545973f, 12.766424f,
|
||||
12.992249f, 13.200120f, 13.364252f, 13.424109f, 13.403420f, 13.353425f, 9.493208f, 9.692467f, 9.916944f,
|
||||
10.176801f, 10.482199f, 10.785470f, 10.974367f, 11.123442f, 11.316370f, 11.603645f, 11.790616f, 11.930889f,
|
||||
12.144082f, 12.546447f, 13.024898f, 13.472300f, 13.889232f, 14.276275f, 14.528972f, 14.555555f, 14.501450f,
|
||||
14.515459f, 14.700572f, 14.927055f, 15.156046f, 15.366046f, 15.532901f, 15.594008f, 15.572885f, 15.521847f,
|
||||
10.970133f, 11.163599f, 11.380694f, 11.633735f, 11.935032f, 12.238887f, 12.432540f, 12.588294f, 12.787534f,
|
||||
13.079956f, 13.277520f, 13.426631f, 13.636713f, 14.013844f, 14.441672f, 14.827978f, 15.191209f, 15.549808f,
|
||||
15.813430f, 15.881828f, 15.883522f, 15.950411f, 16.169330f, 16.407940f, 16.636436f, 16.842583f, 17.010887f,
|
||||
17.073630f, 17.051940f, 16.999537f, 12.219155f, 12.406129f, 12.614796f, 12.860335f, 13.157928f, 13.464224f,
|
||||
13.665207f, 13.830567f, 14.039036f, 14.339629f, 14.552863f, 14.715049f, 14.921564f, 15.264454f, 15.622843f,
|
||||
15.924977f, 16.213829f, 16.532364f, 16.809900f, 16.934835f, 17.012146f, 17.150164f, 17.413412f, 17.666712f,
|
||||
17.892765f, 18.092070f, 18.261044f, 18.325531f, 18.303238f, 18.249378f, 13.766397f, 13.947391f, 14.148263f,
|
||||
14.386917f, 14.681246f, 14.990087f, 15.198166f, 15.372728f, 15.590062f, 15.898583f, 16.126892f, 16.301655f,
|
||||
16.504870f, 16.815214f, 17.107498f, 17.329458f, 17.547403f, 17.827654f, 18.118288f, 18.296928f, 18.446100f,
|
||||
18.651634f, 18.956806f, 19.223820f, 19.447308f, 19.639887f, 19.809319f, 19.875397f, 19.852556f, 19.797365f,
|
||||
15.941937f, 16.118704f, 16.314133f, 16.547867f, 16.839561f, 17.149540f, 17.361883f, 17.542162f, 17.764957f,
|
||||
18.078188f, 18.315733f, 18.498205f, 18.699116f, 18.988684f, 19.238989f, 19.410137f, 19.583265f, 19.839512f,
|
||||
20.138780f, 20.351770f, 20.546844f, 20.795671f, 21.128067f, 21.404358f, 21.626736f, 21.815500f, 21.985610f,
|
||||
22.052843f, 22.029604f, 21.973448f, 17.535220f, 17.710770f, 17.904636f, 18.136950f, 18.427840f, 18.738056f,
|
||||
18.951529f, 19.133352f, 19.357613f, 19.672083f, 19.912102f, 20.096638f, 20.296894f, 20.580765f, 20.819603f,
|
||||
20.976887f, 21.137802f, 21.387535f, 21.689209f, 21.911621f, 22.119276f, 22.379990f, 22.719910f, 22.998823f,
|
||||
23.220970f, 23.408760f, 23.579110f, 23.646685f, 23.623325f, 23.566887f, 18.746353f, 18.922657f, 19.117487f,
|
||||
19.350685f, 19.642070f, 19.952137f, 20.164913f, 20.345781f, 20.569134f, 20.882840f, 21.121330f, 21.304590f,
|
||||
21.505253f, 21.792645f, 22.038572f, 22.204426f, 22.372890f, 22.626648f, 22.926834f, 23.143423f, 23.343302f,
|
||||
23.596668f, 23.931936f, 24.209232f, 24.431519f, 24.619913f, 24.790110f, 24.857473f, 24.834190f, 24.777927f,
|
||||
20.166560f, 20.344206f, 20.540766f, 20.775532f, 21.067804f, 21.377607f, 21.589132f, 21.768297f, 21.990030f,
|
||||
22.302366f, 22.538124f, 22.719105f, 22.920494f, 23.214176f, 23.472767f, 23.653934f, 23.835890f, 24.096842f,
|
||||
24.394371f, 24.600555f, 24.786541f, 25.026773f, 25.353731f, 25.628130f, 25.850672f, 26.040140f, 26.210072f,
|
||||
26.277063f, 26.253906f, 26.197956f, 22.363024f, 22.541250f, 22.738552f, 22.973991f, 23.266647f, 23.576340f,
|
||||
23.787327f, 23.965760f, 24.186796f, 24.498543f, 24.733124f, 24.913122f, 25.114826f, 25.411213f, 25.675262f,
|
||||
25.863028f, 26.050789f, 26.314838f, 26.611223f, 26.812925f, 26.992926f, 27.227505f, 27.550882f, 27.824034f,
|
||||
28.046684f, 28.236614f, 28.406433f, 28.473265f, 28.450163f, 28.394344f, 24.429443f, 24.607670f, 24.804970f,
|
||||
25.040410f, 25.333065f, 25.642756f, 25.853743f, 26.032173f, 26.253210f, 26.564959f, 26.799540f, 26.979540f,
|
||||
27.181242f, 27.477630f, 27.741680f, 27.929441f, 28.117207f, 28.381254f, 28.677637f, 28.879343f, 29.059345f,
|
||||
29.293922f, 29.617298f, 29.890451f, 30.113104f, 30.303034f, 30.472853f, 30.539684f, 30.516582f, 30.460762f,
|
||||
26.000000f, 26.178228f, 26.375526f, 26.610970f, 26.903624f, 27.213314f, 27.424305f, 27.602734f, 27.823772f,
|
||||
28.135519f, 28.370100f, 28.550098f, 28.751800f, 29.048190f, 29.312237f, 29.500000f, 29.687763f, 29.951813f,
|
||||
30.248200f, 30.449903f, 30.629902f, 30.864483f, 31.187859f, 31.461012f, 31.683659f, 31.873592f, 32.043407f,
|
||||
32.110240f, 32.087135f, 32.031320f, 27.570559f, 27.748787f, 27.946087f, 28.181528f, 28.474184f, 28.783876f,
|
||||
28.994865f, 29.173294f, 29.394330f, 29.706080f, 29.940659f, 30.120655f, 30.322360f, 30.618746f, 30.882797f,
|
||||
31.070557f, 31.258320f, 31.522371f, 31.818754f, 32.020460f, 32.200460f, 32.435040f, 32.758415f, 33.031567f,
|
||||
33.254220f, 33.444150f, 33.613964f, 33.680794f, 33.657696f, 33.601880f, 29.636976f, 29.815207f, 30.012500f,
|
||||
30.247944f, 30.540600f, 30.850290f, 31.061283f, 31.239712f, 31.460750f, 31.772500f, 32.007080f, 32.187077f,
|
||||
32.388780f, 32.685165f, 32.949215f, 33.136980f, 33.324740f, 33.588790f, 33.885178f, 34.086884f, 34.266880f,
|
||||
34.501457f, 34.824837f, 35.097990f, 35.320637f, 35.510574f, 35.680390f, 35.747215f, 35.724117f, 35.668300f,
|
||||
31.833440f, 32.011665f, 32.208970f, 32.444412f, 32.737070f, 33.046757f, 33.257744f, 33.436176f, 33.657207f,
|
||||
33.968960f, 34.203537f, 34.383537f, 34.585240f, 34.881630f, 35.145676f, 35.333440f, 35.521206f, 35.785255f,
|
||||
36.081642f, 36.283340f, 36.463340f, 36.697920f, 37.021297f, 37.294453f, 37.517097f, 37.707027f, 37.876846f,
|
||||
37.943680f, 37.920578f, 37.864758f, 33.253647f, 33.431873f, 33.629170f, 33.864613f, 34.157270f, 34.466957f,
|
||||
34.677948f, 34.856377f, 35.077415f, 35.389160f, 35.623745f, 35.803745f, 36.005447f, 36.301834f, 36.565884f,
|
||||
36.753647f, 36.941406f, 37.205456f, 37.501840f, 37.703545f, 37.883545f, 38.118122f, 38.441500f, 38.714653f,
|
||||
38.937300f, 39.127235f, 39.297054f, 39.363884f, 39.340782f, 39.284960f, 34.464783f, 34.643010f, 34.840305f,
|
||||
35.075752f, 35.368404f, 35.678100f, 35.889088f, 36.067516f, 36.288550f, 36.600300f, 36.834885f, 37.014877f,
|
||||
37.216583f, 37.512970f, 37.777020f, 37.964783f, 38.152546f, 38.416595f, 38.712980f, 38.914684f, 39.094685f,
|
||||
39.329260f, 39.652645f, 39.925793f, 40.148440f, 40.338375f, 40.508194f, 40.575024f, 40.551920f, 40.496105f,
|
||||
36.058067f, 36.236290f, 36.433590f, 36.669033f, 36.961685f, 37.271378f, 37.482370f, 37.660800f, 37.881836f,
|
||||
38.193590f, 38.428170f, 38.608162f, 38.809868f, 39.106250f, 39.370300f, 39.558064f, 39.745830f, 40.009880f,
|
||||
40.306267f, 40.507970f, 40.687970f, 40.922550f, 41.245926f, 41.519077f, 41.741722f, 41.931652f, 42.101475f,
|
||||
42.168304f, 42.145203f, 42.089386f, 38.315002f, 38.493233f, 38.690533f, 38.925976f, 39.218628f, 39.528320f,
|
||||
39.739307f, 39.917736f, 40.138775f, 40.450520f, 40.685104f, 40.865097f, 41.066803f, 41.363190f, 41.627243f,
|
||||
41.815002f, 42.002766f, 42.266820f, 42.563200f, 42.764908f, 42.944904f, 43.179485f, 43.502860f, 43.776016f,
|
||||
43.998665f, 44.188595f, 44.358418f, 44.425247f, 44.402145f, 44.346330f, 40.227080f, 40.405310f, 40.602608f,
|
||||
40.838050f, 41.130707f, 41.440395f, 41.651382f, 41.829820f, 42.050854f, 42.362600f, 42.597183f, 42.777180f,
|
||||
42.978880f, 43.275270f, 43.539320f, 43.727080f, 43.914845f, 44.178894f, 44.475280f, 44.676983f, 44.856983f,
|
||||
45.091560f, 45.414940f, 45.688090f, 45.910740f, 46.100674f, 46.270493f, 46.337322f, 46.314220f, 46.258400f,
|
||||
41.785618f, 41.963844f, 42.161144f, 42.396584f, 42.689240f, 42.998936f, 43.209923f, 43.388355f, 43.609394f,
|
||||
43.921143f, 44.155720f, 44.335716f, 44.537420f, 44.833805f, 45.097860f, 45.285614f, 45.473377f, 45.737427f,
|
||||
46.033817f, 46.235523f, 46.415524f, 46.650105f, 46.973476f, 47.246630f, 47.469276f, 47.659210f, 47.829030f,
|
||||
47.895855f, 47.872753f, 47.816940f, 43.115140f, 43.293365f, 43.490665f, 43.726105f, 44.018764f, 44.328457f,
|
||||
44.539444f, 44.717873f, 44.938910f, 45.250660f, 45.485240f, 45.665237f, 45.866940f, 46.163326f, 46.427376f,
|
||||
46.615143f, 46.802902f, 47.066956f, 47.363342f, 47.565050f, 47.745050f, 47.979626f, 48.302998f, 48.576153f,
|
||||
48.798798f, 48.988730f, 49.158546f, 49.225376f, 49.202282f, 49.146458f, 44.303867f, 44.482094f, 44.679394f,
|
||||
44.914833f, 45.207493f, 45.517180f, 45.728170f, 45.906600f, 46.127640f, 46.439384f, 46.673965f, 46.853966f,
|
||||
47.055668f, 47.352055f, 47.616100f, 47.803867f, 47.991630f, 48.255680f, 48.552063f, 48.753770f, 48.933773f,
|
||||
49.168350f, 49.491726f, 49.764877f, 49.987526f, 50.177460f, 50.347275f, 50.414100f, 50.391006f, 50.335186f,
|
||||
44.771675f, 44.949905f, 45.147200f, 45.382645f, 45.675300f, 45.984990f, 46.195976f, 46.374413f, 46.595448f,
|
||||
46.907196f, 47.141773f, 47.321774f, 47.523476f, 47.819862f, 48.083910f, 48.271680f, 48.459446f, 48.723490f,
|
||||
49.019882f, 49.221580f, 49.401585f, 49.636160f, 49.959538f, 50.232693f, 50.455338f, 50.645270f, 50.815090f,
|
||||
50.881920f, 50.858818f, 50.803000f, 44.609966f, 44.788193f, 44.985493f, 45.220936f, 45.513590f, 45.823280f,
|
||||
46.034270f, 46.212700f, 46.433743f, 46.745490f, 46.980070f, 47.160065f, 47.361770f, 47.658157f, 47.922207f,
|
||||
48.109970f, 48.297733f, 48.561783f, 48.858166f, 49.059875f, 49.239872f, 49.474450f, 49.797830f, 50.070980f,
|
||||
50.293625f, 50.483560f, 50.653378f, 50.720203f, 50.697100f, 50.641280f, 44.219246f, 44.397472f, 44.594772f,
|
||||
44.830210f, 45.122868f, 45.432560f, 45.643543f, 45.821980f, 46.043020f, 46.354763f, 46.589344f, 46.769340f,
|
||||
46.971046f, 47.267433f, 47.531483f, 47.719242f, 47.907005f, 48.171050f, 48.467438f, 48.669140f, 48.849144f,
|
||||
49.083720f, 49.407100f, 49.680256f, 49.902905f, 50.092834f, 50.262653f, 50.329483f, 50.306380f, 50.250570f
|
||||
});
|
||||
|
||||
auto size = NDArrayFactory::create<int>({30, 30});
|
||||
nd4j::ops::resize_bicubic op;
|
||||
|
|
|
@ -229,6 +229,640 @@ TEST_F(DeclarableOpsTests15, Test_AdjustContrast_4) {
|
|||
ASSERT_TRUE(e.equalsTo(out));
|
||||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests15, Test_AdjustContrast_5) {
|
||||
auto x = NDArrayFactory::create<double>('c', {1, 3, 4});
|
||||
auto e = NDArrayFactory::create<double>('c', {1, 3, 4}, {
|
||||
-3., -2., -1., 0., 5., 6., 7., 8., 13., 14., 15., 16.
|
||||
});
|
||||
x.linspace(1.);
|
||||
nd4j::ops::adjust_contrast_v2 op;
|
||||
auto result = op.execute({&x}, {2.}, {}, {});
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
auto out = result->at(0);
|
||||
// out->printIndexedBuffer("Adjusted Constrast");
|
||||
ASSERT_TRUE(e.equalsTo(out));
|
||||
delete result;
|
||||
}
|
||||
|
||||
/*
|
||||
* public void testAdjustContrast1() {
|
||||
INDArray in = Nd4j.createFromArray(new float[]{0.7788f,0.8012f,0.7244f,0.2309f,0.7271f,0.1804f,
|
||||
0.5056f,0.8925f,0.5461f,0.9234f,0.0856f,0.7938f,0.6591f,0.5555f,0.1596f,0.3087f,0.1548f,0.4695f,
|
||||
0.9939f,0.6113f,0.6765f,0.1800f,0.6750f,0.2246f,0.0509f,0.4601f,0.8284f,0.2354f,0.9752f,0.8361f,
|
||||
0.2585f,0.4189f,0.7028f,0.7679f,0.5373f,0.7234f,0.2690f,0.0062f,0.0327f,0.0644f,0.8428f,0.7494f,
|
||||
0.0755f,0.6245f,0.3491f,0.5793f,0.5730f,0.1822f,0.6420f,0.9143f,0.3019f,
|
||||
0.3574f,0.1704f,0.8395f,0.5468f,0.0744f,0.9011f,0.6574f,0.4124f,0.2445f,0.4248f,0.5219f,
|
||||
0.6952f,0.4900f,0.2158f,0.9549f,0.1386f,0.1544f,0.5365f,0.0134f,0.4163f,0.1456f,0.4109f,
|
||||
0.2484f, 0.3330f,0.2974f,0.6636f,0.3808f,0.8664f, 0.1896f, 0.7530f, 0.7215f, 0.6612f, 0.7270f,
|
||||
0.5704f,0.2666f,0.7453f,0.0444f,0.3024f,0.4850f,0.7982f,0.0965f,0.7843f,0.5075f,
|
||||
0.0844f,0.8370f,0.6103f,0.4604f,0.6087f, 0.8594f, 0.4599f, 0.6714f, 0.2744f, 0.1981f, 0.4143f,
|
||||
0.7821f,0.3505f,0.5040f,0.1180f,0.8307f,0.1817f,0.8442f,0.5074f,0.4471f,0.5105f,0.6666f,
|
||||
0.2576f,0.2341f,0.6801f,0.2652f,0.5394f,0.4690f,0.6146f,0.1210f,0.2576f,0.0769f,0.4643f,
|
||||
0.1628f,0.2026f,0.3774f,0.0506f,0.3462f,0.5720f,0.0838f,0.4228f,0.0588f,0.5362f,0.4756f,
|
||||
0.2530f,0.1778f,0.0751f,0.8977f,0.3648f,0.3065f,0.4739f,0.7014f,0.4473f,0.5171f,0.1744f,
|
||||
0.3487f,0.7759f,0.9491f,0.2072f,0.2182f,0.6520f,0.3092f,0.9545f,0.1881f,0.9579f,0.1785f,
|
||||
0.9636f,0.4830f,0.6569f,0.3353f,0.9997f,0.5869f,0.5747f,0.0238f,0.2943f,0.5248f,0.5879f,
|
||||
.7266f,0.1965f,0.9167f,0.9726f,0.9206f,0.0519f,0.2997f,0.0039f,0.7652f,0.5498f,
|
||||
0.3794f,0.3791f,0.3528f,0.2873f,0.8082f,0.4732f,0.4399f,0.6606f,0.5991f,0.0034f,0.4874f
|
||||
}).reshape(8,8,3,1);
|
||||
INDArray out = Nd4j.create(DataType.FLOAT, in.shape());
|
||||
INDArray[] res = Nd4j.exec(new AdjustContrast(in, 2.0, out));
|
||||
assertArrayEquals(out.shape(), in.shape());
|
||||
//assertEquals(expected, out);
|
||||
}
|
||||
* */
|
||||
|
||||
TEST_F(DeclarableOpsTests15, Test_AdjustContrast_6) {
|
||||
auto x = NDArrayFactory::create<float>('c', {8,8, 3, 1}, {0.7788f,0.8012f,0.7244f,0.2309f,0.7271f,0.1804f,
|
||||
0.5056f,0.8925f,0.5461f,0.9234f,0.0856f,0.7938f,0.6591f,0.5555f,0.1596f,0.3087f,0.1548f,0.4695f,
|
||||
0.9939f,0.6113f,0.6765f,0.1800f,0.6750f,0.2246f,0.0509f,0.4601f,0.8284f,0.2354f,0.9752f,0.8361f,
|
||||
0.2585f,0.4189f,0.7028f,0.7679f,0.5373f,0.7234f,0.2690f,0.0062f,0.0327f,0.0644f,0.8428f,0.7494f,
|
||||
0.0755f,0.6245f,0.3491f,0.5793f,0.5730f,0.1822f,0.6420f,0.9143f,0.3019f,
|
||||
0.3574f,0.1704f,0.8395f,0.5468f,0.0744f,0.9011f,0.6574f,0.4124f,0.2445f,0.4248f,0.5219f,
|
||||
0.6952f,0.4900f,0.2158f,0.9549f,0.1386f,0.1544f,0.5365f,0.0134f,0.4163f,0.1456f,0.4109f,
|
||||
0.2484f, 0.3330f,0.2974f,0.6636f,0.3808f,0.8664f, 0.1896f, 0.7530f, 0.7215f, 0.6612f, 0.7270f,
|
||||
0.5704f,0.2666f,0.7453f,0.0444f,0.3024f,0.4850f,0.7982f,0.0965f,0.7843f,0.5075f,
|
||||
0.0844f,0.8370f,0.6103f,0.4604f,0.6087f, 0.8594f, 0.4599f, 0.6714f, 0.2744f, 0.1981f, 0.4143f,
|
||||
0.7821f,0.3505f,0.5040f,0.1180f,0.8307f,0.1817f,0.8442f,0.5074f,0.4471f,0.5105f,0.6666f,
|
||||
0.2576f,0.2341f,0.6801f,0.2652f,0.5394f,0.4690f,0.6146f,0.1210f,0.2576f,0.0769f,0.4643f,
|
||||
0.1628f,0.2026f,0.3774f,0.0506f,0.3462f,0.5720f,0.0838f,0.4228f,0.0588f,0.5362f,0.4756f,
|
||||
0.2530f,0.1778f,0.0751f,0.8977f,0.3648f,0.3065f,0.4739f,0.7014f,0.4473f,0.5171f,0.1744f,
|
||||
0.3487f,0.7759f,0.9491f,0.2072f,0.2182f,0.6520f,0.3092f,0.9545f,0.1881f,0.9579f,0.1785f,
|
||||
0.9636f,0.4830f,0.6569f,0.3353f,0.9997f,0.5869f,0.5747f,0.0238f,0.2943f,0.5248f,0.5879f,
|
||||
.7266f,0.1965f,0.9167f,0.9726f,0.9206f,0.0519f,0.2997f,0.0039f,0.7652f,0.5498f,
|
||||
0.3794f,0.3791f,0.3528f,0.2873f,0.8082f,0.4732f,0.4399f,0.6606f,0.5991f,0.0034f,0.4874f});
|
||||
auto e = NDArrayFactory::create<float>('c', {8, 8, 3, 1}, {
|
||||
1.0218375f,
|
||||
1.0666375f,
|
||||
0.9130375f,
|
||||
|
||||
-0.07396251f,
|
||||
0.91843754f,
|
||||
-0.17496246f,
|
||||
|
||||
0.47543746f,
|
||||
1.2492375f,
|
||||
0.55643755f,
|
||||
|
||||
1.3110375f,
|
||||
-0.36456245f,
|
||||
1.0518374f,
|
||||
|
||||
0.7824375f,
|
||||
0.57523745f,
|
||||
-0.21656245f,
|
||||
|
||||
0.0816375f,
|
||||
-0.2261625f,
|
||||
0.40323752f,
|
||||
|
||||
1.4520376f,
|
||||
0.6868375f,
|
||||
0.81723756f,
|
||||
|
||||
-0.17576247f,
|
||||
0.81423753f,
|
||||
-0.08656245f,
|
||||
|
||||
|
||||
-0.36249164f,
|
||||
0.45590833f,
|
||||
1.1925083f,
|
||||
|
||||
0.00650835f,
|
||||
1.4861084f,
|
||||
1.2079083f,
|
||||
|
||||
0.05270836f,
|
||||
0.37350836f,
|
||||
0.94130826f,
|
||||
|
||||
1.0715083f,
|
||||
0.6103083f,
|
||||
0.9825083f,
|
||||
|
||||
0.07370833f,
|
||||
-0.4518917f,
|
||||
-0.39889166f,
|
||||
|
||||
-0.3354917f,
|
||||
1.2213084f,
|
||||
1.0345083f,
|
||||
|
||||
-0.3132917f,
|
||||
0.78470826f,
|
||||
0.23390833f,
|
||||
|
||||
0.6943083f,
|
||||
0.68170834f,
|
||||
-0.09989169f,
|
||||
|
||||
|
||||
0.8352709f,
|
||||
1.3798709f,
|
||||
0.15507084f,
|
||||
|
||||
0.26607084f,
|
||||
-0.10792917f,
|
||||
1.2302709f,
|
||||
|
||||
0.6448709f,
|
||||
-0.29992914f,
|
||||
1.3534708f,
|
||||
|
||||
0.86607087f,
|
||||
0.37607086f,
|
||||
0.04027084f,
|
||||
|
||||
0.40087086f,
|
||||
0.59507084f,
|
||||
0.9416709f,
|
||||
|
||||
0.53127086f,
|
||||
-0.01712915f,
|
||||
1.4610709f,
|
||||
|
||||
-0.17152917f,
|
||||
-0.13992918f,
|
||||
0.6242708f,
|
||||
|
||||
-0.42192918f,
|
||||
0.38387084f,
|
||||
-0.15752912f,
|
||||
|
||||
|
||||
0.3311833f,
|
||||
0.00618333f,
|
||||
0.17538333f,
|
||||
|
||||
0.10418332f,
|
||||
0.8365834f,
|
||||
0.27098334f,
|
||||
|
||||
1.2421833f,
|
||||
-0.1114167f,
|
||||
1.0153834f,
|
||||
|
||||
0.9523833f,
|
||||
0.8317833f,
|
||||
0.9633833f,
|
||||
|
||||
0.6501833f,
|
||||
0.04258335f,
|
||||
0.9999833f,
|
||||
|
||||
-0.40181667f,
|
||||
0.11418331f,
|
||||
0.47938335f,
|
||||
|
||||
1.1057833f,
|
||||
-0.29761666f,
|
||||
1.0779834f,
|
||||
|
||||
0.5243833f,
|
||||
-0.32181668f,
|
||||
1.1833833f,
|
||||
|
||||
|
||||
0.73157084f,
|
||||
0.4317708f,
|
||||
0.7283708f,
|
||||
|
||||
1.2297708f,
|
||||
0.4307708f,
|
||||
0.85377085f,
|
||||
|
||||
0.05977082f,
|
||||
-0.09282917f,
|
||||
0.33957082f,
|
||||
|
||||
1.0751709f,
|
||||
0.2119708f,
|
||||
0.51897085f,
|
||||
|
||||
-0.25302917f,
|
||||
1.1723708f,
|
||||
-0.12562919f,
|
||||
|
||||
1.1993709f,
|
||||
0.5257708f,
|
||||
0.40517086f,
|
||||
|
||||
0.53197086f,
|
||||
0.8441708f,
|
||||
0.02617085f,
|
||||
|
||||
-0.0208292f,
|
||||
0.8711709f,
|
||||
0.04137081f,
|
||||
|
||||
|
||||
0.74936247f,
|
||||
0.6085625f,
|
||||
0.8997625f,
|
||||
|
||||
-0.08743751f,
|
||||
0.18576252f,
|
||||
-0.17563748f,
|
||||
|
||||
0.5991625f,
|
||||
-0.0038375f,
|
||||
0.07576251f,
|
||||
|
||||
0.42536253f,
|
||||
-0.22823751f,
|
||||
0.36296248f,
|
||||
|
||||
0.81456256f,
|
||||
-0.16183749f,
|
||||
0.5161625f,
|
||||
|
||||
-0.21183747f,
|
||||
0.7429625f,
|
||||
0.6217625f,
|
||||
|
||||
0.17656249f,
|
||||
0.02616251f,
|
||||
-0.17923748f,
|
||||
|
||||
1.4659625f,
|
||||
0.40016252f,
|
||||
0.28356248f,
|
||||
|
||||
|
||||
0.4195791f,
|
||||
0.8745791f,
|
||||
0.36637908f,
|
||||
|
||||
0.50597906f,
|
||||
-0.17942089f,
|
||||
0.16917908f,
|
||||
|
||||
1.0235791f,
|
||||
1.3699791f,
|
||||
-0.11382091f,
|
||||
|
||||
-0.0918209f,
|
||||
0.7757791f,
|
||||
0.09017909f,
|
||||
|
||||
1.3807791f,
|
||||
-0.15202093f,
|
||||
1.3875791f,
|
||||
|
||||
-0.1712209f,
|
||||
1.3989791f,
|
||||
0.43777913f,
|
||||
|
||||
0.7855791f,
|
||||
0.1423791f,
|
||||
1.4711791f,
|
||||
|
||||
0.6455791f,
|
||||
0.6211791f,
|
||||
-0.48062086f,
|
||||
|
||||
|
||||
0.10189578f,
|
||||
0.5628958f,
|
||||
0.68909574f,
|
||||
|
||||
0.96649575f,
|
||||
-0.09370419f,
|
||||
1.3466958f,
|
||||
|
||||
1.4584957f,
|
||||
1.3544958f,
|
||||
-0.3829042f,
|
||||
|
||||
0.11269578f,
|
||||
-0.47890422f,
|
||||
1.0436958f,
|
||||
|
||||
0.6128957f,
|
||||
0.27209583f,
|
||||
0.2714958f,
|
||||
|
||||
0.21889582f,
|
||||
0.08789578f,
|
||||
1.1296958f,
|
||||
|
||||
0.4596958f,
|
||||
0.39309582f,
|
||||
0.8344958f,
|
||||
|
||||
0.71149576f,
|
||||
-0.4799042f,
|
||||
0.4880958f
|
||||
});
|
||||
|
||||
nd4j::ops::adjust_contrast op;
|
||||
auto result = op.execute({&x}, {2.}, {}, {});
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
auto out = result->at(0);
|
||||
// out->printBuffer("Adjusted Constrast6");
|
||||
// e.printBuffer("Adjusted Expected 6");
|
||||
// ASSERT_TRUE(e.equalsTo(out));
|
||||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests15, Test_AdjustContrast_7) {
|
||||
auto x = NDArrayFactory::create<double>('c', {8,8, 3, 1}, {0.7788f,0.8012f,0.7244f,0.2309f,0.7271f,0.1804f,
|
||||
0.5056f,0.8925f,0.5461f,0.9234f,0.0856f,0.7938f,0.6591f,0.5555f,0.1596f,0.3087f,0.1548f,0.4695f,
|
||||
0.9939f,0.6113f,0.6765f,0.1800f,0.6750f,0.2246f,0.0509f,0.4601f,0.8284f,0.2354f,0.9752f,0.8361f,
|
||||
0.2585f,0.4189f,0.7028f,0.7679f,0.5373f,0.7234f,0.2690f,0.0062f,0.0327f,0.0644f,0.8428f,0.7494f,
|
||||
0.0755f,0.6245f,0.3491f,0.5793f,0.5730f,0.1822f,0.6420f,0.9143f,0.3019f,
|
||||
0.3574f,0.1704f,0.8395f,0.5468f,0.0744f,0.9011f,0.6574f,0.4124f,0.2445f,0.4248f,0.5219f,
|
||||
0.6952f,0.4900f,0.2158f,0.9549f,0.1386f,0.1544f,0.5365f,0.0134f,0.4163f,0.1456f,0.4109f,
|
||||
0.2484f, 0.3330f,0.2974f,0.6636f,0.3808f,0.8664f, 0.1896f, 0.7530f, 0.7215f, 0.6612f, 0.7270f,
|
||||
0.5704f,0.2666f,0.7453f,0.0444f,0.3024f,0.4850f,0.7982f,0.0965f,0.7843f,0.5075f,
|
||||
0.0844f,0.8370f,0.6103f,0.4604f,0.6087f, 0.8594f, 0.4599f, 0.6714f, 0.2744f, 0.1981f, 0.4143f,
|
||||
0.7821f,0.3505f,0.5040f,0.1180f,0.8307f,0.1817f,0.8442f,0.5074f,0.4471f,0.5105f,0.6666f,
|
||||
0.2576f,0.2341f,0.6801f,0.2652f,0.5394f,0.4690f,0.6146f,0.1210f,0.2576f,0.0769f,0.4643f,
|
||||
0.1628f,0.2026f,0.3774f,0.0506f,0.3462f,0.5720f,0.0838f,0.4228f,0.0588f,0.5362f,0.4756f,
|
||||
0.2530f,0.1778f,0.0751f,0.8977f,0.3648f,0.3065f,0.4739f,0.7014f,0.4473f,0.5171f,0.1744f,
|
||||
0.3487f,0.7759f,0.9491f,0.2072f,0.2182f,0.6520f,0.3092f,0.9545f,0.1881f,0.9579f,0.1785f,
|
||||
0.9636f,0.4830f,0.6569f,0.3353f,0.9997f,0.5869f,0.5747f,0.0238f,0.2943f,0.5248f,0.5879f,
|
||||
.7266f,0.1965f,0.9167f,0.9726f,0.9206f,0.0519f,0.2997f,0.0039f,0.7652f,0.5498f,
|
||||
0.3794f,0.3791f,0.3528f,0.2873f,0.8082f,0.4732f,0.4399f,0.6606f,0.5991f,0.0034f,0.4874f});
|
||||
auto e = NDArrayFactory::create<double>('c', {8, 8, 3, 1}, {
|
||||
1.0218375 ,
|
||||
1.0666375 ,
|
||||
0.9130375 ,
|
||||
|
||||
-0.07396251,
|
||||
0.91843754,
|
||||
-0.17496246,
|
||||
|
||||
0.47543746,
|
||||
1.2492375 ,
|
||||
0.55643755,
|
||||
|
||||
1.3110375 ,
|
||||
-0.36456245,
|
||||
1.0518374 ,
|
||||
|
||||
0.7824375 ,
|
||||
0.57523745,
|
||||
-0.21656245,
|
||||
|
||||
0.0816375 ,
|
||||
-0.2261625 ,
|
||||
0.40323752,
|
||||
|
||||
1.4520376 ,
|
||||
0.6868375 ,
|
||||
0.81723756,
|
||||
|
||||
-0.17576247,
|
||||
0.81423753,
|
||||
-0.08656245,
|
||||
|
||||
|
||||
-0.36249164,
|
||||
0.45590833,
|
||||
1.1925083 ,
|
||||
|
||||
0.00650835,
|
||||
1.4861084 ,
|
||||
1.2079083 ,
|
||||
|
||||
0.05270836,
|
||||
0.37350836,
|
||||
0.94130826,
|
||||
|
||||
1.0715083 ,
|
||||
0.6103083 ,
|
||||
0.9825083 ,
|
||||
|
||||
0.07370833,
|
||||
-0.4518917 ,
|
||||
-0.39889166,
|
||||
|
||||
-0.3354917 ,
|
||||
1.2213084 ,
|
||||
1.0345083 ,
|
||||
|
||||
-0.3132917 ,
|
||||
0.78470826,
|
||||
0.23390833,
|
||||
|
||||
0.6943083 ,
|
||||
0.68170834,
|
||||
-0.09989169,
|
||||
|
||||
|
||||
0.8352709 ,
|
||||
1.3798709 ,
|
||||
0.15507084,
|
||||
|
||||
0.26607084,
|
||||
-0.10792917,
|
||||
1.2302709 ,
|
||||
|
||||
0.6448709 ,
|
||||
-0.29992914,
|
||||
1.3534708 ,
|
||||
|
||||
0.86607087,
|
||||
0.37607086,
|
||||
0.04027084,
|
||||
|
||||
0.40087086,
|
||||
0.59507084,
|
||||
0.9416709 ,
|
||||
|
||||
0.53127086,
|
||||
-0.01712915,
|
||||
1.4610709 ,
|
||||
|
||||
-0.17152917,
|
||||
-0.13992918,
|
||||
0.6242708 ,
|
||||
|
||||
-0.42192918,
|
||||
0.38387084,
|
||||
-0.15752912,
|
||||
|
||||
|
||||
0.3311833 ,
|
||||
0.00618333,
|
||||
0.17538333,
|
||||
|
||||
0.10418332,
|
||||
0.8365834 ,
|
||||
0.27098334,
|
||||
|
||||
1.2421833 ,
|
||||
-0.1114167 ,
|
||||
1.0153834 ,
|
||||
|
||||
0.9523833 ,
|
||||
0.8317833 ,
|
||||
0.9633833 ,
|
||||
|
||||
0.6501833 ,
|
||||
0.04258335,
|
||||
0.9999833 ,
|
||||
|
||||
-0.40181667,
|
||||
0.11418331,
|
||||
0.47938335,
|
||||
|
||||
1.1057833 ,
|
||||
-0.29761666,
|
||||
1.0779834 ,
|
||||
|
||||
0.5243833 ,
|
||||
-0.32181668,
|
||||
1.1833833 ,
|
||||
|
||||
|
||||
0.73157084,
|
||||
0.4317708 ,
|
||||
0.7283708 ,
|
||||
|
||||
1.2297708 ,
|
||||
0.4307708 ,
|
||||
0.85377085,
|
||||
|
||||
0.05977082,
|
||||
-0.09282917,
|
||||
0.33957082,
|
||||
|
||||
1.0751709 ,
|
||||
0.2119708 ,
|
||||
0.51897085,
|
||||
|
||||
-0.25302917,
|
||||
1.1723708 ,
|
||||
-0.12562919,
|
||||
|
||||
1.1993709 ,
|
||||
0.5257708 ,
|
||||
0.40517086,
|
||||
|
||||
0.53197086,
|
||||
0.8441708 ,
|
||||
0.02617085,
|
||||
|
||||
-0.0208292 ,
|
||||
0.8711709 ,
|
||||
0.04137081,
|
||||
|
||||
|
||||
0.74936247,
|
||||
0.6085625 ,
|
||||
0.8997625 ,
|
||||
|
||||
-0.08743751,
|
||||
0.18576252,
|
||||
-0.17563748,
|
||||
|
||||
0.5991625 ,
|
||||
-0.0038375 ,
|
||||
0.07576251,
|
||||
|
||||
0.42536253,
|
||||
-0.22823751,
|
||||
0.36296248,
|
||||
|
||||
0.81456256,
|
||||
-0.16183749,
|
||||
0.5161625 ,
|
||||
|
||||
-0.21183747,
|
||||
0.7429625 ,
|
||||
0.6217625 ,
|
||||
|
||||
0.17656249,
|
||||
0.02616251,
|
||||
-0.17923748,
|
||||
|
||||
1.4659625 ,
|
||||
0.40016252,
|
||||
0.28356248,
|
||||
|
||||
|
||||
0.4195791 ,
|
||||
0.8745791 ,
|
||||
0.36637908,
|
||||
|
||||
0.50597906,
|
||||
-0.17942089,
|
||||
0.16917908,
|
||||
|
||||
1.0235791 ,
|
||||
1.3699791 ,
|
||||
-0.11382091,
|
||||
|
||||
-0.0918209 ,
|
||||
0.7757791 ,
|
||||
0.09017909,
|
||||
|
||||
1.3807791 ,
|
||||
-0.15202093,
|
||||
1.3875791 ,
|
||||
|
||||
-0.1712209 ,
|
||||
1.3989791 ,
|
||||
0.43777913,
|
||||
|
||||
0.7855791 ,
|
||||
0.1423791 ,
|
||||
1.4711791 ,
|
||||
|
||||
0.6455791 ,
|
||||
0.6211791 ,
|
||||
-0.48062086,
|
||||
|
||||
|
||||
0.10189578,
|
||||
0.5628958 ,
|
||||
0.68909574,
|
||||
|
||||
0.96649575,
|
||||
-0.09370419,
|
||||
1.3466958 ,
|
||||
|
||||
1.4584957 ,
|
||||
1.3544958 ,
|
||||
-0.3829042 ,
|
||||
|
||||
0.11269578,
|
||||
-0.47890422,
|
||||
1.0436958 ,
|
||||
|
||||
0.6128957 ,
|
||||
0.27209583,
|
||||
0.2714958 ,
|
||||
|
||||
0.21889582,
|
||||
0.08789578,
|
||||
1.1296958 ,
|
||||
|
||||
0.4596958 ,
|
||||
0.39309582,
|
||||
0.8344958 ,
|
||||
|
||||
0.71149576,
|
||||
-0.4799042,
|
||||
0.4880958
|
||||
});
|
||||
// x.linspace(1.);
|
||||
nd4j::ops::adjust_contrast_v2 op;
|
||||
auto result = op.execute({&x}, {2.}, {}, {});
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
auto out = result->at(0);
|
||||
// out->printBuffer("Adjusted Constrast7");
|
||||
// e.printBuffer("Adjusted expected 7");
|
||||
auto diff = e - *out;
|
||||
// diff.printBuffer("Adjusted subtract 7");
|
||||
ASSERT_TRUE(e.equalsTo(out));
|
||||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests15, Test_BitCast_1) {
|
||||
auto x = NDArrayFactory::create<float>('c', {2, 2, 2});
|
||||
auto e = NDArrayFactory::create<double>('c', {2, 2}, {2., 512., 8192., 131072.032 });
|
||||
|
|
|
@ -2612,8 +2612,9 @@ public class DifferentialFunctionFactory {
|
|||
return new DrawBoundingBoxes(sameDiff, boxes, colors).outputVariable();
|
||||
}
|
||||
|
||||
public SDVariable fakeQuantWithMinMaxVarsPerChannel(SDVariable x, SDVariable min, SDVariable max) {
|
||||
return new FakeQuantWithMinMaxVarsPerChannel(sameDiff,x,min,max).outputVariable();
|
||||
public SDVariable fakeQuantWithMinMaxVarsPerChannel(SDVariable x, SDVariable min, SDVariable max,
|
||||
int num_bits, boolean narrow) {
|
||||
return new FakeQuantWithMinMaxVarsPerChannel(sameDiff,x,min,max,num_bits,narrow).outputVariable();
|
||||
}
|
||||
|
||||
public SDVariable betainc( SDVariable a, SDVariable b, SDVariable x) {
|
||||
|
|
|
@ -0,0 +1,118 @@
|
|||
package org.nd4j.autodiff.listeners.debugging;
|
||||
|
||||
import lombok.NonNull;
|
||||
import org.nd4j.autodiff.listeners.At;
|
||||
import org.nd4j.autodiff.listeners.BaseListener;
|
||||
import org.nd4j.autodiff.listeners.Operation;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Xor;
|
||||
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
public class ArraySavingListener extends BaseListener {
|
||||
|
||||
protected final File dir;
|
||||
protected int count = 0;
|
||||
|
||||
public ArraySavingListener(@NonNull File dir){
|
||||
|
||||
if(!dir.exists()){
|
||||
dir.mkdir();
|
||||
}
|
||||
|
||||
if(dir.listFiles() != null && dir.listFiles().length > 0){
|
||||
throw new IllegalStateException("Directory is not empty: " + dir.getAbsolutePath());
|
||||
}
|
||||
|
||||
this.dir = dir;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isActive(Operation operation) {
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, INDArray[] outputs) {
|
||||
List<String> outNames = op.getOutputsOfOp();
|
||||
for(int i=0; i<outputs.length; i++ ){
|
||||
String filename = (count++) + "_" + outNames.get(i).replaceAll("/", "__") + ".bin";
|
||||
File outFile = new File(dir, filename);
|
||||
|
||||
INDArray arr = outputs[i];
|
||||
try {
|
||||
Nd4j.saveBinary(arr, outFile);
|
||||
System.out.println(outFile.getAbsolutePath());
|
||||
} catch (IOException e){
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public static void compare(File dir1, File dir2, double eps) throws Exception {
|
||||
File[] files1 = dir1.listFiles();
|
||||
File[] files2 = dir2.listFiles();
|
||||
Preconditions.checkNotNull(files1, "No files in directory 1: %s", dir1);
|
||||
Preconditions.checkNotNull(files2, "No files in directory 2: %s", dir2);
|
||||
Preconditions.checkState(files1.length == files2.length, "Different number of files: %s vs %s", files1.length, files2.length);
|
||||
|
||||
Map<String,File> m1 = toMap(files1);
|
||||
Map<String,File> m2 = toMap(files2);
|
||||
|
||||
for(File f : files1){
|
||||
String name = f.getName();
|
||||
String varName = name.substring(name.indexOf('_') + 1, name.length()-4); //Strip "x_" and ".bin"
|
||||
File f2 = m2.get(varName);
|
||||
|
||||
INDArray arr1 = Nd4j.readBinary(f);
|
||||
INDArray arr2 = Nd4j.readBinary(f2);
|
||||
|
||||
//TODO String arrays won't work here!
|
||||
boolean eq = arr1.equalsWithEps(arr2, eps);
|
||||
if(eq){
|
||||
System.out.println("Equals: " + varName.replaceAll("__", "/"));
|
||||
} else {
|
||||
if(arr1.dataType() == DataType.BOOL){
|
||||
INDArray xor = Nd4j.exec(new Xor(arr1, arr2));
|
||||
int count = xor.castTo(DataType.INT).sumNumber().intValue();
|
||||
System.out.println("FAILS: " + varName.replaceAll("__", "/") + " - boolean, # differences = " + count);
|
||||
System.out.println("\t" + f.getAbsolutePath());
|
||||
System.out.println("\t" + f2.getAbsolutePath());
|
||||
xor.close();
|
||||
} else {
|
||||
INDArray sub = arr1.sub(arr2);
|
||||
INDArray diff = Nd4j.math.abs(sub);
|
||||
double maxDiff = diff.maxNumber().doubleValue();
|
||||
System.out.println("FAILS: " + varName.replaceAll("__", "/") + " - max difference = " + maxDiff);
|
||||
System.out.println("\t" + f.getAbsolutePath());
|
||||
System.out.println("\t" + f2.getAbsolutePath());
|
||||
sub.close();
|
||||
diff.close();
|
||||
}
|
||||
}
|
||||
arr1.close();
|
||||
arr2.close();
|
||||
}
|
||||
}
|
||||
|
||||
private static Map<String,File> toMap(File[] files){
|
||||
Map<String,File> ret = new HashMap<>();
|
||||
for(File f : files) {
|
||||
String name = f.getName();
|
||||
String varName = name.substring(name.indexOf('_') + 1, name.length() - 4); //Strip "x_" and ".bin"
|
||||
ret.put(varName, f);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
}
|
|
@ -87,6 +87,7 @@ public class ImportClassMapping {
|
|||
org.nd4j.linalg.api.ops.impl.image.NonMaxSuppression.class,
|
||||
org.nd4j.linalg.api.ops.impl.image.NonMaxSuppressionV3.class,
|
||||
org.nd4j.linalg.api.ops.impl.image.ResizeBilinear.class,
|
||||
org.nd4j.linalg.api.ops.impl.image.ResizeBicubic.class,
|
||||
org.nd4j.linalg.api.ops.impl.image.ResizeNearestNeighbor.class,
|
||||
org.nd4j.linalg.api.ops.impl.indexaccum.FirstIndex.class,
|
||||
org.nd4j.linalg.api.ops.impl.indexaccum.IAMax.class,
|
||||
|
|
|
@ -102,4 +102,7 @@ public abstract class BaseReduceBoolOp extends BaseReduceOp implements ReduceBoo
|
|||
"with 2 inputs, second input (axis) must be an integer datatype for %s, got %s", getClass(), dataTypes);
|
||||
return Collections.singletonList(DataType.BOOL);
|
||||
}
|
||||
|
||||
|
||||
public abstract boolean emptyValue();
|
||||
}
|
||||
|
|
|
@ -21,30 +21,46 @@ import org.nd4j.base.Preconditions;
|
|||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.tensorflow.framework.AttrValue;
|
||||
import org.tensorflow.framework.GraphDef;
|
||||
import org.tensorflow.framework.NodeDef;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
public class FakeQuantWithMinMaxVarsPerChannel extends DynamicCustomOp {
|
||||
protected boolean narrowRange;
|
||||
protected int numBits;
|
||||
|
||||
public FakeQuantWithMinMaxVarsPerChannel() {}
|
||||
|
||||
public FakeQuantWithMinMaxVarsPerChannel(INDArray x, INDArray min, INDArray max) {
|
||||
public FakeQuantWithMinMaxVarsPerChannel(INDArray x, INDArray min, INDArray max, int num_bits, boolean narrow) {
|
||||
Preconditions.checkArgument(min.isVector() && max.isVector() &&
|
||||
min.length() == max.length(),
|
||||
"FakeQuantWithMinMaxVarsPerChannel: min and max should be 1D tensors with the same length");
|
||||
inputArguments.add(x);
|
||||
inputArguments.add(min);
|
||||
inputArguments.add(max);
|
||||
addInputArgument(x,min,max);
|
||||
addIArgument(num_bits);
|
||||
addBArgument(narrow);
|
||||
}
|
||||
|
||||
public FakeQuantWithMinMaxVarsPerChannel(INDArray x, INDArray min, INDArray max,
|
||||
INDArray output) {
|
||||
this(x,min,max);
|
||||
outputArguments.add(output);
|
||||
public FakeQuantWithMinMaxVarsPerChannel(INDArray x, INDArray min, INDArray max, int num_bits) {
|
||||
this(x, min, max, num_bits, false);
|
||||
}
|
||||
|
||||
public FakeQuantWithMinMaxVarsPerChannel(SameDiff sameDiff, SDVariable x, SDVariable min, SDVariable max) {
|
||||
public FakeQuantWithMinMaxVarsPerChannel(INDArray x, INDArray min, INDArray max, boolean narrow) {
|
||||
this(x, min, max, 8, narrow);
|
||||
}
|
||||
|
||||
public FakeQuantWithMinMaxVarsPerChannel(INDArray x, INDArray min, INDArray max) {
|
||||
this(x, min, max, 8, false);
|
||||
}
|
||||
|
||||
public FakeQuantWithMinMaxVarsPerChannel(SameDiff sameDiff, SDVariable x, SDVariable min, SDVariable max,
|
||||
int num_bits, boolean narrow) {
|
||||
super("", sameDiff, new SDVariable[]{x, min, max});
|
||||
addIArgument(num_bits);
|
||||
addBArgument(narrow);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -57,6 +73,18 @@ public class FakeQuantWithMinMaxVarsPerChannel extends DynamicCustomOp {
|
|||
return "FakeQuantWithMinMaxVarsPerChannel";
|
||||
}
|
||||
|
||||
@Override
|
||||
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
||||
if(attributesForNode.containsKey("narrow_range")){
|
||||
this.narrowRange = attributesForNode.get("narrow_range").getB();
|
||||
}
|
||||
if(attributesForNode.containsKey("num_bits")) {
|
||||
this.numBits = (int) attributesForNode.get("num_bits").getI();
|
||||
}
|
||||
addIArgument(numBits);
|
||||
addBArgument(narrowRange);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
||||
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 3, "Expected exactly 3 inputs, got %s", inputDataTypes);
|
||||
|
|
|
@ -0,0 +1,82 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2019 Konduit, K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.nd4j.linalg.api.ops.impl.image;
|
||||
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.NonNull;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.tensorflow.framework.AttrValue;
|
||||
import org.tensorflow.framework.GraphDef;
|
||||
import org.tensorflow.framework.NodeDef;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* ResizeBicubic op wrapper
|
||||
* @author Alexander Stoyakin
|
||||
*/
|
||||
@NoArgsConstructor
|
||||
public class ResizeBicubic extends DynamicCustomOp {
|
||||
|
||||
protected boolean alignCorners = false;
|
||||
protected boolean alignPixelCenters = false;
|
||||
|
||||
public ResizeBicubic(@NonNull INDArray image, INDArray size, boolean alignCorners, boolean alignPixelCenters) {
|
||||
addInputArgument(image, size);
|
||||
addBArgument(alignCorners, alignPixelCenters);
|
||||
}
|
||||
|
||||
public ResizeBicubic(@NonNull SameDiff sameDiff, @NonNull SDVariable image,
|
||||
SDVariable size, boolean alignCorners, boolean alignPixelCenters) {
|
||||
super(sameDiff, new SDVariable[]{image, size});
|
||||
addBArgument(alignCorners, alignPixelCenters);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "resize_bicubic";
|
||||
}
|
||||
|
||||
@Override
|
||||
public String tensorflowName() {
|
||||
return "ResizeBicubic";
|
||||
}
|
||||
|
||||
@Override
|
||||
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
||||
TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
||||
|
||||
this.alignCorners = attributesForNode.get("align_corners").getB();
|
||||
this.alignPixelCenters = attributesForNode.get("half_pixel_centers").getB();
|
||||
addBArgument(alignCorners, alignPixelCenters);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
||||
Preconditions.checkState(inputDataTypes != null && (inputDataTypes.size() == 1 || inputDataTypes.size() == 2),
|
||||
"Expected 1 or 2 input datatypes for %s, got %s", getClass(), inputDataTypes);
|
||||
return Collections.singletonList(Nd4j.defaultFloatingPointType());
|
||||
}
|
||||
}
|
|
@ -41,6 +41,10 @@ public class All extends BaseReduceBoolOp {
|
|||
super(x);
|
||||
}
|
||||
|
||||
public All(INDArray x, int... axis) {
|
||||
super(x, axis);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int opNum() {
|
||||
return 1;
|
||||
|
@ -65,4 +69,9 @@ public class All extends BaseReduceBoolOp {
|
|||
public String tensorflowName() {
|
||||
return "All";
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean emptyValue() {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -65,4 +65,9 @@ public class Any extends BaseReduceBoolOp {
|
|||
public String tensorflowName() {
|
||||
return "Any";
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean emptyValue() {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -71,4 +71,8 @@ public class IsInf extends BaseReduceBoolOp {
|
|||
return Collections.singletonList(f().zerosLike(arg()));
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean emptyValue() {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -71,4 +71,8 @@ public class IsNaN extends BaseReduceBoolOp {
|
|||
return Collections.singletonList(f().zerosLike(arg()));
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean emptyValue() {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4,6 +4,7 @@ import org.nd4j.autodiff.samediff.SDVariable;
|
|||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.tensorflow.framework.AttrValue;
|
||||
import org.tensorflow.framework.GraphDef;
|
||||
|
@ -37,11 +38,21 @@ public class FakeQuantWithMinMaxArgs extends DynamicCustomOp {
|
|||
addArgs();
|
||||
}
|
||||
|
||||
public FakeQuantWithMinMaxArgs(INDArray x, INDArray min, INDArray max, int num_bits, boolean narrow) {
|
||||
Preconditions.checkArgument(min.isVector() && max.isVector() &&
|
||||
min.length() == max.length(),
|
||||
"FakeQuantWithMinMaxArgs: min and max should be 1D tensors with the same length");
|
||||
addInputArgument(x,min,max);
|
||||
addIArgument(num_bits);
|
||||
addBArgument(narrow);
|
||||
}
|
||||
|
||||
public FakeQuantWithMinMaxArgs(){ }
|
||||
|
||||
protected void addArgs(){
|
||||
iArguments.clear();
|
||||
addIArgument(numBits, narrowRange ? 1 : 0);
|
||||
addIArgument(numBits);
|
||||
addBArgument(narrowRange);
|
||||
addTArgument(min, max);
|
||||
}
|
||||
|
||||
|
|
|
@ -4,6 +4,7 @@ import org.nd4j.autodiff.samediff.SDVariable;
|
|||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.tensorflow.framework.AttrValue;
|
||||
import org.tensorflow.framework.GraphDef;
|
||||
|
@ -33,11 +34,22 @@ public class FakeQuantWithMinMaxVars extends DynamicCustomOp {
|
|||
addArgs();
|
||||
}
|
||||
|
||||
public FakeQuantWithMinMaxVars(INDArray x, INDArray min, INDArray max, int num_bits, boolean narrow) {
|
||||
Preconditions.checkArgument(min.isVector() && max.isVector() &&
|
||||
min.length() == max.length(),
|
||||
"FakeQuantWithMinMaxVars: min and max should be 1D tensors with the same length");
|
||||
addInputArgument(x,min,max);
|
||||
addIArgument(num_bits);
|
||||
addBArgument(narrow);
|
||||
}
|
||||
|
||||
public FakeQuantWithMinMaxVars(){ }
|
||||
|
||||
protected void addArgs(){
|
||||
iArguments.clear();
|
||||
addIArgument(numBits, narrowRange ? 1 : 0);
|
||||
bArguments.clear();
|
||||
addIArgument(numBits);
|
||||
addBArgument(narrowRange);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -935,6 +935,18 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
|||
}
|
||||
}
|
||||
|
||||
// FIXME: this should be moved down to C++ on per-op basis
|
||||
// reduce to scalar case, ReduceBool ops require special treatment
|
||||
if (op instanceof BaseReduceBoolOp && op.x().isEmpty() && (dimension == null || (dimension.length == 1 && dimension[0] == Integer.MAX_VALUE))) {
|
||||
if (op.z() == null) {
|
||||
op.setZ(Nd4j.scalar(((BaseReduceBoolOp) op).emptyValue()));
|
||||
} else {
|
||||
op.z().assign(((BaseReduceBoolOp) op).emptyValue());
|
||||
}
|
||||
|
||||
return context;
|
||||
}
|
||||
|
||||
long st = profilingConfigurableHookIn(op);
|
||||
|
||||
checkForCompression(op);
|
||||
|
@ -994,9 +1006,9 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
|||
}
|
||||
}
|
||||
|
||||
if (op.x().isVector() && op.x().length() == ArrayUtil.prod(retShape)) {
|
||||
return null;
|
||||
}
|
||||
//if (op.x().isVector() && op.x().length() == ArrayUtil.prod(retShape)) {
|
||||
// return null;
|
||||
//}
|
||||
|
||||
val dataType = op.resultType();
|
||||
|
||||
|
|
|
@ -265,7 +265,18 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
|
|||
}
|
||||
}
|
||||
|
||||
// FIXME: this should be moved down to C++ on per-op basis
|
||||
val dimension = Shape.normalizeAxis(op.x().rank(), op.dimensions().toIntVector());
|
||||
// reduce to scalar case, ReduceBool ops require special treatment
|
||||
if (op instanceof BaseReduceBoolOp && op.x().isEmpty() && (dimension == null || (dimension.length == 1 && dimension[0] == Integer.MAX_VALUE))) {
|
||||
if (op.z() == null) {
|
||||
op.setZ(Nd4j.scalar(((BaseReduceBoolOp) op).emptyValue()));
|
||||
} else {
|
||||
op.z().assign(((BaseReduceBoolOp) op).emptyValue());
|
||||
}
|
||||
|
||||
return op.z();
|
||||
}
|
||||
|
||||
//validateDataType(Nd4j.dataType(), op);
|
||||
|
||||
|
|
|
@ -71,9 +71,6 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a
|
|||
//Still failing 2019/09/11
|
||||
"slogdet/.*",
|
||||
|
||||
// Failing 2019/11/14 - |https://github.com/eclipse/deeplearning4j/issues/8374
|
||||
"adjust_contrast/*",
|
||||
"adjust_contrast/.*",
|
||||
//Failing 2019/09/11 - https://github.com/eclipse/deeplearning4j/issues/7965
|
||||
"bincount/.*",
|
||||
// Failing 2019/11/14 https://github.com/eclipse/deeplearning4j/issues/8393
|
||||
|
@ -114,29 +111,17 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a
|
|||
// 2019/11/15 - missing dtype argument in nd4j, tests are useless https://github.com/eclipse/deeplearning4j/issues/8398
|
||||
"zeros_like/rank2_float32_dtype_int.*",
|
||||
|
||||
// 2019/11/15 - failure https://github.com/eclipse/deeplearning4j/issues/8402
|
||||
"fake_quant/min_max_args_per_channel.*",
|
||||
|
||||
// Suggesting TF 1.15 bug
|
||||
"non_max_suppression_v2/float16.*",
|
||||
|
||||
// 11.26.2019 failing - https://github.com/eclipse/deeplearning4j/issues/8450
|
||||
"betainc.*",
|
||||
|
||||
// 11.26.2019 failing - https://github.com/eclipse/deeplearning4j/issues/8452
|
||||
"polygamma.*",
|
||||
|
||||
// 11.26.2019 failing - https://github.com/eclipse/deeplearning4j/issues/8453
|
||||
"roll/.*",
|
||||
|
||||
// 11.26.2019 failing https://github.com/eclipse/deeplearning4j/issues/8455
|
||||
"matrix_band_part/.*",
|
||||
|
||||
// 11.28.2019 failing https://github.com/eclipse/deeplearning4j/issues/8458
|
||||
"adjust_hue/.*",
|
||||
|
||||
// 11.28.2019 failing https://github.com/eclipse/deeplearning4j/issues/8459
|
||||
"adjust_saturation/.*"
|
||||
// 05.12.2019 failing https://github.com/eclipse/deeplearning4j/issues/8507
|
||||
"resize_bicubic/int32.*"
|
||||
};
|
||||
|
||||
/* As per TFGraphTestList.printArraysDebugging - this field defines a set of regexes for test cases that should have
|
||||
|
|
|
@ -8134,6 +8134,36 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
|||
assertEquals(Nd4j.createFromArray(1.0,2,3,4,5,6), hStack);
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testReduceAll_1() {
|
||||
val x = Nd4j.empty(DataType.FLOAT);
|
||||
val e = Nd4j.scalar(true);
|
||||
val z = Nd4j.exec(new All(x));
|
||||
|
||||
assertEquals(e, z);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testReduceAll_2() {
|
||||
val x = Nd4j.ones(DataType.FLOAT, 0);
|
||||
val e = Nd4j.scalar(true);
|
||||
val z = Nd4j.exec(new All(x));
|
||||
|
||||
assertEquals(e, z);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testReduceAll_3() {
|
||||
val x = Nd4j.create(DataType.FLOAT, 0);
|
||||
assertEquals(1, x.rank());
|
||||
|
||||
val e = Nd4j.scalar(true);
|
||||
val z = Nd4j.exec(new All(x, 0));
|
||||
|
||||
assertEquals(e, z);
|
||||
}
|
||||
|
||||
@Override
|
||||
public char ordering() {
|
||||
return 'c';
|
||||
|
|
|
@ -943,16 +943,9 @@ public class CustomOpsTests extends BaseNd4jTest {
|
|||
0.0877f, 0.5966f, 0.6600f, 0.3513f, 0.1604f}).reshape(3,5);
|
||||
|
||||
INDArray out = Nd4j.createUninitialized(x.shape());
|
||||
val op = new FakeQuantWithMinMaxVarsPerChannel(x,min,max,out);
|
||||
val op = new FakeQuantWithMinMaxVarsPerChannel(x,min,max);
|
||||
Nd4j.exec(op);
|
||||
assertEquals(expected, out);
|
||||
|
||||
/*TF: [[ 0.7801, 0.5966, 0.7260, 0.2320, 0.5084],
|
||||
[ 0.1800, 0.5046, 0.8684, 0.3513, 0.5084],
|
||||
[ 0.0877, 0.5966, 0.6600, 0.3513, 0.1604]]
|
||||
SD: [[ 0.7770, 0.5969, 0.7232, 0.2310, 0.5098],
|
||||
[ 0.1793, 0.5053, 0.8685, 0.3500, 0.5098],
|
||||
[ 0.0874, 0.5969, 0.6574, 0.3500, 0.1597]]*/
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -1036,13 +1029,12 @@ public class CustomOpsTests extends BaseNd4jTest {
|
|||
INDArray min = Nd4j.createFromArray(new float[]{-63.65f});
|
||||
INDArray max = Nd4j.createFromArray(new float[]{0.1f});
|
||||
|
||||
INDArray output = Nd4j.createUninitialized(DataType.FLOAT, 1,2,3,1);
|
||||
INDArray expected = Nd4j.createFromArray(new float[]{-63.75f, -63.75f, -63.5f, -63.5f, 0.f, 0.f}).
|
||||
reshape(1,2,3,1);
|
||||
|
||||
Nd4j.exec(new FakeQuantWithMinMaxVarsPerChannel(x,min,max,output));
|
||||
INDArray[] output = Nd4j.exec(new FakeQuantWithMinMaxVarsPerChannel(x,min,max));
|
||||
|
||||
assertEquals(expected, output);
|
||||
assertEquals(expected, output[0]);
|
||||
}
|
||||
|
||||
@Test
|
||||
|
|
Loading…
Reference in New Issue