Arbiter generic JSON ser/de fixes (#237)

* Arbiter generic JSON ser/de fixes

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Javadoc fix

Signed-off-by: AlexDBlack <blacka101@gmail.com>
master
Alex Black 2019-09-05 11:51:11 +10:00 committed by GitHub
parent f25e3e71e5
commit 7d85775934
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 133 additions and 99 deletions

View File

@ -17,13 +17,14 @@
package org.deeplearning4j.arbiter.optimize.parameter;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
import org.deeplearning4j.arbiter.optimize.serde.jackson.GenericDeserializer;
import org.deeplearning4j.arbiter.optimize.serde.jackson.GenericSerializer;
import org.deeplearning4j.arbiter.optimize.serde.jackson.FixedValueDeserializer;
import org.deeplearning4j.arbiter.optimize.serde.jackson.FixedValueSerializer;
import org.deeplearning4j.arbiter.util.ObjectUtils;
import org.nd4j.shade.jackson.annotation.JsonCreator;
import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties;
import org.nd4j.shade.jackson.annotation.JsonProperty;
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
import org.nd4j.shade.jackson.databind.annotation.JsonDeserialize;
import org.nd4j.shade.jackson.databind.annotation.JsonSerialize;
@ -37,9 +38,11 @@ import java.util.Map;
* @param <T> Type of (fixed) value
*/
@EqualsAndHashCode
@JsonSerialize(using = FixedValueSerializer.class)
@JsonDeserialize(using = FixedValueDeserializer.class)
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
public class FixedValue<T> implements ParameterSpace<T> {
@JsonSerialize(using = GenericSerializer.class)
@JsonDeserialize(using = GenericDeserializer.class)
@Getter
private Object value;
private int index;

View File

@ -0,0 +1,52 @@
package org.deeplearning4j.arbiter.optimize.serde.jackson;
import org.apache.commons.codec.binary.Base64;
import org.deeplearning4j.arbiter.optimize.parameter.FixedValue;
import org.nd4j.shade.jackson.core.JsonParser;
import org.nd4j.shade.jackson.core.JsonProcessingException;
import org.nd4j.shade.jackson.databind.DeserializationContext;
import org.nd4j.shade.jackson.databind.JsonDeserializer;
import org.nd4j.shade.jackson.databind.JsonNode;
import org.nd4j.shade.jackson.databind.ObjectMapper;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
/**
* A custom deserializer to be used in conjunction with {@link FixedValueSerializer}
* @author Alex Black
*/
public class FixedValueDeserializer extends JsonDeserializer<FixedValue> {
@Override
public FixedValue deserialize(JsonParser p, DeserializationContext deserializationContext) throws IOException, JsonProcessingException {
JsonNode node = p.getCodec().readTree(p);
String className = node.get("@valueclass").asText();
Class<?> c;
try {
c = Class.forName(className);
} catch (Exception e) {
throw new RuntimeException(e);
}
if(node.has("value")){
//Number, String, Enum
JsonNode valueNode = node.get("value");
Object o = new ObjectMapper().treeToValue(valueNode, c);
return new FixedValue<>(o);
} else {
//Everything else
JsonNode valueNode = node.get("data");
String data = valueNode.asText();
byte[] b = new Base64().decode(data);
ObjectInputStream ois = new ObjectInputStream(new ByteArrayInputStream(b));
try {
Object o = ois.readObject();
return new FixedValue<>(o);
} catch (Throwable t) {
throw new RuntimeException(t);
}
}
}
}

View File

@ -0,0 +1,51 @@
package org.deeplearning4j.arbiter.optimize.serde.jackson;
import org.apache.commons.net.util.Base64;
import org.deeplearning4j.arbiter.optimize.parameter.FixedValue;
import org.nd4j.shade.jackson.core.JsonGenerator;
import org.nd4j.shade.jackson.core.type.WritableTypeId;
import org.nd4j.shade.jackson.databind.JsonSerializer;
import org.nd4j.shade.jackson.databind.SerializerProvider;
import org.nd4j.shade.jackson.databind.jsontype.TypeSerializer;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.ObjectOutputStream;
import static org.nd4j.shade.jackson.core.JsonToken.START_OBJECT;
/**
* A custom serializer to handle arbitrary object types
* Uses standard JSON where safe (number, string, enumerations) or Java object serialization (bytes -> base64)
* The latter is not an ideal approach, but Jackson doesn't support serialization/deserialization of arbitrary
* objects very well
*
* @author Alex Black
*/
public class FixedValueSerializer extends JsonSerializer<FixedValue> {
@Override
public void serialize(FixedValue fixedValue, JsonGenerator j, SerializerProvider serializerProvider) throws IOException {
Object o = fixedValue.getValue();
j.writeStringField("@valueclass", o.getClass().getName());
if(o instanceof Number || o instanceof String || o instanceof Enum){
j.writeObjectField("value", o);
} else {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
ObjectOutputStream oos = new ObjectOutputStream(baos);
oos.writeObject(o);
baos.close();
byte[] b = baos.toByteArray();
String base64 = new Base64().encodeToString(b);
j.writeStringField("data", base64);
}
}
@Override
public void serializeWithType(FixedValue value, JsonGenerator gen, SerializerProvider serializers, TypeSerializer typeSer) throws IOException {
WritableTypeId typeId = typeSer.typeId(value, START_OBJECT);
typeSer.writeTypePrefix(gen, typeId);
serialize(value, gen, serializers);
typeSer.writeTypeSuffix(gen, typeId);
}
}

View File

@ -1,46 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.arbiter.optimize.serde.jackson;
import org.nd4j.shade.jackson.core.JsonParser;
import org.nd4j.shade.jackson.databind.DeserializationContext;
import org.nd4j.shade.jackson.databind.JsonDeserializer;
import org.nd4j.shade.jackson.databind.JsonNode;
import org.nd4j.shade.jackson.databind.ObjectMapper;
import java.io.IOException;
/**
* Created by Alex on 15/02/2017.
*/
public class GenericDeserializer extends JsonDeserializer<Object> {
@Override
public Object deserialize(JsonParser p, DeserializationContext ctxt) throws IOException {
JsonNode node = p.getCodec().readTree(p);
String className = node.get("@class").asText();
Class<?> c;
try {
c = Class.forName(className);
} catch (Exception e) {
throw new RuntimeException(e);
}
JsonNode valueNode = node.get("value");
Object o = new ObjectMapper().treeToValue(valueNode, c);
return o;
}
}

View File

@ -1,38 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.arbiter.optimize.serde.jackson;
import org.nd4j.shade.jackson.core.JsonGenerator;
import org.nd4j.shade.jackson.core.JsonProcessingException;
import org.nd4j.shade.jackson.databind.JsonSerializer;
import org.nd4j.shade.jackson.databind.SerializerProvider;
import java.io.IOException;
/**
* Created by Alex on 15/02/2017.
*/
public class GenericSerializer extends JsonSerializer<Object> {
@Override
public void serialize(Object o, JsonGenerator j, SerializerProvider serializerProvider)
throws IOException, JsonProcessingException {
j.writeStartObject();
j.writeStringField("@class", o.getClass().getName());
j.writeObjectField("value", o);
j.writeEndObject();
}
}

View File

@ -24,9 +24,6 @@ import org.nd4j.shade.jackson.databind.SerializationFeature;
import org.nd4j.shade.jackson.dataformat.yaml.YAMLFactory;
import org.nd4j.shade.jackson.datatype.joda.JodaModule;
import java.util.Collections;
import java.util.Map;
/**
* Created by Alex on 16/11/2016.
*/
@ -44,6 +41,7 @@ public class JsonMapper {
mapper.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.NONE);
mapper.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY);
mapper.setVisibility(PropertyAccessor.CREATOR, JsonAutoDetect.Visibility.ANY);
mapper.setVisibility(PropertyAccessor.SETTER, JsonAutoDetect.Visibility.ANY);
yamlMapper = new ObjectMapper(new YAMLFactory());
yamlMapper.registerModule(new JodaModule());
yamlMapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);

