From 7d857759341a1fb5f511b394800c93c0529fb14f Mon Sep 17 00:00:00 2001 From: Alex Black Date: Thu, 5 Sep 2019 11:51:11 +1000 Subject: [PATCH] Arbiter generic JSON ser/de fixes (#237) * Arbiter generic JSON ser/de fixes Signed-off-by: AlexDBlack * Javadoc fix Signed-off-by: AlexDBlack --- .../optimize/parameter/FixedValue.java | 13 +++-- .../serde/jackson/FixedValueDeserializer.java | 52 +++++++++++++++++++ .../serde/jackson/FixedValueSerializer.java | 51 ++++++++++++++++++ .../serde/jackson/GenericDeserializer.java | 46 ---------------- .../serde/jackson/GenericSerializer.java | 38 -------------- .../optimize/serde/jackson/JsonMapper.java | 4 +- .../arbiter/layers/BaseOutputLayerSpace.java | 2 +- .../ui/data/GlobalConfigPersistable.java | 11 +++- .../ui/listener/ArbiterStatusListener.java | 9 +++- .../linalg/lossfunctions/impl/LossL2.java | 3 +- .../linalg/lossfunctions/impl/LossMSE.java | 3 +- 11 files changed, 133 insertions(+), 99 deletions(-) create mode 100644 arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/serde/jackson/FixedValueDeserializer.java create mode 100644 arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/serde/jackson/FixedValueSerializer.java delete mode 100644 arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/serde/jackson/GenericDeserializer.java delete mode 100644 arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/serde/jackson/GenericSerializer.java diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/parameter/FixedValue.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/parameter/FixedValue.java index 0be9613de..6482003e5 100644 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/parameter/FixedValue.java +++ b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/parameter/FixedValue.java @@ -17,13 +17,14 @@ package org.deeplearning4j.arbiter.optimize.parameter; import lombok.EqualsAndHashCode; +import lombok.Getter; import org.deeplearning4j.arbiter.optimize.api.ParameterSpace; -import org.deeplearning4j.arbiter.optimize.serde.jackson.GenericDeserializer; -import org.deeplearning4j.arbiter.optimize.serde.jackson.GenericSerializer; +import org.deeplearning4j.arbiter.optimize.serde.jackson.FixedValueDeserializer; +import org.deeplearning4j.arbiter.optimize.serde.jackson.FixedValueSerializer; import org.deeplearning4j.arbiter.util.ObjectUtils; import org.nd4j.shade.jackson.annotation.JsonCreator; -import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties; import org.nd4j.shade.jackson.annotation.JsonProperty; +import org.nd4j.shade.jackson.annotation.JsonTypeInfo; import org.nd4j.shade.jackson.databind.annotation.JsonDeserialize; import org.nd4j.shade.jackson.databind.annotation.JsonSerialize; @@ -37,9 +38,11 @@ import java.util.Map; * @param Type of (fixed) value */ @EqualsAndHashCode +@JsonSerialize(using = FixedValueSerializer.class) +@JsonDeserialize(using = FixedValueDeserializer.class) +@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class") public class FixedValue implements ParameterSpace { - @JsonSerialize(using = GenericSerializer.class) - @JsonDeserialize(using = GenericDeserializer.class) + @Getter private Object value; private int index; diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/serde/jackson/FixedValueDeserializer.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/serde/jackson/FixedValueDeserializer.java new file mode 100644 index 000000000..24b76fd42 --- /dev/null +++ b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/serde/jackson/FixedValueDeserializer.java @@ -0,0 +1,52 @@ +package org.deeplearning4j.arbiter.optimize.serde.jackson; + +import org.apache.commons.codec.binary.Base64; +import org.deeplearning4j.arbiter.optimize.parameter.FixedValue; +import org.nd4j.shade.jackson.core.JsonParser; +import org.nd4j.shade.jackson.core.JsonProcessingException; +import org.nd4j.shade.jackson.databind.DeserializationContext; +import org.nd4j.shade.jackson.databind.JsonDeserializer; +import org.nd4j.shade.jackson.databind.JsonNode; +import org.nd4j.shade.jackson.databind.ObjectMapper; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.ObjectInputStream; + +/** + * A custom deserializer to be used in conjunction with {@link FixedValueSerializer} + * @author Alex Black + */ +public class FixedValueDeserializer extends JsonDeserializer { + @Override + public FixedValue deserialize(JsonParser p, DeserializationContext deserializationContext) throws IOException, JsonProcessingException { + JsonNode node = p.getCodec().readTree(p); + String className = node.get("@valueclass").asText(); + Class c; + try { + c = Class.forName(className); + } catch (Exception e) { + throw new RuntimeException(e); + } + + if(node.has("value")){ + //Number, String, Enum + JsonNode valueNode = node.get("value"); + Object o = new ObjectMapper().treeToValue(valueNode, c); + return new FixedValue<>(o); + } else { + //Everything else + JsonNode valueNode = node.get("data"); + String data = valueNode.asText(); + + byte[] b = new Base64().decode(data); + ObjectInputStream ois = new ObjectInputStream(new ByteArrayInputStream(b)); + try { + Object o = ois.readObject(); + return new FixedValue<>(o); + } catch (Throwable t) { + throw new RuntimeException(t); + } + } + } +} diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/serde/jackson/FixedValueSerializer.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/serde/jackson/FixedValueSerializer.java new file mode 100644 index 000000000..349177595 --- /dev/null +++ b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/serde/jackson/FixedValueSerializer.java @@ -0,0 +1,51 @@ +package org.deeplearning4j.arbiter.optimize.serde.jackson; + +import org.apache.commons.net.util.Base64; +import org.deeplearning4j.arbiter.optimize.parameter.FixedValue; +import org.nd4j.shade.jackson.core.JsonGenerator; +import org.nd4j.shade.jackson.core.type.WritableTypeId; +import org.nd4j.shade.jackson.databind.JsonSerializer; +import org.nd4j.shade.jackson.databind.SerializerProvider; +import org.nd4j.shade.jackson.databind.jsontype.TypeSerializer; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.ObjectOutputStream; + +import static org.nd4j.shade.jackson.core.JsonToken.START_OBJECT; + +/** + * A custom serializer to handle arbitrary object types + * Uses standard JSON where safe (number, string, enumerations) or Java object serialization (bytes -> base64) + * The latter is not an ideal approach, but Jackson doesn't support serialization/deserialization of arbitrary + * objects very well + * + * @author Alex Black + */ +public class FixedValueSerializer extends JsonSerializer { + @Override + public void serialize(FixedValue fixedValue, JsonGenerator j, SerializerProvider serializerProvider) throws IOException { + Object o = fixedValue.getValue(); + + j.writeStringField("@valueclass", o.getClass().getName()); + if(o instanceof Number || o instanceof String || o instanceof Enum){ + j.writeObjectField("value", o); + } else { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + ObjectOutputStream oos = new ObjectOutputStream(baos); + oos.writeObject(o); + baos.close(); + byte[] b = baos.toByteArray(); + String base64 = new Base64().encodeToString(b); + j.writeStringField("data", base64); + } + } + + @Override + public void serializeWithType(FixedValue value, JsonGenerator gen, SerializerProvider serializers, TypeSerializer typeSer) throws IOException { + WritableTypeId typeId = typeSer.typeId(value, START_OBJECT); + typeSer.writeTypePrefix(gen, typeId); + serialize(value, gen, serializers); + typeSer.writeTypeSuffix(gen, typeId); + } +} diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/serde/jackson/GenericDeserializer.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/serde/jackson/GenericDeserializer.java deleted file mode 100644 index de35dba18..000000000 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/serde/jackson/GenericDeserializer.java +++ /dev/null @@ -1,46 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.serde.jackson; - -import org.nd4j.shade.jackson.core.JsonParser; -import org.nd4j.shade.jackson.databind.DeserializationContext; -import org.nd4j.shade.jackson.databind.JsonDeserializer; -import org.nd4j.shade.jackson.databind.JsonNode; -import org.nd4j.shade.jackson.databind.ObjectMapper; - -import java.io.IOException; - -/** - * Created by Alex on 15/02/2017. - */ -public class GenericDeserializer extends JsonDeserializer { - @Override - public Object deserialize(JsonParser p, DeserializationContext ctxt) throws IOException { - JsonNode node = p.getCodec().readTree(p); - String className = node.get("@class").asText(); - Class c; - try { - c = Class.forName(className); - } catch (Exception e) { - throw new RuntimeException(e); - } - - JsonNode valueNode = node.get("value"); - Object o = new ObjectMapper().treeToValue(valueNode, c); - return o; - } -} diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/serde/jackson/GenericSerializer.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/serde/jackson/GenericSerializer.java deleted file mode 100644 index 035ac7c50..000000000 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/serde/jackson/GenericSerializer.java +++ /dev/null @@ -1,38 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.serde.jackson; - -import org.nd4j.shade.jackson.core.JsonGenerator; -import org.nd4j.shade.jackson.core.JsonProcessingException; -import org.nd4j.shade.jackson.databind.JsonSerializer; -import org.nd4j.shade.jackson.databind.SerializerProvider; - -import java.io.IOException; - -/** - * Created by Alex on 15/02/2017. - */ -public class GenericSerializer extends JsonSerializer { - @Override - public void serialize(Object o, JsonGenerator j, SerializerProvider serializerProvider) - throws IOException, JsonProcessingException { - j.writeStartObject(); - j.writeStringField("@class", o.getClass().getName()); - j.writeObjectField("value", o); - j.writeEndObject(); - } -} diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/serde/jackson/JsonMapper.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/serde/jackson/JsonMapper.java index 8cfb07723..f30cab109 100644 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/serde/jackson/JsonMapper.java +++ b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/serde/jackson/JsonMapper.java @@ -24,9 +24,6 @@ import org.nd4j.shade.jackson.databind.SerializationFeature; import org.nd4j.shade.jackson.dataformat.yaml.YAMLFactory; import org.nd4j.shade.jackson.datatype.joda.JodaModule; -import java.util.Collections; -import java.util.Map; - /** * Created by Alex on 16/11/2016. */ @@ -44,6 +41,7 @@ public class JsonMapper { mapper.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.NONE); mapper.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY); mapper.setVisibility(PropertyAccessor.CREATOR, JsonAutoDetect.Visibility.ANY); + mapper.setVisibility(PropertyAccessor.SETTER, JsonAutoDetect.Visibility.ANY); yamlMapper = new ObjectMapper(new YAMLFactory()); yamlMapper.registerModule(new JodaModule()); yamlMapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false); diff --git a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/BaseOutputLayerSpace.java b/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/BaseOutputLayerSpace.java index 3a72156e9..857f729ad 100644 --- a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/BaseOutputLayerSpace.java +++ b/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/BaseOutputLayerSpace.java @@ -32,7 +32,7 @@ import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction; */ @Data @EqualsAndHashCode(callSuper = true) -@NoArgsConstructor(access = AccessLevel.PROTECTED) //For Jackson JSON/YAML deserialization +@NoArgsConstructor(access = AccessLevel.PUBLIC) //For Jackson JSON/YAML deserialization public abstract class BaseOutputLayerSpace extends FeedForwardLayerSpace { protected ParameterSpace lossFunction; diff --git a/arbiter/arbiter-ui/src/main/java/org/deeplearning4j/arbiter/ui/data/GlobalConfigPersistable.java b/arbiter/arbiter-ui/src/main/java/org/deeplearning4j/arbiter/ui/data/GlobalConfigPersistable.java index d11c251e3..00a95a628 100644 --- a/arbiter/arbiter-ui/src/main/java/org/deeplearning4j/arbiter/ui/data/GlobalConfigPersistable.java +++ b/arbiter/arbiter-ui/src/main/java/org/deeplearning4j/arbiter/ui/data/GlobalConfigPersistable.java @@ -18,8 +18,11 @@ package org.deeplearning4j.arbiter.ui.data; import lombok.Getter; import org.deeplearning4j.arbiter.optimize.config.OptimizationConfiguration; -import org.deeplearning4j.arbiter.ui.misc.JsonMapper; +import org.deeplearning4j.arbiter.optimize.serde.jackson.JsonMapper; import org.deeplearning4j.arbiter.ui.module.ArbiterModule; +import org.deeplearning4j.nn.conf.serde.JsonMappers; + +import java.io.IOException; /** * @@ -64,7 +67,11 @@ public class GlobalConfigPersistable extends BaseJavaPersistable { public OptimizationConfiguration getOptimizationConfiguration(){ - return JsonMapper.fromJson(optimizationConfigJson, OptimizationConfiguration.class); + try { + return JsonMapper.getMapper().readValue(optimizationConfigJson, OptimizationConfiguration.class); + } catch (IOException e){ + throw new RuntimeException(e); + } } public int getCandidatesQueued(){ diff --git a/arbiter/arbiter-ui/src/main/java/org/deeplearning4j/arbiter/ui/listener/ArbiterStatusListener.java b/arbiter/arbiter-ui/src/main/java/org/deeplearning4j/arbiter/ui/listener/ArbiterStatusListener.java index f323f8a7a..7802762c9 100644 --- a/arbiter/arbiter-ui/src/main/java/org/deeplearning4j/arbiter/ui/listener/ArbiterStatusListener.java +++ b/arbiter/arbiter-ui/src/main/java/org/deeplearning4j/arbiter/ui/listener/ArbiterStatusListener.java @@ -26,13 +26,14 @@ import org.deeplearning4j.arbiter.optimize.api.OptimizationResult; import org.deeplearning4j.arbiter.optimize.runner.CandidateInfo; import org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner; import org.deeplearning4j.arbiter.optimize.runner.listener.StatusListener; +import org.deeplearning4j.arbiter.optimize.serde.jackson.JsonMapper; import org.deeplearning4j.arbiter.ui.data.GlobalConfigPersistable; import org.deeplearning4j.arbiter.ui.data.ModelInfoPersistable; -import org.deeplearning4j.arbiter.ui.misc.JsonMapper; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.nd4j.linalg.primitives.Pair; +import java.io.IOException; import java.util.Map; import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; @@ -217,7 +218,11 @@ public class ArbiterStatusListener implements StatusListener { // } //TODO: cache global config, but we don't want to have outdated info (like uninitialized termination conditions) - ocJson = JsonMapper.asJson(r.getConfiguration()); + try { + ocJson = JsonMapper.getMapper().writeValueAsString(r.getConfiguration()); + } catch (IOException e){ + throw new RuntimeException(e); + } GlobalConfigPersistable p = new GlobalConfigPersistable.Builder() .sessionId(sessionId) diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossL2.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossL2.java index e9b30328c..a9f3b2c18 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossL2.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossL2.java @@ -27,6 +27,7 @@ import org.nd4j.linalg.primitives.Pair; import org.nd4j.serde.jackson.shaded.NDArrayTextDeSerializer; import org.nd4j.serde.jackson.shaded.NDArrayTextSerializer; import org.nd4j.shade.jackson.annotation.JsonInclude; +import org.nd4j.shade.jackson.annotation.JsonProperty; import org.nd4j.shade.jackson.databind.annotation.JsonDeserialize; import org.nd4j.shade.jackson.databind.annotation.JsonSerialize; @@ -58,7 +59,7 @@ public class LossL2 implements ILossFunction { * * @param weights Weights array (row vector). May be null. */ - public LossL2(INDArray weights) { + public LossL2(@JsonProperty("weights") INDArray weights) { if (weights != null && !weights.isRowVector()) { throw new IllegalArgumentException("Weights array must be a row vector"); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossMSE.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossMSE.java index 4fc5a7eec..bb64bb777 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossMSE.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossMSE.java @@ -19,6 +19,7 @@ package org.nd4j.linalg.lossfunctions.impl; import lombok.EqualsAndHashCode; import org.nd4j.linalg.activations.IActivation; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.shade.jackson.annotation.JsonProperty; /** * Mean Squared Error loss function: L = 1/N sum_i (actual_i - predicted)^2 @@ -38,7 +39,7 @@ public class LossMSE extends LossL2 { * * @param weights Weights array (row vector). May be null. */ - public LossMSE(INDArray weights) { + public LossMSE(@JsonProperty("weights") INDArray weights) { super(weights); }