From d4e7997134eeec8b5c4d71badd9637eb960e6e21 Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Fri, 26 Jul 2019 20:05:16 -0700 Subject: [PATCH] SameDiff Convolution Config validation, better output methods (#82) * Conv Config validation & tests Signed-off-by: Ryan Nett * stackOutputs utility method Signed-off-by: Ryan Nett * use constructor for validation, support negative kernel sizes (infered from weights) Signed-off-by: Ryan Nett * better output methods Signed-off-by: Ryan Nett * move output to be with fit and evaluate Signed-off-by: Ryan Nett * fixes Signed-off-by: Ryan Nett * more fixes Signed-off-by: Ryan Nett --- .../declarable/generic/convo/fullconv3d.cpp | 444 --------------- .../include/ops/declarable/headers/convo.h | 6 - .../org/nd4j/autodiff/samediff/SameDiff.java | 206 +++++-- .../org/nd4j/autodiff/util/TrainingUtils.java | 70 +++ .../autodiff/validation/OpValidation.java | 1 - .../DifferentialFunctionClassHolder.java | 1 - .../converters/ImportClassMapping.java | 2 - .../impl/layers/convolution/AvgPooling2D.java | 4 - .../impl/layers/convolution/FullConv3D.java | 228 -------- .../convolution/FullConv3DDerivative.java | 77 --- .../impl/layers/convolution/MaxPooling2D.java | 4 - .../impl/layers/convolution/Pooling2D.java | 4 - .../config/BaseConvolutionConfig.java | 5 + .../convolution/config/Conv1DConfig.java | 28 +- .../convolution/config/Conv2DConfig.java | 34 +- .../convolution/config/Conv3DConfig.java | 53 +- .../convolution/config/DeConv2DConfig.java | 35 +- .../convolution/config/DeConv3DConfig.java | 38 +- .../convolution/config/FullConv3DConfig.java | 53 -- .../LocalResponseNormalizationConfig.java | 23 +- .../convolution/config/Pooling2DConfig.java | 50 +- .../convolution/config/Pooling3DConfig.java | 47 +- .../nd4j/linalg/convolution/Convolution.java | 2 - .../org/nd4j/linalg/util/ConvConfigUtil.java | 93 ++++ .../java/org/nd4j/nativeblas/Nd4jCpu.java | 48 -- .../opvalidation/LayerOpValidation.java | 6 +- .../autodiff/samediff/ConvConfigTests.java | 515 ++++++++++++++++++ 27 files changed, 1085 insertions(+), 992 deletions(-) delete mode 100644 libnd4j/include/ops/declarable/generic/convo/fullconv3d.cpp create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/util/TrainingUtils.java delete mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/FullConv3D.java delete mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/FullConv3DDerivative.java delete mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/config/FullConv3DConfig.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/util/ConvConfigUtil.java create mode 100644 nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/ConvConfigTests.java diff --git a/libnd4j/include/ops/declarable/generic/convo/fullconv3d.cpp b/libnd4j/include/ops/declarable/generic/convo/fullconv3d.cpp deleted file mode 100644 index 61da819f6..000000000 --- a/libnd4j/include/ops/declarable/generic/convo/fullconv3d.cpp +++ /dev/null @@ -1,444 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// Created by raver119 on 08.10.2017. -// - -#include -#if NOT_EXCLUDED(OP_fullconv3d) - -#include -#include - -namespace nd4j { - namespace ops { - ////////////////////////////////////////////////////////////////////////// - CUSTOM_OP_IMPL(fullconv3d, 5, 1, false, 0, 13) { - // auto input = INPUT_VARIABLE(0); - // auto weights = INPUT_VARIABLE(1); - // auto bias = INPUT_VARIABLE(2); - // auto columns = INPUT_VARIABLE(3); - // auto ones = INPUT_VARIABLE(4); - - // REQUIRE_TRUE(weights->rankOf() == 5, 0, "Weights should be 5D, got %i instead", weights->rankOf()); - // REQUIRE_TRUE(input->rankOf() == 5, 0, "Input should be 5D, got %i instead", input->rankOf()); - - // // strides - // int dT = INT_ARG(0); - // int dW = INT_ARG(1); - // int dH = INT_ARG(2); - - // // padding - // int pT = INT_ARG(3); - // int pW = INT_ARG(4); - // int pH = INT_ARG(5); - - // // dilation - // int dilationT = INT_ARG(6); - // int dilationW = INT_ARG(7); - // int dilationH = INT_ARG(8); - - // // output padding - // int aT = INT_ARG(9); - // int aW = INT_ARG(10); - // int aH = INT_ARG(11); - - // // bias - // bool biasUsed = INT_ARG(12) != 0; - - - // REQUIRE_TRUE(dT > 0 && dW > 0 && dH > 0, 11, - // "stride should be greater than zero, but got dT: %d dH: %d dW: %d", dT, dH, dW); - // REQUIRE_TRUE(dilationT > 0 && dilationW > 0 && dilationH > 0, 15, - // "dilation should be greater than zero, but got dilationT: %d, dilationH: %d, dilationW: %d", - // dilationT, dilationH, dilationW); - // REQUIRE_TRUE((aT < dT || aT < dilationT) - // && (aW < dW || aW < dilationW) - // && (aH < dH || aH < dilationH), 15, - // "output padding must be smaller than either stride or dilation," - // " but got aT: %d aH: %d aW: %d dT: %d dH: %d dW: %d " - // "dilationT: %d dilationH: %d dilationW: %d", - // aT, aH, aW, dT, dH, dW, dilationT, dilationH, dilationW); - - // auto output = this->getZ(block); - - // const int nInputPlane = weights->shapeOf()[0]; - // const int nOutputPlane = weights->shapeOf()[1]; - // const int kT = weights->shapeOf()[2]; - // const int kH = weights->shapeOf()[3]; - // const int kW = weights->shapeOf()[4]; - - // const Nd4jLong inputWidth = input->shapeOf()[4]; - // const Nd4jLong inputHeight = input->shapeOf()[3]; - // const Nd4jLong inputDepth = input->shapeOf()[2]; - // const Nd4jLong outputDepth = (inputDepth - 1) * dT - 2*pT + (dilationT * (kT - 1) + 1) + aT; - // const Nd4jLong outputHeight = (inputHeight - 1) * dH - 2*pH + (dilationH * (kH - 1) + 1) + aH; - // const Nd4jLong outputWidth = (inputWidth - 1) * dW - 2*pW + (dilationW * (kW - 1) + 1) + aW; - - // const Nd4jLong batchSize = input->shapeOf()[0]; - - // REQUIRE_TRUE(output->isSameShape({ (int) batchSize, (int)nOutputPlane, (int)outputDepth, (int)outputHeight, (int)outputWidth}), 0, "Output should have shape of [%i, %i, %i, %i, %i], but got [%i, %i, %i, %i, %i] instead", (int) batchSize, (int)nOutputPlane, (int)outputDepth, (int)outputHeight, (int)outputWidth, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2), output->sizeAt(3), output->sizeAt(4)); - - // std::unique_ptr inputs(input->allExamples()); - // std::unique_ptr outputs(output->allExamples()); - // for (int e = 0; e < batchSize; e++) { - // auto tadIn = inputs->at(e); - // auto tadOut = outputs->at(e); - - // const int m = weights->shapeOf()[1] * weights->shapeOf()[2] * weights->shapeOf()[3] * weights->shapeOf()[4]; - // const int n = columns->shapeOf()[1]; - // const int k = weights->shapeOf()[0]; - - // // FIXME: mmul helper should be used here - // /* - // nd4j::blas::GEMM::op('c', 'n', 't', m, n, k, - // 1.0, - // tadIn->getBuffer(), n, - // weights->getBuffer(), m, - // 0.0, - // columns->getBuffer(), n); - // */ - - // // ConvolutionUtils::_col2vol(columns->getBuffer(), - // // nOutputPlane, outputDepth, outputHeight, outputWidth, - // // inputDepth, inputHeight, inputWidth, - // // kT, kH, kW, - // // pT, pH, pW, - // // dT, dH, dW, - // // dilationT, dilationH, dilationW, - // // tadOut->getBuffer()); - // ConvolutionUtils::col2vol(*columns, *tadOut, dT, dH, dW, pT, pH, pW, dilationT, dilationH, dilationW); - - - // const int m_ = nOutputPlane; - // const int n_ = outputDepth * outputHeight * outputWidth; - // const int k_ = 1; - - // if (biasUsed) { - // // FIXME: mmul helper should be used here - // /* - // nd4j::blas::GEMM::op('c', 't', 'n', n_, m_, k_, - // 1.0, - // ones->getBuffer(), k_, - // bias->getBuffer(), k_, - // 1.0, - // tadOut->getBuffer(), n_); - // */ - // } - // } - - // STORE_RESULT(*output); - - return Status::OK(); - } - - DECLARE_TYPES(fullconv3d) { - getOpDescriptor() - ->setAllowedInputTypes(nd4j::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - - DECLARE_SHAPE_FN(fullconv3d) { - // auto input = inputShape->at(0); - // auto weights = inputShape->at(1); - - // // strides - // int dT = INT_ARG(0); - // int dW = INT_ARG(1); - // int dH = INT_ARG(2); - - // // padding - // int pT = INT_ARG(3); - // int pW = INT_ARG(4); - // int pH = INT_ARG(5); - - // // dilation - // int dilationT = INT_ARG(6); - // int dilationW = INT_ARG(7); - // int dilationH = INT_ARG(8); - - // // output padding - // int aT = INT_ARG(9); - // int aW = INT_ARG(10); - // int aH = INT_ARG(11); - - // // bias - // bool biasUsed = INT_ARG(12) != 0; - - // Nd4jLong *shapeOf; - // Nd4jLong *newShape; - // ALLOCATE(shapeOf, block.getWorkspace(), 5, Nd4jLong); - // ALLOCATE(newShape, block.getWorkspace(), shape::shapeInfoLength(5), Nd4jLong); - - // const int nInputPlane = weights[1]; - // const int nOutputPlane = weights[2]; - // const int kT = weights[3]; - // const int kH = weights[4]; - // const int kW = weights[5]; - - // const int batchSize = input[1]; - // const Nd4jLong inputWidth = input[5]; - // const Nd4jLong inputHeight = input[4]; - // const Nd4jLong inputDepth = input[3]; - // const Nd4jLong outputDepth = (inputDepth - 1) * dT - 2*pT + (dilationT * (kT - 1) + 1) + aT; - // const Nd4jLong outputHeight = (inputHeight - 1) * dH - 2*pH + (dilationH * (kH - 1) + 1) + aH; - // const Nd4jLong outputWidth = (inputWidth - 1) * dW - 2*pW + (dilationW * (kW - 1) + 1) + aW; - - // nd4j::ArrayUtils::toLongPtr({(Nd4jLong) batchSize, (Nd4jLong)nOutputPlane, (Nd4jLong)outputDepth, (Nd4jLong)outputHeight, (Nd4jLong)outputWidth}, shapeOf); - - // shape::shapeBuffer(5, shapeOf, newShape); - - // RELEASE(shapeOf, block.getWorkspace()); - - // return SHAPELIST(newShape); - return SHAPELIST(); - } - - DECLARE_TYPES(fullconv3d_bp) { - getOpDescriptor() - ->setAllowedInputTypes(nd4j::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } -////////////////////////////////////////////////////////////////////////// - CUSTOM_OP_IMPL(fullconv3d_bp, 5, 1, false, 0, 13) { - // auto input = INPUT_VARIABLE(0); - // auto gradNext = INPUT_VARIABLE(1); - // auto weights = INPUT_VARIABLE(2); - // auto finput = INPUT_VARIABLE(3); - - // // not used - // auto fgradInput = INPUT_VARIABLE(4); - - - // REQUIRE_TRUE(weights->rankOf() == 5, 0, "Weights should be 5D, got %i instead", weights->rankOf()); - // REQUIRE_TRUE(input->rankOf() == 5, 0, "Input should be 5D, got %i instead", input->rankOf()); - - // auto output = OUTPUT_VARIABLE(0); - - // int dT = INT_ARG(0); - // int dW = INT_ARG(1); - // int dH = INT_ARG(2); - // int pT = INT_ARG(3); - // int pW = INT_ARG(4); - // int pH = INT_ARG(5); - // int dilationT = INT_ARG(6); - // int dilationW = INT_ARG(7); - // int dilationH = INT_ARG(8); - // int aT = INT_ARG(9); - // int aW = INT_ARG(10); - // int aH = INT_ARG(11); - // bool biasUsed = INT_ARG(12) != 0; - - // const int nInputPlane = (int)weights->shapeOf()[0]; - // const int nOutputPlane = (int)weights->shapeOf()[1]; - // const int kT = (int)weights->shapeOf()[2]; - // const int kH = (int)weights->shapeOf()[3]; - // const int kW = (int)weights->shapeOf()[4]; - - // const Nd4jLong inputWidth = input->shapeOf()[4]; - // const Nd4jLong inputHeight = input->shapeOf()[3]; - // const Nd4jLong inputDepth = input->shapeOf()[2]; - // const Nd4jLong outputDepth = (inputDepth - 1) * dT - 2*pT + (dilationT * (kT - 1) + 1) + aT; - // const Nd4jLong outputHeight = (inputHeight - 1) * dH - 2*pH + (dilationH * (kH - 1) + 1) + aH; - // const Nd4jLong outputWidth = (inputWidth - 1) * dW - 2*pW + (dilationW * (kW - 1) + 1) + aW; - - // const Nd4jLong batchSize = input->shapeOf()[0]; - - - // REQUIRE_TRUE(output->isSameShape({(int) batchSize, (int) nInputPlane, (int) inputDepth, (int) inputHeight, (int) inputWidth}) ,0, "Output should have shape of [%i, %i, %i, %i, %i], but got [%i, %i, %i, %i, %i] instead", (int) batchSize, (int) nInputPlane, (int) inputDepth, (int) inputHeight, (int) inputWidth, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2), output->sizeAt(3), output->sizeAt(4)); - - // output->assign(0.0); - - // // FIXME: non-inplace reshape!!!! - // NDArray *gradColumns; - // //auto gradColumns = finput->reshape('c', {nOutputPlane*kW*kH*kT, inputDepth*inputHeight*inputWidth }); - - // std::unique_ptr tadsNext(gradNext->allExamples()); - // std::unique_ptr tadsOutput(output->allExamples()); - // for (int e = 0; e < tadsNext->size(); e++) { - // auto tadNext = tadsNext->at(e); - // auto tadOutput = tadsOutput->at(e); - - // // ConvolutionUtils::_vol2col( - // // tadNext->getBuffer(), - // // nOutputPlane, outputDepth, outputHeight, outputWidth, - // // kT, kH, kW, - // // pT, pH, pW, - // // dT, dH, dW, - // // dilationT, dilationH, dilationW, - // // gradColumns->getBuffer()); - // ConvolutionUtils::vol2col(*tadNext, *gradColumns, dT, dH, dW, pT, pH, pW, dilationT, dilationH, dilationW); - - // const auto m = weights->shapeOf()[0]; - // const auto n = gradColumns->shapeOf()[1]; - // const auto k = weights->shapeOf()[1] * weights->shapeOf()[2] * weights->shapeOf()[3] * weights->shapeOf()[4]; - - // // FIXME: mmul helper should be used here - // /* - // nd4j::blas::GEMM::op('f', 'n', 'n', - // n, m, k, - // 1.0f, - // gradColumns->getBuffer(), n, - // weights->getBuffer(), k, - // 0, - // tadOutput->getBuffer(), n - - // ); - // */ - // } - - - // STORE_RESULT(*output); - - // delete gradColumns; - return ND4J_STATUS_OK; - } - DECLARE_SHAPE_FN(fullconv3d_bp) { - // output shape equals to input shape, all out of sudden - // Nd4jLong* newShape; - // COPY_SHAPE(inputShape->at(0), newShape); - - // return SHAPELIST(newShape); - return SHAPELIST(); - } - - DECLARE_TYPES(fullconv3d_grad) { - getOpDescriptor() - ->setAllowedInputTypes(nd4j::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - -////////////////////////////////////////////////////////////////////////// - CUSTOM_OP_IMPL(fullconv3d_grad, 4, 2, false, 1, 13) { - // auto input = INPUT_VARIABLE(0); - // auto epsilon = INPUT_VARIABLE(1); - // auto columns = INPUT_VARIABLE(2); - // auto ones = INPUT_VARIABLE(3); - - // REQUIRE_TRUE(input->rankOf() == epsilon->rankOf(), 0, "Rank of input (%i) & epsilon (%i) should be equal", input->rankOf(), epsilon->rankOf()); - // REQUIRE_TRUE(input->sizeAt(0) == epsilon->sizeAt(0), 1, "Batch size should be equal for input and epsilon"); - - // auto gradWeight = OUTPUT_VARIABLE(0); - // auto gradBias = OUTPUT_VARIABLE(1); - - // REQUIRE_TRUE(gradBias->sizeAt(0) == gradWeight->sizeAt(1), 0, "Bias shape mismatch"); - - // int dT = INT_ARG(0); - // int dW = INT_ARG(1); - // int dH = INT_ARG(2); - // int pT = INT_ARG(3); - // int pW = INT_ARG(4); - // int pH = INT_ARG(5); - // int dilationT = INT_ARG(6); - // int dilationW = INT_ARG(7); - // int dilationH = INT_ARG(8); - // int aT = INT_ARG(9); - // int aW = INT_ARG(10); - // int aH = INT_ARG(11); - // bool biasUsed = INT_ARG(12) != 0; - - // double scale = block.getTArguments()->at(0); - - // int nInputPlane = (int)gradWeight->shapeOf()[0]; - // int nOutputPlane = (int)gradWeight->shapeOf()[1]; - // int kT = (int)gradWeight->shapeOf()[2]; - // int kH = (int)gradWeight->shapeOf()[3]; - // int kW = (int)gradWeight->shapeOf()[4]; - - - // const Nd4jLong inputWidth = input->shapeOf()[4]; - // const Nd4jLong inputHeight = input->shapeOf()[3]; - // const Nd4jLong inputDepth = input->shapeOf()[2]; - // const Nd4jLong outputDepth = (inputDepth - 1) * dT - 2*pT + (dilationT * (kT - 1) + 1) + aT; - // const Nd4jLong outputHeight = (inputHeight - 1) * dH - 2*pH + (dilationH * (kH - 1) + 1) + aH; - // const Nd4jLong outputWidth = (inputWidth - 1) * dW - 2*pW + (dilationW * (kW - 1) + 1) + aW; - - - // REQUIRE_TRUE(gradWeight->isContiguous(), 0, "gradWight should be continuous"); - // REQUIRE_TRUE(gradBias->isContiguous(), 0, "gradBias should be continuous"); - // REQUIRE_TRUE(ones->rankOf() == 3, 0, "Ones should have rank 3, got %i instead", ones->rankOf()); - - // REQUIRE_TRUE(ones->isSameShape({outputDepth, outputHeight, outputWidth}), 0, ""); - - // ones->assign(1.0); - - // std::unique_ptr tadsInput(input->allExamples()); - // std::unique_ptr tadsEpsilon(epsilon->allExamples()); - - // for (int e = 0; e < tadsInput->size(); e++) { - // auto tadInput = tadsInput->at(e); - // auto tadEpsilon = tadsEpsilon->at(e); - - // // ConvolutionUtils::_vol2col( - // // tadEpsilon->getBuffer(), nOutputPlane, - // // outputDepth, outputHeight, outputWidth, - // // kT, kH, kW, - // // pT, pH, pW, - // // dT, dH, dW, - // // dilationT, dilationH, dilationW, - // // columns->getBuffer() - // // ); - // ConvolutionUtils::vol2col(*tadEpsilon, *columns, dT, dH, dW, pT, pH, pW, dilationT, dilationH, dilationW); - // const Nd4jLong n = columns->shapeOf()[0]; // nOutputPlane * kt * kh * kw - // const Nd4jLong m = tadInput->shapeOf()[0]; // nInputPlane - // const Nd4jLong k = columns->shapeOf()[1]; - - // // FIXME: mmul helper should be used here - // /** - // nd4j::blas::GEMM::op('f', 't', 'n', - // n, m, k, - // scale, - // columns->getBuffer(), k, - // tadInput->getBuffer(), k, - // 1, - // gradWeight->getBuffer(), n); - // */ - - // const Nd4jLong m_ = nOutputPlane; - // const Nd4jLong k_ = outputDepth * outputHeight * outputWidth; - - - // if (gradBias) { - // // FIXME: mmul helper should be used here - // /* - // nd4j::blas::GEMV::op('t', - // k_, m_, - // scale, - // tadEpsilon->getBuffer(), k_, - // ones->getBuffer(), 1, (T)1.0f, - // gradBias->getBuffer(), 1); - // */ - // } - // } - - - // STORE_2_RESULTS(*gradWeight, *gradBias); - - return Status::OK(); - } - DECLARE_SHAPE_FN(fullconv3d_grad) { - // auto list = SHAPELIST(); - - // _grad ops MUST have output arrays provided - - // return list; - return SHAPELIST(); - } - } -} - -#endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/headers/convo.h b/libnd4j/include/ops/declarable/headers/convo.h index 33e0d89fd..ee1417386 100644 --- a/libnd4j/include/ops/declarable/headers/convo.h +++ b/libnd4j/include/ops/declarable/headers/convo.h @@ -187,12 +187,6 @@ namespace nd4j { DECLARE_CUSTOM_OP(pnormpool2d_bp, 2, 1, false, 1, 10); #endif - #if NOT_EXCLUDED(OP_fullconv3d) - DECLARE_CUSTOM_OP(fullconv3d, 5, 1, false, 0, 13); - DECLARE_CUSTOM_OP(fullconv3d_bp, 5, 1, false, 0, 13); - DECLARE_CUSTOM_OP(fullconv3d_grad, 4, 2, false, 1, 13); - #endif - /** * This op implements im2col algorithm, widely used in convolution neural networks * Input: 4D input expected diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java index 0da607a42..fd46da910 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java @@ -16,6 +16,9 @@ package org.nd4j.autodiff.samediff; +import static org.nd4j.autodiff.util.TrainingUtils.getSingleOutput; +import static org.nd4j.autodiff.util.TrainingUtils.stackOutputs; + import com.google.common.collect.HashBasedTable; import com.google.common.collect.Table; import com.google.common.primitives.Ints; @@ -73,6 +76,7 @@ import org.nd4j.linalg.dataset.adapter.SingletonMultiDataSetIterator; import org.nd4j.linalg.dataset.api.MultiDataSet; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; +import org.nd4j.linalg.exception.ND4JException; import org.nd4j.linalg.exception.ND4JIllegalArgumentException; import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.exception.ND4UnresolvedOutputVariables; @@ -109,7 +113,7 @@ import org.tensorflow.framework.GraphDef; *

* That graph accumulates operations. *

- * In order to execute the graph, you run one of the execution methods, such as {@link #exec(Map, String...)} + * In order to execute the graph, you run one of the execution methods, such as {@link #output(Map, String...)} */ @AllArgsConstructor @Builder @@ -2262,7 +2266,7 @@ public class SameDiff extends SDBaseOps { MultiDataSet ds = iterator.next(); Map placeholderMap = toPlaceholderMap(ds); - Map m = exec(placeholderMap, reqVars); + Map m = output(placeholderMap, reqVars); for(Map.Entry> e : variableEvals.entrySet()){ INDArray prediction = m.get(e.getKey()); @@ -2288,7 +2292,15 @@ public class SameDiff extends SDBaseOps { * @param outputs The variables to evaluate */ public Map output(DataSet dataSet, String... outputs){ - return output(new SingletonMultiDataSetIterator(dataSet.toMultiDataSet()), outputs).get(0); + return outputBatches(new SingletonMultiDataSetIterator(dataSet.toMultiDataSet()), outputs).get(0); + } + + /** + * Single output inference. + * See {@link #output(DataSet, String...)} + */ + public INDArray outputSingle(DataSet dataSet, String output){ + return output(dataSet, output).get(output); } /** @@ -2299,13 +2311,40 @@ public class SameDiff extends SDBaseOps { * sameDiff.output(iterator, "softmax");} * * + * Uses concatenation on the outputs of {@link #outputBatches(DataSetIterator, String...)} which may cause issues with some inputs. + * RNNs with variable time series length and CNNs with variable image sizes will most likely have issues. + * * @param iterator Iterator as source of data to evaluate * @param outputs The variables to evaluate */ - public List> output(DataSetIterator iterator, String... outputs){ + public Map output(DataSetIterator iterator, String... outputs){ return output(new MultiDataSetIteratorAdapter(iterator), outputs); } + /** + * See {@link #output(DataSetIterator, String...)}, but without the concatenation of batches. + * + */ + public List> outputBatches(DataSetIterator iterator, String... outputs){ + return outputBatches(new MultiDataSetIteratorAdapter(iterator), outputs); + } + + /** + * Single output inference. + * See {@link #output(DataSetIterator, String...)} + */ + public INDArray outputSingle(DataSetIterator dataSet, String output){ + return output(dataSet, output).get(output); + } + + /** + * Single batched output inference. + * See {@link #output(DataSetIterator, String...)} + */ + public List outputSingleBatches(DataSetIterator dataSet, String output){ + return getSingleOutput(outputBatches(dataSet, output), output); + } + /** * Perform inference.
*
@@ -2321,10 +2360,20 @@ public class SameDiff extends SDBaseOps { * } * * + * Uses concatenation on the outputs of {@link #outputBatches(MultiDataSetIterator, String...)} which may cause issues with some inputs. + * RNNs with variable time series length and CNNs with variable image sizes will most likely have issues. + * * @param iterator The iterator - the source of the data for inference * @param outputs The set of outputs to report. If null, defaults to all outputs of this SameDiff. */ - public List> output(MultiDataSetIterator iterator, String... outputs){ + public Map output(MultiDataSetIterator iterator, String... outputs){ + return stackOutputs(outputBatches(iterator, outputs)); + } + + /** + * See {@link #output(MultiDataSetIterator, String...)}, but without the concatenation of batches. + */ + public List> outputBatches(MultiDataSetIterator iterator, String... outputs){ Preconditions.checkState(trainingConfig != null, "Training config has not been set"); List reqVars; @@ -2344,12 +2393,114 @@ public class SameDiff extends SDBaseOps { MultiDataSet ds = iterator.next(); Map placeholderMap = toPlaceholderMap(ds); - predictions.add(exec(placeholderMap, reqVars)); + predictions.add(output(placeholderMap, reqVars)); } return predictions; } + /** + * Single output inference. + * See {@link #output(MultiDataSetIterator, String...)} + */ + public INDArray outputSingle(MultiDataSetIterator dataSet, String output){ + return output(dataSet, output).get(output); + } + + /** + * Single batched output inference. + * See {@link #output(MultiDataSetIterator, String...)} + */ + public List outputSingleBatches(MultiDataSetIterator dataSet, String output){ + return getSingleOutput(outputBatches(dataSet, output), output); + } + + /** + * @deprecated See {@link #outputAll(Map)} + */ + @Deprecated + public Map execAll(Map placeholders){ + return outputAll(placeholders); + } + + /** + * Do inference for all variables for a single batch + */ + public Map outputAll(Map placeholders){ + List allVars = new ArrayList<>(); + for(Variable v : variables.values()){ + allVars.add(v.getName()); + } + return output(placeholders, allVars.toArray(new String[0])); + } + /** + * @deprecated See {@link #outputSingle(Map, String)} + */ + @Deprecated + public INDArray execSingle(Map placeholders, String output){ + return outputSingle(placeholders, output); + } + + /** + * Do inference for a single variable for a single batch + */ + public INDArray outputSingle(Map placeholders, String output){ + return output(placeholders, output).get(output); + } + /** + * @deprecated See {@link #output(Map, List)} + */ + @Deprecated + public Map exec(Map placeholders, List outputs){ + return output(placeholders, outputs); + } + + /** + * Do inference for the given variables for a single batch + */ + public Map output(Map placeholders, List outputs){ + return output(placeholders, outputs.toArray(new String[outputs.size()])); + } + + /** + * @deprecated See {@link #output(Map, String...)} + */ + @Deprecated + public Map exec(Map placeholders, String... outputs) { + return output(placeholders, outputs); + } + + + /** + * Do inference for the given variables for a single batch + */ + public Map output(Map placeholders, String... outputs) { + return output(placeholders, false, null, outputs); + } + + /** + * Do inference for the given variables for a single batch, with training information + */ + protected Map output(Map placeholders, boolean training, At at, String... outputs){ + Preconditions.checkState(outputs != null && outputs.length > 0, "No outputs were specified"); + long threadId = Thread.currentThread().getId(); + if(!sessions.containsKey(threadId)){ + log.info("Creating new InferenceSession for thread {}", threadId); + sessions.put(threadId, new InferenceSession(this)); + } + + List phNames = inputs(); + if(placeholders == null && phNames != null){ + //Maybe user set placeholders before calling exec method? + placeholders = placeholdersPerThread.get(Thread.currentThread().getId()); + } + + //Placeholder validation is performed in InferenceSession + + InferenceSession is = sessions.get(threadId); + Map ret = is.output(Arrays.asList(outputs), placeholders, listeners, training, at); + return ret; + } public SDVariable one(String name, int... shape){ return one(name, Nd4j.defaultFloatingPointType(), shape); @@ -3779,7 +3930,7 @@ public class SameDiff extends SDBaseOps { } //TODO is this 'train' flag the best approach? - sd.exec(placeholders, trainingConfig != null, at, variableGradNamesList.toArray(new String[variableGradNamesList.size()])); + sd.output(placeholders, trainingConfig != null, at, variableGradNamesList.toArray(new String[variableGradNamesList.size()])); } /** @@ -4459,47 +4610,6 @@ public class SameDiff extends SDBaseOps { } } - public Map execAll(Map placeholders){ - List allVars = new ArrayList<>(); - for(Variable v : variables.values()){ - allVars.add(v.getName()); - } - return exec(placeholders, allVars.toArray(new String[allVars.size()])); - } - - public INDArray execSingle(Map placeholders, String output){ - return exec(placeholders, output).get(output); - } - - public Map exec(Map placeholders, List outputs){ - return exec(placeholders, outputs.toArray(new String[outputs.size()])); - } - - public Map exec(Map placeholders, String... outputs) { - return exec(placeholders, false, null, outputs); - } - - protected Map exec(Map placeholders, boolean training, At at, String... outputs){ - Preconditions.checkState(outputs != null && outputs.length > 0, "No outputs were specified"); - long threadId = Thread.currentThread().getId(); - if(!sessions.containsKey(threadId)){ - log.info("Creating new InferenceSession for thread {}", threadId); - sessions.put(threadId, new InferenceSession(this)); - } - - List phNames = inputs(); - if(placeholders == null && phNames != null){ - //Maybe user set placeholders before calling exec method? - placeholders = placeholdersPerThread.get(Thread.currentThread().getId()); - } - - //Placeholder validation is performed in InferenceSession - - InferenceSession is = sessions.get(threadId); - Map ret = is.output(Arrays.asList(outputs), placeholders, listeners, training, at); - return ret; - } - protected int asFlatNode(String name, @NonNull SameDiff scope, @NonNull FlatBufferBuilder bufferBuilder) { int scopeName = bufferBuilder.createString(name); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/util/TrainingUtils.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/util/TrainingUtils.java new file mode 100644 index 000000000..289bd15be --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/util/TrainingUtils.java @@ -0,0 +1,70 @@ +/* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.nd4j.autodiff.util; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import lombok.AccessLevel; +import lombok.NoArgsConstructor; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; +import org.nd4j.linalg.exception.ND4JException; +import org.nd4j.linalg.factory.Nd4j; + +/** + * Utilities for SameDiff training and inference + */ +@NoArgsConstructor(access = AccessLevel.PRIVATE) +public class TrainingUtils { + + /** + * Stack batch outputs, like an output from {@link org.nd4j.autodiff.samediff.SameDiff#output(MultiDataSetIterator, String...)} + */ + public static Map stackOutputs(List> outputs){ + Map> outs = new HashMap<>(); + for(Map batch : outputs){ + for(String k : batch.keySet()){ + if(!outs.containsKey(k)) + outs.put(k, new ArrayList()); + outs.get(k).add(batch.get(k)); + } + } + + Map ret = new HashMap<>(); + for(String k : outs.keySet()){ + try { + ret.put(k, Nd4j.concat(0, outs.get(k).toArray(new INDArray[0]))); + } catch(Exception e){ + throw new ND4JException("Error concatenating batch outputs", e); + } + } + return ret; + } + + /** + * Get a list of batch outputs for a single variable from a list of batch outputs for all variables + */ + public static List getSingleOutput(List> outputs, String output){ + List batches = new ArrayList<>(); + for(Map batch : outputs) + batches.add(batch.get(output)); + + return batches; + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/OpValidation.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/OpValidation.java index 586c279c1..7dcd66530 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/OpValidation.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/OpValidation.java @@ -915,7 +915,6 @@ public class OpValidation { Conv2DDerivative.class, Conv3DDerivative.class, DeConv2DDerivative.class, - FullConv3DDerivative.class, LocalResponseNormalizationDerivative.class, Pooling2DDerivative.class, Pooling3DDerivative.class, diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/DifferentialFunctionClassHolder.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/DifferentialFunctionClassHolder.java index c7464448a..ec37deeb8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/DifferentialFunctionClassHolder.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/DifferentialFunctionClassHolder.java @@ -72,7 +72,6 @@ public class DifferentialFunctionClassHolder { add(AvgPooling2D.class.getName()); add(Conv2D.class.getName()); add(Conv3D.class.getName()); - add(FullConv3D.class.getName()); add(LocalResponseNormalization.class.getName()); add(MaxPooling2D.class.getName()); add(Pooling2D.class.getName()); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java index 19bdb6d55..78e7e74b2 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java @@ -117,8 +117,6 @@ public class ImportClassMapping { org.nd4j.linalg.api.ops.impl.layers.convolution.DeConv3DDerivative.class, org.nd4j.linalg.api.ops.impl.layers.convolution.DepthToSpace.class, org.nd4j.linalg.api.ops.impl.layers.convolution.DepthwiseConv2D.class, - org.nd4j.linalg.api.ops.impl.layers.convolution.FullConv3D.class, - org.nd4j.linalg.api.ops.impl.layers.convolution.FullConv3DDerivative.class, org.nd4j.linalg.api.ops.impl.layers.convolution.Im2col.class, org.nd4j.linalg.api.ops.impl.layers.convolution.Im2colBp.class, org.nd4j.linalg.api.ops.impl.layers.convolution.LegacyPooling2D.class, diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/AvgPooling2D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/AvgPooling2D.java index f5c4100a7..3198a6a56 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/AvgPooling2D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/AvgPooling2D.java @@ -251,8 +251,6 @@ public class AvgPooling2D extends DynamicCustomOp { .kW(kW) .pH(pH) .pW(pW) - .virtualHeight(1) - .virtualWidth(1) .isNHWC(data_format.equalsIgnoreCase("nhwc")) .extra(0.0) // averaging only for non-padded values .build(); @@ -277,8 +275,6 @@ public class AvgPooling2D extends DynamicCustomOp { .kW(kernelShape.get(1).intValue()) .pH(padding.get(0).intValue()) .pW(padding.size() < 2 ? padding.get(0).intValue() : padding.get(1).intValue()) - .virtualWidth(1) - .virtualHeight(1) .build(); this.config = pooling2DConfig; addArgs(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/FullConv3D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/FullConv3D.java deleted file mode 100644 index bb3d76fd7..000000000 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/FullConv3D.java +++ /dev/null @@ -1,228 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.linalg.api.ops.impl.layers.convolution; - -import lombok.Builder; -import lombok.extern.slf4j.Slf4j; -import lombok.val; -import org.nd4j.autodiff.samediff.SDVariable; -import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.base.Preconditions; -import org.nd4j.imports.descriptors.properties.AttributeAdapter; -import org.nd4j.imports.descriptors.properties.PropertyMapping; -import org.nd4j.imports.descriptors.properties.adapters.IntArrayIntIndexAdpater; -import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.DynamicCustomOp; -import org.nd4j.linalg.api.ops.impl.layers.convolution.config.FullConv3DConfig; -import org.nd4j.linalg.util.ArrayUtil; - -import java.lang.reflect.Field; -import java.util.*; - - -/** - * FullConv3D operation - */ -@Slf4j -public class FullConv3D extends DynamicCustomOp { - - protected FullConv3DConfig config; - - @Builder(builderMethodName = "builder") - public FullConv3D(SameDiff sameDiff, SDVariable[] inputFunctions, INDArray[] inputs, INDArray[] outputs, FullConv3DConfig config) { - super(null,sameDiff, inputFunctions, false); - this.config = config; - if(inputs != null) { - addInputArgument(inputs); - } - - if(outputs != null) { - addOutputArgument(outputs); - } - - addArgs(); - } - - - public FullConv3D() {} - - @Override - public Map propertiesForFunction() { - return config.toProperties(); - } - - @Override - public long[] iArgs() { - if (iArguments.size() == 0) - addArgs(); - - return super.iArgs(); - } - - - @Override - public boolean isConfigProperties() { - return true; - } - - @Override - public String configFieldName() { - return "config"; - } - - @Override - public Map> attributeAdaptersForFunction() { - Map> ret = new LinkedHashMap<>(); - Map tfAdapters = new LinkedHashMap<>(); - - tfAdapters.put("dT", new IntArrayIntIndexAdpater(1)); - tfAdapters.put("dW", new IntArrayIntIndexAdpater(2)); - tfAdapters.put("dH",new IntArrayIntIndexAdpater(3)); - - - tfAdapters.put("pT", new IntArrayIntIndexAdpater(1)); - tfAdapters.put("pW", new IntArrayIntIndexAdpater(2)); - tfAdapters.put("pH",new IntArrayIntIndexAdpater(3)); - - ret.put(tensorflowName(),tfAdapters); - - return ret; - } - - - @Override - public Map> mappingsForFunction() { - Map> ret = new HashMap<>(); - Map map = new HashMap<>(); - - - - val strideMapping = PropertyMapping.builder() - .tfAttrName("strides") - .onnxAttrName("strides") - .propertyNames(new String[]{"dT","dW","dH"}) - .build(); - - - - val dilationMapping = PropertyMapping.builder() - .onnxAttrName("dilations") - .propertyNames(new String[]{"dD","dH","dW"}) - .tfAttrName("rates") - .build(); - - - - val sameMode = PropertyMapping.builder() - .onnxAttrName("auto_pad") - .propertyNames(new String[]{"isSameMode"}) - .tfAttrName("padding") - .build(); - - val paddingWidthHeight = PropertyMapping.builder() - .onnxAttrName("padding") - .propertyNames(new String[]{"pT","pW","pH"}) - .build(); - - val dataFormat = PropertyMapping.builder() - .onnxAttrName("data_format") - .tfAttrName("data_format") - .propertyNames(new String[]{"dataFormat"}) - .build(); - - - val outputPadding = PropertyMapping.builder() - .propertyNames(new String[]{"aT","aH","aW"}) - .build(); - - - val biasUsed = PropertyMapping.builder() - .propertyNames(new String[]{"biasUsed"}) - .build(); - - - - - for(val propertyMapping : new PropertyMapping[] { - strideMapping, - dilationMapping, - sameMode, - paddingWidthHeight, - dataFormat, - outputPadding,biasUsed}) { - for(val keys : propertyMapping.getPropertyNames()) - map.put(keys,propertyMapping); - - } - - - ret.put(onnxName(),map); - ret.put(tensorflowName(),map); - return ret; - } - - private void addArgs() { - addIArgument(new long[]{ - config.getDT(), - config.getDW(), - config.getDH(), - config.getPT(), - config.getPW(), - config.getPH(), - config.getDilationT(), - config.getDilationW(), - config.getDilationH(), - config.getAT(), - config.getAW(), - config.getAH(), - ArrayUtil.fromBoolean(config.isBiasUsed())}); - - - } - - - - - @Override - public String opName() { - return "fullconv3d"; - } - - - @Override - public List doDiff(List f1) { - List inputs = new ArrayList<>(); - inputs.addAll(Arrays.asList(args())); - inputs.addAll(f1); - List ret = new ArrayList<>(); - FullConv3DDerivative fullConv3DDerivative = FullConv3DDerivative.derivativeBuilder() - .conv3DConfig(config) - .sameDiff(sameDiff) - .inputFunctions(inputs.toArray(new SDVariable[inputs.size()])) - .build(); - ret.addAll(Arrays.asList(fullConv3DDerivative.outputVariables())); - return ret; - } - - @Override - public List calculateOutputDataTypes(List inputDataTypes){ - int n = args().length; - Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes); - return Collections.singletonList(inputDataTypes.get(0)); - } -} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/FullConv3DDerivative.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/FullConv3DDerivative.java deleted file mode 100644 index 92c41886d..000000000 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/FullConv3DDerivative.java +++ /dev/null @@ -1,77 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.linalg.api.ops.impl.layers.convolution; - -import lombok.Builder; -import lombok.extern.slf4j.Slf4j; -import org.nd4j.autodiff.samediff.SDVariable; -import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.base.Preconditions; -import org.nd4j.imports.NoOpNameFoundException; -import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.impl.layers.convolution.config.FullConv3DConfig; - -import java.util.ArrayList; -import java.util.List; - - -/** - * FullConv3DDerivative operation - */ -@Slf4j -public class FullConv3DDerivative extends FullConv3D { - - public FullConv3DDerivative() {} - - @Builder(builderMethodName = "derivativeBuilder") - public FullConv3DDerivative(SameDiff sameDiff, SDVariable[] inputFunctions, INDArray[] inputs, INDArray[] outputs, FullConv3DConfig conv3DConfig) { - super(sameDiff, inputFunctions, inputs, outputs, conv3DConfig); - } - - @Override - public String opName() { - return "fullconv3d_bp"; - } - - @Override - public String tensorflowName() { - throw new NoOpNameFoundException("No tensorflwo op name found for conv3d derivative"); - } - - @Override - public String[] tensorflowNames() { - throw new NoOpNameFoundException("No tensorflwo op name found for conv3d derivative"); - } - - - @Override - public List doDiff(List f1) { - throw new UnsupportedOperationException("Unable to take derivative of derivative."); - } - - @Override - public List calculateOutputDataTypes(List inputDataTypes){ - int n = args().length; //Original inputs + gradient at - Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes); - List out = new ArrayList<>(n-1); - for( int i=0; i toProperties(); + /** * Get the value for a given property * for this function @@ -154,4 +158,5 @@ public abstract class BaseConvolutionConfig { } + protected abstract void validate(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/config/Conv1DConfig.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/config/Conv1DConfig.java index 2553ed2ac..f04e27533 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/config/Conv1DConfig.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/config/Conv1DConfig.java @@ -16,15 +16,16 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution.config; -import lombok.*; -import org.nd4j.base.Preconditions; - import java.util.LinkedHashMap; import java.util.Map; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.util.ConvConfigUtil; -@Builder @Data -@AllArgsConstructor +@Builder @NoArgsConstructor public class Conv1DConfig extends BaseConvolutionConfig { public static final String NCW = "NCW"; @@ -40,6 +41,16 @@ public class Conv1DConfig extends BaseConvolutionConfig { private String dataFormat = NCW; private boolean isSameMode; + public Conv1DConfig(long k, long s, long p, String dataFormat, boolean isSameMode) { + this.k = k; + this.s = s; + this.p = p; + this.dataFormat = dataFormat; + this.isSameMode = isSameMode; + + validate(); + } + public boolean isNWC(){ Preconditions.checkState(dataFormat.equalsIgnoreCase(NCW) || dataFormat.equalsIgnoreCase(NWC), "Data format must be one of %s or %s, got %s", NCW, NWC, dataFormat); @@ -54,6 +65,7 @@ public class Conv1DConfig extends BaseConvolutionConfig { } } + @Override public Map toProperties() { Map ret = new LinkedHashMap<>(); ret.put("k", k); @@ -64,5 +76,11 @@ public class Conv1DConfig extends BaseConvolutionConfig { return ret; } + @Override + protected void validate() { + ConvConfigUtil.validate1D(k, s, p); + Preconditions.checkArgument(dataFormat != null, "Data format can't be null"); + } + } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/config/Conv2DConfig.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/config/Conv2DConfig.java index 2969ddaaa..40a2a3908 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/config/Conv2DConfig.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/config/Conv2DConfig.java @@ -16,18 +16,16 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution.config; -import lombok.AllArgsConstructor; +import java.util.LinkedHashMap; +import java.util.Map; import lombok.Builder; import lombok.Data; import lombok.NoArgsConstructor; import org.nd4j.base.Preconditions; +import org.nd4j.linalg.util.ConvConfigUtil; -import java.util.LinkedHashMap; -import java.util.Map; - -@Builder @Data -@AllArgsConstructor +@Builder @NoArgsConstructor public class Conv2DConfig extends BaseConvolutionConfig { public static final String NCHW = "NCHW"; @@ -53,6 +51,23 @@ public class Conv2DConfig extends BaseConvolutionConfig { @Builder.Default private String dataFormat = NCHW; + public Conv2DConfig(long kH, long kW, long sH, long sW, long pH, long pW, long dH, long dW, boolean isSameMode, + String dataFormat) { + + this.kH = kH; + this.kW = kW; + this.sH = sH; + this.sW = sW; + this.pH = pH; + this.pW = pW; + this.dH = dH; + this.dW = dW; + this.isSameMode = isSameMode; + this.dataFormat = dataFormat; + + validate(); + } + public boolean isNHWC(){ Preconditions.checkState(dataFormat.equalsIgnoreCase(NCHW) || dataFormat.equalsIgnoreCase(NHWC), "Data format must be one of %s or %s, got %s", NCHW, NHWC, dataFormat); @@ -67,6 +82,7 @@ public class Conv2DConfig extends BaseConvolutionConfig { } } + @Override public Map toProperties() { Map ret = new LinkedHashMap<>(); ret.put("kH", kH); @@ -82,5 +98,11 @@ public class Conv2DConfig extends BaseConvolutionConfig { return ret; } + @Override + protected void validate() { + ConvConfigUtil.validate2D(kH, kW, sH, sW, pH, pW, dH, dW); + Preconditions.checkArgument(dataFormat != null, "Data format can't be null"); + } + } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/config/Conv3DConfig.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/config/Conv3DConfig.java index 478a634c4..f94c23329 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/config/Conv3DConfig.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/config/Conv3DConfig.java @@ -17,30 +17,28 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution.config; -import lombok.AllArgsConstructor; +import java.util.LinkedHashMap; +import java.util.Map; import lombok.Builder; import lombok.Data; import lombok.NoArgsConstructor; import org.nd4j.base.Preconditions; - -import java.util.LinkedHashMap; -import java.util.Map; +import org.nd4j.linalg.util.ConvConfigUtil; @Data @Builder @NoArgsConstructor -@AllArgsConstructor public class Conv3DConfig extends BaseConvolutionConfig { public static final String NDHWC = "NDHWC"; public static final String NCDHW = "NCDHW"; //kernel @Builder.Default - private long kD = 1; + private long kD = -1; @Builder.Default - private long kW = 1; + private long kW = -1; @Builder.Default - private long kH = 1; + private long kH = -1; //strides @Builder.Default @@ -66,14 +64,6 @@ public class Conv3DConfig extends BaseConvolutionConfig { @Builder.Default private long dH = 1; - //output padding - @Builder.Default - private long aD = 0; - @Builder.Default - private long aW = 0; - @Builder.Default - private long aH = 0; - @Builder.Default private boolean biasUsed = false; private boolean isSameMode; @@ -81,6 +71,27 @@ public class Conv3DConfig extends BaseConvolutionConfig { @Builder.Default private String dataFormat = NDHWC; + public Conv3DConfig(long kD, long kW, long kH, long sD, long sW, long sH, long pD, long pW, long pH, long dD, + long dW, long dH, boolean biasUsed, boolean isSameMode, String dataFormat) { + this.kD = kD; + this.kW = kW; + this.kH = kH; + this.sD = sD; + this.sW = sW; + this.sH = sH; + this.pD = pD; + this.pW = pW; + this.pH = pH; + this.dD = dD; + this.dW = dW; + this.dH = dH; + this.biasUsed = biasUsed; + this.isSameMode = isSameMode; + this.dataFormat = dataFormat; + + validate(); + } + public boolean isNCDHW(){ Preconditions.checkState(dataFormat.equalsIgnoreCase(NCDHW) || dataFormat.equalsIgnoreCase(NDHWC), "Data format must be one of %s or %s, got %s", NCDHW, NDHWC, dataFormat); @@ -95,6 +106,7 @@ public class Conv3DConfig extends BaseConvolutionConfig { } } + @Override public Map toProperties() { Map ret = new LinkedHashMap<>(); ret.put("kD", kD); @@ -109,9 +121,6 @@ public class Conv3DConfig extends BaseConvolutionConfig { ret.put("dD", dD); ret.put("dW", dW); ret.put("dH", dH); - ret.put("aD", aD); - ret.put("aW", aW); - ret.put("aH", aH); ret.put("biasUsed", biasUsed); ret.put("dataFormat", dataFormat); ret.put("isSameMode", isSameMode); @@ -119,5 +128,11 @@ public class Conv3DConfig extends BaseConvolutionConfig { return ret; } + @Override + protected void validate() { + ConvConfigUtil.validate3D(kH, kW, kD, sH, sW, sD, pH, pW, pD, dH, dW, dD); + Preconditions.checkArgument(dataFormat != null, "Data format can't be null"); + } + } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/config/DeConv2DConfig.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/config/DeConv2DConfig.java index f95a444ac..89e8e7bf3 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/config/DeConv2DConfig.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/config/DeConv2DConfig.java @@ -16,22 +16,22 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution.config; -import lombok.AllArgsConstructor; +import java.util.LinkedHashMap; +import java.util.Map; import lombok.Builder; import lombok.Data; import lombok.NoArgsConstructor; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.util.ConvConfigUtil; -import java.util.LinkedHashMap; -import java.util.Map; - -@Builder @Data -@AllArgsConstructor +@Builder @NoArgsConstructor public class DeConv2DConfig extends BaseConvolutionConfig { public static final String NCHW = "NCHW"; public static final String NHWC = "NHWC"; + @Builder.Default private long kH = -1L; @Builder.Default private long kW = -1L; @Builder.Default private long sH = 1L; @@ -43,8 +43,25 @@ public class DeConv2DConfig extends BaseConvolutionConfig { @Builder.Default private boolean isSameMode = false; @Builder.Default private String dataFormat = NCHW; + public DeConv2DConfig(long kH, long kW, long sH, long sW, long pH, long pW, long dH, long dW, boolean isSameMode, + String dataFormat) { + this.kH = kH; + this.kW = kW; + this.sH = sH; + this.sW = sW; + this.pH = pH; + this.pW = pW; + this.dH = dH; + this.dW = dW; + this.isSameMode = isSameMode; + this.dataFormat = dataFormat; + validate(); + } + + @Override public Map toProperties() { + Map ret = new LinkedHashMap<>(); ret.put("kH", kH); ret.put("kW", kW); @@ -58,4 +75,10 @@ public class DeConv2DConfig extends BaseConvolutionConfig { ret.put("dataFormat", dataFormat); return ret; } + + @Override + protected void validate() { + ConvConfigUtil.validate2D(kH, kW, sH, sW, pH, pW, dH, dW); + Preconditions.checkArgument(dataFormat != null, "Data format can't be null"); + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/config/DeConv3DConfig.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/config/DeConv3DConfig.java index 89a938f89..2936c2ba2 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/config/DeConv3DConfig.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/config/DeConv3DConfig.java @@ -16,17 +16,16 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution.config; -import lombok.AllArgsConstructor; +import java.util.LinkedHashMap; +import java.util.Map; import lombok.Builder; import lombok.Data; import lombok.NoArgsConstructor; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.util.ConvConfigUtil; -import java.util.LinkedHashMap; -import java.util.Map; - -@Builder @Data -@AllArgsConstructor +@Builder @NoArgsConstructor public class DeConv3DConfig extends BaseConvolutionConfig { public static final String NCDHW = "NCDHW"; @@ -47,7 +46,28 @@ public class DeConv3DConfig extends BaseConvolutionConfig { @Builder.Default private boolean isSameMode = false; @Builder.Default private String dataFormat = NCDHW; + public DeConv3DConfig(long kD, long kH, long kW, long sD, long sH, long sW, long pD, long pH, long pW, long dD, + long dH, long dW, boolean isSameMode, String dataFormat) { + this.kD = kD; + this.kH = kH; + this.kW = kW; + this.sD = sD; + this.sH = sH; + this.sW = sW; + this.pD = pD; + this.pH = pH; + this.pW = pW; + this.dD = dD; + this.dH = dH; + this.dW = dW; + this.isSameMode = isSameMode; + this.dataFormat = dataFormat; + validate(); + } + + + @Override public Map toProperties() { Map ret = new LinkedHashMap<>(); ret.put("kD", kD); @@ -66,4 +86,10 @@ public class DeConv3DConfig extends BaseConvolutionConfig { ret.put("dataFormat", dataFormat); return ret; } + + @Override + protected void validate() { + ConvConfigUtil.validate3D(kH, kW, kD, sH, sW, sD, pH, pW, pD, dH, dW, dD); + Preconditions.checkArgument(dataFormat != null, "Data format can't be null"); + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/config/FullConv3DConfig.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/config/FullConv3DConfig.java deleted file mode 100644 index f740ce929..000000000 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/config/FullConv3DConfig.java +++ /dev/null @@ -1,53 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.linalg.api.ops.impl.layers.convolution.config; - -import lombok.Builder; -import lombok.Data; - -import java.util.LinkedHashMap; -import java.util.Map; - -@Builder -@Data -public class FullConv3DConfig extends BaseConvolutionConfig { - private long dT,dW,dH,pT,pW,pH,dilationT,dilationW,dilationH,aT,aW,aH; - private boolean biasUsed; - private String dataFormat; - - - - - public Map toProperties() { - Map ret = new LinkedHashMap<>(); - ret.put("dT",dT); - ret.put("dW",dW); - ret.put("dH",dH); - ret.put("pT",pT); - ret.put("pW",pW); - ret.put("pH",pH); - ret.put("dD",dilationT); - ret.put("dW",dilationW); - ret.put("dH",dilationH); - ret.put("aT",aT); - ret.put("aW",aW); - ret.put("aH",aH); - ret.put("biasUsed",biasUsed); - ret.put("dataFormat",dataFormat); - return ret; - } -} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/config/LocalResponseNormalizationConfig.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/config/LocalResponseNormalizationConfig.java index 5099003c0..92634af90 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/config/LocalResponseNormalizationConfig.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/config/LocalResponseNormalizationConfig.java @@ -16,19 +16,31 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution.config; -import lombok.Builder; -import lombok.Data; - import java.util.LinkedHashMap; import java.util.Map; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; +import org.nd4j.linalg.util.ConvConfigUtil; @Data @Builder +@NoArgsConstructor public class LocalResponseNormalizationConfig extends BaseConvolutionConfig { private double alpha, beta, bias; private int depth; + public LocalResponseNormalizationConfig(double alpha, double beta, double bias, int depth) { + this.alpha = alpha; + this.beta = beta; + this.bias = bias; + this.depth = depth; + + validate(); + } + + @Override public Map toProperties() { Map ret = new LinkedHashMap<>(); ret.put("alpha", alpha); @@ -38,4 +50,9 @@ public class LocalResponseNormalizationConfig extends BaseConvolutionConfig { return ret; } + @Override + protected void validate() { + ConvConfigUtil.validateLRN(alpha, beta, bias, depth); + } + } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/config/Pooling2DConfig.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/config/Pooling2DConfig.java index 487cbacc0..7176bf0f0 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/config/Pooling2DConfig.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/config/Pooling2DConfig.java @@ -16,32 +16,32 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution.config; -import lombok.AllArgsConstructor; +import java.util.LinkedHashMap; +import java.util.Map; import lombok.Builder; import lombok.Data; import lombok.NoArgsConstructor; import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2D; +import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2D.Divisor; +import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2D.Pooling2DType; +import org.nd4j.linalg.util.ConvConfigUtil; -import java.util.LinkedHashMap; -import java.util.Map; - -@Builder -@AllArgsConstructor @Data +@Builder @NoArgsConstructor public class Pooling2DConfig extends BaseConvolutionConfig { - private long kH, kW; - private long sH, sW; - private long pH, pW; - private long virtualHeight, virtualWidth; + @Builder.Default private long kH = -1, kW = -1; + @Builder.Default private long sH = 1, sW = 1; + @Builder.Default private long pH = 0, pW = 0; /** * Extra is an optional parameter mainly for use with pnorm right now. * All pooling implementations take 9 parameters save pnorm. * Pnorm takes 10 and is cast to an int. */ private double extra; - private Pooling2D.Pooling2DType type; + @Builder.Default + private Pooling2D.Pooling2DType type = Pooling2DType.MAX; @Builder.Default private Pooling2D.Divisor divisor = Pooling2D.Divisor.EXCLUDE_PADDING; private boolean isSameMode; @@ -52,7 +52,26 @@ public class Pooling2DConfig extends BaseConvolutionConfig { @Builder.Default private boolean isNHWC = false; + public Pooling2DConfig(long kH, long kW, long sH, long sW, long pH, long pW, double extra, Pooling2DType type, + Divisor divisor, boolean isSameMode, long dH, long dW, boolean isNHWC) { + this.kH = kH; + this.kW = kW; + this.sH = sH; + this.sW = sW; + this.pH = pH; + this.pW = pW; + this.extra = extra; + this.type = type; + this.divisor = divisor; + this.isSameMode = isSameMode; + this.dH = dH; + this.dW = dW; + this.isNHWC = isNHWC; + validate(); + } + + @Override public Map toProperties() { Map ret = new LinkedHashMap<>(); ret.put("kH", kH); @@ -61,8 +80,6 @@ public class Pooling2DConfig extends BaseConvolutionConfig { ret.put("sW", sW); ret.put("pH", pH); ret.put("pW", pW); - ret.put("virtualHeight", virtualHeight); - ret.put("virtualWidth", virtualWidth); ret.put("extra", extra); ret.put("type", type.toString()); ret.put("isSameMode", isSameMode); @@ -72,4 +89,11 @@ public class Pooling2DConfig extends BaseConvolutionConfig { return ret; } + @Override + protected void validate() { + ConvConfigUtil.validate2D(kH, kW, sH, sW, pH, pW, dH, dW); + + //TODO check other args? + } + } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/config/Pooling3DConfig.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/config/Pooling3DConfig.java index 58581be38..ec155fc10 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/config/Pooling3DConfig.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/config/Pooling3DConfig.java @@ -16,23 +16,22 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution.config; -import lombok.AllArgsConstructor; +import java.util.LinkedHashMap; +import java.util.Map; import lombok.Builder; import lombok.Data; import lombok.NoArgsConstructor; import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling3D; - -import java.util.LinkedHashMap; -import java.util.Map; +import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling3D.Pooling3DType; +import org.nd4j.linalg.util.ConvConfigUtil; @Data @Builder -@AllArgsConstructor @NoArgsConstructor public class Pooling3DConfig extends BaseConvolutionConfig { - private long kD, kW, kH; // kernel - private long sD, sW, sH; // strides - private long pD, pW, pH; // padding + @Builder.Default private long kD = -1, kW = -1, kH = -1; // kernel + @Builder.Default private long sD = 1, sW = 1, sH = 1; // strides + @Builder.Default private long pD = 0, pW = 0, pH = 0; // padding // dilation @Builder.Default private long dD = 1; @@ -40,10 +39,33 @@ public class Pooling3DConfig extends BaseConvolutionConfig { private long dW = 1; @Builder.Default private long dH = 1; - private Pooling3D.Pooling3DType type; + @Builder.Default + private Pooling3D.Pooling3DType type = Pooling3DType.MAX; private boolean isSameMode; @Builder.Default private boolean isNCDHW = true; + public Pooling3DConfig(long kD, long kW, long kH, long sD, long sW, long sH, long pD, long pW, long pH, long dD, + long dW, long dH, Pooling3DType type, boolean isSameMode, boolean isNCDHW) { + this.kD = kD; + this.kW = kW; + this.kH = kH; + this.sD = sD; + this.sW = sW; + this.sH = sH; + this.pD = pD; + this.pW = pW; + this.pH = pH; + this.dD = dD; + this.dW = dW; + this.dH = dH; + this.type = type; + this.isSameMode = isSameMode; + this.isNCDHW = isNCDHW; + + validate(); + } + + @Override public Map toProperties() { Map ret = new LinkedHashMap<>(); ret.put("kD", kD); @@ -63,4 +85,11 @@ public class Pooling3DConfig extends BaseConvolutionConfig { return ret; } + + @Override + protected void validate() { + ConvConfigUtil.validate3D(kH, kW, kD, sH, sW, sD, pH, pW, pD, dH, dW, dD); + + //TODO check other args + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/convolution/Convolution.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/convolution/Convolution.java index 3123b5b01..cab411916 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/convolution/Convolution.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/convolution/Convolution.java @@ -249,8 +249,6 @@ public class Convolution { .isSameMode(isSameMode) .sH(sy) .sW(sx) - .virtualHeight(virtualHeight) - .virtualWidth(virtualWidth) .type(type) .divisor(divisor) .build()) diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/util/ConvConfigUtil.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/util/ConvConfigUtil.java new file mode 100644 index 000000000..91b854923 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/util/ConvConfigUtil.java @@ -0,0 +1,93 @@ +/* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.nd4j.linalg.util; + +import lombok.AccessLevel; +import lombok.NoArgsConstructor; +import org.nd4j.base.Preconditions; + +/** + * Class with utility methods for validating convolution op configurations like {@link org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig} + */ +@NoArgsConstructor(access = AccessLevel.PRIVATE) +public class ConvConfigUtil { + + /** + * Validate a 2D convolution's Kernel, Stride, Padding, and Dilation + */ + public static void validate2D(long kH, long kW, long sH, long sW, long pH, long pW, long dH, long dW){ + Preconditions.checkArgument(kH != 0, "Kernel height can not be 0"); + Preconditions.checkArgument(kW != 0, "Kernel width can not be 0"); + + Preconditions.checkArgument(sH > 0, "Stride height can not be negative or 0, got: %s", sH); + Preconditions.checkArgument(sW > 0, "Stride width can not be negative or 0, got: %s", sW); + + Preconditions.checkArgument(pH >= 0, "Padding height can not be negative, got: %s", pH); + Preconditions.checkArgument(pW >= 0, "Padding width can not be negative, got: %s", pW); + + Preconditions.checkArgument(dH > 0, "Dilation height can not be negative or 0, got: %s", dH); + Preconditions.checkArgument(dW > 0, "Dilation width can not be negative or 0, got: %s", dW); + } + + /** + * Validate a 3D convolution's Kernel, Stride, Padding, and Dilation + */ + public static void validate3D(long kH, long kW, long kD, long sH, long sW, long sD, long pH, long pW, long pD, long dH, long dW, long dD){ + Preconditions.checkArgument(kH != 0, "Kernel height can not be 0"); + Preconditions.checkArgument(kW != 0, "Kernel width can not be 0"); + Preconditions.checkArgument(kD != 0, "Kernel depth can not be 0"); + + Preconditions.checkArgument(sH > 0, "Stride height can not be negative or 0, got: %s", sH); + Preconditions.checkArgument(sW > 0, "Stride width can not be negative or 0, got: %s", sW); + Preconditions.checkArgument(sD > 0, "Stride depth can not be negative or 0, got: %s", sD); + + Preconditions.checkArgument(pH >= 0, "Padding height can not be negative, got: %s", pH); + Preconditions.checkArgument(pW >= 0, "Padding width can not be negative, got: %s", pW); + Preconditions.checkArgument(pD >= 0, "Padding depth can not be negative, got: %s", pD); + + Preconditions.checkArgument(dH > 0, "Dilation height can not be negative or 0, got: %s", dH); + Preconditions.checkArgument(dW > 0, "Dilation width can not be negative or 0, got: %s", dW); + Preconditions.checkArgument(dD > 0, "Dilation depth can not be negative or 0, got: %s", dD); + } + + /** + * Validate a 3D convolution's Output Padding + */ + public static void validateExtra3D(long aH, long aW, long aD){ + Preconditions.checkArgument(aH >= 0, "Output padding height can not be negative, got: %s", aH); + Preconditions.checkArgument(aW >= 0, "Output padding width can not be negative, got: %s", aW); + Preconditions.checkArgument(aD >= 0, "Output padding depth can not be negative, got: %s", aD); + } + + /** + * Validate a 1D convolution's Kernel, Stride, and Padding + */ + public static void validate1D(long k, long s, long p){ + Preconditions.checkArgument(k != 0, "Kernel can not be 0"); + + Preconditions.checkArgument(s > 0, "Stride can not be negative or 0, got: %s", s); + + Preconditions.checkArgument(p >= 0, "Padding can not be negative, got: %s", p); + } + + /** + * Validate a LocalResponseNormalizationConfig + */ + public static void validateLRN(double alpha, double beta, double bias, int depth) { + Preconditions.checkArgument(depth > 0, "Depth can not be 0 or negative, got: %s", depth); + } +} diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java index d840850e4..858108705 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java @@ -14662,54 +14662,6 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); } // #endif -// #if NOT_EXCLUDED(OP_fullconv3d) - @Namespace("nd4j::ops") public static class fullconv3d extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public fullconv3d(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public fullconv3d(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public fullconv3d position(long position) { - return (fullconv3d)super.position(position); - } - - public fullconv3d() { super((Pointer)null); allocate(); } - private native void allocate(); - public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); - } - @Namespace("nd4j::ops") public static class fullconv3d_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public fullconv3d_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public fullconv3d_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public fullconv3d_bp position(long position) { - return (fullconv3d_bp)super.position(position); - } - - public fullconv3d_bp() { super((Pointer)null); allocate(); } - private native void allocate(); - public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); - } - @Namespace("nd4j::ops") public static class fullconv3d_grad extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public fullconv3d_grad(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public fullconv3d_grad(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public fullconv3d_grad position(long position) { - return (fullconv3d_grad)super.position(position); - } - - public fullconv3d_grad() { super((Pointer)null); allocate(); } - private native void allocate(); - public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); - } -// #endif - /** * This op implements im2col algorithm, widely used in convolution neural networks * Input: 4D input expected diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java index 1ad9978e7..c8cbff5f9 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java @@ -1124,7 +1124,7 @@ public class LayerOpValidation extends BaseOpValidation { assertNull(err, err); } - @Test(expected = IllegalStateException.class) + @Test(expected = IllegalArgumentException.class) public void exceptionThrown_WhenConv1DConfigInvalid() { int nIn = 3; int nOut = 4; @@ -1150,7 +1150,7 @@ public class LayerOpValidation extends BaseOpValidation { } - @Test(expected = IllegalStateException.class) + @Test(expected = IllegalArgumentException.class) public void exceptionThrown_WhenConv2DConfigInvalid() { Nd4j.getRandom().setSeed(12345); @@ -1171,7 +1171,7 @@ public class LayerOpValidation extends BaseOpValidation { .build()); } - @Test(expected = IllegalStateException.class) + @Test(expected = IllegalArgumentException.class) public void exceptionThrown_WhenConf3DInvalid() { Nd4j.getRandom().setSeed(12345); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/ConvConfigTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/ConvConfigTests.java new file mode 100644 index 000000000..996ccff7f --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/ConvConfigTests.java @@ -0,0 +1,515 @@ +/* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.nd4j.autodiff.samediff; + +import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +import org.junit.Assert; +import org.junit.Test; +import org.nd4j.linalg.api.ops.impl.layers.convolution.DeConv2D; +import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv1DConfig; +import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig; +import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv3DConfig; +import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv2DConfig; +import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv3DConfig; +import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling2DConfig; +import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling3DConfig; + +public class ConvConfigTests { + + @Test + public void testDeConv2D(){ + DeConv2DConfig.builder().kH(2).kW(4).build(); + + try{ + DeConv2DConfig.builder().kW(4).kH(0).build(); + fail(); + } catch (IllegalArgumentException e){ + assertTrue(e.getMessage().contains("Kernel height")); + } + + try{ + DeConv2DConfig.builder().kH(4).kW(0).build(); + fail(); + } catch (IllegalArgumentException e){ + assertTrue(e.getMessage().contains("Kernel width")); + } + + try{ + DeConv2DConfig.builder().kH(4).kW(3).sH(-2).build(); + fail(); + } catch (IllegalArgumentException e){ + assertTrue(e.getMessage().contains("Stride height")); + } + + try{ + DeConv2DConfig.builder().kH(4).kW(3).sW(-2).build(); + fail(); + } catch (IllegalArgumentException e){ + assertTrue(e.getMessage().contains("Stride width")); + } + + try{ + DeConv2DConfig.builder().kH(4).kW(3).pH(-2).build(); + fail(); + } catch (IllegalArgumentException e){ + assertTrue(e.getMessage().contains("Padding height")); + } + + try{ + DeConv2DConfig.builder().kH(4).kW(3).pW(-2).build(); + fail(); + } catch (IllegalArgumentException e){ + assertTrue(e.getMessage().contains("Padding width")); + } + + try{ + DeConv2DConfig.builder().kH(4).kW(3).dH(-2).build(); + fail(); + } catch (IllegalArgumentException e){ + assertTrue(e.getMessage().contains("Dilation height")); + } + + try{ + DeConv2DConfig.builder().kH(4).kW(3).dW(-2).build(); + fail(); + } catch (IllegalArgumentException e){ + assertTrue(e.getMessage().contains("Dilation width")); + } + } + + @Test + public void testConv2D(){ + Conv2DConfig.builder().kH(2).kW(4).build(); + + try{ + Conv2DConfig.builder().kW(4).kH(0).build(); + fail(); + } catch (IllegalArgumentException e){ + assertTrue(e.getMessage().contains("Kernel height")); + } + + try{ + Conv2DConfig.builder().kH(4).kW(0).build(); + fail(); + } catch (IllegalArgumentException e){ + assertTrue(e.getMessage().contains("Kernel width")); + } + + try{ + Conv2DConfig.builder().kH(4).kW(3).sH(-2).build(); + fail(); + } catch (IllegalArgumentException e){ + assertTrue(e.getMessage().contains("Stride height")); + } + + try{ + Conv2DConfig.builder().kH(4).kW(3).sW(-2).build(); + fail(); + } catch (IllegalArgumentException e){ + assertTrue(e.getMessage().contains("Stride width")); + } + + try{ + Conv2DConfig.builder().kH(4).kW(3).pH(-2).build(); + fail(); + } catch (IllegalArgumentException e){ + assertTrue(e.getMessage().contains("Padding height")); + } + + try{ + Conv2DConfig.builder().kH(4).kW(3).pW(-2).build(); + fail(); + } catch (IllegalArgumentException e){ + assertTrue(e.getMessage().contains("Padding width")); + } + + try{ + Conv2DConfig.builder().kH(4).kW(3).dH(-2).build(); + fail(); + } catch (IllegalArgumentException e){ + assertTrue(e.getMessage().contains("Dilation height")); + } + + try{ + Conv2DConfig.builder().kH(4).kW(3).dW(-2).build(); + fail(); + } catch (IllegalArgumentException e){ + assertTrue(e.getMessage().contains("Dilation width")); + } + } + + @Test + public void testPooling2D(){ + Pooling2DConfig.builder().kH(2).kW(4).build(); + + try{ + Pooling2DConfig.builder().kW(4).kH(0).build(); + fail(); + } catch (IllegalArgumentException e){ + assertTrue(e.getMessage().contains("Kernel height")); + } + + try{ + Pooling2DConfig.builder().kH(4).kW(0).build(); + fail(); + } catch (IllegalArgumentException e){ + assertTrue(e.getMessage().contains("Kernel width")); + } + + try{ + Pooling2DConfig.builder().kH(4).kW(3).sH(-2).build(); + fail(); + } catch (IllegalArgumentException e){ + assertTrue(e.getMessage().contains("Stride height")); + } + + try{ + Pooling2DConfig.builder().kH(4).kW(3).sW(-2).build(); + fail(); + } catch (IllegalArgumentException e){ + assertTrue(e.getMessage().contains("Stride width")); + } + + try{ + Pooling2DConfig.builder().kH(4).kW(3).pH(-2).build(); + fail(); + } catch (IllegalArgumentException e){ + assertTrue(e.getMessage().contains("Padding height")); + } + + try{ + Pooling2DConfig.builder().kH(4).kW(3).pW(-2).build(); + fail(); + } catch (IllegalArgumentException e){ + assertTrue(e.getMessage().contains("Padding width")); + } + + try{ + Pooling2DConfig.builder().kH(4).kW(3).dH(-2).build(); + fail(); + } catch (IllegalArgumentException e){ + assertTrue(e.getMessage().contains("Dilation height")); + } + + try{ + Pooling2DConfig.builder().kH(4).kW(3).dW(-2).build(); + fail(); + } catch (IllegalArgumentException e){ + assertTrue(e.getMessage().contains("Dilation width")); + } + } + + @Test + public void testDeConv3D(){ + DeConv3DConfig.builder().kH(2).kW(4).kD(3).build(); + + try{ + DeConv3DConfig.builder().kW(4).kD(3).kH(0).build(); + fail(); + } catch (IllegalArgumentException e){ + assertTrue(e.getMessage().contains("Kernel height")); + } + + try{ + DeConv3DConfig.builder().kH(4).kD(3).kW(0).build(); + fail(); + } catch (IllegalArgumentException e){ + assertTrue(e.getMessage().contains("Kernel width")); + } + + try{ + DeConv3DConfig.builder().kH(4).kW(3).kD(0).build(); + fail(); + } catch (IllegalArgumentException e){ + assertTrue(e.getMessage().contains("Kernel depth")); + } + + try{ + DeConv3DConfig.builder().kH(4).kW(3).kD(3).sH(-2).build(); + fail(); + } catch (IllegalArgumentException e){ + assertTrue(e.getMessage().contains("Stride height")); + } + + try{ + DeConv3DConfig.builder().kH(4).kW(3).kD(3).sW(-2).build(); + fail(); + } catch (IllegalArgumentException e){ + assertTrue(e.getMessage().contains("Stride width")); + } + + try{ + DeConv3DConfig.builder().kH(4).kW(3).kD(3).sD(-2).build(); + fail(); + } catch (IllegalArgumentException e){ + assertTrue(e.getMessage().contains("Stride depth")); + } + + try{ + DeConv3DConfig.builder().kH(4).kW(3).kD(3).pH(-2).build(); + fail(); + } catch (IllegalArgumentException e){ + assertTrue(e.getMessage().contains("Padding height")); + } + + try{ + DeConv3DConfig.builder().kH(4).kW(3).kD(3).pW(-2).build(); + fail(); + } catch (IllegalArgumentException e){ + assertTrue(e.getMessage().contains("Padding width")); + } + + try{ + DeConv3DConfig.builder().kH(4).kW(3).kD(3).pD(-2).build(); + fail(); + } catch (IllegalArgumentException e){ + assertTrue(e.getMessage().contains("Padding depth")); + } + + try{ + DeConv3DConfig.builder().kH(4).kW(3).kD(3).dH(-2).build(); + fail(); + } catch (IllegalArgumentException e){ + assertTrue(e.getMessage().contains("Dilation height")); + } + + try{ + DeConv3DConfig.builder().kH(4).kW(3).kD(3).dW(-2).build(); + fail(); + } catch (IllegalArgumentException e){ + assertTrue(e.getMessage().contains("Dilation width")); + } + + try{ + DeConv3DConfig.builder().kH(4).kW(3).kD(3).dD(-2).build(); + fail(); + } catch (IllegalArgumentException e){ + assertTrue(e.getMessage().contains("Dilation depth")); + } + } + + @Test + public void testConv3D(){ + Conv3DConfig.builder().kH(2).kW(4).kD(3).build(); + + try{ + Conv3DConfig.builder().kW(4).kD(3).kH(0).build(); + fail(); + } catch (IllegalArgumentException e){ + assertTrue(e.getMessage().contains("Kernel height")); + } + + try{ + Conv3DConfig.builder().kH(4).kD(3).kW(0).build(); + fail(); + } catch (IllegalArgumentException e){ + assertTrue(e.getMessage().contains("Kernel width")); + } + + try{ + Conv3DConfig.builder().kH(4).kW(3).kD(0).build(); + fail(); + } catch (IllegalArgumentException e){ + assertTrue(e.getMessage().contains("Kernel depth")); + } + + try{ + Conv3DConfig.builder().kH(4).kW(3).kD(3).sH(-2).build(); + fail(); + } catch (IllegalArgumentException e){ + assertTrue(e.getMessage().contains("Stride height")); + } + + try{ + Conv3DConfig.builder().kH(4).kW(3).kD(3).sW(-2).build(); + fail(); + } catch (IllegalArgumentException e){ + assertTrue(e.getMessage().contains("Stride width")); + } + + try{ + Conv3DConfig.builder().kH(4).kW(3).kD(3).sD(-2).build(); + fail(); + } catch (IllegalArgumentException e){ + assertTrue(e.getMessage().contains("Stride depth")); + } + + try{ + Conv3DConfig.builder().kH(4).kW(3).kD(3).pH(-2).build(); + fail(); + } catch (IllegalArgumentException e){ + assertTrue(e.getMessage().contains("Padding height")); + } + + try{ + Conv3DConfig.builder().kH(4).kW(3).kD(3).pW(-2).build(); + fail(); + } catch (IllegalArgumentException e){ + assertTrue(e.getMessage().contains("Padding width")); + } + + try{ + Conv3DConfig.builder().kH(4).kW(3).kD(3).pD(-2).build(); + fail(); + } catch (IllegalArgumentException e){ + assertTrue(e.getMessage().contains("Padding depth")); + } + + try{ + Conv3DConfig.builder().kH(4).kW(3).kD(3).dH(-2).build(); + fail(); + } catch (IllegalArgumentException e){ + assertTrue(e.getMessage().contains("Dilation height")); + } + + try{ + Conv3DConfig.builder().kH(4).kW(3).kD(3).dW(-2).build(); + fail(); + } catch (IllegalArgumentException e){ + assertTrue(e.getMessage().contains("Dilation width")); + } + + try{ + Conv3DConfig.builder().kH(4).kW(3).kD(3).dD(-2).build(); + fail(); + } catch (IllegalArgumentException e){ + assertTrue(e.getMessage().contains("Dilation depth")); + } + } + + + + @Test + public void testPooling3D(){ + Pooling3DConfig.builder().kH(2).kW(4).kD(3).build(); + + try{ + Pooling3DConfig.builder().kW(4).kD(3).kH(0).build(); + fail(); + } catch (IllegalArgumentException e){ + assertTrue(e.getMessage().contains("Kernel height")); + } + + try{ + Pooling3DConfig.builder().kH(4).kD(3).kW(0).build(); + fail(); + } catch (IllegalArgumentException e){ + assertTrue(e.getMessage().contains("Kernel width")); + } + + try{ + Pooling3DConfig.builder().kH(4).kW(3).kD(0).build(); + fail(); + } catch (IllegalArgumentException e){ + assertTrue(e.getMessage().contains("Kernel depth")); + } + + try{ + Pooling3DConfig.builder().kH(4).kW(3).kD(3).sH(-2).build(); + fail(); + } catch (IllegalArgumentException e){ + assertTrue(e.getMessage().contains("Stride height")); + } + + try{ + Pooling3DConfig.builder().kH(4).kW(3).kD(3).sW(-2).build(); + fail(); + } catch (IllegalArgumentException e){ + assertTrue(e.getMessage().contains("Stride width")); + } + + try{ + Pooling3DConfig.builder().kH(4).kW(3).kD(3).sD(-2).build(); + fail(); + } catch (IllegalArgumentException e){ + assertTrue(e.getMessage().contains("Stride depth")); + } + + try{ + Pooling3DConfig.builder().kH(4).kW(3).kD(3).pH(-2).build(); + fail(); + } catch (IllegalArgumentException e){ + assertTrue(e.getMessage().contains("Padding height")); + } + + try{ + Pooling3DConfig.builder().kH(4).kW(3).kD(3).pW(-2).build(); + fail(); + } catch (IllegalArgumentException e){ + assertTrue(e.getMessage().contains("Padding width")); + } + + try{ + Pooling3DConfig.builder().kH(4).kW(3).kD(3).pD(-2).build(); + fail(); + } catch (IllegalArgumentException e){ + assertTrue(e.getMessage().contains("Padding depth")); + } + + try{ + Pooling3DConfig.builder().kH(4).kW(3).kD(3).dH(-2).build(); + fail(); + } catch (IllegalArgumentException e){ + assertTrue(e.getMessage().contains("Dilation height")); + } + + try{ + Pooling3DConfig.builder().kH(4).kW(3).kD(3).dW(-2).build(); + fail(); + } catch (IllegalArgumentException e){ + assertTrue(e.getMessage().contains("Dilation width")); + } + + try{ + Pooling3DConfig.builder().kH(4).kW(3).kD(3).dD(-2).build(); + fail(); + } catch (IllegalArgumentException e){ + assertTrue(e.getMessage().contains("Dilation depth")); + } + } + + + + @Test + public void testConv1D(){ + Conv1DConfig.builder().k(2).build(); + + try{ + Conv1DConfig.builder().k(0).build(); + fail(); + } catch (IllegalArgumentException e){ + assertTrue(e.getMessage().contains("Kernel")); + } + + try{ + Conv1DConfig.builder().k(4).s(-2).build(); + fail(); + } catch (IllegalArgumentException e){ + assertTrue(e.getMessage().contains("Stride")); + } + + try{ + Conv1DConfig.builder().k(3).p(-2).build(); + fail(); + } catch (IllegalArgumentException e){ + assertTrue(e.getMessage().contains("Padding")); + } + } +}