SameDiff Convolution Config validation, better output methods (#82)
* Conv Config validation & tests Signed-off-by: Ryan Nett <rnett@skymind.io> * stackOutputs utility method Signed-off-by: Ryan Nett <rnett@skymind.io> * use constructor for validation, support negative kernel sizes (infered from weights) Signed-off-by: Ryan Nett <rnett@skymind.io> * better output methods Signed-off-by: Ryan Nett <rnett@skymind.io> * move output to be with fit and evaluate Signed-off-by: Ryan Nett <rnett@skymind.io> * fixes Signed-off-by: Ryan Nett <rnett@skymind.io> * more fixes Signed-off-by: Ryan Nett <rnett@skymind.io>master
parent
8d1fe8b1b3
commit
d4e7997134
|
@ -1,444 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
//
|
|
||||||
// Created by raver119 on 08.10.2017.
|
|
||||||
//
|
|
||||||
|
|
||||||
#include <op_boilerplate.h>
|
|
||||||
#if NOT_EXCLUDED(OP_fullconv3d)
|
|
||||||
|
|
||||||
#include <ops/declarable/CustomOperations.h>
|
|
||||||
#include <ops/declarable/helpers/convolutions.h>
|
|
||||||
|
|
||||||
namespace nd4j {
|
|
||||||
namespace ops {
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
|
||||||
CUSTOM_OP_IMPL(fullconv3d, 5, 1, false, 0, 13) {
|
|
||||||
// auto input = INPUT_VARIABLE(0);
|
|
||||||
// auto weights = INPUT_VARIABLE(1);
|
|
||||||
// auto bias = INPUT_VARIABLE(2);
|
|
||||||
// auto columns = INPUT_VARIABLE(3);
|
|
||||||
// auto ones = INPUT_VARIABLE(4);
|
|
||||||
|
|
||||||
// REQUIRE_TRUE(weights->rankOf() == 5, 0, "Weights should be 5D, got %i instead", weights->rankOf());
|
|
||||||
// REQUIRE_TRUE(input->rankOf() == 5, 0, "Input should be 5D, got %i instead", input->rankOf());
|
|
||||||
|
|
||||||
// // strides
|
|
||||||
// int dT = INT_ARG(0);
|
|
||||||
// int dW = INT_ARG(1);
|
|
||||||
// int dH = INT_ARG(2);
|
|
||||||
|
|
||||||
// // padding
|
|
||||||
// int pT = INT_ARG(3);
|
|
||||||
// int pW = INT_ARG(4);
|
|
||||||
// int pH = INT_ARG(5);
|
|
||||||
|
|
||||||
// // dilation
|
|
||||||
// int dilationT = INT_ARG(6);
|
|
||||||
// int dilationW = INT_ARG(7);
|
|
||||||
// int dilationH = INT_ARG(8);
|
|
||||||
|
|
||||||
// // output padding
|
|
||||||
// int aT = INT_ARG(9);
|
|
||||||
// int aW = INT_ARG(10);
|
|
||||||
// int aH = INT_ARG(11);
|
|
||||||
|
|
||||||
// // bias
|
|
||||||
// bool biasUsed = INT_ARG(12) != 0;
|
|
||||||
|
|
||||||
|
|
||||||
// REQUIRE_TRUE(dT > 0 && dW > 0 && dH > 0, 11,
|
|
||||||
// "stride should be greater than zero, but got dT: %d dH: %d dW: %d", dT, dH, dW);
|
|
||||||
// REQUIRE_TRUE(dilationT > 0 && dilationW > 0 && dilationH > 0, 15,
|
|
||||||
// "dilation should be greater than zero, but got dilationT: %d, dilationH: %d, dilationW: %d",
|
|
||||||
// dilationT, dilationH, dilationW);
|
|
||||||
// REQUIRE_TRUE((aT < dT || aT < dilationT)
|
|
||||||
// && (aW < dW || aW < dilationW)
|
|
||||||
// && (aH < dH || aH < dilationH), 15,
|
|
||||||
// "output padding must be smaller than either stride or dilation,"
|
|
||||||
// " but got aT: %d aH: %d aW: %d dT: %d dH: %d dW: %d "
|
|
||||||
// "dilationT: %d dilationH: %d dilationW: %d",
|
|
||||||
// aT, aH, aW, dT, dH, dW, dilationT, dilationH, dilationW);
|
|
||||||
|
|
||||||
// auto output = this->getZ(block);
|
|
||||||
|
|
||||||
// const int nInputPlane = weights->shapeOf()[0];
|
|
||||||
// const int nOutputPlane = weights->shapeOf()[1];
|
|
||||||
// const int kT = weights->shapeOf()[2];
|
|
||||||
// const int kH = weights->shapeOf()[3];
|
|
||||||
// const int kW = weights->shapeOf()[4];
|
|
||||||
|
|
||||||
// const Nd4jLong inputWidth = input->shapeOf()[4];
|
|
||||||
// const Nd4jLong inputHeight = input->shapeOf()[3];
|
|
||||||
// const Nd4jLong inputDepth = input->shapeOf()[2];
|
|
||||||
// const Nd4jLong outputDepth = (inputDepth - 1) * dT - 2*pT + (dilationT * (kT - 1) + 1) + aT;
|
|
||||||
// const Nd4jLong outputHeight = (inputHeight - 1) * dH - 2*pH + (dilationH * (kH - 1) + 1) + aH;
|
|
||||||
// const Nd4jLong outputWidth = (inputWidth - 1) * dW - 2*pW + (dilationW * (kW - 1) + 1) + aW;
|
|
||||||
|
|
||||||
// const Nd4jLong batchSize = input->shapeOf()[0];
|
|
||||||
|
|
||||||
// REQUIRE_TRUE(output->isSameShape({ (int) batchSize, (int)nOutputPlane, (int)outputDepth, (int)outputHeight, (int)outputWidth}), 0, "Output should have shape of [%i, %i, %i, %i, %i], but got [%i, %i, %i, %i, %i] instead", (int) batchSize, (int)nOutputPlane, (int)outputDepth, (int)outputHeight, (int)outputWidth, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2), output->sizeAt(3), output->sizeAt(4));
|
|
||||||
|
|
||||||
// std::unique_ptr<ResultSet> inputs(input->allExamples());
|
|
||||||
// std::unique_ptr<ResultSet> outputs(output->allExamples());
|
|
||||||
// for (int e = 0; e < batchSize; e++) {
|
|
||||||
// auto tadIn = inputs->at(e);
|
|
||||||
// auto tadOut = outputs->at(e);
|
|
||||||
|
|
||||||
// const int m = weights->shapeOf()[1] * weights->shapeOf()[2] * weights->shapeOf()[3] * weights->shapeOf()[4];
|
|
||||||
// const int n = columns->shapeOf()[1];
|
|
||||||
// const int k = weights->shapeOf()[0];
|
|
||||||
|
|
||||||
// // FIXME: mmul helper should be used here
|
|
||||||
// /*
|
|
||||||
// nd4j::blas::GEMM<T>::op('c', 'n', 't', m, n, k,
|
|
||||||
// 1.0,
|
|
||||||
// tadIn->getBuffer(), n,
|
|
||||||
// weights->getBuffer(), m,
|
|
||||||
// 0.0,
|
|
||||||
// columns->getBuffer(), n);
|
|
||||||
// */
|
|
||||||
|
|
||||||
// // ConvolutionUtils<T>::_col2vol(columns->getBuffer(),
|
|
||||||
// // nOutputPlane, outputDepth, outputHeight, outputWidth,
|
|
||||||
// // inputDepth, inputHeight, inputWidth,
|
|
||||||
// // kT, kH, kW,
|
|
||||||
// // pT, pH, pW,
|
|
||||||
// // dT, dH, dW,
|
|
||||||
// // dilationT, dilationH, dilationW,
|
|
||||||
// // tadOut->getBuffer());
|
|
||||||
// ConvolutionUtils::col2vol(*columns, *tadOut, dT, dH, dW, pT, pH, pW, dilationT, dilationH, dilationW);
|
|
||||||
|
|
||||||
|
|
||||||
// const int m_ = nOutputPlane;
|
|
||||||
// const int n_ = outputDepth * outputHeight * outputWidth;
|
|
||||||
// const int k_ = 1;
|
|
||||||
|
|
||||||
// if (biasUsed) {
|
|
||||||
// // FIXME: mmul helper should be used here
|
|
||||||
// /*
|
|
||||||
// nd4j::blas::GEMM<T>::op('c', 't', 'n', n_, m_, k_,
|
|
||||||
// 1.0,
|
|
||||||
// ones->getBuffer(), k_,
|
|
||||||
// bias->getBuffer(), k_,
|
|
||||||
// 1.0,
|
|
||||||
// tadOut->getBuffer(), n_);
|
|
||||||
// */
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
|
|
||||||
// STORE_RESULT(*output);
|
|
||||||
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
DECLARE_TYPES(fullconv3d) {
|
|
||||||
getOpDescriptor()
|
|
||||||
->setAllowedInputTypes(nd4j::DataType::ANY)
|
|
||||||
->setAllowedOutputTypes({ALL_FLOATS});
|
|
||||||
}
|
|
||||||
|
|
||||||
DECLARE_SHAPE_FN(fullconv3d) {
|
|
||||||
// auto input = inputShape->at(0);
|
|
||||||
// auto weights = inputShape->at(1);
|
|
||||||
|
|
||||||
// // strides
|
|
||||||
// int dT = INT_ARG(0);
|
|
||||||
// int dW = INT_ARG(1);
|
|
||||||
// int dH = INT_ARG(2);
|
|
||||||
|
|
||||||
// // padding
|
|
||||||
// int pT = INT_ARG(3);
|
|
||||||
// int pW = INT_ARG(4);
|
|
||||||
// int pH = INT_ARG(5);
|
|
||||||
|
|
||||||
// // dilation
|
|
||||||
// int dilationT = INT_ARG(6);
|
|
||||||
// int dilationW = INT_ARG(7);
|
|
||||||
// int dilationH = INT_ARG(8);
|
|
||||||
|
|
||||||
// // output padding
|
|
||||||
// int aT = INT_ARG(9);
|
|
||||||
// int aW = INT_ARG(10);
|
|
||||||
// int aH = INT_ARG(11);
|
|
||||||
|
|
||||||
// // bias
|
|
||||||
// bool biasUsed = INT_ARG(12) != 0;
|
|
||||||
|
|
||||||
// Nd4jLong *shapeOf;
|
|
||||||
// Nd4jLong *newShape;
|
|
||||||
// ALLOCATE(shapeOf, block.getWorkspace(), 5, Nd4jLong);
|
|
||||||
// ALLOCATE(newShape, block.getWorkspace(), shape::shapeInfoLength(5), Nd4jLong);
|
|
||||||
|
|
||||||
// const int nInputPlane = weights[1];
|
|
||||||
// const int nOutputPlane = weights[2];
|
|
||||||
// const int kT = weights[3];
|
|
||||||
// const int kH = weights[4];
|
|
||||||
// const int kW = weights[5];
|
|
||||||
|
|
||||||
// const int batchSize = input[1];
|
|
||||||
// const Nd4jLong inputWidth = input[5];
|
|
||||||
// const Nd4jLong inputHeight = input[4];
|
|
||||||
// const Nd4jLong inputDepth = input[3];
|
|
||||||
// const Nd4jLong outputDepth = (inputDepth - 1) * dT - 2*pT + (dilationT * (kT - 1) + 1) + aT;
|
|
||||||
// const Nd4jLong outputHeight = (inputHeight - 1) * dH - 2*pH + (dilationH * (kH - 1) + 1) + aH;
|
|
||||||
// const Nd4jLong outputWidth = (inputWidth - 1) * dW - 2*pW + (dilationW * (kW - 1) + 1) + aW;
|
|
||||||
|
|
||||||
// nd4j::ArrayUtils::toLongPtr({(Nd4jLong) batchSize, (Nd4jLong)nOutputPlane, (Nd4jLong)outputDepth, (Nd4jLong)outputHeight, (Nd4jLong)outputWidth}, shapeOf);
|
|
||||||
|
|
||||||
// shape::shapeBuffer(5, shapeOf, newShape);
|
|
||||||
|
|
||||||
// RELEASE(shapeOf, block.getWorkspace());
|
|
||||||
|
|
||||||
// return SHAPELIST(newShape);
|
|
||||||
return SHAPELIST();
|
|
||||||
}
|
|
||||||
|
|
||||||
DECLARE_TYPES(fullconv3d_bp) {
|
|
||||||
getOpDescriptor()
|
|
||||||
->setAllowedInputTypes(nd4j::DataType::ANY)
|
|
||||||
->setAllowedOutputTypes({ALL_FLOATS});
|
|
||||||
}
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
|
||||||
CUSTOM_OP_IMPL(fullconv3d_bp, 5, 1, false, 0, 13) {
|
|
||||||
// auto input = INPUT_VARIABLE(0);
|
|
||||||
// auto gradNext = INPUT_VARIABLE(1);
|
|
||||||
// auto weights = INPUT_VARIABLE(2);
|
|
||||||
// auto finput = INPUT_VARIABLE(3);
|
|
||||||
|
|
||||||
// // not used
|
|
||||||
// auto fgradInput = INPUT_VARIABLE(4);
|
|
||||||
|
|
||||||
|
|
||||||
// REQUIRE_TRUE(weights->rankOf() == 5, 0, "Weights should be 5D, got %i instead", weights->rankOf());
|
|
||||||
// REQUIRE_TRUE(input->rankOf() == 5, 0, "Input should be 5D, got %i instead", input->rankOf());
|
|
||||||
|
|
||||||
// auto output = OUTPUT_VARIABLE(0);
|
|
||||||
|
|
||||||
// int dT = INT_ARG(0);
|
|
||||||
// int dW = INT_ARG(1);
|
|
||||||
// int dH = INT_ARG(2);
|
|
||||||
// int pT = INT_ARG(3);
|
|
||||||
// int pW = INT_ARG(4);
|
|
||||||
// int pH = INT_ARG(5);
|
|
||||||
// int dilationT = INT_ARG(6);
|
|
||||||
// int dilationW = INT_ARG(7);
|
|
||||||
// int dilationH = INT_ARG(8);
|
|
||||||
// int aT = INT_ARG(9);
|
|
||||||
// int aW = INT_ARG(10);
|
|
||||||
// int aH = INT_ARG(11);
|
|
||||||
// bool biasUsed = INT_ARG(12) != 0;
|
|
||||||
|
|
||||||
// const int nInputPlane = (int)weights->shapeOf()[0];
|
|
||||||
// const int nOutputPlane = (int)weights->shapeOf()[1];
|
|
||||||
// const int kT = (int)weights->shapeOf()[2];
|
|
||||||
// const int kH = (int)weights->shapeOf()[3];
|
|
||||||
// const int kW = (int)weights->shapeOf()[4];
|
|
||||||
|
|
||||||
// const Nd4jLong inputWidth = input->shapeOf()[4];
|
|
||||||
// const Nd4jLong inputHeight = input->shapeOf()[3];
|
|
||||||
// const Nd4jLong inputDepth = input->shapeOf()[2];
|
|
||||||
// const Nd4jLong outputDepth = (inputDepth - 1) * dT - 2*pT + (dilationT * (kT - 1) + 1) + aT;
|
|
||||||
// const Nd4jLong outputHeight = (inputHeight - 1) * dH - 2*pH + (dilationH * (kH - 1) + 1) + aH;
|
|
||||||
// const Nd4jLong outputWidth = (inputWidth - 1) * dW - 2*pW + (dilationW * (kW - 1) + 1) + aW;
|
|
||||||
|
|
||||||
// const Nd4jLong batchSize = input->shapeOf()[0];
|
|
||||||
|
|
||||||
|
|
||||||
// REQUIRE_TRUE(output->isSameShape({(int) batchSize, (int) nInputPlane, (int) inputDepth, (int) inputHeight, (int) inputWidth}) ,0, "Output should have shape of [%i, %i, %i, %i, %i], but got [%i, %i, %i, %i, %i] instead", (int) batchSize, (int) nInputPlane, (int) inputDepth, (int) inputHeight, (int) inputWidth, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2), output->sizeAt(3), output->sizeAt(4));
|
|
||||||
|
|
||||||
// output->assign(0.0);
|
|
||||||
|
|
||||||
// // FIXME: non-inplace reshape!!!!
|
|
||||||
// NDArray *gradColumns;
|
|
||||||
// //auto gradColumns = finput->reshape('c', {nOutputPlane*kW*kH*kT, inputDepth*inputHeight*inputWidth });
|
|
||||||
|
|
||||||
// std::unique_ptr<ResultSet> tadsNext(gradNext->allExamples());
|
|
||||||
// std::unique_ptr<ResultSet> tadsOutput(output->allExamples());
|
|
||||||
// for (int e = 0; e < tadsNext->size(); e++) {
|
|
||||||
// auto tadNext = tadsNext->at(e);
|
|
||||||
// auto tadOutput = tadsOutput->at(e);
|
|
||||||
|
|
||||||
// // ConvolutionUtils<T>::_vol2col(
|
|
||||||
// // tadNext->getBuffer(),
|
|
||||||
// // nOutputPlane, outputDepth, outputHeight, outputWidth,
|
|
||||||
// // kT, kH, kW,
|
|
||||||
// // pT, pH, pW,
|
|
||||||
// // dT, dH, dW,
|
|
||||||
// // dilationT, dilationH, dilationW,
|
|
||||||
// // gradColumns->getBuffer());
|
|
||||||
// ConvolutionUtils::vol2col(*tadNext, *gradColumns, dT, dH, dW, pT, pH, pW, dilationT, dilationH, dilationW);
|
|
||||||
|
|
||||||
// const auto m = weights->shapeOf()[0];
|
|
||||||
// const auto n = gradColumns->shapeOf()[1];
|
|
||||||
// const auto k = weights->shapeOf()[1] * weights->shapeOf()[2] * weights->shapeOf()[3] * weights->shapeOf()[4];
|
|
||||||
|
|
||||||
// // FIXME: mmul helper should be used here
|
|
||||||
// /*
|
|
||||||
// nd4j::blas::GEMM<T>::op('f', 'n', 'n',
|
|
||||||
// n, m, k,
|
|
||||||
// 1.0f,
|
|
||||||
// gradColumns->getBuffer(), n,
|
|
||||||
// weights->getBuffer(), k,
|
|
||||||
// 0,
|
|
||||||
// tadOutput->getBuffer(), n
|
|
||||||
|
|
||||||
// );
|
|
||||||
// */
|
|
||||||
// }
|
|
||||||
|
|
||||||
|
|
||||||
// STORE_RESULT(*output);
|
|
||||||
|
|
||||||
// delete gradColumns;
|
|
||||||
return ND4J_STATUS_OK;
|
|
||||||
}
|
|
||||||
DECLARE_SHAPE_FN(fullconv3d_bp) {
|
|
||||||
// output shape equals to input shape, all out of sudden
|
|
||||||
// Nd4jLong* newShape;
|
|
||||||
// COPY_SHAPE(inputShape->at(0), newShape);
|
|
||||||
|
|
||||||
// return SHAPELIST(newShape);
|
|
||||||
return SHAPELIST();
|
|
||||||
}
|
|
||||||
|
|
||||||
DECLARE_TYPES(fullconv3d_grad) {
|
|
||||||
getOpDescriptor()
|
|
||||||
->setAllowedInputTypes(nd4j::DataType::ANY)
|
|
||||||
->setAllowedOutputTypes({ALL_FLOATS});
|
|
||||||
}
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
|
||||||
CUSTOM_OP_IMPL(fullconv3d_grad, 4, 2, false, 1, 13) {
|
|
||||||
// auto input = INPUT_VARIABLE(0);
|
|
||||||
// auto epsilon = INPUT_VARIABLE(1);
|
|
||||||
// auto columns = INPUT_VARIABLE(2);
|
|
||||||
// auto ones = INPUT_VARIABLE(3);
|
|
||||||
|
|
||||||
// REQUIRE_TRUE(input->rankOf() == epsilon->rankOf(), 0, "Rank of input (%i) & epsilon (%i) should be equal", input->rankOf(), epsilon->rankOf());
|
|
||||||
// REQUIRE_TRUE(input->sizeAt(0) == epsilon->sizeAt(0), 1, "Batch size should be equal for input and epsilon");
|
|
||||||
|
|
||||||
// auto gradWeight = OUTPUT_VARIABLE(0);
|
|
||||||
// auto gradBias = OUTPUT_VARIABLE(1);
|
|
||||||
|
|
||||||
// REQUIRE_TRUE(gradBias->sizeAt(0) == gradWeight->sizeAt(1), 0, "Bias shape mismatch");
|
|
||||||
|
|
||||||
// int dT = INT_ARG(0);
|
|
||||||
// int dW = INT_ARG(1);
|
|
||||||
// int dH = INT_ARG(2);
|
|
||||||
// int pT = INT_ARG(3);
|
|
||||||
// int pW = INT_ARG(4);
|
|
||||||
// int pH = INT_ARG(5);
|
|
||||||
// int dilationT = INT_ARG(6);
|
|
||||||
// int dilationW = INT_ARG(7);
|
|
||||||
// int dilationH = INT_ARG(8);
|
|
||||||
// int aT = INT_ARG(9);
|
|
||||||
// int aW = INT_ARG(10);
|
|
||||||
// int aH = INT_ARG(11);
|
|
||||||
// bool biasUsed = INT_ARG(12) != 0;
|
|
||||||
|
|
||||||
// double scale = block.getTArguments()->at(0);
|
|
||||||
|
|
||||||
// int nInputPlane = (int)gradWeight->shapeOf()[0];
|
|
||||||
// int nOutputPlane = (int)gradWeight->shapeOf()[1];
|
|
||||||
// int kT = (int)gradWeight->shapeOf()[2];
|
|
||||||
// int kH = (int)gradWeight->shapeOf()[3];
|
|
||||||
// int kW = (int)gradWeight->shapeOf()[4];
|
|
||||||
|
|
||||||
|
|
||||||
// const Nd4jLong inputWidth = input->shapeOf()[4];
|
|
||||||
// const Nd4jLong inputHeight = input->shapeOf()[3];
|
|
||||||
// const Nd4jLong inputDepth = input->shapeOf()[2];
|
|
||||||
// const Nd4jLong outputDepth = (inputDepth - 1) * dT - 2*pT + (dilationT * (kT - 1) + 1) + aT;
|
|
||||||
// const Nd4jLong outputHeight = (inputHeight - 1) * dH - 2*pH + (dilationH * (kH - 1) + 1) + aH;
|
|
||||||
// const Nd4jLong outputWidth = (inputWidth - 1) * dW - 2*pW + (dilationW * (kW - 1) + 1) + aW;
|
|
||||||
|
|
||||||
|
|
||||||
// REQUIRE_TRUE(gradWeight->isContiguous(), 0, "gradWight should be continuous");
|
|
||||||
// REQUIRE_TRUE(gradBias->isContiguous(), 0, "gradBias should be continuous");
|
|
||||||
// REQUIRE_TRUE(ones->rankOf() == 3, 0, "Ones should have rank 3, got %i instead", ones->rankOf());
|
|
||||||
|
|
||||||
// REQUIRE_TRUE(ones->isSameShape({outputDepth, outputHeight, outputWidth}), 0, "");
|
|
||||||
|
|
||||||
// ones->assign(1.0);
|
|
||||||
|
|
||||||
// std::unique_ptr<ResultSet> tadsInput(input->allExamples());
|
|
||||||
// std::unique_ptr<ResultSet> tadsEpsilon(epsilon->allExamples());
|
|
||||||
|
|
||||||
// for (int e = 0; e < tadsInput->size(); e++) {
|
|
||||||
// auto tadInput = tadsInput->at(e);
|
|
||||||
// auto tadEpsilon = tadsEpsilon->at(e);
|
|
||||||
|
|
||||||
// // ConvolutionUtils<T>::_vol2col(
|
|
||||||
// // tadEpsilon->getBuffer(), nOutputPlane,
|
|
||||||
// // outputDepth, outputHeight, outputWidth,
|
|
||||||
// // kT, kH, kW,
|
|
||||||
// // pT, pH, pW,
|
|
||||||
// // dT, dH, dW,
|
|
||||||
// // dilationT, dilationH, dilationW,
|
|
||||||
// // columns->getBuffer()
|
|
||||||
// // );
|
|
||||||
// ConvolutionUtils::vol2col(*tadEpsilon, *columns, dT, dH, dW, pT, pH, pW, dilationT, dilationH, dilationW);
|
|
||||||
// const Nd4jLong n = columns->shapeOf()[0]; // nOutputPlane * kt * kh * kw
|
|
||||||
// const Nd4jLong m = tadInput->shapeOf()[0]; // nInputPlane
|
|
||||||
// const Nd4jLong k = columns->shapeOf()[1];
|
|
||||||
|
|
||||||
// // FIXME: mmul helper should be used here
|
|
||||||
// /**
|
|
||||||
// nd4j::blas::GEMM<T>::op('f', 't', 'n',
|
|
||||||
// n, m, k,
|
|
||||||
// scale,
|
|
||||||
// columns->getBuffer(), k,
|
|
||||||
// tadInput->getBuffer(), k,
|
|
||||||
// 1,
|
|
||||||
// gradWeight->getBuffer(), n);
|
|
||||||
// */
|
|
||||||
|
|
||||||
// const Nd4jLong m_ = nOutputPlane;
|
|
||||||
// const Nd4jLong k_ = outputDepth * outputHeight * outputWidth;
|
|
||||||
|
|
||||||
|
|
||||||
// if (gradBias) {
|
|
||||||
// // FIXME: mmul helper should be used here
|
|
||||||
// /*
|
|
||||||
// nd4j::blas::GEMV<T>::op('t',
|
|
||||||
// k_, m_,
|
|
||||||
// scale,
|
|
||||||
// tadEpsilon->getBuffer(), k_,
|
|
||||||
// ones->getBuffer(), 1, (T)1.0f,
|
|
||||||
// gradBias->getBuffer(), 1);
|
|
||||||
// */
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
|
|
||||||
|
|
||||||
// STORE_2_RESULTS(*gradWeight, *gradBias);
|
|
||||||
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
DECLARE_SHAPE_FN(fullconv3d_grad) {
|
|
||||||
// auto list = SHAPELIST();
|
|
||||||
|
|
||||||
// _grad ops MUST have output arrays provided
|
|
||||||
|
|
||||||
// return list;
|
|
||||||
return SHAPELIST();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#endif
|
|
|
@ -187,12 +187,6 @@ namespace nd4j {
|
||||||
DECLARE_CUSTOM_OP(pnormpool2d_bp, 2, 1, false, 1, 10);
|
DECLARE_CUSTOM_OP(pnormpool2d_bp, 2, 1, false, 1, 10);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if NOT_EXCLUDED(OP_fullconv3d)
|
|
||||||
DECLARE_CUSTOM_OP(fullconv3d, 5, 1, false, 0, 13);
|
|
||||||
DECLARE_CUSTOM_OP(fullconv3d_bp, 5, 1, false, 0, 13);
|
|
||||||
DECLARE_CUSTOM_OP(fullconv3d_grad, 4, 2, false, 1, 13);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This op implements im2col algorithm, widely used in convolution neural networks
|
* This op implements im2col algorithm, widely used in convolution neural networks
|
||||||
* Input: 4D input expected
|
* Input: 4D input expected
|
||||||
|
|
|
@ -16,6 +16,9 @@
|
||||||
|
|
||||||
package org.nd4j.autodiff.samediff;
|
package org.nd4j.autodiff.samediff;
|
||||||
|
|
||||||
|
import static org.nd4j.autodiff.util.TrainingUtils.getSingleOutput;
|
||||||
|
import static org.nd4j.autodiff.util.TrainingUtils.stackOutputs;
|
||||||
|
|
||||||
import com.google.common.collect.HashBasedTable;
|
import com.google.common.collect.HashBasedTable;
|
||||||
import com.google.common.collect.Table;
|
import com.google.common.collect.Table;
|
||||||
import com.google.common.primitives.Ints;
|
import com.google.common.primitives.Ints;
|
||||||
|
@ -73,6 +76,7 @@ import org.nd4j.linalg.dataset.adapter.SingletonMultiDataSetIterator;
|
||||||
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||||
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
|
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
|
||||||
|
import org.nd4j.linalg.exception.ND4JException;
|
||||||
import org.nd4j.linalg.exception.ND4JIllegalArgumentException;
|
import org.nd4j.linalg.exception.ND4JIllegalArgumentException;
|
||||||
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
||||||
import org.nd4j.linalg.exception.ND4UnresolvedOutputVariables;
|
import org.nd4j.linalg.exception.ND4UnresolvedOutputVariables;
|
||||||
|
@ -109,7 +113,7 @@ import org.tensorflow.framework.GraphDef;
|
||||||
* <p>
|
* <p>
|
||||||
* That graph accumulates operations.
|
* That graph accumulates operations.
|
||||||
* <p>
|
* <p>
|
||||||
* In order to execute the graph, you run one of the execution methods, such as {@link #exec(Map, String...)}
|
* In order to execute the graph, you run one of the execution methods, such as {@link #output(Map, String...)}
|
||||||
*/
|
*/
|
||||||
@AllArgsConstructor
|
@AllArgsConstructor
|
||||||
@Builder
|
@Builder
|
||||||
|
@ -2262,7 +2266,7 @@ public class SameDiff extends SDBaseOps {
|
||||||
MultiDataSet ds = iterator.next();
|
MultiDataSet ds = iterator.next();
|
||||||
Map<String,INDArray> placeholderMap = toPlaceholderMap(ds);
|
Map<String,INDArray> placeholderMap = toPlaceholderMap(ds);
|
||||||
|
|
||||||
Map<String,INDArray> m = exec(placeholderMap, reqVars);
|
Map<String,INDArray> m = output(placeholderMap, reqVars);
|
||||||
|
|
||||||
for(Map.Entry<String,List<IEvaluation>> e : variableEvals.entrySet()){
|
for(Map.Entry<String,List<IEvaluation>> e : variableEvals.entrySet()){
|
||||||
INDArray prediction = m.get(e.getKey());
|
INDArray prediction = m.get(e.getKey());
|
||||||
|
@ -2288,7 +2292,15 @@ public class SameDiff extends SDBaseOps {
|
||||||
* @param outputs The variables to evaluate
|
* @param outputs The variables to evaluate
|
||||||
*/
|
*/
|
||||||
public Map<String, INDArray> output(DataSet dataSet, String... outputs){
|
public Map<String, INDArray> output(DataSet dataSet, String... outputs){
|
||||||
return output(new SingletonMultiDataSetIterator(dataSet.toMultiDataSet()), outputs).get(0);
|
return outputBatches(new SingletonMultiDataSetIterator(dataSet.toMultiDataSet()), outputs).get(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Single output inference.
|
||||||
|
* See {@link #output(DataSet, String...)}
|
||||||
|
*/
|
||||||
|
public INDArray outputSingle(DataSet dataSet, String output){
|
||||||
|
return output(dataSet, output).get(output);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -2299,13 +2311,40 @@ public class SameDiff extends SDBaseOps {
|
||||||
* sameDiff.output(iterator, "softmax");}
|
* sameDiff.output(iterator, "softmax");}
|
||||||
* </pre>
|
* </pre>
|
||||||
*
|
*
|
||||||
|
* Uses concatenation on the outputs of {@link #outputBatches(DataSetIterator, String...)} which may cause issues with some inputs.
|
||||||
|
* RNNs with variable time series length and CNNs with variable image sizes will most likely have issues.
|
||||||
|
*
|
||||||
* @param iterator Iterator as source of data to evaluate
|
* @param iterator Iterator as source of data to evaluate
|
||||||
* @param outputs The variables to evaluate
|
* @param outputs The variables to evaluate
|
||||||
*/
|
*/
|
||||||
public List<Map<String, INDArray>> output(DataSetIterator iterator, String... outputs){
|
public Map<String, INDArray> output(DataSetIterator iterator, String... outputs){
|
||||||
return output(new MultiDataSetIteratorAdapter(iterator), outputs);
|
return output(new MultiDataSetIteratorAdapter(iterator), outputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* See {@link #output(DataSetIterator, String...)}, but without the concatenation of batches.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
public List<Map<String, INDArray>> outputBatches(DataSetIterator iterator, String... outputs){
|
||||||
|
return outputBatches(new MultiDataSetIteratorAdapter(iterator), outputs);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Single output inference.
|
||||||
|
* See {@link #output(DataSetIterator, String...)}
|
||||||
|
*/
|
||||||
|
public INDArray outputSingle(DataSetIterator dataSet, String output){
|
||||||
|
return output(dataSet, output).get(output);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Single batched output inference.
|
||||||
|
* See {@link #output(DataSetIterator, String...)}
|
||||||
|
*/
|
||||||
|
public List<INDArray> outputSingleBatches(DataSetIterator dataSet, String output){
|
||||||
|
return getSingleOutput(outputBatches(dataSet, output), output);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Perform inference.<br>
|
* Perform inference.<br>
|
||||||
* <br>
|
* <br>
|
||||||
|
@ -2321,10 +2360,20 @@ public class SameDiff extends SDBaseOps {
|
||||||
* }
|
* }
|
||||||
* </pre>
|
* </pre>
|
||||||
*
|
*
|
||||||
|
* Uses concatenation on the outputs of {@link #outputBatches(MultiDataSetIterator, String...)} which may cause issues with some inputs.
|
||||||
|
* RNNs with variable time series length and CNNs with variable image sizes will most likely have issues.
|
||||||
|
*
|
||||||
* @param iterator The iterator - the source of the data for inference
|
* @param iterator The iterator - the source of the data for inference
|
||||||
* @param outputs The set of outputs to report. If null, defaults to all outputs of this SameDiff.
|
* @param outputs The set of outputs to report. If null, defaults to all outputs of this SameDiff.
|
||||||
*/
|
*/
|
||||||
public List<Map<String, INDArray>> output(MultiDataSetIterator iterator, String... outputs){
|
public Map<String, INDArray> output(MultiDataSetIterator iterator, String... outputs){
|
||||||
|
return stackOutputs(outputBatches(iterator, outputs));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* See {@link #output(MultiDataSetIterator, String...)}, but without the concatenation of batches.
|
||||||
|
*/
|
||||||
|
public List<Map<String, INDArray>> outputBatches(MultiDataSetIterator iterator, String... outputs){
|
||||||
Preconditions.checkState(trainingConfig != null, "Training config has not been set");
|
Preconditions.checkState(trainingConfig != null, "Training config has not been set");
|
||||||
|
|
||||||
List<String> reqVars;
|
List<String> reqVars;
|
||||||
|
@ -2344,12 +2393,114 @@ public class SameDiff extends SDBaseOps {
|
||||||
MultiDataSet ds = iterator.next();
|
MultiDataSet ds = iterator.next();
|
||||||
Map<String,INDArray> placeholderMap = toPlaceholderMap(ds);
|
Map<String,INDArray> placeholderMap = toPlaceholderMap(ds);
|
||||||
|
|
||||||
predictions.add(exec(placeholderMap, reqVars));
|
predictions.add(output(placeholderMap, reqVars));
|
||||||
}
|
}
|
||||||
|
|
||||||
return predictions;
|
return predictions;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Single output inference.
|
||||||
|
* See {@link #output(MultiDataSetIterator, String...)}
|
||||||
|
*/
|
||||||
|
public INDArray outputSingle(MultiDataSetIterator dataSet, String output){
|
||||||
|
return output(dataSet, output).get(output);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Single batched output inference.
|
||||||
|
* See {@link #output(MultiDataSetIterator, String...)}
|
||||||
|
*/
|
||||||
|
public List<INDArray> outputSingleBatches(MultiDataSetIterator dataSet, String output){
|
||||||
|
return getSingleOutput(outputBatches(dataSet, output), output);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @deprecated See {@link #outputAll(Map)}
|
||||||
|
*/
|
||||||
|
@Deprecated
|
||||||
|
public Map<String,INDArray> execAll(Map<String,INDArray> placeholders){
|
||||||
|
return outputAll(placeholders);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Do inference for all variables for a single batch
|
||||||
|
*/
|
||||||
|
public Map<String,INDArray> outputAll(Map<String,INDArray> placeholders){
|
||||||
|
List<String> allVars = new ArrayList<>();
|
||||||
|
for(Variable v : variables.values()){
|
||||||
|
allVars.add(v.getName());
|
||||||
|
}
|
||||||
|
return output(placeholders, allVars.toArray(new String[0]));
|
||||||
|
}
|
||||||
|
/**
|
||||||
|
* @deprecated See {@link #outputSingle(Map, String)}
|
||||||
|
*/
|
||||||
|
@Deprecated
|
||||||
|
public INDArray execSingle(Map<String,INDArray> placeholders, String output){
|
||||||
|
return outputSingle(placeholders, output);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Do inference for a single variable for a single batch
|
||||||
|
*/
|
||||||
|
public INDArray outputSingle(Map<String,INDArray> placeholders, String output){
|
||||||
|
return output(placeholders, output).get(output);
|
||||||
|
}
|
||||||
|
/**
|
||||||
|
* @deprecated See {@link #output(Map, List)}
|
||||||
|
*/
|
||||||
|
@Deprecated
|
||||||
|
public Map<String,INDArray> exec(Map<String,INDArray> placeholders, List<String> outputs){
|
||||||
|
return output(placeholders, outputs);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Do inference for the given variables for a single batch
|
||||||
|
*/
|
||||||
|
public Map<String,INDArray> output(Map<String,INDArray> placeholders, List<String> outputs){
|
||||||
|
return output(placeholders, outputs.toArray(new String[outputs.size()]));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @deprecated See {@link #output(Map, String...)}
|
||||||
|
*/
|
||||||
|
@Deprecated
|
||||||
|
public Map<String,INDArray> exec(Map<String,INDArray> placeholders, String... outputs) {
|
||||||
|
return output(placeholders, outputs);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Do inference for the given variables for a single batch
|
||||||
|
*/
|
||||||
|
public Map<String,INDArray> output(Map<String,INDArray> placeholders, String... outputs) {
|
||||||
|
return output(placeholders, false, null, outputs);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Do inference for the given variables for a single batch, with training information
|
||||||
|
*/
|
||||||
|
protected Map<String,INDArray> output(Map<String,INDArray> placeholders, boolean training, At at, String... outputs){
|
||||||
|
Preconditions.checkState(outputs != null && outputs.length > 0, "No outputs were specified");
|
||||||
|
long threadId = Thread.currentThread().getId();
|
||||||
|
if(!sessions.containsKey(threadId)){
|
||||||
|
log.info("Creating new InferenceSession for thread {}", threadId);
|
||||||
|
sessions.put(threadId, new InferenceSession(this));
|
||||||
|
}
|
||||||
|
|
||||||
|
List<String> phNames = inputs();
|
||||||
|
if(placeholders == null && phNames != null){
|
||||||
|
//Maybe user set placeholders before calling exec method?
|
||||||
|
placeholders = placeholdersPerThread.get(Thread.currentThread().getId());
|
||||||
|
}
|
||||||
|
|
||||||
|
//Placeholder validation is performed in InferenceSession
|
||||||
|
|
||||||
|
InferenceSession is = sessions.get(threadId);
|
||||||
|
Map<String,INDArray> ret = is.output(Arrays.asList(outputs), placeholders, listeners, training, at);
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
public SDVariable one(String name, int... shape){
|
public SDVariable one(String name, int... shape){
|
||||||
return one(name, Nd4j.defaultFloatingPointType(), shape);
|
return one(name, Nd4j.defaultFloatingPointType(), shape);
|
||||||
|
@ -3779,7 +3930,7 @@ public class SameDiff extends SDBaseOps {
|
||||||
}
|
}
|
||||||
|
|
||||||
//TODO is this 'train' flag the best approach?
|
//TODO is this 'train' flag the best approach?
|
||||||
sd.exec(placeholders, trainingConfig != null, at, variableGradNamesList.toArray(new String[variableGradNamesList.size()]));
|
sd.output(placeholders, trainingConfig != null, at, variableGradNamesList.toArray(new String[variableGradNamesList.size()]));
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -4459,47 +4610,6 @@ public class SameDiff extends SDBaseOps {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public Map<String,INDArray> execAll(Map<String,INDArray> placeholders){
|
|
||||||
List<String> allVars = new ArrayList<>();
|
|
||||||
for(Variable v : variables.values()){
|
|
||||||
allVars.add(v.getName());
|
|
||||||
}
|
|
||||||
return exec(placeholders, allVars.toArray(new String[allVars.size()]));
|
|
||||||
}
|
|
||||||
|
|
||||||
public INDArray execSingle(Map<String,INDArray> placeholders, String output){
|
|
||||||
return exec(placeholders, output).get(output);
|
|
||||||
}
|
|
||||||
|
|
||||||
public Map<String,INDArray> exec(Map<String,INDArray> placeholders, List<String> outputs){
|
|
||||||
return exec(placeholders, outputs.toArray(new String[outputs.size()]));
|
|
||||||
}
|
|
||||||
|
|
||||||
public Map<String,INDArray> exec(Map<String,INDArray> placeholders, String... outputs) {
|
|
||||||
return exec(placeholders, false, null, outputs);
|
|
||||||
}
|
|
||||||
|
|
||||||
protected Map<String,INDArray> exec(Map<String,INDArray> placeholders, boolean training, At at, String... outputs){
|
|
||||||
Preconditions.checkState(outputs != null && outputs.length > 0, "No outputs were specified");
|
|
||||||
long threadId = Thread.currentThread().getId();
|
|
||||||
if(!sessions.containsKey(threadId)){
|
|
||||||
log.info("Creating new InferenceSession for thread {}", threadId);
|
|
||||||
sessions.put(threadId, new InferenceSession(this));
|
|
||||||
}
|
|
||||||
|
|
||||||
List<String> phNames = inputs();
|
|
||||||
if(placeholders == null && phNames != null){
|
|
||||||
//Maybe user set placeholders before calling exec method?
|
|
||||||
placeholders = placeholdersPerThread.get(Thread.currentThread().getId());
|
|
||||||
}
|
|
||||||
|
|
||||||
//Placeholder validation is performed in InferenceSession
|
|
||||||
|
|
||||||
InferenceSession is = sessions.get(threadId);
|
|
||||||
Map<String,INDArray> ret = is.output(Arrays.asList(outputs), placeholders, listeners, training, at);
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
protected int asFlatNode(String name, @NonNull SameDiff scope, @NonNull FlatBufferBuilder bufferBuilder) {
|
protected int asFlatNode(String name, @NonNull SameDiff scope, @NonNull FlatBufferBuilder bufferBuilder) {
|
||||||
int scopeName = bufferBuilder.createString(name);
|
int scopeName = bufferBuilder.createString(name);
|
||||||
|
|
|
@ -0,0 +1,70 @@
|
||||||
|
/*
|
||||||
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.nd4j.autodiff.util;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import lombok.AccessLevel;
|
||||||
|
import lombok.NoArgsConstructor;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
|
||||||
|
import org.nd4j.linalg.exception.ND4JException;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Utilities for SameDiff training and inference
|
||||||
|
*/
|
||||||
|
@NoArgsConstructor(access = AccessLevel.PRIVATE)
|
||||||
|
public class TrainingUtils {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Stack batch outputs, like an output from {@link org.nd4j.autodiff.samediff.SameDiff#output(MultiDataSetIterator, String...)}
|
||||||
|
*/
|
||||||
|
public static Map<String, INDArray> stackOutputs(List<Map<String, INDArray>> outputs){
|
||||||
|
Map<String, List<INDArray>> outs = new HashMap<>();
|
||||||
|
for(Map<String, INDArray> batch : outputs){
|
||||||
|
for(String k : batch.keySet()){
|
||||||
|
if(!outs.containsKey(k))
|
||||||
|
outs.put(k, new ArrayList<INDArray>());
|
||||||
|
outs.get(k).add(batch.get(k));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Map<String, INDArray> ret = new HashMap<>();
|
||||||
|
for(String k : outs.keySet()){
|
||||||
|
try {
|
||||||
|
ret.put(k, Nd4j.concat(0, outs.get(k).toArray(new INDArray[0])));
|
||||||
|
} catch(Exception e){
|
||||||
|
throw new ND4JException("Error concatenating batch outputs", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get a list of batch outputs for a single variable from a list of batch outputs for all variables
|
||||||
|
*/
|
||||||
|
public static List<INDArray> getSingleOutput(List<Map<String, INDArray>> outputs, String output){
|
||||||
|
List<INDArray> batches = new ArrayList<>();
|
||||||
|
for(Map<String, INDArray> batch : outputs)
|
||||||
|
batches.add(batch.get(output));
|
||||||
|
|
||||||
|
return batches;
|
||||||
|
}
|
||||||
|
}
|
|
@ -915,7 +915,6 @@ public class OpValidation {
|
||||||
Conv2DDerivative.class,
|
Conv2DDerivative.class,
|
||||||
Conv3DDerivative.class,
|
Conv3DDerivative.class,
|
||||||
DeConv2DDerivative.class,
|
DeConv2DDerivative.class,
|
||||||
FullConv3DDerivative.class,
|
|
||||||
LocalResponseNormalizationDerivative.class,
|
LocalResponseNormalizationDerivative.class,
|
||||||
Pooling2DDerivative.class,
|
Pooling2DDerivative.class,
|
||||||
Pooling3DDerivative.class,
|
Pooling3DDerivative.class,
|
||||||
|
|
|
@ -72,7 +72,6 @@ public class DifferentialFunctionClassHolder {
|
||||||
add(AvgPooling2D.class.getName());
|
add(AvgPooling2D.class.getName());
|
||||||
add(Conv2D.class.getName());
|
add(Conv2D.class.getName());
|
||||||
add(Conv3D.class.getName());
|
add(Conv3D.class.getName());
|
||||||
add(FullConv3D.class.getName());
|
|
||||||
add(LocalResponseNormalization.class.getName());
|
add(LocalResponseNormalization.class.getName());
|
||||||
add(MaxPooling2D.class.getName());
|
add(MaxPooling2D.class.getName());
|
||||||
add(Pooling2D.class.getName());
|
add(Pooling2D.class.getName());
|
||||||
|
|
|
@ -117,8 +117,6 @@ public class ImportClassMapping {
|
||||||
org.nd4j.linalg.api.ops.impl.layers.convolution.DeConv3DDerivative.class,
|
org.nd4j.linalg.api.ops.impl.layers.convolution.DeConv3DDerivative.class,
|
||||||
org.nd4j.linalg.api.ops.impl.layers.convolution.DepthToSpace.class,
|
org.nd4j.linalg.api.ops.impl.layers.convolution.DepthToSpace.class,
|
||||||
org.nd4j.linalg.api.ops.impl.layers.convolution.DepthwiseConv2D.class,
|
org.nd4j.linalg.api.ops.impl.layers.convolution.DepthwiseConv2D.class,
|
||||||
org.nd4j.linalg.api.ops.impl.layers.convolution.FullConv3D.class,
|
|
||||||
org.nd4j.linalg.api.ops.impl.layers.convolution.FullConv3DDerivative.class,
|
|
||||||
org.nd4j.linalg.api.ops.impl.layers.convolution.Im2col.class,
|
org.nd4j.linalg.api.ops.impl.layers.convolution.Im2col.class,
|
||||||
org.nd4j.linalg.api.ops.impl.layers.convolution.Im2colBp.class,
|
org.nd4j.linalg.api.ops.impl.layers.convolution.Im2colBp.class,
|
||||||
org.nd4j.linalg.api.ops.impl.layers.convolution.LegacyPooling2D.class,
|
org.nd4j.linalg.api.ops.impl.layers.convolution.LegacyPooling2D.class,
|
||||||
|
|
|
@ -251,8 +251,6 @@ public class AvgPooling2D extends DynamicCustomOp {
|
||||||
.kW(kW)
|
.kW(kW)
|
||||||
.pH(pH)
|
.pH(pH)
|
||||||
.pW(pW)
|
.pW(pW)
|
||||||
.virtualHeight(1)
|
|
||||||
.virtualWidth(1)
|
|
||||||
.isNHWC(data_format.equalsIgnoreCase("nhwc"))
|
.isNHWC(data_format.equalsIgnoreCase("nhwc"))
|
||||||
.extra(0.0) // averaging only for non-padded values
|
.extra(0.0) // averaging only for non-padded values
|
||||||
.build();
|
.build();
|
||||||
|
@ -277,8 +275,6 @@ public class AvgPooling2D extends DynamicCustomOp {
|
||||||
.kW(kernelShape.get(1).intValue())
|
.kW(kernelShape.get(1).intValue())
|
||||||
.pH(padding.get(0).intValue())
|
.pH(padding.get(0).intValue())
|
||||||
.pW(padding.size() < 2 ? padding.get(0).intValue() : padding.get(1).intValue())
|
.pW(padding.size() < 2 ? padding.get(0).intValue() : padding.get(1).intValue())
|
||||||
.virtualWidth(1)
|
|
||||||
.virtualHeight(1)
|
|
||||||
.build();
|
.build();
|
||||||
this.config = pooling2DConfig;
|
this.config = pooling2DConfig;
|
||||||
addArgs();
|
addArgs();
|
||||||
|
|
|
@ -1,228 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.layers.convolution;
|
|
||||||
|
|
||||||
import lombok.Builder;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import lombok.val;
|
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
|
||||||
import org.nd4j.base.Preconditions;
|
|
||||||
import org.nd4j.imports.descriptors.properties.AttributeAdapter;
|
|
||||||
import org.nd4j.imports.descriptors.properties.PropertyMapping;
|
|
||||||
import org.nd4j.imports.descriptors.properties.adapters.IntArrayIntIndexAdpater;
|
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.FullConv3DConfig;
|
|
||||||
import org.nd4j.linalg.util.ArrayUtil;
|
|
||||||
|
|
||||||
import java.lang.reflect.Field;
|
|
||||||
import java.util.*;
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* FullConv3D operation
|
|
||||||
*/
|
|
||||||
@Slf4j
|
|
||||||
public class FullConv3D extends DynamicCustomOp {
|
|
||||||
|
|
||||||
protected FullConv3DConfig config;
|
|
||||||
|
|
||||||
@Builder(builderMethodName = "builder")
|
|
||||||
public FullConv3D(SameDiff sameDiff, SDVariable[] inputFunctions, INDArray[] inputs, INDArray[] outputs, FullConv3DConfig config) {
|
|
||||||
super(null,sameDiff, inputFunctions, false);
|
|
||||||
this.config = config;
|
|
||||||
if(inputs != null) {
|
|
||||||
addInputArgument(inputs);
|
|
||||||
}
|
|
||||||
|
|
||||||
if(outputs != null) {
|
|
||||||
addOutputArgument(outputs);
|
|
||||||
}
|
|
||||||
|
|
||||||
addArgs();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
public FullConv3D() {}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Map<String, Object> propertiesForFunction() {
|
|
||||||
return config.toProperties();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public long[] iArgs() {
|
|
||||||
if (iArguments.size() == 0)
|
|
||||||
addArgs();
|
|
||||||
|
|
||||||
return super.iArgs();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean isConfigProperties() {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String configFieldName() {
|
|
||||||
return "config";
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Map<String, Map<String, AttributeAdapter>> attributeAdaptersForFunction() {
|
|
||||||
Map<String,Map<String,AttributeAdapter>> ret = new LinkedHashMap<>();
|
|
||||||
Map<String,AttributeAdapter> tfAdapters = new LinkedHashMap<>();
|
|
||||||
|
|
||||||
tfAdapters.put("dT", new IntArrayIntIndexAdpater(1));
|
|
||||||
tfAdapters.put("dW", new IntArrayIntIndexAdpater(2));
|
|
||||||
tfAdapters.put("dH",new IntArrayIntIndexAdpater(3));
|
|
||||||
|
|
||||||
|
|
||||||
tfAdapters.put("pT", new IntArrayIntIndexAdpater(1));
|
|
||||||
tfAdapters.put("pW", new IntArrayIntIndexAdpater(2));
|
|
||||||
tfAdapters.put("pH",new IntArrayIntIndexAdpater(3));
|
|
||||||
|
|
||||||
ret.put(tensorflowName(),tfAdapters);
|
|
||||||
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Map<String, Map<String, PropertyMapping>> mappingsForFunction() {
|
|
||||||
Map<String,Map<String,PropertyMapping>> ret = new HashMap<>();
|
|
||||||
Map<String,PropertyMapping> map = new HashMap<>();
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
val strideMapping = PropertyMapping.builder()
|
|
||||||
.tfAttrName("strides")
|
|
||||||
.onnxAttrName("strides")
|
|
||||||
.propertyNames(new String[]{"dT","dW","dH"})
|
|
||||||
.build();
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
val dilationMapping = PropertyMapping.builder()
|
|
||||||
.onnxAttrName("dilations")
|
|
||||||
.propertyNames(new String[]{"dD","dH","dW"})
|
|
||||||
.tfAttrName("rates")
|
|
||||||
.build();
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
val sameMode = PropertyMapping.builder()
|
|
||||||
.onnxAttrName("auto_pad")
|
|
||||||
.propertyNames(new String[]{"isSameMode"})
|
|
||||||
.tfAttrName("padding")
|
|
||||||
.build();
|
|
||||||
|
|
||||||
val paddingWidthHeight = PropertyMapping.builder()
|
|
||||||
.onnxAttrName("padding")
|
|
||||||
.propertyNames(new String[]{"pT","pW","pH"})
|
|
||||||
.build();
|
|
||||||
|
|
||||||
val dataFormat = PropertyMapping.builder()
|
|
||||||
.onnxAttrName("data_format")
|
|
||||||
.tfAttrName("data_format")
|
|
||||||
.propertyNames(new String[]{"dataFormat"})
|
|
||||||
.build();
|
|
||||||
|
|
||||||
|
|
||||||
val outputPadding = PropertyMapping.builder()
|
|
||||||
.propertyNames(new String[]{"aT","aH","aW"})
|
|
||||||
.build();
|
|
||||||
|
|
||||||
|
|
||||||
val biasUsed = PropertyMapping.builder()
|
|
||||||
.propertyNames(new String[]{"biasUsed"})
|
|
||||||
.build();
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
for(val propertyMapping : new PropertyMapping[] {
|
|
||||||
strideMapping,
|
|
||||||
dilationMapping,
|
|
||||||
sameMode,
|
|
||||||
paddingWidthHeight,
|
|
||||||
dataFormat,
|
|
||||||
outputPadding,biasUsed}) {
|
|
||||||
for(val keys : propertyMapping.getPropertyNames())
|
|
||||||
map.put(keys,propertyMapping);
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
ret.put(onnxName(),map);
|
|
||||||
ret.put(tensorflowName(),map);
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
private void addArgs() {
|
|
||||||
addIArgument(new long[]{
|
|
||||||
config.getDT(),
|
|
||||||
config.getDW(),
|
|
||||||
config.getDH(),
|
|
||||||
config.getPT(),
|
|
||||||
config.getPW(),
|
|
||||||
config.getPH(),
|
|
||||||
config.getDilationT(),
|
|
||||||
config.getDilationW(),
|
|
||||||
config.getDilationH(),
|
|
||||||
config.getAT(),
|
|
||||||
config.getAW(),
|
|
||||||
config.getAH(),
|
|
||||||
ArrayUtil.fromBoolean(config.isBiasUsed())});
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String opName() {
|
|
||||||
return "fullconv3d";
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
|
||||||
List<SDVariable> inputs = new ArrayList<>();
|
|
||||||
inputs.addAll(Arrays.asList(args()));
|
|
||||||
inputs.addAll(f1);
|
|
||||||
List<SDVariable> ret = new ArrayList<>();
|
|
||||||
FullConv3DDerivative fullConv3DDerivative = FullConv3DDerivative.derivativeBuilder()
|
|
||||||
.conv3DConfig(config)
|
|
||||||
.sameDiff(sameDiff)
|
|
||||||
.inputFunctions(inputs.toArray(new SDVariable[inputs.size()]))
|
|
||||||
.build();
|
|
||||||
ret.addAll(Arrays.asList(fullConv3DDerivative.outputVariables()));
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
|
||||||
int n = args().length;
|
|
||||||
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes);
|
|
||||||
return Collections.singletonList(inputDataTypes.get(0));
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,77 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.layers.convolution;
|
|
||||||
|
|
||||||
import lombok.Builder;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
|
||||||
import org.nd4j.base.Preconditions;
|
|
||||||
import org.nd4j.imports.NoOpNameFoundException;
|
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.FullConv3DConfig;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* FullConv3DDerivative operation
|
|
||||||
*/
|
|
||||||
@Slf4j
|
|
||||||
public class FullConv3DDerivative extends FullConv3D {
|
|
||||||
|
|
||||||
public FullConv3DDerivative() {}
|
|
||||||
|
|
||||||
@Builder(builderMethodName = "derivativeBuilder")
|
|
||||||
public FullConv3DDerivative(SameDiff sameDiff, SDVariable[] inputFunctions, INDArray[] inputs, INDArray[] outputs, FullConv3DConfig conv3DConfig) {
|
|
||||||
super(sameDiff, inputFunctions, inputs, outputs, conv3DConfig);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String opName() {
|
|
||||||
return "fullconv3d_bp";
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String tensorflowName() {
|
|
||||||
throw new NoOpNameFoundException("No tensorflwo op name found for conv3d derivative");
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String[] tensorflowNames() {
|
|
||||||
throw new NoOpNameFoundException("No tensorflwo op name found for conv3d derivative");
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
|
||||||
throw new UnsupportedOperationException("Unable to take derivative of derivative.");
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
|
||||||
int n = args().length; //Original inputs + gradient at
|
|
||||||
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes);
|
|
||||||
List<DataType> out = new ArrayList<>(n-1);
|
|
||||||
for( int i=0; i<n-1; i++ ){
|
|
||||||
out.add(inputDataTypes.get(i));
|
|
||||||
}
|
|
||||||
return out;
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -204,8 +204,6 @@ public class MaxPooling2D extends DynamicCustomOp {
|
||||||
.kW(kW)
|
.kW(kW)
|
||||||
.pH(pH)
|
.pH(pH)
|
||||||
.pW(pW)
|
.pW(pW)
|
||||||
.virtualHeight(1)
|
|
||||||
.virtualWidth(1)
|
|
||||||
.isNHWC(data_format.equalsIgnoreCase("nhwc"))
|
.isNHWC(data_format.equalsIgnoreCase("nhwc"))
|
||||||
.extra(1.0) // averaging only for non-padded values
|
.extra(1.0) // averaging only for non-padded values
|
||||||
.build();
|
.build();
|
||||||
|
@ -230,8 +228,6 @@ public class MaxPooling2D extends DynamicCustomOp {
|
||||||
.kW(kernelShape.size() < 2 ? kernelShape.get(0).intValue() : kernelShape.get(1).intValue())
|
.kW(kernelShape.size() < 2 ? kernelShape.get(0).intValue() : kernelShape.get(1).intValue())
|
||||||
.pH(padding.get(0).intValue())
|
.pH(padding.get(0).intValue())
|
||||||
.pW(padding.size() < 2 ? padding.get(0).intValue() : padding.get(1).intValue())
|
.pW(padding.size() < 2 ? padding.get(0).intValue() : padding.get(1).intValue())
|
||||||
.virtualHeight(1)
|
|
||||||
.virtualWidth(1)
|
|
||||||
.build();
|
.build();
|
||||||
this.config = pooling2DConfig;
|
this.config = pooling2DConfig;
|
||||||
addArgs();
|
addArgs();
|
||||||
|
|
|
@ -174,8 +174,6 @@ public class Pooling2D extends DynamicCustomOp {
|
||||||
.kW(kW.intValue())
|
.kW(kW.intValue())
|
||||||
.pH(padding.get(0).intValue())
|
.pH(padding.get(0).intValue())
|
||||||
.pW(padding.get(1).intValue())
|
.pW(padding.get(1).intValue())
|
||||||
.virtualWidth(1)
|
|
||||||
.virtualHeight(1)
|
|
||||||
.build();
|
.build();
|
||||||
this.config = pooling2DConfig;
|
this.config = pooling2DConfig;
|
||||||
addArgs();
|
addArgs();
|
||||||
|
@ -200,8 +198,6 @@ public class Pooling2D extends DynamicCustomOp {
|
||||||
.kW(kernelShape.get(1).intValue())
|
.kW(kernelShape.get(1).intValue())
|
||||||
.pH(padding.get(0).intValue())
|
.pH(padding.get(0).intValue())
|
||||||
.pW(padding.get(1).intValue())
|
.pW(padding.get(1).intValue())
|
||||||
.virtualHeight(1)
|
|
||||||
.virtualWidth(1)
|
|
||||||
.build();
|
.build();
|
||||||
this.config = pooling2DConfig;
|
this.config = pooling2DConfig;
|
||||||
addArgs();
|
addArgs();
|
||||||
|
|
|
@ -16,6 +16,8 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.layers.convolution.config;
|
package org.nd4j.linalg.api.ops.impl.layers.convolution.config;
|
||||||
|
|
||||||
|
import java.util.LinkedHashMap;
|
||||||
|
import java.util.Map;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
||||||
|
|
||||||
|
@ -23,6 +25,8 @@ import java.lang.reflect.Field;
|
||||||
|
|
||||||
public abstract class BaseConvolutionConfig {
|
public abstract class BaseConvolutionConfig {
|
||||||
|
|
||||||
|
public abstract Map<String, Object> toProperties();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get the value for a given property
|
* Get the value for a given property
|
||||||
* for this function
|
* for this function
|
||||||
|
@ -154,4 +158,5 @@ public abstract class BaseConvolutionConfig {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
protected abstract void validate();
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,15 +16,16 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.layers.convolution.config;
|
package org.nd4j.linalg.api.ops.impl.layers.convolution.config;
|
||||||
|
|
||||||
import lombok.*;
|
|
||||||
import org.nd4j.base.Preconditions;
|
|
||||||
|
|
||||||
import java.util.LinkedHashMap;
|
import java.util.LinkedHashMap;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
import lombok.Builder;
|
||||||
|
import lombok.Data;
|
||||||
|
import lombok.NoArgsConstructor;
|
||||||
|
import org.nd4j.base.Preconditions;
|
||||||
|
import org.nd4j.linalg.util.ConvConfigUtil;
|
||||||
|
|
||||||
@Builder
|
|
||||||
@Data
|
@Data
|
||||||
@AllArgsConstructor
|
@Builder
|
||||||
@NoArgsConstructor
|
@NoArgsConstructor
|
||||||
public class Conv1DConfig extends BaseConvolutionConfig {
|
public class Conv1DConfig extends BaseConvolutionConfig {
|
||||||
public static final String NCW = "NCW";
|
public static final String NCW = "NCW";
|
||||||
|
@ -40,6 +41,16 @@ public class Conv1DConfig extends BaseConvolutionConfig {
|
||||||
private String dataFormat = NCW;
|
private String dataFormat = NCW;
|
||||||
private boolean isSameMode;
|
private boolean isSameMode;
|
||||||
|
|
||||||
|
public Conv1DConfig(long k, long s, long p, String dataFormat, boolean isSameMode) {
|
||||||
|
this.k = k;
|
||||||
|
this.s = s;
|
||||||
|
this.p = p;
|
||||||
|
this.dataFormat = dataFormat;
|
||||||
|
this.isSameMode = isSameMode;
|
||||||
|
|
||||||
|
validate();
|
||||||
|
}
|
||||||
|
|
||||||
public boolean isNWC(){
|
public boolean isNWC(){
|
||||||
Preconditions.checkState(dataFormat.equalsIgnoreCase(NCW) || dataFormat.equalsIgnoreCase(NWC),
|
Preconditions.checkState(dataFormat.equalsIgnoreCase(NCW) || dataFormat.equalsIgnoreCase(NWC),
|
||||||
"Data format must be one of %s or %s, got %s", NCW, NWC, dataFormat);
|
"Data format must be one of %s or %s, got %s", NCW, NWC, dataFormat);
|
||||||
|
@ -54,6 +65,7 @@ public class Conv1DConfig extends BaseConvolutionConfig {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
public Map<String, Object> toProperties() {
|
public Map<String, Object> toProperties() {
|
||||||
Map<String, Object> ret = new LinkedHashMap<>();
|
Map<String, Object> ret = new LinkedHashMap<>();
|
||||||
ret.put("k", k);
|
ret.put("k", k);
|
||||||
|
@ -64,5 +76,11 @@ public class Conv1DConfig extends BaseConvolutionConfig {
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected void validate() {
|
||||||
|
ConvConfigUtil.validate1D(k, s, p);
|
||||||
|
Preconditions.checkArgument(dataFormat != null, "Data format can't be null");
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,18 +16,16 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.layers.convolution.config;
|
package org.nd4j.linalg.api.ops.impl.layers.convolution.config;
|
||||||
|
|
||||||
import lombok.AllArgsConstructor;
|
import java.util.LinkedHashMap;
|
||||||
|
import java.util.Map;
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
|
import org.nd4j.linalg.util.ConvConfigUtil;
|
||||||
|
|
||||||
import java.util.LinkedHashMap;
|
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
@Builder
|
|
||||||
@Data
|
@Data
|
||||||
@AllArgsConstructor
|
@Builder
|
||||||
@NoArgsConstructor
|
@NoArgsConstructor
|
||||||
public class Conv2DConfig extends BaseConvolutionConfig {
|
public class Conv2DConfig extends BaseConvolutionConfig {
|
||||||
public static final String NCHW = "NCHW";
|
public static final String NCHW = "NCHW";
|
||||||
|
@ -53,6 +51,23 @@ public class Conv2DConfig extends BaseConvolutionConfig {
|
||||||
@Builder.Default
|
@Builder.Default
|
||||||
private String dataFormat = NCHW;
|
private String dataFormat = NCHW;
|
||||||
|
|
||||||
|
public Conv2DConfig(long kH, long kW, long sH, long sW, long pH, long pW, long dH, long dW, boolean isSameMode,
|
||||||
|
String dataFormat) {
|
||||||
|
|
||||||
|
this.kH = kH;
|
||||||
|
this.kW = kW;
|
||||||
|
this.sH = sH;
|
||||||
|
this.sW = sW;
|
||||||
|
this.pH = pH;
|
||||||
|
this.pW = pW;
|
||||||
|
this.dH = dH;
|
||||||
|
this.dW = dW;
|
||||||
|
this.isSameMode = isSameMode;
|
||||||
|
this.dataFormat = dataFormat;
|
||||||
|
|
||||||
|
validate();
|
||||||
|
}
|
||||||
|
|
||||||
public boolean isNHWC(){
|
public boolean isNHWC(){
|
||||||
Preconditions.checkState(dataFormat.equalsIgnoreCase(NCHW) || dataFormat.equalsIgnoreCase(NHWC),
|
Preconditions.checkState(dataFormat.equalsIgnoreCase(NCHW) || dataFormat.equalsIgnoreCase(NHWC),
|
||||||
"Data format must be one of %s or %s, got %s", NCHW, NHWC, dataFormat);
|
"Data format must be one of %s or %s, got %s", NCHW, NHWC, dataFormat);
|
||||||
|
@ -67,6 +82,7 @@ public class Conv2DConfig extends BaseConvolutionConfig {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
public Map<String, Object> toProperties() {
|
public Map<String, Object> toProperties() {
|
||||||
Map<String, Object> ret = new LinkedHashMap<>();
|
Map<String, Object> ret = new LinkedHashMap<>();
|
||||||
ret.put("kH", kH);
|
ret.put("kH", kH);
|
||||||
|
@ -82,5 +98,11 @@ public class Conv2DConfig extends BaseConvolutionConfig {
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected void validate() {
|
||||||
|
ConvConfigUtil.validate2D(kH, kW, sH, sW, pH, pW, dH, dW);
|
||||||
|
Preconditions.checkArgument(dataFormat != null, "Data format can't be null");
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,30 +17,28 @@
|
||||||
package org.nd4j.linalg.api.ops.impl.layers.convolution.config;
|
package org.nd4j.linalg.api.ops.impl.layers.convolution.config;
|
||||||
|
|
||||||
|
|
||||||
import lombok.AllArgsConstructor;
|
import java.util.LinkedHashMap;
|
||||||
|
import java.util.Map;
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
|
import org.nd4j.linalg.util.ConvConfigUtil;
|
||||||
import java.util.LinkedHashMap;
|
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@Builder
|
@Builder
|
||||||
@NoArgsConstructor
|
@NoArgsConstructor
|
||||||
@AllArgsConstructor
|
|
||||||
public class Conv3DConfig extends BaseConvolutionConfig {
|
public class Conv3DConfig extends BaseConvolutionConfig {
|
||||||
public static final String NDHWC = "NDHWC";
|
public static final String NDHWC = "NDHWC";
|
||||||
public static final String NCDHW = "NCDHW";
|
public static final String NCDHW = "NCDHW";
|
||||||
|
|
||||||
//kernel
|
//kernel
|
||||||
@Builder.Default
|
@Builder.Default
|
||||||
private long kD = 1;
|
private long kD = -1;
|
||||||
@Builder.Default
|
@Builder.Default
|
||||||
private long kW = 1;
|
private long kW = -1;
|
||||||
@Builder.Default
|
@Builder.Default
|
||||||
private long kH = 1;
|
private long kH = -1;
|
||||||
|
|
||||||
//strides
|
//strides
|
||||||
@Builder.Default
|
@Builder.Default
|
||||||
|
@ -66,14 +64,6 @@ public class Conv3DConfig extends BaseConvolutionConfig {
|
||||||
@Builder.Default
|
@Builder.Default
|
||||||
private long dH = 1;
|
private long dH = 1;
|
||||||
|
|
||||||
//output padding
|
|
||||||
@Builder.Default
|
|
||||||
private long aD = 0;
|
|
||||||
@Builder.Default
|
|
||||||
private long aW = 0;
|
|
||||||
@Builder.Default
|
|
||||||
private long aH = 0;
|
|
||||||
|
|
||||||
@Builder.Default
|
@Builder.Default
|
||||||
private boolean biasUsed = false;
|
private boolean biasUsed = false;
|
||||||
private boolean isSameMode;
|
private boolean isSameMode;
|
||||||
|
@ -81,6 +71,27 @@ public class Conv3DConfig extends BaseConvolutionConfig {
|
||||||
@Builder.Default
|
@Builder.Default
|
||||||
private String dataFormat = NDHWC;
|
private String dataFormat = NDHWC;
|
||||||
|
|
||||||
|
public Conv3DConfig(long kD, long kW, long kH, long sD, long sW, long sH, long pD, long pW, long pH, long dD,
|
||||||
|
long dW, long dH, boolean biasUsed, boolean isSameMode, String dataFormat) {
|
||||||
|
this.kD = kD;
|
||||||
|
this.kW = kW;
|
||||||
|
this.kH = kH;
|
||||||
|
this.sD = sD;
|
||||||
|
this.sW = sW;
|
||||||
|
this.sH = sH;
|
||||||
|
this.pD = pD;
|
||||||
|
this.pW = pW;
|
||||||
|
this.pH = pH;
|
||||||
|
this.dD = dD;
|
||||||
|
this.dW = dW;
|
||||||
|
this.dH = dH;
|
||||||
|
this.biasUsed = biasUsed;
|
||||||
|
this.isSameMode = isSameMode;
|
||||||
|
this.dataFormat = dataFormat;
|
||||||
|
|
||||||
|
validate();
|
||||||
|
}
|
||||||
|
|
||||||
public boolean isNCDHW(){
|
public boolean isNCDHW(){
|
||||||
Preconditions.checkState(dataFormat.equalsIgnoreCase(NCDHW) || dataFormat.equalsIgnoreCase(NDHWC),
|
Preconditions.checkState(dataFormat.equalsIgnoreCase(NCDHW) || dataFormat.equalsIgnoreCase(NDHWC),
|
||||||
"Data format must be one of %s or %s, got %s", NCDHW, NDHWC, dataFormat);
|
"Data format must be one of %s or %s, got %s", NCDHW, NDHWC, dataFormat);
|
||||||
|
@ -95,6 +106,7 @@ public class Conv3DConfig extends BaseConvolutionConfig {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
public Map<String, Object> toProperties() {
|
public Map<String, Object> toProperties() {
|
||||||
Map<String, Object> ret = new LinkedHashMap<>();
|
Map<String, Object> ret = new LinkedHashMap<>();
|
||||||
ret.put("kD", kD);
|
ret.put("kD", kD);
|
||||||
|
@ -109,9 +121,6 @@ public class Conv3DConfig extends BaseConvolutionConfig {
|
||||||
ret.put("dD", dD);
|
ret.put("dD", dD);
|
||||||
ret.put("dW", dW);
|
ret.put("dW", dW);
|
||||||
ret.put("dH", dH);
|
ret.put("dH", dH);
|
||||||
ret.put("aD", aD);
|
|
||||||
ret.put("aW", aW);
|
|
||||||
ret.put("aH", aH);
|
|
||||||
ret.put("biasUsed", biasUsed);
|
ret.put("biasUsed", biasUsed);
|
||||||
ret.put("dataFormat", dataFormat);
|
ret.put("dataFormat", dataFormat);
|
||||||
ret.put("isSameMode", isSameMode);
|
ret.put("isSameMode", isSameMode);
|
||||||
|
@ -119,5 +128,11 @@ public class Conv3DConfig extends BaseConvolutionConfig {
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected void validate() {
|
||||||
|
ConvConfigUtil.validate3D(kH, kW, kD, sH, sW, sD, pH, pW, pD, dH, dW, dD);
|
||||||
|
Preconditions.checkArgument(dataFormat != null, "Data format can't be null");
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,22 +16,22 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.layers.convolution.config;
|
package org.nd4j.linalg.api.ops.impl.layers.convolution.config;
|
||||||
|
|
||||||
import lombok.AllArgsConstructor;
|
import java.util.LinkedHashMap;
|
||||||
|
import java.util.Map;
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
|
import org.nd4j.base.Preconditions;
|
||||||
|
import org.nd4j.linalg.util.ConvConfigUtil;
|
||||||
|
|
||||||
import java.util.LinkedHashMap;
|
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
@Builder
|
|
||||||
@Data
|
@Data
|
||||||
@AllArgsConstructor
|
@Builder
|
||||||
@NoArgsConstructor
|
@NoArgsConstructor
|
||||||
public class DeConv2DConfig extends BaseConvolutionConfig {
|
public class DeConv2DConfig extends BaseConvolutionConfig {
|
||||||
public static final String NCHW = "NCHW";
|
public static final String NCHW = "NCHW";
|
||||||
public static final String NHWC = "NHWC";
|
public static final String NHWC = "NHWC";
|
||||||
|
|
||||||
|
|
||||||
@Builder.Default private long kH = -1L;
|
@Builder.Default private long kH = -1L;
|
||||||
@Builder.Default private long kW = -1L;
|
@Builder.Default private long kW = -1L;
|
||||||
@Builder.Default private long sH = 1L;
|
@Builder.Default private long sH = 1L;
|
||||||
|
@ -43,8 +43,25 @@ public class DeConv2DConfig extends BaseConvolutionConfig {
|
||||||
@Builder.Default private boolean isSameMode = false;
|
@Builder.Default private boolean isSameMode = false;
|
||||||
@Builder.Default private String dataFormat = NCHW;
|
@Builder.Default private String dataFormat = NCHW;
|
||||||
|
|
||||||
|
public DeConv2DConfig(long kH, long kW, long sH, long sW, long pH, long pW, long dH, long dW, boolean isSameMode,
|
||||||
|
String dataFormat) {
|
||||||
|
this.kH = kH;
|
||||||
|
this.kW = kW;
|
||||||
|
this.sH = sH;
|
||||||
|
this.sW = sW;
|
||||||
|
this.pH = pH;
|
||||||
|
this.pW = pW;
|
||||||
|
this.dH = dH;
|
||||||
|
this.dW = dW;
|
||||||
|
this.isSameMode = isSameMode;
|
||||||
|
this.dataFormat = dataFormat;
|
||||||
|
|
||||||
|
validate();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
public Map<String, Object> toProperties() {
|
public Map<String, Object> toProperties() {
|
||||||
|
|
||||||
Map<String, Object> ret = new LinkedHashMap<>();
|
Map<String, Object> ret = new LinkedHashMap<>();
|
||||||
ret.put("kH", kH);
|
ret.put("kH", kH);
|
||||||
ret.put("kW", kW);
|
ret.put("kW", kW);
|
||||||
|
@ -58,4 +75,10 @@ public class DeConv2DConfig extends BaseConvolutionConfig {
|
||||||
ret.put("dataFormat", dataFormat);
|
ret.put("dataFormat", dataFormat);
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected void validate() {
|
||||||
|
ConvConfigUtil.validate2D(kH, kW, sH, sW, pH, pW, dH, dW);
|
||||||
|
Preconditions.checkArgument(dataFormat != null, "Data format can't be null");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,17 +16,16 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.layers.convolution.config;
|
package org.nd4j.linalg.api.ops.impl.layers.convolution.config;
|
||||||
|
|
||||||
import lombok.AllArgsConstructor;
|
import java.util.LinkedHashMap;
|
||||||
|
import java.util.Map;
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
|
import org.nd4j.base.Preconditions;
|
||||||
|
import org.nd4j.linalg.util.ConvConfigUtil;
|
||||||
|
|
||||||
import java.util.LinkedHashMap;
|
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
@Builder
|
|
||||||
@Data
|
@Data
|
||||||
@AllArgsConstructor
|
@Builder
|
||||||
@NoArgsConstructor
|
@NoArgsConstructor
|
||||||
public class DeConv3DConfig extends BaseConvolutionConfig {
|
public class DeConv3DConfig extends BaseConvolutionConfig {
|
||||||
public static final String NCDHW = "NCDHW";
|
public static final String NCDHW = "NCDHW";
|
||||||
|
@ -47,7 +46,28 @@ public class DeConv3DConfig extends BaseConvolutionConfig {
|
||||||
@Builder.Default private boolean isSameMode = false;
|
@Builder.Default private boolean isSameMode = false;
|
||||||
@Builder.Default private String dataFormat = NCDHW;
|
@Builder.Default private String dataFormat = NCDHW;
|
||||||
|
|
||||||
|
public DeConv3DConfig(long kD, long kH, long kW, long sD, long sH, long sW, long pD, long pH, long pW, long dD,
|
||||||
|
long dH, long dW, boolean isSameMode, String dataFormat) {
|
||||||
|
this.kD = kD;
|
||||||
|
this.kH = kH;
|
||||||
|
this.kW = kW;
|
||||||
|
this.sD = sD;
|
||||||
|
this.sH = sH;
|
||||||
|
this.sW = sW;
|
||||||
|
this.pD = pD;
|
||||||
|
this.pH = pH;
|
||||||
|
this.pW = pW;
|
||||||
|
this.dD = dD;
|
||||||
|
this.dH = dH;
|
||||||
|
this.dW = dW;
|
||||||
|
this.isSameMode = isSameMode;
|
||||||
|
this.dataFormat = dataFormat;
|
||||||
|
|
||||||
|
validate();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Override
|
||||||
public Map<String, Object> toProperties() {
|
public Map<String, Object> toProperties() {
|
||||||
Map<String, Object> ret = new LinkedHashMap<>();
|
Map<String, Object> ret = new LinkedHashMap<>();
|
||||||
ret.put("kD", kD);
|
ret.put("kD", kD);
|
||||||
|
@ -66,4 +86,10 @@ public class DeConv3DConfig extends BaseConvolutionConfig {
|
||||||
ret.put("dataFormat", dataFormat);
|
ret.put("dataFormat", dataFormat);
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected void validate() {
|
||||||
|
ConvConfigUtil.validate3D(kH, kW, kD, sH, sW, sD, pH, pW, pD, dH, dW, dD);
|
||||||
|
Preconditions.checkArgument(dataFormat != null, "Data format can't be null");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,53 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.layers.convolution.config;
|
|
||||||
|
|
||||||
import lombok.Builder;
|
|
||||||
import lombok.Data;
|
|
||||||
|
|
||||||
import java.util.LinkedHashMap;
|
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
@Builder
|
|
||||||
@Data
|
|
||||||
public class FullConv3DConfig extends BaseConvolutionConfig {
|
|
||||||
private long dT,dW,dH,pT,pW,pH,dilationT,dilationW,dilationH,aT,aW,aH;
|
|
||||||
private boolean biasUsed;
|
|
||||||
private String dataFormat;
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
public Map<String,Object> toProperties() {
|
|
||||||
Map<String,Object> ret = new LinkedHashMap<>();
|
|
||||||
ret.put("dT",dT);
|
|
||||||
ret.put("dW",dW);
|
|
||||||
ret.put("dH",dH);
|
|
||||||
ret.put("pT",pT);
|
|
||||||
ret.put("pW",pW);
|
|
||||||
ret.put("pH",pH);
|
|
||||||
ret.put("dD",dilationT);
|
|
||||||
ret.put("dW",dilationW);
|
|
||||||
ret.put("dH",dilationH);
|
|
||||||
ret.put("aT",aT);
|
|
||||||
ret.put("aW",aW);
|
|
||||||
ret.put("aH",aH);
|
|
||||||
ret.put("biasUsed",biasUsed);
|
|
||||||
ret.put("dataFormat",dataFormat);
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -16,19 +16,31 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.layers.convolution.config;
|
package org.nd4j.linalg.api.ops.impl.layers.convolution.config;
|
||||||
|
|
||||||
import lombok.Builder;
|
|
||||||
import lombok.Data;
|
|
||||||
|
|
||||||
import java.util.LinkedHashMap;
|
import java.util.LinkedHashMap;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
import lombok.Builder;
|
||||||
|
import lombok.Data;
|
||||||
|
import lombok.NoArgsConstructor;
|
||||||
|
import org.nd4j.linalg.util.ConvConfigUtil;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@Builder
|
@Builder
|
||||||
|
@NoArgsConstructor
|
||||||
public class LocalResponseNormalizationConfig extends BaseConvolutionConfig {
|
public class LocalResponseNormalizationConfig extends BaseConvolutionConfig {
|
||||||
|
|
||||||
private double alpha, beta, bias;
|
private double alpha, beta, bias;
|
||||||
private int depth;
|
private int depth;
|
||||||
|
|
||||||
|
public LocalResponseNormalizationConfig(double alpha, double beta, double bias, int depth) {
|
||||||
|
this.alpha = alpha;
|
||||||
|
this.beta = beta;
|
||||||
|
this.bias = bias;
|
||||||
|
this.depth = depth;
|
||||||
|
|
||||||
|
validate();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
public Map<String, Object> toProperties() {
|
public Map<String, Object> toProperties() {
|
||||||
Map<String, Object> ret = new LinkedHashMap<>();
|
Map<String, Object> ret = new LinkedHashMap<>();
|
||||||
ret.put("alpha", alpha);
|
ret.put("alpha", alpha);
|
||||||
|
@ -38,4 +50,9 @@ public class LocalResponseNormalizationConfig extends BaseConvolutionConfig {
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected void validate() {
|
||||||
|
ConvConfigUtil.validateLRN(alpha, beta, bias, depth);
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,32 +16,32 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.layers.convolution.config;
|
package org.nd4j.linalg.api.ops.impl.layers.convolution.config;
|
||||||
|
|
||||||
import lombok.AllArgsConstructor;
|
import java.util.LinkedHashMap;
|
||||||
|
import java.util.Map;
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2D;
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2D;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2D.Divisor;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2D.Pooling2DType;
|
||||||
|
import org.nd4j.linalg.util.ConvConfigUtil;
|
||||||
|
|
||||||
import java.util.LinkedHashMap;
|
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
@Builder
|
|
||||||
@AllArgsConstructor
|
|
||||||
@Data
|
@Data
|
||||||
|
@Builder
|
||||||
@NoArgsConstructor
|
@NoArgsConstructor
|
||||||
public class Pooling2DConfig extends BaseConvolutionConfig {
|
public class Pooling2DConfig extends BaseConvolutionConfig {
|
||||||
|
|
||||||
private long kH, kW;
|
@Builder.Default private long kH = -1, kW = -1;
|
||||||
private long sH, sW;
|
@Builder.Default private long sH = 1, sW = 1;
|
||||||
private long pH, pW;
|
@Builder.Default private long pH = 0, pW = 0;
|
||||||
private long virtualHeight, virtualWidth;
|
|
||||||
/**
|
/**
|
||||||
* Extra is an optional parameter mainly for use with pnorm right now.
|
* Extra is an optional parameter mainly for use with pnorm right now.
|
||||||
* All pooling implementations take 9 parameters save pnorm.
|
* All pooling implementations take 9 parameters save pnorm.
|
||||||
* Pnorm takes 10 and is cast to an int.
|
* Pnorm takes 10 and is cast to an int.
|
||||||
*/
|
*/
|
||||||
private double extra;
|
private double extra;
|
||||||
private Pooling2D.Pooling2DType type;
|
@Builder.Default
|
||||||
|
private Pooling2D.Pooling2DType type = Pooling2DType.MAX;
|
||||||
@Builder.Default
|
@Builder.Default
|
||||||
private Pooling2D.Divisor divisor = Pooling2D.Divisor.EXCLUDE_PADDING;
|
private Pooling2D.Divisor divisor = Pooling2D.Divisor.EXCLUDE_PADDING;
|
||||||
private boolean isSameMode;
|
private boolean isSameMode;
|
||||||
|
@ -52,7 +52,26 @@ public class Pooling2DConfig extends BaseConvolutionConfig {
|
||||||
@Builder.Default
|
@Builder.Default
|
||||||
private boolean isNHWC = false;
|
private boolean isNHWC = false;
|
||||||
|
|
||||||
|
public Pooling2DConfig(long kH, long kW, long sH, long sW, long pH, long pW, double extra, Pooling2DType type,
|
||||||
|
Divisor divisor, boolean isSameMode, long dH, long dW, boolean isNHWC) {
|
||||||
|
this.kH = kH;
|
||||||
|
this.kW = kW;
|
||||||
|
this.sH = sH;
|
||||||
|
this.sW = sW;
|
||||||
|
this.pH = pH;
|
||||||
|
this.pW = pW;
|
||||||
|
this.extra = extra;
|
||||||
|
this.type = type;
|
||||||
|
this.divisor = divisor;
|
||||||
|
this.isSameMode = isSameMode;
|
||||||
|
this.dH = dH;
|
||||||
|
this.dW = dW;
|
||||||
|
this.isNHWC = isNHWC;
|
||||||
|
|
||||||
|
validate();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
public Map<String, Object> toProperties() {
|
public Map<String, Object> toProperties() {
|
||||||
Map<String, Object> ret = new LinkedHashMap<>();
|
Map<String, Object> ret = new LinkedHashMap<>();
|
||||||
ret.put("kH", kH);
|
ret.put("kH", kH);
|
||||||
|
@ -61,8 +80,6 @@ public class Pooling2DConfig extends BaseConvolutionConfig {
|
||||||
ret.put("sW", sW);
|
ret.put("sW", sW);
|
||||||
ret.put("pH", pH);
|
ret.put("pH", pH);
|
||||||
ret.put("pW", pW);
|
ret.put("pW", pW);
|
||||||
ret.put("virtualHeight", virtualHeight);
|
|
||||||
ret.put("virtualWidth", virtualWidth);
|
|
||||||
ret.put("extra", extra);
|
ret.put("extra", extra);
|
||||||
ret.put("type", type.toString());
|
ret.put("type", type.toString());
|
||||||
ret.put("isSameMode", isSameMode);
|
ret.put("isSameMode", isSameMode);
|
||||||
|
@ -72,4 +89,11 @@ public class Pooling2DConfig extends BaseConvolutionConfig {
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected void validate() {
|
||||||
|
ConvConfigUtil.validate2D(kH, kW, sH, sW, pH, pW, dH, dW);
|
||||||
|
|
||||||
|
//TODO check other args?
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,23 +16,22 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.layers.convolution.config;
|
package org.nd4j.linalg.api.ops.impl.layers.convolution.config;
|
||||||
|
|
||||||
import lombok.AllArgsConstructor;
|
import java.util.LinkedHashMap;
|
||||||
|
import java.util.Map;
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling3D;
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling3D;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling3D.Pooling3DType;
|
||||||
import java.util.LinkedHashMap;
|
import org.nd4j.linalg.util.ConvConfigUtil;
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@Builder
|
@Builder
|
||||||
@AllArgsConstructor
|
|
||||||
@NoArgsConstructor
|
@NoArgsConstructor
|
||||||
public class Pooling3DConfig extends BaseConvolutionConfig {
|
public class Pooling3DConfig extends BaseConvolutionConfig {
|
||||||
private long kD, kW, kH; // kernel
|
@Builder.Default private long kD = -1, kW = -1, kH = -1; // kernel
|
||||||
private long sD, sW, sH; // strides
|
@Builder.Default private long sD = 1, sW = 1, sH = 1; // strides
|
||||||
private long pD, pW, pH; // padding
|
@Builder.Default private long pD = 0, pW = 0, pH = 0; // padding
|
||||||
// dilation
|
// dilation
|
||||||
@Builder.Default
|
@Builder.Default
|
||||||
private long dD = 1;
|
private long dD = 1;
|
||||||
|
@ -40,10 +39,33 @@ public class Pooling3DConfig extends BaseConvolutionConfig {
|
||||||
private long dW = 1;
|
private long dW = 1;
|
||||||
@Builder.Default
|
@Builder.Default
|
||||||
private long dH = 1;
|
private long dH = 1;
|
||||||
private Pooling3D.Pooling3DType type;
|
@Builder.Default
|
||||||
|
private Pooling3D.Pooling3DType type = Pooling3DType.MAX;
|
||||||
private boolean isSameMode;
|
private boolean isSameMode;
|
||||||
@Builder.Default private boolean isNCDHW = true;
|
@Builder.Default private boolean isNCDHW = true;
|
||||||
|
|
||||||
|
public Pooling3DConfig(long kD, long kW, long kH, long sD, long sW, long sH, long pD, long pW, long pH, long dD,
|
||||||
|
long dW, long dH, Pooling3DType type, boolean isSameMode, boolean isNCDHW) {
|
||||||
|
this.kD = kD;
|
||||||
|
this.kW = kW;
|
||||||
|
this.kH = kH;
|
||||||
|
this.sD = sD;
|
||||||
|
this.sW = sW;
|
||||||
|
this.sH = sH;
|
||||||
|
this.pD = pD;
|
||||||
|
this.pW = pW;
|
||||||
|
this.pH = pH;
|
||||||
|
this.dD = dD;
|
||||||
|
this.dW = dW;
|
||||||
|
this.dH = dH;
|
||||||
|
this.type = type;
|
||||||
|
this.isSameMode = isSameMode;
|
||||||
|
this.isNCDHW = isNCDHW;
|
||||||
|
|
||||||
|
validate();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
public Map<String, Object> toProperties() {
|
public Map<String, Object> toProperties() {
|
||||||
Map<String, Object> ret = new LinkedHashMap<>();
|
Map<String, Object> ret = new LinkedHashMap<>();
|
||||||
ret.put("kD", kD);
|
ret.put("kD", kD);
|
||||||
|
@ -63,4 +85,11 @@ public class Pooling3DConfig extends BaseConvolutionConfig {
|
||||||
return ret;
|
return ret;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected void validate() {
|
||||||
|
ConvConfigUtil.validate3D(kH, kW, kD, sH, sW, sD, pH, pW, pD, dH, dW, dD);
|
||||||
|
|
||||||
|
//TODO check other args
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -249,8 +249,6 @@ public class Convolution {
|
||||||
.isSameMode(isSameMode)
|
.isSameMode(isSameMode)
|
||||||
.sH(sy)
|
.sH(sy)
|
||||||
.sW(sx)
|
.sW(sx)
|
||||||
.virtualHeight(virtualHeight)
|
|
||||||
.virtualWidth(virtualWidth)
|
|
||||||
.type(type)
|
.type(type)
|
||||||
.divisor(divisor)
|
.divisor(divisor)
|
||||||
.build())
|
.build())
|
||||||
|
|
|
@ -0,0 +1,93 @@
|
||||||
|
/*
|
||||||
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.nd4j.linalg.util;
|
||||||
|
|
||||||
|
import lombok.AccessLevel;
|
||||||
|
import lombok.NoArgsConstructor;
|
||||||
|
import org.nd4j.base.Preconditions;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Class with utility methods for validating convolution op configurations like {@link org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig}
|
||||||
|
*/
|
||||||
|
@NoArgsConstructor(access = AccessLevel.PRIVATE)
|
||||||
|
public class ConvConfigUtil {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Validate a 2D convolution's Kernel, Stride, Padding, and Dilation
|
||||||
|
*/
|
||||||
|
public static void validate2D(long kH, long kW, long sH, long sW, long pH, long pW, long dH, long dW){
|
||||||
|
Preconditions.checkArgument(kH != 0, "Kernel height can not be 0");
|
||||||
|
Preconditions.checkArgument(kW != 0, "Kernel width can not be 0");
|
||||||
|
|
||||||
|
Preconditions.checkArgument(sH > 0, "Stride height can not be negative or 0, got: %s", sH);
|
||||||
|
Preconditions.checkArgument(sW > 0, "Stride width can not be negative or 0, got: %s", sW);
|
||||||
|
|
||||||
|
Preconditions.checkArgument(pH >= 0, "Padding height can not be negative, got: %s", pH);
|
||||||
|
Preconditions.checkArgument(pW >= 0, "Padding width can not be negative, got: %s", pW);
|
||||||
|
|
||||||
|
Preconditions.checkArgument(dH > 0, "Dilation height can not be negative or 0, got: %s", dH);
|
||||||
|
Preconditions.checkArgument(dW > 0, "Dilation width can not be negative or 0, got: %s", dW);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Validate a 3D convolution's Kernel, Stride, Padding, and Dilation
|
||||||
|
*/
|
||||||
|
public static void validate3D(long kH, long kW, long kD, long sH, long sW, long sD, long pH, long pW, long pD, long dH, long dW, long dD){
|
||||||
|
Preconditions.checkArgument(kH != 0, "Kernel height can not be 0");
|
||||||
|
Preconditions.checkArgument(kW != 0, "Kernel width can not be 0");
|
||||||
|
Preconditions.checkArgument(kD != 0, "Kernel depth can not be 0");
|
||||||
|
|
||||||
|
Preconditions.checkArgument(sH > 0, "Stride height can not be negative or 0, got: %s", sH);
|
||||||
|
Preconditions.checkArgument(sW > 0, "Stride width can not be negative or 0, got: %s", sW);
|
||||||
|
Preconditions.checkArgument(sD > 0, "Stride depth can not be negative or 0, got: %s", sD);
|
||||||
|
|
||||||
|
Preconditions.checkArgument(pH >= 0, "Padding height can not be negative, got: %s", pH);
|
||||||
|
Preconditions.checkArgument(pW >= 0, "Padding width can not be negative, got: %s", pW);
|
||||||
|
Preconditions.checkArgument(pD >= 0, "Padding depth can not be negative, got: %s", pD);
|
||||||
|
|
||||||
|
Preconditions.checkArgument(dH > 0, "Dilation height can not be negative or 0, got: %s", dH);
|
||||||
|
Preconditions.checkArgument(dW > 0, "Dilation width can not be negative or 0, got: %s", dW);
|
||||||
|
Preconditions.checkArgument(dD > 0, "Dilation depth can not be negative or 0, got: %s", dD);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Validate a 3D convolution's Output Padding
|
||||||
|
*/
|
||||||
|
public static void validateExtra3D(long aH, long aW, long aD){
|
||||||
|
Preconditions.checkArgument(aH >= 0, "Output padding height can not be negative, got: %s", aH);
|
||||||
|
Preconditions.checkArgument(aW >= 0, "Output padding width can not be negative, got: %s", aW);
|
||||||
|
Preconditions.checkArgument(aD >= 0, "Output padding depth can not be negative, got: %s", aD);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Validate a 1D convolution's Kernel, Stride, and Padding
|
||||||
|
*/
|
||||||
|
public static void validate1D(long k, long s, long p){
|
||||||
|
Preconditions.checkArgument(k != 0, "Kernel can not be 0");
|
||||||
|
|
||||||
|
Preconditions.checkArgument(s > 0, "Stride can not be negative or 0, got: %s", s);
|
||||||
|
|
||||||
|
Preconditions.checkArgument(p >= 0, "Padding can not be negative, got: %s", p);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Validate a LocalResponseNormalizationConfig
|
||||||
|
*/
|
||||||
|
public static void validateLRN(double alpha, double beta, double bias, int depth) {
|
||||||
|
Preconditions.checkArgument(depth > 0, "Depth can not be 0 or negative, got: %s", depth);
|
||||||
|
}
|
||||||
|
}
|
|
@ -14662,54 +14662,6 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
||||||
}
|
}
|
||||||
// #endif
|
// #endif
|
||||||
|
|
||||||
// #if NOT_EXCLUDED(OP_fullconv3d)
|
|
||||||
@Namespace("nd4j::ops") public static class fullconv3d extends DeclarableCustomOp {
|
|
||||||
static { Loader.load(); }
|
|
||||||
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
|
||||||
public fullconv3d(Pointer p) { super(p); }
|
|
||||||
/** Native array allocator. Access with {@link Pointer#position(long)}. */
|
|
||||||
public fullconv3d(long size) { super((Pointer)null); allocateArray(size); }
|
|
||||||
private native void allocateArray(long size);
|
|
||||||
@Override public fullconv3d position(long position) {
|
|
||||||
return (fullconv3d)super.position(position);
|
|
||||||
}
|
|
||||||
|
|
||||||
public fullconv3d() { super((Pointer)null); allocate(); }
|
|
||||||
private native void allocate();
|
|
||||||
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
|
|
||||||
}
|
|
||||||
@Namespace("nd4j::ops") public static class fullconv3d_bp extends DeclarableCustomOp {
|
|
||||||
static { Loader.load(); }
|
|
||||||
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
|
||||||
public fullconv3d_bp(Pointer p) { super(p); }
|
|
||||||
/** Native array allocator. Access with {@link Pointer#position(long)}. */
|
|
||||||
public fullconv3d_bp(long size) { super((Pointer)null); allocateArray(size); }
|
|
||||||
private native void allocateArray(long size);
|
|
||||||
@Override public fullconv3d_bp position(long position) {
|
|
||||||
return (fullconv3d_bp)super.position(position);
|
|
||||||
}
|
|
||||||
|
|
||||||
public fullconv3d_bp() { super((Pointer)null); allocate(); }
|
|
||||||
private native void allocate();
|
|
||||||
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
|
|
||||||
}
|
|
||||||
@Namespace("nd4j::ops") public static class fullconv3d_grad extends DeclarableCustomOp {
|
|
||||||
static { Loader.load(); }
|
|
||||||
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
|
||||||
public fullconv3d_grad(Pointer p) { super(p); }
|
|
||||||
/** Native array allocator. Access with {@link Pointer#position(long)}. */
|
|
||||||
public fullconv3d_grad(long size) { super((Pointer)null); allocateArray(size); }
|
|
||||||
private native void allocateArray(long size);
|
|
||||||
@Override public fullconv3d_grad position(long position) {
|
|
||||||
return (fullconv3d_grad)super.position(position);
|
|
||||||
}
|
|
||||||
|
|
||||||
public fullconv3d_grad() { super((Pointer)null); allocate(); }
|
|
||||||
private native void allocate();
|
|
||||||
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
|
|
||||||
}
|
|
||||||
// #endif
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This op implements im2col algorithm, widely used in convolution neural networks
|
* This op implements im2col algorithm, widely used in convolution neural networks
|
||||||
* Input: 4D input expected
|
* Input: 4D input expected
|
||||||
|
|
|
@ -1124,7 +1124,7 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
assertNull(err, err);
|
assertNull(err, err);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(expected = IllegalStateException.class)
|
@Test(expected = IllegalArgumentException.class)
|
||||||
public void exceptionThrown_WhenConv1DConfigInvalid() {
|
public void exceptionThrown_WhenConv1DConfigInvalid() {
|
||||||
int nIn = 3;
|
int nIn = 3;
|
||||||
int nOut = 4;
|
int nOut = 4;
|
||||||
|
@ -1150,7 +1150,7 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(expected = IllegalStateException.class)
|
@Test(expected = IllegalArgumentException.class)
|
||||||
public void exceptionThrown_WhenConv2DConfigInvalid() {
|
public void exceptionThrown_WhenConv2DConfigInvalid() {
|
||||||
|
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
@ -1171,7 +1171,7 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
.build());
|
.build());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(expected = IllegalStateException.class)
|
@Test(expected = IllegalArgumentException.class)
|
||||||
public void exceptionThrown_WhenConf3DInvalid() {
|
public void exceptionThrown_WhenConf3DInvalid() {
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,515 @@
|
||||||
|
/*
|
||||||
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.nd4j.autodiff.samediff;
|
||||||
|
|
||||||
|
import static org.junit.Assert.assertThat;
|
||||||
|
import static org.junit.Assert.assertTrue;
|
||||||
|
import static org.junit.Assert.fail;
|
||||||
|
|
||||||
|
import org.junit.Assert;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.DeConv2D;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv1DConfig;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv3DConfig;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv2DConfig;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv3DConfig;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling2DConfig;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling3DConfig;
|
||||||
|
|
||||||
|
public class ConvConfigTests {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testDeConv2D(){
|
||||||
|
DeConv2DConfig.builder().kH(2).kW(4).build();
|
||||||
|
|
||||||
|
try{
|
||||||
|
DeConv2DConfig.builder().kW(4).kH(0).build();
|
||||||
|
fail();
|
||||||
|
} catch (IllegalArgumentException e){
|
||||||
|
assertTrue(e.getMessage().contains("Kernel height"));
|
||||||
|
}
|
||||||
|
|
||||||
|
try{
|
||||||
|
DeConv2DConfig.builder().kH(4).kW(0).build();
|
||||||
|
fail();
|
||||||
|
} catch (IllegalArgumentException e){
|
||||||
|
assertTrue(e.getMessage().contains("Kernel width"));
|
||||||
|
}
|
||||||
|
|
||||||
|
try{
|
||||||
|
DeConv2DConfig.builder().kH(4).kW(3).sH(-2).build();
|
||||||
|
fail();
|
||||||
|
} catch (IllegalArgumentException e){
|
||||||
|
assertTrue(e.getMessage().contains("Stride height"));
|
||||||
|
}
|
||||||
|
|
||||||
|
try{
|
||||||
|
DeConv2DConfig.builder().kH(4).kW(3).sW(-2).build();
|
||||||
|
fail();
|
||||||
|
} catch (IllegalArgumentException e){
|
||||||
|
assertTrue(e.getMessage().contains("Stride width"));
|
||||||
|
}
|
||||||
|
|
||||||
|
try{
|
||||||
|
DeConv2DConfig.builder().kH(4).kW(3).pH(-2).build();
|
||||||
|
fail();
|
||||||
|
} catch (IllegalArgumentException e){
|
||||||
|
assertTrue(e.getMessage().contains("Padding height"));
|
||||||
|
}
|
||||||
|
|
||||||
|
try{
|
||||||
|
DeConv2DConfig.builder().kH(4).kW(3).pW(-2).build();
|
||||||
|
fail();
|
||||||
|
} catch (IllegalArgumentException e){
|
||||||
|
assertTrue(e.getMessage().contains("Padding width"));
|
||||||
|
}
|
||||||
|
|
||||||
|
try{
|
||||||
|
DeConv2DConfig.builder().kH(4).kW(3).dH(-2).build();
|
||||||
|
fail();
|
||||||
|
} catch (IllegalArgumentException e){
|
||||||
|
assertTrue(e.getMessage().contains("Dilation height"));
|
||||||
|
}
|
||||||
|
|
||||||
|
try{
|
||||||
|
DeConv2DConfig.builder().kH(4).kW(3).dW(-2).build();
|
||||||
|
fail();
|
||||||
|
} catch (IllegalArgumentException e){
|
||||||
|
assertTrue(e.getMessage().contains("Dilation width"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testConv2D(){
|
||||||
|
Conv2DConfig.builder().kH(2).kW(4).build();
|
||||||
|
|
||||||
|
try{
|
||||||
|
Conv2DConfig.builder().kW(4).kH(0).build();
|
||||||
|
fail();
|
||||||
|
} catch (IllegalArgumentException e){
|
||||||
|
assertTrue(e.getMessage().contains("Kernel height"));
|
||||||
|
}
|
||||||
|
|
||||||
|
try{
|
||||||
|
Conv2DConfig.builder().kH(4).kW(0).build();
|
||||||
|
fail();
|
||||||
|
} catch (IllegalArgumentException e){
|
||||||
|
assertTrue(e.getMessage().contains("Kernel width"));
|
||||||
|
}
|
||||||
|
|
||||||
|
try{
|
||||||
|
Conv2DConfig.builder().kH(4).kW(3).sH(-2).build();
|
||||||
|
fail();
|
||||||
|
} catch (IllegalArgumentException e){
|
||||||
|
assertTrue(e.getMessage().contains("Stride height"));
|
||||||
|
}
|
||||||
|
|
||||||
|
try{
|
||||||
|
Conv2DConfig.builder().kH(4).kW(3).sW(-2).build();
|
||||||
|
fail();
|
||||||
|
} catch (IllegalArgumentException e){
|
||||||
|
assertTrue(e.getMessage().contains("Stride width"));
|
||||||
|
}
|
||||||
|
|
||||||
|
try{
|
||||||
|
Conv2DConfig.builder().kH(4).kW(3).pH(-2).build();
|
||||||
|
fail();
|
||||||
|
} catch (IllegalArgumentException e){
|
||||||
|
assertTrue(e.getMessage().contains("Padding height"));
|
||||||
|
}
|
||||||
|
|
||||||
|
try{
|
||||||
|
Conv2DConfig.builder().kH(4).kW(3).pW(-2).build();
|
||||||
|
fail();
|
||||||
|
} catch (IllegalArgumentException e){
|
||||||
|
assertTrue(e.getMessage().contains("Padding width"));
|
||||||
|
}
|
||||||
|
|
||||||
|
try{
|
||||||
|
Conv2DConfig.builder().kH(4).kW(3).dH(-2).build();
|
||||||
|
fail();
|
||||||
|
} catch (IllegalArgumentException e){
|
||||||
|
assertTrue(e.getMessage().contains("Dilation height"));
|
||||||
|
}
|
||||||
|
|
||||||
|
try{
|
||||||
|
Conv2DConfig.builder().kH(4).kW(3).dW(-2).build();
|
||||||
|
fail();
|
||||||
|
} catch (IllegalArgumentException e){
|
||||||
|
assertTrue(e.getMessage().contains("Dilation width"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testPooling2D(){
|
||||||
|
Pooling2DConfig.builder().kH(2).kW(4).build();
|
||||||
|
|
||||||
|
try{
|
||||||
|
Pooling2DConfig.builder().kW(4).kH(0).build();
|
||||||
|
fail();
|
||||||
|
} catch (IllegalArgumentException e){
|
||||||
|
assertTrue(e.getMessage().contains("Kernel height"));
|
||||||
|
}
|
||||||
|
|
||||||
|
try{
|
||||||
|
Pooling2DConfig.builder().kH(4).kW(0).build();
|
||||||
|
fail();
|
||||||
|
} catch (IllegalArgumentException e){
|
||||||
|
assertTrue(e.getMessage().contains("Kernel width"));
|
||||||
|
}
|
||||||
|
|
||||||
|
try{
|
||||||
|
Pooling2DConfig.builder().kH(4).kW(3).sH(-2).build();
|
||||||
|
fail();
|
||||||
|
} catch (IllegalArgumentException e){
|
||||||
|
assertTrue(e.getMessage().contains("Stride height"));
|
||||||
|
}
|
||||||
|
|
||||||
|
try{
|
||||||
|
Pooling2DConfig.builder().kH(4).kW(3).sW(-2).build();
|
||||||
|
fail();
|
||||||
|
} catch (IllegalArgumentException e){
|
||||||
|
assertTrue(e.getMessage().contains("Stride width"));
|
||||||
|
}
|
||||||
|
|
||||||
|
try{
|
||||||
|
Pooling2DConfig.builder().kH(4).kW(3).pH(-2).build();
|
||||||
|
fail();
|
||||||
|
} catch (IllegalArgumentException e){
|
||||||
|
assertTrue(e.getMessage().contains("Padding height"));
|
||||||
|
}
|
||||||
|
|
||||||
|
try{
|
||||||
|
Pooling2DConfig.builder().kH(4).kW(3).pW(-2).build();
|
||||||
|
fail();
|
||||||
|
} catch (IllegalArgumentException e){
|
||||||
|
assertTrue(e.getMessage().contains("Padding width"));
|
||||||
|
}
|
||||||
|
|
||||||
|
try{
|
||||||
|
Pooling2DConfig.builder().kH(4).kW(3).dH(-2).build();
|
||||||
|
fail();
|
||||||
|
} catch (IllegalArgumentException e){
|
||||||
|
assertTrue(e.getMessage().contains("Dilation height"));
|
||||||
|
}
|
||||||
|
|
||||||
|
try{
|
||||||
|
Pooling2DConfig.builder().kH(4).kW(3).dW(-2).build();
|
||||||
|
fail();
|
||||||
|
} catch (IllegalArgumentException e){
|
||||||
|
assertTrue(e.getMessage().contains("Dilation width"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testDeConv3D(){
|
||||||
|
DeConv3DConfig.builder().kH(2).kW(4).kD(3).build();
|
||||||
|
|
||||||
|
try{
|
||||||
|
DeConv3DConfig.builder().kW(4).kD(3).kH(0).build();
|
||||||
|
fail();
|
||||||
|
} catch (IllegalArgumentException e){
|
||||||
|
assertTrue(e.getMessage().contains("Kernel height"));
|
||||||
|
}
|
||||||
|
|
||||||
|
try{
|
||||||
|
DeConv3DConfig.builder().kH(4).kD(3).kW(0).build();
|
||||||
|
fail();
|
||||||
|
} catch (IllegalArgumentException e){
|
||||||
|
assertTrue(e.getMessage().contains("Kernel width"));
|
||||||
|
}
|
||||||
|
|
||||||
|
try{
|
||||||
|
DeConv3DConfig.builder().kH(4).kW(3).kD(0).build();
|
||||||
|
fail();
|
||||||
|
} catch (IllegalArgumentException e){
|
||||||
|
assertTrue(e.getMessage().contains("Kernel depth"));
|
||||||
|
}
|
||||||
|
|
||||||
|
try{
|
||||||
|
DeConv3DConfig.builder().kH(4).kW(3).kD(3).sH(-2).build();
|
||||||
|
fail();
|
||||||
|
} catch (IllegalArgumentException e){
|
||||||
|
assertTrue(e.getMessage().contains("Stride height"));
|
||||||
|
}
|
||||||
|
|
||||||
|
try{
|
||||||
|
DeConv3DConfig.builder().kH(4).kW(3).kD(3).sW(-2).build();
|
||||||
|
fail();
|
||||||
|
} catch (IllegalArgumentException e){
|
||||||
|
assertTrue(e.getMessage().contains("Stride width"));
|
||||||
|
}
|
||||||
|
|
||||||
|
try{
|
||||||
|
DeConv3DConfig.builder().kH(4).kW(3).kD(3).sD(-2).build();
|
||||||
|
fail();
|
||||||
|
} catch (IllegalArgumentException e){
|
||||||
|
assertTrue(e.getMessage().contains("Stride depth"));
|
||||||
|
}
|
||||||
|
|
||||||
|
try{
|
||||||
|
DeConv3DConfig.builder().kH(4).kW(3).kD(3).pH(-2).build();
|
||||||
|
fail();
|
||||||
|
} catch (IllegalArgumentException e){
|
||||||
|
assertTrue(e.getMessage().contains("Padding height"));
|
||||||
|
}
|
||||||
|
|
||||||
|
try{
|
||||||
|
DeConv3DConfig.builder().kH(4).kW(3).kD(3).pW(-2).build();
|
||||||
|
fail();
|
||||||
|
} catch (IllegalArgumentException e){
|
||||||
|
assertTrue(e.getMessage().contains("Padding width"));
|
||||||
|
}
|
||||||
|
|
||||||
|
try{
|
||||||
|
DeConv3DConfig.builder().kH(4).kW(3).kD(3).pD(-2).build();
|
||||||
|
fail();
|
||||||
|
} catch (IllegalArgumentException e){
|
||||||
|
assertTrue(e.getMessage().contains("Padding depth"));
|
||||||
|
}
|
||||||
|
|
||||||
|
try{
|
||||||
|
DeConv3DConfig.builder().kH(4).kW(3).kD(3).dH(-2).build();
|
||||||
|
fail();
|
||||||
|
} catch (IllegalArgumentException e){
|
||||||
|
assertTrue(e.getMessage().contains("Dilation height"));
|
||||||
|
}
|
||||||
|
|
||||||
|
try{
|
||||||
|
DeConv3DConfig.builder().kH(4).kW(3).kD(3).dW(-2).build();
|
||||||
|
fail();
|
||||||
|
} catch (IllegalArgumentException e){
|
||||||
|
assertTrue(e.getMessage().contains("Dilation width"));
|
||||||
|
}
|
||||||
|
|
||||||
|
try{
|
||||||
|
DeConv3DConfig.builder().kH(4).kW(3).kD(3).dD(-2).build();
|
||||||
|
fail();
|
||||||
|
} catch (IllegalArgumentException e){
|
||||||
|
assertTrue(e.getMessage().contains("Dilation depth"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testConv3D(){
|
||||||
|
Conv3DConfig.builder().kH(2).kW(4).kD(3).build();
|
||||||
|
|
||||||
|
try{
|
||||||
|
Conv3DConfig.builder().kW(4).kD(3).kH(0).build();
|
||||||
|
fail();
|
||||||
|
} catch (IllegalArgumentException e){
|
||||||
|
assertTrue(e.getMessage().contains("Kernel height"));
|
||||||
|
}
|
||||||
|
|
||||||
|
try{
|
||||||
|
Conv3DConfig.builder().kH(4).kD(3).kW(0).build();
|
||||||
|
fail();
|
||||||
|
} catch (IllegalArgumentException e){
|
||||||
|
assertTrue(e.getMessage().contains("Kernel width"));
|
||||||
|
}
|
||||||
|
|
||||||
|
try{
|
||||||
|
Conv3DConfig.builder().kH(4).kW(3).kD(0).build();
|
||||||
|
fail();
|
||||||
|
} catch (IllegalArgumentException e){
|
||||||
|
assertTrue(e.getMessage().contains("Kernel depth"));
|
||||||
|
}
|
||||||
|
|
||||||
|
try{
|
||||||
|
Conv3DConfig.builder().kH(4).kW(3).kD(3).sH(-2).build();
|
||||||
|
fail();
|
||||||
|
} catch (IllegalArgumentException e){
|
||||||
|
assertTrue(e.getMessage().contains("Stride height"));
|
||||||
|
}
|
||||||
|
|
||||||
|
try{
|
||||||
|
Conv3DConfig.builder().kH(4).kW(3).kD(3).sW(-2).build();
|
||||||
|
fail();
|
||||||
|
} catch (IllegalArgumentException e){
|
||||||
|
assertTrue(e.getMessage().contains("Stride width"));
|
||||||
|
}
|
||||||
|
|
||||||
|
try{
|
||||||
|
Conv3DConfig.builder().kH(4).kW(3).kD(3).sD(-2).build();
|
||||||
|
fail();
|
||||||
|
} catch (IllegalArgumentException e){
|
||||||
|
assertTrue(e.getMessage().contains("Stride depth"));
|
||||||
|
}
|
||||||
|
|
||||||
|
try{
|
||||||
|
Conv3DConfig.builder().kH(4).kW(3).kD(3).pH(-2).build();
|
||||||
|
fail();
|
||||||
|
} catch (IllegalArgumentException e){
|
||||||
|
assertTrue(e.getMessage().contains("Padding height"));
|
||||||
|
}
|
||||||
|
|
||||||
|
try{
|
||||||
|
Conv3DConfig.builder().kH(4).kW(3).kD(3).pW(-2).build();
|
||||||
|
fail();
|
||||||
|
} catch (IllegalArgumentException e){
|
||||||
|
assertTrue(e.getMessage().contains("Padding width"));
|
||||||
|
}
|
||||||
|
|
||||||
|
try{
|
||||||
|
Conv3DConfig.builder().kH(4).kW(3).kD(3).pD(-2).build();
|
||||||
|
fail();
|
||||||
|
} catch (IllegalArgumentException e){
|
||||||
|
assertTrue(e.getMessage().contains("Padding depth"));
|
||||||
|
}
|
||||||
|
|
||||||
|
try{
|
||||||
|
Conv3DConfig.builder().kH(4).kW(3).kD(3).dH(-2).build();
|
||||||
|
fail();
|
||||||
|
} catch (IllegalArgumentException e){
|
||||||
|
assertTrue(e.getMessage().contains("Dilation height"));
|
||||||
|
}
|
||||||
|
|
||||||
|
try{
|
||||||
|
Conv3DConfig.builder().kH(4).kW(3).kD(3).dW(-2).build();
|
||||||
|
fail();
|
||||||
|
} catch (IllegalArgumentException e){
|
||||||
|
assertTrue(e.getMessage().contains("Dilation width"));
|
||||||
|
}
|
||||||
|
|
||||||
|
try{
|
||||||
|
Conv3DConfig.builder().kH(4).kW(3).kD(3).dD(-2).build();
|
||||||
|
fail();
|
||||||
|
} catch (IllegalArgumentException e){
|
||||||
|
assertTrue(e.getMessage().contains("Dilation depth"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testPooling3D(){
|
||||||
|
Pooling3DConfig.builder().kH(2).kW(4).kD(3).build();
|
||||||
|
|
||||||
|
try{
|
||||||
|
Pooling3DConfig.builder().kW(4).kD(3).kH(0).build();
|
||||||
|
fail();
|
||||||
|
} catch (IllegalArgumentException e){
|
||||||
|
assertTrue(e.getMessage().contains("Kernel height"));
|
||||||
|
}
|
||||||
|
|
||||||
|
try{
|
||||||
|
Pooling3DConfig.builder().kH(4).kD(3).kW(0).build();
|
||||||
|
fail();
|
||||||
|
} catch (IllegalArgumentException e){
|
||||||
|
assertTrue(e.getMessage().contains("Kernel width"));
|
||||||
|
}
|
||||||
|
|
||||||
|
try{
|
||||||
|
Pooling3DConfig.builder().kH(4).kW(3).kD(0).build();
|
||||||
|
fail();
|
||||||
|
} catch (IllegalArgumentException e){
|
||||||
|
assertTrue(e.getMessage().contains("Kernel depth"));
|
||||||
|
}
|
||||||
|
|
||||||
|
try{
|
||||||
|
Pooling3DConfig.builder().kH(4).kW(3).kD(3).sH(-2).build();
|
||||||
|
fail();
|
||||||
|
} catch (IllegalArgumentException e){
|
||||||
|
assertTrue(e.getMessage().contains("Stride height"));
|
||||||
|
}
|
||||||
|
|
||||||
|
try{
|
||||||
|
Pooling3DConfig.builder().kH(4).kW(3).kD(3).sW(-2).build();
|
||||||
|
fail();
|
||||||
|
} catch (IllegalArgumentException e){
|
||||||
|
assertTrue(e.getMessage().contains("Stride width"));
|
||||||
|
}
|
||||||
|
|
||||||
|
try{
|
||||||
|
Pooling3DConfig.builder().kH(4).kW(3).kD(3).sD(-2).build();
|
||||||
|
fail();
|
||||||
|
} catch (IllegalArgumentException e){
|
||||||
|
assertTrue(e.getMessage().contains("Stride depth"));
|
||||||
|
}
|
||||||
|
|
||||||
|
try{
|
||||||
|
Pooling3DConfig.builder().kH(4).kW(3).kD(3).pH(-2).build();
|
||||||
|
fail();
|
||||||
|
} catch (IllegalArgumentException e){
|
||||||
|
assertTrue(e.getMessage().contains("Padding height"));
|
||||||
|
}
|
||||||
|
|
||||||
|
try{
|
||||||
|
Pooling3DConfig.builder().kH(4).kW(3).kD(3).pW(-2).build();
|
||||||
|
fail();
|
||||||
|
} catch (IllegalArgumentException e){
|
||||||
|
assertTrue(e.getMessage().contains("Padding width"));
|
||||||
|
}
|
||||||
|
|
||||||
|
try{
|
||||||
|
Pooling3DConfig.builder().kH(4).kW(3).kD(3).pD(-2).build();
|
||||||
|
fail();
|
||||||
|
} catch (IllegalArgumentException e){
|
||||||
|
assertTrue(e.getMessage().contains("Padding depth"));
|
||||||
|
}
|
||||||
|
|
||||||
|
try{
|
||||||
|
Pooling3DConfig.builder().kH(4).kW(3).kD(3).dH(-2).build();
|
||||||
|
fail();
|
||||||
|
} catch (IllegalArgumentException e){
|
||||||
|
assertTrue(e.getMessage().contains("Dilation height"));
|
||||||
|
}
|
||||||
|
|
||||||
|
try{
|
||||||
|
Pooling3DConfig.builder().kH(4).kW(3).kD(3).dW(-2).build();
|
||||||
|
fail();
|
||||||
|
} catch (IllegalArgumentException e){
|
||||||
|
assertTrue(e.getMessage().contains("Dilation width"));
|
||||||
|
}
|
||||||
|
|
||||||
|
try{
|
||||||
|
Pooling3DConfig.builder().kH(4).kW(3).kD(3).dD(-2).build();
|
||||||
|
fail();
|
||||||
|
} catch (IllegalArgumentException e){
|
||||||
|
assertTrue(e.getMessage().contains("Dilation depth"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testConv1D(){
|
||||||
|
Conv1DConfig.builder().k(2).build();
|
||||||
|
|
||||||
|
try{
|
||||||
|
Conv1DConfig.builder().k(0).build();
|
||||||
|
fail();
|
||||||
|
} catch (IllegalArgumentException e){
|
||||||
|
assertTrue(e.getMessage().contains("Kernel"));
|
||||||
|
}
|
||||||
|
|
||||||
|
try{
|
||||||
|
Conv1DConfig.builder().k(4).s(-2).build();
|
||||||
|
fail();
|
||||||
|
} catch (IllegalArgumentException e){
|
||||||
|
assertTrue(e.getMessage().contains("Stride"));
|
||||||
|
}
|
||||||
|
|
||||||
|
try{
|
||||||
|
Conv1DConfig.builder().k(3).p(-2).build();
|
||||||
|
fail();
|
||||||
|
} catch (IllegalArgumentException e){
|
||||||
|
assertTrue(e.getMessage().contains("Padding"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue