SameDiff Convolution Config validation, better output methods (#82)
* Conv Config validation & tests Signed-off-by: Ryan Nett <rnett@skymind.io> * stackOutputs utility method Signed-off-by: Ryan Nett <rnett@skymind.io> * use constructor for validation, support negative kernel sizes (infered from weights) Signed-off-by: Ryan Nett <rnett@skymind.io> * better output methods Signed-off-by: Ryan Nett <rnett@skymind.io> * move output to be with fit and evaluate Signed-off-by: Ryan Nett <rnett@skymind.io> * fixes Signed-off-by: Ryan Nett <rnett@skymind.io> * more fixes Signed-off-by: Ryan Nett <rnett@skymind.io>
This commit is contained in:
		
							parent
							
								
									8d1fe8b1b3
								
							
						
					
					
						commit
						d4e7997134
					
				@ -1,444 +0,0 @@
 | 
			
		||||
/*******************************************************************************
 | 
			
		||||
 * Copyright (c) 2015-2018 Skymind, Inc.
 | 
			
		||||
 *
 | 
			
		||||
 * This program and the accompanying materials are made available under the
 | 
			
		||||
 * terms of the Apache License, Version 2.0 which is available at
 | 
			
		||||
 * https://www.apache.org/licenses/LICENSE-2.0.
 | 
			
		||||
 *
 | 
			
		||||
 * Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 | 
			
		||||
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 | 
			
		||||
 * License for the specific language governing permissions and limitations
 | 
			
		||||
 * under the License.
 | 
			
		||||
 *
 | 
			
		||||
 * SPDX-License-Identifier: Apache-2.0
 | 
			
		||||
 ******************************************************************************/
 | 
			
		||||
 | 
			
		||||
//
 | 
			
		||||
// Created by raver119 on 08.10.2017.
 | 
			
		||||
//
 | 
			
		||||
 | 
			
		||||
#include <op_boilerplate.h>
 | 
			
		||||
#if NOT_EXCLUDED(OP_fullconv3d)
 | 
			
		||||
 | 
			
		||||
#include <ops/declarable/CustomOperations.h>
 | 
			
		||||
#include <ops/declarable/helpers/convolutions.h>
 | 
			
		||||
 | 
			
		||||
namespace nd4j {
 | 
			
		||||
    namespace ops {
 | 
			
		||||
        //////////////////////////////////////////////////////////////////////////
 | 
			
		||||
        CUSTOM_OP_IMPL(fullconv3d, 5, 1, false, 0, 13) {
 | 
			
		||||
            // auto input = INPUT_VARIABLE(0);
 | 
			
		||||
            // auto weights = INPUT_VARIABLE(1);
 | 
			
		||||
            // auto bias = INPUT_VARIABLE(2);
 | 
			
		||||
            // auto columns = INPUT_VARIABLE(3);
 | 
			
		||||
            // auto ones = INPUT_VARIABLE(4);
 | 
			
		||||
 | 
			
		||||
            // REQUIRE_TRUE(weights->rankOf() == 5, 0, "Weights should be 5D, got %i instead", weights->rankOf());
 | 
			
		||||
            // REQUIRE_TRUE(input->rankOf() == 5, 0, "Input should be 5D, got %i instead", input->rankOf());
 | 
			
		||||
 | 
			
		||||
            // // strides
 | 
			
		||||
            // int dT = INT_ARG(0);
 | 
			
		||||
            // int dW = INT_ARG(1);
 | 
			
		||||
            // int dH = INT_ARG(2);
 | 
			
		||||
 | 
			
		||||
            // // padding
 | 
			
		||||
            // int pT = INT_ARG(3);
 | 
			
		||||
            // int pW = INT_ARG(4);
 | 
			
		||||
            // int pH = INT_ARG(5);
 | 
			
		||||
 | 
			
		||||
            // // dilation
 | 
			
		||||
            // int dilationT = INT_ARG(6);
 | 
			
		||||
            // int dilationW = INT_ARG(7);
 | 
			
		||||
            // int dilationH = INT_ARG(8);
 | 
			
		||||
 | 
			
		||||
            // // output padding
 | 
			
		||||
            // int aT = INT_ARG(9);
 | 
			
		||||
            // int aW = INT_ARG(10);
 | 
			
		||||
            // int aH = INT_ARG(11);
 | 
			
		||||
 | 
			
		||||
            // // bias
 | 
			
		||||
            // bool biasUsed = INT_ARG(12) != 0;
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
            // REQUIRE_TRUE(dT > 0 && dW > 0 && dH > 0, 11,
 | 
			
		||||
            //              "stride should be greater than zero, but got dT: %d dH: %d dW: %d", dT, dH, dW);
 | 
			
		||||
            // REQUIRE_TRUE(dilationT > 0 && dilationW > 0 && dilationH > 0, 15,
 | 
			
		||||
            //              "dilation should be greater than zero, but got dilationT: %d, dilationH: %d, dilationW: %d",
 | 
			
		||||
            //              dilationT, dilationH, dilationW);
 | 
			
		||||
            // REQUIRE_TRUE((aT < dT || aT < dilationT)
 | 
			
		||||
            //              && (aW < dW || aW < dilationW)
 | 
			
		||||
            //              && (aH < dH || aH < dilationH), 15,
 | 
			
		||||
            //              "output padding must be smaller than either stride or dilation,"
 | 
			
		||||
            //                      " but got aT: %d aH: %d aW: %d dT: %d dH: %d dW: %d "
 | 
			
		||||
            //                      "dilationT: %d dilationH: %d dilationW: %d",
 | 
			
		||||
            //              aT, aH, aW, dT, dH, dW, dilationT, dilationH, dilationW);
 | 
			
		||||
 | 
			
		||||
            // auto output = this->getZ(block);
 | 
			
		||||
 | 
			
		||||
            // const int nInputPlane  = weights->shapeOf()[0];
 | 
			
		||||
            // const int nOutputPlane = weights->shapeOf()[1];
 | 
			
		||||
            // const int kT           = weights->shapeOf()[2];
 | 
			
		||||
            // const int kH           = weights->shapeOf()[3];
 | 
			
		||||
            // const int kW           = weights->shapeOf()[4];
 | 
			
		||||
 | 
			
		||||
            // const Nd4jLong inputWidth   = input->shapeOf()[4];
 | 
			
		||||
            // const Nd4jLong inputHeight  = input->shapeOf()[3];
 | 
			
		||||
            // const Nd4jLong inputDepth   = input->shapeOf()[2];
 | 
			
		||||
            // const Nd4jLong outputDepth  = (inputDepth - 1) * dT - 2*pT + (dilationT * (kT - 1) + 1) + aT;
 | 
			
		||||
            // const Nd4jLong outputHeight = (inputHeight - 1) * dH - 2*pH + (dilationH * (kH - 1) + 1) + aH;
 | 
			
		||||
            // const Nd4jLong outputWidth  = (inputWidth - 1) * dW - 2*pW + (dilationW * (kW - 1) + 1) + aW;
 | 
			
		||||
 | 
			
		||||
            // const Nd4jLong batchSize = input->shapeOf()[0];
 | 
			
		||||
 | 
			
		||||
            // REQUIRE_TRUE(output->isSameShape({ (int) batchSize, (int)nOutputPlane, (int)outputDepth, (int)outputHeight, (int)outputWidth}), 0, "Output should have shape of [%i, %i, %i, %i, %i], but got [%i, %i, %i, %i, %i] instead", (int) batchSize, (int)nOutputPlane, (int)outputDepth, (int)outputHeight, (int)outputWidth, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2), output->sizeAt(3), output->sizeAt(4));
 | 
			
		||||
 | 
			
		||||
            // std::unique_ptr<ResultSet> inputs(input->allExamples());
 | 
			
		||||
            // std::unique_ptr<ResultSet> outputs(output->allExamples());
 | 
			
		||||
            // for (int e = 0; e < batchSize; e++) {
 | 
			
		||||
            //     auto tadIn = inputs->at(e);
 | 
			
		||||
            //     auto tadOut = outputs->at(e);
 | 
			
		||||
 | 
			
		||||
            //     const int m = weights->shapeOf()[1] * weights->shapeOf()[2] * weights->shapeOf()[3] * weights->shapeOf()[4];
 | 
			
		||||
            //     const int n = columns->shapeOf()[1];
 | 
			
		||||
            //     const int k = weights->shapeOf()[0];
 | 
			
		||||
 | 
			
		||||
            //     // FIXME: mmul helper should be used here
 | 
			
		||||
            //     /*
 | 
			
		||||
            //     nd4j::blas::GEMM<T>::op('c', 'n', 't', m, n, k,
 | 
			
		||||
            //                             1.0,
 | 
			
		||||
            //                             tadIn->getBuffer(), n,
 | 
			
		||||
            //                             weights->getBuffer(), m,
 | 
			
		||||
            //                             0.0,
 | 
			
		||||
            //                             columns->getBuffer(), n);
 | 
			
		||||
            //                             */
 | 
			
		||||
 | 
			
		||||
            //     // ConvolutionUtils<T>::_col2vol(columns->getBuffer(),
 | 
			
		||||
            //     //                               nOutputPlane, outputDepth, outputHeight, outputWidth,
 | 
			
		||||
            //     //                               inputDepth, inputHeight, inputWidth,
 | 
			
		||||
            //     //                               kT, kH, kW,
 | 
			
		||||
            //     //                               pT, pH, pW,
 | 
			
		||||
            //     //                               dT, dH, dW,
 | 
			
		||||
            //     //                               dilationT,  dilationH,  dilationW,
 | 
			
		||||
            //     //                               tadOut->getBuffer());
 | 
			
		||||
            //     ConvolutionUtils::col2vol(*columns, *tadOut, dT, dH, dW, pT, pH, pW, dilationT, dilationH, dilationW);
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
            //     const int m_ = nOutputPlane;
 | 
			
		||||
            //     const int n_ = outputDepth * outputHeight * outputWidth;
 | 
			
		||||
            //     const int k_ = 1;
 | 
			
		||||
 | 
			
		||||
            //     if (biasUsed) {
 | 
			
		||||
            //         // FIXME: mmul helper should be used here
 | 
			
		||||
            //         /*
 | 
			
		||||
            //         nd4j::blas::GEMM<T>::op('c', 't', 'n', n_, m_, k_,
 | 
			
		||||
            //                                 1.0,
 | 
			
		||||
            //                                 ones->getBuffer(), k_,
 | 
			
		||||
            //                                 bias->getBuffer(), k_,
 | 
			
		||||
            //                                 1.0,
 | 
			
		||||
            //                                 tadOut->getBuffer(), n_);
 | 
			
		||||
            //                                 */
 | 
			
		||||
            //     }
 | 
			
		||||
            // }
 | 
			
		||||
 | 
			
		||||
            // STORE_RESULT(*output);
 | 
			
		||||
 | 
			
		||||
            return Status::OK();
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        DECLARE_TYPES(fullconv3d) {
 | 
			
		||||
            getOpDescriptor()
 | 
			
		||||
                    ->setAllowedInputTypes(nd4j::DataType::ANY)
 | 
			
		||||
                    ->setAllowedOutputTypes({ALL_FLOATS});
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        DECLARE_SHAPE_FN(fullconv3d) {
 | 
			
		||||
            // auto input = inputShape->at(0);
 | 
			
		||||
            // auto weights = inputShape->at(1);
 | 
			
		||||
 | 
			
		||||
            // // strides
 | 
			
		||||
            // int dT = INT_ARG(0);
 | 
			
		||||
            // int dW = INT_ARG(1);
 | 
			
		||||
            // int dH = INT_ARG(2);
 | 
			
		||||
 | 
			
		||||
            // // padding 
 | 
			
		||||
            // int pT = INT_ARG(3);
 | 
			
		||||
            // int pW = INT_ARG(4);
 | 
			
		||||
            // int pH = INT_ARG(5);
 | 
			
		||||
 | 
			
		||||
            // // dilation
 | 
			
		||||
            // int dilationT = INT_ARG(6);
 | 
			
		||||
            // int dilationW = INT_ARG(7);
 | 
			
		||||
            // int dilationH = INT_ARG(8);
 | 
			
		||||
 | 
			
		||||
            // // output padding
 | 
			
		||||
            // int aT = INT_ARG(9);
 | 
			
		||||
            // int aW = INT_ARG(10);
 | 
			
		||||
            // int aH = INT_ARG(11);
 | 
			
		||||
 | 
			
		||||
            // // bias
 | 
			
		||||
            // bool biasUsed = INT_ARG(12) != 0;
 | 
			
		||||
 | 
			
		||||
            // Nd4jLong *shapeOf;
 | 
			
		||||
            // Nd4jLong *newShape;
 | 
			
		||||
            // ALLOCATE(shapeOf, block.getWorkspace(), 5, Nd4jLong);
 | 
			
		||||
            // ALLOCATE(newShape, block.getWorkspace(), shape::shapeInfoLength(5), Nd4jLong);
 | 
			
		||||
 | 
			
		||||
            // const int nInputPlane  = weights[1];
 | 
			
		||||
            // const int nOutputPlane = weights[2];
 | 
			
		||||
            // const int kT           = weights[3];
 | 
			
		||||
            // const int kH           = weights[4];
 | 
			
		||||
            // const int kW           = weights[5];
 | 
			
		||||
 | 
			
		||||
            // const int batchSize          = input[1];
 | 
			
		||||
            // const Nd4jLong inputWidth   = input[5];
 | 
			
		||||
            // const Nd4jLong inputHeight  = input[4];
 | 
			
		||||
            // const Nd4jLong inputDepth   = input[3];
 | 
			
		||||
            // const Nd4jLong outputDepth  = (inputDepth - 1) * dT - 2*pT + (dilationT * (kT - 1) + 1) + aT;
 | 
			
		||||
            // const Nd4jLong outputHeight = (inputHeight - 1) * dH - 2*pH + (dilationH * (kH - 1) + 1) + aH;
 | 
			
		||||
            // const Nd4jLong outputWidth  = (inputWidth - 1) * dW - 2*pW + (dilationW * (kW - 1) + 1) + aW;
 | 
			
		||||
 | 
			
		||||
            // nd4j::ArrayUtils::toLongPtr({(Nd4jLong) batchSize, (Nd4jLong)nOutputPlane, (Nd4jLong)outputDepth, (Nd4jLong)outputHeight, (Nd4jLong)outputWidth}, shapeOf);
 | 
			
		||||
 | 
			
		||||
            // shape::shapeBuffer(5, shapeOf, newShape);
 | 
			
		||||
 | 
			
		||||
            // RELEASE(shapeOf, block.getWorkspace());
 | 
			
		||||
 | 
			
		||||
            // return SHAPELIST(newShape);
 | 
			
		||||
            return SHAPELIST();
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        DECLARE_TYPES(fullconv3d_bp) {
 | 
			
		||||
            getOpDescriptor()
 | 
			
		||||
                    ->setAllowedInputTypes(nd4j::DataType::ANY)
 | 
			
		||||
                    ->setAllowedOutputTypes({ALL_FLOATS});
 | 
			
		||||
        }
 | 
			
		||||
//////////////////////////////////////////////////////////////////////////
 | 
			
		||||
        CUSTOM_OP_IMPL(fullconv3d_bp, 5, 1, false, 0, 13) {
 | 
			
		||||
            // auto input = INPUT_VARIABLE(0);
 | 
			
		||||
            // auto gradNext = INPUT_VARIABLE(1);
 | 
			
		||||
            // auto weights = INPUT_VARIABLE(2);
 | 
			
		||||
            // auto finput = INPUT_VARIABLE(3);
 | 
			
		||||
 | 
			
		||||
            // // not used
 | 
			
		||||
            // auto fgradInput = INPUT_VARIABLE(4);
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
            // REQUIRE_TRUE(weights->rankOf() == 5, 0, "Weights should be 5D, got %i instead", weights->rankOf());
 | 
			
		||||
            // REQUIRE_TRUE(input->rankOf() == 5, 0, "Input should be 5D, got %i instead", input->rankOf());
 | 
			
		||||
 | 
			
		||||
            // auto output = OUTPUT_VARIABLE(0);
 | 
			
		||||
 | 
			
		||||
            // int dT = INT_ARG(0);
 | 
			
		||||
            // int dW = INT_ARG(1);
 | 
			
		||||
            // int dH = INT_ARG(2);
 | 
			
		||||
            // int pT = INT_ARG(3);
 | 
			
		||||
            // int pW = INT_ARG(4);
 | 
			
		||||
            // int pH = INT_ARG(5);
 | 
			
		||||
            // int dilationT = INT_ARG(6);
 | 
			
		||||
            // int dilationW = INT_ARG(7);
 | 
			
		||||
            // int dilationH = INT_ARG(8);
 | 
			
		||||
            // int aT = INT_ARG(9);
 | 
			
		||||
            // int aW = INT_ARG(10);
 | 
			
		||||
            // int aH = INT_ARG(11);
 | 
			
		||||
            // bool biasUsed = INT_ARG(12) != 0;
 | 
			
		||||
 | 
			
		||||
            // const int nInputPlane  = (int)weights->shapeOf()[0];
 | 
			
		||||
            // const int nOutputPlane = (int)weights->shapeOf()[1];
 | 
			
		||||
            // const int kT           = (int)weights->shapeOf()[2];
 | 
			
		||||
            // const int kH           = (int)weights->shapeOf()[3];
 | 
			
		||||
            // const int kW           = (int)weights->shapeOf()[4];
 | 
			
		||||
 | 
			
		||||
            // const Nd4jLong inputWidth   = input->shapeOf()[4];
 | 
			
		||||
            // const Nd4jLong inputHeight  = input->shapeOf()[3];
 | 
			
		||||
            // const Nd4jLong inputDepth   = input->shapeOf()[2];
 | 
			
		||||
            // const Nd4jLong outputDepth  = (inputDepth - 1) * dT - 2*pT + (dilationT * (kT - 1) + 1) + aT;
 | 
			
		||||
            // const Nd4jLong outputHeight = (inputHeight - 1) * dH - 2*pH + (dilationH * (kH - 1) + 1) + aH;
 | 
			
		||||
            // const Nd4jLong outputWidth  = (inputWidth - 1) * dW - 2*pW + (dilationW * (kW - 1) + 1) + aW;
 | 
			
		||||
 | 
			
		||||
            // const Nd4jLong batchSize = input->shapeOf()[0];
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
            // REQUIRE_TRUE(output->isSameShape({(int) batchSize, (int) nInputPlane, (int) inputDepth, (int) inputHeight, (int) inputWidth}) ,0, "Output should have shape of [%i, %i, %i, %i, %i], but got [%i, %i, %i, %i, %i] instead", (int) batchSize, (int) nInputPlane, (int) inputDepth, (int) inputHeight, (int) inputWidth, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2), output->sizeAt(3), output->sizeAt(4));
 | 
			
		||||
 | 
			
		||||
            // output->assign(0.0);
 | 
			
		||||
 | 
			
		||||
            // // FIXME: non-inplace reshape!!!!
 | 
			
		||||
            // NDArray *gradColumns;
 | 
			
		||||
            // //auto gradColumns = finput->reshape('c', {nOutputPlane*kW*kH*kT, inputDepth*inputHeight*inputWidth });
 | 
			
		||||
 | 
			
		||||
            // std::unique_ptr<ResultSet> tadsNext(gradNext->allExamples());
 | 
			
		||||
            // std::unique_ptr<ResultSet> tadsOutput(output->allExamples());
 | 
			
		||||
            // for (int e = 0; e < tadsNext->size(); e++) {
 | 
			
		||||
            //     auto tadNext = tadsNext->at(e);
 | 
			
		||||
            //     auto tadOutput = tadsOutput->at(e);
 | 
			
		||||
 | 
			
		||||
            //     // ConvolutionUtils<T>::_vol2col(
 | 
			
		||||
            //     //         tadNext->getBuffer(),
 | 
			
		||||
            //     //         nOutputPlane, outputDepth, outputHeight, outputWidth,
 | 
			
		||||
            //     //         kT, kH, kW,
 | 
			
		||||
            //     //         pT, pH, pW,
 | 
			
		||||
            //     //         dT, dH, dW,
 | 
			
		||||
            //     //         dilationT,  dilationH,  dilationW,
 | 
			
		||||
            //     //         gradColumns->getBuffer());
 | 
			
		||||
            //     ConvolutionUtils::vol2col(*tadNext, *gradColumns, dT, dH, dW, pT, pH, pW, dilationT, dilationH, dilationW);
 | 
			
		||||
 | 
			
		||||
            //     const auto m = weights->shapeOf()[0];
 | 
			
		||||
            //     const auto n = gradColumns->shapeOf()[1];
 | 
			
		||||
            //     const auto k = weights->shapeOf()[1] * weights->shapeOf()[2] * weights->shapeOf()[3] * weights->shapeOf()[4];
 | 
			
		||||
 | 
			
		||||
            //     // FIXME: mmul helper should be used here
 | 
			
		||||
            //     /*
 | 
			
		||||
            //     nd4j::blas::GEMM<T>::op('f', 'n', 'n',
 | 
			
		||||
            //                             n, m, k,
 | 
			
		||||
            //                             1.0f,
 | 
			
		||||
            //                             gradColumns->getBuffer(), n,
 | 
			
		||||
            //                             weights->getBuffer(), k,
 | 
			
		||||
            //                             0,
 | 
			
		||||
            //                             tadOutput->getBuffer(), n
 | 
			
		||||
 | 
			
		||||
            //     );
 | 
			
		||||
            //      */
 | 
			
		||||
            // }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
            // STORE_RESULT(*output);
 | 
			
		||||
 | 
			
		||||
            // delete gradColumns;
 | 
			
		||||
            return ND4J_STATUS_OK;
 | 
			
		||||
        }
 | 
			
		||||
        DECLARE_SHAPE_FN(fullconv3d_bp) {
 | 
			
		||||
            // output shape equals to input shape, all out of sudden
 | 
			
		||||
            // Nd4jLong* newShape;
 | 
			
		||||
            // COPY_SHAPE(inputShape->at(0), newShape);
 | 
			
		||||
 | 
			
		||||
            // return SHAPELIST(newShape);
 | 
			
		||||
            return SHAPELIST();
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        DECLARE_TYPES(fullconv3d_grad) {
 | 
			
		||||
            getOpDescriptor()
 | 
			
		||||
                    ->setAllowedInputTypes(nd4j::DataType::ANY)
 | 
			
		||||
                    ->setAllowedOutputTypes({ALL_FLOATS});
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
//////////////////////////////////////////////////////////////////////////
 | 
			
		||||
        CUSTOM_OP_IMPL(fullconv3d_grad, 4, 2, false, 1, 13) {
 | 
			
		||||
            // auto input = INPUT_VARIABLE(0);
 | 
			
		||||
            // auto epsilon = INPUT_VARIABLE(1);
 | 
			
		||||
            // auto columns = INPUT_VARIABLE(2);
 | 
			
		||||
            // auto ones = INPUT_VARIABLE(3);
 | 
			
		||||
 | 
			
		||||
            // REQUIRE_TRUE(input->rankOf() == epsilon->rankOf(), 0, "Rank of input (%i) & epsilon (%i) should be equal", input->rankOf(), epsilon->rankOf());
 | 
			
		||||
            // REQUIRE_TRUE(input->sizeAt(0) == epsilon->sizeAt(0), 1, "Batch size should be equal for input and epsilon");
 | 
			
		||||
 | 
			
		||||
            // auto gradWeight = OUTPUT_VARIABLE(0);
 | 
			
		||||
            // auto gradBias = OUTPUT_VARIABLE(1);
 | 
			
		||||
 | 
			
		||||
            // REQUIRE_TRUE(gradBias->sizeAt(0) == gradWeight->sizeAt(1), 0, "Bias shape mismatch");
 | 
			
		||||
 | 
			
		||||
            // int dT = INT_ARG(0);
 | 
			
		||||
            // int dW = INT_ARG(1);
 | 
			
		||||
            // int dH = INT_ARG(2);
 | 
			
		||||
            // int pT = INT_ARG(3);
 | 
			
		||||
            // int pW = INT_ARG(4);
 | 
			
		||||
            // int pH = INT_ARG(5);
 | 
			
		||||
            // int dilationT = INT_ARG(6);
 | 
			
		||||
            // int dilationW = INT_ARG(7);
 | 
			
		||||
            // int dilationH = INT_ARG(8);
 | 
			
		||||
            // int aT = INT_ARG(9);
 | 
			
		||||
            // int aW = INT_ARG(10);
 | 
			
		||||
            // int aH = INT_ARG(11);
 | 
			
		||||
            // bool biasUsed = INT_ARG(12) != 0;
 | 
			
		||||
 | 
			
		||||
            // double scale = block.getTArguments()->at(0);
 | 
			
		||||
 | 
			
		||||
            // int nInputPlane  = (int)gradWeight->shapeOf()[0];
 | 
			
		||||
            // int nOutputPlane = (int)gradWeight->shapeOf()[1];
 | 
			
		||||
            // int kT           = (int)gradWeight->shapeOf()[2];
 | 
			
		||||
            // int kH           = (int)gradWeight->shapeOf()[3];
 | 
			
		||||
            // int kW           = (int)gradWeight->shapeOf()[4];
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
            // const Nd4jLong inputWidth   = input->shapeOf()[4];
 | 
			
		||||
            // const Nd4jLong inputHeight  = input->shapeOf()[3];
 | 
			
		||||
            // const Nd4jLong inputDepth   = input->shapeOf()[2];
 | 
			
		||||
            // const Nd4jLong outputDepth  = (inputDepth - 1) * dT - 2*pT + (dilationT * (kT - 1) + 1) + aT;
 | 
			
		||||
            // const Nd4jLong outputHeight = (inputHeight - 1) * dH - 2*pH + (dilationH * (kH - 1) + 1) + aH;
 | 
			
		||||
            // const Nd4jLong outputWidth  = (inputWidth - 1) * dW - 2*pW + (dilationW * (kW - 1) + 1) + aW;
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
            // REQUIRE_TRUE(gradWeight->isContiguous(), 0, "gradWight should be continuous");
 | 
			
		||||
            // REQUIRE_TRUE(gradBias->isContiguous(), 0, "gradBias should be continuous");
 | 
			
		||||
            // REQUIRE_TRUE(ones->rankOf() == 3, 0, "Ones should have rank 3, got %i instead", ones->rankOf());
 | 
			
		||||
 | 
			
		||||
            // REQUIRE_TRUE(ones->isSameShape({outputDepth, outputHeight, outputWidth}), 0, "");
 | 
			
		||||
 | 
			
		||||
            // ones->assign(1.0);
 | 
			
		||||
 | 
			
		||||
            // std::unique_ptr<ResultSet> tadsInput(input->allExamples());
 | 
			
		||||
            // std::unique_ptr<ResultSet> tadsEpsilon(epsilon->allExamples());
 | 
			
		||||
 | 
			
		||||
            // for (int e = 0; e < tadsInput->size(); e++) {
 | 
			
		||||
            //     auto tadInput = tadsInput->at(e);
 | 
			
		||||
            //     auto tadEpsilon = tadsEpsilon->at(e);
 | 
			
		||||
 | 
			
		||||
            //     // ConvolutionUtils<T>::_vol2col(
 | 
			
		||||
            //     //         tadEpsilon->getBuffer(), nOutputPlane,
 | 
			
		||||
            //     //         outputDepth, outputHeight, outputWidth,
 | 
			
		||||
            //     //         kT, kH, kW,
 | 
			
		||||
            //     //         pT, pH, pW,
 | 
			
		||||
            //     //         dT, dH, dW,
 | 
			
		||||
            //     //         dilationT,  dilationH,  dilationW,
 | 
			
		||||
            //     //         columns->getBuffer()
 | 
			
		||||
            //     // );
 | 
			
		||||
            //     ConvolutionUtils::vol2col(*tadEpsilon, *columns, dT, dH, dW, pT, pH, pW, dilationT, dilationH, dilationW);
 | 
			
		||||
            //     const Nd4jLong n = columns->shapeOf()[0];   // nOutputPlane * kt * kh * kw
 | 
			
		||||
            //     const Nd4jLong m = tadInput->shapeOf()[0];   // nInputPlane
 | 
			
		||||
            //     const Nd4jLong k = columns->shapeOf()[1];
 | 
			
		||||
 | 
			
		||||
            //     // FIXME: mmul helper should be used here
 | 
			
		||||
            //     /**
 | 
			
		||||
            //     nd4j::blas::GEMM<T>::op('f', 't', 'n',
 | 
			
		||||
            //                             n, m, k,
 | 
			
		||||
            //                             scale,
 | 
			
		||||
            //                             columns->getBuffer(), k,
 | 
			
		||||
            //                             tadInput->getBuffer(), k,
 | 
			
		||||
            //                             1,
 | 
			
		||||
            //                             gradWeight->getBuffer(), n);
 | 
			
		||||
            //                             */
 | 
			
		||||
 | 
			
		||||
            //     const Nd4jLong m_ = nOutputPlane;
 | 
			
		||||
            //     const Nd4jLong k_ = outputDepth * outputHeight * outputWidth;
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
            //     if (gradBias) {
 | 
			
		||||
            //         // FIXME: mmul helper should be used here
 | 
			
		||||
            //         /*
 | 
			
		||||
            //         nd4j::blas::GEMV<T>::op('t',
 | 
			
		||||
            //                                 k_, m_,
 | 
			
		||||
            //                                 scale,
 | 
			
		||||
            //                                 tadEpsilon->getBuffer(), k_,
 | 
			
		||||
            //                                 ones->getBuffer(), 1, (T)1.0f,
 | 
			
		||||
            //                                 gradBias->getBuffer(), 1);
 | 
			
		||||
            //                                 */
 | 
			
		||||
            //     }
 | 
			
		||||
            // }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
            // STORE_2_RESULTS(*gradWeight, *gradBias);
 | 
			
		||||
 | 
			
		||||
            return Status::OK();
 | 
			
		||||
        }
 | 
			
		||||
        DECLARE_SHAPE_FN(fullconv3d_grad) {
 | 
			
		||||
            // auto list = SHAPELIST();
 | 
			
		||||
 | 
			
		||||
            // _grad ops MUST have output arrays provided
 | 
			
		||||
 | 
			
		||||
            // return list;
 | 
			
		||||
            return SHAPELIST();
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#endif
 | 
			
		||||
@ -187,12 +187,6 @@ namespace nd4j {
 | 
			
		||||
        DECLARE_CUSTOM_OP(pnormpool2d_bp, 2, 1, false, 1, 10);
 | 
			
		||||
        #endif
 | 
			
		||||
 | 
			
		||||
        #if NOT_EXCLUDED(OP_fullconv3d)
 | 
			
		||||
        DECLARE_CUSTOM_OP(fullconv3d, 5, 1, false, 0, 13);
 | 
			
		||||
        DECLARE_CUSTOM_OP(fullconv3d_bp, 5, 1, false, 0, 13);
 | 
			
		||||
        DECLARE_CUSTOM_OP(fullconv3d_grad, 4, 2, false, 1, 13);
 | 
			
		||||
        #endif
 | 
			
		||||
 | 
			
		||||
        /**
 | 
			
		||||
         * This op implements im2col algorithm, widely used in convolution neural networks
 | 
			
		||||
         * Input: 4D input expected
 | 
			
		||||
 | 
			
		||||
@ -16,6 +16,9 @@
 | 
			
		||||
 | 
			
		||||
package org.nd4j.autodiff.samediff;
 | 
			
		||||
 | 
			
		||||
import static org.nd4j.autodiff.util.TrainingUtils.getSingleOutput;
 | 
			
		||||
import static org.nd4j.autodiff.util.TrainingUtils.stackOutputs;
 | 
			
		||||
 | 
			
		||||
import com.google.common.collect.HashBasedTable;
 | 
			
		||||
import com.google.common.collect.Table;
 | 
			
		||||
import com.google.common.primitives.Ints;
 | 
			
		||||
@ -73,6 +76,7 @@ import org.nd4j.linalg.dataset.adapter.SingletonMultiDataSetIterator;
 | 
			
		||||
import org.nd4j.linalg.dataset.api.MultiDataSet;
 | 
			
		||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
 | 
			
		||||
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
 | 
			
		||||
import org.nd4j.linalg.exception.ND4JException;
 | 
			
		||||
import org.nd4j.linalg.exception.ND4JIllegalArgumentException;
 | 
			
		||||
import org.nd4j.linalg.exception.ND4JIllegalStateException;
 | 
			
		||||
import org.nd4j.linalg.exception.ND4UnresolvedOutputVariables;
 | 
			
		||||
@ -109,7 +113,7 @@ import org.tensorflow.framework.GraphDef;
 | 
			
		||||
 * <p>
 | 
			
		||||
 * That graph accumulates operations.
 | 
			
		||||
 * <p>
 | 
			
		||||
 * In order to execute the graph, you run one of the execution methods, such as {@link #exec(Map, String...)}
 | 
			
		||||
 * In order to execute the graph, you run one of the execution methods, such as {@link #output(Map, String...)}
 | 
			
		||||
 */
 | 
			
		||||
@AllArgsConstructor
 | 
			
		||||
@Builder
 | 
			
		||||
@ -2262,7 +2266,7 @@ public class SameDiff extends SDBaseOps {
 | 
			
		||||
            MultiDataSet ds = iterator.next();
 | 
			
		||||
            Map<String,INDArray> placeholderMap = toPlaceholderMap(ds);
 | 
			
		||||
 | 
			
		||||
            Map<String,INDArray> m = exec(placeholderMap, reqVars);
 | 
			
		||||
            Map<String,INDArray> m = output(placeholderMap, reqVars);
 | 
			
		||||
 | 
			
		||||
            for(Map.Entry<String,List<IEvaluation>> e : variableEvals.entrySet()){
 | 
			
		||||
                INDArray prediction = m.get(e.getKey());
 | 
			
		||||
@ -2288,7 +2292,15 @@ public class SameDiff extends SDBaseOps {
 | 
			
		||||
     * @param outputs        The variables to evaluate
 | 
			
		||||
     */
 | 
			
		||||
    public Map<String, INDArray> output(DataSet dataSet, String... outputs){
 | 
			
		||||
        return output(new SingletonMultiDataSetIterator(dataSet.toMultiDataSet()), outputs).get(0);
 | 
			
		||||
        return outputBatches(new SingletonMultiDataSetIterator(dataSet.toMultiDataSet()), outputs).get(0);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Single output inference.
 | 
			
		||||
     * See {@link #output(DataSet, String...)}
 | 
			
		||||
     */
 | 
			
		||||
    public INDArray outputSingle(DataSet dataSet, String output){
 | 
			
		||||
        return output(dataSet, output).get(output);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
@ -2299,13 +2311,40 @@ public class SameDiff extends SDBaseOps {
 | 
			
		||||
     * sameDiff.output(iterator, "softmax");}
 | 
			
		||||
     * </pre>
 | 
			
		||||
     *
 | 
			
		||||
     * Uses concatenation on the outputs of {@link #outputBatches(DataSetIterator, String...)} which may cause issues with some inputs.
 | 
			
		||||
     * RNNs with variable time series length and CNNs with variable image sizes will most likely have issues.
 | 
			
		||||
     *
 | 
			
		||||
     * @param iterator       Iterator as source of data to evaluate
 | 
			
		||||
     * @param outputs        The variables to evaluate
 | 
			
		||||
     */
 | 
			
		||||
    public List<Map<String, INDArray>> output(DataSetIterator iterator, String... outputs){
 | 
			
		||||
    public Map<String, INDArray> output(DataSetIterator iterator, String... outputs){
 | 
			
		||||
        return output(new MultiDataSetIteratorAdapter(iterator), outputs);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * See {@link #output(DataSetIterator, String...)}, but without the concatenation of batches.
 | 
			
		||||
     *
 | 
			
		||||
     */
 | 
			
		||||
    public List<Map<String, INDArray>> outputBatches(DataSetIterator iterator, String... outputs){
 | 
			
		||||
        return outputBatches(new MultiDataSetIteratorAdapter(iterator), outputs);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Single output inference.
 | 
			
		||||
     * See {@link #output(DataSetIterator, String...)}
 | 
			
		||||
     */
 | 
			
		||||
    public INDArray outputSingle(DataSetIterator dataSet, String output){
 | 
			
		||||
        return output(dataSet, output).get(output);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Single batched output inference.
 | 
			
		||||
     * See {@link #output(DataSetIterator, String...)}
 | 
			
		||||
     */
 | 
			
		||||
    public List<INDArray> outputSingleBatches(DataSetIterator dataSet, String output){
 | 
			
		||||
        return getSingleOutput(outputBatches(dataSet, output), output);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Perform inference.<br>
 | 
			
		||||
     * <br>
 | 
			
		||||
@ -2321,10 +2360,20 @@ public class SameDiff extends SDBaseOps {
 | 
			
		||||
     * }
 | 
			
		||||
     * </pre>
 | 
			
		||||
     *
 | 
			
		||||
     * Uses concatenation on the outputs of {@link #outputBatches(MultiDataSetIterator, String...)} which may cause issues with some inputs.
 | 
			
		||||
     * RNNs with variable time series length and CNNs with variable image sizes will most likely have issues.
 | 
			
		||||
     *
 | 
			
		||||
     * @param iterator  The iterator - the source of the data for inference
 | 
			
		||||
     * @param outputs   The set of outputs to report.  If null, defaults to all outputs of this SameDiff.
 | 
			
		||||
     */
 | 
			
		||||
    public List<Map<String, INDArray>> output(MultiDataSetIterator iterator, String... outputs){
 | 
			
		||||
    public Map<String, INDArray> output(MultiDataSetIterator iterator, String... outputs){
 | 
			
		||||
        return stackOutputs(outputBatches(iterator, outputs));
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * See {@link #output(MultiDataSetIterator, String...)}, but without the concatenation of batches.
 | 
			
		||||
     */
 | 
			
		||||
    public List<Map<String, INDArray>> outputBatches(MultiDataSetIterator iterator, String... outputs){
 | 
			
		||||
        Preconditions.checkState(trainingConfig != null, "Training config has not been set");
 | 
			
		||||
 | 
			
		||||
        List<String> reqVars;
 | 
			
		||||
@ -2344,12 +2393,114 @@ public class SameDiff extends SDBaseOps {
 | 
			
		||||
            MultiDataSet ds = iterator.next();
 | 
			
		||||
            Map<String,INDArray> placeholderMap = toPlaceholderMap(ds);
 | 
			
		||||
 | 
			
		||||
            predictions.add(exec(placeholderMap, reqVars));
 | 
			
		||||
            predictions.add(output(placeholderMap, reqVars));
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        return predictions;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Single output inference.
 | 
			
		||||
     * See {@link #output(MultiDataSetIterator, String...)}
 | 
			
		||||
     */
 | 
			
		||||
    public INDArray outputSingle(MultiDataSetIterator dataSet, String output){
 | 
			
		||||
        return output(dataSet, output).get(output);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Single batched output inference.
 | 
			
		||||
     * See {@link #output(MultiDataSetIterator, String...)}
 | 
			
		||||
     */
 | 
			
		||||
    public List<INDArray> outputSingleBatches(MultiDataSetIterator dataSet, String output){
 | 
			
		||||
        return getSingleOutput(outputBatches(dataSet, output), output);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * @deprecated See {@link #outputAll(Map)}
 | 
			
		||||
     */
 | 
			
		||||
    @Deprecated
 | 
			
		||||
    public Map<String,INDArray> execAll(Map<String,INDArray> placeholders){
 | 
			
		||||
        return outputAll(placeholders);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Do inference for all variables for a single batch
 | 
			
		||||
     */
 | 
			
		||||
    public Map<String,INDArray> outputAll(Map<String,INDArray> placeholders){
 | 
			
		||||
        List<String> allVars = new ArrayList<>();
 | 
			
		||||
        for(Variable v : variables.values()){
 | 
			
		||||
            allVars.add(v.getName());
 | 
			
		||||
        }
 | 
			
		||||
        return output(placeholders, allVars.toArray(new String[0]));
 | 
			
		||||
    }
 | 
			
		||||
    /**
 | 
			
		||||
     * @deprecated See {@link #outputSingle(Map, String)}
 | 
			
		||||
     */
 | 
			
		||||
    @Deprecated
 | 
			
		||||
    public INDArray execSingle(Map<String,INDArray> placeholders, String output){
 | 
			
		||||
        return outputSingle(placeholders, output);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Do inference for a single variable for a single batch
 | 
			
		||||
     */
 | 
			
		||||
    public INDArray outputSingle(Map<String,INDArray> placeholders, String output){
 | 
			
		||||
        return output(placeholders, output).get(output);
 | 
			
		||||
    }
 | 
			
		||||
    /**
 | 
			
		||||
     * @deprecated See {@link #output(Map, List)}
 | 
			
		||||
     */
 | 
			
		||||
    @Deprecated
 | 
			
		||||
    public Map<String,INDArray> exec(Map<String,INDArray> placeholders, List<String> outputs){
 | 
			
		||||
        return output(placeholders, outputs);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Do inference for the given variables for a single batch
 | 
			
		||||
     */
 | 
			
		||||
    public Map<String,INDArray> output(Map<String,INDArray> placeholders, List<String> outputs){
 | 
			
		||||
        return output(placeholders, outputs.toArray(new String[outputs.size()]));
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * @deprecated See {@link #output(Map, String...)}
 | 
			
		||||
     */
 | 
			
		||||
    @Deprecated
 | 
			
		||||
    public Map<String,INDArray> exec(Map<String,INDArray> placeholders, String... outputs) {
 | 
			
		||||
        return output(placeholders, outputs);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Do inference for the given variables for a single batch
 | 
			
		||||
     */
 | 
			
		||||
    public Map<String,INDArray> output(Map<String,INDArray> placeholders, String... outputs) {
 | 
			
		||||
        return output(placeholders, false, null, outputs);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Do inference for the given variables for a single batch, with training information
 | 
			
		||||
     */
 | 
			
		||||
    protected Map<String,INDArray> output(Map<String,INDArray> placeholders, boolean training, At at, String... outputs){
 | 
			
		||||
        Preconditions.checkState(outputs != null && outputs.length > 0, "No outputs were specified");
 | 
			
		||||
        long threadId = Thread.currentThread().getId();
 | 
			
		||||
        if(!sessions.containsKey(threadId)){
 | 
			
		||||
            log.info("Creating new InferenceSession for thread {}", threadId);
 | 
			
		||||
            sessions.put(threadId, new InferenceSession(this));
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        List<String> phNames = inputs();
 | 
			
		||||
        if(placeholders == null && phNames != null){
 | 
			
		||||
            //Maybe user set placeholders before calling exec method?
 | 
			
		||||
            placeholders = placeholdersPerThread.get(Thread.currentThread().getId());
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        //Placeholder validation is performed in InferenceSession
 | 
			
		||||
 | 
			
		||||
        InferenceSession is = sessions.get(threadId);
 | 
			
		||||
        Map<String,INDArray> ret = is.output(Arrays.asList(outputs), placeholders, listeners, training, at);
 | 
			
		||||
        return ret;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public SDVariable one(String name, int... shape){
 | 
			
		||||
        return one(name, Nd4j.defaultFloatingPointType(), shape);
 | 
			
		||||
@ -3779,7 +3930,7 @@ public class SameDiff extends SDBaseOps {
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        //TODO is this 'train' flag the best approach?
 | 
			
		||||
        sd.exec(placeholders, trainingConfig != null, at, variableGradNamesList.toArray(new String[variableGradNamesList.size()]));
 | 
			
		||||
        sd.output(placeholders, trainingConfig != null, at, variableGradNamesList.toArray(new String[variableGradNamesList.size()]));
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
@ -4459,47 +4610,6 @@ public class SameDiff extends SDBaseOps {
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public Map<String,INDArray> execAll(Map<String,INDArray> placeholders){
 | 
			
		||||
        List<String> allVars = new ArrayList<>();
 | 
			
		||||
        for(Variable v : variables.values()){
 | 
			
		||||
            allVars.add(v.getName());
 | 
			
		||||
        }
 | 
			
		||||
        return exec(placeholders, allVars.toArray(new String[allVars.size()]));
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public INDArray execSingle(Map<String,INDArray> placeholders, String output){
 | 
			
		||||
        return exec(placeholders, output).get(output);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public Map<String,INDArray> exec(Map<String,INDArray> placeholders, List<String> outputs){
 | 
			
		||||
        return exec(placeholders, outputs.toArray(new String[outputs.size()]));
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public Map<String,INDArray> exec(Map<String,INDArray> placeholders, String... outputs) {
 | 
			
		||||
        return exec(placeholders, false, null, outputs);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    protected Map<String,INDArray> exec(Map<String,INDArray> placeholders, boolean training, At at, String... outputs){
 | 
			
		||||
        Preconditions.checkState(outputs != null && outputs.length > 0, "No outputs were specified");
 | 
			
		||||
        long threadId = Thread.currentThread().getId();
 | 
			
		||||
        if(!sessions.containsKey(threadId)){
 | 
			
		||||
            log.info("Creating new InferenceSession for thread {}", threadId);
 | 
			
		||||
            sessions.put(threadId, new InferenceSession(this));
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        List<String> phNames = inputs();
 | 
			
		||||
        if(placeholders == null && phNames != null){
 | 
			
		||||
            //Maybe user set placeholders before calling exec method?
 | 
			
		||||
            placeholders = placeholdersPerThread.get(Thread.currentThread().getId());
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        //Placeholder validation is performed in InferenceSession
 | 
			
		||||
 | 
			
		||||
        InferenceSession is = sessions.get(threadId);
 | 
			
		||||
        Map<String,INDArray> ret = is.output(Arrays.asList(outputs), placeholders, listeners, training, at);
 | 
			
		||||
        return ret;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    protected int asFlatNode(String name, @NonNull SameDiff scope, @NonNull FlatBufferBuilder bufferBuilder) {
 | 
			
		||||
        int scopeName = bufferBuilder.createString(name);
 | 
			
		||||
 | 
			
		||||
@ -0,0 +1,70 @@
 | 
			
		||||
/*
 | 
			
		||||
 * Copyright (c) 2015-2019 Skymind, Inc.
 | 
			
		||||
 *
 | 
			
		||||
 * This program and the accompanying materials are made available under the
 | 
			
		||||
 * terms of the Apache License, Version 2.0 which is available at
 | 
			
		||||
 * https://www.apache.org/licenses/LICENSE-2.0.
 | 
			
		||||
 *
 | 
			
		||||
 * Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 | 
			
		||||
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 | 
			
		||||
 * License for the specific language governing permissions and limitations
 | 
			
		||||
 * under the License.
 | 
			
		||||
 *
 | 
			
		||||
 * SPDX-License-Identifier: Apache-2.0
 | 
			
		||||
 */
 | 
			
		||||
 | 
			
		||||
package org.nd4j.autodiff.util;
 | 
			
		||||
 | 
			
		||||
import java.util.ArrayList;
 | 
			
		||||
import java.util.HashMap;
 | 
			
		||||
import java.util.List;
 | 
			
		||||
import java.util.Map;
 | 
			
		||||
import lombok.AccessLevel;
 | 
			
		||||
import lombok.NoArgsConstructor;
 | 
			
		||||
import org.nd4j.linalg.api.ndarray.INDArray;
 | 
			
		||||
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
 | 
			
		||||
import org.nd4j.linalg.exception.ND4JException;
 | 
			
		||||
import org.nd4j.linalg.factory.Nd4j;
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * Utilities for SameDiff training and inference
 | 
			
		||||
 */
 | 
			
		||||
@NoArgsConstructor(access = AccessLevel.PRIVATE)
 | 
			
		||||
public class TrainingUtils {
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Stack batch outputs, like an output from {@link org.nd4j.autodiff.samediff.SameDiff#output(MultiDataSetIterator, String...)}
 | 
			
		||||
     */
 | 
			
		||||
    public static Map<String, INDArray> stackOutputs(List<Map<String, INDArray>> outputs){
 | 
			
		||||
        Map<String, List<INDArray>> outs = new HashMap<>();
 | 
			
		||||
        for(Map<String, INDArray> batch : outputs){
 | 
			
		||||
            for(String k : batch.keySet()){
 | 
			
		||||
                if(!outs.containsKey(k))
 | 
			
		||||
                    outs.put(k, new ArrayList<INDArray>());
 | 
			
		||||
                outs.get(k).add(batch.get(k));
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        Map<String, INDArray> ret = new HashMap<>();
 | 
			
		||||
        for(String k : outs.keySet()){
 | 
			
		||||
            try {
 | 
			
		||||
                ret.put(k, Nd4j.concat(0, outs.get(k).toArray(new INDArray[0])));
 | 
			
		||||
            } catch(Exception e){
 | 
			
		||||
                throw new ND4JException("Error concatenating batch outputs", e);
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
        return ret;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Get a list of batch outputs for a single variable from a list of batch outputs for all variables
 | 
			
		||||
     */
 | 
			
		||||
    public static List<INDArray> getSingleOutput(List<Map<String, INDArray>> outputs, String output){
 | 
			
		||||
        List<INDArray> batches = new ArrayList<>();
 | 
			
		||||
        for(Map<String, INDArray> batch : outputs)
 | 
			
		||||
            batches.add(batch.get(output));
 | 
			
		||||
 | 
			
		||||
        return batches;
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
@ -915,7 +915,6 @@ public class OpValidation {
 | 
			
		||||
                Conv2DDerivative.class,
 | 
			
		||||
                Conv3DDerivative.class,
 | 
			
		||||
                DeConv2DDerivative.class,
 | 
			
		||||
                FullConv3DDerivative.class,
 | 
			
		||||
                LocalResponseNormalizationDerivative.class,
 | 
			
		||||
                Pooling2DDerivative.class,
 | 
			
		||||
                Pooling3DDerivative.class,
 | 
			
		||||
 | 
			
		||||
@ -72,7 +72,6 @@ public class DifferentialFunctionClassHolder {
 | 
			
		||||
        add(AvgPooling2D.class.getName());
 | 
			
		||||
        add(Conv2D.class.getName());
 | 
			
		||||
        add(Conv3D.class.getName());
 | 
			
		||||
        add(FullConv3D.class.getName());
 | 
			
		||||
        add(LocalResponseNormalization.class.getName());
 | 
			
		||||
        add(MaxPooling2D.class.getName());
 | 
			
		||||
        add(Pooling2D.class.getName());
 | 
			
		||||
 | 
			
		||||
@ -117,8 +117,6 @@ public class ImportClassMapping {
 | 
			
		||||
            org.nd4j.linalg.api.ops.impl.layers.convolution.DeConv3DDerivative.class,
 | 
			
		||||
            org.nd4j.linalg.api.ops.impl.layers.convolution.DepthToSpace.class,
 | 
			
		||||
            org.nd4j.linalg.api.ops.impl.layers.convolution.DepthwiseConv2D.class,
 | 
			
		||||
            org.nd4j.linalg.api.ops.impl.layers.convolution.FullConv3D.class,
 | 
			
		||||
            org.nd4j.linalg.api.ops.impl.layers.convolution.FullConv3DDerivative.class,
 | 
			
		||||
            org.nd4j.linalg.api.ops.impl.layers.convolution.Im2col.class,
 | 
			
		||||
            org.nd4j.linalg.api.ops.impl.layers.convolution.Im2colBp.class,
 | 
			
		||||
            org.nd4j.linalg.api.ops.impl.layers.convolution.LegacyPooling2D.class,
 | 
			
		||||
 | 
			
		||||
@ -251,8 +251,6 @@ public class AvgPooling2D extends DynamicCustomOp {
 | 
			
		||||
                .kW(kW)
 | 
			
		||||
                .pH(pH)
 | 
			
		||||
                .pW(pW)
 | 
			
		||||
                .virtualHeight(1)
 | 
			
		||||
                .virtualWidth(1)
 | 
			
		||||
                .isNHWC(data_format.equalsIgnoreCase("nhwc"))
 | 
			
		||||
                .extra(0.0) // averaging only for non-padded values
 | 
			
		||||
                .build();
 | 
			
		||||
@ -277,8 +275,6 @@ public class AvgPooling2D extends DynamicCustomOp {
 | 
			
		||||
                .kW(kernelShape.get(1).intValue())
 | 
			
		||||
                .pH(padding.get(0).intValue())
 | 
			
		||||
                .pW(padding.size() < 2 ? padding.get(0).intValue() : padding.get(1).intValue())
 | 
			
		||||
                .virtualWidth(1)
 | 
			
		||||
                .virtualHeight(1)
 | 
			
		||||
                .build();
 | 
			
		||||
        this.config = pooling2DConfig;
 | 
			
		||||
        addArgs();
 | 
			
		||||
 | 
			
		||||
@ -1,228 +0,0 @@
 | 
			
		||||
/*******************************************************************************
 | 
			
		||||
 * Copyright (c) 2015-2018 Skymind, Inc.
 | 
			
		||||
 *
 | 
			
		||||
 * This program and the accompanying materials are made available under the
 | 
			
		||||
 * terms of the Apache License, Version 2.0 which is available at
 | 
			
		||||
 * https://www.apache.org/licenses/LICENSE-2.0.
 | 
			
		||||
 *
 | 
			
		||||
 * Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 | 
			
		||||
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 | 
			
		||||
 * License for the specific language governing permissions and limitations
 | 
			
		||||
 * under the License.
 | 
			
		||||
 *
 | 
			
		||||
 * SPDX-License-Identifier: Apache-2.0
 | 
			
		||||
 ******************************************************************************/
 | 
			
		||||
 | 
			
		||||
package org.nd4j.linalg.api.ops.impl.layers.convolution;
 | 
			
		||||
 | 
			
		||||
import lombok.Builder;
 | 
			
		||||
import lombok.extern.slf4j.Slf4j;
 | 
			
		||||
import lombok.val;
 | 
			
		||||
import org.nd4j.autodiff.samediff.SDVariable;
 | 
			
		||||
import org.nd4j.autodiff.samediff.SameDiff;
 | 
			
		||||
import org.nd4j.base.Preconditions;
 | 
			
		||||
import org.nd4j.imports.descriptors.properties.AttributeAdapter;
 | 
			
		||||
import org.nd4j.imports.descriptors.properties.PropertyMapping;
 | 
			
		||||
import org.nd4j.imports.descriptors.properties.adapters.IntArrayIntIndexAdpater;
 | 
			
		||||
import org.nd4j.linalg.api.buffer.DataType;
 | 
			
		||||
import org.nd4j.linalg.api.ndarray.INDArray;
 | 
			
		||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
 | 
			
		||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.FullConv3DConfig;
 | 
			
		||||
import org.nd4j.linalg.util.ArrayUtil;
 | 
			
		||||
 | 
			
		||||
import java.lang.reflect.Field;
 | 
			
		||||
import java.util.*;
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * FullConv3D operation
 | 
			
		||||
 */
 | 
			
		||||
@Slf4j
 | 
			
		||||
public class FullConv3D extends DynamicCustomOp {
 | 
			
		||||
 | 
			
		||||
    protected FullConv3DConfig config;
 | 
			
		||||
 | 
			
		||||
    @Builder(builderMethodName = "builder")
 | 
			
		||||
    public FullConv3D(SameDiff sameDiff, SDVariable[] inputFunctions, INDArray[] inputs, INDArray[] outputs, FullConv3DConfig config) {
 | 
			
		||||
        super(null,sameDiff, inputFunctions, false);
 | 
			
		||||
        this.config = config;
 | 
			
		||||
        if(inputs != null) {
 | 
			
		||||
            addInputArgument(inputs);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        if(outputs != null) {
 | 
			
		||||
            addOutputArgument(outputs);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        addArgs();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    public FullConv3D() {}
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
    public Map<String, Object> propertiesForFunction() {
 | 
			
		||||
        return config.toProperties();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
    public long[] iArgs() {
 | 
			
		||||
        if (iArguments.size() == 0)
 | 
			
		||||
            addArgs();
 | 
			
		||||
 | 
			
		||||
        return super.iArgs();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
    public boolean isConfigProperties() {
 | 
			
		||||
        return true;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
    public String configFieldName() {
 | 
			
		||||
        return "config";
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
    public Map<String, Map<String, AttributeAdapter>> attributeAdaptersForFunction() {
 | 
			
		||||
        Map<String,Map<String,AttributeAdapter>> ret = new LinkedHashMap<>();
 | 
			
		||||
        Map<String,AttributeAdapter> tfAdapters = new LinkedHashMap<>();
 | 
			
		||||
 | 
			
		||||
        tfAdapters.put("dT", new IntArrayIntIndexAdpater(1));
 | 
			
		||||
        tfAdapters.put("dW",  new IntArrayIntIndexAdpater(2));
 | 
			
		||||
        tfAdapters.put("dH",new IntArrayIntIndexAdpater(3));
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        tfAdapters.put("pT", new IntArrayIntIndexAdpater(1));
 | 
			
		||||
        tfAdapters.put("pW",  new IntArrayIntIndexAdpater(2));
 | 
			
		||||
        tfAdapters.put("pH",new IntArrayIntIndexAdpater(3));
 | 
			
		||||
 | 
			
		||||
        ret.put(tensorflowName(),tfAdapters);
 | 
			
		||||
 | 
			
		||||
        return ret;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
    public Map<String, Map<String, PropertyMapping>> mappingsForFunction() {
 | 
			
		||||
        Map<String,Map<String,PropertyMapping>> ret = new HashMap<>();
 | 
			
		||||
        Map<String,PropertyMapping> map = new HashMap<>();
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        val strideMapping = PropertyMapping.builder()
 | 
			
		||||
                .tfAttrName("strides")
 | 
			
		||||
                .onnxAttrName("strides")
 | 
			
		||||
                .propertyNames(new String[]{"dT","dW","dH"})
 | 
			
		||||
                .build();
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        val dilationMapping = PropertyMapping.builder()
 | 
			
		||||
                .onnxAttrName("dilations")
 | 
			
		||||
                .propertyNames(new String[]{"dD","dH","dW"})
 | 
			
		||||
                .tfAttrName("rates")
 | 
			
		||||
                .build();
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        val sameMode = PropertyMapping.builder()
 | 
			
		||||
                .onnxAttrName("auto_pad")
 | 
			
		||||
                .propertyNames(new String[]{"isSameMode"})
 | 
			
		||||
                .tfAttrName("padding")
 | 
			
		||||
                .build();
 | 
			
		||||
 | 
			
		||||
        val paddingWidthHeight = PropertyMapping.builder()
 | 
			
		||||
                .onnxAttrName("padding")
 | 
			
		||||
                .propertyNames(new String[]{"pT","pW","pH"})
 | 
			
		||||
                .build();
 | 
			
		||||
 | 
			
		||||
        val dataFormat = PropertyMapping.builder()
 | 
			
		||||
                .onnxAttrName("data_format")
 | 
			
		||||
                .tfAttrName("data_format")
 | 
			
		||||
                .propertyNames(new String[]{"dataFormat"})
 | 
			
		||||
                .build();
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        val outputPadding = PropertyMapping.builder()
 | 
			
		||||
                .propertyNames(new String[]{"aT","aH","aW"})
 | 
			
		||||
                .build();
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        val biasUsed = PropertyMapping.builder()
 | 
			
		||||
                .propertyNames(new String[]{"biasUsed"})
 | 
			
		||||
                .build();
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        for(val propertyMapping : new PropertyMapping[] {
 | 
			
		||||
                strideMapping,
 | 
			
		||||
                dilationMapping,
 | 
			
		||||
                sameMode,
 | 
			
		||||
                paddingWidthHeight,
 | 
			
		||||
                dataFormat,
 | 
			
		||||
                outputPadding,biasUsed}) {
 | 
			
		||||
            for(val keys : propertyMapping.getPropertyNames())
 | 
			
		||||
                map.put(keys,propertyMapping);
 | 
			
		||||
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        ret.put(onnxName(),map);
 | 
			
		||||
        ret.put(tensorflowName(),map);
 | 
			
		||||
        return ret;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    private void addArgs() {
 | 
			
		||||
        addIArgument(new long[]{
 | 
			
		||||
                config.getDT(),
 | 
			
		||||
                config.getDW(),
 | 
			
		||||
                config.getDH(),
 | 
			
		||||
                config.getPT(),
 | 
			
		||||
                config.getPW(),
 | 
			
		||||
                config.getPH(),
 | 
			
		||||
                config.getDilationT(),
 | 
			
		||||
                config.getDilationW(),
 | 
			
		||||
                config.getDilationH(),
 | 
			
		||||
                config.getAT(),
 | 
			
		||||
                config.getAW(),
 | 
			
		||||
                config.getAH(),
 | 
			
		||||
                ArrayUtil.fromBoolean(config.isBiasUsed())});
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
    public String opName() {
 | 
			
		||||
        return "fullconv3d";
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
    public List<SDVariable> doDiff(List<SDVariable> f1) {
 | 
			
		||||
        List<SDVariable> inputs = new ArrayList<>();
 | 
			
		||||
        inputs.addAll(Arrays.asList(args()));
 | 
			
		||||
        inputs.addAll(f1);
 | 
			
		||||
        List<SDVariable> ret = new ArrayList<>();
 | 
			
		||||
        FullConv3DDerivative fullConv3DDerivative = FullConv3DDerivative.derivativeBuilder()
 | 
			
		||||
                .conv3DConfig(config)
 | 
			
		||||
                .sameDiff(sameDiff)
 | 
			
		||||
                .inputFunctions(inputs.toArray(new SDVariable[inputs.size()]))
 | 
			
		||||
                .build();
 | 
			
		||||
        ret.addAll(Arrays.asList(fullConv3DDerivative.outputVariables()));
 | 
			
		||||
        return ret;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
    public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
 | 
			
		||||
        int n = args().length;
 | 
			
		||||
        Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes);
 | 
			
		||||
        return Collections.singletonList(inputDataTypes.get(0));
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
@ -1,77 +0,0 @@
 | 
			
		||||
/*******************************************************************************
 | 
			
		||||
 * Copyright (c) 2015-2018 Skymind, Inc.
 | 
			
		||||
 *
 | 
			
		||||
 * This program and the accompanying materials are made available under the
 | 
			
		||||
 * terms of the Apache License, Version 2.0 which is available at
 | 
			
		||||
 * https://www.apache.org/licenses/LICENSE-2.0.
 | 
			
		||||
 *
 | 
			
		||||
 * Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 | 
			
		||||
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 | 
			
		||||
 * License for the specific language governing permissions and limitations
 | 
			
		||||
 * under the License.
 | 
			
		||||
 *
 | 
			
		||||
 * SPDX-License-Identifier: Apache-2.0
 | 
			
		||||
 ******************************************************************************/
 | 
			
		||||
 | 
			
		||||
package org.nd4j.linalg.api.ops.impl.layers.convolution;
 | 
			
		||||
 | 
			
		||||
import lombok.Builder;
 | 
			
		||||
import lombok.extern.slf4j.Slf4j;
 | 
			
		||||
import org.nd4j.autodiff.samediff.SDVariable;
 | 
			
		||||
import org.nd4j.autodiff.samediff.SameDiff;
 | 
			
		||||
import org.nd4j.base.Preconditions;
 | 
			
		||||
import org.nd4j.imports.NoOpNameFoundException;
 | 
			
		||||
import org.nd4j.linalg.api.buffer.DataType;
 | 
			
		||||
import org.nd4j.linalg.api.ndarray.INDArray;
 | 
			
		||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.FullConv3DConfig;
 | 
			
		||||
 | 
			
		||||
import java.util.ArrayList;
 | 
			
		||||
import java.util.List;
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * FullConv3DDerivative operation
 | 
			
		||||
 */
 | 
			
		||||
@Slf4j
 | 
			
		||||
public class FullConv3DDerivative extends FullConv3D {
 | 
			
		||||
 | 
			
		||||
    public FullConv3DDerivative() {}
 | 
			
		||||
 | 
			
		||||
    @Builder(builderMethodName = "derivativeBuilder")
 | 
			
		||||
    public FullConv3DDerivative(SameDiff sameDiff, SDVariable[] inputFunctions, INDArray[] inputs, INDArray[] outputs, FullConv3DConfig conv3DConfig) {
 | 
			
		||||
        super(sameDiff, inputFunctions, inputs, outputs, conv3DConfig);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
    public String opName() {
 | 
			
		||||
        return "fullconv3d_bp";
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
    public String tensorflowName() {
 | 
			
		||||
        throw new NoOpNameFoundException("No tensorflwo op name found for conv3d derivative");
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
    public String[] tensorflowNames() {
 | 
			
		||||
        throw new NoOpNameFoundException("No tensorflwo op name found for conv3d derivative");
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
    public List<SDVariable> doDiff(List<SDVariable> f1) {
 | 
			
		||||
        throw new UnsupportedOperationException("Unable to take derivative of derivative.");
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
    public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
 | 
			
		||||
        int n = args().length;  //Original inputs + gradient at
 | 
			
		||||
        Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes);
 | 
			
		||||
        List<DataType> out = new ArrayList<>(n-1);
 | 
			
		||||
        for( int i=0; i<n-1; i++ ){
 | 
			
		||||
            out.add(inputDataTypes.get(i));
 | 
			
		||||
        }
 | 
			
		||||
        return out;
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
@ -204,8 +204,6 @@ public class MaxPooling2D extends DynamicCustomOp {
 | 
			
		||||
                .kW(kW)
 | 
			
		||||
                .pH(pH)
 | 
			
		||||
                .pW(pW)
 | 
			
		||||
                .virtualHeight(1)
 | 
			
		||||
                .virtualWidth(1)
 | 
			
		||||
                .isNHWC(data_format.equalsIgnoreCase("nhwc"))
 | 
			
		||||
                .extra(1.0) // averaging only for non-padded values
 | 
			
		||||
                .build();
 | 
			
		||||
@ -230,8 +228,6 @@ public class MaxPooling2D extends DynamicCustomOp {
 | 
			
		||||
                .kW(kernelShape.size() < 2 ? kernelShape.get(0).intValue() : kernelShape.get(1).intValue())
 | 
			
		||||
                .pH(padding.get(0).intValue())
 | 
			
		||||
                .pW(padding.size() < 2 ? padding.get(0).intValue() : padding.get(1).intValue())
 | 
			
		||||
                .virtualHeight(1)
 | 
			
		||||
                .virtualWidth(1)
 | 
			
		||||
                .build();
 | 
			
		||||
        this.config = pooling2DConfig;
 | 
			
		||||
        addArgs();
 | 
			
		||||
 | 
			
		||||
@ -174,8 +174,6 @@ public class Pooling2D extends DynamicCustomOp {
 | 
			
		||||
                .kW(kW.intValue())
 | 
			
		||||
                .pH(padding.get(0).intValue())
 | 
			
		||||
                .pW(padding.get(1).intValue())
 | 
			
		||||
                .virtualWidth(1)
 | 
			
		||||
                .virtualHeight(1)
 | 
			
		||||
                .build();
 | 
			
		||||
        this.config = pooling2DConfig;
 | 
			
		||||
        addArgs();
 | 
			
		||||
@ -200,8 +198,6 @@ public class Pooling2D extends DynamicCustomOp {
 | 
			
		||||
                .kW(kernelShape.get(1).intValue())
 | 
			
		||||
                .pH(padding.get(0).intValue())
 | 
			
		||||
                .pW(padding.get(1).intValue())
 | 
			
		||||
                .virtualHeight(1)
 | 
			
		||||
                .virtualWidth(1)
 | 
			
		||||
                .build();
 | 
			
		||||
        this.config = pooling2DConfig;
 | 
			
		||||
        addArgs();
 | 
			
		||||
 | 
			
		||||
@ -16,6 +16,8 @@
 | 
			
		||||
 | 
			
		||||
package org.nd4j.linalg.api.ops.impl.layers.convolution.config;
 | 
			
		||||
 | 
			
		||||
import java.util.LinkedHashMap;
 | 
			
		||||
import java.util.Map;
 | 
			
		||||
import lombok.val;
 | 
			
		||||
import org.nd4j.linalg.exception.ND4JIllegalStateException;
 | 
			
		||||
 | 
			
		||||
@ -23,6 +25,8 @@ import java.lang.reflect.Field;
 | 
			
		||||
 | 
			
		||||
public abstract class BaseConvolutionConfig {
 | 
			
		||||
 | 
			
		||||
    public abstract Map<String, Object> toProperties();
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Get the value for a given property
 | 
			
		||||
     * for this function
 | 
			
		||||
@ -154,4 +158,5 @@ public abstract class BaseConvolutionConfig {
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    protected abstract void validate();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -16,15 +16,16 @@
 | 
			
		||||
 | 
			
		||||
package org.nd4j.linalg.api.ops.impl.layers.convolution.config;
 | 
			
		||||
 | 
			
		||||
import lombok.*;
 | 
			
		||||
import org.nd4j.base.Preconditions;
 | 
			
		||||
 | 
			
		||||
import java.util.LinkedHashMap;
 | 
			
		||||
import java.util.Map;
 | 
			
		||||
import lombok.Builder;
 | 
			
		||||
import lombok.Data;
 | 
			
		||||
import lombok.NoArgsConstructor;
 | 
			
		||||
import org.nd4j.base.Preconditions;
 | 
			
		||||
import org.nd4j.linalg.util.ConvConfigUtil;
 | 
			
		||||
 | 
			
		||||
@Builder
 | 
			
		||||
@Data
 | 
			
		||||
@AllArgsConstructor
 | 
			
		||||
@Builder
 | 
			
		||||
@NoArgsConstructor
 | 
			
		||||
public class Conv1DConfig extends BaseConvolutionConfig {
 | 
			
		||||
    public static final String NCW = "NCW";
 | 
			
		||||
@ -40,6 +41,16 @@ public class Conv1DConfig extends BaseConvolutionConfig {
 | 
			
		||||
    private String dataFormat = NCW;
 | 
			
		||||
    private boolean isSameMode;
 | 
			
		||||
 | 
			
		||||
    public Conv1DConfig(long k, long s, long p, String dataFormat, boolean isSameMode) {
 | 
			
		||||
        this.k = k;
 | 
			
		||||
        this.s = s;
 | 
			
		||||
        this.p = p;
 | 
			
		||||
        this.dataFormat = dataFormat;
 | 
			
		||||
        this.isSameMode = isSameMode;
 | 
			
		||||
 | 
			
		||||
        validate();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public boolean isNWC(){
 | 
			
		||||
        Preconditions.checkState(dataFormat.equalsIgnoreCase(NCW) || dataFormat.equalsIgnoreCase(NWC),
 | 
			
		||||
                "Data format must be one of %s or %s, got %s", NCW, NWC, dataFormat);
 | 
			
		||||
@ -54,6 +65,7 @@ public class Conv1DConfig extends BaseConvolutionConfig {
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
    public Map<String, Object> toProperties() {
 | 
			
		||||
        Map<String, Object> ret = new LinkedHashMap<>();
 | 
			
		||||
        ret.put("k", k);
 | 
			
		||||
@ -64,5 +76,11 @@ public class Conv1DConfig extends BaseConvolutionConfig {
 | 
			
		||||
        return ret;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
    protected void validate() {
 | 
			
		||||
        ConvConfigUtil.validate1D(k, s, p);
 | 
			
		||||
        Preconditions.checkArgument(dataFormat != null, "Data format can't be null");
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -16,18 +16,16 @@
 | 
			
		||||
 | 
			
		||||
package org.nd4j.linalg.api.ops.impl.layers.convolution.config;
 | 
			
		||||
 | 
			
		||||
import lombok.AllArgsConstructor;
 | 
			
		||||
import java.util.LinkedHashMap;
 | 
			
		||||
import java.util.Map;
 | 
			
		||||
import lombok.Builder;
 | 
			
		||||
import lombok.Data;
 | 
			
		||||
import lombok.NoArgsConstructor;
 | 
			
		||||
import org.nd4j.base.Preconditions;
 | 
			
		||||
import org.nd4j.linalg.util.ConvConfigUtil;
 | 
			
		||||
 | 
			
		||||
import java.util.LinkedHashMap;
 | 
			
		||||
import java.util.Map;
 | 
			
		||||
 | 
			
		||||
@Builder
 | 
			
		||||
@Data
 | 
			
		||||
@AllArgsConstructor
 | 
			
		||||
@Builder
 | 
			
		||||
@NoArgsConstructor
 | 
			
		||||
public class Conv2DConfig extends BaseConvolutionConfig {
 | 
			
		||||
    public static final String NCHW = "NCHW";
 | 
			
		||||
@ -53,6 +51,23 @@ public class Conv2DConfig extends BaseConvolutionConfig {
 | 
			
		||||
    @Builder.Default
 | 
			
		||||
    private String dataFormat = NCHW;
 | 
			
		||||
 | 
			
		||||
    public Conv2DConfig(long kH, long kW, long sH, long sW, long pH, long pW, long dH, long dW, boolean isSameMode,
 | 
			
		||||
            String dataFormat) {
 | 
			
		||||
 | 
			
		||||
        this.kH = kH;
 | 
			
		||||
        this.kW = kW;
 | 
			
		||||
        this.sH = sH;
 | 
			
		||||
        this.sW = sW;
 | 
			
		||||
        this.pH = pH;
 | 
			
		||||
        this.pW = pW;
 | 
			
		||||
        this.dH = dH;
 | 
			
		||||
        this.dW = dW;
 | 
			
		||||
        this.isSameMode = isSameMode;
 | 
			
		||||
        this.dataFormat = dataFormat;
 | 
			
		||||
 | 
			
		||||
        validate();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public boolean isNHWC(){
 | 
			
		||||
        Preconditions.checkState(dataFormat.equalsIgnoreCase(NCHW) || dataFormat.equalsIgnoreCase(NHWC),
 | 
			
		||||
                "Data format must be one of %s or %s, got %s", NCHW, NHWC, dataFormat);
 | 
			
		||||
@ -67,6 +82,7 @@ public class Conv2DConfig extends BaseConvolutionConfig {
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
    public Map<String, Object> toProperties() {
 | 
			
		||||
        Map<String, Object> ret = new LinkedHashMap<>();
 | 
			
		||||
        ret.put("kH", kH);
 | 
			
		||||
@ -82,5 +98,11 @@ public class Conv2DConfig extends BaseConvolutionConfig {
 | 
			
		||||
        return ret;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
    protected void validate() {
 | 
			
		||||
        ConvConfigUtil.validate2D(kH, kW, sH, sW, pH, pW, dH, dW);
 | 
			
		||||
        Preconditions.checkArgument(dataFormat != null, "Data format can't be null");
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -17,30 +17,28 @@
 | 
			
		||||
package org.nd4j.linalg.api.ops.impl.layers.convolution.config;
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
import lombok.AllArgsConstructor;
 | 
			
		||||
import java.util.LinkedHashMap;
 | 
			
		||||
import java.util.Map;
 | 
			
		||||
import lombok.Builder;
 | 
			
		||||
import lombok.Data;
 | 
			
		||||
import lombok.NoArgsConstructor;
 | 
			
		||||
import org.nd4j.base.Preconditions;
 | 
			
		||||
 | 
			
		||||
import java.util.LinkedHashMap;
 | 
			
		||||
import java.util.Map;
 | 
			
		||||
import org.nd4j.linalg.util.ConvConfigUtil;
 | 
			
		||||
 | 
			
		||||
@Data
 | 
			
		||||
@Builder
 | 
			
		||||
@NoArgsConstructor
 | 
			
		||||
@AllArgsConstructor
 | 
			
		||||
public class Conv3DConfig extends BaseConvolutionConfig {
 | 
			
		||||
    public static final String NDHWC = "NDHWC";
 | 
			
		||||
    public static final String NCDHW = "NCDHW";
 | 
			
		||||
 | 
			
		||||
    //kernel
 | 
			
		||||
    @Builder.Default
 | 
			
		||||
    private long kD = 1;
 | 
			
		||||
    private long kD = -1;
 | 
			
		||||
    @Builder.Default
 | 
			
		||||
    private long kW = 1;
 | 
			
		||||
    private long kW = -1;
 | 
			
		||||
    @Builder.Default
 | 
			
		||||
    private long kH = 1;
 | 
			
		||||
    private long kH = -1;
 | 
			
		||||
 | 
			
		||||
    //strides
 | 
			
		||||
    @Builder.Default
 | 
			
		||||
@ -66,14 +64,6 @@ public class Conv3DConfig extends BaseConvolutionConfig {
 | 
			
		||||
    @Builder.Default
 | 
			
		||||
    private long dH = 1;
 | 
			
		||||
 | 
			
		||||
    //output padding
 | 
			
		||||
    @Builder.Default
 | 
			
		||||
    private long aD = 0;
 | 
			
		||||
    @Builder.Default
 | 
			
		||||
    private long aW = 0;
 | 
			
		||||
    @Builder.Default
 | 
			
		||||
    private long aH = 0;
 | 
			
		||||
 | 
			
		||||
    @Builder.Default
 | 
			
		||||
    private boolean biasUsed = false;
 | 
			
		||||
    private boolean isSameMode;
 | 
			
		||||
@ -81,6 +71,27 @@ public class Conv3DConfig extends BaseConvolutionConfig {
 | 
			
		||||
    @Builder.Default
 | 
			
		||||
    private String dataFormat = NDHWC;
 | 
			
		||||
 | 
			
		||||
    public Conv3DConfig(long kD, long kW, long kH, long sD, long sW, long sH, long pD, long pW, long pH, long dD,
 | 
			
		||||
            long dW, long dH, boolean biasUsed, boolean isSameMode, String dataFormat) {
 | 
			
		||||
        this.kD = kD;
 | 
			
		||||
        this.kW = kW;
 | 
			
		||||
        this.kH = kH;
 | 
			
		||||
        this.sD = sD;
 | 
			
		||||
        this.sW = sW;
 | 
			
		||||
        this.sH = sH;
 | 
			
		||||
        this.pD = pD;
 | 
			
		||||
        this.pW = pW;
 | 
			
		||||
        this.pH = pH;
 | 
			
		||||
        this.dD = dD;
 | 
			
		||||
        this.dW = dW;
 | 
			
		||||
        this.dH = dH;
 | 
			
		||||
        this.biasUsed = biasUsed;
 | 
			
		||||
        this.isSameMode = isSameMode;
 | 
			
		||||
        this.dataFormat = dataFormat;
 | 
			
		||||
 | 
			
		||||
        validate();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public boolean isNCDHW(){
 | 
			
		||||
        Preconditions.checkState(dataFormat.equalsIgnoreCase(NCDHW) || dataFormat.equalsIgnoreCase(NDHWC),
 | 
			
		||||
                "Data format must be one of %s or %s, got %s", NCDHW, NDHWC, dataFormat);
 | 
			
		||||
@ -95,6 +106,7 @@ public class Conv3DConfig extends BaseConvolutionConfig {
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
    public Map<String, Object> toProperties() {
 | 
			
		||||
        Map<String, Object> ret = new LinkedHashMap<>();
 | 
			
		||||
        ret.put("kD", kD);
 | 
			
		||||
@ -109,9 +121,6 @@ public class Conv3DConfig extends BaseConvolutionConfig {
 | 
			
		||||
        ret.put("dD", dD);
 | 
			
		||||
        ret.put("dW", dW);
 | 
			
		||||
        ret.put("dH", dH);
 | 
			
		||||
        ret.put("aD", aD);
 | 
			
		||||
        ret.put("aW", aW);
 | 
			
		||||
        ret.put("aH", aH);
 | 
			
		||||
        ret.put("biasUsed", biasUsed);
 | 
			
		||||
        ret.put("dataFormat", dataFormat);
 | 
			
		||||
        ret.put("isSameMode", isSameMode);
 | 
			
		||||
@ -119,5 +128,11 @@ public class Conv3DConfig extends BaseConvolutionConfig {
 | 
			
		||||
        return ret;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
    protected void validate() {
 | 
			
		||||
        ConvConfigUtil.validate3D(kH, kW, kD, sH, sW, sD, pH, pW, pD, dH, dW, dD);
 | 
			
		||||
        Preconditions.checkArgument(dataFormat != null, "Data format can't be null");
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -16,22 +16,22 @@
 | 
			
		||||
 | 
			
		||||
package org.nd4j.linalg.api.ops.impl.layers.convolution.config;
 | 
			
		||||
 | 
			
		||||
import lombok.AllArgsConstructor;
 | 
			
		||||
import java.util.LinkedHashMap;
 | 
			
		||||
import java.util.Map;
 | 
			
		||||
import lombok.Builder;
 | 
			
		||||
import lombok.Data;
 | 
			
		||||
import lombok.NoArgsConstructor;
 | 
			
		||||
import org.nd4j.base.Preconditions;
 | 
			
		||||
import org.nd4j.linalg.util.ConvConfigUtil;
 | 
			
		||||
 | 
			
		||||
import java.util.LinkedHashMap;
 | 
			
		||||
import java.util.Map;
 | 
			
		||||
 | 
			
		||||
@Builder
 | 
			
		||||
@Data
 | 
			
		||||
@AllArgsConstructor
 | 
			
		||||
@Builder
 | 
			
		||||
@NoArgsConstructor
 | 
			
		||||
public class DeConv2DConfig extends BaseConvolutionConfig {
 | 
			
		||||
    public static final String NCHW = "NCHW";
 | 
			
		||||
    public static final String NHWC = "NHWC";
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    @Builder.Default private long kH = -1L;
 | 
			
		||||
    @Builder.Default private long kW = -1L;
 | 
			
		||||
    @Builder.Default private long sH = 1L;
 | 
			
		||||
@ -43,8 +43,25 @@ public class DeConv2DConfig extends BaseConvolutionConfig {
 | 
			
		||||
    @Builder.Default private boolean isSameMode = false;
 | 
			
		||||
    @Builder.Default private String dataFormat = NCHW;
 | 
			
		||||
 | 
			
		||||
    public DeConv2DConfig(long kH, long kW, long sH, long sW, long pH, long pW, long dH, long dW, boolean isSameMode,
 | 
			
		||||
            String dataFormat) {
 | 
			
		||||
        this.kH = kH;
 | 
			
		||||
        this.kW = kW;
 | 
			
		||||
        this.sH = sH;
 | 
			
		||||
        this.sW = sW;
 | 
			
		||||
        this.pH = pH;
 | 
			
		||||
        this.pW = pW;
 | 
			
		||||
        this.dH = dH;
 | 
			
		||||
        this.dW = dW;
 | 
			
		||||
        this.isSameMode = isSameMode;
 | 
			
		||||
        this.dataFormat = dataFormat;
 | 
			
		||||
 | 
			
		||||
        validate();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
    public Map<String, Object> toProperties() {
 | 
			
		||||
 | 
			
		||||
        Map<String, Object> ret = new LinkedHashMap<>();
 | 
			
		||||
        ret.put("kH", kH);
 | 
			
		||||
        ret.put("kW", kW);
 | 
			
		||||
@ -58,4 +75,10 @@ public class DeConv2DConfig extends BaseConvolutionConfig {
 | 
			
		||||
        ret.put("dataFormat", dataFormat);
 | 
			
		||||
        return ret;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
    protected void validate() {
 | 
			
		||||
        ConvConfigUtil.validate2D(kH, kW, sH, sW, pH, pW, dH, dW);
 | 
			
		||||
        Preconditions.checkArgument(dataFormat != null, "Data format can't be null");
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -16,17 +16,16 @@
 | 
			
		||||
 | 
			
		||||
package org.nd4j.linalg.api.ops.impl.layers.convolution.config;
 | 
			
		||||
 | 
			
		||||
import lombok.AllArgsConstructor;
 | 
			
		||||
import java.util.LinkedHashMap;
 | 
			
		||||
import java.util.Map;
 | 
			
		||||
import lombok.Builder;
 | 
			
		||||
import lombok.Data;
 | 
			
		||||
import lombok.NoArgsConstructor;
 | 
			
		||||
import org.nd4j.base.Preconditions;
 | 
			
		||||
import org.nd4j.linalg.util.ConvConfigUtil;
 | 
			
		||||
 | 
			
		||||
import java.util.LinkedHashMap;
 | 
			
		||||
import java.util.Map;
 | 
			
		||||
 | 
			
		||||
@Builder
 | 
			
		||||
@Data
 | 
			
		||||
@AllArgsConstructor
 | 
			
		||||
@Builder
 | 
			
		||||
@NoArgsConstructor
 | 
			
		||||
public class DeConv3DConfig extends BaseConvolutionConfig {
 | 
			
		||||
    public static final String NCDHW = "NCDHW";
 | 
			
		||||
@ -47,7 +46,28 @@ public class DeConv3DConfig extends BaseConvolutionConfig {
 | 
			
		||||
    @Builder.Default private boolean isSameMode = false;
 | 
			
		||||
    @Builder.Default private String dataFormat = NCDHW;
 | 
			
		||||
 | 
			
		||||
    public DeConv3DConfig(long kD, long kH, long kW, long sD, long sH, long sW, long pD, long pH, long pW, long dD,
 | 
			
		||||
            long dH, long dW, boolean isSameMode, String dataFormat) {
 | 
			
		||||
        this.kD = kD;
 | 
			
		||||
        this.kH = kH;
 | 
			
		||||
        this.kW = kW;
 | 
			
		||||
        this.sD = sD;
 | 
			
		||||
        this.sH = sH;
 | 
			
		||||
        this.sW = sW;
 | 
			
		||||
        this.pD = pD;
 | 
			
		||||
        this.pH = pH;
 | 
			
		||||
        this.pW = pW;
 | 
			
		||||
        this.dD = dD;
 | 
			
		||||
        this.dH = dH;
 | 
			
		||||
        this.dW = dW;
 | 
			
		||||
        this.isSameMode = isSameMode;
 | 
			
		||||
        this.dataFormat = dataFormat;
 | 
			
		||||
 | 
			
		||||
        validate();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
    public Map<String, Object> toProperties() {
 | 
			
		||||
        Map<String, Object> ret = new LinkedHashMap<>();
 | 
			
		||||
        ret.put("kD", kD);
 | 
			
		||||
@ -66,4 +86,10 @@ public class DeConv3DConfig extends BaseConvolutionConfig {
 | 
			
		||||
        ret.put("dataFormat", dataFormat);
 | 
			
		||||
        return ret;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
    protected void validate() {
 | 
			
		||||
        ConvConfigUtil.validate3D(kH, kW, kD, sH, sW, sD, pH, pW, pD, dH, dW, dD);
 | 
			
		||||
        Preconditions.checkArgument(dataFormat != null, "Data format can't be null");
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -1,53 +0,0 @@
 | 
			
		||||
/*******************************************************************************
 | 
			
		||||
 * Copyright (c) 2015-2018 Skymind, Inc.
 | 
			
		||||
 *
 | 
			
		||||
 * This program and the accompanying materials are made available under the
 | 
			
		||||
 * terms of the Apache License, Version 2.0 which is available at
 | 
			
		||||
 * https://www.apache.org/licenses/LICENSE-2.0.
 | 
			
		||||
 *
 | 
			
		||||
 * Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 | 
			
		||||
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 | 
			
		||||
 * License for the specific language governing permissions and limitations
 | 
			
		||||
 * under the License.
 | 
			
		||||
 *
 | 
			
		||||
 * SPDX-License-Identifier: Apache-2.0
 | 
			
		||||
 ******************************************************************************/
 | 
			
		||||
 | 
			
		||||
package org.nd4j.linalg.api.ops.impl.layers.convolution.config;
 | 
			
		||||
 | 
			
		||||
import lombok.Builder;
 | 
			
		||||
import lombok.Data;
 | 
			
		||||
 | 
			
		||||
import java.util.LinkedHashMap;
 | 
			
		||||
import java.util.Map;
 | 
			
		||||
 | 
			
		||||
@Builder
 | 
			
		||||
@Data
 | 
			
		||||
public class FullConv3DConfig extends BaseConvolutionConfig {
 | 
			
		||||
    private long dT,dW,dH,pT,pW,pH,dilationT,dilationW,dilationH,aT,aW,aH;
 | 
			
		||||
    private boolean biasUsed;
 | 
			
		||||
    private String dataFormat;
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    public Map<String,Object> toProperties() {
 | 
			
		||||
        Map<String,Object> ret = new LinkedHashMap<>();
 | 
			
		||||
        ret.put("dT",dT);
 | 
			
		||||
        ret.put("dW",dW);
 | 
			
		||||
        ret.put("dH",dH);
 | 
			
		||||
        ret.put("pT",pT);
 | 
			
		||||
        ret.put("pW",pW);
 | 
			
		||||
        ret.put("pH",pH);
 | 
			
		||||
        ret.put("dD",dilationT);
 | 
			
		||||
        ret.put("dW",dilationW);
 | 
			
		||||
        ret.put("dH",dilationH);
 | 
			
		||||
        ret.put("aT",aT);
 | 
			
		||||
        ret.put("aW",aW);
 | 
			
		||||
        ret.put("aH",aH);
 | 
			
		||||
        ret.put("biasUsed",biasUsed);
 | 
			
		||||
        ret.put("dataFormat",dataFormat);
 | 
			
		||||
        return ret;
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
@ -16,19 +16,31 @@
 | 
			
		||||
 | 
			
		||||
package org.nd4j.linalg.api.ops.impl.layers.convolution.config;
 | 
			
		||||
 | 
			
		||||
import lombok.Builder;
 | 
			
		||||
import lombok.Data;
 | 
			
		||||
 | 
			
		||||
import java.util.LinkedHashMap;
 | 
			
		||||
import java.util.Map;
 | 
			
		||||
import lombok.Builder;
 | 
			
		||||
import lombok.Data;
 | 
			
		||||
import lombok.NoArgsConstructor;
 | 
			
		||||
import org.nd4j.linalg.util.ConvConfigUtil;
 | 
			
		||||
 | 
			
		||||
@Data
 | 
			
		||||
@Builder
 | 
			
		||||
@NoArgsConstructor
 | 
			
		||||
public class LocalResponseNormalizationConfig extends BaseConvolutionConfig {
 | 
			
		||||
 | 
			
		||||
    private double alpha, beta, bias;
 | 
			
		||||
    private int depth;
 | 
			
		||||
 | 
			
		||||
    public LocalResponseNormalizationConfig(double alpha, double beta, double bias, int depth) {
 | 
			
		||||
        this.alpha = alpha;
 | 
			
		||||
        this.beta = beta;
 | 
			
		||||
        this.bias = bias;
 | 
			
		||||
        this.depth = depth;
 | 
			
		||||
 | 
			
		||||
        validate();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
    public Map<String, Object> toProperties() {
 | 
			
		||||
        Map<String, Object> ret = new LinkedHashMap<>();
 | 
			
		||||
        ret.put("alpha", alpha);
 | 
			
		||||
@ -38,4 +50,9 @@ public class LocalResponseNormalizationConfig extends BaseConvolutionConfig {
 | 
			
		||||
        return ret;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
    protected void validate() {
 | 
			
		||||
        ConvConfigUtil.validateLRN(alpha, beta, bias, depth);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -16,32 +16,32 @@
 | 
			
		||||
 | 
			
		||||
package org.nd4j.linalg.api.ops.impl.layers.convolution.config;
 | 
			
		||||
 | 
			
		||||
import lombok.AllArgsConstructor;
 | 
			
		||||
import java.util.LinkedHashMap;
 | 
			
		||||
import java.util.Map;
 | 
			
		||||
import lombok.Builder;
 | 
			
		||||
import lombok.Data;
 | 
			
		||||
import lombok.NoArgsConstructor;
 | 
			
		||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2D;
 | 
			
		||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2D.Divisor;
 | 
			
		||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2D.Pooling2DType;
 | 
			
		||||
import org.nd4j.linalg.util.ConvConfigUtil;
 | 
			
		||||
 | 
			
		||||
import java.util.LinkedHashMap;
 | 
			
		||||
import java.util.Map;
 | 
			
		||||
 | 
			
		||||
@Builder
 | 
			
		||||
@AllArgsConstructor
 | 
			
		||||
@Data
 | 
			
		||||
@Builder
 | 
			
		||||
@NoArgsConstructor
 | 
			
		||||
public class Pooling2DConfig extends BaseConvolutionConfig {
 | 
			
		||||
 | 
			
		||||
    private long kH, kW;
 | 
			
		||||
    private long sH, sW;
 | 
			
		||||
    private long pH, pW;
 | 
			
		||||
    private long virtualHeight, virtualWidth;
 | 
			
		||||
    @Builder.Default private long kH = -1, kW = -1;
 | 
			
		||||
    @Builder.Default private long sH = 1, sW = 1;
 | 
			
		||||
    @Builder.Default private long pH = 0, pW = 0;
 | 
			
		||||
    /**
 | 
			
		||||
     * Extra is an optional parameter mainly for use with pnorm right now.
 | 
			
		||||
     * All pooling implementations take 9 parameters save pnorm.
 | 
			
		||||
     * Pnorm takes 10 and is cast to an int.
 | 
			
		||||
     */
 | 
			
		||||
    private double extra;
 | 
			
		||||
    private Pooling2D.Pooling2DType type;
 | 
			
		||||
    @Builder.Default
 | 
			
		||||
    private Pooling2D.Pooling2DType type = Pooling2DType.MAX;
 | 
			
		||||
    @Builder.Default
 | 
			
		||||
    private Pooling2D.Divisor divisor = Pooling2D.Divisor.EXCLUDE_PADDING;
 | 
			
		||||
    private boolean isSameMode;
 | 
			
		||||
@ -52,7 +52,26 @@ public class Pooling2DConfig extends BaseConvolutionConfig {
 | 
			
		||||
    @Builder.Default
 | 
			
		||||
    private boolean isNHWC = false;
 | 
			
		||||
 | 
			
		||||
    public Pooling2DConfig(long kH, long kW, long sH, long sW, long pH, long pW, double extra, Pooling2DType type,
 | 
			
		||||
            Divisor divisor, boolean isSameMode, long dH, long dW, boolean isNHWC) {
 | 
			
		||||
        this.kH = kH;
 | 
			
		||||
        this.kW = kW;
 | 
			
		||||
        this.sH = sH;
 | 
			
		||||
        this.sW = sW;
 | 
			
		||||
        this.pH = pH;
 | 
			
		||||
        this.pW = pW;
 | 
			
		||||
        this.extra = extra;
 | 
			
		||||
        this.type = type;
 | 
			
		||||
        this.divisor = divisor;
 | 
			
		||||
        this.isSameMode = isSameMode;
 | 
			
		||||
        this.dH = dH;
 | 
			
		||||
        this.dW = dW;
 | 
			
		||||
        this.isNHWC = isNHWC;
 | 
			
		||||
 | 
			
		||||
        validate();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
    public Map<String, Object> toProperties() {
 | 
			
		||||
        Map<String, Object> ret = new LinkedHashMap<>();
 | 
			
		||||
        ret.put("kH", kH);
 | 
			
		||||
@ -61,8 +80,6 @@ public class Pooling2DConfig extends BaseConvolutionConfig {
 | 
			
		||||
        ret.put("sW", sW);
 | 
			
		||||
        ret.put("pH", pH);
 | 
			
		||||
        ret.put("pW", pW);
 | 
			
		||||
        ret.put("virtualHeight", virtualHeight);
 | 
			
		||||
        ret.put("virtualWidth", virtualWidth);
 | 
			
		||||
        ret.put("extra", extra);
 | 
			
		||||
        ret.put("type", type.toString());
 | 
			
		||||
        ret.put("isSameMode", isSameMode);
 | 
			
		||||
@ -72,4 +89,11 @@ public class Pooling2DConfig extends BaseConvolutionConfig {
 | 
			
		||||
        return ret;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
    protected void validate() {
 | 
			
		||||
        ConvConfigUtil.validate2D(kH, kW, sH, sW, pH, pW, dH, dW);
 | 
			
		||||
 | 
			
		||||
        //TODO check other args?
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -16,23 +16,22 @@
 | 
			
		||||
 | 
			
		||||
package org.nd4j.linalg.api.ops.impl.layers.convolution.config;
 | 
			
		||||
 | 
			
		||||
import lombok.AllArgsConstructor;
 | 
			
		||||
import java.util.LinkedHashMap;
 | 
			
		||||
import java.util.Map;
 | 
			
		||||
import lombok.Builder;
 | 
			
		||||
import lombok.Data;
 | 
			
		||||
import lombok.NoArgsConstructor;
 | 
			
		||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling3D;
 | 
			
		||||
 | 
			
		||||
import java.util.LinkedHashMap;
 | 
			
		||||
import java.util.Map;
 | 
			
		||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling3D.Pooling3DType;
 | 
			
		||||
import org.nd4j.linalg.util.ConvConfigUtil;
 | 
			
		||||
 | 
			
		||||
@Data
 | 
			
		||||
@Builder
 | 
			
		||||
@AllArgsConstructor
 | 
			
		||||
@NoArgsConstructor
 | 
			
		||||
public class Pooling3DConfig extends BaseConvolutionConfig {
 | 
			
		||||
    private long kD, kW, kH; // kernel
 | 
			
		||||
    private long sD, sW, sH; // strides
 | 
			
		||||
    private long pD, pW, pH; // padding
 | 
			
		||||
    @Builder.Default private long kD = -1, kW = -1, kH = -1; // kernel
 | 
			
		||||
    @Builder.Default private long sD = 1, sW = 1, sH = 1; // strides
 | 
			
		||||
    @Builder.Default private long pD = 0, pW = 0, pH = 0; // padding
 | 
			
		||||
    // dilation
 | 
			
		||||
    @Builder.Default
 | 
			
		||||
    private long dD = 1;
 | 
			
		||||
@ -40,10 +39,33 @@ public class Pooling3DConfig extends BaseConvolutionConfig {
 | 
			
		||||
    private long dW = 1;
 | 
			
		||||
    @Builder.Default
 | 
			
		||||
    private long dH = 1;
 | 
			
		||||
    private Pooling3D.Pooling3DType type;
 | 
			
		||||
    @Builder.Default
 | 
			
		||||
    private Pooling3D.Pooling3DType type = Pooling3DType.MAX;
 | 
			
		||||
    private boolean isSameMode;
 | 
			
		||||
    @Builder.Default private boolean isNCDHW = true;
 | 
			
		||||
 | 
			
		||||
    public Pooling3DConfig(long kD, long kW, long kH, long sD, long sW, long sH, long pD, long pW, long pH, long dD,
 | 
			
		||||
            long dW, long dH, Pooling3DType type, boolean isSameMode, boolean isNCDHW) {
 | 
			
		||||
        this.kD = kD;
 | 
			
		||||
        this.kW = kW;
 | 
			
		||||
        this.kH = kH;
 | 
			
		||||
        this.sD = sD;
 | 
			
		||||
        this.sW = sW;
 | 
			
		||||
        this.sH = sH;
 | 
			
		||||
        this.pD = pD;
 | 
			
		||||
        this.pW = pW;
 | 
			
		||||
        this.pH = pH;
 | 
			
		||||
        this.dD = dD;
 | 
			
		||||
        this.dW = dW;
 | 
			
		||||
        this.dH = dH;
 | 
			
		||||
        this.type = type;
 | 
			
		||||
        this.isSameMode = isSameMode;
 | 
			
		||||
        this.isNCDHW = isNCDHW;
 | 
			
		||||
 | 
			
		||||
        validate();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
    public Map<String, Object> toProperties() {
 | 
			
		||||
        Map<String, Object> ret = new LinkedHashMap<>();
 | 
			
		||||
        ret.put("kD", kD);
 | 
			
		||||
@ -63,4 +85,11 @@ public class Pooling3DConfig extends BaseConvolutionConfig {
 | 
			
		||||
        return ret;
 | 
			
		||||
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
    protected void validate() {
 | 
			
		||||
        ConvConfigUtil.validate3D(kH, kW, kD, sH, sW, sD, pH, pW, pD, dH, dW, dD);
 | 
			
		||||
 | 
			
		||||
        //TODO check other args
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -249,8 +249,6 @@ public class Convolution {
 | 
			
		||||
                        .isSameMode(isSameMode)
 | 
			
		||||
                        .sH(sy)
 | 
			
		||||
                        .sW(sx)
 | 
			
		||||
                        .virtualHeight(virtualHeight)
 | 
			
		||||
                        .virtualWidth(virtualWidth)
 | 
			
		||||
                        .type(type)
 | 
			
		||||
                        .divisor(divisor)
 | 
			
		||||
                        .build())
 | 
			
		||||
 | 
			
		||||
@ -0,0 +1,93 @@
 | 
			
		||||
/*
 | 
			
		||||
 * Copyright (c) 2015-2019 Skymind, Inc.
 | 
			
		||||
 *
 | 
			
		||||
 * This program and the accompanying materials are made available under the
 | 
			
		||||
 * terms of the Apache License, Version 2.0 which is available at
 | 
			
		||||
 * https://www.apache.org/licenses/LICENSE-2.0.
 | 
			
		||||
 *
 | 
			
		||||
 * Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 | 
			
		||||
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 | 
			
		||||
 * License for the specific language governing permissions and limitations
 | 
			
		||||
 * under the License.
 | 
			
		||||
 *
 | 
			
		||||
 * SPDX-License-Identifier: Apache-2.0
 | 
			
		||||
 */
 | 
			
		||||
 | 
			
		||||
package org.nd4j.linalg.util;
 | 
			
		||||
 | 
			
		||||
import lombok.AccessLevel;
 | 
			
		||||
import lombok.NoArgsConstructor;
 | 
			
		||||
import org.nd4j.base.Preconditions;
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * Class with utility methods for validating convolution op configurations like {@link org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig}
 | 
			
		||||
 */
 | 
			
		||||
@NoArgsConstructor(access = AccessLevel.PRIVATE)
 | 
			
		||||
public class ConvConfigUtil {
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Validate a 2D convolution's Kernel, Stride, Padding, and Dilation
 | 
			
		||||
     */
 | 
			
		||||
    public static void validate2D(long kH, long kW, long sH, long sW, long pH, long pW, long dH, long dW){
 | 
			
		||||
        Preconditions.checkArgument(kH != 0, "Kernel height can not be 0");
 | 
			
		||||
        Preconditions.checkArgument(kW != 0, "Kernel width can not be 0");
 | 
			
		||||
 | 
			
		||||
        Preconditions.checkArgument(sH > 0, "Stride height can not be negative or 0, got: %s", sH);
 | 
			
		||||
        Preconditions.checkArgument(sW > 0, "Stride width can not be negative or 0, got: %s", sW);
 | 
			
		||||
 | 
			
		||||
        Preconditions.checkArgument(pH >= 0, "Padding height can not be negative, got: %s", pH);
 | 
			
		||||
        Preconditions.checkArgument(pW >= 0, "Padding width can not be negative, got: %s", pW);
 | 
			
		||||
 | 
			
		||||
        Preconditions.checkArgument(dH > 0, "Dilation height can not be negative or 0, got: %s", dH);
 | 
			
		||||
        Preconditions.checkArgument(dW > 0, "Dilation width can not be negative or 0, got: %s", dW);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Validate a 3D convolution's Kernel, Stride, Padding, and Dilation
 | 
			
		||||
     */
 | 
			
		||||
    public static void validate3D(long kH, long kW, long kD, long sH, long sW, long sD, long pH, long pW, long pD, long dH, long dW, long dD){
 | 
			
		||||
        Preconditions.checkArgument(kH != 0, "Kernel height can not be 0");
 | 
			
		||||
        Preconditions.checkArgument(kW != 0, "Kernel width can not be 0");
 | 
			
		||||
        Preconditions.checkArgument(kD != 0, "Kernel depth can not be 0");
 | 
			
		||||
 | 
			
		||||
        Preconditions.checkArgument(sH > 0, "Stride height can not be negative or 0, got: %s", sH);
 | 
			
		||||
        Preconditions.checkArgument(sW > 0, "Stride width can not be negative or 0, got: %s", sW);
 | 
			
		||||
        Preconditions.checkArgument(sD > 0, "Stride depth can not be negative or 0, got: %s", sD);
 | 
			
		||||
 | 
			
		||||
        Preconditions.checkArgument(pH >= 0, "Padding height can not be negative, got: %s", pH);
 | 
			
		||||
        Preconditions.checkArgument(pW >= 0, "Padding width can not be negative, got: %s", pW);
 | 
			
		||||
        Preconditions.checkArgument(pD >= 0, "Padding depth can not be negative, got: %s", pD);
 | 
			
		||||
 | 
			
		||||
        Preconditions.checkArgument(dH > 0, "Dilation height can not be negative or 0, got: %s", dH);
 | 
			
		||||
        Preconditions.checkArgument(dW > 0, "Dilation width can not be negative or 0, got: %s", dW);
 | 
			
		||||
        Preconditions.checkArgument(dD > 0, "Dilation depth can not be negative or 0, got: %s", dD);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Validate a 3D convolution's Output Padding
 | 
			
		||||
     */
 | 
			
		||||
    public static void validateExtra3D(long aH, long aW, long aD){
 | 
			
		||||
        Preconditions.checkArgument(aH >= 0, "Output padding height can not be negative, got: %s", aH);
 | 
			
		||||
        Preconditions.checkArgument(aW >= 0, "Output padding width can not be negative, got: %s", aW);
 | 
			
		||||
        Preconditions.checkArgument(aD >= 0, "Output padding depth can not be negative, got: %s", aD);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Validate a 1D convolution's Kernel, Stride, and Padding
 | 
			
		||||
     */
 | 
			
		||||
    public static void validate1D(long k, long s, long p){
 | 
			
		||||
        Preconditions.checkArgument(k != 0, "Kernel can not be 0");
 | 
			
		||||
 | 
			
		||||
        Preconditions.checkArgument(s > 0, "Stride can not be negative or 0, got: %s", s);
 | 
			
		||||
 | 
			
		||||
        Preconditions.checkArgument(p >= 0, "Padding can not be negative, got: %s", p);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Validate a LocalResponseNormalizationConfig
 | 
			
		||||
     */
 | 
			
		||||
    public static void validateLRN(double alpha, double beta, double bias, int depth) {
 | 
			
		||||
        Preconditions.checkArgument(depth > 0, "Depth can not be 0 or negative, got: %s", depth);
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
@ -14662,54 +14662,6 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
 | 
			
		||||
                                                                                }
 | 
			
		||||
//         #endif
 | 
			
		||||
 | 
			
		||||
//         #if NOT_EXCLUDED(OP_fullconv3d)
 | 
			
		||||
        @Namespace("nd4j::ops") public static class fullconv3d extends DeclarableCustomOp {
 | 
			
		||||
            static { Loader.load(); }
 | 
			
		||||
            /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
 | 
			
		||||
            public fullconv3d(Pointer p) { super(p); }
 | 
			
		||||
            /** Native array allocator. Access with {@link Pointer#position(long)}. */
 | 
			
		||||
            public fullconv3d(long size) { super((Pointer)null); allocateArray(size); }
 | 
			
		||||
            private native void allocateArray(long size);
 | 
			
		||||
            @Override public fullconv3d position(long position) {
 | 
			
		||||
                return (fullconv3d)super.position(position);
 | 
			
		||||
            }
 | 
			
		||||
        
 | 
			
		||||
                                                                                    public fullconv3d() { super((Pointer)null); allocate(); }
 | 
			
		||||
                                                                                    private native void allocate();
 | 
			
		||||
                                                                                    public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
 | 
			
		||||
                                                                                }
 | 
			
		||||
        @Namespace("nd4j::ops") public static class fullconv3d_bp extends DeclarableCustomOp {
 | 
			
		||||
            static { Loader.load(); }
 | 
			
		||||
            /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
 | 
			
		||||
            public fullconv3d_bp(Pointer p) { super(p); }
 | 
			
		||||
            /** Native array allocator. Access with {@link Pointer#position(long)}. */
 | 
			
		||||
            public fullconv3d_bp(long size) { super((Pointer)null); allocateArray(size); }
 | 
			
		||||
            private native void allocateArray(long size);
 | 
			
		||||
            @Override public fullconv3d_bp position(long position) {
 | 
			
		||||
                return (fullconv3d_bp)super.position(position);
 | 
			
		||||
            }
 | 
			
		||||
        
 | 
			
		||||
                                                                                    public fullconv3d_bp() { super((Pointer)null); allocate(); }
 | 
			
		||||
                                                                                    private native void allocate();
 | 
			
		||||
                                                                                    public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
 | 
			
		||||
                                                                                }
 | 
			
		||||
        @Namespace("nd4j::ops") public static class fullconv3d_grad extends DeclarableCustomOp {
 | 
			
		||||
            static { Loader.load(); }
 | 
			
		||||
            /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
 | 
			
		||||
            public fullconv3d_grad(Pointer p) { super(p); }
 | 
			
		||||
            /** Native array allocator. Access with {@link Pointer#position(long)}. */
 | 
			
		||||
            public fullconv3d_grad(long size) { super((Pointer)null); allocateArray(size); }
 | 
			
		||||
            private native void allocateArray(long size);
 | 
			
		||||
            @Override public fullconv3d_grad position(long position) {
 | 
			
		||||
                return (fullconv3d_grad)super.position(position);
 | 
			
		||||
            }
 | 
			
		||||
        
 | 
			
		||||
                                                                                    public fullconv3d_grad() { super((Pointer)null); allocate(); }
 | 
			
		||||
                                                                                    private native void allocate();
 | 
			
		||||
                                                                                    public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
 | 
			
		||||
                                                                                }
 | 
			
		||||
//         #endif
 | 
			
		||||
 | 
			
		||||
        /**
 | 
			
		||||
         * This op implements im2col algorithm, widely used in convolution neural networks
 | 
			
		||||
         * Input: 4D input expected
 | 
			
		||||
 | 
			
		||||
@ -1124,7 +1124,7 @@ public class LayerOpValidation extends BaseOpValidation {
 | 
			
		||||
        assertNull(err, err);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test(expected = IllegalStateException.class)
 | 
			
		||||
    @Test(expected = IllegalArgumentException.class)
 | 
			
		||||
    public void exceptionThrown_WhenConv1DConfigInvalid() {
 | 
			
		||||
        int nIn = 3;
 | 
			
		||||
        int nOut = 4;
 | 
			
		||||
@ -1150,7 +1150,7 @@ public class LayerOpValidation extends BaseOpValidation {
 | 
			
		||||
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test(expected = IllegalStateException.class)
 | 
			
		||||
    @Test(expected = IllegalArgumentException.class)
 | 
			
		||||
    public void exceptionThrown_WhenConv2DConfigInvalid() {
 | 
			
		||||
 | 
			
		||||
        Nd4j.getRandom().setSeed(12345);
 | 
			
		||||
@ -1171,7 +1171,7 @@ public class LayerOpValidation extends BaseOpValidation {
 | 
			
		||||
                .build());
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test(expected = IllegalStateException.class)
 | 
			
		||||
    @Test(expected = IllegalArgumentException.class)
 | 
			
		||||
    public void exceptionThrown_WhenConf3DInvalid() {
 | 
			
		||||
        Nd4j.getRandom().setSeed(12345);
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -0,0 +1,515 @@
 | 
			
		||||
/*
 | 
			
		||||
 * Copyright (c) 2015-2019 Skymind, Inc.
 | 
			
		||||
 *
 | 
			
		||||
 * This program and the accompanying materials are made available under the
 | 
			
		||||
 * terms of the Apache License, Version 2.0 which is available at
 | 
			
		||||
 * https://www.apache.org/licenses/LICENSE-2.0.
 | 
			
		||||
 *
 | 
			
		||||
 * Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 | 
			
		||||
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 | 
			
		||||
 * License for the specific language governing permissions and limitations
 | 
			
		||||
 * under the License.
 | 
			
		||||
 *
 | 
			
		||||
 * SPDX-License-Identifier: Apache-2.0
 | 
			
		||||
 */
 | 
			
		||||
 | 
			
		||||
package org.nd4j.autodiff.samediff;
 | 
			
		||||
 | 
			
		||||
import static org.junit.Assert.assertThat;
 | 
			
		||||
import static org.junit.Assert.assertTrue;
 | 
			
		||||
import static org.junit.Assert.fail;
 | 
			
		||||
 | 
			
		||||
import org.junit.Assert;
 | 
			
		||||
import org.junit.Test;
 | 
			
		||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.DeConv2D;
 | 
			
		||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv1DConfig;
 | 
			
		||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
 | 
			
		||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv3DConfig;
 | 
			
		||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv2DConfig;
 | 
			
		||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv3DConfig;
 | 
			
		||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling2DConfig;
 | 
			
		||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling3DConfig;
 | 
			
		||||
 | 
			
		||||
public class ConvConfigTests {
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
    public void testDeConv2D(){
 | 
			
		||||
        DeConv2DConfig.builder().kH(2).kW(4).build();
 | 
			
		||||
 | 
			
		||||
        try{
 | 
			
		||||
            DeConv2DConfig.builder().kW(4).kH(0).build();
 | 
			
		||||
            fail();
 | 
			
		||||
        } catch (IllegalArgumentException e){
 | 
			
		||||
            assertTrue(e.getMessage().contains("Kernel height"));
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        try{
 | 
			
		||||
            DeConv2DConfig.builder().kH(4).kW(0).build();
 | 
			
		||||
            fail();
 | 
			
		||||
        } catch (IllegalArgumentException e){
 | 
			
		||||
            assertTrue(e.getMessage().contains("Kernel width"));
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        try{
 | 
			
		||||
            DeConv2DConfig.builder().kH(4).kW(3).sH(-2).build();
 | 
			
		||||
            fail();
 | 
			
		||||
        } catch (IllegalArgumentException e){
 | 
			
		||||
            assertTrue(e.getMessage().contains("Stride height"));
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        try{
 | 
			
		||||
            DeConv2DConfig.builder().kH(4).kW(3).sW(-2).build();
 | 
			
		||||
            fail();
 | 
			
		||||
        } catch (IllegalArgumentException e){
 | 
			
		||||
            assertTrue(e.getMessage().contains("Stride width"));
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        try{
 | 
			
		||||
            DeConv2DConfig.builder().kH(4).kW(3).pH(-2).build();
 | 
			
		||||
            fail();
 | 
			
		||||
        } catch (IllegalArgumentException e){
 | 
			
		||||
            assertTrue(e.getMessage().contains("Padding height"));
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        try{
 | 
			
		||||
            DeConv2DConfig.builder().kH(4).kW(3).pW(-2).build();
 | 
			
		||||
            fail();
 | 
			
		||||
        } catch (IllegalArgumentException e){
 | 
			
		||||
            assertTrue(e.getMessage().contains("Padding width"));
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        try{
 | 
			
		||||
            DeConv2DConfig.builder().kH(4).kW(3).dH(-2).build();
 | 
			
		||||
            fail();
 | 
			
		||||
        } catch (IllegalArgumentException e){
 | 
			
		||||
            assertTrue(e.getMessage().contains("Dilation height"));
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        try{
 | 
			
		||||
            DeConv2DConfig.builder().kH(4).kW(3).dW(-2).build();
 | 
			
		||||
            fail();
 | 
			
		||||
        } catch (IllegalArgumentException e){
 | 
			
		||||
            assertTrue(e.getMessage().contains("Dilation width"));
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
    public void testConv2D(){
 | 
			
		||||
        Conv2DConfig.builder().kH(2).kW(4).build();
 | 
			
		||||
 | 
			
		||||
        try{
 | 
			
		||||
            Conv2DConfig.builder().kW(4).kH(0).build();
 | 
			
		||||
            fail();
 | 
			
		||||
        } catch (IllegalArgumentException e){
 | 
			
		||||
            assertTrue(e.getMessage().contains("Kernel height"));
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        try{
 | 
			
		||||
            Conv2DConfig.builder().kH(4).kW(0).build();
 | 
			
		||||
            fail();
 | 
			
		||||
        } catch (IllegalArgumentException e){
 | 
			
		||||
            assertTrue(e.getMessage().contains("Kernel width"));
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        try{
 | 
			
		||||
            Conv2DConfig.builder().kH(4).kW(3).sH(-2).build();
 | 
			
		||||
            fail();
 | 
			
		||||
        } catch (IllegalArgumentException e){
 | 
			
		||||
            assertTrue(e.getMessage().contains("Stride height"));
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        try{
 | 
			
		||||
            Conv2DConfig.builder().kH(4).kW(3).sW(-2).build();
 | 
			
		||||
            fail();
 | 
			
		||||
        } catch (IllegalArgumentException e){
 | 
			
		||||
            assertTrue(e.getMessage().contains("Stride width"));
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        try{
 | 
			
		||||
            Conv2DConfig.builder().kH(4).kW(3).pH(-2).build();
 | 
			
		||||
            fail();
 | 
			
		||||
        } catch (IllegalArgumentException e){
 | 
			
		||||
            assertTrue(e.getMessage().contains("Padding height"));
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        try{
 | 
			
		||||
            Conv2DConfig.builder().kH(4).kW(3).pW(-2).build();
 | 
			
		||||
            fail();
 | 
			
		||||
        } catch (IllegalArgumentException e){
 | 
			
		||||
            assertTrue(e.getMessage().contains("Padding width"));
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        try{
 | 
			
		||||
            Conv2DConfig.builder().kH(4).kW(3).dH(-2).build();
 | 
			
		||||
            fail();
 | 
			
		||||
        } catch (IllegalArgumentException e){
 | 
			
		||||
            assertTrue(e.getMessage().contains("Dilation height"));
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        try{
 | 
			
		||||
            Conv2DConfig.builder().kH(4).kW(3).dW(-2).build();
 | 
			
		||||
            fail();
 | 
			
		||||
        } catch (IllegalArgumentException e){
 | 
			
		||||
            assertTrue(e.getMessage().contains("Dilation width"));
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
    public void testPooling2D(){
 | 
			
		||||
        Pooling2DConfig.builder().kH(2).kW(4).build();
 | 
			
		||||
 | 
			
		||||
        try{
 | 
			
		||||
            Pooling2DConfig.builder().kW(4).kH(0).build();
 | 
			
		||||
            fail();
 | 
			
		||||
        } catch (IllegalArgumentException e){
 | 
			
		||||
            assertTrue(e.getMessage().contains("Kernel height"));
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        try{
 | 
			
		||||
            Pooling2DConfig.builder().kH(4).kW(0).build();
 | 
			
		||||
            fail();
 | 
			
		||||
        } catch (IllegalArgumentException e){
 | 
			
		||||
            assertTrue(e.getMessage().contains("Kernel width"));
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        try{
 | 
			
		||||
            Pooling2DConfig.builder().kH(4).kW(3).sH(-2).build();
 | 
			
		||||
            fail();
 | 
			
		||||
        } catch (IllegalArgumentException e){
 | 
			
		||||
            assertTrue(e.getMessage().contains("Stride height"));
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        try{
 | 
			
		||||
            Pooling2DConfig.builder().kH(4).kW(3).sW(-2).build();
 | 
			
		||||
            fail();
 | 
			
		||||
        } catch (IllegalArgumentException e){
 | 
			
		||||
            assertTrue(e.getMessage().contains("Stride width"));
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        try{
 | 
			
		||||
            Pooling2DConfig.builder().kH(4).kW(3).pH(-2).build();
 | 
			
		||||
            fail();
 | 
			
		||||
        } catch (IllegalArgumentException e){
 | 
			
		||||
            assertTrue(e.getMessage().contains("Padding height"));
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        try{
 | 
			
		||||
            Pooling2DConfig.builder().kH(4).kW(3).pW(-2).build();
 | 
			
		||||
            fail();
 | 
			
		||||
        } catch (IllegalArgumentException e){
 | 
			
		||||
            assertTrue(e.getMessage().contains("Padding width"));
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        try{
 | 
			
		||||
            Pooling2DConfig.builder().kH(4).kW(3).dH(-2).build();
 | 
			
		||||
            fail();
 | 
			
		||||
        } catch (IllegalArgumentException e){
 | 
			
		||||
            assertTrue(e.getMessage().contains("Dilation height"));
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        try{
 | 
			
		||||
            Pooling2DConfig.builder().kH(4).kW(3).dW(-2).build();
 | 
			
		||||
            fail();
 | 
			
		||||
        } catch (IllegalArgumentException e){
 | 
			
		||||
            assertTrue(e.getMessage().contains("Dilation width"));
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
    public void testDeConv3D(){
 | 
			
		||||
        DeConv3DConfig.builder().kH(2).kW(4).kD(3).build();
 | 
			
		||||
 | 
			
		||||
        try{
 | 
			
		||||
            DeConv3DConfig.builder().kW(4).kD(3).kH(0).build();
 | 
			
		||||
            fail();
 | 
			
		||||
        } catch (IllegalArgumentException e){
 | 
			
		||||
            assertTrue(e.getMessage().contains("Kernel height"));
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        try{
 | 
			
		||||
            DeConv3DConfig.builder().kH(4).kD(3).kW(0).build();
 | 
			
		||||
            fail();
 | 
			
		||||
        } catch (IllegalArgumentException e){
 | 
			
		||||
            assertTrue(e.getMessage().contains("Kernel width"));
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        try{
 | 
			
		||||
            DeConv3DConfig.builder().kH(4).kW(3).kD(0).build();
 | 
			
		||||
            fail();
 | 
			
		||||
        } catch (IllegalArgumentException e){
 | 
			
		||||
            assertTrue(e.getMessage().contains("Kernel depth"));
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        try{
 | 
			
		||||
            DeConv3DConfig.builder().kH(4).kW(3).kD(3).sH(-2).build();
 | 
			
		||||
            fail();
 | 
			
		||||
        } catch (IllegalArgumentException e){
 | 
			
		||||
            assertTrue(e.getMessage().contains("Stride height"));
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        try{
 | 
			
		||||
            DeConv3DConfig.builder().kH(4).kW(3).kD(3).sW(-2).build();
 | 
			
		||||
            fail();
 | 
			
		||||
        } catch (IllegalArgumentException e){
 | 
			
		||||
            assertTrue(e.getMessage().contains("Stride width"));
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        try{
 | 
			
		||||
            DeConv3DConfig.builder().kH(4).kW(3).kD(3).sD(-2).build();
 | 
			
		||||
            fail();
 | 
			
		||||
        } catch (IllegalArgumentException e){
 | 
			
		||||
            assertTrue(e.getMessage().contains("Stride depth"));
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        try{
 | 
			
		||||
            DeConv3DConfig.builder().kH(4).kW(3).kD(3).pH(-2).build();
 | 
			
		||||
            fail();
 | 
			
		||||
        } catch (IllegalArgumentException e){
 | 
			
		||||
            assertTrue(e.getMessage().contains("Padding height"));
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        try{
 | 
			
		||||
            DeConv3DConfig.builder().kH(4).kW(3).kD(3).pW(-2).build();
 | 
			
		||||
            fail();
 | 
			
		||||
        } catch (IllegalArgumentException e){
 | 
			
		||||
            assertTrue(e.getMessage().contains("Padding width"));
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        try{
 | 
			
		||||
            DeConv3DConfig.builder().kH(4).kW(3).kD(3).pD(-2).build();
 | 
			
		||||
            fail();
 | 
			
		||||
        } catch (IllegalArgumentException e){
 | 
			
		||||
            assertTrue(e.getMessage().contains("Padding depth"));
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        try{
 | 
			
		||||
            DeConv3DConfig.builder().kH(4).kW(3).kD(3).dH(-2).build();
 | 
			
		||||
            fail();
 | 
			
		||||
        } catch (IllegalArgumentException e){
 | 
			
		||||
            assertTrue(e.getMessage().contains("Dilation height"));
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        try{
 | 
			
		||||
            DeConv3DConfig.builder().kH(4).kW(3).kD(3).dW(-2).build();
 | 
			
		||||
            fail();
 | 
			
		||||
        } catch (IllegalArgumentException e){
 | 
			
		||||
            assertTrue(e.getMessage().contains("Dilation width"));
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        try{
 | 
			
		||||
            DeConv3DConfig.builder().kH(4).kW(3).kD(3).dD(-2).build();
 | 
			
		||||
            fail();
 | 
			
		||||
        } catch (IllegalArgumentException e){
 | 
			
		||||
            assertTrue(e.getMessage().contains("Dilation depth"));
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
    public void testConv3D(){
 | 
			
		||||
        Conv3DConfig.builder().kH(2).kW(4).kD(3).build();
 | 
			
		||||
 | 
			
		||||
        try{
 | 
			
		||||
            Conv3DConfig.builder().kW(4).kD(3).kH(0).build();
 | 
			
		||||
            fail();
 | 
			
		||||
        } catch (IllegalArgumentException e){
 | 
			
		||||
            assertTrue(e.getMessage().contains("Kernel height"));
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        try{
 | 
			
		||||
            Conv3DConfig.builder().kH(4).kD(3).kW(0).build();
 | 
			
		||||
            fail();
 | 
			
		||||
        } catch (IllegalArgumentException e){
 | 
			
		||||
            assertTrue(e.getMessage().contains("Kernel width"));
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        try{
 | 
			
		||||
            Conv3DConfig.builder().kH(4).kW(3).kD(0).build();
 | 
			
		||||
            fail();
 | 
			
		||||
        } catch (IllegalArgumentException e){
 | 
			
		||||
            assertTrue(e.getMessage().contains("Kernel depth"));
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        try{
 | 
			
		||||
            Conv3DConfig.builder().kH(4).kW(3).kD(3).sH(-2).build();
 | 
			
		||||
            fail();
 | 
			
		||||
        } catch (IllegalArgumentException e){
 | 
			
		||||
            assertTrue(e.getMessage().contains("Stride height"));
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        try{
 | 
			
		||||
            Conv3DConfig.builder().kH(4).kW(3).kD(3).sW(-2).build();
 | 
			
		||||
            fail();
 | 
			
		||||
        } catch (IllegalArgumentException e){
 | 
			
		||||
            assertTrue(e.getMessage().contains("Stride width"));
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        try{
 | 
			
		||||
            Conv3DConfig.builder().kH(4).kW(3).kD(3).sD(-2).build();
 | 
			
		||||
            fail();
 | 
			
		||||
        } catch (IllegalArgumentException e){
 | 
			
		||||
            assertTrue(e.getMessage().contains("Stride depth"));
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        try{
 | 
			
		||||
            Conv3DConfig.builder().kH(4).kW(3).kD(3).pH(-2).build();
 | 
			
		||||
            fail();
 | 
			
		||||
        } catch (IllegalArgumentException e){
 | 
			
		||||
            assertTrue(e.getMessage().contains("Padding height"));
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        try{
 | 
			
		||||
            Conv3DConfig.builder().kH(4).kW(3).kD(3).pW(-2).build();
 | 
			
		||||
            fail();
 | 
			
		||||
        } catch (IllegalArgumentException e){
 | 
			
		||||
            assertTrue(e.getMessage().contains("Padding width"));
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        try{
 | 
			
		||||
            Conv3DConfig.builder().kH(4).kW(3).kD(3).pD(-2).build();
 | 
			
		||||
            fail();
 | 
			
		||||
        } catch (IllegalArgumentException e){
 | 
			
		||||
            assertTrue(e.getMessage().contains("Padding depth"));
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        try{
 | 
			
		||||
            Conv3DConfig.builder().kH(4).kW(3).kD(3).dH(-2).build();
 | 
			
		||||
            fail();
 | 
			
		||||
        } catch (IllegalArgumentException e){
 | 
			
		||||
            assertTrue(e.getMessage().contains("Dilation height"));
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        try{
 | 
			
		||||
            Conv3DConfig.builder().kH(4).kW(3).kD(3).dW(-2).build();
 | 
			
		||||
            fail();
 | 
			
		||||
        } catch (IllegalArgumentException e){
 | 
			
		||||
            assertTrue(e.getMessage().contains("Dilation width"));
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        try{
 | 
			
		||||
            Conv3DConfig.builder().kH(4).kW(3).kD(3).dD(-2).build();
 | 
			
		||||
            fail();
 | 
			
		||||
        } catch (IllegalArgumentException e){
 | 
			
		||||
            assertTrue(e.getMessage().contains("Dilation depth"));
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
    public void testPooling3D(){
 | 
			
		||||
        Pooling3DConfig.builder().kH(2).kW(4).kD(3).build();
 | 
			
		||||
 | 
			
		||||
        try{
 | 
			
		||||
            Pooling3DConfig.builder().kW(4).kD(3).kH(0).build();
 | 
			
		||||
            fail();
 | 
			
		||||
        } catch (IllegalArgumentException e){
 | 
			
		||||
            assertTrue(e.getMessage().contains("Kernel height"));
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        try{
 | 
			
		||||
            Pooling3DConfig.builder().kH(4).kD(3).kW(0).build();
 | 
			
		||||
            fail();
 | 
			
		||||
        } catch (IllegalArgumentException e){
 | 
			
		||||
            assertTrue(e.getMessage().contains("Kernel width"));
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        try{
 | 
			
		||||
            Pooling3DConfig.builder().kH(4).kW(3).kD(0).build();
 | 
			
		||||
            fail();
 | 
			
		||||
        } catch (IllegalArgumentException e){
 | 
			
		||||
            assertTrue(e.getMessage().contains("Kernel depth"));
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        try{
 | 
			
		||||
            Pooling3DConfig.builder().kH(4).kW(3).kD(3).sH(-2).build();
 | 
			
		||||
            fail();
 | 
			
		||||
        } catch (IllegalArgumentException e){
 | 
			
		||||
            assertTrue(e.getMessage().contains("Stride height"));
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        try{
 | 
			
		||||
            Pooling3DConfig.builder().kH(4).kW(3).kD(3).sW(-2).build();
 | 
			
		||||
            fail();
 | 
			
		||||
        } catch (IllegalArgumentException e){
 | 
			
		||||
            assertTrue(e.getMessage().contains("Stride width"));
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        try{
 | 
			
		||||
            Pooling3DConfig.builder().kH(4).kW(3).kD(3).sD(-2).build();
 | 
			
		||||
            fail();
 | 
			
		||||
        } catch (IllegalArgumentException e){
 | 
			
		||||
            assertTrue(e.getMessage().contains("Stride depth"));
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        try{
 | 
			
		||||
            Pooling3DConfig.builder().kH(4).kW(3).kD(3).pH(-2).build();
 | 
			
		||||
            fail();
 | 
			
		||||
        } catch (IllegalArgumentException e){
 | 
			
		||||
            assertTrue(e.getMessage().contains("Padding height"));
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        try{
 | 
			
		||||
            Pooling3DConfig.builder().kH(4).kW(3).kD(3).pW(-2).build();
 | 
			
		||||
            fail();
 | 
			
		||||
        } catch (IllegalArgumentException e){
 | 
			
		||||
            assertTrue(e.getMessage().contains("Padding width"));
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        try{
 | 
			
		||||
            Pooling3DConfig.builder().kH(4).kW(3).kD(3).pD(-2).build();
 | 
			
		||||
            fail();
 | 
			
		||||
        } catch (IllegalArgumentException e){
 | 
			
		||||
            assertTrue(e.getMessage().contains("Padding depth"));
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        try{
 | 
			
		||||
            Pooling3DConfig.builder().kH(4).kW(3).kD(3).dH(-2).build();
 | 
			
		||||
            fail();
 | 
			
		||||
        } catch (IllegalArgumentException e){
 | 
			
		||||
            assertTrue(e.getMessage().contains("Dilation height"));
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        try{
 | 
			
		||||
            Pooling3DConfig.builder().kH(4).kW(3).kD(3).dW(-2).build();
 | 
			
		||||
            fail();
 | 
			
		||||
        } catch (IllegalArgumentException e){
 | 
			
		||||
            assertTrue(e.getMessage().contains("Dilation width"));
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        try{
 | 
			
		||||
            Pooling3DConfig.builder().kH(4).kW(3).kD(3).dD(-2).build();
 | 
			
		||||
            fail();
 | 
			
		||||
        } catch (IllegalArgumentException e){
 | 
			
		||||
            assertTrue(e.getMessage().contains("Dilation depth"));
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
    public void testConv1D(){
 | 
			
		||||
        Conv1DConfig.builder().k(2).build();
 | 
			
		||||
 | 
			
		||||
        try{
 | 
			
		||||
            Conv1DConfig.builder().k(0).build();
 | 
			
		||||
            fail();
 | 
			
		||||
        } catch (IllegalArgumentException e){
 | 
			
		||||
            assertTrue(e.getMessage().contains("Kernel"));
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        try{
 | 
			
		||||
            Conv1DConfig.builder().k(4).s(-2).build();
 | 
			
		||||
            fail();
 | 
			
		||||
        } catch (IllegalArgumentException e){
 | 
			
		||||
            assertTrue(e.getMessage().contains("Stride"));
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        try{
 | 
			
		||||
            Conv1DConfig.builder().k(3).p(-2).build();
 | 
			
		||||
            fail();
 | 
			
		||||
        } catch (IllegalArgumentException e){
 | 
			
		||||
            assertTrue(e.getMessage().contains("Padding"));
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user