2021-02-01 14:31:20 +09:00
|
|
|
/*
|
|
|
|
|
* ******************************************************************************
|
|
|
|
|
* *
|
|
|
|
|
* *
|
|
|
|
|
* * This program and the accompanying materials are made available under the
|
|
|
|
|
* * terms of the Apache License, Version 2.0 which is available at
|
|
|
|
|
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
|
|
|
* *
|
2021-02-01 17:47:29 +09:00
|
|
|
* * See the NOTICE file distributed with this work for additional
|
|
|
|
|
* * information regarding copyright ownership.
|
2021-02-01 14:31:20 +09:00
|
|
|
* * Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
|
|
|
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
|
|
|
* * License for the specific language governing permissions and limitations
|
|
|
|
|
* * under the License.
|
|
|
|
|
* *
|
|
|
|
|
* * SPDX-License-Identifier: Apache-2.0
|
|
|
|
|
* *****************************************************************************
|
|
|
|
|
*/
|
2019-06-06 15:21:15 +03:00
|
|
|
|
|
|
|
|
package org.deeplearning4j.nn.params;
|
|
|
|
|
|
|
|
|
|
import lombok.val;
|
2023-03-23 17:39:00 +01:00
|
|
|
import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
|
2019-06-06 15:21:15 +03:00
|
|
|
import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder;
|
|
|
|
|
import org.deeplearning4j.nn.weights.IWeightInit;
|
|
|
|
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
2019-10-31 11:23:09 +02:00
|
|
|
import org.nd4j.linalg.exception.ND4JArraySizeException;
|
2019-06-06 15:21:15 +03:00
|
|
|
import org.nd4j.linalg.indexing.NDArrayIndex;
|
|
|
|
|
|
|
|
|
|
import java.util.ArrayList;
|
|
|
|
|
import java.util.LinkedHashMap;
|
|
|
|
|
import java.util.List;
|
|
|
|
|
import java.util.Map;
|
|
|
|
|
|
|
|
|
|
public class VariationalAutoencoderParamInitializer extends DefaultParamInitializer {
|
|
|
|
|
|
|
|
|
|
private static final VariationalAutoencoderParamInitializer INSTANCE = new VariationalAutoencoderParamInitializer();
|
|
|
|
|
|
|
|
|
|
public static VariationalAutoencoderParamInitializer getInstance() {
|
|
|
|
|
return INSTANCE;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
public static final String WEIGHT_KEY_SUFFIX = "W";
|
|
|
|
|
public static final String BIAS_KEY_SUFFIX = "b";
|
|
|
|
|
public static final String PZX_PREFIX = "pZX";
|
|
|
|
|
public static final String PZX_MEAN_PREFIX = PZX_PREFIX + "Mean";
|
|
|
|
|
public static final String PZX_LOGSTD2_PREFIX = PZX_PREFIX + "LogStd2";
|
|
|
|
|
|
|
|
|
|
public static final String ENCODER_PREFIX = "e";
|
|
|
|
|
public static final String DECODER_PREFIX = "d";
|
|
|
|
|
|
|
|
|
|
/** Key for weight parameters connecting the last encoder layer and the mean values for p(z|data) */
|
|
|
|
|
public static final String PZX_MEAN_W = "pZXMean" + WEIGHT_KEY_SUFFIX;
|
|
|
|
|
/** Key for bias parameters for the mean values for p(z|data) */
|
|
|
|
|
public static final String PZX_MEAN_B = "pZXMean" + BIAS_KEY_SUFFIX;
|
|
|
|
|
/** Key for weight parameters connecting the last encoder layer and the log(sigma^2) values for p(z|data) */
|
|
|
|
|
public static final String PZX_LOGSTD2_W = PZX_LOGSTD2_PREFIX + WEIGHT_KEY_SUFFIX;
|
|
|
|
|
/** Key for bias parameters for log(sigma^2) in p(z|data) */
|
|
|
|
|
public static final String PZX_LOGSTD2_B = PZX_LOGSTD2_PREFIX + BIAS_KEY_SUFFIX;
|
|
|
|
|
|
|
|
|
|
public static final String PXZ_PREFIX = "pXZ";
|
|
|
|
|
/** Key for weight parameters connecting the last decoder layer and p(data|z) (according to whatever
|
|
|
|
|
* {@link org.deeplearning4j.nn.conf.layers.variational.ReconstructionDistribution} is set for the VAE) */
|
|
|
|
|
public static final String PXZ_W = PXZ_PREFIX + WEIGHT_KEY_SUFFIX;
|
|
|
|
|
/** Key for bias parameters connecting the last decoder layer and p(data|z) (according to whatever
|
|
|
|
|
* {@link org.deeplearning4j.nn.conf.layers.variational.ReconstructionDistribution} is set for the VAE) */
|
|
|
|
|
public static final String PXZ_B = PXZ_PREFIX + BIAS_KEY_SUFFIX;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@Override
|
2023-03-23 17:39:00 +01:00
|
|
|
public long numParams(LayerConfiguration conf) {
|
|
|
|
|
VariationalAutoencoder layer = (VariationalAutoencoder) conf;
|
2019-06-06 15:21:15 +03:00
|
|
|
|
|
|
|
|
val nIn = layer.getNIn();
|
|
|
|
|
val nOut = layer.getNOut();
|
|
|
|
|
int[] encoderLayerSizes = layer.getEncoderLayerSizes();
|
|
|
|
|
int[] decoderLayerSizes = layer.getDecoderLayerSizes();
|
|
|
|
|
|
|
|
|
|
int paramCount = 0;
|
|
|
|
|
for (int i = 0; i < encoderLayerSizes.length; i++) {
|
|
|
|
|
long encoderLayerIn;
|
|
|
|
|
if (i == 0) {
|
|
|
|
|
encoderLayerIn = nIn;
|
|
|
|
|
} else {
|
|
|
|
|
encoderLayerIn = encoderLayerSizes[i - 1];
|
|
|
|
|
}
|
|
|
|
|
paramCount += (encoderLayerIn + 1) * encoderLayerSizes[i]; //weights + bias
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
//Between the last encoder layer and the parameters for p(z|x):
|
|
|
|
|
int lastEncLayerSize = encoderLayerSizes[encoderLayerSizes.length - 1];
|
|
|
|
|
paramCount += (lastEncLayerSize + 1) * 2 * nOut; //Mean and variance parameters used in unsupervised training
|
|
|
|
|
|
|
|
|
|
//Decoder:
|
|
|
|
|
for (int i = 0; i < decoderLayerSizes.length; i++) {
|
|
|
|
|
long decoderLayerNIn;
|
|
|
|
|
if (i == 0) {
|
|
|
|
|
decoderLayerNIn = nOut;
|
|
|
|
|
} else {
|
|
|
|
|
decoderLayerNIn = decoderLayerSizes[i - 1];
|
|
|
|
|
}
|
|
|
|
|
paramCount += (decoderLayerNIn + 1) * decoderLayerSizes[i];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
//Between last decoder layer and parameters for p(x|z):
|
2019-10-31 11:23:09 +02:00
|
|
|
if (nIn > Integer.MAX_VALUE)
|
|
|
|
|
throw new ND4JArraySizeException();
|
2019-06-06 15:21:15 +03:00
|
|
|
val nDistributionParams = layer.getOutputDistribution().distributionInputSize((int) nIn);
|
|
|
|
|
val lastDecLayerSize = decoderLayerSizes[decoderLayerSizes.length - 1];
|
|
|
|
|
paramCount += (lastDecLayerSize + 1) * nDistributionParams;
|
|
|
|
|
|
|
|
|
|
return paramCount;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@Override
|
2023-03-23 17:39:00 +01:00
|
|
|
public List<String> paramKeys(LayerConfiguration l) {
|
2019-06-06 15:21:15 +03:00
|
|
|
VariationalAutoencoder layer = (VariationalAutoencoder) l;
|
|
|
|
|
int[] encoderLayerSizes = layer.getEncoderLayerSizes();
|
|
|
|
|
int[] decoderLayerSizes = layer.getDecoderLayerSizes();
|
|
|
|
|
|
|
|
|
|
List<String> p = new ArrayList<>();
|
|
|
|
|
|
|
|
|
|
int soFar = 0;
|
|
|
|
|
for (int i = 0; i < encoderLayerSizes.length; i++) {
|
|
|
|
|
String sW = "e" + i + WEIGHT_KEY_SUFFIX;
|
|
|
|
|
String sB = "e" + i + BIAS_KEY_SUFFIX;
|
|
|
|
|
p.add(sW);
|
|
|
|
|
p.add(sB);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
//Last encoder layer -> p(z|x)
|
|
|
|
|
p.add(PZX_MEAN_W);
|
|
|
|
|
p.add(PZX_MEAN_B);
|
|
|
|
|
|
|
|
|
|
//Pretrain params
|
|
|
|
|
p.add(PZX_LOGSTD2_W);
|
|
|
|
|
p.add(PZX_LOGSTD2_B);
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < decoderLayerSizes.length; i++) {
|
|
|
|
|
String sW = "d" + i + WEIGHT_KEY_SUFFIX;
|
|
|
|
|
String sB = "d" + i + BIAS_KEY_SUFFIX;
|
|
|
|
|
p.add(sW);
|
|
|
|
|
p.add(sB);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
//Finally, p(x|z):
|
|
|
|
|
p.add(PXZ_W);
|
|
|
|
|
p.add(PXZ_B);
|
|
|
|
|
|
|
|
|
|
return p;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@Override
|
2023-03-23 17:39:00 +01:00
|
|
|
public List<String> weightKeys(LayerConfiguration layer) {
|
2019-06-06 15:21:15 +03:00
|
|
|
List<String> out = new ArrayList<>();
|
|
|
|
|
for(String s : paramKeys(layer)){
|
|
|
|
|
if(isWeightParam(layer, s)){
|
|
|
|
|
out.add(s);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return out;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@Override
|
2023-03-23 17:39:00 +01:00
|
|
|
public List<String> biasKeys(LayerConfiguration layer) {
|
2019-06-06 15:21:15 +03:00
|
|
|
List<String> out = new ArrayList<>();
|
|
|
|
|
for(String s : paramKeys(layer)){
|
|
|
|
|
if(isBiasParam(layer, s)){
|
|
|
|
|
out.add(s);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return out;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@Override
|
2023-03-23 17:39:00 +01:00
|
|
|
public boolean isWeightParam(LayerConfiguration layer, String key) {
|
2019-06-06 15:21:15 +03:00
|
|
|
return key.endsWith(WEIGHT_KEY_SUFFIX);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@Override
|
2023-03-23 17:39:00 +01:00
|
|
|
public boolean isBiasParam(LayerConfiguration layer, String key) {
|
2019-06-06 15:21:15 +03:00
|
|
|
return key.endsWith(BIAS_KEY_SUFFIX);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@Override
|
2023-03-23 17:39:00 +01:00
|
|
|
public Map<String, INDArray> init(LayerConfiguration conf, INDArray paramsView, boolean initializeParams) {
|
2019-06-06 15:21:15 +03:00
|
|
|
if (paramsView.length() != numParams(conf)) {
|
|
|
|
|
throw new IllegalArgumentException("Incorrect paramsView length: Expected length " + numParams(conf)
|
|
|
|
|
+ ", got length " + paramsView.length());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Map<String, INDArray> ret = new LinkedHashMap<>();
|
2023-03-23 17:39:00 +01:00
|
|
|
VariationalAutoencoder layer = (VariationalAutoencoder) conf;
|
2019-06-06 15:21:15 +03:00
|
|
|
|
|
|
|
|
val nIn = layer.getNIn();
|
|
|
|
|
val nOut = layer.getNOut();
|
|
|
|
|
int[] encoderLayerSizes = layer.getEncoderLayerSizes();
|
|
|
|
|
int[] decoderLayerSizes = layer.getDecoderLayerSizes();
|
|
|
|
|
|
2023-03-23 17:39:00 +01:00
|
|
|
IWeightInit weightInit = layer.getWeightInit();
|
2019-06-06 15:21:15 +03:00
|
|
|
|
|
|
|
|
int soFar = 0;
|
|
|
|
|
for (int i = 0; i < encoderLayerSizes.length; i++) {
|
|
|
|
|
long encoderLayerNIn;
|
|
|
|
|
if (i == 0) {
|
|
|
|
|
encoderLayerNIn = nIn;
|
|
|
|
|
} else {
|
|
|
|
|
encoderLayerNIn = encoderLayerSizes[i - 1];
|
|
|
|
|
}
|
|
|
|
|
val weightParamCount = encoderLayerNIn * encoderLayerSizes[i];
|
|
|
|
|
INDArray weightView = paramsView.get(NDArrayIndex.interval(0,0,true),
|
|
|
|
|
NDArrayIndex.interval(soFar, soFar + weightParamCount));
|
|
|
|
|
soFar += weightParamCount;
|
|
|
|
|
INDArray biasView = paramsView.get(NDArrayIndex.interval(0,0,true),
|
|
|
|
|
NDArrayIndex.interval(soFar, soFar + encoderLayerSizes[i]));
|
|
|
|
|
soFar += encoderLayerSizes[i];
|
|
|
|
|
|
|
|
|
|
INDArray layerWeights = createWeightMatrix(encoderLayerNIn, encoderLayerSizes[i], weightInit,
|
|
|
|
|
weightView, initializeParams);
|
|
|
|
|
INDArray layerBiases = createBias(encoderLayerSizes[i], 0.0, biasView, initializeParams); //TODO don't hardcode 0
|
|
|
|
|
|
|
|
|
|
String sW = "e" + i + WEIGHT_KEY_SUFFIX;
|
|
|
|
|
String sB = "e" + i + BIAS_KEY_SUFFIX;
|
|
|
|
|
ret.put(sW, layerWeights);
|
|
|
|
|
ret.put(sB, layerBiases);
|
|
|
|
|
|
|
|
|
|
conf.addVariable(sW);
|
|
|
|
|
conf.addVariable(sB);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
//Last encoder layer -> p(z|x)
|
|
|
|
|
val nWeightsPzx = encoderLayerSizes[encoderLayerSizes.length - 1] * nOut;
|
|
|
|
|
INDArray pzxWeightsMean =
|
|
|
|
|
paramsView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(soFar, soFar + nWeightsPzx));
|
|
|
|
|
soFar += nWeightsPzx;
|
|
|
|
|
INDArray pzxBiasMean = paramsView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(soFar, soFar + nOut));
|
|
|
|
|
soFar += nOut;
|
|
|
|
|
|
|
|
|
|
INDArray pzxWeightsMeanReshaped = createWeightMatrix(encoderLayerSizes[encoderLayerSizes.length - 1], nOut,
|
|
|
|
|
weightInit, pzxWeightsMean, initializeParams);
|
|
|
|
|
INDArray pzxBiasMeanReshaped = createBias(nOut, 0.0, pzxBiasMean, initializeParams); //TODO don't hardcode 0
|
|
|
|
|
|
|
|
|
|
ret.put(PZX_MEAN_W, pzxWeightsMeanReshaped);
|
|
|
|
|
ret.put(PZX_MEAN_B, pzxBiasMeanReshaped);
|
|
|
|
|
conf.addVariable(PZX_MEAN_W);
|
|
|
|
|
conf.addVariable(PZX_MEAN_B);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
//Pretrain params
|
|
|
|
|
INDArray pzxWeightsLogStdev2 =
|
|
|
|
|
paramsView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(soFar, soFar + nWeightsPzx));
|
|
|
|
|
soFar += nWeightsPzx;
|
|
|
|
|
INDArray pzxBiasLogStdev2 = paramsView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(soFar, soFar + nOut));
|
|
|
|
|
soFar += nOut;
|
|
|
|
|
|
|
|
|
|
INDArray pzxWeightsLogStdev2Reshaped = createWeightMatrix(encoderLayerSizes[encoderLayerSizes.length - 1], nOut,
|
|
|
|
|
weightInit, pzxWeightsLogStdev2, initializeParams);
|
|
|
|
|
INDArray pzxBiasLogStdev2Reshaped = createBias(nOut, 0.0, pzxBiasLogStdev2, initializeParams); //TODO don't hardcode 0
|
|
|
|
|
|
|
|
|
|
ret.put(PZX_LOGSTD2_W, pzxWeightsLogStdev2Reshaped);
|
|
|
|
|
ret.put(PZX_LOGSTD2_B, pzxBiasLogStdev2Reshaped);
|
|
|
|
|
conf.addVariable(PZX_LOGSTD2_W);
|
|
|
|
|
conf.addVariable(PZX_LOGSTD2_B);
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < decoderLayerSizes.length; i++) {
|
|
|
|
|
long decoderLayerNIn;
|
|
|
|
|
if (i == 0) {
|
|
|
|
|
decoderLayerNIn = nOut;
|
|
|
|
|
} else {
|
|
|
|
|
decoderLayerNIn = decoderLayerSizes[i - 1];
|
|
|
|
|
}
|
|
|
|
|
val weightParamCount = decoderLayerNIn * decoderLayerSizes[i];
|
|
|
|
|
INDArray weightView = paramsView.get(NDArrayIndex.interval(0,0,true),
|
|
|
|
|
NDArrayIndex.interval(soFar, soFar + weightParamCount));
|
|
|
|
|
soFar += weightParamCount;
|
|
|
|
|
INDArray biasView = paramsView.get(NDArrayIndex.interval(0,0,true),
|
|
|
|
|
NDArrayIndex.interval(soFar, soFar + decoderLayerSizes[i]));
|
|
|
|
|
soFar += decoderLayerSizes[i];
|
|
|
|
|
|
|
|
|
|
INDArray layerWeights = createWeightMatrix(decoderLayerNIn, decoderLayerSizes[i], weightInit,
|
|
|
|
|
weightView, initializeParams);
|
|
|
|
|
INDArray layerBiases = createBias(decoderLayerSizes[i], 0.0, biasView, initializeParams); //TODO don't hardcode 0
|
|
|
|
|
|
|
|
|
|
String sW = "d" + i + WEIGHT_KEY_SUFFIX;
|
|
|
|
|
String sB = "d" + i + BIAS_KEY_SUFFIX;
|
|
|
|
|
ret.put(sW, layerWeights);
|
|
|
|
|
ret.put(sB, layerBiases);
|
|
|
|
|
conf.addVariable(sW);
|
|
|
|
|
conf.addVariable(sB);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
//Finally, p(x|z):
|
2019-10-31 11:23:09 +02:00
|
|
|
if (nIn > Integer.MAX_VALUE)
|
|
|
|
|
throw new ND4JArraySizeException();
|
2019-06-06 15:21:15 +03:00
|
|
|
int nDistributionParams = layer.getOutputDistribution().distributionInputSize((int) nIn);
|
|
|
|
|
int pxzWeightCount = decoderLayerSizes[decoderLayerSizes.length - 1] * nDistributionParams;
|
|
|
|
|
INDArray pxzWeightView =
|
|
|
|
|
paramsView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(soFar, soFar + pxzWeightCount));
|
|
|
|
|
soFar += pxzWeightCount;
|
|
|
|
|
INDArray pxzBiasView = paramsView.get(NDArrayIndex.interval(0,0,true),
|
|
|
|
|
NDArrayIndex.interval(soFar, soFar + nDistributionParams));
|
|
|
|
|
|
|
|
|
|
INDArray pxzWeightsReshaped = createWeightMatrix(decoderLayerSizes[decoderLayerSizes.length - 1],
|
|
|
|
|
nDistributionParams, weightInit, pxzWeightView, initializeParams);
|
|
|
|
|
INDArray pxzBiasReshaped = createBias(nDistributionParams, 0.0, pxzBiasView, initializeParams); //TODO don't hardcode 0
|
|
|
|
|
|
|
|
|
|
ret.put(PXZ_W, pxzWeightsReshaped);
|
|
|
|
|
ret.put(PXZ_B, pxzBiasReshaped);
|
|
|
|
|
conf.addVariable(PXZ_W);
|
|
|
|
|
conf.addVariable(PXZ_B);
|
|
|
|
|
|
|
|
|
|
return ret;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@Override
|
2023-03-23 17:39:00 +01:00
|
|
|
public Map<String, INDArray> getGradientsFromFlattened(LayerConfiguration conf, INDArray gradientView) {
|
2019-06-06 15:21:15 +03:00
|
|
|
Map<String, INDArray> ret = new LinkedHashMap<>();
|
2023-03-23 17:39:00 +01:00
|
|
|
VariationalAutoencoder layer = (VariationalAutoencoder) conf;
|
2019-06-06 15:21:15 +03:00
|
|
|
|
|
|
|
|
val nIn = layer.getNIn();
|
|
|
|
|
val nOut = layer.getNOut();
|
|
|
|
|
int[] encoderLayerSizes = layer.getEncoderLayerSizes();
|
|
|
|
|
int[] decoderLayerSizes = layer.getDecoderLayerSizes();
|
|
|
|
|
|
|
|
|
|
int soFar = 0;
|
|
|
|
|
for (int i = 0; i < encoderLayerSizes.length; i++) {
|
|
|
|
|
long encoderLayerNIn;
|
|
|
|
|
if (i == 0) {
|
|
|
|
|
encoderLayerNIn = nIn;
|
|
|
|
|
} else {
|
|
|
|
|
encoderLayerNIn = encoderLayerSizes[i - 1];
|
|
|
|
|
}
|
|
|
|
|
val weightParamCount = encoderLayerNIn * encoderLayerSizes[i];
|
|
|
|
|
INDArray weightGradView = gradientView.get(NDArrayIndex.interval(0,0,true),
|
|
|
|
|
NDArrayIndex.interval(soFar, soFar + weightParamCount));
|
|
|
|
|
soFar += weightParamCount;
|
|
|
|
|
INDArray biasGradView = gradientView.get(NDArrayIndex.interval(0,0,true),
|
|
|
|
|
NDArrayIndex.interval(soFar, soFar + encoderLayerSizes[i]));
|
|
|
|
|
soFar += encoderLayerSizes[i];
|
|
|
|
|
|
|
|
|
|
INDArray layerWeights = weightGradView.reshape('f', encoderLayerNIn, encoderLayerSizes[i]);
|
|
|
|
|
INDArray layerBiases = biasGradView; //Aready correct shape (row vector)
|
|
|
|
|
|
|
|
|
|
ret.put("e" + i + WEIGHT_KEY_SUFFIX, layerWeights);
|
|
|
|
|
ret.put("e" + i + BIAS_KEY_SUFFIX, layerBiases);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
//Last encoder layer -> p(z|x)
|
|
|
|
|
val nWeightsPzx = encoderLayerSizes[encoderLayerSizes.length - 1] * nOut;
|
|
|
|
|
INDArray pzxWeightsMean =
|
|
|
|
|
gradientView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(soFar, soFar + nWeightsPzx));
|
|
|
|
|
soFar += nWeightsPzx;
|
|
|
|
|
INDArray pzxBiasMean = gradientView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(soFar, soFar + nOut));
|
|
|
|
|
soFar += nOut;
|
|
|
|
|
|
|
|
|
|
INDArray pzxWeightGradMeanReshaped =
|
|
|
|
|
pzxWeightsMean.reshape('f', encoderLayerSizes[encoderLayerSizes.length - 1], nOut);
|
|
|
|
|
|
|
|
|
|
ret.put(PZX_MEAN_W, pzxWeightGradMeanReshaped);
|
|
|
|
|
ret.put(PZX_MEAN_B, pzxBiasMean);
|
|
|
|
|
|
|
|
|
|
////////////////////////////////////////////////////////
|
|
|
|
|
|
|
|
|
|
INDArray pzxWeightsLogStdev2 =
|
|
|
|
|
gradientView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(soFar, soFar + nWeightsPzx));
|
|
|
|
|
soFar += nWeightsPzx;
|
|
|
|
|
INDArray pzxBiasLogStdev2 = gradientView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(soFar, soFar + nOut));
|
|
|
|
|
soFar += nOut;
|
|
|
|
|
|
|
|
|
|
INDArray pzxWeightsLogStdev2Reshaped = createWeightMatrix(encoderLayerSizes[encoderLayerSizes.length - 1], nOut,
|
|
|
|
|
null, pzxWeightsLogStdev2, false); //TODO
|
|
|
|
|
|
|
|
|
|
ret.put(PZX_LOGSTD2_W, pzxWeightsLogStdev2Reshaped);
|
|
|
|
|
ret.put(PZX_LOGSTD2_B, pzxBiasLogStdev2);
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < decoderLayerSizes.length; i++) {
|
|
|
|
|
long decoderLayerNIn;
|
|
|
|
|
if (i == 0) {
|
|
|
|
|
decoderLayerNIn = nOut;
|
|
|
|
|
} else {
|
|
|
|
|
decoderLayerNIn = decoderLayerSizes[i - 1];
|
|
|
|
|
}
|
|
|
|
|
long weightParamCount = decoderLayerNIn * decoderLayerSizes[i];
|
|
|
|
|
INDArray weightView = gradientView.get(NDArrayIndex.interval(0,0,true),
|
|
|
|
|
NDArrayIndex.interval(soFar, soFar + weightParamCount));
|
|
|
|
|
soFar += weightParamCount;
|
|
|
|
|
INDArray biasView = gradientView.get(NDArrayIndex.interval(0,0,true),
|
|
|
|
|
NDArrayIndex.interval(soFar, soFar + decoderLayerSizes[i]));
|
|
|
|
|
soFar += decoderLayerSizes[i];
|
|
|
|
|
|
|
|
|
|
INDArray layerWeights =
|
|
|
|
|
createWeightMatrix(decoderLayerNIn, decoderLayerSizes[i], null, weightView, false);
|
|
|
|
|
INDArray layerBiases = createBias(decoderLayerSizes[i], 0.0, biasView, false); //TODO don't hardcode 0
|
|
|
|
|
|
|
|
|
|
String sW = "d" + i + WEIGHT_KEY_SUFFIX;
|
|
|
|
|
String sB = "d" + i + BIAS_KEY_SUFFIX;
|
|
|
|
|
ret.put(sW, layerWeights);
|
|
|
|
|
ret.put(sB, layerBiases);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
//Finally, p(x|z):
|
2019-10-31 11:23:09 +02:00
|
|
|
if (nIn > Integer.MAX_VALUE)
|
|
|
|
|
throw new ND4JArraySizeException();
|
2019-06-06 15:21:15 +03:00
|
|
|
int nDistributionParams = layer.getOutputDistribution().distributionInputSize((int) nIn);
|
|
|
|
|
int pxzWeightCount = decoderLayerSizes[decoderLayerSizes.length - 1] * nDistributionParams;
|
|
|
|
|
INDArray pxzWeightView =
|
|
|
|
|
gradientView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(soFar, soFar + pxzWeightCount));
|
|
|
|
|
soFar += pxzWeightCount;
|
|
|
|
|
INDArray pxzBiasView = gradientView.get(NDArrayIndex.interval(0,0,true),
|
|
|
|
|
NDArrayIndex.interval(soFar, soFar + nDistributionParams));
|
|
|
|
|
|
|
|
|
|
INDArray pxzWeightsReshaped = createWeightMatrix(decoderLayerSizes[decoderLayerSizes.length - 1],
|
|
|
|
|
nDistributionParams, null, pxzWeightView, false);
|
|
|
|
|
INDArray pxzBiasReshaped = createBias(nDistributionParams, 0.0, pxzBiasView, false);
|
|
|
|
|
|
|
|
|
|
ret.put(PXZ_W, pxzWeightsReshaped);
|
|
|
|
|
ret.put(PXZ_B, pxzBiasReshaped);
|
|
|
|
|
|
|
|
|
|
return ret;
|
|
|
|
|
}
|
|
|
|
|
}
|