parent
d82877b18b
commit
e9a7a13c00
|
@ -123,7 +123,7 @@ public class CudnnBatchNormalizationHelper extends BaseCudnnHelper implements Ba
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Pair<Gradient, INDArray> backpropGradient(INDArray input, INDArray epsilon, int[] shape, INDArray gamma, INDArray beta,
|
public Pair<Gradient, INDArray> backpropGradient(INDArray input, INDArray epsilon, long[] shape, INDArray gamma, INDArray beta,
|
||||||
INDArray dGammaView, INDArray dBetaView, double eps, LayerWorkspaceMgr layerWorkspaceMgr) {
|
INDArray dGammaView, INDArray dBetaView, double eps, LayerWorkspaceMgr layerWorkspaceMgr) {
|
||||||
this.eps = eps;
|
this.eps = eps;
|
||||||
val miniBatch = (int) input.size(0);
|
val miniBatch = (int) input.size(0);
|
||||||
|
@ -173,8 +173,8 @@ public class CudnnBatchNormalizationHelper extends BaseCudnnHelper implements Ba
|
||||||
|
|
||||||
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, miniBatch, depth, inH, inW,
|
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, miniBatch, depth, inH, inW,
|
||||||
dstStride[0], dstStride[1], dstStride[2], dstStride[3]));
|
dstStride[0], dstStride[1], dstStride[2], dstStride[3]));
|
||||||
checkCudnn(cudnnSetTensor4dDescriptor(cudnnContext.gammaBetaTensorDesc, TENSOR_FORMAT, toCudnnDataType(gamma.data().dataType()), shape[0],
|
checkCudnn(cudnnSetTensor4dDescriptor(cudnnContext.gammaBetaTensorDesc, TENSOR_FORMAT, toCudnnDataType(gamma.data().dataType()), (int)shape[0],
|
||||||
shape[1], shape.length > 2 ? shape[2] : 1, shape.length > 3 ? shape[3] : 1));
|
(int)shape[1], shape.length > 2 ? (int)shape[2] : 1, shape.length > 3 ? (int)shape[3] : 1));
|
||||||
|
|
||||||
Allocator allocator = AtomicAllocator.getInstance();
|
Allocator allocator = AtomicAllocator.getInstance();
|
||||||
CudaContext context = allocator.getFlowController().prepareActionAllWrite(input, epsilon, nextEpsilon, gamma,
|
CudaContext context = allocator.getFlowController().prepareActionAllWrite(input, epsilon, nextEpsilon, gamma,
|
||||||
|
@ -214,7 +214,7 @@ public class CudnnBatchNormalizationHelper extends BaseCudnnHelper implements Ba
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray preOutput(INDArray x, boolean training, int[] shape, INDArray gamma, INDArray beta, INDArray mean,
|
public INDArray preOutput(INDArray x, boolean training, long[] shape, INDArray gamma, INDArray beta, INDArray mean,
|
||||||
INDArray var, double decay, double eps, LayerWorkspaceMgr workspaceMgr) {
|
INDArray var, double decay, double eps, LayerWorkspaceMgr workspaceMgr) {
|
||||||
this.eps = eps;
|
this.eps = eps;
|
||||||
final boolean isHalf = (x.dataType() == DataType.HALF);
|
final boolean isHalf = (x.dataType() == DataType.HALF);
|
||||||
|
@ -252,8 +252,8 @@ public class CudnnBatchNormalizationHelper extends BaseCudnnHelper implements Ba
|
||||||
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, miniBatch, inDepth, inH, inW,
|
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, miniBatch, inDepth, inH, inW,
|
||||||
dstStride[0], dstStride[1], dstStride[2], dstStride[3]));
|
dstStride[0], dstStride[1], dstStride[2], dstStride[3]));
|
||||||
|
|
||||||
checkCudnn(cudnnSetTensor4dDescriptor(cudnnContext.gammaBetaTensorDesc, TENSOR_FORMAT, toCudnnDataType(mean.data().dataType()), shape[0],
|
checkCudnn(cudnnSetTensor4dDescriptor(cudnnContext.gammaBetaTensorDesc, TENSOR_FORMAT, toCudnnDataType(mean.data().dataType()), (int)shape[0],
|
||||||
shape[1], shape.length > 2 ? shape[2] : 1, shape.length > 3 ? shape[3] : 1));
|
(int)shape[1], shape.length > 2 ? (int)shape[2] : 1, shape.length > 3 ? (int)shape[3] : 1));
|
||||||
|
|
||||||
Allocator allocator = AtomicAllocator.getInstance();
|
Allocator allocator = AtomicAllocator.getInstance();
|
||||||
CudaContext context =
|
CudaContext context =
|
||||||
|
|
|
@ -93,11 +93,11 @@ public class Glove implements Serializable {
|
||||||
VocabWord w1, INDArray wordVector, INDArray contextVector, double gradient) {
|
VocabWord w1, INDArray wordVector, INDArray contextVector, double gradient) {
|
||||||
//gradient for word vectors
|
//gradient for word vectors
|
||||||
INDArray grad1 = contextVector.mul(gradient);
|
INDArray grad1 = contextVector.mul(gradient);
|
||||||
INDArray update = weightAdaGrad.getGradient(grad1, w1.getIndex(), ArrayUtil.toInts(syn0.shape()));
|
INDArray update = weightAdaGrad.getGradient(grad1, w1.getIndex(), syn0.shape());
|
||||||
wordVector.subi(update);
|
wordVector.subi(update);
|
||||||
|
|
||||||
double w1Bias = bias.getDouble(w1.getIndex());
|
double w1Bias = bias.getDouble(w1.getIndex());
|
||||||
double biasGradient = biasAdaGrad.getGradient(gradient, w1.getIndex(), ArrayUtil.toInts(bias.shape()));
|
double biasGradient = biasAdaGrad.getGradient(gradient, w1.getIndex(), bias.shape());
|
||||||
double update2 = w1Bias - biasGradient;
|
double update2 = w1Bias - biasGradient;
|
||||||
bias.putScalar(w1.getIndex(), bias.getDouble(w1.getIndex()) - update2);
|
bias.putScalar(w1.getIndex(), bias.getDouble(w1.getIndex()) - update2);
|
||||||
return new Pair<>(update, (float) update2);
|
return new Pair<>(update, (float) update2);
|
||||||
|
|
|
@ -0,0 +1,31 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.deeplearning4j.spark.impl.common.reduce;
|
||||||
|
|
||||||
|
import org.apache.spark.api.java.function.Function2;
|
||||||
|
import scala.Tuple2;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Add both elements of a {@code Tuple2<Integer,Double>}
|
||||||
|
*/
|
||||||
|
public class LongDoubleReduceFunction
|
||||||
|
implements Function2<Tuple2<Long, Double>, Tuple2<Long, Double>, Tuple2<Long, Double>> {
|
||||||
|
@Override
|
||||||
|
public Tuple2<Long, Double> call(Tuple2<Long, Double> f, Tuple2<Long, Double> s) throws Exception {
|
||||||
|
return new Tuple2<>(f._1() + s._1(), f._2() + s._2());
|
||||||
|
}
|
||||||
|
}
|
|
@ -38,6 +38,7 @@ import org.deeplearning4j.spark.api.TrainingMaster;
|
||||||
import org.deeplearning4j.spark.api.stats.SparkTrainingStats;
|
import org.deeplearning4j.spark.api.stats.SparkTrainingStats;
|
||||||
import org.deeplearning4j.spark.impl.SparkListenable;
|
import org.deeplearning4j.spark.impl.SparkListenable;
|
||||||
import org.deeplearning4j.spark.impl.common.reduce.IntDoubleReduceFunction;
|
import org.deeplearning4j.spark.impl.common.reduce.IntDoubleReduceFunction;
|
||||||
|
import org.deeplearning4j.spark.impl.common.reduce.LongDoubleReduceFunction;
|
||||||
import org.deeplearning4j.spark.impl.graph.dataset.DataSetToMultiDataSetFn;
|
import org.deeplearning4j.spark.impl.graph.dataset.DataSetToMultiDataSetFn;
|
||||||
import org.deeplearning4j.spark.impl.graph.dataset.PairDataSetToMultiDataSetFn;
|
import org.deeplearning4j.spark.impl.graph.dataset.PairDataSetToMultiDataSetFn;
|
||||||
import org.deeplearning4j.spark.impl.graph.evaluation.IEvaluateMDSFlatMapFunction;
|
import org.deeplearning4j.spark.impl.graph.evaluation.IEvaluateMDSFlatMapFunction;
|
||||||
|
@ -374,11 +375,11 @@ public class SparkComputationGraph extends SparkListenable {
|
||||||
* in one go)
|
* in one go)
|
||||||
*/
|
*/
|
||||||
public double calculateScore(JavaRDD<DataSet> data, boolean average, int minibatchSize) {
|
public double calculateScore(JavaRDD<DataSet> data, boolean average, int minibatchSize) {
|
||||||
JavaRDD<Tuple2<Integer, Double>> rdd = data.mapPartitions(new ScoreFlatMapFunctionCGDataSet(conf.toJson(),
|
JavaRDD<Tuple2<Long, Double>> rdd = data.mapPartitions(new ScoreFlatMapFunctionCGDataSet(conf.toJson(),
|
||||||
sc.broadcast(network.params(false)), minibatchSize));
|
sc.broadcast(network.params()), minibatchSize));
|
||||||
|
|
||||||
//Reduce to a single tuple, with example count + sum of scores
|
//Reduce to a single tuple, with example count + sum of scores
|
||||||
Tuple2<Integer, Double> countAndSumScores = rdd.reduce(new IntDoubleReduceFunction());
|
Tuple2<Long, Double> countAndSumScores = rdd.reduce(new LongDoubleReduceFunction());
|
||||||
if (average) {
|
if (average) {
|
||||||
return countAndSumScores._2() / countAndSumScores._1();
|
return countAndSumScores._2() / countAndSumScores._1();
|
||||||
} else {
|
} else {
|
||||||
|
@ -409,10 +410,10 @@ public class SparkComputationGraph extends SparkListenable {
|
||||||
* in one go)
|
* in one go)
|
||||||
*/
|
*/
|
||||||
public double calculateScoreMultiDataSet(JavaRDD<MultiDataSet> data, boolean average, int minibatchSize) {
|
public double calculateScoreMultiDataSet(JavaRDD<MultiDataSet> data, boolean average, int minibatchSize) {
|
||||||
JavaRDD<Tuple2<Integer, Double>> rdd = data.mapPartitions(new ScoreFlatMapFunctionCGMultiDataSet(conf.toJson(),
|
JavaRDD<Tuple2<Long, Double>> rdd = data.mapPartitions(new ScoreFlatMapFunctionCGMultiDataSet(conf.toJson(),
|
||||||
sc.broadcast(network.params(false)), minibatchSize));
|
sc.broadcast(network.params()), minibatchSize));
|
||||||
//Reduce to a single tuple, with example count + sum of scores
|
//Reduce to a single tuple, with example count + sum of scores
|
||||||
Tuple2<Integer, Double> countAndSumScores = rdd.reduce(new IntDoubleReduceFunction());
|
Tuple2<Long, Double> countAndSumScores = rdd.reduce(new LongDoubleReduceFunction());
|
||||||
if (average) {
|
if (average) {
|
||||||
return countAndSumScores._2() / countAndSumScores._1();
|
return countAndSumScores._2() / countAndSumScores._1();
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -47,7 +47,7 @@ public abstract class BaseDataSetIterator<T> implements DataSetIterator {
|
||||||
public int inputColumns() {
|
public int inputColumns() {
|
||||||
if (inputColumns == -1)
|
if (inputColumns == -1)
|
||||||
preloadDataSet();
|
preloadDataSet();
|
||||||
return inputColumns;
|
return (int)inputColumns;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
Loading…
Reference in New Issue