SameDiff Convolution Config validation, better output methods (#82)

* Conv Config validation & tests

Signed-off-by: Ryan Nett <rnett@skymind.io>

* stackOutputs utility method

Signed-off-by: Ryan Nett <rnett@skymind.io>

* use constructor for validation, support negative kernel sizes (infered from weights)

Signed-off-by: Ryan Nett <rnett@skymind.io>

* better output methods

Signed-off-by: Ryan Nett <rnett@skymind.io>

* move output to be with fit and evaluate

Signed-off-by: Ryan Nett <rnett@skymind.io>

* fixes

Signed-off-by: Ryan Nett <rnett@skymind.io>

* more fixes

Signed-off-by: Ryan Nett <rnett@skymind.io>
master
Ryan Nett 2019-07-26 20:05:16 -07:00 committed by AlexDBlack
parent 8d1fe8b1b3
commit d4e7997134
27 changed files with 1085 additions and 992 deletions

View File

@ -1,444 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// Created by raver119 on 08.10.2017.
//
#include <op_boilerplate.h>
#if NOT_EXCLUDED(OP_fullconv3d)
#include <ops/declarable/CustomOperations.h>
#include <ops/declarable/helpers/convolutions.h>
namespace nd4j {
namespace ops {
//////////////////////////////////////////////////////////////////////////
CUSTOM_OP_IMPL(fullconv3d, 5, 1, false, 0, 13) {
// auto input = INPUT_VARIABLE(0);
// auto weights = INPUT_VARIABLE(1);
// auto bias = INPUT_VARIABLE(2);
// auto columns = INPUT_VARIABLE(3);
// auto ones = INPUT_VARIABLE(4);
// REQUIRE_TRUE(weights->rankOf() == 5, 0, "Weights should be 5D, got %i instead", weights->rankOf());
// REQUIRE_TRUE(input->rankOf() == 5, 0, "Input should be 5D, got %i instead", input->rankOf());
// // strides
// int dT = INT_ARG(0);
// int dW = INT_ARG(1);
// int dH = INT_ARG(2);
// // padding
// int pT = INT_ARG(3);
// int pW = INT_ARG(4);
// int pH = INT_ARG(5);
// // dilation
// int dilationT = INT_ARG(6);
// int dilationW = INT_ARG(7);
// int dilationH = INT_ARG(8);
// // output padding
// int aT = INT_ARG(9);
// int aW = INT_ARG(10);
// int aH = INT_ARG(11);
// // bias
// bool biasUsed = INT_ARG(12) != 0;
// REQUIRE_TRUE(dT > 0 && dW > 0 && dH > 0, 11,
// "stride should be greater than zero, but got dT: %d dH: %d dW: %d", dT, dH, dW);
// REQUIRE_TRUE(dilationT > 0 && dilationW > 0 && dilationH > 0, 15,
// "dilation should be greater than zero, but got dilationT: %d, dilationH: %d, dilationW: %d",
// dilationT, dilationH, dilationW);
// REQUIRE_TRUE((aT < dT || aT < dilationT)
// && (aW < dW || aW < dilationW)
// && (aH < dH || aH < dilationH), 15,
// "output padding must be smaller than either stride or dilation,"
// " but got aT: %d aH: %d aW: %d dT: %d dH: %d dW: %d "
// "dilationT: %d dilationH: %d dilationW: %d",
// aT, aH, aW, dT, dH, dW, dilationT, dilationH, dilationW);
// auto output = this->getZ(block);
// const int nInputPlane = weights->shapeOf()[0];
// const int nOutputPlane = weights->shapeOf()[1];
// const int kT = weights->shapeOf()[2];
// const int kH = weights->shapeOf()[3];
// const int kW = weights->shapeOf()[4];
// const Nd4jLong inputWidth = input->shapeOf()[4];
// const Nd4jLong inputHeight = input->shapeOf()[3];
// const Nd4jLong inputDepth = input->shapeOf()[2];
// const Nd4jLong outputDepth = (inputDepth - 1) * dT - 2*pT + (dilationT * (kT - 1) + 1) + aT;
// const Nd4jLong outputHeight = (inputHeight - 1) * dH - 2*pH + (dilationH * (kH - 1) + 1) + aH;
// const Nd4jLong outputWidth = (inputWidth - 1) * dW - 2*pW + (dilationW * (kW - 1) + 1) + aW;
// const Nd4jLong batchSize = input->shapeOf()[0];
// REQUIRE_TRUE(output->isSameShape({ (int) batchSize, (int)nOutputPlane, (int)outputDepth, (int)outputHeight, (int)outputWidth}), 0, "Output should have shape of [%i, %i, %i, %i, %i], but got [%i, %i, %i, %i, %i] instead", (int) batchSize, (int)nOutputPlane, (int)outputDepth, (int)outputHeight, (int)outputWidth, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2), output->sizeAt(3), output->sizeAt(4));
// std::unique_ptr<ResultSet> inputs(input->allExamples());
// std::unique_ptr<ResultSet> outputs(output->allExamples());
// for (int e = 0; e < batchSize; e++) {
// auto tadIn = inputs->at(e);
// auto tadOut = outputs->at(e);
// const int m = weights->shapeOf()[1] * weights->shapeOf()[2] * weights->shapeOf()[3] * weights->shapeOf()[4];
// const int n = columns->shapeOf()[1];
// const int k = weights->shapeOf()[0];
// // FIXME: mmul helper should be used here
// /*
// nd4j::blas::GEMM<T>::op('c', 'n', 't', m, n, k,
// 1.0,
// tadIn->getBuffer(), n,
// weights->getBuffer(), m,
// 0.0,
// columns->getBuffer(), n);
// */
// // ConvolutionUtils<T>::_col2vol(columns->getBuffer(),
// // nOutputPlane, outputDepth, outputHeight, outputWidth,
// // inputDepth, inputHeight, inputWidth,
// // kT, kH, kW,
// // pT, pH, pW,
// // dT, dH, dW,
// // dilationT, dilationH, dilationW,
// // tadOut->getBuffer());
// ConvolutionUtils::col2vol(*columns, *tadOut, dT, dH, dW, pT, pH, pW, dilationT, dilationH, dilationW);
// const int m_ = nOutputPlane;
// const int n_ = outputDepth * outputHeight * outputWidth;
// const int k_ = 1;
// if (biasUsed) {
// // FIXME: mmul helper should be used here
// /*
// nd4j::blas::GEMM<T>::op('c', 't', 'n', n_, m_, k_,
// 1.0,
// ones->getBuffer(), k_,
// bias->getBuffer(), k_,
// 1.0,
// tadOut->getBuffer(), n_);
// */
// }
// }
// STORE_RESULT(*output);
return Status::OK();
}
DECLARE_TYPES(fullconv3d) {
getOpDescriptor()
->setAllowedInputTypes(nd4j::DataType::ANY)
->setAllowedOutputTypes({ALL_FLOATS});
}
DECLARE_SHAPE_FN(fullconv3d) {
// auto input = inputShape->at(0);
// auto weights = inputShape->at(1);
// // strides
// int dT = INT_ARG(0);
// int dW = INT_ARG(1);
// int dH = INT_ARG(2);
// // padding
// int pT = INT_ARG(3);
// int pW = INT_ARG(4);
// int pH = INT_ARG(5);
// // dilation
// int dilationT = INT_ARG(6);
// int dilationW = INT_ARG(7);
// int dilationH = INT_ARG(8);
// // output padding
// int aT = INT_ARG(9);
// int aW = INT_ARG(10);
// int aH = INT_ARG(11);
// // bias
// bool biasUsed = INT_ARG(12) != 0;
// Nd4jLong *shapeOf;
// Nd4jLong *newShape;
// ALLOCATE(shapeOf, block.getWorkspace(), 5, Nd4jLong);
// ALLOCATE(newShape, block.getWorkspace(), shape::shapeInfoLength(5), Nd4jLong);
// const int nInputPlane = weights[1];
// const int nOutputPlane = weights[2];
// const int kT = weights[3];
// const int kH = weights[4];
// const int kW = weights[5];
// const int batchSize = input[1];
// const Nd4jLong inputWidth = input[5];
// const Nd4jLong inputHeight = input[4];
// const Nd4jLong inputDepth = input[3];
// const Nd4jLong outputDepth = (inputDepth - 1) * dT - 2*pT + (dilationT * (kT - 1) + 1) + aT;
// const Nd4jLong outputHeight = (inputHeight - 1) * dH - 2*pH + (dilationH * (kH - 1) + 1) + aH;
// const Nd4jLong outputWidth = (inputWidth - 1) * dW - 2*pW + (dilationW * (kW - 1) + 1) + aW;
// nd4j::ArrayUtils::toLongPtr({(Nd4jLong) batchSize, (Nd4jLong)nOutputPlane, (Nd4jLong)outputDepth, (Nd4jLong)outputHeight, (Nd4jLong)outputWidth}, shapeOf);
// shape::shapeBuffer(5, shapeOf, newShape);
// RELEASE(shapeOf, block.getWorkspace());
// return SHAPELIST(newShape);
return SHAPELIST();
}
DECLARE_TYPES(fullconv3d_bp) {
getOpDescriptor()
->setAllowedInputTypes(nd4j::DataType::ANY)
->setAllowedOutputTypes({ALL_FLOATS});
}
//////////////////////////////////////////////////////////////////////////
CUSTOM_OP_IMPL(fullconv3d_bp, 5, 1, false, 0, 13) {
// auto input = INPUT_VARIABLE(0);
// auto gradNext = INPUT_VARIABLE(1);
// auto weights = INPUT_VARIABLE(2);
// auto finput = INPUT_VARIABLE(3);
// // not used
// auto fgradInput = INPUT_VARIABLE(4);
// REQUIRE_TRUE(weights->rankOf() == 5, 0, "Weights should be 5D, got %i instead", weights->rankOf());
// REQUIRE_TRUE(input->rankOf() == 5, 0, "Input should be 5D, got %i instead", input->rankOf());
// auto output = OUTPUT_VARIABLE(0);
// int dT = INT_ARG(0);
// int dW = INT_ARG(1);
// int dH = INT_ARG(2);
// int pT = INT_ARG(3);
// int pW = INT_ARG(4);
// int pH = INT_ARG(5);
// int dilationT = INT_ARG(6);
// int dilationW = INT_ARG(7);
// int dilationH = INT_ARG(8);
// int aT = INT_ARG(9);
// int aW = INT_ARG(10);
// int aH = INT_ARG(11);
// bool biasUsed = INT_ARG(12) != 0;
// const int nInputPlane = (int)weights->shapeOf()[0];
// const int nOutputPlane = (int)weights->shapeOf()[1];
// const int kT = (int)weights->shapeOf()[2];
// const int kH = (int)weights->shapeOf()[3];
// const int kW = (int)weights->shapeOf()[4];
// const Nd4jLong inputWidth = input->shapeOf()[4];
// const Nd4jLong inputHeight = input->shapeOf()[3];
// const Nd4jLong inputDepth = input->shapeOf()[2];
// const Nd4jLong outputDepth = (inputDepth - 1) * dT - 2*pT + (dilationT * (kT - 1) + 1) + aT;
// const Nd4jLong outputHeight = (inputHeight - 1) * dH - 2*pH + (dilationH * (kH - 1) + 1) + aH;
// const Nd4jLong outputWidth = (inputWidth - 1) * dW - 2*pW + (dilationW * (kW - 1) + 1) + aW;
// const Nd4jLong batchSize = input->shapeOf()[0];
// REQUIRE_TRUE(output->isSameShape({(int) batchSize, (int) nInputPlane, (int) inputDepth, (int) inputHeight, (int) inputWidth}) ,0, "Output should have shape of [%i, %i, %i, %i, %i], but got [%i, %i, %i, %i, %i] instead", (int) batchSize, (int) nInputPlane, (int) inputDepth, (int) inputHeight, (int) inputWidth, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2), output->sizeAt(3), output->sizeAt(4));
// output->assign(0.0);
// // FIXME: non-inplace reshape!!!!
// NDArray *gradColumns;
// //auto gradColumns = finput->reshape('c', {nOutputPlane*kW*kH*kT, inputDepth*inputHeight*inputWidth });
// std::unique_ptr<ResultSet> tadsNext(gradNext->allExamples());
// std::unique_ptr<ResultSet> tadsOutput(output->allExamples());
// for (int e = 0; e < tadsNext->size(); e++) {
// auto tadNext = tadsNext->at(e);
// auto tadOutput = tadsOutput->at(e);
// // ConvolutionUtils<T>::_vol2col(
// // tadNext->getBuffer(),
// // nOutputPlane, outputDepth, outputHeight, outputWidth,
// // kT, kH, kW,
// // pT, pH, pW,
// // dT, dH, dW,
// // dilationT, dilationH, dilationW,
// // gradColumns->getBuffer());
// ConvolutionUtils::vol2col(*tadNext, *gradColumns, dT, dH, dW, pT, pH, pW, dilationT, dilationH, dilationW);
// const auto m = weights->shapeOf()[0];
// const auto n = gradColumns->shapeOf()[1];
// const auto k = weights->shapeOf()[1] * weights->shapeOf()[2] * weights->shapeOf()[3] * weights->shapeOf()[4];
// // FIXME: mmul helper should be used here
// /*
// nd4j::blas::GEMM<T>::op('f', 'n', 'n',
// n, m, k,
// 1.0f,
// gradColumns->getBuffer(), n,
// weights->getBuffer(), k,
// 0,
// tadOutput->getBuffer(), n
// );
// */
// }
// STORE_RESULT(*output);
// delete gradColumns;
return ND4J_STATUS_OK;
}
DECLARE_SHAPE_FN(fullconv3d_bp) {
// output shape equals to input shape, all out of sudden
// Nd4jLong* newShape;
// COPY_SHAPE(inputShape->at(0), newShape);
// return SHAPELIST(newShape);
return SHAPELIST();
}
DECLARE_TYPES(fullconv3d_grad) {
getOpDescriptor()
->setAllowedInputTypes(nd4j::DataType::ANY)
->setAllowedOutputTypes({ALL_FLOATS});
}
//////////////////////////////////////////////////////////////////////////
CUSTOM_OP_IMPL(fullconv3d_grad, 4, 2, false, 1, 13) {
// auto input = INPUT_VARIABLE(0);
// auto epsilon = INPUT_VARIABLE(1);
// auto columns = INPUT_VARIABLE(2);
// auto ones = INPUT_VARIABLE(3);
// REQUIRE_TRUE(input->rankOf() == epsilon->rankOf(), 0, "Rank of input (%i) & epsilon (%i) should be equal", input->rankOf(), epsilon->rankOf());
// REQUIRE_TRUE(input->sizeAt(0) == epsilon->sizeAt(0), 1, "Batch size should be equal for input and epsilon");
// auto gradWeight = OUTPUT_VARIABLE(0);
// auto gradBias = OUTPUT_VARIABLE(1);
// REQUIRE_TRUE(gradBias->sizeAt(0) == gradWeight->sizeAt(1), 0, "Bias shape mismatch");
// int dT = INT_ARG(0);
// int dW = INT_ARG(1);
// int dH = INT_ARG(2);
// int pT = INT_ARG(3);
// int pW = INT_ARG(4);
// int pH = INT_ARG(5);
// int dilationT = INT_ARG(6);
// int dilationW = INT_ARG(7);
// int dilationH = INT_ARG(8);
// int aT = INT_ARG(9);
// int aW = INT_ARG(10);
// int aH = INT_ARG(11);
// bool biasUsed = INT_ARG(12) != 0;
// double scale = block.getTArguments()->at(0);
// int nInputPlane = (int)gradWeight->shapeOf()[0];
// int nOutputPlane = (int)gradWeight->shapeOf()[1];
// int kT = (int)gradWeight->shapeOf()[2];
// int kH = (int)gradWeight->shapeOf()[3];
// int kW = (int)gradWeight->shapeOf()[4];
// const Nd4jLong inputWidth = input->shapeOf()[4];
// const Nd4jLong inputHeight = input->shapeOf()[3];
// const Nd4jLong inputDepth = input->shapeOf()[2];
// const Nd4jLong outputDepth = (inputDepth - 1) * dT - 2*pT + (dilationT * (kT - 1) + 1) + aT;
// const Nd4jLong outputHeight = (inputHeight - 1) * dH - 2*pH + (dilationH * (kH - 1) + 1) + aH;
// const Nd4jLong outputWidth = (inputWidth - 1) * dW - 2*pW + (dilationW * (kW - 1) + 1) + aW;
// REQUIRE_TRUE(gradWeight->isContiguous(), 0, "gradWight should be continuous");
// REQUIRE_TRUE(gradBias->isContiguous(), 0, "gradBias should be continuous");
// REQUIRE_TRUE(ones->rankOf() == 3, 0, "Ones should have rank 3, got %i instead", ones->rankOf());
// REQUIRE_TRUE(ones->isSameShape({outputDepth, outputHeight, outputWidth}), 0, "");
// ones->assign(1.0);
// std::unique_ptr<ResultSet> tadsInput(input->allExamples());
// std::unique_ptr<ResultSet> tadsEpsilon(epsilon->allExamples());
// for (int e = 0; e < tadsInput->size(); e++) {
// auto tadInput = tadsInput->at(e);
// auto tadEpsilon = tadsEpsilon->at(e);
// // ConvolutionUtils<T>::_vol2col(
// // tadEpsilon->getBuffer(), nOutputPlane,
// // outputDepth, outputHeight, outputWidth,
// // kT, kH, kW,
// // pT, pH, pW,
// // dT, dH, dW,
// // dilationT, dilationH, dilationW,
// // columns->getBuffer()
// // );
// ConvolutionUtils::vol2col(*tadEpsilon, *columns, dT, dH, dW, pT, pH, pW, dilationT, dilationH, dilationW);
// const Nd4jLong n = columns->shapeOf()[0]; // nOutputPlane * kt * kh * kw
// const Nd4jLong m = tadInput->shapeOf()[0]; // nInputPlane
// const Nd4jLong k = columns->shapeOf()[1];
// // FIXME: mmul helper should be used here
// /**
// nd4j::blas::GEMM<T>::op('f', 't', 'n',
// n, m, k,
// scale,
// columns->getBuffer(), k,
// tadInput->getBuffer(), k,
// 1,
// gradWeight->getBuffer(), n);
// */
// const Nd4jLong m_ = nOutputPlane;
// const Nd4jLong k_ = outputDepth * outputHeight * outputWidth;
// if (gradBias) {
// // FIXME: mmul helper should be used here
// /*
// nd4j::blas::GEMV<T>::op('t',
// k_, m_,
// scale,
// tadEpsilon->getBuffer(), k_,
// ones->getBuffer(), 1, (T)1.0f,
// gradBias->getBuffer(), 1);
// */
// }
// }
// STORE_2_RESULTS(*gradWeight, *gradBias);
return Status::OK();
}
DECLARE_SHAPE_FN(fullconv3d_grad) {
// auto list = SHAPELIST();
// _grad ops MUST have output arrays provided
// return list;
return SHAPELIST();
}
}
}
#endif

View File

@ -187,12 +187,6 @@ namespace nd4j {
DECLARE_CUSTOM_OP(pnormpool2d_bp, 2, 1, false, 1, 10); DECLARE_CUSTOM_OP(pnormpool2d_bp, 2, 1, false, 1, 10);
#endif #endif
#if NOT_EXCLUDED(OP_fullconv3d)
DECLARE_CUSTOM_OP(fullconv3d, 5, 1, false, 0, 13);
DECLARE_CUSTOM_OP(fullconv3d_bp, 5, 1, false, 0, 13);
DECLARE_CUSTOM_OP(fullconv3d_grad, 4, 2, false, 1, 13);
#endif
/** /**
* This op implements im2col algorithm, widely used in convolution neural networks * This op implements im2col algorithm, widely used in convolution neural networks
* Input: 4D input expected * Input: 4D input expected

View File

@ -16,6 +16,9 @@
package org.nd4j.autodiff.samediff; package org.nd4j.autodiff.samediff;
import static org.nd4j.autodiff.util.TrainingUtils.getSingleOutput;
import static org.nd4j.autodiff.util.TrainingUtils.stackOutputs;
import com.google.common.collect.HashBasedTable; import com.google.common.collect.HashBasedTable;
import com.google.common.collect.Table; import com.google.common.collect.Table;
import com.google.common.primitives.Ints; import com.google.common.primitives.Ints;
@ -73,6 +76,7 @@ import org.nd4j.linalg.dataset.adapter.SingletonMultiDataSetIterator;
import org.nd4j.linalg.dataset.api.MultiDataSet; import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.exception.ND4JException;
import org.nd4j.linalg.exception.ND4JIllegalArgumentException; import org.nd4j.linalg.exception.ND4JIllegalArgumentException;
import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.exception.ND4UnresolvedOutputVariables; import org.nd4j.linalg.exception.ND4UnresolvedOutputVariables;
@ -109,7 +113,7 @@ import org.tensorflow.framework.GraphDef;
* <p> * <p>
* That graph accumulates operations. * That graph accumulates operations.
* <p> * <p>
* In order to execute the graph, you run one of the execution methods, such as {@link #exec(Map, String...)} * In order to execute the graph, you run one of the execution methods, such as {@link #output(Map, String...)}
*/ */
@AllArgsConstructor @AllArgsConstructor
@Builder @Builder
@ -2262,7 +2266,7 @@ public class SameDiff extends SDBaseOps {
MultiDataSet ds = iterator.next(); MultiDataSet ds = iterator.next();
Map<String,INDArray> placeholderMap = toPlaceholderMap(ds); Map<String,INDArray> placeholderMap = toPlaceholderMap(ds);
Map<String,INDArray> m = exec(placeholderMap, reqVars); Map<String,INDArray> m = output(placeholderMap, reqVars);
for(Map.Entry<String,List<IEvaluation>> e : variableEvals.entrySet()){ for(Map.Entry<String,List<IEvaluation>> e : variableEvals.entrySet()){
INDArray prediction = m.get(e.getKey()); INDArray prediction = m.get(e.getKey());
@ -2288,7 +2292,15 @@ public class SameDiff extends SDBaseOps {
* @param outputs The variables to evaluate * @param outputs The variables to evaluate
*/ */
public Map<String, INDArray> output(DataSet dataSet, String... outputs){ public Map<String, INDArray> output(DataSet dataSet, String... outputs){
return output(new SingletonMultiDataSetIterator(dataSet.toMultiDataSet()), outputs).get(0); return outputBatches(new SingletonMultiDataSetIterator(dataSet.toMultiDataSet()), outputs).get(0);
}
/**
* Single output inference.
* See {@link #output(DataSet, String...)}
*/
public INDArray outputSingle(DataSet dataSet, String output){
return output(dataSet, output).get(output);
} }
/** /**
@ -2299,13 +2311,40 @@ public class SameDiff extends SDBaseOps {
* sameDiff.output(iterator, "softmax");} * sameDiff.output(iterator, "softmax");}
* </pre> * </pre>
* *
* Uses concatenation on the outputs of {@link #outputBatches(DataSetIterator, String...)} which may cause issues with some inputs.
* RNNs with variable time series length and CNNs with variable image sizes will most likely have issues.
*
* @param iterator Iterator as source of data to evaluate * @param iterator Iterator as source of data to evaluate
* @param outputs The variables to evaluate * @param outputs The variables to evaluate
*/ */
public List<Map<String, INDArray>> output(DataSetIterator iterator, String... outputs){ public Map<String, INDArray> output(DataSetIterator iterator, String... outputs){
return output(new MultiDataSetIteratorAdapter(iterator), outputs); return output(new MultiDataSetIteratorAdapter(iterator), outputs);
} }
/**
* See {@link #output(DataSetIterator, String...)}, but without the concatenation of batches.
*
*/
public List<Map<String, INDArray>> outputBatches(DataSetIterator iterator, String... outputs){
return outputBatches(new MultiDataSetIteratorAdapter(iterator), outputs);
}
/**
* Single output inference.
* See {@link #output(DataSetIterator, String...)}
*/
public INDArray outputSingle(DataSetIterator dataSet, String output){
return output(dataSet, output).get(output);
}
/**
* Single batched output inference.
* See {@link #output(DataSetIterator, String...)}
*/
public List<INDArray> outputSingleBatches(DataSetIterator dataSet, String output){
return getSingleOutput(outputBatches(dataSet, output), output);
}
/** /**
* Perform inference.<br> * Perform inference.<br>
* <br> * <br>
@ -2321,10 +2360,20 @@ public class SameDiff extends SDBaseOps {
* } * }
* </pre> * </pre>
* *
* Uses concatenation on the outputs of {@link #outputBatches(MultiDataSetIterator, String...)} which may cause issues with some inputs.
* RNNs with variable time series length and CNNs with variable image sizes will most likely have issues.
*
* @param iterator The iterator - the source of the data for inference * @param iterator The iterator - the source of the data for inference
* @param outputs The set of outputs to report. If null, defaults to all outputs of this SameDiff. * @param outputs The set of outputs to report. If null, defaults to all outputs of this SameDiff.
*/ */
public List<Map<String, INDArray>> output(MultiDataSetIterator iterator, String... outputs){ public Map<String, INDArray> output(MultiDataSetIterator iterator, String... outputs){
return stackOutputs(outputBatches(iterator, outputs));
}
/**
* See {@link #output(MultiDataSetIterator, String...)}, but without the concatenation of batches.
*/
public List<Map<String, INDArray>> outputBatches(MultiDataSetIterator iterator, String... outputs){
Preconditions.checkState(trainingConfig != null, "Training config has not been set"); Preconditions.checkState(trainingConfig != null, "Training config has not been set");
List<String> reqVars; List<String> reqVars;
@ -2344,12 +2393,114 @@ public class SameDiff extends SDBaseOps {
MultiDataSet ds = iterator.next(); MultiDataSet ds = iterator.next();
Map<String,INDArray> placeholderMap = toPlaceholderMap(ds); Map<String,INDArray> placeholderMap = toPlaceholderMap(ds);
predictions.add(exec(placeholderMap, reqVars)); predictions.add(output(placeholderMap, reqVars));
} }
return predictions; return predictions;
} }
/**
* Single output inference.
* See {@link #output(MultiDataSetIterator, String...)}
*/
public INDArray outputSingle(MultiDataSetIterator dataSet, String output){
return output(dataSet, output).get(output);
}
/**
* Single batched output inference.
* See {@link #output(MultiDataSetIterator, String...)}
*/
public List<INDArray> outputSingleBatches(MultiDataSetIterator dataSet, String output){
return getSingleOutput(outputBatches(dataSet, output), output);
}
/**
* @deprecated See {@link #outputAll(Map)}
*/
@Deprecated
public Map<String,INDArray> execAll(Map<String,INDArray> placeholders){
return outputAll(placeholders);
}
/**
* Do inference for all variables for a single batch
*/
public Map<String,INDArray> outputAll(Map<String,INDArray> placeholders){
List<String> allVars = new ArrayList<>();
for(Variable v : variables.values()){
allVars.add(v.getName());
}
return output(placeholders, allVars.toArray(new String[0]));
}
/**
* @deprecated See {@link #outputSingle(Map, String)}
*/
@Deprecated
public INDArray execSingle(Map<String,INDArray> placeholders, String output){
return outputSingle(placeholders, output);
}
/**
* Do inference for a single variable for a single batch
*/
public INDArray outputSingle(Map<String,INDArray> placeholders, String output){
return output(placeholders, output).get(output);
}
/**
* @deprecated See {@link #output(Map, List)}
*/
@Deprecated
public Map<String,INDArray> exec(Map<String,INDArray> placeholders, List<String> outputs){
return output(placeholders, outputs);
}
/**
* Do inference for the given variables for a single batch
*/
public Map<String,INDArray> output(Map<String,INDArray> placeholders, List<String> outputs){
return output(placeholders, outputs.toArray(new String[outputs.size()]));
}
/**
* @deprecated See {@link #output(Map, String...)}
*/
@Deprecated
public Map<String,INDArray> exec(Map<String,INDArray> placeholders, String... outputs) {
return output(placeholders, outputs);
}
/**
* Do inference for the given variables for a single batch
*/
public Map<String,INDArray> output(Map<String,INDArray> placeholders, String... outputs) {
return output(placeholders, false, null, outputs);
}
/**
* Do inference for the given variables for a single batch, with training information
*/
protected Map<String,INDArray> output(Map<String,INDArray> placeholders, boolean training, At at, String... outputs){
Preconditions.checkState(outputs != null && outputs.length > 0, "No outputs were specified");
long threadId = Thread.currentThread().getId();
if(!sessions.containsKey(threadId)){
log.info("Creating new InferenceSession for thread {}", threadId);
sessions.put(threadId, new InferenceSession(this));
}
List<String> phNames = inputs();
if(placeholders == null && phNames != null){
//Maybe user set placeholders before calling exec method?
placeholders = placeholdersPerThread.get(Thread.currentThread().getId());
}
//Placeholder validation is performed in InferenceSession
InferenceSession is = sessions.get(threadId);
Map<String,INDArray> ret = is.output(Arrays.asList(outputs), placeholders, listeners, training, at);
return ret;
}
public SDVariable one(String name, int... shape){ public SDVariable one(String name, int... shape){
return one(name, Nd4j.defaultFloatingPointType(), shape); return one(name, Nd4j.defaultFloatingPointType(), shape);
@ -3779,7 +3930,7 @@ public class SameDiff extends SDBaseOps {
} }
//TODO is this 'train' flag the best approach? //TODO is this 'train' flag the best approach?
sd.exec(placeholders, trainingConfig != null, at, variableGradNamesList.toArray(new String[variableGradNamesList.size()])); sd.output(placeholders, trainingConfig != null, at, variableGradNamesList.toArray(new String[variableGradNamesList.size()]));
} }
/** /**
@ -4459,47 +4610,6 @@ public class SameDiff extends SDBaseOps {
} }
} }
public Map<String,INDArray> execAll(Map<String,INDArray> placeholders){
List<String> allVars = new ArrayList<>();
for(Variable v : variables.values()){
allVars.add(v.getName());
}
return exec(placeholders, allVars.toArray(new String[allVars.size()]));
}
public INDArray execSingle(Map<String,INDArray> placeholders, String output){
return exec(placeholders, output).get(output);
}
public Map<String,INDArray> exec(Map<String,INDArray> placeholders, List<String> outputs){
return exec(placeholders, outputs.toArray(new String[outputs.size()]));
}
public Map<String,INDArray> exec(Map<String,INDArray> placeholders, String... outputs) {
return exec(placeholders, false, null, outputs);
}
protected Map<String,INDArray> exec(Map<String,INDArray> placeholders, boolean training, At at, String... outputs){
Preconditions.checkState(outputs != null && outputs.length > 0, "No outputs were specified");
long threadId = Thread.currentThread().getId();
if(!sessions.containsKey(threadId)){
log.info("Creating new InferenceSession for thread {}", threadId);
sessions.put(threadId, new InferenceSession(this));
}
List<String> phNames = inputs();
if(placeholders == null && phNames != null){
//Maybe user set placeholders before calling exec method?
placeholders = placeholdersPerThread.get(Thread.currentThread().getId());
}
//Placeholder validation is performed in InferenceSession
InferenceSession is = sessions.get(threadId);
Map<String,INDArray> ret = is.output(Arrays.asList(outputs), placeholders, listeners, training, at);
return ret;
}
protected int asFlatNode(String name, @NonNull SameDiff scope, @NonNull FlatBufferBuilder bufferBuilder) { protected int asFlatNode(String name, @NonNull SameDiff scope, @NonNull FlatBufferBuilder bufferBuilder) {
int scopeName = bufferBuilder.createString(name); int scopeName = bufferBuilder.createString(name);

View File

@ -0,0 +1,70 @@
/*
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
*/
package org.nd4j.autodiff.util;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import lombok.AccessLevel;
import lombok.NoArgsConstructor;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.exception.ND4JException;
import org.nd4j.linalg.factory.Nd4j;
/**
* Utilities for SameDiff training and inference
*/
@NoArgsConstructor(access = AccessLevel.PRIVATE)
public class TrainingUtils {
/**
* Stack batch outputs, like an output from {@link org.nd4j.autodiff.samediff.SameDiff#output(MultiDataSetIterator, String...)}
*/
public static Map<String, INDArray> stackOutputs(List<Map<String, INDArray>> outputs){
Map<String, List<INDArray>> outs = new HashMap<>();
for(Map<String, INDArray> batch : outputs){
for(String k : batch.keySet()){
if(!outs.containsKey(k))
outs.put(k, new ArrayList<INDArray>());
outs.get(k).add(batch.get(k));
}
}
Map<String, INDArray> ret = new HashMap<>();
for(String k : outs.keySet()){
try {
ret.put(k, Nd4j.concat(0, outs.get(k).toArray(new INDArray[0])));
} catch(Exception e){
throw new ND4JException("Error concatenating batch outputs", e);
}
}
return ret;
}
/**
* Get a list of batch outputs for a single variable from a list of batch outputs for all variables
*/
public static List<INDArray> getSingleOutput(List<Map<String, INDArray>> outputs, String output){
List<INDArray> batches = new ArrayList<>();
for(Map<String, INDArray> batch : outputs)
batches.add(batch.get(output));
return batches;
}
}

View File

@ -915,7 +915,6 @@ public class OpValidation {
Conv2DDerivative.class, Conv2DDerivative.class,
Conv3DDerivative.class, Conv3DDerivative.class,
DeConv2DDerivative.class, DeConv2DDerivative.class,
FullConv3DDerivative.class,
LocalResponseNormalizationDerivative.class, LocalResponseNormalizationDerivative.class,
Pooling2DDerivative.class, Pooling2DDerivative.class,
Pooling3DDerivative.class, Pooling3DDerivative.class,

View File

@ -72,7 +72,6 @@ public class DifferentialFunctionClassHolder {
add(AvgPooling2D.class.getName()); add(AvgPooling2D.class.getName());
add(Conv2D.class.getName()); add(Conv2D.class.getName());
add(Conv3D.class.getName()); add(Conv3D.class.getName());
add(FullConv3D.class.getName());
add(LocalResponseNormalization.class.getName()); add(LocalResponseNormalization.class.getName());
add(MaxPooling2D.class.getName()); add(MaxPooling2D.class.getName());
add(Pooling2D.class.getName()); add(Pooling2D.class.getName());

View File

@ -117,8 +117,6 @@ public class ImportClassMapping {
org.nd4j.linalg.api.ops.impl.layers.convolution.DeConv3DDerivative.class, org.nd4j.linalg.api.ops.impl.layers.convolution.DeConv3DDerivative.class,
org.nd4j.linalg.api.ops.impl.layers.convolution.DepthToSpace.class, org.nd4j.linalg.api.ops.impl.layers.convolution.DepthToSpace.class,
org.nd4j.linalg.api.ops.impl.layers.convolution.DepthwiseConv2D.class, org.nd4j.linalg.api.ops.impl.layers.convolution.DepthwiseConv2D.class,
org.nd4j.linalg.api.ops.impl.layers.convolution.FullConv3D.class,
org.nd4j.linalg.api.ops.impl.layers.convolution.FullConv3DDerivative.class,
org.nd4j.linalg.api.ops.impl.layers.convolution.Im2col.class, org.nd4j.linalg.api.ops.impl.layers.convolution.Im2col.class,
org.nd4j.linalg.api.ops.impl.layers.convolution.Im2colBp.class, org.nd4j.linalg.api.ops.impl.layers.convolution.Im2colBp.class,
org.nd4j.linalg.api.ops.impl.layers.convolution.LegacyPooling2D.class, org.nd4j.linalg.api.ops.impl.layers.convolution.LegacyPooling2D.class,

View File

@ -251,8 +251,6 @@ public class AvgPooling2D extends DynamicCustomOp {
.kW(kW) .kW(kW)
.pH(pH) .pH(pH)
.pW(pW) .pW(pW)
.virtualHeight(1)
.virtualWidth(1)
.isNHWC(data_format.equalsIgnoreCase("nhwc")) .isNHWC(data_format.equalsIgnoreCase("nhwc"))
.extra(0.0) // averaging only for non-padded values .extra(0.0) // averaging only for non-padded values
.build(); .build();
@ -277,8 +275,6 @@ public class AvgPooling2D extends DynamicCustomOp {
.kW(kernelShape.get(1).intValue()) .kW(kernelShape.get(1).intValue())
.pH(padding.get(0).intValue()) .pH(padding.get(0).intValue())
.pW(padding.size() < 2 ? padding.get(0).intValue() : padding.get(1).intValue()) .pW(padding.size() < 2 ? padding.get(0).intValue() : padding.get(1).intValue())
.virtualWidth(1)
.virtualHeight(1)
.build(); .build();
this.config = pooling2DConfig; this.config = pooling2DConfig;
addArgs(); addArgs();

View File

@ -1,228 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.layers.convolution;
import lombok.Builder;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.imports.descriptors.properties.AttributeAdapter;
import org.nd4j.imports.descriptors.properties.PropertyMapping;
import org.nd4j.imports.descriptors.properties.adapters.IntArrayIntIndexAdpater;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.FullConv3DConfig;
import org.nd4j.linalg.util.ArrayUtil;
import java.lang.reflect.Field;
import java.util.*;
/**
* FullConv3D operation
*/
@Slf4j
public class FullConv3D extends DynamicCustomOp {
protected FullConv3DConfig config;
@Builder(builderMethodName = "builder")
public FullConv3D(SameDiff sameDiff, SDVariable[] inputFunctions, INDArray[] inputs, INDArray[] outputs, FullConv3DConfig config) {
super(null,sameDiff, inputFunctions, false);
this.config = config;
if(inputs != null) {
addInputArgument(inputs);
}
if(outputs != null) {
addOutputArgument(outputs);
}
addArgs();
}
public FullConv3D() {}
@Override
public Map<String, Object> propertiesForFunction() {
return config.toProperties();
}
@Override
public long[] iArgs() {
if (iArguments.size() == 0)
addArgs();
return super.iArgs();
}
@Override
public boolean isConfigProperties() {
return true;
}
@Override
public String configFieldName() {
return "config";
}
@Override
public Map<String, Map<String, AttributeAdapter>> attributeAdaptersForFunction() {
Map<String,Map<String,AttributeAdapter>> ret = new LinkedHashMap<>();
Map<String,AttributeAdapter> tfAdapters = new LinkedHashMap<>();
tfAdapters.put("dT", new IntArrayIntIndexAdpater(1));
tfAdapters.put("dW", new IntArrayIntIndexAdpater(2));
tfAdapters.put("dH",new IntArrayIntIndexAdpater(3));
tfAdapters.put("pT", new IntArrayIntIndexAdpater(1));
tfAdapters.put("pW", new IntArrayIntIndexAdpater(2));
tfAdapters.put("pH",new IntArrayIntIndexAdpater(3));
ret.put(tensorflowName(),tfAdapters);
return ret;
}
@Override
public Map<String, Map<String, PropertyMapping>> mappingsForFunction() {
Map<String,Map<String,PropertyMapping>> ret = new HashMap<>();
Map<String,PropertyMapping> map = new HashMap<>();
val strideMapping = PropertyMapping.builder()
.tfAttrName("strides")
.onnxAttrName("strides")
.propertyNames(new String[]{"dT","dW","dH"})
.build();
val dilationMapping = PropertyMapping.builder()
.onnxAttrName("dilations")
.propertyNames(new String[]{"dD","dH","dW"})
.tfAttrName("rates")
.build();
val sameMode = PropertyMapping.builder()
.onnxAttrName("auto_pad")
.propertyNames(new String[]{"isSameMode"})
.tfAttrName("padding")
.build();
val paddingWidthHeight = PropertyMapping.builder()
.onnxAttrName("padding")
.propertyNames(new String[]{"pT","pW","pH"})
.build();
val dataFormat = PropertyMapping.builder()
.onnxAttrName("data_format")
.tfAttrName("data_format")
.propertyNames(new String[]{"dataFormat"})
.build();
val outputPadding = PropertyMapping.builder()
.propertyNames(new String[]{"aT","aH","aW"})
.build();
val biasUsed = PropertyMapping.builder()
.propertyNames(new String[]{"biasUsed"})
.build();
for(val propertyMapping : new PropertyMapping[] {
strideMapping,
dilationMapping,
sameMode,
paddingWidthHeight,
dataFormat,
outputPadding,biasUsed}) {
for(val keys : propertyMapping.getPropertyNames())
map.put(keys,propertyMapping);
}
ret.put(onnxName(),map);
ret.put(tensorflowName(),map);
return ret;
}
private void addArgs() {
addIArgument(new long[]{
config.getDT(),
config.getDW(),
config.getDH(),
config.getPT(),
config.getPW(),
config.getPH(),
config.getDilationT(),
config.getDilationW(),
config.getDilationH(),
config.getAT(),
config.getAW(),
config.getAH(),
ArrayUtil.fromBoolean(config.isBiasUsed())});
}
@Override
public String opName() {
return "fullconv3d";
}
@Override
public List<SDVariable> doDiff(List<SDVariable> f1) {
List<SDVariable> inputs = new ArrayList<>();
inputs.addAll(Arrays.asList(args()));
inputs.addAll(f1);
List<SDVariable> ret = new ArrayList<>();
FullConv3DDerivative fullConv3DDerivative = FullConv3DDerivative.derivativeBuilder()
.conv3DConfig(config)
.sameDiff(sameDiff)
.inputFunctions(inputs.toArray(new SDVariable[inputs.size()]))
.build();
ret.addAll(Arrays.asList(fullConv3DDerivative.outputVariables()));
return ret;
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
int n = args().length;
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes);
return Collections.singletonList(inputDataTypes.get(0));
}
}

View File

@ -1,77 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.layers.convolution;
import lombok.Builder;
import lombok.extern.slf4j.Slf4j;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.FullConv3DConfig;
import java.util.ArrayList;
import java.util.List;
/**
* FullConv3DDerivative operation
*/
@Slf4j
public class FullConv3DDerivative extends FullConv3D {
public FullConv3DDerivative() {}
@Builder(builderMethodName = "derivativeBuilder")
public FullConv3DDerivative(SameDiff sameDiff, SDVariable[] inputFunctions, INDArray[] inputs, INDArray[] outputs, FullConv3DConfig conv3DConfig) {
super(sameDiff, inputFunctions, inputs, outputs, conv3DConfig);
}
@Override
public String opName() {
return "fullconv3d_bp";
}
@Override
public String tensorflowName() {
throw new NoOpNameFoundException("No tensorflwo op name found for conv3d derivative");
}
@Override
public String[] tensorflowNames() {
throw new NoOpNameFoundException("No tensorflwo op name found for conv3d derivative");
}
@Override
public List<SDVariable> doDiff(List<SDVariable> f1) {
throw new UnsupportedOperationException("Unable to take derivative of derivative.");
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
int n = args().length; //Original inputs + gradient at
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes);
List<DataType> out = new ArrayList<>(n-1);
for( int i=0; i<n-1; i++ ){
out.add(inputDataTypes.get(i));
}
return out;
}
}

View File

@ -204,8 +204,6 @@ public class MaxPooling2D extends DynamicCustomOp {
.kW(kW) .kW(kW)
.pH(pH) .pH(pH)
.pW(pW) .pW(pW)
.virtualHeight(1)
.virtualWidth(1)
.isNHWC(data_format.equalsIgnoreCase("nhwc")) .isNHWC(data_format.equalsIgnoreCase("nhwc"))
.extra(1.0) // averaging only for non-padded values .extra(1.0) // averaging only for non-padded values
.build(); .build();
@ -230,8 +228,6 @@ public class MaxPooling2D extends DynamicCustomOp {
.kW(kernelShape.size() < 2 ? kernelShape.get(0).intValue() : kernelShape.get(1).intValue()) .kW(kernelShape.size() < 2 ? kernelShape.get(0).intValue() : kernelShape.get(1).intValue())
.pH(padding.get(0).intValue()) .pH(padding.get(0).intValue())
.pW(padding.size() < 2 ? padding.get(0).intValue() : padding.get(1).intValue()) .pW(padding.size() < 2 ? padding.get(0).intValue() : padding.get(1).intValue())
.virtualHeight(1)
.virtualWidth(1)
.build(); .build();
this.config = pooling2DConfig; this.config = pooling2DConfig;
addArgs(); addArgs();

View File

@ -174,8 +174,6 @@ public class Pooling2D extends DynamicCustomOp {
.kW(kW.intValue()) .kW(kW.intValue())
.pH(padding.get(0).intValue()) .pH(padding.get(0).intValue())
.pW(padding.get(1).intValue()) .pW(padding.get(1).intValue())
.virtualWidth(1)
.virtualHeight(1)
.build(); .build();
this.config = pooling2DConfig; this.config = pooling2DConfig;
addArgs(); addArgs();
@ -200,8 +198,6 @@ public class Pooling2D extends DynamicCustomOp {
.kW(kernelShape.get(1).intValue()) .kW(kernelShape.get(1).intValue())
.pH(padding.get(0).intValue()) .pH(padding.get(0).intValue())
.pW(padding.get(1).intValue()) .pW(padding.get(1).intValue())
.virtualHeight(1)
.virtualWidth(1)
.build(); .build();
this.config = pooling2DConfig; this.config = pooling2DConfig;
addArgs(); addArgs();

View File

@ -16,6 +16,8 @@
package org.nd4j.linalg.api.ops.impl.layers.convolution.config; package org.nd4j.linalg.api.ops.impl.layers.convolution.config;
import java.util.LinkedHashMap;
import java.util.Map;
import lombok.val; import lombok.val;
import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.exception.ND4JIllegalStateException;
@ -23,6 +25,8 @@ import java.lang.reflect.Field;
public abstract class BaseConvolutionConfig { public abstract class BaseConvolutionConfig {
public abstract Map<String, Object> toProperties();
/** /**
* Get the value for a given property * Get the value for a given property
* for this function * for this function
@ -154,4 +158,5 @@ public abstract class BaseConvolutionConfig {
} }
protected abstract void validate();
} }

View File

@ -16,15 +16,16 @@
package org.nd4j.linalg.api.ops.impl.layers.convolution.config; package org.nd4j.linalg.api.ops.impl.layers.convolution.config;
import lombok.*;
import org.nd4j.base.Preconditions;
import java.util.LinkedHashMap; import java.util.LinkedHashMap;
import java.util.Map; import java.util.Map;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.util.ConvConfigUtil;
@Builder
@Data @Data
@AllArgsConstructor @Builder
@NoArgsConstructor @NoArgsConstructor
public class Conv1DConfig extends BaseConvolutionConfig { public class Conv1DConfig extends BaseConvolutionConfig {
public static final String NCW = "NCW"; public static final String NCW = "NCW";
@ -40,6 +41,16 @@ public class Conv1DConfig extends BaseConvolutionConfig {
private String dataFormat = NCW; private String dataFormat = NCW;
private boolean isSameMode; private boolean isSameMode;
public Conv1DConfig(long k, long s, long p, String dataFormat, boolean isSameMode) {
this.k = k;
this.s = s;
this.p = p;
this.dataFormat = dataFormat;
this.isSameMode = isSameMode;
validate();
}
public boolean isNWC(){ public boolean isNWC(){
Preconditions.checkState(dataFormat.equalsIgnoreCase(NCW) || dataFormat.equalsIgnoreCase(NWC), Preconditions.checkState(dataFormat.equalsIgnoreCase(NCW) || dataFormat.equalsIgnoreCase(NWC),
"Data format must be one of %s or %s, got %s", NCW, NWC, dataFormat); "Data format must be one of %s or %s, got %s", NCW, NWC, dataFormat);
@ -54,6 +65,7 @@ public class Conv1DConfig extends BaseConvolutionConfig {
} }
} }
@Override
public Map<String, Object> toProperties() { public Map<String, Object> toProperties() {
Map<String, Object> ret = new LinkedHashMap<>(); Map<String, Object> ret = new LinkedHashMap<>();
ret.put("k", k); ret.put("k", k);
@ -64,5 +76,11 @@ public class Conv1DConfig extends BaseConvolutionConfig {
return ret; return ret;
} }
@Override
protected void validate() {
ConvConfigUtil.validate1D(k, s, p);
Preconditions.checkArgument(dataFormat != null, "Data format can't be null");
}
} }

View File

@ -16,18 +16,16 @@
package org.nd4j.linalg.api.ops.impl.layers.convolution.config; package org.nd4j.linalg.api.ops.impl.layers.convolution.config;
import lombok.AllArgsConstructor; import java.util.LinkedHashMap;
import java.util.Map;
import lombok.Builder; import lombok.Builder;
import lombok.Data; import lombok.Data;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
import org.nd4j.linalg.util.ConvConfigUtil;
import java.util.LinkedHashMap;
import java.util.Map;
@Builder
@Data @Data
@AllArgsConstructor @Builder
@NoArgsConstructor @NoArgsConstructor
public class Conv2DConfig extends BaseConvolutionConfig { public class Conv2DConfig extends BaseConvolutionConfig {
public static final String NCHW = "NCHW"; public static final String NCHW = "NCHW";
@ -53,6 +51,23 @@ public class Conv2DConfig extends BaseConvolutionConfig {
@Builder.Default @Builder.Default
private String dataFormat = NCHW; private String dataFormat = NCHW;
public Conv2DConfig(long kH, long kW, long sH, long sW, long pH, long pW, long dH, long dW, boolean isSameMode,
String dataFormat) {
this.kH = kH;
this.kW = kW;
this.sH = sH;
this.sW = sW;
this.pH = pH;
this.pW = pW;
this.dH = dH;
this.dW = dW;
this.isSameMode = isSameMode;
this.dataFormat = dataFormat;
validate();
}
public boolean isNHWC(){ public boolean isNHWC(){
Preconditions.checkState(dataFormat.equalsIgnoreCase(NCHW) || dataFormat.equalsIgnoreCase(NHWC), Preconditions.checkState(dataFormat.equalsIgnoreCase(NCHW) || dataFormat.equalsIgnoreCase(NHWC),
"Data format must be one of %s or %s, got %s", NCHW, NHWC, dataFormat); "Data format must be one of %s or %s, got %s", NCHW, NHWC, dataFormat);
@ -67,6 +82,7 @@ public class Conv2DConfig extends BaseConvolutionConfig {
} }
} }
@Override
public Map<String, Object> toProperties() { public Map<String, Object> toProperties() {
Map<String, Object> ret = new LinkedHashMap<>(); Map<String, Object> ret = new LinkedHashMap<>();
ret.put("kH", kH); ret.put("kH", kH);
@ -82,5 +98,11 @@ public class Conv2DConfig extends BaseConvolutionConfig {
return ret; return ret;
} }
@Override
protected void validate() {
ConvConfigUtil.validate2D(kH, kW, sH, sW, pH, pW, dH, dW);
Preconditions.checkArgument(dataFormat != null, "Data format can't be null");
}
} }

View File

@ -17,30 +17,28 @@
package org.nd4j.linalg.api.ops.impl.layers.convolution.config; package org.nd4j.linalg.api.ops.impl.layers.convolution.config;
import lombok.AllArgsConstructor; import java.util.LinkedHashMap;
import java.util.Map;
import lombok.Builder; import lombok.Builder;
import lombok.Data; import lombok.Data;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
import org.nd4j.linalg.util.ConvConfigUtil;
import java.util.LinkedHashMap;
import java.util.Map;
@Data @Data
@Builder @Builder
@NoArgsConstructor @NoArgsConstructor
@AllArgsConstructor
public class Conv3DConfig extends BaseConvolutionConfig { public class Conv3DConfig extends BaseConvolutionConfig {
public static final String NDHWC = "NDHWC"; public static final String NDHWC = "NDHWC";
public static final String NCDHW = "NCDHW"; public static final String NCDHW = "NCDHW";
//kernel //kernel
@Builder.Default @Builder.Default
private long kD = 1; private long kD = -1;
@Builder.Default @Builder.Default
private long kW = 1; private long kW = -1;
@Builder.Default @Builder.Default
private long kH = 1; private long kH = -1;
//strides //strides
@Builder.Default @Builder.Default
@ -66,14 +64,6 @@ public class Conv3DConfig extends BaseConvolutionConfig {
@Builder.Default @Builder.Default
private long dH = 1; private long dH = 1;
//output padding
@Builder.Default
private long aD = 0;
@Builder.Default
private long aW = 0;
@Builder.Default
private long aH = 0;
@Builder.Default @Builder.Default
private boolean biasUsed = false; private boolean biasUsed = false;
private boolean isSameMode; private boolean isSameMode;
@ -81,6 +71,27 @@ public class Conv3DConfig extends BaseConvolutionConfig {
@Builder.Default @Builder.Default
private String dataFormat = NDHWC; private String dataFormat = NDHWC;
public Conv3DConfig(long kD, long kW, long kH, long sD, long sW, long sH, long pD, long pW, long pH, long dD,
long dW, long dH, boolean biasUsed, boolean isSameMode, String dataFormat) {
this.kD = kD;
this.kW = kW;
this.kH = kH;
this.sD = sD;
this.sW = sW;
this.sH = sH;
this.pD = pD;
this.pW = pW;
this.pH = pH;
this.dD = dD;
this.dW = dW;
this.dH = dH;
this.biasUsed = biasUsed;
this.isSameMode = isSameMode;
this.dataFormat = dataFormat;
validate();
}
public boolean isNCDHW(){ public boolean isNCDHW(){
Preconditions.checkState(dataFormat.equalsIgnoreCase(NCDHW) || dataFormat.equalsIgnoreCase(NDHWC), Preconditions.checkState(dataFormat.equalsIgnoreCase(NCDHW) || dataFormat.equalsIgnoreCase(NDHWC),
"Data format must be one of %s or %s, got %s", NCDHW, NDHWC, dataFormat); "Data format must be one of %s or %s, got %s", NCDHW, NDHWC, dataFormat);
@ -95,6 +106,7 @@ public class Conv3DConfig extends BaseConvolutionConfig {
} }
} }
@Override
public Map<String, Object> toProperties() { public Map<String, Object> toProperties() {
Map<String, Object> ret = new LinkedHashMap<>(); Map<String, Object> ret = new LinkedHashMap<>();
ret.put("kD", kD); ret.put("kD", kD);
@ -109,9 +121,6 @@ public class Conv3DConfig extends BaseConvolutionConfig {
ret.put("dD", dD); ret.put("dD", dD);
ret.put("dW", dW); ret.put("dW", dW);
ret.put("dH", dH); ret.put("dH", dH);
ret.put("aD", aD);
ret.put("aW", aW);
ret.put("aH", aH);
ret.put("biasUsed", biasUsed); ret.put("biasUsed", biasUsed);
ret.put("dataFormat", dataFormat); ret.put("dataFormat", dataFormat);
ret.put("isSameMode", isSameMode); ret.put("isSameMode", isSameMode);
@ -119,5 +128,11 @@ public class Conv3DConfig extends BaseConvolutionConfig {
return ret; return ret;
} }
@Override
protected void validate() {
ConvConfigUtil.validate3D(kH, kW, kD, sH, sW, sD, pH, pW, pD, dH, dW, dD);
Preconditions.checkArgument(dataFormat != null, "Data format can't be null");
}
} }

View File

@ -16,22 +16,22 @@
package org.nd4j.linalg.api.ops.impl.layers.convolution.config; package org.nd4j.linalg.api.ops.impl.layers.convolution.config;
import lombok.AllArgsConstructor; import java.util.LinkedHashMap;
import java.util.Map;
import lombok.Builder; import lombok.Builder;
import lombok.Data; import lombok.Data;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.util.ConvConfigUtil;
import java.util.LinkedHashMap;
import java.util.Map;
@Builder
@Data @Data
@AllArgsConstructor @Builder
@NoArgsConstructor @NoArgsConstructor
public class DeConv2DConfig extends BaseConvolutionConfig { public class DeConv2DConfig extends BaseConvolutionConfig {
public static final String NCHW = "NCHW"; public static final String NCHW = "NCHW";
public static final String NHWC = "NHWC"; public static final String NHWC = "NHWC";
@Builder.Default private long kH = -1L; @Builder.Default private long kH = -1L;
@Builder.Default private long kW = -1L; @Builder.Default private long kW = -1L;
@Builder.Default private long sH = 1L; @Builder.Default private long sH = 1L;
@ -43,8 +43,25 @@ public class DeConv2DConfig extends BaseConvolutionConfig {
@Builder.Default private boolean isSameMode = false; @Builder.Default private boolean isSameMode = false;
@Builder.Default private String dataFormat = NCHW; @Builder.Default private String dataFormat = NCHW;
public DeConv2DConfig(long kH, long kW, long sH, long sW, long pH, long pW, long dH, long dW, boolean isSameMode,
String dataFormat) {
this.kH = kH;
this.kW = kW;
this.sH = sH;
this.sW = sW;
this.pH = pH;
this.pW = pW;
this.dH = dH;
this.dW = dW;
this.isSameMode = isSameMode;
this.dataFormat = dataFormat;
validate();
}
@Override
public Map<String, Object> toProperties() { public Map<String, Object> toProperties() {
Map<String, Object> ret = new LinkedHashMap<>(); Map<String, Object> ret = new LinkedHashMap<>();
ret.put("kH", kH); ret.put("kH", kH);
ret.put("kW", kW); ret.put("kW", kW);
@ -58,4 +75,10 @@ public class DeConv2DConfig extends BaseConvolutionConfig {
ret.put("dataFormat", dataFormat); ret.put("dataFormat", dataFormat);
return ret; return ret;
} }
@Override
protected void validate() {
ConvConfigUtil.validate2D(kH, kW, sH, sW, pH, pW, dH, dW);
Preconditions.checkArgument(dataFormat != null, "Data format can't be null");
}
} }

View File

@ -16,17 +16,16 @@
package org.nd4j.linalg.api.ops.impl.layers.convolution.config; package org.nd4j.linalg.api.ops.impl.layers.convolution.config;
import lombok.AllArgsConstructor; import java.util.LinkedHashMap;
import java.util.Map;
import lombok.Builder; import lombok.Builder;
import lombok.Data; import lombok.Data;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.util.ConvConfigUtil;
import java.util.LinkedHashMap;
import java.util.Map;
@Builder
@Data @Data
@AllArgsConstructor @Builder
@NoArgsConstructor @NoArgsConstructor
public class DeConv3DConfig extends BaseConvolutionConfig { public class DeConv3DConfig extends BaseConvolutionConfig {
public static final String NCDHW = "NCDHW"; public static final String NCDHW = "NCDHW";
@ -47,7 +46,28 @@ public class DeConv3DConfig extends BaseConvolutionConfig {
@Builder.Default private boolean isSameMode = false; @Builder.Default private boolean isSameMode = false;
@Builder.Default private String dataFormat = NCDHW; @Builder.Default private String dataFormat = NCDHW;
public DeConv3DConfig(long kD, long kH, long kW, long sD, long sH, long sW, long pD, long pH, long pW, long dD,
long dH, long dW, boolean isSameMode, String dataFormat) {
this.kD = kD;
this.kH = kH;
this.kW = kW;
this.sD = sD;
this.sH = sH;
this.sW = sW;
this.pD = pD;
this.pH = pH;
this.pW = pW;
this.dD = dD;
this.dH = dH;
this.dW = dW;
this.isSameMode = isSameMode;
this.dataFormat = dataFormat;
validate();
}
@Override
public Map<String, Object> toProperties() { public Map<String, Object> toProperties() {
Map<String, Object> ret = new LinkedHashMap<>(); Map<String, Object> ret = new LinkedHashMap<>();
ret.put("kD", kD); ret.put("kD", kD);
@ -66,4 +86,10 @@ public class DeConv3DConfig extends BaseConvolutionConfig {
ret.put("dataFormat", dataFormat); ret.put("dataFormat", dataFormat);
return ret; return ret;
} }
@Override
protected void validate() {
ConvConfigUtil.validate3D(kH, kW, kD, sH, sW, sD, pH, pW, pD, dH, dW, dD);
Preconditions.checkArgument(dataFormat != null, "Data format can't be null");
}
} }

View File

@ -1,53 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.layers.convolution.config;
import lombok.Builder;
import lombok.Data;
import java.util.LinkedHashMap;
import java.util.Map;
@Builder
@Data
public class FullConv3DConfig extends BaseConvolutionConfig {
private long dT,dW,dH,pT,pW,pH,dilationT,dilationW,dilationH,aT,aW,aH;
private boolean biasUsed;
private String dataFormat;
public Map<String,Object> toProperties() {
Map<String,Object> ret = new LinkedHashMap<>();
ret.put("dT",dT);
ret.put("dW",dW);
ret.put("dH",dH);
ret.put("pT",pT);
ret.put("pW",pW);
ret.put("pH",pH);
ret.put("dD",dilationT);
ret.put("dW",dilationW);
ret.put("dH",dilationH);
ret.put("aT",aT);
ret.put("aW",aW);
ret.put("aH",aH);
ret.put("biasUsed",biasUsed);
ret.put("dataFormat",dataFormat);
return ret;
}
}

View File

@ -16,19 +16,31 @@
package org.nd4j.linalg.api.ops.impl.layers.convolution.config; package org.nd4j.linalg.api.ops.impl.layers.convolution.config;
import lombok.Builder;
import lombok.Data;
import java.util.LinkedHashMap; import java.util.LinkedHashMap;
import java.util.Map; import java.util.Map;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import org.nd4j.linalg.util.ConvConfigUtil;
@Data @Data
@Builder @Builder
@NoArgsConstructor
public class LocalResponseNormalizationConfig extends BaseConvolutionConfig { public class LocalResponseNormalizationConfig extends BaseConvolutionConfig {
private double alpha, beta, bias; private double alpha, beta, bias;
private int depth; private int depth;
public LocalResponseNormalizationConfig(double alpha, double beta, double bias, int depth) {
this.alpha = alpha;
this.beta = beta;
this.bias = bias;
this.depth = depth;
validate();
}
@Override
public Map<String, Object> toProperties() { public Map<String, Object> toProperties() {
Map<String, Object> ret = new LinkedHashMap<>(); Map<String, Object> ret = new LinkedHashMap<>();
ret.put("alpha", alpha); ret.put("alpha", alpha);
@ -38,4 +50,9 @@ public class LocalResponseNormalizationConfig extends BaseConvolutionConfig {
return ret; return ret;
} }
@Override
protected void validate() {
ConvConfigUtil.validateLRN(alpha, beta, bias, depth);
}
} }

View File

@ -16,32 +16,32 @@
package org.nd4j.linalg.api.ops.impl.layers.convolution.config; package org.nd4j.linalg.api.ops.impl.layers.convolution.config;
import lombok.AllArgsConstructor; import java.util.LinkedHashMap;
import java.util.Map;
import lombok.Builder; import lombok.Builder;
import lombok.Data; import lombok.Data;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2D; import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2D;
import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2D.Divisor;
import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2D.Pooling2DType;
import org.nd4j.linalg.util.ConvConfigUtil;
import java.util.LinkedHashMap;
import java.util.Map;
@Builder
@AllArgsConstructor
@Data @Data
@Builder
@NoArgsConstructor @NoArgsConstructor
public class Pooling2DConfig extends BaseConvolutionConfig { public class Pooling2DConfig extends BaseConvolutionConfig {
private long kH, kW; @Builder.Default private long kH = -1, kW = -1;
private long sH, sW; @Builder.Default private long sH = 1, sW = 1;
private long pH, pW; @Builder.Default private long pH = 0, pW = 0;
private long virtualHeight, virtualWidth;
/** /**
* Extra is an optional parameter mainly for use with pnorm right now. * Extra is an optional parameter mainly for use with pnorm right now.
* All pooling implementations take 9 parameters save pnorm. * All pooling implementations take 9 parameters save pnorm.
* Pnorm takes 10 and is cast to an int. * Pnorm takes 10 and is cast to an int.
*/ */
private double extra; private double extra;
private Pooling2D.Pooling2DType type; @Builder.Default
private Pooling2D.Pooling2DType type = Pooling2DType.MAX;
@Builder.Default @Builder.Default
private Pooling2D.Divisor divisor = Pooling2D.Divisor.EXCLUDE_PADDING; private Pooling2D.Divisor divisor = Pooling2D.Divisor.EXCLUDE_PADDING;
private boolean isSameMode; private boolean isSameMode;
@ -52,7 +52,26 @@ public class Pooling2DConfig extends BaseConvolutionConfig {
@Builder.Default @Builder.Default
private boolean isNHWC = false; private boolean isNHWC = false;
public Pooling2DConfig(long kH, long kW, long sH, long sW, long pH, long pW, double extra, Pooling2DType type,
Divisor divisor, boolean isSameMode, long dH, long dW, boolean isNHWC) {
this.kH = kH;
this.kW = kW;
this.sH = sH;
this.sW = sW;
this.pH = pH;
this.pW = pW;
this.extra = extra;
this.type = type;
this.divisor = divisor;
this.isSameMode = isSameMode;
this.dH = dH;
this.dW = dW;
this.isNHWC = isNHWC;
validate();
}
@Override
public Map<String, Object> toProperties() { public Map<String, Object> toProperties() {
Map<String, Object> ret = new LinkedHashMap<>(); Map<String, Object> ret = new LinkedHashMap<>();
ret.put("kH", kH); ret.put("kH", kH);
@ -61,8 +80,6 @@ public class Pooling2DConfig extends BaseConvolutionConfig {
ret.put("sW", sW); ret.put("sW", sW);
ret.put("pH", pH); ret.put("pH", pH);
ret.put("pW", pW); ret.put("pW", pW);
ret.put("virtualHeight", virtualHeight);
ret.put("virtualWidth", virtualWidth);
ret.put("extra", extra); ret.put("extra", extra);
ret.put("type", type.toString()); ret.put("type", type.toString());
ret.put("isSameMode", isSameMode); ret.put("isSameMode", isSameMode);
@ -72,4 +89,11 @@ public class Pooling2DConfig extends BaseConvolutionConfig {
return ret; return ret;
} }
@Override
protected void validate() {
ConvConfigUtil.validate2D(kH, kW, sH, sW, pH, pW, dH, dW);
//TODO check other args?
}
} }

View File

@ -16,23 +16,22 @@
package org.nd4j.linalg.api.ops.impl.layers.convolution.config; package org.nd4j.linalg.api.ops.impl.layers.convolution.config;
import lombok.AllArgsConstructor; import java.util.LinkedHashMap;
import java.util.Map;
import lombok.Builder; import lombok.Builder;
import lombok.Data; import lombok.Data;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling3D; import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling3D;
import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling3D.Pooling3DType;
import java.util.LinkedHashMap; import org.nd4j.linalg.util.ConvConfigUtil;
import java.util.Map;
@Data @Data
@Builder @Builder
@AllArgsConstructor
@NoArgsConstructor @NoArgsConstructor
public class Pooling3DConfig extends BaseConvolutionConfig { public class Pooling3DConfig extends BaseConvolutionConfig {
private long kD, kW, kH; // kernel @Builder.Default private long kD = -1, kW = -1, kH = -1; // kernel
private long sD, sW, sH; // strides @Builder.Default private long sD = 1, sW = 1, sH = 1; // strides
private long pD, pW, pH; // padding @Builder.Default private long pD = 0, pW = 0, pH = 0; // padding
// dilation // dilation
@Builder.Default @Builder.Default
private long dD = 1; private long dD = 1;
@ -40,10 +39,33 @@ public class Pooling3DConfig extends BaseConvolutionConfig {
private long dW = 1; private long dW = 1;
@Builder.Default @Builder.Default
private long dH = 1; private long dH = 1;
private Pooling3D.Pooling3DType type; @Builder.Default
private Pooling3D.Pooling3DType type = Pooling3DType.MAX;
private boolean isSameMode; private boolean isSameMode;
@Builder.Default private boolean isNCDHW = true; @Builder.Default private boolean isNCDHW = true;
public Pooling3DConfig(long kD, long kW, long kH, long sD, long sW, long sH, long pD, long pW, long pH, long dD,
long dW, long dH, Pooling3DType type, boolean isSameMode, boolean isNCDHW) {
this.kD = kD;
this.kW = kW;
this.kH = kH;
this.sD = sD;
this.sW = sW;
this.sH = sH;
this.pD = pD;
this.pW = pW;
this.pH = pH;
this.dD = dD;
this.dW = dW;
this.dH = dH;
this.type = type;
this.isSameMode = isSameMode;
this.isNCDHW = isNCDHW;
validate();
}
@Override
public Map<String, Object> toProperties() { public Map<String, Object> toProperties() {
Map<String, Object> ret = new LinkedHashMap<>(); Map<String, Object> ret = new LinkedHashMap<>();
ret.put("kD", kD); ret.put("kD", kD);
@ -63,4 +85,11 @@ public class Pooling3DConfig extends BaseConvolutionConfig {
return ret; return ret;
} }
@Override
protected void validate() {
ConvConfigUtil.validate3D(kH, kW, kD, sH, sW, sD, pH, pW, pD, dH, dW, dD);
//TODO check other args
}
} }

View File

@ -249,8 +249,6 @@ public class Convolution {
.isSameMode(isSameMode) .isSameMode(isSameMode)
.sH(sy) .sH(sy)
.sW(sx) .sW(sx)
.virtualHeight(virtualHeight)
.virtualWidth(virtualWidth)
.type(type) .type(type)
.divisor(divisor) .divisor(divisor)
.build()) .build())

View File

@ -0,0 +1,93 @@
/*
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
*/
package org.nd4j.linalg.util;
import lombok.AccessLevel;
import lombok.NoArgsConstructor;
import org.nd4j.base.Preconditions;
/**
* Class with utility methods for validating convolution op configurations like {@link org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig}
*/
@NoArgsConstructor(access = AccessLevel.PRIVATE)
public class ConvConfigUtil {
/**
* Validate a 2D convolution's Kernel, Stride, Padding, and Dilation
*/
public static void validate2D(long kH, long kW, long sH, long sW, long pH, long pW, long dH, long dW){
Preconditions.checkArgument(kH != 0, "Kernel height can not be 0");
Preconditions.checkArgument(kW != 0, "Kernel width can not be 0");
Preconditions.checkArgument(sH > 0, "Stride height can not be negative or 0, got: %s", sH);
Preconditions.checkArgument(sW > 0, "Stride width can not be negative or 0, got: %s", sW);
Preconditions.checkArgument(pH >= 0, "Padding height can not be negative, got: %s", pH);
Preconditions.checkArgument(pW >= 0, "Padding width can not be negative, got: %s", pW);
Preconditions.checkArgument(dH > 0, "Dilation height can not be negative or 0, got: %s", dH);
Preconditions.checkArgument(dW > 0, "Dilation width can not be negative or 0, got: %s", dW);
}
/**
* Validate a 3D convolution's Kernel, Stride, Padding, and Dilation
*/
public static void validate3D(long kH, long kW, long kD, long sH, long sW, long sD, long pH, long pW, long pD, long dH, long dW, long dD){
Preconditions.checkArgument(kH != 0, "Kernel height can not be 0");
Preconditions.checkArgument(kW != 0, "Kernel width can not be 0");
Preconditions.checkArgument(kD != 0, "Kernel depth can not be 0");
Preconditions.checkArgument(sH > 0, "Stride height can not be negative or 0, got: %s", sH);
Preconditions.checkArgument(sW > 0, "Stride width can not be negative or 0, got: %s", sW);
Preconditions.checkArgument(sD > 0, "Stride depth can not be negative or 0, got: %s", sD);
Preconditions.checkArgument(pH >= 0, "Padding height can not be negative, got: %s", pH);
Preconditions.checkArgument(pW >= 0, "Padding width can not be negative, got: %s", pW);
Preconditions.checkArgument(pD >= 0, "Padding depth can not be negative, got: %s", pD);
Preconditions.checkArgument(dH > 0, "Dilation height can not be negative or 0, got: %s", dH);
Preconditions.checkArgument(dW > 0, "Dilation width can not be negative or 0, got: %s", dW);
Preconditions.checkArgument(dD > 0, "Dilation depth can not be negative or 0, got: %s", dD);
}
/**
* Validate a 3D convolution's Output Padding
*/
public static void validateExtra3D(long aH, long aW, long aD){
Preconditions.checkArgument(aH >= 0, "Output padding height can not be negative, got: %s", aH);
Preconditions.checkArgument(aW >= 0, "Output padding width can not be negative, got: %s", aW);
Preconditions.checkArgument(aD >= 0, "Output padding depth can not be negative, got: %s", aD);
}
/**
* Validate a 1D convolution's Kernel, Stride, and Padding
*/
public static void validate1D(long k, long s, long p){
Preconditions.checkArgument(k != 0, "Kernel can not be 0");
Preconditions.checkArgument(s > 0, "Stride can not be negative or 0, got: %s", s);
Preconditions.checkArgument(p >= 0, "Padding can not be negative, got: %s", p);
}
/**
* Validate a LocalResponseNormalizationConfig
*/
public static void validateLRN(double alpha, double beta, double bias, int depth) {
Preconditions.checkArgument(depth > 0, "Depth can not be 0 or negative, got: %s", depth);
}
}

View File

@ -14662,54 +14662,6 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
} }
// #endif // #endif
// #if NOT_EXCLUDED(OP_fullconv3d)
@Namespace("nd4j::ops") public static class fullconv3d extends DeclarableCustomOp {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public fullconv3d(Pointer p) { super(p); }
/** Native array allocator. Access with {@link Pointer#position(long)}. */
public fullconv3d(long size) { super((Pointer)null); allocateArray(size); }
private native void allocateArray(long size);
@Override public fullconv3d position(long position) {
return (fullconv3d)super.position(position);
}
public fullconv3d() { super((Pointer)null); allocate(); }
private native void allocate();
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
}
@Namespace("nd4j::ops") public static class fullconv3d_bp extends DeclarableCustomOp {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public fullconv3d_bp(Pointer p) { super(p); }
/** Native array allocator. Access with {@link Pointer#position(long)}. */
public fullconv3d_bp(long size) { super((Pointer)null); allocateArray(size); }
private native void allocateArray(long size);
@Override public fullconv3d_bp position(long position) {
return (fullconv3d_bp)super.position(position);
}
public fullconv3d_bp() { super((Pointer)null); allocate(); }
private native void allocate();
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
}
@Namespace("nd4j::ops") public static class fullconv3d_grad extends DeclarableCustomOp {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public fullconv3d_grad(Pointer p) { super(p); }
/** Native array allocator. Access with {@link Pointer#position(long)}. */
public fullconv3d_grad(long size) { super((Pointer)null); allocateArray(size); }
private native void allocateArray(long size);
@Override public fullconv3d_grad position(long position) {
return (fullconv3d_grad)super.position(position);
}
public fullconv3d_grad() { super((Pointer)null); allocate(); }
private native void allocate();
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
}
// #endif
/** /**
* This op implements im2col algorithm, widely used in convolution neural networks * This op implements im2col algorithm, widely used in convolution neural networks
* Input: 4D input expected * Input: 4D input expected

View File

@ -1124,7 +1124,7 @@ public class LayerOpValidation extends BaseOpValidation {
assertNull(err, err); assertNull(err, err);
} }
@Test(expected = IllegalStateException.class) @Test(expected = IllegalArgumentException.class)
public void exceptionThrown_WhenConv1DConfigInvalid() { public void exceptionThrown_WhenConv1DConfigInvalid() {
int nIn = 3; int nIn = 3;
int nOut = 4; int nOut = 4;
@ -1150,7 +1150,7 @@ public class LayerOpValidation extends BaseOpValidation {
} }
@Test(expected = IllegalStateException.class) @Test(expected = IllegalArgumentException.class)
public void exceptionThrown_WhenConv2DConfigInvalid() { public void exceptionThrown_WhenConv2DConfigInvalid() {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -1171,7 +1171,7 @@ public class LayerOpValidation extends BaseOpValidation {
.build()); .build());
} }
@Test(expected = IllegalStateException.class) @Test(expected = IllegalArgumentException.class)
public void exceptionThrown_WhenConf3DInvalid() { public void exceptionThrown_WhenConf3DInvalid() {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);

View File

@ -0,0 +1,515 @@
/*
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
*/
package org.nd4j.autodiff.samediff;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import org.junit.Assert;
import org.junit.Test;
import org.nd4j.linalg.api.ops.impl.layers.convolution.DeConv2D;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv1DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv3DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv2DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv3DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling2DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling3DConfig;
public class ConvConfigTests {
@Test
public void testDeConv2D(){
DeConv2DConfig.builder().kH(2).kW(4).build();
try{
DeConv2DConfig.builder().kW(4).kH(0).build();
fail();
} catch (IllegalArgumentException e){
assertTrue(e.getMessage().contains("Kernel height"));
}
try{
DeConv2DConfig.builder().kH(4).kW(0).build();
fail();
} catch (IllegalArgumentException e){
assertTrue(e.getMessage().contains("Kernel width"));
}
try{
DeConv2DConfig.builder().kH(4).kW(3).sH(-2).build();
fail();
} catch (IllegalArgumentException e){
assertTrue(e.getMessage().contains("Stride height"));
}
try{
DeConv2DConfig.builder().kH(4).kW(3).sW(-2).build();
fail();
} catch (IllegalArgumentException e){
assertTrue(e.getMessage().contains("Stride width"));
}
try{
DeConv2DConfig.builder().kH(4).kW(3).pH(-2).build();
fail();
} catch (IllegalArgumentException e){
assertTrue(e.getMessage().contains("Padding height"));
}
try{
DeConv2DConfig.builder().kH(4).kW(3).pW(-2).build();
fail();
} catch (IllegalArgumentException e){
assertTrue(e.getMessage().contains("Padding width"));
}
try{
DeConv2DConfig.builder().kH(4).kW(3).dH(-2).build();
fail();
} catch (IllegalArgumentException e){
assertTrue(e.getMessage().contains("Dilation height"));
}
try{
DeConv2DConfig.builder().kH(4).kW(3).dW(-2).build();
fail();
} catch (IllegalArgumentException e){
assertTrue(e.getMessage().contains("Dilation width"));
}
}
@Test
public void testConv2D(){
Conv2DConfig.builder().kH(2).kW(4).build();
try{
Conv2DConfig.builder().kW(4).kH(0).build();
fail();
} catch (IllegalArgumentException e){
assertTrue(e.getMessage().contains("Kernel height"));
}
try{
Conv2DConfig.builder().kH(4).kW(0).build();
fail();
} catch (IllegalArgumentException e){
assertTrue(e.getMessage().contains("Kernel width"));
}
try{
Conv2DConfig.builder().kH(4).kW(3).sH(-2).build();
fail();
} catch (IllegalArgumentException e){
assertTrue(e.getMessage().contains("Stride height"));
}
try{
Conv2DConfig.builder().kH(4).kW(3).sW(-2).build();
fail();
} catch (IllegalArgumentException e){
assertTrue(e.getMessage().contains("Stride width"));
}
try{
Conv2DConfig.builder().kH(4).kW(3).pH(-2).build();
fail();
} catch (IllegalArgumentException e){
assertTrue(e.getMessage().contains("Padding height"));
}
try{
Conv2DConfig.builder().kH(4).kW(3).pW(-2).build();
fail();
} catch (IllegalArgumentException e){
assertTrue(e.getMessage().contains("Padding width"));
}
try{
Conv2DConfig.builder().kH(4).kW(3).dH(-2).build();
fail();
} catch (IllegalArgumentException e){
assertTrue(e.getMessage().contains("Dilation height"));
}
try{
Conv2DConfig.builder().kH(4).kW(3).dW(-2).build();
fail();
} catch (IllegalArgumentException e){
assertTrue(e.getMessage().contains("Dilation width"));
}
}
@Test
public void testPooling2D(){
Pooling2DConfig.builder().kH(2).kW(4).build();
try{
Pooling2DConfig.builder().kW(4).kH(0).build();
fail();
} catch (IllegalArgumentException e){
assertTrue(e.getMessage().contains("Kernel height"));
}
try{
Pooling2DConfig.builder().kH(4).kW(0).build();
fail();
} catch (IllegalArgumentException e){
assertTrue(e.getMessage().contains("Kernel width"));
}
try{
Pooling2DConfig.builder().kH(4).kW(3).sH(-2).build();
fail();
} catch (IllegalArgumentException e){
assertTrue(e.getMessage().contains("Stride height"));
}
try{
Pooling2DConfig.builder().kH(4).kW(3).sW(-2).build();
fail();
} catch (IllegalArgumentException e){
assertTrue(e.getMessage().contains("Stride width"));
}
try{
Pooling2DConfig.builder().kH(4).kW(3).pH(-2).build();
fail();
} catch (IllegalArgumentException e){
assertTrue(e.getMessage().contains("Padding height"));
}
try{
Pooling2DConfig.builder().kH(4).kW(3).pW(-2).build();
fail();
} catch (IllegalArgumentException e){
assertTrue(e.getMessage().contains("Padding width"));
}
try{
Pooling2DConfig.builder().kH(4).kW(3).dH(-2).build();
fail();
} catch (IllegalArgumentException e){
assertTrue(e.getMessage().contains("Dilation height"));
}
try{
Pooling2DConfig.builder().kH(4).kW(3).dW(-2).build();
fail();
} catch (IllegalArgumentException e){
assertTrue(e.getMessage().contains("Dilation width"));
}
}
@Test
public void testDeConv3D(){
DeConv3DConfig.builder().kH(2).kW(4).kD(3).build();
try{
DeConv3DConfig.builder().kW(4).kD(3).kH(0).build();
fail();
} catch (IllegalArgumentException e){
assertTrue(e.getMessage().contains("Kernel height"));
}
try{
DeConv3DConfig.builder().kH(4).kD(3).kW(0).build();
fail();
} catch (IllegalArgumentException e){
assertTrue(e.getMessage().contains("Kernel width"));
}
try{
DeConv3DConfig.builder().kH(4).kW(3).kD(0).build();
fail();
} catch (IllegalArgumentException e){
assertTrue(e.getMessage().contains("Kernel depth"));
}
try{
DeConv3DConfig.builder().kH(4).kW(3).kD(3).sH(-2).build();
fail();
} catch (IllegalArgumentException e){
assertTrue(e.getMessage().contains("Stride height"));
}
try{
DeConv3DConfig.builder().kH(4).kW(3).kD(3).sW(-2).build();
fail();
} catch (IllegalArgumentException e){
assertTrue(e.getMessage().contains("Stride width"));
}
try{
DeConv3DConfig.builder().kH(4).kW(3).kD(3).sD(-2).build();
fail();
} catch (IllegalArgumentException e){
assertTrue(e.getMessage().contains("Stride depth"));
}
try{
DeConv3DConfig.builder().kH(4).kW(3).kD(3).pH(-2).build();
fail();
} catch (IllegalArgumentException e){
assertTrue(e.getMessage().contains("Padding height"));
}
try{
DeConv3DConfig.builder().kH(4).kW(3).kD(3).pW(-2).build();
fail();
} catch (IllegalArgumentException e){
assertTrue(e.getMessage().contains("Padding width"));
}
try{
DeConv3DConfig.builder().kH(4).kW(3).kD(3).pD(-2).build();
fail();
} catch (IllegalArgumentException e){
assertTrue(e.getMessage().contains("Padding depth"));
}
try{
DeConv3DConfig.builder().kH(4).kW(3).kD(3).dH(-2).build();
fail();
} catch (IllegalArgumentException e){
assertTrue(e.getMessage().contains("Dilation height"));
}
try{
DeConv3DConfig.builder().kH(4).kW(3).kD(3).dW(-2).build();
fail();
} catch (IllegalArgumentException e){
assertTrue(e.getMessage().contains("Dilation width"));
}
try{
DeConv3DConfig.builder().kH(4).kW(3).kD(3).dD(-2).build();
fail();
} catch (IllegalArgumentException e){
assertTrue(e.getMessage().contains("Dilation depth"));
}
}
@Test
public void testConv3D(){
Conv3DConfig.builder().kH(2).kW(4).kD(3).build();
try{
Conv3DConfig.builder().kW(4).kD(3).kH(0).build();
fail();
} catch (IllegalArgumentException e){
assertTrue(e.getMessage().contains("Kernel height"));
}
try{
Conv3DConfig.builder().kH(4).kD(3).kW(0).build();
fail();
} catch (IllegalArgumentException e){
assertTrue(e.getMessage().contains("Kernel width"));
}
try{
Conv3DConfig.builder().kH(4).kW(3).kD(0).build();
fail();
} catch (IllegalArgumentException e){
assertTrue(e.getMessage().contains("Kernel depth"));
}
try{
Conv3DConfig.builder().kH(4).kW(3).kD(3).sH(-2).build();
fail();
} catch (IllegalArgumentException e){
assertTrue(e.getMessage().contains("Stride height"));
}
try{
Conv3DConfig.builder().kH(4).kW(3).kD(3).sW(-2).build();
fail();
} catch (IllegalArgumentException e){
assertTrue(e.getMessage().contains("Stride width"));
}
try{
Conv3DConfig.builder().kH(4).kW(3).kD(3).sD(-2).build();
fail();
} catch (IllegalArgumentException e){
assertTrue(e.getMessage().contains("Stride depth"));
}
try{
Conv3DConfig.builder().kH(4).kW(3).kD(3).pH(-2).build();
fail();
} catch (IllegalArgumentException e){
assertTrue(e.getMessage().contains("Padding height"));
}
try{
Conv3DConfig.builder().kH(4).kW(3).kD(3).pW(-2).build();
fail();
} catch (IllegalArgumentException e){
assertTrue(e.getMessage().contains("Padding width"));
}
try{
Conv3DConfig.builder().kH(4).kW(3).kD(3).pD(-2).build();
fail();
} catch (IllegalArgumentException e){
assertTrue(e.getMessage().contains("Padding depth"));
}
try{
Conv3DConfig.builder().kH(4).kW(3).kD(3).dH(-2).build();
fail();
} catch (IllegalArgumentException e){
assertTrue(e.getMessage().contains("Dilation height"));
}
try{
Conv3DConfig.builder().kH(4).kW(3).kD(3).dW(-2).build();
fail();
} catch (IllegalArgumentException e){
assertTrue(e.getMessage().contains("Dilation width"));
}
try{
Conv3DConfig.builder().kH(4).kW(3).kD(3).dD(-2).build();
fail();
} catch (IllegalArgumentException e){
assertTrue(e.getMessage().contains("Dilation depth"));
}
}
@Test
public void testPooling3D(){
Pooling3DConfig.builder().kH(2).kW(4).kD(3).build();
try{
Pooling3DConfig.builder().kW(4).kD(3).kH(0).build();
fail();
} catch (IllegalArgumentException e){
assertTrue(e.getMessage().contains("Kernel height"));
}
try{
Pooling3DConfig.builder().kH(4).kD(3).kW(0).build();
fail();
} catch (IllegalArgumentException e){
assertTrue(e.getMessage().contains("Kernel width"));
}
try{
Pooling3DConfig.builder().kH(4).kW(3).kD(0).build();
fail();
} catch (IllegalArgumentException e){
assertTrue(e.getMessage().contains("Kernel depth"));
}
try{
Pooling3DConfig.builder().kH(4).kW(3).kD(3).sH(-2).build();
fail();
} catch (IllegalArgumentException e){
assertTrue(e.getMessage().contains("Stride height"));
}
try{
Pooling3DConfig.builder().kH(4).kW(3).kD(3).sW(-2).build();
fail();
} catch (IllegalArgumentException e){
assertTrue(e.getMessage().contains("Stride width"));
}
try{
Pooling3DConfig.builder().kH(4).kW(3).kD(3).sD(-2).build();
fail();
} catch (IllegalArgumentException e){
assertTrue(e.getMessage().contains("Stride depth"));
}
try{
Pooling3DConfig.builder().kH(4).kW(3).kD(3).pH(-2).build();
fail();
} catch (IllegalArgumentException e){
assertTrue(e.getMessage().contains("Padding height"));
}
try{
Pooling3DConfig.builder().kH(4).kW(3).kD(3).pW(-2).build();
fail();
} catch (IllegalArgumentException e){
assertTrue(e.getMessage().contains("Padding width"));
}
try{
Pooling3DConfig.builder().kH(4).kW(3).kD(3).pD(-2).build();
fail();
} catch (IllegalArgumentException e){
assertTrue(e.getMessage().contains("Padding depth"));
}
try{
Pooling3DConfig.builder().kH(4).kW(3).kD(3).dH(-2).build();
fail();
} catch (IllegalArgumentException e){
assertTrue(e.getMessage().contains("Dilation height"));
}
try{
Pooling3DConfig.builder().kH(4).kW(3).kD(3).dW(-2).build();
fail();
} catch (IllegalArgumentException e){
assertTrue(e.getMessage().contains("Dilation width"));
}
try{
Pooling3DConfig.builder().kH(4).kW(3).kD(3).dD(-2).build();
fail();
} catch (IllegalArgumentException e){
assertTrue(e.getMessage().contains("Dilation depth"));
}
}
@Test
public void testConv1D(){
Conv1DConfig.builder().k(2).build();
try{
Conv1DConfig.builder().k(0).build();
fail();
} catch (IllegalArgumentException e){
assertTrue(e.getMessage().contains("Kernel"));
}
try{
Conv1DConfig.builder().k(4).s(-2).build();
fail();
} catch (IllegalArgumentException e){
assertTrue(e.getMessage().contains("Stride"));
}
try{
Conv1DConfig.builder().k(3).p(-2).build();
fail();
} catch (IllegalArgumentException e){
assertTrue(e.getMessage().contains("Padding"));
}
}
}