Source code for dnikit_tensorflow._tensorflow._tensorflow_loading

#
# Copyright 2022 Apple Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import typing as t

import tensorflow as tf

from dnikit.base import Model
import dnikit.typing as dt

from ._tensorflow_protocols import running_tf_1
if running_tf_1():
    from ._tf1_loading import load_tf_1_model_from_memory as tf_memory_load
    from ._tf1_loading import load_tf_1_model_from_path as tf_path_load
else:
    # Using type ignores here because the signature of these functions changes (input param)
    from ._tf2_loading import load_tf_2_model_from_memory as tf_memory_load  # type: ignore
    from ._tf2_loading import load_tf_2_model_from_path as tf_path_load  # type: ignore


[docs]def load_tf_model_from_memory(*, session: t.Optional[tf.compat.v1.Session] = None, model: t.Optional[tf.keras.models.Model] = None) -> Model: """ Initialize a TensorFlow :class:`Model <dnikit.base.Model>` from a model loaded in ``memory``. This function is supported for both TF2 and TF1, but different parameters are required. For TF2, only pass parameter ``model``. For TF1, only pass parameter ``session``. Args: session: Pass only this parameter when running TensorFlow 1. This is the session that contains the graph to execute. model: Pass only this parameter when running TensorFlow 2. This is the TF Keras model. Returns: A TensorFlow :class:`Model <dnikit.base.Model>`. """ if running_tf_1(): error_message = 'For TF1 (currently installed), please pass param `session`' else: error_message = 'For TF2 (currently installed), please pass param `model`' # Raise errors for incorrect if session is None and model is None: raise ValueError(error_message) if session is not None and model is not None: raise ValueError(error_message + ' only.') if running_tf_1() and session is None: raise ValueError(error_message) if not running_tf_1() and model is None: raise ValueError(error_message) # Load TF1 with "session" if running_tf_1(): assert session is not None return tf_memory_load(session) # else, load TF2 with "model" assert model is not None return tf_memory_load(model)
[docs]def load_tf_model_from_path(path: dt.PathOrStr) -> Model: """ Initialize a TensorFlow :class:`Model <dnikit.base.Model>` from a model serialized in ``path`` Currently accepted serialized model formats, depending on if TF 1 or TF 2 is running. TF2 Supported formats: - TensorFlow Keras SavedModel - Keras whole models (h5) - Keras models with separate architecture and weights files TF1 Supported formats: - TensorFlow SavedModel - TensorFlow checkpoint (pass checkpoint prefix as ``path`` param) - TensorFlow protobuf - Keras whole models - Keras models with separate architecture and weights files Note: The keras loaders are currently using ``tf.keras`` instead of ``keras`` natively, and so issues might appear when trying to load models saved with native ``keras`` (not tf.keras). In this case, load the model outside of DNIKit with ``keras`` and pass it to load with :func:`load_tf_model_from_memory <dnikit_tensorflow.load_tf_model_from_memory>`. Args: path: Model path (for single model file) or directory that contains all the model files. Returns: A DNIKit TensorFlow :class:`Model <dnikit.base.Model>`. """ return tf_path_load(path)