Arbiter generic JSON ser/de fixes (#237)

* Arbiter generic JSON ser/de fixes

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Javadoc fix

Signed-off-by: AlexDBlack <blacka101@gmail.com>
master
Alex Black 2019-09-05 11:51:11 +10:00 committed by GitHub
parent f25e3e71e5
commit 7d85775934
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 133 additions and 99 deletions

View File

@ -17,13 +17,14 @@
package org.deeplearning4j.arbiter.optimize.parameter; package org.deeplearning4j.arbiter.optimize.parameter;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import lombok.Getter;
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace; import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
import org.deeplearning4j.arbiter.optimize.serde.jackson.GenericDeserializer; import org.deeplearning4j.arbiter.optimize.serde.jackson.FixedValueDeserializer;
import org.deeplearning4j.arbiter.optimize.serde.jackson.GenericSerializer; import org.deeplearning4j.arbiter.optimize.serde.jackson.FixedValueSerializer;
import org.deeplearning4j.arbiter.util.ObjectUtils; import org.deeplearning4j.arbiter.util.ObjectUtils;
import org.nd4j.shade.jackson.annotation.JsonCreator; import org.nd4j.shade.jackson.annotation.JsonCreator;
import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties;
import org.nd4j.shade.jackson.annotation.JsonProperty; import org.nd4j.shade.jackson.annotation.JsonProperty;
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
import org.nd4j.shade.jackson.databind.annotation.JsonDeserialize; import org.nd4j.shade.jackson.databind.annotation.JsonDeserialize;
import org.nd4j.shade.jackson.databind.annotation.JsonSerialize; import org.nd4j.shade.jackson.databind.annotation.JsonSerialize;
@ -37,9 +38,11 @@ import java.util.Map;
* @param <T> Type of (fixed) value * @param <T> Type of (fixed) value
*/ */
@EqualsAndHashCode @EqualsAndHashCode
@JsonSerialize(using = FixedValueSerializer.class)
@JsonDeserialize(using = FixedValueDeserializer.class)
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
public class FixedValue<T> implements ParameterSpace<T> { public class FixedValue<T> implements ParameterSpace<T> {
@JsonSerialize(using = GenericSerializer.class) @Getter
@JsonDeserialize(using = GenericDeserializer.class)
private Object value; private Object value;
private int index; private int index;

View File

@ -0,0 +1,52 @@
package org.deeplearning4j.arbiter.optimize.serde.jackson;
import org.apache.commons.codec.binary.Base64;
import org.deeplearning4j.arbiter.optimize.parameter.FixedValue;
import org.nd4j.shade.jackson.core.JsonParser;
import org.nd4j.shade.jackson.core.JsonProcessingException;
import org.nd4j.shade.jackson.databind.DeserializationContext;
import org.nd4j.shade.jackson.databind.JsonDeserializer;
import org.nd4j.shade.jackson.databind.JsonNode;
import org.nd4j.shade.jackson.databind.ObjectMapper;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
/**
* A custom deserializer to be used in conjunction with {@link FixedValueSerializer}
* @author Alex Black
*/
public class FixedValueDeserializer extends JsonDeserializer<FixedValue> {
@Override
public FixedValue deserialize(JsonParser p, DeserializationContext deserializationContext) throws IOException, JsonProcessingException {
JsonNode node = p.getCodec().readTree(p);
String className = node.get("@valueclass").asText();
Class<?> c;
try {
c = Class.forName(className);
} catch (Exception e) {
throw new RuntimeException(e);
}
if(node.has("value")){
//Number, String, Enum
JsonNode valueNode = node.get("value");
Object o = new ObjectMapper().treeToValue(valueNode, c);
return new FixedValue<>(o);
} else {
//Everything else
JsonNode valueNode = node.get("data");
String data = valueNode.asText();
byte[] b = new Base64().decode(data);
ObjectInputStream ois = new ObjectInputStream(new ByteArrayInputStream(b));
try {
Object o = ois.readObject();
return new FixedValue<>(o);
} catch (Throwable t) {
throw new RuntimeException(t);
}
}
}
}

View File

@ -0,0 +1,51 @@
package org.deeplearning4j.arbiter.optimize.serde.jackson;
import org.apache.commons.net.util.Base64;
import org.deeplearning4j.arbiter.optimize.parameter.FixedValue;
import org.nd4j.shade.jackson.core.JsonGenerator;
import org.nd4j.shade.jackson.core.type.WritableTypeId;
import org.nd4j.shade.jackson.databind.JsonSerializer;
import org.nd4j.shade.jackson.databind.SerializerProvider;
import org.nd4j.shade.jackson.databind.jsontype.TypeSerializer;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.ObjectOutputStream;
import static org.nd4j.shade.jackson.core.JsonToken.START_OBJECT;
/**
* A custom serializer to handle arbitrary object types
* Uses standard JSON where safe (number, string, enumerations) or Java object serialization (bytes -> base64)
* The latter is not an ideal approach, but Jackson doesn't support serialization/deserialization of arbitrary
* objects very well
*
* @author Alex Black
*/
public class FixedValueSerializer extends JsonSerializer<FixedValue> {
@Override
public void serialize(FixedValue fixedValue, JsonGenerator j, SerializerProvider serializerProvider) throws IOException {
Object o = fixedValue.getValue();
j.writeStringField("@valueclass", o.getClass().getName());
if(o instanceof Number || o instanceof String || o instanceof Enum){
j.writeObjectField("value", o);
} else {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
ObjectOutputStream oos = new ObjectOutputStream(baos);
oos.writeObject(o);
baos.close();
byte[] b = baos.toByteArray();
String base64 = new Base64().encodeToString(b);
j.writeStringField("data", base64);
}
}
@Override
public void serializeWithType(FixedValue value, JsonGenerator gen, SerializerProvider serializers, TypeSerializer typeSer) throws IOException {
WritableTypeId typeId = typeSer.typeId(value, START_OBJECT);
typeSer.writeTypePrefix(gen, typeId);
serialize(value, gen, serializers);
typeSer.writeTypeSuffix(gen, typeId);
}
}

View File

@ -1,46 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.arbiter.optimize.serde.jackson;
import org.nd4j.shade.jackson.core.JsonParser;
import org.nd4j.shade.jackson.databind.DeserializationContext;
import org.nd4j.shade.jackson.databind.JsonDeserializer;
import org.nd4j.shade.jackson.databind.JsonNode;
import org.nd4j.shade.jackson.databind.ObjectMapper;
import java.io.IOException;
/**
* Created by Alex on 15/02/2017.
*/
public class GenericDeserializer extends JsonDeserializer<Object> {
@Override
public Object deserialize(JsonParser p, DeserializationContext ctxt) throws IOException {
JsonNode node = p.getCodec().readTree(p);
String className = node.get("@class").asText();
Class<?> c;
try {
c = Class.forName(className);
} catch (Exception e) {
throw new RuntimeException(e);
}
JsonNode valueNode = node.get("value");
Object o = new ObjectMapper().treeToValue(valueNode, c);
return o;
}
}

View File

@ -1,38 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.arbiter.optimize.serde.jackson;
import org.nd4j.shade.jackson.core.JsonGenerator;
import org.nd4j.shade.jackson.core.JsonProcessingException;
import org.nd4j.shade.jackson.databind.JsonSerializer;
import org.nd4j.shade.jackson.databind.SerializerProvider;
import java.io.IOException;
/**
* Created by Alex on 15/02/2017.
*/
public class GenericSerializer extends JsonSerializer<Object> {
@Override
public void serialize(Object o, JsonGenerator j, SerializerProvider serializerProvider)
throws IOException, JsonProcessingException {
j.writeStartObject();
j.writeStringField("@class", o.getClass().getName());
j.writeObjectField("value", o);
j.writeEndObject();
}
}

View File

@ -24,9 +24,6 @@ import org.nd4j.shade.jackson.databind.SerializationFeature;
import org.nd4j.shade.jackson.dataformat.yaml.YAMLFactory; import org.nd4j.shade.jackson.dataformat.yaml.YAMLFactory;
import org.nd4j.shade.jackson.datatype.joda.JodaModule; import org.nd4j.shade.jackson.datatype.joda.JodaModule;
import java.util.Collections;
import java.util.Map;
/** /**
* Created by Alex on 16/11/2016. * Created by Alex on 16/11/2016.
*/ */
@ -44,6 +41,7 @@ public class JsonMapper {
mapper.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.NONE); mapper.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.NONE);
mapper.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY); mapper.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY);
mapper.setVisibility(PropertyAccessor.CREATOR, JsonAutoDetect.Visibility.ANY); mapper.setVisibility(PropertyAccessor.CREATOR, JsonAutoDetect.Visibility.ANY);
mapper.setVisibility(PropertyAccessor.SETTER, JsonAutoDetect.Visibility.ANY);
yamlMapper = new ObjectMapper(new YAMLFactory()); yamlMapper = new ObjectMapper(new YAMLFactory());
yamlMapper.registerModule(new JodaModule()); yamlMapper.registerModule(new JodaModule());
yamlMapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false); yamlMapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);

