Arbiter generic JSON ser/de fixes (#237)
* Arbiter generic JSON ser/de fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * Javadoc fix Signed-off-by: AlexDBlack <blacka101@gmail.com>master
parent
f25e3e71e5
commit
7d85775934
|
@ -17,13 +17,14 @@
|
|||
package org.deeplearning4j.arbiter.optimize.parameter;
|
||||
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.Getter;
|
||||
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
|
||||
import org.deeplearning4j.arbiter.optimize.serde.jackson.GenericDeserializer;
|
||||
import org.deeplearning4j.arbiter.optimize.serde.jackson.GenericSerializer;
|
||||
import org.deeplearning4j.arbiter.optimize.serde.jackson.FixedValueDeserializer;
|
||||
import org.deeplearning4j.arbiter.optimize.serde.jackson.FixedValueSerializer;
|
||||
import org.deeplearning4j.arbiter.util.ObjectUtils;
|
||||
import org.nd4j.shade.jackson.annotation.JsonCreator;
|
||||
import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties;
|
||||
import org.nd4j.shade.jackson.annotation.JsonProperty;
|
||||
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
|
||||
import org.nd4j.shade.jackson.databind.annotation.JsonDeserialize;
|
||||
import org.nd4j.shade.jackson.databind.annotation.JsonSerialize;
|
||||
|
||||
|
@ -37,9 +38,11 @@ import java.util.Map;
|
|||
* @param <T> Type of (fixed) value
|
||||
*/
|
||||
@EqualsAndHashCode
|
||||
@JsonSerialize(using = FixedValueSerializer.class)
|
||||
@JsonDeserialize(using = FixedValueDeserializer.class)
|
||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
||||
public class FixedValue<T> implements ParameterSpace<T> {
|
||||
@JsonSerialize(using = GenericSerializer.class)
|
||||
@JsonDeserialize(using = GenericDeserializer.class)
|
||||
@Getter
|
||||
private Object value;
|
||||
private int index;
|
||||
|
||||
|
|
|
@ -0,0 +1,52 @@
|
|||
package org.deeplearning4j.arbiter.optimize.serde.jackson;
|
||||
|
||||
import org.apache.commons.codec.binary.Base64;
|
||||
import org.deeplearning4j.arbiter.optimize.parameter.FixedValue;
|
||||
import org.nd4j.shade.jackson.core.JsonParser;
|
||||
import org.nd4j.shade.jackson.core.JsonProcessingException;
|
||||
import org.nd4j.shade.jackson.databind.DeserializationContext;
|
||||
import org.nd4j.shade.jackson.databind.JsonDeserializer;
|
||||
import org.nd4j.shade.jackson.databind.JsonNode;
|
||||
import org.nd4j.shade.jackson.databind.ObjectMapper;
|
||||
|
||||
import java.io.ByteArrayInputStream;
|
||||
import java.io.IOException;
|
||||
import java.io.ObjectInputStream;
|
||||
|
||||
/**
|
||||
* A custom deserializer to be used in conjunction with {@link FixedValueSerializer}
|
||||
* @author Alex Black
|
||||
*/
|
||||
public class FixedValueDeserializer extends JsonDeserializer<FixedValue> {
|
||||
@Override
|
||||
public FixedValue deserialize(JsonParser p, DeserializationContext deserializationContext) throws IOException, JsonProcessingException {
|
||||
JsonNode node = p.getCodec().readTree(p);
|
||||
String className = node.get("@valueclass").asText();
|
||||
Class<?> c;
|
||||
try {
|
||||
c = Class.forName(className);
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
|
||||
if(node.has("value")){
|
||||
//Number, String, Enum
|
||||
JsonNode valueNode = node.get("value");
|
||||
Object o = new ObjectMapper().treeToValue(valueNode, c);
|
||||
return new FixedValue<>(o);
|
||||
} else {
|
||||
//Everything else
|
||||
JsonNode valueNode = node.get("data");
|
||||
String data = valueNode.asText();
|
||||
|
||||
byte[] b = new Base64().decode(data);
|
||||
ObjectInputStream ois = new ObjectInputStream(new ByteArrayInputStream(b));
|
||||
try {
|
||||
Object o = ois.readObject();
|
||||
return new FixedValue<>(o);
|
||||
} catch (Throwable t) {
|
||||
throw new RuntimeException(t);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,51 @@
|
|||
package org.deeplearning4j.arbiter.optimize.serde.jackson;
|
||||
|
||||
import org.apache.commons.net.util.Base64;
|
||||
import org.deeplearning4j.arbiter.optimize.parameter.FixedValue;
|
||||
import org.nd4j.shade.jackson.core.JsonGenerator;
|
||||
import org.nd4j.shade.jackson.core.type.WritableTypeId;
|
||||
import org.nd4j.shade.jackson.databind.JsonSerializer;
|
||||
import org.nd4j.shade.jackson.databind.SerializerProvider;
|
||||
import org.nd4j.shade.jackson.databind.jsontype.TypeSerializer;
|
||||
|
||||
import java.io.ByteArrayOutputStream;
|
||||
import java.io.IOException;
|
||||
import java.io.ObjectOutputStream;
|
||||
|
||||
import static org.nd4j.shade.jackson.core.JsonToken.START_OBJECT;
|
||||
|
||||
/**
|
||||
* A custom serializer to handle arbitrary object types
|
||||
* Uses standard JSON where safe (number, string, enumerations) or Java object serialization (bytes -> base64)
|
||||
* The latter is not an ideal approach, but Jackson doesn't support serialization/deserialization of arbitrary
|
||||
* objects very well
|
||||
*
|
||||
* @author Alex Black
|
||||
*/
|
||||
public class FixedValueSerializer extends JsonSerializer<FixedValue> {
|
||||
@Override
|
||||
public void serialize(FixedValue fixedValue, JsonGenerator j, SerializerProvider serializerProvider) throws IOException {
|
||||
Object o = fixedValue.getValue();
|
||||
|
||||
j.writeStringField("@valueclass", o.getClass().getName());
|
||||
if(o instanceof Number || o instanceof String || o instanceof Enum){
|
||||
j.writeObjectField("value", o);
|
||||
} else {
|
||||
ByteArrayOutputStream baos = new ByteArrayOutputStream();
|
||||
ObjectOutputStream oos = new ObjectOutputStream(baos);
|
||||
oos.writeObject(o);
|
||||
baos.close();
|
||||
byte[] b = baos.toByteArray();
|
||||
String base64 = new Base64().encodeToString(b);
|
||||
j.writeStringField("data", base64);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void serializeWithType(FixedValue value, JsonGenerator gen, SerializerProvider serializers, TypeSerializer typeSer) throws IOException {
|
||||
WritableTypeId typeId = typeSer.typeId(value, START_OBJECT);
|
||||
typeSer.writeTypePrefix(gen, typeId);
|
||||
serialize(value, gen, serializers);
|
||||
typeSer.writeTypeSuffix(gen, typeId);
|
||||
}
|
||||
}
|
|
@ -1,46 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.serde.jackson;
|
||||
|
||||
import org.nd4j.shade.jackson.core.JsonParser;
|
||||
import org.nd4j.shade.jackson.databind.DeserializationContext;
|
||||
import org.nd4j.shade.jackson.databind.JsonDeserializer;
|
||||
import org.nd4j.shade.jackson.databind.JsonNode;
|
||||
import org.nd4j.shade.jackson.databind.ObjectMapper;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
/**
|
||||
* Created by Alex on 15/02/2017.
|
||||
*/
|
||||
public class GenericDeserializer extends JsonDeserializer<Object> {
|
||||
@Override
|
||||
public Object deserialize(JsonParser p, DeserializationContext ctxt) throws IOException {
|
||||
JsonNode node = p.getCodec().readTree(p);
|
||||
String className = node.get("@class").asText();
|
||||
Class<?> c;
|
||||
try {
|
||||
c = Class.forName(className);
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
|
||||
JsonNode valueNode = node.get("value");
|
||||
Object o = new ObjectMapper().treeToValue(valueNode, c);
|
||||
return o;
|
||||
}
|
||||
}
|
|
@ -1,38 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.arbiter.optimize.serde.jackson;
|
||||
|
||||
import org.nd4j.shade.jackson.core.JsonGenerator;
|
||||
import org.nd4j.shade.jackson.core.JsonProcessingException;
|
||||
import org.nd4j.shade.jackson.databind.JsonSerializer;
|
||||
import org.nd4j.shade.jackson.databind.SerializerProvider;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
/**
|
||||
* Created by Alex on 15/02/2017.
|
||||
*/
|
||||
public class GenericSerializer extends JsonSerializer<Object> {
|
||||
@Override
|
||||
public void serialize(Object o, JsonGenerator j, SerializerProvider serializerProvider)
|
||||
throws IOException, JsonProcessingException {
|
||||
j.writeStartObject();
|
||||
j.writeStringField("@class", o.getClass().getName());
|
||||
j.writeObjectField("value", o);
|
||||
j.writeEndObject();
|
||||
}
|
||||
}
|
|
@ -24,9 +24,6 @@ import org.nd4j.shade.jackson.databind.SerializationFeature;
|
|||
import org.nd4j.shade.jackson.dataformat.yaml.YAMLFactory;
|
||||
import org.nd4j.shade.jackson.datatype.joda.JodaModule;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* Created by Alex on 16/11/2016.
|
||||
*/
|
||||
|
@ -44,6 +41,7 @@ public class JsonMapper {
|
|||
mapper.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.NONE);
|
||||
mapper.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY);
|
||||
mapper.setVisibility(PropertyAccessor.CREATOR, JsonAutoDetect.Visibility.ANY);
|
||||
mapper.setVisibility(PropertyAccessor.SETTER, JsonAutoDetect.Visibility.ANY);
|
||||
yamlMapper = new ObjectMapper(new YAMLFactory());
|
||||
yamlMapper.registerModule(new JodaModule());
|
||||
yamlMapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
|
||||
|
|
|
@ -32,7 +32,7 @@ import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
|
|||
*/
|
||||
@Data
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
@NoArgsConstructor(access = AccessLevel.PROTECTED) //For Jackson JSON/YAML deserialization
|
||||
@NoArgsConstructor(access = AccessLevel.PUBLIC) //For Jackson JSON/YAML deserialization
|
||||
public abstract class BaseOutputLayerSpace<L extends BaseOutputLayer> extends FeedForwardLayerSpace<L> {
|
||||
|
||||
protected ParameterSpace<ILossFunction> lossFunction;
|
||||
|
|
|
@ -18,8 +18,11 @@ package org.deeplearning4j.arbiter.ui.data;
|
|||
|
||||
import lombok.Getter;
|
||||
import org.deeplearning4j.arbiter.optimize.config.OptimizationConfiguration;
|
||||
import org.deeplearning4j.arbiter.ui.misc.JsonMapper;
|
||||
import org.deeplearning4j.arbiter.optimize.serde.jackson.JsonMapper;
|
||||
import org.deeplearning4j.arbiter.ui.module.ArbiterModule;
|
||||
import org.deeplearning4j.nn.conf.serde.JsonMappers;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
/**
|
||||
*
|
||||
|
@ -64,7 +67,11 @@ public class GlobalConfigPersistable extends BaseJavaPersistable {
|
|||
|
||||
|
||||
public OptimizationConfiguration getOptimizationConfiguration(){
|
||||
return JsonMapper.fromJson(optimizationConfigJson, OptimizationConfiguration.class);
|
||||
try {
|
||||
return JsonMapper.getMapper().readValue(optimizationConfigJson, OptimizationConfiguration.class);
|
||||
} catch (IOException e){
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
public int getCandidatesQueued(){
|
||||
|
|
|
@ -26,13 +26,14 @@ import org.deeplearning4j.arbiter.optimize.api.OptimizationResult;
|
|||
import org.deeplearning4j.arbiter.optimize.runner.CandidateInfo;
|
||||
import org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner;
|
||||
import org.deeplearning4j.arbiter.optimize.runner.listener.StatusListener;
|
||||
import org.deeplearning4j.arbiter.optimize.serde.jackson.JsonMapper;
|
||||
import org.deeplearning4j.arbiter.ui.data.GlobalConfigPersistable;
|
||||
import org.deeplearning4j.arbiter.ui.data.ModelInfoPersistable;
|
||||
import org.deeplearning4j.arbiter.ui.misc.JsonMapper;
|
||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Map;
|
||||
import java.util.UUID;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
|
@ -217,7 +218,11 @@ public class ArbiterStatusListener implements StatusListener {
|
|||
// }
|
||||
//TODO: cache global config, but we don't want to have outdated info (like uninitialized termination conditions)
|
||||
|
||||
ocJson = JsonMapper.asJson(r.getConfiguration());
|
||||
try {
|
||||
ocJson = JsonMapper.getMapper().writeValueAsString(r.getConfiguration());
|
||||
} catch (IOException e){
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
|
||||
GlobalConfigPersistable p = new GlobalConfigPersistable.Builder()
|
||||
.sessionId(sessionId)
|
||||
|
|
|
@ -27,6 +27,7 @@ import org.nd4j.linalg.primitives.Pair;
|
|||
import org.nd4j.serde.jackson.shaded.NDArrayTextDeSerializer;
|
||||
import org.nd4j.serde.jackson.shaded.NDArrayTextSerializer;
|
||||
import org.nd4j.shade.jackson.annotation.JsonInclude;
|
||||
import org.nd4j.shade.jackson.annotation.JsonProperty;
|
||||
import org.nd4j.shade.jackson.databind.annotation.JsonDeserialize;
|
||||
import org.nd4j.shade.jackson.databind.annotation.JsonSerialize;
|
||||
|
||||
|
@ -58,7 +59,7 @@ public class LossL2 implements ILossFunction {
|
|||
*
|
||||
* @param weights Weights array (row vector). May be null.
|
||||
*/
|
||||
public LossL2(INDArray weights) {
|
||||
public LossL2(@JsonProperty("weights") INDArray weights) {
|
||||
if (weights != null && !weights.isRowVector()) {
|
||||
throw new IllegalArgumentException("Weights array must be a row vector");
|
||||
}
|
||||
|
|
|
@ -19,6 +19,7 @@ package org.nd4j.linalg.lossfunctions.impl;
|
|||
import lombok.EqualsAndHashCode;
|
||||
import org.nd4j.linalg.activations.IActivation;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.shade.jackson.annotation.JsonProperty;
|
||||
|
||||
/**
|
||||
* Mean Squared Error loss function: L = 1/N sum_i (actual_i - predicted)^2
|
||||
|
@ -38,7 +39,7 @@ public class LossMSE extends LossL2 {
|
|||
*
|
||||
* @param weights Weights array (row vector). May be null.
|
||||
*/
|
||||
public LossMSE(INDArray weights) {
|
||||
public LossMSE(@JsonProperty("weights") INDArray weights) {
|
||||
super(weights);
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue