More int/long compilation fixes (#22)
Signed-off-by: Alex Black <blacka101@gmail.com>
This commit is contained in:
		
							parent
							
								
									d82877b18b
								
							
						
					
					
						commit
						e9a7a13c00
					
				| @ -123,7 +123,7 @@ public class CudnnBatchNormalizationHelper extends BaseCudnnHelper implements Ba | |||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     @Override |     @Override | ||||||
|     public Pair<Gradient, INDArray> backpropGradient(INDArray input, INDArray epsilon, int[] shape, INDArray gamma, INDArray beta, |     public Pair<Gradient, INDArray> backpropGradient(INDArray input, INDArray epsilon, long[] shape, INDArray gamma, INDArray beta, | ||||||
|                     INDArray dGammaView, INDArray dBetaView, double eps, LayerWorkspaceMgr layerWorkspaceMgr) { |                     INDArray dGammaView, INDArray dBetaView, double eps, LayerWorkspaceMgr layerWorkspaceMgr) { | ||||||
|         this.eps = eps; |         this.eps = eps; | ||||||
|         val miniBatch = (int) input.size(0); |         val miniBatch = (int) input.size(0); | ||||||
| @ -173,8 +173,8 @@ public class CudnnBatchNormalizationHelper extends BaseCudnnHelper implements Ba | |||||||
| 
 | 
 | ||||||
|         checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, miniBatch, depth, inH, inW, |         checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, miniBatch, depth, inH, inW, | ||||||
|                         dstStride[0], dstStride[1], dstStride[2], dstStride[3])); |                         dstStride[0], dstStride[1], dstStride[2], dstStride[3])); | ||||||
|         checkCudnn(cudnnSetTensor4dDescriptor(cudnnContext.gammaBetaTensorDesc, TENSOR_FORMAT, toCudnnDataType(gamma.data().dataType()), shape[0], |         checkCudnn(cudnnSetTensor4dDescriptor(cudnnContext.gammaBetaTensorDesc, TENSOR_FORMAT, toCudnnDataType(gamma.data().dataType()), (int)shape[0], | ||||||
|                 shape[1], shape.length > 2 ? shape[2] : 1, shape.length > 3 ? shape[3] : 1)); |                 (int)shape[1], shape.length > 2 ? (int)shape[2] : 1, shape.length > 3 ? (int)shape[3] : 1)); | ||||||
| 
 | 
 | ||||||
|         Allocator allocator = AtomicAllocator.getInstance(); |         Allocator allocator = AtomicAllocator.getInstance(); | ||||||
|         CudaContext context = allocator.getFlowController().prepareActionAllWrite(input, epsilon, nextEpsilon, gamma, |         CudaContext context = allocator.getFlowController().prepareActionAllWrite(input, epsilon, nextEpsilon, gamma, | ||||||
| @ -214,7 +214,7 @@ public class CudnnBatchNormalizationHelper extends BaseCudnnHelper implements Ba | |||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|     @Override |     @Override | ||||||
|     public INDArray preOutput(INDArray x, boolean training, int[] shape, INDArray gamma, INDArray beta, INDArray mean, |     public INDArray preOutput(INDArray x, boolean training, long[] shape, INDArray gamma, INDArray beta, INDArray mean, | ||||||
|                     INDArray var, double decay, double eps, LayerWorkspaceMgr workspaceMgr) { |                     INDArray var, double decay, double eps, LayerWorkspaceMgr workspaceMgr) { | ||||||
|         this.eps = eps; |         this.eps = eps; | ||||||
|         final boolean isHalf = (x.dataType() == DataType.HALF); |         final boolean isHalf = (x.dataType() == DataType.HALF); | ||||||
| @ -252,8 +252,8 @@ public class CudnnBatchNormalizationHelper extends BaseCudnnHelper implements Ba | |||||||
|         checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, miniBatch, inDepth, inH, inW, |         checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, miniBatch, inDepth, inH, inW, | ||||||
|                         dstStride[0], dstStride[1], dstStride[2], dstStride[3])); |                         dstStride[0], dstStride[1], dstStride[2], dstStride[3])); | ||||||
| 
 | 
 | ||||||
|         checkCudnn(cudnnSetTensor4dDescriptor(cudnnContext.gammaBetaTensorDesc, TENSOR_FORMAT, toCudnnDataType(mean.data().dataType()), shape[0], |         checkCudnn(cudnnSetTensor4dDescriptor(cudnnContext.gammaBetaTensorDesc, TENSOR_FORMAT, toCudnnDataType(mean.data().dataType()), (int)shape[0], | ||||||
|                 shape[1], shape.length > 2 ? shape[2] : 1, shape.length > 3 ? shape[3] : 1)); |                 (int)shape[1], shape.length > 2 ? (int)shape[2] : 1, shape.length > 3 ? (int)shape[3] : 1)); | ||||||
| 
 | 
 | ||||||