View File

@ -32,7 +32,7 @@ import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
*/ */
@Data @Data
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@NoArgsConstructor(access = AccessLevel.PROTECTED) //For Jackson JSON/YAML deserialization @NoArgsConstructor(access = AccessLevel.PUBLIC) //For Jackson JSON/YAML deserialization
public abstract class BaseOutputLayerSpace<L extends BaseOutputLayer> extends FeedForwardLayerSpace<L> { public abstract class BaseOutputLayerSpace<L extends BaseOutputLayer> extends FeedForwardLayerSpace<L> {
protected ParameterSpace<ILossFunction> lossFunction; protected ParameterSpace<ILossFunction> lossFunction;

View File

@ -18,8 +18,11 @@ package org.deeplearning4j.arbiter.ui.data;
import lombok.Getter; import lombok.Getter;
import org.deeplearning4j.arbiter.optimize.config.OptimizationConfiguration; import org.deeplearning4j.arbiter.optimize.config.OptimizationConfiguration;
import org.deeplearning4j.arbiter.ui.misc.JsonMapper; import org.deeplearning4j.arbiter.optimize.serde.jackson.JsonMapper;
import org.deeplearning4j.arbiter.ui.module.ArbiterModule; import org.deeplearning4j.arbiter.ui.module.ArbiterModule;
import org.deeplearning4j.nn.conf.serde.JsonMappers;
import java.io.IOException;
/** /**
* *
@ -64,7 +67,11 @@ public class GlobalConfigPersistable extends BaseJavaPersistable {
public OptimizationConfiguration getOptimizationConfiguration(){ public OptimizationConfiguration getOptimizationConfiguration(){
return JsonMapper.fromJson(optimizationConfigJson, OptimizationConfiguration.class); try {
return JsonMapper.getMapper().readValue(optimizationConfigJson, OptimizationConfiguration.class);
} catch (IOException e){
throw new RuntimeException(e);
}
} }
public int getCandidatesQueued(){ public int getCandidatesQueued(){

View File

@ -26,13 +26,14 @@ import org.deeplearning4j.arbiter.optimize.api.OptimizationResult;
import org.deeplearning4j.arbiter.optimize.runner.CandidateInfo; import org.deeplearning4j.arbiter.optimize.runner.CandidateInfo;
import org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner; import org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner;
import org.deeplearning4j.arbiter.optimize.runner.listener.StatusListener; import org.deeplearning4j.arbiter.optimize.runner.listener.StatusListener;
import org.deeplearning4j.arbiter.optimize.serde.jackson.JsonMapper;
import org.deeplearning4j.arbiter.ui.data.GlobalConfigPersistable; import org.deeplearning4j.arbiter.ui.data.GlobalConfigPersistable;
import org.deeplearning4j.arbiter.ui.data.ModelInfoPersistable; import org.deeplearning4j.arbiter.ui.data.ModelInfoPersistable;
import org.deeplearning4j.arbiter.ui.misc.JsonMapper;
import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.primitives.Pair;
import java.io.IOException;
import java.util.Map; import java.util.Map;
import java.util.UUID; import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentHashMap;
@ -217,7 +218,11 @@ public class ArbiterStatusListener implements StatusListener {
// } // }
//TODO: cache global config, but we don't want to have outdated info (like uninitialized termination conditions) //TODO: cache global config, but we don't want to have outdated info (like uninitialized termination conditions)
ocJson = JsonMapper.asJson(r.getConfiguration()); try {
ocJson = JsonMapper.getMapper().writeValueAsString(r.getConfiguration());
} catch (IOException e){
throw new RuntimeException(e);
}
GlobalConfigPersistable p = new GlobalConfigPersistable.Builder() GlobalConfigPersistable p = new GlobalConfigPersistable.Builder()
.sessionId(sessionId) .sessionId(sessionId)

View File

@ -27,6 +27,7 @@ import org.nd4j.linalg.primitives.Pair;
import org.nd4j.serde.jackson.shaded.NDArrayTextDeSerializer; import org.nd4j.serde.jackson.shaded.NDArrayTextDeSerializer;
import org.nd4j.serde.jackson.shaded.NDArrayTextSerializer; import org.nd4j.serde.jackson.shaded.NDArrayTextSerializer;
import org.nd4j.shade.jackson.annotation.JsonInclude; import org.nd4j.shade.jackson.annotation.JsonInclude;
import org.nd4j.shade.jackson.annotation.JsonProperty;
import org.nd4j.shade.jackson.databind.annotation.JsonDeserialize; import org.nd4j.shade.jackson.databind.annotation.JsonDeserialize;
import org.nd4j.shade.jackson.databind.annotation.JsonSerialize; import org.nd4j.shade.jackson.databind.annotation.JsonSerialize;
@ -58,7 +59,7 @@ public class LossL2 implements ILossFunction {
* *
* @param weights Weights array (row vector). May be null. * @param weights Weights array (row vector). May be null.
*/ */
public LossL2(INDArray weights) { public LossL2(@JsonProperty("weights") INDArray weights) {
if (weights != null && !weights.isRowVector()) { if (weights != null && !weights.isRowVector()) {
throw new IllegalArgumentException("Weights array must be a row vector"); throw new IllegalArgumentException("Weights array must be a row vector");
} }

View File

@ -19,6 +19,7 @@ package org.nd4j.linalg.lossfunctions.impl;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import org.nd4j.linalg.activations.IActivation; import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.shade.jackson.annotation.JsonProperty;
/** /**
* Mean Squared Error loss function: L = 1/N sum_i (actual_i - predicted)^2 * Mean Squared Error loss function: L = 1/N sum_i (actual_i - predicted)^2
@ -38,7 +39,7 @@ public class LossMSE extends LossL2 {
* *
* @param weights Weights array (row vector). May be null. * @param weights Weights array (row vector). May be null.
*/ */
public LossMSE(INDArray weights) { public LossMSE(@JsonProperty("weights") INDArray weights) {
super(weights); super(weights);
} }