# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utils related to keras model saving."""

# pylint: disable=g-bad-import-order, g-direct-tensorflow-import
import tensorflow.compat.v2 as tf
import keras

import copy
import os
from keras import backend
from keras import losses
from keras import optimizer_v1
from keras import optimizers
from keras.engine import base_layer_utils
from keras.utils import generic_utils
from keras.utils import version_utils
from keras.utils.io_utils import ask_to_proceed_with_overwrite
from tensorflow.python.platform import tf_logging as logging
# pylint: enable=g-bad-import-order, g-direct-tensorflow-import


def extract_model_metrics(model):
  """Convert metrics from a Keras model `compile` API to dictionary.

  This is used for converting Keras models to Estimators and SavedModels.

  Args:
    model: A `tf.keras.Model` object.

  Returns:
    Dictionary mapping metric names to metric instances. May return `None` if
    the model does not contain any metrics.
  """
  if getattr(model, '_compile_metrics', None):
    # TODO(psv/kathywu): use this implementation in model to estimator flow.
    # We are not using model.metrics here because we want to exclude the metrics
    # added using `add_metric` API.
    return {m.name: m for m in model._compile_metric_functions}  # pylint: disable=protected-access
  return None


def model_call_inputs(model, keep_original_batch_size=False):
  """Inspect model to get its input signature.

  The model's input signature is a list with a single (possibly-nested) object.
  This is due to the Keras-enforced restriction that tensor inputs must be
  passed in as the first argument.

  For example, a model with input {'feature1': <Tensor>, 'feature2': <Tensor>}
  will have input signature: [{'feature1': TensorSpec, 'feature2': TensorSpec}]

  Args:
    model: Keras Model object.
    keep_original_batch_size: A boolean indicating whether we want to keep using
      the original batch size or set it to None. Default is `False`, which means
      that the batch dim of the returned input signature will always be set to
      `None`.

  Returns:
    A tuple containing `(args, kwargs)` TensorSpecs of the model call function
    inputs.
    `kwargs` does not contain the `training` argument.
  """
  input_specs = model.save_spec(dynamic_batch=not keep_original_batch_size)
  if input_specs is None:
    return None, None
  input_specs = _enforce_names_consistency(input_specs)
  return input_specs


def raise_model_input_error(model):
  if isinstance(model, keras.models.Sequential):
    raise ValueError(
        f'Model {model} cannot be saved because the input shape is not '
        'available. Please specify an input shape either by calling '
        '`build(input_shape)` directly, or by calling the model on actual '
        'data using `Model()`, `Model.fit()`, or `Model.predict()`.')

  # If the model is not a `Sequential`, it is intended to be a subclassed model.
  raise ValueError(
      f'Model {model} cannot be saved either because the input shape is not '
      'available or because the forward pass of the model is not defined.'
      'To define a forward pass, please override `Model.call()`. To specify '
      'an input shape, either call `build(input_shape)` directly, or call '
      'the model on actual data using `Model()`, `Model.fit()`, or '
      '`Model.predict()`. If you have a custom training step, please make '
      'sure to invoke the forward pass in train step through '
      '`Model.__call__`, i.e. `model(inputs)`, as opposed to `model.call()`.')


def trace_model_call(model, input_signature=None):
  """Trace the model call to create a tf.function for exporting a Keras model.

  Args:
    model: A Keras model.
    input_signature: optional, a list of tf.TensorSpec objects specifying the
      inputs to the model.

  Returns:
    A tf.function wrapping the model's call function with input signatures set.

  Raises:
    ValueError: if input signature cannot be inferred from the model.
  """
  if input_signature is None:
    if isinstance(model.call, tf.__internal__.function.Function):
      input_signature = model.call.input_signature

  if input_signature:
    model_args = input_signature
    model_kwargs = {}
  else:
    model_args, model_kwargs = model_call_inputs(model)
    input_signature = model_args  # store

    if model_args is None:
      raise_model_input_error(model)

  @tf.function
  def _wrapped_model(*args, **kwargs):
    """A concrete tf.function that wraps the model's call function."""
    kwargs['training'] = False
    with base_layer_utils.call_context().enter(
        model, inputs=None, build_graph=False, training=False, saving=True):
      outputs = model(*args, **kwargs)

    # Outputs always has to be a flat dict.
    output_names = model.output_names  # Functional Model.
    if output_names is None:  # Subclassed Model.
      from keras.engine import compile_utils  # pylint: disable=g-import-not-at-top
      output_names = compile_utils.create_pseudo_output_names(outputs)
    outputs = tf.nest.flatten(outputs)
    return {name: output for name, output in zip(output_names, outputs)}

  return _wrapped_model.get_concrete_function(*model_args, **model_kwargs)