View File

@ -32,7 +32,7 @@ import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
*/
@Data
@EqualsAndHashCode(callSuper = true)
@NoArgsConstructor(access = AccessLevel.PROTECTED) //For Jackson JSON/YAML deserialization
@NoArgsConstructor(access = AccessLevel.PUBLIC) //For Jackson JSON/YAML deserialization
public abstract class BaseOutputLayerSpace<L extends BaseOutputLayer> extends FeedForwardLayerSpace<L> {
protected ParameterSpace<ILossFunction> lossFunction;

View File

@ -18,8 +18,11 @@ package org.deeplearning4j.arbiter.ui.data;
import lombok.Getter;
import org.deeplearning4j.arbiter.optimize.config.OptimizationConfiguration;
import org.deeplearning4j.arbiter.ui.misc.JsonMapper;
import org.deeplearning4j.arbiter.optimize.serde.jackson.JsonMapper;
import org.deeplearning4j.arbiter.ui.module.ArbiterModule;
import org.deeplearning4j.nn.conf.serde.JsonMappers;
import java.io.IOException;
/**
*
@ -64,7 +67,11 @@ public class GlobalConfigPersistable extends BaseJavaPersistable {
public OptimizationConfiguration getOptimizationConfiguration(){
return JsonMapper.fromJson(optimizationConfigJson, OptimizationConfiguration.class);
try {
return JsonMapper.getMapper().readValue(optimizationConfigJson, OptimizationConfiguration.class);
} catch (IOException e){
throw new RuntimeException(e);
}
}
public int getCandidatesQueued(){

View File

@ -26,13 +26,14 @@ import org.deeplearning4j.arbiter.optimize.api.OptimizationResult;
import org.deeplearning4j.arbiter.optimize.runner.CandidateInfo;
import org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner;
import org.deeplearning4j.arbiter.optimize.runner.listener.StatusListener;
import org.deeplearning4j.arbiter.optimize.serde.jackson.JsonMapper;
import org.deeplearning4j.arbiter.ui.data.GlobalConfigPersistable;
import org.deeplearning4j.arbiter.ui.data.ModelInfoPersistable;
import org.deeplearning4j.arbiter.ui.misc.JsonMapper;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.primitives.Pair;
import java.io.IOException;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
@ -217,7 +218,11 @@ public class ArbiterStatusListener implements StatusListener {
// }
//TODO: cache global config, but we don't want to have outdated info (like uninitialized termination conditions)
ocJson = JsonMapper.asJson(r.getConfiguration());
try {
ocJson = JsonMapper.getMapper().writeValueAsString(r.getConfiguration());
} catch (IOException e){
throw new RuntimeException(e);
}
GlobalConfigPersistable p = new GlobalConfigPersistable.Builder()
.sessionId(sessionId)

View File

@ -27,6 +27,7 @@ import org.nd4j.linalg.primitives.Pair;
import org.nd4j.serde.jackson.shaded.NDArrayTextDeSerializer;
import org.nd4j.serde.jackson.shaded.NDArrayTextSerializer;
import org.nd4j.shade.jackson.annotation.JsonInclude;
import org.nd4j.shade.jackson.annotation.JsonProperty;
import org.nd4j.shade.jackson.databind.annotation.JsonDeserialize;
import org.nd4j.shade.jackson.databind.annotation.JsonSerialize;
@ -58,7 +59,7 @@ public class LossL2 implements ILossFunction {
*
* @param weights Weights array (row vector). May be null.
*/
public LossL2(INDArray weights) {
public LossL2(@JsonProperty("weights") INDArray weights) {
if (weights != null && !weights.isRowVector()) {
throw new IllegalArgumentException("Weights array must be a row vector");
}

View File

@ -19,6 +19,7 @@ package org.nd4j.linalg.lossfunctions.impl;
import lombok.EqualsAndHashCode;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.shade.jackson.annotation.JsonProperty;
/**
* Mean Squared Error loss function: L = 1/N sum_i (actual_i - predicted)^2
@ -38,7 +39,7 @@ public class LossMSE extends LossL2 {
*
* @param weights Weights array (row vector). May be null.
*/
public LossMSE(INDArray weights) {
public LossMSE(@JsonProperty("weights") INDArray weights) {
super(weights);
}