|         Allocator allocator = AtomicAllocator.getInstance(); |         Allocator allocator = AtomicAllocator.getInstance(); | ||||||
|         CudaContext context = |         CudaContext context = | ||||||
|  | |||||||
| @ -93,11 +93,11 @@ public class Glove implements Serializable { | |||||||
|                     VocabWord w1, INDArray wordVector, INDArray contextVector, double gradient) { |                     VocabWord w1, INDArray wordVector, INDArray contextVector, double gradient) { | ||||||
|         //gradient for word vectors |         //gradient for word vectors | ||||||
|         INDArray grad1 = contextVector.mul(gradient); |         INDArray grad1 = contextVector.mul(gradient); | ||||||
|         INDArray update = weightAdaGrad.getGradient(grad1, w1.getIndex(), ArrayUtil.toInts(syn0.shape())); |         INDArray update = weightAdaGrad.getGradient(grad1, w1.getIndex(), syn0.shape()); | ||||||
|         wordVector.subi(update); |         wordVector.subi(update); | ||||||
| 
 | 
 | ||||||
|         double w1Bias = bias.getDouble(w1.getIndex()); |         double w1Bias = bias.getDouble(w1.getIndex()); | ||||||
|         double biasGradient = biasAdaGrad.getGradient(gradient, w1.getIndex(), ArrayUtil.toInts(bias.shape())); |         double biasGradient = biasAdaGrad.getGradient(gradient, w1.getIndex(), bias.shape()); | ||||||
|         double update2 = w1Bias - biasGradient; |         double update2 = w1Bias - biasGradient; | ||||||
|         bias.putScalar(w1.getIndex(), bias.getDouble(w1.getIndex()) - update2); |         bias.putScalar(w1.getIndex(), bias.getDouble(w1.getIndex()) - update2); | ||||||
|         return new Pair<>(update, (float) update2); |         return new Pair<>(update, (float) update2); | ||||||
|  | |||||||
| @ -0,0 +1,31 @@ | |||||||
|  | /******************************************************************************* | ||||||
|  |  * Copyright (c) 2015-2018 Skymind, Inc. | ||||||
|  |  * | ||||||
|  |  * This program and the accompanying materials are made available under the | ||||||
|  |  * terms of the Apache License, Version 2.0 which is available at | ||||||
|  |  * https://www.apache.org/licenses/LICENSE-2.0. | ||||||
|  |  * | ||||||
|  |  * Unless required by applicable law or agreed to in writing, software | ||||||
|  |  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||||
|  |  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||||
|  |  * License for the specific language governing permissions and limitations | ||||||
|  |  * under the License. | ||||||
|  |  * | ||||||
|  |  * SPDX-License-Identifier: Apache-2.0 | ||||||
|  |  ******************************************************************************/ | ||||||
|  | 
 | ||||||
|  | package org.deeplearning4j.spark.impl.common.reduce; | ||||||
|  | 
 | ||||||
|  | import org.apache.spark.api.java.function.Function2; | ||||||
|  | import scala.Tuple2; | ||||||
|  | 
 | ||||||
|  | /** | ||||||
|  |  * Add both elements of a {@code Tuple2<Integer,Double>} | ||||||
|  |  */ | ||||||
|  | public class LongDoubleReduceFunction | ||||||
|  |                 implements Function2<Tuple2<Long, Double>, Tuple2<Long, Double>, Tuple2<Long, Double>> { | ||||||
|  |     @Override | ||||||
|  |     public Tuple2<Long, Double> call(Tuple2<Long, Double> f, Tuple2<Long, Double> s) throws Exception { | ||||||
|  |         return new Tuple2<>(f._1() + s._1(), f._2() + s._2()); | ||||||
|  |     } | ||||||
|  | } | ||||||
| @ -38,6 +38,7 @@ import org.deeplearning4j.spark.api.TrainingMaster; | |||||||
| import org.deeplearning4j.spark.api.stats.SparkTrainingStats; | import org.deeplearning4j.spark.api.stats.SparkTrainingStats; | ||||||
| import org.deeplearning4j.spark.impl.SparkListenable; | import org.deeplearning4j.spark.impl.SparkListenable; | ||||||
| import org.deeplearning4j.spark.impl.common.reduce.IntDoubleReduceFunction; | import org.deeplearning4j.spark.impl.common.reduce.IntDoubleReduceFunction; | ||||||
|  | import org.deeplearning4j.spark.impl.common.reduce.LongDoubleReduceFunction; | ||||||
| import org.deeplearning4j.spark.impl.graph.dataset.DataSetToMultiDataSetFn; | import org.deeplearning4j.spark.impl.graph.dataset.DataSetToMultiDataSetFn; | ||||||
| import org.deeplearning4j.spark.impl.graph.dataset.PairDataSetToMultiDataSetFn; | import org.deeplearning4j.spark.impl.graph.dataset.PairDataSetToMultiDataSetFn; | ||||||
| import org.deeplearning4j.spark.impl.graph.evaluation.IEvaluateMDSFlatMapFunction; | import org.deeplearning4j.spark.impl.graph.evaluation.IEvaluateMDSFlatMapFunction; | ||||||
| @ -374,11 +375,11 @@ public class SparkComputationGraph extends SparkListenable { | |||||||
|      *                      in one go) |      *                      in one go) | ||||||
|      */ |      */ | ||||||
|     public double calculateScore(JavaRDD<DataSet> data, boolean average, int minibatchSize) { |     public double calculateScore(JavaRDD<DataSet> data, boolean average, int minibatchSize) { | ||||||
|         JavaRDD<Tuple2<Integer, Double>> rdd = data.mapPartitions(new ScoreFlatMapFunctionCGDataSet(conf.toJson(), |         JavaRDD<Tuple2<Long, Double>> rdd = data.mapPartitions(new ScoreFlatMapFunctionCGDataSet(conf.toJson(), | ||||||
|                         sc.broadcast(network.params(false)), minibatchSize)); |                         sc.broadcast(network.params()), minibatchSize)); | ||||||
| 
 | 
 | ||||||
|         //Reduce to a single tuple, with example count + sum of scores |         //Reduce to a single tuple, with example count + sum of scores | ||||||
|         Tuple2<Integer, Double> countAndSumScores = rdd.reduce(new IntDoubleReduceFunction()); |         Tuple2<Long, Double> countAndSumScores = rdd.reduce(new LongDoubleReduceFunction()); | ||||||
|         if (average) { |         if (average) { | ||||||
|             return countAndSumScores._2() / countAndSumScores._1(); |             return countAndSumScores._2() / countAndSumScores._1(); | ||||||
|         } else { |         } else { | ||||||
| @ -409,10 +410,10 @@ public class SparkComputationGraph extends SparkListenable { | |||||||
|      *                      in one go) |      *                      in one go) | ||||||
|      */ |      */ | ||||||
|     public double calculateScoreMultiDataSet(JavaRDD<MultiDataSet> data, boolean average, int minibatchSize) { |     public double calculateScoreMultiDataSet(JavaRDD<MultiDataSet> data, boolean average, int minibatchSize) { | ||||||
|         JavaRDD<Tuple2<Integer, Double>> rdd = data.mapPartitions(new ScoreFlatMapFunctionCGMultiDataSet(conf.toJson(), |         JavaRDD<Tuple2<Long, Double>> rdd = data.mapPartitions(new ScoreFlatMapFunctionCGMultiDataSet(conf.toJson(), | ||||||
|                         sc.broadcast(network.params(false)), minibatchSize)); |                         sc.broadcast(network.params()), minibatchSize)); | ||||||
|         //Reduce to a single tuple, with example count + sum of scores |         //Reduce to a single tuple, with example count + sum of scores | ||||||
|         Tuple2<Integer, Double> countAndSumScores = rdd.reduce(new IntDoubleReduceFunction()); |         Tuple2<Long, Double> countAndSumScores = rdd.reduce(new LongDoubleReduceFunction()); | ||||||
|         if (average) { |         if (average) { | ||||||
|             return countAndSumScores._2() / countAndSumScores._1(); |             return countAndSumScores._2() / countAndSumScores._1(); | ||||||
|         } else { |         } else { | ||||||
|  | |||||||
| @ -47,7 +47,7 @@ public abstract class BaseDataSetIterator<T> implements DataSetIterator { | |||||||
|     public int inputColumns() { |     public int inputColumns() { | ||||||
|         if (inputColumns == -1) |         if (inputColumns == -1) | ||||||
|             preloadDataSet(); |             preloadDataSet(); | ||||||
|         return inputColumns; |         return (int)inputColumns; | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     @Override |     @Override | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user