def model_metadata(model, include_optimizer=True, require_config=True):
  """Returns a dictionary containing the model metadata."""
  from keras import __version__ as keras_version  # pylint: disable=g-import-not-at-top
  from keras.optimizer_v2 import optimizer_v2  # pylint: disable=g-import-not-at-top

  model_config = {'class_name': model.__class__.__name__}
  try:
    model_config['config'] = model.get_config()
  except NotImplementedError as e:
    if require_config:
      raise e

  metadata = dict(
      keras_version=str(keras_version),
      backend=backend.backend(),
      model_config=model_config)
  if model.optimizer and include_optimizer:
    if isinstance(model.optimizer, optimizer_v1.TFOptimizer):
      logging.warning(
          'TensorFlow optimizers do not '
          'make it possible to access '
          'optimizer attributes or optimizer state '
          'after instantiation. '
          'As a result, we cannot save the optimizer '
          'as part of the model save file. '
          'You will have to compile your model again after loading it. '
          'Prefer using a Keras optimizer instead '
          '(see keras.io/optimizers).')
    elif model._compile_was_called:  # pylint: disable=protected-access
      training_config = model._get_compile_args(user_metrics=False)  # pylint: disable=protected-access
      training_config.pop('optimizer', None)  # Handled separately.
      metadata['training_config'] = _serialize_nested_config(training_config)
      if isinstance(model.optimizer, optimizer_v2.RestoredOptimizer):
        raise NotImplementedError(
            'Optimizers loaded from a SavedModel cannot be saved. '
            'If you are calling `model.save` or `tf.keras.models.save_model`, '
            'please set the `include_optimizer` option to `False`. For '
            '`tf.saved_model.save`, delete the optimizer from the model.')
      else:
        optimizer_config = {
            'class_name':
                generic_utils.get_registered_name(model.optimizer.__class__),
            'config':
                model.optimizer.get_config()
        }
      metadata['training_config']['optimizer_config'] = optimizer_config
  return metadata


def should_overwrite(filepath, overwrite):
  """Returns whether the filepath should be overwritten."""
  # If file exists and should not be overwritten.
  if not overwrite and os.path.isfile(filepath):
    return ask_to_proceed_with_overwrite(filepath)
  return True


def compile_args_from_training_config(training_config, custom_objects=None):
  """Return model.compile arguments from training config."""
  if custom_objects is None:
    custom_objects = {}

  with generic_utils.CustomObjectScope(custom_objects):
    optimizer_config = training_config['optimizer_config']
    optimizer = optimizers.deserialize(optimizer_config)

    # Recover losses.
    loss = None
    loss_config = training_config.get('loss', None)
    if loss_config is not None:
      loss = _deserialize_nested_config(losses.deserialize, loss_config)

    # Recover metrics.
    metrics = None
    metrics_config = training_config.get('metrics', None)
    if metrics_config is not None:
      metrics = _deserialize_nested_config(_deserialize_metric, metrics_config)

    # Recover weighted metrics.
    weighted_metrics = None
    weighted_metrics_config = training_config.get('weighted_metrics', None)
    if weighted_metrics_config is not None:
      weighted_metrics = _deserialize_nested_config(_deserialize_metric,
                                                    weighted_metrics_config)

    sample_weight_mode = training_config['sample_weight_mode'] if hasattr(
        training_config, 'sample_weight_mode') else None
    loss_weights = training_config['loss_weights']

  return dict(
      optimizer=optimizer,
      loss=loss,
      metrics=metrics,
      weighted_metrics=weighted_metrics,
      loss_weights=loss_weights,
      sample_weight_mode=sample_weight_mode)


def _deserialize_nested_config(deserialize_fn, config):
  """Deserializes arbitrary Keras `config` using `deserialize_fn`."""

  def _is_single_object(obj):
    if isinstance(obj, dict) and 'class_name' in obj:
      return True  # Serialized Keras object.
    if isinstance(obj, str):
      return True  # Serialized function or string.
    return False

  if config is None:
    return None
  if _is_single_object(config):
    return deserialize_fn(config)
  elif isinstance(config, dict):
    return {
        k: _deserialize_nested_config(deserialize_fn, v)
        for k, v in config.items()
    }
  elif isinstance(config, (tuple, list)):
    return [_deserialize_nested_config(deserialize_fn, obj) for obj in config]

  raise ValueError(
      'Saved configuration not understood. Configuration should be a '
      f'dictionary, string, tuple or list. Received: config={config}.')


def _serialize_nested_config(config):
  """Serialized a nested structure of Keras objects."""

  def _serialize_fn(obj):
    if callable(obj):
      return generic_utils.serialize_keras_object(obj)
    return obj

  return tf.nest.map_structure(_serialize_fn, config)


def _deserialize_metric(metric_config):
  """Deserialize metrics, leaving special strings untouched."""
  from keras import metrics as metrics_module  # pylint:disable=g-import-not-at-top
  if metric_config in ['accuracy', 'acc', 'crossentropy', 'ce']:
    # Do not deserialize accuracy and cross-entropy strings as we have special
    # case handling for these in compile, based on model output shape.
    return metric_config
  return metrics_module.deserialize(metric_config)


def _enforce_names_consistency(specs):
  """Enforces that either all specs have names or none do."""

  def _has_name(spec):
    return hasattr(spec, 'name') and spec.name is not None

  def _clear_name(spec):
    spec = copy.deepcopy(spec)
    if hasattr(spec, 'name'):
      spec._name = None  # pylint:disable=protected-access
    return spec

  flat_specs = tf.nest.flatten(specs)
  name_inconsistency = (
      any(_has_name(s) for s in flat_specs) and
      not all(_has_name(s) for s in flat_specs))

  if name_inconsistency:
    specs = tf.nest.map_structure(_clear_name, specs)
  return specs


def try_build_compiled_arguments(model):
  if (not version_utils.is_v1_layer_or_model(model) and
      model.outputs is not None):
    try:
      if not model.compiled_loss.built:
        model.compiled_loss.build(model.outputs)
      if not model.compiled_metrics.built:
        model.compiled_metrics.build(model.outputs, model.outputs)
    except:  # pylint: disable=bare-except
      logging.warning(
          'Compiled the loaded model, but the compiled metrics have yet to '
          'be built. `model.compile_metrics` will be empty until you train '
          'or evaluate the model.')


def is_hdf5_filepath(filepath):
  return (filepath.endswith('.h5') or filepath.endswith('.keras') or
          filepath.endswith('.hdf5'))
