Merge remote-tracking branch 'fork/master'
commit
a76a44e198
|
@ -17,13 +17,14 @@
|
||||||
package org.deeplearning4j.arbiter.optimize.parameter;
|
package org.deeplearning4j.arbiter.optimize.parameter;
|
||||||
|
|
||||||
import lombok.EqualsAndHashCode;
|
import lombok.EqualsAndHashCode;
|
||||||
|
import lombok.Getter;
|
||||||
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
|
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
|
||||||
import org.deeplearning4j.arbiter.optimize.serde.jackson.GenericDeserializer;
|
import org.deeplearning4j.arbiter.optimize.serde.jackson.FixedValueDeserializer;
|
||||||
import org.deeplearning4j.arbiter.optimize.serde.jackson.GenericSerializer;
|
import org.deeplearning4j.arbiter.optimize.serde.jackson.FixedValueSerializer;
|
||||||
import org.deeplearning4j.arbiter.util.ObjectUtils;
|
import org.deeplearning4j.arbiter.util.ObjectUtils;
|
||||||
import org.nd4j.shade.jackson.annotation.JsonCreator;
|
import org.nd4j.shade.jackson.annotation.JsonCreator;
|
||||||
import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties;
|
|
||||||
import org.nd4j.shade.jackson.annotation.JsonProperty;
|
import org.nd4j.shade.jackson.annotation.JsonProperty;
|
||||||
|
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
|
||||||
import org.nd4j.shade.jackson.databind.annotation.JsonDeserialize;
|
import org.nd4j.shade.jackson.databind.annotation.JsonDeserialize;
|
||||||
import org.nd4j.shade.jackson.databind.annotation.JsonSerialize;
|
import org.nd4j.shade.jackson.databind.annotation.JsonSerialize;
|
||||||
|
|
||||||
|
@ -37,9 +38,11 @@ import java.util.Map;
|
||||||
* @param <T> Type of (fixed) value
|
* @param <T> Type of (fixed) value
|
||||||
*/
|
*/
|
||||||
@EqualsAndHashCode
|
@EqualsAndHashCode
|
||||||
|
@JsonSerialize(using = FixedValueSerializer.class)
|
||||||
|
@JsonDeserialize(using = FixedValueDeserializer.class)
|
||||||
|
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
||||||
public class FixedValue<T> implements ParameterSpace<T> {
|
public class FixedValue<T> implements ParameterSpace<T> {
|
||||||
@JsonSerialize(using = GenericSerializer.class)
|
@Getter
|
||||||
@JsonDeserialize(using = GenericDeserializer.class)
|
|
||||||
private Object value;
|
private Object value;
|
||||||
private int index;
|
private int index;
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,52 @@
|
||||||
|
package org.deeplearning4j.arbiter.optimize.serde.jackson;
|
||||||
|
|
||||||
|
import org.apache.commons.codec.binary.Base64;
|
||||||
|
import org.deeplearning4j.arbiter.optimize.parameter.FixedValue;
|
||||||
|
import org.nd4j.shade.jackson.core.JsonParser;
|
||||||
|
import org.nd4j.shade.jackson.core.JsonProcessingException;
|
||||||
|
import org.nd4j.shade.jackson.databind.DeserializationContext;
|
||||||
|
import org.nd4j.shade.jackson.databind.JsonDeserializer;
|
||||||
|
import org.nd4j.shade.jackson.databind.JsonNode;
|
||||||
|
import org.nd4j.shade.jackson.databind.ObjectMapper;
|
||||||
|
|
||||||
|
import java.io.ByteArrayInputStream;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.io.ObjectInputStream;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A custom deserializer to be used in conjunction with {@link FixedValueSerializer}
|
||||||
|
* @author Alex Black
|
||||||
|
*/
|
||||||
|
public class FixedValueDeserializer extends JsonDeserializer<FixedValue> {
|
||||||
|
@Override
|
||||||
|
public FixedValue deserialize(JsonParser p, DeserializationContext deserializationContext) throws IOException, JsonProcessingException {
|
||||||
|
JsonNode node = p.getCodec().readTree(p);
|
||||||
|
String className = node.get("@valueclass").asText();
|
||||||
|
Class<?> c;
|
||||||
|
try {
|
||||||
|
c = Class.forName(className);
|
||||||
|
} catch (Exception e) {
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
|
||||||
|
if(node.has("value")){
|
||||||
|
//Number, String, Enum
|
||||||
|
JsonNode valueNode = node.get("value");
|
||||||
|
Object o = new ObjectMapper().treeToValue(valueNode, c);
|
||||||
|
return new FixedValue<>(o);
|
||||||
|
} else {
|
||||||
|
//Everything else
|
||||||
|
JsonNode valueNode = node.get("data");
|
||||||
|
String data = valueNode.asText();
|
||||||
|
|
||||||
|
byte[] b = new Base64().decode(data);
|
||||||
|
ObjectInputStream ois = new ObjectInputStream(new ByteArrayInputStream(b));
|
||||||
|
try {
|
||||||
|
Object o = ois.readObject();
|
||||||
|
return new FixedValue<>(o);
|
||||||
|
} catch (Throwable t) {
|
||||||
|
throw new RuntimeException(t);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,51 @@
|
||||||
|
package org.deeplearning4j.arbiter.optimize.serde.jackson;
|
||||||
|
|
||||||
|
import org.apache.commons.net.util.Base64;
|
||||||
|
import org.deeplearning4j.arbiter.optimize.parameter.FixedValue;
|
||||||
|
import org.nd4j.shade.jackson.core.JsonGenerator;
|
||||||
|
import org.nd4j.shade.jackson.core.type.WritableTypeId;
|
||||||
|
import org.nd4j.shade.jackson.databind.JsonSerializer;
|
||||||
|
import org.nd4j.shade.jackson.databind.SerializerProvider;
|
||||||
|
import org.nd4j.shade.jackson.databind.jsontype.TypeSerializer;
|
||||||
|
|
||||||
|
import java.io.ByteArrayOutputStream;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.io.ObjectOutputStream;
|
||||||
|
|
||||||
|
import static org.nd4j.shade.jackson.core.JsonToken.START_OBJECT;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A custom serializer to handle arbitrary object types
|
||||||
|
* Uses standard JSON where safe (number, string, enumerations) or Java object serialization (bytes -> base64)
|
||||||
|
* The latter is not an ideal approach, but Jackson doesn't support serialization/deserialization of arbitrary
|
||||||
|
* objects very well
|
||||||
|
*
|
||||||
|
* @author Alex Black
|
||||||
|
*/
|
||||||
|
public class FixedValueSerializer extends JsonSerializer<FixedValue> {
|
||||||
|
@Override
|
||||||
|
public void serialize(FixedValue fixedValue, JsonGenerator j, SerializerProvider serializerProvider) throws IOException {
|
||||||
|
Object o = fixedValue.getValue();
|
||||||
|
|
||||||
|
j.writeStringField("@valueclass", o.getClass().getName());
|
||||||
|
if(o instanceof Number || o instanceof String || o instanceof Enum){
|
||||||
|
j.writeObjectField("value", o);
|
||||||
|
} else {
|
||||||
|
ByteArrayOutputStream baos = new ByteArrayOutputStream();
|
||||||
|
ObjectOutputStream oos = new ObjectOutputStream(baos);
|
||||||
|
oos.writeObject(o);
|
||||||
|
baos.close();
|
||||||
|
byte[] b = baos.toByteArray();
|
||||||
|
String base64 = new Base64().encodeToString(b);
|
||||||
|
j.writeStringField("data", base64);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void serializeWithType(FixedValue value, JsonGenerator gen, SerializerProvider serializers, TypeSerializer typeSer) throws IOException {
|
||||||
|
WritableTypeId typeId = typeSer.typeId(value, START_OBJECT);
|
||||||
|
typeSer.writeTypePrefix(gen, typeId);
|
||||||
|
serialize(value, gen, serializers);
|
||||||
|
typeSer.writeTypeSuffix(gen, typeId);
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,46 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.serde.jackson;
|
|
||||||
|
|
||||||
import org.nd4j.shade.jackson.core.JsonParser;
|
|
||||||
import org.nd4j.shade.jackson.databind.DeserializationContext;
|
|
||||||
import org.nd4j.shade.jackson.databind.JsonDeserializer;
|
|
||||||
import org.nd4j.shade.jackson.databind.JsonNode;
|
|
||||||
import org.nd4j.shade.jackson.databind.ObjectMapper;
|
|
||||||
|
|
||||||
import java.io.IOException;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Created by Alex on 15/02/2017.
|
|
||||||
*/
|
|
||||||
public class GenericDeserializer extends JsonDeserializer<Object> {
|
|
||||||
@Override
|
|
||||||
public Object deserialize(JsonParser p, DeserializationContext ctxt) throws IOException {
|
|
||||||
JsonNode node = p.getCodec().readTree(p);
|
|
||||||
String className = node.get("@class").asText();
|
|
||||||
Class<?> c;
|
|
||||||
try {
|
|
||||||
c = Class.forName(className);
|
|
||||||
} catch (Exception e) {
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
|
||||||
|
|
||||||
JsonNode valueNode = node.get("value");
|
|
||||||
Object o = new ObjectMapper().treeToValue(valueNode, c);
|
|
||||||
return o;
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,38 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.serde.jackson;
|
|
||||||
|
|
||||||
import org.nd4j.shade.jackson.core.JsonGenerator;
|
|
||||||
import org.nd4j.shade.jackson.core.JsonProcessingException;
|
|
||||||
import org.nd4j.shade.jackson.databind.JsonSerializer;
|
|
||||||
import org.nd4j.shade.jackson.databind.SerializerProvider;
|
|
||||||
|
|
||||||
import java.io.IOException;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Created by Alex on 15/02/2017.
|
|
||||||
*/
|
|
||||||
public class GenericSerializer extends JsonSerializer<Object> {
|
|
||||||
@Override
|
|
||||||
public void serialize(Object o, JsonGenerator j, SerializerProvider serializerProvider)
|
|
||||||
throws IOException, JsonProcessingException {
|
|
||||||
j.writeStartObject();
|
|
||||||
j.writeStringField("@class", o.getClass().getName());
|
|
||||||
j.writeObjectField("value", o);
|
|
||||||
j.writeEndObject();
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -24,9 +24,6 @@ import org.nd4j.shade.jackson.databind.SerializationFeature;
|
||||||
import org.nd4j.shade.jackson.dataformat.yaml.YAMLFactory;
|
import org.nd4j.shade.jackson.dataformat.yaml.YAMLFactory;
|
||||||
import org.nd4j.shade.jackson.datatype.joda.JodaModule;
|
import org.nd4j.shade.jackson.datatype.joda.JodaModule;
|
||||||
|
|
||||||
import java.util.Collections;
|
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Created by Alex on 16/11/2016.
|
* Created by Alex on 16/11/2016.
|
||||||
*/
|
*/
|
||||||
|
@ -44,6 +41,7 @@ public class JsonMapper {
|
||||||
mapper.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.NONE);
|
mapper.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.NONE);
|
||||||
mapper.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY);
|
mapper.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY);
|
||||||
mapper.setVisibility(PropertyAccessor.CREATOR, JsonAutoDetect.Visibility.ANY);
|
mapper.setVisibility(PropertyAccessor.CREATOR, JsonAutoDetect.Visibility.ANY);
|
||||||
|
mapper.setVisibility(PropertyAccessor.SETTER, JsonAutoDetect.Visibility.ANY);
|
||||||
yamlMapper = new ObjectMapper(new YAMLFactory());
|
yamlMapper = new ObjectMapper(new YAMLFactory());
|
||||||
yamlMapper.registerModule(new JodaModule());
|
yamlMapper.registerModule(new JodaModule());
|
||||||
yamlMapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
|
yamlMapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
|
||||||
|
|
|
@ -32,7 +32,7 @@ import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
|
||||||
*/
|
*/
|
||||||
@Data
|
@Data
|
||||||
@EqualsAndHashCode(callSuper = true)
|
@EqualsAndHashCode(callSuper = true)
|
||||||
@NoArgsConstructor(access = AccessLevel.PROTECTED) //For Jackson JSON/YAML deserialization
|
@NoArgsConstructor(access = AccessLevel.PUBLIC) //For Jackson JSON/YAML deserialization
|
||||||
public abstract class BaseOutputLayerSpace<L extends BaseOutputLayer> extends FeedForwardLayerSpace<L> {
|
public abstract class BaseOutputLayerSpace<L extends BaseOutputLayer> extends FeedForwardLayerSpace<L> {
|
||||||
|
|
||||||
protected ParameterSpace<ILossFunction> lossFunction;
|
protected ParameterSpace<ILossFunction> lossFunction;
|
||||||
|
|
|
@ -18,8 +18,11 @@ package org.deeplearning4j.arbiter.ui.data;
|
||||||
|
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import org.deeplearning4j.arbiter.optimize.config.OptimizationConfiguration;
|
import org.deeplearning4j.arbiter.optimize.config.OptimizationConfiguration;
|
||||||
import org.deeplearning4j.arbiter.ui.misc.JsonMapper;
|
import org.deeplearning4j.arbiter.optimize.serde.jackson.JsonMapper;
|
||||||
import org.deeplearning4j.arbiter.ui.module.ArbiterModule;
|
import org.deeplearning4j.arbiter.ui.module.ArbiterModule;
|
||||||
|
import org.deeplearning4j.nn.conf.serde.JsonMappers;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
*
|
||||||
|
@ -64,7 +67,11 @@ public class GlobalConfigPersistable extends BaseJavaPersistable {
|
||||||
|
|
||||||
|
|
||||||
public OptimizationConfiguration getOptimizationConfiguration(){
|
public OptimizationConfiguration getOptimizationConfiguration(){
|
||||||
return JsonMapper.fromJson(optimizationConfigJson, OptimizationConfiguration.class);
|
try {
|
||||||
|
return JsonMapper.getMapper().readValue(optimizationConfigJson, OptimizationConfiguration.class);
|
||||||
|
} catch (IOException e){
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public int getCandidatesQueued(){
|
public int getCandidatesQueued(){
|
||||||
|
|
|
@ -26,13 +26,14 @@ import org.deeplearning4j.arbiter.optimize.api.OptimizationResult;
|
||||||
import org.deeplearning4j.arbiter.optimize.runner.CandidateInfo;
|
import org.deeplearning4j.arbiter.optimize.runner.CandidateInfo;
|
||||||
import org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner;
|
import org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner;
|
||||||
import org.deeplearning4j.arbiter.optimize.runner.listener.StatusListener;
|
import org.deeplearning4j.arbiter.optimize.runner.listener.StatusListener;
|
||||||
|
import org.deeplearning4j.arbiter.optimize.serde.jackson.JsonMapper;
|
||||||
import org.deeplearning4j.arbiter.ui.data.GlobalConfigPersistable;
|
import org.deeplearning4j.arbiter.ui.data.GlobalConfigPersistable;
|
||||||
import org.deeplearning4j.arbiter.ui.data.ModelInfoPersistable;
|
import org.deeplearning4j.arbiter.ui.data.ModelInfoPersistable;
|
||||||
import org.deeplearning4j.arbiter.ui.misc.JsonMapper;
|
|
||||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
import org.nd4j.linalg.primitives.Pair;
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.UUID;
|
import java.util.UUID;
|
||||||
import java.util.concurrent.ConcurrentHashMap;
|
import java.util.concurrent.ConcurrentHashMap;
|
||||||
|
@ -217,7 +218,11 @@ public class ArbiterStatusListener implements StatusListener {
|
||||||
// }
|
// }
|
||||||
//TODO: cache global config, but we don't want to have outdated info (like uninitialized termination conditions)
|
//TODO: cache global config, but we don't want to have outdated info (like uninitialized termination conditions)
|
||||||
|
|
||||||
ocJson = JsonMapper.asJson(r.getConfiguration());
|
try {
|
||||||
|
ocJson = JsonMapper.getMapper().writeValueAsString(r.getConfiguration());
|
||||||
|
} catch (IOException e){
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
|
||||||
GlobalConfigPersistable p = new GlobalConfigPersistable.Builder()
|
GlobalConfigPersistable p = new GlobalConfigPersistable.Builder()
|
||||||
.sessionId(sessionId)
|
.sessionId(sessionId)
|
||||||
|
|
File diff suppressed because one or more lines are too long
|
@ -183,12 +183,12 @@
|
||||||
</head>
|
</head>
|
||||||
<body class="bgcolor">
|
<body class="bgcolor">
|
||||||
<meta name="viewport" content="width=device-width, initial-scale=1">
|
<meta name="viewport" content="width=device-width, initial-scale=1">
|
||||||
<link id="bootstrap-style" href="/assets/webjars/bootstrap/2.3.2/css/bootstrap.min.css" rel="stylesheet">
|
<link id="bootstrap-style" href="/assets/webjars/bootstrap/4.3.1/dist/css/bootstrap.min.css" rel="stylesheet">
|
||||||
<script src="/assets/webjars/jquery/2.2.0/jquery.min.js"></script>
|
<script src="/assets/webjars/jquery/2.2.0/jquery.min.js"></script>
|
||||||
<link href="/assets/webjars/jquery-ui/1.10.2/themes/base/jquery-ui.css" rel="stylesheet">
|
<link href="/assets/webjars/jquery-ui/1.10.2/themes/base/jquery-ui.css" rel="stylesheet">
|
||||||
<script src="/assets/webjars/jquery-ui/1.10.2/ui/minified/jquery-ui.min.js"></script>
|
<script src="/assets/webjars/jquery-ui/1.10.2/ui/minified/jquery-ui.min.js"></script>
|
||||||
<script src="/assets/webjars/d3js/3.3.5/d3.min.js" charset="utf-8"></script>
|
<script src="/assets/webjars/d3js/3.3.5/d3.min.js" charset="utf-8"></script>
|
||||||
<script src="/assets/webjars/bootstrap/2.3.2/js/bootstrap.min.js"></script>
|
<script src="/assets/webjars/bootstrap/4.3.1/dist/js/bootstrap.min.js"></script>
|
||||||
<script src="/assets/dl4j-ui.js"></script>
|
<script src="/assets/dl4j-ui.js"></script>
|
||||||
|
|
||||||
<script>
|
<script>
|
||||||
|
|
|
@ -35,7 +35,7 @@ import org.datavec.spark.BaseSparkTest;
|
||||||
import org.datavec.spark.transform.AnalyzeSpark;
|
import org.datavec.spark.transform.AnalyzeSpark;
|
||||||
import org.joda.time.DateTimeZone;
|
import org.joda.time.DateTimeZone;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.nd4j.graph.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.io.ClassPathResource;
|
import org.nd4j.linalg.io.ClassPathResource;
|
||||||
|
|
||||||
|
|
|
@ -135,7 +135,7 @@ public class EncodingHandler implements MessageHandler {
|
||||||
iterations.get().incrementAndGet();
|
iterations.get().incrementAndGet();
|
||||||
|
|
||||||
if (boundary != null && atomicBoundary.get() < 0)
|
if (boundary != null && atomicBoundary.get() < 0)
|
||||||
atomicBoundary.compareAndSet(-1, (int) (updates.lengthLong() * boundary));
|
atomicBoundary.compareAndSet(-1, (int) (updates.length() * boundary));
|
||||||
|
|
||||||
INDArray encoded;
|
INDArray encoded;
|
||||||
|
|
||||||
|
@ -160,11 +160,11 @@ public class EncodingHandler implements MessageHandler {
|
||||||
double encLen = encoded.data().getInt(0);
|
double encLen = encoded.data().getInt(0);
|
||||||
|
|
||||||
// if updates are too dense - we fallback to bitmap encoding
|
// if updates are too dense - we fallback to bitmap encoding
|
||||||
if (encLen >= (updates.lengthLong() / 16)) {
|
if (encLen >= (updates.length() / 16)) {
|
||||||
log.debug("Switching back to bitmapEncoding: iteration {}, epoch {}, threshold {}, encoded length {}", iteration, epoch, currThreshold, encLen);
|
log.debug("Switching back to bitmapEncoding: iteration {}, epoch {}, threshold {}, encoded length {}", iteration, epoch, currThreshold, encLen);
|
||||||
bitmapMode.get().set(true);
|
bitmapMode.get().set(true);
|
||||||
|
|
||||||
DataBuffer buffer = Nd4j.getDataBufferFactory().createInt(updates.lengthLong() / 16 + 5);
|
DataBuffer buffer = Nd4j.getDataBufferFactory().createInt(updates.length() / 16 + 5);
|
||||||
encoded = Nd4j.createArrayFromShapeBuffer(buffer, updates.shapeInfoDataBuffer());
|
encoded = Nd4j.createArrayFromShapeBuffer(buffer, updates.shapeInfoDataBuffer());
|
||||||
|
|
||||||
Nd4j.getExecutioner().bitmapEncode(updates, encoded, currentThreshold.get().get());
|
Nd4j.getExecutioner().bitmapEncode(updates, encoded, currentThreshold.get().get());
|
||||||
|
@ -186,12 +186,12 @@ public class EncodingHandler implements MessageHandler {
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
//Dense bitmap updates
|
//Dense bitmap updates
|
||||||
DataBuffer buffer = Nd4j.getDataBufferFactory().createInt(updates.lengthLong() / 16 + 5);
|
DataBuffer buffer = Nd4j.getDataBufferFactory().createInt(updates.length() / 16 + 5);
|
||||||
encoded = Nd4j.createArrayFromShapeBuffer(buffer, updates.shapeInfoDataBuffer());
|
encoded = Nd4j.createArrayFromShapeBuffer(buffer, updates.shapeInfoDataBuffer());
|
||||||
|
|
||||||
long values = Nd4j.getExecutioner().bitmapEncode(updates, encoded, currentThreshold.get().get());
|
long values = Nd4j.getExecutioner().bitmapEncode(updates, encoded, currentThreshold.get().get());
|
||||||
|
|
||||||
if (values < (updates.lengthLong() / 16 + 5) / 2) {
|
if (values < (updates.length() / 16 + 5) / 2) {
|
||||||
boolean current = bitmapMode.get().get();
|
boolean current = bitmapMode.get().get();
|
||||||
bitmapMode.get().set(false);
|
bitmapMode.get().set(false);
|
||||||
if(!current) {
|
if(!current) {
|
||||||
|
|
|
@ -63,7 +63,7 @@
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>com.fasterxml.jackson.module</groupId>
|
<groupId>com.fasterxml.jackson.module</groupId>
|
||||||
<artifactId>jackson-module-scala_2.11</artifactId>
|
<artifactId>jackson-module-scala_2.11</artifactId>
|
||||||
<version>${jackson.version}</version>
|
<version>2.6.7</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
</dependencies>
|
</dependencies>
|
||||||
|
|
||||||
|
|
|
@ -23,6 +23,9 @@ import org.junit.After;
|
||||||
import org.junit.Before;
|
import org.junit.Before;
|
||||||
|
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
|
import java.lang.reflect.Field;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Created by agibsonccc on 1/23/15.
|
* Created by agibsonccc on 1/23/15.
|
||||||
|
@ -37,7 +40,9 @@ public abstract class BaseSparkTest implements Serializable {
|
||||||
|
|
||||||
@After
|
@After
|
||||||
public void after() {
|
public void after() {
|
||||||
sc.close();
|
if(sc != null) {
|
||||||
|
sc.close();
|
||||||
|
}
|
||||||
sc = null;
|
sc = null;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -48,6 +53,30 @@ public abstract class BaseSparkTest implements Serializable {
|
||||||
public JavaSparkContext getContext() {
|
public JavaSparkContext getContext() {
|
||||||
if (sc != null)
|
if (sc != null)
|
||||||
return sc;
|
return sc;
|
||||||
|
|
||||||
|
//Ensure SPARK_USER environment variable is set for Spark tests
|
||||||
|
String u = System.getenv("SPARK_USER");
|
||||||
|
Map<String, String> env = System.getenv();
|
||||||
|
if(u == null || u.isEmpty()) {
|
||||||
|
try {
|
||||||
|
Class[] classes = Collections.class.getDeclaredClasses();
|
||||||
|
for (Class cl : classes) {
|
||||||
|
if ("java.util.Collections$UnmodifiableMap".equals(cl.getName())) {
|
||||||
|
Field field = cl.getDeclaredField("m");
|
||||||
|
field.setAccessible(true);
|
||||||
|
Object obj = field.get(env);
|
||||||
|
Map<String, String> map = (Map<String, String>) obj;
|
||||||
|
String user = System.getProperty("user.name");
|
||||||
|
if (user == null || user.isEmpty())
|
||||||
|
user = "user";
|
||||||
|
map.put("SPARK_USER", user);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (Exception e) {
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// set to test mode
|
// set to test mode
|
||||||
SparkConf sparkConf = new SparkConf().setMaster("local[4]").set("spark.driver.host", "localhost")
|
SparkConf sparkConf = new SparkConf().setMaster("local[4]").set("spark.driver.host", "localhost")
|
||||||
.setAppName("sparktest")
|
.setAppName("sparktest")
|
||||||
|
|
|
@ -19,6 +19,10 @@ package org.deeplearning4j.spark;
|
||||||
import org.apache.spark.SparkConf;
|
import org.apache.spark.SparkConf;
|
||||||
import org.apache.spark.api.java.JavaSparkContext;
|
import org.apache.spark.api.java.JavaSparkContext;
|
||||||
|
|
||||||
|
import java.lang.reflect.Field;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Created by Alex on 04/07/2017.
|
* Created by Alex on 04/07/2017.
|
||||||
*/
|
*/
|
||||||
|
@ -30,6 +34,31 @@ public class BaseSparkKryoTest extends BaseSparkTest {
|
||||||
return sc;
|
return sc;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//Ensure SPARK_USER environment variable is set for Spark Kryo tests
|
||||||
|
String u = System.getenv("SPARK_USER");
|
||||||
|
if(u == null || u.isEmpty()){
|
||||||
|
try {
|
||||||
|
Class[] classes = Collections.class.getDeclaredClasses();
|
||||||
|
Map<String, String> env = System.getenv();
|
||||||
|
for (Class cl : classes) {
|
||||||
|
if ("java.util.Collections$UnmodifiableMap".equals(cl.getName())) {
|
||||||
|
Field field = cl.getDeclaredField("m");
|
||||||
|
field.setAccessible(true);
|
||||||
|
Object obj = field.get(env);
|
||||||
|
Map<String, String> map = (Map<String, String>) obj;
|
||||||
|
String user = System.getProperty("user.name");
|
||||||
|
if(user == null || user.isEmpty())
|
||||||
|
user = "user";
|
||||||
|
map.put("SPARK_USER", user);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (Exception e){
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
SparkConf sparkConf = new SparkConf().setMaster("local[" + numExecutors() + "]").setAppName("sparktest");
|
SparkConf sparkConf = new SparkConf().setMaster("local[" + numExecutors() + "]").setAppName("sparktest");
|
||||||
|
|
||||||
sparkConf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer");
|
sparkConf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer");
|
||||||
|
|
|
@ -74,7 +74,9 @@ public abstract class BaseSparkTest implements Serializable {
|
||||||
|
|
||||||
@After
|
@After
|
||||||
public void after() {
|
public void after() {
|
||||||
sc.close();
|
if(sc != null) {
|
||||||
|
sc.close();
|
||||||
|
}
|
||||||
sc = null;
|
sc = null;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -19,23 +19,20 @@ package org.deeplearning4j.ui.module.remote;
|
||||||
import com.fasterxml.jackson.databind.JsonNode;
|
import com.fasterxml.jackson.databind.JsonNode;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.deeplearning4j.api.storage.*;
|
import org.deeplearning4j.api.storage.*;
|
||||||
import org.deeplearning4j.ui.api.FunctionType;
|
|
||||||
import org.deeplearning4j.ui.api.HttpMethod;
|
import org.deeplearning4j.ui.api.HttpMethod;
|
||||||
import org.deeplearning4j.ui.api.Route;
|
import org.deeplearning4j.ui.api.Route;
|
||||||
import org.deeplearning4j.ui.api.UIModule;
|
import org.deeplearning4j.ui.api.UIModule;
|
||||||
import org.deeplearning4j.ui.i18n.I18NResource;
|
import org.deeplearning4j.ui.i18n.I18NResource;
|
||||||
|
import play.mvc.Http;
|
||||||
import play.mvc.Result;
|
import play.mvc.Result;
|
||||||
import play.mvc.Results;
|
import play.mvc.Results;
|
||||||
|
|
||||||
import javax.xml.bind.DatatypeConverter;
|
import javax.xml.bind.DatatypeConverter;
|
||||||
import java.io.File;
|
|
||||||
import java.util.Collection;
|
import java.util.Collection;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.concurrent.atomic.AtomicBoolean;
|
import java.util.concurrent.atomic.AtomicBoolean;
|
||||||
|
|
||||||
import static play.mvc.Http.Context.Implicit.request;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
*
|
||||||
* Used to receive UI updates remotely.
|
* Used to receive UI updates remotely.
|
||||||
|
@ -73,7 +70,7 @@ public class RemoteReceiverModule implements UIModule {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<Route> getRoutes() {
|
public List<Route> getRoutes() {
|
||||||
Route r = new Route("/remoteReceive", HttpMethod.POST, FunctionType.Supplier, this::receiveData);
|
Route r = Route.request0Function("/remoteReceive", HttpMethod.POST, this::receiveData);
|
||||||
return Collections.singletonList(r);
|
return Collections.singletonList(r);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -98,7 +95,7 @@ public class RemoteReceiverModule implements UIModule {
|
||||||
return Collections.emptyList();
|
return Collections.emptyList();
|
||||||
}
|
}
|
||||||
|
|
||||||
private Result receiveData() {
|
private Result receiveData(Http.Request request) {
|
||||||
if (!enabled.get()) {
|
if (!enabled.get()) {
|
||||||
return Results.forbidden(
|
return Results.forbidden(
|
||||||
"UI server remote listening is currently disabled. Use UIServer.getInstance().enableRemoteListener()");
|
"UI server remote listening is currently disabled. Use UIServer.getInstance().enableRemoteListener()");
|
||||||
|
@ -109,7 +106,7 @@ public class RemoteReceiverModule implements UIModule {
|
||||||
"UI Server remote listener: no StatsStorage instance is set/available to store results");
|
"UI Server remote listener: no StatsStorage instance is set/available to store results");
|
||||||
}
|
}
|
||||||
|
|
||||||
JsonNode jn = request().body().asJson();
|
JsonNode jn = request.body().asJson();
|
||||||
JsonNode type = jn.get("type");
|
JsonNode type = jn.get("type");
|
||||||
JsonNode dataClass = jn.get("class");
|
JsonNode dataClass = jn.get("class");
|
||||||
JsonNode data = jn.get("data");
|
JsonNode data = jn.get("data");
|
||||||
|
|
|
@ -21,16 +21,16 @@ function selectStdevChart(fieldName) {
|
||||||
$("#stdevUpdates").attr("class", "active");
|
$("#stdevUpdates").attr("class", "active");
|
||||||
}
|
}
|
||||||
|
|
||||||
renderOverviewPage(false);
|
renderOverviewPage(true);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ---------- Render page ---------- */
|
/* ---------- Render page ---------- */
|
||||||
var lastUpdateTime = -1;
|
var lastUpdateTime = -1;
|
||||||
var lastUpdateSession = "";
|
var lastUpdateSession = "";
|
||||||
function renderOverviewPage(firstLoad) {
|
function renderOverviewPage(forceupdate) {
|
||||||
updateSessionWorkerSelect();
|
updateSessionWorkerSelect();
|
||||||
|
|
||||||
if(firstLoad || !lastUpdateSession || lastUpdateSession == "" || lastUpdateSession != currSession){
|
if(forceupdate || !lastUpdateSession || lastUpdateSession == "" || lastUpdateSession != currSession){
|
||||||
executeOverviewUpdate();
|
executeOverviewUpdate();
|
||||||
} else {
|
} else {
|
||||||
//Check last update time first - see if data has actually changed...
|
//Check last update time first - see if data has actually changed...
|
||||||
|
|
|
@ -109,7 +109,7 @@ namespace nd4j {
|
||||||
|
|
||||||
if (deviceId != previousDeviceId) {
|
if (deviceId != previousDeviceId) {
|
||||||
// discard existing stuff
|
// discard existing stuff
|
||||||
nd4j_printf("AffinityManager::setCurrentDevice() was invoked, releasing buffers\n", "");
|
//nd4j_printf("AffinityManager::setCurrentDevice() was invoked, releasing buffers\n", "");
|
||||||
LaunchContext::releaseBuffers();
|
LaunchContext::releaseBuffers();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -158,7 +158,7 @@ LaunchContext::LaunchContext() {
|
||||||
};
|
};
|
||||||
|
|
||||||
void LaunchContext::releaseBuffers() {
|
void LaunchContext::releaseBuffers() {
|
||||||
nd4j_printf("LaunchContext::releaseBuffers() was invoked\n", "");
|
//nd4j_printf("LaunchContext::releaseBuffers() was invoked\n", "");
|
||||||
contextBuffers.release();
|
contextBuffers.release();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -34,7 +34,7 @@ CUSTOM_OP_IMPL(sruCell, 4, 2, false, 0, 0) {
|
||||||
auto xt = INPUT_VARIABLE(0); // input [bS x inSize], bS - batch size, inSize - number of features
|
auto xt = INPUT_VARIABLE(0); // input [bS x inSize], bS - batch size, inSize - number of features
|
||||||
auto ct_1 = INPUT_VARIABLE(1); // previous cell state ct [bS x inSize], that is at previous time step t-1
|
auto ct_1 = INPUT_VARIABLE(1); // previous cell state ct [bS x inSize], that is at previous time step t-1
|
||||||
auto w = INPUT_VARIABLE(2); // weights [inSize x 3*inSize]
|
auto w = INPUT_VARIABLE(2); // weights [inSize x 3*inSize]
|
||||||
auto b = INPUT_VARIABLE(3); // biases [1 x 2*inSize]
|
auto b = INPUT_VARIABLE(3); // biases [2*inSize]
|
||||||
|
|
||||||
auto ht = OUTPUT_VARIABLE(0); // current cell output [bS x inSize], that is at current time step t
|
auto ht = OUTPUT_VARIABLE(0); // current cell output [bS x inSize], that is at current time step t
|
||||||
auto ct = OUTPUT_VARIABLE(1); // current cell state [bS x inSize], that is at current time step t
|
auto ct = OUTPUT_VARIABLE(1); // current cell state [bS x inSize], that is at current time step t
|
||||||
|
|
|
@ -6511,4 +6511,22 @@ public class SameDiff extends SDBaseOps {
|
||||||
public String generateNewVarName(String base, int argIndex) {
|
public String generateNewVarName(String base, int argIndex) {
|
||||||
return generateNewVarName(base, argIndex, true);
|
return generateNewVarName(base, argIndex, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns an unused variable name of the format <base>_#.
|
||||||
|
*
|
||||||
|
* Intended to be used for custom variables (like weights), arguments and op outputs should use {@link #generateNewVarName(String, int)}.
|
||||||
|
*/
|
||||||
|
public String generateDistinctCustomVariableName(String base){
|
||||||
|
if(!variables.containsKey(base))
|
||||||
|
return base;
|
||||||
|
|
||||||
|
int inc = 1;
|
||||||
|
|
||||||
|
while(variables.containsKey(base + "_" + inc)){
|
||||||
|
inc++;
|
||||||
|
}
|
||||||
|
|
||||||
|
return base + "_" + inc;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.nd4j.autodiff.samediff.ops;
|
package org.nd4j.autodiff.samediff.ops;
|
||||||
|
|
||||||
|
import lombok.NonNull;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.recurrent.*;
|
import org.nd4j.linalg.api.ops.impl.layers.recurrent.*;
|
||||||
|
@ -23,6 +24,15 @@ import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.*;
|
||||||
|
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs.GRUCellOutputs;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs.LSTMCellOutputs;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs.LSTMLayerOutputs;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs.SRUCellOutputs;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs.SRULayerOutputs;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.GRUWeights;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMWeights;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.SRUWeights;
|
||||||
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* SameDiff Recurrent Neural Network operations<br>
|
* SameDiff Recurrent Neural Network operations<br>
|
||||||
|
@ -39,90 +49,163 @@ public class SDRNN extends SDOps {
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The gru cell
|
* See {@link #gru(String, SDVariable, SDVariable, GRUWeights)}.
|
||||||
*
|
|
||||||
* @param configuration the configuration to use
|
|
||||||
* @return
|
|
||||||
*/
|
*/
|
||||||
public List<SDVariable> gru(GRUCellConfiguration configuration) {
|
public GRUCellOutputs gru(@NonNull SDVariable x, @NonNull SDVariable hLast, @NonNull GRUWeights weights) {
|
||||||
GRUCell c = new GRUCell(sd, configuration);
|
GRUCell c = new GRUCell(sd, x, hLast, weights);
|
||||||
return Arrays.asList(c.outputVariables());
|
return new GRUCellOutputs(c.outputVariables());
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The gru cell
|
* The GRU cell. Does a single time step operation.
|
||||||
*
|
*
|
||||||
* @param baseName the base name for the gru cell
|
* @param baseName The base name for the gru cell
|
||||||
* @param configuration the configuration to use
|
* @param x Input, with shape [batchSize, inSize]
|
||||||
* @return
|
* @param hLast Output of the previous cell/time step, with shape [batchSize, numUnits]
|
||||||
|
* @param weights The cell's weights.
|
||||||
|
* @return The cell's outputs.
|
||||||
*/
|
*/
|
||||||
public List<SDVariable> gru(String baseName, GRUCellConfiguration configuration) {
|
public GRUCellOutputs gru(String baseName, @NonNull SDVariable x, @NonNull SDVariable hLast, @NonNull GRUWeights weights) {
|
||||||
GRUCell c = new GRUCell(sd, configuration);
|
GRUCell c = new GRUCell(sd, x, hLast, weights);
|
||||||
return Arrays.asList(c.outputVariables(baseName));
|
return new GRUCellOutputs(c.outputVariables(baseName));
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* LSTM unit
|
|
||||||
*
|
|
||||||
* @param baseName the base name for outputs
|
|
||||||
* @param configuration the configuration to use
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public SDVariable lstmCell(String baseName, LSTMCellConfiguration configuration) {
|
|
||||||
return new LSTMCell(sd, configuration).outputVariables(baseName)[0];
|
|
||||||
}
|
|
||||||
|
|
||||||
public List<SDVariable> lstmBlockCell(String name, LSTMBlockCellConfiguration configuration){
|
|
||||||
SDVariable[] v = new LSTMBlockCell(sd, configuration).outputVariables(name);
|
|
||||||
return Arrays.asList(v);
|
|
||||||
}
|
|
||||||
|
|
||||||
public List<SDVariable> lstmLayer(String name, LSTMConfiguration configuration){
|
|
||||||
SDVariable[] v = new LSTMLayer(sd, configuration).outputVariables(name);
|
|
||||||
return Arrays.asList(v);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Simple recurrent unit
|
* See {@link #lstmCell(String, SDVariable, SDVariable, SDVariable, LSTMWeights, LSTMConfiguration)}.
|
||||||
*
|
|
||||||
* @param configuration the configuration for the sru
|
|
||||||
* @return
|
|
||||||
*/
|
*/
|
||||||
public SDVariable sru(SRUConfiguration configuration) {
|
public LSTMCellOutputs lstmCell(@NonNull SDVariable x, @NonNull SDVariable cLast, @NonNull SDVariable yLast,
|
||||||
return new SRU(sd, configuration).outputVariables()[0];
|
LSTMWeights weights, LSTMConfiguration config){
|
||||||
|
LSTMBlockCell c = new LSTMBlockCell(sd, x, cLast, yLast, weights, config);
|
||||||
|
return new LSTMCellOutputs(c.outputVariables());
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Simiple recurrent unit
|
* The LSTM cell. Does a single time step operation.
|
||||||
*
|
*
|
||||||
* @param baseName the base name to use for output variables
|
* @param baseName The base name for the lstm cell
|
||||||
* @param configuration the configuration for the sru
|
* @param x Input, with shape [batchSize, inSize]
|
||||||
* @return
|
* @param cLast Previous cell state, with shape [batchSize, numUnits]
|
||||||
|
* @param yLast Previous cell output, with shape [batchSize, numUnits]
|
||||||
|
* @param weights The cell's weights.
|
||||||
|
* @param config The cell's config.
|
||||||
|
* @return The cell's outputs.
|
||||||
*/
|
*/
|
||||||
public SDVariable sru(String baseName, SRUConfiguration configuration) {
|
public LSTMCellOutputs lstmCell(String baseName, @NonNull SDVariable x, @NonNull SDVariable cLast, @NonNull SDVariable yLast,
|
||||||
return new SRU(sd, configuration).outputVariables(baseName)[0];
|
@NonNull LSTMWeights weights, @NonNull LSTMConfiguration config){
|
||||||
|
LSTMBlockCell c = new LSTMBlockCell(sd, x, cLast, yLast, weights, config);
|
||||||
|
return new LSTMCellOutputs(c.outputVariables(baseName));
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* An sru cell
|
* See {@link #lstmLayer(String, SDVariable, SDVariable, SDVariable, SDVariable, LSTMWeights, LSTMConfiguration)}
|
||||||
*
|
|
||||||
* @param configuration the configuration for the sru cell
|
|
||||||
* @return
|
|
||||||
*/
|
*/
|
||||||
public SDVariable sruCell(SRUCellConfiguration configuration) {
|
public LSTMLayerOutputs lstmLayer(@NonNull SDVariable maxTSLength,
|
||||||
return new SRUCell(sd, configuration).outputVariables()[0];
|
@NonNull SDVariable x, @NonNull SDVariable cLast, @NonNull SDVariable yLast,
|
||||||
|
@NonNull LSTMWeights weights, @NonNull LSTMConfiguration config){
|
||||||
|
LSTMLayer c = new LSTMLayer(sd, maxTSLength, x, cLast, yLast, weights, config);
|
||||||
|
return new LSTMLayerOutputs(c.outputVariables(), config.getDataFormat());
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* An sru cell
|
* See {@link #lstmLayer(String, SDVariable, SDVariable, SDVariable, SDVariable, LSTMWeights, LSTMConfiguration)}
|
||||||
*
|
|
||||||
* @param baseName the base name to use for the output variables
|
|
||||||
* @param configuration the configuration for the sru cell
|
|
||||||
* @return
|
|
||||||
*/
|
*/
|
||||||
public SDVariable sruCell(String baseName, SRUCellConfiguration configuration) {
|
public LSTMLayerOutputs lstmLayer(int maxTSLength, @NonNull SDVariable x, @NonNull SDVariable cLast, @NonNull SDVariable yLast,
|
||||||
return new SRUCell(sd, configuration).outputVariables(baseName)[0];
|
@NonNull LSTMWeights weights, @NonNull LSTMConfiguration config){
|
||||||
|
return lstmLayer(
|
||||||
|
sd.scalar("lstm_max_ts_length", maxTSLength),
|
||||||
|
x, cLast, yLast, weights, config);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* See {@link #lstmLayer(String, SDVariable, SDVariable, SDVariable, SDVariable, LSTMWeights, LSTMConfiguration)}
|
||||||
|
*/
|
||||||
|
public LSTMLayerOutputs lstmLayer(String baseName, int maxTSLength, @NonNull SDVariable x, @NonNull SDVariable cLast, @NonNull SDVariable yLast,
|
||||||
|
@NonNull LSTMWeights weights, @NonNull LSTMConfiguration config){
|
||||||
|
if(baseName != null) {
|
||||||
|
return lstmLayer(baseName,
|
||||||
|
sd.scalar(sd.generateDistinctCustomVariableName(baseName + "_max_ts_length"), maxTSLength),
|
||||||
|
x, cLast, yLast, weights, config);
|
||||||
|
} else {
|
||||||
|
return lstmLayer(maxTSLength, x, cLast, yLast, weights, config);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The LSTM layer. Does multiple time steps.
|
||||||
|
*
|
||||||
|
* Input shape depends on data format (in config):<br>
|
||||||
|
* TNS -> [timeSteps, batchSize, inSize]<br>
|
||||||
|
* NST -> [batchSize, inSize, timeSteps]<br>
|
||||||
|
* NTS -> [batchSize, timeSteps, inSize]<br>
|
||||||
|
*
|
||||||
|
* @param baseName The base name for the lstm layer
|
||||||
|
* @param x Input, with shape dependent on the data format (in config).
|
||||||
|
* @param cLast Previous/initial cell state, with shape [batchSize, numUnits]
|
||||||
|
* @param yLast Previous/initial cell output, with shape [batchSize, numUnits]
|
||||||
|
* @param weights The layer's weights.
|
||||||
|
* @param config The layer's config.
|
||||||
|
* @return The layer's outputs.
|
||||||
|
*/
|
||||||
|
public LSTMLayerOutputs lstmLayer(String baseName, @NonNull SDVariable maxTSLength,
|
||||||
|
@NonNull SDVariable x, @NonNull SDVariable cLast, @NonNull SDVariable yLast,
|
||||||
|
@NonNull LSTMWeights weights, @NonNull LSTMConfiguration config){
|
||||||
|
LSTMLayer c = new LSTMLayer(sd, maxTSLength, x, cLast, yLast, weights, config);
|
||||||
|
return new LSTMLayerOutputs(c.outputVariables(baseName), config.getDataFormat());
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* See {@link #sruCell(String, SDVariable, SDVariable, SRUWeights)}.
|
||||||
|
*/
|
||||||
|
public SRUCellOutputs sruCell(@NonNull SDVariable x, @NonNull SDVariable cLast, @NonNull SRUWeights weights) {
|
||||||
|
return new SRUCellOutputs(new SRUCell(sd, x, cLast, weights).outputVariables());
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The SRU cell. Does a single time step operation.
|
||||||
|
*
|
||||||
|
* @param baseName The base name for the sru cell
|
||||||
|
* @param x Input, with shape [batchSize, inSize]
|
||||||
|
* @param cLast Previous cell state, with shape [batchSize, inSize]
|
||||||
|
* @param weights The cell's weights.
|
||||||
|
* @return The cell's outputs.
|
||||||
|
*/
|
||||||
|
public SRUCellOutputs sruCell(String baseName, @NonNull SDVariable x, @NonNull SDVariable cLast, @NonNull SRUWeights weights) {
|
||||||
|
return new SRUCellOutputs(new SRUCell(sd, x, cLast, weights).outputVariables(baseName));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* See {@link #sru(String, SDVariable, SDVariable, SDVariable, SRUWeights)}
|
||||||
|
*/
|
||||||
|
public SRULayerOutputs sru(@NonNull SDVariable x, @NonNull SDVariable initialC, @NonNull SRUWeights weights) {
|
||||||
|
return sru(x, initialC, null, weights);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* See {@link #sru(String, SDVariable, SDVariable, SDVariable, SRUWeights)}
|
||||||
|
*/
|
||||||
|
public SRULayerOutputs sru(String baseName, @NonNull SDVariable x, @NonNull SDVariable initialC, @NonNull SRUWeights weights) {
|
||||||
|
return sru(baseName, x, initialC, null, weights);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* See {@link #sru(String, SDVariable, SDVariable, SDVariable, SRUWeights)}
|
||||||
|
*/
|
||||||
|
public SRULayerOutputs sru(@NonNull SDVariable x, @NonNull SDVariable initialC, SDVariable mask, @NonNull SRUWeights weights) {
|
||||||
|
return new SRULayerOutputs(new SRU(sd, x, initialC, mask, weights).outputVariables());
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The SRU layer. Does a single time step operation.
|
||||||
|
*
|
||||||
|
* @param baseName The base name for the sru layer
|
||||||
|
* @param x Input, with shape [batchSize, inSize, timeSeriesLength]
|
||||||
|
* @param initialC Initial cell state, with shape [batchSize, inSize]
|
||||||
|
* @param mask An optional dropout mask, with shape [batchSize, inSize]
|
||||||
|
* @param weights The layer's weights.
|
||||||
|
* @return The layer's outputs.
|
||||||
|
*/
|
||||||
|
public SRULayerOutputs sru(String baseName, @NonNull SDVariable x, @NonNull SDVariable initialC, SDVariable mask, @NonNull SRUWeights weights) {
|
||||||
|
return new SRULayerOutputs(new SRU(sd, x, initialC, mask, weights).outputVariables(baseName));
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -75,7 +75,6 @@ import static org.nd4j.linalg.indexing.NDArrayIndex.interval;
|
||||||
@EqualsAndHashCode(callSuper = true,
|
@EqualsAndHashCode(callSuper = true,
|
||||||
exclude = {"auc", "auprc", "probAndLabel", "exactAllocBlockSize", "rocCurve", "prCurve", "axis"})
|
exclude = {"auc", "auprc", "probAndLabel", "exactAllocBlockSize", "rocCurve", "prCurve", "axis"})
|
||||||
@Data
|
@Data
|
||||||
@ToString(exclude = {"probAndLabel", "exactAllocBlockSize", "rocCurve", "prCurve"})
|
|
||||||
@JsonIgnoreProperties({"probAndLabel", "exactAllocBlockSize"})
|
@JsonIgnoreProperties({"probAndLabel", "exactAllocBlockSize"})
|
||||||
@JsonSerialize(using = ROCSerializer.class)
|
@JsonSerialize(using = ROCSerializer.class)
|
||||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY)
|
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY)
|
||||||
|
@ -824,6 +823,11 @@ public class ROC extends BaseEvaluation<ROC> {
|
||||||
return sb.toString();
|
return sb.toString();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString(){
|
||||||
|
return stats();
|
||||||
|
}
|
||||||
|
|
||||||
public double scoreForMetric(Metric metric){
|
public double scoreForMetric(Metric metric){
|
||||||
switch (metric){
|
switch (metric){
|
||||||
case AUROC:
|
case AUROC:
|
||||||
|
|
|
@ -4641,17 +4641,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
||||||
return jvmShapeInfo.length;
|
return jvmShapeInfo.length;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns the total number of elements in the ndarray
|
|
||||||
*
|
|
||||||
* @return the number of elements in the ndarray
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
@Deprecated
|
|
||||||
public long lengthLong() {
|
|
||||||
return jvmShapeInfo.length;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray broadcast(INDArray result) {
|
public INDArray broadcast(INDArray result) {
|
||||||
Nd4j.getCompressor().autoDecompress(this);
|
Nd4j.getCompressor().autoDecompress(this);
|
||||||
|
|
|
@ -279,11 +279,6 @@ public abstract class BaseSparseNDArray implements ISparseNDArray {
|
||||||
return (int) length();
|
return (int) length();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public long lengthLong() {
|
|
||||||
return length;
|
|
||||||
}
|
|
||||||
|
|
||||||
protected void init(long[] shape) {
|
protected void init(long[] shape) {
|
||||||
|
|
||||||
if (shape.length == 1) {
|
if (shape.length == 1) {
|
||||||
|
|
|
@ -2377,17 +2377,7 @@ public interface INDArray extends Serializable, AutoCloseable {
|
||||||
* @return the number of elements in the ndarray
|
* @return the number of elements in the ndarray
|
||||||
*/
|
*/
|
||||||
long length();
|
long length();
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns the total number of elements in the ndarray
|
|
||||||
*
|
|
||||||
* @return the number of elements in the ndarray
|
|
||||||
* @deprecated use {@link #length()}
|
|
||||||
*/
|
|
||||||
@Deprecated
|
|
||||||
long lengthLong();
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Broadcasts this ndarray to be the specified shape
|
* Broadcasts this ndarray to be the specified shape
|
||||||
*
|
*
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.layers.recurrent;
|
package org.nd4j.linalg.api.ops.impl.layers.recurrent;
|
||||||
|
|
||||||
|
import lombok.Getter;
|
||||||
import onnx.Onnx;
|
import onnx.Onnx;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
@ -23,6 +24,7 @@ import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.GRUCellConfiguration;
|
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.GRUCellConfiguration;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.GRUWeights;
|
||||||
import org.tensorflow.framework.AttrValue;
|
import org.tensorflow.framework.AttrValue;
|
||||||
import org.tensorflow.framework.GraphDef;
|
import org.tensorflow.framework.GraphDef;
|
||||||
import org.tensorflow.framework.NodeDef;
|
import org.tensorflow.framework.NodeDef;
|
||||||
|
@ -39,14 +41,15 @@ import java.util.Map;
|
||||||
*/
|
*/
|
||||||
public class GRUCell extends DynamicCustomOp {
|
public class GRUCell extends DynamicCustomOp {
|
||||||
|
|
||||||
private GRUCellConfiguration configuration;
|
@Getter
|
||||||
|
private GRUWeights weights;
|
||||||
|
|
||||||
public GRUCell() {
|
public GRUCell() {
|
||||||
}
|
}
|
||||||
|
|
||||||
public GRUCell(SameDiff sameDiff, GRUCellConfiguration configuration) {
|
public GRUCell(SameDiff sameDiff, SDVariable x, SDVariable hLast, GRUWeights weights) {
|
||||||
super(null, sameDiff, configuration.args());
|
super(null, sameDiff, weights.argsWithInputs(x, hLast));
|
||||||
this.configuration = configuration;
|
this.weights = weights;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -16,12 +16,15 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.layers.recurrent;
|
package org.nd4j.linalg.api.ops.impl.layers.recurrent;
|
||||||
|
|
||||||
|
import lombok.Getter;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMBlockCellConfiguration;
|
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMConfiguration;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMWeights;
|
||||||
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
import org.tensorflow.framework.AttrValue;
|
import org.tensorflow.framework.AttrValue;
|
||||||
import org.tensorflow.framework.GraphDef;
|
import org.tensorflow.framework.GraphDef;
|
||||||
import org.tensorflow.framework.NodeDef;
|
import org.tensorflow.framework.NodeDef;
|
||||||
|
@ -49,10 +52,12 @@ import java.util.Map;
|
||||||
* 6: weights - cell peephole (t) connections to output gate, [numUnits]<br>
|
* 6: weights - cell peephole (t) connections to output gate, [numUnits]<br>
|
||||||
* 7: biases, shape [4*numUnits]<br>
|
* 7: biases, shape [4*numUnits]<br>
|
||||||
* <br>
|
* <br>
|
||||||
* Input integer arguments: set via {@link LSTMBlockCellConfiguration}<br>
|
* Weights are set via {@link LSTMWeights}.<br>
|
||||||
|
* <br>
|
||||||
|
* Input integer arguments: set via {@link LSTMConfiguration}<br>
|
||||||
* 0: if not zero, provide peephole connections<br>
|
* 0: if not zero, provide peephole connections<br>
|
||||||
* <br>
|
* <br>
|
||||||
* Input float arguments: set via {@link LSTMBlockCellConfiguration}<br>
|
* Input float arguments: set via {@link LSTMConfiguration}<br>
|
||||||
* 0: the bias added to forget gates in order to reduce the scale of forgetting in the beginning of the training<br>
|
* 0: the bias added to forget gates in order to reduce the scale of forgetting in the beginning of the training<br>
|
||||||
* 1: clipping value for cell state, if it is not equal to zero, then cell state is clipped<br>
|
* 1: clipping value for cell state, if it is not equal to zero, then cell state is clipped<br>
|
||||||
* <br>
|
* <br>
|
||||||
|
@ -69,15 +74,19 @@ import java.util.Map;
|
||||||
*/
|
*/
|
||||||
public class LSTMBlockCell extends DynamicCustomOp {
|
public class LSTMBlockCell extends DynamicCustomOp {
|
||||||
|
|
||||||
private LSTMBlockCellConfiguration configuration;
|
private LSTMConfiguration configuration;
|
||||||
|
|
||||||
|
@Getter
|
||||||
|
private LSTMWeights weights;
|
||||||
|
|
||||||
public LSTMBlockCell() {
|
public LSTMBlockCell() {
|
||||||
}
|
}
|
||||||
|
|
||||||
public LSTMBlockCell(SameDiff sameDiff, LSTMBlockCellConfiguration configuration) {
|
public LSTMBlockCell(SameDiff sameDiff, SDVariable x, SDVariable cLast, SDVariable yLast, LSTMWeights weights, LSTMConfiguration configuration) {
|
||||||
super(null, sameDiff, configuration.args());
|
super(null, sameDiff, weights.argsWithInputs(x, cLast, yLast));
|
||||||
this.configuration = configuration;
|
this.configuration = configuration;
|
||||||
addIArgument(configuration.iArgs());
|
this.weights = weights;
|
||||||
|
addIArgument(configuration.iArgs(false));
|
||||||
addTArgument(configuration.tArgs());
|
addTArgument(configuration.tArgs());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -97,12 +106,12 @@ public class LSTMBlockCell extends DynamicCustomOp {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
||||||
configuration = LSTMBlockCellConfiguration.builder()
|
configuration = LSTMConfiguration.builder()
|
||||||
.forgetBias(attributesForNode.get("forget_bias").getF())
|
.forgetBias(attributesForNode.get("forget_bias").getF())
|
||||||
.clippingCellValue(attributesForNode.get("cell_clip").getF())
|
.clippingCellValue(attributesForNode.get("cell_clip").getF())
|
||||||
.peepHole(attributesForNode.get("use_peephole").getB())
|
.peepHole(attributesForNode.get("use_peephole").getB())
|
||||||
.build();
|
.build();
|
||||||
addIArgument(configuration.iArgs());
|
addIArgument(configuration.iArgs(false));
|
||||||
addTArgument(configuration.tArgs());
|
addTArgument(configuration.tArgs());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -113,7 +122,7 @@ public class LSTMBlockCell extends DynamicCustomOp {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Map<String, Object> propertiesForFunction() {
|
public Map<String, Object> propertiesForFunction() {
|
||||||
return configuration.toProperties();
|
return configuration.toProperties(false);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.layers.recurrent;
|
package org.nd4j.linalg.api.ops.impl.layers.recurrent;
|
||||||
|
|
||||||
|
import lombok.Getter;
|
||||||
import lombok.NonNull;
|
import lombok.NonNull;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
@ -24,6 +25,7 @@ import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMConfiguration;
|
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMConfiguration;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.RnnDataFormat;
|
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.RnnDataFormat;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMWeights;
|
||||||
import org.tensorflow.framework.AttrValue;
|
import org.tensorflow.framework.AttrValue;
|
||||||
import org.tensorflow.framework.GraphDef;
|
import org.tensorflow.framework.GraphDef;
|
||||||
import org.tensorflow.framework.NodeDef;
|
import org.tensorflow.framework.NodeDef;
|
||||||
|
@ -75,13 +77,17 @@ public class LSTMLayer extends DynamicCustomOp {
|
||||||
|
|
||||||
private LSTMConfiguration configuration;
|
private LSTMConfiguration configuration;
|
||||||
|
|
||||||
|
@Getter
|
||||||
|
private LSTMWeights weights;
|
||||||
|
|
||||||
public LSTMLayer() {
|
public LSTMLayer() {
|
||||||
}
|
}
|
||||||
|
|
||||||
public LSTMLayer(@NonNull SameDiff sameDiff, @NonNull LSTMConfiguration configuration) {
|
public LSTMLayer(@NonNull SameDiff sameDiff, SDVariable maxTSLength, SDVariable x, SDVariable cLast, SDVariable yLast, LSTMWeights weights, LSTMConfiguration configuration) {
|
||||||
super(null, sameDiff, configuration.args());
|
super(null, sameDiff, weights.argsWithInputs(maxTSLength, x, cLast, yLast));
|
||||||
this.configuration = configuration;
|
this.configuration = configuration;
|
||||||
addIArgument(configuration.iArgs());
|
this.weights = weights;
|
||||||
|
addIArgument(configuration.iArgs(true));
|
||||||
addTArgument(configuration.tArgs());
|
addTArgument(configuration.tArgs());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -107,7 +113,7 @@ public class LSTMLayer extends DynamicCustomOp {
|
||||||
.peepHole(attributesForNode.get("use_peephole").getB())
|
.peepHole(attributesForNode.get("use_peephole").getB())
|
||||||
.dataFormat(RnnDataFormat.TNS) //Always time major for TF BlockLSTM
|
.dataFormat(RnnDataFormat.TNS) //Always time major for TF BlockLSTM
|
||||||
.build();
|
.build();
|
||||||
addIArgument(configuration.iArgs());
|
addIArgument(configuration.iArgs(true));
|
||||||
addTArgument(configuration.tArgs());
|
addTArgument(configuration.tArgs());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -118,7 +124,7 @@ public class LSTMLayer extends DynamicCustomOp {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Map<String, Object> propertiesForFunction() {
|
public Map<String, Object> propertiesForFunction() {
|
||||||
return configuration.toProperties();
|
return configuration.toProperties(true);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -16,11 +16,16 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.layers.recurrent;
|
package org.nd4j.linalg.api.ops.impl.layers.recurrent;
|
||||||
|
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.List;
|
||||||
|
import lombok.Getter;
|
||||||
|
import lombok.NonNull;
|
||||||
import onnx.Onnx;
|
import onnx.Onnx;
|
||||||
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.imports.NoOpNameFoundException;
|
import org.nd4j.imports.NoOpNameFoundException;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.SRUConfiguration;
|
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.SRUWeights;
|
||||||
import org.tensorflow.framework.AttrValue;
|
import org.tensorflow.framework.AttrValue;
|
||||||
import org.tensorflow.framework.GraphDef;
|
import org.tensorflow.framework.GraphDef;
|
||||||
import org.tensorflow.framework.NodeDef;
|
import org.tensorflow.framework.NodeDef;
|
||||||
|
@ -34,13 +39,18 @@ import java.util.Map;
|
||||||
*/
|
*/
|
||||||
public class SRU extends DynamicCustomOp {
|
public class SRU extends DynamicCustomOp {
|
||||||
|
|
||||||
private SRUConfiguration configuration;
|
@Getter
|
||||||
|
private SRUWeights weights;
|
||||||
|
|
||||||
|
@Getter
|
||||||
|
private SDVariable mask;
|
||||||
|
|
||||||
public SRU() { }
|
public SRU() { }
|
||||||
|
|
||||||
public SRU(SameDiff sameDiff, SRUConfiguration configuration) {
|
public SRU(@NonNull SameDiff sameDiff, @NonNull SDVariable x, @NonNull SDVariable initialC, SDVariable mask, @NonNull SRUWeights weights) {
|
||||||
super(null, sameDiff, configuration.args());
|
super(null, sameDiff, wrapFilterNull(x, weights.getWeights(), weights.getBias(), initialC, mask));
|
||||||
this.configuration = configuration;
|
this.mask = mask;
|
||||||
|
this.weights = weights;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -68,6 +78,4 @@ public class SRU extends DynamicCustomOp {
|
||||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||||
super.initFromOnnx(node, initWith, attributesForNode, graph);
|
super.initFromOnnx(node, initWith, attributesForNode, graph);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,17 +16,18 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.layers.recurrent;
|
package org.nd4j.linalg.api.ops.impl.layers.recurrent;
|
||||||
|
|
||||||
|
import java.util.Map;
|
||||||
|
import lombok.Getter;
|
||||||
import onnx.Onnx;
|
import onnx.Onnx;
|
||||||
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.imports.NoOpNameFoundException;
|
import org.nd4j.imports.NoOpNameFoundException;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.SRUCellConfiguration;
|
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.SRUWeights;
|
||||||
import org.tensorflow.framework.AttrValue;
|
import org.tensorflow.framework.AttrValue;
|
||||||
import org.tensorflow.framework.GraphDef;
|
import org.tensorflow.framework.GraphDef;
|
||||||
import org.tensorflow.framework.NodeDef;
|
import org.tensorflow.framework.NodeDef;
|
||||||
|
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A simple recurrent unit cell.
|
* A simple recurrent unit cell.
|
||||||
*
|
*
|
||||||
|
@ -34,14 +35,15 @@ import java.util.Map;
|
||||||
*/
|
*/
|
||||||
public class SRUCell extends DynamicCustomOp {
|
public class SRUCell extends DynamicCustomOp {
|
||||||
|
|
||||||
private SRUCellConfiguration configuration;
|
@Getter
|
||||||
|
private SRUWeights weights;
|
||||||
|
|
||||||
public SRUCell() {
|
public SRUCell() {
|
||||||
}
|
}
|
||||||
|
|
||||||
public SRUCell(SameDiff sameDiff, SRUCellConfiguration configuration) {
|
public SRUCell(SameDiff sameDiff, SDVariable x, SDVariable cLast, SRUWeights weights) {
|
||||||
super(null, sameDiff, configuration.args());
|
super(null, sameDiff, weights.argsWithInputs(x, cLast));
|
||||||
this.configuration = configuration;
|
this.weights = weights;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -1,57 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.layers.recurrent.config;
|
|
||||||
|
|
||||||
import lombok.Builder;
|
|
||||||
import lombok.Data;
|
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
|
||||||
import org.nd4j.linalg.util.ArrayUtil;
|
|
||||||
|
|
||||||
import java.util.LinkedHashMap;
|
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
@Builder
|
|
||||||
@Data
|
|
||||||
public class LSTMBlockCellConfiguration {
|
|
||||||
|
|
||||||
private boolean peepHole; //IArg(0)
|
|
||||||
private double forgetBias; //TArg(0)
|
|
||||||
private double clippingCellValue; //TArg(1)
|
|
||||||
|
|
||||||
private SDVariable xt, cLast, yLast, W, Wci, Wcf, Wco, b;
|
|
||||||
|
|
||||||
public Map<String,Object> toProperties() {
|
|
||||||
Map<String,Object> ret = new LinkedHashMap<>();
|
|
||||||
ret.put("peepHole",peepHole);
|
|
||||||
ret.put("clippingCellValue",clippingCellValue);
|
|
||||||
ret.put("forgetBias",forgetBias);
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
public SDVariable[] args() {
|
|
||||||
return new SDVariable[] {xt,cLast, yLast, W, Wci, Wcf, Wco, b};
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
public int[] iArgs() {
|
|
||||||
return new int[] {ArrayUtil.fromBoolean(peepHole)};
|
|
||||||
}
|
|
||||||
|
|
||||||
public double[] tArgs() {
|
|
||||||
return new double[] {forgetBias,clippingCellValue};
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -19,13 +19,15 @@ package org.nd4j.linalg.api.ops.impl.layers.recurrent.config;
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer;
|
||||||
import org.nd4j.linalg.util.ArrayUtil;
|
import org.nd4j.linalg.util.ArrayUtil;
|
||||||
|
|
||||||
import java.util.LinkedHashMap;
|
import java.util.LinkedHashMap;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* LSTM Configuration - for {@link org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer}
|
* LSTM Configuration - for {@link LSTMLayer} and {@link LSTMBlockCell}
|
||||||
*
|
*
|
||||||
* @author Alex Black
|
* @author Alex Black
|
||||||
*/
|
*/
|
||||||
|
@ -33,29 +35,41 @@ import java.util.Map;
|
||||||
@Data
|
@Data
|
||||||
public class LSTMConfiguration {
|
public class LSTMConfiguration {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Whether to provide peephole connections.
|
||||||
|
*/
|
||||||
private boolean peepHole; //IArg(0)
|
private boolean peepHole; //IArg(0)
|
||||||
@Builder.Default private RnnDataFormat dataFormat = RnnDataFormat.TNS; //IArg(1)
|
|
||||||
|
/**
|
||||||
|
* The data format of the input. Only used in {@link LSTMLayer}, ignored in {@link LSTMBlockCell}.
|
||||||
|
*/
|
||||||
|
@Builder.Default private RnnDataFormat dataFormat = RnnDataFormat.TNS; //IArg(1) (only for lstmBlock, not lstmBlockCell)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The bias added to forget gates in order to reduce the scale of forgetting in the beginning of the training.
|
||||||
|
*/
|
||||||
private double forgetBias; //TArg(0)
|
private double forgetBias; //TArg(0)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Clipping value for cell state, if it is not equal to zero, then cell state is clipped.
|
||||||
|
*/
|
||||||
private double clippingCellValue; //TArg(1)
|
private double clippingCellValue; //TArg(1)
|
||||||
|
|
||||||
private SDVariable xt, cLast, yLast, W, Wci, Wcf, Wco, b;
|
public Map<String,Object> toProperties(boolean includeDataFormat) {
|
||||||
|
|
||||||
public Map<String,Object> toProperties() {
|
|
||||||
Map<String,Object> ret = new LinkedHashMap<>();
|
Map<String,Object> ret = new LinkedHashMap<>();
|
||||||
ret.put("peepHole",peepHole);
|
ret.put("peepHole",peepHole);
|
||||||
ret.put("clippingCellValue",clippingCellValue);
|
ret.put("clippingCellValue",clippingCellValue);
|
||||||
ret.put("forgetBias",forgetBias);
|
ret.put("forgetBias",forgetBias);
|
||||||
ret.put("dataFormat", dataFormat);
|
if(includeDataFormat)
|
||||||
|
ret.put("dataFormat", dataFormat);
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
public SDVariable[] args() {
|
|
||||||
return new SDVariable[] {xt,cLast, yLast, W, Wci, Wcf, Wco, b};
|
|
||||||
}
|
|
||||||
|
|
||||||
|
public int[] iArgs(boolean includeDataFormat) {
|
||||||
public int[] iArgs() {
|
if(includeDataFormat) {
|
||||||
return new int[] {ArrayUtil.fromBoolean(peepHole), dataFormat.ordinal()};
|
return new int[]{ArrayUtil.fromBoolean(peepHole), dataFormat.ordinal()};
|
||||||
|
} else return new int[]{ArrayUtil.fromBoolean(peepHole)};
|
||||||
}
|
}
|
||||||
|
|
||||||
public double[] tArgs() {
|
public double[] tArgs() {
|
||||||
|
|
|
@ -1,44 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.layers.recurrent.config;
|
|
||||||
|
|
||||||
import lombok.Builder;
|
|
||||||
import lombok.Data;
|
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
|
||||||
|
|
||||||
@Data
|
|
||||||
@Builder
|
|
||||||
public class SRUCellConfiguration {
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
NDArray<T>* xt = INPUT_VARIABLE(0); // input [batchSize x inSize], batchSize - batch size, inSize - number of features
|
|
||||||
NDArray<T>* ct_1 = INPUT_VARIABLE(1); // previous cell state ct [batchSize x inSize], that is at previous time step t-1
|
|
||||||
NDArray<T>* w = INPUT_VARIABLE(2); // weights [inSize x 3*inSize]
|
|
||||||
NDArray<T>* b = INPUT_VARIABLE(3); // biases [1 x 2*inSize]
|
|
||||||
|
|
||||||
NDArray<T>* ht = OUTPUT_VARIABLE(0); // current cell output [batchSize x inSize], that is at current time step t
|
|
||||||
NDArray<T>* ct = OUTPUT_VARIABLE(1); // current cell state [batchSize x inSize], that is at current time step t
|
|
||||||
|
|
||||||
*/
|
|
||||||
private SDVariable xt,ct_1,w,b,h1,ct;
|
|
||||||
|
|
||||||
|
|
||||||
public SDVariable[] args() {
|
|
||||||
return new SDVariable[] {xt,ct_1,w,b,h1,ct};
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,38 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.layers.recurrent.config;
|
|
||||||
|
|
||||||
import lombok.Builder;
|
|
||||||
import lombok.Data;
|
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
|
||||||
|
|
||||||
@Data
|
|
||||||
@Builder
|
|
||||||
public class SRUConfiguration {
|
|
||||||
/**
|
|
||||||
* NDArray<T>* input = INPUT_VARIABLE(0); // X, input 3d tensor [bS x K x N], N - number of time steps, bS - batch size, K - number of features
|
|
||||||
NDArray<T>* weights = INPUT_VARIABLE(1); // W, 2d tensor of weights [3K x K]
|
|
||||||
NDArray<T>* bias = INPUT_VARIABLE(2); // B, row of biases with twice length [1 x 2*K]
|
|
||||||
NDArray<T>* init = INPUT_VARIABLE(3); // C_{0}, 2d tensor of initial state [bS x K] at time t=0
|
|
||||||
|
|
||||||
*/
|
|
||||||
private SDVariable inputs,weights,bias,init;
|
|
||||||
|
|
||||||
public SDVariable[] args() {
|
|
||||||
return new SDVariable[] {inputs,weights,bias,init};
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -0,0 +1,62 @@
|
||||||
|
package org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs;
|
||||||
|
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.List;
|
||||||
|
import lombok.Getter;
|
||||||
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
import org.nd4j.base.Preconditions;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.recurrent.GRUCell;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The outputs of a GRU cell ({@link GRUCell}.
|
||||||
|
*/
|
||||||
|
@Getter
|
||||||
|
public class GRUCellOutputs {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Reset gate output [batchSize, numUnits].
|
||||||
|
*/
|
||||||
|
private SDVariable r;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Update gate output [batchSize, numUnits].
|
||||||
|
*/
|
||||||
|
private SDVariable u;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Cell gate output [batchSize, numUnits].
|
||||||
|
*/
|
||||||
|
private SDVariable c;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Current cell output [batchSize, numUnits].
|
||||||
|
*/
|
||||||
|
private SDVariable h;
|
||||||
|
|
||||||
|
public GRUCellOutputs(SDVariable[] outputs){
|
||||||
|
Preconditions.checkArgument(outputs.length == 4,
|
||||||
|
"Must have 4 GRU cell outputs, got %s", outputs.length);
|
||||||
|
|
||||||
|
r = outputs[0];
|
||||||
|
u = outputs[1];
|
||||||
|
c = outputs[2];
|
||||||
|
h = outputs[3];
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get all outputs returned by the cell.
|
||||||
|
*/
|
||||||
|
public List<SDVariable> getAllOutputs(){
|
||||||
|
return Arrays.asList(r, u, c, h);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get h, the output of the cell.
|
||||||
|
*
|
||||||
|
* Has shape [batchSize, numUnits].
|
||||||
|
*/
|
||||||
|
public SDVariable getOutput(){
|
||||||
|
return h;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,88 @@
|
||||||
|
package org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs;
|
||||||
|
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.List;
|
||||||
|
import lombok.Getter;
|
||||||
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
import org.nd4j.base.Preconditions;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The outputs of a LSTM cell ({@link LSTMBlockCell}.
|
||||||
|
*/
|
||||||
|
@Getter
|
||||||
|
public class LSTMCellOutputs {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Output - input modulation gate activations [batchSize, numUnits].
|
||||||
|
*/
|
||||||
|
private SDVariable i;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Activations, cell state (pre tanh) [batchSize, numUnits].
|
||||||
|
*/
|
||||||
|
private SDVariable c;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Output - forget gate activations [batchSize, numUnits].
|
||||||
|
*/
|
||||||
|
private SDVariable f;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Output - output gate activations [batchSize, numUnits].
|
||||||
|
*/
|
||||||
|
private SDVariable o;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Output - input gate activations [batchSize, numUnits].
|
||||||
|
*/
|
||||||
|
private SDVariable z;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Cell state, post tanh [batchSize, numUnits].
|
||||||
|
*/
|
||||||
|
private SDVariable h;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Current cell output [batchSize, numUnits].
|
||||||
|
*/
|
||||||
|
private SDVariable y;
|
||||||
|
|
||||||
|
public LSTMCellOutputs(SDVariable[] outputs){
|
||||||
|
Preconditions.checkArgument(outputs.length == 7,
|
||||||
|
"Must have 7 LSTM cell outputs, got %s", outputs.length);
|
||||||
|
|
||||||
|
i = outputs[0];
|
||||||
|
c = outputs[1];
|
||||||
|
f = outputs[2];
|
||||||
|
o = outputs[3];
|
||||||
|
z = outputs[4];
|
||||||
|
h = outputs[5];
|
||||||
|
y = outputs[6];
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get all outputs returned by the cell.
|
||||||
|
*/
|
||||||
|
public List<SDVariable> getAllOutputs(){
|
||||||
|
return Arrays.asList(i, c, f, o, z, h, y);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get y, the output of the cell.
|
||||||
|
*
|
||||||
|
* Has shape [batchSize, numUnits].
|
||||||
|
*/
|
||||||
|
public SDVariable getOutput(){
|
||||||
|
return y;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get c, the cell's state.
|
||||||
|
*
|
||||||
|
* Has shape [batchSize, numUnits].
|
||||||
|
*/
|
||||||
|
public SDVariable getState(){
|
||||||
|
return c;
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,180 @@
|
||||||
|
package org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs;
|
||||||
|
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.List;
|
||||||
|
import lombok.AccessLevel;
|
||||||
|
import lombok.Getter;
|
||||||
|
import org.nd4j.autodiff.samediff.SDIndex;
|
||||||
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
import org.nd4j.base.Preconditions;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.RnnDataFormat;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The outputs of a LSTM layer ({@link LSTMLayer}.
|
||||||
|
*/
|
||||||
|
@Getter
|
||||||
|
public class LSTMLayerOutputs {
|
||||||
|
|
||||||
|
private RnnDataFormat dataFormat;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Output - input modulation gate activations.
|
||||||
|
* Shape depends on data format (in layer config):<br>
|
||||||
|
* TNS -> [timeSteps, batchSize, numUnits]<br>
|
||||||
|
* NST -> [batchSize, numUnits, timeSteps]<br>
|
||||||
|
* NTS -> [batchSize, timeSteps, numUnits]<br>
|
||||||
|
*/
|
||||||
|
private SDVariable i;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Activations, cell state (pre tanh).
|
||||||
|
* Shape depends on data format (in layer config):<br>
|
||||||
|
* TNS -> [timeSteps, batchSize, numUnits]<br>
|
||||||
|
* NST -> [batchSize, numUnits, timeSteps]<br>
|
||||||
|
* NTS -> [batchSize, timeSteps, numUnits]<br>
|
||||||
|
*/
|
||||||
|
private SDVariable c;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Output - forget gate activations.
|
||||||
|
* Shape depends on data format (in layer config):<br>
|
||||||
|
* TNS -> [timeSteps, batchSize, numUnits]<br>
|
||||||
|
* NST -> [batchSize, numUnits, timeSteps]<br>
|
||||||
|
* NTS -> [batchSize, timeSteps, numUnits]<br>
|
||||||
|
*/
|
||||||
|
private SDVariable f;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Output - output gate activations.
|
||||||
|
* Shape depends on data format (in layer config):<br>
|
||||||
|
* TNS -> [timeSteps, batchSize, numUnits]<br>
|
||||||
|
* NST -> [batchSize, numUnits, timeSteps]<br>
|
||||||
|
* NTS -> [batchSize, timeSteps, numUnits]<br>
|
||||||
|
*/
|
||||||
|
private SDVariable o;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Output - input gate activations.
|
||||||
|
* Shape depends on data format (in layer config):<br>
|
||||||
|
* TNS -> [timeSteps, batchSize, numUnits]<br>
|
||||||
|
* NST -> [batchSize, numUnits, timeSteps]<br>
|
||||||
|
* NTS -> [batchSize, timeSteps, numUnits]<br>
|
||||||
|
*/
|
||||||
|
private SDVariable z;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Cell state, post tanh.
|
||||||
|
* Shape depends on data format (in layer config):<br>
|
||||||
|
* TNS -> [timeSteps, batchSize, numUnits]<br>
|
||||||
|
* NST -> [batchSize, numUnits, timeSteps]<br>
|
||||||
|
* NTS -> [batchSize, timeSteps, numUnits]<br>
|
||||||
|
*/
|
||||||
|
private SDVariable h;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Current cell output.
|
||||||
|
* Shape depends on data format (in layer config):<br>
|
||||||
|
* TNS -> [timeSteps, batchSize, numUnits]<br>
|
||||||
|
* NST -> [batchSize, numUnits, timeSteps]<br>
|
||||||
|
* NTS -> [batchSize, timeSteps, numUnits]<br>
|
||||||
|
*/
|
||||||
|
private SDVariable y;
|
||||||
|
|
||||||
|
public LSTMLayerOutputs(SDVariable[] outputs, RnnDataFormat dataFormat){
|
||||||
|
Preconditions.checkArgument(outputs.length == 7,
|
||||||
|
"Must have 7 LSTM layer outputs, got %s", outputs.length);
|
||||||
|
|
||||||
|
i = outputs[0];
|
||||||
|
c = outputs[1];
|
||||||
|
f = outputs[2];
|
||||||
|
o = outputs[3];
|
||||||
|
z = outputs[4];
|
||||||
|
h = outputs[5];
|
||||||
|
y = outputs[6];
|
||||||
|
this.dataFormat = dataFormat;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get all outputs returned by the cell.
|
||||||
|
*/
|
||||||
|
public List<SDVariable> getAllOutputs(){
|
||||||
|
return Arrays.asList(i, c, f, o, z, h, y);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get y, the output of the cell for all time steps.
|
||||||
|
*
|
||||||
|
* Shape depends on data format (in layer config):<br>
|
||||||
|
* TNS -> [timeSteps, batchSize, numUnits]<br>
|
||||||
|
* NST -> [batchSize, numUnits, timeSteps]<br>
|
||||||
|
* NTS -> [batchSize, timeSteps, numUnits]<br>
|
||||||
|
*/
|
||||||
|
public SDVariable getOutput(){
|
||||||
|
return y;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get c, the cell's state for all time steps.
|
||||||
|
*
|
||||||
|
* Shape depends on data format (in layer config):<br>
|
||||||
|
* TNS -> [timeSteps, batchSize, numUnits]<br>
|
||||||
|
* NST -> [batchSize, numUnits, timeSteps]<br>
|
||||||
|
* NTS -> [batchSize, timeSteps, numUnits]<br>
|
||||||
|
*/
|
||||||
|
public SDVariable getState(){
|
||||||
|
return c;
|
||||||
|
}
|
||||||
|
|
||||||
|
private SDVariable lastOutput = null;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get y, the output of the cell, for the last time step.
|
||||||
|
*
|
||||||
|
* Has shape [batchSize, numUnits].
|
||||||
|
*/
|
||||||
|
public SDVariable getLastOutput(){
|
||||||
|
if(lastOutput != null)
|
||||||
|
return lastOutput;
|
||||||
|
|
||||||
|
switch (dataFormat){
|
||||||
|
case TNS:
|
||||||
|
lastOutput = getOutput().get(SDIndex.point(-1), SDIndex.all(), SDIndex.all());
|
||||||
|
break;
|
||||||
|
case NST:
|
||||||
|
lastOutput = getOutput().get(SDIndex.all(), SDIndex.all(), SDIndex.point(-1));
|
||||||
|
break;
|
||||||
|
case NTS:
|
||||||
|
lastOutput = getOutput().get(SDIndex.all(), SDIndex.point(-1), SDIndex.all());
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
return lastOutput;
|
||||||
|
}
|
||||||
|
|
||||||
|
private SDVariable lastState = null;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get c, the state of the cell, for the last time step.
|
||||||
|
*
|
||||||
|
* Has shape [batchSize, numUnits].
|
||||||
|
*/
|
||||||
|
public SDVariable getLastState(){
|
||||||
|
if(lastState != null)
|
||||||
|
return lastState;
|
||||||
|
|
||||||
|
switch (dataFormat){
|
||||||
|
case TNS:
|
||||||
|
lastState = getState().get(SDIndex.point(-1), SDIndex.all(), SDIndex.all());
|
||||||
|
break;
|
||||||
|
case NST:
|
||||||
|
lastState = getState().get(SDIndex.all(), SDIndex.all(), SDIndex.point(-1));
|
||||||
|
break;
|
||||||
|
case NTS:
|
||||||
|
lastState = getState().get(SDIndex.all(), SDIndex.point(-1), SDIndex.all());
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
return lastState;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,60 @@
|
||||||
|
package org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs;
|
||||||
|
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.List;
|
||||||
|
import lombok.Getter;
|
||||||
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
import org.nd4j.base.Preconditions;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.recurrent.GRUCell;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The outputs of a GRU cell ({@link GRUCell}.
|
||||||
|
*/
|
||||||
|
@Getter
|
||||||
|
public class SRUCellOutputs {
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Current cell output [batchSize, numUnits].
|
||||||
|
*/
|
||||||
|
private SDVariable h;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Current cell state [batchSize, numUnits].
|
||||||
|
*/
|
||||||
|
private SDVariable c;
|
||||||
|
|
||||||
|
public SRUCellOutputs(SDVariable[] outputs){
|
||||||
|
Preconditions.checkArgument(outputs.length == 2,
|
||||||
|
"Must have 2 SRU cell outputs, got %s", outputs.length);
|
||||||
|
|
||||||
|
h = outputs[0];
|
||||||
|
c = outputs[1];
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get all outputs returned by the cell.
|
||||||
|
*/
|
||||||
|
public List<SDVariable> getAllOutputs(){
|
||||||
|
return Arrays.asList(h, c);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get h, the output of the cell.
|
||||||
|
*
|
||||||
|
* Has shape [batchSize, inSize].
|
||||||
|
*/
|
||||||
|
public SDVariable getOutput(){
|
||||||
|
return h;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get c, the state of the cell.
|
||||||
|
*
|
||||||
|
* Has shape [batchSize, inSize].
|
||||||
|
*/
|
||||||
|
public SDVariable getState(){
|
||||||
|
return c;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,92 @@
|
||||||
|
package org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs;
|
||||||
|
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.List;
|
||||||
|
import lombok.AccessLevel;
|
||||||
|
import lombok.Getter;
|
||||||
|
import org.nd4j.autodiff.samediff.SDIndex;
|
||||||
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
import org.nd4j.base.Preconditions;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.recurrent.GRUCell;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The outputs of a GRU cell ({@link GRUCell}.
|
||||||
|
*/
|
||||||
|
@Getter
|
||||||
|
public class SRULayerOutputs {
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Current cell output [batchSize, inSize, timeSeriesLength].
|
||||||
|
*/
|
||||||
|
private SDVariable h;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Current cell state [batchSize, inSize, timeSeriesLength].
|
||||||
|
*/
|
||||||
|
private SDVariable c;
|
||||||
|
|
||||||
|
public SRULayerOutputs(SDVariable[] outputs){
|
||||||
|
Preconditions.checkArgument(outputs.length == 2,
|
||||||
|
"Must have 2 SRU cell outputs, got %s", outputs.length);
|
||||||
|
|
||||||
|
h = outputs[0];
|
||||||
|
c = outputs[1];
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get all outputs returned by the cell.
|
||||||
|
*/
|
||||||
|
public List<SDVariable> getAllOutputs(){
|
||||||
|
return Arrays.asList(h, c);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get h, the output of the cell.
|
||||||
|
*
|
||||||
|
* Has shape [batchSize, inSize, timeSeriesLength].
|
||||||
|
*/
|
||||||
|
public SDVariable getOutput(){
|
||||||
|
return h;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get c, the state of the cell.
|
||||||
|
*
|
||||||
|
* Has shape [batchSize, inSize, timeSeriesLength].
|
||||||
|
*/
|
||||||
|
public SDVariable getState(){
|
||||||
|
return c;
|
||||||
|
}
|
||||||
|
|
||||||
|
private SDVariable lastOutput = null;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get y, the output of the cell, for the last time step.
|
||||||
|
*
|
||||||
|
* Has shape [batchSize, inSize].
|
||||||
|
*/
|
||||||
|
public SDVariable getLastOutput(){
|
||||||
|
if(lastOutput != null)
|
||||||
|
return lastOutput;
|
||||||
|
|
||||||
|
lastOutput = getOutput().get(SDIndex.all(), SDIndex.all(), SDIndex.point(-1));
|
||||||
|
return lastOutput;
|
||||||
|
}
|
||||||
|
|
||||||
|
private SDVariable lastState = null;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get c, the state of the cell, for the last time step.
|
||||||
|
*
|
||||||
|
* Has shape [batchSize, inSize].
|
||||||
|
*/
|
||||||
|
public SDVariable getLastState(){
|
||||||
|
if(lastState != null)
|
||||||
|
return lastState;
|
||||||
|
|
||||||
|
lastOutput = getOutput().get(SDIndex.all(), SDIndex.all(), SDIndex.point(-1));
|
||||||
|
return lastState;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,51 @@
|
||||||
|
package org.nd4j.linalg.api.ops.impl.layers.recurrent.weights;
|
||||||
|
|
||||||
|
import lombok.Builder;
|
||||||
|
import lombok.Data;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
|
import lombok.NonNull;
|
||||||
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.recurrent.GRUCell;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The weight configuration of a GRU cell. For {@link GRUCell}.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
@EqualsAndHashCode(callSuper = true)
|
||||||
|
@Data
|
||||||
|
@Builder
|
||||||
|
public class GRUWeights extends RNNWeights {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Reset and Update gate weights, with a shape of [inSize + numUnits, 2*numUnits].
|
||||||
|
*
|
||||||
|
* The reset weights are the [:, 0:numUnits] subset and the update weights are the [:, numUnits:2*numUnits] subset.
|
||||||
|
*/
|
||||||
|
@NonNull
|
||||||
|
private SDVariable ruWeight;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Cell gate weights, with a shape of [inSize + numUnits, numUnits]
|
||||||
|
*/
|
||||||
|
@NonNull
|
||||||
|
private SDVariable cWeight;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Reset and Update gate bias, with a shape of [2*numUnits]. May be null.
|
||||||
|
*
|
||||||
|
* The reset bias is the [0:numUnits] subset and the update bias is the [numUnits:2*numUnits] subset.
|
||||||
|
*/
|
||||||
|
@NonNull
|
||||||
|
private SDVariable ruBias;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Cell gate bias, with a shape of [numUnits]. May be null.
|
||||||
|
*/
|
||||||
|
@NonNull
|
||||||
|
private SDVariable cBias;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public SDVariable[] args() {
|
||||||
|
return filterNonNull(ruWeight, cWeight, ruBias, cBias);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,57 @@
|
||||||
|
package org.nd4j.linalg.api.ops.impl.layers.recurrent.weights;
|
||||||
|
|
||||||
|
import lombok.Builder;
|
||||||
|
import lombok.Data;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
|
import lombok.NonNull;
|
||||||
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The weight configuration of a LSTM layer. For {@link LSTMLayer} and {@link LSTMBlockCell}.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
@EqualsAndHashCode(callSuper = true)
|
||||||
|
@Data
|
||||||
|
@Builder
|
||||||
|
public class LSTMWeights extends RNNWeights {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Input to hidden weights and hidden to hidden weights, with a shape of [inSize + numUnits, 4*numUnits].
|
||||||
|
*
|
||||||
|
* Input to hidden and hidden to hidden are concatenated in dimension 0,
|
||||||
|
* so the input to hidden weights are [:inSize, :] and the hidden to hidden weights are [inSize:, :].
|
||||||
|
*/
|
||||||
|
@NonNull
|
||||||
|
private SDVariable weights;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Cell peephole (t-1) connections to input modulation gate, with a shape of [numUnits].
|
||||||
|
*/
|
||||||
|
@NonNull
|
||||||
|
private SDVariable inputPeepholeWeights;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Cell peephole (t-1) connections to forget gate, with a shape of [numUnits].
|
||||||
|
*/
|
||||||
|
@NonNull
|
||||||
|
private SDVariable forgetPeepholeWeights;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Cell peephole (t) connections to output gate, with a shape of [numUnits].
|
||||||
|
*/
|
||||||
|
@NonNull
|
||||||
|
private SDVariable outputPeepholeWeights;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Input to hidden and hidden to hidden biases, with shape [1, 4*numUnits].
|
||||||
|
*/
|
||||||
|
@NonNull
|
||||||
|
private SDVariable bias;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public SDVariable[] args() {
|
||||||
|
return filterNonNull(weights, inputPeepholeWeights, forgetPeepholeWeights, outputPeepholeWeights, bias);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,35 @@
|
||||||
|
package org.nd4j.linalg.api.ops.impl.layers.recurrent.weights;
|
||||||
|
|
||||||
|
import java.util.Arrays;
|
||||||
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
import org.nd4j.linalg.util.ArrayUtil;
|
||||||
|
|
||||||
|
public abstract class RNNWeights {
|
||||||
|
public abstract SDVariable[] args();
|
||||||
|
|
||||||
|
protected static SDVariable[] filterNonNull(SDVariable... args){
|
||||||
|
int count = 0;
|
||||||
|
for(SDVariable v : args){
|
||||||
|
if(v != null){
|
||||||
|
count++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
SDVariable[] res = new SDVariable[count];
|
||||||
|
|
||||||
|
int i = 0;
|
||||||
|
|
||||||
|
for(SDVariable v : args){
|
||||||
|
if(v != null){
|
||||||
|
res[i] = v;
|
||||||
|
i++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
public SDVariable[] argsWithInputs(SDVariable... inputs){
|
||||||
|
return ArrayUtil.combine(inputs, args());
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,37 @@
|
||||||
|
package org.nd4j.linalg.api.ops.impl.layers.recurrent.weights;
|
||||||
|
|
||||||
|
import lombok.Builder;
|
||||||
|
import lombok.Data;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
|
import lombok.NonNull;
|
||||||
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.recurrent.SRU;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.recurrent.SRUCell;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The weight configuration of a SRU layer. For {@link SRU} and {@link SRUCell}.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
@EqualsAndHashCode(callSuper = true)
|
||||||
|
@Data
|
||||||
|
@Builder
|
||||||
|
public class SRUWeights extends RNNWeights {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Weights, with shape [inSize, 3*inSize].
|
||||||
|
*/
|
||||||
|
@NonNull
|
||||||
|
private SDVariable weights;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Biases, with shape [2*inSize].
|
||||||
|
*/
|
||||||
|
@NonNull
|
||||||
|
private SDVariable bias;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public SDVariable[] args() {
|
||||||
|
return new SDVariable[]{weights, bias};
|
||||||
|
}
|
||||||
|
}
|
|
@ -27,6 +27,7 @@ import org.nd4j.linalg.primitives.Pair;
|
||||||
import org.nd4j.serde.jackson.shaded.NDArrayTextDeSerializer;
|
import org.nd4j.serde.jackson.shaded.NDArrayTextDeSerializer;
|
||||||
import org.nd4j.serde.jackson.shaded.NDArrayTextSerializer;
|
import org.nd4j.serde.jackson.shaded.NDArrayTextSerializer;
|
||||||
import org.nd4j.shade.jackson.annotation.JsonInclude;
|
import org.nd4j.shade.jackson.annotation.JsonInclude;
|
||||||
|
import org.nd4j.shade.jackson.annotation.JsonProperty;
|
||||||
import org.nd4j.shade.jackson.databind.annotation.JsonDeserialize;
|
import org.nd4j.shade.jackson.databind.annotation.JsonDeserialize;
|
||||||
import org.nd4j.shade.jackson.databind.annotation.JsonSerialize;
|
import org.nd4j.shade.jackson.databind.annotation.JsonSerialize;
|
||||||
|
|
||||||
|
@ -58,7 +59,7 @@ public class LossL2 implements ILossFunction {
|
||||||
*
|
*
|
||||||
* @param weights Weights array (row vector). May be null.
|
* @param weights Weights array (row vector). May be null.
|
||||||
*/
|
*/
|
||||||
public LossL2(INDArray weights) {
|
public LossL2(@JsonProperty("weights") INDArray weights) {
|
||||||
if (weights != null && !weights.isRowVector()) {
|
if (weights != null && !weights.isRowVector()) {
|
||||||
throw new IllegalArgumentException("Weights array must be a row vector");
|
throw new IllegalArgumentException("Weights array must be a row vector");
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,6 +19,7 @@ package org.nd4j.linalg.lossfunctions.impl;
|
||||||
import lombok.EqualsAndHashCode;
|
import lombok.EqualsAndHashCode;
|
||||||
import org.nd4j.linalg.activations.IActivation;
|
import org.nd4j.linalg.activations.IActivation;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.shade.jackson.annotation.JsonProperty;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Mean Squared Error loss function: L = 1/N sum_i (actual_i - predicted)^2
|
* Mean Squared Error loss function: L = 1/N sum_i (actual_i - predicted)^2
|
||||||
|
@ -38,7 +39,7 @@ public class LossMSE extends LossL2 {
|
||||||
*
|
*
|
||||||
* @param weights Weights array (row vector). May be null.
|
* @param weights Weights array (row vector). May be null.
|
||||||
*/
|
*/
|
||||||
public LossMSE(INDArray weights) {
|
public LossMSE(@JsonProperty("weights") INDArray weights) {
|
||||||
super(weights);
|
super(weights);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -664,7 +664,7 @@ public class JCublasNDArray extends BaseNDArray {
|
||||||
//if (1 < 0) {
|
//if (1 < 0) {
|
||||||
Nd4j.getExecutioner().commit();
|
Nd4j.getExecutioner().commit();
|
||||||
|
|
||||||
DataBuffer buffer = Nd4j.createBuffer(this.lengthLong(), false);
|
DataBuffer buffer = Nd4j.createBuffer(this.length(), false);
|
||||||
|
|
||||||
AllocationPoint pointDst = AtomicAllocator.getInstance().getAllocationPoint(buffer);
|
AllocationPoint pointDst = AtomicAllocator.getInstance().getAllocationPoint(buffer);
|
||||||
AllocationPoint pointSrc = AtomicAllocator.getInstance().getAllocationPoint(this.data);
|
AllocationPoint pointSrc = AtomicAllocator.getInstance().getAllocationPoint(this.data);
|
||||||
|
@ -686,10 +686,10 @@ public class JCublasNDArray extends BaseNDArray {
|
||||||
val perfD = PerformanceTracker.getInstance().helperStartTransaction();
|
val perfD = PerformanceTracker.getInstance().helperStartTransaction();
|
||||||
|
|
||||||
if (pointSrc.isActualOnDeviceSide()) {
|
if (pointSrc.isActualOnDeviceSide()) {
|
||||||
if (NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(pointDst.getDevicePointer(), pointSrc.getDevicePointer(), this.lengthLong() * Nd4j.sizeOfDataType(buffer.dataType()), CudaConstants.cudaMemcpyDeviceToDevice, context.getOldStream()) == 0)
|
if (NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(pointDst.getDevicePointer(), pointSrc.getDevicePointer(), this.length() * Nd4j.sizeOfDataType(buffer.dataType()), CudaConstants.cudaMemcpyDeviceToDevice, context.getOldStream()) == 0)
|
||||||
throw new ND4JIllegalStateException("memcpyAsync failed");
|
throw new ND4JIllegalStateException("memcpyAsync failed");
|
||||||
} else {
|
} else {
|
||||||
if (NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(pointDst.getDevicePointer(), pointSrc.getHostPointer(), this.lengthLong() * Nd4j.sizeOfDataType(buffer.dataType()), CudaConstants.cudaMemcpyHostToDevice, context.getOldStream()) == 0)
|
if (NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(pointDst.getDevicePointer(), pointSrc.getHostPointer(), this.length() * Nd4j.sizeOfDataType(buffer.dataType()), CudaConstants.cudaMemcpyHostToDevice, context.getOldStream()) == 0)
|
||||||
throw new ND4JIllegalStateException("memcpyAsync failed");
|
throw new ND4JIllegalStateException("memcpyAsync failed");
|
||||||
|
|
||||||
direction = MemcpyDirection.HOST_TO_DEVICE;
|
direction = MemcpyDirection.HOST_TO_DEVICE;
|
||||||
|
@ -738,7 +738,7 @@ public class JCublasNDArray extends BaseNDArray {
|
||||||
if (!this.isView()) {
|
if (!this.isView()) {
|
||||||
Nd4j.getExecutioner().commit();
|
Nd4j.getExecutioner().commit();
|
||||||
|
|
||||||
val buffer = Nd4j.createBuffer(this.dataType(), this.lengthLong(), false);
|
val buffer = Nd4j.createBuffer(this.dataType(), this.length(), false);
|
||||||
|
|
||||||
val pointDst = AtomicAllocator.getInstance().getAllocationPoint(buffer);
|
val pointDst = AtomicAllocator.getInstance().getAllocationPoint(buffer);
|
||||||
val pointSrc = AtomicAllocator.getInstance().getAllocationPoint(this.data);
|
val pointSrc = AtomicAllocator.getInstance().getAllocationPoint(this.data);
|
||||||
|
@ -749,10 +749,10 @@ public class JCublasNDArray extends BaseNDArray {
|
||||||
val perfD = PerformanceTracker.getInstance().helperStartTransaction();
|
val perfD = PerformanceTracker.getInstance().helperStartTransaction();
|
||||||
|
|
||||||
if (pointSrc.isActualOnDeviceSide()) {
|
if (pointSrc.isActualOnDeviceSide()) {
|
||||||
if (NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(pointDst.getDevicePointer(), pointSrc.getDevicePointer(), this.lengthLong() * Nd4j.sizeOfDataType(buffer.dataType()), CudaConstants.cudaMemcpyDeviceToDevice, context.getOldStream()) == 0)
|
if (NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(pointDst.getDevicePointer(), pointSrc.getDevicePointer(), this.length() * Nd4j.sizeOfDataType(buffer.dataType()), CudaConstants.cudaMemcpyDeviceToDevice, context.getOldStream()) == 0)
|
||||||
throw new ND4JIllegalStateException("memcpyAsync failed");
|
throw new ND4JIllegalStateException("memcpyAsync failed");
|
||||||
} else {
|
} else {
|
||||||
if (NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(pointDst.getDevicePointer(), pointSrc.getHostPointer(), this.lengthLong() * Nd4j.sizeOfDataType(buffer.dataType()), CudaConstants.cudaMemcpyHostToDevice, context.getOldStream()) == 0)
|
if (NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(pointDst.getDevicePointer(), pointSrc.getHostPointer(), this.length() * Nd4j.sizeOfDataType(buffer.dataType()), CudaConstants.cudaMemcpyHostToDevice, context.getOldStream()) == 0)
|
||||||
throw new ND4JIllegalStateException("memcpyAsync failed");
|
throw new ND4JIllegalStateException("memcpyAsync failed");
|
||||||
|
|
||||||
direction = MemcpyDirection.HOST_TO_DEVICE;
|
direction = MemcpyDirection.HOST_TO_DEVICE;
|
||||||
|
|
|
@ -424,7 +424,7 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
|
||||||
|
|
||||||
val perfD = PerformanceTracker.getInstance().helperStartTransaction();
|
val perfD = PerformanceTracker.getInstance().helperStartTransaction();
|
||||||
|
|
||||||
nativeOps.memcpyAsync(point.getDevicePointer(), point.getHostPointer(), ret.lengthLong() * Nd4j.sizeOfDataType(ret.data().dataType()), CudaConstants.cudaMemcpyHostToDevice, context.getSpecialStream());
|
nativeOps.memcpyAsync(point.getDevicePointer(), point.getHostPointer(), ret.length() * Nd4j.sizeOfDataType(ret.data().dataType()), CudaConstants.cudaMemcpyHostToDevice, context.getSpecialStream());
|
||||||
context.getSpecialStream().synchronize();
|
context.getSpecialStream().synchronize();
|
||||||
|
|
||||||
if (nativeOps.lastErrorCode() != 0)
|
if (nativeOps.lastErrorCode() != 0)
|
||||||
|
@ -580,7 +580,7 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
|
||||||
if (true) {
|
if (true) {
|
||||||
Nd4j.getExecutioner().push();
|
Nd4j.getExecutioner().push();
|
||||||
|
|
||||||
long len = target.lengthLong();
|
long len = target.length();
|
||||||
|
|
||||||
AtomicAllocator allocator = AtomicAllocator.getInstance();
|
AtomicAllocator allocator = AtomicAllocator.getInstance();
|
||||||
|
|
||||||
|
@ -598,7 +598,7 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
|
||||||
if (arrays[i].elementWiseStride() != 1)
|
if (arrays[i].elementWiseStride() != 1)
|
||||||
throw new ND4JIllegalStateException("Native averaging is applicable only to continuous INDArrays");
|
throw new ND4JIllegalStateException("Native averaging is applicable only to continuous INDArrays");
|
||||||
|
|
||||||
if (arrays[i].lengthLong() != len)
|
if (arrays[i].length() != len)
|
||||||
throw new ND4JIllegalStateException("All arrays should have equal length for averaging");
|
throw new ND4JIllegalStateException("All arrays should have equal length for averaging");
|
||||||
|
|
||||||
AllocationPoint point = allocator.getAllocationPoint(arrays[i]);
|
AllocationPoint point = allocator.getAllocationPoint(arrays[i]);
|
||||||
|
@ -621,7 +621,7 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
|
||||||
|
|
||||||
return target;
|
return target;
|
||||||
} else {
|
} else {
|
||||||
long len = target.lengthLong();
|
long len = target.length();
|
||||||
|
|
||||||
Nd4j.getExecutioner().commit();
|
Nd4j.getExecutioner().commit();
|
||||||
|
|
||||||
|
@ -637,7 +637,7 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
|
||||||
if (arrays[i].elementWiseStride() != 1)
|
if (arrays[i].elementWiseStride() != 1)
|
||||||
throw new ND4JIllegalStateException("Native averaging is applicable only to continuous INDArrays");
|
throw new ND4JIllegalStateException("Native averaging is applicable only to continuous INDArrays");
|
||||||
|
|
||||||
if (arrays[i].lengthLong() != len)
|
if (arrays[i].length() != len)
|
||||||
throw new ND4JIllegalStateException("All arrays should have equal length for averaging");
|
throw new ND4JIllegalStateException("All arrays should have equal length for averaging");
|
||||||
|
|
||||||
((BaseCudaDataBuffer) arrays[i].data()).lazyAllocateHostPointer();
|
((BaseCudaDataBuffer) arrays[i].data()).lazyAllocateHostPointer();
|
||||||
|
@ -689,7 +689,7 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
|
||||||
|
|
||||||
Nd4j.getExecutioner().push();
|
Nd4j.getExecutioner().push();
|
||||||
|
|
||||||
long len = target != null ? target.lengthLong() : arrays[0].lengthLong();
|
long len = target != null ? target.length() : arrays[0].length();
|
||||||
|
|
||||||
AtomicAllocator allocator = AtomicAllocator.getInstance();
|
AtomicAllocator allocator = AtomicAllocator.getInstance();
|
||||||
|
|
||||||
|
@ -707,7 +707,7 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
|
||||||
if (arrays[i].elementWiseStride() != 1)
|
if (arrays[i].elementWiseStride() != 1)
|
||||||
throw new ND4JIllegalStateException("Native averaging is applicable only to continuous INDArrays");
|
throw new ND4JIllegalStateException("Native averaging is applicable only to continuous INDArrays");
|
||||||
|
|
||||||
if (arrays[i].lengthLong() != len)
|
if (arrays[i].length() != len)
|
||||||
throw new ND4JIllegalStateException("All arrays should have equal length for averaging");
|
throw new ND4JIllegalStateException("All arrays should have equal length for averaging");
|
||||||
|
|
||||||
AllocationPoint point = allocator.getAllocationPoint(arrays[i]);
|
AllocationPoint point = allocator.getAllocationPoint(arrays[i]);
|
||||||
|
@ -744,7 +744,7 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
|
||||||
/**
|
/**
|
||||||
* We expect all operations are complete at this point
|
* We expect all operations are complete at this point
|
||||||
*/
|
*/
|
||||||
long len = target == null ? arrays[0].lengthLong() : target.lengthLong();
|
long len = target == null ? arrays[0].length() : target.length();
|
||||||
|
|
||||||
val context = (CudaContext) AtomicAllocator.getInstance().getDeviceContext();
|
val context = (CudaContext) AtomicAllocator.getInstance().getDeviceContext();
|
||||||
|
|
||||||
|
@ -758,7 +758,7 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
|
||||||
if (arrays[i].elementWiseStride() != 1)
|
if (arrays[i].elementWiseStride() != 1)
|
||||||
throw new ND4JIllegalStateException("Native averaging is applicable only to continuous INDArrays");
|
throw new ND4JIllegalStateException("Native averaging is applicable only to continuous INDArrays");
|
||||||
|
|
||||||
if (arrays[i].lengthLong() != len)
|
if (arrays[i].length() != len)
|
||||||
throw new ND4JIllegalStateException("All arrays should have equal length for averaging");
|
throw new ND4JIllegalStateException("All arrays should have equal length for averaging");
|
||||||
|
|
||||||
((BaseCudaDataBuffer) arrays[i].data()).lazyAllocateHostPointer();
|
((BaseCudaDataBuffer) arrays[i].data()).lazyAllocateHostPointer();
|
||||||
|
@ -1303,7 +1303,7 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
int numTads = (int)(tensor.lengthLong() / tadLength);
|
int numTads = (int)(tensor.length() / tadLength);
|
||||||
INDArray[] result = new INDArray[numTads];
|
INDArray[] result = new INDArray[numTads];
|
||||||
|
|
||||||
long[] xPointers = new long[numTads];
|
long[] xPointers = new long[numTads];
|
||||||
|
@ -1378,7 +1378,7 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
|
||||||
new CudaPointer(0));
|
new CudaPointer(0));
|
||||||
|
|
||||||
// we're sending > 10m elements to radixSort
|
// we're sending > 10m elements to radixSort
|
||||||
boolean isRadix = !x.isView() && (x.lengthLong() > 1024 * 1024 * 10);
|
boolean isRadix = !x.isView() && (x.length() > 1024 * 1024 * 10);
|
||||||
INDArray tmpX = x;
|
INDArray tmpX = x;
|
||||||
|
|
||||||
// we need to guarantee all threads are finished here
|
// we need to guarantee all threads are finished here
|
||||||
|
|
|
@ -293,9 +293,9 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
||||||
Pointer yDevTadShapeInfo = null;
|
Pointer yDevTadShapeInfo = null;
|
||||||
|
|
||||||
if (op.y() != null) {
|
if (op.y() != null) {
|
||||||
if (dimension.length == 0 || (dimension.length == 1 && dimension[0] == Integer.MAX_VALUE )|| op.x().tensorAlongDimension(0, dimension).lengthLong() != op.y().lengthLong()) {
|
if (dimension.length == 0 || (dimension.length == 1 && dimension[0] == Integer.MAX_VALUE )|| op.x().tensorAlongDimension(0, dimension).length() != op.y().length()) {
|
||||||
if (!op.isComplexAccumulation() && op.x().lengthLong() != op.y().lengthLong())
|
if (!op.isComplexAccumulation() && op.x().length() != op.y().length())
|
||||||
throw new ND4JIllegalStateException("Op.X [" + op.x().lengthLong() + "] and Op.Y [" + op.y().lengthLong() + "] lengths should match");
|
throw new ND4JIllegalStateException("Op.X [" + op.x().length() + "] and Op.Y [" + op.y().length() + "] lengths should match");
|
||||||
|
|
||||||
if (!op.z().isScalar()) {
|
if (!op.z().isScalar()) {
|
||||||
Pair<DataBuffer, DataBuffer> yTadBuffers = tadManager.getTADOnlyShapeInfo(op.y(), dimension);
|
Pair<DataBuffer, DataBuffer> yTadBuffers = tadManager.getTADOnlyShapeInfo(op.y(), dimension);
|
||||||
|
@ -536,7 +536,7 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
||||||
} else {
|
} else {
|
||||||
if (op.y() != null) {
|
if (op.y() != null) {
|
||||||
//2 options here: either pairwise, equal sizes - OR every X TAD vs. entirety of Y
|
//2 options here: either pairwise, equal sizes - OR every X TAD vs. entirety of Y
|
||||||
if (op.x().lengthLong() == op.y().lengthLong()) {
|
if (op.x().length() == op.y().length()) {
|
||||||
//Pairwise
|
//Pairwise
|
||||||
if (!wholeDims && op.x().tensorsAlongDimension(dimension) != op.y().tensorsAlongDimension(dimension)) {
|
if (!wholeDims && op.x().tensorsAlongDimension(dimension) != op.y().tensorsAlongDimension(dimension)) {
|
||||||
throw new ND4JIllegalStateException("Number of TADs along dimension don't match: (x shape = " +
|
throw new ND4JIllegalStateException("Number of TADs along dimension don't match: (x shape = " +
|
||||||
|
@ -548,11 +548,11 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
||||||
throw new ND4JIllegalStateException("TAD vs TAD comparison requires dimension (or other comparison mode was supposed to be used?)");
|
throw new ND4JIllegalStateException("TAD vs TAD comparison requires dimension (or other comparison mode was supposed to be used?)");
|
||||||
|
|
||||||
//Every X TAD vs. entirety of Y
|
//Every X TAD vs. entirety of Y
|
||||||
val xTADSize = op.x().lengthLong() / op.x().tensorsAlongDimension(dimension);
|
val xTADSize = op.x().length() / op.x().tensorsAlongDimension(dimension);
|
||||||
|
|
||||||
if (xTADSize != op.y().length()) {
|
if (xTADSize != op.y().length()) {
|
||||||
throw new ND4JIllegalStateException("Size of TADs along dimension don't match for pairwise execution:" +
|
throw new ND4JIllegalStateException("Size of TADs along dimension don't match for pairwise execution:" +
|
||||||
" (x TAD size = " + xTADSize + ", y size = " + op.y().lengthLong());
|
" (x TAD size = " + xTADSize + ", y size = " + op.y().length());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -976,7 +976,7 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
||||||
|
|
||||||
if (op.y() != null) {
|
if (op.y() != null) {
|
||||||
//2 options here: either pairwise, equal sizes - OR every X TAD vs. entirety of Y
|
//2 options here: either pairwise, equal sizes - OR every X TAD vs. entirety of Y
|
||||||
if (op.x().lengthLong() == op.y().lengthLong()) {
|
if (op.x().length() == op.y().length()) {
|
||||||
//Pairwise
|
//Pairwise
|
||||||
if (op.x().tensorsAlongDimension(dimension) != op.y().tensorsAlongDimension(dimension)) {
|
if (op.x().tensorsAlongDimension(dimension) != op.y().tensorsAlongDimension(dimension)) {
|
||||||
throw new ND4JIllegalStateException("Number of TADs along dimension don't match: (x shape = " +
|
throw new ND4JIllegalStateException("Number of TADs along dimension don't match: (x shape = " +
|
||||||
|
@ -985,11 +985,11 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
//Every X TAD vs. entirety of Y
|
//Every X TAD vs. entirety of Y
|
||||||
val xTADSize = op.x().lengthLong() / op.x().tensorsAlongDimension(dimension);
|
val xTADSize = op.x().length() / op.x().tensorsAlongDimension(dimension);
|
||||||
|
|
||||||
if (xTADSize != op.y().length()) {
|
if (xTADSize != op.y().length()) {
|
||||||
throw new ND4JIllegalStateException("Size of TADs along dimension don't match for pairwise execution:" +
|
throw new ND4JIllegalStateException("Size of TADs along dimension don't match for pairwise execution:" +
|
||||||
" (x TAD size = " + xTADSize + ", y size = " + op.y().lengthLong());
|
" (x TAD size = " + xTADSize + ", y size = " + op.y().length());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -2031,8 +2031,8 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
||||||
long compressedLength = buffer.getInt(0);
|
long compressedLength = buffer.getInt(0);
|
||||||
long originalLength = buffer.getInt(1);
|
long originalLength = buffer.getInt(1);
|
||||||
|
|
||||||
if (target.lengthLong() != originalLength)
|
if (target.length() != originalLength)
|
||||||
throw new ND4JIllegalStateException("originalLength ["+ originalLength+"] stored in encoded array doesn't match target length ["+ target.lengthLong()+"]");
|
throw new ND4JIllegalStateException("originalLength ["+ originalLength+"] stored in encoded array doesn't match target length ["+ target.length()+"]");
|
||||||
|
|
||||||
DataBuffer result = target.data();
|
DataBuffer result = target.data();
|
||||||
|
|
||||||
|
@ -2056,7 +2056,7 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public long bitmapEncode(INDArray indArray, INDArray target, double threshold) {
|
public long bitmapEncode(INDArray indArray, INDArray target, double threshold) {
|
||||||
long length = indArray.lengthLong();
|
long length = indArray.length();
|
||||||
long tLen = target.data().length();
|
long tLen = target.data().length();
|
||||||
|
|
||||||
if (tLen != (length / 16 + 5))
|
if (tLen != (length / 16 + 5))
|
||||||
|
@ -2117,7 +2117,7 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
||||||
context.getBufferScalar(),
|
context.getBufferScalar(),
|
||||||
context.getBufferReduction());
|
context.getBufferReduction());
|
||||||
|
|
||||||
nativeOps.decodeBitmap(extras, AtomicAllocator.getInstance().getPointer(encoded.data(), context), target.lengthLong(), AtomicAllocator.getInstance().getPointer(target, context), (LongPointer) AtomicAllocator.getInstance().getHostPointer(target.shapeInfoDataBuffer()));
|
nativeOps.decodeBitmap(extras, AtomicAllocator.getInstance().getPointer(encoded.data(), context), target.length(), AtomicAllocator.getInstance().getPointer(target, context), (LongPointer) AtomicAllocator.getInstance().getHostPointer(target.shapeInfoDataBuffer()));
|
||||||
|
|
||||||
if (nativeOps.lastErrorCode() != 0)
|
if (nativeOps.lastErrorCode() != 0)
|
||||||
throw new RuntimeException(nativeOps.lastErrorMessage());
|
throw new RuntimeException(nativeOps.lastErrorMessage());
|
||||||
|
|
|
@ -655,7 +655,7 @@ public class CudaGridExecutioner extends CudaExecutioner implements GridExecutio
|
||||||
op.setZ(ret);
|
op.setZ(ret);
|
||||||
} else {
|
} else {
|
||||||
// compare length
|
// compare length
|
||||||
if (op.z().lengthLong() != ArrayUtil.prodLong(retShape))
|
if (op.z().length() != ArrayUtil.prodLong(retShape))
|
||||||
throw new ND4JIllegalStateException("Shape of target array for reduction [" + Arrays.toString(op.z().shape()) + "] doesn't match expected [" + Arrays.toString(retShape) + "]");
|
throw new ND4JIllegalStateException("Shape of target array for reduction [" + Arrays.toString(op.z().shape()) + "] doesn't match expected [" + Arrays.toString(retShape) + "]");
|
||||||
|
|
||||||
ret = op.z();
|
ret = op.z();
|
||||||
|
|
|
@ -514,7 +514,7 @@ public class CpuNDArrayFactory extends BaseNativeNDArrayFactory {
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
int numTads = (int)(tensor.lengthLong() / tadLength);
|
int numTads = (int)(tensor.length() / tadLength);
|
||||||
INDArray[] result = new INDArray[numTads];
|
INDArray[] result = new INDArray[numTads];
|
||||||
|
|
||||||
PointerPointer targets = new PointerPointer(numTads);
|
PointerPointer targets = new PointerPointer(numTads);
|
||||||
|
@ -693,7 +693,7 @@ public class CpuNDArrayFactory extends BaseNativeNDArrayFactory {
|
||||||
if (arrays.length == 1)
|
if (arrays.length == 1)
|
||||||
return target.addi(arrays[0]);
|
return target.addi(arrays[0]);
|
||||||
|
|
||||||
long len = target.lengthLong();
|
long len = target.length();
|
||||||
|
|
||||||
PointerPointer dataPointers = new PointerPointer(arrays.length);
|
PointerPointer dataPointers = new PointerPointer(arrays.length);
|
||||||
|
|
||||||
|
@ -703,7 +703,7 @@ public class CpuNDArrayFactory extends BaseNativeNDArrayFactory {
|
||||||
if (arrays[i].elementWiseStride() != 1)
|
if (arrays[i].elementWiseStride() != 1)
|
||||||
throw new ND4JIllegalStateException("Native accumulation is applicable only to continuous INDArrays");
|
throw new ND4JIllegalStateException("Native accumulation is applicable only to continuous INDArrays");
|
||||||
|
|
||||||
if (arrays[i].lengthLong() != len)
|
if (arrays[i].length() != len)
|
||||||
throw new ND4JIllegalStateException("All arrays should have equal length for accumulation");
|
throw new ND4JIllegalStateException("All arrays should have equal length for accumulation");
|
||||||
|
|
||||||
dataPointers.put(i, arrays[i].data().addressPointer());
|
dataPointers.put(i, arrays[i].data().addressPointer());
|
||||||
|
@ -744,7 +744,7 @@ public class CpuNDArrayFactory extends BaseNativeNDArrayFactory {
|
||||||
return target.assign(arrays[0]);
|
return target.assign(arrays[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
long len = target != null ? target.lengthLong() : arrays[0].length();
|
long len = target != null ? target.length() : arrays[0].length();
|
||||||
|
|
||||||
PointerPointer dataPointers = new PointerPointer(arrays.length);
|
PointerPointer dataPointers = new PointerPointer(arrays.length);
|
||||||
val firstType = arrays[0].dataType();
|
val firstType = arrays[0].dataType();
|
||||||
|
@ -757,7 +757,7 @@ public class CpuNDArrayFactory extends BaseNativeNDArrayFactory {
|
||||||
if (arrays[i].elementWiseStride() != 1)
|
if (arrays[i].elementWiseStride() != 1)
|
||||||
throw new ND4JIllegalStateException("Native averaging is applicable only to continuous INDArrays");
|
throw new ND4JIllegalStateException("Native averaging is applicable only to continuous INDArrays");
|
||||||
|
|
||||||
if (arrays[i].lengthLong() != len)
|
if (arrays[i].length() != len)
|
||||||
throw new ND4JIllegalStateException("All arrays should have equal length for averaging");
|
throw new ND4JIllegalStateException("All arrays should have equal length for averaging");
|
||||||
|
|
||||||
dataPointers.put(i, arrays[i].data().addressPointer());
|
dataPointers.put(i, arrays[i].data().addressPointer());
|
||||||
|
|
|
@ -303,11 +303,11 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
//Every X TAD vs. entirety of Y
|
//Every X TAD vs. entirety of Y
|
||||||
val xTADSize = op.x().lengthLong() / op.x().tensorsAlongDimension(dimension);
|
val xTADSize = op.x().length() / op.x().tensorsAlongDimension(dimension);
|
||||||
|
|
||||||
if (xTADSize != op.y().length()) {
|
if (xTADSize != op.y().length()) {
|
||||||
throw new ND4JIllegalStateException("Size of TADs along dimension don't match for pairwise execution:" +
|
throw new ND4JIllegalStateException("Size of TADs along dimension don't match for pairwise execution:" +
|
||||||
" (x TAD size = " + xTADSize + ", y size = " + op.y().lengthLong());
|
" (x TAD size = " + xTADSize + ", y size = " + op.y().length());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -329,7 +329,7 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
|
||||||
long xT = op.x().tensorsAlongDimension(dimension);
|
long xT = op.x().tensorsAlongDimension(dimension);
|
||||||
long yT = op.y().tensorsAlongDimension(dimension);
|
long yT = op.y().tensorsAlongDimension(dimension);
|
||||||
|
|
||||||
if (op.z().lengthLong() != xT * yT)
|
if (op.z().length() != xT * yT)
|
||||||
throw new ND4JIllegalStateException("Shape of target array for reduction [" + Arrays.toString(op.z().shape()) + "] doesn't match expected [" + (xT * yT) + "]");
|
throw new ND4JIllegalStateException("Shape of target array for reduction [" + Arrays.toString(op.z().shape()) + "] doesn't match expected [" + (xT * yT) + "]");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -358,7 +358,7 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
|
||||||
// we're going to check, if that's TAD vs TAD comparison or TAD vs full array. if later - we're going slightly different route
|
// we're going to check, if that's TAD vs TAD comparison or TAD vs full array. if later - we're going slightly different route
|
||||||
boolean tvf = false;
|
boolean tvf = false;
|
||||||
if (op.y() != null) {
|
if (op.y() != null) {
|
||||||
if (op.x().tensorAlongDimension(0, dimension).lengthLong() == op.y().lengthLong()) {
|
if (op.x().tensorAlongDimension(0, dimension).length() == op.y().length()) {
|
||||||
tvf = true;
|
tvf = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -366,10 +366,10 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
|
||||||
if (op.isComplexAccumulation()) {
|
if (op.isComplexAccumulation()) {
|
||||||
yTadBuffers = tadManager.getTADOnlyShapeInfo(op.y(), dimension);
|
yTadBuffers = tadManager.getTADOnlyShapeInfo(op.y(), dimension);
|
||||||
|
|
||||||
if (op.x().tensorAlongDimension(0, dimension).lengthLong() != op.y().tensorAlongDimension(0, dimension).lengthLong())
|
if (op.x().tensorAlongDimension(0, dimension).length() != op.y().tensorAlongDimension(0, dimension).length())
|
||||||
throw new ND4JIllegalStateException("Impossible to issue AllDistances operation: TAD lengths mismatch along given dimension: " +
|
throw new ND4JIllegalStateException("Impossible to issue AllDistances operation: TAD lengths mismatch along given dimension: " +
|
||||||
"x TAD length = " + op.x().tensorAlongDimension(0, dimension).lengthLong() + ", y TAD length " +
|
"x TAD length = " + op.x().tensorAlongDimension(0, dimension).length() + ", y TAD length " +
|
||||||
op.y().tensorAlongDimension(0, dimension).lengthLong());
|
op.y().tensorAlongDimension(0, dimension).length());
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -659,7 +659,7 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
|
||||||
|
|
||||||
//validateDataType(Nd4j.dataType(), op);
|
//validateDataType(Nd4j.dataType(), op);
|
||||||
|
|
||||||
if (op.x().lengthLong() != op.z().lengthLong())
|
if (op.x().length() != op.z().length())
|
||||||
throw new ND4JIllegalStateException("op.X length should be equal to op.Z length: " +
|
throw new ND4JIllegalStateException("op.X length should be equal to op.Z length: " +
|
||||||
"x.length()=" + op.x().length() + ", z.length()=" + op.z().length() + " - x shape info = ["
|
"x.length()=" + op.x().length() + ", z.length()=" + op.z().length() + " - x shape info = ["
|
||||||
+ Arrays.toString(op.x().shapeInfoDataBuffer().asInt()) + "], z shape info = ["
|
+ Arrays.toString(op.x().shapeInfoDataBuffer().asInt()) + "], z shape info = ["
|
||||||
|
@ -1449,8 +1449,8 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
|
||||||
long originalLength = buffer.getInt(1);
|
long originalLength = buffer.getInt(1);
|
||||||
float threshold = buffer.getInt(2);
|
float threshold = buffer.getInt(2);
|
||||||
|
|
||||||
if (target.lengthLong() != originalLength)
|
if (target.length() != originalLength)
|
||||||
throw new ND4JIllegalStateException("originalLength ["+ originalLength+"] stored in encoded array doesn't match target length ["+ target.lengthLong()+"]");
|
throw new ND4JIllegalStateException("originalLength ["+ originalLength+"] stored in encoded array doesn't match target length ["+ target.length()+"]");
|
||||||
|
|
||||||
DataTypeEx typeDst = AbstractCompressor.getBufferTypeEx(target.data());
|
DataTypeEx typeDst = AbstractCompressor.getBufferTypeEx(target.data());
|
||||||
|
|
||||||
|
@ -1465,7 +1465,7 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public long bitmapEncode(INDArray indArray, INDArray target, double threshold) {
|
public long bitmapEncode(INDArray indArray, INDArray target, double threshold) {
|
||||||
long length = indArray.lengthLong();
|
long length = indArray.length();
|
||||||
long tLen = target.data().length();
|
long tLen = target.data().length();
|
||||||
|
|
||||||
if (tLen != (length / 16 + 5))
|
if (tLen != (length / 16 + 5))
|
||||||
|
|
|
@ -16,14 +16,19 @@
|
||||||
|
|
||||||
package org.nd4j.autodiff.opvalidation;
|
package org.nd4j.autodiff.opvalidation;
|
||||||
|
|
||||||
|
import java.util.Arrays;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
import org.nd4j.autodiff.samediff.SDIndex;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.GRUCellConfiguration;
|
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.GRUCellConfiguration;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMBlockCellConfiguration;
|
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMConfiguration;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs.LSTMCellOutputs;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.GRUWeights;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMWeights;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||||
import org.nd4j.linalg.indexing.NDArrayIndex;
|
import org.nd4j.linalg.indexing.NDArrayIndex;
|
||||||
|
@ -59,23 +64,18 @@ public class RnnOpValidation extends BaseOpValidation {
|
||||||
SDVariable b = sd.constant(Nd4j.rand(DataType.FLOAT, 4*nOut));
|
SDVariable b = sd.constant(Nd4j.rand(DataType.FLOAT, 4*nOut));
|
||||||
|
|
||||||
double fb = 1.0;
|
double fb = 1.0;
|
||||||
LSTMBlockCellConfiguration conf = LSTMBlockCellConfiguration.builder()
|
LSTMConfiguration conf = LSTMConfiguration.builder()
|
||||||
.xt(x)
|
|
||||||
.cLast(cLast)
|
|
||||||
.yLast(yLast)
|
|
||||||
.W(W)
|
|
||||||
.Wci(Wci)
|
|
||||||
.Wcf(Wcf)
|
|
||||||
.Wco(Wco)
|
|
||||||
.b(b)
|
|
||||||
.peepHole(true)
|
.peepHole(true)
|
||||||
.forgetBias(fb)
|
.forgetBias(fb)
|
||||||
.clippingCellValue(0.0)
|
.clippingCellValue(0.0)
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
List<SDVariable> v = sd.rnn().lstmBlockCell("lstm", conf); //Output order: i, c, f, o, z, h, y
|
LSTMWeights weights = LSTMWeights.builder().weights(W).bias(b)
|
||||||
|
.inputPeepholeWeights(Wci).forgetPeepholeWeights(Wcf).outputPeepholeWeights(Wco).build();
|
||||||
|
|
||||||
|
LSTMCellOutputs v = sd.rnn().lstmCell(x, cLast, yLast, weights, conf); //Output order: i, c, f, o, z, h, y
|
||||||
List<String> toExec = new ArrayList<>();
|
List<String> toExec = new ArrayList<>();
|
||||||
for(SDVariable sdv : v){
|
for(SDVariable sdv : v.getAllOutputs()){
|
||||||
toExec.add(sdv.getVarName());
|
toExec.add(sdv.getVarName());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -167,23 +167,18 @@ public class RnnOpValidation extends BaseOpValidation {
|
||||||
SDVariable b = sd.constant(Nd4j.zeros(DataType.FLOAT, 8));
|
SDVariable b = sd.constant(Nd4j.zeros(DataType.FLOAT, 8));
|
||||||
|
|
||||||
double fb = 1.0;
|
double fb = 1.0;
|
||||||
LSTMBlockCellConfiguration conf = LSTMBlockCellConfiguration.builder()
|
LSTMConfiguration conf = LSTMConfiguration.builder()
|
||||||
.xt(x)
|
|
||||||
.cLast(cLast)
|
|
||||||
.yLast(yLast)
|
|
||||||
.W(W)
|
|
||||||
.Wci(Wci)
|
|
||||||
.Wcf(Wcf)
|
|
||||||
.Wco(Wco)
|
|
||||||
.b(b)
|
|
||||||
.peepHole(false)
|
.peepHole(false)
|
||||||
.forgetBias(fb)
|
.forgetBias(fb)
|
||||||
.clippingCellValue(0.0)
|
.clippingCellValue(0.0)
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
List<SDVariable> v = sd.rnn().lstmBlockCell("lstm", conf); //Output order: i, c, f, o, z, h, y
|
LSTMWeights weights = LSTMWeights.builder().weights(W).bias(b)
|
||||||
|
.inputPeepholeWeights(Wci).forgetPeepholeWeights(Wcf).outputPeepholeWeights(Wco).build();
|
||||||
|
|
||||||
|
LSTMCellOutputs v = sd.rnn().lstmCell(x, cLast, yLast, weights, conf); //Output order: i, c, f, o, z, h, y
|
||||||
List<String> toExec = new ArrayList<>();
|
List<String> toExec = new ArrayList<>();
|
||||||
for(SDVariable sdv : v){
|
for(SDVariable sdv : v.getAllOutputs()){
|
||||||
toExec.add(sdv.getVarName());
|
toExec.add(sdv.getVarName());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -228,16 +223,14 @@ public class RnnOpValidation extends BaseOpValidation {
|
||||||
SDVariable bc = sd.constant(Nd4j.rand(DataType.FLOAT, nOut));
|
SDVariable bc = sd.constant(Nd4j.rand(DataType.FLOAT, nOut));
|
||||||
|
|
||||||
double fb = 1.0;
|
double fb = 1.0;
|
||||||
GRUCellConfiguration conf = GRUCellConfiguration.builder()
|
GRUWeights weights = GRUWeights.builder()
|
||||||
.xt(x)
|
.ruWeight(Wru)
|
||||||
.hLast(hLast)
|
.cWeight(Wc)
|
||||||
.Wru(Wru)
|
.ruBias(bru)
|
||||||
.Wc(Wc)
|
.cBias(bc)
|
||||||
.bru(bru)
|
|
||||||
.bc(bc)
|
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
List<SDVariable> v = sd.rnn().gru("gru", conf);
|
List<SDVariable> v = sd.rnn().gru("gru", x, hLast, weights).getAllOutputs();
|
||||||
List<String> toExec = new ArrayList<>();
|
List<String> toExec = new ArrayList<>();
|
||||||
for(SDVariable sdv : v){
|
for(SDVariable sdv : v){
|
||||||
toExec.add(sdv.getVarName());
|
toExec.add(sdv.getVarName());
|
||||||
|
|
|
@ -5155,7 +5155,7 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
|
|
||||||
INDArray res = x.entropy(1);
|
INDArray res = x.entropy(1);
|
||||||
|
|
||||||
assertEquals(10, res.lengthLong());
|
assertEquals(10, res.length());
|
||||||
|
|
||||||
for (int t = 0; t < x.rows(); t++) {
|
for (int t = 0; t < x.rows(); t++) {
|
||||||
double exp = MathUtils.entropy(x.getRow(t).dup().data().asDouble());
|
double exp = MathUtils.entropy(x.getRow(t).dup().data().asDouble());
|
||||||
|
|
|
@ -415,7 +415,7 @@ public class ShufflesTests extends BaseNd4jTest {
|
||||||
|
|
||||||
for (int x = 0; x < newData.rows(); x++) {
|
for (int x = 0; x < newData.rows(); x++) {
|
||||||
INDArray row = newData.getRow(x);
|
INDArray row = newData.getRow(x);
|
||||||
for (int y = 0; y < row.lengthLong(); y++) {
|
for (int y = 0; y < row.length(); y++) {
|
||||||
if (Math.abs(row.getFloat(y) - newMap[x]) > Nd4j.EPS_THRESHOLD) {
|
if (Math.abs(row.getFloat(y) - newMap[x]) > Nd4j.EPS_THRESHOLD) {
|
||||||
System.out.print("Different data in a row");
|
System.out.print("Different data in a row");
|
||||||
return false;
|
return false;
|
||||||
|
@ -442,7 +442,7 @@ public class ShufflesTests extends BaseNd4jTest {
|
||||||
for (int x = 0; x < newData.rows(); x++) {
|
for (int x = 0; x < newData.rows(); x++) {
|
||||||
INDArray column = newData.getColumn(x);
|
INDArray column = newData.getColumn(x);
|
||||||
double val = column.getDouble(0);
|
double val = column.getDouble(0);
|
||||||
for (int y = 0; y < column.lengthLong(); y++) {
|
for (int y = 0; y < column.length(); y++) {
|
||||||
if (Math.abs(column.getFloat(y) - val) > Nd4j.EPS_THRESHOLD) {
|
if (Math.abs(column.getFloat(y) - val) > Nd4j.EPS_THRESHOLD) {
|
||||||
System.out.print("Different data in a column: " + column.getFloat(y));
|
System.out.print("Different data in a column: " + column.getFloat(y));
|
||||||
return false;
|
return false;
|
||||||
|
|
|
@ -92,7 +92,7 @@ public class LapackTest extends BaseNd4jTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testCholeskyU() {
|
public void testCholeskyU() {
|
||||||
INDArray A = Nd4j.create(new double[] {2, -1, 2, -1, 2, -1, 2, -1, 2,});
|
INDArray A = Nd4j.create(new double[] {3, -1, 2, -1, 3, -1, 2, -1, 3,});
|
||||||
A = A.reshape('f', 3, 3);
|
A = A.reshape('f', 3, 3);
|
||||||
INDArray O = Nd4j.create(A.dataType(), A.shape());
|
INDArray O = Nd4j.create(A.dataType(), A.shape());
|
||||||
Nd4j.copy(A, O);
|
Nd4j.copy(A, O);
|
||||||
|
|
|
@ -168,7 +168,7 @@ public class CompressionTests extends BaseNd4jTest {
|
||||||
INDArray decompressed = Nd4j.create(1, initial.length());
|
INDArray decompressed = Nd4j.create(1, initial.length());
|
||||||
Nd4j.getExecutioner().thresholdDecode(compressed, decompressed);
|
Nd4j.getExecutioner().thresholdDecode(compressed, decompressed);
|
||||||
|
|
||||||
log.info("Decompressed length: {}", decompressed.lengthLong());
|
log.info("Decompressed length: {}", decompressed.length());
|
||||||
|
|
||||||
assertEquals(exp_d, decompressed);
|
assertEquals(exp_d, decompressed);
|
||||||
}
|
}
|
||||||
|
|
|
@ -23,6 +23,7 @@ import lombok.NoArgsConstructor;
|
||||||
|
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
|
import org.nd4j.base.Preconditions;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Simple pair implementation
|
* Simple pair implementation
|
||||||
|
@ -86,4 +87,10 @@ public class Pair<K, V> implements Serializable {
|
||||||
public static <T, E> Pair<T,E> pairOf(T key, E value) {
|
public static <T, E> Pair<T,E> pairOf(T key, E value) {
|
||||||
return new Pair<T, E>(key, value);
|
return new Pair<T, E>(key, value);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static <T> Pair<T, T> fromArray(T[] arr){
|
||||||
|
Preconditions.checkArgument(arr.length == 2,
|
||||||
|
"Can only create a pair from an array with two values, got %s", arr.length);
|
||||||
|
return new Pair<>(arr[0], arr[1]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue