More int/long compilation fixes (#22)

Signed-off-by: Alex Black <blacka101@gmail.com>
master
Alex Black 2019-11-02 15:20:19 +11:00 committed by GitHub
parent d82877b18b
commit e9a7a13c00
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 47 additions and 15 deletions

View File

@ -123,7 +123,7 @@ public class CudnnBatchNormalizationHelper extends BaseCudnnHelper implements Ba
}
@Override
public Pair<Gradient, INDArray> backpropGradient(INDArray input, INDArray epsilon, int[] shape, INDArray gamma, INDArray beta,
public Pair<Gradient, INDArray> backpropGradient(INDArray input, INDArray epsilon, long[] shape, INDArray gamma, INDArray beta,
INDArray dGammaView, INDArray dBetaView, double eps, LayerWorkspaceMgr layerWorkspaceMgr) {
this.eps = eps;
val miniBatch = (int) input.size(0);
@ -173,8 +173,8 @@ public class CudnnBatchNormalizationHelper extends BaseCudnnHelper implements Ba
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, miniBatch, depth, inH, inW,
dstStride[0], dstStride[1], dstStride[2], dstStride[3]));
checkCudnn(cudnnSetTensor4dDescriptor(cudnnContext.gammaBetaTensorDesc, TENSOR_FORMAT, toCudnnDataType(gamma.data().dataType()), shape[0],
shape[1], shape.length > 2 ? shape[2] : 1, shape.length > 3 ? shape[3] : 1));
checkCudnn(cudnnSetTensor4dDescriptor(cudnnContext.gammaBetaTensorDesc, TENSOR_FORMAT, toCudnnDataType(gamma.data().dataType()), (int)shape[0],
(int)shape[1], shape.length > 2 ? (int)shape[2] : 1, shape.length > 3 ? (int)shape[3] : 1));
Allocator allocator = AtomicAllocator.getInstance();
CudaContext context = allocator.getFlowController().prepareActionAllWrite(input, epsilon, nextEpsilon, gamma,
@ -214,7 +214,7 @@ public class CudnnBatchNormalizationHelper extends BaseCudnnHelper implements Ba
@Override
public INDArray preOutput(INDArray x, boolean training, int[] shape, INDArray gamma, INDArray beta, INDArray mean,
public INDArray preOutput(INDArray x, boolean training, long[] shape, INDArray gamma, INDArray beta, INDArray mean,
INDArray var, double decay, double eps, LayerWorkspaceMgr workspaceMgr) {
this.eps = eps;
final boolean isHalf = (x.dataType() == DataType.HALF);
@ -252,8 +252,8 @@ public class CudnnBatchNormalizationHelper extends BaseCudnnHelper implements Ba
checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, miniBatch, inDepth, inH, inW,
dstStride[0], dstStride[1], dstStride[2], dstStride[3]));
checkCudnn(cudnnSetTensor4dDescriptor(cudnnContext.gammaBetaTensorDesc, TENSOR_FORMAT, toCudnnDataType(mean.data().dataType()), shape[0],
shape[1], shape.length > 2 ? shape[2] : 1, shape.length > 3 ? shape[3] : 1));
checkCudnn(cudnnSetTensor4dDescriptor(cudnnContext.gammaBetaTensorDesc, TENSOR_FORMAT, toCudnnDataType(mean.data().dataType()), (int)shape[0],
(int)shape[1], shape.length > 2 ? (int)shape[2] : 1, shape.length > 3 ? (int)shape[3] : 1));
Allocator allocator = AtomicAllocator.getInstance();
CudaContext context =

View File

@ -93,11 +93,11 @@ public class Glove implements Serializable {
VocabWord w1, INDArray wordVector, INDArray contextVector, double gradient) {
//gradient for word vectors
INDArray grad1 = contextVector.mul(gradient);
INDArray update = weightAdaGrad.getGradient(grad1, w1.getIndex(), ArrayUtil.toInts(syn0.shape()));
INDArray update = weightAdaGrad.getGradient(grad1, w1.getIndex(), syn0.shape());
wordVector.subi(update);
double w1Bias = bias.getDouble(w1.getIndex());
double biasGradient = biasAdaGrad.getGradient(gradient, w1.getIndex(), ArrayUtil.toInts(bias.shape()));
double biasGradient = biasAdaGrad.getGradient(gradient, w1.getIndex(), bias.shape());
double update2 = w1Bias - biasGradient;
bias.putScalar(w1.getIndex(), bias.getDouble(w1.getIndex()) - update2);
return new Pair<>(update, (float) update2);

View File

@ -0,0 +1,31 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.spark.impl.common.reduce;
import org.apache.spark.api.java.function.Function2;
import scala.Tuple2;
/**
* Add both elements of a {@code Tuple2<Integer,Double>}
*/
public class LongDoubleReduceFunction
implements Function2<Tuple2<Long, Double>, Tuple2<Long, Double>, Tuple2<Long, Double>> {
@Override
public Tuple2<Long, Double> call(Tuple2<Long, Double> f, Tuple2<Long, Double> s) throws Exception {
return new Tuple2<>(f._1() + s._1(), f._2() + s._2());
}
}

View File

@ -38,6 +38,7 @@ import org.deeplearning4j.spark.api.TrainingMaster;
import org.deeplearning4j.spark.api.stats.SparkTrainingStats;
import org.deeplearning4j.spark.impl.SparkListenable;
import org.deeplearning4j.spark.impl.common.reduce.IntDoubleReduceFunction;
import org.deeplearning4j.spark.impl.common.reduce.LongDoubleReduceFunction;
import org.deeplearning4j.spark.impl.graph.dataset.DataSetToMultiDataSetFn;
import org.deeplearning4j.spark.impl.graph.dataset.PairDataSetToMultiDataSetFn;
import org.deeplearning4j.spark.impl.graph.evaluation.IEvaluateMDSFlatMapFunction;
@ -374,11 +375,11 @@ public class SparkComputationGraph extends SparkListenable {
* in one go)
*/
public double calculateScore(JavaRDD<DataSet> data, boolean average, int minibatchSize) {
JavaRDD<Tuple2<Integer, Double>> rdd = data.mapPartitions(new ScoreFlatMapFunctionCGDataSet(conf.toJson(),
sc.broadcast(network.params(false)), minibatchSize));
JavaRDD<Tuple2<Long, Double>> rdd = data.mapPartitions(new ScoreFlatMapFunctionCGDataSet(conf.toJson(),
sc.broadcast(network.params()), minibatchSize));
//Reduce to a single tuple, with example count + sum of scores
Tuple2<Integer, Double> countAndSumScores = rdd.reduce(new IntDoubleReduceFunction());
Tuple2<Long, Double> countAndSumScores = rdd.reduce(new LongDoubleReduceFunction());
if (average) {
return countAndSumScores._2() / countAndSumScores._1();
} else {
@ -409,10 +410,10 @@ public class SparkComputationGraph extends SparkListenable {
* in one go)
*/
public double calculateScoreMultiDataSet(JavaRDD<MultiDataSet> data, boolean average, int minibatchSize) {
JavaRDD<Tuple2<Integer, Double>> rdd = data.mapPartitions(new ScoreFlatMapFunctionCGMultiDataSet(conf.toJson(),
sc.broadcast(network.params(false)), minibatchSize));
JavaRDD<Tuple2<Long, Double>> rdd = data.mapPartitions(new ScoreFlatMapFunctionCGMultiDataSet(conf.toJson(),
sc.broadcast(network.params()), minibatchSize));
//Reduce to a single tuple, with example count + sum of scores
Tuple2<Integer, Double> countAndSumScores = rdd.reduce(new IntDoubleReduceFunction());
Tuple2<Long, Double> countAndSumScores = rdd.reduce(new LongDoubleReduceFunction());
if (average) {
return countAndSumScores._2() / countAndSumScores._1();
} else {

View File

@ -47,7 +47,7 @@ public abstract class BaseDataSetIterator<T> implements DataSetIterator {
public int inputColumns() {
if (inputColumns == -1)
preloadDataSet();
return inputColumns;
return (int)inputColumns;
}
@Override