Source code for z2pack.io._encoding

"""Defines functions for encoding and decoding Z2Pack objects."""

import numbers
import contextlib
from functools import singledispatch
from collections.abc import Iterable

import numpy as np
from fsc.export import export

# This can create a circular import if it is imported by name (from ... import ...)
# If this is ever an issue, consider splitting the encoding by surface / line
from ..line import LineResult, WccLineData, OverlapLineData, EigenstateLineData
from ..surface._data import SurfaceData, LinePosition
from ..surface._result import SurfaceResult
from ..volume._data import VolumeData, SurfacePosition
from ..volume._result import VolumeResult


[docs]@export @singledispatch def encode(obj): """ Encodes Z2Pack types into JSON / msgpack - compatible types. """ raise TypeError('cannot JSONify {} object {}'.format(type(obj), obj))
@encode.register(np.bool_) def _(obj): return bool(obj) @encode.register(numbers.Real) def _(obj): return float(obj) @encode.register(numbers.Complex) def _(obj): return dict(__complex__=True, real=encode(obj.real), imag=encode(obj.imag)) @encode.register(Iterable) def _(obj): return list(obj) @encode.register(EigenstateLineData) def _(obj): return dict( __eigenstate_line_data__=True, eigenstates=encode(obj.eigenstates) ) @encode.register(OverlapLineData) def _(obj): return dict(__overlap_line_data__=True, overlaps=encode(obj.overlaps)) @encode.register(WccLineData) def _(obj): return dict(__wcc_line_data__=True, wcc=encode(obj.wcc)) @encode.register(LineResult) def _(obj): return dict( __line_result__=True, data=encode(obj.data), ctrl_convergence=obj.ctrl_convergence, ctrl_states=obj.ctrl_states ) @encode.register(LinePosition) def _(obj): return dict(__line_position__=True, t=obj.t, result=encode(obj.result)) @encode.register(SurfaceData) def _(obj): return dict(__surface_data__=True, lines=encode(obj.lines)) @encode.register(SurfaceResult) def _(obj): return dict( __surface_result__=True, data=encode(obj.data), ctrl_convergence=obj.ctrl_convergence, ctrl_states=obj.ctrl_states ) @encode.register(SurfacePosition) def _(obj): return dict(__surface_position__=True, s=obj.s, result=encode(obj.result)) @encode.register(VolumeData) def _(obj): return dict(__volume_data__=True, surfaces=encode(obj.surfaces)) @encode.register(VolumeResult) def _(obj): return dict( __volume_result__=True, data=encode(obj.data), ctrl_convergence=obj.ctrl_convergence, ctrl_states=obj.ctrl_states ) # --------------------------------------------------------------------- #
[docs]@export @singledispatch def decode(obj): """ Decodes JSON / msgpack objects into the corresponding Z2Pack types. """ return obj
def decode_volume_result(obj): """ Decodes a dict into a VolumeResult instance. """ # The states / convergence of the controls are set manually res = VolumeResult(obj['data'], [], []) res.ctrl_convergence = decode(obj['ctrl_convergence']) res.ctrl_states = decode(obj['ctrl_states']) return res def decode_volume_data(obj): """ Decodes a dict into a VolumeData instance. """ return VolumeData(decode(obj['surfaces'])) def decode_surface_position(obj): """ Decodes a dict into a SurfacePosition instance. """ return SurfacePosition(s=obj['s'], result=decode(obj['result'])) def decode_surface_result(obj): """ Decodes a dict into a SurfaceResult instance. """ # The states / convergence of the controls are set manually res = SurfaceResult(obj['data'], [], []) res.ctrl_convergence = decode(obj['ctrl_convergence']) res.ctrl_states = decode(obj['ctrl_states']) return res def decode_surface_data(obj): """ Decodes a dict into a SurfaceData instance. """ return SurfaceData(decode(obj['lines'])) # Needed for legacy results. def decode_surface_line(obj): """ Decodes a dict into a LinePosition instance. """ return decode_line_position(obj) def decode_line_position(obj): """ Decodes a dict into a LinePosition instance. """ return LinePosition(obj['t'], decode(obj['result'])) def decode_line_result(obj): """ Decodes a dict into a LineResult instance. """ # The states / convergence of the controls are set manually res = LineResult(obj['data'], [], []) res.ctrl_convergence = decode(obj['ctrl_convergence']) res.ctrl_states = decode(obj['ctrl_states']) return res def decode_wcc_line_data(obj): """ Decodes a dict into a WccLineData instance. """ return WccLineData(obj['wcc']) def decode_overlap_line_data(obj): """ Decodes a dict into a OverlapLineData instance. """ return OverlapLineData(obj['overlaps']) def decode_eigenstate_line_data(obj): """ Decodes a dict into a EigenstateLineData instance. """ return EigenstateLineData(obj['eigenstates']) def decode_complex(obj): """ Decodes a dict into a complex number. """ return complex(obj['real'], obj['imag']) @decode.register(dict) def _(obj): # pylint: disable=missing-function-docstring with contextlib.suppress(AttributeError): obj = {k.decode('utf-8'): v for k, v in obj.items()} special_markers = [key for key in obj.keys() if key.startswith('__')] if len(special_markers) == 1: name = special_markers[0].strip('__') return globals()['decode_' + name](obj) return obj